Compare commits

..

44 Commits

Author SHA1 Message Date
0xArty
e71f5958f1 fix(frontend): make library builder links versionless 2026-04-24 16:41:04 +01:00
Swifty
a8a4c2e56e add workflow.md 2026-04-24 15:37:50 +02:00
Abhimanyu Yadav
2cb52e5d19 feat(frontend): add Settings v2 page layout behind SETTINGS_V2 flag (SECRT-2272) (#12885)
### Why / What / How

**Why:** The Settings area is getting a redesign (per Figma
[Settings-Page](https://www.figma.com/design/YGck0Hb0GEgFzwbX47kSNs/Settings-Page?node-id=1-2)).
Ticket SECRT-2272 covers just the shell so content/forms for each
section can land in follow-up PRs without blocking on the nav
restructure. v1 at `/profile/settings` must stay intact for end users
during the rollout.

**What:** Adds a new parallel Settings hub at `/settings` (dedicated
sidebar + 7 placeholder sub-routes) behind a new `SETTINGS_V2`
LaunchDarkly flag. Default `false` so nothing changes for users until
the flag flips. Backend is untouched.


https://github.com/user-attachments/assets/dd680eaf-3d41-4a9a-87f3-d06d536a2503


**How:**
- New `Flag.SETTINGS_V2 = "settings-v2"` added to `use-get-flag.ts` with
`defaultFlags[Flag.SETTINGS_V2] = false`. Gate the whole route group at
`layout.tsx` via existing `FeatureFlagPage` HOC which redirects to
`/profile/settings` when the flag is off.
- `SettingsSidebar` replicates the Figma spec (237px, 7 items at 217×38,
`gap-[7px]`, rounded-[8px], active `bg-[#EFEFF0]` + text `#1F1F20` Geist
Medium, inactive text `#505057` Geist Regular, icon 16px Phosphor
light/regular at `#1F1F20`). Colors + typography use the canonical
tokens exported by Figma (zinc-50 `#F9F9FA`, zinc-200 `#DADADC` for the
right-border, etc.).
- `SettingsNavItem` is extracted as its own component and owns its
per-item entrance variant.
- Per-link loading indicator uses Next.js 15's `useLinkStatus()` hook —
spinner appears on the right of the clicked item and clears
automatically once the target page renders.
- `SettingsMobileNav` (< md breakpoint): sidebar hides; a pill trigger
with the current section's icon + label opens a Radix Popover listing
all 7 sections.
- Entrance animations via framer-motion, tuned to Emil Kowalski's
guidelines — `cubic-bezier(0, 0, 0.2, 1)` ease-out, all durations ≤
280ms, only `transform` and `opacity`, `useReducedMotion` disables
movement but keeps fade. Sidebar items stagger in (40ms offset). Main
content re-animates on every route change via `key={pathname}`.
- All 7 placeholder pages render the section title (Poppins Medium 22/28
via `variant="h4"`, `#1F1F20`) + "Coming soon" copy; they are
intentionally client components to avoid hook-order issues with the
client-side flag gate in the layout.

### Changes 🏗️

- `src/services/feature-flags/use-get-flag.ts`: register
`Flag.SETTINGS_V2` + default `false`
- `src/app/(platform)/settings/layout.tsx`: flag gate + responsive shell
+ route-keyed content animation
- `src/app/(platform)/settings/page.tsx`: client-side redirect to
`/settings/profile`
- `src/app/(platform)/settings/components/SettingsSidebar/`:
  - `SettingsSidebar.tsx` — aside with staggered entrance
- `SettingsNavItem.tsx` — per-item Link + icon + label + loader
(extracted)
- `useSettingsSidebar.ts` — hook mapping nav items with `isActive` from
`usePathname`
- `helpers.ts` — typed nav item config (label / href / Phosphor icon) ×
7
-
`src/app/(platform)/settings/components/SettingsMobileNav/SettingsMobileNav.tsx`:
mobile Popover trigger
- 7 placeholder pages: `profile`, `creator-dashboard`, `billing`,
`integrations`, `preferences`, `api-keys`, `oauth-apps`

**Follow-up PRs will migrate real content into each tab.** LaunchDarkly
flag key `settings-v2` must be created in the LD dashboard before
enabling for users.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] `NEXT_PUBLIC_FORCE_FLAG_SETTINGS_V2=true` → `/settings` redirects
to `/settings/profile`, sidebar renders 7 items with "Profile" active
- [x] Click each nav item → URL changes, active item highlights, content
pane re-animates, per-link spinner shows during navigation
- [x] Viewport < 768px → sidebar hides, mobile pill trigger opens
Popover with all 7 items; selecting one navigates and closes
- [x] Without the flag env override, `/settings` redirects to
`/profile/settings` (v1 unchanged)
  - [x] `pnpm types` clean; prettier clean on touched files
- [x] Manual a11y pass with `prefers-reduced-motion` enabled — fade
remains, translations disabled

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
*(no new env vars required; existing `NEXT_PUBLIC_FORCE_FLAG_*` pattern
covers local override)*
- [x] `docker-compose.yml` is updated or already compatible with my
changes *(no docker changes)*
- [x] I have included a list of my configuration changes in the PR
description *(LaunchDarkly dashboard must have `settings-v2` flag
created before enabling; no other config changes)*
2026-04-24 10:39:54 +00:00
Zamil Majdy
ab88d03b13 refactor(backend/integrations): clearer naming + docs for managed-cred sweep (#12908)
## Why

Review comments on #12883 (thanks @Pwuts) surfaced a few spots where the
managed-credential plumbing's names and docstrings didn't match what the
code actually does:

- `_read_or_create_profile_key` suggests "read from any source or create
new", but only migrates the legacy
`managed_credentials.ayrshare_profile_key` side-channel — it doesn't
read an existing managed credential. (That check lives in the outer
`_provision_under_lock`.)
- Docstrings refer to "the startup sweep" in several places — there's no
startup hook; the sweep runs on `/credentials` fetches.
- `is_available` / `auto_provision` relationship wasn't explicit;
readers couldn't tell whether `is_available` was a config check or a
liveness check, or which of the two gates the sweep checks first.

## What

Naming + docstring cleanup. **Zero behavior changes.**

- Rename `_read_or_create_profile_key` →
`_migrate_legacy_or_create_profile_key` with docstring explaining why it
doesn't re-check the managed cred.
- Replace "startup sweep" → "credentials sweep" everywhere.
- `ManagedCredentialProvider` class docstring now names the two gates:
1. `auto_provision` — does this provider participate in the sweep at
all?
  2. `is_available` — are the required env vars / secrets set?
- `is_available` docstring now spells out: what it checks (env vars),
what it does NOT check (upstream health), and that it's only consulted
when `auto_provision=True`.
- `ensure_managed_credentials` docstring defines "credentials sweep",
when it fires, how the per-user in-memory cache works.
- Module-level docstring drops the stale "non-blocking background task"
wording (#12883 made the sweep bounded-await).

## How

4 files, all backend:
- `backend/integrations/managed_credentials.py`
- `backend/integrations/managed_providers/ayrshare.py`
- `backend/integrations/managed_providers/ayrshare_test.py`
- `backend/api/features/integrations/router.py`

Tests: 13/13 Ayrshare tests pass against the rename.

## Checklist

- [x] Follows style guide
- [x] Existing tests still pass (no functional change)
- [x] No new tests needed — pure rename + docstring change
2026-04-24 16:22:09 +07:00
An Vy Le
3aa72b4245 feat(backend/copilot): inline picker-backed inputs via run_block + accept AgentInputBlock subclasses (#12880)
### Why / What / How

**Why:** Resolves #12875. CoPilot's agent-builder was hardcoding Google
Drive file IDs into consuming blocks' `input_default` instead of wiring
an `AgentGoogleDriveFileInputBlock`. A beta user hit this across **13
saved versions** of one agent. Root causes:

1. `validate_io_blocks` only accepted the literal base `AgentInputBlock`
/ `AgentOutputBlock` IDs, so even when CoPilot used a specialized
subclass like `AgentGoogleDriveFileInputBlock` as the only input, the
validator forced it to keep a throwaway base alongside — entrenching the
anti-pattern.
2. Running a Drive consumer directly via CoPilot's `run_block` silently
failed because the auto-credentials flow (picker attaches
`_credentials_id`) existed only in the graph executor, never in
CoPilot's direct-execution path.
3. Drive picker guidance lived in `agent_generation_guide.md` instead of
on the blocks themselves, so it duplicated and drifted from the code.
4. Observed in a live session: when asked to read a private sheet,
CoPilot refused with "share publicly or use the builder" instead of
calling `run_block` and letting the picker render — the prompt rule was
buried and the fallback path (omitted required picker field) returned a
generic schema preview.

**What:** Four coordinated platform + CoPilot improvements. No
block-specific validator rules, no Drive-specific code in UI or prompt.

**How:**

#### 1. `validate_io_blocks` subclass support

Accepts any block with `uiType == "Input"` / `"Output"` (populated from
`Block.block_type` at registration). `AgentGoogleDriveFileInputBlock`,
`AgentDropdownInputBlock`, `AgentTableInputBlock`, etc. stand alone.
Base-ID fallback preserved for call sites that pass a minimal blocks
list.

#### 2. Inline picker via `run_block`

- Extracted `_acquire_auto_credentials` from
`backend/executor/manager.py` into shared
`backend/executor/auto_credentials.py` (exports
`acquire_auto_credentials` + `MissingAutoCredentialsError`).
- Wired it into `backend/copilot/tools/helpers.py::execute_block`. When
`_credentials_id` is present, the block executes with creds injected
(chained flows work). When missing/null, `execute_block` returns the
existing `SetupRequirementsResponse` — frontend's `FormRenderer` renders
the picker inline via the existing
`GoogleDrivePickerField`/`GoogleDrivePickerInput`. On pick, the LLM
re-invokes `run_block` with the populated input — same continuation
pattern as OAuth-missing-credentials. No new response types, no new
continuation tool, no new frontend component.
- `run_block` now short-circuits to `SetupRequirementsResponse` when
missing required fields include a picker-backed field, skipping the
schema-preview round trip the LLM would otherwise take.
- `get_inputs_from_schema` spreads the full property schema (`**schema`)
instead of whitelisting — any `format` / `json_schema_extra` / custom
widget config flows through to the generic custom-field dispatch on the
frontend. Future picker formats (date pickers, file pickers, etc.) work
without backend changes.
- Frontend `SetupRequirementsCard/helpers.ts` uses index-signature
passthrough for arbitrary schema keys — no widget-specific code in that
layer.

#### 3. `validate_only` parameter on `run_block`

`run_block(id, {})` is not always a safe probe — for blocks with zero
required inputs, it executes. New `validate_only: true` parameter
returns `BlockDetailsResponse` (schema + missing-input list) without
executing, rendering picker cards, or charging credits. Same response
shape as the existing schema preview — no new branch, just an extra
condition on the existing one. LLM uses this for pre-flight when it's
unsure whether a block has required inputs.

#### 4. Block-local picker guidance

Agent-generation picker guidance relocated from the guide onto the
blocks themselves — surfaced at `find_block` time, exactly when the LLM
decides to wire a picker-backed consumer:

- `GoogleDriveFileField` (shared factory for every Drive field on
Sheets/Docs/etc.) appends a standard hint to the caller's description
covering: feed from the specialized input block, never hardcode (even
one parsed from a URL), picker is the only credential source.
- `AgentGoogleDriveFileInputBlock`'s block description now covers when
it's required, the `allowed_views` mapping, wiring direction, and a
concrete link-shape example.
- `agent_generation_guide.md` loses the dedicated 71-line Drive section.
The IO-blocks section now tells the LLM specialized subclasses satisfy
the requirement and carry their own usage guidance in block/field
descriptions — read them when `find_block` surfaces a match.
- New "Picker-backed inputs via `run_block`" section in the CoPilot
prompt, written generically (picker fields detected via `format` /
`auto_credentials` schema hints, no provider names hardcoded) — covers:
don't ask the user for URLs/IDs, don't refuse private-resource asks,
chained picker objects pass through as-is.
- Sharpened `MissingAutoCredentialsError` message so when a bare ID
reaches execution, the error explicitly tells the LLM the picker renders
inline (not "ask the user for something").

### Changes 🏗️

- `backend/copilot/tools/agent_generator/validator.py` —
`_collect_io_block_ids` + subclass-aware `validate_io_blocks`.
- `backend/executor/auto_credentials.py` (new) — shared
`acquire_auto_credentials` + `MissingAutoCredentialsError`.
- `backend/executor/manager.py` — imports from the shared module, drops
the local copy.
- `backend/copilot/tools/helpers.py` — `execute_block` calls
`acquire_auto_credentials`, merges kwargs, releases locks in `finally`,
returns `SetupRequirementsResponse` on missing creds.
`get_inputs_from_schema` spreads the full property schema.
- `backend/copilot/tools/run_block.py` — picker-field short-circuit +
`validate_only` parameter.
- `backend/copilot/prompting.py` — "Picker-backed inputs via
`run_block`" + "Pre-flight with `validate_only`" sections.
- `backend/blocks/google/_drive.py` — `GoogleDriveFileField` appends the
agent-builder hint to every Drive consumer's description.
- `backend/blocks/io.py` — `AgentGoogleDriveFileInputBlock` description
expanded.
- `backend/copilot/sdk/agent_generation_guide.md` — Drive section
removed, IO-blocks subclass note expanded.
- `frontend/.../SetupRequirementsCard/helpers.ts` — index-signature
passthrough for arbitrary schema keys; schema fields propagate into the
generated RJSF schema.
- Tests: new `TestExecuteBlockAutoCredentials` (4 cases) +
`validate_only` + picker-short-circuit cases in `run_block_test.py`;
`manager_auto_credentials_test.py` moved to new import path; 6 new
frontend cases in `SetupRequirementsCard/__tests__/helpers.test.ts`
covering schema passthrough.
- Also: one-line hoist of `import secrets` in
`backend/integrations/managed_providers/ayrshare.py` — ruff E402
introduced by #12883 was blocking our lint post-merge.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Backend unit suites: validator_test (48), helpers_test (40),
run_block_test (19), manager_auto_credentials_test (15) — **all green**
- [x] Frontend `SetupRequirementsCard` helpers — **75/75 pass**
(including 6 new passthrough cases)
- [x] `poetry run format` (ruff + isort + black) clean on touched files
(pre-existing pyright errors in unrelated `graphiti_core` /
`StreamEvent` / etc. files not introduced by this PR)
- [x] Live CoPilot chat on dev-builder confirmed the setup card renders
`custom/google_drive_picker_field` for a Drive consumer block called via
`run_block`
- [x] Live agent-generation confirmed CoPilot creates a subclass-only
agent (`AgentGoogleDriveFileInputBlock` → `GoogleSheetsReadBlock` →
`AgentOutputBlock`) with no throwaway base `AgentInputBlock`

#### For configuration changes:
- [x] N/A — no config changes

---------

Co-authored-by: majdyz <zamil.majdy@agpt.co>
2026-04-24 13:05:11 +07:00
Zamil Majdy
cc1f692fec feat(platform): add MAX tier + LD-configurable pricing + hide unconfigured tiers (#12903)
## What

Introduces a new `MAX` tier slot between `PRO` and `BUSINESS`
(self-service $320/mo at 20× capacity), routes every self-service tier's
Stripe price ID through LaunchDarkly, and hides tiers from the UI when
their price isn't configured. `BUSINESS` stays in the enum at 60× as a
reserved/future self-service slot (hidden by default until its LD price
flag is set). ENTERPRISE stays admin-managed.

## Tier shape after this PR

| Enum | UI label | Multiplier | LD price flag | Surfaced in UI by
default |
|---|---|---|---|---|
| `FREE` | Basic | 1× | `stripe-price-id-basic` | no (flag unset) |
| `PRO` | Pro | 5× | `stripe-price-id-pro` | yes (already live) |
| `MAX` **(new)** | Max | 20× | `stripe-price-id-max` | no (flag unset
until $320 price ready) |
| `BUSINESS` | Business | 60× | `stripe-price-id-business` | no
(reserved / future) |
| `ENTERPRISE` | — | 60× | — (admin-managed) | no (Contact-Us only) |

## Prisma

- Added `MAX` between `PRO` and `BUSINESS` in `SubscriptionTier`.
- Migration `add_subscription_tier_max/migration.sql` uses `ALTER TYPE
... ADD VALUE IF NOT EXISTS 'MAX' BEFORE 'BUSINESS'` (transactional
since PG 12). No data migration — no rows currently on BUSINESS via
self-service flows.

## Backend

- `get_subscription_price_id` flag map covers
`FREE`/`PRO`/`MAX`/`BUSINESS`. ENTERPRISE returns `None`.
- `GET /credits/subscription.tier_costs` only includes tiers whose LD
price ID is set. Current tier always present as a safety net.
- `POST /credits/subscription` routes by LD-resolved prices instead of
hard-coding `tier == FREE`:
- Target `FREE` + `stripe-price-id-basic` unset → legacy
cancel-at-period-end (unchanged behaviour).
- Target has LD price → modify in-place when user has an active sub,
else Checkout Session.
- Priced-FREE users with no sub fall through to Checkout (admin-granted
DB-flip shortcut gated on `current_tier != FREE`).
- `sync_subscription_from_stripe` + `get_pending_subscription_change`
cover FREE/PRO/MAX/BUSINESS in the price-to-tier map so every tier's
Stripe webhook reconciles cleanly.
- Pending-tier mapping collapsed into a single membership check.
- `TIER_MULTIPLIERS`: `FREE=1, PRO=5, MAX=20, BUSINESS=60,
ENTERPRISE=60`.

## Frontend

- UI labels: FREE→"Basic", MAX→"Max", BUSINESS→"Business" (PRO
unchanged). `TIER_ORDER` now `[FREE, PRO, MAX, BUSINESS, ENTERPRISE]`.
- `SubscriptionTierSection` filters by `tier_costs` — any tier without a
backend-provided price is hidden (current tier always visible).
- `formatCost` surfaces "Free" only when `FREE` is actually `$0`;
non-zero `stripe-price-id-basic` renders `$X.XX/mo`.
- Admin rate-limit display lists all five tiers with multiplier badges.

## LaunchDarkly flag actions (operator)

- **New:** `stripe-price-id-basic` → FREE tier. Set to `""` or a `$0`
Stripe price.
- **New:** `stripe-price-id-max` → MAX tier. Point at the `$320` Stripe
price when you launch the Max tier.
- **Unchanged:** `stripe-price-id-pro` (PRO), `stripe-price-id-business`
(BUSINESS — leave unset until you're ready for the 60× Business tier).
- Base rate limits stay on `copilot-daily-cost-limit-microdollars` /
`copilot-weekly-cost-limit-microdollars` (Basic's limit; everything else
= × tier multiplier).

## Out of scope

- Subscription-required onboarding screen / middleware gating (separate
PR).
- "Pricing available soon" vs Stripe-failure disambiguation in the UI
(follow-up).

## Testing

- Backend: 213 tests across `subscription_routes_test.py`,
`credit_subscription_test.py`, `rate_limit_test.py`,
`admin/rate_limit_admin_routes_test.py` — all passing.
- Frontend: 91 tests across `credits/` + `admin/rate-limits/` — all
passing.
- Fresh-backend manual E2E on the pre-MAX commit confirmed tier-hiding
works (`tier_costs` returns only the current tier when LD flags are
unset).

## Checklist

- [x] I have read the project's contributing guide.
- [x] I have clearly described what this PR changes and why.
- [x] My code follows the style guidelines of this project.
- [x] I have added tests that prove my fix is effective or that my
feature works.
- [ ] New and existing unit tests pass locally with my changes (CI will
confirm).
2026-04-24 11:11:33 +07:00
Zamil Majdy
be61dc4304 fix(backend): use {schema_prefix} in raw SQL migrations instead of hardcoded 'platform.' (#12905)
### Why / What / How

**Why.** Backend CI was failing at startup with `relation
"platform.AgentNode" does not exist`. Prisma's `migrate deploy` uses the
`schema.prisma` datasource, which doesn't declare a schema, so when
`DATABASE_URL` has no `?schema=platform` query param (as in CI / raw
Supabase), Prisma creates tables in `public` — but the lifespan
migration `backend.data.graph.migrate_llm_models` hardcoded
`platform."AgentNode"` in its raw SQL and crashed the boot.

**What.** Switched `migrate_llm_models` to use the
`execute_raw_with_schema` helper and the `{schema_prefix}` placeholder —
the same pattern already used by the sibling
`fix_llm_provider_credentials` migration in the same file. The helper in
`backend/data/db.py` reads the schema from `DATABASE_URL` at runtime and
substitutes `"platform".` or an empty prefix, so the query works in both
dev (schema=platform) and CI / raw Supabase (public).

**How.**
- Template change: `UPDATE platform."AgentNode"` → `UPDATE
{{schema_prefix}}"AgentNode"` (f-string double-brace escape so
`{schema_prefix}` survives to `.format()` inside
`execute_raw_with_schema`).
- Replace `db.execute_raw(...)` with `execute_raw_with_schema(...)`;
drop the now-unused `prisma as db` import.
- Regression test: mocks `execute_raw_with_schema` and asserts every
emitted query contains `{schema_prefix}` and no longer contains
`platform."AgentNode"`.

### Audit

Audited the other three lifespan migrations in
`backend/api/rest_api.py::lifespan_context`:
- `backend.data.user.migrate_and_encrypt_user_integrations` — uses
Prisma ORM, no raw SQL. OK.
- `backend.data.graph.fix_llm_provider_credentials` — already uses
`query_raw_with_schema` + `{schema_prefix}`. OK.
- `backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs`
— uses Prisma ORM, no raw SQL. OK.

Also grepped the whole backend for `platform."` in Python files —
`migrate_llm_models` was the only offender; the other hits were
unrelated string content (docstrings, error messages, test data).

### Changes

- `autogpt_platform/backend/backend/data/graph.py`: `migrate_llm_models`
now uses `execute_raw_with_schema` with the `{schema_prefix}`
placeholder; unused `prisma as db` import dropped.
- `autogpt_platform/backend/backend/data/graph_test.py`: added
`test_migrate_llm_models_uses_schema_prefix_placeholder` regression
test.

### Checklist

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Ran `migrate_llm_models` under mocked `execute_raw_with_schema` —
all 7 emitted UPDATE queries contain `{schema_prefix}` and none hardcode
`platform."AgentNode"`.
- [x] Verified the f-string double-brace escape by evaluating the
template and running `.format(schema_prefix=...)` — substitution is
correct for both `"platform".` and empty-prefix (public-schema) cases.
- [x] `poetry run pyright backend/data/graph.py` clean (pre-existing
pyright error on `backend/api/features/v1.py:834` on `origin/dev` is
unrelated).
- [x] Grepped the whole backend for other hardcoded `platform."..."`
raw-SQL occurrences — none found.

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
(N/A — no config changes)
- [x] `docker-compose.yml` is updated or already compatible with my
changes (N/A — no config changes)
2026-04-24 10:00:22 +07:00
Zamil Majdy
575f75edf4 refactor(platform): migrate Ayrshare to standard managed-credential flow (#12883)
## Why

Beta user report: AutoPilot told them to sign up for Ayrshare themselves
— which AutoGPT actually manages — because AutoPilot inferred the
requirement from the block description string rather than any structured
schema. Root cause: Ayrshare was the only block family whose
"credential" lived in a bespoke
`UserIntegrations.managed_credentials.ayrshare_profile_key` side channel
and whose blocks declared **no** `credentials` field. `find_block` /
`resolve_block_credentials` had nothing to show the LLM, so the LLM
guessed.

(An initial commit added a runtime `gh` CLI bootstrap for a separate "gh
isn't installed in the sandbox" report — that work was empirically
verified unnecessary and reverted; see the commit history for the bench
results.)

## What

**Ayrshare now goes through the standard managed-credential flow:**

- New `AyrshareManagedProvider` alongside the existing
`AgentMailManagedProvider`. Provisions the per-user profile as
`APIKeyCredentials(provider="ayrshare", is_managed=True)` via the shared
`add_managed_credential` path. Reuses any legacy
`managed_credentials.ayrshare_profile_key` value on first provision so
existing users keep their linked social accounts.
- `AyrshareManagedProvider.is_available()` returns `False` so the
`ensure_managed_credentials` startup sweep **never** auto-provisions
Ayrshare (profile quota is a real per-user subscription cost). New
public `ensure_managed_credential(user_id, store, provider)` helper lets
the `/api/integrations/ayrshare/sso_url` route provision on demand,
reusing the same distributed Redis lock + upsert path as AgentMail.
- New `ProviderBuilder.with_managed_api_key()` method registers
`api_key` as a supported auth type without the env-var-backed default
credential that `with_api_key()` creates — so the org-level Ayrshare
admin key cannot leak to blocks as a "profile key".
- `BaseAyrshareInput` gains a shared `credentials` field; all 13 social
blocks inherit it. Each `run()` now takes `credentials:
APIKeyCredentials`; the inline `get_profile_key` guard + "please link a
social account" error is gone. Standard `resolve_block_credentials`
pre-run check owns the "not connected" path, returning a normal
`SetupRequirementsResponse`.
- **Migration-ordering safety:** `post_provision` hook on
`ManagedCredentialProvider` clears the legacy `ayrshare_profile_key`
field **only after** `add_managed_credential` has durably stored the
managed credential. If persistence fails, the legacy key stays intact so
a retry can reuse it — covered by `TestMigrationOrderingSafety`.
- New public `IntegrationCredentialsStore.get_user_integrations()` —
reads no longer have to reach past the `_get_user_integrations` privacy
fence or abuse `edit_user_integrations` as a pseudo-read.
- `/api/integrations/ayrshare/sso_url` collapses from a 60-line
provision-then-sign dance to: pre-flight `settings_available()`,
`ensure_managed_credential`, fetch the credential, sign a JWT.
- `IntegrationCredentialsStore.set_ayrshare_profile_key` removed — the
managed credential is now the only write path.
- Legacy `UserIntegrations.ManagedCredentials.ayrshare_profile_key`
field is retained so the managed provider can migrate existing users on
first provision; removing the field is a follow-up once rollout has
propagated.

## How

After this PR, `find_block` returns Ayrshare blocks with a structured
`credentials_provider: ['ayrshare']`. AutoPilot sees the credential
requirement the same way it sees GitHub's or AgentMail's, calls
`run_block`, and gets a plain `SetupRequirementsResponse` when the
managed credential has not been provisioned yet. No more
description-string speculation; the whole Ayrshare flow is the normal
flow.

The Builder's `AyrshareConnectButton` (`BlockType.AYRSHARE`) still works
— it hits the same endpoint, now a thin wrapper over the managed
provider — so users still get the "Connect Social Accounts" popup for
OAuth'ing individual social networks.

## Test plan

- [x] `poetry run pytest backend/blocks/test/test_block.py -k "ayrshare
or PostTo"` — 26/26 pass.
- [x] `poetry run pytest
backend/integrations/managed_providers/ayrshare_test.py` — 10/10 pass.
- [x] `poetry run pytest
backend/api/features/integrations/router_test.py` — 21/21 pass.
- [x] `poetry run pyright` on all touched backend files — 0 errors.
- [x] Runtime sanity: `find_block` on `PostToXBlock` lists
`credentials_provider: ['ayrshare']` in the JSON schema.
- [ ] Manual QA in preview: connect social account via Builder's
"Connect Social Accounts" button → post to X via CoPilot end-to-end.
- [ ] Verify existing users with
`managed_credentials.ayrshare_profile_key` continue to work without
re-linking.
2026-04-24 09:37:38 +07:00
Zamil Majdy
0f6eea06c4 feat(platform/backend): dynamic BlockCostType (SECOND/ITEMS/COST_USD/TOKENS) + E2B/FAL migration (#12894)
## Why

PR #12893 shipped flat-floor credit charges so no provider sits
wallet-free. This PR is the next step: make dynamic pricing actually
dynamic. Blocks that scale with walltime, item count, provider-reported
USD, or token volume now get billed based on captured execution stats
instead of a fixed floor.

Before this PR `BlockCostType` only had `RUN` / `BYTE` / `SECOND`, and
`SECOND` was dead code — no caller ever passed `run_time > 0`, so every
per-second entry evaluated to 0. This PR wires the stats plumbing
through, adds the cost-type variants that cover the real billing models
our providers charge on, and migrates blocks across the codebase to use
them.

## What

### Machinery

- `BlockCostType` gains `ITEMS`, `COST_USD`, `TOKENS`. `BlockCost` gains
`cost_divisor: int = 1` so SECOND/ITEMS/TOKENS can express "1 credit per
N units" without fractional amounts.
- `block_usage_cost(..., stats: NodeExecutionStats | None = None)` —
pre-flight (no stats) dynamic types return 0 so the balance check isn't
blocked on unknown-future cost; post-flight (stats populated) they
consume captured execution stats.
- `TokenRate` model + `TOKEN_COST` table (~60 models: Claude family,
GPT-5 family, Gemini 2.5, Groq/Llama, Mistral, Cohere, DeepSeek, Grok,
Kimi, Perplexity Sonar). Rates are credits per 1M tokens with input /
output / cache-read / cache-creation split.
- `compute_token_credits(input_data, stats)` — reads
`stats.input_token_count / output_token_count / cache_read_token_count /
cache_creation_token_count`, multiplies by `TOKEN_COST[model]`, ceils to
integer credits. Falls back to flat `MODEL_COST[model]` for unmapped
models (no silent under-billing).
- `billing.charge_reconciled_usage(node_exec, stats)` — runs
post-flight, charges positive delta / refunds negative delta. RUN-only
blocks produce zero delta (no-op). Swallows `InsufficientBalanceError` +
unexpected errors so reconciliation never poisons the success path.
- Pre-flight balance guard — dynamic-cost blocks (0 pre-flight charge)
are blocked when the wallet is non-positive. Closes Sentry `r3132206798`
(HIGH).
- Reconciliation fires `handle_low_balance` on positive delta so users
still get alerted after post-flight reconciliation.

### Block migrations — cost-type changes

| Provider / block family | Old | New | Cost type |
|---|---|---|---|
| All LLM blocks (Anthropic / OpenAI / Groq / Open Router / Llama API /
v0 / AIML, via `LLM_COST` list) | RUN, flat per-model from `MODEL_COST`
| `TOKEN_COST` per-token rate table (input / output / cache-read /
cache-creation) | **TOKENS** |
| Jina `SearchTheWebBlock` | RUN, 1 cr | 100 cr / $ (≈ 1 cr per $0.01
call) | **COST_USD** |
| ZeroBounce `ValidateEmailsBlock` | RUN, 2 cr | 250 cr / $ (≈ 2 cr per
$0.008 validation) | **COST_USD** |
| Apollo `SearchOrganizationsBlock` | RUN, 2 cr flat | 1 cr / 2 orgs
(divisor=2) | **ITEMS** |
| Apollo `SearchPeopleBlock` (no enrich) | RUN, 10 cr flat | 1 cr /
person | **ITEMS** |
| Apollo `SearchPeopleBlock` (enrich_info=true) | RUN, 20 cr flat | 2 cr
/ person | **ITEMS** |
| Firecrawl (all blocks — Crawl, MapWebsite, Search, Extract, Scrape,
via `ProviderBuilder.with_base_cost`) | RUN, 1 cr | 1000 cr / $ (1 cr
per Firecrawl credit ≈ $0.001) | **COST_USD** |
| DataForSEO (KeywordSuggestions, RelatedKeywords, via `with_base_cost`)
| RUN, 1 cr | 1000 cr / $ | **COST_USD** |
| Exa (~45 blocks, via `with_base_cost`) | RUN, 1 cr | 100 cr / $ (Deep
Research $0.20 → 20 cr) | **COST_USD** |
| E2B `ExecuteCodeBlock` / `InstantiateCodeSandboxBlock` /
`ExecuteCodeStepBlock` | RUN, 2 cr flat | 1 cr / 10 s walltime
(divisor=10) | **SECOND** |
| FAL `AIVideoGeneratorBlock` | RUN, 10 cr flat | 3 cr / walltime s |
**SECOND** |

### Cost-leak fixes — interim values (flagged 🔴 CONSERVATIVE INTERIM in
Notion)

Separate from the type migrations above, these 3 providers had real API
costs but were under-billed (or wallet-free):

| Provider / block | Old | New | Cost type | Plan for proper fix |
|---|---|---|---|---|
| Stagehand (`StagehandObserve` / `Act` / `Extract`, via
`with_base_cost`) | RUN, 1 cr | 1 cr / 3 walltime s (divisor=3) |
**SECOND** | Have blocks emit `provider_cost` USD (session_seconds ×
$0.00028 + real LLM USD) → migrate to `COST_USD 100 cr/$`. |
| Meeting BaaS `BaasBotJoinMeetingBlock` (via `@cost` decorator
override) | RUN, 5 cr | RUN, 30 cr | RUN | Surface meeting duration on
`FetchMeetingData` response → migrate Join to `SECOND` or `COST_USD`
post-flight. |
| AgentMail (~37 blocks, via `with_base_cost`) | **0 cr (unbilled)** |
RUN, 1 cr | RUN | Revisit when AgentMail publishes paid-tier pricing
(currently beta). |

### UI

- `NodeCost.tsx` dynamic labels: RUN → `N /run`, SECOND → `~N /sec` (or
`~N / Xs` with divisor), ITEMS → `~N /item` (or `/ X items`), COST_USD →
`~N · by USD`, TOKENS → `~N · by tokens` (tooltip explains cache
discount).
- Floor amounts prefixed with `~` for dynamic types so users see an
estimate, not a hard guarantee.

## How

The resolver split is the key design decision. Instead of charging the
"true" cost entirely post-flight (which would let a user burn credits
they don't have), pre-flight returns a safe estimate:
- RUN: full `cost_amount` (same as before — backwards compatible).
- SECOND/ITEMS/COST_USD: `0` when stats aren't populated yet.
- TOKENS: `MODEL_COST[model]` as a flat floor from the existing rate
table.

Post-flight, the executor calls `charge_reconciled_usage`, which
evaluates the same resolver with stats and charges the positive delta
(or refunds the negative delta). RUN blocks get a 0-delta no-op; dynamic
blocks get their actual charge. Failure modes are bounded: insufficient
balance is logged (not raised; reconciliation must never poison a
success), unexpected errors are swallowed and alerted via Discord.

TOKENS routes through a dedicated `compute_token_credits` helper so the
rate table (`TOKEN_COST`) can grow organically without touching resolver
logic. Models not yet in `TOKEN_COST` fall back to the flat `MODEL_COST`
tier.

Migration for providers with a real USD spend (Exa, Firecrawl,
DataForSEO, Jina Search, ZeroBounce) is a one-line `_config.py` change
via the extended `ProviderBuilder.with_base_cost`. Each block's `run()`
populates `provider_cost` from the response (Exa's `cost_dollars.total`,
Firecrawl's `credits_used`, etc.) via `merge_stats`, and the post-flight
resolver multiplies by `cost_amount` credits/$.

## Test plan

- [x] 92/92 cost-pipeline tests pass — `block_usage_cost_test.py`,
`billing_reconciliation_test.py`, `manager_cost_tracking_test.py`,
`block_cost_config_test.py`.
- [x] Deep E2E against live stack (real DB, `database_manager` RPC): 8/8
scenarios pass — RUN pre-flight, dry-run no-charge, TOKENS refund, ITEMS
scaling, ITEMS zero-items short-circuit, COST_USD exact + ceil
semantics, pre-flight balance guard. Report:
https://github.com/Significant-Gravitas/AutoGPT/pull/12894#issuecomment-4307672357
- [x] `poetry run ruff check` / `ruff format` / `pnpm format` / `pnpm
lint` / `pnpm types` — clean.
- [x] Manual UI: `NodeCost.tsx` renders `~N · by tokens` for
AITextGeneratorBlock, `~N · by USD` for Jina/Exa/Firecrawl.

## Follow-ups (not in this PR)

- Stagehand / Meeting BaaS / Ayrshare: expose provider-side unit cost
(session-seconds, meeting duration, platform analytics credits) to
migrate from interim flat/walltime to fully dynamic `COST_USD`.
- Replicate / Revid: walltime-based billing once response cost is piped
through.
- AgentMail: final rate once paid tier is published.
2026-04-24 08:45:39 +07:00
Zamil Majdy
43b38f6989 fix(backend/copilot): surface non-zero E2B exits as real results, not sandbox errors (#12904)
## Why

`gh auth status` looked flaky in the E2B sandbox. Not actually flaky: it
fails deterministically when the user has not connected GitHub (or the
token is missing/expired), and our wrapper disguises that legitimate
exit-1 as a sandbox infrastructure failure.

Root cause: E2B's `sandbox.commands.run()` raises `CommandExitException`
for **any** non-zero exit. We caught it as a generic `Exception` and
returned an `ErrorResponse` with message:

```
E2B execution failed: Command exited with code 1 and error:
{stderr}
```

When the model runs `gh auth status 2>&1`, stderr is redirected to
stdout — so `exc.stderr` is empty **and** `exc.stdout` (which carries
the real info, e.g. "You are not logged into any GitHub hosts") is
discarded. The model sees a generic infra failure, can't tell it's an
auth-check signal, and prompts the user with broken-looking errors
instead of calling `connect_integration(provider="github")`.

Compare: the local bubblewrap path already handles non-zero exits
correctly by returning a `BashExecResponse` with `exit_code` set. The
E2B path was asymmetric.

## What

- Import `CommandExitException` and catch it explicitly in
`_execute_on_e2b` before the generic handler.
- Return a `BashExecResponse` with the real `exit_code`, `stdout`, and
`stderr` from the exception (scrubbed of injected secret values, same as
the success path).
- Extract shared scrub/build logic into `_build_response` to avoid
duplicating it across the success and exit-exception branches.
- Keep `TimeoutException` and the catch-all `except Exception` for real
infra failures.

## How

Result shape now matches bubblewrap: non-zero exit is a valid result,
not an error. The model sees:

```
message: "Command executed with status code 1"
exit_code: 1
stdout: "You are not logged into any GitHub hosts. ..."
stderr: ""
```

instead of the prior cryptic "E2B execution failed" message.

## Test plan

- [x] New unit test `test_nonzero_exit_returned_as_bash_exec_response`
in `bash_exec_test.py` — mocks `sandbox.commands.run` to raise
`CommandExitException`, asserts `BashExecResponse` with correct
`exit_code`, and verifies secret scrubbing on both `stdout` and
`stderr`.
- [x] `poetry run pytest backend/copilot/tools/bash_exec_test.py` — 5
passed.
- [x] `poetry run pyright` on changed files — 0 errors.
- [x] `poetry run ruff` — clean.
2026-04-24 07:49:57 +07:00
Nicholas Tindle
10e421cd3e fix(platform): resolve autopilot beta blockers (SECRT-2266/2267/2268/2269) (#12874)
### Why / What / How

**Why:** A beta user spent significant time trying to build and run
agents that read Google Sheets. Four separate failures compounded on
their session — all already open in Linear as SECRT-2266 through
SECRT-2269. Three in-flight PRs each addressed a piece but conflicted on
the same files (`backend/data/model.py`, `backend/blocks/_base.py`,
`autogpt_libs/.../types.py`), so landing them individually would have
been churn. One of the four reported issues (the credential-delete
crash) is also the top unresolved Sentry issue `AUTOGPT-SERVER-6HB` with
100+ events going back to 2025-10-20 — it was archived as "ignored" but
is a real regression. Bug #4 required new work; the others we got by
adopting the existing open PRs and addressing a pending review comment.

**What:** This PR consolidates the three in-flight PRs, adds the two
pieces of new work needed to fully close the beta blockers, and
addresses the pending review on one of the three PRs so it doesn't
require a second round.

- **Closes PR #12004** — Google Drive auto-credentials handling (merged
in)
- **Closes PR #12748** — Incremental OAuth for scope upgrades (merged
in)
- **Closes PR #12588** — superseded by the systemic None-guard here (see
"How" below)
- **Adds Bug 2 fix** — Google credential deletion no longer crashes on
`revoke_tokens`
- **Adds Bug 4 validator** — the agent builder can no longer save a
graph with a hardcoded Drive file ID

**How:**

1. **Adopt PR #12004 (Bug 1 — auto-credentials resolution).** Tags
Drive-file fields as `is_auto_credential` on `CredentialsFieldInfo`,
exposes `BlockSchema.get_auto_credentials_fields()` and
`Graph.regular_credentials_inputs` / `auto_credentials_inputs`, extracts
`_acquire_auto_credentials()` in the executor to resolve embedded
`_credentials_id` at run time, clears `_credentials_id` on agent fork so
cloned agents don't inherit the original author's credential, and fixes
the Firefox referrer policy on the Google Drive picker script load.

2. **Adopt PR #12748 (Bug 3 — credential accumulation).** OAuth callback
now merges scopes into an existing credential (explicit via
`credential_id` in OAuth state, or implicit via `provider + username`
match) instead of appending a new row on every reconnect. GitHub's
non-incremental OAuth path requests the union of existing + new scopes
at login so the upgrade path works there too.

3. **Replace PR #12588 with a systemic None-guard (addresses reviewer
feedback).** The original PR added a per-block `credentials:
GoogleCredentials | None = None` + early guard pattern that would need
to be repeated across 50+ blocks with `GoogleDriveFileField`. Per the
reviewer's ask, we moved the guard into `Block._execute()` once: after
the `setdefault` loop, if `kwargs[kwarg_name] is None` we raise
`BlockExecutionError` with a clean user-facing message. The per-block
change in `sheets.py` is dropped so `credentials: GoogleCredentials`
stays non-`Optional`. Dry-run path skips the guard (executor
intentionally runs blocks without resolved creds for schema validation).

4. **Fix Bug 2 — Google revoke_tokens (SECRT-2267,
AUTOGPT-SERVER-6HB).** `revoke_tokens()` was handing our Pydantic
`OAuth2Credentials` into google-auth's `AuthorizedSession`, which calls
`self.credentials.before_request(...)` on the object and crashes with
`AttributeError: 'OAuth2Credentials' object has no attribute
'before_request'`. Google's token revoke endpoint doesn't need any auth
header — just `token=<token>` in the form body per [Google's
docs](https://developers.google.com/identity/protocols/oauth2/web-server#tokenrevoke).
Switched to the platform's async `Requests` helper, matching how
`reddit.py` / `github.py` / `todoist.py` / other providers do
revocation. No google-auth objects involved.

5. **Fix Bug 4 — hardcoded Drive file IDs in agent graphs
(SECRT-2269).** Evidence from the beta user's session: CoPilot's
agent-builder produced 13 saved graph versions in one session where each
one stuffed either a bare string (`"1KAv…"`) or a partial object
(`{"id": "1KAv…"}`) into
`GoogleSheetsReadBlock.constantInput.spreadsheet`, never wiring an
`AgentGoogleDriveFileInputBlock` as the intended input. Bare-string
versions failed pydantic validation with `is not of type 'object'`;
object-with-only-`id` versions would have crashed at run time because
`_acquire_auto_credentials` has no `_credentials_id` to resolve. Added a
validator in `GraphModel._validate_graph_get_errors` that flags any
auto-credentials field whose `input_default.<field>` is a bare string OR
a dict missing `_credentials_id`, when there's no upstream link feeding
the field. Remediation text is format-aware: when
`field_schema["format"] == "google-drive-picker"` it names
`AgentGoogleDriveFileInputBlock` specifically; for any other future
auto-credentials format (OneDrive / Dropbox / etc.) the remediation is
generic, so we don't ship a stale Google-specific hint that doesn't
apply.

A companion handoff for the CoPilot agent-builder team is drafted at
`/tmp/agent-builder-ticket-drive-file-input.md` (to be filed in their
tracker). The validator here is a safety net so reviewers and the LLM
both get a clear error with the correct remediation; the agent-builder
itself still needs to learn the correct pattern so it stops trying to
hardcode Drive files in the first place.

### Changes 🏗️

**Backend**

- `backend/data/model.py` — merged `is_auto_credential` +
`input_field_name` (#12004) with `OAuthState.credential_id` (#12748);
kept HEAD's defensive `set()` copy on `discriminator_values`.
- `backend/blocks/_base.py` — `_execute()` runs the auto-credentials
setdefault loop + raises `BlockExecutionError` when a resolved value is
`None`.
- `backend/blocks/google/sheets_test.py` — 2 new tests (systemic
None-guard behaviour).
- `backend/blocks/google/_drive.py`, `_drive_test.py` — unchanged on
this branch (earlier bare-string validator was reverted after feedback;
see "Out of scope" below).
- `backend/data/graph.py` — auto-credentials anti-pattern validator in
`_validate_graph_get_errors`.
- `backend/data/graph_test.py` — 11 new tests for the validator.
- `backend/integrations/oauth/google.py` — `revoke_tokens` swapped to
`Requests().post`, removed `AuthorizedSession` misuse.
- `backend/integrations/oauth/google_test.py` — 3 new tests covering the
revoke happy path, no-access-token, and non-2xx-response.
- `backend/integrations/credentials_store.py` — from #12748.
- `backend/api/features/integrations/router.py` — incremental-OAuth
callback + scope upgrade helpers (from #12748).
- `backend/api/features/integrations/incremental_oauth_test.py` — 15
tests (from #12748).
- `backend/api/features/chat/tools/utils.py` → renamed to
`backend/copilot/tools/utils.py` during merge; now uses
`regular_credentials_inputs` for missing-creds + matching (from #12004).
- `backend/copilot/tools/utils_test.py` — moved from
`api/features/chat/tools/`, import paths updated.
- `backend/api/features/library/db.py` — library preset guard uses
`regular_credentials_inputs` (from #12004).
- `backend/data/graph.py` — `regular_credentials_inputs` /
`auto_credentials_inputs` properties + `_reassign_ids` clears
`_credentials_id` on fork (from #12004).
- `backend/executor/manager.py` — `_acquire_auto_credentials()`
extracted + validation (from #12004).
- `backend/executor/utils.py`, `utils_test.py`,
`manager_auto_credentials_test.py` — auto-credentials tests (from
#12004).

**Frontend**

- `frontend/src/components/contextual/GoogleDrivePicker/helpers.ts` —
Firefox referrer fix (from #12004).
-
`frontend/src/components/contextual/CredentialsInput/useCredentialsInput.ts`,
`src/hooks/useCredentials.ts`, `src/lib/autogpt-server-api/client.ts`,
`src/providers/agent-credentials/credentials-provider.tsx`,
`src/app/api/openapi.json` — incremental-OAuth scope upgrade UI (from
#12748).

**Shared libs**

- `autogpt_libs/supabase_integration_credentials_store/types.py` —
merged additions from both #12004 and #12748.

### Test plan 📋

- [x] `poetry run lint` — clean
- [x] `poetry run pytest backend/data/graph_test.py` — 55 passed
including 11 new validator tests
- [x] `poetry run pytest backend/integrations/oauth/google_test.py` — 3
new tests passing
- [x] `poetry run pytest backend/blocks/google/sheets_test.py` — 2 new
tests passing
- [x] `poetry run pytest backend/blocks/google/
backend/integrations/oauth/ backend/executor/ backend/data/graph_test.py
backend/api/features/integrations/ backend/copilot/tools/utils_test.py`
— 250 passed, 6 pre-existing failures that require the docker stack
(RabbitMQ/Redis/Postgres) and fail identically on `origin/dev`
- [x] `pnpm format` — clean
- [x] `pnpm lint` — 3 pre-existing `<img>` warnings on files I didn't
touch, no errors
- [x] `pnpm types` — pre-existing errors on `AgentActivityDropdown` that
also fail on `origin/dev` (unrelated to this PR; needs a separate fix on
dev)
- [x] Live repro on dev verified Bug 2 fires against current prod code —
two fresh Sentry events in `AUTOGPT-SERVER-6HB` at 2026-04-21T21:35:54Z
on `app:dev-behave:cloud` matching the exact `DELETE
/api/integrations/google/credentials/{cred_id}` path. Airtable OAuth2
delete as a control worked cleanly, confirming Google-specific.
- [x] Live repro on dev verified Bug 4 (CoPilot direct-run variant) —
`{"spreadsheet": {"id": "..."}}` → `Cannot use file 'None' (type: None)`
from `_validate_spreadsheet_file` mimeType check, as expected.

Reviewer post-merge verification:
- [ ] Delete a Google OAuth credential via the Integrations UI —
succeeds cleanly, no Sentry event fires
- [ ] Connect Google twice (same account, same scopes) — credential
count stays at 1 (dedup)
- [ ] Save an agent graph with
`GoogleSheetsReadBlock.constantInput.spreadsheet = "bare-id"` via API —
graph validator rejects with `AgentGoogleDriveFileInputBlock`
remediation
- [ ] Save an agent graph with `GoogleSheetsReadBlock` whose
`spreadsheet` is fed by an upstream
`AgentGoogleDriveFileInputBlock.result` — validator accepts, agent runs

### Out of scope (for follow-ups)

- **Bug 1 — "Failed to retrieve Google OAuth credentials"** in
`frontend/src/components/contextual/GoogleDrivePicker/useGoogleDrivePicker.ts:163`.
Zero hits for this string in the beta user's Langfuse traces and we
weren't able to reproduce it from a clean flow. Most likely a
stale-credential race condition (delete in another tab, picker queries a
stale React-Query cache). Tracked as a separate task; not blocking.
- **CoPilot first-attempt mimeType retry loop.** Observed on dev:
CoPilot's first call to `GoogleSheetsReadBlock` sends `{"spreadsheet":
{"id": "..."}}` without `mimeType`, hits `_validate_spreadsheet_file`,
retries with mimeType. Costs a round-trip. Two possible fixes (relax
`_validate_spreadsheet_file` to skip when mimeType is `None` and let
Google's API surface the real error; OR extend
`get_auto_credentials_fields` metadata so CoPilot's tool description
prompts it to always include mimeType). Deliberately deferred — fixing
only one of "API caller sends a bare string" or "CoPilot sends an
incomplete object" risked the same auth-ambiguity the bare-string commit
in this branch history hit.
- **CoPilot agent-builder prompt/guide update.** The validator here
produces the correct error message, but the agent-builder model still
needs to learn to use `AgentGoogleDriveFileInputBlock` upfront rather
than discover it through validator retries. Separate handoff ticket
filed.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Medium Risk**
> Touches OAuth credential issuance/upgrade paths and introduces a new
endpoint that returns raw access tokens (scope-gated), plus broad
changes to execution-time credential resolution/validation; mistakes
could impact auth/security or break integrations.
> 
> **Overview**
> Fixes several Google/Drive agent-builder blockers by **supporting
incremental OAuth scope upgrades** and by hardening how
credential-bearing file inputs (“auto-credentials”) are validated,
resolved, and cleared on graph fork.
> 
> On the integrations API, `/{provider}/login` now accepts
`credential_id` and persists it in `OAuthState` to upgrade an existing
OAuth2 credential on callback (explicit upgrade), with an implicit merge
path for same `provider+username`. The callback path now merges
scopes/metadata, preserves ID/title, preserves existing
`refresh_token`/`username` when missing from incremental responses,
blocks upgrades for managed/system credentials, and adds a **new
`/{provider}/credentials/{cred_id}/picker-token` endpoint** to return a
short-lived access token for provider-hosted pickers (currently
allowlisted to Google Drive scopes).
> 
> For auto-credentials, `CredentialsFieldInfo` gains
`is_auto_credential` + `input_field_name`, graphs now expose
`regular_credentials_inputs` vs `auto_credentials_inputs`, and multiple
callers switch from `aggregate_credentials_inputs()` to
`regular_credentials_inputs` so embedded picker credentials aren’t
treated as user-mapped inputs. Execution-time auto-credential
acquisition is extracted into `_acquire_auto_credentials()` with clearer
error handling and lock cleanup; block execution adds a systemic guard
to surface a clean `Missing credentials` error when auto-credentials are
absent.
> 
> Separately fixes Google credential deletion by rewriting
`GoogleOAuthHandler.revoke_tokens()` to use the platform `Requests`
helper (bounded retries) instead of `AuthorizedSession`, and expands
test coverage across these flows (incremental OAuth, picker-token,
auto-credential validation/acquisition, graph validator, and frontend
diagnostics test stubs).
> 
> <sup>Reviewed by [Cursor Bugbot](https://cursor.com/bugbot) for commit
cac36eae9f. Bugbot is set up for automated
code reviews on this repo. Configure
[here](https://www.cursor.com/dashboard/bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-04-23 17:16:30 +00:00
Zamil Majdy
80bfde1ca6 feat(blocks): charge Ayrshare per-post + align Bannerbear/Jina floors (#12893)
## Why

The cost-tracking audit on 2026-04-23 ([Platform System
Credentials](https://www.notion.so/auto-gpt/4d251f343fe146bcb91b6a037d1bfc3c))
surfaced three gaps where the user wallet was silently subsidising
third-party spend:

1. **Ayrshare (13 blocks)** — zero charge on every social post. No
`BLOCK_COSTS` entry, no SDK `.with_base_cost` registration. Platform
absorbs the entire ~$149/mo Business plan.
2. **Bannerbear** — flat 1 credit/call below the ~$0.025/image unit cost
on the Starter tier ($49/mo / 2K images).
3. **JinaChunkingBlock** — wallet-free; siblings (`JinaEmbeddingBlock`,
`SearchTheWebBlock`) are charged.

## What

- New `backend/blocks/ayrshare/_cost.py` with two-tier
`AYRSHARE_POST_COSTS` (5 credits when `is_video=True`, 2 credits
otherwise — first-match wins in `block_usage_cost`).
- All 13 `PostTo*Block` classes decorated with
`@cost(*AYRSHARE_POST_COSTS)`.
- `BannerbearTextOverlayBlock` floor: 1 → 3 credits in
`bannerbear/_config.py`.
- `JinaChunkingBlock` added to `BLOCK_COSTS` with a flat 1-credit floor.
- `cost(...)` decorator generic-ized via `TypeVar`, so pyright retains
`PostToXBlock.Input/Output` narrowing.

## How

Ayrshare uses a decorator-based registration (not a direct `BLOCK_COSTS`
entry) because each `post_to_*.py` block imports from `backend.sdk`, and
`backend.sdk.cost_integration` imports `BLOCK_COSTS` — listing the
blocks in `block_cost_config.py` would create a circular import. The
`@cost` decorator defined in `sdk/cost_integration.py` was already the
approved escape hatch for this exact shape.

cost_filter in `block_usage_cost` already supports boolean-field
matching (see Apollo's `enrich_info` tier), so `{"is_video": True}` and
`{"is_video": False}` select the right tier at execution time.
`is_video` defaults to `False` on `BaseAyrshareInput`, so posts that
omit the field still land on the 2-credit default.

## Test plan

- [x] `poetry run pytest backend/data/block_cost_config_test.py` — new
6-test suite covers Ayrshare video/non-video/default tiers, the
Bannerbear floor, and the Jina chunking floor
- [x] `poetry run pytest backend/executor/manager_cost_tracking_test.py`
— no regressions (45 pre-existing tests still pass)
- [x] `poetry run ruff format` + `poetry run isort` + `poetry run ruff
check --fix`
- [x] `poetry run pyright` on touched files — 0 errors, 0 warnings
(pre-existing `LlmModel.KIMI_K2_*` errors are on dev and unrelated)
- [ ] Manual: run an Ayrshare post through the builder and confirm 2cr
(text/image) vs 5cr (video) charge
2026-04-23 20:39:35 +07:00
Zamil Majdy
81d6e91f37 feat(platform/copilot): message timestamps + accurate thought-for time (#12890)
## Why

The "Thought for 1m 46s" label under assistant replies has been
misleading
because the backend persists the whole-turn wall clock (from turn start
to
stream end) — which includes tool execution, browser sessions, graph
runs,
etc. Users also had no way to see when a message was actually sent /
received.

## What

- **Per-message timestamps** — `ChatMessage.created_at` (already on the
DB row)
is now serialised through the pydantic model and the
`SessionDetailResponse`,
then plumbed into the UI. Hovering the "Thought for X" label now shows
the
  absolute local date/time via a tooltip.
- **Accurate reasoning duration** — new
`ChatMessage.reasoningDurationMs`
  column. Backend accumulates time between `reasoning-start` and
`reasoning-end` SSE events inside `publish_chunk` (via the session meta
hash). `mark_session_completed` reads the total and persists it
alongside
the existing `durationMs`. Frontend prefers `reasoning_duration_ms` when
  present, falls back to `duration_ms` for legacy rows.

## How

- `schema.prisma` gains `reasoningDurationMs Int?`; migration
  `20260423120000_add_reasoning_duration_ms` adds the column.
- `publish_chunk` gains a side-effect that writes `reasoning_started_at`
/
`reasoning_ms_total` into the existing per-session Redis meta hash when
  reasoning events pass through. No extra IO path, no extra Redis key.
- `set_turn_duration` accepts an optional `reasoning_duration_ms` arg
and
  patches both the DB row and the cached session in place, mirroring the
  existing behaviour for `duration_ms`.
- Frontend: `convertChatSessionMessagesToUiMessages` now returns
`durations`, `reasoningDurations`, and `timestamps` maps. `TurnStatsBar`
picks the best available value and wraps the label in the design-system
  `BaseTooltip` so hover reveals the local timestamp.

## Test plan

- [x] `poetry run pytest
backend/copilot/db_test.py::test_set_turn_duration_*`
- [x] `poetry run pytest backend/copilot/stream_registry_test.py`
- [x] `pnpm format` / `pnpm lint` / `pnpm types` (copilot area)
- [x] `pnpm test:unit src/app/\(platform\)/copilot` — 705 tests pass (4
pre-existing `jszip` module resolution failures unrelated to this
change)
- [ ] Manual: open a session with a long tool run and confirm the new
"Thought for X" reflects only reasoning time (falls back for old rows)
      and the tooltip surfaces the local timestamp.
2026-04-23 18:55:34 +07:00
Zamil Majdy
39cdc0a5e0 fix(backend/copilot): tame Kimi compaction storm + tunable threshold + Langfuse cost backfill (#12889)
## Why

Investigation of two reported sessions
([85804387](https://dev-builder.agpt.co/copilot?sessionId=85804387-7708-4fdc-8ec9-64283cdd902d),
[19d69dec](https://dev-builder.agpt.co/copilot?sessionId=19d69dec-210f-4439-a94b-2d7d443b9909))
where Kimi K2.6 via OpenRouter was running ~30 min per turn with no
actions completed (Discord report from Toran). Langfuse traces showed:

- 31 generation calls per turn at p90 = 151s, max = 415s
- 2.57M uncached tokens, `cache_create=0`, ~4% cache_read — Moonshot's
OpenRouter endpoint silently drops Anthropic-style cache writes
- **3 SDK-internal compactions per turn** — each compaction is itself a
slow LLM round-trip
- Reconciled OpenRouter cost was being recorded to a DB row but never
surfaced on the Langfuse trace, leaving operators to grep pod logs

## What

Four commits, split by concern.

### 1. `fix(backend/copilot): skip CLAUDE_AUTOCOMPACT_PCT_OVERRIDE for
Moonshot/Kimi` (`5fd9c5aa`)

`env.py` was unconditionally setting
`CLAUDE_AUTOCOMPACT_PCT_OVERRIDE=50` (introduced in #12747 to cap
cache-creation cost on Anthropic where context >200K = 54% of total
cost). On Kimi where `cache_create=0` silently, the cache-cost rationale
doesn't apply — but the 50% threshold still made the bundled CLI
auto-compact at ~100K tokens, triggering 3+ compactions per turn against
Kimi's larger effective window. Each compaction added a slow LLM
round-trip (one in our test ran 166s and burned the budget cap before
the user got any output).

Threads the resolved `sdk_model` (and `fallback_model`) into
`build_sdk_env` and skips the env var when the model matches
`is_moonshot_model(...)`. The CLI then uses its default ~93% threshold,
cutting compaction passes to 0–1.

### 2. `feat(backend/copilot): backfill OpenRouter reconciled cost to
Langfuse trace` (`f3de3624` + follow-ups `5ce3d038`, `d2c1a2cd`,
`d8e08525`, `d243bf6c9`)

`record_turn_cost_from_openrouter` runs as a fire-and-forget task after
the OTel span closes, so the Langfuse trace UI showed the SDK CLI's
rate-card estimate only — for non-Anthropic OpenRouter routes that
estimate is Sonnet pricing on Kimi tokens (~5x too high).

The backfill captures `langfuse.get_current_trace_id()` and threads it
into the reconcile task, which emits an `openrouter-cost-reconcile`
child event with the authoritative cost + token usage. **Bug caught
during /pr-test:** `propagate_attributes` only annotates an existing
OTel span, it doesn't create one — by the time the `finally` block runs,
SDK-emitted spans have ended and `get_current_trace_id()` returns None.
Fixed in `d8e08525` by wrapping the turn in
`langfuse.start_as_current_span(name="copilot-sdk-turn")`. Also tags
fallback-path events with `cost_source` so operators can distinguish
reconciled vs estimated turns.

### 3. `feat(backend/copilot): expose CLAUDE_AUTOCOMPACT_PCT_OVERRIDE as
a config knob` (`72416f73`)

The previously-hardcoded `50` is now
`claude_agent_autocompact_pct_override` (default 50, env
`CHAT_CLAUDE_AGENT_AUTOCOMPACT_PCT_OVERRIDE`). Setting to 0 omits the
env var entirely so the CLI uses its native ~93% threshold — useful when
the post-compact floor (system prompt + tool defs ≈ 65–110K) sits close
to an aggressive trigger and operators see back-to-back compaction
cascades. Moonshot routes still skip the env var unconditionally
regardless of config.

### 4. `fix(backend/copilot): align SDK retry compaction target with CLI
autocompact threshold` (`730ad256`)

`_reduce_context` was calling `compact_transcript` without an explicit
`target_tokens`, so it fell back to `get_compression_target(model) =
context_window - 60K`. For Sonnet 200K that's 140K — well above the
CLI's PCT=50 trigger of 90K — and for Kimi 256K it's 196K, above the
CLI's default 167K trigger. Result: a successful retry compaction landed
at 140K/196K and the CLI immediately re-compacted on the next call →
**two compactions per recovered turn**.

New `_compaction_target_tokens(model)` mirrors the CLI's `i6_()` formula
(`min(window * pct/100, window - 13K)`) with a 20K safety buffer so the
post-compact context sits comfortably below the CLI's trigger.

## How — empirical validation against the actual long Kimi transcript

Replayed the 199-message transcript from session 85804387 through the
bundled CLI in two configurations:

| | Post-fix (no override) | Pre-fix (`PCT_OVERRIDE=50`) |
|---|---|---|
| `autocompact: tokens=` | 126,312 | 126,341 |
| `threshold=` | **167,000** | **90,000** |
| Decision | 126K < 167K → **skip** | 126K > 90K → **COMPACTION FIRES**
|
| Duration | 21s | **166s** (8x slower) |
| Cost | $0.34 | **$0.82** (2.4x more) |
| Output | PONG (success) | empty (hit $0.50 budget cap, exit 1) |

The pre-fix configuration burned $0.82 of compaction work over 166s and
never produced a user response — exactly the failure mode reported.

**Why cascade happens at 50%, not at 93%:** post-compaction context is
`summary (~5–10K) + system_prompt + tool_definitions + skills + active
TodoWrite + memory ≈ 65–110K floor`. With trigger at 90K, post-compact
floor sits AT or above the trigger → next assistant message tips over →
immediate re-compaction → cascade until the CLI's rapid-refill breaker
trips at 3 attempts. With trigger at 167K, the same floor sits
comfortably below trigger → no cascade.

## Considered but not done

- **Force `cache_control` markers to reach Moonshot**: bundled CLI sends
them by default; Moonshot silently drops them per their own docs (uses
`X-Msh-Context-Cache` headers, not body markers). Real fix needs
bypassing OpenRouter — out of scope.
- **Slim the system prompt + tool definitions** to lower the
post-compact floor: real win but separate refactor with tool-use
accuracy A/B.
- **LD-driven auto-fallback to Sonnet on Kimi degradation**:
`claude_agent_fallback_model` already wires `--fallback-model` for
overload (529); auto-flipping on slowness needs latency aggregation
infra that doesn't exist yet.

## Test plan

- [x] `poetry run pytest backend/copilot/sdk/env_test.py
backend/copilot/sdk/openrouter_cost_test.py
backend/copilot/sdk/service_helpers_test.py` — 111 passed (37 env + 23
cost + 51 helpers, including 6 new env tests, 3 backfill tests, 6 new
compaction-target tests)
- [x] `poetry run pytest backend/copilot/sdk/` — 970+ passed
- [x] `poetry run pyright .` — 0 errors
- [x] `poetry run format` — clean
- [x] /pr-test --fix end-to-end against dev — 5/5 scenarios PASS,
including Anthropic route ($0.0174 cost +0.0% delta) and Moonshot route
($0.028 vs $0.018 → +58.2% delta validates reconcile rationale)
- [x] Transcript replay validation: pre-fix vs post-fix on real
126K-token transcript → 8x slower / 2.4x more expensive / fails entirely
on pre-fix; clean PONG on post-fix
2026-04-23 18:46:35 +07:00
Zamil Majdy
4242da79f0 fix(backend/copilot): raise baseline tool-round limit to 100 + graceful finish hint (#12892)
## Why

On prod, longer copilot runs (complex feature implementations, multi-bug
fix chains) error out with `Exceeded 30 tool-call rounds without a final
response`, lose mid-stream assistant output, and the UI appears to
re-dispatch an older prompt. Reported by @itsababseh in #breakage for
session `661ba0cc-a905-4c66-bf11-61eb5423d775`.

Langfuse trace of that session shows 52 turns / 344 LLM calls; **two
turns hit exactly 30 rounds** (Turn 38: implementing kill-cam/headshot
juice pass; Turn 42: fixing multi-bug list). Both were legitimate,
non-looping work that simply needed more rounds to complete. Round 30
fired `bash_exec`, the loop cut off cold, no summary was ever produced,
and the stream surfaced `baseline_tool_round_limit`. Frontend
subsequently re-dispatched the same user message several times (turns
39–41 × 3, turns 43–47 × 5 with identical prompt), which is what the
user perceives as "falling back into acting on an older command."

Root cause: [`_MAX_TOOL_ROUNDS =
30`](https://github.com/Significant-Gravitas/AutoGPT/blob/cf6d7034f/autogpt_platform/backend/backend/copilot/baseline/service.py#L125)
has been unchanged since the baseline path was introduced (#12276).
Modern agent turns with Claude Code / Kimi / Sonnet routinely need more.

## What

- Raise `_MAX_TOOL_ROUNDS` from 30 → 100.
- Pass `last_iteration_message` to `tool_call_loop` so the final round
receives a "stop calling tools, wrap up" system hint. The model now
produces a graceful summary on the last round instead of being cut off
mid-tool.

## How

Two-line change in
[`backend/copilot/baseline/service.py`](https://github.com/Significant-Gravitas/AutoGPT/blob/fix/copilot-baseline-tool-round-limit/autogpt_platform/backend/backend/copilot/baseline/service.py):
- Bump the module-level constant.
- Define `_LAST_ITERATION_HINT` and wire it via the existing
`last_iteration_message` kwarg on
[`tool_call_loop`](https://github.com/Significant-Gravitas/AutoGPT/blob/cf6d7034f/autogpt_platform/backend/backend/util/tool_call_loop.py#L188).
The shared loop already handles appending it only on the final iteration
(see `tool_call_loop_test.py::test_last_iteration_message_appended`).

Frontend retry cascade on `baseline_tool_round_limit` is a separate UX
issue — logging it as a follow-up.

## Checklist

- [x] My code follows the project's style guidelines
- [x] I have performed a self-review
- [x] Existing `tool_call_loop_test.py` covers `last_iteration_message`
behavior (10/10 passing)
- [x] No new migrations
- [x] No breaking changes (constant/kwarg only)
2026-04-23 18:38:52 +07:00
Zamil Majdy
cf6d7034fa fix(backend/copilot): sync safety net for Redis-induced zombie sessions (#12886)
## Why

A 25-min-old copilot turn ended up a zombie in Redis (`status=running`
for 60+ min, queued user messages never drained) after a rolling deploy
of `autogpt-copilot-executor`. Root cause:

1. Cluster churn during the rollout broke a Redis call mid-turn.
2. `_execute_async`'s `finally` tried to publish the failure via
`mark_session_completed` on the same (now-broken) event loop +
thread-local Redis client.
3. That Redis call *also* failed; the exception was caught and logged
but never reached Redis — so the session meta stayed `running`.
4. `on_run_done` then completed the future normally, `active_tasks`
drained, the pod exited.
5. The zombie persisted until the 65-min stale-session watchdog reaped
it. While it was live, queued-message pushes succeeded (HTTP only checks
`status=running`), so the UI showed "Queued" bubbles that never drained.

## What

The fix is **one small addition** in the per-turn lifecycle:

### `sync_fail_close_session` — last line of defense in
`processor.execute`'s `finally`

Invoked from `CoPilotProcessor.execute()`'s `finally` on every turn
exit. Submits the CAS coroutine to the processor's long-lived
`self.execution_loop` via `asyncio.run_coroutine_threadsafe` — the same
pattern `ExecutionProcessor.on_graph_execution` uses at
[executor/manager.py:881-892](autogpt_platform/backend/backend/executor/manager.py#L881-L892)
to bridge sync→async through `node_execution_loop`.

- Calls `mark_session_completed(session_id,
error_message=SHUTDOWN_ERROR_MESSAGE)`, which is a CAS on `status ==
"running"`. If the async path already wrote a terminal state the CAS
no-ops; otherwise we mark `failed` and the UI transitions cleanly.
- Bounded by inner `asyncio.wait_for(timeout=10s)` and outer
`future.result(timeout=12s)` so a genuinely unreachable Redis can't hang
the safety net.
- Reuses the long-lived execution loop (no per-turn TCP connect, no
`@thread_cached` thrashing).

The outer `future.result()` in `_execute()` is bounded by
`_CANCEL_GRACE_SECONDS` (5s) so a wedged event loop can't trap the flow
before the safety net fires.

### `cleanup()` stays aligned with agent-executor

Mirrors the pattern from `backend.executor.manager.cleanup` — a single
method that:

1. Flags + tells the broker to stop consuming.
2. Passively waits for `active_tasks` to drain (up to
`GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS`).
3. Worker / executor / lock teardown.

No pre-emptive cancellation of healthy turns, no fail-close step for
stuck turns. Same proven shape agent-executor uses.

### Timeout alignment

Raised both `COPILOT_CONSUMER_TIMEOUT_SECONDS` and
`GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS` to 6h so a rolling deploy can let
the longest legitimate turn finish via its own lifecycle path. Matched
in infra at `terminationGracePeriodSeconds: 21600`
(Significant-Gravitas/AutoGPT_cloud_infrastructure#311).

### RabbitMQ policy — deploy prep

The `x-consumer-timeout` queue argument is changing from 1h → 6h. Tested
empirically on dev's RabbitMQ 4.1.4: `queue_declare` is tolerant of
`x-consumer-timeout` mismatches, so no queue delete is needed. To make
the new timeout **immediately effective for running consumers** (so pods
mid-shutdown don't have their consumer cancelled at the old 1h limit),
apply a policy before deploying:

```bash
rabbitmqctl set_policy copilot-consumer-timeout \
  "^copilot_execution_queue$" \
  '{"consumer-timeout": 21600000}' \
  --apply-to queues
```

Already applied on dev. Apply on prod before the PR's prod deploy.

### Incidental rename

- `_clear_pending_messages_unsafe` → `clear_pending_messages_unsafe`
(keeps the `_unsafe` warning suffix; importable without the
leading-underscore private marker).

## How

Before: transient Redis failure → async finally silently fails → zombie
session → queued messages never drain.
After: transient Redis failure → `execute()`'s sync finally runs
`mark_session_completed` on the processor's long-lived loop → session
correctly marked failed → UI sees terminal state immediately.

SIGTERM path unchanged from the "let in-flight work finish" design: old
pod stops taking new work, existing turns complete naturally.

## Test plan

- [x] `TestSyncFailCloseSession` unit tests — invokes
`mark_session_completed` with the shutdown error, swallows Redis
failures, bounded timeout fires when Redis hangs.
- [x] `TestExecuteSafetyNet` — verifies the `finally` always fires,
including SIGTERM-interrupted and zombie-Redis scenarios.
- [x] Existing `TestExecuteAsyncAclose` + pending_messages tests still
pass (18 passed).
- [x] `pyright` on touched files: 0 errors.
- [x] Manual E2E on native dev stack: sent a `sleep 300 && echo hewwo`
task, SIGTERMed mid-turn at +40s, observed:
   - `[CoPilotExecutor] [cleanup N] Starting graceful shutdown...`
   - Drain-wait ran for ~4.5 min ("1 tasks still active, waiting...")
- Turn finished with `result=Done! The command finished after 5 minutes
and printed: hewwo`
   - `Cleaned up completed session` → `Graceful shutdown completed`
   - No zombie.
- [x] `poetry run format` applied.
- [x] RabbitMQ policy verified on dev. Apply on prod before prod deploy.
- [ ] Verified behavior on next production rolling deploy.
2026-04-23 06:49:06 +07:00
Zamil Majdy
c56c1e5dd6 fix(backend/copilot): disable ask_question tool pending UX rework (#12887)
### Why / What / How

**Why:** The in-conversation Question GUI is unreliable in production —
users submitting answers can get their messages dropped and the agent
gets stuck on the auto-generated "please proceed" step with no way to
make progress. Discord report:
https://discord.com/channels/1126875755960336515/1496474512966029472/1496537943287005365
(see attached video). Pause/queue semantics still need a rework; until
then, the right call is to stop the model from reaching for this tool.

**What:** Removes `ask_question` from the copilot tool registry so the
model never sees or calls it. Historical sessions that already contain
`ask_question` tool calls still render (frontend renderers + response
model untouched), so this is non-destructive to existing chats.
Re-enabling once UX is reworked is a small revert.

**How:**
- Drop the `AskQuestionTool` import + registry entry from
`backend/copilot/tools/__init__.py`.
- Drop `"ask_question"` from the `ToolName` literal in
`backend/copilot/permissions.py` — required because a runtime
consistency check asserts the literal matches `TOOL_REGISTRY.keys()`.
- Delete the "Clarifying — Before or During Building" section from
`backend/copilot/sdk/agent_generation_guide.md` so the SDK-mode system
prompt no longer instructs the model to call `ask_question`.
- Drop the three `prompting_test.py` tests that asserted the guide
mentions that section.
- Keep `ask_question.py`, its unit test, `ClarificationNeededResponse`,
and the frontend `AskQuestion`/`ClarificationQuestionsCard` components
untouched so old sessions still render and re-enabling is a small
revert.

### Changes 🏗️

- `backend/copilot/tools/__init__.py` — remove `AskQuestionTool` import
and `"ask_question"` entry in `TOOL_REGISTRY`.
- `backend/copilot/permissions.py` — remove `"ask_question"` from the
`ToolName` literal.
- `backend/copilot/sdk/agent_generation_guide.md` — remove the
"Clarifying — Before or During Building" section.
- `backend/copilot/prompting_test.py` — remove
`TestAgentGenerationGuideContainsClarifySection` and the now-unused
`Path` import.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [ ] I have tested my changes according to the test plan:
- [x] `poetry run pytest backend/copilot/tools/
backend/copilot/permissions_test.py backend/copilot/prompting_test.py` —
805+78 tests pass, consistency check between `ToolName` literal and
`TOOL_REGISTRY` still holds.
- [ ] Smoke-test in dev: start a copilot session and confirm the model
no longer lists/calls `ask_question` (its OpenAI tool schema is gone
from `get_available_tools()` and from the SDK `allowed_tools`).
- [ ] Load a historical session that contains an `ask_question` tool
call in its transcript — confirm the frontend still renders the question
card (no regression on legacy sessions).
2026-04-22 23:34:04 +07:00
Bentlybro
6fcbe95645 Merge branch 'master' into dev 2026-04-22 15:36:37 +01:00
Zamil Majdy
9703da3dfd refactor(backend/copilot): Moonshot module + cache_control widening + partial-messages default-on + title cost (#12882)
## Why

Several loose ends from the Kimi SDK-default merge (#12878), plus
follow-ups surfaced during review + E2E testing:

1. **Kimi-specific pricing lived inline in `sdk/service.py`** alongside
unrelated SDK plumbing — any future non-Anthropic vendor would have
piled onto the same file.
2. **Moonshot's Anthropic-compat endpoint honours `cache_control: {type:
ephemeral}`**, but the baseline cache-marking gate
(`_is_anthropic_model`) was narrow enough to exclude it → Moonshot fell
back to automatic prefix caching, which drifts readily between turns.
3. **Kimi reasoning rendered AFTER the answer text** on dev because the
summary-walk hoist only reorders within one `AssistantMessage.content`
list, and Moonshot splits each turn into multiple sequential
AssistantMessages (text-only, then thinking-only).
4. **Title generation's LLM call bypassed cost tracking** — admin
dashboard under-reported total provider spend by the aggregate of those
per-session calls.
5. **Cost override** was using the requested primary model, not the
actually-executed model — when the SDK fallback activates the override
mis-routes pricing.

## What

### Moonshot module
New `backend/copilot/moonshot.py`:
- `is_moonshot_model(model)` — prefix check against `moonshotai/`
- `rate_card_usd(model)` — published Moonshot rates, default `(0.60,
2.80)` per MTok with per-slug override slot
- `override_cost_usd(...)` — moved from `sdk/service.py`, replaces CLI's
Sonnet-rate estimate with real rate card
- `moonshot_supports_cache_control(model)` — narrow gate for cache
markers

Rate card is **not canonical** — authoritative cost comes from the
OpenRouter `/generation` reconcile; this module only improves the
in-turn estimate and the reconcile's lookup-fail fallback. Signal
authority: reconcile >> rate card >> CLI.

### Baseline cache-control widened to Moonshot
- New `_supports_prompt_cache_markers` = `_is_anthropic_model OR
is_moonshot_model`
- Both call sites (system-message cache dict, last-tool cache marker)
switched to the wider gate
- OpenAI / Grok / Gemini still return `false` — those endpoints 400 on
the unknown field

**Measured impact in /pr-test:** baseline Kimi continuation turns jumped
to ~98% cache hit (334 uncached + 12.8K cache_read on a 13.1K prompt).

### SDK partial-messages default-on (fixes the reasoning-order bug)
- `CHAT_SDK_INCLUDE_PARTIAL_MESSAGES` flipped from `default=False` →
`default=True`
- Kimi stream now emits `reasoning-start → reasoning-delta* →
reasoning-end → text-start → text-delta*` in the correct order —
verified in /pr-test
- Kill-switch: set `CHAT_SDK_INCLUDE_PARTIAL_MESSAGES=false` to fall
back to summary-only emission

### SDK cost override scoped to Moonshot
- Call site now explicitly gates `if _is_moonshot_model(active_model)` —
Anthropic turns trust CLI's number directly
- Added `_RetryState.observed_model` populated from
`AssistantMessage.model`, preferred over `state.options.model` so
fallback-model turns bill correctly (addresses CodeRabbit review)

### Title cost capture
- `_generate_session_title` now returns `(title, ChatCompletion)` so the
caller controls cost persistence
- `_update_title_async` runs title-persist and cost-record as
independent best-effort steps
- `_title_usage_from_response` helper reads `prompt_tokens /
completion_tokens / cost_usd` (OR's `usage.cost` off `model_extra`)
- Provider label derived from `ChatConfig.base_url` (`open_router` /
`openai`)
- No exception suppressors — `isinstance(cost_raw, (int, float))` check
replaces the inner `float()` try/except

### Misc
- Kimi tool-name whitespace strip in the response adapter — Kimi
occasionally emits tool names with leading spaces the CLI dispatcher
can't resolve
- TODO marker on the rate-card for post-prod-soak removal

## How

- Detection is **prefix-based** (`moonshotai/`) — future Kimi SKUs
transparently inherit rate card + cache-control gate
- Baseline cache-marking was already structured; only the gate changes
- Partial-messages default-on relies on the adapter's diff-based
reconcile (shipped in #12878) which has soaked stable
- Title cost path mirrors `tools/web_search.py`'s pattern for reading
OR's `usage.cost`

## Test plan

- [x] `pytest backend/copilot/moonshot_test.py` — 21 tests
- [x] `pytest backend/copilot/baseline/service_unit_test.py` — updated
for widened gate
- [x] `pytest backend/copilot/sdk/*_test.py
backend/copilot/service_test.py` — no regressions
- [x] Full E2E on local native stack — 10/10 scenarios pass (see
test-report comment)
- [x] Measured: baseline Kimi ~98% cache hit on continuation, SDK Kimi
~62% (capped by Moonshot's prefix ceiling)

## Deferred

SDK-path Moonshot cache hit rate stays at ~62% on long prompts.
`native_tokens_cached=18432` regardless of turn/session suggests a
Moonshot-side cap on cached prefix size. Not fixable by our code —
requires proxy rewriting requests or upstream Moonshot change.
2026-04-22 20:42:47 +07:00
Zamil Majdy
ebb0d3b95b feat(backend/copilot): LaunchDarkly per-user model routing (#12881)
## Summary

Per-user model routing for the copilot via LaunchDarkly. Replaces the
pure-env-var pick on every `(mode, tier)` cell of the model matrix with
an LD-first resolver that falls back to the `ChatConfig` default. Lets
us roll out non-default routes (e.g. Kimi K2.6 on baseline standard) to
a user cohort without shipping a deploy.

| | standard | advanced |

|----------|------------------------------------|------------------------------------|
| fast | `copilot-fast-standard-model` | `copilot-fast-advanced-model` |
| thinking | `copilot-thinking-standard-model` |
`copilot-thinking-advanced-model` |

All four flags are **string-valued** — the value IS the model identifier
(e.g. `"anthropic/claude-sonnet-4-6"` or `"moonshotai/kimi-k2.6"`).

## What ships

- **New module `backend/copilot/model_router.py`** with a single
`resolve_model(mode, tier, user_id, *, config)` coroutine. That's the
one place both paths consult.
- **4 new `Flag` enum values** in `backend/util/feature_flag.py`
(reusing the existing `get_feature_flag_value` helper which already
supports arbitrary return types).
- **`baseline/service.py::_resolve_baseline_model`** → async, takes
`user_id`.
- **`sdk/service.py::_resolve_sdk_model_for_request`** → takes
`user_id`, consults LD for both standard and advanced thinking cells.
- **Default flip**: `fast_standard_model` default goes back to
`anthropic/claude-sonnet-4-6`. Non-Anthropic routes now ship via LD
targeting — safer rollback, per-user cohort control, no redeploy
required to flip.

## Behavior preserved

- `config.claude_agent_model` explicit override still wins
unconditionally (existing escape hatch for ops).
- `use_claude_code_subscription=true` on the standard thinking tier
still returns `None` so the CLI picks the model tied to the user's
Claude Code subscription.
- All legacy env var aliases (`CHAT_MODEL`, `CHAT_ADVANCED_MODEL`,
`CHAT_FAST_MODEL`) still bind to their cells.
- LD client exceptions / misconfigured (non-string) flag values fall
back silently to config default with a single warning log — never fails
the request.

## Files

| File | Change |
|---|---|
| `backend/copilot/model_router.py` | new — `resolve_model` +
`_config_default` + `_FLAG_BY_CELL` map |
| `backend/copilot/model_router_test.py` | new — 11 cases |
| `backend/util/feature_flag.py` | add 4 string-valued `Flag` entries |
| `backend/copilot/config.py` | flip `fast_standard_model` default to
Sonnet |
| `backend/copilot/baseline/service.py` | `_resolve_baseline_model` →
async + LD resolver |
| `backend/copilot/sdk/service.py` | `_resolve_sdk_model_for_request` →
LD resolver + user_id |
| `backend/copilot/baseline/transcript_integration_test.py` | update
tests for new signature + default |

## Test plan

- [x] `poetry run pytest backend/copilot/model_router_test.py
backend/copilot/baseline/transcript_integration_test.py
backend/copilot/sdk/service_test.py backend/copilot/config_test.py` —
**112 passing**
- [x] 11 resolver cases: missing user → fallback, LD string wins,
whitespace stripped, non-string value → fallback, empty string →
fallback, LD exception → fallback + warn, each of 4 cells routes to its
distinct flag
- [x] Legacy env aliases still bind to their new fields
- [ ] Manual dev-env smoke: flip `copilot-fast-standard-model` LD
targeting to `moonshotai/kimi-k2.6` for one user and confirm baseline
uses Kimi while other users stay on Sonnet
- [ ] Confirm SDK path still honors subscription mode (LD not consulted
when `use_claude_code_subscription=true` + standard tier)

## Rollout

1. Merge this PR → default stays Sonnet / Opus across the matrix, no
behavior change.
2. Create the 4 LD flags as string-typed in the LaunchDarkly console
(defaults matching config, so no drift if targeting empty).
3. Add per-user / per-cohort targeting in LD for the routes we want to
roll out (Kimi on baseline standard for a percentage, etc.).
2026-04-22 20:08:37 +07:00
Zamil Majdy
b98bcf31c8 feat(backend/copilot): SDK fast tier defaults to Kimi K2.6 via OpenRouter + vendor-aware cost + cross-model fix (#12878)
## Summary

Make Kimi K2.6 the default for the SDK (extended-thinking) copilot path,
mirroring the baseline default landed in #12871. The SDK already routes
through OpenRouter (see
[`build_sdk_env`](autogpt_platform/backend/backend/copilot/sdk/env.py) —
`ANTHROPIC_BASE_URL` is set to OpenRouter's Anthropic-compatible
`/v1/messages` endpoint), but the model resolver was unconditionally
stripping the vendor prefix, which prevented routing to anything except
Anthropic models. This PR unblocks Kimi (and any other non-Anthropic
OpenRouter vendor) on the SDK fast tier and flips the default to match
the baseline path.

## Why

After #12871 the baseline (`fast_*`) path runs Kimi K2.6 by default —
~5x cheaper than Sonnet at SWE-Bench parity — but the SDK (`thinking_*`)
path was still pinned to Sonnet because:

1. **Model name normalization stripped the vendor prefix.**
`_normalize_model_name("moonshotai/kimi-k2.6")` returned `"kimi-k2.6"`,
which OpenRouter cannot route — the unprefixed form only resolves for
Anthropic models. The docstring on `thinking_standard_model` claimed
"the Claude Agent SDK CLI only speaks to Anthropic endpoints", but the
env builder shows the CLI happily talks to OpenRouter's `/messages`
endpoint, which routes to any vendor in the catalog.
2. **The default was `anthropic/claude-sonnet-4-6`.** Same model on a
more expensive route.
3. **Cost label was hardcoded to `provider="anthropic"`** on the SDK
path's `persist_and_record_usage` call, making cost-analytics rows
misleading once Kimi runs.

## What

1. **`_normalize_model_name`**
([sdk/service.py](autogpt_platform/backend/backend/copilot/sdk/service.py))
— when `config.openrouter_active` is True, the canonical `vendor/model`
slug is preserved unchanged so OpenRouter can route to the correct
provider. Direct-Anthropic mode keeps the existing strip-prefix +
dot-to-hyphen conversion (Anthropic API requires both) and now **raises
`ValueError`** when paired with a non-Anthropic vendor slug — silent
strip would have sent `kimi-k2.6` to the Anthropic API and produced an
opaque `model_not_found`.
2. **`thinking_standard_model`**
([config.py](autogpt_platform/backend/backend/copilot/config.py)) —
default flipped from `anthropic/claude-sonnet-4-6` to
`moonshotai/kimi-k2.6`. Field description rewritten; rollback to Sonnet
is one env var
(`CHAT_THINKING_STANDARD_MODEL=anthropic/claude-sonnet-4.6`).
3. **`@model_validator(mode="after")`** on `ChatConfig`
([config.py:_validate_sdk_model_vendor_compatibility](autogpt_platform/backend/backend/copilot/config.py))
— fail at config load when `use_openrouter=False` is paired with a
non-Anthropic SDK slug. The runtime guard in `_normalize_model_name` is
kept as defence-in-depth, but the validator turns a per-request 500 into
a boot-time error message the operator sees once, before any traffic
lands. Covers `thinking_standard_model`, `thinking_advanced_model`, and
`claude_agent_fallback_model`. Subscription mode is exempt (resolver
returns `None` and never normalizes). The credential-missing case
(`use_openrouter=True` + no `api_key`) is intentionally NOT a boot-time
error so CI builds and OpenAPI-schema export jobs that construct
`ChatConfig()` without secrets keep working — the runtime guard still
catches it on the first SDK turn.
4. **Cost provider attribution**
([sdk/service.py:stream_chat_completion_sdk](autogpt_platform/backend/backend/copilot/sdk/service.py))
— `persist_and_record_usage` now passes `provider="open_router" if
config.openrouter_active else "anthropic"` instead of hardcoded
`"anthropic"`. The dollar value still comes from
`ResultMessage.total_cost_usd`; this just fixes the analytics label.
5. **Baseline rollback example** ([config.py:fast_standard_model
description](autogpt_platform/backend/backend/copilot/config.py)) — same
dot-vs-hyphen footgun fix (CodeRabbit catch).
6. **Tests** — `TestNormalizeModelName` (sdk/) monkeypatches a
deterministic config per case (the helper-test variants were passing
accidentally based on ambient env). New
`TestSdkModelVendorCompatibility` class in `config_test.py` covers all
five validator shapes (default-Kimi + direct-Anthropic raises, anthropic
override succeeds, openrouter mode succeeds, subscription mode skips
check, advanced+fallback tier also validated, empty fallback skipped).
`_ENV_VARS_TO_CLEAR` extended to all model/SDK/subscription env aliases
so a leftover dev `.env` value can't mask validator behaviour. New
`_make_direct_safe_config` helper for direct-Anthropic tests.

## Test plan

- [x] `poetry run pytest backend/copilot/config_test.py
backend/copilot/sdk/service_test.py
backend/copilot/sdk/service_helpers_test.py
backend/copilot/sdk/env_test.py
backend/copilot/sdk/p0_guardrails_test.py` — 238 pass
- [x] `poetry run pytest backend/copilot/` — 2560 pass + 5 pre-existing
integration failures (need real API keys / browser env, unrelated)
- [x] CI green on `feat/copilot-sdk-kimi-default` (35 pass / 0 fail / 1
neutral)
- [x] Manual: SDK extended_thinking turn against Kimi K2.6 via
OpenRouter on the native dev stack — request lands with
`model=moonshotai/kimi-k2.6`, response streams back, multi-turn
`--resume` recalls facts across turns. Backend log: `[SDK] Per-request
model override: standard (moonshotai/kimi-k2.6)`.
- [x] Manual: rollback path —
`CHAT_THINKING_STANDARD_MODEL=anthropic/claude-sonnet-4.6` resumes
Sonnet routing.

## Known follow-ups (not in this PR)

These surfaced during manual testing and will need separate PRs:

- **SDK CLI cost is wrong for non-Anthropic models.**
`ResultMessage.total_cost_usd` comes from a static Anthropic pricing
table baked into the CLI binary; for Kimi K2.6 it falls back to Sonnet
rates, **over-billing ~5x** ($0.089 vs the real ~$0.018 for ~30K prompt
+ ~80 completion). The `provider` label is now correct but the dollar
value isn't. Needs either a per-model rate card override on our side or
a CLI patch upstream.
- **Mid-session model switch (Kimi → Opus) breaks.** Kimi's
`ThinkingBlock`s have no Anthropic `signature` field; when the user
toggles standard → advanced after a Kimi turn, Opus rejects the replayed
transcript with `Invalid signature in thinking block`. Needs transcript
scrubbing on model switch (similar to existing
`TestStripStaleThinkingBlocks` pattern).
- **Reasoning UI ordering on Kimi.** Moonshot/OpenRouter places
`reasoning` AFTER text in the response; the SDK's
`AssistantMessage.content` reflects that order, and `response_adapter`
emits SSE events in the same order — so reasoning lands BELOW the answer
in the UI instead of above. Needs `ThinkingBlock` hoisting in
`response_adapter.py`.
2026-04-22 18:35:01 +07:00
Zamil Majdy
4f11867d92 feat(backend/copilot): TodoWrite for baseline copilot (#12879)
## Summary

Add `TodoWrite` to baseline copilot so the "task checklist" UI works on
non-Claude models (Kimi, GPT, Grok, etc.) the same way it works on the
SDK path. Baseline previously had no `TodoWrite` tool at all — only SDK
mode did via the Claude Code CLI's built-in — so models on baseline just
couldn't reach for a planning checklist.

This closes the last clear feature gap blocking baseline from being the
primary copilot path without giving up model flexibility.

## What ships

- **New MCP tool `TodoWrite`** in `TOOL_REGISTRY`, schema matching the
one the frontend's `GenericTool.helpers.ts` (`getToolCategory → "todo"`)
already renders as the **Steps** accordion. The tool is a stateless echo
— the canonical list lives in the model's latest tool-call args and
replays from transcript on subsequent turns.
- **Prompt guidance** in `SHARED_TOOL_NOTES` teaching the model when to
use it (3+ step tasks; always send the full list; exactly one
`in_progress` at a time).
- **Sharpened `run_sub_session` guidance** in the same prompt section —
framed explicitly as the context-isolation primitive for baseline.
Clearer for the model, no dual-primitive confusion.

## How the SDK path stays untouched

- SDK mode keeps using the CLI-native `TodoWrite` built-in.
- `BASELINE_ONLY_MCP_TOOLS = {"TodoWrite"}` in `sdk/tool_adapter.py`
filters the baseline MCP wrapper out of SDK's `allowed_tools` — no name
shadowing.
- `SDK_BUILTIN_TOOL_NAMES` is now an explicit allowlist (not
auto-derived from capitalization) so the classification stays coherent
when a capitalized tool is platform-owned.

## Files

| File | Change |
|---|---|
| `backend/copilot/tools/todo_write.py` | new — `TodoWriteTool` |
| `backend/copilot/tools/__init__.py` | register in `TOOL_REGISTRY` |
| `backend/copilot/tools/models.py` | add `TodoItem` +
`TodoWriteResponse` + `ResponseType.TODO_WRITE` |
| `backend/copilot/permissions.py` | explicit `SDK_BUILTIN_TOOL_NAMES`;
`apply_tool_permissions` maps baseline-only tools to CLI name for SDK |
| `backend/copilot/sdk/tool_adapter.py` | `BASELINE_ONLY_MCP_TOOLS`
filter |
| `backend/copilot/prompting.py` | `TodoWrite` + sharpened
`run_sub_session` guidance |
| `backend/api/features/chat/routes.py` | add `TodoWriteResponse` to
`ToolResponseUnion` |
| `backend/copilot/tools/todo_write_test.py` | new — schema + execute
tests |
| `frontend/src/app/api/openapi.json` | regenerated |
| `tools/tool_schema_test.py` | budget bumped `32_800 → 34_000` (actual
33_865, +1_065 headroom) |

## Test plan

- [x] `poetry run pytest backend/copilot/
backend/api/features/chat/routes_test.py` — **1010 passing**
- [x] Tool schema char budget regression gate passes
- [x] `_assert_tool_names_consistent` passes
- [x] **E2E on local native stack (Kimi K2.6 via OpenRouter,
`CHAT_USE_CLAUDE_AGENT_SDK=false`)**: baseline called `TodoWrite` on a
3-step prompt, SSE stream carried the exact `{content, activeForm,
status}` shape the UI expects, "Steps" dialog renders `Task list — 0/3
completed` with all three items (see test-report comment below).
- [x] Negative cases covered: two `in_progress` → rejected, missing
`activeForm` → rejected, non-list `todos` → rejected.
2026-04-22 17:28:15 +07:00
Zamil Majdy
33a608ec78 feat(platform/copilot): live baseline streaming + render flag + Sonar web_search + simulator cost tracking + reconnect fixes (#12873)
### Why / What / How

**Why.** Three problems on the baseline copilot path that compound:
extended-thinking turns froze the UI for minutes because Kimi K2.6
events were buffered in `state.pending_events: list` until the full
`tool_call_loop` iteration finished (reasoning arrived in one lump at
the end); the SSE stream replayed 1000 events on every reconnect and the
frontend opened multiple SSE streams in quick succession on tab-focus
thrash (reconnect storm → UI flickers, tab freezes); the `web_search`
tool hit Anthropic's server-side beta directly via a dispatch-model
round-trip that fed entire page contents back through the model for a
second inference pass (observed $0.072 on a 74K-token call); and the
simulator dry-run path ran on Gemini Flash without any cost tracking at
all, so every dry-run was free on the platform's microdollar ledger.

**What.** Grouped deltas, all targeting reliability, cost, and UX of the
copilot live-answer pipeline:

- **Live per-token baseline streaming.** `state.pending_events` is now
an `asyncio.Queue` drained concurrently by the outer async generator.
The tool-call loop runs as a background task; reasoning / text / tool
events reach the SSE wire during the upstream OpenRouter stream, not
after it. `None` is the close sentinel; inner-task exceptions are
re-raised via `await loop_task` once the sentinel arrives. An
`emitted_events: list` mirror preserves post-hoc test inspection.
Coalescing widened 32/40 → 64/50 ms to halve the React re-render rate on
extended-thinking turns while staying under the ~100 ms perceptual
threshold.
- **Reasoning render flag** — `ChatConfig.render_reasoning_in_ui: bool =
True` wired through both `BaselineReasoningEmitter` and
`SDKResponseAdapter`. When False the wire `StreamReasoning*` events are
suppressed while the persisted `ChatMessage(role='reasoning')` rows
always survive (decoupled from the render flag so audit/replay is
unaffected); the service-layer yield filter does the gating. Tokens are
still billed upstream; operator kill-switch for UI-level flicker
investigations.
- **Reconnect storm mitigations** — `ChatConfig.stream_replay_count: int
= 200` (was hard-coded 1000) caps `stream_registry.subscribe_to_session`
XREAD size. Frontend `useCopilotStream::handleReconnect` adds a 1500 ms
debounce via `lastReconnectResumeAtRef`, so tab-focus thrash doesn't fan
out into 5–6 parallel replays in the same second.
- **web_search rewritten to Perplexity Sonar via OpenRouter** — single
unified credential, real `usage.cost` flows through
`persist_and_record_usage(provider='open_router')`. Two tiers via a
`deep` param: `perplexity/sonar` (~$0.005/call quick) and
`perplexity/sonar-deep-research` (~$0.50–$1.30/call multi-step
research). Replaces the Anthropic-native + server-tool dispatches; drops
the hardcoded pricing constants entirely.
- **Synthesised answer surfaced end-to-end** — Sonar already writes a
web-grounded answer on the same call we pay for; the new
`WebSearchResponse.answer` field passes it through and the accordion UI
renders it above citations so the agent doesn't re-fetch URLs that are
usually bot-protected anyway.
- **Deep-tier cost warning + UI affordances** — `deep` param description
is explicit that it's ~100× pricier; UI labels read "Researching /
Researched / N research sources" when `deep=true` so users know what's
running.
- **Simulator cost tracking + cheaper default** —
`google/gemini-2.5-flash` → `google/gemini-2.5-flash-lite` (3× cheaper
tokens) and every dry-run now hits
`persist_and_record_usage(provider='open_router')` with real
`usage.cost`. Previously each sim was free against the user's
microdollar budget.
- **Typed access everywhere** — cost extractors now use
`openai.types.CompletionUsage.model_extra["cost"]` and
`openai.types.chat.ChatCompletion` / `Annotation` /
`AnnotationURLCitation` with no `getattr` / duck typing. Mirrors the
baseline service's `_extract_usage_cost` pattern; keep in sync.

**How.** Key file touches:

1. `copilot/config.py` — `render_reasoning_in_ui`,
`stream_replay_count`, `simulation_model` default.
2. `copilot/baseline/service.py` — `_BaselineStreamState.pending_events:
asyncio.Queue`, `_emit` / `_emit_all` helpers, outer generator runs
`tool_call_loop` as a background task + yields from queue concurrently.
3. `copilot/baseline/reasoning.py` —
`BaselineReasoningEmitter(render_in_ui=...)`, coalescing bumped to 64
chars / 50 ms.
4. `copilot/sdk/service.py` — `state.adapter.render_reasoning_in_ui`
threaded through every adapter construction.
5. `copilot/sdk/response_adapter.py` — `render_reasoning_in_ui` wiring +
service-layer yield filter gating for wire suppression while persistence
stays intact.
6. `copilot/stream_registry.py` — `count=config.stream_replay_count`.
7. `frontend/.../useCopilotStream.ts::handleReconnect` — 1500 ms
debounce.
8. `copilot/tools/web_search.py` + `models.py` — Sonar quick/deep paths,
`WebSearchResponse.answer` + typed extractors.
9. `frontend/.../GenericTool/*` — `answer` render + deep-aware labels /
accordion titles.
10. `executor/simulator.py` + `executor/manager.py` +
`copilot/config.py` — cost tracking + model swap + `user_id` threading.

### Changes

- `copilot/config.py` — new `render_reasoning_in_ui`,
`stream_replay_count`; `simulation_model` default flipped to Flash-Lite.
- `copilot/baseline/service.py` — `pending_events: asyncio.Queue`
refactor; outer gen runs loop as task, yields from queue live.
- `copilot/baseline/reasoning.py` —
`BaselineReasoningEmitter(render_in_ui=...)` + 64/50 coalesce.
- `copilot/sdk/service.py` + `response_adapter.py` —
`render_reasoning_in_ui` wire suppression (persistence preserved).
- `copilot/stream_registry.py` — replay cap from config.
- `copilot/tools/web_search.py` + `models.py` — Sonar quick/deep +
`answer` field + typed extractors.
- `copilot/tools/helpers.py` — tool description tightens `deep=true`
cost warning.
- `frontend/.../useCopilotStream.ts` — reconnect debounce.
- `frontend/.../GenericTool/GenericTool.tsx` + `helpers.ts` + tests —
render `answer`, deep-aware verbs / titles.
- `executor/simulator.py` + `simulator_test.py` + `executor/manager.py`
— cost tracking + model swap + user_id plumbing.

### Follow-up (deferred to a separate PR)

SDK per-token streaming via `include_partial_messages=True` was
attempted (commits `599e83543` + `530fa8f95`) and reverted here. The
two-signal model (StreamEvent partial deltas + AssistantMessage summary)
needs proper per-block diff tracking — when the partial stream delivers
a subset of the final block content, emit only
`summary.text[len(already_emitted):]` from the summary rather than
gating on a binary flag. Binary gating truncated replies in the field
when the partial stream delivered less than the summary (observed: "The
analysis template you" cut off mid-sentence because partial had streamed
that much and the rest only lived in the summary). SDK reasoning still
renders end-of-phase (as today); this PR's baseline per-token streaming
is unaffected.

### Checklist

For code changes:
- [x] Changes listed above
- [x] Test plan below
- [x] Tested according to the test plan:
- [x] `poetry run pytest backend/copilot/baseline/ backend/copilot/sdk/
backend/copilot/tools/web_search_test.py
backend/executor/simulator_test.py` — all pass (155 baseline + 927 SDK +
web_search + simulator)
- [x] `pnpm types && pnpm vitest run
src/app/(platform)/copilot/tools/GenericTool/` — pass
- [x] Manual: baseline live-streaming — Kimi K2.6 reasoning arrives
token-by-token, coalesced (no end-of-stream burst).
- [x] Manual: quick web_search via copilot UI — ~$0.005/call, answer +
citations rendered, cost logged as `provider=open_router`.
- [x] Manual: deep web_search — dispatched only on explicit research
phrasing; `sonar-deep-research` billed, UI labels say "Researched" / "N
research sources".
- [x] Manual: simulator dry-run — Gemini Flash-Lite, `[simulator] Turn
usage` log entry, PlatformCostLog row visible.
- [x] Manual: reconnect debounce — tab-focus thrash no longer produces
parallel XREADs in backend log.
- [ ] Manual: `CHAT_RENDER_REASONING_IN_UI=false` smoke-check —
reasoning collapse absent, no persisted reasoning row on reload.

For configuration changes:
- [x] `.env.default` — new config knobs fall back to pydantic defaults;
existing `CHAT_MODEL`/`CHAT_FAST_MODEL`/`CHAT_ADVANCED_MODEL` legacy
envs still honored upstream (unchanged by this PR).

### Companion PR

PR #12876 closes the `run_block`-via-copilot cost-leak gap (registers
`PerplexityBlock` / `FactCheckerBlock` in `BLOCK_COSTS`; documents the
credit/microdollar wallet boundary). Separate because the credit-wallet
side is orthogonal to the copilot microdollar / rate-limit surface this
PR ships.
2026-04-22 13:52:18 +07:00
Zamil Majdy
e3f6d36759 feat(backend/blocks): register 13 paid blocks + document credit/microdollar wallet boundary (#12876)
### Why / What / How

**Why.** Audit of `BLOCK_COSTS` against `credentials_store.py` system
credentials revealed **13 paid blocks** running for free from the credit
wallet's perspective — `BLOCK_COSTS.get(type(block))` returned `None`,
`cost = 0`, no `spend_credits` deduction. Users without their own API
key consumed system credentials with zero credit drain. Separately, the
credit wallet (user-facing prepaid balance) and the copilot microdollar
counter (operator-side meter that gates `daily_cost_limit_microdollars`)
were never documented as separate systems, so future readers kept
tripping on the "why isn't this block charging my limit?" question.

**What.** Three deltas, all credit-wallet-side:

- **Register the 13 paid blocks in `BLOCK_COSTS`** with reasonable
per-call credit prices (1 credit = $0.01). Pricing researched against
the providers' published rates with ~2-3x markup.
- **Document the credit/microdollar boundary** in
`copilot/rate_limit.py`: credits = user-facing prepaid wallet with
marketplace-creator charging; microdollars = operator-side meter that
only ticks on copilot LLM turns (baseline / SDK / web_search /
simulator). Block execution bills credits, not microdollars — explicit
contract.
- **Populate `provider_cost`** on PerplexityBlock so PlatformCostLog
rows carry the real OpenRouter `x-total-cost` value via the existing
`executor/cost_tracking.log_system_credential_cost` path (separate flow
from credit deduction).

### Block costs registered

| Provider | Block | Credits | Raw cost / markup |
|---|---|---|---|
| Perplexity (OpenRouter) | PerplexityBlock — Sonar | 1 | $0.001-0.005 /
call |
| | PerplexityBlock — Sonar Pro | 5 | $0.025 / call |
| | PerplexityBlock — Sonar Deep Research | 10 | up to $0.05 / call |
| Jina | FactCheckerBlock | 1 | $0.005 / call |
| Mem0 | AddMemoryBlock | 1 | $0.0004 / call (1c floor) |
| | SearchMemoryBlock | 1 | $0.004 / call |
| | GetAllMemoriesBlock | 1 | $0.004 / call |
| | GetLatestMemoryBlock | 1 | $0.004 / call |
| ScreenshotOne | ScreenshotWebPageBlock | 2 | $0.0085 / call (2.4x) |
| Nvidia | NvidiaDeepfakeDetectBlock | 2 | est $0.005 (no public SKU) |
| Smartlead | CreateCampaignBlock | 2 | $0.0065 send-equivalent (3x) |
| | AddLeadToCampaignBlock | 1 | $0.0065 (1.5x) |
| | SaveCampaignSequencesBlock | 1 | config-only |
| ZeroBounce | ValidateEmailsBlock | 2 | $0.008 / email (2.5x) |
| E2B + Anthropic | ClaudeCodeBlock | **100** | $0.50-$2 / typical
session (E2B sandbox + in-sandbox Claude) |

**Not in scope** — already covered via the SDK
`ProviderBuilder.with_base_cost()` pattern in their respective
`_config.py`: Exa, Linear, Airtable, Bannerbear, Wolfram, Firecrawl,
Wordpress, Baas, Stagehand, Dataforseo.

### How

1. `backend/data/block_cost_config.py` — 13 new `BlockCost` entries (3
Perplexity models + Fact Checker + 11 from this round).
2. `backend/copilot/rate_limit.py` — boundary docstring.
3. `backend/blocks/perplexity.py` — populate
`NodeExecutionStats.provider_cost` so PlatformCostLog rows carry the
real OpenRouter `x-total-cost` value.
4. Tests — `TestUnregisteredBlockRunsFree` regression +
`TestNewlyRegisteredBlockCosts` pinning every new entry by `cost_amount`
so a future refactor can't quietly drop one.

The companion Notion "Platform System Credentials" database has been
updated with a new `Platform Credit Cost` column populated across all 30
provider rows.

### Scope trim

An earlier revision piped block execution cost into the **copilot
microdollar counter** via `_record_block_microdollar_cost` in
`copilot/tools/helpers.py::execute_block`. That was reverted in
`16ae0f7b5` — the microdollar counter stays scoped to copilot LLM turns
only, credit wallet handles block execution. The pipe-through crossed a
boundary we explicitly want to keep.

### Changes

- `backend/data/block_cost_config.py` — 13 × `BlockCost` entries across
7 providers.
- `backend/blocks/perplexity.py` — populate `provider_cost` on the
execution stats (feeds PlatformCostLog).
- `backend/copilot/rate_limit.py` — boundary docstring only (no
behaviour change).
- `backend/copilot/tools/helpers_test.py` —
`TestUnregisteredBlockRunsFree` + `TestNewlyRegisteredBlockCosts` (8 new
regression tests).
- `backend/blocks/block_cost_tracking_test.py` — provider-cost
extraction pins.

### Checklist

For code changes:
- [x] Changes listed above
- [x] Test plan below
- [x] Tested according to the test plan:
- [x] `poetry run pytest backend/copilot/tools/helpers_test.py
backend/copilot/tools/run_block_test.py
backend/copilot/tools/continue_run_block_test.py
backend/blocks/block_cost_tracking_test.py
backend/blocks/test/test_perplexity.py` — passes
- [x] `poetry run pytest backend/executor/manager_cost_tracking_test.py
backend/copilot/rate_limit_test.py
backend/copilot/token_tracking_test.py` — passes (confirms docstring
edits didn't regress the LLM-turn microdollar path)
  - [x] Pyright clean on all touched files
- [ ] Manual: run PerplexityBlock via copilot `run_block` — credits
deduct, PlatformCostLog row visible with `provider_cost`, no
microdollar-counter tick.
- [ ] Manual: run an unregistered block via copilot — no error, no
credit drain, no silent billing.
- [ ] Manual: run ClaudeCodeBlock via builder — 100 credits deducted
from wallet.

### Companion PR

PR #12873 ships the copilot microdollar / rate-limit work (web_search
cost, simulator cost, reasoning / reconnect fixes). This PR is
credit-wallet only.
2026-04-22 12:03:02 +07:00
Nicholas Tindle
c1b9ed1f5e fix(backend/copilot): allow multiple compactions per turn (#12834)
### Why / What / How

**Why:** The old `CompactionTracker` set a `_done` flag after the first
completion and short-circuited every subsequent compaction in the same
turn. That blocked the SDK-internal compaction from running after a
pre-query compaction had already fired, so prompt-too-long errors
couldn't actually recover — retries saw the flag, bailed, and we re-hit
the context limit.

**What:** Drop the `_done` flag, track attempts and completions as
separate lists, and expose counters + an observability metadata builder
so callers can record compaction activity per turn.

**How:**
- Remove `_done` and `_compact_start` short-circuits.
- Track `_attempted_sources` / `_completed_sources` /
`_completed_count`.
- Expose `attempt_count`, `completed_count`, and
`get_observability_metadata()` / `get_log_summary()` for downstream
instrumentation (no caller change required in this PR).

### Changes 🏗️

- `backend/copilot/sdk/compaction.py` — rewritten `CompactionTracker`
internals; adds properties + observability helpers.
- `backend/copilot/sdk/compaction_test.py` — tests for multi-compaction
flow + new counters.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [ ] I have tested my changes according to the test plan:
- [ ] `poetry run pytest backend/copilot/sdk/compaction_test.py -xvs`
passes
- [ ] Local chat that hits prompt-too-long now recovers via SDK
compaction instead of failing the turn

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Medium Risk**
> Changes core streaming compaction state transitions and persistence
timing, which could affect UI event sequencing or compaction completion
behavior under concurrency; coverage is improved with new
multi-compaction tests.
> 
> **Overview**
> Fixes `CompactionTracker` so compaction is no longer single-shot per
turn: removes the `_done`/event-gate behavior, queues multiple
`on_compact()` hook firings via a pending transcript-path deque, and
allows subsequent SDK-internal compactions after a pre-query compaction
within the same query.
> 
> Adds lightweight instrumentation by tracking attempt/completion
sources and counts, plus `get_observability_metadata()` and
`get_log_summary()` (including source summaries like `sdk_internal:2`).
Updates/expands tests to cover multi-compaction flows, transcript-path
handling, and the new counters/metadata.
> 
> <sup>Reviewed by [Cursor Bugbot](https://cursor.com/bugbot) for commit
9bf8cdd367. Bugbot is set up for automated
code reviews on this repo. Configure
[here](https://www.cursor.com/dashboard/bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: majdyz <zamil.majdy@agpt.co>
2026-04-22 02:02:03 +00:00
Zamil Majdy
45bc167184 feat(backend/copilot): Kimi K2.6 fast default + 4-config matrix + coalesced reasoning + web_search tool (#12871)
### Why / What / How

**Why.** Three unrelated but interlocking problems on the baseline
(OpenRouter) copilot path, all blocking us from making Kimi K2.6 the
default fast model:

1. **Cost / capability gap on the default.** Kimi K2.6 prices at $0.60 /
$2.80 per MTok — ~5x cheaper input and ~5.4x cheaper output than Sonnet
4.6 — while tying Opus on SWE-Bench Verified (80.2% vs 80.8%) and
beating it on SWE-Bench Pro (58.6% vs 53.4%). OpenRouter exposes the
same `reasoning` / `include_reasoning` extension on Moonshot endpoints
that #12870 plumbed for Anthropic, so the reasoning collapse lights up
end-to-end without per-provider code.
2. **Kimi reasoning deltas freeze the UI.** K2.6 emits ~4,700
reasoning-delta SSE events per turn vs ~28 on Sonnet — the AI SDK v6
Reasoning UIMessagePart can't keep up and the tab locks. Needs a
coalescing buffer upstream.
3. **Kimi loops on `require_guide_read`.** The guide-guard checks
`session.messages` for a prior `agent_building_guide` call, but tool
calls aren't flushed to `session.messages` until the end of the turn —
mid-turn the check keeps returning False and Kimi calls the guide-load
tool repeatedly in the same turn. Needs an in-flight tracker that lives
on `ChatSession`.
4. **No `web_search` tool on either path.** Kimi doesn't have a native
web-search equivalent and the SDK path's native `WebSearch` (the Claude
Code CLI's built-in) doesn't carry cost accounting. We need one
implementation that both paths share and that reports cost through the
same tracker as every other tool call.

**What.** Five grouped deltas on the baseline service, tool layer, and
config:

- **Kimi K2.6 default.** `fast_standard_model` defaults to
`moonshotai/kimi-k2.6`. Full 2×2 model matrix below. Rollback is one env
var.
- **4-config model matrix.** `fast_standard_model` /
`fast_advanced_model` / `thinking_standard_model` /
`thinking_advanced_model`. Each cell independent so baseline can run a
cheap provider at the standard tier without leaking into the SDK path
(which is Anthropic-only by CLI contract). Legacy env vars
(`CHAT_MODEL`, `CHAT_FAST_MODEL`, `CHAT_ADVANCED_MODEL`) stay aliased
via `validation_alias` so live deployments keep resolving to the same
effective cell.
- **Reasoning delta coalescing.** `BaselineReasoningEmitter` buffers
deltas and flushes on a char-count OR time-interval threshold (32 chars
/ 40 ms). ~4,700 → ~150 SSE events per turn on Kimi; no perceptible
change on Sonnet (which was already well under the threshold).
- **In-flight tool-call tracker.** `ChatSession._inflight_tool_calls`
PrivateAttr is populated when a tool-call block is emitted and cleared
at turn end. `session.has_tool_been_called_this_turn(name)` now returns
True mid-turn, not just after the tool-result lands in
`session.messages` — which is what `require_guide_read` needs to cut the
loop.
- **New `web_search` copilot tool.** Wraps Anthropic's server-side
`web_search_20250305` beta via `AsyncAnthropic` (direct — OpenRouter
can't proxy server-side tool execution). Dispatches through
`claude-haiku-4-5` with `max_uses=1`. Cost estimated from published
rates ($0.010 per search + Haiku tokens) since the Anthropic Messages
API doesn't report cost on the response; reported to
`persist_and_record_usage(provider='anthropic')` on both paths. SDK
native `WebSearch` moved from `_SDK_BUILTIN_ALWAYS` into
`SDK_DISALLOWED_TOOLS` so both paths now dispatch through
`mcp__copilot__web_search`.

**How.**

1. `copilot/config.py` — 2×2 model fields with `AliasChoices` preserving
legacy env var names. `populate_by_name = True` so
`ChatConfig(fast_standard_model=...)` works in tests.
2. `copilot/baseline/service.py::_resolve_baseline_model` — resolves the
active baseline cell from `mode` + `tier`, no longer delegates to the
SDK resolver.
3. `copilot/baseline/reasoning.py` — `BaselineReasoningEmitter` gains
`_pending_delta` / `_last_flush_monotonic` and flushes on
`len(_pending_delta) >= _COALESCE_MIN_CHARS` OR `monotonic() -
_last_flush_monotonic >= _COALESCE_MAX_INTERVAL_MS / 1000`.
`_is_reasoning_route` rewritten as an anchored prefix match covering
`anthropic/`, `anthropic.`, `moonshotai/`, and `openrouter/kimi-` —
split from the narrower `_is_anthropic_model` gate that still governs
`cache_control` markers (which Kimi doesn't support).
4. `copilot/model.py::ChatSession` — `_inflight_tool_calls: set[str] =
PrivateAttr(default_factory=set)` plus `announce_inflight_tool_call` /
`clear_inflight_tool_calls` / `has_tool_been_called_this_turn`.
5. `copilot/tools/helpers.py::require_guide_read` — check
`session.has_tool_been_called_this_turn(_AGENT_GUIDE_TOOL_NAME)` before
falling back to scanning `session.messages`.
6. `copilot/tools/web_search.py` — new `WebSearchTool` +
`_extract_results` + `_estimate_cost_usd`. `is_available` gated on
`Settings().secrets.anthropic_api_key` so the deployment can roll back
just by unsetting the key.
7. `copilot/tools/__init__.py` — registers `web_search` in
`TOOL_REGISTRY` so it becomes `mcp__copilot__web_search` in the SDK
path.
8. `copilot/sdk/tool_adapter.py` — `WebSearch` moves to
`SDK_DISALLOWED_TOOLS`.

### Changes

- `copilot/config.py` — 2×2 model matrix with legacy env alias
preservation; `populate_by_name=True`.
- `copilot/baseline/service.py::_resolve_baseline_model` — resolves
against the new matrix.
- `copilot/baseline/reasoning.py` — `BaselineReasoningEmitter`
coalescing buffer; `_is_reasoning_route` rewritten as anchored prefix
match (covers `anthropic/`, `anthropic.`, `moonshotai/`,
`openrouter/kimi-`).
- `copilot/model.py::ChatSession` — `_inflight_tool_calls` PrivateAttr +
helpers.
- `copilot/baseline/service.py::_baseline_tool_executor` — calls
`announce_inflight_tool_call` after emitting `StreamToolInputAvailable`;
`clear_inflight_tool_calls` in the outer `finally` before persist.
- `copilot/tools/helpers.py::require_guide_read` — reads the new tracker
first.
- `copilot/tools/web_search.py` (new) — Anthropic `web_search_20250305`
wrapper + cost estimator.
- `copilot/tools/web_search_test.py` (new) — extractor / cost / dispatch
/ registry tests (12 total).
- `copilot/tools/models.py` — `WebSearchResponse` + `WebSearchResult` +
`ResponseType.WEB_SEARCH`.
- `copilot/tools/__init__.py` — registers `web_search`.
- `copilot/sdk/tool_adapter.py` — moves native `WebSearch` to
`SDK_DISALLOWED_TOOLS`.

### Checklist

For code changes:
- [x] Changes listed above
- [x] Test plan below
- [ ] Tested according to the test plan:
  - [x] `poetry run pytest backend/copilot/baseline/` — all pass
- [x] `poetry run pytest backend/copilot/sdk/` — all pass (SDK resolver
untouched)
- [x] `poetry run pytest backend/copilot/tools/web_search_test.py` — 12
pass
- [ ] Manual: send a multi-step prompt on fast mode with default config;
confirm backend routes to `moonshotai/kimi-k2.6`, SSE stream carries
`reasoning-start/delta/end` (coalesced), Reasoning collapse renders +
survives hard reload.
- [ ] Manual: 43-tool payload reliability on Kimi — watch for malformed
tool-call JSON or wrong-tool selection.
- [ ] Manual: `CHAT_FAST_STANDARD_MODEL=anthropic/claude-sonnet-4-6`
restarts confirm Sonnet routing (rollback path works).
- [ ] Manual: SDK path (`CHAT_USE_CLAUDE_AGENT_SDK=true`) still selects
the SDK service and uses `thinking_standard_model` = Sonnet (no Kimi
leaked into extended thinking).
- [ ] Manual: prompt that forces `web_search` — confirm results render,
`persist_and_record_usage(provider='anthropic')` runs, cost lands in the
per-user ledger.
- [ ] Manual: ask Kimi a question that would require
`agent_building_guide` — confirm the guide loads exactly once per turn
(no loop).

For configuration changes:
- [x] `.env.default` — all four model fields fall back to the pydantic
defaults; legacy `CHAT_MODEL` / `CHAT_FAST_MODEL` /
`CHAT_ADVANCED_MODEL` remain honored via `AliasChoices`.
2026-04-22 08:47:08 +07:00
Nicholas Tindle
e4f291e54b feat(frontend): add AutoGPT logo to share page and zip download for outputs (#11741)
### Why / What / How

**Why:** The share page was unbranded (no logo/navigation) and images
from workspace files couldn't render because the proxy didn't handle
public share URLs. Zip downloads also had several gaps — no size limits,
no workspace file support, silent failures on data URLs, and single
files got wrapped in unnecessary zips.

**What:** Adds AutoGPT branding to the share page, secure public access
to workspace files via a SharedExecutionFile allowlist, and a hardened
zip download module.

**How:** Backend scans execution outputs for `workspace://` URIs on
share-enable and persists an allowlist in a new `SharedExecutionFile`
table. A new unauthenticated endpoint serves files validated against
this allowlist. Frontend proxy routing is extended (with UUID
validation) to handle the 7-segment public share download path as a
binary response. Download logic is consolidated into a shared module
with size limits, parallel fetches, filename sanitization, and
single-file direct download.

### Changes 🏗️

**Share page branding:**
- AutoGPT logo header centered at top, linking to `/`
- Dark/light mode variants with correct `priority` on visible variant
only

**Secure public workspace file access (backend):**
- New `SharedExecutionFile` Prisma model with `@@unique([shareToken,
fileId])` constraint
- `_extract_workspace_file_ids()` scans outputs for `workspace://` URIs
(handles nested dicts/lists)
- `create_shared_execution_files()` / `delete_shared_execution_files()`
manage allowlist lifecycle
- Re-share cleans up stale records before creating new ones (prevents
old token access)
- `GET /public/shared/{token}/files/{id}/download` — validates against
allowlist, uniform 404 for all failures
- `Content-Disposition: inline` for share page rendering
- Hand-written Prisma migration
(`20260417000000_add_shared_execution_file`)

**Frontend proxy fix:**
- `isWorkspaceDownloadRequest` extended to match public share path
(7-segment)
- UUID format validation on dynamic path segments (file IDs, share
tokens)
- 30+ adversarial security tests: path traversal, SQL injection, SSRF
payloads, unicode homoglyphs, null bytes, prototype pollution, etc.

**Download module (`download-outputs.ts`):**
- Consolidated from two divergent copies into single shared module
- `fetchFileAsBlob` with content-length pre-check before buffering
- `sanitizeFilename` strips path traversal, leading dots, falls back to
"file"
- `getUniqueFilename` deduplicates with counter suffix
- `fetchInParallel` with configurable concurrency (5)
- 50 MB per-file limit, 200 MB aggregate limit
- Data URL try-catch, relative URL support (`/api/proxy/...`)
- Single-file downloads skip zip, go directly to browser download
- Dynamic JSZip import for bundle optimization
- 26 unit tests

**Share page file rendering:**
- `WorkspaceFileRenderer` builds public share URLs when `shareToken` is
in metadata
- `RunOutputs` propagates `shareToken` to renderer metadata

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Share page renders with centered AutoGPT logo
  - [x] Logo links to `/` and shows correct dark/light variant
  - [x] Workspace images render inline on share page
  - [x] Download all produces zip with workspace images included
  - [x] Single-file download skips zip, downloads directly
- [x] Re-sharing generates new token and cleans up old allowlist records
  - [x] Public file download returns 404 for files not in allowlist
  - [x] All frontend tests pass (122 tests across 3 suites)
  - [x] Backend formatter + pyright pass
  - [x] Frontend format + lint + types pass

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

> Note: New Prisma migration required. No env/docker changes needed.

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Medium Risk**
> Adds a new unauthenticated file download path gated by a database
allowlist plus a new Prisma model/migration; mistakes here could expose
workspace files or break sharing. Frontend download behavior also
changes significantly (zipping/fetching), which could impact
large-output performance and edge cases.
> 
> **Overview**
> Enables **public rendering and downloading of workspace files on
shared execution pages** by introducing a `SharedExecutionFile`
allowlist tied to the share token and populating it when sharing is
enabled (and clearing it on disable/re-share).
> 
> Adds `GET /public/shared/{share_token}/files/{file_id}/download` (no
auth) that validates the requested file against the allowlist and
returns a uniform 404 on failure; workspace download responses now
support `inline` `Content-Disposition` via the exported
`create_file_download_response` helper.
> 
> Frontend updates the share page to pass `shareToken` into output
renderers so `WorkspaceFileRenderer` can build public-share download
URLs; the proxy matcher is extended/strictly UUID-validated for both
workspace and public-share download paths with extensive adversarial
tests. Output downloading is consolidated into `download-outputs.ts`
using dynamic `jszip` import, filename sanitization/deduping,
concurrency + size limits, and a single-file non-zip fast path.
> 
> <sup>Reviewed by [Cursor Bugbot](https://cursor.com/bugbot) for commit
e2f5bd9b5a. Bugbot is set up for automated
code reviews on this repo. Configure
[here](https://www.cursor.com/dashboard/bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <ntindle@users.noreply.github.com>
Co-authored-by: Otto <otto@agpt.co>
2026-04-21 16:26:37 +00:00
Bently
6efbc59fd8 feat(backend): platform server linking API for multi-platform CoPilot (#12615)
## Why
AutoPilot (CoPilot) needs to reach users across chat platforms — Discord
first, Telegram / Slack / Teams / WhatsApp next. To make usage and
billing coherent, every conversation resolves to one AutoGPT account.
There are two independent linking flows:

- **SERVER links**: the first person to claim a server (Discord guild,
Telegram group, …) becomes its owner. Anyone in the server can chat with
the bot; all usage bills to the owner.
- **USER links**: an individual links their 1:1 DMs with the bot to
their own AutoGPT account. Independent from server links — a server
owner still has to link their DMs separately.

## What
Backend for platform linking, split cleanly by trust boundary:

- **Bot-facing operations** run over cluster-internal RPC via a new
`PlatformLinkingManager(AppService)`. No shared bearer token; trust is
the cluster network itself.
- **User-facing operations** stay on REST under JWT auth (the same
pattern as every other feature).

### REST endpoints (JWT auth)

- `GET /api/platform-linking/tokens/{token}/info` — non-sensitive
display info for the link page
- `POST /api/platform-linking/tokens/{token}/confirm` — confirm a SERVER
link
- `POST /api/platform-linking/user-tokens/{token}/confirm` — confirm a
USER link
- `GET /api/platform-linking/links` / `DELETE /links/{id}` — manage
server links
- `GET /api/platform-linking/user-links` / `DELETE /user-links/{id}` —
manage DM links

### `PlatformLinkingManager` `@expose` methods (internal RPC)

- `resolve_server_link(platform, platform_server_id) -> ResolveResponse`
- `resolve_user_link(platform, platform_user_id) -> ResolveResponse`
- `create_server_link_token(req) -> LinkTokenResponse`
- `create_user_link_token(req) -> LinkTokenResponse`
- `get_link_token_status(token) -> LinkTokenStatusResponse`
- `start_chat_turn(req) -> ChatTurnHandle` — resolves the owner,
persists the user message, creates the stream-registry session, enqueues
the turn; returns `(session_id, turn_id, user_id, subscribe_from="0-0")`
so the caller subscribes directly to the per-turn Redis stream.

### New DB models
- `PlatformLink` — `(platform, platformServerId)` → owner's AutoGPT
`userId`
- `PlatformUserLink` — `(platform, platformUserId)` → AutoGPT `userId`
(for DMs)
- `PlatformLinkToken` — one-time token with `linkType` discriminator
(SERVER | USER) and 30-min TTL

## How

- **New `backend/platform_linking/` package**: `models.py` (Pydantic
types), `links.py` (link CRUD helpers — pure business logic), `chat.py`
(`start_chat_turn` orchestration), `manager.py`
(`PlatformLinkingManager(AppService)` + `PlatformLinkingManagerClient`).
Pattern matches `backend/notifications/` + `backend/data/db_manager.py`.
- **Exception translation at the edge**. Helpers raise domain exceptions
(`NotFoundError`, `LinkAlreadyExistsError`, `LinkTokenExpiredError`,
`LinkFlowMismatchError`, `NotAuthorizedError` — all `ValueError`
subclasses in `backend.util.exceptions` so they auto-register with the
RPC exception-mapping). REST routes translate to HTTP codes via a 7-line
`_translate()` helper.
- **Independent scopes, no DM fallback**. `find_server_link()` and
`find_user_link()` each query their own table. A user who owns a linked
server does not leak that identity into their DMs.
- **Race-safe token consumption**. Confirm paths do atomic `update_many`
with `usedAt = None` + `expiresAt > now` in the WHERE clause;
`create_*_token` invalidates pending tokens before issuing a new one.
- **Bug fix**: `start_chat_turn` persists the user message via
`append_and_save_message` before enqueueing the executor turn — mirrors
`backend/api/features/chat/routes.py`. The previous `chat_proxy.py`
skipped this and ran the executor with no user message in history.
- **Streaming**. Copilot streaming lives on Redis Streams (persistent,
replayable). The bot subscribes directly with `subscribe_from="0-0"`, so
late subscribers replay the full stream; no HTTP SSE proxy needed.
- **No PII in logs**: logs reference `session_id`, `turn_id`,
`server_id`, and AutoGPT `user_id` (last 8 chars), but never raw
platform user IDs.
- **New pod**. `PlatformLinkingManager` runs as its own `AppProcess` on
port `8009`; client via `get_platform_linking_manager_client()`. The
infra chart lands in
[cloud-infrastructure#310](https://github.com/Significant-Gravitas/AutoGPT_cloud_infrastructure/pull/310).

## Tests
- **Models** (`models_test.py`) — Platform / LinkType enums, request
validation (CreateLinkToken / ResolveServer / BotChat), response
schemas.
- **Helpers** (`links_test.py`) — resolve, token create (both flows, 409
on already-linked), token status (pending / linked / expired /
superseded-with-no-link), token info (404 / 410), confirm (404 / wrong
flow / already used / expired / same-user / other-user), delete authz.
- **AppService wiring** (`manager_test.py`) — `@expose` methods delegate
to helpers; client surface covers bot-facing ops and excludes
user-facing ones.
- **Adversarial** (`manager_test.py`, `routes_test.py`):
- `asyncio.gather` double-confirm with same user and with two different
users — exactly one winner, other gets clean `LinkTokenExpiredError`, no
double `PlatformLink.create`.
  - Server- and user-link confirm races.
- `TokenPath` regex guard: rejects `%24`, URL-encoded path traversal,
>64 chars; accepts `secrets.token_urlsafe` shape.
- DELETE `link_id` with SQL-injection-style and path-traversal inputs
returns 404 via `NotFoundError`.

## Stack
- #12618 — bot service (rebased onto this so it can consume
`PlatformLinkingManagerClient`)
- #12624 — `/link/{token}` frontend page
-
[cloud-infrastructure#310](https://github.com/Significant-Gravitas/AutoGPT_cloud_infrastructure/pull/310)
— Helm chart for `copilot-bot` + new `platform-linking-manager`

Merge order: this → #12618#12624, infra whenever.

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: CodeRabbit <noreply@coderabbit.ai>
2026-04-21 16:01:03 +00:00
Nicholas Tindle
6924cf90a5 fix(frontend/copilot): artifact panel fixes (SECRT-2254/2223/2220/2255/2224/2256/2221) (#12856)
### Why / What / How


https://github.com/user-attachments/assets/ca26e0b0-d35d-4a5b-b95f-2421b9907742


**Why** — The Artifact & Side Task List project
(https://linear.app/autogpt/project/artifact-and-side-task-list-ef863c93da3c)
accumulated seven related bugs in the copilot artifact panel. The user
kept seeing panels stuck open, previews broken, clicks not registering —
each ticket was small but they all lived in the same small surface area,
so one review pass is easier than five.

Closes SECRT-2254, SECRT-2223, SECRT-2220, SECRT-2255, SECRT-2224,
SECRT-2256, SECRT-2221.

**What** — Five independent fixes, each in its own commit, shipped
together:

1. **Fragment-link interceptor + render error boundary** (SECRT-2255
crash when clicking `<a href="#x">` in HTML artifacts). Sandboxed srcdoc
iframes resolve fragment links against the parent's URL, so clicking
`#activation` in a Plotly TOC tried to navigate the copilot page into
the iframe. Inject a click-capture script into every artifact iframe;
also wrap the renderer in `ArtifactErrorBoundary` so any future render
throw surfaces with a copyable error instead of a blank panel.
2. **Close panel on copilot page unmount** (SECRT-2254 / 2223 / 2220 —
panel stays open, reopens on unrelated navigation, opens by default on
session switch). The Zustand store outlived page unmounts, so `isOpen:
true` survived `/profile` → `/home` → back. One `useEffect` cleanup in
`useAutoOpenArtifacts` calls `resetArtifactPanel()` on unmount.
3. **Sync loading flip on Try Again** (SECRT-2224 "try again doesn't do
anything"). Retry was correct but the loading-state flip was deferred to
an effect, so a retry that re-failed was visually indistinguishable from
a no-op. `retry()` now sets `isLoading: true` / `error: null`
synchronously with the click so the skeleton flashes every time.
4. **Pointer capture on resize drag** (SECRT-2256 "can't drag right when
expanded far left, click doesn't stop it"). The sandboxed iframe was
eating `pointermove`/`pointerup` events when the cursor drifted over it,
freezing the drag and never delivering the release. `setPointerCapture`
on the handle routes all subsequent pointer events through it regardless
of what's under the cursor.
5. **Stop size-gating natively-rendered artifacts + cache-bust retry**
(SECRT-2221 "broken hi-res PNG preview"). The blanket >10 MB size gate
pushed large images / videos / PDFs into `download-only`, so clicking a
hi-res PNG offered a download instead of a preview. Split the gate so it
only applies to content we actually render in JS (text/html/code/etc).
Image and video retries also append a cache-bust query so the browser
can't silently reuse a negative-cached failure.

**How** — Five commits, one concern each, preserved in the order they
were written. Every fix lands with a regression test that fails on the
unfixed code and passes after.

### Changes 🏗️

- `iframe-sandbox-csp.ts` + usage sites —
`FRAGMENT_LINK_INTERCEPTOR_SCRIPT` injected into all three srcdoc iframe
templates (HTML artifact, inline HTMLRenderer, React artifact).
- `ArtifactErrorBoundary.tsx` (new) — class error boundary local to the
artifact panel with a copyable error fallback.
- `useAutoOpenArtifacts.ts` — unmount cleanup calls
`resetArtifactPanel()`.
- `useArtifactContent.ts` — `retry()` flips loading state synchronously.
- `ArtifactDragHandle.tsx` — `setPointerCapture` /
`releasePointerCapture`; `touch-action: none`.
- `helpers.ts` — split classifier; `NATIVELY_RENDERED` exempts
image/video/pdf from the size gate.
- `ArtifactContent.tsx` — image/video carry a retry nonce that appends
`?_retry=N` on Try Again.
- Test files — new
`ArtifactErrorBoundary`/`ArtifactDragHandle`/`HTMLRenderer` tests, plus
regression cases added to `ArtifactContent.test.tsx`, `helpers.test.ts`,
`iframe-sandbox-csp.test.ts`, `reactArtifactPreview.test.ts`,
`useAutoOpenArtifacts.test.ts`.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] `pnpm vitest run src/app/\(platform\)/copilot
src/components/contextual/OutputRenderers
src/lib/__tests__/iframe-sandbox-csp.test.ts` — 247/247 pass
  - [x] `pnpm format && pnpm types` clean
- [x] Manual: open the Plotly-style TOC HTML artifact (SECRT-2255
repro), click each anchor — iframe scrolls internally, browser URL bar
stays put
- [x] Manual: open panel → navigate to /profile → navigate back → panel
closed (SECRT-2254)
- [x] Manual: panel open in session A → click different session → panel
closed (SECRT-2223)
- [ ] Manual: simulate a failed artifact fetch → click Try Again →
skeleton flashes before result (SECRT-2224)
- [x] Manual: expand panel to near-full width → drag back right,
crossing over the iframe → drag keeps working and release ends it
(SECRT-2256)
- [x] Manual: upload a ~25 MB PNG → clicking it previews in an `<img>`,
not a download button (SECRT-2221)

Replaces #12836, #12837, #12838, #12839, #12840 — same fixes, bundled
for review.


<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Medium Risk**
> Touches artifact rendering and iframe `srcDoc` generation (including
injected scripts) plus panel state/drag interactions; regressions could
break previews or resizing, but changes are scoped to the copilot
artifact UI with broad test coverage.
> 
> **Overview**
> Improves Copilot’s artifact panel resilience and UX by **resetting
panel state on page unmount/session changes**, making content retries
immediately show the loading skeleton, and fixing resize drags via
pointer capture so iframes can’t “steal” pointer events.
> 
> Hardens artifact rendering by adding a local `ArtifactErrorBoundary`
that reports to Sentry and shows a copyable error fallback instead of a
blank/crashed panel.
> 
> Fixes iframe-based previews by injecting a
`FRAGMENT_LINK_INTERCEPTOR_SCRIPT` into HTML and React artifact `srcDoc`
so `#anchor` clicks scroll within the iframe rather than navigating the
parent URL, and adjusts artifact classification/retry behavior so large
images/videos/PDFs remain previewable and image/video retries cache-bust
failed URLs.
> 
> <sup>Reviewed by [Cursor Bugbot](https://cursor.com/bugbot) for commit
bde37a13fd. Bugbot is set up for automated
code reviews on this repo. Configure
[here](https://www.cursor.com/dashboard/bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-21 15:53:01 +00:00
Nicholas Tindle
07e5a6a9e4 [Snyk] Security upgrade next from 15.4.10 to 15.4.11 (#12715)
![snyk-top-banner](https://res.cloudinary.com/snyk/image/upload/r-d/scm-platform/snyk-pull-requests/pr-banner-default.svg)

### Snyk has created this PR to fix 1 vulnerabilities in the yarn
dependencies of this project.

#### Snyk changed the following file(s):

- `autogpt_platform/frontend/package.json`


#### Note for
[zero-installs](https://yarnpkg.com/features/zero-installs) users

If you are using the Yarn feature
[zero-installs](https://yarnpkg.com/features/zero-installs) that was
introduced in Yarn V2, note that this PR does not update the
`.yarn/cache/` directory meaning this code cannot be pulled and
immediately developed on as one would expect for a zero-install project
- you will need to run `yarn` to update the contents of the
`./yarn/cache` directory.
If you are not using zero-install you can ignore this as your flow
should likely be unchanged.



<details>
<summary>⚠️ <b>Warning</b></summary>

```
Failed to update the yarn.lock, please update manually before merging.
```

</details>



#### Vulnerabilities that will be fixed with an upgrade:

|  | Issue |  
:-------------------------:|:-------------------------
![high
severity](https://res.cloudinary.com/snyk/image/upload/w_20,h_20/v1561977819/icon/h.png
'high severity') | Allocation of Resources Without Limits or Throttling
<br/>[SNYK-JS-NEXT-15921797](https://snyk.io/vuln/SNYK-JS-NEXT-15921797)




---

> [!IMPORTANT]
>
> - Check the changes in this PR to ensure they won't cause issues with
your project.
> - Max score is 1000. Note that the real score may have changed since
the PR was raised.
> - This PR was automatically created by Snyk using the credentials of a
real user.

---

**Note:** _You are seeing this because you or someone else with access
to this repository has authorized Snyk to open fix PRs._

For more information: <img
src="https://api.segment.io/v1/pixel/track?data=eyJ3cml0ZUtleSI6InJyWmxZcEdHY2RyTHZsb0lYd0dUcVg4WkFRTnNCOUEwIiwiYW5vbnltb3VzSWQiOiJmM2NkN2NiMy1iYzU5LTRkMDMtOGExMi0xOTEwMDk4OGQwNmUiLCJldmVudCI6IlBSIHZpZXdlZCIsInByb3BlcnRpZXMiOnsicHJJZCI6ImYzY2Q3Y2IzLWJjNTktNGQwMy04YTEyLTE5MTAwOTg4ZDA2ZSJ9fQ=="
width="0" height="0"/>
🧐 [View latest project
report](https://app.snyk.io/org/significant-gravitas/project/3d924968-0cf3-4767-9609-501fa4962856?utm_source&#x3D;github&amp;utm_medium&#x3D;referral&amp;page&#x3D;fix-pr)
📜 [Customise PR
templates](https://docs.snyk.io/scan-using-snyk/pull-requests/snyk-fix-pull-or-merge-requests/customize-pr-templates?utm_source=github&utm_content=fix-pr-template)
🛠 [Adjust project
settings](https://app.snyk.io/org/significant-gravitas/project/3d924968-0cf3-4767-9609-501fa4962856?utm_source&#x3D;github&amp;utm_medium&#x3D;referral&amp;page&#x3D;fix-pr/settings)
📚 [Read about Snyk's upgrade
logic](https://docs.snyk.io/scan-with-snyk/snyk-open-source/manage-vulnerabilities/upgrade-package-versions-to-fix-vulnerabilities?utm_source=github&utm_content=fix-pr-template)

---

**Learn how to fix vulnerabilities with free interactive lessons:**

🦉 [Allocation of Resources Without Limits or
Throttling](https://learn.snyk.io/lesson/no-rate-limiting/?loc&#x3D;fix-pr)

[//]: #
'snyk:metadata:{"breakingChangeRiskLevel":null,"FF_showPullRequestBreakingChanges":false,"FF_showPullRequestBreakingChangesWebSearch":false,"customTemplate":{"variablesUsed":[],"fieldsUsed":[]},"dependencies":[{"name":"next","from":"15.4.10","to":"15.4.11"}],"env":"prod","issuesToFix":["SNYK-JS-NEXT-15921797"],"prId":"f3cd7cb3-bc59-4d03-8a12-19100988d06e","prPublicId":"f3cd7cb3-bc59-4d03-8a12-19100988d06e","packageManager":"yarn","priorityScoreList":[null],"projectPublicId":"3d924968-0cf3-4767-9609-501fa4962856","projectUrl":"https://app.snyk.io/org/significant-gravitas/project/3d924968-0cf3-4767-9609-501fa4962856?utm_source=github&utm_medium=referral&page=fix-pr","prType":"fix","templateFieldSources":{"branchName":"default","commitMessage":"default","description":"default","title":"default"},"templateVariants":["updated-fix-title","pr-warning-shown"],"type":"auto","upgrade":["SNYK-JS-NEXT-15921797"],"vulns":["SNYK-JS-NEXT-15921797"],"patch":[],"isBreakingChange":false,"remediationStrategy":"vuln"}'

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Medium Risk**
> Patch-level upgrade of a core runtime/build dependency (Next.js) can
affect app rendering/build behavior despite being scoped to
dependency/lockfile changes.
> 
> **Overview**
> Upgrades the frontend framework dependency `next` from `15.4.10` to
`15.4.11` in `package.json`.
> 
> Updates `pnpm-lock.yaml` to reflect the new Next.js version (including
`@next/env`) and re-resolves dependent packages that pin `next` in their
peer/optional dependency graphs (e.g., `@sentry/nextjs`,
`@vercel/analytics`, Storybook Next integration).
> 
> <sup>Reviewed by [Cursor Bugbot](https://cursor.com/bugbot) for commit
dc19e1f178. Bugbot is set up for automated
code reviews on this repo. Configure
[here](https://www.cursor.com/dashboard/bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: snyk-bot <snyk-bot@snyk.io>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-21 15:44:47 +00:00
Zamil Majdy
a098f01bd2 feat(builder): AI chat panel for the flow builder (#12699)
### Why

The flow builder had no AI assistance. Users had to switch to a separate
Copilot session to ask about or modify the agent they were looking at,
and that session had no context on the graph — so the LLM guessed, or
the user had to describe the graph by hand.

### What

An AI chat panel anchored to the `/build` page. Opens with a chat-circle
button (bottom-right), binds to the currently-opened agent, and offers
**only** two tools: `edit_agent` and `run_agent`. Per-agent session is
persisted server-side, so a refresh resumes the same conversation. Gated
behind `Flag.BUILDER_CHAT_PANEL` (default off;
`NEXT_PUBLIC_FORCE_FLAG_BUILDER_CHAT_PANEL=true` to enable locally).

### How

**Frontend — new**:
- `(platform)/build/components/BuilderChatPanel/` — panel shell +
`useBuilderChatPanel.ts` coordinator. Renders the shared Copilot
`ChatMessagesContainer` + `ChatInput` (thought rendering, pulse chips,
fast-mode toggle — all reused, no parallel chat stack). Auto-creates a
blank agent when opened with no `flowID`. Listens for `edit_agent` /
`run_agent` tool outputs and wires them to the builder in-place: edit →
`flowVersion` URL param + canvas refetch; run → `flowExecutionID` URL
param → builder's existing execution-follow UI opens.

**Frontend — touched (minimal)**:
- `copilot/components/CopilotChatActionsProvider` — new `chatSurface:
"copilot" | "builder"` flag so cards can suppress "Open in library" /
"Open in builder" / "View Execution" buttons when the chat is the
builder panel (you're already there).
- `copilot/tools/RunAgent/components/ExecutionStartedCard` — title is
now status-aware (`QUEUED → "Execution started"`, `COMPLETED →
"Execution completed"`, `FAILED → "Execution failed"`, etc.).
- `build/components/FlowEditor/Flow/Flow.tsx` — mount the panel behind
the feature flag.

**Backend — new**:
- `copilot/builder_context.py` — the builder-session logic module. Holds
the tool whitelist (`edit_agent`, `run_agent`), the permissions
resolver, the session-long system-prompt suffix (graph id/name + full
agent-building guide — cacheable across turns), and the per-turn
`<builder_context>` prefix (live version + compact nodes/links
snapshot).
- `copilot/builder_context_test.py` — covers both builders, ownership
forwarding, and cap behavior.

**Backend — touched**:
- `api/features/chat/routes.py` — `CreateSessionRequest` gains
`builder_graph_id`. When set, the endpoint routes through
`get_or_create_builder_session` (keyed on `user_id`+`graph_id`, with a
graph-ownership check). No new route; the former `/sessions/builder` is
folded into `POST /sessions`.
- `copilot/model.py` — `ChatSessionMetadata.builder_graph_id`;
`get_or_create_builder_session` helper.
- `data/graph.py` — `GraphSettings.builder_chat_session_id` (new typed
field; stores the builder-chat session pointer per library agent).
- `api/features/library/db.py` —
`update_library_agent_version_and_settings` preserves
`builder_chat_session_id` across graph-version bumps.
- `copilot/tools/edit_agent.py`, `run_agent.py` — builder-bound guard:
default missing `agent_id` to the bound graph, reject any other id.
`run_agent` additionally inlines `node_executions` into dry-run
responses so the LLM can inspect per-node status in the same turn
instead of a follow-up `view_agent_output`. `wait_for_result` docs now
explain the two dispatch modes.
- `copilot/tools/helpers.py::require_guide_read` — bypassed for
builder-bound sessions (the guide is already in the system-prompt
suffix).
- `copilot/tools/agent_generator/pipeline.py` + `tools/models.py` —
`AgentSavedResponse.graph_version` so the frontend can flip
`flowVersion` to the newly-saved version.
- `copilot/baseline/service.py` + `sdk/service.py` — inject the builder
context suffix into the system prompt and the per-turn prefix into the
current user message.
- `blocks/_base.py` — `validate_data(..., exclude_fields=)` so dry-run
can bypass credential required-checks for blocks that need creds in
normal mode (OrchestratorBlock). `blocks/perplexity.py` override
signature matches.
- `executor/simulator.py` — OrchestratorBlock dry-run iteration cap `1 →
min(original, 10)` so multi-role patterns (Advocate/Critic) actually
close the loop; `manager.py` synthesizes placeholder creds in dry-run so
the block's schema validation passes.

### Session lookup

The builder-chat session pointer lives on
`LibraryAgent.settings.builder_chat_session_id` (typed via
`GraphSettings`). `get_or_create_builder_session` reads/writes it
through `library_db().get_library_agent_by_graph_id` +
`update_library_agent(settings=...)` — no raw SQL or JSON-path filter.
Ownership is enforced by the library-agent query's `userId` filter. The
per-session builder binding still lives on
`ChatSession.metadata.builder_graph_id` (used by
`edit_agent`/`run_agent` guards and the system-prompt injection).

### Scope footnotes

- Feature flag defaults **false**. Rollout gate lives in LaunchDarkly.
- No schema migration required: `builder_chat_session_id` slots into the
existing `LibraryAgent.settings` JSON column via the typed
`GraphSettings` model.
- Commits that address review / CI cycles are interleaved with feature
commits — see the commit log for the per-change rationale.

### Test plan

- [x] `pnpm test:unit` + backend `poetry run test` for new and touched
modules
- [x] Agent-browser pass: panel toggle / auto-create / real-time edit
re-render / real-time exec URL subscribe / queue-while-streaming /
cross-graph reset / hard-refresh session persist
- [x] Codecov patch ≥ 80% on diff

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-21 22:47:23 +07:00
Nicholas Tindle
59273fe6a0 fix(frontend): forward sentry-trace and baggage across API proxy (#12835)
### Why / What / How

**Why:** Every request that went through Next's rewrite proxy broke
distributed tracing. The browser Sentry SDK emitted `sentry-trace` and
`baggage`, but `createRequestHeaders` only forwarded impersonation + API
key, so the backend started a disconnected transaction. The frontend →
backend lineage never appeared in Sentry. Same gap on
direct-from-browser requests: the custom mutator never attached the
trace headers itself, so even non-proxied paths lost the link.

**What:**
- **Server side:** forward `sentry-trace` and `baggage` from
`originalRequest.headers` alongside the existing impersonation/API key
forwarding.
- **Client side:** the custom mutator pulls trace data via
`Sentry.getTraceData()` and attaches it to outgoing headers when running
on the client.

**How:** Inline additions — no new observability module, no new
dependencies beyond `@sentry/nextjs` which the frontend already uses for
Sentry init.

### Changes 🏗️

- `src/lib/autogpt-server-api/helpers.ts` — forward `sentry-trace` +
`baggage` in `createRequestHeaders`.
- `src/app/api/mutators/custom-mutator.ts` — import `@sentry/nextjs`,
attach `Sentry.getTraceData()` on client-side requests.
- `src/app/api/mutators/__tests__/custom-mutator.test.ts` — three new
tests: trace-data present, trace-data empty, server-side no-op.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [ ] I have tested my changes according to the test plan:
- [x] `pnpm vitest run
src/app/api/mutators/__tests__/custom-mutator.test.ts` passes (6/6
locally)
  - [x] `pnpm format && pnpm lint` clean
- [x] `pnpm types` clean for touched files (pre-existing unrelated type
errors on dev are untouched)
- [ ] In a local session with Sentry enabled, a `/copilot` chat turn
produces a distributed trace that spans frontend transaction → backend
transaction (single trace ID in Sentry)

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Low Risk**
> Low risk: header-only changes to request construction for
observability, with added tests; primary risk is unintended header
propagation affecting upstream/proxy behavior.
> 
> **Overview**
> Restores **Sentry distributed tracing continuity** for
frontend→backend calls by propagating `sentry-trace`/`baggage` headers.
> 
> On the client, `customMutator` now reads `Sentry.getTraceData()` and
attaches string trace headers to outgoing requests (guarded for
server-side and older Sentry builds). On the server/proxy path,
`createRequestHeaders` now forwards `sentry-trace` and `baggage` from
the incoming `originalRequest` alongside existing impersonation/API-key
forwarding, with new unit tests covering these cases.
> 
> <sup>Reviewed by [Cursor Bugbot](https://cursor.com/bugbot) for commit
0f6946b776. Bugbot is set up for automated
code reviews on this repo. Configure
[here](https://www.cursor.com/dashboard/bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-21 15:29:19 +00:00
Nicholas Tindle
38c2844b83 feat(admin): Add system diagnostics and execution management dashboard (#11235)
### Changes 🏗️
This PR adds a comprehensive admin diagnostics dashboard for monitoring
system health and managing running executions.


https://github.com/user-attachments/assets/f7afa3ed-63d8-4b5c-85e4-8756d9e3879e


#### Backend Changes:
- **New data layer** (backend/data/diagnostics.py): Created a dedicated
diagnostics module following the established data layer pattern
- get_execution_diagnostics() - Retrieves execution metrics (running,
queued, completed counts)
  - get_agent_diagnostics() - Fetches agent-related metrics
- get_running_executions_details() - Lists all running executions with
detailed info
- stop_execution() and stop_executions_bulk() - Admin controls for
stopping executions

- **Admin API endpoints**
(backend/server/v2/admin/diagnostics_admin_routes.py):
  - GET /admin/diagnostics/executions - Execution status metrics
  - GET /admin/diagnostics/agents - Agent utilization metrics
- GET /admin/diagnostics/executions/running - Paginated list of running
executions
  - POST /admin/diagnostics/executions/stop - Stop single execution
- POST /admin/diagnostics/executions/stop-bulk - Stop multiple
executions
  - All endpoints secured with admin-only access

#### Frontend Changes:
- **Diagnostics Dashboard**
(frontend/src/app/(platform)/admin/diagnostics/page.tsx):
- Real-time system metrics display (running, queued, completed
executions)
  - RabbitMQ queue depth monitoring
  - Agent utilization statistics
  - Auto-refresh every 30 seconds

- **Execution Management Table**
(frontend/src/app/(platform)/admin/diagnostics/components/ExecutionsTable.tsx):
- Displays running executions with: ID, Agent Name, Version, User
Email/ID, Status, Start Time
  - Multi-select functionality with checkboxes
  - Individual stop buttons for each execution
  - "Stop Selected" and "Stop All" bulk actions
  - Confirmation dialogs for safety
  - Pagination for handling large datasets
  - Toast notifications for user feedback

#### Security:
- All admin endpoints properly secured with requires_admin_user
decorator
- Frontend routes protected with role-based access controls
- Admin navigation link only visible to admin users

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  
  - [x] Verified admin-only access to diagnostics page
  - [x] Tested execution metrics display and auto-refresh
  - [x] Confirmed RabbitMQ queue depth monitoring works
  - [x] Tested stopping individual executions
  - [x] Tested bulk stop operations with multi-select
  - [x] Verified pagination works for large datasets
  - [x] Confirmed toast notifications appear for all actions

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
(no changes needed)
- [x] `docker-compose.yml` is updated or already compatible with my
changes (no changes needed)
- [x] I have included a list of my configuration changes in the PR
description (no config changes required)



<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Medium Risk**
> Adds new admin-only endpoints that can stop, requeue, and bulk-mark
executions as `FAILED`, plus schedule deletion, which can directly
impact production workload and data integrity if misused or buggy.
> 
> **Overview**
> Introduces a **System Diagnostics** admin feature spanning backend +
frontend to monitor execution/schedule health and perform remediation
actions.
> 
> On the backend, adds a new `backend/data/diagnostics.py` data layer
and `diagnostics_admin_routes.py` with admin-secured endpoints to fetch
execution/agent/schedule metrics (including RabbitMQ queue depths and
invalid-state detection), list problem executions/schedules, and perform
bulk operations like `stop`, `requeue`, and `cleanup` (marking
orphaned/stuck items as `FAILED` or deleting orphaned schedules). It
also extends `get_graph_executions`/`get_graph_executions_count` with
`execution_ids` filtering, pagination, started/updated time filters, and
configurable ordering to support efficient bulk/admin queries.
> 
> On the frontend, adds an admin diagnostics page with summary cards and
tables for executions and schedules (tabs for
orphaned/failed/long-running/stuck-queued/invalid, plus confirmation
dialogs for destructive actions), wires it into admin navigation, and
adds comprehensive unit tests for both the new API routes and UI
behavior.
> 
> <sup>Reviewed by [Cursor Bugbot](https://cursor.com/bugbot) for commit
15b9ed26f9. Bugbot is set up for automated
code reviews on this repo. Configure
[here](https://www.cursor.com/dashboard/bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <ntindle@users.noreply.github.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
2026-04-21 15:28:44 +00:00
Zamil Majdy
24850e2a3e feat(backend/autopilot): stream extended_thinking on baseline via OpenRouter (#12870)
### Why / What / How

**Why:** Fast-mode autopilot never renders a Reasoning block. The
frontend already has `ReasoningCollapse` wired up and the wire protocol
already carries `StreamReasoning*` events (landed for SDK mode in
#12853), but the baseline (OpenRouter OpenAI-compat) path never asks
Anthropic for extended thinking and never parses reasoning deltas off
the stream. Result: users on fast/standard get a good answer with no
visible chain-of-thought, while SDK users see the full Reasoning
collapse.

**What:** Plumb reasoning end-to-end through the baseline path by opting
into OpenRouter's non-OpenAI `reasoning` extension, parsing the
reasoning delta fields off each chunk, and emitting the same
`StreamReasoningStart/Delta/End` events the SDK adapter already uses.

**How:**
- **New config:** `baseline_reasoning_max_tokens` (default 8192; 0
disables). Sent as `extra_body={"reasoning": {"max_tokens": N}}` only on
Anthropic routes — other providers drop the field, and
`is_anthropic_model()` already gates this.
- **Delta extraction:** `_extract_reasoning_delta()` handles all three
OpenRouter/provider variants in priority order — legacy
`delta.reasoning` (string), DeepSeek-style `delta.reasoning_content`,
and the structured `delta.reasoning_details` list (text/summary entries;
encrypted or unknown entries are skipped).
- **Event emission:** Reasoning uses the same state-machine rules the
SDK adapter uses — a text delta or tool_use delta arriving mid-stream
closes the open reasoning block first, so the AI SDK v5 transport keeps
reasoning / text / tool-use as distinct UI parts. On stream end, any
still-open reasoning block gets a matching `reasoning-end` so a
reasoning-only turn still finalises the frontend collapse.
- **Scope:** Live streaming only. Reasoning is not persisted to
`ChatMessage` rows or the transcript builder in this PR (SDK path does
so via `content_blocks=[{type: 'thinking', ...}]`, but that round-trip
requires Anthropic signature plumbing baseline doesn't have today).
Reload will still not show reasoning on baseline sessions — can follow
up if we decide it's worth the signature handling.

### Changes

- `backend/copilot/config.py` — new `baseline_reasoning_max_tokens`
field.
- `backend/copilot/baseline/service.py` — new
`_extract_reasoning_delta()` helper; reasoning block state on
`_BaselineStreamState`; `reasoning` gated into `extra_body`; chunk loop
emits `StreamReasoning*` events with text/tool_use transition rules;
stream-end closes any open reasoning block.
- `backend/copilot/baseline/service_unit_test.py` — 11 new tests
covering extractor variants (legacy string, deepseek alias, structured
list with text/summary aliases, encrypted-skip, empty), paired event
ordering (reasoning-end before text-start), reasoning-only streams, and
that the `reasoning` request param is correctly gated by model route
(Anthropic vs non-Anthropic) and by the config flag.

### Checklist

For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [ ] I have tested my changes according to the test plan:
- [x] `poetry run pytest backend/copilot/baseline/service_unit_test.py
backend/copilot/baseline/transcript_integration_test.py` — 103 passed
- [ ] Manual: with `CHAT_USE_CLAUDE_AGENT_SDK=false` and
`CHAT_MODEL=anthropic/claude-sonnet-4-6`, send a multi-step prompt on
fast mode and confirm a Reasoning collapse appears alongside the final
text
- [ ] Manual: flip `CHAT_BASELINE_REASONING_MAX_TOKENS=0` and confirm
baseline responses revert to text-only (no reasoning param, no reasoning
UI)
- [ ] Manual: with a non-Anthropic baseline model (`openai/gpt-4o`),
confirm the request does NOT include `reasoning` and nothing regresses

For configuration changes:
- [x] `.env.default` is compatible — new setting falls back to the
pydantic default
2026-04-21 21:05:00 +07:00
Zamil Majdy
e17e9f13c4 fix(backend/copilot): reduce SDK + baseline prompt cache waste (#12866)
## Summary

Four cost-reduction changes for the copilot feature. Consolidated into
one PR at user request; each commit is self-contained and bisectable.

### 1. SDK: full cross-user cache on every turn (CLI 2.1.116 bump)
Previous behavior: CLI 2.1.97 crashed when `excludeDynamicSections=True`
was combined with `--resume`, so the code fell back to a raw
`system_prompt` string on resume, losing Claude Code's default prompt
and all cache markers. Every Turn 2+ of an SDK session wrote ~33K tokens
to cache instead of reading.

Fix: install `@anthropic-ai/claude-code@2.1.116` in the backend Docker
image and point the SDK at it via
`CHAT_CLAUDE_AGENT_CLI_PATH=/usr/bin/claude`. CLI 2.1.98+ fixes the
crash, so we can use the preset with `exclude_dynamic_sections=True` on
every turn — Turn 1, 2, 3+ all share the same static prefix and hit the
**cross-user** prompt cache.

**Local dev requirement:** if `CHAT_CLAUDE_AGENT_CLI_PATH` is unset, the
bundled 2.1.97 fallback will crash on `--resume`. Install the CLI
globally (`npm install -g @anthropic-ai/claude-code@2.1.116`) or set the
env var.

### 2. Baseline: add `cache_control` markers (commit `756b3ecd9` +
follow-ups)
Baseline path had zero `cache_control` across `backend/copilot/**`.
Every turn was full uncached input (~18.6K tokens, ~$0.058). Two
ephemeral markers — on the system message (content-blocks form) and the
last tool schema — plus `anthropic-beta: prompt-caching-2024-07-31` via
`extra_headers` as defense-in-depth. Helpers split into `_mark_tools_*`
(precomputed once per session) and `_mark_system_*` (per-round, O(1)).
Repeat hellos: ~$0.058 → ~$0.006.

### 3. Drop `get_baseline_supplement()` (commit `6e6c4d791`)
`_generate_tool_documentation()` emitted ~4.3K tokens of `(tool_name,
description)` pairs that exactly duplicated the tools array already in
the same request. Deleted. `SHARED_TOOL_NOTES` (cross-tool workflow
rules) is preserved. Baseline "hello" input: ~18.7K → ~14.4K tokens.

### 4. Langfuse "CoPilot Prompt" v26 (published under `review` label)
Separate, out-of-repo change. v25 had three duplicate "Example Response"
blocks + a 10-step "Internal Reasoning Process" section. v26 collapses
to one example + bullet-form reasoning. Char count 20,481 → 7,075 (rough
4 chars/token → ~5,100 → ~1,770 tokens).

- v26 is published with label `review` (NOT `production`); v25 remains
active.
- Promote via `mcp__langfuse__updatePromptLabels(name="CoPilot Prompt",
version=26, newLabels=["production"])` after smoke-test.
- Rollback: relabel v25 `production`.

## Test plan
- [x] Unit tests for `_build_system_prompt_value` (fresh vs resumed
turns emit identical preset dict)
- [x] SDK compat tests pass including
`test_bundled_cli_version_is_known_good_against_openrouter`
- [x] `cli_openrouter_compat_test.py` passes against CLI 2.1.116
(locally verified with
`CHAT_CLAUDE_AGENT_CLI_PATH=/opt/homebrew/bin/claude`)
- [x] 8 new `_mark_*` unit tests + identity regression test for
`_fresh_*` helpers
- [x] `SHARED_TOOL_NOTES` public-constant test passes; 5 old tool-docs
tests removed
- [ ] **Manual cost verification (commit 1):** send two consecutive SDK
turns; Turn 2 and Turn 3 should both show `cacheReadTokens` ≈ 33K (full
cross-user cache hits).
- [ ] **Manual cost verification (commit 2):** send two "hello" turns on
baseline <5 min apart; Turn 2 reports `cacheReadTokens` ≈ 18K and cost ≈
$0.006.
- [ ] **Regression sweep for commit 3:** one turn per tool family —
`search_agents`, `run_agent`,
`add_memory`/`forget_memory`/`search_memory`, `search_docs`,
`read_workspace_file` — to verify no tool-selection regression from
dropping the prose tool docs.
- [ ] **Langfuse v26 smoke test:** 5-10 varied turns after relabelling
to `production`; compare responses vs v25 for regression on persona,
concision, capability-gap handling, credential security flows.

## Deployment notes
- Production Docker image now installs CLI 2.1.116 (~20 MB added).
- `CHAT_CLAUDE_AGENT_CLI_PATH=/usr/bin/claude` set in the Dockerfile;
runtime can override via env.
- First deploy after this merge needs a fresh image rebuild to pick up
the new CLI.
2026-04-21 16:34:10 +07:00
Zamil Majdy
f238c153a5 fix(backend/copilot): release session cluster lock on completion (#12867)
## Summary

Fixes a bug where a chat session gets silently stuck after the user
presses Stop mid-turn.

**Root cause:** the cancel endpoint marks the session `failed` after
polling 5s, but the cluster lock held by the still-running task is only
released by `on_run_done` when the task actually finishes. If the task
hangs past the 5s poll (slow LLM call, agent-browser step, etc.), the
lock lingers for up to 5 min — `stream_chat_post`'s `is_turn_in_flight`
check sees the flipped meta (`failed`) and enqueues a new turn, but the
run handler sees the stale lock and drops the user's message at
`manager.py:379` (`reject+requeue=False`). The new SSE stream hangs
until its 60s idle timeout.

### Fix

Two cooperating changes:

1. **`mark_session_completed` force-releases the cluster lock** in the
same transaction that flips status to `completed`/`failed`.
Unconditional delete — by the time we're declaring the session dead, we
don't care who the current lock holder is; the lock has to go so the
next enqueued turn can acquire. This is what closes the stuck-session
window.
2. **`ClusterLock.release()` is now owner-checked** (Lua CAS — `GET ==
token ? DEL : noop` atomically). Force-release means another pod may
legitimately own the key by the time the original task's `on_run_done`
eventually fires. Without the CAS, that late `release()` would wipe the
successor's lock. With it, the late `release()` is a safe no-op when the
owner has changed.

Together: prompt release on completion (via force-delete) + safe cleanup
when on_run_done catches up (via CAS). That re-syncs the API-level
`is_turn_in_flight` check with the actual lock state, so the contention
window disappears.

No changes to the worker-level contention handler: `stream_chat_post`
already queues incoming messages into the pending buffer when a turn is
in flight (via `queue_pending_for_http`). With these fixes, the worker
never sees contention in the common case; if it does (true multi-pod
race), the pre-existing `reject+requeue=False` behaviour still applies —
we'll revisit that path with its own PR if it becomes a production
symptom.

### Verification

- Reproduced the original stuck-session symptom locally (Stop mid-turn →
send new message → backend logs `Session … already running on pod …`,
user message silently lost, SSE stream idle 60s then closes).
- After the fix: cancel → new message → turn starts normally (lock
released by `mark_session_completed`).
- `poetry run pyright` — 0 errors on edited files.
- `pytest backend/copilot/stream_registry_test.py
backend/executor/cluster_lock_test.py` — 33 passed (includes the
successor-not-wiped test).

## Changes

- `autogpt_platform/backend/backend/copilot/executor/utils.py` — extract
`get_session_lock_key(session_id)` helper so the lock-key format has a
single source of truth.
- `autogpt_platform/backend/backend/copilot/executor/manager.py` — use
the helper where the cluster lock is created.
- `autogpt_platform/backend/backend/copilot/stream_registry.py` —
`mark_session_completed` deletes the lock key after the atomic status
swap (force-release).
- `autogpt_platform/backend/backend/executor/cluster_lock.py` —
`ClusterLock.release()` (sync + async) uses a Lua CAS to only delete
when `GET == token`, protecting against wiping a successor after a
force-release.

## Test plan

- [ ] Send a message in /copilot that triggers a long turn (e.g.
`run_agent`), press Stop before it finishes, then send another message.
Expect: new turn starts promptly (no 5-min wait for lock TTL).
- [ ] Happy path regression — send a normal message, verify turn
completes and the session lock key is deleted after completion.
- [ ] Successor protection — unit test
`test_release_does_not_wipe_successor_lock` covers: A acquires, external
DEL, B acquires, A.release() is a no-op, B's lock intact.
2026-04-21 16:27:01 +07:00
Zamil Majdy
01f1289aac feat(copilot): real OpenRouter cost + cost-based rate limits (percent-only public API) (#12864)
## Why

After d7653acd0 removed cost estimation, most baseline turns log with
`tracking_type="tokens"` and no authoritative USD figure (see: dashboard
flipped from `cost_usd` to `tokens` after 4/14/2026). Rate-limit
counters were also token-weighted with hand-rolled cache discounts
(cache_read @ 10%, cache_create @ 25%) and a 5× Opus multiplier — a
proxy for cost that drifts from real OpenRouter billing.

This PR wires real generation cost from OpenRouter into both the
cost-tracking log and the rate limiter, and hides raw spend figures from
the user-facing API so clients can't reverse-engineer per-turn cost or
platform margins.

## What

1. **Real cost from OpenRouter** — baseline passes `extra_body={"usage":
{"include": True}}` and reads `chunk.usage.cost` from the final
streaming chunk. `x-total-cost` header path removed. Missing cost logs
an error and skips the counter update (vs the old estimator that
silently under-counted).
2. **Cost-based rate limiting** — `record_token_usage(...)` →
`record_cost_usage(cost_microdollars)`. The weighted-token math, cache
discount factors, and `_OPUS_COST_MULTIPLIER` are gone; real USD already
reflects model + cache pricing.
3. **Redis key migration** — `copilot:usage:*` → `copilot:cost:*` so
stale token counters can't be misinterpreted as microdollars.
4. **LD flags + config** — renamed to
`copilot-daily-cost-limit-microdollars` /
`copilot-weekly-cost-limit-microdollars` (unit in the LD key so values
can't accidentally be set in dollars or cents).
5. **Public `/usage` hides raw $$** — new `CoPilotUsagePublic` /
`UsageWindowPublic` schemas expose only `percent_used` (0-100) +
`resets_at` + `tier` + `reset_cost`. Admin endpoint keeps raw
microdollars for debugging.
6. **Admin API contract** — `UserRateLimitResponse` fields renamed
`daily/weekly_token_limit` → `daily/weekly_cost_limit_microdollars`,
`daily/weekly_tokens_used` → `daily/weekly_cost_used_microdollars`.
Admin UI displays `$X.XX`.

## How

- `baseline/service.py` — pass `extra_body`, extract cost from
`chunk.usage.cost`, drop the `x-total-cost` header fallback entirely.
- `rate_limit.py` — rewritten around `record_cost_usage`,
`check_rate_limit(daily_cost_limit, weekly_cost_limit)`, new Redis key
prefix. Adds `CoPilotUsagePublic.from_status()` projector for the public
API.
- `token_tracking.py` — converts `cost_usd` → microdollars via
`usd_to_microdollars` and calls `record_cost_usage` only when cost is
present.
- `sdk/service.py` — deletes `_OPUS_COST_MULTIPLIER` and simplifies
`_resolve_model_and_multiplier` to `_resolve_sdk_model_for_request`.
- Chat routes: `/usage` and `/usage/reset` return `CoPilotUsagePublic`.
Internal server-side limit checks still use the raw microdollar
`CoPilotUsageStatus`.
- Admin routes: unchanged response shape (renamed fields only).
- Frontend: `UsagePanelContent`, `UsageLimits`, `CopilotPage`,
`BriefingTabContent`, `credits/page.tsx` consume the new public schema
and render "N% used" + progress bar. Admin `RateLimitDisplay` /
`UsageBar` keep `$X.XX`. Helper `formatMicrodollarsAsUsd` retained for
admin use.
- Tests + snapshots rewritten; new assertions explicitly check that raw
`used`/`limit` keys are absent from the public payload.

## Deploy notes

1. **Before rolling this out, create the new LD flags:**
`copilot-daily-cost-limit-microdollars` (default `500000`) and
`copilot-weekly-cost-limit-microdollars` (default `2500000`). Old
`copilot-*-token-limit` flags can stay in LD for rollback.
2. **One-time Redis cleanup (optional):** token-based counters under
`copilot:usage:*` are orphaned and will TTL out within 7 days. Safe to
ignore or delete manually.

## Test plan

- [x] `poetry run test` — all impacted backend tests pass (182/182 in
targeted scope)
- [x] `pnpm test:unit` — all 1628 integration tests pass
- [x] `poetry run format` / `pnpm format` / `pnpm types` clean
- [x] Manual sanity against dev env — Baseline turn logged $0.1221 for
40K/139 tokens on Sonnet 4 (matches expected pricing)
- [ ] `/pr-test --fix` end-to-end against local native stack
2026-04-21 14:34:43 +07:00
Zamil Majdy
343222ace1 feat(platform): defer paid-to-paid subscription downgrades + cancel-pending flow (#12865)
### Why / What / How

**Why:** Only downgrades to FREE were scheduled at period end; paid→paid
downgrades (e.g. BUSINESS→PRO) applied immediately via Stripe proration.
The asymmetry meant users lost their higher tier mid-cycle in exchange
for a Stripe credit voucher only redeemable on a future subscription — a
confusing pattern that produces negative-value paths for users actually
cancelling. There was also no way to cancel a pending downgrade or
paid→FREE cancellation once scheduled.

**What:** Standardize on "upgrade = immediate, downgrade = next cycle"
and let users cancel a pending change by clicking their current tier.
Harden the new code against conflicting subscription state, concurrent
tab races, flaky Stripe calls, and hot-path latency regressions.

**How:**

Subscription state machine:
- **Upgrade** (PRO→BUSINESS) — `stripe.Subscription.modify` with
immediate proration (unchanged). If a downgrade schedule is already
attached, release it first so the upgrade wins.
- **Paid→paid downgrade** (BUSINESS→PRO) — creates a
`stripe.SubscriptionSchedule` with two phases (current tier until
`current_period_end`, target tier after). No mid-cycle tier demotion.
Defensive pre-clear: existing schedule → release;
`cancel_at_period_end=True` → set to False.
- **Paid→FREE** — unchanged: `cancel_at_period_end=True`.
- **Same-tier update** — reuses the existing `POST
/credits/subscription` route. When `target_tier == current_tier`,
backend calls `release_pending_subscription_schedule` (idempotent) and
returns status. No dedicated cancel-pending endpoint — "Keep my current
tier" IS the cancel operation.
- `release_pending_subscription_schedule` is idempotent on
terminal-state schedules and clears both `schedule` and
`cancel_at_period_end` atomically per call.

API surface:
- New fields on `SubscriptionStatusResponse`: `pending_tier` +
`pending_tier_effective_at` (pulled from the schedule's next-phase
`start_date` so dashboard-authored schedules report the correct
timestamp).
- `POST /credits/subscription` now returns `SubscriptionStatusResponse`
(previously `SubscriptionCheckoutResponse`); the response still carries
`url` for checkout flows and adds the status fields inline.
- `get_pending_subscription_change` is cached with a 30s TTL — avoids
hammering Stripe on every home-page load.
- Webhook dispatches
`subscription_schedule.{released,completed,updated}` through the main
`sync_subscription_from_stripe` flow so both event sources converge to
the same DB state.

Implementation notes:
- New Stripe calls use native async (`stripe.Subscription.list_async`
etc.) and typed attribute access — no `run_in_threadpool` wrapping in
the new helpers.
- Shared `_get_active_subscription` helper collapses the "list
active/trialing subs, take first" pattern used by 4 callers.

Frontend:
- `PendingChangeBanner` sub-component above the tier grid with formatted
effective date + "Keep [CurrentTier]" button. `aria-live="polite"` for
screen readers; locale pinned to `en-US` to avoid SSR/CSR hydration
mismatch.
- "Keep [CurrentTier]" also available as a button on the current tier
card.
- Other tier buttons disabled while a change is pending — user must
resolve pending first to prevent stacked schedules.
- `cancelPendingChange` reuses `useUpdateSubscriptionTier` with `tier:
current_tier`; awaits `refetch()` on both success and error paths so the
UI reconciles even if the server succeeded but the client didn't receive
the response.

### Changes

**Backend (`credit.py`, `v1.py`)**
- Tier-ordering helpers (`is_tier_upgrade`/`is_tier_downgrade`).
- `modify_stripe_subscription_for_tier` routes downgrades through
`_schedule_downgrade_at_period_end`; upgrade path releases any pending
schedule first.
- `_schedule_downgrade_at_period_end` defensively releases pre-existing
schedules and clears `cancel_at_period_end` before creating the new
schedule.
- `release_pending_subscription_schedule` idempotent on terminal-state
schedules; logs partial-failure outcomes.
- `_next_phase_tier_and_start` returns both tier and phase-start
timestamp; warns on unknown prices.
- `get_pending_subscription_change` cached (30s TTL), narrow exception
handling.
- `sync_subscription_schedule_from_stripe` delegates to
`sync_subscription_from_stripe` for convergence with the main webhook
path.
- Shared `_get_active_subscription` +
`_release_schedule_ignoring_terminal` helpers.
- `POST /credits/subscription` absorbs the same-tier "cancel pending
change" branch.

**Frontend (`SubscriptionTierSection/*`)**
- `PendingChangeBanner` new sub-component (a11y, locale-pinned date,
paid→FREE vs paid→paid copy split, non-null effective-date assertion, no
`dark:` utilities).
- "Keep [CurrentTier]" button on current tier card.
- `useSubscriptionTierSection` — `cancelPendingChange` reuses the
update-tier mutation.
- Copy: downgrade dialog + status hint updated.
- `helpers.ts` extracted from the main component.

**Tests**
- Backend: +24 tests (95/95 passing): upgrade-releases-pending-schedule,
schedule-releases-existing-schedule, cancel-at-period-end collision,
terminal-state release idempotency, unknown-price logging, status
response population, same-tier-POST-with-pending, webhook delegation.
- Frontend: +5 integration tests (21/21 passing): banner render/hide,
Keep-button click from banner + current card, paid→paid dialog copy.

### Checklist

- [x] Backend unit tests: 95 pass
- [x] Frontend integration tests: 21 pass
- [x] `poetry run format` / `poetry run lint` clean
- [x] `pnpm format` / `pnpm lint` / `pnpm types` clean
- [ ] Manual E2E on live Stripe (dev env) — pending deploy: BUSINESS→PRO
creates schedule, DB tier unchanged until period end
- [ ] Manual E2E: "Keep BUSINESS" in banner releases schedule
- [ ] Manual E2E: cancel pending paid→FREE flips `cancel_at_period_end`
back to false
- [ ] Manual E2E: BUSINESS→PRO (scheduled) then attempt BUSINESS→FREE
clears the PRO schedule, sets cancel_at_period_end
- [ ] Manual E2E: BUSINESS→PRO (scheduled) then upgrade back to BUSINESS
releases the schedule
2026-04-21 14:01:09 +07:00
Zamil Majdy
a8226af725 fix(copilot): dedupe tool row, lift bash_exec timeout, Stop+resend recovery (#12862)
Closes #12861 · [OPEN-3096](https://linear.app/autogpt/issue/OPEN-3096)

## Why

Four related copilot UX / stability issues surfaced on dev once action
tools started rendering inline in the chat (see #12813):

### 1. Duplicate bash_exec row

`GenericTool` rendered two rows saying the same thing for every
completed tool call — a muted subtitle line ("Command exited with code
1" / "Ran: sleep 20") **and** a `ToolAccordion` with the command echoed
in its description. Previously hidden inside the "Show reasoning" /
"Show steps" collapse, now visibly duplicated.

### 2. `bash_exec` capped at 120s via advisory text

The tool schema said `"Max seconds (default 30, max 120)"`; the model
obeyed, so long-running scripts got clipped at 120s with a vague `Timed
out after 120s` even though the E2B sandbox has no such limit. Confirmed
via Langfuse traces — the model picks `120` for long scripts because
that's what the schema told it the max was. E2B path never had a
server-side clamp.

Originally added in #12103 (default 30) and tightened to "max 120"
advisory in #12398 (token-reduction pass).

### 3. 30s default was too aggressive

`pip install`, small data-processing scripts, etc. routinely cross 30s
and got killed before the model thought to retry with a bigger timeout.

### 4. Stop + edit + resend → "The assistant encountered an error"
([OPEN-3096](https://linear.app/autogpt/issue/OPEN-3096))

Two independent bugs both land on the same banner — fixing only one
leaves the other visible on the next action.

**4a. Stream lock never released on Stop** *(the error in the ticket
screenshot)*. The executor's `async for chunk in
stream_and_publish(...)` broke out on `cancel.is_set()` without calling
`aclose()` on the wrapper. `async for` does NOT auto-close iterators on
`break`, so `stream_chat_completion_sdk` stayed suspended at its current
`await` — still holding the per-session Redis lock (TTL 120s) until GC
eventually closed it. The next `POST /stream` hit `lock.try_acquire()`
at
[sdk/service.py](autogpt_platform/backend/backend/copilot/sdk/service.py)
and yielded `StreamError("Another stream is already active for this
session. Please wait or stop it.")`. The `except GeneratorExit →
lock.release()` handler written exactly for this case never fired
because nothing sent GeneratorExit.

**4b. Orphan `tool_use` after stop-mid-tool.** Even with the lock
released, the stop path persists the session ending on an assistant row
whose `tool_calls` have no matching `role="tool"` row. On the next turn,
`_session_messages_to_transcript` hands Claude CLI `--resume` a JSONL
with a `tool_use` and no paired `tool_result`, and the SDK raises a
vague error — same banner. The ticket's "Open questions" explicitly
flags this.

## What

**Frontend — `GenericTool.tsx`** split responsibilities between the two
rows so they don't duplicate:
- **Subtitle row** (always visible, muted): *what ran* — `Ran: sleep
120`. Never the exit code.
- **Accordion description**: *how it ended* — `completed` / `status code
127 · bash: missing-bin: command not found` / `Timed out after 120s` /
(fallback to command preview for legacy rows missing `exit_code` /
`timed_out`). Pulled from the first non-empty line of `stdout` /
`stderr` when available.
- **Expanded accordion**: full command + stdout + stderr code blocks
(unchanged).

**Backend — `bash_exec.py`**:
- Drop the "max 120" advisory from the schema description.
- Bump default `timeout: 30 → 120`.
- Clean up the result message — `"Command executed with status code 0"`
(no "on E2B", no parens).

**Backend — `executor/processor.py` + `stream_registry.py` (OPEN-3096
#4a)**: wrap the consumer `async for` in `try/finally: await
stream.aclose()`. Close now propagates through `stream_and_publish` into
`stream_chat_completion_sdk`, whose existing `except GeneratorExit →
lock.release()` releases the Redis lock immediately on cancel. Stream
types tightened to `AsyncGenerator[StreamBaseResponse, None]` so the
defensive `getattr(stream, "aclose", None)` goes away.

**Backend — `session_cleanup.py` (OPEN-3096 #4b)**: new
`prune_orphan_tool_calls()` helper walks the trailing session tail and
drops any trailing assistant row whose `tool_calls` have unresolved ids
(plus everything after it) and any trailing `STOPPED_BY_USER_MARKER`
system-stop row. Single backward pass — tolerates the marker being
present or absent. Called from the existing turn-start cleanup in both
`sdk/service.py` and `baseline/service.py`; takes an optional
`log_prefix` so both paths emit the same INFO log when something was
popped. In-memory only — the DB save path is append-only via
`start_sequence`.

## Test plan

- [x] `pnpm exec vitest run src/app/(platform)/copilot/tools/GenericTool
src/app/(platform)/copilot/components/ChatMessagesContainer` — 105 pass
(6 new for GenericTool subtitle/description variants + legacy-fallback
case).
- [x] `pnpm format` / `pnpm lint` / `pnpm types` — clean.
- [x] `poetry run pytest
backend/copilot/sdk/session_persistence_test.py` — 17 pass (6 + 3 new
covering the orphan-tool-call prune and its optional-log-prefix branch).
- [x] `poetry run pytest backend/copilot/stream_registry_test.py
backend/copilot/executor/processor_test.py` — 19 pass (2 for aclose
propagation on the `stream_and_publish` wrapper, 2 for `_execute_async`
aclose propagation on both exit paths, 1 for publish_chunk RedisError
warning ladder).
- [x] `poetry run ruff check` / `poetry run pyright` on touched files —
clean.
- [x] Manual: fire a `bash_exec` — one labelled row, accordion
description reads sensibly (`completed` / `status code 1 · …` / `Timed
out after 120s`).
- [x] Manual: script that needs >120s — no longer clipped.
- [x] Manual: Stop mid-tool + edit + resend — Autopilot resumes without
"Another stream is already active" and without the vague SDK error.

## Scope note

Does not touch `splitReasoningAndResponse` — re-collapsing action tools
back into "Show steps" is #12813's responsibility.
2026-04-21 10:18:52 +07:00
Ubbe
f06b5293de fix(frontend/library): compute monthly spend for AgentBriefingPanel (#12854)
### Why / What / How

<img width="900" alt="Screenshot 2026-04-20 at 19 52 22"
src="https://github.com/user-attachments/assets/c30d5f18-2842-4a8a-ac3d-5bfee18fcd56"
/>

**Why:** The "Spent this month" tile in the Agent Briefing Panel on the
Library page always showed `$0`, even for users with real execution
usage. The tile is meant to give a quick sense of monthly spend across
all agents.

**What:** Compute `monthlySpend` from actual execution data and format
it as currency.

**How:**
- `useLibraryFleetSummary` now sums `stats.cost` (cents) across every
execution whose `started_at` falls within the current calendar month.
Previously `monthlySpend` was hardcoded to `0`.
- `FleetSummary.monthlySpend` is documented as being in cents
(consistent with backend + `formatCents`).
- `StatsGrid` now uses `formatCents` from the copilot usage helpers to
render the tile (e.g. `$12.34` instead of the broken `$0`).

### Changes 🏗️

-
`autogpt_platform/frontend/src/app/(platform)/library/hooks/useLibraryFleetSummary.ts`:
aggregate `stats.cost` across executions started in the current calendar
month; add `toTimestamp` and `startOfCurrentMonth` helpers.
-
`autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/StatsGrid.tsx`:
format the "Spent this month" tile via shared `formatCents` helper.
- `autogpt_platform/frontend/src/app/(platform)/library/types.ts`:
document that `FleetSummary.monthlySpend` is in cents.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [ ] I have tested my changes according to the test plan:
- [ ] Load `/library` with the `AGENT_BRIEFING` flag enabled and at
least one completed execution in the current month — the "Spent this
month" tile shows the correct cumulative cost.
  - [ ] With no executions this month, the tile shows `$0.00`.
- [ ] Type-check (`pnpm types`), lint (`pnpm lint`), and integration
tests (`pnpm test:unit`) pass locally.

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-20 20:28:47 +07:00
Zamil Majdy
70b591d74f fix(copilot): persist reasoning, split steps/reasoning UX, fix mid-turn promote stream stall (#12853)
## Why

Four related issues that surfaced when queued follow-ups hit an
extended_thinking turn:

1. **Mid-turn promote stalled the SSE stream.** `pollBackendAndPromote`
used `setMessages((prev) => [...prev, bubble])` — Vercel AI SDK's
`useChat` streams SSE deltas into `messages[-1]`, so once a user bubble
ended up there, every subsequent chunk silently landed on the wrong
message. Chat sat frozen until a page refresh, even though the backend's
stream completed cleanly.
2. **Thinking-only final turn looked identical to a frozen UI.** When
Claude's last LLM call after a tool_result produced only a
`ThinkingBlock` (no `TextBlock`, no `ToolUseBlock`), the response
adapter silently dropped it and the UI hung on "Thought for Xs" with no
response text.
3. **Reasoning was invisible.** `ThinkingBlock` was dropped live and
never persisted in a way the frontend could render — sessions on reload
/ shared links showed no thinking, a confusing UX gap ("display for
nothing").
4. **Cross-pod Redis replay dropped reasoning events.** The
`stream_registry._reconstruct_chunk` type map had no entries for
`reasoning-*` types, so any client that subscribed mid-stream (share,
reload, cross-pod) silently dropped them with `Unknown chunk type:
reasoning-delta`.

## What

### Mid-turn promote — splice before the trailing assistant

In `useCopilotPendingChips.ts::pollBackendAndPromote`:

```ts
setMessages((prev) => {
  const bubble = makePromotedUserBubble(drained, "midturn", crypto.randomUUID());
  const lastIdx = prev.length - 1;
  if (lastIdx >= 0 && prev[lastIdx].role === "assistant") {
    return [...prev.slice(0, lastIdx), bubble, prev[lastIdx]];
  }
  return [...prev, bubble];
});
```

Streaming assistant stays at `messages[-1]`, AI SDK deltas keep routing
correctly. `useHydrateOnStreamEnd` snaps the bubble to the DB-canonical
position when the stream ends.

### Reasoning — end-to-end visibility (live + persisted)

- **Wire protocol**: new `StreamReasoningStart` / `StreamReasoningDelta`
/ `StreamReasoningEnd` events matching AI SDK v5's `reasoning-*` wire
names, so `useChat` accumulates them into a `type: 'reasoning'`
UIMessage part natively.
- **Response adapter**: every `ThinkingBlock` now emits reasoning
events; text/tool_use transitions close the open reasoning block so AI
SDK doesn't merge distinct parts.
- **Stream registry**: added `reasoning-*` types to
`_reconstruct_chunk`'s type_to_class map so Redis replay no longer drops
them on cross-pod / reload / share.
- **Persistence** (new): each `StreamReasoningStart` opens a
`ChatMessage(role="reasoning")` row in `session.messages`; deltas
accumulate into its content; `StreamReasoningEnd` closes it. No schema
migration — `ChatMessage.role` is already `String`.
`extract_context_messages` filters `role="reasoning"` out of LLM context
(the `--resume` CLI session already carries thinking separately) so the
model never re-ingests prior reasoning.
- **Frontend conversion**: `convertChatSessionMessagesToUiMessages` maps
`role="reasoning"` DB rows into `{type: "reasoning", text}` parts on the
surrounding assistant bubble, so reload / shared-link sessions render
reasoning identically to live stream.

### Steps / Reasoning UX — modal + accordion split

- **`StepsCollapse`** (new): a Dialog-backed "Show steps" modal wraps
the pre-final-answer group (tool timeline + per-block reasoning). Modal
keeps the steps visually grouped and out of the reading flow.
- **`ReasoningCollapse`** (rewritten): inline accordion with "Show
reasoning" / "Hide reasoning" toggle — no longer a modal, so it expands
*inside* the Steps modal without stacking two dialogs. Reasoning text
appears indented with a left border.
- **`splitReasoningAndResponse`**: reasoning parts now stay in the
reasoning group (instead of being pinned out), so they show up inside
the Steps modal alongside the tool-use timeline.

### Thinking-only final turn — synthesize a closing line
(belt-and-suspenders)

- **Prompt rule** (`_USER_FOLLOW_UP_NOTE`): "Every turn MUST end with at
least one short user-facing text sentence."
- **Adapter fallback**: tracks `_text_since_last_tool_result`; at
`ResultMessage success` with tools run + zero text since, opens a fresh
step (`UserMessage` already closed the previous one) and injects `"(Done
— no further commentary.)"` before `StreamFinish`. Only fires for the
pathological case — pure-text turns untouched.

## Test plan

- [x] `pnpm vitest run` on copilot files — all 638 prior tests pass;
**17 new tests** added covering:
- `convertChatSessionToUiMessages`: reasoning row alone / merged with
assistant text / multi-row / empty skip / duration capture
- `ReasoningCollapse`: initial collapsed, toggle, `rotate-90`,
`aria-expanded`
  - `StepsCollapse`: trigger + dialog open renders children
- `MessagePartRenderer`: reasoning → `<pre>` inside collapse,
whitespace/missing text → null
  - `splitReasoningAndResponse`: reasoning-stays-in-reasoning regression
- [x] `poetry run pytest backend/copilot/sdk/response_adapter_test.py` —
36 pass (7 new: 4 reasoning streaming, 3 thinking-only fallback)
- [x] Manual: reasoning streams live and persists across reload on a
fresh session
- [x] Manual: previously-created sessions (pre-persistence) don't have
`role="reasoning"` rows — behaves as a clean no-op (no reasoning shown,
no error), new sessions render reasoning inside Steps modal

## Notes

- No DB migration — `ChatMessage.role` is already an open `String`;
`role="reasoning"` is simply filtered out of LLM context builds but
rendered by the frontend.
- Addresses /pr-review blockers: (a) stream_registry missing reasoning
types in Redis round-trip, (b) fallback text emitted outside a step, (c)
dead `case "thinking"` in renderer (now uses the live `reasoning` type
uniformly).
2026-04-19 10:37:04 +07:00
Zamil Majdy
b1c043c2d8 feat(copilot): queue follow-up messages on busy sessions (UI + run_sub_session + AutoPilot block) (#12737)
## Why

Users and tools can target a copilot session that already has a turn
running. Before this PR there was no uniform behaviour for that case —
the UI manually routed to a separate queue endpoint, `run_sub_session`
and the AutoPilot block raced the cluster lock, and in-turn follow-ups
only reached the model at turn-end via auto-continue. Outcome: dropped
messages, duplicate tool rows, missed mid-turn intent, latent
correctness bugs in block execution.

## What

A single "message arrived → turn already running?" primitive, shared by
every caller:

1. **POST `/stream`** (UI chat): self-defensive. Session idle → SSE as
today; session busy → `202 application/json` with `{buffer_length,
max_buffer_length, turn_in_flight}`. The deprecated `POST
/messages/pending` endpoint is removed (`GET /messages/pending` peek
stays).
2. **`run_copilot_turn_via_queue`** (shared primitive from #12841, used
by `run_sub_session` + `AutoPilotBlock`): gains the same busy-check.
Busy session → push to pending buffer, return `("queued",
SessionResult(queued=True, pending_buffer_length=N))` without creating a
stream registry session or enqueueing a RabbitMQ job. All callers
inherit queueing.
3. **Mid-turn delivery**: drained follow-ups are attached to every
tool_result's `additionalContext` via the SDK's `PostToolUse` hook —
covers both MCP and built-in tools (WebSearch/Read/Agent/etc.), not just
`run_block`. Claude reads the queued text on the next LLM round of the
same turn.
4. **UI observability**: chips promote to a proper user bubble at the
correct chronological position (after the tool_result row that consumed
them). Auto-continue handles end-of-turn drainage; mid-turn backend poll
handles the tool-boundary drainage path.

## How

**Data plane**
- `backend/copilot/pending_messages.py` — Redis list per session
(LPOP-count for atomic drain), TTL, fire-and-forget pub/sub notify. MAX
10 per session.
- `backend/copilot/pending_message_helpers.py` — `is_turn_in_flight`,
`queue_user_message`, `drain_and_format_for_injection`,
`persist_pending_as_user_rows` (shared persist+rollback used by both
baseline and SDK paths).
- `backend/data/redis_helpers.py` — centralised `incr_with_ttl`,
`capped_rpush`, `hash_compare_and_set`; every Lua script and pipeline
atomicity lives in one place.

**Injection sites**
- `backend/copilot/sdk/security_hooks.py::post_tool_use_hook` — drains +
returns `additionalContext`. Single hook covers built-in + MCP tools.
- `backend/copilot/sdk/service.py` — `StreamToolOutputAvailable`
dispatch persists the drained follow-up as a real user row right after
the tool_result (UI bubble at the right index).
`state.midturn_user_rows` keeps the CLI upload watermark honest.
- `backend/copilot/baseline/service.py` — same drain at round
boundaries, uses the shared `persist_pending_as_user_rows` helper so
baseline + SDK code paths don't diverge.

**Dispatch**
- `backend/copilot/sdk/session_waiter.py::run_copilot_turn_via_queue` —
`is_turn_in_flight` short-circuit; `SessionResult` gains `queued` +
`pending_buffer_length`; `SessionOutcome` gains `"queued"`.
- `backend/api/features/chat/routes.py::stream_chat_post` — busy-check
returns 202 with `QueuePendingMessageResponse`; `POST /messages/pending`
deleted.
- `backend/copilot/tools/run_sub_session.py` / `models.py` —
`SubSessionStatusResponse.status` gains `"queued"`;
`response_from_outcome` renders a clear queued-state message with the
pending-buffer depth and a link to watch live.
- `backend/blocks/autopilot.py::execute_copilot` — surfaces queued state
as descriptive response text + empty `tool_calls`/history when
`result.queued`.

**Frontend**
- `src/app/(platform)/copilot/useCopilotPendingChips.ts` — hook owning
the chip lifecycle: backend peek on session load, auto-continue
promotion when a second assistant id appears, mid-turn poll that
promotes when the backend count drops.
- `src/app/(platform)/copilot/useHydrateOnStreamEnd.ts` —
force-hydrate-waits-for-fresh-reference dance extracted.
- `src/app/(platform)/copilot/helpers/stripReplayPrefix.ts` — pure
function with drop / strip / streaming-catch-up cases + helper
decomposition.
- `src/app/(platform)/copilot/helpers/makePromotedBubble.ts` — one-line
helper for the promoted bubble shape.
- `src/app/(platform)/copilot/helpers/queueFollowUpMessage.ts` — thin
`fetch` wrapper for the 202 path (AI SDK's `useChat` fetcher only
handles SSE, so we can't reuse `sendMessage` for the queued response).

## Test plan

Backend unit + integration (`poetry run pytest backend/copilot
backend/api/features/chat`):
- [x] 107 tests pass — pending buffer, drain helpers, routes,
session_waiter queue branch, run_sub_session outcome rendering,
autopilot block
- [x] New `session_waiter_test.py` proves the queue branch
short-circuits `stream_registry.create_session` + `enqueue_copilot_turn`
- [x] Mid-turn persist has a rollback-and-re-queue path tested for when
`session.messages` persist silently fails to back-fill sequences

Frontend unit (`pnpm vitest run`):
- [x] 630 tests pass incl. 22 new for extracted helpers + hooks
- [x] Frontend coverage on touched copilot files: 91%+ (patch 87.37%)

Manual (once merged):
- [ ] Queue two chips while a tool is running; Claude acknowledges both
on the next round, UI shows bubbles in typing order after the tool
output
- [ ] Hand AutoPilot block an existing session_id that has a live turn;
block returns queued status, in-flight turn drains the message on its
next round
- [ ] `run_sub_session` against a busy sub — status=`queued`,
`sub_autopilot_session_link` lets user watch live

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-19 00:48:59 +07:00
Zamil Majdy
fcaebd1bb7 refactor(backend/copilot): unified queue-backed copilot turns + async sub-AutoPilot + guide-read gate (#12841)
### Why / What / How

**Why:** the 10-min stream-level idle timeout was killing legitimate
long-running tool calls — notably sub-AutoPilot runs via
`run_block(AutoPilotBlock)`, which routinely take 15–45 min. The symptom
users saw was `"A tool call appears to be stuck"` even though AutoPilot
was actively working. A second long-standing rough edge was shipped
alongside: agents often skipped `get_agent_building_guide` when
generating agent JSON, producing schemas that failed validation and
burned turns on auto-fix loops.

**What:** three threaded pieces.

1. **Async sub-AutoPilot via `run_sub_session`.** New copilot tool that
delegates a task to a fresh (or resumed) sub-AutoPilot, and its
companion `get_sub_session_result` for polling/cancelling. The agent
starts with `run_sub_session(prompt, wait_for_result≤300s)` and, if the
sub isn't done inside the cap, receives a handle + polls via
`get_sub_session_result(wait_if_running≤300s)`. No single MCP call ever
blocks the stream for more than 5 min, so the 10-min stream-idle timer
stays simple and effective (derived as `MAX_TOOL_WAIT_SECONDS * 2`).

2. **Queue-backed copilot turn dispatch** — one code path for all three
callers.
- `run_sub_session` enqueues a `CoPilotExecutionEntry` on the existing
`copilot_execution` exchange instead of spawning an in-process
`asyncio.Task`.
- `AutoPilotBlock.execute_copilot` (graph block) now uses the **same
queue** instead of `collect_copilot_response` inline.
   - The HTTP SSE endpoint was already queue-backed.
- All three share a single primitive: `run_copilot_turn_via_queue` →
`create_session` → `enqueue_copilot_turn` → `wait_for_session_result`.
The event-aggregation logic (`EventAccumulator`/`process_event`) is a
shared module used by both the direct-stream path and the cross-process
waiter.
- Benefits: **deploy/crash resilience** (RabbitMQ redelivery survives
worker restarts), **natural load balancing** across copilot_executor
workers, **sessions as first-class resources** (UI users can
`/copilot?sessionId=<inner>` into any sub or AutoPilot block's session),
and every future stream-level feature (pending-messages drain #12737,
compaction policies, etc.) applies uniformly instead of bypassing
graph-block sessions.

3. **Guide-read gate on agent-generation tools.** `create_agent` /
`edit_agent` / `validate_agent_graph` / `fix_agent_graph` refuse until
the session has called `get_agent_building_guide`. The pre-existing soft
hint was routinely ignored; the gate makes the dependency enforceable.
All four tool descriptions advertise the requirement in one tightened
sentence ("Requires get_agent_building_guide first (refuses
otherwise).") that stays under the 32000-char schema budget.

**How:**

#### Queue-backed sub-AutoPilot + AutoPilotBlock

- `sdk/session_waiter.py` — new module. `SessionResult` dataclass
mirrors `CopilotResult`. `wait_for_session_result` subscribes to
`stream_registry`, drains events via shared `process_event`, returns
`(outcome, result)`. `wait_for_session_completion` is the cheaper
outcome-only variant. `run_copilot_turn_via_queue` is the canonical
three-step dispatch. Every exit path unsubscribes the listener.
- `sdk/stream_accumulator.py` — new module. `EventAccumulator`,
`ToolCallEntry`, `process_event` extracted from `collect.py`. Both the
direct-stream and cross-process paths now use the same fold logic.
- `tools/run_sub_session.py` / `tools/get_sub_session_result.py` —
rewritten around the shared primitive. `sub_session_id` is now the sub's
`ChatSession` id directly (no separate registry handle). Ownership
re-verified on every call via `get_chat_session`. Cancel via
`enqueue_cancel_task` on the existing `copilot_cancel` fan-out exchange.
- `blocks/autopilot.py` — `execute_copilot` replaced its inline
`collect_copilot_response` with `run_copilot_turn_via_queue`.
`SessionResult` carries response text, tool calls, and token usage back
from the worker so no DB round-trip is needed. The block's public I/O
contract (inputs, outputs, `ToolCallEntry` shape) is unchanged.
- `CoPilotExecutionEntry` gains a `permissions: CopilotPermissions |
None` field forwarded to the worker's `stream_fn` so the sub's
capability filter survives the queue hop. The processor passes it
through to `stream_chat_completion_sdk` /
`stream_chat_completion_baseline`.
- **Deleted**: `sdk/sub_session_registry.py` (module-level dict,
done-callback, abandoned-task cap, `notify_shutdown_and_cancel_all`,
`_reset_for_test`), plus the shutdown-notifier hook in
`copilot_executor.processor.cleanup` — redundant under queue-backed
execution.

#### Run_block single-tool cap (3)

- `tools/helpers.execute_block` caps block execution at
`MAX_TOOL_WAIT_SECONDS = 5 min` via `asyncio.wait_for` around the
generator consumption.
- On timeout: logs `copilot_tool_timeout tool=run_block block=…
block_id=… input_keys=… user=… session=… cap_s=…` (grep-friendly) and
returns an `ErrorResponse` that redirects the LLM to `run_agent` /
`run_sub_session`.
- Billing protection: `_charge_block_credits` is called in a `finally`
guarded by `asyncio.shield` and marked `charge_handled` **before** the
await so cancel-mid-charge doesn't double-bill and
cancel-mid-generator-before-charge still settles via the finally.

#### Guide-read gate

- `helpers.require_guide_read(session, tool_name)` scans
`session.messages` for any prior assistant tool call named
`get_agent_building_guide` (handles both OpenAI and flat shapes).
Applied at the top of `_execute` in `create_agent`, `edit_agent`,
`validate_agent_graph`, `fix_agent_graph`. Tool descriptions advertise
the requirement.

#### Shared timing constants

- `MAX_TOOL_WAIT_SECONDS = 5 * 60` + `STREAM_IDLE_TIMEOUT_SECONDS = 2 *
MAX_TOOL_WAIT_SECONDS` in `constants.py`. Every long-running tool
(`run_agent`, `view_agent_output`, `run_sub_session`,
`get_sub_session_result`, `run_block`) imports from one place; no more
hardcoded 300 / `10*60` literals drifting apart. Stream-idle invariant
("no single tool blocks close to the idle timeout") holds by
construction.

### Frontend

- Friendlier tool-card labels: `run_sub_session` → "Sub-AutoPilot",
`get_sub_session_result` → "Sub-AutoPilot result", `run_block` →
"Action" (matches the builder UI's own naming), `run_agent` → "Agent".
Fixes the double-verb "Running Run …" phrasing.
- `SubSessionStatusResponse.sub_autopilot_session_link` surfaces
`/copilot?sessionId=<inner>` so users can click into any sub's session
from the tool-call card — same pattern as `run_agent`'s
`library_agent_link`.

### Changes 🏗️

- **New modules**: `sdk/session_waiter.py`, `sdk/stream_accumulator.py`,
`tools/run_sub_session.py`, `tools/get_sub_session_result.py`,
`tools/sub_session_test.py`, `tools/agent_guide_gate_test.py`.
- **New response types**: `SubSessionStatusResponse`,
`SubSessionProgressSnapshot`, `SessionResult`.
- **New gate helper**: `require_guide_read` in `tools/helpers.py`.
- **Queue protocol**: `permissions` field on `CoPilotExecutionEntry`,
threaded through `processor.py` → `stream_fn`.
- **Hidden**: `AUTOPILOT_BLOCK_ID` in `COPILOT_EXCLUDED_BLOCK_IDS`
(run_block can't execute AutoPilotBlock; agents use `run_sub_session`
instead).
- **Deleted**: `sdk/sub_session_registry.py`, processor
shutdown-notifier hook.
- **Regenerated**: `openapi.json` for the new response types; block-docs
for the updated `ToolName` Literal.
- **Tool descriptions**: tightened the guide-gate hint across the four
agent-builder tools to stay under the 32000-char schema budget.
- **40+ tests** across sub_session, execute_block cap + billing races,
stream_accumulator, agent_guide_gate, frontend helpers.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Unit suite green on the full copilot tree; `poetry run format` +
`pyright` clean
- [x] Schema character budget test passes (tool descriptions trimmed to
stay under 32000)
- [x] Native UI E2E (`poetry run app` + `pnpm dev`):
`run_sub_session(wait_for_result=60)` returns `status="completed"` +
`sub_autopilot_session_link` inline;
`run_sub_session(wait_for_result=1)` returns `status="running"` +
handle, `get_sub_session_result(wait_if_running=60)` observes `running →
completed` transition
- [x] AutoPilotBlock (graph) goes through `copilot_executor` queue
end-to-end (verified via logs: ExecutionManager's AutoPilotBlock node
spawned session `f6de335b-…`, a different `CoPilotExecutor` worker
acquired its cluster lock and ran the SDK stream)
- [x] Guide gate: `create_agent` without a prior
`get_agent_building_guide` returns the refusal; agent reads the guide
and retries successfully
2026-04-18 23:11:41 +07:00
Toran Bruce Richards
1c0c7a6b44 fix(copilot): add gh auth status check to Tool Discovery Priority section (#12832)
## Problem

The CoPilot system prompt contains a `gh auth status` instruction in the
E2B-specific `GitHub CLI` section, but models pattern-match to
`connect_integration` from the **Tool Discovery Priority** section —
which is where the actual decision to call an external service is made.

Because the GitHub auth check lives in a separate, later section, it's
not salient at the point of decision-making. This causes the model to
call `connect_integration(provider='github')` even when `gh` is already
authenticated via `GH_TOKEN`, unnecessarily prompting the user.

## Fix

Add a 3-line callout directly inside the **Tool Discovery Priority**
section:

```
> 🔑 **GitHub exception:** Before calling `connect_integration` for GitHub,
> always run `gh auth status` first. If it shows `Logged in`, proceed
> directly with `gh`/`git` — no integration connection needed.
```

This places the rule at the exact location where the model decides which
tool path to take, preventing the miss.

## Why this works

- **Placement over repetition**: The existing instruction isn't wrong —
it's just in the wrong spot relative to where the decision is made
- **Negative framing**: Explicitly says "before calling
`connect_integration`" which directly intercepts the incorrect reflex
- **Minimal change**: 4 lines added, zero removed

Co-authored-by: Toran Bruce Richards <22963551+Torantulino@users.noreply.github.com>
2026-04-17 15:22:10 +00:00
447 changed files with 59004 additions and 8874 deletions

View File

@@ -25,6 +25,8 @@ Understand the **Why / What / How** before addressing comments — you need cont
gh pr view {N} --json body --jq '.body'
```
> If GraphQL is rate-limited, `gh pr view` fails. See [GitHub rate limits](#github-rate-limits) for REST fallbacks.
## Fetch comments (all sources)
### 1. Inline review threads — GraphQL (primary source of actionable items)
@@ -109,12 +111,16 @@ Only after this loop completes (all pages fetched, count confirmed) should you b
**Filter to unresolved threads only** — skip any thread where `isResolved: true`. `comments(last: 1)` returns the most recent comment in the thread — act on that; it reflects the reviewer's final ask. Use the thread `id` (Relay global ID) to track threads across polls.
> If GraphQL is rate-limited, see [GitHub rate limits](#github-rate-limits) for the REST fallback (flat comment list — no thread grouping or `isResolved`).
### 2. Top-level reviews — REST (MUST paginate)
```bash
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate
```
> **Already REST — unaffected by GraphQL rate limits or outages. Continue polling reviews normally even when GraphQL is exhausted.**
**CRITICAL — always `--paginate`.** Reviews default to 30 per page. PRs can have 80170+ reviews (mostly empty resolution events). Without pagination you miss reviews past position 30 — including `autogpt-reviewer`'s structured review which is typically posted after several CI runs and sits well beyond the first page.
Two things to extract:
@@ -133,6 +139,8 @@ Two things to extract:
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments --paginate
```
> **Already REST — unaffected by GraphQL rate limits.**
Mostly contains: bot summaries (`coderabbitai[bot]`), CI/conflict detection (`github-actions[bot]`), and author status updates. Scan for non-empty messages from non-bot human reviewers that aren't the PR author — those are the ones that need a response.
## For each unaddressed comment
@@ -327,18 +335,65 @@ git push
5. Restart the polling loop from the top — new commits reset CI status.
## GitHub abuse rate limits
## GitHub rate limits
Two distinct rate limits exist — they have different causes and recovery times:
Three distinct rate limits exist — they have different causes, error shapes, and recovery times:
| Error | HTTP code | Cause | Recovery |
|---|---|---|---|
| `{"code":"abuse"}` | 403 | Secondary rate limit — too many write operations (comments, mutations) in a short window | Wait **23 minutes**. 60s is often not enough. |
| `{"message":"API rate limit exceeded"}` | 429 | Primary rate limit — too many API calls per hour | Wait until `X-RateLimit-Reset` header timestamp |
| `{"message":"API rate limit exceeded"}` | 429 | Primary REST rate limit — 5000 calls/hr per user | Wait until `X-RateLimit-Reset` header timestamp |
| `GraphQL: API rate limit already exceeded for user ID ...` | 403 on stderr, `gh` exits 1 | **GraphQL-specific** per-user limit — distinct from REST's 5000/hr and from the abuse secondary limit. Trips faster than REST because point costs per query. | Wait until the GraphQL window resets (typically ~1 hour from the first call in the window). REST still works — use fallbacks below. |
**Prevention:** Add `sleep 3` between individual thread reply API calls. When posting >20 replies, increase to `sleep 5`.
**Recovery from secondary rate limit (403):**
### Detection
The `gh` CLI surfaces the GraphQL limit on stderr with the exact string `GraphQL: API rate limit already exceeded for user ID <id>` and exits 1 — any `gh api graphql ...` **or** `gh pr view ...` call fails. Check current quota and reset time via the REST endpoint that reports GraphQL quota (this call is REST and still works whether GraphQL is rate-limited OR fully down):
```bash
gh api rate_limit --jq '.resources.graphql' # { "limit": 5000, "used": 5000, "remaining": 0, "reset": 1729...}
# Human-readable reset:
gh api rate_limit --jq '.resources.graphql.reset' | xargs -I{} date -r {}
```
Retry when `remaining > 0`. If you need to proceed sooner, sleep 25 min and probe again — the limit is per user, not per machine, so other concurrent agents under the same token also consume it.
### What keeps working
When GraphQL is unavailable (rate-limited or outage):
- **Keeps working (REST):** top-level reviews fetch, conversation comments fetch, all inline-comment replies, CI status (`gh pr checks`), and the `gh api rate_limit` probe.
- **Degraded:** inline thread list — fall back to flat `/pulls/{N}/comments` REST, which drops thread grouping, `isResolved`, and Relay thread IDs. You still get comment bodies and the `databaseId` as `id`, enough to read and reply.
- **Blocked:** `gh pr view`, the `resolveReviewThread` mutation, and any new `gh api graphql` queries — wait for the quota to reset.
### Fall back to REST
**PR metadata reads** — `gh pr view` uses GraphQL under the hood; use the REST pulls endpoint instead, which returns the full PR object:
```bash
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.body' # == --json body
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.base.ref' # == --json baseRefName
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.mergeable' # == --json mergeable
```
Note: REST `mergeable` returns `true|false|null`; GraphQL returns `MERGEABLE|CONFLICTING|UNKNOWN`. The `null` case maps to `UNKNOWN` — treat it the same (still computing; poll again).
**Inline comments (flat list)** — no thread grouping or `isResolved`, but enough to read and reply:
```bash
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate \
| jq '[.[] | {id, path, line, user: .user.login, body: .body[:200], in_reply_to_id}]'
```
Use this degraded mode to make progress on the fix → reply loop, then return to GraphQL for `resolveReviewThread` once the rate limit resets.
**Replies** — already REST-native (`/pulls/{N}/comments/{ID}/replies`); no change needed, use the same command as the main flow.
**`resolveReviewThread`** — **no REST equivalent**; GitHub does not expose a REST endpoint for thread resolution. Queue the thread IDs needing resolution, wait for the GraphQL limit to reset, then run the resolve mutations in a batch (with `sleep 3` between calls, per the secondary-limit guidance).
### Recovery from secondary rate limit (403 abuse)
1. Stop all API writes immediately
2. Wait **2 minutes minimum** (not 60s — secondary limits are stricter)
3. Resume with `sleep 3` between each call
@@ -397,6 +452,8 @@ gh api graphql -f query='mutation { resolveReviewThread(input: {threadId: "THREA
**Never call this mutation before committing the fix.** The orchestrator will verify actual unresolved counts via GraphQL after you output `ORCHESTRATOR:DONE` — false resolutions will be caught and you will be re-briefed.
> `resolveReviewThread` is GraphQL-only — no REST equivalent. If GraphQL is rate-limited, see [GitHub rate limits](#github-rate-limits) for the queue-and-retry flow.
### Verify actual count before outputting ORCHESTRATOR:DONE
Before claiming "0 unresolved threads", always query GitHub directly — don't rely on your own bookkeeping. Paginate all pages — a single `first: 100` query misses threads beyond page 1:

View File

@@ -0,0 +1,245 @@
---
name: pr-polish
description: Alternate /pr-review and /pr-address on a PR until the PR is truly mergeable — no new review findings, zero unresolved inline threads, zero unaddressed top-level reviews or issue comments, all CI checks green, and two consecutive quiet polls after CI settles. Use when the user wants a PR polished to merge-ready without setting a fixed number of rounds.
user-invocable: true
argument-hint: "[PR number or URL] — if omitted, finds PR for current branch."
metadata:
author: autogpt-team
version: "1.0.0"
---
# PR Polish
**Goal.** Drive a PR to merge-ready by alternating `/pr-review` and `/pr-address` until **all** of the following hold:
1. The most recent `/pr-review` produces **zero new findings** (no new inline comments, no new top-level reviews with a non-empty body).
2. Every inline review thread reachable via GraphQL reports `isResolved: true`.
3. Every non-bot, non-author top-level review has been acknowledged (replied-to) OR resolved via a thread it spawned.
4. Every non-bot, non-author issue comment has been acknowledged (replied-to).
5. Every CI check is `conclusion: "success"` or `"skipped"` / `"neutral"` — none `"failure"` or still pending.
6. **Two consecutive post-CI polls** (≥60s apart) stay clean — no new threads, no new non-empty reviews, no new issue comments. Bots (coderabbitai, sentry, autogpt-reviewer) frequently post late after CI settles; a single green snapshot is not sufficient.
**Do not stop at a fixed number of rounds.** If round N introduces new comments, round N+1 is required. Cap at `_MAX_ROUNDS = 10` as a safety valve, but expect 25 in practice.
## TodoWrite
Before starting, write two todos so the user can see the loop progression:
- `Round {current}: /pr-review + /pr-address on PR #{N}` — current iteration.
- `Final polish polling: 2 consecutive clean polls, CI green, 0 unresolved` — runs after the last non-empty review round.
Update the `current` round counter at the start of each iteration; mark `completed` only when the round's address step finishes (all new threads addressed + resolved).
## Find the PR
```bash
ARG_PR="${ARG:-}"
# Normalize URL → numeric ID if the skill arg is a pull-request URL.
if [[ "$ARG_PR" =~ ^https?://github\.com/[^/]+/[^/]+/pull/([0-9]+) ]]; then
ARG_PR="${BASH_REMATCH[1]}"
fi
PR="${ARG_PR:-$(gh pr list --head "$(git branch --show-current)" --repo Significant-Gravitas/AutoGPT --json number --jq '.[0].number')}"
if [ -z "$PR" ] || [ "$PR" = "null" ]; then
echo "No PR found for current branch. Provide a PR number or URL as the skill arg."
exit 1
fi
echo "Polishing PR #$PR"
```
## The outer loop
```text
round = 0
while round < _MAX_ROUNDS:
round += 1
baseline = snapshot_state(PR) # see "Snapshotting state" below
invoke_skill("pr-review", PR) # posts findings as inline comments / top-level review
findings = diff_state(PR, baseline)
if findings.total == 0:
break # no new findings → go to polish polling
invoke_skill("pr-address", PR) # resolves every unresolved thread + CI failure
# Post-loop: polish polling (see below).
polish_polling(PR)
```
### Snapshotting state
Before each `/pr-review`, capture a baseline so the diff after the review reflects **only** what the review just added (not pre-existing threads):
```bash
# Inline threads — total count + latest databaseId per thread
gh api graphql -f query="
{
repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") {
pullRequest(number: ${PR}) {
reviewThreads(first: 100) {
totalCount
nodes {
id
isResolved
comments(last: 1) { nodes { databaseId } }
}
}
}
}
}" > /tmp/baseline_threads.json
# Top-level reviews — count + latest id per non-empty review
gh api "repos/Significant-Gravitas/AutoGPT/pulls/${PR}/reviews" --paginate \
--jq '[.[] | select((.body // "") != "") | {id, user: .user.login, state, submitted_at}]' \
> /tmp/baseline_reviews.json
# Issue comments — count + latest id per non-bot, non-author comment.
# Bots are filtered by User.type == "Bot" (GitHub sets this for app/bot
# accounts like coderabbitai, github-actions, sentry-io). The author is
# filtered by comparing login to the PR author — export it so jq can see it.
AUTHOR=$(gh api "repos/Significant-Gravitas/AutoGPT/pulls/${PR}" --jq '.user.login')
gh api "repos/Significant-Gravitas/AutoGPT/issues/${PR}/comments" --paginate \
--jq --arg author "$AUTHOR" \
'[.[] | select(.user.type != "Bot" and .user.login != $author)
| {id, user: .user.login, created_at}]' \
> /tmp/baseline_issue_comments.json
```
### Diffing after a review
After `/pr-review` runs, any of these counting as "new findings" means another address round is needed:
- New inline thread `id` not in the baseline.
- An existing thread whose latest comment `databaseId` is higher than the baseline's (new reply on an old thread).
- A new top-level review `id` with a non-empty body.
- A new issue comment `id` from a non-bot, non-author user.
If any of the four buckets is non-empty → not done; invoke `/pr-address` and loop.
## Polish polling
Once `/pr-review` produces zero new findings, do **not** exit yet. Bots (coderabbitai, sentry, autogpt-reviewer) commonly post late reviews after CI settles — 3090 seconds after the final push. Poll at 60-second intervals:
```text
NON_SUCCESS_TERMINAL = {"failure", "cancelled", "timed_out", "action_required", "startup_failure"}
clean_polls = 0
required_clean = 2
while clean_polls < required_clean:
# 1. CI gate — any terminal non-success conclusion (not just "failure")
# must trigger /pr-address. "success", "skipped", "neutral" are clean;
# anything else (including cancelled, timed_out, action_required) is a
# blocker that won't self-resolve.
ci = fetch_check_runs(PR)
if any ci.conclusion in NON_SUCCESS_TERMINAL:
invoke_skill("pr-address", PR) # address failures + any new comments
baseline = snapshot_state(PR) # reset — push during address invalidates old baseline
clean_polls = 0
continue
if any ci.conclusion is None (still in_progress):
sleep 60; continue # wait without counting this as clean
# 2. Comment / thread gate
threads = fetch_unresolved_threads(PR)
new_issue_comments = diff_against_baseline(issue_comments)
new_reviews = diff_against_baseline(reviews)
if threads or new_issue_comments or new_reviews:
invoke_skill("pr-address", PR)
baseline = snapshot_state(PR) # reset — the address loop just dealt with these,
# otherwise they stay "new" relative to the old baseline forever
clean_polls = 0
continue
# 3. Mergeability gate
mergeable = gh api repos/.../pulls/${PR} --jq '.mergeable'
if mergeable == false (CONFLICTING):
resolve_conflicts(PR) # see pr-address skill
clean_polls = 0
continue
if mergeable is null (UNKNOWN):
sleep 60; continue
clean_polls += 1
sleep 60
```
Only after `clean_polls == 2` do you report `ORCHESTRATOR:DONE`.
### Why 2 clean polls, not 1
A single green snapshot can be misleading — the final CI check often completes ~30s before a bot posts its delayed review. One quiet cycle does not prove the PR is stable; two consecutive cycles with no new threads, reviews, or issue comments arriving gives high confidence nothing else is incoming.
### Why checking every source each poll
`/pr-address` polling inside a single round already re-checks its own comments, but `/pr-polish` sits a level above and must also catch:
- New top-level reviews (autogpt-reviewer sometimes posts structured feedback only after several CI green cycles).
- Issue comments from human reviewers (not caught by inline thread polling).
- Sentry bug predictions that land on new line numbers post-push.
- Merge conflicts introduced by a race between your push and a merge to `dev`.
## Invocation pattern
Delegate to existing skills with the `Skill` tool; do not re-implement the review or address logic inline. This keeps the polish loop focused on orchestration and lets the child skills evolve independently.
```python
Skill(skill="pr-review", args=pr_url)
Skill(skill="pr-address", args=pr_url)
```
After each child invocation, re-query GitHub state directly — never trust a summary for the stop condition. The orchestrator's `ORCHESTRATOR:DONE` is verified against actual GraphQL / REST responses per the rules in `pr-address`'s "Verify actual count before outputting ORCHESTRATOR:DONE" section.
### **Auto-continue: do NOT end your response between child skills**
`/pr-polish` is a single orchestration task — one invocation drives the PR all the way to merge-ready. When a child `Skill()` call returns control to you:
- Do NOT summarize and stop.
- Do NOT wait for user confirmation to continue.
- Immediately, in the same response, perform the next loop step: state diff → decide next action → next `Skill()` call or polling sleep.
The child skill returning is a **loop iteration boundary**, not a conversation turn boundary. You are expected to keep going until one of the exit conditions in the opening section is met (2 consecutive clean polls, `_MAX_ROUNDS` hit, or an unrecoverable error).
If the user needs to approve a risky action mid-loop (e.g., a force-push or a destructive git operation), pause there — but not at the routine "round N finished, round N+1 needed" boundary. Those are silent transitions.
## GitHub rate limits
This skill issues many GraphQL calls (one review-thread query per outer iteration plus per-poll queries inside polish polling). Expect the GraphQL budget to be tight on large PRs. When `gh api rate_limit --jq .resources.graphql.remaining` drops below ~200, back off:
- Fall back to REST for reads (flat `/pulls/{N}/comments`, `/pulls/{N}/reviews`, `/issues/{N}/comments`) per the `pr-address` skill's GraphQL-fallback section.
- Queue thread resolutions (GraphQL-only) until the budget resets; keep making progress on fixes + REST replies meanwhile.
- `sleep 5` between any batch of ≥20 writes to avoid secondary rate limits.
## Safety valves
- `_MAX_ROUNDS = 10` — if review+address rounds exceed this, stop and escalate to the user with a summary of what's still unresolved. A PR that cannot converge in 10 rounds has systemic issues that need human judgment.
- After each commit, run `poetry run format` / `pnpm format && pnpm lint && pnpm types` per the target codebase's conventions. A failing format check is CI `failure` that will never self-resolve.
- Every `/pr-review` round checks for **duplicate** concerns first (via `pr-review`'s own "Fetch existing review comments" step) so the loop does not re-post the same finding that a prior round already resolved.
## Reporting
When the skill finishes (either via two clean polls or hitting `_MAX_ROUNDS`), produce a compact summary:
```
PR #{N} polish complete ({rounds_completed} rounds):
- {X} inline threads opened and resolved
- {Y} CI failures fixed
- {Z} new commits pushed
Final state: CI green, {total} threads all resolved, mergeable.
```
If exiting via `_MAX_ROUNDS`, flag explicitly:
```
PR #{N} polish stopped at {_MAX_ROUNDS} rounds — NOT merge-ready:
- {N} threads still unresolved: {titles}
- CI status: {summary}
Needs human review.
```
## When to use this skill
Use when the user says any of:
- "polish this PR"
- "keep reviewing and addressing until it's mergeable"
- "loop /pr-review + /pr-address until done"
- "make sure the PR is actually merge-ready"
Do **not** use when:
- User wants just one review pass (→ `/pr-review`).
- User wants to address already-posted comments without further self-review (→ `/pr-address`).
- A fixed round count is explicitly requested (e.g., "do 3 rounds") — honour the count instead of converging.

View File

@@ -5,7 +5,7 @@ user-invocable: true
argument-hint: "[worktree path or PR number] — tests the PR in the given worktree. Optional flags: --fix (auto-fix issues found)"
metadata:
author: autogpt-team
version: "2.0.0"
version: "2.1.0"
---
# Manual E2E Test
@@ -180,6 +180,120 @@ Based on the PR analysis, write a test plan to `$RESULTS_DIR/test-plan.md`:
**Be critical** — include edge cases, error paths, and security checks. Every scenario MUST specify what screenshots to take and what state to verify.
## Step 3.0: Claim the testing lock (coordinate parallel agents)
Multiple worktrees share the same host — Docker infra (postgres, redis, clamav), app ports (3000/8006/…), and the test user. Two agents running `/pr-test` concurrently will corrupt each other's state (connection-pool exhaustion, port binds failing silently, cross-test assertions). Use the root-worktree lock file to take turns.
### Lock file contract
Path (**always** the root worktree so all siblings see it): `/Users/majdyz/Code/AutoGPT/.ign.testing.lock`
Body (one `key=value` per line):
```
holder=<pr-XXXXX-purpose>
pid=<pid-or-"self">
started=<iso8601>
heartbeat=<iso8601, updated every ~2 min>
worktree=<full path>
branch=<branch name>
intent=<one-line description + rough duration>
```
### Claim
```bash
LOCK=/Users/majdyz/Code/AutoGPT/.ign.testing.lock
NOW=$(date -u +%Y-%m-%dT%H:%MZ)
STALE_AFTER_MIN=5
if [ -f "$LOCK" ]; then
HB=$(grep '^heartbeat=' "$LOCK" | cut -d= -f2)
HB_EPOCH=$(date -j -f '%Y-%m-%dT%H:%MZ' "$HB" +%s 2>/dev/null || date -d "$HB" +%s 2>/dev/null || echo 0)
AGE_MIN=$(( ( $(date -u +%s) - HB_EPOCH ) / 60 ))
if [ "$AGE_MIN" -gt "$STALE_AFTER_MIN" ]; then
echo "WARN: stale lock (${AGE_MIN}m old) — reclaiming"
cat "$LOCK" | sed 's/^/ stale: /'
else
echo "Another agent holds the lock:"; cat "$LOCK"
echo "Wait until released or resume after $((STALE_AFTER_MIN - AGE_MIN))m."
exit 1
fi
fi
cat > "$LOCK" <<EOF
holder=pr-${PR_NUMBER}-e2e
pid=self
started=$NOW
heartbeat=$NOW
worktree=$WORKTREE_PATH
branch=$(cd $WORKTREE_PATH && git branch --show-current)
intent=E2E test PR #${PR_NUMBER}, native mode, ~60min
EOF
echo "Lock claimed"
```
### Heartbeat (MUST run in background during the whole test)
Without a heartbeat a crashed agent keeps the lock forever. Run this as a background process right after claim:
```bash
(while true; do
sleep 120
[ -f "$LOCK" ] || exit 0 # lock released → exit heartbeat
perl -i -pe "s/^heartbeat=.*/heartbeat=$(date -u +%Y-%m-%dT%H:%MZ)/" "$LOCK"
done) &
HEARTBEAT_PID=$!
echo "$HEARTBEAT_PID" > /tmp/pr-test-heartbeat.pid
```
### Release (always — even on failure)
```bash
kill "$HEARTBEAT_PID" 2>/dev/null
rm -f "$LOCK" /tmp/pr-test-heartbeat.pid
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] released lock" \
>> /Users/majdyz/Code/AutoGPT/.ign.testing.log
```
Use a `trap` so release runs even on `exit 1`:
```bash
trap 'kill "$HEARTBEAT_PID" 2>/dev/null; rm -f "$LOCK"' EXIT INT TERM
```
### **Release the lock AS SOON AS the test run is done**
The lock guards **test execution**, not **app lifecycle**. Once Step 5 (record results) and Step 6 (post PR comment) are complete, release the lock IMMEDIATELY — even if:
- The native `poetry run app` / `pnpm dev` processes are still running so the user can keep poking at the app manually.
- You're leaving docker containers up.
- You're tailing logs for a minute or two.
Keeping the lock held past the test run is the single most common way `/pr-test` stalls other agents. **The app staying up is orthogonal to the lock; don't conflate them.** Sibling worktrees running their own `/pr-test` will kill the stray processes and free the ports themselves (Step 3c/3e-native handle that) — they just need the lock file gone.
Concretely, the sequence at the end of every `/pr-test` run (success or failure) is:
```bash
# 1. Write the final report + post PR comment — done above in Step 5/6.
# 2. Release the lock right now, even if the app is still up.
kill "$HEARTBEAT_PID" 2>/dev/null
rm -f "$LOCK" /tmp/pr-test-heartbeat.pid
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] released lock (app may still be running)" \
>> /Users/majdyz/Code/AutoGPT/.ign.testing.log
# 3. Optionally leave the app running and note it so the user knows:
echo "Native stack still running on :3000 / :8006 for manual poking. Kill with:"
echo " pkill -9 -f 'poetry run app'; pkill -9 -f 'next-server|next dev'"
```
If a sibling agent's `/pr-test` needs to take over, it'll do the kill+rebuild dance from Step 3c/3e-native on its own — your only job is to not hold the lock file past the end of your test.
### Shared status log
`/Users/majdyz/Code/AutoGPT/.ign.testing.log` is an append-only channel any agent can read/write. Use it for "I'm waiting", "I'm done, resources free", or post-run notes:
```bash
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] <message>" \
>> /Users/majdyz/Code/AutoGPT/.ign.testing.log
```
## Step 3: Environment setup
### 3a. Copy .env files from the root worktree
@@ -248,7 +362,87 @@ docker ps --format "{{.Names}}" | grep -E "rest_server|executor|copilot|websocke
done
```
### 3e. Build and start
**Native mode also:** when running the app natively (see 3e-native), kill any stray host processes and free the app ports before starting — otherwise `poetry run app` and `pnpm dev` will fail to bind.
```bash
# Kill stray native app processes from prior runs
pkill -9 -f "python.*backend" 2>/dev/null || true
pkill -9 -f "poetry run app" 2>/dev/null || true
pkill -9 -f "next-server|next dev" 2>/dev/null || true
# Free app ports (errors per port are ignored — port may simply be unused)
for port in 3000 8006 8001 8002 8005 8008; do
lsof -ti :$port -sTCP:LISTEN | xargs -r kill -9 2>/dev/null || true
done
```
### 3e-native. Run the app natively (PREFERRED for iterative dev)
Native mode runs infra (postgres, supabase, redis, rabbitmq, clamav) in docker but runs the backend and frontend directly on the host. This avoids the 3-8 minute `docker compose build` cycle on every backend change — code edits are picked up on process restart (seconds) instead of a full image rebuild.
**When to prefer native mode (default for this skill):**
- Iterative dev/debug loops where you're editing backend or frontend code between test runs
- Any PR that touches Python/TS source but not Dockerfiles, compose config, or infra images
- Fast repro of a failing scenario — restart `poetry run app` in a couple of seconds
**When to prefer docker mode (3e fallback):**
- Testing changes to `Dockerfile`, `docker-compose.yml`, or base images
- Production-parity smoke tests (exact container env, networking, volumes)
- CI-equivalent runs where you need the exact image that'll ship
**Note on 3b (copilot auth):** no npm install anywhere. `poetry install` pulls in `claude_agent_sdk`, which ships its own Claude CLI binary — available on `PATH` whenever you run commands via `poetry run` (native) OR whenever the copilot_executor container is built from its Poetry lockfile (docker). The OAuth token extraction still applies (same `refresh_claude_token.sh` call).
**Preamble:** before starting native, run the kill-stray + free-ports block from 3c's "Native mode also" subsection.
**1. Start infra only (one-time per session):**
```bash
cd $PLATFORM_DIR && docker compose --profile local up deps --detach --remove-orphans --build
```
This brings up postgres/supabase/redis/rabbitmq/clamav and skips all app services.
**2. Start the backend natively:**
```bash
cd $BACKEND_DIR && (poetry run app 2>&1 | tee .ign.application.logs) &
```
`poetry run app` spawns **all** app subprocesses — `rest_server`, `executor`, `copilot_executor`, `websocket`, `scheduler`, `notification_server`, `database_manager` — inside ONE parent process. No separate containers, no separate terminals. The `.ign.application.logs` prefix is already gitignored.
**3. Wait for the backend on :8006 BEFORE starting the frontend.** This ordering matters — the frontend's `pnpm dev` startup invokes `generate-api-queries`, which fetches `/openapi.json` from the backend. If the backend isn't listening yet, `pnpm dev` fails immediately.
```bash
for i in $(seq 1 60); do
if [ "$(curl -s -o /dev/null -w '%{http_code}' http://localhost:8006/docs 2>/dev/null)" = "200" ]; then
echo "Backend ready"
break
fi
sleep 2
done
```
**4. Start the frontend natively:**
```bash
cd $FRONTEND_DIR && (pnpm dev 2>&1 | tee .ign.frontend.logs) &
```
**5. Wait for the frontend on :3000:**
```bash
for i in $(seq 1 60); do
if [ "$(curl -s -o /dev/null -w '%{http_code}' http://localhost:3000 2>/dev/null)" = "200" ]; then
echo "Frontend ready"
break
fi
sleep 2
done
```
Once both are up, skip 3e/3f and go straight to **3g/3h** (feature flags / test user creation).
### 3e. Build and start (docker — fallback)
```bash
cd $PLATFORM_DIR && docker compose build --no-cache 2>&1 | tail -20
@@ -442,6 +636,22 @@ agent-browser --session-name pr-test snapshot | grep "text:"
### Checking logs
**Native mode:** when running via `poetry run app` + `pnpm dev`, all app logs stream to the `.ign.*.logs` files written by the `tee` pipes in 3e-native. `rest_server`, `executor`, `copilot_executor`, `websocket`, `scheduler`, `notification_server`, and `database_manager` are all subprocesses of the single `poetry run app` parent, so their output is interleaved in `.ign.application.logs`.
```bash
# Backend (all app subprocesses interleaved)
tail -f $BACKEND_DIR/.ign.application.logs
# Frontend (Next.js dev server)
tail -f $FRONTEND_DIR/.ign.frontend.logs
# Filter for errors across either log
grep -iE "error|exception|traceback" $BACKEND_DIR/.ign.application.logs | tail -20
grep -iE "error|exception|traceback" $FRONTEND_DIR/.ign.frontend.logs | tail -20
```
**Docker mode:**
```bash
# Backend REST server
docker logs autogpt_platform-rest_server-1 2>&1 | tail -30
@@ -571,6 +781,19 @@ Upload screenshots to the PR using the GitHub Git API (no local git operations
**CRITICAL — NEVER post a bare directory link like `https://github.com/.../tree/...`.** Every screenshot MUST appear as `![name](raw_url)` inline in the PR comment so reviewers can see them without clicking any links. After posting, the verification step below greps the comment for `![` tags and exits 1 if none are found — the test run is considered incomplete until this passes.
**CRITICAL — NEVER paste absolute local paths into the PR comment.** Strings like `/Users/…`, `/home/…`, `C:\…` are useless to every reviewer except you. Before posting, grep the final body for `/Users/`, `/home/`, `/tmp/`, `/private/`, `C:\`, `~/` and either drop those lines entirely or rewrite them as repo-relative paths (`autogpt_platform/backend/…`). The PR comment is an artifact reviewers on GitHub read — it must be self-contained on github.com. Keep local paths in `$RESULTS_DIR/test-report.md` for yourself; only copy the *content* they reference (excerpts, test names, log lines) into the PR comment, not the path.
**Pre-post sanity check** (paste after building the comment body, before `gh api ... comments`):
```bash
# Reject any local-looking absolute path or home-dir shortcut in the body
if grep -nE '(^|[^A-Za-z])(/Users/|/home/|/tmp/|/private/|C:\\|~/)[A-Za-z0-9]' "$COMMENT_FILE" ; then
echo "ABORT: local filesystem paths detected in PR comment body."
echo "Remove or rewrite as repo-relative (autogpt_platform/...) before posting."
exit 1
fi
```
```bash
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
REPO="Significant-Gravitas/AutoGPT"
@@ -876,9 +1099,15 @@ test scenario → find issue (bug OR UX problem) → screenshot broken state
### Problem: Frontend shows cookie banner blocking interaction
**Fix:** `agent-browser click 'text=Accept All'` before other interactions.
### Problem: Container loses npm packages after rebuild
**Cause:** `docker compose up --build` rebuilds the image, losing runtime installs.
**Fix:** Add packages to the Dockerfile instead of installing at runtime.
### Problem: Claude CLI not found in copilot_executor container
**Symptom:** Copilot logs say `claude: command not found` or similar when starting an SDK turn.
**Cause:** Image was built without `poetry install` (stale base layer, or Dockerfile bypass). The SDK CLI ships inside the `claude_agent_sdk` Poetry dep — it is NOT an npm package.
**Fix:** Rebuild the image cleanly: `docker compose build --no-cache copilot_executor && docker compose up -d copilot_executor`. Do NOT `docker exec ... npm install -g @anthropic-ai/claude-code` — that is outdated guidance and will pollute the container with a second CLI that the SDK won't use.
### Problem: agent-browser screenshot hangs / times out
**Symptom:** `agent-browser screenshot` exits with code 124 even on `about:blank`.
**Cause:** Stuck CDP connection or Chromium process tree. Seen on macOS when a prior `/pr-test` left a zombie Chrome for Testing.
**Fix:** `pkill -9 -f "agent-browser|chromium|Chrome for Testing" && sleep 2`, then reopen the browser with a fresh `--session-name`. If still failing, verify via `agent-browser eval` + `agent-browser snapshot` (DOM state) instead of relying on PNGs — the feature under test is the same.
### Problem: Services not starting after `docker compose up`
**Fix:** Wait and check health: `docker compose ps`. Common cause: migration hasn't finished. Check: `docker logs autogpt_platform-migrate-1 2>&1 | tail -5`. If supabase-db isn't healthy: `docker restart supabase-db && sleep 10`.

1
.gitignore vendored
View File

@@ -195,3 +195,4 @@ test.db
# Implementation plans (generated by AI agents)
plans/
.claude/worktrees/
test-results/

310
WORKFLOW.md Normal file
View File

@@ -0,0 +1,310 @@
---
hooks:
after_create: |
if command -v mise >/dev/null 2>&1; then
if [ -f mise.toml ]; then
mise trust
mise exec -- mix deps.get
elif [ -f elixir/mise.toml ]; then
cd elixir && mise trust && mise exec -- mix deps.get
fi
fi
before_remove: |
if [ -f elixir/mix.exs ]; then
cd elixir && mise exec -- mix workspace.before_remove
fi
agent:
default_effort: medium
max_turns: 20
---
You are working on a Linear ticket `{{ issue.identifier }}`
{% if attempt %}
Continuation context:
- This is retry attempt #{{ attempt }} because the ticket is still in an active state.
- Resume from the current workspace state instead of restarting from scratch.
- Do not repeat already-completed investigation or validation unless needed for new code changes.
- Do not end the turn while the issue remains in an active state unless you are blocked by missing required permissions/secrets.
{% endif %}
Issue context:
Identifier: {{ issue.identifier }}
Title: {{ issue.title }}
Current status: {{ issue.state }}
Labels: {{ issue.labels }}
URL: {{ issue.url }}
Description:
{% if issue.description %}
{{ issue.description }}
{% else %}
No description provided.
{% endif %}
Instructions:
1. This is an unattended orchestration session. Never ask a human to perform follow-up actions.
2. Only stop early for a true blocker (missing required auth/permissions/secrets). If blocked, record it in the workpad and move the issue according to workflow.
3. Final message must report completed actions and blockers only. Do not include "next steps for user".
Work only in the provided repository copy. Do not touch any other path.
## Prerequisite: Linear MCP or `linear_graphql` tool is available
The agent should be able to talk to Linear, either via a configured Linear MCP server or injected `linear_graphql` tool. If none are present, stop and ask the user to configure Linear.
## Default posture
- Start by determining the ticket's current status, then follow the matching flow for that status.
- Start every task by opening the tracking workpad comment and bringing it up to date before doing new implementation work.
- Spend extra effort up front on planning and verification design before implementation.
- Reproduce first: always confirm the current behavior/issue signal before changing code so the fix target is explicit.
- Keep ticket metadata current (state, checklist, acceptance criteria, links).
- Treat a single persistent Linear comment as the source of truth for progress.
- Use that single workpad comment for all progress and handoff notes; do not post separate "done"/summary comments.
- Treat any ticket-authored `Validation`, `Test Plan`, or `Testing` section as non-negotiable acceptance input: mirror it in the workpad and execute it before considering the work complete.
- When meaningful out-of-scope improvements are discovered during execution,
file a separate Linear issue instead of expanding scope. The follow-up issue
must include a clear title, description, and acceptance criteria, be placed in
`Backlog`, be assigned to the same project as the current issue, link the
current issue as `related`, and use `blockedBy` when the follow-up depends on
the current issue.
- Move status only when the matching quality bar is met.
- Operate autonomously end-to-end unless blocked by missing requirements, secrets, or permissions.
- Use the blocked-access escape hatch only for true external blockers (missing required tools/auth) after exhausting documented fallbacks.
## Related skills
- `linear`: interact with Linear.
- `commit`: produce clean, logical commits during implementation.
- `push`: keep remote branch current and publish updates.
- `pull`: keep branch updated with latest `origin/main` before handoff.
- `land`: when ticket reaches `Merging`, explicitly open and follow `.codex/skills/land/SKILL.md`, which includes the `land` loop.
## Status map
- `Backlog` -> out of scope for this workflow; do not modify.
- `Todo` -> queued; immediately transition to `In Progress` before active work.
- Special case: if a PR is already attached, treat as feedback/rework loop (run full PR feedback sweep, address or explicitly push back, revalidate, return to `Human Review`).
- `In Progress` -> implementation actively underway.
- `Human Review` -> PR is attached and validated; waiting on human approval.
- `Merging` -> approved by human; execute the `land` skill flow (do not call `gh pr merge` directly).
- `Rework` -> reviewer requested changes; planning + implementation required.
- `Done` -> terminal state; no further action required.
## Step 0: Determine current ticket state and route
1. Fetch the issue by explicit ticket ID.
2. Read the current state.
3. Route to the matching flow:
- `Backlog` -> do not modify issue content/state; stop and wait for human to move it to `Todo`.
- `Todo` -> immediately move to `In Progress`, then ensure bootstrap workpad comment exists (create if missing), then start execution flow.
- If PR is already attached, start by reviewing all open PR comments and deciding required changes vs explicit pushback responses.
- `In Progress` -> continue execution flow from current scratchpad comment.
- `Human Review` -> wait and poll for decision/review updates.
- `Merging` -> on entry, open and follow `.codex/skills/land/SKILL.md`; do not call `gh pr merge` directly.
- `Rework` -> run rework flow.
- `Done` -> do nothing and shut down.
4. Check whether a PR already exists for the current branch and whether it is closed.
- If a branch PR exists and is `CLOSED` or `MERGED`, treat prior branch work as non-reusable for this run.
- Create a fresh branch from `origin/main` and restart execution flow as a new attempt.
5. For `Todo` tickets, do startup sequencing in this exact order:
- `update_issue(..., state: "In Progress")`
- find/create `## Codex Workpad` bootstrap comment
- only then begin analysis/planning/implementation work.
6. Add a short comment if state and issue content are inconsistent, then proceed with the safest flow.
## Step 1: Start/continue execution (Todo or In Progress)
1. Find or create a single persistent scratchpad comment for the issue:
- Search existing comments for a marker header: `## Codex Workpad`.
- Ignore resolved comments while searching; only active/unresolved comments are eligible to be reused as the live workpad.
- If found, reuse that comment; do not create a new workpad comment.
- If not found, create one workpad comment and use it for all updates.
- Persist the workpad comment ID and only write progress updates to that ID.
2. If arriving from `Todo`, do not delay on additional status transitions: the issue should already be `In Progress` before this step begins.
3. Immediately reconcile the workpad before new edits:
- Check off items that are already done.
- Expand/fix the plan so it is comprehensive for current scope.
- Ensure `Acceptance Criteria` and `Validation` are current and still make sense for the task.
4. Start work by writing/updating a hierarchical plan in the workpad comment.
5. Ensure the workpad includes a compact environment stamp at the top as a code fence line:
- Format: `<host>:<abs-workdir>@<short-sha>`
- Example: `devbox-01:/home/dev-user/code/symphony-workspaces/MT-32@7bdde33bc`
- Do not include metadata already inferable from Linear issue fields (`issue ID`, `status`, `branch`, `PR link`).
6. Add explicit acceptance criteria and TODOs in checklist form in the same comment.
- If changes are user-facing, include a UI walkthrough acceptance criterion that describes the end-to-end user path to validate.
- If changes touch app files or app behavior, add explicit app-specific flow checks to `Acceptance Criteria` in the workpad (for example: launch path, changed interaction path, and expected result path).
- If the ticket description/comment context includes `Validation`, `Test Plan`, or `Testing` sections, copy those requirements into the workpad `Acceptance Criteria` and `Validation` sections as required checkboxes (no optional downgrade).
7. Run a principal-style self-review of the plan and refine it in the comment.
8. Before implementing, capture a concrete reproduction signal and record it in the workpad `Notes` section (command/output, screenshot, or deterministic UI behavior).
9. Run the `pull` skill to sync with latest `origin/main` before any code edits, then record the pull/sync result in the workpad `Notes`.
- Include a `pull skill evidence` note with:
- merge source(s),
- result (`clean` or `conflicts resolved`),
- resulting `HEAD` short SHA.
10. Compact context and proceed to execution.
## PR feedback sweep protocol (required)
When a ticket has an attached PR, run this protocol before moving to `Human Review`:
1. Identify the PR number from issue links/attachments.
2. Gather feedback from all channels:
- Top-level PR comments (`gh pr view --comments`).
- Inline review comments (`gh api repos/<owner>/<repo>/pulls/<pr>/comments`).
- Review summaries/states (`gh pr view --json reviews`).
3. Treat every actionable reviewer comment (human or bot), including inline review comments, as blocking until one of these is true:
- code/test/docs updated to address it, or
- explicit, justified pushback reply is posted on that thread.
4. Update the workpad plan/checklist to include each feedback item and its resolution status.
5. Re-run validation after feedback-driven changes and push updates.
6. Repeat this sweep until there are no outstanding actionable comments.
## Blocked-access escape hatch (required behavior)
Use this only when completion is blocked by missing required tools or missing auth/permissions that cannot be resolved in-session.
- GitHub is **not** a valid blocker by default. Always try fallback strategies first (alternate remote/auth mode, then continue publish/review flow).
- Do not move to `Human Review` for GitHub access/auth until all fallback strategies have been attempted and documented in the workpad.
- If a non-GitHub required tool is missing, or required non-GitHub auth is unavailable, move the ticket to `Human Review` with a short blocker brief in the workpad that includes:
- what is missing,
- why it blocks required acceptance/validation,
- exact human action needed to unblock.
- Keep the brief concise and action-oriented; do not add extra top-level comments outside the workpad.
## Step 2: Execution phase (Todo -> In Progress -> Human Review)
1. Determine current repo state (`branch`, `git status`, `HEAD`) and verify the kickoff `pull` sync result is already recorded in the workpad before implementation continues.
2. If current issue state is `Todo`, move it to `In Progress`; otherwise leave the current state unchanged.
3. Load the existing workpad comment and treat it as the active execution checklist.
- Edit it liberally whenever reality changes (scope, risks, validation approach, discovered tasks).
4. Implement against the hierarchical TODOs and keep the comment current:
- Check off completed items.
- Add newly discovered items in the appropriate section.
- Keep parent/child structure intact as scope evolves.
- Update the workpad immediately after each meaningful milestone (for example: reproduction complete, code change landed, validation run, review feedback addressed).
- Never leave completed work unchecked in the plan.
- For tickets that started as `Todo` with an attached PR, run the full PR feedback sweep protocol immediately after kickoff and before new feature work.
5. Run validation/tests required for the scope.
- Mandatory gate: execute all ticket-provided `Validation`/`Test Plan`/ `Testing` requirements when present; treat unmet items as incomplete work.
- Prefer a targeted proof that directly demonstrates the behavior you changed.
- You may make temporary local proof edits to validate assumptions (for example: tweak a local build input for `make`, or hardcode a UI account / response path) when this increases confidence.
- Revert every temporary proof edit before commit/push.
- Document these temporary proof steps and outcomes in the workpad `Validation`/`Notes` sections so reviewers can follow the evidence.
- If app-touching, run `launch-app` validation and capture/upload media via `github-pr-media` before handoff.
6. Re-check all acceptance criteria and close any gaps.
7. Before every `git push` attempt, run the required validation for your scope and confirm it passes; if it fails, address issues and rerun until green, then commit and push changes.
8. Attach PR URL to the issue (prefer attachment; use the workpad comment only if attachment is unavailable).
- Ensure the GitHub PR has label `symphony` (add it if missing).
9. Merge latest `origin/main` into branch, resolve conflicts, and rerun checks.
10. Update the workpad comment with final checklist status and validation notes.
- Mark completed plan/acceptance/validation checklist items as checked.
- Add final handoff notes (commit + validation summary) in the same workpad comment.
- Do not include PR URL in the workpad comment; keep PR linkage on the issue via attachment/link fields.
- Add a short `### Confusions` section at the bottom when any part of task execution was unclear/confusing, with concise bullets.
- Do not post any additional completion summary comment.
11. Before moving to `Human Review`, poll PR feedback and checks:
- Read the PR `Manual QA Plan` comment (when present) and use it to sharpen UI/runtime test coverage for the current change.
- Run the full PR feedback sweep protocol.
- Confirm PR checks are passing (green) after the latest changes.
- Confirm every required ticket-provided validation/test-plan item is explicitly marked complete in the workpad.
- Repeat this check-address-verify loop until no outstanding comments remain and checks are fully passing.
- Re-open and refresh the workpad before state transition so `Plan`, `Acceptance Criteria`, and `Validation` exactly match completed work.
12. Only then move issue to `Human Review`.
- Exception: if blocked by missing required non-GitHub tools/auth per the blocked-access escape hatch, move to `Human Review` with the blocker brief and explicit unblock actions.
13. For `Todo` tickets that already had a PR attached at kickoff:
- Ensure all existing PR feedback was reviewed and resolved, including inline review comments (code changes or explicit, justified pushback response).
- Ensure branch was pushed with any required updates.
- Then move to `Human Review`.
## Step 3: Human Review and merge handling
1. When the issue is in `Human Review`, do not code or change ticket content.
2. Poll for updates as needed, including GitHub PR review comments from humans and bots.
3. If review feedback requires changes, move the issue to `Rework` and follow the rework flow.
4. If approved, human moves the issue to `Merging`.
5. When the issue is in `Merging`, open and follow `.codex/skills/land/SKILL.md`, then run the `land` skill in a loop until the PR is merged. Do not call `gh pr merge` directly.
6. After merge is complete, move the issue to `Done`.
## Step 4: Rework handling
1. Treat `Rework` as a full approach reset, not incremental patching.
2. Re-read the full issue body and all human comments; explicitly identify what will be done differently this attempt.
3. Close the existing PR tied to the issue.
4. Remove the existing `## Codex Workpad` comment from the issue.
5. Create a fresh branch from `origin/main`.
6. Start over from the normal kickoff flow:
- If current issue state is `Todo`, move it to `In Progress`; otherwise keep the current state.
- Create a new bootstrap `## Codex Workpad` comment.
- Build a fresh plan/checklist and execute end-to-end.
## Completion bar before Human Review
- Step 1/2 checklist is fully complete and accurately reflected in the single workpad comment.
- Acceptance criteria and required ticket-provided validation items are complete.
- Validation/tests are green for the latest commit.
- PR feedback sweep is complete and no actionable comments remain.
- PR checks are green, branch is pushed, and PR is linked on the issue.
- Required PR metadata is present (`symphony` label).
- If app-touching, runtime validation/media requirements from `App runtime validation (required)` are complete.
## Guardrails
- If the branch PR is already closed/merged, do not reuse that branch or prior implementation state for continuation.
- For closed/merged branch PRs, create a new branch from `origin/main` and restart from reproduction/planning as if starting fresh.
- If issue state is `Backlog`, do not modify it; wait for human to move to `Todo`.
- Do not edit the issue body/description for planning or progress tracking.
- Use exactly one persistent workpad comment (`## Codex Workpad`) per issue.
- If comment editing is unavailable in-session, use the update script. Only report blocked if both MCP editing and script-based editing are unavailable.
- Temporary proof edits are allowed only for local verification and must be reverted before commit.
- If out-of-scope improvements are found, create a separate Backlog issue rather
than expanding current scope, and include a clear
title/description/acceptance criteria, same-project assignment, a `related`
link to the current issue, and `blockedBy` when the follow-up depends on the
current issue.
- Do not move to `Human Review` unless the `Completion bar before Human Review` is satisfied.
- In `Human Review`, do not make changes; wait and poll.
- If state is terminal (`Done`), do nothing and shut down.
- Keep issue text concise, specific, and reviewer-oriented.
- If blocked and no workpad exists yet, add one blocker comment describing blocker, impact, and next unblock action.
## Workpad template
Use this exact structure for the persistent workpad comment and keep it updated in place throughout execution:
````md
## Codex Workpad
```text
<hostname>:<abs-path>@<short-sha>
```
### Plan
- [ ] 1\. Parent task
- [ ] 1.1 Child task
- [ ] 1.2 Child task
- [ ] 2\. Parent task
### Acceptance Criteria
- [ ] Criterion 1
- [ ] Criterion 2
### Validation
- [ ] targeted tests: `<command>`
### Notes
- <short progress note with timestamp>
### Confusions
- <only include when something was confusing during execution>
````

View File

@@ -1,3 +1,6 @@
*.ignore.*
*.ign.*
.application.logs
# Claude Code local settings only — the rest of .claude/ is shared (skills etc.)
.claude/settings.local.json

View File

@@ -59,6 +59,8 @@ class OAuthState(BaseModel):
code_verifier: Optional[str] = None
scopes: list[str]
"""Unix timestamp (seconds) indicating when this OAuth state expires"""
credential_id: Optional[str] = None
"""If set, this OAuth flow upgrades an existing credential's scopes."""
class UserMetadata(BaseModel):

View File

@@ -179,6 +179,9 @@ MEM0_API_KEY=
OPENWEATHERMAP_API_KEY=
GOOGLE_MAPS_API_KEY=
# Platform Bot Linking
PLATFORM_LINK_BASE_URL=http://localhost:3000/link
# Communication Services
DISCORD_BOT_TOKEN=
MEDIUM_API_KEY=

View File

@@ -0,0 +1,932 @@
import asyncio
import logging
from typing import List
from autogpt_libs.auth import requires_admin_user
from autogpt_libs.auth.models import User as AuthUser
from fastapi import APIRouter, HTTPException, Security
from prisma.enums import AgentExecutionStatus
from pydantic import BaseModel
from backend.api.features.admin.model import (
AgentDiagnosticsResponse,
ExecutionDiagnosticsResponse,
)
from backend.data.diagnostics import (
FailedExecutionDetail,
OrphanedScheduleDetail,
RunningExecutionDetail,
ScheduleDetail,
ScheduleHealthMetrics,
cleanup_all_stuck_queued_executions,
cleanup_orphaned_executions_bulk,
cleanup_orphaned_schedules_bulk,
get_agent_diagnostics,
get_all_orphaned_execution_ids,
get_all_schedules_details,
get_all_stuck_queued_execution_ids,
get_execution_diagnostics,
get_failed_executions_count,
get_failed_executions_details,
get_invalid_executions_details,
get_long_running_executions_details,
get_orphaned_executions_details,
get_orphaned_schedules_details,
get_running_executions_details,
get_schedule_health_metrics,
get_stuck_queued_executions_details,
stop_all_long_running_executions,
)
from backend.data.execution import get_graph_executions
from backend.executor.utils import add_graph_execution, stop_graph_execution
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/admin",
tags=["diagnostics", "admin"],
dependencies=[Security(requires_admin_user)],
)
class RunningExecutionsListResponse(BaseModel):
"""Response model for list of running executions"""
executions: List[RunningExecutionDetail]
total: int
class FailedExecutionsListResponse(BaseModel):
"""Response model for list of failed executions"""
executions: List[FailedExecutionDetail]
total: int
class StopExecutionRequest(BaseModel):
"""Request model for stopping a single execution"""
execution_id: str
class StopExecutionsRequest(BaseModel):
"""Request model for stopping multiple executions"""
execution_ids: List[str]
class StopExecutionResponse(BaseModel):
"""Response model for stop execution operations"""
success: bool
stopped_count: int = 0
message: str
class RequeueExecutionResponse(BaseModel):
"""Response model for requeue execution operations"""
success: bool
requeued_count: int = 0
message: str
@router.get(
"/diagnostics/executions",
response_model=ExecutionDiagnosticsResponse,
summary="Get Execution Diagnostics",
)
async def get_execution_diagnostics_endpoint():
"""
Get comprehensive diagnostic information about execution status.
Returns all execution metrics including:
- Current state (running, queued)
- Orphaned executions (>24h old, likely not in executor)
- Failure metrics (1h, 24h, rate)
- Long-running detection (stuck >1h, >24h)
- Stuck queued detection
- Throughput metrics (completions/hour)
- RabbitMQ queue depths
"""
logger.info("Getting execution diagnostics")
diagnostics = await get_execution_diagnostics()
response = ExecutionDiagnosticsResponse(
running_executions=diagnostics.running_count,
queued_executions_db=diagnostics.queued_db_count,
queued_executions_rabbitmq=diagnostics.rabbitmq_queue_depth,
cancel_queue_depth=diagnostics.cancel_queue_depth,
orphaned_running=diagnostics.orphaned_running,
orphaned_queued=diagnostics.orphaned_queued,
failed_count_1h=diagnostics.failed_count_1h,
failed_count_24h=diagnostics.failed_count_24h,
failure_rate_24h=diagnostics.failure_rate_24h,
stuck_running_24h=diagnostics.stuck_running_24h,
stuck_running_1h=diagnostics.stuck_running_1h,
oldest_running_hours=diagnostics.oldest_running_hours,
stuck_queued_1h=diagnostics.stuck_queued_1h,
queued_never_started=diagnostics.queued_never_started,
invalid_queued_with_start=diagnostics.invalid_queued_with_start,
invalid_running_without_start=diagnostics.invalid_running_without_start,
completed_1h=diagnostics.completed_1h,
completed_24h=diagnostics.completed_24h,
throughput_per_hour=diagnostics.throughput_per_hour,
timestamp=diagnostics.timestamp,
)
logger.info(
f"Execution diagnostics: running={diagnostics.running_count}, "
f"queued_db={diagnostics.queued_db_count}, "
f"orphaned={diagnostics.orphaned_running + diagnostics.orphaned_queued}, "
f"failed_24h={diagnostics.failed_count_24h}"
)
return response
@router.get(
"/diagnostics/agents",
response_model=AgentDiagnosticsResponse,
summary="Get Agent Diagnostics",
)
async def get_agent_diagnostics_endpoint():
"""
Get diagnostic information about agents.
Returns:
- agents_with_active_executions: Number of unique agents with running/queued executions
- timestamp: Current timestamp
"""
logger.info("Getting agent diagnostics")
diagnostics = await get_agent_diagnostics()
response = AgentDiagnosticsResponse(
agents_with_active_executions=diagnostics.agents_with_active_executions,
timestamp=diagnostics.timestamp,
)
logger.info(
f"Agent diagnostics: with_active_executions={diagnostics.agents_with_active_executions}"
)
return response
@router.get(
"/diagnostics/executions/running",
response_model=RunningExecutionsListResponse,
summary="List Running Executions",
)
async def list_running_executions(
limit: int = 100,
offset: int = 0,
):
"""
Get detailed list of running and queued executions (recent, likely active).
Args:
limit: Maximum number of executions to return (default 100)
offset: Number of executions to skip (default 0)
Returns:
List of running executions with details
"""
logger.info(f"Listing running executions (limit={limit}, offset={offset})")
executions = await get_running_executions_details(limit=limit, offset=offset)
# Get total count for pagination
diagnostics = await get_execution_diagnostics()
total = diagnostics.running_count + diagnostics.queued_db_count
return RunningExecutionsListResponse(executions=executions, total=total)
@router.get(
"/diagnostics/executions/orphaned",
response_model=RunningExecutionsListResponse,
summary="List Orphaned Executions",
)
async def list_orphaned_executions(
limit: int = 100,
offset: int = 0,
):
"""
Get detailed list of orphaned executions (>24h old, likely not in executor).
Args:
limit: Maximum number of executions to return (default 100)
offset: Number of executions to skip (default 0)
Returns:
List of orphaned executions with details
"""
logger.info(f"Listing orphaned executions (limit={limit}, offset={offset})")
executions = await get_orphaned_executions_details(limit=limit, offset=offset)
# Get total count for pagination
diagnostics = await get_execution_diagnostics()
total = diagnostics.orphaned_running + diagnostics.orphaned_queued
return RunningExecutionsListResponse(executions=executions, total=total)
@router.get(
"/diagnostics/executions/failed",
response_model=FailedExecutionsListResponse,
summary="List Failed Executions",
)
async def list_failed_executions(
limit: int = 100,
offset: int = 0,
hours: int = 24,
):
"""
Get detailed list of failed executions.
Args:
limit: Maximum number of executions to return (default 100)
offset: Number of executions to skip (default 0)
hours: Number of hours to look back (default 24)
Returns:
List of failed executions with error details
"""
logger.info(
f"Listing failed executions (limit={limit}, offset={offset}, hours={hours})"
)
executions = await get_failed_executions_details(
limit=limit, offset=offset, hours=hours
)
# Get total count for pagination
# Always count actual total for given hours parameter
total = await get_failed_executions_count(hours=hours)
return FailedExecutionsListResponse(executions=executions, total=total)
@router.get(
"/diagnostics/executions/long-running",
response_model=RunningExecutionsListResponse,
summary="List Long-Running Executions",
)
async def list_long_running_executions(
limit: int = 100,
offset: int = 0,
):
"""
Get detailed list of long-running executions (RUNNING status >24h).
Args:
limit: Maximum number of executions to return (default 100)
offset: Number of executions to skip (default 0)
Returns:
List of long-running executions with details
"""
logger.info(f"Listing long-running executions (limit={limit}, offset={offset})")
executions = await get_long_running_executions_details(limit=limit, offset=offset)
# Get total count for pagination
diagnostics = await get_execution_diagnostics()
total = diagnostics.stuck_running_24h
return RunningExecutionsListResponse(executions=executions, total=total)
@router.get(
"/diagnostics/executions/stuck-queued",
response_model=RunningExecutionsListResponse,
summary="List Stuck Queued Executions",
)
async def list_stuck_queued_executions(
limit: int = 100,
offset: int = 0,
):
"""
Get detailed list of stuck queued executions (QUEUED >1h, never started).
Args:
limit: Maximum number of executions to return (default 100)
offset: Number of executions to skip (default 0)
Returns:
List of stuck queued executions with details
"""
logger.info(f"Listing stuck queued executions (limit={limit}, offset={offset})")
executions = await get_stuck_queued_executions_details(limit=limit, offset=offset)
# Get total count for pagination
diagnostics = await get_execution_diagnostics()
total = diagnostics.stuck_queued_1h
return RunningExecutionsListResponse(executions=executions, total=total)
@router.get(
"/diagnostics/executions/invalid",
response_model=RunningExecutionsListResponse,
summary="List Invalid Executions",
)
async def list_invalid_executions(
limit: int = 100,
offset: int = 0,
):
"""
Get detailed list of executions in invalid states (READ-ONLY).
Invalid states indicate data corruption and require manual investigation:
- QUEUED but has startedAt (impossible - can't start while queued)
- RUNNING but no startedAt (impossible - can't run without starting)
⚠️ NO BULK ACTIONS PROVIDED - These need case-by-case investigation.
Each invalid execution likely has a different root cause (crashes, race conditions,
DB corruption). Investigate the execution history and logs to determine appropriate
action (manual cleanup, status fix, or leave as-is if system recovered).
Args:
limit: Maximum number of executions to return (default 100)
offset: Number of executions to skip (default 0)
Returns:
List of invalid state executions with details
"""
logger.info(f"Listing invalid state executions (limit={limit}, offset={offset})")
executions = await get_invalid_executions_details(limit=limit, offset=offset)
# Get total count for pagination
diagnostics = await get_execution_diagnostics()
total = (
diagnostics.invalid_queued_with_start
+ diagnostics.invalid_running_without_start
)
return RunningExecutionsListResponse(executions=executions, total=total)
@router.post(
"/diagnostics/executions/requeue",
response_model=RequeueExecutionResponse,
summary="Requeue Stuck Execution",
)
async def requeue_single_execution(
request: StopExecutionRequest, # Reuse same request model (has execution_id)
user: AuthUser = Security(requires_admin_user),
):
"""
Requeue a stuck QUEUED execution (admin only).
Uses add_graph_execution with existing graph_exec_id to requeue.
⚠️ WARNING: Only use for stuck executions. This will re-execute and may cost credits.
Args:
request: Contains execution_id to requeue
Returns:
Success status and message
"""
logger.info(f"Admin {user.user_id} requeueing execution {request.execution_id}")
# Get the execution (validation - must be QUEUED)
executions = await get_graph_executions(
graph_exec_id=request.execution_id,
statuses=[AgentExecutionStatus.QUEUED],
)
if not executions:
raise HTTPException(
status_code=404,
detail="Execution not found or not in QUEUED status",
)
execution = executions[0]
# Use add_graph_execution in requeue mode
await add_graph_execution(
graph_id=execution.graph_id,
user_id=execution.user_id,
graph_version=execution.graph_version,
graph_exec_id=request.execution_id, # Requeue existing execution
)
return RequeueExecutionResponse(
success=True,
requeued_count=1,
message="Execution requeued successfully",
)
@router.post(
"/diagnostics/executions/requeue-bulk",
response_model=RequeueExecutionResponse,
summary="Requeue Multiple Stuck Executions",
)
async def requeue_multiple_executions(
request: StopExecutionsRequest, # Reuse same request model (has execution_ids)
user: AuthUser = Security(requires_admin_user),
):
"""
Requeue multiple stuck QUEUED executions (admin only).
Uses add_graph_execution with existing graph_exec_id to requeue.
⚠️ WARNING: Only use for stuck executions. This will re-execute and may cost credits.
Args:
request: Contains list of execution_ids to requeue
Returns:
Number of executions requeued and success message
"""
logger.info(
f"Admin {user.user_id} requeueing {len(request.execution_ids)} executions"
)
# Get executions by ID list (must be QUEUED)
executions = await get_graph_executions(
execution_ids=request.execution_ids,
statuses=[AgentExecutionStatus.QUEUED],
)
if not executions:
return RequeueExecutionResponse(
success=False,
requeued_count=0,
message="No QUEUED executions found to requeue",
)
# Requeue all executions in parallel using add_graph_execution
async def requeue_one(exec) -> bool:
try:
await add_graph_execution(
graph_id=exec.graph_id,
user_id=exec.user_id,
graph_version=exec.graph_version,
graph_exec_id=exec.id, # Requeue existing
)
return True
except Exception as e:
logger.error(f"Failed to requeue {exec.id}: {e}")
return False
results = await asyncio.gather(
*[requeue_one(exec) for exec in executions], return_exceptions=False
)
requeued_count = sum(1 for success in results if success)
return RequeueExecutionResponse(
success=requeued_count > 0,
requeued_count=requeued_count,
message=f"Requeued {requeued_count} of {len(request.execution_ids)} executions",
)
@router.post(
"/diagnostics/executions/stop",
response_model=StopExecutionResponse,
summary="Stop Single Execution",
)
async def stop_single_execution(
request: StopExecutionRequest,
user: AuthUser = Security(requires_admin_user),
):
"""
Stop a single execution (admin only).
Uses robust stop_graph_execution which cascades to children and waits for termination.
Args:
request: Contains execution_id to stop
Returns:
Success status and message
"""
logger.info(f"Admin {user.user_id} stopping execution {request.execution_id}")
# Get the execution to find its owner user_id (required by stop_graph_execution)
executions = await get_graph_executions(
graph_exec_id=request.execution_id,
)
if not executions:
raise HTTPException(status_code=404, detail="Execution not found")
execution = executions[0]
# Use robust stop_graph_execution (cascades to children, waits for termination)
await stop_graph_execution(
user_id=execution.user_id,
graph_exec_id=request.execution_id,
wait_timeout=15.0,
cascade=True,
)
return StopExecutionResponse(
success=True,
stopped_count=1,
message="Execution stopped successfully",
)
@router.post(
"/diagnostics/executions/stop-bulk",
response_model=StopExecutionResponse,
summary="Stop Multiple Executions",
)
async def stop_multiple_executions(
request: StopExecutionsRequest,
user: AuthUser = Security(requires_admin_user),
):
"""
Stop multiple active executions (admin only).
Uses robust stop_graph_execution which cascades to children and waits for termination.
Args:
request: Contains list of execution_ids to stop
Returns:
Number of executions stopped and success message
"""
logger.info(
f"Admin {user.user_id} stopping {len(request.execution_ids)} executions"
)
# Get executions by ID list
executions = await get_graph_executions(
execution_ids=request.execution_ids,
)
if not executions:
return StopExecutionResponse(
success=False,
stopped_count=0,
message="No executions found",
)
# Stop all executions in parallel using robust stop_graph_execution
async def stop_one(exec) -> bool:
try:
await stop_graph_execution(
user_id=exec.user_id,
graph_exec_id=exec.id,
wait_timeout=15.0,
cascade=True,
)
return True
except Exception as e:
logger.error(f"Failed to stop execution {exec.id}: {e}")
return False
results = await asyncio.gather(
*[stop_one(exec) for exec in executions], return_exceptions=False
)
stopped_count = sum(1 for success in results if success)
return StopExecutionResponse(
success=stopped_count > 0,
stopped_count=stopped_count,
message=f"Stopped {stopped_count} of {len(request.execution_ids)} executions",
)
@router.post(
"/diagnostics/executions/cleanup-orphaned",
response_model=StopExecutionResponse,
summary="Cleanup Orphaned Executions",
)
async def cleanup_orphaned_executions(
request: StopExecutionsRequest,
user: AuthUser = Security(requires_admin_user),
):
"""
Cleanup orphaned executions by directly updating DB status (admin only).
For executions in DB but not actually running in executor (old/stale records).
Args:
request: Contains list of execution_ids to cleanup
Returns:
Number of executions cleaned up and success message
"""
logger.info(
f"Admin {user.user_id} cleaning up {len(request.execution_ids)} orphaned executions"
)
cleaned_count = await cleanup_orphaned_executions_bulk(
request.execution_ids, user.user_id
)
return StopExecutionResponse(
success=cleaned_count > 0,
stopped_count=cleaned_count,
message=f"Cleaned up {cleaned_count} of {len(request.execution_ids)} orphaned executions",
)
# ============================================================================
# SCHEDULE DIAGNOSTICS ENDPOINTS
# ============================================================================
class SchedulesListResponse(BaseModel):
"""Response model for list of schedules"""
schedules: List[ScheduleDetail]
total: int
class OrphanedSchedulesListResponse(BaseModel):
"""Response model for list of orphaned schedules"""
schedules: List[OrphanedScheduleDetail]
total: int
class ScheduleCleanupRequest(BaseModel):
"""Request model for cleaning up schedules"""
schedule_ids: List[str]
class ScheduleCleanupResponse(BaseModel):
"""Response model for schedule cleanup operations"""
success: bool
deleted_count: int = 0
message: str
@router.get(
"/diagnostics/schedules",
response_model=ScheduleHealthMetrics,
summary="Get Schedule Diagnostics",
)
async def get_schedule_diagnostics_endpoint():
"""
Get comprehensive diagnostic information about schedule health.
Returns schedule metrics including:
- Total schedules (user vs system)
- Orphaned schedules by category
- Upcoming executions
"""
logger.info("Getting schedule diagnostics")
diagnostics = await get_schedule_health_metrics()
logger.info(
f"Schedule diagnostics: total={diagnostics.total_schedules}, "
f"user={diagnostics.user_schedules}, "
f"orphaned={diagnostics.total_orphaned}"
)
return diagnostics
@router.get(
"/diagnostics/schedules/all",
response_model=SchedulesListResponse,
summary="List All User Schedules",
)
async def list_all_schedules(
limit: int = 100,
offset: int = 0,
):
"""
Get detailed list of all user schedules (excludes system monitoring jobs).
Args:
limit: Maximum number of schedules to return (default 100)
offset: Number of schedules to skip (default 0)
Returns:
List of schedules with details
"""
logger.info(f"Listing all schedules (limit={limit}, offset={offset})")
schedules = await get_all_schedules_details(limit=limit, offset=offset)
# Get total count
diagnostics = await get_schedule_health_metrics()
total = diagnostics.user_schedules
return SchedulesListResponse(schedules=schedules, total=total)
@router.get(
"/diagnostics/schedules/orphaned",
response_model=OrphanedSchedulesListResponse,
summary="List Orphaned Schedules",
)
async def list_orphaned_schedules():
"""
Get detailed list of orphaned schedules with orphan reasons.
Returns:
List of orphaned schedules categorized by orphan type
"""
logger.info("Listing orphaned schedules")
schedules = await get_orphaned_schedules_details()
return OrphanedSchedulesListResponse(schedules=schedules, total=len(schedules))
@router.post(
"/diagnostics/schedules/cleanup-orphaned",
response_model=ScheduleCleanupResponse,
summary="Cleanup Orphaned Schedules",
)
async def cleanup_orphaned_schedules(
request: ScheduleCleanupRequest,
user: AuthUser = Security(requires_admin_user),
):
"""
Cleanup orphaned schedules by deleting from scheduler (admin only).
Args:
request: Contains list of schedule_ids to delete
Returns:
Number of schedules deleted and success message
"""
logger.info(
f"Admin {user.user_id} cleaning up {len(request.schedule_ids)} orphaned schedules"
)
deleted_count = await cleanup_orphaned_schedules_bulk(
request.schedule_ids, user.user_id
)
return ScheduleCleanupResponse(
success=deleted_count > 0,
deleted_count=deleted_count,
message=f"Deleted {deleted_count} of {len(request.schedule_ids)} orphaned schedules",
)
@router.post(
"/diagnostics/executions/stop-all-long-running",
response_model=StopExecutionResponse,
summary="Stop ALL Long-Running Executions",
)
async def stop_all_long_running_executions_endpoint(
user: AuthUser = Security(requires_admin_user),
):
"""
Stop ALL long-running executions (RUNNING >24h) by sending cancel signals (admin only).
Operates on entire dataset, not limited to pagination.
Returns:
Number of executions stopped and success message
"""
logger.info(f"Admin {user.user_id} stopping ALL long-running executions")
stopped_count = await stop_all_long_running_executions(user.user_id)
return StopExecutionResponse(
success=stopped_count > 0,
stopped_count=stopped_count,
message=f"Stopped {stopped_count} long-running executions",
)
@router.post(
"/diagnostics/executions/cleanup-all-orphaned",
response_model=StopExecutionResponse,
summary="Cleanup ALL Orphaned Executions",
)
async def cleanup_all_orphaned_executions(
user: AuthUser = Security(requires_admin_user),
):
"""
Cleanup ALL orphaned executions (>24h old) by directly updating DB status.
Operates on all executions, not just paginated results.
Returns:
Number of executions cleaned up and success message
"""
logger.info(f"Admin {user.user_id} cleaning up ALL orphaned executions")
# Fetch all orphaned execution IDs
execution_ids = await get_all_orphaned_execution_ids()
if not execution_ids:
return StopExecutionResponse(
success=True,
stopped_count=0,
message="No orphaned executions to cleanup",
)
cleaned_count = await cleanup_orphaned_executions_bulk(execution_ids, user.user_id)
return StopExecutionResponse(
success=cleaned_count > 0,
stopped_count=cleaned_count,
message=f"Cleaned up {cleaned_count} orphaned executions",
)
@router.post(
"/diagnostics/executions/cleanup-all-stuck-queued",
response_model=StopExecutionResponse,
summary="Cleanup ALL Stuck Queued Executions",
)
async def cleanup_all_stuck_queued_executions_endpoint(
user: AuthUser = Security(requires_admin_user),
):
"""
Cleanup ALL stuck queued executions (QUEUED >1h) by updating DB status (admin only).
Operates on entire dataset, not limited to pagination.
Returns:
Number of executions cleaned up and success message
"""
logger.info(f"Admin {user.user_id} cleaning up ALL stuck queued executions")
cleaned_count = await cleanup_all_stuck_queued_executions(user.user_id)
return StopExecutionResponse(
success=cleaned_count > 0,
stopped_count=cleaned_count,
message=f"Cleaned up {cleaned_count} stuck queued executions",
)
@router.post(
"/diagnostics/executions/requeue-all-stuck",
response_model=RequeueExecutionResponse,
summary="Requeue ALL Stuck Queued Executions",
)
async def requeue_all_stuck_executions(
user: AuthUser = Security(requires_admin_user),
):
"""
Requeue ALL stuck queued executions (QUEUED >1h) by publishing to RabbitMQ.
Operates on all executions, not just paginated results.
Uses add_graph_execution with existing graph_exec_id to requeue.
⚠️ WARNING: This will re-execute ALL stuck executions and may cost significant credits.
Returns:
Number of executions requeued and success message
"""
logger.info(f"Admin {user.user_id} requeueing ALL stuck queued executions")
# Fetch all stuck queued execution IDs
execution_ids = await get_all_stuck_queued_execution_ids()
if not execution_ids:
return RequeueExecutionResponse(
success=True,
requeued_count=0,
message="No stuck queued executions to requeue",
)
# Get stuck executions by ID list (must be QUEUED)
executions = await get_graph_executions(
execution_ids=execution_ids,
statuses=[AgentExecutionStatus.QUEUED],
)
# Requeue all in parallel using add_graph_execution
async def requeue_one(exec) -> bool:
try:
await add_graph_execution(
graph_id=exec.graph_id,
user_id=exec.user_id,
graph_version=exec.graph_version,
graph_exec_id=exec.id, # Requeue existing
)
return True
except Exception as e:
logger.error(f"Failed to requeue {exec.id}: {e}")
return False
results = await asyncio.gather(
*[requeue_one(exec) for exec in executions], return_exceptions=False
)
requeued_count = sum(1 for success in results if success)
return RequeueExecutionResponse(
success=requeued_count > 0,
requeued_count=requeued_count,
message=f"Requeued {requeued_count} stuck executions",
)

View File

@@ -0,0 +1,889 @@
from datetime import datetime, timezone
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from prisma.enums import AgentExecutionStatus
import backend.api.features.admin.diagnostics_admin_routes as diagnostics_admin_routes
from backend.data.diagnostics import (
AgentDiagnosticsSummary,
ExecutionDiagnosticsSummary,
FailedExecutionDetail,
OrphanedScheduleDetail,
RunningExecutionDetail,
ScheduleDetail,
ScheduleHealthMetrics,
)
from backend.data.execution import GraphExecutionMeta
app = fastapi.FastAPI()
app.include_router(diagnostics_admin_routes.router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
"""Setup admin auth overrides for all tests in this module"""
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def test_get_execution_diagnostics_success(
mocker: pytest_mock.MockFixture,
):
"""Test fetching execution diagnostics with invalid state detection"""
mock_diagnostics = ExecutionDiagnosticsSummary(
running_count=10,
queued_db_count=5,
rabbitmq_queue_depth=3,
cancel_queue_depth=0,
orphaned_running=2,
orphaned_queued=1,
failed_count_1h=5,
failed_count_24h=20,
failure_rate_24h=0.83,
stuck_running_24h=1,
stuck_running_1h=3,
oldest_running_hours=26.5,
stuck_queued_1h=2,
queued_never_started=1,
invalid_queued_with_start=1, # New invalid state
invalid_running_without_start=1, # New invalid state
completed_1h=50,
completed_24h=1200,
throughput_per_hour=50.0,
timestamp=datetime.now(timezone.utc).isoformat(),
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
return_value=mock_diagnostics,
)
response = client.get("/admin/diagnostics/executions")
assert response.status_code == 200
data = response.json()
# Verify new invalid state fields are included
assert data["invalid_queued_with_start"] == 1
assert data["invalid_running_without_start"] == 1
# Verify all expected fields present
assert "running_executions" in data
assert "orphaned_running" in data
assert "failed_count_24h" in data
def test_list_invalid_executions(
mocker: pytest_mock.MockFixture,
):
"""Test listing executions in invalid states (read-only endpoint)"""
mock_invalid_executions = [
RunningExecutionDetail(
execution_id="exec-invalid-1",
graph_id="graph-123",
graph_name="Test Graph",
graph_version=1,
user_id="user-123",
user_email="test@example.com",
status="QUEUED",
created_at=datetime.now(timezone.utc),
started_at=datetime.now(
timezone.utc
), # QUEUED but has startedAt - INVALID!
queue_status=None,
),
RunningExecutionDetail(
execution_id="exec-invalid-2",
graph_id="graph-456",
graph_name="Another Graph",
graph_version=2,
user_id="user-456",
user_email="user@example.com",
status="RUNNING",
created_at=datetime.now(timezone.utc),
started_at=None, # RUNNING but no startedAt - INVALID!
queue_status=None,
),
]
mock_diagnostics = ExecutionDiagnosticsSummary(
running_count=10,
queued_db_count=5,
rabbitmq_queue_depth=3,
cancel_queue_depth=0,
orphaned_running=0,
orphaned_queued=0,
failed_count_1h=0,
failed_count_24h=0,
failure_rate_24h=0.0,
stuck_running_24h=0,
stuck_running_1h=0,
oldest_running_hours=None,
stuck_queued_1h=0,
queued_never_started=0,
invalid_queued_with_start=1,
invalid_running_without_start=1,
completed_1h=0,
completed_24h=0,
throughput_per_hour=0.0,
timestamp=datetime.now(timezone.utc).isoformat(),
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_invalid_executions_details",
return_value=mock_invalid_executions,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
return_value=mock_diagnostics,
)
response = client.get("/admin/diagnostics/executions/invalid?limit=100&offset=0")
assert response.status_code == 200
data = response.json()
assert data["total"] == 2 # Sum of both invalid state types
assert len(data["executions"]) == 2
# Verify both types of invalid states are returned
assert data["executions"][0]["execution_id"] in [
"exec-invalid-1",
"exec-invalid-2",
]
assert data["executions"][1]["execution_id"] in [
"exec-invalid-1",
"exec-invalid-2",
]
def test_requeue_single_execution_with_add_graph_execution(
mocker: pytest_mock.MockFixture,
admin_user_id: str,
):
"""Test requeueing uses add_graph_execution in requeue mode"""
mock_exec_meta = GraphExecutionMeta(
id="exec-stuck-123",
user_id="user-123",
graph_id="graph-456",
graph_version=1,
inputs=None,
credential_inputs=None,
nodes_input_masks=None,
preset_id=None,
status=AgentExecutionStatus.QUEUED,
started_at=datetime.now(timezone.utc),
ended_at=datetime.now(timezone.utc),
stats=None,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
return_value=[mock_exec_meta],
)
mock_add_graph_execution = mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.add_graph_execution",
return_value=AsyncMock(),
)
response = client.post(
"/admin/diagnostics/executions/requeue",
json={"execution_id": "exec-stuck-123"},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["requeued_count"] == 1
# Verify it used add_graph_execution in requeue mode
mock_add_graph_execution.assert_called_once()
call_kwargs = mock_add_graph_execution.call_args.kwargs
assert call_kwargs["graph_exec_id"] == "exec-stuck-123" # Requeue mode!
assert call_kwargs["graph_id"] == "graph-456"
assert call_kwargs["user_id"] == "user-123"
def test_stop_single_execution_with_stop_graph_execution(
mocker: pytest_mock.MockFixture,
admin_user_id: str,
):
"""Test stopping uses robust stop_graph_execution"""
mock_exec_meta = GraphExecutionMeta(
id="exec-running-123",
user_id="user-789",
graph_id="graph-999",
graph_version=2,
inputs=None,
credential_inputs=None,
nodes_input_masks=None,
preset_id=None,
status=AgentExecutionStatus.RUNNING,
started_at=datetime.now(timezone.utc),
ended_at=datetime.now(timezone.utc),
stats=None,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
return_value=[mock_exec_meta],
)
mock_stop_graph_execution = mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.stop_graph_execution",
return_value=AsyncMock(),
)
response = client.post(
"/admin/diagnostics/executions/stop",
json={"execution_id": "exec-running-123"},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["stopped_count"] == 1
# Verify it used stop_graph_execution with cascade
mock_stop_graph_execution.assert_called_once()
call_kwargs = mock_stop_graph_execution.call_args.kwargs
assert call_kwargs["graph_exec_id"] == "exec-running-123"
assert call_kwargs["user_id"] == "user-789"
assert call_kwargs["cascade"] is True # Stops children too!
assert call_kwargs["wait_timeout"] == 15.0
def test_requeue_not_queued_execution_fails(
mocker: pytest_mock.MockFixture,
):
"""Test that requeue fails if execution is not in QUEUED status"""
# Mock an execution that's RUNNING (not QUEUED)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
return_value=[], # No QUEUED executions found
)
response = client.post(
"/admin/diagnostics/executions/requeue",
json={"execution_id": "exec-running-123"},
)
assert response.status_code == 404
assert "not found or not in QUEUED status" in response.json()["detail"]
def test_list_invalid_executions_no_bulk_actions(
mocker: pytest_mock.MockFixture,
):
"""Verify invalid executions endpoint is read-only (no bulk actions)"""
# This is a documentation test - the endpoint exists but should not
# have corresponding cleanup/stop/requeue endpoints
# These endpoints should NOT exist for invalid states:
invalid_bulk_endpoints = [
"/admin/diagnostics/executions/cleanup-invalid",
"/admin/diagnostics/executions/stop-invalid",
"/admin/diagnostics/executions/requeue-invalid",
]
for endpoint in invalid_bulk_endpoints:
response = client.post(endpoint, json={"execution_ids": ["test"]})
assert response.status_code == 404, f"{endpoint} should not exist (read-only)"
def test_execution_ids_filter_efficiency(
mocker: pytest_mock.MockFixture,
):
"""Test that bulk operations use efficient execution_ids filter"""
mock_exec_metas = [
GraphExecutionMeta(
id=f"exec-{i}",
user_id=f"user-{i}",
graph_id="graph-123",
graph_version=1,
inputs=None,
credential_inputs=None,
nodes_input_masks=None,
preset_id=None,
status=AgentExecutionStatus.QUEUED,
started_at=datetime.now(timezone.utc),
ended_at=datetime.now(timezone.utc),
stats=None,
)
for i in range(3)
]
mock_get_graph_executions = mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
return_value=mock_exec_metas,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.add_graph_execution",
return_value=AsyncMock(),
)
response = client.post(
"/admin/diagnostics/executions/requeue-bulk",
json={"execution_ids": ["exec-0", "exec-1", "exec-2"]},
)
assert response.status_code == 200
# Verify it used execution_ids filter (not fetching all queued)
mock_get_graph_executions.assert_called_once()
call_kwargs = mock_get_graph_executions.call_args.kwargs
assert "execution_ids" in call_kwargs
assert call_kwargs["execution_ids"] == ["exec-0", "exec-1", "exec-2"]
assert call_kwargs["statuses"] == [AgentExecutionStatus.QUEUED]
# ---------------------------------------------------------------------------
# Helper: reusable mock diagnostics summary
# ---------------------------------------------------------------------------
def _make_mock_diagnostics(**overrides) -> ExecutionDiagnosticsSummary:
defaults = dict(
running_count=10,
queued_db_count=5,
rabbitmq_queue_depth=3,
cancel_queue_depth=0,
orphaned_running=2,
orphaned_queued=1,
failed_count_1h=5,
failed_count_24h=20,
failure_rate_24h=0.83,
stuck_running_24h=3,
stuck_running_1h=5,
oldest_running_hours=26.5,
stuck_queued_1h=2,
queued_never_started=1,
invalid_queued_with_start=1,
invalid_running_without_start=1,
completed_1h=50,
completed_24h=1200,
throughput_per_hour=50.0,
timestamp=datetime.now(timezone.utc).isoformat(),
)
defaults.update(overrides)
return ExecutionDiagnosticsSummary(**defaults)
_SENTINEL = object()
def _make_mock_execution(
exec_id: str = "exec-1",
status: str = "RUNNING",
started_at: datetime | None | object = _SENTINEL,
) -> RunningExecutionDetail:
return RunningExecutionDetail(
execution_id=exec_id,
graph_id="graph-123",
graph_name="Test Graph",
graph_version=1,
user_id="user-123",
user_email="test@example.com",
status=status,
created_at=datetime.now(timezone.utc),
started_at=(
datetime.now(timezone.utc) if started_at is _SENTINEL else started_at
),
queue_status=None,
)
def _make_mock_failed_execution(
exec_id: str = "exec-fail-1",
) -> FailedExecutionDetail:
return FailedExecutionDetail(
execution_id=exec_id,
graph_id="graph-123",
graph_name="Test Graph",
graph_version=1,
user_id="user-123",
user_email="test@example.com",
status="FAILED",
created_at=datetime.now(timezone.utc),
started_at=datetime.now(timezone.utc),
failed_at=datetime.now(timezone.utc),
error_message="Something went wrong",
)
def _make_mock_schedule_health(**overrides) -> ScheduleHealthMetrics:
defaults = dict(
total_schedules=15,
user_schedules=10,
system_schedules=5,
orphaned_deleted_graph=2,
orphaned_no_library_access=1,
orphaned_invalid_credentials=0,
orphaned_validation_failed=0,
total_orphaned=3,
schedules_next_hour=4,
schedules_next_24h=8,
total_runs_next_hour=12,
total_runs_next_24h=48,
timestamp=datetime.now(timezone.utc).isoformat(),
)
defaults.update(overrides)
return ScheduleHealthMetrics(**defaults)
# ---------------------------------------------------------------------------
# GET endpoints: execution list variants
# ---------------------------------------------------------------------------
def test_list_running_executions(mocker: pytest_mock.MockFixture):
mock_execs = [
_make_mock_execution("exec-run-1"),
_make_mock_execution("exec-run-2"),
]
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_running_executions_details",
return_value=mock_execs,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
return_value=_make_mock_diagnostics(),
)
response = client.get("/admin/diagnostics/executions/running?limit=50&offset=0")
assert response.status_code == 200
data = response.json()
assert data["total"] == 15 # running_count(10) + queued_db_count(5)
assert len(data["executions"]) == 2
assert data["executions"][0]["execution_id"] == "exec-run-1"
def test_list_orphaned_executions(mocker: pytest_mock.MockFixture):
mock_execs = [_make_mock_execution("exec-orphan-1", status="RUNNING")]
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_orphaned_executions_details",
return_value=mock_execs,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
return_value=_make_mock_diagnostics(),
)
response = client.get("/admin/diagnostics/executions/orphaned?limit=50&offset=0")
assert response.status_code == 200
data = response.json()
assert data["total"] == 3 # orphaned_running(2) + orphaned_queued(1)
assert len(data["executions"]) == 1
def test_list_failed_executions(mocker: pytest_mock.MockFixture):
mock_execs = [_make_mock_failed_execution("exec-fail-1")]
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_failed_executions_details",
return_value=mock_execs,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_failed_executions_count",
return_value=42,
)
response = client.get(
"/admin/diagnostics/executions/failed?limit=50&offset=0&hours=24"
)
assert response.status_code == 200
data = response.json()
assert data["total"] == 42
assert len(data["executions"]) == 1
assert data["executions"][0]["error_message"] == "Something went wrong"
def test_list_long_running_executions(mocker: pytest_mock.MockFixture):
mock_execs = [_make_mock_execution("exec-long-1")]
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_long_running_executions_details",
return_value=mock_execs,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
return_value=_make_mock_diagnostics(),
)
response = client.get(
"/admin/diagnostics/executions/long-running?limit=50&offset=0"
)
assert response.status_code == 200
data = response.json()
assert data["total"] == 3 # stuck_running_24h
assert len(data["executions"]) == 1
def test_list_stuck_queued_executions(mocker: pytest_mock.MockFixture):
mock_execs = [
_make_mock_execution("exec-stuck-1", status="QUEUED", started_at=None)
]
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_stuck_queued_executions_details",
return_value=mock_execs,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics",
return_value=_make_mock_diagnostics(),
)
response = client.get(
"/admin/diagnostics/executions/stuck-queued?limit=50&offset=0"
)
assert response.status_code == 200
data = response.json()
assert data["total"] == 2 # stuck_queued_1h
assert len(data["executions"]) == 1
# ---------------------------------------------------------------------------
# GET endpoints: agent + schedule diagnostics
# ---------------------------------------------------------------------------
def test_get_agent_diagnostics(mocker: pytest_mock.MockFixture):
mock_diag = AgentDiagnosticsSummary(
agents_with_active_executions=7,
timestamp=datetime.now(timezone.utc).isoformat(),
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_agent_diagnostics",
return_value=mock_diag,
)
response = client.get("/admin/diagnostics/agents")
assert response.status_code == 200
data = response.json()
assert data["agents_with_active_executions"] == 7
def test_get_schedule_diagnostics(mocker: pytest_mock.MockFixture):
mock_metrics = _make_mock_schedule_health()
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_schedule_health_metrics",
return_value=mock_metrics,
)
response = client.get("/admin/diagnostics/schedules")
assert response.status_code == 200
data = response.json()
assert data["user_schedules"] == 10
assert data["total_orphaned"] == 3
assert data["total_runs_next_hour"] == 12
def test_list_all_schedules(mocker: pytest_mock.MockFixture):
mock_schedules = [
ScheduleDetail(
schedule_id="sched-1",
schedule_name="Daily Run",
graph_id="graph-1",
graph_name="My Agent",
graph_version=1,
user_id="user-1",
user_email="alice@example.com",
cron="0 9 * * *",
timezone="UTC",
next_run_time=datetime.now(timezone.utc).isoformat(),
),
]
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_all_schedules_details",
return_value=mock_schedules,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_schedule_health_metrics",
return_value=_make_mock_schedule_health(),
)
response = client.get("/admin/diagnostics/schedules/all?limit=50&offset=0")
assert response.status_code == 200
data = response.json()
assert data["total"] == 10
assert len(data["schedules"]) == 1
assert data["schedules"][0]["schedule_name"] == "Daily Run"
def test_list_orphaned_schedules(mocker: pytest_mock.MockFixture):
mock_orphans = [
OrphanedScheduleDetail(
schedule_id="sched-orphan-1",
schedule_name="Ghost Schedule",
graph_id="graph-deleted",
graph_version=1,
user_id="user-1",
orphan_reason="deleted_graph",
error_detail=None,
next_run_time=datetime.now(timezone.utc).isoformat(),
),
]
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_orphaned_schedules_details",
return_value=mock_orphans,
)
response = client.get("/admin/diagnostics/schedules/orphaned")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["schedules"][0]["orphan_reason"] == "deleted_graph"
# ---------------------------------------------------------------------------
# POST endpoints: bulk stop, cleanup, requeue
# ---------------------------------------------------------------------------
def test_stop_multiple_executions(mocker: pytest_mock.MockFixture):
mock_exec_metas = [
GraphExecutionMeta(
id=f"exec-{i}",
user_id=f"user-{i}",
graph_id="graph-123",
graph_version=1,
inputs=None,
credential_inputs=None,
nodes_input_masks=None,
preset_id=None,
status=AgentExecutionStatus.RUNNING,
started_at=datetime.now(timezone.utc),
ended_at=None,
stats=None,
)
for i in range(2)
]
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
return_value=mock_exec_metas,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.stop_graph_execution",
return_value=AsyncMock(),
)
response = client.post(
"/admin/diagnostics/executions/stop-bulk",
json={"execution_ids": ["exec-0", "exec-1"]},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["stopped_count"] == 2
def test_stop_multiple_executions_none_found(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
return_value=[],
)
response = client.post(
"/admin/diagnostics/executions/stop-bulk",
json={"execution_ids": ["nonexistent"]},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is False
assert data["stopped_count"] == 0
def test_cleanup_orphaned_executions(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_executions_bulk",
return_value=3,
)
response = client.post(
"/admin/diagnostics/executions/cleanup-orphaned",
json={"execution_ids": ["exec-1", "exec-2", "exec-3"]},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["stopped_count"] == 3
def test_cleanup_orphaned_schedules(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_schedules_bulk",
return_value=2,
)
response = client.post(
"/admin/diagnostics/schedules/cleanup-orphaned",
json={"schedule_ids": ["sched-1", "sched-2"]},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["deleted_count"] == 2
def test_stop_all_long_running_executions(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.stop_all_long_running_executions",
return_value=5,
)
response = client.post("/admin/diagnostics/executions/stop-all-long-running")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["stopped_count"] == 5
def test_cleanup_all_orphaned_executions(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_all_orphaned_execution_ids",
return_value=["exec-1", "exec-2"],
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_executions_bulk",
return_value=2,
)
response = client.post("/admin/diagnostics/executions/cleanup-all-orphaned")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["stopped_count"] == 2
def test_cleanup_all_orphaned_executions_none(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_all_orphaned_execution_ids",
return_value=[],
)
response = client.post("/admin/diagnostics/executions/cleanup-all-orphaned")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["stopped_count"] == 0
assert "No orphaned" in data["message"]
def test_cleanup_all_stuck_queued_executions(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.cleanup_all_stuck_queued_executions",
return_value=4,
)
response = client.post("/admin/diagnostics/executions/cleanup-all-stuck-queued")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["stopped_count"] == 4
def test_requeue_all_stuck_executions(mocker: pytest_mock.MockFixture):
mock_exec_metas = [
GraphExecutionMeta(
id=f"exec-stuck-{i}",
user_id=f"user-{i}",
graph_id="graph-123",
graph_version=1,
inputs=None,
credential_inputs=None,
nodes_input_masks=None,
preset_id=None,
status=AgentExecutionStatus.QUEUED,
started_at=None,
ended_at=None,
stats=None,
)
for i in range(3)
]
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_all_stuck_queued_execution_ids",
return_value=["exec-stuck-0", "exec-stuck-1", "exec-stuck-2"],
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
return_value=mock_exec_metas,
)
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.add_graph_execution",
return_value=AsyncMock(),
)
response = client.post("/admin/diagnostics/executions/requeue-all-stuck")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["requeued_count"] == 3
def test_requeue_all_stuck_executions_none(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_all_stuck_queued_execution_ids",
return_value=[],
)
response = client.post("/admin/diagnostics/executions/requeue-all-stuck")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["requeued_count"] == 0
assert "No stuck" in data["message"]
def test_requeue_bulk_none_found(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
return_value=[],
)
response = client.post(
"/admin/diagnostics/executions/requeue-bulk",
json={"execution_ids": ["nonexistent"]},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is False
assert data["requeued_count"] == 0
def test_stop_single_execution_not_found(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.admin.diagnostics_admin_routes.get_graph_executions",
return_value=[],
)
response = client.post(
"/admin/diagnostics/executions/stop",
json={"execution_id": "nonexistent"},
)
assert response.status_code == 404
assert "not found" in response.json()["detail"]

View File

@@ -14,3 +14,70 @@ class UserHistoryResponse(BaseModel):
class AddUserCreditsResponse(BaseModel):
new_balance: int
transaction_key: str
class ExecutionDiagnosticsResponse(BaseModel):
"""Response model for execution diagnostics"""
# Current execution state
running_executions: int
queued_executions_db: int
queued_executions_rabbitmq: int
cancel_queue_depth: int
# Orphaned execution detection
orphaned_running: int
orphaned_queued: int
# Failure metrics
failed_count_1h: int
failed_count_24h: int
failure_rate_24h: float
# Long-running detection
stuck_running_24h: int
stuck_running_1h: int
oldest_running_hours: float | None
# Stuck queued detection
stuck_queued_1h: int
queued_never_started: int
# Invalid state detection (data corruption - no auto-actions)
invalid_queued_with_start: int
invalid_running_without_start: int
# Throughput metrics
completed_1h: int
completed_24h: int
throughput_per_hour: float
timestamp: str
class AgentDiagnosticsResponse(BaseModel):
"""Response model for agent diagnostics"""
agents_with_active_executions: int
timestamp: str
class ScheduleHealthMetrics(BaseModel):
"""Response model for schedule diagnostics"""
total_schedules: int
user_schedules: int
system_schedules: int
# Orphan detection
orphaned_deleted_graph: int
orphaned_no_library_access: int
orphaned_invalid_credentials: int
orphaned_validation_failed: int
total_orphaned: int
# Upcoming
schedules_next_hour: int
schedules_next_24h: int
timestamp: str

View File

@@ -32,10 +32,10 @@ router = APIRouter(
class UserRateLimitResponse(BaseModel):
user_id: str
user_email: Optional[str] = None
daily_token_limit: int
weekly_token_limit: int
daily_tokens_used: int
weekly_tokens_used: int
daily_cost_limit_microdollars: int
weekly_cost_limit_microdollars: int
daily_cost_used_microdollars: int
weekly_cost_used_microdollars: int
tier: SubscriptionTier
@@ -101,17 +101,19 @@ async def get_user_rate_limit(
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
daily_limit, weekly_limit, tier = await get_global_rate_limits(
resolved_id, config.daily_token_limit, config.weekly_token_limit
resolved_id,
config.daily_cost_limit_microdollars,
config.weekly_cost_limit_microdollars,
)
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit, tier=tier)
return UserRateLimitResponse(
user_id=resolved_id,
user_email=resolved_email,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_tokens_used=usage.daily.used,
weekly_tokens_used=usage.weekly.used,
daily_cost_limit_microdollars=daily_limit,
weekly_cost_limit_microdollars=weekly_limit,
daily_cost_used_microdollars=usage.daily.used,
weekly_cost_used_microdollars=usage.weekly.used,
tier=tier,
)
@@ -141,7 +143,9 @@ async def reset_user_rate_limit(
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
daily_limit, weekly_limit, tier = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
user_id,
config.daily_cost_limit_microdollars,
config.weekly_cost_limit_microdollars,
)
usage = await get_usage_status(user_id, daily_limit, weekly_limit, tier=tier)
@@ -154,10 +158,10 @@ async def reset_user_rate_limit(
return UserRateLimitResponse(
user_id=user_id,
user_email=resolved_email,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_tokens_used=usage.daily.used,
weekly_tokens_used=usage.weekly.used,
daily_cost_limit_microdollars=daily_limit,
weekly_cost_limit_microdollars=weekly_limit,
daily_cost_used_microdollars=usage.daily.used,
weekly_cost_used_microdollars=usage.weekly.used,
tier=tier,
)

View File

@@ -57,7 +57,7 @@ def _patch_rate_limit_deps(
mocker.patch(
f"{_MOCK_MODULE}.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
return_value=(2_500_000, 12_500_000, SubscriptionTier.BASIC),
)
mocker.patch(
f"{_MOCK_MODULE}.get_usage_status",
@@ -85,11 +85,11 @@ def test_get_rate_limit(
data = response.json()
assert data["user_id"] == target_user_id
assert data["user_email"] == _TARGET_EMAIL
assert data["daily_token_limit"] == 2_500_000
assert data["weekly_token_limit"] == 12_500_000
assert data["daily_tokens_used"] == 500_000
assert data["weekly_tokens_used"] == 3_000_000
assert data["tier"] == "FREE"
assert data["daily_cost_limit_microdollars"] == 2_500_000
assert data["weekly_cost_limit_microdollars"] == 12_500_000
assert data["daily_cost_used_microdollars"] == 500_000
assert data["weekly_cost_used_microdollars"] == 3_000_000
assert data["tier"] == "BASIC"
configured_snapshot.assert_match(
json.dumps(data, indent=2, sort_keys=True) + "\n",
@@ -117,7 +117,7 @@ def test_get_rate_limit_by_email(
data = response.json()
assert data["user_id"] == target_user_id
assert data["user_email"] == _TARGET_EMAIL
assert data["daily_token_limit"] == 2_500_000
assert data["daily_cost_limit_microdollars"] == 2_500_000
def test_get_rate_limit_by_email_not_found(
@@ -160,10 +160,10 @@ def test_reset_user_usage_daily_only(
assert response.status_code == 200
data = response.json()
assert data["daily_tokens_used"] == 0
assert data["daily_cost_used_microdollars"] == 0
# Weekly is untouched
assert data["weekly_tokens_used"] == 3_000_000
assert data["tier"] == "FREE"
assert data["weekly_cost_used_microdollars"] == 3_000_000
assert data["tier"] == "BASIC"
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
@@ -192,9 +192,9 @@ def test_reset_user_usage_daily_and_weekly(
assert response.status_code == 200
data = response.json()
assert data["daily_tokens_used"] == 0
assert data["weekly_tokens_used"] == 0
assert data["tier"] == "FREE"
assert data["daily_cost_used_microdollars"] == 0
assert data["weekly_cost_used_microdollars"] == 0
assert data["tier"] == "BASIC"
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)
@@ -231,7 +231,7 @@ def test_get_rate_limit_email_lookup_failure(
mocker.patch(
f"{_MOCK_MODULE}.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
return_value=(2_500_000, 12_500_000, SubscriptionTier.BASIC),
)
mocker.patch(
f"{_MOCK_MODULE}.get_usage_status",
@@ -324,7 +324,7 @@ def test_set_user_tier(
mocker.patch(
f"{_MOCK_MODULE}.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.FREE,
return_value=SubscriptionTier.BASIC,
)
mock_set = mocker.patch(
f"{_MOCK_MODULE}.set_user_tier",
@@ -347,7 +347,7 @@ def test_set_user_tier_downgrade(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test downgrading a user's tier from PRO to FREE."""
"""Test downgrading a user's tier from PRO to BASIC."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
@@ -365,14 +365,14 @@ def test_set_user_tier_downgrade(
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "FREE"},
json={"user_id": target_user_id, "tier": "BASIC"},
)
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["tier"] == "FREE"
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.FREE)
assert data["tier"] == "BASIC"
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.BASIC)
def test_set_user_tier_invalid_tier(
@@ -456,7 +456,7 @@ def test_set_user_tier_db_failure(
mocker.patch(
f"{_MOCK_MODULE}.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.FREE,
return_value=SubscriptionTier.BASIC,
)
mocker.patch(
f"{_MOCK_MODULE}.set_user_tier",

View File

@@ -2,19 +2,18 @@
import asyncio
import logging
import re
from collections.abc import AsyncGenerator
from typing import Annotated
from uuid import uuid4
from autogpt_libs import auth
from fastapi import APIRouter, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from prisma.models import UserWorkspaceFile
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ConfigDict, Field, field_validator
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
from backend.copilot.builder_context import resolve_session_permissions
from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode
from backend.copilot.db import get_chat_messages_paginated
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
@@ -26,11 +25,18 @@ from backend.copilot.model import (
create_chat_session,
delete_chat_session,
get_chat_session,
get_or_create_builder_session,
get_user_sessions,
update_session_title,
)
from backend.copilot.pending_message_helpers import (
QueuePendingMessageResponse,
is_turn_in_flight,
queue_pending_for_http,
)
from backend.copilot.pending_messages import peek_pending_messages
from backend.copilot.rate_limit import (
CoPilotUsageStatus,
CoPilotUsagePublic,
RateLimitExceeded,
acquire_reset_lock,
check_rate_limit,
@@ -69,13 +75,14 @@ from backend.copilot.tools.models import (
NoResultsResponse,
SetupRequirementsResponse,
SuggestedGoalResponse,
TodoWriteResponse,
UnderstandingUpdatedResponse,
)
from backend.copilot.tracking import track_user_message
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.redis_client import get_redis_async
from backend.data.understanding import get_business_understanding
from backend.data.workspace import get_or_create_workspace
from backend.data.workspace import build_files_block, resolve_workspace_files
from backend.util.exceptions import InsufficientBalanceError, NotFoundError
from backend.util.settings import Settings
@@ -85,10 +92,6 @@ logger = logging.getLogger(__name__)
config = ChatConfig()
_UUID_RE = re.compile(
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
)
async def _validate_and_get_session(
session_id: str,
@@ -133,7 +136,7 @@ def _strip_injected_context(message: dict) -> dict:
class StreamChatRequest(BaseModel):
"""Request model for streaming chat with optional context."""
message: str
message: str = Field(max_length=64_000)
is_user_message: bool = True
context: dict[str, str] | None = None # {url: str, content: str}
file_ids: list[str] | None = Field(
@@ -151,16 +154,45 @@ class StreamChatRequest(BaseModel):
)
class CreateSessionRequest(BaseModel):
"""Request model for creating a new chat session.
class PeekPendingMessagesResponse(BaseModel):
"""Response for the pending-message peek (GET) endpoint.
Returns a read-only view of the pending buffer — messages are NOT
consumed. The frontend uses this to restore the queued-message
indicator after a page refresh and to decide when to clear it once
a turn has ended.
"""
messages: list[str]
count: int
class CreateSessionRequest(BaseModel):
"""Request model for creating (or get-or-creating) a chat session.
Two modes, selected by the body:
- Default: create a fresh session. ``dry_run`` is a **top-level**
field — do not nest it inside ``metadata``.
- Builder-bound: when ``builder_graph_id`` is set, the endpoint
switches to **get-or-create** keyed on
``(user_id, builder_graph_id)``. The builder panel calls this on
mount so the chat persists across refreshes. Graph ownership is
validated inside :func:`get_or_create_builder_session`. Write-side
scope is enforced per-tool (``edit_agent`` / ``run_agent`` reject
any ``agent_id`` other than the bound graph) and a small blacklist
hides tools that conflict with the panel's scope
(``create_agent`` / ``customize_agent`` / ``get_agent_building_guide``
— see :data:`BUILDER_BLOCKED_TOOLS`). Read-side lookups
(``find_block``, ``find_agent``, ``search_docs``, …) stay open.
``dry_run`` is a **top-level** field — do not nest it inside ``metadata``.
Extra/unknown fields are rejected (422) to prevent silent mis-use.
"""
model_config = ConfigDict(extra="forbid")
dry_run: bool = False
builder_graph_id: str | None = Field(default=None, max_length=128)
class CreateSessionResponse(BaseModel):
@@ -305,29 +337,43 @@ async def create_session(
user_id: Annotated[str, Security(auth.get_user_id)],
request: CreateSessionRequest | None = None,
) -> CreateSessionResponse:
"""
Create a new chat session.
"""Create (or get-or-create) a chat session.
Initiates a new chat session for the authenticated user.
Two modes, selected by the request body:
- Default: create a fresh session for the user. ``dry_run=True`` forces
run_block and run_agent calls to use dry-run simulation.
- Builder-bound: when ``builder_graph_id`` is set, get-or-create keyed
on ``(user_id, builder_graph_id)``. Returns the existing session for
that graph or creates one locked to it. Graph ownership is validated
inside :func:`get_or_create_builder_session`; raises 404 on
unauthorized access. Write-side scope is enforced per-tool
(``edit_agent`` / ``run_agent`` reject any ``agent_id`` other than
the bound graph) and a small blacklist hides tools that conflict
with the panel's scope (see :data:`BUILDER_BLOCKED_TOOLS`).
Args:
user_id: The authenticated user ID parsed from the JWT (required).
request: Optional request body. When provided, ``dry_run=True``
forces run_block and run_agent calls to use dry-run simulation.
request: Optional request body with ``dry_run`` and/or
``builder_graph_id``.
Returns:
CreateSessionResponse: Details of the created session.
CreateSessionResponse: Details of the resulting session.
"""
dry_run = request.dry_run if request else False
builder_graph_id = request.builder_graph_id if request else None
logger.info(
f"Creating session with user_id: "
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
f"{', dry_run=True' if dry_run else ''}"
f"{f', builder_graph_id={builder_graph_id}' if builder_graph_id else ''}"
)
session = await create_chat_session(user_id, dry_run=dry_run)
if builder_graph_id:
session = await get_or_create_builder_session(user_id, builder_graph_id)
else:
session = await create_chat_session(user_id, dry_run=dry_run)
return CreateSessionResponse(
id=session.session_id,
@@ -523,23 +569,27 @@ async def get_session(
)
async def get_copilot_usage(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> CoPilotUsageStatus:
) -> CoPilotUsagePublic:
"""Get CoPilot usage status for the authenticated user.
Returns current token usage vs limits for daily and weekly windows.
Global defaults sourced from LaunchDarkly (falling back to config).
Includes the user's rate-limit tier.
Returns the percentage of the daily/weekly allowance used — not the
raw spend or cap — so clients cannot derive per-turn cost or platform
margins. Global defaults sourced from LaunchDarkly (falling back to
config). Includes the user's rate-limit tier.
"""
daily_limit, weekly_limit, tier = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
user_id,
config.daily_cost_limit_microdollars,
config.weekly_cost_limit_microdollars,
)
return await get_usage_status(
status = await get_usage_status(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_cost_limit=daily_limit,
weekly_cost_limit=weekly_limit,
rate_limit_reset_cost=config.rate_limit_reset_cost,
tier=tier,
)
return CoPilotUsagePublic.from_status(status)
class RateLimitResetResponse(BaseModel):
@@ -548,7 +598,9 @@ class RateLimitResetResponse(BaseModel):
success: bool
credits_charged: int = Field(description="Credits charged (in cents)")
remaining_balance: int = Field(description="Credit balance after charge (in cents)")
usage: CoPilotUsageStatus = Field(description="Updated usage status after reset")
usage: CoPilotUsagePublic = Field(
description="Updated usage status after reset (percentages only)"
)
@router.post(
@@ -572,7 +624,7 @@ async def reset_copilot_usage(
) -> RateLimitResetResponse:
"""Reset the daily CoPilot rate limit by spending credits.
Allows users who have hit their daily token limit to spend credits
Allows users who have hit their daily cost limit to spend credits
to reset their daily usage counter and continue working.
Returns 400 if the feature is disabled or the user is not over the limit.
Returns 402 if the user has insufficient credits.
@@ -591,7 +643,9 @@ async def reset_copilot_usage(
)
daily_limit, weekly_limit, tier = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
user_id,
config.daily_cost_limit_microdollars,
config.weekly_cost_limit_microdollars,
)
if daily_limit <= 0:
@@ -628,8 +682,8 @@ async def reset_copilot_usage(
# used for limit checks, not returned to the client.)
usage_status = await get_usage_status(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_cost_limit=daily_limit,
weekly_cost_limit=weekly_limit,
tier=tier,
)
if daily_limit > 0 and usage_status.daily.used < daily_limit:
@@ -664,7 +718,7 @@ async def reset_copilot_usage(
# Reset daily usage in Redis. If this fails, refund the credits
# so the user is not charged for a service they did not receive.
if not await reset_daily_usage(user_id, daily_token_limit=daily_limit):
if not await reset_daily_usage(user_id, daily_cost_limit=daily_limit):
# Compensate: refund the charged credits.
refunded = False
try:
@@ -700,11 +754,11 @@ async def reset_copilot_usage(
finally:
await release_reset_lock(user_id)
# Return updated usage status.
# Return updated usage status (public schema — percentages only).
updated_usage = await get_usage_status(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_cost_limit=daily_limit,
weekly_cost_limit=weekly_limit,
rate_limit_reset_cost=config.rate_limit_reset_cost,
tier=tier,
)
@@ -713,7 +767,7 @@ async def reset_copilot_usage(
success=True,
credits_charged=cost,
remaining_balance=remaining,
usage=updated_usage,
usage=CoPilotUsagePublic.from_status(updated_usage),
)
@@ -764,36 +818,52 @@ async def cancel_session_task(
@router.post(
"/sessions/{session_id}/stream",
responses={
202: {
"model": QueuePendingMessageResponse,
"description": (
"Session has a turn in flight — message queued into the pending "
"buffer and will be picked up between tool-call rounds by the "
"executor currently processing the turn."
),
},
404: {"description": "Session not found or access denied"},
429: {"description": "Cost rate-limit or call-frequency cap exceeded"},
},
)
async def stream_chat_post(
session_id: str,
request: StreamChatRequest,
user_id: str = Security(auth.get_user_id),
):
"""
Stream chat responses for a session (POST with context support).
"""Start a new turn OR queue a follow-up — decided server-side.
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
- Text fragments as they are generated
- Tool call UI elements (if invoked)
- Tool execution results
- **Session idle**: starts a turn. Returns an SSE stream (``text/event-stream``)
with Vercel AI SDK chunks (text fragments, tool-call UI, tool results).
The generation runs in a background task that survives client disconnects;
reconnect via ``GET /sessions/{session_id}/stream`` to resume.
The AI generation runs in a background task that continues even if the client disconnects.
All chunks are written to a per-turn Redis stream for reconnection support. If the client
disconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.
- **Session has a turn in flight**: pushes the message into the per-session
pending buffer and returns ``202 application/json`` with
``QueuePendingMessageResponse``. The executor running the current turn
drains the buffer between tool-call rounds (baseline) or at the start of
the next turn (SDK). Clients should detect the 202 and surface the
message as a queued-chip in the UI.
Args:
session_id: The chat session identifier to associate with the streamed messages.
request: Request body containing message, is_user_message, and optional context.
session_id: The chat session identifier.
request: Request body with message, is_user_message, and optional context.
user_id: Authenticated user ID.
Returns:
StreamingResponse: SSE-formatted response chunks.
"""
import asyncio
import time
stream_start_time = time.perf_counter()
# Wall-clock arrival time, propagated to the executor so the turn-start
# drain can order pending messages relative to this request (pending
# pushed BEFORE this instant were typed earlier; pending pushed AFTER
# are race-path follow-ups typed while /stream was still processing).
request_arrival_at = time.time()
log_meta = {"component": "ChatStream", "session_id": session_id, "user_id": user_id}
logger.info(
@@ -801,7 +871,28 @@ async def stream_chat_post(
f"user={user_id}, message_len={len(request.message)}",
extra={"json_fields": log_meta},
)
await _validate_and_get_session(session_id, user_id)
session = await _validate_and_get_session(session_id, user_id)
builder_permissions = resolve_session_permissions(session)
# Self-defensive queue-fallback: if a turn is already running, don't race
# it on the cluster lock — drop the message into the pending buffer and
# return 202 so the caller can render a chip. Both UI chips and autopilot
# block follow-ups route through this path; keeping the decision on the
# server means every caller gets uniform behaviour.
if (
request.is_user_message
and request.message
and await is_turn_in_flight(session_id)
):
response = await queue_pending_for_http(
session_id=session_id,
user_id=user_id,
message=request.message,
context=request.context,
file_ids=request.file_ids,
)
return JSONResponse(status_code=202, content=response.model_dump())
logger.info(
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
extra={
@@ -812,18 +903,20 @@ async def stream_chat_post(
},
)
# Pre-turn rate limit check (token-based).
# Pre-turn rate limit check (cost-based, microdollars).
# check_rate_limit short-circuits internally when both limits are 0.
# Global defaults sourced from LaunchDarkly, falling back to config.
if user_id:
try:
daily_limit, weekly_limit, _ = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
user_id,
config.daily_cost_limit_microdollars,
config.weekly_cost_limit_microdollars,
)
await check_rate_limit(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
daily_cost_limit=daily_limit,
weekly_cost_limit=weekly_limit,
)
except RateLimitExceeded as e:
raise HTTPException(status_code=429, detail=str(e)) from e
@@ -832,33 +925,10 @@ async def stream_chat_post(
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
sanitized_file_ids: list[str] | None = None
if request.file_ids and user_id:
# Filter to valid UUIDs only to prevent DB abuse
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
if valid_ids:
workspace = await get_or_create_workspace(user_id)
# Batch query instead of N+1
files = await UserWorkspaceFile.prisma().find_many(
where={
"id": {"in": valid_ids},
"workspaceId": workspace.id,
"isDeleted": False,
}
)
# Only keep IDs that actually exist in the user's workspace
sanitized_file_ids = [wf.id for wf in files] or None
file_lines: list[str] = [
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
for wf in files
]
if file_lines:
files_block = (
"\n\n[Attached files]\n"
+ "\n".join(file_lines)
+ "\nUse read_workspace_file with the file_id to access file contents."
)
request.message += files_block
if request.file_ids:
files = await resolve_workspace_files(user_id, request.file_ids)
sanitized_file_ids = [wf.id for wf in files] or None
request.message += build_files_block(files)
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
@@ -917,6 +987,8 @@ async def stream_chat_post(
file_ids=sanitized_file_ids,
mode=request.mode,
model=request.model,
permissions=builder_permissions,
request_arrival_at=request_arrival_at,
)
else:
logger.info(
@@ -1067,6 +1139,31 @@ async def stream_chat_post(
)
@router.get(
"/sessions/{session_id}/messages/pending",
response_model=PeekPendingMessagesResponse,
responses={
404: {"description": "Session not found or access denied"},
},
)
async def get_pending_messages(
session_id: str,
user_id: str = Security(auth.get_user_id),
):
"""Peek at the pending-message buffer without consuming it.
Returns the current contents of the session's pending message buffer
so the frontend can restore the queued-message indicator after a page
refresh and clear it correctly once a turn drains the buffer.
"""
await _validate_and_get_session(session_id, user_id)
pending = await peek_pending_messages(session_id)
return PeekPendingMessagesResponse(
messages=[m.content for m in pending],
count=len(pending),
)
@router.get(
"/sessions/{session_id}/stream",
)
@@ -1323,6 +1420,7 @@ ToolResponseUnion = (
| MemorySearchResponse
| MemoryForgetCandidatesResponse
| MemoryForgetConfirmResponse
| TodoWriteResponse
)

View File

@@ -14,7 +14,7 @@ from fastapi import (
Security,
status,
)
from pydantic import BaseModel, Field, SecretStr, model_validator
from pydantic import BaseModel, Field, model_validator
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
from backend.api.features.library.db import set_preset_webhook, update_preset
@@ -29,15 +29,14 @@ from backend.data.integrations import (
wait_for_webhook_event,
)
from backend.data.model import (
APIKeyCredentials,
Credentials,
CredentialsType,
HostScopedCredentials,
OAuth2Credentials,
UserIntegrations,
is_sdk_default,
)
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
from backend.data.user import get_user_integrations
from backend.executor.utils import add_graph_execution
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
from backend.integrations.credentials_store import (
@@ -48,7 +47,14 @@ from backend.integrations.creds_manager import (
IntegrationCredentialsManager,
create_mcp_oauth_handler,
)
from backend.integrations.managed_credentials import ensure_managed_credentials
from backend.integrations.managed_credentials import (
ensure_managed_credential,
ensure_managed_credentials,
)
from backend.integrations.managed_providers.ayrshare import AyrshareManagedProvider
from backend.integrations.managed_providers.ayrshare import (
settings_available as ayrshare_settings_available,
)
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
@@ -87,14 +93,23 @@ async def login(
scopes: Annotated[
str, Query(title="Comma-separated list of authorization scopes")
] = "",
credential_id: Annotated[
str | None,
Query(title="ID of existing credential to upgrade scopes for"),
] = None,
) -> LoginResponse:
handler = _get_provider_oauth_handler(request, provider)
requested_scopes = scopes.split(",") if scopes else []
if credential_id:
requested_scopes = await _prepare_scope_upgrade(
user_id, provider, credential_id, requested_scopes
)
# Generate and store a secure random state token along with the scopes
state_token, code_challenge = await creds_manager.store.store_state_token(
user_id, provider, requested_scopes
user_id, provider, requested_scopes, credential_id=credential_id
)
login_url = handler.get_login_url(
requested_scopes, state_token, code_challenge=code_challenge
@@ -216,7 +231,9 @@ async def callback(
)
# TODO: Allow specifying `title` to set on `credentials`
await creds_manager.create(user_id, credentials)
credentials = await _merge_or_create_credential(
user_id, provider, credentials, valid_state.credential_id
)
logger.debug(
f"Successfully processed OAuth callback for user {user_id} "
@@ -226,13 +243,38 @@ async def callback(
return to_meta_response(credentials)
# Bound the first-time sweep so a slow upstream (e.g. Ayrshare) can't hang
# the credential-list endpoint. On timeout we still kick off a fire-and-
# forget sweep so provisioning eventually completes; the user just won't
# see the managed cred until the next refresh.
_MANAGED_PROVISION_TIMEOUT_S = 10.0
async def _ensure_managed_credentials_bounded(user_id: str) -> None:
try:
await asyncio.wait_for(
ensure_managed_credentials(user_id, creds_manager.store),
timeout=_MANAGED_PROVISION_TIMEOUT_S,
)
except asyncio.TimeoutError:
logger.warning(
"Managed credential sweep exceeded %.1fs for user=%s; "
"continuing without it — provisioning will complete in background",
_MANAGED_PROVISION_TIMEOUT_S,
user_id,
)
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
@router.get("/credentials", summary="List Credentials")
async def list_credentials(
user_id: Annotated[str, Security(get_user_id)],
) -> list[CredentialsMetaResponse]:
# Fire-and-forget: provision missing managed credentials in the background.
# The credential appears on the next page load; listing is never blocked.
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
# Block on provisioning so managed credentials appear on the first load
# instead of after a refresh, but with a timeout so a slow upstream
# can't hang the endpoint. `_provisioned_users` short-circuits on
# repeat calls.
await _ensure_managed_credentials_bounded(user_id)
credentials = await creds_manager.store.get_all_creds(user_id)
return [
@@ -247,7 +289,7 @@ async def list_credentials_by_provider(
],
user_id: Annotated[str, Security(get_user_id)],
) -> list[CredentialsMetaResponse]:
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
await _ensure_managed_credentials_bounded(user_id)
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
return [
@@ -281,6 +323,115 @@ async def get_credential(
return to_meta_response(credential)
class PickerTokenResponse(BaseModel):
"""Short-lived OAuth access token shipped to the browser for rendering a
provider-hosted picker UI (e.g. Google Drive Picker). Deliberately narrow:
only the fields the client needs to initialize the picker widget. Issued
from the user's own stored credential so ownership and scope gating are
enforced by the credential lookup."""
access_token: str = Field(
description="OAuth access token suitable for the picker SDK call."
)
access_token_expires_at: int | None = Field(
default=None,
description="Unix timestamp at which the access token expires, if known.",
)
# Allowlist of (provider, scopes) tuples that may mint picker tokens. Only
# Drive-picker-capable scopes qualify so a caller can't use this endpoint to
# extract a GitHub / other-provider OAuth token for unrelated purposes. If a
# future provider integrates a hosted picker that needs a raw access token,
# add its specific picker-relevant scopes here.
_PICKER_TOKEN_ALLOWED_SCOPES: dict[ProviderName, frozenset[str]] = {
ProviderName.GOOGLE: frozenset(
[
"https://www.googleapis.com/auth/drive.file",
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive",
]
),
}
@router.post(
"/{provider}/credentials/{cred_id}/picker-token",
summary="Issue a short-lived access token for a provider-hosted picker",
operation_id="postV1GetPickerToken",
)
async def get_picker_token(
provider: Annotated[
ProviderName, Path(title="The provider that owns the credentials")
],
cred_id: Annotated[
str, Path(title="The ID of the OAuth2 credentials to mint a token from")
],
user_id: Annotated[str, Security(get_user_id)],
) -> PickerTokenResponse:
"""Return the raw access token for an OAuth2 credential so the frontend
can initialize a provider-hosted picker (e.g. Google Drive Picker).
`GET /{provider}/credentials/{cred_id}` deliberately strips secrets (see
`CredentialsMetaResponse` + `TestGetCredentialReturnsMetaOnly` in
`router_test.py`). That hardening broke the Drive picker, which needs the
raw access token to call `google.picker.Builder.setOAuthToken(...)`. This
endpoint carves a narrow, explicit hole: the caller must own the
credential, it must be OAuth2, and the endpoint returns only the access
token + its expiry — nothing else about the credential. SDK-default
credentials are excluded for the same reason as `get_credential`.
"""
if is_sdk_default(cred_id):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
credential = await creds_manager.get(user_id, cred_id)
if not credential:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if not provider_matches(credential.provider, provider):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if not isinstance(credential, OAuth2Credentials):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Picker tokens are only available for OAuth2 credentials",
)
if not credential.access_token:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Credential has no access token; reconnect the account",
)
# Gate on provider+scope: only credentials that actually grant access to
# a provider-hosted picker flow may mint a token through this endpoint.
# Prevents using this path to extract bearer tokens for unrelated OAuth
# integrations (e.g. GitHub) that happen to be stored under the same user.
allowed_scopes = _PICKER_TOKEN_ALLOWED_SCOPES.get(provider)
if not allowed_scopes:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(f"Picker tokens are not available for provider '{provider.value}'"),
)
cred_scopes = set(credential.scopes or [])
if cred_scopes.isdisjoint(allowed_scopes):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
"Credential does not grant any scope eligible for the picker. "
"Reconnect with the appropriate scope."
),
)
return PickerTokenResponse(
access_token=credential.access_token.get_secret_value(),
access_token_expires_at=credential.access_token_expires_at,
)
@router.post("/{provider}/credentials", status_code=201, summary="Create Credentials")
async def create_credentials(
user_id: Annotated[str, Security(get_user_id)],
@@ -574,6 +725,186 @@ async def _execute_webhook_preset_trigger(
# Continue processing - webhook should be resilient to individual failures
# -------------------- INCREMENTAL AUTH HELPERS -------------------- #
async def _prepare_scope_upgrade(
user_id: str,
provider: ProviderName,
credential_id: str,
requested_scopes: list[str],
) -> list[str]:
"""Validate an existing credential for scope upgrade and compute scopes.
For providers without native incremental auth (e.g. GitHub), returns the
union of existing + requested scopes. For providers that handle merging
server-side (e.g. Google with ``include_granted_scopes``), returns the
requested scopes unchanged.
Raises HTTPException on validation failure.
"""
# Platform-owned system credentials must never be upgraded — scope
# changes here would leak across every user that shares them.
if is_system_credential(credential_id):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="System credentials cannot be upgraded",
)
existing = await creds_manager.store.get_creds_by_id(user_id, credential_id)
if not existing:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Credential to upgrade not found",
)
if not isinstance(existing, OAuth2Credentials):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only OAuth2 credentials can be upgraded",
)
if not provider_matches(existing.provider, provider.value):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Credential provider does not match the requested provider",
)
if existing.is_managed:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Managed credentials cannot be upgraded",
)
# Google handles scope merging via include_granted_scopes; others need
# the union of existing + new scopes in the login URL.
if provider != ProviderName.GOOGLE:
requested_scopes = list(set(requested_scopes) | set(existing.scopes))
return requested_scopes
async def _merge_or_create_credential(
user_id: str,
provider: ProviderName,
credentials: OAuth2Credentials,
credential_id: str | None,
) -> OAuth2Credentials:
"""Either upgrade an existing credential or create a new one.
When *credential_id* is set (explicit upgrade), merges scopes and updates
the existing credential. Otherwise, checks for an implicit merge (same
provider + username) before falling back to creating a new credential.
"""
if credential_id:
return await _upgrade_existing_credential(user_id, credential_id, credentials)
# Implicit merge: check for existing credential with same provider+username.
# Skip managed/system credentials and require a non-None username on both
# sides so we never accidentally merge unrelated credentials.
if credentials.username is None:
await creds_manager.create(user_id, credentials)
return credentials
existing_creds = await creds_manager.store.get_creds_by_provider(user_id, provider)
matching = next(
(
c
for c in existing_creds
if isinstance(c, OAuth2Credentials)
and not c.is_managed
and not is_system_credential(c.id)
and c.username is not None
and c.username == credentials.username
),
None,
)
if matching:
# Only merge into the existing credential when the new token
# already covers every scope we're about to advertise on it.
# Without this guard we'd overwrite ``matching.access_token`` with
# a narrower token while storing a wider ``scopes`` list — the
# record would claim authorizations the token does not grant, and
# blocks using the lost scopes would fail with opaque 401/403s
# until the user hits re-auth. On a narrowing login, keep the
# two credentials separate instead.
if set(credentials.scopes).issuperset(set(matching.scopes)):
return await _upgrade_existing_credential(user_id, matching.id, credentials)
await creds_manager.create(user_id, credentials)
return credentials
async def _upgrade_existing_credential(
user_id: str,
existing_cred_id: str,
new_credentials: OAuth2Credentials,
) -> OAuth2Credentials:
"""Merge scopes from *new_credentials* into an existing credential."""
# Defense-in-depth: re-check system and provider invariants right before
# the write. The login-time check in `_prepare_scope_upgrade` can go stale
# by the time the callback runs, and the implicit-merge path bypasses
# login-time validation entirely, so every write-path must enforce these
# on its own.
if is_system_credential(existing_cred_id):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="System credentials cannot be upgraded",
)
existing = await creds_manager.store.get_creds_by_id(user_id, existing_cred_id)
if not existing or not isinstance(existing, OAuth2Credentials):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Credential to upgrade not found",
)
if existing.is_managed:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Managed credentials cannot be upgraded",
)
if not provider_matches(existing.provider, new_credentials.provider):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Credential provider does not match the requested provider",
)
if (
existing.username
and new_credentials.username
and existing.username != new_credentials.username
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username mismatch: authenticated as a different user",
)
# Operate on a copy so the caller's ``new_credentials`` object is not
# mutated out from under them. Every caller today immediately discards
# or replaces its reference, but the implicit-merge path in
# ``_merge_or_create_credential`` reads ``credentials.scopes`` before
# calling into us — a future reader after the call would otherwise
# silently see the overwritten values.
merged = new_credentials.model_copy(deep=True)
merged.id = existing.id
merged.title = existing.title
merged.scopes = list(set(existing.scopes) | set(new_credentials.scopes))
merged.metadata = {
**(existing.metadata or {}),
**(new_credentials.metadata or {}),
}
# Preserve the existing refresh_token and username if the incremental
# response doesn't carry them. Providers like Google only return a
# refresh_token on first authorization — dropping it here would orphan
# the credential on the next access-token expiry, forcing the user to
# re-auth from scratch. Username is similarly sticky: if we've already
# resolved it for this credential, keep it rather than silently
# blanking it on an incremental upgrade.
if not merged.refresh_token and existing.refresh_token:
merged.refresh_token = existing.refresh_token
merged.refresh_token_expires_at = existing.refresh_token_expires_at
if not merged.username and existing.username:
merged.username = existing.username
await creds_manager.update(user_id, merged)
return merged
# --------------------------- UTILITIES ---------------------------- #
@@ -784,12 +1115,21 @@ def _get_provider_oauth_handler(
async def get_ayrshare_sso_url(
user_id: Annotated[str, Security(get_user_id)],
) -> AyrshareSSOResponse:
"""
Generate an SSO URL for Ayrshare social media integration.
"""Generate a JWT SSO URL so the user can link their social accounts.
Returns:
dict: Contains the SSO URL for Ayrshare integration
The per-user Ayrshare profile key is provisioned and persisted as a
standard ``is_managed=True`` credential by
:class:`~backend.integrations.managed_providers.ayrshare.AyrshareManagedProvider`.
This endpoint only signs a short-lived JWT pointing at the Ayrshare-
hosted social-linking page; all profile lifecycle logic lives with the
managed provider.
"""
if not ayrshare_settings_available():
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="Ayrshare integration is not configured",
)
try:
client = AyrshareClient()
except MissingConfigError:
@@ -798,66 +1138,63 @@ async def get_ayrshare_sso_url(
detail="Ayrshare integration is not configured",
)
# Ayrshare profile key is stored in the credentials store
# It is generated when creating a new profile, if there is no profile key,
# we create a new profile and store the profile key in the credentials store
user_integrations: UserIntegrations = await get_user_integrations(user_id)
profile_key = user_integrations.managed_credentials.ayrshare_profile_key
if not profile_key:
logger.debug(f"Creating new Ayrshare profile for user {user_id}")
try:
profile = await client.create_profile(
title=f"User {user_id}", messaging_active=True
)
profile_key = profile.profileKey
await creds_manager.store.set_ayrshare_profile_key(user_id, profile_key)
except Exception as e:
logger.error(f"Error creating Ayrshare profile for user {user_id}: {e}")
raise HTTPException(
status_code=HTTP_502_BAD_GATEWAY,
detail="Failed to create Ayrshare profile",
)
else:
logger.debug(f"Using existing Ayrshare profile for user {user_id}")
profile_key_str = (
profile_key.get_secret_value()
if isinstance(profile_key, SecretStr)
else str(profile_key)
# On-demand provisioning: AyrshareManagedProvider opts out of the
# credentials sweep (profile quota is per-user subscription-bound). This
# endpoint is the only trigger that provisions a profile — one Ayrshare
# profile per user who actually opens the connect flow, not one per
# every authenticated user.
provisioned = await ensure_managed_credential(
user_id, creds_manager.store, AyrshareManagedProvider()
)
if not provisioned:
raise HTTPException(
status_code=HTTP_502_BAD_GATEWAY,
detail="Failed to provision Ayrshare profile",
)
ayrshare_creds = [
c
for c in await creds_manager.store.get_creds_by_provider(user_id, "ayrshare")
if c.is_managed and isinstance(c, APIKeyCredentials)
]
if not ayrshare_creds:
logger.error(
"Ayrshare credential provisioning did not produce a credential "
"for user %s",
user_id,
)
raise HTTPException(
status_code=HTTP_502_BAD_GATEWAY,
detail="Failed to provision Ayrshare profile",
)
profile_key_str = ayrshare_creds[0].api_key.get_secret_value()
private_key = settings.secrets.ayrshare_jwt_key
# Ayrshare JWT expiry is 2880 minutes (48 hours)
# Ayrshare JWT max lifetime is 2880 minutes (48 h).
max_expiry_minutes = 2880
try:
logger.debug(f"Generating Ayrshare JWT for user {user_id}")
jwt_response = await client.generate_jwt(
private_key=private_key,
profile_key=profile_key_str,
# `allowed_social` is the set of networks the Ayrshare-hosted
# social-linking page will *offer* the user to connect. Blocks
# exist for more platforms than are listed here; the list is
# deliberately narrower so the rollout can verify each network
# end-to-end before widening the user-visible surface. Keep
# in sync with tested platforms — extend as each is verified
# against the block + Ayrshare's network-specific quirks.
allowed_social=[
# NOTE: We are enabling platforms one at a time
# to speed up the development process
# SocialPlatform.FACEBOOK,
SocialPlatform.TWITTER,
SocialPlatform.LINKEDIN,
SocialPlatform.INSTAGRAM,
SocialPlatform.YOUTUBE,
# SocialPlatform.REDDIT,
# SocialPlatform.TELEGRAM,
# SocialPlatform.GOOGLE_MY_BUSINESS,
# SocialPlatform.PINTEREST,
SocialPlatform.TIKTOK,
# SocialPlatform.BLUESKY,
# SocialPlatform.SNAPCHAT,
# SocialPlatform.THREADS,
],
expires_in=max_expiry_minutes,
verify=True,
)
except Exception as e:
logger.error(f"Error generating Ayrshare JWT for user {user_id}: {e}")
except Exception as exc:
logger.error("Error generating Ayrshare JWT for user %s: %s", user_id, exc)
raise HTTPException(
status_code=HTTP_502_BAD_GATEWAY, detail="Failed to generate JWT"
)

View File

@@ -393,7 +393,7 @@ class TestEnsureManagedCredentials:
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
provider.provision.assert_awaited_once_with("user-1")
provider.provision.assert_awaited_once_with("user-1", store)
store.add_managed_credential.assert_awaited_once_with("user-1", cred)
@pytest.mark.asyncio
@@ -568,3 +568,181 @@ class TestCleanupManagedCredentials:
_PROVIDERS.update(saved)
# No exception raised — cleanup failure is swallowed.
class TestGetPickerToken:
"""POST /{provider}/credentials/{cred_id}/picker-token must:
1. Return the access token for OAuth2 creds the caller owns.
2. 404 for non-owned, non-existent, or wrong-provider creds.
3. 400 for non-OAuth2 creds (API key, host-scoped, user/password).
4. 404 for SDK default creds (same hardening as get_credential).
5. Preserve the `TestGetCredentialReturnsMetaOnly` contract — the
existing meta-only endpoint must still strip secrets even after
this picker-token endpoint exists."""
def test_oauth2_owner_gets_access_token(self):
# Use a Google cred with a drive.file scope — only picker-eligible
# (provider, scope) pairs can mint a token. GitHub-style creds are
# explicitly rejected; see `test_non_picker_provider_rejected_as_400`.
cred = _make_oauth2_cred(
cred_id="cred-gdrive",
provider="google",
)
cred.scopes = ["https://www.googleapis.com/auth/drive.file"]
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.post("/google/credentials/cred-gdrive/picker-token")
assert resp.status_code == 200
data = resp.json()
# The whole point of this endpoint: the access token IS returned here.
assert data["access_token"] == "ghp_secret_token"
# Only the two declared fields come back — nothing else leaks.
assert set(data.keys()) <= {"access_token", "access_token_expires_at"}
def test_non_picker_provider_rejected_as_400(self):
"""Provider allowlist: even with a valid OAuth2 credential, a
non-picker provider (GitHub, etc.) cannot mint a picker token.
Stops this endpoint from being used as a generic bearer-token
extraction path for any stored OAuth cred under the same user."""
cred = _make_oauth2_cred(provider="github")
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.post("/github/credentials/cred-456/picker-token")
assert resp.status_code == 400
assert "not available for provider" in resp.json()["detail"]
assert "ghp_secret_token" not in str(resp.json())
def test_google_oauth_without_drive_scope_rejected(self):
"""Scope allowlist: a Google OAuth2 cred that only carries non-picker
scopes (e.g. gmail.readonly, calendar) cannot mint a picker token.
Forces the frontend to reconnect with a Drive scope before the
picker is available."""
cred = _make_oauth2_cred(provider="google")
cred.scopes = [
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/calendar",
]
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.post("/google/credentials/cred-456/picker-token")
assert resp.status_code == 400
assert "picker" in resp.json()["detail"].lower()
def test_api_key_credential_rejected_as_400(self):
cred = _make_api_key_cred()
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.post("/openai/credentials/cred-123/picker-token")
assert resp.status_code == 400
# API keys must not silently fall through to a 200 response of some
# other shape — the client should see a clear shape rejection.
body = str(resp.json())
assert "sk-secret-key-value" not in body
def test_user_password_credential_rejected_as_400(self):
cred = _make_user_password_cred()
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.post("/openai/credentials/cred-789/picker-token")
assert resp.status_code == 400
body = str(resp.json())
assert "s3cret-pass" not in body
assert "admin" not in body
def test_host_scoped_credential_rejected_as_400(self):
cred = _make_host_scoped_cred()
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.post("/openai/credentials/cred-host/picker-token")
assert resp.status_code == 400
assert "top-secret" not in str(resp.json())
def test_missing_credential_returns_404(self):
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=None)
resp = client.post("/github/credentials/nonexistent/picker-token")
assert resp.status_code == 404
assert resp.json()["detail"] == "Credentials not found"
def test_wrong_provider_returns_404(self):
"""Symmetric with get_credential: provider mismatch is a generic
404, not a 400, so we don't leak existence of a credential the
caller doesn't own on that provider."""
cred = _make_oauth2_cred(provider="github")
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.post("/google/credentials/cred-456/picker-token")
assert resp.status_code == 404
assert resp.json()["detail"] == "Credentials not found"
def test_sdk_default_returns_404(self):
"""SDK defaults are invisible to the user-facing API — picker-token
must not mint a token for them either."""
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock()
resp = client.post("/openai/credentials/openai-default/picker-token")
assert resp.status_code == 404
mock_mgr.get.assert_not_called()
def test_oauth2_without_access_token_returns_400(self):
"""A stored OAuth2 cred whose access_token is missing can't satisfy
a picker init. Surface a clear reconnect instruction rather than
returning an empty string."""
cred = _make_oauth2_cred()
# Simulate a cred that lost its access token
object.__setattr__(cred, "access_token", None)
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.post("/github/credentials/cred-456/picker-token")
assert resp.status_code == 400
assert "reconnect" in resp.json()["detail"].lower()
def test_meta_only_endpoint_still_strips_access_token(self):
"""Regression guard for the coexistence contract: the new
picker-token endpoint must NOT accidentally leak the token through
the meta-only GET endpoint. TestGetCredentialReturnsMetaOnly
covers this more broadly; this is a fast sanity check co-located
with the new endpoint's tests."""
cred = _make_oauth2_cred()
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.get = AsyncMock(return_value=cred)
resp = client.get("/github/credentials/cred-456")
assert resp.status_code == 200
body = resp.json()
assert "access_token" not in body
assert "refresh_token" not in body
assert "ghp_secret_token" not in str(body)

View File

@@ -743,6 +743,7 @@ async def update_library_agent_version_and_settings(
graph=agent_graph,
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
builder_chat_session_id=library.settings.builder_chat_session_id,
)
if updated_settings != library.settings:
library = await update_library_agent(
@@ -1803,7 +1804,7 @@ async def create_preset_from_graph_execution(
raise NotFoundError(
f"Graph #{graph_execution.graph_id} not found or accessible"
)
elif len(graph.aggregate_credentials_inputs()) > 0:
elif len(graph.regular_credentials_inputs) > 0:
raise ValueError(
f"Graph execution #{graph_exec_id} can't be turned into a preset "
"because it was run before this feature existed "

View File

@@ -0,0 +1 @@
"""Platform bot linking — user-facing REST routes."""

View File

@@ -0,0 +1,158 @@
"""User-facing platform_linking REST routes (JWT auth)."""
import logging
from typing import Annotated
from autogpt_libs import auth
from fastapi import APIRouter, HTTPException, Path, Security
from backend.data.db_accessors import platform_linking_db
from backend.platform_linking.models import (
ConfirmLinkResponse,
ConfirmUserLinkResponse,
DeleteLinkResponse,
LinkTokenInfoResponse,
PlatformLinkInfo,
PlatformUserLinkInfo,
)
from backend.util.exceptions import (
LinkAlreadyExistsError,
LinkFlowMismatchError,
LinkTokenExpiredError,
NotAuthorizedError,
NotFoundError,
)
logger = logging.getLogger(__name__)
router = APIRouter()
TokenPath = Annotated[
str,
Path(max_length=64, pattern=r"^[A-Za-z0-9_-]+$"),
]
def _translate(exc: Exception) -> HTTPException:
if isinstance(exc, NotFoundError):
return HTTPException(status_code=404, detail=str(exc))
if isinstance(exc, NotAuthorizedError):
return HTTPException(status_code=403, detail=str(exc))
if isinstance(exc, LinkAlreadyExistsError):
return HTTPException(status_code=409, detail=str(exc))
if isinstance(exc, LinkTokenExpiredError):
return HTTPException(status_code=410, detail=str(exc))
if isinstance(exc, LinkFlowMismatchError):
return HTTPException(status_code=400, detail=str(exc))
return HTTPException(status_code=500, detail="Internal error.")
@router.get(
"/tokens/{token}/info",
response_model=LinkTokenInfoResponse,
dependencies=[Security(auth.requires_user)],
summary="Get display info for a link token",
)
async def get_link_token_info_route(token: TokenPath) -> LinkTokenInfoResponse:
try:
return await platform_linking_db().get_link_token_info(token)
except (NotFoundError, LinkTokenExpiredError) as exc:
raise _translate(exc) from exc
@router.post(
"/tokens/{token}/confirm",
response_model=ConfirmLinkResponse,
dependencies=[Security(auth.requires_user)],
summary="Confirm a SERVER link token (user must be authenticated)",
)
async def confirm_link_token(
token: TokenPath,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> ConfirmLinkResponse:
try:
return await platform_linking_db().confirm_server_link(token, user_id)
except (
NotFoundError,
LinkFlowMismatchError,
LinkTokenExpiredError,
LinkAlreadyExistsError,
) as exc:
raise _translate(exc) from exc
@router.post(
"/user-tokens/{token}/confirm",
response_model=ConfirmUserLinkResponse,
dependencies=[Security(auth.requires_user)],
summary="Confirm a USER link token (user must be authenticated)",
)
async def confirm_user_link_token(
token: TokenPath,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> ConfirmUserLinkResponse:
try:
return await platform_linking_db().confirm_user_link(token, user_id)
except (
NotFoundError,
LinkFlowMismatchError,
LinkTokenExpiredError,
LinkAlreadyExistsError,
) as exc:
raise _translate(exc) from exc
@router.get(
"/links",
response_model=list[PlatformLinkInfo],
dependencies=[Security(auth.requires_user)],
summary="List all platform servers linked to the authenticated user",
)
async def list_my_links(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> list[PlatformLinkInfo]:
return await platform_linking_db().list_server_links(user_id)
@router.get(
"/user-links",
response_model=list[PlatformUserLinkInfo],
dependencies=[Security(auth.requires_user)],
summary="List all DM links for the authenticated user",
)
async def list_my_user_links(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> list[PlatformUserLinkInfo]:
return await platform_linking_db().list_user_links(user_id)
@router.delete(
"/links/{link_id}",
response_model=DeleteLinkResponse,
dependencies=[Security(auth.requires_user)],
summary="Unlink a platform server",
)
async def delete_link(
link_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> DeleteLinkResponse:
try:
return await platform_linking_db().delete_server_link(link_id, user_id)
except (NotFoundError, NotAuthorizedError) as exc:
raise _translate(exc) from exc
@router.delete(
"/user-links/{link_id}",
response_model=DeleteLinkResponse,
dependencies=[Security(auth.requires_user)],
summary="Unlink a DM / user link",
)
async def delete_user_link_route(
link_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> DeleteLinkResponse:
try:
return await platform_linking_db().delete_user_link(link_id, user_id)
except (NotFoundError, NotAuthorizedError) as exc:
raise _translate(exc) from exc

View File

@@ -0,0 +1,264 @@
"""Route tests: domain exceptions → HTTPException status codes."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException
from backend.util.exceptions import (
LinkAlreadyExistsError,
LinkFlowMismatchError,
LinkTokenExpiredError,
NotAuthorizedError,
NotFoundError,
)
def _db_mock(**method_configs):
"""Return a mock of the accessor's return value with the given AsyncMocks."""
db = MagicMock()
for name, mock in method_configs.items():
setattr(db, name, mock)
return db
class TestTokenInfoRouteTranslation:
@pytest.mark.asyncio
async def test_not_found_maps_to_404(self):
from backend.api.features.platform_linking.routes import (
get_link_token_info_route,
)
db = _db_mock(
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as exc:
await get_link_token_info_route(token="abc")
assert exc.value.status_code == 404
@pytest.mark.asyncio
async def test_expired_maps_to_410(self):
from backend.api.features.platform_linking.routes import (
get_link_token_info_route,
)
db = _db_mock(
get_link_token_info=AsyncMock(side_effect=LinkTokenExpiredError("expired"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as exc:
await get_link_token_info_route(token="abc")
assert exc.value.status_code == 410
class TestConfirmLinkRouteTranslation:
@pytest.mark.asyncio
@pytest.mark.parametrize(
"exc,expected_status",
[
(NotFoundError("missing"), 404),
(LinkFlowMismatchError("wrong flow"), 400),
(LinkTokenExpiredError("expired"), 410),
(LinkAlreadyExistsError("already"), 409),
],
)
async def test_translation(self, exc: Exception, expected_status: int):
from backend.api.features.platform_linking.routes import confirm_link_token
db = _db_mock(confirm_server_link=AsyncMock(side_effect=exc))
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as ctx:
await confirm_link_token(token="abc", user_id="u1")
assert ctx.value.status_code == expected_status
class TestConfirmUserLinkRouteTranslation:
@pytest.mark.asyncio
@pytest.mark.parametrize(
"exc,expected_status",
[
(NotFoundError("missing"), 404),
(LinkFlowMismatchError("wrong flow"), 400),
(LinkTokenExpiredError("expired"), 410),
(LinkAlreadyExistsError("already"), 409),
],
)
async def test_translation(self, exc: Exception, expected_status: int):
from backend.api.features.platform_linking.routes import confirm_user_link_token
db = _db_mock(confirm_user_link=AsyncMock(side_effect=exc))
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as ctx:
await confirm_user_link_token(token="abc", user_id="u1")
assert ctx.value.status_code == expected_status
class TestDeleteLinkRouteTranslation:
@pytest.mark.asyncio
async def test_not_found_maps_to_404(self):
from backend.api.features.platform_linking.routes import delete_link
db = _db_mock(
delete_server_link=AsyncMock(side_effect=NotFoundError("missing"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as exc:
await delete_link(link_id="x", user_id="u1")
assert exc.value.status_code == 404
@pytest.mark.asyncio
async def test_not_owned_maps_to_403(self):
from backend.api.features.platform_linking.routes import delete_link
db = _db_mock(
delete_server_link=AsyncMock(side_effect=NotAuthorizedError("nope"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as exc:
await delete_link(link_id="x", user_id="u1")
assert exc.value.status_code == 403
class TestDeleteUserLinkRouteTranslation:
@pytest.mark.asyncio
async def test_not_found_maps_to_404(self):
from backend.api.features.platform_linking.routes import delete_user_link_route
db = _db_mock(delete_user_link=AsyncMock(side_effect=NotFoundError("missing")))
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as exc:
await delete_user_link_route(link_id="x", user_id="u1")
assert exc.value.status_code == 404
@pytest.mark.asyncio
async def test_not_owned_maps_to_403(self):
from backend.api.features.platform_linking.routes import delete_user_link_route
db = _db_mock(
delete_user_link=AsyncMock(side_effect=NotAuthorizedError("nope"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as exc:
await delete_user_link_route(link_id="x", user_id="u1")
assert exc.value.status_code == 403
# ── Adversarial: malformed token path params ──────────────────────────
class TestAdversarialTokenPath:
# TokenPath enforces `^[A-Za-z0-9_-]+$` + max_length=64.
@pytest.fixture
def client(self):
import fastapi
from autogpt_libs.auth import get_user_id, requires_user
from fastapi.testclient import TestClient
import backend.api.features.platform_linking.routes as routes_mod
app = fastapi.FastAPI()
app.dependency_overrides[requires_user] = lambda: None
app.dependency_overrides[get_user_id] = lambda: "caller-user"
app.include_router(routes_mod.router, prefix="/api/platform-linking")
return TestClient(app)
def test_rejects_token_with_special_chars(self, client):
response = client.get("/api/platform-linking/tokens/bad%24token/info")
assert response.status_code == 422
def test_rejects_token_with_path_traversal(self, client):
for probe in ("..%2F..", "foo..bar", "foo%2Fbar"):
response = client.get(f"/api/platform-linking/tokens/{probe}/info")
assert response.status_code in (
404,
422,
), f"path-traversal probe {probe!r} returned {response.status_code}"
def test_rejects_token_too_long(self, client):
long_token = "a" * 65
response = client.get(f"/api/platform-linking/tokens/{long_token}/info")
assert response.status_code == 422
def test_accepts_token_at_max_length(self, client):
token = "a" * 64
db = _db_mock(
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
response = client.get(f"/api/platform-linking/tokens/{token}/info")
assert response.status_code == 404
def test_accepts_urlsafe_b64_token_shape(self, client):
db = _db_mock(
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
response = client.get("/api/platform-linking/tokens/abc-_XYZ123-_abc/info")
assert response.status_code == 404
def test_confirm_rejects_malformed_token(self, client):
response = client.post("/api/platform-linking/tokens/bad%24token/confirm")
assert response.status_code == 422
class TestAdversarialDeleteLinkId:
"""DELETE link_id has no regex — ensure weird values are handled via
NotFoundError (no crash, no cross-user leak)."""
@pytest.fixture
def client(self):
import fastapi
from autogpt_libs.auth import get_user_id, requires_user
from fastapi.testclient import TestClient
import backend.api.features.platform_linking.routes as routes_mod
app = fastapi.FastAPI()
app.dependency_overrides[requires_user] = lambda: None
app.dependency_overrides[get_user_id] = lambda: "caller-user"
app.include_router(routes_mod.router, prefix="/api/platform-linking")
return TestClient(app)
def test_weird_link_id_returns_404(self, client):
db = _db_mock(
delete_server_link=AsyncMock(side_effect=NotFoundError("missing"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
for link_id in ("'; DROP TABLE links;--", "../../etc/passwd", ""):
response = client.delete(f"/api/platform-linking/links/{link_id}")
assert response.status_code in (404, 405)

View File

@@ -189,7 +189,7 @@ async def test_create_store_submission(mocker):
notifyOnAgentApproved=True,
notifyOnAgentRejected=True,
timezone="Europe/Delft",
subscriptionTier=prisma.enums.SubscriptionTier.FREE, # type: ignore[reportCallIssue,reportAttributeAccessIssue]
subscriptionTier=prisma.enums.SubscriptionTier.BASIC, # type: ignore[reportCallIssue,reportAttributeAccessIssue]
)
mock_agent = prisma.models.AgentGraph(
id="agent-id",

View File

@@ -47,6 +47,51 @@ def _configure_frontend_origin(mocker: pytest_mock.MockFixture) -> None:
)
@pytest.fixture(autouse=True)
def _stub_pending_subscription_change(mocker: pytest_mock.MockFixture) -> None:
"""Default pending-change lookup to None so tests don't hit Stripe/DB.
Individual tests can override via their own mocker.patch call.
"""
mocker.patch(
"backend.api.features.v1.get_pending_subscription_change",
new_callable=AsyncMock,
return_value=None,
)
_DEFAULT_TIER_PRICES: dict[SubscriptionTier, str | None] = {
SubscriptionTier.BASIC: None, # Legacy: stripe-price-id-basic unset by default.
SubscriptionTier.PRO: "price_pro",
SubscriptionTier.MAX: "price_max",
SubscriptionTier.BUSINESS: None, # Reserved: Business card hidden by default.
}
@pytest.fixture(autouse=True)
def _stub_subscription_status_lookups(mocker: pytest_mock.MockFixture) -> None:
"""Stub Stripe price + proration lookups used by get_subscription_status.
The POST /credits/subscription handler now returns the full subscription
status payload from every branch (same-tier, BASIC downgrade, paid→paid
modify, checkout creation), so every POST test implicitly hits these
helpers. Individual tests can override via their own mocker.patch call.
"""
async def default_price_id(tier: SubscriptionTier) -> str | None:
return _DEFAULT_TIER_PRICES.get(tier)
mocker.patch(
"backend.api.features.v1.get_subscription_price_id",
side_effect=default_price_id,
)
mocker.patch(
"backend.api.features.v1.get_proration_credit_cents",
new_callable=AsyncMock,
return_value=0,
)
@pytest.mark.parametrize(
"url,expected",
[
@@ -88,15 +133,28 @@ def test_get_subscription_status_pro(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""GET /credits/subscription returns PRO tier with Stripe price for a PRO user."""
"""GET /credits/subscription returns PRO tier with Stripe prices for all priced tiers."""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.PRO
prices = {
SubscriptionTier.BASIC: "price_basic",
SubscriptionTier.PRO: "price_pro",
SubscriptionTier.MAX: "price_max",
SubscriptionTier.BUSINESS: "price_business",
}
amounts = {
"price_basic": 0,
"price_pro": 1999,
"price_max": 4999,
"price_business": 14999,
}
async def mock_price_id(tier: SubscriptionTier) -> str | None:
return "price_pro" if tier == SubscriptionTier.PRO else None
return prices.get(tier)
async def mock_stripe_price_amount(price_id: str) -> int:
return 1999 if price_id == "price_pro" else 0
return amounts.get(price_id, 0)
mocker.patch(
"backend.api.features.v1.get_user_by_id",
@@ -124,16 +182,18 @@ def test_get_subscription_status_pro(
assert data["tier"] == "PRO"
assert data["monthly_cost"] == 1999
assert data["tier_costs"]["PRO"] == 1999
assert data["tier_costs"]["BUSINESS"] == 0
assert data["tier_costs"]["FREE"] == 0
assert data["tier_costs"]["MAX"] == 4999
assert data["tier_costs"]["BUSINESS"] == 14999
assert data["tier_costs"]["BASIC"] == 0
assert "ENTERPRISE" not in data["tier_costs"]
assert data["proration_credit_cents"] == 500
def test_get_subscription_status_defaults_to_free(
def test_get_subscription_status_defaults_to_basic(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""GET /credits/subscription when subscription_tier is None defaults to FREE."""
"""When all LD price IDs are unset, tier_costs is empty and the caller sees cost=0."""
mock_user = Mock()
mock_user.subscription_tier = None
@@ -157,14 +217,9 @@ def test_get_subscription_status_defaults_to_free(
assert response.status_code == 200
data = response.json()
assert data["tier"] == SubscriptionTier.FREE.value
assert data["tier"] == SubscriptionTier.BASIC.value
assert data["monthly_cost"] == 0
assert data["tier_costs"] == {
"FREE": 0,
"PRO": 0,
"BUSINESS": 0,
"ENTERPRISE": 0,
}
assert data["tier_costs"] == {}
assert data["proration_credit_cents"] == 0
@@ -215,11 +270,11 @@ def test_get_subscription_status_stripe_error_falls_back_to_zero(
assert data["tier_costs"]["PRO"] == 0
def test_update_subscription_tier_free_no_payment(
def test_update_subscription_tier_basic_no_payment(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""POST /credits/subscription to FREE tier when payment disabled skips Stripe."""
"""POST /credits/subscription to BASIC tier when payment disabled skips Stripe."""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.PRO
@@ -240,7 +295,7 @@ def test_update_subscription_tier_free_no_payment(
new_callable=AsyncMock,
)
response = client.post("/credits/subscription", json={"tier": "FREE"})
response = client.post("/credits/subscription", json={"tier": "BASIC"})
assert response.status_code == 200
assert response.json()["url"] == ""
@@ -252,7 +307,7 @@ def test_update_subscription_tier_paid_beta_user(
) -> None:
"""POST /credits/subscription for paid tier when payment disabled returns 422."""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.FREE
mock_user.subscription_tier = SubscriptionTier.BASIC
async def mock_feature_disabled(*args, **kwargs):
return False
@@ -279,7 +334,7 @@ def test_update_subscription_tier_paid_requires_urls(
) -> None:
"""POST /credits/subscription for paid tier without success/cancel URLs returns 422."""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.FREE
mock_user.subscription_tier = SubscriptionTier.BASIC
async def mock_feature_enabled(*args, **kwargs):
return True
@@ -305,7 +360,7 @@ def test_update_subscription_tier_creates_checkout(
) -> None:
"""POST /credits/subscription creates Stripe Checkout Session for paid upgrade."""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.FREE
mock_user.subscription_tier = SubscriptionTier.BASIC
async def mock_feature_enabled(*args, **kwargs):
return True
@@ -344,7 +399,7 @@ def test_update_subscription_tier_rejects_open_redirect(
) -> None:
"""POST /credits/subscription rejects success/cancel URLs outside the frontend origin."""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.FREE
mock_user.subscription_tier = SubscriptionTier.BASIC
async def mock_feature_enabled(*args, **kwargs):
return True
@@ -407,30 +462,77 @@ def test_update_subscription_tier_enterprise_blocked(
set_tier_mock.assert_not_awaited()
def test_update_subscription_tier_same_tier_is_noop(
def test_update_subscription_tier_same_tier_releases_pending_change(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""POST /credits/subscription for the user's current paid tier returns 200 with empty URL.
"""POST /credits/subscription for the user's current tier releases any pending change.
Without this guard a duplicate POST (double-click, browser retry, stale page) would
create a second Stripe Checkout Session for the same price, potentially billing the
user twice until the webhook reconciliation fires.
"Stay on my current tier" — the collapsed replacement for the old
/credits/subscription/cancel-pending route. Always calls
release_pending_subscription_schedule (idempotent when nothing is pending)
and returns the refreshed status with url="". Never creates a Checkout
Session — that would double-charge a user who double-clicks their own tier.
"""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.PRO
async def mock_feature_enabled(*args, **kwargs):
return True
mock_user.subscription_tier = SubscriptionTier.BUSINESS
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user,
)
mocker.patch(
release_mock = mocker.patch(
"backend.api.features.v1.release_pending_subscription_schedule",
new_callable=AsyncMock,
return_value=True,
)
checkout_mock = mocker.patch(
"backend.api.features.v1.create_subscription_checkout",
new_callable=AsyncMock,
)
feature_mock = mocker.patch(
"backend.api.features.v1.is_feature_enabled",
side_effect=mock_feature_enabled,
new_callable=AsyncMock,
return_value=True,
)
response = client.post(
"/credits/subscription",
json={
"tier": "BUSINESS",
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
},
)
assert response.status_code == 200
data = response.json()
assert data["tier"] == "BUSINESS"
assert data["url"] == ""
release_mock.assert_awaited_once_with(TEST_USER_ID)
checkout_mock.assert_not_awaited()
# Same-tier branch short-circuits before the payment-flag check.
feature_mock.assert_not_awaited()
def test_update_subscription_tier_same_tier_no_pending_change_returns_status(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""Same-tier request when nothing is pending still returns status with url=""."""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.PRO
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user,
)
release_mock = mocker.patch(
"backend.api.features.v1.release_pending_subscription_schedule",
new_callable=AsyncMock,
return_value=False,
)
checkout_mock = mocker.patch(
"backend.api.features.v1.create_subscription_checkout",
@@ -447,18 +549,58 @@ def test_update_subscription_tier_same_tier_is_noop(
)
assert response.status_code == 200
assert response.json()["url"] == ""
data = response.json()
assert data["tier"] == "PRO"
assert data["url"] == ""
assert data["pending_tier"] is None
release_mock.assert_awaited_once_with(TEST_USER_ID)
checkout_mock.assert_not_awaited()
def test_update_subscription_tier_free_with_payment_schedules_cancel_and_does_not_update_db(
def test_update_subscription_tier_same_tier_stripe_error_returns_502(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""Downgrading to FREE schedules Stripe cancellation at period end.
"""Same-tier request surfaces a 502 when Stripe release fails.
Carries forward the error contract from the removed
/credits/subscription/cancel-pending route so clients keep seeing 502 for
transient Stripe failures.
"""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.BUSINESS
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user,
)
mocker.patch(
"backend.api.features.v1.release_pending_subscription_schedule",
side_effect=stripe.StripeError("network"),
)
response = client.post(
"/credits/subscription",
json={
"tier": "BUSINESS",
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
},
)
assert response.status_code == 502
assert "contact support" in response.json()["detail"].lower()
def test_update_subscription_tier_basic_with_payment_schedules_cancel_and_does_not_update_db(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""Downgrading to BASIC schedules Stripe cancellation at period end.
The DB tier must NOT be updated immediately — the customer.subscription.deleted
webhook fires at period end and downgrades to FREE then.
webhook fires at period end and downgrades to BASIC then.
"""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.PRO
@@ -484,18 +626,18 @@ def test_update_subscription_tier_free_with_payment_schedules_cancel_and_does_no
side_effect=mock_feature_enabled,
)
response = client.post("/credits/subscription", json={"tier": "FREE"})
response = client.post("/credits/subscription", json={"tier": "BASIC"})
assert response.status_code == 200
mock_cancel.assert_awaited_once()
mock_set_tier.assert_not_awaited()
def test_update_subscription_tier_free_cancel_failure_returns_502(
def test_update_subscription_tier_basic_cancel_failure_returns_502(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""Downgrading to FREE returns 502 with a generic error (no Stripe detail leakage)."""
"""Downgrading to BASIC returns 502 with a generic error (no Stripe detail leakage)."""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.PRO
@@ -518,7 +660,7 @@ def test_update_subscription_tier_free_cancel_failure_returns_502(
side_effect=mock_feature_enabled,
)
response = client.post("/credits/subscription", json={"tier": "FREE"})
response = client.post("/credits/subscription", json={"tier": "BASIC"})
assert response.status_code == 502
detail = response.json()["detail"]
@@ -635,6 +777,16 @@ def test_update_subscription_tier_paid_to_paid_modifies_subscription(
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.PRO
async def price_id_with_business(tier: SubscriptionTier) -> str | None:
return {
**_DEFAULT_TIER_PRICES,
SubscriptionTier.BUSINESS: "price_business",
}.get(tier)
mocker.patch(
"backend.api.features.v1.get_subscription_price_id",
side_effect=price_id_with_business,
)
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
@@ -670,6 +822,49 @@ def test_update_subscription_tier_paid_to_paid_modifies_subscription(
checkout_mock.assert_not_awaited()
def test_update_subscription_tier_max_checkout(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""POST /credits/subscription from PRO→MAX modifies the existing subscription."""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.PRO
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user,
)
mocker.patch(
"backend.api.features.v1.is_feature_enabled",
new_callable=AsyncMock,
return_value=True,
)
modify_mock = mocker.patch(
"backend.api.features.v1.modify_stripe_subscription_for_tier",
new_callable=AsyncMock,
return_value=True,
)
checkout_mock = mocker.patch(
"backend.api.features.v1.create_subscription_checkout",
new_callable=AsyncMock,
)
response = client.post(
"/credits/subscription",
json={
"tier": "MAX",
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
},
)
assert response.status_code == 200
assert response.json()["url"] == ""
modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.MAX)
checkout_mock.assert_not_awaited()
def test_update_subscription_tier_admin_granted_paid_to_paid_updates_db_directly(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
@@ -683,6 +878,16 @@ def test_update_subscription_tier_admin_granted_paid_to_paid_updates_db_directly
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.PRO
async def price_id_with_business(tier: SubscriptionTier) -> str | None:
return {
**_DEFAULT_TIER_PRICES,
SubscriptionTier.BUSINESS: "price_business",
}.get(tier)
mocker.patch(
"backend.api.features.v1.get_subscription_price_id",
side_effect=price_id_with_business,
)
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
@@ -725,6 +930,128 @@ def test_update_subscription_tier_admin_granted_paid_to_paid_updates_db_directly
checkout_mock.assert_not_awaited()
def test_update_subscription_tier_priced_basic_no_sub_falls_through_to_checkout(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""Once stripe-price-id-basic is configured, a BASIC user without an active sub
must hit Stripe Checkout rather than being silently set_subscription_tier'd."""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.BASIC
async def mock_price_id(tier: SubscriptionTier) -> str | None:
return {
SubscriptionTier.BASIC: "price_basic",
SubscriptionTier.PRO: "price_pro",
SubscriptionTier.MAX: "price_max",
SubscriptionTier.BUSINESS: "price_business",
}.get(tier)
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user,
)
mocker.patch(
"backend.api.features.v1.get_subscription_price_id",
side_effect=mock_price_id,
)
mocker.patch(
"backend.api.features.v1.is_feature_enabled",
new_callable=AsyncMock,
return_value=True,
)
modify_mock = mocker.patch(
"backend.api.features.v1.modify_stripe_subscription_for_tier",
new_callable=AsyncMock,
return_value=False,
)
set_tier_mock = mocker.patch(
"backend.api.features.v1.set_subscription_tier",
new_callable=AsyncMock,
)
checkout_mock = mocker.patch(
"backend.api.features.v1.create_subscription_checkout",
new_callable=AsyncMock,
return_value="https://checkout.stripe.com/pay/cs_test_priced_basic",
)
response = client.post(
"/credits/subscription",
json={
"tier": "PRO",
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
},
)
assert response.status_code == 200
assert (
response.json()["url"] == "https://checkout.stripe.com/pay/cs_test_priced_basic"
)
# Priced-BASIC user without an active sub: must NOT silently flip DB tier —
# they need to set up payment via Checkout.
set_tier_mock.assert_not_awaited()
checkout_mock.assert_awaited_once()
# modify is still called first; returning False just means "no active sub".
modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.PRO)
def test_update_subscription_tier_target_without_ld_price_returns_422(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""Paid target with no LD-configured Stripe price must fail fast with 422.
Matches the UI hiding: if `stripe-price-id-pro` resolves to None we can't
start a Checkout Session anyway, and we don't want to surface an opaque
Stripe error mid-flow. The handler rejects the request before touching
Stripe at all.
"""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.BASIC
async def mock_price_id(tier: SubscriptionTier) -> str | None:
return None # Neither BASIC nor PRO have an LD price.
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user,
)
mocker.patch(
"backend.api.features.v1.get_subscription_price_id",
side_effect=mock_price_id,
)
mocker.patch(
"backend.api.features.v1.is_feature_enabled",
new_callable=AsyncMock,
return_value=True,
)
checkout_mock = mocker.patch(
"backend.api.features.v1.create_subscription_checkout",
new_callable=AsyncMock,
)
modify_mock = mocker.patch(
"backend.api.features.v1.modify_stripe_subscription_for_tier",
new_callable=AsyncMock,
)
response = client.post(
"/credits/subscription",
json={
"tier": "PRO",
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
},
)
assert response.status_code == 422
assert "not available" in response.json()["detail"].lower()
checkout_mock.assert_not_awaited()
modify_mock.assert_not_awaited()
def test_update_subscription_tier_paid_to_paid_stripe_error_returns_502(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
@@ -733,6 +1060,16 @@ def test_update_subscription_tier_paid_to_paid_stripe_error_returns_502(
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.PRO
async def price_id_with_business(tier: SubscriptionTier) -> str | None:
return {
**_DEFAULT_TIER_PRICES,
SubscriptionTier.BUSINESS: "price_business",
}.get(tier)
mocker.patch(
"backend.api.features.v1.get_subscription_price_id",
side_effect=price_id_with_business,
)
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
@@ -761,11 +1098,11 @@ def test_update_subscription_tier_paid_to_paid_stripe_error_returns_502(
assert response.status_code == 502
def test_update_subscription_tier_free_no_stripe_subscription(
def test_update_subscription_tier_basic_no_stripe_subscription(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""Downgrading to FREE when no Stripe subscription exists updates DB tier directly.
"""Downgrading to BASIC when no Stripe subscription exists updates DB tier directly.
Admin-granted paid tiers have no associated Stripe subscription. When such a
user requests a self-service downgrade, cancel_stripe_subscription returns False
@@ -796,10 +1133,214 @@ def test_update_subscription_tier_free_no_stripe_subscription(
new_callable=AsyncMock,
)
response = client.post("/credits/subscription", json={"tier": "FREE"})
response = client.post("/credits/subscription", json={"tier": "BASIC"})
assert response.status_code == 200
assert response.json()["url"] == ""
cancel_mock.assert_awaited_once_with(TEST_USER_ID)
# DB tier must be updated immediately — no webhook will fire for a missing sub
set_tier_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.FREE)
set_tier_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.BASIC)
def test_get_subscription_status_includes_pending_tier(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""GET /credits/subscription exposes pending_tier and pending_tier_effective_at."""
import datetime as dt
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.BUSINESS
effective_at = dt.datetime(2030, 1, 1, tzinfo=dt.timezone.utc)
async def mock_price_id(tier: SubscriptionTier) -> str | None:
return None
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user,
)
mocker.patch(
"backend.api.features.v1.get_subscription_price_id",
side_effect=mock_price_id,
)
mocker.patch(
"backend.api.features.v1.get_proration_credit_cents",
new_callable=AsyncMock,
return_value=0,
)
mocker.patch(
"backend.api.features.v1.get_pending_subscription_change",
new_callable=AsyncMock,
return_value=(SubscriptionTier.PRO, effective_at),
)
response = client.get("/credits/subscription")
assert response.status_code == 200
data = response.json()
assert data["pending_tier"] == "PRO"
assert data["pending_tier_effective_at"] is not None
def test_get_subscription_status_no_pending_tier(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""When no pending change exists the response omits pending_tier."""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.PRO
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user,
)
mocker.patch(
"backend.api.features.v1.get_subscription_price_id",
new_callable=AsyncMock,
return_value=None,
)
mocker.patch(
"backend.api.features.v1.get_proration_credit_cents",
new_callable=AsyncMock,
return_value=0,
)
mocker.patch(
"backend.api.features.v1.get_pending_subscription_change",
new_callable=AsyncMock,
return_value=None,
)
response = client.get("/credits/subscription")
assert response.status_code == 200
data = response.json()
assert data["pending_tier"] is None
assert data["pending_tier_effective_at"] is None
def test_update_subscription_tier_downgrade_paid_to_paid_schedules(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""A BUSINESS→PRO downgrade request dispatches to modify_stripe_subscription_for_tier."""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.BUSINESS
async def price_id_with_business(tier: SubscriptionTier) -> str | None:
return {
**_DEFAULT_TIER_PRICES,
SubscriptionTier.BUSINESS: "price_business",
}.get(tier)
mocker.patch(
"backend.api.features.v1.get_subscription_price_id",
side_effect=price_id_with_business,
)
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user,
)
mocker.patch(
"backend.api.features.v1.is_feature_enabled",
new_callable=AsyncMock,
return_value=True,
)
modify_mock = mocker.patch(
"backend.api.features.v1.modify_stripe_subscription_for_tier",
new_callable=AsyncMock,
return_value=True,
)
checkout_mock = mocker.patch(
"backend.api.features.v1.create_subscription_checkout",
new_callable=AsyncMock,
)
response = client.post(
"/credits/subscription",
json={
"tier": "PRO",
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
},
)
assert response.status_code == 200
assert response.json()["url"] == ""
modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.PRO)
checkout_mock.assert_not_awaited()
def test_stripe_webhook_dispatches_subscription_schedule_released(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""subscription_schedule.released routes to sync_subscription_schedule_from_stripe."""
schedule_obj = {"id": "sub_sched_1", "subscription": "sub_pro"}
event = {
"type": "subscription_schedule.released",
"data": {"object": schedule_obj},
}
mocker.patch(
"backend.api.features.v1.settings.secrets.stripe_webhook_secret",
new="whsec_test",
)
mocker.patch(
"backend.api.features.v1.stripe.Webhook.construct_event",
return_value=event,
)
sync_mock = mocker.patch(
"backend.api.features.v1.sync_subscription_schedule_from_stripe",
new_callable=AsyncMock,
)
response = client.post(
"/credits/stripe_webhook",
content=b"{}",
headers={"stripe-signature": "t=1,v1=abc"},
)
assert response.status_code == 200
sync_mock.assert_awaited_once_with(schedule_obj)
def test_stripe_webhook_ignores_subscription_schedule_updated(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""subscription_schedule.updated must NOT dispatch: our own
SubscriptionSchedule.create/.modify calls fire this event and would
otherwise loop redundant traffic through the sync handler. State
transitions we care about surface via .released/.completed, and phase
advance to a new price is already covered by customer.subscription.updated.
"""
schedule_obj = {"id": "sub_sched_1", "subscription": "sub_pro"}
event = {
"type": "subscription_schedule.updated",
"data": {"object": schedule_obj},
}
mocker.patch(
"backend.api.features.v1.settings.secrets.stripe_webhook_secret",
new="whsec_test",
)
mocker.patch(
"backend.api.features.v1.stripe.Webhook.construct_event",
return_value=event,
)
sync_mock = mocker.patch(
"backend.api.features.v1.sync_subscription_schedule_from_stripe",
new_callable=AsyncMock,
)
response = client.post(
"/credits/stripe_webhook",
content=b"{}",
headers={"stripe-signature": "t=1,v1=abc"},
)
assert response.status_code == 200
sync_mock.assert_not_awaited()

View File

@@ -26,10 +26,11 @@ from fastapi import (
)
from fastapi.concurrency import run_in_threadpool
from prisma.enums import SubscriptionTier
from pydantic import BaseModel
from pydantic import BaseModel, Field
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
from typing_extensions import Optional, TypedDict
from backend.api.features.workspace.routes import create_file_download_response
from backend.api.model import (
CreateAPIKeyRequest,
CreateAPIKeyResponse,
@@ -49,20 +50,24 @@ from backend.data.auth import api_key as api_key_db
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.credit import (
AutoTopUpConfig,
PendingChangeUnknown,
RefundRequest,
TransactionHistory,
UserCredit,
cancel_stripe_subscription,
create_subscription_checkout,
get_auto_top_up,
get_pending_subscription_change,
get_proration_credit_cents,
get_subscription_price_id,
get_user_credit_model,
handle_subscription_payment_failure,
modify_stripe_subscription_for_tier,
release_pending_subscription_schedule,
set_auto_top_up,
set_subscription_tier,
sync_subscription_from_stripe,
sync_subscription_schedule_from_stripe,
)
from backend.data.graph import GraphSettings
from backend.data.model import CredentialsMetaInput, UserOnboarding
@@ -92,6 +97,7 @@ from backend.data.user import (
update_user_notification_preference,
update_user_timezone,
)
from backend.data.workspace import get_workspace_file_by_id
from backend.executor import scheduler
from backend.executor import utils as execution_utils
from backend.integrations.webhooks.graph_lifecycle_hooks import (
@@ -693,20 +699,26 @@ async def get_user_auto_top_up(
class SubscriptionTierRequest(BaseModel):
tier: Literal["FREE", "PRO", "BUSINESS"]
tier: Literal["BASIC", "PRO", "MAX", "BUSINESS"]
success_url: str = ""
cancel_url: str = ""
class SubscriptionCheckoutResponse(BaseModel):
url: str
class SubscriptionStatusResponse(BaseModel):
tier: Literal["FREE", "PRO", "BUSINESS", "ENTERPRISE"]
tier: Literal["BASIC", "PRO", "MAX", "BUSINESS", "ENTERPRISE"]
monthly_cost: int # amount in cents (Stripe convention)
tier_costs: dict[str, int] # tier name -> amount in cents
proration_credit_cents: int # unused portion of current sub to convert on upgrade
pending_tier: Optional[Literal["BASIC", "PRO", "MAX", "BUSINESS"]] = None
pending_tier_effective_at: Optional[datetime] = None
url: str = Field(
default="",
description=(
"Populated only when POST /credits/subscription starts a Stripe Checkout"
" Session (BASIC → paid upgrade). Empty string in all other branches —"
" the client redirects to this URL when non-empty."
),
)
def _validate_checkout_redirect_url(url: str) -> bool:
@@ -782,39 +794,67 @@ async def get_subscription_status(
user_id: Annotated[str, Security(get_user_id)],
) -> SubscriptionStatusResponse:
user = await get_user_by_id(user_id)
tier = user.subscription_tier or SubscriptionTier.FREE
tier = user.subscription_tier or SubscriptionTier.BASIC
paid_tiers = [SubscriptionTier.PRO, SubscriptionTier.BUSINESS]
priceable_tiers = [
SubscriptionTier.BASIC,
SubscriptionTier.PRO,
SubscriptionTier.MAX,
SubscriptionTier.BUSINESS,
]
price_ids = await asyncio.gather(
*[get_subscription_price_id(t) for t in paid_tiers]
*[get_subscription_price_id(t) for t in priceable_tiers]
)
tier_costs: dict[str, int] = {
SubscriptionTier.FREE.value: 0,
SubscriptionTier.ENTERPRISE.value: 0,
}
async def _cost(pid: str | None) -> int:
return (await _get_stripe_price_amount(pid) or 0) if pid else 0
costs = await asyncio.gather(*[_cost(pid) for pid in price_ids])
for t, cost in zip(paid_tiers, costs):
tier_costs[t.value] = cost
tier_costs: dict[str, int] = {}
for t, pid, cost in zip(priceable_tiers, price_ids, costs):
if pid:
tier_costs[t.value] = cost
current_monthly_cost = tier_costs.get(tier.value, 0)
proration_credit = await get_proration_credit_cents(user_id, current_monthly_cost)
return SubscriptionStatusResponse(
try:
pending = await get_pending_subscription_change(user_id)
except (stripe.StripeError, PendingChangeUnknown):
# Swallow Stripe-side failures (rate limits, transient network) AND
# PendingChangeUnknown (LaunchDarkly price-id lookup failed). Both
# propagate past the cache so the next request retries fresh instead
# of serving a stale None for the TTL window. Let real bugs (KeyError,
# AttributeError, etc.) propagate so they surface in Sentry.
logger.exception(
"get_subscription_status: failed to resolve pending change for user %s",
user_id,
)
pending = None
response = SubscriptionStatusResponse(
tier=tier.value,
monthly_cost=current_monthly_cost,
tier_costs=tier_costs,
proration_credit_cents=proration_credit,
)
if pending is not None:
pending_tier_enum, pending_effective_at = pending
if pending_tier_enum in (
SubscriptionTier.BASIC,
SubscriptionTier.PRO,
SubscriptionTier.MAX,
SubscriptionTier.BUSINESS,
):
response.pending_tier = pending_tier_enum.value
response.pending_tier_effective_at = pending_effective_at
return response
@v1_router.post(
path="/credits/subscription",
summary="Start a Stripe Checkout session to upgrade subscription tier",
summary="Update subscription tier or start a Stripe Checkout session",
operation_id="updateSubscriptionTier",
tags=["credits"],
dependencies=[Security(requires_user)],
@@ -822,38 +862,63 @@ async def get_subscription_status(
async def update_subscription_tier(
request: SubscriptionTierRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> SubscriptionCheckoutResponse:
# Pydantic validates tier is one of FREE/PRO/BUSINESS via Literal type.
) -> SubscriptionStatusResponse:
# Pydantic validates tier is one of BASIC/PRO/MAX/BUSINESS via Literal type.
tier = SubscriptionTier(request.tier)
# ENTERPRISE tier is admin-managed — block self-service changes from ENTERPRISE users.
user = await get_user_by_id(user_id)
if (user.subscription_tier or SubscriptionTier.FREE) == SubscriptionTier.ENTERPRISE:
if (
user.subscription_tier or SubscriptionTier.BASIC
) == SubscriptionTier.ENTERPRISE:
raise HTTPException(
status_code=403,
detail="ENTERPRISE subscription changes must be managed by an administrator",
)
# Same-tier request = "stay on my current tier" = cancel any pending
# scheduled change (paid→paid downgrade or paid→BASIC cancel). This is the
# collapsed behaviour that replaces the old /credits/subscription/cancel-pending
# route. Safe when no pending change exists: release_pending_subscription_schedule
# returns False and we simply return the current status.
if (user.subscription_tier or SubscriptionTier.BASIC) == tier:
try:
await release_pending_subscription_schedule(user_id)
except stripe.StripeError as e:
logger.exception(
"Stripe error releasing pending subscription change for user %s: %s",
user_id,
e,
)
raise HTTPException(
status_code=502,
detail=(
"Unable to cancel the pending subscription change right now. "
"Please try again or contact support."
),
)
return await get_subscription_status(user_id)
payment_enabled = await is_feature_enabled(
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
)
# Downgrade to FREE: schedule Stripe cancellation at period end so the user
# keeps their tier for the time they already paid for. The DB tier is NOT
# updated here when a subscription exists — the customer.subscription.deleted
# webhook fires at period end and downgrades to FREE then.
# Exception: if the user has no active Stripe subscription (e.g. admin-granted
# tier), cancel_stripe_subscription returns False and we update the DB tier
# immediately since no webhook will ever fire.
# When payment is disabled entirely, update the DB tier directly.
if tier == SubscriptionTier.FREE:
current_tier = user.subscription_tier or SubscriptionTier.BASIC
target_price_id, current_tier_price_id = await asyncio.gather(
get_subscription_price_id(tier),
get_subscription_price_id(current_tier),
)
# Legacy cancel: target BASIC + stripe-price-id-basic unset. Schedule Stripe
# cancellation at period end; cancel_at_period_end=True lets the webhook flip
# the DB tier. No active sub (admin-granted) or payment disabled → DB flip.
# Once stripe-price-id-basic is configured, BASIC becomes a real sub and falls
# through to the modify/checkout flow below.
if tier == SubscriptionTier.BASIC and target_price_id is None:
if payment_enabled:
try:
had_subscription = await cancel_stripe_subscription(user_id)
except stripe.StripeError as e:
# Log full Stripe error server-side but return a generic message
# to the client — raw Stripe errors can leak customer/sub IDs and
# infrastructure config details.
logger.exception(
"Stripe error cancelling subscription for user %s: %s",
user_id,
@@ -867,48 +932,37 @@ async def update_subscription_tier(
),
)
if not had_subscription:
# No active Stripe subscription found — the user was on an
# admin-granted tier. Update DB immediately since the
# subscription.deleted webhook will never fire.
await set_subscription_tier(user_id, tier)
return SubscriptionCheckoutResponse(url="")
return await get_subscription_status(user_id)
await set_subscription_tier(user_id, tier)
return SubscriptionCheckoutResponse(url="")
return await get_subscription_status(user_id)
# Paid tier changes require payment to be enabled — block self-service upgrades
# when the flag is off. Admins use the /api/admin/ routes to set tiers directly.
if not payment_enabled:
raise HTTPException(
status_code=422,
detail=f"Subscription not available for tier {tier}",
detail=f"Subscription not available for tier {tier.value}",
)
# No-op short-circuit: if the user is already on the requested paid tier,
# do NOT create a new Checkout Session. Without this guard, a duplicate
# request (double-click, retried POST, stale page) creates a second
# subscription for the same price; the user would be charged for both
# until `_cleanup_stale_subscriptions` runs from the resulting webhook —
# which only fires after the second charge has cleared.
if (user.subscription_tier or SubscriptionTier.FREE) == tier:
return SubscriptionCheckoutResponse(url="")
# Target has no LD price — not provisionable (matches the GET hiding).
if target_price_id is None:
raise HTTPException(
status_code=422,
detail=f"Subscription not available for tier {tier.value}",
)
# Paid→paid tier change: if the user already has a Stripe subscription,
# modify it in-place with proration instead of creating a new Checkout
# Session. This preserves remaining paid time and avoids double-charging.
# The customer.subscription.updated webhook fires and updates the DB tier.
current_tier = user.subscription_tier or SubscriptionTier.FREE
if current_tier in (SubscriptionTier.PRO, SubscriptionTier.BUSINESS):
# User has an active Stripe subscription (current tier has an LD price):
# modify it in-place. modify_stripe_subscription_for_tier returns False when no
# active sub exists — that's only a "DB-only flip is OK" signal for admin-granted
# paid tiers (PRO/BUSINESS with no Stripe record). Priced-BASIC users without a
# sub must still go through Checkout so they set up payment.
if current_tier_price_id is not None:
try:
modified = await modify_stripe_subscription_for_tier(user_id, tier)
if modified:
return SubscriptionCheckoutResponse(url="")
# modify_stripe_subscription_for_tier returns False when no active
# Stripe subscription exists — i.e. the user has an admin-granted
# paid tier with no Stripe record. In that case, update the DB
# tier directly (same as the FREE-downgrade path for admin-granted
# users) rather than sending them through a new Checkout Session.
await set_subscription_tier(user_id, tier)
return SubscriptionCheckoutResponse(url="")
return await get_subscription_status(user_id)
if current_tier != SubscriptionTier.BASIC:
await set_subscription_tier(user_id, tier)
return await get_subscription_status(user_id)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
except stripe.StripeError as e:
@@ -923,7 +977,7 @@ async def update_subscription_tier(
),
)
# Paid upgrade from FREE → create Stripe Checkout Session.
# No active Stripe subscription → create Stripe Checkout Session.
if not request.success_url or not request.cancel_url:
raise HTTPException(
status_code=422,
@@ -978,7 +1032,9 @@ async def update_subscription_tier(
),
)
return SubscriptionCheckoutResponse(url=url)
status = await get_subscription_status(user_id)
status.url = url
return status
@v1_router.post(
@@ -1043,6 +1099,18 @@ async def stripe_webhook(request: Request):
):
await sync_subscription_from_stripe(data_object)
# `subscription_schedule.updated` is deliberately omitted: our own
# `SubscriptionSchedule.create` + `.modify` calls in
# `_schedule_downgrade_at_period_end` would fire that event right back at us
# and loop redundant traffic through this handler. We only care about state
# transitions (released / completed); phase advance to the new price is
# already covered by `customer.subscription.updated`.
if event_type in (
"subscription_schedule.released",
"subscription_schedule.completed",
):
await sync_subscription_schedule_from_stripe(data_object)
if event_type == "invoice.payment_failed":
await handle_subscription_payment_failure(data_object)
@@ -1640,6 +1708,10 @@ async def enable_execution_sharing(
# Generate a unique share token
share_token = str(uuid.uuid4())
# Remove stale allowlist records before updating the token — prevents a
# window where old records + new token could coexist.
await execution_db.delete_shared_execution_files(execution_id=graph_exec_id)
# Update the execution with share info
await execution_db.update_graph_execution_share_status(
execution_id=graph_exec_id,
@@ -1649,6 +1721,14 @@ async def enable_execution_sharing(
shared_at=datetime.now(timezone.utc),
)
# Create allowlist of workspace files referenced in outputs
await execution_db.create_shared_execution_files(
execution_id=graph_exec_id,
share_token=share_token,
user_id=user_id,
outputs=execution.outputs,
)
# Return the share URL
frontend_url = settings.config.frontend_base_url or "http://localhost:3000"
share_url = f"{frontend_url}/share/{share_token}"
@@ -1674,6 +1754,9 @@ async def disable_execution_sharing(
if not execution:
raise HTTPException(status_code=404, detail="Execution not found")
# Remove shared file allowlist records
await execution_db.delete_shared_execution_files(execution_id=graph_exec_id)
# Remove share info
await execution_db.update_graph_execution_share_status(
execution_id=graph_exec_id,
@@ -1699,6 +1782,43 @@ async def get_shared_execution(
return execution
@v1_router.get(
"/public/shared/{share_token}/files/{file_id}/download",
summary="Download a file from a shared execution",
operation_id="download_shared_file",
tags=["graphs"],
)
async def download_shared_file(
share_token: Annotated[
str,
Path(pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"),
],
file_id: Annotated[
str,
Path(pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"),
],
) -> Response:
"""Download a workspace file from a shared execution (no auth required).
Validates that the file was explicitly exposed when sharing was enabled.
Returns a uniform 404 for all failure modes to prevent enumeration attacks.
"""
# Single-query validation against the allowlist
execution_id = await execution_db.get_shared_execution_file(
share_token=share_token, file_id=file_id
)
if not execution_id:
raise HTTPException(status_code=404, detail="Not found")
# Look up the actual file (no workspace scoping needed — the allowlist
# already validated that this file belongs to the shared execution)
file = await get_workspace_file_by_id(file_id)
if not file:
raise HTTPException(status_code=404, detail="Not found")
return await create_file_download_response(file, inline=True)
########################################################
##################### Schedules ########################
########################################################

View File

@@ -0,0 +1,157 @@
"""Tests for the public shared file download endpoint."""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from starlette.responses import Response
from backend.api.features.v1 import v1_router
from backend.data.workspace import WorkspaceFile
app = FastAPI()
app.include_router(v1_router, prefix="/api")
VALID_TOKEN = "550e8400-e29b-41d4-a716-446655440000"
VALID_FILE_ID = "6ba7b810-9dad-11d1-80b4-00c04fd430c8"
def _make_workspace_file(**overrides) -> WorkspaceFile:
defaults = {
"id": VALID_FILE_ID,
"workspace_id": "ws-001",
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
"updated_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
"name": "image.png",
"path": "/image.png",
"storage_path": "local://uploads/image.png",
"mime_type": "image/png",
"size_bytes": 4,
"checksum": None,
"is_deleted": False,
"deleted_at": None,
"metadata": {},
}
defaults.update(overrides)
return WorkspaceFile(**defaults)
def _mock_download_response(**kwargs):
"""Return an AsyncMock that resolves to a Response with inline disposition."""
async def _handler(file, *, inline=False):
return Response(
content=b"\x89PNG",
media_type="image/png",
headers={
"Content-Disposition": (
'inline; filename="image.png"'
if inline
else 'attachment; filename="image.png"'
),
"Content-Length": "4",
},
)
return _handler
class TestDownloadSharedFile:
"""Tests for GET /api/public/shared/{token}/files/{id}/download."""
@pytest.fixture(autouse=True)
def _client(self):
self.client = TestClient(app, raise_server_exceptions=False)
def test_valid_token_and_file_returns_inline_content(self):
with (
patch(
"backend.api.features.v1.execution_db.get_shared_execution_file",
new_callable=AsyncMock,
return_value="exec-123",
),
patch(
"backend.api.features.v1.get_workspace_file_by_id",
new_callable=AsyncMock,
return_value=_make_workspace_file(),
),
patch(
"backend.api.features.v1.create_file_download_response",
side_effect=_mock_download_response(),
),
):
response = self.client.get(
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
)
assert response.status_code == 200
assert response.content == b"\x89PNG"
assert "inline" in response.headers["Content-Disposition"]
def test_invalid_token_format_returns_422(self):
response = self.client.get(
f"/api/public/shared/not-a-uuid/files/{VALID_FILE_ID}/download"
)
assert response.status_code == 422
def test_token_not_in_allowlist_returns_404(self):
with patch(
"backend.api.features.v1.execution_db.get_shared_execution_file",
new_callable=AsyncMock,
return_value=None,
):
response = self.client.get(
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
)
assert response.status_code == 404
def test_file_missing_from_workspace_returns_404(self):
with (
patch(
"backend.api.features.v1.execution_db.get_shared_execution_file",
new_callable=AsyncMock,
return_value="exec-123",
),
patch(
"backend.api.features.v1.get_workspace_file_by_id",
new_callable=AsyncMock,
return_value=None,
),
):
response = self.client.get(
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
)
assert response.status_code == 404
def test_uniform_404_prevents_enumeration(self):
"""Both failure modes produce identical 404 — no information leak."""
with patch(
"backend.api.features.v1.execution_db.get_shared_execution_file",
new_callable=AsyncMock,
return_value=None,
):
resp_no_allow = self.client.get(
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
)
with (
patch(
"backend.api.features.v1.execution_db.get_shared_execution_file",
new_callable=AsyncMock,
return_value="exec-123",
),
patch(
"backend.api.features.v1.get_workspace_file_by_id",
new_callable=AsyncMock,
return_value=None,
),
):
resp_no_file = self.client.get(
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
)
assert resp_no_allow.status_code == 404
assert resp_no_file.status_code == 404
assert resp_no_allow.json() == resp_no_file.json()

View File

@@ -29,7 +29,9 @@ from backend.util.workspace import WorkspaceManager
from backend.util.workspace_storage import get_workspace_storage
def _sanitize_filename_for_header(filename: str) -> str:
def _sanitize_filename_for_header(
filename: str, disposition: str = "attachment"
) -> str:
"""
Sanitize filename for Content-Disposition header to prevent header injection.
@@ -44,11 +46,11 @@ def _sanitize_filename_for_header(filename: str) -> str:
# Check if filename has non-ASCII characters
try:
sanitized.encode("ascii")
return f'attachment; filename="{sanitized}"'
return f'{disposition}; filename="{sanitized}"'
except UnicodeEncodeError:
# Use RFC5987 encoding for UTF-8 filenames
encoded = quote(sanitized, safe="")
return f"attachment; filename*=UTF-8''{encoded}"
return f"{disposition}; filename*=UTF-8''{encoded}"
logger = logging.getLogger(__name__)
@@ -58,19 +60,26 @@ router = fastapi.APIRouter(
)
def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response:
def _create_streaming_response(
content: bytes, file: WorkspaceFile, *, inline: bool = False
) -> Response:
"""Create a streaming response for file content."""
disposition = _sanitize_filename_for_header(
file.name, disposition="inline" if inline else "attachment"
)
return Response(
content=content,
media_type=file.mime_type,
headers={
"Content-Disposition": _sanitize_filename_for_header(file.name),
"Content-Disposition": disposition,
"Content-Length": str(len(content)),
},
)
async def _create_file_download_response(file: WorkspaceFile) -> Response:
async def create_file_download_response(
file: WorkspaceFile, *, inline: bool = False
) -> Response:
"""
Create a download response for a workspace file.
@@ -82,7 +91,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
# For local storage, stream the file directly
if file.storage_path.startswith("local://"):
content = await storage.retrieve(file.storage_path)
return _create_streaming_response(content, file)
return _create_streaming_response(content, file, inline=inline)
# For GCS, try to redirect to signed URL, fall back to streaming
try:
@@ -90,7 +99,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
# If we got back an API path (fallback), stream directly instead
if url.startswith("/api/"):
content = await storage.retrieve(file.storage_path)
return _create_streaming_response(content, file)
return _create_streaming_response(content, file, inline=inline)
return fastapi.responses.RedirectResponse(url=url, status_code=302)
except Exception as e:
# Log the signed URL failure with context
@@ -102,7 +111,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
# Fall back to streaming directly from GCS
try:
content = await storage.retrieve(file.storage_path)
return _create_streaming_response(content, file)
return _create_streaming_response(content, file, inline=inline)
except Exception as fallback_error:
logger.error(
f"Fallback streaming also failed for file {file.id} "
@@ -169,7 +178,7 @@ async def download_file(
if file is None:
raise fastapi.HTTPException(status_code=404, detail="File not found")
return await _create_file_download_response(file)
return await create_file_download_response(file)
@router.delete(

View File

@@ -600,3 +600,221 @@ def test_list_files_offset_is_echoed_back(mock_manager_cls, mock_get_workspace):
mock_instance.list_files.assert_called_once_with(
limit=11, offset=50, include_all_sessions=True
)
# -- _sanitize_filename_for_header tests --
class TestSanitizeFilenameForHeader:
def test_simple_ascii_attachment(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
assert _sanitize_filename_for_header("report.pdf") == (
'attachment; filename="report.pdf"'
)
def test_inline_disposition(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
assert _sanitize_filename_for_header("image.png", disposition="inline") == (
'inline; filename="image.png"'
)
def test_strips_cr_lf_null(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
result = _sanitize_filename_for_header("a\rb\nc\x00d.txt")
assert "\r" not in result
assert "\n" not in result
assert "\x00" not in result
assert 'filename="abcd.txt"' in result
def test_escapes_quotes(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
result = _sanitize_filename_for_header('file"name.txt')
assert 'filename="file\\"name.txt"' in result
def test_header_injection_blocked(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
result = _sanitize_filename_for_header("evil.txt\r\nX-Injected: true")
# CR/LF stripped — the remaining text is safely inside the quoted value
assert "\r" not in result
assert "\n" not in result
assert result == 'attachment; filename="evil.txtX-Injected: true"'
def test_unicode_uses_rfc5987(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
result = _sanitize_filename_for_header("日本語.pdf")
assert "filename*=UTF-8''" in result
assert "attachment" in result
def test_unicode_inline(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
result = _sanitize_filename_for_header("图片.png", disposition="inline")
assert result.startswith("inline; filename*=UTF-8''")
def test_empty_filename(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
result = _sanitize_filename_for_header("")
assert result == 'attachment; filename=""'
# -- _create_streaming_response tests --
class TestCreateStreamingResponse:
def test_attachment_disposition_by_default(self):
from backend.api.features.workspace.routes import _create_streaming_response
file = _make_file(name="data.bin", mime_type="application/octet-stream")
response = _create_streaming_response(b"binary-data", file)
assert (
response.headers["Content-Disposition"] == 'attachment; filename="data.bin"'
)
assert response.headers["Content-Type"] == "application/octet-stream"
assert response.headers["Content-Length"] == "11"
assert response.body == b"binary-data"
def test_inline_disposition(self):
from backend.api.features.workspace.routes import _create_streaming_response
file = _make_file(name="photo.png", mime_type="image/png")
response = _create_streaming_response(b"\x89PNG", file, inline=True)
assert response.headers["Content-Disposition"] == 'inline; filename="photo.png"'
assert response.headers["Content-Type"] == "image/png"
def test_inline_sanitizes_filename(self):
from backend.api.features.workspace.routes import _create_streaming_response
file = _make_file(name='evil"\r\n.txt', mime_type="text/plain")
response = _create_streaming_response(b"data", file, inline=True)
assert "\r" not in response.headers["Content-Disposition"]
assert "\n" not in response.headers["Content-Disposition"]
assert "inline" in response.headers["Content-Disposition"]
def test_content_length_matches_body(self):
from backend.api.features.workspace.routes import _create_streaming_response
content = b"x" * 1000
file = _make_file(name="big.bin", mime_type="application/octet-stream")
response = _create_streaming_response(content, file)
assert response.headers["Content-Length"] == "1000"
# -- create_file_download_response tests --
class TestCreateFileDownloadResponse:
@pytest.mark.asyncio
async def test_local_storage_returns_streaming_response(self, mocker):
from backend.api.features.workspace.routes import create_file_download_response
mock_storage = AsyncMock()
mock_storage.retrieve.return_value = b"file contents"
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage",
return_value=mock_storage,
)
file = _make_file(
storage_path="local://uploads/test.txt",
mime_type="text/plain",
)
response = await create_file_download_response(file)
assert response.status_code == 200
assert response.body == b"file contents"
assert "attachment" in response.headers["Content-Disposition"]
@pytest.mark.asyncio
async def test_local_storage_inline(self, mocker):
from backend.api.features.workspace.routes import create_file_download_response
mock_storage = AsyncMock()
mock_storage.retrieve.return_value = b"\x89PNG"
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage",
return_value=mock_storage,
)
file = _make_file(
storage_path="local://uploads/photo.png",
mime_type="image/png",
name="photo.png",
)
response = await create_file_download_response(file, inline=True)
assert "inline" in response.headers["Content-Disposition"]
@pytest.mark.asyncio
async def test_gcs_redirect(self, mocker):
from backend.api.features.workspace.routes import create_file_download_response
mock_storage = AsyncMock()
mock_storage.get_download_url.return_value = (
"https://storage.googleapis.com/signed-url"
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage",
return_value=mock_storage,
)
file = _make_file(storage_path="gcs://bucket/file.pdf")
response = await create_file_download_response(file)
assert response.status_code == 302
assert (
response.headers["location"] == "https://storage.googleapis.com/signed-url"
)
@pytest.mark.asyncio
async def test_gcs_api_fallback_streams_directly(self, mocker):
from backend.api.features.workspace.routes import create_file_download_response
mock_storage = AsyncMock()
mock_storage.get_download_url.return_value = "/api/fallback"
mock_storage.retrieve.return_value = b"fallback content"
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage",
return_value=mock_storage,
)
file = _make_file(storage_path="gcs://bucket/file.txt")
response = await create_file_download_response(file)
assert response.status_code == 200
assert response.body == b"fallback content"
@pytest.mark.asyncio
async def test_gcs_signed_url_failure_falls_back_to_streaming(self, mocker):
from backend.api.features.workspace.routes import create_file_download_response
mock_storage = AsyncMock()
mock_storage.get_download_url.side_effect = RuntimeError("GCS error")
mock_storage.retrieve.return_value = b"streamed"
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage",
return_value=mock_storage,
)
file = _make_file(storage_path="gcs://bucket/file.txt")
response = await create_file_download_response(file)
assert response.status_code == 200
assert response.body == b"streamed"
@pytest.mark.asyncio
async def test_gcs_total_failure_raises(self, mocker):
from backend.api.features.workspace.routes import create_file_download_response
mock_storage = AsyncMock()
mock_storage.get_download_url.side_effect = RuntimeError("GCS error")
mock_storage.retrieve.side_effect = RuntimeError("Also failed")
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage",
return_value=mock_storage,
)
file = _make_file(storage_path="gcs://bucket/file.txt")
with pytest.raises(RuntimeError, match="Also failed"):
await create_file_download_response(file)

View File

@@ -17,6 +17,7 @@ from fastapi.routing import APIRoute
from prisma.errors import PrismaError
import backend.api.features.admin.credit_admin_routes
import backend.api.features.admin.diagnostics_admin_routes
import backend.api.features.admin.execution_analytics_routes
import backend.api.features.admin.platform_cost_routes
import backend.api.features.admin.rate_limit_admin_routes
@@ -31,6 +32,7 @@ import backend.api.features.library.routes
import backend.api.features.mcp.routes as mcp_routes
import backend.api.features.oauth
import backend.api.features.otto.routes
import backend.api.features.platform_linking.routes
import backend.api.features.postmark.postmark
import backend.api.features.store.model
import backend.api.features.store.routes
@@ -320,6 +322,11 @@ app.include_router(
tags=["v2", "admin"],
prefix="/api/credits",
)
app.include_router(
backend.api.features.admin.diagnostics_admin_routes.router,
tags=["v2", "admin"],
prefix="/api",
)
app.include_router(
backend.api.features.admin.execution_analytics_routes.router,
tags=["v2", "admin"],
@@ -372,6 +379,11 @@ app.include_router(
tags=["oauth"],
prefix="/api/oauth",
)
app.include_router(
backend.api.features.platform_linking.routes.router,
tags=["platform-linking"],
prefix="/api/platform-linking",
)
app.mount("/external-api", external_api)

View File

@@ -42,11 +42,13 @@ def main(**kwargs):
from backend.data.db_manager import DatabaseManager
from backend.executor import ExecutionManager, Scheduler
from backend.notifications import NotificationManager
from backend.platform_linking.manager import PlatformLinkingManager
run_processes(
DatabaseManager().set_log_level("warning"),
Scheduler(),
NotificationManager(),
PlatformLinkingManager(),
WebsocketServer(),
AgentServer(),
ExecutionManager(),

View File

@@ -96,27 +96,64 @@ class BlockCategory(Enum):
class BlockCostType(str, Enum):
RUN = "run" # cost X credits per run
BYTE = "byte" # cost X credits per byte
SECOND = "second" # cost X credits per second
# RUN : cost_amount credits per run.
# BYTE : cost_amount credits per byte of input data.
# SECOND : cost_amount credits per cost_divisor walltime seconds.
# ITEMS : cost_amount credits per cost_divisor items (from stats).
# COST_USD : cost_amount credits per USD of stats.provider_cost.
# TOKENS : per-(model, provider) rate table lookup; see TOKEN_COST.
RUN = "run"
BYTE = "byte"
SECOND = "second"
ITEMS = "items"
COST_USD = "cost_usd"
TOKENS = "tokens"
@property
def is_dynamic(self) -> bool:
"""Real charge is computed post-flight from stats.
Dynamic types (SECOND/ITEMS/COST_USD/TOKENS) return 0 pre-flight and
settle against stats via charge_reconciled_usage once the block runs.
"""
return self in _DYNAMIC_COST_TYPES
_DYNAMIC_COST_TYPES: frozenset[BlockCostType] = frozenset(
{
BlockCostType.SECOND,
BlockCostType.ITEMS,
BlockCostType.COST_USD,
BlockCostType.TOKENS,
}
)
class BlockCost(BaseModel):
cost_amount: int
cost_filter: BlockInput
cost_type: BlockCostType
# cost_divisor: interpret cost_amount as "credits per cost_divisor units".
# Only meaningful for SECOND / ITEMS. TOKENS routes through TOKEN_COST
# rate tables (per-model input/output/cache pricing) and ignores
# cost_divisor entirely. Defaults to 1 so existing RUN/BYTE entries stay
# point-wise. Example: cost_amount=1, cost_divisor=10 under SECOND means
# "1 credit per 10 seconds of walltime".
cost_divisor: int = 1
def __init__(
self,
cost_amount: int,
cost_type: BlockCostType = BlockCostType.RUN,
cost_filter: Optional[BlockInput] = None,
cost_divisor: int = 1,
**data: Any,
) -> None:
super().__init__(
cost_amount=cost_amount,
cost_filter=cost_filter or {},
cost_type=cost_type,
cost_divisor=max(1, cost_divisor),
**data,
)
@@ -168,9 +205,31 @@ class BlockSchema(BaseModel):
return cls.cached_jsonschema
@classmethod
def validate_data(cls, data: BlockInput) -> str | None:
def validate_data(
cls,
data: BlockInput,
exclude_fields: set[str] | None = None,
) -> str | None:
schema = cls.jsonschema()
if exclude_fields:
# Drop the excluded fields from both the properties and the
# ``required`` list so jsonschema doesn't flag them as missing.
# Used by the dry-run path to skip credentials validation while
# still validating the remaining block inputs.
schema = {
**schema,
"properties": {
k: v
for k, v in schema.get("properties", {}).items()
if k not in exclude_fields
},
"required": [
r for r in schema.get("required", []) if r not in exclude_fields
],
}
data = {k: v for k, v in data.items() if k not in exclude_fields}
return json.validate_with_jsonschema(
schema=cls.jsonschema(),
schema=schema,
data={k: v for k, v in data.items() if v is not None},
)
@@ -311,6 +370,8 @@ class BlockSchema(BaseModel):
"credentials_provider": [config.get("provider", "google")],
"credentials_types": [config.get("type", "oauth2")],
"credentials_scopes": config.get("scopes"),
"is_auto_credential": True,
"input_field_name": info["field_name"],
}
result[kwarg_name] = CredentialsFieldInfo.model_validate(
auto_schema, by_alias=True
@@ -421,19 +482,6 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
_optimized_description: ClassVar[str | None] = None
def extra_runtime_cost(self, execution_stats: NodeExecutionStats) -> int:
"""Return extra runtime cost to charge after this block run completes.
Called by the executor after a block finishes with COMPLETED status.
The return value is the number of additional base-cost credits to
charge beyond the single credit already collected by charge_usage
at the start of execution. Defaults to 0 (no extra charges).
Override in blocks (e.g. OrchestratorBlock) that make multiple LLM
calls within one run and should be billed per call.
"""
return 0
def __init__(
self,
id: str = "",
@@ -717,11 +765,16 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
# (e.g. AgentExecutorBlock) get proper input validation.
is_dry_run = getattr(kwargs.get("execution_context"), "dry_run", False)
if is_dry_run:
# Credential fields may be absent (LLM-built agents often skip
# wiring them) or nullified earlier in the pipeline. Validate
# the non-credential inputs against a schema with those fields
# excluded — stripping only the data while keeping them in the
# ``required`` list would falsely report ``'credentials' is a
# required property``.
cred_field_names = set(self.input_schema.get_credentials_fields().keys())
non_cred_data = {
k: v for k, v in input_data.items() if k not in cred_field_names
}
if error := self.input_schema.validate_data(non_cred_data):
if error := self.input_schema.validate_data(
input_data, exclude_fields=cred_field_names
):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
block_name=self.name,
@@ -735,6 +788,61 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
block_id=self.id,
)
# Ensure auto-credential kwargs are present before we hand off to
# run(). A missing auto-credential means the upstream field (e.g.
# a Google Drive picker) didn't embed a _credentials_id, or the
# executor couldn't resolve it. Without this guard, run() would
# crash with a TypeError (missing required kwarg) or an opaque
# AttributeError deep inside the provider SDK.
#
# Only raise when the field is ALSO not populated in input_data.
# ``_acquire_auto_credentials`` intentionally skips setting the
# kwarg in two legitimate cases — ``_credentials_id`` is ``None``
# (chained from upstream) or the field is missing from
# ``input_data`` at prep time (connected from upstream block).
# In both cases the upstream block is expected to populate the
# field value by execute time; raising here would break the
# documented ``AgentGoogleDriveFileInputBlock`` chaining pattern.
# Dry-run skips because the executor intentionally runs blocks
# without resolved creds for schema validation.
if not is_dry_run:
for (
kwarg_name,
info,
) in self.input_schema.get_auto_credentials_fields().items():
kwargs.setdefault(kwarg_name, None)
if kwargs[kwarg_name] is not None:
continue
# Upstream-chained pattern: the field was populated by a
# prior node (e.g. AgentGoogleDriveFileInputBlock) whose
# output carries a resolved ``_credentials_id``.
# ``_acquire_auto_credentials`` deliberately doesn't set
# the kwarg in that case because the value isn't available
# at prep time; the executor fills it in before we reach
# ``_execute``. Trust it if the ``_credentials_id`` KEY
# is present — its value may be explicitly ``None`` in
# the chained case (see sentry thread
# PRRT_kwDOJKSTjM58sJfA). Checking truthiness here would
# falsely preempt run() for every valid chained graph
# that ships ``_credentials_id=None`` in the picker
# object. Mirror ``_acquire_auto_credentials``'s own
# skip rule, which treats ``cred_id is None`` as a
# chained-skip signal.
field_name = info["field_name"]
field_value = input_data.get(field_name)
if isinstance(field_value, dict) and "_credentials_id" in field_value:
continue
raise BlockExecutionError(
message=(
f"Missing credentials for '{kwarg_name}'. "
"Select a file via the picker (which carries "
"its credentials), or connect credentials for "
"this block."
),
block_name=self.name,
block_id=self.id,
)
# Use the validated input data
async for output_name, output_data in self.run(
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),

View File

@@ -171,7 +171,10 @@ class AgentExecutorBlock(Block):
)
self.merge_stats(
NodeExecutionStats(
extra_cost=event.stats.cost if event.stats else 0,
# Sub-graph already debited each of its own nodes; we
# roll up its total so graph_stats.cost reflects the
# full sub-graph spend.
reconciled_cost_delta=(event.stats.cost if event.stats else 0),
extra_steps=event.stats.node_exec_count if event.stats else 0,
)
)

View File

@@ -4,11 +4,16 @@ Shared configuration for all AgentMail blocks.
from agentmail import AsyncAgentMail
from backend.sdk import APIKeyCredentials, ProviderBuilder, SecretStr
from backend.sdk import APIKeyCredentials, BlockCostType, ProviderBuilder, SecretStr
# AgentMail is in beta with no published paid tier yet, but ~37 blocks
# without any BLOCK_COSTS entry means they currently execute wallet-free.
# 1 cr/call is a conservative interim floor so no AgentMail work leaks
# past billing. Revisit once AgentMail publishes usage-based pricing.
agent_mail = (
ProviderBuilder("agent_mail")
.with_api_key("AGENTMAIL_API_KEY", "AgentMail API Key")
.with_base_cost(1, BlockCostType.RUN)
.build()
)

View File

@@ -23,6 +23,7 @@ from backend.copilot.permissions import (
validate_block_identifiers,
)
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
if TYPE_CHECKING:
from backend.data.execution import ExecutionContext
@@ -32,9 +33,36 @@ logger = logging.getLogger(__name__)
# Block ID shared between autopilot.py and copilot prompting.py.
AUTOPILOT_BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
# Identifiers used when registering an AutoPilotBlock turn with the
# stream registry — distinguishes block-originated turns from sub-session
# or HTTP SSE turns in logs / observability.
_AUTOPILOT_TOOL_CALL_ID = "autopilot_block"
_AUTOPILOT_TOOL_NAME = "autopilot_block"
class SubAgentRecursionError(RuntimeError):
"""Raised when the sub-agent nesting depth limit is exceeded."""
# Ceiling on how long AutoPilotBlock.execute_copilot will wait for the
# enqueued turn's terminal event. Graph blocks run synchronously from
# the caller's perspective so we wait effectively as long as needed; 6h
# matches the previous abandoned-task cap and is much longer than any
# legitimate AutoPilot turn.
_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS = 6 * 60 * 60 # 6 hours
class SubAgentRecursionError(BlockExecutionError):
"""Raised when the AutoPilot sub-agent nesting depth limit is exceeded.
Inherits :class:`BlockExecutionError` — this is a known, handled
runtime failure at the block level (caller nested AutoPilotBlocks
beyond the configured limit). Surfaces with the block_name /
block_id the block framework expects, instead of being wrapped in
``BlockUnknownError``.
"""
def __init__(self, message: str) -> None:
super().__init__(
message=message,
block_name="AutoPilotBlock",
block_id=AUTOPILOT_BLOCK_ID,
)
class ToolCallEntry(TypedDict):
@@ -268,11 +296,15 @@ class AutoPilotBlock(Block):
user_id: str,
permissions: "CopilotPermissions | None" = None,
) -> tuple[str, list[ToolCallEntry], str, str, TokenUsage]:
"""Invoke the copilot and collect all stream results.
"""Invoke the copilot on the copilot_executor queue and aggregate the
result.
Delegates to :func:`collect_copilot_response` — the shared helper that
consumes ``stream_chat_completion_sdk`` without wrapping it in an
``asyncio.timeout`` (the SDK manages its own heartbeat-based timeouts).
Delegates to :func:`run_copilot_turn_via_queue` — the shared
primitive used by ``run_sub_session`` too — which creates the
stream_registry meta record, enqueues the job, and waits on the
Redis stream for the terminal event. Any available
copilot_executor worker picks up the job, so this call survives
the graph-executor worker dying mid-turn (RabbitMQ redelivers).
Args:
prompt: The user task/instruction.
@@ -285,8 +317,8 @@ class AutoPilotBlock(Block):
Returns:
A tuple of (response_text, tool_calls, history_json, session_id, usage).
"""
from backend.copilot.sdk.collect import (
collect_copilot_response, # avoid circular import
from backend.copilot.sdk.session_waiter import (
run_copilot_turn_via_queue, # avoid circular import
)
tokens = _check_recursion(max_recursion_depth)
@@ -299,14 +331,35 @@ class AutoPilotBlock(Block):
if system_context:
effective_prompt = f"[System Context: {system_context}]\n\n{prompt}"
result = await collect_copilot_response(
outcome, result = await run_copilot_turn_via_queue(
session_id=session_id,
message=effective_prompt,
user_id=user_id,
message=effective_prompt,
# Graph block execution is synchronous from the caller's
# perspective — wait effectively as long as needed. The
# SDK enforces its own idle-based timeout inside the
# stream_registry pipeline.
timeout=_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS,
permissions=effective_permissions,
tool_call_id=_AUTOPILOT_TOOL_CALL_ID,
tool_name=_AUTOPILOT_TOOL_NAME,
)
if outcome == "failed":
raise RuntimeError(
"AutoPilot turn failed — see the session's transcript"
)
if outcome == "running":
raise RuntimeError(
"AutoPilot turn did not complete within "
f"{_AUTOPILOT_BLOCK_MAX_WAIT_SECONDS}s — session "
f"{session_id}"
)
# Build a lightweight conversation summary from streamed data.
# Build a lightweight conversation summary from the aggregated data.
# When ``result.queued`` is True the prompt rode on an already-
# in-flight turn (``run_copilot_turn_via_queue`` queued it and
# waited on the existing turn's stream); the aggregated result
# is still valid, so the same rendering path applies.
turn_messages: list[dict[str, Any]] = [
{"role": "user", "content": effective_prompt},
]
@@ -315,7 +368,7 @@ class AutoPilotBlock(Block):
{
"role": "assistant",
"content": result.response_text,
"tool_calls": result.tool_calls,
"tool_calls": [tc.model_dump() for tc in result.tool_calls],
}
)
else:
@@ -326,11 +379,11 @@ class AutoPilotBlock(Block):
tool_calls: list[ToolCallEntry] = [
{
"tool_call_id": tc["tool_call_id"],
"tool_name": tc["tool_name"],
"input": tc["input"],
"output": tc["output"],
"success": tc["success"],
"tool_call_id": tc.tool_call_id,
"tool_name": tc.tool_name,
"input": tc.input,
"output": tc.output,
"success": tc.success,
}
for tc in result.tool_calls
]

View File

@@ -0,0 +1,21 @@
"""Shared provider config for Ayrshare social-media blocks.
The "credential" exposed to blocks is the **per-user Ayrshare profile key**,
not the org-level ``AYRSHARE_API_KEY``. Profile keys are provisioned per
user by :class:`~backend.integrations.managed_providers.ayrshare.AyrshareManagedProvider`
and stored in the normal credentials list with ``is_managed=True``, so every
Ayrshare block fits the standard credential flow:
credentials: CredentialsMetaInput = ayrshare.credentials_field(...)
``run_block`` / ``resolve_block_credentials`` take care of the rest.
``with_managed_api_key()`` registers ``api_key`` as a supported auth type
without the env-var-backed default credential that ``with_api_key()`` would
create — the org-level ``AYRSHARE_API_KEY`` is the admin key and must never
reach a block as a "profile key".
"""
from backend.sdk import ProviderBuilder
ayrshare = ProviderBuilder("ayrshare").with_managed_api_key().build()

View File

@@ -0,0 +1,18 @@
from backend.sdk import BlockCost, BlockCostType
# Ayrshare is a subscription proxy ($149/mo Business). Per-post credit charges
# prevent a single heavy user from absorbing the fixed cost and align with the
# upload cost of each post variant.
# cost_filter matches on input_data.is_video BEFORE run() executes, so the flag
# has to be correct at input-eval time. Video-only platforms (YouTube, Snapchat)
# override the base default to True; platforms that accept both (TikTok, etc.)
# rely on the caller setting is_video explicitly for accurate billing.
# First match wins in block_usage_cost, so list the video tier first.
AYRSHARE_POST_COSTS = (
BlockCost(
cost_amount=5, cost_type=BlockCostType.RUN, cost_filter={"is_video": True}
),
BlockCost(
cost_amount=2, cost_type=BlockCostType.RUN, cost_filter={"is_video": False}
),
)

View File

@@ -4,22 +4,25 @@ from typing import Optional
from pydantic import BaseModel, Field
from backend.blocks._base import BlockSchemaInput
from backend.data.model import SchemaField, UserIntegrations
from backend.data.model import CredentialsMetaInput, SchemaField
from backend.integrations.ayrshare import AyrshareClient
from backend.util.clients import get_database_manager_async_client
from backend.util.exceptions import MissingConfigError
async def get_profile_key(user_id: str):
user_integrations: UserIntegrations = (
await get_database_manager_async_client().get_user_integrations(user_id)
)
return user_integrations.managed_credentials.ayrshare_profile_key
from ._config import ayrshare
class BaseAyrshareInput(BlockSchemaInput):
"""Base input model for Ayrshare social media posts with common fields."""
credentials: CredentialsMetaInput = ayrshare.credentials_field(
description=(
"Ayrshare profile credential. AutoGPT provisions this managed "
"credential automatically — the user does not create it. After "
"it's in place, the user links each social account via the "
"Ayrshare SSO popup in the Builder."
),
)
post: str = SchemaField(
description="The post text to be published", default="", advanced=False
)
@@ -29,7 +32,9 @@ class BaseAyrshareInput(BlockSchemaInput):
advanced=False,
)
is_video: bool = SchemaField(
description="Whether the media is a video", default=False, advanced=True
description="Whether the media is a video. Set to True when uploading a video so billing applies the video tier.",
default=False,
advanced=True,
)
schedule_date: Optional[datetime] = SchemaField(
description="UTC datetime for scheduling (YYYY-MM-DDThh:mm:ssZ)",

View File

@@ -1,16 +1,20 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, create_ayrshare_client
@cost(*AYRSHARE_POST_COSTS)
class PostToBlueskyBlock(Block):
"""Block for posting to Bluesky with Bluesky-specific options."""
@@ -57,16 +61,10 @@ class PostToBlueskyBlock(Block):
self,
input_data: "PostToBlueskyBlock.Input",
*,
user_id: str,
credentials: APIKeyCredentials,
**kwargs,
) -> BlockOutput:
"""Post to Bluesky with Bluesky-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -106,7 +104,7 @@ class PostToBlueskyBlock(Block):
random_media_url=input_data.random_media_url,
notes=input_data.notes,
bluesky_options=bluesky_options if bluesky_options else None,
profile_key=profile_key.get_secret_value(),
profile_key=credentials.api_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -1,21 +1,20 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._util import (
BaseAyrshareInput,
CarouselItem,
create_ayrshare_client,
get_profile_key,
)
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, CarouselItem, create_ayrshare_client
@cost(*AYRSHARE_POST_COSTS)
class PostToFacebookBlock(Block):
"""Block for posting to Facebook with Facebook-specific options."""
@@ -120,15 +119,10 @@ class PostToFacebookBlock(Block):
self,
input_data: "PostToFacebookBlock.Input",
*,
user_id: str,
credentials: APIKeyCredentials,
**kwargs,
) -> BlockOutput:
"""Post to Facebook with Facebook-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -204,7 +198,7 @@ class PostToFacebookBlock(Block):
random_media_url=input_data.random_media_url,
notes=input_data.notes,
facebook_options=facebook_options if facebook_options else None,
profile_key=profile_key.get_secret_value(),
profile_key=credentials.api_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -1,16 +1,20 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, create_ayrshare_client
@cost(*AYRSHARE_POST_COSTS)
class PostToGMBBlock(Block):
"""Block for posting to Google My Business with GMB-specific options."""
@@ -110,14 +114,13 @@ class PostToGMBBlock(Block):
)
async def run(
self, input_data: "PostToGMBBlock.Input", *, user_id: str, **kwargs
self,
input_data: "PostToGMBBlock.Input",
*,
credentials: APIKeyCredentials,
**kwargs
) -> BlockOutput:
"""Post to Google My Business with GMB-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -202,7 +205,7 @@ class PostToGMBBlock(Block):
random_media_url=input_data.random_media_url,
notes=input_data.notes,
gmb_options=gmb_options if gmb_options else None,
profile_key=profile_key.get_secret_value(),
profile_key=credentials.api_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -2,22 +2,21 @@ from typing import Any
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._util import (
BaseAyrshareInput,
InstagramUserTag,
create_ayrshare_client,
get_profile_key,
)
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, InstagramUserTag, create_ayrshare_client
@cost(*AYRSHARE_POST_COSTS)
class PostToInstagramBlock(Block):
"""Block for posting to Instagram with Instagram-specific options."""
@@ -112,15 +111,10 @@ class PostToInstagramBlock(Block):
self,
input_data: "PostToInstagramBlock.Input",
*,
user_id: str,
credentials: APIKeyCredentials,
**kwargs,
) -> BlockOutput:
"""Post to Instagram with Instagram-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -241,7 +235,7 @@ class PostToInstagramBlock(Block):
random_media_url=input_data.random_media_url,
notes=input_data.notes,
instagram_options=instagram_options if instagram_options else None,
profile_key=profile_key.get_secret_value(),
profile_key=credentials.api_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -1,16 +1,20 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, create_ayrshare_client
@cost(*AYRSHARE_POST_COSTS)
class PostToLinkedInBlock(Block):
"""Block for posting to LinkedIn with LinkedIn-specific options."""
@@ -112,15 +116,10 @@ class PostToLinkedInBlock(Block):
self,
input_data: "PostToLinkedInBlock.Input",
*,
user_id: str,
credentials: APIKeyCredentials,
**kwargs,
) -> BlockOutput:
"""Post to LinkedIn with LinkedIn-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -214,7 +213,7 @@ class PostToLinkedInBlock(Block):
random_media_url=input_data.random_media_url,
notes=input_data.notes,
linkedin_options=linkedin_options if linkedin_options else None,
profile_key=profile_key.get_secret_value(),
profile_key=credentials.api_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -1,21 +1,20 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._util import (
BaseAyrshareInput,
PinterestCarouselOption,
create_ayrshare_client,
get_profile_key,
)
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, PinterestCarouselOption, create_ayrshare_client
@cost(*AYRSHARE_POST_COSTS)
class PostToPinterestBlock(Block):
"""Block for posting to Pinterest with Pinterest-specific options."""
@@ -92,15 +91,10 @@ class PostToPinterestBlock(Block):
self,
input_data: "PostToPinterestBlock.Input",
*,
user_id: str,
credentials: APIKeyCredentials,
**kwargs,
) -> BlockOutput:
"""Post to Pinterest with Pinterest-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -206,7 +200,7 @@ class PostToPinterestBlock(Block):
random_media_url=input_data.random_media_url,
notes=input_data.notes,
pinterest_options=pinterest_options if pinterest_options else None,
profile_key=profile_key.get_secret_value(),
profile_key=credentials.api_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -1,16 +1,20 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, create_ayrshare_client
@cost(*AYRSHARE_POST_COSTS)
class PostToRedditBlock(Block):
"""Block for posting to Reddit."""
@@ -35,12 +39,12 @@ class PostToRedditBlock(Block):
)
async def run(
self, input_data: "PostToRedditBlock.Input", *, user_id: str, **kwargs
self,
input_data: "PostToRedditBlock.Input",
*,
credentials: APIKeyCredentials,
**kwargs
) -> BlockOutput:
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured."
@@ -61,7 +65,7 @@ class PostToRedditBlock(Block):
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
profile_key=profile_key.get_secret_value(),
profile_key=credentials.api_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -1,16 +1,20 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, create_ayrshare_client
@cost(*AYRSHARE_POST_COSTS)
class PostToSnapchatBlock(Block):
"""Block for posting to Snapchat with Snapchat-specific options."""
@@ -31,6 +35,14 @@ class PostToSnapchatBlock(Block):
advanced=False,
)
# Snapchat is video-only; override the base default so the @cost filter
# selects the 5-credit video tier instead of the 2-credit image tier.
is_video: bool = SchemaField(
description="Whether the media is a video (always True for Snapchat)",
default=True,
advanced=True,
)
# Snapchat-specific options
story_type: str = SchemaField(
description="Type of Snapchat content: 'story' (24-hour Stories), 'saved_story' (Saved Stories), or 'spotlight' (Spotlight posts)",
@@ -62,15 +74,10 @@ class PostToSnapchatBlock(Block):
self,
input_data: "PostToSnapchatBlock.Input",
*,
user_id: str,
credentials: APIKeyCredentials,
**kwargs,
) -> BlockOutput:
"""Post to Snapchat with Snapchat-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -121,7 +128,7 @@ class PostToSnapchatBlock(Block):
random_media_url=input_data.random_media_url,
notes=input_data.notes,
snapchat_options=snapchat_options if snapchat_options else None,
profile_key=profile_key.get_secret_value(),
profile_key=credentials.api_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -1,16 +1,20 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, create_ayrshare_client
@cost(*AYRSHARE_POST_COSTS)
class PostToTelegramBlock(Block):
"""Block for posting to Telegram with Telegram-specific options."""
@@ -57,15 +61,10 @@ class PostToTelegramBlock(Block):
self,
input_data: "PostToTelegramBlock.Input",
*,
user_id: str,
credentials: APIKeyCredentials,
**kwargs,
) -> BlockOutput:
"""Post to Telegram with Telegram-specific validation."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -108,7 +107,7 @@ class PostToTelegramBlock(Block):
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
profile_key=profile_key.get_secret_value(),
profile_key=credentials.api_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -1,16 +1,20 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, create_ayrshare_client
@cost(*AYRSHARE_POST_COSTS)
class PostToThreadsBlock(Block):
"""Block for posting to Threads with Threads-specific options."""
@@ -50,15 +54,10 @@ class PostToThreadsBlock(Block):
self,
input_data: "PostToThreadsBlock.Input",
*,
user_id: str,
credentials: APIKeyCredentials,
**kwargs,
) -> BlockOutput:
"""Post to Threads with Threads-specific validation."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -103,7 +102,7 @@ class PostToThreadsBlock(Block):
random_media_url=input_data.random_media_url,
notes=input_data.notes,
threads_options=threads_options if threads_options else None,
profile_key=profile_key.get_secret_value(),
profile_key=credentials.api_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -2,15 +2,18 @@ from enum import Enum
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, create_ayrshare_client
class TikTokVisibility(str, Enum):
@@ -19,6 +22,7 @@ class TikTokVisibility(str, Enum):
FOLLOWERS = "followers"
@cost(*AYRSHARE_POST_COSTS)
class PostToTikTokBlock(Block):
"""Block for posting to TikTok with TikTok-specific options."""
@@ -113,14 +117,13 @@ class PostToTikTokBlock(Block):
)
async def run(
self, input_data: "PostToTikTokBlock.Input", *, user_id: str, **kwargs
self,
input_data: "PostToTikTokBlock.Input",
*,
credentials: APIKeyCredentials,
**kwargs,
) -> BlockOutput:
"""Post to TikTok with TikTok-specific validation and options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -235,7 +238,7 @@ class PostToTikTokBlock(Block):
random_media_url=input_data.random_media_url,
notes=input_data.notes,
tiktok_options=tiktok_options if tiktok_options else None,
profile_key=profile_key.get_secret_value(),
profile_key=credentials.api_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -1,16 +1,20 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, create_ayrshare_client
@cost(*AYRSHARE_POST_COSTS)
class PostToXBlock(Block):
"""Block for posting to X / Twitter with Twitter-specific options."""
@@ -115,15 +119,10 @@ class PostToXBlock(Block):
self,
input_data: "PostToXBlock.Input",
*,
user_id: str,
credentials: APIKeyCredentials,
**kwargs,
) -> BlockOutput:
"""Post to X / Twitter with enhanced X-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -233,7 +232,7 @@ class PostToXBlock(Block):
random_media_url=input_data.random_media_url,
notes=input_data.notes,
twitter_options=twitter_options if twitter_options else None,
profile_key=profile_key.get_secret_value(),
profile_key=credentials.api_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -3,15 +3,18 @@ from typing import Any
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaOutput,
BlockType,
SchemaField,
cost,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
from ._cost import AYRSHARE_POST_COSTS
from ._util import BaseAyrshareInput, create_ayrshare_client
class YouTubeVisibility(str, Enum):
@@ -20,6 +23,7 @@ class YouTubeVisibility(str, Enum):
UNLISTED = "unlisted"
@cost(*AYRSHARE_POST_COSTS)
class PostToYouTubeBlock(Block):
"""Block for posting to YouTube with YouTube-specific options."""
@@ -39,6 +43,14 @@ class PostToYouTubeBlock(Block):
advanced=False,
)
# YouTube is video-only; override the base default so the @cost filter
# selects the 5-credit video tier instead of the 2-credit image tier.
is_video: bool = SchemaField(
description="Whether the media is a video (always True for YouTube)",
default=True,
advanced=True,
)
# YouTube-specific required options
title: str = SchemaField(
description="Video title (max 100 chars, required). Cannot contain < or > characters.",
@@ -137,16 +149,10 @@ class PostToYouTubeBlock(Block):
self,
input_data: "PostToYouTubeBlock.Input",
*,
user_id: str,
credentials: APIKeyCredentials,
**kwargs,
) -> BlockOutput:
"""Post to YouTube with YouTube-specific validation and options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
@@ -302,7 +308,7 @@ class PostToYouTubeBlock(Block):
random_media_url=input_data.random_media_url,
notes=input_data.notes,
youtube_options=youtube_options,
profile_key=profile_key.get_secret_value(),
profile_key=credentials.api_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:

View File

@@ -8,17 +8,27 @@ from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockCost,
BlockCostType,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
CredentialsMetaInput,
SchemaField,
cost,
)
from ._api import MeetingBaasAPI
from ._config import baas
# Meeting BaaS charges $0.69/hour of recording. The Join block is the
# trigger that starts the recording session; the meeting itself runs out
# of band (we don't get duration back from the FetchMeetingData response
# we use). 30 cr ≈ $0.30 covers a median 30-minute meeting with margin.
# Interim until FetchMeetingData surfaces duration for post-flight
# reconciliation.
@cost(BlockCost(cost_type=BlockCostType.RUN, cost_amount=30))
class BaasBotJoinMeetingBlock(Block):
"""
Deploy a bot immediately or at a scheduled start_time to join and record a meeting.

View File

@@ -3,6 +3,6 @@ from backend.sdk import BlockCostType, ProviderBuilder
bannerbear = (
ProviderBuilder("bannerbear")
.with_api_key("BANNERBEAR_API_KEY", "Bannerbear API Key")
.with_base_cost(1, BlockCostType.RUN)
.with_base_cost(3, BlockCostType.RUN)
.build()
)

View File

@@ -19,6 +19,10 @@ class DataForSeoClient:
trusted_origins=["https://api.dataforseo.com"],
raise_for_status=False,
)
# USD cost reported by DataForSEO on the most recent successful call.
# Populated by keyword_suggestions / related_keywords so the caller
# can surface it via NodeExecutionStats.provider_cost for billing.
self.last_cost_usd: float = 0.0
def _get_headers(self) -> Dict[str, str]:
"""Generate the authorization header using Basic Auth."""
@@ -97,6 +101,9 @@ class DataForSeoClient:
if data.get("tasks") and len(data["tasks"]) > 0:
task = data["tasks"][0]
if task.get("status_code") == 20000: # Success code
# DataForSEO reports per-task USD cost; stash it so callers
# can populate NodeExecutionStats.provider_cost.
self.last_cost_usd = float(task.get("cost") or 0.0)
return task.get("result", [])
else:
error_msg = task.get("status_message", "Task failed")
@@ -174,6 +181,9 @@ class DataForSeoClient:
if data.get("tasks") and len(data["tasks"]) > 0:
task = data["tasks"][0]
if task.get("status_code") == 20000: # Success code
# DataForSEO reports per-task USD cost; stash it so callers
# can populate NodeExecutionStats.provider_cost.
self.last_cost_usd = float(task.get("cost") or 0.0)
return task.get("result", [])
else:
error_msg = task.get("status_message", "Task failed")

View File

@@ -12,6 +12,11 @@ dataforseo = (
password_env_var="DATAFORSEO_PASSWORD",
title="DataForSEO Credentials",
)
.with_base_cost(1, BlockCostType.RUN)
# DataForSEO reports USD cost per task (e.g. $0.001/keyword returned).
# DataForSeoClient stashes it on last_cost_usd; each block emits it via
# merge_stats so the COST_USD resolver bills against real spend.
# 1000 platform credits per USD → 1 credit per $0.001 (≈ 1 credit/
# returned keyword on the standard tier).
.with_base_cost(1000, BlockCostType.COST_USD)
.build()
)

View File

@@ -4,6 +4,7 @@ DataForSEO Google Keyword Suggestions block.
from typing import Any, Dict, List, Optional
from backend.data.model import NodeExecutionStats
from backend.sdk import (
Block,
BlockCategory,
@@ -110,8 +111,10 @@ class DataForSeoKeywordSuggestionsBlock(Block):
test_output=[
(
"suggestion",
lambda x: hasattr(x, "keyword")
and x.keyword == "digital marketing strategy",
lambda x: (
hasattr(x, "keyword")
and x.keyword == "digital marketing strategy"
),
),
("suggestions", lambda x: isinstance(x, list) and len(x) == 1),
("total_count", 1),
@@ -167,6 +170,16 @@ class DataForSeoKeywordSuggestionsBlock(Block):
results = await self._fetch_keyword_suggestions(client, input_data)
# DataForSEO reports per-task USD cost on the response. Feed it
# into NodeExecutionStats so the COST_USD resolver bills the
# real provider spend at reconciliation time.
self.merge_stats(
NodeExecutionStats(
provider_cost=client.last_cost_usd,
provider_cost_type="cost_usd",
)
)
# Process and format the results
suggestions = []
if results and len(results) > 0:

View File

@@ -4,6 +4,7 @@ DataForSEO Google Related Keywords block.
from typing import Any, Dict, List, Optional
from backend.data.model import NodeExecutionStats
from backend.sdk import (
Block,
BlockCategory,
@@ -177,6 +178,16 @@ class DataForSeoRelatedKeywordsBlock(Block):
results = await self._fetch_related_keywords(client, input_data)
# DataForSEO reports per-task USD cost on the response. Feed it
# into NodeExecutionStats so the COST_USD resolver bills the
# real provider spend at reconciliation time.
self.merge_stats(
NodeExecutionStats(
provider_cost=client.last_cost_usd,
provider_cost_type="cost_usd",
)
)
# Process and format the results
related_keywords = []
if results and len(results) > 0:

View File

@@ -11,6 +11,11 @@ exa = (
ProviderBuilder("exa")
.with_api_key("EXA_API_KEY", "Exa API Key")
.with_webhook_manager(ExaWebhookManager)
.with_base_cost(1, BlockCostType.RUN)
# Exa returns `cost_dollars.total` on every response and ExaSearchBlock
# (plus ~45 sibling blocks that share this provider config) already
# populates NodeExecutionStats.provider_cost with it. Bill 100 credits
# per USD (~$0.01/credit): cheap searches stay at 12 credits, a Deep
# Research run at $0.20 lands at 20 credits, matching provider spend.
.with_base_cost(100, BlockCostType.COST_USD)
.build()
)

View File

@@ -1,8 +1,14 @@
from backend.sdk import BlockCostType, ProviderBuilder
# Firecrawl bills in its own credits (1 credit ≈ $0.001). Each block's
# run() estimates USD spend from the operation (pages scraped, limit,
# credits_used on ExtractResponse) and merge_stats populates
# NodeExecutionStats.provider_cost before billing reconciliation. 1000
# platform credits per USD means 1 platform credit per Firecrawl credit
# — roughly matches our existing per-call tier for single-page scrape.
firecrawl = (
ProviderBuilder("firecrawl")
.with_api_key("FIRECRAWL_API_KEY", "Firecrawl API Key")
.with_base_cost(1, BlockCostType.RUN)
.with_base_cost(1000, BlockCostType.COST_USD)
.build()
)

View File

@@ -4,6 +4,7 @@ from firecrawl import FirecrawlApp
from firecrawl.v2.types import ScrapeOptions
from backend.blocks.firecrawl._api import ScrapeFormat
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -86,6 +87,14 @@ class FirecrawlCrawlBlock(Block):
wait_for=input_data.wait_for,
),
)
# Firecrawl bills 1 credit (~$0.001) per crawled page. crawl_result.data
# is the list of scraped pages actually returned.
pages = len(crawl_result.data) if crawl_result.data else 0
self.merge_stats(
NodeExecutionStats(
provider_cost=pages * 0.001, provider_cost_type="cost_usd"
)
)
yield "data", crawl_result.data
for data in crawl_result.data:

View File

@@ -2,25 +2,22 @@ from typing import Any
from firecrawl import FirecrawlApp
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockCost,
BlockCostType,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
CredentialsMetaInput,
SchemaField,
cost,
)
from backend.util.exceptions import BlockExecutionError
from ._config import firecrawl
@cost(BlockCost(2, BlockCostType.RUN))
class FirecrawlExtractBlock(Block):
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = firecrawl.credentials_field()
@@ -74,4 +71,13 @@ class FirecrawlExtractBlock(Block):
block_id=self.id,
) from e
# Firecrawl surfaces actual credit spend on extract responses
# (credits_used). 1 Firecrawl credit ≈ $0.001.
credits_used = getattr(extract_result, "credits_used", None) or 0
self.merge_stats(
NodeExecutionStats(
provider_cost=credits_used * 0.001,
provider_cost_type="cost_usd",
)
)
yield "data", extract_result.data

View File

@@ -2,6 +2,7 @@ from typing import Any
from firecrawl import FirecrawlApp
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -50,6 +51,10 @@ class FirecrawlMapWebsiteBlock(Block):
map_result = app.map(
url=input_data.url,
)
# Firecrawl bills 1 credit (~$0.001) per map request.
self.merge_stats(
NodeExecutionStats(provider_cost=0.001, provider_cost_type="cost_usd")
)
# Convert SearchResult objects to dicts
results_data = [

View File

@@ -3,6 +3,7 @@ from typing import Any
from firecrawl import FirecrawlApp
from backend.blocks.firecrawl._api import ScrapeFormat
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -81,6 +82,11 @@ class FirecrawlScrapeBlock(Block):
max_age=input_data.max_age,
wait_for=input_data.wait_for,
)
# Firecrawl bills 1 credit (~$0.001) per scraped page; scrape is a
# single-page operation.
self.merge_stats(
NodeExecutionStats(provider_cost=0.001, provider_cost_type="cost_usd")
)
yield "data", scrape_result
for f in input_data.formats:

View File

@@ -4,6 +4,7 @@ from firecrawl import FirecrawlApp
from firecrawl.v2.types import ScrapeOptions
from backend.blocks.firecrawl._api import ScrapeFormat
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -68,6 +69,17 @@ class FirecrawlSearchBlock(Block):
wait_for=input_data.wait_for,
),
)
# Firecrawl bills per returned web result (~1 credit each). The
# SearchResponse structure exposes `.web` when scrape_options was
# requested; fall back to `limit` as an upper bound estimate.
web_results = getattr(scrape_result, "web", None) or []
billed_units = max(len(web_results), 1)
self.merge_stats(
NodeExecutionStats(
provider_cost=billed_units * 0.001,
provider_cost_type="cost_usd",
)
)
yield "data", scrape_result
if hasattr(scrape_result, "web") and scrape_result.web:
for site in scrape_result.web:

View File

@@ -133,10 +133,21 @@ def GoogleDriveFileField(
if allowed_mime_types:
picker_config["allowed_mime_types"] = list(allowed_mime_types)
agent_builder_hint = (
"At runtime, feed this from an AgentGoogleDriveFileInputBlock with "
"matching allowed_views. NEVER hardcode a file ID in input_default "
"(including one parsed from a Drive URL the user pasted in chat) — "
"only the picker attaches the _credentials_id needed for auth."
)
return SchemaField(
default=None,
title=title,
description=description,
description=(
f"{description.rstrip('.')}. {agent_builder_hint}"
if description
else agent_builder_hint
),
placeholder=placeholder or "Select from Google Drive",
# Use google-drive-picker format so frontend renders existing component
format="google-drive-picker",

View File

@@ -0,0 +1,129 @@
"""Edge-case tests for Google Sheets block credential handling.
These pin the contract for the systemic auto-credential None-guard in
``Block._execute()``: any block with an auto-credential field (via
``GoogleDriveFileField`` etc.) that's called without resolved
credentials must surface a clean, user-facing ``BlockExecutionError``
— never a wrapped ``TypeError`` (missing required kwarg) or
``AttributeError`` deep in the provider SDK.
"""
import pytest
from backend.blocks.google.sheets import GoogleSheetsReadBlock
from backend.util.exceptions import BlockExecutionError
@pytest.mark.asyncio
async def test_sheets_read_missing_credentials_yields_clean_error():
"""Valid spreadsheet but no resolved credentials -> the systemic
None-guard in ``Block._execute()`` yields a ``Missing credentials``
error before ``run()`` is entered."""
block = GoogleSheetsReadBlock()
input_data = {
"spreadsheet": {
"id": "1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms",
"name": "Test Spreadsheet",
"mimeType": "application/vnd.google-apps.spreadsheet",
},
"range": "Sheet1!A1:B2",
}
with pytest.raises(BlockExecutionError, match="Missing credentials"):
async for _ in block.execute(input_data):
pass
@pytest.mark.asyncio
async def test_sheets_read_no_spreadsheet_still_hits_credentials_guard():
"""When neither spreadsheet nor credentials are present, the
credentials guard fires first (it runs before we hand off to
``run()``). The user-facing message should still be the clean
``Missing credentials`` one, not an opaque ``TypeError``."""
block = GoogleSheetsReadBlock()
input_data = {"range": "Sheet1!A1:B2"} # no spreadsheet, no credentials
with pytest.raises(BlockExecutionError, match="Missing credentials"):
async for _ in block.execute(input_data):
pass
@pytest.mark.asyncio
async def test_sheets_read_upstream_chained_value_skips_guard(mocker):
"""A spreadsheet value chained in from an upstream input block (e.g.
``AgentGoogleDriveFileInputBlock``) carries a resolved
``_credentials_id`` that ``_acquire_auto_credentials`` didn't have
visibility into at prep time. The systemic None-guard must NOT
preempt run() in that case — otherwise every chained Drive-picker
pattern crashes with a bogus ``Missing credentials`` error.
We short-circuit past the guard by patching the Google API client
build; any error that escapes from run() is fine as long as the
``Missing credentials`` message never surfaces."""
# Patch out the real Google Sheets client build so we don't hit the
# network and can detect we reached the provider SDK.
mocker.patch(
"backend.blocks.google.sheets.build",
side_effect=RuntimeError("api-boundary-reached"),
)
block = GoogleSheetsReadBlock()
input_data = {
"spreadsheet": {
"_credentials_id": "upstream-chained-cred-id",
"id": "1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms",
"name": "Upstream-chained sheet",
"mimeType": "application/vnd.google-apps.spreadsheet",
},
"range": "Sheet1!A1:B2",
}
with pytest.raises(Exception) as exc_info:
async for _ in block.execute(input_data):
pass
# The guard should skip (chained data present) and let us reach run(),
# which then hits the patched provider-SDK boundary. A "Missing
# credentials" error here would mean the None-guard broke the
# documented AgentGoogleDriveFileInputBlock chaining pattern.
assert "Missing credentials" not in str(exc_info.value)
@pytest.mark.asyncio
async def test_sheets_read_upstream_chained_with_explicit_none_cred_id_skips_guard(
mocker,
):
"""Sentry HIGH regression (thread PRRT_kwDOJKSTjM58sJfA): the
documented chained-upstream pattern ships the spreadsheet dict with
``_credentials_id=None`` — the executor fills in the resolved id
between prep time and ``run()``. The previous ``_base.py`` guard
used ``field_value.get("_credentials_id")`` and treated the falsy
``None`` value as "missing", raising ``BlockExecutionError`` on
every chained graph.
Pin the contract: the presence of the ``_credentials_id`` key — not
its truthiness — is what signals "trust the skip". A dict with
``_credentials_id: None`` must not preempt run()."""
mocker.patch(
"backend.blocks.google.sheets.build",
side_effect=RuntimeError("api-boundary-reached"),
)
block = GoogleSheetsReadBlock()
input_data = {
"spreadsheet": {
"_credentials_id": None, # explicit None — chained-upstream shape
"id": "1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms",
"name": "Upstream-chained sheet (None cred_id)",
"mimeType": "application/vnd.google-apps.spreadsheet",
},
"range": "Sheet1!A1:B2",
}
with pytest.raises(Exception) as exc_info:
async for _ in block.execute(input_data):
pass
# The guard must not raise "Missing credentials" for this shape.
# We expect to reach run() and hit the patched provider-SDK boundary.
assert "Missing credentials" not in str(exc_info.value)

View File

@@ -737,7 +737,22 @@ class AgentGoogleDriveFileInputBlock(AgentInputBlock):
)
super().__init__(
id="d3b32f15-6fd7-40e3-be52-e083f51b19a2",
description="Block for selecting a file from Google Drive.",
description=(
"Agent-level input for a Google Drive file. REQUIRED for any "
"agent that reads or writes a Drive file (Sheets, Docs, "
"Slides, or generic Drive) — the picker is the only source "
"of the _credentials_id needed at runtime, so consuming "
"blocks cannot receive a hardcoded ID. Set allowed_views to "
'match the consumer: ["SPREADSHEETS"] for Sheets, '
'["DOCUMENTS"] for Docs, ["PRESENTATIONS"] for Slides '
"(leave default for generic Drive). Wire `result` to the "
"consumer block's Drive field and leave that field unset in "
"the consumer's input_default. Example link to a Google "
'Sheets block: {"source_name": "result", "sink_name": '
'"spreadsheet"} (use "document" for Docs, "presentation" '
"for Slides). Use one input block per distinct file; "
"multiple consumers of the same file share it."
),
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentGoogleDriveFileInputBlock.Input,
output_schema=AgentGoogleDriveFileInputBlock.Output,

View File

@@ -15,7 +15,7 @@ from backend.blocks.jina._auth import (
JinaCredentialsInput,
)
from backend.blocks.search import GetRequest
from backend.data.model import SchemaField
from backend.data.model import NodeExecutionStats, SchemaField
from backend.util.exceptions import BlockExecutionError
from backend.util.request import HTTPClientError, HTTPServerError, validate_url_host
@@ -70,6 +70,13 @@ class SearchTheWebBlock(Block, GetRequest):
block_id=self.id,
) from e
# Jina Reader Search: $0.01/query on the paid tier. Fixed per-query
# cost; routed through COST_USD so the platform cost log records
# real USD spend (costMicrodollars) alongside the credit charge.
self.merge_stats(
NodeExecutionStats(provider_cost=0.01, provider_cost_type="cost_usd")
)
# Output the search results
yield "results", results
@@ -128,10 +135,16 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
try:
content = await self.get_request(url, json=False, headers=headers)
except HTTPClientError as e:
yield "error", f"Client error ({e.status_code}) fetching {input_data.url}: {e}"
yield (
"error",
f"Client error ({e.status_code}) fetching {input_data.url}: {e}",
)
return
except HTTPServerError as e:
yield "error", f"Server error ({e.status_code}) fetching {input_data.url}: {e}"
yield (
"error",
f"Server error ({e.status_code}) fetching {input_data.url}: {e}",
)
return
except Exception as e:
yield "error", f"Failed to fetch {input_data.url}: {e}"

View File

@@ -206,6 +206,10 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
GROK_4_20_MULTI_AGENT = "x-ai/grok-4.20-multi-agent"
GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1"
KIMI_K2 = "moonshotai/kimi-k2"
KIMI_K2_0905 = "moonshotai/kimi-k2-0905"
KIMI_K2_5 = "moonshotai/kimi-k2.5"
KIMI_K2_6 = "moonshotai/kimi-k2.6"
KIMI_K2_THINKING = "moonshotai/kimi-k2-thinking"
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
QWEN3_CODER = "qwen/qwen3-coder"
# Z.ai (Zhipu) models
@@ -646,6 +650,24 @@ MODEL_METADATA = {
LlmModel.KIMI_K2: ModelMetadata(
"open_router", 131000, 131000, "Kimi K2", "OpenRouter", "Moonshot AI", 1
),
LlmModel.KIMI_K2_0905: ModelMetadata(
"open_router", 262144, 262144, "Kimi K2 0905", "OpenRouter", "Moonshot AI", 1
),
LlmModel.KIMI_K2_5: ModelMetadata(
"open_router", 262144, 262144, "Kimi K2.5", "OpenRouter", "Moonshot AI", 1
),
LlmModel.KIMI_K2_6: ModelMetadata(
"open_router", 262144, 262144, "Kimi K2.6", "OpenRouter", "Moonshot AI", 2
),
LlmModel.KIMI_K2_THINKING: ModelMetadata(
"open_router",
262144,
262144,
"Kimi K2 Thinking",
"OpenRouter",
"Moonshot AI",
2,
),
LlmModel.QWEN3_235B_A22B_THINKING: ModelMetadata(
"open_router",
262144,

View File

@@ -376,20 +376,12 @@ class OrchestratorBlock(Block):
re-raise carve-out for this reason.
"""
def extra_runtime_cost(self, execution_stats: NodeExecutionStats) -> int:
"""Charge one extra runtime cost per LLM call beyond the first.
In agent mode each iteration makes one LLM call. The first is already
covered by charge_usage(); this returns the number of additional
credits so the executor can bill the remaining calls post-completion.
SDK-mode exemption: when the block runs via _execute_tools_sdk_mode,
the SDK manages its own conversation loop and only exposes aggregate
usage. We hardcode llm_call_count=1 there (the SDK does not report a
per-turn call count), so this method always returns 0 for SDK-mode
executions. Per-iteration billing does not apply to SDK mode.
"""
return max(0, execution_stats.llm_call_count - 1)
# OrchestratorBlock bills via BlockCostType.TOKENS + compute_token_credits,
# which aggregates input_token_count / output_token_count / cache_read /
# cache_creation across every LLM iteration into one post-flight charge.
# The per-iteration flat-fee path (Block.extra_runtime_cost →
# charge_extra_runtime_cost) would double-bill the same tokens, so
# OrchestratorBlock deliberately inherits the base-class no-op default.
# MCP server name used by the Claude Code SDK execution mode. Keep in sync
# with _create_graph_mcp_server and the MCP_PREFIX derivation in _execute_tools_sdk_mode.
@@ -1189,10 +1181,14 @@ class OrchestratorBlock(Block):
not execution_params.execution_context.dry_run
and tool_node_stats.error is None
):
# Charge the sub-block for telemetry / wallet debit. The
# return value is intentionally discarded: on_node_execution
# above ran the sub-block against this graph's own
# graph_stats_pair (manager.py:659-668), so its cost already
# lands in graph_stats.cost on the sub-block's completion.
# Re-merging here would double-count in telemetry / UI / audit.
try:
tool_cost, _ = await execution_processor.charge_node_usage(
node_exec_entry,
)
await execution_processor.charge_node_usage(node_exec_entry)
except InsufficientBalanceError:
# IBE must propagate — see OrchestratorBlock class docstring.
# Log the billing failure here so the discarded tool result
@@ -1214,9 +1210,6 @@ class OrchestratorBlock(Block):
"tool execution was successful",
sink_node_id,
)
tool_cost = 0
if tool_cost > 0:
self.merge_stats(NodeExecutionStats(extra_cost=tool_cost))
# Get outputs from database after execution completes using database manager client
node_outputs = await db_client.get_execution_outputs_by_node_exec_id(

View File

@@ -13,6 +13,7 @@ from backend.blocks._base import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.llm import extract_openrouter_cost
from backend.data.block import BlockInput
from backend.data.model import (
APIKeyCredentials,
@@ -98,14 +99,23 @@ class PerplexityBlock(Block):
return _sanitize_perplexity_model(v)
@classmethod
def validate_data(cls, data: BlockInput) -> str | None:
def validate_data(
cls,
data: BlockInput,
exclude_fields: set[str] | None = None,
) -> str | None:
"""Sanitize the model field before JSON schema validation so that
invalid values are replaced with the default instead of raising a
BlockInputError."""
BlockInputError.
Signature matches ``BlockSchema.validate_data`` (including the
optional ``exclude_fields`` kwarg added for dry-run credential
bypass) so Pyright doesn't flag this as an incompatible override.
"""
model_value = data.get("model")
if model_value is not None:
data["model"] = _sanitize_perplexity_model(model_value).value
return super().validate_data(data)
return super().validate_data(data, exclude_fields=exclude_fields)
system_prompt: str = SchemaField(
title="System Prompt",
@@ -230,12 +240,24 @@ class PerplexityBlock(Block):
if "message" in choice and "annotations" in choice["message"]:
annotations = choice["message"]["annotations"]
# Update execution stats
# Update execution stats. ``execution_stats`` is instance state,
# so always reset token counters — a response without ``usage``
# must not leak a previous run's tokens into ``PlatformCostLog``.
self.execution_stats.input_token_count = 0
self.execution_stats.output_token_count = 0
if response.usage:
self.execution_stats.input_token_count = response.usage.prompt_tokens
self.execution_stats.output_token_count = (
response.usage.completion_tokens
)
# OpenRouter's ``x-total-cost`` response header carries the real
# per-request USD cost. Piping it into ``provider_cost`` lets the
# direct-run ``PlatformCostLog`` flow
# (``executor.cost_tracking::log_system_credential_cost``) record
# the actual operator-side spend instead of inferring from tokens.
# Always overwrite — ``execution_stats`` is instance state, so a
# response without the header must not reuse a previous run's cost.
self.execution_stats.provider_cost = extract_openrouter_cost(response)
return {"response": response_content, "annotations": annotations or []}

View File

@@ -1,8 +1,12 @@
from backend.sdk import BlockCostType, ProviderBuilder
# 1 credit per 3 walltime seconds. Block walltime proxies for the
# Browserbase session lifetime + the LLM call it issues. Interim until
# the block emits real provider_cost (USD) via merge_stats and migrates
# to COST_USD.
stagehand = (
ProviderBuilder("stagehand")
.with_api_key("STAGEHAND_API_KEY", "Stagehand API Key")
.with_base_cost(1, BlockCostType.RUN)
.with_base_cost(1, BlockCostType.SECOND, cost_divisor=3)
.build()
)

View File

@@ -21,7 +21,7 @@ from backend.blocks.zerobounce._auth import (
ZeroBounceCredentials,
ZeroBounceCredentialsInput,
)
from backend.data.model import CredentialsField, SchemaField
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
class Response(BaseModel):
@@ -140,20 +140,22 @@ class ValidateEmailsBlock(Block):
)
],
test_mock={
"validate_email": lambda email, ip_address, credentials: ZBValidateResponse(
data={
"address": email,
"status": ZBValidateStatus.valid,
"sub_status": ZBValidateSubStatus.allowed,
"account": "test",
"domain": "test.com",
"did_you_mean": None,
"domain_age_days": None,
"free_email": False,
"mx_found": False,
"mx_record": None,
"smtp_provider": None,
}
"validate_email": lambda email, ip_address, credentials: (
ZBValidateResponse(
data={
"address": email,
"status": ZBValidateStatus.valid,
"sub_status": ZBValidateSubStatus.allowed,
"account": "test",
"domain": "test.com",
"did_you_mean": None,
"domain_age_days": None,
"free_email": False,
"mx_found": False,
"mx_record": None,
"smtp_provider": None,
}
)
)
},
)
@@ -176,6 +178,13 @@ class ValidateEmailsBlock(Block):
input_data.email, input_data.ip_address, credentials
)
# ZeroBounce bills $0.008 per validated email on the paid tier.
# Routed through COST_USD so platform cost telemetry captures real
# USD spend; the resolver still bills 2 credits per call.
self.merge_stats(
NodeExecutionStats(provider_cost=0.008, provider_cost_type="cost_usd")
)
response_model = Response(**response.__dict__)
yield "response", response_model

View File

@@ -0,0 +1,364 @@
"""Extended-thinking wire support for the baseline (OpenRouter) path.
OpenRouter routes that support extended thinking (Anthropic Claude and
Moonshot Kimi today) expose reasoning through non-OpenAI extension fields
that the OpenAI Python SDK doesn't model:
* ``reasoning`` (legacy string) — enabled by ``include_reasoning: true``.
* ``reasoning_content`` — DeepSeek / some OpenRouter routes.
* ``reasoning_details`` — structured list shipped with the unified
``reasoning`` request param.
This module keeps the wire-level concerns in one place:
* :class:`OpenRouterDeltaExtension` validates the extension dict pulled off
``ChoiceDelta.model_extra`` into typed pydantic models — no ``getattr`` +
``isinstance`` duck-typing at the call site.
* :class:`BaselineReasoningEmitter` owns the reasoning block lifecycle for
one streaming round and emits ``StreamReasoning*`` events so the caller
only has to plumb the events into its pending queue.
* :func:`reasoning_extra_body` builds the ``extra_body`` fragment for the
OpenAI client call. Returns ``None`` for routes without reasoning
support (see :func:`_is_reasoning_route`).
"""
from __future__ import annotations
import logging
import time
import uuid
from typing import Any
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from pydantic import BaseModel, ConfigDict, Field, ValidationError
from backend.copilot.model import ChatMessage
from backend.copilot.response_model import (
StreamBaseResponse,
StreamReasoningDelta,
StreamReasoningEnd,
StreamReasoningStart,
)
logger = logging.getLogger(__name__)
_VISIBLE_REASONING_TYPES = frozenset({"reasoning.text", "reasoning.summary"})
# Coalescing thresholds for ``StreamReasoningDelta`` emission. OpenRouter's
# Kimi K2.6 endpoint tokenises reasoning at a much finer grain than Anthropic
# (~4,700 deltas per turn in one observed session, vs ~28 for Sonnet); without
# coalescing, every chunk is one Redis ``xadd`` + one SSE frame + one React
# re-render of the non-virtualised chat list, which paint-storms the browser
# main thread and freezes the UI. Batching into ~64-char / ~50 ms windows
# cuts the event rate ~150x while staying snappy enough that the Reasoning
# collapse still feels live (well under the ~100 ms perceptual threshold).
# Per-delta persistence to ``session.messages`` stays granular — we only
# coalesce the *wire* emission.
_COALESCE_MIN_CHARS = 64
_COALESCE_MAX_INTERVAL_MS = 50.0
class ReasoningDetail(BaseModel):
"""One entry in OpenRouter's ``reasoning_details`` list.
OpenRouter ships ``type: "reasoning.text"`` / ``"reasoning.summary"`` /
``"reasoning.encrypted"`` entries. Only the first two carry
user-visible text; encrypted entries are opaque and omitted from the
rendered collapse. Unknown future types are tolerated (``extra="ignore"``)
so an upstream addition doesn't crash the stream — but their ``text`` /
``summary`` fields are NOT surfaced because they may carry provider
metadata rather than user-visible reasoning (see
:attr:`visible_text`).
"""
model_config = ConfigDict(extra="ignore")
type: str | None = None
text: str | None = None
summary: str | None = None
@property
def visible_text(self) -> str:
"""Return the human-readable text for this entry, or ``""``.
Only entries with a recognised reasoning type (``reasoning.text`` /
``reasoning.summary``) surface text; unknown or encrypted types
return an empty string even if they carry a ``text`` /
``summary`` field, to guard against future provider metadata
being rendered as reasoning in the UI. Entries missing a
``type`` are treated as text (pre-``reasoning_details`` OpenRouter
payloads omit the field).
"""
if self.type is not None and self.type not in _VISIBLE_REASONING_TYPES:
return ""
return self.text or self.summary or ""
class OpenRouterDeltaExtension(BaseModel):
"""Non-OpenAI fields OpenRouter adds to streaming deltas.
Instantiate via :meth:`from_delta` which pulls the extension dict off
``ChoiceDelta.model_extra`` (where pydantic v2 stashes fields that
aren't part of the declared schema) and validates it through this
model. That keeps the parser honest — malformed entries surface as
validation errors rather than silent ``None``-coalesce bugs — and
avoids the ``getattr`` + ``isinstance`` duck-typing the earlier inline
extractor relied on.
"""
model_config = ConfigDict(extra="ignore")
reasoning: str | None = None
reasoning_content: str | None = None
reasoning_details: list[ReasoningDetail] = Field(default_factory=list)
@classmethod
def from_delta(cls, delta: ChoiceDelta) -> "OpenRouterDeltaExtension":
"""Build an extension view from ``delta.model_extra``.
Malformed provider payloads (e.g. ``reasoning_details`` shipped as
a string rather than a list) surface as a ``ValidationError`` which
is logged and swallowed — returning an empty extension so the rest
of the stream (valid text / tool calls) keeps flowing. An optional
feature's corrupted wire data must never abort the whole stream.
"""
try:
return cls.model_validate(delta.model_extra or {})
except ValidationError as exc:
logger.warning(
"[Baseline] Dropping malformed OpenRouter reasoning payload: %s",
exc,
)
return cls()
def visible_text(self) -> str:
"""Concatenated reasoning text, pulled from whichever channel is set.
Priority: the legacy ``reasoning`` string, then DeepSeek's
``reasoning_content``, then the concatenation of text-bearing
entries in ``reasoning_details``. Only one channel is set per
provider in practice; the priority order just makes the fallback
deterministic if a provider ever emits multiple.
"""
if self.reasoning:
return self.reasoning
if self.reasoning_content:
return self.reasoning_content
return "".join(d.visible_text for d in self.reasoning_details)
def _is_reasoning_route(model: str) -> bool:
"""Return True when the route supports OpenRouter's ``reasoning`` extension.
OpenRouter exposes reasoning tokens via a unified ``reasoning`` request
param that works on any provider that supports extended thinking —
currently Anthropic (Claude Opus / Sonnet) and Moonshot (Kimi K2.6 +
kimi-k2-thinking) advertise it in their ``supported_parameters``.
Other providers silently drop the field, but we skip it anyway to keep
the payload tight and avoid confusing cache diagnostics.
Kept separate from :func:`backend.copilot.baseline.service._is_anthropic_model`
because ``cache_control`` is strictly Anthropic-specific (Moonshot does
its own auto-caching), so the two gates must not conflate.
Both the Claude and Kimi matches are anchored to the provider
prefix (or to a bare model id with no prefix at all) to avoid
substring false positives — a custom ``some-other-provider/claude-mock``
or ``provider/hakimi-large`` configured via
``CHAT_FAST_STANDARD_MODEL`` must NOT inherit the reasoning
extra_body and take a 400 from its upstream. Recognised shapes:
* Claude — ``anthropic/`` or ``anthropic.`` provider prefix, or a
bare ``claude-`` model id with no provider prefix
(``claude-opus-4.7``, ``anthropic/claude-sonnet-4-6``,
``anthropic.claude-3-5-sonnet``). A non-Anthropic prefix like
``someprovider/claude-mock`` is rejected on purpose.
* Kimi — ``moonshotai/`` provider prefix, or a ``kimi-`` model id
with no provider prefix (``kimi-k2.6``,
``moonshotai/kimi-k2-thinking``). Like Claude, a non-Moonshot
prefix is rejected — exception: ``openrouter/kimi-k2.6`` stays
recognised because ``openrouter/`` is how we route to Moonshot
today and changing that would be a behaviour regression for
existing deployments.
"""
lowered = model.lower()
if lowered.startswith(("anthropic/", "anthropic.")):
return True
if lowered.startswith("moonshotai/"):
return True
# ``openrouter/`` historically routes to whatever the default
# upstream for the model is — for kimi that's Moonshot, so accept
# ``openrouter/kimi-...`` here. Other ``openrouter/`` models
# (e.g. ``openrouter/auto``) fall through to the no-prefix check
# below and are rejected unless they start with ``claude-`` /
# ``kimi-`` after the slash, which no real OpenRouter route does.
if lowered.startswith("openrouter/kimi-"):
return True
if "/" in lowered:
# Any other provider prefix is a custom / non-Anthropic /
# non-Moonshot route and must not opt into reasoning. This
# blocks substring false positives like
# ``some-provider/claude-mock-v1`` or ``other/kimi-pro``.
return False
# No provider prefix — accept bare ``claude-*`` and ``kimi-*`` ids
# so direct CLI configs (``claude-3-5-sonnet-20241022``,
# ``kimi-k2-instruct``) keep working.
return lowered.startswith("claude-") or lowered.startswith("kimi-")
def reasoning_extra_body(model: str, max_thinking_tokens: int) -> dict[str, Any] | None:
"""Build the ``extra_body["reasoning"]`` fragment for the OpenAI client.
Returns ``None`` for non-reasoning routes and for
``max_thinking_tokens <= 0`` (operator kill switch).
"""
if not _is_reasoning_route(model) or max_thinking_tokens <= 0:
return None
return {"reasoning": {"max_tokens": max_thinking_tokens}}
class BaselineReasoningEmitter:
"""Owns the reasoning block lifecycle for one streaming round.
Two concerns live here, both driven by the same state machine:
1. **Wire events.** The AI SDK v6 wire format pairs every
``reasoning-start`` with a matching ``reasoning-end`` and treats
reasoning / text / tool-use as distinct UI parts that must not
interleave.
2. **Session persistence.** ``ChatMessage(role="reasoning")`` rows in
``session.messages`` are what
``convertChatSessionToUiMessages.ts`` folds into the assistant
bubble as ``{type: "reasoning"}`` UI parts on reload and on
``useHydrateOnStreamEnd`` swaps. Without them the live-streamed
reasoning parts get overwritten by the hydrated (reasoning-less)
message list the moment the stream ends. Mirrors the SDK path's
``acc.reasoning_response`` pattern so both routes render the same
way on reload.
Pass ``session_messages`` to enable persistence; omit for pure
wire-emission (tests, scratch callers). On first reasoning delta a
fresh ``ChatMessage(role="reasoning")`` is appended and mutated
in-place as further deltas arrive; :meth:`close` drops the reference
but leaves the appended row intact.
``render_in_ui=False`` suppresses only the live wire events
(``StreamReasoning*``); the ``role='reasoning'`` persistence row is
still appended so ``convertChatSessionToUiMessages.ts`` can hydrate
the reasoning bubble on reload. The state machine advances
identically either way.
"""
def __init__(
self,
session_messages: list[ChatMessage] | None = None,
*,
coalesce_min_chars: int = _COALESCE_MIN_CHARS,
coalesce_max_interval_ms: float = _COALESCE_MAX_INTERVAL_MS,
render_in_ui: bool = True,
) -> None:
self._block_id: str = str(uuid.uuid4())
self._open: bool = False
self._session_messages = session_messages
self._current_row: ChatMessage | None = None
# Coalescing state — tests can disable (``=0``) for deterministic
# event assertions.
self._coalesce_min_chars = coalesce_min_chars
self._coalesce_max_interval_ms = coalesce_max_interval_ms
self._pending_delta: str = ""
self._last_flush_monotonic: float = 0.0
self._render_in_ui = render_in_ui
@property
def is_open(self) -> bool:
return self._open
def on_delta(self, delta: ChoiceDelta) -> list[StreamBaseResponse]:
"""Return events for the reasoning text carried by *delta*.
Empty list when the chunk carries no reasoning payload, so this is
safe to call on every chunk without guarding at the call site.
Persistence (when a session message list is attached) stays
per-delta so the DB row's content always equals the concatenation
of wire deltas at every chunk boundary, independent of the
coalescing window. Only the wire emission is batched.
"""
ext = OpenRouterDeltaExtension.from_delta(delta)
text = ext.visible_text()
if not text:
return []
events: list[StreamBaseResponse] = []
# First reasoning text in this block — emit Start + the first Delta
# atomically so the frontend Reasoning collapse renders immediately
# rather than waiting for the coalesce window to elapse. Subsequent
# chunks buffer into ``_pending_delta`` and only flush when the
# char/time thresholds trip.
# Sample the monotonic clock exactly once per chunk — at ~4,700
# chunks per turn, folding the two calls into one cuts ~4,700
# syscalls off the hot path without changing semantics.
now = time.monotonic()
if not self._open:
if self._render_in_ui:
events.append(StreamReasoningStart(id=self._block_id))
events.append(StreamReasoningDelta(id=self._block_id, delta=text))
self._open = True
self._last_flush_monotonic = now
if self._session_messages is not None:
self._current_row = ChatMessage(role="reasoning", content=text)
self._session_messages.append(self._current_row)
return events
if self._current_row is not None:
self._current_row.content = (self._current_row.content or "") + text
self._pending_delta += text
if self._should_flush_pending(now):
if self._render_in_ui:
events.append(
StreamReasoningDelta(id=self._block_id, delta=self._pending_delta)
)
self._pending_delta = ""
self._last_flush_monotonic = now
return events
def _should_flush_pending(self, now: float) -> bool:
"""Return True when the accumulated delta should be emitted now.
*now* is the monotonic timestamp sampled by the caller so the
clock is read at most once per chunk (the flush-timestamp update
reuses the same value).
"""
if not self._pending_delta:
return False
if len(self._pending_delta) >= self._coalesce_min_chars:
return True
elapsed_ms = (now - self._last_flush_monotonic) * 1000.0
return elapsed_ms >= self._coalesce_max_interval_ms
def close(self) -> list[StreamBaseResponse]:
"""Emit ``StreamReasoningEnd`` for the open block (if any) and rotate.
Idempotent — returns ``[]`` when no block is open. Drains any
still-buffered delta first so the frontend never loses tail text
from the coalesce window. The id rotation guarantees the next
reasoning block starts with a fresh id rather than reusing one
already closed on the wire. The persisted row is not removed —
it stays in ``session_messages`` as the durable record of what
was reasoned.
"""
if not self._open:
return []
events: list[StreamBaseResponse] = []
if self._render_in_ui:
if self._pending_delta:
events.append(
StreamReasoningDelta(id=self._block_id, delta=self._pending_delta)
)
events.append(StreamReasoningEnd(id=self._block_id))
self._pending_delta = ""
self._open = False
self._block_id = str(uuid.uuid4())
self._current_row = None
return events

View File

@@ -0,0 +1,514 @@
"""Tests for the baseline reasoning extension module.
Covers the typed OpenRouter delta parser, the stateful emitter, and the
``extra_body`` builder. The emitter is tested against real
``ChoiceDelta`` pydantic instances so the ``model_extra`` plumbing the
parser relies on is exercised end-to-end.
"""
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from backend.copilot.baseline.reasoning import (
BaselineReasoningEmitter,
OpenRouterDeltaExtension,
ReasoningDetail,
_is_reasoning_route,
reasoning_extra_body,
)
from backend.copilot.model import ChatMessage
from backend.copilot.response_model import (
StreamReasoningDelta,
StreamReasoningEnd,
StreamReasoningStart,
)
def _delta(**extra) -> ChoiceDelta:
"""Build a ChoiceDelta with the given extension fields on ``model_extra``."""
return ChoiceDelta.model_validate({"role": "assistant", **extra})
class TestReasoningDetail:
def test_visible_text_prefers_text(self):
d = ReasoningDetail(type="reasoning.text", text="hi", summary="ignored")
assert d.visible_text == "hi"
def test_visible_text_falls_back_to_summary(self):
d = ReasoningDetail(type="reasoning.summary", summary="tldr")
assert d.visible_text == "tldr"
def test_visible_text_empty_for_encrypted(self):
d = ReasoningDetail(type="reasoning.encrypted")
assert d.visible_text == ""
def test_unknown_fields_are_ignored(self):
# OpenRouter may add new fields in future payloads — they shouldn't
# cause validation errors.
d = ReasoningDetail.model_validate(
{"type": "reasoning.future", "text": "x", "signature": "opaque"}
)
assert d.text == "x"
def test_visible_text_empty_for_unknown_type(self):
# Unknown types may carry provider metadata that must not render as
# user-visible reasoning — regardless of whether a text/summary is
# present. Only ``reasoning.text`` / ``reasoning.summary`` surface.
d = ReasoningDetail(type="reasoning.future", text="leaked metadata")
assert d.visible_text == ""
def test_visible_text_surfaces_text_when_type_missing(self):
# Pre-``reasoning_details`` OpenRouter payloads omit ``type`` — treat
# them as text so we don't regress the legacy structured shape.
d = ReasoningDetail(text="plain")
assert d.visible_text == "plain"
class TestOpenRouterDeltaExtension:
def test_from_delta_reads_model_extra(self):
delta = _delta(reasoning="step one")
ext = OpenRouterDeltaExtension.from_delta(delta)
assert ext.reasoning == "step one"
def test_visible_text_legacy_string(self):
ext = OpenRouterDeltaExtension(reasoning="plain text")
assert ext.visible_text() == "plain text"
def test_visible_text_deepseek_alias(self):
ext = OpenRouterDeltaExtension(reasoning_content="alt channel")
assert ext.visible_text() == "alt channel"
def test_visible_text_structured_details_concat(self):
ext = OpenRouterDeltaExtension(
reasoning_details=[
ReasoningDetail(type="reasoning.text", text="hello "),
ReasoningDetail(type="reasoning.text", text="world"),
]
)
assert ext.visible_text() == "hello world"
def test_visible_text_skips_encrypted(self):
ext = OpenRouterDeltaExtension(
reasoning_details=[
ReasoningDetail(type="reasoning.encrypted"),
ReasoningDetail(type="reasoning.text", text="visible"),
]
)
assert ext.visible_text() == "visible"
def test_visible_text_empty_when_all_channels_blank(self):
ext = OpenRouterDeltaExtension()
assert ext.visible_text() == ""
def test_empty_delta_produces_empty_extension(self):
ext = OpenRouterDeltaExtension.from_delta(_delta())
assert ext.reasoning is None
assert ext.reasoning_content is None
assert ext.reasoning_details == []
def test_malformed_reasoning_payload_logged_and_swallowed(self, caplog):
# A malformed payload (e.g. reasoning_details shipped as a string
# rather than a list) must not abort the stream — log it and
# return an empty extension so valid text/tool events keep flowing.
# A plain mock is used here because ``from_delta`` only reads
# ``delta.model_extra`` — avoids reaching into pydantic internals
# (``__pydantic_extra__``) that could be renamed across versions.
from unittest.mock import MagicMock
delta = MagicMock(spec=ChoiceDelta)
delta.model_extra = {"reasoning_details": "not a list"}
with caplog.at_level("WARNING"):
ext = OpenRouterDeltaExtension.from_delta(delta)
assert ext.reasoning_details == []
assert ext.visible_text() == ""
assert any("malformed" in r.message.lower() for r in caplog.records)
def test_unknown_typed_entry_with_text_is_not_surfaced(self):
# Regression: the legacy extractor emitted any entry with a
# ``text`` or ``summary`` field. The typed parser now filters on
# the recognised types so future provider metadata can't leak
# into the reasoning collapse.
ext = OpenRouterDeltaExtension(
reasoning_details=[
ReasoningDetail(type="reasoning.future", text="provider metadata"),
ReasoningDetail(type="reasoning.text", text="real"),
]
)
assert ext.visible_text() == "real"
class TestIsReasoningRoute:
def test_anthropic_routes(self):
assert _is_reasoning_route("anthropic/claude-sonnet-4-6")
assert _is_reasoning_route("claude-3-5-sonnet-20241022")
assert _is_reasoning_route("anthropic.claude-3-5-sonnet")
assert _is_reasoning_route("ANTHROPIC/Claude-Opus") # case-insensitive
def test_moonshot_kimi_routes(self):
# OpenRouter advertises the ``reasoning`` extension on Moonshot
# endpoints — both K2.6 (the new baseline default) and the
# reasoning-native kimi-k2-thinking variant.
assert _is_reasoning_route("moonshotai/kimi-k2.6")
assert _is_reasoning_route("moonshotai/kimi-k2-thinking")
assert _is_reasoning_route("moonshotai/kimi-k2.5")
# Direct (non-OpenRouter) model ids also resolve via the ``kimi-``
# prefix so a future bare ``kimi-k3`` id would still match.
assert _is_reasoning_route("kimi-k2-instruct")
# Provider-prefixed bare kimi ids (without the ``moonshotai/``
# prefix) are also recognised — the match anchors on the final
# path segment.
assert _is_reasoning_route("openrouter/kimi-k2.6")
def test_other_providers_rejected(self):
assert not _is_reasoning_route("openai/gpt-4o")
assert not _is_reasoning_route("google/gemini-2.5-pro")
assert not _is_reasoning_route("xai/grok-4")
assert not _is_reasoning_route("meta-llama/llama-3.3-70b-instruct")
assert not _is_reasoning_route("deepseek/deepseek-r1")
def test_kimi_substring_false_positives_rejected(self):
# Regression: the previous implementation matched any model whose
# name contained the substring ``kimi`` — including unrelated model
# ids like ``hakimi``. The anchored match below rejects them.
assert not _is_reasoning_route("some-provider/hakimi-large")
assert not _is_reasoning_route("hakimi")
assert not _is_reasoning_route("akimi-7b")
def test_claude_substring_false_positives_rejected(self):
# Regression (Sentry review on #12871): ``'claude' in lowered``
# matched any substring — a custom
# ``someprovider/claude-mock-v1`` set via
# ``CHAT_FAST_STANDARD_MODEL`` would inherit the reasoning
# extra_body and take a 400 from its upstream. The anchored
# match requires either an ``anthropic`` / ``anthropic.`` /
# ``anthropic/`` prefix, or a bare ``claude-`` id with no
# provider prefix.
assert not _is_reasoning_route("someprovider/claude-mock-v1")
assert not _is_reasoning_route("custom/claude-like-model")
# Same principle for Kimi — a non-Moonshot provider prefix is
# rejected even when the model id starts with ``kimi-``.
assert not _is_reasoning_route("other/kimi-pro")
class TestReasoningExtraBody:
def test_anthropic_route_returns_fragment(self):
assert reasoning_extra_body("anthropic/claude-sonnet-4-6", 4096) == {
"reasoning": {"max_tokens": 4096}
}
def test_direct_claude_model_id_still_matches(self):
assert reasoning_extra_body("claude-3-5-sonnet-20241022", 2048) == {
"reasoning": {"max_tokens": 2048}
}
def test_kimi_routes_return_fragment(self):
# Kimi K2.6 ships the same OpenRouter ``reasoning`` extension as
# Anthropic, so the gate widened with this PR and the fragment
# must now materialise on Moonshot routes too.
assert reasoning_extra_body("moonshotai/kimi-k2.6", 8192) == {
"reasoning": {"max_tokens": 8192}
}
assert reasoning_extra_body("moonshotai/kimi-k2-thinking", 4096) == {
"reasoning": {"max_tokens": 4096}
}
def test_non_reasoning_route_returns_none(self):
assert reasoning_extra_body("openai/gpt-4o", 4096) is None
assert reasoning_extra_body("google/gemini-2.5-pro", 4096) is None
assert reasoning_extra_body("xai/grok-4", 4096) is None
def test_zero_max_tokens_kill_switch(self):
# Operator kill switch: ``max_thinking_tokens <= 0`` disables the
# ``reasoning`` extra_body fragment on ANY reasoning route (Anthropic
# or Kimi). Lets us silence reasoning without dropping the SDK
# path's budget.
assert reasoning_extra_body("anthropic/claude-sonnet-4-6", 0) is None
assert reasoning_extra_body("anthropic/claude-sonnet-4-6", -1) is None
assert reasoning_extra_body("moonshotai/kimi-k2.6", 0) is None
class TestBaselineReasoningEmitter:
def test_first_text_delta_emits_start_then_delta(self):
emitter = BaselineReasoningEmitter()
events = emitter.on_delta(_delta(reasoning="thinking"))
assert len(events) == 2
assert isinstance(events[0], StreamReasoningStart)
assert isinstance(events[1], StreamReasoningDelta)
assert events[0].id == events[1].id
assert events[1].delta == "thinking"
assert emitter.is_open is True
def test_subsequent_deltas_reuse_block_id_without_new_start(self):
# Disable coalescing so each chunk flushes immediately — this test
# is about the Start/Delta/block-id state machine, not the coalesce
# window. Coalescing behaviour is covered below.
emitter = BaselineReasoningEmitter(
coalesce_min_chars=0, coalesce_max_interval_ms=0
)
first = emitter.on_delta(_delta(reasoning="a"))
second = emitter.on_delta(_delta(reasoning="b"))
assert any(isinstance(e, StreamReasoningStart) for e in first)
assert all(not isinstance(e, StreamReasoningStart) for e in second)
assert len(second) == 1
assert isinstance(second[0], StreamReasoningDelta)
assert first[0].id == second[0].id
def test_empty_delta_emits_nothing(self):
emitter = BaselineReasoningEmitter()
assert emitter.on_delta(_delta(content="hello")) == []
assert emitter.is_open is False
def test_close_emits_end_and_rotates_id(self):
emitter = BaselineReasoningEmitter()
# Capture the block id from the wire event rather than reaching
# into emitter internals — the id on the emitted Start/Delta is
# what the frontend actually receives.
start_events = emitter.on_delta(_delta(reasoning="x"))
first_id = start_events[0].id
events = emitter.close()
assert len(events) == 1
assert isinstance(events[0], StreamReasoningEnd)
assert events[0].id == first_id
assert emitter.is_open is False
# Next reasoning uses a fresh id.
new_events = emitter.on_delta(_delta(reasoning="y"))
assert isinstance(new_events[0], StreamReasoningStart)
assert new_events[0].id != first_id
def test_close_is_idempotent(self):
emitter = BaselineReasoningEmitter()
assert emitter.close() == []
emitter.on_delta(_delta(reasoning="x"))
assert len(emitter.close()) == 1
assert emitter.close() == []
def test_structured_details_round_trip(self):
emitter = BaselineReasoningEmitter()
events = emitter.on_delta(
_delta(
reasoning_details=[
{"type": "reasoning.text", "text": "plan: "},
{"type": "reasoning.summary", "summary": "do the thing"},
]
)
)
deltas = [e for e in events if isinstance(e, StreamReasoningDelta)]
assert len(deltas) == 1
assert deltas[0].delta == "plan: do the thing"
class TestReasoningDeltaCoalescing:
"""Coalescing batches fine-grained provider chunks into bigger wire
frames. OpenRouter's Kimi K2.6 emits ~4,700 reasoning-delta chunks
per turn vs ~28 for Sonnet; without batching, every chunk becomes one
Redis ``xadd`` + one SSE event + one React re-render of the
non-virtualised chat list, which paint-storms the browser. These
tests pin the batching contract: small chunks buffer until the
char-size or time threshold trips, large chunks still flush
immediately, and ``close()`` never drops tail text."""
def test_small_chunks_after_first_buffer_until_threshold(self):
# Generous time threshold so size alone controls flush timing.
emitter = BaselineReasoningEmitter(
coalesce_min_chars=32, coalesce_max_interval_ms=60_000
)
# First chunk always flushes immediately (so UI renders without
# waiting).
first = emitter.on_delta(_delta(reasoning="hi "))
assert any(isinstance(e, StreamReasoningStart) for e in first)
assert sum(isinstance(e, StreamReasoningDelta) for e in first) == 1
# Subsequent small chunks buffer silently — 5 × 4 chars = 20 chars,
# still under the 32-char threshold.
for _ in range(5):
assert emitter.on_delta(_delta(reasoning="abcd")) == []
# Once the threshold is crossed, the accumulated buffer flushes
# as a single StreamReasoningDelta carrying every buffered chunk.
flush = emitter.on_delta(_delta(reasoning="efghijklmnop"))
assert len(flush) == 1
assert isinstance(flush[0], StreamReasoningDelta)
assert flush[0].delta == "abcd" * 5 + "efghijklmnop"
def test_time_based_flush_when_chars_stay_below_threshold(self, monkeypatch):
# Fake ``time.monotonic`` so we can drive the time-based branch
# deterministically without real sleeps.
from backend.copilot.baseline import reasoning as rmod
fake_now = [0.0]
monkeypatch.setattr(rmod.time, "monotonic", lambda: fake_now[0])
emitter = BaselineReasoningEmitter(
coalesce_min_chars=1000, coalesce_max_interval_ms=40
)
# t=0: first chunk flushes immediately.
first = emitter.on_delta(_delta(reasoning="a"))
assert sum(isinstance(e, StreamReasoningDelta) for e in first) == 1
# t=10 ms: still under 40 ms → buffer.
fake_now[0] = 0.010
assert emitter.on_delta(_delta(reasoning="b")) == []
# t=50 ms since last flush → time threshold trips, flush fires.
fake_now[0] = 0.060
flushed = emitter.on_delta(_delta(reasoning="c"))
assert len(flushed) == 1
assert isinstance(flushed[0], StreamReasoningDelta)
assert flushed[0].delta == "bc"
def test_close_flushes_tail_buffer_before_end(self):
emitter = BaselineReasoningEmitter(
coalesce_min_chars=1000, coalesce_max_interval_ms=60_000
)
emitter.on_delta(_delta(reasoning="first")) # flushes (first chunk)
emitter.on_delta(_delta(reasoning=" middle ")) # buffered
emitter.on_delta(_delta(reasoning="tail")) # buffered
events = emitter.close()
assert len(events) == 2
assert isinstance(events[0], StreamReasoningDelta)
assert events[0].delta == " middle tail"
assert isinstance(events[1], StreamReasoningEnd)
def test_coalesce_disabled_flushes_every_chunk(self):
emitter = BaselineReasoningEmitter(
coalesce_min_chars=0, coalesce_max_interval_ms=0
)
first = emitter.on_delta(_delta(reasoning="a"))
second = emitter.on_delta(_delta(reasoning="b"))
assert sum(isinstance(e, StreamReasoningDelta) for e in first) == 1
assert sum(isinstance(e, StreamReasoningDelta) for e in second) == 1
def test_persistence_stays_per_delta_even_when_wire_coalesces(self):
"""DB row content must track every chunk so a crash mid-turn
persists the full reasoning-so-far, even if the coalesce window
never flushed those chunks to the wire."""
session: list[ChatMessage] = []
emitter = BaselineReasoningEmitter(
session,
coalesce_min_chars=1000,
coalesce_max_interval_ms=60_000,
)
emitter.on_delta(_delta(reasoning="first "))
emitter.on_delta(_delta(reasoning="chunk "))
emitter.on_delta(_delta(reasoning="three"))
# No close; verify the persisted row already has everything.
assert len(session) == 1
assert session[0].content == "first chunk three"
class TestReasoningPersistence:
"""The persistence contract: without ``role="reasoning"`` rows in
session.messages, useHydrateOnStreamEnd overwrites the live-streamed
reasoning parts and the Reasoning collapse vanishes. Every delta
must be reflected in the persisted row the moment it's emitted."""
def test_session_row_appended_on_first_delta(self):
session: list[ChatMessage] = []
emitter = BaselineReasoningEmitter(session)
assert session == []
emitter.on_delta(_delta(reasoning="hi"))
assert len(session) == 1
assert session[0].role == "reasoning"
assert session[0].content == "hi"
def test_subsequent_deltas_mutate_same_row(self):
session: list[ChatMessage] = []
emitter = BaselineReasoningEmitter(session)
emitter.on_delta(_delta(reasoning="part one "))
emitter.on_delta(_delta(reasoning="part two"))
assert len(session) == 1
assert session[0].content == "part one part two"
def test_close_keeps_row_in_session(self):
session: list[ChatMessage] = []
emitter = BaselineReasoningEmitter(session)
emitter.on_delta(_delta(reasoning="thought"))
emitter.close()
assert len(session) == 1
assert session[0].content == "thought"
def test_second_reasoning_block_appends_new_row(self):
session: list[ChatMessage] = []
emitter = BaselineReasoningEmitter(session)
emitter.on_delta(_delta(reasoning="first"))
emitter.close()
emitter.on_delta(_delta(reasoning="second"))
assert len(session) == 2
assert [m.content for m in session] == ["first", "second"]
def test_no_session_means_no_persistence(self):
"""Emitter without attached session list emits wire events only."""
emitter = BaselineReasoningEmitter()
events = emitter.on_delta(_delta(reasoning="pure wire"))
assert len(events) == 2 # start + delta, no crash
# Nothing else to assert — just proves None session is supported.
class TestBaselineReasoningEmitterRenderFlag:
"""``render_in_ui=False`` must silence ``StreamReasoning*`` wire events
AND drop persistence of ``role="reasoning"`` rows — the operator hides
the collapse on both the live wire and on reload. Persistence is tied
to the wire events because the frontend's hydration path unconditionally
re-renders persisted reasoning rows; keeping them would make the flag a
no-op post-reload. These tests pin the contract in both directions so
future refactors can't flip only one half."""
def test_render_off_suppresses_start_and_delta(self):
emitter = BaselineReasoningEmitter(render_in_ui=False)
events = emitter.on_delta(_delta(reasoning="hidden"))
# No wire events, but state advanced (is_open == True) so close()
# below has something to rotate.
assert events == []
assert emitter.is_open is True
def test_render_off_suppresses_close_end(self):
emitter = BaselineReasoningEmitter(render_in_ui=False)
emitter.on_delta(_delta(reasoning="hidden"))
events = emitter.close()
assert events == []
assert emitter.is_open is False
def test_render_off_still_persists(self):
"""Persistence is decoupled from the render flag — session
transcript always keeps the ``role="reasoning"`` row so audit
and ``--resume``-equivalent replay never lose thinking text.
The frontend gates rendering separately."""
session: list[ChatMessage] = []
emitter = BaselineReasoningEmitter(session, render_in_ui=False)
emitter.on_delta(_delta(reasoning="part one "))
emitter.on_delta(_delta(reasoning="part two"))
emitter.close()
assert len(session) == 1
assert session[0].role == "reasoning"
assert session[0].content == "part one part two"
def test_render_off_rotates_block_id_between_sessions(self):
"""Even with wire events silenced the block id must rotate on close,
otherwise a hypothetical mid-session flip would reuse a stale id."""
emitter = BaselineReasoningEmitter(render_in_ui=False)
emitter.on_delta(_delta(reasoning="first"))
first_block_id = emitter._block_id
emitter.close()
emitter.on_delta(_delta(reasoning="second"))
assert emitter._block_id != first_block_id
def test_render_on_is_default(self):
"""Defaulting to True preserves backward compat — existing callers
that don't pass the kwarg keep emitting wire events as before."""
emitter = BaselineReasoningEmitter()
events = emitter.on_delta(_delta(reasoning="hello"))
assert len(events) == 2
assert isinstance(events[0], StreamReasoningStart)
assert isinstance(events[1], StreamReasoningDelta)

File diff suppressed because it is too large Load Diff

View File

@@ -63,21 +63,117 @@ def _make_session_messages(*roles: str) -> list[ChatMessage]:
class TestResolveBaselineModel:
"""Model selection honours the per-request mode."""
"""Baseline model resolution honours the per-request tier toggle.
def test_fast_mode_selects_fast_model(self):
assert _resolve_baseline_model("fast") == config.fast_model
Baseline reads the ``fast_*_model`` cells of the (path, tier) matrix
and never falls through to the SDK-side ``thinking_*_model`` cells.
Without a user_id (so no LD context) the resolver returns the
``ChatConfig`` static default; per-user overrides are exercised in
``copilot/model_router_test.py``.
"""
def test_extended_thinking_selects_default_model(self):
assert _resolve_baseline_model("extended_thinking") == config.model
@pytest.mark.asyncio
async def test_advanced_tier_selects_fast_advanced_model(self):
assert (
await _resolve_baseline_model("advanced", None)
== config.fast_advanced_model
)
def test_none_mode_selects_default_model(self):
"""Critical: baseline users without a mode MUST keep the default (opus)."""
assert _resolve_baseline_model(None) == config.model
@pytest.mark.asyncio
async def test_standard_tier_selects_fast_standard_model(self):
assert (
await _resolve_baseline_model("standard", None)
== config.fast_standard_model
)
def test_default_and_fast_models_same(self):
"""SDK defaults currently keep standard and fast on Sonnet 4.6."""
assert config.model == config.fast_model
@pytest.mark.asyncio
async def test_none_tier_selects_fast_standard_model(self):
"""Baseline users without a tier get the fast-standard default."""
assert await _resolve_baseline_model(None, None) == config.fast_standard_model
def test_fast_standard_default_is_sonnet(self):
"""Shipped default: Sonnet on the baseline standard cell — the
non-Anthropic routes ship via the LD flag instead of a config
change. Asserts the declared ``Field`` default so a deploy-time
``CHAT_FAST_STANDARD_MODEL`` override doesn't flake CI."""
from backend.copilot.config import ChatConfig
assert (
ChatConfig.model_fields["fast_standard_model"].default
== "anthropic/claude-sonnet-4-6"
)
def test_fast_advanced_default_is_opus(self):
"""Shipped default: Opus on the baseline advanced cell — mirrors
the SDK advanced cell so the advanced-tier A/B stays clean
(same model, different path)."""
from backend.copilot.config import ChatConfig
assert (
ChatConfig.model_fields["fast_advanced_model"].default
== "anthropic/claude-opus-4.7"
)
def test_standard_and_advanced_cells_differ_on_fast(self):
"""Advanced tier defaults to a different model than standard on
the baseline path. Checked against declared ``Field`` defaults
so operator env overrides don't flake the test."""
from backend.copilot.config import ChatConfig
assert (
ChatConfig.model_fields["fast_standard_model"].default
!= ChatConfig.model_fields["fast_advanced_model"].default
)
def test_legacy_env_aliases_route_to_new_fields(self, monkeypatch):
"""Backward compat: the pre-split env var names must still bind.
The four-field matrix was introduced with ``validation_alias``
entries so that existing deployments setting ``CHAT_MODEL`` /
``CHAT_ADVANCED_MODEL`` / ``CHAT_FAST_MODEL`` continue to override
the same effective cell without a rename. Construct a fresh
``ChatConfig`` with each legacy name set and confirm it lands on
the new field.
"""
from backend.copilot.config import ChatConfig
monkeypatch.setenv("CHAT_MODEL", "legacy/sonnet-via-chat-model")
monkeypatch.setenv("CHAT_ADVANCED_MODEL", "legacy/opus-via-advanced")
monkeypatch.setenv("CHAT_FAST_MODEL", "legacy/fast-via-fast-model")
cfg = ChatConfig()
assert cfg.thinking_standard_model == "legacy/sonnet-via-chat-model"
assert cfg.thinking_advanced_model == "legacy/opus-via-advanced"
assert cfg.fast_standard_model == "legacy/fast-via-fast-model"
def test_all_four_new_env_vars_bind_to_their_cells(self, monkeypatch):
"""Each of the four (path, tier) cells must be overridable via
its documented ``CHAT_*_*_MODEL`` env var — including
``CHAT_FAST_ADVANCED_MODEL`` which was missing a
``validation_alias`` in the original split and only bound
implicitly through ``env_prefix``. Pinning all four here so
that whenever someone touches the config shape, an accidental
unbinding fails CI instead of silently ignoring operator
overrides.
"""
from backend.copilot.config import ChatConfig
monkeypatch.setenv("CHAT_FAST_STANDARD_MODEL", "explicit/fast-std")
monkeypatch.setenv("CHAT_FAST_ADVANCED_MODEL", "explicit/fast-adv")
monkeypatch.setenv("CHAT_THINKING_STANDARD_MODEL", "explicit/think-std")
monkeypatch.setenv("CHAT_THINKING_ADVANCED_MODEL", "explicit/think-adv")
# Clear the legacy aliases so they don't win priority in
# ``AliasChoices`` (first match wins).
for legacy in ("CHAT_MODEL", "CHAT_ADVANCED_MODEL", "CHAT_FAST_MODEL"):
monkeypatch.delenv(legacy, raising=False)
cfg = ChatConfig()
assert cfg.fast_standard_model == "explicit/fast-std"
assert cfg.fast_advanced_model == "explicit/fast-adv"
assert cfg.thinking_standard_model == "explicit/think-std"
assert cfg.thinking_advanced_model == "explicit/think-adv"
class TestLoadPriorTranscript:

View File

@@ -0,0 +1,217 @@
"""Builder-session context helpers — split cacheable system prompt from
the volatile per-turn snapshot so Claude's prompt cache stays warm."""
from __future__ import annotations
import logging
from typing import Any
from backend.copilot.model import ChatSession
from backend.copilot.permissions import CopilotPermissions
from backend.copilot.tools.agent_generator import get_agent_as_json
from backend.copilot.tools.get_agent_building_guide import _load_guide
logger = logging.getLogger(__name__)
BUILDER_CONTEXT_TAG = "builder_context"
BUILDER_SESSION_TAG = "builder_session"
# Tools hidden from builder-bound sessions: ``create_agent`` /
# ``customize_agent`` would mint a new graph (panel is bound to one),
# and ``get_agent_building_guide`` duplicates bytes already in the
# system-prompt suffix. Everything else (find_block, find_agent, …)
# stays available so the LLM can look up ids instead of hallucinating.
BUILDER_BLOCKED_TOOLS: tuple[str, ...] = (
"create_agent",
"customize_agent",
"get_agent_building_guide",
)
def resolve_session_permissions(
session: ChatSession | None,
) -> CopilotPermissions | None:
"""Blacklist :data:`BUILDER_BLOCKED_TOOLS` for builder-bound sessions,
return ``None`` (unrestricted) otherwise."""
if session is None or not session.metadata.builder_graph_id:
return None
return CopilotPermissions(
tools=list(BUILDER_BLOCKED_TOOLS),
tools_exclude=True,
)
# Caps — mirror the frontend ``serializeGraphForChat`` defaults so the
# server-side block stays within a practical token budget for large graphs.
_MAX_NODES = 100
_MAX_LINKS = 200
_FETCH_FAILED_PREFIX = (
f"<{BUILDER_CONTEXT_TAG}>\n"
f"<status>fetch_failed</status>\n"
f"</{BUILDER_CONTEXT_TAG}>\n\n"
)
# Embedded in the cacheable suffix so the LLM picks the right run_agent
# dispatch mode without forcing the user to watch a long-blocking call.
_BUILDER_RUN_AGENT_GUIDANCE = (
"You are operating inside the builder panel, not the standalone "
"copilot page. The builder page already subscribes to agent "
"executions the moment you return an execution_id, so for REAL "
"(non-dry) runs prefer `run_agent(dry_run=False, wait_for_result=0)` "
"— the user will see the run stream in the builder's execution panel "
"in-place and your turn ends immediately with the id. For DRY-RUNS "
"keep `dry_run=True, wait_for_result=120`: blocking is required so "
"you can inspect `execution.node_executions` and report the verdict "
"in the same turn."
)
def _sanitize_for_xml(value: Any) -> str:
"""Escape XML special chars — mirrors ``sanitizeForXml`` in
``BuilderChatPanel/helpers.ts``."""
s = "" if value is None else str(value)
return (
s.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace('"', "&quot;")
.replace("'", "&apos;")
)
def _node_display_name(node: dict[str, Any]) -> str:
"""Prefer the user-set label (``input_default.name`` / ``metadata.title``);
fall back to the block id."""
defaults = node.get("input_default") or {}
metadata = node.get("metadata") or {}
for key in ("name", "title", "label"):
value = defaults.get(key) or metadata.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
block_id = node.get("block_id") or ""
return block_id or "unknown"
def _format_nodes(nodes: list[dict[str, Any]]) -> str:
if not nodes:
return "<nodes>\n</nodes>"
visible = nodes[:_MAX_NODES]
lines = []
for node in visible:
node_id = _sanitize_for_xml(node.get("id") or "")
name = _sanitize_for_xml(_node_display_name(node))
block_id = _sanitize_for_xml(node.get("block_id") or "")
lines.append(f"- {node_id}: {name} ({block_id})")
extra = len(nodes) - len(visible)
if extra > 0:
lines.append(f"({extra} more not shown)")
body = "\n".join(lines)
return f"<nodes>\n{body}\n</nodes>"
def _format_links(
links: list[dict[str, Any]],
nodes: list[dict[str, Any]],
) -> str:
if not links:
return "<links>\n</links>"
name_by_id = {n.get("id"): _node_display_name(n) for n in nodes}
visible = links[:_MAX_LINKS]
lines = []
for link in visible:
src_id = link.get("source_id") or ""
dst_id = link.get("sink_id") or ""
src_name = name_by_id.get(src_id, src_id)
dst_name = name_by_id.get(dst_id, dst_id)
src_out = link.get("source_name") or ""
dst_in = link.get("sink_name") or ""
lines.append(
f"- {_sanitize_for_xml(src_name)}.{_sanitize_for_xml(src_out)} "
f"-> {_sanitize_for_xml(dst_name)}.{_sanitize_for_xml(dst_in)}"
)
extra = len(links) - len(visible)
if extra > 0:
lines.append(f"({extra} more not shown)")
body = "\n".join(lines)
return f"<links>\n{body}\n</links>"
async def build_builder_system_prompt_suffix(session: ChatSession) -> str:
"""Return the cacheable system-prompt suffix for a builder session.
Holds only static content (dispatch guidance + building guide) so the
bytes are identical across turns AND across sessions for different
graphs — the live id/name/version ride on the per-turn prefix.
"""
if not session.metadata.builder_graph_id:
return ""
try:
guide = _load_guide()
except Exception:
logger.exception("[builder_context] Failed to load agent-building guide")
return ""
# The guide is trusted server-side content (read from disk). We do NOT
# escape it — the LLM needs the raw markdown to make sense of block ids,
# code fences, and example JSON.
return (
f"\n\n<{BUILDER_SESSION_TAG}>\n"
f"<run_agent_dispatch_mode>\n"
f"{_BUILDER_RUN_AGENT_GUIDANCE}\n"
f"</run_agent_dispatch_mode>\n"
f"<building_guide>\n{guide}\n</building_guide>\n"
f"</{BUILDER_SESSION_TAG}>"
)
async def build_builder_context_turn_prefix(
session: ChatSession,
user_id: str | None,
) -> str:
"""Return the per-turn ``<builder_context>`` prefix with the live
graph snapshot (id/name/version/nodes/links). ``""`` for non-builder
sessions; fetch-failure marker if the graph cannot be read."""
graph_id = session.metadata.builder_graph_id
if not graph_id:
return ""
try:
agent_json = await get_agent_as_json(graph_id, user_id)
except Exception:
logger.exception(
"[builder_context] Failed to fetch graph %s for session %s",
graph_id,
session.session_id,
)
return _FETCH_FAILED_PREFIX
if not agent_json:
logger.warning(
"[builder_context] Graph %s not found for session %s",
graph_id,
session.session_id,
)
return _FETCH_FAILED_PREFIX
version = _sanitize_for_xml(agent_json.get("version") or "")
raw_name = agent_json.get("name")
graph_name = (
raw_name.strip() if isinstance(raw_name, str) and raw_name.strip() else None
)
nodes = agent_json.get("nodes") or []
links = agent_json.get("links") or []
name_attr = f' name="{_sanitize_for_xml(graph_name)}"' if graph_name else ""
graph_tag = (
f'<graph id="{_sanitize_for_xml(graph_id)}"'
f"{name_attr} "
f'version="{version}" '
f'node_count="{len(nodes)}" '
f'edge_count="{len(links)}"/>'
)
inner = f"{graph_tag}\n{_format_nodes(nodes)}\n{_format_links(links, nodes)}"
return f"<{BUILDER_CONTEXT_TAG}>\n{inner}\n</{BUILDER_CONTEXT_TAG}>\n\n"

View File

@@ -0,0 +1,329 @@
"""Tests for the split builder-context helpers.
Covers both halves of the public API:
- :func:`build_builder_system_prompt_suffix` — session-stable block
appended to the system prompt (contains the guide + graph id/name).
- :func:`build_builder_context_turn_prefix` — per-turn user-message
prefix (contains the live version + node/link snapshot).
"""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.builder_context import (
BUILDER_CONTEXT_TAG,
BUILDER_SESSION_TAG,
build_builder_context_turn_prefix,
build_builder_system_prompt_suffix,
)
from backend.copilot.model import ChatSession
def _session(
builder_graph_id: str | None,
*,
user_id: str = "test-user",
) -> ChatSession:
"""Minimal ``ChatSession`` with *builder_graph_id* on metadata."""
return ChatSession.new(
user_id,
dry_run=False,
builder_graph_id=builder_graph_id,
)
def _agent_json(
nodes: list[dict] | None = None,
links: list[dict] | None = None,
**overrides,
) -> dict:
base: dict = {
"id": "graph-1",
"name": "My Agent",
"description": "A test agent",
"version": 3,
"is_active": True,
"nodes": nodes if nodes is not None else [],
"links": links if links is not None else [],
}
base.update(overrides)
return base
# ---------------------------------------------------------------------------
# build_builder_system_prompt_suffix
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_system_prompt_suffix_empty_for_non_builder():
session = _session(None)
result = await build_builder_system_prompt_suffix(session)
assert result == ""
@pytest.mark.asyncio
async def test_system_prompt_suffix_contains_only_static_content():
session = _session("graph-1")
with patch(
"backend.copilot.builder_context._load_guide",
return_value="# Guide body",
):
suffix = await build_builder_system_prompt_suffix(session)
assert suffix.startswith("\n\n")
assert f"<{BUILDER_SESSION_TAG}>" in suffix
assert f"</{BUILDER_SESSION_TAG}>" in suffix
assert "<building_guide>" in suffix
assert "# Guide body" in suffix
# Dispatch-mode guidance must appear so the LLM knows to prefer
# wait_for_result=0 for real runs (builder UI subscribes live) and
# wait_for_result=120 for dry-runs (so it can inspect the node trace).
assert "<run_agent_dispatch_mode>" in suffix
assert "wait_for_result=0" in suffix
assert "wait_for_result=120" in suffix
# Regression: dynamic graph id/name must NOT leak into the cacheable
# suffix — they live in the per-turn prefix so renames and cross-graph
# sessions don't invalidate Claude's prompt cache.
assert "graph-1" not in suffix
assert "id=" not in suffix
assert "name=" not in suffix
@pytest.mark.asyncio
async def test_system_prompt_suffix_identical_across_graphs():
"""The suffix must be byte-identical regardless of which graph the
session is bound to — that's what keeps the cacheable prefix warm
across sessions."""
s1 = _session("graph-1")
s2 = _session("graph-2", user_id="different-owner")
with patch(
"backend.copilot.builder_context._load_guide",
return_value="# Guide body",
):
suffix_1 = await build_builder_system_prompt_suffix(s1)
suffix_2 = await build_builder_system_prompt_suffix(s2)
assert suffix_1 == suffix_2
@pytest.mark.asyncio
async def test_system_prompt_suffix_empty_when_guide_load_fails():
"""Guide load failure means we have nothing useful to add — emit an
empty suffix rather than a half-built block."""
session = _session("graph-1")
with patch(
"backend.copilot.builder_context._load_guide",
side_effect=OSError("missing"),
):
suffix = await build_builder_system_prompt_suffix(session)
assert suffix == ""
# ---------------------------------------------------------------------------
# build_builder_context_turn_prefix
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_turn_prefix_empty_for_non_builder():
session = _session(None)
result = await build_builder_context_turn_prefix(session, "user-1")
assert result == ""
@pytest.mark.asyncio
async def test_turn_prefix_contains_version_nodes_and_links():
session = _session("graph-1")
nodes = [
{
"id": "n1",
"block_id": "block-A",
"input_default": {"name": "Input"},
"metadata": {},
},
{
"id": "n2",
"block_id": "block-B",
"input_default": {},
"metadata": {},
},
]
links = [
{
"source_id": "n1",
"sink_id": "n2",
"source_name": "out",
"sink_name": "in",
}
]
agent = _agent_json(nodes=nodes, links=links)
with patch(
"backend.copilot.builder_context.get_agent_as_json",
new=AsyncMock(return_value=agent),
):
block = await build_builder_context_turn_prefix(session, "user-1")
assert block.startswith(f"<{BUILDER_CONTEXT_TAG}>\n")
assert block.endswith(f"</{BUILDER_CONTEXT_TAG}>\n\n")
assert 'id="graph-1"' in block
assert 'name="My Agent"' in block
assert 'version="3"' in block
assert 'node_count="2"' in block
assert 'edge_count="1"' in block
assert "n1: Input (block-A)" in block
assert "n2: block-B (block-B)" in block
assert "Input.out -> block-B.in" in block
@pytest.mark.asyncio
async def test_turn_prefix_does_not_include_guide():
"""The guide lives in the cacheable system prompt, not in the per-turn
prefix."""
session = _session("graph-1")
with (
patch(
"backend.copilot.builder_context.get_agent_as_json",
new=AsyncMock(return_value=_agent_json()),
),
# Sentinel guide text — if it leaks into the turn prefix the
# assertion below catches it.
patch(
"backend.copilot.builder_context._load_guide",
return_value="SENTINEL_GUIDE_BODY",
),
):
block = await build_builder_context_turn_prefix(session, "user-1")
assert "SENTINEL_GUIDE_BODY" not in block
assert "<building_guide>" not in block
@pytest.mark.asyncio
async def test_turn_prefix_escapes_graph_name():
session = _session("graph-1")
with patch(
"backend.copilot.builder_context.get_agent_as_json",
new=AsyncMock(return_value=_agent_json(name='<script>&"')),
):
block = await build_builder_context_turn_prefix(session, "user-1")
assert 'name="&lt;script&gt;&amp;&quot;"' in block
@pytest.mark.asyncio
async def test_turn_prefix_forwards_user_id_for_ownership():
"""The graph must be fetched with the caller's ``user_id`` so the
ownership check in ``get_graph`` is enforced — we never emit graph
metadata the session user is not entitled to see."""
session = _session("graph-1", user_id="owner-xyz")
agent_json_mock = AsyncMock(return_value=_agent_json())
with patch(
"backend.copilot.builder_context.get_agent_as_json",
new=agent_json_mock,
):
await build_builder_context_turn_prefix(session, "owner-xyz")
agent_json_mock.assert_awaited_once_with("graph-1", "owner-xyz")
@pytest.mark.asyncio
async def test_turn_prefix_fetch_failure_returns_marker():
session = _session("graph-1")
with patch(
"backend.copilot.builder_context.get_agent_as_json",
new=AsyncMock(side_effect=RuntimeError("boom")),
):
block = await build_builder_context_turn_prefix(session, "user-1")
assert block == (
f"<{BUILDER_CONTEXT_TAG}>\n"
"<status>fetch_failed</status>\n"
f"</{BUILDER_CONTEXT_TAG}>\n\n"
)
@pytest.mark.asyncio
async def test_turn_prefix_graph_not_found_returns_marker():
session = _session("graph-1")
with patch(
"backend.copilot.builder_context.get_agent_as_json",
new=AsyncMock(return_value=None),
):
block = await build_builder_context_turn_prefix(session, "user-1")
assert "<status>fetch_failed</status>" in block
@pytest.mark.asyncio
async def test_turn_prefix_node_cap_truncates_with_more_marker():
session = _session("graph-1")
nodes = [
{"id": f"n{i}", "block_id": "b", "input_default": {}, "metadata": {}}
for i in range(150)
]
agent = _agent_json(nodes=nodes)
with patch(
"backend.copilot.builder_context.get_agent_as_json",
new=AsyncMock(return_value=agent),
):
block = await build_builder_context_turn_prefix(session, "user-1")
assert 'node_count="150"' in block
# 50 nodes past the cap of 100.
assert "(50 more not shown)" in block
@pytest.mark.asyncio
async def test_turn_prefix_link_cap_truncates_with_more_marker():
session = _session("graph-1")
nodes = [
{"id": f"n{i}", "block_id": "b", "input_default": {}, "metadata": {}}
for i in range(5)
]
links = [
{
"source_id": "n0",
"sink_id": "n1",
"source_name": "out",
"sink_name": "in",
}
for _ in range(250)
]
agent = _agent_json(nodes=nodes, links=links)
with patch(
"backend.copilot.builder_context.get_agent_as_json",
new=AsyncMock(return_value=agent),
):
block = await build_builder_context_turn_prefix(session, "user-1")
assert 'edge_count="250"' in block
assert "(50 more not shown)" in block
@pytest.mark.asyncio
async def test_turn_prefix_xml_escaping_in_node_names():
session = _session("graph-1")
nodes = [
{
"id": "n1",
"block_id": "b",
"input_default": {"name": 'evil"</builder_context>"'},
"metadata": {},
}
]
agent = _agent_json(nodes=nodes)
with patch(
"backend.copilot.builder_context.get_agent_as_json",
new=AsyncMock(return_value=agent),
):
block = await build_builder_context_turn_prefix(session, "user-1")
# The raw closing tag must never appear inside the block content —
# escaping stops a user-controlled name from breaking out of the block.
assert "&lt;/builder_context&gt;" in block

View File

@@ -3,7 +3,7 @@
import os
from typing import Literal
from pydantic import Field, field_validator
from pydantic import AliasChoices, Field, field_validator, model_validator
from pydantic_settings import BaseSettings
from backend.util.clients import OPENROUTER_BASE_URL
@@ -17,8 +17,12 @@ from backend.util.clients import OPENROUTER_BASE_URL
CopilotMode = Literal["fast", "extended_thinking"]
# Per-request model tier set by the frontend model toggle.
# 'standard' uses the global config default (currently Sonnet).
# 'advanced' forces the highest-capability model (currently Opus).
# 'standard' picks the cheaper everyday model for the active path —
# ``fast_standard_model`` on the baseline path, ``thinking_standard_model``
# on the SDK path.
# 'advanced' picks the premium model for the active path — ``fast_advanced_model``
# on the baseline path, ``thinking_advanced_model`` on the SDK path (both
# default to Opus today).
# None means no preference — falls through to LD per-user targeting, then config.
# Using tier names instead of model names keeps the contract model-agnostic.
CopilotLlmModel = Literal["standard", "advanced"]
@@ -27,24 +31,60 @@ CopilotLlmModel = Literal["standard", "advanced"]
class ChatConfig(BaseSettings):
"""Configuration for the chat system."""
# OpenAI API Configuration
model: str = Field(
# Chat model tiers — a 2×2 of (path, tier). ``path`` = ``CopilotMode``
# (``"fast"`` → baseline OpenAI-compat / any OpenRouter model;
# ``"extended_thinking"`` → Claude Agent SDK, Anthropic-only CLI).
# ``tier`` = ``CopilotLlmModel`` (``"standard"`` / ``"advanced"``).
# Each cell has its own config so the two paths can evolve
# independently (cheap provider on baseline, Anthropic on SDK) at each
# tier without conflating one path's needs with the other's constraint.
#
# Historical env var names (``CHAT_MODEL`` / ``CHAT_ADVANCED_MODEL`` /
# ``CHAT_FAST_MODEL``) are preserved via ``validation_alias`` so
# existing deployments continue to override the same effective cell.
fast_standard_model: str = Field(
default="anthropic/claude-sonnet-4-6",
description="Default model for extended thinking mode. "
"Uses Sonnet 4.6 as the balanced default. "
"Override via CHAT_MODEL env var if you want a different default.",
validation_alias=AliasChoices(
"CHAT_FAST_STANDARD_MODEL",
"CHAT_FAST_MODEL",
),
description="Baseline path, 'standard' / ``None`` tier. Per-user "
"overrides flow through the ``copilot-fast-standard-model`` LD flag "
"(see ``copilot/model_router.py``); this value is the fallback.",
)
fast_model: str = Field(
fast_advanced_model: str = Field(
default="anthropic/claude-opus-4.7",
validation_alias=AliasChoices("CHAT_FAST_ADVANCED_MODEL"),
description="Baseline path, 'advanced' tier. LD override: "
"``copilot-fast-advanced-model``.",
)
thinking_standard_model: str = Field(
default="anthropic/claude-sonnet-4-6",
description="Model for fast mode (baseline path). Should be faster/cheaper than the default model.",
validation_alias=AliasChoices(
"CHAT_THINKING_STANDARD_MODEL",
"CHAT_MODEL",
),
description="SDK (extended-thinking) path, 'standard' / ``None`` "
"tier. LD override: ``copilot-thinking-standard-model``.",
)
thinking_advanced_model: str = Field(
default="anthropic/claude-opus-4.7",
validation_alias=AliasChoices(
"CHAT_THINKING_ADVANCED_MODEL",
"CHAT_ADVANCED_MODEL",
),
description="SDK (extended-thinking) path, 'advanced' tier. LD "
"override: ``copilot-thinking-advanced-model``.",
)
title_model: str = Field(
default="openai/gpt-4o-mini",
description="Model to use for generating session titles (should be fast/cheap)",
)
simulation_model: str = Field(
default="google/gemini-2.5-flash",
description="Model for dry-run block simulation (should be fast/cheap with good JSON output)",
default="google/gemini-2.5-flash-lite",
description="Model for dry-run block simulation (should be fast/cheap with good JSON output). "
"Gemini 2.5 Flash-Lite is ~3x cheaper than Flash ($0.10/$0.40 vs $0.30/$1.20 per MTok) "
"with JSON-mode reliability adequate for shape-matching block outputs.",
)
api_key: str | None = Field(default=None, description="OpenAI API key")
base_url: str | None = Field(
@@ -96,25 +136,31 @@ class ChatConfig(BaseSettings):
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
)
# Rate limiting — token-based limits per day and per week.
# Per-turn token cost varies with context size: ~10-15K for early turns,
# ~30-50K mid-session, up to ~100K pre-compaction. Average across a
# session with compaction cycles is ~25-35K tokens/turn, so 2.5M daily
# allows ~70-100 turns/day.
# Rate limiting — cost-based limits per day and per week, stored in
# microdollars (1 USD = 1_000_000). The counter tracks the real
# generation cost reported by the provider (OpenRouter ``usage.cost``
# or Claude Agent SDK ``total_cost_usd``), so cache discounts and
# cross-model price differences are already reflected — no token
# weighting or model multiplier is applied on top.
# Checked at the HTTP layer (routes.py) before each turn.
#
# These are base limits for the FREE tier. Higher tiers (PRO, BUSINESS,
# These are base limits for the FREE tier. Higher tiers (PRO, BUSINESS,
# ENTERPRISE) multiply these by their tier multiplier (see
# rate_limit.TIER_MULTIPLIERS). User tier is stored in the
# rate_limit.TIER_MULTIPLIERS). User tier is stored in the
# User.subscriptionTier DB column and resolved inside
# get_global_rate_limits().
daily_token_limit: int = Field(
default=2_500_000,
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
#
# These defaults act as the ceiling when LaunchDarkly is unreachable;
# the live per-tier values come from the COPILOT_*_COST_LIMIT flags.
daily_cost_limit_microdollars: int = Field(
default=1_000_000,
description="Max cost per day in microdollars, resets at midnight UTC "
"(0 = unlimited).",
)
weekly_token_limit: int = Field(
default=12_500_000,
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
weekly_cost_limit_microdollars: int = Field(
default=5_000_000,
description="Max cost per week in microdollars, resets Monday 00:00 UTC "
"(0 = unlimited).",
)
# Cost (in credits / cents) to reset the daily rate limit using credits.
@@ -139,7 +185,7 @@ class ChatConfig(BaseSettings):
claude_agent_model: str | None = Field(
default=None,
description="Model for the Claude Agent SDK path. If None, derives from "
"the `model` field by stripping the OpenRouter provider prefix.",
"`thinking_standard_model` by stripping the OpenRouter provider prefix.",
)
claude_agent_max_buffer_size: int = Field(
default=10 * 1024 * 1024, # 10MB (default SDK is 1MB)
@@ -161,14 +207,18 @@ class ChatConfig(BaseSettings):
"overloaded). The SDK automatically retries with this cheaper model. "
"Empty string disables the fallback (no --fallback-model flag passed to CLI).",
)
claude_agent_max_turns: int = Field(
default=50,
agent_max_turns: int = Field(
default=100,
ge=1,
le=10000,
description="Maximum number of agentic turns (tool-use loops) per query. "
"Prevents runaway tool loops from burning budget. "
"Changed from 1000 to 50 in SDK 0.1.58 upgrade — override via "
"CHAT_CLAUDE_AGENT_MAX_TURNS env var if your workflows need more.",
validation_alias=AliasChoices(
"CHAT_AGENT_MAX_TURNS",
"CHAT_CLAUDE_AGENT_MAX_TURNS",
),
description="Maximum number of tool-call rounds per turn — applies to "
"both the baseline and Claude Agent SDK paths. Prevents runaway tool "
"loops from burning budget. Override via CHAT_AGENT_MAX_TURNS env var "
"(legacy CHAT_CLAUDE_AGENT_MAX_TURNS still accepted).",
)
claude_agent_max_budget_usd: float = Field(
default=10.0,
@@ -179,22 +229,57 @@ class ChatConfig(BaseSettings):
"Set to $10 to allow most tasks to complete (p50=$5.37, p75=$13.07). "
"Override via CHAT_CLAUDE_AGENT_MAX_BUDGET_USD env var.",
)
claude_agent_autocompact_pct_override: int = Field(
default=50,
ge=0,
le=100,
description="Auto-compaction trigger threshold as a percentage of the "
"CLI's perceived window (sets ``CLAUDE_AUTOCOMPACT_PCT_OVERRIDE`` on the "
"SDK subprocess). The CLI caps at its default (~93% of window); values "
"above that have no effect. 50 (= 100K of a 200K window) keeps Anthropic "
"context creation costs down. Set to 0 to omit the env var entirely "
"and let the CLI use its default ~93% threshold — useful when the "
"post-compaction floor (system prompt + tool defs ≈ 65-110K) is close "
"to the trigger and a more aggressive value causes back-to-back "
"compaction cascades. Skipped unconditionally for Moonshot routes.",
)
claude_agent_max_thinking_tokens: int = Field(
default=8192,
ge=1024,
ge=0,
le=128000,
description="Maximum thinking/reasoning tokens per LLM call. "
"Extended thinking on Opus can generate 50k+ tokens at $75/M — "
"capping this is the single biggest cost lever. "
"8192 is sufficient for most tasks; increase for complex reasoning.",
description="Maximum thinking/reasoning tokens per LLM call. Applies "
"to both the Claude Agent SDK path (as ``max_thinking_tokens``) and "
"the baseline OpenRouter path (as ``extra_body.reasoning.max_tokens`` "
"on Anthropic routes). Extended thinking on Opus can generate 50k+ "
"tokens at $75/M — capping this is the single biggest cost lever. "
"8192 is sufficient for most tasks; increase for complex reasoning. "
"Set to 0 to disable extended thinking on both paths (kill switch): "
"baseline skips the ``reasoning`` extra_body; SDK omits the "
"``max_thinking_tokens`` kwarg so the CLI falls back to model default "
"(which, without the flag, leaves extended thinking off).",
)
render_reasoning_in_ui: bool = Field(
default=True,
description="Render reasoning as live UI parts "
"(``StreamReasoning*`` wire events). False suppresses the live "
"wire events only; ``role='reasoning'`` rows are always persisted "
"so the reasoning bubble hydrates on reload. Tokens are billed "
"upstream regardless.",
)
stream_replay_count: int = Field(
default=200,
ge=1,
le=10000,
description="Max Redis stream entries replayed on SSE reconnect.",
)
claude_agent_thinking_effort: Literal["low", "medium", "high", "max"] | None = (
Field(
default=None,
description="Thinking effort level: 'low', 'medium', 'high', 'max', or None. "
"Only applies to models with extended thinking (Opus). "
"Sonnet doesn't have extended thinking — setting effort on Sonnet "
"can cause <internal_reasoning> tag leaks. "
"Applies to models that emit a reasoning channel — Opus (extended "
"thinking) and Kimi K2.6 (OpenRouter ``reasoning`` extension lit "
"up by #12871). Sonnet does not have extended thinking — setting "
"effort on Sonnet can cause <internal_reasoning> tag leaks. "
"None = let the model decide. Override via CHAT_CLAUDE_AGENT_THINKING_EFFORT.",
)
)
@@ -214,6 +299,43 @@ class ChatConfig(BaseSettings):
"from the prefix. Set to False to fall back to passing the system "
"prompt as a raw string.",
)
baseline_prompt_cache_ttl: str = Field(
default="1h",
description="TTL for the ephemeral prompt-cache markers on the baseline "
"OpenRouter path. Anthropic supports only `5m` (default, 1.25x input "
"price for the write) or `1h` (2x input price for the write). 1h is "
"strictly cheaper overall when the static prefix gets >7 reads per "
"write-window; since the system prompt + tools array is identical "
"across all users in our workspace, 1h is the default so cross-user "
"reads amortise the higher write cost. Anthropic has no longer "
"(24h, permanent) TTL option — see "
"https://platform.claude.com/docs/en/build-with-claude/prompt-caching.",
)
sdk_include_partial_messages: bool = Field(
default=True,
description="Stream SDK responses token-by-token instead of in "
"one lump at the end. Set to False if the SDK path starts "
"double-writing text or dropping the tail of long messages.",
)
sdk_reconcile_openrouter_cost: bool = Field(
default=True,
description="Query OpenRouter's ``/api/v1/generation?id=`` after each "
"SDK turn and record the authoritative ``total_cost`` instead of the "
"Claude Agent SDK CLI's estimate. Covers every OpenRouter-routed "
"SDK turn regardless of vendor — the CLI's static Anthropic pricing "
"table is accurate for Anthropic models (Sonnet/Opus via OpenRouter "
"bill at Anthropic's own rates, penny-for-penny), but the reconcile "
"catches any future rate change the CLI hasn't picked up and makes "
"non-Anthropic cost (Kimi et al) correct — real billed amount, "
"matching the baseline path's ``usage.cost`` read since #12864. "
"Kill-switch for emergencies: set ``CHAT_SDK_RECONCILE_OPENROUTER_COST"
"=false`` to fall back to the CLI's ``total_cost_usd`` reported "
"synchronously (accurate-for-Anthropic / over-billed-for-Kimi). "
"Tradeoff: 0.5-2s window between turn end and cost write; rate-limit "
"counter briefly unaware, back-to-back turns in that window see "
"stale state. The alternative (writing an estimate sync then a "
"correction delta) would double-count the rate limit.",
)
claude_agent_cli_path: str | None = Field(
default=None,
description="Optional explicit path to a Claude Code CLI binary. "
@@ -384,6 +506,59 @@ class ChatConfig(BaseSettings):
)
return v
@model_validator(mode="after")
def _validate_sdk_model_vendor_compatibility(self) -> "ChatConfig":
"""Fail at config load when an SDK model slug is incompatible with
explicit direct-Anthropic mode.
The SDK path's ``_normalize_model_name`` raises ``ValueError`` when
a non-Anthropic vendor slug (e.g. ``moonshotai/kimi-k2.6``) is paired
with direct-Anthropic mode — but that fires inside the request loop,
so a misconfigured deployment would surface a 500 to every user
instead of failing visibly at boot.
Only the **explicit** opt-out (``use_openrouter=False``) is checked
here, not the credential-missing path. Build environments and
OpenAPI-schema export jobs construct ``ChatConfig()`` without any
OpenRouter credentials in the env — that's not a misconfiguration,
it's "config loads ok, but no SDK turn will succeed until creds are
wired". The runtime guard in ``_normalize_model_name`` still
catches the credential-missing path on the first SDK turn.
Covers all three SDK fields that flow through
``_normalize_model_name``: primary tier
(``thinking_standard_model``), advanced tier
(``thinking_advanced_model``), and fallback model
(``claude_agent_fallback_model`` via ``_resolve_fallback_model``).
Skipped when ``use_claude_code_subscription=True`` because the
subscription path resolves the model to ``None`` (CLI default)
and never calls ``_normalize_model_name``. Empty fallback strings
are also skipped (no fallback configured).
"""
if self.use_claude_code_subscription:
return self
if self.use_openrouter:
return self
for field_name in (
"thinking_standard_model",
"thinking_advanced_model",
"claude_agent_fallback_model",
):
value: str = getattr(self, field_name)
if not value or "/" not in value:
continue
if value.split("/", 1)[0] != "anthropic":
raise ValueError(
f"Direct-Anthropic mode (use_openrouter=False) "
f"requires an Anthropic model for {field_name}, got "
f"{value!r}. Set CHAT_THINKING_STANDARD_MODEL / "
f"CHAT_THINKING_ADVANCED_MODEL / "
f"CHAT_CLAUDE_AGENT_FALLBACK_MODEL to an anthropic/* "
f"slug, or set CHAT_USE_OPENROUTER=true."
)
return self
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",
@@ -397,3 +572,10 @@ class ChatConfig(BaseSettings):
env_file = ".env"
env_file_encoding = "utf-8"
extra = "ignore" # Ignore extra environment variables
# Accept both the Python attribute name and the validation_alias when
# constructing a ``ChatConfig`` directly (e.g. in tests passing
# ``thinking_standard_model=...``). Without this, pydantic only
# accepts the alias names (``CHAT_THINKING_STANDARD_MODEL`` env) and
# rejects field-name kwargs — breaking ``ChatConfig(field=...)`` in
# every test that constructs a config.
populate_by_name = True

View File

@@ -5,12 +5,17 @@ import pytest
from .config import ChatConfig
# Env vars that the ChatConfig validators read — must be cleared so they don't
# override the explicit constructor values we pass in each test.
# override the explicit constructor values we pass in each test. Includes the
# SDK/baseline model aliases so a leftover ``CHAT_MODEL=...`` in the developer
# or CI environment can't change whether
# ``_validate_sdk_model_vendor_compatibility`` raises.
_ENV_VARS_TO_CLEAR = (
"CHAT_USE_E2B_SANDBOX",
"CHAT_E2B_API_KEY",
"E2B_API_KEY",
"CHAT_USE_OPENROUTER",
"CHAT_USE_CLAUDE_AGENT_SDK",
"CHAT_USE_CLAUDE_CODE_SUBSCRIPTION",
"CHAT_API_KEY",
"OPEN_ROUTER_API_KEY",
"OPENAI_API_KEY",
@@ -19,6 +24,16 @@ _ENV_VARS_TO_CLEAR = (
"OPENAI_BASE_URL",
"CHAT_CLAUDE_AGENT_CLI_PATH",
"CLAUDE_AGENT_CLI_PATH",
"CHAT_FAST_STANDARD_MODEL",
"CHAT_FAST_MODEL",
"CHAT_FAST_ADVANCED_MODEL",
"CHAT_THINKING_STANDARD_MODEL",
"CHAT_THINKING_ADVANCED_MODEL",
"CHAT_MODEL",
"CHAT_ADVANCED_MODEL",
"CHAT_CLAUDE_AGENT_FALLBACK_MODEL",
"CHAT_RENDER_REASONING_IN_UI",
"CHAT_STREAM_REPLAY_COUNT",
)
@@ -28,6 +43,22 @@ def _clean_env(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv(var, raising=False)
def _make_direct_safe_config(**kwargs) -> ChatConfig:
"""Build a ``ChatConfig`` for tests that pass ``use_openrouter=False``
but aren't exercising the SDK vendor-compatibility validator.
Pins ``thinking_standard_model``/``thinking_advanced_model`` to anthropic/*
so the construction passes ``_validate_sdk_model_vendor_compatibility``
without each test having to repeat the override.
"""
defaults: dict = {
"thinking_standard_model": "anthropic/claude-sonnet-4-6",
"thinking_advanced_model": "anthropic/claude-opus-4-7",
}
defaults.update(kwargs)
return ChatConfig(**defaults)
class TestOpenrouterActive:
"""Tests for the openrouter_active property."""
@@ -48,7 +79,7 @@ class TestOpenrouterActive:
assert cfg.openrouter_active is False
def test_disabled_returns_false_despite_credentials(self):
cfg = ChatConfig(
cfg = _make_direct_safe_config(
use_openrouter=False,
api_key="or-key",
base_url="https://openrouter.ai/api/v1",
@@ -164,3 +195,133 @@ class TestClaudeAgentCliPathEnvFallback:
monkeypatch.setenv("CLAUDE_AGENT_CLI_PATH", str(tmp_path))
with pytest.raises(Exception, match="not a regular file"):
ChatConfig()
class TestSdkModelVendorCompatibility:
"""``model_validator`` that fails fast on SDK model vs routing-mode
mismatch — see PR #12878 iteration-2 review. Mirrors the runtime
guard in ``_normalize_model_name`` so misconfig surfaces at boot
instead of as a 500 on the first SDK turn."""
def test_direct_anthropic_with_kimi_override_raises(self):
"""A non-Anthropic SDK model must fail at config load when the
deployment has no OpenRouter credentials."""
with pytest.raises(Exception, match="requires an Anthropic model"):
ChatConfig(
use_openrouter=False,
api_key=None,
base_url=None,
use_claude_code_subscription=False,
thinking_standard_model="moonshotai/kimi-k2.6",
)
def test_direct_anthropic_with_anthropic_default_succeeds(self):
"""Direct-Anthropic mode is fine when both SDK slugs are anthropic/*
— which is the default after the LD-routed model rollout."""
cfg = ChatConfig(
use_openrouter=False,
api_key=None,
base_url=None,
use_claude_code_subscription=False,
)
assert cfg.thinking_standard_model == "anthropic/claude-sonnet-4-6"
def test_openrouter_with_kimi_override_succeeds(self):
"""Kimi slug round-trips cleanly when OpenRouter is on — exercised
via the LD-flag override path in production."""
cfg = ChatConfig(
use_openrouter=True,
api_key="or-key",
base_url="https://openrouter.ai/api/v1",
use_claude_code_subscription=False,
thinking_standard_model="moonshotai/kimi-k2.6",
)
assert cfg.thinking_standard_model == "moonshotai/kimi-k2.6"
def test_subscription_mode_skips_check(self):
"""Subscription path resolves the model to None and bypasses
``_normalize_model_name``, so the slug check is skipped."""
cfg = ChatConfig(
use_openrouter=False,
api_key=None,
base_url=None,
use_claude_code_subscription=True,
)
assert cfg.use_claude_code_subscription is True
def test_advanced_tier_also_validated(self):
"""Both standard and advanced SDK slugs are checked."""
with pytest.raises(Exception, match="thinking_advanced_model"):
ChatConfig(
use_openrouter=False,
api_key=None,
base_url=None,
use_claude_code_subscription=False,
thinking_standard_model="anthropic/claude-sonnet-4-6",
thinking_advanced_model="moonshotai/kimi-k2.6",
)
def test_fallback_model_also_validated(self):
"""``claude_agent_fallback_model`` flows through
``_normalize_model_name`` via ``_resolve_fallback_model`` so the
same direct-Anthropic guard applies."""
with pytest.raises(Exception, match="claude_agent_fallback_model"):
ChatConfig(
use_openrouter=False,
api_key=None,
base_url=None,
use_claude_code_subscription=False,
thinking_standard_model="anthropic/claude-sonnet-4-6",
thinking_advanced_model="anthropic/claude-opus-4-7",
claude_agent_fallback_model="moonshotai/kimi-k2.6",
)
def test_empty_fallback_skipped(self):
"""Empty ``claude_agent_fallback_model`` (no fallback configured)
must not trip the validator — the fallback-disabled state is
intentional and shouldn't require a placeholder anthropic/* slug."""
cfg = ChatConfig(
use_openrouter=False,
api_key=None,
base_url=None,
use_claude_code_subscription=False,
thinking_standard_model="anthropic/claude-sonnet-4-6",
thinking_advanced_model="anthropic/claude-opus-4-7",
claude_agent_fallback_model="",
)
assert cfg.claude_agent_fallback_model == ""
class TestRenderReasoningInUi:
"""``render_reasoning_in_ui`` gates reasoning wire events globally."""
def test_defaults_to_true(self):
"""Default must stay True — flipping it silences the reasoning
collapse for every user, which is an opt-in operator decision."""
cfg = ChatConfig()
assert cfg.render_reasoning_in_ui is True
def test_env_override_false(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("CHAT_RENDER_REASONING_IN_UI", "false")
cfg = ChatConfig()
assert cfg.render_reasoning_in_ui is False
class TestStreamReplayCount:
"""``stream_replay_count`` caps the SSE reconnect replay batch size."""
def test_default_is_200(self):
"""200 covers a full Kimi turn after coalescing (~150 events) while
bounding the replay storm from 1000+ chunks."""
cfg = ChatConfig()
assert cfg.stream_replay_count == 200
def test_env_override(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("CHAT_STREAM_REPLAY_COUNT", "500")
cfg = ChatConfig()
assert cfg.stream_replay_count == 500
def test_zero_rejected(self):
"""count=0 would make XREAD replay nothing — rejected via ge=1."""
with pytest.raises(Exception):
ChatConfig(stream_replay_count=0)

View File

@@ -9,6 +9,11 @@ COPILOT_RETRYABLE_ERROR_PREFIX = (
)
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
# Canonical marker appended as an assistant ChatMessage when the SDK stream
# ends without a ResultMessage (user hit Stop). Checked by exact equality
# at turn start so the next turn's --resume transcript doesn't carry it.
STOPPED_BY_USER_MARKER = f"{COPILOT_SYSTEM_PREFIX} Execution stopped by user"
# Prefix for all synthetic IDs generated by CoPilot block execution.
# Used to distinguish CoPilot-generated records from real graph execution records
# in PendingHumanReview and other tables.
@@ -27,6 +32,24 @@ COMPACTION_DONE_MSG = "Earlier messages were summarized to fit within context li
COMPACTION_TOOL_NAME = "context_compaction"
# ---------------------------------------------------------------------------
# Tool / stream timing budget
# ---------------------------------------------------------------------------
# Max seconds any single MCP tool call may block the stream before returning
# a "still running" handle. Shared by run_agent (wait_for_result),
# view_agent_output (wait_if_running), run_sub_session (wait_for_result),
# get_sub_session_result (wait_if_running), and run_block (hard cap).
#
# Chosen so the stream idle timeout (2× this) always has headroom — a tool
# that returns right at the cap can't race the idle watchdog.
MAX_TOOL_WAIT_SECONDS = 5 * 60 # 5 minutes
# Idle-stream watchdog: abort the SDK stream if no meaningful event arrives
# for this long. Derived from MAX_TOOL_WAIT_SECONDS so the invariant
# "no tool blocks >= idle_timeout" holds by construction.
STREAM_IDLE_TIMEOUT_SECONDS = MAX_TOOL_WAIT_SECONDS * 2 # 10 minutes
def is_copilot_synthetic_id(id_value: str) -> bool:
"""Check if an ID is a CoPilot synthetic ID (not from a real graph execution)."""
return id_value.startswith(COPILOT_SYNTHETIC_ID_PREFIX)

View File

@@ -34,6 +34,7 @@ from .utils import (
CancelCoPilotEvent,
CoPilotExecutionEntry,
create_copilot_queue_config,
get_session_lock_key,
)
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
@@ -104,25 +105,46 @@ class CoPilotExecutor(AppProcess):
time.sleep(1e5)
def cleanup(self):
"""Graceful shutdown with active execution waiting."""
pid = os.getpid()
logger.info(f"[cleanup {pid}] Starting graceful shutdown...")
"""Graceful shutdown — mirrors ``backend.executor.manager`` pattern.
# Signal the consumer thread to stop
1. Stop consumer immediately (both the Python flag that gates
``_handle_run_message`` and ``channel.stop_consuming()`` at
the broker), so no new work enters.
2. Passively wait for ``active_tasks`` to drain — each turn's
own ``finally`` publishes its terminal state via
``mark_session_completed``. When a turn exits, ``on_run_done``
removes it from ``active_tasks`` and releases its cluster lock.
3. Shut down the thread-pool executor (cancels pending, leaves
running threads alone — process exit handles them).
4. Release any cluster locks still held (defensive — on_run_done's
finally should have already released them).
5. Stop message consumer threads + disconnect pika clients.
The zombie-session bug this PR targets is handled inside each
turn's own lifecycle by :func:`sync_fail_close_session`, NOT by
cleanup — so cleanup can stay as a simple "wait, then teardown"
and matches agent-executor's proven pattern.
"""
pid = os.getpid()
prefix = f"[cleanup {pid}]"
logger.info(f"{prefix} Starting graceful shutdown...")
# 1. Stop consumer — flag AND broker-side
try:
self.stop_consuming.set()
run_channel = self.run_client.get_channel()
run_channel.connection.add_callback_threadsafe(
lambda: run_channel.stop_consuming()
)
logger.info(f"[cleanup {pid}] Consumer has been signaled to stop")
logger.info(f"{prefix} Consumer has been signaled to stop")
except Exception as e:
logger.error(f"[cleanup {pid}] Error stopping consumer: {e}")
logger.error(f"{prefix} Error stopping consumer: {e}")
# Wait for active executions to complete
# 2. Wait for in-flight turns to finish naturally
if self.active_tasks:
logger.info(
f"[cleanup {pid}] Waiting for {len(self.active_tasks)} active tasks to complete (timeout: {GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}s)..."
f"{prefix} Waiting for {len(self.active_tasks)} active tasks "
f"to complete (timeout: {GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}s)..."
)
start_time = time.monotonic()
@@ -137,38 +159,42 @@ class CoPilotExecutor(AppProcess):
if not self.active_tasks:
break
# Refresh cluster locks periodically
current_time = time.monotonic()
if current_time - last_refresh >= lock_refresh_interval:
now = time.monotonic()
if now - last_refresh >= lock_refresh_interval:
for lock in list(self._task_locks.values()):
try:
lock.refresh()
except Exception as e:
logger.warning(
f"[cleanup {pid}] Failed to refresh lock: {e}"
)
last_refresh = current_time
logger.warning(f"{prefix} Failed to refresh lock: {e}")
last_refresh = now
logger.info(
f"[cleanup {pid}] {len(self.active_tasks)} tasks still active, waiting..."
f"{prefix} {len(self.active_tasks)} tasks still active, waiting..."
)
time.sleep(10.0)
# Stop message consumers
if self.active_tasks:
logger.warning(
f"{prefix} {len(self.active_tasks)} tasks still running after "
f"{GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}s — process exit will "
f"abandon them; RabbitMQ redelivery handles the message."
)
# 3. Stop message consumer threads
if self._run_thread:
self._stop_message_consumers(
self._run_thread, self.run_client, "[cleanup][run]"
self._run_thread, self.run_client, f"{prefix} [run]"
)
if self._cancel_thread:
self._stop_message_consumers(
self._cancel_thread, self.cancel_client, "[cleanup][cancel]"
self._cancel_thread, self.cancel_client, f"{prefix} [cancel]"
)
# Clean up worker threads (closes per-loop workspace storage sessions)
# 4. Worker cleanup + executor shutdown
if self._executor:
from .processor import cleanup_worker
logger.info(f"[cleanup {pid}] Cleaning up workers...")
logger.info(f"{prefix} Cleaning up workers...")
futures = []
for _ in range(self._executor._max_workers):
futures.append(self._executor.submit(cleanup_worker))
@@ -176,22 +202,20 @@ class CoPilotExecutor(AppProcess):
try:
f.result(timeout=10)
except Exception as e:
logger.warning(f"[cleanup {pid}] Worker cleanup error: {e}")
logger.warning(f"{prefix} Worker cleanup error: {e}")
logger.info(f"[cleanup {pid}] Shutting down executor...")
logger.info(f"{prefix} Shutting down executor...")
self._executor.shutdown(wait=False)
# Release any remaining locks
# 5. Release any cluster locks still held
for session_id, lock in list(self._task_locks.items()):
try:
lock.release()
logger.info(f"[cleanup {pid}] Released lock for {session_id}")
logger.info(f"{prefix} Released lock for {session_id}")
except Exception as e:
logger.error(
f"[cleanup {pid}] Failed to release lock for {session_id}: {e}"
)
logger.error(f"{prefix} Failed to release lock for {session_id}: {e}")
logger.info(f"[cleanup {pid}] Graceful shutdown completed")
logger.info(f"{prefix} Graceful shutdown completed")
# ============ RabbitMQ Consumer Methods ============ #
@@ -366,7 +390,7 @@ class CoPilotExecutor(AppProcess):
# Try to acquire cluster-wide lock
cluster_lock = ClusterLock(
redis=redis.get_redis(),
key=f"copilot:session:{session_id}:lock",
key=get_session_lock_key(session_id),
owner_id=self.executor_id,
timeout=settings.config.cluster_lock_timeout,
)
@@ -386,13 +410,12 @@ class CoPilotExecutor(AppProcess):
# Execute the task
try:
self._task_locks[session_id] = cluster_lock
logger.info(
f"Acquired cluster lock for {session_id}, "
f"executor_id={self.executor_id}"
)
self._task_locks[session_id] = cluster_lock
cancel_event = threading.Event()
future = self.executor.submit(
execute_copilot_turn, entry, cancel_event, cluster_lock
@@ -424,7 +447,6 @@ class CoPilotExecutor(AppProcess):
error_msg = str(e) or type(e).__name__
logger.exception(f"Error in run completion callback: {error_msg}")
finally:
# Release the cluster lock
if session_id in self._task_locks:
logger.info(f"Releasing cluster lock for {session_id}")
self._task_locks[session_id].release()

View File

@@ -5,6 +5,7 @@ in a thread-local context, following the graph executor pattern.
"""
import asyncio
import concurrent.futures
import logging
import os
import subprocess
@@ -30,6 +31,87 @@ from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
SHUTDOWN_ERROR_MESSAGE = (
"Copilot executor shut down before this turn finished. Please retry."
)
# Max time execute() blocks after calling future.cancel() / when draining a
# soon-to-be-cancelled future. Gives _execute_async's own finally a chance to
# publish the accurate terminal state over the Redis CAS; long enough to let
# an in-flight Redis call settle, short enough that shutdown doesn't stall.
_CANCEL_GRACE_SECONDS = 5.0
# Max time the sync safety net itself spends on a single Redis CAS. Without
# this bound the whole point of ``sync_fail_close_session`` is defeated —
# ``mark_session_completed`` would hang on the same broken Redis that caused
# the original failure. On timeout we give up silently; worst case the
# session stays ``running`` until the stale-session watchdog reaps it, but
# at least the pool worker thread isn't blocked forever.
_FAIL_CLOSE_REDIS_TIMEOUT = 10.0
# Module-level symbol preserved for backward-compat with callers that import
# ``sync_fail_close_session``; the real implementation now lives on
# ``CoPilotProcessor`` so it can reuse ``self.execution_loop`` (same
# pattern as ``backend.executor.manager``'s ``node_execution_loop`` bridge
# at :meth:`ExecutionProcessor.on_graph_execution`).
def sync_fail_close_session(
session_id: str,
log: "CoPilotLogMetadata | TruncatedLogger",
execution_loop: asyncio.AbstractEventLoop,
) -> None:
"""Synchronously mark *session_id* as failed from the pool worker thread.
Submits the CAS coroutine to the long-lived *execution_loop* via
``run_coroutine_threadsafe`` — the same shape agent-executor uses at
:meth:`backend.executor.manager.ExecutionProcessor.on_graph_execution`
to reach its ``node_execution_loop`` from the pool worker. Reusing the
persistent loop means:
* no fresh TCP connection per turn (the ``@thread_cached``
``AsyncRedis`` on the execution thread stays bound to the same loop
and is reused across every turn);
* no loop-teardown overhead;
* no ``clear_cache()`` gymnastics to dodge the "loop is closed" pitfall.
``mark_session_completed`` is an atomic CAS on ``status == "running"``,
so when the async path already wrote a terminal state the sync call is
a cheap no-op. The inner ``asyncio.wait_for`` bounds the Redis call so
a wedged Redis can't hang the safety net for the full redis-py default
TCP timeout; the outer ``.result(timeout=...)`` is a belt-and-braces
upper bound for the cross-thread wait.
"""
async def _bounded() -> None:
await asyncio.wait_for(
stream_registry.mark_session_completed(
session_id, error_message=SHUTDOWN_ERROR_MESSAGE
),
timeout=_FAIL_CLOSE_REDIS_TIMEOUT,
)
try:
future = asyncio.run_coroutine_threadsafe(_bounded(), execution_loop)
except RuntimeError as e:
# execution_loop is closed — happens if cleanup() already ran the
# per-worker teardown. Nothing we can do; let the stale-session
# watchdog reap it.
log.warning(f"sync fail-close skipped (execution_loop closed): {e}")
return
try:
future.result(timeout=_FAIL_CLOSE_REDIS_TIMEOUT + 2)
except concurrent.futures.TimeoutError:
log.warning(
f"sync fail-close timed out after {_FAIL_CLOSE_REDIS_TIMEOUT}s "
f"(session={session_id})"
)
future.cancel()
except Exception as e:
log.warning(f"sync fail-close mark_session_completed failed: {e}")
# ============ Mode Routing ============ #
@@ -222,6 +304,10 @@ class CoPilotProcessor:
Shuts down the workspace storage instance that belongs to this
worker's event loop, ensuring ``aiohttp.ClientSession.close()``
runs on the same loop that created the session.
Sub-AutoPilots are enqueued on the copilot_execution queue, so
rolling deploys survive via RabbitMQ redelivery — no bespoke
shutdown notifier needed.
"""
coro = shutdown_workspace_storage()
try:
@@ -248,12 +334,13 @@ class CoPilotProcessor:
):
"""Execute a CoPilot turn.
Runs the async logic in the worker's event loop and handles errors.
Args:
entry: The turn payload containing session and message info
cancel: Threading event to signal cancellation
cluster_lock: Distributed lock to prevent duplicate execution
Thin wrapper around :meth:`_execute`. The ``try/finally`` here
guarantees :func:`sync_fail_close_session` runs on every exit
path — normal completion, exception, or a wedged event loop
that escapes via :data:`_CANCEL_GRACE_SECONDS` timeout.
``mark_session_completed`` is an atomic CAS on
``status == "running"``, so when the async path already wrote a
terminal state the sync call is a cheap no-op.
"""
log = CoPilotLogMetadata(
logging.getLogger(__name__),
@@ -261,10 +348,28 @@ class CoPilotProcessor:
user_id=entry.user_id,
)
log.info("Starting execution")
start_time = time.monotonic()
try:
self._execute(entry, cancel, cluster_lock, log)
finally:
sync_fail_close_session(entry.session_id, log, self.execution_loop)
elapsed = time.monotonic() - start_time
log.info(f"Execution completed in {elapsed:.2f}s")
# Run the async execution in our event loop
def _execute(
self,
entry: CoPilotExecutionEntry,
cancel: threading.Event,
cluster_lock: ClusterLock,
log: CoPilotLogMetadata,
):
"""Submit the async turn to ``self.execution_loop`` and drive it.
Handles the sync/async boundary (cancel-event checks, cluster-lock
refresh, bounded waits) without any Redis-state cleanup logic —
that lives in :func:`sync_fail_close_session` which the outer
:meth:`execute` always invokes on exit.
"""
future = asyncio.run_coroutine_threadsafe(
self._execute_async(entry, cancel, cluster_lock, log),
self.execution_loop,
@@ -278,16 +383,27 @@ class CoPilotProcessor:
if cancel.is_set():
log.info("Cancellation requested")
future.cancel()
break
# Refresh cluster lock to maintain ownership
# Give _execute_async's own finally a short window to
# publish its accurate terminal state before the outer
# sync safety net fires.
try:
future.result(timeout=_CANCEL_GRACE_SECONDS)
except BaseException:
pass
return
cluster_lock.refresh()
if not future.cancelled():
# Get result to propagate any exceptions
future.result()
elapsed = time.monotonic() - start_time
log.info(f"Execution completed in {elapsed:.2f}s")
# Bounded timeout so a wedged event loop can't trap us here —
# on timeout we escape to execute()'s finally and the sync
# safety net fires.
try:
future.result(timeout=_CANCEL_GRACE_SECONDS)
except concurrent.futures.TimeoutError:
log.warning(
"Future did not complete within grace window; "
"falling through to sync fail-close"
)
async def _execute_async(
self,
@@ -342,7 +458,9 @@ class CoPilotProcessor:
# Stream chat completion and publish chunks to Redis.
# stream_and_publish wraps the raw stream with registry
# publishing (shared with collect_copilot_response).
# publishing so subscribers on the session Redis stream
# (e.g. wait_for_session_result, SSE clients) receive the
# same events as they are produced.
raw_stream = stream_fn(
session_id=entry.session_id,
message=entry.message if entry.message else None,
@@ -352,27 +470,37 @@ class CoPilotProcessor:
file_ids=entry.file_ids,
mode=effective_mode,
model=entry.model,
permissions=entry.permissions,
request_arrival_at=entry.request_arrival_at,
)
async for chunk in stream_registry.stream_and_publish(
published_stream = stream_registry.stream_and_publish(
session_id=entry.session_id,
turn_id=entry.turn_id,
stream=raw_stream,
):
if cancel.is_set():
log.info("Cancel requested, breaking stream")
break
)
# Explicit aclose() on early exit: ``async for … break`` does
# not close the generator, so GeneratorExit would never reach
# stream_chat_completion_sdk, leaving its stream lock held
# until GC eventually runs.
try:
async for chunk in published_stream:
if cancel.is_set():
log.info("Cancel requested, breaking stream")
break
# Capture StreamError so mark_session_completed receives
# the error message (stream_and_publish yields but does
# not publish StreamError — that's done by mark_session_completed).
if isinstance(chunk, StreamError):
error_msg = chunk.errorText
break
# Capture StreamError so mark_session_completed receives
# the error message (stream_and_publish yields but does
# not publish StreamError — that's done by mark_session_completed).
if isinstance(chunk, StreamError):
error_msg = chunk.errorText
break
current_time = time.monotonic()
if current_time - last_refresh >= refresh_interval:
cluster_lock.refresh()
last_refresh = current_time
current_time = time.monotonic()
if current_time - last_refresh >= refresh_interval:
cluster_lock.refresh()
last_refresh = current_time
finally:
await published_stream.aclose()
# Stream loop completed
if cancel.is_set():

View File

@@ -10,14 +10,21 @@ the real production helpers from ``processor.py`` so the routing logic
has meaningful coverage.
"""
from unittest.mock import AsyncMock, patch
import asyncio
import concurrent.futures
import logging
import threading
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot.executor.processor import (
CoPilotProcessor,
resolve_effective_mode,
resolve_use_sdk_for_mode,
sync_fail_close_session,
)
from backend.copilot.executor.utils import CoPilotExecutionEntry, CoPilotLogMetadata
class TestResolveUseSdkForMode:
@@ -173,3 +180,319 @@ class TestResolveEffectiveMode:
) as flag_mock:
assert await resolve_effective_mode("fast", None) is None
flag_mock.assert_awaited_once()
# ---------------------------------------------------------------------------
# _execute_async aclose propagation
# ---------------------------------------------------------------------------
class _TrackedStream:
"""Minimal async-generator stand-in that records whether ``aclose``
was called, so tests can verify the processor forces explicit cleanup
of the published stream on every exit path (normal + break on cancel)."""
def __init__(self, events: list):
self._events = events
self.aclose_called = False
def __aiter__(self):
return self
async def __anext__(self):
if not self._events:
raise StopAsyncIteration
return self._events.pop(0)
async def aclose(self) -> None:
self.aclose_called = True
def _make_entry() -> CoPilotExecutionEntry:
return CoPilotExecutionEntry(
session_id="sess-1",
turn_id="turn-1",
user_id="user-1",
message="hi",
is_user_message=True,
request_arrival_at=0.0,
)
def _make_log() -> CoPilotLogMetadata:
return CoPilotLogMetadata(logger=logging.getLogger("test-copilot"))
class TestExecuteAsyncAclose:
"""``_execute_async`` must call ``aclose`` on the published stream both
when the loop exits naturally and when ``cancel`` is set mid-stream —
otherwise ``stream_chat_completion_sdk`` stays suspended and keeps
holding the per-session Redis lock until GC."""
def _patches(self, published_stream: _TrackedStream):
"""Shared mock context: patches every dependency ``_execute_async``
touches so the aclose path is the only behaviour under test."""
return [
patch(
"backend.copilot.executor.processor.ChatConfig",
return_value=MagicMock(test_mode=True, use_claude_agent_sdk=True),
),
patch(
"backend.copilot.executor.processor.stream_chat_completion_dummy",
return_value=MagicMock(),
),
patch(
"backend.copilot.executor.processor.stream_registry.stream_and_publish",
return_value=published_stream,
),
patch(
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
new=AsyncMock(),
),
]
@pytest.mark.asyncio
async def test_normal_exit_calls_aclose(self) -> None:
published = _TrackedStream(events=[MagicMock(), MagicMock()])
proc = CoPilotProcessor()
cancel = threading.Event()
cluster_lock = MagicMock()
patches = self._patches(published)
with patches[0], patches[1], patches[2], patches[3]:
await proc._execute_async(_make_entry(), cancel, cluster_lock, _make_log())
assert published.aclose_called is True
@pytest.mark.asyncio
async def test_cancel_break_calls_aclose(self) -> None:
events = [MagicMock()] # first chunk delivered, then cancel fires
published = _TrackedStream(events=events)
proc = CoPilotProcessor()
cancel = threading.Event()
cancel.set() # pre-set so the loop breaks on the first chunk
cluster_lock = MagicMock()
patches = self._patches(published)
with patches[0], patches[1], patches[2], patches[3]:
await proc._execute_async(_make_entry(), cancel, cluster_lock, _make_log())
assert published.aclose_called is True
@pytest.fixture
def exec_loop():
"""Long-lived asyncio loop on a daemon thread — mirrors the layout
``CoPilotProcessor`` sets up (``execution_loop`` + ``execution_thread``)
so ``sync_fail_close_session`` has a real cross-thread loop to submit
into via ``run_coroutine_threadsafe``."""
loop = asyncio.new_event_loop()
thread = threading.Thread(target=loop.run_forever, daemon=True)
thread.start()
try:
yield loop
finally:
loop.call_soon_threadsafe(loop.stop)
thread.join(timeout=5)
loop.close()
class TestSyncFailCloseSession:
"""``sync_fail_close_session`` is the last-line-of-defense invoked from
``CoPilotProcessor.execute``'s ``finally``. It must call
``mark_session_completed`` via the processor's long-lived
``execution_loop`` (cross-thread submit) and must swallow Redis
failures so a transient outage doesn't propagate out of the finally."""
def test_invokes_mark_session_completed_with_shutdown_message(
self, exec_loop
) -> None:
mock_mark = AsyncMock()
with patch(
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
new=mock_mark,
):
sync_fail_close_session("sess-1", _make_log(), exec_loop)
mock_mark.assert_awaited_once()
assert mock_mark.await_args is not None
assert mock_mark.await_args.args[0] == "sess-1"
assert "shut down" in mock_mark.await_args.kwargs["error_message"].lower()
def test_swallows_redis_error(self, exec_loop) -> None:
# Raising from the mock ensures the helper catches the exception
# instead of propagating it back into execute()'s finally block.
mock_mark = AsyncMock(side_effect=RuntimeError("redis down"))
with patch(
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
new=mock_mark,
):
sync_fail_close_session("sess-2", _make_log(), exec_loop) # must not raise
mock_mark.assert_awaited_once()
def test_closed_execution_loop_skipped_cleanly(self) -> None:
"""If cleanup_worker has already stopped the execution_loop by the
time the safety net fires, ``run_coroutine_threadsafe`` raises
RuntimeError. Expected behavior: log + return without propagating."""
dead_loop = asyncio.new_event_loop()
dead_loop.close()
mock_mark = AsyncMock()
with patch(
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
new=mock_mark,
):
# Must not raise even though the loop is closed
sync_fail_close_session("sess-closed-loop", _make_log(), dead_loop)
# mark_session_completed was never scheduled because the loop was dead
mock_mark.assert_not_awaited()
def test_bounded_timeout_when_redis_hangs(self, exec_loop) -> None:
"""Scenario D: Redis unreachable — the inner ``asyncio.wait_for``
must fire and the helper must return without blocking the worker.
Simulates a wedged Redis by sleeping past the 10s fail-close budget.
The helper must return within the configured grace (+ a small
scheduler margin) and must not re-raise.
"""
import time as _time
from backend.copilot.executor.processor import _FAIL_CLOSE_REDIS_TIMEOUT
async def _hang(*_args, **_kwargs):
await asyncio.sleep(_FAIL_CLOSE_REDIS_TIMEOUT + 5)
with patch(
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
new=_hang,
):
start = _time.monotonic()
sync_fail_close_session(
"sess-hang", _make_log(), exec_loop
) # must not raise
elapsed = _time.monotonic() - start
# wait_for fires at _FAIL_CLOSE_REDIS_TIMEOUT; outer future.result
# has +2s slack. If the timeout is missing/broken the helper would
# block the full sleep duration (~15s).
assert elapsed < _FAIL_CLOSE_REDIS_TIMEOUT + 4.0, (
f"sync_fail_close_session hung for {elapsed:.1f}s — bounded "
f"timeout did not fire"
)
# ---------------------------------------------------------------------------
# End-to-end execute() safety-net coverage — the PR's core invariant
# ---------------------------------------------------------------------------
class TestExecuteSafetyNet:
"""``CoPilotProcessor.execute`` must always invoke
``sync_fail_close_session`` in its ``finally`` so a session never stays
``status=running`` in Redis.
Validates the four deploy-time scenarios the PR targets:
* A — SIGTERM mid-turn: ``cancel`` event fires, ``_execute`` returns,
safety net still runs.
* B — happy path: normal completion, safety net runs (cheap CAS no-op).
* C — zombie Redis state: the async ``mark_session_completed`` in
``_execute_async`` blows up, but the outer safety net marks the
session failed anyway.
* D — covered by ``TestSyncFailCloseSession::test_bounded_timeout…``.
"""
def _attach_exec_loop(self, proc: CoPilotProcessor, loop) -> None:
"""``execute`` dispatches the safety net onto ``self.execution_loop``.
Tests don't call ``on_executor_start`` (which spawns the real
per-worker loop), so wire the shared fixture loop in directly."""
proc.execution_loop = loop
def _run_execute_in_thread(self, proc: CoPilotProcessor, cancel: threading.Event):
"""``CoPilotProcessor.execute`` expects to be called from a pool
worker thread that has *no* running event loop, so we always run
it off the main thread to preserve that invariant. Returns the
future so callers can inspect both result and exception paths."""
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
try:
fut = pool.submit(proc.execute, _make_entry(), cancel, MagicMock())
# Block until execute() returns (or raises) so the safety net
# has run by the time we inspect mocks.
try:
fut.result(timeout=30)
except BaseException:
pass
return fut
finally:
pool.shutdown(wait=True)
def test_happy_path_invokes_safety_net(self, exec_loop) -> None:
"""Scenario B: normal completion still runs the sync safety net.
Proves the ``finally`` always fires, even when nothing went wrong —
``mark_session_completed``'s atomic CAS makes this a cheap no-op
in production."""
mock_mark = AsyncMock()
proc = CoPilotProcessor()
self._attach_exec_loop(proc, exec_loop)
with patch.object(proc, "_execute"), patch(
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
new=mock_mark,
):
self._run_execute_in_thread(proc, threading.Event())
mock_mark.assert_awaited_once()
assert mock_mark.await_args is not None
assert mock_mark.await_args.args[0] == "sess-1"
def test_sigterm_mid_turn_invokes_safety_net(self, exec_loop) -> None:
"""Scenario A: worker raises (simulating future.cancel + grace
timeout escaping ``_execute``); ``execute`` must still reach the
safety net in its ``finally`` and mark the session failed."""
mock_mark = AsyncMock()
proc = CoPilotProcessor()
self._attach_exec_loop(proc, exec_loop)
with patch.object(
proc,
"_execute",
side_effect=concurrent.futures.TimeoutError("grace expired"),
), patch(
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
new=mock_mark,
):
self._run_execute_in_thread(proc, threading.Event())
mock_mark.assert_awaited_once()
def test_zombie_redis_async_path_still_marks_session_failed(
self, exec_loop
) -> None:
"""Scenario C: ``_execute_async``'s own ``mark_session_completed``
call is broken (simulating the exact async-Redis hiccup that caused
the original zombie sessions). The outer ``sync_fail_close_session``
runs on the processor's long-lived ``execution_loop`` and succeeds
where the async path failed."""
call_log: list[str] = []
async def _ok(*args, **kwargs):
call_log.append("sync-ok")
def _broken_execute(entry, cancel, cluster_lock, log):
# Simulate the async path raising because its Redis client is
# wedged (the pre-fix zombie-session scenario).
raise RuntimeError("async Redis client broken")
proc = CoPilotProcessor()
self._attach_exec_loop(proc, exec_loop)
with patch.object(proc, "_execute", side_effect=_broken_execute), patch(
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
new=_ok,
):
self._run_execute_in_thread(proc, threading.Event())
# The sync safety net must have fired despite the async path
# blowing up — this is the core guarantee of the PR.
assert call_log == [
"sync-ok"
], f"expected sync_fail_close_session to run once, got {call_log!r}"

View File

@@ -10,6 +10,7 @@ import logging
from pydantic import BaseModel
from backend.copilot.config import CopilotLlmModel, CopilotMode
from backend.copilot.permissions import CopilotPermissions
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
@@ -81,12 +82,23 @@ COPILOT_CANCEL_EXCHANGE = Exchange(
)
COPILOT_CANCEL_QUEUE_NAME = "copilot_cancel_queue"
# CoPilot operations can include extended thinking and agent generation
# which may take 30+ minutes to complete
COPILOT_CONSUMER_TIMEOUT_SECONDS = 60 * 60 # 1 hour
# Graceful shutdown timeout - allow in-flight operations to complete
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = 30 * 60 # 30 minutes
def get_session_lock_key(session_id: str) -> str:
"""Redis key for the per-session cluster lock held by the executing pod."""
return f"copilot:session:{session_id}:lock"
# CoPilot operations can include extended thinking and agent generation
# which may take several hours to complete. Matches the pod's
# terminationGracePeriodSeconds in the helm chart so a rolling deploy can let
# the longest legitimate turn finish. Also bounds the stale-session auto-
# complete watchdog in stream_registry (consumer_timeout + 5min buffer).
COPILOT_CONSUMER_TIMEOUT_SECONDS = 6 * 60 * 60 # 6 hours
# Graceful shutdown timeout - must match COPILOT_CONSUMER_TIMEOUT_SECONDS so
# cleanup can let the longest legitimate turn complete before the pod is
# SIGKILL'd by kubelet.
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = COPILOT_CONSUMER_TIMEOUT_SECONDS
def create_copilot_queue_config() -> RabbitMQConfig:
@@ -106,9 +118,27 @@ def create_copilot_queue_config() -> RabbitMQConfig:
durable=True,
auto_delete=False,
arguments={
# Extended consumer timeout for long-running LLM operations
# Default 30-minute timeout is insufficient for extended thinking
# and agent generation which can take 30+ minutes
# Consumer timeout matches the pod graceful-shutdown window so a
# rolling deploy never forces redelivery of a turn that the pod
# is still legitimately finishing.
#
# Deploy note: RabbitMQ (verified on 4.1.4) does NOT strictly
# compare ``x-consumer-timeout`` on queue redeclaration, so this
# value can change between deploys without triggering
# PRECONDITION_FAILED. To update the *effective* timeout on an
# already-running queue before the new code deploys (so pods
# mid-shutdown don't have their consumer cancelled at the old
# limit), apply a policy:
#
# rabbitmqctl set_policy copilot-consumer-timeout \
# "^copilot_execution_queue$" \
# '{"consumer-timeout": 21600000}' \
# --apply-to queues
#
# The policy takes effect immediately. Once the policy is set
# to match the code's value the policy is redundant for new
# pods and can be removed after a stable deploy if desired —
# but it's harmless to leave in place.
"x-consumer-timeout": COPILOT_CONSUMER_TIMEOUT_SECONDS
* 1000,
},
@@ -163,6 +193,20 @@ class CoPilotExecutionEntry(BaseModel):
model: CopilotLlmModel | None = None
"""Per-request model tier: 'standard' or 'advanced'. None = server default."""
permissions: CopilotPermissions | None = None
"""Capability filter inherited from a parent run (e.g. ``run_sub_session``
forwards its parent's permissions so the sub can't escalate). ``None``
means the worker applies no filter."""
request_arrival_at: float = 0.0
"""Unix-epoch seconds (server clock) when the originating HTTP
``/stream`` request arrived. The executor's turn-start drain uses
this to decide whether each pending message was typed BEFORE or AFTER
the turn's ``current`` message, and orders the combined user bubble
chronologically. Defaults to ``0.0`` for backward compatibility with
queue messages written before this field existed (they sort as "all
pending before current" — the pre-fix behaviour)."""
class CancelCoPilotEvent(BaseModel):
"""Event to cancel a CoPilot operation."""
@@ -184,6 +228,8 @@ async def enqueue_copilot_turn(
file_ids: list[str] | None = None,
mode: CopilotMode | None = None,
model: CopilotLlmModel | None = None,
permissions: CopilotPermissions | None = None,
request_arrival_at: float = 0.0,
) -> None:
"""Enqueue a CoPilot task for processing by the executor service.
@@ -197,6 +243,8 @@ async def enqueue_copilot_turn(
file_ids: Optional workspace file IDs attached to the user's message
mode: Autopilot mode override ('fast' or 'extended_thinking'). None = server default.
model: Per-request model tier ('standard' or 'advanced'). None = server default.
permissions: Capability filter inherited from a parent run (sub-AutoPilot).
None = no filter.
"""
from backend.util.clients import get_async_copilot_queue
@@ -210,6 +258,8 @@ async def enqueue_copilot_turn(
file_ids=file_ids,
mode=mode,
model=model,
permissions=permissions,
request_arrival_at=request_arrival_at,
)
queue_client = await get_async_copilot_queue()

View File

@@ -20,12 +20,13 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
)
from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from pydantic import BaseModel
from pydantic import BaseModel, PrivateAttr
from backend.data.db_accessors import chat_db
from backend.data.db_accessors import chat_db, library_db
from backend.data.graph import GraphSettings
from backend.data.redis_client import get_redis_async
from backend.util import json
from backend.util.exceptions import DatabaseError, RedisError
from backend.util.exceptions import DatabaseError, NotFoundError, RedisError
from .config import ChatConfig
@@ -54,6 +55,12 @@ class ChatSessionMetadata(BaseModel):
dry_run: bool = False
# Builder-panel binding: when set, the session is locked to the given
# graph. ``edit_agent`` / ``run_agent`` default their ``agent_id`` to
# this graph and reject calls targeting a different agent. Also used
# as a lookup key so refreshing the builder resumes the same chat.
builder_graph_id: str | None = None
class ChatMessage(BaseModel):
role: str
@@ -65,6 +72,7 @@ class ChatMessage(BaseModel):
function_call: dict | None = None
sequence: int | None = None
duration_ms: int | None = None
created_at: datetime | None = None
@staticmethod
def from_db(prisma_message: PrismaChatMessage) -> "ChatMessage":
@@ -79,6 +87,7 @@ class ChatMessage(BaseModel):
function_call=_parse_json_field(prisma_message.functionCall),
sequence=prisma_message.sequence,
duration_ms=prisma_message.durationMs,
created_at=prisma_message.createdAt,
)
@@ -198,9 +207,24 @@ class ChatSessionInfo(BaseModel):
class ChatSession(ChatSessionInfo):
messages: list[ChatMessage]
# In-flight tool-call names for the CURRENT turn. Not persisted to
# DB and not serialised on the wire — ``PrivateAttr`` keeps this a
# process-local scratch buffer that's invisible to ``model_dump`` /
# ``model_dump_json`` / the redis cache path. Populated by the
# baseline tool executor the moment a tool is dispatched so in-turn
# guards (e.g. ``require_guide_read``) can see the call before it
# lands in ``messages`` at turn-end. Cleared when the turn
# completes.
_inflight_tool_calls: set[str] = PrivateAttr(default_factory=set)
@classmethod
def new(cls, user_id: str, *, dry_run: bool) -> Self:
def new(
cls,
user_id: str,
*,
dry_run: bool,
builder_graph_id: str | None = None,
) -> Self:
return cls(
session_id=str(uuid.uuid4()),
user_id=user_id,
@@ -210,7 +234,10 @@ class ChatSession(ChatSessionInfo):
credentials={},
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
metadata=ChatSessionMetadata(dry_run=dry_run),
metadata=ChatSessionMetadata(
dry_run=dry_run,
builder_graph_id=builder_graph_id,
),
)
@classmethod
@@ -226,6 +253,56 @@ class ChatSession(ChatSessionInfo):
messages=[ChatMessage.from_db(m) for m in prisma_session.Messages],
)
def announce_inflight_tool_call(self, tool_name: str) -> None:
"""Record that *tool_name* is being dispatched in the current turn.
Called by the baseline tool executor **before** the tool actually
runs (the announcement is about dispatch, not success). If the
tool raises, the name stays in the buffer for the rest of the
turn — that matches the guide-read gate's contract ("was the tool
called?") but means any future gate wanting *successful*
dispatches would need its own tracking.
Lets in-turn guards (see
``copilot/tools/helpers.py::require_guide_read``) see a tool
call the moment it's issued, instead of waiting for the
``session.messages`` flush at turn end — fixing a loop where a
second tool in the same turn re-fires a guard despite the
guarding tool having already been called (seen on Kimi K2.6 in
particular because its aggressive tool-call chaining exercises
this path much more than Sonnet does). The buffer is cleared by
:meth:`clear_inflight_tool_calls` at turn end.
"""
self._inflight_tool_calls.add(tool_name)
def clear_inflight_tool_calls(self) -> None:
"""Reset the in-flight tool-call announcement buffer."""
self._inflight_tool_calls.clear()
def has_tool_been_called(self, tool_name: str) -> bool:
"""True when *tool_name* has been called in this session.
Checks the in-flight announcement buffer (for calls dispatched
in the *current* turn but not yet flushed into ``messages``) and
the durable ``messages`` history (for past turns + prior rounds
within this turn whose writes already landed). The durable
scan is session-wide, not turn-scoped: a matching tool call
anywhere in ``messages`` counts. This matches the guide-read
contract — once the guide has been read in the session, the
agent doesn't need to re-read it for later create/edit/fix
tools.
"""
if tool_name in self._inflight_tool_calls:
return True
for msg in reversed(self.messages):
if msg.role != "assistant" or not msg.tool_calls:
continue
for tc in msg.tool_calls:
name = tc.get("function", {}).get("name") or tc.get("name")
if name == tool_name:
return True
return False
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
"""Attach a tool_call to the current turn's assistant message.
@@ -712,20 +789,32 @@ async def append_and_save_message(
return session
async def create_chat_session(user_id: str, *, dry_run: bool) -> ChatSession:
async def create_chat_session(
user_id: str,
*,
dry_run: bool,
builder_graph_id: str | None = None,
) -> ChatSession:
"""Create a new chat session and persist it.
Args:
user_id: The authenticated user ID.
dry_run: When True, run_block and run_agent tool calls in this
session are forced to use dry-run simulation mode.
builder_graph_id: When set, locks the session to the given graph.
The builder panel uses this to bind a chat to the currently-
opened agent and to resume the same session on refresh.
Raises:
DatabaseError: If the database write fails. We fail fast to ensure
callers never receive a non-persisted session that only exists
in cache (which would be lost when the cache expires).
"""
session = ChatSession.new(user_id, dry_run=dry_run)
session = ChatSession.new(
user_id,
dry_run=dry_run,
builder_graph_id=builder_graph_id,
)
# Create in database first - fail fast if this fails
try:
@@ -749,6 +838,58 @@ async def create_chat_session(user_id: str, *, dry_run: bool) -> ChatSession:
return session
async def get_or_create_builder_session(
user_id: str,
graph_id: str,
) -> ChatSession:
"""Return the user's builder session for *graph_id*, creating it if absent.
The session pointer is stored on
``LibraryAgent.settings.builder_chat_session_id``. Ownership is enforced
by ``get_library_agent_by_graph_id`` (filters on ``userId``); a miss
raises :class:`NotFoundError` (HTTP 404), which also blocks graph-id
probing by unauthorized callers.
"""
library_agent = await library_db().get_library_agent_by_graph_id(
user_id=user_id, graph_id=graph_id
)
if library_agent is None:
raise NotFoundError(f"Graph {graph_id} not found")
existing_sid = library_agent.settings.builder_chat_session_id
if existing_sid:
session = await get_chat_session(existing_sid, user_id)
if session is not None:
return session
# Serialise create-and-claim so concurrent callers for the same
# (user_id, graph_id) don't each mint a session and orphan one
# (double-click / two-tab race — sentry 13632535).
async with _get_session_lock(f"builder:{user_id}:{graph_id}"):
library_agent = await library_db().get_library_agent_by_graph_id(
user_id=user_id, graph_id=graph_id
)
if library_agent is None:
raise NotFoundError(f"Graph {graph_id} not found")
existing_sid = library_agent.settings.builder_chat_session_id
if existing_sid:
session = await get_chat_session(existing_sid, user_id)
if session is not None:
return session
session = await create_chat_session(
user_id,
dry_run=False,
builder_graph_id=graph_id,
)
await library_db().update_library_agent(
library_agent_id=library_agent.id,
user_id=user_id,
settings=GraphSettings(builder_chat_session_id=session.session_id),
)
return session
async def get_user_sessions(
user_id: str,
limit: int = 50,

View File

@@ -0,0 +1,104 @@
"""LaunchDarkly-aware model selection for the copilot.
Each cell of the ``(mode, tier)`` matrix has a static default baked into
``ChatConfig`` (see ``copilot/config.py``) and a matching LaunchDarkly
string-valued feature flag that can override it per-user. This module
centralises the lookup so both the baseline and SDK paths agree on the
selection rule and so A/B experiments can target a single cell without
shipping a config change.
Matrix:
+----------+-------------------------------------+-------------------------------------+
| | standard | advanced |
+----------+-------------------------------------+-------------------------------------+
| fast | copilot-fast-standard-model | copilot-fast-advanced-model |
| thinking | copilot-thinking-standard-model | copilot-thinking-advanced-model |
+----------+-------------------------------------+-------------------------------------+
LD flag values are arbitrary strings (model identifiers, e.g.
``"anthropic/claude-sonnet-4-6"`` or ``"moonshotai/kimi-k2.6"``). Empty
or non-string values fall back to the config default.
"""
from __future__ import annotations
import logging
from typing import Literal
from backend.copilot.config import ChatConfig
from backend.util.feature_flag import Flag, get_feature_flag_value
logger = logging.getLogger(__name__)
ModelMode = Literal["fast", "thinking"]
ModelTier = Literal["standard", "advanced"]
_FLAG_BY_CELL: dict[tuple[ModelMode, ModelTier], Flag] = {
("fast", "standard"): Flag.COPILOT_FAST_STANDARD_MODEL,
("fast", "advanced"): Flag.COPILOT_FAST_ADVANCED_MODEL,
("thinking", "standard"): Flag.COPILOT_THINKING_STANDARD_MODEL,
("thinking", "advanced"): Flag.COPILOT_THINKING_ADVANCED_MODEL,
}
def _config_default(config: ChatConfig, mode: ModelMode, tier: ModelTier) -> str:
if mode == "fast":
return (
config.fast_advanced_model
if tier == "advanced"
else config.fast_standard_model
)
return (
config.thinking_advanced_model
if tier == "advanced"
else config.thinking_standard_model
)
async def resolve_model(
mode: ModelMode,
tier: ModelTier,
user_id: str | None,
*,
config: ChatConfig,
) -> str:
"""Return the model identifier for a ``(mode, tier)`` cell.
Consults the matching LaunchDarkly flag for *user_id* first and
falls back to the ``ChatConfig`` default on missing user, missing
flag, or non-string flag value. Passing *config* explicitly keeps
the resolver cheap to unit-test.
"""
fallback = _config_default(config, mode, tier).strip()
if not user_id:
return fallback
flag = _FLAG_BY_CELL[(mode, tier)]
try:
value = await get_feature_flag_value(flag.value, user_id, default=fallback)
except Exception:
logger.warning(
"[model_router] LD lookup failed for %s — using config default %s",
flag.value,
fallback,
exc_info=True,
)
return fallback
if isinstance(value, str) and value.strip():
return value.strip()
if value != fallback:
reason = (
"empty string"
if isinstance(value, str)
else f"non-string ({type(value).__name__})"
)
logger.warning(
"[model_router] LD flag %s returned %s — using config default %s",
flag.value,
reason,
fallback,
)
return fallback

View File

@@ -0,0 +1,166 @@
"""Tests for the LD-aware model resolver."""
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.config import ChatConfig
from backend.copilot.model_router import _FLAG_BY_CELL, _config_default, resolve_model
def _make_config() -> ChatConfig:
"""Build a config with the canonical defaults so tests read naturally."""
return ChatConfig(
fast_standard_model="anthropic/claude-sonnet-4-6",
fast_advanced_model="anthropic/claude-opus-4.7",
thinking_standard_model="anthropic/claude-sonnet-4-6",
thinking_advanced_model="anthropic/claude-opus-4.7",
)
class TestConfigDefault:
def test_fast_standard(self):
cfg = _make_config()
assert _config_default(cfg, "fast", "standard") == cfg.fast_standard_model
def test_fast_advanced(self):
cfg = _make_config()
assert _config_default(cfg, "fast", "advanced") == cfg.fast_advanced_model
def test_thinking_standard(self):
cfg = _make_config()
assert (
_config_default(cfg, "thinking", "standard") == cfg.thinking_standard_model
)
def test_thinking_advanced(self):
cfg = _make_config()
assert (
_config_default(cfg, "thinking", "advanced") == cfg.thinking_advanced_model
)
class TestResolveModel:
@pytest.mark.asyncio
async def test_missing_user_returns_fallback(self):
"""Without user_id there's no LD context — skip the lookup entirely."""
cfg = _make_config()
with patch("backend.copilot.model_router.get_feature_flag_value") as mock_flag:
result = await resolve_model("fast", "standard", None, config=cfg)
assert result == cfg.fast_standard_model
mock_flag.assert_not_called()
@pytest.mark.asyncio
async def test_missing_user_strips_whitespace_from_fallback(self):
"""Sentry MEDIUM: the anonymous-user branch returned an unstripped
config value. If ``CHAT_*_MODEL`` env carries trailing whitespace
the downstream ``resolved == tier_default`` check in
``_resolve_sdk_model_for_request`` would diverge from the
whitespace-stripped LD side, bypassing subscription mode for
every anonymous request. Strip at the source."""
cfg = ChatConfig(
fast_standard_model="anthropic/claude-sonnet-4-6 ", # trailing ws
fast_advanced_model="anthropic/claude-opus-4.7",
thinking_standard_model="anthropic/claude-sonnet-4-6",
thinking_advanced_model="anthropic/claude-opus-4.7",
)
result = await resolve_model("fast", "standard", None, config=cfg)
assert result == "anthropic/claude-sonnet-4-6"
@pytest.mark.asyncio
async def test_ld_string_override_wins(self):
"""LD-returned model string replaces the config default."""
cfg = _make_config()
with patch(
"backend.copilot.model_router.get_feature_flag_value",
new=AsyncMock(return_value="moonshotai/kimi-k2.6"),
):
result = await resolve_model("fast", "standard", "user-1", config=cfg)
assert result == "moonshotai/kimi-k2.6"
@pytest.mark.asyncio
async def test_whitespace_is_stripped(self):
cfg = _make_config()
with patch(
"backend.copilot.model_router.get_feature_flag_value",
new=AsyncMock(return_value=" xai/grok-4 "),
):
result = await resolve_model("thinking", "advanced", "user-1", config=cfg)
assert result == "xai/grok-4"
@pytest.mark.asyncio
async def test_non_string_value_falls_back_with_type_in_warning(self, caplog):
"""LD misconfigured as a boolean flag — don't try to use ``True`` as a
model name; return the config default. Warning must say
'non-string' (not 'empty string') so the LD operator knows the
flag type is wrong, not just missing a value."""
import logging
cfg = _make_config()
with caplog.at_level(logging.WARNING, logger="backend.copilot.model_router"):
with patch(
"backend.copilot.model_router.get_feature_flag_value",
new=AsyncMock(return_value=True),
):
result = await resolve_model("fast", "advanced", "user-1", config=cfg)
assert result == cfg.fast_advanced_model
assert any("non-string" in r.message for r in caplog.records)
@pytest.mark.asyncio
async def test_empty_string_falls_back_with_empty_in_warning(self, caplog):
"""When LD returns ``""`` the warning must say 'empty string'
not 'non-string' — so the operator doesn't chase a type bug
when the flag is simply unset to an empty value."""
import logging
cfg = _make_config()
with caplog.at_level(logging.WARNING, logger="backend.copilot.model_router"):
with patch(
"backend.copilot.model_router.get_feature_flag_value",
new=AsyncMock(return_value=""),
):
result = await resolve_model("fast", "standard", "user-1", config=cfg)
assert result == cfg.fast_standard_model
messages = [r.message for r in caplog.records]
assert any("empty string" in m for m in messages)
assert not any("non-string" in m for m in messages)
@pytest.mark.asyncio
async def test_ld_exception_falls_back(self):
"""LD client throws (network blip, SDK init race) — serve the default
instead of failing the whole request."""
cfg = _make_config()
with patch(
"backend.copilot.model_router.get_feature_flag_value",
new=AsyncMock(side_effect=RuntimeError("LD down")),
):
result = await resolve_model("fast", "standard", "user-1", config=cfg)
assert result == cfg.fast_standard_model
@pytest.mark.asyncio
async def test_all_four_cells_hit_distinct_flags(self):
"""Each (mode, tier) cell must route to its own flag — regression
guard against copy-paste bugs in the _FLAG_BY_CELL map."""
cfg = _make_config()
calls: list[str] = []
async def _capture(flag_key, user_id, default):
calls.append(flag_key)
return default
with patch(
"backend.copilot.model_router.get_feature_flag_value",
new=AsyncMock(side_effect=_capture),
):
await resolve_model("fast", "standard", "u", config=cfg)
await resolve_model("fast", "advanced", "u", config=cfg)
await resolve_model("thinking", "standard", "u", config=cfg)
await resolve_model("thinking", "advanced", "u", config=cfg)
assert calls == [
_FLAG_BY_CELL[("fast", "standard")].value,
_FLAG_BY_CELL[("fast", "advanced")].value,
_FLAG_BY_CELL[("thinking", "standard")].value,
_FLAG_BY_CELL[("thinking", "advanced")].value,
]
assert len(set(calls)) == 4

View File

@@ -13,12 +13,15 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
)
from pytest_mock import MockerFixture
from backend.util.exceptions import NotFoundError
from .model import (
ChatMessage,
ChatSession,
Usage,
append_and_save_message,
get_chat_session,
get_or_create_builder_session,
is_message_duplicate,
maybe_append_user_message,
upsert_chat_session,
@@ -918,3 +921,178 @@ async def test_append_and_save_message_lock_release_failure_is_ignored(
new_msg = ChatMessage(role="user", content="new msg")
result = await append_and_save_message(session.session_id, new_msg)
assert result is not None
# ─── get_or_create_builder_session ─────────────────────────────────────
@pytest.mark.asyncio
async def test_get_or_create_builder_session_raises_when_graph_not_owned(
mocker: MockerFixture,
) -> None:
"""Regression: the helper must verify the caller owns the graph before
any session lookup/creation. ``library_db().get_library_agent_by_graph_id``
returns ``None`` when the user doesn't own *graph_id*, which must surface
as :class:`NotFoundError` (mapped to HTTP 404 by the REST layer)."""
library_db_mock = mocker.MagicMock(
get_library_agent_by_graph_id=mocker.AsyncMock(return_value=None),
update_library_agent=mocker.AsyncMock(),
)
mocker.patch("backend.copilot.model.library_db", return_value=library_db_mock)
create_mock = mocker.patch(
"backend.copilot.model.create_chat_session",
new_callable=mocker.AsyncMock,
)
with pytest.raises(NotFoundError):
await get_or_create_builder_session("u1", "graph-not-mine")
# Confirms the ownership check short-circuits before we hit
# create_chat_session, so no orphaned session rows can be created.
create_mock.assert_not_awaited()
library_db_mock.update_library_agent.assert_not_awaited()
@pytest.mark.asyncio
async def test_get_or_create_builder_session_returns_existing_when_owned(
mocker: MockerFixture,
) -> None:
"""When the caller owns the graph AND a session pointer on the library
agent resolves to a live chat session, return it unchanged without
creating a new one or re-writing the pointer."""
existing_session = ChatSession.new(
"u1", dry_run=False, builder_graph_id="graph-mine"
)
existing_session.session_id = "sess-existing"
library_agent = mocker.MagicMock(
id="lib-1",
settings=mocker.MagicMock(builder_chat_session_id="sess-existing"),
)
library_db_mock = mocker.MagicMock(
get_library_agent_by_graph_id=mocker.AsyncMock(return_value=library_agent),
update_library_agent=mocker.AsyncMock(),
)
mocker.patch("backend.copilot.model.library_db", return_value=library_db_mock)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=existing_session,
)
create_mock = mocker.patch(
"backend.copilot.model.create_chat_session",
new_callable=mocker.AsyncMock,
)
result = await get_or_create_builder_session("u1", "graph-mine")
assert result is existing_session
create_mock.assert_not_awaited()
library_db_mock.update_library_agent.assert_not_awaited()
@pytest.mark.asyncio
async def test_get_or_create_builder_session_writes_pointer_on_create(
mocker: MockerFixture,
) -> None:
"""When no session pointer exists yet, create a new ChatSession and
write its id back to ``library_agent.settings.builder_chat_session_id``
so the next call resumes the same chat."""
library_agent = mocker.MagicMock(
id="lib-1",
settings=mocker.MagicMock(builder_chat_session_id=None),
)
library_db_mock = mocker.MagicMock(
get_library_agent_by_graph_id=mocker.AsyncMock(return_value=library_agent),
update_library_agent=mocker.AsyncMock(),
)
mocker.patch("backend.copilot.model.library_db", return_value=library_db_mock)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=None,
)
new_session = ChatSession.new("u1", dry_run=False, builder_graph_id="graph-mine")
new_session.session_id = "sess-new"
create_mock = mocker.patch(
"backend.copilot.model.create_chat_session",
new_callable=mocker.AsyncMock,
return_value=new_session,
)
result = await get_or_create_builder_session("u1", "graph-mine")
assert result is new_session
create_mock.assert_awaited_once()
library_db_mock.update_library_agent.assert_awaited_once()
call_kwargs = library_db_mock.update_library_agent.call_args.kwargs
assert call_kwargs["library_agent_id"] == "lib-1"
assert call_kwargs["user_id"] == "u1"
assert call_kwargs["settings"].builder_chat_session_id == "sess-new"
@pytest.mark.asyncio
async def test_get_or_create_builder_session_recreates_when_pointer_stale(
mocker: MockerFixture,
) -> None:
"""When the stored pointer no longer resolves (session was deleted),
fall through to creating a fresh session and updating the pointer."""
library_agent = mocker.MagicMock(
id="lib-1",
settings=mocker.MagicMock(builder_chat_session_id="sess-gone"),
)
library_db_mock = mocker.MagicMock(
get_library_agent_by_graph_id=mocker.AsyncMock(return_value=library_agent),
update_library_agent=mocker.AsyncMock(),
)
mocker.patch("backend.copilot.model.library_db", return_value=library_db_mock)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=None,
)
new_session = ChatSession.new("u1", dry_run=False, builder_graph_id="graph-mine")
new_session.session_id = "sess-new"
create_mock = mocker.patch(
"backend.copilot.model.create_chat_session",
new_callable=mocker.AsyncMock,
return_value=new_session,
)
result = await get_or_create_builder_session("u1", "graph-mine")
assert result is new_session
create_mock.assert_awaited_once()
library_db_mock.update_library_agent.assert_awaited_once()
def test_chat_message_from_db_round_trips_created_at() -> None:
"""ChatMessage.from_db surfaces the DB row's createdAt on the pydantic
model so the API response carries it through to the frontend's TurnStats
map (powering the hover-reveal date on the copilot UI)."""
from datetime import datetime, timezone
from prisma.models import ChatMessage as PrismaChatMessage
created_at = datetime(2026, 4, 23, 10, 15, 30, tzinfo=timezone.utc)
row = PrismaChatMessage.model_construct(
id="m1",
sessionId="sess-1",
role="assistant",
content="hi",
name=None,
toolCallId=None,
refusal=None,
toolCalls=None,
functionCall=None,
sequence=3,
durationMs=4200,
createdAt=created_at,
)
msg = ChatMessage.from_db(row)
assert msg.role == "assistant"
assert msg.content == "hi"
assert msg.sequence == 3
assert msg.duration_ms == 4200
assert msg.created_at == created_at

View File

@@ -0,0 +1,147 @@
"""Moonshot-specific pricing and cache-control helpers.
Moonshot's Kimi K2.x family is routed through OpenRouter's Anthropic-compat
shim — it speaks Anthropic's API shape but its pricing and cache behaviour
diverge from Anthropic in ways the Claude Agent SDK CLI and our baseline
cache-control gating don't handle on their own:
* **Rate card** — NOT the canonical cost source. The authoritative number
for every OpenRouter-routed turn is the reconcile task
(:mod:`openrouter_cost`), which reads ``total_cost`` directly from
``/api/v1/generation`` post-turn. This module exists purely so the
CLI's in-turn ``ResultMessage.total_cost_usd`` (which silently bills
Moonshot at Sonnet rates, ~5x the real Moonshot price because the CLI
pricing table only knows Anthropic) isn't left wildly wrong before the
reconcile fires AND so the reconcile's lookup-fail fallback records a
plausible Moonshot estimate rather than a Sonnet-rate overcharge.
Signal authority: reconcile >> this module's rate card >> CLI.
* **Cache-control** — Anthropic and Moonshot both accept the
``cache_control: {type: ephemeral}`` breakpoint on message blocks, but
our baseline path currently gates cache markers on an
``anthropic/`` / ``claude`` name match because non-Anthropic providers
(OpenAI, Grok, Gemini) 400 on the unknown field. Moonshot's
Anthropic-compat endpoint silently accepts and honours the marker —
empirically boosts cache hit rate on continuation turns — but was
caught in the non-Anthropic branch of the original gate.
:func:`moonshot_supports_cache_control` lets callers widen the gate
to include Moonshot without weakening the ``false`` answer for
OpenAI et al. (The predicate is intentionally narrow — Moonshot-only
— so callers combine it with an explicit Anthropic check at the call
site; see ``baseline/service.py::_supports_prompt_cache_markers``.)
Detection is prefix-based (``moonshotai/``). Moonshot routes every Kimi
SKU through the same Anthropic-compat surface and currently prices them
identically, so a new ``moonshotai/kimi-k3.0`` slug transparently
inherits both the rate card and the cache-control gate without editing
this file. Per-slug overrides are in :data:`_RATE_OVERRIDES_USD_PER_MTOK`
for when Moonshot eventually splits prices.
"""
from __future__ import annotations
# All Moonshot slugs share these rates as of April 2026 — Moonshot prices
# every Kimi K2.x SKU at $0.60/$2.80 per million (input/output) via
# OpenRouter. Cache-read / cache-write discounts are NOT applied here:
# OpenRouter currently exposes only a single input price per Moonshot
# endpoint; the real billed amount (with cache savings) lands via the
# reconcile path. Keep in sync with https://platform.moonshot.ai/docs/pricing.
_DEFAULT_MOONSHOT_RATE_USD_PER_MTOK: tuple[float, float] = (0.60, 2.80)
# Per-slug overrides for when Moonshot splits pricing across SKUs. Empty
# today — every slug matching ``moonshotai/`` falls back to
# :data:`_DEFAULT_MOONSHOT_RATE_USD_PER_MTOK`.
_RATE_OVERRIDES_USD_PER_MTOK: dict[str, tuple[float, float]] = {}
# Vendor prefix — matches any OpenRouter slug Moonshot ships. Keep as a
# module constant so the prefix check stays in exactly one place.
_MOONSHOT_PREFIX = "moonshotai/"
def is_moonshot_model(model: str | None) -> bool:
"""True when *model* is a Moonshot OpenRouter slug.
Prefix match against ``moonshotai/`` covers every Kimi SKU Moonshot
ships today (``kimi-k2``, ``kimi-k2.5``, ``kimi-k2.6``,
``kimi-k2-thinking``) plus any future SKU Moonshot publishes under
the same namespace. Used by both pricing and cache-control gating.
"""
return isinstance(model, str) and model.startswith(_MOONSHOT_PREFIX)
def rate_card_usd(model: str | None) -> tuple[float, float] | None:
"""Return (input, output) $/Mtok for *model* or None if non-Moonshot.
Looks up a per-slug override first, falling back to the shared
default for anything under ``moonshotai/``. Returns None for
non-Moonshot slugs (including ``None``) so callers can skip the
override without a preflight guard.
"""
if not is_moonshot_model(model):
return None
# ``is_moonshot_model`` narrowed ``model`` to str; dict.get is
# type-safe here despite the wider param annotation above.
assert model is not None
return _RATE_OVERRIDES_USD_PER_MTOK.get(model, _DEFAULT_MOONSHOT_RATE_USD_PER_MTOK)
def override_cost_usd(
*,
model: str | None,
sdk_reported_usd: float,
prompt_tokens: int,
completion_tokens: int,
cache_read_tokens: int,
cache_creation_tokens: int,
) -> float:
"""Recompute SDK turn cost from the Moonshot rate card.
Not the canonical cost source — the OpenRouter ``/generation``
reconcile (:mod:`openrouter_cost`) lands the authoritative billed
amount post-turn. This helper exists only to improve the CLI's
in-turn ``ResultMessage.total_cost_usd``:
1. So the ``cost_usd`` the client sees before the reconcile completes
isn't wildly wrong (the CLI would otherwise ship a Sonnet-rate
estimate, ~5x the real Moonshot bill).
2. So the reconcile's own lookup-fail fallback records a plausible
Moonshot estimate rather than the CLI's Sonnet number.
For Moonshot slugs we compute cost from the reported token counts;
for anything else (including Anthropic) we return the SDK number
unchanged — Anthropic slugs are priced accurately by the CLI.
Cache read / creation tokens are folded into ``prompt_tokens`` at
the full input rate because Moonshot's rate card doesn't distinguish
them at the OpenRouter surface; the reconcile has the authoritative
discount accounting for turns where Moonshot's cache engaged.
"""
if model is None:
return sdk_reported_usd
rates = rate_card_usd(model)
if rates is None:
return sdk_reported_usd
input_rate, output_rate = rates
total_prompt = prompt_tokens + cache_read_tokens + cache_creation_tokens
return (total_prompt * input_rate + completion_tokens * output_rate) / 1_000_000
def moonshot_supports_cache_control(model: str | None) -> bool:
"""True when a Moonshot *model* accepts Anthropic-style ``cache_control``.
Narrow, Moonshot-specific predicate — callers that need the full
"does this route accept cache markers" answer combine this with an
Anthropic check (see ``baseline/service.py::_supports_prompt_cache_markers``).
Named ``moonshot_*`` deliberately so the call site can't mistake it
for a universal predicate that answers correctly for Anthropic
(which also supports cache_control — this function would return
False for Anthropic slugs).
Moonshot's Anthropic-compat endpoint honours the marker. Without
it Moonshot falls back to its own automatic prefix caching, which
drifts more readily between turns (internal testing saw 0/4 cache
hits across two continuation sessions). With explicit
``cache_control`` the upstream cache hit rate rises to the same
ballpark as Anthropic's ~60-95% on continuations.
"""
return is_moonshot_model(model)

View File

@@ -0,0 +1,173 @@
"""Unit tests for Moonshot pricing and cache-control helpers."""
from __future__ import annotations
import pytest
from backend.copilot.moonshot import (
is_moonshot_model,
moonshot_supports_cache_control,
override_cost_usd,
rate_card_usd,
)
class TestIsMoonshotModel:
"""Prefix detection covers every Moonshot SKU without a slug list."""
@pytest.mark.parametrize(
"model",
[
"moonshotai/kimi-k2.6",
"moonshotai/kimi-k2-thinking",
"moonshotai/kimi-k2.5",
"moonshotai/kimi-k2",
"moonshotai/kimi-k3.0", # Future SKU must match transparently.
],
)
def test_moonshot_slugs_match(self, model: str) -> None:
assert is_moonshot_model(model) is True
@pytest.mark.parametrize(
"model",
[
"anthropic/claude-sonnet-4.6",
"anthropic/claude-opus-4.7",
"openai/gpt-4o",
"google/gemini-2.5-flash",
"xai/grok-4",
"deepseek/deepseek-v3",
"", # Empty string — not Moonshot.
],
)
def test_non_moonshot_slugs_do_not_match(self, model: str) -> None:
assert is_moonshot_model(model) is False
@pytest.mark.parametrize("model", [None, 123, ["moonshotai/kimi-k2.6"]])
def test_non_string_returns_false(self, model) -> None:
# Type-robust: never raise on unexpected types; callers pass None.
assert is_moonshot_model(model) is False
class TestRateCardUsd:
"""Rate card defaults to the shared Moonshot price for every SKU."""
def test_moonshot_default_rate(self) -> None:
assert rate_card_usd("moonshotai/kimi-k2.6") == (0.60, 2.80)
def test_future_moonshot_sku_inherits_default(self) -> None:
# Verifies the prefix-based fallback — new SKUs don't need a code
# edit to get a reasonable rate card.
assert rate_card_usd("moonshotai/kimi-k3.0") == (0.60, 2.80)
def test_non_moonshot_returns_none(self) -> None:
assert rate_card_usd("anthropic/claude-sonnet-4.6") is None
assert rate_card_usd("openai/gpt-4o") is None
class TestOverrideCostUsd:
"""Rate-card override replaces the CLI's Sonnet-rate estimate for
Moonshot turns; Anthropic and unknown slugs pass through unchanged."""
def test_moonshot_recomputes_from_rate_card(self) -> None:
"""A 29.5K-prompt Kimi turn should land at ~$0.018 on the
Moonshot rate card, not the CLI's $0.09 Sonnet-rate estimate."""
recomputed = override_cost_usd(
model="moonshotai/kimi-k2.6",
sdk_reported_usd=0.089862, # What the CLI reported (Sonnet price).
prompt_tokens=29564,
completion_tokens=78,
cache_read_tokens=0,
cache_creation_tokens=0,
)
expected = (29564 * 0.60 + 78 * 2.80) / 1_000_000
assert recomputed == pytest.approx(expected, rel=1e-9)
assert 0.017 < recomputed < 0.019 # Sanity against Moonshot's rate card.
def test_anthropic_passes_through(self) -> None:
"""Anthropic slugs are priced accurately by the CLI already —
the override returns the SDK number unchanged."""
assert (
override_cost_usd(
model="anthropic/claude-sonnet-4.6",
sdk_reported_usd=0.089862,
prompt_tokens=29564,
completion_tokens=78,
cache_read_tokens=0,
cache_creation_tokens=0,
)
== 0.089862
)
def test_unknown_non_moonshot_passes_through(self) -> None:
"""A non-Moonshot, non-Anthropic slug falls back to the SDK value
— best-effort rather than leaking a zero or a wrong rate card."""
assert (
override_cost_usd(
model="deepseek/deepseek-v3",
sdk_reported_usd=0.05,
prompt_tokens=10_000,
completion_tokens=500,
cache_read_tokens=0,
cache_creation_tokens=0,
)
== 0.05
)
def test_none_model_passes_through(self) -> None:
"""Subscription mode sets model=None — return the SDK value."""
assert (
override_cost_usd(
model=None,
sdk_reported_usd=0.07,
prompt_tokens=100,
completion_tokens=10,
cache_read_tokens=0,
cache_creation_tokens=0,
)
== 0.07
)
def test_cache_tokens_priced_at_input_rate(self) -> None:
"""OpenRouter's Moonshot endpoints don't expose a discounted
cached-input price — cache_read / cache_creation tokens are
priced at the full input rate. The reconcile path has the
authoritative discount for turns where Moonshot's cache engaged."""
recomputed = override_cost_usd(
model="moonshotai/kimi-k2.6",
sdk_reported_usd=0.5,
prompt_tokens=1000,
completion_tokens=0,
cache_read_tokens=5000,
cache_creation_tokens=2000,
)
expected = (1000 + 5000 + 2000) * 0.60 / 1_000_000
assert recomputed == pytest.approx(expected, rel=1e-9)
class TestSupportsCacheControl:
"""Gate for emitting ``cache_control: {type: ephemeral}`` on message
blocks. True for Moonshot (Anthropic-compat endpoint accepts it)
and False for everything else this module knows about — Anthropic
callers use their own ``_is_anthropic_model`` check which is
combined with this one into a wider gate."""
def test_moonshot_supports_cache_control(self) -> None:
assert moonshot_supports_cache_control("moonshotai/kimi-k2.6") is True
def test_future_moonshot_sku_supports_cache_control(self) -> None:
assert moonshot_supports_cache_control("moonshotai/kimi-k3.0") is True
@pytest.mark.parametrize(
"model",
[
"openai/gpt-4o",
"google/gemini-2.5-flash",
"xai/grok-4",
"deepseek/deepseek-v3",
"",
None,
],
)
def test_non_moonshot_does_not_support_cache_control(self, model) -> None:
assert moonshot_supports_cache_control(model) is False

View File

@@ -0,0 +1,384 @@
"""Shared helpers for draining and injecting pending messages.
Used by both the baseline and SDK copilot paths to avoid duplicating
the try/except drain, format, insert, and persist patterns.
Also provides the call-rate-limit check for the queue endpoint so
routes.py stays free of Redis/Lua details.
"""
import logging
from typing import TYPE_CHECKING, Callable
from fastapi import HTTPException
from pydantic import BaseModel
from backend.copilot.model import ChatMessage, upsert_chat_session
from backend.copilot.pending_messages import (
MAX_PENDING_MESSAGES,
PendingMessage,
PendingMessageContext,
drain_pending_messages,
format_pending_as_user_message,
push_pending_message,
)
from backend.copilot.stream_registry import get_session as get_active_session_meta
from backend.data.redis_client import get_redis_async
from backend.data.redis_helpers import incr_with_ttl
from backend.data.workspace import resolve_workspace_files
if TYPE_CHECKING:
from backend.copilot.model import ChatSession
from backend.copilot.transcript_builder import TranscriptBuilder
logger = logging.getLogger(__name__)
# Call-frequency cap for the pending-message endpoint. The token-budget
# check guards against overspend but not rapid-fire pushes from a client
# with a large budget.
PENDING_CALL_LIMIT = 30
PENDING_CALL_WINDOW_SECONDS = 60
_PENDING_CALL_KEY_PREFIX = "copilot:pending:calls:"
async def is_turn_in_flight(session_id: str) -> bool:
"""Return ``True`` when a copilot turn is actively running for *session_id*.
Used by the unified POST /stream entry point and the autopilot block so
a second message arriving while an earlier turn is still executing gets
queued into the pending buffer instead of racing the in-flight turn on
the cluster lock.
"""
active = await get_active_session_meta(session_id)
return active is not None and active.status == "running"
class QueuePendingMessageResponse(BaseModel):
"""Response returned by ``POST /stream`` with status 202 when a message
is queued because the session already has a turn in flight.
- ``buffer_length``: how many messages are now in the session's
pending buffer (after this push)
- ``max_buffer_length``: the per-session cap (server-side constant)
- ``turn_in_flight``: ``True`` if a copilot turn was running when
we checked — purely informational for UX feedback. Always ``True``
for responses from ``POST /stream`` with status 202.
"""
buffer_length: int
max_buffer_length: int
turn_in_flight: bool
async def queue_user_message(
*,
session_id: str,
message: str,
context: PendingMessageContext | None = None,
file_ids: list[str] | None = None,
) -> QueuePendingMessageResponse:
"""Push *message* into the per-session pending buffer.
The shared primitive for "a message arrived while a turn is in flight"
called from the unified POST /stream handler and the autopilot block.
Call-frequency rate limiting is the caller's responsibility (HTTP path
enforces it; internal block callers skip it).
"""
pending = PendingMessage(
content=message,
file_ids=file_ids or [],
context=context,
)
new_len = await push_pending_message(session_id, pending)
return QueuePendingMessageResponse(
buffer_length=new_len,
max_buffer_length=MAX_PENDING_MESSAGES,
turn_in_flight=await is_turn_in_flight(session_id),
)
async def queue_pending_for_http(
*,
session_id: str,
user_id: str,
message: str,
context: dict[str, str] | None,
file_ids: list[str] | None,
) -> QueuePendingMessageResponse:
"""HTTP-facing wrapper around :func:`queue_user_message`.
Owns the HTTP-only concerns that sat inline in ``stream_chat_post``:
1. Per-user call-rate cap (429 on overflow).
2. File-ID sanitisation against the user's own workspace.
3. ``{url, content}`` dict → ``PendingMessageContext`` coercion.
4. Push via ``queue_user_message``.
Raises :class:`HTTPException` with status 429 if the rate cap is hit;
otherwise returns the ``QueuePendingMessageResponse`` the handler can
serialise 1:1 into the 202 body.
"""
call_count = await check_pending_call_rate(user_id)
if call_count > PENDING_CALL_LIMIT:
raise HTTPException(
status_code=429,
detail=(
f"Too many queued message requests this minute: limit is "
f"{PENDING_CALL_LIMIT} per {PENDING_CALL_WINDOW_SECONDS}s "
"across all sessions"
),
)
sanitized_file_ids: list[str] | None = None
if file_ids:
files = await resolve_workspace_files(user_id, file_ids)
sanitized_file_ids = [wf.id for wf in files] or None
# ``PendingMessageContext`` uses the default ``extra='ignore'`` so
# unknown keys in the loose HTTP-level ``context`` dict are silently
# dropped rather than raising ``ValidationError`` + 500ing (sentry
# r3105553772). The strict mode would only help protect against
# typos, but the upstream ``StreamChatRequest.context: dict[str, str]``
# is already schemaless, so the strict mode adds no real safety.
queue_context = PendingMessageContext.model_validate(context) if context else None
return await queue_user_message(
session_id=session_id,
message=message,
context=queue_context,
file_ids=sanitized_file_ids,
)
async def check_pending_call_rate(user_id: str) -> int:
"""Increment and return the per-user push counter for the current window.
The counter is **user-global**: it counts pushes across ALL sessions
belonging to the user, not per-session. This prevents a client from
bypassing the cap by spreading rapid pushes across many sessions.
Returns the new call count. Raises nothing — callers compare the
return value against ``PENDING_CALL_LIMIT`` and decide what to do.
Fails open (returns 0) if Redis is unavailable so the endpoint stays
usable during Redis hiccups.
"""
try:
redis = await get_redis_async()
key = f"{_PENDING_CALL_KEY_PREFIX}{user_id}"
return await incr_with_ttl(redis, key, PENDING_CALL_WINDOW_SECONDS)
except Exception:
logger.warning(
"pending_message_helpers: call-rate check failed for user=%s, failing open",
user_id,
)
return 0
async def drain_pending_safe(
session_id: str, log_prefix: str = ""
) -> list[PendingMessage]:
"""Drain the pending buffer and return the full ``PendingMessage`` objects.
Returns ``[]`` on any Redis error so callers can always treat the
result as a plain list. Callers that only need the rendered string
(turn-start injection, auto-continue combined prompt) wrap this with
:func:`pending_texts_from` — we return the structured objects so the
re-queue rollback path can preserve ``file_ids`` / ``context`` that
would otherwise be stripped by a text-only conversion.
"""
try:
return await drain_pending_messages(session_id)
except Exception:
logger.warning(
"%s drain_pending_messages failed, skipping",
log_prefix or "pending_messages",
exc_info=True,
)
return []
def pending_texts_from(pending: list[PendingMessage]) -> list[str]:
"""Render a list of ``PendingMessage`` objects into plain text strings.
Shared helper for the two callers that need the rendered form:
turn-start injection (bundles the pending block into the user prompt)
and the auto-continue combined-message path.
"""
return [format_pending_as_user_message(pm)["content"] for pm in pending]
def combine_pending_with_current(
pending: list[PendingMessage],
current_message: str | None,
*,
request_arrival_at: float,
) -> str:
"""Order pending messages around *current_message* by typing time.
Pending messages whose ``enqueued_at`` is strictly greater than
``request_arrival_at`` were typed AFTER the user hit enter to start
the current turn (the "race" path: queued into the pending buffer
while ``/stream`` was still processing on the server). They belong
chronologically AFTER the current message.
Pending messages whose ``enqueued_at`` is less than or equal to
``request_arrival_at`` were typed BEFORE the current turn — usually
from a prior in-flight window that auto-continue didn't consume.
They belong BEFORE the current message.
Stable-sort within each bucket preserves enqueue order for messages
typed in the same phase. Legacy ``PendingMessage`` objects with no
``enqueued_at`` (written by older workers, defaulted to 0.0) sort as
"before everything" — the pre-fix behaviour, which is a safe default
for the rare queue entries that outlived a deploy.
"""
before: list[PendingMessage] = []
after: list[PendingMessage] = []
for pm in pending:
if request_arrival_at > 0 and pm.enqueued_at > request_arrival_at:
after.append(pm)
else:
before.append(pm)
parts = pending_texts_from(before)
if current_message and current_message.strip():
parts.append(current_message)
parts.extend(pending_texts_from(after))
return "\n\n".join(parts)
def insert_pending_before_last(session: "ChatSession", texts: list[str]) -> None:
"""Insert pending messages into *session* just before the last message.
Pending messages were queued during the previous turn, so they belong
chronologically before the current user message that was already
appended via ``maybe_append_user_message``. Inserting at ``len-1``
preserves that order: [...history, pending_1, pending_2, current_msg].
The caller must have already appended the current user message before
calling this function. If ``session.messages`` is unexpectedly empty,
a warning is logged and the messages are appended at index 0 so they
are not silently lost.
"""
if not texts:
return
if not session.messages:
logger.warning(
"insert_pending_before_last: session.messages is empty — "
"current user message was not appended before drain; "
"inserting pending messages at index 0"
)
insert_idx = max(0, len(session.messages) - 1)
for i, content in enumerate(texts):
session.messages.insert(
insert_idx + i, ChatMessage(role="user", content=content)
)
async def persist_session_safe(
session: "ChatSession", log_prefix: str = ""
) -> "ChatSession":
"""Persist *session* to the DB, returning the (possibly updated) session.
Swallows transient DB errors so a failing persist doesn't discard
messages already popped from Redis — the turn continues from memory.
"""
try:
return await upsert_chat_session(session)
except Exception as err:
logger.warning(
"%s Failed to persist pending messages: %s",
log_prefix or "pending_messages",
err,
)
return session
async def persist_pending_as_user_rows(
session: "ChatSession",
transcript_builder: "TranscriptBuilder",
pending: list[PendingMessage],
*,
log_prefix: str,
content_of: Callable[[PendingMessage], str] = lambda pm: pm.content,
on_rollback: Callable[[int], None] | None = None,
) -> bool:
"""Append ``pending`` as user rows to *session* + *transcript_builder*,
persist, and roll back + re-queue if the persist silently failed.
This is the shared mid-turn follow-up persist used by both the baseline
and SDK paths — they differ only in (a) how they derive the displayed
string from a ``PendingMessage`` and (b) what extra per-path state
(e.g. ``openai_messages``) needs trimming on rollback. Those variance
points are exposed as ``content_of`` and ``on_rollback``.
Flow:
1. Snapshot transcript + record the session.messages length.
2. Append one user row per pending message to both stores.
3. ``persist_session_safe`` — swallowed errors mean no sequences get
back-filled, which we use as the failure signal.
4. If any newly-appended row has ``sequence is None`` → rollback:
delete the appended rows, restore the transcript snapshot, call
``on_rollback(anchor)`` for the caller's own state, then re-push
each ``PendingMessage`` into the primary pending buffer so the
next turn-start drain picks them up.
Returns ``True`` when the rows were persisted with sequences, ``False``
when the rollback path fired. Callers can use this to decide whether
to log success or continue a retry loop.
"""
if not pending:
return True
session_anchor = len(session.messages)
transcript_snapshot = transcript_builder.snapshot()
for pm in pending:
content = content_of(pm)
session.messages.append(ChatMessage(role="user", content=content))
transcript_builder.append_user(content=content)
# ``persist_session_safe`` may return a ``model_copy`` of *session* (e.g.
# when ``upsert_chat_session`` patches a concurrently-updated title).
# Do NOT reassign the caller's reference — the caller already pushed the
# rows into its own ``session.messages`` above, and rollback below MUST
# delete from that same list. Inspect the returned object only to learn
# whether sequences were back-filled; if so, copy them onto the caller's
# objects so the session stays internally consistent for downstream
# ``append_and_save_message`` calls.
persisted = await persist_session_safe(session, log_prefix)
persisted_tail = persisted.messages[session_anchor:]
if len(persisted_tail) == len(pending) and all(
m.sequence is not None for m in persisted_tail
):
for caller_msg, persisted_msg in zip(
session.messages[session_anchor:], persisted_tail
):
caller_msg.sequence = persisted_msg.sequence
newly_appended = session.messages[session_anchor:]
if any(m.sequence is None for m in newly_appended):
logger.warning(
"%s Mid-turn follow-up persist did not back-fill sequences; "
"rolling back %d row(s) and re-queueing into the primary buffer",
log_prefix,
len(pending),
)
del session.messages[session_anchor:]
transcript_builder.restore(transcript_snapshot)
if on_rollback is not None:
on_rollback(session_anchor)
for pm in pending:
try:
await push_pending_message(session.session_id, pm)
except Exception:
logger.exception(
"%s Failed to re-queue mid-turn follow-up on rollback",
log_prefix,
)
return False
logger.info(
"%s Persisted %d mid-turn follow-up user row(s)",
log_prefix,
len(pending),
)
return True

View File

@@ -0,0 +1,472 @@
"""Unit tests for pending_message_helpers."""
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from backend.copilot import pending_message_helpers as helpers_module
from backend.copilot.pending_message_helpers import (
PENDING_CALL_LIMIT,
check_pending_call_rate,
combine_pending_with_current,
drain_pending_safe,
insert_pending_before_last,
persist_session_safe,
)
from backend.copilot.pending_messages import PendingMessage
# ── check_pending_call_rate ────────────────────────────────────────────
@pytest.mark.asyncio
async def test_check_pending_call_rate_returns_count(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(
helpers_module, "get_redis_async", AsyncMock(return_value=MagicMock())
)
monkeypatch.setattr(helpers_module, "incr_with_ttl", AsyncMock(return_value=3))
result = await check_pending_call_rate("user-1")
assert result == 3
@pytest.mark.asyncio
async def test_check_pending_call_rate_fails_open_on_redis_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(
helpers_module,
"get_redis_async",
AsyncMock(side_effect=ConnectionError("down")),
)
result = await check_pending_call_rate("user-1")
assert result == 0
@pytest.mark.asyncio
async def test_check_pending_call_rate_at_limit(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(
helpers_module, "get_redis_async", AsyncMock(return_value=MagicMock())
)
monkeypatch.setattr(
helpers_module,
"incr_with_ttl",
AsyncMock(return_value=PENDING_CALL_LIMIT + 1),
)
result = await check_pending_call_rate("user-1")
assert result > PENDING_CALL_LIMIT
# ── drain_pending_safe ─────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_drain_pending_safe_returns_pending_messages(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""``drain_pending_safe`` now returns the structured ``PendingMessage``
objects (not pre-formatted strings) so the auto-continue re-queue path
can preserve ``file_ids`` / ``context`` on rollback."""
msgs = [
PendingMessage(content="hello", file_ids=["f1"]),
PendingMessage(content="world"),
]
monkeypatch.setattr(
helpers_module, "drain_pending_messages", AsyncMock(return_value=msgs)
)
result = await drain_pending_safe("sess-1")
assert result == msgs
# Structured metadata survives — the bug r3105523410 guard.
assert result[0].file_ids == ["f1"]
@pytest.mark.asyncio
async def test_drain_pending_safe_returns_empty_on_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(
helpers_module,
"drain_pending_messages",
AsyncMock(side_effect=RuntimeError("redis down")),
)
result = await drain_pending_safe("sess-1", "[Test]")
assert result == []
@pytest.mark.asyncio
async def test_drain_pending_safe_empty_buffer(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
helpers_module, "drain_pending_messages", AsyncMock(return_value=[])
)
result = await drain_pending_safe("sess-1")
assert result == []
# ── combine_pending_with_current ───────────────────────────────────────
def test_combine_before_current_when_pending_older() -> None:
"""Pending typed before the /stream request → goes ahead of current
(prior-turn / inter-turn case)."""
pending = [
PendingMessage(content="older_a", enqueued_at=100.0),
PendingMessage(content="older_b", enqueued_at=110.0),
]
result = combine_pending_with_current(
pending, "current_msg", request_arrival_at=120.0
)
assert result == "older_a\n\nolder_b\n\ncurrent_msg"
def test_combine_after_current_when_pending_newer() -> None:
"""Pending queued AFTER the /stream request arrived → goes after
current. This is the race path where user hits enter twice in quick
succession (second press goes through the queue endpoint while the
first /stream is still processing)."""
pending = [
PendingMessage(content="race_followup", enqueued_at=125.0),
]
result = combine_pending_with_current(
pending, "current_msg", request_arrival_at=120.0
)
assert result == "current_msg\n\nrace_followup"
def test_combine_mixed_before_and_after() -> None:
"""Mixed bucket: older items first, current, then newer race items."""
pending = [
PendingMessage(content="way_older", enqueued_at=50.0),
PendingMessage(content="race_fast_follow", enqueued_at=125.0),
PendingMessage(content="also_older", enqueued_at=80.0),
]
result = combine_pending_with_current(
pending, "current_msg", request_arrival_at=120.0
)
# Enqueue order preserved within each bucket (stable partition).
assert result == "way_older\n\nalso_older\n\ncurrent_msg\n\nrace_fast_follow"
def test_combine_no_current_joins_pending() -> None:
"""Auto-continue case: no current message, just drained pending."""
pending = [PendingMessage(content="a"), PendingMessage(content="b")]
result = combine_pending_with_current(pending, None, request_arrival_at=0.0)
assert result == "a\n\nb"
def test_combine_legacy_zero_timestamp_sorts_before() -> None:
"""A ``PendingMessage`` from before this field existed (default 0.0)
should sort as "before everything" — safe pre-fix behaviour."""
pending = [PendingMessage(content="legacy", enqueued_at=0.0)]
result = combine_pending_with_current(
pending, "current_msg", request_arrival_at=120.0
)
assert result == "legacy\n\ncurrent_msg"
def test_combine_missing_request_arrival_falls_back_to_before() -> None:
"""If the HTTP handler didn't stamp ``request_arrival_at`` (0.0
default — older queue entries) the combine degrades gracefully to
the pre-fix behaviour: all pending goes before current."""
pending = [
PendingMessage(content="a", enqueued_at=500.0),
PendingMessage(content="b", enqueued_at=1000.0),
]
result = combine_pending_with_current(pending, "current", request_arrival_at=0.0)
assert result == "a\n\nb\n\ncurrent"
# ── insert_pending_before_last ─────────────────────────────────────────
def _make_session(*contents: str) -> Any:
session = MagicMock()
session.messages = [MagicMock(role="user", content=c) for c in contents]
return session
def test_insert_pending_before_last_single_existing_message() -> None:
session = _make_session("current")
insert_pending_before_last(session, ["queued"])
assert session.messages[0].content == "queued"
assert session.messages[1].content == "current"
def test_insert_pending_before_last_multiple_pending() -> None:
session = _make_session("current")
insert_pending_before_last(session, ["p1", "p2"])
contents = [m.content for m in session.messages]
assert contents == ["p1", "p2", "current"]
def test_insert_pending_before_last_empty_session() -> None:
session = _make_session()
insert_pending_before_last(session, ["queued"])
assert session.messages[0].content == "queued"
def test_insert_pending_before_last_no_texts_is_noop() -> None:
session = _make_session("current")
insert_pending_before_last(session, [])
assert len(session.messages) == 1
# ── persist_session_safe ───────────────────────────────────────────────
@pytest.mark.asyncio
async def test_persist_session_safe_returns_updated_session(
monkeypatch: pytest.MonkeyPatch,
) -> None:
original = MagicMock()
updated = MagicMock()
monkeypatch.setattr(
helpers_module, "upsert_chat_session", AsyncMock(return_value=updated)
)
result = await persist_session_safe(original, "[Test]")
assert result is updated
@pytest.mark.asyncio
async def test_persist_session_safe_returns_original_on_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
original = MagicMock()
monkeypatch.setattr(
helpers_module,
"upsert_chat_session",
AsyncMock(side_effect=Exception("db error")),
)
result = await persist_session_safe(original, "[Test]")
assert result is original
# ── persist_pending_as_user_rows ───────────────────────────────────────
class _FakeTranscript:
"""Minimal TranscriptBuilder shim — records append_user + snapshot/restore."""
def __init__(self) -> None:
self.entries: list[str] = []
def append_user(self, content: str, uuid: str | None = None) -> None:
self.entries.append(content)
def snapshot(self) -> list[str]:
return list(self.entries)
def restore(self, snap: list[str]) -> None:
self.entries = list(snap)
def _make_chat_message_class(
monkeypatch: pytest.MonkeyPatch,
) -> Any:
"""Return a simple ChatMessage stand-in that tracks sequence."""
class _Msg:
def __init__(self, role: str, content: str) -> None:
self.role = role
self.content = content
self.sequence: int | None = None
monkeypatch.setattr(helpers_module, "ChatMessage", _Msg)
return _Msg
@pytest.mark.asyncio
async def test_persist_pending_empty_list_is_noop(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from backend.copilot.pending_message_helpers import persist_pending_as_user_rows
_make_chat_message_class(monkeypatch)
session = MagicMock()
session.messages = []
tb = _FakeTranscript()
monkeypatch.setattr(helpers_module, "upsert_chat_session", AsyncMock())
monkeypatch.setattr(helpers_module, "push_pending_message", AsyncMock())
ok = await persist_pending_as_user_rows(session, tb, [], log_prefix="[T]")
assert ok is True
assert session.messages == []
assert tb.entries == []
@pytest.mark.asyncio
async def test_persist_pending_happy_path_appends_and_returns_true(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from backend.copilot.pending_message_helpers import persist_pending_as_user_rows
from backend.copilot.pending_messages import PendingMessage as PM
_make_chat_message_class(monkeypatch)
session = MagicMock()
session.session_id = "sess"
session.messages = []
tb = _FakeTranscript()
async def _fake_upsert(sess: Any) -> Any:
# Simulate the DB back-filling sequence numbers on success.
for i, m in enumerate(sess.messages):
m.sequence = i
return sess
monkeypatch.setattr(helpers_module, "upsert_chat_session", _fake_upsert)
push_mock = AsyncMock()
monkeypatch.setattr(helpers_module, "push_pending_message", push_mock)
pending = [PM(content="a"), PM(content="b")]
ok = await persist_pending_as_user_rows(session, tb, pending, log_prefix="[T]")
assert ok is True
assert [m.content for m in session.messages] == ["a", "b"]
assert tb.entries == ["a", "b"]
push_mock.assert_not_awaited()
@pytest.mark.asyncio
async def test_persist_pending_rollback_when_sequence_missing(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from backend.copilot.pending_message_helpers import persist_pending_as_user_rows
from backend.copilot.pending_messages import PendingMessage as PM
_make_chat_message_class(monkeypatch)
session = MagicMock()
session.session_id = "sess"
# Prior state — anchor point is len(messages) before the helper runs.
session.messages = []
tb = _FakeTranscript()
tb.entries = ["earlier-entry"]
async def _fake_upsert_fails_silently(sess: Any) -> Any:
# Simulate the "persist swallowed the error" branch — sequences stay None.
return sess
monkeypatch.setattr(
helpers_module, "upsert_chat_session", _fake_upsert_fails_silently
)
push_mock = AsyncMock()
monkeypatch.setattr(helpers_module, "push_pending_message", push_mock)
pending = [PM(content="a"), PM(content="b")]
ok = await persist_pending_as_user_rows(session, tb, pending, log_prefix="[T]")
assert ok is False
# Rollback: session.messages trimmed to anchor, transcript restored.
assert session.messages == []
assert tb.entries == ["earlier-entry"]
# Both pending messages re-queued.
assert push_mock.await_count == 2
assert push_mock.await_args_list[0].args[1] is pending[0]
assert push_mock.await_args_list[1].args[1] is pending[1]
@pytest.mark.asyncio
async def test_persist_pending_rollback_calls_on_rollback_hook(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Baseline's openai_messages trim runs via the on_rollback hook."""
from backend.copilot.pending_message_helpers import persist_pending_as_user_rows
from backend.copilot.pending_messages import PendingMessage as PM
_make_chat_message_class(monkeypatch)
session = MagicMock()
session.session_id = "sess"
session.messages = []
tb = _FakeTranscript()
async def _fails(sess: Any) -> Any:
return sess
monkeypatch.setattr(helpers_module, "upsert_chat_session", _fails)
monkeypatch.setattr(helpers_module, "push_pending_message", AsyncMock())
on_rollback_calls: list[int] = []
def _on_rollback(anchor: int) -> None:
on_rollback_calls.append(anchor)
await persist_pending_as_user_rows(
session,
tb,
[PM(content="x")],
log_prefix="[T]",
on_rollback=_on_rollback,
)
assert on_rollback_calls == [0]
@pytest.mark.asyncio
async def test_persist_pending_uses_custom_content_of(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from backend.copilot.pending_message_helpers import persist_pending_as_user_rows
from backend.copilot.pending_messages import PendingMessage as PM
_make_chat_message_class(monkeypatch)
session = MagicMock()
session.session_id = "sess"
session.messages = []
tb = _FakeTranscript()
async def _ok(sess: Any) -> Any:
for i, m in enumerate(sess.messages):
m.sequence = i
return sess
monkeypatch.setattr(helpers_module, "upsert_chat_session", _ok)
monkeypatch.setattr(helpers_module, "push_pending_message", AsyncMock())
await persist_pending_as_user_rows(
session,
tb,
[PM(content="raw")],
log_prefix="[T]",
content_of=lambda pm: f"FORMATTED:{pm.content}",
)
assert session.messages[0].content == "FORMATTED:raw"
assert tb.entries == ["FORMATTED:raw"]
@pytest.mark.asyncio
async def test_persist_pending_swallows_requeue_errors(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A broken push_pending_message on rollback must not raise upward —
the rollback still needs to trim state even if re-queue fails."""
from backend.copilot.pending_message_helpers import persist_pending_as_user_rows
from backend.copilot.pending_messages import PendingMessage as PM
_make_chat_message_class(monkeypatch)
session = MagicMock()
session.session_id = "sess"
session.messages = []
tb = _FakeTranscript()
async def _fails(sess: Any) -> Any:
return sess
monkeypatch.setattr(helpers_module, "upsert_chat_session", _fails)
monkeypatch.setattr(
helpers_module,
"push_pending_message",
AsyncMock(side_effect=RuntimeError("redis down")),
)
ok = await persist_pending_as_user_rows(
session, tb, [PM(content="x")], log_prefix="[T]"
)
# Still returns False (rolled back) — exception was logged + swallowed.
assert ok is False

View File

@@ -0,0 +1,449 @@
"""Pending-message buffer for in-flight copilot turns.
When a user sends a new message while a copilot turn is already executing,
instead of blocking the frontend (or queueing a brand-new turn after the
current one finishes), we want the new message to be *injected into the
running turn* — appended between tool-call rounds so the model sees it
before its next LLM call.
This module provides the cross-process buffer that makes that possible:
- **Producer** (chat API route): pushes a pending message to Redis and
publishes a notification on a pub/sub channel.
- **Consumer** (executor running the turn): on each tool-call round,
drains the buffer and appends the pending messages to the conversation.
The Redis list is the durable store; the pub/sub channel is a fast
wake-up hint for long-idle consumers (not used by default, but available
for future blocking-wait semantics).
A hard cap of ``MAX_PENDING_MESSAGES`` per session prevents abuse. The
buffer is trimmed to the latest ``MAX_PENDING_MESSAGES`` on every push.
"""
import json
import logging
import time
from typing import Any, cast
from pydantic import BaseModel, Field, ValidationError
from backend.data.redis_client import get_redis_async
from backend.data.redis_helpers import capped_rpush
logger = logging.getLogger(__name__)
# Per-session cap. Higher values risk a runaway consumer; lower values
# risk dropping user input under heavy typing. 10 was chosen as a
# reasonable ceiling — a user typing faster than the copilot can drain
# between tool rounds is already an unusual usage pattern.
MAX_PENDING_MESSAGES = 10
# Redis key + TTL. The buffer is ephemeral: if a turn completes or the
# executor dies, the pending messages should either have been drained
# already or are safe to drop (the user can resend).
_PENDING_KEY_PREFIX = "copilot:pending:"
_PENDING_CHANNEL_PREFIX = "copilot:pending:notify:"
_PENDING_TTL_SECONDS = 3600 # 1 hour — matches stream_ttl default
# Secondary queue that carries drained-but-awaiting-persist PendingMessages
# from the MCP tool wrapper (which drains the primary buffer and injects
# into tool output for the LLM) to sdk/service.py's _dispatch_response
# handler for StreamToolOutputAvailable, which pops and persists them as a
# separate user row chronologically after the tool_result row. This is the
# hand-off between "Claude saw the follow-up mid-turn" (wrapper) and "UI
# renders a user bubble for it" (service). Rollback path re-queues into
# the PRIMARY buffer so the next turn-start drain picks them up if the
# user-row persist fails.
_PERSIST_QUEUE_KEY_PREFIX = "copilot:pending-persist:"
# Payload sent on the pub/sub notify channel. Subscribers treat any
# message as a wake-up hint; the value itself is not meaningful.
_NOTIFY_PAYLOAD = "1"
class PendingMessageContext(BaseModel):
"""Structured page context attached to a pending message.
Default ``extra='ignore'`` (pydantic's default): unknown keys from
the loose HTTP-level ``StreamChatRequest.context: dict[str, str]``
are silently dropped rather than raising ``ValidationError`` on
forward-compat additions. The strict ``extra='forbid'`` mode was
removed after sentry r3105553772 — strict validation at this
boundary only added a 500 footgun; the upstream request model is
already schemaless so strict mode protects nothing.
"""
url: str | None = Field(default=None, max_length=2_000)
content: str | None = Field(default=None, max_length=32_000)
class PendingMessage(BaseModel):
"""A user message queued for injection into an in-flight turn."""
content: str = Field(min_length=1, max_length=32_000)
file_ids: list[str] = Field(default_factory=list, max_length=20)
context: PendingMessageContext | None = None
# Wall-clock time (unix seconds, float) the message was queued by the
# user. Used by the turn-start drain to order pending relative to the
# turn's ``current`` message: items typed *before* the current's
# /stream arrival go ahead of it; items typed *after* (race path,
# queued while the /stream HTTP request was still processing) go
# after. Defaults to 0.0 for backward compatibility with entries
# written before this field existed — those sort as "before everything"
# which matches the pre-fix behaviour.
enqueued_at: float = Field(default_factory=time.time)
def _buffer_key(session_id: str) -> str:
return f"{_PENDING_KEY_PREFIX}{session_id}"
def _notify_channel(session_id: str) -> str:
return f"{_PENDING_CHANNEL_PREFIX}{session_id}"
def _decode_redis_item(item: Any) -> str:
"""Decode a redis-py list item to a str.
redis-py returns ``bytes`` when ``decode_responses=False`` and ``str``
when ``decode_responses=True``. This helper handles both so callers
don't have to repeat the isinstance guard.
"""
return item.decode("utf-8") if isinstance(item, bytes) else str(item)
async def push_pending_message(
session_id: str,
message: PendingMessage,
) -> int:
"""Append a pending message to the session's buffer.
Returns the new buffer length. Enforces ``MAX_PENDING_MESSAGES`` by
trimming from the left (oldest) — the newest message always wins if
the user has been typing faster than the copilot can drain.
Delegates to :func:`backend.data.redis_helpers.capped_rpush` so RPUSH
+ LTRIM + EXPIRE + LLEN run atomically (MULTI/EXEC) in one round
trip; a concurrent drain (LPOP) can no longer observe the list
temporarily over ``MAX_PENDING_MESSAGES``.
Note on durability: if the executor turn crashes after a push but before
the drain window runs, the message remains in Redis until the TTL expires
(``_PENDING_TTL_SECONDS``, currently 1 hour). It is delivered on the
next turn that drains the buffer. If no turn runs within the TTL the
message is silently dropped; the user may resend it.
"""
redis = await get_redis_async()
key = _buffer_key(session_id)
payload = message.model_dump_json()
new_length = await capped_rpush(
redis,
key,
payload,
max_len=MAX_PENDING_MESSAGES,
ttl_seconds=_PENDING_TTL_SECONDS,
)
# Fire-and-forget notify. Subscribers use this as a wake-up hint;
# the buffer itself is authoritative so a lost notify is harmless.
try:
await redis.publish(_notify_channel(session_id), _NOTIFY_PAYLOAD)
except Exception as e: # pragma: no cover
logger.warning("pending_messages: publish failed for %s: %s", session_id, e)
logger.info(
"pending_messages: pushed message to session=%s (buffer_len=%d)",
session_id,
new_length,
)
return new_length
async def drain_pending_messages(session_id: str) -> list[PendingMessage]:
"""Atomically pop all pending messages for *session_id*.
Returns them in enqueue order (oldest first). Uses ``LPOP`` with a
count so the read+delete is a single Redis round trip. If the list
is empty or missing, returns ``[]``.
"""
redis = await get_redis_async()
key = _buffer_key(session_id)
# Redis LPOP with count (Redis 6.2+) returns None for missing key,
# empty list if we somehow race an empty key, or the popped items.
# Draining MAX_PENDING_MESSAGES at once is safe because the push side
# uses RPUSH + LTRIM(-MAX_PENDING_MESSAGES, -1) to cap the list to that
# same value, so the list can never hold more items than we drain here.
# If the cap is raised on the push side, raise the drain count here too
# (or switch to a loop drain).
lpop_result = await redis.lpop(key, MAX_PENDING_MESSAGES) # type: ignore[assignment]
if not lpop_result:
return []
raw_popped: list[Any] = list(lpop_result)
# redis-py may return bytes or str depending on decode_responses.
decoded: list[str] = [_decode_redis_item(item) for item in raw_popped]
messages: list[PendingMessage] = []
for payload in decoded:
try:
messages.append(PendingMessage.model_validate(json.loads(payload)))
except (json.JSONDecodeError, ValidationError, TypeError, ValueError) as e:
logger.warning(
"pending_messages: dropping malformed entry for %s: %s",
session_id,
e,
)
if messages:
logger.info(
"pending_messages: drained %d messages for session=%s",
len(messages),
session_id,
)
return messages
async def peek_pending_count(session_id: str) -> int:
"""Return the current buffer length without consuming it."""
redis = await get_redis_async()
length = await cast("Any", redis.llen(_buffer_key(session_id)))
return int(length)
async def peek_pending_messages(session_id: str) -> list[PendingMessage]:
"""Return pending messages without consuming them.
Uses LRANGE 0 -1 to read all items in enqueue order (oldest first)
without removing them. Returns an empty list if the buffer is empty
or the session has no pending messages.
"""
redis = await get_redis_async()
key = _buffer_key(session_id)
items = await cast("Any", redis.lrange(key, 0, -1))
if not items:
return []
messages: list[PendingMessage] = []
for item in items:
try:
messages.append(
PendingMessage.model_validate(json.loads(_decode_redis_item(item)))
)
except (json.JSONDecodeError, ValidationError, TypeError, ValueError) as e:
logger.warning(
"pending_messages: dropping malformed peek entry for %s: %s",
session_id,
e,
)
return messages
async def clear_pending_messages_unsafe(session_id: str) -> None:
"""Drop the session's pending buffer — **not** the normal turn cleanup.
The ``_unsafe`` suffix warns: reaching for this at turn end drops queued
follow-ups on the floor instead of running them (the bug fixed by commit
b64be73). The atomic ``LPOP`` drain at turn start is the primary consumer;
anything pushed after the drain window belongs to the next turn by
definition. Retained only as an operator/debug escape hatch for manually
clearing a stuck session and as a fixture in the unit tests.
"""
redis = await get_redis_async()
await redis.delete(_buffer_key(session_id))
# Per-message and total-block caps for inline tool-boundary injection.
# Per-message keeps a single long paste from dominating; the total cap
# keeps the follow-up block small relative to the 100 KB MCP truncation
# boundary so tool output always stays the larger share of the wrapper
# return value.
_FOLLOWUP_CONTENT_MAX_CHARS = 2_000
_FOLLOWUP_TOTAL_MAX_CHARS = 6_000
def _persist_queue_key(session_id: str) -> str:
return f"{_PERSIST_QUEUE_KEY_PREFIX}{session_id}"
async def stash_pending_for_persist(
session_id: str,
messages: list[PendingMessage],
) -> None:
"""Enqueue drained PendingMessages for UI-row persistence.
Writes each message as a JSON payload to
``copilot:pending-persist:{session_id}``. The SDK service's
tool-result dispatch handler LPOPs this queue right after appending
the tool_result row to ``session.messages``, so the resulting user
row lands at the correct chronological position (after the tool
output the follow-up was drained against).
Fire-and-forget on Redis failures: a stash failure means Claude
still saw the follow-up in tool output (the injection step ran
first), so the only consequence is a missing UI bubble. Logged
so it can be spotted.
"""
if not messages:
return
try:
redis = await get_redis_async()
key = _persist_queue_key(session_id)
payloads = [m.model_dump_json() for m in messages]
await redis.rpush(key, *payloads) # type: ignore[misc]
await redis.expire(key, _PENDING_TTL_SECONDS) # type: ignore[misc]
except Exception:
logger.warning(
"pending_messages: failed to stash %d message(s) for persist "
"(session=%s); UI will miss the follow-up bubble but Claude "
"already saw the content in tool output",
len(messages),
session_id,
exc_info=True,
)
async def drain_pending_for_persist(session_id: str) -> list[PendingMessage]:
"""Atomically drain the persist queue for *session_id*.
Returns the queued ``PendingMessage`` objects in enqueue order (oldest
first). Returns ``[]`` on any error so the service-layer caller can
always treat the result as a plain list. Called by sdk/service.py
after appending a tool_result row to ``session.messages``.
"""
try:
redis = await get_redis_async()
key = _persist_queue_key(session_id)
lpop_result = await redis.lpop( # type: ignore[assignment]
key, MAX_PENDING_MESSAGES
)
except Exception:
logger.warning(
"pending_messages: drain_pending_for_persist failed for session=%s",
session_id,
exc_info=True,
)
return []
if not lpop_result:
return []
raw_popped: list[Any] = list(lpop_result)
messages: list[PendingMessage] = []
for item in raw_popped:
try:
messages.append(
PendingMessage.model_validate(json.loads(_decode_redis_item(item)))
)
except (json.JSONDecodeError, ValidationError, TypeError, ValueError) as e:
logger.warning(
"pending_messages: dropping malformed persist-queue entry "
"for %s: %s",
session_id,
e,
)
return messages
def format_pending_as_followup(pending: list[PendingMessage]) -> str:
"""Render drained pending messages as a ``<user_follow_up>`` block.
Used by the SDK tool-boundary injection path to surface queued user
text inside a tool result so the model reads it on the next LLM round,
without starting a separate turn. Wrapped in a stable XML-style tag so
the shared system-prompt supplement can teach the model to treat the
contents as the user's continuation of their request, not as tool
output. Each message is capped to keep the block bounded even if the
user pastes long content.
"""
if not pending:
return ""
rendered: list[str] = []
total_chars = 0
dropped = 0
for idx, pm in enumerate(pending, start=1):
text = pm.content
if len(text) > _FOLLOWUP_CONTENT_MAX_CHARS:
text = text[:_FOLLOWUP_CONTENT_MAX_CHARS] + "… [truncated]"
entry = f"Message {idx}:\n{text}"
if pm.context and pm.context.url:
entry += f"\n[Page URL: {pm.context.url}]"
if pm.file_ids:
entry += "\n[Attached files: " + ", ".join(pm.file_ids) + "]"
if total_chars + len(entry) > _FOLLOWUP_TOTAL_MAX_CHARS:
dropped = len(pending) - idx + 1
break
rendered.append(entry)
total_chars += len(entry)
if dropped:
rendered.append(f"… [{dropped} more message(s) truncated]")
body = "\n\n".join(rendered)
return (
"<user_follow_up>\n"
"The user sent the following message(s) while this tool was running. "
"Treat them as a continuation of their current request — acknowledge "
"and act on them in your next response. Do not echo these tags back.\n\n"
f"{body}\n"
"</user_follow_up>"
)
async def drain_and_format_for_injection(
session_id: str,
*,
log_prefix: str,
) -> str:
"""Drain the pending buffer and produce a ``<user_follow_up>`` block.
Shared entry point for every mid-turn injection site (``PostToolUse``
hook for MCP + built-in tools, baseline between-rounds drain, etc.).
Also stashes the drained messages on the persist queue so the service
layer appends a real user row after the tool_result it rode in on —
giving the UI a correctly-ordered bubble.
Returns an empty string if nothing was queued or Redis failed; callers
can pass the result straight to ``additionalContext``.
"""
if not session_id:
return ""
try:
pending = await drain_pending_messages(session_id)
except Exception:
logger.warning(
"%s drain_pending_messages failed (session=%s); skipping injection",
log_prefix,
session_id,
exc_info=True,
)
return ""
if not pending:
return ""
logger.info(
"%s Injected %d user follow-up(s) into tool output (session=%s)",
log_prefix,
len(pending),
session_id,
)
await stash_pending_for_persist(session_id, pending)
return format_pending_as_followup(pending)
def format_pending_as_user_message(message: PendingMessage) -> dict[str, Any]:
"""Shape a ``PendingMessage`` into the OpenAI-format user message dict.
Used by the baseline tool-call loop when injecting the buffered
message into the conversation. Context/file metadata (if any) is
embedded into the content so the model sees everything in one block.
"""
parts: list[str] = [message.content]
if message.context:
if message.context.url:
parts.append(f"\n\n[Page URL: {message.context.url}]")
if message.context.content:
parts.append(f"\n\n[Page content]\n{message.context.content}")
if message.file_ids:
parts.append(
"\n\n[Attached files]\n"
+ "\n".join(f"- file_id={fid}" for fid in message.file_ids)
+ "\nUse read_workspace_file with the file_id to access file contents."
)
return {"role": "user", "content": "".join(parts)}

View File

@@ -0,0 +1,614 @@
"""Tests for the copilot pending-messages buffer.
Uses a fake async Redis client so the tests don't require a real Redis
instance (the backend test suite's DB/Redis fixtures are heavyweight
and pull in the full app startup).
"""
import asyncio
import json
from typing import Any
import pytest
from backend.copilot import pending_messages as pm_module
from backend.copilot.pending_messages import (
MAX_PENDING_MESSAGES,
PendingMessage,
PendingMessageContext,
clear_pending_messages_unsafe,
drain_and_format_for_injection,
drain_pending_for_persist,
drain_pending_messages,
format_pending_as_followup,
format_pending_as_user_message,
peek_pending_count,
peek_pending_messages,
push_pending_message,
stash_pending_for_persist,
)
# ── Fake Redis ──────────────────────────────────────────────────────
class _FakeRedis:
def __init__(self) -> None:
# Values are ``str | bytes`` because real redis-py returns
# bytes when ``decode_responses=False``; the drain path must
# handle both and our tests exercise both.
self.lists: dict[str, list[str | bytes]] = {}
self.published: list[tuple[str, str]] = []
async def rpush(self, key: str, *values: Any) -> int:
lst = self.lists.setdefault(key, [])
lst.extend(values)
return len(lst)
async def ltrim(self, key: str, start: int, stop: int) -> None:
lst = self.lists.get(key, [])
# Redis LTRIM stop is inclusive; -1 means the last element.
if stop == -1:
self.lists[key] = lst[start:]
else:
self.lists[key] = lst[start : stop + 1]
async def expire(self, key: str, seconds: int) -> int:
# Fake doesn't enforce TTL — just acknowledge.
return 1
async def publish(self, channel: str, payload: str) -> int:
self.published.append((channel, payload))
return 1
async def lpop(self, key: str, count: int) -> list[str | bytes] | None:
lst = self.lists.get(key)
if not lst:
return None
popped = lst[:count]
self.lists[key] = lst[count:]
return popped
async def llen(self, key: str) -> int:
return len(self.lists.get(key, []))
async def lrange(self, key: str, start: int, stop: int) -> list[str | bytes]:
lst = self.lists.get(key, [])
# Redis LRANGE stop is inclusive; -1 means the last element.
if stop == -1:
return list(lst[start:])
return list(lst[start : stop + 1])
async def delete(self, key: str) -> int:
if key in self.lists:
del self.lists[key]
return 1
return 0
def pipeline(self, transaction: bool = True) -> "_FakePipeline":
# Returns a fake pipeline that records ops and replays them in
# order on ``execute()``. Used by ``capped_rpush`` (push_pending_message)
# and ``incr_with_ttl`` (call-rate check) via MULTI/EXEC.
return _FakePipeline(self)
async def incr(self, key: str) -> int:
# Used by incr_with_ttl's pipeline.
current = int(self.lists.get(key, [0])[0]) if self.lists.get(key) else 0
current += 1
# We abuse the same lists dict for simple counters — store [count].
self.lists[key] = [str(current)]
return current
class _FakePipeline:
"""Async pipeline shim matching the redis-py MULTI/EXEC surface."""
def __init__(self, parent: "_FakeRedis") -> None:
self._parent = parent
self._ops: list[tuple[str, tuple[Any, ...], dict[str, Any]]] = []
# Each method just records the op; dispatching happens in execute().
def rpush(self, key: str, *values: Any) -> "_FakePipeline":
self._ops.append(("rpush", (key, *values), {}))
return self
def ltrim(self, key: str, start: int, stop: int) -> "_FakePipeline":
self._ops.append(("ltrim", (key, start, stop), {}))
return self
def expire(self, key: str, seconds: int, **kw: Any) -> "_FakePipeline":
self._ops.append(("expire", (key, seconds), kw))
return self
def llen(self, key: str) -> "_FakePipeline":
self._ops.append(("llen", (key,), {}))
return self
def incr(self, key: str) -> "_FakePipeline":
self._ops.append(("incr", (key,), {}))
return self
async def execute(self) -> list[Any]:
results: list[Any] = []
for name, args, _kw in self._ops:
fn = getattr(self._parent, name)
results.append(await fn(*args))
return results
# Support `async with pipeline() as pipe:` too.
async def __aenter__(self) -> "_FakePipeline":
return self
async def __aexit__(self, *a: Any) -> None:
return None
@pytest.fixture()
def fake_redis(monkeypatch: pytest.MonkeyPatch) -> _FakeRedis:
redis = _FakeRedis()
async def _get_redis_async() -> _FakeRedis:
return redis
monkeypatch.setattr(pm_module, "get_redis_async", _get_redis_async)
return redis
# ── Basic push / drain ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_push_and_drain_single_message(fake_redis: _FakeRedis) -> None:
length = await push_pending_message("sess1", PendingMessage(content="hello"))
assert length == 1
assert await peek_pending_count("sess1") == 1
drained = await drain_pending_messages("sess1")
assert len(drained) == 1
assert drained[0].content == "hello"
assert await peek_pending_count("sess1") == 0
@pytest.mark.asyncio
async def test_push_and_drain_preserves_order(fake_redis: _FakeRedis) -> None:
for i in range(3):
await push_pending_message("sess2", PendingMessage(content=f"msg {i}"))
drained = await drain_pending_messages("sess2")
assert [m.content for m in drained] == ["msg 0", "msg 1", "msg 2"]
@pytest.mark.asyncio
async def test_drain_empty_returns_empty_list(fake_redis: _FakeRedis) -> None:
assert await drain_pending_messages("nope") == []
# ── Buffer cap ──────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_cap_drops_oldest_when_exceeded(fake_redis: _FakeRedis) -> None:
# Push MAX_PENDING_MESSAGES + 3 messages
for i in range(MAX_PENDING_MESSAGES + 3):
await push_pending_message("sess3", PendingMessage(content=f"m{i}"))
# Buffer should be clamped to MAX
assert await peek_pending_count("sess3") == MAX_PENDING_MESSAGES
drained = await drain_pending_messages("sess3")
assert len(drained) == MAX_PENDING_MESSAGES
# Oldest 3 dropped — we should only see m3..m(MAX+2)
assert drained[0].content == "m3"
assert drained[-1].content == f"m{MAX_PENDING_MESSAGES + 2}"
# ── Clear ───────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_clear_removes_buffer(fake_redis: _FakeRedis) -> None:
await push_pending_message("sess4", PendingMessage(content="x"))
await push_pending_message("sess4", PendingMessage(content="y"))
await clear_pending_messages_unsafe("sess4")
assert await peek_pending_count("sess4") == 0
@pytest.mark.asyncio
async def test_clear_is_idempotent(fake_redis: _FakeRedis) -> None:
# Clearing an already-empty buffer should not raise
await clear_pending_messages_unsafe("sess_empty")
await clear_pending_messages_unsafe("sess_empty")
# ── Publish hook ────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_push_publishes_notification(fake_redis: _FakeRedis) -> None:
await push_pending_message("sess5", PendingMessage(content="hi"))
assert ("copilot:pending:notify:sess5", "1") in fake_redis.published
# ── Format helper ───────────────────────────────────────────────────
def test_format_pending_plain_text() -> None:
msg = PendingMessage(content="just text")
out = format_pending_as_user_message(msg)
assert out == {"role": "user", "content": "just text"}
def test_format_pending_with_context_url() -> None:
msg = PendingMessage(
content="see this page",
context=PendingMessageContext(url="https://example.com"),
)
out = format_pending_as_user_message(msg)
content = out["content"]
assert out["role"] == "user"
assert "see this page" in content
# The URL should appear verbatim in the [Page URL: ...] block.
assert "[Page URL: https://example.com]" in content
def test_format_pending_with_file_ids() -> None:
msg = PendingMessage(content="look here", file_ids=["a", "b"])
out = format_pending_as_user_message(msg)
assert "file_id=a" in out["content"]
assert "file_id=b" in out["content"]
def test_format_pending_with_all_fields() -> None:
"""All fields (content + context url/content + file_ids) should all appear."""
msg = PendingMessage(
content="summarise this",
context=PendingMessageContext(
url="https://example.com/page",
content="headline text",
),
file_ids=["f1", "f2"],
)
out = format_pending_as_user_message(msg)
body = out["content"]
assert out["role"] == "user"
assert "summarise this" in body
assert "[Page URL: https://example.com/page]" in body
assert "[Page content]\nheadline text" in body
assert "file_id=f1" in body
assert "file_id=f2" in body
# ── Followup block caps ────────────────────────────────────────────
def test_format_followup_single_message() -> None:
out = format_pending_as_followup([PendingMessage(content="hello")])
assert "<user_follow_up>" in out
assert "</user_follow_up>" in out
assert "Message 1:\nhello" in out
def test_format_followup_total_cap_drops_overflow() -> None:
"""10 × 2 KB messages must truncate past the total cap (~6 KB) with a
marker indicating how many were dropped."""
messages = [PendingMessage(content="A" * 2_000) for _ in range(10)]
out = format_pending_as_followup(messages)
# Block stays within the total cap (plus a little wrapper overhead).
# The body alone is capped at 6 KB; we allow generous overhead for the
# <user_follow_up> wrapper + headers.
assert len(out) < 8_000
assert "more message(s) truncated" in out
# The first message at least must be present.
assert "Message 1:" in out
def test_format_followup_total_cap_marker_counts_dropped() -> None:
"""The marker should name the exact number of dropped messages."""
# Each 3 KB message gets capped to 2 KB first; with ~2 KB per entry and a
# 6 KB total cap, roughly two entries fit and the rest are dropped.
messages = [PendingMessage(content="X" * 3_000) for _ in range(5)]
out = format_pending_as_followup(messages)
assert "Message 1:" in out
assert "Message 2:" in out
# Message 3 would push total past 6 KB; marker should report exactly how
# many were left out (here: messages 3, 4, 5 → 3 dropped).
assert "[3 more message(s) truncated]" in out
def test_format_followup_empty_returns_empty_string() -> None:
assert format_pending_as_followup([]) == ""
# ── Malformed payload handling ──────────────────────────────────────
@pytest.mark.asyncio
async def test_drain_skips_malformed_entries(
fake_redis: _FakeRedis,
) -> None:
# Seed the fake with a mix of valid and malformed payloads
fake_redis.lists["copilot:pending:bad"] = [
json.dumps({"content": "valid"}),
"{not valid json",
json.dumps({"content": "also valid", "file_ids": ["a"]}),
]
drained = await drain_pending_messages("bad")
assert len(drained) == 2
assert drained[0].content == "valid"
assert drained[1].content == "also valid"
@pytest.mark.asyncio
async def test_drain_decodes_bytes_payloads(
fake_redis: _FakeRedis,
) -> None:
"""Real redis-py returns ``bytes`` when ``decode_responses=False``.
Seed the fake with bytes values to exercise the ``decode("utf-8")``
branch in ``drain_pending_messages`` so a regression there doesn't
slip past CI.
"""
fake_redis.lists["copilot:pending:bytes_sess"] = [
json.dumps({"content": "from bytes"}).encode("utf-8"),
]
drained = await drain_pending_messages("bytes_sess")
assert len(drained) == 1
assert drained[0].content == "from bytes"
@pytest.mark.asyncio
async def test_peek_decodes_bytes_payloads(
fake_redis: _FakeRedis,
) -> None:
"""``peek_pending_messages`` uses the same ``_decode_redis_item`` helper
as the drain path. Seed with bytes to guard against regression.
"""
fake_redis.lists["copilot:pending:peek_bytes_sess"] = [
json.dumps({"content": "peeked from bytes"}).encode("utf-8"),
]
peeked = await peek_pending_messages("peek_bytes_sess")
assert len(peeked) == 1
assert peeked[0].content == "peeked from bytes"
# peek must NOT consume the item
assert fake_redis.lists["copilot:pending:peek_bytes_sess"] != []
# ── Concurrency ─────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_concurrent_push_and_drain(fake_redis: _FakeRedis) -> None:
"""Two pushes fired concurrently should both land; a concurrent drain
should see at least one of them (the fake serialises, so it will
always see both, but we exercise the code path either way)."""
await asyncio.gather(
push_pending_message("sess_conc", PendingMessage(content="a")),
push_pending_message("sess_conc", PendingMessage(content="b")),
)
drained = await drain_pending_messages("sess_conc")
assert len(drained) >= 1
contents = {m.content for m in drained}
assert contents <= {"a", "b"}
# ── Publish error path ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_push_survives_publish_failure(
fake_redis: _FakeRedis, monkeypatch: pytest.MonkeyPatch
) -> None:
"""A publish error must not propagate — the buffer is still authoritative."""
async def _fail_publish(channel: str, payload: str) -> int:
raise RuntimeError("redis publish down")
monkeypatch.setattr(fake_redis, "publish", _fail_publish)
length = await push_pending_message("sess_pub_err", PendingMessage(content="ok"))
assert length == 1
drained = await drain_pending_messages("sess_pub_err")
assert len(drained) == 1
assert drained[0].content == "ok"
# ── peek_pending_messages ────────────────────────────────────────────
@pytest.mark.asyncio
async def test_peek_pending_messages_returns_all_without_consuming(
fake_redis: _FakeRedis,
) -> None:
"""Peek returns all queued messages and leaves the buffer intact."""
await push_pending_message("peek1", PendingMessage(content="first"))
await push_pending_message("peek1", PendingMessage(content="second"))
peeked = await peek_pending_messages("peek1")
assert len(peeked) == 2
assert peeked[0].content == "first"
assert peeked[1].content == "second"
# Buffer must not be consumed — count still 2
assert await peek_pending_count("peek1") == 2
drained = await drain_pending_messages("peek1")
assert len(drained) == 2
@pytest.mark.asyncio
async def test_peek_pending_messages_empty_buffer(fake_redis: _FakeRedis) -> None:
"""Peek on a missing key returns an empty list without raising."""
result = await peek_pending_messages("no_such_session")
assert result == []
@pytest.mark.asyncio
async def test_peek_pending_messages_decodes_bytes_payloads(
fake_redis: _FakeRedis,
) -> None:
"""peek_pending_messages decodes bytes entries the same way drain does."""
fake_redis.lists["copilot:pending:peek_bytes"] = [
json.dumps({"content": "from bytes"}).encode("utf-8"),
]
peeked = await peek_pending_messages("peek_bytes")
assert len(peeked) == 1
assert peeked[0].content == "from bytes"
@pytest.mark.asyncio
async def test_peek_pending_messages_skips_malformed_entries(
fake_redis: _FakeRedis,
) -> None:
"""Malformed entries are skipped and valid ones are returned."""
fake_redis.lists["copilot:pending:peek_bad"] = [
json.dumps({"content": "valid peek"}),
"{bad json",
json.dumps({"content": "also valid peek"}),
]
peeked = await peek_pending_messages("peek_bad")
assert len(peeked) == 2
assert peeked[0].content == "valid peek"
assert peeked[1].content == "also valid peek"
# ── Persist queue (mid-turn follow-up UI bubble hand-off) ───────────
@pytest.mark.asyncio
async def test_stash_for_persist_enqueues_and_drain_pops_in_order(
fake_redis: _FakeRedis,
) -> None:
"""stash_pending_for_persist writes messages under the persist key;
drain_pending_for_persist LPOPs them in enqueue order."""
msgs = [
PendingMessage(content="first mid-turn follow-up"),
PendingMessage(content="second"),
]
await stash_pending_for_persist("sess-persist", msgs)
# Stored under the distinct persist key, NOT the primary buffer.
assert "copilot:pending-persist:sess-persist" in fake_redis.lists
assert "copilot:pending:sess-persist" not in fake_redis.lists
drained = await drain_pending_for_persist("sess-persist")
assert len(drained) == 2
assert drained[0].content == "first mid-turn follow-up"
assert drained[1].content == "second"
# Queue is empty after drain.
assert await drain_pending_for_persist("sess-persist") == []
@pytest.mark.asyncio
async def test_stash_for_persist_empty_list_is_noop(
fake_redis: _FakeRedis,
) -> None:
"""Passing an empty list must NOT create a Redis key (would leak
empty persist entries and require a drain for no reason)."""
await stash_pending_for_persist("sess-noop", [])
assert "copilot:pending-persist:sess-noop" not in fake_redis.lists
@pytest.mark.asyncio
async def test_drain_pending_for_persist_missing_key_returns_empty(
fake_redis: _FakeRedis,
) -> None:
assert await drain_pending_for_persist("never-stashed") == []
@pytest.mark.asyncio
async def test_drain_pending_for_persist_skips_malformed(
fake_redis: _FakeRedis,
) -> None:
fake_redis.lists["copilot:pending-persist:bad"] = [
json.dumps({"content": "good one"}),
"not json",
json.dumps({"content": "another good one"}),
]
result = await drain_pending_for_persist("bad")
assert [m.content for m in result] == ["good one", "another good one"]
@pytest.mark.asyncio
async def test_persist_queue_isolated_from_primary_buffer(
fake_redis: _FakeRedis,
) -> None:
"""Draining the persist queue must NOT touch the primary pending
buffer (and vice versa) — they serve different lifecycles."""
# Seed the primary buffer with one entry.
await push_pending_message("sess-iso", PendingMessage(content="primary"))
# Stash a separate entry on the persist queue.
await stash_pending_for_persist("sess-iso", [PendingMessage(content="persist")])
drained_persist = await drain_pending_for_persist("sess-iso")
assert [m.content for m in drained_persist] == ["persist"]
# Primary buffer untouched.
assert await peek_pending_count("sess-iso") == 1
drained_primary = await drain_pending_messages("sess-iso")
assert [m.content for m in drained_primary] == ["primary"]
@pytest.mark.asyncio
async def test_stash_for_persist_swallows_redis_failure(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A broken Redis during stash must not raise — Claude has already
seen the follow-up via tool output; the only fallout is a missing
UI bubble, which we log and move on."""
async def _broken_redis() -> Any:
raise ConnectionError("redis down")
monkeypatch.setattr(pm_module, "get_redis_async", _broken_redis)
# Must NOT raise.
await stash_pending_for_persist("sess-broken", [PendingMessage(content="lost")])
# ── drain_and_format_for_injection: shared entry point ─────────────────
@pytest.mark.asyncio
async def test_drain_and_format_for_injection_happy_path(
fake_redis: _FakeRedis,
) -> None:
"""Queued messages drain into a ready-to-inject <user_follow_up> block
AND are stashed on the persist queue for UI row hand-off."""
await push_pending_message("sess-share", PendingMessage(content="do X also"))
result = await drain_and_format_for_injection("sess-share", log_prefix="[TEST]")
assert "<user_follow_up>" in result
assert "do X also" in result
# Primary buffer drained.
assert await peek_pending_count("sess-share") == 0
# Persist queue got a copy for the UI.
persisted = await drain_pending_for_persist("sess-share")
assert len(persisted) == 1
assert persisted[0].content == "do X also"
@pytest.mark.asyncio
async def test_drain_and_format_for_injection_empty_returns_empty(
fake_redis: _FakeRedis,
) -> None:
assert await drain_and_format_for_injection("sess-empty", log_prefix="[TEST]") == ""
@pytest.mark.asyncio
async def test_drain_and_format_for_injection_swallows_redis_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
async def _broken() -> Any:
raise ConnectionError("down")
monkeypatch.setattr(pm_module, "get_redis_async", _broken)
# Must NOT raise — broken Redis becomes "nothing to inject".
assert (
await drain_and_format_for_injection("sess-broken", log_prefix="[TEST]") == ""
)
@pytest.mark.asyncio
async def test_drain_and_format_for_injection_missing_session_id() -> None:
assert await drain_and_format_for_injection("", log_prefix="[TEST]") == ""

View File

@@ -52,10 +52,15 @@ is at most as permissive as the parent:
from __future__ import annotations
import re
from typing import Literal, get_args
from typing import TYPE_CHECKING, Literal, get_args
from pydantic import BaseModel, PrivateAttr
if TYPE_CHECKING:
from collections.abc import Iterable
from backend.copilot.tools import ToolGroup
# ---------------------------------------------------------------------------
# Constants — single source of truth for all accepted tool names
# ---------------------------------------------------------------------------
@@ -66,7 +71,6 @@ from pydantic import BaseModel, PrivateAttr
ToolName = Literal[
# Platform tools (must match keys in TOOL_REGISTRY)
"add_understanding",
"ask_question",
"bash_exec",
"browser_act",
"browser_navigate",
@@ -87,6 +91,7 @@ ToolName = Literal[
"get_agent_building_guide",
"get_doc_page",
"get_mcp_guide",
"get_sub_session_result",
"list_folders",
"list_workspace_files",
"memory_forget_confirm",
@@ -99,12 +104,14 @@ ToolName = Literal[
"run_agent",
"run_block",
"run_mcp_tool",
"run_sub_session",
"search_docs",
"search_feature_requests",
"update_folder",
"validate_agent_graph",
"view_agent_output",
"web_fetch",
"web_search",
"write_workspace_file",
# SDK built-ins
"Agent",
@@ -121,9 +128,16 @@ ToolName = Literal[
# Frozen set of all valid tool names — derived from the Literal.
ALL_TOOL_NAMES: frozenset[str] = frozenset(get_args(ToolName))
# SDK built-in tool names — uppercase-initial names are SDK built-ins.
# SDK built-in tool names — tools provided by the Claude Code CLI that our
# code does not implement directly. ``TodoWrite`` is DELIBERATELY excluded:
# baseline mode ships an MCP-wrapped platform version
# (``tools/todo_write.py``), while SDK mode still uses the CLI-native
# original via ``_SDK_BUILTIN_ALWAYS`` in ``sdk/tool_adapter.py`` — the
# MCP copy is filtered out there. ``Task`` remains an SDK-only built-in
# (for queue-backed context-isolation on baseline, use ``run_sub_session``
# instead).
SDK_BUILTIN_TOOL_NAMES: frozenset[str] = frozenset(
n for n in ALL_TOOL_NAMES if n[0].isupper()
{"Agent", "Edit", "Glob", "Grep", "Read", "Task", "WebSearch", "Write"}
)
# Platform tool names — everything that isn't an SDK built-in.
@@ -360,13 +374,17 @@ def apply_tool_permissions(
permissions: CopilotPermissions,
*,
use_e2b: bool = False,
disabled_groups: Iterable[ToolGroup] = (),
) -> tuple[list[str], list[str]]:
"""Compute (allowed_tools, extra_disallowed) for :class:`ClaudeAgentOptions`.
Takes the base allowed/disallowed lists from
:func:`~backend.copilot.sdk.tool_adapter.get_copilot_tool_names` /
:func:`~backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools` and
applies *permissions* on top.
applies *permissions* on top. Tools belonging to any *disabled_groups*
are hidden from the base allowed list — use this to gate capability
groups (e.g. ``"graphiti"`` when the memory backend is off for the
current user).
Returns:
``(allowed_tools, extra_disallowed)`` where *allowed_tools* is the
@@ -376,13 +394,16 @@ def apply_tool_permissions(
"""
from backend.copilot.sdk.tool_adapter import (
_READ_TOOL_NAME,
BASELINE_ONLY_MCP_TOOLS,
MCP_TOOL_PREFIX,
get_copilot_tool_names,
get_sdk_disallowed_tools,
)
from backend.copilot.tools import TOOL_REGISTRY
base_allowed = get_copilot_tool_names(use_e2b=use_e2b)
base_allowed = get_copilot_tool_names(
use_e2b=use_e2b, disabled_groups=disabled_groups
)
base_disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
if permissions.is_empty():
@@ -416,7 +437,14 @@ def apply_tool_permissions(
# keeping only those present in the original base_allowed list.
def to_sdk_names(short: str) -> list[str]:
names: list[str] = []
if short in TOOL_REGISTRY:
if short in BASELINE_ONLY_MCP_TOOLS:
# Baseline ships MCP versions of these (Task/TodoWrite) for
# model-flexibility parity, but SDK mode uses the CLI-native
# originals. Permissions target the CLI built-in here so
# ``base_allowed`` (which excludes the MCP wrappers) still
# matches.
names.append(short)
elif short in TOOL_REGISTRY:
names.append(f"{MCP_TOOL_PREFIX}{short}")
elif short in _SDK_TO_MCP:
# Map SDK built-in file tool to its MCP equivalent.

View File

@@ -582,6 +582,11 @@ class TestApplyToolPermissions:
class TestSdkBuiltinToolNames:
def test_expected_builtins_present(self):
# ``TodoWrite`` is DELIBERATELY absent: baseline ships an MCP-wrapped
# platform version for model-flexibility parity, so it appears in
# PLATFORM_TOOL_NAMES / TOOL_REGISTRY instead. ``Task`` remains
# SDK-only — baseline uses ``run_sub_session`` for the equivalent
# context-isolation role.
expected = {
"Agent",
"Read",
@@ -591,9 +596,9 @@ class TestSdkBuiltinToolNames:
"Grep",
"Task",
"WebSearch",
"TodoWrite",
}
assert expected.issubset(SDK_BUILTIN_TOOL_NAMES)
assert "TodoWrite" not in SDK_BUILTIN_TOOL_NAMES
def test_platform_names_match_tool_registry(self):
"""PLATFORM_TOOL_NAMES (derived from ToolName Literal) must match TOOL_REGISTRY keys."""

Some files were not shown because too many files have changed in this diff Show More