Compare commits

...

24 Commits

Author SHA1 Message Date
Zamil Majdy
7a9b0827bc Merge branch 'dev' into fix/copilot-stream-errors-and-queue-bubbles 2026-04-30 12:48:13 +07:00
Zamil Majdy
7cc1edc61f dx(pr-polish): use --json bucket instead of awk text-column parsing (#12951)
## Why

`/pr-polish` was prematurely emitting `CLEAN-POLL` while CI was still
pending, because the polish-polling loop's CI gate parsed `gh pr checks
$PR` text columns with `awk '{print $2}'`. That works fine for plain job
names, but breaks on jobs with spaces or parens like `test (3.11)`,
`Analyze (python)`, where column 2 is the version `(3.11)` — so `grep -q
"pending"` matched on column 2 of OTHER rows but missed the actual
pending entries. Real symptom on PR #12948: the orchestrator reported
`ORCHESTRATOR:DONE` while `test (3.11/3.12/3.13)` and `Check PR Status`
were still running.

## What

Add a "Concrete CI fetch" subsection right after the polish-polling
pseudocode block, showing the `--json bucket` shape that bypasses the
column-parsing trap entirely. Also flag the `bucket` vs `conclusion`
gotcha (the REST API uses `conclusion`; `gh pr checks --json` only
exposes `bucket`).

## How

Surgical additive edit — the existing pseudocode + state machine is
preserved; the new subsection just translates the abstract
`fetch_check_runs(PR)` into a concrete one-liner so the next implementer
doesn't reach for `awk` again.

## Test plan

- [x] Verified the regression against PR #12948: bucket-based polling
correctly identified 4 pending checks the awk path missed
- [x] Confirmed `gh pr checks {N} --json conclusion` errors with
`Unknown JSON field: "conclusion"` (this gotcha is now noted in the
skill)
2026-04-30 12:44:25 +07:00
majdyz
9969b1ac07 fix(copilot): make CoPilotExecutor picklable for forkserver start
`self._active_tasks_lock = threading.Lock()` in `__init__` (added in #12877
to make `_cleanup_completed_tasks` thread-safe) holds a `_thread.lock`
that the forkserver/spawn start method cannot serialize. With it set
eagerly, `Process(target=self.execute_run_command).start()` from
`AppProcess.start()` raises `TypeError: cannot pickle '_thread.lock'
object` and `poetry run app` aborts at startup before the REST server
binds.

Move the lock to a lazy `@property _active_tasks_lock` so the parent
process never holds a real `threading.Lock` instance — the lock is
materialized inside the forked child the first time
`_cleanup_completed_tasks` runs, where pickling is no longer in play.
This mirrors the existing lazy-init pattern already used for the
ThreadPoolExecutor, RabbitMQ clients, and consumer threads in this
class.
2026-04-30 12:31:02 +07:00
majdyz
d765715fbc fix(copilot): warn on path-traversal in delete_stale_cli_session_file
Self-review: the projects-base guard was returning silently. Mirror the
warn-shape from `_write_cli_session_to_disk` so an out-of-base resolve
surfaces as a Sentry-visible warning. Unreachable in normal operation
(server-generated UUID + deterministic `cli_session_path`), but a hit
would indicate a config or tampering issue worth seeing.
2026-04-30 12:06:21 +07:00
majdyz
8e13d4cb27 fix(copilot/frontend): guard mid-turn poll against session switches
Mirror the request-time-sessionId pattern from usePeekOnBoundary into
useMidTurnDrainPromotion: capture sessionId at request time, compare to
a live ref on resolve, bail if the user switched sessions while the GET
was in flight. Without this, a slow peek for session A could promote
chips into session B's message list after a switch.

Cancellation-flag was tried first but is too broad — this effect re-runs
on every chip-append, which would wrongly invalidate an in-flight poll
for the same session. The sessionId comparison only invalidates on
actual session changes, preserving the chip-append-during-poll race
guarantee.

Added a regression test that holds a peek in flight, switches sessions
mid-resolve, and asserts no promotion fires on the old session's
setMessages.
2026-04-30 11:47:13 +07:00
majdyz
e59fe5af76 fix(copilot): address CodeRabbit review
- delete_stale_cli_session_file: drop exists() TOCTOU; catch
  FileNotFoundError; log basename + strerror only on unexpected OSError.
- useCopilotPendingChips: unify bubble id across auto-continue and
  mid-turn promotion paths via `bubbleIdFor(chip) = pending-chip-{uuid}`,
  so a poll resolving after auto-continue already promoted the same chip
  no longer renders it twice.
- usePeekOnBoundary: capture sessionId at request time and guard the
  .then() callback against a stale response that resolves after the user
  switched sessions (prevents old-session chips bleeding into the new
  session).
2026-04-30 11:17:32 +07:00
majdyz
9e8622c1d1 fixup(copilot): logger.info/warn (not debug); restore chip state diagram 2026-04-30 10:40:39 +07:00
Zamil Majdy
0dcd25f73f Merge branch 'dev' into fix/copilot-stream-errors-and-queue-bubbles 2026-04-30 10:28:57 +07:00
majdyz
78619ba090 fix(copilot): persist context-rescued retry, dedupe error UI, race-safe chips
Three user-reported regressions on dev (chat-mode-option flag), all in one
PR because they share the same surface area:

1. Disappearing queued messages — chips were stored as bare strings keyed
   by array index; the mid-turn poll captured a stale snapshot and used a
   slice-based ``setQueuedMessages(remaining)`` that overwrote any chip the
   user appended during the in-flight peek. Fixed by giving each chip a
   frontend-only UUID, promoting one bubble per chip, and using a functional
   ``setChips(prev => prev.filter(c => !drainedIds.has(c.id)))`` so newly
   appended chips survive the race.

2. Prompt-too-long recurring on the same session for days (SENTRY-1207,
   191 occurrences) — the T2+ context-error retry branch dropped session_id
   to dodge "Session ID already in use", so the recovery CLI wrote to a
   random path and the post-turn upload silently grabbed the stale
   pre-failure file. Next turn re-resumed from the same bloated GCS copy
   and re-tripped, ad infinitum. Fixed by clearing the local session file
   first via the new ``delete_stale_cli_session_file`` helper, then keeping
   ``session_id`` so the CLI's recovery write lands on the predictable
   path that ``upload_transcript`` reads.

3. Double error UI — backend appends a persisted error marker to
   ``session.messages`` AND yields a ``StreamError`` SSE event on the same
   final-failure path. Frontend rendered the marker as an in-line ErrorCard
   bubble and ``error`` from useChat as a trailing red banner — same string,
   twice. Fixed by adding a top-level ``lastAssistantHasErrorMarker`` memo
   in ChatMessagesContainer and gating the banner on ``!lastIsErrorMarker``.

Tests: 216 backend SDK tests pass, 29 focused frontend tests pass
(useCopilotPendingChips, ChatMessagesContainer error-banner-dedup,
makePromotedBubble, plus regression coverage for the new
delete_stale_cli_session_file helper and the retry session_id reuse).
2026-04-30 10:27:02 +07:00
Zamil Majdy
4a1741cc15 fix(platform): cancel-banner copy + clearer 422 on currency mismatch (#12947)
## Why

Two regressions surfaced after
[#12933](https://github.com/Significant-Gravitas/AutoGPT/pull/12933)
merged to `dev`:

1. **Cancel-pending banner shows wrong copy.** The merged PR moved
cancel-at-period-end from `BASIC` → `NO_TIER`, but
`PendingChangeBanner.isCancellation` was still keyed on `"BASIC"`. As a
result, a user who cancels their sub now sees *"Scheduled to downgrade
to No subscription on …"* instead of the intended *"Scheduled to cancel
your subscription on …"*. Caught by Sentry on the merged PR.

2. **Currency-mismatch downgrade returns 502 (looks like outage).** A
user with an existing GBP-active sub (Max Price has
`currency_options.gbp`) tried to downgrade to Pro and got 502. The
backend logs show:
   ```
stripe._error.InvalidRequestError: The price specified only supports
`usd`.
   This doesn't match the expected currency: `gbp`.
   ```
The Pro Price is USD-only; Stripe rejects `SubscriptionSchedule.modify`
because phases must share currency. Wrapping that in a generic 502 hid
the real cause and made it read like a Stripe outage.

## What

* Frontend: flip `PendingChangeBanner.isCancellation` from `pendingTier
=== "BASIC"` to `"NO_TIER"`. Update both component and page-level tests
that exercised the cancellation branch.
* Backend: catch `stripe.InvalidRequestError` whose message mentions
`currency` in `update_subscription_tier`, and return **422** with *"Tier
change unavailable for your current billing currency. Cancel your
subscription and re-subscribe at the target tier, or contact support."*
— so users see the actual reason, not a misleading outage message. Other
`StripeError` paths still return 502.
* New backend test asserts the currency-mismatch branch returns 422 with
the new copy.

## How

* `PendingChangeBanner.tsx` line 28: 1-char change (`"BASIC"` →
`"NO_TIER"`).
* `subscription_routes_test.py` and `PendingChangeBanner.test.tsx`
updated to use `NO_TIER` for the cancellation fixture.
* `v1.py` `update_subscription_tier` adds a typed `except
stripe.InvalidRequestError` branch ahead of the generic `StripeError`;
only currency-mismatch messages get the special 422, everything else
falls through to the existing 502.

## The real fix lives in Stripe config

The defensive 422 here is just a clearer error surface. To actually
unblock GBP/EUR users from changing tiers, the per-tier Stripe Prices
(Pro, and Basic if priced) need `currency_options` for GBP added — Max
already has this, which is why Max checkout shows the £/$ toggle. Stripe
locks `currency_options` after a Price has been transacted, so the
procedure is: create a new Price with USD + GBP from the start → update
the `stripe-price-ids` LD flag to the new Price ID. No further code
change required; same Price ID stays per tier, multiple currencies
inside it.

## Checklist

- [x] Component test for new banner copy
- [x] Backend test for 422 currency-mismatch branch
- [x] Format / lint / types pass
- [x] No protected route added — N/A
2026-04-30 10:25:02 +07:00
Krzysztof Czerwinski
c08b9774dc fix(backend/push): skip OS push for onboarding payloads (#12944)
## Why

[#12723](https://github.com/Significant-Gravitas/AutoGPT/pull/12723)
wired Web Push fanout into `AsyncRedisNotificationEventBus.publish()` so
copilot completion events reach users with the tab closed. But the bus
is also used by `data/onboarding.py` for in-page step toasts, and those
started firing OS-level system notifications (`increment_runs`,
`step_completed`, etc.) — unwanted noise.

## What

Smallest possible patch: skip the OS push fanout when `payload.type ==
"onboarding"`. WebSocket delivery is unchanged.

## How

```python
async def publish(self, event: NotificationEvent) -> None:
    await self.publish_event(event, event.user_id)
    # Skip OS push for onboarding step toasts — those are in-page only.
    # TODO: remove once the onboarding/wallet rework lands.
    if event.payload.model_dump().get("type") == "onboarding":
        return
    ...
```

Five-line addition in `backend/data/notification_bus.py`. Marked `TODO`
to remove once the upcoming onboarding/wallet rework decides per-event
whether a system notification is desired.

Tests: added `test_publish_skips_web_push_for_onboarding`; existing
fanout tests continue to validate the happy path with non-onboarding
payloads.

## Test plan

- [x] `poetry run format` (ruff + isort + black + pyright)
- [ ] CI: `poetry run pytest backend/data/notification_bus_test.py`
- [ ] Manual on dev: trigger onboarding step → confirm no OS
notification; finish copilot session → confirm OS notification still
fires.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-29 16:20:53 +00:00
Zamil Majdy
fe3d6fb118 feat(platform): subscription credit grants + paywall gate + dialog UX + cross-pod cache (#12933)
## Why

Started as a regression fix for admin-granted user downgrades hitting
Stripe Checkout, broadened to close the surrounding gaps in the Stripe
billing flow that surfaced during testing. Three concrete user-facing
problems the PR resolves:

1. **Admin-granted users couldn't change tier in-app** when their
current tier had no `stripe-price-id-*` LD configured — clicking
Downgrade silently routed to a paid-signup Stripe Checkout instead of
just changing the tier.
2. **Subscription payments granted nothing visible to users** — paying
£20–£320/mo gave higher rate-limit multipliers but no AutoPilot credits
in the user's balance, despite a dialog promising "credit to your next
Stripe invoice" (which users naturally read as AutoGPT credits).
3. **Tier oscillated across page refreshes** — `get_user_by_id` was
process-local cached, so dev's 4 server pods each held their own copy.
Tier could read MAX on one pod and BASIC on another for ~5 min after a
webhook update, depending on which pod the request landed on.

Plus three structural improvements caught during review:

4. **No paywall enforcement for paid-cohort users without subscription**
— non-beta users on `BASIC` (no Stripe sub) could freely use AutoPilot.
5. **Upgrade/downgrade dialog copy was misleading** — implied a Stripe
redirect that doesn't happen for existing-sub modifications, used
"credit" ambiguously, and didn't surface the next-invoice date.
6. **Top-up Checkout created an ephemeral Stripe Product per session** —
no canonical Product for dashboard reporting, no way to scope coupons to
top-ups.

## What

### 1. Admin-granted downgrades skip Checkout (price-id-pruning
regression)

`update_subscription_tier()` used to gate its modify-or-DB-flip block on
`current_tier_price_id is not None`. When a tier was pruned from
`stripe-price-ids` LD, that gate skipped the inner DB-flip branch and
the request fell through to Checkout — sending admin-granted users to a
paid-signup flow when they were trying to *reduce* their tier. Drop the
gate and call `modify_stripe_subscription_for_tier()` unconditionally —
the function self-reports `False` when there's no Stripe sub. One
uniform path for everyone now.

### 2. Subscription credit grant on every paid Stripe invoice

New `invoice.payment_succeeded` webhook handler at
[`credit.py:handle_subscription_payment_success`](autogpt_platform/backend/backend/data/credit.py)
adds a `GRANT` transaction equal to `invoice.amount_paid`, keyed by
`INVOICE-{id}` for idempotency (Stripe webhook retries cannot
double-grant). Initial signup, monthly renewal, and prorated upgrade
charges all surface as AutoGPT balance bumps the moment Stripe confirms
the charge. Skipped: non-subscription invoices, $0 invoices, ENTERPRISE
users.

### 3. Cross-pod user cache

[`user.py:31`](autogpt_platform/backend/backend/data/user.py#L31)
`cache_user_lookup = cached(maxsize=1000, ttl_seconds=300,
shared_cache=True)`. Single line — moves the cache to Redis so all
server pods read/write the same key. The existing
`get_user_by_id.cache_delete(user_id)` invalidations now propagate
cross-pod.

### 4. PaywallGate

New
[`PaywallGate`](autogpt_platform/frontend/src/app/(platform)/PaywallGate/PaywallGate.tsx)
wraps the `(platform)/layout.tsx` route group. When
`ENABLE_PLATFORM_PAYMENT === true` (paid cohort) AND `subscription.tier
=== "BASIC"`, redirects to `/profile/credits` where the credits page
shows a "Pick a plan to continue using AutoGPT" banner above the tier
picker.

Notes:
- **Beta cohort skips entirely** (flag off → `useGetSubscriptionStatus`
query disabled, no redirect).
- **Gates on DB tier, not `has_active_stripe_subscription`** — Sentry
caught that a transient Stripe API error in
`get_active_subscription_period_end()` would set `has_active=false` for
paying users, locking them out. The DB tier is set by webhooks and
persists locally; Stripe API hiccups don't flip it.
- **Exempt routes**: `/profile`, `/admin`, `/auth`, `/login`, `/signup`,
`/reset-password`, `/error`, `/unauthorized`, `/health`. Onboarding
lives in the sibling `(no-navbar)` group, so this gate doesn't conflict
with the in-flight onboarding-paywall integration.

### 5. Upgrade/downgrade dialog clarity

`SubscriptionStatusResponse` now exposes
`has_active_stripe_subscription: bool` and `current_period_end: int |
None`, computed via a new
[`get_active_subscription_period_end`](autogpt_platform/backend/backend/data/credit.py)
helper. Frontend dialogs branch on those:

**Upgrade — modify-in-place** (existing sub):
> Your subscription is upgraded to MAX immediately. On your next invoice
on May 21, 2026, your saved card is charged for the upgrade proration
since today plus the next month at the new rate, with the unused portion
of your current plan automatically deducted. Credits matching the paid
amount are added to your AutoGPT balance once Stripe confirms the
charge.

**Upgrade — Checkout** (no sub):
> You'll be redirected to Stripe to enter payment details and start your
MAX subscription. The first invoice's amount is added to your AutoGPT
balance once Stripe confirms the charge.

**Downgrade (paid → paid)**:
> Switching to PRO takes effect at the end of your current billing
period on May 21, 2026 — no charge today. You keep your current plan
until then. From that date your saved card is billed at the PRO rate,
and matching credits are added to your AutoGPT balance with each paid
invoice.

Toast wording on success matches dialog. Tier labels run through
`getTierLabel()` so we render "Pro/Max/Business" not "PRO/MAX/BUSINESS"
(Sentry-flagged in review).

### 6. Top-up Stripe Product ID via LD flag

New `STRIPE_PRODUCT_ID_TOPUP` LD flag. **Unset (default)** → legacy
inline `product_data` (Stripe creates an ephemeral product per Checkout
— backward-compatible with current behavior). **Set to a Stripe Product
ID** → line item references that Product so all top-ups group under one
entity in Stripe Dashboard reporting; per-session amount stays dynamic
via `price_data.unit_amount`. The two paths are mutually exclusive
(Stripe rejects `product` + `product_data` together).

## How

- Backend changes confined to
[`v1.py`](autogpt_platform/backend/backend/api/features/v1.py),
[`credit.py`](autogpt_platform/backend/backend/data/credit.py),
[`user.py`](autogpt_platform/backend/backend/data/user.py),
[`feature_flag.py`](autogpt_platform/backend/backend/util/feature_flag.py).
- Frontend changes: new
[`PaywallGate`](autogpt_platform/frontend/src/app/(platform)/PaywallGate/PaywallGate.tsx)
component + small edits to
[`(platform)/layout.tsx`](autogpt_platform/frontend/src/app/(platform)/layout.tsx),
`SubscriptionTierSection.tsx`, `useSubscriptionTierSection.ts`,
`helpers.ts`.
- Both backend and frontend pass `user.id` to LD context (verified in
[`feature_flag.py:_fetch_user_context_data`](autogpt_platform/backend/backend/util/feature_flag.py)
and
[`feature-flag-provider.tsx`](autogpt_platform/frontend/src/services/feature-flags/feature-flag-provider.tsx))
for proper per-user targeting.

### Out of scope (follow-ups)

- Hard-paywall onboarding integration (Lluis's work — coordinated;
PaywallGate wraps `(platform)/layout.tsx` and onboarding lives in
`(no-navbar)`, so they don't conflict).
- Beta-users-as-Stripe-trial migration.
- Max-cap usage alerting + "Contact us" routing.
- "No Active Subscription" state rename.
- "Your credits" → "Automation Credits" rename + helper tooltip.
- BASIC tier resurface as a free / cancel-subscription option
(deliberately deferred per current product direction).

## Test plan

### Backend (all green in CI)

- [x] `poetry run pytest
backend/api/features/subscription_routes_test.py` — 41 passed.
- [x] `poetry run pytest backend/data/credit_subscription_test.py`
covering: `handle_subscription_payment_success` (grants credits, skips
non-sub/zero/missing-customer/unknown-user/ENTERPRISE, idempotent on
retry), `get_active_subscription_period_end` (happy path, no-customer
short-circuit, Stripe error swallow), top-up Product ID flag both
branches.
- [x] Type-check (3.11/3.12/3.13) — green after explicit
`list[stripe.checkout.Session.CreateParamsLineItem]` typing on top-up
`line_items`.
- [x] Codecov patch — both backend + frontend green.

### Frontend (all green in CI)

- [x] `pnpm test:unit` — 2154/2154 pass, including 5 new PaywallGate
tests (beta-cohort skip, paid-cohort BASIC redirect, no-redirect for
PRO/MAX/BUSINESS, exempt-prefix matrix, loading-state guard) and updated
`formatCost`/dialog-copy assertions.
- [x] `pnpm types`, `pnpm format`, `pnpm lint` — clean.

### Live verification on `dev-builder.agpt.co` (5/5 pass — see PR
comments)

- [x] Login + credits page renders correctly with Pro + Max cards, BASIC
+ BUSINESS hidden, no paywall banner for active subscriber.
- [x] Downgrade dialog shows new copy with concrete date + "no charge
today" + credit-grant explanation.
- [x] PaywallGate does NOT redirect paying users (MAX tier with active
sub).
- [x] PaywallGate REDIRECTS BASIC user (DB-flipped via `kubectl exec`
for testing, restored after) → `/build` redirects to `/profile/credits`,
violet "Pick a plan to continue using AutoGPT" banner displayed.
- [x] Upgrade dialog (modify-in-place) shows the corrected proration
phrasing.
- [ ] Manual: real production-like test of `invoice.payment_succeeded`
granting credits — fires on next billing cycle (2026-05-21 for the dev
test user); not testable today without manipulating Stripe webhook.
2026-04-29 23:15:29 +07:00
Ubbe
c6d31f8252 feat(frontend): gate onboarding SubscriptionStep behind ENABLE_PLATFORM_PAYMENT (#12943)
### Why / What / How

**Why:** The onboarding `SubscriptionStep` (added in #12935) is
currently shown to every new user, but the platform payment system is
rolled out behind the `ENABLE_PLATFORM_PAYMENT` LaunchDarkly flag. We
need the onboarding plan-selection step to honor the same flag so users
in flag-off cohorts don't hit a payment surface that the rest of the
product won't support.

**What:** Conditionally render the `SubscriptionStep` based on
`ENABLE_PLATFORM_PAYMENT`. When the flag is off the wizard runs `Welcome
→ Role → PainPoints → Preparing` (3 user-interactive steps +
transition); when on, behavior is unchanged (`Welcome → Role →
PainPoints → Subscription → Preparing`).

**How:**
- `page.tsx` reads the flag, computes `totalSteps` (3 vs. 4) and
`preparingStep` (4 vs. 5), and only renders `SubscriptionStep` when the
flag is on.
- `useOnboardingPage.ts` threads the same `preparingStep` into the URL
`parseStep` clamp and into the "submit profile when entering Preparing"
effect, so both adapt to the flag state.
- The Zustand store is left unchanged — its hard `Math.min(5, …)` clamp
is unreachable in flag-off flow because PainPointsStep advances 3 → 4
(Preparing) and that's the terminal step.
- `playwright/utils/onboarding.ts`: with `NEXT_PUBLIC_PW_TEST=true`
LaunchDarkly returns `defaultFlags` (`ENABLE_PLATFORM_PAYMENT: false`),
so the helper now waits up to 2s for the Subscription header and only
clicks a plan CTA if the step is actually rendered.

### Changes 🏗️

- `autogpt_platform/frontend/src/app/(no-navbar)/onboarding/page.tsx` —
gate `SubscriptionStep` on `ENABLE_PLATFORM_PAYMENT`; derive
`totalSteps`/`preparingStep` from the flag.
-
`autogpt_platform/frontend/src/app/(no-navbar)/onboarding/useOnboardingPage.ts`
— make `parseStep` and the profile-submission effect respect the
flag-derived `preparingStep`.
- `autogpt_platform/frontend/src/playwright/utils/onboarding.ts` — make
the Subscription step optional in `completeOnboardingWizard` so E2E
works in both flag states.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [ ] I have tested my changes according to the test plan:
- [x] Existing onboarding unit tests pass (`pnpm test:unit` — 2447
passed, including `PainPointsStep`, `RoleStep`, `SubscriptionStep`,
store)
  - [x] `pnpm format`, `pnpm lint`, `pnpm types` clean
- [ ] Manual: with flag **off**, walk onboarding and confirm wizard goes
Welcome → Role → PainPoints → Preparing → /copilot, progress bar shows 3
steps
- [ ] Manual: with flag **on** (LD or
`NEXT_PUBLIC_FORCE_FLAG_ENABLE_PLATFORM_PAYMENT=true`), walk onboarding
and confirm SubscriptionStep is present at step 4, progress bar shows 4
steps
- [ ] Manual: with flag **off**, hit `/onboarding?step=5` directly and
confirm it clamps back to step 1 (no orphan Subscription state)
- [ ] Playwright: `completeOnboardingWizard` E2E flow continues to pass
under default `NEXT_PUBLIC_PW_TEST=true` (flag off path)

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
(no config changes — flag already exists in LaunchDarkly +
`defaultFlags`)
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (none needed)

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-29 23:12:04 +07:00
John Ababseh
28ae7ebac8 feat(onboarding): add subscription plan selection step (#12935)
## Summary

Adds a new **Subscription Step** (Step 4) to the onboarding wizard,
allowing users to choose a plan (Pro, Max, or Team) before reaching the
"Preparing" step.

## Changes

### New files
- **`steps/SubscriptionStep.tsx`** — Full subscription UI with:
  - Three plan cards (Pro $50/mo, Max $320/mo, Team — coming soon)
- Monthly / yearly billing toggle (yearly shows annual total with 20%
discount, plus monthly equivalent)
- Country selector (28 Stripe-supported countries) that opens upward as
a search modal
  - Localized pricing using live exchange rates
- **`steps/countries.ts`** — Currency data module with exchange rates,
`formatPrice()` helper, and zero-decimal currency handling (JPY, KRW,
HUF, CLP)

### Modified files
- **`store.ts`** — Extended `Step` type to `1 | 2 | 3 | 4 | 5`, added
`selectedPlan` and `selectedBilling` state/actions
- **`page.tsx`** — Wired `SubscriptionStep` as Step 4, moved
`PreparingStep` to Step 5, adjusted progress bar and dot indicators
- **`useOnboardingPage.ts`** — Updated `parseStep` range to 1–5, profile
submission now triggers at Step 5

## Design decisions
- Follows existing component patterns: uses `FadeIn`, `Text`, `Button`
atoms, `cn()` utility, Phosphor icons
- Country selector opens **upward** to avoid clipping below the viewport
- Plan selection advances to Step 5 immediately (Stripe integration is
TODO)
- Exchange rates are hardcoded for now — should be fetched from an API
in production

## TODO
- [ ] Integrate with Stripe checkout / backend subscription API
- [ ] Fetch live exchange rates instead of hardcoded values
- [ ] Add responsive layout for mobile viewports

---------

Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
Co-authored-by: Lluis Agusti <hi@llu.lu>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: Ubbe <hi@ubbe.dev>
2026-04-29 21:26:18 +07:00
Krzysztof Czerwinski
e0f9146d54 feat(platform): add Web Push notifications via VAPID for background delivery (#12723)
### Why / What / How

**Why:** When a user kicks off an AutoPilot task and leaves the platform
(closes the tab, switches to another page, or minimizes the browser),
they have no way of knowing when it completes unless they come back and
check. This breaks the "set it and forget it" promise of automation.

**What:** Adds Web Push notifications using the standard Push API
(VAPID). Push notifications are delivered through free browser vendor
services (Google FCM, Apple APNs, Mozilla Push) to a service worker —
even when all AutoGPT tabs are closed, as long as the browser process is
running. The system is generic and extensible to all notification types,
with copilot session completion as the first integration.

**How:**
- **Backend:** A new `PushSubscription` Prisma model stores per-user
push subscriptions. When a `NotificationEvent` is published to the Redis
notification bus, the existing `notification_worker` in `ws_api.py`
fires a tracked background `send_push_for_user()` task. This uses
`pywebpush` to call the browser push services with VAPID authentication.
Includes per-user TTL-bounded debounce (5s), per-user subscription cap
(20), 410/404 auto-cleanup, periodic scheduler-driven cleanup of
high-failure rows, and route-level SSRF rejection of untrusted
endpoints.
- **Frontend:** A `push-sw.js` service worker handles `push` events and
shows OS notifications via `self.registration.showNotification()`, with
click-to-navigate. A `PushNotificationProvider` mounted at the platform
layout registers the SW and subscription on all pages, posts the current
URL to the SW on every Next.js navigation (since Chrome's
`WindowClient.url` is stale for SPA routing), forwards the user's
notifications-toggle setting to the SW, and tears down on logout. The
copilot in-page notification path defers to the SW when a push
subscription is active so users don't get duplicate alerts.

### Behavior — when does an OS notification fire?

| Where the user is focused | Notifications toggle | OS notification? |
|---|---|---|
| Any `/copilot` page (any session, tab visible + browser focused) | on
| suppressed — sidebar green check + title badge handle it |
| `/library` (or any non-`/copilot` route) | on | **fires** |
| `/copilot` but tab hidden (Cmd-Tab away, minimized, different tab) |
on | **fires** |
| All AutoGPT tabs closed (browser process still running) | on |
**fires** |
| Any state | off | suppressed |
| Anywhere | permission not granted / no push subscription | falls back
to in-page `Notification()` if user is away on `/copilot`; nothing
otherwise |

Click any OS notification → focuses an existing tab and navigates it to
`/copilot?sessionId=<id>`, or opens a new window if no AutoGPT tab is
open.

### Test plan

#### Setup
- [ ] Generate VAPID keys via the snippet in `backend/.env.default` and
set `VAPID_PRIVATE_KEY`, `VAPID_PUBLIC_KEY`, `VAPID_CLAIM_EMAIL` in
`backend/.env`
- [ ] Leave `NEXT_PUBLIC_VAPID_PUBLIC_KEY` unset on the frontend (single
source of truth via `/api/push/vapid-key`)
- [ ] Start backend + frontend, grant notification permission on the
copilot page
- [ ] Verify `push-sw.js` is "activated and is running" in DevTools →
Application → Service Workers
- [ ] Verify `POST /api/push/subscribe` created exactly one DB row in
`PushSubscription` for your user

#### Notification show / suppress matrix
- [ ] Trigger completion **on `/copilot` viewing the same session**, tab
visible + focused → no OS notification (sidebar green check appears)
- [ ] Trigger completion **on `/copilot` viewing a different session**,
tab visible + focused → no OS notification (still considered "in the
feature")
- [ ] Trigger completion **on `/library`**, tab visible + focused → OS
notification fires
- [ ] Trigger completion **on `/copilot`** but with the tab hidden
(Cmd-Tab to another app) → OS notification fires
- [ ] Trigger completion with all AutoGPT tabs closed → OS notification
fires (browser must still be running)
- [ ] Toggle notifications **off** in the copilot UI → trigger
completion → no OS notification
- [ ] Toggle notifications **back on** → trigger completion → OS
notification fires

#### Click behavior
- [ ] OS notification → click → focuses an existing AutoGPT tab and
navigates to `/copilot?sessionId=<id>`
- [ ] OS notification with no AutoGPT tab open → click → opens a new tab
on `/copilot?sessionId=<id>`

#### Lifecycle
- [ ] Logout → DB row removed, browser unsubscribed; no further OS
notifications until login + re-subscribe
- [ ] Stale subscription (e.g. unsubscribed externally) → backend gets
410 from FCM → row auto-deleted; second push attempts no longer fan out
to it

### Changes 🏗️

**Backend — New files:**
- `backend/data/push_subscription.py` — CRUD for push subscriptions:
`upsert` (with `MAX_SUBSCRIPTIONS_PER_USER` cap), `find_many`, `delete`,
`increment_fail_count`, `cleanup_failed_subscriptions`,
`validate_push_endpoint` (HTTPS + push-service hostname allowlist for
SSRF prevention)
- `backend/data/push_sender.py` — Fire-and-forget push delivery with
`cachetools.TTLCache`-bounded debounce, defense-in-depth re-validation
at send time, 410/404 auto-cleanup with regex-based status extraction
(covers pywebpush versions where `e.response` is unset)
- `backend/api/features/push/routes.py` — 3 endpoints: `GET
/api/push/vapid-key`, `POST /api/push/subscribe`, `POST
/api/push/unsubscribe` (all with `requires_user` auth and 400 on invalid
endpoints)
- `backend/api/features/push/model.py` — Pydantic models with
`min_length`/`max_length` constraints on endpoint and crypto keys

**Backend — Modified files:**
- `schema.prisma` — Added `PushSubscription` model + `User` relation
- `pyproject.toml` — Added `pywebpush ^2.3` dependency
- `backend/util/settings.py` — VAPID key fields on `Secrets`;
`push_subscription_cleanup_interval_hours` config
- `backend/api/rest_api.py` — Registered push router at `/api/push`
- `backend/api/ws_api.py` — Notification worker now fires
`send_push_for_user()` as a tracked background task (strong-ref set +
done callback so asyncio doesn't GC it mid-run)
- `backend/data/db_manager.py` — Exposed push subscription RPC methods
on the DB manager async client
- `backend/executor/scheduler.py` — Periodic
`cleanup_failed_push_subscriptions` job (default 24h)
- `backend/.env.default` — VAPID env vars with key generation snippet

**Frontend — New files:**
- `public/push-sw.js` — Service worker: routes pushes via
`NOTIFICATION_MAP`, suppresses when user is on `/copilot`, accepts
`CLIENT_URL` and `NOTIFICATIONS_ENABLED` postMessages so SW logic stays
in sync with SPA navigation and the toggle, click handler with focus →
navigate → openWindow fallback, `pushsubscriptionchange` re-subscribe
with `credentials: include`
- `src/services/push-notifications/registration.ts`, `api.ts`,
`helpers.ts` — SW registration / Push API subscription / backend API
helpers
- `src/services/push-notifications/usePushNotifications.ts` — Hook that
auto-subscribes on login and tears down on logout
- `src/services/push-notifications/useReportClientUrl.ts` — Posts
current pathname+search to SW on every Next.js route change (works
around stale `WindowClient.url`)
- `src/services/push-notifications/useReportNotificationsEnabled.ts` —
Forwards the user's notifications toggle to the SW
- `src/services/push-notifications/PushNotificationProvider.tsx` —
Mounts all three hooks at the platform layout level

**Frontend — Modified files:**
- `src/app/(platform)/layout.tsx` — Mounted `<PushNotificationProvider
/>`
- `src/app/(platform)/copilot/useCopilotNotifications.ts` — Skips
in-page `Notification()` when a SW push subscription is active (avoids
duplicate alerts)
- `src/services/storage/local-storage.ts` — Added
`PUSH_SUBSCRIPTION_REGISTERED` key
- `frontend/.env.default` — Optional `NEXT_PUBLIC_VAPID_PUBLIC_KEY`
(left unset by default to keep `/api/push/vapid-key` as the single
source of truth)

**Configuration changes:**
- New env vars: `VAPID_PRIVATE_KEY`, `VAPID_PUBLIC_KEY`,
`VAPID_CLAIM_EMAIL` (backend); optional `NEXT_PUBLIC_VAPID_PUBLIC_KEY`
(frontend)
- New `push_subscription_cleanup_interval_hours` setting (default 24,
range 1–168)
- New DB migration: `PushSubscription` table
(`20260420120000_add_push_subscription`)

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] All blockers and should-fixes from the autogpt-pr-reviewer review
have been addressed (see PR thread)
- [x] All inline review threads resolved (49 threads addressed)

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-29 13:28:21 +00:00
Bently
c3c2737c42 feat(platform): copilot-bot (Python / discord.py) (#12618)
## Why

AutoPilot needs to reach users on chat platforms — Discord first,
Telegram / Slack / Teams / WhatsApp next. This PR adds the bot service
that bridges those platforms to the AutoPilot backend via the
`PlatformLinkingManager` AppService introduced in #12615.

Two independent linking flows (see #12615 for the rationale):

- **SERVER links**: first person to run `/setup` in a guild claims it.
Anyone in the server can mention the bot; all usage bills to the owner.
- **USER links**: an individual DMs the bot, links their personal
account, DMs bill to their own AutoPilot. A server owner still has to
link their DMs separately.

## What

A Python service using `discord.py`, living alongside the rest of the
backend. Connects to the platform_linking service via cluster-internal
RPC (no shared bearer token) and subscribes to copilot streams directly
on Redis (no HTTP SSE proxy).

Originally prototyped in Node.js with Vercel's Chat SDK — rewritten in
Python after team feedback: the rest of the platform is Python,
`discord.py` was already a dependency, and the Chat SDK's streaming-UI
abstractions don't apply to a headless chat bot.

### Deployment

- **Shares the existing backend Docker image** — no separate Dockerfile,
no separate Artifact Registry. A `copilot-bot` poetry script entry lets
the same image run with `command: ["copilot-bot"]` in the Helm chart.
- **Auto-starts with `poetry run app`** when
`AUTOPILOT_BOT_DISCORD_TOKEN` is set, so the full local dev stack
includes the bot without extra setup.
- **Runs standalone** via `poetry run copilot-bot` for the production
pod.

Infra PR:
[AutoGPT_cloud_infrastructure#310](https://github.com/Significant-Gravitas/AutoGPT_cloud_infrastructure/pull/310).

### File layout

```
backend/copilot/bot/
├── app.py              # CoPilotChatBridge(AppService) + adapter factory + outbound @expose
├── config.py           # Shared (platform-agnostic) config
├── handler.py          # Core logic: routing, linking, batched streaming
├── platform_api.py     # Thin facade over PlatformLinkingManagerClient + stream_registry
├── platform_api_test.py
├── text.py             # split_at_boundary + format_batch
├── threads.py          # Redis-backed thread subscription tracking
├── README.md
└── adapters/
    ├── base.py         # PlatformAdapter ABC + MessageContext
    └── discord/
        ├── adapter.py  # Gateway connection, events, thread creation, buttons
        ├── commands.py # /setup, /help, /unlink
        └── config.py   # Discord token + message limits
```

**Locality rule:** anything platform-specific lives under
`adapters/<platform>/`. `app.py` is the only file that names specific
platforms — it's the factory that picks adapters based on which tokens
are set. Adding Telegram later = drop in `adapters/telegram/` with the
same shape.

### `CoPilotChatBridge` — now an `AppService`

Previously `AppProcess`. Now inherits `AppService`, runs its RPC server
on `Config.copilot_chat_bridge_port=8010`, and exposes two scaffolding
`@expose` methods for the backend→chat-platform direction:

- `send_message_to_channel(platform, channel_id, content)` — stub
- `send_dm(platform, platform_user_id, content)` — stub

Both currently raise `NotImplementedError` — they unlock the
architecture for future features (scheduled agent outputs piped to
Discord, etc.) without another structural change. A matching
`CoPilotChatBridgeClient` + `get_copilot_chat_bridge_client()` factory
lets other services call the bot by the same `AppServiceClient` pattern
used for `NotificationManager` and `PlatformLinkingManager`.

### Bot behaviour

- `/setup` — server only, ephemeral, returns a "Link Server" button.
Rejects DM invocations up front.
- `/help` — ephemeral usage info.
- `/unlink` — ephemeral, opens a "Settings" button pointing at
`AUTOGPT_FRONTEND_URL/profile/settings` (real unlinking needs JWT auth).
- **Thread per conversation**: @mentioning the bot in a channel creates
a thread and routes the reply there. Subsequent messages in that thread
don't need another @mention — thread subscriptions are tracked in Redis
with a 7-day TTL.
- **Batched follow-ups**: messages arriving mid-stream append to a
per-thread pending list; drained as a single follow-up turn when the
current stream ends.
- **Persistent typing indicator**: 8-second re-fire loop.
- **Per-user identity prefix**: every forwarded message tagged `[Message
sent by {name} (Discord user ID: ...)]`.
- **Platform-aware chunking**: long responses split at paragraph → line
→ sentence → word boundaries (1900 chars for Discord).
- **Link buttons** for DM link prompts and `/setup` / `/unlink`
responses.
- **Duplicate message guard**: on `DuplicateChatMessageError` the bot
stays quiet — no double response.

### Env vars

| Variable | Purpose |
|----------|---------|
| `AUTOPILOT_BOT_DISCORD_TOKEN` | Discord bot token — enables the
Discord adapter |
| `AUTOGPT_FRONTEND_URL` | Frontend base URL for link confirmation pages
|
| `REDIS_HOST` / `REDIS_PORT` | Shared with backend — session +
thread-subscription state + direct copilot stream subscription |
| `PLATFORMLINKINGMANAGER_HOST` | Cluster DNS name of the
`PlatformLinkingManager` service (RPC target) |

Gone vs. the previous REST design: `AUTOGPT_API_URL`,
`PLATFORM_BOT_API_KEY`, `SSE_IDLE_TIMEOUT`.

## How

- **Adapter pattern**: `PlatformAdapter` ABC defines `start`, `stop`,
`send_message`, `send_link`, `start_typing`, `create_thread`,
`max_message_length`, `chunk_flush_at`, etc. Each platform implements
the interface; the shared `MessageHandler` calls through it.
- **Control plane over RPC**: `PlatformAPI` (~180 lines) is a thin
facade over `PlatformLinkingManagerClient` — `resolve_server`,
`resolve_user`, `create_link_token`, `create_user_link_token`,
`stream_chat`. The bot never constructs HTTP requests or handles an API
key.
- **Streaming over Redis Streams**: `stream_chat` calls
`start_chat_turn` (backend `@expose`), receives a
`ChatTurnHandle(session_id, turn_id, user_id, subscribe_from="0-0")`,
then subscribes directly via
`stream_registry.subscribe_to_session(...)`. Yields text from
`StreamTextDelta`, terminates on `StreamFinish`, surfaces
`StreamError.errorText` to the user. No SSE parsing, no X-Session-Id
header dance.
- **Error model**: backend domain exceptions (`NotFoundError`,
`LinkAlreadyExistsError`, `DuplicateChatMessageError`) cross the RPC
boundary cleanly (all `ValueError`-based, registered in
`backend.util.exceptions`). The bot catches them by type instead of
inspecting HTTP status codes.
- **Cooperative batching**: `TargetState.processing` flag + per-target
`pending` list. Messages arriving while `processing=True` append; the
running stream's finally block loops to drain the list before releasing.
- **Typing helper for `endpoint_to_async`**: added an `@overload` so
`async def` `@expose` methods on the server type-check correctly on the
client side (the scheduler pattern avoids this by using sync `@expose`,
but the new managers are async).

## Tests

- `backend/copilot/bot/platform_api_test.py` — new. Covers resolve
(server + user), create link tokens (success + `LinkAlreadyExistsError`
propagation), stream chat (yields deltas, terminates on `StreamFinish`,
surfaces `StreamError`, propagates `DuplicateChatMessageError` and
`NotFoundError`, handles `subscribe_to_session` returning `None`).
- `poetry run pyright backend/copilot/bot/` — clean.
- `poetry run ruff check backend/copilot/bot/` — clean.
- `poetry run copilot-bot` starts and connects to Discord Gateway, syncs
slash commands.
- `/setup` in a guild → confirm on frontend → mention bot → AutoPilot
streams back in a created thread.
- Thread follow-ups work without re-mentioning.
- Spamming messages mid-stream produces one batched follow-up.
- Long responses chunk at natural boundaries.
- DM to unlinked user → "Link Account" button → confirm → DMs stream as
that user's AutoPilot.

## Stack

- Backend API: #12615 — merge first
- Frontend link page: #12624
- Infra:
[AutoGPT_cloud_infrastructure#310](https://github.com/Significant-Gravitas/AutoGPT_cloud_infrastructure/pull/310)

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: CodeRabbit <noreply@coderabbit.ai>
2026-04-29 08:12:15 +00:00
Abhimanyu Yadav
37f247c795 feat(frontend): creator dashboard page for settings v2 (SECRT-2281) (#12934)
### Why / What / How

**Why:** The creator dashboard route under settings v2 currently shows a
"Coming soon" placeholder. SECRT-2281 fills it in so creators can manage
their store submissions from one place.

**What:** Implements the full creator dashboard at
`/settings/creator-dashboard` — stats overview, desktop submissions
table, mobile submissions list, filtering/sorting, selection bar, edit
modal, and empty/loading/error states.

**How:** Page logic lives in `useCreatorDashboardPage.ts` (data fetch,
filter state, modal state, CRUD callbacks); pure transforms in
`helpers.ts`; UI broken into colocated `components/*` (one folder per
component, each ~200–400 lines). Reuses generated API hooks,
`ErrorCard`, and `EditAgentModal` from the design system. Mobile/desktop
split via Tailwind `md:` breakpoints rather than runtime detection.

### Changes 🏗️

- Replace placeholder `page.tsx` with the real dashboard, wired to
`useCreatorDashboardPage`
- Add `useCreatorDashboardPage.ts` (page-level state + handlers) and
`helpers.ts` (filter/sort/stat utilities)
- Add components: `DashboardHeader`, `DashboardSkeleton`, `EmptyState`,
`StatsOverview`, `SubmissionsList` (+ `columns/*`,
`useSubmissionSelection`), `SubmissionItem` (+ `useSubmissionItem`),
`SubmissionSelectionBar`, `MobileSubmissionsList` (+
`MobileSelectionBar`), `MobileSubmissionItem`, `ColumnFilter`
- Set document title to "Creator dashboard – AutoGPT Platform"
- Surface fetch errors via `ErrorCard` with retry; show
`DashboardSkeleton` while loading; show `EmptyState` when there are no
submissions

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [ ] I have tested my changes according to the test plan:
  - [ ] Loading state renders skeleton until submissions load
  - [ ] Empty state renders when the creator has no submissions
  - [ ] Error state renders `ErrorCard` and retry refetches the list
  - [ ] Stats overview reflects approved/pending/rejected/draft counts
- [ ] Desktop list: sort/filter by status and other columns updates the
visible rows
- [ ] Desktop list: selection bar appears on row select and clears on
reset
- [ ] Mobile list (≤ md breakpoint): renders mobile items + selection
bar
- [ ] Edit modal opens for a submission, saves, and refreshes the list
on success
  - [ ] Delete action removes the submission and updates stats
  - [ ] View action navigates to the submission's public detail
  - [ ] Submit/publish entry point opens the publish modal
  - [ ] Document title shows "Creator dashboard – AutoGPT Platform"
2026-04-28 16:51:52 +00:00
Abhimanyu Yadav
ae4a421620 fix(platform): small fixes and stagger animations on settings pages (#12937)
## Why

The new Settings v2 surfaces (preferences, api-keys, integrations,
profile) shipped with a few rough edges spotted in self-review:

- **Timezone saves silently dropped on refresh.** Backend `GET
/auth/user/timezone` resolved the user via
`get_or_create_user(user_data)` (a 5-min in-process cache keyed by the
JWT-payload dict). `update_user_timezone` only invalidates
`get_user_by_id`'s cache, so the GET kept returning the pre-save tz
until TTL expired — looked exactly like "save did nothing."
- **Confusing "Looks like you're in X" CTA on the Time zone card** that
did nothing in the common case (server tz already matched the browser
tz, so clicking it produced no dirty state).
- **Save was disabled out of the gate when server tz was `"not-set"`** —
the hook substituted the browser tz into both `formState` and
`savedState`, so they were equal and `dirty` was false.
- **Lists felt static.** No motion when API keys / integrations mount,
and the loading skeletons popped in all at once instead of handing off
cleanly to the loaded rows.
- **Profile bio textarea** corner clipped against the rounded-3xl border
and the scrollbar overflowed the rounded container.

## What

### Bug fixes
- `GET /auth/user/timezone` now reads via `get_user_by_id(user_id)` —
the same cache `update_user_timezone` already invalidates — so a save
followed by refresh shows the new tz immediately.
- `usePreferencesPage` now treats the raw server tz (`"not-set"`
included) as the saved baseline, while `formState` uses the browser tz
only as a *display* fallback. Effect: when the user has never set a tz,
Save is enabled on first paint and a single click persists the detected
tz.
- Frontend save flow swapped `setQueryData` for `invalidateQueries`,
mirroring the older `/profile/(user)/settings` page so we always re-read
the persisted value.
- Removed the auto-detect "Looks like you're in X" button + its dead
helpers.

### Animations (per Emil Kowalski's guidelines)
Added orchestrated stagger animations that run on both the loaded list
**and** its skeleton, so the loading→loaded handoff is continuous
in-position:

- **API keys list + skeleton:** 280ms ease-out `cubic-bezier(0.16, 1,
0.3, 1)`, 40ms stagger, opacity + 6px translate.
- **Integrations list + skeleton:** 300ms ease-out, 80ms stagger,
opacity + 16px translate (rows are bigger / fewer).
- Both honor `prefers-reduced-motion` via `useReducedMotion`; only
`opacity` and `transform` are animated.

### Misc polish
- Profile bio textarea: `!rounded-tr-md` so the top-right corner doesn't
fight the surrounding `rounded-3xl`, plus a thin styled scrollbar
(`scrollbar-thin scrollbar-thumb-zinc-200
hover:scrollbar-thumb-zinc-300`) that lives inside the rounded container
instead of breaking out of it.

## How

| File | Change |
| --- | --- |
| `backend/api/features/v1.py` | `get_user_timezone_route` now uses
`get_user_by_id` + `Security(get_user_id)` instead of
`get_or_create_user(user_data)` |
| `frontend/.../preferences/usePreferencesPage.ts` | Split init into
`initialFormState` (browser-fallback display) vs `initialSavedState`
(raw server value); swap optimistic `setQueryData` for
`invalidateQueries` after tz mutate |
| `frontend/.../preferences/components/TimezoneCard/TimezoneCard.tsx` |
Drop `initialValue` prop, remove auto-detect button + unused imports |
| `frontend/.../preferences/page.tsx` | Drop `savedState`/`initialValue`
wiring |
| `frontend/.../api-keys/components/APIKeyList/APIKeyList.tsx` | Wrap
rows in container `motion.div` with `staggerChildren`; per-row
`motion.div` with opacity + y variants |
|
`frontend/.../api-keys/components/APIKeyListSkeleton/APIKeyListSkeleton.tsx`
| Same stagger config so loading→loaded matches |
|
`frontend/.../integrations/components/IntegrationsList/IntegrationsList.tsx`
+ `IntegrationsListSkeleton.tsx` | Same pattern for the providers list |
| `frontend/.../profile/components/ProfileForm/ProfileForm.tsx` |
Tailwind classes only — `!rounded-tr-md` + `scrollbar-thin
scrollbar-thumb-zinc-200 hover:scrollbar-thumb-zinc-300` |

## Test plan

- [ ] On `/settings/preferences`: pick a different tz → Save →
hard-refresh → new tz still selected.
- [ ] First-time user (server tz = `not-set`): land on page, Save button
should already be enabled; click Save → toast confirms; refresh → tz
persists.
- [ ] No "Looks like you're in X" button visible.
- [ ] On `/settings/api-keys`: rows fade/slide in staggered on first
mount; loading skeleton uses the same motion.
- [ ] On `/settings/integrations`: provider groups fade/slide in
staggered; skeleton matches.
- [ ] OS "Reduce motion" enabled → no transforms, content appears
instantly on all four surfaces.
- [ ] On `/settings/profile`: bio textarea top-right corner is no longer
hard-cornered against the card; scrollbar fits inside the rounded shape.
- [ ] Existing unit tests still pass: `pnpm test:unit
src/app/\(platform\)/settings/preferences` and `.../api-keys`.
2026-04-28 16:51:40 +00:00
Zamil Majdy
2879528308 feat(backend): Redis Cluster client support (#12900)
## Why

Pre-launch scaling. Redis is currently a single-master pod — a real
SPOF, and not scalable horizontally. To move it to a sharded Redis
Cluster (via KubeBlocks in GKE), the backend has to speak the cluster
protocol.

Keeping both "standalone" and "cluster" code paths would have local dev
not reflect prod. Going **cluster-only**.

## What

- `backend.data.redis_client` now always constructs `RedisCluster`
(sync) / `redis.asyncio.cluster.RedisCluster` (async). Type aliases
`RedisClient` / `AsyncRedisClient` point at the cluster classes.
- `RedisCluster` uses the existing `REDIS_HOST` / `REDIS_PORT` as a
startup node and auto-discovers peers via `CLUSTER SLOTS`.
- Classic Redis pub/sub is broadcast cluster-wide and redis-py's async
`RedisCluster` has no `.pubsub()`; dedicated `get_redis_pubsub[_async]`
helpers return plain `(Async)Redis` clients to the seed node. All
pub/sub callers (`event_bus`, `notification_bus`,
`copilot.pending_messages`) route through these helpers.
- `rate_limit.py` MULTI/EXEC pipelines are split per-counter — daily and
weekly counters hash to different slots, which `RedisCluster` correctly
rejects as `CrossSlotTransactionError`. Per-counter `INCRBY + EXPIRE`
atomicity is preserved; the counters are logically independent budgets.
- `util/cache.py` shared-cache client is also `RedisCluster` now.
- Pre-existing mock-based unit tests updated; new `redis_client_test.py`
covers the swap.

## Local dev

`docker-compose.platform.yml` now runs **2-master Redis Cluster**
(`redis` + `redis-2`, 16384 slots split 0-8191 / 8192-16383). A one-shot
`redis-init` sidecar bootstraps it on first boot via raw `CLUSTER MEET`
+ `CLUSTER ADDSLOTSRANGE` (bundled `redis-cli --cluster create` enforces
a 3-node minimum).

This deliberately catches cross-slot bugs on a laptop rather than in
prod:

```
>>> ALL SMOKE TESTS PASS <<<
[sync] class: RedisCluster
[sync] 20 keys across slots: OK
[sync] colocated MULTI/EXEC: OK [5, 12, 1]
[sync] cross-slot MULTI/EXEC rejected as expected: CrossSlotTransactionError
[sync] EVAL single-key: OK
[sync] pub/sub (classic, broadcast): OK
[async] class: RedisCluster
[async] 15 keys across slots: OK
[async] colocated pipeline: OK
[async] pub/sub: OK
```

`rest_server` `/health` → 200, both shards have connected clients + keys
distributed 19/19 under the smoke run. `executor` boots + connects to
RabbitMQ + Redis cleanly.

For a 3-shard override (6 pods, with replicas) when you want to test
real KubeBlocks topology:
```
docker compose -f docker-compose.yml -f docker-compose.redis-cluster.yml up -d
```

## Deploy order (companion infra PR:
[cloud_infrastructure#312](https://github.com/Significant-Gravitas/AutoGPT_cloud_infrastructure/pull/312))

The existing `helm/redis` chart is updated in that PR to run as a
1-shard cluster (backwards-compatible toggle, default on). That rollout
must land before this PR's image goes live so the backend's
`RedisCluster` client has something to discover.

Sequence:
1. Infra: `helm upgrade redis` (1-shard cluster-enabled)
2. Infra: `helm upgrade rabbit-mq` (3-node cluster)
3. Backend: merge + deploy this PR
4. Follow-up: swap to KubeBlocks `redis-cluster` chart (3-shard sharded,
already staged in infra PR)

## Caveats / follow-ups

- Classic pub/sub via seed node means every node in the cluster sees
every message (broadcast). Fine at current volume; if it becomes hot,
migrate to `SPUBLISH`/`SSUBSCRIBE` (Redis 7+ sharded pub/sub).
- Per-user rate-limit counters (daily vs weekly) lost cross-counter
transactionality, but per-counter atomicity is preserved — the two
counters are independent budgets so no correctness regression.
- Local 2-master cluster crashes lose the cluster state; `redis-init`
idempotently rebootstraps.

## Checklist

- [x] Lint + format pass (`poetry run format` + `poetry run lint`)
- [x] Unit tests pass — `redis_client_test`, `redis_helpers_test`,
`event_bus_test`, `pending_messages_test`, `rate_limit_test`,
`cluster_lock_test`
- [x] Live smoke against 2-master cluster — sync + async; MULTI/EXEC;
EVAL; pub/sub; cross-slot rejection
- [x] Full stack smoke — `rest_server` /health, `executor` boot, keys
distributed across both shards
- [ ] Dev deploy (pending infra PR merge + manual validation)
2026-04-28 22:21:23 +07:00
Ubbe
1974ec6260 fix(frontend/copilot): fix streaming reconnect races, hydration ordering, and reasoning split (#12813)
## Summary

Improves Copilot/AutoPilot streaming reliability across frontend and
backend. The diff now covers the original streaming investigation issues
plus follow-up CI and review fixes from the latest merge with `dev`.

Addresses [SECRT-2240](https://linear.app/autogpt/issue/SECRT-2240),
[SECRT-2241](https://linear.app/autogpt/issue/SECRT-2241), and
[SECRT-2242](https://linear.app/autogpt/issue/SECRT-2242).

## Changes

- Fixes reasoning vs response rendering so action tools such as
`run_block` and `run_agent` do not cause assistant response text to be
hidden inside the collapsed reasoning section.
- Reworks Copilot session lifecycle handling: active-stream hydration,
resume ordering, reconnect timeout recovery, wake resync, session
deletion, title polling, stop handling, and session-switch stale
callback guards.
- Adds a per-session Copilot stream store/registry and transport helpers
to prevent duplicate resumes, duplicate sends, and cross-session
contamination during reconnect or reload flows.
- Adds pending follow-up message support and backend pending-message
safeguards, including sanitization of queued user content and
requeue-on-persist-failure behavior.
- Improves backend stream and executor robustness: active stream
registry checks, bounded cancellation drain with sync fail-close
fallback, Redis helper coverage, and updated SDK response adapter
expectations for post-tool status events.
- Adds and polishes usage-limit UI, including reset gate behavior,
backdrop blending behind the usage-limit card, and usage panel/card
coverage.
- Fixes a chat input Enter-submit race where Playwright and fast users
could fill the textarea and press Enter before React had re-enabled the
submit button, causing the visible message not to send.
- Refactors the Copilot page into smaller hooks/components and adds
focused tests around stream recovery, hydration, pending queueing,
rate-limit gates, and message rendering.

## Test plan

- [x] `poetry run format`
- [x] `poetry run pytest backend/copilot/sdk/response_adapter_test.py
backend/copilot/executor/processor_test.py`
- [x] `pnpm prettier --write` on touched frontend files
- [x] `pnpm vitest run
src/app/(platform)/copilot/components/ChatInput/__tests__/useChatInput.test.ts`
- [x] `pnpm types`
- [x] `pnpm lint` (passes with existing unrelated `next/no-img-element`
warnings)
- [ ] Full GitHub CI after latest push

## Review notes

- The Sentry review thread about unbounded cancellation cleanup is
addressed in `375ec9d5f`: cancellation now waits for normal async
cleanup but exits after `_CANCEL_GRACE_SECONDS` and falls through to the
sync fail-close path.
- The previous backend CI failures were stale test expectations around
the new `StreamStatus("Analyzing result…")` event after tool output;
tests now assert that event explicitly.
- The previous full-stack E2E failure was the Copilot input Enter race;
the input now submits from the live form value instead of depending on a
possibly stale disabled button state.

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
2026-04-28 15:40:37 +07:00
Zamil Majdy
932ecd3a07 fix(backend/copilot): normalize model name based on actual transport, not config shape (#12932)
## Summary

When `CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true` is paired with a populated
`CHAT_BASE_URL=https://openrouter.ai/api/v1` (e.g. left over from an
earlier OpenRouter setup), the SDK was passing the OpenRouter slug
`anthropic/claude-opus-4.7` straight through to the Claude Code CLI
subprocess. The CLI uses OAuth and ignores
`CHAT_BASE_URL`/`CHAT_API_KEY`, so it rejects the slug:

> There's an issue with the selected model (anthropic/claude-opus-4.7).
It may not exist or you may not have access to it.

The bug was in `_normalize_model_name`, which gated on
`config.openrouter_active` (config-shape check) instead of the transport
the CLI actually uses for the turn.

## Changes

- Add `ChatConfig.effective_transport` property returning `subscription`
| `openrouter` | `direct_anthropic`, detected in that priority order.
Subscription wins over OpenRouter config because the CLI subprocess uses
OAuth and ignores the OpenRouter env vars (see `build_sdk_env` mode 1).
- Switch `_normalize_model_name` to gate on `effective_transport`.
Subscription and direct-Anthropic transports both produce the
CLI-friendly hyphenated form (`claude-opus-4-7`) and reject
non-Anthropic vendors loudly.
- `_resolve_sdk_model_for_request` already routes any LD-served override
through `_normalize_model_name`, so a per-user advanced-tier override
under subscription now correctly becomes `claude-opus-4-7` instead of
the OpenRouter slug. The standard-tier \"no LD override → return None\"
behaviour is preserved.
- Update two existing service tests to assert the corrected behaviour
(Kimi LD override under subscription falls back to tier default
normalised for the CLI; Opus advanced override returns hyphenated form).

## Test plan

- [x] `poetry run pytest backend/copilot/sdk/service_helpers_test.py
backend/copilot/sdk/service_test.py backend/copilot/config_test.py -v` —
165 passed.
- [x] `poetry run pytest backend/copilot/sdk/env_test.py
backend/copilot/sdk/p0_guardrails_test.py` — 136 passed (other call
sites of `openrouter_active` unchanged).
- [x] `poetry run ruff format` + `ruff check` clean on touched files.

### New tests added (service_helpers_test.py)

- Subscription transport with OpenRouter base URL set + advanced-tier LD
override → returns `claude-opus-4-7` (not the OpenRouter slug, not
None).
- Subscription transport with OpenRouter base URL set + standard-tier no
override → returns None (existing behaviour preserved).
- Subscription transport rejects non-Anthropic vendor (`moonshotai/...`)
→ ValueError.
- `effective_transport` returns `subscription` when subscription is on
regardless of OpenRouter config; returns `openrouter` when subscription
is off and OpenRouter is fully configured; returns `direct_anthropic`
otherwise.
2026-04-28 11:40:31 +07:00
Zamil Majdy
4a567a55a4 fix(backend/copilot): pause idle timer during pending tools (#12927)
## Summary

Pause the SDK idle timer while a tool call is pending, with a 2-hour
hung-tool cap as backstop. Fixes SECRT-2239 — long-running tools (10+
min, e.g. sub-agent execution) were being silently aborted by the
10-minute idle timeout introduced in #12660.

## What changed (backend only)

- `_IDLE_TIMEOUT_SECONDS = 1800` (30 min) — soft cap when no tool
pending (raised from 10 min)
- `_HUNG_TOOL_CAP_SECONDS = 7200` (2 h) — hard cap when a tool is
pending; protects against truly hung tool calls without false-aborting
legitimate long-running ones
- `_idle_timeout_threshold(adapter)` — returns the appropriate threshold
based on whether any tool is currently pending in the adapter

Backed by 7 regression tests in
`service_test.py::TestIdleTimeoutThreshold`.

## Frontend coordination

The original cherry-pick batch included a `useStreamActivityWatchdog`
hook for client-side wire-silence detection. That hook is dropped from
this PR because it overlaps with Lluis's #12813, which ships the same
component as part of a comprehensive copilot streaming refactor. End
state on dev: his PR contributes the watchdog, this PR contributes the
backend pause + cap.

## Test plan

- 7/7 unit tests in
`backend/copilot/sdk/service_test.py::TestIdleTimeoutThreshold` pass
- pyright clean on `service.py` + `service_test.py`
- /pr-test --fix posted with native-stack run + screenshots:
https://github.com/Significant-Gravitas/AutoGPT/pull/12927#issuecomment-4328320714

## Linear

SECRT-2239
2026-04-28 09:07:16 +07:00
John Ababseh
2b28434786 feat(platform/backend): Filter store creators with approved agents (#10014)
Filtering store creators to only show profiles with an approved agent
keeps the marketplace focused on usable inventory and prevents empty
creator cards.
 
### Changes 🏗️
 
- add a `num_agents > 0` filter to `get_store_creators`
- add a regression test ensuring we only return creators with approved
agents
- keep the existing SQL injection regression tests intact after rebasing
onto `dev`
 
### Checklist 📋
 
#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [ ] I have tested my changes according to the test plan:
- [ ] python3 -m pytest
autogpt_platform/backend/backend/server/v2/store/db_test.py -k
get_store_creators_only_returns_approved *(blocked: repo environment
lacks pytest and related deps)*
 
<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> Filter `get_store_creators` to creators with `num_agents > 0` and add
a test to validate the behavior.
> 
> - **Store backend**:
> - Update `get_store_creators` in `backend/server/v2/store/db.py` to
filter creators by `num_agents > 0`.
> - **Tests**:
> - Add `test_get_store_creators_only_returns_approved` in
`backend/server/v2/store/db_test.py` to verify filtering and pagination
count calls.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
c2fca584cce5a8c26dbdadd68696a0033642f193. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <ntindle@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ntindle <8845353+ntindle@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-04-27 22:01:24 +00:00
Zamil Majdy
5d1cdc2bad fix(backend/copilot): surface empty-success ResultMessage as stream error (SECRT-2252) (#12926)
## Summary

- Detect ghost-finished sessions where the SDK returns a `ResultMessage`
with `subtype="success"`, empty `result`, no produced content, and
`output_tokens == 0`.
- Emit `StreamError(code="empty_completion")` instead of silently
calling `StreamFinish`, so the caller (and the user) sees the failure.

## Background

Linear: [SECRT-2252](https://linear.app/agpt/issue/SECRT-2252) — SDK
silent empty completion not retried, leaving the user with a blank
stream (`start -> start-step -> finish-step -> finish`).

## Changes

- `response_adapter.py::convert_message`: in the `ResultMessage` branch,
check `_is_empty_completion()` before falling through to the existing
success path. When matched, close any open step, emit `StreamError`, and
skip `StreamFinish`.
- `response_adapter.py::_is_empty_completion`: new helper that returns
`True` only when `result` is falsy, no text/reasoning was emitted, no
tool calls were registered, no tool results were seen, and
`usage["output_tokens"]` is `0`.
- `response_adapter_test.py`: 4 new unit tests covering empty-success
(None and empty-string variants), non-empty success, and the
non-empty-tokens-but-empty-result fallthrough.

## Out of scope (per ticket)

- Retry-once behavior. This PR only surfaces the error; the caller
decides retry semantics. Follow-up work can wire automatic retry on
`code="empty_completion"`.

## Test plan

- [x] `poetry run pytest backend/copilot/sdk/response_adapter_test.py` —
all 58 tests pass (4 new + 54 existing).
- [x] `poetry run pyright backend/copilot/sdk/response_adapter.py
backend/copilot/sdk/response_adapter_test.py` — clean.

## Checklist

- [x] My code follows the style of this project.
- [x] I have added tests covering my changes.
- [x] I have updated the documentation accordingly. (N/A — internal
adapter behavior)
2026-04-27 17:24:57 +00:00
274 changed files with 27665 additions and 3478 deletions

View File

@@ -160,6 +160,24 @@ while clean_polls < required_clean:
Only after `clean_polls == 2` do you report `ORCHESTRATOR:DONE`.
### Concrete CI fetch (don't parse `gh pr checks` text columns)
The `fetch_check_runs(PR)` step above must use `--json`, not the default text output. Job names can contain spaces and parentheses (e.g. `test (3.11)`, `Analyze (python)`), so `gh pr checks $PR | awk '{print $2}'` extracts `(3.11)` instead of the status — leading to a clean-poll firing while jobs are still pending.
```bash
# Reliable: use --json so columns are unambiguous.
ci_json=$(gh pr checks $PR --repo Significant-Gravitas/AutoGPT --json name,state,bucket)
pending=$(echo "$ci_json" | jq '[.[] | select(.bucket == "pending")] | length')
failed=$(echo "$ci_json" | jq '[.[] | select(.bucket == "fail" or .bucket == "cancel")] | length')
# Buckets are: pass | fail | pending | cancel | skipping
# (NOTE: gh pr checks does NOT expose `conclusion` as a JSON field —
# only `bucket`. Don't confuse with the GitHub REST API's check_runs
# endpoint, which DOES use conclusion.)
```
Map back to the pseudocode above: `bucket == "pending"` is `ci.conclusion is None (still in_progress)`; `bucket in {"fail", "cancel"}` is `ci.conclusion in NON_SUCCESS_TERMINAL`; `bucket in {"pass", "skipping"}` is clean.
### Why 2 clean polls, not 1
A single green snapshot can be misleading — the final CI check often completes ~30s before a bot posts its delayed review. One quiet cycle does not prove the PR is stable; two consecutive cycles with no new threads, reviews, or issue comments arriving gives high confidence nothing else is incoming.
@@ -196,6 +214,18 @@ The child skill returning is a **loop iteration boundary**, not a conversation t
If the user needs to approve a risky action mid-loop (e.g., a force-push or a destructive git operation), pause there — but not at the routine "round N finished, round N+1 needed" boundary. Those are silent transitions.
### **Run /pr-polish in the foreground — never in a background agent**
Spawning `/pr-polish` inside an `Agent(subagent_type="general-purpose")` background task **does not work**. Background agents don't inherit the parent's slash-command registry, so `Skill(skill="pr-review")` and `Skill(skill="pr-address")` calls aren't available — the agent has to manually replicate the child skills' logic, which is fragile and tends to stall on the first network or rate-limit hiccup. Symptom: the background task reports `stalled: no progress for 600s` mid-review.
Run `/pr-polish` inline in the foreground conversation. If the user asks for "/pr-polish + /pr-test in parallel", split them: foreground `/pr-polish`, and ONLY then can the test step go to a background agent (because `/pr-test` doesn't itself need to invoke skills).
### **You MUST invoke `Skill(pr-review)` every round — even when bot reviews already exist**
A common failure mode: CodeRabbit / autogpt-reviewer / Sentry have already posted findings on the PR, and the orchestrator skips the `Skill(pr-review)` step on the assumption that "review has been done." That's wrong — the outer loop's purpose is to layer **the agent's own review** on top of the bot reviews, catching issues the bots miss (architecture, naming, cross-file invariants, hidden coupling). If the orchestrator only addresses bot findings without ever running its own review, the loop converges to "bot-clean" but not "agent-reviewed-clean," and the user reasonably asks "did /pr-polish even read the diff?"
**Self-check before reporting `ORCHESTRATOR:DONE`:** confirm at least one `Skill(skill="pr-review")` call appears in the current orchestration. If none, the loop is incomplete — go back and run one round.
## GitHub rate limits
This skill issues many GraphQL calls (one review-thread query per outer iteration plus per-poll queries inside polish polling). Expect the GraphQL budget to be tight on large PRs. When `gh api rate_limit --jq .resources.graphql.remaining` drops below ~200, back off:

View File

@@ -119,10 +119,12 @@ jobs:
runs-on: ubuntu-latest
services:
redis:
image: redis:latest
ports:
- 6379:6379
# Redis is provisioned as a real 3-shard cluster below via docker
# run (see the "Start Redis Cluster" step). GHA services can't
# override the image CMD or stand up multi-container clusters, so
# that setup is inlined — it mirrors the topology of the local dev
# compose stack (autogpt_platform/docker-compose.platform.yml) and
# prod helm chart.
rabbitmq:
image: rabbitmq:4.1.4
ports:
@@ -166,6 +168,68 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Start Redis Cluster (3 shards)
run: |
# 3-master Redis Cluster matching the local compose stack
# (autogpt_platform/docker-compose.platform.yml) and prod. Each
# shard runs in its own container on a dedicated bridge network,
# announces its compose-style hostname for intra-network clients,
# and publishes 1700N on the GHA host so tests can reach every
# shard via localhost. The backend's ``_address_remap`` rewrites
# every CLUSTER SLOTS reply to localhost:<announced-port>, which
# picks the right published port per shard.
#
# Not reusing docker-compose.platform.yml directly because compose
# validates the full file even when only some services are ``up``,
# and that file references services (db/kong/...) defined in a
# sibling compose file — pulling both in would needlessly couple
# CI to the full local-dev stack.
docker network create redis-cluster-ci
for i in 0 1 2; do
port=$((17000 + i))
bus=$((27000 + i))
docker run -d --name redis-$i --network redis-cluster-ci \
--network-alias redis-$i \
-p $port:$port \
redis:7 \
redis-server --port $port \
--cluster-enabled yes \
--cluster-config-file nodes.conf \
--cluster-node-timeout 5000 \
--cluster-require-full-coverage no \
--cluster-announce-hostname redis-$i \
--cluster-announce-port $port \
--cluster-announce-bus-port $bus \
--cluster-preferred-endpoint-type hostname
done
# Wait for each shard to accept commands.
for i in 0 1 2; do
port=$((17000 + i))
for _ in $(seq 1 30); do
docker exec redis-$i redis-cli -p $port ping 2>/dev/null | grep -q PONG && break
sleep 1
done
done
# Form the cluster from an init container on the same network so
# --cluster-preferred-endpoint-type hostname resolves redis-0/1/2.
docker run --rm --network redis-cluster-ci redis:7 \
redis-cli --cluster create \
redis-0:17000 redis-1:17001 redis-2:17002 \
--cluster-replicas 0 --cluster-yes
# Confirm convergence.
for _ in $(seq 1 30); do
state=$(docker exec redis-0 redis-cli -p 17000 cluster info | awk -F: '/^cluster_state:/ {print $2}' | tr -d '[:cntrl:]')
if [ "$state" = "ok" ]; then
echo "Redis Cluster ready (3 shards, state=ok)"
docker exec redis-0 redis-cli -p 17000 cluster nodes
exit 0
fi
sleep 1
done
echo "Redis Cluster failed to reach ok state" >&2
docker exec redis-0 redis-cli -p 17000 cluster info >&2 || true
exit 1
- name: Setup Supabase
uses: supabase/setup-cli@v1
with:
@@ -286,8 +350,13 @@ jobs:
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
REDIS_HOST: "localhost"
REDIS_PORT: "6379"
REDIS_PORT: "17000"
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
# Opt-in: lets backend/data/e2e_redis_restart_test.py spin up its
# own isolated 3-shard cluster (ports 2711027112) and exercise
# ``docker restart <shard>`` mid-stream. Off locally so a
# contributor's ``poetry run test`` doesn't pay the ~15s cost.
E2E_RESTART_ISOLATED: "1"
- name: Upload coverage reports to Codecov
if: ${{ !cancelled() }}

4
.gitignore vendored
View File

@@ -196,3 +196,7 @@ test.db
plans/
.claude/worktrees/
test-results/
# Playwright MCP / local browser-testing artifacts
.playwright-mcp/
copilot-session-switch-qa/

View File

@@ -1,33 +0,0 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class RateLimitSettings(BaseSettings):
redis_host: str = Field(
default="redis://localhost:6379",
description="Redis host",
validation_alias="REDIS_HOST",
)
redis_port: str = Field(
default="6379", description="Redis port", validation_alias="REDIS_PORT"
)
redis_password: Optional[str] = Field(
default=None,
description="Redis password",
validation_alias="REDIS_PASSWORD",
)
requests_per_minute: int = Field(
default=60,
description="Maximum number of requests allowed per minute per API key",
validation_alias="RATE_LIMIT_REQUESTS_PER_MINUTE",
)
model_config = SettingsConfigDict(case_sensitive=True, extra="ignore")
RATE_LIMIT_SETTINGS = RateLimitSettings()

View File

@@ -1,51 +0,0 @@
import time
from typing import Tuple
from redis import Redis
from .config import RATE_LIMIT_SETTINGS
class RateLimiter:
def __init__(
self,
redis_host: str = RATE_LIMIT_SETTINGS.redis_host,
redis_port: str = RATE_LIMIT_SETTINGS.redis_port,
redis_password: str | None = RATE_LIMIT_SETTINGS.redis_password,
requests_per_minute: int = RATE_LIMIT_SETTINGS.requests_per_minute,
):
self.redis = Redis(
host=redis_host,
port=int(redis_port),
password=redis_password,
decode_responses=True,
)
self.window = 60
self.max_requests = requests_per_minute
async def check_rate_limit(self, api_key_id: str) -> Tuple[bool, int, int]:
"""
Check if request is within rate limits.
Args:
api_key_id: The API key identifier to check
Returns:
Tuple of (is_allowed, remaining_requests, reset_time)
"""
now = time.time()
window_start = now - self.window
key = f"ratelimit:{api_key_id}:1min"
pipe = self.redis.pipeline()
pipe.zremrangebyscore(key, 0, window_start)
pipe.zadd(key, {str(now): now})
pipe.zcount(key, window_start, now)
pipe.expire(key, self.window)
_, _, request_count, _ = pipe.execute()
remaining = max(0, self.max_requests - request_count)
reset_time = int(now + self.window)
return request_count <= self.max_requests, remaining, reset_time

View File

@@ -1,32 +0,0 @@
from fastapi import HTTPException, Request
from starlette.middleware.base import RequestResponseEndpoint
from .limiter import RateLimiter
async def rate_limit_middleware(request: Request, call_next: RequestResponseEndpoint):
"""FastAPI middleware for rate limiting API requests."""
limiter = RateLimiter()
if not request.url.path.startswith("/api"):
return await call_next(request)
api_key = request.headers.get("Authorization")
if not api_key:
return await call_next(request)
api_key = api_key.replace("Bearer ", "")
is_allowed, remaining, reset_time = await limiter.check_rate_limit(api_key)
if not is_allowed:
raise HTTPException(
status_code=429, detail="Rate limit exceeded. Please try again later."
)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limiter.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Reset"] = str(reset_time)
return response

View File

@@ -1,13 +1,16 @@
import asyncio
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Union
from expiringdict import ExpiringDict
if TYPE_CHECKING:
from redis.asyncio import Redis as AsyncRedis
from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster
from redis.asyncio.lock import Lock as AsyncRedisLock
AsyncRedisLike = Union[AsyncRedis, AsyncRedisCluster]
class AsyncRedisKeyedMutex:
"""
@@ -17,7 +20,7 @@ class AsyncRedisKeyedMutex:
in case the key is not unlocked for a specified duration, to prevent memory leaks.
"""
def __init__(self, redis: "AsyncRedis", timeout: int | None = 60):
def __init__(self, redis: "AsyncRedisLike", timeout: int | None = 60):
self.redis = redis
self.timeout = timeout
self.locks: dict[Any, "AsyncRedisLock"] = ExpiringDict(

View File

@@ -37,6 +37,23 @@ JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
# Web Push (VAPID) — generate with: poetry run python -c "
# from py_vapid import Vapid; import base64
# from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
# v = Vapid(); v.generate_keys()
# raw_priv = v.private_key.private_numbers().private_value.to_bytes(32, 'big')
# print('VAPID_PRIVATE_KEY=' + base64.urlsafe_b64encode(raw_priv).rstrip(b'=').decode())
# raw_pub = v.public_key.public_bytes(Encoding.X962, PublicFormat.UncompressedPoint)
# print('VAPID_PUBLIC_KEY=' + base64.urlsafe_b64encode(raw_pub).rstrip(b'=').decode())
# "
# Dev-only keypair below — DO NOT use in staging/production. Regenerate
# your own with the snippet above before any non-local deployment.
VAPID_PRIVATE_KEY=17hBPdSdn6TR_yAgQxA0TjTcvRj3Lf6znHnASZ4rOKc
VAPID_PUBLIC_KEY=BBg49iVTWthVbRYphwmZNvZyiSJDqtSO4nmLxDzLKe3Oo9jbtu0Usa14xX4HQQNLUeiEfzD42zWSlrvY1PR12bs
# Per RFC 8292 push services use this in 410 Gone reports; set to a real
# mailbox in production. Defaults to a placeholder for local dev.
VAPID_CLAIM_EMAIL=mailto:dev@example.com
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
# Platform URLs (set these for webhooks and OAuth to work)
PLATFORM_BASE_URL=http://localhost:8000
@@ -182,6 +199,10 @@ GOOGLE_MAPS_API_KEY=
# Platform Bot Linking
PLATFORM_LINK_BASE_URL=http://localhost:3000/link
# CoPilot chat-platform bridge (Discord/Telegram/Slack)
# Uses FRONTEND_BASE_URL (above) for link confirmation pages.
AUTOPILOT_BOT_DISCORD_TOKEN=
# Communication Services
DISCORD_BOT_TOKEN=
MEDIUM_API_KEY=

View File

@@ -1,14 +1,44 @@
import asyncio
from typing import Dict, Set
import json
import logging
import time
from typing import Awaitable, Callable, Dict, Optional, Set
from fastapi import WebSocket
from fastapi import WebSocket, WebSocketDisconnect
from redis.asyncio import Redis as AsyncRedis
from redis.asyncio.client import PubSub as AsyncPubSub
from redis.exceptions import MovedError, RedisError, ResponseError
from starlette.websockets import WebSocketState
from backend.api.model import NotificationPayload, WSMessage, WSMethod
from backend.api.model import WSMessage, WSMethod
from backend.data import redis_client as redis
from backend.data.event_bus import _assert_no_wildcard
from backend.data.execution import (
ExecutionEventType,
GraphExecutionEvent,
NodeExecutionEvent,
exec_channel,
get_graph_execution_meta,
graph_all_channel,
)
from backend.data.notification_bus import NotificationEvent
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
_settings = Settings()
def _is_ws_close_race(exc: BaseException, websocket: WebSocket) -> bool:
"""A SPUBLISH→WS send racing with WS close — benign, drop quietly."""
if isinstance(exc, WebSocketDisconnect):
return True
if (
getattr(websocket, "application_state", None) == WebSocketState.DISCONNECTED
or getattr(websocket, "client_state", None) == WebSocketState.DISCONNECTED
):
return True
if isinstance(exc, RuntimeError) and "close message has been sent" in str(exc):
return True
return False
_EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
ExecutionEventType.GRAPH_EXEC_UPDATE: WSMethod.GRAPH_EXECUTION_EVENT,
@@ -16,128 +46,379 @@ _EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
}
def event_bus_channel(channel_key: str) -> str:
"""Prefix a channel key with the execution event bus name."""
return f"{_settings.config.execution_event_bus_name}/{channel_key}"
def _notification_bus_channel(user_id: str) -> str:
"""Return the full sharded channel name for a user's notifications."""
return f"{_settings.config.notification_event_bus_name}/{user_id}"
MessageHandler = Callable[[Optional[bytes | str]], Awaitable[None]]
def _is_moved_error(exc: BaseException) -> bool:
"""A MOVED redirect — slot migration mid-stream; pump should reconnect."""
if isinstance(exc, MovedError):
return True
if isinstance(exc, ResponseError) and str(exc).startswith("MOVED "):
return True
return False
# Reconnect tunables for shard-failover during pubsub.listen().
_PUMP_RECONNECT_DEADLINE_S = 60.0
_PUMP_RECONNECT_BACKOFF_INITIAL_S = 0.5
_PUMP_RECONNECT_BACKOFF_MAX_S = 8.0
class _Subscription:
"""One SSUBSCRIBE lifecycle bound to a WebSocket, pinned to the owning shard."""
def __init__(self, full_channel: str) -> None:
_assert_no_wildcard(full_channel)
self.full_channel = full_channel
self._client: AsyncRedis | None = None
self._pubsub: AsyncPubSub | None = None
self._task: asyncio.Task | None = None
async def start(self, on_message: MessageHandler) -> None:
await self._open_pubsub()
self._task = asyncio.create_task(self._pump(on_message))
async def _open_pubsub(self) -> None:
"""(Re)establish the sharded pubsub connection + SSUBSCRIBE."""
self._client = await redis.connect_sharded_pubsub_async(self.full_channel)
self._pubsub = self._client.pubsub()
await self._pubsub.execute_command("SSUBSCRIBE", self.full_channel)
# redis-py 6.x async PubSub.listen() exits when ``channels`` is
# empty; raw SSUBSCRIBE doesn't populate it, so do it ourselves.
self._pubsub.channels[self.full_channel] = None # type: ignore[index]
async def _close_pubsub_quietly(self) -> None:
"""Best-effort teardown before reconnect — never raises."""
if self._pubsub is not None:
try:
await self._pubsub.aclose()
except Exception:
pass
self._pubsub = None
if self._client is not None:
try:
await self._client.aclose()
except Exception:
pass
self._client = None
async def _pump(self, on_message: MessageHandler) -> None:
if self._pubsub is None:
return
backoff = _PUMP_RECONNECT_BACKOFF_INITIAL_S
deadline = time.monotonic() + _PUMP_RECONNECT_DEADLINE_S
while True:
pubsub = self._pubsub
if pubsub is None:
return
needs_reconnect = False
try:
async for message in pubsub.listen():
msg_type = message.get("type")
# Server-pushed sunsubscribe: slot ownership changed and
# Redis revoked our SSUBSCRIBE without dropping the TCP.
# Treat as a reconnect trigger so we re-resolve the shard.
if msg_type == "sunsubscribe":
needs_reconnect = True
break
if msg_type not in ("smessage", "message", "pmessage"):
continue
# Successful read resets the reconnect budget.
backoff = _PUMP_RECONNECT_BACKOFF_INITIAL_S
deadline = time.monotonic() + _PUMP_RECONNECT_DEADLINE_S
try:
await on_message(message.get("data"))
except Exception:
logger.exception(
"Websocket message-handler failed for channel %s",
self.full_channel,
)
if not needs_reconnect:
# listen() exited cleanly (channels emptied) — pump is done.
return
except asyncio.CancelledError:
raise
except (ConnectionError, RedisError) as exc:
if isinstance(exc, ResponseError) and not _is_moved_error(exc):
logger.exception(
"Pubsub pump crashed on non-retryable ResponseError for %s",
self.full_channel,
)
return
if time.monotonic() > deadline:
logger.exception(
"Pubsub pump giving up after reconnect deadline for %s",
self.full_channel,
)
return
logger.warning(
"Pubsub pump reconnecting for %s after %s: %s",
self.full_channel,
type(exc).__name__,
exc,
)
except Exception:
logger.exception("Pubsub pump crashed for %s", self.full_channel)
return
# Either a retryable error was raised, or the server pushed a
# sunsubscribe — close the stale pubsub and reopen against the
# (possibly migrated) shard.
await self._close_pubsub_quietly()
await asyncio.sleep(backoff)
backoff = min(backoff * 2, _PUMP_RECONNECT_BACKOFF_MAX_S)
try:
await self._open_pubsub()
except (ConnectionError, RedisError) as reopen_exc:
logger.warning(
"Pubsub pump reopen failed for %s: %s",
self.full_channel,
reopen_exc,
)
# Loop again — deadline check will eventually exit.
continue
async def stop(self) -> None:
if self._task is not None:
self._task.cancel()
try:
await self._task
except (asyncio.CancelledError, Exception):
pass
self._task = None
if self._pubsub is not None:
try:
await self._pubsub.execute_command("SUNSUBSCRIBE", self.full_channel)
except Exception:
logger.warning(
"SUNSUBSCRIBE failed for %s", self.full_channel, exc_info=True
)
try:
await self._pubsub.aclose()
except Exception:
pass
self._pubsub = None
if self._client is not None:
try:
await self._client.aclose()
except Exception:
pass
self._client = None
class ConnectionManager:
def __init__(self):
self.active_connections: Set[WebSocket] = set()
# channel_key → sockets subscribed (public channel keys, not raw Redis channels)
self.subscriptions: Dict[str, Set[WebSocket]] = {}
self.user_connections: Dict[str, Set[WebSocket]] = {}
# websocket → {channel_key: _Subscription}
self._ws_subs: Dict[WebSocket, Dict[str, _Subscription]] = {}
# websocket → notification subscription
self._ws_notifications: Dict[WebSocket, _Subscription] = {}
async def connect_socket(self, websocket: WebSocket, *, user_id: str):
await websocket.accept()
self.active_connections.add(websocket)
if user_id not in self.user_connections:
self.user_connections[user_id] = set()
self.user_connections[user_id].add(websocket)
self._ws_subs.setdefault(websocket, {})
await self._start_notification_subscription(websocket, user_id=user_id)
def disconnect_socket(self, websocket: WebSocket, *, user_id: str):
async def disconnect_socket(self, websocket: WebSocket, *, user_id: str):
self.active_connections.discard(websocket)
for subscribers in self.subscriptions.values():
# Stop SSUBSCRIBE pumps before dropping bookkeeping to avoid leaks.
subs = self._ws_subs.pop(websocket, {})
for sub in subs.values():
await sub.stop()
notif_sub = self._ws_notifications.pop(websocket, None)
if notif_sub is not None:
await notif_sub.stop()
for channel_key, subscribers in list(self.subscriptions.items()):
subscribers.discard(websocket)
user_conns = self.user_connections.get(user_id)
if user_conns is not None:
user_conns.discard(websocket)
if not user_conns:
self.user_connections.pop(user_id, None)
if not subscribers:
self.subscriptions.pop(channel_key, None)
async def subscribe_graph_exec(
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
) -> str:
return await self._subscribe(
_graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id), websocket
# Hash-tagged channel needs graph_id; resolve once per subscribe.
meta = await get_graph_execution_meta(user_id, graph_exec_id)
if meta is None:
raise ValueError(
f"graph_exec #{graph_exec_id} not found for user #{user_id}"
)
channel_key = graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id)
full_channel = event_bus_channel(
exec_channel(user_id, meta.graph_id, graph_exec_id)
)
await self._open_subscription(websocket, channel_key, full_channel)
return channel_key
async def subscribe_graph_execs(
self, *, user_id: str, graph_id: str, websocket: WebSocket
) -> str:
return await self._subscribe(
_graph_execs_channel_key(user_id, graph_id=graph_id), websocket
)
channel_key = _graph_execs_channel_key(user_id, graph_id=graph_id)
full_channel = event_bus_channel(graph_all_channel(user_id, graph_id))
await self._open_subscription(websocket, channel_key, full_channel)
return channel_key
async def unsubscribe_graph_exec(
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
) -> str | None:
return await self._unsubscribe(
_graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id), websocket
)
channel_key = graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id)
return await self._close_subscription(websocket, channel_key)
async def unsubscribe_graph_execs(
self, *, user_id: str, graph_id: str, websocket: WebSocket
) -> str | None:
return await self._unsubscribe(
_graph_execs_channel_key(user_id, graph_id=graph_id), websocket
)
channel_key = _graph_execs_channel_key(user_id, graph_id=graph_id)
return await self._close_subscription(websocket, channel_key)
async def send_execution_update(
self, exec_event: GraphExecutionEvent | NodeExecutionEvent
) -> int:
graph_exec_id = (
exec_event.id
if isinstance(exec_event, GraphExecutionEvent)
else exec_event.graph_exec_id
)
async def _open_subscription(
self, websocket: WebSocket, channel_key: str, full_channel: str
) -> None:
self.subscriptions.setdefault(channel_key, set()).add(websocket)
per_ws = self._ws_subs.setdefault(websocket, {})
if channel_key in per_ws:
return
sub = _Subscription(full_channel)
n_sent = 0
async def on_message(data: Optional[bytes | str]) -> None:
await self._forward_exec_event(websocket, channel_key, data)
channels: set[str] = {
# Send update to listeners for this graph execution
_graph_exec_channel_key(exec_event.user_id, graph_exec_id=graph_exec_id)
}
if isinstance(exec_event, GraphExecutionEvent):
# Send update to listeners for all executions of this graph
channels.add(
_graph_execs_channel_key(
exec_event.user_id, graph_id=exec_event.graph_id
)
)
await sub.start(on_message)
per_ws[channel_key] = sub
for channel in channels.intersection(self.subscriptions.keys()):
message = WSMessage(
method=_EVENT_TYPE_TO_METHOD_MAP[exec_event.event_type],
channel=channel,
data=exec_event.model_dump(),
).model_dump_json()
for connection in self.subscriptions[channel]:
await connection.send_text(message)
n_sent += 1
return n_sent
async def send_notification(
self, *, user_id: str, payload: NotificationPayload
) -> int:
"""Send a notification to all websocket connections belonging to a user."""
message = WSMessage(
method=WSMethod.NOTIFICATION,
data=payload.model_dump(),
).model_dump_json()
connections = tuple(self.user_connections.get(user_id, set()))
if not connections:
return 0
await asyncio.gather(
*(connection.send_text(message) for connection in connections),
return_exceptions=True,
)
return len(connections)
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
if channel_key not in self.subscriptions:
self.subscriptions[channel_key] = set()
self.subscriptions[channel_key].add(websocket)
async def _close_subscription(
self, websocket: WebSocket, channel_key: str
) -> str | None:
subscribers = self.subscriptions.get(channel_key)
if subscribers is None:
return None
subscribers.discard(websocket)
if not subscribers:
self.subscriptions.pop(channel_key, None)
per_ws = self._ws_subs.get(websocket)
if per_ws and channel_key in per_ws:
sub = per_ws.pop(channel_key)
await sub.stop()
return channel_key
async def _unsubscribe(self, channel_key: str, websocket: WebSocket) -> str | None:
if channel_key in self.subscriptions:
self.subscriptions[channel_key].discard(websocket)
if not self.subscriptions[channel_key]:
del self.subscriptions[channel_key]
return channel_key
return None
async def _forward_exec_event(
self,
websocket: WebSocket,
channel_key: str,
raw_payload: Optional[bytes | str],
) -> None:
if raw_payload is None:
return
# Unwrap the `_EventPayloadWrapper` envelope, then re-wrap as a WS message.
try:
wrapper = (
raw_payload.decode()
if isinstance(raw_payload, (bytes, bytearray))
else raw_payload
)
except Exception:
logger.warning(
"Failed to decode pubsub payload on %s", channel_key, exc_info=True
)
return
try:
parsed = json.loads(wrapper)
event_data = parsed.get("payload")
if not isinstance(event_data, dict):
return
event_type = event_data.get("event_type")
method = _EVENT_TYPE_TO_METHOD_MAP.get(ExecutionEventType(event_type))
if method is None:
return
message = WSMessage(
method=method,
channel=channel_key,
data=event_data,
).model_dump_json()
await websocket.send_text(message)
except Exception as e:
if _is_ws_close_race(e, websocket):
logger.debug("Dropped exec event on closed WS for %s", channel_key)
return
logger.exception("Failed to forward exec event on %s", channel_key)
async def _start_notification_subscription(
self, websocket: WebSocket, *, user_id: str
) -> None:
full_channel = _notification_bus_channel(user_id)
sub = _Subscription(full_channel)
async def on_message(data: Optional[bytes | str]) -> None:
await self._forward_notification(websocket, user_id, data)
try:
await sub.start(on_message)
except Exception:
logger.exception(
"Failed to open notification SSUBSCRIBE for user=%s", user_id
)
return
self._ws_notifications[websocket] = sub
async def _forward_notification(
self,
websocket: WebSocket,
user_id: str,
raw_payload: Optional[bytes | str],
) -> None:
if raw_payload is None:
return
try:
wrapper_json = (
raw_payload.decode()
if isinstance(raw_payload, (bytes, bytearray))
else raw_payload
)
parsed = json.loads(wrapper_json)
inner = parsed.get("payload") if isinstance(parsed, dict) else None
if not isinstance(inner, dict):
return
event = NotificationEvent.model_validate(inner)
except Exception:
logger.warning(
"Failed to parse notification payload for user=%s",
user_id,
exc_info=True,
)
return
# Defense in depth against cross-user payloads.
if event.user_id != user_id:
return
message = WSMessage(
method=WSMethod.NOTIFICATION,
data=event.payload.model_dump(),
).model_dump_json()
try:
await websocket.send_text(message)
except Exception as e:
if _is_ws_close_race(e, websocket):
logger.debug("Dropped notification on closed WS for user=%s", user_id)
return
logger.warning(
"Failed to deliver notification to WS for user=%s",
user_id,
exc_info=True,
)
def _graph_exec_channel_key(user_id: str, *, graph_exec_id: str) -> str:
def graph_exec_channel_key(user_id: str, *, graph_exec_id: str) -> str:
return f"{user_id}|graph_exec#{graph_exec_id}"

View File

@@ -0,0 +1,386 @@
"""ConnectionManager integration over the live 3-shard Redis cluster:
SSUBSCRIBE → SPUBLISH → WebSocket forwarding with no Redis mocks. Skips
when the cluster is unreachable."""
import asyncio
import json
from datetime import datetime, timezone
from unittest.mock import AsyncMock
from uuid import uuid4
import pytest
from fastapi import WebSocket
import backend.data.redis_client as redis_client
from backend.api.conn_manager import (
ConnectionManager,
_graph_execs_channel_key,
event_bus_channel,
graph_exec_channel_key,
)
from backend.api.model import WSMethod
from backend.data.execution import (
ExecutionStatus,
GraphExecutionEvent,
GraphExecutionMeta,
NodeExecutionEvent,
exec_channel,
graph_all_channel,
)
def _has_live_cluster() -> bool:
try:
c = redis_client.connect()
except Exception: # noqa: BLE001 — any connect failure → skip
return False
try:
c.close()
except Exception:
pass
return True
pytestmark = pytest.mark.skipif(
not _has_live_cluster(),
reason="local redis cluster not reachable; skip conn_manager integration",
)
def _meta(user_id: str, graph_id: str, graph_exec_id: str) -> GraphExecutionMeta:
"""Build a minimal GraphExecutionMeta for ``subscribe_graph_exec`` to use."""
return GraphExecutionMeta(
id=graph_exec_id,
user_id=user_id,
graph_id=graph_id,
graph_version=1,
inputs=None,
credential_inputs=None,
nodes_input_masks=None,
preset_id=None,
status=ExecutionStatus.RUNNING,
started_at=datetime.now(tz=timezone.utc),
ended_at=None,
stats=GraphExecutionMeta.Stats(),
)
def _node_event_payload(
*, user_id: str, graph_id: str, graph_exec_id: str, marker: str
) -> bytes:
"""Wire-format a NodeExecutionEvent the way RedisExecutionEventBus would."""
inner = NodeExecutionEvent(
user_id=user_id,
graph_id=graph_id,
graph_version=1,
graph_exec_id=graph_exec_id,
node_exec_id=f"node-exec-{marker}",
node_id="node-1",
block_id="block-1",
status=ExecutionStatus.COMPLETED,
input_data={"in": marker},
output_data={"out": [marker]},
add_time=datetime.now(tz=timezone.utc),
queue_time=None,
start_time=datetime.now(tz=timezone.utc),
end_time=datetime.now(tz=timezone.utc),
).model_dump(mode="json")
return json.dumps({"payload": inner}).encode()
def _graph_event_payload(
*, user_id: str, graph_id: str, graph_exec_id: str, marker: str
) -> bytes:
inner = GraphExecutionEvent(
id=graph_exec_id,
user_id=user_id,
graph_id=graph_id,
graph_version=1,
preset_id=None,
status=ExecutionStatus.COMPLETED,
started_at=datetime.now(tz=timezone.utc),
ended_at=datetime.now(tz=timezone.utc),
stats=GraphExecutionEvent.Stats(
cost=0,
duration=1.0,
node_exec_time=0.5,
node_exec_count=1,
),
inputs={"x": marker},
credential_inputs=None,
nodes_input_masks=None,
outputs={"y": [marker]},
).model_dump(mode="json")
return json.dumps({"payload": inner}).encode()
async def _wait_until(predicate, timeout: float = 5.0, interval: float = 0.05) -> bool:
"""Poll ``predicate()`` until truthy or timeout — used to wait for pubsub."""
deadline = asyncio.get_event_loop().time() + timeout
while asyncio.get_event_loop().time() < deadline:
if predicate():
return True
await asyncio.sleep(interval)
return False
@pytest.mark.asyncio
async def test_two_clients_get_independent_ssubscribes_on_right_shards(
monkeypatch,
) -> None:
"""Two WS clients on different graph_exec_ids each receive ONLY their
own publish, even when the channels land on different shards."""
user_id = "user-conn-int-1"
graph_a = f"graph-a-{uuid4().hex[:8]}"
graph_b = f"graph-b-{uuid4().hex[:8]}"
exec_a = f"exec-a-{uuid4().hex[:8]}"
exec_b = f"exec-b-{uuid4().hex[:8]}"
# Stub Prisma lookup so tests don't need a DB.
async def _fake_meta(_uid, gex_id):
return _meta(user_id, graph_a if gex_id == exec_a else graph_b, gex_id)
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
cm = ConnectionManager()
ws_a: AsyncMock = AsyncMock(spec=WebSocket)
ws_b: AsyncMock = AsyncMock(spec=WebSocket)
sent_a: list[str] = []
sent_b: list[str] = []
ws_a.send_text = AsyncMock(side_effect=lambda m: sent_a.append(m))
ws_b.send_text = AsyncMock(side_effect=lambda m: sent_b.append(m))
redis_client.get_redis.cache_clear()
cluster = redis_client.get_redis()
try:
await cm.subscribe_graph_exec(
user_id=user_id, graph_exec_id=exec_a, websocket=ws_a
)
await cm.subscribe_graph_exec(
user_id=user_id, graph_exec_id=exec_b, websocket=ws_b
)
# Let SSUBSCRIBE settle on each shard.
await asyncio.sleep(0.2)
# Publish to each per-exec channel.
chan_a = event_bus_channel(exec_channel(user_id, graph_a, exec_a))
chan_b = event_bus_channel(exec_channel(user_id, graph_b, exec_b))
cluster.spublish(
chan_a,
_node_event_payload(
user_id=user_id,
graph_id=graph_a,
graph_exec_id=exec_a,
marker="A",
).decode(),
)
cluster.spublish(
chan_b,
_node_event_payload(
user_id=user_id,
graph_id=graph_b,
graph_exec_id=exec_b,
marker="B",
).decode(),
)
delivered = await _wait_until(lambda: sent_a and sent_b, timeout=5.0)
assert delivered, f"timeout: sent_a={sent_a!r} sent_b={sent_b!r}"
msg_a = json.loads(sent_a[0])
msg_b = json.loads(sent_b[0])
assert msg_a["channel"] == graph_exec_channel_key(user_id, graph_exec_id=exec_a)
assert msg_b["channel"] == graph_exec_channel_key(user_id, graph_exec_id=exec_b)
assert msg_a["data"]["graph_exec_id"] == exec_a
assert msg_b["data"]["graph_exec_id"] == exec_b
# No cross-talk: each socket got exactly one message.
assert len(sent_a) == 1 and len(sent_b) == 1
finally:
await cm.disconnect_socket(ws_a, user_id=user_id)
await cm.disconnect_socket(ws_b, user_id=user_id)
redis_client.disconnect()
@pytest.mark.asyncio
async def test_aggregate_channel_receives_per_exec_publishes(monkeypatch) -> None:
"""A subscriber on the ``graph_execs`` aggregate channel must receive the
GraphExecutionEvent published to the ``/all`` channel — even though
per-exec events go to a different channel."""
user_id = "user-conn-int-2"
graph_id = f"graph-{uuid4().hex[:8]}"
exec_id = f"exec-{uuid4().hex[:8]}"
async def _fake_meta(_uid, gex_id):
return _meta(user_id, graph_id, gex_id)
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
cm = ConnectionManager()
ws_agg: AsyncMock = AsyncMock(spec=WebSocket)
ws_per: AsyncMock = AsyncMock(spec=WebSocket)
sent_agg: list[str] = []
sent_per: list[str] = []
ws_agg.send_text = AsyncMock(side_effect=lambda m: sent_agg.append(m))
ws_per.send_text = AsyncMock(side_effect=lambda m: sent_per.append(m))
redis_client.get_redis.cache_clear()
cluster = redis_client.get_redis()
try:
await cm.subscribe_graph_execs(
user_id=user_id, graph_id=graph_id, websocket=ws_agg
)
await cm.subscribe_graph_exec(
user_id=user_id, graph_exec_id=exec_id, websocket=ws_per
)
await asyncio.sleep(0.2)
# The eventbus publishes the same event to both channels — replicate.
chan_per = event_bus_channel(exec_channel(user_id, graph_id, exec_id))
chan_all = event_bus_channel(graph_all_channel(user_id, graph_id))
payload = _graph_event_payload(
user_id=user_id,
graph_id=graph_id,
graph_exec_id=exec_id,
marker="agg",
).decode()
cluster.spublish(chan_per, payload)
cluster.spublish(chan_all, payload)
delivered = await _wait_until(lambda: sent_agg and sent_per, timeout=5.0)
assert delivered, f"sent_agg={sent_agg!r} sent_per={sent_per!r}"
agg_msg = json.loads(sent_agg[0])
per_msg = json.loads(sent_per[0])
# Aggregate subscriber's channel key is the per-graph executions key.
assert agg_msg["channel"] == _graph_execs_channel_key(
user_id, graph_id=graph_id
)
assert per_msg["channel"] == graph_exec_channel_key(
user_id, graph_exec_id=exec_id
)
assert agg_msg["method"] == WSMethod.GRAPH_EXECUTION_EVENT.value
finally:
await cm.disconnect_socket(ws_agg, user_id=user_id)
await cm.disconnect_socket(ws_per, user_id=user_id)
redis_client.disconnect()
@pytest.mark.asyncio
async def test_disconnect_unsubscribes_and_drops_future_publishes(monkeypatch) -> None:
"""After ``disconnect_socket`` runs, a subsequent SPUBLISH must NOT reach
the dead websocket — exercises the SUNSUBSCRIBE + bookkeeping cleanup."""
user_id = "user-conn-int-3"
graph_id = f"graph-{uuid4().hex[:8]}"
exec_id = f"exec-{uuid4().hex[:8]}"
async def _fake_meta(_uid, gex_id):
return _meta(user_id, graph_id, gex_id)
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
cm = ConnectionManager()
ws: AsyncMock = AsyncMock(spec=WebSocket)
sent: list[str] = []
ws.send_text = AsyncMock(side_effect=lambda m: sent.append(m))
redis_client.get_redis.cache_clear()
cluster = redis_client.get_redis()
chan = event_bus_channel(exec_channel(user_id, graph_id, exec_id))
payload = _node_event_payload(
user_id=user_id, graph_id=graph_id, graph_exec_id=exec_id, marker="live"
).decode()
try:
await cm.subscribe_graph_exec(
user_id=user_id, graph_exec_id=exec_id, websocket=ws
)
await asyncio.sleep(0.15)
# First publish — must reach the socket.
cluster.spublish(chan, payload)
delivered = await _wait_until(lambda: bool(sent), timeout=5.0)
assert delivered
assert len(sent) == 1
# Disconnect → SUNSUBSCRIBE + bookkeeping cleared.
await cm.disconnect_socket(ws, user_id=user_id)
# Pump cancellation may drain in-flight messages; wait for it.
await asyncio.sleep(0.2)
# Channel bookkeeping must be gone.
assert (
graph_exec_channel_key(user_id, graph_exec_id=exec_id)
not in cm.subscriptions
)
assert ws not in cm._ws_subs
# Second publish — must NOT reach the (already-disconnected) socket.
cluster.spublish(
chan,
_node_event_payload(
user_id=user_id,
graph_id=graph_id,
graph_exec_id=exec_id,
marker="post-disconnect",
).decode(),
)
await asyncio.sleep(0.5)
# Still only the one pre-disconnect message.
assert len(sent) == 1
finally:
redis_client.disconnect()
@pytest.mark.asyncio
async def test_slow_consumer_receives_all_events_without_loss(monkeypatch) -> None:
"""Burst-publish many SPUBLISHes; assert every one reaches the subscriber
in order — guards against drops/reorderings in the pubsub pump."""
user_id = "user-conn-int-4"
graph_id = f"graph-{uuid4().hex[:8]}"
exec_id = f"exec-{uuid4().hex[:8]}"
n_events = 100
async def _fake_meta(_uid, gex_id):
return _meta(user_id, graph_id, gex_id)
monkeypatch.setattr("backend.api.conn_manager.get_graph_execution_meta", _fake_meta)
cm = ConnectionManager()
ws: AsyncMock = AsyncMock(spec=WebSocket)
sent: list[str] = []
ws.send_text = AsyncMock(side_effect=lambda m: sent.append(m))
redis_client.get_redis.cache_clear()
cluster = redis_client.get_redis()
chan = event_bus_channel(exec_channel(user_id, graph_id, exec_id))
try:
await cm.subscribe_graph_exec(
user_id=user_id, graph_exec_id=exec_id, websocket=ws
)
await asyncio.sleep(0.2)
# Burst-publish n_events without yielding to the pump.
for i in range(n_events):
cluster.spublish(
chan,
_node_event_payload(
user_id=user_id,
graph_id=graph_id,
graph_exec_id=exec_id,
marker=f"m{i}",
).decode(),
)
delivered = await _wait_until(
lambda: len(sent) >= n_events, timeout=15.0, interval=0.1
)
assert delivered, f"only delivered {len(sent)}/{n_events}"
# Validate ordering — Redis pub/sub is FIFO per channel.
markers = [json.loads(m)["data"]["input_data"]["in"] for m in sent[:n_events]]
assert markers == [f"m{i}" for i in range(n_events)]
finally:
await cm.disconnect_socket(ws, user_id=user_id)
redis_client.disconnect()

File diff suppressed because it is too large Load Diff

View File

@@ -8,7 +8,7 @@ from uuid import uuid4
from autogpt_libs import auth
from fastapi import APIRouter, HTTPException, Query, Response, Security
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, ConfigDict, Field, field_validator
from backend.copilot import service as chat_service
@@ -47,7 +47,14 @@ from backend.copilot.rate_limit import (
release_reset_lock,
reset_daily_usage,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.response_model import (
StreamError,
StreamFinish,
StreamFinishStep,
StreamHeartbeat,
StreamStart,
StreamStartStep,
)
from backend.copilot.service import strip_injected_context_for_display
from backend.copilot.tools.e2b_sandbox import kill_sandbox
from backend.copilot.tools.models import (
@@ -154,6 +161,14 @@ class StreamChatRequest(BaseModel):
)
class QueuePendingMessageRequest(BaseModel):
"""Request model for queueing a follow-up while a turn is running."""
message: str = Field(max_length=64_000)
context: dict[str, str] | None = None
file_ids: list[str] | None = Field(default=None, max_length=20)
class PeekPendingMessagesResponse(BaseModel):
"""Response for the pending-message peek (GET) endpoint.
@@ -209,6 +224,11 @@ class ActiveStreamInfo(BaseModel):
turn_id: str
last_message_id: str # Redis Stream message ID for resumption
# ISO-8601 timestamp (UTC) marking when the backend registered the turn
# as running. Lets the frontend seed its elapsed-time counter so restored
# turns show honest "time since turn started" instead of the misleading
# "time since this mount resumed the SSE".
started_at: str | None = None
class SessionDetailResponse(BaseModel):
@@ -300,8 +320,11 @@ async def list_sessions(
redis = await get_redis_async()
pipe = redis.pipeline(transaction=False)
for session in sessions:
# Use the canonical helper so the hash-tag braces match every
# other writer; building the key inline drops the braces and
# silently misses every running session on cluster mode.
pipe.hget(
f"{config.session_meta_prefix}{session.session_id}",
stream_registry.get_session_meta_key(session.session_id),
"status",
)
statuses = await pipe.execute()
@@ -529,6 +552,7 @@ async def get_session(
active_stream_info = ActiveStreamInfo(
turn_id=active_session.turn_id,
last_message_id=last_message_id,
started_at=active_session.created_at.isoformat(),
)
# Skip session metadata on "load more" — frontend only needs messages
@@ -816,17 +840,45 @@ async def cancel_session_task(
return CancelSessionResponse(cancelled=True)
def _ui_message_stream_headers() -> dict[str, str]:
return {
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
}
def _empty_ui_message_stream_response() -> StreamingResponse:
# Stable placeholder messageId for the empty queued-mid-turn stream.
# Real turns generate per-message UUIDs via the executor; this stream
# has no message to attach to, but the AI SDK parser still requires a
# non-empty ``messageId`` field on ``StreamStart``.
message_id = uuid4().hex
async def event_generator() -> AsyncGenerator[str, None]:
# Vercel AI SDK's UI-message-stream parser expects symmetric
# start/finish framing at both stream and step level — every
# non-empty turn emits the pair. Without an opener, today's parser
# tolerates the closer (no active parts to flush) but a future SDK
# tightening would silently break the queue-mid-turn UX. Emit the
# full empty pair so the contract stays correct.
yield StreamStart(messageId=message_id).to_sse()
yield StreamStartStep().to_sse()
yield StreamFinishStep().to_sse()
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers=_ui_message_stream_headers(),
)
@router.post(
"/sessions/{session_id}/stream",
responses={
202: {
"model": QueuePendingMessageResponse,
"description": (
"Session has a turn in flight — message queued into the pending "
"buffer and will be picked up between tool-call rounds by the "
"executor currently processing the turn."
),
},
404: {"description": "Session not found or access denied"},
429: {"description": "Cost rate-limit or call-frequency cap exceeded"},
},
@@ -836,19 +888,18 @@ async def stream_chat_post(
request: StreamChatRequest,
user_id: str = Security(auth.get_user_id),
):
"""Start a new turn OR queue a follow-up — decided server-side.
"""Start a new turn and return an AI SDK UI message stream.
- **Session idle**: starts a turn. Returns an SSE stream (``text/event-stream``)
with Vercel AI SDK chunks (text fragments, tool-call UI, tool results).
The generation runs in a background task that survives client disconnects;
reconnect via ``GET /sessions/{session_id}/stream`` to resume.
Returns an SSE stream (``text/event-stream``) with Vercel AI SDK chunks
(text fragments, tool-call UI, tool results). The generation runs in a
background task that survives client disconnects; reconnect via
``GET /sessions/{session_id}/stream`` to resume.
- **Session has a turn in flight**: pushes the message into the per-session
pending buffer and returns ``202 application/json`` with
``QueuePendingMessageResponse``. The executor running the current turn
drains the buffer between tool-call rounds (baseline) or at the start of
the next turn (SDK). Clients should detect the 202 and surface the
message as a queued-chip in the UI.
Follow-up messages typed while a turn is already running should use
``POST /sessions/{session_id}/messages/pending``. If an older client still
posts that follow-up here, we queue it defensively but still return a valid
empty UI-message stream so AI SDK transports never receive a JSON body from
the stream endpoint.
Args:
session_id: The chat session identifier.
@@ -872,26 +923,29 @@ async def stream_chat_post(
extra={"json_fields": log_meta},
)
session = await _validate_and_get_session(session_id, user_id)
builder_permissions = resolve_session_permissions(session)
# Self-defensive queue-fallback: if a turn is already running, don't race
# it on the cluster lock — drop the message into the pending buffer and
# return 202 so the caller can render a chip. Both UI chips and autopilot
# block follow-ups route through this path; keeping the decision on the
# server means every caller gets uniform behaviour.
if (
request.is_user_message
and request.message
and await is_turn_in_flight(session_id)
):
response = await queue_pending_for_http(
session_id=session_id,
user_id=user_id,
message=request.message,
context=request.context,
file_ids=request.file_ids,
)
return JSONResponse(status_code=202, content=response.model_dump())
try:
await queue_pending_for_http(
session_id=session_id,
user_id=user_id,
message=request.message,
context=request.context,
file_ids=request.file_ids,
)
return _empty_ui_message_stream_response()
except HTTPException as exc:
if exc.status_code != 409:
raise
# Permission resolution is only needed below for the actual turn — keep
# it after the queue-fall-through so a queued mid-turn request returns
# without paying the work.
builder_permissions = resolve_session_permissions(session)
logger.info(
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
@@ -1130,12 +1184,37 @@ async def stream_chat_post(
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Disable nginx buffering
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
},
headers=_ui_message_stream_headers(),
)
@router.post(
"/sessions/{session_id}/messages/pending",
response_model=QueuePendingMessageResponse,
responses={
404: {"description": "Session not found or access denied"},
409: {"description": "Session has no active turn to receive pending messages"},
429: {"description": "Call-frequency cap exceeded"},
},
)
async def queue_pending_message(
session_id: str,
request: QueuePendingMessageRequest,
user_id: str = Security(auth.get_user_id),
):
"""Queue a follow-up message while the session has an active turn."""
await _validate_and_get_session(session_id, user_id)
if not await is_turn_in_flight(session_id):
raise HTTPException(
status_code=409,
detail="Session has no active turn. Start a new turn with POST /stream.",
)
return await queue_pending_for_http(
session_id=session_id,
user_id=user_id,
message=request.message,
context=request.context,
file_ids=request.file_ids,
)
@@ -1169,6 +1248,7 @@ async def get_pending_messages(
)
async def resume_session_stream(
session_id: str,
last_chunk_id: str | None = Query(default=None, include_in_schema=False),
user_id: str = Security(auth.get_user_id),
):
"""
@@ -1178,27 +1258,26 @@ async def resume_session_stream(
Checks for an active (in-progress) task on the session and either replays
the full SSE stream or returns 204 No Content if nothing is running.
Args:
session_id: The chat session identifier.
user_id: Optional authenticated user ID.
Returns:
StreamingResponse (SSE) when an active stream exists,
or 204 No Content when there is nothing to resume.
Always replays the active turn from ``0-0``. The AI SDK UI-message parser
keeps text/reasoning part state inside a single parser instance; resuming
from a Redis cursor can skip the ``*-start`` events required by later
``*-delta`` chunks.
"""
import asyncio
active_session, last_message_id = await stream_registry.get_active_session(
active_session, _latest_backend_id = await stream_registry.get_active_session(
session_id, user_id
)
if not active_session:
return Response(status_code=204)
# Always replay from the beginning ("0-0") on resume.
# We can't use last_message_id because it's the latest ID in the backend
# stream, not the latest the frontend received — the gap causes lost
# messages. The frontend deduplicates replayed content.
if last_chunk_id:
logger.info(
"Ignoring deprecated last_chunk_id on stream resume",
extra={"session_id": session_id, "last_chunk_id": last_chunk_id},
)
subscriber_queue = await stream_registry.subscribe_to_session(
session_id=session_id,
user_id=user_id,
@@ -1259,12 +1338,7 @@ async def resume_session_stream(
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
},
headers=_ui_message_stream_headers(),
)

View File

@@ -157,6 +157,11 @@ def _mock_stream_internals(mocker: pytest_mock.MockerFixture):
"backend.api.features.chat.routes._validate_and_get_session",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.is_turn_in_flight",
new_callable=AsyncMock,
return_value=False,
)
mock_save = mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=MagicMock(), # non-None = message was saved (not a duplicate)
@@ -637,7 +642,7 @@ class TestStreamChatRequestModeValidation:
assert req.mode is None
# ─── POST /stream queue-fallback (when a turn is already in flight) ──
# ─── Pending message queue (when a turn is already in flight) ─────────
def _mock_stream_queue_internals(
@@ -646,11 +651,9 @@ def _mock_stream_queue_internals(
session_exists: bool = True,
turn_in_flight: bool = True,
call_count: int = 1,
push_length: int | None = 1,
):
"""Mock dependencies for the POST /stream queue-fallback path.
When ``turn_in_flight`` is True the handler takes the 202 queue branch.
"""
"""Mock dependencies for the pending-message queue path."""
if session_exists:
mock_session = mocker.MagicMock()
mock_session.id = "sess-1"
@@ -692,12 +695,10 @@ def _mock_stream_queue_internals(
return_value=call_count,
)
mocker.patch(
"backend.copilot.pending_message_helpers.push_pending_message",
"backend.copilot.pending_message_helpers.push_pending_message_if_session_running",
new_callable=AsyncMock,
return_value=1,
return_value=push_length,
)
# queue_user_message re-runs is_turn_in_flight via the helper module —
# stub that path out too so we don't need a fake stream_registry.
mocker.patch(
"backend.copilot.pending_message_helpers.get_active_session_meta",
new_callable=AsyncMock,
@@ -705,37 +706,65 @@ def _mock_stream_queue_internals(
)
def test_stream_queue_returns_202_when_turn_in_flight(
def test_queue_pending_message_returns_200_when_turn_in_flight(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Happy path: POST /stream to a session with a live turn → 202 queue."""
"""Happy path: POST /messages/pending to a live turn queues the message."""
_mock_stream_queue_internals(mocker)
response = client.post(
"/sessions/sess-1/stream",
json={"message": "follow-up", "is_user_message": True},
"/sessions/sess-1/messages/pending",
json={"message": "follow-up"},
)
assert response.status_code == 202
assert response.status_code == 200
data = response.json()
assert data["buffer_length"] == 1
assert "turn_in_flight" in data
def test_stream_queue_session_not_found_returns_404(
def test_queue_pending_message_session_not_found_returns_404(
mocker: pytest_mock.MockerFixture,
) -> None:
"""If the session doesn't exist or belong to the user, returns 404."""
_mock_stream_queue_internals(mocker, session_exists=False)
response = client.post(
"/sessions/bad-sess/stream",
json={"message": "hi", "is_user_message": True},
"/sessions/bad-sess/messages/pending",
json={"message": "hi"},
)
assert response.status_code == 404
def test_stream_queue_call_frequency_limit_returns_429(
def test_queue_pending_message_without_active_turn_returns_409(
mocker: pytest_mock.MockerFixture,
) -> None:
"""A pending-message push needs an active turn to consume it."""
_mock_stream_queue_internals(mocker, turn_in_flight=False)
response = client.post(
"/sessions/sess-1/messages/pending",
json={"message": "hi"},
)
assert response.status_code == 409
def test_queue_pending_message_race_after_active_check_returns_409(
mocker: pytest_mock.MockerFixture,
) -> None:
"""If the active turn ends before the atomic push, the message is not queued."""
_mock_stream_queue_internals(mocker, push_length=None)
response = client.post(
"/sessions/sess-1/messages/pending",
json={"message": "hi"},
)
assert response.status_code == 409
def test_queue_pending_message_call_frequency_limit_returns_429(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Per-user call-frequency cap rejects rapid-fire queued pushes."""
@@ -744,14 +773,14 @@ def test_stream_queue_call_frequency_limit_returns_429(
_mock_stream_queue_internals(mocker, call_count=PENDING_CALL_LIMIT + 1)
response = client.post(
"/sessions/sess-1/stream",
json={"message": "hi", "is_user_message": True},
"/sessions/sess-1/messages/pending",
json={"message": "hi"},
)
assert response.status_code == 429
assert "Too many queued message requests this minute" in response.json()["detail"]
def test_stream_queue_converts_context_dict_to_pending_context(
def test_queue_pending_message_converts_context_dict_to_pending_context(
mocker: pytest_mock.MockerFixture,
) -> None:
"""StreamChatRequest.context is a raw dict; must be coerced to the
@@ -768,15 +797,14 @@ def test_stream_queue_converts_context_dict_to_pending_context(
)
response = client.post(
"/sessions/sess-1/stream",
"/sessions/sess-1/messages/pending",
json={
"message": "hi",
"is_user_message": True,
"context": {"url": "https://example.test", "content": "body"},
},
)
assert response.status_code == 202
assert response.status_code == 200
queue_spy.assert_awaited_once()
kwargs = queue_spy.await_args.kwargs
from backend.copilot.pending_messages import PendingMessageContext
@@ -786,7 +814,7 @@ def test_stream_queue_converts_context_dict_to_pending_context(
assert kwargs["context"].content == "body"
def test_stream_queue_passes_none_context_when_omitted(
def test_queue_pending_message_passes_none_context_when_omitted(
mocker: pytest_mock.MockerFixture,
) -> None:
"""When request.context is omitted, the queue call receives context=None."""
@@ -802,15 +830,31 @@ def test_stream_queue_passes_none_context_when_omitted(
)
response = client.post(
"/sessions/sess-1/stream",
json={"message": "hi", "is_user_message": True},
"/sessions/sess-1/messages/pending",
json={"message": "hi"},
)
assert response.status_code == 202
assert response.status_code == 200
queue_spy.assert_awaited_once()
assert queue_spy.await_args.kwargs["context"] is None
def test_stream_chat_queues_legacy_inflight_post_but_returns_sse(
mocker: pytest_mock.MockerFixture,
) -> None:
"""POST /stream must not return JSON to an AI SDK transport."""
_mock_stream_queue_internals(mocker)
response = client.post(
"/sessions/sess-1/stream",
json={"message": "follow-up", "is_user_message": True},
)
assert response.status_code == 200
assert response.headers["content-type"].startswith("text/event-stream")
assert '"type":"finish"' in response.text
# ─── get_pending_messages (GET /sessions/{session_id}/messages/pending) ─────
@@ -1581,9 +1625,14 @@ def test_resume_session_stream_no_subscriber_queue(
mock_registry.subscribe_to_session = AsyncMock(return_value=None)
mocker.patch("backend.api.features.chat.routes.stream_registry", mock_registry)
response = client.get("/sessions/sess-1/stream")
response = client.get("/sessions/sess-1/stream?last_chunk_id=9999-9")
assert response.status_code == 204
mock_registry.subscribe_to_session.assert_awaited_once_with(
session_id="sess-1",
user_id=TEST_USER_ID,
last_message_id="0-0",
)
# ─── DELETE /sessions/{id}/stream — disconnect listeners ──────────────

View File

@@ -0,0 +1,20 @@
import pydantic
class PushSubscriptionKeys(pydantic.BaseModel):
p256dh: str = pydantic.Field(min_length=1, max_length=512)
auth: str = pydantic.Field(min_length=1, max_length=512)
class PushSubscribeRequest(pydantic.BaseModel):
endpoint: str = pydantic.Field(min_length=1, max_length=2048)
keys: PushSubscriptionKeys
user_agent: str | None = pydantic.Field(default=None, max_length=512)
class PushUnsubscribeRequest(pydantic.BaseModel):
endpoint: str = pydantic.Field(min_length=1, max_length=2048)
class VapidPublicKeyResponse(pydantic.BaseModel):
public_key: str

View File

@@ -0,0 +1,64 @@
from typing import Annotated
from autogpt_libs.auth import get_user_id, requires_user
from fastapi import APIRouter, HTTPException, Security
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
from backend.api.features.push.model import (
PushSubscribeRequest,
PushUnsubscribeRequest,
VapidPublicKeyResponse,
)
from backend.data.push_subscription import (
delete_push_subscription,
upsert_push_subscription,
validate_push_endpoint,
)
from backend.util.settings import Settings
router = APIRouter()
_settings = Settings()
@router.get(
"/vapid-key",
summary="Get VAPID public key for push subscription",
)
async def get_vapid_public_key() -> VapidPublicKeyResponse:
return VapidPublicKeyResponse(public_key=_settings.secrets.vapid_public_key)
@router.post(
"/subscribe",
summary="Register a push subscription for the current user",
status_code=HTTP_204_NO_CONTENT,
dependencies=[Security(requires_user)],
)
async def subscribe_push(
user_id: Annotated[str, Security(get_user_id)],
body: PushSubscribeRequest,
) -> None:
try:
await validate_push_endpoint(body.endpoint)
await upsert_push_subscription(
user_id=user_id,
endpoint=body.endpoint,
p256dh=body.keys.p256dh,
auth=body.keys.auth,
user_agent=body.user_agent,
)
except ValueError as e:
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=str(e))
@router.post(
"/unsubscribe",
summary="Remove a push subscription",
status_code=HTTP_204_NO_CONTENT,
dependencies=[Security(requires_user)],
)
async def unsubscribe_push(
user_id: Annotated[str, Security(get_user_id)],
body: PushUnsubscribeRequest,
) -> None:
await delete_push_subscription(user_id, body.endpoint)

View File

@@ -0,0 +1,240 @@
"""Tests for push notification routes."""
from unittest.mock import AsyncMock, MagicMock
import fastapi
import fastapi.testclient
import pytest
from backend.api.features.push.routes import router
app = fastapi.FastAPI()
app.include_router(router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def test_get_vapid_public_key(mocker):
mock_settings = MagicMock()
mock_settings.secrets.vapid_public_key = "test-vapid-public-key-base64url"
mocker.patch(
"backend.api.features.push.routes._settings",
mock_settings,
)
response = client.get("/vapid-key")
assert response.status_code == 200
data = response.json()
assert data["public_key"] == "test-vapid-public-key-base64url"
def test_get_vapid_public_key_empty(mocker):
mock_settings = MagicMock()
mock_settings.secrets.vapid_public_key = ""
mocker.patch(
"backend.api.features.push.routes._settings",
mock_settings,
)
response = client.get("/vapid-key")
assert response.status_code == 200
data = response.json()
assert data["public_key"] == ""
def test_subscribe_push(mocker, test_user_id):
mock_upsert = mocker.patch(
"backend.api.features.push.routes.upsert_push_subscription",
new_callable=AsyncMock,
)
response = client.post(
"/subscribe",
json={
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
"keys": {
"p256dh": "test-p256dh-key",
"auth": "test-auth-key",
},
"user_agent": "Mozilla/5.0 Test",
},
)
assert response.status_code == 204
mock_upsert.assert_awaited_once_with(
user_id=test_user_id,
endpoint="https://fcm.googleapis.com/fcm/send/abc123",
p256dh="test-p256dh-key",
auth="test-auth-key",
user_agent="Mozilla/5.0 Test",
)
def test_subscribe_push_without_user_agent(mocker, test_user_id):
mock_upsert = mocker.patch(
"backend.api.features.push.routes.upsert_push_subscription",
new_callable=AsyncMock,
)
response = client.post(
"/subscribe",
json={
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
"keys": {
"p256dh": "test-p256dh-key",
"auth": "test-auth-key",
},
},
)
assert response.status_code == 204
mock_upsert.assert_awaited_once_with(
user_id=test_user_id,
endpoint="https://fcm.googleapis.com/fcm/send/abc123",
p256dh="test-p256dh-key",
auth="test-auth-key",
user_agent=None,
)
def test_subscribe_push_missing_keys():
response = client.post(
"/subscribe",
json={
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
},
)
assert response.status_code == 422
def test_subscribe_push_missing_endpoint():
response = client.post(
"/subscribe",
json={
"keys": {
"p256dh": "test-p256dh-key",
"auth": "test-auth-key",
},
},
)
assert response.status_code == 422
def test_subscribe_push_rejects_empty_crypto_keys():
response = client.post(
"/subscribe",
json={
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
"keys": {"p256dh": "", "auth": ""},
},
)
assert response.status_code == 422
def test_subscribe_push_rejects_oversized_endpoint():
response = client.post(
"/subscribe",
json={
"endpoint": "https://fcm.googleapis.com/fcm/send/" + "x" * 3000,
"keys": {"p256dh": "k", "auth": "a"},
},
)
assert response.status_code == 422
def test_unsubscribe_push(mocker, test_user_id):
mock_delete = mocker.patch(
"backend.api.features.push.routes.delete_push_subscription",
new_callable=AsyncMock,
)
response = client.post(
"/unsubscribe",
json={
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
},
)
assert response.status_code == 204
mock_delete.assert_awaited_once_with(
test_user_id,
"https://fcm.googleapis.com/fcm/send/abc123",
)
def test_unsubscribe_push_missing_endpoint():
response = client.post(
"/unsubscribe",
json={},
)
assert response.status_code == 422
@pytest.mark.parametrize(
"untrusted_endpoint",
[
"https://localhost/evil",
"https://127.0.0.1/evil",
"https://169.254.169.254/latest/meta-data/",
"https://internal-service.local/api",
"https://attacker.example.com/push",
"http://fcm.googleapis.com/fcm/send/abc",
"file:///etc/passwd",
],
)
def test_subscribe_push_rejects_untrusted_endpoints(mocker, untrusted_endpoint):
mock_upsert = mocker.patch(
"backend.api.features.push.routes.upsert_push_subscription",
new_callable=AsyncMock,
)
response = client.post(
"/subscribe",
json={
"endpoint": untrusted_endpoint,
"keys": {
"p256dh": "test-p256dh-key",
"auth": "test-auth-key",
},
},
)
assert response.status_code == 400
mock_upsert.assert_not_awaited()
def test_subscribe_push_surfaces_cap_as_400(mocker):
mocker.patch(
"backend.api.features.push.routes.upsert_push_subscription",
new_callable=AsyncMock,
side_effect=ValueError("Subscription limit of 20 per user reached"),
)
response = client.post(
"/subscribe",
json={
"endpoint": "https://fcm.googleapis.com/fcm/send/abc123",
"keys": {
"p256dh": "test-p256dh-key",
"auth": "test-auth-key",
},
},
)
assert response.status_code == 400
assert "Subscription limit" in response.json()["detail"]

View File

@@ -490,6 +490,9 @@ async def get_store_creators(
# Build where clause with sanitized inputs
where = {}
# Only return creators with approved agents
where["num_agents"] = {"gt": 0}
if featured:
where["is_featured"] = featured

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from unittest.mock import AsyncMock
import prisma.enums
import prisma.errors
@@ -50,8 +51,8 @@ async def test_get_store_agents(mocker):
# Mock prisma calls
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
mock_store_agent.return_value.find_many = mocker.AsyncMock(return_value=mock_agents)
mock_store_agent.return_value.count = mocker.AsyncMock(return_value=1)
mock_store_agent.return_value.find_many = AsyncMock(return_value=mock_agents)
mock_store_agent.return_value.count = AsyncMock(return_value=1)
# Call function
result = await db.get_store_agents()
@@ -94,7 +95,7 @@ async def test_get_store_agent_details(mocker):
# Mock StoreAgent prisma call
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
mock_store_agent.return_value.find_first = AsyncMock(return_value=mock_agent)
# Call function
result = await db.get_store_agent_details("creator", "test-agent")
@@ -133,7 +134,7 @@ async def test_get_store_creator(mocker):
# Mock prisma call
mock_creator = mocker.patch("prisma.models.Creator.prisma")
mock_creator.return_value.find_unique = mocker.AsyncMock()
mock_creator.return_value.find_unique = AsyncMock()
# Configure the mock to return values that will pass validation
mock_creator.return_value.find_unique.return_value = mock_creator_data
@@ -236,23 +237,23 @@ async def test_create_store_submission(mocker):
# Mock prisma calls
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
mock_agent_graph.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
mock_agent_graph.return_value.find_first = AsyncMock(return_value=mock_agent)
# Mock transaction context manager
mock_tx = mocker.MagicMock()
mocker.patch(
"backend.api.features.store.db.transaction",
return_value=mocker.AsyncMock(
__aenter__=mocker.AsyncMock(return_value=mock_tx),
__aexit__=mocker.AsyncMock(return_value=False),
return_value=AsyncMock(
__aenter__=AsyncMock(return_value=mock_tx),
__aexit__=AsyncMock(return_value=False),
),
)
mock_sl = mocker.patch("prisma.models.StoreListing.prisma")
mock_sl.return_value.find_unique = mocker.AsyncMock(return_value=None)
mock_sl.return_value.find_unique = AsyncMock(return_value=None)
mock_slv = mocker.patch("prisma.models.StoreListingVersion.prisma")
mock_slv.return_value.create = mocker.AsyncMock(return_value=mock_version)
mock_slv.return_value.create = AsyncMock(return_value=mock_version)
# Call function
result = await db.create_store_submission(
@@ -292,10 +293,8 @@ async def test_update_profile(mocker):
# Mock prisma calls
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
mock_profile_db.return_value.find_first = mocker.AsyncMock(
return_value=mock_profile
)
mock_profile_db.return_value.update = mocker.AsyncMock(return_value=mock_profile)
mock_profile_db.return_value.find_first = AsyncMock(return_value=mock_profile)
mock_profile_db.return_value.update = AsyncMock(return_value=mock_profile)
# Test data
profile = Profile(
@@ -336,9 +335,7 @@ async def test_get_user_profile(mocker):
# Mock prisma calls
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
mock_profile_db.return_value.find_first = mocker.AsyncMock(
return_value=mock_profile
)
mock_profile_db.return_value.find_first = AsyncMock(return_value=mock_profile)
# Call function
result = await db.get_user_profile("user-id")
@@ -396,3 +393,38 @@ async def test_get_store_agents_search_category_array_injection():
# Verify the query executed without error
# Category should be parameterized, preventing SQL injection
assert isinstance(result.agents, list)
@pytest.mark.asyncio(loop_scope="session")
async def test_get_store_creators_only_returns_approved(mocker):
mock_creators = [
prisma.models.Creator(
name="Creator One",
username="creator1",
description="desc",
links=["link1"],
avatar_url="avatar.jpg",
num_agents=1,
agent_rating=4.5,
agent_runs=10,
top_categories=["test"],
is_featured=False,
)
]
mock_creator = mocker.patch("prisma.models.Creator.prisma")
mock_creator.return_value.find_many = AsyncMock(return_value=mock_creators)
mock_creator.return_value.count = AsyncMock(return_value=1)
result = await db.get_store_creators()
assert len(result.creators) == 1
assert result.creators[0].username == "creator1"
mock_creator.return_value.find_many.assert_called_once()
mock_creator.return_value.count.assert_called_once()
_, find_kwargs = mock_creator.return_value.find_many.call_args
_, count_kwargs = mock_creator.return_value.count.call_args
assert find_kwargs["where"]["num_agents"] == {"gt": 0}
assert count_kwargs["where"]["num_agents"] == {"gt": 0}

View File

@@ -245,11 +245,12 @@ def test_get_subscription_status_tier_multipliers_ld_override(
assert "BUSINESS" not in data["tier_multipliers"]
def test_get_subscription_status_defaults_to_basic(
def test_get_subscription_status_defaults_to_no_tier(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""When all LD price IDs are unset, tier_costs is empty and the caller sees cost=0."""
"""When user has no subscription_tier, defaults to NO_TIER (the explicit
no-active-subscription state)."""
mock_user = Mock()
mock_user.subscription_tier = None
@@ -273,7 +274,7 @@ def test_get_subscription_status_defaults_to_basic(
assert response.status_code == 200
data = response.json()
assert data["tier"] == SubscriptionTier.BASIC.value
assert data["tier"] == SubscriptionTier.NO_TIER.value
assert data["monthly_cost"] == 0
assert data["tier_costs"] == {}
assert data["proration_credit_cents"] == 0
@@ -326,11 +327,11 @@ def test_get_subscription_status_stripe_error_falls_back_to_zero(
assert data["tier_costs"]["PRO"] == 0
def test_update_subscription_tier_basic_no_payment(
def test_update_subscription_tier_no_tier_no_payment(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""POST /credits/subscription to BASIC tier when payment disabled skips Stripe."""
"""POST /credits/subscription to NO_TIER (cancel) when payment disabled skips Stripe."""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.PRO
@@ -351,7 +352,7 @@ def test_update_subscription_tier_basic_no_payment(
new_callable=AsyncMock,
)
response = client.post("/credits/subscription", json={"tier": "BASIC"})
response = client.post("/credits/subscription", json={"tier": "NO_TIER"})
assert response.status_code == 200
assert response.json()["url"] == ""
@@ -404,12 +405,109 @@ def test_update_subscription_tier_paid_requires_urls(
"backend.api.features.v1.is_feature_enabled",
side_effect=mock_feature_enabled,
)
mocker.patch(
"backend.api.features.v1.modify_stripe_subscription_for_tier",
new_callable=AsyncMock,
return_value=False,
)
response = client.post("/credits/subscription", json={"tier": "PRO"})
assert response.status_code == 422
def test_update_subscription_tier_currency_mismatch_returns_422(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""Stripe rejects a SubscriptionSchedule whose phases mix currencies (e.g.
GBP-checkout sub trying to schedule a USD-only target Price). The handler
must convert that into a specific 422 instead of the generic 502 so the
caller can tell the difference between a currency-config bug and a Stripe
outage."""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.MAX
async def mock_feature_enabled(*args, **kwargs):
return True
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user,
)
mocker.patch(
"backend.api.features.v1.is_feature_enabled",
side_effect=mock_feature_enabled,
)
mocker.patch(
"backend.api.features.v1.modify_stripe_subscription_for_tier",
side_effect=stripe.InvalidRequestError(
"The price specified only supports `usd`. This doesn't match the"
" expected currency: `gbp`.",
param="phases",
),
)
response = client.post(
"/credits/subscription",
json={
"tier": "PRO",
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
},
)
assert response.status_code == 422
detail = response.json()["detail"]
assert "billing currency" in detail.lower()
assert "contact support" in detail.lower()
def test_update_subscription_tier_non_currency_invalid_request_returns_502(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""Locks the contract that *only* currency-mismatch InvalidRequestErrors
translate to 422 — every other Stripe InvalidRequestError must still
surface as the generic 502 so that widening the conditional later is
caught by the suite."""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.MAX
async def mock_feature_enabled(*args, **kwargs):
return True
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
return_value=mock_user,
)
mocker.patch(
"backend.api.features.v1.is_feature_enabled",
side_effect=mock_feature_enabled,
)
mocker.patch(
"backend.api.features.v1.modify_stripe_subscription_for_tier",
side_effect=stripe.InvalidRequestError(
"No such price: 'price_does_not_exist'",
param="items[0][price]",
),
)
response = client.post(
"/credits/subscription",
json={
"tier": "PRO",
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
},
)
assert response.status_code == 502
assert "billing currency" not in response.json()["detail"].lower()
def test_update_subscription_tier_creates_checkout(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
@@ -430,6 +528,11 @@ def test_update_subscription_tier_creates_checkout(
"backend.api.features.v1.is_feature_enabled",
side_effect=mock_feature_enabled,
)
mocker.patch(
"backend.api.features.v1.modify_stripe_subscription_for_tier",
new_callable=AsyncMock,
return_value=False,
)
mocker.patch(
"backend.api.features.v1.create_subscription_checkout",
new_callable=AsyncMock,
@@ -469,6 +572,11 @@ def test_update_subscription_tier_rejects_open_redirect(
"backend.api.features.v1.is_feature_enabled",
side_effect=mock_feature_enabled,
)
mocker.patch(
"backend.api.features.v1.modify_stripe_subscription_for_tier",
new_callable=AsyncMock,
return_value=False,
)
checkout_mock = mocker.patch(
"backend.api.features.v1.create_subscription_checkout",
new_callable=AsyncMock,
@@ -649,14 +757,14 @@ def test_update_subscription_tier_same_tier_stripe_error_returns_502(
assert "contact support" in response.json()["detail"].lower()
def test_update_subscription_tier_basic_with_payment_schedules_cancel_and_does_not_update_db(
def test_update_subscription_tier_no_tier_with_payment_schedules_cancel_and_does_not_update_db(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""Downgrading to BASIC schedules Stripe cancellation at period end.
"""Cancelling to NO_TIER schedules Stripe cancellation at period end.
The DB tier must NOT be updated immediately — the customer.subscription.deleted
webhook fires at period end and downgrades to BASIC then.
webhook fires at period end and downgrades to NO_TIER then.
"""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.PRO
@@ -682,18 +790,18 @@ def test_update_subscription_tier_basic_with_payment_schedules_cancel_and_does_n
side_effect=mock_feature_enabled,
)
response = client.post("/credits/subscription", json={"tier": "BASIC"})
response = client.post("/credits/subscription", json={"tier": "NO_TIER"})
assert response.status_code == 200
mock_cancel.assert_awaited_once()
mock_set_tier.assert_not_awaited()
def test_update_subscription_tier_basic_cancel_failure_returns_502(
def test_update_subscription_tier_no_tier_cancel_failure_returns_502(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""Downgrading to BASIC returns 502 with a generic error (no Stripe detail leakage)."""
"""Cancelling to NO_TIER returns 502 with a generic error (no Stripe detail leakage)."""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.PRO
@@ -716,7 +824,7 @@ def test_update_subscription_tier_basic_cancel_failure_returns_502(
side_effect=mock_feature_enabled,
)
response = client.post("/credits/subscription", json={"tier": "BASIC"})
response = client.post("/credits/subscription", json={"tier": "NO_TIER"})
assert response.status_code == 502
detail = response.json()["detail"]
@@ -921,29 +1029,20 @@ def test_update_subscription_tier_max_checkout(
checkout_mock.assert_not_awaited()
def test_update_subscription_tier_admin_granted_paid_to_paid_updates_db_directly(
def test_update_subscription_tier_no_active_sub_falls_through_to_checkout(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""Admin-granted paid tier users are NOT sent to Stripe checkout for paid→paid changes.
"""Any tier change from a user with no active Stripe sub goes through Checkout.
When modify_stripe_subscription_for_tier returns False (no Stripe subscription
found — admin-granted tier), the endpoint must update the DB tier directly and
return 200 with url="", rather than falling through to Checkout Session creation.
Admin-granted users (no Stripe sub yet) and never-paid users follow the
exact same path: modify returns False → Checkout to set up payment. The
endpoint has no admin-specific branch — admin tier grants happen out-of-band
via the admin portal, not this user-facing route.
"""
mock_user = Mock()
mock_user.subscription_tier = SubscriptionTier.PRO
async def price_id_with_business(tier: SubscriptionTier) -> str | None:
return {
**_DEFAULT_TIER_PRICES,
SubscriptionTier.BUSINESS: "price_business",
}.get(tier)
mocker.patch(
"backend.api.features.v1.get_subscription_price_id",
side_effect=price_id_with_business,
)
mocker.patch(
"backend.api.features.v1.get_user_by_id",
new_callable=AsyncMock,
@@ -954,7 +1053,6 @@ def test_update_subscription_tier_admin_granted_paid_to_paid_updates_db_directly
new_callable=AsyncMock,
return_value=True,
)
# Return False = no Stripe subscription (admin-granted tier)
modify_mock = mocker.patch(
"backend.api.features.v1.modify_stripe_subscription_for_tier",
new_callable=AsyncMock,
@@ -967,23 +1065,24 @@ def test_update_subscription_tier_admin_granted_paid_to_paid_updates_db_directly
checkout_mock = mocker.patch(
"backend.api.features.v1.create_subscription_checkout",
new_callable=AsyncMock,
return_value="https://checkout.stripe.com/pay/cs_test_no_sub",
)
response = client.post(
"/credits/subscription",
json={
"tier": "BUSINESS",
"tier": "MAX",
"success_url": f"{TEST_FRONTEND_ORIGIN}/success",
"cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel",
},
)
assert response.status_code == 200
assert response.json()["url"] == ""
modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.BUSINESS)
# DB tier updated directly — no Stripe Checkout Session created
set_tier_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.BUSINESS)
checkout_mock.assert_not_awaited()
assert response.json()["url"] == "https://checkout.stripe.com/pay/cs_test_no_sub"
modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.MAX)
# No DB-flip — payment must be collected via Checkout regardless of direction.
set_tier_mock.assert_not_awaited()
checkout_mock.assert_awaited_once()
def test_update_subscription_tier_priced_basic_no_sub_falls_through_to_checkout(
@@ -1154,14 +1253,14 @@ def test_update_subscription_tier_paid_to_paid_stripe_error_returns_502(
assert response.status_code == 502
def test_update_subscription_tier_basic_no_stripe_subscription(
def test_update_subscription_tier_no_tier_no_stripe_subscription(
client: fastapi.testclient.TestClient,
mocker: pytest_mock.MockFixture,
) -> None:
"""Downgrading to BASIC when no Stripe subscription exists updates DB tier directly.
"""Cancelling to NO_TIER when no Stripe subscription exists updates DB tier directly.
Admin-granted paid tiers have no associated Stripe subscription. When such a
user requests a self-service downgrade, cancel_stripe_subscription returns False
user requests a self-service cancel, cancel_stripe_subscription returns False
(nothing to cancel), so the endpoint must immediately call set_subscription_tier
rather than waiting for a webhook that will never arrive.
"""
@@ -1189,13 +1288,13 @@ def test_update_subscription_tier_basic_no_stripe_subscription(
new_callable=AsyncMock,
)
response = client.post("/credits/subscription", json={"tier": "BASIC"})
response = client.post("/credits/subscription", json={"tier": "NO_TIER"})
assert response.status_code == 200
assert response.json()["url"] == ""
cancel_mock.assert_awaited_once_with(TEST_USER_ID)
# DB tier must be updated immediately — no webhook will fire for a missing sub
set_tier_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.BASIC)
set_tier_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.NO_TIER)
def test_get_subscription_status_includes_pending_tier(

View File

@@ -57,12 +57,14 @@ from backend.data.credit import (
UserCredit,
cancel_stripe_subscription,
create_subscription_checkout,
get_active_subscription_period_end,
get_auto_top_up,
get_pending_subscription_change,
get_proration_credit_cents,
get_subscription_price_id,
get_user_credit_model,
handle_subscription_payment_failure,
handle_subscription_payment_success,
modify_stripe_subscription_for_tier,
release_pending_subscription_schedule,
set_auto_top_up,
@@ -700,13 +702,13 @@ async def get_user_auto_top_up(
class SubscriptionTierRequest(BaseModel):
tier: Literal["BASIC", "PRO", "MAX", "BUSINESS"]
tier: Literal["NO_TIER", "BASIC", "PRO", "MAX", "BUSINESS"]
success_url: str = ""
cancel_url: str = ""
class SubscriptionStatusResponse(BaseModel):
tier: Literal["BASIC", "PRO", "MAX", "BUSINESS", "ENTERPRISE"]
tier: Literal["NO_TIER", "BASIC", "PRO", "MAX", "BUSINESS", "ENTERPRISE"]
monthly_cost: int # amount in cents (Stripe convention)
tier_costs: dict[str, int] # tier name -> amount in cents
tier_multipliers: dict[str, float] = Field(
@@ -719,7 +721,23 @@ class SubscriptionStatusResponse(BaseModel):
),
)
proration_credit_cents: int # unused portion of current sub to convert on upgrade
pending_tier: Optional[Literal["BASIC", "PRO", "MAX", "BUSINESS"]] = None
has_active_stripe_subscription: bool = Field(
default=False,
description=(
"True when the user has an active/trialing Stripe subscription. The"
" frontend uses this to branch upgrade UX: modify-in-place + saved-card"
" auto-charge when True, redirect to Stripe Checkout when False."
),
)
current_period_end: Optional[int] = Field(
default=None,
description=(
"Unix timestamp of the active subscription's current_period_end. Used"
" to show the date Stripe will issue the next invoice (with prorated"
" upgrade charges, if any). None when no active sub."
),
)
pending_tier: Optional[Literal["NO_TIER", "BASIC", "PRO", "MAX", "BUSINESS"]] = None
pending_tier_effective_at: Optional[datetime] = None
url: str = Field(
default="",
@@ -804,8 +822,11 @@ async def get_subscription_status(
user_id: Annotated[str, Security(get_user_id)],
) -> SubscriptionStatusResponse:
user = await get_user_by_id(user_id)
tier = user.subscription_tier or SubscriptionTier.BASIC
tier = user.subscription_tier or SubscriptionTier.NO_TIER
# Tiers that *can* have a Stripe price configured (and therefore appear
# in the tier picker if the LD flag exposes a price-id). NO_TIER is not
# priceable — it's the implicit "no active subscription" state.
priceable_tiers = [
SubscriptionTier.BASIC,
SubscriptionTier.PRO,
@@ -839,7 +860,10 @@ async def get_subscription_status(
}
current_monthly_cost = tier_costs.get(tier.value, 0)
proration_credit = await get_proration_credit_cents(user_id, current_monthly_cost)
proration_credit, current_period_end = await asyncio.gather(
get_proration_credit_cents(user_id, current_monthly_cost),
get_active_subscription_period_end(user_id),
)
try:
pending = await get_pending_subscription_change(user_id)
@@ -861,10 +885,13 @@ async def get_subscription_status(
tier_costs=tier_costs,
tier_multipliers=tier_multipliers,
proration_credit_cents=proration_credit,
has_active_stripe_subscription=current_period_end is not None,
current_period_end=current_period_end,
)
if pending is not None:
pending_tier_enum, pending_effective_at = pending
if pending_tier_enum in (
SubscriptionTier.NO_TIER,
SubscriptionTier.BASIC,
SubscriptionTier.PRO,
SubscriptionTier.MAX,
@@ -892,7 +919,7 @@ async def update_subscription_tier(
# ENTERPRISE tier is admin-managed — block self-service changes from ENTERPRISE users.
user = await get_user_by_id(user_id)
if (
user.subscription_tier or SubscriptionTier.BASIC
user.subscription_tier or SubscriptionTier.NO_TIER
) == SubscriptionTier.ENTERPRISE:
raise HTTPException(
status_code=403,
@@ -904,7 +931,7 @@ async def update_subscription_tier(
# collapsed behaviour that replaces the old /credits/subscription/cancel-pending
# route. Safe when no pending change exists: release_pending_subscription_schedule
# returns False and we simply return the current status.
if (user.subscription_tier or SubscriptionTier.BASIC) == tier:
if (user.subscription_tier or SubscriptionTier.NO_TIER) == tier:
try:
await release_pending_subscription_schedule(user_id)
except stripe.StripeError as e:
@@ -926,18 +953,14 @@ async def update_subscription_tier(
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
)
current_tier = user.subscription_tier or SubscriptionTier.BASIC
target_price_id, current_tier_price_id = await asyncio.gather(
get_subscription_price_id(tier),
get_subscription_price_id(current_tier),
)
target_price_id = await get_subscription_price_id(tier)
# Legacy cancel: target BASIC + stripe-price-id-basic unset. Schedule Stripe
# cancellation at period end; cancel_at_period_end=True lets the webhook flip
# the DB tier. No active sub (admin-granted) or payment disabled → DB flip.
# Once stripe-price-id-basic is configured, BASIC becomes a real sub and falls
# through to the modify/checkout flow below.
if tier == SubscriptionTier.BASIC and target_price_id is None:
# Cancel: target NO_TIER. Schedule Stripe cancellation at period end;
# cancel_at_period_end=True lets the webhook flip the DB tier. No active
# sub (admin-granted or never-paid) or payment disabled → DB flip.
# NO_TIER is never priceable, so this branch always fires for cancel
# requests regardless of LD config.
if tier == SubscriptionTier.NO_TIER:
if payment_enabled:
try:
had_subscription = await cancel_stripe_subscription(user_id)
@@ -973,32 +996,53 @@ async def update_subscription_tier(
detail=f"Subscription not available for tier {tier.value}",
)
# User has an active Stripe subscription (current tier has an LD price):
# modify it in-place. modify_stripe_subscription_for_tier returns False when no
# active sub exists — that's only a "DB-only flip is OK" signal for admin-granted
# paid tiers (PRO/BUSINESS with no Stripe record). Priced-BASIC users without a
# sub must still go through Checkout so they set up payment.
if current_tier_price_id is not None:
try:
modified = await modify_stripe_subscription_for_tier(user_id, tier)
if modified:
return await get_subscription_status(user_id)
if current_tier != SubscriptionTier.BASIC:
await set_subscription_tier(user_id, tier)
return await get_subscription_status(user_id)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
except stripe.StripeError as e:
logger.exception(
"Stripe error modifying subscription for user %s: %s", user_id, e
# Modify in place if there's a sub; else fall through to Checkout below.
try:
modified = await modify_stripe_subscription_for_tier(user_id, tier)
if modified:
return await get_subscription_status(user_id)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
except stripe.InvalidRequestError as e:
# Stripe rejects schedule modify when phases mix currencies, e.g. the
# active sub was checked out in GBP but the target tier's Price is
# USD-only. 502 reads as outage; surface a 422 with a specific message
# so the user/admin can see what to fix in Stripe.
msg = str(e)
if "currency" in msg.lower():
logger.warning(
"Currency mismatch on tier change for user %s: %s", user_id, msg
)
raise HTTPException(
status_code=502,
status_code=422,
detail=(
"Unable to update your subscription right now. "
"Please try again or contact support."
"Tier change unavailable for your current billing currency."
" Please contact support — the target tier needs to be"
" configured for your currency in Stripe before this"
" change can go through."
),
)
logger.exception(
"Stripe error modifying subscription for user %s: %s", user_id, e
)
raise HTTPException(
status_code=502,
detail=(
"Unable to update your subscription right now. "
"Please try again or contact support."
),
)
except stripe.StripeError as e:
logger.exception(
"Stripe error modifying subscription for user %s: %s", user_id, e
)
raise HTTPException(
status_code=502,
detail=(
"Unable to update your subscription right now. "
"Please try again or contact support."
),
)
# No active Stripe subscription → create Stripe Checkout Session.
if not request.success_url or not request.cancel_url:
@@ -1134,6 +1178,9 @@ async def stripe_webhook(request: Request):
):
await sync_subscription_schedule_from_stripe(data_object)
if event_type == "invoice.payment_succeeded":
await handle_subscription_payment_success(data_object)
if event_type == "invoice.payment_failed":
await handle_subscription_payment_failure(data_object)

View File

@@ -34,6 +34,7 @@ import backend.api.features.oauth
import backend.api.features.otto.routes
import backend.api.features.platform_linking.routes
import backend.api.features.postmark.postmark
import backend.api.features.push.routes as push_routes
import backend.api.features.store.model
import backend.api.features.store.routes
import backend.api.features.v1
@@ -41,6 +42,7 @@ import backend.api.features.workspace.routes as workspace_routes
import backend.data.block
import backend.data.db
import backend.data.graph
import backend.data.redis_client
import backend.data.user
import backend.integrations.webhooks.utils
import backend.util.service
@@ -95,6 +97,8 @@ async def lifespan_context(app: fastapi.FastAPI):
verify_auth_settings()
await backend.data.db.connect()
# Eager connect to fail-fast if Redis is unreachable.
await backend.data.redis_client.get_redis_async()
# Configure thread pool for FastAPI sync operation performance
# CRITICAL: FastAPI automatically runs ALL sync functions in this thread pool:
@@ -146,7 +150,18 @@ async def lifespan_context(app: fastapi.FastAPI):
except Exception as e:
logger.warning(f"Error shutting down workspace storage: {e}")
await backend.data.db.disconnect()
# Each cleanup is wrapped so one failure doesn't block the rest. The
# Redis close in particular silences asyncio's "Unclosed ClusterNode"
# GC warning at interpreter shutdown.
try:
await backend.data.redis_client.disconnect_async()
except Exception:
logger.warning("redis_client.disconnect_async failed", exc_info=True)
try:
await backend.data.db.disconnect()
except Exception:
logger.warning("db.disconnect failed", exc_info=True)
def custom_generate_unique_id(route: APIRoute):
@@ -379,6 +394,11 @@ app.include_router(
tags=["oauth"],
prefix="/api/oauth",
)
app.include_router(
push_routes.router,
tags=["push"],
prefix="/api/push",
)
app.include_router(
backend.api.features.platform_linking.routes.router,
tags=["platform-linking"],

View File

@@ -1,4 +1,3 @@
import asyncio
import logging
from contextlib import asynccontextmanager
from typing import Protocol
@@ -17,14 +16,12 @@ from backend.api.model import (
WSSubscribeGraphExecutionsRequest,
)
from backend.api.utils.cors import build_cors_params
from backend.data.execution import AsyncRedisExecutionEventBus
from backend.data.notification_bus import AsyncRedisNotificationEventBus
from backend.data import db, redis_client
from backend.data.user import DEFAULT_USER_ID
from backend.monitoring.instrumentation import (
instrument_fastapi,
update_websocket_connections,
)
from backend.util.retry import continuous_retry
from backend.util.service import AppProcess
from backend.util.settings import AppEnvironment, Config, Settings
@@ -34,10 +31,24 @@ settings = Settings()
@asynccontextmanager
async def lifespan(app: FastAPI):
manager = get_connection_manager()
fut = asyncio.create_task(event_broadcaster(manager))
fut.add_done_callback(lambda _: logger.info("Event broadcaster stopped"))
yield
# Prisma is needed to resolve graph_id from graph_exec_id on subscribe.
await db.connect()
# Eager connect to fail-fast if Redis is unreachable.
await redis_client.get_redis_async()
try:
yield
finally:
# Each cleanup is wrapped so one failure doesn't block the rest. The
# Redis close silences asyncio's "Unclosed ClusterNode" GC warning at
# interpreter shutdown.
try:
await redis_client.disconnect_async()
except Exception:
logger.warning("redis_client.disconnect_async failed", exc_info=True)
try:
await db.disconnect()
except Exception:
logger.warning("db.disconnect failed", exc_info=True)
docs_url = "/docs" if settings.config.app_env == AppEnvironment.LOCAL else None
@@ -61,31 +72,6 @@ def get_connection_manager():
return _connection_manager
@continuous_retry()
async def event_broadcaster(manager: ConnectionManager):
execution_bus = AsyncRedisExecutionEventBus()
notification_bus = AsyncRedisNotificationEventBus()
try:
async def execution_worker():
async for event in execution_bus.listen("*"):
await manager.send_execution_update(event)
async def notification_worker():
async for notification in notification_bus.listen("*"):
await manager.send_notification(
user_id=notification.user_id,
payload=notification.payload,
)
await asyncio.gather(execution_worker(), notification_worker())
finally:
# Ensure PubSub connections are closed on any exit to prevent leaks
await execution_bus.close()
await notification_bus.close()
async def authenticate_websocket(websocket: WebSocket) -> str:
if not settings.config.enable_auth:
return DEFAULT_USER_ID
@@ -297,6 +283,21 @@ async def websocket_router(
).model_dump_json()
)
continue
except ValueError as e:
logger.warning(
"Subscription rejected for user #%s on '%s': %s",
user_id,
message.method.value,
e,
)
await websocket.send_text(
WSMessage(
method=WSMethod.ERROR,
success=False,
error=str(e),
).model_dump_json()
)
continue
except Exception as e:
logger.error(
f"Error while handling '{message.method.value}' message "
@@ -321,9 +322,13 @@ async def websocket_router(
)
except WebSocketDisconnect:
manager.disconnect_socket(websocket, user_id=user_id)
logger.debug("WebSocket client disconnected")
except Exception:
logger.exception(f"Unexpected error in websocket_router for user #{user_id}")
finally:
# Always release subscription pumps + Redis connections, regardless of how
# the loop exited — otherwise non-WebSocketDisconnect failures leak both.
await manager.disconnect_socket(websocket, user_id=user_id)
update_websocket_connections(user_id, -1)

View File

@@ -44,9 +44,12 @@ def test_websocket_server_uses_cors_helper(mocker) -> None:
"backend.api.ws_api.build_cors_params", return_value=cors_params
)
with override_config(
settings, "backend_cors_allow_origins", cors_params["allow_origins"]
), override_config(settings, "app_env", AppEnvironment.LOCAL):
with (
override_config(
settings, "backend_cors_allow_origins", cors_params["allow_origins"]
),
override_config(settings, "app_env", AppEnvironment.LOCAL),
):
WebsocketServer().run()
build_cors.assert_called_once_with(
@@ -65,9 +68,12 @@ def test_websocket_server_uses_cors_helper(mocker) -> None:
def test_websocket_server_blocks_localhost_in_production(mocker) -> None:
mocker.patch("backend.api.ws_api.uvicorn.run")
with override_config(
settings, "backend_cors_allow_origins", ["http://localhost:3000"]
), override_config(settings, "app_env", AppEnvironment.PRODUCTION):
with (
override_config(
settings, "backend_cors_allow_origins", ["http://localhost:3000"]
),
override_config(settings, "app_env", AppEnvironment.PRODUCTION),
):
with pytest.raises(ValueError):
WebsocketServer().run()
@@ -290,7 +296,232 @@ async def test_handle_unsubscribe_missing_data(
message=message,
)
mock_manager._unsubscribe.assert_not_called()
mock_manager.unsubscribe_graph_exec.assert_not_called()
mock_websocket.send_text.assert_called_once()
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
assert '"success":false' in mock_websocket.send_text.call_args[0][0]
# ---------- Per-graph subscribe branch ----------
@pytest.mark.asyncio
async def test_handle_subscribe_graph_execs_branch(
mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None:
"""The SUBSCRIBE_GRAPH_EXECS branch must route to subscribe_graph_execs,
not subscribe_graph_exec — regression guard for the aggregate channel."""
message = WSMessage(
method=WSMethod.SUBSCRIBE_GRAPH_EXECS,
data={"graph_id": "graph-abc"},
)
mock_manager.subscribe_graph_execs.return_value = (
"user-1|graph#graph-abc|executions"
)
await handle_subscribe(
connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
)
mock_manager.subscribe_graph_execs.assert_called_once_with(
user_id="user-1",
graph_id="graph-abc",
websocket=mock_websocket,
)
mock_manager.subscribe_graph_exec.assert_not_called()
mock_websocket.send_text.assert_called_once()
assert (
'"method":"subscribe_graph_executions"'
in mock_websocket.send_text.call_args[0][0]
)
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
@pytest.mark.asyncio
async def test_handle_subscribe_rejects_unrelated_method(
mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None:
"""handle_subscribe must raise for methods that aren't SUBSCRIBE_*."""
import pytest as _pytest
message = WSMessage(
method=WSMethod.HEARTBEAT,
data={"graph_exec_id": "x"},
)
with _pytest.raises(ValueError):
await handle_subscribe(
connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
)
# ---------- authenticate_websocket branches ----------
@pytest.mark.asyncio
async def test_authenticate_websocket_missing_token_closes_4001(mocker) -> None:
from backend.api.ws_api import authenticate_websocket
mocker.patch.object(settings.config, "enable_auth", True)
ws = AsyncMock(spec=WebSocket)
ws.query_params = {}
user_id = await authenticate_websocket(ws)
ws.close.assert_awaited_once()
assert ws.close.call_args.kwargs["code"] == 4001
assert user_id == ""
@pytest.mark.asyncio
async def test_authenticate_websocket_invalid_token_closes_4003(mocker) -> None:
from backend.api.ws_api import authenticate_websocket
mocker.patch.object(settings.config, "enable_auth", True)
mocker.patch(
"backend.api.ws_api.parse_jwt_token", side_effect=ValueError("bad token")
)
ws = AsyncMock(spec=WebSocket)
ws.query_params = {"token": "abc"}
user_id = await authenticate_websocket(ws)
ws.close.assert_awaited_once()
assert ws.close.call_args.kwargs["code"] == 4003
assert user_id == ""
@pytest.mark.asyncio
async def test_authenticate_websocket_missing_sub_closes_4002(mocker) -> None:
from backend.api.ws_api import authenticate_websocket
mocker.patch.object(settings.config, "enable_auth", True)
mocker.patch("backend.api.ws_api.parse_jwt_token", return_value={"not_sub": "x"})
ws = AsyncMock(spec=WebSocket)
ws.query_params = {"token": "abc"}
user_id = await authenticate_websocket(ws)
ws.close.assert_awaited_once()
assert ws.close.call_args.kwargs["code"] == 4002
assert user_id == ""
@pytest.mark.asyncio
async def test_authenticate_websocket_happy_path_returns_sub(mocker) -> None:
from backend.api.ws_api import authenticate_websocket
mocker.patch.object(settings.config, "enable_auth", True)
mocker.patch("backend.api.ws_api.parse_jwt_token", return_value={"sub": "user-X"})
ws = AsyncMock(spec=WebSocket)
ws.query_params = {"token": "abc"}
user_id = await authenticate_websocket(ws)
assert user_id == "user-X"
@pytest.mark.asyncio
async def test_authenticate_websocket_auth_disabled_returns_default(mocker) -> None:
from backend.api.ws_api import authenticate_websocket
mocker.patch.object(settings.config, "enable_auth", False)
ws = AsyncMock(spec=WebSocket)
ws.query_params = {}
user_id = await authenticate_websocket(ws)
assert user_id == DEFAULT_USER_ID
# ---------- get_connection_manager singleton ----------
def test_get_connection_manager_singleton() -> None:
"""Repeated calls must return the same ConnectionManager — the WS router
depends on a single process-wide subscription table."""
import backend.api.ws_api as ws_api
ws_api._connection_manager = None
a = ws_api.get_connection_manager()
b = ws_api.get_connection_manager()
assert a is b
assert isinstance(a, ConnectionManager)
# ---------- Lifespan: Prisma connect/disconnect ----------
@pytest.mark.asyncio
async def test_lifespan_connects_and_disconnects_prisma(mocker) -> None:
"""Lifespan must both connect() and disconnect() db — the subscribe path
resolves graph_id via Prisma so a missing connect() is the regression bug."""
from fastapi import FastAPI
from backend.api.ws_api import lifespan
mock_db = mocker.patch("backend.api.ws_api.db")
mock_db.connect = AsyncMock()
mock_db.disconnect = AsyncMock()
dummy_app = FastAPI()
async with lifespan(dummy_app):
mock_db.connect.assert_awaited_once()
mock_db.disconnect.assert_not_called()
mock_db.disconnect.assert_awaited_once()
@pytest.mark.asyncio
async def test_lifespan_still_disconnects_on_exception(mocker) -> None:
"""If the app raises inside the yield, Prisma must still disconnect."""
from fastapi import FastAPI
from backend.api.ws_api import lifespan
mock_db = mocker.patch("backend.api.ws_api.db")
mock_db.connect = AsyncMock()
mock_db.disconnect = AsyncMock()
dummy_app = FastAPI()
class _Boom(Exception):
pass
with pytest.raises(_Boom):
async with lifespan(dummy_app):
raise _Boom()
mock_db.disconnect.assert_awaited_once()
# ---------- Health endpoint ----------
def test_health_endpoint_returns_ok() -> None:
# TestClient triggers lifespan — stub it out so Prisma isn't hit.
from contextlib import asynccontextmanager
from fastapi.testclient import TestClient
import backend.api.ws_api as ws_api
@asynccontextmanager
async def _noop_lifespan(app):
yield
# Replace the app-level lifespan temporarily.
real_router_lifespan = ws_api.app.router.lifespan_context
ws_api.app.router.lifespan_context = _noop_lifespan
try:
with TestClient(ws_api.app) as client:
r = client.get("/")
assert r.status_code == 200
assert r.json() == {"status": "healthy"}
finally:
ws_api.app.router.lifespan_context = real_router_lifespan

View File

@@ -38,6 +38,7 @@ def main(**kwargs):
from backend.api.rest_api import AgentServer
from backend.api.ws_api import WebsocketServer
from backend.copilot.bot.app import CoPilotChatBridge
from backend.copilot.executor.manager import CoPilotExecutor
from backend.data.db_manager import DatabaseManager
from backend.executor import ExecutionManager, Scheduler
@@ -53,6 +54,7 @@ def main(**kwargs):
AgentServer(),
ExecutionManager(),
CoPilotExecutor(),
CoPilotChatBridge(),
**kwargs,
)

View File

@@ -7,6 +7,7 @@ import logging
import uuid
from typing import TYPE_CHECKING, Any
from pydantic import field_validator
from typing_extensions import TypedDict # Needed for Python <3.12 compatibility
from backend.blocks._base import (
@@ -17,6 +18,7 @@ from backend.blocks._base import (
BlockSchemaOutput,
)
from backend.copilot.permissions import (
DISABLED_LEGACY_TOOL_NAMES,
CopilotPermissions,
ToolName,
all_known_tool_names,
@@ -198,6 +200,13 @@ class AutoPilotBlock(Block):
# timeouts internally; wrapping with asyncio.timeout corrupts the
# SDK's internal stream (see service.py CRITICAL comment).
@field_validator("tools", mode="before")
@classmethod
def strip_disabled_legacy_tools(cls, tools: Any) -> Any:
if not isinstance(tools, list):
return tools
return [tool for tool in tools if tool not in DISABLED_LEGACY_TOOL_NAMES]
class Output(BlockSchemaOutput):
"""Output schema for the AutoPilot block."""

View File

@@ -62,6 +62,14 @@ class TestBuildAndValidatePermissions:
with pytest.raises(ValidationError, match="not_a_real_tool"):
_make_input(tools=["not_a_real_tool"])
async def test_disabled_legacy_tool_is_accepted_and_removed(self):
inp = _make_input(tools=["ask_question", "run_block"])
result = await _build_and_validate_permissions(inp)
assert inp.tools == ["run_block"]
assert isinstance(result, CopilotPermissions)
assert result.tools == ["run_block"]
async def test_valid_block_name_accepted(self):
mock_block_cls = MagicMock()
mock_block_cls.return_value.name = "HTTP Request"

View File

@@ -0,0 +1 @@
@README.md

View File

@@ -0,0 +1 @@
@README.md

View File

@@ -0,0 +1,79 @@
# CoPilot Bot
Multi-platform chat bot that bridges AutoPilot to Discord (and later Telegram, Slack, etc).
## Running
```bash
# As a standalone service
poetry run copilot-bot
# Or auto-start alongside the rest of the platform
poetry run app # starts the bot too if AUTOPILOT_BOT_DISCORD_TOKEN is set
```
## Required environment variables
See `backend/.env.default` for the full list with documentation. Minimum setup:
| Variable | Purpose |
|----------|---------|
| `AUTOPILOT_BOT_DISCORD_TOKEN` | Discord bot token — enables the Discord adapter |
| `FRONTEND_BASE_URL` | Frontend base URL for link confirmation pages (shared with the rest of the backend) |
| `REDIS_HOST` / `REDIS_PORT` | Session + thread subscription state + copilot stream subscription (inherited from the shared backend config) |
| `PLATFORMLINKINGMANAGER_HOST` | DNS name of the `PlatformLinkingManager` service pod (cluster-internal RPC) |
## Architecture
```
bot/
├── app.py # CoPilotChatBridge(AppService), adapter factory, outbound @expose RPC
├── config.py # Shared (platform-agnostic) config
├── handler.py # Core logic: routing, linking, batched streaming
├── bot_backend.py # Thin facade over PlatformLinkingManagerClient + stream_registry
├── text.py # Text splitting + batch formatting
├── threads.py # Redis-backed thread subscription tracking
└── adapters/
├── base.py # PlatformAdapter interface + MessageContext
└── discord/
├── adapter.py # Gateway connection, events, sends, thread creation
├── commands.py # Slash commands (/setup, /help, /unlink)
└── config.py # Discord token + platform limits
```
**Locality rule:** anything platform-specific lives under `adapters/<platform>/`.
The only file that names specific platforms is `app.py`, which is the factory
that decides which adapters to instantiate based on which tokens are set.
## How messaging works
1. User mentions the bot in a channel
2. Adapter's `on_message` handler fires, constructs a `MessageContext`, passes
it to the shared `MessageHandler`
3. Handler:
- Checks if the user/server is linked (via `bot_backend`)
- If not linked → sends a "Link Account" button prompt
- If linked → creates a thread (for channels) or uses the existing thread/DM
- Marks the thread as subscribed in Redis (7-day TTL)
- Streams the AutoPilot response back, chunked at the adapter's
`chunk_flush_at` boundary
4. Messages that arrive while a stream is running get batched and sent as a
single follow-up turn once the current stream ends
## Adding a new platform
1. Create `adapters/<platform>/` with `adapter.py`, `commands.py` (if the
platform has commands), and `config.py`
2. `adapter.py` subclasses `PlatformAdapter` and implements all its abstract
methods — `max_message_length`, `chunk_flush_at`, `send_message`,
`send_link`, `create_thread`, etc.
3. `config.py` declares the platform's env vars and any platform-specific
numbers (message limits, token name, etc.)
4. Add two lines to `app.py::_build_adapters`:
```python
if <platform>_config.BOT_TOKEN:
adapters.append(<Platform>Adapter(api))
```
The core handler, text utilities, thread tracking, and platform API all stay
untouched.

View File

@@ -0,0 +1,19 @@
"""Entry point for running the CoPilot Chat Bridge service.
Usage:
poetry run copilot-bot
python -m backend.copilot.bot
"""
from backend.app import run_processes
from .app import CoPilotChatBridge
def main():
"""Run the CoPilot Chat Bridge service."""
run_processes(CoPilotChatBridge())
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,110 @@
"""Abstract base for platform adapters.
Each chat platform (Discord, Telegram, Slack, etc.) implements this interface.
The core bot logic in handler.py is platform-agnostic — it only speaks through
these methods.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Awaitable, Callable, Literal, Optional
# Callback signature: (ctx, adapter) -> awaitable None
MessageCallback = Callable[["MessageContext", "PlatformAdapter"], Awaitable[None]]
# Where the message came from:
# - "dm" — 1:1 conversation, reply in-place
# - "channel" — public channel, bot was @mentioned, create a thread to respond
# - "thread" — ongoing thread conversation, reply in-place
ChannelType = Literal["dm", "channel", "thread"]
@dataclass
class MessageContext:
"""Everything the core handler needs to know about an incoming message."""
platform: str
channel_type: ChannelType
server_id: Optional[str]
channel_id: str # DM channel ID / parent channel ID / thread ID
message_id: str # the incoming message itself — used to create threads from it
user_id: str
username: str
text: str # with bot mentions stripped
@property
def is_dm(self) -> bool:
return self.channel_type == "dm"
class PlatformAdapter(ABC):
"""Interface that each chat platform must implement."""
@property
@abstractmethod
def platform_name(self) -> str: ...
@abstractmethod
def on_message(self, callback: MessageCallback) -> None: ...
@abstractmethod
async def start(self) -> None: ...
@abstractmethod
async def stop(self) -> None: ...
@abstractmethod
async def send_message(self, channel_id: str, text: str) -> None: ...
@abstractmethod
async def send_link(
self, channel_id: str, text: str, link_label: str, link_url: str
) -> None:
"""Send a message with a clickable link presented as a button/CTA.
Platforms without native button support should fall back to rendering
the URL inline in the text.
"""
...
@abstractmethod
async def send_reply(
self, channel_id: str, text: str, reply_to_message_id: str
) -> None: ...
@abstractmethod
async def send_ephemeral(
self, channel_id: str, user_id: str, text: str
) -> None: ...
@abstractmethod
async def start_typing(self, channel_id: str) -> None: ...
@abstractmethod
async def stop_typing(self, channel_id: str) -> None: ...
@abstractmethod
async def create_thread(
self, channel_id: str, message_id: str, name: str
) -> Optional[str]:
"""Create a thread from a message. Returns the thread ID, or None if
the platform doesn't support threads or creation failed.
"""
...
@property
@abstractmethod
def max_message_length(self) -> int:
"""Hard platform cap on a single message's content length."""
...
@property
@abstractmethod
def chunk_flush_at(self) -> int:
"""Flush the streaming buffer once it reaches this length.
Should be slightly under max_message_length to leave headroom for
any trailing content that the splitter might pull into the current
chunk.
"""
...

View File

@@ -0,0 +1,209 @@
"""Discord adapter — connects to the Discord Gateway and forwards messages.
Platform-specific machinery only: Gateway connection, message event handling,
thread creation, typing, button rendering. All platform-agnostic logic lives
in the core handler. Slash commands live in commands.py.
"""
import logging
from typing import Optional
import discord
from discord import app_commands
from backend.copilot.bot.bot_backend import BotBackend
from ..base import ChannelType, MessageCallback, MessageContext, PlatformAdapter
from . import commands, config
logger = logging.getLogger(__name__)
class DiscordAdapter(PlatformAdapter):
def __init__(self, api: BotBackend):
intents = discord.Intents.default()
intents.message_content = True
# AutoPilot output is untrusted w.r.t. mentions — suppress @everyone,
# role, and user pings the LLM might produce. Client-level default
# applies to every send() + reply() below.
self._client = discord.Client(
intents=intents,
allowed_mentions=discord.AllowedMentions.none(),
)
self._tree = app_commands.CommandTree(self._client)
self._api = api
self._on_message_callback: Optional[MessageCallback] = None
self._commands_synced = False
self._register_events()
commands.register(self._tree, self._api)
@property
def platform_name(self) -> str:
return "discord"
@property
def max_message_length(self) -> int:
return config.MAX_MESSAGE_LENGTH
@property
def chunk_flush_at(self) -> int:
return config.CHUNK_FLUSH_AT
def on_message(self, callback: MessageCallback) -> None:
self._on_message_callback = callback
async def start(self) -> None:
await self._client.start(config.get_bot_token())
async def stop(self) -> None:
if not self._client.is_closed():
await self._client.close()
async def _resolve_channel(self, channel_id: str):
"""Return the channel for ``channel_id``, falling back to a REST fetch.
``Client.get_channel`` only reads the in-memory cache, so it misses
threads the bot hasn't seen since its last restart. Fall back to
``fetch_channel`` (REST) so long-lived threads keep working.
"""
channel = self._client.get_channel(int(channel_id))
if channel is not None:
return channel
try:
return await self._client.fetch_channel(int(channel_id))
except (discord.NotFound, discord.Forbidden, discord.HTTPException):
logger.warning("Channel %s not found or inaccessible", channel_id)
return None
async def send_message(self, channel_id: str, text: str) -> None:
channel = await self._resolve_channel(channel_id)
if channel and isinstance(channel, discord.abc.Messageable):
# tts=False is the default but we pin it explicitly — AutoPilot
# output is untrusted and should never blast through voice.
await channel.send(text, tts=False)
async def send_link(
self, channel_id: str, text: str, link_label: str, link_url: str
) -> None:
channel = await self._resolve_channel(channel_id)
if channel is None or not isinstance(channel, discord.abc.Messageable):
return
view = discord.ui.View()
view.add_item(
discord.ui.Button(
style=discord.ButtonStyle.link,
label=link_label[:80], # Discord button label max
url=link_url,
)
)
await channel.send(text, view=view, tts=False)
async def send_reply(
self, channel_id: str, text: str, reply_to_message_id: str
) -> None:
channel = await self._resolve_channel(channel_id)
if not channel or not isinstance(channel, discord.abc.Messageable):
return
try:
msg = await channel.fetch_message(int(reply_to_message_id))
await msg.reply(text, tts=False)
except discord.NotFound:
await channel.send(text, tts=False)
async def send_ephemeral(self, channel_id: str, user_id: str, text: str) -> None:
# Ephemeral messages are only possible via interaction responses.
# Fall back to a normal message for non-interaction contexts.
await self.send_message(channel_id, text)
async def start_typing(self, channel_id: str) -> None:
channel = await self._resolve_channel(channel_id)
if channel and isinstance(channel, discord.abc.Messageable):
await channel.typing()
async def stop_typing(self, channel_id: str) -> None:
pass # Discord typing auto-expires after ~10s
async def create_thread(
self, channel_id: str, message_id: str, name: str
) -> Optional[str]:
channel = await self._resolve_channel(channel_id)
if channel is None or not isinstance(channel, discord.TextChannel):
logger.warning("Cannot create thread in non-text channel %s", channel_id)
return None
try:
msg = await channel.fetch_message(int(message_id))
thread = await msg.create_thread(name=name[:100])
return str(thread.id)
except discord.HTTPException:
logger.exception("Failed to create thread in channel %s", channel_id)
return None
# -- Internal --
def _register_events(self) -> None:
@self._client.event
async def on_ready() -> None:
logger.info(f"Discord bot connected as {self._client.user}")
# Sync slash commands once per process — on_ready fires on every
# gateway reconnect, but the command tree only needs uploading once.
if self._commands_synced:
return
try:
synced = await self._tree.sync()
self._commands_synced = True
logger.info(f"Synced {len(synced)} slash commands")
except Exception:
logger.exception("Failed to sync slash commands")
@self._client.event
async def on_message(message: discord.Message) -> None:
if message.author.bot:
return
if self._on_message_callback is None:
return
channel_type = self._channel_type(message)
# Channels require an explicit @mention; DMs and threads always forward
# (handler checks thread subscription).
if channel_type == "channel" and not self._is_mentioned(message):
return
ctx = MessageContext(
platform="discord",
channel_type=channel_type,
server_id=str(message.guild.id) if message.guild else None,
channel_id=str(message.channel.id),
message_id=str(message.id),
user_id=str(message.author.id),
username=message.author.display_name,
text=self._strip_mentions(message),
)
await self._on_message_callback(ctx, self)
def _is_mentioned(self, message: discord.Message) -> bool:
if message.guild is None:
return True # DMs always count
return bool(self._client.user and self._client.user.mentioned_in(message))
@staticmethod
def _channel_type(message: discord.Message) -> ChannelType:
if message.guild is None:
return "dm"
if isinstance(message.channel, discord.Thread):
return "thread"
return "channel"
def _strip_mentions(self, message: discord.Message) -> str:
"""Strip the bot's own mention; replace other users' raw mention
tokens with `@displayname` so the LLM keeps the context.
"""
text = message.content
bot_id = self._client.user.id if self._client.user else None
for user in message.mentions:
raw_tokens = (f"<@{user.id}>", f"<@!{user.id}>")
replacement = "" if user.id == bot_id else f"@{user.display_name}"
for token in raw_tokens:
text = text.replace(token, replacement)
return text.strip()

View File

@@ -0,0 +1,259 @@
"""Tests for DiscordAdapter helpers that don't need a live gateway."""
from typing import cast
from unittest.mock import AsyncMock, MagicMock
import discord
import pytest
from backend.copilot.bot.adapters.discord.adapter import DiscordAdapter
def _bare_adapter(bot_id: int | None = 1000) -> tuple[DiscordAdapter, MagicMock]:
"""Build a DiscordAdapter without going through __init__ (which spins up
discord.py internals). Returns the adapter alongside the MagicMock that
stands in for ``_client`` — tests reach into the mock directly for
per-method stubbing.
"""
adapter = DiscordAdapter.__new__(DiscordAdapter)
client = MagicMock()
client.user = MagicMock(id=bot_id) if bot_id is not None else None
adapter._client = cast(discord.Client, client)
adapter._on_message_callback = None
adapter._commands_synced = False
return adapter, client
def _mention(user_id: int, display_name: str) -> MagicMock:
user = MagicMock()
user.id = user_id
user.display_name = display_name
return user
def _message(content: str, mentions: list[MagicMock]) -> MagicMock:
msg = MagicMock()
msg.content = content
msg.mentions = mentions
return msg
# ── _strip_mentions ────────────────────────────────────────────────────
class TestStripMentions:
def test_strips_only_bot_mention(self):
adapter, _ = _bare_adapter(bot_id=1000)
bot = _mention(1000, "AutoPilot")
alice = _mention(2000, "Alice")
msg = _message(
"<@1000> please summarise what <@2000> said",
mentions=[bot, alice],
)
assert adapter._strip_mentions(msg) == "please summarise what @Alice said"
def test_handles_nickname_style_tokens(self):
adapter, _ = _bare_adapter(bot_id=1000)
bot = _mention(1000, "AutoPilot")
alice = _mention(2000, "Alice")
msg = _message("<@!1000> ping <@!2000>", mentions=[bot, alice])
assert adapter._strip_mentions(msg) == "ping @Alice"
def test_no_bot_user_leaves_all_mentions_as_names(self):
adapter, _ = _bare_adapter(bot_id=None)
alice = _mention(2000, "Alice")
msg = _message("hi <@2000>", mentions=[alice])
assert adapter._strip_mentions(msg) == "hi @Alice"
def test_message_without_mentions_is_trimmed(self):
adapter, _ = _bare_adapter(bot_id=1000)
msg = _message(" hello world ", mentions=[])
assert adapter._strip_mentions(msg) == "hello world"
@pytest.mark.parametrize(
"content,expected",
[
("<@1000>", ""),
("<@!1000>", ""),
("<@1000> hi", "hi"),
("hi <@1000>", "hi"),
],
)
def test_bot_only_variants(self, content: str, expected: str):
adapter, _ = _bare_adapter(bot_id=1000)
bot = _mention(1000, "AutoPilot")
msg = _message(content, mentions=[bot])
assert adapter._strip_mentions(msg) == expected
# ── _channel_type ──────────────────────────────────────────────────────
class TestChannelType:
def test_dm_has_no_guild(self):
msg = MagicMock()
msg.guild = None
assert DiscordAdapter._channel_type(msg) == "dm"
def test_thread_inside_guild(self):
msg = MagicMock()
msg.guild = MagicMock()
msg.channel = MagicMock(spec=discord.Thread)
assert DiscordAdapter._channel_type(msg) == "thread"
def test_regular_channel_inside_guild(self):
msg = MagicMock()
msg.guild = MagicMock()
msg.channel = MagicMock()
assert DiscordAdapter._channel_type(msg) == "channel"
# ── _is_mentioned ──────────────────────────────────────────────────────
class TestIsMentioned:
def test_dm_always_counts_as_mentioned(self):
adapter, _ = _bare_adapter(bot_id=1000)
msg = MagicMock()
msg.guild = None
assert adapter._is_mentioned(msg) is True
def test_guild_requires_explicit_mention(self):
adapter, client = _bare_adapter(bot_id=1000)
msg = MagicMock()
msg.guild = MagicMock()
client.user.mentioned_in.return_value = False
assert adapter._is_mentioned(msg) is False
def test_guild_with_mention_passes(self):
adapter, client = _bare_adapter(bot_id=1000)
msg = MagicMock()
msg.guild = MagicMock()
client.user.mentioned_in.return_value = True
assert adapter._is_mentioned(msg) is True
def test_no_bot_user_treats_guild_mention_as_false(self):
adapter, _ = _bare_adapter(bot_id=None)
msg = MagicMock()
msg.guild = MagicMock()
assert adapter._is_mentioned(msg) is False
# ── _resolve_channel ───────────────────────────────────────────────────
class TestResolveChannel:
@pytest.mark.asyncio
async def test_cache_hit_skips_rest_fetch(self):
adapter, client = _bare_adapter()
cached = MagicMock()
client.get_channel.return_value = cached
client.fetch_channel = AsyncMock()
result = await adapter._resolve_channel("123")
assert result is cached
client.fetch_channel.assert_not_awaited()
@pytest.mark.asyncio
async def test_cache_miss_falls_back_to_rest(self):
adapter, client = _bare_adapter()
fetched = MagicMock()
client.get_channel.return_value = None
client.fetch_channel = AsyncMock(return_value=fetched)
result = await adapter._resolve_channel("123")
assert result is fetched
client.fetch_channel.assert_awaited_once_with(123)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"exc",
[
discord.NotFound(MagicMock(status=404), "gone"),
discord.Forbidden(MagicMock(status=403), "nope"),
discord.HTTPException(MagicMock(status=500), "boom"),
],
)
async def test_rest_failure_returns_none(self, exc: Exception):
adapter, client = _bare_adapter()
client.get_channel.return_value = None
client.fetch_channel = AsyncMock(side_effect=exc)
assert await adapter._resolve_channel("123") is None
# ── send_message / send_reply / send_link ──────────────────────────────
class TestSendMethods:
@pytest.mark.asyncio
async def test_send_message_pins_tts_false(self):
adapter, client = _bare_adapter()
channel = MagicMock(spec=discord.TextChannel)
channel.send = AsyncMock()
client.get_channel.return_value = channel
await adapter.send_message("123", "hi")
channel.send.assert_awaited_once_with("hi", tts=False)
@pytest.mark.asyncio
async def test_send_message_silently_drops_when_channel_missing(self):
adapter, client = _bare_adapter()
client.get_channel.return_value = None
client.fetch_channel = AsyncMock(
side_effect=discord.NotFound(MagicMock(status=404), "gone")
)
# Should not raise even though there's nothing to send to.
await adapter.send_message("123", "hi")
@pytest.mark.asyncio
async def test_send_link_attaches_button_and_pins_tts(self):
adapter, client = _bare_adapter()
channel = MagicMock(spec=discord.TextChannel)
channel.send = AsyncMock()
client.get_channel.return_value = channel
await adapter.send_link("123", "click me", "Open", "https://example.com")
assert channel.send.await_count == 1
kwargs = channel.send.await_args.kwargs
assert kwargs["tts"] is False
view = kwargs["view"]
assert any(
getattr(c, "url", None) == "https://example.com" for c in view.children
)
@pytest.mark.asyncio
async def test_send_reply_falls_back_to_send_when_message_missing(self):
adapter, client = _bare_adapter()
channel = MagicMock(spec=discord.TextChannel)
channel.send = AsyncMock()
channel.fetch_message = AsyncMock(
side_effect=discord.NotFound(MagicMock(status=404), "gone")
)
client.get_channel.return_value = channel
await adapter.send_reply("123", "hello", "999")
channel.send.assert_awaited_once_with("hello", tts=False)
# ── properties ─────────────────────────────────────────────────────────
class TestProperties:
def test_platform_name_is_discord(self):
adapter, _ = _bare_adapter()
assert adapter.platform_name == "discord"
def test_chunk_flush_at_is_under_message_limit(self):
adapter, _ = _bare_adapter()
assert adapter.chunk_flush_at < adapter.max_message_length

View File

@@ -0,0 +1,134 @@
"""Discord slash command handlers.
Registered on the bot's CommandTree at startup. All responses are ephemeral
(visible only to the invoking user) to keep channels clean and to keep link
URLs private.
"""
import logging
import discord
from discord import app_commands
from backend.copilot.bot.bot_backend import BotBackend
from backend.util.exceptions import LinkAlreadyExistsError
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
def register(tree: app_commands.CommandTree, api: BotBackend) -> None:
"""Register all slash commands on the given CommandTree."""
@tree.command(
name="setup",
description="Link this server to an AutoGPT account for AutoPilot",
)
async def setup_command(interaction: discord.Interaction) -> None:
await _handle_setup(interaction, api)
@tree.command(name="help", description="Show AutoPilot bot usage info")
async def help_command(interaction: discord.Interaction) -> None:
await _handle_help(interaction)
@tree.command(
name="unlink",
description="Manage linked servers from your AutoGPT settings",
)
async def unlink_command(interaction: discord.Interaction) -> None:
await _handle_unlink(interaction)
async def _handle_setup(interaction: discord.Interaction, api: BotBackend) -> None:
if interaction.guild is None:
await interaction.response.send_message(
"This command can only be used in a server. "
"To link your DMs, just send me a direct message.",
ephemeral=True,
)
return
await interaction.response.defer(ephemeral=True)
try:
result = await api.create_link_token(
platform="discord",
platform_server_id=str(interaction.guild.id),
platform_user_id=str(interaction.user.id),
platform_username=interaction.user.display_name,
server_name=interaction.guild.name,
channel_id=str(interaction.channel_id or ""),
)
except LinkAlreadyExistsError:
await interaction.followup.send(
"This server is already linked — just mention me!",
ephemeral=True,
)
return
except Exception:
logger.exception("Failed to create link token")
await interaction.followup.send(
"Something went wrong. Try again later.",
ephemeral=True,
)
return
view = discord.ui.View()
view.add_item(
discord.ui.Button(
style=discord.ButtonStyle.link,
label="Link Server",
url=result.link_url,
)
)
await interaction.followup.send(
f"**Set up AutoPilot for {interaction.guild.name}**\n\n"
"Click the button below to connect this server to your AutoGPT "
"account. Once confirmed, everyone here can mention me to use "
"AutoPilot.\n\n"
"All usage will be billed to your account.\n"
"This link expires in 30 minutes.",
ephemeral=True,
view=view,
)
async def _handle_help(interaction: discord.Interaction) -> None:
await interaction.response.send_message(
"**AutoPilot Bot**\n\n"
"Mention me in a server or DM me directly to chat.\n\n"
"**Commands:**\n"
"- `/setup` — Link this server to your AutoGPT account\n"
"- `/help` — Show this message\n"
"- `/unlink` — Manage linked accounts\n\n"
"**How it works:**\n"
"- In a server: the person who runs `/setup` pays for usage\n"
"- In DMs: you link and pay for your own usage\n",
ephemeral=True,
)
async def _handle_unlink(interaction: discord.Interaction) -> None:
config = Settings().config
base_url = config.frontend_base_url or config.platform_base_url
message = (
"Unlinking requires authentication, so it has to be done "
"from the web. Click below to manage your linked accounts."
)
if not base_url:
await interaction.response.send_message(
f"{message}\n\nOpen your AutoGPT settings and visit "
"Profile → Linked accounts.",
ephemeral=True,
)
return
view = discord.ui.View()
view.add_item(
discord.ui.Button(
style=discord.ButtonStyle.link,
label="Open Settings",
url=f"{base_url}/profile/settings",
)
)
await interaction.response.send_message(message, ephemeral=True, view=view)

View File

@@ -0,0 +1,165 @@
"""Tests for Discord slash command handlers.
Targets the ``_handle_*`` functions directly — sidesteps ``CommandTree``
registration since it requires a live ``discord.Client``.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.util.exceptions import LinkAlreadyExistsError
from ...bot_backend import LinkTokenResult
from .commands import _handle_help, _handle_setup, _handle_unlink
def _interaction(*, guild: bool = True) -> MagicMock:
interaction = MagicMock()
interaction.response.send_message = AsyncMock()
interaction.response.defer = AsyncMock()
interaction.followup.send = AsyncMock()
if guild:
# MagicMock treats `name` as a constructor kwarg for the mock's repr,
# not as an attribute — so set it after construction.
interaction.guild = MagicMock(id=123)
interaction.guild.name = "Test Guild"
interaction.user = MagicMock(id=456, display_name="Bently")
interaction.channel_id = 789
else:
interaction.guild = None
interaction.user = MagicMock(id=456, display_name="Bently")
interaction.channel_id = None
return interaction
def _api_with_token() -> MagicMock:
api = MagicMock()
api.create_link_token = AsyncMock(
return_value=LinkTokenResult(
token="abc",
link_url="https://example.com/link/abc",
expires_at="2099-01-01T00:00:00Z",
)
)
return api
class TestHandleSetup:
@pytest.mark.asyncio
async def test_dm_invocation_rejects_early(self):
interaction = _interaction(guild=False)
api = _api_with_token()
await _handle_setup(interaction, api)
interaction.response.send_message.assert_awaited_once()
api.create_link_token.assert_not_awaited()
@pytest.mark.asyncio
async def test_guild_invocation_creates_token_and_posts_button(self):
interaction = _interaction()
api = _api_with_token()
await _handle_setup(interaction, api)
interaction.response.defer.assert_awaited_once_with(ephemeral=True)
api.create_link_token.assert_awaited_once()
call_kwargs = api.create_link_token.await_args.kwargs
assert call_kwargs["platform"] == "discord"
assert call_kwargs["platform_server_id"] == "123"
assert call_kwargs["server_name"] == "Test Guild"
interaction.followup.send.assert_awaited_once()
sent = interaction.followup.send.await_args
assert "Set up AutoPilot for Test Guild" in sent.args[0]
assert sent.kwargs["view"] is not None
@pytest.mark.asyncio
async def test_already_linked_gets_friendly_message(self):
interaction = _interaction()
api = _api_with_token()
api.create_link_token = AsyncMock(side_effect=LinkAlreadyExistsError("already"))
await _handle_setup(interaction, api)
interaction.followup.send.assert_awaited_once()
msg = interaction.followup.send.await_args.args[0]
assert "already linked" in msg
@pytest.mark.asyncio
async def test_backend_error_surfaces_generic_message(self):
interaction = _interaction()
api = _api_with_token()
api.create_link_token = AsyncMock(side_effect=RuntimeError("boom"))
await _handle_setup(interaction, api)
interaction.followup.send.assert_awaited_once()
msg = interaction.followup.send.await_args.args[0]
assert "went wrong" in msg.lower()
class TestHandleHelp:
@pytest.mark.asyncio
async def test_help_sends_ephemeral_message(self):
interaction = _interaction()
await _handle_help(interaction)
interaction.response.send_message.assert_awaited_once()
assert interaction.response.send_message.await_args.kwargs["ephemeral"] is True
body = interaction.response.send_message.await_args.args[0]
assert "/setup" in body
assert "/help" in body
assert "/unlink" in body
class TestHandleUnlink:
@pytest.mark.asyncio
async def test_with_frontend_url_posts_button(self):
interaction = _interaction()
fake_settings = MagicMock()
fake_settings.config.frontend_base_url = "http://localhost:3000"
fake_settings.config.platform_base_url = ""
with patch(
"backend.copilot.bot.adapters.discord.commands.Settings",
return_value=fake_settings,
):
await _handle_unlink(interaction)
interaction.response.send_message.assert_awaited_once()
sent = interaction.response.send_message.await_args
assert sent.kwargs["view"] is not None
assert sent.kwargs["ephemeral"] is True
@pytest.mark.asyncio
async def test_falls_back_to_platform_base_url(self):
interaction = _interaction()
fake_settings = MagicMock()
fake_settings.config.frontend_base_url = ""
fake_settings.config.platform_base_url = "http://other"
with patch(
"backend.copilot.bot.adapters.discord.commands.Settings",
return_value=fake_settings,
):
await _handle_unlink(interaction)
# Button uses the fallback URL.
sent = interaction.response.send_message.await_args
view = sent.kwargs["view"]
assert any(
"http://other" in getattr(child, "url", "") for child in view.children
)
@pytest.mark.asyncio
async def test_no_urls_configured_sends_plain_text(self):
interaction = _interaction()
fake_settings = MagicMock()
fake_settings.config.frontend_base_url = ""
fake_settings.config.platform_base_url = ""
with patch(
"backend.copilot.bot.adapters.discord.commands.Settings",
return_value=fake_settings,
):
await _handle_unlink(interaction)
sent = interaction.response.send_message.await_args
assert "view" not in sent.kwargs or sent.kwargs.get("view") is None
assert "Profile" in sent.args[0]

View File

@@ -0,0 +1,15 @@
"""Discord-specific configuration."""
from backend.util.settings import Settings
def get_bot_token() -> str:
return Settings().secrets.autopilot_bot_discord_token
# Discord message content limit (hard platform cap)
MAX_MESSAGE_LENGTH = 2000
# Flush the streaming buffer at 1900 — leaves 100-char headroom under the
# 2000 cap so the boundary-splitter has room to reach a natural break point.
CHUNK_FLUSH_AT = 1900

View File

@@ -0,0 +1,156 @@
"""CoPilot Chat Bridge — AppService that runs the configured chat-platform
adapters (Discord, Telegram, Slack) and exposes outbound message RPC for
other services to push messages into chat platforms.
"""
import asyncio
import logging
from concurrent.futures import Future
from backend.platform_linking.models import Platform
from backend.util.service import (
AppService,
AppServiceClient,
UnhealthyServiceError,
endpoint_to_async,
expose,
)
from backend.util.settings import Settings
from .adapters.base import PlatformAdapter
from .adapters.discord import config as discord_config
from .adapters.discord.adapter import DiscordAdapter
from .bot_backend import BotBackend
from .handler import MessageHandler
logger = logging.getLogger(__name__)
# Stay up for health-checks and runtime reconfiguration when no adapter is
# configured (e.g. deployed without a Discord token).
_NO_ADAPTER_SLEEP_SECONDS = 3600
class CoPilotChatBridge(AppService):
"""Bridges AutoPilot to external chat platforms via per-platform adapters."""
def __init__(self):
super().__init__()
# Flipped to True once `_run_adapters` reaches its blocking gather
# (or the no-adapter idle loop), and back to False if the task exits
# for any reason. Consumed by `health_check` so orchestrators can
# restart the pod when the bridge is dead-but-listening.
self._adapters_healthy = False
@classmethod
def get_port(cls) -> int:
return Settings().config.copilot_chat_bridge_port
def run_service(self) -> None:
future = asyncio.run_coroutine_threadsafe(
self._run_adapters(), self.shared_event_loop
)
future.add_done_callback(self._on_adapters_exit)
super().run_service()
async def _run_adapters(self) -> None:
api = BotBackend()
adapters = _build_adapters(api)
if not adapters:
logger.info(
"CoPilotChatBridge: no platform adapters configured — idling. "
"Set AUTOPILOT_BOT_DISCORD_TOKEN (or another platform token) to "
"enable an adapter."
)
self._adapters_healthy = True
try:
while True:
await asyncio.sleep(_NO_ADAPTER_SLEEP_SECONDS)
finally:
await api.close()
handler = MessageHandler(api)
for adapter in adapters:
adapter.on_message(handler.handle)
self._adapters_healthy = True
try:
await asyncio.gather(*(a.start() for a in adapters))
finally:
await asyncio.gather(*(a.stop() for a in adapters), return_exceptions=True)
await api.close()
def _on_adapters_exit(self, future: "Future[None]") -> None:
"""Surface exceptions from `_run_adapters` and flip the health flag.
`run_coroutine_threadsafe` would otherwise swallow the exception
into the returned future, leaving the FastAPI health endpoint
cheerfully reporting OK on a dead bridge.
"""
self._adapters_healthy = False
exc = future.exception()
if exc is not None:
logger.error("CoPilotChatBridge adapters crashed: %r", exc, exc_info=exc)
else:
logger.warning("CoPilotChatBridge adapters exited without error")
async def health_check(self) -> str:
if not self._adapters_healthy:
raise UnhealthyServiceError("CoPilotChatBridge adapter task is not running")
return await super().health_check()
@expose
async def send_message_to_channel(
self,
platform: Platform,
channel_id: str,
content: str,
) -> bool:
"""Deliver a message to a channel on the given platform.
Stub — scaffolding for the inbound-RPC pattern (backend → chat
platform). Not yet wired to a concrete adapter. Callers must not use
``request_retry=True`` on the client until this is implemented, since
``ValueError`` crosses the RPC boundary as a client-side 4xx-ish error
rather than a transient 5xx.
"""
raise ValueError(f"send_message_to_channel not yet wired for {platform.value}")
@expose
async def send_dm(
self,
platform: Platform,
platform_user_id: str,
content: str,
) -> bool:
"""Deliver a DM to a user on the given platform.
Stub — scaffolding for the inbound-RPC pattern. See
:meth:`send_message_to_channel` for the retry caveat.
"""
raise ValueError(f"send_dm not yet wired for {platform.value}")
class CoPilotChatBridgeClient(AppServiceClient):
@classmethod
def get_service_type(cls):
return CoPilotChatBridge
send_message_to_channel = endpoint_to_async(
CoPilotChatBridge.send_message_to_channel
)
send_dm = endpoint_to_async(CoPilotChatBridge.send_dm)
def _build_adapters(api: BotBackend) -> list[PlatformAdapter]:
"""Instantiate adapters based on which platform tokens are configured."""
adapters: list[PlatformAdapter] = []
if discord_config.get_bot_token():
adapters.append(DiscordAdapter(api))
logger.info("Discord adapter enabled")
# Future:
# if telegram_config.get_bot_token():
# adapters.append(TelegramAdapter(api))
# if slack_config.get_bot_token():
# adapters.append(SlackAdapter(api))
return adapters

View File

@@ -0,0 +1,194 @@
"""Bot-side facade over `PlatformLinkingManagerClient` + `stream_registry`.
The `BotBackend` class is the bot's single entry point into the AutoGPT
backend — it wraps the linking RPC client and the copilot stream registry
behind plain string-typed methods. Adapters import this directly so the
discord/telegram/slack code never touches Pyro / Redis Streams plumbing.
"""
import asyncio
import logging
from dataclasses import dataclass
from typing import AsyncGenerator, Awaitable, Callable, Optional
from backend.copilot import stream_registry
from backend.copilot.response_model import StreamError, StreamFinish, StreamTextDelta
from backend.platform_linking.models import (
BotChatRequest,
CreateLinkTokenRequest,
CreateUserLinkTokenRequest,
Platform,
)
from backend.util.clients import get_platform_linking_manager_client
from backend.util.exceptions import (
DuplicateChatMessageError,
LinkAlreadyExistsError,
NotFoundError,
)
# How long to wait for a single chunk from the copilot stream before giving
# up. Covers the case where the backend crashes mid-stream and never sends
# ``StreamFinish`` — without this, the bot would hang forever on ``queue.get()``.
STREAM_CHUNK_TIMEOUT_SECONDS = 120
logger = logging.getLogger(__name__)
__all__ = [
"BotBackend",
"DuplicateChatMessageError",
"LinkAlreadyExistsError",
"LinkTokenResult",
"NotFoundError",
"ResolveResult",
]
@dataclass
class ResolveResult:
linked: bool
@dataclass
class LinkTokenResult:
token: str
link_url: str
expires_at: str
class BotBackend:
"""Bot-side linking + chat operations, routed over cluster-internal RPC."""
def __init__(self):
self._client = get_platform_linking_manager_client()
async def close(self) -> None:
# The client's lifecycle is owned by the thread-cached factory; nothing
# to close here. Kept for API compatibility with older bot code.
pass
async def resolve_server(
self, platform: str, platform_server_id: str
) -> ResolveResult:
resp = await self._client.resolve_server_link(
platform=Platform(platform.upper()),
platform_server_id=platform_server_id,
)
return ResolveResult(linked=resp.linked)
async def resolve_user(self, platform: str, platform_user_id: str) -> ResolveResult:
resp = await self._client.resolve_user_link(
platform=Platform(platform.upper()),
platform_user_id=platform_user_id,
)
return ResolveResult(linked=resp.linked)
async def create_link_token(
self,
platform: str,
platform_server_id: str,
platform_user_id: str,
platform_username: str,
server_name: str,
channel_id: str = "",
) -> LinkTokenResult:
resp = await self._client.create_server_link_token(
request=CreateLinkTokenRequest(
platform=Platform(platform.upper()),
platform_server_id=platform_server_id,
platform_user_id=platform_user_id,
platform_username=platform_username or None,
server_name=server_name or None,
channel_id=channel_id or None,
)
)
return LinkTokenResult(
token=resp.token,
link_url=resp.link_url,
expires_at=resp.expires_at.isoformat(),
)
async def create_user_link_token(
self,
platform: str,
platform_user_id: str,
platform_username: str,
) -> LinkTokenResult:
resp = await self._client.create_user_link_token(
request=CreateUserLinkTokenRequest(
platform=Platform(platform.upper()),
platform_user_id=platform_user_id,
platform_username=platform_username or None,
)
)
return LinkTokenResult(
token=resp.token,
link_url=resp.link_url,
expires_at=resp.expires_at.isoformat(),
)
async def stream_chat(
self,
platform: str,
platform_user_id: str,
message: str,
session_id: Optional[str] = None,
platform_server_id: Optional[str] = None,
on_session_id: Optional[Callable[[str], Awaitable[None]]] = None,
) -> AsyncGenerator[str, None]:
"""Start a copilot turn and yield text deltas from the stream.
Raises :class:`DuplicateChatMessageError` if the same message is
already in flight for this session.
"""
handle = await self._client.start_chat_turn(
request=BotChatRequest(
platform=Platform(platform.upper()),
platform_user_id=platform_user_id,
message=message,
session_id=session_id,
platform_server_id=platform_server_id,
)
)
if on_session_id:
await on_session_id(handle.session_id)
queue = await stream_registry.subscribe_to_session(
session_id=handle.session_id,
user_id=handle.user_id,
last_message_id=handle.subscribe_from,
)
if queue is None:
yield "\n[Error: failed to subscribe to response stream]"
return
try:
while True:
try:
chunk = await asyncio.wait_for(
queue.get(), timeout=STREAM_CHUNK_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
logger.warning(
"Stream idle timeout after %ss for session %s",
STREAM_CHUNK_TIMEOUT_SECONDS,
handle.session_id,
)
yield "\n[Error: response timed out]"
return
if isinstance(chunk, StreamTextDelta):
if chunk.delta:
yield chunk.delta
elif isinstance(chunk, StreamFinish):
return
elif isinstance(chunk, StreamError):
logger.error("Stream error from backend: %s", chunk.errorText)
yield f"\n[Error: {chunk.errorText}]"
return
# Other StreamX types (StreamStart, StreamTextStart, tool events,
# etc.) are emitted by the executor for the frontend UI and
# aren't useful for the plain-text bot transcript.
finally:
await stream_registry.unsubscribe_from_session(
session_id=handle.session_id,
subscriber_queue=queue,
)

View File

@@ -0,0 +1,216 @@
"""Tests for the bot's thin facade over PlatformLinkingManagerClient."""
import asyncio
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot.response_model import StreamError, StreamFinish, StreamTextDelta
from backend.platform_linking.models import (
ChatTurnHandle,
LinkTokenResponse,
ResolveResponse,
)
from backend.util.exceptions import (
DuplicateChatMessageError,
LinkAlreadyExistsError,
NotFoundError,
)
from .bot_backend import BotBackend
@pytest.fixture
def api() -> BotBackend:
with patch("backend.copilot.bot.bot_backend.get_platform_linking_manager_client"):
instance = BotBackend()
# Swap in a MagicMock whose RPC methods are AsyncMocks — simpler than
# patching each call site.
instance._client = MagicMock()
return instance
class TestResolve:
@pytest.mark.asyncio
async def test_resolve_server(self, api: BotBackend):
api._client.resolve_server_link = AsyncMock(
return_value=ResolveResponse(linked=True)
)
result = await api.resolve_server("discord", "g1")
assert result.linked is True
api._client.resolve_server_link.assert_awaited_once()
@pytest.mark.asyncio
async def test_resolve_user(self, api: BotBackend):
api._client.resolve_user_link = AsyncMock(
return_value=ResolveResponse(linked=False)
)
result = await api.resolve_user("discord", "u1")
assert result.linked is False
class TestCreateLinkTokens:
@pytest.mark.asyncio
async def test_create_server_link_token(self, api: BotBackend):
api._client.create_server_link_token = AsyncMock(
return_value=LinkTokenResponse(
token="abc",
expires_at=datetime.now(timezone.utc),
link_url="https://example.com/link/abc",
)
)
result = await api.create_link_token(
platform="discord",
platform_server_id="g1",
platform_user_id="u1",
platform_username="Bently",
server_name="Test",
)
assert result.token == "abc"
assert result.link_url.endswith("/link/abc")
@pytest.mark.asyncio
async def test_create_server_link_token_propagates_already_exists(
self, api: BotBackend
):
api._client.create_server_link_token = AsyncMock(
side_effect=LinkAlreadyExistsError("already linked")
)
with pytest.raises(LinkAlreadyExistsError):
await api.create_link_token(
platform="discord",
platform_server_id="g1",
platform_user_id="u1",
platform_username="",
server_name="",
)
@pytest.mark.asyncio
async def test_create_user_link_token(self, api: BotBackend):
api._client.create_user_link_token = AsyncMock(
return_value=LinkTokenResponse(
token="xyz",
expires_at=datetime.now(timezone.utc),
link_url="https://example.com/link/xyz",
)
)
result = await api.create_user_link_token(
platform="discord", platform_user_id="u1", platform_username="Bently"
)
assert result.token == "xyz"
class TestStreamChat:
@pytest.mark.asyncio
async def test_yields_text_deltas_and_terminates_on_finish(self, api: BotBackend):
handle = ChatTurnHandle(session_id="sess", turn_id="turn", user_id="u1")
api._client.start_chat_turn = AsyncMock(return_value=handle)
queue: asyncio.Queue = asyncio.Queue()
await queue.put(StreamTextDelta(id="1", delta="Hello "))
await queue.put(StreamTextDelta(id="2", delta="world"))
await queue.put(StreamFinish())
captured_session_ids: list[str] = []
async def capture(sid: str) -> None:
captured_session_ids.append(sid)
with (
patch(
"backend.copilot.bot.bot_backend.stream_registry.subscribe_to_session",
new=AsyncMock(return_value=queue),
),
patch(
"backend.copilot.bot.bot_backend.stream_registry.unsubscribe_from_session",
new=AsyncMock(),
),
):
chunks: list[str] = []
async for chunk in api.stream_chat(
platform="discord",
platform_user_id="u1",
message="hi",
on_session_id=capture,
):
chunks.append(chunk)
assert "".join(chunks) == "Hello world"
assert captured_session_ids == ["sess"]
@pytest.mark.asyncio
async def test_surfaces_stream_error(self, api: BotBackend):
handle = ChatTurnHandle(session_id="sess", turn_id="turn", user_id="u1")
api._client.start_chat_turn = AsyncMock(return_value=handle)
queue: asyncio.Queue = asyncio.Queue()
await queue.put(StreamError(errorText="executor crashed"))
with (
patch(
"backend.copilot.bot.bot_backend.stream_registry.subscribe_to_session",
new=AsyncMock(return_value=queue),
),
patch(
"backend.copilot.bot.bot_backend.stream_registry.unsubscribe_from_session",
new=AsyncMock(),
),
):
chunks: list[str] = []
async for chunk in api.stream_chat(
platform="discord", platform_user_id="u1", message="hi"
):
chunks.append(chunk)
assert any("executor crashed" in c for c in chunks)
@pytest.mark.asyncio
async def test_duplicate_message_propagates(self, api: BotBackend):
api._client.start_chat_turn = AsyncMock(
side_effect=DuplicateChatMessageError("in flight")
)
with pytest.raises(DuplicateChatMessageError):
async for _ in api.stream_chat(
platform="discord", platform_user_id="u1", message="hi"
):
pass
@pytest.mark.asyncio
async def test_session_not_found_propagates(self, api: BotBackend):
api._client.start_chat_turn = AsyncMock(
side_effect=NotFoundError("session gone")
)
with pytest.raises(NotFoundError):
async for _ in api.stream_chat(
platform="discord",
platform_user_id="u1",
message="hi",
session_id="missing",
):
pass
@pytest.mark.asyncio
async def test_subscribe_returns_none_yields_error(self, api: BotBackend):
handle = ChatTurnHandle(session_id="sess", turn_id="turn", user_id="u1")
api._client.start_chat_turn = AsyncMock(return_value=handle)
with (
patch(
"backend.copilot.bot.bot_backend.stream_registry.subscribe_to_session",
new=AsyncMock(return_value=None),
),
patch(
"backend.copilot.bot.bot_backend.stream_registry.unsubscribe_from_session",
new=AsyncMock(),
),
):
chunks: list[str] = []
async for chunk in api.stream_chat(
platform="discord", platform_user_id="u1", message="hi"
):
chunks.append(chunk)
assert any("failed to subscribe" in c.lower() for c in chunks)

View File

@@ -0,0 +1,4 @@
"""Platform-agnostic bot config."""
# Cache TTL for AutoPilot session IDs (per channel/thread)
SESSION_TTL = 86400 # 24 hours

View File

@@ -0,0 +1,280 @@
"""Platform-agnostic message handler.
Receives a MessageContext from any adapter and drives the full AutoPilot
interaction: link resolution, thread routing, batched streaming with a
persistent typing indicator.
"""
import asyncio
import logging
from dataclasses import dataclass, field
from backend.data.redis_client import get_redis_async
from backend.util.exceptions import (
DuplicateChatMessageError,
LinkAlreadyExistsError,
NotFoundError,
)
from . import threads
from .adapters.base import MessageContext, PlatformAdapter
from .bot_backend import BotBackend
from .config import SESSION_TTL
from .text import format_batch, split_at_boundary
logger = logging.getLogger(__name__)
@dataclass
class TargetState:
"""Per-target streaming state.
A "target" is wherever the bot replies — a thread ID, a DM channel ID.
`pending` holds messages that arrived while a stream was running; they
get drained as a single batched follow-up turn when the stream ends.
"""
processing: bool = False
pending: list[tuple[str, str, str]] = field(default_factory=list)
# Each entry: (username, user_id, text)
class MessageHandler:
def __init__(self, api: BotBackend):
self._api = api
self._targets: dict[str, TargetState] = {}
async def handle(self, ctx: MessageContext, adapter: PlatformAdapter) -> None:
if not ctx.text.strip():
if ctx.channel_type == "channel":
await adapter.send_reply(
ctx.channel_id,
"You mentioned me but didn't say anything. How can I help?",
ctx.message_id,
)
return
if not await self._ensure_linked(ctx, adapter):
return
target_id = await self._resolve_target(ctx, adapter)
if not target_id:
return # Thread not subscribed, ignore silently
await self._enqueue_and_process(ctx, adapter, target_id)
# -- Target resolution --
async def _resolve_target(
self, ctx: MessageContext, adapter: PlatformAdapter
) -> str | None:
if ctx.channel_type == "dm":
return ctx.channel_id
if ctx.channel_type == "thread":
if await threads.is_subscribed(ctx.platform, ctx.channel_id):
return ctx.channel_id
return None
# channel_type == "channel" — create a thread and subscribe
thread_name = f"{ctx.username} × AutoPilot"
thread_id = await adapter.create_thread(
ctx.channel_id, ctx.message_id, thread_name
)
if not thread_id:
logger.warning("Thread creation failed, falling back to channel reply")
return ctx.channel_id
await threads.subscribe(ctx.platform, thread_id)
return thread_id
# -- Batched streaming --
async def _enqueue_and_process(
self, ctx: MessageContext, adapter: PlatformAdapter, target_id: str
) -> None:
state = self._targets.setdefault(target_id, TargetState())
state.pending.append((ctx.username, ctx.user_id, ctx.text))
if state.processing:
# Another invocation is streaming for this target — it will pick
# up the message we just appended when its current stream ends.
return
state.processing = True
try:
while state.pending:
batch = list(state.pending)
state.pending.clear()
await self._stream_batch(batch, ctx, adapter, target_id)
finally:
state.processing = False
# Drop the empty state so the dict doesn't grow unbounded across
# the bot's lifetime.
if not state.pending:
self._targets.pop(target_id, None)
async def _stream_batch(
self,
batch: list[tuple[str, str, str]],
ctx: MessageContext,
adapter: PlatformAdapter,
target_id: str,
) -> None:
prefixed = format_batch(batch, ctx.platform)
redis = await get_redis_async()
cache_key = f"copilot-bot:session:{ctx.platform}:{target_id}"
cached_session_id = await redis.get(cache_key)
async def _on_session_id(sid: str) -> None:
try:
await redis.set(cache_key, sid, ex=SESSION_TTL)
except Exception:
logger.warning("Failed to cache session id for target %s", target_id)
flush_at = adapter.chunk_flush_at
buffer = ""
sent_any_content = False
typing_task = asyncio.create_task(_keep_typing(adapter, target_id))
try:
async for chunk in self._api.stream_chat(
platform=ctx.platform,
platform_user_id=ctx.user_id,
message=prefixed,
session_id=cached_session_id,
platform_server_id=ctx.server_id,
on_session_id=_on_session_id,
):
buffer += chunk
if len(buffer) >= flush_at:
post, buffer = split_at_boundary(buffer, flush_at)
if post:
await adapter.send_message(target_id, post)
if post.strip():
sent_any_content = True
except DuplicateChatMessageError:
# Another in-flight turn is already processing this exact message —
# stay quiet so the user doesn't get a double response.
logger.info("Duplicate message dropped for target %s", target_id)
return
except NotFoundError:
logger.exception("Chat turn rejected")
await adapter.send_message(
target_id, "AutoPilot ran into an error. Try again later."
)
return
except Exception:
logger.exception(
"Unexpected error during streaming for target %s", target_id
)
await adapter.send_message(
target_id,
"Something went wrong. Try again in a moment.",
)
return
finally:
typing_task.cancel()
try:
await typing_task
except asyncio.CancelledError:
pass
await adapter.stop_typing(target_id)
if buffer.strip():
await adapter.send_message(target_id, buffer)
sent_any_content = True
if not sent_any_content:
await adapter.send_message(
target_id,
"AutoPilot didn't produce a response. Try rephrasing your question.",
)
# -- Linking --
async def _ensure_linked(
self, ctx: MessageContext, adapter: PlatformAdapter
) -> bool:
try:
if ctx.is_dm:
result = await self._api.resolve_user(ctx.platform, ctx.user_id)
if not result.linked:
await self._prompt_user_link(ctx, adapter)
return False
else:
if not ctx.server_id:
logger.error("Non-DM message missing server_id: %r", ctx)
return False
result = await self._api.resolve_server(ctx.platform, ctx.server_id)
if not result.linked:
await adapter.send_message(
ctx.channel_id,
"This server isn't linked to an AutoGPT account yet. "
"Ask a server admin to run `/setup` first.",
)
return False
except ValueError:
# ValueError-based domain exceptions (NotFoundError etc.) arrive
# over RPC with this base type.
logger.exception("Failed to check link status")
await adapter.send_message(
ctx.channel_id, "Something went wrong. Try again later."
)
return False
except Exception:
logger.exception("Unexpected error while checking link status")
await adapter.send_message(
ctx.channel_id,
"Something went wrong. Try again in a moment.",
)
return False
return True
async def _prompt_user_link(
self, ctx: MessageContext, adapter: PlatformAdapter
) -> None:
try:
result = await self._api.create_user_link_token(
platform=ctx.platform,
platform_user_id=ctx.user_id,
platform_username=ctx.username,
)
platform_display = ctx.platform.capitalize()
await adapter.send_link(
ctx.channel_id,
f"Your {platform_display} DMs aren't linked to an AutoGPT "
"account yet. Click below to connect — once linked, you can "
"chat with AutoPilot right here.",
link_label="Link Account",
link_url=result.link_url,
)
except LinkAlreadyExistsError:
# Race: user got linked between resolve_user and create. Re-check
# — if still not linked, the backend returned a stale error and
# we shouldn't spam the user.
re_check = await self._api.resolve_user(ctx.platform, ctx.user_id)
if re_check.linked:
return
logger.exception(
"create_user_link_token raised 'already exists' "
"but user isn't actually linked"
)
except Exception:
logger.exception("Failed to create user link token")
await adapter.send_message(
ctx.channel_id,
"Something went wrong setting up the link. Try again later.",
)
async def _keep_typing(adapter: PlatformAdapter, target_id: str) -> None:
"""Re-fire the typing indicator every 8s so it doesn't expire mid-stream."""
try:
while True:
await adapter.start_typing(target_id)
await asyncio.sleep(8)
except asyncio.CancelledError:
raise
except Exception:
logger.debug("Typing loop error", exc_info=True)

View File

@@ -0,0 +1,338 @@
"""Tests for the platform-agnostic message handler."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.util.exceptions import DuplicateChatMessageError, NotFoundError
from .adapters.base import ChannelType, MessageContext
from .bot_backend import LinkTokenResult, ResolveResult
from .handler import MessageHandler, TargetState
def _ctx(
*,
channel_type: ChannelType = "channel",
server_id: str | None = "guild-1",
channel_id: str = "chan-1",
message_id: str = "msg-1",
user_id: str = "user-1",
username: str = "Bently",
text: str = "hello bot",
) -> MessageContext:
return MessageContext(
platform="discord",
channel_type=channel_type,
server_id=server_id,
channel_id=channel_id,
message_id=message_id,
user_id=user_id,
username=username,
text=text,
)
def _adapter() -> MagicMock:
adapter = MagicMock()
adapter.chunk_flush_at = 1900
adapter.send_message = AsyncMock()
adapter.send_reply = AsyncMock()
adapter.send_link = AsyncMock()
adapter.start_typing = AsyncMock()
adapter.stop_typing = AsyncMock()
adapter.create_thread = AsyncMock(return_value="thread-new")
return adapter
def _api(*, server_linked: bool = True, user_linked: bool = True) -> MagicMock:
api = MagicMock()
api.resolve_server = AsyncMock(return_value=ResolveResult(linked=server_linked))
api.resolve_user = AsyncMock(return_value=ResolveResult(linked=user_linked))
api.create_user_link_token = AsyncMock(
return_value=LinkTokenResult(
token="t",
link_url="https://example.com/link/t",
expires_at="2099-01-01T00:00:00Z",
)
)
async def _empty_stream(*args, **kwargs):
if False:
yield ""
api.stream_chat = _empty_stream
return api
class TestEmptyMessage:
@pytest.mark.asyncio
async def test_channel_mention_without_text_gets_nudge(self):
handler = MessageHandler(_api())
adapter = _adapter()
await handler.handle(_ctx(text=" "), adapter)
adapter.send_reply.assert_awaited_once()
adapter.send_message.assert_not_awaited()
@pytest.mark.asyncio
async def test_empty_dm_is_silently_dropped(self):
handler = MessageHandler(_api())
adapter = _adapter()
await handler.handle(_ctx(channel_type="dm", text=""), adapter)
adapter.send_reply.assert_not_awaited()
adapter.send_message.assert_not_awaited()
class TestEnsureLinked:
@pytest.mark.asyncio
async def test_unlinked_server_tells_user_to_setup(self):
handler = MessageHandler(_api(server_linked=False))
adapter = _adapter()
await handler.handle(_ctx(), adapter)
call_args = adapter.send_message.await_args.args
assert "isn't linked" in call_args[1]
assert "/setup" in call_args[1]
@pytest.mark.asyncio
async def test_unlinked_dm_prompts_link_flow(self):
handler = MessageHandler(_api(user_linked=False))
adapter = _adapter()
await handler.handle(_ctx(channel_type="dm", server_id=None), adapter)
adapter.send_link.assert_awaited_once()
assert adapter.send_link.await_args.kwargs["link_url"].startswith(
"https://example.com/link/"
)
@pytest.mark.asyncio
async def test_non_dm_without_server_id_is_rejected(self):
handler = MessageHandler(_api())
adapter = _adapter()
await handler.handle(_ctx(server_id=None), adapter)
# Guard short-circuits before calling resolve_server.
handler._api.resolve_server.assert_not_awaited()
adapter.send_message.assert_not_awaited()
@pytest.mark.asyncio
async def test_backend_error_in_resolve_produces_message(self):
api = _api()
api.resolve_server = AsyncMock(side_effect=NotFoundError("boom"))
handler = MessageHandler(api)
adapter = _adapter()
await handler.handle(_ctx(), adapter)
adapter.send_message.assert_awaited_once()
assert "went wrong" in adapter.send_message.await_args.args[1].lower()
class TestResolveTarget:
@pytest.mark.asyncio
async def test_dm_reuses_channel_id(self):
handler = MessageHandler(_api())
adapter = _adapter()
ctx = _ctx(channel_type="dm", server_id=None, channel_id="dm-42")
result = await handler._resolve_target(ctx, adapter)
assert result == "dm-42"
@pytest.mark.asyncio
async def test_unsubscribed_thread_returns_none(self):
handler = MessageHandler(_api())
adapter = _adapter()
ctx = _ctx(channel_type="thread", channel_id="thread-old")
with patch(
"backend.copilot.bot.handler.threads.is_subscribed",
new=AsyncMock(return_value=False),
):
assert await handler._resolve_target(ctx, adapter) is None
@pytest.mark.asyncio
async def test_subscribed_thread_keeps_channel(self):
handler = MessageHandler(_api())
adapter = _adapter()
ctx = _ctx(channel_type="thread", channel_id="thread-ok")
with patch(
"backend.copilot.bot.handler.threads.is_subscribed",
new=AsyncMock(return_value=True),
):
assert await handler._resolve_target(ctx, adapter) == "thread-ok"
@pytest.mark.asyncio
async def test_channel_creates_and_subscribes_thread(self):
handler = MessageHandler(_api())
adapter = _adapter()
adapter.create_thread = AsyncMock(return_value="thread-created")
with patch(
"backend.copilot.bot.handler.threads.subscribe", new=AsyncMock()
) as subscribe:
result = await handler._resolve_target(_ctx(), adapter)
assert result == "thread-created"
subscribe.assert_awaited_once_with("discord", "thread-created")
@pytest.mark.asyncio
async def test_channel_falls_back_to_parent_when_thread_creation_fails(self):
handler = MessageHandler(_api())
adapter = _adapter()
adapter.create_thread = AsyncMock(return_value=None)
result = await handler._resolve_target(_ctx(channel_id="parent-chan"), adapter)
assert result == "parent-chan"
class TestBatching:
@pytest.mark.asyncio
async def test_concurrent_message_queues_when_processing(self):
"""Second caller with processing=True returns without starting a new stream."""
handler = MessageHandler(_api())
adapter = _adapter()
state = TargetState(processing=True)
handler._targets["target-1"] = state
await handler._enqueue_and_process(_ctx(text="second"), adapter, "target-1")
assert state.processing is True
assert state.pending == [("Bently", "user-1", "second")]
@pytest.mark.asyncio
async def test_target_state_cleared_after_drain(self):
handler = MessageHandler(_api())
adapter = _adapter()
stream_calls: list[list] = []
async def fake_stream_batch(batch, ctx, ad, tid):
stream_calls.append(list(batch))
handler._stream_batch = fake_stream_batch # type: ignore[method-assign]
await handler._enqueue_and_process(_ctx(text="hello"), adapter, "target-1")
assert stream_calls == [[("Bently", "user-1", "hello")]]
# Dict entry should be gone once processing finishes with empty pending.
assert "target-1" not in handler._targets
@pytest.mark.asyncio
async def test_drain_loop_picks_up_appended_messages(self):
"""Messages appended to pending mid-drain are processed in the next iter."""
handler = MessageHandler(_api())
adapter = _adapter()
state = TargetState()
handler._targets["target-1"] = state
seen: list[list] = []
async def fake_stream_batch(batch, ctx, ad, tid):
seen.append(list(batch))
if len(seen) == 1:
# Simulate another caller appending during the first stream.
state.pending.append(("Later", "u2", "follow-up"))
handler._stream_batch = fake_stream_batch # type: ignore[method-assign]
await handler._enqueue_and_process(_ctx(text="first"), adapter, "target-1")
assert seen == [
[("Bently", "user-1", "first")],
[("Later", "u2", "follow-up")],
]
assert "target-1" not in handler._targets
@pytest.mark.asyncio
async def test_duplicate_message_is_silently_dropped(self):
api = _api()
async def duplicate_stream(*args, **kwargs):
raise DuplicateChatMessageError("in flight")
yield "" # pragma: no cover
api.stream_chat = duplicate_stream
handler = MessageHandler(api)
adapter = _adapter()
with patch(
"backend.copilot.bot.handler.get_redis_async",
new=AsyncMock(return_value=AsyncMock(get=AsyncMock(return_value=None))),
):
await handler._stream_batch(
[("Bently", "u1", "hi")], _ctx(), adapter, "target-1"
)
adapter.send_message.assert_not_awaited()
class TestStreamFallback:
"""Covers the empty-response fallback, including the boundary-flush bug
where prior code posted 'AutoPilot didn't produce a response' even though
content had already been flushed mid-stream.
"""
@staticmethod
def _redis_patch():
return patch(
"backend.copilot.bot.handler.get_redis_async",
new=AsyncMock(return_value=AsyncMock(get=AsyncMock(return_value=None))),
)
@pytest.mark.asyncio
async def test_empty_stream_sends_fallback(self):
api = _api()
async def empty(*args, **kwargs):
if False:
yield ""
api.stream_chat = empty
handler = MessageHandler(api)
adapter = _adapter()
with TestStreamFallback._redis_patch():
await handler._stream_batch(
[("Bently", "u1", "hi")], _ctx(), adapter, "target-1"
)
msgs = [c.args[1] for c in adapter.send_message.await_args_list]
assert any("didn't produce a response" in m for m in msgs)
@pytest.mark.asyncio
async def test_whitespace_only_stream_sends_fallback(self):
api = _api()
async def whitespace(*args, **kwargs):
yield " "
yield "\n\n"
api.stream_chat = whitespace
handler = MessageHandler(api)
adapter = _adapter()
with TestStreamFallback._redis_patch():
await handler._stream_batch(
[("Bently", "u1", "hi")], _ctx(), adapter, "target-1"
)
msgs = [c.args[1] for c in adapter.send_message.await_args_list]
assert any("didn't produce a response" in m for m in msgs)
@pytest.mark.asyncio
async def test_content_flushed_mid_stream_does_not_trigger_fallback(self):
"""Regression: before the fix, a response that flushed exactly at a
boundary left buffer == "" and the fallback fired after real content
had already been posted.
"""
api = _api()
adapter = _adapter()
adapter.chunk_flush_at = 50
async def streaming_content(*args, **kwargs):
# Exactly flush_at chars → split_at_boundary returns the whole
# payload as the post and an empty remainder, so the stream ends
# with buffer == "". That USED to fall into the `elif not buffer`
# branch and send the "didn't produce a response" fallback.
yield "x" * 50
api.stream_chat = streaming_content
handler = MessageHandler(api)
with TestStreamFallback._redis_patch():
await handler._stream_batch(
[("Bently", "u1", "hi")], _ctx(), adapter, "target-1"
)
msgs = [c.args[1] for c in adapter.send_message.await_args_list]
assert not any("didn't produce a response" in m for m in msgs)
assert msgs == ["x" * 50]

View File

@@ -0,0 +1,80 @@
"""Text formatting helpers — message batching and chunk splitting."""
import re
# Matches a triple-backtick fence with an optional language tag. Used to tell
# whether a cut falls inside an open Markdown code block.
_CODE_FENCE = re.compile(r"```(\w*)")
def format_batch(batch: list[tuple[str, str, str]], platform: str) -> str:
"""Format one or more pending messages into a single prompt for AutoPilot.
Each batch entry is (username, user_id, text). When multiple messages are
batched together (because they arrived while the bot was streaming a prior
response), they're labelled individually so the LLM can address each.
"""
platform_display = platform.capitalize()
if len(batch) == 1:
username, user_id, text = batch[0]
return (
f"[Message sent by {username} ({platform_display} user ID: {user_id})]\n"
f"{text}"
)
lines = ["[Multiple messages — please address them together]"]
for username, user_id, text in batch:
lines.append(
f"\n[From {username} ({platform_display} user ID: {user_id})]\n{text}"
)
return "\n".join(lines)
def split_at_boundary(text: str, flush_at: int) -> tuple[str, str]:
"""Split text at a natural boundary to fit within a length limit.
Returns (postable_chunk, remaining_text).
Prefers: paragraph > newline > sentence end > space > hard cut.
If the cut lands inside a Markdown code fence (``\\`\\`\\``), the fence is
closed in the chunk and reopened at the start of the remainder so both
sides render correctly.
"""
if len(text) <= flush_at:
return text, ""
search_start = max(0, flush_at - 200)
search_region = text[search_start:flush_at]
for sep in ("\n\n", "\n"):
idx = search_region.rfind(sep)
if idx != -1:
cut = search_start + idx
return _balance_code_fences(text[:cut].rstrip(), text[cut:].lstrip("\n"))
for sep in (". ", "! ", "? "):
idx = search_region.rfind(sep)
if idx != -1:
cut = search_start + idx + len(sep)
return _balance_code_fences(text[:cut], text[cut:])
idx = search_region.rfind(" ")
if idx != -1:
cut = search_start + idx
return _balance_code_fences(text[:cut], text[cut:].lstrip())
return _balance_code_fences(text[:flush_at], text[flush_at:])
def _balance_code_fences(before: str, after: str) -> tuple[str, str]:
"""If ``before`` ends inside an open ``\\`\\`\\`` fence, close and reopen it.
Preserves the language tag from the opening fence so syntax highlighting
survives the split.
"""
fences = _CODE_FENCE.findall(before)
if len(fences) % 2 == 0:
return before, after
lang = fences[-1]
closed_before = f"{before.rstrip()}\n```"
reopened_after = f"```{lang}\n{after.lstrip()}"
return closed_before, reopened_after

View File

@@ -0,0 +1,105 @@
"""Tests for message batching + boundary splitting."""
from .text import _balance_code_fences, format_batch, split_at_boundary
class TestFormatBatch:
def test_single_message_has_header(self):
result = format_batch([("Bently", "123", "hello")], "discord")
assert result == "[Message sent by Bently (Discord user ID: 123)]\nhello"
def test_multi_message_labels_each_sender(self):
result = format_batch(
[
("Alice", "a1", "first"),
("Bob", "b2", "second"),
],
"discord",
)
assert "[Multiple messages" in result
assert "[From Alice (Discord user ID: a1)]\nfirst" in result
assert "[From Bob (Discord user ID: b2)]\nsecond" in result
def test_platform_name_is_capitalized(self):
result = format_batch([("u", "1", "x")], "telegram")
assert "Telegram user ID" in result
class TestSplitAtBoundary:
def test_short_text_returns_unchanged(self):
before, after = split_at_boundary("short", 100)
assert before == "short"
assert after == ""
def test_splits_at_paragraph_boundary(self):
text = "first paragraph.\n\nsecond paragraph that is long enough"
before, after = split_at_boundary(text, 20)
assert before == "first paragraph."
assert after == "second paragraph that is long enough"
def test_splits_at_newline_when_no_paragraph(self):
text = "line one\nline two line three line four line five"
before, after = split_at_boundary(text, 15)
assert before == "line one"
assert after == "line two line three line four line five"
def test_splits_at_sentence_when_no_newline(self):
text = "First sentence. Second sentence is quite a bit longer here."
before, after = split_at_boundary(text, 20)
assert before == "First sentence. "
assert after == "Second sentence is quite a bit longer here."
def test_falls_back_to_space_split(self):
text = "word " * 50
before, after = split_at_boundary(text, 30)
assert not before.endswith(" ")
# Rejoining drops one space at the cut, but no characters other
# than whitespace should be lost.
rejoined = (before + " " + after).replace(" ", " ").strip()
assert rejoined == text.strip()
def test_hard_cut_on_single_long_token(self):
text = "a" * 500
before, after = split_at_boundary(text, 100)
assert len(before) == 100
assert after == "a" * 400
class TestBalanceCodeFences:
def test_balanced_code_unchanged(self):
before = "prose\n```py\nprint('x')\n```\ntail"
after = "more"
b, a = _balance_code_fences(before, after)
assert b == before
assert a == after
def test_open_fence_gets_closed_and_reopened(self):
before = "prose\n```py\nprint('x')"
after = "print('y')\n```\ntail"
b, a = _balance_code_fences(before, after)
assert b.endswith("```")
assert a.startswith("```py\n")
def test_reopens_with_no_lang_when_opener_had_none(self):
before = "```\nsome code here"
after = "more code\n```"
b, a = _balance_code_fences(before, after)
assert b.endswith("\n```")
assert a.startswith("```\n")
def test_preserves_latest_language_when_multiple_fences(self):
before = "```py\nprint()\n```\nmiddle\n```ts\nconst x = 1"
after = "const y = 2\n```"
b, a = _balance_code_fences(before, after)
assert b.endswith("```")
assert a.startswith("```ts\n")
class TestSplitAtBoundaryWithCodeFences:
def test_split_inside_fence_rebalances(self):
code_block = "```python\n" + ("line\n" * 500) + "```\nafter"
before, after = split_at_boundary(code_block, 300)
# ``before`` must close the fence it opened.
assert before.count("```") % 2 == 0
# ``after`` must reopen with the same language tag.
assert after.lstrip().startswith("```python")

View File

@@ -0,0 +1,25 @@
"""Thread subscription tracking.
When the bot creates a thread in response to an @mention, we record the
thread ID so subsequent messages in it don't require another mention.
Subscriptions live in Redis with a 7-day TTL — stale threads age out
automatically.
"""
from backend.data.redis_client import get_redis_async
THREAD_SUBSCRIPTION_TTL = 7 * 86400 # 7 days
def _key(platform: str, thread_id: str) -> str:
return f"copilot-bot:thread:{platform}:{thread_id}"
async def is_subscribed(platform: str, thread_id: str) -> bool:
redis = await get_redis_async()
return bool(await redis.get(_key(platform, thread_id)))
async def subscribe(platform: str, thread_id: str) -> None:
redis = await get_redis_async()
await redis.set(_key(platform, thread_id), "1", ex=THREAD_SUBSCRIPTION_TTL)

View File

@@ -0,0 +1,55 @@
"""Tests for Redis-backed thread subscription tracking."""
from unittest.mock import AsyncMock, patch
import pytest
from . import threads
@pytest.fixture
def redis_mock():
mock = AsyncMock()
mock.get = AsyncMock()
mock.set = AsyncMock()
with patch("backend.copilot.bot.threads.get_redis_async", return_value=mock):
yield mock
class TestSubscribe:
@pytest.mark.asyncio
async def test_writes_key_with_ttl(self, redis_mock):
await threads.subscribe("discord", "thread-123")
redis_mock.set.assert_awaited_once_with(
"copilot-bot:thread:discord:thread-123",
"1",
ex=threads.THREAD_SUBSCRIPTION_TTL,
)
@pytest.mark.asyncio
async def test_key_includes_platform(self, redis_mock):
await threads.subscribe("telegram", "t-1")
key = redis_mock.set.await_args.args[0]
assert "telegram" in key
assert "t-1" in key
class TestIsSubscribed:
@pytest.mark.asyncio
async def test_returns_true_when_present(self, redis_mock):
redis_mock.get.return_value = "1"
assert await threads.is_subscribed("discord", "thread-1") is True
@pytest.mark.asyncio
async def test_returns_false_when_missing(self, redis_mock):
redis_mock.get.return_value = None
assert await threads.is_subscribed("discord", "thread-1") is False
@pytest.mark.asyncio
async def test_uses_same_key_as_subscribe(self, redis_mock):
redis_mock.get.return_value = None
await threads.is_subscribed("discord", "thread-1")
await threads.subscribe("discord", "thread-1")
read_key = redis_mock.get.await_args.args[0]
write_key = redis_mock.set.await_args.args[0]
assert read_key == write_key

View File

@@ -395,11 +395,17 @@ class ChatConfig(BaseSettings):
@property
def openrouter_active(self) -> bool:
"""True when OpenRouter is enabled AND credentials are usable.
"""True when OpenRouter config is shape-valid (flag + credentials).
Single source of truth for "will the SDK route through OpenRouter?".
Checks the flag *and* that ``api_key`` + a valid ``base_url`` are
present — mirrors the fallback logic in ``build_sdk_env``.
Indicates whether OpenRouter settings are present and usable —
``use_openrouter`` set, plus ``api_key`` + a valid ``base_url``,
mirroring the fallback logic in ``build_sdk_env``.
Note: this is a **config-shape check only**. Runtime SDK routing
is governed by ``effective_transport`` — subscription mode
bypasses OpenRouter entirely even when these fields are set, so
callers asking "will the SDK actually route through OpenRouter
for this turn?" should use ``effective_transport`` instead.
"""
if not self.use_openrouter:
return False
@@ -408,6 +414,34 @@ class ChatConfig(BaseSettings):
base = base[:-3]
return bool(self.api_key and base and base.startswith("http"))
@property
def effective_transport(
self,
) -> Literal["subscription", "openrouter", "direct_anthropic"]:
"""The transport the SDK CLI subprocess actually uses for this turn.
Detection order:
1. ``subscription`` — when ``use_claude_code_subscription`` is True
the CLI uses OAuth from the keychain or
``CLAUDE_CODE_OAUTH_TOKEN`` and ignores ``CHAT_BASE_URL`` /
``CHAT_API_KEY`` entirely (see ``build_sdk_env`` mode 1).
2. ``openrouter`` — when ``openrouter_active`` (use_openrouter +
api_key + a valid base_url).
3. ``direct_anthropic`` — fallback (CLI talks to api.anthropic.com
with ``ANTHROPIC_API_KEY`` from parent env).
Use this when the question is "which model-name format will the
CLI accept?" — the OpenRouter slug ``anthropic/claude-opus-4.7``
works through the proxy but is rejected by the subscription /
direct-Anthropic transports.
"""
if self.use_claude_code_subscription:
return "subscription"
if self.openrouter_active:
return "openrouter"
return "direct_anthropic"
@property
def e2b_active(self) -> bool:
"""True when E2B is enabled and the API key is present.
@@ -532,9 +566,13 @@ class ChatConfig(BaseSettings):
(``claude_agent_fallback_model`` via ``_resolve_fallback_model``).
Skipped when ``use_claude_code_subscription=True`` because the
subscription path resolves the model to ``None`` (CLI default)
and never calls ``_normalize_model_name``. Empty fallback strings
are also skipped (no fallback configured).
subscription path normally resolves the static config to ``None``
(CLI default). An LD-served override under subscription does
flow through ``_normalize_model_name``; the runtime guard first
falls back to the tier default, and only avoids a request error
when that default is itself valid (otherwise the original LD
ValueError is re-raised — see ``_resolve_sdk_model_for_request``).
Empty fallback strings are also skipped (no fallback configured).
"""
if self.use_claude_code_subscription:
return self

View File

@@ -85,7 +85,7 @@ class CoPilotExecutor(AppProcess):
self._run_client = None
self._task_locks: dict[str, ClusterLock] = {}
self._active_tasks_lock = threading.Lock()
self._active_tasks_lock_obj: threading.Lock | None = None
# ============ Main Entry Points (AppProcess interface) ============ #
@@ -502,6 +502,12 @@ class CoPilotExecutor(AppProcess):
# ============ Lazy-initialized Properties ============ #
@property
def _active_tasks_lock(self) -> threading.Lock:
if self._active_tasks_lock_obj is None:
self._active_tasks_lock_obj = threading.Lock()
return self._active_tasks_lock_obj
@property
def cancel_thread(self) -> threading.Thread:
if self._cancel_thread is None:

View File

@@ -35,12 +35,14 @@ SHUTDOWN_ERROR_MESSAGE = (
"Copilot executor shut down before this turn finished. Please retry."
)
# Max time execute() blocks after calling future.cancel() / when draining a
# soon-to-be-cancelled future. Gives _execute_async's own finally a chance to
# publish the accurate terminal state over the Redis CAS; long enough to let
# an in-flight Redis call settle, short enough that shutdown doesn't stall.
# Max time execute() blocks after requesting async turn cancellation. The worker
# waits for normal cleanup so late stream writes do not race the manager, but it
# must still escape to the sync fail-close safety net if cleanup wedges.
_CANCEL_GRACE_SECONDS = 5.0
# How long to wait before logging again that a cancelled turn is still draining.
_CANCEL_DRAIN_LOG_INTERVAL_SECONDS = 1.0
# Max time the sync safety net itself spends on a single Redis CAS. Without
# this bound the whole point of ``sync_fail_close_session`` is defeated —
# ``mark_session_completed`` would hang on the same broken Redis that caused
@@ -92,9 +94,11 @@ def sync_fail_close_session(
timeout=_FAIL_CLOSE_REDIS_TIMEOUT,
)
coro = _bounded()
try:
future = asyncio.run_coroutine_threadsafe(_bounded(), execution_loop)
future = asyncio.run_coroutine_threadsafe(coro, execution_loop)
except RuntimeError as e:
coro.close()
# execution_loop is closed — happens if cleanup() already ran the
# per-worker teardown. Nothing we can do; let the stale-session
# watchdog reap it.
@@ -336,8 +340,7 @@ class CoPilotProcessor:
Thin wrapper around :meth:`_execute`. The ``try/finally`` here
guarantees :func:`sync_fail_close_session` runs on every exit
path — normal completion, exception, or a wedged event loop
that escapes via :data:`_CANCEL_GRACE_SECONDS` timeout.
path — normal completion or exception.
``mark_session_completed`` is an atomic CAS on
``status == "running"``, so when the async path already wrote a
terminal state the sync call is a cheap no-op.
@@ -370,40 +373,92 @@ class CoPilotProcessor:
that lives in :func:`sync_fail_close_session` which the outer
:meth:`execute` always invokes on exit.
"""
task_ready: concurrent.futures.Future[asyncio.Task] = (
concurrent.futures.Future()
)
async def run_async_turn():
task = asyncio.current_task()
if task is not None and not task_ready.done():
task_ready.set_result(task)
return await self._execute_async(entry, cancel, cluster_lock, log)
future = asyncio.run_coroutine_threadsafe(
self._execute_async(entry, cancel, cluster_lock, log),
run_async_turn(),
self.execution_loop,
)
# Wait for completion, checking cancel periodically
while not future.done():
cancel_requested = False
cancel_started_at: float | None = None
last_cancel_log_at: float | None = None
def request_cancel() -> None:
nonlocal cancel_requested, cancel_started_at, last_cancel_log_at
log.info("Cancellation requested")
try:
task = task_ready.result(timeout=0)
except concurrent.futures.TimeoutError:
# Sub-millisecond race: ``run_coroutine_threadsafe`` returned
# before ``run_async_turn`` actually started, so
# ``task_ready.set_result`` has not run yet. ``future.cancel``
# on a ``concurrent.futures.Future`` whose underlying task may
# already be picked up by the loop is best-effort — frequently
# a no-op. The slow path is intentional: ``cancel.is_set()``
# is polled inside ``_execute_async`` and the bounded
# ``_CANCEL_GRACE_SECONDS`` drain below force-cancels and falls
# through to ``sync_fail_close_session``, so the worst-case
# observable behaviour is "cancel takes ~5s in this rare race"
# rather than a stuck session.
future.cancel()
else:
self.execution_loop.call_soon_threadsafe(task.cancel)
cancel_requested = True
cancel_started_at = time.monotonic()
last_cancel_log_at = cancel_started_at
def log_cancel_wait() -> None:
nonlocal last_cancel_log_at
if cancel_started_at is None or last_cancel_log_at is None:
return
now = time.monotonic()
if now - last_cancel_log_at < _CANCEL_DRAIN_LOG_INTERVAL_SECONDS:
return
elapsed = now - cancel_started_at
log.warning(f"Waiting for cancelled turn to drain ({elapsed:.1f}s elapsed)")
last_cancel_log_at = now
def cancel_drain_timed_out() -> bool:
if cancel_started_at is None:
return False
elapsed = time.monotonic() - cancel_started_at
if elapsed < _CANCEL_GRACE_SECONDS:
return False
log.warning(
f"Cancelled turn did not drain within {_CANCEL_GRACE_SECONDS:.1f}s; "
"falling through to sync fail-close"
)
future.cancel()
return True
# Wait for completion, checking cancel periodically. A cancellation
# request waits for normal async cleanup, but remains bounded so the
# worker does not refresh the per-session lock forever on a wedged turn.
while True:
try:
future.result(timeout=1.0)
except asyncio.TimeoutError:
if cancel.is_set():
log.info("Cancellation requested")
future.cancel()
# Give _execute_async's own finally a short window to
# publish its accurate terminal state before the outer
# sync safety net fires.
try:
future.result(timeout=_CANCEL_GRACE_SECONDS)
except BaseException:
pass
return
except concurrent.futures.CancelledError:
if cancel_requested or cancel.is_set():
return
cluster_lock.refresh()
if not future.cancelled():
# Bounded timeout so a wedged event loop can't trap us here —
# on timeout we escape to execute()'s finally and the sync
# safety net fires.
try:
future.result(timeout=_CANCEL_GRACE_SECONDS)
raise
except concurrent.futures.TimeoutError:
log.warning(
"Future did not complete within grace window; "
"falling through to sync fail-close"
)
if cancel.is_set() and not cancel_requested:
request_cancel()
elif cancel_requested and cancel_started_at is not None:
if cancel_drain_timed_out():
return
log_cancel_wait()
cluster_lock.refresh()
async def _execute_async(
self,

View File

@@ -496,3 +496,108 @@ class TestExecuteSafetyNet:
assert call_log == [
"sync-ok"
], f"expected sync_fail_close_session to run once, got {call_log!r}"
def test_cancel_waits_for_async_task_to_finish(self, exec_loop) -> None:
"""A cancel request must not let ``_execute`` return while the
underlying asyncio task is still cleaning up. Returning early would
make the manager release the session lock while late stream writes
are still possible."""
proc = CoPilotProcessor()
self._attach_exec_loop(proc, exec_loop)
started = threading.Event()
cancel_seen = threading.Event()
release_cleanup = threading.Event()
finished = threading.Event()
async def _stubborn_cancel(*_args, **_kwargs):
started.set()
try:
await asyncio.sleep(3600)
except asyncio.CancelledError:
cancel_seen.set()
while not release_cleanup.is_set():
await asyncio.sleep(0.01)
finally:
finished.set()
proc._execute_async = _stubborn_cancel # type: ignore[method-assign]
cancel = threading.Event()
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
try:
fut = pool.submit(
proc._execute,
_make_entry(),
cancel,
MagicMock(),
_make_log(),
)
assert started.wait(timeout=5)
cancel.set()
assert cancel_seen.wait(timeout=5)
assert not fut.done()
release_cleanup.set()
fut.result(timeout=5)
assert finished.is_set()
finally:
pool.shutdown(wait=True)
def test_cancel_wait_has_bounded_escape_hatch(self, exec_loop) -> None:
"""A wedged async cleanup must not keep the worker refreshing the
session lock forever; after the grace window, ``_execute`` returns
so ``execute`` can run the sync fail-close safety net."""
proc = CoPilotProcessor()
self._attach_exec_loop(proc, exec_loop)
started = threading.Event()
cancel_seen = threading.Event()
release_cleanup = threading.Event()
finished = threading.Event()
async def _wedged_cancel(*_args, **_kwargs):
started.set()
try:
await asyncio.sleep(3600)
except asyncio.CancelledError:
cancel_seen.set()
while not release_cleanup.is_set():
try:
await asyncio.sleep(0.01)
except asyncio.CancelledError:
pass
finally:
finished.set()
proc._execute_async = _wedged_cancel # type: ignore[method-assign]
cancel = threading.Event()
cluster_lock = MagicMock()
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
try:
with patch(
"backend.copilot.executor.processor._CANCEL_GRACE_SECONDS",
0.05,
):
fut = pool.submit(
proc._execute,
_make_entry(),
cancel,
cluster_lock,
_make_log(),
)
assert started.wait(timeout=5)
cancel.set()
assert cancel_seen.wait(timeout=5)
fut.result(timeout=5)
assert not finished.is_set()
assert cluster_lock.refresh.call_count < 10
release_cleanup.set()
assert finished.wait(timeout=5)
finally:
pool.shutdown(wait=True)

View File

@@ -71,7 +71,9 @@ COPILOT_EXECUTION_EXCHANGE = Exchange(
durable=True,
auto_delete=False,
)
COPILOT_EXECUTION_QUEUE_NAME = "copilot_execution_queue"
# ``_v2`` suffix marks the classic→quorum rollover; old-image consumers
# drain the unsuffixed queue. Orphans cleaned up in a follow-up PR.
COPILOT_EXECUTION_QUEUE_NAME = "copilot_execution_queue_v2"
COPILOT_EXECUTION_ROUTING_KEY = "copilot.run"
COPILOT_CANCEL_EXCHANGE = Exchange(
@@ -80,7 +82,7 @@ COPILOT_CANCEL_EXCHANGE = Exchange(
durable=True,
auto_delete=False,
)
COPILOT_CANCEL_QUEUE_NAME = "copilot_cancel_queue"
COPILOT_CANCEL_QUEUE_NAME = "copilot_cancel_queue_v2"
def get_session_lock_key(session_id: str) -> str:
@@ -118,6 +120,9 @@ def create_copilot_queue_config() -> RabbitMQConfig:
durable=True,
auto_delete=False,
arguments={
# Quorum (not classic mirrored) for leader election + stronger
# replication across RabbitMQ 4.x cluster nodes.
"x-queue-type": "quorum",
# Consumer timeout matches the pod graceful-shutdown window so a
# rolling deploy never forces redelivery of a turn that the pod
# is still legitimately finishing.
@@ -131,7 +136,7 @@ def create_copilot_queue_config() -> RabbitMQConfig:
# limit), apply a policy:
#
# rabbitmqctl set_policy copilot-consumer-timeout \
# "^copilot_execution_queue$" \
# "^copilot_execution_queue_v2$" \
# '{"consumer-timeout": 21600000}' \
# --apply-to queues
#
@@ -139,8 +144,7 @@ def create_copilot_queue_config() -> RabbitMQConfig:
# to match the code's value the policy is redundant for new
# pods and can be removed after a stable deploy if desired —
# but it's harmless to leave in place.
"x-consumer-timeout": COPILOT_CONSUMER_TIMEOUT_SECONDS
* 1000,
"x-consumer-timeout": COPILOT_CONSUMER_TIMEOUT_SECONDS * 1000,
},
)
cancel_queue = Queue(
@@ -149,6 +153,7 @@ def create_copilot_queue_config() -> RabbitMQConfig:
routing_key="", # not used for FANOUT
durable=True,
auto_delete=False,
arguments={"x-queue-type": "quorum"},
)
return RabbitMQConfig(
vhost="/",

View File

@@ -1,9 +1,22 @@
from unittest.mock import MagicMock
import pytest
from .falkordb_driver import AutoGPTFalkorDriver
def test_build_fulltext_query_uses_unquoted_group_ids_for_falkordb() -> None:
driver = AutoGPTFalkorDriver()
@pytest.fixture
def driver() -> AutoGPTFalkorDriver:
# ``build_fulltext_query`` is a pure string-builder that never touches
# the FalkorDB client; injecting a mock avoids the eager Redis probe
# that the upstream ``FalkorDriver.__init__`` runs against
# ``localhost:6379``.
return AutoGPTFalkorDriver(falkor_db=MagicMock())
def test_build_fulltext_query_uses_unquoted_group_ids_for_falkordb(
driver: AutoGPTFalkorDriver,
) -> None:
query = driver.build_fulltext_query(
"Sarah",
group_ids=["user_883cc9da-fe37-4863-839b-acba022bf3ef"],
@@ -13,18 +26,18 @@ def test_build_fulltext_query_uses_unquoted_group_ids_for_falkordb() -> None:
assert '"user_883cc9da-fe37-4863-839b-acba022bf3ef"' not in query
def test_build_fulltext_query_joins_multiple_group_ids_with_or() -> None:
driver = AutoGPTFalkorDriver()
def test_build_fulltext_query_joins_multiple_group_ids_with_or(
driver: AutoGPTFalkorDriver,
) -> None:
query = driver.build_fulltext_query("Sarah", group_ids=["user_a", "user_b"])
assert query == "(@group_id:user_a|user_b) (Sarah)"
def test_stopwords_only_query_returns_group_filter_only() -> None:
def test_stopwords_only_query_returns_group_filter_only(
driver: AutoGPTFalkorDriver,
) -> None:
"""Line 25: sanitized_query is empty (all stopwords) but group_ids present."""
driver = AutoGPTFalkorDriver()
# "the" is a common stopword — the query should reduce to just the group filter.
query = driver.build_fulltext_query(
"the",
@@ -34,10 +47,10 @@ def test_stopwords_only_query_returns_group_filter_only() -> None:
assert query == "(@group_id:user_abc)"
def test_query_without_group_ids_returns_parenthesized_query() -> None:
def test_query_without_group_ids_returns_parenthesized_query(
driver: AutoGPTFalkorDriver,
) -> None:
"""Line 27: sanitized_query has content but no group_ids provided."""
driver = AutoGPTFalkorDriver()
query = driver.build_fulltext_query("Sarah", group_ids=None)
assert query == "(Sarah)"

View File

@@ -21,8 +21,10 @@ from backend.copilot.pending_messages import (
drain_pending_messages,
format_pending_as_user_message,
push_pending_message,
push_pending_message_if_session_running,
)
from backend.copilot.stream_registry import get_session as get_active_session_meta
from backend.copilot.stream_registry import get_session_meta_key
from backend.data.redis_client import get_redis_async
from backend.data.redis_helpers import incr_with_ttl
from backend.data.workspace import resolve_workspace_files
@@ -44,8 +46,8 @@ _PENDING_CALL_KEY_PREFIX = "copilot:pending:calls:"
async def is_turn_in_flight(session_id: str) -> bool:
"""Return ``True`` when a copilot turn is actively running for *session_id*.
Used by the unified POST /stream entry point and the autopilot block so
a second message arriving while an earlier turn is still executing gets
Used by the HTTP pending-message endpoint and the autopilot block so a
second message arriving while an earlier turn is still executing gets
queued into the pending buffer instead of racing the in-flight turn on
the cluster lock.
"""
@@ -54,15 +56,14 @@ async def is_turn_in_flight(session_id: str) -> bool:
class QueuePendingMessageResponse(BaseModel):
"""Response returned by ``POST /stream`` with status 202 when a message
is queued because the session already has a turn in flight.
"""Response returned when a message is queued because the session already
has a turn in flight.
- ``buffer_length``: how many messages are now in the session's
pending buffer (after this push)
- ``max_buffer_length``: the per-session cap (server-side constant)
- ``turn_in_flight``: ``True`` if a copilot turn was running when
we checked — purely informational for UX feedback. Always ``True``
for responses from ``POST /stream`` with status 202.
we checked — purely informational for UX feedback.
"""
buffer_length: int
@@ -76,11 +77,12 @@ async def queue_user_message(
message: str,
context: PendingMessageContext | None = None,
file_ids: list[str] | None = None,
require_turn_in_flight: bool = False,
) -> QueuePendingMessageResponse:
"""Push *message* into the per-session pending buffer.
The shared primitive for "a message arrived while a turn is in flight"
called from the unified POST /stream handler and the autopilot block.
called from the HTTP pending-message path and the autopilot block.
Call-frequency rate limiting is the caller's responsibility (HTTP path
enforces it; internal block callers skip it).
"""
@@ -89,6 +91,18 @@ async def queue_user_message(
file_ids=file_ids or [],
context=context,
)
if require_turn_in_flight:
new_len = await push_pending_message_if_session_running(
session_id,
pending,
session_meta_key=get_session_meta_key(session_id),
)
return QueuePendingMessageResponse(
buffer_length=new_len or 0,
max_buffer_length=MAX_PENDING_MESSAGES,
turn_in_flight=new_len is not None,
)
new_len = await push_pending_message(session_id, pending)
return QueuePendingMessageResponse(
buffer_length=new_len,
@@ -107,7 +121,7 @@ async def queue_pending_for_http(
) -> QueuePendingMessageResponse:
"""HTTP-facing wrapper around :func:`queue_user_message`.
Owns the HTTP-only concerns that sat inline in ``stream_chat_post``:
Owns the HTTP-only concerns for the pending-message route:
1. Per-user call-rate cap (429 on overflow).
2. File-ID sanitisation against the user's own workspace.
@@ -116,19 +130,8 @@ async def queue_pending_for_http(
Raises :class:`HTTPException` with status 429 if the rate cap is hit;
otherwise returns the ``QueuePendingMessageResponse`` the handler can
serialise 1:1 into the 202 body.
serialise 1:1.
"""
call_count = await check_pending_call_rate(user_id)
if call_count > PENDING_CALL_LIMIT:
raise HTTPException(
status_code=429,
detail=(
f"Too many queued message requests this minute: limit is "
f"{PENDING_CALL_LIMIT} per {PENDING_CALL_WINDOW_SECONDS}s "
"across all sessions"
),
)
sanitized_file_ids: list[str] | None = None
if file_ids:
files = await resolve_workspace_files(user_id, file_ids)
@@ -141,12 +144,41 @@ async def queue_pending_for_http(
# typos, but the upstream ``StreamChatRequest.context: dict[str, str]``
# is already schemaless, so the strict mode adds no real safety.
queue_context = PendingMessageContext.model_validate(context) if context else None
return await queue_user_message(
# Push first via the Lua CAS gate. Bumping the per-user call-rate
# counter BEFORE the push would charge a budget tick on every TOCTOU
# loss against turn completion (status flips running→completed between
# the FE's is_turn_in_flight check and our gate), which both this
# endpoint and the POST /stream queue-fall-through can hit. Pushing
# first lets the gate own the no-op short-circuit.
response = await queue_user_message(
session_id=session_id,
message=message,
context=queue_context,
file_ids=sanitized_file_ids,
require_turn_in_flight=True,
)
if not response.turn_in_flight:
raise HTTPException(
status_code=409,
detail="Session has no active turn. Start a new turn with POST /stream.",
)
# Push landed — now charge the rate counter. If this tick crosses the
# limit we still keep the queued message (next drain will pick it up)
# but report 429 so the client backs off further pushes.
call_count = await check_pending_call_rate(user_id)
if call_count > PENDING_CALL_LIMIT:
raise HTTPException(
status_code=429,
detail=(
f"Too many queued message requests this minute: limit is "
f"{PENDING_CALL_LIMIT} per {PENDING_CALL_WINDOW_SECONDS}s "
"across all sessions"
),
)
return response
async def check_pending_call_rate(user_id: str) -> int:
@@ -366,14 +398,35 @@ async def persist_pending_as_user_rows(
transcript_builder.restore(transcript_snapshot)
if on_rollback is not None:
on_rollback(session_anchor)
# ``push_pending_message`` uses the bounded ``capped_rpush`` (LTRIM
# to ``MAX_PENDING_MESSAGES``). If ≥``MAX_PENDING_MESSAGES`` fresh
# follow-ups arrived between the original drain and this rollback
# (heavy typing across a tool boundary), the LTRIM drops oldest
# entries — which can include the ones we just re-pushed. The model
# already saw that content (mid-turn injection earlier in the
# turn), but no DB row lands so the user sees no UI bubble.
# Surface a warning so the bounded data-loss path is visible in
# prod (it is rare and would otherwise be observable only via the
# absence of a bubble).
rollback_buffer_at_cap = False
for pm in pending:
try:
await push_pending_message(session.session_id, pm)
new_length = await push_pending_message(session.session_id, pm)
if new_length >= MAX_PENDING_MESSAGES:
rollback_buffer_at_cap = True
except Exception:
logger.exception(
"%s Failed to re-queue mid-turn follow-up on rollback",
log_prefix,
)
if rollback_buffer_at_cap:
logger.warning(
"%s Rollback re-push hit pending-buffer cap (MAX=%d); a "
"previously queued follow-up may have been LTRIM-displaced "
"(silent UI-bubble drop). Investigate if observed.",
log_prefix,
MAX_PENDING_MESSAGES,
)
return False
logger.info(

View File

@@ -4,17 +4,20 @@ from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi import HTTPException
from backend.copilot import pending_message_helpers as helpers_module
from backend.copilot.pending_message_helpers import (
PENDING_CALL_LIMIT,
QueuePendingMessageResponse,
check_pending_call_rate,
combine_pending_with_current,
drain_pending_safe,
insert_pending_before_last,
persist_session_safe,
queue_pending_for_http,
)
from backend.copilot.pending_messages import PendingMessage
from backend.copilot.pending_messages import MAX_PENDING_MESSAGES, PendingMessage
# ── check_pending_call_rate ────────────────────────────────────────────
@@ -46,6 +49,112 @@ async def test_check_pending_call_rate_fails_open_on_redis_error(
assert result == 0
# ── queue_pending_for_http: gate-then-bump ordering ───────────────────
@pytest.mark.asyncio
async def test_queue_pending_does_not_charge_rate_on_toctou_409(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When the Lua gate refuses the push because the turn just completed,
the per-user call-rate counter must NOT have been incremented — bumping
it before the gate would charge a budget tick for every TOCTOU loss
against turn completion (race that both this endpoint and the POST
/stream queue-fall-through can trigger)."""
monkeypatch.setattr(
helpers_module,
"queue_user_message",
AsyncMock(
return_value=QueuePendingMessageResponse(
buffer_length=0,
max_buffer_length=MAX_PENDING_MESSAGES,
turn_in_flight=False,
)
),
)
rate_mock = AsyncMock(return_value=1)
monkeypatch.setattr(helpers_module, "check_pending_call_rate", rate_mock)
monkeypatch.setattr(
helpers_module, "resolve_workspace_files", AsyncMock(return_value=[])
)
with pytest.raises(HTTPException) as exc_info:
await queue_pending_for_http(
session_id="sess-1",
user_id="user-1",
message="hi",
context=None,
file_ids=None,
)
assert exc_info.value.status_code == 409
rate_mock.assert_not_awaited()
@pytest.mark.asyncio
async def test_queue_pending_charges_rate_only_after_successful_push(
monkeypatch: pytest.MonkeyPatch,
) -> None:
response = QueuePendingMessageResponse(
buffer_length=2,
max_buffer_length=MAX_PENDING_MESSAGES,
turn_in_flight=True,
)
queue_mock = AsyncMock(return_value=response)
monkeypatch.setattr(helpers_module, "queue_user_message", queue_mock)
rate_mock = AsyncMock(return_value=PENDING_CALL_LIMIT)
monkeypatch.setattr(helpers_module, "check_pending_call_rate", rate_mock)
monkeypatch.setattr(
helpers_module, "resolve_workspace_files", AsyncMock(return_value=[])
)
result = await queue_pending_for_http(
session_id="sess-1",
user_id="user-1",
message="hi",
context=None,
file_ids=None,
)
assert result is response
queue_mock.assert_awaited_once()
rate_mock.assert_awaited_once_with("user-1")
@pytest.mark.asyncio
async def test_queue_pending_429_after_push_when_limit_exceeded(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When the post-push rate counter crosses the limit, the message stays
in the buffer (next drain will pick it up) but the response is 429 so
the client backs off."""
response = QueuePendingMessageResponse(
buffer_length=3,
max_buffer_length=MAX_PENDING_MESSAGES,
turn_in_flight=True,
)
queue_mock = AsyncMock(return_value=response)
monkeypatch.setattr(helpers_module, "queue_user_message", queue_mock)
monkeypatch.setattr(
helpers_module,
"check_pending_call_rate",
AsyncMock(return_value=PENDING_CALL_LIMIT + 1),
)
monkeypatch.setattr(
helpers_module, "resolve_workspace_files", AsyncMock(return_value=[])
)
with pytest.raises(HTTPException) as exc_info:
await queue_pending_for_http(
session_id="sess-1",
user_id="user-1",
message="hi",
context=None,
file_ids=None,
)
assert exc_info.value.status_code == 429
queue_mock.assert_awaited_once()
@pytest.mark.asyncio
async def test_check_pending_call_rate_at_limit(
monkeypatch: pytest.MonkeyPatch,

View File

@@ -29,32 +29,21 @@ from typing import Any, cast
from pydantic import BaseModel, Field, ValidationError
from backend.data.redis_client import get_redis_async
from backend.data.redis_helpers import capped_rpush
from backend.data.redis_helpers import capped_rpush, capped_rpush_if_hash_field
logger = logging.getLogger(__name__)
# Per-session cap. Higher values risk a runaway consumer; lower values
# risk dropping user input under heavy typing. 10 was chosen as a
# reasonable ceiling — a user typing faster than the copilot can drain
# between tool rounds is already an unusual usage pattern.
# Per-session cap; typing faster than the copilot drains is already unusual.
MAX_PENDING_MESSAGES = 10
# Redis key + TTL. The buffer is ephemeral: if a turn completes or the
# executor dies, the pending messages should either have been drained
# already or are safe to drop (the user can resend).
# Ephemeral buffer: undrained messages are safe to drop at TTL expiry.
_PENDING_KEY_PREFIX = "copilot:pending:"
_PENDING_CHANNEL_PREFIX = "copilot:pending:notify:"
_PENDING_TTL_SECONDS = 3600 # 1 hour — matches stream_ttl default
# Secondary queue that carries drained-but-awaiting-persist PendingMessages
# from the MCP tool wrapper (which drains the primary buffer and injects
# into tool output for the LLM) to sdk/service.py's _dispatch_response
# handler for StreamToolOutputAvailable, which pops and persists them as a
# separate user row chronologically after the tool_result row. This is the
# hand-off between "Claude saw the follow-up mid-turn" (wrapper) and "UI
# renders a user bubble for it" (service). Rollback path re-queues into
# the PRIMARY buffer so the next turn-start drain picks them up if the
# user-row persist fails.
# Secondary queue: carries drained-but-awaiting-persist PendingMessages from
# the tool wrapper (which injects them into tool output) to sdk/service.py
# (which persists a user row after the tool_result row).
_PERSIST_QUEUE_KEY_PREFIX = "copilot:pending-persist:"
# Payload sent on the pub/sub notify channel. Subscribers treat any
@@ -65,13 +54,8 @@ _NOTIFY_PAYLOAD = "1"
class PendingMessageContext(BaseModel):
"""Structured page context attached to a pending message.
Default ``extra='ignore'`` (pydantic's default): unknown keys from
the loose HTTP-level ``StreamChatRequest.context: dict[str, str]``
are silently dropped rather than raising ``ValidationError`` on
forward-compat additions. The strict ``extra='forbid'`` mode was
removed after sentry r3105553772 — strict validation at this
boundary only added a 500 footgun; the upstream request model is
already schemaless so strict mode protects nothing.
Unknown keys are silently dropped: the upstream request model is
``dict[str, str]``, so strict validation here only adds 500 footguns.
"""
url: str | None = Field(default=None, max_length=2_000)
@@ -84,19 +68,16 @@ class PendingMessage(BaseModel):
content: str = Field(min_length=1, max_length=32_000)
file_ids: list[str] = Field(default_factory=list, max_length=20)
context: PendingMessageContext | None = None
# Wall-clock time (unix seconds, float) the message was queued by the
# user. Used by the turn-start drain to order pending relative to the
# turn's ``current`` message: items typed *before* the current's
# /stream arrival go ahead of it; items typed *after* (race path,
# queued while the /stream HTTP request was still processing) go
# after. Defaults to 0.0 for backward compatibility with entries
# written before this field existed — those sort as "before everything"
# which matches the pre-fix behaviour.
# Enqueue time (unix seconds) so the turn-start drain can order pending
# messages relative to the turn's ``current`` message.
enqueued_at: float = Field(default_factory=time.time)
def _buffer_key(session_id: str) -> str:
return f"{_PENDING_KEY_PREFIX}{session_id}"
# Hash-tag braces colocate this key with stream_registry's session-meta key
# on the same Redis Cluster slot, which the gated-rpush Lua script needs
# (multi-key scripts return CROSSSLOT when KEYS hash to different slots).
return f"{_PENDING_KEY_PREFIX}{{{session_id}}}"
def _notify_channel(session_id: str) -> str:
@@ -104,12 +85,7 @@ def _notify_channel(session_id: str) -> str:
def _decode_redis_item(item: Any) -> str:
"""Decode a redis-py list item to a str.
redis-py returns ``bytes`` when ``decode_responses=False`` and ``str``
when ``decode_responses=True``. This helper handles both so callers
don't have to repeat the isinstance guard.
"""
"""Decode a redis-py list item to str (handles ``bytes`` and ``str``)."""
return item.decode("utf-8") if isinstance(item, bytes) else str(item)
@@ -117,22 +93,11 @@ async def push_pending_message(
session_id: str,
message: PendingMessage,
) -> int:
"""Append a pending message to the session's buffer.
"""Append a pending message to the session's buffer, capped at
``MAX_PENDING_MESSAGES`` (oldest trimmed). Returns the new buffer length.
Returns the new buffer length. Enforces ``MAX_PENDING_MESSAGES`` by
trimming from the left (oldest) — the newest message always wins if
the user has been typing faster than the copilot can drain.
Delegates to :func:`backend.data.redis_helpers.capped_rpush` so RPUSH
+ LTRIM + EXPIRE + LLEN run atomically (MULTI/EXEC) in one round
trip; a concurrent drain (LPOP) can no longer observe the list
temporarily over ``MAX_PENDING_MESSAGES``.
Note on durability: if the executor turn crashes after a push but before
the drain window runs, the message remains in Redis until the TTL expires
(``_PENDING_TTL_SECONDS``, currently 1 hour). It is delivered on the
next turn that drains the buffer. If no turn runs within the TTL the
message is silently dropped; the user may resend it.
The buffer survives consumer crashes until ``_PENDING_TTL_SECONDS``
expires; messages not drained within that window are dropped.
"""
redis = await get_redis_async()
key = _buffer_key(session_id)
@@ -146,10 +111,14 @@ async def push_pending_message(
ttl_seconds=_PENDING_TTL_SECONDS,
)
# Fire-and-forget notify. Subscribers use this as a wake-up hint;
# the buffer itself is authoritative so a lost notify is harmless.
# Fire-and-forget wake-up hint via sharded pub/sub (SPUBLISH routes to
# one shard vs classic PUBLISH's cluster-bus broadcast). Use
# execute_command because redis-py 6.x AsyncRedisCluster has no
# spublish() wrapper.
try:
await redis.publish(_notify_channel(session_id), _NOTIFY_PAYLOAD)
await redis.execute_command(
"SPUBLISH", _notify_channel(session_id), _NOTIFY_PAYLOAD
)
except Exception as e: # pragma: no cover
logger.warning("pending_messages: publish failed for %s: %s", session_id, e)
@@ -161,6 +130,51 @@ async def push_pending_message(
return new_length
async def push_pending_message_if_session_running(
session_id: str,
message: PendingMessage,
*,
session_meta_key: str,
) -> int | None:
"""Append a pending message only while the stream meta is still running."""
redis = await get_redis_async()
key = _buffer_key(session_id)
payload = message.model_dump_json()
new_length = await capped_rpush_if_hash_field(
redis,
hash_key=session_meta_key,
hash_field="status",
expected="running",
list_key=key,
value=payload,
max_len=MAX_PENDING_MESSAGES,
ttl_seconds=_PENDING_TTL_SECONDS,
)
if new_length is None:
logger.info(
"pending_messages: skipped push to session=%s because no running turn exists",
session_id,
)
return None
# Match push_pending_message: SPUBLISH via execute_command so it works on
# both Redis and AsyncRedisCluster (the cluster client has no publish()).
try:
await redis.execute_command(
"SPUBLISH", _notify_channel(session_id), _NOTIFY_PAYLOAD
)
except Exception as e: # pragma: no cover
logger.warning("pending_messages: publish failed for %s: %s", session_id, e)
logger.info(
"pending_messages: pushed message to running session=%s (buffer_len=%d)",
session_id,
new_length,
)
return new_length
async def drain_pending_messages(session_id: str) -> list[PendingMessage]:
"""Atomically pop all pending messages for *session_id*.
@@ -171,13 +185,8 @@ async def drain_pending_messages(session_id: str) -> list[PendingMessage]:
redis = await get_redis_async()
key = _buffer_key(session_id)
# Redis LPOP with count (Redis 6.2+) returns None for missing key,
# empty list if we somehow race an empty key, or the popped items.
# Draining MAX_PENDING_MESSAGES at once is safe because the push side
# uses RPUSH + LTRIM(-MAX_PENDING_MESSAGES, -1) to cap the list to that
# same value, so the list can never hold more items than we drain here.
# If the cap is raised on the push side, raise the drain count here too
# (or switch to a loop drain).
# LPOP with count drains everything in one round-trip; the push side
# caps the list at MAX_PENDING_MESSAGES so nothing is left behind.
lpop_result = await redis.lpop(key, MAX_PENDING_MESSAGES) # type: ignore[assignment]
if not lpop_result:
return []
@@ -241,24 +250,17 @@ async def peek_pending_messages(session_id: str) -> list[PendingMessage]:
async def clear_pending_messages_unsafe(session_id: str) -> None:
"""Drop the session's pending buffer — **not** the normal turn cleanup.
"""Drop the session's pending buffer — operator/debug escape hatch.
The ``_unsafe`` suffix warns: reaching for this at turn end drops queued
follow-ups on the floor instead of running them (the bug fixed by commit
b64be73). The atomic ``LPOP`` drain at turn start is the primary consumer;
anything pushed after the drain window belongs to the next turn by
definition. Retained only as an operator/debug escape hatch for manually
clearing a stuck session and as a fixture in the unit tests.
The ``_unsafe`` suffix warns: normal turn cleanup uses the atomic LPOP
drain; this bypass drops queued follow-ups on the floor.
"""
redis = await get_redis_async()
await redis.delete(_buffer_key(session_id))
# Per-message and total-block caps for inline tool-boundary injection.
# Per-message keeps a single long paste from dominating; the total cap
# keeps the follow-up block small relative to the 100 KB MCP truncation
# boundary so tool output always stays the larger share of the wrapper
# return value.
# Per-message + total caps keep the follow-up block bounded relative to the
# 100 KB MCP tool-output truncation boundary.
_FOLLOWUP_CONTENT_MAX_CHARS = 2_000
_FOLLOWUP_TOTAL_MAX_CHARS = 6_000
@@ -273,17 +275,9 @@ async def stash_pending_for_persist(
) -> None:
"""Enqueue drained PendingMessages for UI-row persistence.
Writes each message as a JSON payload to
``copilot:pending-persist:{session_id}``. The SDK service's
tool-result dispatch handler LPOPs this queue right after appending
the tool_result row to ``session.messages``, so the resulting user
row lands at the correct chronological position (after the tool
output the follow-up was drained against).
Fire-and-forget on Redis failures: a stash failure means Claude
still saw the follow-up in tool output (the injection step ran
first), so the only consequence is a missing UI bubble. Logged
so it can be spotted.
The SDK service LPOPs this right after appending the tool_result row so
the user bubble lands after the tool output. Stash failures are logged
but not raised — the only consequence is a missing UI bubble.
"""
if not messages:
return
@@ -336,8 +330,7 @@ async def drain_pending_for_persist(session_id: str) -> list[PendingMessage]:
)
except (json.JSONDecodeError, ValidationError, TypeError, ValueError) as e:
logger.warning(
"pending_messages: dropping malformed persist-queue entry "
"for %s: %s",
"pending_messages: dropping malformed persist-queue entry for %s: %s",
session_id,
e,
)

View File

@@ -60,6 +60,16 @@ class _FakeRedis:
self.published.append((channel, payload))
return 1
async def execute_command(self, *args: Any) -> Any:
# Minimal handler for the sharded SPUBLISH call made by
# push_pending_message. Routing semantics are irrelevant here —
# we just record the publish for assertion.
if args and args[0] == "SPUBLISH":
_, channel, payload = args
self.published.append((channel, payload))
return 1
raise NotImplementedError(f"fake execute_command does not handle {args[0]!r}")
async def lpop(self, key: str, count: int) -> list[str | bytes] | None:
lst = self.lists.get(key)
if not lst:
@@ -326,7 +336,7 @@ async def test_drain_skips_malformed_entries(
fake_redis: _FakeRedis,
) -> None:
# Seed the fake with a mix of valid and malformed payloads
fake_redis.lists["copilot:pending:bad"] = [
fake_redis.lists[pm_module._buffer_key("bad")] = [
json.dumps({"content": "valid"}),
"{not valid json",
json.dumps({"content": "also valid", "file_ids": ["a"]}),
@@ -347,7 +357,7 @@ async def test_drain_decodes_bytes_payloads(
branch in ``drain_pending_messages`` so a regression there doesn't
slip past CI.
"""
fake_redis.lists["copilot:pending:bytes_sess"] = [
fake_redis.lists[pm_module._buffer_key("bytes_sess")] = [
json.dumps({"content": "from bytes"}).encode("utf-8"),
]
drained = await drain_pending_messages("bytes_sess")
@@ -362,14 +372,14 @@ async def test_peek_decodes_bytes_payloads(
"""``peek_pending_messages`` uses the same ``_decode_redis_item`` helper
as the drain path. Seed with bytes to guard against regression.
"""
fake_redis.lists["copilot:pending:peek_bytes_sess"] = [
fake_redis.lists[pm_module._buffer_key("peek_bytes_sess")] = [
json.dumps({"content": "peeked from bytes"}).encode("utf-8"),
]
peeked = await peek_pending_messages("peek_bytes_sess")
assert len(peeked) == 1
assert peeked[0].content == "peeked from bytes"
# peek must NOT consume the item
assert fake_redis.lists["copilot:pending:peek_bytes_sess"] != []
assert fake_redis.lists[pm_module._buffer_key("peek_bytes_sess")] != []
# ── Concurrency ─────────────────────────────────────────────────────
@@ -445,7 +455,7 @@ async def test_peek_pending_messages_decodes_bytes_payloads(
fake_redis: _FakeRedis,
) -> None:
"""peek_pending_messages decodes bytes entries the same way drain does."""
fake_redis.lists["copilot:pending:peek_bytes"] = [
fake_redis.lists[pm_module._buffer_key("peek_bytes")] = [
json.dumps({"content": "from bytes"}).encode("utf-8"),
]
peeked = await peek_pending_messages("peek_bytes")
@@ -458,7 +468,7 @@ async def test_peek_pending_messages_skips_malformed_entries(
fake_redis: _FakeRedis,
) -> None:
"""Malformed entries are skipped and valid ones are returned."""
fake_redis.lists["copilot:pending:peek_bad"] = [
fake_redis.lists[pm_module._buffer_key("peek_bad")] = [
json.dumps({"content": "valid peek"}),
"{bad json",
json.dumps({"content": "also valid peek"}),
@@ -486,7 +496,7 @@ async def test_stash_for_persist_enqueues_and_drain_pops_in_order(
# Stored under the distinct persist key, NOT the primary buffer.
assert "copilot:pending-persist:sess-persist" in fake_redis.lists
assert "copilot:pending:sess-persist" not in fake_redis.lists
assert pm_module._buffer_key("sess-persist") not in fake_redis.lists
drained = await drain_pending_for_persist("sess-persist")
assert len(drained) == 2
@@ -612,3 +622,59 @@ async def test_drain_and_format_for_injection_swallows_redis_error(
@pytest.mark.asyncio
async def test_drain_and_format_for_injection_missing_session_id() -> None:
assert await drain_and_format_for_injection("", log_prefix="[TEST]") == ""
# ── Cluster-slot colocation regression ───────────────────────────────
# The gated-rpush Lua script in `capped_rpush_if_hash_field` touches both
# the session-meta hash (`stream_registry._get_session_meta_key`) and the
# pending buffer list (`_buffer_key`) atomically. Redis Cluster requires
# every key referenced by a multi-key Lua script to hash to the same slot,
# so both keys must share a hash tag (the `{...}` substring Redis uses for
# slot calculation). Without this, the EVAL returns `CROSSSLOT keys in
# request` once cluster mode is active.
def _redis_keyslot(key: str) -> int:
"""Compute the Redis Cluster slot for ``key`` using CRC16-XMODEM mod 16384.
Mirrors the algorithm in redis-py's ``RedisCluster.keyslot`` and the
Redis spec — extracts the first ``{...}`` substring as the hash tag,
falls back to the whole key when no tag is present.
"""
start = key.find("{")
if start != -1:
end = key.find("}", start + 1)
if end > start + 1:
key = key[start + 1 : end]
crc = 0
poly = 0x1021
for byte in key.encode():
crc ^= byte << 8
for _ in range(8):
crc = ((crc << 1) ^ poly) & 0xFFFF if crc & 0x8000 else (crc << 1) & 0xFFFF
return crc % 16384
def test_buffer_and_session_meta_keys_share_cluster_slot() -> None:
"""Regression: pending-buffer key + session-meta key must hash to the
same Redis Cluster slot, otherwise the gated-rpush Lua script returns
CROSSSLOT once cluster mode is enabled."""
# Late import so the test doesn't pull stream_registry's heavy module
# graph (it transitively wires the AppService client) at file load.
from backend.copilot.stream_registry import _get_session_meta_key
for session_id in [
"sess-abcdef-123",
"0eb0aae8-6926-4b50-97af-72840841dc70",
"x",
]:
buf = pm_module._buffer_key(session_id)
meta = _get_session_meta_key(session_id)
assert "{" in buf and "}" in buf, f"_buffer_key missing hash tag: {buf!r}"
assert (
"{" in meta and "}" in meta
), f"_get_session_meta_key missing hash tag: {meta!r}"
assert _redis_keyslot(buf) == _redis_keyslot(meta), (
f"CROSSSLOT regression: {buf!r} (slot {_redis_keyslot(buf)}) "
f"!= {meta!r} (slot {_redis_keyslot(meta)})"
)

View File

@@ -128,6 +128,14 @@ ToolName = Literal[
# Frozen set of all valid tool names — derived from the Literal.
ALL_TOOL_NAMES: frozenset[str] = frozenset(get_args(ToolName))
DISABLED_LEGACY_TOOL_NAMES: frozenset[str] = frozenset({"ask_question"})
"""Tool names accepted only for backwards compatibility with saved graphs.
These names are intentionally absent from ``ToolName`` and
``PLATFORM_TOOL_NAMES`` so they are not exposed in new block schemas or sent to
the model as available tools.
"""
# SDK built-in tool names — tools provided by the Claude Code CLI that our
# code does not implement directly. ``TodoWrite`` is DELIBERATELY excluded:
# baseline mode ships an MCP-wrapped platform version
@@ -304,7 +312,11 @@ def validate_tool_names(tools: list[str]) -> list[str]:
Returns:
List of invalid names (empty if all are valid).
"""
return [t for t in tools if t not in ALL_TOOL_NAMES]
return [
t
for t in tools
if t not in ALL_TOOL_NAMES and t not in DISABLED_LEGACY_TOOL_NAMES
]
_tool_names_checked = False

View File

@@ -257,6 +257,9 @@ class TestValidateToolNames:
def test_valid_sdk_builtin(self):
assert validate_tool_names(["Read", "Task", "WebSearch"]) == []
def test_disabled_legacy_tool_name_is_accepted(self):
assert validate_tool_names(["ask_question"]) == []
def test_invalid_tool(self):
result = validate_tool_names(["nonexistent_tool"])
assert "nonexistent_tool" in result

View File

@@ -47,15 +47,14 @@ from pydantic import BaseModel, Field
from redis.exceptions import RedisError
from backend.data.db_accessors import user_db
from backend.data.redis_client import get_redis_async
from backend.data.redis_client import AsyncRedisClient, get_redis_async
from backend.data.user import get_user_by_id
from backend.util.cache import cached
logger = logging.getLogger(__name__)
# Redis key prefixes. Bumped from "copilot:usage" (token-based) to
# "copilot:cost" on the token→cost migration so stale counters do not
# get misinterpreted as microdollars (which would dramatically under-count).
# "copilot:cost" (not the legacy "copilot:usage") so stale token-based
# counters are not misread as microdollars.
_USAGE_KEY_PREFIX = "copilot:cost"
@@ -73,6 +72,7 @@ class SubscriptionTier(str, Enum):
from prisma.enums import SubscriptionTier
"""
NO_TIER = "NO_TIER"
BASIC = "BASIC"
PRO = "PRO"
MAX = "MAX"
@@ -88,6 +88,14 @@ class SubscriptionTier(str, Enum):
# eventual ``int(base * multiplier)`` in ``get_global_rate_limits`` keeps the
# downstream microdollar math integer.
_DEFAULT_TIER_MULTIPLIERS: dict[SubscriptionTier, float] = {
# NO_TIER is the explicit "no active Stripe subscription" state —
# multiplier 0.0 collapses the per-period limit to int(base * 0) = 0, so
# all rate-limited routes (CoPilot chat, AutoPilot) refuse with 429
# before any business logic runs. This is the backend half of the
# paywall (the frontend modal nudges UI users; this gate enforces
# server-side regardless of client). BASIC stays as a future paid-tier
# option; for now it falls back to the same baseline as paid tiers.
SubscriptionTier.NO_TIER: 0.0,
SubscriptionTier.BASIC: 1.0,
SubscriptionTier.PRO: 5.0,
SubscriptionTier.MAX: 20.0,
@@ -100,7 +108,7 @@ _DEFAULT_TIER_MULTIPLIERS: dict[SubscriptionTier, float] = {
# ``get_tier_multipliers`` so LD overrides are honoured.
TIER_MULTIPLIERS = _DEFAULT_TIER_MULTIPLIERS
DEFAULT_TIER = SubscriptionTier.BASIC
DEFAULT_TIER = SubscriptionTier.NO_TIER
@cached(ttl_seconds=60, maxsize=1, cache_none=False)
@@ -447,25 +455,16 @@ async def reset_daily_usage(user_id: str, daily_cost_limit: int = 0) -> bool:
try:
redis = await get_redis_async()
# Use a MULTI/EXEC transaction so that DELETE (daily) and DECRBY
# (weekly) either both execute or neither does. This prevents the
# scenario where the daily counter is cleared but the weekly
# counter is not decremented — which would let the caller refund
# credits even though the daily limit was already reset.
d_key = _daily_key(user_id, now=now)
w_key = _weekly_key(user_id, now=now) if daily_cost_limit > 0 else None
pipe = redis.pipeline(transaction=True)
pipe.delete(d_key)
# Daily and weekly keys hash to different cluster slots, so cross-key
# MULTI/EXEC is not available. Issue the writes sequentially — the
# failure mode (daily deleted, weekly not decremented) is a
# best-effort refund budget that the read path already tolerates.
await redis.delete(d_key)
if w_key is not None:
pipe.decrby(w_key, daily_cost_limit)
results = await pipe.execute()
# Clamp negative weekly counter to 0 (best-effort; not critical).
if w_key is not None:
new_val = results[1] # DECRBY result
if new_val < 0:
await redis.set(w_key, 0, keepttl=True)
await _decr_counter_floor_zero(redis, w_key, daily_cost_limit)
logger.info("Reset daily usage for user %s", user_id[:8])
return True
@@ -555,30 +554,18 @@ async def record_cost_usage(
logger.info("Recording copilot spend: %d microdollars", cost_microdollars)
now = datetime.now(UTC)
d_key = _daily_key(user_id, now=now)
w_key = _weekly_key(user_id, now=now)
daily_ttl = max(int((_daily_reset_time(now=now) - now).total_seconds()), 1)
weekly_ttl = max(int((_weekly_reset_time(now=now) - now).total_seconds()), 1)
try:
redis = await get_redis_async()
# Use MULTI/EXEC so each INCRBY/EXPIRE pair is atomic — guarantees
# the TTL is set even if the connection drops mid-pipeline, so
# counters can never survive past their date-based rotation window.
pipe = redis.pipeline(transaction=True)
# Daily counter (expires at next midnight UTC)
d_key = _daily_key(user_id, now=now)
pipe.incrby(d_key, cost_microdollars)
seconds_until_daily_reset = int(
(_daily_reset_time(now=now) - now).total_seconds()
)
pipe.expire(d_key, max(seconds_until_daily_reset, 1))
# Weekly counter (expires end of week)
w_key = _weekly_key(user_id, now=now)
pipe.incrby(w_key, cost_microdollars)
seconds_until_weekly_reset = int(
(_weekly_reset_time(now=now) - now).total_seconds()
)
pipe.expire(w_key, max(seconds_until_weekly_reset, 1))
await pipe.execute()
# Daily and weekly keys hash to different cluster slots — cross-slot
# MULTI/EXEC is not supported, so each counter gets its own
# single-key transaction. Per-counter INCRBY+EXPIRE atomicity is the
# invariant that matters; the two counters are independent budgets.
await _incr_counter_atomic(redis, d_key, cost_microdollars, daily_ttl)
await _incr_counter_atomic(redis, w_key, cost_microdollars, weekly_ttl)
except (RedisError, ConnectionError, OSError):
logger.warning(
"Redis unavailable for recording cost usage (microdollars=%d)",
@@ -586,30 +573,56 @@ async def record_cost_usage(
)
async def _incr_counter_atomic(
redis: AsyncRedisClient, key: str, delta: int, ttl_seconds: int
) -> None:
"""INCRBY + EXPIRE on a single key inside a MULTI/EXEC transaction."""
pipe = redis.pipeline(transaction=True)
pipe.incrby(key, delta)
pipe.expire(key, ttl_seconds)
await pipe.execute()
# Atomic DECRBY + floor-to-zero so a concurrent INCRBY from record_cost_usage
# cannot be lost. DELETE on underflow also avoids leaving a zero-valued key
# with no TTL, which the non-atomic set-with-keepttl variant could do.
_DECR_FLOOR_ZERO_SCRIPT = """
local value = redis.call("DECRBY", KEYS[1], ARGV[1])
if value < 0 then
redis.call("DEL", KEYS[1])
return 0
end
return value
"""
async def _decr_counter_floor_zero(
redis: AsyncRedisClient, key: str, delta: int
) -> None:
"""Atomically DECRBY ``delta`` on ``key`` and DEL on underflow.
DEL on underflow avoids leaving a zero-valued key without a TTL, so the
next INCRBY in ``record_cost_usage`` re-seeds both the value and the
expiry in one shot.
"""
await redis.eval(_DECR_FLOOR_ZERO_SCRIPT, 1, key, delta)
class _UserNotFoundError(Exception):
"""Raised when a user record is missing or has no subscription tier.
Used internally by ``_fetch_user_tier`` to signal a cache-miss condition:
by raising instead of returning ``DEFAULT_TIER``, we prevent the ``@cached``
decorator from storing the fallback value. This avoids a race condition
where a non-existent user's DEFAULT_TIER is cached, then the user is
created with a higher tier but receives the stale cached FREE tier for
up to 5 minutes.
Raising (rather than returning ``DEFAULT_TIER``) prevents ``@cached``
from persisting the fallback, which would otherwise keep serving FREE
for up to the TTL after the user's real tier is set.
"""
@cached(maxsize=1000, ttl_seconds=300, shared_cache=True)
async def _fetch_user_tier(user_id: str) -> SubscriptionTier:
"""Fetch the user's rate-limit tier from the database (cached via Redis).
"""Fetch the user's rate-limit tier, cached across pods.
Uses ``shared_cache=True`` so that tier changes propagate across all pods
immediately when the cache entry is invalidated (via ``cache_delete``).
Only successful DB lookups of existing users with a valid tier are cached.
Raises ``_UserNotFoundError`` when the user is missing or has no tier, so
the ``@cached`` decorator does **not** store a fallback value. This
prevents a race condition where a non-existent user's ``DEFAULT_TIER`` is
cached and then persists after the user is created with a higher tier.
Only successful lookups are cached. Missing users raise
``_UserNotFoundError`` so ``@cached`` never stores the fallback.
"""
try:
user = await user_db().get_user_by_id(user_id)
@@ -651,20 +664,10 @@ get_user_tier.cache_delete = _fetch_user_tier.cache_delete # type: ignore[attr-
async def set_user_tier(user_id: str, tier: SubscriptionTier) -> None:
"""Persist the user's rate-limit tier to the database.
Invalidates every cache that keys off the user's subscription tier so the
change is visible immediately: this function's own ``get_user_tier``, the
shared ``get_user_by_id`` (which exposes ``user.subscription_tier``), and
``get_pending_subscription_change`` (since an admin override can invalidate
a cached ``cancel_at_period_end`` or schedule-based pending state).
If the user has an active Stripe subscription whose current price does not
match ``tier``, Stripe will keep billing the old price and the next
``customer.subscription.updated`` webhook will overwrite the DB tier back
to whatever Stripe has. Proper reconciliation (cancelling or modifying the
Stripe subscription when an admin overrides the tier) is out of scope for
this PR — it changes the admin contract and needs its own test coverage.
For now we emit a ``WARNING`` so drift surfaces via Sentry until that
follow-up lands.
Invalidates the caches that expose ``subscription_tier`` so the change
takes effect immediately. If the user has an active Stripe subscription
on a mismatched price, emits a WARNING; Stripe remains the billing
source of truth and the next webhook will reconcile the DB tier.
Raises:
prisma.errors.RecordNotFoundError: If the user does not exist.
@@ -674,21 +677,13 @@ async def set_user_tier(user_id: str, tier: SubscriptionTier) -> None:
data={"subscriptionTier": tier.value},
)
get_user_tier.cache_delete(user_id) # type: ignore[attr-defined]
# Local import required: backend.data.credit imports backend.copilot.rate_limit
# (via get_user_tier in credit.py's _invalidate_user_tier_caches), so a
# top-level ``from backend.data.credit import ...`` here would create a
# circular import at module-load time.
# Local import: backend.data.credit imports from this module.
from backend.data.credit import get_pending_subscription_change
get_user_by_id.cache_delete(user_id) # type: ignore[attr-defined]
get_pending_subscription_change.cache_delete(user_id) # type: ignore[attr-defined]
# The DB write above is already committed; the drift check is best-effort
# diagnostic logging. Fire-and-forget so admin bulk ops don't wait on a
# Stripe roundtrip. The inner helper wraps its body in a timeout + broad
# except so background task errors still surface via logs rather than as
# "task exception never retrieved" warnings. Cancellation on request
# shutdown is acceptable — the drift warning is non-load-bearing.
# Fire-and-forget drift check so admin bulk ops don't wait on Stripe.
asyncio.ensure_future(_drift_check_background(user_id, tier))
@@ -711,8 +706,6 @@ async def _drift_check_background(user_id: str, tier: SubscriptionTier) -> None:
tier.value,
)
except asyncio.CancelledError:
# Request may have completed and the event loop is cancelling tasks —
# the drift log is non-critical, so accept cancellation silently.
raise
except Exception:
logger.exception(
@@ -726,19 +719,9 @@ async def _drift_check_background(user_id: str, tier: SubscriptionTier) -> None:
async def _warn_if_stripe_subscription_drifts(
user_id: str, new_tier: SubscriptionTier
) -> None:
"""Emit a WARNING when an admin tier override leaves an active Stripe sub on a
mismatched price.
The warning is diagnostic only: Stripe remains the billing source of truth,
so the next ``customer.subscription.updated`` webhook will reset the DB
tier. Surfacing the drift here lets ops catch admin overrides that bypass
the intended Checkout / Portal cancel flows before users notice surprise
charges.
"""
# Local imports: see note in ``set_user_tier`` about the credit <-> rate_limit
# circular. These helpers (``_get_active_subscription``,
# ``get_subscription_price_id``) live in credit.py alongside the rest of
# the Stripe billing code.
"""Emit a WARNING when an admin tier override leaves an active Stripe
subscription on a mismatched price."""
# Local import: breaks a credit <-> rate_limit circular at module load.
from backend.data.credit import _get_active_subscription, get_subscription_price_id
try:
@@ -753,10 +736,8 @@ async def _warn_if_stripe_subscription_drifts(
return
price = items[0].price
current_price_id = price if isinstance(price, str) else price.id
# The LaunchDarkly-backed price lookup must live inside this try/except:
# an LD SDK failure (network, token revoked) here would otherwise
# propagate past set_user_tier's already-committed DB write and turn a
# best-effort diagnostic into a 500 on admin tier writes.
# Inside the try/except: an LD SDK failure here must not turn a
# best-effort diagnostic into a 500 after the DB write committed.
expected_price_id = await get_subscription_price_id(new_tier)
except Exception:
logger.debug(
@@ -816,6 +797,16 @@ async def get_global_rate_limits(
tier = await get_user_tier(user_id)
multipliers = await get_tier_multipliers()
multiplier = multipliers.get(tier.value, 1.0)
# NO_TIER's 0.0 multiplier is the backend half of the paywall — it
# collapses limits to zero so unsubscribed users can't run the chat.
# Only enforce that gate when the platform-payment flag is on for this
# user; in the beta cohort (flag off) NO_TIER falls back to BASIC's
# baseline so the e2e suite and beta testers retain access.
if tier == SubscriptionTier.NO_TIER:
from backend.util.feature_flag import Flag, is_feature_enabled
if not await is_feature_enabled(Flag.ENABLE_PLATFORM_PAYMENT, user_id):
multiplier = multipliers.get(SubscriptionTier.BASIC.value, 1.0)
if multiplier != 1.0:
# Cast back to int to preserve the microdollar integer contract
# downstream — fractional LD multipliers (e.g. 8.5×) truncate at the
@@ -838,12 +829,15 @@ async def reset_user_usage(user_id: str, *, reset_weekly: bool = False) -> None:
the admin believing the counters were zeroed when they were not.
"""
now = datetime.now(UTC)
keys_to_delete = [_daily_key(user_id, now=now)]
if reset_weekly:
keys_to_delete.append(_weekly_key(user_id, now=now))
d_key = _daily_key(user_id, now=now)
w_key = _weekly_key(user_id, now=now) if reset_weekly else None
try:
redis = await get_redis_async()
await redis.delete(*keys_to_delete)
# Daily and weekly keys hash to different cluster slots — multi-key
# DELETE would raise CROSSSLOT, so issue separate calls.
await redis.delete(d_key)
if w_key is not None:
await redis.delete(w_key)
except (RedisError, ConnectionError, OSError):
logger.warning("Redis unavailable for resetting user usage")
raise

View File

@@ -359,6 +359,9 @@ class TestSubscriptionTier:
def test_tier_multipliers(self):
# Float-typed so LD-provided fractional multipliers compose naturally;
# equality against int literals still holds for the whole defaults.
# NO_TIER is 0.0 — explicit "no active subscription" state;
# rate-limited routes refuse with 429 (backend half of the paywall).
assert TIER_MULTIPLIERS[SubscriptionTier.NO_TIER] == 0.0
assert TIER_MULTIPLIERS[SubscriptionTier.BASIC] == 1.0
assert TIER_MULTIPLIERS[SubscriptionTier.PRO] == 5.0
assert TIER_MULTIPLIERS[SubscriptionTier.MAX] == 20.0
@@ -366,8 +369,8 @@ class TestSubscriptionTier:
assert TIER_MULTIPLIERS[SubscriptionTier.ENTERPRISE] == 60.0
assert TIER_MULTIPLIERS is _DEFAULT_TIER_MULTIPLIERS
def test_default_tier_is_basic(self):
assert DEFAULT_TIER == SubscriptionTier.BASIC
def test_default_tier_is_no_tier(self):
assert DEFAULT_TIER == SubscriptionTier.NO_TIER
def test_usage_status_includes_tier(self):
now = datetime.now(UTC)
@@ -375,7 +378,7 @@ class TestSubscriptionTier:
daily=UsageWindow(used=0, limit=100, resets_at=now + timedelta(hours=1)),
weekly=UsageWindow(used=0, limit=500, resets_at=now + timedelta(days=1)),
)
assert status.tier == SubscriptionTier.BASIC
assert status.tier == SubscriptionTier.NO_TIER
def test_usage_status_with_custom_tier(self):
now = datetime.now(UTC)
@@ -1243,18 +1246,9 @@ class TestTierLimitsRespected:
class TestResetDailyUsage:
@staticmethod
def _make_pipeline_mock(decrby_result: int = 0) -> MagicMock:
"""Create a pipeline mock that returns [delete_result, decrby_result]."""
pipe = MagicMock()
pipe.execute = AsyncMock(return_value=[1, decrby_result])
return pipe
@pytest.mark.asyncio
async def test_deletes_daily_key(self):
mock_pipe = self._make_pipeline_mock(decrby_result=0)
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
@@ -1263,14 +1257,12 @@ class TestResetDailyUsage:
result = await reset_daily_usage(_USER, daily_cost_limit=10000)
assert result is True
mock_pipe.delete.assert_called_once()
mock_redis.delete.assert_called_once()
@pytest.mark.asyncio
async def test_reduces_weekly_usage_via_decrby(self):
"""Weekly counter should be reduced via DECRBY in the pipeline."""
mock_pipe = self._make_pipeline_mock(decrby_result=35000)
async def test_reduces_weekly_usage_via_eval(self):
"""Weekly counter should be decremented via the atomic Lua script."""
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
@@ -1278,32 +1270,22 @@ class TestResetDailyUsage:
):
await reset_daily_usage(_USER, daily_cost_limit=10000)
mock_pipe.decrby.assert_called_once()
mock_redis.set.assert_not_called() # 35000 > 0, no clamp needed
@pytest.mark.asyncio
async def test_clamps_negative_weekly_to_zero(self):
"""If DECRBY goes negative, SET to 0 (outside the pipeline)."""
mock_pipe = self._make_pipeline_mock(decrby_result=-5000)
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await reset_daily_usage(_USER, daily_cost_limit=10000)
mock_pipe.decrby.assert_called_once()
mock_redis.set.assert_called_once()
# The Lua script handles both decrement and floor-to-zero in a single
# call — no separate SET is expected for the clamp branch any more.
# Pin the call shape so a regression that targets the wrong key or
# delta (e.g. the daily key, or a sign-flip) fails loudly.
mock_redis.eval.assert_called_once()
eval_args = mock_redis.eval.call_args.args
# eval(script, numkeys, KEYS[1], ARGV[1])
assert eval_args[1] == 1
assert eval_args[2] == _weekly_key(_USER)
assert int(eval_args[3]) == 10000
mock_redis.set.assert_not_called()
@pytest.mark.asyncio
async def test_no_weekly_reduction_when_daily_limit_zero(self):
"""When daily_cost_limit is 0, weekly counter should not be touched."""
mock_pipe = self._make_pipeline_mock()
mock_pipe.execute = AsyncMock(return_value=[1]) # only delete result
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
@@ -1311,8 +1293,8 @@ class TestResetDailyUsage:
):
await reset_daily_usage(_USER, daily_cost_limit=0)
mock_pipe.delete.assert_called_once()
mock_pipe.decrby.assert_not_called()
mock_redis.delete.assert_called_once()
mock_redis.eval.assert_not_called()
@pytest.mark.asyncio
async def test_returns_false_when_redis_unavailable(self):
@@ -1324,6 +1306,23 @@ class TestResetDailyUsage:
assert result is False
@pytest.mark.asyncio
async def test_decr_counter_floor_zero_invokes_lua_script(self):
"""The atomic DECRBY+floor helper routes through redis.eval with the
expected single-key, single-arg call shape."""
from backend.copilot.rate_limit import (
_DECR_FLOOR_ZERO_SCRIPT,
_decr_counter_floor_zero,
)
mock_redis = AsyncMock()
await _decr_counter_floor_zero(mock_redis, "weekly:user1", 42)
mock_redis.eval.assert_called_once_with(
_DECR_FLOOR_ZERO_SCRIPT, 1, "weekly:user1", 42
)
# ---------------------------------------------------------------------------
# Tier-limit enforcement (integration-style)
@@ -1781,8 +1780,9 @@ class TestResetUserUsage:
"backend.copilot.rate_limit.get_redis_async", return_value=mock_redis
):
await reset_user_usage("user-1", reset_weekly=True)
args = mock_redis.delete.call_args[0]
assert len(args) == 2 # both daily and weekly keys
# Daily and weekly keys hash to different cluster slots, so they are
# deleted via two separate DELETE calls (not a single multi-key one).
assert mock_redis.delete.call_count == 2
@pytest.mark.asyncio
async def test_raises_on_redis_failure(self):

View File

@@ -52,7 +52,8 @@ class ResponseType(str, Enum):
ERROR = "error"
USAGE = "usage"
HEARTBEAT = "heartbeat"
STATUS = "status"
STATUS = "data-status"
CURSOR = "data-cursor"
class StreamBaseResponse(BaseModel):
@@ -275,10 +276,18 @@ class StreamError(StreamBaseResponse):
The AI SDK uses z.strictObject({type, errorText}) which rejects
any extra fields like `code` or `details`.
When ``code`` is set we prefix ``errorText`` with ``[code:<id>]`` so
the frontend can still parse a machine-readable code out of the
otherwise opaque text. Idempotent: if the caller already embedded
the prefix, we don't double it.
"""
text = self.errorText
if self.code and not text.lstrip().startswith(f"[code:{self.code}]"):
text = f"[code:{self.code}] {text}"
data = {
"type": self.type.value,
"errorText": self.errorText,
"errorText": text,
}
return f"data: {json_dumps(data)}\n\n"
@@ -300,17 +309,46 @@ class StreamHeartbeat(StreamBaseResponse):
return ": heartbeat\n\n"
class StreamCursor(StreamBaseResponse):
"""Deprecated Redis-stream cursor data part.
Kept so older stored chunks or tests can still be reconstructed, but new
stream subscriptions no longer emit it. AI SDK resume needs a full replay
from ``0-0`` so every ``*-delta`` has its matching ``*-start`` event.
"""
type: ResponseType = ResponseType.CURSOR
chunkId: str = Field(..., description="Redis Stream message ID (XADD)")
def to_sse(self) -> str:
"""Emit as an AI SDK v5 data part."""
data = {
"type": self.type.value,
"data": {"chunkId": self.chunkId},
}
return f"data: {json.dumps(data)}\n\n"
class StreamStatus(StreamBaseResponse):
"""Transient status notification shown to the user during long operations.
Used to provide feedback when the backend performs behind-the-scenes work
(e.g., compacting conversation context on a retry) that would otherwise
leave the user staring at an unexplained pause.
Sent as a proper ``data:`` event so the frontend can display it to the
user. The AI SDK stream parser gracefully skips unknown chunk types
(logs a console warning), so this does not break the stream.
Emitted when the backend is about to enter a phase that would otherwise
leave the user staring at a silent "Thinking…" bubble — e.g. the first
LLM call, the continuation after a tool result, compacting conversation
context on retry, or activating a fallback model. The frontend reads
the latest `data-status` part on the current assistant message and uses
its `message` in place of the generic "Thinking…" copy.
"""
type: ResponseType = ResponseType.STATUS
message: str = Field(..., description="Human-readable status message")
def to_sse(self) -> str:
"""Emit as an AI SDK v5 data part so the client surfaces it as
`type="data-status"` on `message.parts` instead of dropping it as
an unknown chunk type."""
data = {
"type": self.type.value,
"data": {"message": self.message},
}
return f"data: {json.dumps(data)}\n\n"

View File

@@ -11,6 +11,33 @@ import pytest_asyncio
from backend.util import json
# ---------------------------------------------------------------------------
# Env vars that ``ChatConfig`` validators read — must be cleared so explicit
# constructor values are used. Centralised here so adding a new env-backed
# field only needs one update across the SDK test suite.
# ---------------------------------------------------------------------------
_CONFIG_ENV_VARS = (
"CHAT_USE_OPENROUTER",
"CHAT_API_KEY",
"OPEN_ROUTER_API_KEY",
"OPENAI_API_KEY",
"CHAT_BASE_URL",
"OPENROUTER_BASE_URL",
"OPENAI_BASE_URL",
"CHAT_USE_CLAUDE_CODE_SUBSCRIPTION",
"CHAT_USE_CLAUDE_AGENT_SDK",
"CHAT_CLAUDE_AGENT_CROSS_USER_PROMPT_CACHE",
"CHAT_CLAUDE_AGENT_CLI_PATH",
"CLAUDE_AGENT_CLI_PATH",
)
@pytest.fixture()
def _clean_config_env(monkeypatch: pytest.MonkeyPatch) -> None:
"""Clear env-backed CHAT_* settings so ChatConfig uses constructor values."""
for var in _CONFIG_ENV_VARS:
monkeypatch.delenv(var, raising=False)
@pytest_asyncio.fixture(scope="session", loop_scope="session", name="server")
async def _server_noop() -> None:

View File

@@ -22,6 +22,7 @@ from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.response_model import StreamToolOutputAvailable
from .service import (
_RETRYABLE_STREAM_ERROR_CODES,
_classify_final_failure,
_FinalFailure,
_flush_orphan_tool_uses_to_session,
@@ -320,3 +321,22 @@ class TestRetryRollbackContract:
"part-2",
f"{COPILOT_ERROR_PREFIX} Boom",
]
class TestRetryableStreamErrorCodes:
"""SECRT-2252: ``_dispatch_response`` consults this set to decide whether
the StreamError flowing through it should append a retryable marker (UI
shows a retry button) or a terminal one (UI shows ErrorCard only)."""
def test_transient_api_error_is_retryable(self):
assert "transient_api_error" in _RETRYABLE_STREAM_ERROR_CODES
def test_empty_completion_is_retryable(self):
# The adapter emits this for ghost-finished SDK turns. The user
# message ("The model returned an empty response.") only makes sense
# if the UI offers a retry — otherwise the user sees a dead error.
assert "empty_completion" in _RETRYABLE_STREAM_ERROR_CODES
def test_unknown_codes_are_not_retryable(self):
assert "sdk_error" not in _RETRYABLE_STREAM_ERROR_CODES
assert "all_attempts_exhausted" not in _RETRYABLE_STREAM_ERROR_CODES

View File

@@ -36,6 +36,7 @@ from backend.copilot.response_model import (
StreamReasoningStart,
StreamStart,
StreamStartStep,
StreamStatus,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
@@ -374,8 +375,41 @@ class SDKResponseAdapter:
responses.append(StreamFinishStep())
self.step_open = False
# Narrate the gap between "tool returned" and "model emits its
# next chunk". Usually sub-second, but with large tool outputs
# or complex continuations it can stretch long enough for the
# generic "Thinking…" copy to feel dead. The frontend replaces
# it with actual content as soon as the next chunk lands.
if resolved_in_blocks:
responses.append(StreamStatus(message="Analyzing result\u2026"))
elif isinstance(sdk_message, ResultMessage):
self.flush_unresolved_tool_calls(responses)
# SECRT-2252: surface ghost-finished sessions as errors instead of silent finishes.
if sdk_message.subtype == "success" and self._is_empty_completion(
sdk_message
):
if self.step_open:
responses.append(StreamFinishStep())
self.step_open = False
responses.append(
StreamError(
errorText="The model returned an empty response.",
code="empty_completion",
)
)
# Pair with StreamFinish so ``acc.stream_completed`` flips True
# in ``_dispatch_response`` — without it the service-layer
# post-stream branch mis-classifies the turn as "stopped by
# user" and appends a STOPPED_BY_USER_MARKER on top of the
# error marker.
responses.append(StreamFinish())
logger.warning(
"[SDK] [%s] Empty-success ResultMessage detected — "
"emitting stream error instead of silent finish",
(self.session_id or "?")[:12],
)
return responses
# Thinking-only final turn guard: when the model's last LLM
# call after a tool result produced only a ``ThinkingBlock``
# (no ``TextBlock``, no ``ToolUseBlock``) the UI has nothing
@@ -437,6 +471,25 @@ class SDKResponseAdapter:
return responses
def _is_empty_completion(self, msg: ResultMessage) -> bool:
"""True when a success ResultMessage carries no content at all.
Detects the SDK's ghost-finished session: empty ``result``, zero
``output_tokens``, and nothing emitted on the wire this turn (no
text, no reasoning, no tool calls).
"""
if msg.result:
return False
if self.has_started_text or self.has_started_reasoning:
return False
if self.current_tool_calls:
return False
if self._any_tool_results_seen:
return False
usage = msg.usage or {}
output_tokens = usage.get("output_tokens") or 0
return output_tokens == 0
def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None:
"""Start (or restart) a text block if needed."""
if not self.has_started_text or self.has_ended_text:

View File

@@ -25,6 +25,7 @@ from backend.copilot.response_model import (
StreamReasoningEnd,
StreamStart,
StreamStartStep,
StreamStatus,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
@@ -193,13 +194,15 @@ def test_tool_result_emits_output_and_finish_step():
content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")]
)
results = adapter.convert_message(result_msg)
assert len(results) == 2
assert len(results) == 3
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].toolCallId == "t1"
assert results[0].toolName == "find_agent" # prefix stripped
assert results[0].output == "found 3 agents"
assert results[0].success is True
assert isinstance(results[1], StreamFinishStep)
assert isinstance(results[2], StreamStatus)
assert results[2].message == "Analyzing result…"
def test_tool_result_error():
@@ -565,6 +568,105 @@ def test_result_success_does_not_synthesize_when_no_tools_ran():
assert text_deltas == []
def test_result_empty_success_emits_error_and_finish():
"""SECRT-2252: a ``subtype="success"`` ResultMessage with empty ``result``,
no produced content, and ``output_tokens == 0`` is the SDK's ghost-finish
bug. The adapter surfaces it as a ``StreamError`` *paired with*
``StreamFinish`` so the service-layer post-stream flow flips
``acc.stream_completed`` and skips the ``STOPPED_BY_USER_MARKER``
branch. ``SystemMessage(subtype="init")`` opened a step, so the
empty-completion branch must close it before emitting the error."""
adapter = _adapter()
adapter.convert_message(SystemMessage(subtype="init", data={}))
msg = ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="s1",
result=None,
usage={"input_tokens": 5, "output_tokens": 0},
)
results = adapter.convert_message(msg)
types = [type(r).__name__ for r in results]
assert "StreamFinishStep" in types
assert "StreamError" in types
assert "StreamFinish" in types
# Open step must be closed before the error, and the error must
# precede StreamFinish on the wire.
assert types.index("StreamFinishStep") < types.index("StreamError")
assert types.index("StreamError") < types.index("StreamFinish")
err = next(r for r in results if isinstance(r, StreamError))
assert err.code == "empty_completion"
def test_result_empty_success_with_empty_string_result_treated_as_empty():
"""An empty string (not just None) for ``result`` is also empty."""
adapter = _adapter()
adapter.convert_message(SystemMessage(subtype="init", data={}))
msg = ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="s1",
result="",
usage={"output_tokens": 0},
)
results = adapter.convert_message(msg)
err = next(r for r in results if isinstance(r, StreamError))
assert err.code == "empty_completion"
assert any(isinstance(r, StreamFinish) for r in results)
def test_result_success_with_text_emits_finish_not_error():
"""Non-empty success (text was produced) keeps the existing
``StreamFinish`` behaviour — no spurious error."""
adapter = _adapter()
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="hello")], model="test")
)
msg = ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="s1",
result="hello",
usage={"output_tokens": 5},
)
results = adapter.convert_message(msg)
types = [type(r).__name__ for r in results]
assert "StreamFinish" in types
assert "StreamError" not in types
def test_result_success_with_nonzero_output_tokens_not_empty():
"""If ``output_tokens > 0`` but ``result`` is empty, don't classify as
empty — fall through to the existing success path. No prior
AssistantMessage so the `output_tokens` guard is the only thing
keeping `_is_empty_completion()` from firing."""
adapter = _adapter()
adapter.convert_message(SystemMessage(subtype="init", data={}))
msg = ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="s1",
result="",
usage={"output_tokens": 50},
)
results = adapter.convert_message(msg)
types = [type(r).__name__ for r in results]
assert "StreamFinish" in types
assert "StreamError" not in types
def test_result_error_emits_error_and_finish():
adapter = _adapter()
msg = ResultMessage(
@@ -686,6 +788,7 @@ def test_full_conversation_flow():
"StreamToolInputAvailable",
"StreamToolOutputAvailable", # tool result
"StreamFinishStep", # step 1 closed after tool result
"StreamStatus", # user-facing status while continuation is generated
"StreamStartStep", # step 2: continuation text
"StreamTextStart", # new block after tool
"StreamTextDelta", # "I found 2"

View File

@@ -51,7 +51,6 @@ from ..constants import (
COPILOT_RETRYABLE_ERROR_PREFIX,
FRIENDLY_TRANSIENT_MSG,
STOPPED_BY_USER_MARKER,
STREAM_IDLE_TIMEOUT_SECONDS,
is_transient_api_error,
)
from ..session_cleanup import prune_orphan_tool_calls
@@ -185,13 +184,32 @@ _CIRCUIT_BREAKER_ERROR_MSG = (
"Try breaking your request into smaller parts."
)
# Idle timeout: abort the stream if no meaningful SDK message (only heartbeats)
# arrives for this many seconds. Derived from MAX_TOOL_WAIT_SECONDS so the
# invariant "no single tool blocks close to this long" holds by construction —
# long-running tools use the async "start + poll" pattern (initial tool returns
# with a handle, polling tool waits in ≤MAX_TOOL_WAIT_SECONDS chunks), so an
# idle of 2× that genuinely means the SDK itself is stuck.
_IDLE_TIMEOUT_SECONDS = STREAM_IDLE_TIMEOUT_SECONDS
# Two regimes: no tool pending → 30 min (SDK genuinely idle); tool pending →
# 2 h hard cap (lets long sub-AutoPilots run, still backstops a hung tool).
_IDLE_TIMEOUT_SECONDS = 30 * 60
_HUNG_TOOL_CAP_SECONDS = 2 * 60 * 60
def _idle_timeout_threshold(adapter: SDKResponseAdapter) -> int:
"""Pick the idle-timeout threshold for the current heartbeat.
Returns ``_HUNG_TOOL_CAP_SECONDS`` (longer) whenever any tool call is
still pending, so a legitimately long operation isn't killed. Returns
``_IDLE_TIMEOUT_SECONDS`` (shorter) when nothing is pending — the SDK
itself is idle with no work in flight.
"""
if adapter.has_unresolved_tool_calls:
return _HUNG_TOOL_CAP_SECONDS
return _IDLE_TIMEOUT_SECONDS
# StreamError codes that should render as a retryable error in the UI (retry
# button) rather than a terminal ErrorCard. Codes appended via
# ``_append_error_marker`` directly already pass ``retryable=True``; this set
# covers the codes that flow through the adapter -> ``_dispatch_response``.
_RETRYABLE_STREAM_ERROR_CODES: frozenset[str] = frozenset(
{"transient_api_error", "empty_completion"}
)
# Event types that are ephemeral / cosmetic and must NOT be counted toward
@@ -535,6 +553,26 @@ async def _reduce_context(
return ReducedContext(TranscriptBuilder(), False, None, True, True, retry_target)
def _humanise_tool_list(names: list[str]) -> str:
"""Format a list of tool names for user-facing messages.
``["WebSearch"]`` → ``"'WebSearch'"``
``["WebSearch", "run_block"]`` → ``"'WebSearch' and 'run_block'"``
Three or more items collapse to ``"'A', 'B', and 1 other"`` so the
toast stays readable.
"""
if not names:
return ""
quoted = [f"'{n}'" for n in names]
if len(quoted) == 1:
return quoted[0]
if len(quoted) == 2:
return f"{quoted[0]} and {quoted[1]}"
extras = len(quoted) - 2
suffix = "others" if extras > 1 else "other"
return f"{quoted[0]}, {quoted[1]}, and {extras} {suffix}"
def _append_error_marker(
session: ChatSession | None,
display_msg: str,
@@ -901,38 +939,46 @@ async def _iter_sdk_messages(
def _normalize_model_name(raw_model: str) -> str:
"""Normalize a model name for the current routing configuration.
"""Normalize a model name for the **actual** SDK CLI transport.
Two routing modes:
Three transports (see ``ChatConfig.effective_transport``):
1. **OpenRouter active** — the canonical OpenRouter slug is
``"<vendor>/<model>"`` (e.g. ``"anthropic/claude-opus-4.6"``,
``"moonshotai/kimi-k2.6"``). Pass the prefixed name through
1. **OpenRouter** — the canonical OpenRouter slug is
``"<vendor>/<model>"`` (e.g. ``"anthropic/claude-opus-4-6"``,
``"moonshotai/kimi-k2-6"``). Pass the prefixed name through
unchanged so OpenRouter can route to the correct provider. Anthropic
names happen to also resolve when stripped, but non-Anthropic vendors
(Moonshot, Google, etc.) do not — keeping the prefix is the only form
that works for every model in the catalog.
2. **Direct Anthropic** — strip the OpenRouter ``anthropic/`` prefix
and convert dots to hyphens (``"claude-opus-4.6"`` →
``"claude-opus-4-6"``) since the Anthropic Messages API rejects
both the prefix and dot-separated versions. Raises ``ValueError``
when a non-Anthropic vendor slug is paired with direct-Anthropic
mode — silently stripping ``moonshotai/`` would send ``kimi-k2.6``
to the Anthropic API and produce an opaque ``model_not_found``
error far from the misconfiguration source.
2. **Subscription / Direct Anthropic** — strip the OpenRouter
``anthropic/`` prefix and convert dots to hyphens
(``"claude-opus-4.6"`` → ``"claude-opus-4-6"``). The CLI subprocess
(subscription mode) and the Anthropic Messages API both reject the
prefix and dot-separated versions. Raises ``ValueError`` when a
non-Anthropic vendor slug is paired with these transports — silently
stripping ``moonshotai/`` would send ``kimi-k2-6`` to the Anthropic
API / CLI and produce an opaque ``model_not_found`` error far from
the misconfiguration source.
Gating on the **actual transport** (not just config shape) matters
because subscription mode and OpenRouter config can coexist —
``CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true`` paired with a populated
``CHAT_BASE_URL`` / ``CHAT_API_KEY`` (left over from an earlier
OpenRouter setup) used to incorrectly pass ``anthropic/claude-opus-4-7``
to the CLI subprocess, which the CLI rejects.
"""
if config.openrouter_active:
if config.effective_transport == "openrouter":
return raw_model
model = raw_model
if "/" in model:
vendor, model = model.split("/", 1)
if vendor != "anthropic":
raise ValueError(
f"Direct-Anthropic mode (use_openrouter=False or missing "
f"OpenRouter credentials) requires an Anthropic model, got "
f"vendor={vendor!r} from model={raw_model!r}. Set "
f"CHAT_THINKING_STANDARD_MODEL/CHAT_THINKING_ADVANCED_MODEL "
f"to an anthropic/* slug, or enable OpenRouter."
f"{config.effective_transport!r} transport requires an "
f"Anthropic model, got vendor={vendor!r} from "
f"model={raw_model!r}. Set CHAT_THINKING_STANDARD_MODEL/"
f"CHAT_THINKING_ADVANCED_MODEL to an anthropic/* slug, or "
f"enable OpenRouter."
)
return model.replace(".", "-")
@@ -1258,6 +1304,58 @@ def _write_cli_session_to_disk(
return False
def delete_stale_cli_session_file(
sdk_cwd: str,
session_id: str,
log_prefix: str,
) -> bool:
"""Delete the local CLI session file at the predictable path.
Used so a subsequent CLI invocation with ``--session-id`` (no ``--resume``)
doesn't trip ``"Session ID already in use"``. Path-traversal guard:
rejects paths outside the CLI projects base.
Returns True if a file was deleted, False otherwise (missing, traversal,
or unlink failure).
"""
real_path = os.path.realpath(cli_session_path(sdk_cwd, session_id))
if not real_path.startswith(projects_base() + os.sep):
# Mirror ``_write_cli_session_to_disk``'s defense-in-depth: log
# rather than fail silently when the resolved path escapes the
# projects base. In normal operation this is unreachable
# (session_id is a server-generated UUID and ``cli_session_path``
# is deterministic), so a hit indicates a config or tampering
# issue that's worth surfacing.
logger.warning(
"%s CLI session delete path outside projects base: %s",
log_prefix,
os.path.basename(real_path),
)
return False
# Direct unlink — no exists() check (avoids TOCTOU with the file being
# deleted by another process between check and unlink).
try:
Path(real_path).unlink()
logger.info(
"%s Removed stale local CLI session file at %s",
log_prefix,
os.path.basename(real_path),
)
return True
except FileNotFoundError:
return False
except OSError as unlink_err:
# Sanitise log: basename + strerror only (no full path / no raw
# exception which can echo absolute paths back in some libc errors).
logger.warning(
"%s Failed to remove stale local CLI session file %s: %s",
log_prefix,
os.path.basename(real_path),
unlink_err.strerror or type(unlink_err).__name__,
)
return False
def read_cli_session_from_disk(
sdk_cwd: str,
session_id: str,
@@ -2026,7 +2124,7 @@ def _dispatch_response(
_append_error_marker(
ctx.session,
response.errorText,
retryable=(response.code == "transient_api_error"),
retryable=response.code in _RETRYABLE_STREAM_ERROR_CODES,
)
if isinstance(response, StreamReasoningStart):
@@ -2354,6 +2452,13 @@ async def _run_stream_attempt(
for ev in ctx.compaction.emit_pre_query(ctx.session):
yield ev
# Narrate the silent gap between dispatching the query and the
# SDK's first real chunk — usually <1s but can stretch to several
# seconds on cold-starts or large contexts. The frontend prefers
# this over the generic "Thinking…" copy; fast turns replace it
# with content immediately.
yield StreamStatus(message="Contacting the model\u2026")
if ctx.attachments.image_blocks:
content_blocks: list[dict[str, Any]] = [
*ctx.attachments.image_blocks,
@@ -2388,21 +2493,41 @@ async def _run_stream_attempt(
yield ev
yield StreamHeartbeat()
# Idle timeout: abort if the SDK has been silent for too long.
# Long-running tools use the async "start + poll" pattern so
# the MCP handler never blocks longer than the poll cap (5 min)
# — a 10-min gap here means the SDK itself is stuck.
# Threshold flips to the long cap while a tool is pending; clock never resets.
idle_seconds = time.monotonic() - _last_real_msg_time
if idle_seconds >= _IDLE_TIMEOUT_SECONDS:
threshold = _idle_timeout_threshold(state.adapter)
if idle_seconds >= threshold:
unresolved_tool_names = sorted(
{
info.get("name", "unknown")
for tid, info in state.adapter.current_tool_calls.items()
if tid not in state.adapter.resolved_tool_calls
}
)
logger.error(
"%s Idle timeout after %.0fs — aborting stream",
"%s Idle timeout after %.0fs (threshold=%ds, "
"unresolved tools: %s) — aborting stream",
ctx.log_prefix,
idle_seconds,
threshold,
", ".join(unresolved_tool_names) or "none",
)
# The retryable marker written to the session omits
# the `[code:<id>]` prefix — the AI SDK serializer
# (`StreamError.to_sse`) attaches that automatically
# on the wire so the frontend can still parse a
# machine-readable code out of the otherwise opaque
# `{type, errorText}` schema.
stream_error_code = "idle_timeout"
tool_phrase = (
f" while running {_humanise_tool_list(unresolved_tool_names)}"
if unresolved_tool_names
else ""
)
stream_error_msg = (
"The session has been idle for too long. Please try again."
f"AutoPilot stopped responding{tool_phrase}. "
"This usually means a tool got stuck. Please try again."
)
stream_error_code = "idle_timeout"
_append_error_marker(ctx.session, stream_error_msg, retryable=True)
yield StreamError(
errorText=stream_error_msg,
@@ -3082,22 +3207,7 @@ async def _restore_cli_session_for_turn(
# session_id with "Session ID already in use". T1 may have
# left a valid file at this path; we clear it so the fallback
# path (session_id= without --resume) can create a new session.
_stale_path = os.path.realpath(cli_session_path(sdk_cwd, session_id))
if Path(_stale_path).exists() and _stale_path.startswith(
projects_base() + os.sep
):
try:
Path(_stale_path).unlink()
logger.debug(
"%s Removed stale local CLI session file for clean fallback",
log_prefix,
)
except OSError as _unlink_err:
logger.debug(
"%s Failed to remove stale local session file: %s",
log_prefix,
_unlink_err,
)
delete_stale_cli_session_file(sdk_cwd, session_id, log_prefix)
if cli_restore is not None:
result.transcript_content = stripped
@@ -3943,21 +4053,21 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues
if ctx.use_resume and ctx.resume_file:
sdk_options_kwargs_retry["resume"] = ctx.resume_file
sdk_options_kwargs_retry.pop("session_id", None)
elif "session_id" in sdk_options_kwargs:
# Initial invocation used session_id (T1 or mode-switch
# T1): keep it so the CLI writes the session file to the
# predictable path for upload_transcript(). Storage is
# ephemeral per invocation, so no "Session ID already in
# use" conflict occurs — no prior file was restored.
else:
# No --resume on this retry. Whether we entered with
# ``session_id`` (T1, mode-switch) or with ``--resume`` (T2+),
# we want the recovery turn's CLI write to land on the
# predictable ``cli_session_path(.., session_id)`` so the
# post-turn ``upload_transcript`` actually picks up the
# rescued (compacted) content. Without this, a T2+ retry
# would drop session_id to dodge "Session ID already in use",
# write to a random path, and the upload would silently grab
# the stale pre-failure file — leaving GCS bloated and
# guaranteeing the next turn re-trips prompt-too-long.
if sdk_cwd:
delete_stale_cli_session_file(sdk_cwd, session_id, log_prefix)
sdk_options_kwargs_retry.pop("resume", None)
sdk_options_kwargs_retry["session_id"] = session_id
else:
# T2+ retry without --resume: initial invocation used
# --resume, which restored the T1 session file to local
# storage. Re-using session_id without --resume would
# fail with "Session ID already in use".
sdk_options_kwargs_retry.pop("resume", None)
sdk_options_kwargs_retry.pop("session_id", None)
# Recompute system_prompt for retry — the preset is safe on
# every turn (requires CLI ≥ 2.1.98, installed in the Docker
# image and configured via CHAT_CLAUDE_AGENT_CLI_PATH).

View File

@@ -13,6 +13,9 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from claude_agent_sdk import AssistantMessage, TextBlock, ToolUseBlock
from backend.copilot import config as cfg_mod
from backend.copilot.config import ChatConfig
from .conftest import build_test_transcript as _build_transcript
from .service import (
_RETRY_TARGET_TOKENS,
@@ -23,6 +26,7 @@ from .service import (
_iter_sdk_messages,
_normalize_model_name,
_reduce_context,
_resolve_sdk_model_for_request,
_restore_cli_session_for_turn,
_TokenUsage,
)
@@ -373,15 +377,15 @@ class TestNormalizeModelName:
"""
@pytest.fixture
def _direct_anthropic_config(self, monkeypatch: pytest.MonkeyPatch):
def _direct_anthropic_config(
self, monkeypatch: pytest.MonkeyPatch, _clean_config_env: None
):
"""Force ``config.openrouter_active = False`` for prefix-strip tests.
Pins the SDK model fields to anthropic/* so the new
``_validate_sdk_model_vendor_compatibility`` model_validator
permits ChatConfig construction.
"""
from backend.copilot import config as cfg_mod
cfg = cfg_mod.ChatConfig(
use_openrouter=False,
api_key=None,
@@ -393,10 +397,10 @@ class TestNormalizeModelName:
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
@pytest.fixture
def _openrouter_config(self, monkeypatch: pytest.MonkeyPatch):
def _openrouter_config(
self, monkeypatch: pytest.MonkeyPatch, _clean_config_env: None
):
"""Force ``config.openrouter_active = True`` for slug-preservation tests."""
from backend.copilot import config as cfg_mod
cfg = cfg_mod.ChatConfig(
use_openrouter=True,
api_key="or-key",
@@ -445,6 +449,172 @@ class TestNormalizeModelName:
"""Non-Anthropic vendors (Moonshot) require the prefix to route."""
assert _normalize_model_name("moonshotai/kimi-k2.6") == "moonshotai/kimi-k2.6"
@pytest.fixture
def _subscription_with_openrouter_config(
self, monkeypatch: pytest.MonkeyPatch, _clean_config_env: None
):
"""Subscription mode with leftover OpenRouter base_url + api_key.
Reproduces the bug: ``CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true`` plus
a populated ``CHAT_BASE_URL`` (e.g. left over from an earlier
OpenRouter setup) used to incorrectly preserve the OpenRouter slug
because the gate checked config shape (``openrouter_active``) not
actual transport. The CLI subprocess uses OAuth here and rejects
the OpenRouter format.
"""
cfg = cfg_mod.ChatConfig(
use_openrouter=True,
api_key="or-key",
base_url="https://openrouter.ai/api/v1",
use_claude_code_subscription=True,
)
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
def test_subscription_strips_anthropic_prefix_despite_openrouter_config(
self, _subscription_with_openrouter_config
):
"""Subscription transport must produce the CLI-friendly form even
when OpenRouter base_url + api_key are set — the CLI uses OAuth
and ignores those fields, so the OpenRouter slug would be rejected."""
assert _normalize_model_name("anthropic/claude-opus-4.7") == "claude-opus-4-7"
def test_subscription_rejects_non_anthropic_vendor(
self, _subscription_with_openrouter_config
):
"""The CLI subprocess can only talk to Anthropic models — Kimi via
Moonshot must raise so the resolver falls back to a tier default
instead of feeding an unroutable slug to the CLI."""
with pytest.raises(ValueError, match="requires an Anthropic model"):
_normalize_model_name("moonshotai/kimi-k2.6")
# ---------------------------------------------------------------------------
# ChatConfig.effective_transport — single source of truth for "which
# transport will the SDK CLI actually use?"
# ---------------------------------------------------------------------------
class TestEffectiveTransport:
"""Subscription mode wins over OpenRouter even when OpenRouter
base_url + api_key are set, because the CLI subprocess uses OAuth and
ignores ``CHAT_BASE_URL`` / ``CHAT_API_KEY`` (see ``build_sdk_env``
mode 1). Picking the right transport here is what lets
``_normalize_model_name`` produce the correct model-name format.
"""
def test_subscription_wins_over_openrouter_config(self, _clean_config_env):
cfg = ChatConfig(
use_openrouter=True,
api_key="or-key",
base_url="https://openrouter.ai/api/v1",
use_claude_code_subscription=True,
)
assert cfg.effective_transport == "subscription"
# ``openrouter_active`` is still True (config-shape check) but
# the actual transport is subscription.
assert cfg.openrouter_active is True
def test_openrouter_when_subscription_disabled(self, _clean_config_env):
cfg = ChatConfig(
use_openrouter=True,
api_key="or-key",
base_url="https://openrouter.ai/api/v1",
use_claude_code_subscription=False,
)
assert cfg.effective_transport == "openrouter"
def test_direct_anthropic_when_no_openrouter_no_subscription(
self, _clean_config_env
):
cfg = ChatConfig(
use_openrouter=False,
api_key=None,
base_url=None,
use_claude_code_subscription=False,
thinking_standard_model="anthropic/claude-sonnet-4-6",
thinking_advanced_model="anthropic/claude-opus-4-7",
)
assert cfg.effective_transport == "direct_anthropic"
def test_subscription_alone_is_subscription(self, _clean_config_env):
cfg = ChatConfig(
use_openrouter=False,
api_key=None,
base_url=None,
use_claude_code_subscription=True,
)
assert cfg.effective_transport == "subscription"
# ---------------------------------------------------------------------------
# _resolve_sdk_model_for_request — transport-aware LD-override normalisation
# ---------------------------------------------------------------------------
class TestResolveSdkModelForRequestTransportAware:
"""When subscription mode is on but the deployment also has OpenRouter
config populated (e.g. ``CHAT_BASE_URL`` left over from a previous
setup), an LD-served override must be normalised for the **subscription
CLI**, not passed through as the OpenRouter slug. The CLI subprocess
uses OAuth and rejects ``anthropic/claude-opus-4.7`` with the model
error reproduced in local debugging:
``There's an issue with the selected model
(anthropic/claude-opus-4.7). It may not exist or you may not have
access to it.``
"""
@pytest.mark.asyncio
async def test_subscription_advanced_override_normalised_for_cli(
self, monkeypatch: pytest.MonkeyPatch, _clean_config_env: None
):
cfg = cfg_mod.ChatConfig(
thinking_standard_model="anthropic/claude-sonnet-4-6",
thinking_advanced_model="anthropic/claude-opus-4.7",
claude_agent_model=None,
use_openrouter=True,
api_key="or-key",
base_url="https://openrouter.ai/api/v1",
use_claude_code_subscription=True,
)
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
with patch(
"backend.copilot.sdk.service._resolve_thinking_model_for_user",
new=AsyncMock(return_value="anthropic/claude-opus-4.7"),
):
resolved = await _resolve_sdk_model_for_request(
model="advanced", session_id="sess-adv", user_id="user-1"
)
# NOT the OpenRouter slug, NOT None — the CLI-friendly hyphenated form.
assert resolved == "claude-opus-4-7"
@pytest.mark.asyncio
async def test_subscription_standard_no_override_returns_none(
self, monkeypatch: pytest.MonkeyPatch, _clean_config_env: None
):
"""When LD agrees with the config default, subscription mode still
wins on the standard tier — returns ``None`` so the CLI picks the
subscription default model."""
cfg = cfg_mod.ChatConfig(
thinking_standard_model="anthropic/claude-sonnet-4-6",
claude_agent_model=None,
use_openrouter=True,
api_key="or-key",
base_url="https://openrouter.ai/api/v1",
use_claude_code_subscription=True,
)
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
with patch(
"backend.copilot.sdk.service._resolve_thinking_model_for_user",
new=AsyncMock(return_value="anthropic/claude-sonnet-4-6"),
):
resolved = await _resolve_sdk_model_for_request(
model="standard", session_id="sess-std", user_id="user-1"
)
assert resolved is None
# ---------------------------------------------------------------------------
# _TokenUsage — null-safe accumulation (OpenRouter initial-stream-event bug)
@@ -566,17 +736,20 @@ def _build_retry_sdk_options(
ctx_resume_file: str | None,
session_id: str,
) -> dict:
"""Mirror the retry branch in stream_chat_completion_sdk."""
"""Mirror the retry branch in stream_chat_completion_sdk.
Production-side companion: ``delete_stale_cli_session_file`` is invoked
on every non-resume retry path so the CLI doesn't trip "Session ID
already in use" when we re-attach ``session_id``. This helper only
mirrors the kwarg shape (file-system side effect is tested separately).
"""
retry: dict = dict(initial_kwargs)
if ctx_use_resume and ctx_resume_file:
retry["resume"] = ctx_resume_file
retry.pop("session_id", None)
elif "session_id" in initial_kwargs:
retry.pop("resume", None)
retry["session_id"] = session_id
else:
retry.pop("resume", None)
retry.pop("session_id", None)
retry["session_id"] = session_id
return retry
@@ -648,12 +821,21 @@ class TestSdkSessionIdSelection:
assert retry.get("session_id") == self.SESSION_ID
assert "resume" not in retry
def test_retry_removes_session_id_for_t2_plus(self):
"""Retry for T2+ (initial used --resume) removes session_id to avoid conflict."""
def test_retry_keeps_session_id_for_t2_plus(self):
"""Retry for T2+ now keeps session_id so the recovery turn writes to
the predictable ``cli_session_path`` and gets uploaded. Production
clears the stale local file via ``delete_stale_cli_session_file``
before this branch runs to dodge "Session ID already in use".
Regression guard for SENTRY-1207: previously this branch dropped
session_id, the CLI wrote to a random path, and the post-turn
upload silently grabbed the stale pre-failure file — so GCS stayed
bloated and every subsequent turn re-tripped prompt-too-long.
"""
initial = _build_sdk_options(True, self.SESSION_ID, self.SESSION_ID)
# T2+ retry where context reduction dropped --resume
retry = _build_retry_sdk_options(initial, False, None, self.SESSION_ID)
assert "session_id" not in retry
assert retry.get("session_id") == self.SESSION_ID
assert "resume" not in retry
def test_retry_t2_with_resume_sets_resume(self):
@@ -1127,3 +1309,78 @@ class TestCompactionTargetTokens:
# Target derived from the RUNTIME model, not the compactor model.
assert captured["target_tokens"] == 12345
# ---------------------------------------------------------------------------
# delete_stale_cli_session_file — clears a leftover local session file so a
# subsequent --session-id (no --resume) invocation doesn't trip "Session ID
# already in use". Critical for the prompt-too-long retry path.
# ---------------------------------------------------------------------------
class TestDeleteStaleCliSessionFile:
def test_deletes_file_when_present(self, tmp_path) -> None:
from backend.copilot.sdk.service import delete_stale_cli_session_file
sdk_cwd = str(tmp_path / "cwd")
session_id = "sess-deadbeef"
with (
patch(
"backend.copilot.sdk.service.cli_session_path",
return_value=str(tmp_path / "session.jsonl"),
),
patch(
"backend.copilot.sdk.service.projects_base",
return_value=str(tmp_path),
),
):
target = tmp_path / "session.jsonl"
target.write_text("{}\n")
removed = delete_stale_cli_session_file(sdk_cwd, session_id, "[t]")
assert removed is True
assert not target.exists()
def test_returns_false_when_file_missing(self, tmp_path) -> None:
from backend.copilot.sdk.service import delete_stale_cli_session_file
with (
patch(
"backend.copilot.sdk.service.cli_session_path",
return_value=str(tmp_path / "missing.jsonl"),
),
patch(
"backend.copilot.sdk.service.projects_base",
return_value=str(tmp_path),
),
):
removed = delete_stale_cli_session_file("/cwd", "sess", "[t]")
assert removed is False
def test_path_traversal_guard_rejects_outside_projects_base(self, tmp_path) -> None:
"""Refuse to delete files outside the projects base, even if they exist."""
from backend.copilot.sdk.service import delete_stale_cli_session_file
outside = tmp_path / "outside.jsonl"
outside.write_text("data")
projects = tmp_path / "projects"
projects.mkdir()
with (
patch(
"backend.copilot.sdk.service.cli_session_path",
return_value=str(outside),
),
patch(
"backend.copilot.sdk.service.projects_base",
return_value=str(projects),
),
):
removed = delete_stale_cli_session_file("/cwd", "sess", "[t]")
# File was outside projects base — guard rejected, file untouched.
assert removed is False
assert outside.exists()

View File

@@ -12,8 +12,11 @@ import pytest
from backend.copilot import config as cfg_mod
from .service import (
_HUNG_TOOL_CAP_SECONDS,
_IDLE_TIMEOUT_SECONDS,
_build_system_prompt_value,
_humanise_tool_list,
_idle_timeout_threshold,
_is_sdk_disconnect_error,
_normalize_model_name,
_prepare_file_attachments,
@@ -323,27 +326,11 @@ class TestCleanupSdkToolResults:
# ---------------------------------------------------------------------------
# Env vars that ChatConfig validators read — must be cleared so explicit
# constructor values are used.
# Env-cleanup fixture is shared via ``conftest._clean_config_env``. This
# file exposes a re-export for callers that don't rely on conftest discovery
# (kept for backwards compatibility — pytest finds the conftest fixture
# automatically without an explicit import).
# ---------------------------------------------------------------------------
_CONFIG_ENV_VARS = (
"CHAT_USE_OPENROUTER",
"CHAT_API_KEY",
"OPEN_ROUTER_API_KEY",
"OPENAI_API_KEY",
"CHAT_BASE_URL",
"OPENROUTER_BASE_URL",
"OPENAI_BASE_URL",
"CHAT_USE_CLAUDE_CODE_SUBSCRIPTION",
"CHAT_USE_CLAUDE_AGENT_SDK",
"CHAT_CLAUDE_AGENT_CROSS_USER_PROMPT_CACHE",
)
@pytest.fixture()
def _clean_config_env(monkeypatch: pytest.MonkeyPatch) -> None:
for var in _CONFIG_ENV_VARS:
monkeypatch.delenv(var, raising=False)
class TestNormalizeModelName:
@@ -617,7 +604,13 @@ class TestResolveSdkModelForRequestLdFallback:
on ``copilot-model-routing[thinking][standard]`` returned
``None`` (CLI picked subscription default Opus), silently
ignoring the LD override. An LD value different from the
config default is an explicit admin decision and must win."""
config default is an explicit admin decision and must win.
Subscription transport rejects non-Anthropic vendors (the CLI
subprocess can't talk to Moonshot), so the resolver fails soft
to the tier default normalised for the subscription transport
(``claude-sonnet-4-6``) — not ``None``, which would silently
re-introduce the old subscription-default bypass."""
cfg = cfg_mod.ChatConfig(
thinking_standard_model="anthropic/claude-sonnet-4-6",
claude_agent_model=None,
@@ -635,8 +628,9 @@ class TestResolveSdkModelForRequestLdFallback:
resolved = await _resolve_sdk_model_for_request(
model="standard", session_id="sess-std-sub", user_id="user-1"
)
# Expect LD-served Kimi, NOT None (the old subscription-default bypass)
assert resolved == "moonshotai/kimi-k2.6"
# Kimi can't be served by the subscription CLI; fail-soft to
# the tier default normalised for the active transport.
assert resolved == "claude-sonnet-4-6"
@pytest.mark.asyncio
async def test_standard_subscription_survives_trailing_whitespace_in_env(
@@ -703,7 +697,10 @@ class TestResolveSdkModelForRequestLdFallback:
"""Subscription mode bypasses LD only on the standard tier —
the advanced tier always consults LD because the user explicitly
asked for the premium path. A subscription + advanced request
with LD-served Opus must return Opus (not ``None``)."""
with LD-served Opus must return Opus normalised for the
subscription CLI (``claude-opus-4-7``), not the OpenRouter slug
``anthropic/claude-opus-4.7`` which the CLI subprocess rejects
even when ``CHAT_BASE_URL`` is set to the OpenRouter proxy."""
cfg = cfg_mod.ChatConfig(
thinking_standard_model="anthropic/claude-sonnet-4-6",
thinking_advanced_model="anthropic/claude-opus-4.7",
@@ -722,7 +719,7 @@ class TestResolveSdkModelForRequestLdFallback:
resolved = await _resolve_sdk_model_for_request(
model="advanced", session_id="sess-adv-sub", user_id="user-1"
)
assert resolved == "anthropic/claude-opus-4.7"
assert resolved == "claude-opus-4-7"
# ---------------------------------------------------------------------------
@@ -907,14 +904,51 @@ class TestSystemPromptPreset:
assert cfg.claude_agent_cross_user_prompt_cache is False
class TestIdleTimeoutConstant:
"""SECRT-2247: long-running work now uses async start+poll pattern
(run_sub_session / run_agent), so no single MCP tool call ever blocks
the stream close to the idle limit. The plain 10-min cap from the
original code is restored."""
class TestStreamErrorCodePrefix:
"""StreamError.to_sse auto-prefixes errorText with `[code:<id>]` when a
code is set, so the frontend can parse a machine-readable code out of
the AI-SDK's strict `{type, errorText}` schema."""
def test_idle_timeout_is_10_min(self):
assert _IDLE_TIMEOUT_SECONDS == 10 * 60
def test_auto_prefix_when_code_set(self):
from backend.copilot.response_model import StreamError
sse = StreamError(errorText="Boom", code="idle_timeout").to_sse()
assert '"errorText":"[code:idle_timeout] Boom"' in sse
def test_no_prefix_when_code_missing(self):
from backend.copilot.response_model import StreamError
sse = StreamError(errorText="Boom").to_sse()
assert '"errorText":"Boom"' in sse
def test_does_not_double_prefix(self):
from backend.copilot.response_model import StreamError
sse = StreamError(errorText="[code:x] Boom", code="x").to_sse()
assert "[code:x] [code:x]" not in sse
assert '"errorText":"[code:x] Boom"' in sse
class TestHumaniseToolList:
"""Tool-name formatter used to build the idle-timeout error message."""
def test_empty_returns_empty_string(self):
assert _humanise_tool_list([]) == ""
def test_single_tool_is_quoted(self):
assert _humanise_tool_list(["WebSearch"]) == "'WebSearch'"
def test_two_tools_are_joined_with_and(self):
assert (
_humanise_tool_list(["WebSearch", "run_block"])
== "'WebSearch' and 'run_block'"
)
def test_three_uses_singular_other(self):
assert _humanise_tool_list(["a", "b", "c"]) == "'a', 'b', and 1 other"
def test_four_plus_uses_plural_others(self):
assert _humanise_tool_list(["a", "b", "c", "d"]) == "'a', 'b', and 2 others"
# ---------------------------------------------------------------------------
@@ -1137,3 +1171,61 @@ class TestMoonshotHelperReexports:
from .service import _override_cost_for_moonshot
assert _override_cost_for_moonshot is canonical
class TestIdleTimeoutThreshold:
"""SECRT-2247: stream uses two idle thresholds. The shorter 30-min threshold
fires when the SDK is idle with no tool pending. The longer 2-hour cap
applies while any tool call is pending so a 45-min sub-AutoPilot isn't
killed, but a truly hung tool still eventually frees session resources."""
def _make_adapter(self, current: dict, resolved: set):
from backend.copilot.sdk.response_adapter import SDKResponseAdapter
adapter = SDKResponseAdapter(session_id="test")
adapter.current_tool_calls = current
adapter.resolved_tool_calls = resolved
return adapter
def test_threshold_uses_long_cap_with_unresolved_tool_call(self):
adapter = self._make_adapter(
current={"t1": {"name": "run_block"}},
resolved=set(),
)
assert _idle_timeout_threshold(adapter) == _HUNG_TOOL_CAP_SECONDS
def test_threshold_uses_short_cap_when_all_tools_resolved(self):
adapter = self._make_adapter(
current={"t1": {"name": "find_agent"}},
resolved={"t1"},
)
assert _idle_timeout_threshold(adapter) == _IDLE_TIMEOUT_SECONDS
def test_threshold_uses_short_cap_with_no_tool_calls(self):
adapter = self._make_adapter(current={}, resolved=set())
assert _idle_timeout_threshold(adapter) == _IDLE_TIMEOUT_SECONDS
def test_threshold_uses_long_cap_with_mixed_resolved_and_pending(self):
adapter = self._make_adapter(
current={
"t1": {"name": "find_agent"},
"t2": {"name": "run_block"},
},
resolved={"t1"},
)
assert _idle_timeout_threshold(adapter) == _HUNG_TOOL_CAP_SECONDS
def test_idle_timeout_is_30_min_not_the_old_10(self):
# Regression guard: the old 10-min value killed long tool calls
# (SECRT-2247). New idle-without-tools cap is 30 min.
assert _IDLE_TIMEOUT_SECONDS == 30 * 60
def test_hung_tool_cap_is_2_hours(self):
# Hard cap protects against a hung tool leaking resources forever.
# 2 hours is plenty for any legitimate sub-AutoPilot or graph run.
assert _HUNG_TOOL_CAP_SECONDS == 2 * 60 * 60
def test_long_cap_is_strictly_longer_than_short_cap(self):
# The whole point of the two-regime design: pending tools get more
# patience than pure idle.
assert _HUNG_TOOL_CAP_SECONDS > _IDLE_TIMEOUT_SECONDS

View File

@@ -48,6 +48,7 @@ from .response_model import (
StreamReasoningStart,
StreamStart,
StreamStartStep,
StreamStatus,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
@@ -89,8 +90,19 @@ class ActiveSession:
def _get_session_meta_key(session_id: str) -> str:
"""Get Redis key for session metadata (keyed by session_id).
Hash-tag braces colocate this key with ``pending_messages._buffer_key``
on the same Redis Cluster slot — the gated-rpush Lua script touches both
keys atomically and would CROSSSLOT-fail if they hashed to different
shards.
"""
return f"{config.session_meta_prefix}{{{session_id}}}"
def get_session_meta_key(session_id: str) -> str:
"""Get Redis key for session metadata (keyed by session_id)."""
return f"{config.session_meta_prefix}{session_id}"
return _get_session_meta_key(session_id)
def _get_turn_stream_key(turn_id: str) -> str:
@@ -1093,6 +1105,7 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
ResponseType.ERROR.value: StreamError,
ResponseType.USAGE.value: StreamUsage,
ResponseType.HEARTBEAT.value: StreamHeartbeat,
ResponseType.STATUS.value: StreamStatus,
}
chunk_type = chunk_data.get("type")

View File

@@ -343,3 +343,73 @@ async def test_mark_session_completed_survives_lock_release_redis_error():
isinstance(call.args[1], stream_registry.StreamFinish)
for call in publish_mock.call_args_list
), "StreamFinish must still be published even if lock DELETE raises"
# ---------------------------------------------------------------------------
# Replays must contain protocol chunks only. Redis cursor data parts are not
# emitted because AI SDK resume needs the complete stream envelope from 0-0.
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_subscribe_to_session_replays_chunks_without_cursor_parts():
"""During replay, the subscriber queue contains chunks plus terminal finish."""
import orjson
from backend.copilot.response_model import (
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
)
# Three chunks recorded in Redis for a completed turn. Completed status
# means the listener branch is skipped and only the replay path runs,
# which keeps the test hermetic.
stream_key_msgs = [
(
"9999-0",
{"data": orjson.dumps(StreamTextStart(id="blk-1").model_dump()).decode()},
),
(
"9999-1",
{
"data": orjson.dumps(
StreamTextDelta(id="blk-1", delta="hi").model_dump()
).decode()
},
),
(
"9999-2",
{"data": orjson.dumps(StreamTextEnd(id="blk-1").model_dump()).decode()},
),
]
fake_redis = AsyncMock()
fake_redis.hgetall = AsyncMock(
return_value={
"user_id": "u1",
"session_id": "sess-1",
"turn_id": "turn-1",
"status": "completed", # finished → no listener task
}
)
fake_redis.xread = AsyncMock(return_value=[("stream-key", stream_key_msgs)])
with patch.object(
stream_registry, "get_redis_async", new=AsyncMock(return_value=fake_redis)
):
queue = await stream_registry.subscribe_to_session(
session_id="sess-1", user_id="u1", last_message_id="0-0"
)
assert queue is not None
delivered = []
while not queue.empty():
delivered.append(queue.get_nowait())
assert len(delivered) == 4
assert isinstance(delivered[0], StreamTextStart)
assert isinstance(delivered[1], StreamTextDelta)
assert isinstance(delivered[2], StreamTextEnd)
assert isinstance(delivered[3], stream_registry.StreamFinish)

View File

@@ -10,6 +10,7 @@ from backend.data.execution import (
ExecutionStatus,
GraphExecution,
GraphExecutionEvent,
exec_channel,
)
logger = logging.getLogger(__name__)
@@ -81,7 +82,7 @@ async def wait_for_execution(
)
event_bus = AsyncRedisExecutionEventBus()
channel_key = f"{user_id}/{graph_id}/{execution_id}"
channel_key = exec_channel(user_id, graph_id, execution_id)
# Mutable container so _subscribe_and_wait can surface the task even if
# asyncio.wait_for cancels the coroutine before it returns.

View File

@@ -949,24 +949,45 @@ class UserCredit(UserCreditBase):
f"Top up amount must be at least 500 credits and multiple of 100 but is {amount}"
)
# Resolve the Stripe Product ID from LD; when unset (default), keep the
# legacy inline product_data path (Stripe creates an ephemeral product
# per Checkout). When set, reference the canonical Product so all
# top-ups group under one entity in Stripe Dashboard reporting; the
# amount stays dynamic via unit_amount.
topup_product_id = await get_feature_flag_value(
Flag.STRIPE_PRODUCT_ID_TOPUP.value, user_id, default=None
)
line_items: list[stripe.checkout.Session.CreateParamsLineItem] = (
[
{
"price_data": {
"currency": "usd",
"product": topup_product_id,
"unit_amount": amount,
},
"quantity": 1,
}
]
if isinstance(topup_product_id, str) and topup_product_id
else [
{
"price_data": {
"currency": "usd",
"product_data": {"name": "AutoGPT Platform Credits"},
"unit_amount": amount,
},
"quantity": 1,
}
]
)
# Create checkout session
# https://docs.stripe.com/checkout/quickstart?client=react
# unit_amount param is always in the smallest currency unit (so cents for usd)
# which is equal to amount of credits
checkout_session = stripe.checkout.Session.create(
customer=await get_stripe_customer_id(user_id),
line_items=[
{
"price_data": {
"currency": "usd",
"product_data": {
"name": "AutoGPT Platform Credits",
},
"unit_amount": amount,
},
"quantity": 1,
}
],
line_items=line_items,
mode="payment",
ui_mode="hosted",
payment_intent_data={"setup_future_usage": "off_session"},
@@ -1442,6 +1463,7 @@ async def get_proration_credit_cents(user_id: str, monthly_cost_cents: int) -> i
# (move right) from downgrades (move left); ENTERPRISE is admin-managed and
# never reached via self-service flows.
_TIER_ORDER: tuple[SubscriptionTier, ...] = (
SubscriptionTier.NO_TIER,
SubscriptionTier.BASIC,
SubscriptionTier.PRO,
SubscriptionTier.MAX,
@@ -1479,6 +1501,30 @@ async def _get_active_subscription(customer_id: str) -> stripe.Subscription | No
return None
async def get_active_subscription_period_end(user_id: str) -> int | None:
"""Return the Unix timestamp of the active sub's current_period_end, or None.
Used to surface "next invoice on {date}" in upgrade dialog UX. Returns None
for users without a Stripe customer or active sub. Stripe failures swallow
to None — UX falls back to generic copy if the lookup misfires.
"""
user = await get_user_by_id(user_id)
if not user.stripe_customer_id:
return None
try:
sub = await _get_active_subscription(user.stripe_customer_id)
except stripe.StripeError:
logger.warning(
"get_active_subscription_period_end: Stripe lookup failed for user %s",
user_id,
)
return None
if sub is None:
return None
period_end = sub.current_period_end
return int(period_end) if period_end else None
# Substrings Stripe uses in InvalidRequestError messages when the schedule is
# already in a terminal state (released / completed / canceled) and therefore
# cannot be released again. We only swallow the error when one of these appears;
@@ -1670,7 +1716,7 @@ async def modify_stripe_subscription_for_tier(
user = await get_user_by_id(user_id)
if not user.stripe_customer_id:
return False
current_tier = user.subscription_tier or SubscriptionTier.BASIC
current_tier = user.subscription_tier or SubscriptionTier.NO_TIER
sub = await _get_active_subscription(user.stripe_customer_id)
if sub is None:
@@ -1891,7 +1937,7 @@ async def get_pending_subscription_change(
return None
effective_at = datetime.fromtimestamp(period_end, tz=timezone.utc)
if sub.cancel_at_period_end:
return SubscriptionTier.BASIC, effective_at
return SubscriptionTier.NO_TIER, effective_at
if not sub.schedule:
return None
schedule_id = sub.schedule if isinstance(sub.schedule, str) else sub.schedule.id
@@ -1986,6 +2032,7 @@ async def create_subscription_checkout(
success_url=success_url,
cancel_url=cancel_url,
subscription_data={"metadata": {"user_id": user_id, "tier": tier.value}},
allow_promotion_codes=True,
)
if not session.url:
# An empty checkout URL for a paid upgrade is always an error; surfacing it
@@ -2071,7 +2118,7 @@ async def sync_subscription_from_stripe(stripe_subscription: dict) -> None:
# ENTERPRISE user to a different tier — if a user on ENTERPRISE somehow has
# a self-service Stripe sub, it's a data-consistency issue for an operator,
# not something the webhook should automatically "fix".
current_tier = user.subscriptionTier or SubscriptionTier.BASIC
current_tier = user.subscriptionTier or SubscriptionTier.NO_TIER
if current_tier == SubscriptionTier.ENTERPRISE:
logger.warning(
"sync_subscription_from_stripe: refusing to overwrite ENTERPRISE tier"
@@ -2170,7 +2217,7 @@ async def sync_subscription_from_stripe(stripe_subscription: dict) -> None:
current_tier.value,
)
return
tier = SubscriptionTier.BASIC
tier = SubscriptionTier.NO_TIER
# Idempotency: Stripe retries webhooks on delivery failure, and several event
# types map to the same final tier. Skip the DB write + cache invalidation
# when the tier is already correct to avoid redundant writes on replay.
@@ -2269,7 +2316,7 @@ async def handle_subscription_payment_failure(invoice: dict) -> None:
)
return
current_tier = user.subscriptionTier or SubscriptionTier.BASIC
current_tier = user.subscriptionTier or SubscriptionTier.NO_TIER
if current_tier == SubscriptionTier.ENTERPRISE:
logger.warning(
"handle_subscription_payment_failure: skipping ENTERPRISE user %s"
@@ -2310,12 +2357,19 @@ async def handle_subscription_payment_failure(invoice: dict) -> None:
}
),
)
# Balance covered the invoice. Pay the Stripe invoice so Stripe's dunning
# system stops retrying it — without this call Stripe would retry automatically
# and re-trigger this webhook, causing double-deductions each retry cycle.
# Balance covered the invoice. Pay the Stripe invoice with
# ``paid_out_of_band=True`` so Stripe marks the invoice paid without
# retrying the card charge — the card already failed and the user is
# paying via their AutoGPT balance, so a card retry here would
# double-bill the user (card charge + balance debit). Stripe still
# fires ``invoice.payment_succeeded`` on the transition; the success
# handler reads ``paid_out_of_band`` and skips the credit grant so
# the balance debit isn't reversed.
if invoice_id:
try:
await run_in_threadpool(stripe.Invoice.pay, invoice_id)
await run_in_threadpool(
stripe.Invoice.pay, invoice_id, paid_out_of_band=True
)
except stripe.StripeError:
logger.warning(
"handle_subscription_payment_failure: balance deducted for user"
@@ -2355,7 +2409,95 @@ async def handle_subscription_payment_failure(invoice: dict) -> None:
customer_id,
)
return
await set_subscription_tier(user.id, SubscriptionTier.BASIC)
await set_subscription_tier(user.id, SubscriptionTier.NO_TIER)
async def handle_subscription_payment_success(invoice: dict) -> None:
"""Grant AutoGPT credits equal to the paid Stripe invoice amount.
Fires on every paid subscription invoice (initial signup, monthly renewal,
and prorated upgrade charges). Credits = ``invoice.amount_paid`` cents,
keyed by ``invoice_id`` for idempotency so Stripe retries don't double-grant.
Skipped:
- Non-subscription invoices (no ``subscription`` field).
- Zero-amount invoices (e.g. card-validation checks, $0 trials).
- ENTERPRISE users (admin-managed; they don't pay via self-service).
"""
customer_id = invoice.get("customer")
if not customer_id:
logger.warning(
"handle_subscription_payment_success: missing customer in invoice; skipping"
)
return
sub_id: str = invoice.get("subscription") or ""
if not sub_id:
# Non-subscription invoices (one-off invoices, etc.) — no credit grant.
return
user = await User.prisma().find_first(where={"stripeCustomerId": customer_id})
if not user:
logger.warning(
"handle_subscription_payment_success: no user for customer %s",
customer_id,
)
return
if (
user.subscriptionTier or SubscriptionTier.NO_TIER
) == SubscriptionTier.ENTERPRISE:
logger.warning(
"handle_subscription_payment_success: skipping ENTERPRISE user %s"
" (customer %s) — tier is admin-managed",
user.id,
customer_id,
)
return
amount_paid: int = invoice.get("amount_paid", 0)
invoice_id: str = invoice.get("id", "")
if amount_paid <= 0 or not invoice_id:
return
# Skip when ``handle_subscription_payment_failure`` already covered this
# invoice from the user's balance and marked it paid out of band — the
# balance was debited there, granting matching credits here would reverse
# the debit and give the user a free billing period.
if invoice.get("paid_out_of_band"):
logger.info(
"handle_subscription_payment_success: skipping invoice %s for user %s"
" (paid_out_of_band — covered by balance in failure handler)",
invoice_id,
user.id,
)
return
try:
await UserCredit()._add_transaction(
user_id=user.id,
amount=amount_paid,
transaction_type=CreditTransactionType.GRANT,
transaction_key=f"INVOICE-{invoice_id}",
metadata=SafeJson(
{
"stripe_customer_id": customer_id,
"stripe_subscription_id": sub_id,
"stripe_invoice_id": invoice_id,
"billing_reason": invoice.get("billing_reason", ""),
"reason": "subscription_invoice_paid",
}
),
)
logger.info(
"handle_subscription_payment_success: granted %d credits to user %s"
" for invoice %s (sub %s)",
amount_paid,
user.id,
invoice_id,
sub_id,
)
except UniqueViolationError:
# Idempotency key collision — Stripe retried this invoice's webhook and
# we already granted the credits. Safe to ignore.
return
async def admin_get_user_history(

View File

@@ -7,6 +7,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import stripe
from prisma.enums import SubscriptionTier
from prisma.errors import UniqueViolationError
from prisma.models import User
from backend.data.credit import (
@@ -15,6 +16,7 @@ from backend.data.credit import (
get_pending_subscription_change,
get_proration_credit_cents,
handle_subscription_payment_failure,
handle_subscription_payment_success,
is_tier_downgrade,
is_tier_upgrade,
modify_stripe_subscription_for_tier,
@@ -174,7 +176,7 @@ async def test_sync_subscription_from_stripe_enterprise_not_overwritten():
@pytest.mark.asyncio
async def test_sync_subscription_from_stripe_cancelled():
"""When the only active sub is cancelled, the user is downgraded to BASIC."""
"""When the only active sub is cancelled, the user is downgraded to NO_TIER."""
mock_user = _make_user(tier=SubscriptionTier.PRO)
stripe_sub = {
"id": "sub_old",
@@ -199,7 +201,7 @@ async def test_sync_subscription_from_stripe_cancelled():
) as mock_set,
):
await sync_subscription_from_stripe(stripe_sub)
mock_set.assert_awaited_once_with("user-1", SubscriptionTier.BASIC)
mock_set.assert_awaited_once_with("user-1", SubscriptionTier.NO_TIER)
@pytest.mark.asyncio
@@ -1284,7 +1286,10 @@ async def test_sync_subscription_from_stripe_no_metadata_user_id_skips_check():
@pytest.mark.asyncio
async def test_handle_subscription_payment_failure_balance_covers_pays_invoice():
"""When balance covers the invoice, Stripe Invoice.pay is called to stop retries."""
"""When balance covers the invoice, Stripe Invoice.pay is called with
paid_out_of_band=True so the card isn't double-charged on top of the
balance debit (the card already failed; retrying it would let the
success-handler webhook reverse the debit via the credit grant)."""
mock_user = _make_user(user_id="user-1", tier=SubscriptionTier.PRO)
invoice = {
"id": "in_abc123",
@@ -1305,7 +1310,7 @@ async def test_handle_subscription_payment_failure_balance_covers_pays_invoice()
patch("backend.data.credit.stripe.Invoice.pay") as mock_pay,
):
await handle_subscription_payment_failure(invoice)
mock_pay.assert_called_once_with("in_abc123")
mock_pay.assert_called_once_with("in_abc123", paid_out_of_band=True)
@pytest.mark.asyncio
@@ -1367,6 +1372,356 @@ async def test_handle_subscription_payment_failure_passes_invoice_id_as_transact
assert kwargs.get("transaction_key") == "in_idempotency_test"
@pytest.mark.asyncio
async def test_handle_subscription_payment_success_grants_credits():
"""A paid subscription invoice grants credits = amount_paid, keyed by invoice_id."""
mock_user = _make_user(user_id="user-1", tier=SubscriptionTier.PRO)
invoice = {
"id": "in_abc123",
"customer": "cus_123",
"subscription": "sub_abc123",
"amount_paid": 5000,
"billing_reason": "subscription_cycle",
}
add_tx_mock = AsyncMock()
with (
patch(
"backend.data.credit.User.prisma",
return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)),
),
patch(
"backend.data.credit.UserCredit._add_transaction",
new=add_tx_mock,
),
):
await handle_subscription_payment_success(invoice)
add_tx_mock.assert_awaited_once()
kwargs = add_tx_mock.await_args.kwargs
assert kwargs["amount"] == 5000
assert kwargs["transaction_key"] == "INVOICE-in_abc123"
@pytest.mark.asyncio
async def test_handle_subscription_payment_success_skips_non_subscription_invoice():
"""Invoices with no subscription field (one-off invoices) are no-ops."""
invoice = {
"id": "in_abc123",
"customer": "cus_123",
"amount_paid": 5000,
# No 'subscription' field
}
prisma_mock = MagicMock()
with patch("backend.data.credit.User.prisma", return_value=prisma_mock):
await handle_subscription_payment_success(invoice)
prisma_mock.find_first.assert_not_called()
@pytest.mark.asyncio
async def test_handle_subscription_payment_success_skips_paid_out_of_band():
"""When the failure handler covered the invoice from the user's balance and
marked it ``paid_out_of_band=True``, the success-handler webhook that
follows must NOT grant credits — doing so would reverse the balance debit
and effectively give the user a free billing period."""
mock_user = _make_user(user_id="user-1", tier=SubscriptionTier.PRO)
invoice = {
"id": "in_oob_123",
"customer": "cus_123",
"subscription": "sub_abc123",
"amount_paid": 5000,
"billing_reason": "subscription_cycle",
"paid_out_of_band": True,
}
add_tx_mock = AsyncMock()
with (
patch(
"backend.data.credit.User.prisma",
return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)),
),
patch(
"backend.data.credit.UserCredit._add_transaction",
new=add_tx_mock,
),
):
await handle_subscription_payment_success(invoice)
add_tx_mock.assert_not_called()
@pytest.mark.asyncio
async def test_handle_subscription_payment_success_skips_zero_amount():
"""Zero-amount invoices (card validation, $0 trials) are no-ops."""
mock_user = _make_user(user_id="user-1", tier=SubscriptionTier.PRO)
invoice = {
"id": "in_abc123",
"customer": "cus_123",
"subscription": "sub_abc123",
"amount_paid": 0,
}
add_tx_mock = AsyncMock()
with (
patch(
"backend.data.credit.User.prisma",
return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)),
),
patch(
"backend.data.credit.UserCredit._add_transaction",
new=add_tx_mock,
),
):
await handle_subscription_payment_success(invoice)
add_tx_mock.assert_not_called()
@pytest.mark.asyncio
async def test_handle_subscription_payment_success_skips_missing_customer():
"""Invoices missing the customer field are dropped with a warning."""
invoice = {
"id": "in_abc",
"subscription": "sub_abc",
"amount_paid": 1000,
}
prisma_mock = MagicMock()
with patch("backend.data.credit.User.prisma", return_value=prisma_mock):
await handle_subscription_payment_success(invoice)
prisma_mock.find_first.assert_not_called()
@pytest.mark.asyncio
async def test_handle_subscription_payment_success_skips_unknown_user():
"""Invoices for an unknown stripeCustomerId are dropped with a warning."""
invoice = {
"id": "in_abc",
"customer": "cus_unknown",
"subscription": "sub_abc",
"amount_paid": 1000,
}
add_tx_mock = AsyncMock()
with (
patch(
"backend.data.credit.User.prisma",
return_value=MagicMock(find_first=AsyncMock(return_value=None)),
),
patch(
"backend.data.credit.UserCredit._add_transaction",
new=add_tx_mock,
),
):
await handle_subscription_payment_success(invoice)
add_tx_mock.assert_not_called()
@pytest.mark.asyncio
async def test_handle_subscription_payment_success_skips_enterprise():
"""ENTERPRISE users don't get credit grants from Stripe invoices."""
mock_user = _make_user(user_id="user-1", tier=SubscriptionTier.ENTERPRISE)
invoice = {
"id": "in_abc",
"customer": "cus_123",
"subscription": "sub_abc",
"amount_paid": 5000,
}
add_tx_mock = AsyncMock()
with (
patch(
"backend.data.credit.User.prisma",
return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)),
),
patch(
"backend.data.credit.UserCredit._add_transaction",
new=add_tx_mock,
),
):
await handle_subscription_payment_success(invoice)
add_tx_mock.assert_not_called()
@pytest.mark.asyncio
async def test_handle_subscription_payment_success_idempotent_on_unique_violation():
"""If the GRANT transaction key already exists (Stripe webhook retry),
UniqueViolationError is swallowed so the webhook returns 200 and Stripe
stops retrying."""
mock_user = _make_user(user_id="user-1", tier=SubscriptionTier.PRO)
invoice = {
"id": "in_abc",
"customer": "cus_123",
"subscription": "sub_abc",
"amount_paid": 5000,
}
add_tx_mock = AsyncMock(side_effect=UniqueViolationError({"error": "dup"}))
with (
patch(
"backend.data.credit.User.prisma",
return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)),
),
patch(
"backend.data.credit.UserCredit._add_transaction",
new=add_tx_mock,
),
):
await handle_subscription_payment_success(invoice)
add_tx_mock.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_active_subscription_period_end_returns_unix_timestamp():
"""Happy path: returns int(current_period_end) for an active sub."""
mock_sub = stripe.Subscription.construct_from(
{"id": "sub_abc", "current_period_end": 1779340148}, "k"
)
mock_list = MagicMock()
mock_list.data = [mock_sub]
user = MagicMock(spec=User)
user.stripe_customer_id = "cus_abc"
with (
patch(
"backend.data.credit.get_user_by_id",
new_callable=AsyncMock,
return_value=user,
),
patch(
"backend.data.credit.stripe.Subscription.list_async",
new_callable=AsyncMock,
return_value=mock_list,
),
):
from backend.data.credit import get_active_subscription_period_end
result = await get_active_subscription_period_end("user-1")
assert result == 1779340148
@pytest.mark.asyncio
async def test_get_active_subscription_period_end_returns_none_without_customer():
"""Users without a Stripe customer ID return None — no Stripe API call."""
user = MagicMock(spec=User)
user.stripe_customer_id = None
list_mock = AsyncMock()
with (
patch(
"backend.data.credit.get_user_by_id",
new_callable=AsyncMock,
return_value=user,
),
patch(
"backend.data.credit.stripe.Subscription.list_async",
new=list_mock,
),
):
from backend.data.credit import get_active_subscription_period_end
result = await get_active_subscription_period_end("user-1")
assert result is None
list_mock.assert_not_called()
@pytest.mark.asyncio
async def test_get_active_subscription_period_end_swallows_stripe_errors():
"""A Stripe error during the lookup returns None instead of raising."""
user = MagicMock(spec=User)
user.stripe_customer_id = "cus_abc"
with (
patch(
"backend.data.credit.get_user_by_id",
new_callable=AsyncMock,
return_value=user,
),
patch(
"backend.data.credit.stripe.Subscription.list_async",
side_effect=stripe.StripeError("boom"),
),
):
from backend.data.credit import get_active_subscription_period_end
result = await get_active_subscription_period_end("user-1")
assert result is None
@pytest.mark.asyncio
async def test_top_up_intent_uses_inline_product_data_when_flag_unset():
"""When STRIPE_PRODUCT_ID_TOPUP flag is undefined (default), top-up Checkout
creates an ephemeral product per session via product_data."""
from backend.data.credit import UserCredit
mock_session = MagicMock()
mock_session.id = "cs_test_topup"
mock_session.url = "https://checkout.stripe.com/c/cs_test_topup"
create_mock = MagicMock(return_value=mock_session)
credit_system = UserCredit()
with (
patch(
"backend.data.credit.get_stripe_customer_id",
new_callable=AsyncMock,
return_value="cus_123",
),
patch(
"backend.data.credit.get_feature_flag_value",
new_callable=AsyncMock,
return_value=None,
),
patch(
"backend.data.credit.stripe.checkout.Session.create",
new=create_mock,
),
patch.object(credit_system, "_add_transaction", new_callable=AsyncMock),
):
await credit_system.top_up_intent(user_id="user-1", amount=500)
price_data = create_mock.call_args.kwargs["line_items"][0]["price_data"]
assert price_data == {
"currency": "usd",
"unit_amount": 500,
"product_data": {"name": "AutoGPT Platform Credits"},
}
@pytest.mark.asyncio
async def test_top_up_intent_references_product_id_when_flag_set():
"""When STRIPE_PRODUCT_ID_TOPUP flag returns a string, top-up Checkout
references the canonical Product ID and keeps the per-session amount via
unit_amount."""
from backend.data.credit import UserCredit
mock_session = MagicMock()
mock_session.id = "cs_test_topup"
mock_session.url = "https://checkout.stripe.com/c/cs_test_topup"
create_mock = MagicMock(return_value=mock_session)
credit_system = UserCredit()
with (
patch(
"backend.data.credit.get_stripe_customer_id",
new_callable=AsyncMock,
return_value="cus_123",
),
patch(
"backend.data.credit.get_feature_flag_value",
new_callable=AsyncMock,
return_value="prod_abc123",
),
patch(
"backend.data.credit.stripe.checkout.Session.create",
new=create_mock,
),
patch.object(credit_system, "_add_transaction", new_callable=AsyncMock),
):
await credit_system.top_up_intent(user_id="user-1", amount=2500)
price_data = create_mock.call_args.kwargs["line_items"][0]["price_data"]
assert price_data == {
"currency": "usd",
"unit_amount": 2500,
"product": "prod_abc123",
}
# No product_data — that path is mutually exclusive with product reference.
assert "product_data" not in price_data
@pytest.mark.asyncio
async def test_modify_stripe_subscription_for_tier_modifies_existing_sub():
"""modify_stripe_subscription_for_tier calls Subscription.modify and returns True."""
@@ -1845,7 +2200,7 @@ async def test_release_pending_subscription_schedule_no_stripe_customer_returns_
@pytest.mark.asyncio
async def test_get_pending_subscription_change_cancel_at_period_end():
"""cancel_at_period_end=True maps to pending BASIC at current_period_end."""
"""cancel_at_period_end=True maps to pending NO_TIER at current_period_end."""
import time as time_mod
get_pending_subscription_change.cache_clear() # type: ignore[attr-defined]
@@ -1894,7 +2249,7 @@ async def test_get_pending_subscription_change_cancel_at_period_end():
assert result is not None
pending_tier, effective_at = result
assert pending_tier == SubscriptionTier.BASIC
assert pending_tier == SubscriptionTier.NO_TIER
assert int(effective_at.timestamp()) == period_end

View File

@@ -98,6 +98,12 @@ from backend.data.notifications import (
)
from backend.data.onboarding import increment_onboarding_runs
from backend.data.platform_cost import log_platform_cost
from backend.data.push_subscription import (
cleanup_failed_subscriptions,
delete_push_subscription,
get_user_push_subscriptions,
increment_fail_count,
)
from backend.data.understanding import (
get_business_understanding,
upsert_business_understanding,
@@ -339,6 +345,16 @@ class DatabaseManager(AppService):
# ============ Platform Cost Tracking ============ #
log_platform_cost = _(log_platform_cost)
# ============ Push Notifications ============ #
get_user_push_subscriptions = _(get_user_push_subscriptions)
delete_push_subscription = _(delete_push_subscription)
increment_push_fail_count = _(
increment_fail_count, name="increment_push_fail_count"
)
cleanup_failed_push_subscriptions = _(
cleanup_failed_subscriptions, name="cleanup_failed_push_subscriptions"
)
# ============ Platform Linking ============ #
find_server_link_owner = _(platform_linking_db.find_server_link_owner)
find_user_link_owner = _(platform_linking_db.find_user_link_owner)
@@ -557,6 +573,12 @@ class DatabaseManagerAsyncClient(AppServiceClient):
# ============ Platform Cost Tracking ============ #
log_platform_cost = d.log_platform_cost
# ============ Push Notifications ============ #
get_user_push_subscriptions = d.get_user_push_subscriptions
delete_push_subscription = d.delete_push_subscription
increment_push_fail_count = d.increment_push_fail_count
cleanup_failed_push_subscriptions = d.cleanup_failed_push_subscriptions
# ============ Platform Linking ============ #
find_server_link_owner = d.find_server_link_owner
find_user_link_owner = d.find_user_link_owner

View File

@@ -0,0 +1,498 @@
"""End-to-end coverage of the data-layer APIs over the live 3-shard Redis
cluster + RabbitMQ broker. Tests skip when their infra is unreachable.
Container-restart scenarios live in `e2e_redis_restart_test.py`."""
from __future__ import annotations
import asyncio
import json
import time
from datetime import datetime, timezone
from uuid import uuid4
import pytest
import backend.data.redis_client as redis_client
from backend.api.model import NotificationPayload
from backend.data.execution import (
AsyncRedisExecutionEventBus,
ExecutionStatus,
NodeExecutionEvent,
exec_channel,
graph_all_channel,
)
from backend.data.notification_bus import (
AsyncRedisNotificationEventBus,
NotificationEvent,
)
from backend.data.rabbitmq import AsyncRabbitMQ
from backend.executor.utils import (
GRAPH_EXECUTION_EXCHANGE,
GRAPH_EXECUTION_QUEUE_NAME,
create_execution_queue_config,
)
def _has_live_cluster() -> bool:
try:
c = redis_client.connect()
except Exception: # noqa: BLE001 — any connect failure → skip
return False
try:
c.close()
except Exception:
pass
return True
def _has_live_rabbit() -> bool:
"""Probe the rabbitmq host:port from settings; skip if unreachable."""
import socket
from backend.util.settings import Settings
s = Settings()
try:
with socket.create_connection(
(s.config.rabbitmq_host, s.config.rabbitmq_port), timeout=1.0
):
return True
except Exception: # noqa: BLE001 - any connect failure → skip the test
return False
cluster_only = pytest.mark.skipif(
not _has_live_cluster(),
reason="local redis cluster not reachable; skip e2e integration",
)
rabbit_only = pytest.mark.skipif(
not _has_live_rabbit(),
reason="local rabbitmq not reachable; skip e2e integration",
)
def _make_node_event(*, user_id: str, graph_id: str, gex_id: str, marker: str):
return NodeExecutionEvent(
user_id=user_id,
graph_id=graph_id,
graph_version=1,
graph_exec_id=gex_id,
node_exec_id=f"node-exec-{marker}",
node_id="node-1",
block_id="block-1",
status=ExecutionStatus.COMPLETED,
input_data={"in": marker},
output_data={"out": [marker]},
add_time=datetime.now(tz=timezone.utc),
queue_time=None,
start_time=datetime.now(tz=timezone.utc),
end_time=datetime.now(tz=timezone.utc),
)
# ---------- Scenario 1: cluster cache round-trip across slots ----------
@cluster_only
def test_cluster_cache_roundtrip_across_three_slots() -> None:
"""A list-graphs-style cache flow: SET keys with hash tags that land on
different shards, GET them back. Validates the basic cluster-routing
contract end-to-end."""
redis_client.get_redis.cache_clear()
cluster = redis_client.get_redis()
keys = []
try:
# Pick keys that hash to different slots — try until 3 distinct shards.
seen: set[tuple[str, int]] = set()
for i in range(2000):
key = f"e2e:cache:{i}"
node = cluster.get_node_from_key(key)
owner = (node.host, node.port)
if owner in seen:
continue
seen.add(owner)
keys.append(key)
if len(seen) >= 3:
break
assert len(keys) >= 3
for i, k in enumerate(keys):
cluster.setex(k, 60, f"v-{i}")
for i, k in enumerate(keys):
assert cluster.get(k) == f"v-{i}"
finally:
for k in keys:
try:
cluster.delete(k)
except Exception:
pass
redis_client.disconnect()
# ---------- Scenarios 2 & 3: graph execution event streams ----------
@pytest.mark.asyncio
@cluster_only
async def test_graph_execution_events_complete_under_ten_seconds() -> None:
"""A listener subscribes to the per-exec channel; the producer publishes
one node event. The listener must observe it in under 10 seconds —
pins the latency contract end-to-end through SPUBLISH/SSUBSCRIBE."""
redis_client._async_clients.clear()
user_id = f"u-e2e-{uuid4().hex[:8]}"
graph_id = f"g-{uuid4().hex[:8]}"
gex_id = f"x-{uuid4().hex[:8]}"
publisher = AsyncRedisExecutionEventBus()
subscriber = AsyncRedisExecutionEventBus()
received: list[str] = []
async def _consume() -> None:
async for evt in subscriber.listen_events(
exec_channel(user_id, graph_id, gex_id)
):
received.append(getattr(evt, "node_exec_id", "graph"))
return
task = asyncio.create_task(_consume())
await asyncio.sleep(0.3)
start = time.monotonic()
try:
await publisher.publish_event(
_make_node_event(
user_id=user_id, graph_id=graph_id, gex_id=gex_id, marker="m1"
),
exec_channel(user_id, graph_id, gex_id),
)
await asyncio.wait_for(task, timeout=10.0)
finally:
if not task.done():
task.cancel()
await subscriber.close()
await redis_client.disconnect_async()
elapsed = time.monotonic() - start
assert elapsed < 10.0, f"event roundtrip took {elapsed:.2f}s, expected < 10s"
assert received == ["node-exec-m1"]
@pytest.mark.asyncio
@cluster_only
async def test_two_concurrent_graphs_no_cross_talk() -> None:
"""Two graphs execute in parallel; two listeners on different per-exec
channels each receive only their own events."""
redis_client._async_clients.clear()
user_id = f"u-e2e-{uuid4().hex[:8]}"
g1, g2 = f"g1-{uuid4().hex[:8]}", f"g2-{uuid4().hex[:8]}"
e1, e2 = f"e1-{uuid4().hex[:8]}", f"e2-{uuid4().hex[:8]}"
publisher = AsyncRedisExecutionEventBus()
sub_a = AsyncRedisExecutionEventBus()
sub_b = AsyncRedisExecutionEventBus()
async def _listen_one(bus, channel_key: str, sink: list, want: int) -> None:
async for evt in bus.listen_events(channel_key):
sink.append(getattr(evt, "node_exec_id", "graph"))
if len(sink) >= want:
return
sink_a: list[str] = []
sink_b: list[str] = []
t_a = asyncio.create_task(
_listen_one(sub_a, exec_channel(user_id, g1, e1), sink_a, want=3)
)
t_b = asyncio.create_task(
_listen_one(sub_b, exec_channel(user_id, g2, e2), sink_b, want=3)
)
await asyncio.sleep(0.3)
try:
for i in range(3):
await publisher.publish_event(
_make_node_event(
user_id=user_id, graph_id=g1, gex_id=e1, marker=f"a{i}"
),
exec_channel(user_id, g1, e1),
)
await publisher.publish_event(
_make_node_event(
user_id=user_id, graph_id=g2, gex_id=e2, marker=f"b{i}"
),
exec_channel(user_id, g2, e2),
)
await asyncio.wait_for(asyncio.gather(t_a, t_b), timeout=10.0)
assert sink_a == ["node-exec-a0", "node-exec-a1", "node-exec-a2"]
assert sink_b == ["node-exec-b0", "node-exec-b1", "node-exec-b2"]
finally:
await sub_a.close()
await sub_b.close()
await redis_client.disconnect_async()
# ---------- Scenario 4: aggregate /all channel for graph executions ----------
@pytest.mark.asyncio
@cluster_only
async def test_three_executions_land_on_aggregate_channel() -> None:
"""Subscribe to the aggregate ``/all`` channel; trigger 3 different
executions of the same graph; assert all 3 land on the aggregate."""
redis_client._async_clients.clear()
user_id = f"u-e2e-{uuid4().hex[:8]}"
graph_id = f"g-{uuid4().hex[:8]}"
exec_ids = [f"x{i}-{uuid4().hex[:6]}" for i in range(3)]
publisher = AsyncRedisExecutionEventBus()
subscriber = AsyncRedisExecutionEventBus()
received: list[str] = []
async def _listen_all() -> None:
async for evt in subscriber.listen_events(graph_all_channel(user_id, graph_id)):
received.append(getattr(evt, "graph_exec_id", "?"))
if len(received) >= 3:
return
task = asyncio.create_task(_listen_all())
await asyncio.sleep(0.3)
try:
for ex in exec_ids:
await publisher.publish_event(
_make_node_event(
user_id=user_id, graph_id=graph_id, gex_id=ex, marker=ex
),
graph_all_channel(user_id, graph_id),
)
await asyncio.wait_for(task, timeout=10.0)
# Order of receipt may vary slightly under load — check set membership.
assert set(received) == set(exec_ids)
finally:
await subscriber.close()
await redis_client.disconnect_async()
# ---------- Scenarios 5 & 6: copilot/notification per-user channels ----------
@pytest.mark.asyncio
@cluster_only
async def test_copilot_cancel_signal_via_sharded_pubsub() -> None:
"""A subscriber on a per-session channel receives an SPUBLISH cancel
signal — the primitive the copilot executor uses for graceful cancel."""
redis_client._async_clients.clear()
session_id = f"sess-{uuid4().hex[:8]}"
channel = "{copilot/" + session_id + "}/cancel"
client = await redis_client.connect_sharded_pubsub_async(channel)
pubsub = client.pubsub()
received: list[str] = []
try:
await pubsub.execute_command("SSUBSCRIBE", channel)
# Prime the channels map so listen() doesn't early-exit (see _Subscription).
pubsub.channels[channel] = None # type: ignore[index]
async def _pump() -> None:
async for msg in pubsub.listen():
if msg.get("type") == "smessage":
received.append(msg["data"])
return
listener = asyncio.create_task(_pump())
await asyncio.sleep(0.2)
cluster = await redis_client.get_redis_async()
await cluster.execute_command("SPUBLISH", channel, "cancel")
await asyncio.wait_for(listener, timeout=5.0)
assert received == ["cancel"]
finally:
try:
await pubsub.execute_command("SUNSUBSCRIBE", channel)
except Exception:
pass
await pubsub.aclose()
await client.aclose()
await redis_client.disconnect_async()
@pytest.mark.asyncio
@cluster_only
async def test_notification_fan_out_per_user_channel() -> None:
"""Per-user SSUBSCRIBE: a publish on the user's notification channel
reaches the user's listener and only that listener."""
redis_client._async_clients.clear()
user_id = f"u-notif-{uuid4().hex[:8]}"
other_user_id = f"u-other-{uuid4().hex[:8]}"
publisher = AsyncRedisNotificationEventBus()
listener_user = AsyncRedisNotificationEventBus()
listener_other = AsyncRedisNotificationEventBus()
user_received: list[str] = []
other_received: list[str] = []
notif_for_user = NotificationEvent(
user_id=user_id,
payload=NotificationPayload(type="info", event="balance-low"),
)
notif_for_other = NotificationEvent(
user_id=other_user_id,
payload=NotificationPayload(type="info", event="other"),
)
async def _listen_one(bus: AsyncRedisNotificationEventBus, uid: str, sink: list):
async for evt in bus.listen(uid):
sink.append(evt.user_id)
return
t_user = asyncio.create_task(_listen_one(listener_user, user_id, user_received))
t_other = asyncio.create_task(
_listen_one(listener_other, other_user_id, other_received)
)
await asyncio.sleep(0.3)
try:
await publisher.publish(notif_for_user)
await publisher.publish(notif_for_other)
await asyncio.wait_for(asyncio.gather(t_user, t_other), timeout=10.0)
assert user_received == [user_id]
assert other_received == [other_user_id]
finally:
await listener_user.close()
await listener_other.close()
await publisher.close()
await redis_client.disconnect_async()
# ---------- Scenario 7: idle WS connection 60s ----------
@pytest.mark.asyncio
@cluster_only
async def test_idle_subscriber_60s_then_receives_publish() -> None:
"""An SSUBSCRIBE that sits idle past one health-check interval must
still deliver a subsequent SPUBLISH (uses HEALTH_CHECK_INTERVAL+5s)."""
redis_client._async_clients.clear()
channel = "{idle-e2e}/exec/" + uuid4().hex[:8]
client = await redis_client.connect_sharded_pubsub_async(channel)
pubsub = client.pubsub()
try:
await pubsub.execute_command("SSUBSCRIBE", channel)
pubsub.channels[channel] = None # type: ignore[index]
# Drain ssubscribe confirm.
async for _msg in pubsub.listen():
break
idle_seconds = redis_client.HEALTH_CHECK_INTERVAL + 5
await asyncio.sleep(idle_seconds)
cluster = await redis_client.get_redis_async()
await cluster.execute_command("SPUBLISH", channel, "hello-after-idle")
async for msg in pubsub.listen():
if msg.get("type") == "smessage":
assert msg["data"] == "hello-after-idle"
return
finally:
try:
await pubsub.execute_command("SUNSUBSCRIBE", channel)
except Exception:
pass
await pubsub.aclose()
await client.aclose()
await redis_client.disconnect_async()
# ---------- Scenario 8: graph_execution_queue_v2 publish + consume ----------
@pytest.mark.asyncio
@rabbit_only
async def test_graph_execution_queue_publish_and_consume() -> None:
"""End-to-end on a test-scoped quorum queue: publish via AsyncRabbitMQ
→ consume → payload round-trips intact. Uses a unique routing key so
the live executor consumer (if any) doesn't race for the message."""
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
test_queue_name = f"e2e_test_{uuid4().hex[:8]}_v2"
test_routing_key = f"e2e.test.{uuid4().hex[:8]}"
test_exchange = Exchange(
name=GRAPH_EXECUTION_EXCHANGE.name,
type=ExchangeType.DIRECT,
durable=True,
)
test_queue = Queue(
name=test_queue_name,
durable=True,
# Quorum queues reject auto_delete; we delete the queue explicitly
# in the finally block instead.
auto_delete=False,
exchange=test_exchange,
routing_key=test_routing_key,
arguments={"x-queue-type": "quorum"},
)
cfg = RabbitMQConfig(vhost="/", exchanges=[test_exchange], queues=[test_queue])
publisher = AsyncRabbitMQ(cfg)
await publisher.connect()
consumer = AsyncRabbitMQ(cfg)
await consumer.connect()
payload = json.dumps(
{"graph_exec_id": f"e2e-{uuid4().hex[:8]}", "marker": "round-trip"}
)
try:
channel = await consumer.get_channel()
queue_obj = await channel.get_queue(test_queue_name)
await publisher.publish_message(
routing_key=test_routing_key,
message=payload,
exchange=test_exchange,
)
# Poll get() — quorum queue must surface the publish within 5s.
deadline = time.monotonic() + 5.0
msg = None
while time.monotonic() < deadline:
msg = await queue_obj.get(no_ack=True, fail=False)
if msg is not None:
break
await asyncio.sleep(0.05)
assert msg is not None, "publish never reached the quorum queue"
assert msg.body.decode() == payload
finally:
# Best-effort delete in case auto_delete didn't trigger.
try:
channel = await consumer.get_channel()
await channel.queue_delete(test_queue_name, if_unused=False, if_empty=False)
except Exception:
pass
await publisher.disconnect()
await consumer.disconnect()
@pytest.mark.asyncio
@rabbit_only
async def test_graph_execution_queue_uses_quorum_via_real_broker() -> None:
"""Live-broker check that `graph_execution_queue_v2` is declared as
quorum — passive re-declare with `x-queue-type=quorum` must not raise."""
cfg = create_execution_queue_config()
client = AsyncRabbitMQ(cfg)
await client.connect() # declares everything in cfg
try:
channel = await client.get_channel()
# Re-declare passively — must NOT raise PRECONDITION_FAILED if the
# type matches, would raise if quorum was lost.
q = await channel.declare_queue(
name=GRAPH_EXECUTION_QUEUE_NAME,
durable=True,
arguments={"x-queue-type": "quorum"},
passive=True,
)
assert q.name == GRAPH_EXECUTION_QUEUE_NAME
finally:
await client.disconnect()

View File

@@ -0,0 +1,313 @@
"""Sharded pubsub reconnect across a real `docker restart` of a shard,
against a private 3-shard cluster on isolated host ports. Gated on
`E2E_REDIS_CLUSTER_RESTART=1` + `docker` on PATH, marked `pytest.mark.slow`."""
from __future__ import annotations
import asyncio
import importlib
import os
import shutil
import socket
import subprocess
import time
from uuid import uuid4
import pytest
# Disjoint from the dev-compose ports (17000-17002) so both stacks coexist.
ISOLATED_PROJECT = "redis-restart-test"
ISOLATED_PORTS = (27110, 27111, 27112)
ISOLATED_BUS_PORTS = (37110, 37111, 37112)
def _docker_available() -> bool:
return shutil.which("docker") is not None
def _isolated_enabled() -> bool:
return os.getenv("E2E_REDIS_CLUSTER_RESTART", "").lower() in ("1", "true", "yes")
cluster_restart_only = pytest.mark.skipif(
not (_docker_available() and _isolated_enabled()),
reason=(
"isolated docker cluster restart e2e: requires docker + E2E_REDIS_CLUSTER_RESTART=1"
),
)
def _run(cmd: list[str], *, timeout: float = 60.0) -> subprocess.CompletedProcess[str]:
return subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=timeout,
check=False,
)
def _wait_port(port: int, *, deadline_s: float = 60.0) -> None:
deadline = time.monotonic() + deadline_s
while time.monotonic() < deadline:
try:
with socket.create_connection(("127.0.0.1", port), timeout=1.0):
return
except OSError:
time.sleep(0.5)
raise TimeoutError(f"port {port} never opened within {deadline_s:.0f}s")
def _start_isolated_cluster() -> None:
"""Spin up a private 3-shard cluster via raw `docker run` + one-shot
`redis-cli --cluster create`."""
network = f"{ISOLATED_PROJECT}-net"
_run(["docker", "network", "create", network]) # may exist; ignore exit
for i, (port, bus) in enumerate(zip(ISOLATED_PORTS, ISOLATED_BUS_PORTS)):
name = f"{ISOLATED_PROJECT}-redis-{i}"
_run(["docker", "rm", "-f", name])
rc = _run(
[
"docker",
"run",
"-d",
"--name",
name,
"--network",
network,
"--network-alias",
f"redis-{i}",
"-p",
f"{port}:{port}",
"redis:7",
"redis-server",
"--port",
str(port),
"--cluster-enabled",
"yes",
"--cluster-config-file",
"nodes.conf",
"--cluster-node-timeout",
"5000",
"--cluster-require-full-coverage",
"no",
"--cluster-announce-hostname",
f"redis-{i}",
"--cluster-announce-port",
str(port),
"--cluster-announce-bus-port",
str(bus),
"--cluster-preferred-endpoint-type",
"hostname",
]
)
if rc.returncode != 0:
raise RuntimeError(f"docker run redis-{i} failed: {rc.stderr}")
for port in ISOLATED_PORTS:
_wait_port(port)
rc = _run(
[
"docker",
"run",
"--rm",
"--network",
network,
"redis:7",
"redis-cli",
"--cluster",
"create",
f"redis-0:{ISOLATED_PORTS[0]}",
f"redis-1:{ISOLATED_PORTS[1]}",
f"redis-2:{ISOLATED_PORTS[2]}",
"--cluster-replicas",
"0",
"--cluster-yes",
]
)
if rc.returncode != 0:
raise RuntimeError(f"cluster create failed: {rc.stderr}")
deadline = time.monotonic() + 30
while time.monotonic() < deadline:
info = _run(
[
"docker",
"exec",
f"{ISOLATED_PROJECT}-redis-0",
"redis-cli",
"-p",
str(ISOLATED_PORTS[0]),
"cluster",
"info",
]
)
if "cluster_state:ok" in info.stdout:
return
time.sleep(0.5)
raise TimeoutError("isolated cluster never reached cluster_state:ok")
def _wait_cluster_ok(timeout_s: float = 30.0) -> bool:
deadline = time.monotonic() + timeout_s
while time.monotonic() < deadline:
info = _run(
[
"docker",
"exec",
f"{ISOLATED_PROJECT}-redis-0",
"redis-cli",
"-p",
str(ISOLATED_PORTS[0]),
"cluster",
"info",
]
)
if "cluster_state:ok" in info.stdout:
return True
time.sleep(0.5)
return False
def _teardown_isolated_cluster() -> None:
for i in range(3):
_run(["docker", "rm", "-f", f"{ISOLATED_PROJECT}-redis-{i}"])
_run(["docker", "network", "rm", f"{ISOLATED_PROJECT}-net"])
@pytest.fixture(scope="module")
def isolated_cluster():
"""Module-scoped: tests share one cluster lifecycle."""
_start_isolated_cluster()
try:
yield
finally:
_teardown_isolated_cluster()
@pytest.mark.asyncio
@pytest.mark.slow
@cluster_restart_only
async def test_subscriber_survives_shard_restart(isolated_cluster, monkeypatch) -> None:
"""Subscriber must receive a post-`docker restart` SPUBLISH after
reopening the sharded-pubsub client (the broker drops the socket on
restart; production's `with_pubsub` loop reconnects the same way)."""
# Must override REDIS_CLUSTER_HOST/PORT too — those take precedence
# over REDIS_HOST/PORT and a stray .env would point us at the dev cluster.
monkeypatch.setenv("REDIS_HOST", "127.0.0.1")
monkeypatch.setenv("REDIS_PORT", str(ISOLATED_PORTS[0]))
monkeypatch.setenv("REDIS_CLUSTER_HOST", "127.0.0.1")
monkeypatch.setenv("REDIS_CLUSTER_PORT", str(ISOLATED_PORTS[0]))
monkeypatch.setenv("REDIS_USE_ANNOUNCED_ADDRESS", "false")
monkeypatch.delenv("REDIS_PASSWORD", raising=False)
import backend.data.redis_client as rc
importlib.reload(rc)
# Restart whichever container owns the keyslot, not a guess.
cluster = rc.get_redis()
target_tag = f"restart-{uuid4().hex[:8]}"
channel = "{" + target_tag + "}/restart-test"
owner = cluster.get_node_from_key(channel)
port_to_idx = {p: i for i, p in enumerate(ISOLATED_PORTS)}
target_idx = port_to_idx.get(owner.port)
assert (
target_idx is not None
), f"owner port {owner.port} not in known set {ISOLATED_PORTS}"
target_container = f"{ISOLATED_PROJECT}-redis-{target_idx}"
client = await rc.connect_sharded_pubsub_async(channel)
pubsub = client.pubsub()
await pubsub.execute_command("SSUBSCRIBE", channel)
pubsub.channels[channel] = None # type: ignore[index]
received: list[str] = []
async def _drain_one() -> str | None:
try:
async for msg in pubsub.listen():
if msg.get("type") == "smessage":
return msg["data"]
except Exception:
return None
return None
try:
async_cluster = await rc.get_redis_async()
await async_cluster.execute_command("SPUBLISH", channel, "before-restart")
first = await asyncio.wait_for(_drain_one(), timeout=6.0)
received.append(first or "")
assert received == [
"before-restart"
], f"pre-restart publish did not arrive: {received}"
# Restart the shard that owns the slot.
rc_restart = _run(["docker", "restart", "--time", "1", target_container])
assert rc_restart.returncode == 0, rc_restart.stderr
assert _wait_cluster_ok(
timeout_s=30
), "isolated cluster never re-converged to state=ok after restart"
# Hold a small grace window for shard's gossip to settle.
await asyncio.sleep(1.0)
# Old socket is dead — open a fresh sharded-pubsub connection.
try:
await pubsub.aclose()
except Exception:
pass
try:
await client.aclose()
except Exception:
pass
rc._async_clients.clear()
client2 = await rc.connect_sharded_pubsub_async(channel)
pubsub2 = client2.pubsub()
try:
await pubsub2.execute_command("SSUBSCRIBE", channel)
pubsub2.channels[channel] = None # type: ignore[index]
# Drain the SSUBSCRIBE confirm.
async for _msg in pubsub2.listen():
break
async def _drain_after() -> str | None:
async for msg in pubsub2.listen():
if msg.get("type") == "smessage":
return msg["data"]
return None
async_cluster_2 = await rc.get_redis_async()
await async_cluster_2.execute_command("SPUBLISH", channel, "after-restart")
data = await asyncio.wait_for(_drain_after(), timeout=15.0)
assert (
data == "after-restart"
), f"subscriber did not receive post-restart event (got {data!r})"
finally:
try:
await pubsub2.execute_command("SUNSUBSCRIBE", channel)
except Exception:
pass
try:
await pubsub2.aclose()
except Exception:
pass
await client2.aclose()
finally:
try:
await pubsub.aclose()
except Exception:
pass
try:
await client.aclose()
except Exception:
pass
await rc.disconnect_async()
# Undo monkeypatched env BEFORE reloading so subsequent tests see the
# original REDIS_HOST/PORT — otherwise the module captures the
# isolated cluster's port (27110) which is torn down right after this
# test, and any later test that touches redis hangs on conn_retry.
monkeypatch.undo()
importlib.reload(rc)

View File

@@ -1,7 +1,15 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Generator, Generic, Optional, TypeVar
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Generator,
Generic,
Optional,
TypeVar,
)
from pydantic import BaseModel
from redis.asyncio.client import PubSub as AsyncPubSub
@@ -11,6 +19,9 @@ from backend.data import redis_client as redis
from backend.util import json
from backend.util.settings import Settings
if TYPE_CHECKING:
from redis.asyncio import Redis as AsyncRedis
logger = logging.getLogger(__name__)
config = Settings().config
@@ -18,6 +29,15 @@ config = Settings().config
M = TypeVar("M", bound=BaseModel)
def _assert_no_wildcard(channel_key: str) -> None:
"""Sharded pub/sub has no pattern-subscribe; fail fast on wildcards."""
if "*" in channel_key:
raise ValueError(
f"channel_key {channel_key!r} contains a wildcard; sharded pub/sub "
"(SSUBSCRIBE) requires exact channel names."
)
class BaseRedisEventBus(Generic[M], ABC):
Model: type[M]
@@ -71,8 +91,8 @@ class BaseRedisEventBus(Generic[M], ABC):
return message, channel_name
def _deserialize_message(self, msg: Any, channel_key: str) -> M | None:
message_type = "pmessage" if "*" in channel_key else "message"
if msg["type"] != message_type:
# Accept sharded (smessage) and classic (message/pmessage) deliveries.
if msg["type"] not in ("smessage", "message", "pmessage"):
return None
try:
logger.debug(f"[{channel_key}] Consuming an event from Redis {msg['data']}")
@@ -80,12 +100,8 @@ class BaseRedisEventBus(Generic[M], ABC):
except Exception as e:
logger.error(f"Failed to parse event result from Redis {msg} {e}")
def _get_pubsub_channel(
self, connection: redis.Redis | redis.AsyncRedis, channel_key: str
) -> tuple[PubSub | AsyncPubSub, str]:
full_channel_name = f"{self.event_bus_name}/{channel_key}"
pubsub = connection.pubsub()
return pubsub, full_channel_name
def _build_channel_name(self, channel_key: str) -> str:
return f"{self.event_bus_name}/{channel_key}"
class _EventPayloadWrapper(BaseModel, Generic[M]):
@@ -98,88 +114,97 @@ class _EventPayloadWrapper(BaseModel, Generic[M]):
class RedisEventBus(BaseRedisEventBus[M], ABC):
@property
def connection(self) -> redis.Redis:
return redis.get_redis()
def publish_event(self, event: M, channel_key: str):
"""
Publish an event to Redis. Gracefully handles connection failures
by logging the error instead of raising exceptions.
"""
"""Publish via SPUBLISH; swallow failures so Redis blips don't crash callers."""
_assert_no_wildcard(channel_key)
try:
message, full_channel_name = self._serialize_message(event, channel_key)
self.connection.publish(full_channel_name, message)
cluster = redis.get_redis()
cluster.execute_command("SPUBLISH", full_channel_name, message)
except Exception:
logger.exception(
f"Failed to publish event to Redis channel {channel_key}. "
"Event bus operation will continue without Redis connectivity."
)
logger.exception(f"Failed to publish event to Redis channel {channel_key}")
def listen_events(self, channel_key: str) -> Generator[M, None, None]:
pubsub, full_channel_name = self._get_pubsub_channel(
self.connection, channel_key
)
assert isinstance(pubsub, PubSub)
_assert_no_wildcard(channel_key)
full_channel_name = self._build_channel_name(channel_key)
if "*" in channel_key:
pubsub.psubscribe(full_channel_name)
else:
pubsub.subscribe(full_channel_name)
for message in pubsub.listen():
if event := self._deserialize_message(message, channel_key):
yield event
cluster = redis.get_redis()
pubsub: PubSub = cluster.pubsub()
try:
pubsub.ssubscribe(full_channel_name)
for message in pubsub.listen():
if event := self._deserialize_message(message, channel_key):
yield event
finally:
try:
pubsub.sunsubscribe(full_channel_name)
except Exception:
logger.warning(
"Failed to SUNSUBSCRIBE from %s", full_channel_name, exc_info=True
)
try:
pubsub.close()
except Exception:
logger.warning(
"Failed to close sharded pubsub for %s",
full_channel_name,
exc_info=True,
)
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
def __init__(self):
self._pubsub: AsyncPubSub | None = None
@property
async def connection(self) -> redis.AsyncRedis:
return await redis.get_redis_async()
async def close(self) -> None:
"""Close the PubSub connection if it exists."""
if self._pubsub is not None:
try:
await self._pubsub.close()
except Exception:
logger.warning("Failed to close PubSub connection", exc_info=True)
finally:
self._pubsub = None
"""No-op kept for backward compatibility.
Earlier revisions of this class stored the per-listen pubsub on the
instance, requiring an external close. ``listen_events`` now owns its
own client/pubsub locally so concurrent calls on a singleton (e.g.
``_webhook_event_bus``) cannot clobber each other's connection.
"""
return None
async def publish_event(self, event: M, channel_key: str):
"""
Publish an event to Redis. Gracefully handles connection failures
by logging the error instead of raising exceptions.
"""
"""Publish via SPUBLISH; swallow failures so Redis blips don't crash callers."""
_assert_no_wildcard(channel_key)
try:
message, full_channel_name = self._serialize_message(event, channel_key)
connection = await self.connection
await connection.publish(full_channel_name, message)
cluster = await redis.get_redis_async()
# redis-py 6.x async cluster has no spublish(); execute_command handles MOVED.
await cluster.execute_command("SPUBLISH", full_channel_name, message)
except Exception:
logger.exception(
f"Failed to publish event to Redis channel {channel_key}. "
"Event bus operation will continue without Redis connectivity."
)
logger.exception(f"Failed to publish event to Redis channel {channel_key}")
async def listen_events(self, channel_key: str) -> AsyncGenerator[M, None]:
pubsub, full_channel_name = self._get_pubsub_channel(
await self.connection, channel_key
_assert_no_wildcard(channel_key)
full_channel_name = self._build_channel_name(channel_key)
# Sharded pub/sub only delivers on the keyslot-owning shard, so pin
# a plain AsyncRedis to that node. Both client and pubsub stay
# generator-local — concurrent listen_events on the same instance
# (e.g. the singleton _webhook_event_bus) must not share state.
client: "AsyncRedis" = await redis.connect_sharded_pubsub_async(
full_channel_name
)
assert isinstance(pubsub, AsyncPubSub)
self._pubsub = pubsub
if "*" in channel_key:
await pubsub.psubscribe(full_channel_name)
else:
await pubsub.subscribe(full_channel_name)
async for message in pubsub.listen():
if event := self._deserialize_message(message, channel_key):
yield event
pubsub: AsyncPubSub = client.pubsub()
try:
await pubsub.execute_command("SSUBSCRIBE", full_channel_name)
# redis-py 6.x async PubSub.listen() exits when ``channels`` is
# empty; raw SSUBSCRIBE doesn't populate it, so do it ourselves.
pubsub.channels[full_channel_name] = None # type: ignore[index]
async for message in pubsub.listen():
if event := self._deserialize_message(message, channel_key):
yield event
finally:
try:
await pubsub.aclose()
except Exception:
logger.warning("Failed to close PubSub connection", exc_info=True)
try:
await client.aclose()
except Exception:
logger.warning(
"Failed to close shard-pinned Redis connection", exc_info=True
)
async def wait_for_event(
self, channel_key: str, timeout: Optional[float] = None

View File

@@ -1,25 +1,26 @@
"""
Tests for event_bus graceful degradation when Redis is unavailable.
"""
"""Tests for event_bus publish/listen paths."""
from unittest.mock import AsyncMock, patch
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import BaseModel
from backend.data.event_bus import AsyncRedisEventBus
from backend.data.event_bus import (
AsyncRedisEventBus,
RedisEventBus,
_assert_no_wildcard,
)
class TestEvent(BaseModel):
"""Test event model."""
class SampleEvent(BaseModel):
"""Minimal event model used by the tests below."""
message: str
class TestNotificationBus(AsyncRedisEventBus[TestEvent]):
"""Test implementation of AsyncRedisEventBus."""
Model = TestEvent
class _BusUnderTest(AsyncRedisEventBus[SampleEvent]):
Model = SampleEvent
@property
def event_bus_name(self) -> str:
@@ -28,11 +29,10 @@ class TestNotificationBus(AsyncRedisEventBus[TestEvent]):
@pytest.mark.asyncio
async def test_publish_event_handles_connection_failure_gracefully():
"""Test that publish_event logs exception instead of raising when Redis is unavailable."""
bus = TestNotificationBus()
event = TestEvent(message="test message")
"""publish_event must log and swallow when the cluster client is down."""
bus = _BusUnderTest()
event = SampleEvent(message="test message")
# Mock get_redis_async to raise connection error
with patch(
"backend.data.event_bus.redis.get_redis_async",
side_effect=ConnectionError("Authentication required."),
@@ -42,15 +42,487 @@ async def test_publish_event_handles_connection_failure_gracefully():
@pytest.mark.asyncio
async def test_publish_event_works_with_redis_available():
"""Test that publish_event works normally when Redis is available."""
bus = TestNotificationBus()
event = TestEvent(message="test message")
async def test_publish_event_spublishes_via_cluster_client():
"""publish_event routes a single SPUBLISH through the cluster client."""
bus = _BusUnderTest()
event = SampleEvent(message="test message")
# Mock successful Redis connection
mock_redis = AsyncMock()
mock_redis.publish = AsyncMock()
mock_cluster = MagicMock()
mock_cluster.execute_command = AsyncMock()
with patch("backend.data.event_bus.redis.get_redis_async", return_value=mock_redis):
with patch(
"backend.data.event_bus.redis.get_redis_async", return_value=mock_cluster
):
await bus.publish_event(event, "test_channel")
mock_redis.publish.assert_called_once()
mock_cluster.execute_command.assert_awaited_once()
assert mock_cluster.execute_command.await_args[0][0] == "SPUBLISH"
@pytest.mark.asyncio
async def test_publish_event_rejects_wildcard_channel():
"""A channel_key containing ``*`` must raise — no silent no-op."""
bus = _BusUnderTest()
with patch("backend.data.event_bus.redis.get_redis_async") as get_cluster:
with pytest.raises(ValueError):
await bus.publish_event(SampleEvent(message="m"), "user/*/exec")
# The cluster client must never be reached for a wildcard channel.
get_cluster.assert_not_called()
def test_assert_no_wildcard_guard():
"""The standalone guard must reject any ``*``-containing channel."""
with pytest.raises(ValueError):
_assert_no_wildcard("user/*/exec")
# Concrete channels must pass.
_assert_no_wildcard("execution_event/user-1/graph-1/exec-1")
# Live SSUBSCRIBE round-trip; skipped when no cluster is reachable.
def _has_live_cluster() -> bool:
from backend.data import redis_client
try:
c = redis_client.connect()
except Exception: # noqa: BLE001 - any connect failure → skip the test
return False
try:
c.close()
except Exception:
pass
return True
@pytest.mark.asyncio
@pytest.mark.skipif(
not _has_live_cluster(),
reason="local redis cluster not reachable; skip SSUBSCRIBE integration",
)
async def test_ssubscribe_end_to_end_async():
"""SPUBLISH on one AsyncRedisEventBus reaches SSUBSCRIBE on another."""
import asyncio
from backend.data import redis_client
redis_client.get_redis.cache_clear()
redis_client._async_clients.clear()
publisher = _BusUnderTest()
subscriber = _BusUnderTest()
channel_key = "pr12900:event_bus:integration"
received: list[SampleEvent] = []
async def consume() -> None:
async for ev in subscriber.listen_events(channel_key):
received.append(ev)
return
task = asyncio.create_task(consume())
# Let SSUBSCRIBE settle; races drop the publish otherwise.
await asyncio.sleep(0.3)
try:
await publisher.publish_event(SampleEvent(message="hello-ssub"), channel_key)
await asyncio.wait_for(task, timeout=5.0)
finally:
if not task.done():
task.cancel()
await subscriber.close()
await redis_client.disconnect_async()
assert received and received[0].message == "hello-ssub"
@pytest.mark.asyncio
@pytest.mark.skipif(
not _has_live_cluster(),
reason="local redis cluster not reachable; skip execution-bus integration",
)
async def test_execution_bus_listen_and_listen_graph_both_deliver():
"""Per-exec and per-graph channels both receive every execution event."""
import asyncio
from datetime import datetime, timezone
from backend.data import redis_client
from backend.data.execution import (
AsyncRedisExecutionEventBus,
ExecutionStatus,
GraphExecutionEvent,
)
redis_client.get_redis.cache_clear()
redis_client._async_clients.clear()
user_id = "user-it"
graph_id = "graph-it"
exec_id = "exec-it"
now = datetime.now(tz=timezone.utc)
event = GraphExecutionEvent(
id=exec_id,
user_id=user_id,
graph_id=graph_id,
graph_version=1,
preset_id=None,
status=ExecutionStatus.COMPLETED,
started_at=now,
ended_at=now,
stats=GraphExecutionEvent.Stats(
cost=0, duration=0.1, node_exec_time=0.1, node_exec_count=1
),
inputs={},
credential_inputs=None,
nodes_input_masks=None,
outputs={},
)
single = AsyncRedisExecutionEventBus()
all_execs = AsyncRedisExecutionEventBus()
publisher = AsyncRedisExecutionEventBus()
received_single: list = []
received_all: list = []
async def _listen_single() -> None:
async for ev in single.listen(user_id, graph_id, exec_id):
received_single.append(ev)
return
async def _listen_all() -> None:
async for ev in all_execs.listen_graph(user_id, graph_id):
received_all.append(ev)
return
t1 = asyncio.create_task(_listen_single())
t2 = asyncio.create_task(_listen_all())
await asyncio.sleep(0.3)
try:
await publisher.publish(event)
await asyncio.wait_for(asyncio.gather(t1, t2), timeout=5.0)
finally:
for t in (t1, t2):
if not t.done():
t.cancel()
await single.close()
await all_execs.close()
await publisher.close()
await redis_client.disconnect_async()
assert received_single and received_single[0].id == exec_id
assert received_all and received_all[0].id == exec_id
@pytest.mark.asyncio
async def test_listen_events_rejects_wildcard_channel():
"""listen_events on a wildcard channel must raise before touching Redis."""
bus = _BusUnderTest()
with pytest.raises(ValueError):
async for _ in bus.listen_events("user/*/exec"):
break
# ---------- Serialization + size guard ----------
def test_serialize_message_tags_full_channel_name():
"""_serialize_message returns the ``<bus>/<key>`` full channel name."""
bus = _BusUnderTest()
_, full = bus._serialize_message(SampleEvent(message="x"), "chan")
assert full == "test_event_bus/chan"
def test_serialize_message_truncates_oversized_payload(monkeypatch):
"""If the payload exceeds max_message_size_limit, it's replaced with an
``error_comms_update`` payload rather than crashing the cluster."""
import backend.data.event_bus as event_bus
bus = _BusUnderTest()
# Cap tiny to force truncation.
monkeypatch.setattr(event_bus.config, "max_message_size_limit", 50)
message, _ = bus._serialize_message(SampleEvent(message="x" * 1000), "chan")
assert "error_comms_update" in message
assert "Payload too large" in message
def test_deserialize_message_rejects_non_pubsub_types():
"""Non ``smessage|message|pmessage`` deliveries deserialize to None."""
bus = _BusUnderTest()
assert bus._deserialize_message({"type": "ssubscribe", "data": 1}, "c") is None
assert bus._deserialize_message({"type": "subscribe", "data": 1}, "c") is None
def test_deserialize_message_swallows_bad_json():
"""Corrupted payloads must not raise — they return None (logged elsewhere)."""
bus = _BusUnderTest()
assert (
bus._deserialize_message({"type": "smessage", "data": "not-json"}, "c") is None
)
def test_deserialize_message_parses_smessage():
"""Happy-path ``smessage`` yields the inner event model."""
bus = _BusUnderTest()
wrapped = '{"payload":{"message":"hi"}}'
parsed = bus._deserialize_message({"type": "smessage", "data": wrapped}, "chan")
assert parsed is not None and parsed.message == "hi"
# ---------- Sync RedisEventBus ----------
class _SyncBusUnderTest(RedisEventBus[SampleEvent]):
Model = SampleEvent
@property
def event_bus_name(self) -> str:
return "test_event_bus"
def test_sync_publish_event_spublish_only():
"""Sync publish_event must issue a single SPUBLISH (no classic fallback)."""
bus = _SyncBusUnderTest()
cluster = MagicMock()
cluster.execute_command = MagicMock()
with patch("backend.data.event_bus.redis.get_redis", return_value=cluster):
bus.publish_event(SampleEvent(message="m"), "chan")
cluster.execute_command.assert_called_once()
assert cluster.execute_command.call_args.args[0] == "SPUBLISH"
def test_sync_publish_event_rejects_wildcard():
bus = _SyncBusUnderTest()
with patch("backend.data.event_bus.redis.get_redis") as mock_get:
with pytest.raises(ValueError):
bus.publish_event(SampleEvent(message="m"), "user/*/exec")
mock_get.assert_not_called()
def test_sync_publish_event_swallows_connection_errors():
"""publish_event must never raise to callers — logs + drops on failure."""
bus = _SyncBusUnderTest()
with patch(
"backend.data.event_bus.redis.get_redis",
side_effect=ConnectionError("no redis"),
):
# Should NOT raise.
bus.publish_event(SampleEvent(message="m"), "chan")
def test_sync_listen_events_rejects_wildcard():
bus = _SyncBusUnderTest()
with pytest.raises(ValueError):
next(iter(bus.listen_events("user/*/exec")))
def test_sync_listen_events_ssubscribes_and_yields_decoded_events():
"""Sync listen_events: SSUBSCRIBE on the full channel, decode smessage payloads."""
bus = _SyncBusUnderTest()
fake_pubsub = MagicMock()
fake_pubsub.ssubscribe = MagicMock()
fake_pubsub.sunsubscribe = MagicMock()
fake_pubsub.close = MagicMock()
fake_pubsub.listen = MagicMock(
return_value=iter(
[
{"type": "ssubscribe", "data": 1},
{"type": "smessage", "data": '{"payload":{"message":"one"}}'},
]
)
)
cluster = MagicMock()
cluster.pubsub = MagicMock(return_value=fake_pubsub)
with patch("backend.data.event_bus.redis.get_redis", return_value=cluster):
gen = bus.listen_events("chan")
first = next(iter(gen))
assert first.message == "one"
fake_pubsub.ssubscribe.assert_called_once_with("test_event_bus/chan")
def test_sync_listen_events_teardown_swallows_sunsubscribe_errors():
"""Teardown must not propagate SUNSUBSCRIBE/close failures."""
bus = _SyncBusUnderTest()
fake_pubsub = MagicMock()
fake_pubsub.ssubscribe = MagicMock()
fake_pubsub.sunsubscribe = MagicMock(side_effect=RuntimeError("SUNSUB broke"))
fake_pubsub.close = MagicMock(side_effect=RuntimeError("close broke"))
fake_pubsub.listen = MagicMock(return_value=iter([]))
cluster = MagicMock()
cluster.pubsub = MagicMock(return_value=fake_pubsub)
with patch("backend.data.event_bus.redis.get_redis", return_value=cluster):
# Exhausting the generator runs the ``finally`` teardown.
list(bus.listen_events("chan"))
fake_pubsub.sunsubscribe.assert_called_once()
fake_pubsub.close.assert_called_once()
# ---------- Async close() teardown ----------
@pytest.mark.asyncio
async def test_async_close_is_noop():
"""close() is a backward-compat no-op now that listen_events owns its own state."""
bus = _BusUnderTest()
# Repeated calls must not crash; pubsub/client are generator-locals.
await bus.close()
await bus.close()
@pytest.mark.asyncio
async def test_async_listen_events_swallows_aclose_errors():
"""Broken pubsub.aclose / client.aclose must not propagate to the caller."""
bus = _BusUnderTest()
fake_pubsub = MagicMock()
fake_pubsub.execute_command = AsyncMock()
fake_pubsub.channels = {}
fake_pubsub.aclose = AsyncMock(side_effect=RuntimeError("pubsub broke"))
async def _listen():
return
yield # pragma: no cover — unreachable
fake_pubsub.listen = _listen
fake_client = MagicMock()
fake_client.pubsub = MagicMock(return_value=fake_pubsub)
fake_client.aclose = AsyncMock(side_effect=RuntimeError("client broke"))
with patch(
"backend.data.event_bus.redis.connect_sharded_pubsub_async",
AsyncMock(return_value=fake_client),
):
async for _ in bus.listen_events("chan"):
pass # pragma: no cover — never yields
# Both aclose attempts must have run despite raising.
fake_pubsub.aclose.assert_awaited_once()
fake_client.aclose.assert_awaited_once()
@pytest.mark.asyncio
async def test_async_listen_events_concurrent_does_not_share_state():
"""Two concurrent listens on the same bus must keep their pubsub/client local."""
bus = _BusUnderTest()
pubsubs: list[MagicMock] = []
clients: list[MagicMock] = []
started = asyncio.Event()
proceed = asyncio.Event()
def _make_pair() -> tuple[MagicMock, MagicMock]:
pubsub = MagicMock()
pubsub.execute_command = AsyncMock()
pubsub.channels = {}
pubsub.aclose = AsyncMock()
async def _listen():
started.set()
await proceed.wait()
return
yield # pragma: no cover — unreachable
pubsub.listen = _listen
client = MagicMock()
client.pubsub = MagicMock(return_value=pubsub)
client.aclose = AsyncMock()
pubsubs.append(pubsub)
clients.append(client)
return pubsub, client
async def _factory(_chan: str):
_, client = _make_pair()
return client
with patch(
"backend.data.event_bus.redis.connect_sharded_pubsub_async",
AsyncMock(side_effect=_factory),
):
async def _run():
async for _ in bus.listen_events("chan"):
pass # pragma: no cover — never yields
task_a = asyncio.create_task(_run())
task_b = asyncio.create_task(_run())
# Wait for both pumps to be parked inside listen() before unblocking.
await started.wait()
# Yield once more so the second task also enters listen().
await asyncio.sleep(0)
proceed.set()
await asyncio.gather(task_a, task_b)
# Each listen must have closed its OWN pubsub/client exactly once. If
# either was closed twice or zero times, the singleton race is back.
assert len(pubsubs) == 2
for pubsub in pubsubs:
pubsub.aclose.assert_awaited_once()
for client in clients:
client.aclose.assert_awaited_once()
@pytest.mark.asyncio
async def test_async_wait_for_event_returns_none_on_timeout():
"""wait_for_event must coerce asyncio.TimeoutError → None."""
import asyncio as _asyncio
bus = _BusUnderTest()
async def _never(self, channel_key):
await _asyncio.sleep(10)
yield # pragma: no cover — unreachable
with patch.object(_BusUnderTest, "listen_events", _never):
result = await bus.wait_for_event("chan", timeout=0.01)
assert result is None
# The listen_events async happy path is covered by the live-cluster integration
# test above; this one exercises the close-on-exception fallback.
@pytest.mark.asyncio
async def test_async_listen_events_closes_on_exception():
"""If the pump raises, close() must still run to release the shard-pinned client."""
bus = _BusUnderTest()
fake_pubsub = MagicMock()
fake_pubsub.execute_command = AsyncMock()
fake_pubsub.channels = {}
fake_pubsub.aclose = AsyncMock()
class _Boom(Exception):
pass
async def _listen():
raise _Boom()
yield # pragma: no cover — unreachable
fake_pubsub.listen = _listen
fake_client = MagicMock()
fake_client.pubsub = MagicMock(return_value=fake_pubsub)
fake_client.aclose = AsyncMock()
with patch(
"backend.data.event_bus.redis.connect_sharded_pubsub_async",
AsyncMock(return_value=fake_client),
):
with pytest.raises(_Boom):
async for _ in bus.listen_events("chan"):
pass
# close() must have fired (both aclose calls).
fake_pubsub.aclose.assert_awaited_once()
fake_client.aclose.assert_awaited_once()

View File

@@ -570,7 +570,7 @@ async def get_graph_executions(
# Build properly typed order clause
# Prisma wants specific typed dicts for each field, so we construct them explicitly
order_clause: AgentGraphExecutionOrderByInput
match (order_by):
match order_by:
case "startedAt":
order_clause = {
"startedAt": order_direction,
@@ -1337,6 +1337,22 @@ ExecutionEvent = Annotated[
]
# Hash-tagged channels keep per-exec and per-graph keys on the same shard,
# so one SSUBSCRIBE connection can watch both.
def _graph_scope_tag(user_id: str, graph_id: str) -> str:
return "{" + f"{user_id}/{graph_id}" + "}"
def exec_channel(user_id: str, graph_id: str, graph_exec_id: str) -> str:
return f"{_graph_scope_tag(user_id, graph_id)}/exec/{graph_exec_id}"
def graph_all_channel(user_id: str, graph_id: str) -> str:
return f"{_graph_scope_tag(user_id, graph_id)}/all"
class RedisExecutionEventBus(RedisEventBus[ExecutionEvent]):
Model = ExecutionEvent # type: ignore
@@ -1352,16 +1368,20 @@ class RedisExecutionEventBus(RedisEventBus[ExecutionEvent]):
def _publish_node_exec_update(self, res: NodeExecutionResult):
event = NodeExecutionEvent.model_validate(res.model_dump())
self._publish(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
self._publish(event, res.user_id, res.graph_id, res.graph_exec_id)
def _publish_graph_exec_update(self, res: GraphExecution):
event = GraphExecutionEvent.model_validate(res.model_dump())
self._publish(event, f"{res.user_id}/{res.graph_id}/{res.id}")
self._publish(event, res.user_id, res.graph_id, res.id)
def _publish(self, event: ExecutionEvent, channel: str):
"""
truncate inputs and outputs to avoid large payloads
"""
def _publish(
self,
event: ExecutionEvent,
user_id: str,
graph_id: str,
graph_exec_id: str,
):
"""Truncate oversized payloads, then publish to per-exec + per-graph channels."""
limit = config.max_message_size_limit // 2
if isinstance(event, GraphExecutionEvent):
event.inputs = truncate(event.inputs, limit)
@@ -1370,12 +1390,22 @@ class RedisExecutionEventBus(RedisEventBus[ExecutionEvent]):
event.input_data = truncate(event.input_data, limit)
event.output_data = truncate(event.output_data, limit)
super().publish_event(event, channel)
# Publisher fans out: per-exec and per-graph watchers.
super().publish_event(event, exec_channel(user_id, graph_id, graph_exec_id))
super().publish_event(event, graph_all_channel(user_id, graph_id))
def listen(
self, user_id: str, graph_id: str = "*", graph_exec_id: str = "*"
self, user_id: str, graph_id: str, graph_exec_id: str
) -> Generator[ExecutionEvent, None, None]:
for event in self.listen_events(f"{user_id}/{graph_id}/{graph_exec_id}"):
"""Stream events for a specific graph execution."""
for event in self.listen_events(exec_channel(user_id, graph_id, graph_exec_id)):
yield event
def listen_graph(
self, user_id: str, graph_id: str
) -> Generator[ExecutionEvent, None, None]:
"""Stream every event for every execution of ``graph_id``."""
for event in self.listen_events(graph_all_channel(user_id, graph_id)):
yield event
@@ -1395,7 +1425,7 @@ class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionEvent]):
async def _publish_node_exec_update(self, res: NodeExecutionResult):
event = NodeExecutionEvent.model_validate(res.model_dump())
await self._publish(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
await self._publish(event, res.user_id, res.graph_id, res.graph_exec_id)
async def _publish_graph_exec_update(self, res: GraphExecutionMeta):
# GraphExecutionEvent requires inputs and outputs fields that GraphExecutionMeta doesn't have
@@ -1404,12 +1434,16 @@ class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionEvent]):
event_data.setdefault("inputs", {})
event_data.setdefault("outputs", {})
event = GraphExecutionEvent.model_validate(event_data)
await self._publish(event, f"{res.user_id}/{res.graph_id}/{res.id}")
await self._publish(event, res.user_id, res.graph_id, res.id)
async def _publish(self, event: ExecutionEvent, channel: str):
"""
truncate inputs and outputs to avoid large payloads
"""
async def _publish(
self,
event: ExecutionEvent,
user_id: str,
graph_id: str,
graph_exec_id: str,
):
"""Truncate oversized payloads, then publish to per-exec + per-graph channels."""
limit = config.max_message_size_limit // 2
if isinstance(event, GraphExecutionEvent):
event.inputs = truncate(event.inputs, limit)
@@ -1418,12 +1452,25 @@ class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionEvent]):
event.input_data = truncate(event.input_data, limit)
event.output_data = truncate(event.output_data, limit)
await super().publish_event(event, channel)
await super().publish_event(
event, exec_channel(user_id, graph_id, graph_exec_id)
)
await super().publish_event(event, graph_all_channel(user_id, graph_id))
async def listen(
self, user_id: str, graph_id: str = "*", graph_exec_id: str = "*"
self, user_id: str, graph_id: str, graph_exec_id: str
) -> AsyncGenerator[ExecutionEvent, None]:
async for event in self.listen_events(f"{user_id}/{graph_id}/{graph_exec_id}"):
"""Stream events for a specific graph execution."""
async for event in self.listen_events(
exec_channel(user_id, graph_id, graph_exec_id)
):
yield event
async def listen_graph(
self, user_id: str, graph_id: str
) -> AsyncGenerator[ExecutionEvent, None]:
"""Stream every event for every execution of ``graph_id``."""
async for event in self.listen_events(graph_all_channel(user_id, graph_id)):
yield event
@@ -1682,11 +1729,11 @@ async def create_shared_execution_files(
created += 1
except UniqueViolationError:
logger.debug(
f"Skipping shared file record for {file_id}: " f"record already exists"
f"Skipping shared file record for {file_id}: record already exists"
)
except ForeignKeyViolationError:
logger.debug(
f"Skipping shared file record for {file_id}: " f"file does not exist"
f"Skipping shared file record for {file_id}: file does not exist"
)
return created

View File

@@ -0,0 +1,387 @@
"""Tests for the sharded channel builders + publish/listen paths on
``AsyncRedisExecutionEventBus`` / ``RedisExecutionEventBus``.
These tests are intentionally Prisma-free: they exercise only the in-process
event-routing layer, using mocks for the Redis cluster client. The live
SSUBSCRIBE round-trip is covered by the integration test in
``event_bus_test.py``.
"""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.data.execution import (
AsyncRedisExecutionEventBus,
ExecutionEventType,
ExecutionStatus,
GraphExecutionEvent,
NodeExecutionEvent,
RedisExecutionEventBus,
_graph_scope_tag,
exec_channel,
graph_all_channel,
)
# ---------- Hash-tagged channel builders ----------
def test_graph_scope_tag_uses_hash_tag_syntax():
"""Hash-tagged tag must look like ``{user/graph}`` so per-exec + per-graph
channels hash to the same Redis Cluster keyslot."""
assert _graph_scope_tag("u", "g") == "{u/g}"
def test_exec_channel_nests_scope_tag():
"""Per-exec channel: ``{user/graph}/exec/<exec_id>``."""
assert exec_channel("u", "g", "e") == "{u/g}/exec/e"
def test_graph_all_channel_nests_scope_tag():
"""Aggregate channel: ``{user/graph}/all`` — keyslot-aligned with per-exec."""
assert graph_all_channel("u", "g") == "{u/g}/all"
def test_exec_and_graph_channels_share_hash_tag():
"""Invariant: both channels *must* share the ``{user/graph}`` prefix.
If this breaks, SSUBSCRIBE for per-exec and aggregate routes to different
shards and the per-graph listener loses some events."""
exec_ch = exec_channel("u", "g", "e")
graph_ch = graph_all_channel("u", "g")
assert exec_ch.startswith("{u/g}")
assert graph_ch.startswith("{u/g}")
# ---------- NodeExecutionEvent publish → exec channel only ----------
def _sample_node_event() -> NodeExecutionEvent:
now = datetime.now(tz=timezone.utc)
return NodeExecutionEvent(
user_id="u",
graph_id="g",
graph_version=1,
graph_exec_id="e",
node_exec_id="ne",
node_id="nid",
block_id="bid",
status=ExecutionStatus.COMPLETED,
input_data={"a": 1},
output_data={"o": [1]},
add_time=now,
queue_time=None,
start_time=now,
end_time=now,
)
def _sample_graph_event() -> GraphExecutionEvent:
now = datetime.now(tz=timezone.utc)
return GraphExecutionEvent(
id="e",
user_id="u",
graph_id="g",
graph_version=1,
preset_id=None,
status=ExecutionStatus.COMPLETED,
started_at=now,
ended_at=now,
stats=GraphExecutionEvent.Stats(
cost=0, duration=0.1, node_exec_time=0.1, node_exec_count=1
),
inputs={},
credential_inputs=None,
nodes_input_masks=None,
outputs={},
)
@pytest.mark.asyncio
async def test_async_publish_node_sends_to_both_channels():
"""Node events fan out to BOTH per-exec and aggregate channels so the
per-graph WS subscriber sees every node update, not just graph-level ones.
"""
bus = AsyncRedisExecutionEventBus()
sent_channels: list[str] = []
async def _capture(self, event, channel_key):
sent_channels.append(channel_key)
with patch.object(
AsyncRedisExecutionEventBus.__mro__[1], "publish_event", _capture
):
await bus._publish_node_exec_update(_sample_node_event())
assert sent_channels == [
exec_channel("u", "g", "e"),
graph_all_channel("u", "g"),
]
@pytest.mark.asyncio
async def test_async_publish_graph_sends_to_both_channels():
bus = AsyncRedisExecutionEventBus()
sent_channels: list[str] = []
async def _capture(self, event, channel_key):
sent_channels.append(channel_key)
with patch.object(
AsyncRedisExecutionEventBus.__mro__[1], "publish_event", _capture
):
await bus._publish_graph_exec_update(_sample_graph_event())
assert sent_channels == [
exec_channel("u", "g", "e"),
graph_all_channel("u", "g"),
]
@pytest.mark.asyncio
async def test_async_publish_routes_via_type_dispatch():
"""publish() dispatches on the model type — not on status or event_type."""
bus = AsyncRedisExecutionEventBus()
with (
patch.object(bus, "_publish_graph_exec_update", AsyncMock()) as graph_pub,
patch.object(bus, "_publish_node_exec_update", AsyncMock()) as node_pub,
):
await bus.publish(_sample_graph_event())
await bus.publish(_sample_node_event())
graph_pub.assert_awaited_once()
node_pub.assert_awaited_once()
@pytest.mark.asyncio
async def test_async_publish_truncates_oversized_payload(monkeypatch):
"""Payload truncation applies before sending — size exceeded → replacement."""
import backend.data.execution as execution
bus = AsyncRedisExecutionEventBus()
# Force tiny limit so ``truncate`` rewrites the payload.
monkeypatch.setattr(execution.config, "max_message_size_limit", 10)
cluster = MagicMock()
cluster.execute_command = AsyncMock()
with patch("backend.data.event_bus.redis.get_redis_async", return_value=cluster):
await bus.publish(_sample_node_event())
# Called twice: per-exec and per-graph channel.
assert cluster.execute_command.await_count == 2
@pytest.mark.asyncio
async def test_async_listen_uses_exec_channel():
"""listen() must subscribe to the per-exec hash-tagged channel."""
bus = AsyncRedisExecutionEventBus()
captured: list[str] = []
async def _listen_events(self, channel_key):
captured.append(channel_key)
# Return an empty async-generator so the ``async for`` exits cleanly.
if False:
yield # pragma: no cover
with patch.object(AsyncRedisExecutionEventBus, "listen_events", _listen_events):
async for _ in bus.listen("u", "g", "e"):
break # pragma: no cover — generator is empty
assert captured == [exec_channel("u", "g", "e")]
@pytest.mark.asyncio
async def test_async_listen_graph_uses_all_channel():
"""listen_graph() must subscribe to the aggregate hash-tagged channel."""
bus = AsyncRedisExecutionEventBus()
captured: list[str] = []
async def _listen_events(self, channel_key):
captured.append(channel_key)
if False:
yield # pragma: no cover
with patch.object(AsyncRedisExecutionEventBus, "listen_events", _listen_events):
async for _ in bus.listen_graph("u", "g"):
break # pragma: no cover — generator is empty
assert captured == [graph_all_channel("u", "g")]
# ---------- Sync RedisExecutionEventBus (smaller surface; covers branching) ----------
def test_sync_listen_uses_exec_channel():
bus = RedisExecutionEventBus()
captured: list[str] = []
def _listen_events(self, channel_key):
captured.append(channel_key)
return iter([])
with patch.object(RedisExecutionEventBus, "listen_events", _listen_events):
list(bus.listen("u", "g", "e"))
assert captured == [exec_channel("u", "g", "e")]
def test_sync_listen_graph_uses_all_channel():
bus = RedisExecutionEventBus()
captured: list[str] = []
def _listen_events(self, channel_key):
captured.append(channel_key)
return iter([])
with patch.object(RedisExecutionEventBus, "listen_events", _listen_events):
list(bus.listen_graph("u", "g"))
assert captured == [graph_all_channel("u", "g")]
def test_sync_publish_node_sends_to_both_channels():
"""Sync publish path also fans out to per-exec + per-graph."""
bus = RedisExecutionEventBus()
sent: list[str] = []
def _capture(self, event, channel_key):
sent.append(channel_key)
with patch.object(RedisExecutionEventBus.__mro__[1], "publish_event", _capture):
bus._publish_node_exec_update(_sample_node_event().model_copy())
assert sent == [
exec_channel("u", "g", "e"),
graph_all_channel("u", "g"),
]
def test_event_type_is_literal_on_events():
"""event_type is a discriminator literal, not dynamic — the WS fan-out
relies on ``ExecutionEventType(event_type)`` being stable."""
node = _sample_node_event()
graph = _sample_graph_event()
assert node.event_type == ExecutionEventType.NODE_EXEC_UPDATE
assert graph.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE
# ---------- Sync publish dispatch + listen yields ----------
def test_sync_publish_dispatches_on_model_type():
"""Sync ``publish()`` routes GraphExecution and NodeExecutionResult to
their respective helpers — regression guard on the type-dispatch branch."""
from backend.data.execution import GraphExecution, NodeExecutionResult
bus = RedisExecutionEventBus()
graph_like = MagicMock(spec=GraphExecution)
node_like = MagicMock(spec=NodeExecutionResult)
with (
patch.object(bus, "_publish_graph_exec_update") as graph_pub,
patch.object(bus, "_publish_node_exec_update") as node_pub,
):
bus.publish(graph_like)
bus.publish(node_like)
graph_pub.assert_called_once_with(graph_like)
node_pub.assert_called_once_with(node_like)
def test_sync_publish_graph_exec_update_rebuilds_event():
"""Sync ``_publish_graph_exec_update`` validates the input into a
GraphExecutionEvent before delegating to ``_publish`` — don't let a raw
GraphExecution slip through the type-discriminated listener."""
bus = RedisExecutionEventBus()
graph_event = _sample_graph_event()
with patch.object(bus, "_publish") as mock_publish:
# Feed back the event itself (it's a GraphExecution subclass) to avoid
# needing a full Graph fixture.
bus._publish_graph_exec_update(graph_event)
mock_publish.assert_called_once()
args = mock_publish.call_args.args
# The first arg is a GraphExecutionEvent (validated copy).
assert args[0].event_type == ExecutionEventType.GRAPH_EXEC_UPDATE
# Channel-routing args match the input.
assert args[1:] == ("u", "g", "e")
def test_sync_publish_node_exec_update_rebuilds_event():
"""Sync ``_publish_node_exec_update`` validates to NodeExecutionEvent."""
bus = RedisExecutionEventBus()
node_event = _sample_node_event()
with patch.object(bus, "_publish") as mock_publish:
bus._publish_node_exec_update(node_event)
mock_publish.assert_called_once()
args = mock_publish.call_args.args
assert args[0].event_type == ExecutionEventType.NODE_EXEC_UPDATE
assert args[1:] == ("u", "g", "e")
def test_sync_publish_graph_truncates_inputs_and_outputs(monkeypatch):
"""Sync ``_publish`` must truncate GraphExecutionEvent.inputs/outputs when
the payload exceeds the cap — protects Redis from oversized frames."""
import backend.data.execution as execution
bus = RedisExecutionEventBus()
monkeypatch.setattr(execution.config, "max_message_size_limit", 4)
event = _sample_graph_event()
event.inputs = {"long": "x" * 10_000}
event.outputs = {"long": ["y" * 10_000]}
with patch("backend.data.event_bus.redis.get_redis", return_value=MagicMock()):
bus._publish(event, "u", "g", "e")
# After _publish runs, inputs/outputs have been truncated in-place.
import json as _json
assert len(_json.dumps(event.inputs)) < 1000
assert len(_json.dumps(event.outputs)) < 1000
def test_sync_listen_yields_events_from_generator():
"""Sync ``listen()`` must yield through every event produced by the
underlying ``listen_events`` generator."""
bus = RedisExecutionEventBus()
node_ev = _sample_node_event()
def _listen_events(self, channel_key):
yield node_ev
with patch.object(RedisExecutionEventBus, "listen_events", _listen_events):
got = list(bus.listen("u", "g", "e"))
assert got == [node_ev]
def test_sync_listen_graph_yields_events_from_generator():
bus = RedisExecutionEventBus()
graph_ev = _sample_graph_event()
def _listen_events(self, channel_key):
yield graph_ev
with patch.object(RedisExecutionEventBus, "listen_events", _listen_events):
got = list(bus.listen_graph("u", "g"))
assert got == [graph_ev]
def test_execution_bus_name_matches_settings():
"""Both sync and async buses must read the same configured bus name — the
WS subscriber depends on this for channel naming."""
assert (
RedisExecutionEventBus().event_bus_name
== AsyncRedisExecutionEventBus().event_bus_name
)

View File

@@ -1,5 +1,5 @@
import logging
from typing import AsyncGenerator, Literal, Optional, overload
from typing import Literal, Optional, overload
from prisma.models import AgentNode, AgentPreset, IntegrationWebhook
from prisma.types import (
@@ -354,18 +354,10 @@ async def publish_webhook_event(event: WebhookEvent):
)
async def listen_for_webhook_events(
webhook_id: str, event_type: Optional[str] = None
) -> AsyncGenerator[WebhookEvent, None]:
async for event in _webhook_event_bus.listen_events(
f"{webhook_id}/{event_type or '*'}"
):
yield event
async def wait_for_webhook_event(
webhook_id: str, event_type: Optional[str] = None, timeout: Optional[float] = None
webhook_id: str, event_type: str, timeout: Optional[float] = None
) -> WebhookEvent | None:
# Concrete event_type required: sharded pub/sub has no pattern support.
return await _webhook_event_bus.wait_for_event(
f"{webhook_id}/{event_type or '*'}", timeout
f"{webhook_id}/{event_type}", timeout
)

View File

@@ -1,15 +1,23 @@
from __future__ import annotations
import asyncio
import logging
from typing import AsyncGenerator
from pydantic import BaseModel, field_serializer
from backend.api.model import NotificationPayload
from backend.data.event_bus import AsyncRedisEventBus
from backend.data.push_sender import send_push_for_user
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
_settings = Settings()
# Strong refs for in-flight push fanout tasks. asyncio only keeps weak refs
# to tasks, so a fire-and-forget create_task can be GC'd mid-run.
_push_tasks: set[asyncio.Task] = set()
class NotificationEvent(BaseModel):
"""Generic notification event destined for websocket delivery."""
@@ -23,6 +31,14 @@ class NotificationEvent(BaseModel):
return payload.model_dump()
async def _safe_send_push(user_id: str, payload: NotificationPayload) -> None:
"""Deliver web push for a notification, swallowing errors."""
try:
await send_push_for_user(user_id, payload)
except Exception:
logger.exception("Failed to send web push for user %s", user_id)
class AsyncRedisNotificationEventBus(AsyncRedisEventBus[NotificationEvent]):
Model = NotificationEvent # type: ignore
@@ -32,9 +48,19 @@ class AsyncRedisNotificationEventBus(AsyncRedisEventBus[NotificationEvent]):
async def publish(self, event: NotificationEvent) -> None:
await self.publish_event(event, event.user_id)
# Skip OS push for onboarding step toasts — those are in-page only.
# TODO: remove once the onboarding/wallet rework lands and decides
# per-event whether a system notification is desired.
if event.payload.model_dump().get("type") == "onboarding":
return
# Fan out to web push subscriptions in parallel. Fire-and-forget so
# publishers never wait on the push service; held in _push_tasks so
# the task survives until completion.
task = asyncio.create_task(_safe_send_push(event.user_id, event.payload))
_push_tasks.add(task)
task.add_done_callback(_push_tasks.discard)
async def listen(
self, user_id: str = "*"
) -> AsyncGenerator[NotificationEvent, None]:
async def listen(self, user_id: str) -> AsyncGenerator[NotificationEvent, None]:
"""Stream notifications for a specific user."""
async for event in self.listen_events(user_id):
yield event

View File

@@ -0,0 +1,145 @@
"""Tests for AsyncRedisNotificationEventBus.
Covers the tiny delegation surface: publish → publish_event(user_id), listen
→ listen_events(user_id), and the payload serializer that ensures extra
fields survive the Redis round-trip.
"""
from unittest.mock import AsyncMock, patch
import pytest
from backend.api.model import NotificationPayload
from backend.data.notification_bus import (
AsyncRedisNotificationEventBus,
NotificationEvent,
)
def test_notification_event_serializes_payload_including_extras():
"""``NotificationPayload`` allows extra fields; the bus serializer must
preserve them. Dropping extras breaks feature payloads like CopilotCompletion."""
payload = NotificationPayload(type="info", event="hey", extra_field="survive me")
event = NotificationEvent(user_id="u", payload=payload)
dumped = event.model_dump()
assert dumped["payload"]["type"] == "info"
assert dumped["payload"]["event"] == "hey"
assert dumped["payload"]["extra_field"] == "survive me"
@pytest.mark.asyncio
async def test_publish_calls_publish_event_with_user_id_channel():
"""publish(event) → publish_event(event, channel_key=event.user_id)."""
bus = AsyncRedisNotificationEventBus()
event = NotificationEvent(
user_id="user-42", payload=NotificationPayload(type="info", event="hi")
)
with patch.object(
AsyncRedisNotificationEventBus, "publish_event", AsyncMock()
) as mock_pub:
await bus.publish(event)
mock_pub.assert_awaited_once()
args = mock_pub.await_args.args
# Pydantic may pass the event as a positional; regardless, user_id is the channel.
assert args[-1] == "user-42"
@pytest.mark.asyncio
async def test_listen_delegates_to_listen_events_for_user():
"""listen(user_id) must subscribe on the per-user channel."""
bus = AsyncRedisNotificationEventBus()
captured: list[str] = []
async def _listen_events(self, channel_key):
captured.append(channel_key)
if False:
yield # pragma: no cover
with patch.object(AsyncRedisNotificationEventBus, "listen_events", _listen_events):
async for _ in bus.listen("user-42"):
break # pragma: no cover — generator empty
assert captured == ["user-42"]
def test_event_bus_name_is_configured() -> None:
"""The notification bus uses a distinct namespace from the execution bus,
so WS exec channels and notification channels never collide."""
bus = AsyncRedisNotificationEventBus()
assert bus.event_bus_name # non-empty, configured via Settings
@pytest.mark.asyncio
async def test_publish_fans_out_to_web_push():
"""publish() must also kick off web-push fanout for the user."""
bus = AsyncRedisNotificationEventBus()
event = NotificationEvent(
user_id="user-42", payload=NotificationPayload(type="info", event="hi")
)
with (
patch.object(AsyncRedisNotificationEventBus, "publish_event", AsyncMock()),
patch(
"backend.data.notification_bus.send_push_for_user",
new_callable=AsyncMock,
) as mock_push,
):
await bus.publish(event)
# create_task is fire-and-forget — let the event loop drain the task.
import asyncio
for _ in range(3):
await asyncio.sleep(0)
mock_push.assert_awaited_once_with("user-42", event.payload)
@pytest.mark.asyncio
async def test_publish_skips_web_push_for_onboarding():
"""Onboarding step toasts are in-page only and must NOT trigger OS push."""
bus = AsyncRedisNotificationEventBus()
event = NotificationEvent(
user_id="user-42",
payload=NotificationPayload(type="onboarding", event="step_completed"),
)
with (
patch.object(AsyncRedisNotificationEventBus, "publish_event", AsyncMock()),
patch(
"backend.data.notification_bus.send_push_for_user",
new_callable=AsyncMock,
) as mock_push,
):
await bus.publish(event)
import asyncio
for _ in range(3):
await asyncio.sleep(0)
mock_push.assert_not_awaited()
@pytest.mark.asyncio
async def test_publish_swallows_push_errors():
"""A failing push must not propagate or fail the publish."""
bus = AsyncRedisNotificationEventBus()
event = NotificationEvent(
user_id="user-42", payload=NotificationPayload(type="info", event="hi")
)
with (
patch.object(AsyncRedisNotificationEventBus, "publish_event", AsyncMock()),
patch(
"backend.data.notification_bus.send_push_for_user",
new_callable=AsyncMock,
side_effect=RuntimeError("push backend down"),
),
):
await bus.publish(event) # must not raise
import asyncio
for _ in range(3):
await asyncio.sleep(0)

View File

@@ -0,0 +1,139 @@
"""Fire-and-forget Web Push delivery for notification events."""
import asyncio
import json
import logging
import re
import time
import uuid
from cachetools import TTLCache
from pywebpush import WebPushException, webpush
from backend.api.model import NotificationPayload
from backend.data.push_subscription import PushSubscriptionDTO, validate_push_endpoint
from backend.util.clients import get_database_manager_async_client
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
_settings = Settings()
DEBOUNCE_SECONDS = 5.0
# Per-user debounce timestamps, bounded + auto-evicted so the process doesn't
# accumulate one entry per user forever. Process-local — ineffective across
# multiple WS replicas; acceptable since debounce is a best-effort UX nicety.
_user_last_push: TTLCache[str, float] = TTLCache(maxsize=10_000, ttl=DEBOUNCE_SECONDS)
# Fields to forward from the notification payload to the push message
_FORWARDED_FIELDS = ("session_id", "step", "status", "graph_id", "execution_id")
def _extract_status_code(e: WebPushException) -> int | None:
"""Extract HTTP status code from a pywebpush exception."""
if e.response is not None:
return e.response.status_code
# Fallback: parse "Push failed: <code> <reason>" out of the message in
# case a future pywebpush version raises without attaching the Response.
match = re.search(r"Push failed:\s*(\d{3})\b", str(e))
return int(match.group(1)) if match else None
def _build_push_payload(payload: NotificationPayload) -> str:
"""Build a compact JSON payload (<4KB) for the push message.
``id`` is a per-push UUID used by the service worker to build a unique
notification tag, so repeat pushes don't get coalesced by the OS.
"""
data = payload.model_dump(mode="json")
compact: dict[str, object] = {
"id": uuid.uuid4().hex,
"type": data.get("type", ""),
"event": data.get("event", ""),
}
for key in _FORWARDED_FIELDS:
if key in data:
compact[key] = data[key]
return json.dumps(compact)
async def send_push_for_user(user_id: str, payload: NotificationPayload) -> None:
"""Send push notifications to all of a user's subscriptions.
- Skips silently if VAPID keys are not configured.
- Debounces per-user (collapses pushes within DEBOUNCE_SECONDS).
- Cleans up stale subscriptions on 410/404 responses.
"""
vapid_private = _settings.secrets.vapid_private_key
vapid_public = _settings.secrets.vapid_public_key
vapid_claim_email = _settings.secrets.vapid_claim_email
if not vapid_private or not vapid_public or not vapid_claim_email:
logger.debug("VAPID keys not fully configured, skipping push")
return
# py_vapid rejects unprefixed strings deep in webpush(), where they'd
# surface once per subscription as an "Unexpected error". Catch the
# misconfiguration here and skip cleanly.
if not vapid_claim_email.startswith(("mailto:", "https://")):
logger.warning(
"VAPID_CLAIM_EMAIL must start with 'mailto:' or 'https://', got %r"
"skipping push",
vapid_claim_email[:40],
)
return
if user_id in _user_last_push:
logger.debug("Debouncing push for user %s", user_id)
return
_user_last_push[user_id] = time.monotonic()
db_client = get_database_manager_async_client()
subscriptions = await db_client.get_user_push_subscriptions(user_id)
if not subscriptions:
return
push_data = _build_push_payload(payload)
vapid_claims: dict[str, str | int] = {"sub": vapid_claim_email}
async def _send_one(sub: PushSubscriptionDTO) -> None:
try:
# Defense-in-depth: reject endpoints that somehow bypassed the
# subscribe-time check (rows written before the validator existed,
# direct DB writes, or DNS changes that shifted a trusted host to
# a blocked IP).
await validate_push_endpoint(sub.endpoint)
await asyncio.to_thread(
webpush,
subscription_info={
"endpoint": sub.endpoint,
"keys": {"p256dh": sub.p256dh, "auth": sub.auth},
},
data=push_data,
vapid_private_key=vapid_private,
vapid_claims=vapid_claims,
)
except ValueError as e:
logger.warning(
"Refusing push to untrusted endpoint %s: %s",
sub.endpoint[:60],
e,
)
await db_client.delete_push_subscription(sub.user_id, sub.endpoint)
return
except WebPushException as e:
status = _extract_status_code(e)
if status in (410, 404):
logger.info(
"Push subscription gone (%s), removing: %s",
status,
sub.endpoint[:60],
)
await db_client.delete_push_subscription(sub.user_id, sub.endpoint)
else:
logger.warning("Push failed for %s: %s", sub.endpoint[:60], e)
await db_client.increment_push_fail_count(sub.user_id, sub.endpoint)
except Exception:
logger.exception("Unexpected error sending push to %s", sub.endpoint[:60])
await asyncio.gather(
*[_send_one(sub) for sub in subscriptions], return_exceptions=True
)

View File

@@ -0,0 +1,348 @@
"""Tests for fire-and-forget Web Push delivery."""
from unittest.mock import AsyncMock, MagicMock
import pytest
from backend.api.model import NotificationPayload
from backend.data import push_sender
from backend.data.push_subscription import PushSubscriptionDTO
@pytest.fixture(autouse=True)
def clear_debounce():
"""Reset the per-user debounce state between tests."""
push_sender._user_last_push.clear()
yield
push_sender._user_last_push.clear()
@pytest.fixture
def mock_db_client(mocker):
"""Provides a mocked DatabaseManagerAsyncClient with stub async methods."""
client = MagicMock()
client.get_user_push_subscriptions = AsyncMock(return_value=[])
client.delete_push_subscription = AsyncMock()
client.increment_push_fail_count = AsyncMock()
mocker.patch(
"backend.data.push_sender.get_database_manager_async_client",
return_value=client,
)
return client
def _make_settings(
private: str = "vapid-private",
public: str = "vapid-public",
email: str = "mailto:push@agpt.co",
) -> MagicMock:
settings = MagicMock()
settings.secrets.vapid_private_key = private
settings.secrets.vapid_public_key = public
settings.secrets.vapid_claim_email = email
return settings
def _make_subscription(
user_id: str = "user-1",
endpoint: str = "https://fcm.googleapis.com/fcm/send/sub/1",
p256dh: str = "test-p256dh",
auth: str = "test-auth",
) -> PushSubscriptionDTO:
return PushSubscriptionDTO(
user_id=user_id, endpoint=endpoint, p256dh=p256dh, auth=auth
)
def _make_payload(**kwargs) -> NotificationPayload:
defaults = {"type": "agent_run", "event": "completed"}
defaults.update(kwargs)
return NotificationPayload(**defaults)
class TestBuildPushPayload:
def test_includes_type_and_event(self):
payload = _make_payload(type="agent_run", event="completed")
result = push_sender._build_push_payload(payload)
import json
parsed = json.loads(result)
assert parsed["type"] == "agent_run"
assert parsed["event"] == "completed"
def test_forwards_known_fields(self):
payload = _make_payload(
execution_id="exec-1",
graph_id="graph-1",
status="completed",
)
result = push_sender._build_push_payload(payload)
import json
parsed = json.loads(result)
assert parsed["execution_id"] == "exec-1"
assert parsed["graph_id"] == "graph-1"
assert parsed["status"] == "completed"
def test_excludes_unknown_fields(self):
payload = _make_payload(
custom_field="should-not-appear",
)
result = push_sender._build_push_payload(payload)
import json
parsed = json.loads(result)
assert "custom_field" not in parsed
def test_uses_model_dump_json_mode(self):
"""Ensure model_dump(mode='json') serializes enums to strings."""
payload = _make_payload(type="agent_run", event="completed")
result = push_sender._build_push_payload(payload)
import json
parsed = json.loads(result)
assert isinstance(parsed["type"], str)
assert isinstance(parsed["event"], str)
def test_includes_unique_id_per_call(self):
"""Each push gets a fresh UUID so repeats don't collapse under the same SW tag."""
import json
payload = _make_payload(type="agent_run", event="completed")
first = json.loads(push_sender._build_push_payload(payload))
second = json.loads(push_sender._build_push_payload(payload))
assert "id" in first and "id" in second
assert first["id"] != second["id"]
class TestSendPushForUser:
@pytest.mark.asyncio
async def test_skips_when_vapid_private_key_missing(self, mocker, mock_db_client):
mocker.patch.object(push_sender, "_settings", _make_settings(private=""))
await push_sender.send_push_for_user("user-1", _make_payload())
mock_db_client.get_user_push_subscriptions.assert_not_awaited()
@pytest.mark.asyncio
async def test_skips_when_vapid_public_key_missing(self, mocker, mock_db_client):
mocker.patch.object(push_sender, "_settings", _make_settings(public=""))
await push_sender.send_push_for_user("user-1", _make_payload())
mock_db_client.get_user_push_subscriptions.assert_not_awaited()
@pytest.mark.asyncio
async def test_skips_when_vapid_email_missing(self, mocker, mock_db_client):
mocker.patch.object(push_sender, "_settings", _make_settings(email=""))
await push_sender.send_push_for_user("user-1", _make_payload())
mock_db_client.get_user_push_subscriptions.assert_not_awaited()
@pytest.mark.asyncio
async def test_debounces_rapid_calls(self, mocker, mock_db_client):
mocker.patch.object(push_sender, "_settings", _make_settings())
await push_sender.send_push_for_user("user-1", _make_payload())
assert mock_db_client.get_user_push_subscriptions.await_count == 1
await push_sender.send_push_for_user("user-1", _make_payload())
assert mock_db_client.get_user_push_subscriptions.await_count == 1
@pytest.mark.asyncio
async def test_different_users_not_debounced(self, mocker, mock_db_client):
mocker.patch.object(push_sender, "_settings", _make_settings())
await push_sender.send_push_for_user("user-1", _make_payload())
await push_sender.send_push_for_user("user-2", _make_payload())
assert mock_db_client.get_user_push_subscriptions.await_count == 2
@pytest.mark.asyncio
async def test_returns_early_when_no_subscriptions(self, mocker, mock_db_client):
mocker.patch.object(push_sender, "_settings", _make_settings())
mock_webpush = mocker.patch("backend.data.push_sender.webpush")
await push_sender.send_push_for_user("user-1", _make_payload())
mock_webpush.assert_not_called()
@pytest.mark.asyncio
async def test_calls_webpush_for_each_subscription(self, mocker, mock_db_client):
mocker.patch.object(push_sender, "_settings", _make_settings())
sub1 = _make_subscription(endpoint="https://fcm.googleapis.com/fcm/send/sub/1")
sub2 = _make_subscription(endpoint="https://fcm.googleapis.com/fcm/send/sub/2")
mock_db_client.get_user_push_subscriptions.return_value = [sub1, sub2]
mock_webpush = mocker.patch("backend.data.push_sender.webpush")
await push_sender.send_push_for_user("user-1", _make_payload())
assert mock_webpush.call_count == 2
calls = mock_webpush.call_args_list
endpoints_called = [c.kwargs["subscription_info"]["endpoint"] for c in calls]
assert "https://fcm.googleapis.com/fcm/send/sub/1" in endpoints_called
assert "https://fcm.googleapis.com/fcm/send/sub/2" in endpoints_called
@pytest.mark.asyncio
async def test_webpush_called_with_correct_args(self, mocker, mock_db_client):
mocker.patch.object(push_sender, "_settings", _make_settings())
sub = _make_subscription(
user_id="user-1",
endpoint="https://fcm.googleapis.com/fcm/send/sub/1",
p256dh="key-p256dh",
auth="key-auth",
)
mock_db_client.get_user_push_subscriptions.return_value = [sub]
mock_webpush = mocker.patch("backend.data.push_sender.webpush")
await push_sender.send_push_for_user("user-1", _make_payload())
mock_webpush.assert_called_once()
call_kwargs = mock_webpush.call_args.kwargs
assert call_kwargs["subscription_info"] == {
"endpoint": "https://fcm.googleapis.com/fcm/send/sub/1",
"keys": {"p256dh": "key-p256dh", "auth": "key-auth"},
}
assert call_kwargs["vapid_private_key"] == "vapid-private"
assert call_kwargs["vapid_claims"] == {"sub": "mailto:push@agpt.co"}
assert isinstance(call_kwargs["data"], str)
@pytest.mark.asyncio
async def test_removes_subscription_on_410_gone(self, mocker, mock_db_client):
from pywebpush import WebPushException
mocker.patch.object(push_sender, "_settings", _make_settings())
sub = _make_subscription()
mock_db_client.get_user_push_subscriptions.return_value = [sub]
mock_response = MagicMock()
mock_response.status_code = 410
exc = WebPushException("Gone", response=mock_response)
mocker.patch("backend.data.push_sender.webpush", side_effect=exc)
await push_sender.send_push_for_user("user-1", _make_payload())
mock_db_client.delete_push_subscription.assert_awaited_once_with(
sub.user_id, sub.endpoint
)
@pytest.mark.asyncio
async def test_removes_subscription_on_404(self, mocker, mock_db_client):
from pywebpush import WebPushException
mocker.patch.object(push_sender, "_settings", _make_settings())
sub = _make_subscription()
mock_db_client.get_user_push_subscriptions.return_value = [sub]
mock_response = MagicMock()
mock_response.status_code = 404
exc = WebPushException("Not Found", response=mock_response)
mocker.patch("backend.data.push_sender.webpush", side_effect=exc)
await push_sender.send_push_for_user("user-1", _make_payload())
mock_db_client.delete_push_subscription.assert_awaited_once_with(
sub.user_id, sub.endpoint
)
@pytest.mark.asyncio
async def test_removes_subscription_when_status_only_in_message(
self, mocker, mock_db_client
):
"""Some pywebpush versions don't expose .response.status_code; the
sender must still detect 410 from the exception message and clean up."""
from pywebpush import WebPushException
mocker.patch.object(push_sender, "_settings", _make_settings())
sub = _make_subscription()
mock_db_client.get_user_push_subscriptions.return_value = [sub]
# No usable response object — only the message carries the status.
exc = WebPushException("Push failed: 410 Gone\nResponse body:gone")
exc.response = None # type: ignore[assignment]
mocker.patch("backend.data.push_sender.webpush", side_effect=exc)
await push_sender.send_push_for_user("user-1", _make_payload())
mock_db_client.delete_push_subscription.assert_awaited_once_with(
sub.user_id, sub.endpoint
)
@pytest.mark.asyncio
async def test_increments_fail_count_on_other_webpush_error(
self, mocker, mock_db_client
):
from pywebpush import WebPushException
mocker.patch.object(push_sender, "_settings", _make_settings())
sub = _make_subscription()
mock_db_client.get_user_push_subscriptions.return_value = [sub]
mock_response = MagicMock()
mock_response.status_code = 429
exc = WebPushException("Too Many Requests", response=mock_response)
mocker.patch("backend.data.push_sender.webpush", side_effect=exc)
await push_sender.send_push_for_user("user-1", _make_payload())
mock_db_client.increment_push_fail_count.assert_awaited_once_with(
sub.user_id, sub.endpoint
)
@pytest.mark.asyncio
async def test_increments_fail_count_when_no_response_object(
self, mocker, mock_db_client
):
from pywebpush import WebPushException
mocker.patch.object(push_sender, "_settings", _make_settings())
sub = _make_subscription()
mock_db_client.get_user_push_subscriptions.return_value = [sub]
exc = WebPushException("Connection error")
mocker.patch("backend.data.push_sender.webpush", side_effect=exc)
await push_sender.send_push_for_user("user-1", _make_payload())
mock_db_client.increment_push_fail_count.assert_awaited_once_with(
sub.user_id, sub.endpoint
)
@pytest.mark.asyncio
async def test_handles_unexpected_exception_gracefully(
self, mocker, mock_db_client
):
mocker.patch.object(push_sender, "_settings", _make_settings())
sub = _make_subscription()
mock_db_client.get_user_push_subscriptions.return_value = [sub]
mocker.patch(
"backend.data.push_sender.webpush",
side_effect=RuntimeError("network down"),
)
await push_sender.send_push_for_user("user-1", _make_payload())
@pytest.mark.asyncio
async def test_debounce_expires_after_threshold(self, mocker, mock_db_client):
mocker.patch.object(push_sender, "_settings", _make_settings())
await push_sender.send_push_for_user("user-1", _make_payload())
assert mock_db_client.get_user_push_subscriptions.await_count == 1
# Simulate TTL expiry (cachetools evicts on access after TTL elapses).
push_sender._user_last_push.pop("user-1", None)
await push_sender.send_push_for_user("user-1", _make_payload())
assert mock_db_client.get_user_push_subscriptions.await_count == 2

View File

@@ -0,0 +1,142 @@
"""CRUD operations for Web Push subscriptions (PushSubscription model)."""
import logging
from datetime import datetime, timezone
from prisma.models import PushSubscription
from pydantic import BaseModel
from backend.util.request import validate_url_host
logger = logging.getLogger(__name__)
# Hostnames of legitimate Web Push services. Endpoints submitted by
# clients must match one of these; everything else is rejected to prevent
# the backend (which POSTs to the stored URL via pywebpush) from being
# used as an SSRF primitive against internal infrastructure. Covers Chrome/
# Edge/Brave (FCM), Firefox (Autopush), and Safari/macOS (Apple Web Push).
_PUSH_SERVICE_HOSTNAMES: list[str] = [
"fcm.googleapis.com",
"updates.push.services.mozilla.com",
"web.push.apple.com",
]
# Cap on concurrent push subscriptions per user — one entry per device/browser
# is typical, so this comfortably covers real usage while preventing an
# authenticated user from registering unbounded endpoints to amplify outbound
# traffic from the backend.
MAX_SUBSCRIPTIONS_PER_USER = 20
# Delete subscriptions with this many failed push attempts during periodic
# cleanup. Web Push sends occasionally fail transiently; beyond this threshold
# the endpoint is effectively dead and should be removed.
MAX_PUSH_FAILURES = 5
async def validate_push_endpoint(endpoint: str) -> None:
"""Ensure a push-subscription endpoint is an HTTPS URL hosted on a known
Web Push provider. Raises ``ValueError`` otherwise.
Called at subscribe time and again before dispatch (defense-in-depth against
rows written before this check existed or via future codepaths).
"""
parsed, is_trusted, _ = await validate_url_host(
endpoint, trusted_hostnames=_PUSH_SERVICE_HOSTNAMES
)
if parsed.scheme != "https":
raise ValueError("Push endpoint must use https://")
if not is_trusted:
raise ValueError(
f"Push endpoint host '{parsed.hostname}' is not a recognised "
"Web Push service"
)
class PushSubscriptionDTO(BaseModel):
"""RPC-serializable projection of PushSubscription."""
user_id: str
endpoint: str
p256dh: str
auth: str
@staticmethod
def from_db(model: PushSubscription) -> "PushSubscriptionDTO":
return PushSubscriptionDTO(
user_id=model.userId,
endpoint=model.endpoint,
p256dh=model.p256dh,
auth=model.auth,
)
async def upsert_push_subscription(
user_id: str,
endpoint: str,
p256dh: str,
auth: str,
user_agent: str | None = None,
) -> PushSubscription:
existing = await PushSubscription.prisma().find_many(
where={"userId": user_id},
)
# Allow updates to an existing endpoint; only block when adding a *new* one
# past the cap.
has_this_endpoint = any(row.endpoint == endpoint for row in existing)
if len(existing) >= MAX_SUBSCRIPTIONS_PER_USER and not has_this_endpoint:
raise ValueError(
f"Subscription limit of {MAX_SUBSCRIPTIONS_PER_USER} per user reached"
)
return await PushSubscription.prisma().upsert(
where={"userId_endpoint": {"userId": user_id, "endpoint": endpoint}},
data={
"create": {
"userId": user_id,
"endpoint": endpoint,
"p256dh": p256dh,
"auth": auth,
"userAgent": user_agent,
},
"update": {
"p256dh": p256dh,
"auth": auth,
"userAgent": user_agent,
"failCount": 0,
"lastFailedAt": None,
},
},
)
async def get_user_push_subscriptions(user_id: str) -> list[PushSubscriptionDTO]:
rows = await PushSubscription.prisma().find_many(where={"userId": user_id})
return [PushSubscriptionDTO.from_db(row) for row in rows]
async def delete_push_subscription(user_id: str, endpoint: str) -> None:
await PushSubscription.prisma().delete_many(
where={"userId": user_id, "endpoint": endpoint}
)
async def increment_fail_count(user_id: str, endpoint: str) -> None:
await PushSubscription.prisma().update_many(
where={"userId": user_id, "endpoint": endpoint},
data={
"failCount": {"increment": 1},
"lastFailedAt": datetime.now(timezone.utc),
},
)
async def cleanup_failed_subscriptions(
max_failures: int = MAX_PUSH_FAILURES,
) -> int:
"""Delete subscriptions that have exceeded the failure threshold."""
result = await PushSubscription.prisma().delete_many(
where={"failCount": {"gte": max_failures}}
)
if result:
logger.info(f"Cleaned up {result} failed push subscriptions")
return result or 0

View File

@@ -0,0 +1,325 @@
"""Tests for Web Push subscription CRUD operations."""
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock
import pytest
from backend.data import push_subscription
@pytest.fixture
def mock_prisma(mocker):
"""Mock PushSubscription.prisma() and return the mock client."""
mock_client = MagicMock()
mock_client.upsert = AsyncMock()
mock_client.find_many = AsyncMock(return_value=[])
mock_client.delete_many = AsyncMock()
mock_client.update_many = AsyncMock()
mocker.patch(
"backend.data.push_subscription.PushSubscription.prisma",
return_value=mock_client,
)
return mock_client
class TestUpsertPushSubscription:
@pytest.mark.asyncio
async def test_calls_prisma_upsert_with_correct_params(self, mock_prisma):
mock_prisma.upsert.return_value = MagicMock()
await push_subscription.upsert_push_subscription(
user_id="user-1",
endpoint="https://fcm.googleapis.com/fcm/send/sub/1",
p256dh="test-p256dh",
auth="test-auth",
user_agent="Mozilla/5.0",
)
mock_prisma.upsert.assert_awaited_once()
call_kwargs = mock_prisma.upsert.call_args.kwargs
assert call_kwargs["where"] == {
"userId_endpoint": {
"userId": "user-1",
"endpoint": "https://fcm.googleapis.com/fcm/send/sub/1",
}
}
assert call_kwargs["data"]["create"] == {
"userId": "user-1",
"endpoint": "https://fcm.googleapis.com/fcm/send/sub/1",
"p256dh": "test-p256dh",
"auth": "test-auth",
"userAgent": "Mozilla/5.0",
}
assert call_kwargs["data"]["update"] == {
"p256dh": "test-p256dh",
"auth": "test-auth",
"userAgent": "Mozilla/5.0",
"failCount": 0,
"lastFailedAt": None,
}
@pytest.mark.asyncio
async def test_upsert_without_user_agent(self, mock_prisma):
mock_prisma.upsert.return_value = MagicMock()
await push_subscription.upsert_push_subscription(
user_id="user-1",
endpoint="https://fcm.googleapis.com/fcm/send/sub/1",
p256dh="test-p256dh",
auth="test-auth",
)
call_kwargs = mock_prisma.upsert.call_args.kwargs
assert call_kwargs["data"]["create"]["userAgent"] is None
assert call_kwargs["data"]["update"]["userAgent"] is None
@pytest.mark.asyncio
async def test_upsert_returns_prisma_result(self, mock_prisma):
expected = MagicMock()
mock_prisma.upsert.return_value = expected
result = await push_subscription.upsert_push_subscription(
user_id="user-1",
endpoint="https://fcm.googleapis.com/fcm/send/sub/1",
p256dh="test-p256dh",
auth="test-auth",
)
assert result is expected
@pytest.mark.asyncio
async def test_upsert_resets_fail_count_on_update(self, mock_prisma):
mock_prisma.upsert.return_value = MagicMock()
await push_subscription.upsert_push_subscription(
user_id="user-1",
endpoint="https://fcm.googleapis.com/fcm/send/sub/1",
p256dh="test-p256dh",
auth="test-auth",
)
call_kwargs = mock_prisma.upsert.call_args.kwargs
assert call_kwargs["data"]["update"]["failCount"] == 0
assert call_kwargs["data"]["update"]["lastFailedAt"] is None
@pytest.mark.asyncio
async def test_rejects_new_endpoint_past_cap(self, mock_prisma):
existing = [
MagicMock(endpoint=f"https://fcm.googleapis.com/fcm/send/sub/{i}")
for i in range(push_subscription.MAX_SUBSCRIPTIONS_PER_USER)
]
mock_prisma.find_many.return_value = existing
with pytest.raises(ValueError, match="Subscription limit"):
await push_subscription.upsert_push_subscription(
user_id="user-1",
endpoint="https://fcm.googleapis.com/fcm/send/sub/NEW",
p256dh="test-p256dh",
auth="test-auth",
)
mock_prisma.upsert.assert_not_awaited()
@pytest.mark.asyncio
async def test_allows_update_of_existing_endpoint_at_cap(self, mock_prisma):
existing = [
MagicMock(endpoint=f"https://fcm.googleapis.com/fcm/send/sub/{i}")
for i in range(push_subscription.MAX_SUBSCRIPTIONS_PER_USER)
]
mock_prisma.find_many.return_value = existing
mock_prisma.upsert.return_value = MagicMock()
await push_subscription.upsert_push_subscription(
user_id="user-1",
endpoint="https://fcm.googleapis.com/fcm/send/sub/0",
p256dh="rotated-p256dh",
auth="rotated-auth",
)
mock_prisma.upsert.assert_awaited_once()
class TestGetUserPushSubscriptions:
@pytest.mark.asyncio
async def test_returns_list_of_subscription_dtos(self, mock_prisma):
sub1 = MagicMock(
userId="user-1",
endpoint="https://fcm.googleapis.com/fcm/send/sub/1",
p256dh="key1",
auth="auth1",
)
sub2 = MagicMock(
userId="user-1",
endpoint="https://fcm.googleapis.com/fcm/send/sub/2",
p256dh="key2",
auth="auth2",
)
mock_prisma.find_many.return_value = [sub1, sub2]
result = await push_subscription.get_user_push_subscriptions("user-1")
assert [r.endpoint for r in result] == [
"https://fcm.googleapis.com/fcm/send/sub/1",
"https://fcm.googleapis.com/fcm/send/sub/2",
]
assert all(r.user_id == "user-1" for r in result)
mock_prisma.find_many.assert_awaited_once_with(where={"userId": "user-1"})
@pytest.mark.asyncio
async def test_returns_empty_list_when_no_subscriptions(self, mock_prisma):
mock_prisma.find_many.return_value = []
result = await push_subscription.get_user_push_subscriptions("user-1")
assert result == []
class TestDeletePushSubscription:
@pytest.mark.asyncio
async def test_deletes_by_user_id_and_endpoint(self, mock_prisma):
await push_subscription.delete_push_subscription(
"user-1",
"https://fcm.googleapis.com/fcm/send/sub/1",
)
mock_prisma.delete_many.assert_awaited_once_with(
where={
"userId": "user-1",
"endpoint": "https://fcm.googleapis.com/fcm/send/sub/1",
}
)
class TestIncrementFailCount:
@pytest.mark.asyncio
async def test_includes_user_id_in_where(self, mock_prisma):
await push_subscription.increment_fail_count(
"user-1",
"https://fcm.googleapis.com/fcm/send/sub/1",
)
mock_prisma.update_many.assert_awaited_once()
call_kwargs = mock_prisma.update_many.call_args.kwargs
assert call_kwargs["where"] == {
"userId": "user-1",
"endpoint": "https://fcm.googleapis.com/fcm/send/sub/1",
}
@pytest.mark.asyncio
async def test_increments_fail_count_by_one(self, mock_prisma):
await push_subscription.increment_fail_count(
"user-1",
"https://fcm.googleapis.com/fcm/send/sub/1",
)
call_kwargs = mock_prisma.update_many.call_args.kwargs
assert call_kwargs["data"]["failCount"] == {"increment": 1}
@pytest.mark.asyncio
async def test_sets_last_failed_at_to_utc_now(self, mock_prisma):
await push_subscription.increment_fail_count(
"user-1",
"https://fcm.googleapis.com/fcm/send/sub/1",
)
call_kwargs = mock_prisma.update_many.call_args.kwargs
last_failed = call_kwargs["data"]["lastFailedAt"]
assert isinstance(last_failed, datetime)
assert last_failed.tzinfo is not None
class TestCleanupFailedSubscriptions:
@pytest.mark.asyncio
async def test_deletes_subscriptions_exceeding_threshold(self, mock_prisma):
mock_prisma.delete_many.return_value = 3
result = await push_subscription.cleanup_failed_subscriptions(
max_failures=5,
)
assert result == 3
mock_prisma.delete_many.assert_awaited_once_with(
where={"failCount": {"gte": 5}}
)
@pytest.mark.asyncio
async def test_uses_default_max_failures(self, mock_prisma):
mock_prisma.delete_many.return_value = 0
await push_subscription.cleanup_failed_subscriptions()
call_kwargs = mock_prisma.delete_many.call_args.kwargs
assert call_kwargs["where"]["failCount"]["gte"] == 5
@pytest.mark.asyncio
async def test_returns_zero_when_none_deleted(self, mock_prisma):
mock_prisma.delete_many.return_value = 0
result = await push_subscription.cleanup_failed_subscriptions()
assert result == 0
@pytest.mark.asyncio
async def test_returns_zero_when_result_is_none(self, mock_prisma):
mock_prisma.delete_many.return_value = None
result = await push_subscription.cleanup_failed_subscriptions()
assert result == 0
class TestValidatePushEndpoint:
"""Endpoints from clients must land on a known Web Push service — otherwise
the backend can be coerced into POSTing to internal hosts (SSRF)."""
@pytest.mark.asyncio
@pytest.mark.parametrize(
"endpoint",
[
"https://fcm.googleapis.com/fcm/send/abc",
"https://updates.push.services.mozilla.com/wpush/v2/xyz",
"https://web.push.apple.com/some-token",
],
)
async def test_allows_known_push_services(self, endpoint):
await push_subscription.validate_push_endpoint(endpoint)
@pytest.mark.asyncio
async def test_rejects_http_scheme(self):
with pytest.raises(ValueError):
await push_subscription.validate_push_endpoint(
"http://fcm.googleapis.com/fcm/send/abc"
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"endpoint",
[
"https://localhost/evil",
"https://127.0.0.1/evil",
"https://169.254.169.254/latest/meta-data/",
"https://internal-service.local/api",
"https://attacker.example.com/push",
],
)
async def test_rejects_untrusted_hosts(self, endpoint):
with pytest.raises(ValueError):
await push_subscription.validate_push_endpoint(endpoint)
@pytest.mark.asyncio
async def test_rejects_non_http_scheme(self):
with pytest.raises(ValueError):
await push_subscription.validate_push_endpoint("file:///etc/passwd")
@pytest.mark.asyncio
async def test_custom_max_failures_threshold(self, mock_prisma):
mock_prisma.delete_many.return_value = 1
result = await push_subscription.cleanup_failed_subscriptions(
max_failures=10,
)
assert result == 1
call_kwargs = mock_prisma.delete_many.call_args.kwargs
assert call_kwargs["where"]["failCount"]["gte"] == 10

View File

@@ -0,0 +1,278 @@
"""Quorum-queue config assertions + mock-driven publish behaviour for
`AsyncRabbitMQ`. Live-broker scenarios live in `e2e_redis_rabbit_test.py`."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import aio_pika
import pytest
from backend.copilot.executor.utils import (
COPILOT_EXECUTION_EXCHANGE,
COPILOT_EXECUTION_QUEUE_NAME,
COPILOT_EXECUTION_ROUTING_KEY,
create_copilot_queue_config,
)
from backend.data.rabbitmq import (
AsyncRabbitMQ,
Exchange,
ExchangeType,
Queue,
RabbitMQConfig,
)
from backend.executor.utils import (
GRAPH_EXECUTION_EXCHANGE,
GRAPH_EXECUTION_QUEUE_NAME,
GRAPH_EXECUTION_ROUTING_KEY,
create_execution_queue_config,
)
# ---------- Quorum queue config: classic→quorum rollover guard ----------
def test_graph_execution_queue_is_quorum() -> None:
"""Run queue must declare `x-queue-type=quorum` to survive a single
broker-node outage (AUTOGPT-SERVER-8ST/SV/SW)."""
cfg = create_execution_queue_config()
run = next(q for q in cfg.queues if q.name == GRAPH_EXECUTION_QUEUE_NAME)
assert run.arguments is not None
assert run.arguments.get("x-queue-type") == "quorum"
# _v2 suffix marks the rollover so the old-image consumer keeps draining
# the unsuffixed classic queue during a rolling deploy.
assert run.name.endswith("_v2")
assert run.durable is True
assert run.exchange is GRAPH_EXECUTION_EXCHANGE
def test_graph_execution_cancel_queue_is_quorum() -> None:
"""Cancel queue must also be quorum — losing cancellations on a node
flap is just as bad as losing runs."""
cfg = create_execution_queue_config()
cancel = next(q for q in cfg.queues if q.name.endswith("cancel_queue_v2"))
assert cancel.arguments == {"x-queue-type": "quorum"}
def test_copilot_execution_queue_is_quorum_with_consumer_timeout() -> None:
"""Copilot run queue must be quorum + carry a long consumer timeout
matching the pod's graceful-shutdown window."""
cfg = create_copilot_queue_config()
run = next(q for q in cfg.queues if q.name == COPILOT_EXECUTION_QUEUE_NAME)
assert run.arguments is not None
assert run.arguments.get("x-queue-type") == "quorum"
# Timeout must be in milliseconds and substantially larger than the
# default 30-minute timeout so a 6-hour copilot turn doesn't get
# cancelled by RabbitMQ mid-execution.
timeout_ms = run.arguments.get("x-consumer-timeout")
assert isinstance(timeout_ms, int)
assert timeout_ms >= 60 * 60 * 1000 # at least 1 hour
def test_copilot_cancel_queue_is_quorum() -> None:
cfg = create_copilot_queue_config()
cancel = next(q for q in cfg.queues if q.name.endswith("cancel_queue_v2"))
assert cancel.arguments == {"x-queue-type": "quorum"}
# ---------- AsyncRabbitMQ.publish_message: mock-driven behaviour ----------
def _make_async_client(
*, exchange_publish: AsyncMock | None = None
) -> tuple[AsyncRabbitMQ, MagicMock, MagicMock]:
"""Build an AsyncRabbitMQ wired to mock connection/channel/exchange.
Returns the client, the mock channel, and the mock exchange so tests can
assert on per-call arguments and tweak side_effects mid-flight.
"""
cfg = RabbitMQConfig(
vhost="/",
exchanges=[
Exchange(name="test_exchange", type=ExchangeType.DIRECT, durable=True)
],
queues=[
Queue(
name="test_queue",
exchange=Exchange(
name="test_exchange", type=ExchangeType.DIRECT, durable=True
),
routing_key="rk",
arguments={"x-queue-type": "quorum"},
)
],
)
client = AsyncRabbitMQ(cfg)
fake_exchange = MagicMock()
fake_exchange.publish = exchange_publish or AsyncMock()
fake_channel = MagicMock()
fake_channel.is_closed = False
fake_channel.get_exchange = AsyncMock(return_value=fake_exchange)
fake_channel.default_exchange = fake_exchange
fake_connection = MagicMock()
fake_connection.is_closed = False
client._connection = fake_connection
client._channel = fake_channel
return client, fake_channel, fake_exchange
@pytest.mark.asyncio
async def test_publish_100_messages_to_quorum_queue_all_confirmed() -> None:
"""A healthy quorum queue publish path must confirm 100/100 publishes
with no NACKs."""
client, _, fake_exchange = _make_async_client()
exchange = Exchange(name="test_exchange", type=ExchangeType.DIRECT)
for i in range(100):
await client.publish_message(
routing_key="rk", message=f"msg-{i}", exchange=exchange
)
assert fake_exchange.publish.await_count == 100
# Every call carried a persistent message — durable on the broker side.
for call in fake_exchange.publish.await_args_list:
msg = call.args[0]
assert isinstance(msg, aio_pika.Message)
assert msg.delivery_mode == aio_pika.DeliveryMode.PERSISTENT
@pytest.mark.asyncio
async def test_publish_retries_on_delivery_error_then_raises() -> None:
"""Broker-side NACK (DeliveryError) must trigger ``func_retry`` and then
raise gracefully if every retry fails — never crash the publisher loop."""
publish = AsyncMock(
side_effect=aio_pika.exceptions.DeliveryError(message=None, frame=None)
)
client, _, fake_exchange = _make_async_client(exchange_publish=publish)
exchange = Exchange(name="test_exchange", type=ExchangeType.DIRECT)
with pytest.raises(aio_pika.exceptions.DeliveryError):
await client.publish_message(
routing_key="rk", message="will-nack", exchange=exchange
)
# ``func_retry`` is configured for 5 attempts in retry.py — assert the
# publisher attempted at least once but bounded retries.
assert fake_exchange.publish.await_count >= 1
assert fake_exchange.publish.await_count <= 10 # generous upper bound
@pytest.mark.asyncio
async def test_publish_retries_after_one_transient_failure() -> None:
"""A single transient DeliveryError must NOT propagate — ``func_retry``
retries and the second call succeeds."""
publish = AsyncMock(
side_effect=[
aio_pika.exceptions.DeliveryError(message=None, frame=None),
None, # second attempt succeeds
]
)
client, _, fake_exchange = _make_async_client(exchange_publish=publish)
exchange = Exchange(name="test_exchange", type=ExchangeType.DIRECT)
await client.publish_message(
routing_key="rk", message="recovers", exchange=exchange
)
assert fake_exchange.publish.await_count == 2
@pytest.mark.asyncio
async def test_publish_reconnects_on_channel_invalid_state() -> None:
"""ChannelInvalidStateError must clear the channel and trigger a
reconnect-and-retry — the publish_message wrapper handles this
explicitly (see the except-clause in rabbitmq.py)."""
publish = AsyncMock(
side_effect=[
aio_pika.exceptions.ChannelInvalidStateError("channel dead"),
None,
]
)
client, fake_channel, fake_exchange = _make_async_client(exchange_publish=publish)
exchange = Exchange(name="test_exchange", type=ExchangeType.DIRECT)
# Patch connect() so the reconnect path doesn't try to hit a real broker.
async def _fake_connect():
# After reconnect the channel must be valid again.
client._channel = fake_channel
return None
with patch.object(client, "connect", side_effect=_fake_connect):
await client.publish_message(
routing_key="rk", message="reconnects", exchange=exchange
)
# Two publish attempts: the failing one + the post-reconnect retry.
assert fake_exchange.publish.await_count == 2
# ---------- Dual-deploy: legacy classic + new quorum publisher in parallel ----------
@pytest.mark.asyncio
async def test_dual_deploy_publishes_to_legacy_and_new_queues_in_parallel() -> None:
"""Rolling-deploy window: old-image producer publishes to classic queue,
new-image to `_v2` quorum queue — both must succeed independently."""
legacy_client, _, legacy_exchange = _make_async_client()
new_client, _, new_exchange = _make_async_client()
legacy_routing = "copilot.run" # legacy producers used the same routing key
new_routing = COPILOT_EXECUTION_ROUTING_KEY
legacy_exch = Exchange(name="copilot_execution", type=ExchangeType.DIRECT)
new_exch = Exchange(name=COPILOT_EXECUTION_EXCHANGE.name, type=ExchangeType.DIRECT)
# Interleave 10 publishes from each producer — order doesn't matter.
for i in range(10):
await legacy_client.publish_message(
routing_key=legacy_routing, message=f"legacy-{i}", exchange=legacy_exch
)
await new_client.publish_message(
routing_key=new_routing, message=f"new-{i}", exchange=new_exch
)
assert legacy_exchange.publish.await_count == 10
assert new_exchange.publish.await_count == 10
# Each publisher's routing key landed on its own exchange — no crosstalk.
for call in legacy_exchange.publish.await_args_list:
assert call.kwargs.get("routing_key") == legacy_routing
for call in new_exchange.publish.await_args_list:
assert call.kwargs.get("routing_key") == new_routing
@pytest.mark.asyncio
async def test_dual_deploy_legacy_failure_does_not_affect_new_queue() -> None:
"""Legacy classic queue NACKing (AUTOGPT-SERVER-8ST) must not break
publishes on the new `_v2` quorum queue."""
legacy_publish = AsyncMock(
side_effect=aio_pika.exceptions.DeliveryError(message=None, frame=None)
)
legacy_client, _, _ = _make_async_client(exchange_publish=legacy_publish)
new_client, _, new_exchange = _make_async_client()
legacy_exch = Exchange(name="copilot_execution", type=ExchangeType.DIRECT)
new_exch = Exchange(name=COPILOT_EXECUTION_EXCHANGE.name, type=ExchangeType.DIRECT)
# Legacy raises after retries — caller must catch it.
with pytest.raises(aio_pika.exceptions.DeliveryError):
await legacy_client.publish_message(
routing_key="copilot.run", message="legacy-fail", exchange=legacy_exch
)
# New publisher continues to work — 5 successful publishes.
for i in range(5):
await new_client.publish_message(
routing_key=COPILOT_EXECUTION_ROUTING_KEY,
message=f"new-ok-{i}",
exchange=new_exch,
)
assert new_exchange.publish.await_count == 5
# ---------- Configuration sanity for downstream queues ----------
def test_graph_execution_routing_key_constants() -> None:
"""Routing key + exchange wiring must stay aligned — guards against the
classic→quorum migration accidentally also changing the routing key."""
cfg = create_execution_queue_config()
run = next(q for q in cfg.queues if q.name == GRAPH_EXECUTION_QUEUE_NAME)
assert run.routing_key == GRAPH_EXECUTION_ROUTING_KEY
assert GRAPH_EXECUTION_EXCHANGE in cfg.exchanges

View File

@@ -1,85 +1,205 @@
import asyncio
import logging
import os
from dotenv import load_dotenv
from redis import Redis
from redis.asyncio import Redis as AsyncRedis
from redis.asyncio.cluster import ClusterNode as AsyncClusterNode
from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster
from redis.cluster import ClusterNode, RedisCluster
from backend.util.cache import cached, thread_cached
from backend.util.cache import cached
from backend.util.retry import conn_retry
load_dotenv()
HOST = os.getenv("REDIS_HOST", "localhost")
PORT = int(os.getenv("REDIS_PORT", "6379"))
# Prefer the cluster env vars so the cluster-only image can co-exist with
# old-image pods still reading REDIS_HOST during a rollout.
HOST = os.getenv("REDIS_CLUSTER_HOST") or os.getenv("REDIS_HOST", "localhost")
PORT = int(os.getenv("REDIS_CLUSTER_PORT") or os.getenv("REDIS_PORT", "6379"))
PASSWORD = os.getenv("REDIS_PASSWORD", None)
# Default socket timeouts so a wedged Redis endpoint can't hang callers
# indefinitely — long-running code paths (cluster_lock refresh in particular)
# rely on these to fail-fast instead of blocking on no-response TCP. Override
# via env if a specific deployment needs a different budget.
#
# 30s matches the convention in ``backend.data.rabbitmq`` and leaves ~6x
# headroom over the largest ``xread(block=5000)`` wait in stream_registry.
# The connect timeout is shorter (5s) because initial connects should be
# fast; a slow connect usually means the endpoint is genuinely unreachable.
# Fail-fast on a wedged endpoint instead of blocking on no-response TCP.
SOCKET_TIMEOUT = float(os.getenv("REDIS_SOCKET_TIMEOUT", "30"))
SOCKET_CONNECT_TIMEOUT = float(os.getenv("REDIS_SOCKET_CONNECT_TIMEOUT", "5"))
# How often redis-py sends a PING on idle connections to detect half-open
# sockets; cheap and avoids waiting for the OS TCP keepalive (~2h default).
# PING on idle sockets to detect half-open connections without waiting for
# the OS TCP keepalive (~2h default).
HEALTH_CHECK_INTERVAL = int(os.getenv("REDIS_HEALTH_CHECK_INTERVAL", "30"))
# Skip the HOST-pinning remap when each shard's announced hostname resolves
# directly (e.g. compose DNS names redis-0/redis-1/redis-2).
USE_ANNOUNCED_ADDRESS = os.getenv("REDIS_USE_ANNOUNCED_ADDRESS", "").lower() in (
"1",
"true",
"yes",
)
logger = logging.getLogger(__name__)
# Aliases so call-sites don't care which class this is.
RedisClient = RedisCluster
AsyncRedisClient = AsyncRedisCluster
def _address_remap(addr: tuple[str, int]) -> tuple[str, int]:
"""Pin each shard to the seed `HOST`, keep its announced port.
Set `REDIS_USE_ANNOUNCED_ADDRESS=true` when the announced shard FQDNs
resolve directly (e.g. each pod has its own DNS).
"""
if USE_ANNOUNCED_ADDRESS:
return addr
_, port = addr
return HOST, port
@conn_retry("Redis", "Acquiring connection")
def connect() -> Redis:
c = Redis(
host=HOST,
port=PORT,
def connect() -> RedisClient:
c = RedisCluster(
startup_nodes=[ClusterNode(HOST, PORT)],
password=PASSWORD,
decode_responses=True,
socket_timeout=SOCKET_TIMEOUT,
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
socket_keepalive=True,
health_check_interval=HEALTH_CHECK_INTERVAL,
address_remap=_address_remap,
)
c.ping()
# Close on PING failure so retries don't leak ClusterNodes (AUTOGPT-SERVER-8T1).
try:
c.ping()
except Exception:
try:
c.close()
except Exception:
pass
raise
return c
@conn_retry("Redis", "Releasing connection")
def disconnect():
get_redis().close()
get_redis.cache_clear()
@cached(ttl_seconds=3600)
def get_redis() -> Redis:
def get_redis() -> RedisClient:
return connect()
@conn_retry("AsyncRedis", "Acquiring connection")
async def connect_async() -> AsyncRedis:
c = AsyncRedis(
host=HOST,
port=PORT,
async def connect_async() -> AsyncRedisClient:
c = AsyncRedisCluster(
startup_nodes=[AsyncClusterNode(HOST, PORT)],
password=PASSWORD,
decode_responses=True,
socket_timeout=SOCKET_TIMEOUT,
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
socket_keepalive=True,
health_check_interval=HEALTH_CHECK_INTERVAL,
address_remap=_address_remap,
)
await c.ping()
# Close on PING failure so retries don't leak ClusterNodes (AUTOGPT-SERVER-8V6/8V4/8V3).
try:
await c.ping()
except Exception:
try:
await c.close()
except Exception:
pass
raise
return c
# One AsyncRedisCluster per event loop: the client binds to the loop it was
# first awaited on, so a module-level singleton breaks across test loops.
_async_clients: dict[int, AsyncRedisCluster] = {}
@conn_retry("AsyncRedis", "Releasing connection")
async def disconnect_async():
c = await get_redis_async()
await c.close()
loop = asyncio.get_running_loop()
c = _async_clients.pop(id(loop), None)
if c is not None:
await c.close()
@thread_cached
async def get_redis_async() -> AsyncRedis:
return await connect_async()
async def get_redis_async() -> AsyncRedisClient:
loop = asyncio.get_running_loop()
client = _async_clients.get(id(loop))
if client is None:
client = await connect_async()
_async_clients[id(loop)] = client
return client
# Sharded pub/sub only delivers on the keyslot-owning shard; subscribers
# need a plain (Async)Redis connection pinned to that node.
def resolve_shard_for_channel(channel: str) -> tuple[str, int]:
"""Return the ``(host, port)`` of the shard that owns the channel's keyslot.
Applies the configured ``_address_remap`` so callers connect through the
same address the cluster client uses.
"""
cluster = get_redis()
node = cluster.get_node_from_key(channel)
if node is None:
raise RuntimeError(f"No cluster node owns the keyslot for channel {channel!r}")
return _address_remap((node.host, node.port))
@conn_retry("RedisShardedPubSub", "Acquiring connection")
def connect_sharded_pubsub(channel: str) -> Redis:
"""Open a plain ``Redis`` connection pinned to the channel's owning shard."""
host, port = resolve_shard_for_channel(channel)
# socket_timeout=None: pubsub reads block indefinitely; a spurious
# read timeout forces a reconnect whose PING races with subscribe-mode.
c = Redis(
host=host,
port=port,
password=PASSWORD,
decode_responses=True,
socket_timeout=None,
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
socket_keepalive=True,
health_check_interval=HEALTH_CHECK_INTERVAL,
)
try:
c.ping()
except Exception:
try:
c.close()
except Exception:
pass
raise
return c
@conn_retry("AsyncRedisShardedPubSub", "Acquiring connection")
async def connect_sharded_pubsub_async(channel: str) -> AsyncRedis:
"""Async variant of :func:`connect_sharded_pubsub`."""
host, port = resolve_shard_for_channel(channel)
# socket_timeout=None: see ``connect_sharded_pubsub``.
c = AsyncRedis(
host=host,
port=port,
password=PASSWORD,
decode_responses=True,
socket_timeout=None,
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
socket_keepalive=True,
health_check_interval=HEALTH_CHECK_INTERVAL,
)
try:
await c.ping()
except Exception:
try:
await c.close()
except Exception:
pass
raise
return c

View File

@@ -0,0 +1,599 @@
"""Unit tests for the cluster-only Redis client in ``redis_client``.
Patches the redis-py constructors + ``ping()`` so no real Redis is needed.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster
from redis.cluster import RedisCluster
import backend.data.redis_client as redis_client
@pytest.fixture(autouse=True)
def _reset_module_caches() -> None:
"""Flush cached singletons between tests so each test sees a fresh connect."""
redis_client.get_redis.cache_clear()
redis_client._async_clients.clear()
def test_connect_builds_redis_cluster() -> None:
with patch.object(redis_client, "RedisCluster", autospec=True) as mock_cluster:
mock_cluster.return_value = MagicMock(spec=RedisCluster)
client = redis_client.connect()
mock_cluster.assert_called_once()
kwargs = mock_cluster.call_args.kwargs
assert kwargs["password"] == redis_client.PASSWORD
assert kwargs["decode_responses"] is True
assert kwargs["socket_timeout"] == redis_client.SOCKET_TIMEOUT
assert kwargs["socket_connect_timeout"] == redis_client.SOCKET_CONNECT_TIMEOUT
assert kwargs["socket_keepalive"] is True
assert kwargs["health_check_interval"] == redis_client.HEALTH_CHECK_INTERVAL
assert kwargs["address_remap"] is redis_client._address_remap
startup = kwargs["startup_nodes"]
assert len(startup) == 1
# ClusterNode resolves "localhost" → "127.0.0.1" internally; both are
# valid representations of the configured host.
assert startup[0].host in {redis_client.HOST, "127.0.0.1"}
assert startup[0].port == redis_client.PORT
client.ping.assert_called_once()
def test_address_remap_pins_host_and_preserves_port() -> None:
"""Default remap rewrites announced shard host to the configured seed."""
with patch.object(redis_client, "USE_ANNOUNCED_ADDRESS", False):
assert redis_client._address_remap(("any-other-host", 6380)) == (
redis_client.HOST,
6380,
)
def test_address_remap_passthrough_when_use_announced_address() -> None:
"""When announced addresses resolve directly, remap leaves them alone."""
with patch.object(redis_client, "USE_ANNOUNCED_ADDRESS", True):
assert redis_client._address_remap(("redis-1", 17001)) == ("redis-1", 17001)
@pytest.mark.asyncio
async def test_connect_async_builds_async_redis_cluster() -> None:
with patch.object(redis_client, "AsyncRedisCluster", autospec=True) as mock_cluster:
fake = MagicMock(spec=AsyncRedisCluster)
fake.ping = AsyncMock()
mock_cluster.return_value = fake
client = await redis_client.connect_async()
mock_cluster.assert_called_once()
kwargs = mock_cluster.call_args.kwargs
assert kwargs["password"] == redis_client.PASSWORD
assert kwargs["decode_responses"] is True
assert kwargs["socket_timeout"] == redis_client.SOCKET_TIMEOUT
assert kwargs["socket_connect_timeout"] == redis_client.SOCKET_CONNECT_TIMEOUT
assert kwargs["socket_keepalive"] is True
assert kwargs["health_check_interval"] == redis_client.HEALTH_CHECK_INTERVAL
assert kwargs["address_remap"] is redis_client._address_remap
startup = kwargs["startup_nodes"]
assert len(startup) == 1
assert startup[0].host in {redis_client.HOST, "127.0.0.1"}
assert startup[0].port == redis_client.PORT
client.ping.assert_awaited_once()
def test_get_redis_caches_connect() -> None:
with patch.object(redis_client, "connect", autospec=True) as mock_connect:
mock_connect.return_value = MagicMock(spec=RedisCluster)
client_a = redis_client.get_redis()
client_b = redis_client.get_redis()
assert client_a is client_b
mock_connect.assert_called_once()
@pytest.mark.asyncio
async def test_get_redis_async_caches_connect() -> None:
with patch.object(redis_client, "connect_async", autospec=True) as mock_conn:
fake = MagicMock(spec=AsyncRedisCluster)
mock_conn.return_value = fake
a = await redis_client.get_redis_async()
b = await redis_client.get_redis_async()
assert a is b
mock_conn.assert_called_once()
def test_disconnect_closes_cached_client() -> None:
with patch.object(redis_client, "connect", autospec=True) as mock_connect:
fake = MagicMock(spec=RedisCluster)
mock_connect.return_value = fake
redis_client.get_redis()
redis_client.disconnect()
fake.close.assert_called_once()
@pytest.mark.asyncio
async def test_disconnect_async_closes_cached_client() -> None:
with patch.object(redis_client, "connect_async", autospec=True) as mock_connect:
fake = MagicMock(spec=AsyncRedisCluster)
fake.close = AsyncMock()
mock_connect.return_value = fake
await redis_client.get_redis_async()
await redis_client.disconnect_async()
fake.close.assert_awaited_once()
assert redis_client._async_clients == {}
@pytest.mark.asyncio
async def test_disconnect_async_no_cached_client_is_noop() -> None:
with patch.object(redis_client, "connect_async", autospec=True) as mock_connect:
await redis_client.disconnect_async()
mock_connect.assert_not_called()
# Sharded pub/sub end-to-end against the local 3-shard compose cluster.
# Skipped when no cluster is reachable so CI without docker doesn't flap.
def _has_live_cluster() -> bool:
try:
c = redis_client.connect()
except Exception: # noqa: BLE001 — any connect failure → skip the test
return False
try:
c.close()
except Exception:
pass
return True
@pytest.mark.skipif(
not _has_live_cluster(),
reason="local redis cluster not reachable; skip sharded pub/sub integration",
)
def test_sharded_pubsub_end_to_end_sync() -> None:
"""SPUBLISH → SSUBSCRIBE round-trip via the sync cluster client. Uses
per-node `get_message` because redis-py 6.x's
`ClusterPubSub.get_sharded_message(ignore_subscribe_messages=True)`
drops every message, not just the subscribe confirmation."""
redis_client.get_redis.cache_clear()
cluster = redis_client.get_redis()
channel = "pr12900:sharded-pubsub:integration"
ps = cluster.pubsub()
try:
ps.ssubscribe(channel)
assert cluster.spublish(channel, "hello") >= 1
# Exactly one node is subscribed (the keyslot owner); read from it.
assert len(ps.node_pubsub_mapping) == 1
node_ps = next(iter(ps.node_pubsub_mapping.values()))
# First message is the ssubscribe confirmation, second is our payload.
confirm = node_ps.get_message(timeout=2.0)
assert confirm is not None and confirm["type"] == "ssubscribe"
received = node_ps.get_message(timeout=5.0)
assert received is not None and received["type"] == "smessage"
assert received["data"] == "hello"
finally:
try:
ps.sunsubscribe(channel)
except Exception:
pass
ps.close()
redis_client.disconnect()
@pytest.mark.asyncio
@pytest.mark.skipif(
not _has_live_cluster(),
reason="local redis cluster not reachable; skip sharded pub/sub integration",
)
async def test_sharded_spublish_end_to_end_async() -> None:
"""Async cluster client routes SPUBLISH via ``execute_command``
because redis-py 6.x has no async ``spublish()`` wrapper."""
redis_client._async_clients.clear()
cluster = await redis_client.get_redis_async()
try:
res = await cluster.execute_command(
"SPUBLISH", "pr12900:sharded-pubsub:async", "ping"
)
# No subscribers — delivered count is 0, but the command must succeed
# (i.e. not raise MOVED/ASK or routing errors).
assert isinstance(res, int)
finally:
await redis_client.disconnect_async()
# ---------- Sharded pub/sub: unit tests with mocks ----------
def test_connect_sharded_pubsub_pins_host_and_disables_socket_timeout() -> None:
"""`socket_timeout=None` on the pubsub socket: a spurious read timeout
forces a reconnect whose PING races with subscribe-mode."""
with (
patch.object(
redis_client,
"resolve_shard_for_channel",
return_value=("shard-host", 7001),
),
patch.object(redis_client, "Redis", autospec=True) as mock_redis,
):
fake_client = MagicMock()
mock_redis.return_value = fake_client
client = redis_client.connect_sharded_pubsub("chan")
mock_redis.assert_called_once()
kwargs = mock_redis.call_args.kwargs
# Pinned to the shard's remapped address.
assert kwargs["host"] == "shard-host"
assert kwargs["port"] == 7001
# socket_timeout MUST be None for pubsub — see docstring in redis_client.py.
assert kwargs["socket_timeout"] is None
# Idle keepalive + health-check still intact.
assert kwargs["socket_keepalive"] is True
assert kwargs["health_check_interval"] == redis_client.HEALTH_CHECK_INTERVAL
# connect() must PING before returning.
client.ping.assert_called_once()
@pytest.mark.asyncio
async def test_connect_sharded_pubsub_async_disables_socket_timeout() -> None:
"""Async sibling of ``test_connect_sharded_pubsub_pins_host...``. Same
invariant: socket_timeout=None."""
with (
patch.object(
redis_client,
"resolve_shard_for_channel",
return_value=("shard-host", 7001),
),
patch.object(redis_client, "AsyncRedis", autospec=True) as mock_redis,
):
fake_client = MagicMock()
fake_client.ping = AsyncMock()
mock_redis.return_value = fake_client
client = await redis_client.connect_sharded_pubsub_async("chan")
kwargs = mock_redis.call_args.kwargs
assert kwargs["host"] == "shard-host"
assert kwargs["port"] == 7001
assert kwargs["socket_timeout"] is None
assert kwargs["socket_keepalive"] is True
assert kwargs["health_check_interval"] == redis_client.HEALTH_CHECK_INTERVAL
client.ping.assert_awaited_once()
def test_resolve_shard_for_channel_applies_address_remap() -> None:
"""The resolver must run ``_address_remap`` on the announced address so
callers connect through the same address the cluster client uses."""
cluster = MagicMock()
node = MagicMock()
node.host = "announced-host"
node.port = 17001
cluster.get_node_from_key.return_value = node
with (
patch.object(redis_client, "get_redis", return_value=cluster),
patch.object(redis_client, "USE_ANNOUNCED_ADDRESS", False),
):
host, port = redis_client.resolve_shard_for_channel("chan")
# Remap pins the host to the seed, keeps the announced port.
assert host == redis_client.HOST
assert port == 17001
def test_resolve_shard_for_channel_raises_when_no_node_owns_keyslot() -> None:
"""Missing cluster node → explicit RuntimeError, not a silent None deref."""
cluster = MagicMock()
cluster.get_node_from_key.return_value = None
with patch.object(redis_client, "get_redis", return_value=cluster):
with pytest.raises(RuntimeError, match="No cluster node"):
redis_client.resolve_shard_for_channel("chan")
def test_resolve_shard_for_channel_passthrough_with_announced_flag() -> None:
"""When REDIS_USE_ANNOUNCED_ADDRESS is on, resolver returns the announced
address verbatim — no HOST override."""
cluster = MagicMock()
node = MagicMock()
node.host = "redis-2"
node.port = 17002
cluster.get_node_from_key.return_value = node
with (
patch.object(redis_client, "get_redis", return_value=cluster),
patch.object(redis_client, "USE_ANNOUNCED_ADDRESS", True),
):
host, port = redis_client.resolve_shard_for_channel("chan")
assert (host, port) == ("redis-2", 17002)
def test_health_check_interval_is_30s_default() -> None:
"""Idle PING interval must be <=30s so half-open pubsub sockets don't
wait for the OS TCP keepalive (~2h)."""
assert redis_client.HEALTH_CHECK_INTERVAL <= 30
def test_connect_sets_health_check_interval() -> None:
"""The cluster client must propagate health_check_interval to each node
pool — otherwise idle cluster sockets go stale."""
with patch.object(redis_client, "RedisCluster", autospec=True) as mock_cluster:
mock_cluster.return_value = MagicMock(spec=RedisCluster)
redis_client.connect()
kwargs = mock_cluster.call_args.kwargs
assert kwargs["health_check_interval"] == redis_client.HEALTH_CHECK_INTERVAL
assert kwargs["health_check_interval"] > 0
# ---------- K8s same-port shard collapse regression (AUTOGPT-SERVER-8SX) ----------
def test_k8s_shard_collapse_with_announced_address_off_routes_all_to_seed() -> None:
"""In K8s every shard serves on port 6379 behind the seed service, so the
default `_address_remap` collapses all shards to `(HOST, 6379)` — the
AUTOGPT-SERVER-8SX bug. Fix: `REDIS_USE_ANNOUNCED_ADDRESS=true`."""
cluster = MagicMock()
# 3 shards, each owning a distinct hash slot, but every pod serves on
# 6379 in K8s — exactly the production topology.
nodes_by_channel = {
"{ch-a}/x": MagicMock(host="redis-cluster-redis-0", port=6379),
"{ch-b}/y": MagicMock(host="redis-cluster-redis-1", port=6379),
"{ch-c}/z": MagicMock(host="redis-cluster-redis-2", port=6379),
}
cluster.get_node_from_key.side_effect = lambda c: nodes_by_channel[c]
with (
patch.object(redis_client, "get_redis", return_value=cluster),
patch.object(redis_client, "USE_ANNOUNCED_ADDRESS", False),
patch.object(redis_client, "HOST", "redis-dev-seed"),
):
endpoints = {
channel: redis_client.resolve_shard_for_channel(channel)
for channel in nodes_by_channel
}
# The bug: every shard resolves to the same seed:port endpoint.
assert len(set(endpoints.values())) == 1, (
f"Expected the K8s shard-collapse bug, got {endpoints!r}. "
"If this test is failing it means _address_remap behaviour changed "
"and the AUTOGPT-SERVER-8SX regression note in this file needs review."
)
assert all(ep == ("redis-dev-seed", 6379) for ep in endpoints.values())
def test_k8s_shard_collapse_fixed_with_announced_address_on() -> None:
"""With `REDIS_USE_ANNOUNCED_ADDRESS=true`, each shard's announced FQDN
passes through, so distinct slots resolve to distinct endpoints."""
cluster = MagicMock()
nodes_by_channel = {
"{ch-a}/x": MagicMock(host="redis-cluster-redis-0", port=6379),
"{ch-b}/y": MagicMock(host="redis-cluster-redis-1", port=6379),
"{ch-c}/z": MagicMock(host="redis-cluster-redis-2", port=6379),
}
cluster.get_node_from_key.side_effect = lambda c: nodes_by_channel[c]
with (
patch.object(redis_client, "get_redis", return_value=cluster),
patch.object(redis_client, "USE_ANNOUNCED_ADDRESS", True),
patch.object(redis_client, "HOST", "redis-dev-seed"),
):
endpoints = {
channel: redis_client.resolve_shard_for_channel(channel)
for channel in nodes_by_channel
}
# Each shard maps to a distinct endpoint — sharded pubsub can route
# SSUBSCRIBE to the slot owner.
assert len(set(endpoints.values())) == 3
assert endpoints["{ch-a}/x"] == ("redis-cluster-redis-0", 6379)
assert endpoints["{ch-b}/y"] == ("redis-cluster-redis-1", 6379)
assert endpoints["{ch-c}/z"] == ("redis-cluster-redis-2", 6379)
def test_local_compose_remap_keeps_distinct_ports_per_shard() -> None:
"""Local docker-compose announces distinct ports per shard, so the
`(host, port)` tuple stays distinct even with `HOST` pinned to seed."""
cluster = MagicMock()
nodes_by_channel = {
"{ch-a}/x": MagicMock(host="redis-0", port=17000),
"{ch-b}/y": MagicMock(host="redis-1", port=17001),
"{ch-c}/z": MagicMock(host="redis-2", port=17002),
}
cluster.get_node_from_key.side_effect = lambda c: nodes_by_channel[c]
with (
patch.object(redis_client, "get_redis", return_value=cluster),
patch.object(redis_client, "USE_ANNOUNCED_ADDRESS", False),
patch.object(redis_client, "HOST", "localhost"),
):
endpoints = {
channel: redis_client.resolve_shard_for_channel(channel)
for channel in nodes_by_channel
}
# Distinct ports → distinct endpoints even after remap pins the host.
assert len(set(endpoints.values())) == 3
assert endpoints["{ch-a}/x"] == ("localhost", 17000)
assert endpoints["{ch-b}/y"] == ("localhost", 17001)
assert endpoints["{ch-c}/z"] == ("localhost", 17002)
# ---------- Sharded pub/sub: multi-shard integration on the live cluster ----------
def _channel_owner(channel: str) -> tuple[str, int]:
"""Resolve the slot owner for ``channel`` via the live cluster client."""
cluster = redis_client.get_redis()
node = cluster.get_node_from_key(channel)
assert node is not None, f"no slot owner for {channel!r}"
return node.host, node.port
def _channels_on_distinct_shards(n: int = 3) -> list[str]:
"""Build N hash-tagged channels each mapping to a distinct shard."""
seen: dict[tuple[str, int], str] = {}
for tag_id in range(2000):
chan = "{u" + str(tag_id) + "/g}/exec/e"
owner = _channel_owner(chan)
seen.setdefault(owner, chan)
if len(seen) >= n:
break
assert len(seen) >= n, f"could only cover {len(seen)} shards"
return list(seen.values())[:n]
@pytest.mark.skipif(
not _has_live_cluster(),
reason="local redis cluster not reachable; skip multi-shard integration",
)
def test_resolve_shard_for_channel_lands_on_distinct_shards() -> None:
"""3 hash-tagged channels resolve to 3 different shards (slot-distribution)."""
redis_client.get_redis.cache_clear()
try:
channels = _channels_on_distinct_shards(3)
endpoints = {ch: redis_client.resolve_shard_for_channel(ch) for ch in channels}
# Three channels → three distinct (host, port) endpoints.
assert len(set(endpoints.values())) == 3
finally:
redis_client.disconnect()
@pytest.mark.skipif(
not _has_live_cluster(),
reason="local redis cluster not reachable; skip multi-shard integration",
)
def test_sharded_pubsub_concurrent_subscribers_on_three_shards() -> None:
"""SSUBSCRIBE on three channels owned by three different shards, then
SPUBLISH to each — every payload must land on its subscriber."""
redis_client.get_redis.cache_clear()
cluster = redis_client.get_redis()
try:
channels = _channels_on_distinct_shards(3)
# Subscribe via the cluster client so redis-py's per-node pubsub
# mapping handles the sharded routing for us.
ps = cluster.pubsub()
try:
for ch in channels:
ps.ssubscribe(ch)
# The cluster client opens one node-pubsub per shard owner — three
# channels on three shards must produce three distinct node clients.
assert len(ps.node_pubsub_mapping) == 3, (
"Expected SSUBSCRIBE on 3 channels owned by 3 distinct shards "
f"to open 3 node-pubsubs, got {len(ps.node_pubsub_mapping)}"
)
# Publish to each channel and verify each reaches the right node.
for i, ch in enumerate(channels):
assert cluster.spublish(ch, f"payload-{i}") >= 1
# Drain ssubscribe confirmations + smessages from every node.
received: dict[str, str] = {}
for node_ps in ps.node_pubsub_mapping.values():
# First message per node is the ssubscribe confirm; subsequent
# smessages carry the test payloads.
for _ in range(4): # confirm + at most 1 payload per shard
msg = node_ps.get_message(timeout=2.0)
if msg is None:
break
if msg["type"] == "smessage":
received[msg["channel"]] = msg["data"]
for i, ch in enumerate(channels):
assert ch in received, f"channel {ch!r} got no message"
assert received[ch] == f"payload-{i}"
finally:
for ch in channels:
try:
ps.sunsubscribe(ch)
except Exception:
pass
ps.close()
finally:
redis_client.disconnect()
@pytest.mark.skipif(
not _has_live_cluster(),
reason="local redis cluster not reachable; skip multi-shard integration",
)
def test_sharded_pubsub_idle_subscriber_survives_health_check_window() -> None:
"""An SSUBSCRIBE connection must survive an idle window longer than
`HEALTH_CHECK_INTERVAL` — uses `+5s` to provoke at least one health check."""
import time as _time
redis_client.get_redis.cache_clear()
cluster = redis_client.get_redis()
channel = "{idle-test}/exec/e"
client = redis_client.connect_sharded_pubsub(channel)
ps = client.pubsub()
try:
ps.ssubscribe(channel)
confirm = ps.get_message(timeout=2.0)
assert confirm is not None and confirm["type"] == "ssubscribe"
# Idle window — must exceed health_check_interval at least once.
idle_seconds = redis_client.HEALTH_CHECK_INTERVAL + 5
_time.sleep(idle_seconds)
# After idling, publish + receive should still work.
assert cluster.spublish(channel, "post-idle") >= 1
msg = ps.get_message(timeout=5.0)
assert msg is not None and msg["type"] == "smessage"
assert msg["data"] == "post-idle"
finally:
try:
ps.sunsubscribe(channel)
except Exception:
pass
ps.close()
client.close()
redis_client.disconnect()
@pytest.mark.skipif(
not _has_live_cluster(),
reason="local redis cluster not reachable; skip multi-shard integration",
)
def test_sharded_pubsub_reconnect_after_forced_disconnect() -> None:
"""Subscriber reconnect after a forced disconnect — close socket, open
a fresh one, and verify new SPUBLISH events still arrive."""
redis_client.get_redis.cache_clear()
cluster = redis_client.get_redis()
channel = "{reconnect-test}/exec/e"
# Round 1: subscribe, receive one payload, then close everything.
client = redis_client.connect_sharded_pubsub(channel)
ps = client.pubsub()
try:
ps.ssubscribe(channel)
ps.get_message(timeout=2.0) # ssubscribe confirmation
assert cluster.spublish(channel, "before-restart") >= 1
msg = ps.get_message(timeout=5.0)
assert msg is not None and msg["data"] == "before-restart"
finally:
try:
ps.sunsubscribe(channel)
except Exception:
pass
ps.close()
client.close()
# Round 2: a fresh subscriber on the same channel — same routing,
# different socket. This exercises the reconnect-and-resubscribe path
# the conn_manager runs after a network blip.
client2 = redis_client.connect_sharded_pubsub(channel)
ps2 = client2.pubsub()
try:
ps2.ssubscribe(channel)
ps2.get_message(timeout=2.0)
assert cluster.spublish(channel, "after-restart") >= 1
msg = ps2.get_message(timeout=5.0)
assert msg is not None and msg["data"] == "after-restart"
finally:
try:
ps2.sunsubscribe(channel)
except Exception:
pass
ps2.close()
client2.close()
redis_client.disconnect()

View File

@@ -22,8 +22,7 @@ this module can cover.
from typing import Any, cast
from redis import Redis
from redis.asyncio import Redis as AsyncRedis
from backend.data.redis_client import AsyncRedisClient, RedisClient
# ---------------------------------------------------------------------------
# Lua scripts — registered centrally so there is exactly ONE authoritative
@@ -47,9 +46,30 @@ end
return 0
"""
# Push to a capped list only when a hash field currently matches the expected
# value. Returns the new list length, or -1 when the guard fails.
#
# KEYS[1] hash key
# KEYS[2] list key
# ARGV[1] hash field
# ARGV[2] expected current value
# ARGV[3] list value
# ARGV[4] max list length
# ARGV[5] list TTL seconds
_GATED_CAPPED_RPUSH_LUA = """
local current = redis.call('HGET', KEYS[1], ARGV[1])
if current ~= ARGV[2] then
return -1
end
redis.call('RPUSH', KEYS[2], ARGV[3])
redis.call('LTRIM', KEYS[2], -tonumber(ARGV[4]), -1)
redis.call('EXPIRE', KEYS[2], tonumber(ARGV[5]))
return redis.call('LLEN', KEYS[2])
"""
async def incr_with_ttl(
redis: AsyncRedis,
redis: AsyncRedisClient,
key: str,
ttl_seconds: int,
*,
@@ -85,7 +105,7 @@ async def incr_with_ttl(
def incr_with_ttl_sync(
redis: Redis,
redis: RedisClient,
key: str,
ttl_seconds: int,
*,
@@ -103,7 +123,7 @@ def incr_with_ttl_sync(
async def capped_rpush(
redis: AsyncRedis,
redis: AsyncRedisClient,
key: str,
value: str,
*,
@@ -129,8 +149,42 @@ async def capped_rpush(
return int(results[-1])
async def capped_rpush_if_hash_field(
redis: AsyncRedisClient,
*,
hash_key: str,
hash_field: str,
expected: str,
list_key: str,
value: str,
max_len: int,
ttl_seconds: int,
) -> int | None:
"""Atomically RPUSH to a bounded list iff a hash field matches.
Returns the new list length when the push happens, or ``None`` when the
hash field does not currently match ``expected``.
"""
result = await cast(
"Any",
redis.eval(
_GATED_CAPPED_RPUSH_LUA,
2,
hash_key,
list_key,
hash_field,
expected,
value,
str(max_len),
str(ttl_seconds),
),
)
length = int(result)
return None if length < 0 else length
async def hash_compare_and_set(
redis: AsyncRedis,
redis: AsyncRedisClient,
key: str,
field: str,
*,

View File

@@ -11,6 +11,7 @@ import pytest
from backend.data.redis_helpers import (
capped_rpush,
capped_rpush_if_hash_field,
hash_compare_and_set,
incr_with_ttl,
incr_with_ttl_sync,
@@ -56,7 +57,17 @@ class _Fake:
return len(self.lists.get(key, []))
async def eval(self, script: str, numkeys: int, *args: Any) -> int:
# Shim for hash-CAS only.
if numkeys == 2:
hash_key, list_key = args[0], args[1]
field, expected, value, max_len, ttl_seconds = args[2:7]
h = self.hashes.setdefault(hash_key, {})
if h.get(field) != expected:
return -1
await self.rpush(list_key, value)
await self.ltrim(list_key, -int(max_len), -1)
await self.expire(list_key, int(ttl_seconds))
return await self.llen(list_key)
key, field, expected, new = args[0], args[1], args[2], args[3]
h = self.hashes.setdefault(key, {})
if h.get(field) == expected:
@@ -198,6 +209,50 @@ async def test_capped_rpush_first_push_returns_one() -> None:
assert r.lists["buf"] == ["only"]
# ── capped_rpush_if_hash_field ────────────────────────────────────────
@pytest.mark.asyncio
async def test_capped_rpush_if_hash_field_pushes_when_expected_matches() -> None:
r = _Fake()
r.hashes["meta"] = {"status": "running"}
length = await capped_rpush_if_hash_field(
r, # type: ignore[arg-type]
hash_key="meta",
hash_field="status",
expected="running",
list_key="buf",
value="only",
max_len=10,
ttl_seconds=60,
)
assert length == 1
assert r.lists["buf"] == ["only"]
assert r.ttls["buf"] == 60
@pytest.mark.asyncio
async def test_capped_rpush_if_hash_field_skips_when_expected_differs() -> None:
r = _Fake()
r.hashes["meta"] = {"status": "completed"}
length = await capped_rpush_if_hash_field(
r, # type: ignore[arg-type]
hash_key="meta",
hash_field="status",
expected="running",
list_key="buf",
value="lost",
max_len=10,
ttl_seconds=60,
)
assert length is None
assert "buf" not in r.lists
# ── hash_compare_and_set ───────────────────────────────────────────────

View File

@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
settings = Settings()
# Cache decorator alias for consistent user lookup caching
cache_user_lookup = cached(maxsize=1000, ttl_seconds=300)
cache_user_lookup = cached(maxsize=1000, ttl_seconds=300, shared_cache=True)
@cache_user_lookup
@@ -509,8 +509,15 @@ async def update_user_timezone(user_id: str, timezone: str) -> User:
if not user:
raise ValueError(f"User not found with ID: {user_id}")
# Invalidate cache for this user
# Invalidate user caches so subsequent reads see the new timezone.
# get_user_by_id and get_user_by_email are keyed by a single value
# and can be deleted surgically; get_or_create_user is keyed by the
# JWT-payload dict so we can't delete a single entry — clear it
# entirely.
get_user_by_id.cache_delete(user_id)
if user.email:
get_user_by_email.cache_delete(user.email)
get_or_create_user.cache_clear()
return User.from_db(user)
except Exception as e:

View File

@@ -0,0 +1,66 @@
"""Unit tests for helpers in backend.data.user."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.data import user as user_module
from backend.data.user import update_user_timezone
from backend.util.exceptions import DatabaseError
class TestUpdateUserTimezone:
@pytest.mark.asyncio
async def test_invalidates_all_three_user_caches(self):
prisma_user = MagicMock(id="user-1", email="user@example.com")
sentinel_user = MagicMock()
with (
patch.object(user_module, "PrismaUser") as mock_prisma_user,
patch.object(user_module.User, "from_db", return_value=sentinel_user),
patch.object(user_module.get_user_by_id, "cache_delete") as by_id_del,
patch.object(user_module.get_user_by_email, "cache_delete") as by_email_del,
patch.object(user_module.get_or_create_user, "cache_clear") as goc_clear,
):
mock_prisma_user.prisma.return_value.update = AsyncMock(
return_value=prisma_user
)
result = await update_user_timezone("user-1", "Europe/London")
assert result is sentinel_user
by_id_del.assert_called_once_with("user-1")
by_email_del.assert_called_once_with("user@example.com")
goc_clear.assert_called_once_with()
@pytest.mark.asyncio
async def test_skips_email_cache_invalidation_when_email_missing(self):
prisma_user = MagicMock(id="user-1", email=None)
sentinel_user = MagicMock()
with (
patch.object(user_module, "PrismaUser") as mock_prisma_user,
patch.object(user_module.User, "from_db", return_value=sentinel_user),
patch.object(user_module.get_user_by_id, "cache_delete") as by_id_del,
patch.object(user_module.get_user_by_email, "cache_delete") as by_email_del,
patch.object(user_module.get_or_create_user, "cache_clear") as goc_clear,
):
mock_prisma_user.prisma.return_value.update = AsyncMock(
return_value=prisma_user
)
await update_user_timezone("user-1", "Europe/London")
by_id_del.assert_called_once_with("user-1")
by_email_del.assert_not_called()
goc_clear.assert_called_once_with()
@pytest.mark.asyncio
async def test_wraps_prisma_errors_in_database_error(self):
with patch.object(user_module, "PrismaUser") as mock_prisma_user:
mock_prisma_user.prisma.return_value.update = AsyncMock(
side_effect=RuntimeError("connection lost")
)
with pytest.raises(DatabaseError) as exc:
await update_user_timezone("user-1", "Europe/London")
assert "user-1" in str(exc.value)
assert "connection lost" in str(exc.value)

View File

@@ -64,9 +64,12 @@ async def clear_insufficient_funds_notifications(user_id: str) -> int:
redis_client = await redis.get_redis_async()
pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
keys = [key async for key in redis_client.scan_iter(match=pattern)]
if keys:
return await redis_client.delete(*keys)
return 0
# Keys here span multiple graph IDs and therefore multiple cluster
# slots — a bulk DELETE would raise CROSSSLOT, so delete per key.
deleted = 0
for key in keys:
deleted += await redis_client.delete(key)
return deleted
except Exception as e:
logger.warning(
f"Failed to clear insufficient funds notification flags for user "

View File

@@ -7,16 +7,12 @@ import time
from typing import TYPE_CHECKING, Any, cast
if TYPE_CHECKING:
from redis import Redis
from redis.asyncio import Redis as AsyncRedis
from backend.data.redis_client import AsyncRedisClient, RedisClient
logger = logging.getLogger(__name__)
# Lua CAS release: only delete the key if the stored value still matches our
# owner_id. Returns 1 on delete, 0 on no-op. This makes release() safe against
# the race where an external caller (e.g. mark_session_completed's force-release)
# deletes our key and a new owner acquires it before our release() fires — without
# the CAS guard, release() would wipe the successor's valid lock.
# CAS release: DEL only when the stored owner still matches — guards against
# wiping a successor's lock after an external force-release.
_RELEASE_LUA = (
"if redis.call('get', KEYS[1]) == ARGV[1] then "
"return redis.call('del', KEYS[1]) "
@@ -27,7 +23,9 @@ _RELEASE_LUA = (
class ClusterLock:
"""Simple Redis-based distributed lock for preventing duplicate execution."""
def __init__(self, redis: "Redis", key: str, owner_id: str, timeout: int = 300):
def __init__(
self, redis: "RedisClient", key: str, owner_id: str, timeout: int = 300
):
self.redis = redis
self.key = key
self.owner_id = owner_id
@@ -150,7 +148,7 @@ class AsyncClusterLock:
"""Async Redis-based distributed lock for preventing duplicate execution."""
def __init__(
self, redis: "AsyncRedis", key: str, owner_id: str, timeout: int = 300
self, redis: "AsyncRedisClient", key: str, owner_id: str, timeout: int = 300
):
self.redis = redis
self.key = key

Some files were not shown because too many files have changed in this diff Show More