Compare commits

...

14 Commits

Author SHA1 Message Date
Nicholas Tindle
351001fdca fix(frontend/copilot): keep artifact sidebar alive on bad HTML artifacts
Two defensive fixes so one misbehaving artifact can't take down the chat
sidebar:

1. Intercept fragment-link clicks inside artifact iframes. srcdoc iframes
   with `sandbox="allow-scripts"` (no `allow-same-origin`) resolve
   `<a href="#x">` against the parent's URL, so clicking a TOC anchor in
   AI-generated HTML was navigating the copilot page itself to
   `/copilot?sessionId=...#activation` and crashing it. A small click-capture
   script injected alongside the Tailwind CDN now preventDefaults fragment
   clicks and scrolls the local target into view.

2. Wrap the artifact renderer in an ArtifactErrorBoundary so any future
   render-time throw surfaces as a visible, copyable error instead of
   tearing down the whole panel. The fallback exposes a "Copy error
   details" button that puts the artifact title, type, and stack on the
   clipboard for the user to paste back to the agent.

Regression coverage at every injection site: srcdoc for HTML artifacts
(ArtifactContent), for the inline HTMLRenderer, and for React artifacts
(buildReactArtifactSrcDoc). The interceptor logic itself is exercised
against real DOM in iframe-sandbox-csp.test.ts.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-17 15:47:17 -05:00
Joe Munene
3a01874911 fix(frontend/builder): preserve agent name in AgentExecutor node title after reload (#12805)
## Summary

Fixes #11041

When an `AgentExecutorBlock` is placed in the builder, it initially
displays the agent's name (e.g., "Researcher v2"). After saving and
reloading the page, the title reverts to the generic "Agent Executor."

## Root Cause

The backend correctly persists `agent_name` and `graph_version` in
`hardcodedValues` (via `input_default` in `AgentExecutorBlock`).
However, `NodeHeader.tsx` always resolves the display title from
`data.title` (the generic block name), ignoring the persisted agent
name.

## Fix

Modified the title resolution chain in `NodeHeader.tsx` to check
`data.hardcodedValues.agent_name` between the user's custom name and the
generic block title:

1. `data.metadata.customized_name` (user's manual rename) — highest
priority
2. `agent_name` + ` v{graph_version}` from `hardcodedValues` — **new**
3. `data.title` (generic block name) — fallback

This is a frontend-only change. No backend modifications needed.

## Files Changed

-
`autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx`
(+11, -1)

## Test Plan

- [x] Place an AgentExecutorBlock, select an agent — title shows agent
name
- [x] Save graph, reload page — title still shows agent name (was "Agent
Executor" before)
- [x] Double-click to rename — custom name takes priority over agent
name
- [x] Clear custom name — falls back to agent name
- [x] Non-AgentExecutor blocks — unaffected, show generic title as
before

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-04-17 15:20:32 +00:00
Zamil Majdy
6d770d9917 fix(platform/copilot): revert forward pagination, add visibility guarantee for blank chat (#12831)
## Why / What / How

**Why:** PR #12796 changed completed copilot sessions to load messages
from sequence 0 forward (ascending), which broke the standard chat UX —
users now land at the beginning of the conversation instead of the most
recent messages. Reported in Discord.

**What:** Reverts the forward pagination approach and replaces it with a
visibility guarantee that ensures every page contains at least one
user/assistant message.

**How:**
- **Backend**: Removed after_sequence, from_start, forward_paginated,
newest_sequence — always use backward (newest-first) pagination. Added
_expand_for_visibility() helper: after fetching, if the entire page is
tool messages (invisible in UI), expand backward up to 200 messages
until a visible user/assistant message is found.
- **Frontend**: Removed all forwardPaginated/newestSequence plumbing
from hooks and components. Removed bottom LoadMoreSentinel. Simplified
message merge to always prepend paged messages.

### Changes
- routes.py: Reverted to simple backward pagination, removed TOCTOU
re-fetch logic
- db.py: Removed forward mode, extracted _expand_tool_boundary() and
added _expand_for_visibility()
- SessionDetailResponse: Removed newest_sequence and forward_paginated
fields
- openapi.json: Removed after_sequence param and forward pagination
response fields
- Frontend hooks/components: Removed forward pagination props and logic
(-1000 lines)
- Updated all tests (backend: 63 pass, frontend: 1517 pass)

### Checklist
- [x] I have clearly listed my changes in the PR description
- [x] Backend unit tests: 63 pass
- [x] Frontend unit tests: 1517 pass
- [x] Frontend lint + types: clean
- [x] Backend format + pyright: clean
2026-04-17 19:23:28 +07:00
slepybear
334ec18c31 docs: convert in-code comments to MkDocs admonitions in block-sdk-gui… (#12819)
### Why / What / How

<!-- Why: Why does this PR exist? What problem does it solve, or what's
broken/missing without it? -->
This PR converts inline Python comments in code examples within
`block-sdk-guide.md` into MkDocs `!!! note` admonitions. This makes code
examples cleaner and more copy-paste friendly while preserving all
explanatory content.

<!-- What: What does this PR change? Summarize the changes at a high
level. -->
Converts inline comments in code blocks to admonitions following the
pattern established in PR #12396 (new_blocks.md) and PR #12313.

<!-- How: How does it work? Describe the approach, key implementation
details, or architecture decisions. -->
- Wrapped code examples with `!!! note` admonitions
- Removed inline comments from code blocks for clean copy-paste
- Added explanatory admonitions after each code block

### Changes 🏗️

- Provider configuration examples (API key and OAuth)
- Block class Input/Output schema annotations
- Block initialization parameters
- Test configuration
- OAuth and webhook handler implementations
- Authentication types and file handling patterns

### Checklist 📋

#### For documentation changes:
- [x] Follows the admonition pattern from PR #12396
- [x] No code changes, documentation only
- [x] Admonition syntax verified correct

#### For configuration changes:
- [ ] `.env.default` is updated or already compatible with my changes
- [ ] `docker-compose.yml` is updated or already compatible with my
changes

---

**Related Issues**: Closes #8946

Co-authored-by: slepybear <slepybear@users.noreply.github.com>
Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-04-17 07:47:52 +00:00
slepybear
ea5cfdfa2e fix(frontend): remove debug console.log statements (#12823)
## Why
Debug console.log statements were left in production code, which can
leak
sensitive information and pollute browser developer consoles.

## What
Removed console.log from 4 non-legacy frontend components:
- useNavbar.ts: isLoggedIn debug log
- WalletRefill.tsx: autoRefillForm debug log  
- EditAgentForm.tsx: category field debug log
- TimezoneForm.tsx: currentTimezone debug log

## How
Simply deleted the console.log lines as they served no purpose 
other than debugging during development.

## Checklist
- [x] Code follows project conventions
- [x] Only frontend changes (4 files, 6 lines removed)
- [x] No functionality changes

Co-authored-by: slepybear <slepybear@users.noreply.github.com>
2026-04-17 07:31:51 +00:00
Ubbe
d13a85bef7 feat(frontend): surface scheduled agents in library & copilot briefings (#12818)
## Why

Scheduled agents weren't well-surfaced in the Library and Copilot
briefings:

- The Library fleet summary didn't count agents that are scheduled
purely via the scheduler (only those with a `recommended_schedule_cron`
set at the agent level).
- Sitrep items didn't distinguish scheduled or listening (trigger-based)
agents, so they often fell back to a generic "idle" state.
- Scheduled chips showed a generic message with no indication of when
the next run would happen.
- The Copilot Agent Briefing surfaced every scheduled agent regardless
of how far out the next run was — an agent scheduled a month away would
take a slot from something actually happening soon.
- Long sitrep messages overflowed the row.

## What

- Add `is_scheduled` to `LibraryAgent` (sourced from the scheduler) so
the frontend can reliably detect schedule-only agents.
- Count scheduled agents in `useLibraryFleetSummary`.
- Include scheduled and listening agents in sitrep items, with a
priority ordering (error → running → stale → success → listening →
scheduled → idle).
- Show a relative next-run time on scheduled sitrep chips (e.g.
"Scheduled to run in 2h" / "in 3d").
- Filter the Copilot Agent Briefing to scheduled agents whose next run
is within the next 3 days.
- Truncate long sitrep messages to 1 line with `OverflowText` and show
the full text in a tooltip on hover.

## How

- Scheduler → `LibraryAgent` mapping populates `is_scheduled` /
`next_scheduled_run`.
- `useSitrepItems` gains an optional `scheduledWithinMs` parameter.
Copilot's `usePulseChips` passes `3 * 24 * 60 * 60 * 1000`; the Library
briefing omits it to keep its existing (unbounded) behavior.
- Scheduled config-based sitrep items are skipped when
`next_scheduled_run` is missing or outside the window.
- `SitrepItem` wraps the message in `OverflowText` so a single-line
ellipsis + hover tooltip replaces raw overflow.

## Test plan

- [ ] `/library` — scheduled and listening agents appear in the sitrep
with accurate copy; fleet summary counts scheduled agents correctly;
long messages truncate with a tooltip on hover.
- [ ] `/copilot` — on an empty session with the `AGENT_BRIEFING` flag
on, the briefing only shows scheduled agents whose next run is within 3
days; agents scheduled further out no longer appear as "scheduled"
chips.
- [ ] Scheduled chip text reads "Scheduled to run in {Nm|Nh|Nd}"
matching `next_scheduled_run`.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-17 14:36:15 +07:00
Zamil Majdy
60b85640e7 fix(backend/copilot): replace dedup lock with idempotent append_and_save_message (#12814)
## Why

The Redis dedup lock (`chat:msg_dedup:{session}:{content_hash}`, 30s
TTL) was solving the wrong problem:

- Its purpose: block infra/nginx retries from calling
`append_and_save_message` twice after a client disconnect, writing a
duplicate user message to the DB.
- The approach: deliberately hold the lock for 30s on `GeneratorExit`.
- Why unnecessary: the executor's cluster lock already prevents
duplicate *execution*. The only real gap was duplicate *DB writes* in
the ~1s before the executor picks up the turn.

## What

- **Deleted** `message_dedup.py` and `message_dedup_test.py` (~150 lines
removed).
- **Removed** all dedup lock code from `routes.py` (~40 lines removed).
- **`append_and_save_message`** is now idempotent and self-contained:
- Uses redis-py's built-in `Lock(timeout=10, blocking_timeout=2)` —
Lua-script atomic acquire/release, no manual poll/sleep loop.
- Lock context manager yields `bool` (`True` = acquired, `False` =
degraded). When degraded (Redis down or 2s timeout), reads from DB
directly instead of cache to avoid stale-state duplicates.
- Idempotency check: if `session.messages[-1]` already matches the
incoming role+content, returns `None` instead of the session.
- Lock released explicitly as soon as the write completes; `try/except`
in `finally` so a cleanup error after a successful write never surfaces
a false 500.
- On cache-write failure, the stale cache entry is invalidated so future
reads fall back to the authoritative DB.
- **`routes.py`** uses the `None` signal: `is_duplicate_message = (await
append_and_save_message(...)) is None`
- Skips `create_session` and `enqueue_copilot_turn` for duplicates —
client re-attaches to the existing turn's Redis stream.
- `track_user_message` and `turn_id` generation only happen when
`is_duplicate_message` is false.
- **`subscribe_to_session`** retry window increased from 1×50ms to
3×100ms — covers the window where a duplicate request subscribes before
the original's `create_session` hset completes.
- **Cleaned up** `routes_test.py`: removed 5 dedup-specific tests and
the `mock_redis` setup from `_mock_stream_internals`; added
duplicate-skips-enqueue test.

## How

The idempotency guard distinguishes legit same-text messages from
retries via the **assistant turn between them**: if the user said "yes",
got a response, and says "yes" again, `session.messages[-1]` is the
assistant reply, so the role check fails and the second message goes
through. A retry (no response yet) sees the user message as the last
entry and is blocked.

```python
if (
    session.messages
    and session.messages[-1].role == message.role
    and session.messages[-1].content == message.content
):
    return None  # duplicate — caller skips enqueue
```

The Redis lock ensures this check always sees authoritative state even
in multi-replica deployments. When the lock is unavailable (Redis down
or contention), reading from DB directly (bypassing potentially stale
cache) provides the same safety guarantee at the cost of a DB
round-trip.

## Checklist

- [x] PR targets `dev`
- [x] Conventional commit title with scope
- [x] Tests added/updated (duplicate detection, lock degradation, DB
error, cache invalidation paths)
- [x] `poetry run format` and `poetry run pyright` pass clean
- [x] No new linter suppressors
2026-04-16 22:12:30 +07:00
Zamil Majdy
87e4d42750 fix(backend/copilot): fix initial load missing messages + forward pagination for completed sessions (#12796)
### Why / What / How

**Why:** Completed copilot sessions with many messages showed a
completely empty chat view. A user reported a 158-message session that
appeared blank on reload.

**What:** Two bugs fixed:
1. **Backend** — initial page load always returned the newest 50
messages in DESC order. For sessions heavy in tool calls, the user's
original messages (seq 0–5) were never included; all 50 slots consumed
by mid-session tool outputs.
2. **Frontend** — convertChatSessionToUiMessages silently dropped user
messages with null/empty content.

**How:** For completed sessions (no active stream), the backend now
loads from sequence 0 in ASC order. Active/streaming sessions keep
newest-first for streaming context. A new after_sequence forward cursor
enables infinite-scroll for subsequent pages (sentinel moves to bottom).
The frontend wires forward_paginated + newest_sequence end-to-end.

### Changes 🏗️

- db.py: added from_start (ASC) and after_sequence (forward cursor)
modes; added newest_sequence to PaginatedMessages
- routes.py: detect completed vs active on initial load; pass
from_start=True for completed; expose newest_sequence +
forward_paginated; accept after_sequence param
- convertChatSessionToUiMessages.ts: never drop user messages with empty
content
- useLoadMoreMessages.ts: forward pagination via after_sequence; append
pages to end
- ChatMessagesContainer.tsx: LoadMoreSentinel at bottom for
forward-paginated sessions
- Wire newestSequence + forwardPaginated end-to-end through
useChatSession/useCopilotPage/ChatContainer
- openapi.json: add after_sequence + newest_sequence/forward_paginated;
regenerate types
- db_test.py: 9 new unit tests for from_start and after_sequence modes

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Open a completed session with many messages — first user message
visible on initial load
- [x] Scroll to bottom of completed session — load more appends next
page
- [x] Open active/streaming session — newest messages shown first,
streaming unaffected
  - [x] Backend unit tests: all 28 pass
  - [x] Frontend lint/format: clean, no new type errors

---------

Co-authored-by: chernistry <73943355+chernistry@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
2026-04-16 14:16:54 +00:00
Ubbe
0339d95d12 fix(frontend): small UI fixes, sort menu bg, name update auth, stats grid overflow, pulse chips (#12815)
## Summary
- **LibrarySortMenu / AgentFilterMenu**: Force `!bg-transparent` and
neutralise legacy `SelectTrigger` styles (`m-0.5`, `ring-offset-white`,
`shadow-sm`) that caused a white background around the trigger
- **EditNameDialog**: Replace client-side `supabase.auth.updateUser()`
with server-side `PUT /api/auth/user` route — fixes "Auth session
missing!" error caused by `httpOnly` cookies being inaccessible to
browser JS
- **StatsGrid**: Swap label `Text` for `OverflowText` so tile labels
truncate with `…` and show a tooltip instead of wrapping when the grid
is squeezed
- **PulseChips**: Set fixed `15rem` chip width with `shrink-0`,
horizontal scroll, and styled thin scrollbar
- **Tests**: Updated `EditNameDialog` tests to use MSW instead of
mocking Supabase client; added 7 new `PulseChips` integration tests

## Test plan
- [x] `pnpm test:unit` — all 1495 tests pass (91 files)
- [x] `pnpm format && pnpm lint` — clean
- [x] `pnpm types` — no new errors (pre-existing only)
- [ ] QA `/library?sort=updatedAt` — sort menu trigger has no white bg
- [ ] QA `/library` — StatsGrid labels truncate with tooltip on narrow
viewports
- [ ] QA `/copilot` — PulseChips scroll horizontally at fixed width
- [ ] QA `/copilot` — Edit name dialog saves successfully (no "Auth
session missing!")

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-16 20:11:21 +07:00
Toran Bruce Richards
f410929560 feat(platform): Add xAI Grok 4.20 models from OpenRouter (#12620)
Requested by @Torantulino

Adds the 2 xAI Grok 4.20 models available on OpenRouter that are missing
from the platform.

## Why

`x-ai/grok-4.20` and `x-ai/grok-4.20-multi-agent` are xAI's current
flagship models (released March 2026) and are available via OpenRouter,
but weren't accessible from the platform's LLM blocks.

## Changes

**`autogpt_platform/backend/backend/blocks/llm.py`**
- Added `GROK_4_20` and `GROK_4_20_MULTI_AGENT` enum members
- Added corresponding `MODEL_METADATA` entries (open_router provider, 2M
context window, price tier 3)

**`autogpt_platform/backend/backend/data/block_cost_config.py`**
- Added `MODEL_COST` entries at 5 credits each (flagship tier, $2/M in)

**`docs/integrations/block-integrations/llm.md`**
- Added new model IDs to all LLM block tables

| Model | Pricing | Context |
|-------|---------|---------|
| `x-ai/grok-4.20` | $2/M in, $6/M out | 2M |
| `x-ai/grok-4.20-multi-agent` | $2/M in, $6/M out | 2M |

Both models use the standard OpenRouter chat completions API — no
special handling needed.

Resolves: SECRT-2196

---------

Co-authored-by: Torantulino <22963551+Torantulino@users.noreply.github.com>
Co-authored-by: Toran Bruce Richards <Torantulino@users.noreply.github.com>
Co-authored-by: Otto (AGPT) <otto@agpt.co>
2026-04-16 12:14:56 +00:00
Zamil Majdy
2bbec09e1a feat(platform): subscription tier billing via Stripe Checkout (#12727)
## Why

Introducing paid subscription tiers (PRO, BUSINESS) so we can charge for
AutoPilot capacity beyond the free tier. Without a billing integration,
all users share the same rate limits regardless of their willingness to
pay for additional capacity.

## What

End-to-end subscription billing system using Stripe Checkout Sessions:

**Backend:**
- `SubscriptionTier` enum (`FREE`, `PRO`, `BUSINESS`, `ENTERPRISE`) on
the `User` model
- `POST /credits/subscription` — creates a Stripe Checkout Session for
paid upgrades; for FREE tier or when `ENABLE_PLATFORM_PAYMENT` is off,
sets tier directly
- `GET /credits/subscription` — returns current tier, monthly cost
(cents), and all tier costs
- `POST /credits/stripe_webhook` — handles
`customer.subscription.created/updated/deleted`,
`checkout.session.completed`, `charge.dispute.*`, `refund.created`
- `sync_subscription_from_stripe()` — keeps `User.subscriptionTier` in
sync from webhook events; guards against out-of-order delivery
(cancelled event after new sub created), ENTERPRISE overwrite, and
duplicate webhook replay
- Open-redirect protection on `success_url`/`cancel_url` via
`_validate_checkout_redirect_url()`
- `_cancel_customer_subscriptions()` — cancels both active and trialing
subs; propagates errors so callers can avoid updating DB tier on Stripe
failure
- `_cleanup_stale_subscriptions()` — best-effort cancellation of old
subs when a new one becomes active (paid-to-paid upgrade), to prevent
double-billing
- `get_stripe_customer_id()` with idempotency key to prevent duplicate
Stripe customers on concurrent requests
- `cache_none=False` sentinel fix in `@cached` decorator so Stripe price
lookups retry on transient error instead of poisoning the cache with
`None`
- Stripe Price IDs read from LaunchDarkly (`stripe-price-id-pro`,
`stripe-price-id-business`). If not configured, upgrade returns 422.

**Frontend:**
- `SubscriptionTierSection` component on billing page: tier cards
(FREE/PRO/BUSINESS), upgrade/downgrade buttons, per-tier cost display,
Stripe redirect on upgrade
- Confirmation dialog for downgrades
- ENTERPRISE users see a read-only admin-managed banner
- Success toast on return from Stripe Checkout (`?subscription=success`)
- Uses generated `useGetSubscriptionStatus` /
`useUpdateSubscriptionTier` hooks

## How

- Paid upgrades use Stripe Checkout Sessions (not server-side
subscription creation) — Stripe handles PCI-compliant card collection
and the subscription lifecycle
- Tier is synced back via webhook on
`customer.subscription.created/updated/deleted`
- Downgrade to FREE cancels via Stripe API immediately; a
`stripe.StripeError` during cancellation returns 502 with a generic
message (no Stripe detail leakage)
- LaunchDarkly flags: `stripe-price-id-pro` (string),
`stripe-price-id-business` (string), `enable-platform-payment` (bool)
- `ENABLE_PLATFORM_PAYMENT=false` bypasses Stripe for beta/internal
access (sets tier directly)

## Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] `ENABLE_PLATFORM_PAYMENT=false` → tier change updates directly, no
Stripe redirect
- [x] `ENABLE_PLATFORM_PAYMENT=true` with price IDs configured → paid
upgrade redirects to Stripe Checkout
- [x] Stripe webhook `customer.subscription.created` →
`User.subscriptionTier` updated
  - [x] Unrecognised price ID in webhook → logs warning, tier unchanged
  - [x] ENTERPRISE user webhook event → tier not overwritten
  - [x] Empty `STRIPE_WEBHOOK_SECRET` → 503 (prevents HMAC bypass)
  - [x] Open-redirect attack on `success_url`/`cancel_url` → 422

#### For configuration changes:
- [x] No `.env` or `docker-compose.yml` changes
- [x] LaunchDarkly flags to create: `stripe-price-id-pro` (string),
`stripe-price-id-business` (string), `enable-platform-payment` (bool)

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: majdyz <majdy.zamil@gmail.com>
2026-04-16 17:52:06 +07:00
Ubbe
31b88a6e56 feat(frontend): add Agent Briefing Panel (#12764)
## Summary

<img width="800" height="772" alt="Screenshot_2026-04-13_at_18 29 19"
src="https://github.com/user-attachments/assets/3da6eaf2-1485-4c08-9651-18f2f4220eba"
/>
<img width="800" height="285" alt="Screenshot_2026-04-13_at_18 29 24"
src="https://github.com/user-attachments/assets/6a5f981a-1e1d-4d22-a33d-9e1b0e7555a7"
/>
<img width="800" height="288" alt="Screenshot_2026-04-13_at_18 29 27"
src="https://github.com/user-attachments/assets/f97b4611-7c23-4fc9-a12d-edf6314a77ef"
/>
<img width="800" height="433" alt="Screenshot_2026-04-13_at_18 29 31"
src="https://github.com/user-attachments/assets/e6d7241d-84f3-4936-b8cd-e0b12df392bb"
/>
<img width="700" height="554" alt="Screenshot_2026-04-13_at_18 29 40"
src="https://github.com/user-attachments/assets/92c08f21-f950-45cd-8c1d-529905a6e85f"
/>


Implements the Agent Intelligence Layer — real-time agent awareness
across the Library and Copilot pages.

### Core Features
- **Agent Briefing Panel** — stats grid with fleet-wide counts (running,
recently completed, needs attention, scheduled, idle, monthly spend) and
tab-driven content below
- **Enhanced Library Cards** — StatusBadge, run counts, contextual
action buttons (See tasks, Start, Chat) with consistent icon-left
styling
- **Situation Report Items** — prioritized sitrep with error-first
ranking, "See task" deep-links for completed runs, and "Ask AutoPilot"
bridge
- **Home Pulse Chips** — agent status chips on Copilot empty state with
hover-reveal actions (slide-up animation + backdrop blur on desktop,
always visible on touch)
- **Edit Display Name** — pencil icon on Copilot greeting to update
Supabase user metadata inline

### Backend
- **Execution count API** — batch `COUNT(*)` query on
`AgentGraphExecution` grouped by `agentGraphId` for the current user,
avoiding loading full execution rows. Wired into `list_library_agents`
and `list_favorite_library_agents` via `execution_count_override` on
`LibraryAgent.from_db()`

### UI Polish
- Subtler gradient on AgentBriefingPanel (reduced opacity on background
+ animated border)
- Consistent button styles across all action buttons (icon-left, same
sizing)
- Removed duplicate "Open in builder" menu item (kept "Edit agent")
- "Recently completed" tab replaces "Listening" in briefing panel,
showing agents with completed runs in last 72h

## Changes

### Backend
- `backend/api/features/library/db.py` — added
`_fetch_execution_counts()` batch COUNT query, wired into list endpoints
- `backend/api/features/library/model.py` — added
`execution_count_override` param to `LibraryAgent.from_db()`

### Frontend — New files
- `EditNameDialog/EditNameDialog.tsx` — modal to update display name via
Supabase auth
- `PulseChips/PulseChips.module.css` — hover-reveal animation + glass
panel styles

### Frontend — Modified files
- `EmptySession.tsx` — added EditNameDialog and PulseChips
- `PulseChips.tsx` — redesigned with See/Ask buttons, hover overlay on
desktop
- `usePulseChips.ts` — added agentID for deep-linking
- `AgentBriefingPanel.tsx` — subtler gradient, adjusted padding
- `AgentBriefingPanel.module.css` — reduced conic gradient opacity
- `BriefingTabContent.tsx` — added "completed" tab routing
- `StatsGrid.tsx` — replaced Listening with Recently completed,
reordered tabs
- `SitrepItem.tsx` — consistent button styles, "See task" link for
completed items, updated copilot prompt
- `ContextualActionButton.tsx` — icon-left, smaller icon, renamed Run to
Start
- `LibraryAgentCard.tsx` — icon-left on all buttons, EyeIcon for See
tasks
- `AgentCardMenu.tsx` — removed duplicate "Open in builder"
- `useAgentStatus.ts` — added completed count to FleetSummary
- `useLibraryFleetSummary.ts` — added recent completion tracking
- `types.ts` — added `completed` to FleetSummary and AgentStatusFilter

## Test plan
- [ ] Library page renders Agent Briefing Panel with stats grid
- [ ] "Recently completed" tab shows agents with completed runs in last
72h
- [ ] Agent cards show real execution counts (not 0)
- [ ] Action buttons have consistent styling with icon on the left
- [ ] "See task" on completed items deep-links to agent page with
execution selected
- [ ] "Ask AutoPilot" generates last-run-specific prompt for completed
items
- [ ] Copilot empty state shows PulseChips with hover-reveal actions on
desktop
- [ ] PulseChips show See/Ask buttons always on touch screens
- [ ] Pencil icon on greeting opens edit name dialog
- [ ] Name update persists via Supabase and refreshes greeting
- [ ] `pnpm format && pnpm lint && pnpm types` pass
- [ ] `poetry run format` passes for backend changes

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: John Ababseh <jababseh7@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Bentlybro <Github@bentlybro.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: CodeRabbit <noreply@coderabbit.ai>
Co-authored-by: majdyz <zamil.majdy@agpt.co>
2026-04-16 17:32:17 +07:00
Zamil Majdy
d357956d98 refactor(backend/copilot): make session-file helper fns public to fix Pyright warnings (#12812)
## Why
After PR #12804 was squashed into dev, two module-level helper functions
in `backend/copilot/sdk/service.py` remained private (`_`-prefixed)
while being directly imported by name in `sdk/transcript_test.py`.
Pyright reports `reportAttributeAccessIssue` when tests (even those
excluded from CI lint) import private symbols from outside their
defining module.

## What
Rename two helpers to remove the underscore prefix:
- `_process_cli_restore` → `process_cli_restore`
- `_read_cli_session_from_disk` → `read_cli_session_from_disk`

Update call sites in `service.py` and imports/calls/docstrings in
`sdk/transcript_test.py`.

## How
Pure rename — no logic change. Both functions were already module-level
helpers with no reason to be private; the underscore was convention
carried over during the refactor but they are directly unit-tested and
should be public.

All 66 `sdk/transcript_test.py` tests pass after the rename.

## Checklist
- [x] Tests pass (`poetry run pytest
backend/copilot/sdk/transcript_test.py`)
- [x] No `_`-prefixed symbols imported across module boundaries
- [x] No linter suppressors added
2026-04-16 17:00:02 +07:00
Zamil Majdy
697ffa81f0 fix(backend/copilot): update transcript_test to use strip_for_upload after upload_cli_session removal 2026-04-16 16:17:02 +07:00
122 changed files with 8985 additions and 1649 deletions

View File

@@ -18,7 +18,6 @@ from backend.copilot import stream_registry
from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode
from backend.copilot.db import get_chat_messages_paginated
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
from backend.copilot.message_dedup import acquire_dedup_lock
from backend.copilot.model import (
ChatMessage,
ChatSession,
@@ -463,22 +462,13 @@ async def get_session(
Supports cursor-based pagination via ``limit`` and ``before_sequence``.
When no pagination params are provided, returns the most recent messages.
Args:
session_id: The unique identifier for the desired chat session.
user_id: The authenticated user's ID.
limit: Maximum number of messages to return (1-200, default 50).
before_sequence: Return messages with sequence < this value (cursor).
Returns:
SessionDetailResponse: Details for the requested session, including
active_stream info and pagination metadata.
"""
page = await get_chat_messages_paginated(
session_id, limit, before_sequence, user_id=user_id
)
if page is None:
raise NotFoundError(f"Session {session_id} not found.")
messages = [
_strip_injected_context(message.model_dump()) for message in page.messages
]
@@ -489,10 +479,6 @@ async def get_session(
active_session, last_message_id = await stream_registry.get_active_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_session:
active_stream_info = ActiveStreamInfo(
turn_id=active_session.turn_id,
@@ -846,9 +832,6 @@ async def stream_chat_post(
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
sanitized_file_ids: list[str] | None = None
# Capture the original message text BEFORE any mutation (attachment enrichment)
# so the idempotency hash is stable across retries.
original_message = request.message
if request.file_ids and user_id:
# Filter to valid UUIDs only to prevent DB abuse
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
@@ -877,58 +860,36 @@ async def stream_chat_post(
)
request.message += files_block
# ── Idempotency guard ────────────────────────────────────────────────────
# Blocks duplicate executor tasks from concurrent/retried POSTs.
# See backend/copilot/message_dedup.py for the full lifecycle description.
dedup_lock = None
if request.is_user_message:
dedup_lock = await acquire_dedup_lock(
session_id, original_message, sanitized_file_ids
)
if dedup_lock is None and (original_message or sanitized_file_ids):
async def _empty_sse() -> AsyncGenerator[str, None]:
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return StreamingResponse(
_empty_sse(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
"Connection": "keep-alive",
"x-vercel-ai-ui-message-stream": "v1",
},
)
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
# saved yet. append_and_save_message re-fetches inside a lock to prevent
# message loss from concurrent requests.
#
# If any of these operations raises, release the dedup lock before propagating
# so subsequent retries are not blocked for 30 s.
try:
if request.message:
message = ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
)
if request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
# saved yet. append_and_save_message returns None when a duplicate is
# detected — in that case skip enqueue to avoid processing the message twice.
is_duplicate_message = False
if request.message:
message = ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
is_duplicate_message = (
await append_and_save_message(session_id, message)
logger.info(f"[STREAM] User message saved for session {session_id}")
) is None
logger.info(f"[STREAM] User message saved for session {session_id}")
if not is_duplicate_message and request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
)
# Create a task in the stream registry for reconnection support
# Create a task in the stream registry for reconnection support.
# For duplicate messages, skip create_session entirely so the infra-retry
# client subscribes to the *existing* turn's Redis stream and receives the
# in-progress executor output rather than an empty stream.
turn_id = ""
if not is_duplicate_message:
turn_id = str(uuid4())
log_meta["turn_id"] = turn_id
session_create_start = time.perf_counter()
await stream_registry.create_session(
session_id=session_id,
@@ -946,7 +907,6 @@ async def stream_chat_post(
}
},
)
await enqueue_copilot_turn(
session_id=session_id,
user_id=user_id,
@@ -958,10 +918,10 @@ async def stream_chat_post(
mode=request.mode,
model=request.model,
)
except Exception:
if dedup_lock:
await dedup_lock.release()
raise
else:
logger.info(
f"[STREAM] Duplicate message detected for session {session_id}, skipping enqueue"
)
setup_time = (time.perf_counter() - stream_start_time) * 1000
logger.info(
@@ -985,12 +945,6 @@ async def stream_chat_post(
subscriber_queue = None
first_chunk_yielded = False
chunks_yielded = 0
# True for every exit path except GeneratorExit (client disconnect).
# On disconnect the backend turn is still running — releasing the lock
# there would reopen the infra-retry duplicate window. The 30 s TTL
# is the fallback. All other exits (normal finish, early return, error)
# should release so the user can re-send the same message.
release_dedup_lock_on_exit = True
try:
# Subscribe from the position we captured before enqueuing
# This avoids replaying old messages while catching all new ones
@@ -1002,7 +956,7 @@ async def stream_chat_post(
if subscriber_queue is None:
yield StreamFinish().to_sse()
return # finally releases dedup_lock
return
# Read from the subscriber queue and yield to SSE
logger.info(
@@ -1044,7 +998,7 @@ async def stream_chat_post(
}
},
)
break # finally releases dedup_lock
break
except asyncio.TimeoutError:
yield StreamHeartbeat().to_sse()
@@ -1060,7 +1014,6 @@ async def stream_chat_post(
}
},
)
release_dedup_lock_on_exit = False
except Exception as e:
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
logger.error(
@@ -1075,10 +1028,7 @@ async def stream_chat_post(
code="stream_error",
).to_sse()
yield StreamFinish().to_sse()
# finally releases dedup_lock
finally:
if dedup_lock and release_dedup_lock_on_exit:
await dedup_lock.release()
# Unsubscribe when client disconnects or stream ends
if subscriber_queue is not None:
try:

View File

@@ -133,21 +133,12 @@ def test_stream_chat_rejects_too_many_file_ids():
assert response.status_code == 422
def _mock_stream_internals(
mocker: pytest_mock.MockerFixture,
*,
redis_set_returns: object = True,
):
def _mock_stream_internals(mocker: pytest_mock.MockerFixture):
"""Mock the async internals of stream_chat_post so tests can exercise
validation and enrichment logic without needing Redis/RabbitMQ.
Args:
redis_set_returns: Value returned by the mocked Redis ``set`` call.
``True`` (default) simulates a fresh key (new message);
``None`` simulates a collision (duplicate blocked).
validation and enrichment logic without needing RabbitMQ.
Returns:
A namespace with ``redis``, ``save``, and ``enqueue`` mock objects so
A namespace with ``save`` and ``enqueue`` mock objects so
callers can make additional assertions about side-effects.
"""
import types
@@ -158,7 +149,7 @@ def _mock_stream_internals(
)
mock_save = mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
return_value=MagicMock(), # non-None = message was saved (not a duplicate)
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = mocker.AsyncMock(return_value=None)
@@ -174,15 +165,9 @@ def _mock_stream_internals(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
mock_redis = AsyncMock()
mock_redis.set = AsyncMock(return_value=redis_set_returns)
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=AsyncMock,
return_value=mock_redis,
return types.SimpleNamespace(
save=mock_save, enqueue=mock_enqueue, registry=mock_registry
)
ns = types.SimpleNamespace(redis=mock_redis, save=mock_save, enqueue=mock_enqueue)
return ns
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
@@ -211,6 +196,29 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
assert response.status_code == 200
# ─── Duplicate message dedup ──────────────────────────────────────────
def test_stream_chat_skips_enqueue_for_duplicate_message(
mocker: pytest_mock.MockerFixture,
):
"""When append_and_save_message returns None (duplicate detected),
enqueue_copilot_turn and stream_registry.create_session must NOT be called
to avoid double-processing and to prevent overwriting the active stream's
turn_id in Redis (which would cause reconnecting clients to miss the response)."""
mocks = _mock_stream_internals(mocker)
# Override save to return None — signalling a duplicate
mocks.save.return_value = None
response = client.post(
"/sessions/sess-1/stream",
json={"message": "hello"},
)
assert response.status_code == 200
mocks.enqueue.assert_not_called()
mocks.registry.create_session.assert_not_called()
# ─── UUID format filtering ─────────────────────────────────────────────
@@ -706,237 +714,6 @@ class TestStripInjectedContext:
assert result["content"] == "hello"
# ─── Idempotency / duplicate-POST guard ──────────────────────────────
def test_stream_chat_blocks_duplicate_post_returns_empty_sse(
mocker: pytest_mock.MockerFixture,
) -> None:
"""A second POST with the same message within the 30-s window must return
an empty SSE stream (StreamFinish + [DONE]) so the frontend marks the
turn complete without creating a ghost response."""
# redis_set_returns=None simulates a collision: the NX key already exists.
ns = _mock_stream_internals(mocker, redis_set_returns=None)
response = client.post(
"/sessions/sess-dup/stream",
json={"message": "duplicate message", "is_user_message": True},
)
assert response.status_code == 200
body = response.text
# The response must contain StreamFinish (type=finish) and the SSE [DONE] terminator.
assert '"finish"' in body
assert "[DONE]" in body
# The empty SSE response must include the AI SDK protocol header so the
# frontend treats it as a valid stream and marks the turn complete.
assert response.headers.get("x-vercel-ai-ui-message-stream") == "v1"
# The duplicate guard must prevent save/enqueue side effects.
ns.save.assert_not_called()
ns.enqueue.assert_not_called()
def test_stream_chat_first_post_proceeds_normally(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The first POST (Redis NX key set successfully) must proceed through the
normal streaming path — no early return."""
ns = _mock_stream_internals(mocker, redis_set_returns=True)
response = client.post(
"/sessions/sess-new/stream",
json={"message": "first message", "is_user_message": True},
)
assert response.status_code == 200
# Redis set must have been called once with the NX flag.
ns.redis.set.assert_called_once()
call_kwargs = ns.redis.set.call_args
assert call_kwargs.kwargs.get("nx") is True
def test_stream_chat_dedup_skipped_for_non_user_messages(
mocker: pytest_mock.MockerFixture,
) -> None:
"""System/assistant messages (is_user_message=False) bypass the dedup
guard — they are injected programmatically and must always be processed."""
ns = _mock_stream_internals(mocker, redis_set_returns=None)
response = client.post(
"/sessions/sess-sys/stream",
json={"message": "system context", "is_user_message": False},
)
# Even though redis_set_returns=None (would block a user message),
# the endpoint must proceed because is_user_message=False.
assert response.status_code == 200
ns.redis.set.assert_not_called()
def test_stream_chat_dedup_hash_uses_original_message_not_mutated(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The dedup hash must be computed from the original request message,
not the mutated version that has the [Attached files] block appended.
A file_id is sent so the route actually appends the [Attached files] block,
exercising the mutation path — the hash must still match the original text."""
import hashlib
ns = _mock_stream_internals(mocker, redis_set_returns=True)
file_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
# Mock workspace + prisma so the attachment block is actually appended.
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
return_value=type("W", (), {"id": "ws-1"})(),
)
fake_file = type(
"F",
(),
{
"id": file_id,
"name": "doc.pdf",
"mimeType": "application/pdf",
"sizeBytes": 1024,
},
)()
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[fake_file])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
response = client.post(
"/sessions/sess-hash/stream",
json={
"message": "plain message",
"is_user_message": True,
"file_ids": [file_id],
},
)
assert response.status_code == 200
ns.redis.set.assert_called_once()
call_args = ns.redis.set.call_args
dedup_key = call_args.args[0]
# Hash must use the original message + sorted file IDs, not the mutated text.
expected_hash = hashlib.sha256(
f"sess-hash:plain message:{file_id}".encode()
).hexdigest()[:16]
expected_key = f"chat:msg_dedup:sess-hash:{expected_hash}"
assert dedup_key == expected_key, (
f"Dedup key {dedup_key!r} does not match expected {expected_key!r}"
"hash may be using mutated message or wrong inputs"
)
def test_stream_chat_dedup_key_released_after_stream_finish(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The dedup Redis key must be deleted after the turn completes (when
subscriber_queue is None the route yields StreamFinish immediately and
should release the key so the user can re-send the same message)."""
from unittest.mock import AsyncMock as _AsyncMock
# Set up all internals manually so we can control subscribe_to_session.
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.enqueue_copilot_turn",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = _AsyncMock(return_value=None)
# None → early-finish path: StreamFinish yielded immediately, dedup key released.
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
mocker.patch(
"backend.api.features.chat.routes.stream_registry",
mock_registry,
)
mock_redis = mocker.AsyncMock()
mock_redis.set = _AsyncMock(return_value=True)
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=_AsyncMock,
return_value=mock_redis,
)
response = client.post(
"/sessions/sess-finish/stream",
json={"message": "hello", "is_user_message": True},
)
assert response.status_code == 200
body = response.text
assert '"finish"' in body
# The dedup key must be released so intentional re-sends are allowed.
mock_redis.delete.assert_called_once()
def test_stream_chat_dedup_key_released_even_when_redis_delete_raises(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The route must not crash when the dedup Redis delete fails on the
subscriber_queue-is-None early-finish path (except Exception: pass)."""
from unittest.mock import AsyncMock as _AsyncMock
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.enqueue_copilot_turn",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = _AsyncMock(return_value=None)
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
mocker.patch(
"backend.api.features.chat.routes.stream_registry",
mock_registry,
)
mock_redis = mocker.AsyncMock()
mock_redis.set = _AsyncMock(return_value=True)
# Make the delete raise so the except-pass branch is exercised.
mock_redis.delete = _AsyncMock(side_effect=RuntimeError("redis gone"))
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=_AsyncMock,
return_value=mock_redis,
)
# Should not raise even though delete fails.
response = client.post(
"/sessions/sess-finish-err/stream",
json={"message": "hello", "is_user_message": True},
)
assert response.status_code == 200
assert '"finish"' in response.text
# delete must have been attempted — the except-pass branch silenced the error.
mock_redis.delete.assert_called_once()
# ─── DELETE /sessions/{id}/stream — disconnect listeners ──────────────
@@ -980,3 +757,59 @@ def test_disconnect_stream_returns_404_when_session_missing(
assert response.status_code == 404
mock_disconnect.assert_not_awaited()
# ─── GET /sessions/{session_id} — backward pagination ─────────────────────────
def _make_paginated_messages(
mocker: pytest_mock.MockerFixture, *, has_more: bool = False
):
"""Return a mock PaginatedMessages and configure the DB patch."""
from datetime import UTC, datetime
from backend.copilot.db import PaginatedMessages
from backend.copilot.model import ChatMessage, ChatSessionInfo, ChatSessionMetadata
now = datetime.now(UTC)
session_info = ChatSessionInfo(
session_id="sess-1",
user_id=TEST_USER_ID,
usage=[],
started_at=now,
updated_at=now,
metadata=ChatSessionMetadata(),
)
page = PaginatedMessages(
messages=[ChatMessage(role="user", content="hello", sequence=0)],
has_more=has_more,
oldest_sequence=0,
session=session_info,
)
mock_paginate = mocker.patch(
"backend.api.features.chat.routes.get_chat_messages_paginated",
new_callable=AsyncMock,
return_value=page,
)
return page, mock_paginate
def test_get_session_returns_backward_paginated(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""All sessions use backward (newest-first) pagination."""
_make_paginated_messages(mocker)
mocker.patch(
"backend.api.features.chat.routes.stream_registry.get_active_session",
new_callable=AsyncMock,
return_value=(None, None),
)
response = client.get("/sessions/sess-1")
assert response.status_code == 200
data = response.json()
assert data["oldest_sequence"] == 0
assert "forward_paginated" not in data
assert "newest_sequence" not in data

View File

@@ -12,6 +12,7 @@ import prisma.models
import backend.api.features.library.model as library_model
import backend.data.graph as graph_db
from backend.api.features.library.db import _fetch_schedule_info
from backend.data.graph import GraphModel, GraphSettings
from backend.data.includes import library_agent_include
from backend.util.exceptions import NotFoundError
@@ -117,4 +118,5 @@ async def add_graph_to_library(
f"for store listing version #{store_listing_version_id} "
f"to library for user #{user_id}"
)
return library_model.LibraryAgent.from_db(added_agent)
schedule_info = await _fetch_schedule_info(user_id, graph_id=graph_model.id)
return library_model.LibraryAgent.from_db(added_agent, schedule_info=schedule_info)

View File

@@ -21,13 +21,17 @@ async def test_add_graph_to_library_create_new_agent() -> None:
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
return_value=converted_agent,
) as mock_from_db,
patch(
"backend.api.features.library._add_to_library._fetch_schedule_info",
new=AsyncMock(return_value={}),
),
):
mock_prisma.return_value.create = AsyncMock(return_value=created_agent)
result = await add_graph_to_library("slv-id", graph_model, "user-id")
assert result is converted_agent
mock_from_db.assert_called_once_with(created_agent)
mock_from_db.assert_called_once_with(created_agent, schedule_info={})
# Verify create was called with correct data
create_call = mock_prisma.return_value.create.call_args
create_data = create_call.kwargs["data"]
@@ -54,6 +58,10 @@ async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
return_value=converted_agent,
) as mock_from_db,
patch(
"backend.api.features.library._add_to_library._fetch_schedule_info",
new=AsyncMock(return_value={}),
),
):
mock_prisma.return_value.create = AsyncMock(
side_effect=prisma.errors.UniqueViolationError(
@@ -65,7 +73,7 @@ async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
result = await add_graph_to_library("slv-id", graph_model, "user-id")
assert result is converted_agent
mock_from_db.assert_called_once_with(updated_agent)
mock_from_db.assert_called_once_with(updated_agent, schedule_info={})
# Verify update was called with correct where and data
update_call = mock_prisma.return_value.update.call_args
assert update_call.kwargs["where"] == {

View File

@@ -1,6 +1,7 @@
import asyncio
import itertools
import logging
from datetime import datetime, timezone
from typing import Literal, Optional
import fastapi
@@ -43,6 +44,65 @@ config = Config()
integration_creds_manager = IntegrationCredentialsManager()
async def _fetch_execution_counts(user_id: str, graph_ids: list[str]) -> dict[str, int]:
"""Fetch execution counts per graph in a single batched query."""
if not graph_ids:
return {}
rows = await prisma.models.AgentGraphExecution.prisma().group_by(
by=["agentGraphId"],
where={
"userId": user_id,
"agentGraphId": {"in": graph_ids},
"isDeleted": False,
},
count=True,
)
return {
row["agentGraphId"]: int((row.get("_count") or {}).get("_all") or 0)
for row in rows
}
async def _fetch_schedule_info(
user_id: str, graph_id: Optional[str] = None
) -> dict[str, str]:
"""Fetch a map of graph_id → earliest next_run_time ISO string.
When `graph_id` is provided, the scheduler query is narrowed to that graph,
which is cheaper for single-agent lookups (detail page, post-update, etc.).
"""
try:
scheduler_client = get_scheduler_client()
schedules = await scheduler_client.get_execution_schedules(
graph_id=graph_id,
user_id=user_id,
)
earliest: dict[str, tuple[datetime, str]] = {}
for s in schedules:
parsed = _parse_iso_datetime(s.next_run_time)
if parsed is None:
continue
current = earliest.get(s.graph_id)
if current is None or parsed < current[0]:
earliest[s.graph_id] = (parsed, s.next_run_time)
return {graph_id: iso for graph_id, (_, iso) in earliest.items()}
except Exception:
logger.warning("Failed to fetch schedules for library agents", exc_info=True)
return {}
def _parse_iso_datetime(value: str) -> Optional[datetime]:
"""Parse an ISO 8601 datetime, tolerating `Z` and naive forms (assumed UTC)."""
try:
parsed = datetime.fromisoformat(value.replace("Z", "+00:00"))
except ValueError:
logger.warning("Failed to parse schedule next_run_time: %s", value)
return None
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
return parsed
async def list_library_agents(
user_id: str,
search_term: Optional[str] = None,
@@ -137,12 +197,22 @@ async def list_library_agents(
logger.debug(f"Retrieved {len(library_agents)} library agents for user #{user_id}")
graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId]
execution_counts, schedule_info = await asyncio.gather(
_fetch_execution_counts(user_id, graph_ids),
_fetch_schedule_info(user_id),
)
# Only pass valid agents to the response
valid_library_agents: list[library_model.LibraryAgent] = []
for agent in library_agents:
try:
library_agent = library_model.LibraryAgent.from_db(agent)
library_agent = library_model.LibraryAgent.from_db(
agent,
execution_count_override=execution_counts.get(agent.agentGraphId),
schedule_info=schedule_info,
)
valid_library_agents.append(library_agent)
except Exception as e:
# Skip this agent if there was an error
@@ -214,12 +284,22 @@ async def list_favorite_library_agents(
f"Retrieved {len(library_agents)} favorite library agents for user #{user_id}"
)
graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId]
execution_counts, schedule_info = await asyncio.gather(
_fetch_execution_counts(user_id, graph_ids),
_fetch_schedule_info(user_id),
)
# Only pass valid agents to the response
valid_library_agents: list[library_model.LibraryAgent] = []
for agent in library_agents:
try:
library_agent = library_model.LibraryAgent.from_db(agent)
library_agent = library_model.LibraryAgent.from_db(
agent,
execution_count_override=execution_counts.get(agent.agentGraphId),
schedule_info=schedule_info,
)
valid_library_agents.append(library_agent)
except Exception as e:
# Skip this agent if there was an error
@@ -285,6 +365,12 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
where={"userId": store_listing.owningUserId}
)
schedule_info = (
await _fetch_schedule_info(user_id, graph_id=library_agent.AgentGraph.id)
if library_agent.AgentGraph
else {}
)
return library_model.LibraryAgent.from_db(
library_agent,
sub_graphs=(
@@ -294,6 +380,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
),
store_listing=store_listing,
profile=profile,
schedule_info=schedule_info,
)
@@ -329,7 +416,10 @@ async def get_library_agent_by_store_version_id(
},
include=library_agent_include(user_id),
)
return library_model.LibraryAgent.from_db(agent) if agent else None
if not agent:
return None
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent.agentGraphId)
return library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
async def get_library_agent_by_graph_id(
@@ -358,7 +448,10 @@ async def get_library_agent_by_graph_id(
assert agent.AgentGraph # make type checker happy
# Include sub-graphs so we can make a full credentials input schema
sub_graphs = await graph_db.get_sub_graphs(agent.AgentGraph)
return library_model.LibraryAgent.from_db(agent, sub_graphs=sub_graphs)
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent.agentGraphId)
return library_model.LibraryAgent.from_db(
agent, sub_graphs=sub_graphs, schedule_info=schedule_info
)
async def add_generated_agent_image(
@@ -500,7 +593,11 @@ async def create_library_agent(
for agent, graph in zip(library_agents, graph_entries):
asyncio.create_task(add_generated_agent_image(graph, user_id, agent.id))
return [library_model.LibraryAgent.from_db(agent) for agent in library_agents]
schedule_info = await _fetch_schedule_info(user_id)
return [
library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
for agent in library_agents
]
async def update_agent_version_in_library(
@@ -562,7 +659,8 @@ async def update_agent_version_in_library(
f"Failed to update library agent for {agent_graph_id} v{agent_graph_version}"
)
return library_model.LibraryAgent.from_db(lib)
schedule_info = await _fetch_schedule_info(user_id, graph_id=agent_graph_id)
return library_model.LibraryAgent.from_db(lib, schedule_info=schedule_info)
async def create_graph_in_library(
@@ -1467,7 +1565,11 @@ async def bulk_move_agents_to_folder(
),
)
return [library_model.LibraryAgent.from_db(agent) for agent in agents]
schedule_info = await _fetch_schedule_info(user_id)
return [
library_model.LibraryAgent.from_db(agent, schedule_info=schedule_info)
for agent in agents
]
def collect_tree_ids(

View File

@@ -65,6 +65,11 @@ async def test_get_library_agents(mocker):
)
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
mocker.patch(
"backend.api.features.library.db._fetch_execution_counts",
new=mocker.AsyncMock(return_value={}),
)
# Call function
result = await db.list_library_agents("test-user")
@@ -353,3 +358,136 @@ async def test_create_library_agent_uses_upsert():
# Verify update branch restores soft-deleted/archived agents
assert data["update"]["isDeleted"] is False
assert data["update"]["isArchived"] is False
@pytest.mark.asyncio
async def test_list_favorite_library_agents(mocker):
mock_library_agents = [
prisma.models.LibraryAgent(
id="fav1",
userId="test-user",
agentGraphId="agent-fav",
settings="{}", # type: ignore
agentGraphVersion=1,
isCreatedByUser=False,
isDeleted=False,
isArchived=False,
createdAt=datetime.now(),
updatedAt=datetime.now(),
isFavorite=True,
useGraphIsActiveVersion=True,
AgentGraph=prisma.models.AgentGraph(
id="agent-fav",
version=1,
name="Favorite Agent",
description="My Favorite",
userId="other-user",
isActive=True,
createdAt=datetime.now(),
),
)
]
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_many = mocker.AsyncMock(
return_value=mock_library_agents
)
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
mocker.patch(
"backend.api.features.library.db._fetch_execution_counts",
new=mocker.AsyncMock(return_value={"agent-fav": 7}),
)
result = await db.list_favorite_library_agents("test-user")
assert len(result.agents) == 1
assert result.agents[0].id == "fav1"
assert result.agents[0].name == "Favorite Agent"
assert result.agents[0].graph_id == "agent-fav"
assert result.pagination.total_items == 1
assert result.pagination.total_pages == 1
assert result.pagination.current_page == 1
assert result.pagination.page_size == 50
@pytest.mark.asyncio
async def test_list_library_agents_skips_failed_agent(mocker):
"""Agents that fail parsing should be skipped — covers the except branch."""
mock_library_agents = [
prisma.models.LibraryAgent(
id="ua-bad",
userId="test-user",
agentGraphId="agent-bad",
settings="{}", # type: ignore
agentGraphVersion=1,
isCreatedByUser=False,
isDeleted=False,
isArchived=False,
createdAt=datetime.now(),
updatedAt=datetime.now(),
isFavorite=False,
useGraphIsActiveVersion=True,
AgentGraph=prisma.models.AgentGraph(
id="agent-bad",
version=1,
name="Bad Agent",
description="",
userId="other-user",
isActive=True,
createdAt=datetime.now(),
),
)
]
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_many = mocker.AsyncMock(
return_value=mock_library_agents
)
mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1)
mocker.patch(
"backend.api.features.library.db._fetch_execution_counts",
new=mocker.AsyncMock(return_value={}),
)
mocker.patch(
"backend.api.features.library.model.LibraryAgent.from_db",
side_effect=Exception("parse error"),
)
result = await db.list_library_agents("test-user")
assert len(result.agents) == 0
assert result.pagination.total_items == 1
@pytest.mark.asyncio
async def test_fetch_execution_counts_empty_graph_ids():
result = await db._fetch_execution_counts("user-1", [])
assert result == {}
@pytest.mark.asyncio
async def test_fetch_execution_counts_uses_group_by(mocker):
mock_prisma = mocker.patch("prisma.models.AgentGraphExecution.prisma")
mock_prisma.return_value.group_by = mocker.AsyncMock(
return_value=[
{"agentGraphId": "graph-1", "_count": {"_all": 5}},
{"agentGraphId": "graph-2", "_count": {"_all": 2}},
]
)
result = await db._fetch_execution_counts(
"user-1", ["graph-1", "graph-2", "graph-3"]
)
assert result == {"graph-1": 5, "graph-2": 2}
mock_prisma.return_value.group_by.assert_called_once_with(
by=["agentGraphId"],
where={
"userId": "user-1",
"agentGraphId": {"in": ["graph-1", "graph-2", "graph-3"]},
"isDeleted": False,
},
count=True,
)

View File

@@ -214,6 +214,14 @@ class LibraryAgent(pydantic.BaseModel):
folder_name: str | None = None # Denormalized for display
recommended_schedule_cron: str | None = None
is_scheduled: bool = pydantic.Field(
default=False,
description="Whether this agent has active execution schedules",
)
next_scheduled_run: str | None = pydantic.Field(
default=None,
description="ISO 8601 timestamp of the next scheduled run, if any",
)
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
marketplace_listing: Optional["MarketplaceListing"] = None
@@ -223,6 +231,8 @@ class LibraryAgent(pydantic.BaseModel):
sub_graphs: Optional[list[prisma.models.AgentGraph]] = None,
store_listing: Optional[prisma.models.StoreListing] = None,
profile: Optional[prisma.models.Profile] = None,
execution_count_override: Optional[int] = None,
schedule_info: Optional[dict[str, str]] = None,
) -> "LibraryAgent":
"""
Factory method that constructs a LibraryAgent from a Prisma LibraryAgent
@@ -258,10 +268,14 @@ class LibraryAgent(pydantic.BaseModel):
status = status_result.status
new_output = status_result.new_output
execution_count = len(executions)
execution_count = (
execution_count_override
if execution_count_override is not None
else len(executions)
)
success_rate: float | None = None
avg_correctness_score: float | None = None
if execution_count > 0:
if executions and execution_count > 0:
success_count = sum(
1
for e in executions
@@ -354,6 +368,10 @@ class LibraryAgent(pydantic.BaseModel):
folder_id=agent.folderId,
folder_name=agent.Folder.name if agent.Folder else None,
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
is_scheduled=bool(schedule_info and agent.agentGraphId in schedule_info),
next_scheduled_run=(
schedule_info.get(agent.agentGraphId) if schedule_info else None
),
settings=_parse_settings(agent.settings),
marketplace_listing=marketplace_listing_data,
)

View File

@@ -1,11 +1,66 @@
import datetime
import prisma.enums
import prisma.models
import pytest
from . import model as library_model
def _make_library_agent(
*,
graph_id: str = "g1",
executions: list | None = None,
) -> prisma.models.LibraryAgent:
return prisma.models.LibraryAgent(
id="la1",
userId="u1",
agentGraphId=graph_id,
settings="{}", # type: ignore
agentGraphVersion=1,
isCreatedByUser=True,
isDeleted=False,
isArchived=False,
createdAt=datetime.datetime.now(),
updatedAt=datetime.datetime.now(),
isFavorite=False,
useGraphIsActiveVersion=True,
AgentGraph=prisma.models.AgentGraph(
id=graph_id,
version=1,
name="Agent",
description="Desc",
userId="u1",
isActive=True,
createdAt=datetime.datetime.now(),
Executions=executions,
),
)
def test_from_db_execution_count_override_covers_success_rate():
"""Covers execution_count_override is not None branch and executions/count > 0 block."""
now = datetime.datetime.now(datetime.timezone.utc)
exec1 = prisma.models.AgentGraphExecution(
id="exec-1",
agentGraphId="g1",
agentGraphVersion=1,
userId="u1",
executionStatus=prisma.enums.AgentExecutionStatus.COMPLETED,
createdAt=now,
updatedAt=now,
isDeleted=False,
isShared=False,
)
agent = _make_library_agent(executions=[exec1])
result = library_model.LibraryAgent.from_db(agent, execution_count_override=1)
assert result.execution_count == 1
assert result.success_rate is not None
assert result.success_rate == 100.0
@pytest.mark.asyncio
async def test_agent_preset_from_db(test_user_id: str):
# Create mock DB agent

View File

@@ -5,7 +5,8 @@ import time
import uuid
from collections import defaultdict
from datetime import datetime, timezone
from typing import Annotated, Any, Literal, Sequence, get_args
from typing import Annotated, Any, Literal, Sequence, cast, get_args
from urllib.parse import urlparse
import pydantic
import stripe
@@ -54,8 +55,11 @@ from backend.data.credit import (
cancel_stripe_subscription,
create_subscription_checkout,
get_auto_top_up,
get_proration_credit_cents,
get_subscription_price_id,
get_user_credit_model,
handle_subscription_payment_failure,
modify_stripe_subscription_for_tier,
set_auto_top_up,
set_subscription_tier,
sync_subscription_from_stripe,
@@ -699,9 +703,72 @@ class SubscriptionCheckoutResponse(BaseModel):
class SubscriptionStatusResponse(BaseModel):
tier: str
monthly_cost: int
tier_costs: dict[str, int]
tier: Literal["FREE", "PRO", "BUSINESS", "ENTERPRISE"]
monthly_cost: int # amount in cents (Stripe convention)
tier_costs: dict[str, int] # tier name -> amount in cents
proration_credit_cents: int # unused portion of current sub to convert on upgrade
def _validate_checkout_redirect_url(url: str) -> bool:
"""Return True if `url` matches the configured frontend origin.
Prevents open-redirect: attackers must not be able to supply arbitrary
success_url/cancel_url that Stripe will redirect users to after checkout.
Pre-parse rejection rules (applied before urlparse):
- Backslashes (``\\``) are normalised differently across parsers/browsers.
- Control characters (U+0000U+001F) are not valid in URLs and may confuse
some URL-parsing implementations.
"""
# Reject characters that can confuse URL parsers before any parsing.
if "\\" in url:
return False
if any(ord(c) < 0x20 for c in url):
return False
allowed = settings.config.frontend_base_url or settings.config.platform_base_url
if not allowed:
# No configured origin — refuse to validate rather than allow arbitrary URLs.
return False
try:
parsed = urlparse(url)
allowed_parsed = urlparse(allowed)
except ValueError:
return False
if parsed.scheme not in ("http", "https"):
return False
# Reject ``user:pass@host`` authority tricks — ``@`` in the netloc component
# can trick browsers into connecting to a different host than displayed.
# ``@`` in query/fragment is harmless and must be allowed.
if "@" in parsed.netloc:
return False
return (
parsed.scheme == allowed_parsed.scheme
and parsed.netloc == allowed_parsed.netloc
)
@cached(ttl_seconds=300, maxsize=32, cache_none=False)
async def _get_stripe_price_amount(price_id: str) -> int | None:
"""Return the unit_amount (cents) for a Stripe Price ID, cached for 5 minutes.
Returns ``None`` on transient Stripe errors. ``cache_none=False`` opts out
of caching the ``None`` sentinel so the next request retries Stripe instead
of being served a stale "no price" for the rest of the TTL window. Callers
should treat ``None`` as an unknown price and fall back to 0.
Stripe prices rarely change; caching avoids a ~200-600 ms Stripe round-trip on
every GET /credits/subscription page load and reduces quota consumption.
"""
try:
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
return price.unit_amount or 0
except stripe.StripeError:
logger.warning(
"Failed to retrieve Stripe price %s — returning None (not cached)",
price_id,
)
return None
@v1_router.get(
@@ -722,21 +789,26 @@ async def get_subscription_status(
*[get_subscription_price_id(t) for t in paid_tiers]
)
tier_costs: dict[str, int] = {"FREE": 0, "ENTERPRISE": 0}
for t, price_id in zip(paid_tiers, price_ids):
cost = 0
if price_id:
try:
price = await run_in_threadpool(stripe.Price.retrieve, price_id)
cost = price.unit_amount or 0
except stripe.StripeError:
pass
tier_costs: dict[str, int] = {
SubscriptionTier.FREE.value: 0,
SubscriptionTier.ENTERPRISE.value: 0,
}
async def _cost(pid: str | None) -> int:
return (await _get_stripe_price_amount(pid) or 0) if pid else 0
costs = await asyncio.gather(*[_cost(pid) for pid in price_ids])
for t, cost in zip(paid_tiers, costs):
tier_costs[t.value] = cost
current_monthly_cost = tier_costs.get(tier.value, 0)
proration_credit = await get_proration_credit_cents(user_id, current_monthly_cost)
return SubscriptionStatusResponse(
tier=tier.value,
monthly_cost=tier_costs.get(tier.value, 0),
monthly_cost=current_monthly_cost,
tier_costs=tier_costs,
proration_credit_cents=proration_credit,
)
@@ -766,24 +838,125 @@ async def update_subscription_tier(
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
)
# Downgrade to FREE: cancel active Stripe subscription, then update the DB tier.
# Downgrade to FREE: schedule Stripe cancellation at period end so the user
# keeps their tier for the time they already paid for. The DB tier is NOT
# updated here when a subscription exists — the customer.subscription.deleted
# webhook fires at period end and downgrades to FREE then.
# Exception: if the user has no active Stripe subscription (e.g. admin-granted
# tier), cancel_stripe_subscription returns False and we update the DB tier
# immediately since no webhook will ever fire.
# When payment is disabled entirely, update the DB tier directly.
if tier == SubscriptionTier.FREE:
if payment_enabled:
await cancel_stripe_subscription(user_id)
try:
had_subscription = await cancel_stripe_subscription(user_id)
except stripe.StripeError as e:
# Log full Stripe error server-side but return a generic message
# to the client — raw Stripe errors can leak customer/sub IDs and
# infrastructure config details.
logger.exception(
"Stripe error cancelling subscription for user %s: %s",
user_id,
e,
)
raise HTTPException(
status_code=502,
detail=(
"Unable to cancel your subscription right now. "
"Please try again or contact support."
),
)
if not had_subscription:
# No active Stripe subscription found — the user was on an
# admin-granted tier. Update DB immediately since the
# subscription.deleted webhook will never fire.
await set_subscription_tier(user_id, tier)
return SubscriptionCheckoutResponse(url="")
await set_subscription_tier(user_id, tier)
return SubscriptionCheckoutResponse(url="")
# Beta users (payment not enabled) → update tier directly without Stripe.
# Paid tier changes require payment to be enabled — block self-service upgrades
# when the flag is off. Admins use the /api/admin/ routes to set tiers directly.
if not payment_enabled:
await set_subscription_tier(user_id, tier)
raise HTTPException(
status_code=422,
detail=f"Subscription not available for tier {tier}",
)
# No-op short-circuit: if the user is already on the requested paid tier,
# do NOT create a new Checkout Session. Without this guard, a duplicate
# request (double-click, retried POST, stale page) creates a second
# subscription for the same price; the user would be charged for both
# until `_cleanup_stale_subscriptions` runs from the resulting webhook —
# which only fires after the second charge has cleared.
if (user.subscription_tier or SubscriptionTier.FREE) == tier:
return SubscriptionCheckoutResponse(url="")
# Paid upgrade → create Stripe Checkout Session.
# Paid→paid tier change: if the user already has a Stripe subscription,
# modify it in-place with proration instead of creating a new Checkout
# Session. This preserves remaining paid time and avoids double-charging.
# The customer.subscription.updated webhook fires and updates the DB tier.
current_tier = user.subscription_tier or SubscriptionTier.FREE
if current_tier in (SubscriptionTier.PRO, SubscriptionTier.BUSINESS):
try:
modified = await modify_stripe_subscription_for_tier(user_id, tier)
if modified:
return SubscriptionCheckoutResponse(url="")
# modify_stripe_subscription_for_tier returns False when no active
# Stripe subscription exists — i.e. the user has an admin-granted
# paid tier with no Stripe record. In that case, update the DB
# tier directly (same as the FREE-downgrade path for admin-granted
# users) rather than sending them through a new Checkout Session.
await set_subscription_tier(user_id, tier)
return SubscriptionCheckoutResponse(url="")
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
except stripe.StripeError as e:
logger.exception(
"Stripe error modifying subscription for user %s: %s", user_id, e
)
raise HTTPException(
status_code=502,
detail=(
"Unable to update your subscription right now. "
"Please try again or contact support."
),
)
# Paid upgrade from FREE → create Stripe Checkout Session.
if not request.success_url or not request.cancel_url:
raise HTTPException(
status_code=422,
detail="success_url and cancel_url are required for paid tier upgrades",
)
# Open-redirect protection: both URLs must point to the configured frontend
# origin, otherwise an attacker could use our Stripe integration as a
# redirector to arbitrary phishing sites.
#
# Fail early with a clear 503 if the server is misconfigured (neither
# frontend_base_url nor platform_base_url set), so operators get an
# actionable error instead of the misleading "must match the platform
# frontend origin" 422 that _validate_checkout_redirect_url would otherwise
# produce when `allowed` is empty.
if not (settings.config.frontend_base_url or settings.config.platform_base_url):
logger.error(
"update_subscription_tier: neither frontend_base_url nor "
"platform_base_url is configured; cannot validate checkout redirect URLs"
)
raise HTTPException(
status_code=503,
detail=(
"Payment redirect URLs cannot be validated: "
"frontend_base_url or platform_base_url must be set on the server."
),
)
if not _validate_checkout_redirect_url(
request.success_url
) or not _validate_checkout_redirect_url(request.cancel_url):
raise HTTPException(
status_code=422,
detail="success_url and cancel_url must match the platform frontend origin",
)
try:
url = await create_subscription_checkout(
user_id=user_id,
@@ -791,8 +964,19 @@ async def update_subscription_tier(
success_url=request.success_url,
cancel_url=request.cancel_url,
)
except (ValueError, stripe.StripeError) as e:
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
except stripe.StripeError as e:
logger.exception(
"Stripe error creating checkout session for user %s: %s", user_id, e
)
raise HTTPException(
status_code=502,
detail=(
"Unable to start checkout right now. "
"Please try again or contact support."
),
)
return SubscriptionCheckoutResponse(url=url)
@@ -801,44 +985,78 @@ async def update_subscription_tier(
path="/credits/stripe_webhook", summary="Handle Stripe webhooks", tags=["credits"]
)
async def stripe_webhook(request: Request):
webhook_secret = settings.secrets.stripe_webhook_secret
if not webhook_secret:
# Guard: an empty secret allows HMAC forgery (attacker can compute a valid
# signature over the same empty key). Reject all webhook calls when unconfigured.
logger.error(
"stripe_webhook: STRIPE_WEBHOOK_SECRET is not configured — "
"rejecting request to prevent signature bypass"
)
raise HTTPException(status_code=503, detail="Webhook not configured")
# Get the raw request body
payload = await request.body()
# Get the signature header
sig_header = request.headers.get("stripe-signature")
try:
event = stripe.Webhook.construct_event(
payload, sig_header, settings.secrets.stripe_webhook_secret
)
except ValueError as e:
event = stripe.Webhook.construct_event(payload, sig_header, webhook_secret)
except ValueError:
# Invalid payload
raise HTTPException(
status_code=400, detail=f"Invalid payload: {str(e) or type(e).__name__}"
)
except stripe.SignatureVerificationError as e:
raise HTTPException(status_code=400, detail="Invalid payload")
except stripe.SignatureVerificationError:
# Invalid signature
raise HTTPException(
status_code=400, detail=f"Invalid signature: {str(e) or type(e).__name__}"
raise HTTPException(status_code=400, detail="Invalid signature")
# Defensive payload extraction. A malformed payload (missing/non-dict
# `data.object`, missing `id`) would otherwise raise KeyError/TypeError
# AFTER signature verification — which Stripe interprets as a delivery
# failure and retries forever, while spamming Sentry with no useful info.
# Acknowledge with 200 and a warning so Stripe stops retrying.
event_type = event.get("type", "")
event_data = event.get("data") or {}
data_object = event_data.get("object") if isinstance(event_data, dict) else None
if not isinstance(data_object, dict):
logger.warning(
"stripe_webhook: %s missing or non-dict data.object; ignoring",
event_type,
)
return Response(status_code=200)
if (
event["type"] == "checkout.session.completed"
or event["type"] == "checkout.session.async_payment_succeeded"
if event_type in (
"checkout.session.completed",
"checkout.session.async_payment_succeeded",
):
await UserCredit().fulfill_checkout(session_id=event["data"]["object"]["id"])
session_id = data_object.get("id")
if not session_id:
logger.warning(
"stripe_webhook: %s missing data.object.id; ignoring", event_type
)
return Response(status_code=200)
await UserCredit().fulfill_checkout(session_id=session_id)
if event["type"] in (
if event_type in (
"customer.subscription.created",
"customer.subscription.updated",
"customer.subscription.deleted",
):
await sync_subscription_from_stripe(event["data"]["object"])
await sync_subscription_from_stripe(data_object)
if event["type"] == "charge.dispute.created":
await UserCredit().handle_dispute(event["data"]["object"])
if event_type == "invoice.payment_failed":
await handle_subscription_payment_failure(data_object)
if event["type"] == "refund.created" or event["type"] == "charge.dispute.closed":
await UserCredit().deduct_credits(event["data"]["object"])
# `handle_dispute` and `deduct_credits` expect Stripe SDK typed objects
# (Dispute/Refund). The Stripe webhook payload's `data.object` is a
# StripeObject (a dict subclass) carrying that runtime shape, so we cast
# to satisfy the type checker without changing runtime behaviour.
if event_type == "charge.dispute.created":
await UserCredit().handle_dispute(cast(stripe.Dispute, data_object))
if event_type == "refund.created" or event_type == "charge.dispute.closed":
await UserCredit().deduct_credits(
cast("stripe.Refund | stripe.Dispute", data_object)
)
return Response(status_code=200)

View File

@@ -106,7 +106,6 @@ class LlmModelMeta(EnumMeta):
class LlmModel(str, Enum, metaclass=LlmModelMeta):
@classmethod
def _missing_(cls, value: object) -> "LlmModel | None":
"""Handle provider-prefixed model names like 'anthropic/claude-sonnet-4-6'."""
@@ -203,6 +202,8 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
GROK_4 = "x-ai/grok-4"
GROK_4_FAST = "x-ai/grok-4-fast"
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
GROK_4_20 = "x-ai/grok-4.20"
GROK_4_20_MULTI_AGENT = "x-ai/grok-4.20-multi-agent"
GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1"
KIMI_K2 = "moonshotai/kimi-k2"
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
@@ -627,6 +628,18 @@ MODEL_METADATA = {
LlmModel.GROK_4_1_FAST: ModelMetadata(
"open_router", 2000000, 30000, "Grok 4.1 Fast", "OpenRouter", "xAI", 1
),
LlmModel.GROK_4_20: ModelMetadata(
"open_router", 2000000, 100000, "Grok 4.20", "OpenRouter", "xAI", 3
),
LlmModel.GROK_4_20_MULTI_AGENT: ModelMetadata(
"open_router",
2000000,
100000,
"Grok 4.20 Multi-Agent",
"OpenRouter",
"xAI",
3,
),
LlmModel.GROK_CODE_FAST_1: ModelMetadata(
"open_router", 256000, 10000, "Grok Code Fast 1", "OpenRouter", "xAI", 1
),
@@ -987,7 +1000,6 @@ async def llm_call(
reasoning=reasoning,
)
elif provider == "anthropic":
an_tools = convert_openai_tool_fmt_to_anthropic(tools)
# Cache tool definitions alongside the system prompt.
# Placing cache_control on the last tool caches all tool schemas as a

View File

@@ -10,9 +10,11 @@ from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from prisma.types import (
ChatMessageCreateInput,
ChatMessageWhereInput,
ChatSessionCreateInput,
ChatSessionUpdateInput,
ChatSessionWhereInput,
FindManyChatMessageArgsFromChatSession,
)
from pydantic import BaseModel
@@ -30,6 +32,8 @@ from .model import get_chat_session as get_chat_session_cached
logger = logging.getLogger(__name__)
_BOUNDARY_SCAN_LIMIT = 10
class PaginatedMessages(BaseModel):
"""Result of a paginated message query."""
@@ -69,12 +73,10 @@ async def get_chat_messages_paginated(
in parallel with the message query. Returns ``None`` when the session
is not found or does not belong to the user.
Args:
session_id: The chat session ID.
limit: Max messages to return.
before_sequence: Cursor — return messages with sequence < this value.
user_id: If provided, filters via ``Session.userId`` so only the
session owner's messages are returned (acts as an ownership guard).
After fetching, a visibility guarantee ensures the page contains at least
one user or assistant message. If the entire page is tool messages (which
are hidden in the UI), it expands backward until a visible message is found
so the chat never appears blank.
"""
# Build session-existence / ownership check
session_where: ChatSessionWhereInput = {"id": session_id}
@@ -82,7 +84,7 @@ async def get_chat_messages_paginated(
session_where["userId"] = user_id
# Build message include — fetch paginated messages in the same query
msg_include: dict[str, Any] = {
msg_include: FindManyChatMessageArgsFromChatSession = {
"order_by": {"sequence": "desc"},
"take": limit + 1,
}
@@ -111,42 +113,18 @@ async def get_chat_messages_paginated(
# expand backward to include the preceding assistant message that
# owns the tool_calls, so convertChatSessionMessagesToUiMessages
# can pair them correctly.
_BOUNDARY_SCAN_LIMIT = 10
if results and results[0].role == "tool":
boundary_where: dict[str, Any] = {
"sessionId": session_id,
"sequence": {"lt": results[0].sequence},
}
if user_id is not None:
boundary_where["Session"] = {"is": {"userId": user_id}}
extra = await PrismaChatMessage.prisma().find_many(
where=boundary_where,
order={"sequence": "desc"},
take=_BOUNDARY_SCAN_LIMIT,
results, has_more = await _expand_tool_boundary(
session_id, results, has_more, user_id
)
# Visibility guarantee: if the entire page has no user/assistant messages
# (all tool messages), the chat would appear blank. Expand backward
# until we find at least one visible message.
if results and not any(m.role in ("user", "assistant") for m in results):
results, has_more = await _expand_for_visibility(
session_id, results, has_more, user_id
)
# Find the first non-tool message (should be the assistant)
boundary_msgs = []
found_owner = False
for msg in extra:
boundary_msgs.append(msg)
if msg.role != "tool":
found_owner = True
break
boundary_msgs.reverse()
if not found_owner:
logger.warning(
"Boundary expansion did not find owning assistant message "
"for session=%s before sequence=%s (%d msgs scanned)",
session_id,
results[0].sequence,
len(extra),
)
if boundary_msgs:
results = boundary_msgs + results
# Only mark has_more if the expanded boundary isn't the
# very start of the conversation (sequence 0).
if boundary_msgs[0].sequence > 0:
has_more = True
messages = [ChatMessage.from_db(m) for m in results]
oldest_sequence = messages[0].sequence if messages else None
@@ -159,6 +137,98 @@ async def get_chat_messages_paginated(
)
async def _expand_tool_boundary(
session_id: str,
results: list[Any],
has_more: bool,
user_id: str | None,
) -> tuple[list[Any], bool]:
"""Expand backward from the oldest message to include the owning assistant
message when the page starts mid-tool-group."""
boundary_where: ChatMessageWhereInput = {
"sessionId": session_id,
"sequence": {"lt": results[0].sequence},
}
if user_id is not None:
boundary_where["Session"] = {"is": {"userId": user_id}}
extra = await PrismaChatMessage.prisma().find_many(
where=boundary_where,
order={"sequence": "desc"},
take=_BOUNDARY_SCAN_LIMIT,
)
# Find the first non-tool message (should be the assistant)
boundary_msgs = []
found_owner = False
for msg in extra:
boundary_msgs.append(msg)
if msg.role != "tool":
found_owner = True
break
boundary_msgs.reverse()
if not found_owner:
logger.warning(
"Boundary expansion did not find owning assistant message "
"for session=%s before sequence=%s (%d msgs scanned)",
session_id,
results[0].sequence,
len(extra),
)
if boundary_msgs:
results = boundary_msgs + results
has_more = boundary_msgs[0].sequence > 0
return results, has_more
_VISIBILITY_EXPAND_LIMIT = 200
async def _expand_for_visibility(
session_id: str,
results: list[Any],
has_more: bool,
user_id: str | None,
) -> tuple[list[Any], bool]:
"""Expand backward until the page contains at least one user or assistant
message, so the chat is never blank."""
expand_where: ChatMessageWhereInput = {
"sessionId": session_id,
"sequence": {"lt": results[0].sequence},
}
if user_id is not None:
expand_where["Session"] = {"is": {"userId": user_id}}
extra = await PrismaChatMessage.prisma().find_many(
where=expand_where,
order={"sequence": "desc"},
take=_VISIBILITY_EXPAND_LIMIT,
)
if not extra:
return results, has_more
# Collect messages until we find a visible one (user/assistant)
prepend = []
found_visible = False
for msg in extra:
prepend.append(msg)
if msg.role in ("user", "assistant"):
found_visible = True
break
if not found_visible:
logger.warning(
"Visibility expansion did not find any user/assistant message "
"for session=%s before sequence=%s (%d msgs scanned)",
session_id,
results[0].sequence,
len(extra),
)
prepend.reverse()
if prepend:
results = prepend + results
has_more = prepend[0].sequence > 0
return results, has_more
async def create_chat_session(
session_id: str,
user_id: str,

View File

@@ -175,6 +175,138 @@ async def test_no_where_on_messages_without_before_sequence(
assert "where" not in include["Messages"]
# ---------- Visibility guarantee ----------
@pytest.mark.asyncio
async def test_visibility_expands_when_all_tool_messages(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""When the entire page is tool messages, expand backward to find
at least one visible (user/assistant) message so the chat isn't blank."""
find_first, find_many = mock_db
# Newest 3 messages are all tool messages (DESC → reversed to ASC)
find_first.return_value = _make_session(
messages=[
_make_msg(12, role="tool"),
_make_msg(11, role="tool"),
_make_msg(10, role="tool"),
],
)
# Boundary expansion finds the owning assistant first (boundary fix),
# then visibility expansion finds a user message further back
find_many.side_effect = [
# First call: boundary fix (oldest msg is tool → find owner)
[_make_msg(9, role="tool"), _make_msg(8, role="tool")],
# Second call: visibility expansion (still all tool → find visible)
[_make_msg(7, role="tool"), _make_msg(6, role="assistant")],
]
page = await get_chat_messages_paginated(SESSION_ID, limit=3)
assert page is not None
# Should include the expanded messages + original tool messages
roles = [m.role for m in page.messages]
assert "assistant" in roles or "user" in roles
assert page.has_more is True
@pytest.mark.asyncio
async def test_no_visibility_expansion_when_visible_messages_present(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""No visibility expansion needed when page already has visible messages."""
find_first, find_many = mock_db
# Page has an assistant message among tool messages
find_first.return_value = _make_session(
messages=[
_make_msg(5, role="tool"),
_make_msg(4, role="assistant"),
_make_msg(3, role="user"),
],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=3)
assert page is not None
# Boundary expansion might fire (oldest is tool), but NOT visibility
assert [m.sequence for m in page.messages][0] <= 3
@pytest.mark.asyncio
async def test_visibility_no_expansion_when_no_earlier_messages(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""When the page is all tool messages but there are no earlier messages
in the DB, visibility expansion returns early without changes."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(1, role="tool"), _make_msg(0, role="tool")],
)
# Boundary expansion: no earlier messages
# Visibility expansion: no earlier messages
find_many.side_effect = [[], []]
page = await get_chat_messages_paginated(SESSION_ID, limit=2)
assert page is not None
assert all(m.role == "tool" for m in page.messages)
@pytest.mark.asyncio
async def test_visibility_expansion_reaches_seq_zero(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""When visibility expansion finds a visible message at sequence 0,
has_more should be False."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(5, role="tool"), _make_msg(4, role="tool")],
)
find_many.side_effect = [
# Boundary expansion
[_make_msg(3, role="tool")],
# Visibility expansion — finds user at seq 0
[
_make_msg(2, role="tool"),
_make_msg(1, role="tool"),
_make_msg(0, role="user"),
],
]
page = await get_chat_messages_paginated(SESSION_ID, limit=2)
assert page is not None
assert page.messages[0].role == "user"
assert page.messages[0].sequence == 0
assert page.has_more is False
@pytest.mark.asyncio
async def test_visibility_expansion_with_user_id(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Visibility expansion passes user_id filter to the boundary query."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(10, role="tool")],
)
find_many.side_effect = [
# Boundary expansion
[_make_msg(9, role="tool")],
# Visibility expansion
[_make_msg(8, role="assistant")],
]
await get_chat_messages_paginated(SESSION_ID, limit=1, user_id="user-abc")
# Both find_many calls should include the user_id session filter
for call in find_many.call_args_list:
where = call.kwargs.get("where") or call[1].get("where")
assert "Session" in where
assert where["Session"] == {"is": {"userId": "user-abc"}}
@pytest.mark.asyncio
async def test_user_id_filter_applied_to_session_where(
mock_db: tuple[AsyncMock, AsyncMock],
@@ -329,7 +461,8 @@ async def test_boundary_expansion_warns_when_no_owner_found(
with patch("backend.copilot.db.logger") as mock_logger:
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
mock_logger.warning.assert_called_once()
# Two warnings: boundary expansion + visibility expansion (all tool msgs)
assert mock_logger.warning.call_count == 2
assert page is not None
assert page.messages[0].role == "tool"

View File

@@ -1,71 +0,0 @@
"""Per-request idempotency lock for the /stream endpoint.
Prevents duplicate executor tasks from concurrent or retried POSTs (e.g. k8s
rolling-deploy retries, nginx upstream retries, rapid double-clicks).
Lifecycle
---------
1. ``acquire()`` — computes a stable hash of (session_id, message, file_ids)
and atomically sets a Redis NX key. Returns a ``_DedupLock`` on success or
``None`` when the key already exists (duplicate request).
2. ``release()`` — deletes the key. Must be called on turn completion or turn
error so the next legitimate send is never blocked.
3. On client disconnect (``GeneratorExit``) the lock must NOT be released —
the backend turn is still running, and releasing would reopen the duplicate
window for infra-level retries. The 30 s TTL is the safety net.
"""
import hashlib
import logging
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
_KEY_PREFIX = "chat:msg_dedup"
_TTL_SECONDS = 30
class _DedupLock:
def __init__(self, key: str, redis) -> None:
self._key = key
self._redis = redis
async def release(self) -> None:
"""Best-effort key deletion. The TTL handles failures silently."""
try:
await self._redis.delete(self._key)
except Exception:
pass
async def acquire_dedup_lock(
session_id: str,
message: str | None,
file_ids: list[str] | None,
) -> _DedupLock | None:
"""Acquire the idempotency lock for this (session, message, files) tuple.
Returns a ``_DedupLock`` when the lock is freshly acquired (first request).
Returns ``None`` when a duplicate is detected (lock already held).
Returns ``None`` when there is nothing to deduplicate (no message, no files).
"""
if not message and not file_ids:
return None
sorted_ids = ":".join(sorted(file_ids or []))
content_hash = hashlib.sha256(
f"{session_id}:{message or ''}:{sorted_ids}".encode()
).hexdigest()[:16]
key = f"{_KEY_PREFIX}:{session_id}:{content_hash}"
redis = await get_redis_async()
acquired = await redis.set(key, "1", ex=_TTL_SECONDS, nx=True)
if not acquired:
logger.warning(
f"[STREAM] Duplicate user message blocked for session {session_id}, "
f"hash={content_hash} — returning empty SSE",
)
return None
return _DedupLock(key, redis)

View File

@@ -1,94 +0,0 @@
"""Unit tests for backend.copilot.message_dedup."""
from unittest.mock import AsyncMock
import pytest
import pytest_mock
from backend.copilot.message_dedup import _KEY_PREFIX, acquire_dedup_lock
def _patch_redis(mocker: pytest_mock.MockerFixture, *, set_returns):
mock_redis = AsyncMock()
mock_redis.set = AsyncMock(return_value=set_returns)
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=AsyncMock,
return_value=mock_redis,
)
return mock_redis
@pytest.mark.asyncio
async def test_acquire_returns_none_when_no_message_no_files(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Nothing to deduplicate — no Redis call made, None returned."""
mock_redis = _patch_redis(mocker, set_returns=True)
result = await acquire_dedup_lock("sess-1", None, None)
assert result is None
mock_redis.set.assert_not_called()
@pytest.mark.asyncio
async def test_acquire_returns_lock_on_first_request(
mocker: pytest_mock.MockerFixture,
) -> None:
"""First request acquires the lock and returns a _DedupLock."""
mock_redis = _patch_redis(mocker, set_returns=True)
lock = await acquire_dedup_lock("sess-1", "hello", None)
assert lock is not None
mock_redis.set.assert_called_once()
key_arg = mock_redis.set.call_args.args[0]
assert key_arg.startswith(f"{_KEY_PREFIX}:sess-1:")
@pytest.mark.asyncio
async def test_acquire_returns_none_on_duplicate(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Duplicate request (NX fails) returns None to signal the caller."""
_patch_redis(mocker, set_returns=None)
result = await acquire_dedup_lock("sess-1", "hello", None)
assert result is None
@pytest.mark.asyncio
async def test_acquire_key_stable_across_file_order(
mocker: pytest_mock.MockerFixture,
) -> None:
"""File IDs are sorted before hashing so order doesn't affect the key."""
mock_redis_1 = _patch_redis(mocker, set_returns=True)
await acquire_dedup_lock("sess-1", "msg", ["b", "a"])
key_ab = mock_redis_1.set.call_args.args[0]
mock_redis_2 = _patch_redis(mocker, set_returns=True)
await acquire_dedup_lock("sess-1", "msg", ["a", "b"])
key_ba = mock_redis_2.set.call_args.args[0]
assert key_ab == key_ba
@pytest.mark.asyncio
async def test_release_deletes_key(
mocker: pytest_mock.MockerFixture,
) -> None:
"""release() calls Redis delete exactly once."""
mock_redis = _patch_redis(mocker, set_returns=True)
lock = await acquire_dedup_lock("sess-1", "hello", None)
assert lock is not None
await lock.release()
mock_redis.delete.assert_called_once()
@pytest.mark.asyncio
async def test_release_swallows_redis_error(
mocker: pytest_mock.MockerFixture,
) -> None:
"""release() must not raise even when Redis delete fails."""
mock_redis = _patch_redis(mocker, set_returns=True)
mock_redis.delete = AsyncMock(side_effect=RuntimeError("redis down"))
lock = await acquire_dedup_lock("sess-1", "hello", None)
assert lock is not None
await lock.release() # must not raise
mock_redis.delete.assert_called_once()

View File

@@ -1,9 +1,8 @@
import asyncio
import logging
import uuid
from contextlib import asynccontextmanager
from datetime import UTC, datetime
from typing import Any, Self, cast
from weakref import WeakValueDictionary
from typing import Any, AsyncIterator, Self, cast
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
@@ -522,10 +521,7 @@ async def upsert_chat_session(
callers are aware of the persistence failure.
RedisError: If the cache write fails (after successful DB write).
"""
# Acquire session-specific lock to prevent concurrent upserts
lock = await _get_session_lock(session.session_id)
async with lock:
async with _get_session_lock(session.session_id) as _:
# Always query DB for existing message count to ensure consistency
existing_message_count = await chat_db().get_next_sequence(session.session_id)
@@ -651,20 +647,50 @@ async def _save_session_to_db(
msg.sequence = existing_message_count + i
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
async def append_and_save_message(
session_id: str, message: ChatMessage
) -> ChatSession | None:
"""Atomically append a message to a session and persist it.
Acquires the session lock, re-fetches the latest session state,
appends the message, and saves — preventing message loss when
concurrent requests modify the same session.
"""
lock = await _get_session_lock(session_id)
Returns the updated session, or None if the message was detected as a
duplicate (idempotency guard). Callers must check for None and skip any
downstream work (e.g. enqueuing a new LLM turn) when a duplicate is detected.
async with lock:
session = await get_chat_session(session_id)
Uses _get_session_lock (Redis NX) to serialise concurrent writers across replicas.
The idempotency check below provides a last-resort guard when the lock degrades.
"""
async with _get_session_lock(session_id) as lock_acquired:
# When the lock degraded (Redis down or 2s timeout), bypass cache for
# the idempotency check. Stale cache could let two concurrent writers
# both see the old state, pass the check, and write the same message.
if lock_acquired:
session = await get_chat_session(session_id)
else:
session = await _get_session_from_db(session_id)
if session is None:
raise ValueError(f"Session {session_id} not found")
# Idempotency: skip if the trailing block of same-role messages already
# contains this content. Uses is_message_duplicate which checks all
# consecutive trailing messages of the same role, not just [-1].
#
# This collapses infra/nginx retries whether they land on the same pod
# (serialised by the Redis lock) or a different pod.
#
# Legit same-text messages are distinguished by the assistant turn
# between them: if the user said "yes", got a response, and says
# "yes" again, session.messages[-1] is the assistant reply, so the
# role check fails and the second message goes through normally.
#
# Edge case: if a turn dies without writing any assistant message,
# the user's next send of the same text is blocked here permanently.
# The fix is to ensure failed turns always write an error/timeout
# assistant message so the session always ends on an assistant turn.
if message.content is not None and is_message_duplicate(
session.messages, message.role, message.content
):
return None # duplicate — caller should skip enqueue
session.messages.append(message)
existing_message_count = await chat_db().get_next_sequence(session_id)
@@ -679,6 +705,9 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
await cache_chat_session(session)
except Exception as e:
logger.warning(f"Cache write failed for session {session_id}: {e}")
# Invalidate the stale entry so future reads fall back to DB,
# preventing a retry from bypassing the idempotency check above.
await invalidate_session_cache(session_id)
return session
@@ -764,10 +793,6 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
except Exception as e:
logger.warning(f"Failed to delete session {session_id} from cache: {e}")
# Clean up session lock (belt-and-suspenders with WeakValueDictionary)
async with _session_locks_mutex:
_session_locks.pop(session_id, None)
# Shut down any local browser daemon for this session (best-effort).
# Inline import required: all tool modules import ChatSession from this
# module, so any top-level import from tools.* would create a cycle.
@@ -832,25 +857,38 @@ async def update_session_title(
# ==================== Chat session locks ==================== #
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
_session_locks_mutex = asyncio.Lock()
@asynccontextmanager
async def _get_session_lock(session_id: str) -> AsyncIterator[bool]:
"""Distributed Redis lock for a session, usable as an async context manager.
async def _get_session_lock(session_id: str) -> asyncio.Lock:
"""Get or create a lock for a specific session to prevent concurrent upserts.
Yields True if the lock was acquired, False if it timed out or Redis was
unavailable. Callers should treat False as a degraded mode and prefer fresh
DB reads over cache to avoid acting on stale state.
This was originally added to solve the specific problem of race conditions between
the session title thread and the conversation thread, which always occurs on the
same instance as we prevent rapid request sends on the frontend.
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
when no coroutine holds a reference to them, preventing memory leaks from
unbounded growth of session locks. Explicit cleanup also occurs
in `delete_chat_session()`.
Uses redis-py's built-in Lock (Lua-script acquire/release) so lock acquisition
is atomic and release is owner-verified. Blocks up to 2s for a concurrent
writer to finish; the 10s TTL ensures a dead pod never holds the lock forever.
"""
async with _session_locks_mutex:
lock = _session_locks.get(session_id)
if lock is None:
lock = asyncio.Lock()
_session_locks[session_id] = lock
return lock
_lock_key = f"copilot:session_lock:{session_id}"
lock = None
acquired = False
try:
_redis = await get_redis_async()
lock = _redis.lock(_lock_key, timeout=10, blocking_timeout=2)
acquired = await lock.acquire(blocking=True)
if not acquired:
logger.warning(
"Could not acquire session lock for %s within 2s", session_id
)
except Exception as e:
logger.warning("Redis unavailable for session lock on %s: %s", session_id, e)
try:
yield acquired
finally:
if acquired and lock is not None:
try:
await lock.release()
except Exception:
pass # TTL will expire the key

View File

@@ -11,11 +11,13 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
ChatCompletionMessageToolCallParam,
Function,
)
from pytest_mock import MockerFixture
from .model import (
ChatMessage,
ChatSession,
Usage,
append_and_save_message,
get_chat_session,
is_message_duplicate,
maybe_append_user_message,
@@ -574,3 +576,345 @@ def test_maybe_append_assistant_skips_duplicate():
result = maybe_append_user_message(session, "dup", is_user_message=False)
assert result is False
assert len(session.messages) == 2
# --------------------------------------------------------------------------- #
# append_and_save_message #
# --------------------------------------------------------------------------- #
def _make_session_with_messages(*msgs: ChatMessage) -> ChatSession:
s = ChatSession.new(user_id="u1", dry_run=False)
s.messages = list(msgs)
return s
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_returns_none_for_duplicate(
mocker: MockerFixture,
) -> None:
"""append_and_save_message returns None when the trailing message is a duplicate."""
session = _make_session_with_messages(
ChatMessage(role="user", content="hello"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
result = await append_and_save_message(
session.session_id, ChatMessage(role="user", content="hello")
)
assert result is None
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_appends_new_message(
mocker: MockerFixture,
) -> None:
"""append_and_save_message appends a non-duplicate message and returns the session."""
session = _make_session_with_messages(
ChatMessage(role="user", content="hello"),
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=2)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
)
new_msg = ChatMessage(role="user", content="second message")
result = await append_and_save_message(session.session_id, new_msg)
assert result is not None
assert result.messages[-1].content == "second message"
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_raises_when_session_not_found(
mocker: MockerFixture,
) -> None:
"""append_and_save_message raises ValueError when the session does not exist."""
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=None,
)
with pytest.raises(ValueError, match="not found"):
await append_and_save_message(
"missing-session-id", ChatMessage(role="user", content="hi")
)
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_uses_db_when_lock_degraded(
mocker: MockerFixture,
) -> None:
"""When the Redis lock times out (acquired=False), the fallback reads from DB."""
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=False)
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mock_get_from_db = mocker.patch(
"backend.copilot.model._get_session_from_db",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
)
new_msg = ChatMessage(role="user", content="new msg")
result = await append_and_save_message(session.session_id, new_msg)
# DB path was used (not cache-first)
mock_get_from_db.assert_called_once_with(session.session_id)
assert result is not None
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_raises_database_error_on_save_failure(
mocker: MockerFixture,
) -> None:
"""When _save_session_to_db fails, append_and_save_message raises DatabaseError."""
from backend.util.exceptions import DatabaseError
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
side_effect=RuntimeError("db down"),
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
with pytest.raises(DatabaseError):
await append_and_save_message(
session.session_id, ChatMessage(role="user", content="new msg")
)
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_invalidates_cache_on_cache_failure(
mocker: MockerFixture,
) -> None:
"""When cache_chat_session fails, invalidate_session_cache is called to avoid stale reads."""
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock()
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
side_effect=RuntimeError("redis write failed"),
)
mock_invalidate = mocker.patch(
"backend.copilot.model.invalidate_session_cache",
new_callable=mocker.AsyncMock,
)
result = await append_and_save_message(
session.session_id, ChatMessage(role="user", content="new msg")
)
# DB write succeeded, cache invalidation was called
mock_invalidate.assert_called_once_with(session.session_id)
assert result is not None
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_uses_db_when_redis_unavailable(
mocker: MockerFixture,
) -> None:
"""When get_redis_async raises, _get_session_lock yields False (degraded) and DB is read."""
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
side_effect=ConnectionError("redis down"),
)
mock_get_from_db = mocker.patch(
"backend.copilot.model._get_session_from_db",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
)
new_msg = ChatMessage(role="user", content="new msg")
result = await append_and_save_message(session.session_id, new_msg)
mock_get_from_db.assert_called_once_with(session.session_id)
assert result is not None
@pytest.mark.asyncio(loop_scope="session")
async def test_append_and_save_message_lock_release_failure_is_ignored(
mocker: MockerFixture,
) -> None:
"""If lock.release() raises, the exception is swallowed (TTL will clean up)."""
session = _make_session_with_messages(
ChatMessage(role="assistant", content="hi"),
)
mock_redis_lock = mocker.AsyncMock()
mock_redis_lock.acquire = mocker.AsyncMock(return_value=True)
mock_redis_lock.release = mocker.AsyncMock(
side_effect=RuntimeError("release failed")
)
mock_redis_client = mocker.MagicMock()
mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock)
mocker.patch(
"backend.copilot.model.get_redis_async",
new_callable=mocker.AsyncMock,
return_value=mock_redis_client,
)
mocker.patch(
"backend.copilot.model.get_chat_session",
new_callable=mocker.AsyncMock,
return_value=session,
)
mocker.patch(
"backend.copilot.model._save_session_to_db",
new_callable=mocker.AsyncMock,
)
mocker.patch(
"backend.copilot.model.chat_db",
return_value=mocker.MagicMock(
get_next_sequence=mocker.AsyncMock(return_value=1)
),
)
mocker.patch(
"backend.copilot.model.cache_chat_session",
new_callable=mocker.AsyncMock,
)
new_msg = ChatMessage(role="user", content="new msg")
result = await append_and_save_message(session.session_id, new_msg)
assert result is not None

View File

@@ -125,7 +125,12 @@ config = ChatConfig()
class _SystemPromptPreset(SystemPromptPreset, total=False):
"""Extends SystemPromptPreset with fields added in claude-agent-sdk 0.1.59."""
"""Extends :class:`SystemPromptPreset` with ``exclude_dynamic_sections``.
The field was added to the upstream TypedDict in claude-agent-sdk 0.1.59.
Until the package is pinned to that version we declare it locally so Pyright
accepts the kwarg without a ``# type: ignore`` comment.
"""
exclude_dynamic_sections: NotRequired[bool]
@@ -893,7 +898,7 @@ def _write_cli_session_to_disk(
return False
def _read_cli_session_from_disk(
def read_cli_session_from_disk(
sdk_cwd: str,
session_id: str,
log_prefix: str,
@@ -973,7 +978,7 @@ def _read_cli_session_from_disk(
return stripped_bytes
def _process_cli_restore(
def process_cli_restore(
cli_restore: TranscriptDownload,
sdk_cwd: str,
session_id: str,
@@ -2489,9 +2494,7 @@ async def _restore_cli_session_for_turn(
# session path, so we validate BEFORE any disk write.
stripped = ""
if cli_restore is not None and sdk_cwd:
stripped, ok = _process_cli_restore(
cli_restore, sdk_cwd, session_id, log_prefix
)
stripped, ok = process_cli_restore(cli_restore, sdk_cwd, session_id, log_prefix)
if not ok:
result.transcript_covers_prefix = False
cli_restore = None
@@ -3636,7 +3639,7 @@ async def stream_chat_completion_sdk(
# this turn ran without --resume (restore failed or first T2+ on a new
# pod), the T1 session file at the expected path may still be present
# and should be re-uploaded so the next turn can resume from it.
# _read_cli_session_from_disk returns None when the file is absent, so
# read_cli_session_from_disk returns None when the file is absent, so
# this is always safe.
#
# Intentionally NOT gated on skip_transcript_upload: that flag is set
@@ -3665,7 +3668,7 @@ async def stream_chat_completion_sdk(
try:
# Read the CLI's native session file from disk (written by the CLI
# after the turn), then upload the bytes to GCS.
_cli_content = _read_cli_session_from_disk(
_cli_content = read_cli_session_from_disk(
sdk_cwd, session_id, log_prefix
)
if _cli_content:

View File

@@ -1371,7 +1371,7 @@ class TestStripStaleThinkingBlocks:
class TestProcessCliRestore:
"""``_process_cli_restore`` validates, strips, and writes CLI session to disk."""
"""``process_cli_restore`` validates, strips, and writes CLI session to disk."""
def test_writes_stripped_bytes_not_raw(self, tmp_path):
"""Stripped bytes (not raw bytes) must be written to disk for --resume."""
@@ -1380,7 +1380,7 @@ class TestProcessCliRestore:
from pathlib import Path
from unittest.mock import patch
from backend.copilot.sdk.service import _process_cli_restore
from backend.copilot.sdk.service import process_cli_restore
from backend.copilot.transcript import TranscriptDownload
session_id = "12345678-0000-0000-0000-abcdef000001"
@@ -1406,7 +1406,7 @@ class TestProcessCliRestore:
return_value=projects_base_dir,
),
):
stripped_str, ok = _process_cli_restore(
stripped_str, ok = process_cli_restore(
restore, sdk_cwd, session_id, "[Test]"
)
@@ -1433,7 +1433,7 @@ class TestProcessCliRestore:
def test_invalid_content_returns_false(self):
"""Content that fails validation after strip returns (empty, False)."""
from backend.copilot.sdk.service import _process_cli_restore
from backend.copilot.sdk.service import process_cli_restore
from backend.copilot.transcript import TranscriptDownload
# A single progress-only entry — stripped result will be empty/invalid
@@ -1442,7 +1442,7 @@ class TestProcessCliRestore:
content=raw_content.encode("utf-8"), message_count=1, mode="sdk"
)
stripped_str, ok = _process_cli_restore(
stripped_str, ok = process_cli_restore(
restore,
"/tmp/nonexistent-sdk-cwd",
"12345678-0000-0000-0000-000000000099",
@@ -1454,7 +1454,7 @@ class TestProcessCliRestore:
class TestReadCliSessionFromDisk:
"""``_read_cli_session_from_disk`` reads, strips, and optionally writes back the session."""
"""``read_cli_session_from_disk`` reads, strips, and optionally writes back the session."""
def _build_session_file(self, tmp_path, session_id: str):
"""Build the session file path inside tmp_path using the same encoding as cli_session_path."""
@@ -1472,7 +1472,7 @@ class TestReadCliSessionFromDisk:
"""Non-UTF-8 bytes trigger UnicodeDecodeError — returns raw bytes (upload-raw fallback)."""
from unittest.mock import patch
from backend.copilot.sdk.service import _read_cli_session_from_disk
from backend.copilot.sdk.service import read_cli_session_from_disk
session_id = "12345678-0000-0000-0000-aabbccdd0001"
projects_base_dir = str(tmp_path)
@@ -1491,7 +1491,7 @@ class TestReadCliSessionFromDisk:
return_value=projects_base_dir,
),
):
result = _read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
result = read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
# UnicodeDecodeError path returns the raw bytes (upload-raw fallback)
assert result == b"\xff\xfe invalid utf-8\n"
@@ -1500,7 +1500,7 @@ class TestReadCliSessionFromDisk:
"""OSError on write-back returns stripped bytes for GCS upload (not raw)."""
from unittest.mock import patch
from backend.copilot.sdk.service import _read_cli_session_from_disk
from backend.copilot.sdk.service import read_cli_session_from_disk
session_id = "12345678-0000-0000-0000-aabbccdd0002"
projects_base_dir = str(tmp_path)
@@ -1527,7 +1527,7 @@ class TestReadCliSessionFromDisk:
return_value=projects_base_dir,
),
):
result = _read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
result = read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
finally:
session_file.chmod(0o644)

View File

@@ -423,20 +423,33 @@ async def subscribe_to_session(
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
)
# RACE CONDITION FIX: If session not found, retry once after small delay
# This handles the case where subscribe_to_session is called immediately
# after create_session but before Redis propagates the write
# RACE CONDITION FIX: If session not found, retry with backoff.
# Duplicate requests skip create_session and subscribe immediately; the
# original request's create_session (a Redis hset) may not have completed
# yet. 3 × 100ms gives a 300ms window which covers DB-write latency on the
# original request before the hset even starts.
if not meta:
logger.warning(
"[TIMING] Session not found on first attempt, retrying after 50ms delay",
extra={"json_fields": {**log_meta}},
)
await asyncio.sleep(0.05) # 50ms
meta = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
_max_retries = 3
_retry_delay = 0.1 # 100ms per attempt
for attempt in range(_max_retries):
logger.warning(
f"[TIMING] Session not found (attempt {attempt + 1}/{_max_retries}), "
f"retrying after {int(_retry_delay * 1000)}ms",
extra={"json_fields": {**log_meta, "attempt": attempt + 1}},
)
await asyncio.sleep(_retry_delay)
meta = await redis.hgetall(meta_key) # type: ignore[misc]
if meta:
logger.info(
f"[TIMING] Session found after {attempt + 1} retries",
extra={"json_fields": {**log_meta, "attempts": attempt + 1}},
)
break
else:
elapsed = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] Session still not found in Redis after retry ({elapsed:.1f}ms total)",
f"[TIMING] Session still not found in Redis after {_max_retries} retries "
f"({elapsed:.1f}ms total)",
extra={
"json_fields": {
**log_meta,
@@ -446,10 +459,6 @@ async def subscribe_to_session(
},
)
return None
logger.info(
"[TIMING] Session found after retry",
extra={"json_fields": {**log_meta}},
)
# Note: Redis client uses decode_responses=True, so keys are strings
session_status = meta.get("status", "")

View File

@@ -880,31 +880,12 @@ class TestUploadCliSession:
assert meta_content["mode"] == "baseline"
assert meta_content["message_count"] == 4
def test_strips_session_before_upload_and_writes_back(self, tmp_path):
"""Strippable entries (progress, thinking blocks) are removed before upload.
The stripped content is written back to disk (so same-pod turns benefit)
and the smaller bytes are uploaded to GCS.
"""
import asyncio
import os
import re
from unittest.mock import AsyncMock, patch
from .transcript import _sanitize_id, upload_cli_session
projects_base = str(tmp_path)
session_id = "12345678-0000-0000-0000-000000000010"
sdk_cwd = str(tmp_path)
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
session_dir = tmp_path / encoded_cwd
session_dir.mkdir(parents=True, exist_ok=True)
session_file = session_dir / f"{_sanitize_id(session_id)}.jsonl"
# A CLI session with a progress entry (strippable) and a real assistant message.
def test_strips_session_before_upload_and_writes_back(self):
"""strip_for_upload removes progress entries and returns smaller content."""
import json
from .transcript import strip_for_upload
progress_entry = {
"type": "progress",
"uuid": "p1",
@@ -930,64 +911,22 @@ class TestUploadCliSession:
+ json.dumps(asst_entry)
+ "\n"
)
raw_bytes = raw_content.encode("utf-8")
session_file.write_bytes(raw_bytes)
mock_storage = AsyncMock()
stripped = strip_for_upload(raw_content)
with (
patch(
"backend.copilot.transcript._projects_base",
return_value=projects_base,
),
patch(
"backend.copilot.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
),
):
asyncio.run(
upload_cli_session(
user_id="user-1",
session_id=session_id,
sdk_cwd=sdk_cwd,
)
)
# Upload should have been called with stripped bytes (no progress entry).
mock_storage.store.assert_called_once()
stored_content: bytes = mock_storage.store.call_args.kwargs["content"]
stored_lines = stored_content.decode("utf-8").strip().split("\n")
stored_lines = stripped.strip().split("\n")
stored_types = [json.loads(line).get("type") for line in stored_lines]
assert "progress" not in stored_types
assert "user" in stored_types
assert "assistant" in stored_types
# Stripped bytes should be smaller than raw.
assert len(stored_content) < len(raw_bytes)
# File on disk should also be the stripped version.
disk_content = session_file.read_bytes()
assert disk_content == stored_content
assert len(stripped.encode()) < len(raw_content.encode())
def test_strips_stale_thinking_blocks_before_upload(self, tmp_path):
"""Thinking blocks in non-last assistant turns are stripped to reduce size."""
import asyncio
def test_strips_stale_thinking_blocks_before_upload(self):
"""strip_for_upload removes thinking blocks from non-last assistant turns."""
import json
import os
import re
from unittest.mock import AsyncMock, patch
from .transcript import _sanitize_id, upload_cli_session
from .transcript import strip_for_upload
projects_base = str(tmp_path)
session_id = "12345678-0000-0000-0000-000000000011"
sdk_cwd = str(tmp_path)
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
session_dir = tmp_path / encoded_cwd
session_dir.mkdir(parents=True, exist_ok=True)
session_file = session_dir / f"{_sanitize_id(session_id)}.jsonl"
# Two turns: first assistant has thinking block (stale), second doesn't.
u1 = {
"type": "user",
"uuid": "u1",
@@ -1032,32 +971,10 @@ class TestUploadCliSession:
+ json.dumps(a2_no_thinking)
+ "\n"
)
raw_bytes = raw_content.encode("utf-8")
session_file.write_bytes(raw_bytes)
mock_storage = AsyncMock()
stripped = strip_for_upload(raw_content)
with (
patch(
"backend.copilot.transcript._projects_base",
return_value=projects_base,
),
patch(
"backend.copilot.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
),
):
asyncio.run(
upload_cli_session(
user_id="user-1",
session_id=session_id,
sdk_cwd=sdk_cwd,
)
)
stored_content: bytes = mock_storage.store.call_args.kwargs["content"]
stored_lines = stored_content.decode("utf-8").strip().split("\n")
stored_lines = stripped.strip().split("\n")
# a1 should have its thinking block stripped (it's not the last assistant turn).
a1_stored = json.loads(stored_lines[1])
@@ -1073,9 +990,6 @@ class TestUploadCliSession:
a2_stored = json.loads(stored_lines[3])
assert a2_stored["message"]["content"] == [{"type": "text", "text": "answer2"}]
# Stripped bytes smaller than raw.
assert len(stored_content) < len(raw_bytes)
class TestRestoreCliSession:
def test_returns_none_when_file_not_found_in_storage(self):

View File

@@ -143,6 +143,8 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.GROK_4: 9,
LlmModel.GROK_4_FAST: 1,
LlmModel.GROK_4_1_FAST: 1,
LlmModel.GROK_4_20: 5,
LlmModel.GROK_4_20_MULTI_AGENT: 5,
LlmModel.GROK_CODE_FAST_1: 1,
LlmModel.KIMI_K2: 1,
LlmModel.QWEN3_235B_A22B_THINKING: 1,

View File

@@ -1,10 +1,13 @@
import asyncio
import logging
import time
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, cast
import stripe
from fastapi.concurrency import run_in_threadpool
from prisma.enums import (
CreditRefundRequestStatus,
CreditTransactionType,
@@ -31,6 +34,7 @@ from backend.data.model import (
from backend.data.notifications import NotificationEventModel, RefundRequestData
from backend.data.user import get_user_by_id, get_user_email_by_id
from backend.notifications.notifications import queue_notification_async
from backend.util.cache import cached
from backend.util.exceptions import InsufficientBalanceError
from backend.util.feature_flag import Flag, get_feature_flag_value, is_feature_enabled
from backend.util.json import SafeJson, dumps
@@ -432,7 +436,7 @@ class UserCreditBase(ABC):
current_balance, _ = await self._get_credits(user_id)
if current_balance >= ceiling_balance:
raise ValueError(
f"You already have enough balance of ${current_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}"
f"You already have enough balance of ${current_balance / 100}, top-up is not required when you already have at least ${ceiling_balance / 100}"
)
# Single unified atomic operation for all transaction types using UserBalance
@@ -571,7 +575,7 @@ class UserCreditBase(ABC):
if amount < 0 and fail_insufficient_credits:
current_balance, _ = await self._get_credits(user_id)
raise InsufficientBalanceError(
message=f"Insufficient balance of ${current_balance/100}, where this will cost ${abs(amount)/100}",
message=f"Insufficient balance of ${current_balance / 100}, where this will cost ${abs(amount) / 100}",
user_id=user_id,
balance=current_balance,
amount=amount,
@@ -582,7 +586,6 @@ class UserCreditBase(ABC):
class UserCredit(UserCreditBase):
async def _send_refund_notification(
self,
notification_request: RefundRequestData,
@@ -734,7 +737,7 @@ class UserCredit(UserCreditBase):
)
if request.amount <= 0 or request.amount > transaction.amount:
raise AssertionError(
f"Invalid amount to deduct ${request.amount/100} from ${transaction.amount/100} top-up"
f"Invalid amount to deduct ${request.amount / 100} from ${transaction.amount / 100} top-up"
)
balance, _ = await self._add_transaction(
@@ -788,12 +791,12 @@ class UserCredit(UserCreditBase):
# If the user has enough balance, just let them win the dispute.
if balance - amount >= settings.config.refund_credit_tolerance_threshold:
logger.warning(f"Accepting dispute from {user_id} for ${amount/100}")
logger.warning(f"Accepting dispute from {user_id} for ${amount / 100}")
dispute.close()
return
logger.warning(
f"Adding extra info for dispute from {user_id} for ${amount/100}"
f"Adding extra info for dispute from {user_id} for ${amount / 100}"
)
# Retrieve recent transaction history to support our evidence.
# This provides a concise timeline that shows service usage and proper credit application.
@@ -1237,14 +1240,23 @@ async def get_stripe_customer_id(user_id: str) -> str:
if user.stripe_customer_id:
return user.stripe_customer_id
customer = stripe.Customer.create(
# Race protection: two concurrent calls (e.g. user double-clicks "Upgrade",
# or any retried request) would each pass the check above and create their
# own Stripe Customer, leaving an orphaned billable customer in Stripe.
# Pass an idempotency_key so Stripe collapses concurrent + retried calls
# into the same Customer object server-side. The 24h Stripe idempotency
# window comfortably covers any realistic in-flight retry scenario.
customer = await run_in_threadpool(
stripe.Customer.create,
name=user.name or "",
email=user.email,
metadata={"user_id": user_id},
idempotency_key=f"customer-create-{user_id}",
)
await User.prisma().update(
where={"id": user_id}, data={"stripeCustomerId": customer.id}
)
get_user_by_id.cache_delete(user_id)
return customer.id
@@ -1263,23 +1275,203 @@ async def set_subscription_tier(user_id: str, tier: SubscriptionTier) -> None:
data={"subscriptionTier": tier},
)
get_user_by_id.cache_delete(user_id)
# Also invalidate the rate-limit tier cache so CoPilot picks up the new
# tier immediately rather than waiting up to 5 minutes for the TTL to expire.
from backend.copilot.rate_limit import get_user_tier # local import avoids circular
get_user_tier.cache_delete(user_id) # type: ignore[attr-defined]
async def cancel_stripe_subscription(user_id: str) -> None:
"""Cancel all active Stripe subscriptions for a user (called on downgrade to FREE)."""
customer_id = await get_stripe_customer_id(user_id)
subscriptions = stripe.Subscription.list(
customer=customer_id, status="active", limit=10
)
for sub in subscriptions.auto_paging_iter():
try:
stripe.Subscription.cancel(sub["id"])
except stripe.StripeError:
logger.warning(
"cancel_stripe_subscription: failed to cancel sub %s for user %s",
sub["id"],
user_id,
async def _cancel_customer_subscriptions(
customer_id: str,
exclude_sub_id: str | None = None,
at_period_end: bool = False,
) -> int:
"""Cancel all billable Stripe subscriptions for a customer, optionally excluding one.
Cancels both ``active`` and ``trialing`` subscriptions, since trialing subs will
start billing once the trial ends and must be cleaned up on downgrade/upgrade to
avoid double-charging or charging users who intended to cancel.
When ``at_period_end=True``, schedules cancellation at the end of the current
billing period instead of cancelling immediately — the user keeps their tier
until the period ends, then ``customer.subscription.deleted`` fires and the
webhook downgrades them to FREE.
Wraps every synchronous Stripe SDK call with run_in_threadpool so the async event
loop is never blocked. Raises stripe.StripeError on list/cancel failure so callers
that need strict consistency can react; cleanup callers can catch and log instead.
Returns the number of subscriptions cancelled/scheduled for cancellation.
"""
# Query active and trialing separately; Stripe's list API accepts a single status
# filter at a time (no OR), and we explicitly want to skip canceled/incomplete/
# past_due subs rather than filter them out client-side via status="all".
seen_ids: set[str] = set()
for status in ("active", "trialing"):
subscriptions = await run_in_threadpool(
stripe.Subscription.list, customer=customer_id, status=status, limit=10
)
# Iterate only the first page (up to 10); avoid auto_paging_iter which would
# trigger additional sync HTTP calls inside the event loop.
if subscriptions.has_more:
logger.error(
"_cancel_customer_subscriptions: customer %s has more than 10 %s"
" subscriptions — only the first page was processed; remaining"
" subscriptions were NOT cancelled",
customer_id,
status,
)
for sub in subscriptions.data:
sub_id = sub["id"]
if exclude_sub_id and sub_id == exclude_sub_id:
continue
if sub_id in seen_ids:
continue
seen_ids.add(sub_id)
if at_period_end:
await run_in_threadpool(
stripe.Subscription.modify, sub_id, cancel_at_period_end=True
)
else:
await run_in_threadpool(stripe.Subscription.cancel, sub_id)
return len(seen_ids)
async def cancel_stripe_subscription(user_id: str) -> bool:
"""Schedule cancellation of all active/trialing Stripe subscriptions at period end.
The subscription stays active until the end of the billing period so the user
keeps their tier for the time they already paid for. The ``customer.subscription.deleted``
webhook fires at period end and downgrades the DB tier to FREE.
Returns True if at least one subscription was found and scheduled for cancellation,
False if the customer had no active/trialing subscriptions (e.g., admin-granted tier
with no associated Stripe subscription). When False, the caller should update the
DB tier directly since no webhook will fire to do it.
Raises stripe.StripeError if any modification fails, so the caller can avoid
updating the DB tier when Stripe is inconsistent.
"""
# Guard: only proceed if the user already has a Stripe customer ID. Calling
# get_stripe_customer_id for a user who has never had a paid subscription would
# create an orphaned, potentially-billable Stripe Customer object — we avoid that
# by returning False early so the caller can downgrade the DB tier directly.
user = await get_user_by_id(user_id)
if not user.stripe_customer_id:
return False
customer_id = user.stripe_customer_id
try:
cancelled_count = await _cancel_customer_subscriptions(
customer_id, at_period_end=True
)
return cancelled_count > 0
except stripe.StripeError:
logger.warning(
"cancel_stripe_subscription: Stripe error while cancelling subs for user %s",
user_id,
)
raise
async def get_proration_credit_cents(user_id: str, monthly_cost_cents: int) -> int:
"""Return the prorated credit (in cents) the user would receive if they upgraded now.
Fetches the user's active Stripe subscription to determine how many seconds
remain in the current billing period, then calculates the unused portion of
the monthly cost. Returns 0 for FREE/ENTERPRISE users or when no active sub
is found.
"""
if monthly_cost_cents <= 0:
return 0
# Guard: only query Stripe if the user already has a customer ID. Admin-granted
# paid tiers have no Stripe record; calling get_stripe_customer_id would create an
# orphaned customer on every billing-page load for those users.
user = await get_user_by_id(user_id)
if not user.stripe_customer_id:
return 0
try:
customer_id = user.stripe_customer_id
subscriptions = await run_in_threadpool(
stripe.Subscription.list, customer=customer_id, status="active", limit=1
)
if not subscriptions.data:
return 0
sub = subscriptions.data[0]
period_start: int = sub["current_period_start"]
period_end: int = sub["current_period_end"]
now = int(time.time())
total_seconds = period_end - period_start
remaining_seconds = max(period_end - now, 0)
if total_seconds <= 0:
return 0
return int(monthly_cost_cents * remaining_seconds / total_seconds)
except Exception:
logger.warning(
"get_proration_credit_cents: failed to compute proration for user %s",
user_id,
)
return 0
async def modify_stripe_subscription_for_tier(
user_id: str, tier: SubscriptionTier
) -> bool:
"""Modify an existing Stripe subscription to a new paid tier using proration.
For paid→paid tier changes (e.g. PRO↔BUSINESS), modifying the existing
subscription is preferable to cancelling + creating a new one via Checkout:
Stripe handles proration automatically, crediting unused time on the old plan
and charging the pro-rated amount for the new plan in the same billing cycle.
Returns:
True — a subscription was found and modified successfully.
False — no active/trialing subscription exists (e.g. admin-granted tier or
first-time paid signup); caller should fall back to Checkout.
Raises stripe.StripeError on API failures so callers can propagate a 502.
Raises ValueError when no Stripe price ID is configured for the tier.
"""
price_id = await get_subscription_price_id(tier)
if not price_id:
raise ValueError(f"No Stripe price ID configured for tier {tier}")
# Guard: only proceed if the user already has a Stripe customer ID. Calling
# get_stripe_customer_id for a user with no Stripe record (e.g. admin-granted tier)
# would create an orphaned customer object if the subsequent Subscription.list call
# fails. Return False early so the API layer falls back to Checkout instead.
user = await get_user_by_id(user_id)
if not user.stripe_customer_id:
return False
customer_id = user.stripe_customer_id
for status in ("active", "trialing"):
subscriptions = await run_in_threadpool(
stripe.Subscription.list, customer=customer_id, status=status, limit=1
)
if not subscriptions.data:
continue
sub = subscriptions.data[0]
sub_id = sub["id"]
items = sub.get("items", {}).get("data", [])
if not items:
continue
item_id = items[0]["id"]
await run_in_threadpool(
stripe.Subscription.modify,
sub_id,
items=[{"id": item_id, "price": price_id}],
proration_behavior="create_prorations",
)
logger.info(
"modify_stripe_subscription_for_tier: modified sub %s for user %s%s",
sub_id,
user_id,
tier,
)
return True
return False
async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
@@ -1291,8 +1483,19 @@ async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
return AutoTopUpConfig.model_validate(user.top_up_config)
@cached(ttl_seconds=60, maxsize=8, cache_none=False)
async def get_subscription_price_id(tier: SubscriptionTier) -> str | None:
"""Return Stripe Price ID for a tier from LaunchDarkly. None = not configured."""
"""Return Stripe Price ID for a tier from LaunchDarkly, cached for 60 seconds.
Price IDs are LaunchDarkly flag values that change only at deploy time.
Caching for 60 seconds avoids hitting the LD SDK on every webhook delivery
and every GET /credits/subscription page load (called 2x per request).
``cache_none=False`` prevents a transient LD failure from caching ``None``
and blocking subscription upgrades for the full 60-second TTL window.
A tier with no configured flag (FREE, ENTERPRISE) returns ``None`` from an
O(1) dict lookup before hitting LD, so the extra LD call is never made.
"""
flag_map = {
SubscriptionTier.PRO: Flag.STRIPE_PRICE_PRO,
SubscriptionTier.BUSINESS: Flag.STRIPE_PRICE_BUSINESS,
@@ -1300,7 +1503,7 @@ async def get_subscription_price_id(tier: SubscriptionTier) -> str | None:
flag = flag_map.get(tier)
if flag is None:
return None
price_id = await get_feature_flag_value(flag.value, user_id="", default="")
price_id = await get_feature_flag_value(flag.value, user_id="system", default="")
return price_id if isinstance(price_id, str) and price_id else None
@@ -1315,7 +1518,8 @@ async def create_subscription_checkout(
if not price_id:
raise ValueError(f"Subscription not available for tier {tier.value}")
customer_id = await get_stripe_customer_id(user_id)
session = stripe.checkout.Session.create(
session = await run_in_threadpool(
stripe.checkout.Session.create,
customer=customer_id,
mode="subscription",
line_items=[{"price": price_id, "quantity": 1}],
@@ -1323,26 +1527,111 @@ async def create_subscription_checkout(
cancel_url=cancel_url,
subscription_data={"metadata": {"user_id": user_id, "tier": tier.value}},
)
return session.url or ""
if not session.url:
# An empty checkout URL for a paid upgrade is always an error; surfacing it
# as ValueError means the API handler returns 422 instead of silently
# redirecting the client to an empty URL.
raise ValueError("Stripe did not return a checkout session URL")
return session.url
async def _cleanup_stale_subscriptions(customer_id: str, new_sub_id: str) -> None:
"""Best-effort cancel of any active subs for the customer other than new_sub_id.
Called from the webhook handler after a new subscription becomes active. Failures
are logged but not raised so a transient Stripe error doesn't crash the webhook —
a periodic reconciliation job is the intended backstop for persistent drift.
NOTE: until that reconcile job lands, a failure here means the user is silently
billed for two simultaneous subscriptions. The error log below is intentionally
`logger.exception` so it surfaces in Sentry with the customer/sub IDs needed to
manually reconcile, and the metric `stripe_stale_subscription_cleanup_failed`
is bumped so on-call can alert on persistent drift.
TODO(#stripe-reconcile-job): replace this best-effort cleanup with a periodic
reconciliation job that queries Stripe for customers with >1 active sub.
"""
try:
await _cancel_customer_subscriptions(customer_id, exclude_sub_id=new_sub_id)
except stripe.StripeError:
# Use exception() (not warning) so this surfaces as an error in Sentry —
# any failure here means a paid-to-paid upgrade may have left the user
# with two simultaneous active subscriptions.
logger.exception(
"stripe_stale_subscription_cleanup_failed: customer=%s new_sub=%s"
" user may be billed for two simultaneous subscriptions; manual"
" reconciliation required",
customer_id,
new_sub_id,
)
async def sync_subscription_from_stripe(stripe_subscription: dict) -> None:
"""Update User.subscriptionTier from a Stripe subscription object."""
customer_id = stripe_subscription["customer"]
"""Update User.subscriptionTier from a Stripe subscription object.
Expected shape of stripe_subscription (subset of Stripe's Subscription object):
customer: str — Stripe customer ID
status: str — "active" | "trialing" | "canceled" | ...
id: str — Stripe subscription ID
items.data[].price.id: str — Stripe price ID identifying the tier
"""
customer_id = stripe_subscription.get("customer")
if not customer_id:
logger.warning(
"sync_subscription_from_stripe: missing 'customer' field in event, "
"skipping (keys: %s)",
list(stripe_subscription.keys()),
)
return
user = await User.prisma().find_first(where={"stripeCustomerId": customer_id})
if not user:
logger.warning(
"sync_subscription_from_stripe: no user for customer %s", customer_id
)
return
# Cross-check: if the subscription carries a metadata.user_id (set during
# Checkout Session creation), verify it matches the user we found via
# stripeCustomerId. A mismatch indicates a customer↔user mapping
# inconsistency — updating the wrong user's tier would be a data-corruption
# bug, so we log loudly and bail out. Absence of metadata.user_id (e.g.
# subscriptions created outside the Checkout flow) is not an error — we
# simply skip the check and proceed with the customer-ID-based lookup.
metadata = stripe_subscription.get("metadata") or {}
metadata_user_id = metadata.get("user_id") if isinstance(metadata, dict) else None
if metadata_user_id and metadata_user_id != user.id:
logger.error(
"sync_subscription_from_stripe: metadata.user_id=%s does not match"
" user.id=%s found via stripeCustomerId=%s — refusing to update tier"
" to avoid corrupting the wrong user's subscription state",
metadata_user_id,
user.id,
customer_id,
)
return
# ENTERPRISE tiers are admin-managed. Never let a Stripe webhook flip an
# ENTERPRISE user to a different tier — if a user on ENTERPRISE somehow has
# a self-service Stripe sub, it's a data-consistency issue for an operator,
# not something the webhook should automatically "fix".
current_tier = user.subscriptionTier or SubscriptionTier.FREE
if current_tier == SubscriptionTier.ENTERPRISE:
logger.warning(
"sync_subscription_from_stripe: refusing to overwrite ENTERPRISE tier"
" for user %s (customer %s); event status=%s",
user.id,
customer_id,
stripe_subscription.get("status", ""),
)
return
status = stripe_subscription.get("status", "")
new_sub_id = stripe_subscription.get("id", "")
if status in ("active", "trialing"):
price_id = ""
items = stripe_subscription.get("items", {}).get("data", [])
if items:
price_id = items[0].get("price", {}).get("id", "")
pro_price = await get_subscription_price_id(SubscriptionTier.PRO)
biz_price = await get_subscription_price_id(SubscriptionTier.BUSINESS)
pro_price, biz_price = await asyncio.gather(
get_subscription_price_id(SubscriptionTier.PRO),
get_subscription_price_id(SubscriptionTier.BUSINESS),
)
if price_id and pro_price and price_id == pro_price:
tier = SubscriptionTier.PRO
elif price_id and biz_price and price_id == biz_price:
@@ -1359,10 +1648,206 @@ async def sync_subscription_from_stripe(stripe_subscription: dict) -> None:
)
return
else:
# A subscription was cancelled or ended. DO NOT unconditionally downgrade
# to FREE — Stripe does not guarantee webhook delivery order, so a
# `customer.subscription.deleted` for the OLD sub can arrive after we've
# already processed `customer.subscription.created` for a new paid sub.
# Ask Stripe whether any OTHER active/trialing subs exist for this
# customer; if they do, keep the user's current tier (the other sub's
# own event will/has already set the correct tier).
try:
other_subs_active, other_subs_trialing = await asyncio.gather(
run_in_threadpool(
stripe.Subscription.list,
customer=customer_id,
status="active",
limit=10,
),
run_in_threadpool(
stripe.Subscription.list,
customer=customer_id,
status="trialing",
limit=10,
),
)
except stripe.StripeError:
logger.warning(
"sync_subscription_from_stripe: could not verify other active"
" subs for customer %s on cancel event %s; preserving current"
" tier to avoid an unsafe downgrade",
customer_id,
new_sub_id,
)
return
# Filter out the cancelled subscription to check if other active subs
# exist. When new_sub_id is empty (malformed event with no 'id' field),
# we cannot safely exclude any sub — preserve current tier to avoid
# an unsafe downgrade on a malformed webhook payload.
if not new_sub_id:
logger.warning(
"sync_subscription_from_stripe: cancel event missing 'id' field"
" for customer %s; preserving current tier",
customer_id,
)
return
other_active_ids = {sub["id"] for sub in other_subs_active.data} - {new_sub_id}
other_trialing_ids = {sub["id"] for sub in other_subs_trialing.data} - {
new_sub_id
}
still_has_active_sub = bool(other_active_ids or other_trialing_ids)
if still_has_active_sub:
logger.info(
"sync_subscription_from_stripe: sub %s cancelled but customer %s"
" still has another active sub; keeping tier %s",
new_sub_id,
customer_id,
current_tier.value,
)
return
tier = SubscriptionTier.FREE
# Idempotency: Stripe retries webhooks on delivery failure, and several event
# types map to the same final tier. Skip the DB write + cache invalidation
# when the tier is already correct to avoid redundant writes on replay.
if current_tier == tier:
return
# When a new subscription becomes active (e.g. paid-to-paid tier upgrade
# via a fresh Checkout Session), cancel any OTHER active subscriptions for
# the same customer so the user isn't billed twice. We do this in the
# webhook rather than the API handler so that abandoning the checkout
# doesn't leave the user without a subscription.
# IMPORTANT: this runs AFTER the idempotency check above so that webhook
# replays for an already-applied event do NOT trigger another cleanup round
# (which could otherwise cancel a legitimately new subscription the user
# signed up for between the original event and its replay).
if status in ("active", "trialing") and new_sub_id:
# NOTE: paid-to-paid upgrade race (e.g. PRO → BUSINESS):
# _cleanup_stale_subscriptions cancels the old PRO sub before
# set_subscription_tier writes BUSINESS to the DB. If Stripe delivers
# the PRO `customer.subscription.deleted` event concurrently and it
# processes after the PRO cancel but before set_subscription_tier
# commits, the user could momentarily appear as FREE in the DB.
# This window is very short in practice (two sequential awaits),
# but is a known limitation of the current webhook-driven approach.
# A future improvement would be to write the new tier first, then
# cancel the old sub.
await _cleanup_stale_subscriptions(customer_id, new_sub_id)
await set_subscription_tier(user.id, tier)
async def handle_subscription_payment_failure(invoice: dict) -> None:
"""Handle a failed Stripe subscription payment.
Tries to cover the invoice amount from the user's credit balance.
- Balance sufficient → deduct from balance, then pay the Stripe invoice so
Stripe stops retrying it. The sub stays intact and the user keeps their tier.
- Balance insufficient → cancel Stripe sub immediately, downgrade to FREE.
Cancelling here avoids further Stripe retries on an invoice we cannot cover.
"""
customer_id = invoice.get("customer")
if not customer_id:
logger.warning(
"handle_subscription_payment_failure: missing customer in invoice; skipping"
)
return
user = await User.prisma().find_first(where={"stripeCustomerId": customer_id})
if not user:
logger.warning(
"handle_subscription_payment_failure: no user found for customer %s",
customer_id,
)
return
current_tier = user.subscriptionTier or SubscriptionTier.FREE
if current_tier == SubscriptionTier.ENTERPRISE:
logger.warning(
"handle_subscription_payment_failure: skipping ENTERPRISE user %s"
" (customer %s) — tier is admin-managed",
user.id,
customer_id,
)
return
amount_due: int = invoice.get("amount_due", 0)
sub_id: str = invoice.get("subscription", "")
invoice_id: str = invoice.get("id", "")
if amount_due <= 0:
logger.info(
"handle_subscription_payment_failure: amount_due=%d for user %s;"
" nothing to deduct",
amount_due,
user.id,
)
return
credit_model = UserCredit()
try:
await credit_model._add_transaction(
user_id=user.id,
amount=-amount_due,
transaction_type=CreditTransactionType.SUBSCRIPTION,
fail_insufficient_credits=True,
# Use invoice_id as the idempotency key so that Stripe webhook retries
# (e.g. on a transient stripe.Invoice.pay failure) do not double-charge.
transaction_key=invoice_id or None,
metadata=SafeJson(
{
"stripe_customer_id": customer_id,
"stripe_subscription_id": sub_id,
"reason": "subscription_payment_failure_covered_by_balance",
}
),
)
# Balance covered the invoice. Pay the Stripe invoice so Stripe's dunning
# system stops retrying it — without this call Stripe would retry automatically
# and re-trigger this webhook, causing double-deductions each retry cycle.
if invoice_id:
try:
await run_in_threadpool(stripe.Invoice.pay, invoice_id)
except stripe.StripeError:
logger.warning(
"handle_subscription_payment_failure: balance deducted for user"
" %s but failed to mark invoice %s as paid; Stripe may retry",
user.id,
invoice_id,
)
logger.info(
"handle_subscription_payment_failure: deducted %d cents from balance"
" for user %s; Stripe invoice %s paid, sub %s intact, tier preserved",
amount_due,
user.id,
invoice_id,
sub_id,
)
except InsufficientBalanceError:
# Balance insufficient — cancel Stripe subscription first, then downgrade DB.
# Order matters: if we downgrade the DB first and the Stripe cancel fails, the
# user is permanently stuck on FREE while Stripe continues billing them.
# Cancelling Stripe first is safe: if the DB write then fails, the webhook
# customer.subscription.deleted will fire and correct the tier eventually.
logger.info(
"handle_subscription_payment_failure: insufficient balance for user %s;"
" cancelling Stripe sub %s then downgrading to FREE",
user.id,
sub_id,
)
try:
await _cancel_customer_subscriptions(customer_id)
except stripe.StripeError:
logger.warning(
"handle_subscription_payment_failure: failed to cancel Stripe sub %s"
" for user %s (customer %s); skipping tier downgrade to avoid"
" inconsistency — Stripe may continue retrying the invoice",
sub_id,
user.id,
customer_id,
)
return
await set_subscription_tier(user.id, SubscriptionTier.FREE)
async def admin_get_user_history(
page: int = 1,
page_size: int = 20,

View File

@@ -73,6 +73,31 @@ def _get_redis() -> Redis:
return r
class _MissingType:
"""Singleton sentinel type — distinct from ``None`` (a valid cached value).
Using a dedicated class (instead of ``Any = object()``) lets mypy prove
that comparisons ``result is _MISSING`` narrow the type correctly and
prevents accidental use of the sentinel where a real value is expected.
"""
_instance: "_MissingType | None" = None
def __new__(cls) -> "_MissingType":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __repr__(self) -> str:
return "<MISSING>"
# Sentinel returned by ``_get_from_memory`` / ``_get_from_redis`` to mean
# "no entry exists" — distinct from a cached ``None`` value, which is a
# valid result for callers that opt into caching it.
_MISSING = _MissingType()
@dataclass
class CachedValue:
"""Wrapper for cached values with timestamp to avoid tuple ambiguity."""
@@ -160,6 +185,7 @@ def cached(
ttl_seconds: int,
shared_cache: bool = False,
refresh_ttl_on_get: bool = False,
cache_none: bool = True,
) -> Callable[[Callable[P, R]], CachedFunction[P, R]]:
"""
Thundering herd safe cache decorator for both sync and async functions.
@@ -172,6 +198,10 @@ def cached(
ttl_seconds: Time to live in seconds. Required - entries must expire.
shared_cache: If True, use Redis for cross-process caching
refresh_ttl_on_get: If True, refresh TTL when cache entry is accessed (LRU behavior)
cache_none: If True (default) ``None`` is cached like any other value.
Set to ``False`` for functions that return ``None`` to signal a
transient error and should be re-tried on the next call without
poisoning the cache (e.g. external API calls that may fail).
Returns:
Decorated function with caching capabilities
@@ -184,6 +214,12 @@ def cached(
@cached(ttl_seconds=600, shared_cache=True, refresh_ttl_on_get=True)
async def expensive_async_operation(param: str) -> dict:
return {"result": param}
@cached(ttl_seconds=300, cache_none=False)
async def fetch_external(id: str) -> dict | None:
# Returns None on transient error — won't be stored,
# next call retries instead of returning the stale None.
...
"""
def decorator(target_func: Callable[P, R]) -> CachedFunction[P, R]:
@@ -191,9 +227,14 @@ def cached(
cache_storage: dict[tuple, CachedValue] = {}
_event_loop_locks: dict[Any, asyncio.Lock] = {}
def _get_from_redis(redis_key: str) -> Any | None:
def _get_from_redis(redis_key: str) -> Any:
"""Get value from Redis, optionally refreshing TTL.
Returns the cached value (which may be ``None``) on a hit, or the
module-level ``_MISSING`` sentinel on a miss / corrupt entry.
Callers must compare with ``is _MISSING`` so cached ``None`` values
are not mistaken for misses.
Values are expected to carry an HMAC-SHA256 prefix for integrity
verification. Unsigned (legacy) or tampered entries are silently
discarded and treated as cache misses, so the caller recomputes and
@@ -213,11 +254,11 @@ def cached(
f"for {func_name}, discarding entry: "
"possible tampering or legacy unsigned value"
)
return None
return _MISSING
return pickle.loads(payload)
except Exception as e:
logger.error(f"Redis error during cache check for {func_name}: {e}")
return None
return _MISSING
def _set_to_redis(redis_key: str, value: Any) -> None:
"""Set HMAC-signed pickled value in Redis with TTL."""
@@ -227,8 +268,13 @@ def cached(
except Exception as e:
logger.error(f"Redis error storing cache for {func_name}: {e}")
def _get_from_memory(key: tuple) -> Any | None:
"""Get value from in-memory cache, checking TTL."""
def _get_from_memory(key: tuple) -> Any:
"""Get value from in-memory cache, checking TTL.
Returns the cached value (which may be ``None``) on a hit, or the
``_MISSING`` sentinel on a miss / TTL expiry. See
``_get_from_redis`` for the rationale.
"""
if key in cache_storage:
cached_data = cache_storage[key]
if time.time() - cached_data.timestamp < ttl_seconds:
@@ -236,7 +282,7 @@ def cached(
f"Cache hit for {func_name} args: {key[0]} kwargs: {key[1]}"
)
return cached_data.result
return None
return _MISSING
def _set_to_memory(key: tuple, value: Any) -> None:
"""Set value in in-memory cache with timestamp."""
@@ -270,11 +316,11 @@ def cached(
# Fast path: check cache without lock
if shared_cache:
result = _get_from_redis(redis_key)
if result is not None:
if result is not _MISSING:
return result
else:
result = _get_from_memory(key)
if result is not None:
if result is not _MISSING:
return result
# Slow path: acquire lock for cache miss/expiry
@@ -282,22 +328,24 @@ def cached(
# Double-check: another coroutine might have populated cache
if shared_cache:
result = _get_from_redis(redis_key)
if result is not None:
if result is not _MISSING:
return result
else:
result = _get_from_memory(key)
if result is not None:
if result is not _MISSING:
return result
# Cache miss - execute function
logger.debug(f"Cache miss for {func_name}")
result = await target_func(*args, **kwargs)
# Store result
if shared_cache:
_set_to_redis(redis_key, result)
else:
_set_to_memory(key, result)
# Store result (skip ``None`` if the caller opted out of
# caching it — used for transient-error sentinels).
if cache_none or result is not None:
if shared_cache:
_set_to_redis(redis_key, result)
else:
_set_to_memory(key, result)
return result
@@ -315,11 +363,11 @@ def cached(
# Fast path: check cache without lock
if shared_cache:
result = _get_from_redis(redis_key)
if result is not None:
if result is not _MISSING:
return result
else:
result = _get_from_memory(key)
if result is not None:
if result is not _MISSING:
return result
# Slow path: acquire lock for cache miss/expiry
@@ -327,22 +375,24 @@ def cached(
# Double-check: another thread might have populated cache
if shared_cache:
result = _get_from_redis(redis_key)
if result is not None:
if result is not _MISSING:
return result
else:
result = _get_from_memory(key)
if result is not None:
if result is not _MISSING:
return result
# Cache miss - execute function
logger.debug(f"Cache miss for {func_name}")
result = target_func(*args, **kwargs)
# Store result
if shared_cache:
_set_to_redis(redis_key, result)
else:
_set_to_memory(key, result)
# Store result (skip ``None`` if the caller opted out of
# caching it — used for transient-error sentinels).
if cache_none or result is not None:
if shared_cache:
_set_to_redis(redis_key, result)
else:
_set_to_memory(key, result)
return result

View File

@@ -1223,3 +1223,123 @@ class TestCacheHMAC:
assert call_count == 2
legacy_test_fn.cache_clear()
class TestCacheNoneHandling:
"""Tests for the ``cache_none`` parameter on the @cached decorator.
Sentry bug PRRT_kwDOJKSTjM56RTEu (HIGH): the cache previously could not
distinguish "no entry" from "entry is None", so any function returning
``None`` was effectively re-executed on every call. The fix is a
sentinel-based check inside the wrappers, plus an opt-out
``cache_none=False`` flag for callers that *want* errors to retry.
"""
@pytest.mark.asyncio
async def test_async_none_is_cached_by_default(self):
"""With ``cache_none=True`` (default), cached ``None`` is returned
from the cache instead of triggering re-execution."""
call_count = 0
@cached(ttl_seconds=300)
async def maybe_none(x: int) -> int | None:
nonlocal call_count
call_count += 1
return None
assert await maybe_none(1) is None
assert call_count == 1
# Second call should hit the cache, not re-execute.
assert await maybe_none(1) is None
assert call_count == 1
# Different argument is a different cache key — re-executes.
assert await maybe_none(2) is None
assert call_count == 2
def test_sync_none_is_cached_by_default(self):
call_count = 0
@cached(ttl_seconds=300)
def maybe_none(x: int) -> int | None:
nonlocal call_count
call_count += 1
return None
assert maybe_none(1) is None
assert maybe_none(1) is None
assert call_count == 1
@pytest.mark.asyncio
async def test_async_cache_none_false_skips_storing_none(self):
"""``cache_none=False`` skips storing ``None`` so transient errors
are retried on the next call instead of poisoning the cache."""
call_count = 0
results: list[int | None] = [None, None, 42]
@cached(ttl_seconds=300, cache_none=False)
async def maybe_none(x: int) -> int | None:
nonlocal call_count
result = results[call_count]
call_count += 1
return result
# First call: returns None, NOT stored.
assert await maybe_none(1) is None
assert call_count == 1
# Second call with same key: re-executes (None wasn't cached).
assert await maybe_none(1) is None
assert call_count == 2
# Third call: returns 42, this time it IS stored.
assert await maybe_none(1) == 42
assert call_count == 3
# Fourth call: cache hit on the stored 42.
assert await maybe_none(1) == 42
assert call_count == 3
def test_sync_cache_none_false_skips_storing_none(self):
call_count = 0
results: list[int | None] = [None, 99]
@cached(ttl_seconds=300, cache_none=False)
def maybe_none(x: int) -> int | None:
nonlocal call_count
result = results[call_count]
call_count += 1
return result
assert maybe_none(1) is None
assert call_count == 1
# None was not stored — re-executes.
assert maybe_none(1) == 99
assert call_count == 2
# 99 IS stored — no re-execution.
assert maybe_none(1) == 99
assert call_count == 2
@pytest.mark.asyncio
async def test_async_shared_cache_none_is_cached_by_default(self):
"""Shared (Redis) cache also properly returns cached ``None`` values."""
call_count = 0
@cached(ttl_seconds=30, shared_cache=True)
async def maybe_none_redis(x: int) -> int | None:
nonlocal call_count
call_count += 1
return None
maybe_none_redis.cache_clear()
assert await maybe_none_redis(1) is None
assert call_count == 1
assert await maybe_none_redis(1) is None
assert call_count == 1
maybe_none_redis.cache_clear()

View File

@@ -1,6 +1,7 @@
import contextlib
import logging
import os
import uuid
from enum import Enum
from functools import wraps
from typing import Any, Awaitable, Callable, TypeVar
@@ -101,6 +102,12 @@ async def _fetch_user_context_data(user_id: str) -> Context:
"""
builder = Context.builder(user_id).kind("user").anonymous(True)
try:
uuid.UUID(user_id)
except ValueError:
# Non-UUID key (e.g. "system") — skip Supabase lookup, return anonymous context.
return builder.build()
try:
from backend.util.clients import get_supabase

View File

@@ -40,6 +40,8 @@
"folder_id": null,
"folder_name": null,
"recommended_schedule_cron": null,
"is_scheduled": false,
"next_scheduled_run": null,
"settings": {
"human_in_the_loop_safe_mode": true,
"sensitive_action_safe_mode": false
@@ -86,6 +88,8 @@
"folder_id": null,
"folder_name": null,
"recommended_schedule_cron": null,
"is_scheduled": false,
"next_scheduled_run": null,
"settings": {
"human_in_the_loop_safe_mode": true,
"sensitive_action_safe_mode": false

View File

@@ -155,6 +155,7 @@
"@types/twemoji": "13.1.2",
"@vitejs/plugin-react": "5.1.2",
"@vitest/coverage-v8": "4.0.17",
"agentation": "3.0.2",
"axe-playwright": "2.2.2",
"chromatic": "13.3.3",
"concurrently": "9.2.1",

View File

@@ -376,6 +376,9 @@ importers:
'@vitest/coverage-v8':
specifier: 4.0.17
version: 4.0.17(vitest@4.0.17(@opentelemetry/api@1.9.0)(@types/node@24.10.0)(happy-dom@20.3.4)(jiti@2.6.1)(jsdom@27.4.0)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(terser@5.44.1)(yaml@2.8.2))
agentation:
specifier: 3.0.2
version: 3.0.2(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
axe-playwright:
specifier: 2.2.2
version: 2.2.2(playwright@1.56.1)
@@ -4119,6 +4122,17 @@ packages:
resolution: {integrity: sha512-MnA+YT8fwfJPgBx3m60MNqakm30XOkyIoH1y6huTQvC0PwZG7ki8NacLBcrPbNoo8vEZy7Jpuk7+jMO+CUovTQ==}
engines: {node: '>= 14'}
agentation@3.0.2:
resolution: {integrity: sha512-iGzBxFVTuZEIKzLY6AExSLAQH6i6SwxV4pAu7v7m3X6bInZ7qlZXAwrEqyc4+EfP4gM7z2RXBF6SF4DeH0f2lA==}
peerDependencies:
react: '>=18.0.0'
react-dom: '>=18.0.0'
peerDependenciesMeta:
react:
optional: true
react-dom:
optional: true
ai@6.0.134:
resolution: {integrity: sha512-YalNEaavld/kE444gOcsMKXdVVRGEe0SK77fAFcWYcqLg+a7xKnEet8bdfrEAJTfnMjj01rhgrIL10903w1a5Q==}
engines: {node: '>=18'}
@@ -13119,6 +13133,11 @@ snapshots:
agent-base@7.1.4:
optional: true
agentation@3.0.2(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
optionalDependencies:
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
ai@6.0.134(zod@3.25.76):
dependencies:
'@ai-sdk/gateway': 3.0.77(zod@3.25.76)

View File

@@ -1,5 +1,5 @@
import { describe, expect, it } from "vitest";
import { serializeGraphForChat } from "../helpers";
import { getNodeDisplayName, serializeGraphForChat } from "../helpers";
import type { CustomNode } from "../../FlowEditor/nodes/CustomNode/CustomNode";
describe("serializeGraphForChat XML injection prevention", () => {
@@ -53,3 +53,53 @@ describe("serializeGraphForChat XML injection prevention", () => {
expect(result).toContain("&lt;injection&gt;");
});
});
function makeNode(overrides: Partial<CustomNode["data"]> = {}): CustomNode {
return {
id: "node-1",
data: {
title: "AgentExecutorBlock",
description: "",
hardcodedValues: {},
inputSchema: {},
outputSchema: {},
uiType: "agent",
block_id: "b1",
costs: [],
categories: [],
...overrides,
},
type: "custom" as const,
position: { x: 0, y: 0 },
} as unknown as CustomNode;
}
describe("getNodeDisplayName", () => {
it("returns fallback when node is undefined", () => {
expect(getNodeDisplayName(undefined, "fallback-id")).toBe("fallback-id");
});
it("returns customized_name when set", () => {
const node = makeNode({
metadata: { customized_name: "My Agent" } as any,
});
expect(getNodeDisplayName(node, "fallback")).toBe("My Agent");
});
it("returns agent_name with version via getNodeDisplayTitle delegation", () => {
const node = makeNode({
hardcodedValues: { agent_name: "Researcher", graph_version: 3 },
});
expect(getNodeDisplayName(node, "fallback")).toBe("Researcher v3");
});
it("returns block title when no custom or agent name", () => {
const node = makeNode({ title: "SomeBlock" });
expect(getNodeDisplayName(node, "fallback")).toBe("SomeBlock");
});
it("returns fallback when title is empty", () => {
const node = makeNode({ title: "" });
expect(getNodeDisplayName(node, "fallback")).toBe("fallback");
});
});

View File

@@ -1,5 +1,6 @@
import type { CustomNode } from "../FlowEditor/nodes/CustomNode/CustomNode";
import type { CustomEdge } from "../FlowEditor/edges/CustomEdge";
import { getNodeDisplayTitle } from "../FlowEditor/nodes/CustomNode/helpers";
/** Maximum nodes serialized into the AI context to prevent token overruns. */
const MAX_NODES = 100;
@@ -144,18 +145,16 @@ export function getActionKey(action: GraphAction): string {
/**
* Resolves the display name for a node: prefers the user-customized name,
* falls back to the block title, then to the raw ID.
* then agent name from hardcodedValues, then block title, then fallback ID.
* Delegates to `getNodeDisplayTitle` for the 3-tier resolution logic.
* Shared between `serializeGraphForChat` and `ActionItem` to avoid duplication.
*/
export function getNodeDisplayName(
node: CustomNode | undefined,
fallback: string,
): string {
return (
(node?.data.metadata?.customized_name as string | undefined) ||
node?.data.title ||
fallback
);
if (!node) return fallback;
return getNodeDisplayTitle(node.data) || fallback;
}
/**

View File

@@ -0,0 +1,92 @@
import { describe, it, expect } from "vitest";
import { getNodeDisplayTitle, formatNodeDisplayTitle } from "../helpers";
import { CustomNodeData } from "../CustomNode";
function makeNodeData(overrides: Partial<CustomNodeData> = {}): CustomNodeData {
return {
title: "AgentExecutorBlock",
description: "",
hardcodedValues: {},
inputSchema: {},
outputSchema: {},
uiType: "agent",
block_id: "block-1",
costs: [],
categories: [],
...overrides,
} as CustomNodeData;
}
describe("getNodeDisplayTitle", () => {
it("returns customized_name when set (tier 1)", () => {
const data = makeNodeData({
metadata: { customized_name: "My Custom Agent" } as any,
hardcodedValues: { agent_name: "Researcher", graph_version: 2 },
});
expect(getNodeDisplayTitle(data)).toBe("My Custom Agent");
});
it("returns agent_name with version when no customized_name (tier 2)", () => {
const data = makeNodeData({
hardcodedValues: { agent_name: "Researcher", graph_version: 2 },
});
expect(getNodeDisplayTitle(data)).toBe("Researcher v2");
});
it("returns agent_name without version when graph_version is undefined (tier 2)", () => {
const data = makeNodeData({
hardcodedValues: { agent_name: "Researcher" },
});
expect(getNodeDisplayTitle(data)).toBe("Researcher");
});
it("returns agent_name with version 0 (tier 2)", () => {
const data = makeNodeData({
hardcodedValues: { agent_name: "Researcher", graph_version: 0 },
});
expect(getNodeDisplayTitle(data)).toBe("Researcher v0");
});
it("returns generic block title when no custom or agent name (tier 3)", () => {
const data = makeNodeData({ title: "AgentExecutorBlock" });
expect(getNodeDisplayTitle(data)).toBe("AgentExecutorBlock");
});
it("prioritizes customized_name over agent_name", () => {
const data = makeNodeData({
metadata: { customized_name: "Renamed" } as any,
hardcodedValues: { agent_name: "Original Agent", graph_version: 1 },
});
expect(getNodeDisplayTitle(data)).toBe("Renamed");
});
});
describe("formatNodeDisplayTitle", () => {
it("returns custom name as-is without beautifying", () => {
const data = makeNodeData({
metadata: { customized_name: "my_custom_name" } as any,
});
expect(formatNodeDisplayTitle(data)).toBe("my_custom_name");
});
it("returns agent name as-is without beautifying", () => {
const data = makeNodeData({
hardcodedValues: { agent_name: "Blockchain Agent", graph_version: 1 },
});
expect(formatNodeDisplayTitle(data)).toBe("Blockchain Agent v1");
});
it("beautifies generic block title and strips Block suffix", () => {
const data = makeNodeData({ title: "AgentExecutorBlock" });
const result = formatNodeDisplayTitle(data);
expect(result).not.toContain("Block");
expect(result).toBe("Agent Executor");
});
it("does not corrupt agent names containing 'Block'", () => {
const data = makeNodeData({
hardcodedValues: { agent_name: "Blockchain Agent", graph_version: 2 },
});
expect(formatNodeDisplayTitle(data)).toBe("Blockchain Agent v2");
});
});

View File

@@ -6,9 +6,10 @@ import {
TooltipProvider,
TooltipTrigger,
} from "@/components/atoms/Tooltip/BaseTooltip";
import { beautifyString, cn } from "@/lib/utils";
import { useState } from "react";
import { cn } from "@/lib/utils";
import { useEffect, useState } from "react";
import { CustomNodeData } from "../CustomNode";
import { formatNodeDisplayTitle, getNodeDisplayTitle } from "../helpers";
import { NodeBadges } from "./NodeBadges";
import { NodeContextMenu } from "./NodeContextMenu";
import { NodeCost } from "./NodeCost";
@@ -21,15 +22,24 @@ type Props = {
export const NodeHeader = ({ data, nodeId }: Props) => {
const updateNodeData = useNodeStore((state) => state.updateNodeData);
const title = (data.metadata?.customized_name as string) || data.title;
const title = getNodeDisplayTitle(data);
const displayTitle = formatNodeDisplayTitle(data);
const [isEditingTitle, setIsEditingTitle] = useState(false);
const [editedTitle, setEditedTitle] = useState(title);
useEffect(() => {
if (!isEditingTitle) {
setEditedTitle(title);
}
}, [title, isEditingTitle]);
const handleTitleEdit = () => {
updateNodeData(nodeId, {
metadata: { ...data.metadata, customized_name: editedTitle },
});
if (editedTitle !== title) {
updateNodeData(nodeId, {
metadata: { ...data.metadata, customized_name: editedTitle },
});
}
setIsEditingTitle(false);
};
@@ -72,12 +82,12 @@ export const NodeHeader = ({ data, nodeId }: Props) => {
variant="large-semibold"
className="line-clamp-1 hover:cursor-text"
>
{beautifyString(title).replace("Block", "").trim()}
{displayTitle}
</Text>
</div>
</TooltipTrigger>
<TooltipContent>
<p>{beautifyString(title).replace("Block", "").trim()}</p>
<p>{displayTitle}</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>

View File

@@ -0,0 +1,121 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { render, screen, fireEvent } from "@/tests/integrations/test-utils";
import { NodeHeader } from "../NodeHeader";
import { CustomNodeData } from "../../CustomNode";
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
vi.mock("../NodeCost", () => ({
NodeCost: () => <div data-testid="node-cost" />,
}));
vi.mock("../NodeContextMenu", () => ({
NodeContextMenu: () => <div data-testid="node-context-menu" />,
}));
vi.mock("../NodeBadges", () => ({
NodeBadges: () => <div data-testid="node-badges" />,
}));
function makeData(overrides: Partial<CustomNodeData> = {}): CustomNodeData {
return {
title: "AgentExecutorBlock",
description: "",
hardcodedValues: {},
inputSchema: {},
outputSchema: {},
uiType: "agent",
block_id: "block-1",
costs: [],
categories: [],
...overrides,
} as CustomNodeData;
}
describe("NodeHeader", () => {
const mockUpdateNodeData = vi.fn();
beforeEach(() => {
vi.clearAllMocks();
useNodeStore.setState({ updateNodeData: mockUpdateNodeData } as any);
});
it("renders beautified generic block title", () => {
render(<NodeHeader data={makeData()} nodeId="abc-123" />);
expect(screen.getByText("Agent Executor")).toBeTruthy();
});
it("renders agent name with version from hardcodedValues", () => {
const data = makeData({
hardcodedValues: { agent_name: "Researcher", graph_version: 2 },
});
render(<NodeHeader data={data} nodeId="abc-123" />);
expect(screen.getByText("Researcher v2")).toBeTruthy();
});
it("renders customized_name over agent name", () => {
const data = makeData({
metadata: { customized_name: "My Custom Node" } as any,
hardcodedValues: { agent_name: "Researcher", graph_version: 1 },
});
render(<NodeHeader data={data} nodeId="abc-123" />);
expect(screen.getByText("My Custom Node")).toBeTruthy();
});
it("shows node ID prefix", () => {
render(<NodeHeader data={makeData()} nodeId="abc-123" />);
expect(screen.getByText("#abc")).toBeTruthy();
});
it("enters edit mode on double-click and saves on blur", () => {
render(<NodeHeader data={makeData()} nodeId="node-1" />);
const titleEl = screen.getByText("Agent Executor");
fireEvent.doubleClick(titleEl);
const input = screen.getByDisplayValue("AgentExecutorBlock");
fireEvent.change(input, { target: { value: "New Name" } });
fireEvent.blur(input);
expect(mockUpdateNodeData).toHaveBeenCalledWith("node-1", {
metadata: { customized_name: "New Name" },
});
});
it("does not save when title is unchanged on blur", () => {
const data = makeData({
hardcodedValues: { agent_name: "Researcher", graph_version: 2 },
});
render(<NodeHeader data={data} nodeId="node-1" />);
const titleEl = screen.getByText("Researcher v2");
fireEvent.doubleClick(titleEl);
const input = screen.getByDisplayValue("Researcher v2");
fireEvent.blur(input);
expect(mockUpdateNodeData).not.toHaveBeenCalled();
});
it("saves on Enter key", () => {
render(<NodeHeader data={makeData()} nodeId="node-1" />);
fireEvent.doubleClick(screen.getByText("Agent Executor"));
const input = screen.getByDisplayValue("AgentExecutorBlock");
fireEvent.change(input, { target: { value: "Renamed" } });
fireEvent.keyDown(input, { key: "Enter" });
expect(mockUpdateNodeData).toHaveBeenCalledWith("node-1", {
metadata: { customized_name: "Renamed" },
});
});
it("cancels edit on Escape key", () => {
render(<NodeHeader data={makeData()} nodeId="node-1" />);
fireEvent.doubleClick(screen.getByText("Agent Executor"));
const input = screen.getByDisplayValue("AgentExecutorBlock");
fireEvent.change(input, { target: { value: "Changed" } });
fireEvent.keyDown(input, { key: "Escape" });
expect(mockUpdateNodeData).not.toHaveBeenCalled();
expect(screen.getByText("Agent Executor")).toBeTruthy();
});
});

View File

@@ -1,6 +1,55 @@
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import { NodeResolutionData } from "@/app/(platform)/build/stores/types";
import { beautifyString } from "@/lib/utils";
import { RJSFSchema } from "@rjsf/utils";
import { CustomNodeData } from "./CustomNode";
/**
* Resolves the display title for a node using a 3-tier fallback:
*
* 1. `customized_name` — the user's manual rename (highest priority)
* 2. `agent_name` (+ version) from `hardcodedValues` — the selected agent's
* display name, persisted by blocks like AgentExecutorBlock
* 3. `data.title` — the generic block name (e.g. "Agent Executor")
*
* `customized_name` is the user's explicit rename via double-click; it lives in
* node metadata. `agent_name` is the programmatic name of the agent graph
* selected in the block's input form; it lives in `hardcodedValues` alongside
* `graph_version`. These are distinct sources of truth — customized_name always
* wins because it reflects deliberate user intent.
*/
export function getNodeDisplayTitle(data: CustomNodeData): string {
if (data.metadata?.customized_name) {
return data.metadata.customized_name as string;
}
const agentName = data.hardcodedValues?.agent_name as string | undefined;
const graphVersion = data.hardcodedValues?.graph_version as
| number
| undefined;
if (agentName) {
return graphVersion != null ? `${agentName} v${graphVersion}` : agentName;
}
return data.title;
}
/**
* Returns the formatted display title for rendering.
* Agent names and custom names are shown as-is; generic block names get
* beautified and have the trailing " Block" suffix stripped.
*/
export function formatNodeDisplayTitle(data: CustomNodeData): string {
const title = getNodeDisplayTitle(data);
const isAgentOrCustom = !!(
data.metadata?.customized_name || data.hardcodedValues?.agent_name
);
return isAgentOrCustom
? title
: beautifyString(title)
.replace(/ Block$/, "")
.trim();
}
export const nodeStyleBasedOnStatus: Record<AgentExecutionStatus, string> = {
INCOMPLETE: "ring-slate-300 bg-slate-300",

View File

@@ -1,3 +1,4 @@
import { formatNodeDisplayTitle } from "@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/helpers";
import { Separator } from "@/components/ui/separator";
import { ScrollArea } from "@/components/ui/scroll-area";
import { beautifyString, cn } from "@/lib/utils";
@@ -58,9 +59,7 @@ export function GraphSearchContent({
filteredNodes.map((node, index) => {
if (!node?.data) return null;
const nodeTitle =
(node.data.metadata?.customized_name as string) ||
beautifyString(node.data.title || "").replace(/ Block$/, "");
const nodeTitle = formatNodeDisplayTitle(node.data);
const nodeType = beautifyString(node.data.title || "").replace(
/ Block$/,
"",
@@ -70,7 +69,10 @@ export function GraphSearchContent({
node.data.description ||
"";
const hasCustomName = !!node.data.metadata?.customized_name;
const hasCustomName = !!(
node.data.metadata?.customized_name ||
node.data.hardcodedValues?.agent_name
);
return (
<div

View File

@@ -69,6 +69,9 @@ function calculateNodeScore(
const customizedName = String(
node.data?.metadata?.customized_name || "",
).toLowerCase();
const agentName = String(
node.data?.hardcodedValues?.agent_name || "",
).toLowerCase();
// Get input and output names with defensive checks
const inputNames = Object.keys(node.data?.inputSchema?.properties || {}).map(
@@ -81,6 +84,7 @@ function calculateNodeScore(
// 1. Check exact match in customized name, title (includes ID), node ID, or block type (highest priority)
if (
customizedName.includes(query) ||
agentName.includes(query) ||
nodeTitle.includes(query) ||
nodeID.includes(query) ||
blockType.includes(query) ||
@@ -95,6 +99,7 @@ function calculateNodeScore(
queryWords.every(
(word) =>
customizedName.includes(word) ||
agentName.includes(word) ||
nodeTitle.includes(word) ||
beautifiedBlockType.includes(word),
)

View File

@@ -0,0 +1,87 @@
import { renderHook } from "@testing-library/react";
import { beforeEach, describe, expect, it, vi } from "vitest";
import { useChatSession } from "../useChatSession";
const mockUseGetV2GetSession = vi.fn();
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
useGetV2GetSession: (...args: unknown[]) => mockUseGetV2GetSession(...args),
usePostV2CreateSession: () => ({ mutateAsync: vi.fn(), isPending: false }),
getGetV2GetSessionQueryKey: (id: string) => ["session", id],
getGetV2ListSessionsQueryKey: () => ["sessions"],
}));
vi.mock("@tanstack/react-query", () => ({
useQueryClient: () => ({
invalidateQueries: vi.fn(),
setQueryData: vi.fn(),
}),
}));
vi.mock("nuqs", () => ({
parseAsString: { withDefault: (v: unknown) => v },
useQueryState: () => ["sess-1", vi.fn()],
}));
vi.mock("../helpers/convertChatSessionToUiMessages", () => ({
convertChatSessionMessagesToUiMessages: vi.fn(() => ({
messages: [],
historicalDurations: new Map(),
})),
}));
vi.mock("../helpers", () => ({
resolveSessionDryRun: vi.fn(() => false),
}));
vi.mock("@sentry/nextjs", () => ({
captureException: vi.fn(),
}));
function makeQueryResult(data: object | null) {
return {
data: data ? { status: 200, data } : undefined,
isLoading: false,
isError: false,
isFetching: false,
refetch: vi.fn(),
};
}
describe("useChatSession — pagination metadata", () => {
beforeEach(() => {
vi.clearAllMocks();
});
it("returns null for oldestSequence when no session data", () => {
mockUseGetV2GetSession.mockReturnValue(makeQueryResult(null));
const { result } = renderHook(() => useChatSession());
expect(result.current.oldestSequence).toBeNull();
});
it("returns oldestSequence from session data", () => {
mockUseGetV2GetSession.mockReturnValue(
makeQueryResult({
messages: [],
has_more_messages: true,
oldest_sequence: 50,
active_stream: null,
}),
);
const { result } = renderHook(() => useChatSession());
expect(result.current.oldestSequence).toBe(50);
});
it("returns hasMoreMessages from session data", () => {
mockUseGetV2GetSession.mockReturnValue(
makeQueryResult({
messages: [],
has_more_messages: true,
oldest_sequence: 0,
active_stream: null,
}),
);
const { result } = renderHook(() => useChatSession());
expect(result.current.hasMoreMessages).toBe(true);
});
});

View File

@@ -0,0 +1,131 @@
import { renderHook } from "@testing-library/react";
import { beforeEach, describe, expect, it, vi } from "vitest";
import { useCopilotPage } from "../useCopilotPage";
const mockUseChatSession = vi.fn();
const mockUseCopilotStream = vi.fn();
const mockUseLoadMoreMessages = vi.fn();
vi.mock("../useChatSession", () => ({
useChatSession: (...args: unknown[]) => mockUseChatSession(...args),
}));
vi.mock("../useCopilotStream", () => ({
useCopilotStream: (...args: unknown[]) => mockUseCopilotStream(...args),
}));
vi.mock("../useLoadMoreMessages", () => ({
useLoadMoreMessages: (...args: unknown[]) => mockUseLoadMoreMessages(...args),
}));
vi.mock("../useCopilotNotifications", () => ({
useCopilotNotifications: () => undefined,
}));
vi.mock("../useWorkflowImportAutoSubmit", () => ({
useWorkflowImportAutoSubmit: () => undefined,
}));
vi.mock("../store", () => ({
useCopilotUIStore: () => ({
sessionToDelete: null,
setSessionToDelete: vi.fn(),
isDrawerOpen: false,
setDrawerOpen: vi.fn(),
copilotChatMode: "chat",
copilotLlmModel: null,
isDryRun: false,
}),
}));
vi.mock("../helpers/convertChatSessionToUiMessages", () => ({
concatWithAssistantMerge: (a: unknown[], b: unknown[]) => [...a, ...b],
}));
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
useDeleteV2DeleteSession: () => ({ mutate: vi.fn(), isPending: false }),
useGetV2ListSessions: () => ({ data: undefined, isLoading: false }),
getGetV2ListSessionsQueryKey: () => ["sessions"],
}));
vi.mock("@/components/molecules/Toast/use-toast", () => ({
toast: vi.fn(),
}));
vi.mock("@/lib/direct-upload", () => ({
uploadFileDirect: vi.fn(),
}));
vi.mock("@/lib/hooks/useBreakpoint", () => ({
useBreakpoint: () => "lg",
}));
vi.mock("@/lib/supabase/hooks/useSupabase", () => ({
useSupabase: () => ({ isUserLoading: false, isLoggedIn: true }),
}));
vi.mock("@tanstack/react-query", () => ({
useQueryClient: () => ({ invalidateQueries: vi.fn() }),
}));
vi.mock("@/services/feature-flags/use-get-flag", () => ({
Flag: { CHAT_MODE_OPTION: "CHAT_MODE_OPTION" },
useGetFlag: () => false,
}));
function makeBaseChatSession(overrides: Record<string, unknown> = {}) {
return {
sessionId: "sess-1",
setSessionId: vi.fn(),
hydratedMessages: [],
rawSessionMessages: [],
historicalDurations: new Map(),
hasActiveStream: false,
hasMoreMessages: false,
oldestSequence: null,
isLoadingSession: false,
isSessionError: false,
createSession: vi.fn(),
isCreatingSession: false,
refetchSession: vi.fn(),
sessionDryRun: false,
...overrides,
};
}
function makeBaseCopilotStream(overrides: Record<string, unknown> = {}) {
return {
messages: [],
sendMessage: vi.fn(),
stop: vi.fn(),
status: "ready",
error: undefined,
isReconnecting: false,
isSyncing: false,
isUserStoppingRef: { current: false },
rateLimitMessage: null,
dismissRateLimit: vi.fn(),
...overrides,
};
}
function makeBaseLoadMore(overrides: Record<string, unknown> = {}) {
return {
pagedMessages: [],
hasMore: false,
isLoadingMore: false,
loadMore: vi.fn(),
...overrides,
};
}
describe("useCopilotPage — backward pagination message ordering", () => {
beforeEach(() => {
vi.clearAllMocks();
});
it("prepends pagedMessages before currentMessages", () => {
const pagedMsg = { id: "paged", role: "user" };
const currentMsg = { id: "current", role: "assistant" };
mockUseChatSession.mockReturnValue(makeBaseChatSession());
mockUseCopilotStream.mockReturnValue(
makeBaseCopilotStream({ messages: [currentMsg] }),
);
mockUseLoadMoreMessages.mockReturnValue(
makeBaseLoadMore({ pagedMessages: [pagedMsg] }),
);
const { result } = renderHook(() => useCopilotPage());
// Backward: pagedMessages (older) come first
expect(result.current.messages[0]).toEqual(pagedMsg);
expect(result.current.messages[1]).toEqual(currentMsg);
});
});

View File

@@ -0,0 +1,212 @@
import { act, renderHook, waitFor } from "@testing-library/react";
import { beforeEach, describe, expect, it, vi } from "vitest";
import { useLoadMoreMessages } from "../useLoadMoreMessages";
const mockGetV2GetSession = vi.fn();
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
getV2GetSession: (...args: unknown[]) => mockGetV2GetSession(...args),
}));
vi.mock("../helpers/convertChatSessionToUiMessages", () => ({
convertChatSessionMessagesToUiMessages: vi.fn(() => ({ messages: [] })),
extractToolOutputsFromRaw: vi.fn(() => []),
}));
const BASE_ARGS = {
sessionId: "sess-1",
initialOldestSequence: 50,
initialHasMore: true,
initialPageRawMessages: [],
};
function makeSuccessResponse(overrides: {
messages?: unknown[];
has_more_messages?: boolean;
oldest_sequence?: number;
}) {
return {
status: 200,
data: {
messages: overrides.messages ?? [],
has_more_messages: overrides.has_more_messages ?? false,
oldest_sequence: overrides.oldest_sequence ?? 0,
},
};
}
describe("useLoadMoreMessages", () => {
beforeEach(() => {
vi.clearAllMocks();
});
it("initialises with empty pagedMessages and correct cursors", () => {
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
expect(result.current.pagedMessages).toHaveLength(0);
expect(result.current.hasMore).toBe(true);
expect(result.current.isLoadingMore).toBe(false);
});
it("resets all state on sessionId change", () => {
const { result, rerender } = renderHook(
(props) => useLoadMoreMessages(props),
{ initialProps: BASE_ARGS },
);
rerender({
...BASE_ARGS,
sessionId: "sess-2",
initialOldestSequence: 10,
initialHasMore: false,
});
expect(result.current.pagedMessages).toHaveLength(0);
expect(result.current.hasMore).toBe(false);
expect(result.current.isLoadingMore).toBe(false);
});
describe("loadMore — backward pagination", () => {
it("calls getV2GetSession with before_sequence", async () => {
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: [{ role: "user", content: "old", sequence: 0 }],
has_more_messages: false,
oldest_sequence: 0,
}),
);
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).toHaveBeenCalledWith(
"sess-1",
expect.objectContaining({ before_sequence: 50 }),
);
expect(result.current.hasMore).toBe(false);
});
it("is a no-op when hasMore is false", async () => {
const { result } = renderHook(() =>
useLoadMoreMessages({ ...BASE_ARGS, initialHasMore: false }),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).not.toHaveBeenCalled();
});
it("is a no-op when oldestSequence is null", async () => {
const { result } = renderHook(() =>
useLoadMoreMessages({ ...BASE_ARGS, initialOldestSequence: null }),
);
await act(async () => {
await result.current.loadMore();
});
expect(mockGetV2GetSession).not.toHaveBeenCalled();
});
});
describe("loadMore — error handling", () => {
it("does not set hasMore=false on first error", async () => {
mockGetV2GetSession.mockRejectedValueOnce(new Error("network error"));
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
await act(async () => {
await result.current.loadMore();
});
expect(result.current.hasMore).toBe(true);
expect(result.current.isLoadingMore).toBe(false);
});
it("sets hasMore=false after MAX_CONSECUTIVE_ERRORS (3) errors", async () => {
mockGetV2GetSession.mockRejectedValue(new Error("network error"));
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
for (let i = 0; i < 3; i++) {
await act(async () => {
await result.current.loadMore();
});
await waitFor(() => expect(result.current.isLoadingMore).toBe(false));
}
expect(result.current.hasMore).toBe(false);
});
it("ignores non-200 response and increments error count", async () => {
mockGetV2GetSession.mockResolvedValueOnce({ status: 500, data: {} });
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
await act(async () => {
await result.current.loadMore();
});
expect(result.current.hasMore).toBe(true);
expect(result.current.isLoadingMore).toBe(false);
});
});
describe("loadMore — MAX_OLDER_MESSAGES truncation", () => {
it("truncates accumulated messages at MAX_OLDER_MESSAGES (2000)", async () => {
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: Array.from({ length: 2001 }, (_, i) => ({
role: "user",
content: `msg ${i}`,
sequence: i,
})),
has_more_messages: true,
oldest_sequence: 0,
}),
);
const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS));
await act(async () => {
await result.current.loadMore();
});
expect(result.current.hasMore).toBe(false);
});
});
describe("pagedMessages — initialPageRawMessages extraToolOutputs", () => {
it("calls extractToolOutputsFromRaw with non-empty initialPageRawMessages", async () => {
const { extractToolOutputsFromRaw } = await import(
"../helpers/convertChatSessionToUiMessages"
);
const rawMsg = { role: "user", content: "old", sequence: 0 };
mockGetV2GetSession.mockResolvedValueOnce(
makeSuccessResponse({
messages: [rawMsg],
has_more_messages: false,
oldest_sequence: 0,
}),
);
const { result } = renderHook(() =>
useLoadMoreMessages({
...BASE_ARGS,
initialPageRawMessages: [{ role: "assistant", content: "response" }],
}),
);
await act(async () => {
await result.current.loadMore();
});
expect(extractToolOutputsFromRaw).toHaveBeenCalled();
});
});
});

View File

@@ -6,9 +6,11 @@ import { Suspense, useState } from "react";
import { Skeleton } from "@/components/ui/skeleton";
import type { ArtifactRef } from "../../../store";
import type { ArtifactClassification } from "../helpers";
import { ArtifactErrorBoundary } from "./ArtifactErrorBoundary";
import { ArtifactReactPreview } from "./ArtifactReactPreview";
import { ArtifactSkeleton } from "./ArtifactSkeleton";
import {
FRAGMENT_LINK_INTERCEPTOR_SCRIPT,
TAILWIND_CDN_URL,
wrapWithHeadInjection,
} from "@/lib/iframe-sandbox-csp";
@@ -53,13 +55,18 @@ function ArtifactContentLoader({
return (
<div ref={scrollRef} className="flex-1 overflow-y-auto">
<ArtifactRenderer
artifact={artifact}
content={content}
pdfUrl={pdfUrl}
isSourceView={isSourceView}
classification={classification}
/>
<ArtifactErrorBoundary
artifactTitle={artifact.title}
artifactType={classification.type}
>
<ArtifactRenderer
artifact={artifact}
content={content}
pdfUrl={pdfUrl}
isSourceView={isSourceView}
classification={classification}
/>
</ArtifactErrorBoundary>
</div>
);
}
@@ -200,7 +207,10 @@ function ArtifactRenderer({
if (classification.type === "html") {
// Inject Tailwind CDN — no CSP (see iframe-sandbox-csp.ts for why)
const tailwindScript = `<script src="${TAILWIND_CDN_URL}"></script>`;
const wrapped = wrapWithHeadInjection(content, tailwindScript);
const wrapped = wrapWithHeadInjection(
content,
tailwindScript + FRAGMENT_LINK_INTERCEPTOR_SCRIPT,
);
return (
<iframe
sandbox="allow-scripts"

View File

@@ -0,0 +1,96 @@
"use client";
import * as Sentry from "@sentry/nextjs";
import { Component, type ErrorInfo, type ReactNode } from "react";
interface Props {
children: ReactNode;
artifactTitle: string;
artifactType: string;
}
interface State {
error: Error | null;
}
export class ArtifactErrorBoundary extends Component<Props, State> {
state: State = { error: null };
static getDerivedStateFromError(error: Error): State {
return { error };
}
componentDidCatch(error: Error, errorInfo: ErrorInfo) {
Sentry.captureException(error, {
contexts: {
react: { componentStack: errorInfo.componentStack },
},
tags: { errorBoundary: "true", context: "copilot-artifact" },
extra: {
artifactTitle: this.props.artifactTitle,
artifactType: this.props.artifactType,
},
});
}
componentDidUpdate(prevProps: Props) {
if (
this.state.error &&
(prevProps.artifactTitle !== this.props.artifactTitle ||
prevProps.artifactType !== this.props.artifactType)
) {
this.setState({ error: null });
}
}
handleCopy = () => {
const { error } = this.state;
if (!error) return;
const details = [
`Artifact: ${this.props.artifactTitle}`,
`Type: ${this.props.artifactType}`,
`Error: ${error.message}`,
error.stack ? `Stack:\n${error.stack}` : "",
]
.filter(Boolean)
.join("\n");
navigator.clipboard?.writeText(details).catch(() => {});
};
render() {
const { error } = this.state;
if (!error) return this.props.children;
const message = error.message || "Unknown rendering error";
return (
<div
role="alert"
className="flex h-full flex-col items-center justify-center gap-3 p-8 text-center"
>
<p className="text-sm font-medium text-zinc-700">
This artifact couldn&apos;t be rendered
</p>
<p className="max-w-md break-words text-xs text-zinc-500">
Something in{" "}
<span className="font-mono">{this.props.artifactTitle}</span> threw an
error while rendering. The chat and sidebar are still working.
</p>
<pre className="max-h-32 max-w-md overflow-auto whitespace-pre-wrap break-words rounded-md bg-zinc-100 px-3 py-2 text-left text-xs text-zinc-700">
{message}
</pre>
<button
type="button"
onClick={this.handleCopy}
className="rounded-md border border-zinc-200 bg-white px-3 py-1.5 text-xs font-medium text-zinc-700 shadow-sm transition-colors hover:bg-zinc-50 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-violet-400"
>
Copy error details
</button>
<p className="max-w-md text-xs text-zinc-400">
Paste this into the chat so the agent can regenerate a working
version.
</p>
</div>
);
}
}

View File

@@ -412,6 +412,41 @@ describe("ArtifactContent", () => {
expect(iframe?.getAttribute("sandbox")).toBe("allow-scripts");
});
it("injects the fragment-link interceptor into HTML artifact iframes (regression)", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () =>
Promise.resolve(
'<html><head></head><body><a href="#x">x</a><div id="x">x</div></body></html>',
),
}),
);
const { container } = render(
<ArtifactContent
artifact={makeArtifact({
id: "html-frag",
title: "page.html",
mimeType: "text/html",
})}
isSourceView={false}
classification={makeClassification({ type: "html" })}
/>,
);
await screen.findByTitle("page.html");
const srcdoc = container.querySelector("iframe")?.getAttribute("srcdoc");
expect(srcdoc).toBeTruthy();
// Markers unique to FRAGMENT_LINK_INTERCEPTOR_SCRIPT — if any of these
// disappear, the interceptor is no longer being injected and fragment
// links will navigate the parent URL again.
expect(srcdoc).toContain("__fragmentLinkInterceptor");
expect(srcdoc).toContain('a[href^="#"]');
expect(srcdoc).toContain("scrollIntoView");
});
// ── Source view ───────────────────────────────────────────────────
it("renders source view as pre tag", async () => {
@@ -923,6 +958,164 @@ describe("ArtifactContent", () => {
},
);
// ── Error boundary ────────────────────────────────────────────────
it("shows a visible error instead of crashing when the renderer throws", async () => {
const consoleErr = vi.spyOn(console, "error").mockImplementation(() => {});
const originalImpl = vi
.mocked(ArtifactReactPreview)
.getMockImplementation();
vi.mocked(ArtifactReactPreview).mockImplementation(() => {
throw new Error("boom in renderer");
});
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve("source"),
}),
);
const artifact = makeArtifact({
id: "crash-001",
title: "broken.tsx",
mimeType: "text/tsx",
});
const classification = makeClassification({ type: "react" });
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
expect(
await screen.findByText(/This artifact couldn't be rendered/i),
).toBeTruthy();
expect(screen.getByText(/boom in renderer/)).toBeTruthy();
expect(
screen.getByRole("button", { name: /copy error details/i }),
).toBeTruthy();
if (originalImpl) {
vi.mocked(ArtifactReactPreview).mockImplementation(originalImpl);
}
consoleErr.mockRestore();
});
it("copies artifact title, type, and error to the clipboard", async () => {
const consoleErr = vi.spyOn(console, "error").mockImplementation(() => {});
const writeText = vi.fn().mockResolvedValue(undefined);
Object.defineProperty(navigator, "clipboard", {
value: { writeText },
writable: true,
configurable: true,
});
const originalImpl = vi
.mocked(ArtifactReactPreview)
.getMockImplementation();
vi.mocked(ArtifactReactPreview).mockImplementation(() => {
throw new Error("jsx parse failed at line 42");
});
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve("source"),
}),
);
render(
<ArtifactContent
artifact={makeArtifact({
id: "crash-002",
title: "report.tsx",
mimeType: "text/tsx",
})}
isSourceView={false}
classification={makeClassification({ type: "react" })}
/>,
);
fireEvent.click(
await screen.findByRole("button", { name: /copy error details/i }),
);
await waitFor(() => {
expect(writeText).toHaveBeenCalled();
});
const payload = writeText.mock.calls[0]![0] as string;
expect(payload).toContain("report.tsx");
expect(payload).toContain("react");
expect(payload).toContain("jsx parse failed at line 42");
if (originalImpl) {
vi.mocked(ArtifactReactPreview).mockImplementation(originalImpl);
}
consoleErr.mockRestore();
});
it("renders the user-reported plotly HTML artifact into a sandboxed iframe", async () => {
const html = `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>AutoGPT Beta Launch Interactive Report</title>
<script src="https://cdn.plot.ly/plotly-2.27.0.min.js"></script>
<style>
:root { --bg: #f8f9fa; --primary: #6c5ce7; }
* { margin: 0; padding: 0; box-sizing: border-box; }
body { font-family: 'Segoe UI', system-ui, sans-serif; }
</style>
</head>
<body>
<header><h1>\u{1F4CA} AutoGPT Beta Launch Interactive Report</h1></header>
<div class="chart-container" id="globalActivationChart"></div>
<script>
function showTab(tabId, groupId) {
const group = document.getElementById(groupId);
group.querySelectorAll('.tab-content').forEach(t => t.classList.remove('active'));
document.getElementById(tabId).classList.add('active');
}
Plotly.newPlot('globalActivationChart', [{ type: 'pie', values: [1, 2] }], {});
</script>
</body>
</html>`;
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve(html),
}),
);
const artifact = makeArtifact({
id: "html-big-report",
title: "report.html",
mimeType: "text/html",
});
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={makeClassification({ type: "html" })}
/>,
);
await screen.findByTitle("report.html");
const iframe = container.querySelector("iframe");
expect(iframe).toBeTruthy();
expect(iframe?.getAttribute("sandbox")).toBe("allow-scripts");
expect(screen.queryByText(/couldn't be rendered/i)).toBeNull();
});
it("falls back to pre tag when no renderer matches", async () => {
const { globalRegistry } = await import(
"@/components/contextual/OutputRenderers"

View File

@@ -116,4 +116,11 @@ describe("buildReactArtifactSrcDoc", () => {
expect(doc).toContain("/^[A-Z]/.test(name)");
expect(doc).toContain("wrapWithProviders");
});
it("injects the fragment-link interceptor so #anchor clicks stay inside the iframe (regression)", () => {
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
expect(doc).toContain("__fragmentLinkInterceptor");
expect(doc).toContain('a[href^="#"]');
expect(doc).toContain("scrollIntoView");
});
});

View File

@@ -19,7 +19,10 @@
* React is loaded from unpkg with pinned version and SRI integrity hashes.
*/
import { TAILWIND_CDN_URL } from "@/lib/iframe-sandbox-csp";
import {
FRAGMENT_LINK_INTERCEPTOR_SCRIPT,
TAILWIND_CDN_URL,
} from "@/lib/iframe-sandbox-csp";
export { transpileReactArtifactSource } from "./transpileReactArtifact";
@@ -95,6 +98,7 @@ export function buildReactArtifactSrcDoc(
}
</style>
<script src="${TAILWIND_CDN_URL}"></script>
${FRAGMENT_LINK_INTERCEPTOR_SCRIPT}
<script crossorigin="anonymous" src="https://unpkg.com/react@18.3.1/umd/react.production.min.js" integrity="sha384-DGyLxAyjq0f9SPpVevD6IgztCFlnMF6oW/XQGmfe+IsZ8TqEiDrcHkMLKI6fiB/Z"></script><!-- pragma: allowlist secret -->
<script crossorigin="anonymous" src="https://unpkg.com/react-dom@18.3.1/umd/react-dom.production.min.js" integrity="sha384-gTGxhz21lVGYNMcdJOyq01Edg0jhn/c22nsx0kyqP0TxaV5WVdsSH1fSDUf5YJj1"></script><!-- pragma: allowlist secret -->
</head>

View File

@@ -86,11 +86,11 @@ export function ChatInput({
title:
next === "advanced"
? "Switched to Advanced model"
: "Switched to Standard model",
: "Switched to Balanced model",
description:
next === "advanced"
? "Using the highest-capability model."
: "Using the balanced standard model.",
: "Using the balanced default model.",
});
}

View File

@@ -162,10 +162,15 @@ describe("ChatInput mode toggle", () => {
expect(mockSetCopilotChatMode).toHaveBeenCalledWith("extended_thinking");
});
it("hides toggle button when streaming", () => {
it("hides toggle buttons when streaming", () => {
mockFlagValue = true;
render(<ChatInput onSend={mockOnSend} isStreaming />);
expect(screen.queryByLabelText(/switch to/i)).toBeNull();
expect(
screen.queryByLabelText(/switch to (fast|extended thinking) mode/i),
).toBeNull();
expect(
screen.queryByLabelText(/switch to (advanced|balanced|standard) model/i),
).toBeNull();
});
it("shows mode toggle when hasSession is true and not streaming", () => {
@@ -234,7 +239,7 @@ describe("ChatInput model toggle", () => {
mockFlagValue = true;
mockCopilotLlmModel = "advanced";
render(<ChatInput onSend={mockOnSend} />);
fireEvent.click(screen.getByLabelText(/switch to standard model/i));
fireEvent.click(screen.getByLabelText(/switch to balanced model/i));
expect(mockSetCopilotLlmModel).toHaveBeenCalledWith("standard");
});
@@ -288,10 +293,10 @@ describe("ChatInput model toggle", () => {
mockFlagValue = true;
mockCopilotLlmModel = "advanced";
render(<ChatInput onSend={mockOnSend} />);
fireEvent.click(screen.getByLabelText(/switch to standard model/i));
fireEvent.click(screen.getByLabelText(/switch to balanced model/i));
expect(toast).toHaveBeenCalledWith(
expect.objectContaining({
title: expect.stringMatching(/switched to standard model/i),
title: expect.stringMatching(/switched to balanced model/i),
}),
);
});

View File

@@ -2,6 +2,11 @@
import { cn } from "@/lib/utils";
import { Flask } from "@phosphor-icons/react";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
// This button is only rendered on NEW chats (no active session).
// Once a session exists, it is hidden — the session's dry_run flag is
@@ -14,27 +19,31 @@ interface Props {
export function DryRunToggleButton({ isDryRun, onToggle }: Props) {
return (
<button
type="button"
aria-pressed={isDryRun}
onClick={onToggle}
className={cn(
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
isDryRun
? "bg-amber-100 text-amber-900 hover:bg-amber-200"
: "text-neutral-500 hover:bg-neutral-100 hover:text-neutral-700",
)}
aria-label={
isDryRun ? "Test mode active — click to disable" : "Enable Test mode"
}
title={
isDryRun
? "Test mode ON — new chats run agents as simulation (click to disable)"
: "Enable Test mode — new chats will run agents as simulation"
}
>
<Flask size={14} />
{isDryRun && "Test"}
</button>
<Tooltip>
<TooltipTrigger asChild>
<button
type="button"
aria-pressed={isDryRun}
onClick={onToggle}
className={cn(
"inline-flex h-9 items-center justify-center gap-1 rounded-full border border-neutral-200 bg-white px-2.5 text-xs font-medium shadow-sm transition-colors hover:bg-neutral-50",
isDryRun
? "text-amber-900"
: "text-neutral-500 hover:text-neutral-700",
)}
aria-label={isDryRun ? "Test mode active" : "Enable Test mode"}
>
<Flask size={14} />
<span className="hidden sm:inline">
{isDryRun ? "Test mode enabled" : "Enable test mode"}
</span>
</button>
</TooltipTrigger>
<TooltipContent>
{isDryRun
? "Test mode on — new sessions run without performing real actions (click to turn off)."
: "Turn on test mode to try prompts without performing real actions."}
</TooltipContent>
</Tooltip>
);
}

View File

@@ -2,6 +2,11 @@
import { cn } from "@/lib/utils";
import { Brain, Lightning } from "@phosphor-icons/react";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
import type { CopilotMode } from "../../../store";
interface Props {
@@ -11,37 +16,42 @@ interface Props {
export function ModeToggleButton({ mode, onToggle }: Props) {
const isExtended = mode === "extended_thinking";
const tooltipText = isExtended
? "Extended Thinking — deeper reasoning (click to switch to Fast)"
: "Fast mode — quicker responses (click to switch to Thinking)";
return (
<button
type="button"
aria-pressed={isExtended}
onClick={onToggle}
className={cn(
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
isExtended
? "bg-purple-100 text-purple-900 hover:bg-purple-200"
: "bg-amber-100 text-amber-900 hover:bg-amber-200",
)}
aria-label={
isExtended ? "Switch to Fast mode" : "Switch to Extended Thinking mode"
}
title={
isExtended
? "Extended Thinking mode — deeper reasoning (click to switch to Fast mode)"
: "Fast mode — quicker responses (click to switch to Extended Thinking)"
}
>
{isExtended ? (
<>
<Brain size={14} />
Thinking
</>
) : (
<>
<Lightning size={14} />
Fast
</>
)}
</button>
<Tooltip>
<TooltipTrigger asChild>
<button
type="button"
aria-pressed={isExtended}
onClick={onToggle}
className={cn(
"ml-2 inline-flex h-9 items-center justify-center gap-1 rounded-full border border-neutral-200 bg-white px-2.5 text-xs font-medium shadow-sm transition-colors hover:bg-neutral-50",
isExtended ? "text-purple-900" : "text-amber-900",
)}
aria-label={
isExtended
? "Switch to Fast mode"
: "Switch to Extended Thinking mode"
}
>
{isExtended ? (
<>
<Brain size={14} />
Thinking
</>
) : (
<>
<Lightning size={14} />
Fast
</>
)}
</button>
</TooltipTrigger>
<TooltipContent>{tooltipText}</TooltipContent>
</Tooltip>
);
}

View File

@@ -2,6 +2,11 @@
import { cn } from "@/lib/utils";
import { Cpu } from "@phosphor-icons/react";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
import type { CopilotLlmModel } from "../../../store";
interface Props {
@@ -12,27 +17,33 @@ interface Props {
export function ModelToggleButton({ model, onToggle }: Props) {
const isAdvanced = model === "advanced";
return (
<button
type="button"
aria-pressed={isAdvanced}
onClick={onToggle}
className={cn(
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
isAdvanced
? "bg-sky-100 text-sky-900 hover:bg-sky-200"
: "text-neutral-500 hover:bg-neutral-100 hover:text-neutral-700",
)}
aria-label={
isAdvanced ? "Switch to Standard model" : "Switch to Advanced model"
}
title={
isAdvanced
? "Advanced model — highest capability (click to switch to Standard)"
: "Standard model — click to switch to Advanced"
}
>
<Cpu size={14} />
{isAdvanced && "Advanced"}
</button>
<Tooltip>
<TooltipTrigger asChild>
<button
type="button"
aria-pressed={isAdvanced}
onClick={onToggle}
className={cn(
"inline-flex h-9 items-center justify-center gap-1 rounded-full border border-neutral-200 bg-white px-2.5 text-xs font-medium shadow-sm transition-colors hover:bg-neutral-50",
isAdvanced
? "text-sky-900"
: "text-neutral-500 hover:text-neutral-700",
)}
aria-label={
isAdvanced ? "Switch to Balanced model" : "Switch to Advanced model"
}
>
<Cpu size={14} />
<span className="hidden sm:inline">
{isAdvanced ? "Advanced" : "Balanced"}
</span>
</button>
</TooltipTrigger>
<TooltipContent>
{isAdvanced
? "Using the highest-capability model (click to switch to Balanced)."
: "Using the balanced default model (click to switch to Advanced)."}
</TooltipContent>
</Tooltip>
);
}

View File

@@ -1,21 +1,32 @@
import { render, screen, fireEvent, cleanup } from "@testing-library/react";
import {
render as rtlRender,
screen,
fireEvent,
cleanup,
} from "@testing-library/react";
import { afterEach, describe, expect, it, vi } from "vitest";
import type { ReactElement } from "react";
import { TooltipProvider } from "@/components/ui/tooltip";
import { DryRunToggleButton } from "../DryRunToggleButton";
afterEach(cleanup);
function render(ui: ReactElement) {
return rtlRender(<TooltipProvider>{ui}</TooltipProvider>);
}
// DryRunToggleButton only appears on new chats (no active session).
// It has no readOnly/isStreaming props — those scenarios are handled by hiding
// the button entirely at the ChatInput level when hasSession is true.
describe("DryRunToggleButton", () => {
it("shows Test label when isDryRun is true", () => {
it("shows enabled label when isDryRun is true", () => {
render(<DryRunToggleButton isDryRun={true} onToggle={vi.fn()} />);
expect(screen.getByText("Test")).toBeTruthy();
expect(screen.getByText("Test mode enabled")).toBeTruthy();
});
it("shows no text label when isDryRun is false", () => {
it("shows enable label when isDryRun is false", () => {
render(<DryRunToggleButton isDryRun={false} onToggle={vi.fn()} />);
expect(screen.queryByText("Test")).toBeNull();
expect(screen.getByText("Enable test mode")).toBeTruthy();
});
it("calls onToggle when clicked", () => {

View File

@@ -1,9 +1,20 @@
import { render, screen, fireEvent, cleanup } from "@testing-library/react";
import {
render as rtlRender,
screen,
fireEvent,
cleanup,
} from "@testing-library/react";
import { afterEach, describe, expect, it, vi } from "vitest";
import type { ReactElement } from "react";
import { TooltipProvider } from "@/components/ui/tooltip";
import { ModelToggleButton } from "../ModelToggleButton";
afterEach(cleanup);
function render(ui: ReactElement) {
return rtlRender(<TooltipProvider>{ui}</TooltipProvider>);
}
describe("ModelToggleButton", () => {
it("shows no text label when model is standard", () => {
render(<ModelToggleButton model="standard" onToggle={vi.fn()} />);
@@ -31,7 +42,7 @@ describe("ModelToggleButton", () => {
it("sets aria-pressed=true for advanced", () => {
render(<ModelToggleButton model="advanced" onToggle={vi.fn()} />);
const btn = screen.getByLabelText("Switch to Standard model");
const btn = screen.getByLabelText("Switch to Balanced model");
expect(btn.getAttribute("aria-pressed")).toBe("true");
});
});

View File

@@ -0,0 +1,147 @@
import { render, screen, cleanup } from "@/tests/integrations/test-utils";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { ChatMessagesContainer } from "../ChatMessagesContainer";
const mockScrollEl = {
scrollHeight: 100,
scrollTop: 0,
clientHeight: 500,
};
vi.mock("use-stick-to-bottom", () => ({
useStickToBottomContext: () => ({ scrollRef: { current: mockScrollEl } }),
Conversation: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
ConversationContent: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
ConversationScrollButton: () => null,
}));
vi.mock("@/components/ai-elements/conversation", () => ({
Conversation: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
ConversationContent: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
ConversationScrollButton: () => null,
}));
vi.mock("@/components/ai-elements/message", () => ({
Message: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
MessageContent: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
MessageActions: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
}));
vi.mock("../components/AssistantMessageActions", () => ({
AssistantMessageActions: () => null,
}));
vi.mock("../components/CopyButton", () => ({ CopyButton: () => null }));
vi.mock("../components/CollapsedToolGroup", () => ({
CollapsedToolGroup: () => null,
}));
vi.mock("../components/MessageAttachments", () => ({
MessageAttachments: () => null,
}));
vi.mock("../components/MessagePartRenderer", () => ({
MessagePartRenderer: () => null,
}));
vi.mock("../components/ReasoningCollapse", () => ({
ReasoningCollapse: () => null,
}));
vi.mock("../components/ThinkingIndicator", () => ({
ThinkingIndicator: () => null,
}));
vi.mock("../../JobStatsBar/TurnStatsBar", () => ({
TurnStatsBar: () => null,
}));
vi.mock("../../JobStatsBar/useElapsedTimer", () => ({
useElapsedTimer: () => ({ elapsedSeconds: 0 }),
}));
vi.mock("../../CopilotPendingReviews/CopilotPendingReviews", () => ({
CopilotPendingReviews: () => null,
}));
vi.mock("../helpers", () => ({
buildRenderSegments: () => [],
getTurnMessages: () => [],
parseSpecialMarkers: () => ({ markerType: null }),
splitReasoningAndResponse: (parts: unknown[]) => ({
reasoningParts: [],
responseParts: parts,
}),
}));
type ObserverCallback = (entries: { isIntersecting: boolean }[]) => void;
class MockIntersectionObserver {
static lastCallback: ObserverCallback | null = null;
private callback: ObserverCallback;
constructor(cb: ObserverCallback) {
this.callback = cb;
MockIntersectionObserver.lastCallback = cb;
}
observe() {}
disconnect() {}
unobserve() {}
takeRecords() {
return [];
}
root = null;
rootMargin = "";
thresholds = [];
}
const BASE_PROPS = {
messages: [],
status: "ready" as const,
error: undefined,
isLoading: false,
sessionID: "sess-1",
hasMoreMessages: true,
isLoadingMore: false,
onLoadMore: vi.fn(),
onRetry: vi.fn(),
};
describe("ChatMessagesContainer", () => {
beforeEach(() => {
mockScrollEl.scrollHeight = 100;
mockScrollEl.scrollTop = 0;
mockScrollEl.clientHeight = 500;
MockIntersectionObserver.lastCallback = null;
vi.stubGlobal("IntersectionObserver", MockIntersectionObserver);
});
afterEach(() => {
cleanup();
vi.unstubAllGlobals();
});
it("renders top sentinel for backward pagination", () => {
render(<ChatMessagesContainer {...BASE_PROPS} />);
expect(
screen.getByRole("button", { name: /load older messages/i }),
).toBeDefined();
});
it("hides sentinel when hasMoreMessages is false", () => {
render(<ChatMessagesContainer {...BASE_PROPS} hasMoreMessages={false} />);
expect(
screen.queryByRole("button", { name: /load older messages/i }),
).toBeNull();
});
it("hides sentinel when onLoadMore is not provided", () => {
render(<ChatMessagesContainer {...BASE_PROPS} onLoadMore={undefined} />);
expect(
screen.queryByRole("button", { name: /load older messages/i }),
).toBeNull();
});
});

View File

@@ -246,7 +246,7 @@ export function ChatSidebar() {
</SidebarHeader>
)}
{!isCollapsed && (
<SidebarHeader className="shrink-0 px-4 pb-4 pt-4 shadow-[0_4px_6px_-1px_rgba(0,0,0,0.05)]">
<SidebarHeader className="shrink-0 px-4 pb-3 pt-3 shadow-[0_4px_6px_-1px_rgba(0,0,0,0.05)]">
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}

View File

@@ -13,6 +13,10 @@ import {
getSuggestionThemes,
} from "./helpers";
import { SuggestionThemes } from "./components/SuggestionThemes/SuggestionThemes";
import { PulseChips } from "../PulseChips/PulseChips";
import { usePulseChips } from "../PulseChips/usePulseChips";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { EditNameDialog } from "./components/EditNameDialog/EditNameDialog";
interface Props {
inputLayoutId: string;
@@ -34,6 +38,8 @@ export function EmptySession({
}: Props) {
const { user } = useSupabase();
const greetingName = getGreetingName(user);
const isAgentBriefingEnabled = useGetFlag(Flag.AGENT_BRIEFING);
const pulseChips = usePulseChips();
const { data: suggestedPromptsResponse, isLoading: isLoadingPrompts } =
useGetV2GetSuggestedPrompts({
@@ -75,11 +81,16 @@ export function EmptySession({
<div className="mx-auto max-w-[52rem]">
<Text variant="h3" className="mb-1 !text-[1.375rem] text-zinc-700">
Hey, <span className="text-violet-600">{greetingName}</span>
<EditNameDialog currentName={greetingName} />
</Text>
<Text variant="h3" className="mb-8 !font-normal">
Tell me about your work I&apos;ll find what to automate.
</Text>
{isAgentBriefingEnabled && (
<PulseChips chips={pulseChips} onChipClick={onSend} />
)}
<div className="mb-6">
<motion.div
layoutId={inputLayoutId}

View File

@@ -0,0 +1,107 @@
"use client";
import { Button } from "@/components/atoms/Button/Button";
import { Input } from "@/components/atoms/Input/Input";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { PencilSimpleIcon } from "@phosphor-icons/react";
import { useState } from "react";
interface Props {
currentName: string;
}
export function EditNameDialog({ currentName }: Props) {
const [isOpen, setIsOpen] = useState(false);
const [name, setName] = useState(currentName);
const [isSaving, setIsSaving] = useState(false);
const { refreshSession } = useSupabase();
const { toast } = useToast();
function handleOpenChange(open: boolean) {
if (open) setName(currentName);
setIsOpen(open);
}
async function handleSave() {
const trimmed = name.trim();
if (!trimmed) return;
setIsSaving(true);
try {
const res = await fetch("/api/auth/user", {
method: "PUT",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ full_name: trimmed }),
});
if (!res.ok) {
const body = await res.json();
toast({
title: "Failed to update name",
description: body.error ?? "Unknown error",
variant: "destructive",
});
return;
}
const session = await refreshSession();
if (session?.error) {
toast({
title: "Name saved, but session refresh failed",
description: session.error,
variant: "destructive",
});
setIsOpen(false);
return;
}
setIsOpen(false);
toast({ title: "Name updated" });
} finally {
setIsSaving(false);
}
}
return (
<Dialog
title="Edit display name"
styling={{ maxWidth: "24rem" }}
controlled={{ isOpen, set: handleOpenChange }}
>
<Dialog.Trigger>
<button
type="button"
className="ml-1 inline-flex items-center text-violet-500 transition-colors hover:text-violet-700"
>
<PencilSimpleIcon size={16} />
</button>
</Dialog.Trigger>
<Dialog.Content>
<div className="flex flex-col gap-4 px-1">
<Input
id="display-name"
label="Display name"
placeholder="Your name"
value={name}
onChange={(e) => setName(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter") {
e.preventDefault();
handleSave();
}
}}
/>
<Button
variant="primary"
onClick={handleSave}
disabled={!name.trim() || isSaving}
loading={isSaving}
>
Save
</Button>
</div>
</Dialog.Content>
</Dialog>
);
}

View File

@@ -0,0 +1,135 @@
import { beforeEach, describe, expect, test, vi } from "vitest";
import {
fireEvent,
render,
screen,
waitFor,
} from "@/tests/integrations/test-utils";
import { server } from "@/mocks/mock-server";
import { http, HttpResponse } from "msw";
import { EditNameDialog } from "../EditNameDialog";
const mockToast = vi.hoisted(() => vi.fn());
const mockRefreshSession = vi.hoisted(() => vi.fn());
vi.mock("@/components/molecules/Toast/use-toast", () => ({
useToast: () => ({ toast: mockToast }),
}));
vi.mock("@/lib/supabase/hooks/useSupabase", () => ({
useSupabase: () => ({
refreshSession: mockRefreshSession,
}),
}));
function mockUpdateNameSuccess() {
server.use(
http.put("/api/auth/user", () => {
return HttpResponse.json({ user: { id: "u1" } });
}),
);
}
function mockUpdateNameError(message = "Network error") {
server.use(
http.put("/api/auth/user", () => {
return HttpResponse.json({ error: message }, { status: 400 });
}),
);
}
async function openDialogAndGetInput() {
const trigger = screen.getByRole("button");
fireEvent.click(trigger);
await screen.findAllByLabelText(/display name/i);
const inputs =
document.querySelectorAll<HTMLInputElement>("input#display-name");
return inputs[0];
}
function getSaveButton() {
const saves = screen.getAllByRole("button", { name: /save/i });
return saves[0] as HTMLButtonElement;
}
describe("EditNameDialog", () => {
beforeEach(() => {
mockToast.mockReset();
mockRefreshSession.mockReset();
mockRefreshSession.mockResolvedValue({ user: { id: "u1" } });
});
test("opens dialog with current name prefilled", async () => {
mockUpdateNameSuccess();
render(<EditNameDialog currentName="Alice" />);
const input = await openDialogAndGetInput();
expect(input.value).toBe("Alice");
});
test("saves name via API route and closes dialog", async () => {
mockUpdateNameSuccess();
render(<EditNameDialog currentName="Alice" />);
const input = await openDialogAndGetInput();
fireEvent.change(input, { target: { value: "Bob" } });
fireEvent.click(getSaveButton());
await waitFor(() => {
expect(mockRefreshSession).toHaveBeenCalled();
});
expect(mockToast).toHaveBeenCalledWith({ title: "Name updated" });
});
test("shows error toast when API returns error", async () => {
mockUpdateNameError("Network error");
render(<EditNameDialog currentName="Alice" />);
const input = await openDialogAndGetInput();
fireEvent.change(input, { target: { value: "Bob" } });
fireEvent.click(getSaveButton());
await waitFor(() => {
expect(mockToast).toHaveBeenCalledWith(
expect.objectContaining({
title: "Failed to update name",
description: "Network error",
variant: "destructive",
}),
);
});
expect(mockRefreshSession).not.toHaveBeenCalled();
});
test("shows warning toast when refreshSession returns an error", async () => {
mockUpdateNameSuccess();
mockRefreshSession.mockResolvedValue({ error: "refresh failed" });
render(<EditNameDialog currentName="Alice" />);
const input = await openDialogAndGetInput();
fireEvent.change(input, { target: { value: "Bob" } });
fireEvent.click(getSaveButton());
await waitFor(() => {
expect(mockToast).toHaveBeenCalledWith(
expect.objectContaining({
title: "Name saved, but session refresh failed",
description: "refresh failed",
variant: "destructive",
}),
);
});
expect(mockToast).not.toHaveBeenCalledWith({ title: "Name updated" });
});
test("disables Save button while input is empty", async () => {
mockUpdateNameSuccess();
render(<EditNameDialog currentName="Alice" />);
const input = await openDialogAndGetInput();
fireEvent.change(input, { target: { value: " " } });
expect(getSaveButton().disabled).toBe(true);
});
});

View File

@@ -0,0 +1,93 @@
.glassPanel {
position: relative;
isolation: isolate;
}
.glassPanel::before {
content: "";
position: absolute;
inset: 0;
border-radius: inherit;
padding: 1px;
background: conic-gradient(
from var(--border-angle, 0deg),
rgba(129, 120, 228, 0.08),
rgba(129, 120, 228, 0.28),
rgba(168, 130, 255, 0.18),
rgba(129, 120, 228, 0.08),
rgba(99, 102, 241, 0.24),
rgba(129, 120, 228, 0.08)
);
-webkit-mask:
linear-gradient(#000 0 0) content-box,
linear-gradient(#000 0 0);
mask:
linear-gradient(#000 0 0) content-box,
linear-gradient(#000 0 0);
-webkit-mask-composite: xor;
mask-composite: exclude;
animation: rotate-border 6s linear infinite;
pointer-events: none;
z-index: -1;
}
@property --border-angle {
syntax: "<angle>";
initial-value: 0deg;
inherits: false;
}
@keyframes rotate-border {
to {
--border-angle: 360deg;
}
}
.chip {
overflow: hidden;
}
@media (hover: hover) {
.chip {
padding-bottom: 0.9rem;
}
}
@media (hover: none) {
.chip {
padding-bottom: 2.25rem;
}
}
.chipActions {
position: absolute;
inset-inline: 0;
bottom: 0;
background: rgba(255, 255, 255, 0.95);
backdrop-filter: blur(4px);
-webkit-backdrop-filter: blur(4px);
}
@media (hover: hover) {
.chipActions {
opacity: 0;
transform: translateY(100%);
transition:
opacity 0.2s ease-out,
transform 0.2s ease-out;
}
.chip:hover .chipActions {
opacity: 1;
transform: translateY(0);
}
.chipContent {
transition: filter 0.2s ease-out;
}
.chip:hover .chipContent {
filter: blur(2px);
opacity: 0.5;
}
}

View File

@@ -0,0 +1,116 @@
"use client";
import { Text } from "@/components/atoms/Text/Text";
import {
ArrowRightIcon,
EyeIcon,
ChatCircleDotsIcon,
} from "@phosphor-icons/react";
import NextLink from "next/link";
import { StatusBadge } from "@/app/(platform)/library/components/StatusBadge/StatusBadge";
import styles from "./PulseChips.module.css";
import type { PulseChipData } from "./types";
interface Props {
chips: PulseChipData[];
onChipClick?: (prompt: string) => void;
}
export function PulseChips({ chips, onChipClick }: Props) {
if (chips.length === 0) return null;
return (
<div
className={`${styles.glassPanel} mx-[0.6875rem] mb-5 rounded-large p-5`}
>
<div className="mb-3 flex items-center gap-3">
<Text variant="body-medium" className="text-zinc-600">
What&apos;s happening with your agents
</Text>
<NextLink
href="/library"
className="flex items-center gap-1 text-xs text-zinc-500 hover:text-zinc-700"
>
View all <ArrowRightIcon size={12} />
</NextLink>
</div>
<div className="flex gap-2 overflow-x-auto pb-1 scrollbar-thin scrollbar-track-transparent scrollbar-thumb-zinc-300">
{chips.map((chip) => (
<PulseChip key={chip.id} chip={chip} onAsk={onChipClick} />
))}
</div>
</div>
);
}
interface ChipProps {
chip: PulseChipData;
onAsk?: (prompt: string) => void;
}
function PulseChip({ chip, onAsk }: ChipProps) {
function handleAsk() {
const prompt = buildChipPrompt(chip);
onAsk?.(prompt);
}
return (
<div
className={`${styles.chip} relative flex w-[15rem] shrink-0 flex-col items-start gap-2 rounded-medium border border-zinc-100 bg-white px-3 py-2`}
>
<div className={`${styles.chipContent} w-full text-left`}>
{chip.priority === "success" ? (
<span className="inline-flex items-center gap-1.5 rounded-full px-2 py-0.5 text-xs font-medium text-emerald-600">
<span className="h-1.5 w-1.5 rounded-full bg-emerald-500" />
Completed
</span>
) : (
<StatusBadge status={chip.status} />
)}
<div className="mt-2 min-w-0">
<Text variant="small-medium" className="truncate text-zinc-900">
{chip.name}
</Text>
<Text variant="small" className="truncate text-zinc-500">
{chip.shortMessage}
</Text>
</div>
</div>
<div
className={`${styles.chipActions} flex items-center justify-center gap-1.5 rounded-b-medium px-3 py-1.5`}
>
<NextLink
href={`/library/agents/${chip.agentID}`}
className="flex items-center gap-1 rounded-md px-2 py-1 text-xs text-zinc-500 transition-colors hover:bg-zinc-100 hover:text-zinc-700"
>
<EyeIcon size={14} />
See
</NextLink>
<button
type="button"
onClick={handleAsk}
className="flex items-center gap-1 rounded-md px-2 py-1 text-xs text-zinc-500 transition-colors hover:bg-zinc-100 hover:text-zinc-700"
>
<ChatCircleDotsIcon size={14} />
Ask
</button>
</div>
</div>
);
}
function buildChipPrompt(chip: PulseChipData): string {
if (chip.priority === "success") {
return `${chip.name} just finished a run — can you summarize what it did?`;
}
switch (chip.status) {
case "error":
return `What happened with ${chip.name}? It has an error — can you check?`;
case "running":
return `Give me a status update on ${chip.name} — what has it done so far?`;
case "idle":
return `${chip.name} hasn't run recently. Should I keep it or update and re-run it?`;
default:
return `Tell me about ${chip.name} — what's its current status?`;
}
}

View File

@@ -0,0 +1,105 @@
import { describe, expect, test, vi } from "vitest";
import { render, screen, fireEvent } from "@/tests/integrations/test-utils";
import { PulseChips } from "../PulseChips";
import type { PulseChipData } from "../types";
function makeChip(overrides: Partial<PulseChipData> = {}): PulseChipData {
return {
id: "chip-1",
agentID: "agent-1",
name: "Test Agent",
status: "running",
priority: "running",
shortMessage: "Doing work…",
...overrides,
};
}
describe("PulseChips", () => {
test("renders nothing when chips array is empty", () => {
const { container } = render(<PulseChips chips={[]} />);
expect(container.innerHTML).toBe("");
});
test("renders chip names and messages", () => {
const chips = [
makeChip({ id: "1", name: "Alpha Bot", shortMessage: "Running task A" }),
makeChip({ id: "2", name: "Beta Bot", shortMessage: "Running task B" }),
];
render(<PulseChips chips={chips} />);
expect(screen.getByText("Alpha Bot")).toBeDefined();
expect(screen.getByText("Running task A")).toBeDefined();
expect(screen.getByText("Beta Bot")).toBeDefined();
expect(screen.getByText("Running task B")).toBeDefined();
});
test("renders section heading and View all link", () => {
render(<PulseChips chips={[makeChip()]} />);
expect(screen.getByText("What's happening with your agents")).toBeDefined();
expect(screen.getByText("View all")).toBeDefined();
});
test("shows Completed badge for success priority chips", () => {
render(
<PulseChips
chips={[makeChip({ priority: "success", status: "idle" })]}
/>,
);
expect(screen.getByText("Completed")).toBeDefined();
});
test("calls onChipClick with generated prompt when Ask is clicked", () => {
const onChipClick = vi.fn();
render(
<PulseChips
chips={[
makeChip({
name: "Error Agent",
status: "error",
priority: "error",
}),
]}
onChipClick={onChipClick}
/>,
);
fireEvent.click(screen.getByText("Ask"));
expect(onChipClick).toHaveBeenCalledWith(
"What happened with Error Agent? It has an error — can you check?",
);
});
test("generates success prompt for completed chips", () => {
const onChipClick = vi.fn();
render(
<PulseChips
chips={[
makeChip({
name: "Done Agent",
priority: "success",
status: "idle",
}),
]}
onChipClick={onChipClick}
/>,
);
fireEvent.click(screen.getByText("Ask"));
expect(onChipClick).toHaveBeenCalledWith(
"Done Agent just finished a run — can you summarize what it did?",
);
});
test("renders See link pointing to agent detail page", () => {
render(<PulseChips chips={[makeChip({ agentID: "agent-xyz" })]} />);
const seeLink = screen.getByText("See").closest("a");
expect(seeLink?.getAttribute("href")).toBe("/library/agents/agent-xyz");
});
});

View File

@@ -0,0 +1,13 @@
import type {
AgentStatus,
SitrepPriority,
} from "@/app/(platform)/library/types";
export interface PulseChipData {
id: string;
agentID: string;
name: string;
status: AgentStatus;
priority: SitrepPriority;
shortMessage: string;
}

View File

@@ -0,0 +1,25 @@
"use client";
import { useLibraryAgents } from "@/hooks/useLibraryAgents/useLibraryAgents";
import { useSitrepItems } from "@/app/(platform)/library/components/SitrepItem/useSitrepItems";
import type { PulseChipData } from "./types";
import { useMemo } from "react";
const THREE_DAYS_MS = 3 * 24 * 60 * 60 * 1000;
export function usePulseChips(): PulseChipData[] {
const { agents } = useLibraryAgents();
const sitrepItems = useSitrepItems(agents, 5, THREE_DAYS_MS);
return useMemo(() => {
return sitrepItems.map((item) => ({
id: item.id,
agentID: item.agentID,
name: item.agentName,
status: item.status,
priority: item.priority,
shortMessage: item.message,
}));
}, [sitrepItems]);
}

View File

@@ -6,6 +6,9 @@ import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { useRouter } from "next/navigation";
import { useEffect, useRef } from "react";
import { useResetRateLimit } from "../../hooks/useResetRateLimit";
import { formatCents } from "../usageHelpers";
export { formatCents };
interface Props {
isOpen: boolean;
@@ -18,10 +21,6 @@ interface Props {
onCreditChange?: () => void;
}
export function formatCents(cents: number): string {
return `$${(cents / 100).toFixed(2)}`;
}
export function RateLimitResetDialog({
isOpen,
onClose,

View File

@@ -1,35 +1,10 @@
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
import { Button } from "@/components/atoms/Button/Button";
import Link from "next/link";
import { formatCents } from "../RateLimitResetDialog/RateLimitResetDialog";
import { formatCents, formatResetTime } from "../usageHelpers";
import { useResetRateLimit } from "../../hooks/useResetRateLimit";
export function formatResetTime(
resetsAt: Date | string,
now: Date = new Date(),
): string {
const resetDate =
typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt;
const diffMs = resetDate.getTime() - now.getTime();
if (diffMs <= 0) return "now";
const hours = Math.floor(diffMs / (1000 * 60 * 60));
// Under 24h: show relative time ("in 4h 23m")
if (hours < 24) {
const minutes = Math.floor((diffMs % (1000 * 60 * 60)) / (1000 * 60));
if (hours > 0) return `in ${hours}h ${minutes}m`;
return `in ${minutes}m`;
}
// Over 24h: show day and time in local timezone ("Mon 12:00 AM PST")
return resetDate.toLocaleString(undefined, {
weekday: "short",
hour: "numeric",
minute: "2-digit",
timeZoneName: "short",
});
}
export { formatResetTime };
function UsageBar({
label,

View File

@@ -0,0 +1,28 @@
export function formatCents(cents: number): string {
return `$${(cents / 100).toFixed(2)}`;
}
export function formatResetTime(
resetsAt: Date | string,
now: Date = new Date(),
): string {
const resetDate =
typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt;
const diffMs = resetDate.getTime() - now.getTime();
if (diffMs <= 0) return "now";
const hours = Math.floor(diffMs / (1000 * 60 * 60));
if (hours < 24) {
const minutes = Math.floor((diffMs % (1000 * 60 * 60)) / (1000 * 60));
if (hours > 0) return `in ${hours}h ${minutes}m`;
return `in ${minutes}m`;
}
return resetDate.toLocaleString(undefined, {
weekday: "short",
hour: "numeric",
minute: "2-digit",
timeZoneName: "short",
});
}

View File

@@ -0,0 +1,59 @@
import { describe, expect, it } from "vitest";
import { convertChatSessionMessagesToUiMessages } from "../convertChatSessionToUiMessages";
const SESSION_ID = "sess-test";
describe("convertChatSessionMessagesToUiMessages", () => {
it("does not drop user messages with null content", () => {
const result = convertChatSessionMessagesToUiMessages(
SESSION_ID,
[{ role: "user", content: null, sequence: 0 }],
{ isComplete: true },
);
expect(result.messages).toHaveLength(1);
expect(result.messages[0].role).toBe("user");
});
it("does not drop user messages with empty string content", () => {
const result = convertChatSessionMessagesToUiMessages(
SESSION_ID,
[{ role: "user", content: "", sequence: 0 }],
{ isComplete: true },
);
expect(result.messages).toHaveLength(1);
expect(result.messages[0].role).toBe("user");
});
it("still drops non-user messages with null content", () => {
const result = convertChatSessionMessagesToUiMessages(
SESSION_ID,
[{ role: "assistant", content: null, sequence: 0 }],
{ isComplete: true },
);
expect(result.messages).toHaveLength(0);
});
it("still drops non-user messages with empty string content", () => {
const result = convertChatSessionMessagesToUiMessages(
SESSION_ID,
[{ role: "assistant", content: "", sequence: 0 }],
{ isComplete: true },
);
expect(result.messages).toHaveLength(0);
});
it("includes user message with normal content", () => {
const result = convertChatSessionMessagesToUiMessages(
SESSION_ID,
[{ role: "user", content: "hello", sequence: 0 }],
{ isComplete: true },
);
expect(result.messages).toHaveLength(1);
expect(result.messages[0].role).toBe("user");
});
});

View File

@@ -253,6 +253,11 @@ export function convertChatSessionMessagesToUiMessages(
}
}
// User messages must always be rendered, even with empty content, so the
// initial prompt is visible when reloading a session.
if (parts.length === 0 && msg.role === "user") {
parts.push({ type: "text", text: "", state: "done" });
}
if (parts.length === 0) return;
// Merge consecutive assistant messages into a single UIMessage

View File

@@ -84,7 +84,7 @@ export function useCopilotPage() {
copilotModel: isModeToggleEnabled ? copilotLlmModel : undefined,
});
const { olderMessages, hasMore, isLoadingMore, loadMore } =
const { pagedMessages, hasMore, isLoadingMore, loadMore } =
useLoadMoreMessages({
sessionId,
initialOldestSequence: oldestSequence,
@@ -92,10 +92,11 @@ export function useCopilotPage() {
initialPageRawMessages: rawSessionMessages,
});
// Combine older (paginated) messages with current page messages,
// merging consecutive assistant UIMessages at the page boundary so
// reasoning + response parts stay in a single bubble.
const messages = concatWithAssistantMerge(olderMessages, currentMessages);
// Combine paginated messages with current page messages, merging consecutive
// assistant UIMessages at the page boundary so reasoning + response parts
// stay in a single bubble. Paged messages are older history prepended before
// the current page.
const messages = concatWithAssistantMerge(pagedMessages, currentMessages);
useCopilotNotifications(sessionId);

View File

@@ -23,10 +23,10 @@ export function useLoadMoreMessages({
initialHasMore,
initialPageRawMessages,
}: UseLoadMoreMessagesArgs) {
// Store accumulated raw messages from all older pages (in ascending order).
// Accumulated raw messages from all extra pages (ascending order).
// Re-converting them all together ensures tool outputs are matched across
// inter-page boundaries.
const [olderRawMessages, setOlderRawMessages] = useState<unknown[]>([]);
const [pagedRawMessages, setPagedRawMessages] = useState<unknown[]>([]);
const [oldestSequence, setOldestSequence] = useState<number | null>(
initialOldestSequence,
);
@@ -37,16 +37,14 @@ export function useLoadMoreMessages({
// Epoch counter to discard stale loadMore responses after a reset
const epochRef = useRef(0);
// Track the sessionId and initial cursor to reset state on change
const prevSessionIdRef = useRef(sessionId);
const prevInitialOldestRef = useRef(initialOldestSequence);
// Sync initial values from parent when they change.
//
// The parent's `initialOldestSequence` drifts forward every time the
// session query refetches (e.g. after a stream completes — see
// `useCopilotStream` invalidation on `streaming → ready`). If we
// wiped `olderRawMessages` every time that happened, users who had
// wiped `pagedRawMessages` every time that happened, users who had
// scrolled back would lose their loaded history on each new turn and
// subsequent `loadMore` calls would fetch messages that overlap with
// the AI SDK's retained state in `currentMessages`, producing visible
@@ -62,8 +60,7 @@ export function useLoadMoreMessages({
if (prevSessionIdRef.current !== sessionId) {
// Session changed — full reset
prevSessionIdRef.current = sessionId;
prevInitialOldestRef.current = initialOldestSequence;
setOlderRawMessages([]);
setPagedRawMessages([]);
setOldestSequence(initialOldestSequence);
setHasMore(initialHasMore);
setIsLoadingMore(false);
@@ -73,42 +70,35 @@ export function useLoadMoreMessages({
return;
}
prevInitialOldestRef.current = initialOldestSequence;
// If we haven't paged back yet, mirror the parent so the first
// If we haven't paged yet, mirror the parent so the first
// `loadMore` starts from the correct cursor.
if (olderRawMessages.length === 0) {
if (pagedRawMessages.length === 0) {
setOldestSequence(initialOldestSequence);
setHasMore(initialHasMore);
}
}, [sessionId, initialOldestSequence, initialHasMore]);
// Convert all accumulated raw messages in one pass so tool outputs
// are matched across inter-page boundaries. Initial page tool outputs
// are included via extraToolOutputs to handle the boundary between
// the last older page and the initial/streaming page.
const olderMessages: UIMessage<unknown, UIDataTypes, UITools>[] =
// are matched across inter-page boundaries.
// Include initial page tool outputs so older paged pages can match
// tool calls whose outputs landed in the initial page.
const pagedMessages: UIMessage<unknown, UIDataTypes, UITools>[] =
useMemo(() => {
if (!sessionId || olderRawMessages.length === 0) return [];
if (!sessionId || pagedRawMessages.length === 0) return [];
const extraToolOutputs =
initialPageRawMessages.length > 0
? extractToolOutputsFromRaw(initialPageRawMessages)
: undefined;
return convertChatSessionMessagesToUiMessages(
sessionId,
olderRawMessages,
pagedRawMessages,
{ isComplete: true, extraToolOutputs },
).messages;
}, [sessionId, olderRawMessages, initialPageRawMessages]);
}, [sessionId, pagedRawMessages, initialPageRawMessages]);
async function loadMore() {
if (
!sessionId ||
!hasMore ||
isLoadingMoreRef.current ||
oldestSequence === null
)
return;
if (!sessionId || !hasMore || isLoadingMoreRef.current) return;
if (oldestSequence === null) return;
const requestEpoch = epochRef.current;
isLoadingMoreRef.current = true;
@@ -136,15 +126,20 @@ export function useLoadMoreMessages({
consecutiveErrorsRef.current = 0;
const newRaw = (response.data.messages ?? []) as unknown[];
setOlderRawMessages((prev) => {
const estimatedTotal = pagedRawMessages.length + newRaw.length;
setPagedRawMessages((prev) => {
const merged = [...newRaw, ...prev];
if (merged.length > MAX_OLDER_MESSAGES) {
return merged.slice(merged.length - MAX_OLDER_MESSAGES);
}
return merged;
});
// Note: after truncation, oldest_sequence may reference a dropped
// message. This is safe because we also set hasMore=false below,
// preventing further loads with the stale cursor.
setOldestSequence(response.data.oldest_sequence ?? null);
if (newRaw.length + olderRawMessages.length >= MAX_OLDER_MESSAGES) {
if (estimatedTotal >= MAX_OLDER_MESSAGES) {
setHasMore(false);
} else {
setHasMore(!!response.data.has_more_messages);
@@ -164,5 +159,5 @@ export function useLoadMoreMessages({
}
}
return { olderMessages, hasMore, isLoadingMore, loadMore };
return { pagedMessages, hasMore, isLoadingMore, loadMore };
}

View File

@@ -2,14 +2,17 @@ import { Navbar } from "@/components/layout/Navbar/Navbar";
import { NetworkStatusMonitor } from "@/services/network-status/NetworkStatusMonitor";
import { ReactNode } from "react";
import { AdminImpersonationBanner } from "./admin/components/AdminImpersonationBanner";
import { AutoPilotBridgeProvider } from "@/contexts/AutoPilotBridgeContext";
export default function PlatformLayout({ children }: { children: ReactNode }) {
return (
<main className="flex h-screen w-full flex-col">
<NetworkStatusMonitor />
<Navbar />
<AdminImpersonationBanner />
<section className="flex-1">{children}</section>
</main>
<AutoPilotBridgeProvider>
<main className="flex h-screen w-full flex-col">
<NetworkStatusMonitor />
<Navbar />
<AdminImpersonationBanner />
<section className="flex-1">{children}</section>
</main>
</AutoPilotBridgeProvider>
);
}

View File

@@ -137,8 +137,10 @@ describe("LibraryPage", () => {
user_id: "test-user",
name: "Work Agents",
agent_count: 3,
subfolder_count: 0,
color: null,
icon: null,
parent_id: null,
created_at: new Date(),
updated_at: new Date(),
},
@@ -147,8 +149,10 @@ describe("LibraryPage", () => {
user_id: "test-user",
name: "Personal",
agent_count: 1,
subfolder_count: 0,
color: null,
icon: null,
parent_id: null,
created_at: new Date(),
updated_at: new Date(),
},
@@ -158,12 +162,14 @@ describe("LibraryPage", () => {
render(<LibraryPage />);
await waitForAgentsToLoad();
expect(await screen.findByText("Work Agents")).toBeDefined();
expect(screen.getByText("Personal")).toBeDefined();
expect(screen.getAllByTestId("library-folder")).toHaveLength(2);
});
test("shows See runs link on agent card", async () => {
test("shows See tasks link on agent card", async () => {
setupHandlers({
agents: [makeAgent({ name: "Linked Agent", can_access_graph: true })],
});
@@ -172,7 +178,7 @@ describe("LibraryPage", () => {
await screen.findByText("Linked Agent");
const runLinks = screen.getAllByText("See runs");
const runLinks = screen.getAllByText("See tasks");
expect(runLinks.length).toBeGreaterThan(0);
});
@@ -190,7 +196,7 @@ describe("LibraryPage", () => {
expect(importButtons.length).toBeGreaterThan(0);
});
test("renders Jump Back In when there is an active execution", async () => {
test("renders running agent card when execution is active", async () => {
const agent = makeAgent({
id: "lib-1",
graph_id: "g-1",
@@ -218,6 +224,6 @@ describe("LibraryPage", () => {
render(<LibraryPage />);
expect(await screen.findByText("Jump Back In")).toBeDefined();
expect(await screen.findByText("Running Agent")).toBeDefined();
});
});

View File

@@ -0,0 +1,44 @@
.glassPanel {
position: relative;
isolation: isolate;
}
.glassPanel::before {
content: "";
position: absolute;
inset: 0;
border-radius: inherit;
padding: 1px;
background: conic-gradient(
from var(--border-angle, 0deg),
rgba(129, 120, 228, 0.04),
rgba(129, 120, 228, 0.14),
rgba(168, 130, 255, 0.09),
rgba(129, 120, 228, 0.04),
rgba(99, 102, 241, 0.12),
rgba(129, 120, 228, 0.04)
);
-webkit-mask:
linear-gradient(#000 0 0) content-box,
linear-gradient(#000 0 0);
mask:
linear-gradient(#000 0 0) content-box,
linear-gradient(#000 0 0);
-webkit-mask-composite: xor;
mask-composite: exclude;
animation: rotate-border 6s linear infinite;
pointer-events: none;
z-index: -1;
}
@property --border-angle {
syntax: "<angle>";
initial-value: 0deg;
inherits: false;
}
@keyframes rotate-border {
to {
--border-angle: 360deg;
}
}

View File

@@ -0,0 +1,36 @@
"use client";
import { Text } from "@/components/atoms/Text/Text";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { useState } from "react";
import type { FleetSummary, AgentStatusFilter } from "../../types";
import { BriefingTabContent } from "./BriefingTabContent";
import { StatsGrid } from "./StatsGrid";
import styles from "./AgentBriefingPanel.module.css";
interface Props {
summary: FleetSummary;
agents: LibraryAgent[];
}
export function AgentBriefingPanel({ summary, agents }: Props) {
const [userTab, setUserTab] = useState<AgentStatusFilter | null>(null);
const activeTab: AgentStatusFilter =
userTab ?? (summary.running > 0 ? "running" : "all");
return (
<div
className={`${styles.glassPanel} min-h-[14.75rem] rounded-large bg-gradient-to-br from-indigo-50/30 via-white/90 to-purple-50/25 px-5 pb-5 pt-[1.125rem] shadow-sm backdrop-blur-md`}
>
<Text variant="h5">Agent Briefing</Text>
<div className="mt-4 space-y-5">
<StatsGrid
summary={summary}
activeTab={activeTab}
onTabChange={setUserTab}
/>
<BriefingTabContent activeTab={activeTab} agents={agents} />
</div>
</div>
);
}

View File

@@ -0,0 +1,361 @@
"use client";
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat";
import {
formatResetTime,
formatCents,
} from "@/app/(platform)/copilot/components/usageHelpers";
import { useResetRateLimit } from "@/app/(platform)/copilot/hooks/useResetRateLimit";
import { Button } from "@/components/atoms/Button/Button";
import { Badge } from "@/components/atoms/Badge/Badge";
import useCredits from "@/hooks/useCredits";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { useSitrepItems } from "../SitrepItem/useSitrepItems";
import { SitrepItem } from "../SitrepItem/SitrepItem";
import { useAgentStatusMap } from "../../hooks/useAgentStatus";
import type { AgentStatusFilter } from "../../types";
import { Text } from "@/components/atoms/Text/Text";
import Link from "next/link";
import { useState } from "react";
interface Props {
activeTab: AgentStatusFilter;
agents: LibraryAgent[];
}
export function BriefingTabContent({ activeTab, agents }: Props) {
if (activeTab === "all") {
return <UsageSection />;
}
if (
activeTab === "running" ||
activeTab === "attention" ||
activeTab === "completed"
) {
return <ExecutionListSection activeTab={activeTab} agents={agents} />;
}
return <AgentListSection activeTab={activeTab} agents={agents} />;
}
function UsageSection() {
const { data: usage } = useGetV2GetCopilotUsage({
query: {
select: (res) => res.data as CoPilotUsageStatus,
refetchInterval: 30000,
staleTime: 10000,
},
});
const isBillingEnabled = useGetFlag(Flag.ENABLE_PLATFORM_PAYMENT);
const { credits, fetchCredits } = useCredits({ fetchInitialCredits: true });
const resetCost = usage?.reset_cost;
const hasInsufficientCredits =
credits !== null && resetCost != null && credits < resetCost;
if (!usage?.daily || !usage?.weekly) return null;
return (
<div className="py-2">
<div className="flex items-center gap-2">
<Text variant="h5" className="text-neutral-800">
Usage limits
</Text>
{usage.tier && (
<Badge variant="info" size="small" className="bg-[rgb(224,237,255)]">
{usage.tier.charAt(0) + usage.tier.slice(1).toLowerCase()} plan
</Badge>
)}
<div className="flex-1" />
{isBillingEnabled && (
<Link
href="/profile/credits"
className="text-sm text-blue-600 hover:underline"
>
Manage billing
</Link>
)}
</div>
<div className="mt-4 grid grid-cols-1 gap-6 sm:grid-cols-2">
{usage.daily.limit > 0 && (
<UsageMeter
label="Today"
used={usage.daily.used}
limit={usage.daily.limit}
resetsAt={usage.daily.resets_at}
/>
)}
{usage.weekly.limit > 0 && (
<UsageMeter
label="This week"
used={usage.weekly.used}
limit={usage.weekly.limit}
resetsAt={usage.weekly.resets_at}
/>
)}
</div>
<UsageFooter
usage={usage}
hasInsufficientCredits={hasInsufficientCredits}
onCreditChange={fetchCredits}
/>
</div>
);
}
const MAX_VISIBLE = 6;
function ExecutionListSection({
activeTab,
agents,
}: {
activeTab: AgentStatusFilter;
agents: LibraryAgent[];
}) {
const allItems = useSitrepItems(agents, 50);
const [showAll, setShowAll] = useState(false);
const filtered = allItems.filter((item) => {
if (activeTab === "running") return item.priority === "running";
if (activeTab === "attention") return item.priority === "error";
if (activeTab === "completed") return item.priority === "success";
return false;
});
if (filtered.length === 0) {
return <EmptyMessage tab={activeTab} />;
}
const visible = showAll ? filtered : filtered.slice(0, MAX_VISIBLE);
const hasMore = filtered.length > MAX_VISIBLE;
return (
<div>
<div className="grid grid-cols-1 gap-3 lg:grid-cols-2">
{visible.map((item) => (
<SitrepItem key={item.id} item={item} />
))}
</div>
{hasMore && (
<div className="mt-3 flex justify-center">
<Button
variant="secondary"
size="small"
onClick={() => setShowAll(!showAll)}
>
{showAll ? "Collapse" : `Show all (${filtered.length})`}
</Button>
</div>
)}
</div>
);
}
const TAB_STATUS_LABEL: Record<string, string> = {
listening: "Waiting for trigger event",
scheduled: "Has a scheduled run",
idle: "No recent activity",
};
function getAgentStatusLabel(tab: string, agent: LibraryAgent): string {
if (tab === "scheduled" && agent.next_scheduled_run) {
const diff = new Date(agent.next_scheduled_run).getTime() - Date.now();
const minutes = Math.round(diff / 60_000);
if (minutes <= 0) return "Scheduled to run soon";
if (minutes < 60) return `Scheduled to run in ${minutes}m`;
const hours = Math.round(minutes / 60);
if (hours < 24) return `Scheduled to run in ${hours}h`;
const days = Math.round(hours / 24);
return `Scheduled to run in ${days}d`;
}
return TAB_STATUS_LABEL[tab] ?? "";
}
function AgentListSection({
activeTab,
agents,
}: {
activeTab: AgentStatusFilter;
agents: LibraryAgent[];
}) {
const [showAll, setShowAll] = useState(false);
const statusMap = useAgentStatusMap(agents);
const filtered = agents.filter((agent) => {
const status = statusMap.get(agent.graph_id)?.status;
if (activeTab === "listening") return status === "listening";
if (activeTab === "scheduled") return status === "scheduled";
if (activeTab === "idle") return status === "idle";
return false;
});
if (filtered.length === 0) {
return <EmptyMessage tab={activeTab} />;
}
const status =
activeTab === "listening"
? ("listening" as const)
: activeTab === "scheduled"
? ("scheduled" as const)
: ("idle" as const);
const visible = showAll ? filtered : filtered.slice(0, MAX_VISIBLE);
const hasMore = filtered.length > MAX_VISIBLE;
return (
<div>
<div className="grid grid-cols-1 gap-3 lg:grid-cols-2">
{visible.map((agent) => (
<SitrepItem
key={agent.id}
item={{
id: agent.id,
agentID: agent.id,
agentName: agent.name,
agentImageUrl: agent.image_url,
priority: status,
message: getAgentStatusLabel(activeTab, agent),
status,
}}
/>
))}
</div>
{hasMore && (
<div className="mt-3 flex justify-center">
<Button
variant="secondary"
size="small"
onClick={() => setShowAll(!showAll)}
>
{showAll ? "Collapse" : `Show all (${filtered.length})`}
</Button>
</div>
)}
</div>
);
}
function UsageFooter({
usage,
hasInsufficientCredits,
onCreditChange,
}: {
usage: CoPilotUsageStatus;
hasInsufficientCredits: boolean;
onCreditChange?: () => void;
}) {
const isDailyExhausted =
usage.daily.limit > 0 && usage.daily.used >= usage.daily.limit;
const isWeeklyExhausted =
usage.weekly.limit > 0 && usage.weekly.used >= usage.weekly.limit;
const resetCost = usage.reset_cost ?? 0;
const { resetUsage, isPending } = useResetRateLimit({ onCreditChange });
const showReset =
isDailyExhausted &&
!isWeeklyExhausted &&
resetCost > 0 &&
!hasInsufficientCredits;
const showAddCredits =
isDailyExhausted && !isWeeklyExhausted && hasInsufficientCredits;
if (!showReset && !showAddCredits) return null;
return (
<div className="mt-4 flex items-center gap-3">
{showReset && (
<Button
variant="primary"
size="small"
onClick={() => resetUsage()}
loading={isPending}
>
{isPending
? "Resetting..."
: `Reset daily limit for ${formatCents(resetCost)}`}
</Button>
)}
{showAddCredits && (
<Link
href="/profile/credits"
className="inline-flex items-center justify-center rounded-md bg-primary px-3 py-1.5 text-sm font-medium text-primary-foreground hover:bg-primary/90"
>
Add credits to reset
</Link>
)}
</div>
);
}
function UsageMeter({
label,
used,
limit,
resetsAt,
}: {
label: string;
used: number;
limit: number;
resetsAt: Date | string;
}) {
if (limit <= 0) return null;
const rawPercent = (used / limit) * 100;
const percent = Math.min(100, Math.round(rawPercent));
const isHigh = percent >= 80;
const percentLabel =
used > 0 && percent === 0 ? "<1% used" : `${percent}% used`;
return (
<div className="flex flex-col gap-2">
<div className="flex items-baseline justify-between">
<Text variant="body-medium" className="text-neutral-700">
{label}
</Text>
<Text variant="body" className="tabular-nums text-neutral-500">
{percentLabel}
</Text>
</div>
<div className="h-2 w-full overflow-hidden rounded-full bg-neutral-200">
<div
className={`h-full rounded-full transition-[width] duration-300 ease-out ${
isHigh ? "bg-orange-500" : "bg-blue-500"
}`}
style={{ width: `${Math.max(used > 0 ? 1 : 0, percent)}%` }}
/>
</div>
<div className="flex items-baseline justify-between">
<Text variant="small" className="tabular-nums text-neutral-500">
{used.toLocaleString()} / {limit.toLocaleString()}
</Text>
<Text variant="small" className="text-neutral-400">
Resets {formatResetTime(resetsAt)}
</Text>
</div>
</div>
);
}
const EMPTY_MESSAGES: Record<string, string> = {
running: "No agents running right now",
attention: "No agents that need attention",
completed: "No recently completed runs",
listening: "No agents listening for events",
scheduled: "No agents with scheduled runs",
idle: "No idle agents",
};
function EmptyMessage({ tab }: { tab: AgentStatusFilter }) {
return (
<div className="flex items-center justify-center pt-4">
<Text variant="body-medium" className="text-zinc-600">
{EMPTY_MESSAGES[tab] ?? "No agents in this category"}
</Text>
</div>
);
}

View File

@@ -0,0 +1,102 @@
"use client";
import { Text } from "@/components/atoms/Text/Text";
import { OverflowText } from "@/components/atoms/OverflowText/OverflowText";
import { Emoji } from "@/components/atoms/Emoji/Emoji";
import { cn } from "@/lib/utils";
import type { FleetSummary, AgentStatusFilter } from "../../types";
interface Props {
summary: FleetSummary;
activeTab: AgentStatusFilter;
onTabChange: (tab: AgentStatusFilter) => void;
}
const TILES: {
label: string;
key: keyof FleetSummary;
format?: (v: number) => string;
filter: AgentStatusFilter;
emoji: string;
color: string;
}[] = [
{
label: "Spent this month",
key: "monthlySpend",
format: (v) => `$${v.toLocaleString()}`,
filter: "all",
emoji: "💵",
color: "text-zinc-700",
},
{
label: "Running now",
key: "running",
filter: "running",
emoji: "🚩",
color: "text-blue-600",
},
{
label: "Recently completed",
key: "completed",
filter: "completed",
emoji: "🗃️",
color: "text-green-600",
},
{
label: "Needs attention",
key: "error",
filter: "attention",
emoji: "⚠️",
color: "text-red-500",
},
{
label: "Scheduled",
key: "scheduled",
filter: "scheduled",
emoji: "📅",
color: "text-yellow-600",
},
{
label: "Idle",
key: "idle",
filter: "idle",
emoji: "💤",
color: "text-zinc-400",
},
];
export function StatsGrid({ summary, activeTab, onTabChange }: Props) {
return (
<div className="grid grid-cols-1 gap-3 min-[450px]:grid-cols-2 sm:grid-cols-3 lg:grid-cols-6">
{TILES.map((tile) => {
const rawValue = summary[tile.key];
const value = tile.format ? tile.format(rawValue) : rawValue;
const isActive = activeTab === tile.filter;
return (
<button
key={tile.label}
type="button"
onClick={() => onTabChange(tile.filter)}
className={cn(
"flex min-w-0 flex-col gap-1 rounded-medium border p-3 text-left shadow-md transition-all hover:shadow-lg",
isActive
? "border-zinc-900 bg-zinc-50"
: "border-zinc-100 bg-white",
)}
>
<div className="flex min-w-0 items-center gap-1.5">
<Emoji text={tile.emoji} size={18} />
<OverflowText
value={tile.label}
variant="body"
className="text-zinc-800"
/>
</div>
<Text variant="h4">{value}</Text>
</button>
);
})}
</div>
);
}

View File

@@ -0,0 +1,52 @@
"use client";
import type { SelectOption } from "@/components/atoms/Select/Select";
import { Select } from "@/components/atoms/Select/Select";
import { FunnelIcon } from "@phosphor-icons/react";
import type { AgentStatusFilter, FleetSummary } from "../../types";
interface Props {
value: AgentStatusFilter;
onChange: (value: AgentStatusFilter) => void;
summary: FleetSummary;
}
function buildOptions(summary: FleetSummary): SelectOption[] {
return [
{ value: "all", label: "All Agents" },
{ value: "running", label: `Running (${summary.running})` },
{ value: "attention", label: `Needs Attention (${summary.error})` },
{ value: "listening", label: `Listening (${summary.listening})` },
{ value: "scheduled", label: `Scheduled (${summary.scheduled})` },
{ value: "idle", label: `Idle / Stale (${summary.idle})` },
{ value: "healthy", label: "Healthy" },
];
}
export function AgentFilterMenu({ value, onChange, summary }: Props) {
function handleChange(val: string) {
onChange(val as AgentStatusFilter);
}
const options = buildOptions(summary);
return (
<div className="flex items-center" data-testid="agent-filter-dropdown">
<span className="hidden whitespace-nowrap text-sm text-zinc-500 sm:inline">
filter
</span>
<FunnelIcon className="ml-1 h-4 w-4 sm:hidden" />
<Select
id="agent-status-filter"
label="Filter agents"
hideLabel
value={value}
onValueChange={handleChange}
options={options}
size="small"
className="ml-1 w-fit border-none !bg-transparent text-sm underline underline-offset-4 shadow-none"
wrapperClassName="mb-0"
/>
</div>
);
}

View File

@@ -0,0 +1,66 @@
"use client";
import {
EyeIcon,
ArrowsClockwiseIcon,
MonitorPlayIcon,
} from "@phosphor-icons/react";
import { cn } from "@/lib/utils";
import { useRouter } from "next/navigation";
import type { AgentStatus } from "../../types";
interface Props {
status: AgentStatus;
agentID: string;
executionID?: string;
className?: string;
}
export function ContextualActionButton({
status,
agentID,
executionID,
className,
}: Props) {
const router = useRouter();
const config = ACTION_CONFIG[status];
if (!config) return null;
const Icon = config.icon;
function handleClick(e: React.MouseEvent) {
e.preventDefault();
e.stopPropagation();
const params = new URLSearchParams();
if (executionID) params.set("activeItem", executionID);
const query = params.toString();
router.push(`/library/agents/${agentID}${query ? `?${query}` : ""}`);
}
return (
<button
type="button"
onClick={handleClick}
className={cn(
"inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800",
className,
)}
>
<Icon size={12} className="shrink-0" />
{config.label}
</button>
);
}
const ACTION_CONFIG: Record<
AgentStatus,
{ label: string; icon: typeof EyeIcon }
> = {
error: { label: "View error", icon: EyeIcon },
listening: { label: "Reconnect", icon: ArrowsClockwiseIcon },
running: { label: "Watch live", icon: MonitorPlayIcon },
idle: { label: "View", icon: EyeIcon },
scheduled: { label: "View", icon: EyeIcon },
};

View File

@@ -1,46 +0,0 @@
"use client";
import { ArrowRight, Lightning } from "@phosphor-icons/react";
import NextLink from "next/link";
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { useJumpBackIn } from "./useJumpBackIn";
export function JumpBackIn() {
const { execution, isLoading } = useJumpBackIn();
if (isLoading || !execution) {
return null;
}
const href = execution.libraryAgentId
? `/library/agents/${execution.libraryAgentId}?activeTab=runs&activeItem=${execution.id}`
: "#";
return (
<div className="rounded-large bg-gradient-to-r from-zinc-200 via-zinc-200/60 to-indigo-200/50 p-[1px]">
<div className="flex items-center justify-between rounded-large bg-[#F6F7F8] px-5 py-4">
<div className="flex items-center gap-3">
<div className="flex h-9 w-9 items-center justify-center rounded-full bg-zinc-900">
<Lightning size={18} weight="fill" className="text-white" />
</div>
<div className="flex flex-col">
<Text variant="small" className="text-zinc-500">
{execution.statusLabel} · {execution.duration}
</Text>
<Text variant="body-medium" className="text-zinc-900">
{execution.agentName}
</Text>
</div>
</div>
<NextLink href={href}>
<Button variant="secondary" size="small" className="gap-1.5">
Jump Back In
<ArrowRight size={16} />
</Button>
</NextLink>
</div>
</div>
);
}

View File

@@ -1,82 +0,0 @@
"use client";
import { useGetV1ListAllExecutions } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import { okData } from "@/app/api/helpers";
import { useLibraryAgents } from "@/hooks/useLibraryAgents/useLibraryAgents";
import { useMemo } from "react";
function isActive(status: AgentExecutionStatus) {
return (
status === AgentExecutionStatus.RUNNING ||
status === AgentExecutionStatus.QUEUED ||
status === AgentExecutionStatus.REVIEW
);
}
function formatDuration(startedAt: Date | string | null | undefined): string {
if (!startedAt) return "";
const start = new Date(startedAt);
if (isNaN(start.getTime())) return "";
const ms = Date.now() - start.getTime();
if (ms < 0) return "";
const sec = Math.floor(ms / 1000);
if (sec < 5) return "a few seconds";
if (sec < 60) return `${sec}s`;
const min = Math.floor(sec / 60);
if (min < 60) return `${min}m ${sec % 60}s`;
const hr = Math.floor(min / 60);
return `${hr}h ${min % 60}m`;
}
function getStatusLabel(status: AgentExecutionStatus) {
if (status === AgentExecutionStatus.RUNNING) return "Running";
if (status === AgentExecutionStatus.QUEUED) return "Queued";
if (status === AgentExecutionStatus.REVIEW) return "Awaiting approval";
return "";
}
export function useJumpBackIn() {
const { data: executions, isLoading: executionsLoading } =
useGetV1ListAllExecutions({
query: { select: okData },
});
const { agentInfoMap, isRefreshing: agentsLoading } = useLibraryAgents();
const activeExecution = useMemo(() => {
if (!executions) return null;
const active = executions
.filter((e) => isActive(e.status))
.sort((a, b) => {
const aTime = a.started_at ? new Date(a.started_at).getTime() : 0;
const bTime = b.started_at ? new Date(b.started_at).getTime() : 0;
return bTime - aTime;
});
return active[0] ?? null;
}, [executions]);
const enriched = useMemo(() => {
if (!activeExecution) return null;
const info = agentInfoMap.get(activeExecution.graph_id);
return {
id: activeExecution.id,
agentName: info?.name ?? "Unknown Agent",
libraryAgentId: info?.library_agent_id,
status: activeExecution.status,
statusLabel: getStatusLabel(activeExecution.status),
duration: formatDuration(activeExecution.started_at),
};
}, [activeExecution, agentInfoMap]);
return {
execution: enriched,
isLoading: executionsLoading || agentsLoading,
};
}

View File

@@ -8,7 +8,7 @@ interface Props {
export function LibraryActionHeader({ setSearchTerm }: Props) {
return (
<>
<div className="mb-[32px] hidden items-center justify-center gap-4 md:flex">
<div className="mb-7 hidden items-center justify-center gap-4 md:flex">
<LibrarySearchBar setSearchTerm={setSearchTerm} />
<LibraryImportDialog />
</div>

View File

@@ -1,29 +1,40 @@
"use client";
import { Text } from "@/components/atoms/Text/Text";
import { CaretCircleRightIcon } from "@phosphor-icons/react";
import { EyeIcon, ChatCircleDotsIcon } from "@phosphor-icons/react";
import Image from "next/image";
import NextLink from "next/link";
import { useRouter } from "next/navigation";
import { motion } from "framer-motion";
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import Avatar, {
AvatarFallback,
AvatarImage,
} from "@/components/atoms/Avatar/Avatar";
import { Link } from "@/components/atoms/Link/Link";
import { cn } from "@/lib/utils";
import { AgentCardMenu } from "./components/AgentCardMenu";
import { FavoriteButton } from "./components/FavoriteButton";
import { useLibraryAgentCard } from "./useLibraryAgentCard";
import { useFavoriteAnimation } from "../../context/FavoriteAnimationContext";
import { StatusBadge } from "../StatusBadge/StatusBadge";
import { ContextualActionButton } from "../ContextualActionButton/ContextualActionButton";
import type { AgentStatusInfo } from "../../types";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
interface Props {
agent: LibraryAgent;
statusInfo: AgentStatusInfo;
draggable?: boolean;
}
export function LibraryAgentCard({ agent, draggable = true }: Props) {
const { id, name, graph_id, can_access_graph, image_url } = agent;
export function LibraryAgentCard({
agent,
statusInfo,
draggable = true,
}: Props) {
const { id, name, image_url } = agent;
const router = useRouter();
const { triggerFavoriteAnimation } = useFavoriteAnimation();
function handleDragStart(e: React.DragEvent<HTMLDivElement>) {
@@ -31,18 +42,14 @@ export function LibraryAgentCard({ agent, draggable = true }: Props) {
e.dataTransfer.effectAllowed = "move";
}
const {
isFromMarketplace,
isFavorite,
profile,
creator_image_url,
handleToggleFavorite,
} = useLibraryAgentCard({
const { isFavorite, handleToggleFavorite } = useLibraryAgentCard({
agent,
onFavoriteAdd: triggerFavoriteAnimation,
});
return (
const hasError = statusInfo.status === "error";
const card = (
<div
draggable={draggable}
onDragStart={handleDragStart}
@@ -52,7 +59,10 @@ export function LibraryAgentCard({ agent, draggable = true }: Props) {
layoutId={`agent-card-${id}`}
data-testid="library-agent-card"
data-agent-id={id}
className="group relative inline-flex h-[10.625rem] w-full max-w-[25rem] flex-col items-start justify-start gap-2.5 rounded-medium border border-zinc-100 bg-white hover:shadow-md"
className={cn(
"group relative inline-flex h-auto min-h-[10.625rem] w-full max-w-[25rem] flex-col items-start justify-start gap-2.5 rounded-medium border bg-white hover:shadow-md",
hasError ? "border-red-400" : "border-zinc-100",
)}
transition={{
type: "spring",
damping: 25,
@@ -61,23 +71,10 @@ export function LibraryAgentCard({ agent, draggable = true }: Props) {
style={{ willChange: "transform" }}
>
<NextLink href={`/library/agents/${id}`} className="flex-shrink-0">
<div className="relative flex items-center gap-2 px-4 pt-3">
<Avatar className="h-4 w-4 rounded-full">
<AvatarImage
src={
isFromMarketplace
? creator_image_url || "/avatar-placeholder.png"
: profile?.avatar_url || "/avatar-placeholder.png"
}
alt={`${name} creator avatar`}
/>
<AvatarFallback size={48}>{name.charAt(0)}</AvatarFallback>
</Avatar>
<Text
variant="small-medium"
className="uppercase tracking-wide text-zinc-400"
>
{isFromMarketplace ? "FROM MARKETPLACE" : "Built by you"}
<div className="relative flex items-center gap-3 pl-2 pr-4 pt-3">
<StatusBadge status={statusInfo.status} />
<Text variant="small" className="text-zinc-400">
{statusInfo.totalRuns} tasks
</Text>
</div>
</NextLink>
@@ -89,7 +86,7 @@ export function LibraryAgentCard({ agent, draggable = true }: Props) {
<AgentCardMenu agent={agent} />
<div className="flex w-full flex-1 flex-col px-4 pb-2">
<Link
<NextLink
href={`/library/agents/${id}`}
className="flex w-full items-start justify-between gap-2 no-underline hover:no-underline focus:ring-0"
>
@@ -126,30 +123,52 @@ export function LibraryAgentCard({ agent, draggable = true }: Props) {
className="flex-shrink-0 rounded-small object-cover"
/>
)}
</Link>
</NextLink>
<div className="mt-auto flex w-full justify-start gap-6 border-t border-zinc-100 pb-1 pt-3">
<Link
href={`/library/agents/${id}`}
<div className="mt-4 flex w-full items-center justify-end gap-1 border-t border-zinc-100 pb-0 pt-2">
<button
type="button"
onClick={() => router.push(`/library/agents/${id}`)}
data-testid="library-agent-card-see-runs-link"
className="flex items-center gap-1 text-[13px]"
className="inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800"
>
See runs <CaretCircleRightIcon size={20} />
</Link>
{can_access_graph && (
<Link
href={`/build?flowID=${graph_id}`}
data-testid="library-agent-card-open-in-builder-link"
className="flex items-center gap-1 text-[13px]"
isExternal
>
Open in builder <CaretCircleRightIcon size={20} />
</Link>
)}
<EyeIcon size={14} className="shrink-0" />
See tasks
</button>
<ContextualActionButton
status={statusInfo.status}
agentID={id}
executionID={statusInfo.activeExecutionID ?? undefined}
/>
<button
type="button"
onClick={() => {
const prompt = encodeURIComponent(
`Tell me about ${name}, its current status, recent runs and how can I get the most out of it`,
);
router.push(`/copilot?autosubmit=true#prompt=${prompt}`);
}}
className="inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800"
>
<ChatCircleDotsIcon size={14} className="shrink-0" />
Chat
</button>
</div>
</div>
</motion.div>
</div>
);
if (hasError && statusInfo.lastError) {
return (
<Tooltip>
<TooltipTrigger asChild>{card}</TooltipTrigger>
<TooltipContent className="max-w-xs text-red-600">
{statusInfo.lastError}
</TooltipContent>
</Tooltip>
);
}
return card;
}

View File

@@ -169,6 +169,7 @@ export function AgentCardMenu({ agent }: AgentCardMenuProps) {
href={`/build?flowID=${agent.graph_id}&flowVersion=${agent.graph_version}`}
target="_blank"
className="flex items-center gap-2"
data-testid="library-agent-card-open-in-builder-link"
onClick={(e) => e.stopPropagation()}
>
Edit agent

View File

@@ -1,6 +1,7 @@
"use client";
import { LibraryAgentSort } from "@/app/api/__generated__/models/libraryAgentSort";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { InfiniteScroll } from "@/components/contextual/InfiniteScroll/InfiniteScroll";
import { LibraryAgentCard } from "../LibraryAgentCard/LibraryAgentCard";
@@ -16,8 +17,11 @@ import {
} from "framer-motion";
import { LibraryFolderEditDialog } from "../LibraryFolderEditDialog/LibraryFolderEditDialog";
import { LibraryFolderDeleteDialog } from "../LibraryFolderDeleteDialog/LibraryFolderDeleteDialog";
import { LibraryTab } from "../../types";
import type { LibraryTab, AgentStatusFilter, FleetSummary } from "../../types";
import { useLibraryAgentList } from "./useLibraryAgentList";
import { AgentBriefingPanel } from "../AgentBriefingPanel/AgentBriefingPanel";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { useAgentStatusMap, getAgentStatus } from "../../hooks/useAgentStatus";
// cancels the current spring and starts a new one from current state.
const containerVariants = {
@@ -70,6 +74,10 @@ interface Props {
tabs: LibraryTab[];
activeTab: string;
onTabChange: (tabId: string) => void;
statusFilter?: AgentStatusFilter;
onStatusFilterChange?: (filter: AgentStatusFilter) => void;
fleetSummary?: FleetSummary;
briefingAgents?: LibraryAgent[];
}
export function LibraryAgentList({
@@ -81,7 +89,12 @@ export function LibraryAgentList({
tabs,
activeTab,
onTabChange,
statusFilter = "all",
onStatusFilterChange,
fleetSummary,
briefingAgents,
}: Props) {
const isAgentBriefingEnabled = useGetFlag(Flag.AGENT_BRIEFING);
const shouldReduceMotion = useReducedMotion();
const activeContainerVariants = shouldReduceMotion
? reducedContainerVariants
@@ -95,7 +108,7 @@ export function LibraryAgentList({
const {
isFavoritesTab,
agentLoading,
allAgentsCount,
displayedCount,
favoritesCount,
agents,
hasNextPage,
@@ -116,18 +129,37 @@ export function LibraryAgentList({
selectedFolderId,
onFolderSelect,
activeTab,
statusFilter,
});
const agentStatusMap = useAgentStatusMap(agents);
return (
<>
{isAgentBriefingEnabled &&
!selectedFolderId &&
fleetSummary &&
briefingAgents &&
briefingAgents.length > 0 && (
<div className="mb-4">
<AgentBriefingPanel
summary={fleetSummary}
agents={briefingAgents}
/>
</div>
)}
{!selectedFolderId && (
<LibrarySubSection
tabs={tabs}
activeTab={activeTab}
onTabChange={onTabChange}
allCount={allAgentsCount}
allCount={displayedCount}
favoritesCount={favoritesCount}
setLibrarySort={setLibrarySort}
statusFilter={statusFilter}
onStatusFilterChange={onStatusFilterChange}
fleetSummary={fleetSummary}
/>
)}
@@ -219,7 +251,13 @@ export function LibraryAgentList({
0.04,
}}
>
<LibraryAgentCard agent={agent} />
<LibraryAgentCard
agent={agent}
statusInfo={getAgentStatus(
agentStatusMap,
agent.graph_id,
)}
/>
</motion.div>
))}
</motion.div>

View File

@@ -21,7 +21,12 @@ import { useToast } from "@/components/molecules/Toast/use-toast";
import { useFavoriteAgents } from "../../hooks/useFavoriteAgents";
import { getQueryClient } from "@/lib/react-query/queryClient";
import { useQueryClient } from "@tanstack/react-query";
import { useEffect, useRef, useState } from "react";
import { useEffect, useMemo, useRef, useState } from "react";
import type { AgentStatusFilter } from "../../types";
import { useGetV1ListAllExecutions } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
const FILTER_EXHAUST_THRESHOLD = 3;
interface Props {
searchTerm: string;
@@ -29,6 +34,7 @@ interface Props {
selectedFolderId: string | null;
onFolderSelect: (folderId: string | null) => void;
activeTab: string;
statusFilter?: AgentStatusFilter;
}
export function useLibraryAgentList({
@@ -37,12 +43,16 @@ export function useLibraryAgentList({
selectedFolderId,
onFolderSelect,
activeTab,
statusFilter = "all",
}: Props) {
const isFavoritesTab = activeTab === "favorites";
const { toast } = useToast();
const stableQueryClient = getQueryClient();
const queryClient = useQueryClient();
const prevSortRef = useRef<LibraryAgentSort | null>(null);
const [consecutiveEmptyPages, setConsecutiveEmptyPages] = useState(0);
const prevFilteredLengthRef = useRef(0);
const prevAgentsLengthRef = useRef(0);
const [editingFolder, setEditingFolder] = useState<LibraryFolder | null>(
null,
@@ -199,6 +209,90 @@ export function useLibraryAgentList({
const showFolders = !isFavoritesTab;
const { data: executions } = useGetV1ListAllExecutions({
query: { select: okData },
});
const { activeGraphIds, errorGraphIds, completedGraphIds } = useMemo(() => {
const active = new Set<string>();
const errors = new Set<string>();
const completed = new Set<string>();
const cutoff = Date.now() - 72 * 60 * 60 * 1000;
for (const exec of executions ?? []) {
if (
exec.status === AgentExecutionStatus.RUNNING ||
exec.status === AgentExecutionStatus.QUEUED ||
exec.status === AgentExecutionStatus.REVIEW
) {
active.add(exec.graph_id);
}
const endedTs = exec.ended_at
? exec.ended_at instanceof Date
? exec.ended_at.getTime()
: new Date(String(exec.ended_at)).getTime()
: 0;
if (
(exec.status === AgentExecutionStatus.FAILED ||
exec.status === AgentExecutionStatus.TERMINATED) &&
endedTs > cutoff
) {
errors.add(exec.graph_id);
}
if (exec.status === AgentExecutionStatus.COMPLETED && endedTs > cutoff) {
completed.add(exec.graph_id);
}
}
return {
activeGraphIds: active,
errorGraphIds: errors,
completedGraphIds: completed,
};
}, [executions]);
const filteredAgents = filterAgentsByStatus(
agents,
statusFilter,
activeGraphIds,
errorGraphIds,
completedGraphIds,
);
useEffect(() => {
if (statusFilter === "all") {
setConsecutiveEmptyPages(0);
prevFilteredLengthRef.current = filteredAgents.length;
prevAgentsLengthRef.current = agents.length;
return;
}
if (agents.length > prevAgentsLengthRef.current) {
const newFilteredCount = filteredAgents.length;
const previousCount = prevFilteredLengthRef.current;
if (newFilteredCount > previousCount) {
setConsecutiveEmptyPages(0);
} else {
setConsecutiveEmptyPages((prev) => prev + 1);
}
}
prevAgentsLengthRef.current = agents.length;
prevFilteredLengthRef.current = filteredAgents.length;
}, [agents.length, filteredAgents.length, statusFilter]);
useEffect(() => {
setConsecutiveEmptyPages(0);
prevFilteredLengthRef.current = 0;
prevAgentsLengthRef.current = 0;
}, [statusFilter]);
const filteredExhausted =
statusFilter !== "all" && consecutiveEmptyPages >= FILTER_EXHAUST_THRESHOLD;
// When a filter is active, show the filtered count instead of the API total.
const displayedCount =
statusFilter === "all" ? allAgentsCount : filteredAgents.length;
function handleFolderDeleted() {
if (selectedFolderId === deletingFolder?.id) {
onFolderSelect(null);
@@ -210,9 +304,10 @@ export function useLibraryAgentList({
agentLoading,
agentCount,
allAgentsCount,
displayedCount,
favoritesCount: favoriteAgentsData.agentCount,
agents,
hasNextPage: agentsHasNextPage,
agents: filteredAgents,
hasNextPage: agentsHasNextPage && !filteredExhausted,
isFetchingNextPage: agentsIsFetchingNextPage,
fetchNextPage: agentsFetchNextPage,
foldersData,
@@ -226,3 +321,46 @@ export function useLibraryAgentList({
handleFolderDeleted,
};
}
function filterAgentsByStatus<
T extends {
graph_id: string;
has_external_trigger: boolean;
recommended_schedule_cron?: string | null;
},
>(
agents: T[],
statusFilter: AgentStatusFilter,
activeGraphIds: Set<string>,
errorGraphIds: Set<string>,
completedGraphIds: Set<string>,
): T[] {
if (statusFilter === "all") return agents;
return agents.filter((agent) => {
const isRunning = activeGraphIds.has(agent.graph_id);
const hasError = errorGraphIds.has(agent.graph_id);
if (statusFilter === "running") return isRunning;
if (statusFilter === "attention") return hasError && !isRunning;
if (statusFilter === "completed")
return completedGraphIds.has(agent.graph_id);
if (statusFilter === "listening")
return !isRunning && !hasError && agent.has_external_trigger;
if (statusFilter === "scheduled")
return (
!isRunning &&
!hasError &&
!agent.has_external_trigger &&
!!agent.recommended_schedule_cron
);
if (statusFilter === "idle")
return (
!isRunning &&
!hasError &&
!agent.has_external_trigger &&
!agent.recommended_schedule_cron
);
if (statusFilter === "healthy") return !hasError;
return true;
});
}

View File

@@ -2,14 +2,11 @@
import { Text } from "@/components/atoms/Text/Text";
import { Button } from "@/components/atoms/Button/Button";
import {
FolderIcon,
FolderColor,
folderCardStyles,
resolveColor,
} from "./FolderIcon";
import { FolderIcon, FolderColor } from "./FolderIcon";
import { useState } from "react";
import { PencilSimpleIcon, TrashIcon } from "@phosphor-icons/react";
import type { AgentStatus } from "../../types";
import { StatusBadge } from "../StatusBadge/StatusBadge";
interface Props {
id: string;
@@ -21,6 +18,8 @@ interface Props {
onDelete?: () => void;
onAgentDrop?: (agentId: string, folderId: string) => void;
onClick?: () => void;
/** Worst status among child agents (optional, for status aggregation). */
worstStatus?: AgentStatus;
}
export function LibraryFolder({
@@ -33,11 +32,10 @@ export function LibraryFolder({
onDelete,
onAgentDrop,
onClick,
worstStatus,
}: Props) {
const [isHovered, setIsHovered] = useState(false);
const [isDragOver, setIsDragOver] = useState(false);
const resolvedColor = resolveColor(color);
const cardStyle = folderCardStyles[resolvedColor];
function handleDragOver(e: React.DragEvent<HTMLDivElement>) {
if (e.dataTransfer.types.includes("application/agent-id")) {
@@ -64,10 +62,10 @@ export function LibraryFolder({
<div
data-testid="library-folder"
data-folder-id={id}
className={`group relative inline-flex h-[10.625rem] w-full max-w-[25rem] cursor-pointer flex-col items-start justify-between gap-2.5 rounded-medium border p-4 transition-all duration-200 hover:shadow-md ${
className={`group relative inline-flex h-[10.625rem] w-full max-w-[25rem] cursor-pointer flex-col items-start justify-between gap-2.5 rounded-medium border p-4 shadow-sm backdrop-blur-md transition-all duration-200 hover:shadow-md ${
isDragOver
? "border-blue-400 bg-blue-50 ring-2 ring-blue-200"
: `${cardStyle.border} ${cardStyle.bg}`
: "border-indigo-200/40 bg-gradient-to-br from-indigo-50/40 via-white/70 to-purple-50/30"
}`}
onMouseEnter={() => setIsHovered(true)}
onMouseLeave={() => setIsHovered(false)}
@@ -76,7 +74,7 @@ export function LibraryFolder({
onDrop={handleDrop}
onClick={onClick}
>
<div className="flex w-full items-start justify-between gap-4">
<div className="flex w-full items-center justify-between gap-4">
{/* Left side - Folder name and agent count */}
<div className="flex flex-1 flex-col gap-2">
<Text
@@ -86,17 +84,22 @@ export function LibraryFolder({
>
{name}
</Text>
<Text
variant="small"
className="text-zinc-500"
data-testid="library-folder-agent-count"
>
{agentCount} {agentCount === 1 ? "agent" : "agents"}
</Text>
<div className="flex items-center gap-2">
<Text
variant="small"
className="text-zinc-500"
data-testid="library-folder-agent-count"
>
{agentCount} {agentCount === 1 ? "agent" : "agents"}
</Text>
{worstStatus && worstStatus !== "idle" && (
<StatusBadge status={worstStatus} />
)}
</div>
</div>
{/* Right side - Custom folder icon */}
<div className="flex-shrink-0">
<div className="relative top-5 flex flex-shrink-0 items-center">
<FolderIcon isOpen={isHovered} color={color} icon={icon} />
</div>
</div>
@@ -114,7 +117,7 @@ export function LibraryFolder({
e.stopPropagation();
onEdit?.();
}}
className={`h-8 w-8 border p-2 ${cardStyle.buttonBase} ${cardStyle.buttonHover}`}
className="h-8 w-8 border border-neutral-200 bg-white/80 p-2 text-neutral-500 hover:bg-white hover:text-neutral-700"
>
<PencilSimpleIcon className="h-4 w-4" />
</Button>
@@ -126,7 +129,7 @@ export function LibraryFolder({
e.stopPropagation();
onDelete?.();
}}
className={`h-8 w-8 border p-2 ${cardStyle.buttonBase} ${cardStyle.buttonHover}`}
className="h-8 w-8 border border-neutral-200 bg-white/80 p-2 text-neutral-500 hover:bg-white hover:text-neutral-700"
>
<TrashIcon className="h-4 w-4" />
</Button>

View File

@@ -19,11 +19,11 @@ export function LibrarySortMenu({ setLibrarySort }: Props) {
const { handleSortChange } = useLibrarySortMenu({ setLibrarySort });
return (
<div className="flex items-center" data-testid="sort-by-dropdown">
<span className="hidden whitespace-nowrap text-sm sm:inline">
<span className="hidden whitespace-nowrap text-sm text-zinc-500 sm:inline">
sort by
</span>
<Select onValueChange={handleSortChange}>
<SelectTrigger className="ml-1 w-fit space-x-1 border-none px-0 text-sm underline underline-offset-4 shadow-none">
<SelectTrigger className="!m-0 ml-1 w-fit space-x-1 border-none !bg-transparent px-[1rem] text-sm underline underline-offset-4 !shadow-none !ring-offset-transparent">
<ArrowDownNarrowWideIcon className="h-4 w-4 sm:hidden" />
<SelectValue placeholder="Last Modified" />
</SelectTrigger>

View File

@@ -6,9 +6,10 @@ import {
} from "@/components/molecules/TabsLine/TabsLine";
import { LibraryAgentSort } from "@/app/api/__generated__/models/libraryAgentSort";
import { useFavoriteAnimation } from "../../context/FavoriteAnimationContext";
import { LibraryTab } from "../../types";
import type { LibraryTab, AgentStatusFilter, FleetSummary } from "../../types";
import LibraryFolderCreationDialog from "../LibraryFolderCreationDialog/LibraryFolderCreationDialog";
import { LibrarySortMenu } from "../LibrarySortMenu/LibrarySortMenu";
import { AgentFilterMenu } from "../AgentFilterMenu/AgentFilterMenu";
interface Props {
tabs: LibraryTab[];
@@ -17,6 +18,9 @@ interface Props {
allCount: number;
favoritesCount: number;
setLibrarySort: (value: LibraryAgentSort) => void;
statusFilter?: AgentStatusFilter;
onStatusFilterChange?: (filter: AgentStatusFilter) => void;
fleetSummary?: FleetSummary;
}
export function LibrarySubSection({
@@ -26,6 +30,9 @@ export function LibrarySubSection({
allCount,
favoritesCount,
setLibrarySort,
statusFilter = "all",
onStatusFilterChange,
fleetSummary,
}: Props) {
const { registerFavoritesTabRef } = useFavoriteAnimation();
const favoritesRef = useRef<HTMLButtonElement>(null);
@@ -68,8 +75,15 @@ export function LibrarySubSection({
))}
</TabsLineList>
</TabsLine>
<div className="hidden items-center gap-6 md:flex">
<div className="relative top-1.5 hidden items-center gap-6 md:flex">
<LibraryFolderCreationDialog />
{fleetSummary && onStatusFilterChange && (
<AgentFilterMenu
value={statusFilter}
onChange={onStatusFilterChange}
summary={fleetSummary}
/>
)}
<LibrarySortMenu setLibrarySort={setLibrarySort} />
</div>
</div>

View File

@@ -0,0 +1,17 @@
.spinner {
aspect-ratio: 1;
border-radius: 50%;
background:
radial-gradient(farthest-side, currentColor 94%, #0000) top/3px 3px
no-repeat,
conic-gradient(#0000 30%, currentColor);
-webkit-mask: radial-gradient(farthest-side, #0000 calc(100% - 3px), #000 0);
mask: radial-gradient(farthest-side, #0000 calc(100% - 3px), #000 0);
animation: spin 1s infinite linear;
}
@keyframes spin {
100% {
transform: rotate(1turn);
}
}

View File

@@ -0,0 +1,175 @@
"use client";
import { OverflowText } from "@/components/atoms/OverflowText/OverflowText";
import { Text } from "@/components/atoms/Text/Text";
import {
WarningCircleIcon,
ClockCountdownIcon,
CheckCircleIcon,
ChatCircleDotsIcon,
EarIcon,
CalendarDotsIcon,
MoonIcon,
EyeIcon,
} from "@phosphor-icons/react";
import NextLink from "next/link";
import { cn } from "@/lib/utils";
import { useRouter } from "next/navigation";
import type { SitrepItemData, SitrepPriority } from "../../types";
import { ContextualActionButton } from "../ContextualActionButton/ContextualActionButton";
import styles from "./SitrepItem.module.css";
interface Props {
item: SitrepItemData;
}
const PRIORITY_CONFIG: Record<
SitrepPriority,
{
icon?: typeof WarningCircleIcon;
color: string;
bg: string;
cssSpinner?: boolean;
}
> = {
error: {
icon: WarningCircleIcon,
color: "text-red-500",
bg: "bg-red-50",
},
running: {
color: "text-zinc-800",
bg: "",
cssSpinner: true,
},
stale: {
icon: ClockCountdownIcon,
color: "text-yellow-600",
bg: "bg-yellow-50",
},
success: {
icon: CheckCircleIcon,
color: "text-green-600",
bg: "bg-green-50",
},
listening: {
icon: EarIcon,
color: "text-purple-500",
bg: "bg-purple-50",
},
scheduled: {
icon: CalendarDotsIcon,
color: "text-yellow-600",
bg: "bg-yellow-50",
},
idle: {
icon: MoonIcon,
color: "text-zinc-400",
bg: "bg-zinc-100",
},
};
export function SitrepItem({ item }: Props) {
const config = PRIORITY_CONFIG[item.priority];
const router = useRouter();
function handleAskAutoPilot() {
const prompt = buildAutoPilotPrompt(item);
const encoded = encodeURIComponent(prompt);
router.push(`/copilot?autosubmit=true#prompt=${encoded}`);
}
return (
<div
className={cn(
"flex flex-col gap-2 rounded-medium border border-zinc-200/50 bg-transparent p-2 sm:flex-row sm:items-center sm:gap-3",
)}
>
<div className="flex min-w-0 flex-1 items-center gap-3">
{item.agentImageUrl ? (
<img
src={item.agentImageUrl}
alt={item.agentName}
className="h-6 w-6 flex-shrink-0 rounded-full object-cover"
/>
) : (
<div
className={cn(
"flex h-6 w-6 flex-shrink-0 items-center justify-center rounded-full",
config.bg,
)}
>
{config.cssSpinner ? (
<div
className={cn(
styles.spinner,
"h-[21px] w-[21px] text-zinc-800",
)}
/>
) : (
config.icon && (
<config.icon size={14} className={config.color} weight="fill" />
)
)}
</div>
)}
<div className="min-w-0 flex-1">
<Text variant="body-medium" className="leading-tight text-zinc-900">
{item.agentName}
</Text>
<OverflowText
value={item.message}
variant="small"
className="leading-tight text-zinc-500"
/>
</div>
</div>
<div className="flex flex-shrink-0 flex-wrap items-center justify-center gap-1.5 sm:flex-nowrap sm:justify-end">
{item.priority === "success" ? (
<NextLink
href={`/library/agents/${item.agentID}${item.executionID ? `?activeItem=${item.executionID}` : ""}`}
className="inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800"
>
<EyeIcon size={14} className="shrink-0" />
See task
</NextLink>
) : (
<ContextualActionButton
status={item.status}
agentID={item.agentID}
executionID={item.executionID}
/>
)}
<button
type="button"
onClick={handleAskAutoPilot}
className="inline-flex items-center gap-1 rounded-md px-2 py-1.5 text-[13px] font-medium text-zinc-600 transition-colors hover:bg-zinc-50 hover:text-zinc-800"
>
<ChatCircleDotsIcon size={14} className="shrink-0" />
Ask AutoPilot
</button>
</div>
</div>
);
}
function buildAutoPilotPrompt(item: SitrepItemData): string {
switch (item.priority) {
case "error":
return `What happened with ${item.agentName}? It says "${item.message}" — can you check the logs and tell me what to fix?`;
case "running":
return `Give me a status update on the ${item.agentName} run — what has it found so far?`;
case "stale":
return `${item.agentName} hasn't run recently. Should I keep it or update and re-run it?`;
case "success":
return `Show me what ${item.agentName} found in its last run — summarize the results and any key takeaways.`;
case "listening":
return `What is ${item.agentName} listening for? Give me a summary of its trigger configuration.`;
case "scheduled":
return `When is ${item.agentName} scheduled to run next?`;
case "idle":
return `${item.agentName} has been idle. Should I keep it or update and re-run it?`;
}
}

View File

@@ -0,0 +1,34 @@
"use client";
import { Text } from "@/components/atoms/Text/Text";
import { ClockCounterClockwise } from "@phosphor-icons/react";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { useSitrepItems } from "./useSitrepItems";
import { SitrepItem } from "./SitrepItem";
interface Props {
agents: LibraryAgent[];
maxItems?: number;
}
export function SitrepList({ agents, maxItems = 10 }: Props) {
const items = useSitrepItems(agents, maxItems);
if (items.length === 0) return null;
return (
<div>
<div className="mb-2 flex items-center gap-1.5">
<ClockCounterClockwise size={16} className="text-zinc-700" />
<Text variant="body-medium" className="text-zinc-700">
Recent tasks
</Text>
</div>
<div className="grid grid-cols-1 gap-1 lg:grid-cols-2">
{items.map((item) => (
<SitrepItem key={item.id} item={item} />
))}
</div>
</div>
);
}

View File

@@ -0,0 +1,198 @@
"use client";
import { useGetV1ListAllExecutions } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import type { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { okData } from "@/app/api/helpers";
import { useMemo } from "react";
import type { SitrepItemData, SitrepPriority } from "../../types";
import {
isActive,
isFailed,
toEndTime,
endedAfter,
runningMessage,
SEVENTY_TWO_HOURS_MS,
} from "../../hooks/executionHelpers";
export function useSitrepItems(
agents: LibraryAgent[],
maxItems: number,
scheduledWithinMs?: number,
): SitrepItemData[] {
const { data: executions } = useGetV1ListAllExecutions({
query: { select: okData },
});
return useMemo(() => {
if (agents.length === 0) return [];
const graphIdToAgent = new Map(agents.map((a) => [a.graph_id, a]));
const agentExecutions = groupByAgent(executions ?? [], graphIdToAgent);
const items: SitrepItemData[] = [];
const coveredAgentIds = new Set<string>();
for (const [agent, execs] of agentExecutions) {
const item = buildSitrepFromExecutions(agent, execs);
if (item) {
items.push(item);
coveredAgentIds.add(agent.id);
}
}
for (const agent of agents) {
if (coveredAgentIds.has(agent.id)) continue;
const configItem = buildSitrepFromConfig(agent, scheduledWithinMs);
if (configItem) items.push(configItem);
}
const order: Record<SitrepPriority, number> = {
error: 0,
running: 1,
stale: 2,
success: 3,
listening: 4,
scheduled: 5,
idle: 6,
};
items.sort((a, b) => order[a.priority] - order[b.priority]);
return items.slice(0, maxItems);
}, [agents, executions, maxItems, scheduledWithinMs]);
}
function groupByAgent(
executions: GraphExecutionMeta[],
graphIdToAgent: Map<string, LibraryAgent>,
): Map<LibraryAgent, GraphExecutionMeta[]> {
const map = new Map<LibraryAgent, GraphExecutionMeta[]>();
for (const exec of executions) {
const agent = graphIdToAgent.get(exec.graph_id);
if (!agent) continue;
const list = map.get(agent);
if (list) {
list.push(exec);
} else {
map.set(agent, [exec]);
}
}
return map;
}
function buildSitrepFromExecutions(
agent: LibraryAgent,
executions: GraphExecutionMeta[],
): SitrepItemData | null {
const active = executions.find((e) => isActive(e.status));
if (active) {
return {
id: `${agent.id}-${active.id}`,
agentID: agent.id,
agentName: agent.name,
executionID: active.id,
priority: "running",
message:
active.stats?.activity_status ??
runningMessage(active.status, active.started_at),
status: "running",
};
}
const cutoff = Date.now() - SEVENTY_TWO_HOURS_MS;
const recent = executions
.filter((e) => endedAfter(e, cutoff))
.sort((a, b) => toEndTime(b) - toEndTime(a));
const lastFailed = recent.find((e) => isFailed(e.status));
if (lastFailed) {
const errorMsg =
lastFailed.stats?.error ??
lastFailed.stats?.activity_status ??
"Execution failed";
return {
id: `${agent.id}-${lastFailed.id}`,
agentID: agent.id,
agentName: agent.name,
executionID: lastFailed.id,
priority: "error",
message: typeof errorMsg === "string" ? errorMsg : "Execution failed",
status: "error",
};
}
const lastCompleted = recent.find(
(e) => e.status === AgentExecutionStatus.COMPLETED,
);
if (lastCompleted) {
const summary =
lastCompleted.stats?.activity_status ?? "Completed successfully";
return {
id: `${agent.id}-${lastCompleted.id}`,
agentID: agent.id,
agentName: agent.name,
executionID: lastCompleted.id,
priority: "success",
message: typeof summary === "string" ? summary : "Completed successfully",
status: "idle",
};
}
return null;
}
function buildSitrepFromConfig(
agent: LibraryAgent,
scheduledWithinMs?: number,
): SitrepItemData | null {
if (agent.has_external_trigger) {
return {
id: `${agent.id}-listening`,
agentID: agent.id,
agentName: agent.name,
priority: "listening",
message: "Waiting for trigger event",
status: "listening",
};
}
if (agent.is_scheduled || agent.recommended_schedule_cron) {
if (!isNextRunWithin(agent.next_scheduled_run, scheduledWithinMs)) {
return null;
}
return {
id: `${agent.id}-scheduled`,
agentID: agent.id,
agentName: agent.name,
priority: "scheduled",
message: formatNextRun(agent.next_scheduled_run),
status: "scheduled",
};
}
return null;
}
function isNextRunWithin(
iso: string | undefined | null,
windowMs: number | undefined,
): boolean {
if (windowMs === undefined) return true;
if (!iso) return false;
const diff = new Date(iso).getTime() - Date.now();
return diff <= windowMs;
}
function formatNextRun(iso: string | undefined | null): string {
if (!iso) return "Has a scheduled run";
const diff = new Date(iso).getTime() - Date.now();
const minutes = Math.round(diff / 60_000);
if (minutes <= 0) return "Scheduled to run soon";
if (minutes < 60) return `Scheduled to run in ${minutes}m`;
const hours = Math.round(minutes / 60);
if (hours < 24) return `Scheduled to run in ${hours}h`;
const days = Math.round(hours / 24);
return `Scheduled to run in ${days}d`;
}

View File

@@ -0,0 +1,84 @@
"use client";
import { cn } from "@/lib/utils";
import type { AgentStatus } from "../../types";
const STATUS_CONFIG: Record<
AgentStatus,
{ label: string; bg: string; text: string; pulse: boolean }
> = {
running: {
label: "Running",
bg: "",
text: "text-blue-600",
pulse: true,
},
error: {
label: "Error",
bg: "",
text: "text-red-500",
pulse: false,
},
listening: {
label: "Listening",
bg: "",
text: "text-purple-500",
pulse: true,
},
scheduled: {
label: "Scheduled",
bg: "",
text: "text-yellow-600",
pulse: false,
},
idle: {
label: "Idle",
bg: "",
text: "text-zinc-500",
pulse: false,
},
};
interface Props {
status: AgentStatus;
className?: string;
}
export function StatusBadge({ status, className }: Props) {
const config = STATUS_CONFIG[status];
return (
<span
className={cn(
"inline-flex items-center gap-1.5 rounded-full px-2 py-0.5 text-xs font-medium",
config.bg,
config.text,
className,
)}
>
<span
className={cn(
"inline-block h-1.5 w-1.5 rounded-full",
config.pulse && "animate-pulse",
statusDotColor(status),
)}
/>
{config.label}
</span>
);
}
function statusDotColor(status: AgentStatus): string {
switch (status) {
case "running":
return "bg-blue-500";
case "error":
return "bg-red-500";
case "listening":
return "bg-purple-500";
case "scheduled":
return "bg-yellow-500";
case "idle":
return "bg-zinc-400";
}
}

View File

@@ -0,0 +1,59 @@
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import type { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
export const SEVENTY_TWO_HOURS_MS = 72 * 60 * 60 * 1000;
export function isActive(status: string): boolean {
return (
status === AgentExecutionStatus.RUNNING ||
status === AgentExecutionStatus.QUEUED ||
status === AgentExecutionStatus.REVIEW
);
}
export function isFailed(status: string): boolean {
return (
status === AgentExecutionStatus.FAILED ||
status === AgentExecutionStatus.TERMINATED
);
}
export function toEndTime(exec: GraphExecutionMeta): number {
if (!exec.ended_at) return 0;
return exec.ended_at instanceof Date
? exec.ended_at.getTime()
: new Date(exec.ended_at).getTime();
}
export function endedAfter(exec: GraphExecutionMeta, cutoff: number): boolean {
if (!exec.ended_at) return false;
return toEndTime(exec) > cutoff;
}
export function runningMessage(
status: string,
startedAt?: string | Date | null,
): string {
if (status === AgentExecutionStatus.QUEUED) return "Queued for execution";
if (status === AgentExecutionStatus.REVIEW) return "Awaiting review";
if (!startedAt) return "Currently executing";
const ms =
Date.now() -
(startedAt instanceof Date
? startedAt.getTime()
: new Date(startedAt).getTime());
return `Running for ${formatRelativeDuration(ms)}`;
}
export function formatRelativeDuration(ms: number): string {
const seconds = Math.floor(ms / 1000);
if (seconds < 60) return "a few seconds";
const minutes = Math.floor(seconds / 60);
if (minutes < 60) return `${minutes}m`;
const hours = Math.floor(minutes / 60);
const remainingMin = minutes % 60;
if (hours < 24)
return remainingMin > 0 ? `${hours}h ${remainingMin}m` : `${hours}h`;
const days = Math.floor(hours / 24);
return `${days}d ${hours % 24}h`;
}

View File

@@ -0,0 +1,213 @@
"use client";
import { useMemo } from "react";
import { useGetV1ListAllExecutions } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import type { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { okData } from "@/app/api/helpers";
import type {
AgentStatus,
AgentHealth,
AgentStatusInfo,
FleetSummary,
} from "../types";
import {
isActive,
isFailed,
toEndTime,
SEVENTY_TWO_HOURS_MS,
} from "./executionHelpers";
function deriveHealth(
status: AgentStatus,
lastRunAt: string | null,
): AgentHealth {
if (status === "error") return "attention";
if (status === "idle" && lastRunAt) {
const daysSince =
(Date.now() - new Date(lastRunAt).getTime()) / (1000 * 60 * 60 * 24);
if (daysSince > 14) return "stale";
}
return "good";
}
function computeAgentStatus(
agent: LibraryAgent,
agentExecutions: GraphExecutionMeta[],
): AgentStatusInfo {
const activeExec = agentExecutions.find((e) => isActive(e.status));
let status: AgentStatus;
let lastError: string | null = null;
let lastRunAt: string | null = null;
const activeExecutionID = activeExec?.id ?? null;
if (activeExec) {
status = "running";
} else {
const cutoff = Date.now() - SEVENTY_TWO_HOURS_MS;
const recentFailed = agentExecutions.find(
(e) =>
isFailed(e.status) &&
e.ended_at &&
new Date(
e.ended_at instanceof Date ? e.ended_at.getTime() : e.ended_at,
).getTime() > cutoff,
);
if (recentFailed) {
status = "error";
lastError =
(recentFailed.stats?.error as string) ??
(recentFailed.stats?.activity_status as string) ??
"Execution failed";
} else if (agent.has_external_trigger) {
status = "listening";
} else if (agent.is_scheduled || agent.recommended_schedule_cron) {
status = "scheduled";
} else {
status = "idle";
}
}
const completedExecs = agentExecutions.filter((e) => e.ended_at);
if (completedExecs.length > 0) {
const sorted = completedExecs.sort((a, b) => toEndTime(b) - toEndTime(a));
const endedAt = sorted[0].ended_at;
lastRunAt =
endedAt instanceof Date ? endedAt.toISOString() : String(endedAt);
}
const totalRuns = agent.execution_count ?? agentExecutions.length;
return {
status,
health: deriveHealth(status, lastRunAt),
progress: null,
totalRuns,
lastRunAt,
lastError,
activeExecutionID,
monthlySpend: 0,
nextScheduledRun: null,
triggerType: agent.has_external_trigger ? "webhook" : null,
};
}
export function useAgentStatusMap(
agents: LibraryAgent[],
): Map<string, AgentStatusInfo> {
const { data: executions } = useGetV1ListAllExecutions({
query: { select: okData },
});
return useMemo(() => {
const map = new Map<string, AgentStatusInfo>();
const execsByGraph = new Map<string, GraphExecutionMeta[]>();
for (const exec of executions ?? []) {
const list = execsByGraph.get(exec.graph_id);
if (list) {
list.push(exec);
} else {
execsByGraph.set(exec.graph_id, [exec]);
}
}
for (const agent of agents) {
const agentExecs = execsByGraph.get(agent.graph_id) ?? [];
map.set(agent.graph_id, computeAgentStatus(agent, agentExecs));
}
return map;
}, [agents, executions]);
}
const DEFAULT_STATUS: AgentStatusInfo = {
status: "idle",
health: "good",
progress: null,
totalRuns: 0,
lastRunAt: null,
lastError: null,
activeExecutionID: null,
monthlySpend: 0,
nextScheduledRun: null,
triggerType: null,
};
export function getAgentStatus(
statusMap: Map<string, AgentStatusInfo>,
graphID: string,
): AgentStatusInfo {
return statusMap.get(graphID) ?? DEFAULT_STATUS;
}
export function useFleetSummary(agents: LibraryAgent[]): FleetSummary {
const { data: executions } = useGetV1ListAllExecutions({
query: { select: okData },
});
return useMemo(() => {
const counts: FleetSummary = {
running: 0,
error: 0,
completed: 0,
listening: 0,
scheduled: 0,
idle: 0,
monthlySpend: 0,
};
const activeGraphIds = new Set<string>();
const errorGraphIds = new Set<string>();
const completedGraphIds = new Set<string>();
if (executions) {
const cutoff = Date.now() - SEVENTY_TWO_HOURS_MS;
for (const exec of executions) {
if (isActive(exec.status)) {
activeGraphIds.add(exec.graph_id);
}
const endedTs = exec.ended_at
? new Date(
exec.ended_at instanceof Date
? exec.ended_at.getTime()
: exec.ended_at,
).getTime()
: 0;
if (isFailed(exec.status) && endedTs > cutoff) {
errorGraphIds.add(exec.graph_id);
}
if (
exec.status === AgentExecutionStatus.COMPLETED &&
endedTs > cutoff
) {
completedGraphIds.add(exec.graph_id);
}
}
}
for (const agent of agents) {
if (activeGraphIds.has(agent.graph_id)) {
counts.running += 1;
} else if (errorGraphIds.has(agent.graph_id)) {
counts.error += 1;
} else if (agent.has_external_trigger) {
counts.listening += 1;
} else if (agent.is_scheduled || agent.recommended_schedule_cron) {
counts.scheduled += 1;
} else {
counts.idle += 1;
}
if (completedGraphIds.has(agent.graph_id)) {
counts.completed += 1;
}
}
return counts;
}, [agents, executions]);
}
export { deriveHealth };

View File

@@ -0,0 +1,116 @@
"use client";
import {
getGetV1ListAllExecutionsQueryKey,
useGetV1ListAllExecutions,
} from "@/app/api/__generated__/endpoints/graphs/graphs";
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { okData } from "@/app/api/helpers";
import { useExecutionEvents } from "@/hooks/useExecutionEvents";
import { useQueryClient } from "@tanstack/react-query";
import { useCallback, useMemo } from "react";
import type { FleetSummary } from "../types";
import { isActive, isFailed, SEVENTY_TWO_HOURS_MS } from "./executionHelpers";
function isRecentFailure(
status: string,
endedAt?: string | Date | null,
): boolean {
if (!isFailed(status)) return false;
if (!endedAt) return false;
const ts =
endedAt instanceof Date ? endedAt.getTime() : new Date(endedAt).getTime();
return Date.now() - ts < SEVENTY_TWO_HOURS_MS;
}
function isRecentCompletion(
status: string,
endedAt?: string | Date | null,
): boolean {
if (status !== AgentExecutionStatus.COMPLETED) return false;
if (!endedAt) return false;
const ts =
endedAt instanceof Date ? endedAt.getTime() : new Date(endedAt).getTime();
return Date.now() - ts < SEVENTY_TWO_HOURS_MS;
}
export function useLibraryFleetSummary(
agents: LibraryAgent[],
): FleetSummary | undefined {
const queryClient = useQueryClient();
const { data: executions, isSuccess } = useGetV1ListAllExecutions({
query: { select: okData },
});
const graphIDs = useMemo(() => agents.map((a) => a.graph_id), [agents]);
const handleExecutionUpdate = useCallback(() => {
queryClient.invalidateQueries({
queryKey: getGetV1ListAllExecutionsQueryKey(),
});
}, [queryClient]);
useExecutionEvents({
graphIds: graphIDs.length > 0 ? graphIDs : undefined,
enabled: graphIDs.length > 0,
onExecutionUpdate: handleExecutionUpdate,
});
return useMemo(() => {
if (!isSuccess || !executions) return undefined;
const agentsWithActiveExecution = new Set<string>();
const agentsWithRecentFailure = new Set<string>();
const agentsWithRecentCompletion = new Set<string>();
for (const exec of executions) {
if (isActive(exec.status)) {
agentsWithActiveExecution.add(exec.graph_id);
}
if (isRecentFailure(exec.status, exec.ended_at)) {
agentsWithRecentFailure.add(exec.graph_id);
}
if (isRecentCompletion(exec.status, exec.ended_at)) {
agentsWithRecentCompletion.add(exec.graph_id);
}
}
const summary: FleetSummary = {
running: 0,
error: 0,
completed: 0,
listening: 0,
scheduled: 0,
idle: 0,
monthlySpend: 0,
};
for (const agent of agents) {
if (agentsWithActiveExecution.has(agent.graph_id)) {
summary.running += 1;
} else if (agentsWithRecentFailure.has(agent.graph_id)) {
summary.error += 1;
} else if (agent.has_external_trigger) {
summary.listening += 1;
} else if (agent.is_scheduled || agent.recommended_schedule_cron) {
summary.scheduled += 1;
} else {
summary.idle += 1;
}
// Parallel counter: mutually exclusive with running/error (which match
// the sitrep priority order used by the "Recently completed" tab list)
// but orthogonal to listening/scheduled/idle.
if (
!agentsWithActiveExecution.has(agent.graph_id) &&
!agentsWithRecentFailure.has(agent.graph_id) &&
agentsWithRecentCompletion.has(agent.graph_id)
) {
summary.completed += 1;
}
}
return summary;
}, [agents, executions, isSuccess]);
}

View File

@@ -2,12 +2,14 @@
import { useEffect, useState, useCallback } from "react";
import { HeartIcon, ListIcon } from "@phosphor-icons/react";
import { JumpBackIn } from "./components/JumpBackIn/JumpBackIn";
import { LibraryActionHeader } from "./components/LibraryActionHeader/LibraryActionHeader";
import { LibraryAgentList } from "./components/LibraryAgentList/LibraryAgentList";
import { useLibraryListPage } from "./components/useLibraryListPage";
import { FavoriteAnimationProvider } from "./context/FavoriteAnimationContext";
import { LibraryTab } from "./types";
import type { LibraryTab, AgentStatusFilter } from "./types";
import { useLibraryFleetSummary } from "./hooks/useLibraryFleetSummary";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { useLibraryAgents } from "@/hooks/useLibraryAgents/useLibraryAgents";
const LIBRARY_TABS: LibraryTab[] = [
{ id: "all", title: "All", icon: ListIcon },
@@ -19,6 +21,10 @@ export default function LibraryPage() {
useLibraryListPage();
const [selectedFolderId, setSelectedFolderId] = useState<string | null>(null);
const [activeTab, setActiveTab] = useState(LIBRARY_TABS[0].id);
const [statusFilter, setStatusFilter] = useState<AgentStatusFilter>("all");
const isAgentBriefingEnabled = useGetFlag(Flag.AGENT_BRIEFING);
const { agents } = useLibraryAgents();
const fleetSummary = useLibraryFleetSummary(agents);
useEffect(() => {
document.title = "Library AutoGPT Platform";
@@ -40,7 +46,6 @@ export default function LibraryPage() {
>
<main className="pt-160 container min-h-screen space-y-4 pb-20 pt-16 sm:px-8 md:px-12">
<LibraryActionHeader setSearchTerm={setSearchTerm} />
<JumpBackIn />
<LibraryAgentList
searchTerm={searchTerm}
librarySort={librarySort}
@@ -50,6 +55,10 @@ export default function LibraryPage() {
tabs={LIBRARY_TABS}
activeTab={activeTab}
onTabChange={handleTabChange}
statusFilter={statusFilter}
onStatusFilterChange={setStatusFilter}
fleetSummary={isAgentBriefingEnabled ? fleetSummary : undefined}
briefingAgents={isAgentBriefingEnabled ? agents : undefined}
/>
</main>
</FavoriteAnimationProvider>

View File

@@ -1,7 +1,76 @@
import { Icon } from "@phosphor-icons/react";
import type { Icon } from "@phosphor-icons/react";
export interface LibraryTab {
id: string;
title: string;
icon: Icon;
}
/** Agent execution status — drives StatusBadge visuals & filtering. */
export type AgentStatus =
| "running"
| "error"
| "listening"
| "scheduled"
| "idle";
/** Derived health bucket for quick triage. */
export type AgentHealth = "good" | "attention" | "stale";
/** Real-time metadata that powers the Intelligence Layer features. */
export interface AgentStatusInfo {
status: AgentStatus;
health: AgentHealth;
/** 0-100 progress for currently running agents. */
progress: number | null;
totalRuns: number;
lastRunAt: string | null;
lastError: string | null;
/** ID of the currently active execution (when status is "running"). */
activeExecutionID: string | null;
monthlySpend: number;
nextScheduledRun: string | null;
triggerType: string | null;
}
/** Fleet-wide aggregate counts used by the Briefing Panel stats grid. */
export interface FleetSummary {
running: number;
error: number;
completed: number;
listening: number;
scheduled: number;
idle: number;
monthlySpend: number;
}
export type SitrepPriority =
| "error"
| "running"
| "stale"
| "success"
| "listening"
| "scheduled"
| "idle";
export interface SitrepItemData {
id: string;
agentID: string;
agentName: string;
agentImageUrl?: string | null;
executionID?: string;
priority: SitrepPriority;
message: string;
status: AgentStatus;
}
/** Filter options for the agent filter dropdown. */
export type AgentStatusFilter =
| "all"
| "running"
| "attention"
| "completed"
| "listening"
| "scheduled"
| "idle"
| "healthy";

View File

@@ -1,6 +1,8 @@
"use client";
import { useState } from "react";
import { Button } from "@/components/ui/button";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
import { useSubscriptionTierSection } from "./useSubscriptionTierSection";
type TierInfo = {
@@ -15,39 +17,70 @@ const TIERS: TierInfo[] = [
key: "FREE",
label: "Free",
multiplier: "1x",
description: "Base rate limits",
description: "Base AutoPilot capacity with standard rate limits",
},
{
key: "PRO",
label: "Pro",
multiplier: "5x",
description: "5x more AutoPilot capacity",
description: "5x AutoPilot capacity — run 5× more tasks per day/week",
},
{
key: "BUSINESS",
label: "Business",
multiplier: "20x",
description: "20x more AutoPilot capacity",
description: "20x AutoPilot capacity — ideal for teams and heavy workloads",
},
];
function formatCost(cents: number): string {
if (cents === 0) return "Free";
const TIER_ORDER = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"];
function formatCost(cents: number, tierKey: string): string {
if (tierKey === "FREE") return "Free";
if (cents === 0) return "Pricing available soon";
return `$${(cents / 100).toFixed(2)}/mo`;
}
export function SubscriptionTierSection() {
const { subscription, isLoading, error, isPending, changeTier } =
useSubscriptionTierSection();
const [tierError, setTierError] = useState<string | null>(null);
const {
subscription,
isLoading,
error,
tierError,
isPending,
pendingTier,
pendingUpgradeTier,
setPendingUpgradeTier,
confirmUpgrade,
isPaymentEnabled,
changeTier,
handleTierChange,
} = useSubscriptionTierSection();
const [confirmDowngradeTo, setConfirmDowngradeTo] = useState<string | null>(
null,
);
if (isLoading) return null;
if (isLoading) {
return (
<div className="space-y-4">
<Skeleton className="h-6 w-48" />
<div className="grid grid-cols-1 gap-3 sm:grid-cols-3">
<Skeleton className="h-40 rounded-lg" />
<Skeleton className="h-40 rounded-lg" />
<Skeleton className="h-40 rounded-lg" />
</div>
</div>
);
}
if (error) {
return (
<div className="space-y-4">
<h3 className="text-lg font-medium">Subscription Plan</h3>
<p className="rounded-md border border-red-200 bg-red-50 px-3 py-2 text-sm text-red-700 dark:border-red-800 dark:bg-red-900/20 dark:text-red-400">
<p
role="alert"
className="rounded-md border border-red-200 bg-red-50 px-3 py-2 text-sm text-red-700 dark:border-red-800 dark:bg-red-900/20 dark:text-red-400"
>
{error}
</p>
</div>
@@ -56,10 +89,30 @@ export function SubscriptionTierSection() {
if (!subscription) return null;
async function handleTierChange(tierKey: string) {
setTierError(null);
const err = await changeTier(tierKey);
if (err) setTierError(err);
const currentTier = subscription.tier;
if (currentTier === "ENTERPRISE") {
return (
<div className="space-y-4">
<h3 className="text-lg font-medium">Subscription Plan</h3>
<div className="rounded-lg border border-violet-500 bg-violet-50 p-4 dark:bg-violet-900/20">
<p className="font-semibold text-violet-700 dark:text-violet-200">
Enterprise Plan
</p>
<p className="mt-1 text-sm text-neutral-600 dark:text-neutral-400">
Your Enterprise plan is managed by your administrator. Contact your
account team for changes.
</p>
</div>
</div>
);
}
async function confirmDowngrade() {
if (!confirmDowngradeTo) return;
const tier = confirmDowngradeTo;
setConfirmDowngradeTo(null);
await changeTier(tier);
}
return (
@@ -67,24 +120,28 @@ export function SubscriptionTierSection() {
<h3 className="text-lg font-medium">Subscription Plan</h3>
{tierError && (
<p className="rounded-md border border-red-200 bg-red-50 px-3 py-2 text-sm text-red-700 dark:border-red-800 dark:bg-red-900/20 dark:text-red-400">
<p
role="alert"
className="rounded-md border border-red-200 bg-red-50 px-3 py-2 text-sm text-red-700 dark:border-red-800 dark:bg-red-900/20 dark:text-red-400"
>
{tierError}
</p>
)}
<div className="grid grid-cols-1 gap-3 sm:grid-cols-3">
{TIERS.map((tier) => {
const isCurrent = subscription.tier === tier.key;
const isCurrent = currentTier === tier.key;
const cost = subscription.tier_costs[tier.key] ?? 0;
const currentTierOrder = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"];
const currentIdx = currentTierOrder.indexOf(subscription.tier);
const targetIdx = currentTierOrder.indexOf(tier.key);
const currentIdx = TIER_ORDER.indexOf(currentTier);
const targetIdx = TIER_ORDER.indexOf(tier.key);
const isUpgrade = targetIdx > currentIdx;
const isDowngrade = targetIdx < currentIdx;
const isThisPending = pendingTier === tier.key;
return (
<div
key={tier.key}
aria-current={isCurrent ? "true" : undefined}
className={`rounded-lg border p-4 ${
isCurrent
? "border-violet-500 bg-violet-50 dark:bg-violet-900/20"
@@ -100,7 +157,9 @@ export function SubscriptionTierSection() {
)}
</div>
<p className="mb-1 text-2xl font-bold">{formatCost(cost)}</p>
<p className="mb-1 text-2xl font-bold">
{formatCost(cost, tier.key)}
</p>
<p className="mb-1 text-sm font-medium text-neutral-600 dark:text-neutral-400">
{tier.multiplier} rate limits
</p>
@@ -108,14 +167,20 @@ export function SubscriptionTierSection() {
{tier.description}
</p>
{!isCurrent && (
{!isCurrent && isPaymentEnabled && (
<Button
className="w-full"
variant={isUpgrade ? "default" : "outline"}
disabled={isPending}
onClick={() => handleTierChange(tier.key)}
onClick={() =>
handleTierChange(
tier.key,
currentTier,
setConfirmDowngradeTo,
)
}
>
{isPending
{isThisPending
? "Updating..."
: isUpgrade
? `Upgrade to ${tier.label}`
@@ -129,12 +194,79 @@ export function SubscriptionTierSection() {
})}
</div>
{subscription.tier !== "FREE" && (
{currentTier !== "FREE" && isPaymentEnabled && (
<p className="text-sm text-neutral-500">
Your subscription is managed through Stripe. Changes take effect
immediately.
Your subscription is managed through Stripe. Upgrades and paid-tier
changes take effect immediately; downgrades to Free are scheduled for
the end of the current billing period.
</p>
)}
<Dialog
title="Confirm Downgrade"
controlled={{
isOpen: !!confirmDowngradeTo,
set: (open) => {
if (!open) setConfirmDowngradeTo(null);
},
}}
>
<Dialog.Content>
<p className="text-sm text-neutral-600 dark:text-neutral-400">
{confirmDowngradeTo === "FREE"
? "Downgrading to Free will schedule your subscription to cancel at the end of your current billing period. You keep your current plan until then."
: `Switching to ${TIERS.find((t) => t.key === confirmDowngradeTo)?.label ?? confirmDowngradeTo} will take effect immediately.`}{" "}
Are you sure?
</p>
<Dialog.Footer>
<Button
variant="outline"
onClick={() => setConfirmDowngradeTo(null)}
>
Cancel
</Button>
<Button
variant="destructive"
onClick={() => void confirmDowngrade()}
>
Confirm Downgrade
</Button>
</Dialog.Footer>
</Dialog.Content>
</Dialog>
<Dialog
title="Confirm Upgrade"
controlled={{
isOpen: !!pendingUpgradeTier,
set: (open) => {
if (!open) setPendingUpgradeTier(null);
},
}}
>
<Dialog.Content>
<p className="text-sm text-neutral-600 dark:text-neutral-400">
{subscription &&
subscription.proration_credit_cents > 0 &&
`Your unused ${currentTier.charAt(0) + currentTier.slice(1).toLowerCase()} subscription ($${(subscription.proration_credit_cents / 100).toFixed(2)}) will be applied as a credit to your next Stripe invoice. `}
You will be redirected to Stripe to complete your upgrade to{" "}
{TIERS.find((t) => t.key === pendingUpgradeTier)?.label ??
pendingUpgradeTier}
.
</p>
<Dialog.Footer>
<Button
variant="outline"
onClick={() => setPendingUpgradeTier(null)}
>
Cancel
</Button>
<Button onClick={() => void confirmUpgrade()}>
Continue to Checkout
</Button>
</Dialog.Footer>
</Dialog.Content>
</Dialog>
</div>
);
}

View File

@@ -0,0 +1,358 @@
import {
render,
screen,
fireEvent,
waitFor,
cleanup,
} from "@/tests/integrations/test-utils";
import { afterEach, describe, expect, it, vi } from "vitest";
import { SubscriptionTierSection } from "../SubscriptionTierSection";
// Mock next/navigation
const mockSearchParams = new URLSearchParams();
const mockRouterReplace = vi.fn();
vi.mock("next/navigation", async (importOriginal) => {
const actual = await importOriginal<typeof import("next/navigation")>();
return {
...actual,
useSearchParams: () => mockSearchParams,
useRouter: () => ({ push: vi.fn(), replace: mockRouterReplace }),
usePathname: () => "/profile/credits",
};
});
// Mock toast
const mockToast = vi.fn();
vi.mock("@/components/molecules/Toast/use-toast", () => ({
useToast: () => ({ toast: mockToast }),
}));
// Mock feature flags — default to payment enabled so button tests work
let mockPaymentEnabled = true;
vi.mock("@/services/feature-flags/use-get-flag", () => ({
Flag: { ENABLE_PLATFORM_PAYMENT: "enable-platform-payment" },
useGetFlag: () => mockPaymentEnabled,
}));
// Mock generated API hooks
const mockUseGetSubscriptionStatus = vi.fn();
const mockUseUpdateSubscriptionTier = vi.fn();
vi.mock("@/app/api/__generated__/endpoints/credits/credits", () => ({
useGetSubscriptionStatus: (opts: unknown) =>
mockUseGetSubscriptionStatus(opts),
useUpdateSubscriptionTier: () => mockUseUpdateSubscriptionTier(),
}));
// Mock Dialog (Radix portals don't work in happy-dom)
const MockDialogContent = ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
);
const MockDialogFooter = ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
);
function MockDialog({
controlled,
children,
}: {
controlled?: { isOpen: boolean; set: (open: boolean) => void };
children: React.ReactNode;
[key: string]: unknown;
}) {
return controlled?.isOpen ? <div role="dialog">{children}</div> : null;
}
MockDialog.Content = MockDialogContent;
MockDialog.Footer = MockDialogFooter;
vi.mock("@/components/molecules/Dialog/Dialog", () => ({
Dialog: MockDialog,
}));
function makeSubscription({
tier = "FREE",
monthlyCost = 0,
tierCosts = { FREE: 0, PRO: 1999, BUSINESS: 4999, ENTERPRISE: 0 },
prorationCreditCents = 0,
}: {
tier?: string;
monthlyCost?: number;
tierCosts?: Record<string, number>;
prorationCreditCents?: number;
} = {}) {
return {
tier,
monthly_cost: monthlyCost,
tier_costs: tierCosts,
proration_credit_cents: prorationCreditCents,
};
}
function setupMocks({
subscription = makeSubscription(),
isLoading = false,
queryError = null as Error | null,
mutateFn = vi.fn().mockResolvedValue({ status: 200, data: { url: "" } }),
isPending = false,
variables = undefined as { data?: { tier?: string } } | undefined,
} = {}) {
// The hook uses select: (data) => (data.status === 200 ? data.data : null)
// so the data value returned by the hook is already the transformed subscription object.
// We simulate that by returning the subscription directly as data.
mockUseGetSubscriptionStatus.mockReturnValue({
data: subscription,
isLoading,
error: queryError,
refetch: vi.fn(),
});
mockUseUpdateSubscriptionTier.mockReturnValue({
mutateAsync: mutateFn,
isPending,
variables,
});
}
afterEach(() => {
cleanup();
mockUseGetSubscriptionStatus.mockReset();
mockUseUpdateSubscriptionTier.mockReset();
mockToast.mockReset();
mockRouterReplace.mockReset();
mockSearchParams.delete("subscription");
mockPaymentEnabled = true;
});
describe("SubscriptionTierSection", () => {
it("renders skeleton cards while loading", () => {
setupMocks({ isLoading: true });
render(<SubscriptionTierSection />);
// Just verify we're rendering something (not null) and no tier cards
expect(screen.queryByText("Pro")).toBeNull();
expect(screen.queryByText("Business")).toBeNull();
});
it("renders error message when subscription fetch fails", () => {
setupMocks({
queryError: new Error("Network error"),
subscription: makeSubscription(),
});
// Override the data to simulate failed state
mockUseGetSubscriptionStatus.mockReturnValue({
data: null,
isLoading: false,
error: new Error("Network error"),
refetch: vi.fn(),
});
render(<SubscriptionTierSection />);
expect(screen.getByRole("alert")).toBeDefined();
expect(screen.getByText(/failed to load subscription info/i)).toBeDefined();
});
it("renders all three tier cards for FREE user", () => {
setupMocks();
render(<SubscriptionTierSection />);
// Use getAllByText to account for the tier label AND cost display both containing "Free"
expect(screen.getAllByText("Free").length).toBeGreaterThan(0);
expect(screen.getByText("Pro")).toBeDefined();
expect(screen.getByText("Business")).toBeDefined();
});
it("shows Current badge on the active tier", () => {
setupMocks({ subscription: makeSubscription({ tier: "PRO" }) });
render(<SubscriptionTierSection />);
expect(screen.getByText("Current")).toBeDefined();
// Upgrade to PRO button should NOT exist; Upgrade to BUSINESS and Downgrade to Free should
expect(
screen.queryByRole("button", { name: /upgrade to pro/i }),
).toBeNull();
expect(
screen.getByRole("button", { name: /upgrade to business/i }),
).toBeDefined();
expect(
screen.getByRole("button", { name: /downgrade to free/i }),
).toBeDefined();
});
it("displays tier costs from the API", () => {
setupMocks({
subscription: makeSubscription({
tier: "FREE",
tierCosts: { FREE: 0, PRO: 1999, BUSINESS: 4999, ENTERPRISE: 0 },
}),
});
render(<SubscriptionTierSection />);
expect(screen.getByText("$19.99/mo")).toBeDefined();
expect(screen.getByText("$49.99/mo")).toBeDefined();
// FREE tier label should still be visible (there may be multiple "Free" elements)
expect(screen.getAllByText("Free").length).toBeGreaterThan(0);
});
it("shows 'Pricing available soon' when tier cost is 0 for a paid tier", () => {
setupMocks({
subscription: makeSubscription({
tier: "FREE",
tierCosts: { FREE: 0, PRO: 0, BUSINESS: 0, ENTERPRISE: 0 },
}),
});
render(<SubscriptionTierSection />);
// PRO and BUSINESS with cost=0 should show "Pricing available soon"
expect(screen.getAllByText("Pricing available soon")).toHaveLength(2);
});
it("calls changeTier on upgrade click after confirmation dialog", async () => {
const mutateFn = vi
.fn()
.mockResolvedValue({ status: 200, data: { url: "" } });
setupMocks({ mutateFn });
render(<SubscriptionTierSection />);
// Clicking upgrade opens the confirmation dialog first
fireEvent.click(screen.getByRole("button", { name: /upgrade to pro/i }));
// Confirm via the dialog's "Continue to Checkout" button
fireEvent.click(
screen.getByRole("button", { name: /continue to checkout/i }),
);
await waitFor(() => {
expect(mutateFn).toHaveBeenCalledWith(
expect.objectContaining({
data: expect.objectContaining({ tier: "PRO" }),
}),
);
});
});
it("shows confirmation dialog on downgrade click", () => {
setupMocks({ subscription: makeSubscription({ tier: "PRO" }) });
render(<SubscriptionTierSection />);
fireEvent.click(screen.getByRole("button", { name: /downgrade to free/i }));
expect(screen.getByRole("dialog")).toBeDefined();
// The dialog title text appears in both a div and a button — just check the dialog is open
expect(screen.getAllByText(/confirm downgrade/i).length).toBeGreaterThan(0);
});
it("calls changeTier after downgrade confirmation", async () => {
const mutateFn = vi
.fn()
.mockResolvedValue({ status: 200, data: { url: "" } });
setupMocks({
subscription: makeSubscription({ tier: "PRO" }),
mutateFn,
});
render(<SubscriptionTierSection />);
fireEvent.click(screen.getByRole("button", { name: /downgrade to free/i }));
fireEvent.click(screen.getByRole("button", { name: /confirm downgrade/i }));
await waitFor(() => {
expect(mutateFn).toHaveBeenCalledWith(
expect.objectContaining({
data: expect.objectContaining({ tier: "FREE" }),
}),
);
});
});
it("dismisses dialog when Cancel is clicked", () => {
setupMocks({ subscription: makeSubscription({ tier: "PRO" }) });
render(<SubscriptionTierSection />);
fireEvent.click(screen.getByRole("button", { name: /downgrade to free/i }));
expect(screen.getByRole("dialog")).toBeDefined();
fireEvent.click(screen.getByRole("button", { name: /^cancel$/i }));
expect(screen.queryByRole("dialog")).toBeNull();
});
it("redirects to Stripe when checkout URL is returned", async () => {
// Replace window.location with a plain object so assigning .href doesn't
// trigger jsdom navigation (which would throw or reload the test page).
const mockLocation = { href: "" };
vi.stubGlobal("location", mockLocation);
const mutateFn = vi.fn().mockResolvedValue({
status: 200,
data: { url: "https://checkout.stripe.com/pay/cs_test" },
});
setupMocks({ mutateFn });
render(<SubscriptionTierSection />);
// Upgrade opens confirmation dialog first — confirm via "Continue to Checkout"
fireEvent.click(screen.getByRole("button", { name: /upgrade to pro/i }));
fireEvent.click(
screen.getByRole("button", { name: /continue to checkout/i }),
);
await waitFor(() => {
expect(mockLocation.href).toBe("https://checkout.stripe.com/pay/cs_test");
});
vi.unstubAllGlobals();
});
it("shows an error alert when tier change fails", async () => {
const mutateFn = vi.fn().mockRejectedValue(new Error("Stripe unavailable"));
setupMocks({ mutateFn });
render(<SubscriptionTierSection />);
// Upgrade opens confirmation dialog first — confirm to trigger the mutation
fireEvent.click(screen.getByRole("button", { name: /upgrade to pro/i }));
fireEvent.click(
screen.getByRole("button", { name: /continue to checkout/i }),
);
await waitFor(() => {
expect(screen.getByRole("alert")).toBeDefined();
expect(screen.getByText(/stripe unavailable/i)).toBeDefined();
});
});
it("hides action buttons when payment flag is disabled", () => {
mockPaymentEnabled = false;
setupMocks({ subscription: makeSubscription({ tier: "FREE" }) });
render(<SubscriptionTierSection />);
// Tier cards still visible
expect(screen.getByText("Pro")).toBeDefined();
expect(screen.getByText("Business")).toBeDefined();
// No upgrade/downgrade buttons
expect(screen.queryByRole("button", { name: /upgrade/i })).toBeNull();
expect(screen.queryByRole("button", { name: /downgrade/i })).toBeNull();
});
it("shows ENTERPRISE message for ENTERPRISE tier users", () => {
setupMocks({ subscription: makeSubscription({ tier: "ENTERPRISE" }) });
render(<SubscriptionTierSection />);
// Enterprise heading text appears in a <p> (may match multiple), just verify it exists
expect(screen.getAllByText(/enterprise plan/i).length).toBeGreaterThan(0);
expect(screen.getByText(/managed by your administrator/i)).toBeDefined();
// No standard tier cards should be rendered
expect(screen.queryByText("Pro")).toBeNull();
expect(screen.queryByText("Business")).toBeNull();
});
it("shows success toast and clears URL param when ?subscription=success is present", async () => {
mockSearchParams.set("subscription", "success");
setupMocks();
render(<SubscriptionTierSection />);
await waitFor(() => {
expect(mockToast).toHaveBeenCalledWith(
expect.objectContaining({ title: "Subscription upgraded" }),
);
});
// URL param must be stripped so a page refresh doesn't re-trigger the toast
expect(mockRouterReplace).toHaveBeenCalledWith("/profile/credits");
});
it("clears URL param but shows no toast when ?subscription=cancelled is present", async () => {
mockSearchParams.set("subscription", "cancelled");
setupMocks();
render(<SubscriptionTierSection />);
// The cancelled param must be stripped from the URL (same hygiene as success)
await waitFor(() => {
expect(mockRouterReplace).toHaveBeenCalledWith("/profile/credits");
});
// No toast should fire — the user simply abandoned checkout
expect(mockToast).not.toHaveBeenCalled();
});
});

Some files were not shown because too many files have changed in this diff Show More