Compare commits

...

103 Commits

Author SHA1 Message Date
Nicholas Tindle
2826532bc1 fix(backend): validate email format on waitlist join endpoint
Use pydantic.EmailStr for the email parameter so FastAPI/Pydantic
reject malformed emails before they reach the database.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 22:35:51 -06:00
Nicholas Tindle
09c5bc205f fix(backend): address PR review feedback from majdyz
- Remove try/except blocks from admin routes, rely on global exception
  handlers in rest_api.py
- Rename schema relation fields to PascalCase (WaitlistEntries,
  JoinedWaitlists, JoinedUsers) to match codebase convention
- Update all Prisma include/data references in db.py accordingly

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 19:50:56 -06:00
Nicholas Tindle
676fc6647b Merge branch 'dev' into ntindle/waitlist 2026-03-05 17:28:05 -06:00
Bently
5d56548e6b fix(frontend): prevent crash on /library with 401 error from pagination helper (#12292)
## Changes
Fixes crash on `/library` page when backend returns a 401 authentication
error.

### Problem

When the backend returns a 401 error, React Query still calls
`getNextPageParam` with the error response. The response doesn't have
the expected pagination structure, causing `pagination` to be
`undefined`. The code then crashes trying
 to access `pagination.current_page`.

Error:
TypeError: Cannot read properties of undefined (reading 'current_page')
    at Object.getNextPageParam

### Solution

Added a defensive null check in `getPaginationNextPageNumber()` to
handle cases where `pagination` is undefined:

```typescript
const { pagination } = lastPage.data;
if (!pagination) return undefined;
```
When undefined is returned, React Query interprets this as "no next page
available" and gracefully stops pagination instead of crashing.

Testing

- Manual testing: Verify /library page handles 401 errors without
crashing
- The fix is defensive and doesn't change behavior for successful
responses

Related Issues

Closes OPEN-2684
2026-03-05 19:52:36 +00:00
Otto
6ecf55d214 fix(frontend): fix 'Open link' button text color to white for contrast (#12304)
Requested by @ntindle

The Streamdown external link safety modal's "Open link" button had dark
text (`color: black`) on a dark background, making it unreadable.
Changed to `color: white` for proper contrast per our design system.

**File:** `autogpt_platform/frontend/src/app/globals.css`

Resolves SECRT-2061

---
Co-authored-by: Nick Tindle (@ntindle)
2026-03-05 19:50:39 +00:00
Bently
7c8c7bf395 feat(llm): add Claude Sonnet 4.6 model (#12158)
## Summary
Adds Claude Sonnet 4.6 (`claude-sonnet-4-6`) to the platform.

## Model Details (from [Anthropic
docs](https://www.anthropic.com/news/claude-sonnet-4-6))
- **API ID:** `claude-sonnet-4-6`
- **Pricing:** $3 / input MTok, $15 / output MTok (same as Sonnet 4.5)
- **Context window:** 200K tokens (1M beta)
- **Max output:** 64K tokens
- **Knowledge cutoff:** Aug 2025 (reliable), Jan 2026 (training data)

## Changes
- Added `CLAUDE_4_6_SONNET` to `LlmModel` enum
- Added metadata entry with correct context/output limits
- Updated Stagehand to use Sonnet 4.6 (better for browser automation
tasks)

## Why
Sonnet 4.6 brings major improvements in coding, computer use, and
reasoning. Developers with early access often prefer it to even Opus
4.5.

---------

Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
2026-03-05 19:36:56 +00:00
Zamil Majdy
be18436e8f Merge branch 'master' of github.com:Significant-Gravitas/AutoGPT into dev 2026-03-06 02:31:40 +07:00
Zamil Majdy
ea0333c1fc fix(copilot): always upload transcript instead of size-based skip (#12303)
## Summary

Fixes copilot sessions "forgetting" previous turns due to stale
transcript storage.

**Root cause:** The transcript upload logic used byte size comparison
(`existing >= new → skip`) to prevent overwriting newer transcripts with
older ones. However, with `--resume` the CLI compacts old tool results,
so newer transcripts can have **fewer bytes** despite containing **more
conversation events**. This caused the stored transcript to freeze at
whatever the largest historical upload was — every subsequent turn
downloaded the same stale transcript and the agent lost context of
recent turns.

**Evidence from prod session `41a3814c`:**
- Stored transcript: 764KB (frozen, never updated)
- Turn 1 output: 379KB (75 lines) → upload skipped (764KB >= 379KB)
- Turn 2 output: 422KB (71 lines) → upload skipped (764KB >= 422KB)
- Turn 3 output: **empty** → upload skipped
- Agent resumed from the same stale 764KB transcript every turn, losing
context of the PR it created

**Fix:** Remove the size comparison entirely. The executor holds a
cluster lock per session, so concurrent uploads cannot race. Just always
overwrite with the latest transcript.

## Test plan
- [x] `poetry run pytest backend/copilot/sdk/transcript_test.py` — 25/25
pass
- [x] All pre-commit hooks pass
- [ ] After deploy: verify multi-turn sessions retain context across
turns
2026-03-06 02:26:52 +07:00
Zamil Majdy
21c705af6e fix(backend/copilot): prevent title update from overwriting session messages (#12302)
### Changes 🏗️

Fixes a race condition in `update_session_title()` where the background
title generation task could overwrite the Redis session cache with a
stale snapshot, causing the copilot to "forget" its previous turns.

**Root cause:** `update_session_title()` performs a read-modify-write on
the Redis cache (read full session → set title → write back). Meanwhile,
`upsert_chat_session()` writes a newer version with more messages during
streaming. If the title task reads early (e.g., 34 messages) and writes
late (after streaming persisted 101 messages), the stale 34-message
version overwrites the 101-message version. When the next message lands
on a different pod, it loads the stale session from Redis.

**Fix:** Replace the read-modify-write with a simple cache invalidation
(`invalidate_session_cache`). The title is already updated in the DB;
the next access just reloads from DB with the correct title and
messages. No locks, no deserialization of the full session blob, no risk
of stale overwrites.

**Evidence from prod logs (session `41a3814c`):**
- Pod `tm2jb` persisted session with 101 messages
- Pod `phflm` loaded session from Redis cache with only 35 messages (66
messages lost)
- The title background task ran between these events, overwriting the
cache

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] `poetry run pytest backend/copilot/model_test.py` — 15/15 pass
  - [x] All pre-commit hooks pass (ruff, black, isort, pyright)
- [ ] After deploy: verify long sessions no longer lose context on
multi-pod setups
2026-03-05 18:49:41 +00:00
Zamil Majdy
a576be9db2 fix(backend): install agent-browser + Chromium in Docker image (#12301)
The Copilot browser tool (`browser_navigate`, `browser_act`,
`browser_screenshot`) has been broken on dev because `agent-browser` CLI
+ Chromium were never installed in the backend Docker image.

### Changes 🏗️

- Added `npx playwright install-deps chromium` to install Chromium
runtime libraries (libnss3, libatk, etc.)
- Added `npm install -g agent-browser` to install the CLI
- Added `agent-browser install` to download the Chromium binary
- Layer is placed after existing COPY-from-builder lines to preserve
Docker cache ordering

### Root cause

Every `browser_navigate` call fails with:
```
WARNING  [browser_navigate] open failed for <url>: agent-browser is not installed
(run: npm install -g agent-browser && agent-browser install).
```
The error originates from `FileNotFoundError` in `agent_browser.py:101`
when the subprocess tries to execute the `agent-browser` binary which
doesn't exist in the container.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Verified `agent-browser` binary is missing from current dev pod
via `kubectl logs`
- [x] Confirmed session `01eeac29-5a7` shows repeated failures for all
URLs
- [ ] After deploy: verify browser_navigate works in a Copilot session
on dev

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)
2026-03-05 18:44:55 +00:00
dependabot[bot]
5e90585f10 chore(deps): bump crazy-max/ghaction-github-runtime from 3 to 4 (#12262)
Bumps
[crazy-max/ghaction-github-runtime](https://github.com/crazy-max/ghaction-github-runtime)
from 3 to 4.
<details>
<summary>Release notes</summary>
<p><em>Sourced from <a
href="https://github.com/crazy-max/ghaction-github-runtime/releases">crazy-max/ghaction-github-runtime's
releases</a>.</em></p>
<blockquote>
<h2>v3.1.0</h2>
<ul>
<li>Bump <code>@​actions/core</code> from 1.10.0 to 1.11.1 in <a
href="https://redirect.github.com/crazy-max/ghaction-github-runtime/pull/58">crazy-max/ghaction-github-runtime#58</a></li>
<li>Bump braces from 3.0.2 to 3.0.3 in <a
href="https://redirect.github.com/crazy-max/ghaction-github-runtime/pull/54">crazy-max/ghaction-github-runtime#54</a></li>
<li>Bump cross-spawn from 7.0.3 to 7.0.6 in <a
href="https://redirect.github.com/crazy-max/ghaction-github-runtime/pull/59">crazy-max/ghaction-github-runtime#59</a></li>
<li>Bump ip from 2.0.0 to 2.0.1 in <a
href="https://redirect.github.com/crazy-max/ghaction-github-runtime/pull/50">crazy-max/ghaction-github-runtime#50</a></li>
<li>Bump micromatch from 4.0.5 to 4.0.8 in <a
href="https://redirect.github.com/crazy-max/ghaction-github-runtime/pull/55">crazy-max/ghaction-github-runtime#55</a></li>
<li>Bump tar from 6.1.14 to 6.2.1 in <a
href="https://redirect.github.com/crazy-max/ghaction-github-runtime/pull/51">crazy-max/ghaction-github-runtime#51</a></li>
</ul>
<p><strong>Full Changelog</strong>: <a
href="https://github.com/crazy-max/ghaction-github-runtime/compare/v3.0.0...v3.1.0">https://github.com/crazy-max/ghaction-github-runtime/compare/v3.0.0...v3.1.0</a></p>
</blockquote>
</details>
<details>
<summary>Commits</summary>
<ul>
<li><a
href="04d248b846"><code>04d248b</code></a>
Merge pull request <a
href="https://redirect.github.com/crazy-max/ghaction-github-runtime/issues/76">#76</a>
from crazy-max/node24</li>
<li><a
href="c8f8e4e4e2"><code>c8f8e4e</code></a>
node 24 as default runtime</li>
<li><a
href="494a382acb"><code>494a382</code></a>
Merge pull request <a
href="https://redirect.github.com/crazy-max/ghaction-github-runtime/issues/68">#68</a>
from crazy-max/dependabot/npm_and_yarn/actions/core-2.0.1</li>
<li><a
href="5d51b8ef32"><code>5d51b8e</code></a>
Merge pull request <a
href="https://redirect.github.com/crazy-max/ghaction-github-runtime/issues/74">#74</a>
from crazy-max/dependabot/npm_and_yarn/minimatch-3.1.5</li>
<li><a
href="f7077dccce"><code>f7077dc</code></a>
chore: update generated content</li>
<li><a
href="4d1e03547a"><code>4d1e035</code></a>
chore(deps): bump minimatch from 3.1.2 to 3.1.5</li>
<li><a
href="b59d56d5bc"><code>b59d56d</code></a>
chore(deps): bump <code>@​actions/core</code> from 1.11.1 to 2.0.1</li>
<li><a
href="6d0e2ef281"><code>6d0e2ef</code></a>
Merge pull request <a
href="https://redirect.github.com/crazy-max/ghaction-github-runtime/issues/75">#75</a>
from crazy-max/esm</li>
<li><a
href="41d6f6acdb"><code>41d6f6a</code></a>
remove codecov config</li>
<li><a
href="b5018eca65"><code>b5018ec</code></a>
chore: update generated content</li>
<li>Additional commits viewable in <a
href="https://github.com/crazy-max/ghaction-github-runtime/compare/v3...v4">compare
view</a></li>
</ul>
</details>
<br />


[![Dependabot compatibility
score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=crazy-max/ghaction-github-runtime&package-manager=github_actions&previous-version=3&new-version=4)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)

Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.

[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)

---

<details>
<summary>Dependabot commands and options</summary>
<br />

You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot show <dependency name> ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)


</details>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
2026-03-05 15:59:06 +00:00
Zamil Majdy
3e22a0e786 fix(copilot): pin claude-agent-sdk to 0.1.45 to fix tool_reference content block validation error (#12294)
Requested by @majdyz

## Problem

CoPilot throws `400 Invalid Anthropic Messages API request` errors on
first message, both locally and on Dev.

## Root Cause

The CLI's built-in `ToolSearch` tool returns `tool_reference` content
blocks (`{"type": "tool_reference", "tool_name":
"mcp__copilot__find_block"}`). When the CLI constructs the next
Anthropic API request, it passes these blocks as-is in the
`tool_result.content` field. However, the Anthropic Messages API only
accepts `text` and `image` content block types in tool results.

This causes a Zod validation error:
```
messages[3].content[0].content: Invalid input: expected string, received array
```

The error only manifests when using **OpenRouter** (`ANTHROPIC_BASE_URL`
set) because the Anthropic TypeScript SDK performs stricter client-side
Zod validation in that code path vs the subscription auth path.

PR #12288 bumped `claude-agent-sdk` from `0.1.39` to `^0.1.46`, which
upgraded the bundled Claude CLI from `v2.1.49` to `v2.1.69` where this
issue was introduced.

## Fix

Pin to `0.1.45` which has a CLI version that doesn't produce
`tool_reference` content blocks in tool results.

## Testing
- CoPilot first message should work without 400 errors via OpenRouter
- SDK compat tests should still pass
2026-03-05 13:12:26 +00:00
Ubbe
6abe39b33a feat(frontend/copilot): add per-turn work-done summary stats (#12257)
## Summary
- Adds per-turn work-done counters (e.g. "3 searches", "1 agent run")
shown as plain text on the final assistant message of each
user/assistant interaction pair
- Counters aggregate tool calls by category (searches, agents run,
blocks run, agents created/edited, agents scheduled)
- Copy and TTS actions now appear only on the final assistant message
per turn, with text aggregated from all assistant messages in that turn
- Removes the global JobStatsBar above the chat input

Resolves: SECRT-2026

## Test plan
- [ ] Work-done counters appear only on the last assistant message of
each turn (not on intermediate assistant messages)
- [ ] Counters increment correctly as tool call parts appear in messages
- [ ] Internal operations (add_understanding, search_docs, get_doc_page,
find_block) are NOT counted
- [ ] Max 3 counter categories shown, sorted by volume
- [ ] Copy/TTS actions appear only on the final assistant message per
turn
- [ ] Copy/TTS aggregate text from all assistant messages in the turn
- [ ] No counters or actions shown while streaming is still in progress
- [ ] No type errors, lint errors, or format issues introduced

Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 19:32:48 +08:00
Zamil Majdy
476cf1c601 feat(copilot): support Claude Code subscription auth for SDK mode (#12288)
## Summary

- Adds `CHAT_USE_CLAUDE_CODE_SUBSCRIPTION` config flag to let the
copilot SDK path use the Claude CLI's own subscription auth (from
`claude login`) instead of API keys
- When enabled, the SDK subprocess inherits CLI credentials — no
`ANTHROPIC_BASE_URL`/`AUTH_TOKEN` override is injected
- Forces SDK mode regardless of LaunchDarkly flag (baseline path uses
`openai.AsyncOpenAI` which requires an API key)
- Validates CLI installation on first use with clear error messages

## Setup

```bash
npm install -g @anthropic-ai/claude-code
claude login
# then set in .env:
CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true
```

## Changes

| File | Change |
|------|--------|
| `copilot/config.py` | New `use_claude_code_subscription` field + env
var validator |
| `copilot/sdk/service.py` | `_validate_claude_code_subscription()` +
`_build_sdk_env()` early-return + fail-fast guard |
| `copilot/executor/processor.py` | Force SDK mode via short-circuit
`or` |

## Test plan

- [ ] Set `CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true`, unset all API keys
- [ ] Run `claude login` on the host
- [ ] Start backend, send a copilot message — verify SDK subprocess uses
CLI auth
- [ ] Verify existing OpenRouter/API key flows still work (no
regression)
2026-03-05 09:55:35 +00:00
Zamil Majdy
25022f2d1e fix(copilot): handle empty tool_call arguments in baseline path (#12289)
## Summary

Handle empty/None `tool_call.arguments` in the baseline copilot path
that cause OpenRouter 400 errors when converting to Anthropic format.

## Changes

**`backend/copilot/baseline/service.py`**:
- Default empty `tc["arguments"]` to `"{}"` to prevent OpenRouter from
failing on empty tool arguments during format conversion.

## Test plan

- [x] Existing baseline tests pass
- [ ] Verify on staging: trigger a tool call in baseline mode and
confirm normal flow works
2026-03-05 09:53:05 +00:00
Zamil Majdy
ce1675cfc7 feat(copilot): add Langfuse tracing to baseline LLM path (#12281)
## Summary
Depends on #12276 (baseline code).

- Swap shared OpenAI client to `langfuse.openai.AsyncOpenAI` —
auto-captures all LLM calls (token usage, latency, model, prompts) as
Langfuse generations when configured
- Add `propagate_attributes()` context in baseline streaming for
`user_id`/`session_id` attribution, matching the SDK path's OTEL tracing
- No-op when Langfuse is not configured — `langfuse.openai.AsyncOpenAI`
falls back to standard `openai.AsyncOpenAI` behavior

## Observability parity

| Aspect | SDK path | Baseline path (after this PR) |
|--------|----------|-------------------------------|
| LLM call tracing | OTEL via `configure_claude_agent_sdk()` |
`langfuse.openai.AsyncOpenAI` auto-instrumentation |
| User/session context | `propagate_attributes()` |
`propagate_attributes()` |
| Langfuse prompts | Shared `_build_system_prompt()` | Shared
`_build_system_prompt()` |
| Token/cost tracking | Via OTEL spans | Via Langfuse generation objects
|

## Test plan
- [x] `poetry run format` passes (pyright, ruff, black, isort)
- [ ] Verify Langfuse traces appear for baseline path with
`CHAT_USE_CLAUDE_AGENT_SDK=false`
- [ ] Verify SDK path tracing is unaffected
2026-03-05 09:51:16 +00:00
Otto
3d0ede9f34 feat(backend/copilot): attach uploaded images and PDFs as multimodal vision blocks (#12273)
Requested by @majdyz

When users upload images or PDFs to CoPilot, the AI couldn't see the
content because the CLI's Zod validator rejects large base64 in MCP tool
results and even small images were misidentified (the CLI silently drops
or corrupts image content blocks in tool results).

## Approach

Embed uploaded images directly as **vision content blocks** in the user
message via `client._transport.write()`. The SDK's `client.query()` only
accepts string content, so we bypass it for multimodal messages —
writing a properly structured user message with `[...image_blocks,
{"type": "text", "text": query}]` directly to the transport. This
ensures the CLI binary receives images as native vision blocks, matching
how the Anthropic API handles multimodal input.

For binary files accessed via workspace tools at runtime, we save them
to the SDK's ephemeral working directory (`sdk_cwd`) and return a file
path for the CLI's built-in `Read` tool to handle natively.

## Changes

### Vision content blocks for attached files — `service.py`
- `_prepare_file_attachments` downloads workspace files before the
query, converts images to base64 vision blocks (`{"type": "image",
"source": {"type": "base64", ...}}`)
- When vision blocks are present, writes multimodal user message
directly to `client._transport` instead of using `client.query()`
- Non-image files (PDFs, text) are saved to `sdk_cwd` with a hint to use
the Read tool

### File-path based access for workspace tools — `workspace_files.py`
- `read_workspace_file` saves binary files to `sdk_cwd` instead of
returning base64, returning a path for the Read tool

### SDK context for ephemeral directory — `tool_adapter.py`
- Added `sdk_cwd` context variable so workspace tools can access the
ephemeral directory
- Removed inline base64 multimodal block machinery
(`_extract_content_block`, `_strip_base64_from_text`, `_BLOCK_BUILDERS`,
etc.)

### Frontend — rendering improvements
- `MessageAttachments.tsx` — uses `OutputRenderers` system
(`globalRegistry` + `OutputItem`) for image/video preview rendering
instead of custom components
- `GenericTool.tsx` — uses `OutputRenderers` system for inline image
rendering of base64 content
- `routes.py` — returns 409 for duplicate workspace filenames

### Tests
- `tool_adapter_test.py` — removed multimodal extraction/stripping
tests, added `get_sdk_cwd` tests
- `service_test.py` — rewritten for `_prepare_file_attachments` with
file-on-disk assertions

Closes OPEN-3022

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-03-05 09:09:59 +00:00
Ubbe
5474f7c495 feat(frontend/copilot): add output action buttons (upvote, downvote) with Langfuse feedback (#12260)
## Summary
- Feedback is submitted to the backend Langfuse integration
(`/api/chat/sessions/{id}/feedback`) for observability
- Downvote opens a modal dialog for optional detailed feedback text (max
2000 chars)
- Buttons are hidden during streaming and appear on hover; once feedback
is selected they stay visible

## Changes
- **`AssistantMessageActions.tsx`** (new): Renders copy (CopySimple),
thumbs-up, and thumbs-down buttons using `MessageAction` from the design
system. Visual states for selected feedback (green for upvote, red for
downvote with filled icons).
- **`FeedbackModal.tsx`** (new): Dialog with a textarea for optional
downvote comment, using the design system `Dialog` component.
- **`useMessageFeedback.ts`** (new): Hook managing per-message feedback
state and backend submission via `POST
/api/chat/sessions/{id}/feedback`.
- **`ChatMessagesContainer.tsx`** (modified): Renders
`AssistantMessageActions` after `MessageContent` for assistant messages
when not streaming.
- **`ChatContainer.tsx`** (modified): Passes `sessionID` prop through to
`ChatMessagesContainer`.

## Test plan
- [ ] Verify action buttons appear on hover over assistant messages
- [ ] Verify buttons are hidden during active streaming
- [ ] Click copy button → text copied to clipboard, success toast shown
- [ ] Click upvote → green highlight, "Thank you" toast, button locked
- [ ] Click downvote → red highlight, feedback modal opens
- [ ] Submit feedback modal with/without comment → modal closes,
feedback sent
- [ ] Cancel feedback modal → modal closes, downvote stays locked
- [ ] Verify feedback POST reaches `/api/chat/sessions/{id}/feedback`

### Linear issue
Closes SECRT-2051

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 17:01:05 +08:00
Abhimanyu Yadav
f1b771b7ee feat(platform): switch builder file inputs from base64 to workspace uploads (#12226)
## Summary

Builder node file inputs were stored as base64 data URIs directly in
graph JSON, bloating saves and causing lag. This PR uploads files to the
existing workspace system and stores lightweight `workspace://`
references instead.

## What changed

- **Upload**: When a user picks a file in a builder node input, it gets
uploaded to workspace storage and the graph stores a small
`workspace://file-id#mime/type` URI instead of a huge base64 string.

- **Delete**: When a user clears a file input, the workspace file is
soft-deleted from storage so it doesn't leave orphaned files behind.

- **Execution**: Wired up `workspace_id` on `ExecutionContext` so blocks
can resolve `workspace://` URIs during graph runs. `store_media_file()`
already knew how to handle them.

- **Output rendering**: Added a renderer that displays `workspace://`
URIs as images, videos, audio players, or download cards in node output.

- **Proxy fix**: Removed a `Content-Type: text/plain` override on
multipart form responses that was breaking the generated hooks' response
parsing.

Existing graphs with base64 `data:` URIs continue to work — no migration
needed.

## Test plan

- [x] Upload file in builder → spinner shows, completes, file label
appears
- [x] Save/reload graph → `workspace://` URI persists, not base64
- [x] Clear file input → workspace file is deleted
- [x] Run graph → blocks resolve `workspace://` files correctly
- [x] Output renders images/video/audio from `workspace://` URIs
- [x] Old graphs with base64 still work

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-05 08:38:18 +00:00
Otto
aa7a2f0a48 hotfix(frontend/signup): Add missing createUser() call in password signup (#12287)
Requested by @0ubbe

Password signup was missing the backend `createUser()` call that the
OAuth callback flow already had. This caused `getOnboardingStatus()` to
fail/hang for new users whose backend record didn't exist yet, resulting
in an infinite spinner after account creation.

## Root Cause

| Flow | createUser() | getOnboardingStatus() | Result |
|------|-------------|----------------------|--------|
| OAuth signup |  Called |  Works | Redirects correctly |
| Password signup |  Missing |  Fails/hangs | Infinite spinner |

## Fix

Adds `createUser()` call in `signup/actions.ts` after session is set,
before onboarding status check — matching the OAuth callback pattern.
Includes error handling with Sentry reporting.

## Testing

- Create a new password account → should redirect without spinner
- OAuth signup unaffected (no changes to that flow)

Fixes OPEN-3023

---------

Co-authored-by: Lluis Agusti <hi@llu.lu>
2026-03-05 16:11:51 +08:00
Nicholas Tindle
3722d05b9b fix(frontend/builder): make Google Drive file inputs chainable (#12274)
Resolves: OPEN-3018

Google Drive picker fields on INPUT blocks were missing connection
handles, making them non-chainable in the new builder.

### Changes 🏗️

- **Render `TitleFieldTemplate` with `InputNodeHandle`** — uses
`getHandleId()` with `fieldPathId.$id` (which correctly resolves to e.g.
`agpt_%_spreadsheet`), fixing the previous `_@_` handle error caused by
using `idSchema.$id` (undefined for custom RJSF FieldProps)
- **Override `showHandles: !!nodeId`** in uiOptions — the INPUT block's
`generate-ui-schema.ts` sets `showHandles: false`, but Google Drive
fields need handles to be chainable
- **Hide picker content when handle is connected** — uses
`useEdgeStore.isInputConnected()` to detect wired connections and
conditionally hides the picker/placeholder UI

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Add a Google Drive file input block to a graph in the new builder
  - [x] Verify the connection handle appears on the input
  - [x] Connect another block's output to the Google Drive input handle
- [x] Verify the picker UI hides when connected and reappears when
disconnected
- [x] Verify the Google Drive picker still works normally on non-INPUT
block nodes

🤖 Generated with [Claude Code](https://claude.com/claude-code)

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Medium Risk**
> Changes input-handle ID generation and conditional rendering for
Google Drive fields in the builder; regressions could break edge
connections or hide the picker unexpectedly on some nodes.
> 
> **Overview**
> Google Drive picker fields now render a proper RJSF
`TitleFieldTemplate` (and thus input handles) using a computed
`handleId` derived from `fieldPathId.$id`, and force `showHandles` on
when a `nodeId` is present.
> 
> The picker/placeholder UI is now conditionally hidden when
`useEdgeStore.isInputConnected()` reports the input handle is connected,
preventing duplicate input UI when the value comes from an upstream
node.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
1f1df53a38. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: abhi1992002 <abhimanyu1992002@gmail.com>
Co-authored-by: Abhimanyu Yadav <122007096+Abhi1992002@users.noreply.github.com>
2026-03-05 04:28:01 +00:00
Zamil Majdy
592830ce9b feat(copilot): persist large tool outputs to workspace with retrieval instructions (#12279)
## Summary
- Large tool outputs (>80K chars) are now persisted to session workspace
storage before truncation, preventing permanent data loss
- Truncated output includes a head preview (50K chars) with clear
retrieval instructions referencing `read_workspace_file` with
offset/length
- Added `offset` and `length` parameters to `ReadWorkspaceFileTool` for
paginated reads of large files without re-triggering truncation

## Problem
Tool outputs exceeding 100K chars were permanently lost — truncated by
`StreamToolOutputAvailable.model_post_init` using middle-out truncation.
The model had no way to retrieve the full output later, causing
recursive read loops where the agent repeatedly tries to re-read
truncated data.

## Solution
1. **`BaseTool.execute()`** — When output exceeds 80K chars, persist
full output to workspace at `tool-outputs/{tool_call_id}.json`, then
replace with a head preview wrapped in `<tool-output-truncated>` tags
containing retrieval instructions
2. **`ReadWorkspaceFileTool`** — New `offset`/`length` parameters enable
paginated reads so the agent can fetch slices without re-triggering
truncation
3. **Graceful fallback** — If workspace write fails, returns raw output
unchanged for existing truncation to handle

## Test plan
- [x] `base_test.py`: 5 tests covering persist+preview, fallback on
error, small output passthrough, large output persistence, anonymous
user skip
- [x] `workspace_files_test.py`: Ranged read test covering offset+length
slice, offset-only, offset beyond file length
- [ ] CI passes
- [ ] Review comments addressed
2026-03-05 00:50:01 +00:00
Zamil Majdy
6cc680f71c feat(copilot): improve SDK loading time (#12280)
## Summary
- Skip CLI version check at worker init (saves ~300ms/request)
- Pre-warm bundled CLI binary at startup to warm OS page caches (~500ms
saved on first request per worker)
- Parallelize E2B setup, system prompt fetch, and transcript download
with `asyncio.gather()` (saves ~200-500ms)
- Enable Langfuse prompt caching with configurable TTL (default 300s)

## Test plan
- [ ] `poetry run pytest backend/copilot/sdk/service_test.py -s -vvv`
- [ ] Manual: send copilot messages via SDK path, verify resume still
works on multi-turn
- [ ] Check executor logs for "CLI pre-warm done" messages
2026-03-05 00:49:14 +00:00
Otto
b342bfa3ba fix(frontend): revalidate layout after email/password login (#12285)
Requested by @ntindle

After logging in with email/password, the page navigates but renders a
blank/unauthenticated state (just logo + cookie banner). A manual page
refresh fixes it.

The `login` server action calls `signInWithPassword()` server-side but
doesn't call `revalidatePath()`, so Next.js serves cached RSC payloads
that don't reflect the new auth state. The OAuth callback route already
does this correctly.

**Fix:** Add `revalidatePath(next, "layout")` after successful login,
matching the OAuth callback pattern.

Closes SECRT-2059
2026-03-04 22:25:48 +00:00
Zamil Majdy
0215332386 feat(copilot): remove legacy copilot, add baseline non-SDK mode with tool calling (#12276)
## Summary
- Remove ~1200 lines of broken/unmaintained non-SDK copilot streaming
code (retry logic, parallel tool calls, context window management)
- Add `stream_chat_completion_baseline()` as a clean fallback LLM path
with full tool-calling support when `CHAT_USE_CLAUDE_AGENT_SDK=false`
(e.g. when Anthropic is down)
- Baseline reuses the same shared `TOOL_REGISTRY`,
`get_available_tools()`, and `execute_tool()` as the SDK path
- Move baseline code to dedicated `baseline/` folder (mirrors `sdk/`
structure)
- Clean up SDK service: remove unused params, fix model/env resolution,
fix stream error persistence
- Clean up config: remove `max_retries`, `thinking_enabled` fields
(non-SDK only)

## Changes
| File | Action |
|------|--------|
| `backend/copilot/baseline/__init__.py` | New — package export |
| `backend/copilot/baseline/service.py` | New — baseline streaming with
tool-call loop |
| `backend/copilot/baseline/service_test.py` | New — multi-turn keyword
recall test |
| `backend/copilot/service.py` | Remove ~1200 lines of legacy code, keep
shared helpers only |
| `backend/copilot/executor/processor.py` | Simplify branching to SDK vs
baseline |
| `backend/copilot/sdk/service.py` | Remove unused params, fix model/env
separation, fix stream error persistence |
| `backend/copilot/config.py` | Remove `max_retries`, `thinking_enabled`
|
| `backend/copilot/service_test.py` | Keep SDK test only (baseline test
moved) |
| `backend/copilot/parallel_tool_calls_test.py` | Deleted (tested
removed code) |

## Test plan
- [x] `poetry run format` passes
- [x] CI passes (all 3 Python versions, types, CodeQL)
- [ ] SDK path works unchanged in production
- [x] Baseline path (`CHAT_USE_CLAUDE_AGENT_SDK=false`) streams
responses with tool calling
- [x] Baseline emits correct Vercel AI SDK stream protocol events
2026-03-04 13:51:46 +00:00
Zamil Majdy
160d6eddfb feat(copilot): enable OpenRouter broadcast for SDK /messages endpoint (#12277)
## Summary

OpenRouter Broadcast silently drops traces for the Anthropic-native
`/api/v1/messages` endpoint unless an `x-session-id` HTTP header is
present. This was confirmed by systematic testing against our Langfuse
integration:

| Test | Endpoint | `x-session-id` header | Broadcast to Langfuse |
|------|----------|-----------------------|----------------------|
| 1 | `/chat/completions` | N/A (body fields work) |  |
| 2 | `/messages` (body fields only) |  |  |
| 3 | `/messages` (header + body) |  |  |
| 4 | `/messages` (`metadata.user_id` only) |  |  |
| 5 | `/messages` (header only) |  |  |

**Root cause:** OpenRouter only triggers broadcast for the `/messages`
endpoint when the `x-session-id` HTTP header is present — body-level
`session_id` and `metadata.user_id` are insufficient.

### Changes
- **SDK path:** Inject `x-session-id` and `x-user-id` via
`ANTHROPIC_CUSTOM_HEADERS` env var in `_build_sdk_env()`, which the
Claude Agent SDK CLI reads and attaches to every outgoing API request
- **Non-SDK path:** Add `trace` object (`trace_name` + `environment`) to
`extra_body` for richer broadcast metadata in Langfuse

This creates complementary traces alongside the existing OTEL
integration: broadcast provides cost/usage data from OpenRouter while
OTEL provides full tool-call observability with `userId`, `sessionId`,
`environment`, and `tags`.

## Test plan
- [x] Verified via test script: `/messages` with `x-session-id` header →
trace appears in Langfuse with correct `sessionId`
- [x] Verified `/chat/completions` with `trace` object → trace appears
with custom `trace_name`
- [x] Pre-commit hooks pass (ruff, black, isort, pyright)
- [ ] Deploy to dev and verify broadcast traces appear for real copilot
SDK sessions
2026-03-04 09:07:48 +00:00
Ubbe
a5db9c05d0 feat(frontend/copilot): add text-to-speech and share output actions (#12256)
## Summary
- Add text-to-speech action button to CoPilot assistant messages using
the browser Web Speech API
- Add share action button that uses the Web Share API with clipboard
fallback
- Replace inline SVG copy icon with Phosphor CopyIcon for consistency

## Linked Issue
SECRT-2052

## Test plan
- [ ] Verify copy button still works
- [ ] Click speaker icon and verify TTS reads aloud
- [ ] Click stop while playing and verify speech stops
- [ ] Click share icon and verify native share or clipboard fallback

Note: This PR should be merged after SECRT-2051 PR

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-04 16:08:54 +08:00
Otto
b74d41d50c fix(backend): handle UniqueViolationError in workspace file retry path (#12267)
Requested by @majdyz

When two concurrent requests write to the same workspace file path with
`overwrite=True`, the retry after deleting the conflicting file could
also hit a `UniqueViolationError`. This raw Prisma exception was
bubbling up unhandled to Sentry as a high-priority alert
(AUTOGPT-SERVER-7ZA).

Now the retry path catches `UniqueViolationError` specifically and
converts it to a `ValueError` with a clear message, matching the
existing pattern for the non-overwrite path.

**Change:** `autogpt_platform/backend/backend/util/workspace.py` — added
a specific `UniqueViolationError` catch before the generic `Exception`
catch in the retry block.

**Risk:** Minimal — only affects the already-failing retry path. No
behavior change for success paths.

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-03-04 07:04:50 +00:00
Otto
a897f9e124 feat(copilot): render context compaction as tool-call UI events (#12250)
Requested by @majdyz

When CoPilot compacts (summarizes/truncates) conversation history to fit
within context limits, the user now sees it rendered like a tool call —
a spinner while compaction runs, then a completion notice.

**Backend:**
- Added `compaction_start_events()`, `compaction_end_events()`,
`compaction_events()` in `response_model.py` using the existing
tool-call SSE protocol (`tool-input-start` → `tool-input-available` →
`tool-output-available`)
- All three compaction paths (legacy `service.py`, SDK pre-query, SDK
mid-stream) use the same pattern
- Pre-query and SDK-internal compaction tracked independently so neither
suppresses the other

**Frontend:**
- Added `compaction` tool category to `GenericTool` with
`ArrowsClockwise` icon
- Shows "Summarizing earlier messages…" with spinner while running
- Shows "Earlier messages were summarized" when done
- No expandable accordion — just the status line

**Cleanup:**
- Removed unused `system_notice_start/end_events`,
`COMPACTION_STARTED_MSG`
- Removed unused `system_notice_events`, `system_error_events`,
`_system_text_events`

Closes SECRT-2053

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-03-04 05:31:32 +00:00
Zamil Majdy
7fd26d3554 feat(copilot): run_mcp_tool — MCP server discovery and execution in Otto (#12213)
## Summary

Enables Otto (the AutoGPT copilot) to connect to any MCP (Model Context
Protocol) server, discover its tools, and execute them — with the same
credential login UI used in the graph builder.

**Why a dedicated `run_mcp_tool` instead of reusing `run_block` +
MCPToolBlock?**
Two blockers make `run_block` unworkable for MCP:
1. **No discovery mode** — `MCPToolBlock` errors with "No tool selected"
when `selected_tool` is empty; the agent can't learn what tools exist
before picking one.
2. **Credential matching bug** — `find_matching_credential()` (the block
execution path) does NOT check MCP server URLs; it would match any
stored MCP OAuth credential regardless of server. The correct
`_credential_is_for_mcp_server()` helper only applies in the graph path.

## Changes

### Backend
- **New `run_mcp_tool` copilot tool** (`run_mcp_tool.py`) — two-stage
flow:
1. `run_mcp_tool(server_url)` → discovers available tools via
`MCPClient.list_tools()`
2. `run_mcp_tool(server_url, tool_name, tool_arguments)` → executes via
`MCPClient.call_tool()`
- Lazy auth: fast DB credential lookup first
(`MCPToolBlock._auto_lookup_credential`); on HTTP 401/403 with no stored
creds, returns `SetupRequirementsResponse` so the frontend renders the
existing CredentialsGroupedView OAuth login card
- **New response models** in `models.py`: `MCPToolsDiscoveredResponse`,
`MCPToolOutputResponse`, `MCPToolInfo`
- **Exclude MCPToolBlock** from `find_block` / `run_block`
(`COPILOT_EXCLUDED_BLOCK_TYPES`)
- **System prompt update** — MCP section with two-step flow,
`input_schema` guidance, auth-wait instruction, and registry URL
(`registry.modelcontextprotocol.io`)

### Frontend
- **`RunMCPToolComponent`** — routes between credential prompt (reuses
`SetupRequirementsCard` from RunBlock) and result card; discovery step
shows only a minimal in-progress animation (agent-internal, not
user-facing)
- **`MCPToolOutputCard`** — renders tool result as formatted JSON or
plain text
- **`helpers.tsx`** — type guards (`isMCPToolOutput`,
`isSetupRequirementsOutput`, `isErrorOutput`), output parsing, animation
text
- Registered `tool-run_mcp_tool` case in `ChatMessagesContainer`

## Test plan

- [ ] Call `run_mcp_tool(server_url)` with a public MCP server → see
discovery animation, agent gets tool list
- [ ] Call `run_mcp_tool(server_url, tool_name, tool_arguments)` → see
`MCPToolOutputCard` with result
- [ ] Call with an auth-required server and no stored creds →
`SetupRequirementsCard` renders with MCP OAuth button
- [ ] After connecting credentials, retry → executes successfully
- [ ] `find_block("MCP")` returns no results (MCPToolBlock excluded)
- [ ] Backend unit tests: mock `MCPClient` for discovery + execution +
auth error paths

---------

Co-authored-by: Otto (AGPT) <otto@agpt.co>
2026-03-04 05:30:38 +00:00
Zamil Majdy
b504cf9854 feat(copilot): Add agent-browser multi-step browser automation tools (#12230)
## Summary

Adds three new Copilot tools for multi-step browser automation using the
[agent-browser](https://github.com/vercel-labs/agent-browser) CLI
(Playwright-based local daemon):

- **`browser_navigate`** — navigate to a URL and get an
accessibility-tree snapshot with `@ref` IDs
- **`browser_act`** — interact with page elements (click, fill, scroll,
check, press, select, `dblclick`, `type`, `wait`, back, forward,
reload); returns updated snapshot
- **`browser_screenshot`** — capture annotated screenshot (with `@ref`
overlays) and save to user workspace

Also adds **`browse_web`** (Stagehand + Browserbase) for one-shot
JS-rendered page extraction.

### Why two browser tools?

| Tool | When to use |
|------|-------------|
| `browse_web` | Single-shot extraction — cloud Browserbase session, no
local daemon needed |
| `browser_navigate` / `browser_act` | Multi-step flows (login →
navigate → scrape), persistent session within a Copilot session |

### Design decisions

- **SSRF protection**: Uses the same `validate_url()` from
`backend.util.request` as HTTP blocks — async DNS, all IPs checked, full
RFC 1918 + link-local + IPv6 coverage
- **Session isolation**: `_run()` passes both `--session <id>` (isolated
Chromium context per Copilot session) **and** `--session-name <id>`
(persist cookies within a session), preventing cross-session state
leakage while supporting login flows
- **Snapshot truncation**: Interactive-only accessibility tree
(`snapshot -i`) capped at 20 000 chars with a continuation hint
- **Screenshot storage**: PNG bytes uploaded to user workspace via
`WriteWorkspaceFileTool`; returns `file_id` for retrieval

### Bugs fixed in this PR

- Session isolation bug: `--session-name` alone shared browser history
across different Copilot sessions; added `--session` to isolate contexts
- Missing actions: added `dblclick`, `type` (append without clearing),
`wait` (CSS selector or ms delay)

## Test plan

- [x] 53 unit tests covering all three tools, all actions, SSRF
integration, auth check, session isolation, snapshot truncation,
timeout, missing binary
- [x] Integration test: real `agent-browser` CLI + Anthropic API
tool-calling loop (3/3 scenarios passed)
- [x] Linting (Ruff, isort, Black, Pyright) all passing

```
backend/copilot/tools/agent_browser_test.py  53 passed in 17.79s
```
2026-03-03 21:55:28 +00:00
Zamil Majdy
29da8db48e feat(copilot): E2B cloud sandbox — unified file tools, persistent execution, output truncation (#12212)
## Summary

- **E2B file tools**: New MCP tools
(`read_file`/`write_file`/`edit_file`/`glob`/`grep`) that operate
directly on the E2B sandbox filesystem (`/home/user`). When E2B is
active, these replace SDK built-in `Read/Write/Edit/Glob/Grep` so all
tools share a single coherent filesystem with `bash_exec` — no sync
needed.
- **E2B sandbox lifecycle**: New `e2b_sandbox.py` manages sandbox
creation and reconnection via Redis, with stale-key cleanup on
reconnection failure.
- **E2B enabled by default**: `use_e2b_sandbox` defaults to `True`; set
`CHAT_USE_E2B_SANDBOX=false` to disable.
- **Centralized output truncation**: All MCP tool outputs are truncated
via `_truncating` wrapper and stashed (`_pending_tool_outputs`) to
bypass SDK's head-truncation for the frontend.
- **Frontend tool display**: `GenericTool.tsx` now renders bash
stdout/stderr, file content, edit diffs (old/new), todo lists, and
glob/grep results with category-specific icons and status text.
- **Workspace file tools + E2B**: `read_workspace_file`'s `save_to_path`
and `write_workspace_file`'s `source_path` route to E2B sandbox when
active.

## Files changed

| Area | Files | What |
|------|-------|------|
| E2B file tools | `sdk/e2b_file_tools.py`, `sdk/e2b_file_tools_test.py`
| MCP file tool handlers + tests |
| E2B sandbox | `tools/e2b_sandbox.py` | Sandbox lifecycle
(create/reconnect/Redis) |
| Tool adapter | `sdk/tool_adapter.py` | MCP server, truncation, stash,
path validation |
| Service | `sdk/service.py` | E2B integration, prompt supplements |
| Security | `sdk/security_hooks.py`, `sdk/security_hooks_test.py` |
Path validation for E2B mode |
| Bash exec | `tools/bash_exec.py` | E2B execution path |
| Workspace files | `tools/workspace_files.py`,
`tools/workspace_files_test.py` | E2B-aware save/source paths |
| Config | `copilot/config.py` | E2B config fields (default on) |
| Truncation | `util/truncate.py` | Middle-out truncation fix |
| Frontend | `GenericTool.tsx` | Tool-specific display rendering |

## Test plan

- [x] `security_hooks_test.py` — 43 tests (path validation, tool access,
deny messages)
- [x] `e2b_file_tools_test.py` — 19 tests (path resolution, local read
safety)
- [x] `workspace_files_test.py` — 17 tests (ephemeral path validation)
- [x] CI green (backend 3.11/3.12/3.13, lint, types, e2e)
2026-03-03 21:31:38 +00:00
Nicholas Tindle
757ec1f064 feat(platform): Add file upload to copilot chat [SECRT-1788] (#12220)
## Summary

- Add file attachment support to copilot chat (documents, images,
spreadsheets, video, audio)
- Show upload progress with spinner overlays on file chips during upload
- Display attached files as styled pills in sent user messages using AI
SDK's native `FileUIPart`
- Backend upload endpoint with virus scanning (ClamAV), per-file size
limits, and per-user storage caps
- Enrich chat stream with file metadata so the LLM can access files via
`read_workspace_file`

Resolves: [SECRT-1788](https://linear.app/autogpt/issue/SECRT-1788)

### Backend
| File | Change |
|------|--------|
| `chat/routes.py` | Accept `file_ids` in stream request, enrich user
message with file metadata |
| `workspace/routes.py` | New `POST /files/upload` and `GET
/storage/usage` endpoints |
| `executor/utils.py` | Thread `file_ids` through
`CoPilotExecutionEntry` and RabbitMQ |
| `settings.py` | Add `max_file_size_mb` and `max_workspace_storage_mb`
config |

### Frontend
| File | Change |
|------|--------|
| `AttachmentMenu.tsx` | **New** — `+` button with popover for file
category selection |
| `FileChips.tsx` | **New** — file preview chips with upload spinner
state |
| `MessageAttachments.tsx` | **New** — paperclip pills rendering
`FileUIPart` in chat bubbles |
| `upload/route.ts` | **New** — Next.js API proxy for multipart uploads
to backend |
| `ChatInput.tsx` | Integrate attachment menu, file chips, upload
progress |
| `useCopilotPage.ts` | Upload flow, `FileUIPart` construction,
transport `file_ids` extraction |
| `ChatMessagesContainer.tsx` | Render file parts as
`MessageAttachments` |
| `ChatContainer.tsx` / `EmptySession.tsx` | Thread `isUploadingFiles`
prop |
| `useChatInput.ts` | `canSendEmpty` option for file-only sends |
| `stream/route.ts` | Forward `file_ids` to backend |

## Test plan

- [x] Attach files via `+` button → file chips appear with X buttons
- [x] Remove a chip → file is removed from the list
- [x] Send message with files → chips show upload spinners → message
appears with file attachment pills
- [x] Upload failure → toast error, chips revert to editable (no phantom
message sent)
- [x] New session (empty form): same upload flow works
- [x] Messages without files render normally
- [x] Network tab: `file_ids` present in stream POST body

🤖 Generated with [Claude Code](https://claude.com/claude-code)

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Medium Risk**
> Adds authenticated file upload/storage-quota enforcement and threads
`file_ids` through the chat streaming path, which affects data handling
and storage behavior. Risk is mitigated by UUID/workspace scoping, size
limits, and virus scanning but still touches security- and
reliability-sensitive upload flows.
> 
> **Overview**
> Copilot chat now supports attaching files: the frontend adds
drag-and-drop and an attach button, shows selected files as removable
chips with an upload-in-progress state, and renders sent attachments
using AI SDK `FileUIPart` with download links.
> 
> On send, files are uploaded to the backend (with client-side limits
and failure handling) and the chat stream request includes `file_ids`;
the backend sanitizes/filters IDs, scopes them to the user’s workspace,
appends an `[Attached files]` metadata block to the user message for the
LLM, and forwards the sanitized IDs through `enqueue_copilot_turn`.
> 
> The backend adds `POST /workspace/files/upload` (filename
sanitization, per-file size limit, ClamAV scan, and per-user storage
quota with post-write rollback) plus `GET /workspace/storage/usage`,
introduces `max_workspace_storage_mb` config, optimizes workspace size
calculation, and fixes executor cleanup to avoid un-awaited coroutine
warnings; new route tests cover file ID validation and upload quota/scan
behaviors.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
8d3b95d046. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-03 20:23:27 +00:00
Ubbe
9442c648a4 fix(platform/copilot): bypass Vercel SSE proxy, refactor hook architecture (#12254)
## Summary

Reliability, architecture, and UX improvements for the CoPilot SSE
streaming pipeline.

### Frontend

- **SSE proxy bypass**: Connect directly to the Python backend for SSE
streams, avoiding the Next.js serverless proxy and its 800s Vercel
function timeout ceiling
- **Hook refactor**: Decompose the 490-line `useCopilotPage` monolith
into focused domain modules:
- `helpers.ts` — pure functions (`deduplicateMessages`,
`resolveInProgressTools`)
- `store.ts` — Zustand store for shared UI state (`sessionToDelete`,
drawer open/close)
- `useCopilotStream.ts` — SSE transport, `useChat` wrapper,
reconnect/resume logic, stop+cancel
  - `useCopilotPage.ts` — thin orchestrator (~160 lines)
- **ChatMessagesContainer refactor**: Split 525-line monolith into
sub-components:
  - `helpers.ts` — pure text parsing (markers, workspace URLs)
- `components/ThinkingIndicator.tsx` — ScaleLoader animation + cycling
phrases with pulse
- `components/MessagePartRenderer.tsx` — tool dispatch switch +
workspace media
- **Stop UX fixes**:
- Guard `isReconnecting` and resume effect with `isUserStoppingRef` so
the input unlocks immediately after explicit stop (previously stuck
until page refresh)
- Inject cancellation marker locally in `stop()` so "You manually
stopped this chat" shows instantly
- **Thinking indicator polish**: Replace MorphingBlob SVG with
ScaleLoader (16px), fix initial dark circle flash via
`animation-fill-mode: backwards`, smooth `animate-pulse` text instead of
shimmer gradient
- **ChatSidebar consolidation**: Reads `sessionToDelete` from Zustand
store instead of duplicating delete state/mutation locally
- **Auth error handling**: `getAuthHeaders()` throws on failure instead
of silently returning empty headers; 401 errors show user-facing toast
- **Stale closure fix**: Use refs for reconnect guards to avoid stale
closures during rapid reconnect cycles
- **Session switch resume**: Clear `hasResumedRef` on session switch so
returning to a session with an active stream auto-reconnects
- **Target session cache invalidation**: Invalidate the target session's
React Query cache on switch so `active_stream` is accurate for resume
- **Dedup hardening**: Content-fingerprint dedup resets on non-assistant
messages, preventing legitimate repeated responses from being dropped
- **Marker prefixes**: Hex-suffixed markers (`[__COPILOT_ERROR_f7a1__]`)
to prevent LLM false-positives
- **Code style**: Remove unnecessary `useCallback` wrappers per project
convention, replace unsafe `as` cast with runtime type guard

### Backend (minimal)

- **Faster heartbeat**: 10s → 3s interval to keep SSE alive through
proxies/LBs
- **Faster stall detection**: SSE subscriber queue timeout 30s → 10s
- **Marker prefixes**: Matching hex-suffixed prefixes for error/system
markers

## Test plan

- [ ] Verify SSE streams connect directly to backend (no Next.js proxy
in network tab)
- [ ] Verify reconnect works on transient disconnects (up to 3 attempts
with backoff)
- [ ] Verify auth failure shows user-facing toast
- [ ] Verify switching sessions and switching back shows messages and
resumes active stream
- [ ] Verify deleting a chat from sidebar works (shared Zustand state)
- [ ] Verify mobile drawer delete works (shared Zustand state)
- [ ] Verify thinking indicator shows ScaleLoader + pulsing text, no
dark circle flash
- [ ] Verify stopping a stream immediately unlocks the input and shows
"You manually stopped this chat"
- [ ] Verify marker prefix parsing still works with hex-suffixed
prefixes
- [ ] `pnpm format && pnpm lint && pnpm types` pass

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-03 19:56:24 +08:00
Zamil Majdy
1c51dd18aa fix(test): backdate UserBalance.updatedAt in test_block_credit_reset (#12236)
## Root cause

The test constructs \`month3\` using \`datetime.now().replace(month=3,
day=1)\` — hardcoded to **March of the real current year**. When
\`update(balance=400)\` runs, Prisma auto-sets \`updatedAt\` to the
**real wall-clock time**.

The refill guard in \`BetaUserCredit.get_credits\` is:
\`\`\`python
if (snapshot_time.year, snapshot_time.month) == (cur_time.year,
cur_time.month):
    return balance  # same month → skip refill
\`\`\`

This means the test only fails when run **during the real month of
March**, because the mocked \`month3\` and the real \`updatedAt\` both
land in March:

| Test runs in | \`snapshot_time\` (real \`updatedAt\`) | \`cur_time\`
(mocked month3) | Same? | Result |
|---|---|---|---|---|
| January 2026 | \`(2026, 1)\` | \`(2026, 3)\` |  | refill triggers  |
| February 2026 | \`(2026, 2)\` | \`(2026, 3)\` |  | refill triggers 
|
| **March 2026** | **\`(2026, 3)\`** | **\`(2026, 3)\`** | **** |
**skips refill ** |
| April 2026 | \`(2026, 4)\` | \`(2026, 3)\` |  | refill triggers  |

It would silently pass again in April, then fail again next March 2027.

## Fix

Explicitly pass \`updatedAt=month2\` when updating the balance to 400,
so the month2→month3 transition is correctly detected regardless of when
the test actually runs. This matches the existing pattern used earlier
in the same test for the month1 setup.

## Test plan
- [ ] \`pytest backend/data/credit_test.py::test_block_credit_reset\`
passes
- [ ] No other credit tests broken
2026-03-01 07:46:04 +00:00
Zamil Majdy
6f4f80871d feat(copilot): Langfuse SDK tracing for Claude Agent SDK path (#12228)
## Problem

The Copilot SDK path (`ClaudeSDKClient`) routes API calls through `POST
/api/v1/messages` (Anthropic-native endpoint). OpenRouter Broadcast
**silently excludes** this endpoint — it only forwards `POST
/api/v1/chat/completions` (OpenAI-compat) to Langfuse. As a result, all
SDK-path turns were invisible in Langfuse.

**Root cause confirmed** via live pod test: two HTTP calls (one per
endpoint), only the `/chat/completions` one appeared in Langfuse.

## Solution

Add **Langfuse SDK direct tracing** in `sdk/service.py`, wrapping each
`stream_chat_completion_sdk()` call with a `generation` observation.

### What gets captured per user turn
| Field | Value |
|---|---|
| `name` | `copilot-sdk-session` |
| `model` | resolved SDK model |
| `input` | user message |
| `output` | final accumulated assistant text |
| `usage_details.input` | aggregated input tokens (from
`ResultMessage.usage`) |
| `usage_details.output` | aggregated output tokens |
| `cost_details.total` | total cost USD |
| trace `session_id` | copilot session ID |
| trace `user_id` | authenticated user ID |
| trace `tags` | `["sdk"]` |

Token counts and cost are **aggregated** across all internal Anthropic
API calls in the session (tool-use turns included), sourced from
`ResultMessage.usage`.

### Implementation notes
- Span is opened via
`start_as_current_observation(as_type='generation')` before
`ClaudeSDKClient` enters
- Span is **always closed in `finally`** — survives errors,
cancellations, and user stops
- Fails open: any Langfuse init error is caught and logged at `DEBUG`,
tracing is disabled for that turn but the session continues normally
- Only runs when `_is_langfuse_configured()` returns true (same guard as
the non-SDK path)

## Also included

`reproduce_openrouter_broadcast_gap.py` — standalone repro script (no
sensitive data) demonstrating that `/api/v1/messages` is not captured by
OpenRouter Broadcast while `/api/v1/chat/completions` is. To be filed
with OpenRouter support.

## Test plan

- [ ] Deploy to dev, send a Copilot message via the SDK path
- [ ] Confirm trace appears in Langfuse with `tags=["sdk"]`, correct
`session_id`/`user_id`, non-zero token counts
- [ ] Confirm session still works normally when `LANGFUSE_PUBLIC_KEY` is
not set (no-op path)
- [ ] Confirm session still works on error/cancellation (span closed in
finally)
2026-02-27 16:26:46 +00:00
Ubbe
e8cca6cd9a feat(frontend/copilot): migrate ChatInput to ai-sdk prompt-input component (#12207)
## Summary

- **Migrate ChatInput** to composable `PromptInput*` sub-components from
AI SDK Elements, replacing the custom implementation with a boxy,
Claude-style input layout (textarea + footer with tools and submit)
- **Eliminate JS-based DOM height manipulation** (60+ lines removed from
`useChatInput.ts`) in favor of CSS-driven auto-resize via
`min-h`/`max-h`, fixing input sizing jumps (SECRT-2040)
- **Change stop button color** from red to black (`bg-zinc-800`) per
SECRT-2038, while keeping mic recording button red
- **Add new UI primitives**: `InputGroup`, `Spinner`, `Textarea`, and
`prompt-input` composition layer

### New files
- `src/components/ai-elements/prompt-input.tsx` — Composable prompt
input sub-components (PromptInput, PromptInputBody, PromptInputTextarea,
PromptInputFooter, PromptInputTools, PromptInputButton,
PromptInputSubmit)
- `src/components/ui/input-group.tsx` — Layout primitive with flex-col
support, rounded-xlarge styling
- `src/components/ui/spinner.tsx` — Loading spinner using Phosphor
CircleNotch
- `src/components/ui/textarea.tsx` — Base shadcn Textarea component

### Modified files
- `ChatInput.tsx` — Rewritten to compose PromptInput* sub-components
with InputGroup
- `useChatInput.ts` — Simplified: removed maxRows, hasMultipleLines,
handleKeyDown, all DOM style manipulation
- `useVoiceRecording.ts` — Removed `baseHandleKeyDown` dependency;
PromptInputTextarea handles Enter→submit natively

## Resolves
- SECRT-2042: Migrate copilot chat input to ai-sdk prompt-input
component
- SECRT-2038: Change stop button color from red to black

## Test plan
- [ ] Type a message and send it — verify it submits and clears the
input
- [ ] Multi-line input grows smoothly without sizing jumps
- [ ] Press Enter to send, Shift+Enter for new line
- [ ] Voice recording: press space on empty input to start, space again
to stop
- [ ] Mic button stays red while recording; stop-generating button is
black
- [ ] Input has boxy rectangular shape with rounded-xlarge corners
- [ ] Streaming: stop button appears during generation, clicking it
stops the stream
- [ ] EmptySession layout renders correctly with the new input
- [ ] Input is disabled during transcription with "Transcribing..."
placeholder

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-27 15:24:19 +00:00
Abhimanyu Yadav
bf6308e87c fix(frontend): truncate node output and fix dialog overflow (#12222)
## Summary
- **Truncate node output on canvas**: Input content now uses
`shortContent={true}` so it renders truncated text instead of full
content. Output items are capped to 3 per pin with `.slice(0, 3)`.
- **Increase truncation limit**: `TextRenderer` truncation raised from
100 to 200 characters for better readability.
- **Fix dialog content overflow**: Removed legacy `ScrollArea` from the
Full Preview dialog (`NodeDataViewer`) — it was preventing proper width
constraint, causing JSON/code content to overflow beyond the dialog
boundary. Replaced with a simple flex container that respects the
dialog's width.
- **Reposition action buttons**: Copy/download buttons moved from
right-side/absolute positioning to below the content, aligned left, for
better layout with horizontally-scrollable content.
- **Add overflow protection to ContentRenderer**: Added
`overflow-hidden` and `pre` word-wrap rules to prevent content from
breaking out of the node card on the canvas.

## Test plan
- [x] Open a node with long JSON output on the builder canvas — verify
content is truncated
- [x] Click the expand button to open "Full Output Preview" dialog —
verify content stays within dialog bounds and scrolls horizontally if
needed
- [x] Verify copy/download buttons appear below the content,
left-aligned
- [x] Check that input data also shows truncated on the canvas node
- [x] Verify output items are capped at 3 per pin on the canvas

🤖 Generated with [Claude Code](https://claude.com/claude-code)
2026-02-27 10:53:16 +00:00
Otto
4e59143d16 Add plans/ to .gitignore (#12229)
Requested by @torantula

Adds `plans/` to `.gitignore` and removes one existing tracked plan
file.
2026-02-27 10:12:46 +00:00
Reinier van der Leer
d5efb6915b dx: Sync types & dependencies on pre-commit and post-checkout (#12211)
Our pre-commit hooks can use an update: the type check often fails based
on stale type definitions, the OpenAPI schema isn't synced/checked, and
the pre-push checks aren't installed by default.

### Changes 🏗️

- Regenerate Prisma `.pyi` type stub in on `prisma generate` hook:
Pyright prefers `.pyi` over `.py`, so a stale stub shadows the
regenerated `types.py`
- Also run setup hooks (dependency install, `prisma generate`, `pnpm
generate:api`) on `post-checkout`, to keep types and packages in sync
after switching branches
- Switch these hooks to `git diff` checks because `post-checkout`
doesn't support file triggers/filters
- Add `Check & Install dependencies - AutoGPT Platform - Frontend` hook
- Add `Sync API types - AutoGPT Platform - Backend -> Frontend` hook
- Fix non-ASCII issue in `export-api-schema` (`ensure_ascii=False`)
- Exclude `pnpm-lock.yaml` from `detect-secrets` hook (integrity hashes
cause ~1800 false positives)
- Add `default_stages: [pre-commit]`
- Add `post-checkout`, `pre-push` to `default_install_hook_types`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Tested locally
2026-02-26 22:28:59 +01:00
Nicholas Tindle
b9aac42056 Merge branch 'master' into dev 2026-02-26 13:39:34 -06:00
Otto
95651d33da feat(backend): add fpdf2 dependency for PDF operations in copilot executor (#12216)
Requested by @majdyz

Adds [fpdf2](https://github.com/py-pdf/fpdf2) (v2.8.6) to backend
dependencies to enable PDF generation and manipulation in the copilot
executor.

fpdf2 is a lightweight PDF generation library (no external dependencies,
pure Python) that allows creating PDFs with text, images, tables, and
more.
2026-02-26 18:21:34 +00:00
Zamil Majdy
b30418d833 fix(copilot): inject working directory into SDK prompt + workspace download links (#12215)
## Summary

- Replaces the static `_SDK_TOOL_SUPPLEMENT` placeholder path with
`_build_sdk_tool_supplement(cwd: str)` that injects the session-specific
working directory
- `sdk_cwd` is computed once via `_make_sdk_cwd(session_id)`,
`os.makedirs` is called after lock acquisition (inside the protected
`try/finally`), and the same variable is used everywhere — no drift
between prompt and execution directory
- Added `ValueError`/`OSError` error handling for cwd preparation with
proper `StreamError` emission
- Teaches the SDK agent how to share workspace files with the user via
`workspace://` Markdown links (images render inline, videos render with
player controls, other files as download links)
- `WorkspaceWriteResponse` now includes `download_url` (pre-formatted
`workspace://file_id#mime` string) and a normalised `mime_type` field
(MIME parameters stripped, lowercased)
- Frontend: workspace `workspace://` regular links now resolve to
absolute URLs so Streamdown's "Copy link" copies the full URL
- Frontend: Streamdown's "Open link" button colour overridden to match
the design system (violet accent) — previously near-invisible in dark
mode due to `--primary` resolving to near-white

## Motivation

The SDK agent was seeing a hardcoded placeholder path in the system
prompt instead of the real working directory, causing it to reference
wrong paths in tool calls. Additionally, there was no guidance for the
agent on how to share files it writes to the workspace with the user in
chat.

## Test plan

- [ ] CI green (test 3.11 / 3.12 / 3.13)
- [ ] Start a copilot session with `CHAT_USE_CLAUDE_AGENT_SDK=true` and
verify the agent references the correct `sdk_cwd` path in its tool calls
- [ ] Ask the agent to write a file and confirm it responds with a
clickable download link / inline image using the `workspace://` syntax
- [ ] Verify the "Open link" button in the Streamdown external-link
dialog is visible in both light and dark mode
- [ ] Click "Copy link" on a workspace file link and confirm it copies
the full URL (including host)
2026-02-26 17:26:19 +00:00
Otto
ed729ddbe2 feat(copilot): Wait for agent execution completion (#12147)
Adds the ability for CoPilot to wait for agent execution to complete
before returning results.

Closes SECRT-2003.

## Changes

### New: `execution_utils.py`
- `wait_for_execution()` — uses Redis pubsub to wait for execution to
reach terminal state
- `TERMINAL_STATUSES` — shared frozenset of completed/failed/terminated
- `PAUSED_STATUSES` — handles REVIEW (human-in-the-loop) as a
stop-waiting state
- `get_execution_outputs()` — helper to extract outputs

### `run_agent.py`
- New `wait_for_result` parameter (0-300 seconds)
- When >0, waits for execution to complete and returns outputs directly
- Handles completed, failed, terminated, review, and timeout states with
appropriate responses

### `agent_output.py` (view_agent_output)
- New `wait_if_running` parameter (0-300 seconds)
- Includes running/queued/review executions when waiting is requested
- Status-aware response messages (completed, failed, running, review,
etc.)

## How it works
1. After starting execution, subscribes to Redis pubsub channel for
execution events
2. Re-checks DB after subscribing to close the race window
3. `asyncio.wait_for` enforces the timeout
4. On completion: returns full outputs via `AgentOutputResponse`
5. On timeout: returns current state with guidance to check again later
6. On error/terminated: returns `ErrorResponse` with details
7. Redis connection always cleaned up via `finally` block

## Testing

- [x] Run an agent with `wait_for_result=0` — should return immediately
with execution ID (existing behavior)
- [x] Run a fast agent with `wait_for_result=60` — should return
completed outputs
- [x] Run a slow agent with `wait_for_result=5` — should timeout and
return current status
- [x] Use `view_agent_output` with `wait_if_running=0` on a completed
execution — should return outputs
- [x] Use `view_agent_output` with `wait_if_running=30` on a running
execution — should wait and return
- [ ] ~~Verify Redis connections are cleaned up (no leaked pubsub
connections after timeout)~~
- [ ] ~~Test with a failed execution — should return error response~~
- [ ] ~~Test with a terminated execution — should return error response
(not "still running")~~

## Collaboration

This PR was developed in collaboration with @Pwuts.

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2026-02-26 16:41:33 +00:00
Otto
8c7030af0b fix(copilot): handle 'all' keyword in find_library_agent tool (#12138)
When users ask CoPilot to "show all my agents" or similar, the LLM was
passing the literal string "all" as a search query to
`find_library_agent`, which matched no agents because there's no agent
named "all". (issue:
[SECRT-2002](https://linear.app/autogpt/issue/SECRT-2002))

## Changes

- **Make `query` parameter optional** in `FindLibraryAgentTool` - users
can now omit it to list all agents
- **Add special keyword handling** - keywords like "all", "*",
"everything", "any", or empty string are treated as "list all" rather
than literal searches
- **Update response messages** - differentiate between "listing all
agents" vs "searching for X"

## Example

Before:
```
User: Show me all my agents
CoPilot: find_library_agent(query="all")
Result: No agents matching 'all' found in your library
```

After:
```
User: Show me all my agents  
CoPilot: find_library_agent(query="all") OR find_library_agent()
Result: Found 5 agents in your library
```

## Testing

- [x] Test with "show me all my agents" prompt
- [x] Test with empty query
- [x] Test with specific search terms (should still work as before)

## Collaboration

This PR was developed in collaboration with @Pwuts.
2026-02-26 16:07:40 +00:00
Otto
195b14286a fix(frontend): fix Streamdown link safety modal and add origin check (#12209)
Requested by @ntindle

Fixes the Streamdown link safety modal in CoPilot with three changes:

**1. Fix invisible "Open link" button (HIGH)**
Added Streamdown's dist directory to the Tailwind content scan in
`tailwind.config.ts`. Previously, Tailwind was only scanning
`./src/**/*.{ts,tsx}`, so classes used by Streamdown's internal modal
components (like `bg-primary`, `text-primary-foreground`,
`hover:bg-primary/90`) were being purged. The "Open link" button
rendered invisible but remained clickable.

**2. Add same-origin URL whitelist (MEDIUM)**
Configured `linkSafety.onLinkCheck` on the `<Streamdown>` component in
`message.tsx` to whitelist same-origin URLs. Previously, ALL links
(including internal `/api/proxy/...` workspace download URLs) triggered
the "Open external link?" modal. Now same-origin links open directly.

**3. Add Storybook stories (LOW)**
Added `message.stories.tsx` with stories covering default messages, user
messages, internal/external links, the link safety modal, and
conversations.

### Testing
- [ ] Open link safety modal → "Open link" button is visible with proper
styling
- [ ] Click a workspace download link → opens directly (no modal)
- [ ] Click an external link → shows safety modal
- [ ] Verify in both light and dark mode
- [ ] Verify on mobile viewport
- [ ] Storybook stories render correctly

Fixes SECRT-2044
2026-02-26 15:19:54 +00:00
Zamil Majdy
29ca034e40 fix(backend/frontend): error handling, stream reconnection, and chat switching (#12205)
## Problem

CoPilot executions were experiencing:
1. **Duplicate error markers** - Both `execute()` and `_execute_async()`
called `mark_session_completed`, sending duplicate completion markers
2. **RuntimeError bypass** - RuntimeErrors that weren't SDK cleanup
issues bypassed error persistence logic
3. **Generic error messages** - StreamError showed "An error occurred"
instead of actual error text
4. **Empty chat on reconnect** - Messages cleared immediately when
reconnecting, before new messages arrived
5. **Stream not resuming** - Switching chats (A → B → A) didn't resume
active streams due to stale `hasResumedRef`
6. **Excessive diagnostic logging** - 60+ lines of STREAM_DIAG console
logs not needed in production

## Changes 🏗️

### 1. Consolidated Exception Handling
**Files:** `backend/copilot/executor/processor.py`,
`backend/copilot/sdk/service.py`

**processor.py:**
- Removed all error handling from `execute()` method
- Kept error handling only in `_execute_async()` where work happens
- Merged `CancelledError` and `BaseException` handlers into single catch
- Uses `isinstance()` to determine error message

**service.py:**
- Merged `CancelledError` and `Exception` handlers into single catch
- Moved RuntimeError check inside main Exception handler
- Prevents non-cancel-scope RuntimeErrors from bypassing error
persistence

**Impact:** Eliminates duplicate `mark_session_completed` calls, ~70
lines of code removed

---

### 2. Fixed StreamError Message
**File:** `backend/copilot/sdk/service.py`

- Changed from generic `"An error occurred. Please try again."`
- Now shows actual error: `errorText=error_msg`
- Provides real error details to frontend during active stream

---

### 3. Deferred Message Clearing on Reconnect
**File:** `frontend/src/app/(platform)/copilot/useCopilotPage.ts`

- Added `shouldClearOnNextMessageRef` flag
- Set flag when reconnect starts
- Clear old assistant messages only AFTER first new message arrives
- Prevents empty chat flicker during reconnection

---

### 4. Fixed Chat Switching Stream Resume
**File:** `frontend/src/app/(platform)/copilot/useCopilotPage.ts`

**Problem:** When switching Chat A → B → A, the stream didn't resume
because `hasResumedRef.current.get(sessionId)` was still `true`

**Fix:** Clear `hasResumedRef` entry when navigating away from session

**Flow now:**
1. In Chat A with active stream
2. Switch to Chat B → clears `hasResumedRef` for Chat A
3. Switch back to Chat A → `hasResumedRef` is false → resumes stream 

---

### 5. Removed Diagnostic Logging
**Files:** `frontend/useCopilotPage.ts`, `frontend/useChatSession.ts`,
`backend/stream_registry.py`, `backend/processor.py`,
`backend/routes.py`

- Removed all `[STREAM_DIAG]` console logs (60+ lines)
- Logs were useful for debugging but not needed in production
- Cleaner codebase, reduced noise in logs

---

### 6. Exception Handling Order Consistency
**File:** `backend/copilot/executor/processor.py`

- Made both CancelledError and regular exception branches follow same
pattern
- Set `error_msg` before logging in both cases
- Consistent code structure

---

## Architecture Quality: **9/10**

**Strengths:**
- Eliminated duplicate completion markers
- All RuntimeErrors now get proper error persistence
- Real error messages shown to users
- Stream resume works reliably when switching chats
- Cleaner codebase with diagnostic logs removed
- Consistent exception handling patterns

**Trade-offs:**
- Message clearing deferred means brief period with stale + new messages
(acceptable, prevents empty chat)

## Test Plan

- [x] Verify no duplicate completion markers sent
- [x] Trigger RuntimeError, verify error persists
- [x] Check StreamError shows actual error message
- [x] Reconnect, verify chat doesn't go empty
- [x] Switch Chat A → B → A with active stream, verify resume works
- [x] Verify no STREAM_DIAG logs in console
- [x] Run `pnpm format && pnpm lint && pnpm types` - all passed
- [x] Run `poetry run format` - all passed
- [ ] Test in production

## Checklist 📋

- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan  
- [x] I have tested my changes according to the test plan
- [x] `.env.default` is updated or compatible (no config changes)
- [x] `docker-compose.yml` is updated or compatible (no config changes)
2026-02-26 13:32:25 +00:00
Reinier van der Leer
1d9dd782a8 feat(backend/api): Add POST /graphs endpoint to external API (#12208)
- Resolves [SECRT-2031: Add upload agent to Library endpoint on external
API](https://linear.app/autogpt/issue/SECRT-2031)

### Changes 🏗️

- Add `POST /graphs` to v1 external API
- Add support for requiring multiple scopes in `require_permission`
middleware
- Add `WRITE_GRAPH` and `WRITE_LIBRARY` API permission scopes

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Test `POST /graphs` endpoint through `/docs` Swagger UI
2026-02-26 12:54:39 +01:00
Krzysztof Czerwinski
a1cb3d2a91 feat(blocks): Add Telegram blocks (#12141)
Add Telegram blocks that allow the use of [Telegram bots' API
features](https://core.telegram.org/bots/features).

### Changes 🏗️

1. Credentials & API layer: Bot token auth via `APIKeyCredentials`,
helper functions for JSON API calls (call_telegram_api) and multipart
file uploads (call_telegram_api_with_file)
2. Trigger blocks:
- `TelegramMessageTriggerBlock` — receives messages (text, photo, voice,
audio, document, video, edited message) with configurable event filters
- `TelegramMessageReactionTriggerBlock` — fires on reaction changes
(private chats auto, groups require admin)
2. Action blocks (11 total):
  - Send: Message, Photo, Voice, Audio, Document, Video
  - Reply to Message, Edit Message, Delete Message
  - Get File (download by file_id)
3. Webhook manager: Registers/deregisters webhooks via Telegram's
setWebhook API, validates incoming requests using
X-Telegram-Bot-Api-Secret-Token header
4. Provider registration: Added TELEGRAM to ProviderName enum and
registered `TelegramWebhooksManager`
5. Media send blocks support both URL passthrough (Telegram fetches
directly) and file upload for workspace/data URI inputs

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Non-AI UUIDs
  - [x] Blocks work correctly
    - [x] SendTelegramMessageBlock
    - [x] SendTelegramPhotoBlock
    - [x] SendTelegramVoiceBlock
    - [x] SendTelegramAudioBlock
    - [x] SendTelegramDocumentBlock
    - [x] SendTelegramVideoBlock
    - [x] ReplyToTelegramMessageBlock
    - [x] GetTelegramFileBlock
    - [x] DeleteTelegramMessageBlock
    - [x] EditTelegramMessageBlock
    - [x] TelegramMessageTriggerBlock (works for every trigger type)
    - [x] TelegramMessageReactionTriggerBlock

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2026-02-26 10:25:08 +00:00
Otto
1b91327034 fix(builder): Show X button on edge line hover, not just button hover (#12083)
## Summary

Fixes the issue where the X button for removing connections between
nodes only appears when hovering directly over the button itself. Users
now see the button when hovering anywhere on the connection line.

## Changes

- Added an invisible interaction path along the edge with a 20px stroke
width
- The path triggers the same hover state as the button
- This makes the X button visible when hovering the line OR the button
- Preserves existing behavior for broken edges (always visible)

## Testing

1. Hover over an edge line (not the button) → X button should appear
2. Move from line to button → button should stay visible  
3. Move away from both → button should fade out
4. Broken edges should still show X button always

## Linear

Fixes SECRT-1943

## Screenshots

This is a UX improvement - no visual changes except the button now
appears on line hover.

<!-- greptile_comment -->

<h2>Greptile Overview</h2>

<details><summary><h3>Greptile Summary</h3></summary>

This PR improves the UX for edge deletion by adding an invisible
interaction path with a 20px stroke width that makes the delete button
(X) appear when hovering anywhere along the connection line, not just
when hovering directly over the button.

**Key Changes:**
- Added invisible `<path>` element before `BaseEdge` with
`stroke="transparent"` and `strokeWidth={20}`
- Path has `onMouseEnter` and `onMouseLeave` handlers that trigger the
same `setIsHovered` state used by the delete button
- Delete button visibility logic remains unchanged: fades in when
`isHovered` is true (or always visible for broken edges)
- Works uniformly for all edge types (regular, static, and broken edges)

**How It Works:**
The invisible path creates a wider hit area (20px) around the edge
curve, making it much easier for users to trigger the hover state. When
the mouse enters this area, `isHovered` becomes true, which causes the
delete button to fade in (via the existing opacity transition logic).
The button itself also has hover handlers, so moving from the line to
the button maintains the visible state smoothly.
</details>


<details><summary><h3>Confidence Score: 5/5</h3></summary>

- This PR is safe to merge with minimal risk - it's a small, focused UX
improvement with no logic changes
- The implementation is clean and focused: adds only 9 lines of code,
uses existing state management (`isHovered`), and doesn't modify any
deletion logic. The invisible path is a standard SVG/React pattern for
expanding hit areas, and the approach is consistent with how the delete
button already handles hover events. No breaking changes, no side
effects.
- No files require special attention
</details>


<!-- greptile_other_comments_section -->

<!-- /greptile_comment -->

---------

Co-authored-by: Krzysztof Czerwinski <kpczerwinski@gmail.com>
2026-02-26 10:02:01 +00:00
Krzysztof Czerwinski
c7cdb40c5b feat(platform): Update new builder search (#11806)
### Changes 🏗️

- Add materialized view for suggested blocks
- Make `search` in builder accept comma separated filter list in query
- Remove Otto suggestions
- Use hybrid search for blocks search in builder
- Exclude `AgentExecutorBlocks` from builder
- Remove `Block` suffix from builder items

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Materialized view function works (when running manually)
- [x] Higher execution count blocks are shown first in "suggested
blocks" (uses materialized view)
  - [x] Hybrid search works
- [x] `AgentExecutorBlocks` doesn't appear on search results and in
blocks list
  - [x] `Block` suffix isn't displayed on blocks names in builder items

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2026-02-26 09:56:40 +00:00
Nicholas Tindle
77fb4419d0 Handle workspace:// URLs in regular markdown links (#12166)
### Changes 🏗️

Extended the `resolveWorkspaceUrls` function to handle both image syntax
(`![alt](workspace://id#mime)`) and regular link syntax
(`[text](workspace://id)`).

Previously, only image links were being resolved. Regular workspace
links were being blocked by Streamdown's rehype-harden sanitizer because
`workspace://` is not in the allowed URL-scheme whitelist, causing
"[blocked]" to appear next to link text.

The fix:
- Refactored the function to process image links first (existing
behavior)
- Added a second regex replacement to handle regular links using a
negative lookbehind (`(?<!!)`) to avoid matching image syntax
- Both patterns now resolve `workspace://` URLs to proxy download URLs
via `/api/proxy`
- Updated JSDoc comments to clarify the dual handling

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Verified image links with MIME type hints still resolve correctly
- [x] Verified regular workspace links now resolve to proxy URLs instead
of being blocked
- [x] Confirmed negative lookbehind prevents double-processing of image
syntax

https://claude.ai/code/session_0184TVJJcEoB8wbX9htCnv4b

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Low Risk**
> Low risk: a small, localized frontend markdown preprocessing change
that only rewrites `workspace://` URLs to existing `/api/proxy` download
URLs; main risk is regex edge cases affecting link rendering.
> 
> **Overview**
> Updates `resolveWorkspaceUrls` in `ChatMessagesContainer` to rewrite
**both** `workspace://` image markdown and regular markdown links into
`/api/proxy` download URLs so Streamdown sanitization no longer shows
`[blocked]` for workspace links.
> 
> Image handling is preserved (including `#video/*` MIME hints via
`video:` alt text), and a second regex pass with a negative lookbehind
avoids double-processing image syntax when rewriting plain links.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
e17749b72c. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Ubbe <hi@ubbe.dev>
Co-authored-by: Lluis Agusti <hi@llu.lu>
2026-02-25 12:33:10 +00:00
Bently
9f002ce8f6 fix(frontend): improve UX for expired or duplicate password reset links (#12123)
## Summary
Improves the user experience when a password reset link has expired or
been used, replacing the confusing generic error with a clean, helpful
message.

## Changes
- Added `ExpiredLinkMessage` component that displays a user-friendly
error state
- Updated reset password page to detect expired/used links from:
- Supabase error format
(`error=access_denied&error_code=otp_expired&error_description=...`)
  - Internal clean format (`error=link_expired`)
- Enhanced callback route to detect and map expired/invalid link errors
- Clear, actionable UI with:
  - Friendly error message explaining what happened
  - "Send Me a New Link" button to request a new reset email
  - Login link for users who already have access

## Before
Users saw a confusing URL with error parameters and an unclear form:
```
/reset-password?error=access_denied&error_code=otp_expired&error_description=Email+link+is+invalid+or+has+expired
```

## After
Users see a clean, helpful message explaining the issue and how to fix
it.

<img width="548" height="454" alt="image"
src="https://github.com/user-attachments/assets/e867e522-146c-4d43-91b3-9e62d2957f95"
/>


Closes SECRT-1369

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [ ] Navigate to `/reset-password?error=link_expired` and verify the
ExpiredLinkMessage component appears
  - [ ] Click "Send Me a New Link" and verify the email form appears
- [ ] Navigate to
`/reset-password?error=access_denied&error_code=otp_expired` and verify
same behavior

<!-- greptile_comment -->

<details><summary><h3>Greptile Summary</h3></summary>

Improved password reset UX by adding an `ExpiredLinkMessage` component
that displays when users follow expired or already-used reset links. The
implementation detects expired link errors from Supabase
(`error_code=otp_expired`) and internal format (`error=link_expired`),
replacing confusing URL parameters with a clean message.

**Key changes:**
- Added error detection logic in both callback route and reset password
page to identify expired/invalid links
- Created new `ExpiredLinkMessage` component with friendly messaging
- Enhanced error handling to differentiate between expired links and
other errors

**Issues found:**
- The "Send Me a New Link" button misleadingly suggests it will send an
email, but it only reveals the email form - user must still enter email
and submit
- `access_denied` error detection may be too broad and could incorrectly
classify non-expired errors as expired links
</details>


<details><summary><h3>Confidence Score: 3/5</h3></summary>

- This PR improves UX but has logic issues that could mislead users
- The implementation correctly detects expired links and displays
helpful UI, but the "Send Me a New Link" button doesn't actually send an
email (just shows the form), which creates a misleading user experience.
Additionally, the `access_denied` error check is overly broad and could
incorrectly classify errors. These are functional issues that should be
addressed before merge.
- Pay close attention to `page.tsx` - the `handleSendNewLink` function
and error detection logic need refinement
</details>


<details><summary><h3>Flowchart</h3></summary>

```mermaid
flowchart TD
    Start[User clicks reset link with code] --> Callback[API: /auth/callback/reset-password]
    Callback --> CheckCode{Code valid?}
    
    CheckCode -->|No - expired/invalid/used| DetectError[Detect error type]
    DetectError --> CheckExpired{Contains expired/<br/>invalid/otp_expired/<br/>already/used?}
    CheckExpired -->|Yes| RedirectExpired[Redirect to /reset-password?error=link_expired]
    CheckExpired -->|No| RedirectOther[Redirect to /reset-password?error=message]
    
    CheckCode -->|Yes| RedirectSuccess[Redirect to /reset-password]
    
    RedirectExpired --> PageLoad[Page: /reset-password]
    RedirectOther --> PageLoad
    RedirectSuccess --> PageLoad
    
    PageLoad --> ParseParams[Parse URL params]
    ParseParams --> CheckErrorParams{Has error or<br/>error_code?}
    
    CheckErrorParams -->|Yes| CheckExpiredParams{error=link_expired OR<br/>errorCode=otp_expired OR<br/>error=access_denied OR<br/>description contains<br/>expired/invalid?}
    CheckExpiredParams -->|Yes| ShowExpired[Show ExpiredLinkMessage]
    CheckExpiredParams -->|No| ShowToast[Show error toast]
    
    CheckErrorParams -->|No| CheckUser{User<br/>authenticated?}
    
    ShowExpired --> ClickButton[User clicks 'Send Me a New Link']
    ClickButton --> HideExpired[setShowExpiredMessage false]
    HideExpired --> ShowForm[Show email form]
    
    ShowToast --> ClearParams[Clear error params from URL]
    ClearParams --> CheckUser
    
    CheckUser -->|Yes| ShowPasswordForm[Show password change form]
    CheckUser -->|No| ShowForm[Show email form]
```
</details>


<sub>Last reviewed commit: 80e9f40</sub>

<!-- greptile_other_comments_section -->

<!-- /greptile_comment -->
2026-02-25 12:11:55 +00:00
Ubbe
74691076c6 fix(frontend/copilot): show clarification and agent-saved cards without accordion (#12204)
### Background

The CoPilot tool UI wraps several output cards (clarification questions,
agent saved confirmation) inside a collapsible `ToolAccordion`. This
means users have to expand the accordion to see important interactive
content — clarification questions they need to answer, or confirmation
that their agent was created/updated.

### Changes 🏗️

- **Clarification questions always visible**: Moved
`ClarificationQuestionsCard` out of the `ToolAccordion` in both
`CreateAgent` and `EditAgent` tools so users immediately see and can
answer questions without expanding an accordion
- **Agent saved card always visible**: Moved the agent-saved
confirmation card out of the `ToolAccordion` in both tools so the
success state with library/builder links is immediately visible
- **Extracted `AgentSavedCard` component**: The agent-saved card was
duplicated between `CreateAgent` and `EditAgent` — extracted it into a
shared `copilot/components/AgentSavedCard/AgentSavedCard.tsx` component,
parameterized by `message` ("has been saved to your library!" vs "has
been updated!")
- **ClarificationQuestionsCard polish**: Updated spacing, icon
(`ChatTeardropDotsIcon`), typography variants, border styles, and number
badge sizing for a cleaner look
- **Minor atom tweaks**: Lightened `secondary` button variant
(`zinc-200` → `zinc-100`), changed textarea border radius from
`rounded-3xl` to `rounded-xl`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] `pnpm format` passes
  - [x] `pnpm lint` passes
  - [x] `pnpm types` passes
- [ ] Create an agent via CoPilot and verify the saved card shows
without accordion
- [ ] Trigger clarification questions and verify they show without
accordion
- [ ] Edit an agent via CoPilot and verify the updated card shows
without accordion
- [ ] Verify the ClarificationQuestionsCard styling looks correct
(spacing, icons, borders)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-25 18:13:01 +07:00
Otto
b15ad0df9b hotfix(frontend): fix null credits TypeError on /copilot (#12202)
Requested by @majdyz

Fix `TypeError: Cannot read properties of null (reading 'credits')` on
the /copilot page.

**Sentry:**
[BUILDER-71P](https://significant-gravitas.sentry.io/issues/7256025912/)
**Linear:** SENTRY-1110

## Root Cause

Two issues combined:

1. **`getUserCredit()` had a broken try/catch** — it wasn't `await`ing
`_get()`, so async errors (including null responses) were never caught
2. **`_makeClientRequest` returns `null` during logout** — when a user
is logging out and `/credits` races with auth teardown, the response is
`null`

Chain: logout starts → `/credits` fetch races → auth error →
`_makeClientRequest` returns `null` → `getUserCredit` passes `null`
through → `fetchCredits` does `null.credits` → 💥

## Fix

- `getUserCredit()`: Add `await` + null coalescing fallback to `{
credits: 0 }`
- `fetchCredits()`: Add optional chaining guard (`response?.credits ??
null`)
2026-02-25 10:38:08 +00:00
Abhimanyu Yadav
2136defea8 feat(library): implement folder organization system for agents (#12101)
### Changes 🏗️

This PR adds folder organization capabilities to the library, allowing
users to organize their agents into folders:

- Added new `LibraryFolder` model and database schema
- Created folder management API endpoints for CRUD operations
- Implemented folder tree structure with proper parent-child
relationships
- Added drag-and-drop functionality for moving agents between folders
- Created folder creation dialog with emoji picker for folder icons
- Added folder editing and deletion capabilities
- Implemented folder navigation in the library UI
- Added validation to prevent circular references and excessive nesting
- Created animation for favoriting agents
- Updated library agent list to show folder structure
- Added folder filtering to agent list queries

<img width="1512" height="950" alt="Screenshot 2026-02-13 at 9 08 45 PM"
src="https://github.com/user-attachments/assets/78778e03-4349-4d50-ad71-d83028ca004a"
/>

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Create a new folder with custom name, icon, and color
  - [x] Move agents into folders via drag and drop
  - [x] Move agents into folders via context menu
  - [x] Navigate between folders
  - [x] Edit folder properties (name, icon, color)
  - [x] Delete folders and verify agents return to root
  - [x] Verify favorite animation works when adding to favorites
  - [x] Test folder navigation with search functionality
  - [x] Verify folder tree structure is maintained

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

<!-- greptile_comment -->

<h2>Greptile Overview</h2>

<details><summary><h3>Greptile Summary</h3></summary>

This PR implements a comprehensive folder organization system for
library agents, enabling hierarchical structure up to 5 levels deep.

**Backend Changes:**
- Added `LibraryFolder` model with self-referential hierarchy
(`parentId` → `Parent`/`Children`)
- Implemented CRUD operations with validation for circular references
and depth limits (MAX_FOLDER_DEPTH=5)
- Added `folderId` foreign key to `LibraryAgent` table
- Created folder management endpoints: list, get, create, update, move,
delete, and bulk agent moves
- Proper soft-delete cascade handling for folders and their contained
agents

**Frontend Changes:**
- Created folder creation/edit/delete dialogs with emoji picker
integration
- Implemented folder navigation UI with breadcrumbs and folder tree
structure
- Added drag-and-drop support for moving agents between folders
- Created context menu for agent actions (move to folder, remove from
folder)
- Added favorite animation system with `FavoriteAnimationProvider`
- Integrated folder filtering into agent list queries

**Key Features:**
- Folders support custom names, emoji icons, and hex colors
- Unique constraint per parent folder per user prevents duplicate names
- Validation prevents circular folder hierarchies and excessive nesting
- Agents can be moved between folders via drag-drop or context menu
- Deleting a folder soft-deletes all descendant folders and contained
agents
</details>


<details><summary><h3>Confidence Score: 4/5</h3></summary>

- This PR is safe to merge with minor considerations for performance
optimization
- The implementation is well-structured with proper validation, error
handling, and database constraints. The folder hierarchy logic correctly
prevents circular references and enforces depth limits. However, there
are some performance concerns with N+1 queries in depth calculation and
circular reference checking that could be optimized for deeply nested
hierarchies. The foreign key constraint (ON DELETE RESTRICT) conflicts
with the hard-delete code path but shouldn't cause issues since
soft-deletes are the default. The client-side duplicate validation is
redundant but not harmful.
- Pay close attention to migration file (foreign key constraint) and
db.py (performance of recursive queries)
</details>


<details><summary><h3>Sequence Diagram</h3></summary>

```mermaid
sequenceDiagram
    participant User
    participant Frontend
    participant API
    participant DB

    User->>Frontend: Create folder with name/icon/color
    Frontend->>API: POST /v2/folders
    API->>DB: Validate parent exists & depth limit
    API->>DB: Check unique constraint (userId, parentId, name)
    DB-->>API: Folder created
    API-->>Frontend: LibraryFolder response
    Frontend-->>User: Show success toast

    User->>Frontend: Drag agent to folder
    Frontend->>API: POST /v2/folders/agents/bulk-move
    API->>DB: Verify folder exists
    API->>DB: Update LibraryAgent.folderId
    DB-->>API: Agents updated
    API-->>Frontend: Updated agents
    Frontend-->>User: Refresh agent list

    User->>Frontend: Navigate into folder
    Frontend->>API: GET /v2/library/agents?folder_id=X
    API->>DB: Query agents WHERE folderId=X
    DB-->>API: Filtered agents
    API-->>Frontend: Agent list
    Frontend-->>User: Display folder contents

    User->>Frontend: Delete folder
    Frontend->>API: DELETE /v2/folders/{id}
    API->>DB: Get descendant folders recursively
    API->>DB: Soft-delete folders + agents in transaction
    DB-->>API: Deletion complete
    API-->>Frontend: 204 No Content
    Frontend-->>User: Show success toast
```
</details>


<sub>Last reviewed commit: a6c2f64</sub>

<!-- greptile_other_comments_section -->

<!-- /greptile_comment -->
2026-02-24 15:04:56 +00:00
Zamil Majdy
6e61cb103c fix(copilot): workspace file listing fix (#12190)
Requested by @majdyz

Improves workspace file display in GenericTool:
- Base64 content decoding for workspace files
- Rich file object rendering (path, size, mime type)
- MCP text extraction from SDK tool responses (Read, Glob, Grep, Edit)
- Better file list formatting for both string and object file entries

---------

Co-authored-by: Otto (AGPT) <otto@agpt.co>
2026-02-24 12:33:24 +00:00
Nicholas Tindle
e688f4003e fix(backend): handle malformed emails in PII masking
Prevents IndexError when email has empty local part (e.g., "@example.com").

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-08 21:56:49 -06:00
Nicholas Tindle
3b6f1a4591 fix(frontend): check response status in delete mutation
Consistent with other mutations, now checks response.status === 200
before showing success toast.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-08 20:55:10 -06:00
Nicholas Tindle
6b1432d59e fix(backend): add null fallback for categories field
Prevents validation error when DB returns None for categories.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-08 20:42:00 -06:00
Nicholas Tindle
f91edde32a fix(backend): mask email PII in waitlist logging
Avoid logging raw email addresses by masking to first char + domain.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-08 20:35:13 -06:00
Nicholas Tindle
4ba6c44f61 fix(frontend): regenerate openapi.json with correct structure
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-08 20:22:11 -06:00
Nicholas Tindle
b4e16e7246 Merge branch 'dev' into ntindle/waitlist 2026-02-08 20:00:02 -06:00
Nicholas Tindle
adeeba76d1 fix(backend): remove unused pytest import
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-08 19:07:00 -06:00
Nicholas Tindle
c88918af4f Merge remote-tracking branch 'origin/dev' into ntindle/waitlist 2026-02-08 18:42:10 -06:00
Nicholas Tindle
69618a5e05 Merge branch 'dev' into ntindle/waitlist 2026-02-04 19:02:00 -06:00
Nicholas Tindle
3610be3e83 Merge branch 'dev' into ntindle/waitlist 2026-01-20 17:47:02 -06:00
Nicholas Tindle
9e1f7c9415 Merge branch 'dev' into ntindle/waitlist 2026-01-19 01:12:14 -06:00
Nicholas Tindle
0d03ebb43c fix: lint 2026-01-16 11:34:00 -06:00
Nicholas Tindle
1b37bd6da9 Merge branch 'dev' into ntindle/waitlist 2026-01-16 11:32:05 -06:00
Nicholas Tindle
db989a5eed fix: lint 2026-01-15 15:58:33 -06:00
Nicholas Tindle
e3a8c57a35 Merge branch 'dev' into ntindle/waitlist
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 15:54:38 -06:00
Nicholas Tindle
dfc8e53386 fix(backend): add assertions to fix type errors in waitlist admin functions
Prisma's update() returns T | None but we verify existence before updating,
so assert the result is not None to satisfy the type checker.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 15:48:30 -06:00
Nicholas Tindle
b5b7e5da92 fix(backend): don't mark waitlist DONE if email-only users pending
The notify_waitlist_users_on_launch function was marking waitlists as
DONE after notifying registered users, but ignoring unaffiliatedEmailUsers
who haven't been notified yet. Since DONE waitlists are excluded from
future notification queries, those email users would never receive
notifications when that functionality is implemented.

Now the waitlist remains in an active state if there are pending
email-only signups that still need notifications.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 15:41:31 -06:00
Nicholas Tindle
07ea2c2ab7 fix(backend): check waitlist existence before update in update_waitlist_admin
Added find_unique check before update() call to properly return 404 when
waitlist doesn't exist, following the established pattern used in other
waitlist admin functions.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 15:37:19 -06:00
Nicholas Tindle
9c873a0158 fix(backend): add exception handling to add_self_to_waitlist route
The public waitlist join route was missing exception handling, causing
500 errors for all failures. Now properly returns:
- 404 for waitlist not found
- 400 for closed/unavailable waitlists
- 500 for unexpected errors

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 15:30:54 -06:00
Nicholas Tindle
ed634db8f7 fix(backend): validate waitlist status enum at API boundary
Changed WaitlistUpdateRequest.status from str to the actual enum type.
Pydantic now validates the status value, returning 422 for invalid
values instead of a misleading 404 "Waitlist not found" error.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 15:26:26 -06:00
Nicholas Tindle
398197f3ea fix(frontend): add title attribute to YouTube iframe for accessibility
Screen readers need a title attribute on iframes to describe their
content. Added "YouTube video player" title to the embedded video.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 15:23:52 -06:00
Nicholas Tindle
b7df4cfdbf fix(backend): align migration FK with schema (SET NULL not CASCADE)
The migration had ON DELETE CASCADE for WaitlistEntry.storeListingId,
but the Prisma schema specifies onDelete: SetNull. This mismatch would
cause waitlist entries and all signup data to be deleted when a store
listing is removed, instead of just unlinking them.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 15:18:03 -06:00
Nicholas Tindle
5d8dd46759 fix(backend): align waitlist admin functions with established patterns
- delete_waitlist_admin: add find_unique check before update, raise
  ValueError if not found, add except ValueError: raise
- link_waitlist_to_listing_admin: add find_unique check for waitlist
  before update, remove dead code
- delete_waitlist route: add except ValueError: → 404, remove dead
  code bool check pattern

All waitlist admin functions now follow the consistent pattern:
1. find_unique to check existence
2. raise ValueError if not found
3. except ValueError: raise to bubble up
4. except Exception: raise DatabaseError

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 14:54:53 -06:00
Nicholas Tindle
f9518b6f8b fix(frontend): use generated query key for waitlist cache invalidation
The hardcoded query key string didn't match the actual generated key,
causing cache invalidation to fail after joining a waitlist. Now uses
the generated getGetV2GetWaitlistIdsTheCurrentUserHasJoinedQueryKey()
function for correct cache invalidation.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 14:44:35 -06:00
Nicholas Tindle
205b220e90 fix(backend): filter out DONE/CANCELED waitlists before sending notifications
The notify_waitlist_users_on_launch function was not filtering by
waitlist status, which could cause duplicate notifications when an
agent is re-approved. Now excludes DONE and CANCELED waitlists,
consistent with get_waitlist() and add_user_to_waitlist().

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 14:41:37 -06:00
Nicholas Tindle
29a232fcb4 fix(frontend): add URL validation and sandbox to video player
- Add getYouTubeVideoId() to extract video IDs from YouTube URLs
- Add isValidVideoUrl() to validate video URLs before rendering
- Create VideoPlayer component that:
  - Embeds YouTube videos via iframe with safe embed URL
  - Adds sandbox attribute to restrict iframe capabilities
  - Adds proper allow attributes for media playback
  - Falls back to native video element for valid non-YouTube URLs
  - Shows error state for invalid URLs

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 14:29:10 -06:00
Nicholas Tindle
a53f261812 feat(frontend): add TODO warning for email-only waitlist notifications
Adds a warning banner on the admin waitlist page indicating that
notifications for email-only signups (non-logged-in users) have not
been implemented yet.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 14:05:21 -06:00
Nicholas Tindle
00a20f77be feat(backend): add waitlist_launch email notification template
The WAITLIST_LAUNCH notification type was referencing a template that
didn't exist, causing FileNotFoundError when trying to notify users
that an agent they waitlisted has launched.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-12 16:38:04 -06:00
Nicholas Tindle
4d49536a40 Discard changes to autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts 2026-01-12 15:28:37 -07:00
Nicholas Tindle
6028a2528c refactor(frontend): consolidate waitlist modals and align with Figma design
- Merge JoinWaitlistModal into WaitlistDetailModal for unified experience
- Add MediaCarousel component supporting videos and images with play overlay
- Update WaitlistCard styling to match Figma (rounded-large, line-clamp-5, zinc-800 button)
- Update success state with party emoji and Close button per Figma design
- Add sticky footer for buttons during modal scroll
- Support email input for non-logged-in users

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-12 16:27:09 -06:00
Nicholas Tindle
b31cd05675 fix(backend): correct typo in unaffiliatedEmailUsers field name
- Rename unafilliatedEmailUsers -> unaffiliatedEmailUsers in schema.prisma
- Update migration SQL to use correct column name
- Update all references in db.py and model.py

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-12 15:33:38 -06:00
Nicholas Tindle
128366772f refactor(backend): remove apscheduler tables from prisma schema
- Remove apscheduler_jobs and apscheduler_jobs_batched_notifications models
- Delete migration 20260107000001_add_apscheduler_tables
- Remove index rename statements from waitlist migration

APScheduler tables are managed at runtime by APScheduler itself and
should not be part of the Prisma schema.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-12 15:29:17 -06:00
Nicholas Tindle
764cdf17fe refactor(frontend): migrate waitlist admin components to generated API hooks
- Convert WaitlistTable to use generated React Query hooks directly
- Convert CreateWaitlistButton to use generated hooks
- Update WaitlistDetailModal to use generated types and design system Dialog
- Remove deprecated waitlist types from types.ts
- Remove deprecated waitlist methods from BackendAPI client
- Delete actions.ts server actions (no longer needed)
- Replace lucide-react icons with Phosphor icons

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-12 15:26:34 -06:00
Nicholas Tindle
1dd83b4cf8 fix(frontend): add text color to status badge fallback in WaitlistTable
Ensures unknown status values have readable text contrast by adding
text-gray-700 to the fallback className.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-12 15:09:44 -06:00
Nicholas Tindle
24a34f7ce5 Merge branch 'dev' into ntindle/waitlist 2026-01-12 14:08:48 -07:00
Nicholas Tindle
20fe2c3877 fix(backend): remove PII-exposing fields from public waitlist model
Remove `owner` (User type) and `storeListing` (StoreListingWithVersions)
fields from StoreWaitlistEntry. These fields were never populated but
exposed PII types (email, stripe_customer_id, etc.) in the OpenAPI schema.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-12 14:52:51 -06:00
Nicholas Tindle
738c7e2bef fix(platform): address remaining PR review feedback for waitlist
Backend fixes:
- Fix optional field clearing by using model_fields_set
- Re-fetch waitlist data after join operation
- Only mark waitlist as DONE if all notifications succeed
- Fix race condition in email removal with transaction
- Rename waitlist_id to waitlistId for naming consistency

Frontend fixes:
- Migrate useWaitlistSection to generated API hooks
- Migrate JoinWaitlistModal to design system + generated hooks
- Migrate WaitlistSignupsDialog to design system + generated hooks
- Replace lucide-react icons with Phosphor in WaitlistTable
- Add proper error state in WaitlistSignupsDialog
- Update waitlistId naming across components

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-12 14:43:10 -06:00
Nicholas Tindle
9edfe0fb97 refactor(frontend): migrate EditWaitlistDialog to design system and generated API
- Replace legacy Dialog components with molecules/Dialog
- Replace legacy Input/Label/Textarea with atoms/Input
- Replace legacy Select with atoms/Select
- Replace @/lib/autogpt-server-api/types with @/app/api/__generated__/models
- Replace updateWaitlist action with usePutV2UpdateWaitlist hook
- Remove dependency on BackendAPI in favor of generated React Query hooks

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-09 16:49:35 -07:00
Nicholas Tindle
4aabe71001 fix(platform): address PR review feedback for waitlist feature
Backend fixes:
- Fix creator_username null check in store URL construction
- Add embed=True to link_waitlist_to_listing endpoint body param
- Fix race condition in email list with transaction wrapper
- Replace str(e) with generic error messages in admin ValueError handlers
- Add validation requiring user_id or email in waitlist join
- Configure WAITLIST_LAUNCH in notification system (data type, queue, template, subject)
- Change StoreListing cascade delete to SetNull to preserve waitlist data

Frontend fixes:
- Escape internal quotes in CSV export for proper RFC 4180 compliance
- Remove incorrect 'use server' directive from page.tsx
- Replace lucide-react Check icon with Phosphor Icons

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-09 16:40:35 -07:00
Nicholas Tindle
b3999669f2 refactor(platform): simplify waitlist code and remove type duplication
- Backend: Extract _waitlist_to_store_entry helper to reduce duplication
- Backend: Use dict comprehension in update_waitlist_admin for cleaner code
- Frontend: Import types directly from shared types file instead of re-exporting
- Frontend: Remove redundant isMember check in WaitlistCard handleJoinClick

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-09 16:25:27 -07:00
Swifty
8c45a5ee98 Merge branch 'dev' into ntindle/waitlist 2026-01-08 12:38:46 +01:00
Nicholas Tindle
4b654c7e9f fix(frontend): Fix lint and type errors in waitlist admin components
- Remove unused WaitlistSignup import
- Change button size from "sm" to "small"

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-07 22:48:53 -07:00
Nicholas Tindle
8d82e3b633 fix(backend): Use Prisma connect pattern for waitlist-listing relation
Use StoreListing relation with connect pattern instead of directly
setting storeListingId, which doesn't work with Prisma's typed update.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-07 22:01:18 -07:00
Nicholas Tindle
d4ecdb64ed feat(platform): Show "On the waitlist" status for joined users
- Add GET /api/store/waitlist/my-memberships endpoint to fetch user's joined waitlists
- Add get_user_waitlist_memberships() db function
- Update useWaitlistSection hook to fetch memberships when logged in
- Update WaitlistCard to show green "On the waitlist" button for members
- Update WaitlistDetailModal to show member status
- Add onSuccess callback to JoinWaitlistModal for optimistic UI updates

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-07 21:15:03 -07:00
Nicholas Tindle
a73fb8f114 feat(platform): Add waitlist feature with admin management and user notifications
Backend:
- Add waitlist admin API routes for CRUD operations
- Add admin functions for waitlist management (create, update, delete, list)
- Add WaitlistLaunchData notification type for user notifications
- Integrate waitlist notifications into store submission approval flow
- Auto-notify waitlist users when linked agent is approved

Frontend:
- Add admin waitlist management page with table, create/edit dialogs
- Add WaitlistSection component to marketplace homepage
- Add WaitlistCard, WaitlistDetailModal, JoinWaitlistModal components
- Add API client methods and types for waitlist operations

Database:
- Add WAITLIST_LAUNCH notification type enum
- Add baseline migration for APScheduler tables

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-07 20:38:15 -07:00
Nicholas Tindle
2c60aa64ef wip: adding waitlist 2026-01-06 22:13:35 -07:00
267 changed files with 28080 additions and 5539 deletions

View File

@@ -149,7 +149,7 @@ jobs:
driver-opts: network=host
- name: Set up Platform - Expose GHA cache to docker buildx CLI
uses: crazy-max/ghaction-github-runtime@v3
uses: crazy-max/ghaction-github-runtime@v4
- name: Set up Platform - Build Docker images (with cache)
working-directory: autogpt_platform

4
.gitignore vendored
View File

@@ -180,4 +180,6 @@ autogpt_platform/backend/settings.py
.claude/settings.local.json
CLAUDE.local.md
/autogpt_platform/backend/logs
.next
.next
# Implementation plans (generated by AI agents)
plans/

1
.nvmrc Normal file
View File

@@ -0,0 +1 @@
22

View File

@@ -1,3 +1,10 @@
default_install_hook_types:
- pre-commit
- pre-push
- post-checkout
default_stages: [pre-commit]
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
@@ -17,6 +24,7 @@ repos:
name: Detect secrets
description: Detects high entropy strings that are likely to be passwords.
files: ^autogpt_platform/
exclude: pnpm-lock\.yaml$
stages: [pre-push]
- repo: local
@@ -26,49 +34,106 @@ repos:
- id: poetry-install
name: Check & Install dependencies - AutoGPT Platform - Backend
alias: poetry-install-platform-backend
entry: poetry -C autogpt_platform/backend install
# include autogpt_libs source (since it's a path dependency)
files: ^autogpt_platform/(backend|autogpt_libs)/poetry\.lock$
types: [file]
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^autogpt_platform/(backend|autogpt_libs)/poetry\.lock$" || exit 0;
poetry -C autogpt_platform/backend install
'
always_run: true
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - AutoGPT Platform - Libs
alias: poetry-install-platform-libs
entry: poetry -C autogpt_platform/autogpt_libs install
files: ^autogpt_platform/autogpt_libs/poetry\.lock$
types: [file]
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^autogpt_platform/autogpt_libs/poetry\.lock$" || exit 0;
poetry -C autogpt_platform/autogpt_libs install
'
always_run: true
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: pnpm-install
name: Check & Install dependencies - AutoGPT Platform - Frontend
alias: pnpm-install-platform-frontend
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^autogpt_platform/frontend/pnpm-lock\.yaml$" || exit 0;
pnpm --prefix autogpt_platform/frontend install
'
always_run: true
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic - AutoGPT
alias: poetry-install-classic-autogpt
entry: poetry -C classic/original_autogpt install
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/(original_autogpt|forge)/poetry\.lock$" || exit 0;
poetry -C classic/original_autogpt install
'
# include forge source (since it's a path dependency)
files: ^classic/(original_autogpt|forge)/poetry\.lock$
types: [file]
always_run: true
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic - Forge
alias: poetry-install-classic-forge
entry: poetry -C classic/forge install
files: ^classic/forge/poetry\.lock$
types: [file]
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/forge/poetry\.lock$" || exit 0;
poetry -C classic/forge install
'
always_run: true
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic - Benchmark
alias: poetry-install-classic-benchmark
entry: poetry -C classic/benchmark install
files: ^classic/benchmark/poetry\.lock$
types: [file]
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/benchmark/poetry\.lock$" || exit 0;
poetry -C classic/benchmark install
'
always_run: true
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- repo: local
# For proper type checking, Prisma client must be up-to-date.
@@ -76,12 +141,54 @@ repos:
- id: prisma-generate
name: Prisma Generate - AutoGPT Platform - Backend
alias: prisma-generate-platform-backend
entry: bash -c 'cd autogpt_platform/backend && poetry run prisma generate'
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^autogpt_platform/((backend|autogpt_libs)/poetry\.lock|backend/schema\.prisma)$" || exit 0;
cd autogpt_platform/backend
&& poetry run prisma generate
&& poetry run gen-prisma-stub
'
# include everything that triggers poetry install + the prisma schema
files: ^autogpt_platform/((backend|autogpt_libs)/poetry\.lock|backend/schema.prisma)$
types: [file]
always_run: true
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: export-api-schema
name: Export API schema - AutoGPT Platform - Backend -> Frontend
alias: export-api-schema-platform
entry: >
bash -c '
cd autogpt_platform/backend
&& poetry run export-api-schema --output ../frontend/src/app/api/openapi.json
&& cd ../frontend
&& pnpm prettier --write ./src/app/api/openapi.json
'
files: ^autogpt_platform/backend/
language: system
pass_filenames: false
- id: generate-api-client
name: Generate API client - AutoGPT Platform - Frontend
alias: generate-api-client-platform-frontend
entry: >
bash -c '
SCHEMA=autogpt_platform/frontend/src/app/api/openapi.json;
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --quiet "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF" -- "$SCHEMA" && exit 0
else
git diff --quiet HEAD -- "$SCHEMA" && exit 0
fi;
cd autogpt_platform/frontend && pnpm generate:api
'
always_run: true
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.2

View File

@@ -1,2 +1,3 @@
*.ignore.*
*.ign.*
*.ign.*
.application.logs

View File

@@ -95,7 +95,7 @@ ENV DEBIAN_FRONTEND=noninteractive
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
# for the bash_exec MCP tool.
# for the bash_exec MCP tool (fallback when E2B is not configured).
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.13 \
@@ -111,13 +111,29 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
# Copy poetry (build-time only, for `poetry install --only-root` to create entry points)
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
# Copy Node.js installation for Prisma
# Copy Node.js installation for Prisma and agent-browser.
# npm/npx are symlinks in the builder (-> ../lib/node_modules/npm/bin/*-cli.js);
# COPY resolves them to regular files, breaking require() paths. Recreate as
# proper symlinks so npm/npx can find their modules.
COPY --from=builder /usr/bin/node /usr/bin/node
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
COPY --from=builder /usr/bin/npm /usr/bin/npm
COPY --from=builder /usr/bin/npx /usr/bin/npx
RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
# Install agent-browser (Copilot browser tool) + Chromium runtime dependencies.
# These are the runtime libraries Chromium/Playwright needs on Debian 13 (trixie).
RUN apt-get update && apt-get install -y --no-install-recommends \
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
fonts-liberation libfontconfig1 \
&& rm -rf /var/lib/apt/lists/* \
&& npm install -g agent-browser \
&& agent-browser install \
&& rm -rf /tmp/* /root/.npm
WORKDIR /app/autogpt_platform/backend
# Copy only the .venv from builder (not the entire /app directory)

View File

@@ -88,20 +88,23 @@ async def require_auth(
)
def require_permission(permission: APIKeyPermission):
def require_permission(*permissions: APIKeyPermission):
"""
Dependency function for checking specific permissions
Dependency function for checking required permissions.
All listed permissions must be present.
(works with API keys and OAuth tokens)
"""
async def check_permission(
async def check_permissions(
auth: APIAuthorizationInfo = Security(require_auth),
) -> APIAuthorizationInfo:
if permission not in auth.scopes:
missing = [p for p in permissions if p not in auth.scopes]
if missing:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Missing required permission: {permission.value}",
detail=f"Missing required permission(s): "
f"{', '.join(p.value for p in missing)}",
)
return auth
return check_permission
return check_permissions

View File

@@ -18,6 +18,7 @@ from backend.data import user as user_db
from backend.data.auth.base import APIAuthorizationInfo
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.executor.utils import add_graph_execution
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
from backend.util.settings import Settings
from .integrations import integrations_router
@@ -95,6 +96,43 @@ async def execute_graph_block(
return output
@v1_router.post(
path="/graphs",
tags=["graphs"],
status_code=201,
dependencies=[
Security(
require_permission(
APIKeyPermission.WRITE_GRAPH, APIKeyPermission.WRITE_LIBRARY
)
)
],
)
async def create_graph(
graph: graph_db.Graph,
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.WRITE_GRAPH, APIKeyPermission.WRITE_LIBRARY)
),
) -> graph_db.GraphModel:
"""
Create a new agent graph.
The graph will be validated and assigned a new ID.
It is automatically added to the user's library.
"""
from backend.api.features.library import db as library_db
graph_model = graph_db.make_graph_model(graph, auth.user_id)
graph_model.reassign_ids(user_id=auth.user_id, reassign_graph_id=True)
graph_model.validate_graph(for_run=False)
await graph_db.create_graph(graph_model, user_id=auth.user_id)
await library_db.create_library_agent(graph_model, auth.user_id)
activated_graph = await on_graph_activate(graph_model, user_id=auth.user_id)
return activated_graph
@v1_router.post(
path="/graphs/{graph_id}/execute/{graph_version}",
tags=["graphs"],

View File

@@ -0,0 +1,164 @@
import logging
import autogpt_libs.auth
import fastapi
import backend.api.features.store.db as store_db
import backend.api.features.store.model as store_model
logger = logging.getLogger(__name__)
router = fastapi.APIRouter(
prefix="/admin/waitlist",
tags=["store", "admin", "waitlist"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_admin_user)],
)
@router.post(
"",
summary="Create Waitlist",
response_model=store_model.WaitlistAdminResponse,
)
async def create_waitlist(
request: store_model.WaitlistCreateRequest,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Create a new waitlist (admin only).
Args:
request: Waitlist creation details
user_id: Authenticated admin user creating the waitlist
Returns:
WaitlistAdminResponse with the created waitlist details
"""
return await store_db.create_waitlist_admin(
admin_user_id=user_id,
data=request,
)
@router.get(
"",
summary="List All Waitlists",
response_model=store_model.WaitlistAdminListResponse,
)
async def list_waitlists():
"""
Get all waitlists with admin details (admin only).
Returns:
WaitlistAdminListResponse with all waitlists
"""
return await store_db.get_waitlists_admin()
@router.get(
"/{waitlist_id}",
summary="Get Waitlist Details",
response_model=store_model.WaitlistAdminResponse,
)
async def get_waitlist(
waitlist_id: str = fastapi.Path(..., description="The ID of the waitlist"),
):
"""
Get a single waitlist with admin details (admin only).
Args:
waitlist_id: ID of the waitlist to retrieve
Returns:
WaitlistAdminResponse with waitlist details
"""
return await store_db.get_waitlist_admin(waitlist_id)
@router.put(
"/{waitlist_id}",
summary="Update Waitlist",
response_model=store_model.WaitlistAdminResponse,
)
async def update_waitlist(
request: store_model.WaitlistUpdateRequest,
waitlist_id: str = fastapi.Path(..., description="The ID of the waitlist"),
):
"""
Update a waitlist (admin only).
Args:
waitlist_id: ID of the waitlist to update
request: Fields to update
Returns:
WaitlistAdminResponse with updated waitlist details
"""
return await store_db.update_waitlist_admin(waitlist_id, request)
@router.delete(
"/{waitlist_id}",
summary="Delete Waitlist",
)
async def delete_waitlist(
waitlist_id: str = fastapi.Path(..., description="The ID of the waitlist"),
):
"""
Soft delete a waitlist (admin only).
Args:
waitlist_id: ID of the waitlist to delete
Returns:
Success message
"""
await store_db.delete_waitlist_admin(waitlist_id)
return {"message": "Waitlist deleted successfully"}
@router.get(
"/{waitlist_id}/signups",
summary="Get Waitlist Signups",
response_model=store_model.WaitlistSignupListResponse,
)
async def get_waitlist_signups(
waitlist_id: str = fastapi.Path(..., description="The ID of the waitlist"),
):
"""
Get all signups for a waitlist (admin only).
Args:
waitlist_id: ID of the waitlist
Returns:
WaitlistSignupListResponse with all signups
"""
return await store_db.get_waitlist_signups_admin(waitlist_id)
@router.post(
"/{waitlist_id}/link",
summary="Link Waitlist to Store Listing",
response_model=store_model.WaitlistAdminResponse,
)
async def link_waitlist_to_listing(
waitlist_id: str = fastapi.Path(..., description="The ID of the waitlist"),
store_listing_id: str = fastapi.Body(
..., embed=True, description="The ID of the store listing"
),
):
"""
Link a waitlist to a store listing (admin only).
When the linked store listing is approved/published, waitlist users
will be automatically notified.
Args:
waitlist_id: ID of the waitlist
store_listing_id: ID of the store listing to link
Returns:
WaitlistAdminResponse with updated waitlist details
"""
return await store_db.link_waitlist_to_listing_admin(waitlist_id, store_listing_id)

View File

@@ -1,15 +1,17 @@
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from difflib import SequenceMatcher
from typing import Sequence
from typing import Any, Sequence, get_args, get_origin
import prisma
from prisma.enums import ContentType
from prisma.models import mv_suggested_blocks
import backend.api.features.library.db as library_db
import backend.api.features.library.model as library_model
import backend.api.features.store.db as store_db
import backend.api.features.store.model as store_model
from backend.api.features.store.hybrid_search import unified_hybrid_search
from backend.blocks import load_all_blocks
from backend.blocks._base import (
AnyBlockSchema,
@@ -19,7 +21,6 @@ from backend.blocks._base import (
BlockType,
)
from backend.blocks.llm import LlmModel
from backend.data.db import query_raw_with_schema
from backend.integrations.providers import ProviderName
from backend.util.cache import cached
from backend.util.models import Pagination
@@ -42,6 +43,16 @@ MAX_LIBRARY_AGENT_RESULTS = 100
MAX_MARKETPLACE_AGENT_RESULTS = 100
MIN_SCORE_FOR_FILTERED_RESULTS = 10.0
# Boost blocks over marketplace agents in search results
BLOCK_SCORE_BOOST = 50.0
# Block IDs to exclude from search results
EXCLUDED_BLOCK_IDS = frozenset(
{
"e189baac-8c20-45a1-94a7-55177ea42565", # AgentExecutorBlock
}
)
SearchResultItem = BlockInfo | library_model.LibraryAgent | store_model.StoreAgent
@@ -64,8 +75,8 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type()
# Skip disabled blocks
if block.disabled:
# Skip disabled and excluded blocks
if block.disabled or block.id in EXCLUDED_BLOCK_IDS:
continue
# Skip blocks that don't have categories (all should have at least one)
if not block.categories:
@@ -116,6 +127,9 @@ def get_blocks(
# Skip disabled blocks
if block.disabled:
continue
# Skip excluded blocks
if block.id in EXCLUDED_BLOCK_IDS:
continue
# Skip blocks that don't match the category
if category and category not in {c.name.lower() for c in block.categories}:
continue
@@ -255,14 +269,25 @@ async def _build_cached_search_results(
"my_agents": 0,
}
block_results, block_total, integration_total = _collect_block_results(
normalized_query=normalized_query,
include_blocks=include_blocks,
include_integrations=include_integrations,
)
scored_items.extend(block_results)
total_items["blocks"] = block_total
total_items["integrations"] = integration_total
# Use hybrid search when query is present, otherwise list all blocks
if (include_blocks or include_integrations) and normalized_query:
block_results, block_total, integration_total = await _hybrid_search_blocks(
query=search_query,
include_blocks=include_blocks,
include_integrations=include_integrations,
)
scored_items.extend(block_results)
total_items["blocks"] = block_total
total_items["integrations"] = integration_total
elif include_blocks or include_integrations:
# No query - list all blocks using in-memory approach
block_results, block_total, integration_total = _collect_block_results(
include_blocks=include_blocks,
include_integrations=include_integrations,
)
scored_items.extend(block_results)
total_items["blocks"] = block_total
total_items["integrations"] = integration_total
if include_library_agents:
library_response = await library_db.list_library_agents(
@@ -307,10 +332,14 @@ async def _build_cached_search_results(
def _collect_block_results(
*,
normalized_query: str,
include_blocks: bool,
include_integrations: bool,
) -> tuple[list[_ScoredItem], int, int]:
"""
Collect all blocks for listing (no search query).
All blocks get BLOCK_SCORE_BOOST to prioritize them over marketplace agents.
"""
results: list[_ScoredItem] = []
block_count = 0
integration_count = 0
@@ -323,6 +352,10 @@ def _collect_block_results(
if block.disabled:
continue
# Skip excluded blocks
if block.id in EXCLUDED_BLOCK_IDS:
continue
block_info = block.get_info()
credentials = list(block.input_schema.get_credentials_fields().values())
is_integration = len(credentials) > 0
@@ -332,10 +365,6 @@ def _collect_block_results(
if not is_integration and not include_blocks:
continue
score = _score_block(block, block_info, normalized_query)
if not _should_include_item(score, normalized_query):
continue
filter_type: FilterType = "integrations" if is_integration else "blocks"
if is_integration:
integration_count += 1
@@ -346,8 +375,122 @@ def _collect_block_results(
_ScoredItem(
item=block_info,
filter_type=filter_type,
score=score,
sort_key=_get_item_name(block_info),
score=BLOCK_SCORE_BOOST,
sort_key=block_info.name.lower(),
)
)
return results, block_count, integration_count
async def _hybrid_search_blocks(
*,
query: str,
include_blocks: bool,
include_integrations: bool,
) -> tuple[list[_ScoredItem], int, int]:
"""
Search blocks using hybrid search with builder-specific filtering.
Uses unified_hybrid_search for semantic + lexical search, then applies
post-filtering for block/integration types and scoring adjustments.
Scoring:
- Base: hybrid relevance score (0-1) scaled to 0-100, plus BLOCK_SCORE_BOOST
to prioritize blocks over marketplace agents in combined results
- +30 for exact name match, +15 for prefix name match
- +20 if the block has an LlmModel field and the query matches an LLM model name
Args:
query: The search query string
include_blocks: Whether to include regular blocks
include_integrations: Whether to include integration blocks
Returns:
Tuple of (scored_items, block_count, integration_count)
"""
results: list[_ScoredItem] = []
block_count = 0
integration_count = 0
if not include_blocks and not include_integrations:
return results, block_count, integration_count
normalized_query = query.strip().lower()
# Fetch more results to account for post-filtering
search_results, _ = await unified_hybrid_search(
query=query,
content_types=[ContentType.BLOCK],
page=1,
page_size=150,
min_score=0.10,
)
# Load all blocks for getting BlockInfo
all_blocks = load_all_blocks()
for result in search_results:
block_id = result["content_id"]
# Skip excluded blocks
if block_id in EXCLUDED_BLOCK_IDS:
continue
metadata = result.get("metadata", {})
hybrid_score = result.get("relevance", 0.0)
# Get the actual block class
if block_id not in all_blocks:
continue
block_cls = all_blocks[block_id]
block: AnyBlockSchema = block_cls()
if block.disabled:
continue
# Check block/integration filter using metadata
is_integration = metadata.get("is_integration", False)
if is_integration and not include_integrations:
continue
if not is_integration and not include_blocks:
continue
# Get block info
block_info = block.get_info()
# Calculate final score: scale hybrid score and add builder-specific bonuses
# Hybrid scores are 0-1, builder scores were 0-200+
# Add BLOCK_SCORE_BOOST to prioritize blocks over marketplace agents
final_score = hybrid_score * 100 + BLOCK_SCORE_BOOST
# Add LLM model match bonus
has_llm_field = metadata.get("has_llm_model_field", False)
if has_llm_field and _matches_llm_model(block.input_schema, normalized_query):
final_score += 20
# Add exact/prefix match bonus for deterministic tie-breaking
name = block_info.name.lower()
if name == normalized_query:
final_score += 30
elif name.startswith(normalized_query):
final_score += 15
# Track counts
filter_type: FilterType = "integrations" if is_integration else "blocks"
if is_integration:
integration_count += 1
else:
block_count += 1
results.append(
_ScoredItem(
item=block_info,
filter_type=filter_type,
score=final_score,
sort_key=name,
)
)
@@ -472,6 +615,8 @@ async def _get_static_counts():
block: AnyBlockSchema = block_type()
if block.disabled:
continue
if block.id in EXCLUDED_BLOCK_IDS:
continue
all_blocks += 1
@@ -498,47 +643,25 @@ async def _get_static_counts():
}
def _contains_type(annotation: Any, target: type) -> bool:
"""Check if an annotation is or contains the target type (handles Optional/Union/Annotated)."""
if annotation is target:
return True
origin = get_origin(annotation)
if origin is None:
return False
return any(_contains_type(arg, target) for arg in get_args(annotation))
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
for field in schema_cls.model_fields.values():
if field.annotation == LlmModel:
if _contains_type(field.annotation, LlmModel):
# Check if query matches any value in llm_models
if any(query in name for name in llm_models):
return True
return False
def _score_block(
block: AnyBlockSchema,
block_info: BlockInfo,
normalized_query: str,
) -> float:
if not normalized_query:
return 0.0
name = block_info.name.lower()
description = block_info.description.lower()
score = _score_primary_fields(name, description, normalized_query)
category_text = " ".join(
category.get("category", "").lower() for category in block_info.categories
)
score += _score_additional_field(category_text, normalized_query, 12, 6)
credentials_info = block.input_schema.get_credentials_fields_info().values()
provider_names = [
provider.value.lower()
for info in credentials_info
for provider in info.provider
]
provider_text = " ".join(provider_names)
score += _score_additional_field(provider_text, normalized_query, 15, 6)
if _matches_llm_model(block.input_schema, normalized_query):
score += 20
return score
def _score_library_agent(
agent: library_model.LibraryAgent,
normalized_query: str,
@@ -645,31 +768,20 @@ def _get_all_providers() -> dict[ProviderName, Provider]:
return providers
@cached(ttl_seconds=3600)
@cached(ttl_seconds=3600, shared_cache=True)
async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
suggested_blocks = []
# Sum the number of executions for each block type
# Prisma cannot group by nested relations, so we do a raw query
# Calculate the cutoff timestamp
timestamp_threshold = datetime.now(timezone.utc) - timedelta(days=30)
"""Return the most-executed blocks from the last 14 days.
results = await query_raw_with_schema(
"""
SELECT
agent_node."agentBlockId" AS block_id,
COUNT(execution.id) AS execution_count
FROM {schema_prefix}"AgentNodeExecution" execution
JOIN {schema_prefix}"AgentNode" agent_node ON execution."agentNodeId" = agent_node.id
WHERE execution."endedTime" >= $1::timestamp
GROUP BY agent_node."agentBlockId"
ORDER BY execution_count DESC;
""",
timestamp_threshold,
)
Queries the mv_suggested_blocks materialized view (refreshed hourly via pg_cron)
and returns the top `count` blocks sorted by execution count, excluding
Input/Output/Agent block types and blocks in EXCLUDED_BLOCK_IDS.
"""
results = await mv_suggested_blocks.prisma().find_many()
# Get the top blocks based on execution count
# But ignore Input and Output blocks
# But ignore Input, Output, Agent, and excluded blocks
blocks: list[tuple[BlockInfo, int]] = []
execution_counts = {row.block_id: row.execution_count for row in results}
for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type()
@@ -679,11 +791,9 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
BlockType.AGENT,
):
continue
# Find the execution count for this block
execution_count = next(
(row["execution_count"] for row in results if row["block_id"] == block.id),
0,
)
if block.id in EXCLUDED_BLOCK_IDS:
continue
execution_count = execution_counts.get(block.id, 0)
blocks.append((block.get_info(), execution_count))
# Sort blocks by execution count
blocks.sort(key=lambda x: x[1], reverse=True)

View File

@@ -27,7 +27,6 @@ class SearchEntry(BaseModel):
# Suggestions
class SuggestionsResponse(BaseModel):
otto_suggestions: list[str]
recent_searches: list[SearchEntry]
providers: list[ProviderName]
top_blocks: list[BlockInfo]

View File

@@ -1,5 +1,5 @@
import logging
from typing import Annotated, Sequence
from typing import Annotated, Sequence, cast, get_args
import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
@@ -10,6 +10,8 @@ from backend.util.models import Pagination
from . import db as builder_db
from . import model as builder_model
VALID_FILTER_VALUES = get_args(builder_model.FilterType)
logger = logging.getLogger(__name__)
router = fastapi.APIRouter(
@@ -49,11 +51,6 @@ async def get_suggestions(
Get all suggestions for the Blocks Menu.
"""
return builder_model.SuggestionsResponse(
otto_suggestions=[
"What blocks do I need to get started?",
"Help me create a list",
"Help me feed my data to Google Maps",
],
recent_searches=await builder_db.get_recent_searches(user_id),
providers=[
ProviderName.TWITTER,
@@ -151,7 +148,7 @@ async def get_providers(
async def search(
user_id: Annotated[str, fastapi.Security(get_user_id)],
search_query: Annotated[str | None, fastapi.Query()] = None,
filter: Annotated[list[builder_model.FilterType] | None, fastapi.Query()] = None,
filter: Annotated[str | None, fastapi.Query()] = None,
search_id: Annotated[str | None, fastapi.Query()] = None,
by_creator: Annotated[list[str] | None, fastapi.Query()] = None,
page: Annotated[int, fastapi.Query()] = 1,
@@ -160,9 +157,20 @@ async def search(
"""
Search for blocks (including integrations), marketplace agents, and user library agents.
"""
# If no filters are provided, then we will return all types
if not filter:
filter = [
# Parse and validate filter parameter
filters: list[builder_model.FilterType]
if filter:
filter_values = [f.strip() for f in filter.split(",")]
invalid_filters = [f for f in filter_values if f not in VALID_FILTER_VALUES]
if invalid_filters:
raise fastapi.HTTPException(
status_code=400,
detail=f"Invalid filter value(s): {', '.join(invalid_filters)}. "
f"Valid values are: {', '.join(VALID_FILTER_VALUES)}",
)
filters = cast(list[builder_model.FilterType], filter_values)
else:
filters = [
"blocks",
"integrations",
"marketplace_agents",
@@ -174,7 +182,7 @@ async def search(
cached_results = await builder_db.get_sorted_search_results(
user_id=user_id,
search_query=search_query,
filters=filter,
filters=filters,
by_creator=by_creator,
)
@@ -196,7 +204,7 @@ async def search(
user_id,
builder_model.SearchEntry(
search_query=search_query,
filter=filter,
filter=filters,
by_creator=by_creator,
search_id=search_id,
),

View File

@@ -2,6 +2,7 @@
import asyncio
import logging
import re
from collections.abc import AsyncGenerator
from typing import Annotated
from uuid import uuid4
@@ -9,7 +10,8 @@ from uuid import uuid4
from autogpt_libs import auth
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from prisma.models import UserWorkspaceFile
from pydantic import BaseModel, Field
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
@@ -40,6 +42,8 @@ from backend.copilot.tools.models import (
ErrorResponse,
ExecutionStartedResponse,
InputValidationErrorResponse,
MCPToolOutputResponse,
MCPToolsDiscoveredResponse,
NeedLoginResponse,
NoResultsResponse,
SetupRequirementsResponse,
@@ -47,10 +51,14 @@ from backend.copilot.tools.models import (
UnderstandingUpdatedResponse,
)
from backend.copilot.tracking import track_user_message
from backend.data.workspace import get_or_create_workspace
from backend.util.exceptions import NotFoundError
config = ChatConfig()
_UUID_RE = re.compile(
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
)
logger = logging.getLogger(__name__)
@@ -79,6 +87,9 @@ class StreamChatRequest(BaseModel):
message: str
is_user_message: bool = True
context: dict[str, str] | None = None # {url: str, content: str}
file_ids: list[str] | None = Field(
default=None, max_length=20
) # Workspace file IDs attached to this message
class CreateSessionResponse(BaseModel):
@@ -238,6 +249,18 @@ async def delete_session(
detail=f"Session {session_id} not found or access denied",
)
# Best-effort cleanup of the E2B sandbox (if any).
config = ChatConfig()
if config.use_e2b_sandbox and config.e2b_api_key:
from backend.copilot.tools.e2b_sandbox import kill_sandbox
try:
await kill_sandbox(session_id, config.e2b_api_key)
except Exception:
logger.warning(
"[E2B] Failed to kill sandbox for session %s", session_id[:12]
)
return Response(status_code=204)
@@ -394,6 +417,38 @@ async def stream_chat_post(
},
)
# Enrich message with file metadata if file_ids are provided.
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
sanitized_file_ids: list[str] | None = None
if request.file_ids and user_id:
# Filter to valid UUIDs only to prevent DB abuse
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
if valid_ids:
workspace = await get_or_create_workspace(user_id)
# Batch query instead of N+1
files = await UserWorkspaceFile.prisma().find_many(
where={
"id": {"in": valid_ids},
"workspaceId": workspace.id,
"isDeleted": False,
}
)
# Only keep IDs that actually exist in the user's workspace
sanitized_file_ids = [wf.id for wf in files] or None
file_lines: list[str] = [
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
for wf in files
]
if file_lines:
files_block = (
"\n\n[Attached files]\n"
+ "\n".join(file_lines)
+ "\nUse read_workspace_file with the file_id to access file contents."
)
request.message += files_block
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
# saved yet. append_and_save_message re-fetches inside a lock to prevent
@@ -445,6 +500,7 @@ async def stream_chat_post(
turn_id=turn_id,
is_user_message=request.is_user_message,
context=request.context,
file_ids=sanitized_file_ids,
)
setup_time = (time.perf_counter() - stream_start_time) * 1000
@@ -487,7 +543,7 @@ async def stream_chat_post(
)
while True:
try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
chunks_yielded += 1
if not first_chunk_yielded:
@@ -621,11 +677,10 @@ async def resume_session_stream(
if not active_session:
return Response(status_code=204)
# Subscribe from the beginning ("0-0") to replay all chunks for this turn.
# This is necessary because hydrated messages filter out incomplete tool calls
# to avoid "No tool invocation found" errors. The resume stream delivers
# those tool calls fresh with proper SDK state.
# The AI SDK's deduplication will handle any duplicate chunks.
# Always replay from the beginning ("0-0") on resume.
# We can't use last_message_id because it's the latest ID in the backend
# stream, not the latest the frontend received — the gap causes lost
# messages. The frontend deduplicates replayed content.
subscriber_queue = await stream_registry.subscribe_to_session(
session_id=session_id,
user_id=user_id,
@@ -641,7 +696,7 @@ async def resume_session_stream(
try:
while True:
try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
if chunk_count < 3:
logger.info(
"Resume stream chunk",
@@ -801,6 +856,8 @@ ToolResponseUnion = (
| BlockOutputResponse
| DocSearchResultsResponse
| DocPageResponse
| MCPToolsDiscoveredResponse
| MCPToolOutputResponse
)

View File

@@ -0,0 +1,160 @@
"""Tests for chat route file_ids validation and enrichment."""
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from backend.api.features.chat import routes as chat_routes
app = fastapi.FastAPI()
app.include_router(chat_routes.router)
client = fastapi.testclient.TestClient(app)
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
# ---- file_ids Pydantic validation (B1) ----
def test_stream_chat_rejects_too_many_file_ids():
"""More than 20 file_ids should be rejected by Pydantic validation (422)."""
response = client.post(
"/sessions/sess-1/stream",
json={
"message": "hello",
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(21)],
},
)
assert response.status_code == 422
def _mock_stream_internals(mocker: pytest_mock.MockFixture):
"""Mock the async internals of stream_chat_post so tests can exercise
validation and enrichment logic without needing Redis/RabbitMQ."""
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = mocker.AsyncMock(return_value=None)
mocker.patch(
"backend.api.features.chat.routes.stream_registry",
mock_registry,
)
mocker.patch(
"backend.api.features.chat.routes.enqueue_copilot_turn",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
"""Exactly 20 file_ids should be accepted (not rejected by validation)."""
_mock_stream_internals(mocker)
# Patch workspace lookup as imported by the routes module
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
return_value=type("W", (), {"id": "ws-1"})(),
)
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
response = client.post(
"/sessions/sess-1/stream",
json={
"message": "hello",
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(20)],
},
)
# Should get past validation — 200 streaming response expected
assert response.status_code == 200
# ---- UUID format filtering ----
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
"""Non-UUID strings in file_ids should be silently filtered out
and NOT passed to the database query."""
_mock_stream_internals(mocker)
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
return_value=type("W", (), {"id": "ws-1"})(),
)
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
valid_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
client.post(
"/sessions/sess-1/stream",
json={
"message": "hello",
"file_ids": [
valid_id,
"not-a-uuid",
"../../../etc/passwd",
"",
],
},
)
# The find_many call should only receive the one valid UUID
mock_prisma.find_many.assert_called_once()
call_kwargs = mock_prisma.find_many.call_args[1]
assert call_kwargs["where"]["id"]["in"] == [valid_id]
# ---- Cross-workspace file_ids ----
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
"""The batch query should scope to the user's workspace."""
_mock_stream_internals(mocker)
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
return_value=type("W", (), {"id": "my-workspace-id"})(),
)
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
fid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
client.post(
"/sessions/sess-1/stream",
json={"message": "hi", "file_ids": [fid]},
)
call_kwargs = mock_prisma.find_many.call_args[1]
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
assert call_kwargs["where"]["isDeleted"] is False

View File

@@ -22,6 +22,7 @@ from backend.data.human_review import (
)
from backend.data.model import USER_TIMEZONE_NOT_SET
from backend.data.user import get_user_by_id
from backend.data.workspace import get_or_create_workspace
from backend.executor.utils import add_graph_execution
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
@@ -321,10 +322,13 @@ async def process_review_action(
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
)
workspace = await get_or_create_workspace(user_id)
execution_context = ExecutionContext(
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
user_timezone=user_timezone,
workspace_id=workspace.id,
)
await add_graph_execution(

File diff suppressed because it is too large Load Diff

View File

@@ -144,6 +144,7 @@ async def test_add_agent_to_library(mocker):
)
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.create = mocker.AsyncMock(
return_value=mock_library_agent_data
@@ -178,7 +179,6 @@ async def test_add_agent_to_library(mocker):
"agentGraphVersion": 1,
}
},
include={"AgentGraph": True},
)
# Check that create was called with the expected data including settings
create_call_args = mock_library_agent.return_value.create.call_args

View File

@@ -0,0 +1,10 @@
class FolderValidationError(Exception):
"""Raised when folder operations fail validation."""
pass
class FolderAlreadyExistsError(FolderValidationError):
"""Raised when a folder with the same name already exists in the location."""
pass

View File

@@ -26,6 +26,95 @@ class LibraryAgentStatus(str, Enum):
ERROR = "ERROR"
# === Folder Models ===
class LibraryFolder(pydantic.BaseModel):
"""Represents a folder for organizing library agents."""
id: str
user_id: str
name: str
icon: str | None = None
color: str | None = None
parent_id: str | None = None
created_at: datetime.datetime
updated_at: datetime.datetime
agent_count: int = 0 # Direct agents in folder
subfolder_count: int = 0 # Direct child folders
@staticmethod
def from_db(
folder: prisma.models.LibraryFolder,
agent_count: int = 0,
subfolder_count: int = 0,
) -> "LibraryFolder":
"""Factory method that constructs a LibraryFolder from a Prisma model."""
return LibraryFolder(
id=folder.id,
user_id=folder.userId,
name=folder.name,
icon=folder.icon,
color=folder.color,
parent_id=folder.parentId,
created_at=folder.createdAt,
updated_at=folder.updatedAt,
agent_count=agent_count,
subfolder_count=subfolder_count,
)
class LibraryFolderTree(LibraryFolder):
"""Folder with nested children for tree view."""
children: list["LibraryFolderTree"] = []
class FolderCreateRequest(pydantic.BaseModel):
"""Request model for creating a folder."""
name: str = pydantic.Field(..., min_length=1, max_length=100)
icon: str | None = None
color: str | None = pydantic.Field(
None, pattern=r"^#[0-9A-Fa-f]{6}$", description="Hex color code (#RRGGBB)"
)
parent_id: str | None = None
class FolderUpdateRequest(pydantic.BaseModel):
"""Request model for updating a folder."""
name: str | None = pydantic.Field(None, min_length=1, max_length=100)
icon: str | None = None
color: str | None = None
class FolderMoveRequest(pydantic.BaseModel):
"""Request model for moving a folder to a new parent."""
target_parent_id: str | None = None # None = move to root
class BulkMoveAgentsRequest(pydantic.BaseModel):
"""Request model for moving multiple agents to a folder."""
agent_ids: list[str]
folder_id: str | None = None # None = move to root
class FolderListResponse(pydantic.BaseModel):
"""Response schema for a list of folders."""
folders: list[LibraryFolder]
pagination: Pagination
class FolderTreeResponse(pydantic.BaseModel):
"""Response schema for folder tree structure."""
tree: list[LibraryFolderTree]
class MarketplaceListingCreator(pydantic.BaseModel):
"""Creator information for a marketplace listing."""
@@ -120,6 +209,9 @@ class LibraryAgent(pydantic.BaseModel):
can_access_graph: bool
is_latest_version: bool
is_favorite: bool
folder_id: str | None = None
folder_name: str | None = None # Denormalized for display
recommended_schedule_cron: str | None = None
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
marketplace_listing: Optional["MarketplaceListing"] = None
@@ -259,6 +351,8 @@ class LibraryAgent(pydantic.BaseModel):
can_access_graph=can_access_graph,
is_latest_version=is_latest_version,
is_favorite=agent.isFavorite,
folder_id=agent.folderId,
folder_name=agent.Folder.name if agent.Folder else None,
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
settings=_parse_settings(agent.settings),
marketplace_listing=marketplace_listing_data,
@@ -470,3 +564,7 @@ class LibraryAgentUpdateRequest(pydantic.BaseModel):
settings: Optional[GraphSettings] = pydantic.Field(
default=None, description="User-specific settings for this library agent"
)
folder_id: Optional[str] = pydantic.Field(
default=None,
description="Folder ID to move agent to (None to move to root)",
)

View File

@@ -1,9 +1,11 @@
import fastapi
from .agents import router as agents_router
from .folders import router as folders_router
from .presets import router as presets_router
router = fastapi.APIRouter()
router.include_router(presets_router)
router.include_router(folders_router)
router.include_router(agents_router)

View File

@@ -41,6 +41,14 @@ async def list_library_agents(
ge=1,
description="Number of agents per page (must be >= 1)",
),
folder_id: Optional[str] = Query(
None,
description="Filter by folder ID",
),
include_root_only: bool = Query(
False,
description="Only return agents without a folder (root-level agents)",
),
) -> library_model.LibraryAgentResponse:
"""
Get all agents in the user's library (both created and saved).
@@ -51,6 +59,8 @@ async def list_library_agents(
sort_by=sort_by,
page=page,
page_size=page_size,
folder_id=folder_id,
include_root_only=include_root_only,
)
@@ -168,6 +178,7 @@ async def update_library_agent(
is_favorite=payload.is_favorite,
is_archived=payload.is_archived,
settings=payload.settings,
folder_id=payload.folder_id,
)

View File

@@ -0,0 +1,287 @@
from typing import Optional
import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, Query, Security, status
from fastapi.responses import Response
from .. import db as library_db
from .. import model as library_model
router = APIRouter(
prefix="/folders",
tags=["library", "folders", "private"],
dependencies=[Security(autogpt_auth_lib.requires_user)],
)
@router.get(
"",
summary="List Library Folders",
response_model=library_model.FolderListResponse,
responses={
200: {"description": "List of folders"},
500: {"description": "Server error"},
},
)
async def list_folders(
user_id: str = Security(autogpt_auth_lib.get_user_id),
parent_id: Optional[str] = Query(
None,
description="Filter by parent folder ID. If not provided, returns root-level folders.",
),
include_relations: bool = Query(
True,
description="Include agent and subfolder relations (for counts)",
),
) -> library_model.FolderListResponse:
"""
List folders for the authenticated user.
Args:
user_id: ID of the authenticated user.
parent_id: Optional parent folder ID to filter by.
include_relations: Whether to include agent and subfolder relations for counts.
Returns:
A FolderListResponse containing folders.
"""
folders = await library_db.list_folders(
user_id=user_id,
parent_id=parent_id,
include_relations=include_relations,
)
return library_model.FolderListResponse(
folders=folders,
pagination=library_model.Pagination(
total_items=len(folders),
total_pages=1,
current_page=1,
page_size=len(folders),
),
)
@router.get(
"/tree",
summary="Get Folder Tree",
response_model=library_model.FolderTreeResponse,
responses={
200: {"description": "Folder tree structure"},
500: {"description": "Server error"},
},
)
async def get_folder_tree(
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.FolderTreeResponse:
"""
Get the full folder tree for the authenticated user.
Args:
user_id: ID of the authenticated user.
Returns:
A FolderTreeResponse containing the nested folder structure.
"""
tree = await library_db.get_folder_tree(user_id=user_id)
return library_model.FolderTreeResponse(tree=tree)
@router.get(
"/{folder_id}",
summary="Get Folder",
response_model=library_model.LibraryFolder,
responses={
200: {"description": "Folder details"},
404: {"description": "Folder not found"},
500: {"description": "Server error"},
},
)
async def get_folder(
folder_id: str,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.LibraryFolder:
"""
Get a specific folder.
Args:
folder_id: ID of the folder to retrieve.
user_id: ID of the authenticated user.
Returns:
The requested LibraryFolder.
"""
return await library_db.get_folder(folder_id=folder_id, user_id=user_id)
@router.post(
"",
summary="Create Folder",
status_code=status.HTTP_201_CREATED,
response_model=library_model.LibraryFolder,
responses={
201: {"description": "Folder created successfully"},
400: {"description": "Validation error"},
404: {"description": "Parent folder not found"},
409: {"description": "Folder name conflict"},
500: {"description": "Server error"},
},
)
async def create_folder(
payload: library_model.FolderCreateRequest,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.LibraryFolder:
"""
Create a new folder.
Args:
payload: The folder creation request.
user_id: ID of the authenticated user.
Returns:
The created LibraryFolder.
"""
return await library_db.create_folder(
user_id=user_id,
name=payload.name,
parent_id=payload.parent_id,
icon=payload.icon,
color=payload.color,
)
@router.patch(
"/{folder_id}",
summary="Update Folder",
response_model=library_model.LibraryFolder,
responses={
200: {"description": "Folder updated successfully"},
400: {"description": "Validation error"},
404: {"description": "Folder not found"},
409: {"description": "Folder name conflict"},
500: {"description": "Server error"},
},
)
async def update_folder(
folder_id: str,
payload: library_model.FolderUpdateRequest,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.LibraryFolder:
"""
Update a folder's properties.
Args:
folder_id: ID of the folder to update.
payload: The folder update request.
user_id: ID of the authenticated user.
Returns:
The updated LibraryFolder.
"""
return await library_db.update_folder(
folder_id=folder_id,
user_id=user_id,
name=payload.name,
icon=payload.icon,
color=payload.color,
)
@router.post(
"/{folder_id}/move",
summary="Move Folder",
response_model=library_model.LibraryFolder,
responses={
200: {"description": "Folder moved successfully"},
400: {"description": "Validation error (circular reference)"},
404: {"description": "Folder or target parent not found"},
409: {"description": "Folder name conflict in target location"},
500: {"description": "Server error"},
},
)
async def move_folder(
folder_id: str,
payload: library_model.FolderMoveRequest,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.LibraryFolder:
"""
Move a folder to a new parent.
Args:
folder_id: ID of the folder to move.
payload: The move request with target parent.
user_id: ID of the authenticated user.
Returns:
The moved LibraryFolder.
"""
return await library_db.move_folder(
folder_id=folder_id,
user_id=user_id,
target_parent_id=payload.target_parent_id,
)
@router.delete(
"/{folder_id}",
summary="Delete Folder",
status_code=status.HTTP_204_NO_CONTENT,
responses={
204: {"description": "Folder deleted successfully"},
404: {"description": "Folder not found"},
500: {"description": "Server error"},
},
)
async def delete_folder(
folder_id: str,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> Response:
"""
Soft-delete a folder and all its contents.
Args:
folder_id: ID of the folder to delete.
user_id: ID of the authenticated user.
Returns:
204 No Content if successful.
"""
await library_db.delete_folder(
folder_id=folder_id,
user_id=user_id,
soft_delete=True,
)
return Response(status_code=status.HTTP_204_NO_CONTENT)
# === Bulk Agent Operations ===
@router.post(
"/agents/bulk-move",
summary="Bulk Move Agents",
response_model=list[library_model.LibraryAgent],
responses={
200: {"description": "Agents moved successfully"},
404: {"description": "Folder not found"},
500: {"description": "Server error"},
},
)
async def bulk_move_agents(
payload: library_model.BulkMoveAgentsRequest,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> list[library_model.LibraryAgent]:
"""
Move multiple agents to a folder.
Args:
payload: The bulk move request with agent IDs and target folder.
user_id: ID of the authenticated user.
Returns:
The updated LibraryAgents.
"""
return await library_db.bulk_move_agents_to_folder(
agent_ids=payload.agent_ids,
folder_id=payload.folder_id,
user_id=user_id,
)

View File

@@ -115,6 +115,8 @@ async def test_get_library_agents_success(
sort_by=library_model.LibraryAgentSort.UPDATED_AT,
page=1,
page_size=15,
folder_id=None,
include_root_only=False,
)

View File

@@ -7,20 +7,24 @@ frontend can list available tools on an MCP server before placing a block.
import logging
from typing import Annotated, Any
from urllib.parse import urlparse
import fastapi
from autogpt_libs.auth import get_user_id
from fastapi import Security
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, SecretStr
from backend.api.features.integrations.router import CredentialsMetaResponse
from backend.blocks.mcp.client import MCPClient, MCPClientError
from backend.blocks.mcp.helpers import (
auto_lookup_mcp_credential,
normalize_mcp_url,
server_host,
)
from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.request import HTTPClientError, Requests
from backend.util.request import HTTPClientError, Requests, validate_url
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -74,32 +78,20 @@ async def discover_tools(
If the user has a stored MCP credential for this server URL, it will be
used automatically — no need to pass an explicit auth token.
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
auth_token = request.auth_token
# Auto-use stored MCP credential when no explicit token is provided.
if not auth_token:
mcp_creds = await creds_manager.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
best_cred = await auto_lookup_mcp_credential(
user_id, normalize_mcp_url(request.server_url)
)
# Find the freshest credential for this server URL
best_cred: OAuth2Credentials | None = None
for cred in mcp_creds:
if (
isinstance(cred, OAuth2Credentials)
and (cred.metadata or {}).get("mcp_server_url") == request.server_url
):
if best_cred is None or (
(cred.access_token_expires_at or 0)
> (best_cred.access_token_expires_at or 0)
):
best_cred = cred
if best_cred:
# Refresh the token if expired before using it
best_cred = await creds_manager.refresh_if_needed(user_id, best_cred)
logger.info(
f"Using MCP credential {best_cred.id} for {request.server_url}, "
f"expires_at={best_cred.access_token_expires_at}"
)
auth_token = best_cred.access_token.get_secret_value()
client = MCPClient(request.server_url, auth_token=auth_token)
@@ -134,7 +126,7 @@ async def discover_tools(
],
server_name=(
init_result.get("serverInfo", {}).get("name")
or urlparse(request.server_url).hostname
or server_host(request.server_url)
or "MCP"
),
protocol_version=init_result.get("protocolVersion"),
@@ -173,7 +165,16 @@ async def mcp_oauth_login(
3. Performs Dynamic Client Registration (RFC 7591) if available
4. Returns the authorization URL for the frontend to open in a popup
"""
client = MCPClient(request.server_url)
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
# Normalize the URL so that credentials stored here are matched consistently
# by auto_lookup_mcp_credential (which also uses normalized URLs).
server_url = normalize_mcp_url(request.server_url)
client = MCPClient(server_url)
# Step 1: Discover protected-resource metadata (RFC 9728)
protected_resource = await client.discover_auth()
@@ -182,7 +183,16 @@ async def mcp_oauth_login(
if protected_resource and protected_resource.get("authorization_servers"):
auth_server_url = protected_resource["authorization_servers"][0]
resource_url = protected_resource.get("resource", request.server_url)
resource_url = protected_resource.get("resource", server_url)
# Validate the auth server URL from metadata to prevent SSRF.
try:
await validate_url(auth_server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(
status_code=400,
detail=f"Invalid authorization server URL in metadata: {e}",
)
# Step 2a: Discover auth-server metadata (RFC 8414)
metadata = await client.discover_auth_server_metadata(auth_server_url)
@@ -192,7 +202,7 @@ async def mcp_oauth_login(
# Don't assume a resource_url — omitting it lets the auth server choose
# the correct audience for the token (RFC 8707 resource is optional).
resource_url = None
metadata = await client.discover_auth_server_metadata(request.server_url)
metadata = await client.discover_auth_server_metadata(server_url)
if (
not metadata
@@ -222,12 +232,18 @@ async def mcp_oauth_login(
client_id = ""
client_secret = ""
if registration_endpoint:
reg_result = await _register_mcp_client(
registration_endpoint, redirect_uri, request.server_url
)
if reg_result:
client_id = reg_result.get("client_id", "")
client_secret = reg_result.get("client_secret", "")
# Validate the registration endpoint to prevent SSRF via metadata.
try:
await validate_url(registration_endpoint, trusted_origins=[])
except ValueError:
pass # Skip registration, fall back to default client_id
else:
reg_result = await _register_mcp_client(
registration_endpoint, redirect_uri, server_url
)
if reg_result:
client_id = reg_result.get("client_id", "")
client_secret = reg_result.get("client_secret", "")
if not client_id:
client_id = "autogpt-platform"
@@ -245,7 +261,7 @@ async def mcp_oauth_login(
"token_url": token_url,
"revoke_url": revoke_url,
"resource_url": resource_url,
"server_url": request.server_url,
"server_url": server_url,
"client_id": client_id,
"client_secret": client_secret,
},
@@ -342,7 +358,7 @@ async def mcp_oauth_callback(
credentials.metadata["mcp_token_url"] = meta["token_url"]
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
hostname = server_host(meta["server_url"])
credentials.title = f"MCP: {hostname}"
# Remove old MCP credentials for the same server to prevent stale token buildup.
@@ -357,7 +373,9 @@ async def mcp_oauth_callback(
):
await creds_manager.store.delete_creds_by_id(user_id, old.id)
logger.info(
f"Removed old MCP credential {old.id} for {meta['server_url']}"
"Removed old MCP credential %s for %s",
old.id,
server_host(meta["server_url"]),
)
except Exception:
logger.debug("Could not clean up old MCP credentials", exc_info=True)
@@ -375,6 +393,93 @@ async def mcp_oauth_callback(
)
# ======================== Bearer Token ======================== #
class MCPStoreTokenRequest(BaseModel):
"""Request to store a bearer token for an MCP server that doesn't support OAuth."""
server_url: str = Field(
description="MCP server URL the token authenticates against"
)
token: SecretStr = Field(
min_length=1, description="Bearer token / API key for the MCP server"
)
@router.post(
"/token",
summary="Store a bearer token for an MCP server",
)
async def mcp_store_token(
request: MCPStoreTokenRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> CredentialsMetaResponse:
"""
Store a manually provided bearer token as an MCP credential.
Used by the Copilot MCPSetupCard when the server doesn't support the MCP
OAuth discovery flow (returns 400 from /oauth/login). Subsequent
``run_mcp_tool`` calls will automatically pick up the token via
``_auto_lookup_credential``.
"""
token = request.token.get_secret_value().strip()
if not token:
raise fastapi.HTTPException(status_code=422, detail="Token must not be blank.")
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
# Normalize URL so trailing-slash variants match existing credentials.
server_url = normalize_mcp_url(request.server_url)
hostname = server_host(server_url)
# Collect IDs of old credentials to clean up after successful create.
old_cred_ids: list[str] = []
try:
old_creds = await creds_manager.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
old_cred_ids = [
old.id
for old in old_creds
if isinstance(old, OAuth2Credentials)
and normalize_mcp_url((old.metadata or {}).get("mcp_server_url", ""))
== server_url
]
except Exception:
logger.debug("Could not query old MCP token credentials", exc_info=True)
credentials = OAuth2Credentials(
provider=ProviderName.MCP.value,
title=f"MCP: {hostname}",
access_token=SecretStr(token),
scopes=[],
metadata={"mcp_server_url": server_url},
)
await creds_manager.create(user_id, credentials)
# Only delete old credentials after the new one is safely stored.
for old_id in old_cred_ids:
try:
await creds_manager.store.delete_creds_by_id(user_id, old_id)
except Exception:
logger.debug("Could not clean up old MCP token credential", exc_info=True)
return CredentialsMetaResponse(
id=credentials.id,
provider=credentials.provider,
type=credentials.type,
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
host=hostname,
)
# ======================== Helpers ======================== #
@@ -400,5 +505,7 @@ async def _register_mcp_client(
return data
return None
except Exception as e:
logger.warning(f"Dynamic client registration failed for {server_url}: {e}")
logger.warning(
"Dynamic client registration failed for %s: %s", server_host(server_url), e
)
return None

View File

@@ -11,9 +11,11 @@ import httpx
import pytest
import pytest_asyncio
from autogpt_libs.auth import get_user_id
from pydantic import SecretStr
from backend.api.features.mcp.routes import router
from backend.blocks.mcp.client import MCPClientError, MCPTool
from backend.data.model import OAuth2Credentials
from backend.util.request import HTTPClientError
app = fastapi.FastAPI()
@@ -28,6 +30,16 @@ async def client():
yield c
@pytest.fixture(autouse=True)
def _bypass_ssrf_validation():
"""Bypass validate_url in all route tests (test URLs don't resolve)."""
with patch(
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
):
yield
class TestDiscoverTools:
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_success(self, client):
@@ -56,9 +68,12 @@ class TestDiscoverTools:
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
),
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={
@@ -107,10 +122,6 @@ class TestDiscoverTools:
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_auto_uses_stored_credential(self, client):
"""When no explicit token is given, stored MCP credentials are used."""
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
stored_cred = OAuth2Credentials(
provider="mcp",
title="MCP: example.com",
@@ -124,10 +135,12 @@ class TestDiscoverTools:
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=stored_cred,
),
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred])
mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred)
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
@@ -149,9 +162,12 @@ class TestDiscoverTools:
async def test_discover_tools_mcp_error(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
),
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=MCPClientError("Connection refused")
@@ -169,9 +185,12 @@ class TestDiscoverTools:
async def test_discover_tools_generic_error(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
),
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
@@ -187,9 +206,12 @@ class TestDiscoverTools:
async def test_discover_tools_auth_required(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
),
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
@@ -207,9 +229,12 @@ class TestDiscoverTools:
async def test_discover_tools_forbidden(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
),
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
@@ -331,10 +356,6 @@ class TestOAuthLogin:
class TestOAuthCallback:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_success(self, client):
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
mock_creds = OAuth2Credentials(
provider="mcp",
title=None,
@@ -434,3 +455,118 @@ class TestOAuthCallback:
assert response.status_code == 400
assert "token exchange failed" in response.json()["detail"].lower()
class TestStoreToken:
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_success(self, client):
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
mock_cm.create = AsyncMock()
response = await client.post(
"/token",
json={
"server_url": "https://mcp.example.com/mcp",
"token": "my-api-key-123",
},
)
assert response.status_code == 200
data = response.json()
assert data["provider"] == "mcp"
assert data["type"] == "oauth2"
assert data["host"] == "mcp.example.com"
mock_cm.create.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_blank_rejected(self, client):
"""Blank token string (after stripping) should return 422."""
response = await client.post(
"/token",
json={
"server_url": "https://mcp.example.com/mcp",
"token": " ",
},
)
# Pydantic min_length=1 catches the whitespace-only token
assert response.status_code == 422
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_replaces_old_credential(self, client):
old_cred = OAuth2Credentials(
provider="mcp",
title="MCP: mcp.example.com",
access_token=SecretStr("old-token"),
scopes=[],
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
)
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[old_cred])
mock_cm.create = AsyncMock()
mock_cm.store.delete_creds_by_id = AsyncMock()
response = await client.post(
"/token",
json={
"server_url": "https://mcp.example.com/mcp",
"token": "new-token",
},
)
assert response.status_code == 200
mock_cm.store.delete_creds_by_id.assert_called_once_with(
"test-user-id", old_cred.id
)
class TestSSRFValidation:
"""Verify that validate_url is enforced on all endpoints."""
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):
response = await client.post(
"/discover-tools",
json={"server_url": "http://localhost/mcp"},
)
assert response.status_code == 400
assert "blocked loopback" in response.json()["detail"].lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
side_effect=ValueError("blocked private IP"),
):
response = await client.post(
"/oauth/login",
json={"server_url": "http://10.0.0.1/mcp"},
)
assert response.status_code == 400
assert "blocked private ip" in response.json()["detail"].lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):
response = await client.post(
"/token",
json={
"server_url": "http://127.0.0.1/mcp",
"token": "some-token",
},
)
assert response.status_code == 400
assert "blocked loopback" in response.json()["detail"].lower()

View File

@@ -9,15 +9,26 @@ import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Any, get_args, get_origin
from prisma.enums import ContentType
from backend.blocks.llm import LlmModel
from backend.data.db import query_raw_with_schema
logger = logging.getLogger(__name__)
def _contains_type(annotation: Any, target: type) -> bool:
"""Check if an annotation is or contains the target type (handles Optional/Union/Annotated)."""
if annotation is target:
return True
origin = get_origin(annotation)
if origin is None:
return False
return any(_contains_type(arg, target) for arg in get_args(annotation))
@dataclass
class ContentItem:
"""Represents a piece of content to be embedded."""
@@ -188,45 +199,51 @@ class BlockHandler(ContentHandler):
try:
block_instance = block_cls()
# Skip disabled blocks - they shouldn't be indexed
if block_instance.disabled:
continue
# Build searchable text from block metadata
parts = []
if hasattr(block_instance, "name") and block_instance.name:
if block_instance.name:
parts.append(block_instance.name)
if (
hasattr(block_instance, "description")
and block_instance.description
):
if block_instance.description:
parts.append(block_instance.description)
if hasattr(block_instance, "categories") and block_instance.categories:
# Convert BlockCategory enum to strings
if block_instance.categories:
parts.append(
" ".join(str(cat.value) for cat in block_instance.categories)
)
# Add input/output schema info
if hasattr(block_instance, "input_schema"):
schema = block_instance.input_schema
if hasattr(schema, "model_json_schema"):
schema_dict = schema.model_json_schema()
if "properties" in schema_dict:
for prop_name, prop_info in schema_dict[
"properties"
].items():
if "description" in prop_info:
parts.append(
f"{prop_name}: {prop_info['description']}"
)
# Add input schema field descriptions
block_input_fields = block_instance.input_schema.model_fields
parts += [
f"{field_name}: {field_info.description}"
for field_name, field_info in block_input_fields.items()
if field_info.description
]
searchable_text = " ".join(parts)
# Convert categories set of enums to list of strings for JSON serialization
categories = getattr(block_instance, "categories", set())
categories_list = (
[cat.value for cat in categories] if categories else []
[cat.value for cat in block_instance.categories]
if block_instance.categories
else []
)
# Extract provider names from credentials fields
credentials_info = (
block_instance.input_schema.get_credentials_fields_info()
)
is_integration = len(credentials_info) > 0
provider_names = [
provider.value.lower()
for info in credentials_info.values()
for provider in info.provider
]
# Check if block has LlmModel field in input schema
has_llm_model_field = any(
_contains_type(field.annotation, LlmModel)
for field in block_instance.input_schema.model_fields.values()
)
items.append(
@@ -235,8 +252,11 @@ class BlockHandler(ContentHandler):
content_type=ContentType.BLOCK,
searchable_text=searchable_text,
metadata={
"name": getattr(block_instance, "name", ""),
"name": block_instance.name,
"categories": categories_list,
"providers": provider_names,
"has_llm_model_field": has_llm_model_field,
"is_integration": is_integration,
},
user_id=None, # Blocks are public
)

View File

@@ -82,9 +82,10 @@ async def test_block_handler_get_missing_items(mocker):
mock_block_instance.description = "Performs calculations"
mock_block_instance.categories = [MagicMock(value="MATH")]
mock_block_instance.disabled = False
mock_block_instance.input_schema.model_json_schema.return_value = {
"properties": {"expression": {"description": "Math expression to evaluate"}}
}
mock_field = MagicMock()
mock_field.description = "Math expression to evaluate"
mock_block_instance.input_schema.model_fields = {"expression": mock_field}
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
mock_block_class.return_value = mock_block_instance
mock_blocks = {"block-uuid-1": mock_block_class}
@@ -309,19 +310,19 @@ async def test_content_handlers_registry():
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_handles_missing_attributes():
"""Test BlockHandler gracefully handles blocks with missing attributes."""
async def test_block_handler_handles_empty_attributes():
"""Test BlockHandler handles blocks with empty/falsy attribute values."""
handler = BlockHandler()
# Mock block with minimal attributes
# Mock block with empty values (all attributes exist but are falsy)
mock_block_class = MagicMock()
mock_block_instance = MagicMock()
mock_block_instance.name = "Minimal Block"
mock_block_instance.disabled = False
# No description, categories, or schema
del mock_block_instance.description
del mock_block_instance.categories
del mock_block_instance.input_schema
mock_block_instance.description = ""
mock_block_instance.categories = set()
mock_block_instance.input_schema.model_fields = {}
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
mock_block_class.return_value = mock_block_instance
mock_blocks = {"block-minimal": mock_block_class}
@@ -352,6 +353,8 @@ async def test_block_handler_skips_failed_blocks():
good_instance.description = "Works fine"
good_instance.categories = []
good_instance.disabled = False
good_instance.input_schema.model_fields = {}
good_instance.input_schema.get_credentials_fields_info.return_value = {}
good_block.return_value = good_instance
bad_block = MagicMock()

View File

@@ -22,6 +22,7 @@ from backend.data.notifications import (
AgentApprovalData,
AgentRejectionData,
NotificationEventModel,
WaitlistLaunchData,
)
from backend.notifications.notifications import queue_notification_async
from backend.util.exceptions import DatabaseError
@@ -1730,6 +1731,29 @@ async def review_store_submission(
# Don't fail the review process if email sending fails
pass
# Notify waitlist users if this is an approval and has a linked waitlist
if is_approved and submission.StoreListing:
try:
frontend_base_url = (
settings.config.frontend_base_url
or settings.config.platform_base_url
)
store_agent = (
await prisma.models.StoreAgent.prisma().find_first_or_raise(
where={"storeListingVersionId": submission.id}
)
)
creator_username = store_agent.creator_username or "unknown"
store_url = f"{frontend_base_url}/marketplace/agent/{creator_username}/{store_agent.slug}"
await notify_waitlist_users_on_launch(
store_listing_id=submission.StoreListing.id,
agent_name=submission.name,
store_url=store_url,
)
except Exception as e:
logger.error(f"Failed to notify waitlist users on agent approval: {e}")
# Don't fail the approval process
# Convert to Pydantic model for consistency
return store_model.StoreSubmission(
listing_id=(submission.StoreListing.id if submission.StoreListing else ""),
@@ -1977,3 +2001,557 @@ async def get_agent_as_admin(
)
return graph
def _waitlist_to_store_entry(
waitlist: prisma.models.WaitlistEntry,
) -> store_model.StoreWaitlistEntry:
"""Convert a WaitlistEntry to StoreWaitlistEntry for public display."""
return store_model.StoreWaitlistEntry(
waitlistId=waitlist.id,
slug=waitlist.slug,
name=waitlist.name,
subHeading=waitlist.subHeading,
videoUrl=waitlist.videoUrl,
agentOutputDemoUrl=waitlist.agentOutputDemoUrl,
imageUrls=waitlist.imageUrls or [],
description=waitlist.description,
categories=waitlist.categories or [],
)
async def get_waitlist() -> list[store_model.StoreWaitlistEntry]:
"""Get all active waitlists for public display."""
try:
waitlists = await prisma.models.WaitlistEntry.prisma().find_many(
where=prisma.types.WaitlistEntryWhereInput(isDeleted=False),
)
# Filter out closed/done waitlists and sort by votes (descending)
excluded_statuses = {
prisma.enums.WaitlistExternalStatus.CANCELED,
prisma.enums.WaitlistExternalStatus.DONE,
}
active_waitlists = [w for w in waitlists if w.status not in excluded_statuses]
sorted_list = sorted(active_waitlists, key=lambda x: x.votes, reverse=True)
return [_waitlist_to_store_entry(w) for w in sorted_list]
except Exception as e:
logger.error(f"Error fetching waitlists: {e}")
raise DatabaseError("Failed to fetch waitlists") from e
async def get_user_waitlist_memberships(user_id: str) -> list[str]:
"""Get all waitlist IDs that a user has joined."""
try:
user = await prisma.models.User.prisma().find_unique(
where={"id": user_id},
include={"JoinedWaitlists": True},
)
if not user or not user.JoinedWaitlists:
return []
return [w.id for w in user.JoinedWaitlists]
except Exception as e:
logger.error(f"Error fetching user waitlist memberships: {e}")
raise DatabaseError("Failed to fetch waitlist memberships") from e
async def add_user_to_waitlist(
waitlist_id: str, user_id: str | None, email: str | None
) -> store_model.StoreWaitlistEntry:
"""
Add a user to a waitlist.
For logged-in users: connects via joinedUsers relation
For anonymous users: adds email to unaffiliatedEmailUsers array
"""
if not user_id and not email:
raise ValueError("Either user_id or email must be provided")
try:
# Find the waitlist
waitlist = await prisma.models.WaitlistEntry.prisma().find_unique(
where={"id": waitlist_id},
include={"JoinedUsers": True},
)
if not waitlist:
raise ValueError(f"Waitlist {waitlist_id} not found")
if waitlist.isDeleted:
raise ValueError(f"Waitlist {waitlist_id} is no longer available")
if waitlist.status in [
prisma.enums.WaitlistExternalStatus.CANCELED,
prisma.enums.WaitlistExternalStatus.DONE,
]:
raise ValueError(f"Waitlist {waitlist_id} is closed")
if user_id:
# Check if user already joined
joined_user_ids = [u.id for u in (waitlist.JoinedUsers or [])]
if user_id in joined_user_ids:
# Already joined - return waitlist info
logger.debug(f"User {user_id} already joined waitlist {waitlist_id}")
else:
# Connect user to waitlist
await prisma.models.WaitlistEntry.prisma().update(
where={"id": waitlist_id},
data={"JoinedUsers": {"connect": [{"id": user_id}]}},
)
logger.info(f"User {user_id} joined waitlist {waitlist_id}")
# If user was previously in email list, remove them
# Use transaction to prevent race conditions
if email:
async with transaction() as tx:
current_waitlist = await tx.waitlistentry.find_unique(
where={"id": waitlist_id}
)
if current_waitlist and email in (
current_waitlist.unaffiliatedEmailUsers or []
):
updated_emails: list[str] = [
e
for e in (current_waitlist.unaffiliatedEmailUsers or [])
if e != email
]
await tx.waitlistentry.update(
where={"id": waitlist_id},
data={"unaffiliatedEmailUsers": updated_emails},
)
elif email:
# Add email to unaffiliated list if not already present
# Use transaction to prevent race conditions with concurrent signups
async with transaction() as tx:
# Re-fetch within transaction to get latest state
current_waitlist = await tx.waitlistentry.find_unique(
where={"id": waitlist_id}
)
if current_waitlist:
current_emails: list[str] = list(
current_waitlist.unaffiliatedEmailUsers or []
)
if email not in current_emails:
current_emails.append(email)
await tx.waitlistentry.update(
where={"id": waitlist_id},
data={"unaffiliatedEmailUsers": current_emails},
)
# Mask email for logging to avoid PII exposure
parts = email.split("@") if "@" in email else []
local = parts[0] if len(parts) > 0 else ""
domain = parts[1] if len(parts) > 1 else "unknown"
masked = (local[0] if local else "?") + "***@" + domain
logger.info(f"Email {masked} added to waitlist {waitlist_id}")
else:
logger.debug(f"Email already exists on waitlist {waitlist_id}")
# Re-fetch to return updated data
updated_waitlist = await prisma.models.WaitlistEntry.prisma().find_unique(
where={"id": waitlist_id}
)
return _waitlist_to_store_entry(updated_waitlist or waitlist)
except ValueError:
raise
except Exception as e:
logger.error(f"Error adding user to waitlist: {e}")
raise DatabaseError("Failed to add user to waitlist") from e
# ============== Admin Waitlist Functions ==============
def _waitlist_to_admin_response(
waitlist: prisma.models.WaitlistEntry,
) -> store_model.WaitlistAdminResponse:
"""Convert a WaitlistEntry to WaitlistAdminResponse."""
joined_count = len(waitlist.JoinedUsers) if waitlist.JoinedUsers else 0
email_count = (
len(waitlist.unaffiliatedEmailUsers) if waitlist.unaffiliatedEmailUsers else 0
)
return store_model.WaitlistAdminResponse(
id=waitlist.id,
createdAt=waitlist.createdAt.isoformat() if waitlist.createdAt else "",
updatedAt=waitlist.updatedAt.isoformat() if waitlist.updatedAt else "",
slug=waitlist.slug,
name=waitlist.name,
subHeading=waitlist.subHeading,
description=waitlist.description,
categories=waitlist.categories or [],
imageUrls=waitlist.imageUrls or [],
videoUrl=waitlist.videoUrl,
agentOutputDemoUrl=waitlist.agentOutputDemoUrl,
status=waitlist.status or prisma.enums.WaitlistExternalStatus.NOT_STARTED,
votes=waitlist.votes,
signupCount=joined_count + email_count,
storeListingId=waitlist.storeListingId,
owningUserId=waitlist.owningUserId,
)
async def create_waitlist_admin(
admin_user_id: str,
data: store_model.WaitlistCreateRequest,
) -> store_model.WaitlistAdminResponse:
"""Create a new waitlist (admin only)."""
logger.info(f"Admin {admin_user_id} creating waitlist: {data.name}")
try:
waitlist = await prisma.models.WaitlistEntry.prisma().create(
data=prisma.types.WaitlistEntryCreateInput(
name=data.name,
slug=data.slug,
subHeading=data.subHeading,
description=data.description,
categories=data.categories,
imageUrls=data.imageUrls,
videoUrl=data.videoUrl,
agentOutputDemoUrl=data.agentOutputDemoUrl,
owningUserId=admin_user_id,
status=prisma.enums.WaitlistExternalStatus.NOT_STARTED,
),
include={"JoinedUsers": True},
)
return _waitlist_to_admin_response(waitlist)
except Exception as e:
logger.error(f"Error creating waitlist: {e}")
raise DatabaseError("Failed to create waitlist") from e
async def get_waitlists_admin() -> store_model.WaitlistAdminListResponse:
"""Get all waitlists with admin details."""
try:
waitlists = await prisma.models.WaitlistEntry.prisma().find_many(
where=prisma.types.WaitlistEntryWhereInput(isDeleted=False),
include={"JoinedUsers": True},
order={"createdAt": "desc"},
)
return store_model.WaitlistAdminListResponse(
waitlists=[_waitlist_to_admin_response(w) for w in waitlists],
totalCount=len(waitlists),
)
except Exception as e:
logger.error(f"Error fetching waitlists for admin: {e}")
raise DatabaseError("Failed to fetch waitlists") from e
async def get_waitlist_admin(
waitlist_id: str,
) -> store_model.WaitlistAdminResponse:
"""Get a single waitlist with admin details."""
try:
waitlist = await prisma.models.WaitlistEntry.prisma().find_unique(
where={"id": waitlist_id},
include={"JoinedUsers": True},
)
if not waitlist:
raise ValueError(f"Waitlist {waitlist_id} not found")
if waitlist.isDeleted:
raise ValueError(f"Waitlist {waitlist_id} has been deleted")
return _waitlist_to_admin_response(waitlist)
except ValueError:
raise
except Exception as e:
logger.error(f"Error fetching waitlist {waitlist_id}: {e}")
raise DatabaseError("Failed to fetch waitlist") from e
async def update_waitlist_admin(
waitlist_id: str,
data: store_model.WaitlistUpdateRequest,
) -> store_model.WaitlistAdminResponse:
"""Update a waitlist (admin only)."""
logger.info(f"Updating waitlist {waitlist_id}")
try:
# Check if waitlist exists first
existing = await prisma.models.WaitlistEntry.prisma().find_unique(
where={"id": waitlist_id}
)
if not existing:
raise ValueError(f"Waitlist {waitlist_id} not found")
if existing.isDeleted:
raise ValueError(f"Waitlist {waitlist_id} has been deleted")
# Build update data from explicitly provided fields
# Use model_fields_set to allow clearing fields by setting them to None
field_mappings = {
"name": data.name,
"slug": data.slug,
"subHeading": data.subHeading,
"description": data.description,
"categories": data.categories,
"imageUrls": data.imageUrls,
"videoUrl": data.videoUrl,
"agentOutputDemoUrl": data.agentOutputDemoUrl,
"storeListingId": data.storeListingId,
}
update_data: dict[str, Any] = {
k: v for k, v in field_mappings.items() if k in data.model_fields_set
}
# Add status if provided (already validated as enum by Pydantic)
if "status" in data.model_fields_set and data.status is not None:
update_data["status"] = data.status
if not update_data:
# No updates, just return current data
return await get_waitlist_admin(waitlist_id)
waitlist = await prisma.models.WaitlistEntry.prisma().update(
where={"id": waitlist_id},
data=prisma.types.WaitlistEntryUpdateInput(**update_data),
include={"JoinedUsers": True},
)
# We already verified existence above, so this should never be None
assert waitlist is not None
return _waitlist_to_admin_response(waitlist)
except ValueError:
raise
except Exception as e:
logger.error(f"Error updating waitlist {waitlist_id}: {e}")
raise DatabaseError("Failed to update waitlist") from e
async def delete_waitlist_admin(waitlist_id: str) -> None:
"""Soft delete a waitlist (admin only)."""
logger.info(f"Soft deleting waitlist {waitlist_id}")
try:
# Check if waitlist exists first
waitlist = await prisma.models.WaitlistEntry.prisma().find_unique(
where={"id": waitlist_id},
)
if not waitlist:
raise ValueError(f"Waitlist {waitlist_id} not found")
if waitlist.isDeleted:
raise ValueError(f"Waitlist {waitlist_id} has already been deleted")
await prisma.models.WaitlistEntry.prisma().update(
where={"id": waitlist_id},
data={"isDeleted": True},
)
except ValueError:
raise
except Exception as e:
logger.error(f"Error deleting waitlist {waitlist_id}: {e}")
raise DatabaseError("Failed to delete waitlist") from e
async def get_waitlist_signups_admin(
waitlist_id: str,
) -> store_model.WaitlistSignupListResponse:
"""Get all signups for a waitlist (admin only)."""
try:
waitlist = await prisma.models.WaitlistEntry.prisma().find_unique(
where={"id": waitlist_id},
include={"JoinedUsers": True},
)
if not waitlist:
raise ValueError(f"Waitlist {waitlist_id} not found")
signups: list[store_model.WaitlistSignup] = []
# Add user signups
for user in waitlist.JoinedUsers or []:
signups.append(
store_model.WaitlistSignup(
type="user",
userId=user.id,
email=user.email,
username=user.name,
)
)
# Add email signups
for email in waitlist.unaffiliatedEmailUsers or []:
signups.append(
store_model.WaitlistSignup(
type="email",
email=email,
)
)
return store_model.WaitlistSignupListResponse(
waitlistId=waitlist_id,
signups=signups,
totalCount=len(signups),
)
except ValueError:
raise
except Exception as e:
logger.error(f"Error fetching signups for waitlist {waitlist_id}: {e}")
raise DatabaseError("Failed to fetch waitlist signups") from e
async def link_waitlist_to_listing_admin(
waitlist_id: str,
store_listing_id: str,
) -> store_model.WaitlistAdminResponse:
"""Link a waitlist to a store listing (admin only)."""
logger.info(f"Linking waitlist {waitlist_id} to listing {store_listing_id}")
try:
# Verify the waitlist exists
waitlist = await prisma.models.WaitlistEntry.prisma().find_unique(
where={"id": waitlist_id}
)
if not waitlist:
raise ValueError(f"Waitlist {waitlist_id} not found")
if waitlist.isDeleted:
raise ValueError(f"Waitlist {waitlist_id} has been deleted")
# Verify the store listing exists
listing = await prisma.models.StoreListing.prisma().find_unique(
where={"id": store_listing_id}
)
if not listing:
raise ValueError(f"Store listing {store_listing_id} not found")
updated_waitlist = await prisma.models.WaitlistEntry.prisma().update(
where={"id": waitlist_id},
data={"StoreListing": {"connect": {"id": store_listing_id}}},
include={"JoinedUsers": True},
)
# We already verified existence above, so this should never be None
assert updated_waitlist is not None
return _waitlist_to_admin_response(updated_waitlist)
except ValueError:
raise
except Exception as e:
logger.error(f"Error linking waitlist to listing: {e}")
raise DatabaseError("Failed to link waitlist to listing") from e
async def notify_waitlist_users_on_launch(
store_listing_id: str,
agent_name: str,
store_url: str,
) -> int:
"""
Notify all users on waitlists linked to a store listing when the agent is launched.
Args:
store_listing_id: The ID of the store listing that was approved
agent_name: The name of the approved agent
store_url: The URL to the agent's store page
Returns:
The number of notifications sent
"""
logger.info(f"Notifying waitlist users for store listing {store_listing_id}")
try:
# Find all active waitlists linked to this store listing
# Exclude DONE and CANCELED to prevent duplicate notifications on re-approval
waitlists = await prisma.models.WaitlistEntry.prisma().find_many(
where={
"storeListingId": store_listing_id,
"isDeleted": False,
"status": {
"not_in": [
prisma.enums.WaitlistExternalStatus.DONE,
prisma.enums.WaitlistExternalStatus.CANCELED,
]
},
},
include={"JoinedUsers": True},
)
if not waitlists:
logger.info(
f"No active waitlists found for store listing {store_listing_id}"
)
return 0
notification_count = 0
launched_at = datetime.now(tz=timezone.utc)
for waitlist in waitlists:
# Track notification results for this waitlist
users_to_notify = waitlist.JoinedUsers or []
failed_user_ids: list[str] = []
# Notify registered users
for user in users_to_notify:
try:
notification_data = WaitlistLaunchData(
agent_name=agent_name,
waitlist_name=waitlist.name,
store_url=store_url,
launched_at=launched_at,
)
notification_event = NotificationEventModel[WaitlistLaunchData](
user_id=user.id,
type=prisma.enums.NotificationType.WAITLIST_LAUNCH,
data=notification_data,
)
await queue_notification_async(notification_event)
notification_count += 1
except Exception as e:
logger.error(
f"Failed to send waitlist launch notification to user {user.id}: {e}"
)
failed_user_ids.append(user.id)
# Note: For unaffiliated email users, you would need to send emails directly
# since they don't have user IDs for the notification system.
# This could be done via a separate email service.
# For now, we log these for potential manual follow-up or future implementation.
has_pending_email_users = bool(waitlist.unaffiliatedEmailUsers)
if has_pending_email_users:
logger.info(
f"Waitlist {waitlist.id} has {len(waitlist.unaffiliatedEmailUsers)} "
f"unaffiliated email users that need email notifications"
)
# Only mark waitlist as DONE if all registered user notifications succeeded
# AND there are no unaffiliated email users still waiting for notifications
if not failed_user_ids and not has_pending_email_users:
await prisma.models.WaitlistEntry.prisma().update(
where={"id": waitlist.id},
data={"status": prisma.enums.WaitlistExternalStatus.DONE},
)
logger.info(f"Updated waitlist {waitlist.id} status to DONE")
elif failed_user_ids:
logger.warning(
f"Waitlist {waitlist.id} not marked as DONE due to "
f"{len(failed_user_ids)} failed notifications"
)
elif has_pending_email_users:
logger.warning(
f"Waitlist {waitlist.id} not marked as DONE due to "
f"{len(waitlist.unaffiliatedEmailUsers)} pending email-only users"
)
logger.info(
f"Sent {notification_count} waitlist launch notifications for store listing {store_listing_id}"
)
return notification_count
except Exception as e:
logger.error(
f"Error notifying waitlist users for store listing {store_listing_id}: {e}"
)
# Don't raise - we don't want to fail the approval process
return 0

View File

@@ -224,6 +224,102 @@ class ReviewSubmissionRequest(pydantic.BaseModel):
internal_comments: str | None = None # Private admin notes
class StoreWaitlistEntry(pydantic.BaseModel):
"""Public waitlist entry - no PII fields exposed."""
waitlistId: str
slug: str
# Content fields
name: str
subHeading: str
videoUrl: str | None = None
agentOutputDemoUrl: str | None = None
imageUrls: list[str]
description: str
categories: list[str]
class StoreWaitlistsAllResponse(pydantic.BaseModel):
listings: list[StoreWaitlistEntry]
# Admin Waitlist Models
class WaitlistCreateRequest(pydantic.BaseModel):
"""Request model for creating a new waitlist."""
name: str
slug: str
subHeading: str
description: str
categories: list[str] = []
imageUrls: list[str] = []
videoUrl: str | None = None
agentOutputDemoUrl: str | None = None
class WaitlistUpdateRequest(pydantic.BaseModel):
"""Request model for updating a waitlist."""
name: str | None = None
slug: str | None = None
subHeading: str | None = None
description: str | None = None
categories: list[str] | None = None
imageUrls: list[str] | None = None
videoUrl: str | None = None
agentOutputDemoUrl: str | None = None
status: prisma.enums.WaitlistExternalStatus | None = None
storeListingId: str | None = None # Link to a store listing
class WaitlistAdminResponse(pydantic.BaseModel):
"""Admin response model with full waitlist details including internal data."""
id: str
createdAt: str
updatedAt: str
slug: str
name: str
subHeading: str
description: str
categories: list[str]
imageUrls: list[str]
videoUrl: str | None = None
agentOutputDemoUrl: str | None = None
status: prisma.enums.WaitlistExternalStatus
votes: int
signupCount: int # Total count of joinedUsers + unaffiliatedEmailUsers
storeListingId: str | None = None
owningUserId: str
class WaitlistSignup(pydantic.BaseModel):
"""Individual signup entry for a waitlist."""
type: str # "user" or "email"
userId: str | None = None
email: str | None = None
username: str | None = None # For user signups
class WaitlistSignupListResponse(pydantic.BaseModel):
"""Response model for listing waitlist signups."""
waitlistId: str
signups: list[WaitlistSignup]
totalCount: int
class WaitlistAdminListResponse(pydantic.BaseModel):
"""Response model for listing all waitlists (admin view)."""
waitlists: list[WaitlistAdminResponse]
totalCount: int
class UnifiedSearchResult(pydantic.BaseModel):
"""A single result from unified hybrid search across all content types."""

View File

@@ -8,6 +8,8 @@ import autogpt_libs.auth
import fastapi
import fastapi.responses
import prisma.enums
import pydantic
from autogpt_libs.auth.dependencies import get_optional_user_id
import backend.data.graph
import backend.util.json
@@ -81,6 +83,74 @@ async def update_or_create_profile(
return updated_profile
##############################################
############## Waitlist Endpoints ############
##############################################
@router.get(
"/waitlist",
summary="Get the agent waitlist",
tags=["store", "public"],
response_model=store_model.StoreWaitlistsAllResponse,
)
async def get_waitlist():
"""
Get all active waitlists for public display.
"""
waitlists = await store_db.get_waitlist()
return store_model.StoreWaitlistsAllResponse(listings=waitlists)
@router.get(
"/waitlist/my-memberships",
summary="Get waitlist IDs the current user has joined",
tags=["store", "private"],
)
async def get_my_waitlist_memberships(
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
) -> list[str]:
"""Returns list of waitlist IDs the authenticated user has joined."""
return await store_db.get_user_waitlist_memberships(user_id)
@router.post(
path="/waitlist/{waitlist_id}/join",
summary="Add self to the agent waitlist",
tags=["store", "public"],
response_model=store_model.StoreWaitlistEntry,
)
async def add_self_to_waitlist(
user_id: str | None = fastapi.Security(get_optional_user_id),
waitlist_id: str = fastapi.Path(..., description="The ID of the waitlist to join"),
email: pydantic.EmailStr | None = fastapi.Body(
default=None, embed=True, description="Email address for unauthenticated users"
),
):
"""
Add the current user to the agent waitlist.
"""
if not user_id and not email:
raise fastapi.HTTPException(
status_code=400,
detail="Either user authentication or email address is required",
)
try:
waitlist_entry = await store_db.add_user_to_waitlist(
waitlist_id=waitlist_id, user_id=user_id, email=email
)
return waitlist_entry
except ValueError as e:
error_msg = str(e)
if "not found" in error_msg:
raise fastapi.HTTPException(status_code=404, detail="Waitlist not found")
# Waitlist exists but is closed or unavailable
raise fastapi.HTTPException(status_code=400, detail=error_msg)
except Exception:
raise fastapi.HTTPException(
status_code=500, detail="An error occurred while joining the waitlist"
)
##############################################
############### Agent Endpoints ##############
##############################################

View File

@@ -3,15 +3,29 @@ Workspace API routes for managing user file storage.
"""
import logging
import os
import re
from typing import Annotated
from urllib.parse import quote
import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
from fastapi import Query, UploadFile
from fastapi.responses import Response
from pydantic import BaseModel
from backend.data.workspace import WorkspaceFile, get_workspace, get_workspace_file
from backend.data.workspace import (
WorkspaceFile,
count_workspace_files,
get_or_create_workspace,
get_workspace,
get_workspace_file,
get_workspace_total_size,
soft_delete_workspace_file,
)
from backend.util.settings import Config
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import WorkspaceManager
from backend.util.workspace_storage import get_workspace_storage
@@ -98,6 +112,25 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
raise
class UploadFileResponse(BaseModel):
file_id: str
name: str
path: str
mime_type: str
size_bytes: int
class DeleteFileResponse(BaseModel):
deleted: bool
class StorageUsageResponse(BaseModel):
used_bytes: int
limit_bytes: int
used_percent: float
file_count: int
@router.get(
"/files/{file_id}/download",
summary="Download file by ID",
@@ -120,3 +153,148 @@ async def download_file(
raise fastapi.HTTPException(status_code=404, detail="File not found")
return await _create_file_download_response(file)
@router.delete(
"/files/{file_id}",
summary="Delete a workspace file",
)
async def delete_workspace_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
file_id: str,
) -> DeleteFileResponse:
"""
Soft-delete a workspace file and attempt to remove it from storage.
Used when a user clears a file input in the builder.
"""
workspace = await get_workspace(user_id)
if workspace is None:
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
manager = WorkspaceManager(user_id, workspace.id)
deleted = await manager.delete_file(file_id)
if not deleted:
raise fastapi.HTTPException(status_code=404, detail="File not found")
return DeleteFileResponse(deleted=True)
@router.post(
"/files/upload",
summary="Upload file to workspace",
)
async def upload_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
file: UploadFile,
session_id: str | None = Query(default=None),
) -> UploadFileResponse:
"""
Upload a file to the user's workspace.
Files are stored in session-scoped paths when session_id is provided,
so the agent's session-scoped tools can discover them automatically.
"""
config = Config()
# Sanitize filename — strip any directory components
filename = os.path.basename(file.filename or "upload") or "upload"
# Read file content with early abort on size limit
max_file_bytes = config.max_file_size_mb * 1024 * 1024
chunks: list[bytes] = []
total_size = 0
while chunk := await file.read(64 * 1024): # 64KB chunks
total_size += len(chunk)
if total_size > max_file_bytes:
raise fastapi.HTTPException(
status_code=413,
detail=f"File exceeds maximum size of {config.max_file_size_mb} MB",
)
chunks.append(chunk)
content = b"".join(chunks)
# Get or create workspace
workspace = await get_or_create_workspace(user_id)
# Pre-write storage cap check (soft check — final enforcement is post-write)
storage_limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
current_usage = await get_workspace_total_size(workspace.id)
if storage_limit_bytes and current_usage + len(content) > storage_limit_bytes:
used_percent = (current_usage / storage_limit_bytes) * 100
raise fastapi.HTTPException(
status_code=413,
detail={
"message": "Storage limit exceeded",
"used_bytes": current_usage,
"limit_bytes": storage_limit_bytes,
"used_percent": round(used_percent, 1),
},
)
# Warn at 80% usage
if (
storage_limit_bytes
and (usage_ratio := (current_usage + len(content)) / storage_limit_bytes) >= 0.8
):
logger.warning(
f"User {user_id} workspace storage at {usage_ratio * 100:.1f}% "
f"({current_usage + len(content)} / {storage_limit_bytes} bytes)"
)
# Virus scan
await scan_content_safe(content, filename=filename)
# Write file via WorkspaceManager
manager = WorkspaceManager(user_id, workspace.id, session_id)
try:
workspace_file = await manager.write_file(content, filename)
except ValueError as e:
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
# Post-write storage check — eliminates TOCTOU race on the quota.
# If a concurrent upload pushed us over the limit, undo this write.
new_total = await get_workspace_total_size(workspace.id)
if storage_limit_bytes and new_total > storage_limit_bytes:
await soft_delete_workspace_file(workspace_file.id, workspace.id)
raise fastapi.HTTPException(
status_code=413,
detail={
"message": "Storage limit exceeded (concurrent upload)",
"used_bytes": new_total,
"limit_bytes": storage_limit_bytes,
},
)
return UploadFileResponse(
file_id=workspace_file.id,
name=workspace_file.name,
path=workspace_file.path,
mime_type=workspace_file.mime_type,
size_bytes=workspace_file.size_bytes,
)
@router.get(
"/storage/usage",
summary="Get workspace storage usage",
)
async def get_storage_usage(
user_id: Annotated[str, fastapi.Security(get_user_id)],
) -> StorageUsageResponse:
"""
Get storage usage information for the user's workspace.
"""
config = Config()
workspace = await get_or_create_workspace(user_id)
used_bytes = await get_workspace_total_size(workspace.id)
file_count = await count_workspace_files(workspace.id)
limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
return StorageUsageResponse(
used_bytes=used_bytes,
limit_bytes=limit_bytes,
used_percent=round((used_bytes / limit_bytes) * 100, 1) if limit_bytes else 0,
file_count=file_count,
)

View File

@@ -0,0 +1,359 @@
"""Tests for workspace file upload and download routes."""
import io
from datetime import datetime, timezone
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from backend.api.features.workspace import routes as workspace_routes
from backend.data.workspace import WorkspaceFile
app = fastapi.FastAPI()
app.include_router(workspace_routes.router)
@app.exception_handler(ValueError)
async def _value_error_handler(
request: fastapi.Request, exc: ValueError
) -> fastapi.responses.JSONResponse:
"""Mirror the production ValueError → 400 mapping from rest_api.py."""
return fastapi.responses.JSONResponse(status_code=400, content={"detail": str(exc)})
client = fastapi.testclient.TestClient(app)
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
MOCK_WORKSPACE = type("W", (), {"id": "ws-1"})()
_NOW = datetime(2023, 1, 1, tzinfo=timezone.utc)
MOCK_FILE = WorkspaceFile(
id="file-aaa-bbb",
workspace_id="ws-1",
created_at=_NOW,
updated_at=_NOW,
name="hello.txt",
path="/session/hello.txt",
mime_type="text/plain",
size_bytes=13,
storage_path="local://hello.txt",
)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def _upload(
filename: str = "hello.txt",
content: bytes = b"Hello, world!",
content_type: str = "text/plain",
):
"""Helper to POST a file upload."""
return client.post(
"/files/upload?session_id=sess-1",
files={"file": (filename, io.BytesIO(content), content_type)},
)
# ---- Happy path ----
def test_upload_happy_path(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload()
assert response.status_code == 200
data = response.json()
assert data["file_id"] == "file-aaa-bbb"
assert data["name"] == "hello.txt"
assert data["size_bytes"] == 13
# ---- Per-file size limit ----
def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
"""Files larger than max_file_size_mb should be rejected with 413."""
cfg = mocker.patch("backend.api.features.workspace.routes.Config")
cfg.return_value.max_file_size_mb = 0 # 0 MB → any content is too big
cfg.return_value.max_workspace_storage_mb = 500
response = _upload(content=b"x" * 1024)
assert response.status_code == 413
# ---- Storage quota exceeded ----
def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
# Current usage already at limit
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=500 * 1024 * 1024,
)
response = _upload()
assert response.status_code == 413
assert "Storage limit exceeded" in response.text
# ---- Post-write quota race (B2) ----
def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
"""If a concurrent upload tips the total over the limit after write,
the file should be soft-deleted and 413 returned."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
# Pre-write check passes (under limit), but post-write check fails
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
side_effect=[0, 600 * 1024 * 1024], # first call OK, second over limit
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
mock_delete = mocker.patch(
"backend.api.features.workspace.routes.soft_delete_workspace_file",
return_value=None,
)
response = _upload()
assert response.status_code == 413
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-1")
# ---- Any extension accepted (no allowlist) ----
def test_upload_any_extension(mocker: pytest_mock.MockFixture):
"""Any file extension should be accepted — ClamAV is the security layer."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload(filename="data.xyz", content=b"arbitrary")
assert response.status_code == 200
# ---- Virus scan rejection ----
def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
"""Files flagged by ClamAV should be rejected and never written to storage."""
from backend.api.features.store.exceptions import VirusDetectedError
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
side_effect=VirusDetectedError("Eicar-Test-Signature"),
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload(filename="evil.exe", content=b"X5O!P%@AP...")
assert response.status_code == 400
assert "Virus detected" in response.text
mock_manager.write_file.assert_not_called()
# ---- No file extension ----
def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
"""Files without an extension should be accepted and stored as-is."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload(
filename="Makefile",
content=b"all:\n\techo hello",
content_type="application/octet-stream",
)
assert response.status_code == 200
mock_manager.write_file.assert_called_once()
assert mock_manager.write_file.call_args[0][1] == "Makefile"
# ---- Filename sanitization (SF5) ----
def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
"""Path-traversal filenames should be reduced to their basename."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
# Filename with traversal
_upload(filename="../../etc/passwd.txt")
# write_file should have been called with just the basename
mock_manager.write_file.assert_called_once()
call_args = mock_manager.write_file.call_args
assert call_args[0][1] == "passwd.txt"
# ---- Download ----
def test_download_file_not_found(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_file",
return_value=None,
)
response = client.get("/files/some-file-id/download")
assert response.status_code == 404
# ---- Delete ----
def test_delete_file_success(mocker: pytest_mock.MockFixture):
"""Deleting an existing file should return {"deleted": true}."""
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=MOCK_WORKSPACE,
)
mock_manager = mocker.MagicMock()
mock_manager.delete_file = mocker.AsyncMock(return_value=True)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = client.delete("/files/file-aaa-bbb")
assert response.status_code == 200
assert response.json() == {"deleted": True}
mock_manager.delete_file.assert_called_once_with("file-aaa-bbb")
def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
"""Deleting a non-existent file should return 404."""
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=MOCK_WORKSPACE,
)
mock_manager = mocker.MagicMock()
mock_manager.delete_file = mocker.AsyncMock(return_value=False)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = client.delete("/files/nonexistent-id")
assert response.status_code == 404
assert "File not found" in response.text
def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
"""Deleting when user has no workspace should return 404."""
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=None,
)
response = client.delete("/files/file-aaa-bbb")
assert response.status_code == 404
assert "Workspace not found" in response.text

View File

@@ -19,6 +19,7 @@ from prisma.errors import PrismaError
import backend.api.features.admin.credit_admin_routes
import backend.api.features.admin.execution_analytics_routes
import backend.api.features.admin.store_admin_routes
import backend.api.features.admin.waitlist_admin_routes
import backend.api.features.builder
import backend.api.features.builder.routes
import backend.api.features.chat.routes as chat_routes
@@ -41,6 +42,10 @@ import backend.data.user
import backend.integrations.webhooks.utils
import backend.util.service
import backend.util.settings
from backend.api.features.library.exceptions import (
FolderAlreadyExistsError,
FolderValidationError,
)
from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
@@ -261,6 +266,10 @@ async def validation_error_handler(
app.add_exception_handler(PrismaError, handle_internal_http_error(500))
app.add_exception_handler(
FolderAlreadyExistsError, handle_internal_http_error(409, False)
)
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400, False))
app.add_exception_handler(NotFoundError, handle_internal_http_error(404, False))
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403, False))
app.add_exception_handler(RequestValidationError, validation_error_handler)
@@ -291,6 +300,11 @@ app.include_router(
tags=["v2", "admin"],
prefix="/api/store",
)
app.include_router(
backend.api.features.admin.waitlist_admin_routes.router,
tags=["v2", "admin"],
prefix="/api/store",
)
app.include_router(
backend.api.features.admin.credit_admin_routes.router,
tags=["v2", "admin"],

View File

@@ -116,6 +116,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
CLAUDE_4_6_OPUS = "claude-opus-4-6"
CLAUDE_4_6_SONNET = "claude-sonnet-4-6"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
# AI/ML API models
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
@@ -274,6 +275,9 @@ MODEL_METADATA = {
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
), # claude-opus-4-6
LlmModel.CLAUDE_4_6_SONNET: ModelMetadata(
"anthropic", 200000, 64000, "Claude Sonnet 4.6", "Anthropic", "Anthropic", 3
), # claude-sonnet-4-6
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
), # claude-opus-4-5-20251101

View File

@@ -6,7 +6,6 @@ and execute them. Works like AgentExecutorBlock — the user selects a tool from
dropdown and the input/output schema adapts dynamically.
"""
import json
import logging
from typing import Any, Literal
@@ -20,6 +19,11 @@ from backend.blocks._base import (
BlockType,
)
from backend.blocks.mcp.client import MCPClient, MCPClientError
from backend.blocks.mcp.helpers import (
auto_lookup_mcp_credential,
normalize_mcp_url,
parse_mcp_content,
)
from backend.data.block import BlockInput, BlockOutput
from backend.data.model import (
CredentialsField,
@@ -179,31 +183,7 @@ class MCPToolBlock(Block):
f"{error_text or 'Unknown error'}"
)
# Extract text content from the result
output_parts = []
for item in result.content:
if item.get("type") == "text":
text = item.get("text", "")
# Try to parse as JSON for structured output
try:
output_parts.append(json.loads(text))
except (json.JSONDecodeError, ValueError):
output_parts.append(text)
elif item.get("type") == "image":
output_parts.append(
{
"type": "image",
"data": item.get("data"),
"mimeType": item.get("mimeType"),
}
)
elif item.get("type") == "resource":
output_parts.append(item.get("resource", {}))
# If single result, unwrap
if len(output_parts) == 1:
return output_parts[0]
return output_parts if output_parts else None
return parse_mcp_content(result.content)
@staticmethod
async def _auto_lookup_credential(
@@ -211,37 +191,10 @@ class MCPToolBlock(Block):
) -> "OAuth2Credentials | None":
"""Auto-lookup stored MCP credential for a server URL.
This is a fallback for nodes that don't have ``credentials`` explicitly
set (e.g. nodes created before the credential field was wired up).
Delegates to :func:`~backend.blocks.mcp.helpers.auto_lookup_mcp_credential`.
The caller should pass a normalized URL.
"""
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
try:
mgr = IntegrationCredentialsManager()
mcp_creds = await mgr.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
best: OAuth2Credentials | None = None
for cred in mcp_creds:
if (
isinstance(cred, OAuth2Credentials)
and (cred.metadata or {}).get("mcp_server_url") == server_url
):
if best is None or (
(cred.access_token_expires_at or 0)
> (best.access_token_expires_at or 0)
):
best = cred
if best:
best = await mgr.refresh_if_needed(user_id, best)
logger.info(
"Auto-resolved MCP credential %s for %s", best.id, server_url
)
return best
except Exception:
logger.warning("Auto-lookup MCP credential failed", exc_info=True)
return None
return await auto_lookup_mcp_credential(user_id, server_url)
async def run(
self,
@@ -278,7 +231,7 @@ class MCPToolBlock(Block):
# the stored MCP credential for this server URL.
if credentials is None:
credentials = await self._auto_lookup_credential(
user_id, input_data.server_url
user_id, normalize_mcp_url(input_data.server_url)
)
auth_token = (

View File

@@ -55,7 +55,9 @@ class MCPClient:
server_url: str,
auth_token: str | None = None,
):
self.server_url = server_url.rstrip("/")
from backend.blocks.mcp.helpers import normalize_mcp_url
self.server_url = normalize_mcp_url(server_url)
self.auth_token = auth_token
self._request_id = 0
self._session_id: str | None = None

View File

@@ -0,0 +1,117 @@
"""Shared MCP helpers used by blocks, copilot tools, and API routes."""
from __future__ import annotations
import json
import logging
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
if TYPE_CHECKING:
from backend.data.model import OAuth2Credentials
logger = logging.getLogger(__name__)
def normalize_mcp_url(url: str) -> str:
"""Normalize an MCP server URL for consistent credential matching.
Strips leading/trailing whitespace and a single trailing slash so that
``https://mcp.example.com/`` and ``https://mcp.example.com`` resolve to
the same stored credential.
"""
return url.strip().rstrip("/")
def server_host(server_url: str) -> str:
"""Extract the hostname from a server URL for display purposes.
Uses ``parsed.hostname`` (never ``netloc``) to strip any embedded
username/password before surfacing the value in UI messages.
"""
try:
parsed = urlparse(server_url)
return parsed.hostname or server_url
except Exception:
return server_url
def parse_mcp_content(content: list[dict[str, Any]]) -> Any:
"""Parse MCP tool response content into a plain Python value.
- text items: parsed as JSON when possible, kept as str otherwise
- image items: kept as ``{type, data, mimeType}`` dict for frontend rendering
- resource items: unwrapped to their resource payload dict
Single-item responses are unwrapped from the list; multiple items are
returned as a list; empty content returns ``None``.
"""
output_parts: list[Any] = []
for item in content:
item_type = item.get("type")
if item_type == "text":
text = item.get("text", "")
try:
output_parts.append(json.loads(text))
except (json.JSONDecodeError, ValueError):
output_parts.append(text)
elif item_type == "image":
output_parts.append(
{
"type": "image",
"data": item.get("data"),
"mimeType": item.get("mimeType"),
}
)
elif item_type == "resource":
output_parts.append(item.get("resource", {}))
if len(output_parts) == 1:
return output_parts[0]
return output_parts or None
async def auto_lookup_mcp_credential(
user_id: str, server_url: str
) -> OAuth2Credentials | None:
"""Look up the best stored MCP credential for *server_url*.
The caller should pass a **normalized** URL (via :func:`normalize_mcp_url`)
so the comparison with ``mcp_server_url`` in credential metadata matches.
Returns the credential with the latest ``access_token_expires_at``, refreshed
if needed, or ``None`` when no match is found.
"""
from backend.data.model import OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
try:
mgr = IntegrationCredentialsManager()
mcp_creds = await mgr.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
# Collect all matching credentials and pick the best one.
# Primary sort: latest access_token_expires_at (tokens with expiry
# are preferred over non-expiring ones). Secondary sort: last in
# iteration order, which corresponds to the most recently created
# row — this acts as a tiebreaker when multiple bearer tokens have
# no expiry (e.g. after a failed old-credential cleanup).
best: OAuth2Credentials | None = None
for cred in mcp_creds:
if (
isinstance(cred, OAuth2Credentials)
and (cred.metadata or {}).get("mcp_server_url") == server_url
):
if best is None or (
(cred.access_token_expires_at or 0)
>= (best.access_token_expires_at or 0)
):
best = cred
if best:
best = await mgr.refresh_if_needed(user_id, best)
logger.info("Auto-resolved MCP credential %s for %s", best.id, server_url)
return best
except Exception:
logger.warning("Auto-lookup MCP credential failed", exc_info=True)
return None

View File

@@ -0,0 +1,98 @@
"""Unit tests for the shared MCP helpers."""
from backend.blocks.mcp.helpers import normalize_mcp_url, parse_mcp_content, server_host
# ---------------------------------------------------------------------------
# normalize_mcp_url
# ---------------------------------------------------------------------------
def test_normalize_trailing_slash():
assert normalize_mcp_url("https://mcp.example.com/") == "https://mcp.example.com"
def test_normalize_whitespace():
assert normalize_mcp_url(" https://mcp.example.com ") == "https://mcp.example.com"
def test_normalize_both():
assert (
normalize_mcp_url(" https://mcp.example.com/ ") == "https://mcp.example.com"
)
def test_normalize_noop():
assert normalize_mcp_url("https://mcp.example.com") == "https://mcp.example.com"
def test_normalize_path_with_trailing_slash():
assert (
normalize_mcp_url("https://mcp.example.com/path/")
== "https://mcp.example.com/path"
)
# ---------------------------------------------------------------------------
# server_host
# ---------------------------------------------------------------------------
def test_server_host_standard_url():
assert server_host("https://mcp.example.com/mcp") == "mcp.example.com"
def test_server_host_strips_credentials():
"""hostname must not expose user:pass."""
assert server_host("https://user:secret@mcp.example.com/mcp") == "mcp.example.com"
def test_server_host_with_port():
"""Port should not appear in hostname (hostname strips it)."""
assert server_host("https://mcp.example.com:8080/mcp") == "mcp.example.com"
def test_server_host_fallback():
"""Falls back to the raw string for un-parseable URLs."""
assert server_host("not-a-url") == "not-a-url"
# ---------------------------------------------------------------------------
# parse_mcp_content
# ---------------------------------------------------------------------------
def test_parse_text_plain():
assert parse_mcp_content([{"type": "text", "text": "hello world"}]) == "hello world"
def test_parse_text_json():
content = [{"type": "text", "text": '{"status": "ok", "count": 42}'}]
assert parse_mcp_content(content) == {"status": "ok", "count": 42}
def test_parse_image():
content = [{"type": "image", "data": "abc123==", "mimeType": "image/png"}]
assert parse_mcp_content(content) == {
"type": "image",
"data": "abc123==",
"mimeType": "image/png",
}
def test_parse_resource():
content = [
{"type": "resource", "resource": {"uri": "file:///tmp/out.txt", "text": "hi"}}
]
assert parse_mcp_content(content) == {"uri": "file:///tmp/out.txt", "text": "hi"}
def test_parse_multi_item():
content = [
{"type": "text", "text": "first"},
{"type": "text", "text": "second"},
]
assert parse_mcp_content(content) == ["first", "second"]
def test_parse_empty():
assert parse_mcp_content([]) is None

View File

@@ -83,7 +83,8 @@ class StagehandRecommendedLlmModel(str, Enum):
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
# Anthropic
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929" # Keep for backwards compat
CLAUDE_4_6_SONNET = "claude-sonnet-4-6"
@property
def provider_name(self) -> str:
@@ -137,7 +138,7 @@ class StagehandObserveBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
@@ -227,7 +228,7 @@ class StagehandActBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
@@ -324,7 +325,7 @@ class StagehandExtractBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()

View File

@@ -0,0 +1,182 @@
"""
Telegram Bot API helper functions.
Provides utilities for making authenticated requests to the Telegram Bot API.
"""
import logging
from io import BytesIO
from typing import Any, Optional
from pydantic import BaseModel
from backend.data.model import APIKeyCredentials
from backend.util.request import Requests
logger = logging.getLogger(__name__)
TELEGRAM_API_BASE = "https://api.telegram.org"
class TelegramMessageResult(BaseModel, extra="allow"):
"""Result from Telegram send/edit message API calls."""
message_id: int = 0
chat: dict[str, Any] = {}
date: int = 0
text: str = ""
class TelegramFileResult(BaseModel, extra="allow"):
"""Result from Telegram getFile API call."""
file_id: str = ""
file_unique_id: str = ""
file_size: int = 0
file_path: str = ""
class TelegramAPIException(ValueError):
"""Exception raised for Telegram API errors."""
def __init__(self, message: str, error_code: int = 0):
super().__init__(message)
self.error_code = error_code
def get_bot_api_url(bot_token: str, method: str) -> str:
"""Construct Telegram Bot API URL for a method."""
return f"{TELEGRAM_API_BASE}/bot{bot_token}/{method}"
def get_file_url(bot_token: str, file_path: str) -> str:
"""Construct Telegram file download URL."""
return f"{TELEGRAM_API_BASE}/file/bot{bot_token}/{file_path}"
async def call_telegram_api(
credentials: APIKeyCredentials,
method: str,
data: Optional[dict[str, Any]] = None,
) -> TelegramMessageResult:
"""
Make a request to the Telegram Bot API.
Args:
credentials: Bot token credentials
method: API method name (e.g., "sendMessage", "getFile")
data: Request parameters
Returns:
API response result
Raises:
TelegramAPIException: If the API returns an error
"""
token = credentials.api_key.get_secret_value()
url = get_bot_api_url(token, method)
response = await Requests().post(url, json=data or {})
result = response.json()
if not result.get("ok"):
error_code = result.get("error_code", 0)
description = result.get("description", "Unknown error")
raise TelegramAPIException(description, error_code)
return TelegramMessageResult(**result.get("result", {}))
async def call_telegram_api_with_file(
credentials: APIKeyCredentials,
method: str,
file_field: str,
file_data: bytes,
filename: str,
content_type: str,
data: Optional[dict[str, Any]] = None,
) -> TelegramMessageResult:
"""
Make a multipart/form-data request to the Telegram Bot API with a file upload.
Args:
credentials: Bot token credentials
method: API method name (e.g., "sendPhoto", "sendVoice")
file_field: Form field name for the file (e.g., "photo", "voice")
file_data: Raw file bytes
filename: Filename for the upload
content_type: MIME type of the file
data: Additional form parameters
Returns:
API response result
Raises:
TelegramAPIException: If the API returns an error
"""
token = credentials.api_key.get_secret_value()
url = get_bot_api_url(token, method)
files = [(file_field, (filename, BytesIO(file_data), content_type))]
response = await Requests().post(url, files=files, data=data or {})
result = response.json()
if not result.get("ok"):
error_code = result.get("error_code", 0)
description = result.get("description", "Unknown error")
raise TelegramAPIException(description, error_code)
return TelegramMessageResult(**result.get("result", {}))
async def get_file_info(
credentials: APIKeyCredentials, file_id: str
) -> TelegramFileResult:
"""
Get file information from Telegram.
Args:
credentials: Bot token credentials
file_id: Telegram file_id from message
Returns:
File info dict containing file_id, file_unique_id, file_size, file_path
"""
result = await call_telegram_api(credentials, "getFile", {"file_id": file_id})
return TelegramFileResult(**result.model_dump())
async def get_file_download_url(credentials: APIKeyCredentials, file_id: str) -> str:
"""
Get the download URL for a Telegram file.
Args:
credentials: Bot token credentials
file_id: Telegram file_id from message
Returns:
Full download URL
"""
token = credentials.api_key.get_secret_value()
result = await get_file_info(credentials, file_id)
file_path = result.file_path
if not file_path:
raise TelegramAPIException("No file_path returned from getFile")
return get_file_url(token, file_path)
async def download_telegram_file(credentials: APIKeyCredentials, file_id: str) -> bytes:
"""
Download a file from Telegram servers.
Args:
credentials: Bot token credentials
file_id: Telegram file_id
Returns:
File content as bytes
"""
url = await get_file_download_url(credentials, file_id)
response = await Requests().get(url)
return response.content

View File

@@ -0,0 +1,43 @@
"""
Telegram Bot credentials handling.
Telegram bots use an API key (bot token) obtained from @BotFather.
"""
from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
# Bot token credentials (API key style)
TelegramCredentials = APIKeyCredentials
TelegramCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.TELEGRAM], Literal["api_key"]
]
def TelegramCredentialsField() -> TelegramCredentialsInput:
"""Creates a Telegram bot token credentials field."""
return CredentialsField(
description="Telegram Bot API token from @BotFather. "
"Create a bot at https://t.me/BotFather to get your token."
)
# Test credentials for unit tests
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="telegram",
api_key=SecretStr("test_telegram_bot_token"),
title="Mock Telegram Bot Token",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,377 @@
"""
Telegram trigger blocks for receiving messages via webhooks.
"""
import logging
from pydantic import BaseModel
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
BlockWebhookConfig,
)
from backend.data.model import SchemaField
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks.telegram import TelegramWebhookType
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TelegramCredentialsField,
TelegramCredentialsInput,
)
logger = logging.getLogger(__name__)
# Example payload for testing
EXAMPLE_MESSAGE_PAYLOAD = {
"update_id": 123456789,
"message": {
"message_id": 1,
"from": {
"id": 12345678,
"is_bot": False,
"first_name": "John",
"last_name": "Doe",
"username": "johndoe",
"language_code": "en",
},
"chat": {
"id": 12345678,
"first_name": "John",
"last_name": "Doe",
"username": "johndoe",
"type": "private",
},
"date": 1234567890,
"text": "Hello, bot!",
},
}
class TelegramTriggerBase:
"""Base class for Telegram trigger blocks."""
class Input(BlockSchemaInput):
credentials: TelegramCredentialsInput = TelegramCredentialsField()
payload: dict = SchemaField(hidden=True, default_factory=dict)
class TelegramMessageTriggerBlock(TelegramTriggerBase, Block):
"""
Triggers when a message is received or edited in your Telegram bot.
Supports text, photos, voice messages, audio files, documents, and videos.
Connect the outputs to other blocks to process messages and send responses.
"""
class Input(TelegramTriggerBase.Input):
class EventsFilter(BaseModel):
"""Filter for message types to receive."""
text: bool = True
photo: bool = False
voice: bool = False
audio: bool = False
document: bool = False
video: bool = False
edited_message: bool = False
events: EventsFilter = SchemaField(
title="Message Types", description="Types of messages to receive"
)
class Output(BlockSchemaOutput):
payload: dict = SchemaField(
description="The complete webhook payload from Telegram"
)
chat_id: int = SchemaField(
description="The chat ID where the message was received. "
"Use this to send replies."
)
message_id: int = SchemaField(description="The unique message ID")
user_id: int = SchemaField(description="The user ID who sent the message")
username: str = SchemaField(description="Username of the sender (may be empty)")
first_name: str = SchemaField(description="First name of the sender")
event: str = SchemaField(
description="The message type (text, photo, voice, audio, etc.)"
)
text: str = SchemaField(
description="Text content of the message (for text messages)"
)
photo_file_id: str = SchemaField(
description="File ID of the photo (for photo messages). "
"Use GetTelegramFileBlock to download."
)
voice_file_id: str = SchemaField(
description="File ID of the voice message (for voice messages). "
"Use GetTelegramFileBlock to download."
)
audio_file_id: str = SchemaField(
description="File ID of the audio file (for audio messages). "
"Use GetTelegramFileBlock to download."
)
file_id: str = SchemaField(
description="File ID for document/video messages. "
"Use GetTelegramFileBlock to download."
)
file_name: str = SchemaField(
description="Original filename (for document/audio messages)"
)
caption: str = SchemaField(description="Caption for media messages")
is_edited: bool = SchemaField(
description="Whether this is an edit of a previously sent message"
)
def __init__(self):
super().__init__(
id="4435e4e0-df6e-4301-8f35-ad70b12fc9ec",
description="Triggers when a message is received or edited in your Telegram bot. "
"Supports text, photos, voice messages, audio files, documents, and videos.",
categories={BlockCategory.SOCIAL},
input_schema=TelegramMessageTriggerBlock.Input,
output_schema=TelegramMessageTriggerBlock.Output,
webhook_config=BlockWebhookConfig(
provider=ProviderName.TELEGRAM,
webhook_type=TelegramWebhookType.BOT,
resource_format="bot",
event_filter_input="events",
event_format="message.{event}",
),
test_input={
"events": {"text": True, "photo": True},
"credentials": TEST_CREDENTIALS_INPUT,
"payload": EXAMPLE_MESSAGE_PAYLOAD,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("payload", EXAMPLE_MESSAGE_PAYLOAD),
("chat_id", 12345678),
("message_id", 1),
("user_id", 12345678),
("username", "johndoe"),
("first_name", "John"),
("is_edited", False),
("event", "text"),
("text", "Hello, bot!"),
("photo_file_id", ""),
("voice_file_id", ""),
("audio_file_id", ""),
("file_id", ""),
("file_name", ""),
("caption", ""),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
payload = input_data.payload
is_edited = "edited_message" in payload
message = payload.get("message") or payload.get("edited_message", {})
# Extract common fields
chat = message.get("chat", {})
sender = message.get("from", {})
yield "payload", payload
yield "chat_id", chat.get("id", 0)
yield "message_id", message.get("message_id", 0)
yield "user_id", sender.get("id", 0)
yield "username", sender.get("username", "")
yield "first_name", sender.get("first_name", "")
yield "is_edited", is_edited
# For edited messages, yield event as "edited_message" and extract
# all content fields from the edited message body
if is_edited:
yield "event", "edited_message"
yield "text", message.get("text", "")
photos = message.get("photo", [])
yield "photo_file_id", photos[-1].get("file_id", "") if photos else ""
voice = message.get("voice", {})
yield "voice_file_id", voice.get("file_id", "")
audio = message.get("audio", {})
yield "audio_file_id", audio.get("file_id", "")
document = message.get("document", {})
video = message.get("video", {})
yield "file_id", (document.get("file_id", "") or video.get("file_id", ""))
yield "file_name", (
document.get("file_name", "") or audio.get("file_name", "")
)
yield "caption", message.get("caption", "")
# Determine message type and extract content
elif "text" in message:
yield "event", "text"
yield "text", message.get("text", "")
yield "photo_file_id", ""
yield "voice_file_id", ""
yield "audio_file_id", ""
yield "file_id", ""
yield "file_name", ""
yield "caption", ""
elif "photo" in message:
# Get the largest photo (last in array)
photos = message.get("photo", [])
photo_fid = photos[-1].get("file_id", "") if photos else ""
yield "event", "photo"
yield "text", ""
yield "photo_file_id", photo_fid
yield "voice_file_id", ""
yield "audio_file_id", ""
yield "file_id", ""
yield "file_name", ""
yield "caption", message.get("caption", "")
elif "voice" in message:
voice = message.get("voice", {})
yield "event", "voice"
yield "text", ""
yield "photo_file_id", ""
yield "voice_file_id", voice.get("file_id", "")
yield "audio_file_id", ""
yield "file_id", ""
yield "file_name", ""
yield "caption", message.get("caption", "")
elif "audio" in message:
audio = message.get("audio", {})
yield "event", "audio"
yield "text", ""
yield "photo_file_id", ""
yield "voice_file_id", ""
yield "audio_file_id", audio.get("file_id", "")
yield "file_id", ""
yield "file_name", audio.get("file_name", "")
yield "caption", message.get("caption", "")
elif "document" in message:
document = message.get("document", {})
yield "event", "document"
yield "text", ""
yield "photo_file_id", ""
yield "voice_file_id", ""
yield "audio_file_id", ""
yield "file_id", document.get("file_id", "")
yield "file_name", document.get("file_name", "")
yield "caption", message.get("caption", "")
elif "video" in message:
video = message.get("video", {})
yield "event", "video"
yield "text", ""
yield "photo_file_id", ""
yield "voice_file_id", ""
yield "audio_file_id", ""
yield "file_id", video.get("file_id", "")
yield "file_name", video.get("file_name", "")
yield "caption", message.get("caption", "")
else:
yield "event", "other"
yield "text", ""
yield "photo_file_id", ""
yield "voice_file_id", ""
yield "audio_file_id", ""
yield "file_id", ""
yield "file_name", ""
yield "caption", ""
# Example payload for reaction trigger testing
EXAMPLE_REACTION_PAYLOAD = {
"update_id": 123456790,
"message_reaction": {
"chat": {
"id": 12345678,
"first_name": "John",
"last_name": "Doe",
"username": "johndoe",
"type": "private",
},
"message_id": 42,
"user": {
"id": 12345678,
"is_bot": False,
"first_name": "John",
"username": "johndoe",
},
"date": 1234567890,
"new_reaction": [{"type": "emoji", "emoji": "👍"}],
"old_reaction": [],
},
}
class TelegramMessageReactionTriggerBlock(TelegramTriggerBase, Block):
"""
Triggers when a reaction to a message is changed.
Works automatically in private chats. In group chats, the bot must be
an administrator to receive reaction updates.
"""
class Input(TelegramTriggerBase.Input):
pass
class Output(BlockSchemaOutput):
payload: dict = SchemaField(
description="The complete webhook payload from Telegram"
)
chat_id: int = SchemaField(
description="The chat ID where the reaction occurred"
)
message_id: int = SchemaField(description="The message ID that was reacted to")
user_id: int = SchemaField(description="The user ID who changed the reaction")
username: str = SchemaField(description="Username of the user (may be empty)")
new_reactions: list = SchemaField(
description="List of new reactions on the message"
)
old_reactions: list = SchemaField(
description="List of previous reactions on the message"
)
def __init__(self):
super().__init__(
id="82525328-9368-4966-8f0c-cd78e80181fd",
description="Triggers when a reaction to a message is changed. "
"Works in private chats automatically. "
"In groups, the bot must be an administrator.",
categories={BlockCategory.SOCIAL},
input_schema=TelegramMessageReactionTriggerBlock.Input,
output_schema=TelegramMessageReactionTriggerBlock.Output,
webhook_config=BlockWebhookConfig(
provider=ProviderName.TELEGRAM,
webhook_type=TelegramWebhookType.BOT,
resource_format="bot",
event_filter_input="",
event_format="message_reaction",
),
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"payload": EXAMPLE_REACTION_PAYLOAD,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("payload", EXAMPLE_REACTION_PAYLOAD),
("chat_id", 12345678),
("message_id", 42),
("user_id", 12345678),
("username", "johndoe"),
("new_reactions", [{"type": "emoji", "emoji": "👍"}]),
("old_reactions", []),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
payload = input_data.payload
reaction = payload.get("message_reaction", {})
chat = reaction.get("chat", {})
user = reaction.get("user", {})
yield "payload", payload
yield "chat_id", chat.get("id", 0)
yield "message_id", reaction.get("message_id", 0)
yield "user_id", user.get("id", 0)
yield "username", user.get("username", "")
yield "new_reactions", reaction.get("new_reaction", [])
yield "old_reactions", reaction.get("old_reaction", [])

View File

@@ -34,10 +34,12 @@ def main(output: Path, pretty: bool):
"""Generate and output the OpenAPI JSON specification."""
openapi_schema = get_openapi_schema()
json_output = json.dumps(openapi_schema, indent=2 if pretty else None)
json_output = json.dumps(
openapi_schema, indent=2 if pretty else None, ensure_ascii=False
)
if output:
output.write_text(json_output)
output.write_text(json_output, encoding="utf-8")
click.echo(f"✅ OpenAPI specification written to {output}\n\nPreview:")
click.echo(f"\n{json_output[:500]} ...")
else:

View File

@@ -0,0 +1,3 @@
from .service import stream_chat_completion_baseline
__all__ = ["stream_chat_completion_baseline"]

View File

@@ -0,0 +1,420 @@
"""Baseline LLM fallback — OpenAI-compatible streaming with tool calling.
Used when ``CHAT_USE_CLAUDE_AGENT_SDK=false``, e.g. as a fallback when the
Claude Agent SDK / Anthropic API is unavailable. Routes through any
OpenAI-compatible provider (OpenRouter by default) and reuses the same
shared tool registry as the SDK path.
"""
import asyncio
import logging
import uuid
from collections.abc import AsyncGenerator
from typing import Any
import orjson
from langfuse import propagate_attributes
from backend.copilot.model import (
ChatMessage,
ChatSession,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from backend.copilot.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from backend.copilot.service import (
_build_system_prompt,
_generate_session_title,
client,
config,
)
from backend.copilot.tools import execute_tool, get_available_tools
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError
from backend.util.prompt import compress_context
logger = logging.getLogger(__name__)
# Set to hold background tasks to prevent garbage collection
_background_tasks: set[asyncio.Task[Any]] = set()
# Maximum number of tool-call rounds before forcing a text response.
_MAX_TOOL_ROUNDS = 30
async def _update_title_async(
session_id: str, message: str, user_id: str | None
) -> None:
"""Generate and persist a session title in the background."""
try:
title = await _generate_session_title(message, user_id, session_id)
if title:
await update_session_title(session_id, title)
except Exception as e:
logger.warning("[Baseline] Failed to update session title: %s", e)
async def _compress_session_messages(
messages: list[ChatMessage],
) -> list[ChatMessage]:
"""Compress session messages if they exceed the model's token limit.
Uses the shared compress_context() utility which supports LLM-based
summarization of older messages while keeping recent ones intact,
with progressive truncation and middle-out deletion as fallbacks.
"""
messages_dict = []
for msg in messages:
msg_dict: dict[str, Any] = {"role": msg.role}
if msg.content:
msg_dict["content"] = msg.content
messages_dict.append(msg_dict)
try:
result = await compress_context(
messages=messages_dict,
model=config.model,
client=client,
)
except Exception as e:
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
result = await compress_context(
messages=messages_dict,
model=config.model,
client=None,
)
if result.was_compacted:
logger.info(
"[Baseline] Context compacted: %d -> %d tokens "
"(%d summarized, %d dropped)",
result.original_token_count,
result.token_count,
result.messages_summarized,
result.messages_dropped,
)
return [
ChatMessage(role=m["role"], content=m.get("content"))
for m in result.messages
]
return messages
async def stream_chat_completion_baseline(
session_id: str,
message: str | None = None,
is_user_message: bool = True,
user_id: str | None = None,
session: ChatSession | None = None,
**_kwargs: Any,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Baseline LLM with tool calling via OpenAI-compatible API.
Designed as a fallback when the Claude Agent SDK is unavailable.
Uses the same tool registry as the SDK path but routes through any
OpenAI-compatible provider (e.g. OpenRouter).
Flow: stream response -> if tool_calls, execute them -> feed results back -> repeat.
"""
if session is None:
session = await get_chat_session(session_id, user_id)
if not session:
raise NotFoundError(
f"Session {session_id} not found. Please create a new session first."
)
# Append user message
new_role = "user" if is_user_message else "assistant"
if message and (
len(session.messages) == 0
or not (
session.messages[-1].role == new_role
and session.messages[-1].content == message
)
):
session.messages.append(ChatMessage(role=new_role, content=message))
if is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(message),
)
session = await upsert_chat_session(session)
# Generate title for new sessions
if is_user_message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
if len(user_messages) == 1:
first_message = user_messages[0].content or message or ""
if first_message:
task = asyncio.create_task(
_update_title_async(session_id, first_message, user_id)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
message_id = str(uuid.uuid4())
# Build system prompt only on the first turn to avoid mid-conversation
# changes from concurrent chats updating business understanding.
is_first_turn = len(session.messages) <= 1
if is_first_turn:
system_prompt, _ = await _build_system_prompt(
user_id, has_conversation_history=False
)
else:
system_prompt, _ = await _build_system_prompt(
user_id=None, has_conversation_history=True
)
# Compress context if approaching the model's token limit
messages_for_context = await _compress_session_messages(session.messages)
# Build OpenAI message list from session history
openai_messages: list[dict[str, Any]] = [
{"role": "system", "content": system_prompt}
]
for msg in messages_for_context:
if msg.role in ("user", "assistant") and msg.content:
openai_messages.append({"role": msg.role, "content": msg.content})
tools = get_available_tools()
yield StreamStart(messageId=message_id, sessionId=session_id)
# Propagate user/session context to Langfuse so all LLM calls within
# this request are grouped under a single trace with proper attribution.
_trace_ctx: Any = None
try:
_trace_ctx = propagate_attributes(
user_id=user_id,
session_id=session_id,
trace_name="copilot-baseline",
tags=["baseline"],
)
_trace_ctx.__enter__()
except Exception:
logger.warning("[Baseline] Langfuse trace context setup failed")
assistant_text = ""
text_block_id = str(uuid.uuid4())
text_started = False
step_open = False
try:
for _round in range(_MAX_TOOL_ROUNDS):
# Open a new step for each LLM round
yield StreamStartStep()
step_open = True
# Stream a response from the model
create_kwargs: dict[str, Any] = dict(
model=config.model,
messages=openai_messages,
stream=True,
)
if tools:
create_kwargs["tools"] = tools
response = await client.chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
# Accumulate streamed response (text + tool calls)
round_text = ""
tool_calls_by_index: dict[int, dict[str, str]] = {}
async for chunk in response:
delta = chunk.choices[0].delta if chunk.choices else None
if not delta:
continue
# Text content
if delta.content:
if not text_started:
yield StreamTextStart(id=text_block_id)
text_started = True
round_text += delta.content
yield StreamTextDelta(id=text_block_id, delta=delta.content)
# Tool call fragments (streamed incrementally)
if delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index
if idx not in tool_calls_by_index:
tool_calls_by_index[idx] = {
"id": "",
"name": "",
"arguments": "",
}
entry = tool_calls_by_index[idx]
if tc.id:
entry["id"] = tc.id
if tc.function and tc.function.name:
entry["name"] = tc.function.name
if tc.function and tc.function.arguments:
entry["arguments"] += tc.function.arguments
# Close text block if we had one this round
if text_started:
yield StreamTextEnd(id=text_block_id)
text_started = False
text_block_id = str(uuid.uuid4())
# Accumulate text for session persistence
assistant_text += round_text
# No tool calls -> model is done
if not tool_calls_by_index:
yield StreamFinishStep()
step_open = False
break
# Close step before tool execution
yield StreamFinishStep()
step_open = False
# Append the assistant message with tool_calls to context.
assistant_msg: dict[str, Any] = {"role": "assistant"}
if round_text:
assistant_msg["content"] = round_text
assistant_msg["tool_calls"] = [
{
"id": tc["id"],
"type": "function",
"function": {
"name": tc["name"],
"arguments": tc["arguments"] or "{}",
},
}
for tc in tool_calls_by_index.values()
]
openai_messages.append(assistant_msg)
# Execute each tool call and stream events
for tc in tool_calls_by_index.values():
tool_call_id = tc["id"]
tool_name = tc["name"]
raw_args = tc["arguments"] or "{}"
try:
tool_args = orjson.loads(raw_args)
except orjson.JSONDecodeError as parse_err:
parse_error = (
f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
)
logger.warning("[Baseline] %s", parse_error)
yield StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=parse_error,
success=False,
)
openai_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": parse_error,
}
)
continue
yield StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
yield StreamToolInputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
input=tool_args,
)
# Execute via shared tool registry
try:
result: StreamToolOutputAvailable = await execute_tool(
tool_name=tool_name,
parameters=tool_args,
user_id=user_id,
session=session,
tool_call_id=tool_call_id,
)
yield result
tool_output = (
result.output
if isinstance(result.output, str)
else str(result.output)
)
except Exception as e:
error_output = f"Tool execution error: {e}"
logger.error(
"[Baseline] Tool %s failed: %s",
tool_name,
error_output,
exc_info=True,
)
yield StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=error_output,
success=False,
)
tool_output = error_output
# Append tool result to context for next round
openai_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": tool_output,
}
)
else:
# for-loop exhausted without break -> tool-round limit hit
limit_msg = (
f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds "
"without a final response."
)
logger.error("[Baseline] %s", limit_msg)
yield StreamError(
errorText=limit_msg,
code="baseline_tool_round_limit",
)
except Exception as e:
error_msg = str(e) or type(e).__name__
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
# Close any open text/step before emitting error
if text_started:
yield StreamTextEnd(id=text_block_id)
if step_open:
yield StreamFinishStep()
yield StreamError(errorText=error_msg, code="baseline_error")
# Still persist whatever we got
finally:
# Close Langfuse trace context
if _trace_ctx is not None:
try:
_trace_ctx.__exit__(None, None, None)
except Exception:
logger.warning("[Baseline] Langfuse trace context teardown failed")
# Persist assistant response
if assistant_text:
session.messages.append(
ChatMessage(role="assistant", content=assistant_text)
)
try:
await upsert_chat_session(session)
except Exception as persist_err:
logger.error("[Baseline] Failed to persist session: %s", persist_err)
yield StreamFinish()

View File

@@ -0,0 +1,99 @@
import logging
from os import getenv
import pytest
from backend.copilot.baseline import stream_chat_completion_baseline
from backend.copilot.model import (
create_chat_session,
get_chat_session,
upsert_chat_session,
)
from backend.copilot.response_model import (
StreamError,
StreamFinish,
StreamStart,
StreamTextDelta,
)
logger = logging.getLogger(__name__)
@pytest.mark.asyncio(loop_scope="session")
async def test_baseline_multi_turn(setup_test_user, test_user_id):
"""Test that the baseline LLM path streams responses and maintains history.
Turn 1: Send a message with a unique keyword.
Turn 2: Ask the model to recall the keyword — proving conversation history
is correctly passed to the single-call LLM.
"""
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
# --- Turn 1: send a message with a unique keyword ---
keyword = "QUASAR99"
turn1_msg = (
f"Please remember this special keyword: {keyword}. "
"Just confirm you've noted it, keep your response brief."
)
turn1_text = ""
turn1_errors: list[str] = []
got_start = False
got_finish = False
async for chunk in stream_chat_completion_baseline(
session.session_id,
turn1_msg,
user_id=test_user_id,
):
if isinstance(chunk, StreamStart):
got_start = True
elif isinstance(chunk, StreamTextDelta):
turn1_text += chunk.delta
elif isinstance(chunk, StreamError):
turn1_errors.append(chunk.errorText)
elif isinstance(chunk, StreamFinish):
got_finish = True
assert got_start, "Turn 1 did not yield StreamStart"
assert got_finish, "Turn 1 did not yield StreamFinish"
assert not turn1_errors, f"Turn 1 errors: {turn1_errors}"
assert turn1_text, "Turn 1 produced no text"
logger.info(f"Turn 1 response: {turn1_text[:100]}")
# Reload session for turn 2
session = await get_chat_session(session.session_id, test_user_id)
assert session, "Session not found after turn 1"
# Verify messages were persisted (user + assistant)
assert (
len(session.messages) >= 2
), f"Expected at least 2 messages after turn 1, got {len(session.messages)}"
# --- Turn 2: ask model to recall the keyword ---
turn2_msg = "What was the special keyword I asked you to remember?"
turn2_text = ""
turn2_errors: list[str] = []
async for chunk in stream_chat_completion_baseline(
session.session_id,
turn2_msg,
user_id=test_user_id,
session=session,
):
if isinstance(chunk, StreamTextDelta):
turn2_text += chunk.delta
elif isinstance(chunk, StreamError):
turn2_errors.append(chunk.errorText)
assert not turn2_errors, f"Turn 2 errors: {turn2_errors}"
assert turn2_text, "Turn 2 produced no text"
assert keyword in turn2_text, (
f"Model did not recall keyword '{keyword}' in turn 2. "
f"Response: {turn2_text[:200]}"
)
logger.info(f"Turn 2 recalled keyword successfully: {turn2_text[:100]}")

View File

@@ -26,11 +26,6 @@ class ChatConfig(BaseSettings):
# Session TTL Configuration - 12 hours
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
# Streaming Configuration
max_retries: int = Field(
default=3,
description="Max retries for fallback path (SDK handles retries internally)",
)
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
max_agent_schedules: int = Field(
default=30, description="Maximum number of agent schedules"
@@ -67,11 +62,15 @@ class ChatConfig(BaseSettings):
default="CoPilot Prompt",
description="Name of the prompt in Langfuse to fetch",
)
langfuse_prompt_cache_ttl: int = Field(
default=300,
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
)
# Claude Agent SDK Configuration
use_claude_agent_sdk: bool = Field(
default=True,
description="Use Claude Agent SDK for chat completions",
description="Use Claude Agent SDK (True) or OpenAI-compatible LLM baseline (False)",
)
claude_agent_model: str | None = Field(
default=None,
@@ -85,25 +84,60 @@ class ChatConfig(BaseSettings):
)
claude_agent_max_subtasks: int = Field(
default=10,
description="Max number of sub-agent Tasks the SDK can spawn per session.",
description="Max number of concurrent sub-agent Tasks the SDK can run per session.",
)
claude_agent_use_resume: bool = Field(
default=True,
description="Use --resume for multi-turn conversations instead of "
"history compression. Falls back to compression when unavailable.",
)
# Extended thinking configuration for Claude models
thinking_enabled: bool = Field(
default=True,
description="Enable adaptive thinking for Claude models via OpenRouter",
use_claude_code_subscription: bool = Field(
default=False,
description="For personal/dev use: use Claude Code CLI subscription auth instead of API keys. Requires `claude login` on the host. Only works with SDK mode.",
)
# E2B Sandbox Configuration
use_e2b_sandbox: bool = Field(
default=True,
description="Use E2B cloud sandboxes for persistent bash/python execution. "
"When enabled, bash_exec routes commands to E2B and SDK file tools "
"operate directly on the sandbox via E2B's filesystem API.",
)
e2b_api_key: str | None = Field(
default=None,
description="E2B API key. Falls back to E2B_API_KEY environment variable.",
)
e2b_sandbox_template: str = Field(
default="base",
description="E2B sandbox template to use for copilot sessions.",
)
e2b_sandbox_timeout: int = Field(
default=43200, # 12 hours — same as session_ttl
description="E2B sandbox keepalive timeout in seconds.",
)
@field_validator("use_e2b_sandbox", mode="before")
@classmethod
def get_use_e2b_sandbox(cls, v):
"""Get use_e2b_sandbox from environment if not provided."""
env_val = os.getenv("CHAT_USE_E2B_SANDBOX", "").lower()
if env_val:
return env_val in ("true", "1", "yes", "on")
return True if v is None else v
@field_validator("e2b_api_key", mode="before")
@classmethod
def get_e2b_api_key(cls, v):
"""Get E2B API key from environment if not provided."""
if not v:
v = os.getenv("CHAT_E2B_API_KEY") or os.getenv("E2B_API_KEY")
return v
@field_validator("api_key", mode="before")
@classmethod
def get_api_key(cls, v):
"""Get API key from environment if not provided."""
if v is None:
if not v:
# Try to get from environment variables
# First check for CHAT_API_KEY (Pydantic prefix)
v = os.getenv("CHAT_API_KEY")
@@ -113,13 +147,16 @@ class ChatConfig(BaseSettings):
if not v:
# Fall back to OPENAI_API_KEY
v = os.getenv("OPENAI_API_KEY")
# Note: ANTHROPIC_API_KEY is intentionally NOT included here.
# The SDK CLI picks it up from the env directly. Including it
# would pair it with the OpenRouter base_url, causing auth failures.
return v
@field_validator("base_url", mode="before")
@classmethod
def get_base_url(cls, v):
"""Get base URL from environment if not provided."""
if v is None:
if not v:
# Check for OpenRouter or custom base URL
v = os.getenv("CHAT_BASE_URL")
if not v:
@@ -141,6 +178,15 @@ class ChatConfig(BaseSettings):
# Default to True (SDK enabled by default)
return True if v is None else v
@field_validator("use_claude_code_subscription", mode="before")
@classmethod
def get_use_claude_code_subscription(cls, v):
"""Get use_claude_code_subscription from environment if not provided."""
env_val = os.getenv("CHAT_USE_CLAUDE_CODE_SUBSCRIPTION", "").lower()
if env_val:
return env_val in ("true", "1", "yes", "on")
return False if v is None else v
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",

View File

@@ -0,0 +1,11 @@
"""Shared constants for the CoPilot module."""
# Special message prefixes for text-based markers (parsed by frontend).
# The hex suffix makes accidental LLM generation of these strings virtually
# impossible, avoiding false-positive marker detection in normal conversation.
COPILOT_ERROR_PREFIX = "[__COPILOT_ERROR_f7a1__]" # Renders as ErrorCard
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
# Compaction notice messages shown to users.
COMPACTION_DONE_MSG = "Earlier messages were summarized to fit within context limits."
COMPACTION_TOOL_NAME = "context_compaction"

View File

@@ -16,7 +16,7 @@ from prisma.types import (
)
from backend.data import db
from backend.util.json import SafeJson
from backend.util.json import SafeJson, sanitize_string
from .model import ChatMessage, ChatSession, ChatSessionInfo
@@ -101,15 +101,16 @@ async def add_chat_message(
"sequence": sequence,
}
# Add optional string fields
# Add optional string fields — sanitize to strip PostgreSQL-incompatible
# control characters (null bytes etc.) that may appear in tool outputs.
if content is not None:
data["content"] = content
data["content"] = sanitize_string(content)
if name is not None:
data["name"] = name
if tool_call_id is not None:
data["toolCallId"] = tool_call_id
if refusal is not None:
data["refusal"] = refusal
data["refusal"] = sanitize_string(refusal)
# Add optional JSON fields only when they have values
if tool_calls is not None:
@@ -170,15 +171,16 @@ async def add_chat_messages_batch(
"createdAt": now,
}
# Add optional string fields
# Add optional string fields — sanitize to strip
# PostgreSQL-incompatible control characters.
if msg.get("content") is not None:
data["content"] = msg["content"]
data["content"] = sanitize_string(msg["content"])
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = msg["refusal"]
data["refusal"] = sanitize_string(msg["refusal"])
# Add optional JSON fields only when they have values
if msg.get("tool_calls") is not None:
@@ -312,7 +314,7 @@ async def update_tool_message_content(
"toolCallId": tool_call_id,
},
data={
"content": new_content,
"content": sanitize_string(new_content),
},
)
if result == 0:

View File

@@ -4,6 +4,7 @@ This module contains the CoPilotExecutor class that consumes chat tasks from
RabbitMQ and processes them using a thread pool, following the graph executor pattern.
"""
import asyncio
import logging
import os
import threading
@@ -409,14 +410,19 @@ class CoPilotExecutor(AppProcess):
def on_run_done(f: Future):
logger.info(f"Run completed for {session_id}")
error_msg = None
try:
if exec_error := f.exception():
logger.error(f"Execution for {session_id} failed: {exec_error}")
error_msg = str(exec_error) or type(exec_error).__name__
logger.error(f"Execution for {session_id} failed: {error_msg}")
ack_message(reject=True, requeue=False)
else:
ack_message(reject=False, requeue=False)
except asyncio.CancelledError:
logger.info(f"Run completion callback cancelled for {session_id}")
except BaseException as e:
logger.exception(f"Error in run completion callback: {e}")
error_msg = str(e) or type(e).__name__
logger.exception(f"Error in run completion callback: {error_msg}")
finally:
# Release the cluster lock
if session_id in self._task_locks:

View File

@@ -6,11 +6,13 @@ in a thread-local context, following the graph executor pattern.
import asyncio
import logging
import os
import subprocess
import threading
import time
from backend.copilot import service as copilot_service
from backend.copilot import stream_registry
from backend.copilot.baseline import stream_chat_completion_baseline
from backend.copilot.config import ChatConfig
from backend.copilot.response_model import StreamFinish
from backend.copilot.sdk import service as sdk_service
@@ -108,8 +110,41 @@ class CoPilotProcessor:
)
self.execution_thread.start()
# Skip the SDK's per-request CLI version check — the bundled CLI is
# already version-matched to the SDK package.
os.environ.setdefault("CLAUDE_AGENT_SDK_SKIP_VERSION_CHECK", "1")
# Pre-warm the bundled CLI binary so the OS page-caches the ~185 MB
# executable. First spawn pays ~1.2 s; subsequent spawns ~0.65 s.
self._prewarm_cli()
logger.info(f"[CoPilotExecutor] Worker {self.tid} started")
def _prewarm_cli(self) -> None:
"""Run the bundled CLI binary once to warm OS page caches."""
try:
from claude_agent_sdk._internal.transport.subprocess_cli import (
SubprocessCLITransport,
)
cli_path = SubprocessCLITransport._find_bundled_cli(None) # type: ignore[arg-type]
if cli_path:
result = subprocess.run(
[cli_path, "-v"],
capture_output=True,
timeout=10,
)
if result.returncode == 0:
logger.info(f"[CoPilotExecutor] CLI pre-warm done: {cli_path}")
else:
logger.warning(
"[CoPilotExecutor] CLI pre-warm failed (rc=%d): %s",
result.returncode, # type: ignore[reportCallIssue]
cli_path,
)
except Exception as e:
logger.debug(f"[CoPilotExecutor] CLI pre-warm skipped: {e}")
def cleanup(self):
"""Clean up event-loop-bound resources before the loop is destroyed.
@@ -119,13 +154,16 @@ class CoPilotProcessor:
"""
from backend.util.workspace_storage import shutdown_workspace_storage
coro = shutdown_workspace_storage()
try:
future = asyncio.run_coroutine_threadsafe(
shutdown_workspace_storage(), self.execution_loop
)
future = asyncio.run_coroutine_threadsafe(coro, self.execution_loop)
future.result(timeout=5)
except Exception as e:
logger.warning(f"[CoPilotExecutor] Worker {self.tid} cleanup error: {e}")
coro.close() # Prevent "coroutine was never awaited" warning
error_msg = str(e) or type(e).__name__
logger.warning(
f"[CoPilotExecutor] Worker {self.tid} cleanup error: {error_msg}"
)
# Stop the event loop
self.execution_loop.call_soon_threadsafe(self.execution_loop.stop)
@@ -157,47 +195,30 @@ class CoPilotProcessor:
start_time = time.monotonic()
try:
# Run the async execution in our event loop
future = asyncio.run_coroutine_threadsafe(
self._execute_async(entry, cancel, cluster_lock, log),
self.execution_loop,
)
# Run the async execution in our event loop
future = asyncio.run_coroutine_threadsafe(
self._execute_async(entry, cancel, cluster_lock, log),
self.execution_loop,
)
# Wait for completion, checking cancel periodically
while not future.done():
try:
future.result(timeout=1.0)
except asyncio.TimeoutError:
if cancel.is_set():
log.info("Cancellation requested")
future.cancel()
break
# Refresh cluster lock to maintain ownership
cluster_lock.refresh()
if not future.cancelled():
# Get result to propagate any exceptions
future.result()
elapsed = time.monotonic() - start_time
log.info(f"Execution completed in {elapsed:.2f}s")
except BaseException as e:
elapsed = time.monotonic() - start_time
log.error(f"Execution failed after {elapsed:.2f}s: {e}")
# Safety net: if _execute_async's error handler failed to mark
# the session (e.g. RuntimeError from SDK cleanup), do it here.
# Wait for completion, checking cancel periodically
while not future.done():
try:
asyncio.run_coroutine_threadsafe(
stream_registry.mark_session_completed(
entry.session_id, error_message=str(e) or "Unknown error"
),
self.execution_loop,
).result(timeout=5.0)
except Exception as cleanup_err:
log.error(f"Safety net mark_session_completed failed: {cleanup_err}")
raise
future.result(timeout=1.0)
except asyncio.TimeoutError:
if cancel.is_set():
log.info("Cancellation requested")
future.cancel()
break
# Refresh cluster lock to maintain ownership
cluster_lock.refresh()
if not future.cancelled():
# Get result to propagate any exceptions
future.result()
elapsed = time.monotonic() - start_time
log.info(f"Execution completed in {elapsed:.2f}s")
async def _execute_async(
self,
@@ -208,7 +229,7 @@ class CoPilotProcessor:
):
"""Async execution logic for a CoPilot turn.
Calls the stream_chat_completion service function and publishes
Calls the chat completion service (SDK or baseline) and publishes
results to the stream registry.
Args:
@@ -219,11 +240,13 @@ class CoPilotProcessor:
"""
last_refresh = time.monotonic()
refresh_interval = 30.0 # Refresh lock every 30 seconds
error_msg = None
try:
# Choose service based on LaunchDarkly flag
# Choose service based on LaunchDarkly flag.
# Claude Code subscription forces SDK mode (CLI subprocess auth).
config = ChatConfig()
use_sdk = await is_feature_enabled(
use_sdk = config.use_claude_code_subscription or await is_feature_enabled(
Flag.COPILOT_SDK,
entry.user_id or "anonymous",
default=config.use_claude_agent_sdk,
@@ -231,9 +254,9 @@ class CoPilotProcessor:
stream_fn = (
sdk_service.stream_chat_completion_sdk
if use_sdk
else copilot_service.stream_chat_completion
else stream_chat_completion_baseline
)
log.info(f"Using {'SDK' if use_sdk else 'standard'} service")
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
# Stream chat completion and publish chunks to Redis.
async for chunk in stream_fn(
@@ -242,6 +265,7 @@ class CoPilotProcessor:
is_user_message=entry.is_user_message,
user_id=entry.user_id,
context=entry.context,
file_ids=entry.file_ids,
):
if cancel.is_set():
log.info("Cancel requested, breaking stream")
@@ -264,17 +288,26 @@ class CoPilotProcessor:
exc_info=True,
)
error_message = "Operation cancelled" if cancel.is_set() else None
await stream_registry.mark_session_completed(
entry.session_id, error_message=error_message
)
# Stream loop completed
if cancel.is_set():
log.info("Stream cancelled by user")
except BaseException as e:
log.error(f"Turn failed: {e}")
# Handle all exceptions (including CancelledError) with appropriate logging
if isinstance(e, asyncio.CancelledError):
log.info("Turn cancelled")
error_msg = "Operation cancelled"
else:
error_msg = str(e) or type(e).__name__
log.error(f"Turn failed: {error_msg}")
raise
finally:
# If no exception but user cancelled, still mark as cancelled
if not error_msg and cancel.is_set():
error_msg = "Operation cancelled"
try:
await stream_registry.mark_session_completed(
entry.session_id, error_message=str(e) or "Unknown error"
entry.session_id, error_message=error_msg
)
except Exception as mark_err:
log.error(f"mark_session_completed also failed: {mark_err}")
raise
log.error(f"Failed to mark session completed: {mark_err}")

View File

@@ -153,6 +153,9 @@ class CoPilotExecutionEntry(BaseModel):
context: dict[str, str] | None = None
"""Optional context for the message (e.g., {url: str, content: str})"""
file_ids: list[str] | None = None
"""Workspace file IDs attached to the user's message"""
class CancelCoPilotEvent(BaseModel):
"""Event to cancel a CoPilot operation."""
@@ -171,6 +174,7 @@ async def enqueue_copilot_turn(
turn_id: str,
is_user_message: bool = True,
context: dict[str, str] | None = None,
file_ids: list[str] | None = None,
) -> None:
"""Enqueue a CoPilot task for processing by the executor service.
@@ -181,6 +185,7 @@ async def enqueue_copilot_turn(
turn_id: Per-turn UUID for Redis stream isolation
is_user_message: Whether the message is from the user (vs system/assistant)
context: Optional context for the message (e.g., {url: str, content: str})
file_ids: Optional workspace file IDs attached to the user's message
"""
from backend.util.clients import get_async_copilot_queue
@@ -191,6 +196,7 @@ async def enqueue_copilot_turn(
message=message,
is_user_message=is_user_message,
context=context,
file_ids=file_ids,
)
queue_client = await get_async_copilot_queue()

View File

@@ -672,6 +672,16 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
async with _session_locks_mutex:
_session_locks.pop(session_id, None)
# Shut down any local browser daemon for this session (best-effort).
# Inline import required: all tool modules import ChatSession from this
# module, so any top-level import from tools.* would create a cycle.
try:
from .tools.agent_browser import close_browser_session
await close_browser_session(session_id, user_id=user_id)
except Exception as e:
logger.debug(f"Browser cleanup for session {session_id}: {e}")
return True
@@ -695,19 +705,10 @@ async def update_session_title(session_id: str, title: str) -> bool:
logger.warning(f"Session {session_id} not found for title update")
return False
# Update title in cache if it exists (instead of invalidating).
# This prevents race conditions where cache invalidation causes
# the frontend to see stale DB data while streaming is still in progress.
try:
cached = await _get_session_from_cache(session_id)
if cached:
cached.title = title
await cache_chat_session(cached)
except Exception as e:
# Not critical - title will be correct on next full cache refresh
logger.warning(
f"Failed to update title in cache for session {session_id}: {e}"
)
# Invalidate the cache so the next access reloads from DB with the
# updated title. This avoids a read-modify-write on the full session
# blob, which could overwrite concurrent message updates.
await invalidate_session_cache(session_id)
return True
except Exception as e:

View File

@@ -1,269 +0,0 @@
"""Tests for parallel tool call execution in CoPilot.
These tests mock _yield_tool_call to avoid importing the full copilot stack
which requires Prisma, DB connections, etc.
"""
import asyncio
import time
from typing import Any, cast
import pytest
@pytest.mark.asyncio
async def test_parallel_tool_calls_run_concurrently():
"""Multiple tool calls should complete in ~max(delays), not sum(delays)."""
from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from backend.copilot.service import _execute_tool_calls_parallel
n_tools = 3
delay_per_tool = 0.2
tool_calls = [
{
"id": f"call_{i}",
"type": "function",
"function": {"name": f"tool_{i}", "arguments": "{}"},
}
for i in range(n_tools)
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
original_yield = None
async def fake_yield(tc_list, idx, sess):
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"],
toolName=tc_list[idx]["function"]["name"],
input={},
)
await asyncio.sleep(delay_per_tool)
yield StreamToolOutputAvailable(
toolCallId=tc_list[idx]["id"],
toolName=tc_list[idx]["function"]["name"],
output="{}",
)
import backend.copilot.service as svc
original_yield = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
start = time.monotonic()
events = []
async for event in _execute_tool_calls_parallel(
tool_calls, cast(Any, FakeSession())
):
events.append(event)
elapsed = time.monotonic() - start
finally:
svc._yield_tool_call = original_yield
assert len(events) == n_tools * 2
# Parallel: should take ~delay, not ~n*delay
assert elapsed < delay_per_tool * (
n_tools - 0.5
), f"Took {elapsed:.2f}s, expected parallel (~{delay_per_tool}s)"
@pytest.mark.asyncio
async def test_single_tool_call_works():
"""Single tool call should work identically."""
from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from backend.copilot.service import _execute_tool_calls_parallel
tool_calls = [
{
"id": "call_0",
"type": "function",
"function": {"name": "t", "arguments": "{}"},
}
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
async def fake_yield(tc_list, idx, sess):
yield StreamToolInputAvailable(toolCallId="call_0", toolName="t", input={})
yield StreamToolOutputAvailable(toolCallId="call_0", toolName="t", output="{}")
import backend.copilot.service as svc
orig = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
events = [
e
async for e in _execute_tool_calls_parallel(
tool_calls, cast(Any, FakeSession())
)
]
finally:
svc._yield_tool_call = orig
assert len(events) == 2
@pytest.mark.asyncio
async def test_retryable_error_propagates():
"""Retryable errors should be raised after all tools finish."""
from backend.copilot.response_model import StreamToolOutputAvailable
from backend.copilot.service import _execute_tool_calls_parallel
tool_calls = [
{
"id": f"call_{i}",
"type": "function",
"function": {"name": f"t_{i}", "arguments": "{}"},
}
for i in range(2)
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
async def fake_yield(tc_list, idx, sess):
if idx == 1:
raise KeyError("bad")
from backend.copilot.response_model import StreamToolInputAvailable
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"], toolName="t_0", input={}
)
await asyncio.sleep(0.05)
yield StreamToolOutputAvailable(
toolCallId=tc_list[idx]["id"], toolName="t_0", output="{}"
)
import backend.copilot.service as svc
orig = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
events = []
with pytest.raises(KeyError):
async for event in _execute_tool_calls_parallel(
tool_calls, cast(Any, FakeSession())
):
events.append(event)
# First tool's events should still be yielded
assert any(isinstance(e, StreamToolOutputAvailable) for e in events)
finally:
svc._yield_tool_call = orig
@pytest.mark.asyncio
async def test_session_shared_across_parallel_tools():
"""All parallel tools should receive the same session instance."""
from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from backend.copilot.service import _execute_tool_calls_parallel
tool_calls = [
{
"id": f"call_{i}",
"type": "function",
"function": {"name": f"t_{i}", "arguments": "{}"},
}
for i in range(3)
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
observed_sessions = []
async def fake_yield(tc_list, idx, sess):
observed_sessions.append(sess)
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
)
yield StreamToolOutputAvailable(
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", output="{}"
)
import backend.copilot.service as svc
orig = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
async for _ in _execute_tool_calls_parallel(
tool_calls, cast(Any, FakeSession())
):
pass
finally:
svc._yield_tool_call = orig
assert len(observed_sessions) == 3
assert observed_sessions[0] is observed_sessions[1] is observed_sessions[2]
@pytest.mark.asyncio
async def test_cancellation_cleans_up():
"""Generator close should cancel in-flight tasks."""
from backend.copilot.response_model import StreamToolInputAvailable
from backend.copilot.service import _execute_tool_calls_parallel
tool_calls = [
{
"id": f"call_{i}",
"type": "function",
"function": {"name": f"t_{i}", "arguments": "{}"},
}
for i in range(2)
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
started = asyncio.Event()
async def fake_yield(tc_list, idx, sess):
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
)
started.set()
await asyncio.sleep(10) # simulate long-running
import backend.copilot.service as svc
orig = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
gen = _execute_tool_calls_parallel(tool_calls, cast(Any, FakeSession()))
await gen.__anext__() # get first event
await started.wait()
await gen.aclose() # close generator
finally:
svc._yield_tool_call = orig
# If we get here without hanging, cleanup worked

View File

@@ -13,6 +13,7 @@ from typing import Any
from pydantic import BaseModel, Field
from backend.util.json import dumps as json_dumps
from backend.util.truncate import truncate
logger = logging.getLogger(__name__)
@@ -150,6 +151,9 @@ class StreamToolInputAvailable(StreamBaseResponse):
)
_MAX_TOOL_OUTPUT_SIZE = 100_000 # ~100 KB; truncate to avoid bloating SSE/DB
class StreamToolOutputAvailable(StreamBaseResponse):
"""Tool execution result."""
@@ -164,6 +168,10 @@ class StreamToolOutputAvailable(StreamBaseResponse):
default=True, description="Whether the tool execution succeeded"
)
def model_post_init(self, __context: Any) -> None:
"""Truncate oversized outputs after construction."""
self.output = truncate(self.output, _MAX_TOOL_OUTPUT_SIZE)
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-spec fields."""
data = {

View File

@@ -0,0 +1,239 @@
"""Compaction tracking for SDK-based chat sessions.
Encapsulates the state machine and event emission for context compaction,
both pre-query (history compressed before SDK query) and SDK-internal
(PreCompact hook fires mid-stream).
All compaction-related helpers live here: event builders, message filtering,
persistence, and the ``CompactionTracker`` state machine.
"""
import asyncio
import logging
import uuid
from collections.abc import Callable
from ..constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
from ..model import ChatMessage, ChatSession
from ..response_model import (
StreamBaseResponse,
StreamFinishStep,
StreamStartStep,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Event builders (private — use CompactionTracker or compaction_events)
# ---------------------------------------------------------------------------
def _start_events(tool_call_id: str) -> list[StreamBaseResponse]:
"""Build the opening events for a compaction tool call."""
return [
StreamStartStep(),
StreamToolInputStart(toolCallId=tool_call_id, toolName=COMPACTION_TOOL_NAME),
StreamToolInputAvailable(
toolCallId=tool_call_id, toolName=COMPACTION_TOOL_NAME, input={}
),
]
def _end_events(tool_call_id: str, message: str) -> list[StreamBaseResponse]:
"""Build the closing events for a compaction tool call."""
return [
StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=COMPACTION_TOOL_NAME,
output=message,
),
StreamFinishStep(),
]
def _new_tool_call_id() -> str:
return f"compaction-{uuid.uuid4().hex[:12]}"
# ---------------------------------------------------------------------------
# Public event builder
# ---------------------------------------------------------------------------
def emit_compaction(session: ChatSession) -> list[StreamBaseResponse]:
"""Create, persist, and return a self-contained compaction tool call.
Convenience for callers that don't use ``CompactionTracker`` (e.g. the
legacy non-SDK streaming path in ``service.py``).
"""
tc_id = _new_tool_call_id()
evts = compaction_events(COMPACTION_DONE_MSG, tool_call_id=tc_id)
_persist(session, tc_id, COMPACTION_DONE_MSG)
return evts
def compaction_events(
message: str, tool_call_id: str | None = None
) -> list[StreamBaseResponse]:
"""Emit a self-contained compaction tool call (already completed).
When *tool_call_id* is provided it is reused (e.g. for persistence that
must match an already-streamed start event). Otherwise a new ID is
generated.
"""
tc_id = tool_call_id or _new_tool_call_id()
return _start_events(tc_id) + _end_events(tc_id, message)
# ---------------------------------------------------------------------------
# Message filtering
# ---------------------------------------------------------------------------
def filter_compaction_messages(
messages: list[ChatMessage],
) -> list[ChatMessage]:
"""Remove synthetic compaction tool-call messages (UI-only artifacts).
Strips assistant messages whose only tool calls are compaction calls,
and their corresponding tool-result messages.
"""
compaction_ids: set[str] = set()
filtered: list[ChatMessage] = []
for msg in messages:
if msg.role == "assistant" and msg.tool_calls:
for tc in msg.tool_calls:
if tc.get("function", {}).get("name") == COMPACTION_TOOL_NAME:
compaction_ids.add(tc.get("id", ""))
real_calls = [
tc
for tc in msg.tool_calls
if tc.get("function", {}).get("name") != COMPACTION_TOOL_NAME
]
if not real_calls and not msg.content:
continue
if msg.role == "tool" and msg.tool_call_id in compaction_ids:
continue
filtered.append(msg)
return filtered
# ---------------------------------------------------------------------------
# Persistence
# ---------------------------------------------------------------------------
def _persist(session: ChatSession, tool_call_id: str, message: str) -> None:
"""Append compaction tool-call + result to session messages.
Compaction events are synthetic so they bypass the normal adapter
accumulation. This explicitly records them so they survive a page refresh.
"""
session.messages.append(
ChatMessage(
role="assistant",
content="",
tool_calls=[
{
"id": tool_call_id,
"type": "function",
"function": {
"name": COMPACTION_TOOL_NAME,
"arguments": "{}",
},
}
],
)
)
session.messages.append(
ChatMessage(role="tool", content=message, tool_call_id=tool_call_id)
)
# ---------------------------------------------------------------------------
# CompactionTracker — state machine for streaming sessions
# ---------------------------------------------------------------------------
class CompactionTracker:
"""Tracks compaction state and yields UI events.
Two compaction paths:
1. **Pre-query** — history compressed before the SDK query starts.
Call :meth:`emit_pre_query` to yield a self-contained tool call.
2. **SDK-internal** — ``PreCompact`` hook fires mid-stream.
Call :meth:`emit_start_if_ready` on heartbeat ticks and
:meth:`emit_end_if_ready` when a message arrives.
"""
def __init__(self) -> None:
self._compact_start = asyncio.Event()
self._start_emitted = False
self._done = False
self._tool_call_id = ""
@property
def on_compact(self) -> Callable[[], None]:
"""Callback for the PreCompact hook."""
return self._compact_start.set
# ------------------------------------------------------------------
# Pre-query compaction
# ------------------------------------------------------------------
def emit_pre_query(self, session: ChatSession) -> list[StreamBaseResponse]:
"""Emit + persist a self-contained compaction tool call."""
self._done = True
return emit_compaction(session)
# ------------------------------------------------------------------
# SDK-internal compaction
# ------------------------------------------------------------------
def reset_for_query(self) -> None:
"""Reset per-query state before a new SDK query."""
self._done = False
self._start_emitted = False
self._tool_call_id = ""
def emit_start_if_ready(self) -> list[StreamBaseResponse]:
"""If the PreCompact hook fired, emit start events (spinning tool)."""
if self._compact_start.is_set() and not self._start_emitted and not self._done:
self._compact_start.clear()
self._start_emitted = True
self._tool_call_id = _new_tool_call_id()
return _start_events(self._tool_call_id)
return []
async def emit_end_if_ready(self, session: ChatSession) -> list[StreamBaseResponse]:
"""If compaction is in progress, emit end events and persist."""
# Yield so pending hook tasks can set compact_start
await asyncio.sleep(0)
if self._done:
return []
if not self._start_emitted and not self._compact_start.is_set():
return []
if self._start_emitted:
# Close the open spinner
done_events = _end_events(self._tool_call_id, COMPACTION_DONE_MSG)
persist_id = self._tool_call_id
else:
# PreCompact fired but start never emitted — self-contained
persist_id = _new_tool_call_id()
done_events = compaction_events(
COMPACTION_DONE_MSG, tool_call_id=persist_id
)
self._compact_start.clear()
self._start_emitted = False
self._done = True
_persist(session, persist_id, COMPACTION_DONE_MSG)
return done_events

View File

@@ -0,0 +1,291 @@
"""Tests for sdk/compaction.py — event builders, filtering, persistence, and
CompactionTracker state machine."""
import pytest
from backend.copilot.constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.response_model import (
StreamFinishStep,
StreamStartStep,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from backend.copilot.sdk.compaction import (
CompactionTracker,
compaction_events,
emit_compaction,
filter_compaction_messages,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_session() -> ChatSession:
return ChatSession.new(user_id="test-user")
# ---------------------------------------------------------------------------
# compaction_events
# ---------------------------------------------------------------------------
class TestCompactionEvents:
def test_returns_start_and_end_events(self):
evts = compaction_events("done")
assert len(evts) == 5
assert isinstance(evts[0], StreamStartStep)
assert isinstance(evts[1], StreamToolInputStart)
assert isinstance(evts[2], StreamToolInputAvailable)
assert isinstance(evts[3], StreamToolOutputAvailable)
assert isinstance(evts[4], StreamFinishStep)
def test_uses_provided_tool_call_id(self):
evts = compaction_events("msg", tool_call_id="my-id")
tool_start = evts[1]
assert isinstance(tool_start, StreamToolInputStart)
assert tool_start.toolCallId == "my-id"
def test_generates_id_when_not_provided(self):
evts = compaction_events("msg")
tool_start = evts[1]
assert isinstance(tool_start, StreamToolInputStart)
assert tool_start.toolCallId.startswith("compaction-")
def test_tool_name_is_context_compaction(self):
evts = compaction_events("msg")
tool_start = evts[1]
assert isinstance(tool_start, StreamToolInputStart)
assert tool_start.toolName == COMPACTION_TOOL_NAME
# ---------------------------------------------------------------------------
# emit_compaction
# ---------------------------------------------------------------------------
class TestEmitCompaction:
def test_persists_to_session(self):
session = _make_session()
assert len(session.messages) == 0
evts = emit_compaction(session)
assert len(evts) == 5
# Should have appended 2 messages (assistant tool call + tool result)
assert len(session.messages) == 2
assert session.messages[0].role == "assistant"
assert session.messages[0].tool_calls is not None
assert (
session.messages[0].tool_calls[0]["function"]["name"]
== COMPACTION_TOOL_NAME
)
assert session.messages[1].role == "tool"
assert session.messages[1].content == COMPACTION_DONE_MSG
# ---------------------------------------------------------------------------
# filter_compaction_messages
# ---------------------------------------------------------------------------
class TestFilterCompactionMessages:
def test_removes_compaction_tool_calls(self):
msgs = [
ChatMessage(role="user", content="hello"),
ChatMessage(
role="assistant",
content="",
tool_calls=[
{
"id": "comp-1",
"type": "function",
"function": {"name": COMPACTION_TOOL_NAME, "arguments": "{}"},
}
],
),
ChatMessage(
role="tool", content=COMPACTION_DONE_MSG, tool_call_id="comp-1"
),
ChatMessage(role="assistant", content="world"),
]
filtered = filter_compaction_messages(msgs)
assert len(filtered) == 2
assert filtered[0].content == "hello"
assert filtered[1].content == "world"
def test_keeps_non_compaction_tool_calls(self):
msgs = [
ChatMessage(
role="assistant",
content="",
tool_calls=[
{
"id": "real-1",
"type": "function",
"function": {"name": "search", "arguments": "{}"},
}
],
),
ChatMessage(role="tool", content="result", tool_call_id="real-1"),
]
filtered = filter_compaction_messages(msgs)
assert len(filtered) == 2
def test_keeps_assistant_with_content_and_compaction_call(self):
"""If assistant message has both content and a compaction tool call,
the message is kept (has real content)."""
msgs = [
ChatMessage(
role="assistant",
content="I have content",
tool_calls=[
{
"id": "comp-1",
"type": "function",
"function": {"name": COMPACTION_TOOL_NAME, "arguments": "{}"},
}
],
),
]
filtered = filter_compaction_messages(msgs)
assert len(filtered) == 1
def test_empty_list(self):
assert filter_compaction_messages([]) == []
# ---------------------------------------------------------------------------
# CompactionTracker
# ---------------------------------------------------------------------------
class TestCompactionTracker:
def test_on_compact_sets_event(self):
tracker = CompactionTracker()
tracker.on_compact()
assert tracker._compact_start.is_set()
def test_emit_start_if_ready_no_event(self):
tracker = CompactionTracker()
assert tracker.emit_start_if_ready() == []
def test_emit_start_if_ready_with_event(self):
tracker = CompactionTracker()
tracker.on_compact()
evts = tracker.emit_start_if_ready()
assert len(evts) == 3
assert isinstance(evts[0], StreamStartStep)
assert isinstance(evts[1], StreamToolInputStart)
assert isinstance(evts[2], StreamToolInputAvailable)
def test_emit_start_only_once(self):
tracker = CompactionTracker()
tracker.on_compact()
evts1 = tracker.emit_start_if_ready()
assert len(evts1) == 3
# Second call should return empty
evts2 = tracker.emit_start_if_ready()
assert evts2 == []
@pytest.mark.asyncio
async def test_emit_end_after_start(self):
tracker = CompactionTracker()
session = _make_session()
tracker.on_compact()
tracker.emit_start_if_ready()
evts = await tracker.emit_end_if_ready(session)
assert len(evts) == 2
assert isinstance(evts[0], StreamToolOutputAvailable)
assert isinstance(evts[1], StreamFinishStep)
# Should persist
assert len(session.messages) == 2
@pytest.mark.asyncio
async def test_emit_end_without_start_self_contained(self):
"""If PreCompact fired but start was never emitted, emit_end
produces a self-contained compaction event."""
tracker = CompactionTracker()
session = _make_session()
tracker.on_compact()
# Don't call emit_start_if_ready
evts = await tracker.emit_end_if_ready(session)
assert len(evts) == 5 # Full self-contained event
assert isinstance(evts[0], StreamStartStep)
assert len(session.messages) == 2
@pytest.mark.asyncio
async def test_emit_end_no_op_when_done(self):
tracker = CompactionTracker()
session = _make_session()
tracker.on_compact()
tracker.emit_start_if_ready()
await tracker.emit_end_if_ready(session)
# Second call should be no-op
evts = await tracker.emit_end_if_ready(session)
assert evts == []
@pytest.mark.asyncio
async def test_emit_end_no_op_when_nothing_happened(self):
tracker = CompactionTracker()
session = _make_session()
evts = await tracker.emit_end_if_ready(session)
assert evts == []
def test_emit_pre_query(self):
tracker = CompactionTracker()
session = _make_session()
evts = tracker.emit_pre_query(session)
assert len(evts) == 5
assert len(session.messages) == 2
assert tracker._done is True
def test_reset_for_query(self):
tracker = CompactionTracker()
tracker._done = True
tracker._start_emitted = True
tracker._tool_call_id = "old"
tracker.reset_for_query()
assert tracker._done is False
assert tracker._start_emitted is False
assert tracker._tool_call_id == ""
@pytest.mark.asyncio
async def test_pre_query_blocks_sdk_compaction(self):
"""After pre-query compaction, SDK compaction events are suppressed."""
tracker = CompactionTracker()
session = _make_session()
tracker.emit_pre_query(session)
tracker.on_compact()
evts = tracker.emit_start_if_ready()
assert evts == [] # _done blocks it
@pytest.mark.asyncio
async def test_reset_allows_new_compaction(self):
"""After reset_for_query, compaction can fire again."""
tracker = CompactionTracker()
session = _make_session()
tracker.emit_pre_query(session)
tracker.reset_for_query()
tracker.on_compact()
evts = tracker.emit_start_if_ready()
assert len(evts) == 3 # Start events emitted
@pytest.mark.asyncio
async def test_tool_call_id_consistency(self):
"""Start and end events use the same tool_call_id."""
tracker = CompactionTracker()
session = _make_session()
tracker.on_compact()
start_evts = tracker.emit_start_if_ready()
end_evts = await tracker.emit_end_if_ready(session)
start_evt = start_evts[1]
end_evt = end_evts[0]
assert isinstance(start_evt, StreamToolInputStart)
assert isinstance(end_evt, StreamToolOutputAvailable)
assert start_evt.toolCallId == end_evt.toolCallId
# Persisted ID should also match
tool_calls = session.messages[0].tool_calls
assert tool_calls is not None
assert tool_calls[0]["id"] == start_evt.toolCallId

View File

@@ -10,6 +10,7 @@ import asyncio
import logging
import uuid
from collections.abc import AsyncGenerator
from typing import Any
from ..model import ChatSession
from ..response_model import StreamBaseResponse, StreamStart, StreamTextDelta
@@ -26,6 +27,7 @@ async def stream_chat_completion_dummy(
retry_count: int = 0,
session: ChatSession | None = None,
context: dict[str, str] | None = None,
**_kwargs: Any,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Stream dummy chat completion for testing.

View File

@@ -0,0 +1,362 @@
"""MCP file-tool handlers that route to the E2B cloud sandbox.
When E2B is active, these tools replace the SDK built-in Read/Write/Edit/
Glob/Grep so that all file operations share the same ``/home/user``
filesystem as ``bash_exec``.
SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
"""
from __future__ import annotations
import itertools
import json
import logging
import os
import shlex
from typing import Any, Callable
from backend.copilot.tools.e2b_sandbox import E2B_WORKDIR
logger = logging.getLogger(__name__)
# Lazy imports to break circular dependency with tool_adapter.
def _get_sandbox(): # type: ignore[return]
from .tool_adapter import get_current_sandbox # noqa: E402
return get_current_sandbox()
def _is_allowed_local(path: str) -> bool:
from .tool_adapter import is_allowed_local_path # noqa: E402
return is_allowed_local_path(path)
def _resolve_remote(path: str) -> str:
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
Raises :class:`ValueError` if the resolved path escapes the sandbox.
"""
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
normalized = os.path.normpath(candidate)
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
return normalized
def _mcp(text: str, *, error: bool = False) -> dict[str, Any]:
if error:
text = json.dumps({"error": text, "type": "error"})
return {"content": [{"type": "text", "text": text}], "isError": error}
def _get_sandbox_and_path(
file_path: str,
) -> tuple[Any, str] | dict[str, Any]:
"""Common preamble: get sandbox + resolve path, or return MCP error."""
sandbox = _get_sandbox()
if sandbox is None:
return _mcp("No E2B sandbox available", error=True)
try:
remote = _resolve_remote(file_path)
except ValueError as exc:
return _mcp(str(exc), error=True)
return sandbox, remote
# Tool handlers
async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
file_path: str = args.get("file_path", "")
offset: int = max(0, int(args.get("offset", 0)))
limit: int = max(1, int(args.get("limit", 2000)))
if not file_path:
return _mcp("file_path is required", error=True)
# SDK-internal paths (tool-results, ephemeral working dir) stay on the host.
if _is_allowed_local(file_path):
return _read_local(file_path, offset, limit)
result = _get_sandbox_and_path(file_path)
if isinstance(result, dict):
return result
sandbox, remote = result
try:
raw: bytes = await sandbox.files.read(remote, format="bytes")
content = raw.decode("utf-8", errors="replace")
except Exception as exc:
return _mcp(f"Failed to read {remote}: {exc}", error=True)
lines = content.splitlines(keepends=True)
selected = list(itertools.islice(lines, offset, offset + limit))
numbered = "".join(
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
)
return _mcp(numbered)
async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
file_path: str = args.get("file_path", "")
content: str = args.get("content", "")
if not file_path:
return _mcp("file_path is required", error=True)
result = _get_sandbox_and_path(file_path)
if isinstance(result, dict):
return result
sandbox, remote = result
try:
parent = os.path.dirname(remote)
if parent and parent != E2B_WORKDIR:
await sandbox.files.make_dir(parent)
await sandbox.files.write(remote, content)
except Exception as exc:
return _mcp(f"Failed to write {remote}: {exc}", error=True)
return _mcp(f"Successfully wrote to {remote}")
async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
file_path: str = args.get("file_path", "")
old_string: str = args.get("old_string", "")
new_string: str = args.get("new_string", "")
replace_all: bool = args.get("replace_all", False)
if not file_path:
return _mcp("file_path is required", error=True)
if not old_string:
return _mcp("old_string is required", error=True)
result = _get_sandbox_and_path(file_path)
if isinstance(result, dict):
return result
sandbox, remote = result
try:
raw: bytes = await sandbox.files.read(remote, format="bytes")
content = raw.decode("utf-8", errors="replace")
except Exception as exc:
return _mcp(f"Failed to read {remote}: {exc}", error=True)
count = content.count(old_string)
if count == 0:
return _mcp(f"old_string not found in {file_path}", error=True)
if count > 1 and not replace_all:
return _mcp(
f"old_string appears {count} times in {file_path}. "
"Use replace_all=true or provide a more unique string.",
error=True,
)
updated = (
content.replace(old_string, new_string)
if replace_all
else content.replace(old_string, new_string, 1)
)
try:
await sandbox.files.write(remote, updated)
except Exception as exc:
return _mcp(f"Failed to write {remote}: {exc}", error=True)
return _mcp(f"Edited {remote} ({count} replacement{'s' if count > 1 else ''})")
async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
pattern: str = args.get("pattern", "")
path: str = args.get("path", "")
if not pattern:
return _mcp("pattern is required", error=True)
sandbox = _get_sandbox()
if sandbox is None:
return _mcp("No E2B sandbox available", error=True)
try:
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
except ValueError as exc:
return _mcp(str(exc), error=True)
cmd = f"find {shlex.quote(search_dir)} -name {shlex.quote(pattern)} -type f 2>/dev/null | head -500"
try:
result = await sandbox.commands.run(cmd, cwd=E2B_WORKDIR, timeout=10)
except Exception as exc:
return _mcp(f"Glob failed: {exc}", error=True)
files = [line for line in (result.stdout or "").strip().splitlines() if line]
return _mcp(json.dumps(files, indent=2))
async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
pattern: str = args.get("pattern", "")
path: str = args.get("path", "")
include: str = args.get("include", "")
if not pattern:
return _mcp("pattern is required", error=True)
sandbox = _get_sandbox()
if sandbox is None:
return _mcp("No E2B sandbox available", error=True)
try:
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
except ValueError as exc:
return _mcp(str(exc), error=True)
parts = ["grep", "-rn", "--color=never"]
if include:
parts.extend(["--include", include])
parts.extend([pattern, search_dir])
cmd = " ".join(shlex.quote(p) for p in parts) + " 2>/dev/null | head -200"
try:
result = await sandbox.commands.run(cmd, cwd=E2B_WORKDIR, timeout=15)
except Exception as exc:
return _mcp(f"Grep failed: {exc}", error=True)
output = (result.stdout or "").strip()
return _mcp(output if output else "No matches found.")
# Local read (for SDK-internal paths)
def _read_local(file_path: str, offset: int, limit: int) -> dict[str, Any]:
"""Read from the host filesystem (defence-in-depth path check)."""
if not _is_allowed_local(file_path):
return _mcp(f"Path not allowed: {file_path}", error=True)
expanded = os.path.realpath(os.path.expanduser(file_path))
try:
with open(expanded) as fh:
selected = list(itertools.islice(fh, offset, offset + limit))
numbered = "".join(
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
)
return _mcp(numbered)
except FileNotFoundError:
return _mcp(f"File not found: {file_path}", error=True)
except Exception as exc:
return _mcp(f"Error reading {file_path}: {exc}", error=True)
# Tool descriptors (name, description, schema, handler)
E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
(
"read_file",
"Read a file from the cloud sandbox (/home/user). "
"Use offset and limit for large files.",
{
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path (relative to /home/user, or absolute).",
},
"offset": {
"type": "integer",
"description": "Line to start reading from (0-indexed). Default: 0.",
},
"limit": {
"type": "integer",
"description": "Number of lines to read. Default: 2000.",
},
},
"required": ["file_path"],
},
_handle_read_file,
),
(
"write_file",
"Write or create a file in the cloud sandbox (/home/user). "
"Parent directories are created automatically. "
"To copy a workspace file into the sandbox, use "
"read_workspace_file with save_to_path instead.",
{
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path (relative to /home/user, or absolute).",
},
"content": {"type": "string", "description": "Content to write."},
},
"required": ["file_path", "content"],
},
_handle_write_file,
),
(
"edit_file",
"Targeted text replacement in a sandbox file. "
"old_string must appear in the file and is replaced with new_string.",
{
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path (relative to /home/user, or absolute).",
},
"old_string": {"type": "string", "description": "Text to find."},
"new_string": {"type": "string", "description": "Replacement text."},
"replace_all": {
"type": "boolean",
"description": "Replace all occurrences (default: false).",
},
},
"required": ["file_path", "old_string", "new_string"],
},
_handle_edit_file,
),
(
"glob",
"Search for files by name pattern in the cloud sandbox.",
{
"type": "object",
"properties": {
"pattern": {
"type": "string",
"description": "Glob pattern (e.g. *.py).",
},
"path": {
"type": "string",
"description": "Directory to search. Default: /home/user.",
},
},
"required": ["pattern"],
},
_handle_glob,
),
(
"grep",
"Search file contents by regex in the cloud sandbox.",
{
"type": "object",
"properties": {
"pattern": {"type": "string", "description": "Regex pattern."},
"path": {
"type": "string",
"description": "File or directory. Default: /home/user.",
},
"include": {
"type": "string",
"description": "Glob to filter files (e.g. *.py).",
},
},
"required": ["pattern"],
},
_handle_grep,
),
]
E2B_FILE_TOOL_NAMES: list[str] = [name for name, *_ in E2B_FILE_TOOLS]

View File

@@ -0,0 +1,153 @@
"""Tests for E2B file-tool path validation and local read safety.
Pure unit tests with no external dependencies (no E2B, no sandbox).
"""
import os
import pytest
from .e2b_file_tools import _read_local, _resolve_remote
from .tool_adapter import _current_project_dir
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# ---------------------------------------------------------------------------
# _resolve_remote — sandbox path normalisation & boundary enforcement
# ---------------------------------------------------------------------------
class TestResolveRemote:
def test_relative_path_resolved(self):
assert _resolve_remote("src/main.py") == "/home/user/src/main.py"
def test_absolute_within_sandbox(self):
assert _resolve_remote("/home/user/file.txt") == "/home/user/file.txt"
def test_workdir_itself(self):
assert _resolve_remote("/home/user") == "/home/user"
def test_relative_dotslash(self):
assert _resolve_remote("./README.md") == "/home/user/README.md"
def test_traversal_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("../../etc/passwd")
def test_absolute_traversal_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/home/user/../../etc/passwd")
def test_absolute_outside_sandbox_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/etc/passwd")
def test_root_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/")
def test_home_other_user_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/home/other/file.txt")
def test_deep_nested_allowed(self):
assert _resolve_remote("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
def test_trailing_slash_normalised(self):
assert _resolve_remote("src/") == "/home/user/src"
def test_double_dots_within_sandbox_ok(self):
"""Path that resolves back within /home/user is allowed."""
assert _resolve_remote("a/b/../c.txt") == "/home/user/a/c.txt"
# ---------------------------------------------------------------------------
# _read_local — host filesystem reads with allowlist enforcement
#
# In E2B mode, _read_local only allows tool-results paths (via
# is_allowed_local_path without sdk_cwd). Regular files live on the
# sandbox, not the host.
# ---------------------------------------------------------------------------
class TestReadLocal:
def _make_tool_results_file(self, encoded: str, filename: str, content: str) -> str:
"""Create a tool-results file and return its path."""
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
os.makedirs(tool_results_dir, exist_ok=True)
filepath = os.path.join(tool_results_dir, filename)
with open(filepath, "w") as f:
f.write(content)
return filepath
def test_read_tool_results_file(self):
"""Reading a tool-results file should succeed."""
encoded = "-tmp-copilot-e2b-test-read"
filepath = self._make_tool_results_file(
encoded, "result.txt", "line 1\nline 2\nline 3\n"
)
token = _current_project_dir.set(encoded)
try:
result = _read_local(filepath, offset=0, limit=2000)
assert result["isError"] is False
assert "line 1" in result["content"][0]["text"]
assert "line 2" in result["content"][0]["text"]
finally:
_current_project_dir.reset(token)
os.unlink(filepath)
def test_read_disallowed_path_blocked(self):
"""Reading /etc/passwd should be blocked by the allowlist."""
result = _read_local("/etc/passwd", offset=0, limit=10)
assert result["isError"] is True
assert "not allowed" in result["content"][0]["text"].lower()
def test_read_nonexistent_tool_results(self):
"""A tool-results path that doesn't exist returns FileNotFoundError."""
encoded = "-tmp-copilot-e2b-test-nofile"
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
os.makedirs(tool_results_dir, exist_ok=True)
filepath = os.path.join(tool_results_dir, "nonexistent.txt")
token = _current_project_dir.set(encoded)
try:
result = _read_local(filepath, offset=0, limit=10)
assert result["isError"] is True
assert "not found" in result["content"][0]["text"].lower()
finally:
_current_project_dir.reset(token)
os.rmdir(tool_results_dir)
def test_read_traversal_path_blocked(self):
"""A traversal attempt that escapes allowed directories is blocked."""
result = _read_local("/tmp/copilot-abc/../../etc/shadow", offset=0, limit=10)
assert result["isError"] is True
assert "not allowed" in result["content"][0]["text"].lower()
def test_read_arbitrary_host_path_blocked(self):
"""Arbitrary host paths are blocked even if they exist."""
result = _read_local("/proc/self/environ", offset=0, limit=10)
assert result["isError"] is True
def test_read_with_offset_and_limit(self):
"""Offset and limit should control which lines are returned."""
encoded = "-tmp-copilot-e2b-test-offset"
content = "".join(f"line {i}\n" for i in range(10))
filepath = self._make_tool_results_file(encoded, "lines.txt", content)
token = _current_project_dir.set(encoded)
try:
result = _read_local(filepath, offset=3, limit=2)
assert result["isError"] is False
text = result["content"][0]["text"]
assert "line 3" in text
assert "line 4" in text
assert "line 2" not in text
assert "line 5" not in text
finally:
_current_project_dir.reset(token)
os.unlink(filepath)
def test_read_without_project_dir_blocks_all(self):
"""Without _current_project_dir set, all paths are blocked."""
result = _read_local("/tmp/anything.txt", offset=0, limit=10)
assert result["isError"] is True

View File

@@ -0,0 +1,172 @@
"""Tests for OTEL tracing setup in the SDK copilot path."""
import os
from unittest.mock import MagicMock, patch
class TestSetupLangfuseOtel:
"""Tests for _setup_langfuse_otel()."""
def test_noop_when_langfuse_not_configured(self):
"""No env vars should be set when Langfuse credentials are missing."""
with patch(
"backend.copilot.sdk.service._is_langfuse_configured", return_value=False
):
from backend.copilot.sdk.service import _setup_langfuse_otel
# Clear any previously set env vars
env_keys = [
"LANGSMITH_OTEL_ENABLED",
"LANGSMITH_OTEL_ONLY",
"LANGSMITH_TRACING",
"OTEL_EXPORTER_OTLP_ENDPOINT",
"OTEL_EXPORTER_OTLP_HEADERS",
]
saved = {k: os.environ.pop(k, None) for k in env_keys}
try:
_setup_langfuse_otel()
for key in env_keys:
assert key not in os.environ, f"{key} should not be set"
finally:
for k, v in saved.items():
if v is not None:
os.environ[k] = v
def test_sets_env_vars_when_langfuse_configured(self):
"""OTEL env vars should be set when Langfuse credentials exist."""
mock_settings = MagicMock()
mock_settings.secrets.langfuse_public_key = "pk-test-123"
mock_settings.secrets.langfuse_secret_key = "sk-test-456"
mock_settings.secrets.langfuse_host = "https://langfuse.example.com"
mock_settings.secrets.langfuse_tracing_environment = "test"
with (
patch(
"backend.copilot.sdk.service._is_langfuse_configured",
return_value=True,
),
patch("backend.copilot.sdk.service.Settings", return_value=mock_settings),
patch(
"backend.copilot.sdk.service.configure_claude_agent_sdk",
return_value=True,
) as mock_configure,
):
from backend.copilot.sdk.service import _setup_langfuse_otel
# Clear env vars so setdefault works
env_keys = [
"LANGSMITH_OTEL_ENABLED",
"LANGSMITH_OTEL_ONLY",
"LANGSMITH_TRACING",
"OTEL_EXPORTER_OTLP_ENDPOINT",
"OTEL_EXPORTER_OTLP_HEADERS",
"OTEL_RESOURCE_ATTRIBUTES",
]
saved = {k: os.environ.pop(k, None) for k in env_keys}
try:
_setup_langfuse_otel()
assert os.environ["LANGSMITH_OTEL_ENABLED"] == "true"
assert os.environ["LANGSMITH_OTEL_ONLY"] == "true"
assert os.environ["LANGSMITH_TRACING"] == "true"
assert (
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"]
== "https://langfuse.example.com/api/public/otel"
)
assert "Authorization=Basic" in os.environ["OTEL_EXPORTER_OTLP_HEADERS"]
assert (
os.environ["OTEL_RESOURCE_ATTRIBUTES"]
== "langfuse.environment=test"
)
mock_configure.assert_called_once_with(tags=["sdk"])
finally:
for k, v in saved.items():
if v is not None:
os.environ[k] = v
elif k in os.environ:
del os.environ[k]
def test_existing_env_vars_not_overwritten(self):
"""Explicit env-var overrides should not be clobbered."""
mock_settings = MagicMock()
mock_settings.secrets.langfuse_public_key = "pk-test"
mock_settings.secrets.langfuse_secret_key = "sk-test"
mock_settings.secrets.langfuse_host = "https://langfuse.example.com"
with (
patch(
"backend.copilot.sdk.service._is_langfuse_configured",
return_value=True,
),
patch("backend.copilot.sdk.service.Settings", return_value=mock_settings),
patch(
"backend.copilot.sdk.service.configure_claude_agent_sdk",
return_value=True,
),
):
from backend.copilot.sdk.service import _setup_langfuse_otel
saved = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
try:
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = "https://custom.endpoint/v1"
_setup_langfuse_otel()
assert (
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"]
== "https://custom.endpoint/v1"
)
finally:
if saved is not None:
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = saved
elif "OTEL_EXPORTER_OTLP_ENDPOINT" in os.environ:
del os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"]
def test_graceful_failure_on_exception(self):
"""Setup should not raise even if internal code fails."""
with (
patch(
"backend.copilot.sdk.service._is_langfuse_configured",
return_value=True,
),
patch(
"backend.copilot.sdk.service.Settings",
side_effect=RuntimeError("settings unavailable"),
),
):
from backend.copilot.sdk.service import _setup_langfuse_otel
# Should not raise — just logs and returns
_setup_langfuse_otel()
class TestPropagateAttributesImport:
"""Verify langfuse.propagate_attributes is available."""
def test_propagate_attributes_is_importable(self):
from langfuse import propagate_attributes
assert callable(propagate_attributes)
def test_propagate_attributes_returns_context_manager(self):
from langfuse import propagate_attributes
ctx = propagate_attributes(user_id="u1", session_id="s1", tags=["test"])
assert hasattr(ctx, "__enter__")
assert hasattr(ctx, "__exit__")
class TestReceiveResponseCompat:
"""Verify ClaudeSDKClient.receive_response() exists (langsmith patches it)."""
def test_receive_response_exists(self):
from claude_agent_sdk import ClaudeSDKClient
assert hasattr(ClaudeSDKClient, "receive_response")
def test_receive_response_is_async_generator(self):
import inspect
from claude_agent_sdk import ClaudeSDKClient
method = getattr(ClaudeSDKClient, "receive_response")
assert inspect.isfunction(method) or inspect.ismethod(method)

View File

@@ -118,7 +118,7 @@ async def test_build_query_resume_up_to_date():
ChatMessage(role="user", content="what's new?"),
]
)
result = await _build_query_message(
result, was_compacted = await _build_query_message(
"what's new?",
session,
use_resume=True,
@@ -127,6 +127,7 @@ async def test_build_query_resume_up_to_date():
)
# transcript_msg_count == msg_count - 1, so no gap
assert result == "what's new?"
assert was_compacted is False
@pytest.mark.asyncio
@@ -141,7 +142,7 @@ async def test_build_query_resume_stale_transcript():
ChatMessage(role="user", content="turn 3"),
]
)
result = await _build_query_message(
result, was_compacted = await _build_query_message(
"turn 3",
session,
use_resume=True,
@@ -152,6 +153,7 @@ async def test_build_query_resume_stale_transcript():
assert "turn 2" in result
assert "reply 2" in result
assert "Now, the user says:\nturn 3" in result
assert was_compacted is False # gap context does not compact
@pytest.mark.asyncio
@@ -164,7 +166,7 @@ async def test_build_query_resume_zero_msg_count():
ChatMessage(role="user", content="new msg"),
]
)
result = await _build_query_message(
result, was_compacted = await _build_query_message(
"new msg",
session,
use_resume=True,
@@ -172,13 +174,14 @@ async def test_build_query_resume_zero_msg_count():
session_id="test-session",
)
assert result == "new msg"
assert was_compacted is False
@pytest.mark.asyncio
async def test_build_query_no_resume_single_message():
"""Without --resume and only 1 message, return raw message."""
session = _make_session([ChatMessage(role="user", content="first")])
result = await _build_query_message(
result, was_compacted = await _build_query_message(
"first",
session,
use_resume=False,
@@ -186,6 +189,7 @@ async def test_build_query_no_resume_single_message():
session_id="test-session",
)
assert result == "first"
assert was_compacted is False
@pytest.mark.asyncio
@@ -199,16 +203,16 @@ async def test_build_query_no_resume_multi_message(monkeypatch):
]
)
# Mock _compress_conversation_history to return the messages as-is
async def _mock_compress(sess):
return sess.messages[:-1]
# Mock _compress_messages to return the messages as-is
async def _mock_compress(msgs):
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_conversation_history",
"backend.copilot.sdk.service._compress_messages",
_mock_compress,
)
result = await _build_query_message(
result, was_compacted = await _build_query_message(
"new question",
session,
use_resume=False,
@@ -219,3 +223,33 @@ async def test_build_query_no_resume_multi_message(monkeypatch):
assert "older question" in result
assert "older answer" in result
assert "Now, the user says:\nnew question" in result
assert was_compacted is False # mock returns False
@pytest.mark.asyncio
async def test_build_query_no_resume_multi_message_compacted(monkeypatch):
"""When compression actually compacts, was_compacted should be True."""
session = _make_session(
[
ChatMessage(role="user", content="old"),
ChatMessage(role="assistant", content="reply"),
ChatMessage(role="user", content="new"),
]
)
async def _mock_compress(msgs):
return msgs, True # Simulate actual compaction
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages",
_mock_compress,
)
result, was_compacted = await _build_query_message(
"new",
session,
use_resume=False,
transcript_msg_count=0,
session_id="test-session",
)
assert was_compacted is True

View File

@@ -6,7 +6,6 @@ ensuring multi-user isolation and preventing unauthorized operations.
import json
import logging
import os
import re
from collections.abc import Callable
from typing import Any, cast
@@ -16,6 +15,7 @@ from .tool_adapter import (
DANGEROUS_PATTERNS,
MCP_TOOL_PREFIX,
WORKSPACE_SCOPED_TOOLS,
is_allowed_local_path,
stash_pending_tool_output,
)
@@ -38,40 +38,20 @@ def _validate_workspace_path(
) -> dict[str, Any]:
"""Validate that a workspace-scoped tool only accesses allowed paths.
Allowed directories:
Delegates to :func:`is_allowed_local_path` which permits:
- The SDK working directory (``/tmp/copilot-<session>/``)
- The SDK tool-results directory (``~/.claude/projects/…/tool-results/``)
- The current session's tool-results directory
(``~/.claude/projects/<encoded-cwd>/tool-results/``)
"""
path = tool_input.get("file_path") or tool_input.get("path") or ""
if not path:
# Glob/Grep without a path default to cwd which is already sandboxed
return {}
# Resolve relative paths against sdk_cwd (the SDK sets cwd so the LLM
# naturally uses relative paths like "test.txt" instead of absolute ones).
# Tilde paths (~/) are home-dir references, not relative — expand first.
if path.startswith("~"):
resolved = os.path.realpath(os.path.expanduser(path))
elif not os.path.isabs(path) and sdk_cwd:
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
else:
resolved = os.path.realpath(path)
# Allow access within the SDK working directory
if sdk_cwd:
norm_cwd = os.path.realpath(sdk_cwd)
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
return {}
# Allow access to ~/.claude/projects/*/tool-results/ (big tool results)
claude_dir = os.path.realpath(os.path.expanduser("~/.claude/projects"))
tool_results_seg = os.sep + "tool-results" + os.sep
if resolved.startswith(claude_dir + os.sep) and tool_results_seg in resolved:
if is_allowed_local_path(path, sdk_cwd):
return {}
logger.warning(
f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})"
)
logger.warning(f"Blocked {tool_name} outside workspace: {path}")
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
return _deny(
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
@@ -146,6 +126,7 @@ def create_security_hooks(
user_id: str | None,
sdk_cwd: str | None = None,
max_subtasks: int = 3,
on_compact: Callable[[], None] | None = None,
on_stop: Callable[[str, str], None] | None = None,
) -> dict[str, Any]:
"""Create the security hooks configuration for Claude Agent SDK.
@@ -160,7 +141,7 @@ def create_security_hooks(
Args:
user_id: Current user ID for isolation validation
sdk_cwd: SDK working directory for workspace-scoped tool validation
max_subtasks: Maximum Task (sub-agent) spawns allowed per session
max_subtasks: Maximum concurrent Task (sub-agent) spawns allowed per session
on_stop: Callback ``(transcript_path, sdk_session_id)`` invoked when
the SDK finishes processing — used to read the JSONL transcript
before the CLI process exits.
@@ -172,8 +153,9 @@ def create_security_hooks(
from claude_agent_sdk import HookMatcher
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
# Per-session counter for Task sub-agent spawns
task_spawn_count = 0
# Per-session tracking for Task sub-agent concurrency.
# Set of tool_use_ids that consumed a slot — len() is the active count.
task_tool_use_ids: set[str] = set()
async def pre_tool_use_hook(
input_data: HookInput,
@@ -181,7 +163,6 @@ def create_security_hooks(
context: HookContext,
) -> SyncHookJSONOutput:
"""Combined pre-tool-use validation hook."""
nonlocal task_spawn_count
_ = context # unused but required by signature
tool_name = cast(str, input_data.get("tool_name", ""))
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
@@ -200,18 +181,18 @@ def create_security_hooks(
"(remove the run_in_background parameter)."
),
)
if task_spawn_count >= max_subtasks:
if len(task_tool_use_ids) >= max_subtasks:
logger.warning(
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
)
return cast(
SyncHookJSONOutput,
_deny(
f"Maximum {max_subtasks} sub-tasks per session. "
"Please continue in the main conversation."
f"Maximum {max_subtasks} concurrent sub-tasks. "
"Wait for running sub-tasks to finish, "
"or continue in the main conversation."
),
)
task_spawn_count += 1
# Strip MCP prefix for consistent validation
is_copilot_tool = tool_name.startswith(MCP_TOOL_PREFIX)
@@ -229,9 +210,24 @@ def create_security_hooks(
if result:
return cast(SyncHookJSONOutput, result)
# Reserve the Task slot only after all validations pass
if tool_name == "Task" and tool_use_id is not None:
task_tool_use_ids.add(tool_use_id)
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
return cast(SyncHookJSONOutput, {})
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:
"""Release a Task concurrency slot if one was reserved."""
if tool_name == "Task" and tool_use_id in task_tool_use_ids:
task_tool_use_ids.discard(tool_use_id)
logger.info(
"[SDK] Task slot released, active=%d/%d, user=%s",
len(task_tool_use_ids),
max_subtasks,
user_id,
)
async def post_tool_use_hook(
input_data: HookInput,
tool_use_id: str | None,
@@ -246,6 +242,8 @@ def create_security_hooks(
"""
_ = context
tool_name = cast(str, input_data.get("tool_name", ""))
_release_task_slot(tool_name, tool_use_id)
is_builtin = not tool_name.startswith(MCP_TOOL_PREFIX)
logger.info(
"[SDK] PostToolUse: %s (builtin=%s, tool_use_id=%s)",
@@ -289,6 +287,9 @@ def create_security_hooks(
f"[SDK] Tool failed: {tool_name}, error={error}, "
f"user={user_id}, tool_use_id={tool_use_id}"
)
_release_task_slot(tool_name, tool_use_id)
return cast(SyncHookJSONOutput, {})
async def pre_compact_hook(
@@ -306,6 +307,8 @@ def create_security_hooks(
logger.info(
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
)
if on_compact is not None:
on_compact()
return cast(SyncHookJSONOutput, {})
# --- Stop hook: capture transcript path for stateless resume ---

View File

@@ -120,17 +120,31 @@ def test_read_no_cwd_denies_absolute():
def test_read_tool_results_allowed():
from .tool_adapter import _current_project_dir
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert result == {}
# is_allowed_local_path requires the session's encoded cwd to be set
token = _current_project_dir.set("-tmp-copilot-abc123")
try:
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert result == {}
finally:
_current_project_dir.reset(token)
def test_read_claude_projects_without_tool_results_denied():
def test_read_claude_projects_session_dir_allowed():
"""Files within the current session's project dir are allowed."""
from .tool_adapter import _current_project_dir
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert _is_denied(result)
token = _current_project_dir.set("-tmp-copilot-abc123")
try:
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert not _is_denied(result)
finally:
_current_project_dir.reset(token)
# -- Built-in Bash is blocked (use bash_exec MCP tool instead) ---------------
@@ -208,19 +222,22 @@ def test_bash_builtin_blocked_message_clarity():
@pytest.fixture()
def _hooks():
"""Create security hooks and return the PreToolUse handler."""
"""Create security hooks and return (pre, post, post_failure) handlers."""
from .security_hooks import create_security_hooks
hooks = create_security_hooks(user_id="u1", sdk_cwd=SDK_CWD, max_subtasks=2)
pre = hooks["PreToolUse"][0].hooks[0]
return pre
post = hooks["PostToolUse"][0].hooks[0]
post_failure = hooks["PostToolUseFailure"][0].hooks[0]
return pre, post, post_failure
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_task_background_blocked(_hooks):
"""Task with run_in_background=true must be denied."""
result = await _hooks(
pre, _, _ = _hooks
result = await pre(
{"tool_name": "Task", "tool_input": {"run_in_background": True, "prompt": "x"}},
tool_use_id=None,
context={},
@@ -233,9 +250,10 @@ async def test_task_background_blocked(_hooks):
@pytest.mark.asyncio
async def test_task_foreground_allowed(_hooks):
"""Task without run_in_background should be allowed."""
result = await _hooks(
pre, _, _ = _hooks
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "do stuff"}},
tool_use_id=None,
tool_use_id="tu-1",
context={},
)
assert not _is_denied(result)
@@ -245,25 +263,102 @@ async def test_task_foreground_allowed(_hooks):
@pytest.mark.asyncio
async def test_task_limit_enforced(_hooks):
"""Task spawns beyond max_subtasks should be denied."""
pre, _, _ = _hooks
# First two should pass
for _ in range(2):
result = await _hooks(
for i in range(2):
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "ok"}},
tool_use_id=None,
tool_use_id=f"tu-limit-{i}",
context={},
)
assert not _is_denied(result)
# Third should be denied (limit=2)
result = await _hooks(
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "over limit"}},
tool_use_id=None,
tool_use_id="tu-limit-2",
context={},
)
assert _is_denied(result)
assert "Maximum" in _reason(result)
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_task_slot_released_on_completion(_hooks):
"""Completing a Task should free a slot so new Tasks can be spawned."""
pre, post, _ = _hooks
# Fill both slots
for i in range(2):
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "ok"}},
tool_use_id=f"tu-comp-{i}",
context={},
)
assert not _is_denied(result)
# Third should be denied — at capacity
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "over"}},
tool_use_id="tu-comp-2",
context={},
)
assert _is_denied(result)
# Complete first task — frees a slot
await post(
{"tool_name": "Task", "tool_input": {}},
tool_use_id="tu-comp-0",
context={},
)
# Now a new Task should be allowed
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "after release"}},
tool_use_id="tu-comp-3",
context={},
)
assert not _is_denied(result)
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_task_slot_released_on_failure(_hooks):
"""A failed Task should also free its concurrency slot."""
pre, _, post_failure = _hooks
# Fill both slots
for i in range(2):
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "ok"}},
tool_use_id=f"tu-fail-{i}",
context={},
)
assert not _is_denied(result)
# At capacity
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "over"}},
tool_use_id="tu-fail-2",
context={},
)
assert _is_denied(result)
# Fail first task — should free a slot
await post_failure(
{"tool_name": "Task", "tool_input": {}, "error": "something broke"},
tool_use_id="tu-fail-0",
context={},
)
# New Task should be allowed
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "after failure"}},
tool_use_id="tu-fail-3",
context={},
)
assert not _is_denied(result)
# -- _is_tool_error_or_denial ------------------------------------------------
@@ -298,7 +393,9 @@ class TestIsToolErrorOrDenial:
def test_subtask_limit_denial(self):
assert (
_is_tool_error_or_denial(
"Maximum 2 sub-tasks per session. Please continue in the main conversation."
"Maximum 2 concurrent sub-tasks. "
"Wait for running sub-tasks to finish, "
"or continue in the main conversation."
)
is True
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,147 @@
"""Tests for SDK service helpers."""
import base64
import os
from dataclasses import dataclass
from unittest.mock import AsyncMock, patch
import pytest
from .service import _prepare_file_attachments
@dataclass
class _FakeFileInfo:
id: str
name: str
path: str
mime_type: str
size_bytes: int
_PATCH_TARGET = "backend.copilot.sdk.service.get_manager"
class TestPrepareFileAttachments:
@pytest.mark.asyncio
async def test_empty_list_returns_empty(self, tmp_path):
result = await _prepare_file_attachments([], "u", "s", str(tmp_path))
assert result.hint == ""
assert result.image_blocks == []
@pytest.mark.asyncio
async def test_image_embedded_as_vision_block(self, tmp_path):
"""JPEG images should become vision content blocks, not files on disk."""
raw = b"\xff\xd8\xff\xe0fake-jpeg"
info = _FakeFileInfo(
id="abc",
name="photo.jpg",
path="/photo.jpg",
mime_type="image/jpeg",
size_bytes=len(raw),
)
mgr = AsyncMock()
mgr.get_file_info.return_value = info
mgr.read_file_by_id.return_value = raw
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
result = await _prepare_file_attachments(
["abc"], "user1", "sess1", str(tmp_path)
)
assert "1 file" in result.hint
assert "photo.jpg" in result.hint
assert "embedded as image" in result.hint
assert len(result.image_blocks) == 1
block = result.image_blocks[0]
assert block["type"] == "image"
assert block["source"]["media_type"] == "image/jpeg"
assert block["source"]["data"] == base64.b64encode(raw).decode("ascii")
# Image should NOT be written to disk (embedded instead)
assert not os.path.exists(os.path.join(tmp_path, "photo.jpg"))
@pytest.mark.asyncio
async def test_pdf_saved_to_disk(self, tmp_path):
"""PDFs should be saved to disk for Read tool access, not embedded."""
info = _FakeFileInfo("f1", "doc.pdf", "/doc.pdf", "application/pdf", 50)
mgr = AsyncMock()
mgr.get_file_info.return_value = info
mgr.read_file_by_id.return_value = b"%PDF-1.4 fake"
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
result = await _prepare_file_attachments(["f1"], "u", "s", str(tmp_path))
assert result.image_blocks == []
saved = tmp_path / "doc.pdf"
assert saved.exists()
assert saved.read_bytes() == b"%PDF-1.4 fake"
assert str(saved) in result.hint
@pytest.mark.asyncio
async def test_mixed_images_and_files(self, tmp_path):
"""Images become blocks, non-images go to disk."""
infos = {
"id1": _FakeFileInfo("id1", "a.png", "/a.png", "image/png", 4),
"id2": _FakeFileInfo("id2", "b.pdf", "/b.pdf", "application/pdf", 4),
"id3": _FakeFileInfo("id3", "c.txt", "/c.txt", "text/plain", 4),
}
mgr = AsyncMock()
mgr.get_file_info.side_effect = lambda fid: infos[fid]
mgr.read_file_by_id.return_value = b"data"
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
result = await _prepare_file_attachments(
["id1", "id2", "id3"], "u", "s", str(tmp_path)
)
assert "3 files" in result.hint
assert "a.png" in result.hint
assert "b.pdf" in result.hint
assert "c.txt" in result.hint
# Only the image should be a vision block
assert len(result.image_blocks) == 1
assert result.image_blocks[0]["source"]["media_type"] == "image/png"
# Non-image files should be on disk
assert (tmp_path / "b.pdf").exists()
assert (tmp_path / "c.txt").exists()
# Read tool hint should appear (has non-image files)
assert "Read tool" in result.hint
@pytest.mark.asyncio
async def test_singular_noun(self, tmp_path):
info = _FakeFileInfo("x", "only.txt", "/only.txt", "text/plain", 2)
mgr = AsyncMock()
mgr.get_file_info.return_value = info
mgr.read_file_by_id.return_value = b"hi"
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
result = await _prepare_file_attachments(["x"], "u", "s", str(tmp_path))
assert "1 file." in result.hint
@pytest.mark.asyncio
async def test_missing_file_skipped(self, tmp_path):
mgr = AsyncMock()
mgr.get_file_info.return_value = None
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
result = await _prepare_file_attachments(
["missing-id"], "u", "s", str(tmp_path)
)
assert result.hint == ""
assert result.image_blocks == []
@pytest.mark.asyncio
async def test_image_only_no_read_hint(self, tmp_path):
"""When all files are images, no Read tool hint should appear."""
info = _FakeFileInfo("i1", "cat.png", "/cat.png", "image/png", 4)
mgr = AsyncMock()
mgr.get_file_info.return_value = info
mgr.read_file_by_id.return_value = b"data"
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
result = await _prepare_file_attachments(["i1"], "u", "s", str(tmp_path))
assert "Read tool" not in result.hint
assert len(result.image_blocks) == 1

View File

@@ -9,20 +9,84 @@ import itertools
import json
import logging
import os
import re
import uuid
from contextvars import ContextVar
from typing import Any
from typing import TYPE_CHECKING, Any
from claude_agent_sdk import create_sdk_mcp_server, tool
from backend.copilot.model import ChatSession
from backend.copilot.tools import TOOL_REGISTRY
from backend.copilot.tools.base import BaseTool
from backend.util.truncate import truncate
from .e2b_file_tools import E2B_FILE_TOOL_NAMES, E2B_FILE_TOOLS
if TYPE_CHECKING:
from e2b import AsyncSandbox
logger = logging.getLogger(__name__)
# Allowed base directory for the Read tool (SDK saves oversized tool results here).
# Restricted to ~/.claude/projects/ and further validated to require "tool-results"
# in the path — prevents reading settings, credentials, or other sensitive files.
_SDK_PROJECTS_DIR = os.path.expanduser("~/.claude/projects/")
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# Max MCP response size in chars — keeps tool output under the SDK's 10 MB JSON buffer.
_MCP_MAX_CHARS = 500_000
# Context variable holding the encoded project directory name for the current
# session (e.g. "-private-tmp-copilot-<uuid>"). Set by set_execution_context()
# so that path validation can scope tool-results reads to the current session.
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
def _encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does.
The CLI replaces all non-alphanumeric characters with ``-``.
"""
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
"""Check whether *path* is an allowed host-filesystem path.
Allowed:
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
- Files under ``~/.claude/projects/<encoded-cwd>/`` — the SDK's
project directory for this session (tool-results, transcripts, etc.)
Both checks are scoped to the **current session** so sessions cannot
read each other's data.
"""
if not path:
return False
if path.startswith("~"):
resolved = os.path.realpath(os.path.expanduser(path))
elif not os.path.isabs(path) and sdk_cwd:
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
else:
resolved = os.path.realpath(path)
# Allow access within the SDK working directory
if sdk_cwd:
norm_cwd = os.path.realpath(sdk_cwd)
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
return True
# Allow access within the current session's CLI project directory
# (~/.claude/projects/<encoded-cwd>/).
encoded = _current_project_dir.get("")
if encoded:
session_project = os.path.join(_SDK_PROJECTS_DIR, encoded)
if resolved == session_project or resolved.startswith(session_project + os.sep):
return True
return False
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
MCP_SERVER_NAME = "copilot"
@@ -33,6 +97,15 @@ _current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default
_current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
)
# E2B cloud sandbox for the current turn (None when E2B is not configured).
# Passed to bash_exec so commands run on E2B instead of the local bwrap sandbox.
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
"_current_sandbox", default=None
)
# Raw SDK working directory path (e.g. /tmp/copilot-<session_id>).
# Used by workspace tools to save binary files for the CLI's built-in Read.
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
# Stash for MCP tool outputs before the SDK potentially truncates them.
# Keyed by tool_name → full output string. Consumed (popped) by the
# response adapter when it builds StreamToolOutputAvailable.
@@ -53,22 +126,39 @@ _stash_event: ContextVar[asyncio.Event | None] = ContextVar(
def set_execution_context(
user_id: str | None,
session: ChatSession,
sandbox: "AsyncSandbox | None" = None,
sdk_cwd: str | None = None,
) -> None:
"""Set the execution context for tool calls.
This must be called before streaming begins to ensure tools have access
to user_id and session information.
to user_id, session, and (optionally) an E2B sandbox for bash execution.
Args:
user_id: Current user's ID.
session: Current chat session.
sandbox: Optional E2B sandbox; when set, bash_exec routes commands there.
sdk_cwd: SDK working directory; used to scope tool-results reads.
"""
_current_user_id.set(user_id)
_current_session.set(session)
_current_sandbox.set(sandbox)
_current_sdk_cwd.set(sdk_cwd or "")
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
_pending_tool_outputs.set({})
_stash_event.set(asyncio.Event())
def get_current_sandbox() -> "AsyncSandbox | None":
"""Return the E2B sandbox for the current turn, or None."""
return _current_sandbox.get()
def get_sdk_cwd() -> str:
"""Return the SDK ephemeral working directory for the current turn."""
return _current_sdk_cwd.get()
def get_execution_context() -> tuple[str | None, ChatSession | None]:
"""Get the current execution context."""
return (
@@ -182,66 +272,12 @@ async def _execute_tool_sync(
result.output if isinstance(result.output, str) else json.dumps(result.output)
)
# Stash the full output before the SDK potentially truncates it.
pending = _pending_tool_outputs.get(None)
if pending is not None:
pending.setdefault(base_tool.name, []).append(text)
content_blocks: list[dict[str, str]] = [{"type": "text", "text": text}]
# If the tool result contains inline image data, add an MCP image block
# so Claude can "see" the image (e.g. read_workspace_file on a small PNG).
image_block = _extract_image_block(text)
if image_block:
content_blocks.append(image_block)
return {
"content": content_blocks,
"content": [{"type": "text", "text": text}],
"isError": not result.success,
}
# MIME types that Claude can process as image content blocks.
_SUPPORTED_IMAGE_TYPES = frozenset(
{"image/png", "image/jpeg", "image/gif", "image/webp"}
)
def _extract_image_block(text: str) -> dict[str, str] | None:
"""Extract an MCP image content block from a tool result JSON string.
Detects workspace file responses with ``content_base64`` and an image
MIME type, returning an MCP-format image block that allows Claude to
"see" the image. Returns ``None`` if the result is not an inline image.
"""
try:
data = json.loads(text)
except (json.JSONDecodeError, TypeError):
return None
if not isinstance(data, dict):
return None
mime_type = data.get("mime_type", "")
base64_content = data.get("content_base64", "")
# Only inline small images — large ones would exceed Claude's limits.
# 32 KB raw ≈ ~43 KB base64.
_MAX_IMAGE_BASE64_BYTES = 43_000
if (
mime_type in _SUPPORTED_IMAGE_TYPES
and base64_content
and len(base64_content) <= _MAX_IMAGE_BASE64_BYTES
):
return {
"type": "image",
"data": base64_content,
"mimeType": mime_type,
}
return None
def _mcp_error(message: str) -> dict[str, Any]:
return {
"content": [
@@ -284,29 +320,32 @@ def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Read a file with optional offset/limit. Restricted to SDK working directory.
"""Read a local file with optional offset/limit.
After reading, the file is deleted to prevent accumulation in long-running pods.
Only allows paths that pass :func:`is_allowed_local_path` — the current
session's tool-results directory and ephemeral working directory.
"""
file_path = args.get("file_path", "")
offset = args.get("offset", 0)
limit = args.get("limit", 2000)
# Security: only allow reads under ~/.claude/projects/**/tool-results/
real_path = os.path.realpath(file_path)
if not real_path.startswith(_SDK_PROJECTS_DIR) or "tool-results" not in real_path:
if not is_allowed_local_path(file_path):
return {
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
"isError": True,
}
resolved = os.path.realpath(os.path.expanduser(file_path))
try:
with open(real_path) as f:
with open(resolved) as f:
selected = list(itertools.islice(f, offset, offset + limit))
content = "".join(selected)
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
# don't delete here — the SDK may read in multiple chunks.
return {"content": [{"type": "text", "text": content}], "isError": False}
return {
"content": [{"type": "text", "text": content}],
"isError": False,
}
except FileNotFoundError:
return {
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
@@ -344,50 +383,86 @@ _READ_TOOL_SCHEMA = {
}
# Create the MCP server configuration
def create_copilot_mcp_server():
# ---------------------------------------------------------------------------
# MCP result helpers
# ---------------------------------------------------------------------------
def _text_from_mcp_result(result: dict[str, Any]) -> str:
"""Extract concatenated text from an MCP response's content blocks."""
content = result.get("content", [])
if not isinstance(content, list):
return ""
return "".join(
b.get("text", "")
for b in content
if isinstance(b, dict) and b.get("type") == "text"
)
def create_copilot_mcp_server(*, use_e2b: bool = False):
"""Create an in-process MCP server configuration for CoPilot tools.
This can be passed to ClaudeAgentOptions.mcp_servers.
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
package being available. This function returns the configuration that
can be used with the SDK.
When *use_e2b* is True, five additional MCP file tools are registered
that route directly to the E2B sandbox filesystem, and the caller should
disable the corresponding SDK built-in tools via
:func:`get_sdk_disallowed_tools`.
"""
try:
from claude_agent_sdk import create_sdk_mcp_server, tool
# Create decorated tool functions
sdk_tools = []
def _truncating(fn, tool_name: str):
"""Wrap a tool handler so its response is truncated to stay under the
SDK's 10 MB JSON buffer, and stash the (truncated) output for the
response adapter before the SDK can apply its own head-truncation.
for tool_name, base_tool in TOOL_REGISTRY.items():
handler = create_tool_handler(base_tool)
decorated = tool(
tool_name,
base_tool.description,
_build_input_schema(base_tool),
)(handler)
Applied once to every registered tool."""
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
result = await fn(args)
truncated = truncate(result, _MCP_MAX_CHARS)
# Stash the text so the response adapter can forward our
# middle-out truncated version to the frontend instead of the
# SDK's head-truncated version (for outputs >~100 KB the SDK
# persists to tool-results/ with a 2 KB head-only preview).
if not truncated.get("isError"):
text = _text_from_mcp_result(truncated)
if text:
stash_pending_tool_output(tool_name, text)
return truncated
return wrapper
sdk_tools = []
for tool_name, base_tool in TOOL_REGISTRY.items():
handler = create_tool_handler(base_tool)
decorated = tool(
tool_name,
base_tool.description,
_build_input_schema(base_tool),
)(_truncating(handler, tool_name))
sdk_tools.append(decorated)
# E2B file tools replace SDK built-in Read/Write/Edit/Glob/Grep.
if use_e2b:
for name, desc, schema, handler in E2B_FILE_TOOLS:
decorated = tool(name, desc, schema)(_truncating(handler, name))
sdk_tools.append(decorated)
# Add the Read tool so the SDK can read back oversized tool results
read_tool = tool(
_READ_TOOL_NAME,
_READ_TOOL_DESCRIPTION,
_READ_TOOL_SCHEMA,
)(_read_file_handler)
sdk_tools.append(read_tool)
# Read tool for SDK-truncated tool results (always needed).
read_tool = tool(
_READ_TOOL_NAME,
_READ_TOOL_DESCRIPTION,
_READ_TOOL_SCHEMA,
)(_truncating(_read_file_handler, _READ_TOOL_NAME))
sdk_tools.append(read_tool)
server = create_sdk_mcp_server(
name=MCP_SERVER_NAME,
version="1.0.0",
tools=sdk_tools,
)
return server
except ImportError:
# Let ImportError propagate so service.py handles the fallback
raise
return create_sdk_mcp_server(
name=MCP_SERVER_NAME,
version="1.0.0",
tools=sdk_tools,
)
# SDK built-in tools allowed within the workspace directory.
@@ -397,16 +472,11 @@ def create_copilot_mcp_server():
# Task allows spawning sub-agents (rate-limited by security hooks).
# WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk.
# TodoWrite manages the task checklist shown in the UI — no security concern.
_SDK_BUILTIN_TOOLS = [
"Read",
"Write",
"Edit",
"Glob",
"Grep",
"Task",
"WebSearch",
"TodoWrite",
]
# In E2B mode, all five are disabled — MCP equivalents provide direct sandbox
# access. read_file also handles local tool-results and ephemeral reads.
_SDK_BUILTIN_FILE_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep"]
_SDK_BUILTIN_ALWAYS = ["Task", "WebSearch", "TodoWrite"]
_SDK_BUILTIN_TOOLS = [*_SDK_BUILTIN_FILE_TOOLS, *_SDK_BUILTIN_ALWAYS]
# SDK built-in tools that must be explicitly blocked.
# Bash: dangerous — agent uses mcp__copilot__bash_exec with kernel-level
@@ -453,11 +523,37 @@ DANGEROUS_PATTERNS = [
r"subprocess",
]
# List of tool names for allowed_tools configuration
# Include MCP tools, the MCP Read tool for oversized results,
# and SDK built-in file tools for workspace operations.
# Static tool name list for the non-E2B case (backward compatibility).
COPILOT_TOOL_NAMES = [
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
*_SDK_BUILTIN_TOOLS,
]
def get_copilot_tool_names(*, use_e2b: bool = False) -> list[str]:
"""Build the ``allowed_tools`` list for :class:`ClaudeAgentOptions`.
When *use_e2b* is True the SDK built-in file tools are replaced by MCP
equivalents that route to the E2B sandbox.
"""
if not use_e2b:
return list(COPILOT_TOOL_NAMES)
return [
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
*[f"{MCP_TOOL_PREFIX}{name}" for name in E2B_FILE_TOOL_NAMES],
*_SDK_BUILTIN_ALWAYS,
]
def get_sdk_disallowed_tools(*, use_e2b: bool = False) -> list[str]:
"""Build the ``disallowed_tools`` list for :class:`ClaudeAgentOptions`.
When *use_e2b* is True the SDK built-in file tools are also disabled
because MCP equivalents provide direct sandbox access.
"""
if not use_e2b:
return list(SDK_DISALLOWED_TOOLS)
return [*SDK_DISALLOWED_TOOLS, *_SDK_BUILTIN_FILE_TOOLS]

View File

@@ -0,0 +1,170 @@
"""Tests for tool_adapter helpers: truncation, stash, context vars."""
import pytest
from backend.util.truncate import truncate
from .tool_adapter import (
_MCP_MAX_CHARS,
_text_from_mcp_result,
get_sdk_cwd,
pop_pending_tool_output,
set_execution_context,
stash_pending_tool_output,
)
# ---------------------------------------------------------------------------
# _text_from_mcp_result
# ---------------------------------------------------------------------------
class TestTextFromMcpResult:
def test_single_text_block(self):
result = {"content": [{"type": "text", "text": "hello"}]}
assert _text_from_mcp_result(result) == "hello"
def test_multiple_text_blocks_concatenated(self):
result = {
"content": [
{"type": "text", "text": "one"},
{"type": "text", "text": "two"},
]
}
assert _text_from_mcp_result(result) == "onetwo"
def test_non_text_blocks_ignored(self):
result = {
"content": [
{"type": "image", "data": "..."},
{"type": "text", "text": "only this"},
]
}
assert _text_from_mcp_result(result) == "only this"
def test_empty_content_list(self):
assert _text_from_mcp_result({"content": []}) == ""
def test_missing_content_key(self):
assert _text_from_mcp_result({}) == ""
def test_non_list_content(self):
assert _text_from_mcp_result({"content": "raw string"}) == ""
def test_missing_text_field(self):
result = {"content": [{"type": "text"}]}
assert _text_from_mcp_result(result) == ""
# ---------------------------------------------------------------------------
# get_sdk_cwd
# ---------------------------------------------------------------------------
class TestGetSdkCwd:
def test_returns_empty_string_by_default(self):
set_execution_context(
user_id="test",
session=None, # type: ignore[arg-type]
sandbox=None,
)
assert get_sdk_cwd() == ""
def test_returns_set_value(self):
set_execution_context(
user_id="test",
session=None, # type: ignore[arg-type]
sandbox=None,
sdk_cwd="/tmp/copilot-test-123",
)
assert get_sdk_cwd() == "/tmp/copilot-test-123"
# ---------------------------------------------------------------------------
# stash / pop round-trip (the mechanism _truncating relies on)
# ---------------------------------------------------------------------------
class TestToolOutputStash:
@pytest.fixture(autouse=True)
def _init_context(self):
"""Initialise the context vars that stash_pending_tool_output needs."""
set_execution_context(
user_id="test",
session=None, # type: ignore[arg-type]
sandbox=None,
sdk_cwd="/tmp/test",
)
def test_stash_and_pop(self):
stash_pending_tool_output("my_tool", "output1")
assert pop_pending_tool_output("my_tool") == "output1"
def test_pop_empty_returns_none(self):
assert pop_pending_tool_output("nonexistent") is None
def test_fifo_order(self):
stash_pending_tool_output("t", "first")
stash_pending_tool_output("t", "second")
assert pop_pending_tool_output("t") == "first"
assert pop_pending_tool_output("t") == "second"
assert pop_pending_tool_output("t") is None
def test_dict_serialised_to_json(self):
stash_pending_tool_output("t", {"key": "value"})
assert pop_pending_tool_output("t") == '{"key": "value"}'
def test_separate_tool_names(self):
stash_pending_tool_output("a", "alpha")
stash_pending_tool_output("b", "beta")
assert pop_pending_tool_output("b") == "beta"
assert pop_pending_tool_output("a") == "alpha"
# ---------------------------------------------------------------------------
# _truncating wrapper (integration via create_copilot_mcp_server)
# ---------------------------------------------------------------------------
class TestTruncationAndStashIntegration:
"""Test truncation + stash behavior that _truncating relies on."""
@pytest.fixture(autouse=True)
def _init_context(self):
set_execution_context(
user_id="test",
session=None, # type: ignore[arg-type]
sandbox=None,
sdk_cwd="/tmp/test",
)
def test_small_output_stashed(self):
"""Non-error output is stashed for the response adapter."""
result = {
"content": [{"type": "text", "text": "small output"}],
"isError": False,
}
truncated = truncate(result, _MCP_MAX_CHARS)
text = _text_from_mcp_result(truncated)
assert text == "small output"
stash_pending_tool_output("test_tool", text)
assert pop_pending_tool_output("test_tool") == "small output"
def test_error_result_not_stashed(self):
"""Error results should not be stashed."""
result = {
"content": [{"type": "text", "text": "error msg"}],
"isError": True,
}
# _truncating only stashes when not result.get("isError")
if not result.get("isError"):
stash_pending_tool_output("err_tool", "should not happen")
assert pop_pending_tool_output("err_tool") is None
def test_large_output_truncated(self):
"""Output exceeding _MCP_MAX_CHARS is truncated before stashing."""
big_text = "x" * (_MCP_MAX_CHARS + 100_000)
result = {"content": [{"type": "text", "text": big_text}]}
truncated = truncate(result, _MCP_MAX_CHARS)
text = _text_from_mcp_result(truncated)
assert len(text) < len(big_text)
assert len(str(truncated)) <= _MCP_MAX_CHARS

View File

@@ -331,10 +331,10 @@ async def upload_transcript(
) -> None:
"""Strip progress entries and upload transcript to bucket storage.
Safety: only overwrites when the new (stripped) transcript is larger than
what is already stored. Since JSONL is append-only, the latest transcript
is always the longest. This prevents a slow/stale background task from
clobbering a newer upload from a concurrent turn.
The executor holds a cluster lock per session, so concurrent uploads for
the same session cannot happen. We always overwrite — with ``--resume``
the CLI may compact old tool results, so neither byte size nor line count
is a reliable proxy for "newer".
Args:
message_count: ``len(session.messages)`` at upload time — used by
@@ -353,20 +353,6 @@ async def upload_transcript(
storage = await get_workspace_storage()
wid, fid, fname = _storage_path_parts(user_id, session_id)
encoded = stripped.encode("utf-8")
new_size = len(encoded)
# Check existing transcript size to avoid overwriting newer with older
path = _build_storage_path(user_id, session_id, storage)
try:
existing = await storage.retrieve(path)
if len(existing) >= new_size:
logger.info(
f"[Transcript] Skipping upload — existing ({len(existing)}B) "
f">= new ({new_size}B) for session {session_id}"
)
return
except (FileNotFoundError, Exception):
pass # No existing transcript or retrieval error — proceed with upload
await storage.store(
workspace_id=wid,
@@ -375,11 +361,8 @@ async def upload_transcript(
content=encoded,
)
# Store metadata alongside the transcript so the next turn can detect
# staleness and only compress the gap instead of the full history.
# Wrapped in try/except so a metadata write failure doesn't orphan
# the already-uploaded transcript — the next turn will just fall back
# to full gap fill (msg_count=0).
# Update metadata so message_count stays current. The gap-fill logic
# in _build_query_message relies on it to avoid re-compressing messages.
try:
meta = {"message_count": message_count, "uploaded_at": time.time()}
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
@@ -393,7 +376,7 @@ async def upload_transcript(
logger.warning(f"[Transcript] Failed to write metadata for {session_id}: {e}")
logger.info(
f"[Transcript] Uploaded {new_size}B "
f"[Transcript] Uploaded {len(encoded)}B "
f"(stripped from {len(content)}B, msg_count={message_count}) "
f"for session {session_id}"
)

File diff suppressed because it is too large Load Diff

View File

@@ -4,75 +4,14 @@ from os import getenv
import pytest
from . import service as chat_service
from .model import create_chat_session, get_chat_session, upsert_chat_session
from .response_model import StreamError, StreamTextDelta, StreamToolOutputAvailable
from .response_model import StreamError, StreamTextDelta
from .sdk import service as sdk_service
from .sdk.transcript import download_transcript
logger = logging.getLogger(__name__)
@pytest.mark.asyncio(loop_scope="session")
async def test_stream_chat_completion(setup_test_user, test_user_id):
"""
Test the stream_chat_completion function.
"""
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
has_errors = False
assistant_message = ""
async for chunk in chat_service.stream_chat_completion(
session.session_id, "Hello, how are you?", user_id=session.user_id
):
logger.info(chunk)
if isinstance(chunk, StreamError):
has_errors = True
if isinstance(chunk, StreamTextDelta):
assistant_message += chunk.delta
# StreamFinish is published by mark_session_completed (processor layer),
# not by the service. The generator completing means the stream ended.
assert not has_errors, "Error occurred while streaming chat completion"
assert assistant_message, "Assistant message is empty"
@pytest.mark.asyncio(loop_scope="session")
async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user_id):
"""
Test the stream_chat_completion function.
"""
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
has_errors = False
had_tool_calls = False
async for chunk in chat_service.stream_chat_completion(
session.session_id,
"Please find me an agent that can help me with my business. Use the query 'moneny printing agent'",
user_id=session.user_id,
):
logger.info(chunk)
if isinstance(chunk, StreamError):
has_errors = True
if isinstance(chunk, StreamToolOutputAvailable):
had_tool_calls = True
assert not has_errors, "Error occurred while streaming chat completion"
assert had_tool_calls, "Tool calls did not occur"
session = await get_chat_session(session.session_id)
assert session, "Session not found"
assert session.usage, "Usage is empty"
@pytest.mark.asyncio(loop_scope="session")
async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
"""Test that the SDK --resume path captures and uses transcripts across turns.

View File

@@ -707,7 +707,6 @@ async def mark_session_completed(
True if session was newly marked completed, False if already completed/failed
"""
status: Literal["completed", "failed"] = "failed" if error_message else "completed"
redis = await get_redis_async()
meta_key = _get_session_meta_key(session_id)
@@ -734,7 +733,10 @@ async def mark_session_completed(
# This is the SINGLE place that publishes StreamFinish — services and
# the processor must NOT publish it themselves.
try:
await publish_chunk(turn_id, StreamFinish())
await publish_chunk(
turn_id,
StreamFinish(),
)
except Exception as e:
logger.error(
f"Failed to publish StreamFinish for session {session_id}: {e}. "

View File

@@ -1,12 +1,14 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
from openai.types.chat import ChatCompletionToolParam
from backend.copilot.model import ChatSession
from backend.copilot.tracking import track_tool_called
from .add_understanding import AddUnderstandingTool
from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreenshotTool
from .agent_output import AgentOutputTool
from .base import BaseTool
from .bash_exec import BashExecTool
@@ -20,6 +22,7 @@ from .find_library_agent import FindLibraryAgentTool
from .get_doc_page import GetDocPageTool
from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .run_mcp_tool import RunMCPToolTool
from .search_docs import SearchDocsTool
from .web_fetch import WebFetchTool
from .workspace_files import (
@@ -30,6 +33,7 @@ from .workspace_files import (
)
if TYPE_CHECKING:
from backend.copilot.model import ChatSession
from backend.copilot.response_model import StreamToolOutputAvailable
logger = logging.getLogger(__name__)
@@ -45,11 +49,16 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"find_library_agent": FindLibraryAgentTool(),
"run_agent": RunAgentTool(),
"run_block": RunBlockTool(),
"run_mcp_tool": RunMCPToolTool(),
"view_agent_output": AgentOutputTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
# Web fetch for safe URL retrieval
"web_fetch": WebFetchTool(),
# Agent-browser multi-step automation (navigate, act, screenshot)
"browser_navigate": BrowserNavigateTool(),
"browser_act": BrowserActTool(),
"browser_screenshot": BrowserScreenshotTool(),
# Sandboxed code execution (bubblewrap)
"bash_exec": BashExecTool(),
# Persistent workspace tools (cloud storage, survives across sessions)
@@ -67,10 +76,17 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
find_agent_tool = TOOL_REGISTRY["find_agent"]
run_agent_tool = TOOL_REGISTRY["run_agent"]
# Generated from registry for OpenAI API
tools: list[ChatCompletionToolParam] = [
tool.as_openai_tool() for tool in TOOL_REGISTRY.values()
]
def get_available_tools() -> list[ChatCompletionToolParam]:
"""Return OpenAI tool schemas for tools available in the current environment.
Called per-request so that env-var or binary availability is evaluated
fresh each time (e.g. browser_* tools are excluded when agent-browser
CLI is not installed).
"""
return [
tool.as_openai_tool() for tool in TOOL_REGISTRY.values() if tool.is_available
]
def get_tool(tool_name: str) -> BaseTool | None:

View File

@@ -1,3 +1,4 @@
import logging
import uuid
from datetime import UTC, datetime
from os import getenv
@@ -12,12 +13,34 @@ from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
from backend.blocks.llm import AITextGeneratorBlock
from backend.copilot.model import ChatSession
from backend.data import db as db_module
from backend.data.db import prisma
from backend.data.graph import Graph, Link, Node, create_graph
from backend.data.model import APIKeyCredentials
from backend.data.user import get_or_create_user
from backend.integrations.credentials_store import IntegrationCredentialsStore
_logger = logging.getLogger(__name__)
async def _ensure_db_connected() -> None:
"""Ensure the Prisma connection is alive on the current event loop.
On Python 3.11, the httpx transport inside Prisma can reference a stale
(closed) event loop when session-scoped async fixtures are evaluated long
after the initial ``server`` fixture connected Prisma. A cheap health-check
followed by a reconnect fixes this without affecting other fixtures.
"""
try:
await prisma.query_raw("SELECT 1")
except Exception:
_logger.info("Prisma connection stale reconnecting")
try:
await db_module.disconnect()
except Exception:
pass
await db_module.connect()
def make_session(user_id: str):
return ChatSession(
@@ -43,6 +66,8 @@ async def setup_test_data(server):
Depends on ``server`` to ensure Prisma is connected.
"""
await _ensure_db_connected()
# 1. Create a test user
user_data = {
"sub": f"test-user-{uuid.uuid4()}",
@@ -164,6 +189,8 @@ async def setup_llm_test_data(server):
Depends on ``server`` to ensure Prisma is connected.
"""
await _ensure_db_connected()
key = getenv("OPENAI_API_KEY")
if not key:
return pytest.skip("OPENAI_API_KEY is not set")
@@ -330,6 +357,8 @@ async def setup_firecrawl_test_data(server):
Depends on ``server`` to ensure Prisma is connected.
"""
await _ensure_db_connected()
# 1. Create a test user
user_data = {
"sub": f"test-user-{uuid.uuid4()}",

View File

@@ -0,0 +1,876 @@
"""Agent-browser tools — multi-step browser automation for the Copilot.
Uses the agent-browser CLI (https://github.com/vercel-labs/agent-browser)
which runs a local Chromium instance managed by a persistent daemon.
- Runs locally — no cloud account required
- Full interaction support: click, fill, scroll, login flows, multi-step
- Session persistence via --session-name: cookies/auth carry across tool calls
within the same Copilot session, enabling login → navigate → extract workflows
- Screenshot with --annotate overlays @ref labels, saved to workspace for user
- The Claude Agent SDK's multi-turn loop handles orchestration — each tool call
is one browser action; the LLM chains them naturally
SSRF protection:
Uses the shared validate_url() from backend.util.request, which is the same
guard used by HTTP blocks and web_fetch. It resolves ALL DNS answers (not just
the first), blocks RFC 1918, loopback, link-local, 0.0.0.0/8, multicast,
and all relevant IPv6 ranges, and applies IDNA encoding to prevent Unicode
domain attacks.
Requires:
npm install -g agent-browser
agent-browser install (downloads Chromium, one-time per machine)
"""
import asyncio
import base64
import json
import logging
import os
import shutil
import tempfile
from typing import Any
from backend.copilot.model import ChatSession
from backend.util.request import validate_url
from .base import BaseTool
from .models import (
BrowserActResponse,
BrowserNavigateResponse,
BrowserScreenshotResponse,
ErrorResponse,
ToolResponseBase,
)
from .workspace_files import get_manager
logger = logging.getLogger(__name__)
# Per-command timeout (seconds). Navigation + networkidle wait can be slow.
_CMD_TIMEOUT = 45
# Accessibility tree can be very large; cap it to keep LLM context manageable.
_MAX_SNAPSHOT_CHARS = 20_000
# ---------------------------------------------------------------------------
# Subprocess helper
# ---------------------------------------------------------------------------
async def _run(
session_name: str,
*args: str,
timeout: int = _CMD_TIMEOUT,
) -> tuple[int, str, str]:
"""Run agent-browser for the given session and return (rc, stdout, stderr).
Uses both:
--session <name> → isolated Chromium context (no shared history/cookies
with other Copilot sessions — prevents cross-session
browser state leakage)
--session-name <name> → persist cookies/localStorage across tool calls within
the same session (enables login → navigate flows)
"""
cmd = [
"agent-browser",
"--session",
session_name,
"--session-name",
session_name,
*args,
]
proc = None
try:
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout)
return proc.returncode or 0, stdout.decode(), stderr.decode()
except asyncio.TimeoutError:
# Kill the orphaned subprocess so it does not linger in the process table.
if proc is not None and proc.returncode is None:
proc.kill()
try:
await proc.communicate()
except Exception:
pass # Best-effort reap; ignore errors during cleanup.
return 1, "", f"Command timed out after {timeout}s."
except FileNotFoundError:
return (
1,
"",
"agent-browser is not installed (run: npm install -g agent-browser && agent-browser install).",
)
async def _snapshot(session_name: str) -> str:
"""Return the current page's interactive accessibility tree, truncated."""
rc, stdout, stderr = await _run(session_name, "snapshot", "-i", "-c")
if rc != 0:
return f"[snapshot failed: {stderr[:300]}]"
text = stdout.strip()
if len(text) > _MAX_SNAPSHOT_CHARS:
suffix = "\n\n[Snapshot truncated — use browser_act to navigate further]"
keep = max(0, _MAX_SNAPSHOT_CHARS - len(suffix))
text = text[:keep] + suffix
return text
# ---------------------------------------------------------------------------
# Stateless session helpers — persist / restore browser state across pods
# ---------------------------------------------------------------------------
# Module-level cache of sessions known to be alive on this pod.
# Avoids the subprocess probe on every tool call within the same pod.
_alive_sessions: set[str] = set()
# Per-session locks to prevent concurrent _ensure_session calls from
# triggering duplicate _restore_browser_state for the same session.
# Protected by _session_locks_mutex to ensure setdefault/pop are not
# interleaved across await boundaries.
_session_locks: dict[str, asyncio.Lock] = {}
_session_locks_mutex = asyncio.Lock()
# Workspace filename for persisted browser state (auto-scoped to session).
# Dot-prefixed so it is hidden from user workspace listings.
_STATE_FILENAME = "._browser_state.json"
# Maximum concurrent subprocesses during cookie/storage restore.
_RESTORE_CONCURRENCY = 10
# Maximum cookies to restore per session. Pathological sites can accumulate
# thousands of cookies; restoring them all would be slow and is rarely useful.
_MAX_RESTORE_COOKIES = 100
# Background tasks for fire-and-forget state persistence.
# Prevents GC from collecting tasks before they complete.
_background_tasks: set[asyncio.Task] = set()
def _fire_and_forget_save(
session_name: str, user_id: str, session: ChatSession
) -> None:
"""Schedule state persistence as a background task (non-blocking).
State save is already best-effort (errors are swallowed), so running it
in the background avoids adding latency to tool responses.
"""
task = asyncio.create_task(_save_browser_state(session_name, user_id, session))
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
async def _has_local_session(session_name: str) -> bool:
"""Check if the local agent-browser daemon for this session is running."""
rc, _, _ = await _run(session_name, "get", "url", timeout=5)
return rc == 0
async def _save_browser_state(
session_name: str, user_id: str, session: ChatSession
) -> None:
"""Persist browser state (cookies, localStorage, URL) to workspace.
Best-effort: errors are logged but never propagate to the tool response.
"""
try:
# Gather state in parallel
(rc_url, url_out, _), (rc_ck, ck_out, _), (rc_ls, ls_out, _) = (
await asyncio.gather(
_run(session_name, "get", "url", timeout=10),
_run(session_name, "cookies", "get", "--json", timeout=10),
_run(session_name, "storage", "local", "--json", timeout=10),
)
)
state = {
"url": url_out.strip() if rc_url == 0 else "",
"cookies": (json.loads(ck_out) if rc_ck == 0 and ck_out.strip() else []),
"local_storage": (
json.loads(ls_out) if rc_ls == 0 and ls_out.strip() else {}
),
}
manager = await get_manager(user_id, session.session_id)
await manager.write_file(
content=json.dumps(state).encode("utf-8"),
filename=_STATE_FILENAME,
mime_type="application/json",
overwrite=True,
)
except Exception:
logger.warning(
"[browser] Failed to save browser state for session %s",
session_name,
exc_info=True,
)
async def _restore_browser_state(
session_name: str, user_id: str, session: ChatSession
) -> bool:
"""Restore browser state from workspace storage into a fresh daemon.
Best-effort: errors are logged but never propagate to the tool response.
Returns True on success (or no state to restore), False on failure.
"""
try:
manager = await get_manager(user_id, session.session_id)
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
if file_info is None:
return True # No saved state — first call or never saved
state_bytes = await manager.read_file(_STATE_FILENAME)
state = json.loads(state_bytes.decode("utf-8"))
url = state.get("url", "")
cookies = state.get("cookies", [])
local_storage = state.get("local_storage", {})
# Navigate first — starts daemon + sets the correct origin for cookies
if url:
# Validate the saved URL to prevent SSRF via stored redirect targets.
try:
await validate_url(url, trusted_origins=[])
except ValueError:
logger.warning(
"[browser] State restore: blocked SSRF URL %s", url[:200]
)
return False
rc, _, stderr = await _run(session_name, "open", url)
if rc != 0:
logger.warning(
"[browser] State restore: failed to open %s: %s",
url,
stderr[:200],
)
return False
await _run(session_name, "wait", "--load", "load", timeout=15)
# Restore cookies and localStorage in parallel via asyncio.gather.
# Semaphore caps concurrent subprocess spawns so we don't overwhelm the
# system when a session has hundreds of cookies.
sem = asyncio.Semaphore(_RESTORE_CONCURRENCY)
# Guard against pathological sites with thousands of cookies.
if len(cookies) > _MAX_RESTORE_COOKIES:
logger.debug(
"[browser] State restore: capping cookies from %d to %d",
len(cookies),
_MAX_RESTORE_COOKIES,
)
cookies = cookies[:_MAX_RESTORE_COOKIES]
async def _set_cookie(c: dict[str, Any]) -> None:
name = c.get("name", "")
value = c.get("value", "")
domain = c.get("domain", "")
path = c.get("path", "/")
if not (name and domain):
return
async with sem:
rc, _, stderr = await _run(
session_name,
"cookies",
"set",
name,
value,
"--domain",
domain,
"--path",
path,
timeout=5,
)
if rc != 0:
logger.debug(
"[browser] State restore: cookie set failed for %s: %s",
name,
stderr[:100],
)
async def _set_storage(key: str, val: object) -> None:
async with sem:
rc, _, stderr = await _run(
session_name,
"storage",
"local",
"set",
key,
str(val),
timeout=5,
)
if rc != 0:
logger.debug(
"[browser] State restore: localStorage set failed for %s: %s",
key,
stderr[:100],
)
await asyncio.gather(
*[_set_cookie(c) for c in cookies],
*[_set_storage(k, v) for k, v in local_storage.items()],
)
return True
except Exception:
logger.warning(
"[browser] Failed to restore browser state for session %s",
session_name,
exc_info=True,
)
return False
async def _ensure_session(
session_name: str, user_id: str, session: ChatSession
) -> None:
"""Ensure the local browser daemon has state. Restore from cloud if needed."""
if session_name in _alive_sessions:
return
async with _session_locks_mutex:
lock = _session_locks.setdefault(session_name, asyncio.Lock())
async with lock:
# Double-check after acquiring lock — another coroutine may have restored.
if session_name in _alive_sessions:
return
if await _has_local_session(session_name):
_alive_sessions.add(session_name)
return
if await _restore_browser_state(session_name, user_id, session):
_alive_sessions.add(session_name)
async def close_browser_session(session_name: str, user_id: str | None = None) -> None:
"""Shut down the local agent-browser daemon and clean up stored state.
Deletes ``._browser_state.json`` from workspace storage so cookies and
other credentials do not linger after the session is deleted.
Best-effort: errors are logged but never raised.
"""
_alive_sessions.discard(session_name)
async with _session_locks_mutex:
_session_locks.pop(session_name, None)
# Delete persisted browser state (cookies, localStorage) from workspace.
if user_id:
try:
manager = await get_manager(user_id, session_name)
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
if file_info is not None:
await manager.delete_file(file_info.id)
except Exception:
logger.debug(
"[browser] Failed to delete state file for session %s",
session_name,
exc_info=True,
)
try:
rc, _, stderr = await _run(session_name, "close", timeout=10)
if rc != 0:
logger.debug(
"[browser] close failed for session %s: %s",
session_name,
stderr[:200],
)
except Exception:
logger.debug(
"[browser] Exception closing browser session %s",
session_name,
exc_info=True,
)
# ---------------------------------------------------------------------------
# Tool: browser_navigate
# ---------------------------------------------------------------------------
class BrowserNavigateTool(BaseTool):
"""Navigate to a URL and return the page's interactive elements.
The browser session persists across tool calls within this Copilot session
(keyed to session_id), so cookies and auth state carry over. This enables
full login flows: navigate to login page → browser_act to fill credentials
→ browser_act to submit → browser_navigate to the target page.
"""
@property
def name(self) -> str:
return "browser_navigate"
@property
def description(self) -> str:
return (
"Navigate to a URL using a real browser. Returns an accessibility "
"tree snapshot listing the page's interactive elements with @ref IDs "
"(e.g. @e3) that can be used with browser_act. "
"Session persists — cookies and login state carry over between calls. "
"Use this (with browser_act) for multi-step interaction: login flows, "
"form filling, button clicks, or anything requiring page interaction. "
"For plain static pages, prefer web_fetch — no browser overhead. "
"For authenticated pages: navigate to the login page first, use browser_act "
"to fill credentials and submit, then navigate to the target page. "
"Note: for slow SPAs, the returned snapshot may reflect a partially-loaded "
"state. If elements seem missing, use browser_act with action='wait' and a "
"CSS selector or millisecond delay, then take a browser_screenshot to verify."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The HTTP/HTTPS URL to navigate to.",
},
"wait_for": {
"type": "string",
"enum": ["networkidle", "load", "domcontentloaded"],
"default": "networkidle",
"description": "When to consider navigation complete. Use 'networkidle' for SPAs (default).",
},
},
"required": ["url"],
}
@property
def requires_auth(self) -> bool:
return True
@property
def is_available(self) -> bool:
return shutil.which("agent-browser") is not None
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
) -> ToolResponseBase:
"""Navigate to *url*, wait for the page to settle, and return a snapshot.
The snapshot is an accessibility-tree listing of interactive elements.
Note: for slow SPAs that never fully idle, the snapshot may reflect a
partially-loaded state (the wait is best-effort).
"""
url: str = (kwargs.get("url") or "").strip()
wait_for: str = kwargs.get("wait_for") or "networkidle"
session_name = session.session_id
if not url:
return ErrorResponse(
message="Please provide a URL to navigate to.",
error="missing_url",
session_id=session_name,
)
try:
await validate_url(url, trusted_origins=[])
except ValueError as e:
return ErrorResponse(
message=str(e),
error="blocked_url",
session_id=session_name,
)
# Restore browser state from cloud if this is a different pod
if user_id:
await _ensure_session(session_name, user_id, session)
# Navigate
rc, _, stderr = await _run(session_name, "open", url)
if rc != 0:
logger.warning(
"[browser_navigate] open failed for %s: %s", url, stderr[:300]
)
return ErrorResponse(
message="Failed to navigate to URL.",
error="navigation_failed",
session_id=session_name,
)
# Wait for page to settle (best-effort: some SPAs never reach networkidle)
wait_rc, _, wait_err = await _run(session_name, "wait", "--load", wait_for)
if wait_rc != 0:
logger.warning(
"[browser_navigate] wait(%s) failed: %s", wait_for, wait_err[:300]
)
# Get current title and URL in parallel
(_, title_out, _), (_, url_out, _) = await asyncio.gather(
_run(session_name, "get", "title"),
_run(session_name, "get", "url"),
)
snapshot = await _snapshot(session_name)
result = BrowserNavigateResponse(
message=f"Navigated to {url}",
url=url_out.strip() or url,
title=title_out.strip(),
snapshot=snapshot,
session_id=session_name,
)
# Persist browser state to cloud for cross-pod continuity
if user_id:
_fire_and_forget_save(session_name, user_id, session)
return result
# ---------------------------------------------------------------------------
# Tool: browser_act
# ---------------------------------------------------------------------------
_NO_TARGET_ACTIONS = frozenset({"back", "forward", "reload"})
_SCROLL_ACTIONS = frozenset({"scroll"})
_TARGET_ONLY_ACTIONS = frozenset({"click", "dblclick", "hover", "check", "uncheck"})
_TARGET_VALUE_ACTIONS = frozenset({"fill", "type", "select"})
# wait <selector|ms>: waits for a DOM element or a fixed delay (e.g. "1000" for 1 s)
_WAIT_ACTIONS = frozenset({"wait"})
class BrowserActTool(BaseTool):
"""Perform an action on the current browser page and return the updated snapshot.
Use @ref IDs from the snapshot returned by browser_navigate (e.g. '@e3').
The LLM orchestrates multi-step flows by chaining browser_navigate and
browser_act calls across turns of the Claude Agent SDK conversation.
"""
@property
def name(self) -> str:
return "browser_act"
@property
def description(self) -> str:
return (
"Interact with the current browser page. Use @ref IDs from the "
"snapshot (e.g. '@e3') to target elements. Returns an updated snapshot. "
"Supported actions: click, dblclick, fill, type, scroll, hover, press, "
"check, uncheck, select, wait, back, forward, reload. "
"fill clears the field before typing; type appends without clearing. "
"wait accepts a CSS selector (waits for element) or milliseconds string (e.g. '1000'). "
"Example login flow: fill @e1 with email → fill @e2 with password → "
"click @e3 (submit) → browser_navigate to the target page."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": [
"click",
"dblclick",
"fill",
"type",
"scroll",
"hover",
"press",
"check",
"uncheck",
"select",
"wait",
"back",
"forward",
"reload",
],
"description": "The action to perform.",
},
"target": {
"type": "string",
"description": (
"Element to target. Use @ref from snapshot (e.g. '@e3'), "
"a CSS selector, or a text description. "
"Required for: click, dblclick, fill, type, hover, check, uncheck, select. "
"For wait: a CSS selector to wait for, or milliseconds as a string (e.g. '1000')."
),
},
"value": {
"type": "string",
"description": (
"For fill/type: the text to enter. "
"For press: key name (e.g. 'Enter', 'Tab', 'Control+a'). "
"For select: the option value to select."
),
},
"direction": {
"type": "string",
"enum": ["up", "down", "left", "right"],
"default": "down",
"description": "For scroll: direction to scroll.",
},
},
"required": ["action"],
}
@property
def requires_auth(self) -> bool:
return True
@property
def is_available(self) -> bool:
return shutil.which("agent-browser") is not None
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
) -> ToolResponseBase:
"""Perform a browser action and return an updated page snapshot.
Validates the *action*/*target*/*value* combination, delegates to
``agent-browser``, waits for the page to settle, and returns the
accessibility-tree snapshot so the LLM can plan the next step.
"""
action: str = (kwargs.get("action") or "").strip()
target: str = (kwargs.get("target") or "").strip()
value: str = (kwargs.get("value") or "").strip()
direction: str = (kwargs.get("direction") or "down").strip()
session_name = session.session_id
if not action:
return ErrorResponse(
message="Please specify an action.",
error="missing_action",
session_id=session_name,
)
# Build the agent-browser command args
if action in _NO_TARGET_ACTIONS:
cmd_args = [action]
elif action in _SCROLL_ACTIONS:
cmd_args = ["scroll", direction]
elif action == "press":
if not value:
return ErrorResponse(
message="'press' requires a 'value' (key name, e.g. 'Enter').",
error="missing_value",
session_id=session_name,
)
cmd_args = ["press", value]
elif action in _TARGET_ONLY_ACTIONS:
if not target:
return ErrorResponse(
message=f"'{action}' requires a 'target' element.",
error="missing_target",
session_id=session_name,
)
cmd_args = [action, target]
elif action in _TARGET_VALUE_ACTIONS:
if not target or not value:
return ErrorResponse(
message=f"'{action}' requires both 'target' and 'value'.",
error="missing_params",
session_id=session_name,
)
cmd_args = [action, target, value]
elif action in _WAIT_ACTIONS:
if not target:
return ErrorResponse(
message=(
"'wait' requires a 'target': a CSS selector to wait for, "
"or milliseconds as a string (e.g. '1000')."
),
error="missing_target",
session_id=session_name,
)
cmd_args = ["wait", target]
else:
return ErrorResponse(
message=f"Unsupported action: {action}",
error="invalid_action",
session_id=session_name,
)
# Restore browser state from cloud if this is a different pod
if user_id:
await _ensure_session(session_name, user_id, session)
rc, _, stderr = await _run(session_name, *cmd_args)
if rc != 0:
logger.warning("[browser_act] %s failed: %s", action, stderr[:300])
return ErrorResponse(
message=f"Action '{action}' failed.",
error="action_failed",
session_id=session_name,
)
# Allow the page to settle after interaction (best-effort: SPAs may not idle)
settle_rc, _, settle_err = await _run(
session_name, "wait", "--load", "networkidle"
)
if settle_rc != 0:
logger.warning(
"[browser_act] post-action wait failed: %s", settle_err[:300]
)
snapshot = await _snapshot(session_name)
_, url_out, _ = await _run(session_name, "get", "url")
result = BrowserActResponse(
message=f"Performed '{action}'" + (f" on '{target}'" if target else ""),
action=action,
current_url=url_out.strip(),
snapshot=snapshot,
session_id=session_name,
)
# Persist browser state to cloud for cross-pod continuity
if user_id:
_fire_and_forget_save(session_name, user_id, session)
return result
# ---------------------------------------------------------------------------
# Tool: browser_screenshot
# ---------------------------------------------------------------------------
class BrowserScreenshotTool(BaseTool):
"""Capture a screenshot of the current browser page and save it to the workspace."""
@property
def name(self) -> str:
return "browser_screenshot"
@property
def description(self) -> str:
return (
"Take a screenshot of the current browser page and save it to the workspace. "
"IMPORTANT: After calling this tool, immediately call read_workspace_file "
"with the returned file_id to display the image inline to the user — "
"the screenshot is not visible until you do this. "
"With annotate=true (default), @ref labels are overlaid on interactive "
"elements, making it easy to see which @ref ID maps to which element on screen."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"annotate": {
"type": "boolean",
"default": True,
"description": "Overlay @ref labels on interactive elements (default: true).",
},
"filename": {
"type": "string",
"default": "screenshot.png",
"description": "Filename to save in the workspace.",
},
},
}
@property
def requires_auth(self) -> bool:
return True
@property
def is_available(self) -> bool:
return shutil.which("agent-browser") is not None
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
) -> ToolResponseBase:
"""Capture a PNG screenshot and upload it to the workspace.
Handles string-to-bool coercion for *annotate* (OpenAI function-call
payloads sometimes deliver ``"true"``/``"false"`` as strings).
Returns a :class:`BrowserScreenshotResponse` with the workspace
``file_id`` the LLM should pass to ``read_workspace_file``.
"""
raw_annotate = kwargs.get("annotate", True)
if isinstance(raw_annotate, str):
annotate = raw_annotate.strip().lower() in {"1", "true", "yes", "on"}
else:
annotate = bool(raw_annotate)
filename: str = (kwargs.get("filename") or "screenshot.png").strip()
session_name = session.session_id
# Restore browser state from cloud if this is a different pod
if user_id:
await _ensure_session(session_name, user_id, session)
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".png")
os.close(tmp_fd)
try:
cmd_args = ["screenshot"]
if annotate:
cmd_args.append("--annotate")
cmd_args.append(tmp_path)
rc, _, stderr = await _run(session_name, *cmd_args)
if rc != 0:
logger.warning("[browser_screenshot] failed: %s", stderr[:300])
return ErrorResponse(
message="Failed to take screenshot.",
error="screenshot_failed",
session_id=session_name,
)
with open(tmp_path, "rb") as f:
png_bytes = f.read()
finally:
try:
os.unlink(tmp_path)
except OSError:
pass # Best-effort temp file cleanup; not critical if it fails.
# Upload to workspace so the user can view it
png_b64 = base64.b64encode(png_bytes).decode()
# Import here to avoid circular deps — workspace_files imports from .models
from .workspace_files import WorkspaceWriteResponse, WriteWorkspaceFileTool
write_resp = await WriteWorkspaceFileTool()._execute(
user_id=user_id,
session=session,
filename=filename,
content_base64=png_b64,
)
if not isinstance(write_resp, WorkspaceWriteResponse):
return ErrorResponse(
message="Screenshot taken but failed to save to workspace.",
error="workspace_write_failed",
session_id=session_name,
)
result = BrowserScreenshotResponse(
message=f"Screenshot saved to workspace as '{filename}'. Use read_workspace_file with file_id='{write_resp.file_id}' to retrieve it.",
file_id=write_resp.file_id,
filename=filename,
session_id=session_name,
)
# Persist browser state to cloud for cross-pod continuity
if user_id:
_fire_and_forget_save(session_name, user_id, session)
return result

File diff suppressed because it is too large Load Diff

View File

@@ -5,7 +5,7 @@ import re
from datetime import datetime, timedelta, timezone
from typing import Any
from pydantic import BaseModel, field_validator
from pydantic import BaseModel, Field, field_validator
from backend.api.features.library.model import LibraryAgent
from backend.copilot.model import ChatSession
@@ -13,6 +13,7 @@ from backend.data.db_accessors import execution_db, library_db
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
from .base import BaseTool
from .execution_utils import TERMINAL_STATUSES, wait_for_execution
from .models import (
AgentOutputResponse,
ErrorResponse,
@@ -33,6 +34,7 @@ class AgentOutputInput(BaseModel):
store_slug: str = ""
execution_id: str = ""
run_time: str = "latest"
wait_if_running: int = Field(default=0, ge=0, le=300)
@field_validator(
"agent_name",
@@ -116,6 +118,11 @@ class AgentOutputTool(BaseTool):
Select which run to retrieve using:
- execution_id: Specific execution ID
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
Wait for completion (optional):
- wait_if_running: Max seconds to wait if execution is still running (0-300).
If the execution is running/queued, waits up to this many seconds for completion.
Returns current status on timeout. If already finished, returns immediately.
"""
@property
@@ -145,6 +152,13 @@ class AgentOutputTool(BaseTool):
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
),
},
"wait_if_running": {
"type": "integer",
"description": (
"Max seconds to wait if execution is still running (0-300). "
"If running, waits for completion. Returns current state on timeout."
),
},
},
"required": [],
}
@@ -224,10 +238,14 @@ class AgentOutputTool(BaseTool):
execution_id: str | None,
time_start: datetime | None,
time_end: datetime | None,
include_running: bool = False,
) -> tuple[GraphExecution | None, list[GraphExecutionMeta], str | None]:
"""
Fetch execution(s) based on filters.
Returns (single_execution, available_executions_meta, error_message).
Args:
include_running: If True, also look for running/queued executions (for waiting)
"""
exec_db = execution_db()
@@ -242,11 +260,25 @@ class AgentOutputTool(BaseTool):
return None, [], f"Execution '{execution_id}' not found"
return execution, [], None
# Get completed executions with time filters
# Determine which statuses to query
statuses = [ExecutionStatus.COMPLETED]
if include_running:
statuses.extend(
[
ExecutionStatus.RUNNING,
ExecutionStatus.QUEUED,
ExecutionStatus.INCOMPLETE,
ExecutionStatus.REVIEW,
ExecutionStatus.FAILED,
ExecutionStatus.TERMINATED,
]
)
# Get executions with time filters
executions = await exec_db.get_graph_executions(
graph_id=graph_id,
user_id=user_id,
statuses=[ExecutionStatus.COMPLETED],
statuses=statuses,
created_time_gte=time_start,
created_time_lte=time_end,
limit=10,
@@ -313,10 +345,33 @@ class AgentOutputTool(BaseTool):
for e in available_executions[:5]
]
message = f"Found execution outputs for agent '{agent.name}'"
# Build appropriate message based on execution status
if execution.status == ExecutionStatus.COMPLETED:
message = f"Found execution outputs for agent '{agent.name}'"
elif execution.status == ExecutionStatus.FAILED:
message = f"Execution for agent '{agent.name}' failed"
elif execution.status == ExecutionStatus.TERMINATED:
message = f"Execution for agent '{agent.name}' was terminated"
elif execution.status == ExecutionStatus.REVIEW:
message = (
f"Execution for agent '{agent.name}' is awaiting human review. "
"The user needs to approve it before it can continue."
)
elif execution.status in (
ExecutionStatus.RUNNING,
ExecutionStatus.QUEUED,
ExecutionStatus.INCOMPLETE,
):
message = (
f"Execution for agent '{agent.name}' is still {execution.status.value}. "
"Results may be incomplete. Use wait_if_running to wait for completion."
)
else:
message = f"Found execution for agent '{agent.name}' (status: {execution.status.value})"
if len(available_executions) > 1:
message += (
f". Showing latest of {len(available_executions)} matching executions."
f" Showing latest of {len(available_executions)} matching executions."
)
return AgentOutputResponse(
@@ -431,13 +486,17 @@ class AgentOutputTool(BaseTool):
# Parse time expression
time_start, time_end = parse_time_expression(input_data.run_time)
# Fetch execution(s)
# Check if we should wait for running executions
wait_timeout = input_data.wait_if_running
# Fetch execution(s) - include running if we're going to wait
execution, available_executions, exec_error = await self._get_execution(
user_id=user_id,
graph_id=agent.graph_id,
execution_id=input_data.execution_id or None,
time_start=time_start,
time_end=time_end,
include_running=wait_timeout > 0,
)
if exec_error:
@@ -446,4 +505,17 @@ class AgentOutputTool(BaseTool):
session_id=session_id,
)
# If we have an execution that's still running and we should wait
if execution and wait_timeout > 0 and execution.status not in TERMINAL_STATUSES:
logger.info(
f"Execution {execution.id} is {execution.status}, "
f"waiting up to {wait_timeout}s for completion"
)
execution = await wait_for_execution(
user_id=user_id,
graph_id=agent.graph_id,
execution_id=execution.id,
timeout_seconds=wait_timeout,
)
return self._build_response(agent, execution, available_executions, session_id)

View File

@@ -1,8 +1,13 @@
"""Shared agent search functionality for find_agent and find_library_agent tools."""
from __future__ import annotations
import logging
import re
from typing import Literal
from typing import TYPE_CHECKING, Literal
if TYPE_CHECKING:
from backend.api.features.library.model import LibraryAgent
from backend.data.db_accessors import library_db, store_db
from backend.util.exceptions import DatabaseError, NotFoundError
@@ -24,94 +29,24 @@ _UUID_PATTERN = re.compile(
re.IGNORECASE,
)
def _is_uuid(text: str) -> bool:
"""Check if text is a valid UUID v4."""
return bool(_UUID_PATTERN.match(text.strip()))
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
"""Fetch a library agent by ID (library agent ID or graph_id).
Tries multiple lookup strategies:
1. First by graph_id (AgentGraph primary key)
2. Then by library agent ID (LibraryAgent primary key)
Args:
user_id: The user ID
agent_id: The ID to look up (can be graph_id or library agent ID)
Returns:
AgentInfo if found, None otherwise
"""
lib_db = library_db()
try:
agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id)
if agent:
logger.debug(f"Found library agent by graph_id: {agent.name}")
return AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by graph_id {agent_id}: {e}",
exc_info=True,
)
try:
agent = await lib_db.get_library_agent(agent_id, user_id)
if agent:
logger.debug(f"Found library agent by library_id: {agent.name}")
return AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
except NotFoundError:
logger.debug(f"Library agent not found by library_id: {agent_id}")
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by library_id {agent_id}: {e}",
exc_info=True,
)
return None
# Keywords that should be treated as "list all" rather than a literal search
_LIST_ALL_KEYWORDS = frozenset({"all", "*", "everything", "any", ""})
async def search_agents(
query: str,
source: SearchSource,
session_id: str | None,
session_id: str | None = None,
user_id: str | None = None,
) -> ToolResponseBase:
"""
Search for agents in marketplace or user library.
For library searches, keywords like "all", "*", "everything", or an empty
query will list all agents without filtering.
Args:
query: Search query string
query: Search query string. Special keywords list all library agents.
source: "marketplace" or "library"
session_id: Chat session ID
user_id: User ID (required for library search)
@@ -119,7 +54,11 @@ async def search_agents(
Returns:
AgentsFoundResponse, NoResultsResponse, or ErrorResponse
"""
if not query:
# Normalize list-all keywords to empty string for library searches
if source == "library" and query.lower().strip() in _LIST_ALL_KEYWORDS:
query = ""
if source == "marketplace" and not query:
return ErrorResponse(
message="Please provide a search query", session_id=session_id
)
@@ -159,28 +98,18 @@ async def search_agents(
logger.info(f"Found agent by direct ID lookup: {agent.name}")
if not agents:
logger.info(f"Searching user library for: {query}")
search_term = query or None
logger.info(
f"{'Listing all agents in' if not query else 'Searching'} "
f"user library{'' if not query else f' for: {query}'}"
)
results = await library_db().list_library_agents(
user_id=user_id, # type: ignore[arg-type]
search_term=query,
page_size=10,
search_term=search_term,
page_size=50 if not query else 10,
)
for agent in results.agents:
agents.append(
AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
)
agents.append(_library_agent_to_info(agent))
logger.info(f"Found {len(agents)} agents in {source}")
except NotFoundError:
pass
@@ -193,42 +122,62 @@ async def search_agents(
)
if not agents:
suggestions = (
[
if source == "marketplace":
suggestions = [
"Try more general terms",
"Browse categories in the marketplace",
"Check spelling",
]
if source == "marketplace"
else [
no_results_msg = (
f"No agents found matching '{query}'. Let the user know they can "
"try different keywords or browse the marketplace. Also let them "
"know you can create a custom agent for them based on their needs."
)
elif not query:
# User asked to list all but library is empty
suggestions = [
"Browse the marketplace to find and add agents",
"Use find_agent to search the marketplace",
]
no_results_msg = (
"Your library is empty. Let the user know they can browse the "
"marketplace to find agents, or you can create a custom agent "
"for them based on their needs."
)
else:
suggestions = [
"Try different keywords",
"Use find_agent to search the marketplace",
"Check your library at /library",
]
)
no_results_msg = (
f"No agents found matching '{query}'. Let the user know they can try different keywords or browse the marketplace. Also let them know you can create a custom agent for them based on their needs."
if source == "marketplace"
else f"No agents matching '{query}' found in your library. Let the user know you can create a custom agent for them based on their needs."
)
no_results_msg = (
f"No agents matching '{query}' found in your library. Let the "
"user know you can create a custom agent for them based on "
"their needs."
)
return NoResultsResponse(
message=no_results_msg, session_id=session_id, suggestions=suggestions
)
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} "
title += (
f"for '{query}'"
if source == "marketplace"
else f"in your library for '{query}'"
)
if source == "marketplace":
title = (
f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} for '{query}'"
)
elif not query:
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} in your library"
else:
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} in your library for '{query}'"
message = (
"Now you have found some options for the user to choose from. "
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
"Please ask the user if they would like to use any of these agents. Let the user know we can create a custom agent for them based on their needs."
"Please ask the user if they would like to use any of these agents. "
"Let the user know we can create a custom agent for them based on their needs."
if source == "marketplace"
else "Found agents in the user's library. You can provide a link to view an agent at: "
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute. Let the user know we can create a custom agent for them based on their needs."
else "Found agents in the user's library. You can provide a link to view "
"an agent at: /library/agents/{agent_id}. Use agent_output to get "
"execution results, or run_agent to execute. Let the user know we can "
"create a custom agent for them based on their needs."
)
return AgentsFoundResponse(
@@ -238,3 +187,67 @@ async def search_agents(
count=len(agents),
session_id=session_id,
)
def _is_uuid(text: str) -> bool:
"""Check if text is a valid UUID v4."""
return bool(_UUID_PATTERN.match(text.strip()))
def _library_agent_to_info(agent: LibraryAgent) -> AgentInfo:
"""Convert a library agent model to an AgentInfo."""
return AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
"""Fetch a library agent by ID (library agent ID or graph_id).
Tries multiple lookup strategies:
1. First by graph_id (AgentGraph primary key)
2. Then by library agent ID (LibraryAgent primary key)
"""
lib_db = library_db()
try:
agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id)
if agent:
logger.debug(f"Found library agent by graph_id: {agent.name}")
return _library_agent_to_info(agent)
except NotFoundError:
logger.debug(f"Library agent not found by graph_id: {agent_id}")
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by graph_id {agent_id}: {e}",
exc_info=True,
)
try:
agent = await lib_db.get_library_agent(agent_id, user_id)
if agent:
logger.debug(f"Found library agent by library_id: {agent.name}")
return _library_agent_to_info(agent)
except NotFoundError:
logger.debug(f"Library agent not found by library_id: {agent_id}")
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by library_id {agent_id}: {e}",
exc_info=True,
)
return None

View File

@@ -1,5 +1,6 @@
"""Base classes and shared utilities for chat tools."""
import json
import logging
from typing import Any
@@ -7,11 +8,98 @@ from openai.types.chat import ChatCompletionToolParam
from backend.copilot.model import ChatSession
from backend.copilot.response_model import StreamToolOutputAvailable
from backend.data.db_accessors import workspace_db
from backend.util.truncate import truncate
from backend.util.workspace import WorkspaceManager
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
logger = logging.getLogger(__name__)
# Persist full tool output to workspace when it exceeds this threshold.
# Must be below _MAX_TOOL_OUTPUT_SIZE (100K) in response_model.py so we
# capture the data before model_post_init middle-out truncation discards it.
_LARGE_OUTPUT_THRESHOLD = 80_000
# Character budget for the middle-out preview. The total preview + wrapper
# must stay below BOTH:
# - _MAX_TOOL_OUTPUT_SIZE (100K) in response_model.py (our own truncation)
# - Claude SDK's ~100 KB tool-result spill-to-disk threshold
# to avoid double truncation/spilling. 95K + ~300 wrapper = ~95.3K, under both.
_PREVIEW_CHARS = 95_000
# Fields whose values are binary/base64 data — truncating them produces
# garbage, so we replace them with a human-readable size summary instead.
_BINARY_FIELD_NAMES = {"content_base64"}
def _summarize_binary_fields(raw_json: str) -> str:
"""Replace known binary fields with a size summary so truncate() doesn't
produce garbled base64 in the middle-out preview."""
try:
data = json.loads(raw_json)
except (json.JSONDecodeError, TypeError):
return raw_json
if not isinstance(data, dict):
return raw_json
changed = False
for key in _BINARY_FIELD_NAMES:
if key in data and isinstance(data[key], str) and len(data[key]) > 1_000:
byte_size = len(data[key]) * 3 // 4 # approximate decoded size
data[key] = f"<binary, ~{byte_size:,} bytes>"
changed = True
return json.dumps(data, ensure_ascii=False) if changed else raw_json
async def _persist_and_summarize(
raw_output: str,
user_id: str,
session_id: str,
tool_call_id: str,
) -> str:
"""Persist full output to workspace and return a middle-out preview with retrieval instructions.
On failure, returns the original ``raw_output`` unchanged so that the
existing ``model_post_init`` middle-out truncation handles it as before.
"""
file_path = f"tool-outputs/{tool_call_id}.json"
try:
workspace = await workspace_db().get_or_create_workspace(user_id)
manager = WorkspaceManager(user_id, workspace.id, session_id)
await manager.write_file(
content=raw_output.encode("utf-8"),
filename=f"{tool_call_id}.json",
path=file_path,
mime_type="application/json",
overwrite=True,
)
except Exception:
logger.warning(
"Failed to persist large tool output for %s",
tool_call_id,
exc_info=True,
)
return raw_output # fall back to normal truncation
total = len(raw_output)
preview = truncate(_summarize_binary_fields(raw_output), _PREVIEW_CHARS)
retrieval = (
f"\nFull output ({total:,} chars) saved to workspace. "
f"Use read_workspace_file("
f'path="{file_path}", offset=<char_offset>, length=50000) '
f"to read any section."
)
return (
f'<tool-output-truncated total_chars={total} path="{file_path}">\n'
f"{preview}\n"
f"{retrieval}\n"
f"</tool-output-truncated>"
)
class BaseTool:
"""Base class for all chat tools."""
@@ -36,6 +124,16 @@ class BaseTool:
"""Whether this tool requires authentication."""
return False
@property
def is_available(self) -> bool:
"""Whether this tool is available in the current environment.
Override to check required env vars, binaries, or other dependencies.
Unavailable tools are excluded from the LLM tool list so the model is
never offered an option that will immediately fail.
"""
return True
def as_openai_tool(self) -> ChatCompletionToolParam:
"""Convert to OpenAI tool format."""
return ChatCompletionToolParam(
@@ -57,7 +155,7 @@ class BaseTool:
"""Execute the tool with authentication check.
Args:
user_id: User ID (may be anonymous like "anon_123")
user_id: User ID (None for anonymous users)
session_id: Chat session ID
**kwargs: Tool-specific parameters
@@ -81,10 +179,21 @@ class BaseTool:
try:
result = await self._execute(user_id, session, **kwargs)
raw_output = result.model_dump_json()
if (
len(raw_output) > _LARGE_OUTPUT_THRESHOLD
and user_id
and session.session_id
):
raw_output = await _persist_and_summarize(
raw_output, user_id, session.session_id, tool_call_id
)
return StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=self.name,
output=result.model_dump_json(),
output=raw_output,
)
except Exception as e:
logger.error(f"Error in {self.name}: {e}", exc_info=True)

View File

@@ -0,0 +1,194 @@
"""Tests for BaseTool large-output persistence in execute()."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot.tools.base import (
_LARGE_OUTPUT_THRESHOLD,
BaseTool,
_persist_and_summarize,
_summarize_binary_fields,
)
from backend.copilot.tools.models import ResponseType, ToolResponseBase
class _HugeOutputTool(BaseTool):
"""Fake tool that returns an arbitrarily large output."""
def __init__(self, output_size: int) -> None:
self._output_size = output_size
@property
def name(self) -> str:
return "huge_output_tool"
@property
def description(self) -> str:
return "Returns a huge output"
@property
def parameters(self) -> dict:
return {"type": "object", "properties": {}}
async def _execute(self, user_id, session, **kwargs) -> ToolResponseBase:
return ToolResponseBase(
type=ResponseType.ERROR,
message="x" * self._output_size,
)
# ---------------------------------------------------------------------------
# _persist_and_summarize
# ---------------------------------------------------------------------------
class TestPersistAndSummarize:
@pytest.mark.asyncio
async def test_returns_middle_out_preview_with_retrieval_instructions(self):
raw = "A" * 200_000
mock_workspace = MagicMock()
mock_workspace.id = "ws-1"
mock_db = AsyncMock()
mock_db.get_or_create_workspace = AsyncMock(return_value=mock_workspace)
mock_manager = AsyncMock()
with (
patch("backend.copilot.tools.base.workspace_db", return_value=mock_db),
patch(
"backend.copilot.tools.base.WorkspaceManager",
return_value=mock_manager,
),
):
result = await _persist_and_summarize(raw, "user-1", "session-1", "tc-123")
assert "<tool-output-truncated" in result
assert "</tool-output-truncated>" in result
assert "total_chars=200000" in result
assert 'path="tool-outputs/tc-123.json"' in result
assert "read_workspace_file" in result
# Middle-out sentinel from truncate()
assert "omitted" in result
# Total result is much shorter than the raw output
assert len(result) < len(raw)
# Verify write_file was called with full content
mock_manager.write_file.assert_awaited_once()
call_kwargs = mock_manager.write_file.call_args
assert call_kwargs.kwargs["content"] == raw.encode("utf-8")
assert call_kwargs.kwargs["path"] == "tool-outputs/tc-123.json"
@pytest.mark.asyncio
async def test_fallback_on_workspace_error(self):
"""If workspace write fails, return raw output for normal truncation."""
raw = "B" * 200_000
mock_db = AsyncMock()
mock_db.get_or_create_workspace = AsyncMock(side_effect=RuntimeError("boom"))
with patch("backend.copilot.tools.base.workspace_db", return_value=mock_db):
result = await _persist_and_summarize(raw, "user-1", "session-1", "tc-fail")
assert result == raw # unchanged — fallback to normal truncation
# ---------------------------------------------------------------------------
# BaseTool.execute — integration with persistence
# ---------------------------------------------------------------------------
class TestBaseToolExecuteLargeOutput:
@pytest.mark.asyncio
async def test_small_output_not_persisted(self):
"""Outputs under the threshold go through without persistence."""
tool = _HugeOutputTool(output_size=100)
session = MagicMock()
session.session_id = "s-1"
with patch(
"backend.copilot.tools.base._persist_and_summarize",
new_callable=AsyncMock,
) as persist_mock:
result = await tool.execute("user-1", session, "tc-small")
persist_mock.assert_not_awaited()
assert "<tool-output-truncated" not in str(result.output)
@pytest.mark.asyncio
async def test_large_output_persisted(self):
"""Outputs over the threshold trigger persistence + preview."""
tool = _HugeOutputTool(output_size=_LARGE_OUTPUT_THRESHOLD + 10_000)
session = MagicMock()
session.session_id = "s-1"
mock_workspace = MagicMock()
mock_workspace.id = "ws-1"
mock_db = AsyncMock()
mock_db.get_or_create_workspace = AsyncMock(return_value=mock_workspace)
mock_manager = AsyncMock()
with (
patch("backend.copilot.tools.base.workspace_db", return_value=mock_db),
patch(
"backend.copilot.tools.base.WorkspaceManager",
return_value=mock_manager,
),
):
result = await tool.execute("user-1", session, "tc-big")
assert "<tool-output-truncated" in str(result.output)
assert "read_workspace_file" in str(result.output)
mock_manager.write_file.assert_awaited_once()
@pytest.mark.asyncio
async def test_no_persistence_without_user_id(self):
"""Anonymous users skip persistence (no workspace)."""
tool = _HugeOutputTool(output_size=_LARGE_OUTPUT_THRESHOLD + 10_000)
session = MagicMock()
session.session_id = "s-1"
# user_id=None → should not attempt persistence
with patch(
"backend.copilot.tools.base._persist_and_summarize",
new_callable=AsyncMock,
) as persist_mock:
result = await tool.execute(None, session, "tc-anon")
persist_mock.assert_not_awaited()
# Output is set but not wrapped in <tool-output-truncated> tags
# (it will be middle-out truncated by model_post_init instead)
assert "<tool-output-truncated" not in str(result.output)
# ---------------------------------------------------------------------------
# _summarize_binary_fields
# ---------------------------------------------------------------------------
class TestSummarizeBinaryFields:
def test_replaces_large_content_base64(self):
import json
data = {"content_base64": "A" * 10_000, "name": "file.png"}
result = json.loads(_summarize_binary_fields(json.dumps(data)))
assert result["name"] == "file.png"
assert "<binary" in result["content_base64"]
assert "bytes>" in result["content_base64"]
def test_preserves_small_content_base64(self):
import json
data = {"content_base64": "AQID", "name": "tiny.bin"}
result_str = _summarize_binary_fields(json.dumps(data))
result = json.loads(result_str)
assert result["content_base64"] == "AQID" # unchanged
def test_non_json_passthrough(self):
raw = "not json at all"
assert _summarize_binary_fields(raw) == raw
def test_no_binary_fields_unchanged(self):
import json
data = {"message": "hello", "type": "info"}
raw = json.dumps(data)
assert _summarize_binary_fields(raw) == raw

View File

@@ -1,19 +1,30 @@
"""Bash execution tool — run shell commands in a bubblewrap sandbox.
"""Bash execution tool — run shell commands on E2B or in a bubblewrap sandbox.
Full Bash scripting is allowed (loops, conditionals, pipes, functions, etc.).
Safety comes from OS-level isolation (bubblewrap): only system dirs visible
read-only, writable workspace only, clean env, no network.
When an E2B sandbox is available in the current execution context the command
runs directly on the remote E2B cloud environment. This means:
Requires bubblewrap (``bwrap``) — the tool is disabled when bwrap is not
available (e.g. macOS development).
- **Persistent filesystem**: files survive across turns via HTTP-based sync
with the sandbox's ``/home/user`` directory (E2B files API), shared with
SDK Read/Write/Edit tools.
- **Full internet access**: E2B sandboxes have unrestricted outbound network.
- **Execution isolation**: E2B provides a fresh, containerised Linux environment.
When E2B is *not* configured the tool falls back to **bubblewrap** (bwrap):
OS-level isolation with a whitelist-only filesystem, no network, and resource
limits. Requires bubblewrap to be installed (Linux only).
"""
import logging
import shlex
from typing import Any
from e2b import AsyncSandbox
from e2b.exceptions import TimeoutException
from backend.copilot.model import ChatSession
from .base import BaseTool
from .e2b_sandbox import E2B_WORKDIR
from .models import BashExecResponse, ErrorResponse, ToolResponseBase
from .sandbox import get_workspace_dir, has_full_sandbox, run_sandboxed
@@ -21,7 +32,7 @@ logger = logging.getLogger(__name__)
class BashExecTool(BaseTool):
"""Execute Bash commands in a bubblewrap sandbox."""
"""Execute Bash commands on E2B or in a bubblewrap sandbox."""
@property
def name(self) -> str:
@@ -29,28 +40,16 @@ class BashExecTool(BaseTool):
@property
def description(self) -> str:
if not has_full_sandbox():
return (
"Bash execution is DISABLED — bubblewrap sandbox is not "
"available on this platform. Do not call this tool."
)
return (
"Execute a Bash command or script in a bubblewrap sandbox. "
"Execute a Bash command or script. "
"Full Bash scripting is supported (loops, conditionals, pipes, "
"functions, etc.). "
"The sandbox shares the same working directory as the SDK Read/Write "
"tools — files created by either are accessible to both. "
"SECURITY: Only system directories (/usr, /bin, /lib, /etc) are "
"visible read-only, the per-session workspace is the only writable "
"path, environment variables are wiped (no secrets), all network "
"access is blocked at the kernel level, and resource limits are "
"enforced (max 64 processes, 512MB memory, 50MB file size). "
"Application code, configs, and other directories are NOT accessible. "
"To fetch web content, use the web_fetch tool instead. "
"The working directory is shared with the SDK Read/Write/Edit/Glob/Grep "
"tools — files created by either are immediately visible to both. "
"Execution is killed after the timeout (default 30s, max 120s). "
"Returns stdout and stderr. "
"Useful for file manipulation, data processing with Unix tools "
"(grep, awk, sed, jq, etc.), and running shell scripts."
"Useful for file manipulation, data processing, running scripts, "
"and installing packages."
)
@property
@@ -85,15 +84,8 @@ class BashExecTool(BaseTool):
) -> ToolResponseBase:
session_id = session.session_id if session else None
if not has_full_sandbox():
return ErrorResponse(
message="bash_exec requires bubblewrap sandbox (Linux only).",
error="sandbox_unavailable",
session_id=session_id,
)
command: str = (kwargs.get("command") or "").strip()
timeout: int = kwargs.get("timeout", 30)
timeout: int = int(kwargs.get("timeout", 30))
if not command:
return ErrorResponse(
@@ -102,6 +94,21 @@ class BashExecTool(BaseTool):
session_id=session_id,
)
# E2B path: run on remote cloud sandbox when available.
from backend.copilot.sdk.tool_adapter import get_current_sandbox
sandbox = get_current_sandbox()
if sandbox is not None:
return await self._execute_on_e2b(sandbox, command, timeout, session_id)
# Bubblewrap fallback: local isolated execution.
if not has_full_sandbox():
return ErrorResponse(
message="bash_exec requires bubblewrap sandbox (Linux only).",
error="sandbox_unavailable",
session_id=session_id,
)
workspace = get_workspace_dir(session_id or "default")
stdout, stderr, exit_code, timed_out = await run_sandboxed(
@@ -122,3 +129,43 @@ class BashExecTool(BaseTool):
timed_out=timed_out,
session_id=session_id,
)
async def _execute_on_e2b(
self,
sandbox: AsyncSandbox,
command: str,
timeout: int,
session_id: str | None,
) -> ToolResponseBase:
"""Execute *command* on the E2B sandbox via commands.run()."""
try:
result = await sandbox.commands.run(
f"bash -c {shlex.quote(command)}",
cwd=E2B_WORKDIR,
timeout=timeout,
envs={"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"},
)
return BashExecResponse(
message=f"Command executed on E2B (exit {result.exit_code})",
stdout=result.stdout or "",
stderr=result.stderr or "",
exit_code=result.exit_code,
timed_out=False,
session_id=session_id,
)
except Exception as exc:
if isinstance(exc, TimeoutException):
return BashExecResponse(
message="Execution timed out",
stdout="",
stderr=f"Timed out after {timeout}s",
exit_code=-1,
timed_out=True,
session_id=session_id,
)
logger.error("[E2B] bash_exec failed: %s", exc, exc_info=True)
return ErrorResponse(
message=f"E2B execution failed: {exc}",
error="e2b_execution_error",
session_id=session_id,
)

View File

@@ -0,0 +1,170 @@
"""E2B sandbox lifecycle for CoPilot: persistent cloud execution.
Each session gets a long-lived E2B cloud sandbox. ``bash_exec`` runs commands
directly on the sandbox via ``sandbox.commands.run()``. SDK file tools
(read_file/write_file/edit_file/glob/grep) route to the sandbox's
``/home/user`` directory via E2B's HTTP-based filesystem API — all tools
share a single coherent filesystem with no local sync required.
Lifecycle
---------
1. **Turn start** connect to the existing sandbox (sandbox_id in Redis) or
create a new one via ``get_or_create_sandbox()``.
2. **Execution** ``bash_exec`` and MCP file tools operate directly on the
sandbox's ``/home/user`` filesystem.
3. **Session expiry** E2B sandbox is killed by its own timeout (session_ttl).
"""
import asyncio
import logging
from e2b import AsyncSandbox
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
_SANDBOX_REDIS_PREFIX = "copilot:e2b:sandbox:"
E2B_WORKDIR = "/home/user"
_CREATING = "__creating__"
_CREATION_LOCK_TTL = 60
_MAX_WAIT_ATTEMPTS = 20 # 20 * 0.5s = 10s max wait
async def _try_reconnect(
sandbox_id: str, api_key: str, redis_key: str, timeout: int
) -> "AsyncSandbox | None":
"""Try to reconnect to an existing sandbox. Returns None on failure."""
try:
sandbox = await AsyncSandbox.connect(sandbox_id, api_key=api_key)
if await sandbox.is_running():
redis = await get_redis_async()
await redis.expire(redis_key, timeout)
return sandbox
except Exception as exc:
logger.warning("[E2B] Reconnect to %.12s failed: %s", sandbox_id, exc)
# Stale — clear Redis so a new sandbox can be created.
redis = await get_redis_async()
await redis.delete(redis_key)
return None
async def get_or_create_sandbox(
session_id: str,
api_key: str,
template: str = "base",
timeout: int = 43200,
) -> AsyncSandbox:
"""Return the existing E2B sandbox for *session_id* or create a new one.
The sandbox_id is persisted in Redis so the same sandbox is reused
across turns. Concurrent calls for the same session are serialised
via a Redis ``SET NX`` creation lock.
"""
redis = await get_redis_async()
redis_key = f"{_SANDBOX_REDIS_PREFIX}{session_id}"
# 1. Try reconnecting to an existing sandbox.
raw = await redis.get(redis_key)
if raw:
sandbox_id = raw if isinstance(raw, str) else raw.decode()
if sandbox_id != _CREATING:
sandbox = await _try_reconnect(sandbox_id, api_key, redis_key, timeout)
if sandbox:
logger.info(
"[E2B] Reconnected to %.12s for session %.12s",
sandbox_id,
session_id,
)
return sandbox
# 2. Claim creation lock. If another request holds it, wait for the result.
claimed = await redis.set(redis_key, _CREATING, nx=True, ex=_CREATION_LOCK_TTL)
if not claimed:
for _ in range(_MAX_WAIT_ATTEMPTS):
await asyncio.sleep(0.5)
raw = await redis.get(redis_key)
if not raw:
break # Lock expired — fall through to retry creation
sandbox_id = raw if isinstance(raw, str) else raw.decode()
if sandbox_id != _CREATING:
sandbox = await _try_reconnect(sandbox_id, api_key, redis_key, timeout)
if sandbox:
return sandbox
break # Stale sandbox cleared — fall through to create
# Try to claim creation lock again after waiting.
claimed = await redis.set(redis_key, _CREATING, nx=True, ex=_CREATION_LOCK_TTL)
if not claimed:
# Another process may have created a sandbox — try to use it.
raw = await redis.get(redis_key)
if raw:
sandbox_id = raw if isinstance(raw, str) else raw.decode()
if sandbox_id != _CREATING:
sandbox = await _try_reconnect(
sandbox_id, api_key, redis_key, timeout
)
if sandbox:
return sandbox
raise RuntimeError(
f"Could not acquire E2B creation lock for session {session_id[:12]}"
)
# 3. Create a new sandbox.
try:
sandbox = await AsyncSandbox.create(
template=template, api_key=api_key, timeout=timeout
)
except Exception:
await redis.delete(redis_key)
raise
await redis.setex(redis_key, timeout, sandbox.sandbox_id)
logger.info(
"[E2B] Created sandbox %.12s for session %.12s",
sandbox.sandbox_id,
session_id,
)
return sandbox
async def kill_sandbox(session_id: str, api_key: str) -> bool:
"""Kill the E2B sandbox for *session_id* and clean up its Redis entry.
Returns ``True`` if a sandbox was found and killed, ``False`` otherwise.
Safe to call even when no sandbox exists for the session.
"""
redis = await get_redis_async()
redis_key = f"{_SANDBOX_REDIS_PREFIX}{session_id}"
raw = await redis.get(redis_key)
if not raw:
return False
sandbox_id = raw if isinstance(raw, str) else raw.decode()
await redis.delete(redis_key)
if sandbox_id == _CREATING:
return False
try:
async def _connect_and_kill():
sandbox = await AsyncSandbox.connect(sandbox_id, api_key=api_key)
await sandbox.kill()
await asyncio.wait_for(_connect_and_kill(), timeout=10)
logger.info(
"[E2B] Killed sandbox %.12s for session %.12s",
sandbox_id,
session_id,
)
return True
except Exception as exc:
logger.warning(
"[E2B] Failed to kill sandbox %.12s for session %.12s: %s",
sandbox_id,
session_id,
exc,
)
return False

View File

@@ -0,0 +1,272 @@
"""Tests for e2b_sandbox: get_or_create_sandbox, _try_reconnect, kill_sandbox.
Uses mock Redis and mock AsyncSandbox — no external dependencies.
Tests are synchronous (using asyncio.run) to avoid conflicts with the
session-scoped event loop in conftest.py.
"""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from .e2b_sandbox import (
_CREATING,
_SANDBOX_REDIS_PREFIX,
_try_reconnect,
get_or_create_sandbox,
kill_sandbox,
)
_KEY = f"{_SANDBOX_REDIS_PREFIX}sess-123"
_API_KEY = "test-api-key"
_TIMEOUT = 300
def _mock_sandbox(sandbox_id: str = "sb-abc", running: bool = True) -> MagicMock:
sb = MagicMock()
sb.sandbox_id = sandbox_id
sb.is_running = AsyncMock(return_value=running)
return sb
def _mock_redis(get_val: str | bytes | None = None, set_nx_result: bool = True):
r = AsyncMock()
r.get = AsyncMock(return_value=get_val)
r.set = AsyncMock(return_value=set_nx_result)
r.setex = AsyncMock()
r.delete = AsyncMock()
r.expire = AsyncMock()
return r
def _patch_redis(redis):
return patch(
"backend.copilot.tools.e2b_sandbox.get_redis_async",
new_callable=AsyncMock,
return_value=redis,
)
# ---------------------------------------------------------------------------
# _try_reconnect
# ---------------------------------------------------------------------------
class TestTryReconnect:
def test_reconnect_success(self):
sb = _mock_sandbox()
redis = _mock_redis()
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
assert result is sb
redis.expire.assert_awaited_once_with(_KEY, _TIMEOUT)
redis.delete.assert_not_awaited()
def test_reconnect_not_running_clears_key(self):
sb = _mock_sandbox(running=False)
redis = _mock_redis()
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
assert result is None
redis.delete.assert_awaited_once_with(_KEY)
def test_reconnect_exception_clears_key(self):
redis = _mock_redis()
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
assert result is None
redis.delete.assert_awaited_once_with(_KEY)
# ---------------------------------------------------------------------------
# get_or_create_sandbox
# ---------------------------------------------------------------------------
class TestGetOrCreateSandbox:
def test_reconnect_existing(self):
"""When Redis has a valid sandbox_id, reconnect to it."""
sb = _mock_sandbox()
redis = _mock_redis(get_val="sb-abc")
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
)
assert result is sb
mock_cls.create.assert_not_called()
def test_create_new_when_no_key(self):
"""When Redis is empty, claim lock and create a new sandbox."""
sb = _mock_sandbox("sb-new")
redis = _mock_redis(get_val=None, set_nx_result=True)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.create = AsyncMock(return_value=sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
)
assert result is sb
redis.setex.assert_awaited_once_with(_KEY, _TIMEOUT, "sb-new")
def test_create_failure_clears_lock(self):
"""If sandbox creation fails, the Redis lock is deleted."""
redis = _mock_redis(get_val=None, set_nx_result=True)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.create = AsyncMock(side_effect=RuntimeError("quota"))
with pytest.raises(RuntimeError, match="quota"):
asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
)
redis.delete.assert_awaited_once_with(_KEY)
def test_wait_for_lock_then_reconnect(self):
"""When another process holds the lock, wait and reconnect."""
sb = _mock_sandbox("sb-other")
redis = _mock_redis()
redis.get = AsyncMock(side_effect=[_CREATING, "sb-other"])
redis.set = AsyncMock(return_value=False)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
patch(
"backend.copilot.tools.e2b_sandbox.asyncio.sleep",
new_callable=AsyncMock,
),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
)
assert result is sb
def test_stale_reconnect_clears_and_creates(self):
"""When stored sandbox is stale, clear key and create a new one."""
stale_sb = _mock_sandbox("sb-stale", running=False)
new_sb = _mock_sandbox("sb-fresh")
redis = _mock_redis(get_val="sb-stale", set_nx_result=True)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=stale_sb)
mock_cls.create = AsyncMock(return_value=new_sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
)
assert result is new_sb
redis.delete.assert_awaited()
# ---------------------------------------------------------------------------
# kill_sandbox
# ---------------------------------------------------------------------------
class TestKillSandbox:
def test_kill_existing_sandbox(self):
"""Kill a running sandbox and clean up Redis."""
sb = _mock_sandbox()
sb.kill = AsyncMock()
redis = _mock_redis(get_val="sb-abc")
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
assert result is True
redis.delete.assert_awaited_once_with(_KEY)
sb.kill.assert_awaited_once()
def test_kill_no_sandbox(self):
"""No-op when no sandbox exists in Redis."""
redis = _mock_redis(get_val=None)
with _patch_redis(redis):
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
assert result is False
redis.delete.assert_not_awaited()
def test_kill_creating_state(self):
"""Clears Redis key but returns False when sandbox is still being created."""
redis = _mock_redis(get_val=_CREATING)
with _patch_redis(redis):
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
assert result is False
redis.delete.assert_awaited_once_with(_KEY)
def test_kill_connect_failure(self):
"""Returns False and cleans Redis if connect/kill fails."""
redis = _mock_redis(get_val="sb-abc")
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
assert result is False
redis.delete.assert_awaited_once_with(_KEY)
def test_kill_with_bytes_redis_value(self):
"""Redis may return bytes — kill_sandbox should decode correctly."""
sb = _mock_sandbox()
sb.kill = AsyncMock()
redis = _mock_redis(get_val=b"sb-abc")
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
assert result is True
sb.kill.assert_awaited_once()
def test_kill_timeout_returns_false(self):
"""Returns False when E2B API calls exceed the 10s timeout."""
redis = _mock_redis(get_val="sb-abc")
with (
_patch_redis(redis),
patch(
"backend.copilot.tools.e2b_sandbox.asyncio.wait_for",
new_callable=AsyncMock,
side_effect=asyncio.TimeoutError,
),
):
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
assert result is False
redis.delete.assert_awaited_once_with(_KEY)

View File

@@ -0,0 +1,186 @@
"""Shared utilities for execution waiting and status handling."""
import asyncio
import logging
from typing import Any
from backend.data.db_accessors import execution_db
from backend.data.execution import (
AsyncRedisExecutionEventBus,
ExecutionStatus,
GraphExecution,
GraphExecutionEvent,
)
logger = logging.getLogger(__name__)
# Terminal statuses that indicate execution is complete
TERMINAL_STATUSES = frozenset(
{
ExecutionStatus.COMPLETED,
ExecutionStatus.FAILED,
ExecutionStatus.TERMINATED,
}
)
# Statuses where execution is paused but not finished (e.g. human-in-the-loop)
PAUSED_STATUSES = frozenset(
{
ExecutionStatus.REVIEW,
}
)
# Statuses that mean "stop waiting" (terminal or paused)
STOP_WAITING_STATUSES = TERMINAL_STATUSES | PAUSED_STATUSES
_POST_SUBSCRIBE_RECHECK_DELAY = 0.1 # seconds to wait for subscription to establish
async def wait_for_execution(
user_id: str,
graph_id: str,
execution_id: str,
timeout_seconds: int,
) -> GraphExecution | None:
"""
Wait for an execution to reach a terminal or paused status using Redis pubsub.
Handles the race condition between checking status and subscribing by
re-checking the DB after the subscription is established.
Args:
user_id: User ID
graph_id: Graph ID
execution_id: Execution ID to wait for
timeout_seconds: Max seconds to wait
Returns:
The execution with current status, or None if not found
"""
exec_db = execution_db()
# Quick check — maybe it's already done
execution = await exec_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)
if not execution:
return None
if execution.status in STOP_WAITING_STATUSES:
logger.debug(
f"Execution {execution_id} already in stop-waiting state: "
f"{execution.status}"
)
return execution
logger.info(
f"Waiting up to {timeout_seconds}s for execution {execution_id} "
f"(current status: {execution.status})"
)
event_bus = AsyncRedisExecutionEventBus()
channel_key = f"{user_id}/{graph_id}/{execution_id}"
# Mutable container so _subscribe_and_wait can surface the task even if
# asyncio.wait_for cancels the coroutine before it returns.
task_holder: list[asyncio.Task] = []
try:
result = await asyncio.wait_for(
_subscribe_and_wait(
event_bus, channel_key, user_id, execution_id, exec_db, task_holder
),
timeout=timeout_seconds,
)
return result
except asyncio.TimeoutError:
logger.info(f"Timeout waiting for execution {execution_id}")
except Exception as e:
logger.error(f"Error waiting for execution: {e}", exc_info=True)
finally:
for task in task_holder:
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
await event_bus.close()
# Return current state on timeout/error
return await exec_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)
async def _subscribe_and_wait(
event_bus: AsyncRedisExecutionEventBus,
channel_key: str,
user_id: str,
execution_id: str,
exec_db: Any,
task_holder: list[asyncio.Task],
) -> GraphExecution | None:
"""
Subscribe to execution events and wait for a terminal/paused status.
Appends the consumer task to ``task_holder`` so the caller can clean it up
even if this coroutine is cancelled by ``asyncio.wait_for``.
To avoid the race condition where the execution completes between the
initial DB check and the Redis subscription, we:
1. Start listening (which subscribes internally)
2. Re-check the DB after subscription is active
3. If still running, wait for pubsub events
"""
listen_iter = event_bus.listen_events(channel_key).__aiter__()
done = asyncio.Event()
result_execution: GraphExecution | None = None
async def _consume() -> None:
nonlocal result_execution
try:
async for event in listen_iter:
if isinstance(event, GraphExecutionEvent):
logger.debug(f"Received execution update: {event.status}")
if event.status in STOP_WAITING_STATUSES:
result_execution = await exec_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)
done.set()
return
except Exception as e:
logger.error(f"Error in execution consumer: {e}", exc_info=True)
done.set()
consume_task = asyncio.create_task(_consume())
task_holder.append(consume_task)
# Give the subscription a moment to establish, then re-check DB
await asyncio.sleep(_POST_SUBSCRIBE_RECHECK_DELAY)
execution = await exec_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)
if execution and execution.status in STOP_WAITING_STATUSES:
return execution
# Wait for the pubsub consumer to find a terminal event
await done.wait()
return result_execution
def get_execution_outputs(execution: GraphExecution | None) -> dict[str, Any] | None:
"""Extract outputs from an execution, or return None."""
if execution is None:
return None
return execution.outputs

View File

@@ -32,6 +32,7 @@ COPILOT_EXCLUDED_BLOCK_TYPES = {
BlockType.NOTE, # Visual annotation only - no runtime behavior
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
BlockType.MCP_TOOL, # Has dedicated run_mcp_tool tool with proper discovery + auth flow
}
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)

View File

@@ -19,9 +19,10 @@ class FindLibraryAgentTool(BaseTool):
@property
def description(self) -> str:
return (
"Search for agents in the user's library. Use this to find agents "
"the user has already added to their library, including agents they "
"created or added from the marketplace."
"Search for or list agents in the user's library. Use this to find "
"agents the user has already added to their library, including agents "
"they created or added from the marketplace. "
"Omit the query to list all agents."
)
@property
@@ -31,10 +32,13 @@ class FindLibraryAgentTool(BaseTool):
"properties": {
"query": {
"type": "string",
"description": "Search query to find agents by name or description.",
"description": (
"Search query to find agents by name or description. "
"Omit to list all agents in the library."
),
},
},
"required": ["query"],
"required": [],
}
@property
@@ -45,7 +49,7 @@ class FindLibraryAgentTool(BaseTool):
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
return await search_agents(
query=kwargs.get("query", "").strip(),
query=(kwargs.get("query") or "").strip(),
source="library",
session_id=session.session_id,
user_id=user_id,

View File

@@ -41,6 +41,10 @@ class ResponseType(str, Enum):
INPUT_VALIDATION_ERROR = "input_validation_error"
# Web fetch
WEB_FETCH = "web_fetch"
# Agent-browser multi-step automation (navigate, act, screenshot)
BROWSER_NAVIGATE = "browser_navigate"
BROWSER_ACT = "browser_act"
BROWSER_SCREENSHOT = "browser_screenshot"
# Code execution
BASH_EXEC = "bash_exec"
# Feature request types
@@ -48,6 +52,9 @@ class ResponseType(str, Enum):
FEATURE_REQUEST_CREATED = "feature_request_created"
# Goal refinement
SUGGESTED_GOAL = "suggested_goal"
# MCP tool types
MCP_TOOLS_DISCOVERED = "mcp_tools_discovered"
MCP_TOOL_OUTPUT = "mcp_tool_output"
# Base response model
@@ -476,3 +483,59 @@ class FeatureRequestCreatedResponse(ToolResponseBase):
issue_url: str
is_new_issue: bool # False if added to existing
customer_name: str
# MCP tool models
class MCPToolInfo(BaseModel):
"""Information about a single MCP tool discovered from a server."""
name: str
description: str
input_schema: dict[str, Any]
class MCPToolsDiscoveredResponse(ToolResponseBase):
"""Response when MCP tools are discovered from a server (agent-internal)."""
type: ResponseType = ResponseType.MCP_TOOLS_DISCOVERED
server_url: str
tools: list[MCPToolInfo]
class MCPToolOutputResponse(ToolResponseBase):
"""Response after executing an MCP tool."""
type: ResponseType = ResponseType.MCP_TOOL_OUTPUT
server_url: str
tool_name: str
result: Any = None
success: bool = True
# Agent-browser multi-step automation models
class BrowserNavigateResponse(ToolResponseBase):
"""Response for browser_navigate tool."""
type: ResponseType = ResponseType.BROWSER_NAVIGATE
url: str
title: str
snapshot: str # Interactive accessibility tree with @ref IDs
class BrowserActResponse(ToolResponseBase):
"""Response for browser_act tool."""
type: ResponseType = ResponseType.BROWSER_ACT
action: str
current_url: str = ""
snapshot: str # Updated accessibility tree after the action
class BrowserScreenshotResponse(ToolResponseBase):
"""Response for browser_screenshot tool."""
type: ResponseType = ResponseType.BROWSER_SCREENSHOT
file_id: str # Workspace file ID — use read_workspace_file to retrieve
filename: str

View File

@@ -9,6 +9,7 @@ from backend.copilot.config import ChatConfig
from backend.copilot.model import ChatSession
from backend.copilot.tracking import track_agent_run_success, track_agent_scheduled
from backend.data.db_accessors import graph_db, library_db, user_db
from backend.data.execution import ExecutionStatus
from backend.data.graph import GraphModel
from backend.data.model import CredentialsMetaInput
from backend.executor import utils as execution_utils
@@ -20,12 +21,15 @@ from backend.util.timezone_utils import (
)
from .base import BaseTool
from .execution_utils import get_execution_outputs, wait_for_execution
from .helpers import get_inputs_from_schema
from .models import (
AgentDetails,
AgentDetailsResponse,
AgentOutputResponse,
ErrorResponse,
ExecutionOptions,
ExecutionOutputInfo,
ExecutionStartedResponse,
InputValidationErrorResponse,
SetupInfo,
@@ -66,6 +70,7 @@ class RunAgentInput(BaseModel):
schedule_name: str = ""
cron: str = ""
timezone: str = "UTC"
wait_for_result: int = Field(default=0, ge=0, le=300)
@field_validator(
"username_agent_slug",
@@ -147,6 +152,14 @@ class RunAgentTool(BaseTool):
"type": "string",
"description": "IANA timezone for schedule (default: UTC)",
},
"wait_for_result": {
"type": "integer",
"description": (
"Max seconds to wait for execution to complete (0-300). "
"If >0, blocks until the execution finishes or times out. "
"Returns execution outputs when complete."
),
},
},
"required": [],
}
@@ -341,6 +354,7 @@ class RunAgentTool(BaseTool):
graph=graph,
graph_credentials=graph_credentials,
inputs=params.inputs,
wait_for_result=params.wait_for_result,
)
except NotFoundError as e:
@@ -424,8 +438,9 @@ class RunAgentTool(BaseTool):
graph: GraphModel,
graph_credentials: dict[str, CredentialsMetaInput],
inputs: dict[str, Any],
wait_for_result: int = 0,
) -> ToolResponseBase:
"""Execute an agent immediately."""
"""Execute an agent immediately, optionally waiting for completion."""
session_id = session.session_id
# Check rate limits
@@ -462,6 +477,91 @@ class RunAgentTool(BaseTool):
)
library_agent_link = f"/library/agents/{library_agent.id}"
# If wait_for_result is requested, wait for execution to complete
if wait_for_result > 0:
logger.info(
f"Waiting up to {wait_for_result}s for execution {execution.id}"
)
completed = await wait_for_execution(
user_id=user_id,
graph_id=library_agent.graph_id,
execution_id=execution.id,
timeout_seconds=wait_for_result,
)
if completed and completed.status == ExecutionStatus.COMPLETED:
outputs = get_execution_outputs(completed)
return AgentOutputResponse(
message=(
f"Agent '{library_agent.name}' completed successfully. "
f"View at {library_agent_link}."
),
session_id=session_id,
agent_name=library_agent.name,
agent_id=library_agent.graph_id,
library_agent_id=library_agent.id,
library_agent_link=library_agent_link,
execution=ExecutionOutputInfo(
execution_id=execution.id,
status=completed.status.value,
started_at=completed.started_at,
ended_at=completed.ended_at,
outputs=outputs or {},
),
)
elif completed and completed.status == ExecutionStatus.FAILED:
error_detail = completed.stats.error if completed.stats else None
return ErrorResponse(
message=(
f"Agent '{library_agent.name}' execution failed. "
f"View details at {library_agent_link}."
),
session_id=session_id,
error=error_detail,
)
elif completed and completed.status == ExecutionStatus.TERMINATED:
error_detail = completed.stats.error if completed.stats else None
return ErrorResponse(
message=(
f"Agent '{library_agent.name}' execution was terminated. "
f"View details at {library_agent_link}."
),
session_id=session_id,
error=error_detail,
)
elif completed and completed.status == ExecutionStatus.REVIEW:
return ExecutionStartedResponse(
message=(
f"Agent '{library_agent.name}' is awaiting human review. "
f"Check at {library_agent_link}."
),
session_id=session_id,
execution_id=execution.id,
graph_id=library_agent.graph_id,
graph_name=library_agent.name,
library_agent_id=library_agent.id,
library_agent_link=library_agent_link,
status=ExecutionStatus.REVIEW.value,
)
else:
status = completed.status.value if completed else "unknown"
return ExecutionStartedResponse(
message=(
f"Agent '{library_agent.name}' is still {status} after "
f"{wait_for_result}s. Check results later at "
f"{library_agent_link}. "
f"Use view_agent_output with wait_if_running to check again."
),
session_id=session_id,
execution_id=execution.id,
graph_id=library_agent.graph_id,
graph_name=library_agent.name,
library_agent_id=library_agent.id,
library_agent_link=library_agent_link,
status=status,
)
return ExecutionStartedResponse(
message=(
f"Agent '{library_agent.name}' execution started successfully. "

View File

@@ -0,0 +1,339 @@
"""Tool for discovering and executing MCP (Model Context Protocol) server tools."""
import logging
from typing import Any
from urllib.parse import urlparse
from backend.blocks.mcp.block import MCPToolBlock
from backend.blocks.mcp.client import MCPClient, MCPClientError
from backend.blocks.mcp.helpers import (
auto_lookup_mcp_credential,
normalize_mcp_url,
parse_mcp_content,
server_host,
)
from backend.copilot.model import ChatSession
from backend.copilot.tools.utils import build_missing_credentials_from_field_info
from backend.util.request import HTTPClientError, validate_url
from .base import BaseTool
from .models import (
ErrorResponse,
MCPToolInfo,
MCPToolOutputResponse,
MCPToolsDiscoveredResponse,
SetupInfo,
SetupRequirementsResponse,
ToolResponseBase,
UserReadiness,
)
logger = logging.getLogger(__name__)
# HTTP status codes that indicate authentication is required
_AUTH_STATUS_CODES = {401, 403}
class RunMCPToolTool(BaseTool):
"""
Tool for discovering and executing tools on any MCP server.
Stage 1 — discovery: call with just server_url to get available tools.
Stage 2 — execution: call with server_url + tool_name + tool_arguments.
If the server requires OAuth credentials that the user hasn't connected yet,
a SetupRequirementsResponse is returned so the frontend can render the
same OAuth login UI as the graph builder.
"""
@property
def name(self) -> str:
return "run_mcp_tool"
@property
def description(self) -> str:
return (
"Connect to an MCP (Model Context Protocol) server to discover and execute its tools. "
"Call with just `server_url` to see available tools. "
"Then call again with `server_url`, `tool_name`, and `tool_arguments` to execute. "
"If the server requires authentication, the user will be prompted to connect it. "
"Find MCP servers at https://registry.modelcontextprotocol.io/ — hundreds of integrations "
"including GitHub, Postgres, Slack, filesystem, and more."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"server_url": {
"type": "string",
"description": (
"URL of the MCP server (Streamable HTTP endpoint), "
"e.g. https://mcp.example.com/mcp"
),
},
"tool_name": {
"type": "string",
"description": (
"Name of the MCP tool to execute. "
"Omit on first call to discover available tools."
),
},
"tool_arguments": {
"type": "object",
"description": (
"Arguments to pass to the selected tool. "
"Must match the tool's input schema returned during discovery."
),
},
},
"required": ["server_url"],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
server_url: str = (kwargs.get("server_url") or "").strip()
tool_name: str = (kwargs.get("tool_name") or "").strip()
raw_tool_arguments = kwargs.get("tool_arguments")
tool_arguments: dict[str, Any] = (
raw_tool_arguments if isinstance(raw_tool_arguments, dict) else {}
)
session_id = session.session_id
if raw_tool_arguments is not None and not isinstance(raw_tool_arguments, dict):
return ErrorResponse(
message="tool_arguments must be a JSON object.",
session_id=session_id,
)
if not server_url:
return ErrorResponse(
message="Please provide a server_url for the MCP server.",
session_id=session_id,
)
_parsed = urlparse(server_url)
if _parsed.username or _parsed.password:
return ErrorResponse(
message=(
"Do not include credentials in server_url. "
"Use the MCP credential setup flow instead."
),
session_id=session_id,
)
if _parsed.query or _parsed.fragment:
return ErrorResponse(
message=(
"Do not include query parameters or fragments in server_url. "
"Use the MCP credential setup flow instead."
),
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="Authentication required.",
session_id=session_id,
)
# Validate URL to prevent SSRF — blocks loopback and private IP ranges
try:
await validate_url(server_url, trusted_origins=[])
except ValueError as e:
msg = str(e)
if "Unable to resolve" in msg or "No IP addresses" in msg:
user_msg = (
f"Hostname not found: {server_host(server_url)}. "
"Please check the URL — the domain may not exist."
)
else:
user_msg = f"Blocked server URL: {msg}"
return ErrorResponse(message=user_msg, session_id=session_id)
# Fast DB lookup — no network call.
# Normalize for matching because stored credentials use normalized URLs.
creds = await auto_lookup_mcp_credential(user_id, normalize_mcp_url(server_url))
auth_token = creds.access_token.get_secret_value() if creds else None
client = MCPClient(server_url, auth_token=auth_token)
try:
await client.initialize()
if not tool_name:
# Stage 1: Discover available tools
return await self._discover_tools(client, server_url, session_id)
else:
# Stage 2: Execute the selected tool
return await self._execute_tool(
client, server_url, tool_name, tool_arguments, session_id
)
except HTTPClientError as e:
if e.status_code in _AUTH_STATUS_CODES and not creds:
# Server requires auth and user has no stored credentials
return self._build_setup_requirements(server_url, session_id)
logger.warning("MCP HTTP error for %s: %s", server_host(server_url), e)
return ErrorResponse(
message=f"MCP server returned HTTP {e.status_code}: {e}",
session_id=session_id,
)
except MCPClientError as e:
logger.warning("MCP client error for %s: %s", server_host(server_url), e)
return ErrorResponse(
message=str(e),
session_id=session_id,
)
except Exception:
logger.error(
"Unexpected error calling MCP server %s",
server_host(server_url),
exc_info=True,
)
return ErrorResponse(
message="An unexpected error occurred connecting to the MCP server. Please try again.",
session_id=session_id,
)
async def _discover_tools(
self,
client: MCPClient,
server_url: str,
session_id: str,
) -> MCPToolsDiscoveredResponse:
"""List available tools from an already-initialized MCPClient.
Called when the agent invokes run_mcp_tool with only server_url (no
tool_name). Returns MCPToolsDiscoveredResponse so the agent can
inspect tool schemas and choose one to execute in a follow-up call.
"""
tools = await client.list_tools()
tool_infos = [
MCPToolInfo(
name=t.name,
description=t.description,
input_schema=t.input_schema,
)
for t in tools
]
host = server_host(server_url)
return MCPToolsDiscoveredResponse(
message=(
f"Discovered {len(tool_infos)} tool(s) on {host}. "
"Call run_mcp_tool again with tool_name and tool_arguments to execute one."
),
server_url=server_url,
tools=tool_infos,
session_id=session_id,
)
async def _execute_tool(
self,
client: MCPClient,
server_url: str,
tool_name: str,
tool_arguments: dict[str, Any],
session_id: str,
) -> MCPToolOutputResponse | ErrorResponse:
"""Execute a specific tool on an already-initialized MCPClient.
Parses the MCP content response into a plain Python value:
- text items: parsed as JSON when possible, kept as str otherwise
- image items: kept as {type, data, mimeType} dict for frontend rendering
- resource items: unwrapped to their resource payload dict
Single-item responses are unwrapped from the list; multiple items are
returned as a list; empty content returns None.
"""
result = await client.call_tool(tool_name, tool_arguments)
if result.is_error:
error_text = " ".join(
item.get("text", "")
for item in result.content
if item.get("type") == "text"
)
return ErrorResponse(
message=f"MCP tool '{tool_name}' returned an error: {error_text or 'Unknown error'}",
session_id=session_id,
)
result_value = parse_mcp_content(result.content)
return MCPToolOutputResponse(
message=f"MCP tool '{tool_name}' executed successfully.",
server_url=server_url,
tool_name=tool_name,
result=result_value,
success=True,
session_id=session_id,
)
def _build_setup_requirements(
self,
server_url: str,
session_id: str,
) -> SetupRequirementsResponse | ErrorResponse:
"""Build a SetupRequirementsResponse for a missing MCP server credential."""
mcp_block = MCPToolBlock()
credentials_fields_info = mcp_block.input_schema.get_credentials_fields_info()
# Apply the server_url discriminator value so the frontend's CredentialsGroupedView
# can match the credential to the correct OAuth provider/server.
for field_info in credentials_fields_info.values():
if field_info.discriminator == "server_url":
field_info.discriminator_values.add(server_url)
missing_creds_dict = build_missing_credentials_from_field_info(
credentials_fields_info, matched_keys=set()
)
if not missing_creds_dict:
logger.error(
"No credential requirements found for MCP server %s"
"MCPToolBlock may not have credentials configured",
server_host(server_url),
)
return ErrorResponse(
message=(
f"The MCP server at {server_host(server_url)} requires authentication, "
"but no credential configuration was found."
),
session_id=session_id,
)
missing_creds_list = list(missing_creds_dict.values())
host = server_host(server_url)
return SetupRequirementsResponse(
message=(
f"The MCP server at {host} requires authentication. "
"Please connect your credentials to continue."
),
session_id=session_id,
setup_info=SetupInfo(
agent_id=server_url,
agent_name=f"MCP: {host}",
user_readiness=UserReadiness(
has_all_credentials=False,
missing_credentials=missing_creds_dict,
ready_to_run=False,
),
requirements={
"credentials": missing_creds_list,
"inputs": [],
"execution_modes": ["immediate"],
},
),
graph_id=None,
graph_version=None,
)

View File

@@ -0,0 +1,759 @@
"""Unit tests for the run_mcp_tool copilot tool."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
from backend.blocks.mcp.helpers import server_host
from ._test_data import make_session
from .models import (
ErrorResponse,
MCPToolOutputResponse,
MCPToolsDiscoveredResponse,
SetupRequirementsResponse,
)
from .run_mcp_tool import RunMCPToolTool
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_USER_ID = "test-user-run-mcp-tool"
_SERVER_URL = "https://remote.mcpservers.org/fetch/mcp"
def _make_tool_list(*names: str):
"""Build a list of mock MCPClientTool objects."""
tools = []
for name in names:
t = MagicMock()
t.name = name
t.description = f"Description for {name}"
t.input_schema = {"type": "object", "properties": {}, "required": []}
tools.append(t)
return tools
def _make_call_result(content: list[dict], is_error: bool = False) -> MagicMock:
result = MagicMock()
result.is_error = is_error
result.content = content
return result
# ---------------------------------------------------------------------------
# server_host helper
# ---------------------------------------------------------------------------
def test_server_host_plain_url():
assert server_host("https://mcp.example.com/mcp") == "mcp.example.com"
def test_server_host_strips_credentials():
"""netloc would expose user:pass — hostname must not."""
assert server_host("https://user:secret@mcp.example.com/mcp") == "mcp.example.com"
def test_server_host_with_port():
"""Port should not appear in the returned hostname (hostname strips it)."""
assert server_host("https://mcp.example.com:8080/mcp") == "mcp.example.com"
def test_server_host_invalid_url():
"""Falls back to the raw string for un-parseable URLs."""
result = server_host("not-a-url")
assert result == "not-a-url"
# ---------------------------------------------------------------------------
# Input validation
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_missing_server_url_returns_error():
tool = RunMCPToolTool()
session = make_session(_USER_ID)
response = await tool._execute(user_id=_USER_ID, session=session)
assert isinstance(response, ErrorResponse)
assert "server_url" in response.message.lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_missing_user_id_returns_error():
tool = RunMCPToolTool()
session = make_session(_USER_ID)
response = await tool._execute(
user_id=None, session=session, server_url=_SERVER_URL
)
assert isinstance(response, ErrorResponse)
assert "authentication" in response.message.lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_ssrf_blocked_url_returns_error():
"""Private/loopback URLs must be rejected before any network call."""
tool = RunMCPToolTool()
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):
response = await tool._execute(
user_id=_USER_ID, session=session, server_url="http://localhost/mcp"
)
assert isinstance(response, ErrorResponse)
assert (
"blocked" in response.message.lower() or "invalid" in response.message.lower()
)
@pytest.mark.asyncio(loop_scope="session")
async def test_credential_bearing_url_returns_error():
"""URLs with embedded user:pass@ must be rejected before any network call."""
tool = RunMCPToolTool()
session = make_session(_USER_ID)
response = await tool._execute(
user_id=_USER_ID,
session=session,
server_url="https://user:secret@mcp.example.com/mcp",
)
assert isinstance(response, ErrorResponse)
assert (
"credential" in response.message.lower()
or "do not include" in response.message.lower()
)
@pytest.mark.asyncio(loop_scope="session")
async def test_non_dict_tool_arguments_returns_error():
"""tool_arguments must be a JSON object — strings/arrays are rejected early."""
tool = RunMCPToolTool()
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url",
new_callable=AsyncMock,
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
):
response = await tool._execute(
user_id=_USER_ID,
session=session,
server_url=_SERVER_URL,
tool_name="fetch",
tool_arguments=["this", "is", "a", "list"], # wrong type
)
assert isinstance(response, ErrorResponse)
assert "json object" in response.message.lower()
# ---------------------------------------------------------------------------
# Stage 1 — Discovery
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_returns_discovered_response():
"""Calling with only server_url triggers discovery and returns tool list."""
tool = RunMCPToolTool()
session = make_session(_USER_ID)
mock_tools = _make_tool_list("fetch", "search")
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
):
mock_client = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=mock_tools)
with patch(
"backend.copilot.tools.run_mcp_tool.MCPClient",
return_value=mock_client,
):
response = await tool._execute(
user_id=_USER_ID,
session=session,
server_url=_SERVER_URL,
)
assert isinstance(response, MCPToolsDiscoveredResponse)
assert len(response.tools) == 2
assert response.tools[0].name == "fetch"
assert response.tools[1].name == "search"
assert response.server_url == _SERVER_URL
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_with_credentials():
"""Stored credentials are passed as Bearer token to MCPClient."""
tool = RunMCPToolTool()
session = make_session(_USER_ID)
mock_creds = MagicMock()
mock_creds.access_token = SecretStr("test-token-abc")
mock_tools = _make_tool_list("push_notification")
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=mock_creds,
):
mock_client = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=mock_tools)
with patch(
"backend.copilot.tools.run_mcp_tool.MCPClient",
) as MockMCPClient:
MockMCPClient.return_value = mock_client
response = await tool._execute(
user_id=_USER_ID,
session=session,
server_url=_SERVER_URL,
)
# Verify MCPClient was created with the resolved auth token
MockMCPClient.assert_called_once_with(
_SERVER_URL, auth_token="test-token-abc"
)
assert isinstance(response, MCPToolsDiscoveredResponse)
assert len(response.tools) == 1
# ---------------------------------------------------------------------------
# Stage 2 — Execution
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_execute_tool_returns_output_response():
"""Calling with tool_name executes the tool and returns MCPToolOutputResponse."""
tool = RunMCPToolTool()
session = make_session(_USER_ID)
text_result = "# Example Domain\nThis domain is for examples."
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
):
mock_result = _make_call_result([{"type": "text", "text": text_result}])
mock_client = AsyncMock()
mock_client.call_tool = AsyncMock(return_value=mock_result)
with patch(
"backend.copilot.tools.run_mcp_tool.MCPClient",
return_value=mock_client,
):
response = await tool._execute(
user_id=_USER_ID,
session=session,
server_url=_SERVER_URL,
tool_name="fetch",
tool_arguments={"url": "https://example.com"},
)
assert isinstance(response, MCPToolOutputResponse)
assert response.tool_name == "fetch"
assert response.server_url == _SERVER_URL
assert response.success is True
assert text_result in response.result
@pytest.mark.asyncio(loop_scope="session")
async def test_execute_tool_parses_json_result():
"""JSON text content items are parsed into Python objects."""
tool = RunMCPToolTool()
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
):
mock_result = _make_call_result(
[{"type": "text", "text": '{"status": "ok", "count": 42}'}]
)
mock_client = AsyncMock()
mock_client.call_tool = AsyncMock(return_value=mock_result)
with patch(
"backend.copilot.tools.run_mcp_tool.MCPClient",
return_value=mock_client,
):
response = await tool._execute(
user_id=_USER_ID,
session=session,
server_url=_SERVER_URL,
tool_name="status",
tool_arguments={},
)
assert isinstance(response, MCPToolOutputResponse)
assert response.result == {"status": "ok", "count": 42}
@pytest.mark.asyncio(loop_scope="session")
async def test_execute_tool_image_content():
"""Image content items are returned as {type, data, mimeType} dicts."""
tool = RunMCPToolTool()
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
):
mock_result = _make_call_result(
[{"type": "image", "data": "abc123==", "mimeType": "image/png"}]
)
mock_client = AsyncMock()
mock_client.call_tool = AsyncMock(return_value=mock_result)
with patch(
"backend.copilot.tools.run_mcp_tool.MCPClient",
return_value=mock_client,
):
response = await tool._execute(
user_id=_USER_ID,
session=session,
server_url=_SERVER_URL,
tool_name="screenshot",
tool_arguments={},
)
assert isinstance(response, MCPToolOutputResponse)
assert response.result == {
"type": "image",
"data": "abc123==",
"mimeType": "image/png",
}
@pytest.mark.asyncio(loop_scope="session")
async def test_execute_tool_resource_content():
"""Resource content items are unwrapped to their resource payload."""
tool = RunMCPToolTool()
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
):
mock_result = _make_call_result(
[
{
"type": "resource",
"resource": {"uri": "file:///tmp/out.txt", "text": "hello"},
}
]
)
mock_client = AsyncMock()
mock_client.call_tool = AsyncMock(return_value=mock_result)
with patch(
"backend.copilot.tools.run_mcp_tool.MCPClient",
return_value=mock_client,
):
response = await tool._execute(
user_id=_USER_ID,
session=session,
server_url=_SERVER_URL,
tool_name="read_file",
tool_arguments={},
)
assert isinstance(response, MCPToolOutputResponse)
assert response.result == {"uri": "file:///tmp/out.txt", "text": "hello"}
@pytest.mark.asyncio(loop_scope="session")
async def test_execute_tool_multi_item_content():
"""Multiple content items are returned as a list."""
tool = RunMCPToolTool()
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
):
mock_result = _make_call_result(
[
{"type": "text", "text": "part one"},
{"type": "text", "text": "part two"},
]
)
mock_client = AsyncMock()
mock_client.call_tool = AsyncMock(return_value=mock_result)
with patch(
"backend.copilot.tools.run_mcp_tool.MCPClient",
return_value=mock_client,
):
response = await tool._execute(
user_id=_USER_ID,
session=session,
server_url=_SERVER_URL,
tool_name="multi",
tool_arguments={},
)
assert isinstance(response, MCPToolOutputResponse)
assert response.result == ["part one", "part two"]
@pytest.mark.asyncio(loop_scope="session")
async def test_execute_tool_empty_content_returns_none():
"""Empty content list results in result=None."""
tool = RunMCPToolTool()
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
):
mock_result = _make_call_result([])
mock_client = AsyncMock()
mock_client.call_tool = AsyncMock(return_value=mock_result)
with patch(
"backend.copilot.tools.run_mcp_tool.MCPClient",
return_value=mock_client,
):
response = await tool._execute(
user_id=_USER_ID,
session=session,
server_url=_SERVER_URL,
tool_name="ping",
tool_arguments={},
)
assert isinstance(response, MCPToolOutputResponse)
assert response.result is None
@pytest.mark.asyncio(loop_scope="session")
async def test_execute_tool_returns_error_on_tool_failure():
"""When the MCP tool returns is_error=True, an ErrorResponse is returned."""
tool = RunMCPToolTool()
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
):
mock_result = _make_call_result(
[{"type": "text", "text": "Tool not found"}], is_error=True
)
mock_client = AsyncMock()
mock_client.call_tool = AsyncMock(return_value=mock_result)
with patch(
"backend.copilot.tools.run_mcp_tool.MCPClient",
return_value=mock_client,
):
response = await tool._execute(
user_id=_USER_ID,
session=session,
server_url=_SERVER_URL,
tool_name="nonexistent",
tool_arguments={},
)
assert isinstance(response, ErrorResponse)
assert "nonexistent" in response.message
# ---------------------------------------------------------------------------
# Auth / credential flow
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_auth_required_without_creds_returns_setup_requirements():
"""HTTP 401 from MCP with no stored creds → SetupRequirementsResponse."""
from backend.util.request import HTTPClientError
tool = RunMCPToolTool()
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None, # No stored credentials
):
mock_client = AsyncMock()
mock_client.initialize = AsyncMock(
side_effect=HTTPClientError("Unauthorized", status_code=401)
)
with patch(
"backend.copilot.tools.run_mcp_tool.MCPClient",
return_value=mock_client,
):
with patch.object(
RunMCPToolTool,
"_build_setup_requirements",
return_value=MagicMock(spec=SetupRequirementsResponse),
) as mock_build:
response = await tool._execute(
user_id=_USER_ID,
session=session,
server_url=_SERVER_URL,
)
mock_build.assert_called_once()
# Should have returned what _build_setup_requirements returned
assert response is mock_build.return_value
@pytest.mark.asyncio(loop_scope="session")
async def test_auth_error_with_existing_creds_returns_error():
"""HTTP 403 when creds ARE present → generic ErrorResponse (not setup card)."""
from backend.util.request import HTTPClientError
tool = RunMCPToolTool()
session = make_session(_USER_ID)
mock_creds = MagicMock()
mock_creds.access_token = SecretStr("stale-token")
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=mock_creds,
):
mock_client = AsyncMock()
mock_client.initialize = AsyncMock(
side_effect=HTTPClientError("Forbidden", status_code=403)
)
with patch(
"backend.copilot.tools.run_mcp_tool.MCPClient",
return_value=mock_client,
):
response = await tool._execute(
user_id=_USER_ID,
session=session,
server_url=_SERVER_URL,
)
assert isinstance(response, ErrorResponse)
assert "403" in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_mcp_client_error_returns_error_response():
"""MCPClientError (protocol-level) maps to a clean ErrorResponse."""
from backend.blocks.mcp.client import MCPClientError
tool = RunMCPToolTool()
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
):
mock_client = AsyncMock()
mock_client.initialize = AsyncMock(
side_effect=MCPClientError("JSON-RPC protocol error")
)
with patch(
"backend.copilot.tools.run_mcp_tool.MCPClient",
return_value=mock_client,
):
response = await tool._execute(
user_id=_USER_ID,
session=session,
server_url=_SERVER_URL,
)
assert isinstance(response, ErrorResponse)
assert "JSON-RPC" in response.message or "protocol" in response.message.lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_unexpected_exception_returns_generic_error():
"""Unhandled exceptions inside the MCP call don't leak traceback text to the user."""
tool = RunMCPToolTool()
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
):
mock_client = AsyncMock()
# An unexpected error inside initialize (inside the try block)
mock_client.initialize = AsyncMock(
side_effect=ValueError("Unexpected internal error")
)
with patch(
"backend.copilot.tools.run_mcp_tool.MCPClient",
return_value=mock_client,
):
response = await tool._execute(
user_id=_USER_ID,
session=session,
server_url=_SERVER_URL,
)
assert isinstance(response, ErrorResponse)
# Must not leak the raw exception message
assert "Unexpected internal error" not in response.message
assert (
"unexpected" in response.message.lower() or "error" in response.message.lower()
)
# ---------------------------------------------------------------------------
# Tool metadata
# ---------------------------------------------------------------------------
def test_tool_name():
assert RunMCPToolTool().name == "run_mcp_tool"
def test_tool_requires_auth():
assert RunMCPToolTool().requires_auth is True
def test_tool_parameters_schema():
params = RunMCPToolTool().parameters
assert params["type"] == "object"
assert "server_url" in params["properties"]
assert "tool_name" in params["properties"]
assert "tool_arguments" in params["properties"]
assert params["required"] == ["server_url"]
# ---------------------------------------------------------------------------
# Query/fragment rejection
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_query_in_url_returns_error():
"""server_url with query parameters must be rejected."""
tool = RunMCPToolTool()
session = make_session(_USER_ID)
response = await tool._execute(
user_id=_USER_ID,
session=session,
server_url="https://mcp.example.com/mcp?key=val",
)
assert isinstance(response, ErrorResponse)
assert "query" in response.message.lower() or "fragment" in response.message.lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_fragment_in_url_returns_error():
"""server_url with a fragment must be rejected."""
tool = RunMCPToolTool()
session = make_session(_USER_ID)
response = await tool._execute(
user_id=_USER_ID,
session=session,
server_url="https://mcp.example.com/mcp#section",
)
assert isinstance(response, ErrorResponse)
assert "query" in response.message.lower() or "fragment" in response.message.lower()
# ---------------------------------------------------------------------------
# Credential lookup normalization
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_credential_lookup_normalizes_trailing_slash():
"""Credential lookup must normalize the URL (strip trailing slash)."""
tool = RunMCPToolTool()
session = make_session(_USER_ID)
url_with_slash = "https://mcp.example.com/mcp/"
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
) as mock_lookup:
mock_client = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=[])
with patch(
"backend.copilot.tools.run_mcp_tool.MCPClient",
return_value=mock_client,
):
await tool._execute(
user_id=_USER_ID,
session=session,
server_url=url_with_slash,
)
# Credential lookup should use the normalized URL (no trailing slash)
mock_lookup.assert_called_once_with(_USER_ID, "https://mcp.example.com/mcp")
# ---------------------------------------------------------------------------
# _build_setup_requirements
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_build_setup_requirements_returns_setup_response():
"""_build_setup_requirements should return a SetupRequirementsResponse."""
tool = RunMCPToolTool()
result = tool._build_setup_requirements(
server_url=_SERVER_URL,
session_id="test-session",
)
assert isinstance(result, SetupRequirementsResponse)
assert result.setup_info.agent_id == _SERVER_URL
assert "authentication" in result.message.lower()

View File

@@ -30,6 +30,10 @@ _TEXT_CONTENT_TYPES = {
"application/xhtml+xml",
"application/rss+xml",
"application/atom+xml",
# RFC 7807 — JSON problem details; used by many REST APIs for error responses
"application/problem+json",
"application/problem+xml",
"application/ld+json",
}

View File

@@ -8,6 +8,7 @@ from typing import Any, Optional
from pydantic import BaseModel
from backend.copilot.model import ChatSession
from backend.copilot.tools.e2b_sandbox import E2B_WORKDIR
from backend.copilot.tools.sandbox import make_session_path
from backend.data.db_accessors import workspace_db
from backend.util.settings import Config
@@ -20,7 +21,7 @@ from .models import ErrorResponse, ResponseType, ToolResponseBase
logger = logging.getLogger(__name__)
def _resolve_write_content(
async def _resolve_write_content(
content_text: str | None,
content_b64: str | None,
source_path: str | None,
@@ -30,6 +31,9 @@ def _resolve_write_content(
Returns the raw bytes on success, or an ``ErrorResponse`` on validation
failure (wrong number of sources, invalid path, file not found, etc.).
When an E2B sandbox is active, ``source_path`` reads from the sandbox
filesystem instead of the local ephemeral directory.
"""
# Normalise empty strings to None so counting and dispatch stay in sync.
if content_text is not None and content_text == "":
@@ -54,24 +58,7 @@ def _resolve_write_content(
)
if source_path is not None:
validated = _validate_ephemeral_path(
source_path, param_name="source_path", session_id=session_id
)
if isinstance(validated, ErrorResponse):
return validated
try:
with open(validated, "rb") as f:
return f.read()
except FileNotFoundError:
return ErrorResponse(
message=f"Source file not found: {source_path}",
session_id=session_id,
)
except Exception as e:
return ErrorResponse(
message=f"Failed to read source file: {e}",
session_id=session_id,
)
return await _read_source_path(source_path, session_id)
if content_b64 is not None:
try:
@@ -91,6 +78,106 @@ def _resolve_write_content(
return content_text.encode("utf-8")
def _resolve_sandbox_path(
path: str, session_id: str | None, param_name: str
) -> str | ErrorResponse:
"""Normalize *path* to an absolute sandbox path under :data:`E2B_WORKDIR`.
Delegates to :func:`~backend.copilot.sdk.e2b_file_tools._resolve_remote`
and wraps any ``ValueError`` into an :class:`ErrorResponse`.
"""
from backend.copilot.sdk.e2b_file_tools import _resolve_remote
try:
return _resolve_remote(path)
except ValueError:
return ErrorResponse(
message=f"{param_name} must be within {E2B_WORKDIR}",
session_id=session_id,
)
async def _read_source_path(source_path: str, session_id: str) -> bytes | ErrorResponse:
"""Read *source_path* from E2B sandbox or local ephemeral directory."""
from backend.copilot.sdk.tool_adapter import get_current_sandbox
sandbox = get_current_sandbox()
if sandbox is not None:
remote = _resolve_sandbox_path(source_path, session_id, "source_path")
if isinstance(remote, ErrorResponse):
return remote
try:
data = await sandbox.files.read(remote, format="bytes")
return bytes(data)
except Exception as exc:
return ErrorResponse(
message=f"Source file not found on sandbox: {source_path} ({exc})",
session_id=session_id,
)
# Local fallback: validate path stays within ephemeral directory.
validated = _validate_ephemeral_path(
source_path, param_name="source_path", session_id=session_id
)
if isinstance(validated, ErrorResponse):
return validated
try:
with open(validated, "rb") as f:
return f.read()
except FileNotFoundError:
return ErrorResponse(
message=f"Source file not found: {source_path}",
session_id=session_id,
)
except Exception as e:
return ErrorResponse(
message=f"Failed to read source file: {e}",
session_id=session_id,
)
async def _save_to_path(
path: str, content: bytes, session_id: str
) -> str | ErrorResponse:
"""Write *content* to *path* on E2B sandbox or local ephemeral directory.
Returns the resolved path on success, or an ``ErrorResponse`` on failure.
"""
from backend.copilot.sdk.tool_adapter import get_current_sandbox
sandbox = get_current_sandbox()
if sandbox is not None:
remote = _resolve_sandbox_path(path, session_id, "save_to_path")
if isinstance(remote, ErrorResponse):
return remote
try:
await sandbox.files.write(remote, content)
except Exception as exc:
return ErrorResponse(
message=f"Failed to write to sandbox: {path} ({exc})",
session_id=session_id,
)
return remote
validated = _validate_ephemeral_path(
path, param_name="save_to_path", session_id=session_id
)
if isinstance(validated, ErrorResponse):
return validated
try:
dir_path = os.path.dirname(validated)
if dir_path:
os.makedirs(dir_path, exist_ok=True)
with open(validated, "wb") as f:
f.write(content)
except Exception as exc:
return ErrorResponse(
message=f"Failed to write to local path: {path} ({exc})",
session_id=session_id,
)
return validated
def _validate_ephemeral_path(
path: str, *, param_name: str, session_id: str
) -> ErrorResponse | str:
@@ -131,7 +218,7 @@ def _is_text_mime(mime_type: str) -> bool:
return any(mime_type.startswith(t) for t in _TEXT_MIME_PREFIXES)
async def _get_manager(user_id: str, session_id: str) -> WorkspaceManager:
async def get_manager(user_id: str, session_id: str) -> WorkspaceManager:
"""Create a session-scoped WorkspaceManager."""
workspace = await workspace_db().get_or_create_workspace(user_id)
return WorkspaceManager(user_id, workspace.id, session_id)
@@ -214,7 +301,11 @@ class WorkspaceWriteResponse(ToolResponseBase):
file_id: str
name: str
path: str
mime_type: str
size_bytes: int
# workspace:// URL the agent can embed directly in chat to give the user a link.
# Format: workspace://<file_id>#<mime_type> (frontend resolves to download URL)
download_url: str
source: str | None = None # "content", "base64", or "copied from <path>"
content_preview: str | None = None # First 200 chars for text files
@@ -295,7 +386,7 @@ class ListWorkspaceFilesTool(BaseTool):
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
try:
manager = await _get_manager(user_id, session_id)
manager = await get_manager(user_id, session_id)
files = await manager.list_files(
path=path_prefix, limit=limit, include_all_sessions=include_all_sessions
)
@@ -341,7 +432,7 @@ class ListWorkspaceFilesTool(BaseTool):
class ReadWorkspaceFileTool(BaseTool):
"""Tool for reading file content from workspace."""
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB for text/image files
PREVIEW_SIZE = 500
@property
@@ -357,8 +448,10 @@ class ReadWorkspaceFileTool(BaseTool):
"Specify either file_id or path to identify the file. "
"For small text files, returns content directly. "
"For large or binary files, returns metadata and a download URL. "
"Optionally use 'save_to_path' to copy the file to the ephemeral "
"working directory for processing with bash_exec or SDK tools. "
"Use 'save_to_path' to copy the file to the working directory "
"(sandbox or ephemeral) for processing with bash_exec or file tools. "
"Use 'offset' and 'length' for paginated reads of large files "
"(e.g., persisted tool outputs). "
"Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@@ -382,9 +475,10 @@ class ReadWorkspaceFileTool(BaseTool):
"save_to_path": {
"type": "string",
"description": (
"If provided, save the file to this path in the ephemeral "
"working directory (e.g., '/tmp/copilot-.../data.csv') "
"so it can be processed with bash_exec or SDK tools. "
"If provided, save the file to this path in the working "
"directory (cloud sandbox when E2B is active, or "
"ephemeral dir otherwise) so it can be processed with "
"bash_exec or file tools. "
"The file content is still returned in the response."
),
},
@@ -395,6 +489,20 @@ class ReadWorkspaceFileTool(BaseTool):
"Default is false (auto-selects based on file size/type)."
),
},
"offset": {
"type": "integer",
"description": (
"Character offset to start reading from (0-based). "
"Use with 'length' for paginated reads of large files."
),
},
"length": {
"type": "integer",
"description": (
"Maximum number of characters to return. "
"Defaults to full file. Use with 'offset' for paginated reads."
),
},
},
"required": [], # At least one must be provided
}
@@ -419,23 +527,16 @@ class ReadWorkspaceFileTool(BaseTool):
path: Optional[str] = kwargs.get("path")
save_to_path: Optional[str] = kwargs.get("save_to_path")
force_download_url: bool = kwargs.get("force_download_url", False)
char_offset: int = max(0, kwargs.get("offset", 0))
char_length: Optional[int] = kwargs.get("length")
if not file_id and not path:
return ErrorResponse(
message="Please provide either file_id or path", session_id=session_id
)
# Validate and resolve save_to_path (use sanitized real path).
if save_to_path:
validated_save = _validate_ephemeral_path(
save_to_path, param_name="save_to_path", session_id=session_id
)
if isinstance(validated_save, ErrorResponse):
return validated_save
save_to_path = validated_save
try:
manager = await _get_manager(user_id, session_id)
manager = await get_manager(user_id, session_id)
resolved = await _resolve_file(manager, file_id, path, session_id)
if isinstance(resolved, ErrorResponse):
return resolved
@@ -445,11 +546,38 @@ class ReadWorkspaceFileTool(BaseTool):
cached_content: bytes | None = None
if save_to_path:
cached_content = await manager.read_file_by_id(target_file_id)
dir_path = os.path.dirname(save_to_path)
if dir_path:
os.makedirs(dir_path, exist_ok=True)
with open(save_to_path, "wb") as f:
f.write(cached_content)
result = await _save_to_path(save_to_path, cached_content, session_id)
if isinstance(result, ErrorResponse):
return result
save_to_path = result
# Ranged read: return a character slice directly.
if char_offset > 0 or char_length is not None:
raw = cached_content or await manager.read_file_by_id(target_file_id)
text = raw.decode("utf-8", errors="replace")
total_chars = len(text)
end = (
char_offset + char_length
if char_length is not None
else total_chars
)
slice_text = text[char_offset:end]
return WorkspaceFileContentResponse(
file_id=file_info.id,
name=file_info.name,
path=file_info.path,
mime_type="text/plain",
content_base64=base64.b64encode(slice_text.encode("utf-8")).decode(
"utf-8"
),
message=(
f"Read chars {char_offset}"
f"{char_offset + len(slice_text)} "
f"of {total_chars:,} total "
f"from {file_info.name}"
),
session_id=session_id,
)
is_small = file_info.size_bytes <= self.MAX_INLINE_SIZE_BYTES
is_text = _is_text_mime(file_info.mime_type)
@@ -625,7 +753,7 @@ class WriteWorkspaceFileTool(BaseTool):
content_text: str | None = kwargs.get("content")
content_b64: str | None = kwargs.get("content_base64")
resolved = _resolve_write_content(
resolved = await _resolve_write_content(
content_text,
content_b64,
source_path_arg,
@@ -644,7 +772,7 @@ class WriteWorkspaceFileTool(BaseTool):
try:
await scan_content_safe(content, filename=filename)
manager = await _get_manager(user_id, session_id)
manager = await get_manager(user_id, session_id)
rec = await manager.write_file(
content=content,
filename=filename,
@@ -680,11 +808,21 @@ class WriteWorkspaceFileTool(BaseTool):
except Exception:
pass
# Strip MIME parameters (e.g. "text/html; charset=utf-8" → "text/html")
# and normalise to lowercase so the fragment is URL-safe.
normalized_mime = (rec.mime_type or "").split(";", 1)[0].strip().lower()
download_url = (
f"workspace://{rec.id}#{normalized_mime}"
if normalized_mime
else f"workspace://{rec.id}"
)
return WorkspaceWriteResponse(
file_id=rec.id,
name=rec.name,
path=rec.path,
mime_type=normalized_mime,
size_bytes=rec.size_bytes,
download_url=download_url,
source=source,
content_preview=preview,
message=msg,
@@ -761,7 +899,7 @@ class DeleteWorkspaceFileTool(BaseTool):
)
try:
manager = await _get_manager(user_id, session_id)
manager = await get_manager(user_id, session_id)
resolved = await _resolve_file(manager, file_id, path, session_id)
if isinstance(resolved, ErrorResponse):
return resolved

View File

@@ -102,67 +102,68 @@ class TestValidateEphemeralPath:
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
class TestResolveWriteContent:
def test_no_sources_returns_error(self):
async def test_no_sources_returns_error(self):
from backend.copilot.tools.models import ErrorResponse
result = _resolve_write_content(None, None, None, "s1")
result = await _resolve_write_content(None, None, None, "s1")
assert isinstance(result, ErrorResponse)
def test_multiple_sources_returns_error(self):
async def test_multiple_sources_returns_error(self):
from backend.copilot.tools.models import ErrorResponse
result = _resolve_write_content("text", "b64data", None, "s1")
result = await _resolve_write_content("text", "b64data", None, "s1")
assert isinstance(result, ErrorResponse)
def test_plain_text_content(self):
result = _resolve_write_content("hello world", None, None, "s1")
async def test_plain_text_content(self):
result = await _resolve_write_content("hello world", None, None, "s1")
assert result == b"hello world"
def test_base64_content(self):
async def test_base64_content(self):
raw = b"binary data"
b64 = base64.b64encode(raw).decode()
result = _resolve_write_content(None, b64, None, "s1")
result = await _resolve_write_content(None, b64, None, "s1")
assert result == raw
def test_invalid_base64_returns_error(self):
async def test_invalid_base64_returns_error(self):
from backend.copilot.tools.models import ErrorResponse
result = _resolve_write_content(None, "not-valid-b64!!!", None, "s1")
result = await _resolve_write_content(None, "not-valid-b64!!!", None, "s1")
assert isinstance(result, ErrorResponse)
assert "base64" in result.message.lower()
def test_source_path(self, ephemeral_dir):
async def test_source_path(self, ephemeral_dir):
target = ephemeral_dir / "input.txt"
target.write_bytes(b"file content")
result = _resolve_write_content(None, None, str(target), "s1")
result = await _resolve_write_content(None, None, str(target), "s1")
assert result == b"file content"
def test_source_path_not_found(self, ephemeral_dir):
async def test_source_path_not_found(self, ephemeral_dir):
from backend.copilot.tools.models import ErrorResponse
missing = str(ephemeral_dir / "nope.txt")
result = _resolve_write_content(None, None, missing, "s1")
result = await _resolve_write_content(None, None, missing, "s1")
assert isinstance(result, ErrorResponse)
def test_source_path_outside_ephemeral(self, ephemeral_dir, tmp_path):
async def test_source_path_outside_ephemeral(self, ephemeral_dir, tmp_path):
from backend.copilot.tools.models import ErrorResponse
outside = tmp_path / "outside.txt"
outside.write_text("nope")
result = _resolve_write_content(None, None, str(outside), "s1")
result = await _resolve_write_content(None, None, str(outside), "s1")
assert isinstance(result, ErrorResponse)
def test_empty_string_sources_treated_as_none(self):
async def test_empty_string_sources_treated_as_none(self):
from backend.copilot.tools.models import ErrorResponse
# All empty strings → same as no sources
result = _resolve_write_content("", "", "", "s1")
result = await _resolve_write_content("", "", "", "s1")
assert isinstance(result, ErrorResponse)
def test_empty_string_source_path_with_text(self):
async def test_empty_string_source_path_with_text(self):
# source_path="" should be normalised to None, so only content counts
result = _resolve_write_content("hello", "", "", "s1")
result = await _resolve_write_content("hello", "", "", "s1")
assert result == b"hello"
@@ -235,6 +236,65 @@ async def test_workspace_file_round_trip(setup_test_data):
assert not any(f.file_id == file_id for f in list_resp2.files)
# ---------------------------------------------------------------------------
# Ranged reads (offset / length)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_read_workspace_file_with_offset_and_length(setup_test_data):
"""Read a slice of a text file using offset and length."""
user = setup_test_data["user"]
session = make_session(user.id)
# Write a known-content file
content = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" * 100 # 2600 chars
write_tool = WriteWorkspaceFileTool()
write_resp = await write_tool._execute(
user_id=user.id,
session=session,
filename="ranged_test.txt",
content=content,
)
assert isinstance(write_resp, WorkspaceWriteResponse), write_resp.message
file_id = write_resp.file_id
from backend.copilot.tools.workspace_files import WorkspaceFileContentResponse
read_tool = ReadWorkspaceFileTool()
# Read with offset=100, length=50
resp = await read_tool._execute(
user_id=user.id, session=session, file_id=file_id, offset=100, length=50
)
assert isinstance(resp, WorkspaceFileContentResponse), resp.message
decoded = base64.b64decode(resp.content_base64).decode()
assert decoded == content[100:150]
assert "100" in resp.message
assert "2,600" in resp.message # total chars (comma-formatted)
# Read with offset only (no length) — returns from offset to end
resp2 = await read_tool._execute(
user_id=user.id, session=session, file_id=file_id, offset=2500
)
assert isinstance(resp2, WorkspaceFileContentResponse)
decoded2 = base64.b64decode(resp2.content_base64).decode()
assert decoded2 == content[2500:]
assert len(decoded2) == 100
# Read with offset beyond file length — returns empty string
resp3 = await read_tool._execute(
user_id=user.id, session=session, file_id=file_id, offset=9999, length=10
)
assert isinstance(resp3, WorkspaceFileContentResponse)
decoded3 = base64.b64decode(resp3.content_base64).decode()
assert decoded3 == ""
# Cleanup
delete_tool = DeleteWorkspaceFileTool()
await delete_tool._execute(user_id=user.id, session=session, file_id=file_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_write_workspace_file_source_path(setup_test_data):
"""E2E: write a file from ephemeral source_path to workspace."""

View File

@@ -81,6 +81,7 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.CLAUDE_4_OPUS: 21,
LlmModel.CLAUDE_4_SONNET: 5,
LlmModel.CLAUDE_4_6_OPUS: 14,
LlmModel.CLAUDE_4_6_SONNET: 9,
LlmModel.CLAUDE_4_5_HAIKU: 4,
LlmModel.CLAUDE_4_5_OPUS: 14,
LlmModel.CLAUDE_4_5_SONNET: 9,

View File

@@ -178,9 +178,13 @@ async def test_block_credit_reset(server: SpinTestServer):
assert month2_balance == 1100 # Balance persists, no reset
# Now test the refill behavior when balance is low
# Set balance below refill threshold
# Set balance below refill threshold and backdate updatedAt to month2 so
# the month3 refill check sees a different (month2 → month3) transition.
# Without the explicit updatedAt, Prisma sets it to real-world NOW which
# may share the same calendar month as the mocked month3, suppressing refill.
await UserBalance.prisma().update(
where={"userId": DEFAULT_USER_ID}, data={"balance": 400}
where={"userId": DEFAULT_USER_ID},
data={"balance": 400, "updatedAt": month2},
)
# Create a month 2 transaction to update the last transaction time

View File

@@ -79,6 +79,12 @@ INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
}
LIBRARY_FOLDER_INCLUDE: prisma.types.LibraryFolderInclude = {
"LibraryAgents": {"where": {"isDeleted": False}},
"Children": {"where": {"isDeleted": False}},
}
def library_agent_include(
user_id: str,
include_nodes: bool = True,
@@ -105,6 +111,7 @@ def library_agent_include(
"""
result: prisma.types.LibraryAgentInclude = {
"Creator": True, # Always needed for creator info
"Folder": True, # Always needed for folder info
}
# Build AgentGraph include based on requested options

View File

@@ -184,7 +184,7 @@ async def find_webhook_by_credentials_and_props(
credentials_id: str,
webhook_type: str,
resource: str,
events: list[str],
events: Optional[list[str]],
) -> Webhook | None:
webhook = await IntegrationWebhook.prisma().find_first(
where={
@@ -192,7 +192,7 @@ async def find_webhook_by_credentials_and_props(
"credentialsId": credentials_id,
"webhookType": webhook_type,
"resource": resource,
"events": {"has_every": events},
**({"events": {"has_every": events}} if events else {}),
},
)
return Webhook.from_db(webhook) if webhook else None

View File

@@ -211,6 +211,22 @@ class AgentRejectionData(BaseNotificationData):
return value
class WaitlistLaunchData(BaseNotificationData):
"""Notification data for when an agent from a waitlist is launched."""
agent_name: str
waitlist_name: str
store_url: str
launched_at: datetime
@field_validator("launched_at")
@classmethod
def validate_timezone(cls, value: datetime):
if value.tzinfo is None:
raise ValueError("datetime must have timezone information")
return value
NotificationData = Annotated[
Union[
AgentRunData,
@@ -223,6 +239,7 @@ NotificationData = Annotated[
DailySummaryData,
RefundRequestData,
BaseSummaryData,
WaitlistLaunchData,
],
Field(discriminator="type"),
]
@@ -273,6 +290,7 @@ def get_notif_data_type(
NotificationType.REFUND_PROCESSED: RefundRequestData,
NotificationType.AGENT_APPROVED: AgentApprovalData,
NotificationType.AGENT_REJECTED: AgentRejectionData,
NotificationType.WAITLIST_LAUNCH: WaitlistLaunchData,
}[notification_type]
@@ -318,6 +336,7 @@ class NotificationTypeOverride:
NotificationType.REFUND_PROCESSED: QueueType.ADMIN,
NotificationType.AGENT_APPROVED: QueueType.IMMEDIATE,
NotificationType.AGENT_REJECTED: QueueType.IMMEDIATE,
NotificationType.WAITLIST_LAUNCH: QueueType.IMMEDIATE,
}
return BATCHING_RULES.get(self.notification_type, QueueType.IMMEDIATE)
@@ -337,6 +356,7 @@ class NotificationTypeOverride:
NotificationType.REFUND_PROCESSED: "refund_processed.html",
NotificationType.AGENT_APPROVED: "agent_approved.html",
NotificationType.AGENT_REJECTED: "agent_rejected.html",
NotificationType.WAITLIST_LAUNCH: "waitlist_launch.html",
}[self.notification_type]
@property
@@ -354,6 +374,7 @@ class NotificationTypeOverride:
NotificationType.REFUND_PROCESSED: "Refund for ${{data.amount / 100}} to {{data.user_name}} has been processed",
NotificationType.AGENT_APPROVED: "🎉 Your agent '{{data.agent_name}}' has been approved!",
NotificationType.AGENT_REJECTED: "Your agent '{{data.agent_name}}' needs some updates",
NotificationType.WAITLIST_LAUNCH: "🚀 {{data.agent_name}} is now available!",
}[self.notification_type]

View File

@@ -327,11 +327,16 @@ async def get_workspace_total_size(workspace_id: str) -> int:
"""
Get the total size of all files in a workspace.
Queries Prisma directly (skipping Pydantic model conversion) and only
fetches the ``sizeBytes`` column to minimise data transfer.
Args:
workspace_id: The workspace ID
Returns:
Total size in bytes
"""
files = await list_workspace_files(workspace_id)
return sum(file.size_bytes for file in files)
files = await UserWorkspaceFile.prisma().find_many(
where={"workspaceId": workspace_id, "isDeleted": False},
)
return sum(f.sizeBytes for f in files)

Some files were not shown because too many files have changed in this diff Show More