Compare commits

...

135 Commits

Author SHA1 Message Date
Zamil Majdy
5cc72e7608 fix(backend): validate non-finite float in extract_openrouter_cost, restore UTILITIES comment
- Add math.isfinite() guard in extract_openrouter_cost so inf/nan header
  values are rejected instead of stored (Sentry finding)
- Add test cases for inf, -inf, and nan header values
- Restore the '# ------- UTILITIES ------- #' section separator in
  manager.py that was accidentally dropped during the drain-on-shutdown commit
2026-04-08 10:26:41 +07:00
Zamil Majdy
fa0650214d refactor(backend): extract shared _update_title_async to copilot/service.py
The function was duplicated in both baseline/service.py and sdk/service.py
with identical logic. Consolidate into the shared service module alongside
_generate_session_title which it wraps.
2026-04-08 10:14:29 +07:00
Zamil Majdy
3e016508d4 fix(backend): replace getattr/hasattr duck-typing with try/except in baseline cost extraction
Consistent with the pattern used in llm.extract_openrouter_cost():
use try/except (AttributeError, ValueError) instead of getattr(response, '_response', None)
+ hasattr(raw_resp, 'headers') so there is only one attribute-access pattern in the codebase.
2026-04-08 10:08:52 +07:00
Zamil Majdy
b80d7abda9 fix(backend): remove per-worker Prisma connect, route DB via DatabaseManagerAsyncClient
- Remove local import + db.connect()/disconnect() from CoPilotProcessor.on_executor_start
  DB calls already route through db_accessors (chat_db, user_db) which fall back to
  DatabaseManagerAsyncClient RPC when db.is_connected() is False
- Fix rate_limit._fetch_user_tier to use user_db().get_user_by_id() instead of
  PrismaUser.prisma() directly — avoids requiring Prisma connected on worker event loop
- Add subscription_tier field to User Pydantic model, mapped in User.from_db() so
  the RPC path returns the tier value without a direct Prisma connection
2026-04-08 10:05:18 +07:00
Zamil Majdy
0e310c788a fix(backend): fix TestGetPlatformCostDashboard mocks to match 3-query implementation
get_platform_cost_dashboard runs 3 concurrent queries (by_provider,
by_user, COUNT DISTINCT userId) but the unit tests only provided 2
side_effect values, causing StopAsyncIteration on the third call.
Updated all three test cases to supply a third mock return value and
corrected await_count assertion from 2 to 3.
2026-04-07 23:15:42 +07:00
Zamil Majdy
91af007c18 fix(backend): guard against non-finite cost_usd in persist_and_record_usage
float('inf') and float('nan') do not raise ValueError/TypeError so they
bypass the existing try/except. Passing them to usd_to_microdollars causes
OverflowError at round(inf * 1_000_000). Add math.isfinite(val) and val >= 0
check (matching the same pattern used in baseline/service.py and llm.py)
before assigning cost_float.
2026-04-07 22:56:20 +07:00
Zamil Majdy
e7ca81ed89 fix(backend): address coderabbitai nitpicks in cost tracking files
- token_tracking.py: convert logger.info %s calls to f-strings per style guide
- cost_tracking.py: simplify metadata=meta (was redundantly `meta or None`);
  move token_tracking imports to module level to remove # noqa: PLC0415 suppressors
- baseline/service.py: remove dead UnboundLocalError from except tuple since
  response is initialized to None before the try block
2026-04-07 22:54:52 +07:00
Zamil Majdy
5164fa878f fix(backend): fix race condition on _copilot_tasks concurrent iteration during drain
Add _pending_log_tasks_lock to token_tracking.py so that add/discard
operations on _pending_log_tasks are always lock-protected. Update
drain_pending_cost_logs in cost_tracking.py to acquire the copilot
tasks lock (not its own lock) when taking a snapshot of the copilot
set, preventing RuntimeError: Set changed size during iteration during
graceful shutdown when done callbacks fire concurrently.
2026-04-07 22:48:51 +07:00
Zamil Majdy
cf605ef5a3 fix(backend): fix race condition on _pending_log_tasks and uncapped total_users
- cost_tracking.py: add threading.Lock (_pending_log_tasks_lock) around all
  add/discard/iterate access to _pending_log_tasks; worker thread done callbacks
  and drain_pending_cost_logs() run concurrently across loops, causing
  RuntimeError: Set changed size during iteration without a lock

- platform_cost.py: add a separate COUNT(DISTINCT userId) query so total_users
  is accurate for >100 active users; previously it was silently capped at
  MAX_USER_ROWS=100 because it was derived from len(by_user_rows)
2026-04-07 22:25:25 +07:00
Zamil Majdy
e7bd05c6f1 fix(backend): filter drain tasks by current event loop to prevent cross-loop asyncio.wait() crash 2026-04-07 21:42:17 +07:00
Zamil Majdy
22fb3549e3 test(frontend): fix null-user dash assertion using getAllByText to handle multiple matches 2026-04-07 21:31:43 +07:00
Zamil Majdy
1c3fe1444e fix(backend): address 4 unresolved review threads on cost tracking
1. cost_tracking.py: replace shared _log_semaphore with per-loop dict
   (_log_semaphores + _get_log_semaphore()) — asyncio.Semaphore is not
   thread-safe and must not be shared across executor worker threads/loops

2. cost_tracking.py: only honor provider_cost_type when provider_cost is
   also present (not None); use tracking_amount (not raw stats.provider_cost)
   in usd_to_microdollars() to avoid unit mismatches

3. token_tracking.py: add semaphore to _schedule_cost_log (same pattern
   as cost_tracking.py) to bound concurrent DB inserts under load; fix
   forward-reference string in _pending_log_tasks type annotation

4. baseline/service.py: validate x-total-cost header with math.isfinite
   and max(0.0, cost) guard before accumulating — rejects nan/inf values
   that float() accepts but that should never reach the persistence path
2026-04-07 21:25:56 +07:00
Zamil Majdy
b89321a688 Merge remote-tracking branch 'origin/dev' into codex/platform-cost-tracking 2026-04-07 21:16:02 +07:00
Krzysztof Czerwinski
67bdef13e7 feat(platform): load copilot messages from newest first with cursor-based pagination (#12328)
Copilot chat sessions with long histories loaded all messages at once,
causing slow initial loads. This PR adds cursor-based pagination so only
the most recent messages load initially, with older messages fetched on
demand as the user scrolls up.

### Changes 🏗️

**Backend:**
- Cursor-based pagination on `GET /sessions/{session_id}` (`limit`,
`before_sequence` params)
- `user_id` relation filter on the paginated query — ownership check and
message fetch now run in parallel
- Backward boundary expansion to keep tool-call / assistant message
pairs intact at page edges
- Unit tests for paginated queries

**Frontend:**
- `useLoadMoreMessages` hook + `LoadMoreSentinel` (IntersectionObserver)
for infinite scroll upward
- `ScrollPreserver` to maintain scroll position when older messages are
prepended
- Session-keyed `Conversation` remount with one-frame opacity hide to
eliminate scroll flash on switch
- Scrollbar moved to the correct scroll container; loading spinner no
longer causes overflow

### Checklist 📋

- [x] Pagination: only recent messages load initially; older pages load
on scroll-up
- [x] Scroll position preserved on prepend; no flash on session switch
- [x] Tool-call boundary pairs stay intact across page edges
- [x] Stream reconnection still works on initial load

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
2026-04-07 12:43:47 +00:00
Ubbe
e67dd93ee8 refactor(frontend): remove stale feature flags and stabilize share execution (#12697)
## Why

Stale feature flags add noise to the codebase and make it harder to
understand which flags are actually gating live features. Four flags
were defined but never referenced anywhere in the frontend, and the
"Share Execution Results" flag has been stable long enough to remove its
gate.

## What

- Remove 4 unused flags from the `Flag` enum and `defaultFlags`:
`NEW_BLOCK_MENU`, `GRAPH_SEARCH`, `ENABLE_ENHANCED_OUTPUT_HANDLING`,
`AGENT_FAVORITING`
- Remove the `SHARE_EXECUTION_RESULTS` flag and its conditional — the
`ShareRunButton` now always renders

## How

- Deleted enum entries and default values in `use-get-flag.ts`
- Removed the `useGetFlag` call and conditional wrapper around
`<ShareRunButton />` in `SelectedRunActions.tsx`

## Changes

- `src/services/feature-flags/use-get-flag.ts` — removed 5 flags from
enum + defaults
- `src/app/(platform)/library/.../SelectedRunActions.tsx` — removed flag
import, condition; share button always renders

### Checklist

- [x] My PR is small and focused on one change
- [x] I've tested my changes locally
- [x] `pnpm format && pnpm lint` pass

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-07 19:28:40 +07:00
Otto
3140a60816 fix(frontend/builder): allow horizontal scroll for JSON output data (#12638)
Requested by @Abhi1992002 

## Why

JSON output data in the "Complete Output Data" dialog and node output
panel gets clipped — text overflows and is hidden with no way to scroll
right. Reported by Zamil in #frontend.

## What

The `ContentRenderer` wrapper divs used `overflow-hidden` which
prevented the `JSONRenderer`'s `overflow-x-auto` from working. Changed
both wrapper divs from `overflow-hidden` to `overflow-x-auto`.

```diff
- overflow-hidden [&>*]:rounded-xlarge [&>*]:!text-xs [&_pre]:whitespace-pre-wrap [&_pre]:break-words
+ overflow-x-auto [&>*]:rounded-xlarge [&>*]:!text-xs [&_pre]:whitespace-pre-wrap [&_pre]:break-words

- overflow-hidden [&>*]:rounded-xlarge [&>*]:!text-xs
+ overflow-x-auto [&>*]:rounded-xlarge [&>*]:!text-xs
```

## Scope
- 1 file changed (`ContentRenderer.tsx`)
- 2 lines: `overflow-hidden` → `overflow-x-auto`
- CSS only, no logic changes

Resolves SECRT-2206

Co-authored-by: Abhimanyu Yadav <122007096+Abhi1992002@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
2026-04-07 19:11:09 +07:00
Nicholas Tindle
41c2ee9f83 feat(platform): add copilot artifact preview panel (#12629)
### Why / What / How

Copilot artifacts were not previewing reliably: PDFs downloaded instead
of rendering, Python code could still render like markdown, JSX/TSX
artifacts were brittle, HTML dashboards/charts could fail to execute,
and users had to manually open artifact panes after generation. The pane
also got stuck at maximized width when trying to drag it smaller.

This PR adds a dedicated copilot artifact panel and preview pipeline
across the backend/frontend boundary. It preserves artifact metadata
needed for classification, adds extension-first preview routing,
introduces dedicated preview/rendering paths for HTML/CSV/code/PDF/React
artifacts, auto-opens new or edited assistant artifacts, and fixes the
maximized-pane resize path so dragging exits maximized mode immediately.

### Changes 🏗️

- add artifact card and artifact panel UI in copilot, including
persisted panel state and resize/maximize/minimize behavior
- add shared artifact extraction/classification helpers and auto-open
behavior for new or edited assistant messages with artifacts
- add preview/rendering support for HTML, CSV, PDF, code, and React
artifact files
- fix code artifacts such as Python to render through the code renderer
with a dark code surface instead of markdown-style output
- improve JSX/TSX preview behavior with provider wrapping, fallback
export selection, and explicit runtime error surfaces
- allow script execution inside HTML previews so embedded chart
dashboards can render
- update workspace artifact/backend API handling and regenerate the
frontend OpenAPI client
- add regression coverage for artifact helpers, React preview runtime,
auto-open behavior, code rendering, and panel store behavior

- post-review hardening: correct download path for cross-origin URLs,
defer scroll restore until content mounts, gate auto-open behind the
ARTIFACTS flag, parse CSVs with RFC 4180-compliant quoted newlines + BOM
handling, distinguish 413 vs 409 on upload, normalize empty session_id,
and keep AnimatePresence mounted so the panel exit animation plays

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] `pnpm format`
  - [x] `pnpm lint`
  - [x] `pnpm types`
  - [x] `pnpm test:unit`

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Medium Risk**
> Adds a new Copilot artifact preview surface that executes
user/AI-generated HTML/React in sandboxed iframes and changes workspace
file upload/listing behavior, so regressions could affect file handling
and client security assumptions despite sandboxing safeguards.
> 
> **Overview**
> Adds an **Artifacts** feature (flagged by `Flag.ARTIFACTS`) to
Copilot: workspace file links/attachments now render as `ArtifactCard`s
and can open a new resizable/minimizable `ArtifactPanel` with history,
auto-open behavior, copy/download actions, and persisted panel width.
> 
> Introduces a richer artifact preview pipeline with type classification
and dedicated renderers for **HTML**, **CSV**, **PDF**, **code
(Shiki-highlighted)**, and **React/TSX** (transpiled and executed in a
sandboxed iframe), plus safer download filename handling and content
caching/scroll restore.
> 
> Extends the workspace backend API by adding `GET /workspace/files`
pagination, standardizing operation IDs in OpenAPI, attaching
`metadata.origin` on uploads/agent-created files, normalizing empty
`session_id`, improving upload error mapping (409 vs 413), and hardening
post-quota soft-delete error handling; updates and expands test coverage
accordingly.
> 
> <sup>Reviewed by [Cursor Bugbot](https://cursor.com/bugbot) for commit
b732d10eca. Bugbot is set up for automated
code reviews on this repo. Configure
[here](https://www.cursor.com/dashboard/bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-07 11:24:22 +00:00
Zamil Majdy
630d6d4705 fix(backend): add semaphore to executor cost log tasks; fix type annotation
- Add `_log_semaphore = asyncio.Semaphore(50)` to cost_tracking.py to bound
  concurrent DB inserts (mirrors platform_cost.py's existing semaphore)
- Narrow `_extract_model_name` param type from `Any` to `str | dict | None`
- Add `test_get_dashboard_cache_hit` to verify TTL cache deduplicates DB calls
- Add `scope="col"` to all table `<th>` elements for screen-reader accessibility
- Add `(local time)` labels to date filter inputs to clarify timezone behaviour
2026-04-07 18:15:57 +07:00
Zamil Majdy
7c685c6677 fix(backend): update platform_cost_test to expect masked email in dashboard
_mask_email() is applied to by_user emails in get_platform_cost_dashboard().
Test now asserts 'a***@b.com' instead of 'a@b.com'.
2026-04-07 18:02:00 +07:00
Ubbe
ca748ee12a feat(frontend): refine AutoPilot onboarding — branding, auto-advance, soft cap, polish (#12686)
### Why / What / How

**Why:** The onboarding flow had inconsistent branding ("Autopilot" vs
"AutoPilot"), a heavy progress bar that dominated the header, an extra
click on the role screen, and no guidance on how many pain points to
select — leading to users selecting everything or nothing useful.

**What:** Copy & brand fixes, UX improvements (auto-advance, soft cap),
and visual polish (progress bar, checkmark badges, purple focus inputs).

**How:**
- Replaced all "Autopilot" with "AutoPilot" (capital P) across screens
1-3
- Removed the `?` tooltip on screen 1 (users will learn about AutoPilot
from the access email)
- Changed name label to conversational "What should I call you?"
- Screen 2: auto-advances 350ms after role selection (except "Other"
which still shows input + button)
- Screen 3: soft cap of 3 selections with green confirmation text and
shake animation on overflow attempt
- Thinned progress bar from ~10px to 3px (Linear/Notion style)
- Added purple checkmark badges on selected cards
- Updated Input atom focus state to purple ring

### Changes 🏗️

- **WelcomeStep**: "AutoPilot" branding, removed tooltip, conversational
label
- **RoleStep**: Updated subtitle, auto-advance on non-"Other" role
select, Continue button only for "Other"
- **PainPointsStep**: Soft cap of 3 with dynamic helper text and shake
animation
- **usePainPointsStep**: Added `atLimit`/`shaking` state, wrapped
`togglePainPoint` with cap logic
- **store.ts**: `togglePainPoint` returns early when at 3 and adding
- **ProgressBar**: 3px height, removed glow shadow
- **SelectableCard**: Added purple checkmark badge on selected state
- **Input atom**: Focus ring changed from zinc to purple
- **tailwind.config.ts**: Added `shake` keyframe and `animate-shake`
utility

### Checklist 📋

#### For code changes:
- [ ] I have clearly listed my changes in the PR description
- [ ] I have made a test plan
- [ ] I have tested my changes according to the test plan:
  - [ ] Navigate through full onboarding flow (screens 1→2→3→4)
  - [ ] Verify "AutoPilot" branding on all screens (no "Autopilot")
  - [ ] Verify screen 2 auto-advances after tapping a role (non-"Other")
  - [ ] Verify "Other" role still shows text input and Continue button
  - [ ] Verify Back button works correctly from screen 2 and 3
  - [ ] Select 3 pain points and verify green "3 selected" text
  - [ ] Attempt 4th selection and verify shake animation + swap message
  - [ ] Deselect one and verify can select a different one
  - [ ] Verify checkmark badges appear on selected cards
  - [ ] Verify progress bar is thin (3px) and subtle
  - [ ] Verify input focus state is purple across onboarding inputs
- [ ] Verify "Something else" + other text input still works on screen 3

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-07 17:58:36 +07:00
Zamil Majdy
bbdf13c7a8 test(backend): add missing cost-tracking tests for Exa research, Apollo people, and copilot baseline
- ExaCreateResearchBlock, ExaGetResearchBlock, ExaWaitForResearchBlock: verify
  merge_stats is called with provider_cost=cost_dollars.total when completed, and
  not called when costDollars is absent
- SearchPeopleBlock: verify provider_cost=len(people) with type='items'
- Copilot baseline: 4 tests for x-total-cost header extraction in
  _baseline_llm_caller — including accumulation across turns and extraction in
  the finally block when the stream raises
2026-04-07 17:44:15 +07:00
Zamil Majdy
e1ea4cf326 test(frontend): rewrite PlatformCostContent tests to mock Orval hooks
Component now uses React Query hooks (useGetV2GetPlatformCostDashboard,
useGetV2GetPlatformCostLogs) instead of server actions, so tests must
mock @/app/api/__generated__/endpoints/admin/admin rather than ../actions.

Adds 16 test cases covering loading state, empty/data renders, tabs,
filters, pagination, null email/user handling, and tracking type badges.
2026-04-07 17:34:15 +07:00
Zamil Majdy
db6b4444e0 fix(platform): address autogpt-reviewer should-fix items
- Remove dead _pending_log_tasks/schedule_cost_log/drain_pending_cost_logs
  from platform_cost.py (only cost_tracking.py and token_tracking.py have
  active task registries; drain comment updated to match)
- Replace vars(other) iteration in NodeExecutionStats.__iadd__ with
  type(other).model_fields to avoid any potential __pydantic_extra__ leakage
- Fix rate-override clear: onRateOverride(key, null) deletes the key so
  defaultRateFor() takes effect instead of pinning estimated cost to $0
- Type extract_openrouter_cost parameter as OpenAIChatCompletion
- Fix early-return guard in persist_and_record_usage: allow through when
  all token counts are 0 but cost_usd is provided (fully-cached responses)
- Add missing tests: LLM retry cost (only last attempt merged), zero-token
  copilot cost, Exa search + similar merge_stats coverage
2026-04-07 17:23:42 +07:00
Zamil Majdy
9b1175473b fix(backend/copilot): update stale comment in processor.py after cost log routing change
token_tracking.py now routes cost logs through DatabaseManagerAsyncClient
(platform_cost_db()), so the Prisma connect in on_executor_start() is for
copilot/db.py and rate_limit.py direct Prisma usage.
2026-04-07 17:11:56 +07:00
Zamil Majdy
752a238166 refactor(frontend): replace server actions with React Query hooks for cost dashboard
usePlatformCostContent.ts now calls useGetV2GetPlatformCostDashboard and
useGetV2GetPlatformCostLogs directly (with okData selector) so the browser
gets proper caching, deduplication, and background refetch.

actions.ts is retained as a plain helper module (no 'use server') because
the co-located test file imports from it; the functions are no longer called
by the hook.
2026-04-07 17:08:42 +07:00
Zamil Majdy
2a73d1baa9 fix(backend/copilot): route copilot cost logging through DatabaseManagerAsyncClient
The copilot executor's token_tracking.py was using schedule_cost_log()
which calls execute_raw_with_schema() directly on the Prisma singleton.
In the copilot_executor process, Prisma is not reliably connected due to
event-loop binding issues, causing ClientNotConnectedError on every turn.

Fix: route cost logging through platform_cost_db() -> DatabaseManagerAsyncClient
RPC (same approach already used by the block executor). Also fix
_copilot_block_name() to extract only the service tag from the log prefix
(e.g. "[SDK][session-id][T1]" -> "copilot:SDK") instead of the full suffix.

Update cost_tracking.py drain to drain token_tracking._pending_log_tasks,
and update token_tracking_test.py mocks to match new call site.
2026-04-07 16:58:41 +07:00
Zamil Majdy
254e6057f4 fix(backend/copilot): connect Prisma in copilot executor for cost log writes
schedule_cost_log() in token_tracking.py writes PlatformCostLog rows via
execute_raw_with_schema(), which requires an active Prisma connection.
Connect Prisma at on_executor_start() so cost tracking is not silently dropped
in the copilot executor process.
2026-04-07 16:54:00 +07:00
Zamil Majdy
a616e5a060 fix(backend): address PR review — email masking, semaphore, openrouter cost style
- Mask user emails in admin API responses (dashboard + logs) to reduce
  PII exposure in proxy/CDN logs; _mask_email() shows first 2 chars only
- Add _log_semaphore(50) in platform_cost.py to bound concurrent DB inserts
  and provide back-pressure under high load
- Refactor extract_openrouter_cost() to use try/except AttributeError
  instead of getattr/hasattr, and log a WARNING when _response is missing
  so SDK changes are detectable
- Add comment to usePlatformCostContent.ts explaining why server actions
  are used instead of React Query (server-side withRoleAccess constraint)
2026-04-07 16:38:07 +07:00
Zamil Majdy
c9461836c6 fix(backend): address all open review comments on platform cost tracking
- Normalize provider to lowercase at write time; drop LOWER() in filter so
  the (provider, createdAt) index is used without function overhead
- Drop COALESCE(trackingType, metadata->>'tracking_type') fallback — new rows
  always have trackingType set at write time
- Derive total_users from len(by_user_rows) instead of a separate
  COUNT(DISTINCT userId) query (saves one aggregation per dashboard load)
- Add 30-second TTLCache for dashboard endpoint (cachetools, maxsize=256)
- Add backpressure/bounds comment to _pending_log_tasks in platform_cost.py
- Convert f-string logger calls in token_tracking.py to lazy %s formatting
- Add 6 block-level tests for ExaCodeContextBlock and ExaContentsBlock cost
  paths: valid/invalid/zero cost_dollars strings and None cost_dollars
- Update existing tests to match provider-lowercasing and 2-query dashboard
2026-04-07 16:23:10 +07:00
Zamil Majdy
50a8df3d67 Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into codex/platform-cost-tracking 2026-04-07 16:12:47 +07:00
Zamil Majdy
243b12778f dx: improve pr-test skill — inline screenshots, flow captions, and test evaluation (#12692)
## Changes

### 1. Inline image enforcement (Step 7)
- Added `CRITICAL` warning: never post a bare directory tree link
- Added post-comment verification block that greps for `![` tags and
exits 1 if none found — agents can't silently skip inline embedding

### 2. Structured screenshot captions (Step 6)
- `SCREENSHOT_EXPLANATIONS` now requires **Flow** (which scenario),
**Steps** (exact actions taken), **Evidence** (what this proves)
- Good/bad example included so agents know what format is expected
- A bare "shows the page" caption is explicitly rejected

### 3. Test completeness evaluation (Step 8) — new step
After posting screenshots, the agent must evaluate coverage against the
test plan and post a formal GitHub review:
- **`APPROVE`** — every scenario tested with screenshot + DB/API
evidence, no blockers
- **`REQUEST_CHANGES`** — lists exact gaps: untested scenarios, missing
evidence, confirmed bugs
- Per-scenario checklist (/) required in the review body
- Cannot auto-approve without ticking every item in the test plan

## Why

- Agents were posting `https://github.com/.../tree/test-screenshots/...`
instead of `![name](url)` inline
- Screenshot captions were too vague to be useful ("shows the page")
- No mechanism to catch incomplete test runs — agent could skip
scenarios and still post a passing report

## Checklist

- [x] `.claude/skills/pr-test/SKILL.md` updated
- [x] No production code changes — skill/dx only
- [x] Pre-commit hooks pass
2026-04-07 16:04:08 +07:00
Zamil Majdy
3f7a8dc44d fix(backend/copilot): use active_model instead of config.model for cost attribution
When fast mode is selected, the baseline copilot uses a different model
(active_model from _resolve_baseline_model) than the config default. Using
config.model for cost attribution would misattribute costs to the wrong model.
2026-04-07 15:54:36 +07:00
Zamil Majdy
1c15d6a6cc fix(frontend): update platform-costs tests for Skeleton loading state and removed duration_seconds
- Update loading state test to check for absent Skeleton elements (animate-pulse)
  rather than absent 'Loading...' text (which was removed in previous commit)
- Update helpers.test.ts to test sandbox_seconds instead of the removed duration_seconds
2026-04-07 15:52:38 +07:00
Zamil Majdy
a31be77408 fix(platform): address additional reviewer feedback on platform cost dashboard
- Extract _extract_model_name() helper in cost_tracking.py to replace nested isinstance checks
- Replace Lucide icons with Phosphor equivalents in admin/layout.tsx
- Replace Loading... text with Skeleton components in PlatformCostContent
- Switch Promise.all to Promise.allSettled in usePlatformCostContent for partial data resilience
- Fix hardcoded border-blue-600/text-blue-600 with design token border-primary/text-primary
- Remove dead duration_seconds case from helpers.ts and TrackingBadge (backend never emits it)
2026-04-07 15:46:39 +07:00
Zamil Majdy
1d45f2f18c fix(platform): fix baseline copilot OpenRouter cost extraction and credit_cost check
- Fix wrong attribute: baseline/service.py used getattr(response, 'response')
  but AsyncStream exposes the raw httpx response via '_response' (with
  underscore), matching the pattern in llm.py:extract_openrouter_cost().
  OpenRouter cost tracking in baseline copilot was silently failing.
- Fix falsy zero-cost guard: change `if credit_cost:` to `if credit_cost is
  not None:` so free-tier blocks (credit_cost=0) include the field in metadata.
2026-04-07 15:25:02 +07:00
Zamil Majdy
27e34e9514 fix(platform): drain pending cost logs on shutdown, remove dark: badges
- Wire drain_pending_cost_logs() into ExecutionManager.cleanup() so
  in-flight INSERT tasks are awaited before process exit during rolling
  deployments (uses the existing node_execution_loop; no-op if the loop
  was never started, e.g. in tests)
- Remove prohibited dark: Tailwind classes from TrackingBadge badges;
  light tokens (text-green-700, text-blue-700, …) now apply in all
  themes — design system handles dark mode via CSS variables
2026-04-07 15:07:41 +07:00
Zamil Majdy
16d696edcc fix(platform): address autogpt-reviewer blockers and should-fix items
- Fix LLM retry double-counting: track tokens per attempt but only merge
  provider_cost on the successful attempt, not across all retries
- Add drain_pending_cost_logs() to platform_cost.py; update cost_tracking
  to drain both executor and copilot task sets on shutdown
- Remove prohibited dark: Tailwind classes from PlatformCostContent error
  div, replace with Alert component (design system error variant)
- Add block-level cost tracking tests for: JinaEmbeddingBlock (with/without
  usage), UnrealTextToSpeechBlock (character count), GoogleMapsSearchBlock
  (place count), AddLeadToCampaignBlock (lead count)
- Add __iadd__ edge case tests: provider_cost_type first-write-to-None and
  None does not overwrite existing value
- Rename metadata key provider_cost_usd to provider_cost_raw (value unit
  varies by tracking type; only cost_usd uses USD)
- Add test verifying per_run providers have no provider_cost_raw in metadata
2026-04-07 15:05:44 +07:00
Zamil Majdy
f87bbd5966 fix(backend): route platform cost logging through DatabaseManagerAsyncClient
Fixes ClientNotConnectedError in the executor process by routing
log_platform_cost through the DatabaseManagerAsyncClient RPC proxy
instead of calling execute_raw_with_schema directly on the unconnected
module-level prisma instance.
2026-04-07 14:53:46 +07:00
Nicholas Tindle
b64d1ed9fa Merge branch 'dev' into codex/platform-cost-tracking 2026-04-06 14:31:13 -05:00
An Vy Le
43c81910ae fix(backend/copilot): skip AI blocks without model property in fix_ai_model_parameter (#12688)
### Why / What / How

**Why:** Some AI-category blocks do not expose a `"model"` input
property in their `inputSchema`. The `fix_ai_model_parameter` fixer was
unconditionally injecting a default model value (e.g. `"gpt-4o"`) into
any node whose block has category `"AI"`, regardless of whether that
block actually accepts a `model` input. This causes the agent JSON to
include an invalid field for those blocks.

**What:** Guard the model-injection logic with a check that `"model"`
exists in the block's `inputSchema.properties` before attempting to set
or validate the field. AI blocks that have no model selector are now
skipped entirely.

**How:** In `fix_ai_model_parameter`, after confirming `is_ai_block`,
extract `input_properties` from the block's `inputSchema.properties` and
`continue` if `"model"` is absent. The subsequent `model_schema` lookup
is also simplified to reuse the already-fetched `input_properties` dict.
A regression test is added to cover this case.

### Changes 🏗️

- `backend/copilot/tools/agent_generator/fixer.py`: In
`fix_ai_model_parameter`, skip AI-category nodes whose block
`inputSchema.properties` does not contain a `"model"` key; reuse
`input_properties` for the subsequent `model_schema` lookup.
- `backend/copilot/tools/agent_generator/fixer_test.py`: Add
`test_ai_block_without_model_property_is_skipped` to
`TestFixAiModelParameter`.

### Checklist 📋

#### For code changes:
- [ ] I have clearly listed my changes in the PR description
- [ ] I have made a test plan
- [ ] I have tested my changes according to the test plan:
- [ ] Run `poetry run pytest
backend/copilot/tools/agent_generator/fixer_test.py` — all 50 tests pass
(49 pre-existing + 1 new)

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 17:14:11 +00:00
Zamil Majdy
3895d95826 fix(platform): address reviewer comments — tests, a11y, and frontend polish
Backend:
- Add block cost tracking tests for ExaCodeContext, ExaContents, and
  SearchOrganizations blocks (high-severity reviewer ask)
- Add test verifying FAILED status skips cost log in manager
- Add test for empty org list tracking zero items cost

Frontend:
- Rename trackingBadge() → TrackingBadge component (PascalCase convention)
- Move toLocalInput/toUtcIso helpers from usePlatformCostContent.ts to helpers.ts
- Add aria-label to ProviderTable rate override inputs
- Add role="alert" to error state div in PlatformCostContent
- Add Clear Filters button next to Apply
- Fix text-gray-500 → text-muted-foreground in page.tsx (dark mode)
- Dark-mode-compatible error div styling
- Strengthen PlatformCostContent test assertion (exact count instead of >= 1)
- Add tab panel visibility tests and toLocalInput/toUtcIso unit tests
2026-04-06 21:49:23 +07:00
Zamil Majdy
181208528f fix(backend): update token_tracking_test mock targets after _schedule_log refactor
After extracting _schedule_log into schedule_cost_log() in platform_cost.py,
token_tracking no longer has log_platform_cost_safe as an attribute.
Update patch targets to backend.data.platform_cost.log_platform_cost_safe.
2026-04-06 21:37:38 +07:00
Zamil Majdy
0365a26c85 refactor(backend): add clarifying comment to NodeExecutionStats.__iadd__ vars() usage 2026-04-06 21:25:00 +07:00
Zamil Majdy
fb63ae54f0 refactor(platform): address review comments on platform cost tracking
Backend:
- Extract shared _schedule_log into schedule_cost_log() in platform_cost.py
  so both cost_tracking and token_tracking drain a single task set
- Add DEFAULT_DASHBOARD_DAYS=30 default for dashboard queries to avoid
  full-table scans when no date filter is provided
- Add MAX_PROVIDER_ROWS=500 / MAX_USER_ROWS=100 named constants
- Fix typing.Optional -> X | None union syntax in routes
- Fix logger f-strings to lazy %s format in platform_cost_routes
- Fix token_tracking condition to allow logging when cost_usd is set
  even if total_tokens is 0 (fully-cached responses)
- Fix test_get_dashboard_success to use real PlatformCostDashboard instance
- Add invalid input tests (422 for bad dates, page_size=0/201, page=0)
- Add test_does_not_raise_when_block_usage_cost_raises
- Add test_provider_cost_zero_is_not_none

Frontend:
- Fix TrackingBadge dark mode colors using design tokens
- Fix UserTable null key for deleted users (use unknown-{idx} fallback)
- Fix ProviderTable rate input from uncontrolled to controlled
- Fix "use server" directive on page component (not a server action)
- Add ARIA label and tabpanel roles to tab UI
- Fix LogsTable fragile cast with safe formatLogDate helper
2026-04-06 21:19:49 +07:00
Zamil Majdy
6de79fb73f fix: resolve merge conflicts with dev branch
Keep cost_usd field alongside new thinking_stripper and session_messages
fields added in dev for baseline copilot state.
2026-04-06 21:00:23 +07:00
Ubbe
a11199aa67 dx(frontend): set up React integration testing with Vitest + RTL + MSW (#12667)
## Summary
- Establish React integration tests (Vitest + RTL + MSW) as the primary
frontend testing strategy (~90% of tests)
- Update all contributor documentation (TESTING.md, CONTRIBUTING.md,
AGENTS.md) to reflect the integration-first convention
- Add `NuqsTestingAdapter` and `TooltipProvider` to the shared test
wrapper so page-level tests work out of the box
- Write 8 integration tests for the library page as a reference example
for the pattern

## Why
We had the testing infrastructure (Vitest, RTL, MSW, Orval-generated
handlers) but no established convention for page-level integration
tests. Most existing tests were for stores or small components. Since
our frontend is client-first, we need a documented, repeatable pattern
for testing full pages with mocked APIs.

## What
- **Docs**: Rewrote `TESTING.md` as a comprehensive guide. Updated
testing sections in `CONTRIBUTING.md`, `frontend/AGENTS.md`,
`platform/AGENTS.md`, and `autogpt_platform/AGENTS.md`
- **Test infra**: Added `NuqsTestingAdapter` (for `nuqs` query state
hooks) and `TooltipProvider` (for Radix tooltips) to `test-utils.tsx`
- **Reference tests**: `library/__tests__/main.test.tsx` with 8 tests
covering agent rendering, tabs, folders, search bar, and Jump Back In

## How
- Convention: tests live in `__tests__/` next to `page.tsx`, named
descriptively (`main.test.tsx`, `search.test.tsx`)
- Pattern: `setupHandlers()` → `render(<Page />)` → `findBy*` assertions
- MSW handlers from
`@/app/api/__generated__/endpoints/{tag}/{tag}.msw.ts` for API mocking
- Custom `render()` from `@/tests/integrations/test-utils` wraps all
required providers

## Test plan
- [x] All 422 unit/integration tests pass (`pnpm test:unit`)
- [x] `pnpm format` clean
- [x] `pnpm lint` clean (no new errors)
- [x] `pnpm types` — pre-existing onboarding type errors only, no new
errors

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2026-04-06 13:17:08 +00:00
Zamil Majdy
5f82a71d5f feat(copilot): add Fast/Thinking mode toggle with full tool parity (#12623)
### Why / What / How

Users need a way to choose between fast, cheap responses (Sonnet) and
deep reasoning (Opus) in the copilot. Previously only the SDK/Opus path
existed, and the baseline path was a degraded fallback with no tool
calling, no file attachments, no E2B sandbox, and no permission
enforcement.

This PR adds a copilot mode toggle and brings the baseline (fast) path
to full feature parity with the SDK (extended thinking) path.

### Changes 🏗️

#### 1. Mode toggle (UI → full stack)
- Add Fast / Thinking mode toggle to ChatInput footer (Phosphor
`Brain`/`Zap` icons via lucide-react)
- Thread `mode: "fast" | "extended_thinking" | null` from
`StreamChatRequest` → RabbitMQ queue → executor → service selection
- Fast → baseline service (Sonnet 4 via OpenRouter), Thinking → SDK
service (Opus 4.6)
- Toggle gated behind `CHAT_MODE_OPTION` feature flag with server-side
enforcement
- Mode persists in localStorage with SSR-safe init

#### 2. Baseline service full tool parity
- **Tool call persistence**: Store structured `ChatMessage` entries
(assistant + tool results) instead of flat concatenated text — enables
frontend to render tool call details and maintain context across turns
- **E2B sandbox**: Wire up `get_or_create_sandbox()` so `bash_exec`
routes to E2B (image download, Python/PIL compression, filesystem
access)
- **File attachments**: Accept `file_ids`, download workspace files,
embed images as OpenAI vision blocks, save non-images to working dir
- **Permissions**: Filter tool list via `CopilotPermissions`
(whitelist/blacklist)
- **URL context**: Pass `context` dict to user message for URL-shared
content
- **Execution context**: Pass `sandbox`, `sdk_cwd`, `permissions` to
`set_execution_context()`
- **Model**: Changed `fast_model` from `google/gemini-2.5-flash` to
`anthropic/claude-sonnet-4` for reliable function calling
- **Temp dir cleanup**: Lazy `mkdtemp` (only when files attached) +
`shutil.rmtree` in finally

#### 3. Transcript support for Fast mode
- Baseline service now downloads / validates / loads / appends / uploads
transcripts (parity with SDK)
- Enables seamless mode switching mid-conversation via shared transcript
- Upload shielded from cancellation, bounded at 5s timeout

#### 4. Feature-flag infrastructure fixes
- `FORCE_FLAG_*` env-var overrides on both backend and frontend for
local dev / E2E
- LaunchDarkly context parity (frontend mirrors backend user context)
- `CHAT_MODE_OPTION` default flipped to `false` to match backend

#### 5. Other hardening
- Double-submit ref guard in `useChatInput` + reconnect dedup in
`useCopilotStream`
- `copilotModeRef` pattern to read latest mode without recreating
transport
- Shared `CopilotMode` type across frontend files
- File name collision handling with numeric suffix
- Path sanitization in file description hints (`os.path.basename`)

### Test plan
- [x] 30 new unit tests: `_env_flag_override` (12), `envFlagOverride`
(8), `_filter_tools_by_permissions` (4), `_prepare_baseline_attachments`
(6)
- [x] E2E tested on dev: fast mode creates E2B sandbox, calls 7-10
tools, generates and renders images
- [x] Mode switching mid-session works (shared transcript + session
messages)
- [x] Server-side flag gate enforced (crafted `mode=fast` stripped when
flag off)
- [x] All 37 CI checks green
- [x] Verified via agent-browser: workspace images render correctly in
all message positions

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Zamil Majdy <majdy.zamil@gmail.com>
2026-04-06 19:54:36 +07:00
Zamil Majdy
d57da6c078 refactor(platform): extract usePlatformCostContent hook
PlatformCostContent.tsx was 248 lines mixing data fetching, URL state,
filter inputs, rate overrides, and rendering. Per frontend convention,
extract the stateful/effectful logic into a dedicated hook:

- usePlatformCostContent.ts (new, 142 lines) — owns:
  - dashboard/logs/pagination fetching via effect
  - URL ↔ filter input sync (startInput, endInput, providerInput, userInput)
  - rateOverrides state + handleRateOverride
  - toLocalInput/toUtcIso datetime helpers
  - updateUrl + handleFilter actions
  - totalEstimatedCost reducer
- PlatformCostContent.tsx (now 182 lines) — pure rendering only.
2026-04-05 15:56:32 +02:00
Zamil Majdy
689cd67a13 refactor(platform): address autogpt-reviewer feedback (batch 2)
- resolve_tracking: replace hardcoded provider string literals with
  ProviderName enum values + _CHARACTER_BILLED_PROVIDERS /
  _WALLTIME_BILLED_PROVIDERS frozensets (nice-to-have #2).
- NodeExecutionStats.__iadd__: replace double model_dump() with
  vars()-based iteration for ~10-50x speedup on each merge_stats() call
  (hot path — runs once per block per yield across 20+ blocks).
- Add 3 accumulation tests for provider_cost semantics:
  - Multiple provider_cost values sum (not last-write-wins)
  - None never overwrites a set value
  - provider_cost_type is last-write-wins (documented semantics)
2026-04-05 15:54:12 +02:00
Zamil Majdy
dca89d1586 refactor(platform): address autogpt-reviewer feedback (batch 1)
- cost_tracking.py: replace `Any` types with NodeExecutionEntry + Block
- Extract usd_to_microdollars utility in platform_cost.py, used by
  cost_tracking.py and copilot/token_tracking.py.
- llm.py: extract x-total-cost header parsing to extract_openrouter_cost()
  helper + 8 unit tests covering present/absent/empty/non-numeric/zero
  cases. Previously untested blocker.
- token_tracking.py: extract COPILOT_BLOCK_ID, COPILOT_CREDENTIAL_ID
  constants + _copilot_block_name() helper (clearer than inline
  f"copilot:{log_prefix.strip(' []')}".rstrip(":")).
- platform_cost.py: cap by_provider query at LIMIT 500 (defensive bound).
- TrackingBadge.tsx: drop dark: classes per frontend convention, add
  "items" badge color.
- PlatformCostContent.tsx: drop dark: classes from error banner,
  add role="tablist"/role="tab"/aria-selected to tabs, add htmlFor
  to filter input labels.
- admin/layout.tsx: Receipt icon moved from lucide-react to phosphor.
- ProviderTable.tsx: add "(unsaved)" label to Rate column header to
  signal per-session only.
2026-04-05 15:46:50 +02:00
Zamil Majdy
2f63fcd383 test(frontend): update platform-costs helpers tests for per-type rate estimation
Updates the test suite to match the new per-type rate estimation logic:
- rateOverrides now use composite keys (provider:tracking_type)
- trackingValue appends unit suffixes (tokens, chars, items)
- characters/items tracking reads from total_tracking_amount
- adds coverage for default rates across characters, items, duration types
2026-04-05 15:30:32 +02:00
Zamil Majdy
f04cd08e40 feat(platform): add trackingAmount column + per-type rate estimation
Problem
- cost_tracking.py was multiplying stats.provider_cost by 1M to get
  cost_microdollars regardless of tracking_type. When provider_cost_type
  was "items" or "characters", 5.0 items got stored as $5 USD.
- The dashboard had no way to aggregate item/character counts since
  they aren't naturally carried by inputTokens/outputTokens/duration.
- Dashboard estimation only handled cost_usd/tokens/per_run; characters,
  items, sandbox_seconds, walltime_seconds showed "-" always.

Fix
- Add PlatformCostLog.trackingAmount (Float?) column + migration.
- cost_tracking.py: only treat provider_cost as USD when tracking_type
  is "cost_usd"; always populate trackingAmount with resolve_tracking's
  amount so the dashboard can aggregate it.
- Dashboard query: SUM(trackingAmount) as total_tracking_amount.
- ProviderCostSummary (backend + regenerated TS): add total_tracking_amount.
- Frontend helpers: DEFAULT_COST_PER_1K_CHARS, DEFAULT_COST_PER_ITEM,
  DEFAULT_COST_PER_SECOND tables for characters/items/duration rates.
  estimateCostForRow dispatches per tracking_type and multiplies the
  correct amount by the correct rate.
- ProviderTable: show editable rate input for every tracking_type
  (not only per_run), with unit label ($/1K tokens, $/1K chars, $/item,
  $/second, $/run). Rate overrides keyed on "provider:tracking_type".
2026-04-05 15:23:31 +02:00
Zamil Majdy
44714f1b25 refactor(platform): use provider_cost_type Literal instead of output_size misuse
Blocks previously called merge_stats(NodeExecutionStats(output_size=...))
to signal "per-request" billing or "N items returned", but `output_size`
is semantically the output payload byte count and is always overridden
by the executor wrapper (manager.py:440 = len(json.dumps(output_data))).
Those calls were silently dead code.

Changes:
- Add ProviderCostType Literal enum on NodeExecutionStats with the
  canonical set of tracking types: cost_usd, tokens, characters,
  sandbox_seconds, walltime_seconds, per_run, items.
- Add provider_cost_type field to NodeExecutionStats so blocks can
  declare their billing model explicitly instead of resolve_tracking
  guessing from provider name.
- resolve_tracking honors provider_cost_type first, falling back to
  heuristics only when not set.
- Remove 26 dead merge_stats(output_size=1) calls across 15 blocks.
- Replace 5 merge_stats(output_size=len(X)) calls with explicit
  provider_cost+provider_cost_type (items/characters) so the count
  is preserved through the wrapper's output_size override.
- Clean up unused NodeExecutionStats imports in 14 files.
- Add tests for block-declared provider_cost_type pathway.
2026-04-05 14:56:44 +02:00
Zamil Majdy
78b95f8a76 fix(platform): add provider_cost tracking to Exa code_context block 2026-04-05 14:44:30 +02:00
Zamil Majdy
6f0c1dfa11 fix(platform): close tracking gaps found during audit
- resolve_tracking: read `script` field for elevenlabs in addition to
  `script_input`/`text` — VideoNarrationBlock uses `script`, was
  producing tracking_amount=0 characters before.
- exa/similar.py + exa/research.py (3 blocks): extract provider_cost
  from response.cost_dollars.total via merge_stats so tracking_type
  ends up as "cost_usd" with real dollar amounts instead of
  falling through to per_run.
- Add test for script field resolution.

Audit finding: `output_size` set via merge_stats in blocks is
always overridden by the executor wrapper (manager.py:440 computes
byte count of serialized output), and `walltime` is also set by
the wrapper (manager.py:667). So the existing merge_stats(output_size=1)
calls in ~15 blocks are dead code for cost tracking purposes; they
don't hurt but don't add data either. The real tracking data sources
are: (1) input/output_token_count from LLM blocks, (2) provider_cost
from APIs that return USD, (3) input_data for per-character TTS,
(4) auto-populated walltime for wall-clock billing.
2026-04-05 14:38:24 +02:00
Zamil Majdy
5e595231da test(platform): align actions tests with string date passthrough
The actions intentionally pass raw ISO strings (cast to Date) to the
generated client to avoid Date.toString() producing non-ISO output
that FastAPI rejects. Update the tests to match this behavior rather
than expecting Date instances.
2026-04-05 12:54:42 +02:00
Zamil Majdy
7b36bed8a5 fix(platform): address autogpt-reviewer feedback on cost tracking
- cost_tracking.py + token_tracking.py: switch back to asyncio.create_task
  for true fire-and-forget on hot path, but hold strong references in a
  module-level set (with done-callback discard) so tasks can't be GC'd
  mid-flight. Addresses both the "await blocks executor" concern and the
  "task may vanish before completion" concern.
- cost_tracking.py: `> 0` checks instead of truthy for output_size/walltime
  so legitimate zero values aren't stored as NULL.
- platform_cost_routes_test.py: add explicit 403 test for non-admin JWT
  and extend 401 test to cover /logs endpoint.
- actions.ts: forward raw ISO strings to generated client instead of Date
  objects — the client calls .toString() which produces human-readable
  format that FastAPI rejects with 422. Fixes timezone filter on the
  admin dashboard.
2026-04-05 12:50:03 +02:00
Zamil Majdy
372900c141 fix(platform): address 5 self-review items on cost tracking
- cost_tracking.py: drop asyncio.create_task fire-and-forget (risked task
  GC mid-flight per Python docs); await log_platform_cost_safe directly.
  Wrap body in try/except so logging never disrupts executor.
- token_tracking.py: same create_task fix; await directly.
- platform_cost.py: document that by_provider rows are keyed on
  (provider, tracking_type) so the same provider can appear multiple times.
- PlatformCostContent.tsx: convert datetime-local (naive local time) to
  UTC ISO before URL serialization so filter windows match admin's wall
  clock regardless of backend timezone. Convert back to local for input
  display.
2026-04-05 11:55:00 +02:00
Nicholas Tindle
1a305db162 ci(frontend): add Playwright E2E coverage reporting to Codecov (#12665)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-04 00:55:09 -05:00
Zamil Majdy
7afd2b249d fix(platform): address 9 should-fix items from PR review
1. Fix route path double-nesting: /api/admin/platform-costs/{dashboard,logs}
2. Fix falsy zero suppression: pass raw token counts instead of `or None`
3. Split 546-line PlatformCostContent into SummaryCard, ProviderTable,
   UserTable, LogsTable, TrackingBadge sub-components
4. Add merge_stats accumulation tests and integration test for
   on_node_execution -> log_system_credential_cost wiring
5. Add source citations for DEFAULT_COST_PER_RUN values
6. Extract MICRODOLLARS_PER_USD constant, use in all conversion sites
7. Parallelize COUNT + SELECT in get_platform_cost_logs with asyncio.gather
8. Remove dead block_name parameter from resolve_tracking()
9. Remove unrelated store.test.ts (added by this PR, not on dev)
2026-04-03 23:14:38 +02:00
Zamil Majdy
8d22653810 fix(platform): address 4 review blockers on cost tracking
- Fire-and-forget cost logging via asyncio.create_task() instead of await
  to avoid blocking executor and copilot streaming paths on DB INSERT
- Add trackingType column to PlatformCostLog schema, migration, and INSERT;
  update dashboard/logs queries to use COALESCE(column, JSONB) for backward
  compat and index-friendly GROUP BY
- Admin auth test now explicitly mocks get_jwt_payload to raise 401 instead
  of relying on bare FastAPI app behavior
- Blocker 3 (nullable user_id) was already addressed in prior commit
2026-04-03 22:43:57 +02:00
Zamil Majdy
48a653dc63 fix(copilot): prevent duplicate side effects from double-submit and stale-cache race (#12660)
## Why

#12604 (intermediate persistence) introduced two bugs on dev:

1. **Duplicate user messages** — `set_turn_duration` calls
`invalidate_session_cache()` which deletes the Redis key. Concurrent
`get_chat_session()` calls re-populate it from DB with stale data. The
executor loads this stale cache, misses the user message, and re-appends
it.

2. **Tool outputs lost on hydration** — Intermediate flushes save
assistant messages to DB before `StreamToolInputAvailable` sets
`tool_calls` on them. Since `_save_session_to_db` is append-only (uses
`start_sequence`), the `tool_calls` update is lost — subsequent flushes
start past that index. On page refresh / SSE reconnect, tool UIs
(SetupRequirementsCard, run_block output, etc.) are invisible.

3. **Sessions stuck running** — If a tool call hangs (e.g. WebSearch
provider not responding), the stream never completes,
`mark_session_completed` never runs, and the `active_stream` flag stays
stale in Redis.

## What

- **In-place cache update** in `set_turn_duration` — replaces
`invalidate_session_cache()` with a read-modify-write that patches the
duration on the cached session, eliminating the stale-cache repopulation
window
- **tool_calls backfill** — tracks the flush watermark and assistant
message index; when `StreamToolInputAvailable` sets `tool_calls` on an
already-flushed assistant, updates the DB record directly via
`update_message_tool_calls()`
- **Improved message dedup** — `is_message_duplicate()` /
`maybe_append_user_message()` scans trailing same-role messages (current
turn) instead of only checking `messages[-1]`
- **Idle timeout** — aborts the stream with a retryable error if no
meaningful SDK message arrives for 10 minutes, preventing hung tool
calls from leaving sessions stuck

## Changes

- `copilot/db.py` — `update_message_tool_calls()`, in-place cache update
in `set_turn_duration`
- `copilot/model.py` — `is_message_duplicate()`,
`maybe_append_user_message()`
- `copilot/sdk/service.py` — flush watermark tracking, tool_calls
backfill, idle timeout
- `copilot/baseline/service.py` — use `maybe_append_user_message()`
- `copilot/model_test.py` — unit tests for dedup
- `copilot/db_test.py` — unit tests for set_turn_duration cache update

## Checklist

- [x] My PR title follows [conventional
commit](https://www.conventionalcommits.org/) format
- [x] Out-of-scope changes are less than 20% of the PR
- [x] Changes to `data/*.py` validated for user ID checks (N/A)
- [x] Protected routes updated in middleware (N/A)
2026-04-04 01:09:42 +07:00
Toran Bruce Richards
f6ddcbc6cb feat(platform): Add all 12 Z.ai GLM models via OpenRouter (#12672)
## Summary

Add Z.ai (Zhipu AI) GLM model family to the platform LLM blocks, routed
through OpenRouter. This enables users to select any of the 12 Z.ai
models across all LLM-powered blocks (AI Text Generator, AI
Conversation, AI Structured Response, AI Text Summarizer, AI List
Generator).

## Gap Analysis

All 12 Z.ai models currently available on OpenRouter's API were missing
from the AutoGPT platform:

| Model | Context Window | Max Output | Price Tier | Cost |
|-------|---------------|------------|------------|------|
| GLM 4 32B | 128K | N/A | Tier 1 | 1 |
| GLM 4.5 | 131K | 98K | Tier 2 | 2 |
| GLM 4.5 Air | 131K | 98K | Tier 1 | 1 |
| GLM 4.5 Air (Free) | 131K | 96K | Tier 1 | 1 |
| GLM 4.5V (vision) | 65K | 16K | Tier 2 | 2 |
| GLM 4.6 | 204K | 204K | Tier 1 | 1 |
| GLM 4.6V (vision) | 131K | 131K | Tier 1 | 1 |
| GLM 4.7 | 202K | 65K | Tier 1 | 1 |
| GLM 4.7 Flash | 202K | N/A | Tier 1 | 1 |
| GLM 5 | 80K | 131K | Tier 2 | 2 |
| GLM 5 Turbo | 202K | 131K | Tier 3 | 4 |
| GLM 5V Turbo (vision) | 202K | 131K | Tier 3 | 4 |

## Changes

- **`autogpt_platform/backend/backend/blocks/llm.py`**: Added 12
`LlmModel` enum entries and corresponding `MODEL_METADATA` with context
windows, max output tokens, display names, and price tiers sourced from
OpenRouter API
- **`autogpt_platform/backend/backend/data/block_cost_config.py`**:
Added `MODEL_COST` entries for all 12 models, with costs scaled to match
pricing (1 for budget, 2 for mid-range, 4 for premium)

## How it works

All Z.ai models route through the existing OpenRouter provider
(`open_router`) — no new provider or API client code needed. Users with
an OpenRouter API key can immediately select any Z.ai model from the
model dropdown in any LLM block.

## Related

- Linear: REQ-83

---------

Co-authored-by: AutoGPT CoPilot <copilot@agpt.co>
2026-04-03 15:48:33 +00:00
Zamil Majdy
b00e16b438 fix(platform): fix model_test to use Optional fields for None-skip test
Use provider_cost (Optional) and error (Optional) instead of
walltime (non-nullable float) to test __iadd__ None-skip behavior.
2026-04-03 17:25:15 +02:00
Zamil Majdy
b5acfb7855 fix: resolve merge conflict with dev in helpers.test.ts 2026-04-03 17:12:07 +02:00
Zamil Majdy
1ee0bd6619 fix(platform): use round() for microdollar conversion and add cost tracking tests
- Fix float->int truncation bug in token_tracking.py and cost_tracking.py
  where int(cost * 1_000_000) would under-count (e.g. 0.0015 -> 1499
  instead of 1500). Now uses round() for correct rounding.
- Extract _resolve_tracking and _log_system_credential_cost from
  manager.py into dedicated cost_tracking.py module for testability.
- Add unit tests for all 8+ provider branches in resolve_tracking,
  log_system_credential_cost happy/skip paths, and model conversion.
- Add NodeExecutionStats.__iadd__ regression tests for None-skip behavior.
- Add frontend component tests for PlatformCostContent (14 tests) and
  actions.ts server actions (7 tests) to improve codecov patch coverage.
2026-04-03 17:04:07 +02:00
Zamil Majdy
98f13a6e5d feat(copilot): add create -> dry-run -> fix loop to agent generation (#12578)
## Summary
- Instructs the copilot LLM to automatically dry-run agents after
creating or editing them, inspect the output for wiring/data-flow
issues, and fix iteratively before presenting the agent as ready to the
user
- Updates tool descriptions (run_agent, get_agent_building_guide),
prompting supplement, and agent generation guide with clear workflow
instructions and error pattern guidance
- Adds Tool Discovery Priority to shared tool notes (find_block ->
run_mcp_tool -> SendAuthenticatedWebRequestBlock -> manual API)
- Adds 37 tests: prompt regression tests + functional tests (tool schema
validation, Pydantic model, guide workflow ordering)
- **Frontend**: Fixes host-scoped credential UX — replaces duplicate
credentials for the same host instead of stacking them, wires up delete
functionality with confirmation modal, updates button text contextually
("Update headers" vs "Add headers")

## Test plan
- [x] All 37 `dry_run_loop_test.py` tests pass (prompt content, tool
schemas, Pydantic model, guide ordering)
- [x] Existing `tool_schema_test.py` passes (110 tests including
character budget gate)
- [x] Ruff lint and format pass
- [x] Pyright type checking passes
- [x] Frontend: `pnpm lint`, `pnpm types` pass
- [x] Manual verification: confirm copilot follows the create -> dry-run
-> fix workflow when asked to build an agent
- [x] Manual verification: confirm host-scoped credentials replace
instead of duplicate
2026-04-03 14:48:57 +00:00
Zamil Majdy
613978a611 ci: add gitleaks secret scanning to pre-commit hooks (#12649)
### Why / What / How

**Why:** We had no local pre-commit protection against accidentally
committing secrets. The existing `detect-secrets` hook only ran on
`pre-push`, which is too late — secrets are already in git history by
that point. GitHub's push protection only covers known provider patterns
and runs server-side.

**What:** Adds a 3-layer defense against secret leaks: local pre-commit
hooks (gitleaks + detect-secrets), and a CI workflow as a safety net.

**How:** 
- Moved `detect-secrets` from `pre-push` to `pre-commit` stage
- Added `gitleaks` as a second pre-commit hook (Go binary, faster and
more comprehensive rule set)
- Added `.gitleaks.toml` config with allowlists for known false
positives (test fixtures, dev docker JWTs, Firebase public keys, lock
files, docs examples)
- Added `repo-secret-scan.yml` CI workflow using `gitleaks-action` on
PRs/pushes to master/dev

### Changes 🏗️

- `.pre-commit-config.yaml`: Moved `detect-secrets` to pre-commit stage,
added baseline arg, added `gitleaks` hook
- `.gitleaks.toml`: New config with tuned allowlists for this repo's
false positives
- `.secrets.baseline`: Empty baseline for detect-secrets to track known
findings
- `.github/workflows/repo-secret-scan.yml`: New CI workflow running
gitleaks on every PR and push

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Ran `gitleaks detect --no-git` against the full repo — only `.env`
files (gitignored) remain as findings
  - [x] Verified gitleaks catches a test secret file correctly
- [x] Pre-commit hooks pass on commit (both detect-secrets and gitleaks
passed)

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)
2026-04-03 14:01:26 +00:00
Zamil Majdy
2b0e8a5a9f feat(platform): add rate-limit tiering system for CoPilot (#12581)
## Summary
- Adds a four-tier subscription system (FREE/PRO/BUSINESS/ENTERPRISE)
for CoPilot with configurable multipliers (1x/5x/20x/60x) applied on top
of the base LaunchDarkly/config limits
- Stores user tier in the database (`User.subscriptionTier` column as a
Prisma enum, defaults to PRO for beta testing) with admin API endpoints
for tier management
- Includes tier info in usage status responses and OTEL/Langfuse trace
metadata for observability

## Tier Structure
| Tier | Multiplier | Daily Tokens | Weekly Tokens | Notes |
|------|-----------|-------------|--------------|-------|
| FREE | 1x | 2.5M | 12.5M | Base tier (unused during beta) |
| PRO | 5x | 12.5M | 62.5M | Default on sign-up (beta) |
| BUSINESS | 20x | 50M | 250M | Manual upgrade for select users |
| ENTERPRISE | 60x | 150M | 750M | Highest tier, custom |

## Changes
- **`rate_limit.py`**: `SubscriptionTier` enum
(FREE/PRO/BUSINESS/ENTERPRISE), `TIER_MULTIPLIERS`, `get_user_tier()`,
`set_user_tier()`, update `get_global_rate_limits()` to apply tier
multiplier and return 3-tuple, add `tier` field to `CoPilotUsageStatus`
- **`rate_limit_admin_routes.py`**: Add `GET/POST
/admin/rate_limit/tier` endpoints, include `tier` in
`UserRateLimitResponse`
- **`routes.py`** (chat): Include tier in `/usage` endpoint response
- **`sdk/service.py`**: Send `subscription_tier` in OTEL/Langfuse trace
metadata
- **`schema.prisma`**: Add `SubscriptionTier` enum and
`subscriptionTier` column to `User` model (default: PRO)
- **`config.py`**: Update docs to reflect tier system
- **Migration**: `20260326200000_add_rate_limit_tier` — creates enum,
migrates STANDARD→PRO, adds BUSINESS, sets default to PRO

## Test plan
- [x] 72 unit tests all passing (43 rate_limit + 11 admin routes + 18
chat routes)
- [ ] Verify FREE tier users get base limits (2.5M daily, 12.5M weekly)
- [ ] Verify PRO tier users get 5x limits (12.5M daily, 62.5M weekly)
- [ ] Verify BUSINESS tier users get 20x limits (50M daily, 250M weekly)
- [ ] Verify ENTERPRISE tier users get 60x limits (150M daily, 750M
weekly)
- [ ] Verify admin can read and set user tiers via API
- [ ] Verify tier info appears in Langfuse traces
- [ ] Verify migration applies cleanly (creates enum, migrates STANDARD
users to PRO, adds BUSINESS, default PRO)

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
2026-04-03 13:36:01 +00:00
Zamil Majdy
08bb05141c dx: enhance pr-address skill with detailed codecov coverage guidance (#12662)
Enhanced pr-address skill codecov section with local coverage commands,
priority guide, and troubleshooting steps.
2026-04-03 13:15:46 +00:00
Zamil Majdy
4190f75b0b test: additional coverage for platform cost and token tracking 2026-04-03 15:09:54 +02:00
Zamil Majdy
71315aa982 fix(backend): use actual provider in persist_and_record_usage cost logging
The provider field was hardcoded to "open_router" for all PlatformCostLog
entries, even when the SDK (Anthropic) path was the caller. Add a provider
parameter that defaults to "open_router" for backward compatibility and
pass "anthropic" from the SDK service layer.
2026-04-03 14:37:47 +02:00
Nicholas Tindle
3ccaa5e103 ci(frontend): make frontend coverage checks informational (non-blocking) (#12663)
### Why / What / How

**Why:** Frontend test coverage is still ramping up. The default
component status checks (project + patch at 80%) would block merges for
insufficient coverage on frontend changes, which isn't practical yet.

**What:** Override the platform-frontend component's coverage statuses
to be `informational: true`, so they report but don't block merges.

**How:** Added explicit `statuses` to the `platform-frontend` component
in `codecov.yml` with `informational: true` on both project and patch
checks, overriding the `default_rules`.

### Changes 🏗️

- **`codecov.yml`**: Added `informational: true` to platform-frontend
component's project and patch status checks

### Checklist 📋

#### For code changes:
- [ ] I have clearly listed my changes in the PR description
- [ ] I have made a test plan
- [ ] I have tested my changes according to the test plan:
- [ ] Verify Codecov frontend status checks show as informational
(non-blocking) on PRs touching frontend code

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Low Risk**
> Low risk: Codecov configuration-only change that affects merge gating
for frontend coverage statuses but does not alter runtime code.
> 
> **Overview**
> Updates `codecov.yml` to override the `platform-frontend` component’s
coverage `statuses` so both **project** and **patch** checks are marked
`informational: true` (non-blocking), while leaving the default
component coverage rules unchanged for other components.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
f8e8426a31. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-03 12:22:05 +00:00
Zamil Majdy
960f893295 test(platform): add unit tests for platform cost helpers and data layer
Extract pure helper functions (formatMicrodollars, formatTokens,
formatDuration, estimateCostForRow, trackingValue, toDateOrUndefined)
from PlatformCostContent.tsx into helpers.ts for testability. Add 26
vitest cases covering all formatting and cost-estimation branches.

Add backend tests for _build_where and _json_or_none in
platform_cost.py (11 pytest cases covering filter combinations).
2026-04-03 14:15:28 +02:00
Krzysztof Czerwinski
09e42041ce fix(frontend): AutoPilot notification follow-ups — branding, UX, persistence, and cross-tab sync (#12428)
AutoPilot (copilot) notifications had several follow-up issues after
initial implementation: old "Otto" branding, UX quirks, a service-worker
crash, notification state that didn't persist or sync across tabs, a
broken notification sound, and noisy Sentry alerts from SSR.

### Changes 🏗️

- **Rename "Otto" → "AutoPilot"** in all notification surfaces: browser
notifications, document title badge, permission dialog copy, and
notification banner copy
- **Agent Activity icon**: changed from `Bell` to `Pulse` (Phosphor) in
the navbar dropdown
- **Centered dialog buttons**: the "Stay in the loop" permission dialog
buttons are now centered instead of right-aligned
- **Service worker notification fix**: wrapped `new Notification()` in
try-catch so it degrades gracefully in service worker / PWA contexts
instead of throwing `TypeError: Illegal constructor`
- **Persist notification state**: `completedSessionIDs` is now stored in
localStorage (`copilot-completed-sessions`) so it survives page
refreshes and new tabs
- **Cross-tab sync**: a `storage` event listener keeps
`completedSessionIDs` and `document.title` in sync across all open tabs
— clearing a notification in one tab clears it everywhere
- **Fix notification sound**: corrected the sound file path from
`/sounds/notification.mp3` to `/notification.mp3` and added a
`.gitignore` exception (root `.gitignore` has a blanket `*.mp3` ignore
rule from legacy AutoGPT agent days)
- **Fix SSR Sentry noise**: guarded the Copilot Zustand store
initialization with a client-side check so `storage.get()` is never
called during SSR, eliminating spurious Sentry alerts (BUILDER-7CB, 7CC,
7C7) while keeping the Sentry reporting in `local-storage.ts` intact for
genuinely unexpected SSR access

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Verify "AutoPilot" appears (not "Otto") in browser notification,
document title, permission dialog, and banner
  - [x] Verify Pulse icon in navbar Agent Activity dropdown
  - [x] Verify "Stay in the loop" dialog buttons are centered
- [x] Open two tabs on copilot → trigger completion → both tabs show
badge/checkmark
  - [x] Click completed session in tab 1 → badge clears in both tabs
  - [x] Refresh a tab → completed session state is preserved
  - [x] Verify notification sound plays on completion
  - [x] Verify no Sentry alerts from SSR localStorage access

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-03 11:44:22 +00:00
Zamil Majdy
759effab60 test(frontend): add unit tests for onboarding store and GenericTool helpers
Improve frontend patch coverage with comprehensive tests for the
onboarding wizard zustand store and GenericTool helper functions.
2026-04-03 13:25:52 +02:00
Zamil Majdy
a50e95f210 feat(backend/copilot): add include_graph option to find_library_agent (#12622)
## Why

The copilot's `edit_agent` tool requires the LLM to provide a complete
agent JSON (all nodes + links), but the LLM had **no way to see the
current graph structure** before editing. It was editing blindly —
guessing/hallucinating the entire node+link structure and replacing the
graph wholesale.

## What

- Add `include_graph` boolean parameter (default `false`) to the
existing `find_library_agent` tool
- When `true`, each returned `AgentInfo` includes a `graph` field with
the full graph JSON (nodes, links, `input_default` values)
- Update the agent generation guide to instruct the LLM to always fetch
the current graph before editing

## How

- Added `graph: dict[str, Any] | None` field to `AgentInfo` model
- Added `_enrich_agents_with_graph()` helper in `agent_search.py` that
calls the existing `get_agent_as_json()` utility to fetch full graph
data
- Threaded `include_graph` parameter through `find_library_agent` →
`search_agents` → `_search_library`
- Updated `agent_generation_guide.md` to add an "if editing" step that
fetches the graph first

No new tools introduced — reuses existing `find_library_agent` with one
optional flag.

## Test plan

- [x] Unit tests: 2 new tests added
(`test_include_graph_fetches_nodes_and_links`,
`test_include_graph_false_does_not_fetch`)
- [x] All 7 `agent_search_test.py` tests pass
- [x] All pre-commit hooks pass (lint, format, typecheck)
- [ ] Verify copilot correctly uses `include_graph=true` before editing
an agent (manual test)
2026-04-03 11:20:57 +00:00
Zamil Majdy
92b395d82a fix(backend): use OpenRouter client for simulator to support non-OpenAI models (#12656)
## Why

Dry-run block simulation is failing in production with `404 - model
gemini-2.5-flash does not exist`. The simulator's default model
(`google/gemini-2.5-flash`) is a non-OpenAI model that requires
OpenRouter routing, but the shared `get_openai_client()` prefers the
direct OpenAI key, creating a client that can't handle non-OpenAI
models. The old code also stripped the provider prefix, sending
`gemini-2.5-flash` to OpenAI's API.

## What

- Added `prefer_openrouter` keyword parameter to `get_openai_client()` —
when True, prefers the OpenRouter key (returns None if unavailable,
rather than falling back to an incompatible direct OpenAI client)
- Simulator now calls `get_openai_client(prefer_openrouter=True)` so
`google/gemini-2.5-flash` routes correctly through OpenRouter
- Removed the redundant `SIMULATION_MODEL` env var override and the
now-unnecessary provider prefix stripping from `_simulator_model()`

## How

`get_openai_client()` is decorated with `@cached(ttl_seconds=3600)`
which keys by args, so `get_openai_client()` and
`get_openai_client(prefer_openrouter=True)` are cached independently.
When `prefer_openrouter=True` and no OpenRouter key exists, returns
`None` instead of falling back — the simulator already handles `None`
with a clear error message.

### Checklist
- [x] All 24 dry-run tests pass
- [x] Test asserts `get_openai_client` is called with
`prefer_openrouter=True`
- [x] Format, lint, and pyright pass
- [x] No changes to user-facing APIs
- [ ] Deploy to staging and verify simulation works

---------

Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
2026-04-03 11:19:09 +00:00
Zamil Majdy
45b6ada739 Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into codex/platform-cost-tracking 2026-04-03 13:07:06 +02:00
Ubbe
86abfbd394 feat(frontend): redesign onboarding wizard with Autopilot-first flow (#12640)
### Why / What / How

<img width="800" height="827" alt="Screenshot 2026-04-02 at 15 40 24"
src="https://github.com/user-attachments/assets/69a381c1-2884-434b-9406-4a3f7eec87cf"
/>
<img width="800" height="825" alt="Screenshot 2026-04-02 at 15 40 41"
src="https://github.com/user-attachments/assets/c6191a68-a8ba-482b-ba47-c06c71d69f0c"
/>
<img width="800" height="825" alt="Screenshot 2026-04-02 at 15 40 48"
src="https://github.com/user-attachments/assets/31b632b9-59cb-4bf7-a6a0-6158846fcf9a"
/>
<img width="800" height="812" alt="Screenshot 2026-04-02 at 15 40 54"
src="https://github.com/user-attachments/assets/64e38a15-2e56-4c0e-bd84-987bf6076bf7"
/>



**Why:** The existing onboarding flow was outdated and didn't align with
the new Autopilot-first experience. New users need a streamlined,
visually polished wizard that collects their role and pain points to
personalize Autopilot suggestions.

**What:** Complete redesign of the onboarding wizard as a 4-step flow:
Welcome → Role selection → Pain points → Preparing workspace. Uses the
design system throughout (atoms/molecules), adds animations, and syncs
steps with URL search params.

**How:** 
- Zustand store manages wizard state (name, role, pain points, current
step)
- Steps synced to `?step=N` URL params for browser navigation support
- Pain points reordered based on selected role (e.g. Sales sees "Finding
leads" first)
- Design system components used exclusively (no raw shadcn `ui/`
imports)
- New reusable components: `FadeIn` (atom), `TypingText` (molecule) with
Storybook stories
- `AutoGPTLogo` made sizeable via Tailwind className prop, migrated in
Navbar
- Fixed `SetupAnalytics` crash (client component was rendered inside
`<head>`)

### Changes 🏗️

- **New onboarding wizard** (`steps/WelcomeStep`, `RoleStep`,
`PainPointsStep`, `PreparingStep`)
- **New shared components**: `ProgressBar`, `StepIndicator`,
`SelectableCard`, `CardCarousel`
- **New design system components**: `FadeIn` atom with stories,
`TypingText` molecule with stories
- **`AutoGPTLogo`** — size now controlled via `className` prop instead
of numeric `size`
- **Navbar** — migrated from legacy `IconAutoGPTLogo` to design system
`AutoGPTLogo`
- **Layout fix** — moved `SetupAnalytics` from `<head>` to `<body>` to
fix React hydration crash
- **Role-based pain point ordering** — top picks surfaced first based on
role selection
- **URL-synced steps** — `?step=N` search params for back/forward
navigation
- Removed old onboarding pages (1-welcome through 6-congrats, reset
page)
- Emoji/image assets for role selection cards

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Complete onboarding flow from step 1 through 4 as a new user
  - [x] Verify back button navigates to previous step
  - [x] Verify progress bar advances correctly (hidden on step 4)
  - [x] Verify step indicator dots show for steps 1-3
  - [x] Verify role selection reorders pain points on next step
  - [x] Verify "Other" role/pain point shows text input
  - [x] Verify typing animation on PreparingStep title
  - [x] Verify fade-in animations on all steps
  - [x] Verify URL updates with `?step=N` on navigation
  - [x] Verify browser back/forward works with step URLs
  - [x] Verify mobile horizontal scroll on card grids
  - [x] Verify `pnpm types` passes cleanly

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-03 18:06:57 +07:00
Nicholas Tindle
a7f4093424 ci(platform): set up Codecov coverage reporting across platform and classic (#12655)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-03 03:48:30 -05:00
Nicholas Tindle
e33b1e2105 feat(classic): update classic autogpt a bit to make it more useful for my day to day (#11797)
## Summary

This PR modernizes AutoGPT Classic to make it more useful for day-to-day
autonomous agent development. Major changes include consolidating the
project structure, adding new prompt strategies, modernizing the
benchmark system, and improving the development experience.

**Note: AutoGPT Classic is an experimental, unsupported project
preserved for educational/historical purposes. Dependencies will not be
actively updated.**

## Changes 🏗️

### Project Structure & Build System
- **Consolidated Poetry projects** - Merged `forge/`,
`original_autogpt/`, and benchmark packages into a single
`pyproject.toml` at `classic/` root
- **Removed old benchmark infrastructure** - Deleted the complex
`agbenchmark` package (3000+ lines) in favor of the new
`direct_benchmark` harness
- **Removed frontend** - Deleted `benchmark/frontend/` React app (no
longer needed)
- **Cleaned up CI workflows** - Simplified GitHub Actions workflows for
the consolidated project structure
- **Added CLAUDE.md** - Documentation for working with the codebase
using Claude Code

### New Direct Benchmark System
- **`direct_benchmark` harness** - New streamlined benchmark runner
with:
  - Rich TUI with multi-panel layout showing parallel test execution
  - Incremental resume and selective reset capabilities
  - CI mode for non-interactive environments
  - Step-level logging with colored prefixes
  - "Would have passed" tracking for timed-out challenges
  - Copy-paste completion blocks for sharing results

### Multiple Prompt Strategies
Added pluggable prompt strategy system supporting:
- **one_shot** - Single-prompt completion
- **plan_execute** - Plan first, then execute steps
- **rewoo** - Reasoning without observation (deferred tool execution)
- **react** - Reason + Act iterative loop
- **lats** - Language Agent Tree Search (MCTS-based exploration)
- **sub_agent** - Multi-agent delegation architecture
- **debate** - Multi-agent debate for consensus

### LLM Provider Improvements
- Added support for modern **Anthropic Claude models**
(claude-3.5-sonnet, claude-3-haiku, etc.)
- Added **Groq** provider support
- Improved tool call error feedback for LLM self-correction
- Fixed deprecated API usage

### Web Components
- **Replaced Selenium with Playwright** for web browsing (better async
support, faster)
- Added **lightweight web fetch component** for simple URL fetching
- **Modernized web search** with tiered provider system (Tavily, Serper,
Google)

### Agent Capabilities
- **Workspace permissions system** - Pattern-based allow/deny lists for
agent commands
- **Rich interactive selector** for command approval with scopes
(once/agent/workspace/deny)
- **TodoComponent** with LLM-powered task decomposition
- **Platform blocks integration** - Connect to AutoGPT Platform API for
additional blocks
- **Sub-agent architecture** - Agents can spawn and coordinate
sub-agents

### Developer Experience
- **Python 3.12+ support** with CI testing on 3.12, 3.13, 3.14
- **Current working directory as default workspace** - Run `autogpt`
from any project directory
- Simplified log format (removed timestamps)
- Improved configuration and setup flow
- External benchmark adapters for GAIA, SWE-bench, and AgentBench

### Bug Fixes
- Fixed N/A command loop when using native tool calling
- Fixed auto-advance plan steps in Plan-Execute strategy
- Fixed approve+feedback to execute command then send feedback
- Fixed parallel tool calls in action history
- Always recreate Docker containers for code execution
- Various pyright type errors resolved
- Linting and formatting issues fixed across codebase

## Test Plan

- [x] CI lint, type, and test checks pass
- [x] Run `poetry install` from `classic/` directory
- [x] Run `poetry run autogpt` and verify CLI starts
- [x] Run `poetry run direct-benchmark run --tests ReadFile` to verify
benchmark works

## Notes

- This is a WIP PR for personal use improvements
- The project is marked as **unsupported** - no active maintenance
planned
- Contains known vulnerabilities in dependencies (intentionally not
updated)

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Medium Risk**
> CI/build workflows are substantially reworked (runner matrix removal,
path/layout changes, new benchmark runner), so breakage is most likely
in automation and packaging rather than runtime behavior.
> 
> **Overview**
> **Modernizes the `classic/` project layout and automation around a
single consolidated Poetry project** (root
`classic/pyproject.toml`/`poetry.lock`) and updates docs
(`classic/README.md`, new `classic/CLAUDE.md`) accordingly.
> 
> **Replaces the old `agbenchmark` CI usage with `direct-benchmark` in
GitHub Actions**, including new/updated benchmark smoke and regression
workflows, standardized `working-directory: classic`, and a move to
**Python 3.12** on Ubuntu-only runners (plus updated caching, coverage
flags, and required `ANTHROPIC_API_KEY` wiring).
> 
> Cleans up repo/dev tooling by removing the classic frontend workflow,
deleting the Forge VCR cassette submodule (`.gitmodules`) and associated
CI steps, consolidating `flake8`/`isort`/`pyright` pre-commit hooks to
run from `classic/`, updating ignores for new report/workspace
artifacts, and updating `classic/Dockerfile.autogpt` to build from
Python 3.12 with the consolidated project structure.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
de67834dac. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-04-03 07:16:36 +00:00
Zamil Majdy
fff101e037 feat(backend): add SQL query block with multi-database support for CoPilot analytics (#12569)
## Summary
- Add a read-only SQL query block for CoPilot/AutoPilot analytics access
- Supports **multiple databases**: PostgreSQL, MySQL, SQLite, MSSQL via
SQLAlchemy
- Enforces read-only queries (SELECT only) with defense-in-depth SQL
validation using sqlparse
- SSRF protection: blocks connections to private/internal IPs
- Credentials stored securely via the platform credential system

## Changes
- New `SQLQueryBlock` in `backend/blocks/sql_query_block.py` with
`DatabaseType` enum
- SQLAlchemy-based execution with dialect-specific read-only and timeout
settings
- Connection URL validation ensuring driver matches selected database
type
- Comprehensive test suite (62 tests) including URL validation,
sanitization, serialization
- Documentation in `docs/integrations/block-integrations/data.md`
- Added `DATABASE` provider to `ProviderName` enum

### Checklist 📋
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan

#### Test plan:
- [x] Unit tests pass for query validation, URL validation, error
sanitization, value serialization
- [x] Read-only enforcement rejects INSERT/UPDATE/DELETE/DROP
- [x] Multi-statement injection blocked
- [x] SSRF protection blocks private IPs
- [x] Connection URL driver validation works for all 4 database types

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-03 06:43:40 +00:00
Zamil Majdy
da544d3411 fix(platform): fix CI - regenerate API schema + fix Date type mismatch
- Regenerate openapi.json using export-api-schema command (CI-compatible)
- Convert string date params to Date objects before passing to generated
  API functions (orval generates Date | null for datetime fields)
- pnpm types passes cleanly
2026-04-02 20:56:08 +02:00
Zamil Majdy
54e5059d7c fix(platform): use generated Pagination type + estimate token costs
- Import Pagination from generated client instead of hand-written types
- Add DEFAULT_COST_PER_1K_TOKENS for OpenAI/Anthropic/Groq/Ollama
- estimateCostForRow now computes cost from token count when provider
  doesn't report actual USD (tokens * rate_per_1k / 1000)
- Added date comment for when default rates were checked
2026-04-02 20:43:23 +02:00
Zamil Majdy
1d7d2f77f3 feat(platform): tracking-aware dashboard with generated API client
Backend:
- ProviderCostSummary now includes tracking_type and total_duration_seconds
- CostLogRow includes tracking_type and duration
- SQL queries extract tracking_type from metadata JSON

Frontend:
- Replaced hand-written types/client with generated API client (orval)
- Actions use getV2GetPlatformCostDashboard/getV2GetPlatformCostLogs
- Provider table shows: Type badge, Usage metric, Known Cost, Estimated Cost
- Per-run providers have editable $/run input with defaults
- Summary cards show "Known Cost" vs "Estimated Total"
- Logs table shows tracking_type badge + duration column
- Color-coded badges: cost_usd(green), tokens(blue), duration(orange),
  characters(purple), per_run(gray)
2026-04-02 20:27:58 +02:00
Zamil Majdy
567bc73ec4 fix(blocks): regenerate block docs after merge with dev 2026-04-02 19:32:15 +02:00
Zamil Majdy
61ef54af05 Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into codex/platform-cost-tracking 2026-04-02 19:28:37 +02:00
Zamil Majdy
405403e6b7 fix(backend): initialize response before try block to satisfy pyright 2026-04-02 19:21:22 +02:00
Zamil Majdy
f1ac05b2e0 fix(backend): propagate dry-run mode to special blocks with LLM-powered simulation (#12575)
## Summary
- **OrchestratorBlock & AgentExecutorBlock** now execute for real in
dry-run mode so the orchestrator can make LLM calls and agent executors
can spawn child graphs. Their downstream tool blocks and child-graph
blocks are still simulated via `simulate_block()`. Credential fields
from node defaults are restored since `validate_exec()` wipes them in
dry-run mode. Agent-mode iterations capped at 1 in dry-run.
- **All blocks** (including MCPToolBlock) are simulated via a single
generic `simulate_block()` path. The LLM prompt is grounded by
`inspect.getsource(block.run)`, giving the simulator access to the exact
implementation of each block's `run()` method. This produces realistic
mock responses for any block type without needing block-specific
simulation logic.
- Updated agent generation guide to document special block dry-run
behavior.
- Minor frontend fixes: exported `formatCents` from
`RateLimitResetDialog` for reuse in `UsagePanelContent`, used `useRef`
for stable callback references in `useResetRateLimit` to avoid stale
closures.
- 74 tests (21 existing dry-run + 53 new simulator tests covering prompt
building, passthrough logic, and special block dry-run).

## Design

The simulator (`backend/executor/simulator.py`) uses a two-tier
approach:

1. **Passthrough blocks** (OrchestratorBlock, AgentExecutorBlock):
`prepare_dry_run()` returns modified input_data so these blocks execute
for real in `manager.py`. OrchestratorBlock gets `max_iterations=1`
(agent mode) or 0 (traditional mode). AgentExecutorBlock spawns real
child graph executions whose blocks inherit `dry_run=True`.

2. **All other blocks**: `simulate_block()` builds an LLM prompt
containing:
   - Block name and description
   - Input/output schemas (JSON Schema)
   - The block's `run()` source code via `inspect.getsource(block.run)`
- The actual input values (with credentials stripped and long values
truncated)

The LLM then role-plays the block's execution, producing realistic
outputs grounded in the actual implementation.

Special handling for input/output blocks: `AgentInputBlock` and
`AgentOutputBlock` are pure passthrough (no LLM call needed).

## Test plan
- [x] All 74 tests pass (`pytest backend/copilot/tools/test_dry_run.py
backend/executor/simulator_test.py`)
- [x] Pre-commit hooks pass (ruff, isort, black, pyright, frontend
typecheck)
- [x] CI: all checks green
- [x] E2E: dry-run execution completes with `is_dry_run=true`, cost=0,
no errors
- [x] E2E: normal (non-dry-run) execution unchanged
- [x] E2E: Create agent with OrchestratorBlock + tool blocks, run with
`dry_run=True`, verify orchestrator makes real LLM calls while tool
blocks are simulated
- [x] E2E: AgentExecutorBlock spawns child graph in dry-run, child
blocks are LLM-simulated
- [x] E2E: Builder simulate button works end-to-end with special blocks

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-02 17:09:55 +00:00
Zamil Majdy
f115607779 fix(copilot): recognize Agent tool name and route CLI state into workspace (#12635)
### Why / What / How

**Why:** The Claude Agent SDK CLI renamed the sub-agent tool from
`"Task"` to `"Agent"` in v2.x. Our security hooks only checked for
`"Task"`, so all sub-agent security controls were silently bypassed on
production: concurrency limiting didn't apply, and slot tracking was
broken. This was discovered via Langfuse trace analysis of session
`62b1b2b9` where background sub-agents ran unchecked.

Additionally, the CLI writes sub-agent output to `/tmp/claude-<uid>/`
and project state to `$HOME/.claude/` — both outside the per-session
workspace (`/tmp/copilot-<session>/`). This caused `PermissionError` in
E2B sandboxes and silently lost sub-agent results.

The frontend also had no rendering for the `Agent` / `TaskOutput` SDK
built-in tools — they fell through to the generic "other" category with
no context-aware display.

**What:**
1. Fix the sub-agent tool name recognition (`"Task"` → `{"Task",
"Agent"}`)
2. Allow `run_in_background` — the SDK handles async lifecycle cleanly
(returns `isAsync:true`, model polls via `TaskOutput`)
3. Route CLI state into the workspace via `CLAUDE_CODE_TMPDIR` and
`HOME` env vars
4. Add lifecycle hooks (`SubagentStart`/`SubagentStop`) for
observability
5. Add frontend `"agent"` tool category with proper UI rendering

**How:**
- Security hooks check `tool_name in _SUBAGENT_TOOLS` (frozenset of
`"Task"` and `"Agent"`)
- Background agents are allowed but still count against `max_subtasks`
concurrency limit
- Frontend detects `isAsync: true` output → shows "Agent started
(background)" not "Agent completed"
- `TaskOutput` tool shows retrieval status and collected results
- Robot icon and agent-specific accordion rendering for both foreground
and background agents

### Changes 🏗️

**Backend:**
- **`security_hooks.py`**: Replace `tool_name == "Task"` with `tool_name
in _SUBAGENT_TOOLS`. Remove `run_in_background` deny block (SDK handles
async lifecycle). Add `SubagentStart`/`SubagentStop` hooks.
- **`tool_adapter.py`**: Add `"Agent"` to `_SDK_BUILTIN_ALWAYS` list
alongside `"Task"`.
- **`service.py`**: Set `CLAUDE_CODE_TMPDIR=sdk_cwd` and `HOME=sdk_cwd`
in SDK subprocess env.
- **`security_hooks_test.py`**: Update background tests (allowed, not
blocked). Add test for background agents counting against concurrency
limit.

**Frontend:**
- **`GenericTool/helpers.ts`**: Add `"agent"` tool category for `Agent`,
`Task`, `TaskOutput`. Agent-specific animation text detecting `isAsync`
output. Input summaries from description/prompt fields.
- **`GenericTool/GenericTool.tsx`**: Add `RobotIcon` for agent category.
Add `getAgentAccordionData()` with async-aware title/content.
`TaskOutput` shows retrieval status.
- **`useChatSession.ts`**: Fix pre-existing TS error (void mutation
body).

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] All security hooks tests pass (background allowed + limit
enforced)
  - [x] Pre-commit hooks (ruff, black, isort, pyright, tsc) all pass
  - [x] E2E test: copilot agent create+run scenario PASS
- [ ] Deploy to dev and test copilot sub-agent spawning with background
mode

#### For configuration changes:
- [x] `.env.default` is updated or already compatible
- [x] `docker-compose.yml` is updated or already compatible
2026-04-03 00:09:19 +07:00
Zamil Majdy
1aef8b7155 fix(backend/copilot): fix tool output file reading between E2B and host (#12646)
### Why / What / How

**Why:** When copilot tools return large outputs (e.g. 3MB+ base64
images from API calls), the agent cannot process them in the E2B
sandbox. Three compounding issues prevent seamless file access:
1. The `<tool-output-truncated path="...">` tag uses a bare `path=`
attribute that the model confuses with a local filesystem path (it's
actually a workspace path)
2. `is_allowed_local_path` rejects `tool-outputs/` directories (only
`tool-results/` was allowed)
3. SDK-internal files read via the `Read` tool are not available in the
E2B sandbox for `bash_exec` processing

**What:** Fixes all three issues so that large tool outputs can be
seamlessly read and processed in both host and E2B contexts.

**How:**
- Changed `path=` → `workspace_path=` in the truncation tag to
disambiguate workspace vs filesystem paths
- Added `save_to_path` guidance in the retrieval instructions for E2B
users
- Extended `is_allowed_local_path` to accept both `tool-results` and
`tool-outputs` directories
- Added automatic bridging: when E2B is active and `Read` accesses an
SDK-internal file, the file is automatically copied to `/tmp/<filename>`
in the sandbox
- Updated system prompting to explain both SDK tool-result bridging and
workspace `<tool-output-truncated>` handling

### Changes 🏗️

- **`tools/base.py`**: `_persist_and_summarize` now uses
`workspace_path=` attribute and includes `save_to_path` example for E2B
processing
- **`context.py`**: `is_allowed_local_path` accepts both `tool-results`
and `tool-outputs` directory names
- **`sdk/e2b_file_tools.py`**: `_handle_read_file` bridges SDK-internal
files to `/tmp/` in E2B sandbox; new `_bridge_to_sandbox` helper
- **`prompting.py`**: Updated "SDK tool-result files" section and added
"Large tool outputs saved to workspace" section
- **Tests**: Added `tool-outputs` path validation tests in
`context_test.py` and `e2b_file_tools_test.py`; updated `base_test.py`
assertion for `workspace_path`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] `poetry run pytest backend/copilot/tools/base_test.py` — all 9
tests pass (persistence, truncation, binary fields)
  - [x] `poetry run format` and `poetry run lint` pass clean
  - [x] All pre-commit hooks pass
- [ ] `context_test.py`, `e2b_file_tools_test.py`,
`security_hooks_test.py` — blocked by pre-existing DB migration issue on
worktree (missing `User.subscriptionTier` column); CI will validate
these
2026-04-03 00:08:04 +07:00
Zamil Majdy
ab16e63b0a fix(platform): pass model name to copilot cost tracking
Both SDK and Baseline paths now pass config.model to
persist_and_record_usage so PlatformCostLog records the actual
model (e.g. anthropic/claude-sonnet-4) for filtering/grouping.
2026-04-02 19:03:47 +02:00
Zamil Majdy
45d3193727 fix(platform): move baseline cost extraction to finally + accumulate multi-round costs
- Move x-total-cost header extraction to finally block so cost is
  captured even when stream errors mid-way (we already paid)
- Accumulate cost across multi-round tool-calling turns instead of
  overwriting with last round only
- Handle UnboundLocalError if response was never assigned
2026-04-02 19:00:44 +02:00
Zamil Majdy
9a08011d7d fix(platform): move opentelemetry import to top-level in both copilot paths 2026-04-02 18:57:36 +02:00
Zamil Majdy
6fa66ac7da feat(platform): add cost/token OTEL attributes to both copilot paths
Both SDK and Baseline copilot paths now set OpenTelemetry span
attributes for cost tracking before the trace context closes:
- gen_ai.usage.prompt_tokens
- gen_ai.usage.completion_tokens
- gen_ai.usage.cost_usd (when available)
- gen_ai.usage.cache_read_tokens (SDK only)
- gen_ai.usage.cache_creation_tokens (SDK only)

Also extracts x-total-cost from OpenRouter response headers in the
Baseline streaming path, giving actual USD cost for both modes.

These attributes flow to Langfuse/any OTEL backend for cost dashboards.
2026-04-02 18:55:12 +02:00
Zamil Majdy
4bad08394c feat(platform): extract OpenRouter cost from baseline copilot path
The baseline copilot path uses the same OpenRouter API but wasn't
extracting the x-total-cost header. Now extracts cost from the
streaming response headers and passes it to persist_and_record_usage,
giving us actual USD cost for both copilot modes.
2026-04-02 18:39:55 +02:00
Zamil Majdy
993c43b623 feat(platform): add merge_stats to remaining blocks (FAL, Revid, D-ID, E2B, YouTube, Weather, TTS, Enrichlayer)
Every system credential block now has explicit merge_stats tracking.
No block relies on the generic fallback anymore.
2026-04-02 18:22:02 +02:00
Zamil Majdy
a8a62eeefc feat(platform): add merge_stats tracking to all system credential blocks
Every block that uses system credentials now calls merge_stats with
meaningful data after the API response:
- Google Maps: output_size = number of places returned (= detail API calls)
- Apollo people/org: output_size = results count
- Apollo person: output_size = 1 per enrichment
- SmartLead: output_size = leads added or 1 per operation
- Ideogram: output_size = 1 per image
- Replicate: output_size = 1 per prediction
- Nvidia: output_size = 1 per inference
- ScreenshotOne: output_size = 1 per screenshot
- ZeroBounce: output_size = 1 per email validated
- Mem0: output_size = 1 per memory operation
2026-04-02 18:13:15 +02:00
Zamil Majdy
173614bcc5 fix(platform): audit and fix per-provider tracking accuracy
- Fix ElevenLabs/D-ID field name: script -> script_input
- Remove incorrect Google Maps api_calls formula, use per_run instead
- Remove D-ID from generation_seconds (walltime includes polling)
- Jina embeddings: extract total_tokens from response.usage
- Simplify tracking types: cost_usd, tokens, characters,
  sandbox_seconds, walltime_seconds, per_run
2026-04-02 17:58:24 +02:00
Zamil Majdy
fbe634fb19 fix(platform): handle null user_id in cost logs and fix 0.0 cost stored as NULL
- Add null-safe optional chaining for user_id.slice() in LogsTable, displaying
  "Deleted user" when user_id is null to prevent frontend crash
- Change `if cost_float` to `if cost_float is not None` in token_tracking.py
  so that a legitimate $0.00 cost is stored as 0 instead of NULL
2026-04-02 17:38:59 +02:00
Zamil Majdy
a338c72c42 Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into codex/platform-cost-tracking 2026-04-02 17:36:14 +02:00
Zamil Majdy
7f4398efa3 feat(platform): provider-specific tracking types for accurate cost metrics
Replace one-size-fits-all tracking cascade with provider-aware logic:
- cost_usd: OpenRouter (x-total-cost header), Exa (cost_dollars)
- tokens: OpenAI, Anthropic, Groq, Ollama (token counts)
- characters: Unreal Speech, ElevenLabs (input text length)
- api_calls: Google Maps (1 nearby + N detail calls)
- sandbox_seconds: E2B (sandbox execution time)
- generation_seconds: FAL, Revid, D-ID, Replicate (video/image gen time)
- per_run: Apollo, SmartLead, ZeroBounce, Jina, etc.
2026-04-02 17:30:15 +02:00
Zamil Majdy
c2a054c511 fix(backend): prevent provider_cost loss on stats merge and widen costMicrodollars to BigInt
- NodeExecutionStats.__iadd__ was overwriting accumulated provider_cost
  with None when merging stats that lacked provider_cost (e.g. the final
  llm_call_count/llm_retry_count merge). Skip None values in __iadd__
  so existing data is never erased.
- Widen PlatformCostLog.costMicrodollars from Int (max ~$2,147) to
  BigInt to prevent theoretical overflow for high-cost aggregated
  node executions.
2026-04-02 17:28:27 +02:00
Zamil Majdy
83b00f4789 feat(platform): add copilot/autopilot cost tracking via token_tracking.py
Copilot uses OpenRouter via a separate code path (not through the block
executor). This integrates PlatformCostLog into the shared
persist_and_record_usage() function which is called by both SDK and
baseline copilot paths, capturing:
- Every LLM turn (main conversation, title gen, context compression)
- Tokens (prompt + completion + cache)
- Actual USD cost when available (SDK path provides cost_usd)
- Session ID for correlation
2026-04-02 17:17:53 +02:00
Nicholas Tindle
0da949ba42 feat(e2b): set git committer identity from user's GitHub profile (#12650)
## Summary

Sets git author/committer identity in E2B sandboxes using the user's
connected GitHub account profile, so commits are properly attributed.

## Changes

### `integration_creds.py`
- Added `get_github_user_git_identity(user_id)` that fetches the user's
name and email from the GitHub `/user` API
- Uses TTL cache (10 min) to avoid repeated API calls
- Falls back to GitHub noreply email
(`{id}+{login}@users.noreply.github.com`) when user has a private email
- Falls back to `login` if `name` is not set

### `bash_exec.py`
- After injecting integration env vars, calls
`get_github_user_git_identity()` and sets `GIT_AUTHOR_NAME`,
`GIT_AUTHOR_EMAIL`, `GIT_COMMITTER_NAME`, `GIT_COMMITTER_EMAIL`
- Only sets these if the user has a connected GitHub account

### `bash_exec_test.py`
- Added tests covering: identity set from GitHub profile, no identity
when GitHub not connected, no injection when no user_id

## Why
Previously, commits made inside E2B sandboxes had no author identity
set, leading to unattributed commits. This dynamically resolves identity
from the user's actual GitHub account rather than hardcoding a default.

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Medium Risk**
> Adds outbound calls to GitHub’s `/user` API during `bash_exec` runs
and injects returned identity into the sandbox environment, which could
impact reliability (network/timeouts) and attribution behavior. Caching
mitigates repeated calls but incorrect/expired tokens or API failures
may lead to missing identity in commits.
> 
> **Overview**
> Sets git author/committer environment variables in the E2B `bash_exec`
path by fetching the connected user’s GitHub profile and injecting
`GIT_AUTHOR_*`/`GIT_COMMITTER_*` into the sandbox env.
> 
> Introduces `get_github_user_git_identity()` with TTL caching
(including a short-lived null cache), fallback to GitHub noreply email
when needed, and ensures `invalidate_user_provider_cache()` also clears
identity caches for the `github` provider. Updates tests to cover
identity injection behavior and the new cache invalidation semantics.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
955ec81efe. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: AutoGPT <autopilot@agpt.co>
2026-04-02 15:07:22 +00:00
Zamil Majdy
95524e94b3 feat(platform): add tracking_type and tracking_amount to cost log metadata
Standardize cost tracking across providers:
- cost_usd: actual dollar cost (OpenRouter, Exa)
- tokens: total token count (LLM blocks)
- duration_seconds: execution time (video gen, sandboxes)
- per_run: flat per-request (all others)
2026-04-02 17:04:50 +02:00
Zamil Majdy
2c517ff9a1 feat(platform): add per-provider cost extraction
- OpenRouter: Extract actual USD cost from x-total-cost response header
- Exa (search, contents): Write cost_dollars.total to execution_stats
- LLM blocks: Store provider_cost in stats when available
- Add provider_cost field to NodeExecutionStats
- Hook now converts provider_cost to costMicrodollars in PlatformCostLog
- Metadata includes both credit_cost and provider_cost_usd when available
2026-04-02 16:57:34 +02:00
Zamil Majdy
7020ae2189 fix(backend): handle NULL userId in platform cost models and queries
Make user_id Optional[str] in UserCostSummary and CostLogRow to handle
cases where the referenced user has been deleted. Use .get() for safe
access to user_id from query result rows. Regenerate OpenAPI schema.
2026-04-02 16:54:09 +02:00
Zamil Majdy
b9336984be fix(platform): re-add credit_cost to platform cost log metadata
Include the block's credit cost (from block_cost_config) in the log
metadata so every entry has a known cost proxy even when the provider
doesn't expose actual dollar costs.
2026-04-02 16:37:28 +02:00
Zamil Majdy
9924dedddc fix(platform): address bot review comments (sentry + coderabbit)
- CRITICAL: Use execute_raw_with_schema for INSERT (not query_raw)
- Remove accidentally committed transcripts/
- Add dry_run guard to skip cost logging for simulated executions
- Change onDelete: Cascade → SetNull to preserve cost history
- Add standalone createdAt index for date-only queries
- Add deterministic tiebreaker (id) to pagination ORDER BY
- Update migration SQL to match schema changes
2026-04-02 16:26:01 +02:00
Zamil Majdy
c054799b4f fix: regenerate API schema and block docs 2026-04-02 16:23:12 +02:00
Zamil Majdy
f3b5d584a3 fix(platform): address PR review round 5
- Replace ServerCrash icon with Receipt for Platform Costs sidebar
2026-04-02 16:02:00 +02:00
Zamil Majdy
476d9dcf80 fix(platform): address PR review round 4
- Add tests for query parameter forwarding and pagination
2026-04-02 16:00:08 +02:00
Zamil Majdy
072b623f8b fix(platform): address PR review round 3
- Remove duplicate block_usage_cost call from cost logging
- Add case-insensitive provider filter using LOWER()
- Add platform_cost_routes_test.py with basic endpoint tests
2026-04-02 15:58:00 +02:00
Zamil Majdy
26b0c95936 fix(platform): address PR review round 2
- Parallelize dashboard queries with asyncio.gather for ~3x speedup
- Move json import to top-level
- Use consistent p. table alias across all dashboard queries
2026-04-02 15:55:03 +02:00
Zamil Majdy
308357de84 fix(platform): address PR review round 1
- Parameterize LIMIT/OFFSET in SQL queries to prevent injection
- Only log platform cost on successful block execution
- Convert model enum values to strings for proper logging
- Add error handling with try/catch/finally in frontend useEffect
- Drive filter state from URL params to prevent desync
- Add dark mode support using design tokens
- Return total_users count in dashboard for accurate reporting
- Add credit_cost to metadata as cost proxy until per-token pricing
2026-04-02 15:51:28 +02:00
Zamil Majdy
1a6c50c6cc feat(platform): add platform cost tracking for system credentials
Track real API costs incurred when users consume system-managed credentials.
Captures provider, tokens, duration, and model per block execution and
surfaces an admin dashboard with provider/user aggregation and raw logs.
2026-04-02 15:42:18 +02:00
Zamil Majdy
6b031085bd feat(platform): add generic ask_question copilot tool (#12647)
### Why / What / How

**Why:** The copilot can ask clarifying questions in plain text, but
that text gets collapsed into hidden "reasoning" UI when the LLM also
calls tools in the same turn. This makes clarification questions
invisible to users. The existing `ClarificationNeededResponse` model and
`ClarificationQuestionsCard` UI component were built for this purpose
but had no tool wiring them up.

**What:** Adds a generic `ask_question` tool that produces a visible,
interactive clarification card instead of collapsible plain text. Unlike
the agent-generation-specific `clarify_agent_request` proposed in
#12601, this tool is workflow-agnostic — usable for agent building,
editing, troubleshooting, or any flow needing user input.

**How:** 
- Backend: New `AskQuestionTool` reuses existing
`ClarificationNeededResponse` model. Registered in `TOOL_REGISTRY` and
`ToolName` permissions.
- Frontend: New `AskQuestion/` renderer reuses
`ClarificationQuestionsCard` from CreateAgent. Registered in
`CUSTOM_TOOL_TYPES` (prevents collapse into reasoning) and
`MessagePartRenderer`.
- Guide: `agent_generation_guide.md` updated to reference `ask_question`
for the clarification step.

### Changes 🏗️

- **`copilot/tools/ask_question.py`** — New generic tool: takes
`question`, optional `options[]` and `keyword`, returns
`ClarificationNeededResponse`
- **`copilot/tools/__init__.py`** — Register `ask_question` in
`TOOL_REGISTRY`
- **`copilot/permissions.py`** — Add `ask_question` to `ToolName`
literal
- **`copilot/sdk/agent_generation_guide.md`** — Reference `ask_question`
tool in clarification step
- **`ChatMessagesContainer/helpers.ts`** — Add `tool-ask_question` to
`CUSTOM_TOOL_TYPES`
- **`MessagePartRenderer.tsx`** — Add switch case for
`tool-ask_question`
- **`AskQuestion/AskQuestion.tsx`** — Renderer reusing
`ClarificationQuestionsCard`
- **`AskQuestion/helpers.ts`** — Output parsing and animation text

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Backend format + pyright pass
  - [x] Frontend lint + types pass
  - [x] Pre-commit hooks pass
- [ ] Manual test: copilot uses `ask_question` and card renders visibly
(not collapsed)
2026-04-02 12:56:48 +00:00
Toran Bruce Richards
11b846dd49 fix(blocks): rename placeholder_values to options on AgentDropdownInputBlock (#12595)
## Summary

Resolves [REQ-78](https://linear.app/autogpt/issue/REQ-78): The
`placeholder_values` field on `AgentDropdownInputBlock` is misleadingly
named. In every major UI framework "placeholder" means non-binding hint
text that disappears on focus, but this field actually creates a
dropdown selector that restricts the user to only those values.

## Changes

### Core rename (`autogpt_platform/backend/backend/blocks/io.py`)
- Renamed `placeholder_values` → `options` on
`AgentDropdownInputBlock.Input`
- Added clear field description: *"If provided, renders the input as a
dropdown selector restricted to these values. Leave empty for free-text
input."*
- Updated class docstring to describe actual behavior
- Overrode `model_construct()` to remap legacy `placeholder_values` →
`options` for **backward compatibility** with existing persisted agent
JSON

### Tests (`autogpt_platform/backend/backend/blocks/test/test_block.py`)
- Updated existing tests to use canonical `options` field name
- Added 2 new backward-compat tests verifying legacy
`placeholder_values` still works through both `model_construct()` and
`Graph._generate_schema()` paths

### Documentation
- Updated
`autogpt_platform/backend/backend/copilot/sdk/agent_generation_guide.md`
— changed field name in CoPilot SDK guide
- Updated `docs/integrations/block-integrations/basic.md` — changed
field name and description in public docs

### Load tests
(`autogpt_platform/backend/load-tests/tests/api/graph-execution-test.js`)
- Removed spurious `placeholder_values: {}` from AgentInputBlock node
(this field never existed on AgentInputBlock)
- Fixed execution input to use `value` instead of `placeholder_values`

## Backward Compatibility

Existing agents with `placeholder_values` in their persisted
`input_default` JSON will continue to work — the `model_construct()`
override transparently remaps the old key to `options`. No database
migration needed since the field is stored inside a JSON blob, not as a
dedicated column.

## Testing

- All existing tests updated and passing
- 2 new backward-compat tests added
- No frontend changes needed (frontend reads `enum` from generated JSON
Schema, not the field name directly)

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-04-02 05:56:17 +00:00
Zamil Majdy
b9e29c96bd fix(backend/copilot): detect prompt-too-long in AssistantMessage content and ResultMessage success subtype (#12642)
## Why

PR #12625 fixed the prompt-too-long retry mechanism for most paths, but
two SDK-specific paths were still broken. The dev session `d2f7cba3`
kept accumulating synthetic "Prompt is too long" error entries on every
turn, growing the transcript from 2.5 MB → 3.2 MB, making recovery
impossible.

Root causes identified from production logs (`[T25]`, `[T28]`):

**Path 1 — AssistantMessage content check:**
When the Claude API rejects a prompt, the SDK surfaces it as
`AssistantMessage(error="invalid_request", content=[TextBlock("Prompt is
too long")])`. Our check only inspected `error_text = str(sdk_error)`
which is `"invalid_request"` — not a prompt-too-long pattern. The
content was then streamed out as `StreamText`, setting `events_yielded =
1`, which blocked retry even when the ResultMessage fired.

**Path 2 — ResultMessage success subtype:**
After the SDK auto-compacts internally (via `PreCompact` hook) and the
compacted transcript is _still_ too long, the SDK returns
`ResultMessage(subtype="success", result="Prompt is too long")`. Our
check only ran for `subtype="error"`. With `subtype="success"`, the
stream "completed normally", appended the synthetic error entry to the
transcript via `transcript_builder`, and uploaded it to GCS — causing
the transcript to grow on each failed turn.

## What

- **AssistantMessage handler**: when `sdk_error` is set, also check the
content text. `sdk_error` being non-`None` confirms this is an API error
message (not user-generated content), so content inspection is safe.
- **ResultMessage handler**: check `result` for prompt-too-long patterns
regardless of `subtype`, covering the SDK auto-compact path where
`subtype="success"` with `result="Prompt is too long"`.

## How

Two targeted one-line condition expansions in `_run_stream_attempt`,
plus two new integration tests in `retry_scenarios_test.py` that
reproduce each broken path and verify retry fires correctly.

## Changes

- `backend/copilot/sdk/service.py`: fix AssistantMessage content check +
ResultMessage subtype-independent check
- `backend/copilot/sdk/retry_scenarios_test.py`: add 2 integration tests
for the new scenarios

## Checklist

- [x] Tests added for both new scenarios (45 total, all pass)
- [x] Formatted (`poetry run format`)
- [x] No false-positive risk: AssistantMessage check gated behind
`sdk_error is not None`
- [x] Root cause verified from production pod logs
2026-04-01 22:32:09 +00:00
Zamil Majdy
4ac0ba570a fix(backend): fix copilot credential loading across event loops (#12628)
## Why

CoPilot autopilot sessions are inconsistently failing to load user
credentials (specifically GitHub OAuth). Some sessions proceed normally,
some show "provide credentials" prompts despite the user having valid
creds, and some are completely blocked.

Production logs confirmed the root cause: `RuntimeError: Task got Future
<Future pending> attached to a different loop` in the credential refresh
path, cascading into null-cache poisoning that blocks credential lookups
for 60 seconds.

## What

Three interrelated bugs in the credential system:

1. **`refresh_if_needed` always acquired Redis locks even with
`lock=False`** — The `lock` parameter only controlled the inner
credential lock, but the outer "refresh" scope lock was always acquired.
The copilot executor uses multiple worker threads with separate event
loops; the `asyncio.Lock` inside `AsyncRedisKeyedMutex` was bound to one
loop and failed on others.

2. **Stale event loop in `locks()` singleton** — Both
`IntegrationCredentialsManager` and `IntegrationCredentialsStore` cached
their `AsyncRedisKeyedMutex` without tracking which event loop created
it. When a different worker thread (with a different loop) reused the
singleton, it got the "Future attached to different loop" error.

3. **Null-cache poisoning on refresh failure** — When OAuth refresh
failed (due to the event loop error), the code fell through to cache "no
credentials found" for 60 seconds via `_null_cache`. This blocked ALL
subsequent credential lookups for that user+provider, even though the
credentials existed and could refresh fine on retry.

## How

- Split `refresh_if_needed` into `_refresh_locked` / `_refresh_unlocked`
so `lock=False` truly skips ALL Redis locking (safe for copilot's
best-effort background injection)
- Added event loop tracking to `locks()` in both
`IntegrationCredentialsManager` and `IntegrationCredentialsStore` —
recreates the mutex when the running loop changes
- Only populate `_null_cache` when the user genuinely has no
credentials; skip caching when OAuth refresh failed transiently
- Updated existing test to verify null-cache is not poisoned on refresh
failure

## Test plan

- [x] All 14 existing `integration_creds_test.py` tests pass
- [x] Updated
`test_oauth2_refresh_failure_returns_none_without_null_cache` verifies
null-cache is not populated on refresh failure
- [x] Format, lint, and typecheck pass
- [ ] Deploy to staging and verify copilot sessions consistently load
GitHub credentials
2026-04-02 00:11:38 +07:00
Zamil Majdy
d61a2c6cd0 Revert "fix(backend/copilot): detect prompt-too-long in AssistantMessage content and ResultMessage success subtype"
This reverts commit 1c301b4b61.
2026-04-01 18:59:38 +02:00
Zamil Majdy
1c301b4b61 fix(backend/copilot): detect prompt-too-long in AssistantMessage content and ResultMessage success subtype
The SDK returns AssistantMessage(error="invalid_request", content=[TextBlock("Prompt is too long")])
followed by ResultMessage(subtype="success", result="Prompt is too long") when the transcript is
rejected after internal auto-compaction. Both paths bypassed the retry mechanism:

- AssistantMessage handler only checked error_text ("invalid_request"), not the content which
  holds the actual error description. The content was then streamed as text, setting events_yielded=1,
  which blocked retry even when ResultMessage fired.
- ResultMessage handler only triggered prompt-too-long detection for subtype="error", not
  subtype="success". The stream "completed normally", stored the synthetic error entry in the
  transcript, and uploaded it — causing the transcript to grow unboundedly on each failed turn.

Fixes:
1. AssistantMessage handler: when sdk_error is set (confirmed error message), also check content
   text. sdk_error being set guarantees this is an API error, not user-generated content, so
   content inspection is safe.
2. ResultMessage handler: check result for prompt-too-long regardless of subtype, covering the
   case where the SDK auto-compacts internally but the result is still too long.

Adds integration tests for both new scenarios.
2026-04-01 18:28:46 +02:00
Zamil Majdy
24d0c35ed3 fix(backend/copilot): prompt-too-long retry, compaction churn, model-aware compression, and truncated tool call recovery (#12625)
## Why

CoPilot has several context management issues that degrade long
sessions:
1. "Prompt is too long" errors crash the session instead of triggering
retry/compaction
2. Stale thinking blocks bloat transcripts, causing unnecessary
compaction every turn
3. Compression target is hardcoded regardless of model context window
size
4. Truncated tool calls (empty `{}` args from max_tokens) kill the
session instead of guiding the model to self-correct

## What

**Fix 1: Prompt-too-long retry bypass (SENTRY-1207)**
The SDK surfaces "prompt too long" via `AssistantMessage.error` and
`ResultMessage.result` — neither triggered the retry/compaction loop
(only Python exceptions did). Now both paths are intercepted and
re-raised.

**Fix 2: Strip stale thinking blocks before upload**
Thinking/redacted_thinking blocks in non-last assistant entries are
10-50K tokens each but only needed for API signature verification in the
*last* message. Stripping before upload reduces transcript size and
prevents per-turn compaction.

**Fix 3: Model-aware compression target**
`compress_context()` now computes `target_tokens` from the model's
context window (e.g. 140K for Opus 200K) instead of a hardcoded 120K
default. Larger models retain more history; smaller models compress more
aggressively.

**Fix 4: Self-correcting truncated tool calls**
When the model's response exceeds max_tokens, tool call inputs get
silently truncated to `{}`. Previously this tripped a circuit breaker
after 3 attempts. Now the MCP wrapper detects empty args and returns
guidance: "write in chunks with `cat >>`, pass via
`@@agptfile:filename`". The model can self-correct instead of the
session dying.

## How

- **service.py**: `_is_prompt_too_long` checks in both
`AssistantMessage.error` and `ResultMessage` error handlers. Circuit
breaker limit raised from 3→5.
- **transcript.py**: `strip_stale_thinking_blocks()` reverse-scans for
last assistant `message.id`, strips thinking blocks from all others.
Called in `upload_transcript()`.
- **prompt.py**: `get_compression_target(model)` computes
`context_window - 60K overhead`. `compress_context()` uses it when
`target_tokens` is None.
- **tool_adapter.py**: `_truncating` wrapper intercepts empty args on
tools with required params, returns actionable guidance instead of
failing.

## Related

- Fixes SENTRY-1207
- Sessions: `d2f7cba3` (repeated compaction), `08b807d4` (prompt too
long), `130d527c` (truncated tool calls)
- Extends #12413, consolidates #12626

## Test plan

- [x] 6 unit tests for `strip_stale_thinking_blocks`
- [x] 1 integration test for ResultMessage prompt-too-long → compaction
retry
- [x] Pyright clean (0 errors), all pre-commit hooks pass
- [ ] E2E: Load transcripts from affected sessions and verify behavior
2026-04-01 15:10:57 +00:00
Zamil Majdy
8aae7751dc fix(backend/copilot): prevent duplicate block execution from pre-launch arg mismatch (#12632)
## Why

CoPilot sessions are duplicating Linear tickets and GitHub PRs.
Investigation of 5 production sessions (March 31st) found that 3/5
created duplicate Linear issues — each with consecutive IDs at the exact
same timestamp, but only one visible in Langfuse traces.

Production gcloud logs confirm: **279 arg mismatch warnings per day**,
**37 duplicate block execution pairs**, and all LinearCreateIssueBlock
failures in pairs.

Related: SECRT-2204

## What

Replace the speculative pre-launch mechanism with the SDK's native
parallel dispatch via `readOnlyHint` tool annotations. Remove ~580 lines
of pre-launch infrastructure code.

## How

### Root cause
The pre-launch mechanism had three compounding bugs:
1. **Arg mismatch**: The SDK CLI normalises args between the
`AssistantMessage` (used for pre-launch) and the MCP `tools/call`
dispatch, causing frequent mismatches (279/day in prod)
2. **FIFO desync on denial**: Security hooks can deny tool calls,
causing the CLI to skip the MCP dispatch — but the pre-launched task
stays in the FIFO queue, misaligning all subsequent matches
3. **Cancel race**: `task.cancel()` is best-effort in asyncio — if the
HTTP call to Linear/GitHub already completed, the side effect is
irreversible

### Fix
- **Removed** `pre_launch_tool_call()`, `cancel_pending_tool_tasks()`,
`_tool_task_queues` ContextVar, all FIFO queue logic, and all 4
`cancel_pending_tool_tasks()` calls in `service.py`
- **Added** `readOnlyHint=True` annotations on 15+ read-only tools
(`find_block`, `search_docs`, `list_workspace_files`, etc.) — the SDK
CLI natively dispatches these in parallel ([ref:
anthropics/claude-code#14353](https://github.com/anthropics/claude-code/issues/14353))
- Side-effect tools (`run_block`, `bash_exec`, `create_agent`, etc.)
have no annotation → CLI runs them sequentially → no duplicate execution
risk

### Net change: -578 lines, +105 lines
2026-04-01 13:42:54 +00:00
An Vy Le
725da7e887 dx(backend/copilot): clarify ambiguous agent goals using find_block before generation (#12601)
### Why / What / How

**Why:** When a user asks CoPilot to build an agent with an ambiguous
goal (output format, delivery channel, data source, or trigger
unspecified), the agent generator previously made assumptions and jumped
straight into JSON generation. This produced agents that didn't match
what the user actually wanted, requiring multiple correction cycles.

**What:** Adds a "Clarifying Before Building" section to the agent
generation guide. When the goal is ambiguous, CoPilot first calls
`find_block` to discover what the platform actually supports for the
ambiguous dimension, then asks the user one concrete question grounded
in real platform options (e.g. "The platform supports Gmail, Slack, and
Google Docs — which should the agent use for delivery?"). Only after the
user answers does the full agent generation workflow proceed.

**How:** The clarification instruction is added to
`agent_generation_guide.md` — the guide loaded on-demand via
`get_agent_building_guide` when the LLM is about to build an agent. This
avoids polluting the system prompt supplement (which loads for every
CoPilot conversation, not just agent building). No dedicated tool is
needed — the LLM asks naturally in conversation text after discovering
real platform options via `find_block`.

### Changes 🏗️

- `backend/copilot/sdk/agent_generation_guide.md`: Adds "Clarifying
Before Building" section before the workflow steps. Instructs the model
to call `find_block` for the ambiguous dimension, ask the user one
grounded question, wait for the answer, then proceed to generation.
- `backend/copilot/prompting_test.py`: New test file verifying the guide
contains the clarification section and references `find_block`.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [ ] Ask CoPilot to "build an agent to send a report" (ambiguous
output) — verify it calls `find_block` for delivery options and asks one
grounded question before generating JSON
- [ ] Ask CoPilot to "build an agent to scrape prices from Amazon and
email me daily" (specific goal) — verify it skips clarification and
proceeds directly to agent generation
- [ ] Verify the clarification question lists real block options (e.g.
Gmail, Slack, Google Docs) rather than abstract options

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-04-01 13:32:12 +00:00
seer-by-sentry[bot]
bd9e9ec614 fix(frontend): remove LaunchDarkly local storage bootstrapping (#12606)
### Why / What / How

<!-- Why: Why does this PR exist? What problem does it solve, or what's
broken/missing without it? -->
This PR fixes
[BUILDER-7HD](https://sentry.io/organizations/significant-gravitas/issues/7374387984/).
The issue was that: LaunchDarkly SDK fails to construct streaming URL
due to non-string `_url` from malformed `localStorage` bootstrap data.
<!-- What: What does this PR change? Summarize the changes at a high
level. -->
Removed the `bootstrap: "localStorage"` option from the LaunchDarkly
provider configuration.
<!-- How: How does it work? Describe the approach, key implementation
details, or architecture decisions. -->
This change ensures that LaunchDarkly no longer attempts to load initial
feature flag values from local storage. Flag values will now always be
fetched directly from the LaunchDarkly service, preventing potential
issues with stale local storage data.

### Changes 🏗️

<!-- List the key changes. Keep it higher level than the diff but
specific enough to highlight what's new/modified. -->
- Removed the `bootstrap: "localStorage"` option from the LaunchDarkly
provider configuration.
- LaunchDarkly will now always fetch flag values directly from its
service, bypassing local storage.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [ ] I have made a test plan
- [ ] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
- [ ] Verify that LaunchDarkly flags are loaded correctly without
issues.
- [ ] Ensure no errors related to `localStorage` or streaming URL
construction appear in the console.

<details>
  <summary>Example test plan</summary>
  
  - [ ] Create from scratch and execute an agent with at least 3 blocks
- [ ] Import an agent from file upload, and confirm it executes
correctly
  - [ ] Upload agent to marketplace
- [ ] Import an agent from marketplace and confirm it executes correctly
  - [ ] Edit an agent from monitor, and confirm it executes correctly
</details>

#### For configuration changes:

- [ ] `.env.default` is updated or already compatible with my changes
- [ ] `docker-compose.yml` is updated or already compatible with my
changes
- [ ] I have included a list of my configuration changes in the PR
description (under **Changes**)

<details>
  <summary>Examples of configuration changes</summary>

  - Changing ports
  - Adding new services that need to communicate with each other
  - Secrets or environment variable changes
  - New or infrastructure changes such as databases
</details>

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
Co-authored-by: seer-by-sentry[bot] <157164994+seer-by-sentry[bot]@users.noreply.github.com>
2026-04-01 19:12:54 +07:00
Nicholas Tindle
88589764b5 dx(platform): normalize agent instructions for Claude and Codex (#12592)
### Why / What / How

Why: repo guidance was split between Claude-specific `CLAUDE.md` files
and Codex-specific `AGENTS.md` files, which duplicated instruction
content and made the same repository behave differently across agents.
The repo also had Claude skills under `.claude/skills` but no
Codex-visible repo skill path.

What: this PR bridges the repo's Claude skills into Codex and normalizes
shared instruction files so `AGENTS.md` becomes the canonical source
while each `CLAUDE.md` imports its sibling `AGENTS.md`.

How: add a repo-local `.agents/skills` symlink pointing to
`../.claude/skills`; move nested `CLAUDE.md` content into sibling
`AGENTS.md` files; replace each repo `CLAUDE.md` with a one-line
`@AGENTS.md` shim so Claude and Codex read the same scoped guidance
without duplicating text. The root `CLAUDE.md` now imports the root
`AGENTS.md` rather than symlinking to it.

Note: the instruction-file normalization commit was created with
`--no-verify` because the repo's frontend pre-commit `tsc` hook
currently fails on unrelated existing errors, largely missing
`autogpt_platform/frontend/src/app/api/__generated__/*` modules.

### Changes 🏗️

- Add `.agents/skills` as a repo-local symlink to `../.claude/skills` so
Codex discovers the existing Claude repo skills.
- Add a real root `CLAUDE.md` shim that imports the canonical root
`AGENTS.md`.
- Promote nested scoped instruction content into sibling `AGENTS.md`
files under `autogpt_platform/`, `autogpt_platform/backend/`,
`autogpt_platform/frontend/`, `autogpt_platform/frontend/src/tests/`,
and `docs/`.
- Replace the corresponding nested `CLAUDE.md` files with one-line
`@AGENTS.md` shims.
- Preserve the existing scoped instruction hierarchy while making the
shared content cross-compatible between Claude and Codex.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Verified `.agents/skills` resolves to `../.claude/skills`
  - [x] Verified each repo `CLAUDE.md` now contains only `@AGENTS.md`
- [x] Verified the expected `AGENTS.md` files exist at the root and
nested scoped directories
- [x] Verified the branch contains only the intended agent-guidance
commits relative to `dev` and the working tree is clean

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

No runtime configuration changes are included in this PR.

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Low Risk**
> Low risk: documentation/instruction-file reshuffle plus an
`.agents/skills` pointer; no runtime code paths are modified.
> 
> **Overview**
> Unifies agent guidance so **`AGENTS.md` becomes canonical** and all
corresponding `CLAUDE.md` files become 1-line shims (`@AGENTS.md`) at
the repo root, `autogpt_platform/`, backend, frontend, frontend tests,
and `docs/`.
> 
> Adds `.agents/skills` pointing to `../.claude/skills` so non-Claude
agents discover the same shared skills/instructions, eliminating
duplicated/agent-specific guidance content.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
839483c3b6. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->
2026-04-01 09:08:51 +00:00
Zamil Majdy
c659f3b058 fix(copilot): fix dry-run simulation showing INCOMPLETE/error status (#12580)
## Summary
- **Backend**: Strip empty `error` pins from dry-run simulation outputs
that the simulator always includes (set to `""` meaning "no error").
This was causing the LLM to misinterpret successful simulations as
failures and report "INCOMPLETE" status to users
- **Backend**: Add explicit "Status: COMPLETED" to dry-run response
message to prevent LLM misinterpretation
- **Backend**: Update simulation prompt to exclude `error` from the
"MUST include" keys list, and instruct LLM to omit error unless
simulating a logical failure
- **Frontend**: Fix `isRunBlockErrorOutput()` type guard that was too
broad (`"error" in output` matched BlockOutputResponse objects, not just
ErrorResponse), causing dry-run results to be displayed as errors
- **Frontend**: Fix `parseOutput()` fallback matching to not classify
BlockOutputResponse as ErrorResponse
- **Frontend**: Filter out empty error pins from `BlockOutputCard`
display and accordion metadata output key counting
- **Frontend**: Clear stale execution results before dry-run/no-input
runs so the UI shows fresh output
- **Frontend**: Fix first-click simulate race condition by invalidating
execution details query after WebSocket subscription confirms

## Test plan
- [x] All 12 existing + 5 new dry-run tests pass (`poetry run pytest
backend/copilot/tools/test_dry_run.py -x -v`)
- [x] All 23 helpers tests pass (`poetry run pytest
backend/copilot/tools/helpers_test.py -x -v`)
- [x] All 13 run_block tests pass (`poetry run pytest
backend/copilot/tools/run_block_test.py -x -v`)
- [x] Backend linting passes (ruff check + format)
- [x] Frontend linting passes (next lint)
- [ ] Manual: trigger dry-run on a block with error output pin (e.g.
Komodo Image Generator) — should show "Simulated" status with clean
output, no misleading "error" section
- [ ] Manual: first click on Simulate button should immediately show
results (no race condition)

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
2026-03-31 21:03:00 +00:00
Zamil Majdy
80581a8364 fix(copilot): add tool call circuit breakers and intermediate persistence (#12604)
## Why

CoPilot session `d2f7cba3` took **82 minutes** and cost **$20.66** for a
single user message. Root causes:
1. Redis session meta key expired after 1h, making the session invisible
to the resume endpoint — causing empty page on reload
2. Redis stream key also expired during sub-agent gaps (task_progress
events produced no chunks)
3. No intermediate persistence — session messages only saved to DB after
the entire turn completes
4. Sub-agents retried similar WebSearch queries (addressed via prompt
guidance)

## What

### Redis TTL fixes (root cause of empty session on reload)
- `publish_chunk()` now periodically refreshes **both** the session meta
key AND stream key TTL (every 60s).
- `task_progress` SDK events now emit `StreamHeartbeat` chunks, ensuring
`publish_chunk` is called even during long sub-agent gaps where no real
chunks are produced.
- Without this fix, turns exceeding the 1h `stream_ttl` lose their
"running" status and stream data, making `get_active_session()` return
False.

### Intermediate DB persistence
- Session messages flushed to DB every **30 seconds** or **10 new
messages** during the stream loop.
- Uses `asyncio.shield(upsert_chat_session())` matching the existing
`finally` block pattern.

### Orphaned message cleanup on rollback
- On stream attempt rollback, orphaned messages persisted by
intermediate flushes are now cleaned up from the DB via
`delete_messages_from_sequence`.
- Prevents stale messages from resurfacing on page reload after a failed
retry.

### Prompt guidance
- Added web search best practices to code supplement (search efficiency,
sub-agent scope separation).

### Approach: root cause fixes, not capability limits
- **No tool call caps** — artificial limits on WebSearch or total tool
calls would reduce autopilot capability without addressing why searches
were redundant.
- **Task tool remains enabled** — sub-agent delegation via Task is a
core capability. The existing `max_subtasks` concurrency guard is
sufficient.
- The real fixes (TTL refresh, persistence, prompt guidance) address the
underlying bugs and behavioral issues.

## How

### Files changed
- `stream_registry.py` — Redis meta + stream key TTL refresh in
`publish_chunk()`, module-level keepalive tracker
- `response_adapter.py` — `task_progress` SystemMessage →
StreamHeartbeat emission
- `service.py` — Intermediate DB persistence in `_run_stream_attempt`
stream loop, orphan cleanup on rollback
- `db.py` — `delete_messages_from_sequence` for rollback cleanup
- `prompting.py` — Web search best practices

### GCP log evidence
```
# Meta key expired during 82-min turn:
09:49 — GET_SESSION: active_session=False, msg_count=1  ← meta gone
10:18 — Session persisted in finally with 189 messages   ← turn completed

# T13 (1h45min) same bug reproduced live:
16:20 — task_progress events still arriving, but active_session=False

# Actual cost:
Turn usage: cache_read=347916, cache_create=212472, output=12375, cost_usd=20.66
```

### Test plan
- [x] task_progress emits StreamHeartbeat
- [x] Task background blocked, foreground allowed, slot release on
completion/failure
- [x] CI green (lint, type-check, tests, e2e, CodeQL)

---------

Co-authored-by: Zamil Majdy <majdy.zamil@gmail.com>
2026-03-31 21:01:56 +00:00
lif
3c046eb291 fix(frontend): show all agent outputs instead of only the last one (#12504)
Fixes #9175

### Changes 🏗️

The Agent Outputs panel only displayed the last execution result per
output node, discarding all prior outputs during a run.

**Root cause:** In `AgentOutputs.tsx`, the `outputs` useMemo extracted
only the last element from `nodeExecutionResults`:
```tsx
const latestResult = executionResults[executionResults.length - 1];
```

**Fix:** Changed `.map()` to `.flatMap()` over output nodes, iterating
through all `executionResults` for each node. Each execution result now
gets its own renderer lookup and metadata entry, so the panel shows
every output produced during the run.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Verified TypeScript compiles without errors
- [x] Confirmed the flatMap logic correctly iterates all execution
results
  - [x] Verified existing filter for null renderers is preserved
- [x] Run an agent with multiple outputs and confirm all show in the
panel

---------

Signed-off-by: majiayu000 <1835304752@qq.com>
Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-03-31 20:31:12 +00:00
Zamil Majdy
3e25488b2d feat(copilot): add session-level dry_run flag to autopilot sessions (#12582)
## Summary
- Adds a session-level `dry_run` flag that forces ALL tool calls
(`run_block`, `run_agent`) in a copilot/autopilot session to use dry-run
simulation mode
- Stores the flag in a typed `ChatSessionMetadata` JSON model on the
`ChatSession` DB row, accessed via `session.dry_run` property
- Adds `dry_run` to the AutoPilot block Input schema so graph builders
can create dry-run autopilot nodes
- Refactors multiple copilot tools from `**kwargs` to explicit
parameters for type safety

## Changes
- **Prisma schema**: Added `metadata` JSON column to `ChatSession` model
with migration
- **Python models**: Added `ChatSessionMetadata` model with `dry_run`
field, added `metadata` field to `ChatSessionInfo` and `ChatSession`,
updated `from_db()`, `new()`, and `create_chat_session()`
- **Session propagation**: `set_execution_context(user_id, session)`
called from `baseline/service.py` so tool handlers can read
session-level flags via `session.dry_run`
- **Tool enforcement**: `run_block` and `run_agent` check
`session.dry_run` and force `dry_run=True` when set; `run_agent` blocks
scheduling in dry-run sessions
- **AutoPilot block**: Added `dry_run` input field, passes it when
creating sessions
- **Chat API**: Added `CreateSessionRequest` model with `dry_run` field
to `POST /sessions` endpoint; added `metadata` to session responses
- **Frontend**: Updated `useChatSession.ts` to pass body to the create
session mutation
- **Tool refactoring**: Multiple copilot tools refactored from
`**kwargs` to explicit named parameters (agent_browser, manage_folders,
workspace_files, connect_integration, agent_output, bash_exec, etc.) for
better type safety

## Test plan
- [x] Unit tests for `ChatSession.new()` with dry_run parameter
- [x] Unit tests for `RunBlockTool` session dry_run override
- [x] Unit tests for `RunAgentTool` session dry_run override
- [x] Unit tests for session dry_run blocks scheduling
- [x] Existing dry_run tests still pass (12/12)
- [x] Existing permissions tests still pass
- [x] All pre-commit hooks pass (ruff, isort, pyright, tsc)
- [ ] Manual: Create autopilot session with `dry_run=True`, verify
run_block/run_agent calls use simulation

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-31 16:27:36 +00:00
Abhimanyu Yadav
57b17dc8e1 feat(platform): generic managed credential system with AgentMail auto-provisioning (#12537)
### Why / What / How

**Why:** We need a third credential type: **system-provided but unique
per user** (managed credentials). Currently we have system credentials
(same for all users) and user credentials (user provides their own
keys). Managed credentials bridge the gap — the platform provisions them
automatically, one per user, for integrations like AgentMail where each
user needs their own pod-scoped API key.

**What:**
- Generic **managed credential provider registry** — any integration can
register a provider that auto-provisions per-user credentials
- **AgentMail** is the first consumer: creates a pod + pod-scoped API
key using the org-level API key
- Managed credentials appear in the credential dropdown like normal API
keys but with `autogpt_managed=True` — users **cannot update or delete**
them
- **Auto-provisioning** on `GET /credentials` — lazily creates managed
credentials when users browse their credential list
- **Account deletion cleanup** utility — revokes external resources
(pods, API keys) before user deletion
- **Frontend UX** — hides the delete button for managed credentials on
the integrations page

**How:**

### Backend

**New files:**
- `backend/integrations/managed_credentials.py` —
`ManagedCredentialProvider` ABC, global registry,
`ensure_managed_credentials()` (with per-user asyncio lock +
`asyncio.gather` for concurrency), `cleanup_managed_credentials()`
- `backend/integrations/managed_providers/__init__.py` —
`register_all()` called at startup
- `backend/integrations/managed_providers/agentmail.py` —
`AgentMailManagedProvider` with `provision()` (creates pod + API key via
agentmail SDK) and `deprovision()` (deletes pod)

**Modified files:**
- `credentials_store.py` — `autogpt_managed` guards on update/delete,
`has_managed_credential()` / `add_managed_credential()` helpers
- `model.py` — `autogpt_managed: bool` + `metadata: dict` on
`_BaseCredentials`
- `router.py` — calls `ensure_managed_credentials()` in list endpoints,
removed explicit `/agentmail/connect` endpoint
- `user.py` — `cleanup_user_managed_credentials()` for account deletion
- `rest_api.py` — registers managed providers at startup
- `settings.py` — `agentmail_api_key` setting

### Frontend
- Added `autogpt_managed` to `CredentialsMetaResponse` type
- Conditionally hides delete button on integrations page for managed
credentials

### Key design decisions
- **Auto-provision in API layer, not data layer** — keeps
`get_all_creds()` side-effect-free
- **Race-safe** — per-(user, provider) asyncio lock with double-check
pattern prevents duplicate pods
- **Idempotent** — AgentMail SDK `client_id` ensures pod creation is
idempotent; `add_managed_credential()` uses upsert under Redis lock
- **Error-resilient** — provisioning failures are logged but never block
credential listing

### Changes 🏗️

| File | Action | Description |
|------|--------|-------------|
| `backend/integrations/managed_credentials.py` | NEW | ABC, registry,
ensure/cleanup |
| `backend/integrations/managed_providers/__init__.py` | NEW | Registers
all providers at startup |
| `backend/integrations/managed_providers/agentmail.py` | NEW |
AgentMail provisioning/deprovisioning |
| `backend/integrations/credentials_store.py` | MODIFY | Guards +
managed credential helpers |
| `backend/data/model.py` | MODIFY | `autogpt_managed` + `metadata`
fields |
| `backend/api/features/integrations/router.py` | MODIFY |
Auto-provision on list, removed `/agentmail/connect` |
| `backend/data/user.py` | MODIFY | Account deletion cleanup |
| `backend/api/rest_api.py` | MODIFY | Provider registration at startup
|
| `backend/util/settings.py` | MODIFY | `agentmail_api_key` setting |
| `frontend/.../integrations/page.tsx` | MODIFY | Hide delete for
managed creds |
| `frontend/.../types.ts` | MODIFY | `autogpt_managed` field |

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] 23 tests pass in `router_test.py` (9 new tests for
ensure/cleanup/auto-provisioning)
  - [x] `poetry run format && poetry run lint` — clean
  - [x] OpenAPI schema regenerated
- [x] Manual: verify managed credential appears in AgentMail block
dropdown
  - [x] Manual: verify delete button hidden for managed credentials
- [x] Manual: verify managed credential cannot be deleted via API (403)

#### For configuration changes:
- [x] `.env.default` is updated with `AGENTMAIL_API_KEY=`

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-03-31 12:56:18 +00:00
Krishna Chaitanya
a20188ae59 fix(blocks): validate non-empty input in AIConversationBlock before LLM call (#12545)
### Why / What / How

**Why:** When `AIConversationBlock` receives an empty messages list and
an empty prompt, the block blindly forwards the empty array to the
downstream LLM API, which returns a cryptic `400 Bad Request` error:
`"Invalid 'messages': empty array. Expected an array with minimum length
1."` This is confusing for users who don't understand why their agent
failed.

**What:** Add early input validation in `AIConversationBlock.run()` that
raises a clear `ValueError` when both `messages` and `prompt` are empty.
Also add three unit tests covering the validation logic.

**How:** A simple guard clause at the top of the `run` method checks `if
not input_data.messages and not input_data.prompt` before the LLM call
is made. If both are empty, a descriptive `ValueError` is raised. If
either one has content, the block proceeds normally.

### Changes

- `autogpt_platform/backend/backend/blocks/llm.py`: Add validation guard
in `AIConversationBlock.run()` to reject empty messages + empty prompt
before calling the LLM
- `autogpt_platform/backend/backend/blocks/test/test_llm.py`: Add
`TestAIConversationBlockValidation` with three tests:
- `test_empty_messages_and_empty_prompt_raises_error` — validates the
guard clause
- `test_empty_messages_with_prompt_succeeds` — ensures prompt-only usage
still works
- `test_nonempty_messages_with_empty_prompt_succeeds` — ensures
messages-only usage still works

### Checklist

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Lint passes (`ruff check`)
  - [x] Formatting passes (`ruff format`)
- [x] New unit tests validate the empty-input guard and the happy paths

Closes #11875

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-03-31 12:43:42 +00:00
2702 changed files with 84571 additions and 826447 deletions

1
.agents/skills Symbolic link
View File

@@ -0,0 +1 @@
../.claude/skills

10
.claude/settings.json Normal file
View File

@@ -0,0 +1,10 @@
{
"permissions": {
"allowedTools": [
"Read", "Grep", "Glob",
"Bash(ls:*)", "Bash(cat:*)", "Bash(grep:*)", "Bash(find:*)",
"Bash(git status:*)", "Bash(git diff:*)", "Bash(git log:*)", "Bash(git worktree:*)",
"Bash(tmux:*)", "Bash(sleep:*)", "Bash(branchlet:*)"
]
}
}

View File

@@ -95,6 +95,28 @@ Address comments **one at a time**: fix → commit → push → inline reply →
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in <commit-sha>: <description>"` |
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in <commit-sha>: <description>"` |
## Codecov coverage
Codecov patch target is **80%** on changed lines. Checks are **informational** (not blocking) but should be green.
### Running coverage locally
**Backend** (from `autogpt_platform/backend/`):
```bash
poetry run pytest -s -vv --cov=backend --cov-branch --cov-report term-missing
```
**Frontend** (from `autogpt_platform/frontend/`):
```bash
pnpm vitest run --coverage
```
### When codecov/patch fails
1. Find uncovered files: `git diff --name-only $(gh pr view --json baseRefName --jq '.baseRefName')...HEAD`
2. For each uncovered file — extract inline logic to `helpers.ts`/`helpers.py` and test those (highest ROI). Colocate tests as `*_test.py` (backend) or `__tests__/*.test.ts` (frontend).
3. Run coverage locally to verify, commit, push.
## Format and commit
After fixing, format the changed code:

View File

@@ -530,9 +530,19 @@ After showing all screenshots, output a **detailed** summary table:
# but Homebrew bash is 5.x; Linux typically has bash 5.x). If running on Bash <4, use a
# plain variable with a lookup function instead.
declare -A SCREENSHOT_EXPLANATIONS=(
["01-login-page.png"]="Shows the login page loaded successfully with SSO options visible."
["02-builder-with-block.png"]="The builder canvas displays the newly added block connected to the trigger."
# ... one entry per screenshot, using the same explanations you showed the user above
# Each explanation MUST answer three things:
# 1. FLOW: Which test scenario / user journey is this part of?
# 2. STEPS: What exact actions were taken to reach this state?
# 3. EVIDENCE: What does this screenshot prove (pass/fail/data)?
#
# Good example:
# ["03-cost-log-after-run.png"]="Flow: LLM block cost tracking. Steps: Logged in as tester@gmail.com → ran 'Cost Test Agent' → waited for COMPLETED status. Evidence: PlatformCostLog table shows 1 new row with cost_microdollars=1234 and correct user_id."
#
# Bad example (too vague — never do this):
# ["03-cost-log.png"]="Shows the cost log table."
["01-login-page.png"]="Flow: Login flow. Steps: Opened /login. Evidence: Login page renders with email/password fields and SSO options visible."
["02-builder-with-block.png"]="Flow: Block execution. Steps: Logged in → /build → added LLM block. Evidence: Builder canvas shows block connected to trigger, ready to run."
# ... one entry per screenshot using the flow/steps/evidence format above
)
TEST_RESULTS_TABLE="| 1 | Login flow | PASS | N/A | 01-login-before.png, 02-login-after.png |
@@ -547,6 +557,9 @@ Upload screenshots to the PR using the GitHub Git API (no local git operations
**This step is MANDATORY. Every test run MUST post a PR comment with screenshots. No exceptions.**
> **CRITICAL — NEVER post a bare directory link like `https://github.com/.../tree/...`.**
> Every screenshot MUST appear as `![name](raw_url)` inline in the PR comment so reviewers can see them without clicking any links. After posting, the verification step below greps the comment for `![` tags and exits 1 if none are found — the test run is considered incomplete until this passes.
```bash
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
REPO="Significant-Gravitas/AutoGPT"
@@ -582,12 +595,25 @@ for img in "${SCREENSHOT_FILES[@]}"; do
done
TREE_JSON+=']'
# Step 2: Create tree, commit, and branch ref
# Step 2: Create tree, commit (with parent), and branch ref
TREE_SHA=$(echo "$TREE_JSON" | jq -c '{tree: .}' | gh api "repos/${REPO}/git/trees" --input - --jq '.sha')
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
-f tree="$TREE_SHA" \
--jq '.sha')
# Resolve existing branch tip as parent (avoids orphan commits on repeat runs)
PARENT_SHA=$(gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" --jq '.object.sha' 2>/dev/null || true)
if [ -n "$PARENT_SHA" ]; then
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
-f tree="$TREE_SHA" \
-f "parents[]=$PARENT_SHA" \
--jq '.sha')
else
# First commit on this branch — no parent
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
-f tree="$TREE_SHA" \
--jq '.sha')
fi
gh api "repos/${REPO}/git/refs" \
-f ref="refs/heads/${SCREENSHOTS_BRANCH}" \
-f sha="$COMMIT_SHA" 2>/dev/null \
@@ -656,17 +682,123 @@ ${IMAGE_MARKDOWN}
${FAILED_SECTION}
INNEREOF
gh api "repos/${REPO}/issues/$PR_NUMBER/comments" -F body=@"$COMMENT_FILE"
POSTED_BODY=$(gh api "repos/${REPO}/issues/$PR_NUMBER/comments" -F body=@"$COMMENT_FILE" --jq '.body')
rm -f "$COMMENT_FILE"
```
**The PR comment MUST include:**
1. A summary table of all scenarios with PASS/FAIL and before/after API evidence
2. Every successfully uploaded screenshot rendered inline; any failed uploads listed with manual attachment instructions
3. A 1-2 sentence explanation below each screenshot describing what it proves
3. A structured explanation below each screenshot covering: **Flow** (which scenario), **Steps** (exact actions taken to reach this state), **Evidence** (what this proves — pass/fail/data values). A bare "shows the page" caption is not acceptable.
This approach uses the GitHub Git API to create blobs, trees, commits, and refs entirely server-side. No local `git checkout` or `git push` — safe for worktrees and won't interfere with the PR branch.
**Verify inline rendering after posting — this is required, not optional:**
```bash
# 1. Confirm the posted comment body contains inline image markdown syntax
if ! echo "$POSTED_BODY" | grep -q '!\['; then
echo "❌ FAIL: No inline image tags in posted comment body. Re-check IMAGE_MARKDOWN and re-post."
exit 1
fi
# 2. Verify at least one raw URL actually resolves (catches wrong branch name, wrong path, etc.)
FIRST_IMG_URL=$(echo "$POSTED_BODY" | grep -o 'https://raw.githubusercontent.com[^)]*' | head -1)
if [ -n "$FIRST_IMG_URL" ]; then
HTTP_STATUS=$(curl -s -o /dev/null -w "%{http_code}" --max-time 10 "$FIRST_IMG_URL")
if [ "$HTTP_STATUS" = "200" ]; then
echo "✅ Inline images confirmed and raw URL resolves (HTTP 200)"
else
echo "❌ FAIL: Raw image URL returned HTTP $HTTP_STATUS — images will not render inline."
echo " URL: $FIRST_IMG_URL"
echo " Check branch name, path, and that the push succeeded."
exit 1
fi
else
echo "⚠️ Could not extract a raw URL from the comment — verify manually."
fi
```
## Step 8: Evaluate test completeness and post a GitHub review
After posting the PR comment, evaluate whether the test run actually covered everything it needed to. This is NOT a rubber-stamp — be critical. Then post a formal GitHub review so the PR author and reviewers can see the verdict.
### 8a. Evaluate against the test plan
Re-read `$RESULTS_DIR/test-plan.md` (written in Step 2) and `$RESULTS_DIR/test-report.md` (written in Step 5). For each scenario in the plan, answer:
> **Note:** `test-report.md` is written in Step 5. If it doesn't exist, write it before proceeding here — see the Step 5 template. Do not skip evaluation because the file is missing; create it from your notes instead.
| Question | Pass criteria |
|----------|--------------|
| Was it tested? | Explicit steps were executed, not just described |
| Is there screenshot evidence? | At least one before/after screenshot per scenario |
| Did the core feature work correctly? | Expected state matches actual state |
| Were negative cases tested? | At least one failure/rejection case per feature |
| Was DB/API state verified (not just UI)? | Raw API response or DB query confirms state change |
Build a verdict:
- **APPROVE** — every scenario tested, evidence present, no bugs found or all bugs are minor/known
- **REQUEST_CHANGES** — one or more: untested scenarios, missing evidence, bugs found, data not verified
### 8b. Post the GitHub review
```bash
EVAL_FILE=$(mktemp)
# === STEP A: Write header ===
cat > "$EVAL_FILE" << 'ENDEVAL'
## 🧪 Test Evaluation
### Coverage checklist
ENDEVAL
# === STEP B: Append ONE line per scenario — do this BEFORE calculating verdict ===
# Format: "- ✅ **Scenario N name**: <what was done and verified>"
# or "- ❌ **Scenario N name**: <what is missing or broken>"
# Examples:
# echo "- ✅ **Scenario 1 Login flow**: tested, screenshot evidence present, auth token verified via API" >> "$EVAL_FILE"
# echo "- ❌ **Scenario 3 Cost logging**: NOT verified in DB — UI showed entry but raw SQL query was skipped" >> "$EVAL_FILE"
#
# !!! IMPORTANT: append ALL scenario lines here before proceeding to STEP C !!!
# === STEP C: Derive verdict from the checklist — runs AFTER all lines are appended ===
FAIL_COUNT=$(grep -c "^- ❌" "$EVAL_FILE" || true)
if [ "$FAIL_COUNT" -eq 0 ]; then
VERDICT="APPROVE"
else
VERDICT="REQUEST_CHANGES"
fi
# === STEP D: Append verdict section ===
cat >> "$EVAL_FILE" << ENDVERDICT
### Verdict
ENDVERDICT
if [ "$VERDICT" = "APPROVE" ]; then
echo "✅ All scenarios covered with evidence. No blocking issues found." >> "$EVAL_FILE"
else
echo "$FAIL_COUNT scenario(s) incomplete or have confirmed bugs. See ❌ items above." >> "$EVAL_FILE"
echo "" >> "$EVAL_FILE"
echo "**Required before merge:** address each ❌ item above." >> "$EVAL_FILE"
fi
# === STEP E: Post the review ===
gh api "repos/${REPO}/pulls/$PR_NUMBER/reviews" \
--method POST \
-f body="$(cat "$EVAL_FILE")" \
-f event="$VERDICT"
rm -f "$EVAL_FILE"
```
**Rules:**
- Never auto-approve without checking every scenario in the test plan
- `REQUEST_CHANGES` if ANY scenario is untested, lacks DB/API evidence, or has a confirmed bug
- The evaluation body must list every scenario explicitly (✅ or ❌) — not just the failures
- If you find new bugs during evaluation, add them to the request-changes body and (if `--fix` flag is set) fix them before posting
## Fix mode (--fix flag)
When `--fix` is present, the standard is HIGHER. Do not just note issues — FIX them immediately.

View File

@@ -0,0 +1,224 @@
---
name: write-frontend-tests
description: "Analyze the current branch diff against dev, plan integration tests for changed frontend pages/components, and write them. TRIGGER when user asks to write frontend tests, add test coverage, or 'write tests for my changes'."
user-invocable: true
args: "[base branch] — defaults to dev. Optionally pass a specific base branch to diff against."
metadata:
author: autogpt-team
version: "1.0.0"
---
# Write Frontend Tests
Analyze the current branch's frontend changes, plan integration tests, and write them.
## References
Before writing any tests, read the testing rules and conventions:
- `autogpt_platform/frontend/TESTING.md` — testing strategy, file locations, examples
- `autogpt_platform/frontend/src/tests/AGENTS.md` — detailed testing rules, MSW patterns, decision flowchart
- `autogpt_platform/frontend/src/tests/integrations/test-utils.tsx` — custom render with providers
- `autogpt_platform/frontend/src/tests/integrations/vitest.setup.tsx` — MSW server setup
## Step 1: Identify changed frontend files
```bash
BASE_BRANCH="${ARGUMENTS:-dev}"
cd autogpt_platform/frontend
# Get changed frontend files (excluding generated, config, and test files)
git diff "$BASE_BRANCH"...HEAD --name-only -- src/ \
| grep -v '__generated__' \
| grep -v '__tests__' \
| grep -v '\.test\.' \
| grep -v '\.stories\.' \
| grep -v '\.spec\.'
```
Also read the diff to understand what changed:
```bash
git diff "$BASE_BRANCH"...HEAD --stat -- src/
git diff "$BASE_BRANCH"...HEAD -- src/ | head -500
```
## Step 2: Categorize changes and find test targets
For each changed file, determine:
1. **Is it a page?** (`page.tsx`) — these are the primary test targets
2. **Is it a hook?** (`use*.ts`) — test via the page that uses it
3. **Is it a component?** (`.tsx` in `components/`) — test via the parent page unless it's complex enough to warrant isolation
4. **Is it a helper?** (`helpers.ts`, `utils.ts`) — unit test directly if pure logic
**Priority order:**
1. Pages with new/changed data fetching or user interactions
2. Components with complex internal logic (modals, forms, wizards)
3. Hooks with non-trivial business logic
4. Pure helper functions
Skip: styling-only changes, type-only changes, config changes.
## Step 3: Check for existing tests
For each test target, check if tests already exist:
```bash
# For a page at src/app/(platform)/library/page.tsx
ls src/app/\(platform\)/library/__tests__/ 2>/dev/null
# For a component at src/app/(platform)/library/components/AgentCard/AgentCard.tsx
ls src/app/\(platform\)/library/components/AgentCard/__tests__/ 2>/dev/null
```
Note which targets have no tests (need new files) vs which have tests that need updating.
## Step 4: Identify API endpoints used
For each test target, find which API hooks are used:
```bash
# Find generated API hook imports in the changed files
grep -rn 'from.*__generated__/endpoints' src/app/\(platform\)/library/
grep -rn 'use[A-Z].*V[12]' src/app/\(platform\)/library/
```
For each API hook found, locate the corresponding MSW handler:
```bash
# If the page uses useGetV2ListLibraryAgents, find its MSW handlers
grep -rn 'getGetV2ListLibraryAgents.*Handler' src/app/api/__generated__/endpoints/library/library.msw.ts
```
List every MSW handler you will need (200 for happy path, 4xx for error paths).
## Step 5: Write the test plan
Before writing code, output a plan as a numbered list:
```
Test plan for [branch name]:
1. src/app/(platform)/library/__tests__/main.test.tsx (NEW)
- Renders page with agent list (MSW 200)
- Shows loading state
- Shows error state (MSW 422)
- Handles empty agent list
2. src/app/(platform)/library/__tests__/search.test.tsx (NEW)
- Filters agents by search query
- Shows no results message
- Clears search
3. src/app/(platform)/library/components/AgentCard/__tests__/AgentCard.test.tsx (UPDATE)
- Add test for new "duplicate" action
```
Present this plan to the user. Wait for confirmation before proceeding. If the user has feedback, adjust the plan.
## Step 6: Write the tests
For each test file in the plan, follow these conventions:
### File structure
```tsx
import { render, screen, waitFor } from "@/tests/integrations/test-utils";
import { server } from "@/mocks/mock-server";
// Import MSW handlers for endpoints the page uses
import {
getGetV2ListLibraryAgentsMockHandler200,
getGetV2ListLibraryAgentsMockHandler422,
} from "@/app/api/__generated__/endpoints/library/library.msw";
// Import the component under test
import LibraryPage from "../page";
describe("LibraryPage", () => {
test("renders agent list from API", async () => {
server.use(getGetV2ListLibraryAgentsMockHandler200());
render(<LibraryPage />);
expect(await screen.findByText(/my agents/i)).toBeDefined();
});
test("shows error state on API failure", async () => {
server.use(getGetV2ListLibraryAgentsMockHandler422());
render(<LibraryPage />);
expect(await screen.findByText(/error/i)).toBeDefined();
});
});
```
### Rules
- Use `render()` from `@/tests/integrations/test-utils` (NOT from `@testing-library/react` directly)
- Use `server.use()` to set up MSW handlers BEFORE rendering
- Use `findBy*` (async) for elements that appear after data fetching — NOT `getBy*`
- Use `getBy*` only for elements that are immediately present in the DOM
- Use `screen` queries — do NOT destructure from `render()`
- Use `waitFor` when asserting side effects or state changes after interactions
- Import `fireEvent` or `userEvent` from the test-utils for interactions
- Do NOT mock internal hooks or functions — mock at the API boundary via MSW
- Do NOT use `act()` manually — `render` and `fireEvent` handle it
- Keep tests focused: one behavior per test
- Use descriptive test names that read like sentences
### Test location
```
# For pages: __tests__/ next to page.tsx
src/app/(platform)/library/__tests__/main.test.tsx
# For complex standalone components: __tests__/ inside component folder
src/app/(platform)/library/components/AgentCard/__tests__/AgentCard.test.tsx
# For pure helpers: co-located .test.ts
src/app/(platform)/library/helpers.test.ts
```
### Custom MSW overrides
When the auto-generated faker data is not enough, override with specific data:
```tsx
import { http, HttpResponse } from "msw";
server.use(
http.get("http://localhost:3000/api/proxy/api/v2/library/agents", () => {
return HttpResponse.json({
agents: [
{ id: "1", name: "Test Agent", description: "A test agent" },
],
pagination: { total_items: 1, total_pages: 1, page: 1, page_size: 10 },
});
}),
);
```
Use the proxy URL pattern: `http://localhost:3000/api/proxy/api/v{version}/{path}` — this matches the MSW base URL configured in `orval.config.ts`.
## Step 7: Run and verify
After writing all tests:
```bash
cd autogpt_platform/frontend
pnpm test:unit --reporter=verbose
```
If tests fail:
1. Read the error output carefully
2. Fix the test (not the source code, unless there is a genuine bug)
3. Re-run until all pass
Then run the full checks:
```bash
pnpm format
pnpm lint
pnpm types
```

View File

@@ -6,11 +6,19 @@ on:
paths:
- '.github/workflows/classic-autogpt-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/direct_benchmark/**'
- 'classic/forge/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
pull_request:
branches: [ master, dev, release-* ]
paths:
- '.github/workflows/classic-autogpt-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/direct_benchmark/**'
- 'classic/forge/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
concurrency:
group: ${{ format('classic-autogpt-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
@@ -19,47 +27,22 @@ concurrency:
defaults:
run:
shell: bash
working-directory: classic/original_autogpt
working-directory: classic
jobs:
test:
permissions:
contents: read
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
platform-os: [ubuntu, macos, macos-arm64, windows]
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
runs-on: ubuntu-latest
steps:
# Quite slow on macOS (2~4 minutes to set up Docker)
# - name: Set up Docker (macOS)
# if: runner.os == 'macOS'
# uses: crazy-max/ghaction-setup-docker@v3
- name: Start MinIO service (Linux)
if: runner.os == 'Linux'
- name: Start MinIO service
working-directory: '.'
run: |
docker pull minio/minio:edge-cicd
docker run -d -p 9000:9000 minio/minio:edge-cicd
- name: Start MinIO service (macOS)
if: runner.os == 'macOS'
working-directory: ${{ runner.temp }}
run: |
brew install minio/stable/minio
mkdir data
minio server ./data &
# No MinIO on Windows:
# - Windows doesn't support running Linux Docker containers
# - It doesn't seem possible to start background processes on Windows. They are
# killed after the step returns.
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
- name: Checkout repository
uses: actions/checkout@v4
with:
@@ -71,41 +54,23 @@ jobs:
git config --global user.name "Auto-GPT-Bot"
git config --global user.email "github-bot@agpt.co"
- name: Set up Python ${{ matrix.python-version }}
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
python-version: "3.12"
- id: get_date
name: Get date
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
- name: Set up Python dependency cache
# On Windows, unpacking cached dependencies takes longer than just installing them
if: runner.os != 'Windows'
uses: actions/cache@v4
with:
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
key: poetry-${{ runner.os }}-${{ hashFiles('classic/original_autogpt/poetry.lock') }}
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
- name: Install Poetry (Unix)
if: runner.os != 'Windows'
run: |
curl -sSL https://install.python-poetry.org | python3 -
if [ "${{ runner.os }}" = "macOS" ]; then
PATH="$HOME/.local/bin:$PATH"
echo "$HOME/.local/bin" >> $GITHUB_PATH
fi
- name: Install Poetry (Windows)
if: runner.os == 'Windows'
shell: pwsh
run: |
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
$env:PATH += ";$env:APPDATA\Python\Scripts"
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
- name: Install Python dependencies
run: poetry install
@@ -116,12 +81,13 @@ jobs:
--cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
--numprocesses=logical --durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
tests/unit tests/integration
original_autogpt/tests/unit original_autogpt/tests/integration
env:
CI: true
PLAIN_OUTPUT: True
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
S3_ENDPOINT_URL: http://127.0.0.1:9000
AWS_ACCESS_KEY_ID: minioadmin
AWS_SECRET_ACCESS_KEY: minioadmin
@@ -135,11 +101,11 @@ jobs:
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: autogpt-agent,${{ runner.os }}
flags: autogpt-agent
- name: Upload logs to artifact
if: always()
uses: actions/upload-artifact@v4
with:
name: test-logs
path: classic/original_autogpt/logs/
path: classic/logs/

View File

@@ -148,7 +148,7 @@ jobs:
--entrypoint poetry ${{ env.IMAGE_NAME }} run \
pytest -v --cov=autogpt --cov-branch --cov-report term-missing \
--numprocesses=4 --durations=10 \
tests/unit tests/integration 2>&1 | tee test_output.txt
original_autogpt/tests/unit original_autogpt/tests/integration 2>&1 | tee test_output.txt
test_failure=${PIPESTATUS[0]}

View File

@@ -10,10 +10,9 @@ on:
- '.github/workflows/classic-autogpts-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/benchmark/**'
- 'classic/run'
- 'classic/cli.py'
- 'classic/setup.py'
- 'classic/direct_benchmark/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
- '!**/*.md'
pull_request:
branches: [ master, dev, release-* ]
@@ -21,10 +20,9 @@ on:
- '.github/workflows/classic-autogpts-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/benchmark/**'
- 'classic/run'
- 'classic/cli.py'
- 'classic/setup.py'
- 'classic/direct_benchmark/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
- '!**/*.md'
defaults:
@@ -35,13 +33,9 @@ defaults:
jobs:
serve-agent-protocol:
runs-on: ubuntu-latest
strategy:
matrix:
agent-name: [ original_autogpt ]
fail-fast: false
timeout-minutes: 20
env:
min-python-version: '3.10'
min-python-version: '3.12'
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -55,22 +49,22 @@ jobs:
python-version: ${{ env.min-python-version }}
- name: Install Poetry
working-directory: ./classic/${{ matrix.agent-name }}/
run: |
curl -sSL https://install.python-poetry.org | python -
- name: Run regression tests
- name: Install dependencies
run: poetry install
- name: Run smoke tests with direct-benchmark
run: |
./run agent start ${{ matrix.agent-name }}
cd ${{ matrix.agent-name }}
poetry run agbenchmark --mock --test=BasicRetrieval --test=Battleship --test=WebArenaTask_0
poetry run agbenchmark --test=WriteFile
poetry run direct-benchmark run \
--strategies one_shot \
--models claude \
--tests ReadFile,WriteFile \
--json
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
AGENT_NAME: ${{ matrix.agent-name }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
REQUESTS_CA_BUNDLE: /etc/ssl/certs/ca-certificates.crt
HELICONE_CACHE_ENABLED: false
HELICONE_PROPERTY_AGENT: ${{ matrix.agent-name }}
REPORTS_FOLDER: ${{ format('../../reports/{0}', matrix.agent-name) }}
TELEMETRY_ENVIRONMENT: autogpt-ci
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
NONINTERACTIVE_MODE: "true"
CI: true

View File

@@ -1,18 +1,24 @@
name: Classic - AGBenchmark CI
name: Classic - Direct Benchmark CI
on:
push:
branches: [ master, dev, ci-test* ]
paths:
- 'classic/benchmark/**'
- '!classic/benchmark/reports/**'
- 'classic/direct_benchmark/**'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- .github/workflows/classic-benchmark-ci.yml
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
pull_request:
branches: [ master, dev, release-* ]
paths:
- 'classic/benchmark/**'
- '!classic/benchmark/reports/**'
- 'classic/direct_benchmark/**'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- .github/workflows/classic-benchmark-ci.yml
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
concurrency:
group: ${{ format('benchmark-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
@@ -23,95 +29,16 @@ defaults:
shell: bash
env:
min-python-version: '3.10'
min-python-version: '3.12'
jobs:
test:
permissions:
contents: read
benchmark-tests:
runs-on: ubuntu-latest
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
platform-os: [ubuntu, macos, macos-arm64, windows]
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
defaults:
run:
shell: bash
working-directory: classic/benchmark
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Set up Python dependency cache
# On Windows, unpacking cached dependencies takes longer than just installing them
if: runner.os != 'Windows'
uses: actions/cache@v4
with:
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
key: poetry-${{ runner.os }}-${{ hashFiles('classic/benchmark/poetry.lock') }}
- name: Install Poetry (Unix)
if: runner.os != 'Windows'
run: |
curl -sSL https://install.python-poetry.org | python3 -
if [ "${{ runner.os }}" = "macOS" ]; then
PATH="$HOME/.local/bin:$PATH"
echo "$HOME/.local/bin" >> $GITHUB_PATH
fi
- name: Install Poetry (Windows)
if: runner.os == 'Windows'
shell: pwsh
run: |
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
$env:PATH += ";$env:APPDATA\Python\Scripts"
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
- name: Install Python dependencies
run: poetry install
- name: Run pytest with coverage
run: |
poetry run pytest -vv \
--cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
--durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
tests
env:
CI: true
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- name: Upload test results to Codecov
if: ${{ !cancelled() }} # Run even if tests fail
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }}
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: agbenchmark,${{ runner.os }}
self-test-with-agent:
runs-on: ubuntu-latest
strategy:
matrix:
agent-name: [forge]
fail-fast: false
timeout-minutes: 20
working-directory: classic
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -124,53 +51,120 @@ jobs:
with:
python-version: ${{ env.min-python-version }}
- name: Set up Python dependency cache
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python -
curl -sSL https://install.python-poetry.org | python3 -
- name: Install dependencies
run: poetry install
- name: Run basic benchmark tests
run: |
echo "Testing ReadFile challenge with one_shot strategy..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot \
--models claude \
--tests ReadFile \
--json
echo "Testing WriteFile challenge..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot \
--models claude \
--tests WriteFile \
--json
env:
CI: true
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
NONINTERACTIVE_MODE: "true"
- name: Test category filtering
run: |
echo "Testing coding category..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot \
--models claude \
--categories coding \
--tests ReadFile,WriteFile \
--json
env:
CI: true
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
NONINTERACTIVE_MODE: "true"
- name: Test multiple strategies
run: |
echo "Testing multiple strategies..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot,plan_execute \
--models claude \
--tests ReadFile \
--parallel 2 \
--json
env:
CI: true
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
NONINTERACTIVE_MODE: "true"
# Run regression tests on maintain challenges
regression-tests:
runs-on: ubuntu-latest
timeout-minutes: 45
if: github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev'
defaults:
run:
shell: bash
working-directory: classic
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true
- name: Set up Python ${{ env.min-python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ env.min-python-version }}
- name: Set up Python dependency cache
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python3 -
- name: Install dependencies
run: poetry install
- name: Run regression tests
working-directory: classic
run: |
./run agent start ${{ matrix.agent-name }}
cd ${{ matrix.agent-name }}
set +e # Ignore non-zero exit codes and continue execution
echo "Running the following command: poetry run agbenchmark --maintain --mock"
poetry run agbenchmark --maintain --mock
EXIT_CODE=$?
set -e # Stop ignoring non-zero exit codes
# Check if the exit code was 5, and if so, exit with 0 instead
if [ $EXIT_CODE -eq 5 ]; then
echo "regression_tests.json is empty."
fi
echo "Running the following command: poetry run agbenchmark --mock"
poetry run agbenchmark --mock
echo "Running the following command: poetry run agbenchmark --mock --category=data"
poetry run agbenchmark --mock --category=data
echo "Running the following command: poetry run agbenchmark --mock --category=coding"
poetry run agbenchmark --mock --category=coding
# echo "Running the following command: poetry run agbenchmark --test=WriteFile"
# poetry run agbenchmark --test=WriteFile
cd ../benchmark
poetry install
echo "Adding the BUILD_SKILL_TREE environment variable. This will attempt to add new elements in the skill tree. If new elements are added, the CI fails because they should have been pushed"
export BUILD_SKILL_TREE=true
# poetry run agbenchmark --mock
# CHANGED=$(git diff --name-only | grep -E '(agbenchmark/challenges)|(../classic/frontend/assets)') || echo "No diffs"
# if [ ! -z "$CHANGED" ]; then
# echo "There are unstaged changes please run agbenchmark and commit those changes since they are needed."
# echo "$CHANGED"
# exit 1
# else
# echo "No unstaged changes."
# fi
echo "Running regression tests (previously beaten challenges)..."
poetry run direct-benchmark run \
--fresh \
--strategies one_shot \
--models claude \
--maintain \
--parallel 4 \
--json
env:
CI: true
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
TELEMETRY_ENVIRONMENT: autogpt-benchmark-ci
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
NONINTERACTIVE_MODE: "true"

View File

@@ -6,13 +6,15 @@ on:
paths:
- '.github/workflows/classic-forge-ci.yml'
- 'classic/forge/**'
- '!classic/forge/tests/vcr_cassettes'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
pull_request:
branches: [ master, dev, release-* ]
paths:
- '.github/workflows/classic-forge-ci.yml'
- 'classic/forge/**'
- '!classic/forge/tests/vcr_cassettes'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
concurrency:
group: ${{ format('forge-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
@@ -21,131 +23,60 @@ concurrency:
defaults:
run:
shell: bash
working-directory: classic/forge
working-directory: classic
jobs:
test:
permissions:
contents: read
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
platform-os: [ubuntu, macos, macos-arm64, windows]
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
runs-on: ubuntu-latest
steps:
# Quite slow on macOS (2~4 minutes to set up Docker)
# - name: Set up Docker (macOS)
# if: runner.os == 'macOS'
# uses: crazy-max/ghaction-setup-docker@v3
- name: Start MinIO service (Linux)
if: runner.os == 'Linux'
- name: Start MinIO service
working-directory: '.'
run: |
docker pull minio/minio:edge-cicd
docker run -d -p 9000:9000 minio/minio:edge-cicd
- name: Start MinIO service (macOS)
if: runner.os == 'macOS'
working-directory: ${{ runner.temp }}
run: |
brew install minio/stable/minio
mkdir data
minio server ./data &
# No MinIO on Windows:
# - Windows doesn't support running Linux Docker containers
# - It doesn't seem possible to start background processes on Windows. They are
# killed after the step returns.
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true
- name: Checkout cassettes
if: ${{ startsWith(github.event_name, 'pull_request') }}
env:
PR_BASE: ${{ github.event.pull_request.base.ref }}
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
run: |
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
cassette_base_branch="${PR_BASE}"
cd tests/vcr_cassettes
if ! git ls-remote --exit-code --heads origin $cassette_base_branch ; then
cassette_base_branch="master"
fi
if git ls-remote --exit-code --heads origin $cassette_branch ; then
git fetch origin $cassette_branch
git fetch origin $cassette_base_branch
git checkout $cassette_branch
# Pick non-conflicting cassette updates from the base branch
git merge --no-commit --strategy-option=ours origin/$cassette_base_branch
echo "Using cassettes from mirror branch '$cassette_branch'," \
"synced to upstream branch '$cassette_base_branch'."
else
git checkout -b $cassette_branch
echo "Branch '$cassette_branch' does not exist in cassette submodule." \
"Using cassettes from '$cassette_base_branch'."
fi
- name: Set up Python ${{ matrix.python-version }}
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
python-version: "3.12"
- name: Set up Python dependency cache
# On Windows, unpacking cached dependencies takes longer than just installing them
if: runner.os != 'Windows'
uses: actions/cache@v4
with:
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
key: poetry-${{ runner.os }}-${{ hashFiles('classic/forge/poetry.lock') }}
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
- name: Install Poetry (Unix)
if: runner.os != 'Windows'
run: |
curl -sSL https://install.python-poetry.org | python3 -
if [ "${{ runner.os }}" = "macOS" ]; then
PATH="$HOME/.local/bin:$PATH"
echo "$HOME/.local/bin" >> $GITHUB_PATH
fi
- name: Install Poetry (Windows)
if: runner.os == 'Windows'
shell: pwsh
run: |
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
$env:PATH += ";$env:APPDATA\Python\Scripts"
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
- name: Install Python dependencies
run: poetry install
- name: Install Playwright browsers
run: poetry run playwright install chromium
- name: Run pytest with coverage
run: |
poetry run pytest -vv \
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
--durations=10 \
--junitxml=junit.xml -o junit_family=legacy \
forge
forge/forge forge/tests
env:
CI: true
PLAIN_OUTPUT: True
# API keys - tests that need these will skip if not available
# Secrets are not available to fork PRs (GitHub security feature)
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
S3_ENDPOINT_URL: http://127.0.0.1:9000
AWS_ACCESS_KEY_ID: minioadmin
AWS_SECRET_ACCESS_KEY: minioadmin
@@ -159,85 +90,11 @@ jobs:
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: forge,${{ runner.os }}
- id: setup_git_auth
name: Set up git token authentication
# Cassettes may be pushed even when tests fail
if: success() || failure()
run: |
config_key="http.${{ github.server_url }}/.extraheader"
if [ "${{ runner.os }}" = 'macOS' ]; then
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64)
else
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64 -w0)
fi
git config "$config_key" \
"Authorization: Basic $base64_pat"
cd tests/vcr_cassettes
git config "$config_key" \
"Authorization: Basic $base64_pat"
echo "config_key=$config_key" >> $GITHUB_OUTPUT
- id: push_cassettes
name: Push updated cassettes
# For pull requests, push updated cassettes even when tests fail
if: github.event_name == 'push' || (! github.event.pull_request.head.repo.fork && (success() || failure()))
env:
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
run: |
if [ "${{ startsWith(github.event_name, 'pull_request') }}" = "true" ]; then
is_pull_request=true
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
else
cassette_branch="${{ github.ref_name }}"
fi
cd tests/vcr_cassettes
# Commit & push changes to cassettes if any
if ! git diff --quiet; then
git add .
git commit -m "Auto-update cassettes"
git push origin HEAD:$cassette_branch
if [ ! $is_pull_request ]; then
cd ../..
git add tests/vcr_cassettes
git commit -m "Update cassette submodule"
git push origin HEAD:$cassette_branch
fi
echo "updated=true" >> $GITHUB_OUTPUT
else
echo "updated=false" >> $GITHUB_OUTPUT
echo "No cassette changes to commit"
fi
- name: Post Set up git token auth
if: steps.setup_git_auth.outcome == 'success'
run: |
git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
git submodule foreach git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
- name: Apply "behaviour change" label and comment on PR
if: ${{ startsWith(github.event_name, 'pull_request') }}
run: |
PR_NUMBER="${{ github.event.pull_request.number }}"
TOKEN="${{ secrets.PAT_REVIEW }}"
REPO="${{ github.repository }}"
if [[ "${{ steps.push_cassettes.outputs.updated }}" == "true" ]]; then
echo "Adding label and comment..."
echo $TOKEN | gh auth login --with-token
gh issue edit $PR_NUMBER --add-label "behaviour change"
gh issue comment $PR_NUMBER --body "You changed AutoGPT's behaviour on ${{ runner.os }}. The cassettes have been updated and will be merged to the submodule when this Pull Request gets merged."
fi
flags: forge
- name: Upload logs to artifact
if: always()
uses: actions/upload-artifact@v4
with:
name: test-logs
path: classic/forge/logs/
path: classic/logs/

View File

@@ -1,60 +0,0 @@
name: Classic - Frontend CI/CD
on:
push:
branches:
- master
- dev
- 'ci-test*' # This will match any branch that starts with "ci-test"
paths:
- 'classic/frontend/**'
- '.github/workflows/classic-frontend-ci.yml'
pull_request:
paths:
- 'classic/frontend/**'
- '.github/workflows/classic-frontend-ci.yml'
jobs:
build:
permissions:
contents: write
pull-requests: write
runs-on: ubuntu-latest
env:
BUILD_BRANCH: ${{ format('classic-frontend-build/{0}', github.ref_name) }}
steps:
- name: Checkout Repo
uses: actions/checkout@v4
- name: Setup Flutter
uses: subosito/flutter-action@v2
with:
flutter-version: '3.13.2'
- name: Build Flutter to Web
run: |
cd classic/frontend
flutter build web --base-href /app/
# - name: Commit and Push to ${{ env.BUILD_BRANCH }}
# if: github.event_name == 'push'
# run: |
# git config --local user.email "action@github.com"
# git config --local user.name "GitHub Action"
# git add classic/frontend/build/web
# git checkout -B ${{ env.BUILD_BRANCH }}
# git commit -m "Update frontend build to ${GITHUB_SHA:0:7}" -a
# git push -f origin ${{ env.BUILD_BRANCH }}
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
if: github.event_name == 'push'
uses: peter-evans/create-pull-request@v8
with:
add-paths: classic/frontend/build/web
base: ${{ github.ref_name }}
branch: ${{ env.BUILD_BRANCH }}
delete-branch: true
title: "Update frontend build in `${{ github.ref_name }}`"
body: "This PR updates the frontend build based on commit ${{ github.sha }}."
commit-message: "Update frontend build based on commit ${{ github.sha }}"

View File

@@ -7,7 +7,9 @@ on:
- '.github/workflows/classic-python-checks-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/benchmark/**'
- 'classic/direct_benchmark/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
- '**.py'
- '!classic/forge/tests/vcr_cassettes'
pull_request:
@@ -16,7 +18,9 @@ on:
- '.github/workflows/classic-python-checks-ci.yml'
- 'classic/original_autogpt/**'
- 'classic/forge/**'
- 'classic/benchmark/**'
- 'classic/direct_benchmark/**'
- 'classic/pyproject.toml'
- 'classic/poetry.lock'
- '**.py'
- '!classic/forge/tests/vcr_cassettes'
@@ -27,44 +31,13 @@ concurrency:
defaults:
run:
shell: bash
working-directory: classic
jobs:
get-changed-parts:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- id: changes-in
name: Determine affected subprojects
uses: dorny/paths-filter@v3
with:
filters: |
original_autogpt:
- classic/original_autogpt/autogpt/**
- classic/original_autogpt/tests/**
- classic/original_autogpt/poetry.lock
forge:
- classic/forge/forge/**
- classic/forge/tests/**
- classic/forge/poetry.lock
benchmark:
- classic/benchmark/agbenchmark/**
- classic/benchmark/tests/**
- classic/benchmark/poetry.lock
outputs:
changed-parts: ${{ steps.changes-in.outputs.changes }}
lint:
needs: get-changed-parts
runs-on: ubuntu-latest
env:
min-python-version: "3.10"
strategy:
matrix:
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
fail-fast: false
min-python-version: "3.12"
steps:
- name: Checkout repository
@@ -81,42 +54,31 @@ jobs:
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
# Install dependencies
- name: Install Python dependencies
run: poetry -C classic/${{ matrix.sub-package }} install
run: poetry install
# Lint
- name: Lint (isort)
run: poetry run isort --check .
working-directory: classic/${{ matrix.sub-package }}
- name: Lint (Black)
if: success() || failure()
run: poetry run black --check .
working-directory: classic/${{ matrix.sub-package }}
- name: Lint (Flake8)
if: success() || failure()
run: poetry run flake8 .
working-directory: classic/${{ matrix.sub-package }}
types:
needs: get-changed-parts
runs-on: ubuntu-latest
env:
min-python-version: "3.10"
strategy:
matrix:
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
fail-fast: false
min-python-version: "3.12"
steps:
- name: Checkout repository
@@ -133,19 +95,16 @@ jobs:
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
- name: Install Poetry
run: curl -sSL https://install.python-poetry.org | python3 -
# Install dependencies
- name: Install Python dependencies
run: poetry -C classic/${{ matrix.sub-package }} install
run: poetry install
# Typecheck
- name: Typecheck
if: success() || failure()
run: poetry run pyright
working-directory: classic/${{ matrix.sub-package }}

View File

@@ -269,12 +269,14 @@ jobs:
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
- name: Run pytest
- name: Run pytest with coverage
run: |
if [[ "${{ runner.debug }}" == "1" ]]; then
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG \
--cov=backend --cov-branch --cov-report term-missing --cov-report xml
else
poetry run pytest -s -vv
poetry run pytest -s -vv \
--cov=backend --cov-branch --cov-report term-missing --cov-report xml
fi
env:
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
@@ -287,11 +289,13 @@ jobs:
REDIS_PORT: "6379"
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
# - name: Upload coverage reports to Codecov
# uses: codecov/codecov-action@v4
# with:
# token: ${{ secrets.CODECOV_TOKEN }}
# flags: backend,${{ runner.os }}
- name: Upload coverage reports to Codecov
if: ${{ !cancelled() }}
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: platform-backend
files: ./autogpt_platform/backend/coverage.xml
env:
CI: true

View File

@@ -148,3 +148,11 @@ jobs:
- name: Run Integration Tests
run: pnpm test:unit
- name: Upload coverage reports to Codecov
if: ${{ !cancelled() }}
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: platform-frontend
files: ./autogpt_platform/frontend/coverage/cobertura-coverage.xml

View File

@@ -179,21 +179,30 @@ jobs:
pip install pyyaml
# Resolve extends and generate a flat compose file that bake can understand
export NEXT_PUBLIC_SOURCEMAPS NEXT_PUBLIC_PW_TEST
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
# Ensure NEXT_PUBLIC_SOURCEMAPS is in resolved compose
# (docker compose config on some versions drops this arg)
if ! grep -q "NEXT_PUBLIC_SOURCEMAPS" docker-compose.resolved.yml; then
echo "Injecting NEXT_PUBLIC_SOURCEMAPS into resolved compose (docker compose config dropped it)"
sed -i '/NEXT_PUBLIC_PW_TEST/a\ NEXT_PUBLIC_SOURCEMAPS: "true"' docker-compose.resolved.yml
fi
# Add cache configuration to the resolved compose file
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
--source docker-compose.resolved.yml \
--cache-from "type=gha" \
--cache-to "type=gha,mode=max" \
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend/**') }}" \
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}" \
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}-sourcemaps" \
--git-ref "${{ github.ref }}"
# Build with bake using the resolved compose file (now includes cache config)
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
env:
NEXT_PUBLIC_PW_TEST: true
NEXT_PUBLIC_SOURCEMAPS: true
- name: Set up tests - Cache E2E test data
id: e2e-data-cache
@@ -279,6 +288,11 @@ jobs:
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Copy source maps from Docker for E2E coverage
run: |
FRONTEND_CONTAINER=$(docker compose -f ../docker-compose.resolved.yml ps -q frontend)
docker cp "$FRONTEND_CONTAINER":/app/.next/static .next-static-coverage
- name: Set up tests - Install dependencies
run: pnpm install --frozen-lockfile
@@ -289,6 +303,15 @@ jobs:
run: pnpm test:no-build
continue-on-error: false
- name: Upload E2E coverage to Codecov
if: ${{ !cancelled() }}
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
flags: platform-frontend-e2e
files: ./autogpt_platform/frontend/coverage/e2e/cobertura-coverage.xml
disable_search: true
- name: Upload Playwright report
if: always()
uses: actions/upload-artifact@v4

10
.gitignore vendored
View File

@@ -3,6 +3,7 @@
classic/original_autogpt/keys.py
classic/original_autogpt/*.json
auto_gpt_workspace/*
.autogpt/
*.mpeg
.env
# Root .env files
@@ -16,6 +17,7 @@ log-ingestion.txt
/logs
*.log
*.mp3
!autogpt_platform/frontend/public/notification.mp3
mem.sqlite3
venvAutoGPT
@@ -159,6 +161,10 @@ CURRENT_BULLETIN.md
# AgBenchmark
classic/benchmark/agbenchmark/reports/
classic/reports/
classic/direct_benchmark/reports/
classic/.benchmark_workspaces/
classic/direct_benchmark/.benchmark_workspaces/
# Nodejs
package-lock.json
@@ -177,9 +183,13 @@ autogpt_platform/backend/settings.py
*.ign.*
.test-contents
**/.claude/settings.local.json
.claude/settings.local.json
CLAUDE.local.md
/autogpt_platform/backend/logs
# Test database
test.db
.next
# Implementation plans (generated by AI agents)
plans/

36
.gitleaks.toml Normal file
View File

@@ -0,0 +1,36 @@
title = "AutoGPT Gitleaks Config"
[extend]
useDefault = true
[allowlist]
description = "Global allowlist"
paths = [
# Template/example env files (no real secrets)
'''\.env\.(default|example|template)$''',
# Lock files
'''pnpm-lock\.yaml$''',
'''poetry\.lock$''',
# Secrets baseline
'''\.secrets\.baseline$''',
# Build artifacts and caches (should not be committed)
'''__pycache__/''',
'''classic/frontend/build/''',
# Docker dev setup (local dev JWTs/keys only)
'''autogpt_platform/db/docker/''',
# Load test configs (dev JWTs)
'''load-tests/configs/''',
# Test files with fake/fixture keys (_test.py, test_*.py, conftest.py)
'''(_test|test_.*|conftest)\.py$''',
# Documentation (only contains placeholder keys in curl/API examples)
'''docs/.*\.md$''',
# Firebase config (public API keys by design)
'''google-services\.json$''',
'''classic/frontend/(lib|web)/''',
]
# CI test-only encryption key (marked DO NOT USE IN PRODUCTION)
regexes = [
'''dvziYgz0KSK8FENhju0ZYi8''',
# LLM model name enum values falsely flagged as API keys
'''Llama-\d.*Instruct''',
]

3
.gitmodules vendored
View File

@@ -1,3 +0,0 @@
[submodule "classic/forge/tests/vcr_cassettes"]
path = classic/forge/tests/vcr_cassettes
url = https://github.com/Significant-Gravitas/Auto-GPT-test-cassettes

View File

@@ -23,9 +23,15 @@ repos:
- id: detect-secrets
name: Detect secrets
description: Detects high entropy strings that are likely to be passwords.
args: ["--baseline", ".secrets.baseline"]
files: ^autogpt_platform/
exclude: pnpm-lock\.yaml$
stages: [pre-push]
exclude: (pnpm-lock\.yaml|\.env\.(default|example|template))$
- repo: https://github.com/gitleaks/gitleaks
rev: v8.24.3
hooks:
- id: gitleaks
name: Detect secrets (gitleaks)
- repo: local
# For proper type checking, all dependencies need to be up-to-date.
@@ -84,51 +90,16 @@ repos:
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic - AutoGPT
alias: poetry-install-classic-autogpt
name: Check & Install dependencies - Classic
alias: poetry-install-classic
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/(original_autogpt|forge)/poetry\.lock$" || exit 0;
poetry -C classic/original_autogpt install
'
# include forge source (since it's a path dependency)
always_run: true
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic - Forge
alias: poetry-install-classic-forge
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/forge/poetry\.lock$" || exit 0;
poetry -C classic/forge install
'
always_run: true
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic - Benchmark
alias: poetry-install-classic-benchmark
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/benchmark/poetry\.lock$" || exit 0;
poetry -C classic/benchmark install
fi | grep -qE "^classic/poetry\.lock$" || exit 0;
poetry -C classic install
'
always_run: true
language: system
@@ -223,26 +194,10 @@ repos:
language: system
- id: isort
name: Lint (isort) - Classic - AutoGPT
alias: isort-classic-autogpt
entry: poetry -P classic/original_autogpt run isort -p autogpt
files: ^classic/original_autogpt/
types: [file, python]
language: system
- id: isort
name: Lint (isort) - Classic - Forge
alias: isort-classic-forge
entry: poetry -P classic/forge run isort -p forge
files: ^classic/forge/
types: [file, python]
language: system
- id: isort
name: Lint (isort) - Classic - Benchmark
alias: isort-classic-benchmark
entry: poetry -P classic/benchmark run isort -p agbenchmark
files: ^classic/benchmark/
name: Lint (isort) - Classic
alias: isort-classic
entry: bash -c 'cd classic && poetry run isort $(echo "$@" | sed "s|classic/||g")' --
files: ^classic/(original_autogpt|forge|direct_benchmark)/
types: [file, python]
language: system
@@ -256,26 +211,13 @@ repos:
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
# To have flake8 load the config of the individual subprojects, we have to call
# them separately.
# Use consolidated flake8 config at classic/.flake8
hooks:
- id: flake8
name: Lint (Flake8) - Classic - AutoGPT
alias: flake8-classic-autogpt
files: ^classic/original_autogpt/(autogpt|scripts|tests)/
args: [--config=classic/original_autogpt/.flake8]
- id: flake8
name: Lint (Flake8) - Classic - Forge
alias: flake8-classic-forge
files: ^classic/forge/(forge|tests)/
args: [--config=classic/forge/.flake8]
- id: flake8
name: Lint (Flake8) - Classic - Benchmark
alias: flake8-classic-benchmark
files: ^classic/benchmark/(agbenchmark|tests)/((?!reports).)*[/.]
args: [--config=classic/benchmark/.flake8]
name: Lint (Flake8) - Classic
alias: flake8-classic
files: ^classic/(original_autogpt|forge|direct_benchmark)/
args: [--config=classic/.flake8]
- repo: local
hooks:
@@ -311,29 +253,10 @@ repos:
pass_filenames: false
- id: pyright
name: Typecheck - Classic - AutoGPT
alias: pyright-classic-autogpt
entry: poetry -C classic/original_autogpt run pyright
# include forge source (since it's a path dependency) but exclude *_test.py files:
files: ^(classic/original_autogpt/((autogpt|scripts|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
types: [file]
language: system
pass_filenames: false
- id: pyright
name: Typecheck - Classic - Forge
alias: pyright-classic-forge
entry: poetry -C classic/forge run pyright
files: ^classic/forge/(forge/|poetry\.lock$)
types: [file]
language: system
pass_filenames: false
- id: pyright
name: Typecheck - Classic - Benchmark
alias: pyright-classic-benchmark
entry: poetry -C classic/benchmark run pyright
files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
name: Typecheck - Classic
alias: pyright-classic
entry: poetry -C classic run pyright
files: ^classic/(original_autogpt|forge|direct_benchmark)/.*\.py$|^classic/poetry\.lock$
types: [file]
language: system
pass_filenames: false
@@ -360,26 +283,9 @@ repos:
# pass_filenames: false
# - id: pytest
# name: Run tests - Classic - AutoGPT (excl. slow tests)
# alias: pytest-classic-autogpt
# entry: bash -c 'cd classic/original_autogpt && poetry run pytest --cov=autogpt -m "not slow" tests/unit tests/integration'
# # include forge source (since it's a path dependency) but exclude *_test.py files:
# files: ^(classic/original_autogpt/((autogpt|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
# language: system
# pass_filenames: false
# - id: pytest
# name: Run tests - Classic - Forge (excl. slow tests)
# alias: pytest-classic-forge
# entry: bash -c 'cd classic/forge && poetry run pytest --cov=forge -m "not slow"'
# files: ^classic/forge/(forge/|tests/|poetry\.lock$)
# language: system
# pass_filenames: false
# - id: pytest
# name: Run tests - Classic - Benchmark
# alias: pytest-classic-benchmark
# entry: bash -c 'cd classic/benchmark && poetry run pytest --cov=benchmark'
# files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
# name: Run tests - Classic (excl. slow tests)
# alias: pytest-classic
# entry: bash -c 'cd classic && poetry run pytest -m "not slow"'
# files: ^classic/(original_autogpt|forge|direct_benchmark)/
# language: system
# pass_filenames: false

467
.secrets.baseline Normal file
View File

@@ -0,0 +1,467 @@
{
"version": "1.5.0",
"plugins_used": [
{
"name": "ArtifactoryDetector"
},
{
"name": "AWSKeyDetector"
},
{
"name": "AzureStorageKeyDetector"
},
{
"name": "Base64HighEntropyString",
"limit": 4.5
},
{
"name": "BasicAuthDetector"
},
{
"name": "CloudantDetector"
},
{
"name": "DiscordBotTokenDetector"
},
{
"name": "GitHubTokenDetector"
},
{
"name": "GitLabTokenDetector"
},
{
"name": "HexHighEntropyString",
"limit": 3.0
},
{
"name": "IbmCloudIamDetector"
},
{
"name": "IbmCosHmacDetector"
},
{
"name": "IPPublicDetector"
},
{
"name": "JwtTokenDetector"
},
{
"name": "KeywordDetector",
"keyword_exclude": ""
},
{
"name": "MailchimpDetector"
},
{
"name": "NpmDetector"
},
{
"name": "OpenAIDetector"
},
{
"name": "PrivateKeyDetector"
},
{
"name": "PypiTokenDetector"
},
{
"name": "SendGridDetector"
},
{
"name": "SlackDetector"
},
{
"name": "SoftlayerDetector"
},
{
"name": "SquareOAuthDetector"
},
{
"name": "StripeDetector"
},
{
"name": "TelegramBotTokenDetector"
},
{
"name": "TwilioKeyDetector"
}
],
"filters_used": [
{
"path": "detect_secrets.filters.allowlist.is_line_allowlisted"
},
{
"path": "detect_secrets.filters.common.is_ignored_due_to_verification_policies",
"min_level": 2
},
{
"path": "detect_secrets.filters.heuristic.is_indirect_reference"
},
{
"path": "detect_secrets.filters.heuristic.is_likely_id_string"
},
{
"path": "detect_secrets.filters.heuristic.is_lock_file"
},
{
"path": "detect_secrets.filters.heuristic.is_not_alphanumeric_string"
},
{
"path": "detect_secrets.filters.heuristic.is_potential_uuid"
},
{
"path": "detect_secrets.filters.heuristic.is_prefixed_with_dollar_sign"
},
{
"path": "detect_secrets.filters.heuristic.is_sequential_string"
},
{
"path": "detect_secrets.filters.heuristic.is_swagger_file"
},
{
"path": "detect_secrets.filters.heuristic.is_templated_secret"
},
{
"path": "detect_secrets.filters.regex.should_exclude_file",
"pattern": [
"\\.env$",
"pnpm-lock\\.yaml$",
"\\.env\\.(default|example|template)$",
"__pycache__",
"_test\\.py$",
"test_.*\\.py$",
"conftest\\.py$",
"poetry\\.lock$",
"node_modules"
]
}
],
"results": {
"autogpt_platform/backend/backend/api/external/v1/integrations.py": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/backend/backend/api/external/v1/integrations.py",
"hashed_secret": "665b1e3851eefefa3fb878654292f16597d25155",
"is_verified": false,
"line_number": 289
}
],
"autogpt_platform/backend/backend/blocks/airtable/_config.py": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/backend/backend/blocks/airtable/_config.py",
"hashed_secret": "57e168b03afb7c1ee3cdc4ee3db2fe1cc6e0df26",
"is_verified": false,
"line_number": 29
}
],
"autogpt_platform/backend/backend/blocks/dataforseo/_config.py": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/backend/backend/blocks/dataforseo/_config.py",
"hashed_secret": "32ce93887331fa5d192f2876ea15ec000c7d58b8",
"is_verified": false,
"line_number": 12
}
],
"autogpt_platform/backend/backend/blocks/github/checks.py": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/github/checks.py",
"hashed_secret": "8ac6f92737d8586790519c5d7bfb4d2eb172c238",
"is_verified": false,
"line_number": 108
}
],
"autogpt_platform/backend/backend/blocks/github/ci.py": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/github/ci.py",
"hashed_secret": "90bd1b48e958257948487b90bee080ba5ed00caa",
"is_verified": false,
"line_number": 123
}
],
"autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
"hashed_secret": "f96896dafced7387dcd22343b8ea29d3d2c65663",
"is_verified": false,
"line_number": 42
},
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
"hashed_secret": "b80a94d5e70bedf4f5f89d2f5a5255cc9492d12e",
"is_verified": false,
"line_number": 193
},
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
"hashed_secret": "75b17e517fe1b3136394f6bec80c4f892da75e42",
"is_verified": false,
"line_number": 344
},
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
"hashed_secret": "b0bfb5e4e2394e7f8906e5ed1dffd88b2bc89dd5",
"is_verified": false,
"line_number": 534
}
],
"autogpt_platform/backend/backend/blocks/github/statuses.py": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/github/statuses.py",
"hashed_secret": "8ac6f92737d8586790519c5d7bfb4d2eb172c238",
"is_verified": false,
"line_number": 85
}
],
"autogpt_platform/backend/backend/blocks/google/docs.py": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/google/docs.py",
"hashed_secret": "c95da0c6696342c867ef0c8258d2f74d20fd94d4",
"is_verified": false,
"line_number": 203
}
],
"autogpt_platform/backend/backend/blocks/google/sheets.py": [
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/google/sheets.py",
"hashed_secret": "bd5a04fa3667e693edc13239b6d310c5c7a8564b",
"is_verified": false,
"line_number": 57
}
],
"autogpt_platform/backend/backend/blocks/linear/_config.py": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/backend/backend/blocks/linear/_config.py",
"hashed_secret": "b37f020f42d6d613b6ce30103e4d408c4499b3bb",
"is_verified": false,
"line_number": 53
}
],
"autogpt_platform/backend/backend/blocks/medium.py": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/medium.py",
"hashed_secret": "ff998abc1ce6d8f01a675fa197368e44c8916e9c",
"is_verified": false,
"line_number": 131
}
],
"autogpt_platform/backend/backend/blocks/replicate/replicate_block.py": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/replicate/replicate_block.py",
"hashed_secret": "8bbdd6f26368f58ea4011d13d7f763cb662e66f0",
"is_verified": false,
"line_number": 55
}
],
"autogpt_platform/backend/backend/blocks/slant3d/webhook.py": [
{
"type": "Hex High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/slant3d/webhook.py",
"hashed_secret": "36263c76947443b2f6e6b78153967ac4a7da99f9",
"is_verified": false,
"line_number": 100
}
],
"autogpt_platform/backend/backend/blocks/talking_head.py": [
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/backend/backend/blocks/talking_head.py",
"hashed_secret": "44ce2d66222529eea4a32932823466fc0601c799",
"is_verified": false,
"line_number": 113
}
],
"autogpt_platform/backend/backend/blocks/wordpress/_config.py": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/backend/backend/blocks/wordpress/_config.py",
"hashed_secret": "e62679512436161b78e8a8d68c8829c2a1031ccb",
"is_verified": false,
"line_number": 17
}
],
"autogpt_platform/backend/backend/util/cache.py": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/backend/backend/util/cache.py",
"hashed_secret": "37f0c918c3fa47ca4a70e42037f9f123fdfbc75b",
"is_verified": false,
"line_number": 449
}
],
"autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts",
"hashed_secret": "5baa61e4c9b93f3f0682250b6cf8331b7ee68fd8",
"is_verified": false,
"line_number": 6
}
],
"autogpt_platform/frontend/src/app/(platform)/dictionaries/en.json": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/dictionaries/en.json",
"hashed_secret": "8be3c943b1609fffbfc51aad666d0a04adf83c9d",
"is_verified": false,
"line_number": 5
}
],
"autogpt_platform/frontend/src/app/(platform)/dictionaries/es.json": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/dictionaries/es.json",
"hashed_secret": "5a6d1c612954979ea99ee33dbb2d231b00f6ac0a",
"is_verified": false,
"line_number": 5
}
],
"autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts",
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
"is_verified": false,
"line_number": 6
},
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts",
"hashed_secret": "f72cbb45464d487064610c5411c576ca4019d380",
"is_verified": false,
"line_number": 8
}
],
"autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts",
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
"is_verified": false,
"line_number": 5
},
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts",
"hashed_secret": "f72cbb45464d487064610c5411c576ca4019d380",
"is_verified": false,
"line_number": 7
}
],
"autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx",
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
"is_verified": false,
"line_number": 192
},
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx",
"hashed_secret": "86275db852204937bbdbdebe5fabe8536e030ab6",
"is_verified": false,
"line_number": 193
}
],
"autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts",
"hashed_secret": "47acd2028cf81b5da88ddeedb2aea4eca4b71fbd",
"is_verified": false,
"line_number": 102
},
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts",
"hashed_secret": "8be3c943b1609fffbfc51aad666d0a04adf83c9d",
"is_verified": false,
"line_number": 103
}
],
"autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts": [
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
"hashed_secret": "9c486c92f1a7420e1045c7ad963fbb7ba3621025",
"is_verified": false,
"line_number": 73
},
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
"hashed_secret": "9277508c7a6effc8fb59163efbfada189e35425c",
"is_verified": false,
"line_number": 75
},
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
"hashed_secret": "8dc7e2cb1d0935897d541bf5facab389b8a50340",
"is_verified": false,
"line_number": 77
},
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
"hashed_secret": "79a26ad48775944299be6aaf9fb1d5302c1ed75b",
"is_verified": false,
"line_number": 79
},
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
"hashed_secret": "a3b62b44500a1612e48d4cab8294df81561b3b1a",
"is_verified": false,
"line_number": 81
},
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
"hashed_secret": "a58979bd0b21ef4f50417d001008e60dd7a85c64",
"is_verified": false,
"line_number": 83
},
{
"type": "Base64 High Entropy String",
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
"hashed_secret": "6cb6e075f8e8c7c850f9d128d6608e5dbe209a79",
"is_verified": false,
"line_number": 85
}
],
"autogpt_platform/frontend/src/lib/constants.ts": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/lib/constants.ts",
"hashed_secret": "27b924db06a28cc755fb07c54f0fddc30659fe4d",
"is_verified": false,
"line_number": 10
}
],
"autogpt_platform/frontend/src/tests/credentials/index.ts": [
{
"type": "Secret Keyword",
"filename": "autogpt_platform/frontend/src/tests/credentials/index.ts",
"hashed_secret": "c18006fc138809314751cd1991f1e0b820fabd37",
"is_verified": false,
"line_number": 4
}
]
},
"generated_at": "2026-04-02T13:10:54Z"
}

View File

@@ -1,6 +1,6 @@
# AutoGPT Platform Contribution Guide
This guide provides context for Codex when updating the **autogpt_platform** folder.
This guide provides context for coding agents when updating the **autogpt_platform** folder.
## Directory overview
@@ -30,7 +30,7 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
- Regenerate with `pnpm generate:api`
- Pattern: `use{Method}{Version}{OperationName}`
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
5. **Testing**: Integration tests (Vitest + RTL + MSW) are the default (~90%, page-level). Playwright for E2E critical flows. Storybook for design system components. See `autogpt_platform/frontend/TESTING.md`
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
@@ -47,7 +47,9 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
## Testing
- Backend: `poetry run test` (runs pytest with a docker based postgres + prisma).
- Frontend: `pnpm test` or `pnpm test-ui` for Playwright tests. See `docs/content/platform/contributing/tests.md` for tips.
- Frontend integration tests: `pnpm test:unit` (Vitest + RTL + MSW, primary testing approach).
- Frontend E2E tests: `pnpm test` or `pnpm test-ui` for Playwright tests.
- See `autogpt_platform/frontend/TESTING.md` for the full testing strategy.
Always run the relevant linters and tests before committing.
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).

1
CLAUDE.md Normal file
View File

@@ -0,0 +1 @@
@AGENTS.md

120
autogpt_platform/AGENTS.md Normal file
View File

@@ -0,0 +1,120 @@
# AutoGPT Platform
This file provides guidance to coding agents when working with code in this repository.
## Repository Overview
AutoGPT Platform is a monorepo containing:
- **Backend** (`backend`): Python FastAPI server with async support
- **Frontend** (`frontend`): Next.js React application
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
## Component Documentation
- **Backend**: See @backend/AGENTS.md for backend-specific commands, architecture, and development tasks
- **Frontend**: See @frontend/AGENTS.md for frontend-specific commands, architecture, and development patterns
## Key Concepts
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
3. **Integrations**: OAuth and API connections stored per user
4. **Store**: Marketplace for sharing agent templates
5. **Virus Scanning**: ClamAV integration for file upload security
### Environment Configuration
#### Configuration Files
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
#### Docker Environment Loading Order
1. `.env.default` files provide base configuration (tracked in git)
2. `.env` files provide user-specific overrides (gitignored)
3. Docker Compose `environment:` sections provide service-specific overrides
4. Shell environment variables have highest precedence
#### Key Points
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
- The `env_file` directive loads variables INTO containers at runtime
- Backend/Frontend services use YAML anchors for consistent configuration
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
### Branching Strategy
- **`dev`** is the main development branch. All PRs should target `dev`.
- **`master`** is the production branch. Only used for production releases.
### Creating Pull Requests
- Create the PR against the `dev` branch of the repository.
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
- Use conventional commit messages (see below)
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
```bash
PR_BODY=$(mktemp)
cat > "$PR_BODY" << 'PREOF'
## Summary
- use `backticks` freely here
PREOF
gh pr create --title "..." --body-file "$PR_BODY" --base dev
rm "$PR_BODY"
```
- Run the github pre-commit hooks to ensure code quality.
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, follow a test-first approach:
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
2. **Implement the fix/feature** — write the minimal code to make the test pass.
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
This ensures every change is covered by a test and that the test actually validates the intended behavior.
### Reviewing/Revising Pull Requests
Use `/pr-review` to review a PR or `/pr-address` to address comments.
When fetching comments manually:
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
### Conventional Commits
Use this format for commit messages and Pull Request titles:
**Conventional Commit Types:**
- `feat`: Introduces a new feature to the codebase
- `fix`: Patches a bug in the codebase
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
- `ci`: Changes to CI configuration
- `docs`: Documentation-only changes
- `dx`: Improvements to the developer experience
**Recommended Base Scopes:**
- `platform`: Changes affecting both frontend and backend
- `frontend`
- `backend`
- `infra`
- `blocks`: Modifications/additions of individual blocks
**Subscope Examples:**
- `backend/executor`
- `backend/db`
- `frontend/builder` (includes changes to the block UI component)
- `infra/prod`
Use these scopes and subscopes for clarity and consistency in commit messages.

View File

@@ -1,120 +1 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Repository Overview
AutoGPT Platform is a monorepo containing:
- **Backend** (`backend`): Python FastAPI server with async support
- **Frontend** (`frontend`): Next.js React application
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
## Component Documentation
- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks
- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns
## Key Concepts
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
3. **Integrations**: OAuth and API connections stored per user
4. **Store**: Marketplace for sharing agent templates
5. **Virus Scanning**: ClamAV integration for file upload security
### Environment Configuration
#### Configuration Files
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
#### Docker Environment Loading Order
1. `.env.default` files provide base configuration (tracked in git)
2. `.env` files provide user-specific overrides (gitignored)
3. Docker Compose `environment:` sections provide service-specific overrides
4. Shell environment variables have highest precedence
#### Key Points
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
- The `env_file` directive loads variables INTO containers at runtime
- Backend/Frontend services use YAML anchors for consistent configuration
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
### Branching Strategy
- **`dev`** is the main development branch. All PRs should target `dev`.
- **`master`** is the production branch. Only used for production releases.
### Creating Pull Requests
- Create the PR against the `dev` branch of the repository.
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
- Use conventional commit messages (see below)
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
```bash
PR_BODY=$(mktemp)
cat > "$PR_BODY" << 'PREOF'
## Summary
- use `backticks` freely here
PREOF
gh pr create --title "..." --body-file "$PR_BODY" --base dev
rm "$PR_BODY"
```
- Run the github pre-commit hooks to ensure code quality.
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, follow a test-first approach:
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
2. **Implement the fix/feature** — write the minimal code to make the test pass.
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
This ensures every change is covered by a test and that the test actually validates the intended behavior.
### Reviewing/Revising Pull Requests
Use `/pr-review` to review a PR or `/pr-address` to address comments.
When fetching comments manually:
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
### Conventional Commits
Use this format for commit messages and Pull Request titles:
**Conventional Commit Types:**
- `feat`: Introduces a new feature to the codebase
- `fix`: Patches a bug in the codebase
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
- `ci`: Changes to CI configuration
- `docs`: Documentation-only changes
- `dx`: Improvements to the developer experience
**Recommended Base Scopes:**
- `platform`: Changes affecting both frontend and backend
- `frontend`
- `backend`
- `infra`
- `blocks`: Modifications/additions of individual blocks
**Subscope Examples:**
- `backend/executor`
- `backend/db`
- `frontend/builder` (includes changes to the block UI component)
- `infra/prod`
Use these scopes and subscopes for clarity and consistency in commit messages.
@AGENTS.md

View File

@@ -178,6 +178,7 @@ SMTP_USERNAME=
SMTP_PASSWORD=
# Business & Marketing Tools
AGENTMAIL_API_KEY=
APOLLO_API_KEY=
ENRICHLAYER_API_KEY=
AYRSHARE_API_KEY=

View File

@@ -0,0 +1,227 @@
# Backend
This file provides guidance to coding agents when working with the backend.
## Essential Commands
To run something with Python package dependencies you MUST use `poetry run ...`.
```bash
# Install dependencies
poetry install
# Run database migrations
poetry run prisma migrate dev
# Start all services (database, redis, rabbitmq, clamav)
docker compose up -d
# Run the backend as a whole
poetry run app
# Run tests
poetry run test
# Run specific test
poetry run pytest path/to/test_file.py::test_function_name
# Run block tests (tests that validate all blocks work correctly)
poetry run pytest backend/blocks/test/test_block.py -xvs
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
# Lint and format
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
poetry run format # Black + isort
poetry run lint # ruff
```
More details can be found in @TESTING.md
### Creating/Updating Snapshots
When you first write a test or when the expected output changes:
```bash
poetry run pytest path/to/test.py --snapshot-update
```
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
## Architecture
- **API Layer**: FastAPI with REST and WebSocket endpoints
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
- **Queue System**: RabbitMQ for async task processing
- **Execution Engine**: Separate executor service processes agent workflows
- **Authentication**: JWT-based with Supabase integration
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
## Code Style
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
- **Pydantic models** over dataclass/namedtuple/dict for structured data
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
- **List comprehensions** over manual loop-and-append
- **Early return** — guard clauses first, avoid deep nesting
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
- **`max(0, value)` guards** — for computed values that should never be negative
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
## Testing Approach
- Uses pytest with snapshot testing for API responses
- Test files are colocated with source files (`*_test.py`)
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
- After refactoring, update mock targets to match new module paths
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, write the test **before** the implementation:
```python
# 1. Write a failing test marked xfail
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
# 2. Run it — confirm it fails (XFAIL)
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
# 3. Implement the fix
# 4. Remove xfail, run again — confirm it passes
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
```
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
## Database Schema
Key models (defined in `schema.prisma`):
- `User`: Authentication and profile data
- `AgentGraph`: Workflow definitions with version control
- `AgentGraphExecution`: Execution history and results
- `AgentNode`: Individual nodes in a workflow
- `StoreListing`: Marketplace listings for sharing agents
## Environment Configuration
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
## Common Development Tasks
### Adding a new block
Follow the comprehensive [Block SDK Guide](@../../docs/platform/block-sdk-guide.md) which covers:
- Provider configuration with `ProviderBuilder`
- Block schema definition
- Authentication (API keys, OAuth, webhooks)
- Testing and validation
- File organization
Quick steps:
1. Create new file in `backend/blocks/`
2. Configure provider using `ProviderBuilder` in `_config.py`
3. Inherit from `Block` base class
4. Define input/output schemas using `BlockSchema`
5. Implement async `run` method
6. Generate unique block ID using `uuid.uuid4()`
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
ex: do the inputs and outputs tie well together?
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
#### Handling files in blocks with `store_media_file()`
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
| Format | Use When | Returns |
|--------|----------|---------|
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
**Examples:**
```python
# INPUT: Need to process file locally with ffmpeg
local_path = await store_media_file(
file=input_data.video,
execution_context=execution_context,
return_format="for_local_processing",
)
# local_path = "video.mp4" - use with Path/ffmpeg/etc
# INPUT: Need to send to external API like Replicate
image_b64 = await store_media_file(
file=input_data.image,
execution_context=execution_context,
return_format="for_external_api",
)
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
# OUTPUT: Returning result from block
result_url = await store_media_file(
file=generated_image_url,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", result_url
# In CoPilot: result_url = "workspace://abc123"
# In graphs: result_url = "data:image/png;base64,..."
```
**Key points:**
- `for_block_output` is the ONLY format that auto-adapts to execution context
- Always use `for_block_output` for block outputs unless you have a specific reason not to
- Never hardcode workspace checks - let `for_block_output` handle it
### Modifying the API
1. Update route in `backend/api/features/`
2. Add/update Pydantic models in same directory
3. Write tests alongside the route file
4. Run `poetry run test` to verify
## Workspace & Media Files
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
- Working on CoPilot file upload/download features
- Building blocks that handle `MediaFileType` inputs/outputs
- Modifying `WorkspaceManager` or `store_media_file()`
- Debugging file persistence or virus scanning issues
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
## Security Implementation
### Cache Protection Middleware
- Located in `backend/api/middleware/security.py`
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses an allow list approach - only explicitly permitted paths can be cached
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
- Applied to both main API server and external API applications

View File

@@ -1,227 +1 @@
# CLAUDE.md - Backend
This file provides guidance to Claude Code when working with the backend.
## Essential Commands
To run something with Python package dependencies you MUST use `poetry run ...`.
```bash
# Install dependencies
poetry install
# Run database migrations
poetry run prisma migrate dev
# Start all services (database, redis, rabbitmq, clamav)
docker compose up -d
# Run the backend as a whole
poetry run app
# Run tests
poetry run test
# Run specific test
poetry run pytest path/to/test_file.py::test_function_name
# Run block tests (tests that validate all blocks work correctly)
poetry run pytest backend/blocks/test/test_block.py -xvs
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
# Lint and format
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
poetry run format # Black + isort
poetry run lint # ruff
```
More details can be found in @TESTING.md
### Creating/Updating Snapshots
When you first write a test or when the expected output changes:
```bash
poetry run pytest path/to/test.py --snapshot-update
```
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
## Architecture
- **API Layer**: FastAPI with REST and WebSocket endpoints
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
- **Queue System**: RabbitMQ for async task processing
- **Execution Engine**: Separate executor service processes agent workflows
- **Authentication**: JWT-based with Supabase integration
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
## Code Style
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
- **Pydantic models** over dataclass/namedtuple/dict for structured data
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
- **List comprehensions** over manual loop-and-append
- **Early return** — guard clauses first, avoid deep nesting
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
- **`max(0, value)` guards** — for computed values that should never be negative
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
## Testing Approach
- Uses pytest with snapshot testing for API responses
- Test files are colocated with source files (`*_test.py`)
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
- After refactoring, update mock targets to match new module paths
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, write the test **before** the implementation:
```python
# 1. Write a failing test marked xfail
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
# 2. Run it — confirm it fails (XFAIL)
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
# 3. Implement the fix
# 4. Remove xfail, run again — confirm it passes
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
```
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
## Database Schema
Key models (defined in `schema.prisma`):
- `User`: Authentication and profile data
- `AgentGraph`: Workflow definitions with version control
- `AgentGraphExecution`: Execution history and results
- `AgentNode`: Individual nodes in a workflow
- `StoreListing`: Marketplace listings for sharing agents
## Environment Configuration
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
## Common Development Tasks
### Adding a new block
Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers:
- Provider configuration with `ProviderBuilder`
- Block schema definition
- Authentication (API keys, OAuth, webhooks)
- Testing and validation
- File organization
Quick steps:
1. Create new file in `backend/blocks/`
2. Configure provider using `ProviderBuilder` in `_config.py`
3. Inherit from `Block` base class
4. Define input/output schemas using `BlockSchema`
5. Implement async `run` method
6. Generate unique block ID using `uuid.uuid4()`
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
ex: do the inputs and outputs tie well together?
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
#### Handling files in blocks with `store_media_file()`
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
| Format | Use When | Returns |
|--------|----------|---------|
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
**Examples:**
```python
# INPUT: Need to process file locally with ffmpeg
local_path = await store_media_file(
file=input_data.video,
execution_context=execution_context,
return_format="for_local_processing",
)
# local_path = "video.mp4" - use with Path/ffmpeg/etc
# INPUT: Need to send to external API like Replicate
image_b64 = await store_media_file(
file=input_data.image,
execution_context=execution_context,
return_format="for_external_api",
)
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
# OUTPUT: Returning result from block
result_url = await store_media_file(
file=generated_image_url,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", result_url
# In CoPilot: result_url = "workspace://abc123"
# In graphs: result_url = "data:image/png;base64,..."
```
**Key points:**
- `for_block_output` is the ONLY format that auto-adapts to execution context
- Always use `for_block_output` for block outputs unless you have a specific reason not to
- Never hardcode workspace checks - let `for_block_output` handle it
### Modifying the API
1. Update route in `backend/api/features/`
2. Add/update Pydantic models in same directory
3. Write tests alongside the route file
4. Run `poetry run test` to verify
## Workspace & Media Files
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
- Working on CoPilot file upload/download features
- Building blocks that handle `MediaFileType` inputs/outputs
- Modifying `WorkspaceManager` or `store_media_file()`
- Debugging file persistence or virus scanning issues
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
## Security Implementation
### Cache Protection Middleware
- Located in `backend/api/middleware/security.py`
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses an allow list approach - only explicitly permitted paths can be cached
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
- Applied to both main API server and external API applications
@AGENTS.md

View File

@@ -31,7 +31,10 @@ from backend.data.model import (
UserPasswordCredentials,
is_sdk_default,
)
from backend.integrations.credentials_store import provider_matches
from backend.integrations.credentials_store import (
is_system_credential,
provider_matches,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
@@ -618,6 +621,11 @@ async def delete_credential(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if is_system_credential(cred_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="System-managed credentials cannot be deleted",
)
creds = await creds_manager.store.get_creds_by_id(auth.user_id, cred_id)
if not creds:
raise HTTPException(

View File

@@ -72,7 +72,7 @@ class RunAgentRequest(BaseModel):
def _create_ephemeral_session(user_id: str) -> ChatSession:
"""Create an ephemeral session for stateless API requests."""
return ChatSession.new(user_id)
return ChatSession.new(user_id, dry_run=False)
@tools_router.post(

View File

@@ -0,0 +1,98 @@
import logging
from datetime import datetime
from autogpt_libs.auth import get_user_id, requires_admin_user
from cachetools import TTLCache
from fastapi import APIRouter, Query, Security
from pydantic import BaseModel
from backend.data.platform_cost import (
CostLogRow,
PlatformCostDashboard,
get_platform_cost_dashboard,
get_platform_cost_logs,
)
from backend.util.models import Pagination
logger = logging.getLogger(__name__)
# Cache dashboard results for 30 seconds per unique filter combination.
# The table is append-only so stale reads are acceptable for analytics.
_DASHBOARD_CACHE_TTL = 30
_dashboard_cache: TTLCache[tuple, PlatformCostDashboard] = TTLCache(
maxsize=256, ttl=_DASHBOARD_CACHE_TTL
)
router = APIRouter(
prefix="/platform-costs",
tags=["platform-cost", "admin"],
dependencies=[Security(requires_admin_user)],
)
class PlatformCostLogsResponse(BaseModel):
logs: list[CostLogRow]
pagination: Pagination
@router.get(
"/dashboard",
response_model=PlatformCostDashboard,
summary="Get Platform Cost Dashboard",
)
async def get_cost_dashboard(
admin_user_id: str = Security(get_user_id),
start: datetime | None = Query(None),
end: datetime | None = Query(None),
provider: str | None = Query(None),
user_id: str | None = Query(None),
):
logger.info("Admin %s fetching platform cost dashboard", admin_user_id)
cache_key = (start, end, provider, user_id)
cached = _dashboard_cache.get(cache_key)
if cached is not None:
return cached
result = await get_platform_cost_dashboard(
start=start,
end=end,
provider=provider,
user_id=user_id,
)
_dashboard_cache[cache_key] = result
return result
@router.get(
"/logs",
response_model=PlatformCostLogsResponse,
summary="Get Platform Cost Logs",
)
async def get_cost_logs(
admin_user_id: str = Security(get_user_id),
start: datetime | None = Query(None),
end: datetime | None = Query(None),
provider: str | None = Query(None),
user_id: str | None = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=200),
):
logger.info("Admin %s fetching platform cost logs", admin_user_id)
logs, total = await get_platform_cost_logs(
start=start,
end=end,
provider=provider,
user_id=user_id,
page=page,
page_size=page_size,
)
total_pages = (total + page_size - 1) // page_size
return PlatformCostLogsResponse(
logs=logs,
pagination=Pagination(
total_items=total,
total_pages=total_pages,
current_page=page,
page_size=page_size,
),
)

View File

@@ -0,0 +1,192 @@
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from backend.data.platform_cost import PlatformCostDashboard
from . import platform_cost_routes
from .platform_cost_routes import router as platform_cost_router
app = fastapi.FastAPI()
app.include_router(platform_cost_router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
"""Setup admin auth overrides for all tests in this module"""
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
# Clear TTL cache so each test starts cold.
platform_cost_routes._dashboard_cache.clear()
yield
app.dependency_overrides.clear()
def test_get_dashboard_success(
mocker: pytest_mock.MockerFixture,
) -> None:
real_dashboard = PlatformCostDashboard(
by_provider=[],
by_user=[],
total_cost_microdollars=0,
total_requests=0,
total_users=0,
)
mocker.patch(
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
AsyncMock(return_value=real_dashboard),
)
response = client.get("/platform-costs/dashboard")
assert response.status_code == 200
data = response.json()
assert "by_provider" in data
assert "by_user" in data
assert data["total_cost_microdollars"] == 0
def test_get_logs_success(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
AsyncMock(return_value=([], 0)),
)
response = client.get("/platform-costs/logs")
assert response.status_code == 200
data = response.json()
assert data["logs"] == []
assert data["pagination"]["total_items"] == 0
def test_get_dashboard_with_filters(
mocker: pytest_mock.MockerFixture,
) -> None:
real_dashboard = PlatformCostDashboard(
by_provider=[],
by_user=[],
total_cost_microdollars=0,
total_requests=0,
total_users=0,
)
mock_dashboard = AsyncMock(return_value=real_dashboard)
mocker.patch(
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
mock_dashboard,
)
response = client.get(
"/platform-costs/dashboard",
params={
"start": "2026-01-01T00:00:00",
"end": "2026-04-01T00:00:00",
"provider": "openai",
"user_id": "test-user-123",
},
)
assert response.status_code == 200
mock_dashboard.assert_called_once()
call_kwargs = mock_dashboard.call_args.kwargs
assert call_kwargs["provider"] == "openai"
assert call_kwargs["user_id"] == "test-user-123"
assert call_kwargs["start"] is not None
assert call_kwargs["end"] is not None
def test_get_logs_with_pagination(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
AsyncMock(return_value=([], 0)),
)
response = client.get(
"/platform-costs/logs",
params={"page": 2, "page_size": 25, "provider": "anthropic"},
)
assert response.status_code == 200
data = response.json()
assert data["pagination"]["current_page"] == 2
assert data["pagination"]["page_size"] == 25
def test_get_dashboard_requires_admin() -> None:
import fastapi
from fastapi import HTTPException
def reject_jwt(request: fastapi.Request):
raise HTTPException(status_code=401, detail="Not authenticated")
app.dependency_overrides[get_jwt_payload] = reject_jwt
try:
response = client.get("/platform-costs/dashboard")
assert response.status_code == 401
response = client.get("/platform-costs/logs")
assert response.status_code == 401
finally:
app.dependency_overrides.clear()
def test_get_dashboard_rejects_non_admin(mock_jwt_user, mock_jwt_admin) -> None:
"""Non-admin JWT must be rejected with 403 by requires_admin_user."""
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
try:
response = client.get("/platform-costs/dashboard")
assert response.status_code == 403
response = client.get("/platform-costs/logs")
assert response.status_code == 403
finally:
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
def test_get_logs_invalid_page_size_too_large() -> None:
"""page_size > 200 must be rejected with 422."""
response = client.get("/platform-costs/logs", params={"page_size": 201})
assert response.status_code == 422
def test_get_logs_invalid_page_size_zero() -> None:
"""page_size = 0 (below ge=1) must be rejected with 422."""
response = client.get("/platform-costs/logs", params={"page_size": 0})
assert response.status_code == 422
def test_get_logs_invalid_page_negative() -> None:
"""page < 1 must be rejected with 422."""
response = client.get("/platform-costs/logs", params={"page": 0})
assert response.status_code == 422
def test_get_dashboard_invalid_date_format() -> None:
"""Malformed start date must be rejected with 422."""
response = client.get("/platform-costs/dashboard", params={"start": "not-a-date"})
assert response.status_code == 422
def test_get_dashboard_cache_hit(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Second identical request returns cached result without calling the DB again."""
real_dashboard = PlatformCostDashboard(
by_provider=[],
by_user=[],
total_cost_microdollars=42,
total_requests=1,
total_users=1,
)
mock_fn = mocker.patch(
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
AsyncMock(return_value=real_dashboard),
)
client.get("/platform-costs/dashboard")
client.get("/platform-costs/dashboard")
mock_fn.assert_awaited_once() # second request hit the cache

View File

@@ -9,11 +9,14 @@ from pydantic import BaseModel
from backend.copilot.config import ChatConfig
from backend.copilot.rate_limit import (
SubscriptionTier,
get_global_rate_limits,
get_usage_status,
get_user_tier,
reset_user_usage,
set_user_tier,
)
from backend.data.user import get_user_by_email, get_user_email_by_id
from backend.data.user import get_user_by_email, get_user_email_by_id, search_users
logger = logging.getLogger(__name__)
@@ -33,6 +36,17 @@ class UserRateLimitResponse(BaseModel):
weekly_token_limit: int
daily_tokens_used: int
weekly_tokens_used: int
tier: SubscriptionTier
class UserTierResponse(BaseModel):
user_id: str
tier: SubscriptionTier
class SetUserTierRequest(BaseModel):
user_id: str
tier: SubscriptionTier
async def _resolve_user_id(
@@ -86,10 +100,10 @@ async def get_user_rate_limit(
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
daily_limit, weekly_limit = await get_global_rate_limits(
daily_limit, weekly_limit, tier = await get_global_rate_limits(
resolved_id, config.daily_token_limit, config.weekly_token_limit
)
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit)
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit, tier=tier)
return UserRateLimitResponse(
user_id=resolved_id,
@@ -98,6 +112,7 @@ async def get_user_rate_limit(
weekly_token_limit=weekly_limit,
daily_tokens_used=usage.daily.used,
weekly_tokens_used=usage.weekly.used,
tier=tier,
)
@@ -125,10 +140,10 @@ async def reset_user_rate_limit(
logger.exception("Failed to reset user usage")
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
daily_limit, weekly_limit = await get_global_rate_limits(
daily_limit, weekly_limit, tier = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
usage = await get_usage_status(user_id, daily_limit, weekly_limit)
usage = await get_usage_status(user_id, daily_limit, weekly_limit, tier=tier)
try:
resolved_email = await get_user_email_by_id(user_id)
@@ -143,4 +158,102 @@ async def reset_user_rate_limit(
weekly_token_limit=weekly_limit,
daily_tokens_used=usage.daily.used,
weekly_tokens_used=usage.weekly.used,
tier=tier,
)
@router.get(
"/rate_limit/tier",
response_model=UserTierResponse,
summary="Get User Rate Limit Tier",
)
async def get_user_rate_limit_tier(
user_id: str,
admin_user_id: str = Security(get_user_id),
) -> UserTierResponse:
"""Get a user's current rate-limit tier. Admin-only.
Returns 404 if the user does not exist in the database.
"""
logger.info("Admin %s checking tier for user %s", admin_user_id, user_id)
resolved_email = await get_user_email_by_id(user_id)
if resolved_email is None:
raise HTTPException(status_code=404, detail=f"User {user_id} not found")
tier = await get_user_tier(user_id)
return UserTierResponse(user_id=user_id, tier=tier)
@router.post(
"/rate_limit/tier",
response_model=UserTierResponse,
summary="Set User Rate Limit Tier",
)
async def set_user_rate_limit_tier(
request: SetUserTierRequest,
admin_user_id: str = Security(get_user_id),
) -> UserTierResponse:
"""Set a user's rate-limit tier. Admin-only.
Returns 404 if the user does not exist in the database.
"""
try:
resolved_email = await get_user_email_by_id(request.user_id)
except Exception:
logger.warning(
"Failed to resolve email for user %s",
request.user_id,
exc_info=True,
)
resolved_email = None
if resolved_email is None:
raise HTTPException(status_code=404, detail=f"User {request.user_id} not found")
old_tier = await get_user_tier(request.user_id)
logger.info(
"Admin %s changing tier for user %s (%s): %s -> %s",
admin_user_id,
request.user_id,
resolved_email,
old_tier.value,
request.tier.value,
)
try:
await set_user_tier(request.user_id, request.tier)
except Exception as e:
logger.exception("Failed to set user tier")
raise HTTPException(status_code=500, detail="Failed to set tier") from e
return UserTierResponse(user_id=request.user_id, tier=request.tier)
class UserSearchResult(BaseModel):
user_id: str
user_email: Optional[str] = None
@router.get(
"/rate_limit/search_users",
response_model=list[UserSearchResult],
summary="Search Users by Name or Email",
)
async def admin_search_users(
query: str,
limit: int = 20,
admin_user_id: str = Security(get_user_id),
) -> list[UserSearchResult]:
"""Search users by partial email or name. Admin-only.
Queries the User table directly — returns results even for users
without credit transaction history.
"""
if len(query.strip()) < 3:
raise HTTPException(
status_code=400,
detail="Search query must be at least 3 characters.",
)
logger.info("Admin %s searching users with query=%r", admin_user_id, query)
results = await search_users(query, limit=max(1, min(limit, 50)))
return [UserSearchResult(user_id=uid, user_email=email) for uid, email in results]

View File

@@ -9,7 +9,7 @@ import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from pytest_snapshot.plugin import Snapshot
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
from backend.copilot.rate_limit import CoPilotUsageStatus, SubscriptionTier, UsageWindow
from .rate_limit_admin_routes import router as rate_limit_admin_router
@@ -57,7 +57,7 @@ def _patch_rate_limit_deps(
mocker.patch(
f"{_MOCK_MODULE}.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(2_500_000, 12_500_000),
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
)
mocker.patch(
f"{_MOCK_MODULE}.get_usage_status",
@@ -89,6 +89,7 @@ def test_get_rate_limit(
assert data["weekly_token_limit"] == 12_500_000
assert data["daily_tokens_used"] == 500_000
assert data["weekly_tokens_used"] == 3_000_000
assert data["tier"] == "FREE"
configured_snapshot.assert_match(
json.dumps(data, indent=2, sort_keys=True) + "\n",
@@ -162,6 +163,7 @@ def test_reset_user_usage_daily_only(
assert data["daily_tokens_used"] == 0
# Weekly is untouched
assert data["weekly_tokens_used"] == 3_000_000
assert data["tier"] == "FREE"
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
@@ -192,6 +194,7 @@ def test_reset_user_usage_daily_and_weekly(
data = response.json()
assert data["daily_tokens_used"] == 0
assert data["weekly_tokens_used"] == 0
assert data["tier"] == "FREE"
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)
@@ -228,7 +231,7 @@ def test_get_rate_limit_email_lookup_failure(
mocker.patch(
f"{_MOCK_MODULE}.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(2_500_000, 12_500_000),
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
)
mocker.patch(
f"{_MOCK_MODULE}.get_usage_status",
@@ -261,3 +264,303 @@ def test_admin_endpoints_require_admin_role(mock_jwt_user) -> None:
json={"user_id": "test"},
)
assert response.status_code == 403
# ---------------------------------------------------------------------------
# Tier management endpoints
# ---------------------------------------------------------------------------
def test_get_user_tier(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test getting a user's rate-limit tier."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=_TARGET_EMAIL,
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.PRO,
)
response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id})
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["tier"] == "PRO"
def test_get_user_tier_user_not_found(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that getting tier for a non-existent user returns 404."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=None,
)
response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id})
assert response.status_code == 404
def test_set_user_tier(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test setting a user's rate-limit tier (upgrade)."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=_TARGET_EMAIL,
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.FREE,
)
mock_set = mocker.patch(
f"{_MOCK_MODULE}.set_user_tier",
new_callable=AsyncMock,
)
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "ENTERPRISE"},
)
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["tier"] == "ENTERPRISE"
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.ENTERPRISE)
def test_set_user_tier_downgrade(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test downgrading a user's tier from PRO to FREE."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=_TARGET_EMAIL,
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.PRO,
)
mock_set = mocker.patch(
f"{_MOCK_MODULE}.set_user_tier",
new_callable=AsyncMock,
)
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "FREE"},
)
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["tier"] == "FREE"
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.FREE)
def test_set_user_tier_invalid_tier(
target_user_id: str,
) -> None:
"""Test that setting an invalid tier returns 422."""
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "invalid"},
)
assert response.status_code == 422
def test_set_user_tier_invalid_tier_uppercase(
target_user_id: str,
) -> None:
"""Test that setting an unrecognised uppercase tier (e.g. 'INVALID') returns 422.
Regression: ensures Pydantic enum validation rejects values that are not
members of SubscriptionTier, even when they look like valid enum names.
"""
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "INVALID"},
)
assert response.status_code == 422
body = response.json()
assert "detail" in body
def test_set_user_tier_email_lookup_failure_returns_404(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that email lookup failure returns 404 (user unverifiable)."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
side_effect=Exception("DB connection failed"),
)
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "PRO"},
)
assert response.status_code == 404
def test_set_user_tier_user_not_found(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that setting tier for a non-existent user returns 404."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=None,
)
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "PRO"},
)
assert response.status_code == 404
def test_set_user_tier_db_failure(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that DB failure on set tier returns 500."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=_TARGET_EMAIL,
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.FREE,
)
mocker.patch(
f"{_MOCK_MODULE}.set_user_tier",
new_callable=AsyncMock,
side_effect=Exception("DB connection refused"),
)
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "PRO"},
)
assert response.status_code == 500
def test_tier_endpoints_require_admin_role(mock_jwt_user) -> None:
"""Test that tier admin endpoints require admin role."""
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
response = client.get("/admin/rate_limit/tier", params={"user_id": "test"})
assert response.status_code == 403
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": "test", "tier": "PRO"},
)
assert response.status_code == 403
# ─── search_users endpoint ──────────────────────────────────────────
def test_search_users_returns_matching_users(
mocker: pytest_mock.MockerFixture,
admin_user_id: str,
) -> None:
"""Partial search should return all matching users from the User table."""
mocker.patch(
_MOCK_MODULE + ".search_users",
new_callable=AsyncMock,
return_value=[
("user-1", "zamil.majdy@gmail.com"),
("user-2", "zamil.majdy@agpt.co"),
],
)
response = client.get("/admin/rate_limit/search_users", params={"query": "zamil"})
assert response.status_code == 200
results = response.json()
assert len(results) == 2
assert results[0]["user_email"] == "zamil.majdy@gmail.com"
assert results[1]["user_email"] == "zamil.majdy@agpt.co"
def test_search_users_empty_results(
mocker: pytest_mock.MockerFixture,
admin_user_id: str,
) -> None:
"""Search with no matches returns empty list."""
mocker.patch(
_MOCK_MODULE + ".search_users",
new_callable=AsyncMock,
return_value=[],
)
response = client.get(
"/admin/rate_limit/search_users", params={"query": "nonexistent"}
)
assert response.status_code == 200
assert response.json() == []
def test_search_users_short_query_rejected(
admin_user_id: str,
) -> None:
"""Query shorter than 3 characters should return 400."""
response = client.get("/admin/rate_limit/search_users", params={"query": "ab"})
assert response.status_code == 400
def test_search_users_negative_limit_clamped(
mocker: pytest_mock.MockerFixture,
admin_user_id: str,
) -> None:
"""Negative limit should be clamped to 1, not passed through."""
mock_search = mocker.patch(
_MOCK_MODULE + ".search_users",
new_callable=AsyncMock,
return_value=[],
)
response = client.get(
"/admin/rate_limit/search_users", params={"query": "test", "limit": -1}
)
assert response.status_code == 200
mock_search.assert_awaited_once_with("test", limit=1)
def test_search_users_requires_admin_role(mock_jwt_user) -> None:
"""Test that the search_users endpoint requires admin role."""
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
response = client.get("/admin/rate_limit/search_users", params={"query": "test"})
assert response.status_code == 403

View File

@@ -11,15 +11,17 @@ from autogpt_libs import auth
from fastapi import APIRouter, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from prisma.models import UserWorkspaceFile
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
from backend.copilot.config import ChatConfig
from backend.copilot.config import ChatConfig, CopilotMode
from backend.copilot.db import get_chat_messages_paginated
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
from backend.copilot.model import (
ChatMessage,
ChatSession,
ChatSessionMetadata,
append_and_save_message,
create_chat_session,
delete_chat_session,
@@ -110,6 +112,23 @@ class StreamChatRequest(BaseModel):
file_ids: list[str] | None = Field(
default=None, max_length=20
) # Workspace file IDs attached to this message
mode: CopilotMode | None = Field(
default=None,
description="Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. "
"If None, uses the server default (extended_thinking).",
)
class CreateSessionRequest(BaseModel):
"""Request model for creating a new chat session.
``dry_run`` is a **top-level** field — do not nest it inside ``metadata``.
Extra/unknown fields are rejected (422) to prevent silent mis-use.
"""
model_config = ConfigDict(extra="forbid")
dry_run: bool = False
class CreateSessionResponse(BaseModel):
@@ -118,6 +137,7 @@ class CreateSessionResponse(BaseModel):
id: str
created_at: str
user_id: str | None
metadata: ChatSessionMetadata = ChatSessionMetadata()
class ActiveStreamInfo(BaseModel):
@@ -136,8 +156,11 @@ class SessionDetailResponse(BaseModel):
user_id: str | None
messages: list[dict]
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
has_more_messages: bool = False
oldest_sequence: int | None = None
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
metadata: ChatSessionMetadata = ChatSessionMetadata()
class SessionSummaryResponse(BaseModel):
@@ -248,6 +271,7 @@ async def list_sessions(
)
async def create_session(
user_id: Annotated[str, Security(auth.get_user_id)],
request: CreateSessionRequest | None = None,
) -> CreateSessionResponse:
"""
Create a new chat session.
@@ -256,22 +280,28 @@ async def create_session(
Args:
user_id: The authenticated user ID parsed from the JWT (required).
request: Optional request body. When provided, ``dry_run=True``
forces run_block and run_agent calls to use dry-run simulation.
Returns:
CreateSessionResponse: Details of the created session.
"""
dry_run = request.dry_run if request else False
logger.info(
f"Creating session with user_id: "
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
f"{', dry_run=True' if dry_run else ''}"
)
session = await create_chat_session(user_id)
session = await create_chat_session(user_id, dry_run=dry_run)
return CreateSessionResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
user_id=session.user_id,
metadata=session.metadata,
)
@@ -367,59 +397,78 @@ async def update_session_title_route(
async def get_session(
session_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
limit: int = Query(default=50, ge=1, le=200),
before_sequence: int | None = Query(default=None, ge=0),
) -> SessionDetailResponse:
"""
Retrieve the details of a specific chat session.
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
If there's an active stream for this session, returns active_stream info for reconnection.
Supports cursor-based pagination via ``limit`` and ``before_sequence``.
When no pagination params are provided, returns the most recent messages.
Args:
session_id: The unique identifier for the desired chat session.
user_id: The optional authenticated user ID, or None for anonymous access.
user_id: The authenticated user's ID.
limit: Maximum number of messages to return (1-200, default 50).
before_sequence: Return messages with sequence < this value (cursor).
Returns:
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
SessionDetailResponse: Details for the requested session, including
active_stream info and pagination metadata.
"""
session = await get_chat_session(session_id, user_id)
if not session:
page = await get_chat_messages_paginated(
session_id, limit, before_sequence, user_id=user_id
)
if page is None:
raise NotFoundError(f"Session {session_id} not found.")
messages = [message.model_dump() for message in page.messages]
messages = [message.model_dump() for message in session.messages]
# Check if there's an active stream for this session
# Only check active stream on initial load (not on "load more" requests)
active_stream_info = None
active_session, last_message_id = await stream_registry.get_active_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_session:
# Keep the assistant message (including tool_calls) so the frontend can
# render the correct tool UI (e.g. CreateAgent with mini game).
# convertChatSessionToUiMessages handles isComplete=false by setting
# tool parts without output to state "input-available".
active_stream_info = ActiveStreamInfo(
turn_id=active_session.turn_id,
last_message_id=last_message_id,
if before_sequence is None:
active_session, last_message_id = await stream_registry.get_active_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_session:
active_stream_info = ActiveStreamInfo(
turn_id=active_session.turn_id,
last_message_id=last_message_id,
)
# Skip session metadata on "load more" — frontend only needs messages
if before_sequence is not None:
return SessionDetailResponse(
id=page.session.session_id,
created_at=page.session.started_at.isoformat(),
updated_at=page.session.updated_at.isoformat(),
user_id=page.session.user_id or None,
messages=messages,
active_stream=None,
has_more_messages=page.has_more,
oldest_sequence=page.oldest_sequence,
total_prompt_tokens=0,
total_completion_tokens=0,
)
# Sum token usage from session
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
total_prompt = sum(u.prompt_tokens for u in page.session.usage)
total_completion = sum(u.completion_tokens for u in page.session.usage)
return SessionDetailResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
updated_at=session.updated_at.isoformat(),
user_id=session.user_id or None,
id=page.session.session_id,
created_at=page.session.started_at.isoformat(),
updated_at=page.session.updated_at.isoformat(),
user_id=page.session.user_id or None,
messages=messages,
active_stream=active_stream_info,
has_more_messages=page.has_more,
oldest_sequence=page.oldest_sequence,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
metadata=page.session.metadata,
)
@@ -433,8 +482,9 @@ async def get_copilot_usage(
Returns current token usage vs limits for daily and weekly windows.
Global defaults sourced from LaunchDarkly (falling back to config).
Includes the user's rate-limit tier.
"""
daily_limit, weekly_limit = await get_global_rate_limits(
daily_limit, weekly_limit, tier = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
return await get_usage_status(
@@ -442,6 +492,7 @@ async def get_copilot_usage(
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
rate_limit_reset_cost=config.rate_limit_reset_cost,
tier=tier,
)
@@ -493,7 +544,7 @@ async def reset_copilot_usage(
detail="Rate limit reset is not available (credit system is disabled).",
)
daily_limit, weekly_limit = await get_global_rate_limits(
daily_limit, weekly_limit, tier = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
@@ -527,10 +578,13 @@ async def reset_copilot_usage(
try:
# Verify the user is actually at or over their daily limit.
# (rate_limit_reset_cost intentionally omitted — this object is only
# used for limit checks, not returned to the client.)
usage_status = await get_usage_status(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
tier=tier,
)
if daily_limit > 0 and usage_status.daily.used < daily_limit:
raise HTTPException(
@@ -606,6 +660,7 @@ async def reset_copilot_usage(
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
rate_limit_reset_cost=config.rate_limit_reset_cost,
tier=tier,
)
return RateLimitResetResponse(
@@ -716,7 +771,7 @@ async def stream_chat_post(
# Global defaults sourced from LaunchDarkly, falling back to config.
if user_id:
try:
daily_limit, weekly_limit = await get_global_rate_limits(
daily_limit, weekly_limit, _ = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
await check_rate_limit(
@@ -811,6 +866,7 @@ async def stream_chat_post(
is_user_message=request.is_user_message,
context=request.context,
file_ids=sanitized_file_ids,
mode=request.mode,
)
setup_time = (time.perf_counter() - stream_start_time) * 1000
@@ -1174,7 +1230,7 @@ async def health_check() -> dict:
)
# Create and retrieve session to verify full data layer
session = await create_chat_session(health_check_user_id)
session = await create_chat_session(health_check_user_id, dry_run=False)
await get_chat_session(session.session_id, health_check_user_id)
return {

View File

@@ -9,6 +9,7 @@ import pytest
import pytest_mock
from backend.api.features.chat import routes as chat_routes
from backend.copilot.rate_limit import SubscriptionTier
app = fastapi.FastAPI()
app.include_router(chat_routes.router)
@@ -331,14 +332,28 @@ def _mock_usage(
*,
daily_used: int = 500,
weekly_used: int = 2000,
daily_limit: int = 10000,
weekly_limit: int = 50000,
tier: "SubscriptionTier" = SubscriptionTier.FREE,
) -> AsyncMock:
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
"""Mock get_usage_status and get_global_rate_limits for usage endpoint tests.
Mocks both ``get_global_rate_limits`` (returns the given limits + tier) and
``get_usage_status`` so that tests exercise the endpoint without hitting
LaunchDarkly or Prisma.
"""
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
mocker.patch(
"backend.api.features.chat.routes.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(daily_limit, weekly_limit, tier),
)
resets_at = datetime.now(UTC) + timedelta(days=1)
status = CoPilotUsageStatus(
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
daily=UsageWindow(used=daily_used, limit=daily_limit, resets_at=resets_at),
weekly=UsageWindow(used=weekly_used, limit=weekly_limit, resets_at=resets_at),
)
return mocker.patch(
"backend.api.features.chat.routes.get_usage_status",
@@ -369,6 +384,7 @@ def test_usage_returns_daily_and_weekly(
daily_token_limit=10000,
weekly_token_limit=50000,
rate_limit_reset_cost=chat_routes.config.rate_limit_reset_cost,
tier=SubscriptionTier.FREE,
)
@@ -376,11 +392,9 @@ def test_usage_uses_config_limits(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
mock_get = _mock_usage(mocker)
"""The endpoint forwards resolved limits from get_global_rate_limits to get_usage_status."""
mock_get = _mock_usage(mocker, daily_limit=99999, weekly_limit=77777)
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 500)
response = client.get("/usage")
@@ -391,6 +405,7 @@ def test_usage_uses_config_limits(
daily_token_limit=99999,
weekly_token_limit=77777,
rate_limit_reset_cost=500,
tier=SubscriptionTier.FREE,
)
@@ -469,3 +484,98 @@ def test_suggested_prompts_empty_prompts(
assert response.status_code == 200
assert response.json() == {"themes": []}
# ─── Create session: dry_run contract ─────────────────────────────────
def _mock_create_chat_session(mocker: pytest_mock.MockerFixture):
"""Mock create_chat_session to return a fake session."""
from backend.copilot.model import ChatSession
async def _fake_create(user_id: str, *, dry_run: bool):
return ChatSession.new(user_id, dry_run=dry_run)
return mocker.patch(
"backend.api.features.chat.routes.create_chat_session",
new_callable=AsyncMock,
side_effect=_fake_create,
)
def test_create_session_dry_run_true(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""Sending ``{"dry_run": true}`` sets metadata.dry_run to True."""
_mock_create_chat_session(mocker)
response = client.post("/sessions", json={"dry_run": True})
assert response.status_code == 200
assert response.json()["metadata"]["dry_run"] is True
def test_create_session_dry_run_default_false(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""Empty body defaults dry_run to False."""
_mock_create_chat_session(mocker)
response = client.post("/sessions")
assert response.status_code == 200
assert response.json()["metadata"]["dry_run"] is False
def test_create_session_rejects_nested_metadata(
test_user_id: str,
) -> None:
"""Sending ``{"metadata": {"dry_run": true}}`` must return 422, not silently
default to ``dry_run=False``. This guards against the common mistake of
nesting dry_run inside metadata instead of providing it at the top level."""
response = client.post(
"/sessions",
json={"metadata": {"dry_run": True}},
)
assert response.status_code == 422
class TestStreamChatRequestModeValidation:
"""Pydantic-level validation of the ``mode`` field on StreamChatRequest."""
def test_rejects_invalid_mode_value(self) -> None:
"""Any string outside the Literal set must raise ValidationError."""
from pydantic import ValidationError
from backend.api.features.chat.routes import StreamChatRequest
with pytest.raises(ValidationError):
StreamChatRequest(message="hi", mode="turbo") # type: ignore[arg-type]
def test_accepts_fast_mode(self) -> None:
from backend.api.features.chat.routes import StreamChatRequest
req = StreamChatRequest(message="hi", mode="fast")
assert req.mode == "fast"
def test_accepts_extended_thinking_mode(self) -> None:
from backend.api.features.chat.routes import StreamChatRequest
req = StreamChatRequest(message="hi", mode="extended_thinking")
assert req.mode == "extended_thinking"
def test_accepts_none_mode(self) -> None:
"""``mode=None`` is valid (server decides via feature flags)."""
from backend.api.features.chat.routes import StreamChatRequest
req = StreamChatRequest(message="hi", mode=None)
assert req.mode is None
def test_mode_defaults_to_none_when_omitted(self) -> None:
from backend.api.features.chat.routes import StreamChatRequest
req = StreamChatRequest(message="hi")
assert req.mode is None

View File

@@ -40,11 +40,15 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step
from backend.data.user import get_user_integrations
from backend.executor.utils import add_graph_execution
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
from backend.integrations.credentials_store import provider_matches
from backend.integrations.credentials_store import (
is_system_credential,
provider_matches,
)
from backend.integrations.creds_manager import (
IntegrationCredentialsManager,
create_mcp_oauth_handler,
)
from backend.integrations.managed_credentials import ensure_managed_credentials
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
@@ -110,6 +114,7 @@ class CredentialsMetaResponse(BaseModel):
default=None,
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
)
is_managed: bool = False
@model_validator(mode="before")
@classmethod
@@ -148,6 +153,7 @@ def to_meta_response(cred: Credentials) -> CredentialsMetaResponse:
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=CredentialsMetaResponse.get_host(cred),
is_managed=cred.is_managed,
)
@@ -224,6 +230,9 @@ async def callback(
async def list_credentials(
user_id: Annotated[str, Security(get_user_id)],
) -> list[CredentialsMetaResponse]:
# Fire-and-forget: provision missing managed credentials in the background.
# The credential appears on the next page load; listing is never blocked.
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
credentials = await creds_manager.store.get_all_creds(user_id)
return [
@@ -238,6 +247,7 @@ async def list_credentials_by_provider(
],
user_id: Annotated[str, Security(get_user_id)],
) -> list[CredentialsMetaResponse]:
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
return [
@@ -332,6 +342,11 @@ async def delete_credentials(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if is_system_credential(cred_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="System-managed credentials cannot be deleted",
)
creds = await creds_manager.store.get_creds_by_id(user_id, cred_id)
if not creds:
raise HTTPException(
@@ -342,6 +357,11 @@ async def delete_credentials(
status_code=status.HTTP_404_NOT_FOUND,
detail="Credentials not found",
)
if creds.is_managed:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="AutoGPT-managed credentials cannot be deleted",
)
try:
await remove_all_webhooks_for_credentials(user_id, creds, force)

View File

@@ -1,6 +1,7 @@
"""Tests for credentials API security: no secret leakage, SDK defaults filtered."""
from unittest.mock import AsyncMock, patch
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
import fastapi
import fastapi.testclient
@@ -276,3 +277,294 @@ class TestCreateCredentialNoSecretInResponse:
assert resp.status_code == 403
mock_mgr.create.assert_not_called()
class TestManagedCredentials:
"""AutoGPT-managed credentials cannot be deleted by users."""
def test_delete_is_managed_returns_403(self):
cred = APIKeyCredentials(
id="managed-cred-1",
provider="agent_mail",
title="AgentMail (managed by AutoGPT)",
api_key=SecretStr("sk-managed-key"),
is_managed=True,
)
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.store.get_creds_by_id = AsyncMock(return_value=cred)
resp = client.request("DELETE", "/agent_mail/credentials/managed-cred-1")
assert resp.status_code == 403
assert "AutoGPT-managed" in resp.json()["detail"]
def test_list_credentials_includes_is_managed_field(self):
managed = APIKeyCredentials(
id="managed-1",
provider="agent_mail",
title="AgentMail (managed)",
api_key=SecretStr("sk-key"),
is_managed=True,
)
regular = APIKeyCredentials(
id="regular-1",
provider="openai",
title="My Key",
api_key=SecretStr("sk-key"),
)
with patch(
"backend.api.features.integrations.router.creds_manager"
) as mock_mgr:
mock_mgr.store.get_all_creds = AsyncMock(return_value=[managed, regular])
resp = client.get("/credentials")
assert resp.status_code == 200
data = resp.json()
managed_cred = next(c for c in data if c["id"] == "managed-1")
regular_cred = next(c for c in data if c["id"] == "regular-1")
assert managed_cred["is_managed"] is True
assert regular_cred["is_managed"] is False
# ---------------------------------------------------------------------------
# Managed credential provisioning infrastructure
# ---------------------------------------------------------------------------
def _make_managed_cred(
provider: str = "agent_mail", pod_id: str = "pod-abc"
) -> APIKeyCredentials:
return APIKeyCredentials(
id="managed-auto",
provider=provider,
title="AgentMail (managed by AutoGPT)",
api_key=SecretStr("sk-pod-key"),
is_managed=True,
metadata={"pod_id": pod_id},
)
def _make_store_mock(**kwargs) -> MagicMock:
"""Create a store mock with a working async ``locks()`` context manager."""
@asynccontextmanager
async def _noop_locked(key):
yield
locks_obj = MagicMock()
locks_obj.locked = _noop_locked
store = MagicMock(**kwargs)
store.locks = AsyncMock(return_value=locks_obj)
return store
class TestEnsureManagedCredentials:
"""Unit tests for the ensure/cleanup helpers in managed_credentials.py."""
@pytest.mark.asyncio
async def test_provisions_when_missing(self):
"""Provider.provision() is called when no managed credential exists."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
cred = _make_managed_cred()
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=True)
provider.provision = AsyncMock(return_value=cred)
store = _make_store_mock()
store.has_managed_credential = AsyncMock(return_value=False)
store.add_managed_credential = AsyncMock()
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
provider.provision.assert_awaited_once_with("user-1")
store.add_managed_credential.assert_awaited_once_with("user-1", cred)
@pytest.mark.asyncio
async def test_skips_when_already_exists(self):
"""Provider.provision() is NOT called when managed credential exists."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=True)
provider.provision = AsyncMock()
store = _make_store_mock()
store.has_managed_credential = AsyncMock(return_value=True)
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
provider.provision.assert_not_awaited()
@pytest.mark.asyncio
async def test_skips_when_unavailable(self):
"""Provider.provision() is NOT called when provider is not available."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=False)
provider.provision = AsyncMock()
store = _make_store_mock()
store.has_managed_credential = AsyncMock()
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
provider.provision.assert_not_awaited()
store.has_managed_credential.assert_not_awaited()
@pytest.mark.asyncio
async def test_provision_failure_does_not_propagate(self):
"""A failed provision is logged but does not raise."""
from backend.integrations.managed_credentials import (
_PROVIDERS,
_provisioned_users,
ensure_managed_credentials,
)
provider = MagicMock()
provider.provider_name = "test_provider"
provider.is_available = AsyncMock(return_value=True)
provider.provision = AsyncMock(side_effect=RuntimeError("boom"))
store = _make_store_mock()
store.has_managed_credential = AsyncMock(return_value=False)
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["test_provider"] = provider
_provisioned_users.pop("user-1", None)
try:
await ensure_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
_provisioned_users.pop("user-1", None)
# No exception raised — provisioning failure is swallowed.
class TestCleanupManagedCredentials:
"""Unit tests for cleanup_managed_credentials."""
@pytest.mark.asyncio
async def test_calls_deprovision_for_managed_creds(self):
from backend.integrations.managed_credentials import (
_PROVIDERS,
cleanup_managed_credentials,
)
cred = _make_managed_cred()
provider = MagicMock()
provider.provider_name = "agent_mail"
provider.deprovision = AsyncMock()
store = MagicMock()
store.get_all_creds = AsyncMock(return_value=[cred])
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["agent_mail"] = provider
try:
await cleanup_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
provider.deprovision.assert_awaited_once_with("user-1", cred)
@pytest.mark.asyncio
async def test_skips_non_managed_creds(self):
from backend.integrations.managed_credentials import (
_PROVIDERS,
cleanup_managed_credentials,
)
regular = _make_api_key_cred()
provider = MagicMock()
provider.provider_name = "openai"
provider.deprovision = AsyncMock()
store = MagicMock()
store.get_all_creds = AsyncMock(return_value=[regular])
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["openai"] = provider
try:
await cleanup_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
provider.deprovision.assert_not_awaited()
@pytest.mark.asyncio
async def test_deprovision_failure_does_not_propagate(self):
from backend.integrations.managed_credentials import (
_PROVIDERS,
cleanup_managed_credentials,
)
cred = _make_managed_cred()
provider = MagicMock()
provider.provider_name = "agent_mail"
provider.deprovision = AsyncMock(side_effect=RuntimeError("boom"))
store = MagicMock()
store.get_all_creds = AsyncMock(return_value=[cred])
saved = dict(_PROVIDERS)
_PROVIDERS.clear()
_PROVIDERS["agent_mail"] = provider
try:
await cleanup_managed_credentials("user-1", store)
finally:
_PROVIDERS.clear()
_PROVIDERS.update(saved)
# No exception raised — cleanup failure is swallowed.

View File

@@ -481,6 +481,11 @@ async def create_library_agent(
sensitive_action_safe_mode=sensitive_action_safe_mode,
).model_dump()
),
**(
{"Folder": {"connect": {"id": folder_id}}}
if folder_id and graph_entry is graph
else {}
),
},
},
include=library_agent_include(

View File

@@ -12,6 +12,7 @@ Tests cover:
5. Complete OAuth flow end-to-end
"""
import asyncio
import base64
import hashlib
import secrets
@@ -58,14 +59,27 @@ async def test_user(server, test_user_id: str):
yield test_user_id
# Cleanup - delete in correct order due to foreign key constraints
await PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id})
await PrismaOAuthRefreshToken.prisma().delete_many(where={"userId": test_user_id})
await PrismaOAuthAuthorizationCode.prisma().delete_many(
where={"userId": test_user_id}
)
await PrismaOAuthApplication.prisma().delete_many(where={"ownerId": test_user_id})
await PrismaUser.prisma().delete(where={"id": test_user_id})
# Cleanup - delete in correct order due to foreign key constraints.
# Wrap in try/except because the event loop or Prisma engine may already
# be closed during session teardown on Python 3.12+.
try:
await asyncio.gather(
PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id}),
PrismaOAuthRefreshToken.prisma().delete_many(
where={"userId": test_user_id}
),
PrismaOAuthAuthorizationCode.prisma().delete_many(
where={"userId": test_user_id}
),
)
await asyncio.gather(
PrismaOAuthApplication.prisma().delete_many(
where={"ownerId": test_user_id}
),
PrismaUser.prisma().delete(where={"id": test_user_id}),
)
except RuntimeError:
pass
@pytest_asyncio.fixture

View File

@@ -0,0 +1,61 @@
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import pytest
from backend.api.features.v1 import v1_router
app = fastapi.FastAPI()
app.include_router(v1_router)
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def test_onboarding_profile_success(mocker):
mock_extract = mocker.patch(
"backend.api.features.v1.extract_business_understanding",
new_callable=AsyncMock,
)
mock_upsert = mocker.patch(
"backend.api.features.v1.upsert_business_understanding",
new_callable=AsyncMock,
)
from backend.data.understanding import BusinessUnderstandingInput
mock_extract.return_value = BusinessUnderstandingInput.model_construct(
user_name="John",
user_role="Founder/CEO",
pain_points=["Finding leads"],
suggested_prompts={"Learn": ["How do I automate lead gen?"]},
)
mock_upsert.return_value = AsyncMock()
response = client.post(
"/onboarding/profile",
json={
"user_name": "John",
"user_role": "Founder/CEO",
"pain_points": ["Finding leads", "Email & outreach"],
},
)
assert response.status_code == 200
mock_extract.assert_awaited_once()
mock_upsert.assert_awaited_once()
def test_onboarding_profile_missing_fields():
response = client.post(
"/onboarding/profile",
json={"user_name": "John"},
)
assert response.status_code == 422

View File

@@ -189,6 +189,7 @@ async def test_create_store_submission(mocker):
notifyOnAgentApproved=True,
notifyOnAgentRejected=True,
timezone="Europe/Delft",
subscriptionTier=prisma.enums.SubscriptionTier.FREE, # type: ignore[reportCallIssue,reportAttributeAccessIssue]
)
mock_agent = prisma.models.AgentGraph(
id="agent-id",

View File

@@ -63,12 +63,17 @@ from backend.data.onboarding import (
UserOnboardingUpdate,
complete_onboarding_step,
complete_re_run_agent,
format_onboarding_for_extraction,
get_recommended_agents,
get_user_onboarding,
onboarding_enabled,
reset_user_onboarding,
update_user_onboarding,
)
from backend.data.tally import extract_business_understanding
from backend.data.understanding import (
BusinessUnderstandingInput,
upsert_business_understanding,
)
from backend.data.user import (
get_or_create_user,
get_user_by_id,
@@ -282,35 +287,33 @@ async def get_onboarding_agents(
return await get_recommended_agents(user_id)
class OnboardingStatusResponse(pydantic.BaseModel):
"""Response for onboarding status check."""
class OnboardingProfileRequest(pydantic.BaseModel):
"""Request body for onboarding profile submission."""
is_onboarding_enabled: bool
is_chat_enabled: bool
user_name: str = pydantic.Field(min_length=1, max_length=100)
user_role: str = pydantic.Field(min_length=1, max_length=100)
pain_points: list[str] = pydantic.Field(default_factory=list, max_length=20)
class OnboardingStatusResponse(pydantic.BaseModel):
"""Response for onboarding completion check."""
is_completed: bool
@v1_router.get(
"/onboarding/enabled",
summary="Is onboarding enabled",
"/onboarding/completed",
summary="Check if onboarding is completed",
tags=["onboarding", "public"],
response_model=OnboardingStatusResponse,
dependencies=[Security(requires_user)],
)
async def is_onboarding_enabled(
async def is_onboarding_completed(
user_id: Annotated[str, Security(get_user_id)],
) -> OnboardingStatusResponse:
# Check if chat is enabled for user
is_chat_enabled = await is_feature_enabled(Flag.CHAT, user_id, False)
# If chat is enabled, skip legacy onboarding
if is_chat_enabled:
return OnboardingStatusResponse(
is_onboarding_enabled=False,
is_chat_enabled=True,
)
user_onboarding = await get_user_onboarding(user_id)
return OnboardingStatusResponse(
is_onboarding_enabled=await onboarding_enabled(),
is_chat_enabled=False,
is_completed=OnboardingStep.VISIT_COPILOT in user_onboarding.completedSteps,
)
@@ -325,6 +328,38 @@ async def reset_onboarding(user_id: Annotated[str, Security(get_user_id)]):
return await reset_user_onboarding(user_id)
@v1_router.post(
"/onboarding/profile",
summary="Submit onboarding profile",
tags=["onboarding"],
dependencies=[Security(requires_user)],
)
async def submit_onboarding_profile(
data: OnboardingProfileRequest,
user_id: Annotated[str, Security(get_user_id)],
):
formatted = format_onboarding_for_extraction(
user_name=data.user_name,
user_role=data.user_role,
pain_points=data.pain_points,
)
try:
understanding_input = await extract_business_understanding(formatted)
except Exception:
understanding_input = BusinessUnderstandingInput.model_construct()
# Ensure the direct fields are set even if LLM missed them
understanding_input.user_name = data.user_name
understanding_input.user_role = data.user_role
if not understanding_input.pain_points:
understanding_input.pain_points = data.pain_points
await upsert_business_understanding(user_id, understanding_input)
return {"status": "ok"}
########################################################
##################### Blocks ###########################
########################################################

View File

@@ -12,7 +12,7 @@ import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
from fastapi import Query, UploadFile
from fastapi.responses import Response
from pydantic import BaseModel
from pydantic import BaseModel, Field
from backend.data.workspace import (
WorkspaceFile,
@@ -131,9 +131,26 @@ class StorageUsageResponse(BaseModel):
file_count: int
class WorkspaceFileItem(BaseModel):
id: str
name: str
path: str
mime_type: str
size_bytes: int
metadata: dict = Field(default_factory=dict)
created_at: str
class ListFilesResponse(BaseModel):
files: list[WorkspaceFileItem]
offset: int = 0
has_more: bool = False
@router.get(
"/files/{file_id}/download",
summary="Download file by ID",
operation_id="getWorkspaceDownloadFileById",
)
async def download_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
@@ -158,6 +175,7 @@ async def download_file(
@router.delete(
"/files/{file_id}",
summary="Delete a workspace file",
operation_id="deleteWorkspaceFile",
)
async def delete_workspace_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
@@ -183,6 +201,7 @@ async def delete_workspace_file(
@router.post(
"/files/upload",
summary="Upload file to workspace",
operation_id="uploadWorkspaceFile",
)
async def upload_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
@@ -196,6 +215,9 @@ async def upload_file(
Files are stored in session-scoped paths when session_id is provided,
so the agent's session-scoped tools can discover them automatically.
"""
# Empty-string session_id drops session scoping; normalize to None.
session_id = session_id or None
config = Config()
# Sanitize filename — strip any directory components
@@ -250,16 +272,27 @@ async def upload_file(
manager = WorkspaceManager(user_id, workspace.id, session_id)
try:
workspace_file = await manager.write_file(
content, filename, overwrite=overwrite
content, filename, overwrite=overwrite, metadata={"origin": "user-upload"}
)
except ValueError as e:
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
# write_file raises ValueError for both path-conflict and size-limit
# cases; map each to its correct HTTP status.
message = str(e)
if message.startswith("File too large"):
raise fastapi.HTTPException(status_code=413, detail=message) from e
raise fastapi.HTTPException(status_code=409, detail=message) from e
# Post-write storage check — eliminates TOCTOU race on the quota.
# If a concurrent upload pushed us over the limit, undo this write.
new_total = await get_workspace_total_size(workspace.id)
if storage_limit_bytes and new_total > storage_limit_bytes:
await soft_delete_workspace_file(workspace_file.id, workspace.id)
try:
await soft_delete_workspace_file(workspace_file.id, workspace.id)
except Exception as e:
logger.warning(
f"Failed to soft-delete over-quota file {workspace_file.id} "
f"in workspace {workspace.id}: {e}"
)
raise fastapi.HTTPException(
status_code=413,
detail={
@@ -281,6 +314,7 @@ async def upload_file(
@router.get(
"/storage/usage",
summary="Get workspace storage usage",
operation_id="getWorkspaceStorageUsage",
)
async def get_storage_usage(
user_id: Annotated[str, fastapi.Security(get_user_id)],
@@ -301,3 +335,57 @@ async def get_storage_usage(
used_percent=round((used_bytes / limit_bytes) * 100, 1) if limit_bytes else 0,
file_count=file_count,
)
@router.get(
"/files",
summary="List workspace files",
operation_id="listWorkspaceFiles",
)
async def list_workspace_files(
user_id: Annotated[str, fastapi.Security(get_user_id)],
session_id: str | None = Query(default=None),
limit: int = Query(default=200, ge=1, le=1000),
offset: int = Query(default=0, ge=0),
) -> ListFilesResponse:
"""
List files in the user's workspace.
When session_id is provided, only files for that session are returned.
Otherwise, all files across sessions are listed. Results are paginated
via `limit`/`offset`; `has_more` indicates whether additional pages exist.
"""
workspace = await get_or_create_workspace(user_id)
# Treat empty-string session_id the same as omitted — an empty value
# would otherwise silently list files across every session instead of
# scoping to one.
session_id = session_id or None
manager = WorkspaceManager(user_id, workspace.id, session_id)
include_all = session_id is None
# Fetch one extra to compute has_more without a separate count query.
files = await manager.list_files(
limit=limit + 1,
offset=offset,
include_all_sessions=include_all,
)
has_more = len(files) > limit
page = files[:limit]
return ListFilesResponse(
files=[
WorkspaceFileItem(
id=f.id,
name=f.name,
path=f.path,
mime_type=f.mime_type,
size_bytes=f.size_bytes,
metadata=f.metadata or {},
created_at=f.created_at.isoformat(),
)
for f in page
],
offset=offset,
has_more=has_more,
)

View File

@@ -1,48 +1,28 @@
"""Tests for workspace file upload and download routes."""
import io
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from backend.api.features.workspace import routes as workspace_routes
from backend.data.workspace import WorkspaceFile
from backend.api.features.workspace.routes import router
from backend.data.workspace import Workspace, WorkspaceFile
app = fastapi.FastAPI()
app.include_router(workspace_routes.router)
app.include_router(router)
@app.exception_handler(ValueError)
async def _value_error_handler(
request: fastapi.Request, exc: ValueError
) -> fastapi.responses.JSONResponse:
"""Mirror the production ValueError → 400 mapping from rest_api.py."""
"""Mirror the production ValueError → 400 mapping from the REST app."""
return fastapi.responses.JSONResponse(status_code=400, content={"detail": str(exc)})
client = fastapi.testclient.TestClient(app)
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
MOCK_WORKSPACE = type("W", (), {"id": "ws-1"})()
_NOW = datetime(2023, 1, 1, tzinfo=timezone.utc)
MOCK_FILE = WorkspaceFile(
id="file-aaa-bbb",
workspace_id="ws-1",
created_at=_NOW,
updated_at=_NOW,
name="hello.txt",
path="/session/hello.txt",
mime_type="text/plain",
size_bytes=13,
storage_path="local://hello.txt",
)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
@@ -53,25 +33,201 @@ def setup_app_auth(mock_jwt_user):
app.dependency_overrides.clear()
def _make_workspace(user_id: str = "test-user-id") -> Workspace:
return Workspace(
id="ws-001",
user_id=user_id,
created_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
updated_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
)
def _make_file(**overrides) -> WorkspaceFile:
defaults = {
"id": "file-001",
"workspace_id": "ws-001",
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
"updated_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
"name": "test.txt",
"path": "/test.txt",
"storage_path": "local://test.txt",
"mime_type": "text/plain",
"size_bytes": 100,
"checksum": None,
"is_deleted": False,
"deleted_at": None,
"metadata": {},
}
defaults.update(overrides)
return WorkspaceFile(**defaults)
def _make_file_mock(**overrides) -> MagicMock:
"""Create a mock WorkspaceFile to simulate DB records with null fields."""
defaults = {
"id": "file-001",
"name": "test.txt",
"path": "/test.txt",
"mime_type": "text/plain",
"size_bytes": 100,
"metadata": {},
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
}
defaults.update(overrides)
mock = MagicMock(spec=WorkspaceFile)
for k, v in defaults.items():
setattr(mock, k, v)
return mock
# -- list_workspace_files tests --
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_returns_all_when_no_session(mock_manager_cls, mock_get_workspace):
mock_get_workspace.return_value = _make_workspace()
files = [
_make_file(id="f1", name="a.txt", metadata={"origin": "user-upload"}),
_make_file(id="f2", name="b.csv", metadata={"origin": "agent-created"}),
]
mock_instance = AsyncMock()
mock_instance.list_files.return_value = files
mock_manager_cls.return_value = mock_instance
response = client.get("/files")
assert response.status_code == 200
data = response.json()
assert len(data["files"]) == 2
assert data["has_more"] is False
assert data["offset"] == 0
assert data["files"][0]["id"] == "f1"
assert data["files"][0]["metadata"] == {"origin": "user-upload"}
assert data["files"][1]["id"] == "f2"
mock_instance.list_files.assert_called_once_with(
limit=201, offset=0, include_all_sessions=True
)
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_scopes_to_session_when_provided(
mock_manager_cls, mock_get_workspace, test_user_id
):
mock_get_workspace.return_value = _make_workspace(user_id=test_user_id)
mock_instance = AsyncMock()
mock_instance.list_files.return_value = []
mock_manager_cls.return_value = mock_instance
response = client.get("/files?session_id=sess-123")
assert response.status_code == 200
data = response.json()
assert data["files"] == []
assert data["has_more"] is False
mock_manager_cls.assert_called_once_with(test_user_id, "ws-001", "sess-123")
mock_instance.list_files.assert_called_once_with(
limit=201, offset=0, include_all_sessions=False
)
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_null_metadata_coerced_to_empty_dict(
mock_manager_cls, mock_get_workspace
):
"""Route uses `f.metadata or {}` for pre-existing files with null metadata."""
mock_get_workspace.return_value = _make_workspace()
mock_instance = AsyncMock()
mock_instance.list_files.return_value = [_make_file_mock(metadata=None)]
mock_manager_cls.return_value = mock_instance
response = client.get("/files")
assert response.status_code == 200
assert response.json()["files"][0]["metadata"] == {}
# -- upload_file metadata tests --
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.get_workspace_total_size")
@patch("backend.api.features.workspace.routes.scan_content_safe")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_upload_passes_user_upload_origin_metadata(
mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace
):
mock_get_workspace.return_value = _make_workspace()
mock_total_size.return_value = 100
written = _make_file(id="new-file", name="doc.pdf")
mock_instance = AsyncMock()
mock_instance.write_file.return_value = written
mock_manager_cls.return_value = mock_instance
response = client.post(
"/files/upload",
files={"file": ("doc.pdf", b"fake-pdf-content", "application/pdf")},
)
assert response.status_code == 200
mock_instance.write_file.assert_called_once()
call_kwargs = mock_instance.write_file.call_args
assert call_kwargs.kwargs.get("metadata") == {"origin": "user-upload"}
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.get_workspace_total_size")
@patch("backend.api.features.workspace.routes.scan_content_safe")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_upload_returns_409_on_file_conflict(
mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace
):
mock_get_workspace.return_value = _make_workspace()
mock_total_size.return_value = 100
mock_instance = AsyncMock()
mock_instance.write_file.side_effect = ValueError("File already exists at path")
mock_manager_cls.return_value = mock_instance
response = client.post(
"/files/upload",
files={"file": ("dup.txt", b"content", "text/plain")},
)
assert response.status_code == 409
assert "already exists" in response.json()["detail"]
# -- Restored upload/download/delete security + invariant tests --
def _upload(
filename: str = "hello.txt",
content: bytes = b"Hello, world!",
content_type: str = "text/plain",
):
"""Helper to POST a file upload."""
return client.post(
"/files/upload?session_id=sess-1",
files={"file": (filename, io.BytesIO(content), content_type)},
)
# ---- Happy path ----
_MOCK_FILE = WorkspaceFile(
id="file-aaa-bbb",
workspace_id="ws-001",
created_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
updated_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
name="hello.txt",
path="/sessions/sess-1/hello.txt",
mime_type="text/plain",
size_bytes=13,
storage_path="local://hello.txt",
)
def test_upload_happy_path(mocker: pytest_mock.MockFixture):
def test_upload_happy_path(mocker):
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
@@ -82,7 +238,7 @@ def test_upload_happy_path(mocker: pytest_mock.MockFixture):
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
@@ -96,10 +252,7 @@ def test_upload_happy_path(mocker: pytest_mock.MockFixture):
assert data["size_bytes"] == 13
# ---- Per-file size limit ----
def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
def test_upload_exceeds_max_file_size(mocker):
"""Files larger than max_file_size_mb should be rejected with 413."""
cfg = mocker.patch("backend.api.features.workspace.routes.Config")
cfg.return_value.max_file_size_mb = 0 # 0 MB → any content is too big
@@ -109,15 +262,11 @@ def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
assert response.status_code == 413
# ---- Storage quota exceeded ----
def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
def test_upload_storage_quota_exceeded(mocker):
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
return_value=_make_workspace(),
)
# Current usage already at limit
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=500 * 1024 * 1024,
@@ -128,27 +277,22 @@ def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
assert "Storage limit exceeded" in response.text
# ---- Post-write quota race (B2) ----
def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
"""If a concurrent upload tips the total over the limit after write,
the file should be soft-deleted and 413 returned."""
def test_upload_post_write_quota_race(mocker):
"""Concurrent upload tipping over limit after write should soft-delete + 413."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
return_value=_make_workspace(),
)
# Pre-write check passes (under limit), but post-write check fails
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
side_effect=[0, 600 * 1024 * 1024], # first call OK, second over limit
side_effect=[0, 600 * 1024 * 1024],
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
@@ -160,17 +304,14 @@ def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
response = _upload()
assert response.status_code == 413
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-1")
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-001")
# ---- Any extension accepted (no allowlist) ----
def test_upload_any_extension(mocker: pytest_mock.MockFixture):
def test_upload_any_extension(mocker):
"""Any file extension should be accepted — ClamAV is the security layer."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
@@ -181,7 +322,7 @@ def test_upload_any_extension(mocker: pytest_mock.MockFixture):
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
@@ -191,16 +332,13 @@ def test_upload_any_extension(mocker: pytest_mock.MockFixture):
assert response.status_code == 200
# ---- Virus scan rejection ----
def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
def test_upload_blocked_by_virus_scan(mocker):
"""Files flagged by ClamAV should be rejected and never written to storage."""
from backend.api.features.store.exceptions import VirusDetectedError
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
@@ -211,7 +349,7 @@ def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
side_effect=VirusDetectedError("Eicar-Test-Signature"),
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
@@ -219,18 +357,14 @@ def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
response = _upload(filename="evil.exe", content=b"X5O!P%@AP...")
assert response.status_code == 400
assert "Virus detected" in response.text
mock_manager.write_file.assert_not_called()
# ---- No file extension ----
def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
def test_upload_file_without_extension(mocker):
"""Files without an extension should be accepted and stored as-is."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
@@ -241,7 +375,7 @@ def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
@@ -257,14 +391,11 @@ def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
assert mock_manager.write_file.call_args[0][1] == "Makefile"
# ---- Filename sanitization (SF5) ----
def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
def test_upload_strips_path_components(mocker):
"""Path-traversal filenames should be reduced to their basename."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
@@ -275,28 +406,23 @@ def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
# Filename with traversal
_upload(filename="../../etc/passwd.txt")
# write_file should have been called with just the basename
mock_manager.write_file.assert_called_once()
call_args = mock_manager.write_file.call_args
assert call_args[0][1] == "passwd.txt"
# ---- Download ----
def test_download_file_not_found(mocker: pytest_mock.MockFixture):
def test_download_file_not_found(mocker):
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=MOCK_WORKSPACE,
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_file",
@@ -307,14 +433,11 @@ def test_download_file_not_found(mocker: pytest_mock.MockFixture):
assert response.status_code == 404
# ---- Delete ----
def test_delete_file_success(mocker: pytest_mock.MockFixture):
def test_delete_file_success(mocker):
"""Deleting an existing file should return {"deleted": true}."""
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=MOCK_WORKSPACE,
return_value=_make_workspace(),
)
mock_manager = mocker.MagicMock()
mock_manager.delete_file = mocker.AsyncMock(return_value=True)
@@ -329,11 +452,11 @@ def test_delete_file_success(mocker: pytest_mock.MockFixture):
mock_manager.delete_file.assert_called_once_with("file-aaa-bbb")
def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
def test_delete_file_not_found(mocker):
"""Deleting a non-existent file should return 404."""
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=MOCK_WORKSPACE,
return_value=_make_workspace(),
)
mock_manager = mocker.MagicMock()
mock_manager.delete_file = mocker.AsyncMock(return_value=False)
@@ -347,7 +470,7 @@ def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
assert "File not found" in response.text
def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
def test_delete_file_no_workspace(mocker):
"""Deleting when user has no workspace should return 404."""
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
@@ -357,3 +480,123 @@ def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
response = client.delete("/files/file-aaa-bbb")
assert response.status_code == 404
assert "Workspace not found" in response.text
def test_upload_write_file_too_large_returns_413(mocker):
"""write_file raises ValueError("File too large: …") → must map to 413."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(
side_effect=ValueError("File too large: 900 bytes exceeds 1MB limit")
)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload()
assert response.status_code == 413
assert "File too large" in response.text
def test_upload_write_file_conflict_returns_409(mocker):
"""Non-'File too large' ValueErrors from write_file stay as 409."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=_make_workspace(),
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(
side_effect=ValueError("File already exists at path: /sessions/x/a.txt")
)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload()
assert response.status_code == 409
assert "already exists" in response.text
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_has_more_true_when_limit_exceeded(
mock_manager_cls, mock_get_workspace
):
"""The limit+1 fetch trick must flip has_more=True and trim the page."""
mock_get_workspace.return_value = _make_workspace()
# Backend was asked for limit+1=3, and returned exactly 3 items.
files = [
_make_file(id="f1", name="a.txt"),
_make_file(id="f2", name="b.txt"),
_make_file(id="f3", name="c.txt"),
]
mock_instance = AsyncMock()
mock_instance.list_files.return_value = files
mock_manager_cls.return_value = mock_instance
response = client.get("/files?limit=2")
assert response.status_code == 200
data = response.json()
assert data["has_more"] is True
assert len(data["files"]) == 2
assert data["files"][0]["id"] == "f1"
assert data["files"][1]["id"] == "f2"
mock_instance.list_files.assert_called_once_with(
limit=3, offset=0, include_all_sessions=True
)
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_has_more_false_when_exactly_page_size(
mock_manager_cls, mock_get_workspace
):
"""Exactly `limit` rows means we're on the last page — has_more=False."""
mock_get_workspace.return_value = _make_workspace()
files = [_make_file(id="f1", name="a.txt"), _make_file(id="f2", name="b.txt")]
mock_instance = AsyncMock()
mock_instance.list_files.return_value = files
mock_manager_cls.return_value = mock_instance
response = client.get("/files?limit=2")
assert response.status_code == 200
data = response.json()
assert data["has_more"] is False
assert len(data["files"]) == 2
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
@patch("backend.api.features.workspace.routes.WorkspaceManager")
def test_list_files_offset_is_echoed_back(mock_manager_cls, mock_get_workspace):
mock_get_workspace.return_value = _make_workspace()
mock_instance = AsyncMock()
mock_instance.list_files.return_value = []
mock_manager_cls.return_value = mock_instance
response = client.get("/files?offset=50&limit=10")
assert response.status_code == 200
assert response.json()["offset"] == 50
mock_instance.list_files.assert_called_once_with(
limit=11, offset=50, include_all_sessions=True
)

View File

@@ -18,6 +18,7 @@ from prisma.errors import PrismaError
import backend.api.features.admin.credit_admin_routes
import backend.api.features.admin.execution_analytics_routes
import backend.api.features.admin.platform_cost_routes
import backend.api.features.admin.rate_limit_admin_routes
import backend.api.features.admin.store_admin_routes
import backend.api.features.builder
@@ -118,6 +119,11 @@ async def lifespan_context(app: fastapi.FastAPI):
AutoRegistry.patch_integrations()
# Register managed credential providers (e.g. AgentMail)
from backend.integrations.managed_providers import register_all
register_all()
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
@@ -324,6 +330,11 @@ app.include_router(
tags=["v2", "admin"],
prefix="/api/copilot",
)
app.include_router(
backend.api.features.admin.platform_cost_routes.router,
tags=["v2", "admin"],
prefix="/api/admin",
)
app.include_router(
backend.api.features.executions.review.routes.router,
tags=["v2", "executions", "review"],

View File

@@ -698,13 +698,30 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
if should_pause:
return
# Validate the input data (original or reviewer-modified) once
if error := self.input_schema.validate_data(input_data):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
block_name=self.name,
block_id=self.id,
)
# Validate the input data (original or reviewer-modified) once.
# In dry-run mode, credential fields may contain sentinel None values
# that would fail JSON schema required checks. We still validate the
# non-credential fields so blocks that execute for real during dry-run
# (e.g. AgentExecutorBlock) get proper input validation.
is_dry_run = getattr(kwargs.get("execution_context"), "dry_run", False)
if is_dry_run:
cred_field_names = set(self.input_schema.get_credentials_fields().keys())
non_cred_data = {
k: v for k, v in input_data.items() if k not in cred_field_names
}
if error := self.input_schema.validate_data(non_cred_data):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
block_name=self.name,
block_id=self.id,
)
else:
if error := self.input_schema.validate_data(input_data):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
block_name=self.name,
block_id=self.id,
)
# Use the validated input data
async for output_name, output_data in self.run(

View File

@@ -49,11 +49,17 @@ class AgentExecutorBlock(Block):
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
required_fields = cls.get_input_schema(data).get("required", [])
return set(required_fields) - set(data)
# Check against the nested `inputs` dict, not the top-level node
# data — required fields like "topic" live inside data["inputs"],
# not at data["topic"].
provided = data.get("inputs", {})
return set(required_fields) - set(provided)
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
return validate_with_jsonschema(cls.get_input_schema(data), data)
return validate_with_jsonschema(
cls.get_input_schema(data), data.get("inputs", {})
)
class Output(BlockSchema):
# Use BlockSchema to avoid automatic error field that could clash with graph outputs
@@ -88,6 +94,7 @@ class AgentExecutorBlock(Block):
execution_context=execution_context.model_copy(
update={"parent_execution_id": graph_exec_id},
),
dry_run=execution_context.dry_run,
)
logger = execution_utils.LogMetadata(
@@ -149,14 +156,19 @@ class AgentExecutorBlock(Block):
ExecutionStatus.TERMINATED,
ExecutionStatus.FAILED,
]:
logger.debug(
f"Execution {log_id} received event {event.event_type} with status {event.status}"
logger.info(
f"Execution {log_id} skipping event {event.event_type} status={event.status} "
f"node={getattr(event, 'node_exec_id', '?')}"
)
continue
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
# If the graph execution is COMPLETED, TERMINATED, or FAILED,
# we can stop listening for further events.
logger.info(
f"Execution {log_id} graph completed with status {event.status}, "
f"yielded {len(yielded_node_exec_ids)} outputs"
)
self.merge_stats(
NodeExecutionStats(
extra_cost=event.stats.cost if event.stats else 0,

View File

@@ -17,7 +17,7 @@ from backend.blocks.apollo.models import (
PrimaryPhone,
SearchOrganizationsRequest,
)
from backend.data.model import CredentialsField, SchemaField
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
class SearchOrganizationsBlock(Block):
@@ -218,6 +218,11 @@ To find IDs, identify the values for organization_id when you call this endpoint
) -> BlockOutput:
query = SearchOrganizationsRequest(**input_data.model_dump())
organizations = await self.search_organizations(query, credentials)
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(organizations)), provider_cost_type="items"
)
)
for organization in organizations:
yield "organization", organization
yield "organizations", organizations

View File

@@ -21,7 +21,7 @@ from backend.blocks.apollo.models import (
SearchPeopleRequest,
SenorityLevels,
)
from backend.data.model import CredentialsField, SchemaField
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
class SearchPeopleBlock(Block):
@@ -366,4 +366,9 @@ class SearchPeopleBlock(Block):
*(enrich_or_fallback(person) for person in people)
)
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(people)), provider_cost_type="items"
)
)
yield "people", people

View File

@@ -146,6 +146,21 @@ class AutoPilotBlock(Block):
advanced=True,
)
dry_run: bool = SchemaField(
description=(
"When enabled, run_block and run_agent tool calls in this "
"autopilot session are forced to use dry-run simulation mode. "
"No real API calls, side effects, or credits are consumed "
"by those tools. Useful for testing agent wiring and "
"previewing outputs. "
"Only applies when creating a new session (session_id is empty). "
"When reusing an existing session_id, the session's original "
"dry_run setting is preserved."
),
default=False,
advanced=True,
)
# timeout_seconds removed: the SDK manages its own heartbeat-based
# timeouts internally; wrapping with asyncio.timeout corrupts the
# SDK's internal stream (see service.py CRITICAL comment).
@@ -232,11 +247,11 @@ class AutoPilotBlock(Block):
},
)
async def create_session(self, user_id: str) -> str:
async def create_session(self, user_id: str, *, dry_run: bool) -> str:
"""Create a new chat session and return its ID (mockable for tests)."""
from backend.copilot.model import create_chat_session # avoid circular import
session = await create_chat_session(user_id)
session = await create_chat_session(user_id, dry_run=dry_run)
return session.session_id
async def execute_copilot(
@@ -367,7 +382,9 @@ class AutoPilotBlock(Block):
# even if the downstream stream fails (avoids orphaned sessions).
sid = input_data.session_id
if not sid:
sid = await self.create_session(execution_context.user_id)
sid = await self.create_session(
execution_context.user_id, dry_run=input_data.dry_run
)
# NOTE: No asyncio.timeout() here — the SDK manages its own
# heartbeat-based timeouts internally. Wrapping with asyncio.timeout

View File

@@ -0,0 +1,712 @@
"""Unit tests for merge_stats cost tracking in individual blocks.
Covers the exa code_context, exa contents, and apollo organization blocks
to verify provider cost is correctly extracted and reported.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, NodeExecutionStats
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
TEST_EXA_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="exa",
api_key=SecretStr("mock-exa-api-key"),
title="Mock Exa API key",
expires_at=None,
)
TEST_EXA_CREDENTIALS_INPUT = {
"provider": TEST_EXA_CREDENTIALS.provider,
"id": TEST_EXA_CREDENTIALS.id,
"type": TEST_EXA_CREDENTIALS.type,
"title": TEST_EXA_CREDENTIALS.title,
}
# ---------------------------------------------------------------------------
# ExaCodeContextBlock — cost_dollars is a string like "0.005"
# ---------------------------------------------------------------------------
class TestExaCodeContextBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_float_cost(self):
"""float(cost_dollars) parsed from API string and passed to merge_stats."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
api_response = {
"requestId": "req-1",
"query": "how to use hooks",
"response": "Here are some examples...",
"resultsCount": 3,
"costDollars": "0.005",
"searchTime": 1.2,
"outputTokens": 100,
}
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.exa.code_context.Requests.post",
new_callable=AsyncMock,
return_value=mock_resp,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = ExaCodeContextBlock.Input(
query="how to use hooks",
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
)
results = []
async for output in block.run(
input_data,
credentials=TEST_EXA_CREDENTIALS,
):
results.append(output)
assert len(accumulated) == 1
assert accumulated[0].provider_cost == pytest.approx(0.005)
@pytest.mark.asyncio
async def test_invalid_cost_dollars_does_not_raise(self):
"""When cost_dollars cannot be parsed as float, merge_stats is not called."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
api_response = {
"requestId": "req-2",
"query": "query",
"response": "response",
"resultsCount": 0,
"costDollars": "N/A",
"searchTime": 0.5,
"outputTokens": 0,
}
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
merge_calls: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.exa.code_context.Requests.post",
new_callable=AsyncMock,
return_value=mock_resp,
),
patch.object(
block, "merge_stats", side_effect=lambda s: merge_calls.append(s)
),
):
input_data = ExaCodeContextBlock.Input(
query="query",
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(
input_data,
credentials=TEST_EXA_CREDENTIALS,
):
pass
assert merge_calls == []
@pytest.mark.asyncio
async def test_zero_cost_is_tracked(self):
"""A zero cost_dollars string '0.0' should still be recorded."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
api_response = {
"requestId": "req-3",
"query": "query",
"response": "...",
"resultsCount": 1,
"costDollars": "0.0",
"searchTime": 0.1,
"outputTokens": 10,
}
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.exa.code_context.Requests.post",
new_callable=AsyncMock,
return_value=mock_resp,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = ExaCodeContextBlock.Input(
query="query",
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(
input_data,
credentials=TEST_EXA_CREDENTIALS,
):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 0.0
# ---------------------------------------------------------------------------
# ExaContentsBlock — response.cost_dollars.total (CostDollars model)
# ---------------------------------------------------------------------------
class TestExaContentsBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_cost_dollars_total(self):
"""provider_cost equals response.cost_dollars.total when present."""
from backend.blocks.exa.contents import ExaContentsBlock
from backend.blocks.exa.helpers import CostDollars
block = ExaContentsBlock()
cost_dollars = CostDollars(total=0.012)
mock_response = MagicMock()
mock_response.results = []
mock_response.context = None
mock_response.statuses = None
mock_response.cost_dollars = cost_dollars
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.exa.contents.AsyncExa",
return_value=MagicMock(
get_contents=AsyncMock(return_value=mock_response)
),
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = ExaContentsBlock.Input(
urls=["https://example.com"],
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(
input_data,
credentials=TEST_EXA_CREDENTIALS,
):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == pytest.approx(0.012)
@pytest.mark.asyncio
async def test_no_merge_stats_when_cost_dollars_absent(self):
"""When response.cost_dollars is None, merge_stats is not called."""
from backend.blocks.exa.contents import ExaContentsBlock
block = ExaContentsBlock()
mock_response = MagicMock()
mock_response.results = []
mock_response.context = None
mock_response.statuses = None
mock_response.cost_dollars = None
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.exa.contents.AsyncExa",
return_value=MagicMock(
get_contents=AsyncMock(return_value=mock_response)
),
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = ExaContentsBlock.Input(
urls=["https://example.com"],
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(
input_data,
credentials=TEST_EXA_CREDENTIALS,
):
pass
assert accumulated == []
# ---------------------------------------------------------------------------
# SearchOrganizationsBlock — provider_cost = float(len(organizations))
# ---------------------------------------------------------------------------
class TestSearchOrganizationsBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_org_count(self):
"""provider_cost == number of returned organizations, type == 'items'."""
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
)
from backend.blocks.apollo.models import Organization
from backend.blocks.apollo.organization import SearchOrganizationsBlock
block = SearchOrganizationsBlock()
fake_orgs = [Organization(id=str(i), name=f"Org{i}") for i in range(3)]
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
SearchOrganizationsBlock,
"search_organizations",
new_callable=AsyncMock,
return_value=fake_orgs,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = SearchOrganizationsBlock.Input(
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
)
results = []
async for output in block.run(
input_data,
credentials=APOLLO_CREDS,
):
results.append(output)
assert len(accumulated) == 1
assert accumulated[0].provider_cost == pytest.approx(3.0)
assert accumulated[0].provider_cost_type == "items"
@pytest.mark.asyncio
async def test_empty_org_list_tracks_zero(self):
"""An empty organization list results in provider_cost=0.0."""
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
)
from backend.blocks.apollo.organization import SearchOrganizationsBlock
block = SearchOrganizationsBlock()
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
SearchOrganizationsBlock,
"search_organizations",
new_callable=AsyncMock,
return_value=[],
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = SearchOrganizationsBlock.Input(
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(
input_data,
credentials=APOLLO_CREDS,
):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 0.0
assert accumulated[0].provider_cost_type == "items"
# ---------------------------------------------------------------------------
# JinaEmbeddingBlock — token count from usage.total_tokens
# ---------------------------------------------------------------------------
class TestJinaEmbeddingBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_token_count(self):
"""provider token count is recorded when API returns usage.total_tokens."""
from backend.blocks.jina._auth import TEST_CREDENTIALS as JINA_CREDS
from backend.blocks.jina._auth import TEST_CREDENTIALS_INPUT as JINA_CREDS_INPUT
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
block = JinaEmbeddingBlock()
api_response = {
"data": [{"embedding": [0.1, 0.2, 0.3]}],
"usage": {"total_tokens": 42},
}
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.jina.embeddings.Requests.post",
new_callable=AsyncMock,
return_value=mock_resp,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = JinaEmbeddingBlock.Input(
texts=["hello world"],
credentials=JINA_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=JINA_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].input_token_count == 42
@pytest.mark.asyncio
async def test_no_merge_stats_when_usage_absent(self):
"""When API response omits usage field, merge_stats is not called."""
from backend.blocks.jina._auth import TEST_CREDENTIALS as JINA_CREDS
from backend.blocks.jina._auth import TEST_CREDENTIALS_INPUT as JINA_CREDS_INPUT
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
block = JinaEmbeddingBlock()
api_response = {
"data": [{"embedding": [0.1, 0.2, 0.3]}],
}
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.jina.embeddings.Requests.post",
new_callable=AsyncMock,
return_value=mock_resp,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = JinaEmbeddingBlock.Input(
texts=["hello"],
credentials=JINA_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=JINA_CREDS):
pass
assert accumulated == []
# ---------------------------------------------------------------------------
# UnrealTextToSpeechBlock — character count from input text length
# ---------------------------------------------------------------------------
class TestUnrealTextToSpeechBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_character_count(self):
"""provider_cost equals len(text) with type='characters'."""
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
from backend.blocks.text_to_speech_block import (
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
)
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
block = UnrealTextToSpeechBlock()
test_text = "Hello, world!"
with (
patch.object(
UnrealTextToSpeechBlock,
"call_unreal_speech_api",
new_callable=AsyncMock,
return_value={"OutputUri": "https://example.com/audio.mp3"},
),
patch.object(block, "merge_stats") as mock_merge,
):
input_data = UnrealTextToSpeechBlock.Input(
text=test_text,
credentials=TTS_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=TTS_CREDS):
pass
mock_merge.assert_called_once()
stats = mock_merge.call_args[0][0]
assert stats.provider_cost == float(len(test_text))
assert stats.provider_cost_type == "characters"
@pytest.mark.asyncio
async def test_empty_text_gives_zero_characters(self):
"""An empty text string results in provider_cost=0.0."""
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
from backend.blocks.text_to_speech_block import (
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
)
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
block = UnrealTextToSpeechBlock()
with (
patch.object(
UnrealTextToSpeechBlock,
"call_unreal_speech_api",
new_callable=AsyncMock,
return_value={"OutputUri": "https://example.com/audio.mp3"},
),
patch.object(block, "merge_stats") as mock_merge,
):
input_data = UnrealTextToSpeechBlock.Input(
text="",
credentials=TTS_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=TTS_CREDS):
pass
mock_merge.assert_called_once()
stats = mock_merge.call_args[0][0]
assert stats.provider_cost == 0.0
assert stats.provider_cost_type == "characters"
# ---------------------------------------------------------------------------
# GoogleMapsSearchBlock — item count from search_places results
# ---------------------------------------------------------------------------
class TestGoogleMapsSearchBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_place_count(self):
"""provider_cost equals number of returned places, type == 'items'."""
from backend.blocks.google_maps import TEST_CREDENTIALS as MAPS_CREDS
from backend.blocks.google_maps import (
TEST_CREDENTIALS_INPUT as MAPS_CREDS_INPUT,
)
from backend.blocks.google_maps import GoogleMapsSearchBlock
block = GoogleMapsSearchBlock()
fake_places = [{"name": f"Place{i}", "address": f"Addr{i}"} for i in range(4)]
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
GoogleMapsSearchBlock,
"search_places",
return_value=fake_places,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = GoogleMapsSearchBlock.Input(
query="coffee shops",
credentials=MAPS_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=MAPS_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 4.0
assert accumulated[0].provider_cost_type == "items"
@pytest.mark.asyncio
async def test_empty_results_tracks_zero(self):
"""Zero places returned results in provider_cost=0.0."""
from backend.blocks.google_maps import TEST_CREDENTIALS as MAPS_CREDS
from backend.blocks.google_maps import (
TEST_CREDENTIALS_INPUT as MAPS_CREDS_INPUT,
)
from backend.blocks.google_maps import GoogleMapsSearchBlock
block = GoogleMapsSearchBlock()
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
GoogleMapsSearchBlock,
"search_places",
return_value=[],
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = GoogleMapsSearchBlock.Input(
query="nothing here",
credentials=MAPS_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=MAPS_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 0.0
assert accumulated[0].provider_cost_type == "items"
# ---------------------------------------------------------------------------
# SmartLeadAddLeadsBlock — item count from lead_list length
# ---------------------------------------------------------------------------
class TestSmartLeadAddLeadsBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_lead_count(self):
"""provider_cost equals number of leads uploaded, type == 'items'."""
from backend.blocks.smartlead._auth import TEST_CREDENTIALS as SL_CREDS
from backend.blocks.smartlead._auth import (
TEST_CREDENTIALS_INPUT as SL_CREDS_INPUT,
)
from backend.blocks.smartlead.campaign import AddLeadToCampaignBlock
from backend.blocks.smartlead.models import (
AddLeadsToCampaignResponse,
LeadInput,
)
block = AddLeadToCampaignBlock()
fake_leads = [
LeadInput(first_name="Alice", last_name="A", email="alice@example.com"),
LeadInput(first_name="Bob", last_name="B", email="bob@example.com"),
]
fake_response = AddLeadsToCampaignResponse(
ok=True,
upload_count=2,
total_leads=2,
block_count=0,
duplicate_count=0,
invalid_email_count=0,
invalid_emails=[],
already_added_to_campaign=0,
unsubscribed_leads=[],
is_lead_limit_exhausted=False,
lead_import_stopped_count=0,
bounce_count=0,
)
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
AddLeadToCampaignBlock,
"add_leads_to_campaign",
new_callable=AsyncMock,
return_value=fake_response,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = AddLeadToCampaignBlock.Input(
campaign_id=123,
lead_list=fake_leads,
credentials=SL_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=SL_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 2.0
assert accumulated[0].provider_cost_type == "items"
# ---------------------------------------------------------------------------
# SearchPeopleBlock — item count from people list length
# ---------------------------------------------------------------------------
class TestSearchPeopleBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_people_count(self):
"""provider_cost equals number of returned people, type == 'items'."""
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
)
from backend.blocks.apollo.models import Contact
from backend.blocks.apollo.people import SearchPeopleBlock
block = SearchPeopleBlock()
fake_people = [Contact(id=str(i), first_name=f"Person{i}") for i in range(5)]
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
SearchPeopleBlock,
"search_people",
new_callable=AsyncMock,
return_value=fake_people,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = SearchPeopleBlock.Input(
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=APOLLO_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == pytest.approx(5.0)
assert accumulated[0].provider_cost_type == "items"
@pytest.mark.asyncio
async def test_empty_people_list_tracks_zero(self):
"""An empty people list results in provider_cost=0.0."""
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
)
from backend.blocks.apollo.people import SearchPeopleBlock
block = SearchPeopleBlock()
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
SearchPeopleBlock,
"search_people",
new_callable=AsyncMock,
return_value=[],
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = SearchPeopleBlock.Input(
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=APOLLO_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 0.0
assert accumulated[0].provider_cost_type == "items"

View File

@@ -9,6 +9,7 @@ from typing import Union
from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -116,3 +117,10 @@ class ExaCodeContextBlock(Block):
yield "cost_dollars", context.cost_dollars
yield "search_time", context.search_time
yield "output_tokens", context.output_tokens
# Parse cost_dollars (API returns as string, e.g. "0.005")
try:
cost_usd = float(context.cost_dollars)
self.merge_stats(NodeExecutionStats(provider_cost=cost_usd))
except (ValueError, TypeError):
pass

View File

@@ -4,6 +4,7 @@ from typing import Optional
from exa_py import AsyncExa
from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -223,3 +224,6 @@ class ExaContentsBlock(Block):
if response.cost_dollars:
yield "cost_dollars", response.cost_dollars
self.merge_stats(
NodeExecutionStats(provider_cost=response.cost_dollars.total)
)

View File

@@ -0,0 +1,575 @@
"""Tests for cost tracking in Exa blocks.
Covers the cost_dollars → provider_cost → merge_stats path for both
ExaContentsBlock and ExaCodeContextBlock.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.blocks.exa._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
from backend.data.model import NodeExecutionStats
class TestExaCodeContextCostTracking:
"""ExaCodeContextBlock parses cost_dollars (string) and calls merge_stats."""
@pytest.mark.asyncio
async def test_valid_cost_string_is_parsed_and_merged(self):
"""A numeric cost string like '0.005' is merged as provider_cost."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
api_response = {
"requestId": "req-1",
"query": "test query",
"response": "some code",
"resultsCount": 3,
"costDollars": "0.005",
"searchTime": 1.2,
"outputTokens": 100,
}
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
outputs = []
async for key, value in block.run(
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
outputs.append((key, value))
assert any(k == "cost_dollars" for k, _ in outputs)
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.005)
@pytest.mark.asyncio
async def test_invalid_cost_string_does_not_raise(self):
"""A non-numeric cost_dollars value is swallowed silently."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
api_response = {
"requestId": "req-2",
"query": "test",
"response": "code",
"resultsCount": 0,
"costDollars": "N/A",
"searchTime": 0.5,
"outputTokens": 0,
}
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
outputs = []
async for key, value in block.run(
block.Input(query="test", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
outputs.append((key, value))
# No merge_stats call because float() raised ValueError
assert len(merged) == 0
@pytest.mark.asyncio
async def test_zero_cost_string_is_merged(self):
"""'0.0' is a valid cost — should still be tracked."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
api_response = {
"requestId": "req-3",
"query": "free query",
"response": "result",
"resultsCount": 1,
"costDollars": "0.0",
"searchTime": 0.1,
"outputTokens": 10,
}
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
async for _ in block.run(
block.Input(query="free query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.0)
class TestExaContentsCostTracking:
"""ExaContentsBlock merges cost_dollars.total as provider_cost."""
@pytest.mark.asyncio
async def test_cost_dollars_total_is_merged(self):
"""When the SDK response includes cost_dollars, its total is merged."""
from backend.blocks.exa.contents import ExaContentsBlock
from backend.blocks.exa.helpers import CostDollars
block = ExaContentsBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.statuses = None
mock_sdk_response.cost_dollars = CostDollars(total=0.012)
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.012)
@pytest.mark.asyncio
async def test_no_cost_dollars_skips_merge(self):
"""When cost_dollars is absent, merge_stats is not called."""
from backend.blocks.exa.contents import ExaContentsBlock
block = ExaContentsBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.statuses = None
mock_sdk_response.cost_dollars = None
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 0
@pytest.mark.asyncio
async def test_zero_cost_dollars_is_merged(self):
"""A total of 0.0 (free tier) should still be merged."""
from backend.blocks.exa.contents import ExaContentsBlock
from backend.blocks.exa.helpers import CostDollars
block = ExaContentsBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.statuses = None
mock_sdk_response.cost_dollars = CostDollars(total=0.0)
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.0)
class TestExaSearchCostTracking:
"""ExaSearchBlock merges cost_dollars.total as provider_cost."""
@pytest.mark.asyncio
async def test_cost_dollars_total_is_merged(self):
"""When the SDK response includes cost_dollars, its total is merged."""
from backend.blocks.exa.helpers import CostDollars
from backend.blocks.exa.search import ExaSearchBlock
block = ExaSearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.resolved_search_type = None
mock_sdk_response.cost_dollars = CostDollars(total=0.008)
with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.search = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.008)
@pytest.mark.asyncio
async def test_no_cost_dollars_skips_merge(self):
"""When cost_dollars is absent, merge_stats is not called."""
from backend.blocks.exa.search import ExaSearchBlock
block = ExaSearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.resolved_search_type = None
mock_sdk_response.cost_dollars = None
with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.search = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 0
class TestExaSimilarCostTracking:
"""ExaFindSimilarBlock merges cost_dollars.total as provider_cost."""
@pytest.mark.asyncio
async def test_cost_dollars_total_is_merged(self):
"""When the SDK response includes cost_dollars, its total is merged."""
from backend.blocks.exa.helpers import CostDollars
from backend.blocks.exa.similar import ExaFindSimilarBlock
block = ExaFindSimilarBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.request_id = "req-1"
mock_sdk_response.cost_dollars = CostDollars(total=0.015)
with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.015)
@pytest.mark.asyncio
async def test_no_cost_dollars_skips_merge(self):
"""When cost_dollars is absent, merge_stats is not called."""
from backend.blocks.exa.similar import ExaFindSimilarBlock
block = ExaFindSimilarBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.request_id = "req-2"
mock_sdk_response.cost_dollars = None
with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 0
# ---------------------------------------------------------------------------
# ExaCreateResearchBlock — cost_dollars from completed poll response
# ---------------------------------------------------------------------------
COMPLETED_RESEARCH_RESPONSE = {
"researchId": "test-research-id",
"status": "completed",
"model": "exa-research",
"instructions": "test instructions",
"createdAt": 1700000000000,
"finishedAt": 1700000060000,
"costDollars": {
"total": 0.05,
"numSearches": 3,
"numPages": 10,
"reasoningTokens": 500,
},
"output": {"content": "Research findings...", "parsed": None},
}
PENDING_RESEARCH_RESPONSE = {
"researchId": "test-research-id",
"status": "pending",
"model": "exa-research",
"instructions": "test instructions",
"createdAt": 1700000000000,
}
class TestExaCreateResearchBlockCostTracking:
"""ExaCreateResearchBlock merges cost from completed poll response."""
@pytest.mark.asyncio
async def test_cost_merged_when_research_completes(self):
"""merge_stats called with provider_cost=total when poll returns completed."""
from backend.blocks.exa.research import ExaCreateResearchBlock
block = ExaCreateResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
create_resp = MagicMock()
create_resp.json.return_value = PENDING_RESEARCH_RESPONSE
poll_resp = MagicMock()
poll_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
mock_instance = MagicMock()
mock_instance.post = AsyncMock(return_value=create_resp)
mock_instance.get = AsyncMock(return_value=poll_resp)
with (
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
patch("asyncio.sleep", new=AsyncMock()),
):
async for _ in block.run(
block.Input(
instructions="test instructions",
wait_for_completion=True,
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.05)
@pytest.mark.asyncio
async def test_no_merge_when_no_cost_dollars(self):
"""When completed response has no costDollars, merge_stats is not called."""
from backend.blocks.exa.research import ExaCreateResearchBlock
block = ExaCreateResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
create_resp = MagicMock()
create_resp.json.return_value = PENDING_RESEARCH_RESPONSE
poll_resp = MagicMock()
poll_resp.json.return_value = no_cost_response
mock_instance = MagicMock()
mock_instance.post = AsyncMock(return_value=create_resp)
mock_instance.get = AsyncMock(return_value=poll_resp)
with (
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
patch("asyncio.sleep", new=AsyncMock()),
):
async for _ in block.run(
block.Input(
instructions="test instructions",
wait_for_completion=True,
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert merged == []
# ---------------------------------------------------------------------------
# ExaGetResearchBlock — cost_dollars from single GET response
# ---------------------------------------------------------------------------
class TestExaGetResearchBlockCostTracking:
"""ExaGetResearchBlock merges cost when the fetched research has cost_dollars."""
@pytest.mark.asyncio
async def test_cost_merged_from_completed_research(self):
"""merge_stats called with provider_cost=total when research has costDollars."""
from backend.blocks.exa.research import ExaGetResearchBlock
block = ExaGetResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
get_resp = MagicMock()
get_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
mock_instance = MagicMock()
mock_instance.get = AsyncMock(return_value=get_resp)
with patch("backend.blocks.exa.research.Requests", return_value=mock_instance):
async for _ in block.run(
block.Input(
research_id="test-research-id",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.05)
@pytest.mark.asyncio
async def test_no_merge_when_no_cost_dollars(self):
"""When research has no costDollars, merge_stats is not called."""
from backend.blocks.exa.research import ExaGetResearchBlock
block = ExaGetResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
get_resp = MagicMock()
get_resp.json.return_value = no_cost_response
mock_instance = MagicMock()
mock_instance.get = AsyncMock(return_value=get_resp)
with patch("backend.blocks.exa.research.Requests", return_value=mock_instance):
async for _ in block.run(
block.Input(
research_id="test-research-id",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert merged == []
# ---------------------------------------------------------------------------
# ExaWaitForResearchBlock — cost_dollars from polling response
# ---------------------------------------------------------------------------
class TestExaWaitForResearchBlockCostTracking:
"""ExaWaitForResearchBlock merges cost when the polled research has cost_dollars."""
@pytest.mark.asyncio
async def test_cost_merged_when_research_completes(self):
"""merge_stats called with provider_cost=total once polling returns completed."""
from backend.blocks.exa.research import ExaWaitForResearchBlock
block = ExaWaitForResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
poll_resp = MagicMock()
poll_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
mock_instance = MagicMock()
mock_instance.get = AsyncMock(return_value=poll_resp)
with (
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
patch("asyncio.sleep", new=AsyncMock()),
):
async for _ in block.run(
block.Input(
research_id="test-research-id",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.05)
@pytest.mark.asyncio
async def test_no_merge_when_no_cost_dollars(self):
"""When completed research has no costDollars, merge_stats is not called."""
from backend.blocks.exa.research import ExaWaitForResearchBlock
block = ExaWaitForResearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
poll_resp = MagicMock()
poll_resp.json.return_value = no_cost_response
mock_instance = MagicMock()
mock_instance.get = AsyncMock(return_value=poll_resp)
with (
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
patch("asyncio.sleep", new=AsyncMock()),
):
async for _ in block.run(
block.Input(
research_id="test-research-id",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
),
credentials=TEST_CREDENTIALS,
):
pass
assert merged == []

View File

@@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -232,6 +233,11 @@ class ExaCreateResearchBlock(Block):
if research.cost_dollars:
yield "cost_total", research.cost_dollars.total
self.merge_stats(
NodeExecutionStats(
provider_cost=research.cost_dollars.total
)
)
return
await asyncio.sleep(check_interval)
@@ -346,6 +352,9 @@ class ExaGetResearchBlock(Block):
yield "cost_searches", research.cost_dollars.num_searches
yield "cost_pages", research.cost_dollars.num_pages
yield "cost_reasoning_tokens", research.cost_dollars.reasoning_tokens
self.merge_stats(
NodeExecutionStats(provider_cost=research.cost_dollars.total)
)
yield "error_message", research.error
@@ -432,6 +441,9 @@ class ExaWaitForResearchBlock(Block):
if research.cost_dollars:
yield "cost_total", research.cost_dollars.total
self.merge_stats(
NodeExecutionStats(provider_cost=research.cost_dollars.total)
)
return

View File

@@ -4,6 +4,7 @@ from typing import Optional
from exa_py import AsyncExa
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -206,3 +207,6 @@ class ExaSearchBlock(Block):
if response.cost_dollars:
yield "cost_dollars", response.cost_dollars
self.merge_stats(
NodeExecutionStats(provider_cost=response.cost_dollars.total)
)

View File

@@ -3,6 +3,7 @@ from typing import Optional
from exa_py import AsyncExa
from backend.data.model import NodeExecutionStats
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -167,3 +168,6 @@ class ExaFindSimilarBlock(Block):
if response.cost_dollars:
yield "cost_dollars", response.cost_dollars
self.merge_stats(
NodeExecutionStats(provider_cost=response.cost_dollars.total)
)

View File

@@ -14,6 +14,7 @@ from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
)
from backend.integrations.providers import ProviderName
@@ -117,6 +118,11 @@ class GoogleMapsSearchBlock(Block):
input_data.radius,
input_data.max_results,
)
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(places)), provider_cost_type="items"
)
)
for place in places:
yield "place", place

View File

@@ -2,6 +2,8 @@ import copy
from datetime import date, time
from typing import Any, Optional
from pydantic import AliasChoices, Field
from backend.blocks._base import (
Block,
BlockCategory,
@@ -467,7 +469,8 @@ class AgentFileInputBlock(AgentInputBlock):
class AgentDropdownInputBlock(AgentInputBlock):
"""
A specialized text input block that relies on placeholder_values to present a dropdown.
A specialized text input block that presents a dropdown selector
restricted to a fixed set of values.
"""
class Input(AgentInputBlock.Input):
@@ -477,16 +480,23 @@ class AgentDropdownInputBlock(AgentInputBlock):
advanced=False,
title="Default Value",
)
placeholder_values: list = SchemaField(
description="Possible values for the dropdown.",
# Use Field() directly (not SchemaField) to pass validation_alias,
# which handles backward compat for legacy "placeholder_values" across
# all construction paths (model_construct, __init__, model_validate).
options: list = Field(
default_factory=list,
advanced=False,
title="Dropdown Options",
description=(
"If provided, renders the input as a dropdown selector "
"restricted to these values. Leave empty for free-text input."
),
validation_alias=AliasChoices("options", "placeholder_values"),
json_schema_extra={"advanced": False, "secret": False},
)
def generate_schema(self):
schema = super().generate_schema()
if possible_values := self.placeholder_values:
if possible_values := self.options:
schema["enum"] = possible_values
return schema
@@ -504,13 +514,13 @@ class AgentDropdownInputBlock(AgentInputBlock):
{
"value": "Option A",
"name": "dropdown_1",
"placeholder_values": ["Option A", "Option B", "Option C"],
"options": ["Option A", "Option B", "Option C"],
"description": "Dropdown example 1",
},
{
"value": "Option C",
"name": "dropdown_2",
"placeholder_values": ["Option A", "Option B", "Option C"],
"options": ["Option A", "Option B", "Option C"],
"description": "Dropdown example 2",
},
],

View File

@@ -10,7 +10,7 @@ from backend.blocks.jina._auth import (
JinaCredentialsField,
JinaCredentialsInput,
)
from backend.data.model import SchemaField
from backend.data.model import NodeExecutionStats, SchemaField
from backend.util.request import Requests
@@ -45,5 +45,13 @@ class JinaEmbeddingBlock(Block):
}
data = {"input": input_data.texts, "model": input_data.model}
response = await Requests().post(url, headers=headers, json=data)
embeddings = [e["embedding"] for e in response.json()["data"]]
resp_json = response.json()
embeddings = [e["embedding"] for e in resp_json["data"]]
usage = resp_json.get("usage", {})
if usage.get("total_tokens"):
self.merge_stats(
NodeExecutionStats(
input_token_count=usage.get("total_tokens", 0),
)
)
yield "embeddings", embeddings

View File

@@ -1,6 +1,7 @@
# This file contains a lot of prompt block strings that would trigger "line too long"
# flake8: noqa: E501
import logging
import math
import re
import secrets
from abc import ABC
@@ -13,6 +14,7 @@ import ollama
import openai
from anthropic.types import ToolParam
from groq import AsyncGroq
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
from pydantic import BaseModel, SecretStr
from backend.blocks._base import (
@@ -205,6 +207,19 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
KIMI_K2 = "moonshotai/kimi-k2"
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
QWEN3_CODER = "qwen/qwen3-coder"
# Z.ai (Zhipu) models
ZAI_GLM_4_32B = "z-ai/glm-4-32b"
ZAI_GLM_4_5 = "z-ai/glm-4.5"
ZAI_GLM_4_5_AIR = "z-ai/glm-4.5-air"
ZAI_GLM_4_5_AIR_FREE = "z-ai/glm-4.5-air:free"
ZAI_GLM_4_5V = "z-ai/glm-4.5v"
ZAI_GLM_4_6 = "z-ai/glm-4.6"
ZAI_GLM_4_6V = "z-ai/glm-4.6v"
ZAI_GLM_4_7 = "z-ai/glm-4.7"
ZAI_GLM_4_7_FLASH = "z-ai/glm-4.7-flash"
ZAI_GLM_5 = "z-ai/glm-5"
ZAI_GLM_5_TURBO = "z-ai/glm-5-turbo"
ZAI_GLM_5V_TURBO = "z-ai/glm-5v-turbo"
# Llama API models
LLAMA_API_LLAMA_4_SCOUT = "Llama-4-Scout-17B-16E-Instruct-FP8"
LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8"
@@ -630,6 +645,43 @@ MODEL_METADATA = {
LlmModel.QWEN3_CODER: ModelMetadata(
"open_router", 262144, 262144, "Qwen 3 Coder", "OpenRouter", "Qwen", 3
),
# https://openrouter.ai/models?q=z-ai
LlmModel.ZAI_GLM_4_32B: ModelMetadata(
"open_router", 128000, 128000, "GLM 4 32B", "OpenRouter", "Z.ai", 1
),
LlmModel.ZAI_GLM_4_5: ModelMetadata(
"open_router", 131072, 98304, "GLM 4.5", "OpenRouter", "Z.ai", 2
),
LlmModel.ZAI_GLM_4_5_AIR: ModelMetadata(
"open_router", 131072, 98304, "GLM 4.5 Air", "OpenRouter", "Z.ai", 1
),
LlmModel.ZAI_GLM_4_5_AIR_FREE: ModelMetadata(
"open_router", 131072, 96000, "GLM 4.5 Air (Free)", "OpenRouter", "Z.ai", 1
),
LlmModel.ZAI_GLM_4_5V: ModelMetadata(
"open_router", 65536, 16384, "GLM 4.5V", "OpenRouter", "Z.ai", 2
),
LlmModel.ZAI_GLM_4_6: ModelMetadata(
"open_router", 204800, 204800, "GLM 4.6", "OpenRouter", "Z.ai", 1
),
LlmModel.ZAI_GLM_4_6V: ModelMetadata(
"open_router", 131072, 131072, "GLM 4.6V", "OpenRouter", "Z.ai", 1
),
LlmModel.ZAI_GLM_4_7: ModelMetadata(
"open_router", 202752, 65535, "GLM 4.7", "OpenRouter", "Z.ai", 1
),
LlmModel.ZAI_GLM_4_7_FLASH: ModelMetadata(
"open_router", 202752, 202752, "GLM 4.7 Flash", "OpenRouter", "Z.ai", 1
),
LlmModel.ZAI_GLM_5: ModelMetadata(
"open_router", 80000, 80000, "GLM 5", "OpenRouter", "Z.ai", 2
),
LlmModel.ZAI_GLM_5_TURBO: ModelMetadata(
"open_router", 202752, 131072, "GLM 5 Turbo", "OpenRouter", "Z.ai", 3
),
LlmModel.ZAI_GLM_5V_TURBO: ModelMetadata(
"open_router", 202752, 131072, "GLM 5V Turbo", "OpenRouter", "Z.ai", 3
),
# Llama API models
LlmModel.LLAMA_API_LLAMA_4_SCOUT: ModelMetadata(
"llama_api",
@@ -687,6 +739,7 @@ class LLMResponse(BaseModel):
prompt_tokens: int
completion_tokens: int
reasoning: Optional[str] = None
provider_cost: float | None = None
def convert_openai_tool_fmt_to_anthropic(
@@ -721,6 +774,35 @@ def convert_openai_tool_fmt_to_anthropic(
return anthropic_tools
def extract_openrouter_cost(response: OpenAIChatCompletion) -> float | None:
"""Extract OpenRouter's `x-total-cost` header from an OpenAI SDK response.
OpenRouter returns the per-request USD cost in a response header. The
OpenAI SDK exposes the raw httpx response via an undocumented `_response`
attribute. We use try/except AttributeError so that if the SDK ever drops
or renames that attribute, the warning is visible in logs rather than
silently degrading to no cost tracking.
"""
try:
raw_resp = response._response # type: ignore[attr-defined]
except AttributeError:
logger.warning(
"OpenAI SDK response missing _response attribute"
" — OpenRouter cost tracking unavailable"
)
return None
try:
cost_header = raw_resp.headers.get("x-total-cost")
if not cost_header:
return None
cost = float(cost_header)
if not math.isfinite(cost):
return None
return cost
except (ValueError, TypeError, AttributeError):
return None
def extract_openai_reasoning(response) -> str | None:
"""Extract reasoning from OpenAI-compatible response if available."""
"""Note: This will likely not working since the reasoning is not present in another Response API"""
@@ -1053,6 +1135,7 @@ async def llm_call(
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
reasoning=reasoning,
provider_cost=extract_openrouter_cost(response),
)
elif provider == "llama_api":
tools_param = tools if tools else openai.NOT_GIVEN
@@ -1360,6 +1443,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
error_feedback_message = ""
llm_model = input_data.model
last_attempt_cost: float | None = None
for retry_count in range(input_data.retry):
logger.debug(f"LLM request: {prompt}")
@@ -1377,12 +1461,15 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
max_tokens=input_data.max_tokens,
)
response_text = llm_response.response
self.merge_stats(
NodeExecutionStats(
input_token_count=llm_response.prompt_tokens,
output_token_count=llm_response.completion_tokens,
)
# Merge token counts for every attempt (each call costs tokens).
# provider_cost (actual USD) is tracked separately and only merged
# on success to avoid double-counting across retries.
token_stats = NodeExecutionStats(
input_token_count=llm_response.prompt_tokens,
output_token_count=llm_response.completion_tokens,
)
self.merge_stats(token_stats)
last_attempt_cost = llm_response.provider_cost
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
if input_data.expected_format:
@@ -1451,6 +1538,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
NodeExecutionStats(
llm_call_count=retry_count + 1,
llm_retry_count=retry_count,
provider_cost=last_attempt_cost,
)
)
yield "response", response_obj
@@ -1471,6 +1559,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
NodeExecutionStats(
llm_call_count=retry_count + 1,
llm_retry_count=retry_count,
provider_cost=last_attempt_cost,
)
)
yield "response", {"response": response_text}
@@ -2016,6 +2105,19 @@ class AIConversationBlock(AIBlockBase):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
has_messages = any(
isinstance(m, dict)
and isinstance(m.get("content"), str)
and bool(m["content"].strip())
for m in (input_data.messages or [])
)
has_prompt = bool(input_data.prompt and input_data.prompt.strip())
if not has_messages and not has_prompt:
raise ValueError(
"Cannot call LLM with no messages and no prompt. "
"Provide at least one message or a non-empty prompt."
)
response = await self.llm_call(
AIStructuredResponseGeneratorBlock.Input(
prompt=input_data.prompt,

View File

@@ -89,6 +89,12 @@ class MCPToolBlock(Block):
default={},
hidden=True,
)
tool_description: str = SchemaField(
description="Description of the selected MCP tool. "
"Populated automatically when a tool is selected.",
default="",
hidden=True,
)
tool_arguments: dict[str, Any] = SchemaField(
description="Arguments to pass to the selected MCP tool. "

View File

@@ -23,7 +23,7 @@ from backend.blocks.smartlead.models import (
SaveSequencesResponse,
Sequence,
)
from backend.data.model import CredentialsField, SchemaField
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
class CreateCampaignBlock(Block):
@@ -226,6 +226,12 @@ class AddLeadToCampaignBlock(Block):
response = await self.add_leads_to_campaign(
input_data.campaign_id, input_data.lead_list, credentials
)
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(input_data.lead_list)),
provider_cost_type="items",
)
)
yield "campaign_id", input_data.campaign_id
yield "upload_count", response.upload_count

View File

@@ -0,0 +1,323 @@
import asyncio
from typing import Any, Literal
from pydantic import SecretStr
from sqlalchemy.engine.url import URL
from sqlalchemy.exc import DBAPIError, OperationalError, ProgrammingError
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.sql_query_helpers import (
_DATABASE_TYPE_DEFAULT_PORT,
_DATABASE_TYPE_TO_DRIVER,
DatabaseType,
_execute_query,
_sanitize_error,
_validate_query_is_read_only,
_validate_single_statement,
)
from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
SchemaField,
UserPasswordCredentials,
)
from backend.integrations.providers import ProviderName
from backend.util.request import resolve_and_check_blocked
TEST_CREDENTIALS = UserPasswordCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="database",
username=SecretStr("test_user"),
password=SecretStr("test_pass"),
title="Mock Database credentials",
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
DatabaseCredentials = UserPasswordCredentials
DatabaseCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.DATABASE],
Literal["user_password"],
]
def DatabaseCredentialsField() -> DatabaseCredentialsInput:
return CredentialsField(
description="Database username and password",
)
class SQLQueryBlock(Block):
class Input(BlockSchemaInput):
database_type: DatabaseType = SchemaField(
default=DatabaseType.POSTGRES,
description="Database engine",
advanced=False,
)
host: SecretStr = SchemaField(
description=(
"Database hostname or IP address. "
"Treated as a secret to avoid leaking infrastructure details. "
"Private/internal IPs are blocked (SSRF protection)."
),
placeholder="db.example.com",
secret=True,
)
port: int | None = SchemaField(
default=None,
description=(
"Database port (leave empty for default: "
"PostgreSQL: 5432, MySQL: 3306, MSSQL: 1433)"
),
ge=1,
le=65535,
)
database: str = SchemaField(
description="Name of the database to connect to",
placeholder="my_database",
)
query: str = SchemaField(
description="SQL query to execute",
placeholder="SELECT * FROM analytics.daily_active_users LIMIT 10",
)
read_only: bool = SchemaField(
default=True,
description=(
"When enabled (default), only SELECT queries are allowed "
"and the database session is set to read-only mode. "
"Disable to allow write operations (INSERT, UPDATE, DELETE, etc.)."
),
)
timeout: int = SchemaField(
default=30,
description="Query timeout in seconds (max 120)",
ge=1,
le=120,
)
max_rows: int = SchemaField(
default=1000,
description="Maximum number of rows to return (max 10000)",
ge=1,
le=10000,
)
credentials: DatabaseCredentialsInput = DatabaseCredentialsField()
class Output(BlockSchemaOutput):
results: list[dict[str, Any]] = SchemaField(
description="Query results as a list of row dictionaries"
)
columns: list[str] = SchemaField(
description="Column names from the query result"
)
row_count: int = SchemaField(description="Number of rows returned")
truncated: bool = SchemaField(
description=(
"True when the result set was capped by max_rows, "
"indicating additional rows exist in the database"
)
)
affected_rows: int = SchemaField(
description="Number of rows affected by a write query (INSERT/UPDATE/DELETE)"
)
error: str = SchemaField(description="Error message if the query failed")
def __init__(self):
super().__init__(
id="4dc35c0f-4fd8-465e-9616-5a216f1ba2bc",
description=(
"Execute a SQL query. Read-only by default for safety "
"-- disable to allow write operations. "
"Supports PostgreSQL, MySQL, and MSSQL via SQLAlchemy."
),
categories={BlockCategory.DATA},
input_schema=SQLQueryBlock.Input,
output_schema=SQLQueryBlock.Output,
test_input={
"query": "SELECT 1 AS test_col",
"database_type": DatabaseType.POSTGRES,
"host": "localhost",
"database": "test_db",
"timeout": 30,
"max_rows": 1000,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("results", [{"test_col": 1}]),
("columns", ["test_col"]),
("row_count", 1),
("truncated", False),
],
test_mock={
"execute_query": lambda *_args, **_kwargs: (
[{"test_col": 1}],
["test_col"],
-1,
False,
),
"check_host_allowed": lambda *_args, **_kwargs: ["127.0.0.1"],
},
)
@staticmethod
async def check_host_allowed(host: str) -> list[str]:
"""Validate that the given host is not a private/blocked address.
Returns the list of resolved IP addresses so the caller can pin the
connection to the validated IP (preventing DNS rebinding / TOCTOU).
Raises ValueError or OSError if the host is blocked.
Extracted as a method so it can be mocked during block tests.
"""
return await resolve_and_check_blocked(host)
@staticmethod
def execute_query(
connection_url: URL | str,
query: str,
timeout: int,
max_rows: int,
read_only: bool = True,
database_type: DatabaseType = DatabaseType.POSTGRES,
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
"""Execute a SQL query and return (rows, columns, affected_rows, truncated).
Delegates to ``_execute_query`` in ``sql_query_helpers``.
Extracted as a method so it can be mocked during block tests.
"""
return _execute_query(
connection_url=connection_url,
query=query,
timeout=timeout,
max_rows=max_rows,
read_only=read_only,
database_type=database_type,
)
async def run(
self,
input_data: Input,
*,
credentials: DatabaseCredentials,
**_kwargs: Any,
) -> BlockOutput:
# Validate query structure and read-only constraints.
error = self._validate_query(input_data)
if error:
yield "error", error
return
# Validate host and resolve for SSRF protection.
host, pinned_host, error = await self._resolve_host(input_data)
if error:
yield "error", error
return
# Build connection URL and execute.
port = input_data.port or _DATABASE_TYPE_DEFAULT_PORT[input_data.database_type]
username = credentials.username.get_secret_value()
connection_url = URL.create(
drivername=_DATABASE_TYPE_TO_DRIVER[input_data.database_type],
username=username,
password=credentials.password.get_secret_value(),
host=pinned_host,
port=port,
database=input_data.database,
)
conn_str = connection_url.render_as_string(hide_password=True)
db_name = input_data.database
def _sanitize(err: Exception) -> str:
return _sanitize_error(
str(err).strip(),
conn_str,
host=pinned_host,
original_host=host,
username=username,
port=port,
database=db_name,
)
try:
results, columns, affected, truncated = await asyncio.to_thread(
self.execute_query,
connection_url=connection_url,
query=input_data.query,
timeout=input_data.timeout,
max_rows=input_data.max_rows,
read_only=input_data.read_only,
database_type=input_data.database_type,
)
yield "results", results
yield "columns", columns
yield "row_count", len(results)
yield "truncated", truncated
if affected >= 0:
yield "affected_rows", affected
except OperationalError as e:
yield (
"error",
self._classify_operational_error(
_sanitize(e),
input_data.timeout,
),
)
except ProgrammingError as e:
yield "error", f"SQL error: {_sanitize(e)}"
except DBAPIError as e:
yield "error", f"Database error: {_sanitize(e)}"
except ModuleNotFoundError:
yield (
"error",
(
f"Database driver not available for "
f"{input_data.database_type.value}. "
f"Please contact the platform administrator."
),
)
@staticmethod
def _validate_query(input_data: "SQLQueryBlock.Input") -> str | None:
"""Validate query structure and read-only constraints."""
stmt_error, parsed_stmt = _validate_single_statement(input_data.query)
if stmt_error:
return stmt_error
assert parsed_stmt is not None
if input_data.read_only:
return _validate_query_is_read_only(parsed_stmt)
return None
async def _resolve_host(
self, input_data: "SQLQueryBlock.Input"
) -> tuple[str, str, str | None]:
"""Validate and resolve the database host. Returns (host, pinned_ip, error)."""
host = input_data.host.get_secret_value().strip()
if not host:
return "", "", "Database host is required."
if host.startswith("/"):
return host, "", "Unix socket connections are not allowed."
try:
resolved_ips = await self.check_host_allowed(host)
except (ValueError, OSError) as e:
return host, "", f"Blocked host: {str(e).strip()}"
return host, resolved_ips[0], None
@staticmethod
def _classify_operational_error(sanitized_msg: str, timeout: int) -> str:
"""Classify an already-sanitized OperationalError for user display."""
lower = sanitized_msg.lower()
if "timeout" in lower or "cancel" in lower:
return f"Query timed out after {timeout}s."
if "connect" in lower:
return f"Failed to connect to database: {sanitized_msg}"
return f"Database error: {sanitized_msg}"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,430 @@
import re
from datetime import date, datetime, time
from decimal import Decimal
from enum import Enum
from typing import Any
import sqlparse
from sqlalchemy import create_engine, text
from sqlalchemy.engine.url import URL
class DatabaseType(str, Enum):
POSTGRES = "postgres"
MYSQL = "mysql"
MSSQL = "mssql"
# Defense-in-depth: reject queries containing data-modifying keywords.
# These are checked against parsed SQL tokens (not raw text) so column names
# and string literals do not cause false positives.
_DISALLOWED_KEYWORDS = {
"INSERT",
"UPDATE",
"DELETE",
"DROP",
"ALTER",
"CREATE",
"TRUNCATE",
"GRANT",
"REVOKE",
"COPY",
"EXECUTE",
"CALL",
"SET",
"RESET",
"DISCARD",
"NOTIFY",
"DO",
# MySQL file exfiltration: LOAD DATA LOCAL INFILE reads server/client files
"LOAD",
# MySQL REPLACE is INSERT-or-UPDATE; data modification
"REPLACE",
# ANSI MERGE (UPSERT) modifies data
"MERGE",
# MSSQL BULK INSERT loads external files into tables
"BULK",
# MSSQL EXEC / EXEC sp_name runs stored procedures (arbitrary code)
"EXEC",
}
# Map DatabaseType enum values to the expected SQLAlchemy driver prefix.
_DATABASE_TYPE_TO_DRIVER = {
DatabaseType.POSTGRES: "postgresql",
DatabaseType.MYSQL: "mysql+pymysql",
DatabaseType.MSSQL: "mssql+pymssql",
}
# Connection timeout in seconds passed to the DBAPI driver (connect_timeout /
# login_timeout). This bounds how long the driver waits to establish a TCP
# connection to the database server. It is separate from the per-statement
# timeout configured via SET commands inside _configure_session().
_CONNECT_TIMEOUT_SECONDS = 10
# Default ports for each database type.
_DATABASE_TYPE_DEFAULT_PORT = {
DatabaseType.POSTGRES: 5432,
DatabaseType.MYSQL: 3306,
DatabaseType.MSSQL: 1433,
}
def _sanitize_error(
error_msg: str,
connection_string: str,
*,
host: str = "",
original_host: str = "",
username: str = "",
port: int = 0,
database: str = "",
) -> str:
"""Remove connection string, credentials, and infrastructure details
from error messages so they are safe to expose to the LLM.
Scrubs:
- The full connection string
- URL-embedded credentials (``://user:pass@``)
- ``password=<value>`` key-value pairs
- The database hostname / IP used for the connection
- The original (pre-resolution) hostname provided by the user
- Any IPv4 addresses that appear in the message
- Any bracketed IPv6 addresses (e.g. ``[::1]``, ``[fe80::1%eth0]``)
- The database username
- The database port number
- The database name
"""
sanitized = error_msg.replace(connection_string, "<connection_string>")
sanitized = re.sub(r"password=[^\s&]+", "password=***", sanitized)
sanitized = re.sub(r"://[^@]+@", "://***:***@", sanitized)
# Replace the known host (may be an IP already) before the generic IP pass.
# Also replace the original (pre-DNS-resolution) hostname if it differs.
if original_host and original_host != host:
sanitized = sanitized.replace(original_host, "<host>")
if host:
sanitized = sanitized.replace(host, "<host>")
# Replace any remaining IPv4 addresses (e.g. resolved IPs the driver logs)
sanitized = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", "<ip>", sanitized)
# Replace bracketed IPv6 addresses (e.g. "[::1]", "[fe80::1%eth0]")
sanitized = re.sub(r"\[[0-9a-fA-F:]+(?:%[^\]]+)?\]", "<ip>", sanitized)
# Replace the database username (handles double-quoted, single-quoted,
# and unquoted formats across PostgreSQL, MySQL, and MSSQL error messages).
if username:
sanitized = re.sub(
r"""for user ["']?""" + re.escape(username) + r"""["']?""",
"for user <user>",
sanitized,
)
# Catch remaining bare occurrences in various quote styles:
# - PostgreSQL: "FATAL: role "myuser" does not exist"
# - MySQL: "Access denied for user 'myuser'@'host'"
# - MSSQL: "Login failed for user 'myuser'"
sanitized = sanitized.replace(f'"{username}"', "<user>")
sanitized = sanitized.replace(f"'{username}'", "<user>")
# Replace the port number (handles "port 5432" and ":5432" formats)
if port:
port_str = re.escape(str(port))
sanitized = re.sub(
r"(?:port |:)" + port_str + r"(?![0-9])",
lambda m: ("port " if m.group().startswith("p") else ":") + "<port>",
sanitized,
)
# Replace the database name to avoid leaking internal infrastructure names.
# Use word-boundary regex to prevent mangling when the database name is a
# common substring (e.g. "test", "data", "on").
if database:
sanitized = re.sub(r"\b" + re.escape(database) + r"\b", "<database>", sanitized)
return sanitized
def _extract_keyword_tokens(parsed: sqlparse.sql.Statement) -> list[str]:
"""Extract keyword tokens from a parsed SQL statement.
Uses sqlparse token type classification to collect Keyword/DML/DDL/DCL
tokens. String literals and identifiers have different token types, so
they are naturally excluded from the result.
"""
return [
token.normalized.upper()
for token in parsed.flatten()
if token.ttype
in (
sqlparse.tokens.Keyword,
sqlparse.tokens.Keyword.DML,
sqlparse.tokens.Keyword.DDL,
sqlparse.tokens.Keyword.DCL,
)
]
def _has_disallowed_into(stmt: sqlparse.sql.Statement) -> bool:
"""Check if a statement contains a disallowed ``INTO`` clause.
``SELECT ... INTO @variable`` is a valid read-only MySQL syntax that stores
a query result into a session-scoped user variable. All other forms of
``INTO`` are data-modifying or file-writing and must be blocked:
* ``SELECT ... INTO new_table`` (PostgreSQL / MSSQL creates a table)
* ``SELECT ... INTO OUTFILE`` (MySQL writes to the filesystem)
* ``SELECT ... INTO DUMPFILE`` (MySQL writes to the filesystem)
* ``INSERT INTO ...`` (already blocked by INSERT being in the
disallowed set, but we reject INTO as well for defense-in-depth)
Returns ``True`` if the statement contains a disallowed ``INTO``.
"""
flat = list(stmt.flatten())
for i, token in enumerate(flat):
if not (
token.ttype in (sqlparse.tokens.Keyword,)
and token.normalized.upper() == "INTO"
):
continue
# Look at the first non-whitespace token after INTO.
j = i + 1
while j < len(flat) and flat[j].ttype is sqlparse.tokens.Text.Whitespace:
j += 1
if j >= len(flat):
# INTO at the very end malformed, block it.
return True
next_token = flat[j]
# MySQL user variable: either a single Name starting with "@"
# (e.g. ``@total``) or a bare ``@`` Operator token followed by a Name.
if next_token.ttype is sqlparse.tokens.Name and next_token.value.startswith(
"@"
):
continue
if next_token.ttype is sqlparse.tokens.Operator and next_token.value == "@":
continue
# Everything else (table name, OUTFILE, DUMPFILE, etc.) is disallowed.
return True
return False
def _validate_query_is_read_only(stmt: sqlparse.sql.Statement) -> str | None:
"""Validate that a parsed SQL statement is read-only (SELECT/WITH only).
Accepts an already-parsed statement from ``_validate_single_statement``
to avoid re-parsing. Checks:
1. Statement type must be SELECT (sqlparse classifies WITH...SELECT as SELECT)
2. No disallowed keywords (INSERT, UPDATE, DELETE, DROP, etc.)
3. No disallowed INTO clauses (allows MySQL ``SELECT ... INTO @variable``)
Returns an error message if the query is not read-only, None otherwise.
"""
# sqlparse returns 'SELECT' for SELECT and WITH...SELECT queries
if stmt.get_type() != "SELECT":
return "Only SELECT queries are allowed."
# Defense-in-depth: check parsed keyword tokens for disallowed keywords
for kw in _extract_keyword_tokens(stmt):
# Normalize multi-word tokens (e.g. "SET LOCAL" -> "SET")
base_kw = kw.split()[0] if " " in kw else kw
if base_kw in _DISALLOWED_KEYWORDS:
return f"Disallowed SQL keyword: {kw}"
# Contextual check for INTO: allow MySQL @variable syntax, block everything else
if _has_disallowed_into(stmt):
return "Disallowed SQL keyword: INTO"
return None
def _validate_single_statement(
query: str,
) -> tuple[str | None, sqlparse.sql.Statement | None]:
"""Validate that the query contains exactly one non-empty SQL statement.
Returns (error_message, parsed_statement). If error_message is not None,
the query is invalid and parsed_statement will be None.
"""
stripped = query.strip().rstrip(";").strip()
if not stripped:
return "Query is empty.", None
# Parse the SQL using sqlparse for proper tokenization
statements = sqlparse.parse(stripped)
# Filter out empty statements and comment-only statements
statements = [
s
for s in statements
if s.tokens
and str(s).strip()
and not all(
t.is_whitespace or t.ttype in sqlparse.tokens.Comment for t in s.flatten()
)
]
if not statements:
return "Query is empty.", None
# Reject multiple statements -- prevents injection via semicolons
if len(statements) > 1:
return "Only single statements are allowed.", None
return None, statements[0]
def _serialize_value(value: Any) -> Any:
"""Convert database-specific types to JSON-serializable Python types."""
if isinstance(value, Decimal):
# NaN / Infinity are not valid JSON numbers; serialize as strings.
if value.is_nan() or value.is_infinite():
return str(value)
# Use int for whole numbers; use str for fractional to preserve exact
# precision (float would silently round high-precision analytics values).
if value == value.to_integral_value():
return int(value)
return str(value)
if isinstance(value, (datetime, date, time)):
return value.isoformat()
if isinstance(value, memoryview):
return bytes(value).hex()
if isinstance(value, bytes):
return value.hex()
return value
def _configure_session(
conn: Any,
dialect_name: str,
timeout_ms: str,
read_only: bool,
) -> None:
"""Set session-level timeout and read-only mode for the given dialect.
Timeout limitations by database:
* **PostgreSQL** ``statement_timeout`` reliably cancels any running
statement (SELECT or DML) after the configured duration.
* **MySQL** ``MAX_EXECUTION_TIME`` only applies to **read-only SELECT**
statements. DML (INSERT/UPDATE/DELETE) and DDL are *not* bounded by
this hint; they rely on the server's ``wait_timeout`` /
``interactive_timeout`` instead. There is no session-level setting in
MySQL that reliably cancels long-running writes.
* **MSSQL** ``SET LOCK_TIMEOUT`` only limits how long the server waits
to acquire a **lock**. CPU-bound queries (e.g. large scans, hash
joins) that do not block on locks will *not* be cancelled. MSSQL has
no session-level ``statement_timeout`` equivalent; the closest
mechanism is Resource Governor (requires sysadmin configuration) or
``CONTEXT_INFO``-based external monitoring.
Note: SQLite is not supported by this block. The ``_configure_session``
function is a no-op for unrecognised dialect names, so an SQLite engine
would skip all SET commands silently. The block's ``DatabaseType`` enum
intentionally excludes SQLite.
"""
if dialect_name == "postgresql":
conn.execute(text("SET statement_timeout = " + timeout_ms))
if read_only:
conn.execute(text("SET default_transaction_read_only = ON"))
elif dialect_name == "mysql":
# NOTE: MAX_EXECUTION_TIME only applies to SELECT statements.
# Write queries (INSERT/UPDATE/DELETE) are not bounded by this
# setting; they rely on the database's wait_timeout instead.
# See docstring above for full limitations.
conn.execute(text("SET SESSION MAX_EXECUTION_TIME = " + timeout_ms))
if read_only:
conn.execute(text("SET SESSION TRANSACTION READ ONLY"))
elif dialect_name == "mssql":
# MSSQL: SET LOCK_TIMEOUT limits lock-wait time (ms) only.
# CPU-bound queries without lock contention are NOT cancelled.
# See docstring above for full limitations.
conn.execute(text("SET LOCK_TIMEOUT " + timeout_ms))
# MSSQL lacks a session-level read-only mode like
# PostgreSQL/MySQL. Read-only enforcement is handled by
# the SQL validation layer (_validate_query_is_read_only)
# and the ROLLBACK in the finally block.
def _run_in_transaction(
conn: Any,
dialect_name: str,
query: str,
max_rows: int,
read_only: bool,
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
"""Execute a query inside an explicit transaction, returning results.
Returns ``(rows, columns, affected_rows, truncated)`` where *truncated*
is ``True`` when ``fetchmany`` returned exactly ``max_rows`` rows,
indicating that additional rows may exist in the result set.
"""
# MSSQL uses T-SQL "BEGIN TRANSACTION"; others use "BEGIN".
begin_stmt = "BEGIN TRANSACTION" if dialect_name == "mssql" else "BEGIN"
conn.execute(text(begin_stmt))
try:
result = conn.execute(text(query))
affected = result.rowcount if not result.returns_rows else -1
columns = list(result.keys()) if result.returns_rows else []
rows = result.fetchmany(max_rows) if result.returns_rows else []
truncated = len(rows) == max_rows
results = [
{col: _serialize_value(val) for col, val in zip(columns, row)}
for row in rows
]
except Exception:
try:
conn.execute(text("ROLLBACK"))
except Exception:
pass
raise
else:
conn.execute(text("ROLLBACK" if read_only else "COMMIT"))
return results, columns, affected, truncated
def _execute_query(
connection_url: URL | str,
query: str,
timeout: int,
max_rows: int,
read_only: bool = True,
database_type: DatabaseType = DatabaseType.POSTGRES,
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
"""Execute a SQL query and return (rows, columns, affected_rows, truncated).
Uses SQLAlchemy to connect to any supported database.
For SELECT queries, rows are limited to ``max_rows`` via DBAPI fetchmany.
``truncated`` is ``True`` when the result set was capped by ``max_rows``.
For write queries, affected_rows contains the rowcount from the driver.
When ``read_only`` is True, the database session is set to read-only
mode and the transaction is always rolled back.
"""
# Determine driver-specific connection timeout argument.
# pymssql uses "login_timeout", while PostgreSQL/MySQL use "connect_timeout".
timeout_key = (
"login_timeout" if database_type == DatabaseType.MSSQL else "connect_timeout"
)
engine = create_engine(
connection_url, connect_args={timeout_key: _CONNECT_TIMEOUT_SECONDS}
)
try:
with engine.connect() as conn:
# Use AUTOCOMMIT so SET commands take effect immediately.
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
# Compute timeout in milliseconds. The value is Pydantic-validated
# (ge=1, le=120), but we use int() as defense-in-depth.
# NOTE: SET commands do not support bind parameters in most
# databases, so we use str(int(...)) for safe interpolation.
timeout_ms = str(int(timeout * 1000))
_configure_session(conn, engine.dialect.name, timeout_ms, read_only)
return _run_in_transaction(
conn, engine.dialect.name, query, max_rows, read_only
)
finally:
engine.dispose()

View File

@@ -300,13 +300,27 @@ def test_agent_input_block_ignores_legacy_placeholder_values():
def test_dropdown_input_block_produces_enum():
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum."""
options = ["Option A", "Option B"]
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum
using the canonical 'options' field name."""
opts = ["Option A", "Option B"]
instance = AgentDropdownInputBlock.Input.model_construct(
name="choice", value=None, placeholder_values=options
name="choice", value=None, options=opts
)
schema = instance.generate_schema()
assert schema.get("enum") == options
assert schema.get("enum") == opts
def test_dropdown_input_block_legacy_placeholder_values_produces_enum():
"""Verify backward compat: passing legacy 'placeholder_values' to
AgentDropdownInputBlock still produces enum via model_construct remap."""
opts = ["Option A", "Option B"]
instance = AgentDropdownInputBlock.Input.model_construct(
name="choice", value=None, placeholder_values=opts
)
schema = instance.generate_schema()
assert (
schema.get("enum") == opts
), "Legacy placeholder_values should be remapped to options"
def test_generate_schema_integration_legacy_placeholder_values():
@@ -329,11 +343,11 @@ def test_generate_schema_integration_legacy_placeholder_values():
def test_generate_schema_integration_dropdown_produces_enum():
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
— verifies enum IS produced for dropdown blocks."""
— verifies enum IS produced for dropdown blocks using canonical field name."""
dropdown_input_default = {
"name": "color",
"value": None,
"placeholder_values": ["Red", "Green", "Blue"],
"options": ["Red", "Green", "Blue"],
}
result = BaseGraph._generate_schema(
(AgentDropdownInputBlock.Input, dropdown_input_default),
@@ -344,3 +358,36 @@ def test_generate_schema_integration_dropdown_produces_enum():
"Green",
"Blue",
], "Graph schema should contain enum from AgentDropdownInputBlock"
def test_generate_schema_integration_dropdown_legacy_placeholder_values():
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
using legacy 'placeholder_values' — verifies backward compat produces enum."""
legacy_dropdown_input_default = {
"name": "color",
"value": None,
"placeholder_values": ["Red", "Green", "Blue"],
}
result = BaseGraph._generate_schema(
(AgentDropdownInputBlock.Input, legacy_dropdown_input_default),
)
color_props = result["properties"]["color"]
assert color_props.get("enum") == [
"Red",
"Green",
"Blue",
], "Legacy placeholder_values should still produce enum via model_construct remap"
def test_dropdown_input_block_init_legacy_placeholder_values():
"""Verify backward compat: constructing AgentDropdownInputBlock.Input via
model_validate with legacy 'placeholder_values' correctly maps to 'options'."""
opts = ["Option A", "Option B"]
instance = AgentDropdownInputBlock.Input.model_validate(
{"name": "choice", "value": None, "placeholder_values": opts}
)
assert (
instance.options == opts
), "Legacy placeholder_values should be remapped to options via model_validate"
schema = instance.generate_schema()
assert schema.get("enum") == opts

View File

@@ -199,6 +199,66 @@ class TestLLMStatsTracking:
assert block.execution_stats.llm_call_count == 2 # retry_count + 1 = 1 + 1 = 2
assert block.execution_stats.llm_retry_count == 1
@pytest.mark.asyncio
async def test_retry_cost_uses_last_attempt_only(self):
"""provider_cost is only merged from the final successful attempt.
Intermediate retry costs are intentionally dropped to avoid
double-counting: the cost of failed attempts is captured in
last_attempt_cost only when the loop eventually succeeds.
"""
import backend.blocks.llm as llm
block = llm.AIStructuredResponseGeneratorBlock()
call_count = 0
async def mock_llm_call(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
# First attempt: fails validation, returns cost $0.01
return llm.LLMResponse(
raw_response="",
prompt=[],
response='<json_output id="test123456">{"wrong": "key"}</json_output>',
tool_calls=None,
prompt_tokens=10,
completion_tokens=5,
reasoning=None,
provider_cost=0.01,
)
# Second attempt: succeeds, returns cost $0.02
return llm.LLMResponse(
raw_response="",
prompt=[],
response='<json_output id="test123456">{"key1": "value1", "key2": "value2"}</json_output>',
tool_calls=None,
prompt_tokens=20,
completion_tokens=10,
reasoning=None,
provider_cost=0.02,
)
block.llm_call = mock_llm_call # type: ignore
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
prompt="Test prompt",
expected_format={"key1": "desc1", "key2": "desc2"},
model=llm.DEFAULT_LLM_MODEL,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2,
)
with patch("secrets.token_hex", return_value="test123456"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
# Only the final successful attempt's cost is merged
assert block.execution_stats.provider_cost == pytest.approx(0.02)
# Tokens from both attempts accumulate
assert block.execution_stats.input_token_count == 30
assert block.execution_stats.output_token_count == 15
@pytest.mark.asyncio
async def test_ai_text_summarizer_multiple_chunks(self):
"""Test that AITextSummarizerBlock correctly accumulates stats across multiple chunks."""
@@ -488,6 +548,154 @@ class TestLLMStatsTracking:
assert outputs["response"] == {"result": "test"}
class TestAIConversationBlockValidation:
"""Test that AIConversationBlock validates inputs before calling the LLM."""
@pytest.mark.asyncio
async def test_empty_messages_and_empty_prompt_raises_error(self):
"""Empty messages with no prompt should raise ValueError, not a cryptic API error."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_empty_messages_with_prompt_succeeds(self):
"""Empty messages but a non-empty prompt should proceed without error."""
block = llm.AIConversationBlock()
async def mock_llm_call(input_data, credentials):
return {"response": "OK"}
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
input_data = llm.AIConversationBlock.Input(
messages=[],
prompt="Hello, how are you?",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
outputs = {}
async for name, data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[name] = data
assert outputs["response"] == "OK"
@pytest.mark.asyncio
async def test_nonempty_messages_with_empty_prompt_succeeds(self):
"""Non-empty messages with no prompt should proceed without error."""
block = llm.AIConversationBlock()
async def mock_llm_call(input_data, credentials):
return {"response": "response from conversation"}
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": "Hello"}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
outputs = {}
async for name, data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[name] = data
assert outputs["response"] == "response from conversation"
@pytest.mark.asyncio
async def test_messages_with_empty_content_raises_error(self):
"""Messages with empty content strings should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": ""}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_whitespace_content_raises_error(self):
"""Messages with whitespace-only content should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": " "}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_none_entry_raises_error(self):
"""Messages list containing None should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[None],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_empty_dict_raises_error(self):
"""Messages list containing empty dict should be treated as no messages."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_messages_with_none_content_raises_error(self):
"""Messages with content=None should not crash with AttributeError."""
block = llm.AIConversationBlock()
input_data = llm.AIConversationBlock.Input(
messages=[{"role": "user", "content": None}],
prompt="",
model=llm.DEFAULT_LLM_MODEL,
credentials=_TEST_AI_CREDENTIALS,
)
with pytest.raises(ValueError, match="no messages and no prompt"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
class TestAITextSummarizerValidation:
"""Test that AITextSummarizerBlock validates LLM responses are strings."""
@@ -839,3 +1047,63 @@ class TestLlmModelMissing:
assert (
llm.LlmModel("extra/google/gemini-2.5-pro") == llm.LlmModel.GEMINI_2_5_PRO
)
class TestExtractOpenRouterCost:
"""Tests for extract_openrouter_cost — the x-total-cost header parser."""
def _mk_response(self, headers: dict | None):
response = MagicMock()
if headers is None:
response._response = None
else:
raw = MagicMock()
raw.headers = headers
response._response = raw
return response
def test_extracts_numeric_cost(self):
response = self._mk_response({"x-total-cost": "0.0042"})
assert llm.extract_openrouter_cost(response) == 0.0042
def test_returns_none_when_header_missing(self):
response = self._mk_response({})
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_when_header_empty_string(self):
response = self._mk_response({"x-total-cost": ""})
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_when_header_non_numeric(self):
response = self._mk_response({"x-total-cost": "not-a-number"})
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_when_no_response_attr(self):
response = MagicMock(spec=[]) # no _response attr
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_when_raw_is_none(self):
response = self._mk_response(None)
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_when_raw_has_no_headers(self):
response = MagicMock()
response._response = MagicMock(spec=[]) # no headers attr
assert llm.extract_openrouter_cost(response) is None
def test_returns_zero_for_zero_cost(self):
"""Zero-cost is a valid value (free tier) and must not become None."""
response = self._mk_response({"x-total-cost": "0"})
assert llm.extract_openrouter_cost(response) == 0.0
def test_returns_none_for_inf(self):
response = self._mk_response({"x-total-cost": "inf"})
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_for_negative_inf(self):
response = self._mk_response({"x-total-cost": "-inf"})
assert llm.extract_openrouter_cost(response) is None
def test_returns_none_for_nan(self):
response = self._mk_response({"x-total-cost": "nan"})
assert llm.extract_openrouter_cost(response) is None

View File

@@ -13,6 +13,7 @@ from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
)
from backend.integrations.providers import ProviderName
@@ -104,4 +105,10 @@ class UnrealTextToSpeechBlock(Block):
input_data.text,
input_data.voice_id,
)
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(input_data.text)),
provider_cost_type="characters",
)
)
yield "mp3_url", api_response["OutputUri"]

File diff suppressed because it is too large Load Diff

View File

@@ -31,7 +31,7 @@ async def test_baseline_multi_turn(setup_test_user, test_user_id):
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
session = await create_chat_session(test_user_id, dry_run=False)
session = await upsert_chat_session(session)
# --- Turn 1: send a message with a unique keyword ---

View File

@@ -0,0 +1,799 @@
"""Unit tests for baseline service pure-logic helpers.
These tests cover ``_baseline_conversation_updater`` and ``_BaselineStreamState``
without requiring API keys, database connections, or network access.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from openai.types.chat import ChatCompletionToolParam
from backend.copilot.baseline.service import (
_baseline_conversation_updater,
_BaselineStreamState,
_compress_session_messages,
_ThinkingStripper,
)
from backend.copilot.model import ChatMessage
from backend.copilot.transcript_builder import TranscriptBuilder
from backend.util.prompt import CompressResult
from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult
class TestBaselineStreamState:
def test_defaults(self):
state = _BaselineStreamState()
assert state.pending_events == []
assert state.assistant_text == ""
assert state.text_started is False
assert state.turn_prompt_tokens == 0
assert state.turn_completion_tokens == 0
assert state.text_block_id # Should be a UUID string
def test_mutable_fields(self):
state = _BaselineStreamState()
state.assistant_text = "hello"
state.turn_prompt_tokens = 100
state.turn_completion_tokens = 50
assert state.assistant_text == "hello"
assert state.turn_prompt_tokens == 100
assert state.turn_completion_tokens == 50
class TestBaselineConversationUpdater:
"""Tests for _baseline_conversation_updater which updates the OpenAI
message list and transcript builder after each LLM call."""
def _make_transcript_builder(self) -> TranscriptBuilder:
builder = TranscriptBuilder()
builder.append_user("test question")
return builder
def test_text_only_response(self):
"""When the LLM returns text without tool calls, the updater appends
a single assistant message and records it in the transcript."""
messages: list = []
builder = self._make_transcript_builder()
response = LLMLoopResponse(
response_text="Hello, world!",
tool_calls=[],
raw_response=None,
prompt_tokens=0,
completion_tokens=0,
)
_baseline_conversation_updater(
messages,
response,
tool_results=None,
transcript_builder=builder,
model="test-model",
)
assert len(messages) == 1
assert messages[0]["role"] == "assistant"
assert messages[0]["content"] == "Hello, world!"
# Transcript should have user + assistant
assert builder.entry_count == 2
assert builder.last_entry_type == "assistant"
def test_tool_calls_response(self):
"""When the LLM returns tool calls, the updater appends the assistant
message with tool_calls and tool result messages."""
messages: list = []
builder = self._make_transcript_builder()
response = LLMLoopResponse(
response_text="Let me search...",
tool_calls=[
LLMToolCall(
id="tc_1",
name="search",
arguments='{"query": "test"}',
),
],
raw_response=None,
prompt_tokens=0,
completion_tokens=0,
)
tool_results = [
ToolCallResult(
tool_call_id="tc_1",
tool_name="search",
content="Found result",
),
]
_baseline_conversation_updater(
messages,
response,
tool_results=tool_results,
transcript_builder=builder,
model="test-model",
)
# Messages: assistant (with tool_calls) + tool result
assert len(messages) == 2
assert messages[0]["role"] == "assistant"
assert messages[0]["content"] == "Let me search..."
assert len(messages[0]["tool_calls"]) == 1
assert messages[0]["tool_calls"][0]["id"] == "tc_1"
assert messages[1]["role"] == "tool"
assert messages[1]["tool_call_id"] == "tc_1"
assert messages[1]["content"] == "Found result"
# Transcript: user + assistant(tool_use) + user(tool_result)
assert builder.entry_count == 3
def test_tool_calls_without_text(self):
"""Tool calls without accompanying text should still work."""
messages: list = []
builder = self._make_transcript_builder()
response = LLMLoopResponse(
response_text=None,
tool_calls=[
LLMToolCall(id="tc_1", name="run", arguments="{}"),
],
raw_response=None,
prompt_tokens=0,
completion_tokens=0,
)
tool_results = [
ToolCallResult(tool_call_id="tc_1", tool_name="run", content="done"),
]
_baseline_conversation_updater(
messages,
response,
tool_results=tool_results,
transcript_builder=builder,
model="test-model",
)
assert len(messages) == 2
assert "content" not in messages[0] # No text content
assert messages[0]["tool_calls"][0]["function"]["name"] == "run"
def test_no_text_no_tools(self):
"""When the response has no text and no tool calls, nothing is appended."""
messages: list = []
builder = self._make_transcript_builder()
response = LLMLoopResponse(
response_text=None,
tool_calls=[],
raw_response=None,
prompt_tokens=0,
completion_tokens=0,
)
_baseline_conversation_updater(
messages,
response,
tool_results=None,
transcript_builder=builder,
model="test-model",
)
assert len(messages) == 0
# Only the user entry from setup
assert builder.entry_count == 1
def test_multiple_tool_calls(self):
"""Multiple tool calls in a single response are all recorded."""
messages: list = []
builder = self._make_transcript_builder()
response = LLMLoopResponse(
response_text=None,
tool_calls=[
LLMToolCall(id="tc_1", name="tool_a", arguments="{}"),
LLMToolCall(id="tc_2", name="tool_b", arguments='{"x": 1}'),
],
raw_response=None,
prompt_tokens=0,
completion_tokens=0,
)
tool_results = [
ToolCallResult(tool_call_id="tc_1", tool_name="tool_a", content="result_a"),
ToolCallResult(tool_call_id="tc_2", tool_name="tool_b", content="result_b"),
]
_baseline_conversation_updater(
messages,
response,
tool_results=tool_results,
transcript_builder=builder,
model="test-model",
)
# 1 assistant + 2 tool results
assert len(messages) == 3
assert len(messages[0]["tool_calls"]) == 2
assert messages[1]["tool_call_id"] == "tc_1"
assert messages[2]["tool_call_id"] == "tc_2"
def test_invalid_tool_arguments_handled(self):
"""Tool call with invalid JSON arguments: the arguments field is
stored as-is in the message, and orjson failure falls back to {}
in the transcript content_blocks."""
messages: list = []
builder = self._make_transcript_builder()
response = LLMLoopResponse(
response_text=None,
tool_calls=[
LLMToolCall(id="tc_1", name="tool_x", arguments="not-json"),
],
raw_response=None,
prompt_tokens=0,
completion_tokens=0,
)
tool_results = [
ToolCallResult(tool_call_id="tc_1", tool_name="tool_x", content="ok"),
]
_baseline_conversation_updater(
messages,
response,
tool_results=tool_results,
transcript_builder=builder,
model="test-model",
)
# Should not raise — invalid JSON falls back to {} in transcript
assert len(messages) == 2
assert messages[0]["tool_calls"][0]["function"]["arguments"] == "not-json"
class TestCompressSessionMessagesPreservesToolCalls:
"""``_compress_session_messages`` must round-trip tool_calls + tool_call_id.
Compression serialises ChatMessage to dict for ``compress_context`` and
reifies the result back to ChatMessage. A regression that drops
``tool_calls`` or ``tool_call_id`` would corrupt the OpenAI message
list and break downstream tool-execution rounds.
"""
@pytest.mark.asyncio
async def test_compressed_output_keeps_tool_calls_and_ids(self):
# Simulate compression that returns a summary + the most recent
# assistant(tool_call) + tool(tool_result) intact.
summary = {"role": "system", "content": "prior turns: user asked X"}
assistant_with_tc = {
"role": "assistant",
"content": "calling tool",
"tool_calls": [
{
"id": "tc_abc",
"type": "function",
"function": {"name": "search", "arguments": '{"q":"y"}'},
}
],
}
tool_result = {
"role": "tool",
"tool_call_id": "tc_abc",
"content": "search result",
}
compress_result = CompressResult(
messages=[summary, assistant_with_tc, tool_result],
token_count=100,
was_compacted=True,
original_token_count=5000,
messages_summarized=10,
messages_dropped=0,
)
# Input: messages that should be compressed.
input_messages = [
ChatMessage(role="user", content="q1"),
ChatMessage(
role="assistant",
content="calling tool",
tool_calls=[
{
"id": "tc_abc",
"type": "function",
"function": {
"name": "search",
"arguments": '{"q":"y"}',
},
}
],
),
ChatMessage(
role="tool",
tool_call_id="tc_abc",
content="search result",
),
]
with patch(
"backend.copilot.baseline.service.compress_context",
new=AsyncMock(return_value=compress_result),
):
compressed = await _compress_session_messages(
input_messages, model="openrouter/anthropic/claude-opus-4"
)
# Summary, assistant(tool_calls), tool(tool_call_id).
assert len(compressed) == 3
# Assistant message must keep its tool_calls intact.
assistant_msg = compressed[1]
assert assistant_msg.role == "assistant"
assert assistant_msg.tool_calls is not None
assert len(assistant_msg.tool_calls) == 1
assert assistant_msg.tool_calls[0]["id"] == "tc_abc"
assert assistant_msg.tool_calls[0]["function"]["name"] == "search"
# Tool-role message must keep tool_call_id for OpenAI linkage.
tool_msg = compressed[2]
assert tool_msg.role == "tool"
assert tool_msg.tool_call_id == "tc_abc"
assert tool_msg.content == "search result"
@pytest.mark.asyncio
async def test_uncompressed_passthrough_keeps_fields(self):
"""When compression is a no-op (was_compacted=False), the original
messages must be returned unchanged — including tool_calls."""
input_messages = [
ChatMessage(
role="assistant",
content="c",
tool_calls=[
{
"id": "t1",
"type": "function",
"function": {"name": "f", "arguments": "{}"},
}
],
),
ChatMessage(role="tool", tool_call_id="t1", content="ok"),
]
noop_result = CompressResult(
messages=[], # ignored when was_compacted=False
token_count=10,
was_compacted=False,
)
with patch(
"backend.copilot.baseline.service.compress_context",
new=AsyncMock(return_value=noop_result),
):
out = await _compress_session_messages(
input_messages, model="openrouter/anthropic/claude-opus-4"
)
assert out is input_messages # same list returned
assert out[0].tool_calls is not None
assert out[0].tool_calls[0]["id"] == "t1"
assert out[1].tool_call_id == "t1"
# ---- _ThinkingStripper tests ---- #
def test_thinking_stripper_basic_thinking_tag() -> None:
"""<thinking>...</thinking> blocks are fully stripped."""
s = _ThinkingStripper()
assert s.process("<thinking>internal reasoning here</thinking>Hello!") == "Hello!"
def test_thinking_stripper_internal_reasoning_tag() -> None:
"""<internal_reasoning>...</internal_reasoning> blocks (Gemini) are stripped."""
s = _ThinkingStripper()
assert (
s.process("<internal_reasoning>step by step</internal_reasoning>Answer")
== "Answer"
)
def test_thinking_stripper_split_across_chunks() -> None:
"""Tags split across multiple chunks are handled correctly."""
s = _ThinkingStripper()
out = s.process("Hello <thin")
out += s.process("king>secret</thinking> world")
assert out == "Hello world"
def test_thinking_stripper_plain_text_preserved() -> None:
"""Plain text with the word 'thinking' is not stripped."""
s = _ThinkingStripper()
assert (
s.process("I am thinking about this problem")
== "I am thinking about this problem"
)
def test_thinking_stripper_multiple_blocks() -> None:
"""Multiple reasoning blocks in one stream are all stripped."""
s = _ThinkingStripper()
result = s.process(
"A<thinking>x</thinking>B<internal_reasoning>y</internal_reasoning>C"
)
assert result == "ABC"
def test_thinking_stripper_flush_discards_unclosed() -> None:
"""Unclosed reasoning block is discarded on flush."""
s = _ThinkingStripper()
s.process("Start<thinking>never closed")
flushed = s.flush()
assert "never closed" not in flushed
def test_thinking_stripper_empty_block() -> None:
"""Empty reasoning blocks are handled gracefully."""
s = _ThinkingStripper()
assert s.process("Before<thinking></thinking>After") == "BeforeAfter"
# ---- _filter_tools_by_permissions tests ---- #
def _make_tool(name: str) -> ChatCompletionToolParam:
"""Build a minimal OpenAI ChatCompletionToolParam."""
return ChatCompletionToolParam(
type="function",
function={"name": name, "parameters": {}},
)
class TestFilterToolsByPermissions:
"""Tests for _filter_tools_by_permissions."""
@patch(
"backend.copilot.permissions.all_known_tool_names",
return_value=frozenset({"run_block", "web_fetch", "bash_exec"}),
)
def test_empty_permissions_returns_all(self, _mock_names):
"""Empty permissions (no filtering) returns every tool unchanged."""
from backend.copilot.baseline.service import _filter_tools_by_permissions
from backend.copilot.permissions import CopilotPermissions
tools = [_make_tool("run_block"), _make_tool("web_fetch")]
perms = CopilotPermissions()
result = _filter_tools_by_permissions(tools, perms)
assert result == tools
@patch(
"backend.copilot.permissions.all_known_tool_names",
return_value=frozenset({"run_block", "web_fetch", "bash_exec"}),
)
def test_allowlist_keeps_only_matching(self, _mock_names):
"""Explicit allowlist (tools_exclude=False) keeps only listed tools."""
from backend.copilot.baseline.service import _filter_tools_by_permissions
from backend.copilot.permissions import CopilotPermissions
tools = [
_make_tool("run_block"),
_make_tool("web_fetch"),
_make_tool("bash_exec"),
]
perms = CopilotPermissions(tools=["web_fetch"], tools_exclude=False)
result = _filter_tools_by_permissions(tools, perms)
assert len(result) == 1
assert result[0]["function"]["name"] == "web_fetch"
@patch(
"backend.copilot.permissions.all_known_tool_names",
return_value=frozenset({"run_block", "web_fetch", "bash_exec"}),
)
def test_blacklist_excludes_listed(self, _mock_names):
"""Blacklist (tools_exclude=True) removes only the listed tools."""
from backend.copilot.baseline.service import _filter_tools_by_permissions
from backend.copilot.permissions import CopilotPermissions
tools = [
_make_tool("run_block"),
_make_tool("web_fetch"),
_make_tool("bash_exec"),
]
perms = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
result = _filter_tools_by_permissions(tools, perms)
names = [t["function"]["name"] for t in result]
assert "bash_exec" not in names
assert "run_block" in names
assert "web_fetch" in names
assert len(result) == 2
@patch(
"backend.copilot.permissions.all_known_tool_names",
return_value=frozenset({"run_block", "web_fetch", "bash_exec"}),
)
def test_unknown_tool_name_filtered_out(self, _mock_names):
"""A tool whose name is not in all_known_tool_names is dropped."""
from backend.copilot.baseline.service import _filter_tools_by_permissions
from backend.copilot.permissions import CopilotPermissions
tools = [_make_tool("run_block"), _make_tool("unknown_tool")]
perms = CopilotPermissions(tools=["run_block"], tools_exclude=False)
result = _filter_tools_by_permissions(tools, perms)
names = [t["function"]["name"] for t in result]
assert "unknown_tool" not in names
assert names == ["run_block"]
# ---- _prepare_baseline_attachments tests ---- #
class TestPrepareBaselineAttachments:
"""Tests for _prepare_baseline_attachments."""
@pytest.mark.asyncio
async def test_empty_file_ids(self):
"""Empty file_ids returns empty hint and blocks."""
from backend.copilot.baseline.service import _prepare_baseline_attachments
hint, blocks = await _prepare_baseline_attachments([], "user1", "sess1", "/tmp")
assert hint == ""
assert blocks == []
@pytest.mark.asyncio
async def test_empty_user_id(self):
"""Empty user_id returns empty hint and blocks."""
from backend.copilot.baseline.service import _prepare_baseline_attachments
hint, blocks = await _prepare_baseline_attachments(
["file1"], "", "sess1", "/tmp"
)
assert hint == ""
assert blocks == []
@pytest.mark.asyncio
async def test_image_file_returns_vision_blocks(self):
"""A PNG image within size limits is returned as a base64 vision block."""
from backend.copilot.baseline.service import _prepare_baseline_attachments
fake_info = AsyncMock()
fake_info.name = "photo.png"
fake_info.mime_type = "image/png"
fake_info.size_bytes = 1024
fake_manager = AsyncMock()
fake_manager.get_file_info = AsyncMock(return_value=fake_info)
fake_manager.read_file_by_id = AsyncMock(return_value=b"\x89PNG_FAKE_DATA")
with patch(
"backend.copilot.baseline.service.get_workspace_manager",
new=AsyncMock(return_value=fake_manager),
):
hint, blocks = await _prepare_baseline_attachments(
["fid1"], "user1", "sess1", "/tmp/workdir"
)
assert len(blocks) == 1
assert blocks[0]["type"] == "image"
assert blocks[0]["source"]["media_type"] == "image/png"
assert blocks[0]["source"]["type"] == "base64"
assert "photo.png" in hint
assert "embedded as image" in hint
@pytest.mark.asyncio
async def test_non_image_file_saved_to_working_dir(self, tmp_path):
"""A non-image file is written to working_dir."""
from backend.copilot.baseline.service import _prepare_baseline_attachments
fake_info = AsyncMock()
fake_info.name = "data.csv"
fake_info.mime_type = "text/csv"
fake_info.size_bytes = 42
fake_manager = AsyncMock()
fake_manager.get_file_info = AsyncMock(return_value=fake_info)
fake_manager.read_file_by_id = AsyncMock(return_value=b"col1,col2\na,b")
with patch(
"backend.copilot.baseline.service.get_workspace_manager",
new=AsyncMock(return_value=fake_manager),
):
hint, blocks = await _prepare_baseline_attachments(
["fid1"], "user1", "sess1", str(tmp_path)
)
assert blocks == []
assert "data.csv" in hint
assert "saved to" in hint
saved = tmp_path / "data.csv"
assert saved.exists()
assert saved.read_bytes() == b"col1,col2\na,b"
@pytest.mark.asyncio
async def test_file_not_found_skipped(self):
"""When get_file_info returns None the file is silently skipped."""
from backend.copilot.baseline.service import _prepare_baseline_attachments
fake_manager = AsyncMock()
fake_manager.get_file_info = AsyncMock(return_value=None)
with patch(
"backend.copilot.baseline.service.get_workspace_manager",
new=AsyncMock(return_value=fake_manager),
):
hint, blocks = await _prepare_baseline_attachments(
["missing_id"], "user1", "sess1", "/tmp"
)
assert hint == ""
assert blocks == []
@pytest.mark.asyncio
async def test_workspace_manager_error(self):
"""When get_workspace_manager raises, returns empty results."""
from backend.copilot.baseline.service import _prepare_baseline_attachments
with patch(
"backend.copilot.baseline.service.get_workspace_manager",
new=AsyncMock(side_effect=RuntimeError("connection failed")),
):
hint, blocks = await _prepare_baseline_attachments(
["fid1"], "user1", "sess1", "/tmp"
)
assert hint == ""
assert blocks == []
class TestBaselineCostExtraction:
"""Tests for x-total-cost header extraction in _baseline_llm_caller."""
@pytest.mark.asyncio
async def test_cost_usd_extracted_from_response_header(self):
"""state.cost_usd is set from x-total-cost header when present."""
from backend.copilot.baseline.service import (
_baseline_llm_caller,
_BaselineStreamState,
)
state = _BaselineStreamState(model="gpt-4o-mini")
# Build a mock raw httpx response with the cost header
mock_raw_response = MagicMock()
mock_raw_response.headers = {"x-total-cost": "0.0123"}
# Build a mock async streaming response that yields no chunks but has
# a _response attribute pointing to the mock httpx response
mock_stream_response = MagicMock()
mock_stream_response._response = mock_raw_response
async def empty_aiter():
return
yield # make it an async generator
mock_stream_response.__aiter__ = lambda self: empty_aiter()
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(
return_value=mock_stream_response
)
with patch(
"backend.copilot.baseline.service._get_openai_client",
return_value=mock_client,
):
await _baseline_llm_caller(
messages=[{"role": "user", "content": "hi"}],
tools=[],
state=state,
)
assert state.cost_usd == pytest.approx(0.0123)
@pytest.mark.asyncio
async def test_cost_usd_accumulates_across_calls(self):
"""cost_usd accumulates when _baseline_llm_caller is called multiple times."""
from backend.copilot.baseline.service import (
_baseline_llm_caller,
_BaselineStreamState,
)
state = _BaselineStreamState(model="gpt-4o-mini")
def make_stream_mock(cost: str) -> MagicMock:
mock_raw = MagicMock()
mock_raw.headers = {"x-total-cost": cost}
mock_stream = MagicMock()
mock_stream._response = mock_raw
async def empty_aiter():
return
yield
mock_stream.__aiter__ = lambda self: empty_aiter()
return mock_stream
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(
side_effect=[make_stream_mock("0.01"), make_stream_mock("0.02")]
)
with patch(
"backend.copilot.baseline.service._get_openai_client",
return_value=mock_client,
):
await _baseline_llm_caller(
messages=[{"role": "user", "content": "first"}],
tools=[],
state=state,
)
await _baseline_llm_caller(
messages=[{"role": "user", "content": "second"}],
tools=[],
state=state,
)
assert state.cost_usd == pytest.approx(0.03)
@pytest.mark.asyncio
async def test_no_cost_when_header_absent(self):
"""state.cost_usd remains None when response has no x-total-cost header."""
from backend.copilot.baseline.service import (
_baseline_llm_caller,
_BaselineStreamState,
)
state = _BaselineStreamState(model="gpt-4o-mini")
mock_raw = MagicMock()
mock_raw.headers = {}
mock_stream = MagicMock()
mock_stream._response = mock_raw
async def empty_aiter():
return
yield
mock_stream.__aiter__ = lambda self: empty_aiter()
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream)
with patch(
"backend.copilot.baseline.service._get_openai_client",
return_value=mock_client,
):
await _baseline_llm_caller(
messages=[{"role": "user", "content": "hi"}],
tools=[],
state=state,
)
assert state.cost_usd is None
@pytest.mark.asyncio
async def test_cost_extracted_even_when_stream_raises(self):
"""cost_usd is captured in the finally block even when streaming fails."""
from backend.copilot.baseline.service import (
_baseline_llm_caller,
_BaselineStreamState,
)
state = _BaselineStreamState(model="gpt-4o-mini")
mock_raw = MagicMock()
mock_raw.headers = {"x-total-cost": "0.005"}
mock_stream = MagicMock()
mock_stream._response = mock_raw
async def failing_aiter():
raise RuntimeError("stream error")
yield # make it an async generator
mock_stream.__aiter__ = lambda self: failing_aiter()
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream)
with (
patch(
"backend.copilot.baseline.service._get_openai_client",
return_value=mock_client,
),
pytest.raises(RuntimeError, match="stream error"),
):
await _baseline_llm_caller(
messages=[{"role": "user", "content": "hi"}],
tools=[],
state=state,
)
assert state.cost_usd == pytest.approx(0.005)

View File

@@ -0,0 +1,667 @@
"""Integration tests for baseline transcript flow.
Exercises the real helpers in ``baseline/service.py`` that download,
validate, load, append to, backfill, and upload the transcript.
Storage is mocked via ``download_transcript`` / ``upload_transcript``
patches; no network access is required.
"""
import json as stdlib_json
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.baseline.service import (
_load_prior_transcript,
_record_turn_to_transcript,
_resolve_baseline_model,
_upload_final_transcript,
is_transcript_stale,
should_upload_transcript,
)
from backend.copilot.service import config
from backend.copilot.transcript import (
STOP_REASON_END_TURN,
STOP_REASON_TOOL_USE,
TranscriptDownload,
)
from backend.copilot.transcript_builder import TranscriptBuilder
from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult
def _make_transcript_content(*roles: str) -> str:
"""Build a minimal valid JSONL transcript from role names."""
lines = []
parent = ""
for i, role in enumerate(roles):
uid = f"uuid-{i}"
entry: dict = {
"type": role,
"uuid": uid,
"parentUuid": parent,
"message": {
"role": role,
"content": [{"type": "text", "text": f"{role} message {i}"}],
},
}
if role == "assistant":
entry["message"]["id"] = f"msg_{i}"
entry["message"]["model"] = "test-model"
entry["message"]["type"] = "message"
entry["message"]["stop_reason"] = STOP_REASON_END_TURN
lines.append(stdlib_json.dumps(entry))
parent = uid
return "\n".join(lines) + "\n"
class TestResolveBaselineModel:
"""Model selection honours the per-request mode."""
def test_fast_mode_selects_fast_model(self):
assert _resolve_baseline_model("fast") == config.fast_model
def test_extended_thinking_selects_default_model(self):
assert _resolve_baseline_model("extended_thinking") == config.model
def test_none_mode_selects_default_model(self):
"""Critical: baseline users without a mode MUST keep the default (opus)."""
assert _resolve_baseline_model(None) == config.model
def test_default_and_fast_models_differ(self):
"""Sanity: the two tiers are actually distinct in production config."""
assert config.model != config.fast_model
class TestLoadPriorTranscript:
"""``_load_prior_transcript`` wraps the download + validate + load flow."""
@pytest.mark.asyncio
async def test_loads_fresh_transcript(self):
builder = TranscriptBuilder()
content = _make_transcript_content("user", "assistant")
download = TranscriptDownload(content=content, message_count=2)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
transcript_builder=builder,
)
assert covers is True
assert builder.entry_count == 2
assert builder.last_entry_type == "assistant"
@pytest.mark.asyncio
async def test_rejects_stale_transcript(self):
"""msg_count strictly less than session-1 is treated as stale."""
builder = TranscriptBuilder()
content = _make_transcript_content("user", "assistant")
# session has 6 messages, transcript only covers 2 → stale.
download = TranscriptDownload(content=content, message_count=2)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=6,
transcript_builder=builder,
)
assert covers is False
assert builder.is_empty
@pytest.mark.asyncio
async def test_missing_transcript_returns_false(self):
builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=None),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
transcript_builder=builder,
)
assert covers is False
assert builder.is_empty
@pytest.mark.asyncio
async def test_invalid_transcript_returns_false(self):
builder = TranscriptBuilder()
download = TranscriptDownload(
content='{"type":"progress","uuid":"a"}\n',
message_count=1,
)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
transcript_builder=builder,
)
assert covers is False
assert builder.is_empty
@pytest.mark.asyncio
async def test_download_exception_returns_false(self):
builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(side_effect=RuntimeError("boom")),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
transcript_builder=builder,
)
assert covers is False
assert builder.is_empty
@pytest.mark.asyncio
async def test_zero_message_count_not_stale(self):
"""When msg_count is 0 (unknown), staleness check is skipped."""
builder = TranscriptBuilder()
download = TranscriptDownload(
content=_make_transcript_content("user", "assistant"),
message_count=0,
)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=20,
transcript_builder=builder,
)
assert covers is True
assert builder.entry_count == 2
class TestUploadFinalTranscript:
"""``_upload_final_transcript`` serialises and calls storage."""
@pytest.mark.asyncio
async def test_uploads_valid_transcript(self):
builder = TranscriptBuilder()
builder.append_user(content="hi")
builder.append_assistant(
content_blocks=[{"type": "text", "text": "hello"}],
model="test-model",
stop_reason=STOP_REASON_END_TURN,
)
upload_mock = AsyncMock(return_value=None)
with patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
):
await _upload_final_transcript(
user_id="user-1",
session_id="session-1",
transcript_builder=builder,
session_msg_count=2,
)
upload_mock.assert_awaited_once()
assert upload_mock.await_args is not None
call_kwargs = upload_mock.await_args.kwargs
assert call_kwargs["user_id"] == "user-1"
assert call_kwargs["session_id"] == "session-1"
assert call_kwargs["message_count"] == 2
assert "hello" in call_kwargs["content"]
@pytest.mark.asyncio
async def test_skips_upload_when_builder_empty(self):
builder = TranscriptBuilder()
upload_mock = AsyncMock(return_value=None)
with patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
):
await _upload_final_transcript(
user_id="user-1",
session_id="session-1",
transcript_builder=builder,
session_msg_count=0,
)
upload_mock.assert_not_awaited()
@pytest.mark.asyncio
async def test_swallows_upload_exceptions(self):
"""Upload failures should not propagate (flow continues for the user)."""
builder = TranscriptBuilder()
builder.append_user(content="hi")
builder.append_assistant(
content_blocks=[{"type": "text", "text": "hello"}],
model="test-model",
stop_reason=STOP_REASON_END_TURN,
)
with patch(
"backend.copilot.baseline.service.upload_transcript",
new=AsyncMock(side_effect=RuntimeError("storage unavailable")),
):
# Should not raise.
await _upload_final_transcript(
user_id="user-1",
session_id="session-1",
transcript_builder=builder,
session_msg_count=2,
)
class TestRecordTurnToTranscript:
"""``_record_turn_to_transcript`` translates LLMLoopResponse → transcript."""
def test_records_final_assistant_text(self):
builder = TranscriptBuilder()
builder.append_user(content="hi")
response = LLMLoopResponse(
response_text="hello there",
tool_calls=[],
raw_response=None,
)
_record_turn_to_transcript(
response,
tool_results=None,
transcript_builder=builder,
model="test-model",
)
assert builder.entry_count == 2
assert builder.last_entry_type == "assistant"
jsonl = builder.to_jsonl()
assert "hello there" in jsonl
assert STOP_REASON_END_TURN in jsonl
def test_records_tool_use_then_tool_result(self):
"""Anthropic ordering: assistant(tool_use) → user(tool_result)."""
builder = TranscriptBuilder()
builder.append_user(content="use a tool")
response = LLMLoopResponse(
response_text=None,
tool_calls=[
LLMToolCall(id="call-1", name="echo", arguments='{"text":"hi"}')
],
raw_response=None,
)
tool_results = [
ToolCallResult(tool_call_id="call-1", tool_name="echo", content="hi")
]
_record_turn_to_transcript(
response,
tool_results,
transcript_builder=builder,
model="test-model",
)
# user, assistant(tool_use), user(tool_result) = 3 entries
assert builder.entry_count == 3
jsonl = builder.to_jsonl()
assert STOP_REASON_TOOL_USE in jsonl
assert "tool_use" in jsonl
assert "tool_result" in jsonl
assert "call-1" in jsonl
def test_records_nothing_on_empty_response(self):
builder = TranscriptBuilder()
builder.append_user(content="hi")
response = LLMLoopResponse(
response_text=None,
tool_calls=[],
raw_response=None,
)
_record_turn_to_transcript(
response,
tool_results=None,
transcript_builder=builder,
model="test-model",
)
assert builder.entry_count == 1
def test_malformed_tool_args_dont_crash(self):
"""Bad JSON in tool arguments falls back to {} without raising."""
builder = TranscriptBuilder()
builder.append_user(content="hi")
response = LLMLoopResponse(
response_text=None,
tool_calls=[LLMToolCall(id="call-1", name="echo", arguments="{not-json")],
raw_response=None,
)
tool_results = [
ToolCallResult(tool_call_id="call-1", tool_name="echo", content="ok")
]
_record_turn_to_transcript(
response,
tool_results,
transcript_builder=builder,
model="test-model",
)
assert builder.entry_count == 3
jsonl = builder.to_jsonl()
assert '"input":{}' in jsonl
class TestRoundTrip:
"""End-to-end: load prior → append new turn → upload."""
@pytest.mark.asyncio
async def test_full_round_trip(self):
prior = _make_transcript_content("user", "assistant")
download = TranscriptDownload(content=prior, message_count=2)
builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
transcript_builder=builder,
)
assert covers is True
assert builder.entry_count == 2
# New user turn.
builder.append_user(content="new question")
assert builder.entry_count == 3
# New assistant turn.
response = LLMLoopResponse(
response_text="new answer",
tool_calls=[],
raw_response=None,
)
_record_turn_to_transcript(
response,
tool_results=None,
transcript_builder=builder,
model="test-model",
)
assert builder.entry_count == 4
# Upload.
upload_mock = AsyncMock(return_value=None)
with patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
):
await _upload_final_transcript(
user_id="user-1",
session_id="session-1",
transcript_builder=builder,
session_msg_count=4,
)
upload_mock.assert_awaited_once()
assert upload_mock.await_args is not None
uploaded = upload_mock.await_args.kwargs["content"]
assert "new question" in uploaded
assert "new answer" in uploaded
# Original content preserved in the round trip.
assert "user message 0" in uploaded
assert "assistant message 1" in uploaded
@pytest.mark.asyncio
async def test_backfill_append_guard(self):
"""Backfill only runs when the last entry is not already assistant."""
builder = TranscriptBuilder()
builder.append_user(content="hi")
# Simulate the backfill guard from stream_chat_completion_baseline.
assistant_text = "partial text before error"
if builder.last_entry_type != "assistant":
builder.append_assistant(
content_blocks=[{"type": "text", "text": assistant_text}],
model="test-model",
stop_reason=STOP_REASON_END_TURN,
)
assert builder.last_entry_type == "assistant"
assert "partial text before error" in builder.to_jsonl()
# Second invocation: the guard must prevent double-append.
initial_count = builder.entry_count
if builder.last_entry_type != "assistant":
builder.append_assistant(
content_blocks=[{"type": "text", "text": "duplicate"}],
model="test-model",
stop_reason=STOP_REASON_END_TURN,
)
assert builder.entry_count == initial_count
class TestIsTranscriptStale:
"""``is_transcript_stale`` gates prior-transcript loading."""
def test_none_download_is_not_stale(self):
assert is_transcript_stale(None, session_msg_count=5) is False
def test_zero_message_count_is_not_stale(self):
"""Legacy transcripts without msg_count tracking must remain usable."""
dl = TranscriptDownload(content="", message_count=0)
assert is_transcript_stale(dl, session_msg_count=20) is False
def test_stale_when_covers_less_than_prefix(self):
dl = TranscriptDownload(content="", message_count=2)
# session has 6 messages; transcript must cover at least 5 (6-1).
assert is_transcript_stale(dl, session_msg_count=6) is True
def test_fresh_when_covers_full_prefix(self):
dl = TranscriptDownload(content="", message_count=5)
assert is_transcript_stale(dl, session_msg_count=6) is False
def test_fresh_when_exceeds_prefix(self):
"""Race: transcript ahead of session count is still acceptable."""
dl = TranscriptDownload(content="", message_count=10)
assert is_transcript_stale(dl, session_msg_count=6) is False
def test_boundary_equal_to_prefix_minus_one(self):
dl = TranscriptDownload(content="", message_count=5)
assert is_transcript_stale(dl, session_msg_count=6) is False
class TestShouldUploadTranscript:
"""``should_upload_transcript`` gates the final upload."""
def test_upload_allowed_for_user_with_coverage(self):
assert should_upload_transcript("user-1", True) is True
def test_upload_skipped_when_no_user(self):
assert should_upload_transcript(None, True) is False
def test_upload_skipped_when_empty_user(self):
assert should_upload_transcript("", True) is False
def test_upload_skipped_without_coverage(self):
"""Partial transcript must never clobber a more complete stored one."""
assert should_upload_transcript("user-1", False) is False
def test_upload_skipped_when_no_user_and_no_coverage(self):
assert should_upload_transcript(None, False) is False
class TestTranscriptLifecycle:
"""End-to-end: download → validate → build → upload.
Simulates the full transcript lifecycle inside
``stream_chat_completion_baseline`` by mocking the storage layer and
driving each step through the real helpers.
"""
@pytest.mark.asyncio
async def test_full_lifecycle_happy_path(self):
"""Fresh download, append a turn, upload covers the session."""
builder = TranscriptBuilder()
prior = _make_transcript_content("user", "assistant")
download = TranscriptDownload(content=prior, message_count=2)
upload_mock = AsyncMock(return_value=None)
with (
patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
),
patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
),
):
# --- 1. Download & load prior transcript ---
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
transcript_builder=builder,
)
assert covers is True
# --- 2. Append a new user turn + a new assistant response ---
builder.append_user(content="follow-up question")
_record_turn_to_transcript(
LLMLoopResponse(
response_text="follow-up answer",
tool_calls=[],
raw_response=None,
),
tool_results=None,
transcript_builder=builder,
model="test-model",
)
# --- 3. Gate + upload ---
assert (
should_upload_transcript(
user_id="user-1", transcript_covers_prefix=covers
)
is True
)
await _upload_final_transcript(
user_id="user-1",
session_id="session-1",
transcript_builder=builder,
session_msg_count=4,
)
upload_mock.assert_awaited_once()
assert upload_mock.await_args is not None
uploaded = upload_mock.await_args.kwargs["content"]
assert "follow-up question" in uploaded
assert "follow-up answer" in uploaded
# Original prior-turn content preserved.
assert "user message 0" in uploaded
assert "assistant message 1" in uploaded
@pytest.mark.asyncio
async def test_lifecycle_stale_download_suppresses_upload(self):
"""Stale download → covers=False → upload must be skipped."""
builder = TranscriptBuilder()
# session has 10 msgs but stored transcript only covers 2 → stale.
stale = TranscriptDownload(
content=_make_transcript_content("user", "assistant"),
message_count=2,
)
upload_mock = AsyncMock(return_value=None)
with (
patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=stale),
),
patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=10,
transcript_builder=builder,
)
assert covers is False
# The caller's gate mirrors the production path.
assert (
should_upload_transcript(user_id="user-1", transcript_covers_prefix=covers)
is False
)
upload_mock.assert_not_awaited()
@pytest.mark.asyncio
async def test_lifecycle_anonymous_user_skips_upload(self):
"""Anonymous (user_id=None) → upload gate must return False."""
builder = TranscriptBuilder()
builder.append_user(content="hi")
builder.append_assistant(
content_blocks=[{"type": "text", "text": "hello"}],
model="test-model",
stop_reason=STOP_REASON_END_TURN,
)
assert (
should_upload_transcript(user_id=None, transcript_covers_prefix=True)
is False
)
@pytest.mark.asyncio
async def test_lifecycle_missing_download_still_uploads_new_content(self):
"""No prior transcript → covers defaults to True in the service,
new turn should upload cleanly."""
builder = TranscriptBuilder()
upload_mock = AsyncMock(return_value=None)
with (
patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=None),
),
patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=1,
transcript_builder=builder,
)
# No download: covers is False, so the production path would
# skip upload. This protects against overwriting a future
# more-complete transcript with a single-turn snapshot.
assert covers is False
assert (
should_upload_transcript(
user_id="user-1", transcript_covers_prefix=covers
)
is False
)
upload_mock.assert_not_awaited()

View File

@@ -8,18 +8,35 @@ from pydantic_settings import BaseSettings
from backend.util.clients import OPENROUTER_BASE_URL
# Per-request routing mode for a single chat turn.
# - 'fast': route to the baseline OpenAI-compatible path with the cheaper model.
# - 'extended_thinking': route to the Claude Agent SDK path with the default
# (opus) model.
# ``None`` means "no override"; the server falls back to the Claude Code
# subscription flag → LaunchDarkly COPILOT_SDK → config.use_claude_agent_sdk.
CopilotMode = Literal["fast", "extended_thinking"]
class ChatConfig(BaseSettings):
"""Configuration for the chat system."""
# OpenAI API Configuration
model: str = Field(
default="anthropic/claude-opus-4.6", description="Default model to use"
default="anthropic/claude-opus-4.6",
description="Default model for extended thinking mode",
)
fast_model: str = Field(
default="anthropic/claude-sonnet-4",
description="Model for fast mode (baseline path). Should be faster/cheaper than the default model.",
)
title_model: str = Field(
default="openai/gpt-4o-mini",
description="Model to use for generating session titles (should be fast/cheap)",
)
simulation_model: str = Field(
default="google/gemini-2.5-flash",
description="Model for dry-run block simulation (should be fast/cheap with good JSON output)",
)
api_key: str | None = Field(default=None, description="OpenAI API key")
base_url: str | None = Field(
default=OPENROUTER_BASE_URL,
@@ -77,11 +94,11 @@ class ChatConfig(BaseSettings):
# allows ~70-100 turns/day.
# Checked at the HTTP layer (routes.py) before each turn.
#
# TODO: These are deploy-time constants applied identically to every user.
# If per-user or per-plan limits are needed (e.g., free tier vs paid), these
# must move to the database (e.g., a UserPlan table) and get_usage_status /
# check_rate_limit would look up each user's specific limits instead of
# reading config.daily_token_limit / config.weekly_token_limit.
# These are base limits for the FREE tier. Higher tiers (PRO, BUSINESS,
# ENTERPRISE) multiply these by their tier multiplier (see
# rate_limit.TIER_MULTIPLIERS). User tier is stored in the
# User.subscriptionTier DB column and resolved inside
# get_global_rate_limits().
daily_token_limit: int = Field(
default=2_500_000,
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",

View File

@@ -149,7 +149,8 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
Allowed:
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
- Files under ``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/...``.
- Files under ``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/...``
or ``tool-outputs/...``.
The SDK nests tool-results under a conversation UUID directory;
the UUID segment is validated with ``_UUID_RE``.
"""
@@ -174,17 +175,20 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
# Defence-in-depth: ensure project_dir didn't escape the base.
if not project_dir.startswith(SDK_PROJECTS_DIR + os.sep):
return False
# Only allow: <encoded-cwd>/<uuid>/tool-results/<file>
# Only allow: <encoded-cwd>/<uuid>/<tool-dir>/<file>
# The SDK always creates a conversation UUID directory between
# the project dir and tool-results/.
# the project dir and the tool directory.
# Accept both "tool-results" (SDK's persisted outputs) and
# "tool-outputs" (the model sometimes confuses workspace paths
# with filesystem paths and generates this variant).
if resolved.startswith(project_dir + os.sep):
relative = resolved[len(project_dir) + 1 :]
parts = relative.split(os.sep)
# Require exactly: [<uuid>, "tool-results", <file>, ...]
# Require exactly: [<uuid>, "tool-results"|"tool-outputs", <file>, ...]
if (
len(parts) >= 3
and _UUID_RE.match(parts[0])
and parts[1] == "tool-results"
and parts[1] in ("tool-results", "tool-outputs")
):
return True

View File

@@ -134,6 +134,21 @@ def test_is_allowed_local_path_tool_results_with_uuid():
_current_project_dir.set("")
def test_is_allowed_local_path_tool_outputs_with_uuid():
"""Files under <encoded-cwd>/<uuid>/tool-outputs/ are also allowed."""
encoded = "test-encoded-dir"
conv_uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
path = os.path.join(
SDK_PROJECTS_DIR, encoded, conv_uuid, "tool-outputs", "output.json"
)
_current_project_dir.set(encoded)
try:
assert is_allowed_local_path(path, sdk_cwd=None)
finally:
_current_project_dir.set("")
def test_is_allowed_local_path_tool_results_without_uuid_rejected():
"""Direct <encoded-cwd>/tool-results/ (no UUID) is rejected."""
encoded = "test-encoded-dir"
@@ -159,7 +174,7 @@ def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
def test_is_allowed_local_path_valid_uuid_wrong_segment_name_rejected():
"""A valid UUID dir but non-'tool-results' second segment is rejected."""
"""A valid UUID dir but non-'tool-results'/'tool-outputs' second segment is rejected."""
encoded = "test-encoded-dir"
uuid_str = "12345678-1234-5678-9abc-def012345678"
path = os.path.join(

View File

@@ -14,15 +14,32 @@ from prisma.types import (
ChatSessionUpdateInput,
ChatSessionWhereInput,
)
from pydantic import BaseModel
from backend.data import db
from backend.util.json import SafeJson, sanitize_string
from .model import ChatMessage, ChatSession, ChatSessionInfo, invalidate_session_cache
from .model import (
ChatMessage,
ChatSession,
ChatSessionInfo,
ChatSessionMetadata,
cache_chat_session,
)
from .model import get_chat_session as get_chat_session_cached
logger = logging.getLogger(__name__)
class PaginatedMessages(BaseModel):
"""Result of a paginated message query."""
messages: list[ChatMessage]
has_more: bool
oldest_sequence: int | None
session: ChatSessionInfo
async def get_chat_session(session_id: str) -> ChatSession | None:
"""Get a chat session by ID from the database."""
session = await PrismaChatSession.prisma().find_unique(
@@ -32,9 +49,120 @@ async def get_chat_session(session_id: str) -> ChatSession | None:
return ChatSession.from_db(session) if session else None
async def get_chat_session_metadata(session_id: str) -> ChatSessionInfo | None:
"""Get chat session metadata (without messages) for ownership validation."""
session = await PrismaChatSession.prisma().find_unique(
where={"id": session_id},
)
return ChatSessionInfo.from_db(session) if session else None
async def get_chat_messages_paginated(
session_id: str,
limit: int = 50,
before_sequence: int | None = None,
user_id: str | None = None,
) -> PaginatedMessages | None:
"""Get paginated messages for a session, newest first.
Verifies session existence (and ownership when ``user_id`` is provided)
in parallel with the message query. Returns ``None`` when the session
is not found or does not belong to the user.
Args:
session_id: The chat session ID.
limit: Max messages to return.
before_sequence: Cursor — return messages with sequence < this value.
user_id: If provided, filters via ``Session.userId`` so only the
session owner's messages are returned (acts as an ownership guard).
"""
# Build session-existence / ownership check
session_where: ChatSessionWhereInput = {"id": session_id}
if user_id is not None:
session_where["userId"] = user_id
# Build message include — fetch paginated messages in the same query
msg_include: dict[str, Any] = {
"order_by": {"sequence": "desc"},
"take": limit + 1,
}
if before_sequence is not None:
msg_include["where"] = {"sequence": {"lt": before_sequence}}
# Single query: session existence/ownership + paginated messages
session = await PrismaChatSession.prisma().find_first(
where=session_where,
include={"Messages": msg_include},
)
if session is None:
return None
session_info = ChatSessionInfo.from_db(session)
results = list(session.Messages) if session.Messages else []
has_more = len(results) > limit
results = results[:limit]
# Reverse to ascending order
results.reverse()
# Tool-call boundary fix: if the oldest message is a tool message,
# expand backward to include the preceding assistant message that
# owns the tool_calls, so convertChatSessionMessagesToUiMessages
# can pair them correctly.
_BOUNDARY_SCAN_LIMIT = 10
if results and results[0].role == "tool":
boundary_where: dict[str, Any] = {
"sessionId": session_id,
"sequence": {"lt": results[0].sequence},
}
if user_id is not None:
boundary_where["Session"] = {"is": {"userId": user_id}}
extra = await PrismaChatMessage.prisma().find_many(
where=boundary_where,
order={"sequence": "desc"},
take=_BOUNDARY_SCAN_LIMIT,
)
# Find the first non-tool message (should be the assistant)
boundary_msgs = []
found_owner = False
for msg in extra:
boundary_msgs.append(msg)
if msg.role != "tool":
found_owner = True
break
boundary_msgs.reverse()
if not found_owner:
logger.warning(
"Boundary expansion did not find owning assistant message "
"for session=%s before sequence=%s (%d msgs scanned)",
session_id,
results[0].sequence,
len(extra),
)
if boundary_msgs:
results = boundary_msgs + results
# Only mark has_more if the expanded boundary isn't the
# very start of the conversation (sequence 0).
if boundary_msgs[0].sequence > 0:
has_more = True
messages = [ChatMessage.from_db(m) for m in results]
oldest_sequence = messages[0].sequence if messages else None
return PaginatedMessages(
messages=messages,
has_more=has_more,
oldest_sequence=oldest_sequence,
session=session_info,
)
async def create_chat_session(
session_id: str,
user_id: str,
metadata: ChatSessionMetadata | None = None,
) -> ChatSessionInfo:
"""Create a new chat session in the database."""
data = ChatSessionCreateInput(
@@ -43,6 +171,7 @@ async def create_chat_session(
credentials=SafeJson({}),
successfulAgentRuns=SafeJson({}),
successfulAgentSchedules=SafeJson({}),
metadata=SafeJson((metadata or ChatSessionMetadata()).model_dump()),
)
prisma_session = await PrismaChatSession.prisma().create(data=data)
return ChatSessionInfo.from_db(prisma_session)
@@ -57,7 +186,12 @@ async def update_chat_session(
total_completion_tokens: int | None = None,
title: str | None = None,
) -> ChatSession | None:
"""Update a chat session's metadata."""
"""Update a chat session's mutable fields.
Note: ``metadata`` (which includes ``dry_run``) is intentionally omitted —
it is set once at creation time and treated as immutable for the lifetime
of the session.
"""
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
if credentials is not None:
@@ -367,8 +501,11 @@ async def update_tool_message_content(
async def set_turn_duration(session_id: str, duration_ms: int) -> None:
"""Set durationMs on the last assistant message in a session.
Also invalidates the Redis session cache so the next GET returns
the updated duration.
Updates the Redis cache in-place instead of invalidating it.
Invalidation would delete the key, creating a window where concurrent
``get_chat_session`` calls re-populate the cache from DB — potentially
with stale data if the DB write from the previous turn hasn't propagated.
This race caused duplicate user messages on the next turn.
"""
last_msg = await PrismaChatMessage.prisma().find_first(
where={"sessionId": session_id, "role": "assistant"},
@@ -379,5 +516,13 @@ async def set_turn_duration(session_id: str, duration_ms: int) -> None:
where={"id": last_msg.id},
data={"durationMs": duration_ms},
)
# Invalidate cache so the session is re-fetched from DB with durationMs
await invalidate_session_cache(session_id)
# Update cache in-place rather than invalidating to avoid a
# race window where the empty cache gets re-populated with
# stale data by a concurrent get_chat_session call.
session = await get_chat_session_cached(session_id)
if session and session.messages:
for msg in reversed(session.messages):
if msg.role == "assistant":
msg.duration_ms = duration_ms
break
await cache_chat_session(session)

View File

@@ -0,0 +1,388 @@
"""Unit tests for copilot.db — paginated message queries."""
from __future__ import annotations
from datetime import UTC, datetime
from typing import Any
from unittest.mock import AsyncMock, patch
import pytest
from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from backend.copilot.db import (
PaginatedMessages,
get_chat_messages_paginated,
set_turn_duration,
)
from backend.copilot.model import ChatMessage as CopilotChatMessage
from backend.copilot.model import ChatSession, get_chat_session, upsert_chat_session
def _make_msg(
sequence: int,
role: str = "assistant",
content: str | None = "hello",
tool_calls: Any = None,
) -> PrismaChatMessage:
"""Build a minimal PrismaChatMessage for testing."""
return PrismaChatMessage(
id=f"msg-{sequence}",
createdAt=datetime.now(UTC),
sessionId="sess-1",
role=role,
content=content,
sequence=sequence,
toolCalls=tool_calls,
name=None,
toolCallId=None,
refusal=None,
functionCall=None,
)
def _make_session(
session_id: str = "sess-1",
user_id: str = "user-1",
messages: list[PrismaChatMessage] | None = None,
) -> PrismaChatSession:
"""Build a minimal PrismaChatSession for testing."""
now = datetime.now(UTC)
session = PrismaChatSession.model_construct(
id=session_id,
createdAt=now,
updatedAt=now,
userId=user_id,
credentials={},
successfulAgentRuns={},
successfulAgentSchedules={},
totalPromptTokens=0,
totalCompletionTokens=0,
title=None,
metadata={},
Messages=messages or [],
)
return session
SESSION_ID = "sess-1"
@pytest.fixture()
def mock_db():
"""Patch ChatSession.prisma().find_first and ChatMessage.prisma().find_many.
find_first is used for the main query (session + included messages).
find_many is used only for boundary expansion queries.
"""
with (
patch.object(PrismaChatSession, "prisma") as mock_session_prisma,
patch.object(PrismaChatMessage, "prisma") as mock_msg_prisma,
):
find_first = AsyncMock()
mock_session_prisma.return_value.find_first = find_first
find_many = AsyncMock(return_value=[])
mock_msg_prisma.return_value.find_many = find_many
yield find_first, find_many
# ---------- Basic pagination ----------
@pytest.mark.asyncio
async def test_basic_page_returns_messages_ascending(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Messages are returned in ascending sequence order."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(3), _make_msg(2), _make_msg(1)],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
assert isinstance(page, PaginatedMessages)
assert [m.sequence for m in page.messages] == [1, 2, 3]
assert page.has_more is False
assert page.oldest_sequence == 1
@pytest.mark.asyncio
async def test_has_more_when_results_exceed_limit(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""has_more is True when DB returns more than limit items."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(3), _make_msg(2), _make_msg(1)],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=2)
assert page is not None
assert page.has_more is True
assert len(page.messages) == 2
assert [m.sequence for m in page.messages] == [2, 3]
@pytest.mark.asyncio
async def test_empty_session_returns_no_messages(
mock_db: tuple[AsyncMock, AsyncMock],
):
find_first, _ = mock_db
find_first.return_value = _make_session(messages=[])
page = await get_chat_messages_paginated(SESSION_ID, limit=50)
assert page is not None
assert page.messages == []
assert page.has_more is False
assert page.oldest_sequence is None
@pytest.mark.asyncio
async def test_before_sequence_filters_correctly(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""before_sequence is passed as a where filter inside the Messages include."""
find_first, _ = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(2), _make_msg(1)],
)
await get_chat_messages_paginated(SESSION_ID, limit=50, before_sequence=5)
call_kwargs = find_first.call_args
include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include")
assert include["Messages"]["where"] == {"sequence": {"lt": 5}}
@pytest.mark.asyncio
async def test_no_where_on_messages_without_before_sequence(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Without before_sequence, the Messages include has no where clause."""
find_first, _ = mock_db
find_first.return_value = _make_session(messages=[_make_msg(1)])
await get_chat_messages_paginated(SESSION_ID, limit=50)
call_kwargs = find_first.call_args
include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include")
assert "where" not in include["Messages"]
@pytest.mark.asyncio
async def test_user_id_filter_applied_to_session_where(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""user_id adds a userId filter to the session-level where clause."""
find_first, _ = mock_db
find_first.return_value = _make_session(messages=[_make_msg(1)])
await get_chat_messages_paginated(SESSION_ID, limit=50, user_id="user-abc")
call_kwargs = find_first.call_args
where = call_kwargs.kwargs.get("where") or call_kwargs[1].get("where")
assert where["userId"] == "user-abc"
@pytest.mark.asyncio
async def test_session_not_found_returns_none(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Returns None when session doesn't exist or user doesn't own it."""
find_first, _ = mock_db
find_first.return_value = None
page = await get_chat_messages_paginated(SESSION_ID, limit=50)
assert page is None
@pytest.mark.asyncio
async def test_session_info_included_in_result(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""PaginatedMessages includes session metadata."""
find_first, _ = mock_db
find_first.return_value = _make_session(messages=[_make_msg(1)])
page = await get_chat_messages_paginated(SESSION_ID, limit=50)
assert page is not None
assert page.session.session_id == SESSION_ID
# ---------- Backward boundary expansion ----------
@pytest.mark.asyncio
async def test_boundary_expansion_includes_assistant(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""When page starts with a tool message, expand backward to include
the owning assistant message."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(5, role="tool"), _make_msg(4, role="tool")],
)
find_many.return_value = [_make_msg(3, role="assistant")]
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
assert page is not None
assert [m.sequence for m in page.messages] == [3, 4, 5]
assert page.messages[0].role == "assistant"
assert page.oldest_sequence == 3
@pytest.mark.asyncio
async def test_boundary_expansion_includes_multiple_tool_msgs(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""Boundary expansion scans past consecutive tool messages to find
the owning assistant."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(7, role="tool")],
)
find_many.return_value = [
_make_msg(6, role="tool"),
_make_msg(5, role="tool"),
_make_msg(4, role="assistant"),
]
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
assert page is not None
assert [m.sequence for m in page.messages] == [4, 5, 6, 7]
assert page.messages[0].role == "assistant"
@pytest.mark.asyncio
async def test_boundary_expansion_sets_has_more_when_not_at_start(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""After boundary expansion, has_more=True if expanded msgs aren't at seq 0."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(3, role="tool")],
)
find_many.return_value = [_make_msg(2, role="assistant")]
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
assert page is not None
assert page.has_more is True
@pytest.mark.asyncio
async def test_boundary_expansion_no_has_more_at_conversation_start(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""has_more stays False when boundary expansion reaches seq 0."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(1, role="tool")],
)
find_many.return_value = [_make_msg(0, role="assistant")]
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
assert page is not None
assert page.has_more is False
assert page.oldest_sequence == 0
@pytest.mark.asyncio
async def test_no_boundary_expansion_when_first_msg_not_tool(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""No boundary expansion when the first message is not a tool message."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(3, role="user"), _make_msg(2, role="assistant")],
)
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
assert page is not None
assert find_many.call_count == 0
assert [m.sequence for m in page.messages] == [2, 3]
@pytest.mark.asyncio
async def test_boundary_expansion_warns_when_no_owner_found(
mock_db: tuple[AsyncMock, AsyncMock],
):
"""When boundary scan doesn't find a non-tool message, a warning is logged
and the boundary messages are still included."""
find_first, find_many = mock_db
find_first.return_value = _make_session(
messages=[_make_msg(10, role="tool")],
)
find_many.return_value = [_make_msg(i, role="tool") for i in range(9, -1, -1)]
with patch("backend.copilot.db.logger") as mock_logger:
page = await get_chat_messages_paginated(SESSION_ID, limit=5)
mock_logger.warning.assert_called_once()
assert page is not None
assert page.messages[0].role == "tool"
assert len(page.messages) > 1
# ---------- Turn duration (integration tests) ----------
@pytest.mark.asyncio(loop_scope="session")
async def test_set_turn_duration_updates_cache_in_place(setup_test_user, test_user_id):
"""set_turn_duration patches the cached session without invalidation.
Verifies that after calling set_turn_duration the Redis-cached session
reflects the updated durationMs on the last assistant message, without
the cache having been deleted and re-populated (which could race with
concurrent get_chat_session calls).
"""
session = ChatSession.new(user_id=test_user_id, dry_run=False)
session.messages = [
CopilotChatMessage(role="user", content="hello"),
CopilotChatMessage(role="assistant", content="hi there"),
]
session = await upsert_chat_session(session)
# Ensure the session is in cache
cached = await get_chat_session(session.session_id, test_user_id)
assert cached is not None
assert cached.messages[-1].duration_ms is None
# Update turn duration — should patch cache in-place
await set_turn_duration(session.session_id, 1234)
# Read from cache (not DB) — the cache should already have the update
updated = await get_chat_session(session.session_id, test_user_id)
assert updated is not None
assistant_msgs = [m for m in updated.messages if m.role == "assistant"]
assert len(assistant_msgs) == 1
assert assistant_msgs[0].duration_ms == 1234
@pytest.mark.asyncio(loop_scope="session")
async def test_set_turn_duration_no_assistant_message(setup_test_user, test_user_id):
"""set_turn_duration is a no-op when there are no assistant messages."""
session = ChatSession.new(user_id=test_user_id, dry_run=False)
session.messages = [
CopilotChatMessage(role="user", content="hello"),
]
session = await upsert_chat_session(session)
# Should not raise
await set_turn_duration(session.session_id, 5678)
cached = await get_chat_session(session.session_id, test_user_id)
assert cached is not None
# User message should not have durationMs
assert cached.messages[0].duration_ms is None

View File

@@ -13,7 +13,7 @@ import time
from backend.copilot import stream_registry
from backend.copilot.baseline import stream_chat_completion_baseline
from backend.copilot.config import ChatConfig
from backend.copilot.config import ChatConfig, CopilotMode
from backend.copilot.response_model import StreamError
from backend.copilot.sdk import service as sdk_service
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
@@ -30,6 +30,57 @@ from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
# ============ Mode Routing ============ #
async def resolve_effective_mode(
mode: CopilotMode | None,
user_id: str | None,
) -> CopilotMode | None:
"""Strip ``mode`` when the user is not entitled to the toggle.
The UI gates the mode toggle behind ``CHAT_MODE_OPTION``; the
processor enforces the same gate server-side so an authenticated
user cannot bypass the flag by crafting a request directly.
"""
if mode is None:
return None
allowed = await is_feature_enabled(
Flag.CHAT_MODE_OPTION,
user_id or "anonymous",
default=False,
)
if not allowed:
logger.info(f"Ignoring mode={mode} — CHAT_MODE_OPTION is disabled for user")
return None
return mode
async def resolve_use_sdk_for_mode(
mode: CopilotMode | None,
user_id: str | None,
*,
use_claude_code_subscription: bool,
config_default: bool,
) -> bool:
"""Pick the SDK vs baseline path for a single turn.
Per-request ``mode`` wins whenever it is set (after the
``CHAT_MODE_OPTION`` gate has been applied upstream). Otherwise
falls back to the Claude Code subscription override, then the
``COPILOT_SDK`` LaunchDarkly flag, then the config default.
"""
if mode == "fast":
return False
if mode == "extended_thinking":
return True
return use_claude_code_subscription or await is_feature_enabled(
Flag.COPILOT_SDK,
user_id or "anonymous",
default=config_default,
)
# ============ Module Entry Points ============ #
# Thread-local storage for processor instances
@@ -100,8 +151,8 @@ class CoPilotProcessor:
This method is called once per worker thread to set up the async event
loop and initialize any required resources.
Database is accessed only through DatabaseManager, so we don't need to connect
to Prisma directly.
DB operations route through DatabaseManagerAsyncClient (RPC) via the
db_accessors pattern — no direct Prisma connection is needed here.
"""
configure_logging()
set_service_name("CoPilotExecutor")
@@ -250,21 +301,26 @@ class CoPilotProcessor:
if config.test_mode:
stream_fn = stream_chat_completion_dummy
log.warning("Using DUMMY service (CHAT_TEST_MODE=true)")
effective_mode = None
else:
use_sdk = (
config.use_claude_code_subscription
or await is_feature_enabled(
Flag.COPILOT_SDK,
entry.user_id or "anonymous",
default=config.use_claude_agent_sdk,
)
# Enforce server-side feature-flag gate so unauthorised
# users cannot force a mode by crafting the request.
effective_mode = await resolve_effective_mode(entry.mode, entry.user_id)
use_sdk = await resolve_use_sdk_for_mode(
effective_mode,
entry.user_id,
use_claude_code_subscription=config.use_claude_code_subscription,
config_default=config.use_claude_agent_sdk,
)
stream_fn = (
sdk_service.stream_chat_completion_sdk
if use_sdk
else stream_chat_completion_baseline
)
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
log.info(
f"Using {'SDK' if use_sdk else 'baseline'} service "
f"(mode={effective_mode or 'default'})"
)
# Stream chat completion and publish chunks to Redis.
# stream_and_publish wraps the raw stream with registry
@@ -276,6 +332,7 @@ class CoPilotProcessor:
user_id=entry.user_id,
context=entry.context,
file_ids=entry.file_ids,
mode=effective_mode,
)
async for chunk in stream_registry.stream_and_publish(
session_id=entry.session_id,

View File

@@ -0,0 +1,175 @@
"""Unit tests for CoPilot mode routing logic in the processor.
Tests cover the mode→service mapping:
- 'fast' → baseline service
- 'extended_thinking' → SDK service
- None → feature flag / config fallback
as well as the ``CHAT_MODE_OPTION`` server-side gate. The tests import
the real production helpers from ``processor.py`` so the routing logic
has meaningful coverage.
"""
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.executor.processor import (
resolve_effective_mode,
resolve_use_sdk_for_mode,
)
class TestResolveUseSdkForMode:
"""Tests for the per-request mode routing logic."""
@pytest.mark.asyncio
async def test_fast_mode_uses_baseline(self):
"""mode='fast' always routes to baseline, regardless of flags."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=True),
):
assert (
await resolve_use_sdk_for_mode(
"fast",
"user-1",
use_claude_code_subscription=True,
config_default=True,
)
is False
)
@pytest.mark.asyncio
async def test_extended_thinking_uses_sdk(self):
"""mode='extended_thinking' always routes to SDK, regardless of flags."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
):
assert (
await resolve_use_sdk_for_mode(
"extended_thinking",
"user-1",
use_claude_code_subscription=False,
config_default=False,
)
is True
)
@pytest.mark.asyncio
async def test_none_mode_uses_subscription_override(self):
"""mode=None with claude_code_subscription=True routes to SDK."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
):
assert (
await resolve_use_sdk_for_mode(
None,
"user-1",
use_claude_code_subscription=True,
config_default=False,
)
is True
)
@pytest.mark.asyncio
async def test_none_mode_uses_feature_flag(self):
"""mode=None with feature flag enabled routes to SDK."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=True),
) as flag_mock:
assert (
await resolve_use_sdk_for_mode(
None,
"user-1",
use_claude_code_subscription=False,
config_default=False,
)
is True
)
flag_mock.assert_awaited_once()
@pytest.mark.asyncio
async def test_none_mode_uses_config_default(self):
"""mode=None falls back to config.use_claude_agent_sdk."""
# When LaunchDarkly returns the default (True), we expect SDK routing.
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=True),
):
assert (
await resolve_use_sdk_for_mode(
None,
"user-1",
use_claude_code_subscription=False,
config_default=True,
)
is True
)
@pytest.mark.asyncio
async def test_none_mode_all_disabled(self):
"""mode=None with all flags off routes to baseline."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
):
assert (
await resolve_use_sdk_for_mode(
None,
"user-1",
use_claude_code_subscription=False,
config_default=False,
)
is False
)
class TestResolveEffectiveMode:
"""Tests for the CHAT_MODE_OPTION server-side gate."""
@pytest.mark.asyncio
async def test_none_mode_passes_through(self):
"""mode=None is returned as-is without a flag check."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
) as flag_mock:
assert await resolve_effective_mode(None, "user-1") is None
flag_mock.assert_not_awaited()
@pytest.mark.asyncio
async def test_mode_stripped_when_flag_disabled(self):
"""When CHAT_MODE_OPTION is off, mode is dropped to None."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
):
assert await resolve_effective_mode("fast", "user-1") is None
assert await resolve_effective_mode("extended_thinking", "user-1") is None
@pytest.mark.asyncio
async def test_mode_preserved_when_flag_enabled(self):
"""When CHAT_MODE_OPTION is on, the user-selected mode is preserved."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=True),
):
assert await resolve_effective_mode("fast", "user-1") == "fast"
assert (
await resolve_effective_mode("extended_thinking", "user-1")
== "extended_thinking"
)
@pytest.mark.asyncio
async def test_anonymous_user_with_mode(self):
"""Anonymous users (user_id=None) still pass through the gate."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
) as flag_mock:
assert await resolve_effective_mode("fast", None) is None
flag_mock.assert_awaited_once()

View File

@@ -9,6 +9,7 @@ import logging
from pydantic import BaseModel
from backend.copilot.config import CopilotMode
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
@@ -156,6 +157,9 @@ class CoPilotExecutionEntry(BaseModel):
file_ids: list[str] | None = None
"""Workspace file IDs attached to the user's message"""
mode: CopilotMode | None = None
"""Autopilot mode override: 'fast' or 'extended_thinking'. None = server default."""
class CancelCoPilotEvent(BaseModel):
"""Event to cancel a CoPilot operation."""
@@ -175,6 +179,7 @@ async def enqueue_copilot_turn(
is_user_message: bool = True,
context: dict[str, str] | None = None,
file_ids: list[str] | None = None,
mode: CopilotMode | None = None,
) -> None:
"""Enqueue a CoPilot task for processing by the executor service.
@@ -186,6 +191,7 @@ async def enqueue_copilot_turn(
is_user_message: Whether the message is from the user (vs system/assistant)
context: Optional context for the message (e.g., {url: str, content: str})
file_ids: Optional workspace file IDs attached to the user's message
mode: Autopilot mode override ('fast' or 'extended_thinking'). None = server default.
"""
from backend.util.clients import get_async_copilot_queue
@@ -197,6 +203,7 @@ async def enqueue_copilot_turn(
is_user_message=is_user_message,
context=context,
file_ids=file_ids,
mode=mode,
)
queue_client = await get_async_copilot_queue()

View File

@@ -0,0 +1,123 @@
"""Tests for CoPilot executor utils (queue config, message models, logging)."""
from backend.copilot.executor.utils import (
COPILOT_EXECUTION_EXCHANGE,
COPILOT_EXECUTION_QUEUE_NAME,
COPILOT_EXECUTION_ROUTING_KEY,
CancelCoPilotEvent,
CoPilotExecutionEntry,
CoPilotLogMetadata,
create_copilot_queue_config,
)
class TestCoPilotExecutionEntry:
def test_basic_fields(self):
entry = CoPilotExecutionEntry(
session_id="s1",
user_id="u1",
message="hello",
)
assert entry.session_id == "s1"
assert entry.user_id == "u1"
assert entry.message == "hello"
assert entry.is_user_message is True
assert entry.mode is None
assert entry.context is None
assert entry.file_ids is None
def test_mode_field(self):
entry = CoPilotExecutionEntry(
session_id="s1",
user_id="u1",
message="test",
mode="fast",
)
assert entry.mode == "fast"
entry2 = CoPilotExecutionEntry(
session_id="s1",
user_id="u1",
message="test",
mode="extended_thinking",
)
assert entry2.mode == "extended_thinking"
def test_optional_fields(self):
entry = CoPilotExecutionEntry(
session_id="s1",
user_id="u1",
message="test",
turn_id="t1",
context={"url": "https://example.com"},
file_ids=["f1", "f2"],
is_user_message=False,
)
assert entry.turn_id == "t1"
assert entry.context == {"url": "https://example.com"}
assert entry.file_ids == ["f1", "f2"]
assert entry.is_user_message is False
def test_serialization_roundtrip(self):
entry = CoPilotExecutionEntry(
session_id="s1",
user_id="u1",
message="hello",
mode="fast",
)
json_str = entry.model_dump_json()
restored = CoPilotExecutionEntry.model_validate_json(json_str)
assert restored == entry
class TestCancelCoPilotEvent:
def test_basic(self):
event = CancelCoPilotEvent(session_id="s1")
assert event.session_id == "s1"
def test_serialization(self):
event = CancelCoPilotEvent(session_id="s1")
restored = CancelCoPilotEvent.model_validate_json(event.model_dump_json())
assert restored.session_id == "s1"
class TestCreateCopilotQueueConfig:
def test_returns_valid_config(self):
config = create_copilot_queue_config()
assert len(config.exchanges) == 2
assert len(config.queues) == 2
def test_execution_queue_properties(self):
config = create_copilot_queue_config()
exec_queue = next(
q for q in config.queues if q.name == COPILOT_EXECUTION_QUEUE_NAME
)
assert exec_queue.durable is True
assert exec_queue.exchange == COPILOT_EXECUTION_EXCHANGE
assert exec_queue.routing_key == COPILOT_EXECUTION_ROUTING_KEY
def test_cancel_queue_uses_fanout(self):
config = create_copilot_queue_config()
cancel_queue = next(
q for q in config.queues if q.name != COPILOT_EXECUTION_QUEUE_NAME
)
assert cancel_queue.exchange is not None
assert cancel_queue.exchange.type.value == "fanout"
class TestCoPilotLogMetadata:
def test_creates_logger_with_metadata(self):
import logging
base_logger = logging.getLogger("test")
log = CoPilotLogMetadata(base_logger, session_id="s1", user_id="u1")
assert log is not None
def test_filters_none_values(self):
import logging
base_logger = logging.getLogger("test")
log = CoPilotLogMetadata(
base_logger, session_id="s1", user_id=None, turn_id="t1"
)
assert log is not None

View File

@@ -59,6 +59,16 @@ _null_cache: TTLCache[tuple[str, str], bool] = TTLCache(
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
)
# GitHub user identity caches (keyed by user_id only, not provider tuple).
# Declared here so invalidate_user_provider_cache() can reference them.
_GH_IDENTITY_CACHE_TTL = 600.0 # 10 min — profile data rarely changes
_gh_identity_cache: TTLCache[str, dict[str, str]] = TTLCache(
maxsize=_CACHE_MAX_SIZE, ttl=_GH_IDENTITY_CACHE_TTL
)
_gh_identity_null_cache: TTLCache[str, bool] = TTLCache(
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
)
def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
"""Remove the cached entry for *user_id*/*provider* from both caches.
@@ -66,11 +76,19 @@ def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
Call this after storing new credentials so that the next
``get_provider_token()`` call performs a fresh DB lookup instead of
serving a stale TTL-cached result.
For GitHub specifically, also clears the git-identity caches so that
``get_github_user_git_identity()`` re-fetches the user's profile on
the next call instead of serving stale identity data.
"""
key = (user_id, provider)
_token_cache.pop(key, None)
_null_cache.pop(key, None)
if provider == "github":
_gh_identity_cache.pop(user_id, None)
_gh_identity_null_cache.pop(user_id, None)
# Register this module's cache-bust function with the credentials manager so
# that any create/update/delete operation immediately evicts stale cache
@@ -123,6 +141,7 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
[c for c in creds_list if c.type == "oauth2"],
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
)
refresh_failed = False
for creds in oauth2_creds:
if creds.type == "oauth2":
try:
@@ -141,6 +160,7 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
# Do NOT fall back to the stale token — it is likely expired
# or revoked. Returning None forces the caller to re-auth,
# preventing the LLM from receiving a non-functional token.
refresh_failed = True
continue
_token_cache[cache_key] = token
return token
@@ -152,8 +172,12 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
_token_cache[cache_key] = token
return token
# No credentials found — cache to avoid repeated DB hits.
_null_cache[cache_key] = True
# Only cache "not connected" when the user truly has no credentials for this
# provider. If we had OAuth credentials but refresh failed (e.g. transient
# network error, event-loop mismatch), do NOT cache the negative result —
# the next call should retry the refresh instead of being blocked for 60 s.
if not refresh_failed:
_null_cache[cache_key] = True
return None
@@ -171,3 +195,76 @@ async def get_integration_env_vars(user_id: str) -> dict[str, str]:
for var in var_names:
env[var] = token
return env
# ---------------------------------------------------------------------------
# GitHub user identity (for git committer env vars)
# ---------------------------------------------------------------------------
async def get_github_user_git_identity(user_id: str) -> dict[str, str] | None:
"""Fetch the GitHub user's name and email for git committer env vars.
Uses the ``/user`` GitHub API endpoint with the user's stored token.
Returns a dict with ``GIT_AUTHOR_NAME``, ``GIT_AUTHOR_EMAIL``,
``GIT_COMMITTER_NAME``, and ``GIT_COMMITTER_EMAIL`` if the user has a
connected GitHub account. Returns ``None`` otherwise.
Results are cached for 10 minutes; "not connected" results are cached for
60 s (same as null-token cache).
"""
if user_id in _gh_identity_null_cache:
return None
if cached := _gh_identity_cache.get(user_id):
return cached
token = await get_provider_token(user_id, "github")
if not token:
_gh_identity_null_cache[user_id] = True
return None
import aiohttp
try:
async with aiohttp.ClientSession() as session:
async with session.get(
"https://api.github.com/user",
headers={
"Authorization": f"token {token}",
"Accept": "application/vnd.github+json",
},
timeout=aiohttp.ClientTimeout(total=5),
) as resp:
if resp.status != 200:
logger.warning(
"[git-identity] GitHub /user returned %s for user %s",
resp.status,
user_id,
)
return None
data = await resp.json()
except Exception as exc:
logger.warning(
"[git-identity] Failed to fetch GitHub profile for user %s: %s",
user_id,
exc,
)
return None
name = data.get("name") or data.get("login") or "AutoGPT User"
# GitHub may return email=null if the user has set their email to private.
# Fall back to the noreply address GitHub generates for every account.
email = data.get("email")
if not email:
gh_id = data.get("id", "")
login = data.get("login", "user")
email = f"{gh_id}+{login}@users.noreply.github.com"
identity = {
"GIT_AUTHOR_NAME": name,
"GIT_AUTHOR_EMAIL": email,
"GIT_COMMITTER_NAME": name,
"GIT_COMMITTER_EMAIL": email,
}
_gh_identity_cache[user_id] = identity
return identity

View File

@@ -9,6 +9,8 @@ from backend.copilot.integration_creds import (
_NULL_CACHE_TTL,
_TOKEN_CACHE_TTL,
PROVIDER_ENV_VARS,
_gh_identity_cache,
_gh_identity_null_cache,
_null_cache,
_token_cache,
get_integration_env_vars,
@@ -49,9 +51,13 @@ def clear_caches():
"""Ensure clean caches before and after every test."""
_token_cache.clear()
_null_cache.clear()
_gh_identity_cache.clear()
_gh_identity_null_cache.clear()
yield
_token_cache.clear()
_null_cache.clear()
_gh_identity_cache.clear()
_gh_identity_null_cache.clear()
class TestInvalidateUserProviderCache:
@@ -77,6 +83,34 @@ class TestInvalidateUserProviderCache:
invalidate_user_provider_cache(_USER, _PROVIDER)
assert other_key in _token_cache
def test_clears_gh_identity_cache_for_github_provider(self):
"""When provider is 'github', identity caches must also be cleared."""
_gh_identity_cache[_USER] = {
"GIT_AUTHOR_NAME": "Old Name",
"GIT_AUTHOR_EMAIL": "old@example.com",
"GIT_COMMITTER_NAME": "Old Name",
"GIT_COMMITTER_EMAIL": "old@example.com",
}
invalidate_user_provider_cache(_USER, "github")
assert _USER not in _gh_identity_cache
def test_clears_gh_identity_null_cache_for_github_provider(self):
"""When provider is 'github', the identity null-cache must also be cleared."""
_gh_identity_null_cache[_USER] = True
invalidate_user_provider_cache(_USER, "github")
assert _USER not in _gh_identity_null_cache
def test_does_not_clear_gh_identity_cache_for_other_providers(self):
"""When provider is NOT 'github', identity caches must be left alone."""
_gh_identity_cache[_USER] = {
"GIT_AUTHOR_NAME": "Some Name",
"GIT_AUTHOR_EMAIL": "some@example.com",
"GIT_COMMITTER_NAME": "Some Name",
"GIT_COMMITTER_EMAIL": "some@example.com",
}
invalidate_user_provider_cache(_USER, "some-other-provider")
assert _USER in _gh_identity_cache
class TestGetProviderToken:
@pytest.mark.asyncio(loop_scope="session")
@@ -129,8 +163,15 @@ class TestGetProviderToken:
assert result == "oauth-tok"
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth2_refresh_failure_returns_none(self):
"""On refresh failure, return None instead of caching a stale token."""
async def test_oauth2_refresh_failure_returns_none_without_null_cache(self):
"""On refresh failure, return None but do NOT cache in null_cache.
The user has credentials — they just couldn't be refreshed right now
(e.g. transient network error or event-loop mismatch in the copilot
executor). Caching a negative result would block all credential
lookups for 60 s even though the creds exist and may refresh fine
on the next attempt.
"""
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
@@ -141,6 +182,8 @@ class TestGetProviderToken:
# Stale tokens must NOT be returned — forces re-auth.
assert result is None
# Must NOT cache negative result when refresh failed — next call retries.
assert (_USER, _PROVIDER) not in _null_cache
@pytest.mark.asyncio(loop_scope="session")
async def test_no_credentials_caches_null_entry(self):
@@ -176,6 +219,96 @@ class TestGetProviderToken:
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
class TestThreadSafetyLocks:
"""Bug reproduction: shared AsyncRedisKeyedMutex across threads caused
'Future attached to a different loop' when copilot workers accessed
credentials from different event loops."""
@pytest.mark.asyncio(loop_scope="session")
async def test_store_locks_returns_per_thread_instance(self):
"""IntegrationCredentialsStore.locks() must return different instances
for different threads (via @thread_cached)."""
import asyncio
import concurrent.futures
from backend.integrations.credentials_store import IntegrationCredentialsStore
store = IntegrationCredentialsStore()
async def get_locks_id():
mock_redis = AsyncMock()
with patch(
"backend.integrations.credentials_store.get_redis_async",
return_value=mock_redis,
):
locks = await store.locks()
return id(locks)
# Get locks from main thread
main_id = await get_locks_id()
# Get locks from a worker thread
def run_in_thread():
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(get_locks_id())
finally:
loop.close()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
worker_id = await asyncio.get_event_loop().run_in_executor(
pool, run_in_thread
)
assert main_id != worker_id, (
"Store.locks() returned the same instance across threads. "
"This would cause 'Future attached to a different loop' errors."
)
@pytest.mark.asyncio(loop_scope="session")
async def test_manager_delegates_to_store_locks(self):
"""IntegrationCredentialsManager.locks() should delegate to store."""
from backend.integrations.creds_manager import IntegrationCredentialsManager
manager = IntegrationCredentialsManager()
mock_redis = AsyncMock()
with patch(
"backend.integrations.credentials_store.get_redis_async",
return_value=mock_redis,
):
locks = await manager.locks()
# Should have gotten it from the store
assert locks is not None
class TestRefreshUnlockedPath:
"""Bug reproduction: copilot worker threads need lock-free refresh because
Redis-backed asyncio.Lock created on one event loop can't be used on another."""
@pytest.mark.asyncio(loop_scope="session")
async def test_refresh_if_needed_lock_false_skips_redis(self):
"""refresh_if_needed(lock=False) must not touch Redis locks at all."""
from backend.integrations.creds_manager import IntegrationCredentialsManager
manager = IntegrationCredentialsManager()
creds = _make_oauth2_creds()
mock_handler = MagicMock()
mock_handler.needs_refresh = MagicMock(return_value=False)
with patch(
"backend.integrations.creds_manager._get_provider_oauth_handler",
new_callable=AsyncMock,
return_value=mock_handler,
):
result = await manager.refresh_if_needed(_USER, creds, lock=False)
# Should return credentials without touching locks
assert result.id == creds.id
class TestGetIntegrationEnvVars:
@pytest.mark.asyncio(loop_scope="session")
async def test_injects_all_env_vars_for_provider(self):

View File

@@ -46,6 +46,16 @@ def _get_session_cache_key(session_id: str) -> str:
# ===================== Chat data models ===================== #
class ChatSessionMetadata(BaseModel):
"""Typed metadata stored in the ``metadata`` JSON column of ChatSession.
Add new session-level flags here instead of adding DB columns —
no migration required for new fields as long as a default is provided.
"""
dry_run: bool = False
class ChatMessage(BaseModel):
role: str
content: str | None = None
@@ -54,6 +64,7 @@ class ChatMessage(BaseModel):
refusal: str | None = None
tool_calls: list[dict] | None = None
function_call: dict | None = None
sequence: int | None = None
duration_ms: int | None = None
@staticmethod
@@ -67,10 +78,54 @@ class ChatMessage(BaseModel):
refusal=prisma_message.refusal,
tool_calls=_parse_json_field(prisma_message.toolCalls),
function_call=_parse_json_field(prisma_message.functionCall),
sequence=prisma_message.sequence,
duration_ms=prisma_message.durationMs,
)
def is_message_duplicate(
messages: list[ChatMessage],
role: str,
content: str,
) -> bool:
"""Check whether *content* is already present in the current pending turn.
Only inspects trailing messages that share the given *role* (i.e. the
current turn). This ensures legitimately repeated messages across different
turns are not suppressed, while same-turn duplicates from stale cache are
still caught.
"""
for m in reversed(messages):
if m.role == role:
if m.content == content:
return True
else:
break
return False
def maybe_append_user_message(
session: "ChatSession",
message: str | None,
is_user_message: bool,
) -> bool:
"""Append a user/assistant message to the session if not already present.
The route handler already persists the user message before enqueueing,
so we check trailing same-role messages to avoid re-appending when the
session cache is slightly stale.
Returns True if the message was appended, False if skipped.
"""
if not message:
return False
role = "user" if is_user_message else "assistant"
if is_message_duplicate(session.messages, role, message):
return False
session.messages.append(ChatMessage(role=role, content=message))
return True
class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
@@ -90,6 +145,12 @@ class ChatSessionInfo(BaseModel):
updated_at: datetime
successful_agent_runs: dict[str, int] = {}
successful_agent_schedules: dict[str, int] = {}
metadata: ChatSessionMetadata = ChatSessionMetadata()
@property
def dry_run(self) -> bool:
"""Convenience accessor for ``metadata.dry_run``."""
return self.metadata.dry_run
@classmethod
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
@@ -103,6 +164,10 @@ class ChatSessionInfo(BaseModel):
prisma_session.successfulAgentSchedules, default={}
)
# Parse typed metadata from the JSON column.
raw_metadata = _parse_json_field(prisma_session.metadata, default={})
metadata = ChatSessionMetadata.model_validate(raw_metadata)
# Calculate usage from token counts.
# NOTE: Per-turn cache_read_tokens / cache_creation_tokens breakdown
# is lost after persistence — the DB only stores aggregate prompt and
@@ -128,6 +193,7 @@ class ChatSessionInfo(BaseModel):
updated_at=prisma_session.updatedAt,
successful_agent_runs=successful_agent_runs,
successful_agent_schedules=successful_agent_schedules,
metadata=metadata,
)
@@ -135,7 +201,7 @@ class ChatSession(ChatSessionInfo):
messages: list[ChatMessage]
@classmethod
def new(cls, user_id: str) -> Self:
def new(cls, user_id: str, *, dry_run: bool) -> Self:
return cls(
session_id=str(uuid.uuid4()),
user_id=user_id,
@@ -145,6 +211,7 @@ class ChatSession(ChatSessionInfo):
credentials={},
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
metadata=ChatSessionMetadata(dry_run=dry_run),
)
@classmethod
@@ -532,6 +599,7 @@ async def _save_session_to_db(
await db.create_chat_session(
session_id=session.session_id,
user_id=session.user_id,
metadata=session.metadata,
)
existing_message_count = 0
@@ -609,21 +677,27 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
return session
async def create_chat_session(user_id: str) -> ChatSession:
async def create_chat_session(user_id: str, *, dry_run: bool) -> ChatSession:
"""Create a new chat session and persist it.
Args:
user_id: The authenticated user ID.
dry_run: When True, run_block and run_agent tool calls in this
session are forced to use dry-run simulation mode.
Raises:
DatabaseError: If the database write fails. We fail fast to ensure
callers never receive a non-persisted session that only exists
in cache (which would be lost when the cache expires).
"""
session = ChatSession.new(user_id)
session = ChatSession.new(user_id, dry_run=dry_run)
# Create in database first - fail fast if this fails
try:
await chat_db().create_chat_session(
session_id=session.session_id,
user_id=user_id,
metadata=session.metadata,
)
except Exception as e:
logger.error(f"Failed to create session {session.session_id} in database: {e}")

View File

@@ -17,6 +17,8 @@ from .model import (
ChatSession,
Usage,
get_chat_session,
is_message_duplicate,
maybe_append_user_message,
upsert_chat_session,
)
@@ -46,7 +48,7 @@ messages = [
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_serialization_deserialization():
s = ChatSession.new(user_id="abc123")
s = ChatSession.new(user_id="abc123", dry_run=False)
s.messages = messages
s.usage = [Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300)]
serialized = s.model_dump_json()
@@ -57,7 +59,7 @@ async def test_chatsession_serialization_deserialization():
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_redis_storage(setup_test_user, test_user_id):
s = ChatSession.new(user_id=test_user_id)
s = ChatSession.new(user_id=test_user_id, dry_run=False)
s.messages = messages
s = await upsert_chat_session(s)
@@ -75,7 +77,7 @@ async def test_chatsession_redis_storage_user_id_mismatch(
setup_test_user, test_user_id
):
s = ChatSession.new(user_id=test_user_id)
s = ChatSession.new(user_id=test_user_id, dry_run=False)
s.messages = messages
s = await upsert_chat_session(s)
@@ -90,7 +92,7 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
from backend.data.redis_client import get_redis_async
# Create session with messages including assistant message
s = ChatSession.new(user_id=test_user_id)
s = ChatSession.new(user_id=test_user_id, dry_run=False)
s.messages = messages # Contains user, assistant, and tool messages
assert s.session_id is not None, "Session id is not set"
# Upsert to save to both cache and DB
@@ -241,7 +243,7 @@ _raw_tc2 = {
def test_add_tool_call_appends_to_existing_assistant():
"""When the last assistant is from the current turn, tool_call is added to it."""
session = ChatSession.new(user_id="u")
session = ChatSession.new(user_id="u", dry_run=False)
session.messages = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="working on it"),
@@ -254,7 +256,7 @@ def test_add_tool_call_appends_to_existing_assistant():
def test_add_tool_call_creates_assistant_when_none_exists():
"""When there's no current-turn assistant, a new one is created."""
session = ChatSession.new(user_id="u")
session = ChatSession.new(user_id="u", dry_run=False)
session.messages = [
ChatMessage(role="user", content="hi"),
]
@@ -267,7 +269,7 @@ def test_add_tool_call_creates_assistant_when_none_exists():
def test_add_tool_call_does_not_cross_user_boundary():
"""A user message acts as a boundary — previous assistant is not modified."""
session = ChatSession.new(user_id="u")
session = ChatSession.new(user_id="u", dry_run=False)
session.messages = [
ChatMessage(role="assistant", content="old turn"),
ChatMessage(role="user", content="new message"),
@@ -282,7 +284,7 @@ def test_add_tool_call_does_not_cross_user_boundary():
def test_add_tool_call_multiple_times():
"""Multiple long-running tool calls accumulate on the same assistant."""
session = ChatSession.new(user_id="u")
session = ChatSession.new(user_id="u", dry_run=False)
session.messages = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="doing stuff"),
@@ -300,7 +302,7 @@ def test_add_tool_call_multiple_times():
def test_to_openai_messages_merges_split_assistants():
"""End-to-end: session with split assistants produces valid OpenAI messages."""
session = ChatSession.new(user_id="u")
session = ChatSession.new(user_id="u", dry_run=False)
session.messages = [
ChatMessage(role="user", content="build agent"),
ChatMessage(role="assistant", content="Let me build that"),
@@ -352,7 +354,7 @@ async def test_concurrent_saves_collision_detection(setup_test_user, test_user_i
import asyncio
# Create a session with initial messages
session = ChatSession.new(user_id=test_user_id)
session = ChatSession.new(user_id=test_user_id, dry_run=False)
for i in range(3):
session.messages.append(
ChatMessage(
@@ -424,3 +426,151 @@ async def test_concurrent_saves_collision_detection(setup_test_user, test_user_i
assert "Streaming message 1" in contents
assert "Streaming message 2" in contents
assert "Callback result" in contents
# --------------------------------------------------------------------------- #
# is_message_duplicate #
# --------------------------------------------------------------------------- #
def test_duplicate_detected_in_trailing_same_role():
"""Duplicate user message at the tail is detected."""
msgs = [
ChatMessage(role="user", content="hello"),
ChatMessage(role="assistant", content="hi there"),
ChatMessage(role="user", content="yes"),
]
assert is_message_duplicate(msgs, "user", "yes") is True
def test_duplicate_not_detected_across_turns():
"""Same text in a previous turn (separated by assistant) is NOT a duplicate."""
msgs = [
ChatMessage(role="user", content="yes"),
ChatMessage(role="assistant", content="ok"),
]
assert is_message_duplicate(msgs, "user", "yes") is False
def test_no_duplicate_on_empty_messages():
"""Empty message list never reports a duplicate."""
assert is_message_duplicate([], "user", "hello") is False
def test_no_duplicate_when_content_differs():
"""Different content in the trailing same-role block is not a duplicate."""
msgs = [
ChatMessage(role="assistant", content="response"),
ChatMessage(role="user", content="first message"),
]
assert is_message_duplicate(msgs, "user", "second message") is False
def test_duplicate_with_multiple_trailing_same_role():
"""Detects duplicate among multiple consecutive same-role messages."""
msgs = [
ChatMessage(role="assistant", content="response"),
ChatMessage(role="user", content="msg1"),
ChatMessage(role="user", content="msg2"),
]
assert is_message_duplicate(msgs, "user", "msg1") is True
assert is_message_duplicate(msgs, "user", "msg2") is True
assert is_message_duplicate(msgs, "user", "msg3") is False
def test_duplicate_check_for_assistant_role():
"""Works correctly when checking assistant role too."""
msgs = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="hello"),
ChatMessage(role="assistant", content="how can I help?"),
]
assert is_message_duplicate(msgs, "assistant", "hello") is True
assert is_message_duplicate(msgs, "assistant", "new response") is False
def test_no_false_positive_when_content_is_none():
"""Messages with content=None in the trailing block do not match."""
msgs = [
ChatMessage(role="user", content=None),
ChatMessage(role="user", content="hello"),
]
assert is_message_duplicate(msgs, "user", "hello") is True
# None-content message should not match any string
msgs2 = [
ChatMessage(role="user", content=None),
]
assert is_message_duplicate(msgs2, "user", "hello") is False
def test_all_same_role_messages():
"""When all messages share the same role, the entire list is scanned."""
msgs = [
ChatMessage(role="user", content="first"),
ChatMessage(role="user", content="second"),
ChatMessage(role="user", content="third"),
]
assert is_message_duplicate(msgs, "user", "first") is True
assert is_message_duplicate(msgs, "user", "new") is False
# --------------------------------------------------------------------------- #
# maybe_append_user_message #
# --------------------------------------------------------------------------- #
def test_maybe_append_user_message_appends_new():
"""A new user message is appended and returns True."""
session = ChatSession.new(user_id="u", dry_run=False)
session.messages = [
ChatMessage(role="assistant", content="hello"),
]
result = maybe_append_user_message(session, "new msg", is_user_message=True)
assert result is True
assert len(session.messages) == 2
assert session.messages[-1].role == "user"
assert session.messages[-1].content == "new msg"
def test_maybe_append_user_message_skips_duplicate():
"""A duplicate user message is skipped and returns False."""
session = ChatSession.new(user_id="u", dry_run=False)
session.messages = [
ChatMessage(role="assistant", content="hello"),
ChatMessage(role="user", content="dup"),
]
result = maybe_append_user_message(session, "dup", is_user_message=True)
assert result is False
assert len(session.messages) == 2
def test_maybe_append_user_message_none_message():
"""None/empty message returns False without appending."""
session = ChatSession.new(user_id="u", dry_run=False)
assert maybe_append_user_message(session, None, is_user_message=True) is False
assert maybe_append_user_message(session, "", is_user_message=True) is False
assert len(session.messages) == 0
def test_maybe_append_assistant_message():
"""Works for assistant role when is_user_message=False."""
session = ChatSession.new(user_id="u", dry_run=False)
session.messages = [
ChatMessage(role="user", content="hi"),
]
result = maybe_append_user_message(session, "response", is_user_message=False)
assert result is True
assert session.messages[-1].role == "assistant"
assert session.messages[-1].content == "response"
def test_maybe_append_assistant_skips_duplicate():
"""Duplicate assistant message is skipped."""
session = ChatSession.new(user_id="u", dry_run=False)
session.messages = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="dup"),
]
result = maybe_append_user_message(session, "dup", is_user_message=False)
assert result is False
assert len(session.messages) == 2

View File

@@ -66,6 +66,7 @@ from pydantic import BaseModel, PrivateAttr
ToolName = Literal[
# Platform tools (must match keys in TOOL_REGISTRY)
"add_understanding",
"ask_question",
"bash_exec",
"browser_act",
"browser_navigate",
@@ -102,6 +103,7 @@ ToolName = Literal[
"web_fetch",
"write_workspace_file",
# SDK built-ins
"Agent",
"Edit",
"Glob",
"Grep",

View File

@@ -544,6 +544,7 @@ class TestApplyToolPermissions:
class TestSdkBuiltinToolNames:
def test_expected_builtins_present(self):
expected = {
"Agent",
"Read",
"Write",
"Edit",

View File

@@ -18,6 +18,18 @@ After `write_workspace_file`, embed the `download_url` in Markdown:
- Image: `![chart](workspace://file_id#image/png)`
- Video: `![recording](workspace://file_id#video/mp4)`
### Handling binary/image data in tool outputs — CRITICAL
When a tool output contains base64-encoded binary data (images, PDFs, etc.):
1. **NEVER** try to inline or render the base64 content in your response.
2. **Save** the data to workspace using `write_workspace_file` (pass the base64 data URI as content).
3. **Show** the result via the workspace download URL in Markdown: `![image](workspace://file_id#image/png)`.
### Passing large data between tools — CRITICAL
When tool outputs produce large text that you need to feed into another tool:
- **NEVER** copy-paste the full text into the next tool call argument.
- **Save** the output to a file (workspace or local), then use `@@agptfile:` references.
- This avoids token limits and ensures data integrity.
### File references — @@agptfile:
Pass large file content to tools by reference: `@@agptfile:<uri>[<start>-<end>]`
- `workspace://<file_id>` or `workspace:///<path>` — workspace files
@@ -107,6 +119,28 @@ Do not re-fetch or re-generate data you already have from prior tool calls.
After building the file, reference it with `@@agptfile:` in other tools:
`@@agptfile:/home/user/report.md`
### Web search best practices
- If 3 similar web searches don't return the specific data you need, conclude
it isn't publicly available and work with what you have.
- Prefer fewer, well-targeted searches over many variations of the same query.
- When spawning sub-agents for research, ensure each has a distinct
non-overlapping scope to avoid redundant searches.
### Tool Discovery Priority
When the user asks to interact with a service or API, follow this order:
1. **find_block first** — Search platform blocks with `find_block`. The platform has hundreds of built-in blocks (Google Sheets, Docs, Calendar, Gmail, Slack, GitHub, etc.) that work without extra setup.
2. **run_mcp_tool** — If no matching block exists, check if a hosted MCP server is available for the service. Only use known MCP server URLs from the registry.
3. **SendAuthenticatedWebRequestBlock** — If no block or MCP server exists, use `SendAuthenticatedWebRequestBlock` with existing host-scoped credentials. Check available credentials via `connect_integration`.
4. **Manual API call** — As a last resort, guide the user to set up credentials and use `SendAuthenticatedWebRequestBlock` with direct API calls.
**Never skip step 1.** Built-in blocks are more reliable, tested, and user-friendly than MCP or raw API calls.
### Sub-agent tasks
- When using the Task tool, NEVER set `run_in_background` to true.
All tasks must run in the foreground.
@@ -131,6 +165,11 @@ parent autopilot handles orchestration.
# E2B-only notes — E2B has full internet access so gh CLI works there.
# Not shown in local (bubblewrap) mode: --unshare-net blocks all network.
_E2B_TOOL_NOTES = """
### SDK tool-result files in E2B
When you `Read` an SDK tool-result file, it is automatically copied into the
sandbox so `bash_exec` can access it for further processing.
The exact sandbox path is shown in the `[Sandbox copy available at ...]` note.
### GitHub CLI (`gh`) and git
- If the user has connected their GitHub account, both `gh` and `git` are
pre-authenticated — use them directly without any manual login step.
@@ -196,19 +235,22 @@ def _build_storage_supplement(
- Files here **survive across sessions indefinitely**
### Moving files between storages
- **{file_move_name_1_to_2}**: Copy to persistent workspace
- **{file_move_name_2_to_1}**: Download for processing
- **{file_move_name_1_to_2}**: `write_workspace_file(filename="output.json", source_path="/path/to/local/file")`
- **{file_move_name_2_to_1}**: `read_workspace_file(path="tool-outputs/data.json", save_to_path="{working_dir}/data.json")`
### File persistence
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
### SDK tool-result files
When tool outputs are large, the SDK truncates them and saves the full output to
a local file under `~/.claude/projects/.../tool-results/`. To read these files,
always use `Read` (NOT `bash_exec`, NOT `read_workspace_file`).
These files are on the host filesystem — `bash_exec` runs in the sandbox and
CANNOT access them. `read_workspace_file` reads from cloud workspace storage,
where SDK tool-results are NOT stored.
a local file under `~/.claude/projects/.../tool-results/` (or `tool-outputs/`).
To read these files, use `Read` — it reads from the host filesystem.
### Large tool outputs saved to workspace
When a tool output contains `<tool-output-truncated workspace_path="...">`, the
full output is in workspace storage (NOT on the local filesystem). To access it:
- Use `read_workspace_file(path="...", offset=..., length=50000)` for reading sections.
- To process in the sandbox, use `read_workspace_file(path="...", save_to_path="{working_dir}/file.json")` first, then use `bash_exec` on the local copy.
{_SHARED_TOOL_NOTES}{extra_notes}"""

View File

@@ -0,0 +1,28 @@
"""Tests for agent generation guide — verifies clarification section."""
from pathlib import Path
class TestAgentGenerationGuideContainsClarifySection:
"""The agent generation guide must include the clarification section."""
def test_guide_includes_clarify_section(self):
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
content = guide_path.read_text(encoding="utf-8")
assert "Before or During Building" in content
def test_guide_mentions_find_block_for_clarification(self):
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
content = guide_path.read_text(encoding="utf-8")
clarify_section = content.split("Before or During Building")[1].split(
"### Workflow"
)[0]
assert "find_block" in clarify_section
def test_guide_mentions_ask_question_tool(self):
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
content = guide_path.read_text(encoding="utf-8")
clarify_section = content.split("Before or During Building")[1].split(
"### Workflow"
)[0]
assert "ask_question" in clarify_section

View File

@@ -9,11 +9,15 @@ UTC). Fails open when Redis is unavailable to avoid blocking users.
import asyncio
import logging
from datetime import UTC, datetime, timedelta
from enum import Enum
from prisma.models import User as PrismaUser
from pydantic import BaseModel, Field
from redis.exceptions import RedisError
from backend.data.db_accessors import user_db
from backend.data.redis_client import get_redis_async
from backend.util.cache import cached
logger = logging.getLogger(__name__)
@@ -21,6 +25,40 @@ logger = logging.getLogger(__name__)
_USAGE_KEY_PREFIX = "copilot:usage"
# ---------------------------------------------------------------------------
# Subscription tier definitions
# ---------------------------------------------------------------------------
class SubscriptionTier(str, Enum):
"""Subscription tiers with increasing token allowances.
Mirrors the ``SubscriptionTier`` enum in ``schema.prisma``.
Once ``prisma generate`` is run, this can be replaced with::
from prisma.enums import SubscriptionTier
"""
FREE = "FREE"
PRO = "PRO"
BUSINESS = "BUSINESS"
ENTERPRISE = "ENTERPRISE"
# Multiplier applied to the base limits (from LD / config) for each tier.
# Intentionally int (not float): keeps limits as whole token counts and avoids
# floating-point rounding. If fractional multipliers are ever needed, change
# the type and round the result in get_global_rate_limits().
TIER_MULTIPLIERS: dict[SubscriptionTier, int] = {
SubscriptionTier.FREE: 1,
SubscriptionTier.PRO: 5,
SubscriptionTier.BUSINESS: 20,
SubscriptionTier.ENTERPRISE: 60,
}
DEFAULT_TIER = SubscriptionTier.FREE
class UsageWindow(BaseModel):
"""Usage within a single time window."""
@@ -36,6 +74,7 @@ class CoPilotUsageStatus(BaseModel):
daily: UsageWindow
weekly: UsageWindow
tier: SubscriptionTier = DEFAULT_TIER
reset_cost: int = Field(
default=0,
description="Credit cost (in cents) to reset the daily limit. 0 = feature disabled.",
@@ -66,6 +105,7 @@ async def get_usage_status(
daily_token_limit: int,
weekly_token_limit: int,
rate_limit_reset_cost: int = 0,
tier: SubscriptionTier = DEFAULT_TIER,
) -> CoPilotUsageStatus:
"""Get current usage status for a user.
@@ -74,6 +114,7 @@ async def get_usage_status(
daily_token_limit: Max tokens per day (0 = unlimited).
weekly_token_limit: Max tokens per week (0 = unlimited).
rate_limit_reset_cost: Credit cost (cents) to reset daily limit (0 = disabled).
tier: The user's rate-limit tier (included in the response).
Returns:
CoPilotUsageStatus with current usage and limits.
@@ -103,6 +144,7 @@ async def get_usage_status(
limit=weekly_token_limit,
resets_at=_weekly_reset_time(now=now),
),
tier=tier,
reset_cost=rate_limit_reset_cost,
)
@@ -161,8 +203,9 @@ async def reset_daily_usage(user_id: str, daily_token_limit: int = 0) -> bool:
daily_token_limit: The configured daily token limit. When positive,
the weekly counter is reduced by this amount.
Fails open: returns False if Redis is unavailable (consistent with
the fail-open design of this module).
Returns False if Redis is unavailable so the caller can handle
compensation (fail-closed for billed operations, unlike the read-only
rate-limit checks which fail-open).
"""
now = datetime.now(UTC)
try:
@@ -342,20 +385,103 @@ async def record_token_usage(
)
class _UserNotFoundError(Exception):
"""Raised when a user record is missing or has no subscription tier.
Used internally by ``_fetch_user_tier`` to signal a cache-miss condition:
by raising instead of returning ``DEFAULT_TIER``, we prevent the ``@cached``
decorator from storing the fallback value. This avoids a race condition
where a non-existent user's DEFAULT_TIER is cached, then the user is
created with a higher tier but receives the stale cached FREE tier for
up to 5 minutes.
"""
@cached(maxsize=1000, ttl_seconds=300, shared_cache=True)
async def _fetch_user_tier(user_id: str) -> SubscriptionTier:
"""Fetch the user's rate-limit tier from the database (cached via Redis).
Uses ``shared_cache=True`` so that tier changes propagate across all pods
immediately when the cache entry is invalidated (via ``cache_delete``).
Only successful DB lookups of existing users with a valid tier are cached.
Raises ``_UserNotFoundError`` when the user is missing or has no tier, so
the ``@cached`` decorator does **not** store a fallback value. This
prevents a race condition where a non-existent user's ``DEFAULT_TIER`` is
cached and then persists after the user is created with a higher tier.
"""
try:
user = await user_db().get_user_by_id(user_id)
except Exception:
raise _UserNotFoundError(user_id)
if user.subscription_tier:
return SubscriptionTier(user.subscription_tier)
raise _UserNotFoundError(user_id)
async def get_user_tier(user_id: str) -> SubscriptionTier:
"""Look up the user's rate-limit tier from the database.
Successful results are cached for 5 minutes (via ``_fetch_user_tier``)
to avoid a DB round-trip on every rate-limit check.
Falls back to ``DEFAULT_TIER`` **without caching** when the DB is
unreachable or returns an unrecognised value, so the next call retries
the query instead of serving a stale fallback for up to 5 minutes.
"""
try:
return await _fetch_user_tier(user_id)
except Exception as exc:
logger.warning(
"Failed to resolve rate-limit tier for user %s, defaulting to %s: %s",
user_id[:8],
DEFAULT_TIER.value,
exc,
)
return DEFAULT_TIER
# Expose cache management on the public function so callers (including tests)
# never need to reach into the private ``_fetch_user_tier``.
get_user_tier.cache_clear = _fetch_user_tier.cache_clear # type: ignore[attr-defined]
get_user_tier.cache_delete = _fetch_user_tier.cache_delete # type: ignore[attr-defined]
async def set_user_tier(user_id: str, tier: SubscriptionTier) -> None:
"""Persist the user's rate-limit tier to the database.
Also invalidates the ``get_user_tier`` cache for this user so that
subsequent rate-limit checks immediately see the new tier.
Raises:
prisma.errors.RecordNotFoundError: If the user does not exist.
"""
await PrismaUser.prisma().update(
where={"id": user_id},
data={"subscriptionTier": tier.value},
)
# Invalidate cached tier so rate-limit checks pick up the change immediately.
get_user_tier.cache_delete(user_id) # type: ignore[attr-defined]
async def get_global_rate_limits(
user_id: str,
config_daily: int,
config_weekly: int,
) -> tuple[int, int]:
) -> tuple[int, int, SubscriptionTier]:
"""Resolve global rate limits from LaunchDarkly, falling back to config.
The base limits (from LD or config) are multiplied by the user's
tier multiplier so that higher tiers receive proportionally larger
allowances.
Args:
user_id: User ID for LD flag evaluation context.
config_daily: Fallback daily limit from ChatConfig.
config_weekly: Fallback weekly limit from ChatConfig.
Returns:
(daily_token_limit, weekly_token_limit) tuple.
(daily_token_limit, weekly_token_limit, tier) 3-tuple.
"""
# Lazy import to avoid circular dependency:
# rate_limit -> feature_flag -> settings -> ... -> rate_limit
@@ -377,7 +503,15 @@ async def get_global_rate_limits(
except (TypeError, ValueError):
logger.warning("Invalid LD value for weekly token limit: %r", weekly_raw)
weekly = config_weekly
return daily, weekly
# Apply tier multiplier
tier = await get_user_tier(user_id)
multiplier = TIER_MULTIPLIERS.get(tier, 1)
if multiplier != 1:
daily = daily * multiplier
weekly = weekly * multiplier
return daily, weekly, tier
async def reset_user_usage(user_id: str, *, reset_weekly: bool = False) -> None:

File diff suppressed because it is too large Load Diff

View File

@@ -9,7 +9,7 @@ import pytest
from fastapi import HTTPException
from backend.api.features.chat.routes import reset_copilot_usage
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
from backend.copilot.rate_limit import CoPilotUsageStatus, SubscriptionTier, UsageWindow
from backend.util.exceptions import InsufficientBalanceError
@@ -53,6 +53,18 @@ def _mock_settings(enable_credit: bool = True):
return mock
def _mock_rate_limits(
daily: int = 2_500_000,
weekly: int = 12_500_000,
tier: SubscriptionTier = SubscriptionTier.PRO,
):
"""Mock get_global_rate_limits to return fixed limits (no tier multiplier)."""
return patch(
f"{_MODULE}.get_global_rate_limits",
AsyncMock(return_value=(daily, weekly, tier)),
)
@pytest.mark.asyncio
class TestResetCopilotUsage:
async def test_feature_disabled_returns_400(self):
@@ -70,6 +82,7 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", _make_config(daily_token_limit=0)),
patch(f"{_MODULE}.settings", _mock_settings()),
_mock_rate_limits(daily=0),
):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
@@ -83,6 +96,7 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
@@ -112,6 +126,7 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
@@ -141,6 +156,7 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
@@ -171,6 +187,7 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=3)),
):
with pytest.raises(HTTPException) as exc_info:
@@ -208,6 +225,7 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
@@ -228,6 +246,7 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", _make_config()),
patch(f"{_MODULE}.settings", _mock_settings()),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=None)),
):
with pytest.raises(HTTPException) as exc_info:
@@ -245,6 +264,7 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
@@ -275,6 +295,7 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),

View File

@@ -3,26 +3,62 @@
You can create, edit, and customize agents directly. You ARE the brain —
generate the agent JSON yourself using block schemas, then validate and save.
### Clarifying — Before or During Building
Use `ask_question` whenever the user's intent is ambiguous — whether
that's before starting or midway through the workflow. Common moments:
- **Before building**: output format, delivery channel, data source, or
trigger is unspecified.
- **During block discovery**: multiple blocks could fit and the user
should choose.
- **During JSON generation**: a wiring decision depends on user
preference.
Steps:
1. Call `find_block` (or another discovery tool) to learn what the
platform actually supports for the ambiguous dimension.
2. Call `ask_question` with a concrete question listing the discovered
options (e.g. "The platform supports Gmail, Slack, and Google Docs —
which should the agent use for delivery?").
3. **Wait for the user's answer** before continuing.
**Skip this** when the goal already specifies all dimensions (e.g.
"scrape prices from Amazon and email me daily").
### Workflow for Creating/Editing Agents
1. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
1. **If editing**: First narrow to the specific agent by UUID, then fetch its
graph: `find_library_agent(query="<agent_id>", include_graph=true)`. This
returns the full graph structure (nodes + links). **Never edit blindly**
always inspect the current graph first so you know exactly what to change.
Avoid using `include_graph=true` with broad keyword searches, as fetching
multiple graphs at once is expensive and consumes LLM context budget.
2. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
search for relevant blocks. This returns block IDs, names, descriptions,
and full input/output schemas.
2. **Find library agents**: Call `find_library_agent` to discover reusable
3. **Find library agents**: Call `find_library_agent` to discover reusable
agents that can be composed as sub-agents via `AgentExecutorBlock`.
3. **Generate JSON**: Build the agent JSON using block schemas:
- Use block IDs from step 1 as `block_id` in nodes
4. **Generate/modify JSON**: Build or modify the agent JSON using block schemas:
- Use block IDs from step 2 as `block_id` in nodes
- Wire outputs to inputs using links
- Set design-time config in `input_default`
- Use `AgentInputBlock` for values the user provides at runtime
4. **Write to workspace**: Save the JSON to a workspace file so the user
- When editing, apply targeted changes and preserve unchanged parts
5. **Write to workspace**: Save the JSON to a workspace file so the user
can review it: `write_workspace_file(filename="agent.json", content=...)`
5. **Validate**: Call `validate_agent_graph` with the agent JSON to check
6. **Validate**: Call `validate_agent_graph` with the agent JSON to check
for errors
6. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
7. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
or fix manually based on the error descriptions. Iterate until valid.
7. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
8. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
the final `agent_json`
8. **Dry-run**: ALWAYS call `run_agent` with `dry_run=True` and
`wait_for_result=120` to verify the agent works end-to-end.
9. **Inspect & fix**: Check the dry-run output for errors. If issues are
found, call `edit_agent` to fix and dry-run again. Repeat until the
simulation passes or the problems are clearly unfixable.
See "REQUIRED: Dry-Run Verification Loop" section below for details.
### Agent JSON Structure
@@ -74,8 +110,8 @@ These define the agent's interface — what it accepts and what it produces.
**AgentDropdownInputBlock** (ID: `655d6fdf-a334-421c-b733-520549c07cd1`):
- Specialized input block that presents a dropdown/select to the user
- Required `input_default` fields: `name` (str), `placeholder_values` (list of options, must have at least one)
- Optional: `title`, `description`, `value` (default selection)
- Required `input_default` fields: `name` (str)
- Optional: `options` (list of dropdown values; when omitted/empty, input behaves as free-text), `title`, `description`, `value` (default selection)
- Output: `result` — the user-selected value at runtime
- Use this instead of AgentInputBlock when the user should pick from a fixed set of options
@@ -216,19 +252,62 @@ call in a loop until the task is complete:
Regular blocks work exactly like sub-agents as tools — wire each input
field from `source_name: "tools"` on the Orchestrator side.
### Testing with Dry Run
### REQUIRED: Dry-Run Verification Loop (create -> dry-run -> fix)
After saving an agent, suggest a dry run to validate wiring without consuming
real API calls, credentials, or credits:
After creating or editing an agent, you MUST dry-run it before telling the
user the agent is ready. NEVER skip this step.
1. **Run**: Call `run_agent` or `run_block` with `dry_run=True` and provide
sample inputs. This executes the graph with mock outputs, verifying that
links resolve correctly and required inputs are satisfied.
2. **Check results**: Call `view_agent_output` with `show_execution_details=True`
to inspect the full node-by-node execution trace. This shows what each node
received as input and produced as output, making it easy to spot wiring issues.
3. **Iterate**: If the dry run reveals wiring issues or missing inputs, fix
the agent JSON and re-save before suggesting a real execution.
#### Step-by-step workflow
1. **Create/Edit**: Call `create_agent` or `edit_agent` to save the agent.
2. **Dry-run**: Call `run_agent` with `dry_run=True`, `wait_for_result=120`,
and realistic sample inputs that exercise every path in the agent. This
simulates execution using an LLM for each block — no real API calls,
credentials, or credits are consumed.
3. **Inspect output**: Examine the dry-run result for problems. If
`wait_for_result` returns only a summary, call
`view_agent_output(execution_id=..., show_execution_details=True)` to
see the full node-by-node execution trace. Look for:
- **Errors / failed nodes** — a node raised an exception or returned an
error status. Common causes: wrong `source_name`/`sink_name` in links,
missing `input_default` values, or referencing a nonexistent block output.
- **Null / empty outputs** — data did not flow through a link. Verify that
`source_name` and `sink_name` match the block schemas exactly (case-
sensitive, including nested `_#_` notation).
- **Nodes that never executed** — the node was not reached. Likely a
missing or broken link from an upstream node.
- **Unexpected values** — data arrived but in the wrong type or
structure. Check type compatibility between linked ports.
4. **Fix**: If any issues are found, call `edit_agent` with the corrected
agent JSON, then go back to step 2.
5. **Repeat**: Continue the dry-run -> fix cycle until the simulation passes
or the problems are clearly unfixable. If you stop making progress,
report the remaining issues to the user and ask for guidance.
#### Good vs bad dry-run output
**Good output** (agent is ready):
- All nodes executed successfully (no errors in the execution trace)
- Data flows through every link with non-null, correctly-typed values
- The final `AgentOutputBlock` contains a meaningful result
- Status is `COMPLETED`
**Bad output** (needs fixing):
- Status is `FAILED` — check the error message for the failing node
- An output node received `null` — trace back to find the broken link
- A node received data in the wrong format (e.g. string where list expected)
- Nodes downstream of a failing node were skipped entirely
**Special block behaviour in dry-run mode:**
- **OrchestratorBlock** and **AgentExecutorBlock** execute for real so the
orchestrator can make LLM calls and agent executors can spawn child graphs.
Their downstream tool blocks and child-graph blocks are still simulated.
Note: real LLM inference calls are made (consuming API quota), even though
platform credits are not charged. Agent-mode iterations are capped at 1 in
dry-run to keep it fast.
- **MCPToolBlock** is simulated using the selected tool's name and JSON Schema
so the LLM can produce a realistic mock response without connecting to the
MCP server.
### Example: Simple AI Text Processor

View File

@@ -25,7 +25,7 @@ from backend.copilot.sdk.compaction import (
def _make_session() -> ChatSession:
return ChatSession.new(user_id="test-user")
return ChatSession.new(user_id="test-user", dry_run=False)
# ---------------------------------------------------------------------------

View File

@@ -2,14 +2,30 @@
from __future__ import annotations
from collections.abc import AsyncIterator
from unittest.mock import patch
from uuid import uuid4
import pytest
import pytest_asyncio
from backend.util import json
@pytest_asyncio.fixture(scope="session", loop_scope="session", name="server")
async def _server_noop() -> None:
"""No-op server stub — SDK tests don't need the full backend."""
return None
@pytest_asyncio.fixture(
scope="session", loop_scope="session", autouse=True, name="graph_cleanup"
)
async def _graph_cleanup_noop() -> AsyncIterator[None]:
"""No-op graph cleanup stub."""
yield
@pytest.fixture()
def mock_chat_config():
"""Mock ChatConfig so compact_transcript tests skip real config lookup."""

View File

@@ -8,6 +8,9 @@ SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
"""
import asyncio
import base64
import hashlib
import itertools
import json
import logging
@@ -28,6 +31,12 @@ from backend.copilot.context import (
logger = logging.getLogger(__name__)
# Default number of lines returned by ``read_file`` when the caller does not
# specify a limit. Also used as the threshold in ``bridge_to_sandbox`` to
# decide whether the model is requesting the full file (and thus whether the
# bridge copy is worthwhile).
_DEFAULT_READ_LIMIT = 2000
async def _check_sandbox_symlink_escape(
sandbox: Any,
@@ -89,7 +98,7 @@ def _get_sandbox_and_path(
return sandbox, remote
async def _sandbox_write(sandbox: Any, path: str, content: str) -> None:
async def _sandbox_write(sandbox: Any, path: str, content: str | bytes) -> None:
"""Write *content* to *path* inside the sandbox.
The E2B filesystem API (``sandbox.files.write``) and the command API
@@ -102,11 +111,14 @@ async def _sandbox_write(sandbox: Any, path: str, content: str) -> None:
To work around this, writes targeting ``/tmp`` are performed via
``tee`` through the command API, which runs as the sandbox ``user``
and can therefore always overwrite user-owned files.
*content* may be ``str`` (text) or ``bytes`` (binary). Both paths
are handled correctly: text is encoded to bytes for the base64 shell
pipe, and raw bytes are passed through without any encoding.
"""
if path == "/tmp" or path.startswith("/tmp/"):
import base64 as _b64
encoded = _b64.b64encode(content.encode()).decode()
raw = content.encode() if isinstance(content, str) else content
encoded = base64.b64encode(raw).decode()
result = await sandbox.commands.run(
f"echo {shlex.quote(encoded)} | base64 -d > {shlex.quote(path)}",
cwd=E2B_WORKDIR,
@@ -128,14 +140,25 @@ async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
"""Read lines from a sandbox file, falling back to the local host for SDK-internal paths."""
file_path: str = args.get("file_path", "")
offset: int = max(0, int(args.get("offset", 0)))
limit: int = max(1, int(args.get("limit", 2000)))
limit: int = max(1, int(args.get("limit", _DEFAULT_READ_LIMIT)))
if not file_path:
return _mcp("file_path is required", error=True)
# SDK-internal paths (tool-results, ephemeral working dir) stay on the host.
# SDK-internal paths (tool-results/tool-outputs, ephemeral working dir)
# stay on the host. When E2B is active, also copy the file into the
# sandbox so bash_exec can access it for further processing.
if _is_allowed_local(file_path):
return _read_local(file_path, offset, limit)
result = _read_local(file_path, offset, limit)
if not result.get("isError"):
sandbox = _get_sandbox()
if sandbox is not None:
annotation = await bridge_and_annotate(
sandbox, file_path, offset, limit
)
if annotation:
result["content"][0]["text"] += annotation
return result
result = _get_sandbox_and_path(file_path)
if isinstance(result, dict):
@@ -302,6 +325,103 @@ async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
return _mcp(output if output else "No matches found.")
# Bridging: copy SDK-internal files into E2B sandbox
# Files larger than this are written to /home/user/ via sandbox.files.write()
# instead of /tmp/ via shell base64, to avoid shell argument length limits
# and E2B command timeouts. Base64 expands content by ~33%, so keep this
# well under the typical Linux ARG_MAX (128 KB).
_BRIDGE_SHELL_MAX_BYTES = 32 * 1024 # 32 KB
# Files larger than this are skipped entirely to avoid excessive transfer times.
_BRIDGE_SKIP_BYTES = 50 * 1024 * 1024 # 50 MB
async def bridge_to_sandbox(
sandbox: Any, file_path: str, offset: int, limit: int
) -> str | None:
"""Best-effort copy of a host-side SDK file into the E2B sandbox.
When the model reads an SDK-internal file (e.g. tool-results), it often
wants to process the data with bash. Copying the file into the sandbox
under a stable name lets ``bash_exec`` access it without extra steps.
Only copies when offset=0 and limit is large enough to indicate the model
wants the full file. Errors are logged but never propagated.
Returns the sandbox path on success, or ``None`` on skip/failure.
Size handling:
- <= 32 KB: written to ``/tmp/<hash>-<basename>`` via shell base64
(``_sandbox_write``). Kept small to stay within ARG_MAX.
- 32 KB - 50 MB: written to ``/home/user/<hash>-<basename>`` via
``sandbox.files.write()`` to avoid shell argument length limits.
- > 50 MB: skipped entirely with a warning.
The sandbox filename is prefixed with a short hash of the full source
path to avoid collisions when different source files share the same
basename (e.g. multiple ``result.json`` files).
"""
if offset != 0 or limit < _DEFAULT_READ_LIMIT:
return None
try:
expanded = os.path.realpath(os.path.expanduser(file_path))
basename = os.path.basename(expanded)
source_id = hashlib.sha256(expanded.encode()).hexdigest()[:12]
unique_name = f"{source_id}-{basename}"
file_size = os.path.getsize(expanded)
if file_size > _BRIDGE_SKIP_BYTES:
logger.warning(
"[E2B] Skipping bridge for large file (%d bytes): %s",
file_size,
basename,
)
return None
def _read_bytes() -> bytes:
with open(expanded, "rb") as fh:
return fh.read()
raw_content = await asyncio.to_thread(_read_bytes)
try:
text_content: str | None = raw_content.decode("utf-8")
except UnicodeDecodeError:
text_content = None
data: str | bytes = text_content if text_content is not None else raw_content
if file_size <= _BRIDGE_SHELL_MAX_BYTES:
sandbox_path = f"/tmp/{unique_name}"
await _sandbox_write(sandbox, sandbox_path, data)
else:
sandbox_path = f"/home/user/{unique_name}"
await sandbox.files.write(sandbox_path, data)
logger.info(
"[E2B] Bridged SDK file to sandbox: %s -> %s", basename, sandbox_path
)
return sandbox_path
except Exception:
logger.warning(
"[E2B] Failed to bridge SDK file to sandbox: %s",
file_path,
exc_info=True,
)
return None
async def bridge_and_annotate(
sandbox: Any, file_path: str, offset: int, limit: int
) -> str | None:
"""Bridge a host file to the sandbox and return a newline-prefixed annotation.
Combines ``bridge_to_sandbox`` with the standard annotation suffix so
callers don't need to duplicate the pattern. Returns a string like
``"\\n[Sandbox copy available at /tmp/abc-file.txt]"`` on success, or
``None`` if bridging was skipped or failed.
"""
sandbox_path = await bridge_to_sandbox(sandbox, file_path, offset, limit)
if sandbox_path is None:
return None
return f"\n[Sandbox copy available at {sandbox_path}]"
# Local read (for SDK-internal paths)

View File

@@ -3,6 +3,7 @@
Pure unit tests with no external dependencies (no E2B, no sandbox).
"""
import hashlib
import os
import shutil
from types import SimpleNamespace
@@ -13,12 +14,26 @@ import pytest
from backend.copilot.context import E2B_WORKDIR, SDK_PROJECTS_DIR, _current_project_dir
from .e2b_file_tools import (
_BRIDGE_SHELL_MAX_BYTES,
_BRIDGE_SKIP_BYTES,
_DEFAULT_READ_LIMIT,
_check_sandbox_symlink_escape,
_read_local,
_sandbox_write,
bridge_and_annotate,
bridge_to_sandbox,
resolve_sandbox_path,
)
def _expected_bridge_path(file_path: str, prefix: str = "/tmp") -> str:
"""Compute the expected sandbox path for a bridged file."""
expanded = os.path.realpath(os.path.expanduser(file_path))
basename = os.path.basename(expanded)
source_id = hashlib.sha256(expanded.encode()).hexdigest()[:12]
return f"{prefix}/{source_id}-{basename}"
# ---------------------------------------------------------------------------
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
# ---------------------------------------------------------------------------
@@ -91,9 +106,9 @@ class TestResolveSandboxPath:
# ---------------------------------------------------------------------------
# _read_local — host filesystem reads with allowlist enforcement
#
# In E2B mode, _read_local only allows tool-results paths (via
# is_allowed_local_path without sdk_cwd). Regular files live on the
# sandbox, not the host.
# In E2B mode, _read_local only allows tool-results/tool-outputs paths
# (via is_allowed_local_path without sdk_cwd). Regular files live on
# the sandbox, not the host.
# ---------------------------------------------------------------------------
@@ -119,7 +134,7 @@ class TestReadLocal:
)
token = _current_project_dir.set(encoded)
try:
result = _read_local(filepath, offset=0, limit=2000)
result = _read_local(filepath, offset=0, limit=_DEFAULT_READ_LIMIT)
assert result["isError"] is False
assert "line 1" in result["content"][0]["text"]
assert "line 2" in result["content"][0]["text"]
@@ -127,6 +142,25 @@ class TestReadLocal:
_current_project_dir.reset(token)
os.unlink(filepath)
def test_read_tool_outputs_file(self):
"""Reading a tool-outputs file should also succeed."""
encoded = "-tmp-copilot-e2b-test-read-outputs"
tool_outputs_dir = os.path.join(
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-outputs"
)
os.makedirs(tool_outputs_dir, exist_ok=True)
filepath = os.path.join(tool_outputs_dir, "sdk-abc123.json")
with open(filepath, "w") as f:
f.write('{"data": "test"}\n')
token = _current_project_dir.set(encoded)
try:
result = _read_local(filepath, offset=0, limit=_DEFAULT_READ_LIMIT)
assert result["isError"] is False
assert "test" in result["content"][0]["text"]
finally:
_current_project_dir.reset(token)
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
def test_read_disallowed_path_blocked(self):
"""Reading /etc/passwd should be blocked by the allowlist."""
result = _read_local("/etc/passwd", offset=0, limit=10)
@@ -335,3 +369,199 @@ class TestSandboxWrite:
encoded_in_cmd = call_args.split("echo ")[1].split(" |")[0].strip("'")
decoded = base64.b64decode(encoded_in_cmd).decode()
assert decoded == content
# ---------------------------------------------------------------------------
# bridge_to_sandbox — copy SDK-internal files into E2B sandbox
# ---------------------------------------------------------------------------
def _make_bridge_sandbox() -> SimpleNamespace:
"""Build a sandbox mock suitable for bridge_to_sandbox tests."""
run_result = SimpleNamespace(stdout="", stderr="", exit_code=0)
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
files = SimpleNamespace(write=AsyncMock())
return SimpleNamespace(commands=commands, files=files)
class TestBridgeToSandbox:
@pytest.mark.asyncio
async def test_happy_path_small_file(self, tmp_path):
"""A small file is bridged to /tmp/<hash>-<basename> via _sandbox_write."""
f = tmp_path / "result.json"
f.write_text('{"ok": true}')
sandbox = _make_bridge_sandbox()
result = await bridge_to_sandbox(
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
)
expected = _expected_bridge_path(str(f))
assert result == expected
sandbox.commands.run.assert_called_once()
cmd = sandbox.commands.run.call_args[0][0]
assert "result.json" in cmd
sandbox.files.write.assert_not_called()
@pytest.mark.asyncio
async def test_skip_when_offset_nonzero(self, tmp_path):
"""Bridging is skipped when offset != 0 (partial read)."""
f = tmp_path / "data.txt"
f.write_text("content")
sandbox = _make_bridge_sandbox()
result = await bridge_to_sandbox(
sandbox, str(f), offset=10, limit=_DEFAULT_READ_LIMIT
)
assert result is None
sandbox.commands.run.assert_not_called()
sandbox.files.write.assert_not_called()
@pytest.mark.asyncio
async def test_skip_when_limit_too_small(self, tmp_path):
"""Bridging is skipped when limit < _DEFAULT_READ_LIMIT (partial read)."""
f = tmp_path / "data.txt"
f.write_text("content")
sandbox = _make_bridge_sandbox()
await bridge_to_sandbox(sandbox, str(f), offset=0, limit=100)
sandbox.commands.run.assert_not_called()
sandbox.files.write.assert_not_called()
@pytest.mark.asyncio
async def test_nonexistent_file_does_not_raise(self, tmp_path):
"""Bridging a non-existent file logs but does not propagate errors."""
sandbox = _make_bridge_sandbox()
await bridge_to_sandbox(
sandbox, str(tmp_path / "ghost.txt"), offset=0, limit=_DEFAULT_READ_LIMIT
)
sandbox.commands.run.assert_not_called()
sandbox.files.write.assert_not_called()
@pytest.mark.asyncio
async def test_sandbox_write_failure_returns_none(self, tmp_path):
"""If sandbox write fails, returns None (best-effort)."""
f = tmp_path / "data.txt"
f.write_text("content")
sandbox = _make_bridge_sandbox()
sandbox.commands.run.side_effect = RuntimeError("E2B timeout")
result = await bridge_to_sandbox(
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
)
assert result is None
@pytest.mark.asyncio
async def test_large_file_uses_files_api(self, tmp_path):
"""Files > 32 KB but <= 50 MB are written to /home/user/ via files.write."""
f = tmp_path / "big.json"
f.write_bytes(b"x" * (_BRIDGE_SHELL_MAX_BYTES + 1))
sandbox = _make_bridge_sandbox()
result = await bridge_to_sandbox(
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
)
expected = _expected_bridge_path(str(f), prefix="/home/user")
assert result == expected
sandbox.files.write.assert_called_once()
call_args = sandbox.files.write.call_args[0]
assert call_args[0] == expected
sandbox.commands.run.assert_not_called()
@pytest.mark.asyncio
async def test_small_binary_file_preserves_bytes(self, tmp_path):
"""A small binary file is bridged to /tmp via base64 without corruption."""
binary_data = bytes(range(256))
f = tmp_path / "image.png"
f.write_bytes(binary_data)
sandbox = _make_bridge_sandbox()
result = await bridge_to_sandbox(
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
)
expected = _expected_bridge_path(str(f))
assert result == expected
sandbox.commands.run.assert_called_once()
cmd = sandbox.commands.run.call_args[0][0]
assert "base64" in cmd
sandbox.files.write.assert_not_called()
@pytest.mark.asyncio
async def test_large_binary_file_writes_raw_bytes(self, tmp_path):
"""A large binary file is bridged to /home/user/ as raw bytes."""
binary_data = bytes(range(256)) * 200
f = tmp_path / "photo.jpg"
f.write_bytes(binary_data)
sandbox = _make_bridge_sandbox()
result = await bridge_to_sandbox(
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
)
expected = _expected_bridge_path(str(f), prefix="/home/user")
assert result == expected
sandbox.files.write.assert_called_once()
call_args = sandbox.files.write.call_args[0]
assert call_args[0] == expected
assert call_args[1] == binary_data
sandbox.commands.run.assert_not_called()
@pytest.mark.asyncio
async def test_very_large_file_skipped(self, tmp_path):
"""Files > 50 MB are skipped entirely."""
f = tmp_path / "huge.bin"
# Create a sparse file to avoid actually writing 50 MB
with open(f, "wb") as fh:
fh.seek(_BRIDGE_SKIP_BYTES + 1)
fh.write(b"\0")
sandbox = _make_bridge_sandbox()
result = await bridge_to_sandbox(
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
)
assert result is None
sandbox.commands.run.assert_not_called()
sandbox.files.write.assert_not_called()
# ---------------------------------------------------------------------------
# bridge_and_annotate — shared helper wrapping bridge_to_sandbox + annotation
# ---------------------------------------------------------------------------
class TestBridgeAndAnnotate:
@pytest.mark.asyncio
async def test_returns_annotation_on_success(self, tmp_path):
"""On success, returns a newline-prefixed annotation with the sandbox path."""
f = tmp_path / "data.json"
f.write_text('{"ok": true}')
sandbox = _make_bridge_sandbox()
annotation = await bridge_and_annotate(
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
)
expected_path = _expected_bridge_path(str(f))
assert annotation == f"\n[Sandbox copy available at {expected_path}]"
@pytest.mark.asyncio
async def test_returns_none_when_skipped(self, tmp_path):
"""When bridging is skipped (e.g. offset != 0), returns None."""
f = tmp_path / "data.json"
f.write_text("content")
sandbox = _make_bridge_sandbox()
annotation = await bridge_and_annotate(
sandbox, str(f), offset=10, limit=_DEFAULT_READ_LIMIT
)
assert annotation is None

View File

@@ -275,7 +275,7 @@ class TestCompactionE2E:
# --- Step 7: CompactionTracker receives PreCompact hook ---
tracker = CompactionTracker()
session = ChatSession.new(user_id="test-user")
session = ChatSession.new(user_id="test-user", dry_run=False)
tracker.on_compact(str(session_file))
# --- Step 8: Next SDK message arrives → emit_start ---
@@ -376,7 +376,7 @@ class TestCompactionE2E:
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
tracker = CompactionTracker()
session = ChatSession.new(user_id="test")
session = ChatSession.new(user_id="test", dry_run=False)
builder = TranscriptBuilder()
# --- First query with compaction ---

View File

@@ -20,6 +20,7 @@ config = ChatConfig()
def build_sdk_env(
session_id: str | None = None,
user_id: str | None = None,
sdk_cwd: str | None = None,
) -> dict[str, str]:
"""Build env vars for the SDK CLI subprocess.
@@ -29,25 +30,35 @@ def build_sdk_env(
``ANTHROPIC_API_KEY`` from the parent environment.
3. **OpenRouter** (default) — overrides base URL and auth token to
route through the proxy, with Langfuse trace headers.
When *sdk_cwd* is provided, ``CLAUDE_CODE_TMPDIR`` is set so that
the CLI writes temp/sub-agent output inside the per-session workspace
directory rather than an inaccessible system temp path.
"""
# --- Mode 1: Claude Code subscription auth ---
if config.use_claude_code_subscription:
validate_subscription()
return {
env: dict[str, str] = {
"ANTHROPIC_API_KEY": "",
"ANTHROPIC_AUTH_TOKEN": "",
"ANTHROPIC_BASE_URL": "",
}
if sdk_cwd:
env["CLAUDE_CODE_TMPDIR"] = sdk_cwd
return env
# --- Mode 2: Direct Anthropic (no proxy hop) ---
if not config.openrouter_active:
return {}
env = {}
if sdk_cwd:
env["CLAUDE_CODE_TMPDIR"] = sdk_cwd
return env
# --- Mode 3: OpenRouter proxy ---
base = (config.base_url or "").rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
env: dict[str, str] = {
env = {
"ANTHROPIC_BASE_URL": base,
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
@@ -65,4 +76,7 @@ def build_sdk_env(
if parts:
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
if sdk_cwd:
env["CLAUDE_CODE_TMPDIR"] = sdk_cwd
return env

Some files were not shown because too many files have changed in this diff Show More