Compare commits

..

36 Commits

Author SHA1 Message Date
majdyz
9a2373bf61 test: add E2E screenshots for PR #12870 2026-04-21 20:09:32 +07:00
majdyz
63c4229774 test(backend/copilot): cover reasoning persistence wiring end-to-end
Adds test_reasoning_persists_to_state_session_messages, which drives
reasoning deltas through _baseline_llm_caller and asserts a
role="reasoning" row lands on state.session_messages with the
concatenated delta content.  Catches regressions in
_BaselineStreamState.__post_init__ that silently pass the wrong list
reference to the emitter.
2026-04-21 20:02:55 +07:00
majdyz
c0a27ab878 refactor(backend/copilot): use mock delta in reasoning validation test
- Replace object.__setattr__(__pydantic_extra__) with MagicMock(spec=ChoiceDelta)
  so the test no longer depends on a pydantic-v2 internal attribute name.
- Document the mutate-in-place invariant on _BaselineStreamState.session_messages
  so future edits know the emitter shares the list reference.
2026-04-21 19:59:05 +07:00
majdyz
08b568021b fix(backend/copilot): harden reasoning delta parsing and restore kill switch
- Filter reasoning_details entries by recognised type (reasoning.text /
  reasoning.summary) so future provider metadata cannot leak into the UI
  collapse.
- Swallow + log pydantic ValidationError on malformed OpenRouter
  reasoning payloads instead of aborting the stream; valid text/tool
  events keep flowing.
- Restore the max_thinking_tokens<=0 kill switch on the baseline path so
  operators can silence reasoning without touching the SDK path.
- Drop the duplicate _is_anthropic_route helper; reuse _is_anthropic_model
  from service.py via a lazy import.
- Restore integration coverage for reasoning-only streams and the
  zero-tokens kill switch in service_unit_test.py.
2026-04-21 19:53:14 +07:00
majdyz
316b132a13 fix(backend/copilot): persist baseline reasoning as session rows
pr-test surfaced the headline feature broken: backend emitted a clean reasoning-start/delta/end stream but the frontend Reasoning collapse never rendered.

Root cause: useHydrateOnStreamEnd swaps in the DB-hydrated message list the moment the stream ends, and convertChatSessionToUiMessages.ts only emits {type:'reasoning'} UI parts from ChatMessage(role='reasoning') rows.  SDK persists these rows via acc.reasoning_response in _dispatch_response; baseline didn't, so the live-streamed reasoning parts got overwritten by a reasoning-less hydrate.

Fold persistence into the same BaselineReasoningEmitter that owns the wire events: when a session_messages list is attached, the first reasoning delta appends a ChatMessage(role='reasoning', content=''), every delta mutates .content in lockstep with the StreamReasoningDelta, and close() leaves the row intact.  _BaselineStreamState wires the emitter to its session_messages via __post_init__, so existing callsites don't change.

Mirrors the SDK contract exactly, including across tool-call continuations (each new reasoning block → fresh row). New tests in reasoning_test.py cover the persistence lifecycle (row appended, deltas mutate same row, close keeps row, second block appends new row, no-session works for pure wire emission).
2026-04-21 19:25:02 +07:00
majdyz
db25bbf47d refactor(backend/copilot): extract baseline reasoning into typed module
Address review feedback: the reasoning plumbing was spread across service.py as a mix of inline state, a dict-parsing helper, and a second private close helper, with its own duplicate config field alongside the SDK's thinking-token setting.

* New backend/copilot/baseline/reasoning.py encapsulates the whole concern: ReasoningDetail / OpenRouterDeltaExtension validate the extension fields via pydantic (no getattr / isinstance duck typing), BaselineReasoningEmitter owns the start/delta/end lifecycle, and reasoning_extra_body builds the request fragment.

* _BaselineStreamState drops reasoning_block_id + reasoning_started for a single reasoning_emitter: BaselineReasoningEmitter — three call sites in _baseline_llm_caller collapse to state.reasoning_emitter.on_delta / .close() calls.

* baseline_reasoning_max_tokens deleted; both SDK and baseline now read from the existing claude_agent_max_thinking_tokens, with its docstring updated to describe the shared contract. No reason to have two knobs for the same thing.

* Moved the wire-parser tests to a dedicated backend/copilot/baseline/reasoning_test.py that exercises the pydantic models directly. service_unit_test.py keeps four integration smoke tests that rebuild real ChoiceDelta pydantic chunks (so .model_extra plumbing is exercised end-to-end), and drops the obsolete 'config=0 disables' case.

Net: ~200 fewer lines across service.py + its unit test, behaviour unchanged, reasoning_test.py gives first-class coverage of the parser variants.
2026-04-21 19:07:09 +07:00
majdyz
2517dae85a refactor(backend/copilot): drop unnecessary forward-ref quotes on _BaselineStreamState
Review cycle 3 nit. _BaselineStreamState is defined earlier in the
module (L330) than _close_reasoning_block_if_open (L533), so the
annotation doesn't need to be stringified.
2026-04-21 18:43:14 +07:00
majdyz
080d42b9da fix(backend/copilot): close reasoning/text blocks on exception path
Review cycle 2 follow-up. CodeRabbit flagged that
`_close_reasoning_block_if_open` + thinking-stripper flush + StreamTextEnd
sat in the `_baseline_llm_caller` try block but not its finally, so an
exception mid-stream (network drop, provider 500, cancel) left the
reasoning block unclosed and the frontend collapse never finalised.

- Move close-reasoning + stripper flush + StreamTextEnd emission into the
  outer finally of `_baseline_llm_caller` so they run on both normal and
  exception paths, preserving the
  `...Reasoning/TextEnd -> StreamFinishStep` protocol ordering.
- Remove the now-redundant StreamTextEnd insert-before-StreamFinishStep
  patch in `stream_chat_completion_baseline`'s exception handler — the
  inner finally already closed the text block, so the flag was always
  False by the time the outer handler ran.
- Add `test_reasoning_closed_on_mid_stream_exception` covering the new
  invariant: a stream that yields a reasoning delta then raises must
  still emit StreamReasoningEnd before StreamFinishStep.
2026-04-21 18:39:36 +07:00
majdyz
3d7b381620 refactor(backend/copilot): DRY reasoning-end helper, widen extractor, cover tool_call transition
Review cycle 1 follow-ups.

- Extract `_close_reasoning_block_if_open(state)` helper and replace the
  three inline copies (text branch, tool_calls branch, stream-end) so
  future edits cannot desync the rotation rules.
- Support typed/pydantic entries in `reasoning_details` via attribute
  access fallback — guards against upstream OpenAI-SDK drift that would
  otherwise silently drop every entry.
- Add `test_reasoning_then_tool_call_closes_reasoning_first` covering
  the tool_calls branch (no prior coverage) and
  `test_structured_details_accept_typed_pydantic_entries` covering the
  non-dict fallback.
2026-04-21 18:34:02 +07:00
majdyz
02be5440fc feat(backend/copilot): stream extended_thinking on baseline via OpenRouter
Baseline route's OpenAI-compat call never requested Anthropic extended thinking, so reasoning deltas were invisible even though the frontend's Reasoning collapse was already wired for SDK mode. Fast-mode autopilot never showed a Reasoning block.

Wire the non-OpenAI extension through:

* New 'baseline_reasoning_max_tokens' config (default 8192, 0 disables). Sent as extra_body={'reasoning': {'max_tokens': N}} only on Anthropic routes; other providers ignore the field.

* Extract reasoning from delta via 'reasoning' (legacy string), 'reasoning_content' (DeepSeek), and structured 'reasoning_details'.

* Emit StreamReasoningStart / StreamReasoningDelta / StreamReasoningEnd through the same state machine the SDK adapter uses — reasoning closes on text/tool_use/stream-end so AI SDK v5 keeps the parts distinct.

* Unit tests cover the extractor variants, paired event ordering, reasoning-only streams, and that the reasoning request param is gated by model route and config.
2026-04-21 18:26:45 +07:00
Zamil Majdy
e17e9f13c4 fix(backend/copilot): reduce SDK + baseline prompt cache waste (#12866)
## Summary

Four cost-reduction changes for the copilot feature. Consolidated into
one PR at user request; each commit is self-contained and bisectable.

### 1. SDK: full cross-user cache on every turn (CLI 2.1.116 bump)
Previous behavior: CLI 2.1.97 crashed when `excludeDynamicSections=True`
was combined with `--resume`, so the code fell back to a raw
`system_prompt` string on resume, losing Claude Code's default prompt
and all cache markers. Every Turn 2+ of an SDK session wrote ~33K tokens
to cache instead of reading.

Fix: install `@anthropic-ai/claude-code@2.1.116` in the backend Docker
image and point the SDK at it via
`CHAT_CLAUDE_AGENT_CLI_PATH=/usr/bin/claude`. CLI 2.1.98+ fixes the
crash, so we can use the preset with `exclude_dynamic_sections=True` on
every turn — Turn 1, 2, 3+ all share the same static prefix and hit the
**cross-user** prompt cache.

**Local dev requirement:** if `CHAT_CLAUDE_AGENT_CLI_PATH` is unset, the
bundled 2.1.97 fallback will crash on `--resume`. Install the CLI
globally (`npm install -g @anthropic-ai/claude-code@2.1.116`) or set the
env var.

### 2. Baseline: add `cache_control` markers (commit `756b3ecd9` +
follow-ups)
Baseline path had zero `cache_control` across `backend/copilot/**`.
Every turn was full uncached input (~18.6K tokens, ~$0.058). Two
ephemeral markers — on the system message (content-blocks form) and the
last tool schema — plus `anthropic-beta: prompt-caching-2024-07-31` via
`extra_headers` as defense-in-depth. Helpers split into `_mark_tools_*`
(precomputed once per session) and `_mark_system_*` (per-round, O(1)).
Repeat hellos: ~$0.058 → ~$0.006.

### 3. Drop `get_baseline_supplement()` (commit `6e6c4d791`)
`_generate_tool_documentation()` emitted ~4.3K tokens of `(tool_name,
description)` pairs that exactly duplicated the tools array already in
the same request. Deleted. `SHARED_TOOL_NOTES` (cross-tool workflow
rules) is preserved. Baseline "hello" input: ~18.7K → ~14.4K tokens.

### 4. Langfuse "CoPilot Prompt" v26 (published under `review` label)
Separate, out-of-repo change. v25 had three duplicate "Example Response"
blocks + a 10-step "Internal Reasoning Process" section. v26 collapses
to one example + bullet-form reasoning. Char count 20,481 → 7,075 (rough
4 chars/token → ~5,100 → ~1,770 tokens).

- v26 is published with label `review` (NOT `production`); v25 remains
active.
- Promote via `mcp__langfuse__updatePromptLabels(name="CoPilot Prompt",
version=26, newLabels=["production"])` after smoke-test.
- Rollback: relabel v25 `production`.

## Test plan
- [x] Unit tests for `_build_system_prompt_value` (fresh vs resumed
turns emit identical preset dict)
- [x] SDK compat tests pass including
`test_bundled_cli_version_is_known_good_against_openrouter`
- [x] `cli_openrouter_compat_test.py` passes against CLI 2.1.116
(locally verified with
`CHAT_CLAUDE_AGENT_CLI_PATH=/opt/homebrew/bin/claude`)
- [x] 8 new `_mark_*` unit tests + identity regression test for
`_fresh_*` helpers
- [x] `SHARED_TOOL_NOTES` public-constant test passes; 5 old tool-docs
tests removed
- [ ] **Manual cost verification (commit 1):** send two consecutive SDK
turns; Turn 2 and Turn 3 should both show `cacheReadTokens` ≈ 33K (full
cross-user cache hits).
- [ ] **Manual cost verification (commit 2):** send two "hello" turns on
baseline <5 min apart; Turn 2 reports `cacheReadTokens` ≈ 18K and cost ≈
$0.006.
- [ ] **Regression sweep for commit 3:** one turn per tool family —
`search_agents`, `run_agent`,
`add_memory`/`forget_memory`/`search_memory`, `search_docs`,
`read_workspace_file` — to verify no tool-selection regression from
dropping the prose tool docs.
- [ ] **Langfuse v26 smoke test:** 5-10 varied turns after relabelling
to `production`; compare responses vs v25 for regression on persona,
concision, capability-gap handling, credential security flows.

## Deployment notes
- Production Docker image now installs CLI 2.1.116 (~20 MB added).
- `CHAT_CLAUDE_AGENT_CLI_PATH=/usr/bin/claude` set in the Dockerfile;
runtime can override via env.
- First deploy after this merge needs a fresh image rebuild to pick up
the new CLI.
2026-04-21 16:34:10 +07:00
Zamil Majdy
f238c153a5 fix(backend/copilot): release session cluster lock on completion (#12867)
## Summary

Fixes a bug where a chat session gets silently stuck after the user
presses Stop mid-turn.

**Root cause:** the cancel endpoint marks the session `failed` after
polling 5s, but the cluster lock held by the still-running task is only
released by `on_run_done` when the task actually finishes. If the task
hangs past the 5s poll (slow LLM call, agent-browser step, etc.), the
lock lingers for up to 5 min — `stream_chat_post`'s `is_turn_in_flight`
check sees the flipped meta (`failed`) and enqueues a new turn, but the
run handler sees the stale lock and drops the user's message at
`manager.py:379` (`reject+requeue=False`). The new SSE stream hangs
until its 60s idle timeout.

### Fix

Two cooperating changes:

1. **`mark_session_completed` force-releases the cluster lock** in the
same transaction that flips status to `completed`/`failed`.
Unconditional delete — by the time we're declaring the session dead, we
don't care who the current lock holder is; the lock has to go so the
next enqueued turn can acquire. This is what closes the stuck-session
window.
2. **`ClusterLock.release()` is now owner-checked** (Lua CAS — `GET ==
token ? DEL : noop` atomically). Force-release means another pod may
legitimately own the key by the time the original task's `on_run_done`
eventually fires. Without the CAS, that late `release()` would wipe the
successor's lock. With it, the late `release()` is a safe no-op when the
owner has changed.

Together: prompt release on completion (via force-delete) + safe cleanup
when on_run_done catches up (via CAS). That re-syncs the API-level
`is_turn_in_flight` check with the actual lock state, so the contention
window disappears.

No changes to the worker-level contention handler: `stream_chat_post`
already queues incoming messages into the pending buffer when a turn is
in flight (via `queue_pending_for_http`). With these fixes, the worker
never sees contention in the common case; if it does (true multi-pod
race), the pre-existing `reject+requeue=False` behaviour still applies —
we'll revisit that path with its own PR if it becomes a production
symptom.

### Verification

- Reproduced the original stuck-session symptom locally (Stop mid-turn →
send new message → backend logs `Session … already running on pod …`,
user message silently lost, SSE stream idle 60s then closes).
- After the fix: cancel → new message → turn starts normally (lock
released by `mark_session_completed`).
- `poetry run pyright` — 0 errors on edited files.
- `pytest backend/copilot/stream_registry_test.py
backend/executor/cluster_lock_test.py` — 33 passed (includes the
successor-not-wiped test).

## Changes

- `autogpt_platform/backend/backend/copilot/executor/utils.py` — extract
`get_session_lock_key(session_id)` helper so the lock-key format has a
single source of truth.
- `autogpt_platform/backend/backend/copilot/executor/manager.py` — use
the helper where the cluster lock is created.
- `autogpt_platform/backend/backend/copilot/stream_registry.py` —
`mark_session_completed` deletes the lock key after the atomic status
swap (force-release).
- `autogpt_platform/backend/backend/executor/cluster_lock.py` —
`ClusterLock.release()` (sync + async) uses a Lua CAS to only delete
when `GET == token`, protecting against wiping a successor after a
force-release.

## Test plan

- [ ] Send a message in /copilot that triggers a long turn (e.g.
`run_agent`), press Stop before it finishes, then send another message.
Expect: new turn starts promptly (no 5-min wait for lock TTL).
- [ ] Happy path regression — send a normal message, verify turn
completes and the session lock key is deleted after completion.
- [ ] Successor protection — unit test
`test_release_does_not_wipe_successor_lock` covers: A acquires, external
DEL, B acquires, A.release() is a no-op, B's lock intact.
2026-04-21 16:27:01 +07:00
Zamil Majdy
01f1289aac feat(copilot): real OpenRouter cost + cost-based rate limits (percent-only public API) (#12864)
## Why

After d7653acd0 removed cost estimation, most baseline turns log with
`tracking_type="tokens"` and no authoritative USD figure (see: dashboard
flipped from `cost_usd` to `tokens` after 4/14/2026). Rate-limit
counters were also token-weighted with hand-rolled cache discounts
(cache_read @ 10%, cache_create @ 25%) and a 5× Opus multiplier — a
proxy for cost that drifts from real OpenRouter billing.

This PR wires real generation cost from OpenRouter into both the
cost-tracking log and the rate limiter, and hides raw spend figures from
the user-facing API so clients can't reverse-engineer per-turn cost or
platform margins.

## What

1. **Real cost from OpenRouter** — baseline passes `extra_body={"usage":
{"include": True}}` and reads `chunk.usage.cost` from the final
streaming chunk. `x-total-cost` header path removed. Missing cost logs
an error and skips the counter update (vs the old estimator that
silently under-counted).
2. **Cost-based rate limiting** — `record_token_usage(...)` →
`record_cost_usage(cost_microdollars)`. The weighted-token math, cache
discount factors, and `_OPUS_COST_MULTIPLIER` are gone; real USD already
reflects model + cache pricing.
3. **Redis key migration** — `copilot:usage:*` → `copilot:cost:*` so
stale token counters can't be misinterpreted as microdollars.
4. **LD flags + config** — renamed to
`copilot-daily-cost-limit-microdollars` /
`copilot-weekly-cost-limit-microdollars` (unit in the LD key so values
can't accidentally be set in dollars or cents).
5. **Public `/usage` hides raw $$** — new `CoPilotUsagePublic` /
`UsageWindowPublic` schemas expose only `percent_used` (0-100) +
`resets_at` + `tier` + `reset_cost`. Admin endpoint keeps raw
microdollars for debugging.
6. **Admin API contract** — `UserRateLimitResponse` fields renamed
`daily/weekly_token_limit` → `daily/weekly_cost_limit_microdollars`,
`daily/weekly_tokens_used` → `daily/weekly_cost_used_microdollars`.
Admin UI displays `$X.XX`.

## How

- `baseline/service.py` — pass `extra_body`, extract cost from
`chunk.usage.cost`, drop the `x-total-cost` header fallback entirely.
- `rate_limit.py` — rewritten around `record_cost_usage`,
`check_rate_limit(daily_cost_limit, weekly_cost_limit)`, new Redis key
prefix. Adds `CoPilotUsagePublic.from_status()` projector for the public
API.
- `token_tracking.py` — converts `cost_usd` → microdollars via
`usd_to_microdollars` and calls `record_cost_usage` only when cost is
present.
- `sdk/service.py` — deletes `_OPUS_COST_MULTIPLIER` and simplifies
`_resolve_model_and_multiplier` to `_resolve_sdk_model_for_request`.
- Chat routes: `/usage` and `/usage/reset` return `CoPilotUsagePublic`.
Internal server-side limit checks still use the raw microdollar
`CoPilotUsageStatus`.
- Admin routes: unchanged response shape (renamed fields only).
- Frontend: `UsagePanelContent`, `UsageLimits`, `CopilotPage`,
`BriefingTabContent`, `credits/page.tsx` consume the new public schema
and render "N% used" + progress bar. Admin `RateLimitDisplay` /
`UsageBar` keep `$X.XX`. Helper `formatMicrodollarsAsUsd` retained for
admin use.
- Tests + snapshots rewritten; new assertions explicitly check that raw
`used`/`limit` keys are absent from the public payload.

## Deploy notes

1. **Before rolling this out, create the new LD flags:**
`copilot-daily-cost-limit-microdollars` (default `500000`) and
`copilot-weekly-cost-limit-microdollars` (default `2500000`). Old
`copilot-*-token-limit` flags can stay in LD for rollback.
2. **One-time Redis cleanup (optional):** token-based counters under
`copilot:usage:*` are orphaned and will TTL out within 7 days. Safe to
ignore or delete manually.

## Test plan

- [x] `poetry run test` — all impacted backend tests pass (182/182 in
targeted scope)
- [x] `pnpm test:unit` — all 1628 integration tests pass
- [x] `poetry run format` / `pnpm format` / `pnpm types` clean
- [x] Manual sanity against dev env — Baseline turn logged $0.1221 for
40K/139 tokens on Sonnet 4 (matches expected pricing)
- [ ] `/pr-test --fix` end-to-end against local native stack
2026-04-21 14:34:43 +07:00
Zamil Majdy
343222ace1 feat(platform): defer paid-to-paid subscription downgrades + cancel-pending flow (#12865)
### Why / What / How

**Why:** Only downgrades to FREE were scheduled at period end; paid→paid
downgrades (e.g. BUSINESS→PRO) applied immediately via Stripe proration.
The asymmetry meant users lost their higher tier mid-cycle in exchange
for a Stripe credit voucher only redeemable on a future subscription — a
confusing pattern that produces negative-value paths for users actually
cancelling. There was also no way to cancel a pending downgrade or
paid→FREE cancellation once scheduled.

**What:** Standardize on "upgrade = immediate, downgrade = next cycle"
and let users cancel a pending change by clicking their current tier.
Harden the new code against conflicting subscription state, concurrent
tab races, flaky Stripe calls, and hot-path latency regressions.

**How:**

Subscription state machine:
- **Upgrade** (PRO→BUSINESS) — `stripe.Subscription.modify` with
immediate proration (unchanged). If a downgrade schedule is already
attached, release it first so the upgrade wins.
- **Paid→paid downgrade** (BUSINESS→PRO) — creates a
`stripe.SubscriptionSchedule` with two phases (current tier until
`current_period_end`, target tier after). No mid-cycle tier demotion.
Defensive pre-clear: existing schedule → release;
`cancel_at_period_end=True` → set to False.
- **Paid→FREE** — unchanged: `cancel_at_period_end=True`.
- **Same-tier update** — reuses the existing `POST
/credits/subscription` route. When `target_tier == current_tier`,
backend calls `release_pending_subscription_schedule` (idempotent) and
returns status. No dedicated cancel-pending endpoint — "Keep my current
tier" IS the cancel operation.
- `release_pending_subscription_schedule` is idempotent on
terminal-state schedules and clears both `schedule` and
`cancel_at_period_end` atomically per call.

API surface:
- New fields on `SubscriptionStatusResponse`: `pending_tier` +
`pending_tier_effective_at` (pulled from the schedule's next-phase
`start_date` so dashboard-authored schedules report the correct
timestamp).
- `POST /credits/subscription` now returns `SubscriptionStatusResponse`
(previously `SubscriptionCheckoutResponse`); the response still carries
`url` for checkout flows and adds the status fields inline.
- `get_pending_subscription_change` is cached with a 30s TTL — avoids
hammering Stripe on every home-page load.
- Webhook dispatches
`subscription_schedule.{released,completed,updated}` through the main
`sync_subscription_from_stripe` flow so both event sources converge to
the same DB state.

Implementation notes:
- New Stripe calls use native async (`stripe.Subscription.list_async`
etc.) and typed attribute access — no `run_in_threadpool` wrapping in
the new helpers.
- Shared `_get_active_subscription` helper collapses the "list
active/trialing subs, take first" pattern used by 4 callers.

Frontend:
- `PendingChangeBanner` sub-component above the tier grid with formatted
effective date + "Keep [CurrentTier]" button. `aria-live="polite"` for
screen readers; locale pinned to `en-US` to avoid SSR/CSR hydration
mismatch.
- "Keep [CurrentTier]" also available as a button on the current tier
card.
- Other tier buttons disabled while a change is pending — user must
resolve pending first to prevent stacked schedules.
- `cancelPendingChange` reuses `useUpdateSubscriptionTier` with `tier:
current_tier`; awaits `refetch()` on both success and error paths so the
UI reconciles even if the server succeeded but the client didn't receive
the response.

### Changes

**Backend (`credit.py`, `v1.py`)**
- Tier-ordering helpers (`is_tier_upgrade`/`is_tier_downgrade`).
- `modify_stripe_subscription_for_tier` routes downgrades through
`_schedule_downgrade_at_period_end`; upgrade path releases any pending
schedule first.
- `_schedule_downgrade_at_period_end` defensively releases pre-existing
schedules and clears `cancel_at_period_end` before creating the new
schedule.
- `release_pending_subscription_schedule` idempotent on terminal-state
schedules; logs partial-failure outcomes.
- `_next_phase_tier_and_start` returns both tier and phase-start
timestamp; warns on unknown prices.
- `get_pending_subscription_change` cached (30s TTL), narrow exception
handling.
- `sync_subscription_schedule_from_stripe` delegates to
`sync_subscription_from_stripe` for convergence with the main webhook
path.
- Shared `_get_active_subscription` +
`_release_schedule_ignoring_terminal` helpers.
- `POST /credits/subscription` absorbs the same-tier "cancel pending
change" branch.

**Frontend (`SubscriptionTierSection/*`)**
- `PendingChangeBanner` new sub-component (a11y, locale-pinned date,
paid→FREE vs paid→paid copy split, non-null effective-date assertion, no
`dark:` utilities).
- "Keep [CurrentTier]" button on current tier card.
- `useSubscriptionTierSection` — `cancelPendingChange` reuses the
update-tier mutation.
- Copy: downgrade dialog + status hint updated.
- `helpers.ts` extracted from the main component.

**Tests**
- Backend: +24 tests (95/95 passing): upgrade-releases-pending-schedule,
schedule-releases-existing-schedule, cancel-at-period-end collision,
terminal-state release idempotency, unknown-price logging, status
response population, same-tier-POST-with-pending, webhook delegation.
- Frontend: +5 integration tests (21/21 passing): banner render/hide,
Keep-button click from banner + current card, paid→paid dialog copy.

### Checklist

- [x] Backend unit tests: 95 pass
- [x] Frontend integration tests: 21 pass
- [x] `poetry run format` / `poetry run lint` clean
- [x] `pnpm format` / `pnpm lint` / `pnpm types` clean
- [ ] Manual E2E on live Stripe (dev env) — pending deploy: BUSINESS→PRO
creates schedule, DB tier unchanged until period end
- [ ] Manual E2E: "Keep BUSINESS" in banner releases schedule
- [ ] Manual E2E: cancel pending paid→FREE flips `cancel_at_period_end`
back to false
- [ ] Manual E2E: BUSINESS→PRO (scheduled) then attempt BUSINESS→FREE
clears the PRO schedule, sets cancel_at_period_end
- [ ] Manual E2E: BUSINESS→PRO (scheduled) then upgrade back to BUSINESS
releases the schedule
2026-04-21 14:01:09 +07:00
Zamil Majdy
a8226af725 fix(copilot): dedupe tool row, lift bash_exec timeout, Stop+resend recovery (#12862)
Closes #12861 · [OPEN-3096](https://linear.app/autogpt/issue/OPEN-3096)

## Why

Four related copilot UX / stability issues surfaced on dev once action
tools started rendering inline in the chat (see #12813):

### 1. Duplicate bash_exec row

`GenericTool` rendered two rows saying the same thing for every
completed tool call — a muted subtitle line ("Command exited with code
1" / "Ran: sleep 20") **and** a `ToolAccordion` with the command echoed
in its description. Previously hidden inside the "Show reasoning" /
"Show steps" collapse, now visibly duplicated.

### 2. `bash_exec` capped at 120s via advisory text

The tool schema said `"Max seconds (default 30, max 120)"`; the model
obeyed, so long-running scripts got clipped at 120s with a vague `Timed
out after 120s` even though the E2B sandbox has no such limit. Confirmed
via Langfuse traces — the model picks `120` for long scripts because
that's what the schema told it the max was. E2B path never had a
server-side clamp.

Originally added in #12103 (default 30) and tightened to "max 120"
advisory in #12398 (token-reduction pass).

### 3. 30s default was too aggressive

`pip install`, small data-processing scripts, etc. routinely cross 30s
and got killed before the model thought to retry with a bigger timeout.

### 4. Stop + edit + resend → "The assistant encountered an error"
([OPEN-3096](https://linear.app/autogpt/issue/OPEN-3096))

Two independent bugs both land on the same banner — fixing only one
leaves the other visible on the next action.

**4a. Stream lock never released on Stop** *(the error in the ticket
screenshot)*. The executor's `async for chunk in
stream_and_publish(...)` broke out on `cancel.is_set()` without calling
`aclose()` on the wrapper. `async for` does NOT auto-close iterators on
`break`, so `stream_chat_completion_sdk` stayed suspended at its current
`await` — still holding the per-session Redis lock (TTL 120s) until GC
eventually closed it. The next `POST /stream` hit `lock.try_acquire()`
at
[sdk/service.py](autogpt_platform/backend/backend/copilot/sdk/service.py)
and yielded `StreamError("Another stream is already active for this
session. Please wait or stop it.")`. The `except GeneratorExit →
lock.release()` handler written exactly for this case never fired
because nothing sent GeneratorExit.

**4b. Orphan `tool_use` after stop-mid-tool.** Even with the lock
released, the stop path persists the session ending on an assistant row
whose `tool_calls` have no matching `role="tool"` row. On the next turn,
`_session_messages_to_transcript` hands Claude CLI `--resume` a JSONL
with a `tool_use` and no paired `tool_result`, and the SDK raises a
vague error — same banner. The ticket's "Open questions" explicitly
flags this.

## What

**Frontend — `GenericTool.tsx`** split responsibilities between the two
rows so they don't duplicate:
- **Subtitle row** (always visible, muted): *what ran* — `Ran: sleep
120`. Never the exit code.
- **Accordion description**: *how it ended* — `completed` / `status code
127 · bash: missing-bin: command not found` / `Timed out after 120s` /
(fallback to command preview for legacy rows missing `exit_code` /
`timed_out`). Pulled from the first non-empty line of `stdout` /
`stderr` when available.
- **Expanded accordion**: full command + stdout + stderr code blocks
(unchanged).

**Backend — `bash_exec.py`**:
- Drop the "max 120" advisory from the schema description.
- Bump default `timeout: 30 → 120`.
- Clean up the result message — `"Command executed with status code 0"`
(no "on E2B", no parens).

**Backend — `executor/processor.py` + `stream_registry.py` (OPEN-3096
#4a)**: wrap the consumer `async for` in `try/finally: await
stream.aclose()`. Close now propagates through `stream_and_publish` into
`stream_chat_completion_sdk`, whose existing `except GeneratorExit →
lock.release()` releases the Redis lock immediately on cancel. Stream
types tightened to `AsyncGenerator[StreamBaseResponse, None]` so the
defensive `getattr(stream, "aclose", None)` goes away.

**Backend — `session_cleanup.py` (OPEN-3096 #4b)**: new
`prune_orphan_tool_calls()` helper walks the trailing session tail and
drops any trailing assistant row whose `tool_calls` have unresolved ids
(plus everything after it) and any trailing `STOPPED_BY_USER_MARKER`
system-stop row. Single backward pass — tolerates the marker being
present or absent. Called from the existing turn-start cleanup in both
`sdk/service.py` and `baseline/service.py`; takes an optional
`log_prefix` so both paths emit the same INFO log when something was
popped. In-memory only — the DB save path is append-only via
`start_sequence`.

## Test plan

- [x] `pnpm exec vitest run src/app/(platform)/copilot/tools/GenericTool
src/app/(platform)/copilot/components/ChatMessagesContainer` — 105 pass
(6 new for GenericTool subtitle/description variants + legacy-fallback
case).
- [x] `pnpm format` / `pnpm lint` / `pnpm types` — clean.
- [x] `poetry run pytest
backend/copilot/sdk/session_persistence_test.py` — 17 pass (6 + 3 new
covering the orphan-tool-call prune and its optional-log-prefix branch).
- [x] `poetry run pytest backend/copilot/stream_registry_test.py
backend/copilot/executor/processor_test.py` — 19 pass (2 for aclose
propagation on the `stream_and_publish` wrapper, 2 for `_execute_async`
aclose propagation on both exit paths, 1 for publish_chunk RedisError
warning ladder).
- [x] `poetry run ruff check` / `poetry run pyright` on touched files —
clean.
- [x] Manual: fire a `bash_exec` — one labelled row, accordion
description reads sensibly (`completed` / `status code 1 · …` / `Timed
out after 120s`).
- [x] Manual: script that needs >120s — no longer clipped.
- [x] Manual: Stop mid-tool + edit + resend — Autopilot resumes without
"Another stream is already active" and without the vague SDK error.

## Scope note

Does not touch `splitReasoningAndResponse` — re-collapsing action tools
back into "Show steps" is #12813's responsibility.
2026-04-21 10:18:52 +07:00
Ubbe
f06b5293de fix(frontend/library): compute monthly spend for AgentBriefingPanel (#12854)
### Why / What / How

<img width="900" alt="Screenshot 2026-04-20 at 19 52 22"
src="https://github.com/user-attachments/assets/c30d5f18-2842-4a8a-ac3d-5bfee18fcd56"
/>

**Why:** The "Spent this month" tile in the Agent Briefing Panel on the
Library page always showed `$0`, even for users with real execution
usage. The tile is meant to give a quick sense of monthly spend across
all agents.

**What:** Compute `monthlySpend` from actual execution data and format
it as currency.

**How:**
- `useLibraryFleetSummary` now sums `stats.cost` (cents) across every
execution whose `started_at` falls within the current calendar month.
Previously `monthlySpend` was hardcoded to `0`.
- `FleetSummary.monthlySpend` is documented as being in cents
(consistent with backend + `formatCents`).
- `StatsGrid` now uses `formatCents` from the copilot usage helpers to
render the tile (e.g. `$12.34` instead of the broken `$0`).

### Changes 🏗️

-
`autogpt_platform/frontend/src/app/(platform)/library/hooks/useLibraryFleetSummary.ts`:
aggregate `stats.cost` across executions started in the current calendar
month; add `toTimestamp` and `startOfCurrentMonth` helpers.
-
`autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/StatsGrid.tsx`:
format the "Spent this month" tile via shared `formatCents` helper.
- `autogpt_platform/frontend/src/app/(platform)/library/types.ts`:
document that `FleetSummary.monthlySpend` is in cents.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [ ] I have tested my changes according to the test plan:
- [ ] Load `/library` with the `AGENT_BRIEFING` flag enabled and at
least one completed execution in the current month — the "Spent this
month" tile shows the correct cumulative cost.
  - [ ] With no executions this month, the tile shows `$0.00`.
- [ ] Type-check (`pnpm types`), lint (`pnpm lint`), and integration
tests (`pnpm test:unit`) pass locally.

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-20 20:28:47 +07:00
Zamil Majdy
70b591d74f fix(copilot): persist reasoning, split steps/reasoning UX, fix mid-turn promote stream stall (#12853)
## Why

Four related issues that surfaced when queued follow-ups hit an
extended_thinking turn:

1. **Mid-turn promote stalled the SSE stream.** `pollBackendAndPromote`
used `setMessages((prev) => [...prev, bubble])` — Vercel AI SDK's
`useChat` streams SSE deltas into `messages[-1]`, so once a user bubble
ended up there, every subsequent chunk silently landed on the wrong
message. Chat sat frozen until a page refresh, even though the backend's
stream completed cleanly.
2. **Thinking-only final turn looked identical to a frozen UI.** When
Claude's last LLM call after a tool_result produced only a
`ThinkingBlock` (no `TextBlock`, no `ToolUseBlock`), the response
adapter silently dropped it and the UI hung on "Thought for Xs" with no
response text.
3. **Reasoning was invisible.** `ThinkingBlock` was dropped live and
never persisted in a way the frontend could render — sessions on reload
/ shared links showed no thinking, a confusing UX gap ("display for
nothing").
4. **Cross-pod Redis replay dropped reasoning events.** The
`stream_registry._reconstruct_chunk` type map had no entries for
`reasoning-*` types, so any client that subscribed mid-stream (share,
reload, cross-pod) silently dropped them with `Unknown chunk type:
reasoning-delta`.

## What

### Mid-turn promote — splice before the trailing assistant

In `useCopilotPendingChips.ts::pollBackendAndPromote`:

```ts
setMessages((prev) => {
  const bubble = makePromotedUserBubble(drained, "midturn", crypto.randomUUID());
  const lastIdx = prev.length - 1;
  if (lastIdx >= 0 && prev[lastIdx].role === "assistant") {
    return [...prev.slice(0, lastIdx), bubble, prev[lastIdx]];
  }
  return [...prev, bubble];
});
```

Streaming assistant stays at `messages[-1]`, AI SDK deltas keep routing
correctly. `useHydrateOnStreamEnd` snaps the bubble to the DB-canonical
position when the stream ends.

### Reasoning — end-to-end visibility (live + persisted)

- **Wire protocol**: new `StreamReasoningStart` / `StreamReasoningDelta`
/ `StreamReasoningEnd` events matching AI SDK v5's `reasoning-*` wire
names, so `useChat` accumulates them into a `type: 'reasoning'`
UIMessage part natively.
- **Response adapter**: every `ThinkingBlock` now emits reasoning
events; text/tool_use transitions close the open reasoning block so AI
SDK doesn't merge distinct parts.
- **Stream registry**: added `reasoning-*` types to
`_reconstruct_chunk`'s type_to_class map so Redis replay no longer drops
them on cross-pod / reload / share.
- **Persistence** (new): each `StreamReasoningStart` opens a
`ChatMessage(role="reasoning")` row in `session.messages`; deltas
accumulate into its content; `StreamReasoningEnd` closes it. No schema
migration — `ChatMessage.role` is already `String`.
`extract_context_messages` filters `role="reasoning"` out of LLM context
(the `--resume` CLI session already carries thinking separately) so the
model never re-ingests prior reasoning.
- **Frontend conversion**: `convertChatSessionMessagesToUiMessages` maps
`role="reasoning"` DB rows into `{type: "reasoning", text}` parts on the
surrounding assistant bubble, so reload / shared-link sessions render
reasoning identically to live stream.

### Steps / Reasoning UX — modal + accordion split

- **`StepsCollapse`** (new): a Dialog-backed "Show steps" modal wraps
the pre-final-answer group (tool timeline + per-block reasoning). Modal
keeps the steps visually grouped and out of the reading flow.
- **`ReasoningCollapse`** (rewritten): inline accordion with "Show
reasoning" / "Hide reasoning" toggle — no longer a modal, so it expands
*inside* the Steps modal without stacking two dialogs. Reasoning text
appears indented with a left border.
- **`splitReasoningAndResponse`**: reasoning parts now stay in the
reasoning group (instead of being pinned out), so they show up inside
the Steps modal alongside the tool-use timeline.

### Thinking-only final turn — synthesize a closing line
(belt-and-suspenders)

- **Prompt rule** (`_USER_FOLLOW_UP_NOTE`): "Every turn MUST end with at
least one short user-facing text sentence."
- **Adapter fallback**: tracks `_text_since_last_tool_result`; at
`ResultMessage success` with tools run + zero text since, opens a fresh
step (`UserMessage` already closed the previous one) and injects `"(Done
— no further commentary.)"` before `StreamFinish`. Only fires for the
pathological case — pure-text turns untouched.

## Test plan

- [x] `pnpm vitest run` on copilot files — all 638 prior tests pass;
**17 new tests** added covering:
- `convertChatSessionToUiMessages`: reasoning row alone / merged with
assistant text / multi-row / empty skip / duration capture
- `ReasoningCollapse`: initial collapsed, toggle, `rotate-90`,
`aria-expanded`
  - `StepsCollapse`: trigger + dialog open renders children
- `MessagePartRenderer`: reasoning → `<pre>` inside collapse,
whitespace/missing text → null
  - `splitReasoningAndResponse`: reasoning-stays-in-reasoning regression
- [x] `poetry run pytest backend/copilot/sdk/response_adapter_test.py` —
36 pass (7 new: 4 reasoning streaming, 3 thinking-only fallback)
- [x] Manual: reasoning streams live and persists across reload on a
fresh session
- [x] Manual: previously-created sessions (pre-persistence) don't have
`role="reasoning"` rows — behaves as a clean no-op (no reasoning shown,
no error), new sessions render reasoning inside Steps modal

## Notes

- No DB migration — `ChatMessage.role` is already an open `String`;
`role="reasoning"` is simply filtered out of LLM context builds but
rendered by the frontend.
- Addresses /pr-review blockers: (a) stream_registry missing reasoning
types in Redis round-trip, (b) fallback text emitted outside a step, (c)
dead `case "thinking"` in renderer (now uses the live `reasoning` type
uniformly).
2026-04-19 10:37:04 +07:00
Zamil Majdy
b1c043c2d8 feat(copilot): queue follow-up messages on busy sessions (UI + run_sub_session + AutoPilot block) (#12737)
## Why

Users and tools can target a copilot session that already has a turn
running. Before this PR there was no uniform behaviour for that case —
the UI manually routed to a separate queue endpoint, `run_sub_session`
and the AutoPilot block raced the cluster lock, and in-turn follow-ups
only reached the model at turn-end via auto-continue. Outcome: dropped
messages, duplicate tool rows, missed mid-turn intent, latent
correctness bugs in block execution.

## What

A single "message arrived → turn already running?" primitive, shared by
every caller:

1. **POST `/stream`** (UI chat): self-defensive. Session idle → SSE as
today; session busy → `202 application/json` with `{buffer_length,
max_buffer_length, turn_in_flight}`. The deprecated `POST
/messages/pending` endpoint is removed (`GET /messages/pending` peek
stays).
2. **`run_copilot_turn_via_queue`** (shared primitive from #12841, used
by `run_sub_session` + `AutoPilotBlock`): gains the same busy-check.
Busy session → push to pending buffer, return `("queued",
SessionResult(queued=True, pending_buffer_length=N))` without creating a
stream registry session or enqueueing a RabbitMQ job. All callers
inherit queueing.
3. **Mid-turn delivery**: drained follow-ups are attached to every
tool_result's `additionalContext` via the SDK's `PostToolUse` hook —
covers both MCP and built-in tools (WebSearch/Read/Agent/etc.), not just
`run_block`. Claude reads the queued text on the next LLM round of the
same turn.
4. **UI observability**: chips promote to a proper user bubble at the
correct chronological position (after the tool_result row that consumed
them). Auto-continue handles end-of-turn drainage; mid-turn backend poll
handles the tool-boundary drainage path.

## How

**Data plane**
- `backend/copilot/pending_messages.py` — Redis list per session
(LPOP-count for atomic drain), TTL, fire-and-forget pub/sub notify. MAX
10 per session.
- `backend/copilot/pending_message_helpers.py` — `is_turn_in_flight`,
`queue_user_message`, `drain_and_format_for_injection`,
`persist_pending_as_user_rows` (shared persist+rollback used by both
baseline and SDK paths).
- `backend/data/redis_helpers.py` — centralised `incr_with_ttl`,
`capped_rpush`, `hash_compare_and_set`; every Lua script and pipeline
atomicity lives in one place.

**Injection sites**
- `backend/copilot/sdk/security_hooks.py::post_tool_use_hook` — drains +
returns `additionalContext`. Single hook covers built-in + MCP tools.
- `backend/copilot/sdk/service.py` — `StreamToolOutputAvailable`
dispatch persists the drained follow-up as a real user row right after
the tool_result (UI bubble at the right index).
`state.midturn_user_rows` keeps the CLI upload watermark honest.
- `backend/copilot/baseline/service.py` — same drain at round
boundaries, uses the shared `persist_pending_as_user_rows` helper so
baseline + SDK code paths don't diverge.

**Dispatch**
- `backend/copilot/sdk/session_waiter.py::run_copilot_turn_via_queue` —
`is_turn_in_flight` short-circuit; `SessionResult` gains `queued` +
`pending_buffer_length`; `SessionOutcome` gains `"queued"`.
- `backend/api/features/chat/routes.py::stream_chat_post` — busy-check
returns 202 with `QueuePendingMessageResponse`; `POST /messages/pending`
deleted.
- `backend/copilot/tools/run_sub_session.py` / `models.py` —
`SubSessionStatusResponse.status` gains `"queued"`;
`response_from_outcome` renders a clear queued-state message with the
pending-buffer depth and a link to watch live.
- `backend/blocks/autopilot.py::execute_copilot` — surfaces queued state
as descriptive response text + empty `tool_calls`/history when
`result.queued`.

**Frontend**
- `src/app/(platform)/copilot/useCopilotPendingChips.ts` — hook owning
the chip lifecycle: backend peek on session load, auto-continue
promotion when a second assistant id appears, mid-turn poll that
promotes when the backend count drops.
- `src/app/(platform)/copilot/useHydrateOnStreamEnd.ts` —
force-hydrate-waits-for-fresh-reference dance extracted.
- `src/app/(platform)/copilot/helpers/stripReplayPrefix.ts` — pure
function with drop / strip / streaming-catch-up cases + helper
decomposition.
- `src/app/(platform)/copilot/helpers/makePromotedBubble.ts` — one-line
helper for the promoted bubble shape.
- `src/app/(platform)/copilot/helpers/queueFollowUpMessage.ts` — thin
`fetch` wrapper for the 202 path (AI SDK's `useChat` fetcher only
handles SSE, so we can't reuse `sendMessage` for the queued response).

## Test plan

Backend unit + integration (`poetry run pytest backend/copilot
backend/api/features/chat`):
- [x] 107 tests pass — pending buffer, drain helpers, routes,
session_waiter queue branch, run_sub_session outcome rendering,
autopilot block
- [x] New `session_waiter_test.py` proves the queue branch
short-circuits `stream_registry.create_session` + `enqueue_copilot_turn`
- [x] Mid-turn persist has a rollback-and-re-queue path tested for when
`session.messages` persist silently fails to back-fill sequences

Frontend unit (`pnpm vitest run`):
- [x] 630 tests pass incl. 22 new for extracted helpers + hooks
- [x] Frontend coverage on touched copilot files: 91%+ (patch 87.37%)

Manual (once merged):
- [ ] Queue two chips while a tool is running; Claude acknowledges both
on the next round, UI shows bubbles in typing order after the tool
output
- [ ] Hand AutoPilot block an existing session_id that has a live turn;
block returns queued status, in-flight turn drains the message on its
next round
- [ ] `run_sub_session` against a busy sub — status=`queued`,
`sub_autopilot_session_link` lets user watch live

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-19 00:48:59 +07:00
Zamil Majdy
fcaebd1bb7 refactor(backend/copilot): unified queue-backed copilot turns + async sub-AutoPilot + guide-read gate (#12841)
### Why / What / How

**Why:** the 10-min stream-level idle timeout was killing legitimate
long-running tool calls — notably sub-AutoPilot runs via
`run_block(AutoPilotBlock)`, which routinely take 15–45 min. The symptom
users saw was `"A tool call appears to be stuck"` even though AutoPilot
was actively working. A second long-standing rough edge was shipped
alongside: agents often skipped `get_agent_building_guide` when
generating agent JSON, producing schemas that failed validation and
burned turns on auto-fix loops.

**What:** three threaded pieces.

1. **Async sub-AutoPilot via `run_sub_session`.** New copilot tool that
delegates a task to a fresh (or resumed) sub-AutoPilot, and its
companion `get_sub_session_result` for polling/cancelling. The agent
starts with `run_sub_session(prompt, wait_for_result≤300s)` and, if the
sub isn't done inside the cap, receives a handle + polls via
`get_sub_session_result(wait_if_running≤300s)`. No single MCP call ever
blocks the stream for more than 5 min, so the 10-min stream-idle timer
stays simple and effective (derived as `MAX_TOOL_WAIT_SECONDS * 2`).

2. **Queue-backed copilot turn dispatch** — one code path for all three
callers.
- `run_sub_session` enqueues a `CoPilotExecutionEntry` on the existing
`copilot_execution` exchange instead of spawning an in-process
`asyncio.Task`.
- `AutoPilotBlock.execute_copilot` (graph block) now uses the **same
queue** instead of `collect_copilot_response` inline.
   - The HTTP SSE endpoint was already queue-backed.
- All three share a single primitive: `run_copilot_turn_via_queue` →
`create_session` → `enqueue_copilot_turn` → `wait_for_session_result`.
The event-aggregation logic (`EventAccumulator`/`process_event`) is a
shared module used by both the direct-stream path and the cross-process
waiter.
- Benefits: **deploy/crash resilience** (RabbitMQ redelivery survives
worker restarts), **natural load balancing** across copilot_executor
workers, **sessions as first-class resources** (UI users can
`/copilot?sessionId=<inner>` into any sub or AutoPilot block's session),
and every future stream-level feature (pending-messages drain #12737,
compaction policies, etc.) applies uniformly instead of bypassing
graph-block sessions.

3. **Guide-read gate on agent-generation tools.** `create_agent` /
`edit_agent` / `validate_agent_graph` / `fix_agent_graph` refuse until
the session has called `get_agent_building_guide`. The pre-existing soft
hint was routinely ignored; the gate makes the dependency enforceable.
All four tool descriptions advertise the requirement in one tightened
sentence ("Requires get_agent_building_guide first (refuses
otherwise).") that stays under the 32000-char schema budget.

**How:**

#### Queue-backed sub-AutoPilot + AutoPilotBlock

- `sdk/session_waiter.py` — new module. `SessionResult` dataclass
mirrors `CopilotResult`. `wait_for_session_result` subscribes to
`stream_registry`, drains events via shared `process_event`, returns
`(outcome, result)`. `wait_for_session_completion` is the cheaper
outcome-only variant. `run_copilot_turn_via_queue` is the canonical
three-step dispatch. Every exit path unsubscribes the listener.
- `sdk/stream_accumulator.py` — new module. `EventAccumulator`,
`ToolCallEntry`, `process_event` extracted from `collect.py`. Both the
direct-stream and cross-process paths now use the same fold logic.
- `tools/run_sub_session.py` / `tools/get_sub_session_result.py` —
rewritten around the shared primitive. `sub_session_id` is now the sub's
`ChatSession` id directly (no separate registry handle). Ownership
re-verified on every call via `get_chat_session`. Cancel via
`enqueue_cancel_task` on the existing `copilot_cancel` fan-out exchange.
- `blocks/autopilot.py` — `execute_copilot` replaced its inline
`collect_copilot_response` with `run_copilot_turn_via_queue`.
`SessionResult` carries response text, tool calls, and token usage back
from the worker so no DB round-trip is needed. The block's public I/O
contract (inputs, outputs, `ToolCallEntry` shape) is unchanged.
- `CoPilotExecutionEntry` gains a `permissions: CopilotPermissions |
None` field forwarded to the worker's `stream_fn` so the sub's
capability filter survives the queue hop. The processor passes it
through to `stream_chat_completion_sdk` /
`stream_chat_completion_baseline`.
- **Deleted**: `sdk/sub_session_registry.py` (module-level dict,
done-callback, abandoned-task cap, `notify_shutdown_and_cancel_all`,
`_reset_for_test`), plus the shutdown-notifier hook in
`copilot_executor.processor.cleanup` — redundant under queue-backed
execution.

#### Run_block single-tool cap (3)

- `tools/helpers.execute_block` caps block execution at
`MAX_TOOL_WAIT_SECONDS = 5 min` via `asyncio.wait_for` around the
generator consumption.
- On timeout: logs `copilot_tool_timeout tool=run_block block=…
block_id=… input_keys=… user=… session=… cap_s=…` (grep-friendly) and
returns an `ErrorResponse` that redirects the LLM to `run_agent` /
`run_sub_session`.
- Billing protection: `_charge_block_credits` is called in a `finally`
guarded by `asyncio.shield` and marked `charge_handled` **before** the
await so cancel-mid-charge doesn't double-bill and
cancel-mid-generator-before-charge still settles via the finally.

#### Guide-read gate

- `helpers.require_guide_read(session, tool_name)` scans
`session.messages` for any prior assistant tool call named
`get_agent_building_guide` (handles both OpenAI and flat shapes).
Applied at the top of `_execute` in `create_agent`, `edit_agent`,
`validate_agent_graph`, `fix_agent_graph`. Tool descriptions advertise
the requirement.

#### Shared timing constants

- `MAX_TOOL_WAIT_SECONDS = 5 * 60` + `STREAM_IDLE_TIMEOUT_SECONDS = 2 *
MAX_TOOL_WAIT_SECONDS` in `constants.py`. Every long-running tool
(`run_agent`, `view_agent_output`, `run_sub_session`,
`get_sub_session_result`, `run_block`) imports from one place; no more
hardcoded 300 / `10*60` literals drifting apart. Stream-idle invariant
("no single tool blocks close to the idle timeout") holds by
construction.

### Frontend

- Friendlier tool-card labels: `run_sub_session` → "Sub-AutoPilot",
`get_sub_session_result` → "Sub-AutoPilot result", `run_block` →
"Action" (matches the builder UI's own naming), `run_agent` → "Agent".
Fixes the double-verb "Running Run …" phrasing.
- `SubSessionStatusResponse.sub_autopilot_session_link` surfaces
`/copilot?sessionId=<inner>` so users can click into any sub's session
from the tool-call card — same pattern as `run_agent`'s
`library_agent_link`.

### Changes 🏗️

- **New modules**: `sdk/session_waiter.py`, `sdk/stream_accumulator.py`,
`tools/run_sub_session.py`, `tools/get_sub_session_result.py`,
`tools/sub_session_test.py`, `tools/agent_guide_gate_test.py`.
- **New response types**: `SubSessionStatusResponse`,
`SubSessionProgressSnapshot`, `SessionResult`.
- **New gate helper**: `require_guide_read` in `tools/helpers.py`.
- **Queue protocol**: `permissions` field on `CoPilotExecutionEntry`,
threaded through `processor.py` → `stream_fn`.
- **Hidden**: `AUTOPILOT_BLOCK_ID` in `COPILOT_EXCLUDED_BLOCK_IDS`
(run_block can't execute AutoPilotBlock; agents use `run_sub_session`
instead).
- **Deleted**: `sdk/sub_session_registry.py`, processor
shutdown-notifier hook.
- **Regenerated**: `openapi.json` for the new response types; block-docs
for the updated `ToolName` Literal.
- **Tool descriptions**: tightened the guide-gate hint across the four
agent-builder tools to stay under the 32000-char schema budget.
- **40+ tests** across sub_session, execute_block cap + billing races,
stream_accumulator, agent_guide_gate, frontend helpers.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Unit suite green on the full copilot tree; `poetry run format` +
`pyright` clean
- [x] Schema character budget test passes (tool descriptions trimmed to
stay under 32000)
- [x] Native UI E2E (`poetry run app` + `pnpm dev`):
`run_sub_session(wait_for_result=60)` returns `status="completed"` +
`sub_autopilot_session_link` inline;
`run_sub_session(wait_for_result=1)` returns `status="running"` +
handle, `get_sub_session_result(wait_if_running=60)` observes `running →
completed` transition
- [x] AutoPilotBlock (graph) goes through `copilot_executor` queue
end-to-end (verified via logs: ExecutionManager's AutoPilotBlock node
spawned session `f6de335b-…`, a different `CoPilotExecutor` worker
acquired its cluster lock and ran the SDK stream)
- [x] Guide gate: `create_agent` without a prior
`get_agent_building_guide` returns the refusal; agent reads the guide
and retries successfully
2026-04-18 23:11:41 +07:00
Joe Munene
3a01874911 fix(frontend/builder): preserve agent name in AgentExecutor node title after reload (#12805)
## Summary

Fixes #11041

When an `AgentExecutorBlock` is placed in the builder, it initially
displays the agent's name (e.g., "Researcher v2"). After saving and
reloading the page, the title reverts to the generic "Agent Executor."

## Root Cause

The backend correctly persists `agent_name` and `graph_version` in
`hardcodedValues` (via `input_default` in `AgentExecutorBlock`).
However, `NodeHeader.tsx` always resolves the display title from
`data.title` (the generic block name), ignoring the persisted agent
name.

## Fix

Modified the title resolution chain in `NodeHeader.tsx` to check
`data.hardcodedValues.agent_name` between the user's custom name and the
generic block title:

1. `data.metadata.customized_name` (user's manual rename) — highest
priority
2. `agent_name` + ` v{graph_version}` from `hardcodedValues` — **new**
3. `data.title` (generic block name) — fallback

This is a frontend-only change. No backend modifications needed.

## Files Changed

-
`autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx`
(+11, -1)

## Test Plan

- [x] Place an AgentExecutorBlock, select an agent — title shows agent
name
- [x] Save graph, reload page — title still shows agent name (was "Agent
Executor" before)
- [x] Double-click to rename — custom name takes priority over agent
name
- [x] Clear custom name — falls back to agent name
- [x] Non-AgentExecutor blocks — unaffected, show generic title as
before

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-04-17 15:20:32 +00:00
Zamil Majdy
6d770d9917 fix(platform/copilot): revert forward pagination, add visibility guarantee for blank chat (#12831)
## Why / What / How

**Why:** PR #12796 changed completed copilot sessions to load messages
from sequence 0 forward (ascending), which broke the standard chat UX —
users now land at the beginning of the conversation instead of the most
recent messages. Reported in Discord.

**What:** Reverts the forward pagination approach and replaces it with a
visibility guarantee that ensures every page contains at least one
user/assistant message.

**How:**
- **Backend**: Removed after_sequence, from_start, forward_paginated,
newest_sequence — always use backward (newest-first) pagination. Added
_expand_for_visibility() helper: after fetching, if the entire page is
tool messages (invisible in UI), expand backward up to 200 messages
until a visible user/assistant message is found.
- **Frontend**: Removed all forwardPaginated/newestSequence plumbing
from hooks and components. Removed bottom LoadMoreSentinel. Simplified
message merge to always prepend paged messages.

### Changes
- routes.py: Reverted to simple backward pagination, removed TOCTOU
re-fetch logic
- db.py: Removed forward mode, extracted _expand_tool_boundary() and
added _expand_for_visibility()
- SessionDetailResponse: Removed newest_sequence and forward_paginated
fields
- openapi.json: Removed after_sequence param and forward pagination
response fields
- Frontend hooks/components: Removed forward pagination props and logic
(-1000 lines)
- Updated all tests (backend: 63 pass, frontend: 1517 pass)

### Checklist
- [x] I have clearly listed my changes in the PR description
- [x] Backend unit tests: 63 pass
- [x] Frontend unit tests: 1517 pass
- [x] Frontend lint + types: clean
- [x] Backend format + pyright: clean
2026-04-17 19:23:28 +07:00
slepybear
334ec18c31 docs: convert in-code comments to MkDocs admonitions in block-sdk-gui… (#12819)
### Why / What / How

<!-- Why: Why does this PR exist? What problem does it solve, or what's
broken/missing without it? -->
This PR converts inline Python comments in code examples within
`block-sdk-guide.md` into MkDocs `!!! note` admonitions. This makes code
examples cleaner and more copy-paste friendly while preserving all
explanatory content.

<!-- What: What does this PR change? Summarize the changes at a high
level. -->
Converts inline comments in code blocks to admonitions following the
pattern established in PR #12396 (new_blocks.md) and PR #12313.

<!-- How: How does it work? Describe the approach, key implementation
details, or architecture decisions. -->
- Wrapped code examples with `!!! note` admonitions
- Removed inline comments from code blocks for clean copy-paste
- Added explanatory admonitions after each code block

### Changes 🏗️

- Provider configuration examples (API key and OAuth)
- Block class Input/Output schema annotations
- Block initialization parameters
- Test configuration
- OAuth and webhook handler implementations
- Authentication types and file handling patterns

### Checklist 📋

#### For documentation changes:
- [x] Follows the admonition pattern from PR #12396
- [x] No code changes, documentation only
- [x] Admonition syntax verified correct

#### For configuration changes:
- [ ] `.env.default` is updated or already compatible with my changes
- [ ] `docker-compose.yml` is updated or already compatible with my
changes

---

**Related Issues**: Closes #8946

Co-authored-by: slepybear <slepybear@users.noreply.github.com>
Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-04-17 07:47:52 +00:00
slepybear
ea5cfdfa2e fix(frontend): remove debug console.log statements (#12823)
## Why
Debug console.log statements were left in production code, which can
leak
sensitive information and pollute browser developer consoles.

## What
Removed console.log from 4 non-legacy frontend components:
- useNavbar.ts: isLoggedIn debug log
- WalletRefill.tsx: autoRefillForm debug log  
- EditAgentForm.tsx: category field debug log
- TimezoneForm.tsx: currentTimezone debug log

## How
Simply deleted the console.log lines as they served no purpose 
other than debugging during development.

## Checklist
- [x] Code follows project conventions
- [x] Only frontend changes (4 files, 6 lines removed)
- [x] No functionality changes

Co-authored-by: slepybear <slepybear@users.noreply.github.com>
2026-04-17 07:31:51 +00:00
Ubbe
d13a85bef7 feat(frontend): surface scheduled agents in library & copilot briefings (#12818)
## Why

Scheduled agents weren't well-surfaced in the Library and Copilot
briefings:

- The Library fleet summary didn't count agents that are scheduled
purely via the scheduler (only those with a `recommended_schedule_cron`
set at the agent level).
- Sitrep items didn't distinguish scheduled or listening (trigger-based)
agents, so they often fell back to a generic "idle" state.
- Scheduled chips showed a generic message with no indication of when
the next run would happen.
- The Copilot Agent Briefing surfaced every scheduled agent regardless
of how far out the next run was — an agent scheduled a month away would
take a slot from something actually happening soon.
- Long sitrep messages overflowed the row.

## What

- Add `is_scheduled` to `LibraryAgent` (sourced from the scheduler) so
the frontend can reliably detect schedule-only agents.
- Count scheduled agents in `useLibraryFleetSummary`.
- Include scheduled and listening agents in sitrep items, with a
priority ordering (error → running → stale → success → listening →
scheduled → idle).
- Show a relative next-run time on scheduled sitrep chips (e.g.
"Scheduled to run in 2h" / "in 3d").
- Filter the Copilot Agent Briefing to scheduled agents whose next run
is within the next 3 days.
- Truncate long sitrep messages to 1 line with `OverflowText` and show
the full text in a tooltip on hover.

## How

- Scheduler → `LibraryAgent` mapping populates `is_scheduled` /
`next_scheduled_run`.
- `useSitrepItems` gains an optional `scheduledWithinMs` parameter.
Copilot's `usePulseChips` passes `3 * 24 * 60 * 60 * 1000`; the Library
briefing omits it to keep its existing (unbounded) behavior.
- Scheduled config-based sitrep items are skipped when
`next_scheduled_run` is missing or outside the window.
- `SitrepItem` wraps the message in `OverflowText` so a single-line
ellipsis + hover tooltip replaces raw overflow.

## Test plan

- [ ] `/library` — scheduled and listening agents appear in the sitrep
with accurate copy; fleet summary counts scheduled agents correctly;
long messages truncate with a tooltip on hover.
- [ ] `/copilot` — on an empty session with the `AGENT_BRIEFING` flag
on, the briefing only shows scheduled agents whose next run is within 3
days; agents scheduled further out no longer appear as "scheduled"
chips.
- [ ] Scheduled chip text reads "Scheduled to run in {Nm|Nh|Nd}"
matching `next_scheduled_run`.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-17 14:36:15 +07:00
Zamil Majdy
60b85640e7 fix(backend/copilot): replace dedup lock with idempotent append_and_save_message (#12814)
## Why

The Redis dedup lock (`chat:msg_dedup:{session}:{content_hash}`, 30s
TTL) was solving the wrong problem:

- Its purpose: block infra/nginx retries from calling
`append_and_save_message` twice after a client disconnect, writing a
duplicate user message to the DB.
- The approach: deliberately hold the lock for 30s on `GeneratorExit`.
- Why unnecessary: the executor's cluster lock already prevents
duplicate *execution*. The only real gap was duplicate *DB writes* in
the ~1s before the executor picks up the turn.

## What

- **Deleted** `message_dedup.py` and `message_dedup_test.py` (~150 lines
removed).
- **Removed** all dedup lock code from `routes.py` (~40 lines removed).
- **`append_and_save_message`** is now idempotent and self-contained:
- Uses redis-py's built-in `Lock(timeout=10, blocking_timeout=2)` —
Lua-script atomic acquire/release, no manual poll/sleep loop.
- Lock context manager yields `bool` (`True` = acquired, `False` =
degraded). When degraded (Redis down or 2s timeout), reads from DB
directly instead of cache to avoid stale-state duplicates.
- Idempotency check: if `session.messages[-1]` already matches the
incoming role+content, returns `None` instead of the session.
- Lock released explicitly as soon as the write completes; `try/except`
in `finally` so a cleanup error after a successful write never surfaces
a false 500.
- On cache-write failure, the stale cache entry is invalidated so future
reads fall back to the authoritative DB.
- **`routes.py`** uses the `None` signal: `is_duplicate_message = (await
append_and_save_message(...)) is None`
- Skips `create_session` and `enqueue_copilot_turn` for duplicates —
client re-attaches to the existing turn's Redis stream.
- `track_user_message` and `turn_id` generation only happen when
`is_duplicate_message` is false.
- **`subscribe_to_session`** retry window increased from 1×50ms to
3×100ms — covers the window where a duplicate request subscribes before
the original's `create_session` hset completes.
- **Cleaned up** `routes_test.py`: removed 5 dedup-specific tests and
the `mock_redis` setup from `_mock_stream_internals`; added
duplicate-skips-enqueue test.

## How

The idempotency guard distinguishes legit same-text messages from
retries via the **assistant turn between them**: if the user said "yes",
got a response, and says "yes" again, `session.messages[-1]` is the
assistant reply, so the role check fails and the second message goes
through. A retry (no response yet) sees the user message as the last
entry and is blocked.

```python
if (
    session.messages
    and session.messages[-1].role == message.role
    and session.messages[-1].content == message.content
):
    return None  # duplicate — caller skips enqueue
```

The Redis lock ensures this check always sees authoritative state even
in multi-replica deployments. When the lock is unavailable (Redis down
or contention), reading from DB directly (bypassing potentially stale
cache) provides the same safety guarantee at the cost of a DB
round-trip.

## Checklist

- [x] PR targets `dev`
- [x] Conventional commit title with scope
- [x] Tests added/updated (duplicate detection, lock degradation, DB
error, cache invalidation paths)
- [x] `poetry run format` and `poetry run pyright` pass clean
- [x] No new linter suppressors
2026-04-16 22:12:30 +07:00
Zamil Majdy
87e4d42750 fix(backend/copilot): fix initial load missing messages + forward pagination for completed sessions (#12796)
### Why / What / How

**Why:** Completed copilot sessions with many messages showed a
completely empty chat view. A user reported a 158-message session that
appeared blank on reload.

**What:** Two bugs fixed:
1. **Backend** — initial page load always returned the newest 50
messages in DESC order. For sessions heavy in tool calls, the user's
original messages (seq 0–5) were never included; all 50 slots consumed
by mid-session tool outputs.
2. **Frontend** — convertChatSessionToUiMessages silently dropped user
messages with null/empty content.

**How:** For completed sessions (no active stream), the backend now
loads from sequence 0 in ASC order. Active/streaming sessions keep
newest-first for streaming context. A new after_sequence forward cursor
enables infinite-scroll for subsequent pages (sentinel moves to bottom).
The frontend wires forward_paginated + newest_sequence end-to-end.

### Changes 🏗️

- db.py: added from_start (ASC) and after_sequence (forward cursor)
modes; added newest_sequence to PaginatedMessages
- routes.py: detect completed vs active on initial load; pass
from_start=True for completed; expose newest_sequence +
forward_paginated; accept after_sequence param
- convertChatSessionToUiMessages.ts: never drop user messages with empty
content
- useLoadMoreMessages.ts: forward pagination via after_sequence; append
pages to end
- ChatMessagesContainer.tsx: LoadMoreSentinel at bottom for
forward-paginated sessions
- Wire newestSequence + forwardPaginated end-to-end through
useChatSession/useCopilotPage/ChatContainer
- openapi.json: add after_sequence + newest_sequence/forward_paginated;
regenerate types
- db_test.py: 9 new unit tests for from_start and after_sequence modes

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Open a completed session with many messages — first user message
visible on initial load
- [x] Scroll to bottom of completed session — load more appends next
page
- [x] Open active/streaming session — newest messages shown first,
streaming unaffected
  - [x] Backend unit tests: all 28 pass
  - [x] Frontend lint/format: clean, no new type errors

---------

Co-authored-by: chernistry <73943355+chernistry@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
2026-04-16 14:16:54 +00:00
Ubbe
0339d95d12 fix(frontend): small UI fixes, sort menu bg, name update auth, stats grid overflow, pulse chips (#12815)
## Summary
- **LibrarySortMenu / AgentFilterMenu**: Force `!bg-transparent` and
neutralise legacy `SelectTrigger` styles (`m-0.5`, `ring-offset-white`,
`shadow-sm`) that caused a white background around the trigger
- **EditNameDialog**: Replace client-side `supabase.auth.updateUser()`
with server-side `PUT /api/auth/user` route — fixes "Auth session
missing!" error caused by `httpOnly` cookies being inaccessible to
browser JS
- **StatsGrid**: Swap label `Text` for `OverflowText` so tile labels
truncate with `…` and show a tooltip instead of wrapping when the grid
is squeezed
- **PulseChips**: Set fixed `15rem` chip width with `shrink-0`,
horizontal scroll, and styled thin scrollbar
- **Tests**: Updated `EditNameDialog` tests to use MSW instead of
mocking Supabase client; added 7 new `PulseChips` integration tests

## Test plan
- [x] `pnpm test:unit` — all 1495 tests pass (91 files)
- [x] `pnpm format && pnpm lint` — clean
- [x] `pnpm types` — no new errors (pre-existing only)
- [ ] QA `/library?sort=updatedAt` — sort menu trigger has no white bg
- [ ] QA `/library` — StatsGrid labels truncate with tooltip on narrow
viewports
- [ ] QA `/copilot` — PulseChips scroll horizontally at fixed width
- [ ] QA `/copilot` — Edit name dialog saves successfully (no "Auth
session missing!")

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-16 20:11:21 +07:00
Toran Bruce Richards
f410929560 feat(platform): Add xAI Grok 4.20 models from OpenRouter (#12620)
Requested by @Torantulino

Adds the 2 xAI Grok 4.20 models available on OpenRouter that are missing
from the platform.

## Why

`x-ai/grok-4.20` and `x-ai/grok-4.20-multi-agent` are xAI's current
flagship models (released March 2026) and are available via OpenRouter,
but weren't accessible from the platform's LLM blocks.

## Changes

**`autogpt_platform/backend/backend/blocks/llm.py`**
- Added `GROK_4_20` and `GROK_4_20_MULTI_AGENT` enum members
- Added corresponding `MODEL_METADATA` entries (open_router provider, 2M
context window, price tier 3)

**`autogpt_platform/backend/backend/data/block_cost_config.py`**
- Added `MODEL_COST` entries at 5 credits each (flagship tier, $2/M in)

**`docs/integrations/block-integrations/llm.md`**
- Added new model IDs to all LLM block tables

| Model | Pricing | Context |
|-------|---------|---------|
| `x-ai/grok-4.20` | $2/M in, $6/M out | 2M |
| `x-ai/grok-4.20-multi-agent` | $2/M in, $6/M out | 2M |

Both models use the standard OpenRouter chat completions API — no
special handling needed.

Resolves: SECRT-2196

---------

Co-authored-by: Torantulino <22963551+Torantulino@users.noreply.github.com>
Co-authored-by: Toran Bruce Richards <Torantulino@users.noreply.github.com>
Co-authored-by: Otto (AGPT) <otto@agpt.co>
2026-04-16 12:14:56 +00:00
Zamil Majdy
2bbec09e1a feat(platform): subscription tier billing via Stripe Checkout (#12727)
## Why

Introducing paid subscription tiers (PRO, BUSINESS) so we can charge for
AutoPilot capacity beyond the free tier. Without a billing integration,
all users share the same rate limits regardless of their willingness to
pay for additional capacity.

## What

End-to-end subscription billing system using Stripe Checkout Sessions:

**Backend:**
- `SubscriptionTier` enum (`FREE`, `PRO`, `BUSINESS`, `ENTERPRISE`) on
the `User` model
- `POST /credits/subscription` — creates a Stripe Checkout Session for
paid upgrades; for FREE tier or when `ENABLE_PLATFORM_PAYMENT` is off,
sets tier directly
- `GET /credits/subscription` — returns current tier, monthly cost
(cents), and all tier costs
- `POST /credits/stripe_webhook` — handles
`customer.subscription.created/updated/deleted`,
`checkout.session.completed`, `charge.dispute.*`, `refund.created`
- `sync_subscription_from_stripe()` — keeps `User.subscriptionTier` in
sync from webhook events; guards against out-of-order delivery
(cancelled event after new sub created), ENTERPRISE overwrite, and
duplicate webhook replay
- Open-redirect protection on `success_url`/`cancel_url` via
`_validate_checkout_redirect_url()`
- `_cancel_customer_subscriptions()` — cancels both active and trialing
subs; propagates errors so callers can avoid updating DB tier on Stripe
failure
- `_cleanup_stale_subscriptions()` — best-effort cancellation of old
subs when a new one becomes active (paid-to-paid upgrade), to prevent
double-billing
- `get_stripe_customer_id()` with idempotency key to prevent duplicate
Stripe customers on concurrent requests
- `cache_none=False` sentinel fix in `@cached` decorator so Stripe price
lookups retry on transient error instead of poisoning the cache with
`None`
- Stripe Price IDs read from LaunchDarkly (`stripe-price-id-pro`,
`stripe-price-id-business`). If not configured, upgrade returns 422.

**Frontend:**
- `SubscriptionTierSection` component on billing page: tier cards
(FREE/PRO/BUSINESS), upgrade/downgrade buttons, per-tier cost display,
Stripe redirect on upgrade
- Confirmation dialog for downgrades
- ENTERPRISE users see a read-only admin-managed banner
- Success toast on return from Stripe Checkout (`?subscription=success`)
- Uses generated `useGetSubscriptionStatus` /
`useUpdateSubscriptionTier` hooks

## How

- Paid upgrades use Stripe Checkout Sessions (not server-side
subscription creation) — Stripe handles PCI-compliant card collection
and the subscription lifecycle
- Tier is synced back via webhook on
`customer.subscription.created/updated/deleted`
- Downgrade to FREE cancels via Stripe API immediately; a
`stripe.StripeError` during cancellation returns 502 with a generic
message (no Stripe detail leakage)
- LaunchDarkly flags: `stripe-price-id-pro` (string),
`stripe-price-id-business` (string), `enable-platform-payment` (bool)
- `ENABLE_PLATFORM_PAYMENT=false` bypasses Stripe for beta/internal
access (sets tier directly)

## Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] `ENABLE_PLATFORM_PAYMENT=false` → tier change updates directly, no
Stripe redirect
- [x] `ENABLE_PLATFORM_PAYMENT=true` with price IDs configured → paid
upgrade redirects to Stripe Checkout
- [x] Stripe webhook `customer.subscription.created` →
`User.subscriptionTier` updated
  - [x] Unrecognised price ID in webhook → logs warning, tier unchanged
  - [x] ENTERPRISE user webhook event → tier not overwritten
  - [x] Empty `STRIPE_WEBHOOK_SECRET` → 503 (prevents HMAC bypass)
  - [x] Open-redirect attack on `success_url`/`cancel_url` → 422

#### For configuration changes:
- [x] No `.env` or `docker-compose.yml` changes
- [x] LaunchDarkly flags to create: `stripe-price-id-pro` (string),
`stripe-price-id-business` (string), `enable-platform-payment` (bool)

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: majdyz <majdy.zamil@gmail.com>
2026-04-16 17:52:06 +07:00
Ubbe
31b88a6e56 feat(frontend): add Agent Briefing Panel (#12764)
## Summary

<img width="800" height="772" alt="Screenshot_2026-04-13_at_18 29 19"
src="https://github.com/user-attachments/assets/3da6eaf2-1485-4c08-9651-18f2f4220eba"
/>
<img width="800" height="285" alt="Screenshot_2026-04-13_at_18 29 24"
src="https://github.com/user-attachments/assets/6a5f981a-1e1d-4d22-a33d-9e1b0e7555a7"
/>
<img width="800" height="288" alt="Screenshot_2026-04-13_at_18 29 27"
src="https://github.com/user-attachments/assets/f97b4611-7c23-4fc9-a12d-edf6314a77ef"
/>
<img width="800" height="433" alt="Screenshot_2026-04-13_at_18 29 31"
src="https://github.com/user-attachments/assets/e6d7241d-84f3-4936-b8cd-e0b12df392bb"
/>
<img width="700" height="554" alt="Screenshot_2026-04-13_at_18 29 40"
src="https://github.com/user-attachments/assets/92c08f21-f950-45cd-8c1d-529905a6e85f"
/>


Implements the Agent Intelligence Layer — real-time agent awareness
across the Library and Copilot pages.

### Core Features
- **Agent Briefing Panel** — stats grid with fleet-wide counts (running,
recently completed, needs attention, scheduled, idle, monthly spend) and
tab-driven content below
- **Enhanced Library Cards** — StatusBadge, run counts, contextual
action buttons (See tasks, Start, Chat) with consistent icon-left
styling
- **Situation Report Items** — prioritized sitrep with error-first
ranking, "See task" deep-links for completed runs, and "Ask AutoPilot"
bridge
- **Home Pulse Chips** — agent status chips on Copilot empty state with
hover-reveal actions (slide-up animation + backdrop blur on desktop,
always visible on touch)
- **Edit Display Name** — pencil icon on Copilot greeting to update
Supabase user metadata inline

### Backend
- **Execution count API** — batch `COUNT(*)` query on
`AgentGraphExecution` grouped by `agentGraphId` for the current user,
avoiding loading full execution rows. Wired into `list_library_agents`
and `list_favorite_library_agents` via `execution_count_override` on
`LibraryAgent.from_db()`

### UI Polish
- Subtler gradient on AgentBriefingPanel (reduced opacity on background
+ animated border)
- Consistent button styles across all action buttons (icon-left, same
sizing)
- Removed duplicate "Open in builder" menu item (kept "Edit agent")
- "Recently completed" tab replaces "Listening" in briefing panel,
showing agents with completed runs in last 72h

## Changes

### Backend
- `backend/api/features/library/db.py` — added
`_fetch_execution_counts()` batch COUNT query, wired into list endpoints
- `backend/api/features/library/model.py` — added
`execution_count_override` param to `LibraryAgent.from_db()`

### Frontend — New files
- `EditNameDialog/EditNameDialog.tsx` — modal to update display name via
Supabase auth
- `PulseChips/PulseChips.module.css` — hover-reveal animation + glass
panel styles

### Frontend — Modified files
- `EmptySession.tsx` — added EditNameDialog and PulseChips
- `PulseChips.tsx` — redesigned with See/Ask buttons, hover overlay on
desktop
- `usePulseChips.ts` — added agentID for deep-linking
- `AgentBriefingPanel.tsx` — subtler gradient, adjusted padding
- `AgentBriefingPanel.module.css` — reduced conic gradient opacity
- `BriefingTabContent.tsx` — added "completed" tab routing
- `StatsGrid.tsx` — replaced Listening with Recently completed,
reordered tabs
- `SitrepItem.tsx` — consistent button styles, "See task" link for
completed items, updated copilot prompt
- `ContextualActionButton.tsx` — icon-left, smaller icon, renamed Run to
Start
- `LibraryAgentCard.tsx` — icon-left on all buttons, EyeIcon for See
tasks
- `AgentCardMenu.tsx` — removed duplicate "Open in builder"
- `useAgentStatus.ts` — added completed count to FleetSummary
- `useLibraryFleetSummary.ts` — added recent completion tracking
- `types.ts` — added `completed` to FleetSummary and AgentStatusFilter

## Test plan
- [ ] Library page renders Agent Briefing Panel with stats grid
- [ ] "Recently completed" tab shows agents with completed runs in last
72h
- [ ] Agent cards show real execution counts (not 0)
- [ ] Action buttons have consistent styling with icon on the left
- [ ] "See task" on completed items deep-links to agent page with
execution selected
- [ ] "Ask AutoPilot" generates last-run-specific prompt for completed
items
- [ ] Copilot empty state shows PulseChips with hover-reveal actions on
desktop
- [ ] PulseChips show See/Ask buttons always on touch screens
- [ ] Pencil icon on greeting opens edit name dialog
- [ ] Name update persists via Supabase and refreshes greeting
- [ ] `pnpm format && pnpm lint && pnpm types` pass
- [ ] `poetry run format` passes for backend changes

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: John Ababseh <jababseh7@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Bentlybro <Github@bentlybro.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: CodeRabbit <noreply@coderabbit.ai>
Co-authored-by: majdyz <zamil.majdy@agpt.co>
2026-04-16 17:32:17 +07:00
Zamil Majdy
d357956d98 refactor(backend/copilot): make session-file helper fns public to fix Pyright warnings (#12812)
## Why
After PR #12804 was squashed into dev, two module-level helper functions
in `backend/copilot/sdk/service.py` remained private (`_`-prefixed)
while being directly imported by name in `sdk/transcript_test.py`.
Pyright reports `reportAttributeAccessIssue` when tests (even those
excluded from CI lint) import private symbols from outside their
defining module.

## What
Rename two helpers to remove the underscore prefix:
- `_process_cli_restore` → `process_cli_restore`
- `_read_cli_session_from_disk` → `read_cli_session_from_disk`

Update call sites in `service.py` and imports/calls/docstrings in
`sdk/transcript_test.py`.

## How
Pure rename — no logic change. Both functions were already module-level
helpers with no reason to be private; the underscore was convention
carried over during the refactor but they are directly unit-tested and
should be public.

All 66 `sdk/transcript_test.py` tests pass after the rename.

## Checklist
- [x] Tests pass (`poetry run pytest
backend/copilot/sdk/transcript_test.py`)
- [x] No `_`-prefixed symbols imported across module boundaries
- [x] No linter suppressors added
2026-04-16 17:00:02 +07:00
Zamil Majdy
697ffa81f0 fix(backend/copilot): update transcript_test to use strip_for_upload after upload_cli_session removal 2026-04-16 16:17:02 +07:00
Zamil Majdy
2b4727e8b2 chore: merge master into dev, resolve baseline/transcript conflicts
Conflicts in baseline/service.py, baseline/transcript_integration_test.py,
and transcript.py arose because dev-only commit 0cd0a76305
(baseline upload fix) overlapped with the same fix in PR #12804 which
landed in master. Took master's version for all three files — it is the
complete, reviewed implementation.
2026-04-16 15:38:46 +07:00
Zamil Majdy
0d4b31e8a1 refactor(backend/copilot): unified transcript context — extract_context_messages, mode-gated --resume, compaction-aware gap-fill (#12804)
### Why / What / How

**Why:** The copilot had two separate GCS paths (`cli-sessions/` and
`chat-transcripts/`), redundant function names
(`upload_cli_session`/`restore_cli_session`), and no shared context
strategy between modes. When switching from baseline→SDK or
SDK→baseline, the receiving mode discarded the stored transcript and
fell back to full DB reconstruction — loading all raw messages instead
of the compacted form — causing inflated context, wasted tokens, and
loss of CLI compaction summaries.

**What:**
- Single GCS path (`cli-sessions/`) for both modes — `chat-transcripts/`
removed
- Unified public API: `upload_transcript` / `download_transcript` /
`TranscriptDownload`
- `TranscriptMode = Literal["sdk", "baseline"]` persisted in
`.meta.json` — SDK skips `--resume` when `mode != "sdk"`
(baseline-written JSONL has stripped fields / synthetic IDs)
- `extract_context_messages(download, session_messages)` — shared
context primitive used by **both SDK and baseline**: reads compacted
transcript content + fills only the DB gap (messages after watermark),
so CLI compaction summaries are preserved across mode switches
- Watermark fix: `_jsonl_covered = transcript_msg_count + 2` when a real
transcript is present, preventing false gap detection after `--resume`
- Baseline gap-fill: `_append_gap_to_builder` converts `ChatMessage` →
JSONL entries; no more silently discarded stale transcripts

**How:**

```
SDK turn (mode="sdk" transcript available):
  ──► --resume  [full CLI session restored natively]
  ──► inject gap prefix if DB has messages after watermark

SDK turn (mode="baseline" transcript available):
  ──► cannot --resume (synthetic CLI IDs)
  ──► extract_context_messages(download, session_messages):
        returns transcript JSONL (compacted, isCompactSummary preserved) + gap
        excludes session_messages[-1] (current turn — caller injects it separately)
  ──► format as <conversation_history> + "Now, the user says: {current}"

Baseline turn (any transcript):
  ──► _load_prior_transcript → TranscriptDownload
  ──► extract_context_messages(download, session_messages) + session_messages[-1]
        replaces full session.messages DB read
  ──► LLM messages: [compacted history + gap] + [current user turn]

Transcript unavailable — both SDK (use_resume=False) and baseline:
  ──► extract_context_messages(None, session_messages) returns session_messages[:-1]
        (all prior DB messages except the current user turn at [-1])
  ──► graceful fallback — no crash, no empty context
  ──► covers: first turn, GCS error, corrupt JSONL, missing .meta.json
  ──► next successful response uploads a fresh transcript
```

`extract_context_messages` is the shared primitive — both modes call the
same function, which handles:
- `download=None` (first turn, GCS unavailable) → falls back to
`session_messages[:-1]`
- Empty/corrupt content → falls back to `session_messages[:-1]`
- `bytes` content (raw GCS) or `str` content (pre-decoded baseline path)
- `isCompactSummary=True` entries → preserved so CLI compaction survives
mode switches
- Missing/corrupt `.meta.json` → `message_count` defaults to `0`, `mode`
defaults to `"sdk"`

**Why `[:-1]` and not all messages?** `session_messages[-1]` is always
the current user turn being handled right now. Both callers inject it
separately — SDK wraps it as `"Now, the user says: ..."`, baseline
appends it as the final message in the LLM array. Returning it inside
`extract_context_messages` would double-inject it.

### Changes 🏗️

- **`transcript.py`**: `CliSessionRestore` → `TranscriptDownload` +
`mode` field; `upload_cli_session` → `upload_transcript`;
`restore_cli_session` → `download_transcript`; add `TranscriptMode`,
`detect_gap`, `extract_context_messages`; import `ChatMessage` via
relative path to match `service.py` style
- **`sdk/service.py`**: mode-check before `--resume`; `_RestoreResult`
carries `baseline_download` + `context_messages` + `transcript_content`;
`_build_query_message` accepts `prior_messages` override;
`_restore_cli_session_for_turn` populates `context_messages` via
`extract_context_messages` and sets `transcript_content` to prevent
duplicate DB reconstruction; watermark fix (`_jsonl_covered =
transcript_msg_count + 2`)
- **`baseline/service.py`**: `_load_prior_transcript` returns `(bool,
TranscriptDownload | None)`; LLM context replaced with
`extract_context_messages(download, messages)`; `_append_gap_to_builder`
+ `detect_gap` call; `upload_transcript(mode="baseline")`
- **`sdk/transcript.py`**: updated re-exports, old aliases removed
- **`scripts/download_transcripts.py`**: updated for `bytes | str`
content type
- **Test files**: 179 tests total; `transcript_test.py`,
`baseline/transcript_integration_test.py`,
`sdk/service_helpers_test.py`, `sdk/test_transcript_watermark.py`,
`test/copilot/test_transcript_watermark.py` all updated/added

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] 179 unit tests pass — `transcript_test`,
`baseline/transcript_integration_test`, `sdk/service_helpers_test`,
`sdk/test_transcript_watermark`
  - [x] pyright 0 errors on all changed files
- [x] SDK `--resume` path still works when `mode="sdk"` transcript is
present
- [x] SDK fallback uses `extract_context_messages` (compacted baseline
content + gap) when `mode="baseline"` transcript is stored — no more
full DB reconstruction
- [x] Baseline uses `extract_context_messages` per turn instead of full
`session.messages` DB read
  - [x] `isCompactSummary=True` entries preserved across mode switches
- [x] Watermark (`_jsonl_covered`) fix prevents false gap detection
after `--resume`
- [x] Baseline gap detection no longer silently discards stale
transcripts
- [x] `TranscriptDownload.content` accepts `bytes | str` — backward
compatible
- [x] Transcript unavailable (GCS error, first turn, corrupt file)
gracefully falls back to `session_messages[:-1]` without crash — applies
to both SDK and baseline paths

---------

Co-authored-by: chernistry <73943355+chernistry@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
2026-04-16 15:35:18 +07:00
Zamil Majdy
0cd0a76305 fix(backend/copilot): baseline always uploads when GCS has no transcript
_load_prior_transcript was returning False for missing/invalid transcripts,
which caused should_upload_transcript to suppress the upload. The original
intent was to protect against overwriting a *newer* GCS version — but a
missing or corrupt file is not 'newer'. Only stale (watermark ahead) and
download errors (unknown GCS state) should suppress upload.

Also renames transcript_covers_prefix → transcript_upload_safe throughout
to accurately describe what the flag means.
2026-04-16 14:58:42 +07:00
Toran Bruce Richards
d01a51be0e Add check for GitHub account connection status (#12807)
Added instruction to check GitHub authentication status before prompting
user. This prevents repeated, unnecessary asking of the user to add
their GitHub credentials when they're already added, which is currently
a prevalent bug.

### Changes 🏗️
- Added one line to
`autogpt_platform/backend/backend/copilot/prompting.py` instructing
AutoPilot to run `gh auth status` before prompting the user to connect
their GitHub account.

Co-authored-by: Toran Bruce Richards <22963551+Torantulino@users.noreply.github.com>
2026-04-16 12:09:00 +07:00
260 changed files with 27496 additions and 5122 deletions

View File

@@ -25,6 +25,8 @@ Understand the **Why / What / How** before addressing comments — you need cont
gh pr view {N} --json body --jq '.body'
```
> If GraphQL is rate-limited, `gh pr view` fails. See [GitHub rate limits](#github-rate-limits) for REST fallbacks.
## Fetch comments (all sources)
### 1. Inline review threads — GraphQL (primary source of actionable items)
@@ -109,12 +111,16 @@ Only after this loop completes (all pages fetched, count confirmed) should you b
**Filter to unresolved threads only** — skip any thread where `isResolved: true`. `comments(last: 1)` returns the most recent comment in the thread — act on that; it reflects the reviewer's final ask. Use the thread `id` (Relay global ID) to track threads across polls.
> If GraphQL is rate-limited, see [GitHub rate limits](#github-rate-limits) for the REST fallback (flat comment list — no thread grouping or `isResolved`).
### 2. Top-level reviews — REST (MUST paginate)
```bash
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate
```
> **Already REST — unaffected by GraphQL rate limits or outages. Continue polling reviews normally even when GraphQL is exhausted.**
**CRITICAL — always `--paginate`.** Reviews default to 30 per page. PRs can have 80170+ reviews (mostly empty resolution events). Without pagination you miss reviews past position 30 — including `autogpt-reviewer`'s structured review which is typically posted after several CI runs and sits well beyond the first page.
Two things to extract:
@@ -133,6 +139,8 @@ Two things to extract:
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments --paginate
```
> **Already REST — unaffected by GraphQL rate limits.**
Mostly contains: bot summaries (`coderabbitai[bot]`), CI/conflict detection (`github-actions[bot]`), and author status updates. Scan for non-empty messages from non-bot human reviewers that aren't the PR author — those are the ones that need a response.
## For each unaddressed comment
@@ -327,18 +335,65 @@ git push
5. Restart the polling loop from the top — new commits reset CI status.
## GitHub abuse rate limits
## GitHub rate limits
Two distinct rate limits exist — they have different causes and recovery times:
Three distinct rate limits exist — they have different causes, error shapes, and recovery times:
| Error | HTTP code | Cause | Recovery |
|---|---|---|---|
| `{"code":"abuse"}` | 403 | Secondary rate limit — too many write operations (comments, mutations) in a short window | Wait **23 minutes**. 60s is often not enough. |
| `{"message":"API rate limit exceeded"}` | 429 | Primary rate limit — too many API calls per hour | Wait until `X-RateLimit-Reset` header timestamp |
| `{"message":"API rate limit exceeded"}` | 429 | Primary REST rate limit — 5000 calls/hr per user | Wait until `X-RateLimit-Reset` header timestamp |
| `GraphQL: API rate limit already exceeded for user ID ...` | 403 on stderr, `gh` exits 1 | **GraphQL-specific** per-user limit — distinct from REST's 5000/hr and from the abuse secondary limit. Trips faster than REST because point costs per query. | Wait until the GraphQL window resets (typically ~1 hour from the first call in the window). REST still works — use fallbacks below. |
**Prevention:** Add `sleep 3` between individual thread reply API calls. When posting >20 replies, increase to `sleep 5`.
**Recovery from secondary rate limit (403):**
### Detection
The `gh` CLI surfaces the GraphQL limit on stderr with the exact string `GraphQL: API rate limit already exceeded for user ID <id>` and exits 1 — any `gh api graphql ...` **or** `gh pr view ...` call fails. Check current quota and reset time via the REST endpoint that reports GraphQL quota (this call is REST and still works whether GraphQL is rate-limited OR fully down):
```bash
gh api rate_limit --jq '.resources.graphql' # { "limit": 5000, "used": 5000, "remaining": 0, "reset": 1729...}
# Human-readable reset:
gh api rate_limit --jq '.resources.graphql.reset' | xargs -I{} date -r {}
```
Retry when `remaining > 0`. If you need to proceed sooner, sleep 25 min and probe again — the limit is per user, not per machine, so other concurrent agents under the same token also consume it.
### What keeps working
When GraphQL is unavailable (rate-limited or outage):
- **Keeps working (REST):** top-level reviews fetch, conversation comments fetch, all inline-comment replies, CI status (`gh pr checks`), and the `gh api rate_limit` probe.
- **Degraded:** inline thread list — fall back to flat `/pulls/{N}/comments` REST, which drops thread grouping, `isResolved`, and Relay thread IDs. You still get comment bodies and the `databaseId` as `id`, enough to read and reply.
- **Blocked:** `gh pr view`, the `resolveReviewThread` mutation, and any new `gh api graphql` queries — wait for the quota to reset.
### Fall back to REST
**PR metadata reads** — `gh pr view` uses GraphQL under the hood; use the REST pulls endpoint instead, which returns the full PR object:
```bash
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.body' # == --json body
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.base.ref' # == --json baseRefName
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.mergeable' # == --json mergeable
```
Note: REST `mergeable` returns `true|false|null`; GraphQL returns `MERGEABLE|CONFLICTING|UNKNOWN`. The `null` case maps to `UNKNOWN` — treat it the same (still computing; poll again).
**Inline comments (flat list)** — no thread grouping or `isResolved`, but enough to read and reply:
```bash
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate \
| jq '[.[] | {id, path, line, user: .user.login, body: .body[:200], in_reply_to_id}]'
```
Use this degraded mode to make progress on the fix → reply loop, then return to GraphQL for `resolveReviewThread` once the rate limit resets.
**Replies** — already REST-native (`/pulls/{N}/comments/{ID}/replies`); no change needed, use the same command as the main flow.
**`resolveReviewThread`** — **no REST equivalent**; GitHub does not expose a REST endpoint for thread resolution. Queue the thread IDs needing resolution, wait for the GraphQL limit to reset, then run the resolve mutations in a batch (with `sleep 3` between calls, per the secondary-limit guidance).
### Recovery from secondary rate limit (403 abuse)
1. Stop all API writes immediately
2. Wait **2 minutes minimum** (not 60s — secondary limits are stricter)
3. Resume with `sleep 3` between each call
@@ -397,6 +452,8 @@ gh api graphql -f query='mutation { resolveReviewThread(input: {threadId: "THREA
**Never call this mutation before committing the fix.** The orchestrator will verify actual unresolved counts via GraphQL after you output `ORCHESTRATOR:DONE` — false resolutions will be caught and you will be re-briefed.
> `resolveReviewThread` is GraphQL-only — no REST equivalent. If GraphQL is rate-limited, see [GitHub rate limits](#github-rate-limits) for the queue-and-retry flow.
### Verify actual count before outputting ORCHESTRATOR:DONE
Before claiming "0 unresolved threads", always query GitHub directly — don't rely on your own bookkeeping. Paginate all pages — a single `first: 100` query misses threads beyond page 1:

View File

@@ -5,7 +5,7 @@ user-invocable: true
argument-hint: "[worktree path or PR number] — tests the PR in the given worktree. Optional flags: --fix (auto-fix issues found)"
metadata:
author: autogpt-team
version: "2.0.0"
version: "2.1.0"
---
# Manual E2E Test
@@ -180,6 +180,94 @@ Based on the PR analysis, write a test plan to `$RESULTS_DIR/test-plan.md`:
**Be critical** — include edge cases, error paths, and security checks. Every scenario MUST specify what screenshots to take and what state to verify.
## Step 3.0: Claim the testing lock (coordinate parallel agents)
Multiple worktrees share the same host — Docker infra (postgres, redis, clamav), app ports (3000/8006/…), and the test user. Two agents running `/pr-test` concurrently will corrupt each other's state (connection-pool exhaustion, port binds failing silently, cross-test assertions). Use the root-worktree lock file to take turns.
### Lock file contract
Path (**always** the root worktree so all siblings see it): `/Users/majdyz/Code/AutoGPT/.ign.testing.lock`
Body (one `key=value` per line):
```
holder=<pr-XXXXX-purpose>
pid=<pid-or-"self">
started=<iso8601>
heartbeat=<iso8601, updated every ~2 min>
worktree=<full path>
branch=<branch name>
intent=<one-line description + rough duration>
```
### Claim
```bash
LOCK=/Users/majdyz/Code/AutoGPT/.ign.testing.lock
NOW=$(date -u +%Y-%m-%dT%H:%MZ)
STALE_AFTER_MIN=5
if [ -f "$LOCK" ]; then
HB=$(grep '^heartbeat=' "$LOCK" | cut -d= -f2)
HB_EPOCH=$(date -j -f '%Y-%m-%dT%H:%MZ' "$HB" +%s 2>/dev/null || date -d "$HB" +%s 2>/dev/null || echo 0)
AGE_MIN=$(( ( $(date -u +%s) - HB_EPOCH ) / 60 ))
if [ "$AGE_MIN" -gt "$STALE_AFTER_MIN" ]; then
echo "WARN: stale lock (${AGE_MIN}m old) — reclaiming"
cat "$LOCK" | sed 's/^/ stale: /'
else
echo "Another agent holds the lock:"; cat "$LOCK"
echo "Wait until released or resume after $((STALE_AFTER_MIN - AGE_MIN))m."
exit 1
fi
fi
cat > "$LOCK" <<EOF
holder=pr-${PR_NUMBER}-e2e
pid=self
started=$NOW
heartbeat=$NOW
worktree=$WORKTREE_PATH
branch=$(cd $WORKTREE_PATH && git branch --show-current)
intent=E2E test PR #${PR_NUMBER}, native mode, ~60min
EOF
echo "Lock claimed"
```
### Heartbeat (MUST run in background during the whole test)
Without a heartbeat a crashed agent keeps the lock forever. Run this as a background process right after claim:
```bash
(while true; do
sleep 120
[ -f "$LOCK" ] || exit 0 # lock released → exit heartbeat
perl -i -pe "s/^heartbeat=.*/heartbeat=$(date -u +%Y-%m-%dT%H:%MZ)/" "$LOCK"
done) &
HEARTBEAT_PID=$!
echo "$HEARTBEAT_PID" > /tmp/pr-test-heartbeat.pid
```
### Release (always — even on failure)
```bash
kill "$HEARTBEAT_PID" 2>/dev/null
rm -f "$LOCK" /tmp/pr-test-heartbeat.pid
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] released lock" \
>> /Users/majdyz/Code/AutoGPT/.ign.testing.log
```
Use a `trap` so release runs even on `exit 1`:
```bash
trap 'kill "$HEARTBEAT_PID" 2>/dev/null; rm -f "$LOCK"' EXIT INT TERM
```
### Shared status log
`/Users/majdyz/Code/AutoGPT/.ign.testing.log` is an append-only channel any agent can read/write. Use it for "I'm waiting", "I'm done, resources free", or post-run notes:
```bash
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] <message>" \
>> /Users/majdyz/Code/AutoGPT/.ign.testing.log
```
## Step 3: Environment setup
### 3a. Copy .env files from the root worktree
@@ -248,7 +336,87 @@ docker ps --format "{{.Names}}" | grep -E "rest_server|executor|copilot|websocke
done
```
### 3e. Build and start
**Native mode also:** when running the app natively (see 3e-native), kill any stray host processes and free the app ports before starting — otherwise `poetry run app` and `pnpm dev` will fail to bind.
```bash
# Kill stray native app processes from prior runs
pkill -9 -f "python.*backend" 2>/dev/null || true
pkill -9 -f "poetry run app" 2>/dev/null || true
pkill -9 -f "next-server|next dev" 2>/dev/null || true
# Free app ports (errors per port are ignored — port may simply be unused)
for port in 3000 8006 8001 8002 8005 8008; do
lsof -ti :$port -sTCP:LISTEN | xargs -r kill -9 2>/dev/null || true
done
```
### 3e-native. Run the app natively (PREFERRED for iterative dev)
Native mode runs infra (postgres, supabase, redis, rabbitmq, clamav) in docker but runs the backend and frontend directly on the host. This avoids the 3-8 minute `docker compose build` cycle on every backend change — code edits are picked up on process restart (seconds) instead of a full image rebuild.
**When to prefer native mode (default for this skill):**
- Iterative dev/debug loops where you're editing backend or frontend code between test runs
- Any PR that touches Python/TS source but not Dockerfiles, compose config, or infra images
- Fast repro of a failing scenario — restart `poetry run app` in a couple of seconds
**When to prefer docker mode (3e fallback):**
- Testing changes to `Dockerfile`, `docker-compose.yml`, or base images
- Production-parity smoke tests (exact container env, networking, volumes)
- CI-equivalent runs where you need the exact image that'll ship
**Note on 3b (copilot auth):** no npm install anywhere. `poetry install` pulls in `claude_agent_sdk`, which ships its own Claude CLI binary — available on `PATH` whenever you run commands via `poetry run` (native) OR whenever the copilot_executor container is built from its Poetry lockfile (docker). The OAuth token extraction still applies (same `refresh_claude_token.sh` call).
**Preamble:** before starting native, run the kill-stray + free-ports block from 3c's "Native mode also" subsection.
**1. Start infra only (one-time per session):**
```bash
cd $PLATFORM_DIR && docker compose --profile local up deps --detach --remove-orphans --build
```
This brings up postgres/supabase/redis/rabbitmq/clamav and skips all app services.
**2. Start the backend natively:**
```bash
cd $BACKEND_DIR && (poetry run app 2>&1 | tee .ign.application.logs) &
```
`poetry run app` spawns **all** app subprocesses — `rest_server`, `executor`, `copilot_executor`, `websocket`, `scheduler`, `notification_server`, `database_manager` — inside ONE parent process. No separate containers, no separate terminals. The `.ign.application.logs` prefix is already gitignored.
**3. Wait for the backend on :8006 BEFORE starting the frontend.** This ordering matters — the frontend's `pnpm dev` startup invokes `generate-api-queries`, which fetches `/openapi.json` from the backend. If the backend isn't listening yet, `pnpm dev` fails immediately.
```bash
for i in $(seq 1 60); do
if [ "$(curl -s -o /dev/null -w '%{http_code}' http://localhost:8006/docs 2>/dev/null)" = "200" ]; then
echo "Backend ready"
break
fi
sleep 2
done
```
**4. Start the frontend natively:**
```bash
cd $FRONTEND_DIR && (pnpm dev 2>&1 | tee .ign.frontend.logs) &
```
**5. Wait for the frontend on :3000:**
```bash
for i in $(seq 1 60); do
if [ "$(curl -s -o /dev/null -w '%{http_code}' http://localhost:3000 2>/dev/null)" = "200" ]; then
echo "Frontend ready"
break
fi
sleep 2
done
```
Once both are up, skip 3e/3f and go straight to **3g/3h** (feature flags / test user creation).
### 3e. Build and start (docker — fallback)
```bash
cd $PLATFORM_DIR && docker compose build --no-cache 2>&1 | tail -20
@@ -442,6 +610,22 @@ agent-browser --session-name pr-test snapshot | grep "text:"
### Checking logs
**Native mode:** when running via `poetry run app` + `pnpm dev`, all app logs stream to the `.ign.*.logs` files written by the `tee` pipes in 3e-native. `rest_server`, `executor`, `copilot_executor`, `websocket`, `scheduler`, `notification_server`, and `database_manager` are all subprocesses of the single `poetry run app` parent, so their output is interleaved in `.ign.application.logs`.
```bash
# Backend (all app subprocesses interleaved)
tail -f $BACKEND_DIR/.ign.application.logs
# Frontend (Next.js dev server)
tail -f $FRONTEND_DIR/.ign.frontend.logs
# Filter for errors across either log
grep -iE "error|exception|traceback" $BACKEND_DIR/.ign.application.logs | tail -20
grep -iE "error|exception|traceback" $FRONTEND_DIR/.ign.frontend.logs | tail -20
```
**Docker mode:**
```bash
# Backend REST server
docker logs autogpt_platform-rest_server-1 2>&1 | tail -30
@@ -876,9 +1060,15 @@ test scenario → find issue (bug OR UX problem) → screenshot broken state
### Problem: Frontend shows cookie banner blocking interaction
**Fix:** `agent-browser click 'text=Accept All'` before other interactions.
### Problem: Container loses npm packages after rebuild
**Cause:** `docker compose up --build` rebuilds the image, losing runtime installs.
**Fix:** Add packages to the Dockerfile instead of installing at runtime.
### Problem: Claude CLI not found in copilot_executor container
**Symptom:** Copilot logs say `claude: command not found` or similar when starting an SDK turn.
**Cause:** Image was built without `poetry install` (stale base layer, or Dockerfile bypass). The SDK CLI ships inside the `claude_agent_sdk` Poetry dep — it is NOT an npm package.
**Fix:** Rebuild the image cleanly: `docker compose build --no-cache copilot_executor && docker compose up -d copilot_executor`. Do NOT `docker exec ... npm install -g @anthropic-ai/claude-code` — that is outdated guidance and will pollute the container with a second CLI that the SDK won't use.
### Problem: agent-browser screenshot hangs / times out
**Symptom:** `agent-browser screenshot` exits with code 124 even on `about:blank`.
**Cause:** Stuck CDP connection or Chromium process tree. Seen on macOS when a prior `/pr-test` left a zombie Chrome for Testing.
**Fix:** `pkill -9 -f "agent-browser|chromium|Chrome for Testing" && sleep 2`, then reopen the browser with a fresh `--session-name`. If still failing, verify via `agent-browser eval` + `agent-browser snapshot` (DOM state) instead of relying on PNGs — the feature under test is the same.
### Problem: Services not starting after `docker compose up`
**Fix:** Wait and check health: `docker compose ps`. Common cause: migration hasn't finished. Check: `docker logs autogpt_platform-migrate-1 2>&1 | tail -5`. If supabase-db isn't healthy: `docker restart supabase-db && sleep 10`.

View File

@@ -1,430 +0,0 @@
name: CLA Label Sync
on:
# Real-time: when CLA status changes (CLA-assistant uses Status API)
status:
# When PRs are opened or updated
pull_request_target:
types: [opened, synchronize, reopened]
# Scheduled sweep - check stale PRs daily
schedule:
- cron: '0 9 * * *' # 9 AM UTC daily
# Manual trigger for testing
workflow_dispatch:
inputs:
pr_number:
description: 'Specific PR number to check (optional)'
required: false
permissions:
pull-requests: write
issues: write
contents: read
statuses: read
checks: read
env:
CLA_CHECK_NAME: 'license/cla'
LABEL_PENDING: 'cla: pending'
LABEL_SIGNED: 'cla: signed'
# Timing configuration (all independently configurable)
REMINDER_DAYS: 7 # Days before first reminder
CLOSE_WARNING_DAYS: 23 # Days before "closing soon" warning
CLOSE_DAYS: 30 # Days before auto-close
jobs:
sync-labels:
runs-on: ubuntu-latest
# Only run on status events if it's the CLA check
if: github.event_name != 'status' || github.event.context == 'license/cla'
steps:
- name: Ensure CLA labels exist
uses: actions/github-script@v7
with:
script: |
const labels = [
{ name: 'cla: pending', color: 'fbca04', description: 'CLA not yet signed by all contributors' },
{ name: 'cla: signed', color: '0e8a16', description: 'CLA signed by all contributors' }
];
for (const label of labels) {
try {
await github.rest.issues.getLabel({
owner: context.repo.owner,
repo: context.repo.repo,
name: label.name
});
} catch (e) {
if (e.status === 404) {
await github.rest.issues.createLabel({
owner: context.repo.owner,
repo: context.repo.repo,
name: label.name,
color: label.color,
description: label.description
});
console.log(`Created label: ${label.name}`);
} else {
throw e;
}
}
}
- name: Sync CLA labels and handle stale PRs
uses: actions/github-script@v7
with:
script: |
const CLA_CHECK_NAME = process.env.CLA_CHECK_NAME;
const LABEL_PENDING = process.env.LABEL_PENDING;
const LABEL_SIGNED = process.env.LABEL_SIGNED;
const REMINDER_DAYS = parseInt(process.env.REMINDER_DAYS);
const CLOSE_WARNING_DAYS = parseInt(process.env.CLOSE_WARNING_DAYS);
const CLOSE_DAYS = parseInt(process.env.CLOSE_DAYS);
// Validate timing configuration
if ([REMINDER_DAYS, CLOSE_WARNING_DAYS, CLOSE_DAYS].some(Number.isNaN)) {
core.setFailed('Invalid timing configuration — REMINDER_DAYS, CLOSE_WARNING_DAYS, and CLOSE_DAYS must be numeric.');
return;
}
if (!(REMINDER_DAYS < CLOSE_WARNING_DAYS && CLOSE_WARNING_DAYS < CLOSE_DAYS)) {
core.warning(`Timing order looks odd: REMINDER(${REMINDER_DAYS}) < WARNING(${CLOSE_WARNING_DAYS}) < CLOSE(${CLOSE_DAYS}) expected.`);
}
const CLA_SIGN_URL = `https://cla-assistant.io/${context.repo.owner}/${context.repo.repo}`;
// Helper: Get CLA status for a commit
async function getClaStatus(headSha) {
// CLA-assistant uses the commit status API (not checks API)
const { data: statuses } = await github.rest.repos.getCombinedStatusForRef({
owner: context.repo.owner,
repo: context.repo.repo,
ref: headSha
});
const claStatus = statuses.statuses.find(
s => s.context === CLA_CHECK_NAME
);
if (claStatus) {
return {
found: true,
passed: claStatus.state === 'success',
state: claStatus.state,
description: claStatus.description
};
}
// Fallback: check the Checks API too
const { data: checkRuns } = await github.rest.checks.listForRef({
owner: context.repo.owner,
repo: context.repo.repo,
ref: headSha
});
const claCheck = checkRuns.check_runs.find(
check => check.name === CLA_CHECK_NAME
);
if (claCheck) {
return {
found: true,
passed: claCheck.conclusion === 'success',
state: claCheck.conclusion,
description: claCheck.output?.summary || ''
};
}
return { found: false, passed: false, state: 'unknown' };
}
// Helper: Check if bot already commented with a specific marker (paginated)
async function hasCommentWithMarker(prNumber, marker) {
// Use paginate to fetch ALL comments, not just first 100
const comments = await github.paginate(
github.rest.issues.listComments,
{
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
per_page: 100
}
);
return comments.some(c =>
c.user?.type === 'Bot' &&
c.body?.includes(marker)
);
}
// Helper: Days since a date
function daysSince(dateString) {
const date = new Date(dateString);
const now = new Date();
return Math.floor((now - date) / (1000 * 60 * 60 * 24));
}
// Determine which PRs to check
let prsToCheck = [];
if (context.eventName === 'status') {
// Status event from CLA-assistant - find PRs with this commit
const sha = context.payload.sha;
console.log(`Status event for SHA: ${sha}, context: ${context.payload.context}`);
// Search for open PRs with this head SHA (paginated)
const allPRs = await github.paginate(
github.rest.pulls.list,
{
owner: context.repo.owner,
repo: context.repo.repo,
state: 'open',
per_page: 100
}
);
prsToCheck = allPRs.filter(pr => pr.head.sha === sha).map(pr => pr.number);
if (prsToCheck.length === 0) {
console.log('No open PRs found with this SHA');
return;
}
} else if (context.eventName === 'pull_request_target') {
prsToCheck = [context.payload.pull_request.number];
} else if (context.eventName === 'workflow_dispatch' && context.payload.inputs?.pr_number) {
prsToCheck = [parseInt(context.payload.inputs.pr_number)];
} else {
// Scheduled run: check all open PRs (paginated to handle >100 PRs)
const openPRs = await github.paginate(
github.rest.pulls.list,
{
owner: context.repo.owner,
repo: context.repo.repo,
state: 'open',
per_page: 100
}
);
prsToCheck = openPRs.map(pr => pr.number);
}
console.log(`Checking ${prsToCheck.length} PR(s): ${prsToCheck.join(', ')}`);
for (const prNumber of prsToCheck) {
try {
// Get PR details
const { data: pr } = await github.rest.pulls.get({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: prNumber
});
// Skip if PR is from a bot
if (pr.user.type === 'Bot') {
console.log(`PR #${prNumber}: Skipping bot PR`);
continue;
}
// Skip if PR is not open (closed/merged)
if (pr.state !== 'open') {
console.log(`PR #${prNumber}: Skipping non-open PR (state=${pr.state})`);
continue;
}
// Skip if PR already has cla: signed label (optimization for scheduled sweeps)
const currentLabels = pr.labels.map(l => l.name);
const knownPlatformPR = currentLabels.includes(LABEL_SIGNED) || currentLabels.includes(LABEL_PENDING);
// Skip listFiles if we've already labelled this PR (a previous run verified it touches platform code)
if (!knownPlatformPR) {
const PLATFORM_PATH = 'autogpt_platform/';
const prFiles = await github.paginate(
github.rest.pulls.listFiles,
{
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: prNumber,
per_page: 100
}
);
const touchesPlatform = prFiles.some(f => f.filename.startsWith(PLATFORM_PATH));
if (!touchesPlatform) {
console.log(`PR #${prNumber}: Skipping - doesn't touch ${PLATFORM_PATH}`);
continue;
}
}
const claStatus = await getClaStatus(pr.head.sha);
const hasPending = currentLabels.includes(LABEL_PENDING);
const hasSigned = currentLabels.includes(LABEL_SIGNED);
const prAgeDays = daysSince(pr.created_at);
console.log(`PR #${prNumber}: CLA ${claStatus.passed ? 'passed' : 'pending'} (${claStatus.state}), age: ${prAgeDays} days`);
if (claStatus.passed) {
// ✅ CLA signed - add signed label, remove pending
if (!hasSigned) {
await github.rest.issues.addLabels({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
labels: [LABEL_SIGNED]
});
console.log(`Added '${LABEL_SIGNED}' to PR #${prNumber}`);
}
if (hasPending) {
await github.rest.issues.removeLabel({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
name: LABEL_PENDING
});
console.log(`Removed '${LABEL_PENDING}' from PR #${prNumber}`);
}
} else {
// ⏳ CLA pending
// Add pending label if not present
if (!hasPending) {
await github.rest.issues.addLabels({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
labels: [LABEL_PENDING]
});
console.log(`Added '${LABEL_PENDING}' to PR #${prNumber}`);
}
if (hasSigned) {
await github.rest.issues.removeLabel({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
name: LABEL_SIGNED
});
console.log(`Removed '${LABEL_SIGNED}' from PR #${prNumber}`);
}
// Check if we need to send reminder or close
const REMINDER_MARKER = '<!-- cla-reminder -->';
const CLOSE_WARNING_MARKER = '<!-- cla-close-warning -->';
// 📢 Reminder after REMINDER_DAYS (but before warning window)
if (prAgeDays >= REMINDER_DAYS && prAgeDays < CLOSE_WARNING_DAYS) {
const hasReminder = await hasCommentWithMarker(prNumber, REMINDER_MARKER);
if (!hasReminder) {
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
body: `${REMINDER_MARKER}
👋 **Friendly reminder:** This PR is waiting on a signed CLA.
All contributors need to sign our Contributor License Agreement before we can merge this PR.
**➡️ [Sign the CLA here](${CLA_SIGN_URL}?pullRequest=${prNumber})**
<details>
<summary>Why do we need a CLA?</summary>
The CLA protects both you and the project by clarifying the terms under which your contribution is made. It's a one-time process — once signed, it covers all your future contributions.
</details>
<details>
<summary>Common issues</summary>
- **Email mismatch:** Make sure your Git commit email matches your GitHub account email
- **Merge commits:** If you merged \`dev\` into your branch, try rebasing instead: \`git rebase origin/dev && git push --force-with-lease\`
- **Multiple authors:** All commit authors need to sign, not just the PR author
</details>
If you have questions, just ask! 🙂`
});
console.log(`Posted reminder on PR #${prNumber}`);
}
}
// ⚠️ Close warning at CLOSE_WARNING_DAYS
if (prAgeDays >= CLOSE_WARNING_DAYS && prAgeDays < CLOSE_DAYS) {
const hasCloseWarning = await hasCommentWithMarker(prNumber, CLOSE_WARNING_MARKER);
if (!hasCloseWarning) {
const daysRemaining = CLOSE_DAYS - prAgeDays;
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
body: `${CLOSE_WARNING_MARKER}
⚠️ **This PR will be automatically closed in ${daysRemaining} day${daysRemaining === 1 ? '' : 's'}** if the CLA is not signed.
We haven't received a signed CLA from all contributors yet. Please sign it to keep this PR open:
**➡️ [Sign the CLA here](${CLA_SIGN_URL}?pullRequest=${prNumber})**
If you're unable to sign or have questions, please let us know — we're happy to help!`
});
console.log(`Posted close warning on PR #${prNumber}`);
}
}
// 🚪 Auto-close after CLOSE_DAYS
if (prAgeDays >= CLOSE_DAYS) {
const CLOSE_MARKER = '<!-- cla-auto-closed -->';
const OVERRIDE_LABEL = 'cla: override';
// Check for override label (maintainer wants to keep PR open)
if (currentLabels.includes(OVERRIDE_LABEL)) {
console.log(`PR #${prNumber}: Skipping close due to '${OVERRIDE_LABEL}' label`);
} else {
// Check if we already posted a close comment
const hasCloseComment = await hasCommentWithMarker(prNumber, CLOSE_MARKER);
if (!hasCloseComment) {
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
body: `${CLOSE_MARKER}
👋 Closing this PR due to unsigned CLA after ${CLOSE_DAYS} days.
Thank you for your contribution! If you'd still like to contribute:
1. [Sign the CLA](${CLA_SIGN_URL})
2. Re-open this PR or create a new one
> **Maintainers:** To reopen and exempt from future auto-close, add the \`cla: override\` label before reopening. Without it, the PR will not be re-closed automatically (a reopened PR is treated as a maintainer decision).
We appreciate your interest in AutoGPT and hope to see you back! 🚀`
});
await github.rest.pulls.update({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: prNumber,
state: 'closed'
});
console.log(`Closed PR #${prNumber} due to unsigned CLA`);
} else {
console.log(`PR #${prNumber}: Already auto-closed previously, skipping (maintainer may have reopened)`);
}
}
}
}
} catch (error) {
console.error(`Error processing PR #${prNumber}: ${error.message}`);
}
}
console.log('CLA label sync complete!');

View File

@@ -32,10 +32,10 @@ router = APIRouter(
class UserRateLimitResponse(BaseModel):
user_id: str
user_email: Optional[str] = None
daily_token_limit: int
weekly_token_limit: int
daily_tokens_used: int
weekly_tokens_used: int
daily_cost_limit_microdollars: int
weekly_cost_limit_microdollars: int
daily_cost_used_microdollars: int
weekly_cost_used_microdollars: int
tier: SubscriptionTier
@@ -101,17 +101,19 @@ async def get_user_rate_limit(
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
daily_limit, weekly_limit, tier = await get_global_rate_limits(
resolved_id, config.daily_token_limit, config.weekly_token_limit
resolved_id,
config.daily_cost_limit_microdollars,
config.weekly_cost_limit_microdollars,
)
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit, tier=tier)
return UserRateLimitResponse(
user_id=resolved_id,
user_email=resolved_email,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_tokens_used=usage.daily.used,
weekly_tokens_used=usage.weekly.used,
daily_cost_limit_microdollars=daily_limit,
weekly_cost_limit_microdollars=weekly_limit,
daily_cost_used_microdollars=usage.daily.used,
weekly_cost_used_microdollars=usage.weekly.used,
tier=tier,
)
@@ -141,7 +143,9 @@ async def reset_user_rate_limit(
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
daily_limit, weekly_limit, tier = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
user_id,
config.daily_cost_limit_microdollars,
config.weekly_cost_limit_microdollars,
)
usage = await get_usage_status(user_id, daily_limit, weekly_limit, tier=tier)
@@ -154,10 +158,10 @@ async def reset_user_rate_limit(
return UserRateLimitResponse(
user_id=user_id,
user_email=resolved_email,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_tokens_used=usage.daily.used,
weekly_tokens_used=usage.weekly.used,
daily_cost_limit_microdollars=daily_limit,
weekly_cost_limit_microdollars=weekly_limit,
daily_cost_used_microdollars=usage.daily.used,
weekly_cost_used_microdollars=usage.weekly.used,
tier=tier,
)

View File

@@ -85,10 +85,10 @@ def test_get_rate_limit(
data = response.json()
assert data["user_id"] == target_user_id
assert data["user_email"] == _TARGET_EMAIL
assert data["daily_token_limit"] == 2_500_000
assert data["weekly_token_limit"] == 12_500_000
assert data["daily_tokens_used"] == 500_000
assert data["weekly_tokens_used"] == 3_000_000
assert data["daily_cost_limit_microdollars"] == 2_500_000
assert data["weekly_cost_limit_microdollars"] == 12_500_000
assert data["daily_cost_used_microdollars"] == 500_000
assert data["weekly_cost_used_microdollars"] == 3_000_000
assert data["tier"] == "FREE"
configured_snapshot.assert_match(
@@ -117,7 +117,7 @@ def test_get_rate_limit_by_email(
data = response.json()
assert data["user_id"] == target_user_id
assert data["user_email"] == _TARGET_EMAIL
assert data["daily_token_limit"] == 2_500_000
assert data["daily_cost_limit_microdollars"] == 2_500_000
def test_get_rate_limit_by_email_not_found(
@@ -160,9 +160,9 @@ def test_reset_user_usage_daily_only(
assert response.status_code == 200
data = response.json()
assert data["daily_tokens_used"] == 0
assert data["daily_cost_used_microdollars"] == 0
# Weekly is untouched
assert data["weekly_tokens_used"] == 3_000_000
assert data["weekly_cost_used_microdollars"] == 3_000_000
assert data["tier"] == "FREE"
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
@@ -192,8 +192,8 @@ def test_reset_user_usage_daily_and_weekly(
assert response.status_code == 200
data = response.json()
assert data["daily_tokens_used"] == 0
assert data["weekly_tokens_used"] == 0
assert data["daily_cost_used_microdollars"] == 0
assert data["weekly_cost_used_microdollars"] == 0
assert data["tier"] == "FREE"
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)

View File

@@ -2,15 +2,13 @@
import asyncio
import logging
import re
from collections.abc import AsyncGenerator
from typing import Annotated
from uuid import uuid4
from autogpt_libs import auth
from fastapi import APIRouter, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from prisma.models import UserWorkspaceFile
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ConfigDict, Field, field_validator
from backend.copilot import service as chat_service
@@ -18,7 +16,6 @@ from backend.copilot import stream_registry
from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode
from backend.copilot.db import get_chat_messages_paginated
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
from backend.copilot.message_dedup import acquire_dedup_lock
from backend.copilot.model import (
ChatMessage,
ChatSession,
@@ -30,8 +27,14 @@ from backend.copilot.model import (
get_user_sessions,
update_session_title,
)
from backend.copilot.pending_message_helpers import (
QueuePendingMessageResponse,
is_turn_in_flight,
queue_pending_for_http,
)
from backend.copilot.pending_messages import peek_pending_messages
from backend.copilot.rate_limit import (
CoPilotUsageStatus,
CoPilotUsagePublic,
RateLimitExceeded,
acquire_reset_lock,
check_rate_limit,
@@ -76,7 +79,7 @@ from backend.copilot.tracking import track_user_message
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.redis_client import get_redis_async
from backend.data.understanding import get_business_understanding
from backend.data.workspace import get_or_create_workspace
from backend.data.workspace import build_files_block, resolve_workspace_files
from backend.util.exceptions import InsufficientBalanceError, NotFoundError
from backend.util.settings import Settings
@@ -86,10 +89,6 @@ logger = logging.getLogger(__name__)
config = ChatConfig()
_UUID_RE = re.compile(
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
)
async def _validate_and_get_session(
session_id: str,
@@ -152,6 +151,19 @@ class StreamChatRequest(BaseModel):
)
class PeekPendingMessagesResponse(BaseModel):
"""Response for the pending-message peek (GET) endpoint.
Returns a read-only view of the pending buffer — messages are NOT
consumed. The frontend uses this to restore the queued-message
indicator after a page refresh and to decide when to clear it once
a turn has ended.
"""
messages: list[str]
count: int
class CreateSessionRequest(BaseModel):
"""Request model for creating a new chat session.
@@ -463,22 +475,13 @@ async def get_session(
Supports cursor-based pagination via ``limit`` and ``before_sequence``.
When no pagination params are provided, returns the most recent messages.
Args:
session_id: The unique identifier for the desired chat session.
user_id: The authenticated user's ID.
limit: Maximum number of messages to return (1-200, default 50).
before_sequence: Return messages with sequence < this value (cursor).
Returns:
SessionDetailResponse: Details for the requested session, including
active_stream info and pagination metadata.
"""
page = await get_chat_messages_paginated(
session_id, limit, before_sequence, user_id=user_id
)
if page is None:
raise NotFoundError(f"Session {session_id} not found.")
messages = [
_strip_injected_context(message.model_dump()) for message in page.messages
]
@@ -489,10 +492,6 @@ async def get_session(
active_session, last_message_id = await stream_registry.get_active_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_session:
active_stream_info = ActiveStreamInfo(
turn_id=active_session.turn_id,
@@ -537,23 +536,27 @@ async def get_session(
)
async def get_copilot_usage(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> CoPilotUsageStatus:
) -> CoPilotUsagePublic:
"""Get CoPilot usage status for the authenticated user.
Returns current token usage vs limits for daily and weekly windows.
Global defaults sourced from LaunchDarkly (falling back to config).
Includes the user's rate-limit tier.
Returns the percentage of the daily/weekly allowance used — not the
raw spend or cap — so clients cannot derive per-turn cost or platform
margins. Global defaults sourced from LaunchDarkly (falling back to
config). Includes the user's rate-limit tier.
"""
daily_limit, weekly_limit, tier = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
user_id,
config.daily_cost_limit_microdollars,
config.weekly_cost_limit_microdollars,
)
return await get_usage_status(
status = await get_usage_status(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_cost_limit=daily_limit,
weekly_cost_limit=weekly_limit,
rate_limit_reset_cost=config.rate_limit_reset_cost,
tier=tier,
)
return CoPilotUsagePublic.from_status(status)
class RateLimitResetResponse(BaseModel):
@@ -562,7 +565,9 @@ class RateLimitResetResponse(BaseModel):
success: bool
credits_charged: int = Field(description="Credits charged (in cents)")
remaining_balance: int = Field(description="Credit balance after charge (in cents)")
usage: CoPilotUsageStatus = Field(description="Updated usage status after reset")
usage: CoPilotUsagePublic = Field(
description="Updated usage status after reset (percentages only)"
)
@router.post(
@@ -586,7 +591,7 @@ async def reset_copilot_usage(
) -> RateLimitResetResponse:
"""Reset the daily CoPilot rate limit by spending credits.
Allows users who have hit their daily token limit to spend credits
Allows users who have hit their daily cost limit to spend credits
to reset their daily usage counter and continue working.
Returns 400 if the feature is disabled or the user is not over the limit.
Returns 402 if the user has insufficient credits.
@@ -605,7 +610,9 @@ async def reset_copilot_usage(
)
daily_limit, weekly_limit, tier = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
user_id,
config.daily_cost_limit_microdollars,
config.weekly_cost_limit_microdollars,
)
if daily_limit <= 0:
@@ -642,8 +649,8 @@ async def reset_copilot_usage(
# used for limit checks, not returned to the client.)
usage_status = await get_usage_status(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_cost_limit=daily_limit,
weekly_cost_limit=weekly_limit,
tier=tier,
)
if daily_limit > 0 and usage_status.daily.used < daily_limit:
@@ -678,7 +685,7 @@ async def reset_copilot_usage(
# Reset daily usage in Redis. If this fails, refund the credits
# so the user is not charged for a service they did not receive.
if not await reset_daily_usage(user_id, daily_token_limit=daily_limit):
if not await reset_daily_usage(user_id, daily_cost_limit=daily_limit):
# Compensate: refund the charged credits.
refunded = False
try:
@@ -714,11 +721,11 @@ async def reset_copilot_usage(
finally:
await release_reset_lock(user_id)
# Return updated usage status.
# Return updated usage status (public schema — percentages only).
updated_usage = await get_usage_status(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_cost_limit=daily_limit,
weekly_cost_limit=weekly_limit,
rate_limit_reset_cost=config.rate_limit_reset_cost,
tier=tier,
)
@@ -727,7 +734,7 @@ async def reset_copilot_usage(
success=True,
credits_charged=cost,
remaining_balance=remaining,
usage=updated_usage,
usage=CoPilotUsagePublic.from_status(updated_usage),
)
@@ -778,36 +785,52 @@ async def cancel_session_task(
@router.post(
"/sessions/{session_id}/stream",
responses={
202: {
"model": QueuePendingMessageResponse,
"description": (
"Session has a turn in flight — message queued into the pending "
"buffer and will be picked up between tool-call rounds by the "
"executor currently processing the turn."
),
},
404: {"description": "Session not found or access denied"},
429: {"description": "Cost rate-limit or call-frequency cap exceeded"},
},
)
async def stream_chat_post(
session_id: str,
request: StreamChatRequest,
user_id: str = Security(auth.get_user_id),
):
"""
Stream chat responses for a session (POST with context support).
"""Start a new turn OR queue a follow-up — decided server-side.
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
- Text fragments as they are generated
- Tool call UI elements (if invoked)
- Tool execution results
- **Session idle**: starts a turn. Returns an SSE stream (``text/event-stream``)
with Vercel AI SDK chunks (text fragments, tool-call UI, tool results).
The generation runs in a background task that survives client disconnects;
reconnect via ``GET /sessions/{session_id}/stream`` to resume.
The AI generation runs in a background task that continues even if the client disconnects.
All chunks are written to a per-turn Redis stream for reconnection support. If the client
disconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.
- **Session has a turn in flight**: pushes the message into the per-session
pending buffer and returns ``202 application/json`` with
``QueuePendingMessageResponse``. The executor running the current turn
drains the buffer between tool-call rounds (baseline) or at the start of
the next turn (SDK). Clients should detect the 202 and surface the
message as a queued-chip in the UI.
Args:
session_id: The chat session identifier to associate with the streamed messages.
request: Request body containing message, is_user_message, and optional context.
session_id: The chat session identifier.
request: Request body with message, is_user_message, and optional context.
user_id: Authenticated user ID.
Returns:
StreamingResponse: SSE-formatted response chunks.
"""
import asyncio
import time
stream_start_time = time.perf_counter()
# Wall-clock arrival time, propagated to the executor so the turn-start
# drain can order pending messages relative to this request (pending
# pushed BEFORE this instant were typed earlier; pending pushed AFTER
# are race-path follow-ups typed while /stream was still processing).
request_arrival_at = time.time()
log_meta = {"component": "ChatStream", "session_id": session_id, "user_id": user_id}
logger.info(
@@ -816,6 +839,26 @@ async def stream_chat_post(
extra={"json_fields": log_meta},
)
await _validate_and_get_session(session_id, user_id)
# Self-defensive queue-fallback: if a turn is already running, don't race
# it on the cluster lock — drop the message into the pending buffer and
# return 202 so the caller can render a chip. Both UI chips and autopilot
# block follow-ups route through this path; keeping the decision on the
# server means every caller gets uniform behaviour.
if (
request.is_user_message
and request.message
and await is_turn_in_flight(session_id)
):
response = await queue_pending_for_http(
session_id=session_id,
user_id=user_id,
message=request.message,
context=request.context,
file_ids=request.file_ids,
)
return JSONResponse(status_code=202, content=response.model_dump())
logger.info(
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
extra={
@@ -826,18 +869,20 @@ async def stream_chat_post(
},
)
# Pre-turn rate limit check (token-based).
# Pre-turn rate limit check (cost-based, microdollars).
# check_rate_limit short-circuits internally when both limits are 0.
# Global defaults sourced from LaunchDarkly, falling back to config.
if user_id:
try:
daily_limit, weekly_limit, _ = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
user_id,
config.daily_cost_limit_microdollars,
config.weekly_cost_limit_microdollars,
)
await check_rate_limit(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_cost_limit=daily_limit,
weekly_cost_limit=weekly_limit,
)
except RateLimitExceeded as e:
raise HTTPException(status_code=429, detail=str(e)) from e
@@ -846,89 +891,41 @@ async def stream_chat_post(
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
sanitized_file_ids: list[str] | None = None
# Capture the original message text BEFORE any mutation (attachment enrichment)
# so the idempotency hash is stable across retries.
original_message = request.message
if request.file_ids and user_id:
# Filter to valid UUIDs only to prevent DB abuse
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
if valid_ids:
workspace = await get_or_create_workspace(user_id)
# Batch query instead of N+1
files = await UserWorkspaceFile.prisma().find_many(
where={
"id": {"in": valid_ids},
"workspaceId": workspace.id,
"isDeleted": False,
}
)
# Only keep IDs that actually exist in the user's workspace
sanitized_file_ids = [wf.id for wf in files] or None
file_lines: list[str] = [
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
for wf in files
]
if file_lines:
files_block = (
"\n\n[Attached files]\n"
+ "\n".join(file_lines)
+ "\nUse read_workspace_file with the file_id to access file contents."
)
request.message += files_block
# ── Idempotency guard ────────────────────────────────────────────────────
# Blocks duplicate executor tasks from concurrent/retried POSTs.
# See backend/copilot/message_dedup.py for the full lifecycle description.
dedup_lock = None
if request.is_user_message:
dedup_lock = await acquire_dedup_lock(
session_id, original_message, sanitized_file_ids
)
if dedup_lock is None and (original_message or sanitized_file_ids):
async def _empty_sse() -> AsyncGenerator[str, None]:
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return StreamingResponse(
_empty_sse(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
"Connection": "keep-alive",
"x-vercel-ai-ui-message-stream": "v1",
},
)
if request.file_ids:
files = await resolve_workspace_files(user_id, request.file_ids)
sanitized_file_ids = [wf.id for wf in files] or None
request.message += build_files_block(files)
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
# saved yet. append_and_save_message re-fetches inside a lock to prevent
# message loss from concurrent requests.
#
# If any of these operations raises, release the dedup lock before propagating
# so subsequent retries are not blocked for 30 s.
try:
if request.message:
message = ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
)
if request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
# saved yet. append_and_save_message returns None when a duplicate is
# detected — in that case skip enqueue to avoid processing the message twice.
is_duplicate_message = False
if request.message:
message = ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
is_duplicate_message = (
await append_and_save_message(session_id, message)
logger.info(f"[STREAM] User message saved for session {session_id}")
) is None
logger.info(f"[STREAM] User message saved for session {session_id}")
if not is_duplicate_message and request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
)
# Create a task in the stream registry for reconnection support
# Create a task in the stream registry for reconnection support.
# For duplicate messages, skip create_session entirely so the infra-retry
# client subscribes to the *existing* turn's Redis stream and receives the
# in-progress executor output rather than an empty stream.
turn_id = ""
if not is_duplicate_message:
turn_id = str(uuid4())
log_meta["turn_id"] = turn_id
session_create_start = time.perf_counter()
await stream_registry.create_session(
session_id=session_id,
@@ -946,7 +943,6 @@ async def stream_chat_post(
}
},
)
await enqueue_copilot_turn(
session_id=session_id,
user_id=user_id,
@@ -957,11 +953,12 @@ async def stream_chat_post(
file_ids=sanitized_file_ids,
mode=request.mode,
model=request.model,
request_arrival_at=request_arrival_at,
)
else:
logger.info(
f"[STREAM] Duplicate message detected for session {session_id}, skipping enqueue"
)
except Exception:
if dedup_lock:
await dedup_lock.release()
raise
setup_time = (time.perf_counter() - stream_start_time) * 1000
logger.info(
@@ -985,12 +982,6 @@ async def stream_chat_post(
subscriber_queue = None
first_chunk_yielded = False
chunks_yielded = 0
# True for every exit path except GeneratorExit (client disconnect).
# On disconnect the backend turn is still running — releasing the lock
# there would reopen the infra-retry duplicate window. The 30 s TTL
# is the fallback. All other exits (normal finish, early return, error)
# should release so the user can re-send the same message.
release_dedup_lock_on_exit = True
try:
# Subscribe from the position we captured before enqueuing
# This avoids replaying old messages while catching all new ones
@@ -1002,7 +993,7 @@ async def stream_chat_post(
if subscriber_queue is None:
yield StreamFinish().to_sse()
return # finally releases dedup_lock
return
# Read from the subscriber queue and yield to SSE
logger.info(
@@ -1044,7 +1035,7 @@ async def stream_chat_post(
}
},
)
break # finally releases dedup_lock
break
except asyncio.TimeoutError:
yield StreamHeartbeat().to_sse()
@@ -1060,7 +1051,6 @@ async def stream_chat_post(
}
},
)
release_dedup_lock_on_exit = False
except Exception as e:
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
logger.error(
@@ -1075,10 +1065,7 @@ async def stream_chat_post(
code="stream_error",
).to_sse()
yield StreamFinish().to_sse()
# finally releases dedup_lock
finally:
if dedup_lock and release_dedup_lock_on_exit:
await dedup_lock.release()
# Unsubscribe when client disconnects or stream ends
if subscriber_queue is not None:
try:
@@ -1117,6 +1104,31 @@ async def stream_chat_post(
)
@router.get(
"/sessions/{session_id}/messages/pending",
response_model=PeekPendingMessagesResponse,
responses={
404: {"description": "Session not found or access denied"},
},
)
async def get_pending_messages(
session_id: str,
user_id: str = Security(auth.get_user_id),
):
"""Peek at the pending-message buffer without consuming it.
Returns the current contents of the session's pending message buffer
so the frontend can restore the queued-message indicator after a page
refresh and clear it correctly once a turn drains the buffer.
"""
await _validate_and_get_session(session_id, user_id)
pending = await peek_pending_messages(session_id)
return PeekPendingMessagesResponse(
messages=[m.content for m in pending],
count=len(pending),
)
@router.get(
"/sessions/{session_id}/stream",
)

View File

@@ -133,21 +133,12 @@ def test_stream_chat_rejects_too_many_file_ids():
assert response.status_code == 422
def _mock_stream_internals(
mocker: pytest_mock.MockerFixture,
*,
redis_set_returns: object = True,
):
def _mock_stream_internals(mocker: pytest_mock.MockerFixture):
"""Mock the async internals of stream_chat_post so tests can exercise
validation and enrichment logic without needing Redis/RabbitMQ.
Args:
redis_set_returns: Value returned by the mocked Redis ``set`` call.
``True`` (default) simulates a fresh key (new message);
``None`` simulates a collision (duplicate blocked).
validation and enrichment logic without needing RabbitMQ.
Returns:
A namespace with ``redis``, ``save``, and ``enqueue`` mock objects so
A namespace with ``save`` and ``enqueue`` mock objects so
callers can make additional assertions about side-effects.
"""
import types
@@ -158,7 +149,7 @@ def _mock_stream_internals(
)
mock_save = mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
return_value=MagicMock(), # non-None = message was saved (not a duplicate)
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = mocker.AsyncMock(return_value=None)
@@ -174,15 +165,9 @@ def _mock_stream_internals(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
mock_redis = AsyncMock()
mock_redis.set = AsyncMock(return_value=redis_set_returns)
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=AsyncMock,
return_value=mock_redis,
return types.SimpleNamespace(
save=mock_save, enqueue=mock_enqueue, registry=mock_registry
)
ns = types.SimpleNamespace(redis=mock_redis, save=mock_save, enqueue=mock_enqueue)
return ns
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
@@ -190,7 +175,7 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
_mock_stream_internals(mocker)
# Patch workspace lookup as imported by the routes module
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
"backend.data.workspace.get_or_create_workspace",
return_value=type("W", (), {"id": "ws-1"})(),
)
mock_prisma = mocker.MagicMock()
@@ -211,6 +196,29 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
assert response.status_code == 200
# ─── Duplicate message dedup ──────────────────────────────────────────
def test_stream_chat_skips_enqueue_for_duplicate_message(
mocker: pytest_mock.MockerFixture,
):
"""When append_and_save_message returns None (duplicate detected),
enqueue_copilot_turn and stream_registry.create_session must NOT be called
to avoid double-processing and to prevent overwriting the active stream's
turn_id in Redis (which would cause reconnecting clients to miss the response)."""
mocks = _mock_stream_internals(mocker)
# Override save to return None — signalling a duplicate
mocks.save.return_value = None
response = client.post(
"/sessions/sess-1/stream",
json={"message": "hello"},
)
assert response.status_code == 200
mocks.enqueue.assert_not_called()
mocks.registry.create_session.assert_not_called()
# ─── UUID format filtering ─────────────────────────────────────────────
@@ -219,7 +227,7 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockerFixture):
and NOT passed to the database query."""
_mock_stream_internals(mocker)
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
"backend.data.workspace.get_or_create_workspace",
return_value=type("W", (), {"id": "ws-1"})(),
)
@@ -257,7 +265,7 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockerFixture):
"""The batch query should scope to the user's workspace."""
_mock_stream_internals(mocker)
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
"backend.data.workspace.get_or_create_workspace",
return_value=type("W", (), {"id": "my-workspace-id"})(),
)
@@ -288,8 +296,8 @@ def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockerF
_mock_stream_internals(mocker)
# Ensure the rate-limit branch is entered by setting a non-zero limit.
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000)
mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000)
mocker.patch(
"backend.api.features.chat.routes.check_rate_limit",
side_effect=RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1)),
@@ -310,8 +318,8 @@ def test_stream_chat_returns_429_on_weekly_rate_limit(
from backend.copilot.rate_limit import RateLimitExceeded
_mock_stream_internals(mocker)
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000)
mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000)
resets_at = datetime.now(UTC) + timedelta(days=3)
mocker.patch(
"backend.api.features.chat.routes.check_rate_limit",
@@ -333,8 +341,8 @@ def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockerFixture):
from backend.copilot.rate_limit import RateLimitExceeded
_mock_stream_internals(mocker)
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000)
mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000)
mocker.patch(
"backend.api.features.chat.routes.check_rate_limit",
side_effect=RateLimitExceeded(
@@ -394,23 +402,33 @@ def test_usage_returns_daily_and_weekly(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""GET /usage returns daily and weekly usage."""
"""GET /usage returns percentages for daily and weekly windows only.
The raw used/limit microdollar values MUST NOT leak — clients should not
be able to derive per-turn cost or platform margins from the public API.
"""
mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000)
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000)
mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000)
response = client.get("/usage")
assert response.status_code == 200
data = response.json()
assert data["daily"]["used"] == 500
assert data["weekly"]["used"] == 2000
# 500 / 10000 = 5%, 2000 / 50000 = 4%
assert data["daily"]["percent_used"] == 5.0
assert data["weekly"]["percent_used"] == 4.0
# Raw spend/limit must not be exposed.
assert "used" not in data["daily"]
assert "limit" not in data["daily"]
assert "used" not in data["weekly"]
assert "limit" not in data["weekly"]
mock_get.assert_called_once_with(
user_id=test_user_id,
daily_token_limit=10000,
weekly_token_limit=50000,
daily_cost_limit=10000,
weekly_cost_limit=50000,
rate_limit_reset_cost=chat_routes.config.rate_limit_reset_cost,
tier=SubscriptionTier.FREE,
)
@@ -430,8 +448,8 @@ def test_usage_uses_config_limits(
assert response.status_code == 200
mock_get.assert_called_once_with(
user_id=test_user_id,
daily_token_limit=99999,
weekly_token_limit=77777,
daily_cost_limit=99999,
weekly_cost_limit=77777,
rate_limit_reset_cost=500,
tier=SubscriptionTier.FREE,
)
@@ -609,6 +627,246 @@ class TestStreamChatRequestModeValidation:
assert req.mode is None
# ─── POST /stream queue-fallback (when a turn is already in flight) ──
def _mock_stream_queue_internals(
mocker: pytest_mock.MockerFixture,
*,
session_exists: bool = True,
turn_in_flight: bool = True,
call_count: int = 1,
):
"""Mock dependencies for the POST /stream queue-fallback path.
When ``turn_in_flight`` is True the handler takes the 202 queue branch.
"""
if session_exists:
mock_session = mocker.MagicMock()
mock_session.id = "sess-1"
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
new_callable=AsyncMock,
return_value=mock_session,
)
else:
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
side_effect=fastapi.HTTPException(
status_code=404, detail="Session not found."
),
)
mocker.patch(
"backend.api.features.chat.routes.is_turn_in_flight",
new_callable=AsyncMock,
return_value=turn_in_flight,
)
mocker.patch(
"backend.api.features.chat.routes.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(0, 0, None),
)
mocker.patch(
"backend.api.features.chat.routes.check_rate_limit",
new_callable=AsyncMock,
return_value=None,
)
mocker.patch(
"backend.copilot.pending_message_helpers.get_redis_async",
new_callable=AsyncMock,
return_value=mocker.MagicMock(),
)
mocker.patch(
"backend.copilot.pending_message_helpers.incr_with_ttl",
new_callable=AsyncMock,
return_value=call_count,
)
mocker.patch(
"backend.copilot.pending_message_helpers.push_pending_message",
new_callable=AsyncMock,
return_value=1,
)
# queue_user_message re-runs is_turn_in_flight via the helper module —
# stub that path out too so we don't need a fake stream_registry.
mocker.patch(
"backend.copilot.pending_message_helpers.get_active_session_meta",
new_callable=AsyncMock,
return_value=None,
)
def test_stream_queue_returns_202_when_turn_in_flight(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Happy path: POST /stream to a session with a live turn → 202 queue."""
_mock_stream_queue_internals(mocker)
response = client.post(
"/sessions/sess-1/stream",
json={"message": "follow-up", "is_user_message": True},
)
assert response.status_code == 202
data = response.json()
assert data["buffer_length"] == 1
assert "turn_in_flight" in data
def test_stream_queue_session_not_found_returns_404(
mocker: pytest_mock.MockerFixture,
) -> None:
"""If the session doesn't exist or belong to the user, returns 404."""
_mock_stream_queue_internals(mocker, session_exists=False)
response = client.post(
"/sessions/bad-sess/stream",
json={"message": "hi", "is_user_message": True},
)
assert response.status_code == 404
def test_stream_queue_call_frequency_limit_returns_429(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Per-user call-frequency cap rejects rapid-fire queued pushes."""
from backend.copilot.pending_message_helpers import PENDING_CALL_LIMIT
_mock_stream_queue_internals(mocker, call_count=PENDING_CALL_LIMIT + 1)
response = client.post(
"/sessions/sess-1/stream",
json={"message": "hi", "is_user_message": True},
)
assert response.status_code == 429
assert "Too many queued message requests this minute" in response.json()["detail"]
def test_stream_queue_converts_context_dict_to_pending_context(
mocker: pytest_mock.MockerFixture,
) -> None:
"""StreamChatRequest.context is a raw dict; must be coerced to the
typed PendingMessageContext before being pushed onto the buffer."""
_mock_stream_queue_internals(mocker)
queue_spy = mocker.patch(
"backend.copilot.pending_message_helpers.queue_user_message",
new_callable=AsyncMock,
)
from backend.copilot.pending_message_helpers import QueuePendingMessageResponse
queue_spy.return_value = QueuePendingMessageResponse(
buffer_length=1, max_buffer_length=10, turn_in_flight=True
)
response = client.post(
"/sessions/sess-1/stream",
json={
"message": "hi",
"is_user_message": True,
"context": {"url": "https://example.test", "content": "body"},
},
)
assert response.status_code == 202
queue_spy.assert_awaited_once()
kwargs = queue_spy.await_args.kwargs
from backend.copilot.pending_messages import PendingMessageContext
assert isinstance(kwargs["context"], PendingMessageContext)
assert kwargs["context"].url == "https://example.test"
assert kwargs["context"].content == "body"
def test_stream_queue_passes_none_context_when_omitted(
mocker: pytest_mock.MockerFixture,
) -> None:
"""When request.context is omitted, the queue call receives context=None."""
_mock_stream_queue_internals(mocker)
queue_spy = mocker.patch(
"backend.copilot.pending_message_helpers.queue_user_message",
new_callable=AsyncMock,
)
from backend.copilot.pending_message_helpers import QueuePendingMessageResponse
queue_spy.return_value = QueuePendingMessageResponse(
buffer_length=1, max_buffer_length=10, turn_in_flight=True
)
response = client.post(
"/sessions/sess-1/stream",
json={"message": "hi", "is_user_message": True},
)
assert response.status_code == 202
queue_spy.assert_awaited_once()
assert queue_spy.await_args.kwargs["context"] is None
# ─── get_pending_messages (GET /sessions/{session_id}/messages/pending) ─────
def test_get_pending_messages_returns_200_with_empty_buffer(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Happy path: no pending messages returns 200 with empty list."""
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
new_callable=AsyncMock,
return_value=mocker.MagicMock(),
)
mocker.patch(
"backend.api.features.chat.routes.peek_pending_messages",
new_callable=AsyncMock,
return_value=[],
)
response = client.get("/sessions/sess-1/messages/pending")
assert response.status_code == 200
data = response.json()
assert data["messages"] == []
assert data["count"] == 0
def test_get_pending_messages_returns_queued_messages(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Returns pending messages from buffer without consuming them."""
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
new_callable=AsyncMock,
return_value=mocker.MagicMock(),
)
mocker.patch(
"backend.api.features.chat.routes.peek_pending_messages",
new_callable=AsyncMock,
return_value=[
MagicMock(content="first message"),
MagicMock(content="second message"),
],
)
response = client.get("/sessions/sess-1/messages/pending")
assert response.status_code == 200
data = response.json()
assert data["count"] == 2
assert data["messages"] == ["first message", "second message"]
def test_get_pending_messages_session_not_found_returns_404(
mocker: pytest_mock.MockerFixture,
) -> None:
"""If session does not exist or belongs to another user, returns 404."""
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
side_effect=fastapi.HTTPException(status_code=404, detail="Session not found."),
)
response = client.get("/sessions/bad-sess/messages/pending")
assert response.status_code == 404
class TestStripInjectedContext:
"""Unit tests for `_strip_injected_context` — the GET-side helper that
hides the server-injected `<user_context>` block from API responses.
@@ -706,237 +964,6 @@ class TestStripInjectedContext:
assert result["content"] == "hello"
# ─── Idempotency / duplicate-POST guard ──────────────────────────────
def test_stream_chat_blocks_duplicate_post_returns_empty_sse(
mocker: pytest_mock.MockerFixture,
) -> None:
"""A second POST with the same message within the 30-s window must return
an empty SSE stream (StreamFinish + [DONE]) so the frontend marks the
turn complete without creating a ghost response."""
# redis_set_returns=None simulates a collision: the NX key already exists.
ns = _mock_stream_internals(mocker, redis_set_returns=None)
response = client.post(
"/sessions/sess-dup/stream",
json={"message": "duplicate message", "is_user_message": True},
)
assert response.status_code == 200
body = response.text
# The response must contain StreamFinish (type=finish) and the SSE [DONE] terminator.
assert '"finish"' in body
assert "[DONE]" in body
# The empty SSE response must include the AI SDK protocol header so the
# frontend treats it as a valid stream and marks the turn complete.
assert response.headers.get("x-vercel-ai-ui-message-stream") == "v1"
# The duplicate guard must prevent save/enqueue side effects.
ns.save.assert_not_called()
ns.enqueue.assert_not_called()
def test_stream_chat_first_post_proceeds_normally(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The first POST (Redis NX key set successfully) must proceed through the
normal streaming path — no early return."""
ns = _mock_stream_internals(mocker, redis_set_returns=True)
response = client.post(
"/sessions/sess-new/stream",
json={"message": "first message", "is_user_message": True},
)
assert response.status_code == 200
# Redis set must have been called once with the NX flag.
ns.redis.set.assert_called_once()
call_kwargs = ns.redis.set.call_args
assert call_kwargs.kwargs.get("nx") is True
def test_stream_chat_dedup_skipped_for_non_user_messages(
mocker: pytest_mock.MockerFixture,
) -> None:
"""System/assistant messages (is_user_message=False) bypass the dedup
guard — they are injected programmatically and must always be processed."""
ns = _mock_stream_internals(mocker, redis_set_returns=None)
response = client.post(
"/sessions/sess-sys/stream",
json={"message": "system context", "is_user_message": False},
)
# Even though redis_set_returns=None (would block a user message),
# the endpoint must proceed because is_user_message=False.
assert response.status_code == 200
ns.redis.set.assert_not_called()
def test_stream_chat_dedup_hash_uses_original_message_not_mutated(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The dedup hash must be computed from the original request message,
not the mutated version that has the [Attached files] block appended.
A file_id is sent so the route actually appends the [Attached files] block,
exercising the mutation path — the hash must still match the original text."""
import hashlib
ns = _mock_stream_internals(mocker, redis_set_returns=True)
file_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
# Mock workspace + prisma so the attachment block is actually appended.
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
return_value=type("W", (), {"id": "ws-1"})(),
)
fake_file = type(
"F",
(),
{
"id": file_id,
"name": "doc.pdf",
"mimeType": "application/pdf",
"sizeBytes": 1024,
},
)()
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[fake_file])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
response = client.post(
"/sessions/sess-hash/stream",
json={
"message": "plain message",
"is_user_message": True,
"file_ids": [file_id],
},
)
assert response.status_code == 200
ns.redis.set.assert_called_once()
call_args = ns.redis.set.call_args
dedup_key = call_args.args[0]
# Hash must use the original message + sorted file IDs, not the mutated text.
expected_hash = hashlib.sha256(
f"sess-hash:plain message:{file_id}".encode()
).hexdigest()[:16]
expected_key = f"chat:msg_dedup:sess-hash:{expected_hash}"
assert dedup_key == expected_key, (
f"Dedup key {dedup_key!r} does not match expected {expected_key!r}"
"hash may be using mutated message or wrong inputs"
)
def test_stream_chat_dedup_key_released_after_stream_finish(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The dedup Redis key must be deleted after the turn completes (when
subscriber_queue is None the route yields StreamFinish immediately and
should release the key so the user can re-send the same message)."""
from unittest.mock import AsyncMock as _AsyncMock
# Set up all internals manually so we can control subscribe_to_session.
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.enqueue_copilot_turn",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = _AsyncMock(return_value=None)
# None → early-finish path: StreamFinish yielded immediately, dedup key released.
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
mocker.patch(
"backend.api.features.chat.routes.stream_registry",
mock_registry,
)
mock_redis = mocker.AsyncMock()
mock_redis.set = _AsyncMock(return_value=True)
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=_AsyncMock,
return_value=mock_redis,
)
response = client.post(
"/sessions/sess-finish/stream",
json={"message": "hello", "is_user_message": True},
)
assert response.status_code == 200
body = response.text
assert '"finish"' in body
# The dedup key must be released so intentional re-sends are allowed.
mock_redis.delete.assert_called_once()
def test_stream_chat_dedup_key_released_even_when_redis_delete_raises(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The route must not crash when the dedup Redis delete fails on the
subscriber_queue-is-None early-finish path (except Exception: pass)."""
from unittest.mock import AsyncMock as _AsyncMock
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.enqueue_copilot_turn",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = _AsyncMock(return_value=None)
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
mocker.patch(
"backend.api.features.chat.routes.stream_registry",
mock_registry,
)
mock_redis = mocker.AsyncMock()
mock_redis.set = _AsyncMock(return_value=True)
# Make the delete raise so the except-pass branch is exercised.
mock_redis.delete = _AsyncMock(side_effect=RuntimeError("redis gone"))
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=_AsyncMock,
return_value=mock_redis,
)
# Should not raise even though delete fails.
response = client.post(
"/sessions/sess-finish-err/stream",
json={"message": "hello", "is_user_message": True},
)
assert response.status_code == 200
assert '"finish"' in response.text
# delete must have been attempted — the except-pass branch silenced the error.
mock_redis.delete.assert_called_once()
# ─── DELETE /sessions/{id}/stream — disconnect listeners ──────────────
@@ -980,3 +1007,59 @@ def test_disconnect_stream_returns_404_when_session_missing(
assert response.status_code == 404
mock_disconnect.assert_not_awaited()
# ─── GET /sessions/{session_id} — backward pagination ─────────────────────────
def _make_paginated_messages(
mocker: pytest_mock.MockerFixture, *, has_more: bool = False
):
"""Return a mock PaginatedMessages and configure the DB patch."""
from datetime import UTC, datetime
from backend.copilot.db import PaginatedMessages
from backend.copilot.model import ChatMessage, ChatSessionInfo, ChatSessionMetadata
now = datetime.now(UTC)
session_info = ChatSessionInfo(
session_id="sess-1",
user_id=TEST_USER_ID,
usage=[],
started_at=now,
updated_at=now,
metadata=ChatSessionMetadata(),
)
page = PaginatedMessages(
messages=[ChatMessage(role="user", content="hello", sequence=0)],
has_more=has_more,
oldest_sequence=0,
session=session_info,
)
mock_paginate = mocker.patch(
"backend.api.features.chat.routes.get_chat_messages_paginated",
new_callable=AsyncMock,
return_value=page,
)
return page, mock_paginate
def test_get_session_returns_backward_paginated(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""All sessions use backward (newest-first) pagination."""
_make_paginated_messages(mocker)
mocker.patch(
"backend.api.features.chat.routes.stream_registry.get_active_session",
new_callable=AsyncMock,
return_value=(None, None),
)
response = client.get("/sessions/sess-1")
assert response.status_code == 200
data = response.json()
assert data["oldest_sequence"] == 0
assert "forward_paginated" not in data
assert "newest_sequence" not in data

View File

@@ -12,6 +12,7 @@ import prisma.models
import backend.api.features.library.model as library_model
import backend.data.graph as graph_db
from backend.api.features.library.db import _fetch_schedule_info
from backend.data.graph import GraphModel, GraphSettings
from backend.data.includes import library_agent_include
from backend.util.exceptions import NotFoundError
@@ -117,4 +118,5 @@ async def add_graph_to_library(
f"for store listing version #{store_listing_version_id} "
f"to library for user #{user_id}"
)
return library_model.LibraryAgent.from_db(added_agent)
schedule_info = await _fetch_schedule_info(user_id, graph_id=graph_model.id)
return library_model.LibraryAgent.from_db(added_agent, schedule_info=schedule_info)

View File

@@ -21,13 +21,17 @@ async def test_add_graph_to_library_create_new_agent() -> None:
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
return_value=converted_agent,
) as mock_from_db,
patch(
"backend.api.features.library._add_to_library._fetch_schedule_info",
new=AsyncMock(return_value={}),
),
):
mock_prisma.return_value.create = AsyncMock(return_value=created_agent)
result = await add_graph_to_library("slv-id", graph_model, "user-id")
assert result is converted_agent
mock_from_db.assert_called_once_with(created_agent)
mock_from_db.assert_called_once_with(created_agent, schedule_info={})
# Verify create was called with correct data
create_call = mock_prisma.return_value.create.call_args
create_data = create_call.kwargs["data"]
@@ -54,6 +58,10 @@ async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
return_value=converted_agent,
) as mock_from_db,
patch(
"backend.api.features.library._add_to_library._fetch_schedule_info",
new=AsyncMock(return_value={}),
),
):
mock_prisma.return_value.create = AsyncMock(
side_effect=prisma.errors.UniqueViolationError(
@@ -65,7 +73,7 @@ async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
result = await add_graph_to_library("slv-id", graph_model, "user-id")
assert result is converted_agent
mock_from_db.assert_called_once_with(updated_agent)
mock_from_db.assert_called_once_with(updated_agent, schedule_info={})
# Verify update was called with correct where and data
update_call = mock_prisma.return_value.update.call_args
assert update_call.kwargs["where"] == {

View File

@@ -1,6 +1,7 @@
import asyncio
import itertools
import logging
from datetime import datetime, timezone
from typing import Literal, Optional
import fastapi
@@ -43,6 +44,65 @@ config = Config()
integration_creds_manager = IntegrationCredentialsManager()
async def _fetch_execution_counts(user_id: str, graph_ids: list[str]) -> dict[str, int]:
"""Fetch execution counts per graph in a single batched query."""
if not graph_ids:
return {}
rows = await prisma.models.AgentGraphExecution.prisma().group_by(
by=["agentGraphId"],
where={
"userId": user_id,
"agentGraphId": {"in": graph_ids},
"isDeleted": False,
},
count=True,
)
return {
row["agentGraphId"]: int((row.get("_count") or {}).get("_all") or 0)
for row in rows
}
async def _fetch_schedule_info(
user_id: str, graph_id: Optional[str] = None
) -> dict[str, str]:
"""Fetch a map of graph_id → earliest next_run_time ISO string.
When `graph_id` is provided, the scheduler query is narrowed to that graph,
which is cheaper for single-agent lookups (detail page, post-update, etc.).
"""
try:
scheduler_client = get_scheduler_client()
schedules = await scheduler_client.get_execution_schedules(
graph_id=graph_id,
user_id=user_id,
)
earliest: dict[str, tuple[datetime, str]] = {}
for s in schedules:
parsed = _parse_iso_datetime(s.next_run_time)
if parsed is None:
continue
current = earliest.get(s.graph_id)
if current is None or parsed < current[0]:
earliest[s.graph_id] = (parsed, s.next_run_time)
return {graph_id: iso for graph_id, (_, iso) in earliest.items()}
except Exception:
logger.warning("Failed to fetch schedules for library agents", exc_info=True)
return {}
def _parse_iso_datetime(value: str) -> Optional[datetime]:
"""Parse an ISO 8601 datetime, tolerating `Z` and naive forms (assumed UTC)."""
try:
parsed = datetime.fromisoformat(value.replace("Z", "+00:00"))
except ValueError:
logger.warning("Failed to parse schedule next_run_time: %s", value)
return None
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
return parsed
async def list_library_agents(
user_id: str,
search_term: Optional[str] = None,
@@ -137,12 +197,22 @@ async def list_library_agents(
logger.debug(f"Retrieved {len(library_agents)} library agents for user #{user_id}")
graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId]
execution_counts, schedule_info = await asyncio.gather(
_fetch_execution_counts(user_id, graph_ids),
_fetch_schedule_info(user_id),
)
# Only pass valid agents to the response
valid_library_agents: list[library_model.LibraryAgent] = []
for agent in library_agents:
try:
library_agent = library_model.LibraryAgent.from_db(agent)
library_agent = library_model.LibraryAgent.from_db(
agent,
execution_count_override=execution_counts.get(agent.agentGraphId),
schedule_info=schedule_info,
)
valid_library_agents.append(library_agent)
except Exception as e:
# Skip this agent if there was an error
@@ -214,12 +284,22 @@ async def list_favorite_library_agents(
f"Retrieved {len(library_agents)} favorite library agents for user #{user_id}"
)
graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId]
execution_counts, schedule_info = await asyncio.gather(
_fetch_execution_counts(user_id, graph_ids),
_fetch_schedule_info(user_id),
)
# Only pass valid agents to the response
valid_library_agents: list[library_model.LibraryAgent] = []
for agent in library_agents:
try:
library_agent = library_model.LibraryAgent.from_db(agent)
library_agent = library_model.LibraryAgent.from_db(
agent,
execution_count_override=execution_counts.get(agent.agentGraphId),
schedule_info=schedule_info,
)
valid_library_agents.append(library_agent)
except Exception as e:
# Skip this agent if there was an error
@@ -285,6 +365,12 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
where={"userId": store_listing.owningUserId}
)
schedule_info = (
await _fetch_schedule_info(user_id, graph_id=library_agent.AgentGraph.id)
if library_agent.AgentGraph
else {}
)
return library_model.LibraryAgent.from_db(
library_agent,
sub_graphs=(
@@ -294,6 +380,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
),
store_listing=store_listing,
profile=profile,
schedule_info=schedule_info,
)
@@ -329,7 +416,10 @@ async def get_library_agent_by_store_version_id(
},
include=library_agent_include(user_id),
)
return library_model.LibraryAgent.from_db(agent) if agent else None
if not agent:
return None
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent.agentGraphId)
return library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
async def get_library_agent_by_graph_id(
@@ -358,7 +448,10 @@ async def get_library_agent_by_graph_id(
assert agent.AgentGraph # make type checker happy
# Include sub-graphs so we can make a full credentials input schema
sub_graphs = await graph_db.get_sub_graphs(agent.AgentGraph)
return library_model.LibraryAgent.from_db(agent, sub_graphs=sub_graphs)
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent.agentGraphId)
return library_model.LibraryAgent.from_db(
agent, sub_graphs=sub_graphs, schedule_info=schedule_info
)
async def add_generated_agent_image(
@@ -500,7 +593,11 @@ async def create_library_agent(
for agent, graph in zip(library_agents, graph_entries):
asyncio.create_task(add_generated_agent_image(graph, user_id, agent.id))
return [library_model.LibraryAgent.from_db(agent) for agent in library_agents]
schedule_info = await _fetch_schedule_info(user_id)
return [
library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
for agent in library_agents
]
async def update_agent_version_in_library(
@@ -562,7 +659,8 @@ async def update_agent_version_in_library(
f"Failed to update library agent for {agent_graph_id} v{agent_graph_version}"
)
return library_model.LibraryAgent.from_db(lib)
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent_graph_id)
return library_model.LibraryAgent.from_db(lib, schedule_info=schedule_info)
async def create_graph_in_library(
@@ -1467,7 +1565,11 @@ async def bulk_move_agents_to_folder(
),
)
return [library_model.LibraryAgent.from_db(agent) for agent in agents]
schedule_info = await _fetch_schedule_info(user_id)
return [
library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
for agent in agents
]
def collect_tree_ids(

View File

@@ -65,6 +65,11 @@ async def test_get_library_agents(mocker):
)
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
mocker.patch(
"backend.api.features.library.db._fetch_execution_counts",
new=mocker.AsyncMock(return_value={}),
)
# Call function
result = await db.list_library_agents("test-user")
@@ -353,3 +358,136 @@ async def test_create_library_agent_uses_upsert():
# Verify update branch restores soft-deleted/archived agents
assert data["update"]["isDeleted"] is False
assert data["update"]["isArchived"] is False
@pytest.mark.asyncio
async def test_list_favorite_library_agents(mocker):
mock_library_agents = [
prisma.models.LibraryAgent(
id="fav1",
userId="test-user",
agentGraphId="agent-fav",
settings="{}", # type: ignore
agentGraphVersion=1,
isCreatedByUser=False,
isDeleted=False,
isArchived=False,
createdAt=datetime.now(),
updatedAt=datetime.now(),
isFavorite=True,
useGraphIsActiveVersion=True,
AgentGraph=prisma.models.AgentGraph(
id="agent-fav",
version=1,
name="Favorite Agent",
description="My Favorite",
userId="other-user",
isActive=True,
createdAt=datetime.now(),
),
)
]
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_many = mocker.AsyncMock(
return_value=mock_library_agents
)
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
mocker.patch(
"backend.api.features.library.db._fetch_execution_counts",
new=mocker.AsyncMock(return_value={"agent-fav": 7}),
)
result = await db.list_favorite_library_agents("test-user")
assert len(result.agents) == 1
assert result.agents[0].id == "fav1"
assert result.agents[0].name == "Favorite Agent"
assert result.agents[0].graph_id == "agent-fav"
assert result.pagination.total_items == 1
assert result.pagination.total_pages == 1
assert result.pagination.current_page == 1
assert result.pagination.page_size == 50
@pytest.mark.asyncio
async def test_list_library_agents_skips_failed_agent(mocker):
"""Agents that fail parsing should be skipped — covers the except branch."""
mock_library_agents = [
prisma.models.LibraryAgent(
id="ua-bad",
userId="test-user",
agentGraphId="agent-bad",
settings="{}", # type: ignore
agentGraphVersion=1,
isCreatedByUser=False,
isDeleted=False,
isArchived=False,
createdAt=datetime.now(),
updatedAt=datetime.now(),
isFavorite=False,
useGraphIsActiveVersion=True,
AgentGraph=prisma.models.AgentGraph(
id="agent-bad",
version=1,
name="Bad Agent",
description="",
userId="other-user",
isActive=True,
createdAt=datetime.now(),
),
)
]
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_many = mocker.AsyncMock(
return_value=mock_library_agents
)
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
mocker.patch(
"backend.api.features.library.db._fetch_execution_counts",
new=mocker.AsyncMock(return_value={}),
)
mocker.patch(
"backend.api.features.library.model.LibraryAgent.from_db",
side_effect=Exception("parse error"),
)
result = await db.list_library_agents("test-user")
assert len(result.agents) == 0
assert result.pagination.total_items == 1
@pytest.mark.asyncio
async def test_fetch_execution_counts_empty_graph_ids():
result = await db._fetch_execution_counts("user-1", [])
assert result == {}
@pytest.mark.asyncio
async def test_fetch_execution_counts_uses_group_by(mocker):
mock_prisma = mocker.patch("prisma.models.AgentGraphExecution.prisma")
mock_prisma.return_value.group_by = mocker.AsyncMock(
return_value=[
{"agentGraphId": "graph-1", "_count": {"_all": 5}},
{"agentGraphId": "graph-2", "_count": {"_all": 2}},
]
)
result = await db._fetch_execution_counts(
"user-1", ["graph-1", "graph-2", "graph-3"]
)
assert result == {"graph-1": 5, "graph-2": 2}
mock_prisma.return_value.group_by.assert_called_once_with(
by=["agentGraphId"],
where={
"userId": "user-1",
"agentGraphId": {"in": ["graph-1", "graph-2", "graph-3"]},
"isDeleted": False,
},
count=True,
)

View File

@@ -214,6 +214,14 @@ class LibraryAgent(pydantic.BaseModel):
folder_name: str | None = None # Denormalized for display
recommended_schedule_cron: str | None = None
is_scheduled: bool = pydantic.Field(
default=False,
description="Whether this agent has active execution schedules",
)
next_scheduled_run: str | None = pydantic.Field(
default=None,
description="ISO 8601 timestamp of the next scheduled run, if any",
)
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
marketplace_listing: Optional["MarketplaceListing"] = None
@@ -223,6 +231,8 @@ class LibraryAgent(pydantic.BaseModel):
sub_graphs: Optional[list[prisma.models.AgentGraph]] = None,
store_listing: Optional[prisma.models.StoreListing] = None,
profile: Optional[prisma.models.Profile] = None,
execution_count_override: Optional[int] = None,
schedule_info: Optional[dict[str, str]] = None,
) -> "LibraryAgent":
"""
Factory method that constructs a LibraryAgent from a Prisma LibraryAgent
@@ -258,10 +268,14 @@ class LibraryAgent(pydantic.BaseModel):
status = status_result.status
new_output = status_result.new_output
execution_count = len(executions)
execution_count = (
execution_count_override
if execution_count_override is not None
else len(executions)
)
success_rate: float | None = None
avg_correctness_score: float | None = None
if execution_count > 0:
if executions and execution_count > 0:
success_count = sum(
1
for e in executions
@@ -354,6 +368,10 @@ class LibraryAgent(pydantic.BaseModel):
folder_id=agent.folderId,
folder_name=agent.Folder.name if agent.Folder else None,
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
is_scheduled=bool(schedule_info and agent.agentGraphId in schedule_info),
next_scheduled_run=(
schedule_info.get(agent.agentGraphId) if schedule_info else None
),
settings=_parse_settings(agent.settings),
marketplace_listing=marketplace_listing_data,
)

View File

@@ -1,11 +1,66 @@
import datetime
import prisma.enums
import prisma.models
import pytest
from . import model as library_model
def _make_library_agent(
*,
graph_id: str = "g1",
executions: list | None = None,
) -> prisma.models.LibraryAgent:
return prisma.models.LibraryAgent(
id="la1",
userId="u1",
agentGraphId=graph_id,
settings="{}", # type: ignore
agentGraphVersion=1,
isCreatedByUser=True,
isDeleted=False,
isArchived=False,
createdAt=datetime.datetime.now(),
updatedAt=datetime.datetime.now(),
isFavorite=False,
useGraphIsActiveVersion=True,
AgentGraph=prisma.models.AgentGraph(
id=graph_id,
version=1,
name="Agent",
description="Desc",
userId="u1",
isActive=True,
createdAt=datetime.datetime.now(),
Executions=executions,
),
)
def test_from_db_execution_count_override_covers_success_rate():
"""Covers execution_count_override is not None branch and executions/count > 0 block."""
now = datetime.datetime.now(datetime.timezone.utc)
exec1 = prisma.models.AgentGraphExecution(
id="exec-1",
agentGraphId="g1",
agentGraphVersion=1,
userId="u1",
executionStatus=prisma.enums.AgentExecutionStatus.COMPLETED,
createdAt=now,
updatedAt=now,
isDeleted=False,
isShared=False,
)
agent = _make_library_agent(executions=[exec1])
result = library_model.LibraryAgent.from_db(agent, execution_count_override=1)
assert result.execution_count == 1
assert result.success_rate is not None
assert result.success_rate == 100.0
@pytest.mark.asyncio
async def test_agent_preset_from_db(test_user_id: str):
# Create mock DB agent

View File

@@ -5,7 +5,8 @@ import time
import uuid
from collections import defaultdict
from datetime import datetime, timezone
from typing import Annotated, Any, Literal, Sequence, get_args
from typing import Annotated, Any, Literal, Sequence, cast, get_args
from urllib.parse import urlparse
import pydantic
import stripe
@@ -25,7 +26,7 @@ from fastapi import (
)
from fastapi.concurrency import run_in_threadpool
from prisma.enums import SubscriptionTier
from pydantic import BaseModel
from pydantic import BaseModel, Field
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
from typing_extensions import Optional, TypedDict
@@ -48,17 +49,24 @@ from backend.data.auth import api_key as api_key_db
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.credit import (
AutoTopUpConfig,
PendingChangeUnknown,
RefundRequest,
TransactionHistory,
UserCredit,
cancel_stripe_subscription,
create_subscription_checkout,
get_auto_top_up,
get_pending_subscription_change,
get_proration_credit_cents,
get_subscription_price_id,
get_user_credit_model,
handle_subscription_payment_failure,
modify_stripe_subscription_for_tier,
release_pending_subscription_schedule,
set_auto_top_up,
set_subscription_tier,
sync_subscription_from_stripe,
sync_subscription_schedule_from_stripe,
)
from backend.data.graph import GraphSettings
from backend.data.model import CredentialsMetaInput, UserOnboarding
@@ -694,14 +702,83 @@ class SubscriptionTierRequest(BaseModel):
cancel_url: str = ""
class SubscriptionCheckoutResponse(BaseModel):
url: str
class SubscriptionStatusResponse(BaseModel):
tier: str
monthly_cost: int
tier_costs: dict[str, int]
tier: Literal["FREE", "PRO", "BUSINESS", "ENTERPRISE"]
monthly_cost: int # amount in cents (Stripe convention)
tier_costs: dict[str, int] # tier name -> amount in cents
proration_credit_cents: int # unused portion of current sub to convert on upgrade
pending_tier: Optional[Literal["FREE", "PRO", "BUSINESS"]] = None
pending_tier_effective_at: Optional[datetime] = None
url: str = Field(
default="",
description=(
"Populated only when POST /credits/subscription starts a Stripe Checkout"
" Session (FREE → paid upgrade). Empty string in all other branches —"
" the client redirects to this URL when non-empty."
),
)
def _validate_checkout_redirect_url(url: str) -> bool:
"""Return True if `url` matches the configured frontend origin.
Prevents open-redirect: attackers must not be able to supply arbitrary
success_url/cancel_url that Stripe will redirect users to after checkout.
Pre-parse rejection rules (applied before urlparse):
- Backslashes (``\\``) are normalised differently across parsers/browsers.
- Control characters (U+0000U+001F) are not valid in URLs and may confuse
some URL-parsing implementations.
"""
# Reject characters that can confuse URL parsers before any parsing.
if "\\" in url:
return False
if any(ord(c) < 0x20 for c in url):
return False
allowed = settings.config.frontend_base_url or settings.config.platform_base_url
if not allowed:
# No configured origin — refuse to validate rather than allow arbitrary URLs.
return False
try:
parsed = urlparse(url)
allowed_parsed = urlparse(allowed)
except ValueError:
return False
if parsed.scheme not in ("http", "https"):
return False
# Reject ``user:pass@host`` authority tricks — ``@`` in the netloc component
# can trick browsers into connecting to a different host than displayed.
# ``@`` in query/fragment is harmless and must be allowed.
if "@" in parsed.netloc:
return False
return (
parsed.scheme == allowed_parsed.scheme
and parsed.netloc == allowed_parsed.netloc
)
@cached(ttl_seconds=300, maxsize=32, cache_none=False)
async def _get_stripe_price_amount(price_id: str) -> int | None:
"""Return the unit_amount (cents) for a Stripe Price ID, cached for 5 minutes.
Returns ``None`` on transient Stripe errors. ``cache_none=False`` opts out
of caching the ``None`` sentinel so the next request retries Stripe instead
of being served a stale "no price" for the rest of the TTL window. Callers
should treat ``None`` as an unknown price and fall back to 0.
Stripe prices rarely change; caching avoids a ~200-600 ms Stripe round-trip on
every GET /credits/subscription page load and reduces quota consumption.
"""
try:
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
return price.unit_amount or 0
except stripe.StripeError:
logger.warning(
"Failed to retrieve Stripe price %s — returning None (not cached)",
price_id,
)
return None
@v1_router.get(
@@ -722,27 +799,57 @@ async def get_subscription_status(
*[get_subscription_price_id(t) for t in paid_tiers]
)
tier_costs: dict[str, int] = {"FREE": 0, "ENTERPRISE": 0}
for t, price_id in zip(paid_tiers, price_ids):
cost = 0
if price_id:
try:
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
cost = price.unit_amount or 0
except stripe.StripeError:
pass
tier_costs: dict[str, int] = {
SubscriptionTier.FREE.value: 0,
SubscriptionTier.ENTERPRISE.value: 0,
}
async def _cost(pid: str | None) -> int:
return (await _get_stripe_price_amount(pid) or 0) if pid else 0
costs = await asyncio.gather(*[_cost(pid) for pid in price_ids])
for t, cost in zip(paid_tiers, costs):
tier_costs[t.value] = cost
return SubscriptionStatusResponse(
current_monthly_cost = tier_costs.get(tier.value, 0)
proration_credit = await get_proration_credit_cents(user_id, current_monthly_cost)
try:
pending = await get_pending_subscription_change(user_id)
except (stripe.StripeError, PendingChangeUnknown):
# Swallow Stripe-side failures (rate limits, transient network) AND
# PendingChangeUnknown (LaunchDarkly price-id lookup failed). Both
# propagate past the cache so the next request retries fresh instead
# of serving a stale None for the TTL window. Let real bugs (KeyError,
# AttributeError, etc.) propagate so they surface in Sentry.
logger.exception(
"get_subscription_status: failed to resolve pending change for user %s",
user_id,
)
pending = None
response = SubscriptionStatusResponse(
tier=tier.value,
monthly_cost=tier_costs.get(tier.value, 0),
monthly_cost=current_monthly_cost,
tier_costs=tier_costs,
proration_credit_cents=proration_credit,
)
if pending is not None:
pending_tier_enum, pending_effective_at = pending
if pending_tier_enum == SubscriptionTier.FREE:
response.pending_tier = "FREE"
elif pending_tier_enum == SubscriptionTier.PRO:
response.pending_tier = "PRO"
elif pending_tier_enum == SubscriptionTier.BUSINESS:
response.pending_tier = "BUSINESS"
if response.pending_tier is not None:
response.pending_tier_effective_at = pending_effective_at
return response
@v1_router.post(
path="/credits/subscription",
summary="Start a Stripe Checkout session to upgrade subscription tier",
summary="Update subscription tier or start a Stripe Checkout session",
operation_id="updateSubscriptionTier",
tags=["credits"],
dependencies=[Security(requires_user)],
@@ -750,7 +857,7 @@ async def get_subscription_status(
async def update_subscription_tier(
request: SubscriptionTierRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> SubscriptionCheckoutResponse:
) -> SubscriptionStatusResponse:
# Pydantic validates tier is one of FREE/PRO/BUSINESS via Literal type.
tier = SubscriptionTier(request.tier)
@@ -762,28 +869,143 @@ async def update_subscription_tier(
detail="ENTERPRISE subscription changes must be managed by an administrator",
)
# Same-tier request = "stay on my current tier" = cancel any pending
# scheduled change (paid→paid downgrade or paid→FREE cancel). This is the
# collapsed behaviour that replaces the old /credits/subscription/cancel-pending
# route. Safe when no pending change exists: release_pending_subscription_schedule
# returns False and we simply return the current status.
if (user.subscription_tier or SubscriptionTier.FREE) == tier:
try:
await release_pending_subscription_schedule(user_id)
except stripe.StripeError as e:
logger.exception(
"Stripe error releasing pending subscription change for user %s: %s",
user_id,
e,
)
raise HTTPException(
status_code=502,
detail=(
"Unable to cancel the pending subscription change right now. "
"Please try again or contact support."
),
)
return await get_subscription_status(user_id)
payment_enabled = await is_feature_enabled(
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
)
# Downgrade to FREE: cancel active Stripe subscription, then update the DB tier.
# Downgrade to FREE: schedule Stripe cancellation at period end so the user
# keeps their tier for the time they already paid for. The DB tier is NOT
# updated here when a subscription exists — the customer.subscription.deleted
# webhook fires at period end and downgrades to FREE then.
# Exception: if the user has no active Stripe subscription (e.g. admin-granted
# tier), cancel_stripe_subscription returns False and we update the DB tier
# immediately since no webhook will ever fire.
# When payment is disabled entirely, update the DB tier directly.
if tier == SubscriptionTier.FREE:
if payment_enabled:
await cancel_stripe_subscription(user_id)
try:
had_subscription = await cancel_stripe_subscription(user_id)
except stripe.StripeError as e:
# Log full Stripe error server-side but return a generic message
# to the client — raw Stripe errors can leak customer/sub IDs and
# infrastructure config details.
logger.exception(
"Stripe error cancelling subscription for user %s: %s",
user_id,
e,
)
raise HTTPException(
status_code=502,
detail=(
"Unable to cancel your subscription right now. "
"Please try again or contact support."
),
)
if not had_subscription:
# No active Stripe subscription found — the user was on an
# admin-granted tier. Update DB immediately since the
# subscription.deleted webhook will never fire.
await set_subscription_tier(user_id, tier)
return await get_subscription_status(user_id)
await set_subscription_tier(user_id, tier)
return SubscriptionCheckoutResponse(url="")
return await get_subscription_status(user_id)
# Beta users (payment not enabled) → update tier directly without Stripe.
# Paid tier changes require payment to be enabled — block self-service upgrades
# when the flag is off. Admins use the /api/admin/ routes to set tiers directly.
if not payment_enabled:
await set_subscription_tier(user_id, tier)
return SubscriptionCheckoutResponse(url="")
raise HTTPException(
status_code=422,
detail=f"Subscription not available for tier {tier}",
)
# Paid upgrade → create Stripe Checkout Session.
# Paid→paid tier change: if the user already has a Stripe subscription,
# modify it in-place with proration instead of creating a new Checkout
# Session. This preserves remaining paid time and avoids double-charging.
# The customer.subscription.updated webhook fires and updates the DB tier.
current_tier = user.subscription_tier or SubscriptionTier.FREE
if current_tier in (SubscriptionTier.PRO, SubscriptionTier.BUSINESS):
try:
modified = await modify_stripe_subscription_for_tier(user_id, tier)
if modified:
return await get_subscription_status(user_id)
# modify_stripe_subscription_for_tier returns False when no active
# Stripe subscription exists — i.e. the user has an admin-granted
# paid tier with no Stripe record. In that case, update the DB
# tier directly (same as the FREE-downgrade path for admin-granted
# users) rather than sending them through a new Checkout Session.
await set_subscription_tier(user_id, tier)
return await get_subscription_status(user_id)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
except stripe.StripeError as e:
logger.exception(
"Stripe error modifying subscription for user %s: %s", user_id, e
)
raise HTTPException(
status_code=502,
detail=(
"Unable to update your subscription right now. "
"Please try again or contact support."
),
)
# Paid upgrade from FREE → create Stripe Checkout Session.
if not request.success_url or not request.cancel_url:
raise HTTPException(
status_code=422,
detail="success_url and cancel_url are required for paid tier upgrades",
)
# Open-redirect protection: both URLs must point to the configured frontend
# origin, otherwise an attacker could use our Stripe integration as a
# redirector to arbitrary phishing sites.
#
# Fail early with a clear 503 if the server is misconfigured (neither
# frontend_base_url nor platform_base_url set), so operators get an
# actionable error instead of the misleading "must match the platform
# frontend origin" 422 that _validate_checkout_redirect_url would otherwise
# produce when `allowed` is empty.
if not (settings.config.frontend_base_url or settings.config.platform_base_url):
logger.error(
"update_subscription_tier: neither frontend_base_url nor "
"platform_base_url is configured; cannot validate checkout redirect URLs"
)
raise HTTPException(
status_code=503,
detail=(
"Payment redirect URLs cannot be validated: "
"frontend_base_url or platform_base_url must be set on the server."
),
)
if not _validate_checkout_redirect_url(
request.success_url
) or not _validate_checkout_redirect_url(request.cancel_url):
raise HTTPException(
status_code=422,
detail="success_url and cancel_url must match the platform frontend origin",
)
try:
url = await create_subscription_checkout(
user_id=user_id,
@@ -791,54 +1013,113 @@ async def update_subscription_tier(
success_url=request.success_url,
cancel_url=request.cancel_url,
)
except (ValueError, stripe.StripeError) as e:
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
except stripe.StripeError as e:
logger.exception(
"Stripe error creating checkout session for user %s: %s", user_id, e
)
raise HTTPException(
status_code=502,
detail=(
"Unable to start checkout right now. "
"Please try again or contact support."
),
)
return SubscriptionCheckoutResponse(url=url)
status = await get_subscription_status(user_id)
status.url = url
return status
@v1_router.post(
path="/credits/stripe_webhook", summary="Handle Stripe webhooks", tags=["credits"]
)
async def stripe_webhook(request: Request):
webhook_secret = settings.secrets.stripe_webhook_secret
if not webhook_secret:
# Guard: an empty secret allows HMAC forgery (attacker can compute a valid
# signature over the same empty key). Reject all webhook calls when unconfigured.
logger.error(
"stripe_webhook: STRIPE_WEBHOOK_SECRET is not configured — "
"rejecting request to prevent signature bypass"
)
raise HTTPException(status_code=503, detail="Webhook not configured")
# Get the raw request body
payload = await request.body()
# Get the signature header
sig_header = request.headers.get("stripe-signature")
try:
event = stripe.Webhook.construct_event(
payload, sig_header, settings.secrets.stripe_webhook_secret
)
except ValueError as e:
event = stripe.Webhook.construct_event(payload, sig_header, webhook_secret)
except ValueError:
# Invalid payload
raise HTTPException(
status_code=400, detail=f"Invalid payload: {str(e) or type(e).__name__}"
)
except stripe.SignatureVerificationError as e:
raise HTTPException(status_code=400, detail="Invalid payload")
except stripe.SignatureVerificationError:
# Invalid signature
raise HTTPException(
status_code=400, detail=f"Invalid signature: {str(e) or type(e).__name__}"
raise HTTPException(status_code=400, detail="Invalid signature")
# Defensive payload extraction. A malformed payload (missing/non-dict
# `data.object`, missing `id`) would otherwise raise KeyError/TypeError
# AFTER signature verification — which Stripe interprets as a delivery
# failure and retries forever, while spamming Sentry with no useful info.
# Acknowledge with 200 and a warning so Stripe stops retrying.
event_type = event.get("type", "")
event_data = event.get("data") or {}
data_object = event_data.get("object") if isinstance(event_data, dict) else None
if not isinstance(data_object, dict):
logger.warning(
"stripe_webhook: %s missing or non-dict data.object; ignoring",
event_type,
)
return Response(status_code=200)
if (
event["type"] == "checkout.session.completed"
or event["type"] == "checkout.session.async_payment_succeeded"
if event_type in (
"checkout.session.completed",
"checkout.session.async_payment_succeeded",
):
await UserCredit().fulfill_checkout(session_id=event["data"]["object"]["id"])
session_id = data_object.get("id")
if not session_id:
logger.warning(
"stripe_webhook: %s missing data.object.id; ignoring", event_type
)
return Response(status_code=200)
await UserCredit().fulfill_checkout(session_id=session_id)
if event["type"] in (
if event_type in (
"customer.subscription.created",
"customer.subscription.updated",
"customer.subscription.deleted",
):
await sync_subscription_from_stripe(event["data"]["object"])
await sync_subscription_from_stripe(data_object)
if event["type"] == "charge.dispute.created":
await UserCredit().handle_dispute(event["data"]["object"])
# `subscription_schedule.updated` is deliberately omitted: our own
# `SubscriptionSchedule.create` + `.modify` calls in
# `_schedule_downgrade_at_period_end` would fire that event right back at us
# and loop redundant traffic through this handler. We only care about state
# transitions (released / completed); phase advance to the new price is
# already covered by `customer.subscription.updated`.
if event_type in (
"subscription_schedule.released",
"subscription_schedule.completed",
):
await sync_subscription_schedule_from_stripe(data_object)
if event["type"] == "refund.created" or event["type"] == "charge.dispute.closed":
await UserCredit().deduct_credits(event["data"]["object"])
if event_type == "invoice.payment_failed":
await handle_subscription_payment_failure(data_object)
# `handle_dispute` and `deduct_credits` expect Stripe SDK typed objects
# (Dispute/Refund). The Stripe webhook payload's `data.object` is a
# StripeObject (a dict subclass) carrying that runtime shape, so we cast
# to satisfy the type checker without changing runtime behaviour.
if event_type == "charge.dispute.created":
await UserCredit().handle_dispute(cast(stripe.Dispute, data_object))
if event_type == "refund.created" or event_type == "charge.dispute.closed":
await UserCredit().deduct_credits(
cast("stripe.Refund | stripe.Dispute", data_object)
)
return Response(status_code=200)

View File

@@ -23,6 +23,7 @@ from backend.copilot.permissions import (
validate_block_identifiers,
)
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
if TYPE_CHECKING:
from backend.data.execution import ExecutionContext
@@ -32,9 +33,36 @@ logger = logging.getLogger(__name__)
# Block ID shared between autopilot.py and copilot prompting.py.
AUTOPILOT_BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
# Identifiers used when registering an AutoPilotBlock turn with the
# stream registry — distinguishes block-originated turns from sub-session
# or HTTP SSE turns in logs / observability.
_AUTOPILOT_TOOL_CALL_ID = "autopilot_block"
_AUTOPILOT_TOOL_NAME = "autopilot_block"
class SubAgentRecursionError(RuntimeError):
"""Raised when the sub-agent nesting depth limit is exceeded."""
# Ceiling on how long AutoPilotBlock.execute_copilot will wait for the
# enqueued turn's terminal event. Graph blocks run synchronously from
# the caller's perspective so we wait effectively as long as needed; 6h
# matches the previous abandoned-task cap and is much longer than any
# legitimate AutoPilot turn.
_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS = 6 * 60 * 60 # 6 hours
class SubAgentRecursionError(BlockExecutionError):
"""Raised when the AutoPilot sub-agent nesting depth limit is exceeded.
Inherits :class:`BlockExecutionError` — this is a known, handled
runtime failure at the block level (caller nested AutoPilotBlocks
beyond the configured limit). Surfaces with the block_name /
block_id the block framework expects, instead of being wrapped in
``BlockUnknownError``.
"""
def __init__(self, message: str) -> None:
super().__init__(
message=message,
block_name="AutoPilotBlock",
block_id=AUTOPILOT_BLOCK_ID,
)
class ToolCallEntry(TypedDict):
@@ -268,11 +296,15 @@ class AutoPilotBlock(Block):
user_id: str,
permissions: "CopilotPermissions | None" = None,
) -> tuple[str, list[ToolCallEntry], str, str, TokenUsage]:
"""Invoke the copilot and collect all stream results.
"""Invoke the copilot on the copilot_executor queue and aggregate the
result.
Delegates to :func:`collect_copilot_response` — the shared helper that
consumes ``stream_chat_completion_sdk`` without wrapping it in an
``asyncio.timeout`` (the SDK manages its own heartbeat-based timeouts).
Delegates to :func:`run_copilot_turn_via_queue` — the shared
primitive used by ``run_sub_session`` too — which creates the
stream_registry meta record, enqueues the job, and waits on the
Redis stream for the terminal event. Any available
copilot_executor worker picks up the job, so this call survives
the graph-executor worker dying mid-turn (RabbitMQ redelivers).
Args:
prompt: The user task/instruction.
@@ -285,8 +317,8 @@ class AutoPilotBlock(Block):
Returns:
A tuple of (response_text, tool_calls, history_json, session_id, usage).
"""
from backend.copilot.sdk.collect import (
collect_copilot_response, # avoid circular import
from backend.copilot.sdk.session_waiter import (
run_copilot_turn_via_queue, # avoid circular import
)
tokens = _check_recursion(max_recursion_depth)
@@ -299,14 +331,35 @@ class AutoPilotBlock(Block):
if system_context:
effective_prompt = f"[System Context: {system_context}]\n\n{prompt}"
result = await collect_copilot_response(
outcome, result = await run_copilot_turn_via_queue(
session_id=session_id,
message=effective_prompt,
user_id=user_id,
message=effective_prompt,
# Graph block execution is synchronous from the caller's
# perspective — wait effectively as long as needed. The
# SDK enforces its own idle-based timeout inside the
# stream_registry pipeline.
timeout=_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS,
permissions=effective_permissions,
tool_call_id=_AUTOPILOT_TOOL_CALL_ID,
tool_name=_AUTOPILOT_TOOL_NAME,
)
if outcome == "failed":
raise RuntimeError(
"AutoPilot turn failed — see the session's transcript"
)
if outcome == "running":
raise RuntimeError(
"AutoPilot turn did not complete within "
f"{_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS}s — session "
f"{session_id}"
)
# Build a lightweight conversation summary from streamed data.
# Build a lightweight conversation summary from the aggregated data.
# When ``result.queued`` is True the prompt rode on an already-
# in-flight turn (``run_copilot_turn_via_queue`` queued it and
# waited on the existing turn's stream); the aggregated result
# is still valid, so the same rendering path applies.
turn_messages: list[dict[str, Any]] = [
{"role": "user", "content": effective_prompt},
]
@@ -315,7 +368,7 @@ class AutoPilotBlock(Block):
{
"role": "assistant",
"content": result.response_text,
"tool_calls": result.tool_calls,
"tool_calls": [tc.model_dump() for tc in result.tool_calls],
}
)
else:
@@ -326,11 +379,11 @@ class AutoPilotBlock(Block):
tool_calls: list[ToolCallEntry] = [
{
"tool_call_id": tc["tool_call_id"],
"tool_name": tc["tool_name"],
"input": tc["input"],
"output": tc["output"],
"success": tc["success"],
"tool_call_id": tc.tool_call_id,
"tool_name": tc.tool_name,
"input": tc.input,
"output": tc.output,
"success": tc.success,
}
for tc in result.tool_calls
]

View File

@@ -106,7 +106,6 @@ class LlmModelMeta(EnumMeta):
class LlmModel(str, Enum, metaclass=LlmModelMeta):
@classmethod
def _missing_(cls, value: object) -> "LlmModel | None":
"""Handle provider-prefixed model names like 'anthropic/claude-sonnet-4-6'."""
@@ -203,6 +202,8 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
GROK_4 = "x-ai/grok-4"
GROK_4_FAST = "x-ai/grok-4-fast"
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
GROK_4_20 = "x-ai/grok-4.20"
GROK_4_20_MULTI_AGENT = "x-ai/grok-4.20-multi-agent"
GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1"
KIMI_K2 = "moonshotai/kimi-k2"
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
@@ -627,6 +628,18 @@ MODEL_METADATA = {
LlmModel.GROK_4_1_FAST: ModelMetadata(
"open_router", 2000000, 30000, "Grok 4.1 Fast", "OpenRouter", "xAI", 1
),
LlmModel.GROK_4_20: ModelMetadata(
"open_router", 2000000, 100000, "Grok 4.20", "OpenRouter", "xAI", 3
),
LlmModel.GROK_4_20_MULTI_AGENT: ModelMetadata(
"open_router",
2000000,
100000,
"Grok 4.20 Multi-Agent",
"OpenRouter",
"xAI",
3,
),
LlmModel.GROK_CODE_FAST_1: ModelMetadata(
"open_router", 256000, 10000, "Grok Code Fast 1", "OpenRouter", "xAI", 1
),
@@ -987,7 +1000,6 @@ async def llm_call(
reasoning=reasoning,
)
elif provider == "anthropic":
an_tools = convert_openai_tool_fmt_to_anthropic(tools)
# Cache tool definitions alongside the system prompt.
# Placing cache_control on the last tool caches all tool schemas as a

View File

@@ -0,0 +1,230 @@
"""Extended-thinking wire support for the baseline (OpenRouter) path.
Anthropic routes on OpenRouter expose extended thinking through
non-OpenAI extension fields that the OpenAI Python SDK doesn't model:
* ``reasoning`` (legacy string) — enabled by ``include_reasoning: true``.
* ``reasoning_content`` — DeepSeek / some OpenRouter routes.
* ``reasoning_details`` — structured list shipped with the unified
``reasoning`` request param.
This module keeps the wire-level concerns in one place:
* :class:`OpenRouterDeltaExtension` validates the extension dict pulled off
``ChoiceDelta.model_extra`` into typed pydantic models — no ``getattr`` +
``isinstance`` duck-typing at the call site.
* :class:`BaselineReasoningEmitter` owns the reasoning block lifecycle for
one streaming round and emits ``StreamReasoning*`` events so the caller
only has to plumb the events into its pending queue.
* :func:`reasoning_extra_body` builds the ``extra_body`` fragment for the
OpenAI client call. Returns ``None`` on non-Anthropic routes.
"""
from __future__ import annotations
import logging
import uuid
from typing import Any
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from pydantic import BaseModel, ConfigDict, Field, ValidationError
from backend.copilot.model import ChatMessage
from backend.copilot.response_model import (
StreamBaseResponse,
StreamReasoningDelta,
StreamReasoningEnd,
StreamReasoningStart,
)
logger = logging.getLogger(__name__)
_VISIBLE_REASONING_TYPES = frozenset({"reasoning.text", "reasoning.summary"})
class ReasoningDetail(BaseModel):
"""One entry in OpenRouter's ``reasoning_details`` list.
OpenRouter ships ``type: "reasoning.text"`` / ``"reasoning.summary"`` /
``"reasoning.encrypted"`` entries. Only the first two carry
user-visible text; encrypted entries are opaque and omitted from the
rendered collapse. Unknown future types are tolerated (``extra="ignore"``)
so an upstream addition doesn't crash the stream — but their ``text`` /
``summary`` fields are NOT surfaced because they may carry provider
metadata rather than user-visible reasoning (see
:attr:`visible_text`).
"""
model_config = ConfigDict(extra="ignore")
type: str | None = None
text: str | None = None
summary: str | None = None
@property
def visible_text(self) -> str:
"""Return the human-readable text for this entry, or ``""``.
Only entries with a recognised reasoning type (``reasoning.text`` /
``reasoning.summary``) surface text; unknown or encrypted types
return an empty string even if they carry a ``text`` /
``summary`` field, to guard against future provider metadata
being rendered as reasoning in the UI. Entries missing a
``type`` are treated as text (pre-``reasoning_details`` OpenRouter
payloads omit the field).
"""
if self.type is not None and self.type not in _VISIBLE_REASONING_TYPES:
return ""
return self.text or self.summary or ""
class OpenRouterDeltaExtension(BaseModel):
"""Non-OpenAI fields OpenRouter adds to streaming deltas.
Instantiate via :meth:`from_delta` which pulls the extension dict off
``ChoiceDelta.model_extra`` (where pydantic v2 stashes fields that
aren't part of the declared schema) and validates it through this
model. That keeps the parser honest — malformed entries surface as
validation errors rather than silent ``None``-coalesce bugs — and
avoids the ``getattr`` + ``isinstance`` duck-typing the earlier inline
extractor relied on.
"""
model_config = ConfigDict(extra="ignore")
reasoning: str | None = None
reasoning_content: str | None = None
reasoning_details: list[ReasoningDetail] = Field(default_factory=list)
@classmethod
def from_delta(cls, delta: ChoiceDelta) -> "OpenRouterDeltaExtension":
"""Build an extension view from ``delta.model_extra``.
Malformed provider payloads (e.g. ``reasoning_details`` shipped as
a string rather than a list) surface as a ``ValidationError`` which
is logged and swallowed — returning an empty extension so the rest
of the stream (valid text / tool calls) keeps flowing. An optional
feature's corrupted wire data must never abort the whole stream.
"""
try:
return cls.model_validate(delta.model_extra or {})
except ValidationError as exc:
logger.warning(
"[Baseline] Dropping malformed OpenRouter reasoning payload: %s",
exc,
)
return cls()
def visible_text(self) -> str:
"""Concatenated reasoning text, pulled from whichever channel is set.
Priority: the legacy ``reasoning`` string, then DeepSeek's
``reasoning_content``, then the concatenation of text-bearing
entries in ``reasoning_details``. Only one channel is set per
provider in practice; the priority order just makes the fallback
deterministic if a provider ever emits multiple.
"""
if self.reasoning:
return self.reasoning
if self.reasoning_content:
return self.reasoning_content
return "".join(d.visible_text for d in self.reasoning_details)
def reasoning_extra_body(model: str, max_thinking_tokens: int) -> dict[str, Any] | None:
"""Build the ``extra_body["reasoning"]`` fragment for the OpenAI client.
Returns ``None`` for non-Anthropic routes (other OpenRouter providers
ignore the field but we skip it anyway to keep the payload minimal)
and for ``max_thinking_tokens <= 0`` (operator kill switch).
"""
# Imported lazily to avoid pulling service.py at module load — service.py
# imports this module, and the lazy import keeps the dependency one-way.
from backend.copilot.baseline.service import _is_anthropic_model
if not _is_anthropic_model(model) or max_thinking_tokens <= 0:
return None
return {"reasoning": {"max_tokens": max_thinking_tokens}}
class BaselineReasoningEmitter:
"""Owns the reasoning block lifecycle for one streaming round.
Two concerns live here, both driven by the same state machine:
1. **Wire events.** The AI SDK v6 wire format pairs every
``reasoning-start`` with a matching ``reasoning-end`` and treats
reasoning / text / tool-use as distinct UI parts that must not
interleave.
2. **Session persistence.** ``ChatMessage(role="reasoning")`` rows in
``session.messages`` are what
``convertChatSessionToUiMessages.ts`` folds into the assistant
bubble as ``{type: "reasoning"}`` UI parts on reload and on
``useHydrateOnStreamEnd`` swaps. Without them the live-streamed
reasoning parts get overwritten by the hydrated (reasoning-less)
message list the moment the stream ends. Mirrors the SDK path's
``acc.reasoning_response`` pattern so both routes render the same
way on reload.
Pass ``session_messages`` to enable persistence; omit for pure
wire-emission (tests, scratch callers). On first reasoning delta a
fresh ``ChatMessage(role="reasoning")`` is appended and mutated
in-place as further deltas arrive; :meth:`close` drops the reference
but leaves the appended row intact.
"""
def __init__(
self,
session_messages: list[ChatMessage] | None = None,
) -> None:
self._block_id: str = str(uuid.uuid4())
self._open: bool = False
self._session_messages = session_messages
self._current_row: ChatMessage | None = None
@property
def is_open(self) -> bool:
return self._open
def on_delta(self, delta: ChoiceDelta) -> list[StreamBaseResponse]:
"""Return events for the reasoning text carried by *delta*.
Empty list when the chunk carries no reasoning payload, so this is
safe to call on every chunk without guarding at the call site.
Persistence (when a session message list is attached) happens in
lockstep with emission so the row's content stays equal to the
concatenated deltas at every delta boundary.
"""
ext = OpenRouterDeltaExtension.from_delta(delta)
text = ext.visible_text()
if not text:
return []
events: list[StreamBaseResponse] = []
if not self._open:
events.append(StreamReasoningStart(id=self._block_id))
self._open = True
if self._session_messages is not None:
self._current_row = ChatMessage(role="reasoning", content="")
self._session_messages.append(self._current_row)
events.append(StreamReasoningDelta(id=self._block_id, delta=text))
if self._current_row is not None:
self._current_row.content = (self._current_row.content or "") + text
return events
def close(self) -> list[StreamBaseResponse]:
"""Emit ``StreamReasoningEnd`` for the open block (if any) and rotate.
Idempotent — returns ``[]`` when no block is open. The id rotation
guarantees the next reasoning block starts with a fresh id rather
than reusing one already closed on the wire. The persisted row is
not removed — it stays in ``session_messages`` as the durable
record of what was reasoned.
"""
if not self._open:
return []
event = StreamReasoningEnd(id=self._block_id)
self._open = False
self._block_id = str(uuid.uuid4())
self._current_row = None
return [event]

View File

@@ -0,0 +1,281 @@
"""Tests for the baseline reasoning extension module.
Covers the typed OpenRouter delta parser, the stateful emitter, and the
``extra_body`` builder. The emitter is tested against real
``ChoiceDelta`` pydantic instances so the ``model_extra`` plumbing the
parser relies on is exercised end-to-end.
"""
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from backend.copilot.baseline.reasoning import (
BaselineReasoningEmitter,
OpenRouterDeltaExtension,
ReasoningDetail,
reasoning_extra_body,
)
from backend.copilot.model import ChatMessage
from backend.copilot.response_model import (
StreamReasoningDelta,
StreamReasoningEnd,
StreamReasoningStart,
)
def _delta(**extra) -> ChoiceDelta:
"""Build a ChoiceDelta with the given extension fields on ``model_extra``."""
return ChoiceDelta.model_validate({"role": "assistant", **extra})
class TestReasoningDetail:
def test_visible_text_prefers_text(self):
d = ReasoningDetail(type="reasoning.text", text="hi", summary="ignored")
assert d.visible_text == "hi"
def test_visible_text_falls_back_to_summary(self):
d = ReasoningDetail(type="reasoning.summary", summary="tldr")
assert d.visible_text == "tldr"
def test_visible_text_empty_for_encrypted(self):
d = ReasoningDetail(type="reasoning.encrypted")
assert d.visible_text == ""
def test_unknown_fields_are_ignored(self):
# OpenRouter may add new fields in future payloads — they shouldn't
# cause validation errors.
d = ReasoningDetail.model_validate(
{"type": "reasoning.future", "text": "x", "signature": "opaque"}
)
assert d.text == "x"
def test_visible_text_empty_for_unknown_type(self):
# Unknown types may carry provider metadata that must not render as
# user-visible reasoning — regardless of whether a text/summary is
# present. Only ``reasoning.text`` / ``reasoning.summary`` surface.
d = ReasoningDetail(type="reasoning.future", text="leaked metadata")
assert d.visible_text == ""
def test_visible_text_surfaces_text_when_type_missing(self):
# Pre-``reasoning_details`` OpenRouter payloads omit ``type`` — treat
# them as text so we don't regress the legacy structured shape.
d = ReasoningDetail(text="plain")
assert d.visible_text == "plain"
class TestOpenRouterDeltaExtension:
def test_from_delta_reads_model_extra(self):
delta = _delta(reasoning="step one")
ext = OpenRouterDeltaExtension.from_delta(delta)
assert ext.reasoning == "step one"
def test_visible_text_legacy_string(self):
ext = OpenRouterDeltaExtension(reasoning="plain text")
assert ext.visible_text() == "plain text"
def test_visible_text_deepseek_alias(self):
ext = OpenRouterDeltaExtension(reasoning_content="alt channel")
assert ext.visible_text() == "alt channel"
def test_visible_text_structured_details_concat(self):
ext = OpenRouterDeltaExtension(
reasoning_details=[
ReasoningDetail(type="reasoning.text", text="hello "),
ReasoningDetail(type="reasoning.text", text="world"),
]
)
assert ext.visible_text() == "hello world"
def test_visible_text_skips_encrypted(self):
ext = OpenRouterDeltaExtension(
reasoning_details=[
ReasoningDetail(type="reasoning.encrypted"),
ReasoningDetail(type="reasoning.text", text="visible"),
]
)
assert ext.visible_text() == "visible"
def test_visible_text_empty_when_all_channels_blank(self):
ext = OpenRouterDeltaExtension()
assert ext.visible_text() == ""
def test_empty_delta_produces_empty_extension(self):
ext = OpenRouterDeltaExtension.from_delta(_delta())
assert ext.reasoning is None
assert ext.reasoning_content is None
assert ext.reasoning_details == []
def test_malformed_reasoning_payload_logged_and_swallowed(self, caplog):
# A malformed payload (e.g. reasoning_details shipped as a string
# rather than a list) must not abort the stream — log it and
# return an empty extension so valid text/tool events keep flowing.
# A plain mock is used here because ``from_delta`` only reads
# ``delta.model_extra`` — avoids reaching into pydantic internals
# (``__pydantic_extra__``) that could be renamed across versions.
from unittest.mock import MagicMock
delta = MagicMock(spec=ChoiceDelta)
delta.model_extra = {"reasoning_details": "not a list"}
with caplog.at_level("WARNING"):
ext = OpenRouterDeltaExtension.from_delta(delta)
assert ext.reasoning_details == []
assert ext.visible_text() == ""
assert any("malformed" in r.message.lower() for r in caplog.records)
def test_unknown_typed_entry_with_text_is_not_surfaced(self):
# Regression: the legacy extractor emitted any entry with a
# ``text`` or ``summary`` field. The typed parser now filters on
# the recognised types so future provider metadata can't leak
# into the reasoning collapse.
ext = OpenRouterDeltaExtension(
reasoning_details=[
ReasoningDetail(type="reasoning.future", text="provider metadata"),
ReasoningDetail(type="reasoning.text", text="real"),
]
)
assert ext.visible_text() == "real"
class TestReasoningExtraBody:
def test_anthropic_route_returns_fragment(self):
assert reasoning_extra_body("anthropic/claude-sonnet-4-6", 4096) == {
"reasoning": {"max_tokens": 4096}
}
def test_direct_claude_model_id_still_matches(self):
assert reasoning_extra_body("claude-3-5-sonnet-20241022", 2048) == {
"reasoning": {"max_tokens": 2048}
}
def test_non_anthropic_route_returns_none(self):
assert reasoning_extra_body("openai/gpt-4o", 4096) is None
assert reasoning_extra_body("google/gemini-2.5-pro", 4096) is None
def test_zero_max_tokens_kill_switch(self):
# Operator kill switch: ``max_thinking_tokens <= 0`` disables the
# ``reasoning`` extra_body fragment even on an Anthropic route.
# Lets us silence reasoning without dropping the SDK path's budget.
assert reasoning_extra_body("anthropic/claude-sonnet-4-6", 0) is None
assert reasoning_extra_body("anthropic/claude-sonnet-4-6", -1) is None
class TestBaselineReasoningEmitter:
def test_first_text_delta_emits_start_then_delta(self):
emitter = BaselineReasoningEmitter()
events = emitter.on_delta(_delta(reasoning="thinking"))
assert len(events) == 2
assert isinstance(events[0], StreamReasoningStart)
assert isinstance(events[1], StreamReasoningDelta)
assert events[0].id == events[1].id
assert events[1].delta == "thinking"
assert emitter.is_open is True
def test_subsequent_deltas_reuse_block_id_without_new_start(self):
emitter = BaselineReasoningEmitter()
first = emitter.on_delta(_delta(reasoning="a"))
second = emitter.on_delta(_delta(reasoning="b"))
assert any(isinstance(e, StreamReasoningStart) for e in first)
assert all(not isinstance(e, StreamReasoningStart) for e in second)
assert len(second) == 1
assert isinstance(second[0], StreamReasoningDelta)
assert first[0].id == second[0].id
def test_empty_delta_emits_nothing(self):
emitter = BaselineReasoningEmitter()
assert emitter.on_delta(_delta(content="hello")) == []
assert emitter.is_open is False
def test_close_emits_end_and_rotates_id(self):
emitter = BaselineReasoningEmitter()
# Capture the block id from the wire event rather than reaching
# into emitter internals — the id on the emitted Start/Delta is
# what the frontend actually receives.
start_events = emitter.on_delta(_delta(reasoning="x"))
first_id = start_events[0].id
events = emitter.close()
assert len(events) == 1
assert isinstance(events[0], StreamReasoningEnd)
assert events[0].id == first_id
assert emitter.is_open is False
# Next reasoning uses a fresh id.
new_events = emitter.on_delta(_delta(reasoning="y"))
assert isinstance(new_events[0], StreamReasoningStart)
assert new_events[0].id != first_id
def test_close_is_idempotent(self):
emitter = BaselineReasoningEmitter()
assert emitter.close() == []
emitter.on_delta(_delta(reasoning="x"))
assert len(emitter.close()) == 1
assert emitter.close() == []
def test_structured_details_round_trip(self):
emitter = BaselineReasoningEmitter()
events = emitter.on_delta(
_delta(
reasoning_details=[
{"type": "reasoning.text", "text": "plan: "},
{"type": "reasoning.summary", "summary": "do the thing"},
]
)
)
deltas = [e for e in events if isinstance(e, StreamReasoningDelta)]
assert len(deltas) == 1
assert deltas[0].delta == "plan: do the thing"
class TestReasoningPersistence:
"""The persistence contract: without ``role="reasoning"`` rows in
session.messages, useHydrateOnStreamEnd overwrites the live-streamed
reasoning parts and the Reasoning collapse vanishes. Every delta
must be reflected in the persisted row the moment it's emitted."""
def test_session_row_appended_on_first_delta(self):
session: list[ChatMessage] = []
emitter = BaselineReasoningEmitter(session)
assert session == []
emitter.on_delta(_delta(reasoning="hi"))
assert len(session) == 1
assert session[0].role == "reasoning"
assert session[0].content == "hi"
def test_subsequent_deltas_mutate_same_row(self):
session: list[ChatMessage] = []
emitter = BaselineReasoningEmitter(session)
emitter.on_delta(_delta(reasoning="part one "))
emitter.on_delta(_delta(reasoning="part two"))
assert len(session) == 1
assert session[0].content == "part one part two"
def test_close_keeps_row_in_session(self):
session: list[ChatMessage] = []
emitter = BaselineReasoningEmitter(session)
emitter.on_delta(_delta(reasoning="thought"))
emitter.close()
assert len(session) == 1
assert session[0].content == "thought"
def test_second_reasoning_block_appends_new_row(self):
session: list[ChatMessage] = []
emitter = BaselineReasoningEmitter(session)
emitter.on_delta(_delta(reasoning="first"))
emitter.close()
emitter.on_delta(_delta(reasoning="second"))
assert len(session) == 2
assert [m.content for m in session] == ["first", "second"]
def test_no_session_means_no_persistence(self):
"""Emitter without attached session list emits wire events only."""
emitter = BaselineReasoningEmitter()
events = emitter.on_delta(_delta(reasoning="pure wire"))
assert len(events) == 2 # start + delta, no crash
# Nothing else to assert — just proves None session is supported.

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,7 @@
"""Integration tests for baseline transcript flow.
Exercises the real helpers in ``baseline/service.py`` that download,
validate, load, append to, backfill, and upload the transcript.
Exercises the real helpers in ``baseline/service.py`` that restore,
validate, load, append to, backfill, and upload the CLI session.
Storage is mocked via ``download_transcript`` / ``upload_transcript``
patches; no network access is required.
"""
@@ -12,13 +12,14 @@ from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.baseline.service import (
_append_gap_to_builder,
_load_prior_transcript,
_record_turn_to_transcript,
_resolve_baseline_model,
_upload_final_transcript,
is_transcript_stale,
should_upload_transcript,
)
from backend.copilot.model import ChatMessage
from backend.copilot.service import config
from backend.copilot.transcript import (
STOP_REASON_END_TURN,
@@ -54,106 +55,128 @@ def _make_transcript_content(*roles: str) -> str:
return "\n".join(lines) + "\n"
def _make_session_messages(*roles: str) -> list[ChatMessage]:
"""Build a list of ChatMessage objects matching the given roles."""
return [
ChatMessage(role=r, content=f"{r} message {i}") for i, r in enumerate(roles)
]
class TestResolveBaselineModel:
"""Model selection honours the per-request mode."""
"""Baseline model resolution honours the per-request tier toggle."""
def test_fast_mode_selects_fast_model(self):
assert _resolve_baseline_model("fast") == config.fast_model
def test_advanced_tier_selects_advanced_model(self):
assert _resolve_baseline_model("advanced") == config.advanced_model
def test_extended_thinking_selects_default_model(self):
assert _resolve_baseline_model("extended_thinking") == config.model
def test_standard_tier_selects_default_model(self):
assert _resolve_baseline_model("standard") == config.model
def test_none_mode_selects_default_model(self):
"""Critical: baseline users without a mode MUST keep the default (opus)."""
def test_none_tier_selects_default_model(self):
"""Baseline users without a tier MUST keep the default (standard)."""
assert _resolve_baseline_model(None) == config.model
def test_default_and_fast_models_same(self):
"""SDK defaults currently keep standard and fast on Sonnet 4.6."""
assert config.model == config.fast_model
def test_standard_and_advanced_models_differ(self):
"""Advanced tier defaults to a different (Opus) model than standard."""
assert config.model != config.advanced_model
class TestLoadPriorTranscript:
"""``_load_prior_transcript`` wraps the download + validate + load flow."""
"""``_load_prior_transcript`` wraps the CLI session restore + validate + load flow."""
@pytest.mark.asyncio
async def test_loads_fresh_transcript(self):
builder = TranscriptBuilder()
content = _make_transcript_content("user", "assistant")
download = TranscriptDownload(content=content, message_count=2)
restore = TranscriptDownload(
content=content.encode("utf-8"), message_count=2, mode="sdk"
)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
covers, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
session_messages=_make_session_messages("user", "assistant", "user"),
transcript_builder=builder,
)
assert covers is True
assert dl is not None
assert dl.message_count == 2
assert builder.entry_count == 2
assert builder.last_entry_type == "assistant"
@pytest.mark.asyncio
async def test_rejects_stale_transcript(self):
"""msg_count strictly less than session-1 is treated as stale."""
async def test_fills_gap_when_transcript_is_behind(self):
"""When transcript covers fewer messages than session, gap is filled from DB."""
builder = TranscriptBuilder()
content = _make_transcript_content("user", "assistant")
# session has 6 messages, transcript only covers 2 → stale.
download = TranscriptDownload(content=content, message_count=2)
# transcript covers 2 messages, session has 4 (plus current user turn = 5)
restore = TranscriptDownload(
content=content.encode("utf-8"), message_count=2, mode="baseline"
)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
covers, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=6,
session_messages=_make_session_messages(
"user", "assistant", "user", "assistant", "user"
),
transcript_builder=builder,
)
assert covers is False
assert builder.is_empty
assert covers is True
assert dl is not None
# 2 from transcript + 2 gap messages (user+assistant at positions 2,3)
assert builder.entry_count == 4
@pytest.mark.asyncio
async def test_missing_transcript_returns_false(self):
async def test_missing_transcript_allows_upload(self):
"""Nothing in GCS → upload is safe; the turn writes the first snapshot."""
builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=None),
):
covers = await _load_prior_transcript(
upload_safe, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
session_messages=_make_session_messages("user", "assistant"),
transcript_builder=builder,
)
assert covers is False
assert upload_safe is True
assert dl is None
assert builder.is_empty
@pytest.mark.asyncio
async def test_invalid_transcript_returns_false(self):
async def test_invalid_transcript_allows_upload(self):
"""Corrupt file in GCS → overwriting with a valid one is better."""
builder = TranscriptBuilder()
download = TranscriptDownload(
content='{"type":"progress","uuid":"a"}\n',
restore = TranscriptDownload(
content=b'{"type":"progress","uuid":"a"}\n',
message_count=1,
mode="sdk",
)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
upload_safe, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
session_messages=_make_session_messages("user", "assistant"),
transcript_builder=builder,
)
assert covers is False
assert upload_safe is True
assert dl is None
assert builder.is_empty
@pytest.mark.asyncio
@@ -163,36 +186,39 @@ class TestLoadPriorTranscript:
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(side_effect=RuntimeError("boom")),
):
covers = await _load_prior_transcript(
covers, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
session_messages=_make_session_messages("user", "assistant"),
transcript_builder=builder,
)
assert covers is False
assert dl is None
assert builder.is_empty
@pytest.mark.asyncio
async def test_zero_message_count_not_stale(self):
"""When msg_count is 0 (unknown), staleness check is skipped."""
"""When msg_count is 0 (unknown), gap detection is skipped."""
builder = TranscriptBuilder()
download = TranscriptDownload(
content=_make_transcript_content("user", "assistant"),
restore = TranscriptDownload(
content=_make_transcript_content("user", "assistant").encode("utf-8"),
message_count=0,
mode="sdk",
)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
covers, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=20,
session_messages=_make_session_messages(*["user"] * 20),
transcript_builder=builder,
)
assert covers is True
assert dl is not None
assert builder.entry_count == 2
@@ -227,7 +253,7 @@ class TestUploadFinalTranscript:
assert call_kwargs["user_id"] == "user-1"
assert call_kwargs["session_id"] == "session-1"
assert call_kwargs["message_count"] == 2
assert "hello" in call_kwargs["content"]
assert b"hello" in call_kwargs["content"]
@pytest.mark.asyncio
async def test_skips_upload_when_builder_empty(self):
@@ -374,17 +400,19 @@ class TestRoundTrip:
@pytest.mark.asyncio
async def test_full_round_trip(self):
prior = _make_transcript_content("user", "assistant")
download = TranscriptDownload(content=prior, message_count=2)
restore = TranscriptDownload(
content=prior.encode("utf-8"), message_count=2, mode="sdk"
)
builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
covers, _ = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
session_messages=_make_session_messages("user", "assistant", "user"),
transcript_builder=builder,
)
assert covers is True
@@ -424,11 +452,11 @@ class TestRoundTrip:
upload_mock.assert_awaited_once()
assert upload_mock.await_args is not None
uploaded = upload_mock.await_args.kwargs["content"]
assert "new question" in uploaded
assert "new answer" in uploaded
assert b"new question" in uploaded
assert b"new answer" in uploaded
# Original content preserved in the round trip.
assert "user message 0" in uploaded
assert "assistant message 1" in uploaded
assert b"user message 0" in uploaded
assert b"assistant message 1" in uploaded
@pytest.mark.asyncio
async def test_backfill_append_guard(self):
@@ -459,36 +487,6 @@ class TestRoundTrip:
assert builder.entry_count == initial_count
class TestIsTranscriptStale:
"""``is_transcript_stale`` gates prior-transcript loading."""
def test_none_download_is_not_stale(self):
assert is_transcript_stale(None, session_msg_count=5) is False
def test_zero_message_count_is_not_stale(self):
"""Legacy transcripts without msg_count tracking must remain usable."""
dl = TranscriptDownload(content="", message_count=0)
assert is_transcript_stale(dl, session_msg_count=20) is False
def test_stale_when_covers_less_than_prefix(self):
dl = TranscriptDownload(content="", message_count=2)
# session has 6 messages; transcript must cover at least 5 (6-1).
assert is_transcript_stale(dl, session_msg_count=6) is True
def test_fresh_when_covers_full_prefix(self):
dl = TranscriptDownload(content="", message_count=5)
assert is_transcript_stale(dl, session_msg_count=6) is False
def test_fresh_when_exceeds_prefix(self):
"""Race: transcript ahead of session count is still acceptable."""
dl = TranscriptDownload(content="", message_count=10)
assert is_transcript_stale(dl, session_msg_count=6) is False
def test_boundary_equal_to_prefix_minus_one(self):
dl = TranscriptDownload(content="", message_count=5)
assert is_transcript_stale(dl, session_msg_count=6) is False
class TestShouldUploadTranscript:
"""``should_upload_transcript`` gates the final upload."""
@@ -510,7 +508,7 @@ class TestShouldUploadTranscript:
class TestTranscriptLifecycle:
"""End-to-end: download → validate → build → upload.
"""End-to-end: restore → validate → build → upload.
Simulates the full transcript lifecycle inside
``stream_chat_completion_baseline`` by mocking the storage layer and
@@ -519,27 +517,29 @@ class TestTranscriptLifecycle:
@pytest.mark.asyncio
async def test_full_lifecycle_happy_path(self):
"""Fresh download, append a turn, upload covers the session."""
"""Fresh restore, append a turn, upload covers the session."""
builder = TranscriptBuilder()
prior = _make_transcript_content("user", "assistant")
download = TranscriptDownload(content=prior, message_count=2)
restore = TranscriptDownload(
content=prior.encode("utf-8"), message_count=2, mode="sdk"
)
upload_mock = AsyncMock(return_value=None)
with (
patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
),
patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
),
):
# --- 1. Download & load prior transcript ---
covers = await _load_prior_transcript(
# --- 1. Restore & load prior session ---
covers, _ = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
session_messages=_make_session_messages("user", "assistant", "user"),
transcript_builder=builder,
)
assert covers is True
@@ -559,10 +559,7 @@ class TestTranscriptLifecycle:
# --- 3. Gate + upload ---
assert (
should_upload_transcript(
user_id="user-1", transcript_covers_prefix=covers
)
is True
should_upload_transcript(user_id="user-1", upload_safe=covers) is True
)
await _upload_final_transcript(
user_id="user-1",
@@ -574,20 +571,21 @@ class TestTranscriptLifecycle:
upload_mock.assert_awaited_once()
assert upload_mock.await_args is not None
uploaded = upload_mock.await_args.kwargs["content"]
assert "follow-up question" in uploaded
assert "follow-up answer" in uploaded
assert b"follow-up question" in uploaded
assert b"follow-up answer" in uploaded
# Original prior-turn content preserved.
assert "user message 0" in uploaded
assert "assistant message 1" in uploaded
assert b"user message 0" in uploaded
assert b"assistant message 1" in uploaded
@pytest.mark.asyncio
async def test_lifecycle_stale_download_suppresses_upload(self):
"""Stale download → covers=False → upload must be skipped."""
async def test_lifecycle_stale_download_fills_gap(self):
"""When transcript covers fewer messages, gap is filled rather than rejected."""
builder = TranscriptBuilder()
# session has 10 msgs but stored transcript only covers 2 → stale.
# session has 5 msgs but stored transcript only covers 2 → gap filled.
stale = TranscriptDownload(
content=_make_transcript_content("user", "assistant"),
content=_make_transcript_content("user", "assistant").encode("utf-8"),
message_count=2,
mode="baseline",
)
upload_mock = AsyncMock(return_value=None)
@@ -601,20 +599,18 @@ class TestTranscriptLifecycle:
new=upload_mock,
),
):
covers = await _load_prior_transcript(
covers, _ = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=10,
session_messages=_make_session_messages(
"user", "assistant", "user", "assistant", "user"
),
transcript_builder=builder,
)
assert covers is False
# The caller's gate mirrors the production path.
assert (
should_upload_transcript(user_id="user-1", transcript_covers_prefix=covers)
is False
)
upload_mock.assert_not_awaited()
assert covers is True
# Gap was filled: 2 from transcript + 2 gap messages
assert builder.entry_count == 4
@pytest.mark.asyncio
async def test_lifecycle_anonymous_user_skips_upload(self):
@@ -627,15 +623,11 @@ class TestTranscriptLifecycle:
stop_reason=STOP_REASON_END_TURN,
)
assert (
should_upload_transcript(user_id=None, transcript_covers_prefix=True)
is False
)
assert should_upload_transcript(user_id=None, upload_safe=True) is False
@pytest.mark.asyncio
async def test_lifecycle_missing_download_still_uploads_new_content(self):
"""No prior transcript → covers defaults to True in the service,
new turn should upload cleanly."""
"""No prior session → upload is safe; the turn writes the first snapshot."""
builder = TranscriptBuilder()
upload_mock = AsyncMock(return_value=None)
with (
@@ -648,20 +640,117 @@ class TestTranscriptLifecycle:
new=upload_mock,
),
):
covers = await _load_prior_transcript(
upload_safe, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=1,
session_messages=_make_session_messages("user"),
transcript_builder=builder,
)
# No download: covers is False, so the production path would
# skip upload. This protects against overwriting a future
# more-complete transcript with a single-turn snapshot.
assert covers is False
# Nothing in GCS → upload is safe so the first baseline turn
# can write the initial transcript snapshot.
assert upload_safe is True
assert dl is None
assert (
should_upload_transcript(
user_id="user-1", transcript_covers_prefix=covers
)
is False
should_upload_transcript(user_id="user-1", upload_safe=upload_safe)
is True
)
upload_mock.assert_not_awaited()
# ---------------------------------------------------------------------------
# _append_gap_to_builder
# ---------------------------------------------------------------------------
class TestAppendGapToBuilder:
"""``_append_gap_to_builder`` converts ChatMessage objects to TranscriptBuilder entries."""
def test_user_message_appended(self):
builder = TranscriptBuilder()
msgs = [ChatMessage(role="user", content="hello")]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 1
assert builder.last_entry_type == "user"
def test_assistant_text_message_appended(self):
builder = TranscriptBuilder()
msgs = [
ChatMessage(role="user", content="q"),
ChatMessage(role="assistant", content="answer"),
]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 2
assert builder.last_entry_type == "assistant"
assert "answer" in builder.to_jsonl()
def test_assistant_with_tool_calls_appended(self):
"""Assistant tool_calls are recorded as tool_use blocks in the transcript."""
builder = TranscriptBuilder()
tool_call = {
"id": "tc-1",
"type": "function",
"function": {"name": "my_tool", "arguments": '{"key":"val"}'},
}
msgs = [ChatMessage(role="assistant", content=None, tool_calls=[tool_call])]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 1
jsonl = builder.to_jsonl()
assert "tool_use" in jsonl
assert "my_tool" in jsonl
assert "tc-1" in jsonl
def test_assistant_invalid_json_args_uses_empty_dict(self):
"""Malformed JSON in tool_call arguments falls back to {}."""
builder = TranscriptBuilder()
tool_call = {
"id": "tc-bad",
"type": "function",
"function": {"name": "bad_tool", "arguments": "not-json"},
}
msgs = [ChatMessage(role="assistant", content=None, tool_calls=[tool_call])]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 1
jsonl = builder.to_jsonl()
assert '"input":{}' in jsonl
def test_assistant_empty_content_and_no_tools_uses_fallback(self):
"""Assistant with no content and no tool_calls gets a fallback empty text block."""
builder = TranscriptBuilder()
msgs = [ChatMessage(role="assistant", content=None)]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 1
jsonl = builder.to_jsonl()
assert "text" in jsonl
def test_tool_role_with_tool_call_id_appended(self):
"""Tool result messages are appended when tool_call_id is set."""
builder = TranscriptBuilder()
# Need a preceding assistant tool_use entry
builder.append_user("use tool")
builder.append_assistant(
content_blocks=[
{"type": "tool_use", "id": "tc-1", "name": "my_tool", "input": {}}
]
)
msgs = [ChatMessage(role="tool", tool_call_id="tc-1", content="result")]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 3
assert "tool_result" in builder.to_jsonl()
def test_tool_role_without_tool_call_id_skipped(self):
"""Tool messages without tool_call_id are silently skipped."""
builder = TranscriptBuilder()
msgs = [ChatMessage(role="tool", tool_call_id=None, content="orphan")]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 0
def test_tool_call_missing_function_key_uses_unknown_name(self):
"""A tool_call dict with no 'function' key uses 'unknown' as the tool name."""
builder = TranscriptBuilder()
# Tool call dict exists but 'function' sub-dict is missing entirely
msgs = [
ChatMessage(role="assistant", content=None, tool_calls=[{"id": "tc-x"}])
]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 1
jsonl = builder.to_jsonl()
assert "unknown" in jsonl

View File

@@ -17,8 +17,8 @@ from backend.util.clients import OPENROUTER_BASE_URL
CopilotMode = Literal["fast", "extended_thinking"]
# Per-request model tier set by the frontend model toggle.
# 'standard' uses the global config default (currently Sonnet).
# 'advanced' forces the highest-capability model (currently Opus).
# 'standard' uses ``ChatConfig.model`` (Sonnet by default).
# 'advanced' uses ``ChatConfig.advanced_model`` (Opus by default).
# None means no preference — falls through to LD per-user targeting, then config.
# Using tier names instead of model names keeps the contract model-agnostic.
CopilotLlmModel = Literal["standard", "advanced"]
@@ -27,16 +27,21 @@ CopilotLlmModel = Literal["standard", "advanced"]
class ChatConfig(BaseSettings):
"""Configuration for the chat system."""
# OpenAI API Configuration
# Chat model tiers — applied orthogonally to the path (fast=baseline vs
# extended_thinking=SDK). The "fast" vs "extended_thinking" toggle picks
# which code path runs (no reasoning / heavy SDK); "standard" vs
# "advanced" picks the model inside that path.
model: str = Field(
default="anthropic/claude-sonnet-4-6",
description="Default model for extended thinking mode. "
"Uses Sonnet 4.6 as the balanced default. "
"Override via CHAT_MODEL env var if you want a different default.",
description="Model used for the 'standard' tier (Sonnet by default). "
"Applies to both baseline (fast) and SDK (extended thinking) paths. "
"Override via CHAT_MODEL env var.",
)
fast_model: str = Field(
default="anthropic/claude-sonnet-4-6",
description="Model for fast mode (baseline path). Should be faster/cheaper than the default model.",
advanced_model: str = Field(
default="anthropic/claude-opus-4-7",
description="Model used for the 'advanced' tier (Opus by default). "
"Applies to both baseline (fast) and SDK (extended thinking) paths. "
"Override via CHAT_ADVANCED_MODEL env var.",
)
title_model: str = Field(
default="openai/gpt-4o-mini",
@@ -96,25 +101,31 @@ class ChatConfig(BaseSettings):
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
)
# Rate limiting — token-based limits per day and per week.
# Per-turn token cost varies with context size: ~10-15K for early turns,
# ~30-50K mid-session, up to ~100K pre-compaction. Average across a
# session with compaction cycles is ~25-35K tokens/turn, so 2.5M daily
# allows ~70-100 turns/day.
# Rate limiting — cost-based limits per day and per week, stored in
# microdollars (1 USD = 1_000_000). The counter tracks the real
# generation cost reported by the provider (OpenRouter ``usage.cost``
# or Claude Agent SDK ``total_cost_usd``), so cache discounts and
# cross-model price differences are already reflected — no token
# weighting or model multiplier is applied on top.
# Checked at the HTTP layer (routes.py) before each turn.
#
# These are base limits for the FREE tier. Higher tiers (PRO, BUSINESS,
# These are base limits for the FREE tier. Higher tiers (PRO, BUSINESS,
# ENTERPRISE) multiply these by their tier multiplier (see
# rate_limit.TIER_MULTIPLIERS). User tier is stored in the
# rate_limit.TIER_MULTIPLIERS). User tier is stored in the
# User.subscriptionTier DB column and resolved inside
# get_global_rate_limits().
daily_token_limit: int = Field(
default=2_500_000,
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
#
# These defaults act as the ceiling when LaunchDarkly is unreachable;
# the live per-tier values come from the COPILOT_*_COST_LIMIT flags.
daily_cost_limit_microdollars: int = Field(
default=1_000_000,
description="Max cost per day in microdollars, resets at midnight UTC "
"(0 = unlimited).",
)
weekly_token_limit: int = Field(
default=12_500_000,
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
weekly_cost_limit_microdollars: int = Field(
default=5_000_000,
description="Max cost per week in microdollars, resets Monday 00:00 UTC "
"(0 = unlimited).",
)
# Cost (in credits / cents) to reset the daily rate limit using credits.
@@ -183,9 +194,11 @@ class ChatConfig(BaseSettings):
default=8192,
ge=1024,
le=128000,
description="Maximum thinking/reasoning tokens per LLM call. "
"Extended thinking on Opus can generate 50k+ tokens at $75/M — "
"capping this is the single biggest cost lever. "
description="Maximum thinking/reasoning tokens per LLM call. Applies "
"to both the Claude Agent SDK path (as ``max_thinking_tokens``) and "
"the baseline OpenRouter path (as ``extra_body.reasoning.max_tokens`` "
"on Anthropic routes). Extended thinking on Opus can generate 50k+ "
"tokens at $75/M — capping this is the single biggest cost lever. "
"8192 is sufficient for most tasks; increase for complex reasoning.",
)
claude_agent_thinking_effort: Literal["low", "medium", "high", "max"] | None = (
@@ -214,6 +227,18 @@ class ChatConfig(BaseSettings):
"from the prefix. Set to False to fall back to passing the system "
"prompt as a raw string.",
)
baseline_prompt_cache_ttl: str = Field(
default="1h",
description="TTL for the ephemeral prompt-cache markers on the baseline "
"OpenRouter path. Anthropic supports only `5m` (default, 1.25x input "
"price for the write) or `1h` (2x input price for the write). 1h is "
"strictly cheaper overall when the static prefix gets >7 reads per "
"write-window; since the system prompt + tools array is identical "
"across all users in our workspace, 1h is the default so cross-user "
"reads amortise the higher write cost. Anthropic has no longer "
"(24h, permanent) TTL option — see "
"https://platform.claude.com/docs/en/build-with-claude/prompt-caching.",
)
claude_agent_cli_path: str | None = Field(
default=None,
description="Optional explicit path to a Claude Code CLI binary. "

View File

@@ -9,6 +9,11 @@ COPILOT_RETRYABLE_ERROR_PREFIX = (
)
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
# Canonical marker appended as an assistant ChatMessage when the SDK stream
# ends without a ResultMessage (user hit Stop). Checked by exact equality
# at turn start so the next turn's --resume transcript doesn't carry it.
STOPPED_BY_USER_MARKER = f"{COPILOT_SYSTEM_PREFIX} Execution stopped by user"
# Prefix for all synthetic IDs generated by CoPilot block execution.
# Used to distinguish CoPilot-generated records from real graph execution records
# in PendingHumanReview and other tables.
@@ -27,6 +32,24 @@ COMPACTION_DONE_MSG = "Earlier messages were summarized to fit within context li
COMPACTION_TOOL_NAME = "context_compaction"
# ---------------------------------------------------------------------------
# Tool / stream timing budget
# ---------------------------------------------------------------------------
# Max seconds any single MCP tool call may block the stream before returning
# a "still running" handle. Shared by run_agent (wait_for_result),
# view_agent_output (wait_if_running), run_sub_session (wait_for_result),
# get_sub_session_result (wait_if_running), and run_block (hard cap).
#
# Chosen so the stream idle timeout (2× this) always has headroom — a tool
# that returns right at the cap can't race the idle watchdog.
MAX_TOOL_WAIT_SECONDS = 5 * 60 # 5 minutes
# Idle-stream watchdog: abort the SDK stream if no meaningful event arrives
# for this long. Derived from MAX_TOOL_WAIT_SECONDS so the invariant
# "no tool blocks >= idle_timeout" holds by construction.
STREAM_IDLE_TIMEOUT_SECONDS = MAX_TOOL_WAIT_SECONDS * 2 # 10 minutes
def is_copilot_synthetic_id(id_value: str) -> bool:
"""Check if an ID is a CoPilot synthetic ID (not from a real graph execution)."""
return id_value.startswith(COPILOT_SYNTHETIC_ID_PREFIX)

View File

@@ -23,7 +23,7 @@ if TYPE_CHECKING:
# Allowed base directory for the Read tool. Public so service.py can use it
# for sweep operations without depending on a private implementation detail.
# Respects CLAUDE_CONFIG_DIR env var, consistent with transcript.py's
# _projects_base() function.
# projects_base() function.
_config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
SDK_PROJECTS_DIR = os.path.realpath(os.path.join(_config_dir, "projects"))

View File

@@ -10,9 +10,11 @@ from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from prisma.types import (
ChatMessageCreateInput,
ChatMessageWhereInput,
ChatSessionCreateInput,
ChatSessionUpdateInput,
ChatSessionWhereInput,
FindManyChatMessageArgsFromChatSession,
)
from pydantic import BaseModel
@@ -30,6 +32,8 @@ from .model import get_chat_session as get_chat_session_cached
logger = logging.getLogger(__name__)
_BOUNDARY_SCAN_LIMIT = 10
class PaginatedMessages(BaseModel):
"""Result of a paginated message query."""
@@ -69,12 +73,10 @@ async def get_chat_messages_paginated(
in parallel with the message query. Returns ``None`` when the session
is not found or does not belong to the user.
Args:
session_id: The chat session ID.
limit: Max messages to return.
before_sequence: Cursor — return messages with sequence < this value.
user_id: If provided, filters via ``Session.userId`` so only the
session owner's messages are returned (acts as an ownership guard).
After fetching, a visibility guarantee ensures the page contains at least
one user or assistant message. If the entire page is tool messages (which
are hidden in the UI), it expands backward until a visible message is found
so the chat never appears blank.
"""
# Build session-existence / ownership check
session_where: ChatSessionWhereInput = {"id": session_id}
@@ -82,7 +84,7 @@ async def get_chat_messages_paginated(
session_where["userId"] = user_id
# Build message include — fetch paginated messages in the same query
msg_include: dict[str, Any] = {
msg_include: FindManyChatMessageArgsFromChatSession = {
"order_by": {"sequence": "desc"},
"take": limit + 1,
}
@@ -111,42 +113,18 @@ async def get_chat_messages_paginated(
# expand backward to include the preceding assistant message that
# owns the tool_calls, so convertChatSessionMessagesToUiMessages
# can pair them correctly.
_BOUNDARY_SCAN_LIMIT = 10
if results and results[0].role == "tool":
boundary_where: dict[str, Any] = {
"sessionId": session_id,
"sequence": {"lt": results[0].sequence},
}
if user_id is not None:
boundary_where["Session"] = {"is": {"userId": user_id}}
extra = await PrismaChatMessage.prisma().find_many(
where=boundary_where,
order={"sequence": "desc"},
take=_BOUNDARY_SCAN_LIMIT,
results, has_more = await _expand_tool_boundary(
session_id, results, has_more, user_id
)
# Visibility guarantee: if the entire page has no user/assistant messages
# (all tool messages), the chat would appear blank. Expand backward
# until we find at least one visible message.
if results and not any(m.role in ("user", "assistant") for m in results):
results, has_more = await _expand_for_visibility(
session_id, results, has_more, user_id
)
# Find the first non-tool message (should be the assistant)
boundary_msgs = []
found_owner = False
for msg in extra:
boundary_msgs.append(msg)
if msg.role != "tool":
found_owner = True
break
boundary_msgs.reverse()
if not found_owner:
logger.warning(
"Boundary expansion did not find owning assistant message "
"for session=%s before sequence=%s (%d msgs scanned)",
session_id,
results[0].sequence,
len(extra),
)
if boundary_msgs:
results = boundary_msgs + results
# Only mark has_more if the expanded boundary isn't the
# very start of the conversation (sequence 0).
if boundary_msgs[0].sequence > 0:
has_more = True
messages = [ChatMessage.from_db(m) for m in results]
oldest_sequence = messages[0].sequence if messages else None
@@ -159,6 +137,98 @@ async def get_chat_messages_paginated(
)
async def _expand_tool_boundary(
session_id: str,
results: list[Any],
has_more: bool,
user_id: str | None,
) -> tuple[list[Any], bool]:
"""Expand backward from the oldest message to include the owning assistant
message when the page starts mid-tool-group."""
boundary_where: ChatMessageWhereInput = {
"sessionId": session_id,
"sequence": {"lt": results[0].sequence},
}
if user_id is not None:
boundary_where["Session"] = {"is": {"userId": user_id}}
extra = await PrismaChatMessage.prisma().find_many(
where=boundary_where,
order={"sequence": "desc"},
take=_BOUNDARY_SCAN_LIMIT,
)
# Find the first non-tool message (should be the assistant)
boundary_msgs = []
found_owner = False
for msg in extra:
boundary_msgs.append(msg)
if msg.role != "tool":
found_owner = True
break
boundary_msgs.reverse()
if not found_owner:
logger.warning(
"Boundary expansion did not find owning assistant message "
"for session=%s before sequence=%s (%d msgs scanned)",
session_id,
results[0].sequence,
len(extra),
)
if boundary_msgs:
results = boundary_msgs + results
has_more = boundary_msgs[0].sequence > 0
return results, has_more
_VISIBILITY_EXPAND_LIMIT = 200
async def _expand_for_visibility(
session_id: str,
results: list[Any],
has_more: bool,
user_id: str | None,
) -> tuple[list[Any], bool]:
"""Expand backward until the page contains at least one user or assistant
message, so the chat is never blank."""
expand_where: ChatMessageWhereInput = {
"sessionId": session_id,
"sequence": {"lt": results[0].sequence},
}
if user_id is not None:
expand_where["Session"] = {"is": {"userId": user_id}}
extra = await PrismaChatMessage.prisma().find_many(
where=expand_where,
order={"sequence": "desc"},
take=_VISIBILITY_EXPAND_LIMIT,
)
if not extra:
return results, has_more
# Collect messages until we find a visible one (user/assistant)
prepend = []
found_visible = False
for msg in extra:
prepend.append(msg)
if msg.role in ("user", "assistant"):
found_visible = True
break
if not found_visible:
logger.warning(
"Visibility expansion did not find any user/assistant message "
"for session=%s before sequence=%s (%d msgs scanned)",
session_id,
results[0].sequence,
len(extra),
)
prepend.reverse()
if prepend:
results = prepend + results
has_more = prepend[0].sequence > 0
return results, has_more
async def create_chat_session(
session_id: str,
user_id: str,

View File

@@ -175,6 +175,138 @@ async def test_no_where_on_messages_without_before_sequence(
assert "where" not in include["Messages"]
# ---------- Visibility guarantee ----------
@pytest.mark.asyncio
async def test_visibility_expands_when_all_tool_messages(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""When the entire page is tool messages, expand backward to find
at least one visible (user/assistant) message so the chat isn't blank."""
find_first, find_many = mock_db
# Newest 3 messages are all tool messages (DESC → reversed to ASC)
find_first.return_value = _make_session(
messages=[
_make_msg(12, role="tool"),
_make_msg(11, role="tool"),
_make_msg(10, role="tool"),
],
)
# Boundary expansion finds the owning assistant first (boundary fix),
# then visibility expansion finds a user message further back
find_many.side_effect = [
# First call: boundary fix (oldest msg is tool → find owner)
[_make_msg(9, role="tool"), _make_msg(8, role="tool")],
# Second call: visibility expansion (still all tool → find visible)
[_make_msg(7, role="tool"), _make_msg(6, role="assistant")],
]
page = await get_chat_messages_paginated(SESSION_ID, limit=3)
assert page is not None
# Should include the expanded messages + original tool messages
roles = [m.role for m in page.messages]
assert "assistant" in roles or "user" in roles
assert page.has_more is True
@pytest.mark.asyncio
async def test_no_visibility_expansion_when_visible_messages_present(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""No visibility expansion needed when page already has visible messages."""
find_first, find_many = mock_db
# Page has an assistant message among tool messages
find_first.return_value = _make_session(
messages=[
_make_msg(5, role="tool"),
_make_msg(4, role="assistant"),
_make_msg(3, role="user"),
],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=3)
assert page is not None
# Boundary expansion might fire (oldest is tool), but NOT visibility
assert [m.sequence for m in page.messages][0] <= 3
@pytest.mark.asyncio
async def test_visibility_no_expansion_when_no_earlier_messages(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""When the page is all tool messages but there are no earlier messages
in the DB, visibility expansion returns early without changes."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(1, role="tool"), _make_msg(0, role="tool")],
)
# Boundary expansion: no earlier messages
# Visibility expansion: no earlier messages
find_many.side_effect = [[], []]
page = await get_chat_messages_paginated(SESSION_ID, limit=2)
assert page is not None
assert all(m.role == "tool" for m in page.messages)
@pytest.mark.asyncio
async def test_visibility_expansion_reaches_seq_zero(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""When visibility expansion finds a visible message at sequence 0,
has_more should be False."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(5, role="tool"), _make_msg(4, role="tool")],
)
find_many.side_effect = [
# Boundary expansion
[_make_msg(3, role="tool")],
# Visibility expansion — finds user at seq 0
[
_make_msg(2, role="tool"),
_make_msg(1, role="tool"),
_make_msg(0, role="user"),
],
]
page = await get_chat_messages_paginated(SESSION_ID, limit=2)
assert page is not None
assert page.messages[0].role == "user"
assert page.messages[0].sequence == 0
assert page.has_more is False
@pytest.mark.asyncio
async def test_visibility_expansion_with_user_id(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Visibility expansion passes user_id filter to the boundary query."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(10, role="tool")],
)
find_many.side_effect = [
# Boundary expansion
[_make_msg(9, role="tool")],
# Visibility expansion
[_make_msg(8, role="assistant")],
]
await get_chat_messages_paginated(SESSION_ID, limit=1, user_id="user-abc")
# Both find_many calls should include the user_id session filter
for call in find_many.call_args_list:
where = call.kwargs.get("where") or call[1].get("where")
assert "Session" in where
assert where["Session"] == {"is": {"userId": "user-abc"}}
@pytest.mark.asyncio
async def test_user_id_filter_applied_to_session_where(
mock_db: tuple[AsyncMock, AsyncMock],
@@ -329,7 +461,8 @@ async def test_boundary_expansion_warns_when_no_owner_found(
with patch("backend.copilot.db.logger") as mock_logger:
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
mock_logger.warning.assert_called_once()
# Two warnings: boundary expansion + visibility expansion (all tool msgs)
assert mock_logger.warning.call_count == 2
assert page is not None
assert page.messages[0].role == "tool"

View File

@@ -34,6 +34,7 @@ from .utils import (
CancelCoPilotEvent,
CoPilotExecutionEntry,
create_copilot_queue_config,
get_session_lock_key,
)
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
@@ -366,7 +367,7 @@ class CoPilotExecutor(AppProcess):
# Try to acquire cluster-wide lock
cluster_lock = ClusterLock(
redis=redis.get_redis(),
key=f"copilot:session:{session_id}:lock",
key=get_session_lock_key(session_id),
owner_id=self.executor_id,
timeout=settings.config.cluster_lock_timeout,
)

View File

@@ -222,6 +222,10 @@ class CoPilotProcessor:
Shuts down the workspace storage instance that belongs to this
worker's event loop, ensuring ``aiohttp.ClientSession.close()``
runs on the same loop that created the session.
Sub-AutoPilots are enqueued on the copilot_execution queue, so
rolling deploys survive via RabbitMQ redelivery — no bespoke
shutdown notifier needed.
"""
coro = shutdown_workspace_storage()
try:
@@ -342,7 +346,9 @@ class CoPilotProcessor:
# Stream chat completion and publish chunks to Redis.
# stream_and_publish wraps the raw stream with registry
# publishing (shared with collect_copilot_response).
# publishing so subscribers on the session Redis stream
# (e.g. wait_for_session_result, SSE clients) receive the
# same events as they are produced.
raw_stream = stream_fn(
session_id=entry.session_id,
message=entry.message if entry.message else None,
@@ -352,27 +358,37 @@ class CoPilotProcessor:
file_ids=entry.file_ids,
mode=effective_mode,
model=entry.model,
permissions=entry.permissions,
request_arrival_at=entry.request_arrival_at,
)
async for chunk in stream_registry.stream_and_publish(
published_stream = stream_registry.stream_and_publish(
session_id=entry.session_id,
turn_id=entry.turn_id,
stream=raw_stream,
):
if cancel.is_set():
log.info("Cancel requested, breaking stream")
break
)
# Explicit aclose() on early exit: ``async for … break`` does
# not close the generator, so GeneratorExit would never reach
# stream_chat_completion_sdk, leaving its stream lock held
# until GC eventually runs.
try:
async for chunk in published_stream:
if cancel.is_set():
log.info("Cancel requested, breaking stream")
break
# Capture StreamError so mark_session_completed receives
# the error message (stream_and_publish yields but does
# not publish StreamError — that's done by mark_session_completed).
if isinstance(chunk, StreamError):
error_msg = chunk.errorText
break
# Capture StreamError so mark_session_completed receives
# the error message (stream_and_publish yields but does
# not publish StreamError — that's done by mark_session_completed).
if isinstance(chunk, StreamError):
error_msg = chunk.errorText
break
current_time = time.monotonic()
if current_time - last_refresh >= refresh_interval:
cluster_lock.refresh()
last_refresh = current_time
current_time = time.monotonic()
if current_time - last_refresh >= refresh_interval:
cluster_lock.refresh()
last_refresh = current_time
finally:
await published_stream.aclose()
# Stream loop completed
if cancel.is_set():

View File

@@ -10,14 +10,18 @@ the real production helpers from ``processor.py`` so the routing logic
has meaningful coverage.
"""
from unittest.mock import AsyncMock, patch
import logging
import threading
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot.executor.processor import (
CoPilotProcessor,
resolve_effective_mode,
resolve_use_sdk_for_mode,
)
from backend.copilot.executor.utils import CoPilotExecutionEntry, CoPilotLogMetadata
class TestResolveUseSdkForMode:
@@ -173,3 +177,101 @@ class TestResolveEffectiveMode:
) as flag_mock:
assert await resolve_effective_mode("fast", None) is None
flag_mock.assert_awaited_once()
# ---------------------------------------------------------------------------
# _execute_async aclose propagation
# ---------------------------------------------------------------------------
class _TrackedStream:
"""Minimal async-generator stand-in that records whether ``aclose``
was called, so tests can verify the processor forces explicit cleanup
of the published stream on every exit path (normal + break on cancel)."""
def __init__(self, events: list):
self._events = events
self.aclose_called = False
def __aiter__(self):
return self
async def __anext__(self):
if not self._events:
raise StopAsyncIteration
return self._events.pop(0)
async def aclose(self) -> None:
self.aclose_called = True
def _make_entry() -> CoPilotExecutionEntry:
return CoPilotExecutionEntry(
session_id="sess-1",
turn_id="turn-1",
user_id="user-1",
message="hi",
is_user_message=True,
request_arrival_at=0.0,
)
def _make_log() -> CoPilotLogMetadata:
return CoPilotLogMetadata(logger=logging.getLogger("test-copilot"))
class TestExecuteAsyncAclose:
"""``_execute_async`` must call ``aclose`` on the published stream both
when the loop exits naturally and when ``cancel`` is set mid-stream —
otherwise ``stream_chat_completion_sdk`` stays suspended and keeps
holding the per-session Redis lock until GC."""
def _patches(self, published_stream: _TrackedStream):
"""Shared mock context: patches every dependency ``_execute_async``
touches so the aclose path is the only behaviour under test."""
return [
patch(
"backend.copilot.executor.processor.ChatConfig",
return_value=MagicMock(test_mode=True, use_claude_agent_sdk=True),
),
patch(
"backend.copilot.executor.processor.stream_chat_completion_dummy",
return_value=MagicMock(),
),
patch(
"backend.copilot.executor.processor.stream_registry.stream_and_publish",
return_value=published_stream,
),
patch(
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
new=AsyncMock(),
),
]
@pytest.mark.asyncio
async def test_normal_exit_calls_aclose(self) -> None:
published = _TrackedStream(events=[MagicMock(), MagicMock()])
proc = CoPilotProcessor()
cancel = threading.Event()
cluster_lock = MagicMock()
patches = self._patches(published)
with patches[0], patches[1], patches[2], patches[3]:
await proc._execute_async(_make_entry(), cancel, cluster_lock, _make_log())
assert published.aclose_called is True
@pytest.mark.asyncio
async def test_cancel_break_calls_aclose(self) -> None:
events = [MagicMock()] # first chunk delivered, then cancel fires
published = _TrackedStream(events=events)
proc = CoPilotProcessor()
cancel = threading.Event()
cancel.set() # pre-set so the loop breaks on the first chunk
cluster_lock = MagicMock()
patches = self._patches(published)
with patches[0], patches[1], patches[2], patches[3]:
await proc._execute_async(_make_entry(), cancel, cluster_lock, _make_log())
assert published.aclose_called is True

View File

@@ -10,6 +10,7 @@ import logging
from pydantic import BaseModel
from backend.copilot.config import CopilotLlmModel, CopilotMode
from backend.copilot.permissions import CopilotPermissions
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
@@ -81,6 +82,12 @@ COPILOT_CANCEL_EXCHANGE = Exchange(
)
COPILOT_CANCEL_QUEUE_NAME = "copilot_cancel_queue"
def get_session_lock_key(session_id: str) -> str:
"""Redis key for the per-session cluster lock held by the executing pod."""
return f"copilot:session:{session_id}:lock"
# CoPilot operations can include extended thinking and agent generation
# which may take 30+ minutes to complete
COPILOT_CONSUMER_TIMEOUT_SECONDS = 60 * 60 # 1 hour
@@ -163,6 +170,20 @@ class CoPilotExecutionEntry(BaseModel):
model: CopilotLlmModel | None = None
"""Per-request model tier: 'standard' or 'advanced'. None = server default."""
permissions: CopilotPermissions | None = None
"""Capability filter inherited from a parent run (e.g. ``run_sub_session``
forwards its parent's permissions so the sub can't escalate). ``None``
means the worker applies no filter."""
request_arrival_at: float = 0.0
"""Unix-epoch seconds (server clock) when the originating HTTP
``/stream`` request arrived. The executor's turn-start drain uses
this to decide whether each pending message was typed BEFORE or AFTER
the turn's ``current`` message, and orders the combined user bubble
chronologically. Defaults to ``0.0`` for backward compatibility with
queue messages written before this field existed (they sort as "all
pending before current" — the pre-fix behaviour)."""
class CancelCoPilotEvent(BaseModel):
"""Event to cancel a CoPilot operation."""
@@ -184,6 +205,8 @@ async def enqueue_copilot_turn(
file_ids: list[str] | None = None,
mode: CopilotMode | None = None,
model: CopilotLlmModel | None = None,
permissions: CopilotPermissions | None = None,
request_arrival_at: float = 0.0,
) -> None:
"""Enqueue a CoPilot task for processing by the executor service.
@@ -197,6 +220,8 @@ async def enqueue_copilot_turn(
file_ids: Optional workspace file IDs attached to the user's message
mode: Autopilot mode override ('fast' or 'extended_thinking'). None = server default.
model: Per-request model tier ('standard' or 'advanced'). None = server default.
permissions: Capability filter inherited from a parent run (sub-AutoPilot).
None = no filter.
"""
from backend.util.clients import get_async_copilot_queue
@@ -210,6 +235,8 @@ async def enqueue_copilot_turn(
file_ids=file_ids,
mode=mode,
model=model,
permissions=permissions,
request_arrival_at=request_arrival_at,
)
queue_client = await get_async_copilot_queue()

View File

@@ -1,71 +0,0 @@
"""Per-request idempotency lock for the /stream endpoint.
Prevents duplicate executor tasks from concurrent or retried POSTs (e.g. k8s
rolling-deploy retries, nginx upstream retries, rapid double-clicks).
Lifecycle
---------
1. ``acquire()`` — computes a stable hash of (session_id, message, file_ids)
and atomically sets a Redis NX key. Returns a ``_DedupLock`` on success or
``None`` when the key already exists (duplicate request).
2. ``release()`` — deletes the key. Must be called on turn completion or turn
error so the next legitimate send is never blocked.
3. On client disconnect (``GeneratorExit``) the lock must NOT be released —
the backend turn is still running, and releasing would reopen the duplicate
window for infra-level retries. The 30 s TTL is the safety net.
"""
import hashlib
import logging
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
_KEY_PREFIX = "chat:msg_dedup"
_TTL_SECONDS = 30
class _DedupLock:
def __init__(self, key: str, redis) -> None:
self._key = key
self._redis = redis
async def release(self) -> None:
"""Best-effort key deletion. The TTL handles failures silently."""
try:
await self._redis.delete(self._key)
except Exception:
pass
async def acquire_dedup_lock(
session_id: str,
message: str | None,
file_ids: list[str] | None,
) -> _DedupLock | None:
"""Acquire the idempotency lock for this (session, message, files) tuple.
Returns a ``_DedupLock`` when the lock is freshly acquired (first request).
Returns ``None`` when a duplicate is detected (lock already held).
Returns ``None`` when there is nothing to deduplicate (no message, no files).
"""
if not message and not file_ids:
return None
sorted_ids = ":".join(sorted(file_ids or []))
content_hash = hashlib.sha256(
f"{session_id}:{message or ''}:{sorted_ids}".encode()
).hexdigest()[:16]
key = f"{_KEY_PREFIX}:{session_id}:{content_hash}"
redis = await get_redis_async()
acquired = await redis.set(key, "1", ex=_TTL_SECONDS, nx=True)
if not acquired:
logger.warning(
f"[STREAM] Duplicate user message blocked for session {session_id}, "
f"hash={content_hash} — returning empty SSE",
)
return None
return _DedupLock(key, redis)

View File

@@ -1,94 +0,0 @@
"""Unit tests for backend.copilot.message_dedup."""
from unittest.mock import AsyncMock
import pytest
import pytest_mock
from backend.copilot.message_dedup import _KEY_PREFIX, acquire_dedup_lock
def _patch_redis(mocker: pytest_mock.MockerFixture, *, set_returns):
mock_redis = AsyncMock()
mock_redis.set = AsyncMock(return_value=set_returns)
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=AsyncMock,
return_value=mock_redis,
)
return mock_redis
@pytest.mark.asyncio
async def test_acquire_returns_none_when_no_message_no_files(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Nothing to deduplicate — no Redis call made, None returned."""
mock_redis = _patch_redis(mocker, set_returns=True)
result = await acquire_dedup_lock("sess-1", None, None)
assert result is None
mock_redis.set.assert_not_called()
@pytest.mark.asyncio
async def test_acquire_returns_lock_on_first_request(
mocker: pytest_mock.MockerFixture,
) -> None:
"""First request acquires the lock and returns a _DedupLock."""
mock_redis = _patch_redis(mocker, set_returns=True)
lock = await acquire_dedup_lock("sess-1", "hello", None)
assert lock is not None
mock_redis.set.assert_called_once()
key_arg = mock_redis.set.call_args.args[0]
assert key_arg.startswith(f"{_KEY_PREFIX}:sess-1:")
@pytest.mark.asyncio
async def test_acquire_returns_none_on_duplicate(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Duplicate request (NX fails) returns None to signal the caller."""
_patch_redis(mocker, set_returns=None)
result = await acquire_dedup_lock("sess-1", "hello", None)
assert result is None
@pytest.mark.asyncio
async def test_acquire_key_stable_across_file_order(
mocker: pytest_mock.MockerFixture,
) -> None:
"""File IDs are sorted before hashing so order doesn't affect the key."""
mock_redis_1 = _patch_redis(mocker, set_returns=True)
await acquire_dedup_lock("sess-1", "msg", ["b", "a"])
key_ab = mock_redis_1.set.call_args.args[0]
mock_redis_2 = _patch_redis(mocker, set_returns=True)
await acquire_dedup_lock("sess-1", "msg", ["a", "b"])
key_ba = mock_redis_2.set.call_args.args[0]
assert key_ab == key_ba
@pytest.mark.asyncio
async def test_release_deletes_key(
mocker: pytest_mock.MockerFixture,
) -> None:
"""release() calls Redis delete exactly once."""
mock_redis = _patch_redis(mocker, set_returns=True)
lock = await acquire_dedup_lock("sess-1", "hello", None)
assert lock is not None
await lock.release()
mock_redis.delete.assert_called_once()
@pytest.mark.asyncio
async def test_release_swallows_redis_error(
mocker: pytest_mock.MockerFixture,
) -> None:
"""release() must not raise even when Redis delete fails."""
mock_redis = _patch_redis(mocker, set_returns=True)
mock_redis.delete = AsyncMock(side_effect=RuntimeError("redis down"))
lock = await acquire_dedup_lock("sess-1", "hello", None)
assert lock is not None
await lock.release() # must not raise
mock_redis.delete.assert_called_once()

View File

@@ -1,9 +1,8 @@
import asyncio
import logging
import uuid
from contextlib import asynccontextmanager
from datetime import UTC, datetime
from typing import Any, Self, cast
from weakref import WeakValueDictionary
from typing import Any, AsyncIterator, Self, cast
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
@@ -522,10 +521,7 @@ async def upsert_chat_session(
callers are aware of the persistence failure.
RedisError: If the cache write fails (after successful DB write).
"""
# Acquire session-specific lock to prevent concurrent upserts
lock = await _get_session_lock(session.session_id)
async with lock:
async with _get_session_lock(session.session_id) as _:
# Always query DB for existing message count to ensure consistency
existing_message_count = await chat_db().get_next_sequence(session.session_id)
@@ -651,20 +647,50 @@ async def _save_session_to_db(
msg.sequence = existing_message_count + i
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
async def append_and_save_message(
session_id: str, message: ChatMessage
) -> ChatSession | None:
"""Atomically append a message to a session and persist it.
Acquires the session lock, re-fetches the latest session state,
appends the message, and saves — preventing message loss when
concurrent requests modify the same session.
"""
lock = await _get_session_lock(session_id)
Returns the updated session, or None if the message was detected as a
duplicate (idempotency guard). Callers must check for None and skip any
downstream work (e.g. enqueuing a new LLM turn) when a duplicate is detected.
async with lock:
session = await get_chat_session(session_id)
Uses _get_session_lock (Redis NX) to serialise concurrent writers across replicas.
The idempotency check below provides a last-resort guard when the lock degrades.
"""
async with _get_session_lock(session_id) as lock_acquired:
# When the lock degraded (Redis down or 2s timeout), bypass cache for
# the idempotency check. Stale cache could let two concurrent writers
# both see the old state, pass the check, and write the same message.
if lock_acquired:
session = await get_chat_session(session_id)
else:
session = await _get_session_from_db(session_id)
if session is None:
raise ValueError(f"Session {session_id} not found")
# Idempotency: skip if the trailing block of same-role messages already
# contains this content. Uses is_message_duplicate which checks all
# consecutive trailing messages of the same role, not just [-1].
#
# This collapses infra/nginx retries whether they land on the same pod
# (serialised by the Redis lock) or a different pod.
#
# Legit same-text messages are distinguished by the assistant turn
# between them: if the user said "yes", got a response, and says
# "yes" again, session.messages[-1] is the assistant reply, so the
# role check fails and the second message goes through normally.
#
# Edge case: if a turn dies without writing any assistant message,
# the user's next send of the same text is blocked here permanently.
# The fix is to ensure failed turns always write an error/timeout
# assistant message so the session always ends on an assistant turn.
if message.content is not None and is_message_duplicate(
session.messages, message.role, message.content
):
return None # duplicate — caller should skip enqueue
session.messages.append(message)
existing_message_count = await chat_db().get_next_sequence(session_id)
@@ -679,6 +705,9 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
await cache_chat_session(session)
except Exception as e:
logger.warning(f"Cache write failed for session {session_id}: {e}")
# Invalidate the stale entry so future reads fall back to DB,
# preventing a retry from bypassing the idempotency check above.
await invalidate_session_cache(session_id)
return session
@@ -764,10 +793,6 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
except Exception as e:
logger.warning(f"Failed to delete session {session_id} from cache: {e}")
# Clean up session lock (belt-and-suspenders with WeakValueDictionary)
async with _session_locks_mutex:
_session_locks.pop(session_id, None)
# Shut down any local browser daemon for this session (best-effort).
# Inline import required: all tool modules import ChatSession from this
# module, so any top-level import from tools.* would create a cycle.
@@ -832,25 +857,38 @@ async def update_session_title(
# ==================== Chat session locks ==================== #
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
_session_locks_mutex = asyncio.Lock()
@asynccontextmanager
async def _get_session_lock(session_id: str) -> AsyncIterator[bool]:
"""Distributed Redis lock for a session, usable as an async context manager.
async def _get_session_lock(session_id: str) -> asyncio.Lock:
"""Get or create a lock for a specific session to prevent concurrent upserts.
Yields True if the lock was acquired, False if it timed out or Redis was
unavailable. Callers should treat False as a degraded mode and prefer fresh
DB reads over cache to avoid acting on stale state.
This was originally added to solve the specific problem of race conditions between
the session title thread and the conversation thread, which always occurs on the
same instance as we prevent rapid request sends on the frontend.
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
when no coroutine holds a reference to them, preventing memory leaks from
unbounded growth of session locks. Explicit cleanup also occurs
in `delete_chat_session()`.
Uses redis-py's built-in Lock (Lua-script acquire/release) so lock acquisition
is atomic and release is owner-verified. Blocks up to 2s for a concurrent
writer to finish; the 10s TTL ensures a dead pod never holds the lock forever.
"""
async with _session_locks_mutex:
lock = _session_locks.get(session_id)
if lock is None:
lock = asyncio.Lock()
_session_locks[session_id] = lock
return lock
_lock_key = f"copilot:session_lock:{session_id}"
lock = None
acquired = False
try:
_redis = await get_redis_async()
lock = _redis.lock(_lock_key, timeout=10, blocking_timeout=2)
acquired = await lock.acquire(blocking=True)
if not acquired:
logger.warning(
"Could not acquire session lock for %s within 2s", session_id
)
except Exception as e:
logger.warning("Redis unavailable for session lock on %s: %s", session_id, e)
try:
yield acquired
finally:
if acquired and lock is not None:
try:
await lock.release()
except Exception:
pass # TTL will expire the key

View File

@@ -11,11 +11,13 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
ChatCompletionMessageToolCallParam,
Function,
)
from pytest_mock import MockerFixture
from .model import (
ChatMessage,
ChatSession,
Usage,
append_and_save_message,
get_chat_session,
is_message_duplicate,
maybe_append_user_message,
@@ -574,3 +576,345 @@ def test_maybe_append_assistant_skips_duplicate():
result = maybe_append_user_message(session, "dup", is_user_message=False)
assert result is False
assert len(session.messages) == 2
# --------------------------------------------------------------------------- #
# append_and_save_message #
# --------------------------------------------------------------------------- #
def _make_session_with_messages(*msgs: ChatMessage) -> ChatSession:
s = ChatSession.new(user_id="u1", dry_run=False)
s.messages = list(msgs)
return s
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_returns_none_for_duplicate(
mocker: MockerFixture,
) -> None:
"""append_and_save_message returns None when the trailing message is a duplicate."""
session = _make_session_with_messages(
ChatMessage(role="user", content="hello"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
result = await append_and_save_message(
session.session_id, ChatMessage(role="user", content="hello")
)
assert result is None
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_appends_new_message(
mocker: MockerFixture,
) -> None:
"""append_and_save_message appends a non-duplicate message and returns the session."""
session = _make_session_with_messages(
ChatMessage(role="user", content="hello"),
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=2)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
)
new_msg = ChatMessage(role="user", content="second message")
result = await append_and_save_message(session.session_id, new_msg)
assert result is not None
assert result.messages[-1].content == "second message"
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_raises_when_session_not_found(
mocker: MockerFixture,
) -> None:
"""append_and_save_message raises ValueError when the session does not exist."""
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=None,
)
with pytest.raises(ValueError, match="not found"):
await append_and_save_message(
"missing-session-id", ChatMessage(role="user", content="hi")
)
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_uses_db_when_lock_degraded(
mocker: MockerFixture,
) -> None:
"""When the Redis lock times out (acquired=False), the fallback reads from DB."""
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=False)
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mock_get_from_db = mocker.patch(
"backend.copilot.model._get_session_from_db",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
)
new_msg = ChatMessage(role="user", content="new msg")
result = await append_and_save_message(session.session_id, new_msg)
# DB path was used (not cache-first)
mock_get_from_db.assert_called_once_with(session.session_id)
assert result is not None
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_raises_database_error_on_save_failure(
mocker: MockerFixture,
) -> None:
"""When _save_session_to_db fails, append_and_save_message raises DatabaseError."""
from backend.util.exceptions import DatabaseError
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
side_effect=RuntimeError("db down"),
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
with pytest.raises(DatabaseError):
await append_and_save_message(
session.session_id, ChatMessage(role="user", content="new msg")
)
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_invalidates_cache_on_cache_failure(
mocker: MockerFixture,
) -> None:
"""When cache_chat_session fails, invalidate_session_cache is called to avoid stale reads."""
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
side_effect=RuntimeError("redis write failed"),
)
mock_invalidate = mocker.patch(
"backend.copilot.model.invalidate_session_cache",
new_callable=mocker.AsyncMock,
)
result = await append_and_save_message(
session.session_id, ChatMessage(role="user", content="new msg")
)
# DB write succeeded, cache invalidation was called
mock_invalidate.assert_called_once_with(session.session_id)
assert result is not None
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_uses_db_when_redis_unavailable(
mocker: MockerFixture,
) -> None:
"""When get_redis_async raises, _get_session_lock yields False (degraded) and DB is read."""
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
side_effect=ConnectionError("redis down"),
)
mock_get_from_db = mocker.patch(
"backend.copilot.model._get_session_from_db",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
)
new_msg = ChatMessage(role="user", content="new msg")
result = await append_and_save_message(session.session_id, new_msg)
mock_get_from_db.assert_called_once_with(session.session_id)
assert result is not None
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_lock_release_failure_is_ignored(
mocker: MockerFixture,
) -> None:
"""If lock.release() raises, the exception is swallowed (TTL will clean up)."""
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock(
side_effect=RuntimeError("release failed")
)
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
)
new_msg = ChatMessage(role="user", content="new msg")
result = await append_and_save_message(session.session_id, new_msg)
assert result is not None

View File

@@ -0,0 +1,384 @@
"""Shared helpers for draining and injecting pending messages.
Used by both the baseline and SDK copilot paths to avoid duplicating
the try/except drain, format, insert, and persist patterns.
Also provides the call-rate-limit check for the queue endpoint so
routes.py stays free of Redis/Lua details.
"""
import logging
from typing import TYPE_CHECKING, Callable
from fastapi import HTTPException
from pydantic import BaseModel
from backend.copilot.model import ChatMessage, upsert_chat_session
from backend.copilot.pending_messages import (
MAX_PENDING_MESSAGES,
PendingMessage,
PendingMessageContext,
drain_pending_messages,
format_pending_as_user_message,
push_pending_message,
)
from backend.copilot.stream_registry import get_session as get_active_session_meta
from backend.data.redis_client import get_redis_async
from backend.data.redis_helpers import incr_with_ttl
from backend.data.workspace import resolve_workspace_files
if TYPE_CHECKING:
from backend.copilot.model import ChatSession
from backend.copilot.transcript_builder import TranscriptBuilder
logger = logging.getLogger(__name__)
# Call-frequency cap for the pending-message endpoint. The token-budget
# check guards against overspend but not rapid-fire pushes from a client
# with a large budget.
PENDING_CALL_LIMIT = 30
PENDING_CALL_WINDOW_SECONDS = 60
_PENDING_CALL_KEY_PREFIX = "copilot:pending:calls:"
async def is_turn_in_flight(session_id: str) -> bool:
"""Return ``True`` when a copilot turn is actively running for *session_id*.
Used by the unified POST /stream entry point and the autopilot block so
a second message arriving while an earlier turn is still executing gets
queued into the pending buffer instead of racing the in-flight turn on
the cluster lock.
"""
active = await get_active_session_meta(session_id)
return active is not None and active.status == "running"
class QueuePendingMessageResponse(BaseModel):
"""Response returned by ``POST /stream`` with status 202 when a message
is queued because the session already has a turn in flight.
- ``buffer_length``: how many messages are now in the session's
pending buffer (after this push)
- ``max_buffer_length``: the per-session cap (server-side constant)
- ``turn_in_flight``: ``True`` if a copilot turn was running when
we checked — purely informational for UX feedback. Always ``True``
for responses from ``POST /stream`` with status 202.
"""
buffer_length: int
max_buffer_length: int
turn_in_flight: bool
async def queue_user_message(
*,
session_id: str,
message: str,
context: PendingMessageContext | None = None,
file_ids: list[str] | None = None,
) -> QueuePendingMessageResponse:
"""Push *message* into the per-session pending buffer.
The shared primitive for "a message arrived while a turn is in flight"
called from the unified POST /stream handler and the autopilot block.
Call-frequency rate limiting is the caller's responsibility (HTTP path
enforces it; internal block callers skip it).
"""
pending = PendingMessage(
content=message,
file_ids=file_ids or [],
context=context,
)
new_len = await push_pending_message(session_id, pending)
return QueuePendingMessageResponse(
buffer_length=new_len,
max_buffer_length=MAX_PENDING_MESSAGES,
turn_in_flight=await is_turn_in_flight(session_id),
)
async def queue_pending_for_http(
*,
session_id: str,
user_id: str,
message: str,
context: dict[str, str] | None,
file_ids: list[str] | None,
) -> QueuePendingMessageResponse:
"""HTTP-facing wrapper around :func:`queue_user_message`.
Owns the HTTP-only concerns that sat inline in ``stream_chat_post``:
1. Per-user call-rate cap (429 on overflow).
2. File-ID sanitisation against the user's own workspace.
3. ``{url, content}`` dict → ``PendingMessageContext`` coercion.
4. Push via ``queue_user_message``.
Raises :class:`HTTPException` with status 429 if the rate cap is hit;
otherwise returns the ``QueuePendingMessageResponse`` the handler can
serialise 1:1 into the 202 body.
"""
call_count = await check_pending_call_rate(user_id)
if call_count > PENDING_CALL_LIMIT:
raise HTTPException(
status_code=429,
detail=(
f"Too many queued message requests this minute: limit is "
f"{PENDING_CALL_LIMIT} per {PENDING_CALL_WINDOW_SECONDS}s "
"across all sessions"
),
)
sanitized_file_ids: list[str] | None = None
if file_ids:
files = await resolve_workspace_files(user_id, file_ids)
sanitized_file_ids = [wf.id for wf in files] or None
# ``PendingMessageContext`` uses the default ``extra='ignore'`` so
# unknown keys in the loose HTTP-level ``context`` dict are silently
# dropped rather than raising ``ValidationError`` + 500ing (sentry
# r3105553772). The strict mode would only help protect against
# typos, but the upstream ``StreamChatRequest.context: dict[str, str]``
# is already schemaless, so the strict mode adds no real safety.
queue_context = PendingMessageContext.model_validate(context) if context else None
return await queue_user_message(
session_id=session_id,
message=message,
context=queue_context,
file_ids=sanitized_file_ids,
)
async def check_pending_call_rate(user_id: str) -> int:
"""Increment and return the per-user push counter for the current window.
The counter is **user-global**: it counts pushes across ALL sessions
belonging to the user, not per-session. This prevents a client from
bypassing the cap by spreading rapid pushes across many sessions.
Returns the new call count. Raises nothing — callers compare the
return value against ``PENDING_CALL_LIMIT`` and decide what to do.
Fails open (returns 0) if Redis is unavailable so the endpoint stays
usable during Redis hiccups.
"""
try:
redis = await get_redis_async()
key = f"{_PENDING_CALL_KEY_PREFIX}{user_id}"
return await incr_with_ttl(redis, key, PENDING_CALL_WINDOW_SECONDS)
except Exception:
logger.warning(
"pending_message_helpers: call-rate check failed for user=%s, failing open",
user_id,
)
return 0
async def drain_pending_safe(
session_id: str, log_prefix: str = ""
) -> list[PendingMessage]:
"""Drain the pending buffer and return the full ``PendingMessage`` objects.
Returns ``[]`` on any Redis error so callers can always treat the
result as a plain list. Callers that only need the rendered string
(turn-start injection, auto-continue combined prompt) wrap this with
:func:`pending_texts_from` — we return the structured objects so the
re-queue rollback path can preserve ``file_ids`` / ``context`` that
would otherwise be stripped by a text-only conversion.
"""
try:
return await drain_pending_messages(session_id)
except Exception:
logger.warning(
"%s drain_pending_messages failed, skipping",
log_prefix or "pending_messages",
exc_info=True,
)
return []
def pending_texts_from(pending: list[PendingMessage]) -> list[str]:
"""Render a list of ``PendingMessage`` objects into plain text strings.
Shared helper for the two callers that need the rendered form:
turn-start injection (bundles the pending block into the user prompt)
and the auto-continue combined-message path.
"""
return [format_pending_as_user_message(pm)["content"] for pm in pending]
def combine_pending_with_current(
pending: list[PendingMessage],
current_message: str | None,
*,
request_arrival_at: float,
) -> str:
"""Order pending messages around *current_message* by typing time.
Pending messages whose ``enqueued_at`` is strictly greater than
``request_arrival_at`` were typed AFTER the user hit enter to start
the current turn (the "race" path: queued into the pending buffer
while ``/stream`` was still processing on the server). They belong
chronologically AFTER the current message.
Pending messages whose ``enqueued_at`` is less than or equal to
``request_arrival_at`` were typed BEFORE the current turn — usually
from a prior in-flight window that auto-continue didn't consume.
They belong BEFORE the current message.
Stable-sort within each bucket preserves enqueue order for messages
typed in the same phase. Legacy ``PendingMessage`` objects with no
``enqueued_at`` (written by older workers, defaulted to 0.0) sort as
"before everything" — the pre-fix behaviour, which is a safe default
for the rare queue entries that outlived a deploy.
"""
before: list[PendingMessage] = []
after: list[PendingMessage] = []
for pm in pending:
if request_arrival_at > 0 and pm.enqueued_at > request_arrival_at:
after.append(pm)
else:
before.append(pm)
parts = pending_texts_from(before)
if current_message and current_message.strip():
parts.append(current_message)
parts.extend(pending_texts_from(after))
return "\n\n".join(parts)
def insert_pending_before_last(session: "ChatSession", texts: list[str]) -> None:
"""Insert pending messages into *session* just before the last message.
Pending messages were queued during the previous turn, so they belong
chronologically before the current user message that was already
appended via ``maybe_append_user_message``. Inserting at ``len-1``
preserves that order: [...history, pending_1, pending_2, current_msg].
The caller must have already appended the current user message before
calling this function. If ``session.messages`` is unexpectedly empty,
a warning is logged and the messages are appended at index 0 so they
are not silently lost.
"""
if not texts:
return
if not session.messages:
logger.warning(
"insert_pending_before_last: session.messages is empty — "
"current user message was not appended before drain; "
"inserting pending messages at index 0"
)
insert_idx = max(0, len(session.messages) - 1)
for i, content in enumerate(texts):
session.messages.insert(
insert_idx + i, ChatMessage(role="user", content=content)
)
async def persist_session_safe(
session: "ChatSession", log_prefix: str = ""
) -> "ChatSession":
"""Persist *session* to the DB, returning the (possibly updated) session.
Swallows transient DB errors so a failing persist doesn't discard
messages already popped from Redis — the turn continues from memory.
"""
try:
return await upsert_chat_session(session)
except Exception as err:
logger.warning(
"%s Failed to persist pending messages: %s",
log_prefix or "pending_messages",
err,
)
return session
async def persist_pending_as_user_rows(
session: "ChatSession",
transcript_builder: "TranscriptBuilder",
pending: list[PendingMessage],
*,
log_prefix: str,
content_of: Callable[[PendingMessage], str] = lambda pm: pm.content,
on_rollback: Callable[[int], None] | None = None,
) -> bool:
"""Append ``pending`` as user rows to *session* + *transcript_builder*,
persist, and roll back + re-queue if the persist silently failed.
This is the shared mid-turn follow-up persist used by both the baseline
and SDK paths — they differ only in (a) how they derive the displayed
string from a ``PendingMessage`` and (b) what extra per-path state
(e.g. ``openai_messages``) needs trimming on rollback. Those variance
points are exposed as ``content_of`` and ``on_rollback``.
Flow:
1. Snapshot transcript + record the session.messages length.
2. Append one user row per pending message to both stores.
3. ``persist_session_safe`` — swallowed errors mean no sequences get
back-filled, which we use as the failure signal.
4. If any newly-appended row has ``sequence is None`` → rollback:
delete the appended rows, restore the transcript snapshot, call
``on_rollback(anchor)`` for the caller's own state, then re-push
each ``PendingMessage`` into the primary pending buffer so the
next turn-start drain picks them up.
Returns ``True`` when the rows were persisted with sequences, ``False``
when the rollback path fired. Callers can use this to decide whether
to log success or continue a retry loop.
"""
if not pending:
return True
session_anchor = len(session.messages)
transcript_snapshot = transcript_builder.snapshot()
for pm in pending:
content = content_of(pm)
session.messages.append(ChatMessage(role="user", content=content))
transcript_builder.append_user(content=content)
# ``persist_session_safe`` may return a ``model_copy`` of *session* (e.g.
# when ``upsert_chat_session`` patches a concurrently-updated title).
# Do NOT reassign the caller's reference — the caller already pushed the
# rows into its own ``session.messages`` above, and rollback below MUST
# delete from that same list. Inspect the returned object only to learn
# whether sequences were back-filled; if so, copy them onto the caller's
# objects so the session stays internally consistent for downstream
# ``append_and_save_message`` calls.
persisted = await persist_session_safe(session, log_prefix)
persisted_tail = persisted.messages[session_anchor:]
if len(persisted_tail) == len(pending) and all(
m.sequence is not None for m in persisted_tail
):
for caller_msg, persisted_msg in zip(
session.messages[session_anchor:], persisted_tail
):
caller_msg.sequence = persisted_msg.sequence
newly_appended = session.messages[session_anchor:]
if any(m.sequence is None for m in newly_appended):
logger.warning(
"%s Mid-turn follow-up persist did not back-fill sequences; "
"rolling back %d row(s) and re-queueing into the primary buffer",
log_prefix,
len(pending),
)
del session.messages[session_anchor:]
transcript_builder.restore(transcript_snapshot)
if on_rollback is not None:
on_rollback(session_anchor)
for pm in pending:
try:
await push_pending_message(session.session_id, pm)
except Exception:
logger.exception(
"%s Failed to re-queue mid-turn follow-up on rollback",
log_prefix,
)
return False
logger.info(
"%s Persisted %d mid-turn follow-up user row(s)",
log_prefix,
len(pending),
)
return True

View File

@@ -0,0 +1,472 @@
"""Unit tests for pending_message_helpers."""
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from backend.copilot import pending_message_helpers as helpers_module
from backend.copilot.pending_message_helpers import (
PENDING_CALL_LIMIT,
check_pending_call_rate,
combine_pending_with_current,
drain_pending_safe,
insert_pending_before_last,
persist_session_safe,
)
from backend.copilot.pending_messages import PendingMessage
# ── check_pending_call_rate ────────────────────────────────────────────
@pytest.mark.asyncio
async def test_check_pending_call_rate_returns_count(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(
helpers_module, "get_redis_async", AsyncMock(return_value=MagicMock())
)
monkeypatch.setattr(helpers_module, "incr_with_ttl", AsyncMock(return_value=3))
result = await check_pending_call_rate("user-1")
assert result == 3
@pytest.mark.asyncio
async def test_check_pending_call_rate_fails_open_on_redis_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(
helpers_module,
"get_redis_async",
AsyncMock(side_effect=ConnectionError("down")),
)
result = await check_pending_call_rate("user-1")
assert result == 0
@pytest.mark.asyncio
async def test_check_pending_call_rate_at_limit(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(
helpers_module, "get_redis_async", AsyncMock(return_value=MagicMock())
)
monkeypatch.setattr(
helpers_module,
"incr_with_ttl",
AsyncMock(return_value=PENDING_CALL_LIMIT + 1),
)
result = await check_pending_call_rate("user-1")
assert result > PENDING_CALL_LIMIT
# ── drain_pending_safe ─────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_drain_pending_safe_returns_pending_messages(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""``drain_pending_safe`` now returns the structured ``PendingMessage``
objects (not pre-formatted strings) so the auto-continue re-queue path
can preserve ``file_ids`` / ``context`` on rollback."""
msgs = [
PendingMessage(content="hello", file_ids=["f1"]),
PendingMessage(content="world"),
]
monkeypatch.setattr(
helpers_module, "drain_pending_messages", AsyncMock(return_value=msgs)
)
result = await drain_pending_safe("sess-1")
assert result == msgs
# Structured metadata survives — the bug r3105523410 guard.
assert result[0].file_ids == ["f1"]
@pytest.mark.asyncio
async def test_drain_pending_safe_returns_empty_on_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(
helpers_module,
"drain_pending_messages",
AsyncMock(side_effect=RuntimeError("redis down")),
)
result = await drain_pending_safe("sess-1", "[Test]")
assert result == []
@pytest.mark.asyncio
async def test_drain_pending_safe_empty_buffer(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
helpers_module, "drain_pending_messages", AsyncMock(return_value=[])
)
result = await drain_pending_safe("sess-1")
assert result == []
# ── combine_pending_with_current ───────────────────────────────────────
def test_combine_before_current_when_pending_older() -> None:
"""Pending typed before the /stream request → goes ahead of current
(prior-turn / inter-turn case)."""
pending = [
PendingMessage(content="older_a", enqueued_at=100.0),
PendingMessage(content="older_b", enqueued_at=110.0),
]
result = combine_pending_with_current(
pending, "current_msg", request_arrival_at=120.0
)
assert result == "older_a\n\nolder_b\n\ncurrent_msg"
def test_combine_after_current_when_pending_newer() -> None:
"""Pending queued AFTER the /stream request arrived → goes after
current. This is the race path where user hits enter twice in quick
succession (second press goes through the queue endpoint while the
first /stream is still processing)."""
pending = [
PendingMessage(content="race_followup", enqueued_at=125.0),
]
result = combine_pending_with_current(
pending, "current_msg", request_arrival_at=120.0
)
assert result == "current_msg\n\nrace_followup"
def test_combine_mixed_before_and_after() -> None:
"""Mixed bucket: older items first, current, then newer race items."""
pending = [
PendingMessage(content="way_older", enqueued_at=50.0),
PendingMessage(content="race_fast_follow", enqueued_at=125.0),
PendingMessage(content="also_older", enqueued_at=80.0),
]
result = combine_pending_with_current(
pending, "current_msg", request_arrival_at=120.0
)
# Enqueue order preserved within each bucket (stable partition).
assert result == "way_older\n\nalso_older\n\ncurrent_msg\n\nrace_fast_follow"
def test_combine_no_current_joins_pending() -> None:
"""Auto-continue case: no current message, just drained pending."""
pending = [PendingMessage(content="a"), PendingMessage(content="b")]
result = combine_pending_with_current(pending, None, request_arrival_at=0.0)
assert result == "a\n\nb"
def test_combine_legacy_zero_timestamp_sorts_before() -> None:
"""A ``PendingMessage`` from before this field existed (default 0.0)
should sort as "before everything" — safe pre-fix behaviour."""
pending = [PendingMessage(content="legacy", enqueued_at=0.0)]
result = combine_pending_with_current(
pending, "current_msg", request_arrival_at=120.0
)
assert result == "legacy\n\ncurrent_msg"
def test_combine_missing_request_arrival_falls_back_to_before() -> None:
"""If the HTTP handler didn't stamp ``request_arrival_at`` (0.0
default — older queue entries) the combine degrades gracefully to
the pre-fix behaviour: all pending goes before current."""
pending = [
PendingMessage(content="a", enqueued_at=500.0),
PendingMessage(content="b", enqueued_at=1000.0),
]
result = combine_pending_with_current(pending, "current", request_arrival_at=0.0)
assert result == "a\n\nb\n\ncurrent"
# ── insert_pending_before_last ─────────────────────────────────────────
def _make_session(*contents: str) -> Any:
session = MagicMock()
session.messages = [MagicMock(role="user", content=c) for c in contents]
return session
def test_insert_pending_before_last_single_existing_message() -> None:
session = _make_session("current")
insert_pending_before_last(session, ["queued"])
assert session.messages[0].content == "queued"
assert session.messages[1].content == "current"
def test_insert_pending_before_last_multiple_pending() -> None:
session = _make_session("current")
insert_pending_before_last(session, ["p1", "p2"])
contents = [m.content for m in session.messages]
assert contents == ["p1", "p2", "current"]
def test_insert_pending_before_last_empty_session() -> None:
session = _make_session()
insert_pending_before_last(session, ["queued"])
assert session.messages[0].content == "queued"
def test_insert_pending_before_last_no_texts_is_noop() -> None:
session = _make_session("current")
insert_pending_before_last(session, [])
assert len(session.messages) == 1
# ── persist_session_safe ───────────────────────────────────────────────
@pytest.mark.asyncio
async def test_persist_session_safe_returns_updated_session(
monkeypatch: pytest.MonkeyPatch,
) -> None:
original = MagicMock()
updated = MagicMock()
monkeypatch.setattr(
helpers_module, "upsert_chat_session", AsyncMock(return_value=updated)
)
result = await persist_session_safe(original, "[Test]")
assert result is updated
@pytest.mark.asyncio
async def test_persist_session_safe_returns_original_on_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
original = MagicMock()
monkeypatch.setattr(
helpers_module,
"upsert_chat_session",
AsyncMock(side_effect=Exception("db error")),
)
result = await persist_session_safe(original, "[Test]")
assert result is original
# ── persist_pending_as_user_rows ───────────────────────────────────────
class _FakeTranscript:
"""Minimal TranscriptBuilder shim — records append_user + snapshot/restore."""
def __init__(self) -> None:
self.entries: list[str] = []
def append_user(self, content: str, uuid: str | None = None) -> None:
self.entries.append(content)
def snapshot(self) -> list[str]:
return list(self.entries)
def restore(self, snap: list[str]) -> None:
self.entries = list(snap)
def _make_chat_message_class(
monkeypatch: pytest.MonkeyPatch,
) -> Any:
"""Return a simple ChatMessage stand-in that tracks sequence."""
class _Msg:
def __init__(self, role: str, content: str) -> None:
self.role = role
self.content = content
self.sequence: int | None = None
monkeypatch.setattr(helpers_module, "ChatMessage", _Msg)
return _Msg
@pytest.mark.asyncio
async def test_persist_pending_empty_list_is_noop(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from backend.copilot.pending_message_helpers import persist_pending_as_user_rows
_make_chat_message_class(monkeypatch)
session = MagicMock()
session.messages = []
tb = _FakeTranscript()
monkeypatch.setattr(helpers_module, "upsert_chat_session", AsyncMock())
monkeypatch.setattr(helpers_module, "push_pending_message", AsyncMock())
ok = await persist_pending_as_user_rows(session, tb, [], log_prefix="[T]")
assert ok is True
assert session.messages == []
assert tb.entries == []
@pytest.mark.asyncio
async def test_persist_pending_happy_path_appends_and_returns_true(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from backend.copilot.pending_message_helpers import persist_pending_as_user_rows
from backend.copilot.pending_messages import PendingMessage as PM
_make_chat_message_class(monkeypatch)
session = MagicMock()
session.session_id = "sess"
session.messages = []
tb = _FakeTranscript()
async def _fake_upsert(sess: Any) -> Any:
# Simulate the DB back-filling sequence numbers on success.
for i, m in enumerate(sess.messages):
m.sequence = i
return sess
monkeypatch.setattr(helpers_module, "upsert_chat_session", _fake_upsert)
push_mock = AsyncMock()
monkeypatch.setattr(helpers_module, "push_pending_message", push_mock)
pending = [PM(content="a"), PM(content="b")]
ok = await persist_pending_as_user_rows(session, tb, pending, log_prefix="[T]")
assert ok is True
assert [m.content for m in session.messages] == ["a", "b"]
assert tb.entries == ["a", "b"]
push_mock.assert_not_awaited()
@pytest.mark.asyncio
async def test_persist_pending_rollback_when_sequence_missing(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from backend.copilot.pending_message_helpers import persist_pending_as_user_rows
from backend.copilot.pending_messages import PendingMessage as PM
_make_chat_message_class(monkeypatch)
session = MagicMock()
session.session_id = "sess"
# Prior state — anchor point is len(messages) before the helper runs.
session.messages = []
tb = _FakeTranscript()
tb.entries = ["earlier-entry"]
async def _fake_upsert_fails_silently(sess: Any) -> Any:
# Simulate the "persist swallowed the error" branch — sequences stay None.
return sess
monkeypatch.setattr(
helpers_module, "upsert_chat_session", _fake_upsert_fails_silently
)
push_mock = AsyncMock()
monkeypatch.setattr(helpers_module, "push_pending_message", push_mock)
pending = [PM(content="a"), PM(content="b")]
ok = await persist_pending_as_user_rows(session, tb, pending, log_prefix="[T]")
assert ok is False
# Rollback: session.messages trimmed to anchor, transcript restored.
assert session.messages == []
assert tb.entries == ["earlier-entry"]
# Both pending messages re-queued.
assert push_mock.await_count == 2
assert push_mock.await_args_list[0].args[1] is pending[0]
assert push_mock.await_args_list[1].args[1] is pending[1]
@pytest.mark.asyncio
async def test_persist_pending_rollback_calls_on_rollback_hook(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Baseline's openai_messages trim runs via the on_rollback hook."""
from backend.copilot.pending_message_helpers import persist_pending_as_user_rows
from backend.copilot.pending_messages import PendingMessage as PM
_make_chat_message_class(monkeypatch)
session = MagicMock()
session.session_id = "sess"
session.messages = []
tb = _FakeTranscript()
async def _fails(sess: Any) -> Any:
return sess
monkeypatch.setattr(helpers_module, "upsert_chat_session", _fails)
monkeypatch.setattr(helpers_module, "push_pending_message", AsyncMock())
on_rollback_calls: list[int] = []
def _on_rollback(anchor: int) -> None:
on_rollback_calls.append(anchor)
await persist_pending_as_user_rows(
session,
tb,
[PM(content="x")],
log_prefix="[T]",
on_rollback=_on_rollback,
)
assert on_rollback_calls == [0]
@pytest.mark.asyncio
async def test_persist_pending_uses_custom_content_of(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from backend.copilot.pending_message_helpers import persist_pending_as_user_rows
from backend.copilot.pending_messages import PendingMessage as PM
_make_chat_message_class(monkeypatch)
session = MagicMock()
session.session_id = "sess"
session.messages = []
tb = _FakeTranscript()
async def _ok(sess: Any) -> Any:
for i, m in enumerate(sess.messages):
m.sequence = i
return sess
monkeypatch.setattr(helpers_module, "upsert_chat_session", _ok)
monkeypatch.setattr(helpers_module, "push_pending_message", AsyncMock())
await persist_pending_as_user_rows(
session,
tb,
[PM(content="raw")],
log_prefix="[T]",
content_of=lambda pm: f"FORMATTED:{pm.content}",
)
assert session.messages[0].content == "FORMATTED:raw"
assert tb.entries == ["FORMATTED:raw"]
@pytest.mark.asyncio
async def test_persist_pending_swallows_requeue_errors(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A broken push_pending_message on rollback must not raise upward —
the rollback still needs to trim state even if re-queue fails."""
from backend.copilot.pending_message_helpers import persist_pending_as_user_rows
from backend.copilot.pending_messages import PendingMessage as PM
_make_chat_message_class(monkeypatch)
session = MagicMock()
session.session_id = "sess"
session.messages = []
tb = _FakeTranscript()
async def _fails(sess: Any) -> Any:
return sess
monkeypatch.setattr(helpers_module, "upsert_chat_session", _fails)
monkeypatch.setattr(
helpers_module,
"push_pending_message",
AsyncMock(side_effect=RuntimeError("redis down")),
)
ok = await persist_pending_as_user_rows(
session, tb, [PM(content="x")], log_prefix="[T]"
)
# Still returns False (rolled back) — exception was logged + swallowed.
assert ok is False

View File

@@ -0,0 +1,450 @@
"""Pending-message buffer for in-flight copilot turns.
When a user sends a new message while a copilot turn is already executing,
instead of blocking the frontend (or queueing a brand-new turn after the
current one finishes), we want the new message to be *injected into the
running turn* — appended between tool-call rounds so the model sees it
before its next LLM call.
This module provides the cross-process buffer that makes that possible:
- **Producer** (chat API route): pushes a pending message to Redis and
publishes a notification on a pub/sub channel.
- **Consumer** (executor running the turn): on each tool-call round,
drains the buffer and appends the pending messages to the conversation.
The Redis list is the durable store; the pub/sub channel is a fast
wake-up hint for long-idle consumers (not used by default, but available
for future blocking-wait semantics).
A hard cap of ``MAX_PENDING_MESSAGES`` per session prevents abuse. The
buffer is trimmed to the latest ``MAX_PENDING_MESSAGES`` on every push.
"""
import json
import logging
import time
from typing import Any, cast
from pydantic import BaseModel, Field, ValidationError
from backend.data.redis_client import get_redis_async
from backend.data.redis_helpers import capped_rpush
logger = logging.getLogger(__name__)
# Per-session cap. Higher values risk a runaway consumer; lower values
# risk dropping user input under heavy typing. 10 was chosen as a
# reasonable ceiling — a user typing faster than the copilot can drain
# between tool rounds is already an unusual usage pattern.
MAX_PENDING_MESSAGES = 10
# Redis key + TTL. The buffer is ephemeral: if a turn completes or the
# executor dies, the pending messages should either have been drained
# already or are safe to drop (the user can resend).
_PENDING_KEY_PREFIX = "copilot:pending:"
_PENDING_CHANNEL_PREFIX = "copilot:pending:notify:"
_PENDING_TTL_SECONDS = 3600 # 1 hour — matches stream_ttl default
# Secondary queue that carries drained-but-awaiting-persist PendingMessages
# from the MCP tool wrapper (which drains the primary buffer and injects
# into tool output for the LLM) to sdk/service.py's _dispatch_response
# handler for StreamToolOutputAvailable, which pops and persists them as a
# separate user row chronologically after the tool_result row. This is the
# hand-off between "Claude saw the follow-up mid-turn" (wrapper) and "UI
# renders a user bubble for it" (service). Rollback path re-queues into
# the PRIMARY buffer so the next turn-start drain picks them up if the
# user-row persist fails.
_PERSIST_QUEUE_KEY_PREFIX = "copilot:pending-persist:"
# Payload sent on the pub/sub notify channel. Subscribers treat any
# message as a wake-up hint; the value itself is not meaningful.
_NOTIFY_PAYLOAD = "1"
class PendingMessageContext(BaseModel):
"""Structured page context attached to a pending message.
Default ``extra='ignore'`` (pydantic's default): unknown keys from
the loose HTTP-level ``StreamChatRequest.context: dict[str, str]``
are silently dropped rather than raising ``ValidationError`` on
forward-compat additions. The strict ``extra='forbid'`` mode was
removed after sentry r3105553772 — strict validation at this
boundary only added a 500 footgun; the upstream request model is
already schemaless so strict mode protects nothing.
"""
url: str | None = Field(default=None, max_length=2_000)
content: str | None = Field(default=None, max_length=32_000)
class PendingMessage(BaseModel):
"""A user message queued for injection into an in-flight turn."""
content: str = Field(min_length=1, max_length=32_000)
file_ids: list[str] = Field(default_factory=list, max_length=20)
context: PendingMessageContext | None = None
# Wall-clock time (unix seconds, float) the message was queued by the
# user. Used by the turn-start drain to order pending relative to the
# turn's ``current`` message: items typed *before* the current's
# /stream arrival go ahead of it; items typed *after* (race path,
# queued while the /stream HTTP request was still processing) go
# after. Defaults to 0.0 for backward compatibility with entries
# written before this field existed — those sort as "before everything"
# which matches the pre-fix behaviour.
enqueued_at: float = Field(default_factory=time.time)
def _buffer_key(session_id: str) -> str:
return f"{_PENDING_KEY_PREFIX}{session_id}"
def _notify_channel(session_id: str) -> str:
return f"{_PENDING_CHANNEL_PREFIX}{session_id}"
def _decode_redis_item(item: Any) -> str:
"""Decode a redis-py list item to a str.
redis-py returns ``bytes`` when ``decode_responses=False`` and ``str``
when ``decode_responses=True``. This helper handles both so callers
don't have to repeat the isinstance guard.
"""
return item.decode("utf-8") if isinstance(item, bytes) else str(item)
async def push_pending_message(
session_id: str,
message: PendingMessage,
) -> int:
"""Append a pending message to the session's buffer.
Returns the new buffer length. Enforces ``MAX_PENDING_MESSAGES`` by
trimming from the left (oldest) — the newest message always wins if
the user has been typing faster than the copilot can drain.
Delegates to :func:`backend.data.redis_helpers.capped_rpush` so RPUSH
+ LTRIM + EXPIRE + LLEN run atomically (MULTI/EXEC) in one round
trip; a concurrent drain (LPOP) can no longer observe the list
temporarily over ``MAX_PENDING_MESSAGES``.
Note on durability: if the executor turn crashes after a push but before
the drain window runs, the message remains in Redis until the TTL expires
(``_PENDING_TTL_SECONDS``, currently 1 hour). It is delivered on the
next turn that drains the buffer. If no turn runs within the TTL the
message is silently dropped; the user may resend it.
"""
redis = await get_redis_async()
key = _buffer_key(session_id)
payload = message.model_dump_json()
new_length = await capped_rpush(
redis,
key,
payload,
max_len=MAX_PENDING_MESSAGES,
ttl_seconds=_PENDING_TTL_SECONDS,
)
# Fire-and-forget notify. Subscribers use this as a wake-up hint;
# the buffer itself is authoritative so a lost notify is harmless.
try:
await redis.publish(_notify_channel(session_id), _NOTIFY_PAYLOAD)
except Exception as e: # pragma: no cover
logger.warning("pending_messages: publish failed for %s: %s", session_id, e)
logger.info(
"pending_messages: pushed message to session=%s (buffer_len=%d)",
session_id,
new_length,
)
return new_length
async def drain_pending_messages(session_id: str) -> list[PendingMessage]:
"""Atomically pop all pending messages for *session_id*.
Returns them in enqueue order (oldest first). Uses ``LPOP`` with a
count so the read+delete is a single Redis round trip. If the list
is empty or missing, returns ``[]``.
"""
redis = await get_redis_async()
key = _buffer_key(session_id)
# Redis LPOP with count (Redis 6.2+) returns None for missing key,
# empty list if we somehow race an empty key, or the popped items.
# Draining MAX_PENDING_MESSAGES at once is safe because the push side
# uses RPUSH + LTRIM(-MAX_PENDING_MESSAGES, -1) to cap the list to that
# same value, so the list can never hold more items than we drain here.
# If the cap is raised on the push side, raise the drain count here too
# (or switch to a loop drain).
lpop_result = await redis.lpop(key, MAX_PENDING_MESSAGES) # type: ignore[assignment]
if not lpop_result:
return []
raw_popped: list[Any] = list(lpop_result)
# redis-py may return bytes or str depending on decode_responses.
decoded: list[str] = [_decode_redis_item(item) for item in raw_popped]
messages: list[PendingMessage] = []
for payload in decoded:
try:
messages.append(PendingMessage.model_validate(json.loads(payload)))
except (json.JSONDecodeError, ValidationError, TypeError, ValueError) as e:
logger.warning(
"pending_messages: dropping malformed entry for %s: %s",
session_id,
e,
)
if messages:
logger.info(
"pending_messages: drained %d messages for session=%s",
len(messages),
session_id,
)
return messages
async def peek_pending_count(session_id: str) -> int:
"""Return the current buffer length without consuming it."""
redis = await get_redis_async()
length = await cast("Any", redis.llen(_buffer_key(session_id)))
return int(length)
async def peek_pending_messages(session_id: str) -> list[PendingMessage]:
"""Return pending messages without consuming them.
Uses LRANGE 0 -1 to read all items in enqueue order (oldest first)
without removing them. Returns an empty list if the buffer is empty
or the session has no pending messages.
"""
redis = await get_redis_async()
key = _buffer_key(session_id)
items = await cast("Any", redis.lrange(key, 0, -1))
if not items:
return []
messages: list[PendingMessage] = []
for item in items:
try:
messages.append(
PendingMessage.model_validate(json.loads(_decode_redis_item(item)))
)
except (json.JSONDecodeError, ValidationError, TypeError, ValueError) as e:
logger.warning(
"pending_messages: dropping malformed peek entry for %s: %s",
session_id,
e,
)
return messages
async def _clear_pending_messages_unsafe(session_id: str) -> None:
"""Drop the session's pending buffer — **not** the normal turn cleanup.
Named ``_unsafe`` because reaching for this at turn end drops queued
follow-ups on the floor instead of running them (the bug fixed by
commit b64be73). The atomic ``LPOP`` drain at turn start is the
primary consumer; anything pushed after the drain window belongs to
the next turn by definition. Retained only as an operator/debug
escape hatch for manually clearing a stuck session and as a fixture
in the unit tests.
"""
redis = await get_redis_async()
await redis.delete(_buffer_key(session_id))
# Per-message and total-block caps for inline tool-boundary injection.
# Per-message keeps a single long paste from dominating; the total cap
# keeps the follow-up block small relative to the 100 KB MCP truncation
# boundary so tool output always stays the larger share of the wrapper
# return value.
_FOLLOWUP_CONTENT_MAX_CHARS = 2_000
_FOLLOWUP_TOTAL_MAX_CHARS = 6_000
def _persist_queue_key(session_id: str) -> str:
return f"{_PERSIST_QUEUE_KEY_PREFIX}{session_id}"
async def stash_pending_for_persist(
session_id: str,
messages: list[PendingMessage],
) -> None:
"""Enqueue drained PendingMessages for UI-row persistence.
Writes each message as a JSON payload to
``copilot:pending-persist:{session_id}``. The SDK service's
tool-result dispatch handler LPOPs this queue right after appending
the tool_result row to ``session.messages``, so the resulting user
row lands at the correct chronological position (after the tool
output the follow-up was drained against).
Fire-and-forget on Redis failures: a stash failure means Claude
still saw the follow-up in tool output (the injection step ran
first), so the only consequence is a missing UI bubble. Logged
so it can be spotted.
"""
if not messages:
return
try:
redis = await get_redis_async()
key = _persist_queue_key(session_id)
payloads = [m.model_dump_json() for m in messages]
await redis.rpush(key, *payloads) # type: ignore[misc]
await redis.expire(key, _PENDING_TTL_SECONDS) # type: ignore[misc]
except Exception:
logger.warning(
"pending_messages: failed to stash %d message(s) for persist "
"(session=%s); UI will miss the follow-up bubble but Claude "
"already saw the content in tool output",
len(messages),
session_id,
exc_info=True,
)
async def drain_pending_for_persist(session_id: str) -> list[PendingMessage]:
"""Atomically drain the persist queue for *session_id*.
Returns the queued ``PendingMessage`` objects in enqueue order (oldest
first). Returns ``[]`` on any error so the service-layer caller can
always treat the result as a plain list. Called by sdk/service.py
after appending a tool_result row to ``session.messages``.
"""
try:
redis = await get_redis_async()
key = _persist_queue_key(session_id)
lpop_result = await redis.lpop( # type: ignore[assignment]
key, MAX_PENDING_MESSAGES
)
except Exception:
logger.warning(
"pending_messages: drain_pending_for_persist failed for session=%s",
session_id,
exc_info=True,
)
return []
if not lpop_result:
return []
raw_popped: list[Any] = list(lpop_result)
messages: list[PendingMessage] = []
for item in raw_popped:
try:
messages.append(
PendingMessage.model_validate(json.loads(_decode_redis_item(item)))
)
except (json.JSONDecodeError, ValidationError, TypeError, ValueError) as e:
logger.warning(
"pending_messages: dropping malformed persist-queue entry "
"for %s: %s",
session_id,
e,
)
return messages
def format_pending_as_followup(pending: list[PendingMessage]) -> str:
"""Render drained pending messages as a ``<user_follow_up>`` block.
Used by the SDK tool-boundary injection path to surface queued user
text inside a tool result so the model reads it on the next LLM round,
without starting a separate turn. Wrapped in a stable XML-style tag so
the shared system-prompt supplement can teach the model to treat the
contents as the user's continuation of their request, not as tool
output. Each message is capped to keep the block bounded even if the
user pastes long content.
"""
if not pending:
return ""
rendered: list[str] = []
total_chars = 0
dropped = 0
for idx, pm in enumerate(pending, start=1):
text = pm.content
if len(text) > _FOLLOWUP_CONTENT_MAX_CHARS:
text = text[:_FOLLOWUP_CONTENT_MAX_CHARS] + "… [truncated]"
entry = f"Message {idx}:\n{text}"
if pm.context and pm.context.url:
entry += f"\n[Page URL: {pm.context.url}]"
if pm.file_ids:
entry += "\n[Attached files: " + ", ".join(pm.file_ids) + "]"
if total_chars + len(entry) > _FOLLOWUP_TOTAL_MAX_CHARS:
dropped = len(pending) - idx + 1
break
rendered.append(entry)
total_chars += len(entry)
if dropped:
rendered.append(f"… [{dropped} more message(s) truncated]")
body = "\n\n".join(rendered)
return (
"<user_follow_up>\n"
"The user sent the following message(s) while this tool was running. "
"Treat them as a continuation of their current request — acknowledge "
"and act on them in your next response. Do not echo these tags back.\n\n"
f"{body}\n"
"</user_follow_up>"
)
async def drain_and_format_for_injection(
session_id: str,
*,
log_prefix: str,
) -> str:
"""Drain the pending buffer and produce a ``<user_follow_up>`` block.
Shared entry point for every mid-turn injection site (``PostToolUse``
hook for MCP + built-in tools, baseline between-rounds drain, etc.).
Also stashes the drained messages on the persist queue so the service
layer appends a real user row after the tool_result it rode in on —
giving the UI a correctly-ordered bubble.
Returns an empty string if nothing was queued or Redis failed; callers
can pass the result straight to ``additionalContext``.
"""
if not session_id:
return ""
try:
pending = await drain_pending_messages(session_id)
except Exception:
logger.warning(
"%s drain_pending_messages failed (session=%s); skipping injection",
log_prefix,
session_id,
exc_info=True,
)
return ""
if not pending:
return ""
logger.info(
"%s Injected %d user follow-up(s) into tool output (session=%s)",
log_prefix,
len(pending),
session_id,
)
await stash_pending_for_persist(session_id, pending)
return format_pending_as_followup(pending)
def format_pending_as_user_message(message: PendingMessage) -> dict[str, Any]:
"""Shape a ``PendingMessage`` into the OpenAI-format user message dict.
Used by the baseline tool-call loop when injecting the buffered
message into the conversation. Context/file metadata (if any) is
embedded into the content so the model sees everything in one block.
"""
parts: list[str] = [message.content]
if message.context:
if message.context.url:
parts.append(f"\n\n[Page URL: {message.context.url}]")
if message.context.content:
parts.append(f"\n\n[Page content]\n{message.context.content}")
if message.file_ids:
parts.append(
"\n\n[Attached files]\n"
+ "\n".join(f"- file_id={fid}" for fid in message.file_ids)
+ "\nUse read_workspace_file with the file_id to access file contents."
)
return {"role": "user", "content": "".join(parts)}

View File

@@ -0,0 +1,614 @@
"""Tests for the copilot pending-messages buffer.
Uses a fake async Redis client so the tests don't require a real Redis
instance (the backend test suite's DB/Redis fixtures are heavyweight
and pull in the full app startup).
"""
import asyncio
import json
from typing import Any
import pytest
from backend.copilot import pending_messages as pm_module
from backend.copilot.pending_messages import (
MAX_PENDING_MESSAGES,
PendingMessage,
PendingMessageContext,
_clear_pending_messages_unsafe,
drain_and_format_for_injection,
drain_pending_for_persist,
drain_pending_messages,
format_pending_as_followup,
format_pending_as_user_message,
peek_pending_count,
peek_pending_messages,
push_pending_message,
stash_pending_for_persist,
)
# ── Fake Redis ──────────────────────────────────────────────────────
class _FakeRedis:
def __init__(self) -> None:
# Values are ``str | bytes`` because real redis-py returns
# bytes when ``decode_responses=False``; the drain path must
# handle both and our tests exercise both.
self.lists: dict[str, list[str | bytes]] = {}
self.published: list[tuple[str, str]] = []
async def rpush(self, key: str, *values: Any) -> int:
lst = self.lists.setdefault(key, [])
lst.extend(values)
return len(lst)
async def ltrim(self, key: str, start: int, stop: int) -> None:
lst = self.lists.get(key, [])
# Redis LTRIM stop is inclusive; -1 means the last element.
if stop == -1:
self.lists[key] = lst[start:]
else:
self.lists[key] = lst[start : stop + 1]
async def expire(self, key: str, seconds: int) -> int:
# Fake doesn't enforce TTL — just acknowledge.
return 1
async def publish(self, channel: str, payload: str) -> int:
self.published.append((channel, payload))
return 1
async def lpop(self, key: str, count: int) -> list[str | bytes] | None:
lst = self.lists.get(key)
if not lst:
return None
popped = lst[:count]
self.lists[key] = lst[count:]
return popped
async def llen(self, key: str) -> int:
return len(self.lists.get(key, []))
async def lrange(self, key: str, start: int, stop: int) -> list[str | bytes]:
lst = self.lists.get(key, [])
# Redis LRANGE stop is inclusive; -1 means the last element.
if stop == -1:
return list(lst[start:])
return list(lst[start : stop + 1])
async def delete(self, key: str) -> int:
if key in self.lists:
del self.lists[key]
return 1
return 0
def pipeline(self, transaction: bool = True) -> "_FakePipeline":
# Returns a fake pipeline that records ops and replays them in
# order on ``execute()``. Used by ``capped_rpush`` (push_pending_message)
# and ``incr_with_ttl`` (call-rate check) via MULTI/EXEC.
return _FakePipeline(self)
async def incr(self, key: str) -> int:
# Used by incr_with_ttl's pipeline.
current = int(self.lists.get(key, [0])[0]) if self.lists.get(key) else 0
current += 1
# We abuse the same lists dict for simple counters — store [count].
self.lists[key] = [str(current)]
return current
class _FakePipeline:
"""Async pipeline shim matching the redis-py MULTI/EXEC surface."""
def __init__(self, parent: "_FakeRedis") -> None:
self._parent = parent
self._ops: list[tuple[str, tuple[Any, ...], dict[str, Any]]] = []
# Each method just records the op; dispatching happens in execute().
def rpush(self, key: str, *values: Any) -> "_FakePipeline":
self._ops.append(("rpush", (key, *values), {}))
return self
def ltrim(self, key: str, start: int, stop: int) -> "_FakePipeline":
self._ops.append(("ltrim", (key, start, stop), {}))
return self
def expire(self, key: str, seconds: int, **kw: Any) -> "_FakePipeline":
self._ops.append(("expire", (key, seconds), kw))
return self
def llen(self, key: str) -> "_FakePipeline":
self._ops.append(("llen", (key,), {}))
return self
def incr(self, key: str) -> "_FakePipeline":
self._ops.append(("incr", (key,), {}))
return self
async def execute(self) -> list[Any]:
results: list[Any] = []
for name, args, _kw in self._ops:
fn = getattr(self._parent, name)
results.append(await fn(*args))
return results
# Support `async with pipeline() as pipe:` too.
async def __aenter__(self) -> "_FakePipeline":
return self
async def __aexit__(self, *a: Any) -> None:
return None
@pytest.fixture()
def fake_redis(monkeypatch: pytest.MonkeyPatch) -> _FakeRedis:
redis = _FakeRedis()
async def _get_redis_async() -> _FakeRedis:
return redis
monkeypatch.setattr(pm_module, "get_redis_async", _get_redis_async)
return redis
# ── Basic push / drain ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_push_and_drain_single_message(fake_redis: _FakeRedis) -> None:
length = await push_pending_message("sess1", PendingMessage(content="hello"))
assert length == 1
assert await peek_pending_count("sess1") == 1
drained = await drain_pending_messages("sess1")
assert len(drained) == 1
assert drained[0].content == "hello"
assert await peek_pending_count("sess1") == 0
@pytest.mark.asyncio
async def test_push_and_drain_preserves_order(fake_redis: _FakeRedis) -> None:
for i in range(3):
await push_pending_message("sess2", PendingMessage(content=f"msg {i}"))
drained = await drain_pending_messages("sess2")
assert [m.content for m in drained] == ["msg 0", "msg 1", "msg 2"]
@pytest.mark.asyncio
async def test_drain_empty_returns_empty_list(fake_redis: _FakeRedis) -> None:
assert await drain_pending_messages("nope") == []
# ── Buffer cap ──────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_cap_drops_oldest_when_exceeded(fake_redis: _FakeRedis) -> None:
# Push MAX_PENDING_MESSAGES + 3 messages
for i in range(MAX_PENDING_MESSAGES + 3):
await push_pending_message("sess3", PendingMessage(content=f"m{i}"))
# Buffer should be clamped to MAX
assert await peek_pending_count("sess3") == MAX_PENDING_MESSAGES
drained = await drain_pending_messages("sess3")
assert len(drained) == MAX_PENDING_MESSAGES
# Oldest 3 dropped — we should only see m3..m(MAX+2)
assert drained[0].content == "m3"
assert drained[-1].content == f"m{MAX_PENDING_MESSAGES + 2}"
# ── Clear ───────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_clear_removes_buffer(fake_redis: _FakeRedis) -> None:
await push_pending_message("sess4", PendingMessage(content="x"))
await push_pending_message("sess4", PendingMessage(content="y"))
await _clear_pending_messages_unsafe("sess4")
assert await peek_pending_count("sess4") == 0
@pytest.mark.asyncio
async def test_clear_is_idempotent(fake_redis: _FakeRedis) -> None:
# Clearing an already-empty buffer should not raise
await _clear_pending_messages_unsafe("sess_empty")
await _clear_pending_messages_unsafe("sess_empty")
# ── Publish hook ────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_push_publishes_notification(fake_redis: _FakeRedis) -> None:
await push_pending_message("sess5", PendingMessage(content="hi"))
assert ("copilot:pending:notify:sess5", "1") in fake_redis.published
# ── Format helper ───────────────────────────────────────────────────
def test_format_pending_plain_text() -> None:
msg = PendingMessage(content="just text")
out = format_pending_as_user_message(msg)
assert out == {"role": "user", "content": "just text"}
def test_format_pending_with_context_url() -> None:
msg = PendingMessage(
content="see this page",
context=PendingMessageContext(url="https://example.com"),
)
out = format_pending_as_user_message(msg)
content = out["content"]
assert out["role"] == "user"
assert "see this page" in content
# The URL should appear verbatim in the [Page URL: ...] block.
assert "[Page URL: https://example.com]" in content
def test_format_pending_with_file_ids() -> None:
msg = PendingMessage(content="look here", file_ids=["a", "b"])
out = format_pending_as_user_message(msg)
assert "file_id=a" in out["content"]
assert "file_id=b" in out["content"]
def test_format_pending_with_all_fields() -> None:
"""All fields (content + context url/content + file_ids) should all appear."""
msg = PendingMessage(
content="summarise this",
context=PendingMessageContext(
url="https://example.com/page",
content="headline text",
),
file_ids=["f1", "f2"],
)
out = format_pending_as_user_message(msg)
body = out["content"]
assert out["role"] == "user"
assert "summarise this" in body
assert "[Page URL: https://example.com/page]" in body
assert "[Page content]\nheadline text" in body
assert "file_id=f1" in body
assert "file_id=f2" in body
# ── Followup block caps ────────────────────────────────────────────
def test_format_followup_single_message() -> None:
out = format_pending_as_followup([PendingMessage(content="hello")])
assert "<user_follow_up>" in out
assert "</user_follow_up>" in out
assert "Message 1:\nhello" in out
def test_format_followup_total_cap_drops_overflow() -> None:
"""10 × 2 KB messages must truncate past the total cap (~6 KB) with a
marker indicating how many were dropped."""
messages = [PendingMessage(content="A" * 2_000) for _ in range(10)]
out = format_pending_as_followup(messages)
# Block stays within the total cap (plus a little wrapper overhead).
# The body alone is capped at 6 KB; we allow generous overhead for the
# <user_follow_up> wrapper + headers.
assert len(out) < 8_000
assert "more message(s) truncated" in out
# The first message at least must be present.
assert "Message 1:" in out
def test_format_followup_total_cap_marker_counts_dropped() -> None:
"""The marker should name the exact number of dropped messages."""
# Each 3 KB message gets capped to 2 KB first; with ~2 KB per entry and a
# 6 KB total cap, roughly two entries fit and the rest are dropped.
messages = [PendingMessage(content="X" * 3_000) for _ in range(5)]
out = format_pending_as_followup(messages)
assert "Message 1:" in out
assert "Message 2:" in out
# Message 3 would push total past 6 KB; marker should report exactly how
# many were left out (here: messages 3, 4, 5 → 3 dropped).
assert "[3 more message(s) truncated]" in out
def test_format_followup_empty_returns_empty_string() -> None:
assert format_pending_as_followup([]) == ""
# ── Malformed payload handling ──────────────────────────────────────
@pytest.mark.asyncio
async def test_drain_skips_malformed_entries(
fake_redis: _FakeRedis,
) -> None:
# Seed the fake with a mix of valid and malformed payloads
fake_redis.lists["copilot:pending:bad"] = [
json.dumps({"content": "valid"}),
"{not valid json",
json.dumps({"content": "also valid", "file_ids": ["a"]}),
]
drained = await drain_pending_messages("bad")
assert len(drained) == 2
assert drained[0].content == "valid"
assert drained[1].content == "also valid"
@pytest.mark.asyncio
async def test_drain_decodes_bytes_payloads(
fake_redis: _FakeRedis,
) -> None:
"""Real redis-py returns ``bytes`` when ``decode_responses=False``.
Seed the fake with bytes values to exercise the ``decode("utf-8")``
branch in ``drain_pending_messages`` so a regression there doesn't
slip past CI.
"""
fake_redis.lists["copilot:pending:bytes_sess"] = [
json.dumps({"content": "from bytes"}).encode("utf-8"),
]
drained = await drain_pending_messages("bytes_sess")
assert len(drained) == 1
assert drained[0].content == "from bytes"
@pytest.mark.asyncio
async def test_peek_decodes_bytes_payloads(
fake_redis: _FakeRedis,
) -> None:
"""``peek_pending_messages`` uses the same ``_decode_redis_item`` helper
as the drain path. Seed with bytes to guard against regression.
"""
fake_redis.lists["copilot:pending:peek_bytes_sess"] = [
json.dumps({"content": "peeked from bytes"}).encode("utf-8"),
]
peeked = await peek_pending_messages("peek_bytes_sess")
assert len(peeked) == 1
assert peeked[0].content == "peeked from bytes"
# peek must NOT consume the item
assert fake_redis.lists["copilot:pending:peek_bytes_sess"] != []
# ── Concurrency ─────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_concurrent_push_and_drain(fake_redis: _FakeRedis) -> None:
"""Two pushes fired concurrently should both land; a concurrent drain
should see at least one of them (the fake serialises, so it will
always see both, but we exercise the code path either way)."""
await asyncio.gather(
push_pending_message("sess_conc", PendingMessage(content="a")),
push_pending_message("sess_conc", PendingMessage(content="b")),
)
drained = await drain_pending_messages("sess_conc")
assert len(drained) >= 1
contents = {m.content for m in drained}
assert contents <= {"a", "b"}
# ── Publish error path ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_push_survives_publish_failure(
fake_redis: _FakeRedis, monkeypatch: pytest.MonkeyPatch
) -> None:
"""A publish error must not propagate — the buffer is still authoritative."""
async def _fail_publish(channel: str, payload: str) -> int:
raise RuntimeError("redis publish down")
monkeypatch.setattr(fake_redis, "publish", _fail_publish)
length = await push_pending_message("sess_pub_err", PendingMessage(content="ok"))
assert length == 1
drained = await drain_pending_messages("sess_pub_err")
assert len(drained) == 1
assert drained[0].content == "ok"
# ── peek_pending_messages ────────────────────────────────────────────
@pytest.mark.asyncio
async def test_peek_pending_messages_returns_all_without_consuming(
fake_redis: _FakeRedis,
) -> None:
"""Peek returns all queued messages and leaves the buffer intact."""
await push_pending_message("peek1", PendingMessage(content="first"))
await push_pending_message("peek1", PendingMessage(content="second"))
peeked = await peek_pending_messages("peek1")
assert len(peeked) == 2
assert peeked[0].content == "first"
assert peeked[1].content == "second"
# Buffer must not be consumed — count still 2
assert await peek_pending_count("peek1") == 2
drained = await drain_pending_messages("peek1")
assert len(drained) == 2
@pytest.mark.asyncio
async def test_peek_pending_messages_empty_buffer(fake_redis: _FakeRedis) -> None:
"""Peek on a missing key returns an empty list without raising."""
result = await peek_pending_messages("no_such_session")
assert result == []
@pytest.mark.asyncio
async def test_peek_pending_messages_decodes_bytes_payloads(
fake_redis: _FakeRedis,
) -> None:
"""peek_pending_messages decodes bytes entries the same way drain does."""
fake_redis.lists["copilot:pending:peek_bytes"] = [
json.dumps({"content": "from bytes"}).encode("utf-8"),
]
peeked = await peek_pending_messages("peek_bytes")
assert len(peeked) == 1
assert peeked[0].content == "from bytes"
@pytest.mark.asyncio
async def test_peek_pending_messages_skips_malformed_entries(
fake_redis: _FakeRedis,
) -> None:
"""Malformed entries are skipped and valid ones are returned."""
fake_redis.lists["copilot:pending:peek_bad"] = [
json.dumps({"content": "valid peek"}),
"{bad json",
json.dumps({"content": "also valid peek"}),
]
peeked = await peek_pending_messages("peek_bad")
assert len(peeked) == 2
assert peeked[0].content == "valid peek"
assert peeked[1].content == "also valid peek"
# ── Persist queue (mid-turn follow-up UI bubble hand-off) ───────────
@pytest.mark.asyncio
async def test_stash_for_persist_enqueues_and_drain_pops_in_order(
fake_redis: _FakeRedis,
) -> None:
"""stash_pending_for_persist writes messages under the persist key;
drain_pending_for_persist LPOPs them in enqueue order."""
msgs = [
PendingMessage(content="first mid-turn follow-up"),
PendingMessage(content="second"),
]
await stash_pending_for_persist("sess-persist", msgs)
# Stored under the distinct persist key, NOT the primary buffer.
assert "copilot:pending-persist:sess-persist" in fake_redis.lists
assert "copilot:pending:sess-persist" not in fake_redis.lists
drained = await drain_pending_for_persist("sess-persist")
assert len(drained) == 2
assert drained[0].content == "first mid-turn follow-up"
assert drained[1].content == "second"
# Queue is empty after drain.
assert await drain_pending_for_persist("sess-persist") == []
@pytest.mark.asyncio
async def test_stash_for_persist_empty_list_is_noop(
fake_redis: _FakeRedis,
) -> None:
"""Passing an empty list must NOT create a Redis key (would leak
empty persist entries and require a drain for no reason)."""
await stash_pending_for_persist("sess-noop", [])
assert "copilot:pending-persist:sess-noop" not in fake_redis.lists
@pytest.mark.asyncio
async def test_drain_pending_for_persist_missing_key_returns_empty(
fake_redis: _FakeRedis,
) -> None:
assert await drain_pending_for_persist("never-stashed") == []
@pytest.mark.asyncio
async def test_drain_pending_for_persist_skips_malformed(
fake_redis: _FakeRedis,
) -> None:
fake_redis.lists["copilot:pending-persist:bad"] = [
json.dumps({"content": "good one"}),
"not json",
json.dumps({"content": "another good one"}),
]
result = await drain_pending_for_persist("bad")
assert [m.content for m in result] == ["good one", "another good one"]
@pytest.mark.asyncio
async def test_persist_queue_isolated_from_primary_buffer(
fake_redis: _FakeRedis,
) -> None:
"""Draining the persist queue must NOT touch the primary pending
buffer (and vice versa) — they serve different lifecycles."""
# Seed the primary buffer with one entry.
await push_pending_message("sess-iso", PendingMessage(content="primary"))
# Stash a separate entry on the persist queue.
await stash_pending_for_persist("sess-iso", [PendingMessage(content="persist")])
drained_persist = await drain_pending_for_persist("sess-iso")
assert [m.content for m in drained_persist] == ["persist"]
# Primary buffer untouched.
assert await peek_pending_count("sess-iso") == 1
drained_primary = await drain_pending_messages("sess-iso")
assert [m.content for m in drained_primary] == ["primary"]
@pytest.mark.asyncio
async def test_stash_for_persist_swallows_redis_failure(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A broken Redis during stash must not raise — Claude has already
seen the follow-up via tool output; the only fallout is a missing
UI bubble, which we log and move on."""
async def _broken_redis() -> Any:
raise ConnectionError("redis down")
monkeypatch.setattr(pm_module, "get_redis_async", _broken_redis)
# Must NOT raise.
await stash_pending_for_persist("sess-broken", [PendingMessage(content="lost")])
# ── drain_and_format_for_injection: shared entry point ─────────────────
@pytest.mark.asyncio
async def test_drain_and_format_for_injection_happy_path(
fake_redis: _FakeRedis,
) -> None:
"""Queued messages drain into a ready-to-inject <user_follow_up> block
AND are stashed on the persist queue for UI row hand-off."""
await push_pending_message("sess-share", PendingMessage(content="do X also"))
result = await drain_and_format_for_injection("sess-share", log_prefix="[TEST]")
assert "<user_follow_up>" in result
assert "do X also" in result
# Primary buffer drained.
assert await peek_pending_count("sess-share") == 0
# Persist queue got a copy for the UI.
persisted = await drain_pending_for_persist("sess-share")
assert len(persisted) == 1
assert persisted[0].content == "do X also"
@pytest.mark.asyncio
async def test_drain_and_format_for_injection_empty_returns_empty(
fake_redis: _FakeRedis,
) -> None:
assert await drain_and_format_for_injection("sess-empty", log_prefix="[TEST]") == ""
@pytest.mark.asyncio
async def test_drain_and_format_for_injection_swallows_redis_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
async def _broken() -> Any:
raise ConnectionError("down")
monkeypatch.setattr(pm_module, "get_redis_async", _broken)
# Must NOT raise — broken Redis becomes "nothing to inject".
assert (
await drain_and_format_for_injection("sess-broken", log_prefix="[TEST]") == ""
)
@pytest.mark.asyncio
async def test_drain_and_format_for_injection_missing_session_id() -> None:
assert await drain_and_format_for_injection("", log_prefix="[TEST]") == ""

View File

@@ -87,6 +87,7 @@ ToolName = Literal[
"get_agent_building_guide",
"get_doc_page",
"get_mcp_guide",
"get_sub_session_result",
"list_folders",
"list_workspace_files",
"memory_forget_confirm",
@@ -99,6 +100,7 @@ ToolName = Literal[
"run_agent",
"run_block",
"run_mcp_tool",
"run_sub_session",
"search_docs",
"search_feature_requests",
"update_folder",

View File

@@ -8,11 +8,12 @@ handling the distinction between:
from functools import cache
from backend.blocks.autopilot import AUTOPILOT_BLOCK_ID
from backend.copilot.tools import TOOL_REGISTRY
# Shared technical notes that apply to both SDK and baseline modes
_SHARED_TOOL_NOTES = f"""\
# Workflow rules appended to the system prompt on every copilot turn
# (baseline appends directly; SDK appends via the storage-supplement
# template). These are cross-tool rules (file sharing, @@agptfile: refs,
# tool-discovery priority, sub-agent etiquette) that don't belong on any
# individual tool schema.
SHARED_TOOL_NOTES = """\
### Sharing files
After `write_workspace_file`, embed the `download_url` in Markdown:
@@ -68,13 +69,13 @@ that would be corrupted by text encoding.
Example — committing an image file to GitHub:
```json
{{
"files": [{{
{
"files": [{
"path": "docs/hero.png",
"content": "workspace://abc123#image/png",
"operation": "upsert"
}}]
}}
}]
}
```
### Writing large files — CRITICAL (causes production failures)
@@ -149,20 +150,27 @@ When the user asks to interact with a service or API, follow this order:
All tasks must run in the foreground.
### Delegating to another autopilot (sub-autopilot pattern)
Use the **AutoPilotBlock** (`run_block` with block_id
`{AUTOPILOT_BLOCK_ID}`) to delegate a task to a fresh
autopilot instance. The sub-autopilot has its own full tool set and can
perform multi-step work autonomously.
Use the **`run_sub_session`** tool to delegate a task to a fresh
sub-AutoPilot. The sub has its own full tool set and can perform
multi-step work autonomously.
- **Input**: `prompt` (required) the task description.
Optional: `system_context` to constrain behavior, `session_id` to
continue a previous conversation, `max_recursion_depth` (default 3).
- **Output**: `response` (text), `tool_calls` (list), `session_id`
(for continuation), `conversation_history`, `token_usage`.
- `prompt` (required): the task description.
- `system_context` (optional): extra context prepended to the prompt.
- `sub_autopilot_session_id` (optional): continue an existing
sub-AutoPilot — pass the `sub_autopilot_session_id` returned by a
previous completed run.
- `wait_for_result` (default 60, max 300): seconds to wait inline. If
the sub isn't done by then you get `status="running"` + a
`sub_session_id` — call **`get_sub_session_result`** with that id
(wait up to 300s more per call) until it returns `completed` or
`error`. Works across turns — safe to reconnect in a later message.
Use this when a task is complex enough to benefit from a separate
autopilot context, e.g. "research X and write a report" while the
parent autopilot handles orchestration.
parent autopilot handles orchestration. Do NOT invoke `AutoPilotBlock`
via `run_block` — it's hidden from `run_block` by design because the
dedicated tool handles the async lifecycle correctly.
"""
# E2B-only notes — E2B has full internet access so gh CLI works there.
@@ -174,6 +182,7 @@ sandbox so `bash_exec` can access it for further processing.
The exact sandbox path is shown in the `[Sandbox copy available at ...]` note.
### GitHub CLI (`gh`) and git
- To check if the user has their GitHub account already connected, run `gh auth status`. Always check this before asking them to connect it.
- If the user has connected their GitHub account, both `gh` and `git` are
pre-authenticated — use them directly without any manual login step.
`git` HTTPS operations (clone, push, pull) work automatically.
@@ -254,7 +263,7 @@ When a tool output contains `<tool-output-truncated workspace_path="...">`, the
full output is in workspace storage (NOT on the local filesystem). To access it:
- Use `read_workspace_file(path="...", offset=..., length=50000)` for reading sections.
- To process in the sandbox, use `read_workspace_file(path="...", save_to_path="{working_dir}/file.json")` first, then use `bash_exec` on the local copy.
{_SHARED_TOOL_NOTES}{extra_notes}"""
{SHARED_TOOL_NOTES}{extra_notes}"""
# Pre-built supplements for common environments
@@ -305,33 +314,37 @@ def _get_cloud_sandbox_supplement() -> str:
)
def _generate_tool_documentation() -> str:
"""Auto-generate tool documentation from TOOL_REGISTRY.
_USER_FOLLOW_UP_NOTE = """
# `<user_follow_up>` blocks in tool output
NOTE: This is ONLY used in baseline mode (direct OpenAI API).
SDK mode doesn't need it since Claude gets tool schemas automatically.
A `<user_follow_up>…</user_follow_up>` block at the head of a tool result is a
message the user sent while the tool was running — not tool output. The user is
watching the chat live and waiting for confirmation their message landed.
This generates a complete list of available tools with their descriptions,
ensuring the documentation stays in sync with the actual tool implementations.
All workflow guidance is now embedded in individual tool descriptions.
Every time you see one:
Only documents tools that are available in the current environment
(checked via tool.is_available property).
"""
docs = "\n## AVAILABLE TOOLS\n\n"
1. **Ack immediately.** Your very next emission must be a short visible line,
before any more tool calls:
*"Got your follow-up: {paraphrase}. {what I'll do}."*
# Sort tools alphabetically for consistent output
# Filter by is_available to match get_available_tools() behavior
for name in sorted(TOOL_REGISTRY.keys()):
tool = TOOL_REGISTRY[name]
if not tool.is_available:
continue
schema = tool.as_openai_tool()
desc = schema["function"].get("description", "No description available")
# Format as bullet list with tool name in code style
docs += f"- **`{name}`**: {desc}\n"
2. **Then act on it:**
- Question/input request → stop the tool chain and answer/ask back.
- New requirement → fold into the current plan.
- Correction → update the plan and continue with the revised target.
return docs
Never echo the `<user_follow_up>` tags back. The block holds only the user's
words — the rest of the tool result is the real data.
# Always close the turn with visible text
Every turn MUST end with at least one short user-facing text sentence —
even if it is only "Done." or "I'm stopping here because X." Never end a
turn with only tool calls or only thinking. The user's UI renders text
messages; a turn that emits only thinking blocks or only tool calls shows
up as a frozen screen with no response. If your plan was to stop after
the last tool result, still produce one closing sentence summarising
what happened so the user knows the turn is complete.
"""
@cache
@@ -356,9 +369,12 @@ def get_sdk_supplement(use_e2b: bool) -> str:
Returns:
The supplement string to append to the system prompt
"""
if use_e2b:
return _get_cloud_sandbox_supplement()
return _get_local_storage_supplement("/tmp/copilot-<session-id>")
base = (
_get_cloud_sandbox_supplement()
if use_e2b
else _get_local_storage_supplement("/tmp/copilot-<session-id>")
)
return base + _USER_FOLLOW_UP_NOTE
def get_graphiti_supplement() -> str:
@@ -395,17 +411,3 @@ You have access to persistent temporal memory tools that remember facts across s
- group_id is handled automatically by the system — never set it yourself.
- When storing, be specific about operational rules and instructions (e.g., "CC Sarah on client communications" not just "Sarah is the assistant").
"""
def get_baseline_supplement() -> str:
"""Get the supplement for baseline mode (direct OpenAI API).
Baseline mode INCLUDES auto-generated tool documentation because the
direct API doesn't automatically provide tool schemas to Claude.
Also includes shared technical notes (but NOT SDK-specific environment details).
Returns:
The supplement string to append to the system prompt
"""
tool_docs = _generate_tool_documentation()
return tool_docs + _SHARED_TOOL_NOTES

View File

@@ -1,9 +1,16 @@
"""CoPilot rate limiting based on token usage.
"""CoPilot rate limiting based on generation cost.
Uses Redis fixed-window counters to track per-user token consumption
with configurable daily and weekly limits. Daily windows reset at
midnight UTC; weekly windows reset at ISO week boundary (Monday 00:00
UTC). Fails open when Redis is unavailable to avoid blocking users.
Uses Redis fixed-window counters to track per-user USD spend (stored as
microdollars, matching ``PlatformCostLog.cost_microdollars``) with
configurable daily and weekly limits. Daily windows reset at midnight UTC;
weekly windows reset at ISO week boundary (Monday 00:00 UTC). Fails open
when Redis is unavailable to avoid blocking users.
Storing microdollars rather than tokens means the counter already reflects
real model pricing (including cache discounts and provider surcharges), so
this module carries no pricing table — the cost comes from OpenRouter's
``usage.cost`` field (baseline) or the Claude Agent SDK's reported total
cost (SDK path).
"""
import asyncio
@@ -17,12 +24,15 @@ from redis.exceptions import RedisError
from backend.data.db_accessors import user_db
from backend.data.redis_client import get_redis_async
from backend.data.user import get_user_by_id
from backend.util.cache import cached
logger = logging.getLogger(__name__)
# Redis key prefixes
_USAGE_KEY_PREFIX = "copilot:usage"
# Redis key prefixes. Bumped from "copilot:usage" (token-based) to
# "copilot:cost" on the token→cost migration so stale counters do not
# get misinterpreted as microdollars (which would dramatically under-count).
_USAGE_KEY_PREFIX = "copilot:cost"
# ---------------------------------------------------------------------------
@@ -31,7 +41,7 @@ _USAGE_KEY_PREFIX = "copilot:usage"
class SubscriptionTier(str, Enum):
"""Subscription tiers with increasing token allowances.
"""Subscription tiers with increasing cost allowances.
Mirrors the ``SubscriptionTier`` enum in ``schema.prisma``.
Once ``prisma generate`` is run, this can be replaced with::
@@ -45,9 +55,9 @@ class SubscriptionTier(str, Enum):
ENTERPRISE = "ENTERPRISE"
# Multiplier applied to the base limits (from LD / config) for each tier.
# Intentionally int (not float): keeps limits as whole token counts and avoids
# floating-point rounding. If fractional multipliers are ever needed, change
# Multiplier applied to the base cost limits (from LD / config) for each tier.
# Intentionally int (not float): keeps limits as whole microdollars and avoids
# floating-point rounding. If fractional multipliers are ever needed, change
# the type and round the result in get_global_rate_limits().
TIER_MULTIPLIERS: dict[SubscriptionTier, int] = {
SubscriptionTier.FREE: 1,
@@ -60,17 +70,27 @@ DEFAULT_TIER = SubscriptionTier.FREE
class UsageWindow(BaseModel):
"""Usage within a single time window."""
"""Usage within a single time window.
``used`` and ``limit`` are in microdollars (1 USD = 1_000_000).
"""
used: int
limit: int = Field(
description="Maximum tokens allowed in this window. 0 means unlimited."
description="Maximum microdollars of spend allowed in this window. "
"0 means unlimited."
)
resets_at: datetime
class CoPilotUsageStatus(BaseModel):
"""Current usage status for a user across all windows."""
"""Current usage status for a user across all windows.
Internal representation used by server-side code that needs to compare
usage against limits (e.g. the reset-credits endpoint). The public API
returns ``CoPilotUsagePublic`` instead so that raw spend and limit
figures never leak to clients.
"""
daily: UsageWindow
weekly: UsageWindow
@@ -81,6 +101,68 @@ class CoPilotUsageStatus(BaseModel):
)
class UsageWindowPublic(BaseModel):
"""Public view of a usage window — only the percentage and reset time.
Hides the raw spend and the cap so clients cannot derive per-turn cost
or reverse-engineer platform margins. ``percent_used`` is capped at 100.
"""
percent_used: float = Field(
ge=0.0,
le=100.0,
description="Percentage of the window's allowance used (0-100). "
"Clamped at 100 when over the cap.",
)
resets_at: datetime
class CoPilotUsagePublic(BaseModel):
"""Current usage status for a user — public (client-safe) shape."""
daily: UsageWindowPublic | None = Field(
default=None,
description="Null when no daily cap is configured (unlimited).",
)
weekly: UsageWindowPublic | None = Field(
default=None,
description="Null when no weekly cap is configured (unlimited).",
)
tier: SubscriptionTier = DEFAULT_TIER
reset_cost: int = Field(
default=0,
description="Credit cost (in cents) to reset the daily limit. 0 = feature disabled.",
)
@classmethod
def from_status(cls, status: CoPilotUsageStatus) -> "CoPilotUsagePublic":
"""Project the internal status onto the client-safe schema."""
def window(w: UsageWindow) -> UsageWindowPublic | None:
if w.limit <= 0:
return None
# When at/over the cap, snap to exactly 100.0 so the UI's
# rounded display and its exhaustion check (`percent_used >= 100`)
# agree. Without this, e.g. 99.95% would render as "100% used"
# via Math.round but fail the exhaustion check, leaving the
# reset button hidden while the bar appears full.
if w.used >= w.limit:
pct = 100.0
else:
pct = round(100.0 * w.used / w.limit, 1)
return UsageWindowPublic(
percent_used=pct,
resets_at=w.resets_at,
)
return cls(
daily=window(status.daily),
weekly=window(status.weekly),
tier=status.tier,
reset_cost=status.reset_cost,
)
class RateLimitExceeded(Exception):
"""Raised when a user exceeds their CoPilot usage limit."""
@@ -102,8 +184,8 @@ class RateLimitExceeded(Exception):
async def get_usage_status(
user_id: str,
daily_token_limit: int,
weekly_token_limit: int,
daily_cost_limit: int,
weekly_cost_limit: int,
rate_limit_reset_cost: int = 0,
tier: SubscriptionTier = DEFAULT_TIER,
) -> CoPilotUsageStatus:
@@ -111,13 +193,13 @@ async def get_usage_status(
Args:
user_id: The user's ID.
daily_token_limit: Max tokens per day (0 = unlimited).
weekly_token_limit: Max tokens per week (0 = unlimited).
daily_cost_limit: Max microdollars of spend per day (0 = unlimited).
weekly_cost_limit: Max microdollars of spend per week (0 = unlimited).
rate_limit_reset_cost: Credit cost (cents) to reset daily limit (0 = disabled).
tier: The user's rate-limit tier (included in the response).
Returns:
CoPilotUsageStatus with current usage and limits.
CoPilotUsageStatus with current usage and limits in microdollars.
"""
now = datetime.now(UTC)
daily_used = 0
@@ -136,12 +218,12 @@ async def get_usage_status(
return CoPilotUsageStatus(
daily=UsageWindow(
used=daily_used,
limit=daily_token_limit,
limit=daily_cost_limit,
resets_at=_daily_reset_time(now=now),
),
weekly=UsageWindow(
used=weekly_used,
limit=weekly_token_limit,
limit=weekly_cost_limit,
resets_at=_weekly_reset_time(now=now),
),
tier=tier,
@@ -151,22 +233,22 @@ async def get_usage_status(
async def check_rate_limit(
user_id: str,
daily_token_limit: int,
weekly_token_limit: int,
daily_cost_limit: int,
weekly_cost_limit: int,
) -> None:
"""Check if user is within rate limits. Raises RateLimitExceeded if not.
This is a pre-turn soft check. The authoritative usage counter is updated
by ``record_token_usage()`` after the turn completes. Under concurrency,
by ``record_cost_usage()`` after the turn completes. Under concurrency,
two parallel turns may both pass this check against the same snapshot.
This is acceptable because token-based limits are approximate by nature
(the exact token count is unknown until after generation).
This is acceptable because cost-based limits are approximate by nature
(the exact cost is unknown until after generation).
Fails open: if Redis is unavailable, allows the request.
"""
# Short-circuit: when both limits are 0 (unlimited) skip the Redis
# round-trip entirely.
if daily_token_limit <= 0 and weekly_token_limit <= 0:
if daily_cost_limit <= 0 and weekly_cost_limit <= 0:
return
now = datetime.now(UTC)
@@ -182,26 +264,25 @@ async def check_rate_limit(
logger.warning("Redis unavailable for rate limit check, allowing request")
return
# Worst-case overshoot: N concurrent requests × ~15K tokens each.
if daily_token_limit > 0 and daily_used >= daily_token_limit:
if daily_cost_limit > 0 and daily_used >= daily_cost_limit:
raise RateLimitExceeded("daily", _daily_reset_time(now=now))
if weekly_token_limit > 0 and weekly_used >= weekly_token_limit:
if weekly_cost_limit > 0 and weekly_used >= weekly_cost_limit:
raise RateLimitExceeded("weekly", _weekly_reset_time(now=now))
async def reset_daily_usage(user_id: str, daily_token_limit: int = 0) -> bool:
"""Reset a user's daily token usage counter in Redis.
async def reset_daily_usage(user_id: str, daily_cost_limit: int = 0) -> bool:
"""Reset a user's daily cost usage counter in Redis.
Called after a user pays credits to extend their daily limit.
Also reduces the weekly usage counter by ``daily_token_limit`` tokens
Also reduces the weekly usage counter by ``daily_cost_limit`` microdollars
(clamped to 0) so the user effectively gets one extra day's worth of
weekly capacity.
Args:
user_id: The user's ID.
daily_token_limit: The configured daily token limit. When positive,
the weekly counter is reduced by this amount.
daily_cost_limit: The configured daily cost limit in microdollars.
When positive, the weekly counter is reduced by this amount.
Returns False if Redis is unavailable so the caller can handle
compensation (fail-closed for billed operations, unlike the read-only
@@ -217,12 +298,12 @@ async def reset_daily_usage(user_id: str, daily_token_limit: int = 0) -> bool:
# counter is not decremented — which would let the caller refund
# credits even though the daily limit was already reset.
d_key = _daily_key(user_id, now=now)
w_key = _weekly_key(user_id, now=now) if daily_token_limit > 0 else None
w_key = _weekly_key(user_id, now=now) if daily_cost_limit > 0 else None
pipe = redis.pipeline(transaction=True)
pipe.delete(d_key)
if w_key is not None:
pipe.decrby(w_key, daily_token_limit)
pipe.decrby(w_key, daily_cost_limit)
results = await pipe.execute()
# Clamp negative weekly counter to 0 (best-effort; not critical).
@@ -295,84 +376,40 @@ async def increment_daily_reset_count(user_id: str) -> None:
logger.warning("Redis unavailable for tracking reset count")
async def record_token_usage(
async def record_cost_usage(
user_id: str,
prompt_tokens: int,
completion_tokens: int,
*,
cache_read_tokens: int = 0,
cache_creation_tokens: int = 0,
model_cost_multiplier: float = 1.0,
cost_microdollars: int,
) -> None:
"""Record token usage for a user across all windows.
"""Record a user's generation spend against daily and weekly counters.
Uses cost-weighted counting so cached tokens don't unfairly penalise
multi-turn conversations. Anthropic's pricing:
- uncached input: 100%
- cache creation: 25%
- cache read: 10%
- output: 100%
``prompt_tokens`` should be the *uncached* input count (``input_tokens``
from the API response). Cache counts are passed separately.
``model_cost_multiplier`` scales the final weighted total to reflect
relative model cost. Use 5.0 for Opus (5× more expensive than Sonnet)
so that Opus turns deplete the rate limit faster, proportional to cost.
``cost_microdollars`` is the real generation cost reported by the
provider (OpenRouter's ``usage.cost`` or the Claude Agent SDK's
``total_cost_usd`` converted to microdollars). Because the provider
cost already reflects model pricing and cache discounts, this function
carries no pricing table or weighting — it just increments counters.
Args:
user_id: The user's ID.
prompt_tokens: Uncached input tokens.
completion_tokens: Output tokens.
cache_read_tokens: Tokens served from prompt cache (10% cost).
cache_creation_tokens: Tokens written to prompt cache (25% cost).
model_cost_multiplier: Relative model cost factor (1.0 = Sonnet, 5.0 = Opus).
cost_microdollars: Spend to record in microdollars (1 USD = 1_000_000).
Non-positive values are ignored.
"""
prompt_tokens = max(0, prompt_tokens)
completion_tokens = max(0, completion_tokens)
cache_read_tokens = max(0, cache_read_tokens)
cache_creation_tokens = max(0, cache_creation_tokens)
weighted_input = (
prompt_tokens
+ round(cache_creation_tokens * 0.25)
+ round(cache_read_tokens * 0.1)
)
total = round(
(weighted_input + completion_tokens) * max(1.0, model_cost_multiplier)
)
if total <= 0:
cost_microdollars = max(0, cost_microdollars)
if cost_microdollars <= 0:
return
raw_total = (
prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens
)
logger.info(
"Recording token usage for %s: raw=%d, weighted=%d, multiplier=%.1fx "
"(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)",
user_id[:8],
raw_total,
total,
model_cost_multiplier,
prompt_tokens,
cache_read_tokens,
cache_creation_tokens,
completion_tokens,
)
logger.info("Recording copilot spend: %d microdollars", cost_microdollars)
now = datetime.now(UTC)
try:
redis = await get_redis_async()
# transaction=False: these are independent INCRBY+EXPIRE pairs on
# separate keys — no cross-key atomicity needed. Skipping
# MULTI/EXEC avoids the overhead. If the connection drops between
# INCRBY and EXPIRE the key survives until the next date-based key
# rotation (daily/weekly), so the memory-leak risk is negligible.
pipe = redis.pipeline(transaction=False)
# Use MULTI/EXEC so each INCRBY/EXPIRE pair is atomic — guarantees
# the TTL is set even if the connection drops mid-pipeline, so
# counters can never survive past their date-based rotation window.
pipe = redis.pipeline(transaction=True)
# Daily counter (expires at next midnight UTC)
d_key = _daily_key(user_id, now=now)
pipe.incrby(d_key, total)
pipe.incrby(d_key, cost_microdollars)
seconds_until_daily_reset = int(
(_daily_reset_time(now=now) - now).total_seconds()
)
@@ -380,7 +417,7 @@ async def record_token_usage(
# Weekly counter (expires end of week)
w_key = _weekly_key(user_id, now=now)
pipe.incrby(w_key, total)
pipe.incrby(w_key, cost_microdollars)
seconds_until_weekly_reset = int(
(_weekly_reset_time(now=now) - now).total_seconds()
)
@@ -389,8 +426,8 @@ async def record_token_usage(
await pipe.execute()
except (RedisError, ConnectionError, OSError):
logger.warning(
"Redis unavailable for recording token usage (tokens=%d)",
total,
"Redis unavailable for recording cost usage (microdollars=%d)",
cost_microdollars,
)
@@ -459,8 +496,20 @@ get_user_tier.cache_delete = _fetch_user_tier.cache_delete # type: ignore[attr-
async def set_user_tier(user_id: str, tier: SubscriptionTier) -> None:
"""Persist the user's rate-limit tier to the database.
Also invalidates the ``get_user_tier`` cache for this user so that
subsequent rate-limit checks immediately see the new tier.
Invalidates every cache that keys off the user's subscription tier so the
change is visible immediately: this function's own ``get_user_tier``, the
shared ``get_user_by_id`` (which exposes ``user.subscription_tier``), and
``get_pending_subscription_change`` (since an admin override can invalidate
a cached ``cancel_at_period_end`` or schedule-based pending state).
If the user has an active Stripe subscription whose current price does not
match ``tier``, Stripe will keep billing the old price and the next
``customer.subscription.updated`` webhook will overwrite the DB tier back
to whatever Stripe has. Proper reconciliation (cancelling or modifying the
Stripe subscription when an admin overrides the tier) is out of scope for
this PR — it changes the admin contract and needs its own test coverage.
For now we emit a ``WARNING`` so drift surfaces via Sentry until that
follow-up lands.
Raises:
prisma.errors.RecordNotFoundError: If the user does not exist.
@@ -469,8 +518,113 @@ async def set_user_tier(user_id: str, tier: SubscriptionTier) -> None:
where={"id": user_id},
data={"subscriptionTier": tier.value},
)
# Invalidate cached tier so rate-limit checks pick up the change immediately.
get_user_tier.cache_delete(user_id) # type: ignore[attr-defined]
# Local import required: backend.data.credit imports backend.copilot.rate_limit
# (via get_user_tier in credit.py's _invalidate_user_tier_caches), so a
# top-level ``from backend.data.credit import ...`` here would create a
# circular import at module-load time.
from backend.data.credit import get_pending_subscription_change
get_user_by_id.cache_delete(user_id) # type: ignore[attr-defined]
get_pending_subscription_change.cache_delete(user_id) # type: ignore[attr-defined]
# The DB write above is already committed; the drift check is best-effort
# diagnostic logging. Fire-and-forget so admin bulk ops don't wait on a
# Stripe roundtrip. The inner helper wraps its body in a timeout + broad
# except so background task errors still surface via logs rather than as
# "task exception never retrieved" warnings. Cancellation on request
# shutdown is acceptable — the drift warning is non-load-bearing.
asyncio.ensure_future(_drift_check_background(user_id, tier))
async def _drift_check_background(user_id: str, tier: SubscriptionTier) -> None:
"""Run the Stripe drift check in the background, logging rather than raising."""
try:
await asyncio.wait_for(
_warn_if_stripe_subscription_drifts(user_id, tier),
timeout=5.0,
)
logger.debug(
"set_user_tier: drift check completed for user=%s admin_tier=%s",
user_id,
tier.value,
)
except asyncio.TimeoutError:
logger.warning(
"set_user_tier: drift check timed out for user=%s admin_tier=%s",
user_id,
tier.value,
)
except asyncio.CancelledError:
# Request may have completed and the event loop is cancelling tasks —
# the drift log is non-critical, so accept cancellation silently.
raise
except Exception:
logger.exception(
"set_user_tier: drift check background task failed for"
" user=%s admin_tier=%s",
user_id,
tier.value,
)
async def _warn_if_stripe_subscription_drifts(
user_id: str, new_tier: SubscriptionTier
) -> None:
"""Emit a WARNING when an admin tier override leaves an active Stripe sub on a
mismatched price.
The warning is diagnostic only: Stripe remains the billing source of truth,
so the next ``customer.subscription.updated`` webhook will reset the DB
tier. Surfacing the drift here lets ops catch admin overrides that bypass
the intended Checkout / Portal cancel flows before users notice surprise
charges.
"""
# Local imports: see note in ``set_user_tier`` about the credit <-> rate_limit
# circular. These helpers (``_get_active_subscription``,
# ``get_subscription_price_id``) live in credit.py alongside the rest of
# the Stripe billing code.
from backend.data.credit import _get_active_subscription, get_subscription_price_id
try:
user = await get_user_by_id(user_id)
if not getattr(user, "stripe_customer_id", None):
return
sub = await _get_active_subscription(user.stripe_customer_id)
if sub is None:
return
items = sub["items"].data
if not items:
return
price = items[0].price
current_price_id = price if isinstance(price, str) else price.id
# The LaunchDarkly-backed price lookup must live inside this try/except:
# an LD SDK failure (network, token revoked) here would otherwise
# propagate past set_user_tier's already-committed DB write and turn a
# best-effort diagnostic into a 500 on admin tier writes.
expected_price_id = await get_subscription_price_id(new_tier)
except Exception:
logger.debug(
"_warn_if_stripe_subscription_drifts: drift lookup failed for"
" user=%s; skipping drift warning",
user_id,
exc_info=True,
)
return
if expected_price_id is not None and expected_price_id == current_price_id:
return
logger.warning(
"Admin tier override will drift from Stripe: user=%s admin_tier=%s"
" stripe_sub=%s stripe_price=%s expected_price=%s — the next"
" customer.subscription.updated webhook will reconcile the DB tier"
" back to whatever Stripe has; cancel or modify the Stripe subscription"
" if you intended the admin override to stick.",
user_id,
new_tier.value,
sub.id,
current_price_id,
expected_price_id,
)
async def get_global_rate_limits(
@@ -480,37 +634,41 @@ async def get_global_rate_limits(
) -> tuple[int, int, SubscriptionTier]:
"""Resolve global rate limits from LaunchDarkly, falling back to config.
The base limits (from LD or config) are multiplied by the user's
tier multiplier so that higher tiers receive proportionally larger
allowances.
Values are microdollars. The base limits (from LD or config) are
multiplied by the user's tier multiplier so that higher tiers receive
proportionally larger allowances.
Args:
user_id: User ID for LD flag evaluation context.
config_daily: Fallback daily limit from ChatConfig.
config_weekly: Fallback weekly limit from ChatConfig.
config_daily: Fallback daily cost limit (microdollars) from ChatConfig.
config_weekly: Fallback weekly cost limit (microdollars) from ChatConfig.
Returns:
(daily_token_limit, weekly_token_limit, tier) 3-tuple.
(daily_cost_limit, weekly_cost_limit, tier) — limits in microdollars.
"""
# Lazy import to avoid circular dependency:
# rate_limit -> feature_flag -> settings -> ... -> rate_limit
from backend.util.feature_flag import Flag, get_feature_flag_value
daily_raw = await get_feature_flag_value(
Flag.COPILOT_DAILY_TOKEN_LIMIT.value, user_id, config_daily
)
weekly_raw = await get_feature_flag_value(
Flag.COPILOT_WEEKLY_TOKEN_LIMIT.value, user_id, config_weekly
# Fetch daily + weekly flags in parallel — each LD evaluation is an
# independent network round-trip, so gather cuts latency roughly in half.
daily_raw, weekly_raw = await asyncio.gather(
get_feature_flag_value(
Flag.COPILOT_DAILY_COST_LIMIT.value, user_id, config_daily
),
get_feature_flag_value(
Flag.COPILOT_WEEKLY_COST_LIMIT.value, user_id, config_weekly
),
)
try:
daily = max(0, int(daily_raw))
except (TypeError, ValueError):
logger.warning("Invalid LD value for daily token limit: %r", daily_raw)
logger.warning("Invalid LD value for daily cost limit: %r", daily_raw)
daily = config_daily
try:
weekly = max(0, int(weekly_raw))
except (TypeError, ValueError):
logger.warning("Invalid LD value for weekly token limit: %r", weekly_raw)
logger.warning("Invalid LD value for weekly cost limit: %r", weekly_raw)
weekly = config_weekly
# Apply tier multiplier

View File

@@ -24,7 +24,7 @@ from .rate_limit import (
get_usage_status,
get_user_tier,
increment_daily_reset_count,
record_token_usage,
record_cost_usage,
release_reset_lock,
reset_daily_usage,
reset_user_usage,
@@ -82,7 +82,7 @@ class TestGetUsageStatus:
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
_USER, daily_cost_limit=10000, weekly_cost_limit=50000
)
assert isinstance(status, CoPilotUsageStatus)
@@ -98,7 +98,7 @@ class TestGetUsageStatus:
side_effect=ConnectionError("Redis down"),
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
_USER, daily_cost_limit=10000, weekly_cost_limit=50000
)
assert status.daily.used == 0
@@ -115,7 +115,7 @@ class TestGetUsageStatus:
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
_USER, daily_cost_limit=10000, weekly_cost_limit=50000
)
assert status.daily.used == 0
@@ -132,7 +132,7 @@ class TestGetUsageStatus:
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
_USER, daily_cost_limit=10000, weekly_cost_limit=50000
)
assert status.daily.used == 500
@@ -148,7 +148,7 @@ class TestGetUsageStatus:
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
_USER, daily_cost_limit=10000, weekly_cost_limit=50000
)
now = datetime.now(UTC)
@@ -174,7 +174,7 @@ class TestCheckRateLimit:
):
# Should not raise
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
_USER, daily_cost_limit=10000, weekly_cost_limit=50000
)
@pytest.mark.asyncio
@@ -188,7 +188,7 @@ class TestCheckRateLimit:
):
with pytest.raises(RateLimitExceeded) as exc_info:
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
_USER, daily_cost_limit=10000, weekly_cost_limit=50000
)
assert exc_info.value.window == "daily"
@@ -203,7 +203,7 @@ class TestCheckRateLimit:
):
with pytest.raises(RateLimitExceeded) as exc_info:
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
_USER, daily_cost_limit=10000, weekly_cost_limit=50000
)
assert exc_info.value.window == "weekly"
@@ -216,7 +216,7 @@ class TestCheckRateLimit:
):
# Should not raise
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
_USER, daily_cost_limit=10000, weekly_cost_limit=50000
)
@pytest.mark.asyncio
@@ -229,15 +229,15 @@ class TestCheckRateLimit:
return_value=mock_redis,
):
# Should not raise — limits of 0 mean unlimited
await check_rate_limit(_USER, daily_token_limit=0, weekly_token_limit=0)
await check_rate_limit(_USER, daily_cost_limit=0, weekly_cost_limit=0)
# ---------------------------------------------------------------------------
# record_token_usage
# record_cost_usage
# ---------------------------------------------------------------------------
class TestRecordTokenUsage:
class TestRecordCostUsage:
@staticmethod
def _make_pipeline_mock() -> MagicMock:
"""Create a pipeline mock with sync methods and async execute."""
@@ -255,27 +255,40 @@ class TestRecordTokenUsage:
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
await record_cost_usage(_USER, cost_microdollars=123_456)
# Should call incrby twice (daily + weekly) with total=150
# Should call incrby twice (daily + weekly) with the same cost
incrby_calls = mock_pipe.incrby.call_args_list
assert len(incrby_calls) == 2
assert incrby_calls[0].args[1] == 150 # daily
assert incrby_calls[1].args[1] == 150 # weekly
assert incrby_calls[0].args[1] == 123_456 # daily
assert incrby_calls[1].args[1] == 123_456 # weekly
@pytest.mark.asyncio
async def test_skips_when_zero_tokens(self):
async def test_skips_when_cost_is_zero(self):
mock_redis = AsyncMock()
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=0, completion_tokens=0)
await record_cost_usage(_USER, cost_microdollars=0)
# Should not call pipeline at all
mock_redis.pipeline.assert_not_called()
@pytest.mark.asyncio
async def test_skips_when_cost_is_negative(self):
"""Negative costs are clamped to zero and skip the pipeline."""
mock_redis = AsyncMock()
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_cost_usage(_USER, cost_microdollars=-10)
mock_redis.pipeline.assert_not_called()
@pytest.mark.asyncio
async def test_sets_expire_on_both_keys(self):
"""Pipeline should call expire for both daily and weekly keys."""
@@ -287,7 +300,7 @@ class TestRecordTokenUsage:
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
await record_cost_usage(_USER, cost_microdollars=5_000)
expire_calls = mock_pipe.expire.call_args_list
assert len(expire_calls) == 2
@@ -308,32 +321,7 @@ class TestRecordTokenUsage:
side_effect=ConnectionError("Redis down"),
):
# Should not raise
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
@pytest.mark.asyncio
async def test_cost_weighted_counting(self):
"""Cached tokens should be weighted: cache_read=10%, cache_create=25%."""
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(
_USER,
prompt_tokens=100, # uncached → 100
completion_tokens=50, # output → 50
cache_read_tokens=10000, # 10% → 1000
cache_creation_tokens=400, # 25% → 100
)
# Expected weighted total: 100 + 1000 + 100 + 50 = 1250
incrby_calls = mock_pipe.incrby.call_args_list
assert len(incrby_calls) == 2
assert incrby_calls[0].args[1] == 1250 # daily
assert incrby_calls[1].args[1] == 1250 # weekly
await record_cost_usage(_USER, cost_microdollars=5_000)
@pytest.mark.asyncio
async def test_handles_redis_error_during_pipeline_execute(self):
@@ -348,7 +336,7 @@ class TestRecordTokenUsage:
return_value=mock_redis,
):
# Should not raise — fail-open
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
await record_cost_usage(_USER, cost_microdollars=5_000)
# ---------------------------------------------------------------------------
@@ -581,6 +569,80 @@ class TestSetUserTier:
assert tier_after == SubscriptionTier.ENTERPRISE
@pytest.mark.asyncio
async def test_drift_check_swallows_launchdarkly_failure(self):
"""LaunchDarkly price-id lookup failures inside the drift check must
never bubble up and 500 the admin tier write — the DB update is
already committed by the time we check drift."""
mock_prisma = AsyncMock()
mock_prisma.update = AsyncMock(return_value=None)
mock_user = MagicMock()
mock_user.stripe_customer_id = "cus_abc"
mock_sub = MagicMock()
mock_sub.id = "sub_abc"
mock_sub["items"].data = [MagicMock(price=MagicMock(id="price_mismatch"))]
with (
patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=mock_prisma,
),
patch(
"backend.copilot.rate_limit.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user,
),
patch(
"backend.data.credit._get_active_subscription",
new_callable=AsyncMock,
return_value=mock_sub,
),
patch(
"backend.data.credit.get_subscription_price_id",
new_callable=AsyncMock,
side_effect=RuntimeError("LD SDK not initialized"),
),
):
# Must NOT raise — drift check is best-effort diagnostic only.
await set_user_tier(_USER, SubscriptionTier.PRO)
mock_prisma.update.assert_awaited_once()
@pytest.mark.asyncio
async def test_drift_check_timeout_is_bounded(self):
"""A Stripe call that stalls on the 80s SDK default must not block the
admin tier write — set_user_tier wraps the drift check in a 5s timeout
and logs + returns on TimeoutError."""
import asyncio as _asyncio
mock_prisma = AsyncMock()
mock_prisma.update = AsyncMock(return_value=None)
async def _never_returns(_user_id: str, _tier):
await _asyncio.sleep(60)
with (
patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=mock_prisma,
),
patch(
"backend.copilot.rate_limit._warn_if_stripe_subscription_drifts",
side_effect=_never_returns,
),
patch(
"backend.copilot.rate_limit.asyncio.wait_for",
new_callable=AsyncMock,
side_effect=_asyncio.TimeoutError,
),
):
await set_user_tier(_USER, SubscriptionTier.PRO)
# Set_user_tier still completed — the drift timeout did not propagate.
mock_prisma.update.assert_awaited_once()
# ---------------------------------------------------------------------------
# get_global_rate_limits with tiers
@@ -745,7 +807,7 @@ class TestTierLimitsRespected:
assert tier == SubscriptionTier.PRO
# Should NOT raise — 3M < 12.5M
await check_rate_limit(
_USER, daily_token_limit=daily, weekly_token_limit=weekly
_USER, daily_cost_limit=daily, weekly_cost_limit=weekly
)
@pytest.mark.asyncio
@@ -779,7 +841,7 @@ class TestTierLimitsRespected:
# Should raise — 2.5M >= 2.5M
with pytest.raises(RateLimitExceeded):
await check_rate_limit(
_USER, daily_token_limit=daily, weekly_token_limit=weekly
_USER, daily_cost_limit=daily, weekly_cost_limit=weekly
)
@pytest.mark.asyncio
@@ -811,7 +873,7 @@ class TestTierLimitsRespected:
assert tier == SubscriptionTier.ENTERPRISE
# Should NOT raise — 100M < 150M
await check_rate_limit(
_USER, daily_token_limit=daily, weekly_token_limit=weekly
_USER, daily_cost_limit=daily, weekly_cost_limit=weekly
)
@@ -838,7 +900,7 @@ class TestResetDailyUsage:
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
result = await reset_daily_usage(_USER, daily_token_limit=10000)
result = await reset_daily_usage(_USER, daily_cost_limit=10000)
assert result is True
mock_pipe.delete.assert_called_once()
@@ -854,7 +916,7 @@ class TestResetDailyUsage:
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await reset_daily_usage(_USER, daily_token_limit=10000)
await reset_daily_usage(_USER, daily_cost_limit=10000)
mock_pipe.decrby.assert_called_once()
mock_redis.set.assert_not_called() # 35000 > 0, no clamp needed
@@ -870,14 +932,14 @@ class TestResetDailyUsage:
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await reset_daily_usage(_USER, daily_token_limit=10000)
await reset_daily_usage(_USER, daily_cost_limit=10000)
mock_pipe.decrby.assert_called_once()
mock_redis.set.assert_called_once()
@pytest.mark.asyncio
async def test_no_weekly_reduction_when_daily_limit_zero(self):
"""When daily_token_limit is 0, weekly counter should not be touched."""
"""When daily_cost_limit is 0, weekly counter should not be touched."""
mock_pipe = self._make_pipeline_mock()
mock_pipe.execute = AsyncMock(return_value=[1]) # only delete result
mock_redis = AsyncMock()
@@ -887,7 +949,7 @@ class TestResetDailyUsage:
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await reset_daily_usage(_USER, daily_token_limit=0)
await reset_daily_usage(_USER, daily_cost_limit=0)
mock_pipe.delete.assert_called_once()
mock_pipe.decrby.assert_not_called()
@@ -898,7 +960,7 @@ class TestResetDailyUsage:
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
result = await reset_daily_usage(_USER, daily_token_limit=10000)
result = await reset_daily_usage(_USER, daily_cost_limit=10000)
assert result is False

View File

@@ -16,14 +16,14 @@ from backend.util.exceptions import InsufficientBalanceError
# Minimal config mock matching ChatConfig fields used by the endpoint.
def _make_config(
rate_limit_reset_cost: int = 500,
daily_token_limit: int = 2_500_000,
weekly_token_limit: int = 12_500_000,
daily_cost_limit_microdollars: int = 10_000_000,
weekly_cost_limit_microdollars: int = 50_000_000,
max_daily_resets: int = 5,
):
cfg = MagicMock()
cfg.rate_limit_reset_cost = rate_limit_reset_cost
cfg.daily_token_limit = daily_token_limit
cfg.weekly_token_limit = weekly_token_limit
cfg.daily_cost_limit_microdollars = daily_cost_limit_microdollars
cfg.weekly_cost_limit_microdollars = weekly_cost_limit_microdollars
cfg.max_daily_resets = max_daily_resets
return cfg
@@ -77,10 +77,10 @@ class TestResetCopilotUsage:
assert "not available" in exc_info.value.detail
async def test_no_daily_limit_returns_400(self):
"""When daily_token_limit=0 (unlimited), endpoint returns 400."""
"""When daily_cost_limit=0 (unlimited), endpoint returns 400."""
with (
patch(f"{_MODULE}.config", _make_config(daily_token_limit=0)),
patch(f"{_MODULE}.config", _make_config(daily_cost_limit_microdollars=0)),
patch(f"{_MODULE}.settings", _mock_settings()),
_mock_rate_limits(daily=0),
):

View File

@@ -34,6 +34,15 @@ class ResponseType(str, Enum):
TEXT_DELTA = "text-delta"
TEXT_END = "text-end"
# Reasoning streaming (extended_thinking content blocks). Matches
# the Vercel AI SDK v5 wire names so the client's ``useChat``
# transport accumulates these into a ``type: 'reasoning'`` UIMessage
# part that the ``ReasoningCollapse`` component renders collapsed by
# default.
REASONING_START = "reasoning-start"
REASONING_DELTA = "reasoning-delta"
REASONING_END = "reasoning-end"
# Tool interaction
TOOL_INPUT_START = "tool-input-start"
TOOL_INPUT_AVAILABLE = "tool-input-available"
@@ -130,6 +139,31 @@ class StreamTextEnd(StreamBaseResponse):
id: str = Field(..., description="Text block ID")
# ========== Reasoning Streaming ==========
class StreamReasoningStart(StreamBaseResponse):
"""Start of a reasoning block (extended_thinking content)."""
type: ResponseType = ResponseType.REASONING_START
id: str = Field(..., description="Reasoning block ID")
class StreamReasoningDelta(StreamBaseResponse):
"""Streaming reasoning content delta."""
type: ResponseType = ResponseType.REASONING_DELTA
id: str = Field(..., description="Reasoning block ID")
delta: str = Field(..., description="Reasoning content delta")
class StreamReasoningEnd(StreamBaseResponse):
"""End of a reasoning block."""
type: ResponseType = ResponseType.REASONING_END
id: str = Field(..., description="Reasoning block ID")
# ========== Tool Interaction ==========

View File

@@ -24,14 +24,10 @@ from typing import TYPE_CHECKING, Any
# Static imports for type checkers so they can resolve __all__ entries
# without executing the lazy-import machinery at runtime.
if TYPE_CHECKING:
from .collect import CopilotResult as CopilotResult
from .collect import collect_copilot_response as collect_copilot_response
from .service import stream_chat_completion_sdk as stream_chat_completion_sdk
from .tool_adapter import create_copilot_mcp_server as create_copilot_mcp_server
__all__ = [
"CopilotResult",
"collect_copilot_response",
"stream_chat_completion_sdk",
"create_copilot_mcp_server",
]
@@ -39,8 +35,6 @@ __all__ = [
# Dispatch table for PEP 562 lazy imports. Each entry is a (module, attr)
# pair so new exports can be added without touching __getattr__ itself.
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
"CopilotResult": (".collect", "CopilotResult"),
"collect_copilot_response": (".collect", "collect_copilot_response"),
"stream_chat_completion_sdk": (".service", "stream_chat_completion_sdk"),
"create_copilot_mcp_server": (".tool_adapter", "create_copilot_mcp_server"),
}

View File

@@ -1,232 +0,0 @@
"""Public helpers for consuming a copilot stream as a simple request-response.
This module exposes :class:`CopilotResult` and :func:`collect_copilot_response`
so that callers (e.g. the AutoPilot block) can consume the copilot stream
without implementing their own event loop.
"""
from __future__ import annotations
import logging
import uuid
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from backend.copilot.permissions import CopilotPermissions
from pydantic import BaseModel, Field
from redis.exceptions import RedisError
from .. import stream_registry
from ..response_model import (
StreamError,
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
StreamUsage,
)
from .service import stream_chat_completion_sdk
logger = logging.getLogger(__name__)
# Identifiers used when registering AutoPilot-originated streams in the
# stream registry. Distinct from "chat_stream"/"chat" used by the HTTP SSE
# endpoint, making it easy to filter AutoPilot streams in logs/observability.
AUTOPILOT_TOOL_CALL_ID = "autopilot_stream"
AUTOPILOT_TOOL_NAME = "autopilot"
class CopilotResult:
"""Aggregated result from consuming a copilot stream.
Returned by :func:`collect_copilot_response` so callers don't need to
implement their own event-loop over the raw stream events.
"""
__slots__ = (
"response_text",
"tool_calls",
"prompt_tokens",
"completion_tokens",
"total_tokens",
)
def __init__(self) -> None:
self.response_text: str = ""
self.tool_calls: list[dict[str, Any]] = []
self.prompt_tokens: int = 0
self.completion_tokens: int = 0
self.total_tokens: int = 0
class _RegistryHandle(BaseModel):
"""Tracks stream registry session state for cleanup."""
publish_turn_id: str = ""
error_msg: str | None = None
error_already_published: bool = False
@asynccontextmanager
async def _registry_session(
session_id: str, user_id: str, turn_id: str
) -> AsyncIterator[_RegistryHandle]:
"""Create a stream registry session and ensure it is finalized."""
handle = _RegistryHandle(publish_turn_id=turn_id)
try:
await stream_registry.create_session(
session_id=session_id,
user_id=user_id,
tool_call_id=AUTOPILOT_TOOL_CALL_ID,
tool_name=AUTOPILOT_TOOL_NAME,
turn_id=turn_id,
)
except (RedisError, ConnectionError, OSError):
logger.warning(
"[collect] Failed to create stream registry session for %s, "
"frontend will not receive real-time updates",
session_id[:12],
exc_info=True,
)
# Disable chunk publishing but keep finalization enabled so
# mark_session_completed can clean up any partial registry state.
handle.publish_turn_id = ""
try:
yield handle
finally:
try:
await stream_registry.mark_session_completed(
session_id,
error_message=handle.error_msg,
skip_error_publish=handle.error_already_published,
)
except (RedisError, ConnectionError, OSError):
logger.warning(
"[collect] Failed to mark stream completed for %s",
session_id[:12],
exc_info=True,
)
class _ToolCallEntry(BaseModel):
"""A single tool call observed during stream consumption."""
tool_call_id: str
tool_name: str
input: Any
output: Any = None
success: bool | None = None
class _EventAccumulator(BaseModel):
"""Mutable accumulator for stream events."""
response_parts: list[str] = Field(default_factory=list)
tool_calls: list[_ToolCallEntry] = Field(default_factory=list)
tool_calls_by_id: dict[str, _ToolCallEntry] = Field(default_factory=dict)
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
def _process_event(event: object, acc: _EventAccumulator) -> str | None:
"""Process a single stream event and return error_msg if StreamError.
Uses structural pattern matching for dispatch per project guidelines.
"""
match event:
case StreamTextDelta(delta=delta):
acc.response_parts.append(delta)
case StreamToolInputAvailable() as e:
entry = _ToolCallEntry(
tool_call_id=e.toolCallId,
tool_name=e.toolName,
input=e.input,
)
acc.tool_calls.append(entry)
acc.tool_calls_by_id[e.toolCallId] = entry
case StreamToolOutputAvailable() as e:
if tc := acc.tool_calls_by_id.get(e.toolCallId):
tc.output = e.output
tc.success = e.success
else:
logger.debug(
"Received tool output for unknown tool_call_id: %s",
e.toolCallId,
)
case StreamUsage() as e:
acc.prompt_tokens += e.prompt_tokens
acc.completion_tokens += e.completion_tokens
acc.total_tokens += e.total_tokens
case StreamError(errorText=err):
return err
return None
async def collect_copilot_response(
*,
session_id: str,
message: str,
user_id: str,
is_user_message: bool = True,
permissions: "CopilotPermissions | None" = None,
) -> CopilotResult:
"""Consume :func:`stream_chat_completion_sdk` and return aggregated results.
Registers with the stream registry so the frontend can connect via SSE
and receive real-time updates while the AutoPilot block is executing.
Args:
session_id: Chat session to use.
message: The user message / prompt.
user_id: Authenticated user ID.
is_user_message: Whether this is a user-initiated message.
permissions: Optional capability filter. When provided, restricts
which tools and blocks the copilot may use during this execution.
Returns:
A :class:`CopilotResult` with the aggregated response text,
tool calls, and token usage.
Raises:
RuntimeError: If the stream yields a ``StreamError`` event.
"""
turn_id = str(uuid.uuid4())
async with _registry_session(session_id, user_id, turn_id) as handle:
try:
raw_stream = stream_chat_completion_sdk(
session_id=session_id,
message=message,
is_user_message=is_user_message,
user_id=user_id,
permissions=permissions,
)
published_stream = stream_registry.stream_and_publish(
session_id=session_id,
turn_id=handle.publish_turn_id,
stream=raw_stream,
)
acc = _EventAccumulator()
async for event in published_stream:
if err := _process_event(event, acc):
handle.error_msg = err
# stream_and_publish skips StreamError events, so
# mark_session_completed must publish the error to Redis.
handle.error_already_published = False
raise RuntimeError(f"Copilot error: {err}")
except Exception:
if handle.error_msg is None:
handle.error_msg = "AutoPilot execution failed"
raise
result = CopilotResult()
result.response_text = "".join(acc.response_parts)
result.tool_calls = [tc.model_dump() for tc in acc.tool_calls]
result.prompt_tokens = acc.prompt_tokens
result.completion_tokens = acc.completion_tokens
result.total_tokens = acc.total_tokens
return result

View File

@@ -1,177 +0,0 @@
"""Tests for collect_copilot_response stream registry integration."""
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.response_model import (
StreamError,
StreamFinish,
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
StreamUsage,
)
from backend.copilot.sdk.collect import collect_copilot_response
def _mock_stream_fn(*events):
"""Return a callable that returns an async generator."""
async def _gen(**_kwargs):
for e in events:
yield e
return _gen
@pytest.fixture
def mock_registry():
"""Patch stream_registry module used by collect."""
with patch("backend.copilot.sdk.collect.stream_registry") as m:
m.create_session = AsyncMock()
m.publish_chunk = AsyncMock()
m.mark_session_completed = AsyncMock()
# stream_and_publish: pass-through that also publishes (real logic)
# We re-implement the pass-through here so the event loop works,
# but still track publish_chunk calls via the mock.
async def _stream_and_publish(session_id, turn_id, stream):
async for event in stream:
if turn_id and not isinstance(event, (StreamFinish, StreamError)):
await m.publish_chunk(turn_id, event)
yield event
m.stream_and_publish = _stream_and_publish
yield m
@pytest.fixture
def stream_fn_patch():
"""Helper to patch stream_chat_completion_sdk."""
def _patch(events):
return patch(
"backend.copilot.sdk.collect.stream_chat_completion_sdk",
new=_mock_stream_fn(*events),
)
return _patch
@pytest.mark.asyncio
async def test_stream_registry_called_on_success(mock_registry, stream_fn_patch):
"""Stream registry create/publish/complete are called correctly on success."""
events = [
StreamTextDelta(id="t1", delta="Hello "),
StreamTextDelta(id="t1", delta="world"),
StreamUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
StreamFinish(),
]
with stream_fn_patch(events):
result = await collect_copilot_response(
session_id="test-session",
message="hi",
user_id="user-1",
)
assert result.response_text == "Hello world"
assert result.total_tokens == 15
mock_registry.create_session.assert_awaited_once()
# StreamFinish should NOT be published (mark_session_completed does it)
published_types = [
type(call.args[1]).__name__
for call in mock_registry.publish_chunk.call_args_list
]
assert "StreamFinish" not in published_types
assert "StreamTextDelta" in published_types
mock_registry.mark_session_completed.assert_awaited_once()
_, kwargs = mock_registry.mark_session_completed.call_args
assert kwargs.get("error_message") is None
@pytest.mark.asyncio
async def test_stream_registry_error_on_stream_error(mock_registry, stream_fn_patch):
"""mark_session_completed receives error message when StreamError occurs."""
events = [
StreamTextDelta(id="t1", delta="partial"),
StreamError(errorText="something broke"),
]
with stream_fn_patch(events):
with pytest.raises(RuntimeError, match="something broke"):
await collect_copilot_response(
session_id="test-session",
message="hi",
user_id="user-1",
)
_, kwargs = mock_registry.mark_session_completed.call_args
assert kwargs.get("error_message") == "something broke"
# stream_and_publish skips StreamError, so mark_session_completed must
# publish it (skip_error_publish=False).
assert kwargs.get("skip_error_publish") is False
# StreamError should NOT be published via publish_chunk — mark_session_completed
# handles it to avoid double-publication.
published_types = [
type(call.args[1]).__name__
for call in mock_registry.publish_chunk.call_args_list
]
assert "StreamError" not in published_types
@pytest.mark.asyncio
async def test_graceful_degradation_when_create_session_fails(
mock_registry, stream_fn_patch
):
"""AutoPilot still works when stream registry create_session raises."""
events = [
StreamTextDelta(id="t1", delta="works"),
StreamFinish(),
]
mock_registry.create_session = AsyncMock(side_effect=ConnectionError("Redis down"))
with stream_fn_patch(events):
result = await collect_copilot_response(
session_id="test-session",
message="hi",
user_id="user-1",
)
assert result.response_text == "works"
# publish_chunk should NOT be called because turn_id was cleared
mock_registry.publish_chunk.assert_not_awaited()
# mark_session_completed IS still called to clean up any partial state
mock_registry.mark_session_completed.assert_awaited_once()
@pytest.mark.asyncio
async def test_tool_calls_published_and_collected(mock_registry, stream_fn_patch):
"""Tool call events are both published to registry and collected in result."""
events = [
StreamToolInputAvailable(
toolCallId="tc-1", toolName="read_file", input={"path": "/tmp"}
),
StreamToolOutputAvailable(
toolCallId="tc-1", output="file contents", success=True
),
StreamTextDelta(id="t1", delta="done"),
StreamFinish(),
]
with stream_fn_patch(events):
result = await collect_copilot_response(
session_id="test-session",
message="hi",
user_id="user-1",
)
assert len(result.tool_calls) == 1
assert result.tool_calls[0]["tool_name"] == "read_file"
assert result.tool_calls[0]["output"] == "file contents"
assert result.tool_calls[0]["success"] is True
assert result.response_text == "done"

View File

@@ -84,9 +84,10 @@ async def test_resolve_file_ref_local_path_with_line_range():
async def test_resolve_file_ref_rejects_path_outside_sdk_cwd():
"""resolve_file_ref raises ValueError for paths outside sdk_cwd."""
with tempfile.TemporaryDirectory() as sdk_cwd:
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox_var:
with (
patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var,
patch("backend.copilot.context._current_sandbox") as mock_sandbox_var,
):
mock_cwd_var.get.return_value = sdk_cwd
mock_sandbox_var.get.return_value = None
@@ -387,11 +388,13 @@ async def test_read_file_handler_local_file():
with open(test_file, "w") as f:
f.writelines(lines)
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
"backend.copilot.context._current_project_dir"
) as mock_proj_var, patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", _make_session()),
with (
patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var,
patch("backend.copilot.context._current_project_dir") as mock_proj_var,
patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", _make_session()),
),
):
mock_cwd_var.get.return_value = sdk_cwd
# No project_dir set — so is_sdk_tool_path returns False for sdk_cwd paths
@@ -413,12 +416,15 @@ async def test_read_file_handler_workspace_uri():
mock_manager = AsyncMock()
mock_manager.read_file_by_id.return_value = b"workspace file content\nline two\n"
with patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", mock_session),
), patch(
"backend.copilot.sdk.file_ref.get_workspace_manager",
new=AsyncMock(return_value=mock_manager),
with (
patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", mock_session),
),
patch(
"backend.copilot.sdk.file_ref.get_workspace_manager",
new=AsyncMock(return_value=mock_manager),
),
):
result = await _read_file_handler(
{"file_path": "workspace://file-id-abc", "offset": 0, "limit": 10}
@@ -446,11 +452,13 @@ async def test_read_file_handler_workspace_uri_no_session():
@pytest.mark.asyncio
async def test_read_file_handler_access_denied():
"""_read_file_handler rejects paths outside allowed locations."""
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox, patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", _make_session()),
with (
patch("backend.copilot.context._current_sdk_cwd") as mock_cwd,
patch("backend.copilot.context._current_sandbox") as mock_sandbox,
patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", _make_session()),
),
):
mock_cwd.get.return_value = "/tmp/safe-dir"
mock_sandbox.get.return_value = None
@@ -490,11 +498,11 @@ async def test_read_file_bytes_e2b_sandbox_branch():
mock_sandbox = AsyncMock()
mock_sandbox.files.read.return_value = bytearray(b"sandbox content")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox_var, patch(
"backend.copilot.context._current_project_dir"
) as mock_proj:
with (
patch("backend.copilot.context._current_sdk_cwd") as mock_cwd,
patch("backend.copilot.context._current_sandbox") as mock_sandbox_var,
patch("backend.copilot.context._current_project_dir") as mock_proj,
):
mock_cwd.get.return_value = ""
mock_sandbox_var.get.return_value = mock_sandbox
mock_proj.get.return_value = ""
@@ -513,11 +521,11 @@ async def test_read_file_bytes_e2b_path_escapes_sandbox_raises():
session = _make_session()
mock_sandbox = AsyncMock()
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox_var, patch(
"backend.copilot.context._current_project_dir"
) as mock_proj:
with (
patch("backend.copilot.context._current_sdk_cwd") as mock_cwd,
patch("backend.copilot.context._current_sandbox") as mock_sandbox_var,
patch("backend.copilot.context._current_project_dir") as mock_proj,
):
mock_cwd.get.return_value = ""
mock_sandbox_var.get.return_value = mock_sandbox
mock_proj.get.return_value = ""

View File

@@ -1394,11 +1394,7 @@ async def test_e2e_toml_dict_with_list_value_to_concat_block():
"""TOML dict with a list value → List[List[Any]] block: extracts list
values, ignoring scalar values like 'title'."""
toml_content = (
'title = "Fruits"\n'
"[[fruits]]\n"
'name = "apple"\n'
"[[fruits]]\n"
'name = "banana"\n'
'title = "Fruits"\n[[fruits]]\nname = "apple"\n[[fruits]]\nname = "banana"\n'
)
async def _resolve(ref, *a, **kw): # noqa: ARG001
@@ -1692,12 +1688,15 @@ async def test_media_file_field_passthrough_workspace_uri():
},
}
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=AssertionError("should not read file content")),
), patch(
"backend.copilot.sdk.file_ref.read_file_bytes",
new=AsyncMock(side_effect=AssertionError("should not read file bytes")),
with (
patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=AssertionError("should not read file content")),
),
patch(
"backend.copilot.sdk.file_ref.read_file_bytes",
new=AsyncMock(side_effect=AssertionError("should not read file bytes")),
),
):
result = await expand_file_refs_in_args(
{"image": "@@agptfile:workspace://img123"},

View File

@@ -8,7 +8,7 @@ Cross-mode transcript flow
==========================
Both ``baseline/service.py`` (fast mode) and ``sdk/service.py`` (extended_thinking
mode) read and write the same JSONL transcript store via
mode) read and write the same CLI session store via
``backend.copilot.transcript.upload_transcript`` /
``download_transcript``.
@@ -250,8 +250,9 @@ class TestSdkToFastModeSwitch:
@pytest.mark.asyncio
async def test_scenario_s_baseline_loads_sdk_transcript(self):
"""Scenario S: SDK-written transcript is accepted by baseline's load helper."""
"""Scenario S: SDK-written CLI session is accepted by baseline's load helper."""
from backend.copilot.baseline.service import _load_prior_transcript
from backend.copilot.model import ChatMessage
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
from backend.copilot.transcript_builder import TranscriptBuilder
@@ -267,33 +268,41 @@ class TestSdkToFastModeSwitch:
sdk_transcript = builder_sdk.to_jsonl()
# Baseline session now has those 2 SDK messages + 1 new baseline message.
download = TranscriptDownload(content=sdk_transcript, message_count=2)
restore = TranscriptDownload(
content=sdk_transcript.encode("utf-8"), message_count=2, mode="sdk"
)
baseline_builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
covers, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3, # 2 SDK + 1 new baseline
session_messages=[
ChatMessage(role="user", content="sdk-question"),
ChatMessage(role="assistant", content="sdk-answer"),
ChatMessage(role="user", content="baseline-question"),
],
transcript_builder=baseline_builder,
)
# Transcript is valid and covers the prefix.
# CLI session is valid and covers the prefix.
assert covers is True
assert dl is not None
assert baseline_builder.entry_count == 2
@pytest.mark.asyncio
async def test_scenario_s_stale_sdk_transcript_not_loaded(self):
"""Scenario S (stale): SDK transcript is stale — baseline does not load it.
"""Scenario S (stale): SDK CLI session is stale — baseline does not load it.
If SDK mode produced more turns than the transcript captured (e.g.
upload failed on one turn), the baseline rejects the stale transcript
If SDK mode produced more turns than the session captured (e.g.
upload failed on one turn), the baseline rejects the stale session
to avoid injecting an incomplete history.
"""
from backend.copilot.baseline.service import _load_prior_transcript
from backend.copilot.model import ChatMessage
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
from backend.copilot.transcript_builder import TranscriptBuilder
@@ -306,21 +315,33 @@ class TestSdkToFastModeSwitch:
)
sdk_transcript = builder_sdk.to_jsonl()
# Transcript covers only 2 messages but session has 10 (many SDK turns).
download = TranscriptDownload(content=sdk_transcript, message_count=2)
# Session covers only 2 messages but session has 10 (many SDK turns).
# With watermark=2 and 10 total messages, detect_gap will fill the gap
# by appending messages 2..8 (positions 2 to total-2).
restore = TranscriptDownload(
content=sdk_transcript.encode("utf-8"), message_count=2, mode="sdk"
)
# Build a session with 10 alternating user/assistant messages + current user
session_messages = [
ChatMessage(role="user" if i % 2 == 0 else "assistant", content=f"msg-{i}")
for i in range(10)
]
baseline_builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
covers, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=10,
session_messages=session_messages,
transcript_builder=baseline_builder,
)
# Stale transcript must be rejected.
assert covers is False
assert baseline_builder.is_empty
# With gap filling, covers is True and gap messages are appended.
assert covers is True
assert dl is not None
# 2 from transcript + 7 gap messages (positions 2..8, excluding last user turn)
assert baseline_builder.entry_count == 9

View File

@@ -255,6 +255,111 @@ async def test_build_query_no_resume_multi_message(monkeypatch):
assert was_compacted is False # mock returns False
@pytest.mark.asyncio
async def test_build_query_session_msg_ceiling_prevents_pending_duplication():
"""session_msg_ceiling stops pending messages from leaking into the gap.
Scenario: transcript covers 2 messages, session has 2 historical + 1 current
+ 2 pending drained at turn start. Without the ceiling the gap would include
the pending messages AND current_message already has them → duplication.
With session_msg_ceiling=3 (pre-drain count) the gap slice is empty and
only current_message carries the pending content.
"""
# session.messages after drain: [hist1, hist2, current_msg, pending1, pending2]
session = _make_session(
[
ChatMessage(role="user", content="hist1"),
ChatMessage(role="assistant", content="hist2"),
ChatMessage(role="user", content="current msg with pending1 pending2"),
ChatMessage(role="user", content="pending1"),
ChatMessage(role="user", content="pending2"),
]
)
# transcript covers hist1+hist2 (2 messages); pre-drain count was 3 (includes current_msg)
result, was_compacted = await _build_query_message(
"current msg with pending1 pending2",
session,
use_resume=True,
transcript_msg_count=2,
session_id="test-session",
session_msg_ceiling=3, # len(session.messages) before drain
)
# Gap should be empty (transcript_msg_count == ceiling - 1), so no history prepended
assert result == "current msg with pending1 pending2"
assert was_compacted is False
# Pending messages must NOT appear in gap context
assert "pending1" not in result.split("current msg")[0]
@pytest.mark.asyncio
async def test_build_query_session_msg_ceiling_preserves_real_gap():
"""session_msg_ceiling still surfaces a genuine stale-transcript gap.
Scenario: transcript covers 2 messages, session has 4 historical + 1 current
+ 2 pending. Ceiling = 5 (pre-drain). Real gap = messages 2-3 (hist3, hist4).
"""
session = _make_session(
[
ChatMessage(role="user", content="hist1"),
ChatMessage(role="assistant", content="hist2"),
ChatMessage(role="user", content="hist3"),
ChatMessage(role="assistant", content="hist4"),
ChatMessage(role="user", content="current"),
ChatMessage(role="user", content="pending1"),
ChatMessage(role="user", content="pending2"),
]
)
result, was_compacted = await _build_query_message(
"current",
session,
use_resume=True,
transcript_msg_count=2,
session_id="test-session",
session_msg_ceiling=5, # pre-drain: [hist1..hist4, current]
)
# Gap = session.messages[2:4] = [hist3, hist4]
assert "<conversation_history>" in result
assert "hist3" in result
assert "hist4" in result
assert "Now, the user says:\ncurrent" in result
# Pending messages must NOT appear in gap
assert "pending1" not in result
assert "pending2" not in result
@pytest.mark.asyncio
async def test_build_query_session_msg_ceiling_suppresses_spurious_no_resume_fallback():
"""session_msg_ceiling prevents the no-resume compression fallback from
firing on the first turn of a session when pending messages inflate msg_count.
Scenario: fresh session (1 message) + 1 pending message drained at turn start.
Without the ceiling: msg_count=2 > 1 → fallback triggers → pending message
leaked into history → wrong context sent to model.
With session_msg_ceiling=1 (pre-drain count): effective_count=1, 1 > 1 is False
→ fallback does not trigger → current_message returned as-is.
"""
# session.messages after drain: [current_msg, pending_msg]
session = _make_session(
[
ChatMessage(role="user", content="What is 2 plus 2?"),
ChatMessage(role="user", content="What is 7 plus 7?"), # pending
]
)
result, was_compacted = await _build_query_message(
"What is 2 plus 2?\n\nWhat is 7 plus 7?",
session,
use_resume=False,
transcript_msg_count=0,
session_id="test-session",
session_msg_ceiling=1, # pre-drain: only 1 message existed
)
# Should return current_message directly without wrapping in history context
assert result == "What is 2 plus 2?\n\nWhat is 7 plus 7?"
assert was_compacted is False
# Pending question must NOT appear in a spurious history section
assert "<conversation_history>" not in result
@pytest.mark.asyncio
async def test_build_query_no_resume_multi_message_compacted(monkeypatch):
"""When compression actually compacts, was_compacted should be True."""

View File

@@ -28,6 +28,9 @@ from backend.copilot.response_model import (
StreamFinish,
StreamFinishStep,
StreamHeartbeat,
StreamReasoningDelta,
StreamReasoningEnd,
StreamReasoningStart,
StreamStart,
StreamStartStep,
StreamTextDelta,
@@ -56,9 +59,21 @@ class SDKResponseAdapter:
self.text_block_id = str(uuid.uuid4())
self.has_started_text = False
self.has_ended_text = False
self.reasoning_block_id = str(uuid.uuid4())
self.has_started_reasoning = False
self.has_ended_reasoning = True
self.current_tool_calls: dict[str, dict[str, str]] = {}
self.resolved_tool_calls: set[str] = set()
self.step_open = False
# Track whether any ``TextBlock`` was emitted after the most recent
# tool_result. Used at ``ResultMessage`` time to detect the
# "thinking-only final turn" case — when Claude's last LLM call
# produced only a ``ThinkingBlock`` (no text, no tool_use) the UI
# hangs on the last tool result with a "Thought for Xs" label and
# no response text. We synthesize a short closing line in that
# case so the turn renders as cleanly complete.
self._text_since_last_tool_result = False
self._any_tool_results_seen = False
@property
def has_unresolved_tool_calls(self) -> bool:
@@ -103,18 +118,43 @@ class SDKResponseAdapter:
for block in sdk_message.content:
if isinstance(block, TextBlock):
if block.text:
# Reasoning and text are distinct UI parts; close
# any open reasoning block before opening text so
# the AI SDK transport doesn't merge them.
self._end_reasoning_if_open(responses)
self._ensure_text_started(responses)
responses.append(
StreamTextDelta(id=self.text_block_id, delta=block.text)
)
self._text_since_last_tool_result = True
elif isinstance(block, ThinkingBlock):
# Thinking blocks are preserved in the transcript but
# not streamed to the frontend — skip silently.
pass
# Stream extended_thinking content as a reasoning
# block. The Vercel AI SDK's ``useChat`` transport
# recognises ``reasoning-start`` / ``reasoning-delta``
# / ``reasoning-end`` events and accumulates them into
# a ``type: 'reasoning'`` UIMessage part the frontend
# renders via ``ReasoningCollapse`` (collapsed by
# default). We also persist the text as a
# ``type: 'thinking'`` part in ``session.messages`` via
# ``_format_sdk_content_blocks``, so shared / reloaded
# sessions see the same reasoning. Without streaming
# it live, extended_thinking turns that end
# thinking-only left the UI stuck on "Thought for Xs"
# with nothing rendered until a page refresh.
if block.thinking:
self._end_text_if_open(responses)
self._ensure_reasoning_started(responses)
responses.append(
StreamReasoningDelta(
id=self.reasoning_block_id,
delta=block.thinking,
)
)
elif isinstance(block, ToolUseBlock):
self._end_text_if_open(responses)
self._end_reasoning_if_open(responses)
# Strip MCP prefix so frontend sees "find_block"
# instead of "mcp__copilot__find_block".
@@ -210,16 +250,58 @@ class SDKResponseAdapter:
resolved_in_blocks.add(parent_id)
self.resolved_tool_calls.update(resolved_in_blocks)
if resolved_in_blocks:
# A new tool_result just landed — reset the
# "has the model emitted text since the last tool result?"
# tracker so the thinking-only-final-turn guard at
# ``ResultMessage`` time stays accurate.
self._text_since_last_tool_result = False
self._any_tool_results_seen = True
# Close the current step after tool results — the next
# AssistantMessage will open a new step for the continuation.
if self.step_open:
self._end_reasoning_if_open(responses)
responses.append(StreamFinishStep())
self.step_open = False
elif isinstance(sdk_message, ResultMessage):
self._flush_unresolved_tool_calls(responses)
# Thinking-only final turn guard: when the model's last LLM
# call after a tool result produced only a ``ThinkingBlock``
# (no ``TextBlock``, no ``ToolUseBlock``) the UI has nothing
# to render after the tool output — it hangs on "Thought for
# Xs" with no response text. Synthesise a short closing line
# so the turn visibly completes. Condition: we've seen at
# least one tool_result AND zero TextBlocks since. The
# prompt rule (``_USER_FOLLOW_UP_NOTE``'s closing clause)
# asks the model to always end with text, but we can't rely
# on it for extended_thinking / edge cases.
if (
self._any_tool_results_seen
and not self._text_since_last_tool_result
and sdk_message.subtype == "success"
):
# UserMessage (tool_result) closed the last step, so we must
# open a fresh one before emitting any text — the AI SDK v5
# transport rejects text-delta chunks that aren't wrapped in
# start-step / finish-step.
if not self.step_open:
responses.append(StreamStartStep())
self.step_open = True
# Close any open reasoning block first — text and reasoning
# must not interleave on the wire (AI SDK v5 maps distinct
# start/end events to distinct UI parts).
self._end_reasoning_if_open(responses)
self._ensure_text_started(responses)
responses.append(
StreamTextDelta(
id=self.text_block_id,
delta="(Done — no further commentary.)",
)
)
self._end_text_if_open(responses)
self._end_reasoning_if_open(responses)
# Close the step before finishing.
if self.step_open:
responses.append(StreamFinishStep())
@@ -261,6 +343,26 @@ class SDKResponseAdapter:
responses.append(StreamTextEnd(id=self.text_block_id))
self.has_ended_text = True
def _ensure_reasoning_started(self, responses: list[StreamBaseResponse]) -> None:
"""Start (or restart) a reasoning block if needed.
Each ``ThinkingBlock`` the SDK emits gets its own streaming block
on the wire so the frontend can render a new ``Reasoning`` part
per LLM turn (rather than concatenating across the whole session).
"""
if not self.has_started_reasoning or self.has_ended_reasoning:
if self.has_ended_reasoning:
self.reasoning_block_id = str(uuid.uuid4())
self.has_ended_reasoning = False
responses.append(StreamReasoningStart(id=self.reasoning_block_id))
self.has_started_reasoning = True
def _end_reasoning_if_open(self, responses: list[StreamBaseResponse]) -> None:
"""End the current reasoning block if one is open."""
if self.has_started_reasoning and not self.has_ended_reasoning:
responses.append(StreamReasoningEnd(id=self.reasoning_block_id))
self.has_ended_reasoning = True
def _flush_unresolved_tool_calls(self, responses: list[StreamBaseResponse]) -> None:
"""Emit outputs for tool calls that didn't receive a UserMessage result.
@@ -305,7 +407,7 @@ class SDKResponseAdapter:
self.resolved_tool_calls.add(tool_id)
flushed = True
logger.info(
"[SDK] [%s] Flushed stashed output for %s " "(call %s, %d chars)",
"[SDK] [%s] Flushed stashed output for %s (call %s, %d chars)",
sid,
tool_name,
tool_id[:12],
@@ -335,9 +437,17 @@ class SDKResponseAdapter:
tool_id[:12],
)
if flushed and self.step_open:
responses.append(StreamFinishStep())
self.step_open = False
if flushed:
# Mirror the UserMessage tool_result path: a flushed tool output is
# still a tool_result as far as the thinking-only-final-turn guard
# is concerned. Without this, a turn whose ONLY tool outputs come
# from the flush path (SDK built-ins like WebSearch) would miss
# the fallback synthesis if the model then produced no text.
self._text_since_last_tool_result = False
self._any_tool_results_seen = True
if self.step_open:
responses.append(StreamFinishStep())
self.step_open = False
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:

View File

@@ -8,6 +8,7 @@ from claude_agent_sdk import (
ResultMessage,
SystemMessage,
TextBlock,
ThinkingBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
@@ -19,6 +20,7 @@ from backend.copilot.response_model import (
StreamFinish,
StreamFinishStep,
StreamHeartbeat,
StreamReasoningDelta,
StreamStart,
StreamStartStep,
StreamTextDelta,
@@ -251,6 +253,200 @@ def test_result_success_emits_finish_step_and_finish():
assert isinstance(results[2], StreamFinish)
# -- Reasoning streaming -----------------------------------------------------
def test_thinking_block_streams_as_reasoning():
"""ThinkingBlock content streams as StreamReasoningDelta so the
frontend renders it via the ``Reasoning`` part (collapsed by
default) instead of dropping it silently."""
adapter = _adapter()
msg = AssistantMessage(
content=[
ThinkingBlock(thinking="planning step 1", signature="sig"),
],
model="test",
)
results = adapter.convert_message(msg)
# Step + ReasoningStart + ReasoningDelta
types = [type(r).__name__ for r in results]
assert "StreamReasoningStart" in types
assert any(
isinstance(r, StreamReasoningDelta) and r.delta == "planning step 1"
for r in results
)
def test_text_after_thinking_closes_reasoning_and_opens_text():
"""Reasoning and text are distinct UI parts — opening text must
emit ``ReasoningEnd`` first so the AI SDK transport doesn't merge
them into the same ``Reasoning`` part."""
adapter = _adapter()
adapter.convert_message(
AssistantMessage(
content=[ThinkingBlock(thinking="warming up", signature="sig")],
model="test",
)
)
results = adapter.convert_message(
AssistantMessage(content=[TextBlock(text="hello")], model="test")
)
types = [type(r).__name__ for r in results]
# ReasoningEnd must come before TextStart
re_idx = types.index("StreamReasoningEnd")
ts_idx = types.index("StreamTextStart")
assert re_idx < ts_idx
def test_tool_use_after_thinking_closes_reasoning():
"""Opening a tool also closes an open reasoning block."""
adapter = _adapter()
adapter.convert_message(
AssistantMessage(
content=[ThinkingBlock(thinking="let me search", signature="sig")],
model="test",
)
)
results = adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_block", input={})
],
model="test",
)
)
types = [type(r).__name__ for r in results]
assert types.index("StreamReasoningEnd") < types.index("StreamToolInputStart")
def test_empty_thinking_block_is_ignored():
"""A ThinkingBlock with empty content shouldn't emit anything."""
adapter = _adapter()
msg = AssistantMessage(
content=[ThinkingBlock(thinking="", signature="sig")],
model="test",
)
results = adapter.convert_message(msg)
# Only the StepStart fires — no reasoning events.
assert [type(r).__name__ for r in results] == ["StreamStartStep"]
def test_result_success_synthesizes_fallback_text_when_final_turn_is_thinking_only():
"""If the model's last LLM call after a tool_result produced only a
ThinkingBlock (no TextBlock), the UI would hang on the tool output
with no response text. The adapter should inject a short closing
line before ``StreamFinish`` so the turn visibly completes."""
adapter = _adapter()
# Tool use + tool_result (simulates the tool round).
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_block", input={}),
],
model="test",
)
)
adapter.convert_message(
UserMessage(
content=[
ToolResultBlock(tool_use_id="t1", content="result", is_error=False)
],
parent_tool_use_id=None,
)
)
# Model's "final turn" after tool_result is thinking-only. This test
# simulates the *degenerate* case where the SDK never surfaces an
# AssistantMessage carrying the ThinkingBlock at all (not even the
# streamed reasoning events) before ResultMessage — only the tool_result
# has arrived. The fallback guard should still synthesize closing text.
msg = ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=4,
session_id="s1",
result="",
)
results = adapter.convert_message(msg)
# Fallback text should be injected before the finish events.
text_deltas = [r for r in results if isinstance(r, StreamTextDelta)]
assert len(text_deltas) == 1, "should synthesize exactly one fallback text"
assert text_deltas[0].delta.strip() # non-empty
assert isinstance(results[-1], StreamFinish)
def test_result_success_does_not_synthesize_when_text_already_emitted():
"""Guard: do NOT synthesize when the model DID emit closing text
after the last tool result — the fallback is only for the silent
thinking-only case."""
adapter = _adapter()
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_block", input={})
],
model="test",
)
)
adapter.convert_message(
UserMessage(
content=[
ToolResultBlock(tool_use_id="t1", content="result", is_error=False)
],
parent_tool_use_id=None,
)
)
# Model responds with actual text after the tool result.
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="all done")], model="test")
)
msg = ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=4,
session_id="s1",
result="all done",
)
results = adapter.convert_message(msg)
# No fallback — the only TextDelta came from the previous
# AssistantMessage call, not from ResultMessage's synthesis.
text_deltas = [r for r in results if isinstance(r, StreamTextDelta)]
assert text_deltas == []
def test_result_success_does_not_synthesize_when_no_tools_ran():
"""Guard: no tool_results seen ⇒ no fallback. Pure-text turns with
no tools legitimately produce text-only responses through normal
AssistantMessage events; we don't need a fallback there."""
adapter = _adapter()
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="hello")], model="test")
)
msg = ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="s1",
result="hello",
)
results = adapter.convert_message(msg)
text_deltas = [r for r in results if isinstance(r, StreamTextDelta)]
assert text_deltas == []
def test_result_error_emits_error_and_finish():
adapter = _adapter()
msg = ResultMessage(
@@ -426,6 +622,13 @@ def test_flush_unresolved_at_result_message():
"StreamToolInputAvailable",
"StreamToolOutputAvailable", # flushed with empty output
"StreamFinishStep", # step closed by flush
# Flush marks a tool_result as seen, so the thinking-only-final-turn
# guard at ResultMessage time synthesizes a closing text delta.
"StreamStartStep",
"StreamTextStart",
"StreamTextDelta",
"StreamTextEnd",
"StreamFinishStep",
"StreamFinish",
]
# The flushed output should be empty (no stash available)

View File

@@ -27,6 +27,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot.transcript import (
TranscriptDownload,
_flatten_assistant_content,
_flatten_tool_result_content,
_messages_to_transcript,
@@ -999,14 +1000,15 @@ def _make_sdk_patches(
f"{_SVC}.download_transcript",
dict(
new_callable=AsyncMock,
return_value=MagicMock(content=original_transcript, message_count=2),
return_value=TranscriptDownload(
content=original_transcript.encode("utf-8"),
message_count=2,
mode="sdk",
),
),
),
(
f"{_SVC}.restore_cli_session",
dict(new_callable=AsyncMock, return_value=True),
),
(f"{_SVC}.upload_cli_session", dict(new_callable=AsyncMock)),
(f"{_SVC}.strip_for_upload", dict(return_value=original_transcript)),
(f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)),
(f"{_SVC}.validate_transcript", dict(return_value=True)),
(
f"{_SVC}.compact_transcript",
@@ -1037,8 +1039,13 @@ def _make_sdk_patches(
claude_agent_fallback_model=None,
),
),
(f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)),
(f"{_SVC}.get_user_tier", dict(new_callable=AsyncMock, return_value=None)),
# Stub pending-message drain so retry tests don't hit Redis.
# Returns an empty list → no mid-turn injection happens.
(
f"{_SVC}.drain_pending_safe",
dict(new_callable=AsyncMock, return_value=[]),
),
]
@@ -1914,14 +1921,14 @@ class TestStreamChatCompletionRetryIntegration:
compacted_transcript=None,
client_side_effect=_client_factory,
)
# Override restore_cli_session to return False (CLI native session unavailable)
# Override download_transcript to return None (CLI native session unavailable)
patches = [
(
(
f"{_SVC}.restore_cli_session",
dict(new_callable=AsyncMock, return_value=False),
f"{_SVC}.download_transcript",
dict(new_callable=AsyncMock, return_value=None),
)
if p[0] == f"{_SVC}.restore_cli_session"
if p[0] == f"{_SVC}.download_transcript"
else p
)
for p in patches
@@ -1944,7 +1951,7 @@ class TestStreamChatCompletionRetryIntegration:
# captured_options holds {"options": ClaudeAgentOptions}, so check
# the attribute directly rather than dict keys.
assert not getattr(captured_options.get("options"), "resume", None), (
f"--resume was set even though restore_cli_session returned False: "
f"--resume was set even though download_transcript returned None: "
f"{captured_options}"
)
assert any(isinstance(e, StreamStart) for e in events)

View File

@@ -94,21 +94,23 @@ def test_agent_options_accepts_required_fields():
def test_agent_options_accepts_system_prompt_preset_with_exclude_dynamic_sections():
"""Verify ClaudeAgentOptions accepts the exact preset dict _build_system_prompt_value produces.
The production code always includes ``exclude_dynamic_sections=True`` in the preset
dict. This compat test mirrors that exact shape so any SDK version that starts
rejecting unknown keys will be caught here rather than at runtime.
The Turn 1 (non-resume) code path includes ``exclude_dynamic_sections=True`` in
the preset dict for cross-user caching. This compat test mirrors that exact
shape so any SDK version that starts rejecting unknown keys will be caught
here rather than at runtime.
"""
from claude_agent_sdk import ClaudeAgentOptions
from claude_agent_sdk.types import SystemPromptPreset
from .service import _build_system_prompt_value
# Call the production helper directly so this test is tied to the real
# dict shape rather than a hand-rolled copy.
preset = _build_system_prompt_value("custom system prompt", cross_user_cache=True)
assert isinstance(
preset, dict
), "_build_system_prompt_value must return a dict when caching is on"
assert preset.get("exclude_dynamic_sections") is True, (
"Turn 1 must strip dynamic sections to keep the prefix cacheable " "cross-user"
)
sdk_preset = cast(SystemPromptPreset, preset)
opts = ClaudeAgentOptions(system_prompt=sdk_preset)
@@ -116,8 +118,9 @@ def test_agent_options_accepts_system_prompt_preset_with_exclude_dynamic_section
def test_build_system_prompt_value_returns_plain_string_when_cross_user_cache_off():
"""When cross_user_cache=False (e.g. on --resume turns), the helper must return
a plain string so the preset+resume crash is avoided."""
"""When cross_user_cache=False (feature flag disabled globally), the
helper returns a plain string; the CLI will receive --system-prompt
(replace-mode) and skip the preset entirely."""
from .service import _build_system_prompt_value
result = _build_system_prompt_value("my prompt", cross_user_cache=False)
@@ -262,6 +265,12 @@ _KNOWN_GOOD_BUNDLED_CLI_VERSIONS: frozenset[str] = frozenset(
"2.1.97", # claude-agent-sdk 0.1.58 -- OpenRouter-safe only with
# CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS=1 (injected by
# build_sdk_env() in env.py).
"2.1.116", # claude-agent-sdk 0.1.64 -- first bundled version that
# fixes the --resume + excludeDynamicSections=True crash
# (introduced in 2.1.98), unlocking cross-user prompt
# cache reads on every resumed SDK turn. Still requires
# CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS=1. Verified
# OpenRouter-safe via cli_openrouter_compat_test.py.
}
)

View File

@@ -10,7 +10,12 @@ import re
from collections.abc import Callable
from typing import Any, cast
from backend.copilot.context import is_allowed_local_path, is_sdk_tool_path
from backend.copilot.context import (
get_execution_context,
is_allowed_local_path,
is_sdk_tool_path,
)
from backend.copilot.pending_messages import drain_and_format_for_injection
from .tool_adapter import (
BLOCKED_TOOLS,
@@ -327,6 +332,30 @@ def create_security_hooks(
tool_name,
)
# Mid-turn drain: after ANY tool finishes (MCP or built-in), pull
# any queued user follow-up messages and attach them to the
# tool_result as ``additionalContext``. This is the
# protocol-legal mid-turn injection slot — Claude reads the
# follow-up on the next LLM round without starting a new turn.
# The drain helper also stashes a persist-queue copy so
# ``sdk/service.py`` can append a matching user row to the UI.
_, session = get_execution_context()
followup = ""
if session is not None and session.session_id:
followup = await drain_and_format_for_injection(
session.session_id,
log_prefix="[SDK][PostToolUse]",
)
if followup:
return cast(
SyncHookJSONOutput,
{
"hookSpecificOutput": {
"hookEventName": "PostToolUse",
"additionalContext": followup,
}
},
)
return cast(SyncHookJSONOutput, {})
async def post_tool_failure_hook(
@@ -365,7 +394,7 @@ def create_security_hooks(
trigger = _sanitize(str(input_data.get("trigger", "auto")), max_len=50)
# Sanitize untrusted input: strip control chars for logging AND
# for the value passed downstream. read_compacted_entries()
# validates against _projects_base() as defence-in-depth, but
# validates against projects_base() as defence-in-depth, but
# sanitizing here prevents log injection and rejects obviously
# malformed paths early.
transcript_path = _sanitize(

View File

@@ -699,3 +699,160 @@ async def test_subagent_hooks_sanitize_inputs(_subagent_hooks, caplog):
assert "\u202a" not in record.message
assert "\u200b" not in record.message
assert "/tmp/maliciouspath" in caplog.text
# -- PostToolUse: mid-turn pending-message drain ------------------------------
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_post_tool_use_injects_followup_additional_context(
monkeypatch,
):
"""Queued messages drain into ``additionalContext`` for any tool."""
from unittest.mock import MagicMock
from backend.copilot import context as ctx_mod
from backend.copilot import pending_messages as pm_module
session = MagicMock()
session.session_id = "sess-post-inject"
ctx_mod.set_execution_context(
user_id="u1",
session=session,
sandbox=None,
sdk_cwd=SDK_CWD,
)
async def fake_drain(_session_id: str):
assert _session_id == "sess-post-inject"
return [pm_module.PendingMessage(content="please also do X")]
async def fake_stash(_session_id, _messages):
return None
monkeypatch.setattr(
"backend.copilot.pending_messages.drain_pending_messages", fake_drain
)
monkeypatch.setattr(
"backend.copilot.pending_messages.stash_pending_for_persist", fake_stash
)
hooks = create_security_hooks(user_id="u1", sdk_cwd=SDK_CWD, max_subtasks=2)
post = hooks["PostToolUse"][0].hooks[0]
result = await post(
{
"tool_name": "WebSearch", # built-in — the path the old wrapper missed
"tool_response": "search results here",
},
tool_use_id="tu-web-1",
context={},
)
injected = result.get("hookSpecificOutput", {})
assert injected.get("hookEventName") == "PostToolUse"
assert "<user_follow_up>" in injected.get("additionalContext", "")
assert "please also do X" in injected.get("additionalContext", "")
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_post_tool_use_no_pending_returns_empty(monkeypatch):
from unittest.mock import MagicMock
from backend.copilot import context as ctx_mod
session = MagicMock()
session.session_id = "sess-post-empty"
ctx_mod.set_execution_context(
user_id="u1", session=session, sandbox=None, sdk_cwd=SDK_CWD
)
async def fake_drain(_session_id: str):
return []
monkeypatch.setattr(
"backend.copilot.pending_messages.drain_pending_messages", fake_drain
)
hooks = create_security_hooks(user_id="u1", sdk_cwd=SDK_CWD, max_subtasks=2)
post = hooks["PostToolUse"][0].hooks[0]
result = await post(
{"tool_name": "mcp__copilot__run_block", "tool_response": "ok"},
tool_use_id="tu-mcp-1",
context={},
)
# No additionalContext means Claude gets the tool_result verbatim.
assert "hookSpecificOutput" not in result
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_post_tool_use_drain_failure_returns_empty(monkeypatch):
"""A Redis blip must not corrupt the hook response."""
from unittest.mock import MagicMock
from backend.copilot import context as ctx_mod
session = MagicMock()
session.session_id = "sess-post-fail"
ctx_mod.set_execution_context(
user_id="u1", session=session, sandbox=None, sdk_cwd=SDK_CWD
)
async def failing_drain(_session_id: str):
raise RuntimeError("redis down")
monkeypatch.setattr(
"backend.copilot.pending_messages.drain_pending_messages", failing_drain
)
hooks = create_security_hooks(user_id="u1", sdk_cwd=SDK_CWD, max_subtasks=2)
post = hooks["PostToolUse"][0].hooks[0]
result = await post(
{"tool_name": "Read", "tool_response": "file body"},
tool_use_id="tu-read-1",
context={},
)
assert "hookSpecificOutput" not in result
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_post_tool_use_no_session_skips_drain(monkeypatch):
from backend.copilot import context as ctx_mod
ctx_mod.set_execution_context(
user_id=None,
session=None, # type: ignore[arg-type]
sandbox=None,
sdk_cwd=SDK_CWD,
)
drain_called = False
async def fake_drain(_session_id: str):
nonlocal drain_called
drain_called = True
return []
monkeypatch.setattr(
"backend.copilot.pending_messages.drain_pending_messages", fake_drain
)
hooks = create_security_hooks(user_id=None, sdk_cwd=SDK_CWD, max_subtasks=2)
post = hooks["PostToolUse"][0].hooks[0]
result = await post(
{"tool_name": "WebSearch", "tool_response": "x"},
tool_use_id="tu-x",
context={},
)
assert drain_called is False
assert "hookSpecificOutput" not in result

File diff suppressed because it is too large Load Diff

View File

@@ -22,6 +22,7 @@ from .service import (
_iter_sdk_messages,
_normalize_model_name,
_reduce_context,
_restore_cli_session_for_turn,
_TokenUsage,
)
@@ -615,3 +616,340 @@ class TestSdkSessionIdSelection:
)
assert retry.get("resume") == self.SESSION_ID
assert "session_id" not in retry
# ---------------------------------------------------------------------------
# _restore_cli_session_for_turn — mode check
# ---------------------------------------------------------------------------
class TestRestoreCliSessionModeCheck:
"""SDK skips --resume when the transcript was written by the baseline mode."""
@pytest.mark.asyncio
async def test_baseline_mode_transcript_skips_gcs_content(self, tmp_path):
"""A transcript with mode='baseline' must not be used as the --resume source.
The mode check discards the GCS baseline content and falls back to DB
reconstruction from session.messages instead.
"""
from datetime import UTC, datetime
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.transcript import TranscriptDownload
from backend.copilot.transcript_builder import TranscriptBuilder
session = ChatSession(
session_id="test-session",
user_id="user-1",
messages=[
ChatMessage(role="user", content="hello-unique-marker"),
ChatMessage(role="assistant", content="world-unique-marker"),
ChatMessage(role="user", content="follow up"),
],
title="test",
usage=[],
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
builder = TranscriptBuilder()
# Baseline content with a sentinel that must NOT appear in the final transcript
baseline_restore = TranscriptDownload(
content=b'{"type":"user","uuid":"bad-uuid","message":{"role":"user","content":"BASELINE_SENTINEL"}}\n',
message_count=1,
mode="baseline",
)
import backend.copilot.sdk.service as _svc_mod
download_mock = AsyncMock(return_value=baseline_restore)
with (
patch(
"backend.copilot.sdk.service.download_transcript",
new=download_mock,
),
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
):
result = await _restore_cli_session_for_turn(
user_id="user-1",
session_id="test-session",
session=session,
sdk_cwd=str(tmp_path),
transcript_builder=builder,
log_prefix="[Test]",
)
# download_transcript was called (attempted GCS restore)
download_mock.assert_awaited_once()
# use_resume must be False — baseline transcripts cannot be used with --resume
assert result.use_resume is False
# context_messages must be populated — new behaviour uses transcript content + gap
# instead of full DB reconstruction.
assert result.context_messages is not None
# The baseline transcript has 1 user message (BASELINE_SENTINEL).
# Watermark=1 but position 0 is 'user', not 'assistant', so detect_gap returns [].
# Result: 1 message from transcript, no gap.
assert len(result.context_messages) == 1
assert "BASELINE_SENTINEL" in (result.context_messages[0].content or "")
@pytest.mark.asyncio
async def test_sdk_mode_transcript_allows_resume(self, tmp_path):
"""A valid SDK-written transcript is accepted for --resume."""
import json as stdlib_json
from datetime import UTC, datetime
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
from backend.copilot.transcript_builder import TranscriptBuilder
lines = [
stdlib_json.dumps(
{
"type": "user",
"uuid": "uid-0",
"parentUuid": "",
"message": {"role": "user", "content": "hi"},
}
),
stdlib_json.dumps(
{
"type": "assistant",
"uuid": "uid-1",
"parentUuid": "uid-0",
"message": {
"role": "assistant",
"id": "msg_1",
"model": "test",
"type": "message",
"stop_reason": STOP_REASON_END_TURN,
"content": [{"type": "text", "text": "hello"}],
},
}
),
]
content = ("\n".join(lines) + "\n").encode("utf-8")
session = ChatSession(
session_id="test-session",
user_id="user-1",
messages=[
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="hello"),
ChatMessage(role="user", content="follow up"),
],
title="test",
usage=[],
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
builder = TranscriptBuilder()
sdk_restore = TranscriptDownload(
content=content,
message_count=2,
mode="sdk",
)
import backend.copilot.sdk.service as _svc_mod
with (
patch(
"backend.copilot.sdk.service.download_transcript",
new=AsyncMock(return_value=sdk_restore),
),
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
):
result = await _restore_cli_session_for_turn(
user_id="user-1",
session_id="test-session",
session=session,
sdk_cwd=str(tmp_path),
transcript_builder=builder,
log_prefix="[Test]",
)
assert result.use_resume is True
@pytest.mark.asyncio
async def test_baseline_mode_context_messages_from_transcript_content(
self, tmp_path
):
"""mode='baseline' → context_messages populated from transcript content + gap.
When a baseline-mode transcript exists, extract_context_messages converts
the JSONL content to ChatMessage objects and returns them in context_messages.
use_resume must remain False.
"""
import json as stdlib_json
from datetime import UTC, datetime
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
from backend.copilot.transcript_builder import TranscriptBuilder
# Build a minimal valid JSONL transcript with 2 messages
lines = [
stdlib_json.dumps(
{
"type": "user",
"uuid": "uid-0",
"parentUuid": "",
"message": {"role": "user", "content": "TRANSCRIPT_USER"},
}
),
stdlib_json.dumps(
{
"type": "assistant",
"uuid": "uid-1",
"parentUuid": "uid-0",
"message": {
"role": "assistant",
"id": "msg_1",
"model": "test",
"type": "message",
"stop_reason": STOP_REASON_END_TURN,
"content": [{"type": "text", "text": "TRANSCRIPT_ASSISTANT"}],
},
}
),
]
content = ("\n".join(lines) + "\n").encode("utf-8")
session = ChatSession(
session_id="test-session",
user_id="user-1",
messages=[
ChatMessage(role="user", content="DB_USER"),
ChatMessage(role="assistant", content="DB_ASSISTANT"),
ChatMessage(role="user", content="current turn"),
],
title="test",
usage=[],
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
builder = TranscriptBuilder()
baseline_restore = TranscriptDownload(
content=content,
message_count=2,
mode="baseline",
)
import backend.copilot.sdk.service as _svc_mod
with (
patch(
"backend.copilot.sdk.service.download_transcript",
new=AsyncMock(return_value=baseline_restore),
),
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
):
result = await _restore_cli_session_for_turn(
user_id="user-1",
session_id="test-session",
session=session,
sdk_cwd=str(tmp_path),
transcript_builder=builder,
log_prefix="[Test]",
)
assert result.use_resume is False
assert result.context_messages is not None
# Transcript content has 2 messages, no gap (watermark=2, session prior=2)
assert len(result.context_messages) == 2
assert result.context_messages[0].role == "user"
assert result.context_messages[1].role == "assistant"
assert "TRANSCRIPT_ASSISTANT" in (result.context_messages[1].content or "")
# transcript_content must be non-empty so the _seed_transcript guard in
# stream_chat_completion_sdk skips DB reconstruction (which would duplicate
# builder entries since load_previous appends).
assert result.transcript_content != ""
@pytest.mark.asyncio
async def test_baseline_mode_gap_present_context_includes_gap(self, tmp_path):
"""mode='baseline' + gap → context_messages includes transcript msgs and gap."""
import json as stdlib_json
from datetime import UTC, datetime
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
from backend.copilot.transcript_builder import TranscriptBuilder
# Transcript covers only 2 messages; session has 4 prior + current turn
lines = [
stdlib_json.dumps(
{
"type": "user",
"uuid": "uid-0",
"parentUuid": "",
"message": {"role": "user", "content": "TRANSCRIPT_USER_0"},
}
),
stdlib_json.dumps(
{
"type": "assistant",
"uuid": "uid-1",
"parentUuid": "uid-0",
"message": {
"role": "assistant",
"id": "msg_1",
"model": "test",
"type": "message",
"stop_reason": STOP_REASON_END_TURN,
"content": [{"type": "text", "text": "TRANSCRIPT_ASSISTANT_1"}],
},
}
),
]
content = ("\n".join(lines) + "\n").encode("utf-8")
session = ChatSession(
session_id="test-session",
user_id="user-1",
messages=[
ChatMessage(role="user", content="DB_USER_0"),
ChatMessage(role="assistant", content="DB_ASSISTANT_1"),
ChatMessage(role="user", content="GAP_USER_2"),
ChatMessage(role="assistant", content="GAP_ASSISTANT_3"),
ChatMessage(role="user", content="current turn"),
],
title="test",
usage=[],
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
builder = TranscriptBuilder()
baseline_restore = TranscriptDownload(
content=content,
message_count=2, # watermark=2; session has 4 prior → gap of 2
mode="baseline",
)
import backend.copilot.sdk.service as _svc_mod
with (
patch(
"backend.copilot.sdk.service.download_transcript",
new=AsyncMock(return_value=baseline_restore),
),
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
):
result = await _restore_cli_session_for_turn(
user_id="user-1",
session_id="test-session",
session=session,
sdk_cwd=str(tmp_path),
transcript_builder=builder,
log_prefix="[Test]",
)
assert result.use_resume is False
assert result.context_messages is not None
# 2 from transcript + 2 gap messages = 4 total
assert len(result.context_messages) == 4
roles = [m.role for m in result.context_messages]
assert roles == ["user", "assistant", "user", "assistant"]
# Gap messages come from DB (ChatMessage objects)
gap_user = result.context_messages[2]
gap_asst = result.context_messages[3]
assert gap_user.content == "GAP_USER_2"
assert gap_asst.content == "GAP_ASSISTANT_3"

View File

@@ -11,6 +11,7 @@ import pytest
from backend.copilot import config as cfg_mod
from .service import (
_IDLE_TIMEOUT_SECONDS,
_build_system_prompt_value,
_is_sdk_disconnect_error,
_normalize_model_name,
@@ -176,70 +177,18 @@ class TestPromptSupplement:
assert "## Tool notes" in local_supplement
assert "## Tool notes" in e2b_supplement
def test_baseline_supplement_includes_tool_docs(self):
"""Baseline mode MUST include tool documentation (direct API needs it)."""
from backend.copilot.prompting import get_baseline_supplement
def test_baseline_supplement_has_shared_notes_no_tool_list(self):
"""Baseline now relies on the OpenAI tools array for schemas and only
appends SHARED_TOOL_NOTES (workflow rules not present in any schema).
The old auto-generated ``## AVAILABLE TOOLS`` list is gone — it was
~4.3K tokens of pure duplication of the tools array."""
from backend.copilot.prompting import SHARED_TOOL_NOTES
supplement = get_baseline_supplement()
# MUST have tool list section
assert "## AVAILABLE TOOLS" in supplement
# Should NOT have environment-specific notes (SDK-only)
assert "## Tool notes" not in supplement
def test_baseline_supplement_includes_key_tools(self):
"""Baseline supplement should document all essential tools."""
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.tools import TOOL_REGISTRY
docs = get_baseline_supplement()
# Core agent workflow tools (always available)
assert "`create_agent`" in docs
assert "`run_agent`" in docs
assert "`find_library_agent`" in docs
assert "`edit_agent`" in docs
# MCP integration (always available)
assert "`run_mcp_tool`" in docs
# Folder management (always available)
assert "`create_folder`" in docs
# Browser tools only if available (Playwright may not be installed in CI)
if (
TOOL_REGISTRY.get("browser_navigate")
and TOOL_REGISTRY["browser_navigate"].is_available
):
assert "`browser_navigate`" in docs
def test_baseline_supplement_includes_workflows(self):
"""Baseline supplement should include workflow guidance in tool descriptions."""
from backend.copilot.prompting import get_baseline_supplement
docs = get_baseline_supplement()
# Workflows are now in individual tool descriptions (not separate sections)
# Check that key workflow concepts appear in tool descriptions
assert "agent_json" in docs or "find_block" in docs
assert "run_mcp_tool" in docs
def test_baseline_supplement_completeness(self):
"""All available tools from TOOL_REGISTRY should appear in baseline supplement."""
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.tools import TOOL_REGISTRY
docs = get_baseline_supplement()
# Verify each available registered tool is documented
# (matches _generate_tool_documentation which filters by is_available)
for tool_name, tool in TOOL_REGISTRY.items():
if not tool.is_available:
continue
assert (
f"`{tool_name}`" in docs
), f"Tool '{tool_name}' missing from baseline supplement"
assert "## AVAILABLE TOOLS" not in SHARED_TOOL_NOTES
# Keep the high-value workflow rules that are NOT in any tool schema.
assert "@@agptfile:" in SHARED_TOOL_NOTES
assert "Tool Discovery Priority" in SHARED_TOOL_NOTES
assert "run_sub_session" in SHARED_TOOL_NOTES
def test_pause_task_scheduled_before_transcript_upload(self):
"""Pause is scheduled as a background task before transcript upload begins.
@@ -283,21 +232,6 @@ class TestPromptSupplement:
# concurrently during upload's first yield. The ordering guarantee is
# that create_task is CALLED before upload is AWAITED (see source order).
def test_baseline_supplement_no_duplicate_tools(self):
"""No tool should appear multiple times in baseline supplement."""
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.tools import TOOL_REGISTRY
docs = get_baseline_supplement()
# Count occurrences of each available tool in the entire supplement
for tool_name, tool in TOOL_REGISTRY.items():
if not tool.is_available:
continue
# Count how many times this tool appears as a bullet point
count = docs.count(f"- **`{tool_name}`**")
assert count == 1, f"Tool '{tool_name}' appears {count} times (should be 1)"
# ---------------------------------------------------------------------------
# _cleanup_sdk_tool_results — orchestration + rate-limiting
@@ -699,6 +633,17 @@ class TestSystemPromptPreset:
assert result["append"] == ""
assert result["exclude_dynamic_sections"] is True
def test_resume_and_fresh_share_the_same_static_prefix(self):
"""Every turn (fresh + --resume) must emit the same preset dict
so the cross-user cache prefix match works on all turns. This
relies on CLI ≥ 2.1.98 (installed in the Docker image); older
CLIs would crash on --resume + excludeDynamicSections=True."""
fresh = _build_system_prompt_value("sys", cross_user_cache=True)
resumed = _build_system_prompt_value("sys", cross_user_cache=True)
assert fresh == resumed
assert isinstance(fresh, dict)
assert fresh.get("exclude_dynamic_sections") is True
def test_default_config_is_enabled(self, _clean_config_env):
"""The default value for claude_agent_cross_user_prompt_cache is True."""
cfg = cfg_mod.ChatConfig(
@@ -719,3 +664,13 @@ class TestSystemPromptPreset:
use_claude_code_subscription=False,
)
assert cfg.claude_agent_cross_user_prompt_cache is False
class TestIdleTimeoutConstant:
"""SECRT-2247: long-running work now uses async start+poll pattern
(run_sub_session / run_agent), so no single MCP tool call ever blocks
the stream close to the idle limit. The plain 10-min cap from the
original code is restored."""
def test_idle_timeout_is_10_min(self):
assert _IDLE_TIMEOUT_SECONDS == 10 * 60

View File

@@ -19,9 +19,11 @@ from __future__ import annotations
from datetime import datetime, timezone
from unittest.mock import MagicMock
from backend.copilot.constants import STOPPED_BY_USER_MARKER
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.response_model import StreamStartStep, StreamTextDelta
from backend.copilot.sdk.service import _dispatch_response, _StreamAccumulator
from backend.copilot.session_cleanup import prune_orphan_tool_calls
_NOW = datetime(2024, 1, 1, tzinfo=timezone.utc)
@@ -215,3 +217,183 @@ class TestPreCreateAssistantMessage:
_simulate_pre_create(acc, ctx)
assert len(ctx.session.messages) == 0
class TestPruneOrphanToolCalls:
"""A Stop mid-tool-call leaves the session ending on an assistant row whose
``tool_calls`` have no matching ``role="tool"`` row. Unless pruned before
the next turn, the ``--resume`` transcript would hand Claude CLI a
``tool_use`` without a paired ``tool_result`` and the SDK would fail.
"""
@staticmethod
def _tool_call(call_id: str, name: str = "bash_exec") -> dict:
return {
"id": call_id,
"type": "function",
"function": {"name": name, "arguments": "{}"},
}
def test_stop_mid_tool_leaves_orphan_assistant(self) -> None:
"""Stop between StreamToolInputAvailable and StreamToolOutputAvailable:
the assistant row has ``tool_calls`` but no matching tool row."""
messages: list[ChatMessage] = [
ChatMessage(role="user", content="do something"),
ChatMessage(
role="assistant",
content="",
tool_calls=[self._tool_call("tc_abc")],
),
]
removed = prune_orphan_tool_calls(messages)
assert removed == 1
assert len(messages) == 1
assert messages[-1].role == "user"
def test_stop_strips_stopped_by_user_marker_and_orphan(self) -> None:
"""The service also appends a ``STOPPED_BY_USER_MARKER`` after a
user stop when the stream loop exits cleanly; both tail rows must go."""
messages: list[ChatMessage] = [
ChatMessage(role="user", content="do something"),
ChatMessage(
role="assistant",
content="",
tool_calls=[self._tool_call("tc_abc")],
),
ChatMessage(role="assistant", content=STOPPED_BY_USER_MARKER),
]
removed = prune_orphan_tool_calls(messages)
assert removed == 2
assert len(messages) == 1
assert messages[-1].role == "user"
def test_completed_tool_call_is_preserved(self) -> None:
"""An assistant row whose tool_calls are all resolved is a healthy
trailing state and must not be popped."""
messages: list[ChatMessage] = [
ChatMessage(role="user", content="do something"),
ChatMessage(
role="assistant",
content="",
tool_calls=[self._tool_call("tc_abc")],
),
ChatMessage(
role="tool",
content="ok",
tool_call_id="tc_abc",
),
]
removed = prune_orphan_tool_calls(messages)
assert removed == 0
assert len(messages) == 3
def test_partial_resolution_still_pops(self) -> None:
"""If an assistant emits multiple tool_calls and only some are
resolved, the assistant row is still unsafe for ``--resume``."""
messages: list[ChatMessage] = [
ChatMessage(role="user", content="do something"),
ChatMessage(
role="assistant",
content="",
tool_calls=[
self._tool_call("tc_1"),
self._tool_call("tc_2"),
],
),
ChatMessage(
role="tool",
content="ok",
tool_call_id="tc_1",
),
]
removed = prune_orphan_tool_calls(messages)
# Both the orphan assistant and its partial tool row are dropped.
assert removed == 2
assert len(messages) == 1
assert messages[-1].role == "user"
def test_plain_assistant_text_preserved(self) -> None:
"""A regular text-only assistant tail is healthy and must be kept."""
messages: list[ChatMessage] = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="hello"),
]
removed = prune_orphan_tool_calls(messages)
assert removed == 0
assert len(messages) == 2
def test_empty_session_is_noop(self) -> None:
messages: list[ChatMessage] = []
assert prune_orphan_tool_calls(messages) == 0
class TestPruneOrphanToolCallsLogging:
"""``prune_orphan_tool_calls`` emits an INFO log when the caller passes
``log_prefix`` and something was actually popped. Shared by the SDK
and baseline turn-start cleanup so both paths log in the same shape."""
def _tool_call(self, call_id: str) -> dict:
return {"id": call_id, "type": "function", "function": {"name": "bash"}}
def test_logs_when_something_was_pruned(self, caplog) -> None:
import backend.copilot.session_cleanup as sc
messages: list[ChatMessage] = [
ChatMessage(role="user", content="hi"),
ChatMessage(
role="assistant", content="", tool_calls=[self._tool_call("tc_1")]
),
]
sc.logger.propagate = True
caplog.set_level("INFO", logger=sc.logger.name)
removed = prune_orphan_tool_calls(messages, log_prefix="[TEST] [abc123]")
assert removed == 1
assert any(
"[TEST] [abc123]" in r.message and "Dropped 1" in r.message
for r in caplog.records
), caplog.text
def test_no_log_when_nothing_to_prune(self, caplog) -> None:
import backend.copilot.session_cleanup as sc
messages: list[ChatMessage] = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="hello"),
]
sc.logger.propagate = True
caplog.set_level("INFO", logger=sc.logger.name)
removed = prune_orphan_tool_calls(messages, log_prefix="[TEST] [xyz]")
assert removed == 0
assert not any("[TEST] [xyz]" in r.message for r in caplog.records), caplog.text
def test_no_log_when_log_prefix_is_none(self, caplog) -> None:
"""Without ``log_prefix``, ``prune_orphan_tool_calls`` is silent."""
import backend.copilot.session_cleanup as sc
messages: list[ChatMessage] = [
ChatMessage(role="user", content="hi"),
ChatMessage(
role="assistant", content="", tool_calls=[self._tool_call("tc_1")]
),
]
sc.logger.propagate = True
caplog.set_level("INFO", logger=sc.logger.name)
removed = prune_orphan_tool_calls(messages)
assert removed == 1
assert caplog.text == ""

View File

@@ -0,0 +1,217 @@
"""Cross-process helpers: dispatch + await a copilot session turn.
The sub-AutoPilot tools (``run_sub_session``, ``get_sub_session_result``)
and ``AutoPilotBlock`` all delegate a copilot turn to the
``copilot_executor`` queue and then wait on the shared
``stream_registry`` for the terminal event. This module is the
centralised primitive so every caller agrees on the dispatch shape,
the event aggregation, and the cleanup contract.
:func:`wait_for_session_result` accumulates stream events into an
:class:`EventAccumulator` so callers get back ``response_text`` /
``tool_calls`` / token usage in memory without an extra DB round-trip.
:func:`run_copilot_turn_via_queue` is the one-shot "create session meta
→ enqueue → wait for result" sequence every caller uses.
"""
from __future__ import annotations
import asyncio
import logging
import uuid
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Literal
from backend.copilot import stream_registry
from backend.copilot.executor.utils import enqueue_copilot_turn
from backend.copilot.pending_message_helpers import (
is_turn_in_flight,
queue_user_message,
)
from backend.copilot.response_model import StreamError, StreamFinish
from .stream_accumulator import EventAccumulator, ToolCallEntry, process_event
if TYPE_CHECKING:
from backend.copilot.permissions import CopilotPermissions
logger = logging.getLogger(__name__)
SessionOutcome = Literal["completed", "failed", "running", "queued"]
@dataclass
class SessionResult:
"""Aggregated result from a copilot session turn observed via
``stream_registry``.
When ``queued`` is set, :func:`run_copilot_turn_via_queue` detected an
in-flight turn on the target session and pushed the message onto the
pending buffer instead of starting a new turn. ``response_text`` is
empty and the aggregate counts are zero in that case; the executor
running the earlier turn drains the buffer on its next round.
"""
response_text: str = ""
tool_calls: list[ToolCallEntry] = field(default_factory=list)
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
queued: bool = False
pending_buffer_length: int = 0
async def wait_for_session_result(
*,
session_id: str,
user_id: str | None,
timeout: float,
) -> tuple[SessionOutcome, SessionResult]:
"""Drain the session's stream events and aggregate them into a result.
Returns whatever has been observed at the cap (``running`` + partial
result) or at the terminal event (``completed`` / ``failed`` + full
result). Cleans up the subscriber listener on every exit path so
long-running polls don't leak listeners (sentry r3105348640).
"""
queue = await stream_registry.subscribe_to_session(
session_id=session_id,
user_id=user_id,
)
result = SessionResult()
if queue is None:
# Session meta not in Redis yet, or the caller doesn't own it.
# ``subscribe_to_session`` already retried with backoff before
# returning None.
return "running", result
acc = EventAccumulator()
outcome: SessionOutcome = "running"
try:
loop = asyncio.get_event_loop()
deadline = loop.time() + max(timeout, 0)
while True:
remaining = deadline - loop.time()
if remaining <= 0:
break
event = await asyncio.wait_for(queue.get(), timeout=remaining)
process_event(event, acc)
if isinstance(event, StreamFinish):
outcome = "completed"
break
if isinstance(event, StreamError):
outcome = "failed"
break
except asyncio.TimeoutError:
pass
finally:
await stream_registry.unsubscribe_from_session(
session_id=session_id,
subscriber_queue=queue,
)
result.response_text = "".join(acc.response_parts)
result.tool_calls = list(acc.tool_calls)
result.prompt_tokens = acc.prompt_tokens
result.completion_tokens = acc.completion_tokens
result.total_tokens = acc.total_tokens
return outcome, result
async def run_copilot_turn_via_queue(
*,
session_id: str,
user_id: str,
message: str,
timeout: float,
permissions: "CopilotPermissions | None" = None,
tool_call_id: str,
tool_name: str,
) -> tuple[SessionOutcome, SessionResult]:
"""Dispatch a copilot turn onto the queue and wait for its result.
The canonical invocation path shared by ``run_sub_session`` (the
copilot tool), ``AutoPilotBlock`` (the graph block), and any future
caller that needs to run a copilot turn without occupying its own
worker with the SDK stream:
1. Create a ``stream_registry`` session meta record for the turn.
2. Enqueue a ``CoPilotExecutionEntry`` on the copilot_execution
exchange. Any idle copilot_executor worker claims it.
3. Subscribe to the session's Redis stream and drain events until
``StreamFinish`` / ``StreamError`` or the cap fires.
``tool_call_id`` / ``tool_name`` disambiguate who originated the
turn in observability / replay (e.g. ``"sub:<parent>"`` for a
sub-session, ``"autopilot_block"`` for an AutoPilotBlock run).
Self-defensive queue-fallback: if the target session already has a
turn running (another ``run_sub_session`` / AutoPilot block / UI
chat), don't race it on the cluster lock. Push the message onto the
pending buffer so the existing turn drains it at its next round
boundary, then:
* ``timeout == 0`` — return immediately with
``("queued", SessionResult(queued=True, ...))``. Callers that
explicitly opted into fire-and-forget (``run_sub_session`` with
``wait_for_result=0``) use this to bail without waiting.
* ``timeout > 0`` — **subscribe to the in-flight turn's stream and
return its aggregated result** (exactly the same shape as a
normally-dispatched turn, but with ``result.queued=True`` so
callers can tell we rode on someone else's turn). Semantically
identical to "I asked the session to do something and here is
what happened next"; no separate deferred-state branch needed in
``run_sub_session`` / ``AutoPilotBlock``.
"""
if await is_turn_in_flight(session_id):
logger.info(
"[queue] session=%s has a turn in flight; queueing message "
"(tool=%s) into pending buffer instead of starting a new turn",
session_id[:12],
tool_name,
)
state = await queue_user_message(session_id=session_id, message=message)
if timeout <= 0:
# Fire-and-forget: caller explicitly asked not to wait.
return "queued", SessionResult(
queued=True, pending_buffer_length=state.buffer_length
)
# Ride the in-flight turn: subscribe to its stream and return the
# same aggregated result shape as a fresh dispatch. The model
# drains the pending buffer between tool rounds (baseline) or at
# the next tool boundary via the PostToolUse hook (SDK), so the
# response we observe will reflect our queued follow-up (or be
# the terminal result if the in-flight turn finishes before the
# buffer drains — in that case ``result.queued=True`` is still
# the correct signal for the caller).
outcome, observed = await wait_for_session_result(
session_id=session_id,
user_id=user_id,
timeout=timeout,
)
observed.queued = True
observed.pending_buffer_length = state.buffer_length
return outcome, observed
turn_id = str(uuid.uuid4())
await stream_registry.create_session(
session_id=session_id,
user_id=user_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
turn_id=turn_id,
)
await enqueue_copilot_turn(
session_id=session_id,
user_id=user_id,
message=message,
turn_id=turn_id,
permissions=permissions,
)
return await wait_for_session_result(
session_id=session_id,
user_id=user_id,
timeout=timeout,
)

View File

@@ -0,0 +1,169 @@
"""Tests for the shared queue primitive in ``session_waiter``.
Focuses on the queue-on-busy fallback:
* ``timeout == 0`` — push into the buffer and return immediately with
``("queued", SessionResult(queued=True, ...))``; skip registry +
RabbitMQ entirely.
* ``timeout > 0`` — push into the buffer, then subscribe to the
in-flight turn's stream and return its aggregated result (with
``queued=True`` annotation) so callers get the same shape as a
fresh dispatch.
"""
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.sdk.session_waiter import SessionResult, run_copilot_turn_via_queue
_QR = type(
"QR",
(),
{"buffer_length": 4, "max_buffer_length": 10, "turn_in_flight": True},
)
@pytest.mark.asyncio
async def test_queue_branch_timeout_zero_returns_immediately():
"""Busy + timeout=0 → no registry, no enqueue, no wait, queued result."""
queue_mock = AsyncMock(return_value=_QR())
create_session = AsyncMock()
enqueue = AsyncMock()
wait_result = AsyncMock()
with (
patch(
"backend.copilot.sdk.session_waiter.is_turn_in_flight",
new=AsyncMock(return_value=True),
),
patch(
"backend.copilot.sdk.session_waiter.queue_user_message",
new=queue_mock,
),
patch(
"backend.copilot.sdk.session_waiter.stream_registry.create_session",
new=create_session,
),
patch(
"backend.copilot.sdk.session_waiter.enqueue_copilot_turn",
new=enqueue,
),
patch(
"backend.copilot.sdk.session_waiter.wait_for_session_result",
new=wait_result,
),
):
outcome, result = await run_copilot_turn_via_queue(
session_id="sess-busy",
user_id="u1",
message="follow-up",
timeout=0,
tool_call_id="sub:parent",
tool_name="run_sub_session",
)
assert outcome == "queued"
assert isinstance(result, SessionResult)
assert result.queued is True
assert result.pending_buffer_length == 4
create_session.assert_not_awaited()
enqueue.assert_not_awaited()
wait_result.assert_not_awaited()
queue_mock.assert_awaited_once_with(session_id="sess-busy", message="follow-up")
@pytest.mark.asyncio
async def test_queue_branch_positive_timeout_rides_inflight_turn():
"""Busy + timeout>0 → push buffer, subscribe to in-flight turn, return
its aggregated result with ``queued=True`` annotation."""
queue_mock = AsyncMock(return_value=_QR())
create_session = AsyncMock()
enqueue = AsyncMock()
observed = SessionResult()
observed.response_text = "final answer from in-flight turn"
wait_result = AsyncMock(return_value=("completed", observed))
with (
patch(
"backend.copilot.sdk.session_waiter.is_turn_in_flight",
new=AsyncMock(return_value=True),
),
patch(
"backend.copilot.sdk.session_waiter.queue_user_message",
new=queue_mock,
),
patch(
"backend.copilot.sdk.session_waiter.stream_registry.create_session",
new=create_session,
),
patch(
"backend.copilot.sdk.session_waiter.enqueue_copilot_turn",
new=enqueue,
),
patch(
"backend.copilot.sdk.session_waiter.wait_for_session_result",
new=wait_result,
),
):
outcome, result = await run_copilot_turn_via_queue(
session_id="sess-busy",
user_id="u1",
message="follow-up",
timeout=30.0,
tool_call_id="autopilot_block",
tool_name="autopilot_block",
)
# We rode on the existing turn — its outcome + aggregate propagate up.
assert outcome == "completed"
assert result.response_text == "final answer from in-flight turn"
# Marker so callers can tell we didn't start a fresh turn.
assert result.queued is True
assert result.pending_buffer_length == 4
# Still no new registry entry / no new RabbitMQ job — that was the point.
create_session.assert_not_awaited()
enqueue.assert_not_awaited()
# Subscribed to the session stream (not a new turn_id).
wait_result.assert_awaited_once()
assert wait_result.await_args.kwargs["session_id"] == "sess-busy"
@pytest.mark.asyncio
async def test_idle_session_enqueues_normally():
"""Idle session → registry session created, enqueued, drain waits."""
create_session = AsyncMock()
enqueue = AsyncMock()
wait_result = AsyncMock(return_value=("completed", SessionResult()))
with (
patch(
"backend.copilot.sdk.session_waiter.is_turn_in_flight",
new=AsyncMock(return_value=False),
),
patch(
"backend.copilot.sdk.session_waiter.stream_registry.create_session",
new=create_session,
),
patch(
"backend.copilot.sdk.session_waiter.enqueue_copilot_turn",
new=enqueue,
),
patch(
"backend.copilot.sdk.session_waiter.wait_for_session_result",
new=wait_result,
),
):
outcome, result = await run_copilot_turn_via_queue(
session_id="sess-idle",
user_id="u1",
message="kick off",
timeout=0.1,
tool_call_id="autopilot_block",
tool_name="autopilot_block",
)
assert outcome == "completed"
assert result.queued is False
create_session.assert_awaited_once()
enqueue.assert_awaited_once()

View File

@@ -0,0 +1,85 @@
"""Stream event → aggregated result accumulator.
Consumes the same ``StreamBaseResponse`` events that fly over
``stream_registry`` (text deltas, tool i/o, usage, errors) and folds
them into a single :class:`EventAccumulator` state. Used by
:func:`session_waiter.wait_for_session_result` to read events from a
Redis Stream subscription so a different process can obtain the
aggregated result for a session it didn't run.
Keeping the dispatch in one place means new event types can be added
without drifting callers apart on what "response_text", "tool_calls",
or token counts mean.
"""
from __future__ import annotations
import logging
from typing import Any
from pydantic import BaseModel, Field
from ..response_model import (
StreamError,
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
StreamUsage,
)
logger = logging.getLogger(__name__)
class ToolCallEntry(BaseModel):
"""A single tool call observed during stream consumption."""
tool_call_id: str
tool_name: str
input: Any
output: Any = None
success: bool | None = None
class EventAccumulator(BaseModel):
"""Mutable accumulator fed by :func:`process_event`."""
response_parts: list[str] = Field(default_factory=list)
tool_calls: list[ToolCallEntry] = Field(default_factory=list)
tool_calls_by_id: dict[str, ToolCallEntry] = Field(default_factory=dict)
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
def process_event(event: object, acc: EventAccumulator) -> str | None:
"""Fold *event* into *acc*. Returns the error text on ``StreamError``.
Uses structural pattern matching for dispatch per project guidelines.
"""
match event:
case StreamTextDelta(delta=delta):
acc.response_parts.append(delta)
case StreamToolInputAvailable() as e:
entry = ToolCallEntry(
tool_call_id=e.toolCallId,
tool_name=e.toolName,
input=e.input,
)
acc.tool_calls.append(entry)
acc.tool_calls_by_id[e.toolCallId] = entry
case StreamToolOutputAvailable() as e:
if tc := acc.tool_calls_by_id.get(e.toolCallId):
tc.output = e.output
tc.success = e.success
else:
logger.debug(
"Received tool output for unknown tool_call_id: %s",
e.toolCallId,
)
case StreamUsage() as e:
acc.prompt_tokens += e.prompt_tokens
acc.completion_tokens += e.completion_tokens
acc.total_tokens += e.total_tokens
case StreamError(errorText=err):
return err
return None

View File

@@ -0,0 +1,95 @@
"""Unit tests for the watermark-fix logic in stream_chat_completion_sdk.
The fix is at the upload step: when use_resume=True and transcript_msg_count>0
we set the JSONL coverage watermark to transcript_msg_count + 2 (the pair just
recorded) instead of len(session.messages). This prevents the "inflated
watermark" bug where a stale JSONL in GCS could hide missing context from
future gap-fill checks.
"""
from __future__ import annotations
def _compute_jsonl_covered(
use_resume: bool,
transcript_msg_count: int,
session_msg_count: int,
) -> int:
"""Mirror the watermark computation from ``stream_chat_completion_sdk``.
Extracted here so we can unit-test it independently without invoking the
full streaming stack.
"""
if use_resume and transcript_msg_count > 0:
return transcript_msg_count + 2
return session_msg_count
class TestWatermarkFix:
"""Watermark computation logic — mirrors the finally-block in SDK service."""
def test_inflated_watermark_triggers_gap_fill(self):
"""Stale JSONL (T12) with high watermark (46) → after fix, watermark=14.
Before fix: watermark=46 → next turn's gap check (transcript_msg_count < db-1)
never fires because 46 >= 47-1=46, so context loss is silent.
After fix: watermark = 12 + 2 = 14 → gap check fires (14 < 46) and
the model receives the missing turns.
"""
# Simulate: use_resume=True, transcript covered T12 (12 msgs), DB now has 47
use_resume = True
transcript_msg_count = 12
session_msg_count = 47 # DB count (what old code used to set watermark)
watermark = _compute_jsonl_covered(
use_resume, transcript_msg_count, session_msg_count
)
assert watermark == 14 # 12 + 2, NOT 47
# Verify: the gap check would fire on next turn
# next-turn check: transcript_msg_count < msg_count - 1 → 14 < 47-1=46 → True
assert watermark < session_msg_count - 1
def test_no_false_positive_when_transcript_current(self):
"""Transcript current (watermark=46, DB=47) → gap stays 0.
When the JSONL actually covers T46 (the most recent assistant turn),
uploading watermark=46+2=48 means next turn's gap check sees
48 >= 48-1=47 → no gap. Correct.
"""
use_resume = True
transcript_msg_count = 46
session_msg_count = 47
watermark = _compute_jsonl_covered(
use_resume, transcript_msg_count, session_msg_count
)
assert watermark == 48 # 46 + 2
# Next turn: session has 48 msgs, watermark=48 → 48 >= 48-1=47 → no gap
next_turn_session = 48
assert watermark >= next_turn_session - 1
def test_fresh_session_falls_back_to_db_count(self):
"""use_resume=False → watermark = len(session.messages) (original behaviour)."""
use_resume = False
transcript_msg_count = 0
session_msg_count = 3
watermark = _compute_jsonl_covered(
use_resume, transcript_msg_count, session_msg_count
)
assert watermark == session_msg_count
def test_old_format_meta_zero_count_falls_back_to_db(self):
"""transcript_msg_count=0 (old-format meta with no count field) → DB fallback."""
use_resume = True
transcript_msg_count = 0 # old-format meta or not-yet-set
session_msg_count = 10
watermark = _compute_jsonl_covered(
use_resume, transcript_msg_count, session_msg_count
)
assert watermark == session_msg_count

View File

@@ -62,11 +62,24 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Max MCP response size in chars. 100K chars ≈ 25K tokens. The SDK writes oversized results to tool-results/ files.
# Set to 100K (down from a previous 500K) because the SDK already reads back large results from disk via
# tool-results/ — sending 500K chars inline bloated the context window and caused cache-miss thrashing.
# 100K keeps the common case (block output, API responses) in-band without punishing the context budget.
_MCP_MAX_CHARS = 100_000
# Max MCP response size in chars — sized to the Claude CLI's internal cap.
#
# The CLI has a default ``maxResultSizeChars = 1e5`` (100K chars) annotation
# for MCP tool results, but the actual trigger is TOKEN-based (see
# ``sizeEstimateTokens`` in the bundled CLI at ``tengu_mcp_large_result_handled``)
# and fires around 2025K tokens. For JSON-heavy tool output (~34 chars/token)
# that lands anywhere from ~60K to ~100K chars in practice; we've observed the
# error path at 81K chars in production. When it fires, the CLI persists the
# full output to disk and REPLACES the returned content with a synthetic
# ``"Error: result (N characters) exceeds maximum allowed tokens. Output has
# been saved to …"`` message — which destroys any `<user_follow_up>` block
# we injected.
#
# 70K gives us headroom below the observed 81K trigger and leaves ~6K for the
# follow-up injection plus CLI wire overhead. Oversized content is still
# reachable via ``read_tool_result`` against the persisted disk file; only
# the inline reply to this specific call is truncated.
_MCP_MAX_CHARS = 70_000
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
MCP_SERVER_NAME = "copilot"
@@ -248,7 +261,14 @@ async def _execute_tool_sync(
session: ChatSession,
args: dict[str, Any],
) -> dict[str, Any]:
"""Execute a tool synchronously and return MCP-formatted response."""
"""Execute a tool inline and return an MCP-formatted response.
The call runs to completion — no per-handler timeout, no parking. The
stream-level idle timer in ``_run_stream_attempt`` pauses while a tool
is pending, so a long sub-AutoPilot / graph execution doesn't trip the
30-min idle safety net (SECRT-2247). A genuine hang is handled by the
broader session lifecycle (user closes the tab / cancel endpoint).
"""
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
result = await base_tool.execute(
user_id=user_id,
@@ -612,8 +632,12 @@ def _make_truncating_wrapper(
else:
_clear_tool_failures(tool_name)
# Stash BEFORE stripping so the frontend SSE stream receives
# the full output including _STRIP_FROM_LLM fields (e.g. is_dry_run).
# Stash the raw tool output for the frontend SSE stream so widgets
# (bash, tool viewers) receive clean JSON. Mid-turn user follow-up
# injection for MCP + built-in tools is now handled uniformly by
# the ``PostToolUse`` hook via ``additionalContext`` so Claude sees
# the follow-up attached to the tool_result without mutating the
# frontend-facing payload.
if not truncated.get("isError"):
text = _text_from_mcp_result(truncated)
if text:

View File

@@ -251,7 +251,10 @@ class TestTruncationAndStashIntegration:
# ---------------------------------------------------------------------------
def _make_mock_tool(name: str, output: str = "result") -> MagicMock:
def _make_mock_tool(
name: str,
output: str = "result",
) -> MagicMock:
"""Return a BaseTool mock that returns a successful StreamToolOutputAvailable."""
tool = MagicMock()
tool.name = name
@@ -336,6 +339,38 @@ class TestCreateToolHandler:
assert mock_tool.execute.await_count == 2
class TestToolInlineExecution:
"""Tools run inline to completion — no per-handler timeout, no parking."""
@pytest.fixture(autouse=True)
def _init(self):
_init_ctx(session=_make_mock_session())
@pytest.mark.asyncio
async def test_tool_runs_to_completion_regardless_of_duration(self):
"""A tool that takes a while still runs inline; the handler does not
park, cancel, or wrap it in a timeout. The stream-level idle timer
(in _run_stream_attempt) is what pauses while tool calls are pending."""
async def slow_but_completes(*_args, **_kwargs):
await asyncio.sleep(0.1)
return StreamToolOutputAvailable(
toolCallId="t1",
output="final-result",
toolName="slow_tool",
success=True,
)
mock_tool = _make_mock_tool("slow_tool")
mock_tool.execute = AsyncMock(side_effect=slow_but_completes)
handler = create_tool_handler(mock_tool)
result = await handler({})
assert result["isError"] is False
assert "final-result" in result["content"][0]["text"]
# ---------------------------------------------------------------------------
# Regression tests: bugs fixed by removing pre-launch mechanism
#
@@ -873,7 +908,9 @@ class TestStripLlmFields:
"""
dry_run_session = MagicMock()
dry_run_session.dry_run = True
set_execution_context(user_id="test", session=dry_run_session, sandbox=None, sdk_cwd="/tmp/test") # type: ignore[arg-type]
set_execution_context(
user_id="test", session=dry_run_session, sandbox=None, sdk_cwd="/tmp/test"
) # type: ignore[arg-type]
full_payload = '{"message": "done", "is_dry_run": true}'
@@ -906,7 +943,9 @@ class TestStripLlmFields:
"""
normal_session = MagicMock()
normal_session.dry_run = False
set_execution_context(user_id="test", session=normal_session, sandbox=None, sdk_cwd="/tmp/test") # type: ignore[arg-type]
set_execution_context(
user_id="test", session=normal_session, sandbox=None, sdk_cwd="/tmp/test"
) # type: ignore[arg-type]
full_payload = '{"message": "simulated", "is_dry_run": true}'
@@ -929,3 +968,53 @@ class TestStripLlmFields:
stashed = pop_pending_tool_output("fake_tool_normal")
assert stashed is not None
assert '"is_dry_run": true' in stashed
class TestTruncatingWrapperLeavesOutputUntouched:
"""Mid-turn drain moved to the shared ``PostToolUse`` hook path so every
tool (MCP + built-in) is covered uniformly. The wrapper must therefore
forward tool output verbatim and never touch ``<user_follow_up>``."""
@pytest.mark.asyncio
async def test_wrapper_does_not_inject_followup(self):
session = MagicMock()
session.dry_run = False
session.session_id = "sess-no-inject"
set_execution_context(user_id="u", session=session, sandbox=None, sdk_cwd="/tmp/test") # type: ignore[arg-type]
async def fake_tool_fn(_args: dict) -> dict:
return {
"content": [{"type": "text", "text": "CLEAN_OUTPUT"}],
"isError": False,
}
wrapper = _make_truncating_wrapper(fake_tool_fn, "fake_tool_clean")
result = await wrapper({})
text = result["content"][0]["text"]
assert text == "CLEAN_OUTPUT"
assert "<user_follow_up>" not in text
@pytest.mark.asyncio
async def test_stash_stays_clean(self):
"""The frontend-facing stash must be a byte-for-byte copy of the
raw tool output (needed for JSON.parse in the bash widget)."""
session = MagicMock()
session.dry_run = False
session.session_id = "sess-stash"
set_execution_context(user_id="u", session=session, sandbox=None, sdk_cwd="/tmp/test") # type: ignore[arg-type]
clean_json = '{"stdout": "hello\\n", "exit_code": 0}'
async def fake_tool_fn(_args: dict) -> dict:
return {
"content": [{"type": "text", "text": clean_json}],
"isError": False,
}
wrapper = _make_truncating_wrapper(fake_tool_fn, "fake_tool_stash_pure")
await wrapper({})
stashed = pop_pending_tool_output("fake_tool_stash_pure")
assert stashed == clean_json
assert "<user_follow_up>" not in (stashed or "")

View File

@@ -12,18 +12,20 @@ from backend.copilot.transcript import (
ENTRY_TYPE_MESSAGE,
STOP_REASON_END_TURN,
STRIPPABLE_TYPES,
TRANSCRIPT_STORAGE_PREFIX,
TranscriptDownload,
TranscriptMode,
cleanup_stale_project_dirs,
cli_session_path,
compact_transcript,
delete_transcript,
detect_gap,
download_transcript,
extract_context_messages,
projects_base,
read_compacted_entries,
restore_cli_session,
strip_for_upload,
strip_progress_entries,
strip_stale_thinking_blocks,
upload_cli_session,
upload_transcript,
validate_transcript,
write_transcript_to_tempfile,
@@ -34,18 +36,20 @@ __all__ = [
"ENTRY_TYPE_MESSAGE",
"STOP_REASON_END_TURN",
"STRIPPABLE_TYPES",
"TRANSCRIPT_STORAGE_PREFIX",
"TranscriptDownload",
"TranscriptMode",
"cleanup_stale_project_dirs",
"cli_session_path",
"compact_transcript",
"delete_transcript",
"detect_gap",
"download_transcript",
"extract_context_messages",
"projects_base",
"read_compacted_entries",
"restore_cli_session",
"strip_for_upload",
"strip_progress_entries",
"strip_stale_thinking_blocks",
"upload_cli_session",
"upload_transcript",
"validate_transcript",
"write_transcript_to_tempfile",

View File

@@ -297,8 +297,8 @@ class TestStripProgressEntries:
class TestDeleteTranscript:
@pytest.mark.asyncio
async def test_deletes_both_jsonl_and_meta(self):
"""delete_transcript removes both the .jsonl and .meta.json files."""
async def test_deletes_cli_session_and_meta(self):
"""delete_transcript removes the CLI session .jsonl and .meta.json."""
mock_storage = AsyncMock()
mock_storage.delete = AsyncMock()
@@ -309,7 +309,7 @@ class TestDeleteTranscript:
):
await delete_transcript("user-123", "session-456")
assert mock_storage.delete.call_count == 3
assert mock_storage.delete.call_count == 2
paths = [call.args[0] for call in mock_storage.delete.call_args_list]
assert any(p.endswith(".jsonl") for p in paths)
assert any(p.endswith(".meta.json") for p in paths)
@@ -319,7 +319,7 @@ class TestDeleteTranscript:
"""If .jsonl delete fails, .meta.json delete is still attempted."""
mock_storage = AsyncMock()
mock_storage.delete = AsyncMock(
side_effect=[Exception("jsonl delete failed"), None, None]
side_effect=[Exception("jsonl delete failed"), None]
)
with patch(
@@ -330,14 +330,14 @@ class TestDeleteTranscript:
# Should not raise
await delete_transcript("user-123", "session-456")
assert mock_storage.delete.call_count == 3
assert mock_storage.delete.call_count == 2
@pytest.mark.asyncio
async def test_handles_meta_delete_failure(self):
"""If .meta.json delete fails, no exception propagates."""
mock_storage = AsyncMock()
mock_storage.delete = AsyncMock(
side_effect=[None, Exception("meta delete failed"), None]
side_effect=[None, Exception("meta delete failed")]
)
with patch(
@@ -1015,7 +1015,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: str(projects_dir),
)
@@ -1044,7 +1044,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: str(projects_dir),
)
@@ -1070,7 +1070,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: str(projects_dir),
)
@@ -1096,7 +1096,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: str(projects_dir),
)
@@ -1118,7 +1118,7 @@ class TestCleanupStaleProjectDirs:
nonexistent = str(tmp_path / "does-not-exist" / "projects")
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: nonexistent,
)
@@ -1137,7 +1137,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: str(projects_dir),
)
@@ -1165,7 +1165,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: str(projects_dir),
)
@@ -1189,7 +1189,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: str(projects_dir),
)
@@ -1368,3 +1368,172 @@ class TestStripStaleThinkingBlocks:
# Both entries of last turn (msg_last) preserved
assert lines[1]["message"]["content"][0]["type"] == "thinking"
assert lines[2]["message"]["content"][0]["type"] == "text"
class TestProcessCliRestore:
"""``process_cli_restore`` validates, strips, and writes CLI session to disk."""
def test_writes_stripped_bytes_not_raw(self, tmp_path):
"""Stripped bytes (not raw bytes) must be written to disk for --resume."""
import os
import re
from pathlib import Path
from unittest.mock import patch
from backend.copilot.sdk.service import process_cli_restore
from backend.copilot.transcript import TranscriptDownload
session_id = "12345678-0000-0000-0000-abcdef000001"
sdk_cwd = str(tmp_path)
projects_base_dir = str(tmp_path)
# Build raw content with a strippable progress entry + a valid user/assistant pair
raw_content = (
'{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n'
'{"type":"user","uuid":"u1","parentUuid":null,"message":{"role":"user","content":"hi"}}\n'
'{"type":"assistant","uuid":"a1","parentUuid":"u1","message":{"role":"assistant","content":[{"type":"text","text":"hello"}]}}\n'
)
raw_bytes = raw_content.encode("utf-8")
restore = TranscriptDownload(content=raw_bytes, message_count=2, mode="sdk")
with (
patch(
"backend.copilot.sdk.service.projects_base",
return_value=projects_base_dir,
),
patch(
"backend.copilot.transcript.projects_base",
return_value=projects_base_dir,
),
):
stripped_str, ok = process_cli_restore(
restore, sdk_cwd, session_id, "[Test]"
)
assert ok, "Expected successful restore"
# Find the written session file
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
session_file = Path(projects_base_dir) / encoded_cwd / f"{session_id}.jsonl"
assert session_file.exists(), "Session file should have been written"
written_bytes = session_file.read_bytes()
# The written bytes must be the stripped version (no progress entry)
assert (
b"progress" not in written_bytes
), "Raw bytes with progress entry should not have been written"
assert (
b"hello" in written_bytes
), "Stripped content should still contain assistant turn"
# Written bytes must equal the stripped string re-encoded
assert written_bytes == stripped_str.encode(
"utf-8"
), "Written bytes must equal stripped content"
def test_invalid_content_returns_false(self):
"""Content that fails validation after strip returns (empty, False)."""
from backend.copilot.sdk.service import process_cli_restore
from backend.copilot.transcript import TranscriptDownload
# A single progress-only entry — stripped result will be empty/invalid
raw_content = '{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n'
restore = TranscriptDownload(
content=raw_content.encode("utf-8"), message_count=1, mode="sdk"
)
stripped_str, ok = process_cli_restore(
restore,
"/tmp/nonexistent-sdk-cwd",
"12345678-0000-0000-0000-000000000099",
"[Test]",
)
assert not ok
assert stripped_str == ""
class TestReadCliSessionFromDisk:
"""``read_cli_session_from_disk`` reads, strips, and optionally writes back the session."""
def _build_session_file(self, tmp_path, session_id: str):
"""Build the session file path inside tmp_path using the same encoding as cli_session_path."""
import os
import re
from pathlib import Path
sdk_cwd = str(tmp_path)
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
session_dir = Path(str(tmp_path)) / encoded_cwd
session_dir.mkdir(parents=True, exist_ok=True)
return sdk_cwd, session_dir / f"{session_id}.jsonl"
def test_returns_raw_bytes_for_invalid_utf8(self, tmp_path):
"""Non-UTF-8 bytes trigger UnicodeDecodeError — returns raw bytes (upload-raw fallback)."""
from unittest.mock import patch
from backend.copilot.sdk.service import read_cli_session_from_disk
session_id = "12345678-0000-0000-0000-aabbccdd0001"
projects_base_dir = str(tmp_path)
sdk_cwd, session_file = self._build_session_file(tmp_path, session_id)
# Write raw invalid UTF-8 bytes
session_file.write_bytes(b"\xff\xfe invalid utf-8\n")
with (
patch(
"backend.copilot.sdk.service.projects_base",
return_value=projects_base_dir,
),
patch(
"backend.copilot.transcript.projects_base",
return_value=projects_base_dir,
),
):
result = read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
# UnicodeDecodeError path returns the raw bytes (upload-raw fallback)
assert result == b"\xff\xfe invalid utf-8\n"
def test_write_back_oserror_still_returns_stripped_bytes(self, tmp_path):
"""OSError on write-back returns stripped bytes for GCS upload (not raw)."""
from unittest.mock import patch
from backend.copilot.sdk.service import read_cli_session_from_disk
session_id = "12345678-0000-0000-0000-aabbccdd0002"
projects_base_dir = str(tmp_path)
sdk_cwd, session_file = self._build_session_file(tmp_path, session_id)
# Content with a strippable progress entry so stripped_bytes < raw_bytes
raw_content = (
'{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n'
'{"type":"user","uuid":"u1","parentUuid":null,"message":{"role":"user","content":"hi"}}\n'
'{"type":"assistant","uuid":"a1","parentUuid":"u1","message":{"role":"assistant","content":[{"type":"text","text":"hello"}]}}\n'
)
session_file.write_bytes(raw_content.encode("utf-8"))
# Make the file read-only so write_bytes raises OSError on the write-back
session_file.chmod(0o444)
try:
with (
patch(
"backend.copilot.sdk.service.projects_base",
return_value=projects_base_dir,
),
patch(
"backend.copilot.transcript.projects_base",
return_value=projects_base_dir,
),
):
result = read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
finally:
session_file.chmod(0o644)
# Must return stripped bytes (not raw, not None) so GCS gets the clean version
assert result is not None
assert (
b"progress" not in result
), "Stripped bytes must not contain progress entry"
assert b"hello" in result, "Stripped bytes should contain assistant turn"

View File

@@ -26,7 +26,7 @@ from backend.data.understanding import (
from backend.util.exceptions import NotAuthorizedError, NotFoundError
from backend.util.settings import AppEnvironment, Settings
from .config import ChatConfig
from .config import ChatConfig, CopilotLlmModel
from .model import (
ChatMessage,
ChatSessionInfo,
@@ -40,6 +40,21 @@ logger = logging.getLogger(__name__)
config = ChatConfig()
settings = Settings()
def resolve_chat_model(tier: CopilotLlmModel | None) -> str:
"""Return the configured OpenRouter model string for the given tier.
Shared by the baseline (fast) and SDK (extended thinking) paths so
both honor the same standard/advanced env-var configuration. ``None``
and ``'standard'`` fall through to ``config.model``; ``'advanced'``
uses ``config.advanced_model``. Keep this flat — if a third tier
shows up later, extend here and both paths pick it up for free.
"""
if tier == "advanced":
return config.advanced_model
return config.model
_client: LangfuseAsyncOpenAI | None = None
_langfuse = None
@@ -446,7 +461,9 @@ async def inject_user_context(
+ final_message
)
for session_msg in session_messages:
# Scan in reverse so we target the current turn's user message, not
# an older one that may exist when pending messages have been drained.
for session_msg in reversed(session_messages):
if session_msg.role == "user":
# Only touch the DB / in-memory state when the content actually
# needs to change — avoids an unnecessary write on the common

View File

@@ -61,18 +61,23 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
# (CLI version, platform). When that happens, multi-turn still works
# via conversation compression (non-resume path), but we can't test
# the --resume round-trip.
transcript = None
cli_session = None
for _ in range(10):
await asyncio.sleep(0.5)
transcript = await download_transcript(test_user_id, session.session_id)
if transcript:
cli_session = await download_transcript(test_user_id, session.session_id)
# Wait until both the session bytes AND the message_count watermark are
# present — a session with message_count=0 means the .meta.json hasn't
# been uploaded yet, so --resume on the next turn would skip gap-fill.
if cli_session and cli_session.message_count > 0:
break
if not transcript:
if not cli_session:
return pytest.skip(
"CLI did not produce a usable transcript — "
"cannot test --resume round-trip in this environment"
)
logger.info(f"Turn 1 transcript uploaded: {len(transcript.content)} bytes")
logger.info(
f"Turn 1 CLI session uploaded: {len(cli_session.content)} bytes, msg_count={cli_session.message_count}"
)
# Reload session for turn 2
session = await get_chat_session(session.session_id, test_user_id)

View File

@@ -0,0 +1,77 @@
"""Pre-turn cleanup of transient markers left on ``session.messages`` by
prior turns (user-initiated Stop, cancelled tool calls, etc.).
Shared by both the SDK and baseline chat entry points so both code paths
start every new turn from a well-formed message list.
"""
import logging
from backend.copilot.constants import STOPPED_BY_USER_MARKER
from backend.copilot.model import ChatMessage
logger = logging.getLogger(__name__)
def prune_orphan_tool_calls(
messages: list[ChatMessage],
log_prefix: str | None = None,
) -> int:
"""Pop trailing orphan tool-use blocks from *messages* in place.
A Stop mid-tool-call leaves the session ending on an assistant message
whose ``tool_calls`` have no matching ``role="tool"`` row — the tool
never produced output because the executor was cancelled. Feeding that
tail to the next ``--resume`` turn would hand the Claude CLI a
``tool_use`` with no paired ``tool_result`` and the SDK raises a
generic error.
Also strips trailing ``STOPPED_BY_USER_MARKER`` assistant rows emitted
by the same Stop path so the next turn's transcript starts clean.
If *log_prefix* is given, emits an INFO log with the prefix whenever
something was actually popped so the turn-start cleanup is visible.
In-memory only — the DB write path is append-only via
``start_sequence`` so no delete is needed; the same rows are popped
again on the next session load.
"""
cut_index: int | None = None
resolved_ids: set[str] = set()
for i in range(len(messages) - 1, -1, -1):
msg = messages[i]
if msg.role == "tool" and msg.tool_call_id:
resolved_ids.add(msg.tool_call_id)
continue
if msg.role == "assistant" and msg.content == STOPPED_BY_USER_MARKER:
cut_index = i
continue
if msg.role == "assistant" and msg.tool_calls:
pending_ids = {
tc.get("id")
for tc in msg.tool_calls
if isinstance(tc, dict) and tc.get("id")
}
if pending_ids and not pending_ids.issubset(resolved_ids):
cut_index = i
break
break
if cut_index is None:
return 0
removed = len(messages) - cut_index
del messages[cut_index:]
if log_prefix:
logger.info(
"%s Dropped %d trailing orphan tool-use/stop row(s) "
"before starting new turn",
log_prefix,
removed,
)
return removed

View File

@@ -17,7 +17,7 @@ Subscribers:
import asyncio
import logging
import time
from collections.abc import AsyncIterator
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Literal
@@ -32,9 +32,10 @@ from backend.data.notification_bus import (
NotificationEvent,
)
from backend.data.redis_client import get_redis_async
from backend.data.redis_helpers import hash_compare_and_set
from .config import ChatConfig
from .executor.utils import COPILOT_CONSUMER_TIMEOUT_SECONDS
from .executor.utils import COPILOT_CONSUMER_TIMEOUT_SECONDS, get_session_lock_key
from .response_model import (
ResponseType,
StreamBaseResponse,
@@ -42,6 +43,9 @@ from .response_model import (
StreamFinish,
StreamFinishStep,
StreamHeartbeat,
StreamReasoningDelta,
StreamReasoningEnd,
StreamReasoningStart,
StreamStart,
StreamStartStep,
StreamTextDelta,
@@ -68,17 +72,6 @@ _listener_sessions: dict[int, tuple[str, asyncio.Task]] = {}
# If the queue is full and doesn't drain within this time, send an overflow error
QUEUE_PUT_TIMEOUT = 5.0
# Lua script for atomic compare-and-swap status update (idempotent completion)
# Returns 1 if status was updated, 0 if already completed/failed
COMPLETE_SESSION_SCRIPT = """
local current = redis.call("HGET", KEYS[1], "status")
if current == "running" then
redis.call("HSET", KEYS[1], "status", ARGV[1])
return 1
end
return 0
"""
@dataclass
class ActiveSession:
@@ -336,8 +329,8 @@ async def publish_chunk(
async def stream_and_publish(
session_id: str,
turn_id: str,
stream: AsyncIterator[StreamBaseResponse],
) -> AsyncIterator[StreamBaseResponse]:
stream: AsyncGenerator[StreamBaseResponse, None],
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Wrap an async stream iterator with registry publishing.
Publishes each chunk to the stream registry for frontend SSE consumption,
@@ -360,27 +353,35 @@ async def stream_and_publish(
"""
publish_failed_once = False
async for event in stream:
if turn_id and not isinstance(event, (StreamFinish, StreamError)):
try:
await publish_chunk(turn_id, event, session_id=session_id)
except (RedisError, ConnectionError, OSError):
if not publish_failed_once:
publish_failed_once = True
logger.warning(
"[stream_and_publish] Failed to publish chunk %s for %s "
"(further failures logged at DEBUG)",
type(event).__name__,
session_id[:12],
exc_info=True,
)
else:
logger.debug(
"[stream_and_publish] Failed to publish chunk %s",
type(event).__name__,
exc_info=True,
)
yield event
# async-for does not close an iterator on GeneratorExit; forward close
# to ``stream`` explicitly so its own cleanup (stream lock, persist)
# runs deterministically instead of waiting for GC.
try:
async for event in stream:
if turn_id and not isinstance(event, (StreamFinish, StreamError)):
try:
await publish_chunk(turn_id, event, session_id=session_id)
except (RedisError, ConnectionError, OSError):
# Full stack trace on the first failure; terser lines
# for the rest so subsequent failures don't flood logs
# while still being visible at WARNING.
if not publish_failed_once:
publish_failed_once = True
logger.warning(
"[stream_and_publish] Failed to publish chunk %s for %s",
type(event).__name__,
session_id[:12],
exc_info=True,
)
else:
logger.warning(
"[stream_and_publish] Failed to publish chunk %s for %s",
type(event).__name__,
session_id[:12],
)
yield event
finally:
await stream.aclose()
async def subscribe_to_session(
@@ -423,20 +424,33 @@ async def subscribe_to_session(
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
)
# RACE CONDITION FIX: If session not found, retry once after small delay
# This handles the case where subscribe_to_session is called immediately
# after create_session but before Redis propagates the write
# RACE CONDITION FIX: If session not found, retry with backoff.
# Duplicate requests skip create_session and subscribe immediately; the
# original request's create_session (a Redis hset) may not have completed
# yet. 3 × 100ms gives a 300ms window which covers DB-write latency on the
# original request before the hset even starts.
if not meta:
logger.warning(
"[TIMING] Session not found on first attempt, retrying after 50ms delay",
extra={"json_fields": {**log_meta}},
)
await asyncio.sleep(0.05) # 50ms
meta = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
_max_retries = 3
_retry_delay = 0.1 # 100ms per attempt
for attempt in range(_max_retries):
logger.warning(
f"[TIMING] Session not found (attempt {attempt + 1}/{_max_retries}), "
f"retrying after {int(_retry_delay * 1000)}ms",
extra={"json_fields": {**log_meta, "attempt": attempt + 1}},
)
await asyncio.sleep(_retry_delay)
meta = await redis.hgetall(meta_key) # type: ignore[misc]
if meta:
logger.info(
f"[TIMING] Session found after {attempt + 1} retries",
extra={"json_fields": {**log_meta, "attempts": attempt + 1}},
)
break
else:
elapsed = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] Session still not found in Redis after retry ({elapsed:.1f}ms total)",
f"[TIMING] Session still not found in Redis after {_max_retries} retries "
f"({elapsed:.1f}ms total)",
extra={
"json_fields": {
**log_meta,
@@ -446,10 +460,6 @@ async def subscribe_to_session(
},
)
return None
logger.info(
"[TIMING] Session found after retry",
extra={"json_fields": {**log_meta}},
)
# Note: Redis client uses decode_responses=True, so keys are strings
session_status = meta.get("status", "")
@@ -830,15 +840,26 @@ async def mark_session_completed(
turn_id = _parse_session_meta(meta, session_id).turn_id if meta else session_id
# Atomic compare-and-swap: only update if status is "running"
result = await redis.eval(COMPLETE_SESSION_SCRIPT, 1, meta_key, status) # type: ignore[misc]
swapped = await hash_compare_and_set(
redis, meta_key, "status", expected="running", new=status
)
# Clean up the in-memory TTL refresh tracker to prevent unbounded growth.
_meta_ttl_refresh_at.pop(session_id, None)
if result == 0:
if not swapped:
logger.debug(f"Session {session_id} already completed/failed, skipping")
return False
# Force-release the executor's cluster lock so the next enqueued turn can
# acquire it immediately. The lock holder's on_run_done will also release
# (idempotent delete); doing it here unblocks cases where the task hangs
# past the cancel timeout or a pod crash leaves the lock orphaned.
try:
await redis.delete(get_session_lock_key(session_id))
except RedisError as e:
logger.warning(f"Failed to release cluster lock for session {session_id}: {e}")
if error_message and not skip_error_publish:
try:
await publish_chunk(turn_id, StreamError(errorText=error_message))
@@ -1061,6 +1082,9 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
ResponseType.TEXT_START.value: StreamTextStart,
ResponseType.TEXT_DELTA.value: StreamTextDelta,
ResponseType.TEXT_END.value: StreamTextEnd,
ResponseType.REASONING_START.value: StreamReasoningStart,
ResponseType.REASONING_DELTA.value: StreamReasoningDelta,
ResponseType.REASONING_END.value: StreamReasoningEnd,
ResponseType.TOOL_INPUT_START.value: StreamToolInputStart,
ResponseType.TOOL_INPUT_AVAILABLE.value: StreamToolInputAvailable,
ResponseType.TOOL_OUTPUT_AVAILABLE.value: StreamToolOutputAvailable,

View File

@@ -4,8 +4,10 @@ import asyncio
from unittest.mock import AsyncMock, patch
import pytest
from redis.exceptions import RedisError
from backend.copilot import stream_registry
from backend.copilot.executor.utils import get_session_lock_key
@pytest.fixture(autouse=True)
@@ -108,3 +110,228 @@ async def test_disconnect_all_listeners_timeout_not_counted():
await task
except asyncio.CancelledError:
pass
# ---------------------------------------------------------------------------
# stream_and_publish: closing the wrapper forwards GeneratorExit into the
# inner stream so its finally (stream lock release, etc.) runs deterministically.
# ---------------------------------------------------------------------------
class _FakeEvent:
"""Minimal stand-in for a StreamBaseResponse so publish_chunk is a no-op."""
def __init__(self, idx: int):
self.idx = idx
@pytest.mark.asyncio
async def test_stream_and_publish_aclose_propagates_to_inner_stream():
"""Closing the wrapper MUST run the inner generator's finally block."""
inner_finally_ran = asyncio.Event()
async def _inner():
try:
yield _FakeEvent(0)
yield _FakeEvent(1)
yield _FakeEvent(2)
finally:
inner_finally_ran.set()
inner = _inner()
# Empty turn_id skips publish_chunk — keeps the test hermetic (no Redis).
wrapper = stream_registry.stream_and_publish(
session_id="sess-test", turn_id="", stream=inner
)
# Consume one event, then close the wrapper early.
first = await wrapper.__anext__()
assert isinstance(first, _FakeEvent)
await wrapper.aclose()
# The inner generator's finally must have run deterministically
# (not deferred to GC) so the caller's cleanup (lock release, etc.)
# is observable right after aclose returns.
assert inner_finally_ran.is_set()
@pytest.mark.asyncio
async def test_stream_and_publish_logs_warning_on_publish_chunk_failure():
"""``stream_and_publish`` must not propagate a Redis publish failure —
it warns once with full stack trace, keeps yielding, and logs
subsequent failures at WARNING (terser, no exc_info) so repeated
errors stay visible without flooding the trace."""
from redis.exceptions import RedisError
async def _inner():
yield _FakeEvent(0)
yield _FakeEvent(1)
yield _FakeEvent(2)
async def _raising_publish(turn_id, event, session_id=None):
raise RedisError("boom")
warning_mock = patch.object(
stream_registry.logger, "warning", autospec=True
).start()
try:
with patch.object(stream_registry, "publish_chunk", new=_raising_publish):
wrapper = stream_registry.stream_and_publish(
session_id="sess-test", turn_id="turn-1", stream=_inner()
)
received = [evt async for evt in wrapper]
finally:
patch.stopall()
# Every event still yields through — publish failures don't break the stream.
assert len(received) == 3
# One warning per failed publish (3 total). First call carries a
# stack trace (``exc_info=True``); subsequent calls are terser.
assert warning_mock.call_count == 3
assert warning_mock.call_args_list[0].kwargs.get("exc_info") is True
assert warning_mock.call_args_list[1].kwargs.get("exc_info") is not True
@pytest.mark.asyncio
async def test_stream_and_publish_consumer_break_then_aclose_releases_inner():
"""The processor pattern — break on cancel, then aclose — must release."""
inner_finally_ran = asyncio.Event()
async def _inner():
try:
for idx in range(100):
yield _FakeEvent(idx)
finally:
inner_finally_ran.set()
inner = _inner()
wrapper = stream_registry.stream_and_publish(
session_id="sess-test", turn_id="", stream=inner
)
# Mimic the processor: consume a few events, simulate Stop by breaking,
# then aclose the wrapper (as processor._execute_async now does in the
# try/finally around the async for).
try:
count = 0
async for _ in wrapper:
count += 1
if count >= 2:
break
finally:
await wrapper.aclose()
assert inner_finally_ran.is_set()
# ---------------------------------------------------------------------------
# mark_session_completed: the atomic meta flip to completed/failed must also
# release the per-session cluster lock, so the next enqueued turn's run
# handler can acquire it without waiting for the TTL (5 min default).
# ---------------------------------------------------------------------------
class _FakeRedis:
"""Minimal async-Redis fake: only the calls mark_session_completed makes."""
def __init__(self, meta: dict[str, str]):
self._meta = dict(meta)
self.deleted_keys: list[str] = []
self.delete = AsyncMock(side_effect=self._record_delete)
async def _record_delete(self, *keys: str):
self.deleted_keys.extend(keys)
for k in keys:
self._meta.pop(k, None)
return len(keys)
async def hgetall(self, _key: str):
return dict(self._meta)
@pytest.mark.asyncio
async def test_mark_session_completed_releases_cluster_lock_on_success():
"""CAS swap must be followed by a DELETE on the session's lock key so a
stuck-because-of-stale-lock session becomes immediately claimable."""
fake_redis = _FakeRedis({"status": "running", "turn_id": "turn-1"})
with (
patch.object(
stream_registry, "get_redis_async", new=AsyncMock(return_value=fake_redis)
),
patch.object(
stream_registry, "hash_compare_and_set", new=AsyncMock(return_value=True)
),
patch.object(stream_registry, "publish_chunk", new=AsyncMock()),
patch.object(
stream_registry.chat_db(),
"set_turn_duration",
new=AsyncMock(),
create=True,
),
):
result = await stream_registry.mark_session_completed("sess-1")
assert result is True
assert get_session_lock_key("sess-1") in fake_redis.deleted_keys
@pytest.mark.asyncio
async def test_mark_session_completed_skips_lock_release_when_already_completed():
"""CAS failure = someone else completed the session first; we must not
delete their already-released lock, and we must NOT publish StreamFinish
twice (the winning caller already published it)."""
fake_redis = _FakeRedis({"status": "completed", "turn_id": "turn-1"})
publish_mock = AsyncMock()
with (
patch.object(
stream_registry, "get_redis_async", new=AsyncMock(return_value=fake_redis)
),
patch.object(
stream_registry, "hash_compare_and_set", new=AsyncMock(return_value=False)
),
patch.object(stream_registry, "publish_chunk", new=publish_mock),
):
result = await stream_registry.mark_session_completed("sess-1")
assert result is False
assert get_session_lock_key("sess-1") not in fake_redis.deleted_keys
assert not any(
isinstance(call.args[1], stream_registry.StreamFinish)
for call in publish_mock.call_args_list
), "StreamFinish must NOT be re-published on the CAS-no-op branch"
@pytest.mark.asyncio
async def test_mark_session_completed_survives_lock_release_redis_error():
"""A Redis hiccup during lock DELETE must not prevent the StreamFinish
publish — the client's SSE stream would otherwise hang on the stale meta
status while Redis recovers."""
fake_redis = _FakeRedis({"status": "running", "turn_id": "turn-1"})
fake_redis.delete = AsyncMock(side_effect=RedisError("boom"))
publish_mock = AsyncMock()
with (
patch.object(
stream_registry, "get_redis_async", new=AsyncMock(return_value=fake_redis)
),
patch.object(
stream_registry, "hash_compare_and_set", new=AsyncMock(return_value=True)
),
patch.object(stream_registry, "publish_chunk", new=publish_mock),
patch.object(
stream_registry.chat_db(),
"set_turn_duration",
new=AsyncMock(),
create=True,
),
):
result = await stream_registry.mark_session_completed("sess-1")
assert result is True
assert any(
isinstance(call.args[1], stream_registry.StreamFinish)
for call in publish_mock.call_args_list
), "StreamFinish must still be published even if lock DELETE raises"

View File

@@ -1,9 +1,9 @@
"""Shared token-usage persistence and rate-limit recording.
"""Shared usage persistence and rate-limit recording.
Both the baseline (OpenRouter) and SDK (Anthropic) service layers need to:
1. Append a ``Usage`` record to the session.
2. Log the turn's token counts.
3. Record weighted usage in Redis for rate-limiting.
2. Log the turn's token counts and cost.
3. Record the real generation cost in Redis for rate-limiting.
4. Write a PlatformCostLog entry for admin cost tracking.
This module extracts that common logic so both paths stay in sync.
@@ -19,7 +19,7 @@ from backend.data.db_accessors import platform_cost_db
from backend.data.platform_cost import PlatformCostEntry, usd_to_microdollars
from .model import ChatSession, Usage
from .rate_limit import record_token_usage
from .rate_limit import record_cost_usage
logger = logging.getLogger(__name__)
@@ -96,9 +96,14 @@ async def persist_and_record_usage(
cost_usd: float | str | None = None,
model: str | None = None,
provider: str = "open_router",
model_cost_multiplier: float = 1.0,
) -> int:
"""Persist token usage to session and record for rate limiting.
"""Persist token usage to session and record generation cost for rate limiting.
Rate-limit counters are charged in microdollars against the provider's
reported cost (``cost_usd``), so cache discounts and cross-model pricing
differences are already reflected. When cost is unknown the turn is
logged but the rate-limit counter is left alone — the caller logs an
error at the point the absence is detected.
Args:
session: The chat session to append usage to (may be None on error).
@@ -108,11 +113,11 @@ async def persist_and_record_usage(
cache_read_tokens: Tokens served from prompt cache (Anthropic only).
cache_creation_tokens: Tokens written to prompt cache (Anthropic only).
log_prefix: Prefix for log messages (e.g. "[SDK]", "[Baseline]").
cost_usd: Optional cost for logging (float from SDK, str otherwise).
cost_usd: Real generation cost for the turn (float from SDK or parsed
from OpenRouter usage.cost). ``None`` means the provider did not
report a cost and rate limiting is skipped for this turn.
model: Model identifier for cost log attribution.
provider: Cost provider name (e.g. "anthropic", "open_router").
model_cost_multiplier: Relative model cost factor for rate limiting
(1.0 = Sonnet/default, 5.0 = Opus). Scales the token counter so
more expensive models deplete the rate limit proportionally faster.
Returns:
The computed total_tokens (prompt + completion; cache excluded).
@@ -156,37 +161,51 @@ async def persist_and_record_usage(
else:
logger.info(
f"{log_prefix} Turn usage: prompt={prompt_tokens}, completion={completion_tokens},"
f" total={total_tokens}"
f" total={total_tokens}, cost_usd={cost_usd}"
)
if user_id:
cost_float: float | None = None
if cost_usd is not None:
try:
await record_token_usage(
user_id=user_id,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cache_read_tokens=cache_read_tokens,
cache_creation_tokens=cache_creation_tokens,
model_cost_multiplier=model_cost_multiplier,
val = float(cost_usd)
except (ValueError, TypeError):
logger.error(
"%s cost_usd is not numeric: %r — rate limit skipped",
log_prefix,
cost_usd,
)
except Exception as usage_err:
logger.warning("%s Failed to record token usage: %s", log_prefix, usage_err)
else:
if not math.isfinite(val):
logger.error(
"%s cost_usd is non-finite: %r — rate limit skipped",
log_prefix,
val,
)
elif val < 0:
logger.warning(
"%s cost_usd %s is negative — skipping rate-limit + cost log",
log_prefix,
val,
)
else:
cost_float = val
cost_microdollars = usd_to_microdollars(cost_float)
if user_id and cost_microdollars is not None and cost_microdollars > 0:
# record_cost_usage() owns its fail-open handling for Redis/network
# errors. Don't wrap with a broad except here — unexpected accounting
# bugs should surface instead of being silently logged as warnings.
await record_cost_usage(
user_id=user_id,
cost_microdollars=cost_microdollars,
)
# Log to PlatformCostLog for admin cost dashboard.
# Include entries where cost_usd is set even if token count is 0
# (e.g. fully-cached Anthropic responses where only cache tokens
# accumulate a charge without incrementing total_tokens).
if user_id and (total_tokens > 0 or cost_usd is not None):
cost_float = None
if cost_usd is not None:
try:
val = float(cost_usd)
if math.isfinite(val) and val >= 0:
cost_float = val
except (ValueError, TypeError):
pass
cost_microdollars = usd_to_microdollars(cost_float)
if user_id and (total_tokens > 0 or cost_float is not None):
session_id = session.session_id if session else None
if cost_float is not None:

View File

@@ -37,7 +37,7 @@ class TestTotalTokens:
async def test_returns_prompt_plus_completion(self):
"""total_tokens = prompt + completion (cache excluded from total)."""
with patch(
"backend.copilot.token_tracking.record_token_usage",
"backend.copilot.token_tracking.record_cost_usage",
new_callable=AsyncMock,
):
total = await persist_and_record_usage(
@@ -63,7 +63,7 @@ class TestTotalTokens:
async def test_cache_tokens_excluded_from_total(self):
"""Cache tokens are stored separately and not added to total_tokens."""
with patch(
"backend.copilot.token_tracking.record_token_usage",
"backend.copilot.token_tracking.record_cost_usage",
new_callable=AsyncMock,
):
total = await persist_and_record_usage(
@@ -81,7 +81,7 @@ class TestTotalTokens:
async def test_baseline_path_no_cache(self):
"""Baseline (OpenRouter) path passes no cache tokens; total = prompt + completion."""
with patch(
"backend.copilot.token_tracking.record_token_usage",
"backend.copilot.token_tracking.record_cost_usage",
new_callable=AsyncMock,
):
total = await persist_and_record_usage(
@@ -97,7 +97,7 @@ class TestTotalTokens:
async def test_sdk_path_with_cache(self):
"""SDK (Anthropic) path passes cache tokens; total still = prompt + completion."""
with patch(
"backend.copilot.token_tracking.record_token_usage",
"backend.copilot.token_tracking.record_cost_usage",
new_callable=AsyncMock,
):
total = await persist_and_record_usage(
@@ -123,7 +123,7 @@ class TestSessionPersistence:
async def test_appends_usage_to_session(self):
session = _make_session()
with patch(
"backend.copilot.token_tracking.record_token_usage",
"backend.copilot.token_tracking.record_cost_usage",
new_callable=AsyncMock,
):
await persist_and_record_usage(
@@ -144,7 +144,7 @@ class TestSessionPersistence:
async def test_appends_cache_breakdown_to_session(self):
session = _make_session()
with patch(
"backend.copilot.token_tracking.record_token_usage",
"backend.copilot.token_tracking.record_cost_usage",
new_callable=AsyncMock,
):
await persist_and_record_usage(
@@ -163,7 +163,7 @@ class TestSessionPersistence:
async def test_multiple_turns_append_multiple_records(self):
session = _make_session()
with patch(
"backend.copilot.token_tracking.record_token_usage",
"backend.copilot.token_tracking.record_cost_usage",
new_callable=AsyncMock,
):
await persist_and_record_usage(
@@ -178,7 +178,7 @@ class TestSessionPersistence:
async def test_none_session_does_not_raise(self):
"""When session is None (e.g. error path), no exception should be raised."""
with patch(
"backend.copilot.token_tracking.record_token_usage",
"backend.copilot.token_tracking.record_cost_usage",
new_callable=AsyncMock,
):
total = await persist_and_record_usage(
@@ -210,10 +210,11 @@ class TestSessionPersistence:
class TestRateLimitRecording:
@pytest.mark.asyncio
async def test_calls_record_token_usage_when_user_id_present(self):
async def test_calls_record_cost_usage_when_cost_and_user_id_present(self):
"""Rate-limit counter is charged with the real provider cost (microdollars)."""
mock_record = AsyncMock()
with patch(
"backend.copilot.token_tracking.record_token_usage",
"backend.copilot.token_tracking.record_cost_usage",
new=mock_record,
):
await persist_and_record_usage(
@@ -223,22 +224,35 @@ class TestRateLimitRecording:
completion_tokens=50,
cache_read_tokens=1000,
cache_creation_tokens=200,
cost_usd=0.0123,
)
mock_record.assert_awaited_once_with(
user_id="user-abc",
prompt_tokens=100,
completion_tokens=50,
cache_read_tokens=1000,
cache_creation_tokens=200,
model_cost_multiplier=1.0,
cost_microdollars=12_300,
)
@pytest.mark.asyncio
async def test_skips_record_when_cost_is_missing(self):
"""Without a provider cost we have no authoritative figure to charge."""
mock_record = AsyncMock()
with patch(
"backend.copilot.token_tracking.record_cost_usage",
new=mock_record,
):
await persist_and_record_usage(
session=None,
user_id="user-abc",
prompt_tokens=100,
completion_tokens=50,
)
mock_record.assert_not_awaited()
@pytest.mark.asyncio
async def test_skips_record_when_user_id_is_none(self):
"""Anonymous sessions should not create Redis keys."""
mock_record = AsyncMock()
with patch(
"backend.copilot.token_tracking.record_token_usage",
"backend.copilot.token_tracking.record_cost_usage",
new=mock_record,
):
await persist_and_record_usage(
@@ -246,32 +260,38 @@ class TestRateLimitRecording:
user_id=None,
prompt_tokens=100,
completion_tokens=50,
cost_usd=0.001,
)
mock_record.assert_not_awaited()
@pytest.mark.asyncio
async def test_record_failure_does_not_raise(self):
"""A Redis error in record_token_usage should be swallowed (fail-open)."""
mock_record = AsyncMock(side_effect=ConnectionError("Redis down"))
async def test_record_usage_bubbles_unexpected_error(self):
"""Unexpected errors from record_cost_usage must propagate.
record_cost_usage() owns its own (RedisError, ConnectionError, OSError)
fail-open handling. Anything else is a real accounting bug and
should not be silently swallowed at this layer.
"""
mock_record = AsyncMock(side_effect=RuntimeError("boom"))
with patch(
"backend.copilot.token_tracking.record_token_usage",
"backend.copilot.token_tracking.record_cost_usage",
new=mock_record,
):
# Should not raise
total = await persist_and_record_usage(
session=None,
user_id="user-xyz",
prompt_tokens=100,
completion_tokens=50,
)
assert total == 150
with pytest.raises(RuntimeError, match="boom"):
await persist_and_record_usage(
session=None,
user_id="user-xyz",
prompt_tokens=100,
completion_tokens=50,
cost_usd=0.002,
)
@pytest.mark.asyncio
async def test_skips_record_when_zero_tokens(self):
"""Returns 0 before calling record_token_usage when tokens are zero."""
async def test_skips_record_when_zero_tokens_and_no_cost(self):
"""Returns 0 before calling record_cost_usage when there is nothing to record."""
mock_record = AsyncMock()
with patch(
"backend.copilot.token_tracking.record_token_usage",
"backend.copilot.token_tracking.record_cost_usage",
new=mock_record,
):
await persist_and_record_usage(
@@ -295,7 +315,7 @@ class TestPlatformCostLogging:
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
"backend.copilot.token_tracking.record_cost_usage",
new_callable=AsyncMock,
),
patch(
@@ -336,7 +356,7 @@ class TestPlatformCostLogging:
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
"backend.copilot.token_tracking.record_cost_usage",
new_callable=AsyncMock,
),
patch(
@@ -369,7 +389,7 @@ class TestPlatformCostLogging:
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
"backend.copilot.token_tracking.record_cost_usage",
new_callable=AsyncMock,
),
patch(
@@ -394,7 +414,7 @@ class TestPlatformCostLogging:
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
"backend.copilot.token_tracking.record_cost_usage",
new_callable=AsyncMock,
),
patch(
@@ -423,7 +443,7 @@ class TestPlatformCostLogging:
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
"backend.copilot.token_tracking.record_cost_usage",
new_callable=AsyncMock,
),
patch(
@@ -452,7 +472,7 @@ class TestPlatformCostLogging:
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
"backend.copilot.token_tracking.record_cost_usage",
new_callable=AsyncMock,
),
patch(
@@ -479,7 +499,7 @@ class TestPlatformCostLogging:
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
"backend.copilot.token_tracking.record_cost_usage",
new_callable=AsyncMock,
),
patch(
@@ -509,7 +529,7 @@ class TestPlatformCostLogging:
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
"backend.copilot.token_tracking.record_cost_usage",
new_callable=AsyncMock,
),
patch(
@@ -545,7 +565,7 @@ class TestPlatformCostLogging:
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
"backend.copilot.token_tracking.record_cost_usage",
new_callable=AsyncMock,
),
patch(

View File

@@ -26,6 +26,7 @@ from .fix_agent import FixAgentGraphTool
from .get_agent_building_guide import GetAgentBuildingGuideTool
from .get_doc_page import GetDocPageTool
from .get_mcp_guide import GetMCPGuideTool
from .get_sub_session_result import GetSubSessionResultTool
from .graphiti_forget import MemoryForgetConfirmTool, MemoryForgetSearchTool
from .graphiti_search import MemorySearchTool
from .graphiti_store import MemoryStoreTool
@@ -40,6 +41,7 @@ from .manage_folders import (
from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .run_mcp_tool import RunMCPToolTool
from .run_sub_session import RunSubSessionTool
from .search_docs import SearchDocsTool
from .validate_agent import ValidateAgentGraphTool
from .web_fetch import WebFetchTool
@@ -81,6 +83,8 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"run_agent": RunAgentTool(),
"run_block": RunBlockTool(),
"continue_run_block": ContinueRunBlockTool(),
"run_sub_session": RunSubSessionTool(),
"get_sub_session_result": GetSubSessionResultTool(),
"run_mcp_tool": RunMCPToolTool(),
"get_mcp_guide": GetMCPGuideTool(),
"view_agent_output": AgentOutputTool(),

View File

@@ -12,7 +12,7 @@ from backend.api.features.store import db as store_db
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
from backend.blocks.llm import AITextGeneratorBlock
from backend.copilot.model import ChatSession
from backend.copilot.model import ChatMessage, ChatSession
from backend.data import db as db_module
from backend.data.db import prisma
from backend.data.graph import Graph, Link, Node, create_graph
@@ -42,11 +42,28 @@ async def _ensure_db_connected() -> None:
await db_module.connect()
def make_session(user_id: str):
def make_session(user_id: str, *, guide_read: bool = True):
"""Build a fake ChatSession for tool tests.
``guide_read=True`` (default) pre-populates the session with a
``get_agent_building_guide`` tool-call history entry so the agent-
generation gate (see ``helpers.require_guide_read``) lets through any
subsequent ``create_agent`` / ``edit_agent`` / ``validate_agent_graph``
/ ``fix_agent_graph`` call.
"""
messages: list[ChatMessage] = []
if guide_read:
messages.append(
ChatMessage(
role="assistant",
content="",
tool_calls=[{"function": {"name": "get_agent_building_guide"}}],
)
)
return ChatSession(
session_id=str(uuid.uuid4()),
user_id=user_id,
messages=[],
messages=messages,
usage=[],
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),

View File

@@ -1325,7 +1325,7 @@ class AgentFixer:
"""
if not library_agents:
logger.debug(
"fix_agent_executor_blocks: No library_agents provided, " "skipping"
"fix_agent_executor_blocks: No library_agents provided, skipping"
)
return agent
@@ -1390,7 +1390,7 @@ class AgentFixer:
if "user_id" not in input_default:
input_default["user_id"] = ""
self.add_fix_log(
f"Fixed AgentExecutorBlock {node_id}: Added missing " f"user_id"
f"Fixed AgentExecutorBlock {node_id}: Added missing user_id"
)
# Ensure inputs is present
@@ -1689,8 +1689,7 @@ class AgentFixer:
if field not in input_default or input_default[field] is None:
input_default[field] = default_value
self.add_fix_log(
f"OrchestratorBlock {node_id}: "
f"Set {field}={default_value!r}"
f"OrchestratorBlock {node_id}: Set {field}={default_value!r}"
)
return agent

View File

@@ -0,0 +1,119 @@
"""Tests for the ``require_guide_read`` gate on agent-generation tools.
The agent-building guide carries block ids, link semantics, and
AgentExecutorBlock / MCPToolBlock conventions that the agent needs before
producing agent JSON. Without the gate, agents often skip the guide to save
tokens and then produce JSON that fails validation — wasting turns on
auto-fix loops.
"""
from unittest.mock import MagicMock
import pytest
from backend.copilot.model import ChatMessage, ChatSession
from .helpers import require_guide_read
from .models import ErrorResponse
def _session_with_messages(messages: list[ChatMessage]) -> ChatSession:
"""Build a minimal ChatSession whose ``messages`` matches *messages*."""
session = MagicMock(spec=ChatSession)
session.session_id = "test-session"
session.messages = messages
return session
def test_no_messages_gate_fires():
session = _session_with_messages([])
result = require_guide_read(session, "create_agent")
assert isinstance(result, ErrorResponse)
assert "get_agent_building_guide" in result.message
assert "create_agent" in result.message
def test_user_message_only_gate_fires():
session = _session_with_messages(
[ChatMessage(role="user", content="build an agent")]
)
assert isinstance(require_guide_read(session, "create_agent"), ErrorResponse)
def test_assistant_without_tool_calls_gate_fires():
session = _session_with_messages(
[ChatMessage(role="assistant", content="sure!", tool_calls=None)]
)
assert isinstance(require_guide_read(session, "create_agent"), ErrorResponse)
def test_unrelated_tool_call_gate_fires():
session = _session_with_messages(
[
ChatMessage(
role="assistant",
content="",
tool_calls=[{"function": {"name": "find_block"}}],
)
]
)
assert isinstance(require_guide_read(session, "create_agent"), ErrorResponse)
def test_guide_called_via_openai_shape_gate_passes():
"""OpenAI/Anthropic wrap names under 'function': {'name': ...}."""
session = _session_with_messages(
[
ChatMessage(
role="assistant",
content="",
tool_calls=[
{"function": {"name": "get_agent_building_guide"}},
],
)
]
)
assert require_guide_read(session, "create_agent") is None
def test_guide_called_via_flat_shape_gate_passes():
"""Some callers log tool calls with a flat {'name': ...} shape."""
session = _session_with_messages(
[
ChatMessage(
role="assistant",
content="",
tool_calls=[{"name": "get_agent_building_guide"}],
)
]
)
assert require_guide_read(session, "create_agent") is None
def test_guide_earlier_in_history_still_passes():
"""A guide call earlier in the session keeps the gate open for subsequent
create/edit/validate/fix calls — the agent doesn't need to re-read it."""
session = _session_with_messages(
[
ChatMessage(role="user", content="build X"),
ChatMessage(
role="assistant",
content="",
tool_calls=[{"function": {"name": "get_agent_building_guide"}}],
),
ChatMessage(role="user", content="also Y"),
ChatMessage(role="assistant", content="working on it"),
]
)
assert require_guide_read(session, "edit_agent") is None
@pytest.mark.parametrize(
"tool_name",
["create_agent", "edit_agent", "validate_agent_graph", "fix_agent_graph"],
)
def test_tool_name_surfaced_in_error(tool_name: str):
session = _session_with_messages([])
result = require_guide_read(session, tool_name)
assert isinstance(result, ErrorResponse)
assert tool_name in result.message

View File

@@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field, field_validator
from backend.api.features.library.model import LibraryAgent
from backend.copilot.constants import MAX_TOOL_WAIT_SECONDS
from backend.copilot.model import ChatSession
from backend.data.db_accessors import execution_db, library_db
from backend.data.execution import (
@@ -39,7 +40,7 @@ class AgentOutputInput(BaseModel):
store_slug: str = ""
execution_id: str = ""
run_time: str = "latest"
wait_if_running: int = Field(default=0, ge=0, le=300)
wait_if_running: int = Field(default=0, ge=0, le=MAX_TOOL_WAIT_SECONDS)
show_execution_details: bool = False
@field_validator(
@@ -148,9 +149,13 @@ class AgentOutputTool(BaseTool):
},
"wait_if_running": {
"type": "integer",
"description": "Max seconds to wait if still running (0-300). Returns current state on timeout.",
"description": (
"Max seconds to wait if still running "
f"(0-{MAX_TOOL_WAIT_SECONDS}). "
"Returns current state on timeout."
),
"minimum": 0,
"maximum": 300,
"maximum": MAX_TOOL_WAIT_SECONDS,
},
"show_execution_details": {
"type": "boolean",

View File

@@ -47,7 +47,7 @@ class BashExecTool(BaseTool):
return (
"Execute a Bash command or script. Shares filesystem with SDK file tools. "
"Useful for scripts, data processing, and package installation. "
"Killed after timeout (default 30s, max 120s)."
"Killed after `timeout` seconds."
)
@property
@@ -61,8 +61,8 @@ class BashExecTool(BaseTool):
},
"timeout": {
"type": "integer",
"description": "Max seconds (default 30, max 120).",
"default": 30,
"description": "Timeout in seconds; raise for long-running commands.",
"default": 120,
},
},
"required": ["command"],
@@ -80,7 +80,7 @@ class BashExecTool(BaseTool):
user_id: str | None,
session: ChatSession,
command: str = "",
timeout: int = 30,
timeout: int = 120,
**kwargs: Any,
) -> ToolResponseBase:
"""Run a bash command on E2B (if available) or in a bubblewrap sandbox.
@@ -129,7 +129,7 @@ class BashExecTool(BaseTool):
message=(
"Execution timed out"
if timed_out
else f"Command executed (exit {exit_code})"
else f"Command executed with status code {exit_code}"
),
stdout=stdout,
stderr=stderr,
@@ -183,7 +183,7 @@ class BashExecTool(BaseTool):
stdout = stdout.replace(secret, "[REDACTED]")
stderr = stderr.replace(secret, "[REDACTED]")
return BashExecResponse(
message=f"Command executed on E2B (exit {result.exit_code})",
message=f"Command executed with status code {result.exit_code}",
stdout=stdout,
stderr=stderr,
exit_code=result.exit_code,

View File

@@ -35,12 +35,15 @@ class TestBashExecE2BTokenInjection:
sandbox = _make_sandbox(stdout="ok")
env_vars = {"GH_TOKEN": "gh-secret", "GITHUB_TOKEN": "gh-secret"}
with patch(
"backend.copilot.tools.bash_exec.get_integration_env_vars",
new=AsyncMock(return_value=env_vars),
) as mock_get_env, patch(
"backend.copilot.tools.bash_exec.get_github_user_git_identity",
new=AsyncMock(return_value=None),
with (
patch(
"backend.copilot.tools.bash_exec.get_integration_env_vars",
new=AsyncMock(return_value=env_vars),
) as mock_get_env,
patch(
"backend.copilot.tools.bash_exec.get_github_user_git_identity",
new=AsyncMock(return_value=None),
),
):
result = await tool._execute_on_e2b(
sandbox=sandbox,
@@ -69,12 +72,15 @@ class TestBashExecE2BTokenInjection:
"GIT_COMMITTER_EMAIL": "test@example.com",
}
with patch(
"backend.copilot.tools.bash_exec.get_integration_env_vars",
new=AsyncMock(return_value={}),
), patch(
"backend.copilot.tools.bash_exec.get_github_user_git_identity",
new=AsyncMock(return_value=identity),
with (
patch(
"backend.copilot.tools.bash_exec.get_integration_env_vars",
new=AsyncMock(return_value={}),
),
patch(
"backend.copilot.tools.bash_exec.get_github_user_git_identity",
new=AsyncMock(return_value=identity),
),
):
await tool._execute_on_e2b(
sandbox=sandbox,
@@ -97,12 +103,15 @@ class TestBashExecE2BTokenInjection:
session = make_session(user_id=_USER)
sandbox = _make_sandbox(stdout="ok")
with patch(
"backend.copilot.tools.bash_exec.get_integration_env_vars",
new=AsyncMock(return_value={}),
), patch(
"backend.copilot.tools.bash_exec.get_github_user_git_identity",
new=AsyncMock(return_value=None),
with (
patch(
"backend.copilot.tools.bash_exec.get_integration_env_vars",
new=AsyncMock(return_value={}),
),
patch(
"backend.copilot.tools.bash_exec.get_github_user_git_identity",
new=AsyncMock(return_value=None),
),
):
await tool._execute_on_e2b(
sandbox=sandbox,
@@ -123,13 +132,16 @@ class TestBashExecE2BTokenInjection:
session = make_session(user_id=_USER)
sandbox = _make_sandbox(stdout="ok")
with patch(
"backend.copilot.tools.bash_exec.get_integration_env_vars",
new=AsyncMock(return_value={"GH_TOKEN": "should-not-appear"}),
) as mock_get_env, patch(
"backend.copilot.tools.bash_exec.get_github_user_git_identity",
new=AsyncMock(return_value=None),
) as mock_get_identity:
with (
patch(
"backend.copilot.tools.bash_exec.get_integration_env_vars",
new=AsyncMock(return_value={"GH_TOKEN": "should-not-appear"}),
) as mock_get_env,
patch(
"backend.copilot.tools.bash_exec.get_github_user_git_identity",
new=AsyncMock(return_value=None),
) as mock_get_identity,
):
result = await tool._execute_on_e2b(
sandbox=sandbox,
command="echo hi",

View File

@@ -8,6 +8,7 @@ from backend.copilot.model import ChatSession
from .agent_generator.pipeline import fetch_library_agents, fix_validate_and_save
from .base import BaseTool
from .helpers import require_guide_read
from .models import ErrorResponse, ToolResponseBase
logger = logging.getLogger(__name__)
@@ -23,8 +24,9 @@ class CreateAgentTool(BaseTool):
@property
def description(self) -> str:
return (
"Create a new agent from JSON (nodes + links). Validates, auto-fixes, and saves. "
"If you haven't already, call get_agent_building_guide first."
"Create a new agent from JSON (nodes + links). Validates, "
"auto-fixes, and saves. "
"Requires get_agent_building_guide first (refuses otherwise)."
)
@property
@@ -70,6 +72,10 @@ class CreateAgentTool(BaseTool):
) -> ToolResponseBase:
session_id = session.session_id if session else None
guide_gate = require_guide_read(session, "create_agent")
if guide_gate is not None:
return guide_gate
if not agent_json:
return ErrorResponse(
message=(

View File

@@ -8,6 +8,7 @@ from backend.copilot.model import ChatSession
from .agent_generator import get_agent_as_json
from .agent_generator.pipeline import fetch_library_agents, fix_validate_and_save
from .base import BaseTool
from .helpers import require_guide_read
from .models import ErrorResponse, ToolResponseBase
logger = logging.getLogger(__name__)
@@ -24,7 +25,7 @@ class EditAgentTool(BaseTool):
def description(self) -> str:
return (
"Edit an existing agent. Validates, auto-fixes, and saves. "
"If you haven't already, call get_agent_building_guide first."
"Requires get_agent_building_guide first (refuses otherwise)."
)
@property
@@ -73,6 +74,10 @@ class EditAgentTool(BaseTool):
library_agent_ids = []
session_id = session.session_id if session else None
guide_gate = require_guide_read(session, "edit_agent")
if guide_gate is not None:
return guide_gate
if not agent_id:
return ErrorResponse(
message="Please provide the agent ID to edit.",

View File

@@ -42,6 +42,10 @@ COPILOT_EXCLUDED_BLOCK_IDS = {
# OrchestratorBlock - dynamically discovers downstream blocks via graph topology;
# usable in agent graphs (guide hardcodes its ID) but cannot run standalone.
"3b191d9f-356f-482d-8238-ba04b6d18381",
# AutoPilotBlock - has dedicated run_sub_session tool with async start +
# poll lifecycle. Calling it via run_block would block the parent stream
# for the sub-AutoPilot's entire runtime (15-45+ min typical).
"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6",
}

View File

@@ -7,6 +7,7 @@ from backend.copilot.model import ChatSession
from .agent_generator.validation import AgentFixer, AgentValidator, get_blocks_as_dicts
from .base import BaseTool
from .helpers import require_guide_read
from .models import ErrorResponse, FixResultResponse, ToolResponseBase
logger = logging.getLogger(__name__)
@@ -25,7 +26,8 @@ class FixAgentGraphTool(BaseTool):
"Auto-fix common agent JSON issues: missing/invalid UUIDs, StoreValueBlock prerequisites, "
"double curly brace escaping, AddToList/AddToDictionary prerequisites, credentials, "
"node spacing, AI model defaults, link static properties, and type mismatches. "
"Returns fixed JSON and list of fixes applied."
"Returns fixed JSON and list of fixes applied. "
"Requires get_agent_building_guide first (refuses otherwise)."
)
@property
@@ -56,6 +58,10 @@ class FixAgentGraphTool(BaseTool):
) -> ToolResponseBase:
session_id = session.session_id if session else None
guide_gate = require_guide_read(session, "fix_agent_graph")
if guide_gate is not None:
return guide_gate
if not agent_json or not isinstance(agent_json, dict):
return ErrorResponse(
message="Please provide a valid agent JSON object.",

View File

@@ -43,8 +43,10 @@ class GetAgentBuildingGuideTool(BaseTool):
@property
def description(self) -> str:
return (
"Get the agent JSON building guide (nodes, links, AgentExecutorBlock, MCPToolBlock usage, "
"and the create->dry-run->fix iterative workflow). Call before generating agent JSON."
"Agent JSON building guide (nodes, links, AgentExecutorBlock, "
"MCPToolBlock, iterative create->dry-run->fix flow). REQUIRED "
"before create_agent / edit_agent / validate_agent_graph / "
"fix_agent_graph — they refuse until called once per session."
)
@property

View File

@@ -0,0 +1,305 @@
"""Poll / wait on / cancel a sub-AutoPilot started by ``run_sub_session``.
Companion to :mod:`run_sub_session`. Operates on the sub's
``ChatSession`` directly — there is no separate registry. Ownership is
re-verified on every call by loading the ChatSession and comparing its
``user_id`` against the authenticated caller.
* **Wait** — subscribe to ``stream_registry`` for the session and drain
until ``StreamFinish`` / ``StreamError`` (terminal) or the per-call
cap fires. On terminal, the aggregated :class:`SessionResult` comes
back in memory — no DB round-trip for the response content.
* **Just check** — ``wait_if_running=0`` skips the subscription. If the
sub's last assistant message already looks terminal, returns
``completed`` with that content.
* **Cancel** — fan out a ``CancelCoPilotEvent`` on the shared cancel
exchange. Whichever worker is running the sub breaks out of its
stream and finalises the session as ``failed``.
"""
import json
import logging
import time
from typing import Any
from backend.copilot import stream_registry
from backend.copilot.executor.utils import enqueue_cancel_task
from backend.copilot.model import ChatSession, get_chat_session
from backend.copilot.sdk.session_waiter import (
SessionOutcome,
SessionResult,
wait_for_session_result,
)
from backend.copilot.sdk.stream_accumulator import ToolCallEntry
from .base import BaseTool
from .models import (
ErrorResponse,
SubSessionProgressSnapshot,
SubSessionStatusResponse,
ToolResponseBase,
)
from .run_sub_session import (
MAX_SUB_SESSION_WAIT_SECONDS,
_sub_session_link,
response_from_outcome,
)
logger = logging.getLogger(__name__)
# Cap on how many recent messages we echo back in a progress snapshot.
_PROGRESS_MESSAGE_LIMIT = 5
_PROGRESS_CONTENT_PREVIEW_CHARS = 400
class GetSubSessionResultTool(BaseTool):
"""Wait for, inspect, or cancel a sub-AutoPilot."""
@property
def name(self) -> str:
return "get_sub_session_result"
@property
def requires_auth(self) -> bool:
return True
@property
def description(self) -> str:
return (
"Poll / wait / cancel a sub-AutoPilot from run_sub_session. "
f"Waits up to wait_if_running sec (max {MAX_SUB_SESSION_WAIT_SECONDS}); "
"cancel=true aborts; include_progress=true returns recent messages "
"from the still-running sub. Works across turns."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"sub_session_id": {
"type": "string",
"description": (
"The sub's session_id returned by run_sub_session "
"(also accepted: sub_autopilot_session_id — same value)."
),
},
"wait_if_running": {
"type": "integer",
"description": (
f"Seconds to wait. 0 = just check. Clamped to "
f"{MAX_SUB_SESSION_WAIT_SECONDS}."
),
"default": 60,
},
"cancel": {
"type": "boolean",
"description": (
"Cancel the sub; takes precedence over wait_if_running."
),
"default": False,
},
"include_progress": {
"type": "boolean",
"description": (
"Populate progress.last_messages when status=running."
),
"default": False,
},
},
"required": ["sub_session_id"],
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
*,
sub_session_id: str = "",
wait_if_running: int = 60,
cancel: bool = False,
include_progress: bool = False,
**kwargs,
) -> ToolResponseBase:
inner_session_id = sub_session_id.strip()
if not inner_session_id:
return ErrorResponse(
message="sub_session_id is required",
session_id=session.session_id,
)
if user_id is None:
return ErrorResponse(
message="Authentication required",
session_id=session.session_id,
)
# Ownership check on every call — loads the ChatSession and
# confirms the caller owns it. Returning the same "not found"
# shape for "doesn't exist" and "belongs to someone else" avoids
# leaking session existence.
sub = await get_chat_session(inner_session_id)
if sub is None or sub.user_id != user_id:
return ErrorResponse(
message=(
f"No sub-session with id {inner_session_id}. It may have "
"never existed or belongs to another user."
),
session_id=session.session_id,
)
started_at = time.monotonic()
if cancel:
# Fan out the cancel event. Whichever worker is running the
# sub will break out of its stream and finalise the session
# as failed. Return "cancelled" immediately; the sub may
# still emit a little more output before the worker notices,
# but the agent doesn't need to wait for that.
await enqueue_cancel_task(inner_session_id)
return SubSessionStatusResponse(
message="Sub-AutoPilot cancel requested.",
session_id=session.session_id,
status="cancelled",
sub_session_id=inner_session_id,
sub_autopilot_session_id=inner_session_id,
sub_autopilot_session_link=_sub_session_link(inner_session_id),
elapsed_seconds=0.0,
)
# If a turn is currently running for this session (stream registry
# meta shows status=running), we can NOT short-circuit on the
# persisted last assistant message — that message belongs to a
# PRIOR turn, and surfacing it here would hand the caller stale
# data while the new turn is mid-flight (sentry r3105409601).
# Only short-circuit when there's no active turn AND the last
# persisted message already looks terminal.
effective_wait = max(0, min(wait_if_running, MAX_SUB_SESSION_WAIT_SECONDS))
registry_session = await stream_registry.get_session(inner_session_id)
turn_in_flight = registry_session is not None and (
getattr(registry_session, "status", "") == "running"
)
terminal_result = None if turn_in_flight else _already_terminal_result(sub)
outcome: SessionOutcome
result: SessionResult
if terminal_result is not None:
outcome, result = "completed", terminal_result
elif effective_wait > 0:
outcome, result = await wait_for_session_result(
session_id=inner_session_id,
user_id=user_id,
timeout=effective_wait,
)
else:
outcome, result = "running", SessionResult()
elapsed = time.monotonic() - started_at
if outcome == "running" and include_progress:
# Running + caller wants progress — hand-assemble the response
# with the progress snapshot attached. response_from_outcome
# doesn't carry progress, so we build the response here.
progress = await _build_progress_snapshot(inner_session_id)
link = _sub_session_link(inner_session_id)
return SubSessionStatusResponse(
message=(
f"Sub-AutoPilot still running after {elapsed:.0f}s."
f"{f' Watch live at {link}.' if link else ''} "
"Call again to keep waiting, or cancel=true to abort."
),
session_id=session.session_id,
status="running",
sub_session_id=inner_session_id,
sub_autopilot_session_id=inner_session_id,
sub_autopilot_session_link=link,
elapsed_seconds=round(elapsed, 2),
progress=progress,
)
return response_from_outcome(
outcome=outcome,
result=result,
inner_session_id=inner_session_id,
parent_session_id=session.session_id,
elapsed=elapsed,
)
def _already_terminal_result(sub: ChatSession) -> SessionResult | None:
"""Rebuild the aggregated result from the sub's persisted last turn,
when the last message is a terminal assistant message.
Lets ``get_sub_session_result`` short-circuit the subscribe+wait
when the agent polls well after the sub actually finished (a common
case when the user pauses and later asks "what's the result?").
Returns ``None`` if the last message isn't terminal.
"""
if not sub.messages:
return None
last = sub.messages[-1]
if last.role != "assistant":
return None
if not last.content and not last.tool_calls:
return None
result = SessionResult()
result.response_text = last.content or ""
# Persisted tool calls are OpenAI-shape dicts; translate to
# ToolCallEntry so the downstream ``response_from_outcome`` can
# ``.model_dump()`` them uniformly with the live-drain path.
for tc in last.tool_calls or []:
fn = tc.get("function") or {}
result.tool_calls.append(
ToolCallEntry(
tool_call_id=tc.get("id", ""),
tool_name=fn.get("name") or tc.get("name") or "",
input=fn.get("arguments") or tc.get("arguments") or tc.get("input"),
output=tc.get("output"),
success=tc.get("success"),
)
)
return result
async def _build_progress_snapshot(
inner_session_id: str | None,
) -> SubSessionProgressSnapshot | None:
"""Read the sub's ChatSession and return a preview of recent messages.
Returns ``None`` silently on lookup failure — progress is best-effort;
missing progress shouldn't abort the normal ``still running`` response.
"""
if not inner_session_id:
return None
try:
sub = await get_chat_session(inner_session_id)
if sub is None:
return None
messages = list(sub.messages)
except Exception as exc: # best-effort peek
logger.debug(
"Progress snapshot unavailable for sub %s: %s",
inner_session_id,
exc,
)
return None
tail = messages[-_PROGRESS_MESSAGE_LIMIT:]
previews: list[dict[str, Any]] = []
for msg in tail:
content = getattr(msg, "content", "") or ""
if not isinstance(content, str):
try:
content = json.dumps(content, default=str)
except (TypeError, ValueError):
content = str(content)
if len(content) > _PROGRESS_CONTENT_PREVIEW_CHARS:
content = content[:_PROGRESS_CONTENT_PREVIEW_CHARS] + ""
previews.append(
{
"role": getattr(msg, "role", "unknown"),
"content": content,
}
)
return SubSessionProgressSnapshot(
message_count=len(messages),
last_messages=previews,
)

View File

@@ -1,5 +1,6 @@
"""Shared helpers for chat tools."""
import asyncio
import logging
import uuid
from collections import defaultdict
@@ -14,6 +15,7 @@ from backend.copilot.constants import (
COPILOT_NODE_EXEC_ID_SEPARATOR,
COPILOT_NODE_PREFIX,
COPILOT_SESSION_PREFIX,
MAX_TOOL_WAIT_SECONDS,
)
from backend.copilot.model import ChatSession
from backend.copilot.sdk.file_ref import FileRefExpansionError, expand_file_refs_in_args
@@ -85,6 +87,71 @@ def get_inputs_from_schema(
return results
async def _charge_block_credits(
_credit_db: Any,
*,
user_id: str,
block_name: str,
block_id: str,
node_exec_id: str,
cost: int,
cost_filter: dict[str, Any],
synthetic_graph_id: str,
synthetic_node_id: str,
) -> None:
"""Charge credits for a block execution and log any billing leak.
Centralised so the normal-path charge and the cancellation-recovery charge
(see ``execute_block``'s finally) use the same metadata and the same
leak-logging contract.
"""
try:
await _credit_db.spend_credits(
user_id=user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=synthetic_graph_id,
graph_id=synthetic_graph_id,
node_id=synthetic_node_id,
node_exec_id=node_exec_id,
block_id=block_id,
block=block_name,
input=cost_filter,
reason="copilot_block_execution",
),
)
except Exception as e:
# Block already executed (with possible side effects). Never
# return ErrorResponse here — the user received output and
# deserves it. Log the billing failure for reconciliation.
leak_type = (
"INSUFFICIENT_BALANCE"
if isinstance(e, InsufficientBalanceError)
else "UNEXPECTED_ERROR"
)
logger.error(
"BILLING_LEAK[%s]: block executed but credit charge failed — "
"user_id=%s, block_id=%s, node_exec_id=%s, cost=%s: %s",
leak_type,
user_id,
block_id,
node_exec_id,
cost,
e,
extra={
"json_fields": {
"billing_leak": True,
"leak_type": leak_type,
"user_id": user_id,
"cost": str(cost),
}
},
)
# Intentionally swallow. Block already executed with possible side
# effects; the caller must still return BlockOutputResponse. The
# BILLING_LEAK log above is the signal for reconciliation.
async def execute_block(
*,
block: AnyBlockSchema,
@@ -210,67 +277,97 @@ async def execute_block(
session_id=session_id,
)
# Execute the block and collect outputs
# Execute the block under the shared MCP wait cap. A block is
# expected to finish in MAX_TOOL_WAIT_SECONDS; if it doesn't, the
# MCP handler would block the stream close to the idle timeout.
# wait_for cancels the generator on timeout, but the finally below
# still settles billing via asyncio.shield — external side effects
# may already have landed and the user should be charged for them.
outputs: dict[str, list[Any]] = defaultdict(list)
async for output_name, output_data in block.execute(
input_data,
**exec_kwargs,
):
outputs[output_name].append(output_data)
charge_handled = False
try:
await asyncio.wait_for(
_collect_block_outputs(block, input_data, exec_kwargs, outputs),
timeout=MAX_TOOL_WAIT_SECONDS,
)
# Charge credits for block execution
if has_cost:
try:
await _credit_db.spend_credits(
user_id=user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=synthetic_graph_id,
graph_id=synthetic_graph_id,
node_id=synthetic_node_id,
node_exec_id=node_exec_id,
# Normal (non-cancelled) path. Mark charge_handled BEFORE the
# await so an outer cancellation landing mid-charge can't race
# the finally block into a double-charge. asyncio.shield keeps
# the spend running to completion even if the outer awaitable
# is cancelled.
if has_cost:
charge_handled = True
await asyncio.shield(
_charge_block_credits(
_credit_db,
user_id=user_id,
block_name=block.name,
block_id=block_id,
block=block.name,
input=cost_filter,
reason="copilot_block_execution",
),
)
except Exception as e:
# Block already executed (with possible side effects). Never
# return ErrorResponse here — the user received output and
# deserves it. Log the billing failure for reconciliation.
leak_type = (
"INSUFFICIENT_BALANCE"
if isinstance(e, InsufficientBalanceError)
else "UNEXPECTED_ERROR"
)
logger.error(
"BILLING_LEAK[%s]: block executed but credit charge failed — "
"user_id=%s, block_id=%s, node_exec_id=%s, cost=%s: %s",
leak_type,
user_id,
block_id,
node_exec_id,
cost,
e,
extra={
"json_fields": {
"billing_leak": True,
"leak_type": leak_type,
"user_id": user_id,
"cost": str(cost),
}
},
node_exec_id=node_exec_id,
cost=cost,
cost_filter=cost_filter,
synthetic_graph_id=synthetic_graph_id,
synthetic_node_id=synthetic_node_id,
)
)
return BlockOutputResponse(
message=f"Block '{block.name}' executed successfully",
block_id=block_id,
block_name=block.name,
outputs=dict(outputs),
success=True,
session_id=session_id,
)
return BlockOutputResponse(
message=f"Block '{block.name}' executed successfully",
block_id=block_id,
block_name=block.name,
outputs=dict(outputs),
success=True,
session_id=session_id,
)
except asyncio.TimeoutError:
# Structured record of tool-call timeouts (SECRT-2247 part 3).
# Grep prod logs for `copilot_tool_timeout` to find tools that
# keep hitting the cap — candidates for prompt tuning or
# escalation to the async start+poll pattern.
logger.warning(
"copilot_tool_timeout tool=run_block block=%s block_id=%s "
"input_keys=%s user=%s session=%s cap_s=%d",
block.name,
block_id,
sorted(input_data.keys()),
user_id,
session_id,
MAX_TOOL_WAIT_SECONDS,
)
return ErrorResponse(
message=(
f"Block '{block.name}' exceeded the "
f"{MAX_TOOL_WAIT_SECONDS}s single-tool wait cap and was "
"cancelled. Long-running work should go through run_agent "
"(graph executions) or run_sub_session (sub-AutoPilot "
"tasks) — those use async start+poll so nothing blocks "
"the chat stream."
),
session_id=session_id,
)
finally:
# Sentry r3105079148: asyncio.wait_for raises CancelledError into
# the generator. Normal `except Exception` doesn't catch it, so
# without this finally a cancelled block would skip credit
# charging entirely while external side effects still landed.
# Only run when the normal-path charge was NOT reached (the flag
# is set before the await, so any cancellation during charge still
# sets it and avoids double-billing — r3105216985).
if has_cost and outputs and not charge_handled:
await asyncio.shield(
_charge_block_credits(
_credit_db,
user_id=user_id,
block_name=block.name,
block_id=block_id,
node_exec_id=node_exec_id,
cost=cost,
cost_filter=cost_filter,
synthetic_graph_id=synthetic_graph_id,
synthetic_node_id=synthetic_node_id,
)
)
except BlockError as e:
logger.warning("Block execution failed: %s", e)
@@ -288,6 +385,23 @@ async def execute_block(
)
async def _collect_block_outputs(
block: AnyBlockSchema,
input_data: dict[str, Any],
exec_kwargs: dict[str, Any],
outputs: dict[str, list[Any]],
) -> None:
"""Drive ``block.execute`` and append each emitted pair to *outputs*.
Extracted so ``asyncio.wait_for`` can wrap exactly the generator-
consumption step; callers read ``outputs`` afterwards (including from
the cancellation path) to decide whether the block produced enough
side-effects to warrant billing.
"""
async for output_name, output_data in block.execute(input_data, **exec_kwargs):
outputs[output_name].append(output_data)
async def resolve_block_credentials(
user_id: str,
block: AnyBlockSchema,
@@ -655,3 +769,51 @@ def _resolve_discriminated_credentials(
resolved[field_name] = effective_field_info
return resolved
# ---------------------------------------------------------------------------
# Agent-generation gate
# ---------------------------------------------------------------------------
#
# Tools that produce or modify agent JSON (create_agent, edit_agent,
# validate_agent_graph, fix_agent_graph) require the parent agent to have
# read the agent-building guide first — otherwise it tends to generate
# JSON that doesn't match the current block schemas, link semantics, or
# AgentExecutorBlock conventions, then waste turns fixing validation
# errors. ``require_guide_read`` returns an ``ErrorResponse`` the caller
# should short-circuit with, or ``None`` when the guide has been read.
_AGENT_GUIDE_TOOL_NAME = "get_agent_building_guide"
def _guide_read_in_session(session: ChatSession) -> bool:
"""True if this session's assistant messages include a guide tool call."""
for msg in reversed(session.messages):
if msg.role != "assistant" or not msg.tool_calls:
continue
for tc in msg.tool_calls:
name = tc.get("function", {}).get("name") or tc.get("name")
if name == _AGENT_GUIDE_TOOL_NAME:
return True
return False
def require_guide_read(session: ChatSession, tool_name: str):
"""Return an ErrorResponse if the guide hasn't been loaded this session.
Import inline to keep ``helpers.py`` free of tool-response imports.
"""
from .models import ErrorResponse # noqa: PLC0415 — avoid circular import
if _guide_read_in_session(session):
return None
return ErrorResponse(
message=(
f"Call get_agent_building_guide first, then retry {tool_name}. "
"The guide documents required block ids, input/output schemas, "
"link semantics, and AgentExecutorBlock / MCPToolBlock usage — "
"generating agent JSON without it produces schema mismatches."
),
session_id=session.session_id,
)

View File

@@ -259,6 +259,90 @@ class ErrorResponse(ToolResponseBase):
details: dict[str, Any] | None = None
class SubSessionProgressSnapshot(BaseModel):
"""Mid-flight snapshot of a running sub-AutoPilot.
Returned under ``progress`` on :class:`SubSessionStatusResponse` when the
caller passes ``include_progress=true`` while the sub is still running.
"""
message_count: int = Field(
description="Total messages in the sub's ChatSession so far.",
)
last_messages: list[dict[str, Any]] = Field(
default_factory=list,
description=(
"Up to the last 5 messages (role + truncated content) from the "
"sub's ChatSession — lets the agent report intermediate progress."
),
)
class SubSessionStatusResponse(ToolResponseBase):
"""Status / result of a sub-AutoPilot run started by ``run_sub_session``.
Returned by both ``run_sub_session`` (synchronously when the sub finishes
within ``wait_for_result``, else with ``status='running'``) and
``get_sub_session_result`` when the agent polls.
"""
type: ResponseType = ResponseType.MCP_TOOL_OUTPUT
status: Literal["running", "completed", "cancelled", "error", "queued"] = Field(
description=(
"Current state of the sub-AutoPilot run. ``queued`` means the "
"target session already had a turn in flight, so the message was "
"pushed onto its pending buffer and will be picked up by the "
"existing turn on its next drain."
),
)
sub_session_id: str = Field(
description=(
"Opaque id for this run. Pass to ``get_sub_session_result`` or "
"``run_sub_session(cancel=true, ...)`` to interact with it."
),
)
response: str | None = Field(
default=None,
description="Assistant response text when status=completed.",
)
sub_autopilot_session_id: str | None = Field(
default=None,
description=(
"The session_id of the sub-AutoPilot conversation. Use with "
"``run_sub_session(..., sub_autopilot_session_id=<this>)`` "
"to continue it."
),
)
sub_autopilot_session_link: str | None = Field(
default=None,
description=(
"Relative URL the user can click to open the sub-AutoPilot "
"conversation in the CoPilot UI. Always set when "
"``sub_autopilot_session_id`` is set."
),
)
tool_calls: list[dict[str, Any]] | None = Field(
default=None,
description="Tool calls made during the sub-AutoPilot run.",
)
error: str | None = Field(
default=None,
description="Error message when status=error.",
)
elapsed_seconds: float | None = Field(
default=None,
description="How long the sub-AutoPilot has been running (or took).",
)
progress: SubSessionProgressSnapshot | None = Field(
default=None,
description=(
"Mid-flight progress snapshot. Populated only when "
"get_sub_session_result is called with include_progress=true "
"and the sub is still running."
),
)
class InputValidationErrorResponse(ToolResponseBase):
"""Response when run_agent receives unknown input fields."""

View File

@@ -6,6 +6,7 @@ from typing import Any
from pydantic import BaseModel, Field, field_validator
from backend.copilot.config import ChatConfig
from backend.copilot.constants import MAX_TOOL_WAIT_SECONDS
from backend.copilot.model import ChatSession
from backend.copilot.tracking import track_agent_run_success, track_agent_scheduled
from backend.data.db_accessors import graph_db, library_db, user_db
@@ -71,7 +72,7 @@ class RunAgentInput(BaseModel):
schedule_name: str = ""
cron: str = ""
timezone: str = "UTC"
wait_for_result: int = Field(default=0, ge=0, le=300)
wait_for_result: int = Field(default=0, ge=0, le=MAX_TOOL_WAIT_SECONDS)
dry_run: bool = Field(default=False)
@field_validator(
@@ -150,9 +151,12 @@ class RunAgentTool(BaseTool):
},
"wait_for_result": {
"type": "integer",
"description": "Max seconds to wait for completion (0-300).",
"description": (
"Max seconds to wait for completion "
f"(0-{MAX_TOOL_WAIT_SECONDS})."
),
"minimum": 0,
"maximum": 300,
"maximum": MAX_TOOL_WAIT_SECONDS,
},
"dry_run": {
"type": "boolean",

View File

@@ -140,7 +140,9 @@ class TestRunBlockFiltering:
async def test_block_denied_by_permissions_returns_error(self):
"""A block denied by CopilotPermissions returns an ErrorResponse."""
session = make_session(user_id=_TEST_USER_ID)
block_id = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
# NB: must not match any id in COPILOT_EXCLUDED_BLOCK_IDS — we want
# the permissions guard to fire, not the exclusion guard.
block_id = "11111111-2222-3333-4444-555555555555"
standard_block = make_mock_block(block_id, "HTTP Request", BlockType.STANDARD)
perms = CopilotPermissions(blocks=[block_id], blocks_exclude=True)
@@ -645,3 +647,230 @@ class TestRunBlockSensitiveAction:
assert isinstance(response, BlockOutputResponse)
assert response.success is True
class TestExecuteBlockTimeout:
"""``execute_block`` caps the block's generator consumption at
MAX_TOOL_WAIT_SECONDS and must:
1. Return an actionable ErrorResponse pointing at run_agent / run_sub_session.
2. Log a ``copilot_tool_timeout`` warning (SECRT-2247 part 3).
3. Still charge credits when outputs were produced before the timeout
(sentry r3105079148 — cancellation must not leak billing)."""
@pytest.mark.asyncio(loop_scope="session")
async def test_timeout_returns_error_and_logs(self, caplog):
import asyncio
import logging
from backend.copilot.tools.helpers import execute_block
mock_block = MagicMock()
mock_block.name = "SlowBlock"
mock_block.id = "slow-block-id"
mock_block.input_schema = MagicMock()
mock_block.input_schema.jsonschema.return_value = {
"properties": {},
"required": [],
}
mock_block.input_schema.get_credentials_fields.return_value = {}
async def _hang(_input, **_kwargs):
await asyncio.sleep(10)
yield "never", "never"
mock_block.execute = _hang
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="ws-1")
)
with (
patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(0, {}),
),
patch(
"backend.copilot.tools.helpers.MAX_TOOL_WAIT_SECONDS",
0.05,
),
caplog.at_level(logging.WARNING, logger="backend.copilot.tools.helpers"),
):
response = await execute_block(
block=mock_block,
block_id="slow-block-id",
input_data={"x": 1},
user_id="u-1",
session_id="s-1",
node_exec_id="n-1",
matched_credentials={},
dry_run=False,
)
assert isinstance(response, ErrorResponse)
assert "single-tool wait cap" in response.message
assert "run_agent" in response.message
assert any(
"copilot_tool_timeout" in record.getMessage() for record in caplog.records
), "timeout must emit a grep-friendly log line for SECRT-2247 part 3"
@pytest.mark.asyncio(loop_scope="session")
async def test_cancellation_after_output_still_charges_credits(self):
"""Regression for sentry r3105079148 — wait_for's CancelledError
bypassed credit charging; fix uses a shielded finally. One output
emitted, then timeout: spend_credits must still be called once."""
import asyncio
from backend.copilot.tools.helpers import execute_block
mock_block = MagicMock()
mock_block.name = "CostlyBlock"
mock_block.id = "costly-block-id"
mock_block.input_schema = MagicMock()
mock_block.input_schema.jsonschema.return_value = {
"properties": {},
"required": [],
}
mock_block.input_schema.get_credentials_fields.return_value = {}
# Generator: emit ONE output (simulating a side-effectful API call),
# then hang — execute_block's internal wait_for cancels us.
async def _one_output_then_hang(_input, **_kw):
yield "result", "side effect happened"
await asyncio.sleep(10)
yield "extra", "should never arrive"
mock_block.execute = _one_output_then_hang
charged: dict[str, object] = {}
class _FakeCreditDB:
async def get_credits(self, _user_id: str) -> int:
return 10_000
async def spend_credits(self, **kwargs):
charged["last"] = kwargs
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="ws-1")
)
with (
patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
),
patch(
"backend.copilot.tools.helpers.credit_db",
return_value=_FakeCreditDB(),
),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(5, {}),
),
patch(
"backend.copilot.tools.helpers.MAX_TOOL_WAIT_SECONDS",
0.2,
),
):
response = await execute_block(
block=mock_block,
block_id="costly-block-id",
input_data={},
user_id="u-42",
session_id="s-42",
node_exec_id="n-42",
matched_credentials={},
dry_run=False,
)
# Cap fired → response is the timeout ErrorResponse
assert isinstance(response, ErrorResponse)
assert "single-tool wait cap" in response.message
# Critical: billing ran via the shielded finally despite the cancellation
assert charged.get("last") is not None, (
"Credits were NOT charged after cancellation — billing leak "
"(sentry r3105079148)"
)
assert charged["last"]["user_id"] == "u-42"
assert charged["last"]["cost"] == 5
@pytest.mark.asyncio(loop_scope="session")
async def test_no_double_charge_on_cancellation_during_charge(self):
"""Regression for sentry r3105216985 — if the caller cancels during
the normal-path credit charge, the finally must NOT charge a second
time. The fix marks charge_handled BEFORE awaiting spend_credits."""
import asyncio
from backend.copilot.tools.helpers import execute_block
mock_block = MagicMock()
mock_block.name = "OnceOnlyBlock"
mock_block.id = "once-only-id"
mock_block.input_schema = MagicMock()
mock_block.input_schema.jsonschema.return_value = {
"properties": {},
"required": [],
}
mock_block.input_schema.get_credentials_fields.return_value = {}
async def _one_then_done(_input, **_kw):
yield "result", "done"
mock_block.execute = _one_then_done
spend_calls: list[dict] = []
class _CountingCreditDB:
async def get_credits(self, _user_id: str) -> int:
return 10_000
async def spend_credits(self, **kwargs):
# Cooperative suspension so an outer cancellation can
# theoretically interleave — shield should still make this
# complete exactly once.
await asyncio.sleep(0)
spend_calls.append(kwargs)
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="ws-1")
)
with (
patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
),
patch(
"backend.copilot.tools.helpers.credit_db",
return_value=_CountingCreditDB(),
),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(7, {}),
),
):
response = await execute_block(
block=mock_block,
block_id="once-only-id",
input_data={},
user_id="u-single",
session_id="s-single",
node_exec_id="n-single",
matched_credentials={},
dry_run=False,
)
assert isinstance(response, BlockOutputResponse)
assert response.success is True
assert len(spend_calls) == 1, (
f"spend_credits must be called exactly once, got {len(spend_calls)} "
"(double-charge — sentry r3105216985)"
)

View File

@@ -0,0 +1,258 @@
"""Start a sub-AutoPilot conversation via the copilot_executor queue.
Mirror-image of ``run_agent`` + ``view_agent_output`` for copilot turns:
1. The tool creates (or validates ownership of) an inner ``ChatSession``
and calls :func:`run_copilot_turn_via_queue` — the shared primitive
that creates the stream-registry session meta, enqueues a
``CoPilotExecutionEntry``, and waits on the Redis stream until the
terminal event arrives or the cap fires.
2. Any available ``copilot_executor`` worker claims the job, runs
the SDK stream to completion, and publishes the final
``StreamFinish`` event on the session's Redis stream.
3. If the terminal event arrives in the wait window, the aggregated
:class:`SessionResult` (response text, tool calls, usage) comes back
in memory — no DB round-trip. Otherwise the tool returns
``status="running"`` + the sub's ``session_id`` and the agent polls
via :mod:`get_sub_session_result`.
Compared to the prior in-process ``asyncio.Task`` implementation this
gives us deploy/crash resilience, natural load balancing across
workers, and a uniform conversation model — a sub is just another
copilot turn routed through the same queue and event bus as every
other turn.
"""
import logging
import time
from typing import Any
from backend.copilot.constants import MAX_TOOL_WAIT_SECONDS
from backend.copilot.context import get_current_permissions
from backend.copilot.model import ChatSession, create_chat_session, get_chat_session
from backend.copilot.sdk.session_waiter import (
SessionOutcome,
SessionResult,
run_copilot_turn_via_queue,
)
from .base import BaseTool
from .models import ErrorResponse, SubSessionStatusResponse, ToolResponseBase
logger = logging.getLogger(__name__)
# Max wait for a single run_sub_session / get_sub_session_result call.
# Shared with every other long-running tool so the stream idle timeout's
# 2x headroom holds uniformly.
MAX_SUB_SESSION_WAIT_SECONDS = MAX_TOOL_WAIT_SECONDS
class RunSubSessionTool(BaseTool):
"""Delegate a task to a fresh sub-AutoPilot via the copilot_executor queue."""
@property
def name(self) -> str:
return "run_sub_session"
@property
def requires_auth(self) -> bool:
return True
@property
def description(self) -> str:
return (
"Delegate a task to a fresh sub-AutoPilot. Runs on the copilot "
"executor queue — survives tab-close AND worker restarts. Waits "
f"up to wait_for_result sec (max {MAX_SUB_SESSION_WAIT_SECONDS}). "
"If not done, returns status=running + sub_session_id — poll via "
"get_sub_session_result."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "The task for the sub-AutoPilot to execute.",
},
"system_context": {
"type": "string",
"description": "Optional context prepended to the prompt.",
"default": "",
},
"sub_autopilot_session_id": {
"type": "string",
"description": ("Continue/queue-into a prior sub; empty = new."),
"default": "",
},
"wait_for_result": {
"type": "integer",
"description": (
"Seconds to wait inline. 0 = return immediately. "
f"Clamped to {MAX_SUB_SESSION_WAIT_SECONDS}."
),
"default": 60,
},
},
"required": ["prompt"],
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
*,
prompt: str = "",
system_context: str = "",
sub_autopilot_session_id: str = "",
wait_for_result: int = 60,
**kwargs,
) -> ToolResponseBase:
if not prompt.strip():
return ErrorResponse(
message="prompt is required",
session_id=session.session_id,
)
if user_id is None:
return ErrorResponse(
message="Authentication required",
session_id=session.session_id,
)
# Resolve the sub's ChatSession id — either resume an owned one or
# create a fresh session that inherits the parent's dry_run so a
# sub spawned inside a dry-run conversation doesn't silently
# escalate to a live run.
sub_session_param = sub_autopilot_session_id.strip()
if sub_session_param:
owned = await get_chat_session(sub_session_param)
if owned is None or owned.user_id != user_id:
return ErrorResponse(
message=(
f"sub_autopilot_session_id {sub_session_param} is not "
"a session you own. Leave empty to start a fresh sub, "
"or pass a session_id returned by a previous "
"run_sub_session call of yours."
),
session_id=session.session_id,
)
inner_session_id = sub_session_param
else:
new_session = await create_chat_session(user_id, dry_run=session.dry_run)
inner_session_id = new_session.session_id
effective_prompt = prompt
if system_context.strip():
effective_prompt = f"[System Context: {system_context.strip()}]\n\n{prompt}"
cap = max(0, min(wait_for_result, MAX_SUB_SESSION_WAIT_SECONDS))
started_at = time.monotonic()
outcome, result = await run_copilot_turn_via_queue(
session_id=inner_session_id,
user_id=user_id,
message=effective_prompt,
timeout=cap,
permissions=get_current_permissions(),
tool_call_id=(f"sub:{session.session_id}" if session.session_id else "sub"),
tool_name="run_sub_session",
)
elapsed = time.monotonic() - started_at
return response_from_outcome(
outcome=outcome,
result=result,
inner_session_id=inner_session_id,
parent_session_id=session.session_id,
elapsed=elapsed,
)
def _sub_session_link(inner_session_id: str | None) -> str | None:
"""Build the CoPilot UI URL for a sub-AutoPilot session.
Kept in one place so the format stays consistent across the
running/completed/error paths, and so the frontend only has one
contract to honour.
"""
if not inner_session_id:
return None
return f"/copilot?sessionId={inner_session_id}"
def response_from_outcome(
*,
outcome: SessionOutcome,
result: SessionResult,
inner_session_id: str,
parent_session_id: str | None,
elapsed: float,
) -> SubSessionStatusResponse:
"""Translate a ``(SessionOutcome, SessionResult)`` tuple into the
``SubSessionStatusResponse`` contract the LLM sees.
``completed`` surfaces the aggregated response text + tool calls.
``failed`` returns the error marker with the same handles.
``running`` returns just the polling handles so the agent can resume.
``queued`` means the target session already had a turn in flight; the
message was appended to its pending buffer and will be processed by
the existing turn on its next drain.
"""
link = _sub_session_link(inner_session_id)
if outcome == "queued":
return SubSessionStatusResponse(
message=(
f"Target session already had a turn in flight; the message "
f"was queued ({result.pending_buffer_length} now pending) and "
"will be processed by the existing turn on its next drain. "
f"Call get_sub_session_result to poll progress"
f"{f' or watch live at {link}' if link else ''}."
),
session_id=parent_session_id,
status="queued",
sub_session_id=inner_session_id,
sub_autopilot_session_id=inner_session_id,
sub_autopilot_session_link=link,
elapsed_seconds=round(elapsed, 2),
)
if outcome == "running":
return SubSessionStatusResponse(
message=(
f"Sub-AutoPilot is still running after {elapsed:.0f}s."
f"{f' Watch live at {link}.' if link else ''} "
"Call get_sub_session_result (optionally with "
"include_progress=true) to wait, poll, or inspect progress."
),
session_id=parent_session_id,
status="running",
sub_session_id=inner_session_id,
sub_autopilot_session_id=inner_session_id,
sub_autopilot_session_link=link,
elapsed_seconds=round(elapsed, 2),
)
if outcome == "failed":
return SubSessionStatusResponse(
message="Sub-AutoPilot failed. See the sub's transcript for details.",
session_id=parent_session_id,
status="error",
sub_session_id=inner_session_id,
sub_autopilot_session_id=inner_session_id,
sub_autopilot_session_link=link,
elapsed_seconds=round(elapsed, 2),
)
# completed
return SubSessionStatusResponse(
message=f"Sub-AutoPilot completed.{f' View at {link}.' if link else ''}",
session_id=parent_session_id,
status="completed",
sub_session_id=inner_session_id,
sub_autopilot_session_id=inner_session_id,
sub_autopilot_session_link=link,
response=result.response_text,
tool_calls=[tc.model_dump() for tc in result.tool_calls],
elapsed_seconds=round(elapsed, 2),
)

View File

@@ -0,0 +1,523 @@
"""Tests for run_sub_session + get_sub_session_result (queue-backed flow).
Sub-AutoPilots are enqueued on the copilot_execution RabbitMQ queue and
executed by any copilot_executor worker. The tools wait for completion
by subscribing to ``stream_registry`` for the sub's ChatSession. These
tests patch the three integration seams — ``enqueue_copilot_turn``,
``wait_for_session_result``, and ``stream_registry.create_session``
— to exercise the tool logic without needing RabbitMQ or Redis.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from .get_sub_session_result import GetSubSessionResultTool
from .models import ErrorResponse, SubSessionStatusResponse
from .run_sub_session import MAX_SUB_SESSION_WAIT_SECONDS, RunSubSessionTool
def _session(user_id: str = "u", session_id: str = "s1") -> MagicMock:
sess = MagicMock()
sess.session_id = session_id
sess.dry_run = False
return sess
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def mock_queue(monkeypatch):
"""Patch the enqueue helpers + the stream-registry session creator at
the source modules (session_waiter / get_sub_session_result) so tests
don't need RabbitMQ or Redis. Returns a dict of the mocks so
individual tests can assert on them.
"""
enqueue_turn = AsyncMock()
enqueue_cancel = AsyncMock()
create_session = AsyncMock()
# run_sub_session calls enqueue_copilot_turn via session_waiter's
# run_copilot_turn_via_queue helper — patch at the helper's source.
monkeypatch.setattr(
"backend.copilot.sdk.session_waiter.enqueue_copilot_turn",
enqueue_turn,
)
monkeypatch.setattr(
"backend.copilot.tools.get_sub_session_result.enqueue_cancel_task",
enqueue_cancel,
)
monkeypatch.setattr(
"backend.copilot.sdk.session_waiter.stream_registry.create_session",
create_session,
)
return {
"enqueue_turn": enqueue_turn,
"enqueue_cancel": enqueue_cancel,
"create_session": create_session,
}
@pytest.fixture
def mock_waiter(monkeypatch):
"""Patch the queue-backed primitive and the lightweight waiter so
tests can drive outcome + result deterministically. Returns the
``run_copilot_turn_via_queue`` mock (used by run_sub_session) and
the ``wait_for_session_result`` mock (used by get_sub_session_result)
wired to return ``("running", SessionResult())`` by default."""
from backend.copilot.sdk.session_waiter import SessionResult
turn_mock = AsyncMock(return_value=("running", SessionResult()))
result_mock = AsyncMock(return_value=("running", SessionResult()))
monkeypatch.setattr(
"backend.copilot.tools.run_sub_session.run_copilot_turn_via_queue",
turn_mock,
)
monkeypatch.setattr(
"backend.copilot.tools.get_sub_session_result.wait_for_session_result",
result_mock,
)
# Single handle with both attrs for tests that only care about one.
turn_mock.result_mock = result_mock
return turn_mock
@pytest.fixture
def mock_model(monkeypatch):
"""Patch the model-layer helpers the tools call for session CRUD +
ownership checks. The create side returns a fake ChatSession with a
fresh uuid each call."""
created: list[MagicMock] = []
async def fake_create(user_id: str, *, dry_run: bool):
sess = MagicMock()
sess.session_id = f"inner-{len(created) + 1}"
sess.user_id = user_id
sess.dry_run = dry_run
sess.messages = []
created.append(sess)
return sess
async def fake_get(session_id: str):
for s in created:
if s.session_id == session_id:
return s
return None
# The tool modules bind these names at import time, so patch the
# local module bindings (not the source in backend.copilot.model).
monkeypatch.setattr(
"backend.copilot.tools.run_sub_session.create_chat_session", fake_create
)
monkeypatch.setattr(
"backend.copilot.tools.run_sub_session.get_chat_session", fake_get
)
monkeypatch.setattr(
"backend.copilot.tools.get_sub_session_result.get_chat_session", fake_get
)
return {"created": created, "get": fake_get}
# ---------------------------------------------------------------------------
# RunSubSessionTool
# ---------------------------------------------------------------------------
class TestRunSubSession:
@pytest.mark.asyncio
async def test_missing_prompt_returns_error(self):
r = await RunSubSessionTool()._execute(
user_id="u", session=_session(), prompt=""
)
assert isinstance(r, ErrorResponse)
@pytest.mark.asyncio
async def test_no_user_returns_error(self):
r = await RunSubSessionTool()._execute(
user_id=None, session=_session(), prompt="hi"
)
assert isinstance(r, ErrorResponse)
@pytest.mark.asyncio
async def test_resume_with_other_users_session_id_rejected(
self, monkeypatch, mock_queue, mock_waiter
):
"""Ownership must be re-verified when the caller passes a resume id."""
foreign = MagicMock(session_id="alien-sess", user_id="not-caller", messages=[])
async def fake_get(session_id: str):
if session_id == "alien-sess":
return foreign
return None
monkeypatch.setattr(
"backend.copilot.tools.run_sub_session.get_chat_session", fake_get
)
r = await RunSubSessionTool()._execute(
user_id="alice",
session=_session("alice"),
prompt="continue",
sub_autopilot_session_id="alien-sess",
)
assert isinstance(r, ErrorResponse)
assert "is not a session you own" in r.message
mock_queue["enqueue_turn"].assert_not_awaited()
@pytest.mark.asyncio
async def test_propagates_dry_run_to_sub(self, mock_queue, mock_waiter, mock_model):
"""Fresh sub-session must inherit the parent's dry_run flag."""
parent = _session("alice")
parent.dry_run = True
await RunSubSessionTool()._execute(
user_id="alice",
session=parent,
prompt="hi",
wait_for_result=0, # skip the wait helper for this assertion
)
assert mock_model["created"], "create_chat_session was never awaited"
assert mock_model["created"][0].dry_run is True
@pytest.mark.asyncio
async def test_forwards_parent_permissions_to_queue(
self, monkeypatch, mock_queue, mock_waiter, mock_model
):
"""The parent's CopilotPermissions must be passed through to the
queue primitive so the worker applies the same filter."""
from backend.copilot.permissions import CopilotPermissions
perms = CopilotPermissions(tools=["run_block"], tools_exclude=False)
monkeypatch.setattr(
"backend.copilot.tools.run_sub_session.get_current_permissions",
lambda: perms,
)
await RunSubSessionTool()._execute(
user_id="alice",
session=_session("alice"),
prompt="hi",
wait_for_result=0,
)
mock_waiter.assert_awaited_once()
assert mock_waiter.await_args.kwargs["permissions"] is perms
@pytest.mark.asyncio
async def test_wait_for_result_zero_returns_running(
self, mock_queue, mock_waiter, mock_model
):
"""wait_for_result=0 still dispatches the job (so the sub starts)
but the primitive returns 'running' immediately because timeout=0,
and the tool surfaces that to the caller."""
r = await RunSubSessionTool()._execute(
user_id="alice",
session=_session("alice"),
prompt="hi",
wait_for_result=0,
)
assert isinstance(r, SubSessionStatusResponse)
assert r.status == "running"
assert r.sub_session_id == r.sub_autopilot_session_id == "inner-1"
assert r.sub_autopilot_session_link == "/copilot?sessionId=inner-1"
mock_waiter.assert_awaited_once()
assert mock_waiter.await_args.kwargs["timeout"] == 0
@pytest.mark.asyncio
async def test_wait_for_result_completed_returns_final_response(
self, mock_queue, mock_waiter, mock_model
):
"""When the queue primitive returns 'completed' + a SessionResult,
the tool surfaces response_text + tool_calls directly — no DB
round-trip needed for the content."""
from backend.copilot.sdk.session_waiter import SessionResult
from backend.copilot.sdk.stream_accumulator import ToolCallEntry
res = SessionResult()
res.response_text = "the answer"
res.tool_calls = [
ToolCallEntry(
tool_call_id="tc-1",
tool_name="foo",
input={"x": 1},
output="ok",
success=True,
)
]
mock_waiter.return_value = ("completed", res)
r = await RunSubSessionTool()._execute(
user_id="alice",
session=_session("alice"),
prompt="hi",
wait_for_result=60,
)
assert isinstance(r, SubSessionStatusResponse)
assert r.status == "completed"
assert r.response == "the answer"
assert r.tool_calls is not None and len(r.tool_calls) == 1
assert r.tool_calls[0]["tool_name"] == "foo"
mock_waiter.assert_awaited_once()
@pytest.mark.asyncio
async def test_queued_outcome_surfaces_queued_status(
self, mock_queue, mock_waiter, mock_model
):
"""When the shared primitive reports the target session already has
a turn running, the tool surfaces ``status='queued'`` so the LLM can
decide whether to poll or move on."""
from backend.copilot.sdk.session_waiter import SessionResult
queued_res = SessionResult(queued=True, pending_buffer_length=2)
mock_waiter.return_value = ("queued", queued_res)
r = await RunSubSessionTool()._execute(
user_id="alice",
session=_session("alice"),
prompt="please do another thing",
wait_for_result=0,
)
assert isinstance(r, SubSessionStatusResponse)
assert r.status == "queued"
assert r.sub_session_id == "inner-1"
assert "queued" in (r.message or "").lower()
@pytest.mark.asyncio
async def test_wait_clamps_above_maximum(self, mock_queue, mock_waiter, mock_model):
"""wait_for_result values above the cap are clamped before being
passed to the queue primitive."""
await RunSubSessionTool()._execute(
user_id="alice",
session=_session("alice"),
prompt="hi",
wait_for_result=MAX_SUB_SESSION_WAIT_SECONDS + 999,
)
mock_waiter.assert_awaited_once()
assert mock_waiter.await_args.kwargs["timeout"] == MAX_SUB_SESSION_WAIT_SECONDS
# ---------------------------------------------------------------------------
# GetSubSessionResultTool
# ---------------------------------------------------------------------------
class TestGetSubSessionResult:
@pytest.mark.asyncio
async def test_missing_id_returns_error(self):
r = await GetSubSessionResultTool()._execute(
user_id="u", session=_session(), sub_session_id=""
)
assert isinstance(r, ErrorResponse)
@pytest.mark.asyncio
async def test_unknown_id_returns_error(self, monkeypatch):
async def none_get(_sid):
return None
monkeypatch.setattr(
"backend.copilot.tools.get_sub_session_result.get_chat_session",
none_get,
)
r = await GetSubSessionResultTool()._execute(
user_id="u", session=_session(), sub_session_id="missing"
)
assert isinstance(r, ErrorResponse)
assert "No sub-session with id missing" in r.message
@pytest.mark.asyncio
async def test_other_user_cannot_access(self, monkeypatch):
"""Cross-user lookups are indistinguishable from 'not found'."""
foreign = MagicMock(user_id="bob", messages=[])
async def foreign_get(_sid):
return foreign
monkeypatch.setattr(
"backend.copilot.tools.get_sub_session_result.get_chat_session",
foreign_get,
)
r = await GetSubSessionResultTool()._execute(
user_id="alice", session=_session("alice"), sub_session_id="bobs-sess"
)
assert isinstance(r, ErrorResponse)
assert "No sub-session" in r.message
@pytest.mark.asyncio
async def test_wait_returns_running(self, monkeypatch, mock_waiter):
sub = MagicMock(user_id="alice", messages=[])
async def fake_get(_sid):
return sub
async def no_active_session(_sid):
return None
monkeypatch.setattr(
"backend.copilot.tools.get_sub_session_result.get_chat_session",
fake_get,
)
monkeypatch.setattr(
"backend.copilot.tools.get_sub_session_result.stream_registry.get_session",
no_active_session,
)
r = await GetSubSessionResultTool()._execute(
user_id="alice",
session=_session("alice"),
sub_session_id="inner-7",
wait_if_running=30,
)
assert isinstance(r, SubSessionStatusResponse)
assert r.status == "running"
assert r.sub_session_id == "inner-7"
mock_waiter.result_mock.assert_awaited_once()
@pytest.mark.asyncio
async def test_wait_returns_completed_with_response(self, monkeypatch, mock_waiter):
"""'completed' outcome surfaces the SessionResult directly."""
from backend.copilot.sdk.session_waiter import SessionResult
sub = MagicMock(user_id="alice", messages=[]) # not terminal-looking
async def fake_get(_sid):
return sub
async def no_active_session(_sid):
return None
monkeypatch.setattr(
"backend.copilot.tools.get_sub_session_result.get_chat_session",
fake_get,
)
monkeypatch.setattr(
"backend.copilot.tools.get_sub_session_result.stream_registry.get_session",
no_active_session,
)
res = SessionResult()
res.response_text = "done"
mock_waiter.result_mock.return_value = ("completed", res)
r = await GetSubSessionResultTool()._execute(
user_id="alice",
session=_session("alice"),
sub_session_id="inner-3",
wait_if_running=30,
)
assert isinstance(r, SubSessionStatusResponse)
assert r.status == "completed"
assert r.response == "done"
@pytest.mark.asyncio
async def test_already_terminal_skips_waiter(self, monkeypatch, mock_waiter):
"""If the sub's last message is already terminal AND no turn is
in flight, the tool returns 'completed' without ever calling
wait_for_session_result — it rebuilds the response from the
persisted message instead."""
sub = MagicMock(user_id="alice")
assistant = MagicMock()
assistant.role = "assistant"
assistant.content = "already done"
assistant.tool_calls = None
sub.messages = [assistant]
async def fake_get(_sid):
return sub
async def no_active_session(_sid):
return None
monkeypatch.setattr(
"backend.copilot.tools.get_sub_session_result.get_chat_session",
fake_get,
)
monkeypatch.setattr(
"backend.copilot.tools.get_sub_session_result.stream_registry.get_session",
no_active_session,
)
r = await GetSubSessionResultTool()._execute(
user_id="alice",
session=_session("alice"),
sub_session_id="inner-9",
wait_if_running=30,
)
assert isinstance(r, SubSessionStatusResponse)
assert r.status == "completed"
assert r.response == "already done"
mock_waiter.result_mock.assert_not_awaited()
@pytest.mark.asyncio
async def test_resume_turn_in_flight_does_not_return_stale(
self, monkeypatch, mock_waiter
):
"""Regression for sentry r3105409601: on a resumed session whose
stream_registry status is 'running' (new turn is mid-flight) the
tool must NOT short-circuit to the prior turn's terminal message.
It subscribes to the stream like a normal running-session poll."""
# DB state reflects the PREVIOUS turn's terminal assistant message.
prior = MagicMock()
prior.role = "assistant"
prior.content = "OLD stale result"
prior.tool_calls = None
sub = MagicMock(user_id="alice", messages=[prior])
async def fake_get(_sid):
return sub
running_meta = MagicMock(status="running")
async def active_registry(_sid):
return running_meta
monkeypatch.setattr(
"backend.copilot.tools.get_sub_session_result.get_chat_session",
fake_get,
)
monkeypatch.setattr(
"backend.copilot.tools.get_sub_session_result.stream_registry.get_session",
active_registry,
)
r = await GetSubSessionResultTool()._execute(
user_id="alice",
session=_session("alice"),
sub_session_id="inner-11",
wait_if_running=30,
)
# The waiter must have been awaited — stale short-circuit was skipped.
mock_waiter.result_mock.assert_awaited_once()
assert isinstance(r, SubSessionStatusResponse)
# Default mock_waiter.result_mock.return_value = ("running", SessionResult())
assert r.status == "running"
# And crucially NOT the stale content.
assert r.response is None or r.response == ""
@pytest.mark.asyncio
async def test_cancel_publishes_cancel_event(
self, monkeypatch, mock_queue, mock_waiter
):
"""cancel=true fans out a CancelCoPilotEvent and returns 'cancelled'
without waiting for the sub to finish (the worker will finalise)."""
sub = MagicMock(user_id="alice", messages=[])
async def fake_get(_sid):
return sub
monkeypatch.setattr(
"backend.copilot.tools.get_sub_session_result.get_chat_session",
fake_get,
)
r = await GetSubSessionResultTool()._execute(
user_id="alice",
session=_session("alice"),
sub_session_id="inner-5",
cancel=True,
)
assert isinstance(r, SubSessionStatusResponse)
assert r.status == "cancelled"
mock_queue["enqueue_cancel"].assert_awaited_once_with("inner-5")
mock_waiter.result_mock.assert_not_awaited()

View File

@@ -754,15 +754,15 @@ async def test_run_agent_session_dry_run_overrides_kwargs():
captured_params["dry_run"] = params.dry_run
return {}, None
with patch(
"backend.copilot.tools.run_agent.fetch_graph_from_store_slug",
new_callable=AsyncMock,
return_value=(graph, None),
), patch.object(
tool, "_check_prerequisites", side_effect=capture_prerequisites
), patch.object(
tool, "_run_agent", new_callable=AsyncMock
) as mock_run_agent:
with (
patch(
"backend.copilot.tools.run_agent.fetch_graph_from_store_slug",
new_callable=AsyncMock,
return_value=(graph, None),
),
patch.object(tool, "_check_prerequisites", side_effect=capture_prerequisites),
patch.object(tool, "_run_agent", new_callable=AsyncMock) as mock_run_agent,
):
mock_run_agent.return_value = MagicMock()
# Pass dry_run=False in kwargs — session.dry_run=True should win.
@@ -796,15 +796,15 @@ async def test_run_agent_session_dry_run_false_allows_scheduling():
captured_params["dry_run"] = params.dry_run
return {}, None
with patch(
"backend.copilot.tools.run_agent.fetch_graph_from_store_slug",
new_callable=AsyncMock,
return_value=(graph, None),
), patch.object(
tool, "_check_prerequisites", side_effect=capture_prerequisites
), patch.object(
tool, "_schedule_agent", new_callable=AsyncMock
) as mock_schedule:
with (
patch(
"backend.copilot.tools.run_agent.fetch_graph_from_store_slug",
new_callable=AsyncMock,
return_value=(graph, None),
),
patch.object(tool, "_check_prerequisites", side_effect=capture_prerequisites),
patch.object(tool, "_schedule_agent", new_callable=AsyncMock) as mock_schedule,
):
mock_schedule.return_value = MagicMock()
await tool._execute(
@@ -840,15 +840,15 @@ async def test_run_agent_session_dry_run_false_allows_llm_dry_run_true():
captured_params["dry_run"] = params.dry_run
return {}, None
with patch(
"backend.copilot.tools.run_agent.fetch_graph_from_store_slug",
new_callable=AsyncMock,
return_value=(graph, None),
), patch.object(
tool, "_check_prerequisites", side_effect=capture_prerequisites
), patch.object(
tool, "_run_agent", new_callable=AsyncMock
) as mock_run_agent:
with (
patch(
"backend.copilot.tools.run_agent.fetch_graph_from_store_slug",
new_callable=AsyncMock,
return_value=(graph, None),
),
patch.object(tool, "_check_prerequisites", side_effect=capture_prerequisites),
patch.object(tool, "_run_agent", new_callable=AsyncMock) as mock_run_agent,
):
mock_run_agent.return_value = MagicMock()
# LLM passes dry_run=True; normal session must NOT override it to False

View File

@@ -7,6 +7,7 @@ from backend.copilot.model import ChatSession
from .agent_generator.validation import AgentValidator, get_blocks_as_dicts
from .base import BaseTool
from .helpers import require_guide_read
from .models import ErrorResponse, ToolResponseBase, ValidationResultResponse
logger = logging.getLogger(__name__)
@@ -24,7 +25,8 @@ class ValidateAgentGraphTool(BaseTool):
return (
"Validate agent JSON for correctness: block_ids, links, required fields, "
"type compatibility, nested sink notation, prompt brace escaping, "
"and AgentExecutorBlock configs. On failure, use fix_agent_graph to auto-fix."
"and AgentExecutorBlock configs. On failure, use fix_agent_graph to auto-fix. "
"Requires get_agent_building_guide first (refuses otherwise)."
)
@property
@@ -53,6 +55,10 @@ class ValidateAgentGraphTool(BaseTool):
) -> ToolResponseBase:
session_id = session.session_id if session else None
guide_gate = require_guide_read(session, "validate_agent_graph")
if guide_gate is not None:
return guide_gate
if not agent_json or not isinstance(agent_json, dict):
return ErrorResponse(
message="Please provide a valid agent JSON object.",

View File

@@ -1,10 +1,10 @@
"""JSONL transcript management for stateless multi-turn resume.
The Claude Code CLI persists conversations as JSONL files (one JSON object per
line). When the SDK's ``Stop`` hook fires we read this file, strip bloat
(progress entries, metadata), and upload the result to bucket storage. On the
next turn we download the transcript, write it to a temp file, and pass
``--resume`` so the CLI can reconstruct the full conversation.
line). When the SDK's ``Stop`` hook fires the caller reads this file, strips
bloat (progress entries, metadata), and uploads the result to bucket storage.
On the next turn the caller downloads the bytes and writes them to disk before
passing ``--resume`` so the CLI can reconstruct the full conversation.
Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
filesystem for self-hosted) — no DB column needed.
@@ -20,6 +20,7 @@ import shutil
import time
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Literal
from uuid import uuid4
from backend.util import json
@@ -27,6 +28,9 @@ from backend.util.clients import get_openai_client
from backend.util.prompt import CompressResult, compress_context
from backend.util.workspace_storage import GCSWorkspaceStorage, get_workspace_storage
if TYPE_CHECKING:
from .model import ChatMessage
logger = logging.getLogger(__name__)
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
@@ -44,17 +48,17 @@ STRIPPABLE_TYPES = frozenset(
)
TranscriptMode = Literal["sdk", "baseline"]
@dataclass
class TranscriptDownload:
"""Result of downloading a transcript with its metadata."""
content: str
message_count: int = 0 # session.messages length when uploaded
uploaded_at: float = 0.0 # epoch timestamp of upload
content: bytes | str
message_count: int = 0
# "sdk" = Claude CLI native, "baseline" = TranscriptBuilder
mode: TranscriptMode = "sdk"
# Workspace storage constants — deterministic path from session_id.
TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts"
# Storage prefix for the CLI's native session JSONL files (for cross-pod --resume).
_CLI_SESSION_STORAGE_PREFIX = "cli-sessions"
@@ -363,7 +367,7 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
def _projects_base() -> str:
def projects_base() -> str:
"""Return the resolved path to the CLI's projects directory."""
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
return os.path.realpath(os.path.join(config_dir, "projects"))
@@ -390,8 +394,8 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
Returns the number of directories removed.
"""
projects_base = _projects_base()
if not os.path.isdir(projects_base):
_pbase = projects_base()
if not os.path.isdir(_pbase):
return 0
now = time.time()
@@ -399,7 +403,7 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
# Scoped mode: only clean up the one directory for the current session.
if encoded_cwd:
target = Path(projects_base) / encoded_cwd
target = Path(_pbase) / encoded_cwd
if not target.is_dir():
return 0
# Guard: only sweep copilot-generated dirs.
@@ -437,7 +441,7 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
# Only safe for single-tenant deployments; callers should prefer the
# scoped variant by passing encoded_cwd.
try:
entries = Path(projects_base).iterdir()
entries = Path(_pbase).iterdir()
except OSError as e:
logger.warning("[Transcript] Failed to list projects dir: %s", e)
return 0
@@ -490,9 +494,9 @@ def read_compacted_entries(transcript_path: str) -> list[dict] | None:
if not transcript_path:
return None
projects_base = _projects_base()
_pbase = projects_base()
real_path = os.path.realpath(transcript_path)
if not real_path.startswith(projects_base + os.sep):
if not real_path.startswith(_pbase + os.sep):
logger.warning(
"[Transcript] transcript_path outside projects base: %s", transcript_path
)
@@ -611,28 +615,6 @@ def validate_transcript(content: str | None) -> bool:
# ---------------------------------------------------------------------------
def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
"""Return (workspace_id, file_id, filename) for a session's transcript.
Path structure: ``chat-transcripts/{user_id}/{session_id}.jsonl``
IDs are sanitized to hex+hyphen to prevent path traversal.
"""
return (
TRANSCRIPT_STORAGE_PREFIX,
_sanitize_id(user_id),
f"{_sanitize_id(session_id)}.jsonl",
)
def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
"""Return (workspace_id, file_id, filename) for a session's transcript metadata."""
return (
TRANSCRIPT_STORAGE_PREFIX,
_sanitize_id(user_id),
f"{_sanitize_id(session_id)}.meta.json",
)
def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
"""Build a full storage path from (workspace_id, file_id, filename) parts."""
wid, fid, fname = parts
@@ -642,24 +624,12 @@ def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
return f"local://{wid}/{fid}/{fname}"
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
"""Build the full storage path string that ``retrieve()`` expects."""
return _build_path_from_parts(_storage_path_parts(user_id, session_id), backend)
def _build_meta_storage_path(user_id: str, session_id: str, backend: object) -> str:
"""Build the full storage path for the companion .meta.json file."""
return _build_path_from_parts(
_meta_storage_path_parts(user_id, session_id), backend
)
# ---------------------------------------------------------------------------
# CLI native session file — cross-pod --resume support
# ---------------------------------------------------------------------------
def _cli_session_path(sdk_cwd: str, session_id: str) -> str:
def cli_session_path(sdk_cwd: str, session_id: str) -> str:
"""Expected path of the CLI's native session JSONL file.
The CLI resolves the working directory via ``os.path.realpath``, then
@@ -675,7 +645,7 @@ def _cli_session_path(sdk_cwd: str, session_id: str) -> str:
"""
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
safe_id = _sanitize_id(session_id)
return os.path.join(_projects_base(), encoded_cwd, f"{safe_id}.jsonl")
return os.path.join(projects_base(), encoded_cwd, f"{safe_id}.jsonl")
def _cli_session_storage_path_parts(
@@ -689,235 +659,82 @@ def _cli_session_storage_path_parts(
)
async def upload_cli_session(
user_id: str,
session_id: str,
sdk_cwd: str,
log_prefix: str = "[Transcript]",
) -> None:
"""Upload the CLI's native session JSONL file to remote storage.
Called after each turn so the next turn can restore the file on any pod
(eliminating the pod-affinity requirement for --resume).
The CLI only writes the session file after the turn completes, so this
must run in the finally block, AFTER the SDK stream has finished.
"""
session_file = _cli_session_path(sdk_cwd, session_id)
real_path = os.path.realpath(session_file)
projects_base = _projects_base()
if not real_path.startswith(projects_base + os.sep):
logger.warning(
"%s CLI session file outside projects base, skipping upload: %s",
log_prefix,
os.path.basename(real_path),
)
return
try:
raw_bytes = Path(real_path).read_bytes()
except FileNotFoundError:
logger.debug(
"%s CLI session file not found, skipping upload: %s",
log_prefix,
session_file,
)
return
except OSError as e:
logger.warning("%s Failed to read CLI session file: %s", log_prefix, e)
return
# Strip stale thinking blocks and metadata entries (progress, file-history-snapshot,
# queue-operation) from the CLI session before writing it back locally and uploading
# to GCS. Thinking blocks from non-last assistant turns are not needed for --resume
# but can be massive (tens of thousands of tokens each), causing the CLI to auto-compact
# its session when the context window fills up. Stripping keeps the session well below
# the ~200K-token compaction threshold and prevents silent context loss.
try:
raw_text = raw_bytes.decode("utf-8")
stripped_text = strip_for_upload(raw_text)
stripped_bytes = stripped_text.encode("utf-8")
if len(stripped_bytes) < len(raw_bytes):
# Write the stripped version back locally so same-pod turns also benefit.
Path(real_path).write_bytes(stripped_bytes)
logger.info(
"%s Stripped CLI session file: %dB → %dB",
log_prefix,
len(raw_bytes),
len(stripped_bytes),
)
content = stripped_bytes
except Exception as e:
logger.warning(
"%s Failed to strip CLI session file, uploading raw: %s", log_prefix, e
)
content = raw_bytes
storage = await get_workspace_storage()
wid, fid, fname = _cli_session_storage_path_parts(user_id, session_id)
try:
await storage.store(
workspace_id=wid, file_id=fid, filename=fname, content=content
)
logger.info(
"%s Uploaded CLI session file (%dB) for cross-pod --resume",
log_prefix,
len(content),
)
except Exception as e:
logger.warning("%s Failed to upload CLI session file: %s", log_prefix, e)
async def restore_cli_session(
user_id: str,
session_id: str,
sdk_cwd: str,
log_prefix: str = "[Transcript]",
) -> bool:
"""Download and restore the CLI's native session file for --resume.
Returns True if the file was successfully restored and --resume can be
used with the session UUID. Returns False if not available (first turn
or upload failed), in which case the caller should not set --resume.
"""
session_file = _cli_session_path(sdk_cwd, session_id)
real_path = os.path.realpath(session_file)
projects_base = _projects_base()
if not real_path.startswith(projects_base + os.sep):
logger.warning(
"%s CLI session restore path outside projects base: %s",
log_prefix,
os.path.basename(session_file),
)
return False
# If the session file already exists locally (same-pod reuse), use it directly.
# Downloading from storage could overwrite a newer local version when a previous
# turn's upload failed: stored content is stale while the local file already
# contains extended history from that turn.
if Path(real_path).exists():
logger.debug(
"%s CLI session file already exists locally — using it for --resume",
log_prefix,
)
return True
storage = await get_workspace_storage()
path = _build_path_from_parts(
_cli_session_storage_path_parts(user_id, session_id), storage
def _cli_session_meta_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
"""Return (workspace_id, file_id, filename) for the CLI session meta file."""
return (
_CLI_SESSION_STORAGE_PREFIX,
_sanitize_id(user_id),
f"{_sanitize_id(session_id)}.meta.json",
)
try:
content = await storage.retrieve(path)
except FileNotFoundError:
logger.debug("%s No CLI session in storage (first turn or missing)", log_prefix)
return False
except Exception as e:
logger.warning("%s Failed to download CLI session: %s", log_prefix, e)
return False
try:
os.makedirs(os.path.dirname(real_path), exist_ok=True)
Path(real_path).write_bytes(content)
logger.info(
"%s Restored CLI session file (%dB) for --resume",
log_prefix,
len(content),
)
return True
except OSError as e:
logger.warning("%s Failed to write CLI session file: %s", log_prefix, e)
return False
async def upload_transcript(
user_id: str,
session_id: str,
content: str,
content: bytes,
message_count: int = 0,
mode: TranscriptMode = "sdk",
log_prefix: str = "[Transcript]",
skip_strip: bool = False,
) -> None:
"""Strip progress entries and stale thinking blocks, then upload transcript.
"""Upload CLI session content to GCS with companion meta.json.
The transcript represents the FULL active context (atomic).
Each upload REPLACES the previous transcript entirely.
Pure GCS operation — no disk I/O. The caller is responsible for reading
the session file from disk before calling this function.
The executor holds a cluster lock per session, so concurrent uploads for
the same session cannot happen.
Also uploads a companion .meta.json with the message_count watermark so
download_transcript can return it without a separate fetch.
Args:
content: Complete JSONL transcript (from TranscriptBuilder).
message_count: ``len(session.messages)`` at upload time.
skip_strip: When ``True``, skip the strip + re-validate pass.
Safe for builder-generated content (baseline path) which
never emits progress entries or stale thinking blocks.
Called after each turn so the next turn can restore the file on any pod
(eliminating the pod-affinity requirement for --resume).
"""
if skip_strip:
# Caller guarantees the content is already clean and valid.
stripped = content
else:
# Strip metadata entries and stale thinking blocks in a single parse.
# SDK-built transcripts may have progress entries; strip for safety.
stripped = strip_for_upload(content)
if not skip_strip and not validate_transcript(stripped):
# Log entry types for debugging — helps identify why validation failed
entry_types = [
json.loads(line, fallback={"type": "INVALID_JSON"}).get("type", "?")
for line in stripped.strip().split("\n")
]
logger.warning(
"%s Skipping upload — stripped content not valid "
"(types=%s, stripped_len=%d, raw_len=%d)",
log_prefix,
entry_types,
len(stripped),
len(content),
)
logger.debug("%s Raw content preview: %s", log_prefix, content[:500])
logger.debug("%s Stripped content: %s", log_prefix, stripped[:500])
return
storage = await get_workspace_storage()
wid, fid, fname = _storage_path_parts(user_id, session_id)
encoded = stripped.encode("utf-8")
meta = {"message_count": message_count, "uploaded_at": time.time()}
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
wid, fid, fname = _cli_session_storage_path_parts(user_id, session_id)
mwid, mfid, mfname = _cli_session_meta_path_parts(user_id, session_id)
meta = {"message_count": message_count, "mode": mode, "uploaded_at": time.time()}
meta_encoded = json.dumps(meta).encode("utf-8")
# Transcript + metadata are independent objects at different keys, so
# write them concurrently. ``return_exceptions`` keeps a metadata
# failure from sinking the transcript write.
transcript_result, metadata_result = await asyncio.gather(
storage.store(
workspace_id=wid,
file_id=fid,
filename=fname,
content=encoded,
),
storage.store(
workspace_id=mwid,
file_id=mfid,
filename=mfname,
content=meta_encoded,
),
return_exceptions=True,
)
if isinstance(transcript_result, BaseException):
raise transcript_result
if isinstance(metadata_result, BaseException):
# Metadata is best-effort — the gap-fill logic in
# _build_query_message tolerates a missing metadata file.
logger.warning("%s Failed to write metadata: %s", log_prefix, metadata_result)
# Write JSONL first, meta second — sequential so a crash between the two
# leaves an orphaned JSONL (no meta) rather than an orphaned meta (wrong
# watermark / mode paired with stale or absent content).
# On any failure we roll back the other file so the pair is always absent
# together; download_transcript returns None when either file is missing.
try:
await storage.store(
workspace_id=wid, file_id=fid, filename=fname, content=content
)
except Exception as session_err:
logger.warning(
"%s Failed to upload CLI session file: %s", log_prefix, session_err
)
return
try:
await storage.store(
workspace_id=mwid, file_id=mfid, filename=mfname, content=meta_encoded
)
except Exception as meta_err:
logger.warning("%s Failed to upload CLI session meta: %s", log_prefix, meta_err)
# Roll back the JSONL so neither file exists — avoids orphaned JSONL being
# used with wrong mode/watermark defaults on the next restore.
try:
session_path = _build_path_from_parts(
_cli_session_storage_path_parts(user_id, session_id), storage
)
await storage.delete(session_path)
except Exception as rollback_err:
logger.debug(
"%s Session rollback failed (harmless — download will return None): %s",
log_prefix,
rollback_err,
)
return
logger.info(
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
"%s Uploaded CLI session (%dB, msg_count=%d, mode=%s)",
log_prefix,
len(encoded),
len(content),
message_count,
mode,
)
@@ -926,83 +743,181 @@ async def download_transcript(
session_id: str,
log_prefix: str = "[Transcript]",
) -> TranscriptDownload | None:
"""Download transcript and metadata from bucket storage.
"""Download CLI session from GCS. Returns content + message_count + mode, or None if not found.
Returns a ``TranscriptDownload`` with the JSONL content and the
``message_count`` watermark from the upload, or ``None`` if not found.
Pure GCS operation — no disk I/O. The caller is responsible for writing
content to disk if --resume is needed.
The content and metadata fetches run concurrently since they are
independent objects in the bucket.
Returns a TranscriptDownload with the raw content, message_count watermark,
and mode on success, or None if not available (first turn or upload failed).
"""
storage = await get_workspace_storage()
path = _build_storage_path(user_id, session_id, storage)
meta_path = _build_meta_storage_path(user_id, session_id, storage)
path = _build_path_from_parts(
_cli_session_storage_path_parts(user_id, session_id), storage
)
meta_path = _build_path_from_parts(
_cli_session_meta_path_parts(user_id, session_id), storage
)
content_task = asyncio.create_task(storage.retrieve(path))
meta_task = asyncio.create_task(storage.retrieve(meta_path))
content_result, meta_result = await asyncio.gather(
content_task, meta_task, return_exceptions=True
storage.retrieve(path),
storage.retrieve(meta_path),
return_exceptions=True,
)
if isinstance(content_result, FileNotFoundError):
logger.debug("%s No transcript in storage", log_prefix)
logger.debug("%s No CLI session in storage (first turn or missing)", log_prefix)
return None
if isinstance(content_result, BaseException):
logger.warning(
"%s Failed to download transcript: %s", log_prefix, content_result
"%s Failed to download CLI session: %s", log_prefix, content_result
)
return None
content = content_result.decode("utf-8")
content: bytes = content_result
# Metadata is best-effort — old transcripts won't have it.
# Parse message_count and mode from companion meta best-effort, defaults.
message_count = 0
uploaded_at = 0.0
mode: TranscriptMode = "sdk"
if isinstance(meta_result, FileNotFoundError):
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
pass # No meta — old upload; default to "sdk"
elif isinstance(meta_result, BaseException):
logger.debug(
"%s Failed to load transcript metadata: %s", log_prefix, meta_result
)
logger.debug("%s Failed to load CLI session meta: %s", log_prefix, meta_result)
else:
meta = json.loads(meta_result.decode("utf-8"), fallback={})
message_count = meta.get("message_count", 0)
uploaded_at = meta.get("uploaded_at", 0.0)
try:
meta_str = meta_result.decode("utf-8")
except UnicodeDecodeError:
logger.debug("%s CLI session meta is not valid UTF-8, ignoring", log_prefix)
meta_str = None
if meta_str is not None:
meta = json.loads(meta_str, fallback={})
if isinstance(meta, dict):
raw_count = meta.get("message_count", 0)
message_count = (
raw_count if isinstance(raw_count, int) and raw_count >= 0 else 0
)
raw_mode = meta.get("mode", "sdk")
mode = raw_mode if raw_mode in ("sdk", "baseline") else "sdk"
logger.info(
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
)
return TranscriptDownload(
content=content,
message_count=message_count,
uploaded_at=uploaded_at,
"%s Downloaded CLI session (%dB, msg_count=%d, mode=%s)",
log_prefix,
len(content),
message_count,
mode,
)
return TranscriptDownload(content=content, message_count=message_count, mode=mode)
def detect_gap(
download: TranscriptDownload,
session_messages: list[ChatMessage],
) -> list[ChatMessage]:
"""Return chat-db messages after the transcript watermark (excluding current user turn).
Returns [] if transcript is current, watermark is zero, or the watermark
position doesn't end on an assistant turn (misaligned watermark).
"""
if download.message_count == 0:
return []
wm = download.message_count
total = len(session_messages)
if wm >= total - 1:
return []
# Sanity: position wm-1 should be an assistant turn; misaligned watermark
# means the DB messages shifted (e.g. deletion) — skip gap to avoid wrong context.
# In normal operation ``message_count`` is always written after a complete
# user→assistant exchange (never mid-turn), so the last covered position is
# always assistant. This guard fires only on data corruption or message deletion.
if session_messages[wm - 1].role != "assistant":
return []
return list(session_messages[wm : total - 1])
def extract_context_messages(
download: TranscriptDownload | None,
session_messages: "list[ChatMessage]",
) -> "list[ChatMessage]":
"""Return context messages for the current turn: transcript content + gap.
This is the shared context primitive used by both the SDK path
(``use_resume=False`` → ``<conversation_history>`` injection) and the
baseline path (OpenAI messages array).
How it works:
- When a transcript exists, ``TranscriptBuilder.load_previous`` preserves
``isCompactSummary=True`` compaction entries, so the returned messages
mirror the compacted context the CLI would see via ``--resume``.
- The gap (DB messages after the transcript watermark) is always small in
normal operation; it only grows during mode switches or when an upload
was missed.
- Falls back to full DB messages when no transcript exists (first turn,
upload failure, or GCS unavailable).
- Returns *prior* messages only (excluding the current user turn at
``session_messages[-1]``). Callers that need the current turn append
``session_messages[-1]`` themselves.
- **Tool calls from transcript entries are flattened to text**: assistant
messages derived from the JSONL use ``_flatten_assistant_content``, which
serialises ``tool_use`` blocks as human-readable text rather than
structured ``tool_calls``. Gap messages (from DB) preserve their
original ``tool_calls`` field. This is the same trade-off as the old
``_compress_session_messages(session.messages)`` approach — no regression.
Args:
download: The ``TranscriptDownload`` from GCS, or ``None`` when no
transcript is available. ``content`` may be either ``bytes`` or
``str`` (the baseline path decodes + strips before returning).
session_messages: All messages in the session, with the current user
turn as the last element.
Returns:
A list of ``ChatMessage`` objects covering the prior conversation
context, suitable for injection as conversation history.
"""
from .model import ChatMessage as _ChatMessage # runtime import
# ``role="reasoning"`` rows are persisted for frontend replay of
# extended_thinking content but are NOT conversation context — the
# transcript-based --resume path already carries thinking separately,
# and sending them back to the model as user/assistant turns would be
# both redundant and malformed. Drop them before any gap detection
# or transcript comparison so ordering invariants still hold.
session_messages = [m for m in session_messages if m.role != "reasoning"]
prior = session_messages[:-1]
if download is None:
return prior
raw_content = download.content
if not raw_content:
return prior
# Handle both bytes (raw GCS download) and str (pre-decoded baseline path).
if isinstance(raw_content, bytes):
try:
content_str: str = raw_content.decode("utf-8")
except UnicodeDecodeError:
return prior
else:
content_str = raw_content
raw = _transcript_to_messages(content_str)
if not raw:
return prior
transcript_msgs = [
_ChatMessage(role=m["role"], content=m.get("content") or "") for m in raw
]
gap = detect_gap(download, session_messages)
return transcript_msgs + gap
async def delete_transcript(user_id: str, session_id: str) -> None:
"""Delete transcript and its metadata from bucket storage.
Removes both the ``.jsonl`` transcript and the companion ``.meta.json``
so stale ``message_count`` watermarks cannot corrupt gap-fill logic.
"""
"""Delete CLI session JSONL and its companion .meta.json from bucket storage."""
storage = await get_workspace_storage()
path = _build_storage_path(user_id, session_id, storage)
try:
await storage.delete(path)
logger.info("[Transcript] Deleted transcript for session %s", session_id)
except Exception as e:
logger.warning("[Transcript] Failed to delete transcript: %s", e)
# Also delete the companion .meta.json to avoid orphaned metadata.
try:
meta_path = _build_meta_storage_path(user_id, session_id, storage)
await storage.delete(meta_path)
logger.info("[Transcript] Deleted metadata for session %s", session_id)
except Exception as e:
logger.warning("[Transcript] Failed to delete metadata: %s", e)
# Also delete the CLI native session file to prevent storage growth.
try:
cli_path = _build_path_from_parts(
_cli_session_storage_path_parts(user_id, session_id), storage
@@ -1012,6 +927,15 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
except Exception as e:
logger.warning("[Transcript] Failed to delete CLI session: %s", e)
try:
cli_meta_path = _build_path_from_parts(
_cli_session_meta_path_parts(user_id, session_id), storage
)
await storage.delete(cli_meta_path)
logger.info("[Transcript] Deleted CLI session meta for session %s", session_id)
except Exception as e:
logger.warning("[Transcript] Failed to delete CLI session meta: %s", e)
# ---------------------------------------------------------------------------
# Transcript compaction — LLM summarization for prompt-too-long recovery

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More