Compare commits

...

37 Commits

Author SHA1 Message Date
Lluis Agusti
948119d071 feat(frontend): add hit-area utility for improved touch targets
Install the hit-area Tailwind CSS utility (via shadcn registry) and apply
it to small interactive elements in the Copilot and Library Agent pages.
The utility expands clickable zones using CSS pseudo-elements without
affecting visual layout.

Targets improved:
- File chip remove button (FileChips.tsx) — ~12px → hit-area-4
- Chat sidebar "more actions" trigger (ChatSidebar.tsx) — ~20px → hit-area-3
- Mobile dropdown trigger (CopilotPage.tsx) — ~20px → hit-area-3
- Collapsed tool group toggle (CollapsedToolGroup.tsx) — hit-area-y-2
- MessageAction buttons (message.tsx) — 32px → hit-area-2
- Sidebar dropdown triggers ×4 (*ActionsDropdown.tsx) — ~20px → hit-area-3

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-11 16:32:09 +08:00
Zamil Majdy
15e3980d65 fix(frontend): buffer workspace file downloads to prevent truncation (#12349)
## Summary
- Workspace file downloads (images, CSVs, etc.) were silently truncated
(~10 KB lost from the end) when served through the Next.js proxy
- Root cause: `new NextResponse(response.body)` passes a
`ReadableStream` directly, which Next.js / Vercel silently truncates for
larger files
- Fix: fully buffer with `response.arrayBuffer()` before forwarding, and
set `Content-Length` from the actual buffer size
- Keeps the auth proxy intact — no signed URLs (which would be public
and expire, breaking chat history)

## Root cause verification
Confirmed locally on session `080f27f9-0379-4085-a67a-ee34cc40cd62`:
- Backend `write_workspace_file` logs **978,831 bytes** written
- Direct backend download (`curl
localhost:8006/api/workspace/files/.../download`): **978,831 bytes** 
- Browser download through Next.js proxy: **truncated** 

## Why not signed URLs?
- Signed URLs are effectively public — anyone with the link can download
the file (privacy concern)
- Signed URLs expire, but chat history persists — reopening a
conversation later would show broken downloads
- Buffering is fine: workspace files are capped at 100 MB, Vercel
function memory is 1 GB

## Related
- Discord thread: `#Truncated File Bug` channel
- Related PR #12319 (signed URL approach) — this fix is simpler and
preserves auth

## Test plan
- [ ] Download a workspace file (CSV, PNG, any type) through the copilot
UI
- [ ] Verify downloaded file size matches the original
- [ ] Verify PNGs open correctly and CSVs have all rows

cc @Swiftyos @uberdot @AdarshRawat1
2026-03-10 18:23:51 +00:00
Zamil Majdy
fe9eb2564b feat(copilot): HITL review for sensitive block execution (#12356)
## Summary
- Integrates existing Human-In-The-Loop (HITL) review infrastructure
into CoPilot's direct block execution (`run_block`) for blocks marked
with `is_sensitive_action=True`
- Removes the `PendingHumanReview → AgentGraphExecution` FK constraint
to support synthetic CoPilot session IDs (migration included)
- Adds `ReviewRequiredResponse` model + frontend `ReviewRequiredCard`
component to surface review status in the chat UI
- Auto-approval works within a CoPilot session: once a block is
approved, subsequent executions of the same block in the same session
are auto-approved (using `copilot-session-{session_id}` as
`graph_exec_id` and `copilot-node-{block_id}` as `node_id`)

## Test plan
- [x] All 11 `run_block_test.py` tests pass (3 new sensitive action
tests)
- [ ] Manual: Execute a block with `is_sensitive_action=True` in CoPilot
→ verify ReviewRequiredResponse is returned and rendered
- [ ] Manual: Approve in review panel → re-execute the same block →
verify auto-approval kicks in
- [ ] Manual: Verify non-sensitive blocks still execute without review
2026-03-10 18:20:11 +00:00
Otto
5641cdd3ca fix(backend): update test patches for validate_url → validate_url_host rename (#12358)
bfb843a renamed `validate_url` to `validate_url_host` in
`agent_browser`, `run_mcp_tool`, and MCP routes, but the corresponding
test files still patched the old name, causing `AttributeError` in CI.

Updates all mock patch targets and assertions across 3 test files:
- `agent_browser_test.py`
- `test_run_mcp_tool.py`  
- `mcp/test_routes.py`

---
Co-authored-by: Zamil Majdy (@majdyz) <zamil.majdy@agpt.co>
Co-authored-by: Reinier van der Leer (@Pwuts) <pwuts@agpt.co>
2026-03-10 17:22:11 +00:00
Otto
bfb843a56e Merge commit from fork
* Fix SSRF via user-controlled ollama_host field

Validate ollama_host against BLOCKED_IP_NETWORKS before passing to
ollama.AsyncClient(). The server-configured default (env: OLLAMA_HOST)
is allowed without validation; user-supplied values that differ are
checked for private/internal IP resolution.

Fixes GHSA-6jx2-4h7q-3fx3

* Generalize validate_ollama_host to validate_host; fix description line length

* Rename to validate_untrusted_host with whitelist parameter

* Apply PR suggestion: include whitelist in error message; run formatting

* Move whitelist check after URL normalization; match on netloc

* revert unrelated formatting changes

* Dedup validate_url and validate_untrusted_host; normalize whitelist

* Move _resolve_and_check_blocked after calling functions

* dedup and clean up

* make trusted_hostnames truly optional

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2026-03-10 15:51:58 +01:00
Abhimanyu Yadav
684845d946 fix(frontend/builder): handle discriminated unions and improve node layout (#12354)
## Summary
- **Discriminated union support (oneOf)**: Added a new `OneOfField`
component that properly
renders Pydantic discriminated unions. Hides the unusable parent object
handle, auto-populates
the discriminator value, shows a dropdown with variant titles (e.g.,
"Username" / "UserId"), and
filters out the internal discriminator field from the form.
Non-discriminated `oneOf` schemas
  fall back to existing `AnyOfField` behavior.
- **Collapsible object outputs**: Object-type outputs with nested keys
(e.g.,
`PersonLookupResponse.Url`, `PersonLookupResponse.profile`) are now
collapsed by default behind a
caret toggle. Nested keys show short names instead of the full
`Parent.Key` prefix.
- **Node layout cleanup**: Removed excessive bottom margin (`mb-6`) from
`FormRenderer`, hide the
Advanced toggle when no advanced fields exist, and add rounded bottom
corners on OUTPUT-type
  blocks.

<img width="440" height="427" alt="Screenshot 2026-03-10 at 11 31 55 AM"
src="https://github.com/user-attachments/assets/06cc5414-4e02-4371-bdeb-1695e7cb2c97"
/>
<img width="371" height="320" alt="Screenshot 2026-03-10 at 11 36 52 AM"
src="https://github.com/user-attachments/assets/1a55f87a-c602-4f4d-b91b-6e49f810e5d5"
/>

  ## Test plan
- [x] Add a Twitter Get User block — verify "Identifier" shows a
dropdown (Username/UserId) with
no unusable parent handle, discriminator field is hidden, and the block
can run without staying
  INCOMPLETE
- [x] Add any block with object outputs (e.g., PersonLookupResponse) —
verify nested keys are
  collapsed by default and expand on click with short labels
- [x] Verify blocks without advanced fields don't show the Advanced
toggle
- [x] Verify existing `anyOf` schemas (optional types, 3+ variant
unions) still render correctly
  - [x] Check OUTPUT-type blocks have rounded bottom corners

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
Co-authored-by: eureka928 <meobius123@gmail.com>
2026-03-10 14:13:32 +00:00
Bently
6a6b23c2e1 fix(frontend): Remove unused Otto Server Action causing 107K+ errors (#12336)
## Summary

Fixes [OPEN-3025](https://linear.app/autogpt/issue/OPEN-3025) —
**107,571+ Server Action errors** in production

Removes the orphaned `askOtto` Server Action that was left behind after
the Otto chat widget removal in PR #12082.

## Problem

Next.js Server Actions that are never imported are excluded from the
server manifest. Old client bundles still reference the action ID,
causing "not found" errors.

**Sentry impact:**
- **BUILDER-3BN:** 107,571 events
- **BUILDER-729:** 285 events  
- **BUILDER-3QH:** 1,611 events
- **36+ users affected**

## Root Cause

1. **Mar 2025:** Otto widget added to `/build` page with `askOtto`
Server Action
2. **Feb 2026:** Otto widget removed (PR #12082), but `actions.ts` left
behind
3. **Result:** Dead code → not in manifest → errors

## Evidence

```bash
# Zero imports across frontend:
grep -r "askOtto" src/ --exclude="actions.ts"
# → No results

# Server manifest missing the action:
cat .next/server/server-reference-manifest.json
# → Only includes login/supabase actions, NOT build/actions
```

## Changes

-  Delete
`autogpt_platform/frontend/src/app/(platform)/build/actions.ts`

## Testing

1. Verify no imports of `askOtto` in codebase 
2. Check Sentry for error drop after deploy
3. Monitor for new "Server Action not found" errors

## Checklist

- [x] Dead code confirmed (zero imports)
- [x] Sentry issues documented
- [x] Clear commit message with context
2026-03-10 09:03:38 +00:00
Dream
d0a1d72e8a fix(frontend/builder): batch undo history for cascading operations (#12344)
## Summary

Fixes undo in the Builder not working correctly when deleting nodes.
When a node is deleted, React Flow fires `onNodesChange` (node removal)
and `onEdgesChange` (cascading edge cleanup) as separate callbacks —
each independently pushing to the undo history stack. This creates
intermediate states that break undo:

- Single undo restores a partial state (e.g. edges pointing to a deleted
node)
- Multiple undos required to fully restore the graph
- Redo also produces inconsistent states

Resolves #10999

### Changes 🏗️

- **`historyStore.ts`** — Added microtask-based batching to
`pushState()`. Multiple calls within the same synchronous execution
(same event loop tick) are coalesced into a single history entry,
keeping only the first pre-change snapshot. Uses `queueMicrotask` so all
cascading store updates from a single user action settle before the
history entry is committed.
- Reset `pendingState` in `initializeHistory()` and `clear()` to prevent
stale batched state from leaking across graph loads or navigation.

**Side benefit:** Copy/paste operations that add multiple nodes and
edges now also produce a single history entry instead of one per
node/edge.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Place 3 blocks A, B, C and connect A→B→C
  - [x] Delete block C (removes node + cascading edge B→C)
  - [x] Delete connection A→B
  - [x] Undo — connection A→B restored (single undo, not multiple)
  - [x] Undo — block C and connection B→C restored
  - [x] Redo — block C removed again with its connections
- [x] Copy/paste multiple connected blocks — single undo reverts entire
paste

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
Co-authored-by: Abhimanyu Yadav <122007096+Abhi1992002@users.noreply.github.com>
2026-03-10 04:55:07 +00:00
Zamil Majdy
f1945d6a2f feat(platform/copilot): @@agptfile: file-ref protocol for tool call inputs + block input toggle (#12332)
## Summary

- **Problem**: When the LLM calls a tool with large file content, it
must rewrite all content token-by-token. This is wasteful since the
files are already accessible on disk.
- **Solution**: Introduces an \`@@agptfile:\` reference protocol. The
LLM passes a file path reference; the processor loads and substitutes
the content before executing the tool.

### Protocol

\`\`\`
@@agptfile:<uri>[<start>-<end>]
\`\`\`

**Supported URI types:**
| URI | Source |
|-----|--------|
| \`workspace://<file_id>\` | Persistent workspace file by ID |
| \`workspace:///<path>\` | Workspace file by virtual path |
| \`/absolute/path\` | Absolute host or sandbox path |

**Line range** is optional; omitting it reads the whole file.

### Backend changes

- Rename \`@file:\` → \`@@agptfile:\` prefix for uniqueness; extract
\`FILE_REF_PREFIX\` constant
- Extract shared execution-context ContextVars into
\`backend/copilot/context.py\` — eliminates duplicate ContextVar objects
that caused \`e2b_file_tools.py\` to always see empty context
- \`tool_adapter.py\` imports ContextVars from \`context.py\` (single
source of truth)
- \`expand_file_refs_in_string\` raises \`FileRefExpansionError\` on
failure (instead of inline error strings), blocking tool execution and
returning a clear error hint to the model
- Tighten URI regex: only expand refs starting with \`workspace://\` or
\`/\`
- Aggregate budget: 1 MB total expansion cap across all refs in one
string
- Per-file cap: 200 KB per individual ref
- Fix \`_read_file_handler\` to pass \`get_sdk_cwd()\` to
\`is_allowed_local_path\` — ephemeral working directory files were
incorrectly blocked
- Fix \`_is_allowed_local\` in \`e2b_file_tools.py\` to pass
\`get_sdk_cwd()\`
- Restrict local path allow-list to \`tool-results/\` subdirectory only
(was entire session project dir)
- Add \`raise_on_error\` param + remove two-pass \`_FILE_REF_ERROR_RE\`
detection
- Update system prompt docs and tool_adapter error messages

### Frontend changes

- \`BlockInputCard\`: hidden by default with Show/Hide toggle + \`mb-2\`
spacing

## Test plan

- [ ] \`poetry run pytest backend/copilot/ -x
--ignore=backend/copilot/sdk/file_ref_integration_test.py\` passes
- [ ] \`@@agptfile:workspace:///<path>[1-50]\` expands correctly in tool
calls
- [ ] Invalid line ranges produce \`[file-ref error: ...]\` inline
messages
- [ ] Files outside \`sdk_cwd\` / \`tool-results/\` are rejected
- [ ] Block input card shows hidden by default with toggle
2026-03-09 18:39:13 +00:00
Zamil Majdy
6491cb1e23 feat(copilot): local agent generation with validation, fixing, MCP & sub-agent support (#12238)
## Summary

Port the agent generation pipeline from the external AgentGenerator
service into local copilot tools, making the Claude Agent SDK itself
handle validation, fixing, and block recommendation — no separate inner
LLM calls needed.

Key capabilities:
- **Local agent generation**: Create, edit, and customize agents
entirely within the SDK session
- **Graph validation**: 9 validation checks (block existence, link
references, type compatibility, IO blocks, etc.)
- **Graph fixing**: 17+ auto-fix methods (ID repair, link rewiring, type
conversion, credential stripping, dynamic block sink names, etc.)
- **MCP tool blocks**: Guide and fixer support for MCPToolBlock nodes
with proper dynamic input schema handling
- **Sub-agent composition**: AgentExecutorBlock support with library
agent schema enrichment
- **Embedding fallback**: Falls back to OpenRouter for embeddings when
`openai_internal_api_key` is unavailable
- **Actionable error messages**: Excluded block types (MCP, Agent)
return specific hints redirecting to the correct tool

### New Tools
- `validate_agent_graph` — run 9 validation checks on agent JSON
- `fix_agent_graph` — apply 17+ auto-fixes to agent JSON
- `get_blocks_for_goal` — recommend blocks for a given goal (with
optimized descriptions)

### Refactored Tools
- `create_agent`, `edit_agent`, `customize_agent` — accept `agent_json`
for local generation with shared fix→validate→save pipeline
- `find_block` — added `include_schemas` parameter, excludes MCP/Agent
blocks with actionable hints
- `run_block` — actionable error messages for excluded block types
- `find_library_agent` — enriched with `graph_version`, `input_schema`,
`output_schema` for sub-agent composition

### Architecture
- Split 2,558-line `validation.py` into `fixer.py`, `validator.py`,
`helpers.py`, `pipeline.py`
- Extracted shared `fix_validate_and_save()` pipeline (was duplicated
across 3 tools)
- Shared `OPENROUTER_BASE_URL` constant across codebase
- Comprehensive test coverage: 78+ unit tests for fixer/validator, 8
run_block tests, 17 SDK compat tests

## Test plan
- [x] `poetry run format` passes
- [x] `poetry run pytest -s -vvv backend/copilot/` — all tests pass
- [x] CI green on all Python versions (3.11, 3.12, 3.13)
- [x] Manual E2E: copilot generates agents with correct IO blocks,
links, and node structure
- [x] Manual E2E: MCP tool blocks use bare field names for dynamic
inputs
- [x] Manual E2E: sub-agent composition with AgentExecutorBlock
2026-03-09 16:10:22 +00:00
nKOxxx
c7124a5240 Add documentation for Google Gemini integration (#12283)
## Summary
Adding comprehensive documentation for Google Gemini integration with
AutoGPT.

## Changes
- Added setup instructions for Gemini API
- Documented configuration options
- Added examples and best practices

## Related Issues
N/A - Documentation improvement

## Testing
- Verified documentation accuracy
- Tested all code examples

## Checklist
- [x] Code follows project style
- [x] Documentation updated
- [x] Tests pass (if applicable)
2026-03-09 15:13:28 +00:00
Zamil Majdy
5537cb2858 dx: add shared Claude Code skills as auto-triggered guidelines (#12297)
## Summary
- Add 8 Claude Code skills under \`.claude/skills/\` that act as
**auto-triggered guidelines** — the LLM invokes them automatically based
on context, no manual \`/command\` needed
- Skills: \`pr-review\`, \`pr-create\`, \`new-block\`,
\`openapi-regen\`, \`backend-check\`, \`frontend-check\`,
\`worktree-setup\`, \`code-style\`
- Each skill has an explicit TRIGGER condition so the LLM knows when to
apply it without being asked

## Changes

### Skills (all auto-triggered by context)
| Skill | Trigger |
|-------|---------|
| \`pr-review\` | User shares a PR URL or asks to address review
comments |
| \`pr-create\` | User asks to create a PR, push changes for review, or
submit work |
| \`new-block\` | User asks to create a new block or add a new
integration |
| \`openapi-regen\` | API routes change, new endpoints added, or
frontend types are stale |
| \`backend-check\` | Backend Python code has been modified |
| \`frontend-check\` | Frontend TypeScript/React code has been modified
|
| \`worktree-setup\` | User asks to work on a branch in isolation or set
up a worktree |
| \`code-style\` | Writing or reviewing Python code |

## Test plan
- [ ] Verify skills appear automatically in Claude Code when context
matches (no \`/command\` needed)
- [ ] Modify frontend code — confirm \`frontend-check\` fires
automatically
- [ ] Ask Claude to "create a PR" — confirm \`pr-create\` fires without
\`/pr-create\`
- [ ] Share a PR URL — confirm \`pr-review\` fires automatically

---------

Co-authored-by: Krzysztof Czerwinski <kpczerwinski@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 15:10:38 +00:00
Zamil Majdy
aef5f6d666 feat(copilot): E2B sandbox auto-pause between turns to eliminate idle billing (#12330)
## Summary

### Before
- E2B sandboxes ran continuously between CoPilot turns, billing for idle
time
- Sandbox timeout caused **termination** (kill), losing all session
state
- No explicit cleanup when sessions were deleted — sandboxes leaked
- Single timeout concept with no separation between pause and kill
semantics

### After
- **Per-turn pause**: `pause_sandbox()` is called in the `finally` block
after every CoPilot turn, stopping billing instantly between turns
(paused sandboxes cost \$0 compute)
- **Auto-pause safety net**: Sandboxes are created with
`lifecycle={"on_timeout": "pause"}` (`pause_timeout` = 4h default) so
they auto-pause rather than terminate if the explicit pause is missed
- **Auto-reconnect**: `AsyncSandbox.connect()` in e2b SDK v2
auto-resumes paused sandboxes transparently — no extra code needed
- **Session delete cleanup**: `kill_sandbox()` is now called in
`delete_chat_session()` to explicitly terminate sandboxes and free
resources
- **Two distinct timeouts**: `pause_timeout` (4h, e2b auto-pause) vs
`redis_ttl` (12h, session key lifetime)

### Key Changes

| File | Change |
|------|--------|
| `pyproject.toml` | Bump `e2b-code-interpreter` `1.x` → `2.x` |
| `e2b_sandbox.py` | Add `pause_sandbox()`, `kill_sandbox()`,
`_act_on_sandbox()` helper; `lifecycle={"on_timeout": "pause"}`;
separate `pause_timeout` / `redis_ttl` params |
| `sdk/service.py` | Call `pause_sandbox()` in `finally` block
**before** transcript upload; use walrus operator for type-safe
`e2b_api_key` narrowing |
| `model.py` | Call `kill_sandbox()` in `delete_chat_session()`; inline
import to avoid circular dependency |
| `config.py` | Add `e2b_active` property; rename `e2b_sandbox_timeout`
default to 4h |
| `e2b_sandbox_test.py` | Add `test_pause_then_reconnect_reuses_sandbox`
test; update all `sandbox_timeout` → `pause_timeout` |

### Verified E2E
- Used real `E2B_API_KEY` from k8s dev cluster to manually verify:
sandbox created → paused → `is_running() == False` → reconnected via
`connect()` → state preserved → killed

## Test plan
- [x] `poetry run pytest backend/copilot/tools/e2b_sandbox_test.py` —
all 19 tests pass
- [x] CI: test (3.11, 3.12, 3.13), types — all green
- [x] E2E verified with real E2B credentials
2026-03-09 14:55:10 +00:00
Ubbe
8063391d0a feat(frontend/copilot): pin interactive tool cards outside reasoning collapse (#12346)
## Summary

<img width="400" height="227" alt="Screenshot 2026-03-09 at 22 43 10"
src="https://github.com/user-attachments/assets/0116e260-860d-4466-9763-e02de2766e50"
/>

<img width="600" height="618" alt="Screenshot 2026-03-09 at 22 43 14"
src="https://github.com/user-attachments/assets/beaa6aca-afa8-483f-ac06-439bf162c951"
/>

- When the copilot stream finishes, tool calls that require user
interaction (credentials, inputs, clarification) are now **pinned**
outside the "Show reasoning" collapse instead of being hidden
- Added `isInteractiveToolPart()` helper that checks tool output's
`type` field against a set of interactive response types
- Modified `splitReasoningAndResponse()` to extract interactive tools
from reasoning into the visible response section
- Added styleguide section with 3 demos: `setup_requirements`,
`agent_details`, and `agent_saved` pinning scenarios

### Interactive response types kept visible:
`setup_requirements`, `agent_details`, `block_details`, `need_login`,
`input_validation_error`, `clarification_needed`, `suggested_goal`,
`agent_preview`, `agent_saved`

Error responses remain in reasoning (LLM explains them in final text).

Closes SECRT-2088

## Test plan
- [ ] Verify copilot stream with interactive tool (e.g. run_agent
requiring credentials) keeps the tool card visible after stream ends
- [ ] Verify non-interactive tools (find_block, bash_exec) still
collapse into "Show reasoning"
- [ ] Verify styleguide page at `/copilot/styleguide` renders the new
"Reasoning Collapse: Interactive Tool Pinning" section correctly
- [ ] Verify `pnpm types`, `pnpm lint`, `pnpm format` all pass

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 23:12:14 +08:00
Otto
0bbb12d688 fix(frontend/copilot): hide New Chat button on Autopilot homepage (#12321)
Requested by @0ubbe

The **New Chat** button was visible on the Autopilot homepage where
clicking it has no effect (since `sessionId` is already `null`). This
hides the button when no chat session is active, so it only appears when
the user is viewing a conversation and wants to start a new one.

**Changes:**
- `ChatSidebar.tsx` — hide button in both collapsed and expanded sidebar
states when `sessionId` is null
- `MobileDrawer.tsx` — same fix for mobile drawer

---
Co-authored-by: Ubbe <ubbe@users.noreply.github.com>
2026-03-09 22:41:11 +08:00
Otto
eadc68f2a5 feat(frontend/copilot): move microphone button to right side of input box (#12320)
Requested by @olivia-1421

Moves the microphone/recording button from the left-side tools group to
the right side, next to the submit button. The left side is now reserved
for the attachment/upload (plus) button only.

**Before:** `[ 📎 🎤 ] .................. [ ➤ ]`
**After:**  `[ 📎 ] .................. [ 🎤 ➤ ]`

---
Co-authored-by: Olivia <olivia-1421@users.noreply.github.com>

---------

Co-authored-by: Ubbe <hi@ubbe.dev>
Co-authored-by: Lluis Agusti <hi@llu.lu>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 18:37:02 +08:00
Reinier van der Leer
eca7b5e793 Merge commit from fork 2026-03-08 10:24:44 +01:00
Otto
c304a4937a fix(backend): Handle manual run attempts for triggered agents (#12298)
When a webhook-triggered agent is executed directly (e.g. via Copilot)
without actual webhook data, `GraphExecution.from_db()` crashes with
`KeyError: 'payload'` because it does a hard key access on
`exec.input_data["payload"]` for webhook blocks.

This caused 232 Sentry events (AUTOGPT-SERVER-821) and multiple
INCOMPLETE graph executions due to retries.

**Changes:**

1. **Defensive fix in `from_db()`** — use `.get("payload")` instead of
`["payload"]` to handle missing keys gracefully (matches existing
pattern for input blocks using `.get("value")`)

2. **Upfront refusal in `_construct_starting_node_execution_input()`** —
refuse execution of webhook/webhook_manual blocks when no payload is
provided. The check is placed after `nodes_input_masks` application, so
legitimate webhook triggers (which inject payload via
`nodes_input_masks`) pass through fine.

Resolves [SENTRY-1113: Copilot is able to manually initiate runs for
triggered agents (which
fails)](https://linear.app/autogpt/issue/SENTRY-1113/copilot-is-able-to-manually-initiate-runs-for-triggered-agents-which)

---
Co-authored-by: Reinier van der Leer (@Pwuts) <pwuts@agpt.co>
2026-03-06 20:47:51 +00:00
Zamil Majdy
8cfabcf4fd refactor(backend/copilot): centralize prompt building in prompting.py (#12324)
## Summary

Centralizes all prompt building logic into a new
`backend/copilot/prompting.py` module with clear SDK vs baseline and
local vs E2B distinctions.

### Key Changes

**New `prompting.py` module:**
- `get_sdk_supplement(use_e2b, cwd)` - For SDK mode (NO tool docs -
Claude gets schemas automatically)
- `get_baseline_supplement(use_e2b, cwd)` - For baseline mode (WITH
auto-generated tool docs from TOOL_REGISTRY)
- Handles local/E2B storage differences

**SDK mode (`sdk/service.py`):**
- Removed 165+ lines of duplicate constants
- Now imports and uses `get_sdk_supplement()`
- Cleaner, more maintainable

**Baseline mode (`baseline/service.py`):**
- Now appends `get_baseline_supplement()` to system prompt
- Baseline mode finally gets tool documentation!

**Enhanced tool descriptions:**
- `create_agent`: Added feedback loop workflow (suggested_goal,
clarifying_questions)
- `run_mcp_tool`: Added known server URLs, 2-step workflow, auth
handling

**Tests:**
- Updated to verify SDK excludes tool docs, baseline includes them
- All existing tests pass

### Architecture Benefits

 Single source of truth for prompt supplements
 Clear SDK vs baseline distinction (SDK doesn't need tool docs)
 Clear local vs E2B distinction (storage systems)
 Easy to maintain and update
 Eliminates code duplication

## Test plan

- [x] Unit tests pass (TestPromptSupplement class)
- [x] SDK mode excludes tool documentation
- [x] Baseline mode includes tool documentation
- [x] E2B vs local mode differences handled correctly
2026-03-06 18:56:20 +00:00
Zamil Majdy
7bf407b66c Merge branch 'master' of github.com:Significant-Gravitas/AutoGPT into dev 2026-03-07 02:01:41 +07:00
Zamil Majdy
7ead4c040f hotfix(backend/copilot): capture tool results in transcript (#12323)
## Summary
- Fixes tool results not being captured in the CoPilot transcript during
SDK-based streaming
- Adds `transcript_builder.add_user_message()` call with `tool_result`
content block when a `StreamToolOutputAvailable` event is received
- Ensures transcript accurately reflects the full conversation including
tool outputs, which is critical for Langfuse tracing and debugging

## Context
After the transcript refactor in #12318, tool call results from the SDK
streaming loop were not being recorded in the transcript. This meant
Langfuse traces were missing tool outputs, making it hard to debug agent
behavior.

## Test plan
- [ ] Verify CoPilot conversation with tool calls captures tool results
in Langfuse traces
- [ ] Verify transcript includes tool_result content blocks after tool
execution
2026-03-06 18:58:48 +00:00
Abhimanyu Yadav
0f813f1bf9 feat(copilot): Add folder management tools to CoPilot (#12290)
Adds folder management capabilities to the CoPilot, allowing users to
organize agents into folders directly from the chat interface.

<img width="823" height="356" alt="Screenshot 2026-03-05 at 5 26 30 PM"
src="https://github.com/user-attachments/assets/4c55f926-1e71-488f-9eb6-fca87c4ab01b"
/>
<img width="797" height="150" alt="Screenshot 2026-03-05 at 5 28 40 PM"
src="https://github.com/user-attachments/assets/5c9c6f8b-57ac-4122-b17d-b9f091bb7c4e"
/>
<img width="763" height="196" alt="Screenshot 2026-03-05 at 5 28 36 PM"
src="https://github.com/user-attachments/assets/d1b22b5d-921d-44ac-90e8-a5820bb3146d"
/>
<img width="756" height="199" alt="Screenshot 2026-03-05 at 5 30 17 PM"
src="https://github.com/user-attachments/assets/40a59748-f42e-4521-bae0-cc786918a9b5"
/>

### Changes

**Backend -- 6 new CoPilot tools** (`manage_folders.py`):
- `create_folder` -- Create folders with optional parent, icon, and
color
- `list_folders` -- List folder tree or children of a specific folder,
with optional `include_agents` to show agents inside each folder
- `update_folder` -- Rename or change icon/color
- `move_folder` -- Reparent a folder or move to root
- `delete_folder` -- Soft-delete (agents moved to root, not deleted)
- `move_agents_to_folder` -- Bulk-move agents into a folder or back to
root

**Backend -- DatabaseManager RPC registration**:
- Registered all 7 folder DB functions (`create_folder`, `list_folders`,
`get_folder_tree`, `update_folder`, `move_folder`, `delete_folder`,
`bulk_move_agents_to_folder`) in `DatabaseManager` and
`DatabaseManagerAsyncClient` so they work via RPC in the CoPilotExecutor
process
- `manage_folders.py` uses `db_accessors.library_db()` pattern
(consistent with all other copilot tools) instead of direct Prisma
imports

**Backend -- folder_id threading**:
- `create_agent` and `customize_agent` tools accept optional `folder_id`
to save agents directly into a folder
- `save_agent_to_library` -> `create_graph_in_library` ->
`create_library_agent` pipeline passes `folder_id` through
- `create_library_agent` refactored from `asyncio.gather` to sequential
loop to support conditional `folderId` assignment on the main graph only
(not sub-graphs)

**Backend -- system prompt and models**:
- Added folder tool descriptions and usage guidance to Otto's system
prompt
- Added `FolderAgentSummary` model for lightweight agent info in folder
listings
- Added 6 `ResponseType` enum values and corresponding Pydantic response
models (`FolderInfo`, `FolderTreeInfo`, `FolderCreatedResponse`, etc.)

**Frontend -- FolderTool UI component**:
- `FolderTool.tsx` -- Renders folder operations in chat using the
`file-tree` molecule component for tree view, with `FileIcon` for agents
and `FolderIcon` for folders (both `text-neutral-600`)
- `helpers.ts` -- Type guards, output parsing, animation text helpers,
and `FolderAgentSummary` type
- `MessagePartRenderer.tsx` -- Routes 6 folder tool types to
`FolderTool` component
- Flat folder list view shows agents inside `FolderCard` when
`include_agents` is set

**Frontend -- file-tree molecule**:
- Fixed 3 pre-existing lint errors in `file-tree.tsx` (unused `ref`,
`handleSelect`, `className` params)
- Updated tree indicator line color from `bg-neutral-100` to
`bg-neutral-400` for visibility
- Added `file-tree.stories.tsx` with 5 stories: Default, AllExpanded,
FoldersOnly, WithInitialSelection, NoIndicator
- Added `ui/scroll-area.tsx` (dependency of file-tree, was missing from
non-legacy ui folder)

### Checklist

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Create a folder via copilot chat ("create a folder called
Marketing")
  - [x] List folders ("show me my folders")
- [x] List folders with agents ("show me my folders and the agents in
them")
- [x] Update folder name/icon/color ("rename Marketing folder to Sales")
- [x] Move folder to a different parent ("move Sales into the Projects
folder")
  - [x] Delete a folder and verify agents move to root
- [x] Move agents into a folder ("put my newsletter agent in the
Marketing folder")
- [x] Create agent with folder_id ("create a scraper agent and save it
in my Tools folder")
- [x] Verify FolderTool UI renders loading, success, error, and empty
states correctly
- [x] Verify folder tree renders nested folders with file-tree component
- [x] Verify agents appear as FileIcon nodes in tree view when
include_agents is true
  - [x] Verify file-tree storybook stories render correctly
2026-03-06 14:59:03 +00:00
Reinier van der Leer
aa08063939 refactor(backend/db): Improve & clean up Marketplace DB layer & API (#12284)
These changes were part of #12206, but here they are separately for
easier review.
This is all primarily to make the v2 API (#11678) work possible/easier.

### Changes 🏗️

- Fix relations between `Profile`, `StoreListing`, and `AgentGraph`
- Redefine `StoreSubmission` view with more efficient joins (100x
speed-up on dev DB) and more consistent field names
- Clean up query functions in `store/db.py`
- Clean up models in `store/model.py`
- Add missing fields to `StoreAgent` and `StoreSubmission` views
- Rename ambiguous `agent_id` -> `graph_id`
- Clean up API route definitions & docs in `store/routes.py`
  - Make routes more consistent
- Avoid collision edge-case between `/agents/{username}/{agent_name}`
and `/agents/{store_listing_version_id}/*`
- Replace all usages of legacy `BackendAPI` for store endpoints with
generated client
- Remove scope requirements on public store endpoints in v1 external API

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Test all Marketplace views (including admin views)
    - [x] Download an agent from the marketplace
  - [x] Submit an agent to the Marketplace
  - [x] Approve/reject Marketplace submission
2026-03-06 14:38:12 +00:00
Zamil Majdy
bde6a4c0df Merge branch 'master' of github.com:Significant-Gravitas/AutoGPT into dev
# Conflicts:
#	autogpt_platform/backend/backend/copilot/sdk/service.py
2026-03-06 21:07:37 +07:00
Zamil Majdy
d56452898a hotfix(backend/copilot): refactor transcript to SDK-based atomic full-context model (#12318)
## Summary

Major refactor to eliminate CLI transcript race conditions and simplify
the codebase by building transcripts directly from SDK messages instead
of reading CLI files.

## Problem

The previous approach had race conditions:
- SDK reads CLI transcript file during stop hook
- CLI may not have finished writing → incomplete transcript
- Complex merge logic to detect and fix incomplete writes
- ~200 lines of synthetic entry detection and merge code

## Solution

**Atomic Full-Context Transcript Model:**
- Build transcript from SDK messages during streaming
(`TranscriptBuilder`)
- Each upload REPLACES the previous transcript entirely (atomic)
- No CLI file reading → no race conditions
- Eliminates all merge complexity

## Key Changes

### Core Refactor
- **NEW**: `transcript_builder.py` - Build JSONL from SDK messages
during streaming
- **SIMPLIFIED**: `transcript.py` - Removed merge logic, simplified
upload/download
- **SIMPLIFIED**: `service.py` - Use TranscriptBuilder, removed stop
hook callback
- **CLEANED**: `security_hooks.py` - Removed `on_stop` parameter

### Performance & Code Quality
- **orjson migration**: Use `backend.util.json` (2-3x faster than
stdlib)
- Added `fallback` parameter to `json.loads()` for cleaner error
handling
- Moved SDK imports to top-level per code style guidelines

### Bug Fixes
- Fixed garbage collection bug in background task handling
- Fixed double upload bug in timeout handling  
- Downgraded PII-risk logging from WARNING to DEBUG
- Added 30s timeout to prevent session lock hang

## Code Removed (~200 lines)

- `merge_with_previous_transcript()` - No longer needed
- `read_transcript_file()` - No longer needed
- `CapturedTranscript` dataclass - No longer needed
- `_on_stop()` callback - No longer needed
- Synthetic entry detection logic - No longer needed
- Manual append/merge logic in finally block - No longer needed

## Testing

-  All transcript tests passing (24/24)
-  Verified with real session logs showing proper transcript growth
-  Verified with Langfuse traces showing proper turn tracking (1-8)

## Transcript Growth Pattern

From session logs:
- **Turn 1**: 2 entries (initial)
- **Turn 2**: 5 entries (+3), 2257B uploaded
- **Turn N**: ~2N entries (linear growth)

Each upload is the **complete atomic state** - always REPLACES, never
incremental.

## Files Changed

```
backend/copilot/sdk/transcript_builder.py (NEW)   | +140 lines
backend/copilot/sdk/transcript.py                  | -198, +125 lines  
backend/copilot/sdk/service.py                     | -214, +160 lines
backend/copilot/sdk/security_hooks.py              | -33, +10 lines
backend/copilot/sdk/transcript_test.py             | -85, +36 lines
backend/util/json.py                               | +45 lines
```

**Net result**: -200 lines, more reliable, faster JSON operations.

## Migration Notes

This is a **breaking change** for any code that:
- Directly calls `merge_with_previous_transcript()` or
`read_transcript_file()`
- Relies on incremental transcript uploads
- Expects stop hook callbacks

All internal usage has been updated.

---

@ntindle - Tagging for autogpt-reviewer
2026-03-06 21:03:49 +07:00
Ubbe
7507240177 feat(copilot): collapse repeated tool calls and fix stream stuck on completion (#12282)
## Summary
- **Frontend:** Group consecutive completed generic tool parts into
collapsible summary rows with a "Reasoning" collapse for finalized
messages. Merge consecutive assistant messages on hydration to avoid
split bubbles. Extract GenericTool helpers. Add `reconnectExhausted`
state and a brief delay before refetching session to reduce stale
`active_stream` reconnect cycles.
- **Backend:** Make transcript upload fire-and-forget instead of
blocking the generator exit. The 30s upload timeout in
`_try_upload_transcript` was delaying `mark_session_completed()`,
keeping the SSE stream alive with only heartbeats after the LLM had
finished — causing the UI to stay stuck in "streaming" state.

## Test plan
- [ ] Send a message in Copilot that triggers multiple tool calls —
verify they collapse into a grouped summary row once completed
- [ ] Verify the final text response appears below the collapsed
reasoning section
- [ ] Confirm the stream properly closes after the agent finishes (no
stuck "Stop" button)
- [ ] Refresh mid-stream and verify reconnection works correctly
- [ ] Click Stop during streaming — verify the UI becomes responsive
immediately

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-06 21:21:59 +08:00
Abhimanyu Yadav
d7c3f5b8fc fix(frontend): bypass Next.js proxy for file uploads to fix 413 error (#12315)
## Summary
- File uploads routed through the Next.js API proxy (`/api/proxy/...`)
fail with HTTP 413 for files >4.5MB due to Vercel's serverless function
body size limit
- Created shared `uploadFileDirect` utility (`src/lib/direct-upload.ts`)
that uploads files directly from the browser to the Python backend,
bypassing the proxy entirely
- Updated `useWorkspaceUpload` to use direct upload instead of the
generated hook (which went through the proxy)
- Deduplicated the copilot page's inline upload logic to use the same
shared utility

## Changes 🏗️
- **New**: `src/lib/direct-upload.ts` — shared utility for
direct-to-backend file uploads (up to 256MB)
- **Updated**: `useWorkspaceUpload.ts` — replaced proxy-based generated
hook with `uploadFileDirect`
- **Updated**: `useCopilotPage.ts` — replaced inline upload logic with
shared `uploadFileDirect`, removed unused imports

## Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Upload a file >5MB via workspace file input (e.g. in agent
builder) — should succeed without 413
  - [x] Upload a file >5MB via copilot chat — should succeed without 413
  - [x] Upload a small file (<1MB) via both paths — should still work
  - [x] Verify file delete still works from workspace file input
2026-03-06 12:20:18 +00:00
Otto
3e108a813a fix(backend): Use db_manager for workspace in add_graph_execution (#12312)
When `add_graph_execution` is called from a context where the global
Prisma client isn't connected (e.g. CoPilot tools, external API), the
call to `get_or_create_workspace(user_id)` crashes with
`ClientNotConnectedError` because it directly accesses
`UserWorkspace.prisma()`.

The fix adds `workspace_db` to the existing `if prisma.is_connected()`
fallback pattern, consistent with how all other DB calls in the function
already work.

**Sentry:** AUTOGPT-SERVER-83T (and ~15 related issues going back to Jan
2026)

---
Co-authored-by: Reinier van der Leer (@Pwuts) <pwuts@agpt.co>

Co-authored-by: Reinier van der Leer (@Pwuts) <pwuts@agpt.co>
2026-03-06 08:48:15 +01:00
Krzysztof Czerwinski
08c49a78f8 feat(copilot): UX improvements (#12258)
CoPilot conversation UX improvements (SECRT-2055):

1. **Rename conversations** — Inline rename via the session dropdown
menu. New `PATCH /sessions/{session_id}/title` endpoint with server-side
validation (rejects blank/whitespace-only titles, normalizes
whitespace). Pressing Enter or clicking away submits; Escape cancels
without submitting.

2. **New Chat button moved to top & sticky** — The 'New Chat' button is
now at the top of the sidebar (under 'Your chats') instead of the
footer, and stays fixed — only the session list below it scrolls. A
subtle shadow separator mirrors the original footer style.

3. **Auto-generated title appears live** — After the first message in a
new chat, the sidebar polls for the backend-generated title and animates
it in smoothly once available. The backend also guards against
auto-title overwriting a user-set title.

4. **External Link popup redesign** — Replaced the CSS-hacked external
link confirmation dialog with a proper AutoGPT `Dialog` component using
the design system (`Button`, `Text`, `Dialog`). Removed the old
`globals.css` workaround.

<img width="321" height="263" alt="Screenshot 2026-03-03 at 6 31 50 pm"
src="https://github.com/user-attachments/assets/3cdd1c6f-cca6-4f16-8165-15a1dc2d53f7"
/>

<img width="374" height="74" alt="Screenshot 2026-03-02 at 6 39 07 pm"
src="https://github.com/user-attachments/assets/6f9fc953-5fa7-4469-9eab-7074e7604519"
/>

<img width="548" height="293" alt="Screenshot 2026-03-02 at 6 36 28 pm"
src="https://github.com/user-attachments/assets/0f34683b-7281-4826-ac6f-ac7926e67854"
/>

### Changes 🏗️

**Backend:**
- `routes.py`: Added `PATCH /sessions/{session_id}/title` endpoint with
`UpdateSessionTitleRequest` Pydantic model — validates non-blank title,
normalizes whitespace, returns 404 vs 500 correctly
- `routes_test.py`: New test file — 7 test cases covering success,
whitespace trimming, blank rejection (422), not found (404), internal
failure (500)
- `service.py`: Auto-title generation now checks if a user-set title
already exists before overwriting
- `openapi.json`: Updated with new endpoint schema

**Frontend:**
- `ChatSidebar.tsx`: Inline rename (Enter/blur submits, Escape cancels
via ref flag); "New Chat" button sticky at top with shadow separator;
session title animates when auto-generated title appears
(`AnimatePresence`)
- `useCopilotPage.ts`: Polls for auto-generated title after stream ends,
stops as soon as title appears in cache
- `MobileDrawer.tsx`: Updated to match sidebar layout changes
- `DeleteChatDialog.tsx`: Removed redundant `onClose` prop (controlled
Dialog already handles close)
- `message.tsx`: Added `ExternalLinkModal` using AutoGPT design system;
removed redundant `onClose` prop
- `globals.css`: Removed old CSS hack for external link modal

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Create a new chat, send a message — verify auto-generated title
appears in sidebar without refresh
- [x] Rename a chat via dropdown — Enter submits, Escape reverts, blank
title rejected
- [x] Rename a chat, then send another message — verify user title is
not overwritten by auto-title
- [x] With many chats, scroll the sidebar — verify "New Chat" button
stays fixed at top
- [x] Click an external link in a message — verify the new dialog
appears with AutoGPT styling

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-06 06:01:41 +00:00
Bently
5d56548e6b fix(frontend): prevent crash on /library with 401 error from pagination helper (#12292)
## Changes
Fixes crash on `/library` page when backend returns a 401 authentication
error.

### Problem

When the backend returns a 401 error, React Query still calls
`getNextPageParam` with the error response. The response doesn't have
the expected pagination structure, causing `pagination` to be
`undefined`. The code then crashes trying
 to access `pagination.current_page`.

Error:
TypeError: Cannot read properties of undefined (reading 'current_page')
    at Object.getNextPageParam

### Solution

Added a defensive null check in `getPaginationNextPageNumber()` to
handle cases where `pagination` is undefined:

```typescript
const { pagination } = lastPage.data;
if (!pagination) return undefined;
```
When undefined is returned, React Query interprets this as "no next page
available" and gracefully stops pagination instead of crashing.

Testing

- Manual testing: Verify /library page handles 401 errors without
crashing
- The fix is defensive and doesn't change behavior for successful
responses

Related Issues

Closes OPEN-2684
2026-03-05 19:52:36 +00:00
Otto
6ecf55d214 fix(frontend): fix 'Open link' button text color to white for contrast (#12304)
Requested by @ntindle

The Streamdown external link safety modal's "Open link" button had dark
text (`color: black`) on a dark background, making it unreadable.
Changed to `color: white` for proper contrast per our design system.

**File:** `autogpt_platform/frontend/src/app/globals.css`

Resolves SECRT-2061

---
Co-authored-by: Nick Tindle (@ntindle)
2026-03-05 19:50:39 +00:00
Bently
7c8c7bf395 feat(llm): add Claude Sonnet 4.6 model (#12158)
## Summary
Adds Claude Sonnet 4.6 (`claude-sonnet-4-6`) to the platform.

## Model Details (from [Anthropic
docs](https://www.anthropic.com/news/claude-sonnet-4-6))
- **API ID:** `claude-sonnet-4-6`
- **Pricing:** $3 / input MTok, $15 / output MTok (same as Sonnet 4.5)
- **Context window:** 200K tokens (1M beta)
- **Max output:** 64K tokens
- **Knowledge cutoff:** Aug 2025 (reliable), Jan 2026 (training data)

## Changes
- Added `CLAUDE_4_6_SONNET` to `LlmModel` enum
- Added metadata entry with correct context/output limits
- Updated Stagehand to use Sonnet 4.6 (better for browser automation
tasks)

## Why
Sonnet 4.6 brings major improvements in coding, computer use, and
reasoning. Developers with early access often prefer it to even Opus
4.5.

---------

Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
2026-03-05 19:36:56 +00:00
Zamil Majdy
0b9e0665dd Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT 2026-03-06 02:32:36 +07:00
Zamil Majdy
be18436e8f Merge branch 'master' of github.com:Significant-Gravitas/AutoGPT into dev 2026-03-06 02:31:40 +07:00
Zamil Majdy
f6f268a1f0 Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into HEAD 2026-03-06 02:29:56 +07:00
Zamil Majdy
ea0333c1fc fix(copilot): always upload transcript instead of size-based skip (#12303)
## Summary

Fixes copilot sessions "forgetting" previous turns due to stale
transcript storage.

**Root cause:** The transcript upload logic used byte size comparison
(`existing >= new → skip`) to prevent overwriting newer transcripts with
older ones. However, with `--resume` the CLI compacts old tool results,
so newer transcripts can have **fewer bytes** despite containing **more
conversation events**. This caused the stored transcript to freeze at
whatever the largest historical upload was — every subsequent turn
downloaded the same stale transcript and the agent lost context of
recent turns.

**Evidence from prod session `41a3814c`:**
- Stored transcript: 764KB (frozen, never updated)
- Turn 1 output: 379KB (75 lines) → upload skipped (764KB >= 379KB)
- Turn 2 output: 422KB (71 lines) → upload skipped (764KB >= 422KB)
- Turn 3 output: **empty** → upload skipped
- Agent resumed from the same stale 764KB transcript every turn, losing
context of the PR it created

**Fix:** Remove the size comparison entirely. The executor holds a
cluster lock per session, so concurrent uploads cannot race. Just always
overwrite with the latest transcript.

## Test plan
- [x] `poetry run pytest backend/copilot/sdk/transcript_test.py` — 25/25
pass
- [x] All pre-commit hooks pass
- [ ] After deploy: verify multi-turn sessions retain context across
turns
2026-03-06 02:26:52 +07:00
Otto
aa7a2f0a48 hotfix(frontend/signup): Add missing createUser() call in password signup (#12287)
Requested by @0ubbe

Password signup was missing the backend `createUser()` call that the
OAuth callback flow already had. This caused `getOnboardingStatus()` to
fail/hang for new users whose backend record didn't exist yet, resulting
in an infinite spinner after account creation.

## Root Cause

| Flow | createUser() | getOnboardingStatus() | Result |
|------|-------------|----------------------|--------|
| OAuth signup |  Called |  Works | Redirects correctly |
| Password signup |  Missing |  Fails/hangs | Infinite spinner |

## Fix

Adds `createUser()` call in `signup/actions.ts` after session is set,
before onboarding status check — matching the OAuth callback pattern.
Includes error handling with Sentry reporting.

## Testing

- Create a new password account → should redirect without spinner
- OAuth signup unaffected (no changes to that flow)

Fixes OPEN-3023

---------

Co-authored-by: Lluis Agusti <hi@llu.lu>
2026-03-05 16:11:51 +08:00
256 changed files with 18787 additions and 8002 deletions

View File

@@ -0,0 +1,17 @@
---
name: backend-check
description: Run the full backend formatting, linting, and test suite. Ensures code quality before commits and PRs. TRIGGER when backend Python code has been modified and needs validation.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Backend Check
## Steps
1. **Format**: `poetry run format` — runs formatting AND linting. NEVER run ruff/black/isort individually
2. **Fix** any remaining errors manually, re-run until clean
3. **Test**: `poetry run test` (runs DB setup + pytest). For specific files: `poetry run pytest -s -vvv <test_files>`
4. **Snapshots** (if needed): `poetry run pytest path/to/test.py --snapshot-update` — review with `git diff`

View File

@@ -0,0 +1,35 @@
---
name: code-style
description: Python code style preferences for the AutoGPT backend. Apply when writing or reviewing Python code. TRIGGER when writing new Python code, reviewing PRs, or refactoring backend code.
user-invocable: false
metadata:
author: autogpt-team
version: "1.0.0"
---
# Code Style
## Imports
- **Top-level only** — no local/inner imports. Move all imports to the top of the file.
## Typing
- **No duck typing** — avoid `hasattr`, `getattr`, `isinstance` for type dispatch. Use proper typed interfaces, unions, or protocols.
- **Pydantic models** over dataclass, namedtuple, or raw dict for structured data.
- **No linter suppressors** — avoid `# type: ignore`, `# noqa`, `# pyright: ignore` etc. 99% of the time the right fix is fixing the type/code, not silencing the tool.
## Code Structure
- **List comprehensions** over manual loop-and-append.
- **Early return** — guard clauses first, avoid deep nesting.
- **Flatten inline** — prefer short, concise expressions. Reduce `if/else` chains with direct returns or ternaries when readable.
- **Modular functions** — break complex logic into small, focused functions rather than long blocks with nested conditionals.
## Review Checklist
Before finishing, always ask:
- Can any function be split into smaller pieces?
- Is there unnecessary nesting that an early return would eliminate?
- Can any loop be a comprehension?
- Is there a simpler way to express this logic?

View File

@@ -0,0 +1,16 @@
---
name: frontend-check
description: Run the full frontend formatting, linting, and type checking suite. Ensures code quality before commits and PRs. TRIGGER when frontend TypeScript/React code has been modified and needs validation.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Frontend Check
## Steps (in order)
1. **Format**: `pnpm format` — NEVER run individual formatters
2. **Lint**: `pnpm lint` — fix errors, re-run until clean
3. **Types**: `pnpm types` — if it keeps failing after multiple attempts, stop and ask the user

View File

@@ -0,0 +1,29 @@
---
name: new-block
description: Create a new backend block following the Block SDK Guide. Guides through provider configuration, schema definition, authentication, and testing. TRIGGER when user asks to create a new block, add a new integration, or build a new node for the graph editor.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# New Block Creation
Read `docs/platform/block-sdk-guide.md` first for the full guide.
## Steps
1. **Provider config** (if external service): create `_config.py` with `ProviderBuilder`
2. **Block file** in `backend/blocks/` (from `autogpt_platform/backend/`):
- Generate a UUID once with `uuid.uuid4()`, then **hard-code that string** as `id` (IDs must be stable across imports)
- `Input(BlockSchema)` and `Output(BlockSchema)` classes
- `async def run` that `yield`s output fields
3. **Files**: use `store_media_file()` with `"for_block_output"` for outputs
4. **Test**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[MyBlock]' -xvs`
5. **Format**: `poetry run format`
## Rules
- Analyze interfaces: do inputs/outputs connect well with other blocks in a graph?
- Use top-level imports, avoid duck typing
- Always use `for_block_output` for block outputs

View File

@@ -0,0 +1,28 @@
---
name: openapi-regen
description: Regenerate the OpenAPI spec and frontend API client. Starts the backend REST server, fetches the spec, and regenerates the typed frontend hooks. TRIGGER when API routes change, new endpoints are added, or frontend API types are stale.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# OpenAPI Spec Regeneration
## Steps
1. **Run end-to-end** in a single shell block (so `REST_PID` persists):
```bash
cd autogpt_platform/backend && poetry run rest &
REST_PID=$!
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && kill $REST_PID && exit 1; done
cd ../frontend && pnpm generate:api:force
kill $REST_PID
pnpm types && pnpm lint && pnpm format
```
## Rules
- Always use `pnpm generate:api:force` (not `pnpm generate:api`)
- Don't manually edit files in `src/app/api/__generated__/`
- Generated hooks follow: `use{Method}{Version}{OperationName}`

View File

@@ -0,0 +1,31 @@
---
name: pr-create
description: Create a pull request for the current branch. TRIGGER when user asks to create a PR, open a pull request, push changes for review, or submit work for merging.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Create Pull Request
## Steps
1. **Check for existing PR**: `gh pr view --json url -q .url 2>/dev/null` — if a PR already exists, output its URL and stop
2. **Understand changes**: `git status`, `git diff dev...HEAD`, `git log dev..HEAD --oneline`
3. **Read PR template**: `.github/PULL_REQUEST_TEMPLATE.md`
4. **Draft PR title**: Use conventional commits format (see CLAUDE.md for types and scopes)
5. **Fill out PR template** as the body — be thorough in the Changes section
6. **Format first** (if relevant changes exist):
- Backend: `cd autogpt_platform/backend && poetry run format`
- Frontend: `cd autogpt_platform/frontend && pnpm format`
- Fix any lint errors, then commit formatting changes before pushing
7. **Push**: `git push -u origin HEAD`
8. **Create PR**: `gh pr create --base dev`
9. **Output** the PR URL
## Rules
- Always target `dev` branch
- Do NOT run tests — CI will handle that
- Use the PR template from `.github/PULL_REQUEST_TEMPLATE.md`

View File

@@ -0,0 +1,51 @@
---
name: pr-review
description: Address all open PR review comments systematically. Fetches comments, addresses each one, reacts +1/-1, and replies when clarification is needed. Keeps iterating until all comments are addressed and CI is green. TRIGGER when user shares a PR URL, asks to address review comments, fix PR feedback, or respond to reviewer comments.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# PR Review Comment Workflow
## Steps
1. **Find PR**: `gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT`
2. **Fetch comments** (all three sources):
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` (top-level reviews)
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` (inline review comments)
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` (PR conversation comments)
3. **Skip** comments already reacted to by PR author
4. **For each unreacted comment**:
- Read referenced code, make the fix (or reply if you disagree/need info)
- **Inline review comments** (`pulls/{N}/comments`):
- React: `gh api repos/.../pulls/comments/{ID}/reactions -f content="+1"` (or `-1`)
- Reply: `gh api repos/.../pulls/{N}/comments/{ID}/replies -f body="..."`
- **PR conversation comments** (`issues/{N}/comments`):
- React: `gh api repos/.../issues/comments/{ID}/reactions -f content="+1"` (or `-1`)
- No threaded replies — post a new issue comment if needed
- **Top-level reviews**: no reaction API — address in code, reply via issue comment if needed
5. **Include autogpt-reviewer bot fixes** too
6. **Format**: `cd autogpt_platform/backend && poetry run format`, `cd autogpt_platform/frontend && pnpm format`
7. **Commit & push**
8. **Re-fetch comments** immediately — address any new unreacted ones before waiting on CI
9. **Stay productive while CI runs** — don't idle. In priority order:
- Run any pending local tests (`poetry run pytest`, e2e, etc.) and fix failures
- Address any remaining comments
- Only poll `gh pr checks {N}` as the last resort when there's truly nothing left to do
10. **If CI fails** — fix, go back to step 6
11. **Re-fetch comments again** after CI is green — address anything that appeared while CI was running
12. **Done** only when: all comments reacted AND CI is green.
## CRITICAL: Do Not Stop
**Loop is: address → format → commit → push → re-check comments → run local tests → wait CI → re-check comments → repeat.**
Never idle. If CI is running and you have nothing to address, run local tests. Waiting on CI is the last resort.
## Rules
- One todo per comment
- For inline review comments: reply on existing threads. For PR conversation comments: post a new issue comment (API doesn't support threaded replies)
- React to every comment: +1 addressed, -1 disagreed (with explanation)

View File

@@ -0,0 +1,45 @@
---
name: worktree-setup
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, generates Prisma client, and optionally starts the app (with port conflict resolution) or runs tests. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Worktree Setup
## Preferred: Use Branchlet
The repo has a `.branchlet.json` config — it handles env file copying, dependency installation, and Prisma generation automatically.
```bash
npm install -g branchlet # install once
branchlet create -n <name> -s <source-branch> -b <new-branch>
branchlet list --json # list all worktrees
```
## Manual Fallback
If branchlet isn't available:
1. `git worktree add ../<RepoName><N> <branch-name>`
2. Copy `.env` files: `backend/.env`, `frontend/.env`, `autogpt_platform/.env`, `db/docker/.env`
3. Install deps:
- `cd autogpt_platform/backend && poetry install && poetry run prisma generate`
- `cd autogpt_platform/frontend && pnpm install`
## Running the App
Free ports first — backend uses: 8001, 8002, 8003, 8005, 8006, 8007, 8008.
```bash
for port in 8001 8002 8003 8005 8006 8007 8008; do
lsof -ti :$port | xargs kill -9 2>/dev/null || true
done
cd <worktree>/autogpt_platform/backend && poetry run app
```
## CoPilot Testing Gotcha
SDK mode spawns a Claude subprocess — **won't work inside Claude Code**. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.

View File

@@ -1,7 +1,7 @@
import logging
import urllib.parse
from collections import defaultdict
from typing import Annotated, Any, Literal, Optional, Sequence
from typing import Annotated, Any, Optional, Sequence
from fastapi import APIRouter, Body, HTTPException, Security
from prisma.enums import AgentExecutionStatus, APIKeyPermission
@@ -9,9 +9,10 @@ from pydantic import BaseModel, Field
from typing_extensions import TypedDict
import backend.api.features.store.cache as store_cache
import backend.api.features.store.db as store_db
import backend.api.features.store.model as store_model
import backend.blocks
from backend.api.external.middleware import require_permission
from backend.api.external.middleware import require_auth, require_permission
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data import user as user_db
@@ -230,13 +231,13 @@ async def get_graph_execution_results(
@v1_router.get(
path="/store/agents",
tags=["store"],
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
response_model=store_model.StoreAgentsResponse,
)
async def get_store_agents(
featured: bool = False,
creator: str | None = None,
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
sorted_by: store_db.StoreAgentsSortOptions | None = None,
search_query: str | None = None,
category: str | None = None,
page: int = 1,
@@ -278,7 +279,7 @@ async def get_store_agents(
@v1_router.get(
path="/store/agents/{username}/{agent_name}",
tags=["store"],
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
response_model=store_model.StoreAgentDetails,
)
async def get_store_agent(
@@ -306,13 +307,13 @@ async def get_store_agent(
@v1_router.get(
path="/store/creators",
tags=["store"],
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
response_model=store_model.CreatorsResponse,
)
async def get_store_creators(
featured: bool = False,
search_query: str | None = None,
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
page: int = 1,
page_size: int = 20,
) -> store_model.CreatorsResponse:
@@ -348,7 +349,7 @@ async def get_store_creators(
@v1_router.get(
path="/store/creators/{username}",
tags=["store"],
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
response_model=store_model.CreatorDetails,
)
async def get_store_creator(

View File

@@ -24,14 +24,13 @@ router = fastapi.APIRouter(
@router.get(
"/listings",
summary="Get Admin Listings History",
response_model=store_model.StoreListingsWithVersionsResponse,
)
async def get_admin_listings_with_versions(
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
search: typing.Optional[str] = None,
page: int = 1,
page_size: int = 20,
):
) -> store_model.StoreListingsWithVersionsAdminViewResponse:
"""
Get store listings with their version history for admins.
@@ -45,36 +44,26 @@ async def get_admin_listings_with_versions(
page_size: Number of items per page
Returns:
StoreListingsWithVersionsResponse with listings and their versions
Paginated listings with their versions
"""
try:
listings = await store_db.get_admin_listings_with_versions(
status=status,
search_query=search,
page=page,
page_size=page_size,
)
return listings
except Exception as e:
logger.exception("Error getting admin listings with versions: %s", e)
return fastapi.responses.JSONResponse(
status_code=500,
content={
"detail": "An error occurred while retrieving listings with versions"
},
)
listings = await store_db.get_admin_listings_with_versions(
status=status,
search_query=search,
page=page,
page_size=page_size,
)
return listings
@router.post(
"/submissions/{store_listing_version_id}/review",
summary="Review Store Submission",
response_model=store_model.StoreSubmission,
)
async def review_submission(
store_listing_version_id: str,
request: store_model.ReviewSubmissionRequest,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
) -> store_model.StoreSubmissionAdminView:
"""
Review a store listing submission.
@@ -84,31 +73,24 @@ async def review_submission(
user_id: Authenticated admin user performing the review
Returns:
StoreSubmission with updated review information
StoreSubmissionAdminView with updated review information
"""
try:
already_approved = await store_db.check_submission_already_approved(
store_listing_version_id=store_listing_version_id,
)
submission = await store_db.review_store_submission(
store_listing_version_id=store_listing_version_id,
is_approved=request.is_approved,
external_comments=request.comments,
internal_comments=request.internal_comments or "",
reviewer_id=user_id,
)
already_approved = await store_db.check_submission_already_approved(
store_listing_version_id=store_listing_version_id,
)
submission = await store_db.review_store_submission(
store_listing_version_id=store_listing_version_id,
is_approved=request.is_approved,
external_comments=request.comments,
internal_comments=request.internal_comments or "",
reviewer_id=user_id,
)
state_changed = already_approved != request.is_approved
# Clear caches when the request is approved as it updates what is shown on the store
if state_changed:
store_cache.clear_all_caches()
return submission
except Exception as e:
logger.exception("Error reviewing submission: %s", e)
return fastapi.responses.JSONResponse(
status_code=500,
content={"detail": "An error occurred while reviewing the submission"},
)
state_changed = already_approved != request.is_approved
# Clear caches whenever approval state changes, since store visibility can change
if state_changed:
store_cache.clear_all_caches()
return submission
@router.get(

View File

@@ -11,7 +11,7 @@ from autogpt_libs import auth
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from prisma.models import UserWorkspaceFile
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
@@ -25,8 +25,10 @@ from backend.copilot.model import (
delete_chat_session,
get_chat_session,
get_user_sessions,
update_session_title,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.e2b_sandbox import kill_sandbox
from backend.copilot.tools.models import (
AgentDetailsResponse,
AgentOutputResponse,
@@ -141,6 +143,20 @@ class CancelSessionResponse(BaseModel):
reason: str | None = None
class UpdateSessionTitleRequest(BaseModel):
"""Request model for updating a session's title."""
title: str
@field_validator("title")
@classmethod
def title_must_not_be_blank(cls, v: str) -> str:
stripped = v.strip()
if not stripped:
raise ValueError("Title must not be blank")
return stripped
# ========== Routes ==========
@@ -250,12 +266,12 @@ async def delete_session(
)
# Best-effort cleanup of the E2B sandbox (if any).
config = ChatConfig()
if config.use_e2b_sandbox and config.e2b_api_key:
from backend.copilot.tools.e2b_sandbox import kill_sandbox
# sandbox_id is in Redis; kill_sandbox() fetches it from there.
e2b_cfg = ChatConfig()
if e2b_cfg.e2b_active:
assert e2b_cfg.e2b_api_key # guaranteed by e2b_active check
try:
await kill_sandbox(session_id, config.e2b_api_key)
await kill_sandbox(session_id, e2b_cfg.e2b_api_key)
except Exception:
logger.warning(
"[E2B] Failed to kill sandbox for session %s", session_id[:12]
@@ -264,6 +280,43 @@ async def delete_session(
return Response(status_code=204)
@router.patch(
"/sessions/{session_id}/title",
summary="Update session title",
dependencies=[Security(auth.requires_user)],
status_code=200,
responses={404: {"description": "Session not found or access denied"}},
)
async def update_session_title_route(
session_id: str,
request: UpdateSessionTitleRequest,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> dict:
"""
Update the title of a chat session.
Allows the user to rename their chat session.
Args:
session_id: The session ID to update.
request: Request body containing the new title.
user_id: The authenticated user's ID.
Returns:
dict: Status of the update.
Raises:
HTTPException: 404 if session not found or not owned by user.
"""
success = await update_session_title(session_id, user_id, request.title)
if not success:
raise HTTPException(
status_code=404,
detail=f"Session {session_id} not found or access denied",
)
return {"status": "ok"}
@router.get(
"/sessions/{session_id}",
)
@@ -753,7 +806,6 @@ async def resume_session_stream(
@router.patch(
"/sessions/{session_id}/assign-user",
dependencies=[Security(auth.requires_user)],
status_code=200,
)
async def session_assign_user(
session_id: str,

View File

@@ -1,4 +1,6 @@
"""Tests for chat route file_ids validation and enrichment."""
"""Tests for chat API routes: session title update and file attachment validation."""
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
@@ -17,6 +19,7 @@ TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
"""Setup auth overrides for all tests in this module"""
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
@@ -24,7 +27,95 @@ def setup_app_auth(mock_jwt_user):
app.dependency_overrides.clear()
# ---- file_ids Pydantic validation (B1) ----
def _mock_update_session_title(
mocker: pytest_mock.MockerFixture, *, success: bool = True
):
"""Mock update_session_title."""
return mocker.patch(
"backend.api.features.chat.routes.update_session_title",
new_callable=AsyncMock,
return_value=success,
)
# ─── Update title: success ─────────────────────────────────────────────
def test_update_title_success(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
mock_update = _mock_update_session_title(mocker, success=True)
response = client.patch(
"/sessions/sess-1/title",
json={"title": "My project"},
)
assert response.status_code == 200
assert response.json() == {"status": "ok"}
mock_update.assert_called_once_with("sess-1", test_user_id, "My project")
def test_update_title_trims_whitespace(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
mock_update = _mock_update_session_title(mocker, success=True)
response = client.patch(
"/sessions/sess-1/title",
json={"title": " trimmed "},
)
assert response.status_code == 200
mock_update.assert_called_once_with("sess-1", test_user_id, "trimmed")
# ─── Update title: blank / whitespace-only → 422 ──────────────────────
def test_update_title_blank_rejected(
test_user_id: str,
) -> None:
"""Whitespace-only titles must be rejected before hitting the DB."""
response = client.patch(
"/sessions/sess-1/title",
json={"title": " "},
)
assert response.status_code == 422
def test_update_title_empty_rejected(
test_user_id: str,
) -> None:
response = client.patch(
"/sessions/sess-1/title",
json={"title": ""},
)
assert response.status_code == 422
# ─── Update title: session not found or wrong user → 404 ──────────────
def test_update_title_not_found(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
_mock_update_session_title(mocker, success=False)
response = client.patch(
"/sessions/sess-1/title",
json={"title": "New name"},
)
assert response.status_code == 404
# ─── file_ids Pydantic validation ─────────────────────────────────────
def test_stream_chat_rejects_too_many_file_ids():
@@ -92,7 +183,7 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
assert response.status_code == 200
# ---- UUID format filtering ----
# ─── UUID format filtering ─────────────────────────────────────────────
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
@@ -131,7 +222,7 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
assert call_kwargs["where"]["id"]["in"] == [valid_id]
# ---- Cross-workspace file_ids ----
# ─── Cross-workspace file_ids ─────────────────────────────────────────
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):

View File

@@ -638,7 +638,7 @@ async def test_process_review_action_auto_approve_creates_auto_approval_records(
# Mock get_node_executions to return node_id mapping
mock_get_node_executions = mocker.patch(
"backend.data.execution.get_node_executions"
"backend.api.features.executions.review.routes.get_node_executions"
)
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
mock_node_exec.node_exec_id = "test_node_123"
@@ -936,7 +936,7 @@ async def test_process_review_action_auto_approve_only_applies_to_approved_revie
# Mock get_node_executions to return node_id mapping
mock_get_node_executions = mocker.patch(
"backend.data.execution.get_node_executions"
"backend.api.features.executions.review.routes.get_node_executions"
)
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
mock_node_exec.node_exec_id = "node_exec_approved"
@@ -1148,7 +1148,7 @@ async def test_process_review_action_per_review_auto_approve_granularity(
# Mock get_node_executions to return batch node data
mock_get_node_executions = mocker.patch(
"backend.data.execution.get_node_executions"
"backend.api.features.executions.review.routes.get_node_executions"
)
# Create mock node executions for each review
mock_node_execs = []

View File

@@ -6,10 +6,15 @@ import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, HTTPException, Query, Security, status
from prisma.enums import ReviewStatus
from backend.copilot.constants import (
is_copilot_synthetic_id,
parse_node_id_from_exec_id,
)
from backend.data.execution import (
ExecutionContext,
ExecutionStatus,
get_graph_execution_meta,
get_node_executions,
)
from backend.data.graph import get_graph_settings
from backend.data.human_review import (
@@ -36,6 +41,38 @@ router = APIRouter(
)
async def _resolve_node_ids(
node_exec_ids: list[str],
graph_exec_id: str,
is_copilot: bool,
) -> dict[str, str]:
"""Resolve node_exec_id -> node_id for auto-approval records.
CoPilot synthetic IDs encode node_id in the format "{node_id}:{random}".
Graph executions look up node_id from NodeExecution records.
"""
if not node_exec_ids:
return {}
if is_copilot:
return {neid: parse_node_id_from_exec_id(neid) for neid in node_exec_ids}
node_execs = await get_node_executions(
graph_exec_id=graph_exec_id, include_exec_data=False
)
node_exec_map = {ne.node_exec_id: ne.node_id for ne in node_execs}
result = {}
for neid in node_exec_ids:
if neid in node_exec_map:
result[neid] = node_exec_map[neid]
else:
logger.error(
f"Failed to resolve node_id for {neid}: Node execution not found."
)
return result
@router.get(
"/pending",
summary="Get Pending Reviews",
@@ -110,14 +147,16 @@ async def list_pending_reviews_for_execution(
"""
# Verify user owns the graph execution before returning reviews
graph_exec = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
# (CoPilot synthetic IDs don't have graph execution records)
if not is_copilot_synthetic_id(graph_exec_id):
graph_exec = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
@@ -160,30 +199,26 @@ async def process_review_action(
)
graph_exec_id = next(iter(graph_exec_ids))
is_copilot = is_copilot_synthetic_id(graph_exec_id)
# Validate execution status before processing reviews
graph_exec_meta = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec_meta:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
# Only allow processing reviews if execution is paused for review
# or incomplete (partial execution with some reviews already processed)
if graph_exec_meta.status not in (
ExecutionStatus.REVIEW,
ExecutionStatus.INCOMPLETE,
):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. "
f"Reviews can only be processed when execution is paused (REVIEW status). "
f"Current status: {graph_exec_meta.status}",
# Validate execution status for graph executions (skip for CoPilot synthetic IDs)
if not is_copilot:
graph_exec_meta = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec_meta:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
if graph_exec_meta.status not in (
ExecutionStatus.REVIEW,
ExecutionStatus.INCOMPLETE,
):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}",
)
# Build review decisions map and track which reviews requested auto-approval
# Auto-approved reviews use original data (no modifications allowed)
@@ -236,7 +271,7 @@ async def process_review_action(
)
return (node_id, False)
# Collect node_exec_ids that need auto-approval
# Collect node_exec_ids that need auto-approval and resolve their node_ids
node_exec_ids_needing_auto_approval = [
node_exec_id
for node_exec_id, review_result in updated_reviews.items()
@@ -244,29 +279,16 @@ async def process_review_action(
and auto_approve_requests.get(node_exec_id, False)
]
# Batch-fetch node executions to get node_ids
node_id_map = await _resolve_node_ids(
node_exec_ids_needing_auto_approval, graph_exec_id, is_copilot
)
# Deduplicate by node_id — one auto-approval per node
nodes_needing_auto_approval: dict[str, Any] = {}
if node_exec_ids_needing_auto_approval:
from backend.data.execution import get_node_executions
node_execs = await get_node_executions(
graph_exec_id=graph_exec_id, include_exec_data=False
)
node_exec_map = {node_exec.node_exec_id: node_exec for node_exec in node_execs}
for node_exec_id in node_exec_ids_needing_auto_approval:
node_exec = node_exec_map.get(node_exec_id)
if node_exec:
review_result = updated_reviews[node_exec_id]
# Use the first approved review for this node (deduplicate by node_id)
if node_exec.node_id not in nodes_needing_auto_approval:
nodes_needing_auto_approval[node_exec.node_id] = review_result
else:
logger.error(
f"Failed to create auto-approval record for {node_exec_id}: "
f"Node execution not found. This may indicate a race condition "
f"or data inconsistency."
)
for node_exec_id in node_exec_ids_needing_auto_approval:
node_id = node_id_map.get(node_exec_id)
if node_id and node_id not in nodes_needing_auto_approval:
nodes_needing_auto_approval[node_id] = updated_reviews[node_exec_id]
# Execute all auto-approval creations in parallel (deduplicated by node_id)
auto_approval_results = await asyncio.gather(
@@ -281,13 +303,11 @@ async def process_review_action(
auto_approval_failed_count = 0
for result in auto_approval_results:
if isinstance(result, Exception):
# Unexpected exception during auto-approval creation
auto_approval_failed_count += 1
logger.error(
f"Unexpected exception during auto-approval creation: {result}"
)
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
# Auto-approval creation failed (returned False)
auto_approval_failed_count += 1
# Count results
@@ -302,22 +322,20 @@ async def process_review_action(
if review.status == ReviewStatus.REJECTED
)
# Resume execution only if ALL pending reviews for this execution have been processed
if updated_reviews:
# Resume graph execution only for real graph executions (not CoPilot)
# CoPilot sessions are resumed by the LLM retrying run_block with review_id
if not is_copilot and updated_reviews:
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
if not still_has_pending:
# Get the graph_id from any processed review
first_review = next(iter(updated_reviews.values()))
try:
# Fetch user and settings to build complete execution context
user = await get_user_by_id(user_id)
settings = await get_graph_settings(
user_id=user_id, graph_id=first_review.graph_id
)
# Preserve user's timezone preference when resuming execution
user_timezone = (
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
)

View File

@@ -8,7 +8,6 @@ import prisma.errors
import prisma.models
import prisma.types
import backend.api.features.store.exceptions as store_exceptions
import backend.api.features.store.image_gen as store_image_gen
import backend.api.features.store.media as store_media
import backend.data.graph as graph_db
@@ -251,7 +250,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
The requested LibraryAgent.
Raises:
AgentNotFoundError: If the specified agent does not exist.
NotFoundError: If the specified agent does not exist.
DatabaseError: If there's an error during retrieval.
"""
library_agent = await prisma.models.LibraryAgent.prisma().find_first(
@@ -398,6 +397,7 @@ async def create_library_agent(
hitl_safe_mode: bool = True,
sensitive_action_safe_mode: bool = False,
create_library_agents_for_sub_graphs: bool = True,
folder_id: str | None = None,
) -> list[library_model.LibraryAgent]:
"""
Adds an agent to the user's library (LibraryAgent table).
@@ -414,12 +414,18 @@ async def create_library_agent(
If the graph has sub-graphs, the parent graph will always be the first entry in the list.
Raises:
AgentNotFoundError: If the specified agent does not exist.
NotFoundError: If the specified agent does not exist.
DatabaseError: If there's an error during creation or if image generation fails.
"""
logger.info(
f"Creating library agent for graph #{graph.id} v{graph.version}; user:<redacted>"
)
# Authorization: FK only checks existence, not ownership.
# Verify the folder belongs to this user to prevent cross-user nesting.
if folder_id:
await get_folder(folder_id, user_id)
graph_entries = (
[graph, *graph.sub_graphs] if create_library_agents_for_sub_graphs else [graph]
)
@@ -432,7 +438,6 @@ async def create_library_agent(
isCreatedByUser=(user_id == user_id),
useGraphIsActiveVersion=True,
User={"connect": {"id": user_id}},
# Creator={"connect": {"id": user_id}},
AgentGraph={
"connect": {
"graphVersionId": {
@@ -448,6 +453,11 @@ async def create_library_agent(
sensitive_action_safe_mode=sensitive_action_safe_mode,
).model_dump()
),
**(
{"Folder": {"connect": {"id": folder_id}}}
if folder_id and graph_entry is graph
else {}
),
),
include=library_agent_include(
user_id, include_nodes=False, include_executions=False
@@ -529,6 +539,7 @@ async def update_agent_version_in_library(
async def create_graph_in_library(
graph: graph_db.Graph,
user_id: str,
folder_id: str | None = None,
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
"""Create a new graph and add it to the user's library."""
graph.version = 1
@@ -542,6 +553,7 @@ async def create_graph_in_library(
user_id=user_id,
sensitive_action_safe_mode=True,
create_library_agents_for_sub_graphs=False,
folder_id=folder_id,
)
if created_graph.is_active:
@@ -817,7 +829,7 @@ async def add_store_agent_to_library(
The newly created LibraryAgent if successfully added, the existing corresponding one if any.
Raises:
AgentNotFoundError: If the store listing or associated agent is not found.
NotFoundError: If the store listing or associated agent is not found.
DatabaseError: If there's an issue creating the LibraryAgent record.
"""
logger.debug(
@@ -832,7 +844,7 @@ async def add_store_agent_to_library(
)
if not store_listing_version or not store_listing_version.AgentGraph:
logger.warning(f"Store listing version not found: {store_listing_version_id}")
raise store_exceptions.AgentNotFoundError(
raise NotFoundError(
f"Store listing version {store_listing_version_id} not found or invalid"
)
@@ -846,7 +858,7 @@ async def add_store_agent_to_library(
include_subgraphs=False,
)
if not graph_model:
raise store_exceptions.AgentNotFoundError(
raise NotFoundError(
f"Graph #{graph.id} v{graph.version} not found or accessible"
)
@@ -1481,6 +1493,67 @@ async def bulk_move_agents_to_folder(
return [library_model.LibraryAgent.from_db(agent) for agent in agents]
def collect_tree_ids(
nodes: list[library_model.LibraryFolderTree],
visited: set[str] | None = None,
) -> list[str]:
"""Collect all folder IDs from a folder tree."""
if visited is None:
visited = set()
ids: list[str] = []
for n in nodes:
if n.id in visited:
continue
visited.add(n.id)
ids.append(n.id)
ids.extend(collect_tree_ids(n.children, visited))
return ids
async def get_folder_agent_summaries(
user_id: str, folder_id: str
) -> list[dict[str, str | None]]:
"""Get a lightweight list of agents in a folder (id, name, description)."""
all_agents: list[library_model.LibraryAgent] = []
for page in itertools.count(1):
resp = await list_library_agents(
user_id=user_id, folder_id=folder_id, page=page
)
all_agents.extend(resp.agents)
if page >= resp.pagination.total_pages:
break
return [
{"id": a.id, "name": a.name, "description": a.description} for a in all_agents
]
async def get_root_agent_summaries(
user_id: str,
) -> list[dict[str, str | None]]:
"""Get a lightweight list of root-level agents (folderId IS NULL)."""
all_agents: list[library_model.LibraryAgent] = []
for page in itertools.count(1):
resp = await list_library_agents(
user_id=user_id, include_root_only=True, page=page
)
all_agents.extend(resp.agents)
if page >= resp.pagination.total_pages:
break
return [
{"id": a.id, "name": a.name, "description": a.description} for a in all_agents
]
async def get_folder_agents_map(
user_id: str, folder_ids: list[str]
) -> dict[str, list[dict[str, str | None]]]:
"""Get agent summaries for multiple folders concurrently."""
results = await asyncio.gather(
*(get_folder_agent_summaries(user_id, fid) for fid in folder_ids)
)
return dict(zip(folder_ids, results))
##############################################
########### Presets DB Functions #############
##############################################

View File

@@ -4,7 +4,6 @@ import prisma.enums
import prisma.models
import pytest
import backend.api.features.store.exceptions
from backend.data.db import connect
from backend.data.includes import library_agent_include
@@ -218,7 +217,7 @@ async def test_add_agent_to_library_not_found(mocker):
)
# Call function and verify exception
with pytest.raises(backend.api.features.store.exceptions.AgentNotFoundError):
with pytest.raises(db.NotFoundError):
await db.add_store_agent_to_library("version123", "test-user")
# Verify mock called correctly

View File

@@ -24,7 +24,7 @@ from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.request import HTTPClientError, Requests, validate_url
from backend.util.request import HTTPClientError, Requests, validate_url_host
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -80,7 +80,7 @@ async def discover_tools(
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
await validate_url_host(request.server_url)
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
@@ -167,7 +167,7 @@ async def mcp_oauth_login(
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
await validate_url_host(request.server_url)
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
@@ -187,7 +187,7 @@ async def mcp_oauth_login(
# Validate the auth server URL from metadata to prevent SSRF.
try:
await validate_url(auth_server_url, trusted_origins=[])
await validate_url_host(auth_server_url)
except ValueError as e:
raise fastapi.HTTPException(
status_code=400,
@@ -234,7 +234,7 @@ async def mcp_oauth_login(
if registration_endpoint:
# Validate the registration endpoint to prevent SSRF via metadata.
try:
await validate_url(registration_endpoint, trusted_origins=[])
await validate_url_host(registration_endpoint)
except ValueError:
pass # Skip registration, fall back to default client_id
else:
@@ -429,7 +429,7 @@ async def mcp_store_token(
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
await validate_url_host(request.server_url)
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")

View File

@@ -32,9 +32,9 @@ async def client():
@pytest.fixture(autouse=True)
def _bypass_ssrf_validation():
"""Bypass validate_url in all route tests (test URLs don't resolve)."""
"""Bypass validate_url_host in all route tests (test URLs don't resolve)."""
with patch(
"backend.api.features.mcp.routes.validate_url",
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
):
yield
@@ -521,12 +521,12 @@ class TestStoreToken:
class TestSSRFValidation:
"""Verify that validate_url is enforced on all endpoints."""
"""Verify that validate_url_host is enforced on all endpoints."""
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):
@@ -541,7 +541,7 @@ class TestSSRFValidation:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
side_effect=ValueError("blocked private IP"),
):
@@ -556,7 +556,7 @@ class TestSSRFValidation:
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):

View File

@@ -1,5 +1,3 @@
from typing import Literal
from backend.util.cache import cached
from . import db as store_db
@@ -23,7 +21,7 @@ def clear_all_caches():
async def _get_cached_store_agents(
featured: bool,
creator: str | None,
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None,
sorted_by: store_db.StoreAgentsSortOptions | None,
search_query: str | None,
category: str | None,
page: int,
@@ -57,7 +55,7 @@ async def _get_cached_agent_details(
async def _get_cached_store_creators(
featured: bool,
search_query: str | None,
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None,
sorted_by: store_db.StoreCreatorsSortOptions | None,
page: int,
page_size: int,
):
@@ -75,4 +73,4 @@ async def _get_cached_store_creators(
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
async def _get_cached_creator_details(username: str):
"""Cached helper to get creator details."""
return await store_db.get_store_creator_details(username=username.lower())
return await store_db.get_store_creator(username=username.lower())

File diff suppressed because it is too large Load Diff

View File

@@ -26,7 +26,7 @@ async def test_get_store_agents(mocker):
mock_agents = [
prisma.models.StoreAgent(
listing_id="test-id",
storeListingVersionId="version123",
listing_version_id="version123",
slug="test-agent",
agent_name="Test Agent",
agent_video=None,
@@ -40,11 +40,11 @@ async def test_get_store_agents(mocker):
runs=10,
rating=4.5,
versions=["1.0"],
agentGraphVersions=["1"],
agentGraphId="test-graph-id",
graph_id="test-graph-id",
graph_versions=["1"],
updated_at=datetime.now(),
is_available=False,
useForOnboarding=False,
use_for_onboarding=False,
)
]
@@ -68,10 +68,10 @@ async def test_get_store_agents(mocker):
@pytest.mark.asyncio(loop_scope="session")
async def test_get_store_agent_details(mocker):
# Mock data
# Mock data - StoreAgent view already contains the active version data
mock_agent = prisma.models.StoreAgent(
listing_id="test-id",
storeListingVersionId="version123",
listing_version_id="version123",
slug="test-agent",
agent_name="Test Agent",
agent_video="video.mp4",
@@ -85,102 +85,38 @@ async def test_get_store_agent_details(mocker):
runs=10,
rating=4.5,
versions=["1.0"],
agentGraphVersions=["1"],
agentGraphId="test-graph-id",
updated_at=datetime.now(),
is_available=False,
useForOnboarding=False,
)
# Mock active version agent (what we want to return for active version)
mock_active_agent = prisma.models.StoreAgent(
listing_id="test-id",
storeListingVersionId="active-version-id",
slug="test-agent",
agent_name="Test Agent Active",
agent_video="active_video.mp4",
agent_image=["active_image.jpg"],
featured=False,
creator_username="creator",
creator_avatar="avatar.jpg",
sub_heading="Test heading active",
description="Test description active",
categories=["test"],
runs=15,
rating=4.8,
versions=["1.0", "2.0"],
agentGraphVersions=["1", "2"],
agentGraphId="test-graph-id-active",
graph_id="test-graph-id",
graph_versions=["1"],
updated_at=datetime.now(),
is_available=True,
useForOnboarding=False,
use_for_onboarding=False,
)
# Create a mock StoreListing result
mock_store_listing = mocker.MagicMock()
mock_store_listing.activeVersionId = "active-version-id"
mock_store_listing.hasApprovedVersion = True
mock_store_listing.ActiveVersion = mocker.MagicMock()
mock_store_listing.ActiveVersion.recommendedScheduleCron = None
# Mock StoreAgent prisma call - need to handle multiple calls
# Mock StoreAgent prisma call
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
# Set up side_effect to return different results for different calls
def mock_find_first_side_effect(*args, **kwargs):
where_clause = kwargs.get("where", {})
if "storeListingVersionId" in where_clause:
# Second call for active version
return mock_active_agent
else:
# First call for initial lookup
return mock_agent
mock_store_agent.return_value.find_first = mocker.AsyncMock(
side_effect=mock_find_first_side_effect
)
# Mock Profile prisma call
mock_profile = mocker.MagicMock()
mock_profile.userId = "user-id-123"
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
mock_profile_db.return_value.find_first = mocker.AsyncMock(
return_value=mock_profile
)
# Mock StoreListing prisma call
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(
return_value=mock_store_listing
)
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
# Call function
result = await db.get_store_agent_details("creator", "test-agent")
# Verify results - should use active version data
# Verify results - constructed from the StoreAgent view
assert result.slug == "test-agent"
assert result.agent_name == "Test Agent Active" # From active version
assert result.active_version_id == "active-version-id"
assert result.agent_name == "Test Agent"
assert result.active_version_id == "version123"
assert result.has_approved_version is True
assert (
result.store_listing_version_id == "active-version-id"
) # Should be active version ID
assert result.store_listing_version_id == "version123"
assert result.graph_id == "test-graph-id"
assert result.runs == 10
assert result.rating == 4.5
# Verify mocks called correctly - now expecting 2 calls
assert mock_store_agent.return_value.find_first.call_count == 2
# Check the specific calls
calls = mock_store_agent.return_value.find_first.call_args_list
assert calls[0] == mocker.call(
# Verify single StoreAgent lookup
mock_store_agent.return_value.find_first.assert_called_once_with(
where={"creator_username": "creator", "slug": "test-agent"}
)
assert calls[1] == mocker.call(where={"storeListingVersionId": "active-version-id"})
mock_store_listing_db.return_value.find_first.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_get_store_creator_details(mocker):
async def test_get_store_creator(mocker):
# Mock data
mock_creator_data = prisma.models.Creator(
name="Test Creator",
@@ -202,7 +138,7 @@ async def test_get_store_creator_details(mocker):
mock_creator.return_value.find_unique.return_value = mock_creator_data
# Call function
result = await db.get_store_creator_details("creator")
result = await db.get_store_creator("creator")
# Verify results
assert result.username == "creator"
@@ -218,61 +154,110 @@ async def test_get_store_creator_details(mocker):
@pytest.mark.asyncio(loop_scope="session")
async def test_create_store_submission(mocker):
# Mock data
now = datetime.now()
# Mock agent graph (with no pending submissions) and user with profile
mock_profile = prisma.models.Profile(
id="profile-id",
userId="user-id",
name="Test User",
username="testuser",
description="Test",
isFeatured=False,
links=[],
createdAt=now,
updatedAt=now,
)
mock_user = prisma.models.User(
id="user-id",
email="test@example.com",
createdAt=now,
updatedAt=now,
Profile=[mock_profile],
emailVerified=True,
metadata="{}", # type: ignore[reportArgumentType]
integrations="",
maxEmailsPerDay=1,
notifyOnAgentRun=True,
notifyOnZeroBalance=True,
notifyOnLowBalance=True,
notifyOnBlockExecutionFailed=True,
notifyOnContinuousAgentError=True,
notifyOnDailySummary=True,
notifyOnWeeklySummary=True,
notifyOnMonthlySummary=True,
notifyOnAgentApproved=True,
notifyOnAgentRejected=True,
timezone="Europe/Delft",
)
mock_agent = prisma.models.AgentGraph(
id="agent-id",
version=1,
userId="user-id",
createdAt=datetime.now(),
createdAt=now,
isActive=True,
StoreListingVersions=[],
User=mock_user,
)
mock_listing = prisma.models.StoreListing(
# Mock the created StoreListingVersion (returned by create)
mock_store_listing_obj = prisma.models.StoreListing(
id="listing-id",
createdAt=datetime.now(),
updatedAt=datetime.now(),
createdAt=now,
updatedAt=now,
isDeleted=False,
hasApprovedVersion=False,
slug="test-agent",
agentGraphId="agent-id",
agentGraphVersion=1,
owningUserId="user-id",
Versions=[
prisma.models.StoreListingVersion(
id="version-id",
agentGraphId="agent-id",
agentGraphVersion=1,
name="Test Agent",
description="Test description",
createdAt=datetime.now(),
updatedAt=datetime.now(),
subHeading="Test heading",
imageUrls=["image.jpg"],
categories=["test"],
isFeatured=False,
isDeleted=False,
version=1,
storeListingId="listing-id",
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
isAvailable=True,
)
],
useForOnboarding=False,
)
mock_version = prisma.models.StoreListingVersion(
id="version-id",
agentGraphId="agent-id",
agentGraphVersion=1,
name="Test Agent",
description="Test description",
createdAt=now,
updatedAt=now,
subHeading="",
imageUrls=[],
categories=[],
isFeatured=False,
isDeleted=False,
version=1,
storeListingId="listing-id",
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
isAvailable=True,
submittedAt=now,
StoreListing=mock_store_listing_obj,
)
# Mock prisma calls
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
mock_agent_graph.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
mock_store_listing = mocker.patch("prisma.models.StoreListing.prisma")
mock_store_listing.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_store_listing.return_value.create = mocker.AsyncMock(return_value=mock_listing)
# Mock transaction context manager
mock_tx = mocker.MagicMock()
mocker.patch(
"backend.api.features.store.db.transaction",
return_value=mocker.AsyncMock(
__aenter__=mocker.AsyncMock(return_value=mock_tx),
__aexit__=mocker.AsyncMock(return_value=False),
),
)
mock_sl = mocker.patch("prisma.models.StoreListing.prisma")
mock_sl.return_value.find_unique = mocker.AsyncMock(return_value=None)
mock_slv = mocker.patch("prisma.models.StoreListingVersion.prisma")
mock_slv.return_value.create = mocker.AsyncMock(return_value=mock_version)
# Call function
result = await db.create_store_submission(
user_id="user-id",
agent_id="agent-id",
agent_version=1,
graph_id="agent-id",
graph_version=1,
slug="test-agent",
name="Test Agent",
description="Test description",
@@ -281,11 +266,11 @@ async def test_create_store_submission(mocker):
# Verify results
assert result.name == "Test Agent"
assert result.description == "Test description"
assert result.store_listing_version_id == "version-id"
assert result.listing_version_id == "version-id"
# Verify mocks called correctly
mock_agent_graph.return_value.find_first.assert_called_once()
mock_store_listing.return_value.create.assert_called_once()
mock_slv.return_value.create.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
@@ -318,7 +303,6 @@ async def test_update_profile(mocker):
description="Test description",
links=["link1"],
avatar_url="avatar.jpg",
is_featured=False,
)
# Call function
@@ -389,7 +373,7 @@ async def test_get_store_agents_with_search_and_filters_parameterized():
creators=["creator1'; DROP TABLE Users; --", "creator2"],
category="AI'; DELETE FROM StoreAgent; --",
featured=True,
sorted_by="rating",
sorted_by=db.StoreAgentsSortOptions.RATING,
page=1,
page_size=20,
)

View File

@@ -57,12 +57,6 @@ class StoreError(ValueError):
pass
class AgentNotFoundError(NotFoundError):
"""Raised when an agent is not found"""
pass
class CreatorNotFoundError(NotFoundError):
"""Raised when a creator is not found"""

View File

@@ -568,7 +568,7 @@ async def hybrid_search(
SELECT uce."contentId" as "storeListingVersionId"
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
INNER JOIN {{schema_prefix}}"StoreAgent" sa
ON uce."contentId" = sa."storeListingVersionId"
ON uce."contentId" = sa.listing_version_id
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
AND uce."userId" IS NULL
AND uce.search @@ plainto_tsquery('english', {query_param})
@@ -582,7 +582,7 @@ async def hybrid_search(
SELECT uce."contentId", uce.embedding
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
INNER JOIN {{schema_prefix}}"StoreAgent" sa
ON uce."contentId" = sa."storeListingVersionId"
ON uce."contentId" = sa.listing_version_id
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
AND uce."userId" IS NULL
AND {where_clause}
@@ -605,7 +605,7 @@ async def hybrid_search(
sa.featured,
sa.is_available,
sa.updated_at,
sa."agentGraphId",
sa.graph_id,
-- Searchable text for BM25 reranking
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
-- Semantic score
@@ -627,9 +627,9 @@ async def hybrid_search(
sa.runs as popularity_raw
FROM candidates c
INNER JOIN {{schema_prefix}}"StoreAgent" sa
ON c."storeListingVersionId" = sa."storeListingVersionId"
ON c."storeListingVersionId" = sa.listing_version_id
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
ON sa."storeListingVersionId" = uce."contentId"
ON sa.listing_version_id = uce."contentId"
AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
),
max_vals AS (
@@ -665,7 +665,7 @@ async def hybrid_search(
featured,
is_available,
updated_at,
"agentGraphId",
graph_id,
searchable_text,
semantic_score,
lexical_score,

View File

@@ -1,11 +1,14 @@
import datetime
from typing import List
from typing import TYPE_CHECKING, List, Self
import prisma.enums
import pydantic
from backend.util.models import Pagination
if TYPE_CHECKING:
import prisma.models
class ChangelogEntry(pydantic.BaseModel):
version: str
@@ -13,9 +16,9 @@ class ChangelogEntry(pydantic.BaseModel):
date: datetime.datetime
class MyAgent(pydantic.BaseModel):
agent_id: str
agent_version: int
class MyUnpublishedAgent(pydantic.BaseModel):
graph_id: str
graph_version: int
agent_name: str
agent_image: str | None = None
description: str
@@ -23,8 +26,8 @@ class MyAgent(pydantic.BaseModel):
recommended_schedule_cron: str | None = None
class MyAgentsResponse(pydantic.BaseModel):
agents: list[MyAgent]
class MyUnpublishedAgentsResponse(pydantic.BaseModel):
agents: list[MyUnpublishedAgent]
pagination: Pagination
@@ -40,6 +43,21 @@ class StoreAgent(pydantic.BaseModel):
rating: float
agent_graph_id: str
@classmethod
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgent":
return cls(
slug=agent.slug,
agent_name=agent.agent_name,
agent_image=agent.agent_image[0] if agent.agent_image else "",
creator=agent.creator_username or "Needs Profile",
creator_avatar=agent.creator_avatar or "",
sub_heading=agent.sub_heading,
description=agent.description,
runs=agent.runs,
rating=agent.rating,
agent_graph_id=agent.graph_id,
)
class StoreAgentsResponse(pydantic.BaseModel):
agents: list[StoreAgent]
@@ -62,81 +80,192 @@ class StoreAgentDetails(pydantic.BaseModel):
runs: int
rating: float
versions: list[str]
agentGraphVersions: list[str]
agentGraphId: str
graph_id: str
graph_versions: list[str]
last_updated: datetime.datetime
recommended_schedule_cron: str | None = None
active_version_id: str | None = None
has_approved_version: bool = False
active_version_id: str
has_approved_version: bool
# Optional changelog data when include_changelog=True
changelog: list[ChangelogEntry] | None = None
class Creator(pydantic.BaseModel):
name: str
username: str
description: str
avatar_url: str
num_agents: int
agent_rating: float
agent_runs: int
is_featured: bool
class CreatorsResponse(pydantic.BaseModel):
creators: List[Creator]
pagination: Pagination
class CreatorDetails(pydantic.BaseModel):
name: str
username: str
description: str
links: list[str]
avatar_url: str
agent_rating: float
agent_runs: int
top_categories: list[str]
@classmethod
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgentDetails":
return cls(
store_listing_version_id=agent.listing_version_id,
slug=agent.slug,
agent_name=agent.agent_name,
agent_video=agent.agent_video or "",
agent_output_demo=agent.agent_output_demo or "",
agent_image=agent.agent_image,
creator=agent.creator_username or "",
creator_avatar=agent.creator_avatar or "",
sub_heading=agent.sub_heading,
description=agent.description,
categories=agent.categories,
runs=agent.runs,
rating=agent.rating,
versions=agent.versions,
graph_id=agent.graph_id,
graph_versions=agent.graph_versions,
last_updated=agent.updated_at,
recommended_schedule_cron=agent.recommended_schedule_cron,
active_version_id=agent.listing_version_id,
has_approved_version=True, # StoreAgent view only has approved agents
)
class Profile(pydantic.BaseModel):
name: str
"""Marketplace user profile (only attributes that the user can update)"""
username: str
name: str
description: str
avatar_url: str | None
links: list[str]
avatar_url: str
is_featured: bool = False
class ProfileDetails(Profile):
"""Marketplace user profile (including read-only fields)"""
is_featured: bool
@classmethod
def from_db(cls, profile: "prisma.models.Profile") -> "ProfileDetails":
return cls(
name=profile.name,
username=profile.username,
avatar_url=profile.avatarUrl,
description=profile.description,
links=profile.links,
is_featured=profile.isFeatured,
)
class CreatorDetails(ProfileDetails):
"""Marketplace creator profile details, including aggregated stats"""
num_agents: int
agent_runs: int
agent_rating: float
top_categories: list[str]
@classmethod
def from_db(cls, creator: "prisma.models.Creator") -> "CreatorDetails": # type: ignore[override]
return cls(
name=creator.name,
username=creator.username,
avatar_url=creator.avatar_url,
description=creator.description,
links=creator.links,
is_featured=creator.is_featured,
num_agents=creator.num_agents,
agent_runs=creator.agent_runs,
agent_rating=creator.agent_rating,
top_categories=creator.top_categories,
)
class CreatorsResponse(pydantic.BaseModel):
creators: List[CreatorDetails]
pagination: Pagination
class StoreSubmission(pydantic.BaseModel):
# From StoreListing:
listing_id: str
agent_id: str
agent_version: int
user_id: str
slug: str
# From StoreListingVersion:
listing_version_id: str
listing_version: int
graph_id: str
graph_version: int
name: str
sub_heading: str
slug: str
description: str
instructions: str | None = None
instructions: str | None
categories: list[str]
image_urls: list[str]
date_submitted: datetime.datetime
status: prisma.enums.SubmissionStatus
runs: int
rating: float
store_listing_version_id: str | None = None
version: int | None = None # Actual version number from the database
video_url: str | None
agent_output_demo_url: str | None
submitted_at: datetime.datetime | None
changes_summary: str | None
status: prisma.enums.SubmissionStatus
reviewed_at: datetime.datetime | None = None
reviewer_id: str | None = None
review_comments: str | None = None # External comments visible to creator
internal_comments: str | None = None # Private notes for admin use only
reviewed_at: datetime.datetime | None = None
changes_summary: str | None = None
# Additional fields for editing
video_url: str | None = None
agent_output_demo_url: str | None = None
categories: list[str] = []
# Aggregated from AgentGraphExecutions and StoreListingReviews:
run_count: int = 0
review_count: int = 0
review_avg_rating: float = 0.0
@classmethod
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
"""Construct from the StoreSubmission Prisma view."""
return cls(
listing_id=_sub.listing_id,
user_id=_sub.user_id,
slug=_sub.slug,
listing_version_id=_sub.listing_version_id,
listing_version=_sub.listing_version,
graph_id=_sub.graph_id,
graph_version=_sub.graph_version,
name=_sub.name,
sub_heading=_sub.sub_heading,
description=_sub.description,
instructions=_sub.instructions,
categories=_sub.categories,
image_urls=_sub.image_urls,
video_url=_sub.video_url,
agent_output_demo_url=_sub.agent_output_demo_url,
submitted_at=_sub.submitted_at,
changes_summary=_sub.changes_summary,
status=_sub.status,
reviewed_at=_sub.reviewed_at,
reviewer_id=_sub.reviewer_id,
review_comments=_sub.review_comments,
run_count=_sub.run_count,
review_count=_sub.review_count,
review_avg_rating=_sub.review_avg_rating,
)
@classmethod
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
"""
Construct from the StoreListingVersion Prisma model (with StoreListing included)
"""
if not (_l := _lv.StoreListing):
raise ValueError("StoreListingVersion must have included StoreListing")
return cls(
listing_id=_l.id,
user_id=_l.owningUserId,
slug=_l.slug,
listing_version_id=_lv.id,
listing_version=_lv.version,
graph_id=_lv.agentGraphId,
graph_version=_lv.agentGraphVersion,
name=_lv.name,
sub_heading=_lv.subHeading,
description=_lv.description,
instructions=_lv.instructions,
categories=_lv.categories,
image_urls=_lv.imageUrls,
video_url=_lv.videoUrl,
agent_output_demo_url=_lv.agentOutputDemoUrl,
submitted_at=_lv.submittedAt,
changes_summary=_lv.changesSummary,
status=_lv.submissionStatus,
reviewed_at=_lv.reviewedAt,
reviewer_id=_lv.reviewerId,
review_comments=_lv.reviewComments,
)
class StoreSubmissionsResponse(pydantic.BaseModel):
@@ -144,33 +273,12 @@ class StoreSubmissionsResponse(pydantic.BaseModel):
pagination: Pagination
class StoreListingWithVersions(pydantic.BaseModel):
"""A store listing with its version history"""
listing_id: str
slug: str
agent_id: str
agent_version: int
active_version_id: str | None = None
has_approved_version: bool = False
creator_email: str | None = None
latest_version: StoreSubmission | None = None
versions: list[StoreSubmission] = []
class StoreListingsWithVersionsResponse(pydantic.BaseModel):
"""Response model for listings with version history"""
listings: list[StoreListingWithVersions]
pagination: Pagination
class StoreSubmissionRequest(pydantic.BaseModel):
agent_id: str = pydantic.Field(
..., min_length=1, description="Agent ID cannot be empty"
graph_id: str = pydantic.Field(
..., min_length=1, description="Graph ID cannot be empty"
)
agent_version: int = pydantic.Field(
..., gt=0, description="Agent version must be greater than 0"
graph_version: int = pydantic.Field(
..., gt=0, description="Graph version must be greater than 0"
)
slug: str
name: str
@@ -198,12 +306,42 @@ class StoreSubmissionEditRequest(pydantic.BaseModel):
recommended_schedule_cron: str | None = None
class ProfileDetails(pydantic.BaseModel):
name: str
username: str
description: str
links: list[str]
avatar_url: str | None = None
class StoreSubmissionAdminView(StoreSubmission):
internal_comments: str | None # Private admin notes
@classmethod
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
return cls(
**StoreSubmission.from_db(_sub).model_dump(),
internal_comments=_sub.internal_comments,
)
@classmethod
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
return cls(
**StoreSubmission.from_listing_version(_lv).model_dump(),
internal_comments=_lv.internalComments,
)
class StoreListingWithVersionsAdminView(pydantic.BaseModel):
"""A store listing with its version history"""
listing_id: str
graph_id: str
slug: str
active_listing_version_id: str | None = None
has_approved_version: bool = False
creator_email: str | None = None
latest_version: StoreSubmissionAdminView | None = None
versions: list[StoreSubmissionAdminView] = []
class StoreListingsWithVersionsAdminViewResponse(pydantic.BaseModel):
"""Response model for listings with version history"""
listings: list[StoreListingWithVersionsAdminView]
pagination: Pagination
class StoreReview(pydantic.BaseModel):

View File

@@ -1,203 +0,0 @@
import datetime
import prisma.enums
from . import model as store_model
def test_pagination():
pagination = store_model.Pagination(
total_items=100, total_pages=5, current_page=2, page_size=20
)
assert pagination.total_items == 100
assert pagination.total_pages == 5
assert pagination.current_page == 2
assert pagination.page_size == 20
def test_store_agent():
agent = store_model.StoreAgent(
slug="test-agent",
agent_name="Test Agent",
agent_image="test.jpg",
creator="creator1",
creator_avatar="avatar.jpg",
sub_heading="Test subheading",
description="Test description",
runs=50,
rating=4.5,
agent_graph_id="test-graph-id",
)
assert agent.slug == "test-agent"
assert agent.agent_name == "Test Agent"
assert agent.runs == 50
assert agent.rating == 4.5
assert agent.agent_graph_id == "test-graph-id"
def test_store_agents_response():
response = store_model.StoreAgentsResponse(
agents=[
store_model.StoreAgent(
slug="test-agent",
agent_name="Test Agent",
agent_image="test.jpg",
creator="creator1",
creator_avatar="avatar.jpg",
sub_heading="Test subheading",
description="Test description",
runs=50,
rating=4.5,
agent_graph_id="test-graph-id",
)
],
pagination=store_model.Pagination(
total_items=1, total_pages=1, current_page=1, page_size=20
),
)
assert len(response.agents) == 1
assert response.pagination.total_items == 1
def test_store_agent_details():
details = store_model.StoreAgentDetails(
store_listing_version_id="version123",
slug="test-agent",
agent_name="Test Agent",
agent_video="video.mp4",
agent_output_demo="demo.mp4",
agent_image=["image1.jpg", "image2.jpg"],
creator="creator1",
creator_avatar="avatar.jpg",
sub_heading="Test subheading",
description="Test description",
categories=["cat1", "cat2"],
runs=50,
rating=4.5,
versions=["1.0", "2.0"],
agentGraphVersions=["1", "2"],
agentGraphId="test-graph-id",
last_updated=datetime.datetime.now(),
)
assert details.slug == "test-agent"
assert len(details.agent_image) == 2
assert len(details.categories) == 2
assert len(details.versions) == 2
def test_creator():
creator = store_model.Creator(
agent_rating=4.8,
agent_runs=1000,
name="Test Creator",
username="creator1",
description="Test description",
avatar_url="avatar.jpg",
num_agents=5,
is_featured=False,
)
assert creator.name == "Test Creator"
assert creator.num_agents == 5
def test_creators_response():
response = store_model.CreatorsResponse(
creators=[
store_model.Creator(
agent_rating=4.8,
agent_runs=1000,
name="Test Creator",
username="creator1",
description="Test description",
avatar_url="avatar.jpg",
num_agents=5,
is_featured=False,
)
],
pagination=store_model.Pagination(
total_items=1, total_pages=1, current_page=1, page_size=20
),
)
assert len(response.creators) == 1
assert response.pagination.total_items == 1
def test_creator_details():
details = store_model.CreatorDetails(
name="Test Creator",
username="creator1",
description="Test description",
links=["link1.com", "link2.com"],
avatar_url="avatar.jpg",
agent_rating=4.8,
agent_runs=1000,
top_categories=["cat1", "cat2"],
)
assert details.name == "Test Creator"
assert len(details.links) == 2
assert details.agent_rating == 4.8
assert len(details.top_categories) == 2
def test_store_submission():
submission = store_model.StoreSubmission(
listing_id="listing123",
agent_id="agent123",
agent_version=1,
sub_heading="Test subheading",
name="Test Agent",
slug="test-agent",
description="Test description",
image_urls=["image1.jpg", "image2.jpg"],
date_submitted=datetime.datetime(2023, 1, 1),
status=prisma.enums.SubmissionStatus.PENDING,
runs=50,
rating=4.5,
)
assert submission.name == "Test Agent"
assert len(submission.image_urls) == 2
assert submission.status == prisma.enums.SubmissionStatus.PENDING
def test_store_submissions_response():
response = store_model.StoreSubmissionsResponse(
submissions=[
store_model.StoreSubmission(
listing_id="listing123",
agent_id="agent123",
agent_version=1,
sub_heading="Test subheading",
name="Test Agent",
slug="test-agent",
description="Test description",
image_urls=["image1.jpg"],
date_submitted=datetime.datetime(2023, 1, 1),
status=prisma.enums.SubmissionStatus.PENDING,
runs=50,
rating=4.5,
)
],
pagination=store_model.Pagination(
total_items=1, total_pages=1, current_page=1, page_size=20
),
)
assert len(response.submissions) == 1
assert response.pagination.total_items == 1
def test_store_submission_request():
request = store_model.StoreSubmissionRequest(
agent_id="agent123",
agent_version=1,
slug="test-agent",
name="Test Agent",
sub_heading="Test subheading",
video_url="video.mp4",
image_urls=["image1.jpg", "image2.jpg"],
description="Test description",
categories=["cat1", "cat2"],
)
assert request.agent_id == "agent123"
assert request.agent_version == 1
assert len(request.image_urls) == 2
assert len(request.categories) == 2

View File

@@ -1,16 +1,17 @@
import logging
import tempfile
import typing
import urllib.parse
from typing import Literal
import autogpt_libs.auth
import fastapi
import fastapi.responses
import prisma.enums
from fastapi import Query, Security
from pydantic import BaseModel
import backend.data.graph
import backend.util.json
from backend.util.exceptions import NotFoundError
from backend.util.models import Pagination
from . import cache as store_cache
@@ -34,22 +35,15 @@ router = fastapi.APIRouter()
"/profile",
summary="Get user profile",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.ProfileDetails,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def get_profile(
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Get the profile details for the authenticated user.
Cached for 1 hour per user.
"""
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.ProfileDetails:
"""Get the profile details for the authenticated user."""
profile = await store_db.get_user_profile(user_id)
if profile is None:
return fastapi.responses.JSONResponse(
status_code=404,
content={"detail": "Profile not found"},
)
raise NotFoundError("User does not have a profile yet")
return profile
@@ -57,98 +51,17 @@ async def get_profile(
"/profile",
summary="Update user profile",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.CreatorDetails,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def update_or_create_profile(
profile: store_model.Profile,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Update the store profile for the authenticated user.
Args:
profile (Profile): The updated profile details
user_id (str): ID of the authenticated user
Returns:
CreatorDetails: The updated profile
Raises:
HTTPException: If there is an error updating the profile
"""
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.ProfileDetails:
"""Update the store profile for the authenticated user."""
updated_profile = await store_db.update_profile(user_id=user_id, profile=profile)
return updated_profile
##############################################
############### Agent Endpoints ##############
##############################################
@router.get(
"/agents",
summary="List store agents",
tags=["store", "public"],
response_model=store_model.StoreAgentsResponse,
)
async def get_agents(
featured: bool = False,
creator: str | None = None,
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
search_query: str | None = None,
category: str | None = None,
page: int = 1,
page_size: int = 20,
):
"""
Get a paginated list of agents from the store with optional filtering and sorting.
Args:
featured (bool, optional): Filter to only show featured agents. Defaults to False.
creator (str | None, optional): Filter agents by creator username. Defaults to None.
sorted_by (str | None, optional): Sort agents by "runs" or "rating". Defaults to None.
search_query (str | None, optional): Search agents by name, subheading and description. Defaults to None.
category (str | None, optional): Filter agents by category. Defaults to None.
page (int, optional): Page number for pagination. Defaults to 1.
page_size (int, optional): Number of agents per page. Defaults to 20.
Returns:
StoreAgentsResponse: Paginated list of agents matching the filters
Raises:
HTTPException: If page or page_size are less than 1
Used for:
- Home Page Featured Agents
- Home Page Top Agents
- Search Results
- Agent Details - Other Agents By Creator
- Agent Details - Similar Agents
- Creator Details - Agents By Creator
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
agents = await store_cache._get_cached_store_agents(
featured=featured,
creator=creator,
sorted_by=sorted_by,
search_query=search_query,
category=category,
page=page,
page_size=page_size,
)
return agents
##############################################
############### Search Endpoints #############
##############################################
@@ -158,60 +71,30 @@ async def get_agents(
"/search",
summary="Unified search across all content types",
tags=["store", "public"],
response_model=store_model.UnifiedSearchResponse,
)
async def unified_search(
query: str,
content_types: list[str] | None = fastapi.Query(
content_types: list[prisma.enums.ContentType] | None = Query(
default=None,
description="Content types to search: STORE_AGENT, BLOCK, DOCUMENTATION. If not specified, searches all.",
description="Content types to search. If not specified, searches all.",
),
page: int = 1,
page_size: int = 20,
user_id: str | None = fastapi.Security(
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
user_id: str | None = Security(
autogpt_libs.auth.get_optional_user_id, use_cache=False
),
):
) -> store_model.UnifiedSearchResponse:
"""
Search across all content types (store agents, blocks, documentation) using hybrid search.
Search across all content types (marketplace agents, blocks, documentation)
using hybrid search.
Combines semantic (embedding-based) and lexical (text-based) search for best results.
Args:
query: The search query string
content_types: Optional list of content types to filter by (STORE_AGENT, BLOCK, DOCUMENTATION)
page: Page number for pagination (default 1)
page_size: Number of results per page (default 20)
user_id: Optional authenticated user ID (for user-scoped content in future)
Returns:
UnifiedSearchResponse: Paginated list of search results with relevance scores
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
# Convert string content types to enum
content_type_enums: list[prisma.enums.ContentType] | None = None
if content_types:
try:
content_type_enums = [prisma.enums.ContentType(ct) for ct in content_types]
except ValueError as e:
raise fastapi.HTTPException(
status_code=422,
detail=f"Invalid content type. Valid values: STORE_AGENT, BLOCK, DOCUMENTATION. Error: {e}",
)
# Perform unified hybrid search
results, total = await store_hybrid_search.unified_hybrid_search(
query=query,
content_types=content_type_enums,
content_types=content_types,
user_id=user_id,
page=page,
page_size=page_size,
@@ -245,22 +128,69 @@ async def unified_search(
)
##############################################
############### Agent Endpoints ##############
##############################################
@router.get(
"/agents",
summary="List store agents",
tags=["store", "public"],
)
async def get_agents(
featured: bool = Query(
default=False, description="Filter to only show featured agents"
),
creator: str | None = Query(
default=None, description="Filter agents by creator username"
),
category: str | None = Query(default=None, description="Filter agents by category"),
search_query: str | None = Query(
default=None, description="Literal + semantic search on names and descriptions"
),
sorted_by: store_db.StoreAgentsSortOptions | None = Query(
default=None,
description="Property to sort results by. Ignored if search_query is provided.",
),
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
) -> store_model.StoreAgentsResponse:
"""
Get a paginated list of agents from the marketplace,
with optional filtering and sorting.
Used for:
- Home Page Featured Agents
- Home Page Top Agents
- Search Results
- Agent Details - Other Agents By Creator
- Agent Details - Similar Agents
- Creator Details - Agents By Creator
"""
agents = await store_cache._get_cached_store_agents(
featured=featured,
creator=creator,
sorted_by=sorted_by,
search_query=search_query,
category=category,
page=page,
page_size=page_size,
)
return agents
@router.get(
"/agents/{username}/{agent_name}",
summary="Get specific agent",
tags=["store", "public"],
response_model=store_model.StoreAgentDetails,
)
async def get_agent(
async def get_agent_by_name(
username: str,
agent_name: str,
include_changelog: bool = fastapi.Query(default=False),
):
"""
This is only used on the AgentDetails Page.
It returns the store listing agents details.
"""
include_changelog: bool = Query(default=False),
) -> store_model.StoreAgentDetails:
"""Get details of a marketplace agent"""
username = urllib.parse.unquote(username).lower()
# URL decode the agent name since it comes from the URL path
agent_name = urllib.parse.unquote(agent_name).lower()
@@ -270,76 +200,82 @@ async def get_agent(
return agent
@router.get(
"/graph/{store_listing_version_id}",
summary="Get agent graph",
tags=["store"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
)
async def get_graph_meta_by_store_listing_version_id(
store_listing_version_id: str,
) -> backend.data.graph.GraphModelWithoutNodes:
"""
Get Agent Graph from Store Listing Version ID.
"""
graph = await store_db.get_available_graph(store_listing_version_id)
return graph
@router.get(
"/agents/{store_listing_version_id}",
summary="Get agent by version",
tags=["store"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreAgentDetails,
)
async def get_store_agent(store_listing_version_id: str):
"""
Get Store Agent Details from Store Listing Version ID.
"""
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
return agent
@router.post(
"/agents/{username}/{agent_name}/review",
summary="Create agent review",
tags=["store"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreReview,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def create_review(
async def post_user_review_for_agent(
username: str,
agent_name: str,
review: store_model.StoreReviewCreate,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Create a review for a store agent.
Args:
username: Creator's username
agent_name: Name/slug of the agent
review: Review details including score and optional comments
user_id: ID of authenticated user creating the review
Returns:
The created review
"""
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.StoreReview:
"""Post a user review on a marketplace agent listing"""
username = urllib.parse.unquote(username).lower()
agent_name = urllib.parse.unquote(agent_name).lower()
# Create the review
created_review = await store_db.create_store_review(
user_id=user_id,
store_listing_version_id=review.store_listing_version_id,
score=review.score,
comments=review.comments,
)
return created_review
@router.get(
"/listings/versions/{store_listing_version_id}",
summary="Get agent by version",
tags=["store"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def get_agent_by_listing_version(
store_listing_version_id: str,
) -> store_model.StoreAgentDetails:
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
return agent
@router.get(
"/listings/versions/{store_listing_version_id}/graph",
summary="Get agent graph",
tags=["store"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def get_graph_meta_by_store_listing_version_id(
store_listing_version_id: str,
) -> backend.data.graph.GraphModelWithoutNodes:
"""Get outline of graph belonging to a specific marketplace listing version"""
graph = await store_db.get_available_graph(store_listing_version_id)
return graph
@router.get(
"/listings/versions/{store_listing_version_id}/graph/download",
summary="Download agent file",
tags=["store", "public"],
)
async def download_agent_file(
store_listing_version_id: str,
) -> fastapi.responses.FileResponse:
"""Download agent graph file for a specific marketplace listing version"""
graph_data = await store_db.get_agent(store_listing_version_id)
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
# Sending graph as a stream (similar to marketplace v1)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as tmp_file:
tmp_file.write(backend.util.json.dumps(graph_data))
tmp_file.flush()
return fastapi.responses.FileResponse(
tmp_file.name, filename=file_name, media_type="application/json"
)
##############################################
############# Creator Endpoints #############
##############################################
@@ -349,37 +285,19 @@ async def create_review(
"/creators",
summary="List store creators",
tags=["store", "public"],
response_model=store_model.CreatorsResponse,
)
async def get_creators(
featured: bool = False,
search_query: str | None = None,
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
page: int = 1,
page_size: int = 20,
):
"""
This is needed for:
- Home Page Featured Creators
- Search Results Page
---
To support this functionality we need:
- featured: bool - to limit the list to just featured agents
- search_query: str - vector search based on the creators profile description.
- sorted_by: [agent_rating, agent_runs] -
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
featured: bool = Query(
default=False, description="Filter to only show featured creators"
),
search_query: str | None = Query(
default=None, description="Literal + semantic search on names and descriptions"
),
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
) -> store_model.CreatorsResponse:
"""List or search marketplace creators"""
creators = await store_cache._get_cached_store_creators(
featured=featured,
search_query=search_query,
@@ -391,18 +309,12 @@ async def get_creators(
@router.get(
"/creator/{username}",
"/creators/{username}",
summary="Get creator details",
tags=["store", "public"],
response_model=store_model.CreatorDetails,
)
async def get_creator(
username: str,
):
"""
Get the details of a creator.
- Creator Details Page
"""
async def get_creator(username: str) -> store_model.CreatorDetails:
"""Get details on a marketplace creator"""
username = urllib.parse.unquote(username).lower()
creator = await store_cache._get_cached_creator_details(username=username)
return creator
@@ -414,20 +326,17 @@ async def get_creator(
@router.get(
"/myagents",
"/my-unpublished-agents",
summary="Get my agents",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.MyAgentsResponse,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def get_my_agents(
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
page: typing.Annotated[int, fastapi.Query(ge=1)] = 1,
page_size: typing.Annotated[int, fastapi.Query(ge=1)] = 20,
):
"""
Get user's own agents.
"""
async def get_my_unpublished_agents(
user_id: str = Security(autogpt_libs.auth.get_user_id),
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
) -> store_model.MyUnpublishedAgentsResponse:
"""List the authenticated user's unpublished agents"""
agents = await store_db.get_my_agents(user_id, page=page, page_size=page_size)
return agents
@@ -436,28 +345,17 @@ async def get_my_agents(
"/submissions/{submission_id}",
summary="Delete store submission",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=bool,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def delete_submission(
submission_id: str,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Delete a store listing submission.
Args:
user_id (str): ID of the authenticated user
submission_id (str): ID of the submission to be deleted
Returns:
bool: True if the submission was successfully deleted, False otherwise
"""
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> bool:
"""Delete a marketplace listing submission"""
result = await store_db.delete_store_submission(
user_id=user_id,
submission_id=submission_id,
)
return result
@@ -465,37 +363,14 @@ async def delete_submission(
"/submissions",
summary="List my submissions",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreSubmissionsResponse,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def get_submissions(
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
page: int = 1,
page_size: int = 20,
):
"""
Get a paginated list of store submissions for the authenticated user.
Args:
user_id (str): ID of the authenticated user
page (int, optional): Page number for pagination. Defaults to 1.
page_size (int, optional): Number of submissions per page. Defaults to 20.
Returns:
StoreListingsResponse: Paginated list of store submissions
Raises:
HTTPException: If page or page_size are less than 1
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
user_id: str = Security(autogpt_libs.auth.get_user_id),
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
) -> store_model.StoreSubmissionsResponse:
"""List the authenticated user's marketplace listing submissions"""
listings = await store_db.get_store_submissions(
user_id=user_id,
page=page,
@@ -508,30 +383,17 @@ async def get_submissions(
"/submissions",
summary="Create store submission",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreSubmission,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def create_submission(
submission_request: store_model.StoreSubmissionRequest,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Create a new store listing submission.
Args:
submission_request (StoreSubmissionRequest): The submission details
user_id (str): ID of the authenticated user submitting the listing
Returns:
StoreSubmission: The created store submission
Raises:
HTTPException: If there is an error creating the submission
"""
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.StoreSubmission:
"""Submit a new marketplace listing for review"""
result = await store_db.create_store_submission(
user_id=user_id,
agent_id=submission_request.agent_id,
agent_version=submission_request.agent_version,
graph_id=submission_request.graph_id,
graph_version=submission_request.graph_version,
slug=submission_request.slug,
name=submission_request.name,
video_url=submission_request.video_url,
@@ -544,7 +406,6 @@ async def create_submission(
changes_summary=submission_request.changes_summary or "Initial Submission",
recommended_schedule_cron=submission_request.recommended_schedule_cron,
)
return result
@@ -552,28 +413,14 @@ async def create_submission(
"/submissions/{store_listing_version_id}",
summary="Edit store submission",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreSubmission,
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def edit_submission(
store_listing_version_id: str,
submission_request: store_model.StoreSubmissionEditRequest,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Edit an existing store listing submission.
Args:
store_listing_version_id (str): ID of the store listing version to edit
submission_request (StoreSubmissionRequest): The updated submission details
user_id (str): ID of the authenticated user editing the listing
Returns:
StoreSubmission: The updated store submission
Raises:
HTTPException: If there is an error editing the submission
"""
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.StoreSubmission:
"""Update a pending marketplace listing submission"""
result = await store_db.edit_store_submission(
user_id=user_id,
store_listing_version_id=store_listing_version_id,
@@ -588,7 +435,6 @@ async def edit_submission(
changes_summary=submission_request.changes_summary,
recommended_schedule_cron=submission_request.recommended_schedule_cron,
)
return result
@@ -596,115 +442,61 @@ async def edit_submission(
"/submissions/media",
summary="Upload submission media",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def upload_submission_media(
file: fastapi.UploadFile,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Upload media (images/videos) for a store listing submission.
Args:
file (UploadFile): The media file to upload
user_id (str): ID of the authenticated user uploading the media
Returns:
str: URL of the uploaded media file
Raises:
HTTPException: If there is an error uploading the media
"""
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> str:
"""Upload media for a marketplace listing submission"""
media_url = await store_media.upload_media(user_id=user_id, file=file)
return media_url
class ImageURLResponse(BaseModel):
image_url: str
@router.post(
"/submissions/generate_image",
summary="Generate submission image",
tags=["store", "private"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def generate_image(
agent_id: str,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
) -> fastapi.responses.Response:
graph_id: str,
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> ImageURLResponse:
"""
Generate an image for a store listing submission.
Args:
agent_id (str): ID of the agent to generate an image for
user_id (str): ID of the authenticated user
Returns:
JSONResponse: JSON containing the URL of the generated image
Generate an image for a marketplace listing submission based on the properties
of a given graph.
"""
agent = await backend.data.graph.get_graph(
graph_id=agent_id, version=None, user_id=user_id
graph = await backend.data.graph.get_graph(
graph_id=graph_id, version=None, user_id=user_id
)
if not agent:
raise fastapi.HTTPException(
status_code=404, detail=f"Agent with ID {agent_id} not found"
)
if not graph:
raise NotFoundError(f"Agent graph #{graph_id} not found")
# Use .jpeg here since we are generating JPEG images
filename = f"agent_{agent_id}.jpeg"
filename = f"agent_{graph_id}.jpeg"
existing_url = await store_media.check_media_exists(user_id, filename)
if existing_url:
logger.info(f"Using existing image for agent {agent_id}")
return fastapi.responses.JSONResponse(content={"image_url": existing_url})
logger.info(f"Using existing image for agent graph {graph_id}")
return ImageURLResponse(image_url=existing_url)
# Generate agent image as JPEG
image = await store_image_gen.generate_agent_image(agent=agent)
image = await store_image_gen.generate_agent_image(agent=graph)
# Create UploadFile with the correct filename and content_type
image_file = fastapi.UploadFile(
file=image,
filename=filename,
)
image_url = await store_media.upload_media(
user_id=user_id, file=image_file, use_file_name=True
)
return fastapi.responses.JSONResponse(content={"image_url": image_url})
@router.get(
"/download/agents/{store_listing_version_id}",
summary="Download agent file",
tags=["store", "public"],
)
async def download_agent_file(
store_listing_version_id: str = fastapi.Path(
..., description="The ID of the agent to download"
),
) -> fastapi.responses.FileResponse:
"""
Download the agent file by streaming its content.
Args:
store_listing_version_id (str): The ID of the agent to download
Returns:
StreamingResponse: A streaming response containing the agent's graph data.
Raises:
HTTPException: If the agent is not found or an unexpected error occurs.
"""
graph_data = await store_db.get_agent(store_listing_version_id)
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
# Sending graph as a stream (similar to marketplace v1)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as tmp_file:
tmp_file.write(backend.util.json.dumps(graph_data))
tmp_file.flush()
return fastapi.responses.FileResponse(
tmp_file.name, filename=file_name, media_type="application/json"
)
return ImageURLResponse(image_url=image_url)
##############################################

View File

@@ -8,6 +8,8 @@ import pytest
import pytest_mock
from pytest_snapshot.plugin import Snapshot
from backend.api.features.store.db import StoreAgentsSortOptions
from . import model as store_model
from . import routes as store_routes
@@ -196,7 +198,7 @@ def test_get_agents_sorted(
mock_db_call.assert_called_once_with(
featured=False,
creators=None,
sorted_by="runs",
sorted_by=StoreAgentsSortOptions.RUNS,
search_query=None,
category=None,
page=1,
@@ -380,9 +382,11 @@ def test_get_agent_details(
runs=100,
rating=4.5,
versions=["1.0.0", "1.1.0"],
agentGraphVersions=["1", "2"],
agentGraphId="test-graph-id",
graph_versions=["1", "2"],
graph_id="test-graph-id",
last_updated=FIXED_NOW,
active_version_id="test-version-id",
has_approved_version=True,
)
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_agent_details")
mock_db_call.return_value = mocked_value
@@ -435,15 +439,17 @@ def test_get_creators_pagination(
) -> None:
mocked_value = store_model.CreatorsResponse(
creators=[
store_model.Creator(
store_model.CreatorDetails(
name=f"Creator {i}",
username=f"creator{i}",
description=f"Creator {i} description",
avatar_url=f"avatar{i}.jpg",
num_agents=1,
agent_rating=4.5,
agent_runs=100,
description=f"Creator {i} description",
links=[f"user{i}.link.com"],
is_featured=False,
num_agents=1,
agent_runs=100,
agent_rating=4.5,
top_categories=["cat1", "cat2", "cat3"],
)
for i in range(5)
],
@@ -496,19 +502,19 @@ def test_get_creator_details(
mocked_value = store_model.CreatorDetails(
name="Test User",
username="creator1",
avatar_url="avatar.jpg",
description="Test creator description",
links=["link1.com", "link2.com"],
avatar_url="avatar.jpg",
agent_rating=4.8,
is_featured=True,
num_agents=5,
agent_runs=1000,
agent_rating=4.8,
top_categories=["category1", "category2"],
)
mock_db_call = mocker.patch(
"backend.api.features.store.db.get_store_creator_details"
)
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_creator")
mock_db_call.return_value = mocked_value
response = client.get("/creator/creator1")
response = client.get("/creators/creator1")
assert response.status_code == 200
data = store_model.CreatorDetails.model_validate(response.json())
@@ -528,19 +534,26 @@ def test_get_submissions_success(
submissions=[
store_model.StoreSubmission(
listing_id="test-listing-id",
name="Test Agent",
description="Test agent description",
image_urls=["test.jpg"],
date_submitted=FIXED_NOW,
status=prisma.enums.SubmissionStatus.APPROVED,
runs=50,
rating=4.2,
agent_id="test-agent-id",
agent_version=1,
sub_heading="Test agent subheading",
user_id="test-user-id",
slug="test-agent",
video_url="test.mp4",
listing_version_id="test-version-id",
listing_version=1,
graph_id="test-agent-id",
graph_version=1,
name="Test Agent",
sub_heading="Test agent subheading",
description="Test agent description",
instructions="Click the button!",
categories=["test-category"],
image_urls=["test.jpg"],
video_url="test.mp4",
agent_output_demo_url="demo_video.mp4",
submitted_at=FIXED_NOW,
changes_summary="Initial Submission",
status=prisma.enums.SubmissionStatus.APPROVED,
run_count=50,
review_count=5,
review_avg_rating=4.2,
)
],
pagination=store_model.Pagination(

View File

@@ -11,6 +11,7 @@ import pytest
from backend.util.models import Pagination
from . import cache as store_cache
from .db import StoreAgentsSortOptions
from .model import StoreAgent, StoreAgentsResponse
@@ -215,7 +216,7 @@ class TestCacheDeletion:
await store_cache._get_cached_store_agents(
featured=True,
creator="testuser",
sorted_by="rating",
sorted_by=StoreAgentsSortOptions.RATING,
search_query="AI assistant",
category="productivity",
page=2,
@@ -227,7 +228,7 @@ class TestCacheDeletion:
deleted = store_cache._get_cached_store_agents.cache_delete(
featured=True,
creator="testuser",
sorted_by="rating",
sorted_by=StoreAgentsSortOptions.RATING,
search_query="AI assistant",
category="productivity",
page=2,
@@ -239,7 +240,7 @@ class TestCacheDeletion:
deleted = store_cache._get_cached_store_agents.cache_delete(
featured=True,
creator="testuser",
sorted_by="rating",
sorted_by=StoreAgentsSortOptions.RATING,
search_query="AI assistant",
category="productivity",
page=2,

View File

@@ -449,7 +449,6 @@ async def execute_graph_block(
async def upload_file(
user_id: Annotated[str, Security(get_user_id)],
file: UploadFile = File(...),
provider: str = "gcs",
expiration_hours: int = 24,
) -> UploadFileResponse:
"""
@@ -512,7 +511,6 @@ async def upload_file(
storage_path = await cloud_storage.store_file(
content=content,
filename=file_name,
provider=provider,
expiration_hours=expiration_hours,
user_id=user_id,
)

View File

@@ -515,7 +515,6 @@ async def test_upload_file_success(test_user_id: str):
result = await upload_file(
file=upload_file_mock,
user_id=test_user_id,
provider="gcs",
expiration_hours=24,
)
@@ -533,7 +532,6 @@ async def test_upload_file_success(test_user_id: str):
mock_handler.store_file.assert_called_once_with(
content=file_content,
filename="test.txt",
provider="gcs",
expiration_hours=24,
user_id=test_user_id,
)

View File

@@ -55,6 +55,7 @@ from backend.util.exceptions import (
MissingConfigError,
NotAuthorizedError,
NotFoundError,
PreconditionFailed,
)
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
from backend.util.service import UnhealthyServiceError
@@ -275,6 +276,7 @@ app.add_exception_handler(RequestValidationError, validation_error_handler)
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
app.add_exception_handler(MissingConfigError, handle_internal_http_error(503))
app.add_exception_handler(ValueError, handle_internal_http_error(400))
app.add_exception_handler(PreconditionFailed, handle_internal_http_error(428))
app.add_exception_handler(Exception, handle_internal_http_error(500))
app.include_router(backend.api.features.v1.v1_router, tags=["v1"], prefix="/api")

View File

@@ -418,6 +418,8 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
_optimized_description: ClassVar[str | None] = None
def __init__(
self,
id: str = "",
@@ -470,6 +472,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.block_type = block_type
self.webhook_config = webhook_config
self.is_sensitive_action = is_sensitive_action
# Read from ClassVar set by initialize_blocks()
self.optimized_description: str | None = type(self)._optimized_description
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
if self.webhook_config:
@@ -620,6 +624,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
graph_id: str,
graph_version: int,
execution_context: "ExecutionContext",
is_graph_execution: bool = True,
**kwargs,
) -> tuple[bool, BlockInput]:
"""
@@ -648,6 +653,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
graph_version=graph_version,
block_name=self.name,
editable=True,
is_graph_execution=is_graph_execution,
)
if decision is None:

View File

@@ -126,7 +126,7 @@ class PrintToConsoleBlock(Block):
output_schema=PrintToConsoleBlock.Output,
test_input={"text": "Hello, World!"},
is_sensitive_action=True,
disabled=True, # Disabled per Nick Tindle's request (OPEN-3000)
disabled=True,
test_output=[
("output", "Hello, World!"),
("status", "printed"),

View File

@@ -142,7 +142,7 @@ class BaseE2BExecutorMixin:
start_timestamp = ts_result.stdout.strip() if ts_result.stdout else None
# Execute the code
execution = await sandbox.run_code(
execution = await sandbox.run_code( # type: ignore[attr-defined]
code,
language=language.value,
on_error=lambda e: sandbox.kill(), # Kill the sandbox on error

View File

@@ -67,6 +67,7 @@ class HITLReviewHelper:
graph_version: int,
block_name: str = "Block",
editable: bool = False,
is_graph_execution: bool = True,
) -> Optional[ReviewResult]:
"""
Handle a review request for a block that requires human review.
@@ -143,10 +144,11 @@ class HITLReviewHelper:
logger.info(
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
)
await HITLReviewHelper.update_node_execution_status(
exec_id=node_exec_id,
status=ExecutionStatus.REVIEW,
)
if is_graph_execution:
await HITLReviewHelper.update_node_execution_status(
exec_id=node_exec_id,
status=ExecutionStatus.REVIEW,
)
return None # Signal that execution should pause
# Mark review as processed if not already done
@@ -168,6 +170,7 @@ class HITLReviewHelper:
graph_version: int,
block_name: str = "Block",
editable: bool = False,
is_graph_execution: bool = True,
) -> Optional[ReviewDecision]:
"""
Handle a review request and return the decision in a single call.
@@ -197,6 +200,7 @@ class HITLReviewHelper:
graph_version=graph_version,
block_name=block_name,
editable=editable,
is_graph_execution=is_graph_execution,
)
if review_result is None:

View File

@@ -17,7 +17,7 @@ from backend.blocks.jina._auth import (
from backend.blocks.search import GetRequest
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
from backend.util.request import HTTPClientError, HTTPServerError, validate_url
from backend.util.request import HTTPClientError, HTTPServerError, validate_url_host
class SearchTheWebBlock(Block, GetRequest):
@@ -112,7 +112,7 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
) -> BlockOutput:
if input_data.raw_content:
try:
parsed_url, _, _ = await validate_url(input_data.url, [])
parsed_url, _, _ = await validate_url_host(input_data.url)
url = parsed_url.geturl()
except ValueError as e:
yield "error", f"Invalid URL: {e}"

View File

@@ -31,10 +31,14 @@ from backend.data.model import (
)
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.clients import OPENROUTER_BASE_URL
from backend.util.logging import TruncatedLogger
from backend.util.prompt import compress_context, estimate_token_count
from backend.util.request import validate_url_host
from backend.util.settings import Settings
from backend.util.text import TextFormatter
settings = Settings()
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
fmt = TextFormatter(autoescape=False)
@@ -116,6 +120,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
CLAUDE_4_6_OPUS = "claude-opus-4-6"
CLAUDE_4_6_SONNET = "claude-sonnet-4-6"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
# AI/ML API models
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
@@ -274,6 +279,9 @@ MODEL_METADATA = {
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
), # claude-opus-4-6
LlmModel.CLAUDE_4_6_SONNET: ModelMetadata(
"anthropic", 200000, 64000, "Claude Sonnet 4.6", "Anthropic", "Anthropic", 3
), # claude-sonnet-4-6
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
), # claude-opus-4-5-20251101
@@ -800,6 +808,11 @@ async def llm_call(
if tools:
raise ValueError("Ollama does not support tools.")
# Validate user-provided Ollama host to prevent SSRF etc.
await validate_url_host(
ollama_host, trusted_hostnames=[settings.config.ollama_host]
)
client = ollama.AsyncClient(host=ollama_host)
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
@@ -821,7 +834,7 @@ async def llm_call(
elif provider == "open_router":
tools_param = tools if tools else openai.NOT_GIVEN
client = openai.AsyncOpenAI(
base_url="https://openrouter.ai/api/v1",
base_url=OPENROUTER_BASE_URL,
api_key=credentials.api_key.get_secret_value(),
)

View File

@@ -21,6 +21,7 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.clients import OPENROUTER_BASE_URL
from backend.util.logging import TruncatedLogger
logger = TruncatedLogger(logging.getLogger(__name__), "[Perplexity-Block]")
@@ -136,7 +137,7 @@ class PerplexityBlock(Block):
) -> dict[str, Any]:
"""Call Perplexity via OpenRouter and extract annotations."""
client = openai.AsyncOpenAI(
base_url="https://openrouter.ai/api/v1",
base_url=OPENROUTER_BASE_URL,
api_key=credentials.api_key.get_secret_value(),
)

View File

@@ -83,7 +83,8 @@ class StagehandRecommendedLlmModel(str, Enum):
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
# Anthropic
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929" # Keep for backwards compat
CLAUDE_4_6_SONNET = "claude-sonnet-4-6"
@property
def provider_name(self) -> str:
@@ -137,7 +138,7 @@ class StagehandObserveBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
@@ -227,7 +228,7 @@ class StagehandActBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
@@ -324,7 +325,7 @@ class StagehandExtractBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()

View File

@@ -1,8 +1,8 @@
import logging
from typing import Literal
from pydantic import BaseModel
from backend.api.features.store.db import StoreAgentsSortOptions
from backend.blocks._base import (
Block,
BlockCategory,
@@ -176,8 +176,8 @@ class SearchStoreAgentsBlock(Block):
category: str | None = SchemaField(
description="Filter by category", default=None
)
sort_by: Literal["rating", "runs", "name", "updated_at"] = SchemaField(
description="How to sort the results", default="rating"
sort_by: StoreAgentsSortOptions = SchemaField(
description="How to sort the results", default=StoreAgentsSortOptions.RATING
)
limit: int = SchemaField(
description="Maximum number of results to return", default=10, ge=1, le=100
@@ -278,7 +278,7 @@ class SearchStoreAgentsBlock(Block):
self,
query: str | None = None,
category: str | None = None,
sort_by: Literal["rating", "runs", "name", "updated_at"] = "rating",
sort_by: StoreAgentsSortOptions = StoreAgentsSortOptions.RATING,
limit: int = 10,
) -> SearchAgentsResponse:
"""

View File

@@ -2,6 +2,7 @@ from unittest.mock import MagicMock
import pytest
from backend.api.features.store.db import StoreAgentsSortOptions
from backend.blocks.system.library_operations import (
AddToLibraryFromStoreBlock,
LibraryAgent,
@@ -121,7 +122,10 @@ async def test_search_store_agents_block(mocker):
)
input_data = block.Input(
query="test", category="productivity", sort_by="rating", limit=10
query="test",
category="productivity",
sort_by=StoreAgentsSortOptions.RATING, # type: ignore[reportArgumentType]
limit=10,
)
outputs = {}

View File

@@ -22,6 +22,7 @@ from backend.copilot.model import (
update_session_title,
upsert_chat_session,
)
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.response_model import (
StreamBaseResponse,
StreamError,
@@ -62,8 +63,8 @@ async def _update_title_async(
"""Generate and persist a session title in the background."""
try:
title = await _generate_session_title(message, user_id, session_id)
if title:
await update_session_title(session_id, title)
if title and user_id:
await update_session_title(session_id, user_id, title, only_if_empty=True)
except Exception as e:
logger.warning("[Baseline] Failed to update session title: %s", e)
@@ -176,14 +177,17 @@ async def stream_chat_completion_baseline(
# changes from concurrent chats updating business understanding.
is_first_turn = len(session.messages) <= 1
if is_first_turn:
system_prompt, _ = await _build_system_prompt(
base_system_prompt, _ = await _build_system_prompt(
user_id, has_conversation_history=False
)
else:
system_prompt, _ = await _build_system_prompt(
base_system_prompt, _ = await _build_system_prompt(
user_id=None, has_conversation_history=True
)
# Append tool documentation and technical notes
system_prompt = base_system_prompt + get_baseline_supplement()
# Compress context if approaching the model's token limit
messages_for_context = await _compress_session_messages(session.messages)

View File

@@ -1,10 +1,13 @@
"""Configuration management for chat system."""
import os
from typing import Literal
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings
from backend.util.clients import OPENROUTER_BASE_URL
class ChatConfig(BaseSettings):
"""Configuration for the chat system."""
@@ -19,7 +22,7 @@ class ChatConfig(BaseSettings):
)
api_key: str | None = Field(default=None, description="OpenAI API key")
base_url: str | None = Field(
default="https://openrouter.ai/api/v1",
default=OPENROUTER_BASE_URL,
description="Base URL for API (e.g., for OpenRouter)",
)
@@ -112,9 +115,37 @@ class ChatConfig(BaseSettings):
description="E2B sandbox template to use for copilot sessions.",
)
e2b_sandbox_timeout: int = Field(
default=43200, # 12 hours — same as session_ttl
description="E2B sandbox keepalive timeout in seconds.",
default=10800, # 3 hours — wall-clock timeout, not idle; explicit pause is primary
description="E2B sandbox running-time timeout (seconds). "
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
"mechanism; this is the safety net.",
)
e2b_sandbox_on_timeout: Literal["kill", "pause"] = Field(
default="pause",
description="E2B lifecycle action on timeout: 'pause' (default, free) or 'kill'.",
)
@property
def e2b_active(self) -> bool:
"""True when E2B is enabled and the API key is present.
Single source of truth for "should we use E2B right now?".
Prefer this over combining ``use_e2b_sandbox`` and ``e2b_api_key``
separately at call sites.
"""
return self.use_e2b_sandbox and bool(self.e2b_api_key)
@property
def active_e2b_api_key(self) -> str | None:
"""Return the E2B API key when E2B is enabled and configured, else None.
Combines the ``use_e2b_sandbox`` flag check and key presence into one.
Use in callers::
if api_key := config.active_e2b_api_key:
# E2B is active; api_key is narrowed to str
"""
return self.e2b_api_key if self.e2b_active else None
@field_validator("use_e2b_sandbox", mode="before")
@classmethod
@@ -164,7 +195,7 @@ class ChatConfig(BaseSettings):
if not v:
v = os.getenv("OPENAI_BASE_URL")
if not v:
v = "https://openrouter.ai/api/v1"
v = OPENROUTER_BASE_URL
return v
@field_validator("use_claude_agent_sdk", mode="before")

View File

@@ -0,0 +1,38 @@
"""Unit tests for ChatConfig."""
import pytest
from .config import ChatConfig
# Env vars that the ChatConfig validators read — must be cleared so they don't
# override the explicit constructor values we pass in each test.
_E2B_ENV_VARS = (
"CHAT_USE_E2B_SANDBOX",
"CHAT_E2B_API_KEY",
"E2B_API_KEY",
)
@pytest.fixture(autouse=True)
def _clean_e2b_env(monkeypatch: pytest.MonkeyPatch) -> None:
for var in _E2B_ENV_VARS:
monkeypatch.delenv(var, raising=False)
class TestE2BActive:
"""Tests for the e2b_active property — single source of truth for E2B usage."""
def test_both_enabled_and_key_present_returns_true(self):
"""e2b_active is True when use_e2b_sandbox=True and e2b_api_key is set."""
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key="test-key")
assert cfg.e2b_active is True
def test_enabled_but_missing_key_returns_false(self):
"""e2b_active is False when use_e2b_sandbox=True but e2b_api_key is absent."""
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key=None)
assert cfg.e2b_active is False
def test_disabled_returns_false(self):
"""e2b_active is False when use_e2b_sandbox=False regardless of key."""
cfg = ChatConfig(use_e2b_sandbox=False, e2b_api_key="test-key")
assert cfg.e2b_active is False

View File

@@ -6,6 +6,32 @@
COPILOT_ERROR_PREFIX = "[__COPILOT_ERROR_f7a1__]" # Renders as ErrorCard
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
# Prefix for all synthetic IDs generated by CoPilot block execution.
# Used to distinguish CoPilot-generated records from real graph execution records
# in PendingHumanReview and other tables.
COPILOT_SYNTHETIC_ID_PREFIX = "copilot-"
# Sub-prefixes for session-scoped and node-scoped synthetic IDs.
COPILOT_SESSION_PREFIX = f"{COPILOT_SYNTHETIC_ID_PREFIX}session-"
COPILOT_NODE_PREFIX = f"{COPILOT_SYNTHETIC_ID_PREFIX}node-"
# Separator used in synthetic node_exec_id to encode node_id.
# Format: "{node_id}:{random_hex}" — extract node_id via rsplit(":", 1)[0]
COPILOT_NODE_EXEC_ID_SEPARATOR = ":"
# Compaction notice messages shown to users.
COMPACTION_DONE_MSG = "Earlier messages were summarized to fit within context limits."
COMPACTION_TOOL_NAME = "context_compaction"
def is_copilot_synthetic_id(id_value: str) -> bool:
"""Check if an ID is a CoPilot synthetic ID (not from a real graph execution)."""
return id_value.startswith(COPILOT_SYNTHETIC_ID_PREFIX)
def parse_node_id_from_exec_id(node_exec_id: str) -> str:
"""Extract node_id from a synthetic node_exec_id.
Format: "{node_id}:{random_hex}" → returns "{node_id}".
"""
return node_exec_id.rsplit(COPILOT_NODE_EXEC_ID_SEPARATOR, 1)[0]

View File

@@ -0,0 +1,115 @@
"""Shared execution context for copilot SDK tool handlers.
All context variables and their accessors live here so that
``tool_adapter``, ``file_ref``, and ``e2b_file_tools`` can import them
without creating circular dependencies.
"""
import os
import re
from contextvars import ContextVar
from typing import TYPE_CHECKING
from backend.copilot.model import ChatSession
if TYPE_CHECKING:
from e2b import AsyncSandbox
# Allowed base directory for the Read tool.
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# Encoded project-directory name for the current session (e.g.
# "-private-tmp-copilot-<uuid>"). Set by set_execution_context() so path
# validation can scope tool-results reads to the current session.
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
_current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
)
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
"_current_sandbox", default=None
)
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
def _encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does."""
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
def set_execution_context(
user_id: str | None,
session: ChatSession,
sandbox: "AsyncSandbox | None" = None,
sdk_cwd: str | None = None,
) -> None:
"""Set per-turn context variables used by file-resolution tool handlers."""
_current_user_id.set(user_id)
_current_session.set(session)
_current_sandbox.set(sandbox)
_current_sdk_cwd.set(sdk_cwd or "")
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
def get_execution_context() -> tuple[str | None, ChatSession | None]:
"""Return the current (user_id, session) pair for the active request."""
return _current_user_id.get(), _current_session.get()
def get_current_sandbox() -> "AsyncSandbox | None":
"""Return the E2B sandbox for the current session, or None if not active."""
return _current_sandbox.get()
def get_sdk_cwd() -> str:
"""Return the SDK working directory for the current session (empty string if unset)."""
return _current_sdk_cwd.get()
E2B_WORKDIR = "/home/user"
def resolve_sandbox_path(path: str) -> str:
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
Raises :class:`ValueError` if the resolved path escapes the sandbox.
"""
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
normalized = os.path.normpath(candidate)
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
return normalized
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
"""Return True if *path* is within an allowed host-filesystem location.
Allowed:
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
- Files under ``~/.claude/projects/<encoded-cwd>/tool-results/`` (SDK tool-results)
"""
if not path:
return False
if path.startswith("~"):
resolved = os.path.realpath(os.path.expanduser(path))
elif not os.path.isabs(path) and sdk_cwd:
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
else:
resolved = os.path.realpath(path)
if sdk_cwd:
norm_cwd = os.path.realpath(sdk_cwd)
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
return True
encoded = _current_project_dir.get("")
if encoded:
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
if resolved == tool_results_dir or resolved.startswith(
tool_results_dir + os.sep
):
return True
return False

View File

@@ -0,0 +1,163 @@
"""Tests for context.py — execution context variables and path helpers."""
from __future__ import annotations
import os
import tempfile
from unittest.mock import MagicMock
import pytest
from backend.copilot.context import (
_SDK_PROJECTS_DIR,
_current_project_dir,
get_current_sandbox,
get_execution_context,
get_sdk_cwd,
is_allowed_local_path,
resolve_sandbox_path,
set_execution_context,
)
def _make_session() -> MagicMock:
s = MagicMock()
s.session_id = "test-session"
return s
# ---------------------------------------------------------------------------
# Context variable getters
# ---------------------------------------------------------------------------
def test_get_execution_context_defaults():
"""get_execution_context returns (None, session) when user_id is not set."""
set_execution_context(None, _make_session())
user_id, session = get_execution_context()
assert user_id is None
assert session is not None
def test_set_and_get_execution_context():
"""set_execution_context stores user_id and session."""
mock_session = _make_session()
set_execution_context("user-abc", mock_session)
user_id, session = get_execution_context()
assert user_id == "user-abc"
assert session is mock_session
def test_get_current_sandbox_none_by_default():
"""get_current_sandbox returns None when no sandbox is set."""
set_execution_context("u1", _make_session(), sandbox=None)
assert get_current_sandbox() is None
def test_get_current_sandbox_returns_set_value():
"""get_current_sandbox returns the sandbox set via set_execution_context."""
mock_sandbox = MagicMock()
set_execution_context("u1", _make_session(), sandbox=mock_sandbox)
assert get_current_sandbox() is mock_sandbox
def test_get_sdk_cwd_empty_when_not_set():
"""get_sdk_cwd returns empty string when sdk_cwd is not set."""
set_execution_context("u1", _make_session(), sdk_cwd=None)
assert get_sdk_cwd() == ""
def test_get_sdk_cwd_returns_set_value():
"""get_sdk_cwd returns the value set via set_execution_context."""
set_execution_context("u1", _make_session(), sdk_cwd="/tmp/copilot-test")
assert get_sdk_cwd() == "/tmp/copilot-test"
# ---------------------------------------------------------------------------
# is_allowed_local_path
# ---------------------------------------------------------------------------
def test_is_allowed_local_path_empty():
assert not is_allowed_local_path("")
def test_is_allowed_local_path_inside_sdk_cwd():
with tempfile.TemporaryDirectory() as cwd:
path = os.path.join(cwd, "file.txt")
assert is_allowed_local_path(path, cwd)
def test_is_allowed_local_path_sdk_cwd_itself():
with tempfile.TemporaryDirectory() as cwd:
assert is_allowed_local_path(cwd, cwd)
def test_is_allowed_local_path_outside_sdk_cwd():
with tempfile.TemporaryDirectory() as cwd:
assert not is_allowed_local_path("/etc/passwd", cwd)
def test_is_allowed_local_path_no_sdk_cwd_no_project_dir():
"""Without sdk_cwd or project_dir, all paths are rejected."""
_current_project_dir.set("")
assert not is_allowed_local_path("/tmp/some-file.txt", sdk_cwd=None)
def test_is_allowed_local_path_tool_results_dir():
"""Files under the tool-results directory for the current project are allowed."""
encoded = "test-encoded-dir"
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
path = os.path.join(tool_results_dir, "output.txt")
_current_project_dir.set(encoded)
try:
assert is_allowed_local_path(path, sdk_cwd=None)
finally:
_current_project_dir.set("")
def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
"""A path adjacent to tool-results/ but not inside it is rejected."""
encoded = "test-encoded-dir"
sibling_path = os.path.join(_SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
_current_project_dir.set(encoded)
try:
assert not is_allowed_local_path(sibling_path, sdk_cwd=None)
finally:
_current_project_dir.set("")
# ---------------------------------------------------------------------------
# resolve_sandbox_path
# ---------------------------------------------------------------------------
def test_resolve_sandbox_path_absolute_valid():
assert (
resolve_sandbox_path("/home/user/project/main.py")
== "/home/user/project/main.py"
)
def test_resolve_sandbox_path_relative():
assert resolve_sandbox_path("project/main.py") == "/home/user/project/main.py"
def test_resolve_sandbox_path_workdir_itself():
assert resolve_sandbox_path("/home/user") == "/home/user"
def test_resolve_sandbox_path_normalizes_dots():
assert resolve_sandbox_path("/home/user/a/../b") == "/home/user/b"
def test_resolve_sandbox_path_escape_raises():
with pytest.raises(ValueError, match="/home/user"):
resolve_sandbox_path("/home/user/../../etc/passwd")
def test_resolve_sandbox_path_absolute_outside_raises():
with pytest.raises(ValueError, match="/home/user"):
resolve_sandbox_path("/etc/passwd")

View File

@@ -81,6 +81,35 @@ async def update_chat_session(
return ChatSession.from_db(session) if session else None
async def update_chat_session_title(
session_id: str,
user_id: str,
title: str,
*,
only_if_empty: bool = False,
) -> bool:
"""Update the title of a chat session, scoped to the owning user.
Always filters by (session_id, user_id) so callers cannot mutate another
user's session even when they know the session_id.
Args:
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
guard so auto-generated titles never overwrite a user-set title.
Returns True if a row was updated, False otherwise (session not found,
wrong user, or — when only_if_empty — title was already set).
"""
where: ChatSessionWhereInput = {"id": session_id, "userId": user_id}
if only_if_empty:
where["title"] = None
result = await PrismaChatSession.prisma().update_many(
where=where,
data={"title": title, "updatedAt": datetime.now(UTC)},
)
return result > 0
async def add_chat_message(
session_id: str,
role: str,

View File

@@ -469,8 +469,16 @@ async def upsert_chat_session(
)
db_error = e
# Save to cache (best-effort, even if DB failed)
# Save to cache (best-effort, even if DB failed).
# Title updates (update_session_title) run *outside* this lock because
# they only touch the title field, not messages. So a concurrent rename
# or auto-title may have written a newer title to Redis while this
# upsert was in progress. Always prefer the cached title to avoid
# overwriting it with the stale in-memory copy.
try:
existing_cached = await _get_session_from_cache(session.session_id)
if existing_cached and existing_cached.title:
session = session.model_copy(update={"title": existing_cached.title})
await cache_chat_session(session)
except Exception as e:
# If DB succeeded but cache failed, raise cache error
@@ -685,30 +693,48 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
return True
async def update_session_title(session_id: str, title: str) -> bool:
"""Update only the title of a chat session.
async def update_session_title(
session_id: str,
user_id: str,
title: str,
*,
only_if_empty: bool = False,
) -> bool:
"""Update the title of a chat session, scoped to the owning user.
This is a lightweight operation that doesn't touch messages, avoiding
race conditions with concurrent message updates. Use this for background
title generation instead of upsert_chat_session.
Lightweight operation that doesn't touch messages, avoiding race conditions
with concurrent message updates.
Args:
session_id: The session ID to update.
user_id: Owning user — the DB query filters on this.
title: The new title to set.
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
so auto-generated titles never overwrite a user-set title.
Returns:
True if updated successfully, False otherwise.
True if updated successfully, False otherwise (not found, wrong user,
or — when only_if_empty — title was already set).
"""
try:
result = await chat_db().update_chat_session(session_id=session_id, title=title)
if result is None:
logger.warning(f"Session {session_id} not found for title update")
updated = await chat_db().update_chat_session_title(
session_id, user_id, title, only_if_empty=only_if_empty
)
if not updated:
return False
# Invalidate the cache so the next access reloads from DB with the
# updated title. This avoids a read-modify-write on the full session
# blob, which could overwrite concurrent message updates.
await invalidate_session_cache(session_id)
# Update title in cache if it exists (instead of invalidating).
# This prevents race conditions where cache invalidation causes
# the frontend to see stale DB data while streaming is still in progress.
try:
cached = await _get_session_from_cache(session_id)
if cached:
cached.title = title
await cache_chat_session(cached)
except Exception as e:
logger.warning(
f"Cache title update failed for session {session_id} (non-critical): {e}"
)
return True
except Exception as e:

View File

@@ -0,0 +1,138 @@
"""Scheduler job to generate LLM-optimized block descriptions.
Runs periodically to rewrite block descriptions into concise, actionable
summaries that help the copilot LLM pick the right blocks during agent
generation.
"""
import asyncio
import logging
from backend.blocks import get_blocks
from backend.util.clients import get_database_manager_client, get_openai_client
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = (
"You are a technical writer for an automation platform. "
"Rewrite the following block description to be concise (under 50 words), "
"informative, and actionable. Focus on what the block does and when to "
"use it. Output ONLY the rewritten description, nothing else. "
"Do not use markdown formatting."
)
# Rate-limit delay between sequential LLM calls (seconds)
_RATE_LIMIT_DELAY = 0.5
# Maximum tokens for optimized description generation
_MAX_DESCRIPTION_TOKENS = 150
# Model for generating optimized descriptions (fast, cheap)
_MODEL = "gpt-4o-mini"
async def _optimize_descriptions(blocks: list[dict[str, str]]) -> dict[str, str]:
"""Call the shared OpenAI client to rewrite each block description."""
client = get_openai_client()
if client is None:
logger.error(
"No OpenAI client configured, skipping block description optimization"
)
return {}
results: dict[str, str] = {}
for block in blocks:
block_id = block["id"]
block_name = block["name"]
description = block["description"]
try:
response = await client.chat.completions.create(
model=_MODEL,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": f"Block name: {block_name}\nDescription: {description}",
},
],
max_tokens=_MAX_DESCRIPTION_TOKENS,
)
optimized = (response.choices[0].message.content or "").strip()
if optimized:
results[block_id] = optimized
logger.debug("Optimized description for %s", block_name)
else:
logger.warning("Empty response for block %s", block_name)
except Exception:
logger.warning(
"Failed to optimize description for %s", block_name, exc_info=True
)
await asyncio.sleep(_RATE_LIMIT_DELAY)
return results
def optimize_block_descriptions() -> dict[str, int]:
"""Generate optimized descriptions for blocks that don't have one yet.
Uses the shared OpenAI client to rewrite block descriptions into concise
summaries suitable for agent generation prompts.
Returns:
Dict with counts: processed, success, failed, skipped.
"""
db_client = get_database_manager_client()
blocks = db_client.get_blocks_needing_optimization()
if not blocks:
logger.info("All blocks already have optimized descriptions")
return {"processed": 0, "success": 0, "failed": 0, "skipped": 0}
logger.info("Found %d blocks needing optimized descriptions", len(blocks))
non_empty = [b for b in blocks if b.get("description", "").strip()]
skipped = len(blocks) - len(non_empty)
new_descriptions = asyncio.run(_optimize_descriptions(non_empty))
stats = {
"processed": len(non_empty),
"success": len(new_descriptions),
"failed": len(non_empty) - len(new_descriptions),
"skipped": skipped,
}
logger.info(
"Block description optimization complete: "
"%d/%d succeeded, %d failed, %d skipped",
stats["success"],
stats["processed"],
stats["failed"],
stats["skipped"],
)
if new_descriptions:
for block_id, optimized in new_descriptions.items():
db_client.update_block_optimized_description(block_id, optimized)
# Update in-memory descriptions first so the cache rebuilds with fresh data.
try:
block_classes = get_blocks()
for block_id, optimized in new_descriptions.items():
if block_id in block_classes:
block_classes[block_id]._optimized_description = optimized
logger.info(
"Updated %d in-memory block descriptions", len(new_descriptions)
)
except Exception:
logger.warning(
"Could not update in-memory block descriptions", exc_info=True
)
from backend.copilot.tools.agent_generator.blocks import (
reset_block_caches, # local to avoid circular import
)
reset_block_caches()
return stats

View File

@@ -0,0 +1,91 @@
"""Unit tests for optimize_blocks._optimize_descriptions."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from backend.copilot.optimize_blocks import _RATE_LIMIT_DELAY, _optimize_descriptions
def _make_client_response(text: str) -> MagicMock:
"""Build a minimal mock that looks like an OpenAI ChatCompletion response."""
choice = MagicMock()
choice.message.content = text
response = MagicMock()
response.choices = [choice]
return response
def _run(coro):
return asyncio.get_event_loop().run_until_complete(coro)
class TestOptimizeDescriptions:
"""Tests for _optimize_descriptions async function."""
def test_returns_empty_when_no_client(self):
with patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=None
):
result = _run(
_optimize_descriptions([{"id": "b1", "name": "B", "description": "d"}])
)
assert result == {}
def test_success_single_block(self):
client = MagicMock()
client.chat.completions.create = AsyncMock(
return_value=_make_client_response("Short desc.")
)
blocks = [{"id": "b1", "name": "MyBlock", "description": "A block."}]
with (
patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
),
patch(
"backend.copilot.optimize_blocks.asyncio.sleep", new_callable=AsyncMock
),
):
result = _run(_optimize_descriptions(blocks))
assert result == {"b1": "Short desc."}
client.chat.completions.create.assert_called_once()
def test_skips_block_on_exception(self):
client = MagicMock()
client.chat.completions.create = AsyncMock(side_effect=Exception("API error"))
blocks = [{"id": "b1", "name": "MyBlock", "description": "A block."}]
with (
patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
),
patch(
"backend.copilot.optimize_blocks.asyncio.sleep", new_callable=AsyncMock
),
):
result = _run(_optimize_descriptions(blocks))
assert result == {}
def test_sleeps_between_blocks(self):
client = MagicMock()
client.chat.completions.create = AsyncMock(
return_value=_make_client_response("desc")
)
blocks = [
{"id": "b1", "name": "B1", "description": "d1"},
{"id": "b2", "name": "B2", "description": "d2"},
]
sleep_mock = AsyncMock()
with (
patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
),
patch("backend.copilot.optimize_blocks.asyncio.sleep", sleep_mock),
):
_run(_optimize_descriptions(blocks))
assert sleep_mock.call_count == 2
sleep_mock.assert_called_with(_RATE_LIMIT_DELAY)

View File

@@ -0,0 +1,218 @@
"""Centralized prompt building logic for CoPilot.
This module contains all prompt construction functions and constants,
handling the distinction between:
- SDK mode vs Baseline mode (tool documentation needs)
- Local mode vs E2B mode (storage/filesystem differences)
"""
from backend.copilot.tools import TOOL_REGISTRY
# Shared technical notes that apply to both SDK and baseline modes
_SHARED_TOOL_NOTES = """\
### Sharing files with the user
After saving a file to the persistent workspace with `write_workspace_file`,
share it with the user by embedding the `download_url` from the response in
your message as a Markdown link or image:
- **Any file** — shows as a clickable download link:
`[report.csv](workspace://file_id#text/csv)`
- **Image** — renders inline in chat:
`![chart](workspace://file_id#image/png)`
- **Video** — renders inline in chat with player controls:
`![recording](workspace://file_id#video/mp4)`
The `download_url` field in the `write_workspace_file` response is already
in the correct format — paste it directly after the `(` in the Markdown.
### Passing file content to tools — @@agptfile: references
Instead of copying large file contents into a tool argument, pass a file
reference and the platform will load the content for you.
Syntax: `@@agptfile:<uri>[<start>-<end>]`
- `<uri>` **must** start with `workspace://` or `/` (absolute path):
- `workspace://<file_id>` — workspace file by ID
- `workspace:///<path>` — workspace file by virtual path
- `/absolute/local/path` — ephemeral or sdk_cwd file
- E2B sandbox absolute path (e.g. `/home/user/script.py`)
- `[<start>-<end>]` is an optional 1-indexed inclusive line range.
- URIs that do not start with `workspace://` or `/` are **not** expanded.
Examples:
```
@@agptfile:workspace://abc123
@@agptfile:workspace://abc123[10-50]
@@agptfile:workspace:///reports/q1.md
@@agptfile:/tmp/copilot-<session>/output.py[1-80]
@@agptfile:/home/user/script.py
```
You can embed a reference inside any string argument, or use it as the entire
value. Multiple references in one argument are all expanded.
### Sub-agent tasks
- When using the Task tool, NEVER set `run_in_background` to true.
All tasks must run in the foreground.
"""
# Environment-specific supplement templates
def _build_storage_supplement(
working_dir: str,
sandbox_type: str,
storage_system_1_name: str,
storage_system_1_characteristics: list[str],
storage_system_1_persistence: list[str],
file_move_name_1_to_2: str,
file_move_name_2_to_1: str,
) -> str:
"""Build storage/filesystem supplement for a specific environment.
Template function handles all formatting (bullets, indentation, markdown).
Callers provide clean data as lists of strings.
Args:
working_dir: Working directory path
sandbox_type: Description of bash_exec sandbox
storage_system_1_name: Name of primary storage (ephemeral or cloud)
storage_system_1_characteristics: List of characteristic descriptions
storage_system_1_persistence: List of persistence behavior descriptions
file_move_name_1_to_2: Direction label for primary→persistent
file_move_name_2_to_1: Direction label for persistent→primary
"""
# Format lists as bullet points with proper indentation
characteristics = "\n".join(f" - {c}" for c in storage_system_1_characteristics)
persistence = "\n".join(f" - {p}" for p in storage_system_1_persistence)
return f"""
## Tool notes
### Shell commands
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
for shell commands — it runs {sandbox_type}.
### Working directory
- Your working directory is: `{working_dir}`
- All SDK file tools AND `bash_exec` operate on the same filesystem
- Use relative paths or absolute paths under `{working_dir}` for all file operations
### Two storage systems — CRITICAL to understand
1. **{storage_system_1_name}** (`{working_dir}`):
{characteristics}
{persistence}
2. **Persistent workspace** (cloud storage):
- Files here **survive across sessions indefinitely**
### Moving files between storages
- **{file_move_name_1_to_2}**: Copy to persistent workspace
- **{file_move_name_2_to_1}**: Download for processing
### File persistence
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
{_SHARED_TOOL_NOTES}"""
# Pre-built supplements for common environments
def _get_local_storage_supplement(cwd: str) -> str:
"""Local ephemeral storage (files lost between turns)."""
return _build_storage_supplement(
working_dir=cwd,
sandbox_type="in a network-isolated sandbox",
storage_system_1_name="Ephemeral working directory",
storage_system_1_characteristics=[
"Shared by SDK Read/Write/Edit/Glob/Grep tools AND `bash_exec`",
],
storage_system_1_persistence=[
"Files here are **lost between turns** — do NOT rely on them persisting",
"Use for temporary work: running scripts, processing data, etc.",
],
file_move_name_1_to_2="Ephemeral → Persistent",
file_move_name_2_to_1="Persistent → Ephemeral",
)
def _get_cloud_sandbox_supplement() -> str:
"""Cloud persistent sandbox (files survive across turns in session)."""
return _build_storage_supplement(
working_dir="/home/user",
sandbox_type="in a cloud sandbox with full internet access",
storage_system_1_name="Cloud sandbox",
storage_system_1_characteristics=[
"Shared by all file tools AND `bash_exec` — same filesystem",
"Full Linux environment with internet access",
],
storage_system_1_persistence=[
"Files **persist across turns** within the current session",
"Lost when the session expires (12 h inactivity)",
],
file_move_name_1_to_2="Sandbox → Persistent",
file_move_name_2_to_1="Persistent → Sandbox",
)
def _generate_tool_documentation() -> str:
"""Auto-generate tool documentation from TOOL_REGISTRY.
NOTE: This is ONLY used in baseline mode (direct OpenAI API).
SDK mode doesn't need it since Claude gets tool schemas automatically.
This generates a complete list of available tools with their descriptions,
ensuring the documentation stays in sync with the actual tool implementations.
All workflow guidance is now embedded in individual tool descriptions.
Only documents tools that are available in the current environment
(checked via tool.is_available property).
"""
docs = "\n## AVAILABLE TOOLS\n\n"
# Sort tools alphabetically for consistent output
# Filter by is_available to match get_available_tools() behavior
for name in sorted(TOOL_REGISTRY.keys()):
tool = TOOL_REGISTRY[name]
if not tool.is_available:
continue
schema = tool.as_openai_tool()
desc = schema["function"].get("description", "No description available")
# Format as bullet list with tool name in code style
docs += f"- **`{name}`**: {desc}\n"
return docs
def get_sdk_supplement(use_e2b: bool, cwd: str = "") -> str:
"""Get the supplement for SDK mode (Claude Agent SDK).
SDK mode does NOT include tool documentation because Claude automatically
receives tool schemas from the SDK. Only includes technical notes about
storage systems and execution environment.
Args:
use_e2b: Whether E2B cloud sandbox is being used
cwd: Current working directory (only used in local_storage mode)
Returns:
The supplement string to append to the system prompt
"""
if use_e2b:
return _get_cloud_sandbox_supplement()
return _get_local_storage_supplement(cwd)
def get_baseline_supplement() -> str:
"""Get the supplement for baseline mode (direct OpenAI API).
Baseline mode INCLUDES auto-generated tool documentation because the
direct API doesn't automatically provide tool schemas to Claude.
Also includes shared technical notes (but NOT SDK-specific environment details).
Returns:
The supplement string to append to the system prompt
"""
tool_docs = _generate_tool_documentation()
return tool_docs + _SHARED_TOOL_NOTES

View File

@@ -0,0 +1,155 @@
## Agent Generation Guide
You can create, edit, and customize agents directly. You ARE the brain —
generate the agent JSON yourself using block schemas, then validate and save.
### Workflow for Creating/Editing Agents
1. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
search for relevant blocks. This returns block IDs, names, descriptions,
and full input/output schemas.
2. **Find library agents**: Call `find_library_agent` to discover reusable
agents that can be composed as sub-agents via `AgentExecutorBlock`.
3. **Generate JSON**: Build the agent JSON using block schemas:
- Use block IDs from step 1 as `block_id` in nodes
- Wire outputs to inputs using links
- Set design-time config in `input_default`
- Use `AgentInputBlock` for values the user provides at runtime
4. **Write to workspace**: Save the JSON to a workspace file so the user
can review it: `write_workspace_file(filename="agent.json", content=...)`
5. **Validate**: Call `validate_agent_graph` with the agent JSON to check
for errors
6. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
or fix manually based on the error descriptions. Iterate until valid.
7. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
the final `agent_json`
### Agent JSON Structure
```json
{
"id": "<UUID v4>", // auto-generated if omitted
"version": 1,
"is_active": true,
"name": "Agent Name",
"description": "What the agent does",
"nodes": [
{
"id": "<UUID v4>",
"block_id": "<block UUID from find_block>",
"input_default": {
"field_name": "design-time value"
},
"metadata": {
"position": {"x": 0, "y": 0},
"customized_name": "Optional display name"
}
}
],
"links": [
{
"id": "<UUID v4>",
"source_id": "<source node UUID>",
"source_name": "output_field_name",
"sink_id": "<sink node UUID>",
"sink_name": "input_field_name",
"is_static": false
}
]
}
```
### REQUIRED: AgentInputBlock and AgentOutputBlock
Every agent MUST include at least one AgentInputBlock and one AgentOutputBlock.
These define the agent's interface — what it accepts and what it produces.
**AgentInputBlock** (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`):
- Defines a user-facing input field on the agent
- Required `input_default` fields: `name` (str), `value` (default: null)
- Optional: `title`, `description`, `placeholder_values` (for dropdowns)
- Output: `result` — the user-provided value at runtime
- Create one AgentInputBlock per distinct input the agent needs
**AgentOutputBlock** (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`):
- Defines a user-facing output displayed after the agent runs
- Required `input_default` fields: `name` (str)
- The `value` input should be linked from another block's output
- Optional: `title`, `description`, `format` (Jinja2 template)
- Create one AgentOutputBlock per distinct result to show the user
Without these blocks, the agent has no interface and the user cannot provide
inputs or see outputs. NEVER skip them.
### Key Rules
- **Name & description**: Include `name` and `description` in the agent JSON
when creating a new agent, or when editing and the agent's purpose changed.
Without these the agent gets a generic default name.
- **Design-time vs runtime**: `input_default` = values known at build time.
For user-provided values, create an `AgentInputBlock` node and link its
output to the consuming block's input.
- **Credentials**: Do NOT require credentials upfront. Users configure
credentials later in the platform UI after the agent is saved.
- **Node spacing**: Position nodes with at least 800 X-units between them.
- **Nested properties**: Use `parentField_#_childField` notation in link
sink_name/source_name to access nested object fields.
- **is_static links**: Set `is_static: true` when the link carries a
design-time constant (matches a field in inputSchema with a default).
- **ConditionBlock**: Needs a `StoreValueBlock` wired to its `value2` input.
- **Prompt templates**: Use `{{variable}}` (double curly braces) for
literal braces in prompt strings — single `{` and `}` are for
template variables.
- **AgentExecutorBlock**: When composing sub-agents, set `graph_id` and
`graph_version` in input_default, and wire inputs/outputs to match
the sub-agent's schema.
### Using Sub-Agents (AgentExecutorBlock)
To compose agents using other agents as sub-agents:
1. Call `find_library_agent` to find the sub-agent — the response includes
`graph_id`, `graph_version`, `input_schema`, and `output_schema`
2. Create an `AgentExecutorBlock` node (ID: `e189baac-8c20-45a1-94a7-55177ea42565`)
3. Set `input_default`:
- `graph_id`: from the library agent's `graph_id`
- `graph_version`: from the library agent's `graph_version`
- `input_schema`: from the library agent's `input_schema` (JSON Schema)
- `output_schema`: from the library agent's `output_schema` (JSON Schema)
- `user_id`: leave as `""` (filled at runtime)
- `inputs`: `{}` (populated by links at runtime)
4. Wire inputs: link to sink names matching the sub-agent's `input_schema`
property names (e.g., if input_schema has a `"url"` property, use
`"url"` as the sink_name)
5. Wire outputs: link from source names matching the sub-agent's
`output_schema` property names
6. Pass `library_agent_ids` to `create_agent`/`customize_agent` with
the library agent IDs used, so the fixer can validate schemas
### Using MCP Tools (MCPToolBlock)
To use an MCP (Model Context Protocol) tool as a node in the agent:
1. The user must specify which MCP server URL and tool name they want
2. Create an `MCPToolBlock` node (ID: `a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4`)
3. Set `input_default`:
- `server_url`: the MCP server URL (e.g. `"https://mcp.example.com/sse"`)
- `selected_tool`: the tool name on that server
- `tool_input_schema`: JSON Schema for the tool's inputs
- `tool_arguments`: `{}` (populated by links or hardcoded values)
4. The block requires MCP credentials — the user configures these in the
platform UI after the agent is saved
5. Wire inputs using the tool argument field name directly as the sink_name
(e.g., `query`, NOT `tool_arguments_#_query`). The execution engine
automatically collects top-level fields matching tool_input_schema into
tool_arguments.
6. Output: `result` (the tool's return value) and `error` (error message)
### Example: Simple AI Text Processor
A minimal agent with input, processing, and output:
- Node 1: `AgentInputBlock` (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`,
input_default: {"name": "user_text", "title": "Text to process"},
output: "result")
- Node 2: `AITextGeneratorBlock` (input: "prompt" linked from Node 1's "result")
- Node 3: `AgentOutputBlock` (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`,
input_default: {"name": "summary", "title": "Summary"},
input: "value" linked from Node 2's output)

View File

@@ -8,8 +8,6 @@ SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
"""
from __future__ import annotations
import itertools
import json
import logging
@@ -17,36 +15,23 @@ import os
import shlex
from typing import Any, Callable
from backend.copilot.tools.e2b_sandbox import E2B_WORKDIR
from backend.copilot.context import (
E2B_WORKDIR,
get_current_sandbox,
get_sdk_cwd,
is_allowed_local_path,
resolve_sandbox_path,
)
logger = logging.getLogger(__name__)
# Lazy imports to break circular dependency with tool_adapter.
def _get_sandbox(): # type: ignore[return]
from .tool_adapter import get_current_sandbox # noqa: E402
def _get_sandbox():
return get_current_sandbox()
def _is_allowed_local(path: str) -> bool:
from .tool_adapter import is_allowed_local_path # noqa: E402
return is_allowed_local_path(path)
def _resolve_remote(path: str) -> str:
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
Raises :class:`ValueError` if the resolved path escapes the sandbox.
"""
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
normalized = os.path.normpath(candidate)
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
return normalized
return is_allowed_local_path(path, get_sdk_cwd())
def _mcp(text: str, *, error: bool = False) -> dict[str, Any]:
@@ -63,7 +48,7 @@ def _get_sandbox_and_path(
if sandbox is None:
return _mcp("No E2B sandbox available", error=True)
try:
remote = _resolve_remote(file_path)
remote = resolve_sandbox_path(file_path)
except ValueError as exc:
return _mcp(str(exc), error=True)
return sandbox, remote
@@ -73,6 +58,7 @@ def _get_sandbox_and_path(
async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
"""Read lines from a sandbox file, falling back to the local host for SDK-internal paths."""
file_path: str = args.get("file_path", "")
offset: int = max(0, int(args.get("offset", 0)))
limit: int = max(1, int(args.get("limit", 2000)))
@@ -104,6 +90,7 @@ async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
"""Write content to a sandbox file, creating parent directories as needed."""
file_path: str = args.get("file_path", "")
content: str = args.get("content", "")
@@ -127,6 +114,7 @@ async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
"""Replace a substring in a sandbox file, with optional replace-all support."""
file_path: str = args.get("file_path", "")
old_string: str = args.get("old_string", "")
new_string: str = args.get("new_string", "")
@@ -172,6 +160,7 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
"""Find files matching a name pattern inside the sandbox using ``find``."""
pattern: str = args.get("pattern", "")
path: str = args.get("path", "")
@@ -183,7 +172,7 @@ async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
return _mcp("No E2B sandbox available", error=True)
try:
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
search_dir = resolve_sandbox_path(path) if path else E2B_WORKDIR
except ValueError as exc:
return _mcp(str(exc), error=True)
@@ -198,6 +187,7 @@ async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
"""Search file contents by regex inside the sandbox using ``grep -rn``."""
pattern: str = args.get("pattern", "")
path: str = args.get("path", "")
include: str = args.get("include", "")
@@ -210,7 +200,7 @@ async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
return _mcp("No E2B sandbox available", error=True)
try:
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
search_dir = resolve_sandbox_path(path) if path else E2B_WORKDIR
except ValueError as exc:
return _mcp(str(exc), error=True)
@@ -238,7 +228,7 @@ def _read_local(file_path: str, offset: int, limit: int) -> dict[str, Any]:
return _mcp(f"Path not allowed: {file_path}", error=True)
expanded = os.path.realpath(os.path.expanduser(file_path))
try:
with open(expanded) as fh:
with open(expanded, encoding="utf-8", errors="replace") as fh:
selected = list(itertools.islice(fh, offset, offset + limit))
numbered = "".join(
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)

View File

@@ -7,59 +7,60 @@ import os
import pytest
from .e2b_file_tools import _read_local, _resolve_remote
from .tool_adapter import _current_project_dir
from backend.copilot.context import _current_project_dir
from .e2b_file_tools import _read_local, resolve_sandbox_path
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# ---------------------------------------------------------------------------
# _resolve_remote — sandbox path normalisation & boundary enforcement
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
# ---------------------------------------------------------------------------
class TestResolveRemote:
class TestResolveSandboxPath:
def test_relative_path_resolved(self):
assert _resolve_remote("src/main.py") == "/home/user/src/main.py"
assert resolve_sandbox_path("src/main.py") == "/home/user/src/main.py"
def test_absolute_within_sandbox(self):
assert _resolve_remote("/home/user/file.txt") == "/home/user/file.txt"
assert resolve_sandbox_path("/home/user/file.txt") == "/home/user/file.txt"
def test_workdir_itself(self):
assert _resolve_remote("/home/user") == "/home/user"
assert resolve_sandbox_path("/home/user") == "/home/user"
def test_relative_dotslash(self):
assert _resolve_remote("./README.md") == "/home/user/README.md"
assert resolve_sandbox_path("./README.md") == "/home/user/README.md"
def test_traversal_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("../../etc/passwd")
resolve_sandbox_path("../../etc/passwd")
def test_absolute_traversal_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/home/user/../../etc/passwd")
resolve_sandbox_path("/home/user/../../etc/passwd")
def test_absolute_outside_sandbox_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/etc/passwd")
resolve_sandbox_path("/etc/passwd")
def test_root_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/")
resolve_sandbox_path("/")
def test_home_other_user_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/home/other/file.txt")
resolve_sandbox_path("/home/other/file.txt")
def test_deep_nested_allowed(self):
assert _resolve_remote("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
assert resolve_sandbox_path("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
def test_trailing_slash_normalised(self):
assert _resolve_remote("src/") == "/home/user/src"
assert resolve_sandbox_path("src/") == "/home/user/src"
def test_double_dots_within_sandbox_ok(self):
"""Path that resolves back within /home/user is allowed."""
assert _resolve_remote("a/b/../c.txt") == "/home/user/a/c.txt"
assert resolve_sandbox_path("a/b/../c.txt") == "/home/user/a/c.txt"
# ---------------------------------------------------------------------------

View File

@@ -0,0 +1,281 @@
"""File reference protocol for tool call inputs.
Allows the LLM to pass a file reference instead of embedding large content
inline. The processor expands ``@@agptfile:<uri>[<start>-<end>]`` tokens in tool
arguments before the tool is executed.
Protocol
--------
@@agptfile:<uri>[<start>-<end>]
``<uri>`` (required)
- ``workspace://<file_id>`` — workspace file by ID
- ``workspace://<file_id>#<mime>`` — same, MIME hint is ignored for reads
- ``workspace:///<path>`` — workspace file by virtual path
- ``/absolute/local/path`` — ephemeral or sdk_cwd file (validated by
:func:`~backend.copilot.sdk.tool_adapter.is_allowed_local_path`)
- Any absolute path that resolves inside the E2B sandbox
(``/home/user/...``) when a sandbox is active
``[<start>-<end>]`` (optional)
Line range, 1-indexed inclusive. Examples: ``[1-100]``, ``[50-200]``.
Omit to read the entire file.
Examples
--------
@@agptfile:workspace://abc123
@@agptfile:workspace://abc123[10-50]
@@agptfile:workspace:///reports/q1.md
@@agptfile:/tmp/copilot-<session>/output.py[1-80]
@@agptfile:/home/user/script.sh
"""
import itertools
import logging
import os
import re
from dataclasses import dataclass
from typing import Any
from backend.copilot.context import (
get_current_sandbox,
get_sdk_cwd,
is_allowed_local_path,
resolve_sandbox_path,
)
from backend.copilot.model import ChatSession
from backend.copilot.tools.workspace_files import get_manager
from backend.util.file import parse_workspace_uri
class FileRefExpansionError(Exception):
"""Raised when a ``@@agptfile:`` reference in tool call args fails to resolve.
Separating this from inline substitution lets callers (e.g. the MCP tool
wrapper) block tool execution and surface a helpful error to the model
rather than passing an ``[file-ref error: …]`` string as actual input.
"""
logger = logging.getLogger(__name__)
FILE_REF_PREFIX = "@@agptfile:"
# Matches: @@agptfile:<uri>[start-end]?
# Group 1 URI; must start with '/' (absolute path) or 'workspace://'
# Group 2 start line (optional)
# Group 3 end line (optional)
_FILE_REF_RE = re.compile(
re.escape(FILE_REF_PREFIX) + r"((?:workspace://|/)[^\[\s]*)(?:\[(\d+)-(\d+)\])?"
)
# Maximum characters returned for a single file reference expansion.
_MAX_EXPAND_CHARS = 200_000
# Maximum total characters across all @@agptfile: expansions in one string.
_MAX_TOTAL_EXPAND_CHARS = 1_000_000
@dataclass
class FileRef:
uri: str
start_line: int | None # 1-indexed, inclusive
end_line: int | None # 1-indexed, inclusive
def parse_file_ref(text: str) -> FileRef | None:
"""Return a :class:`FileRef` if *text* is a bare file reference token.
A "bare token" means the entire string matches the ``@@agptfile:...`` pattern
(after stripping whitespace). Use :func:`expand_file_refs_in_string` to
expand references embedded in larger strings.
"""
m = _FILE_REF_RE.fullmatch(text.strip())
if not m:
return None
start = int(m.group(2)) if m.group(2) else None
end = int(m.group(3)) if m.group(3) else None
if start is not None and start < 1:
return None
if end is not None and end < 1:
return None
if start is not None and end is not None and end < start:
return None
return FileRef(uri=m.group(1), start_line=start, end_line=end)
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
"""Slice *text* to the requested 1-indexed line range (inclusive)."""
if start is None and end is None:
return text
lines = text.splitlines(keepends=True)
s = (start - 1) if start is not None else 0
e = end if end is not None else len(lines)
selected = list(itertools.islice(lines, s, e))
return "".join(selected)
async def read_file_bytes(
uri: str,
user_id: str | None,
session: ChatSession,
) -> bytes:
"""Resolve *uri* to raw bytes using workspace, local, or E2B path logic.
Raises :class:`ValueError` if the URI cannot be resolved.
"""
# Strip MIME fragment (e.g. workspace://id#mime) before dispatching.
plain = uri.split("#")[0] if uri.startswith("workspace://") else uri
if plain.startswith("workspace://"):
if not user_id:
raise ValueError("workspace:// file references require authentication")
manager = await get_manager(user_id, session.session_id)
ws = parse_workspace_uri(plain)
try:
return await (
manager.read_file(ws.file_ref)
if ws.is_path
else manager.read_file_by_id(ws.file_ref)
)
except FileNotFoundError:
raise ValueError(f"File not found: {plain}")
except Exception as exc:
raise ValueError(f"Failed to read {plain}: {exc}") from exc
if is_allowed_local_path(plain, get_sdk_cwd()):
resolved = os.path.realpath(os.path.expanduser(plain))
try:
with open(resolved, "rb") as fh:
return fh.read()
except FileNotFoundError:
raise ValueError(f"File not found: {plain}")
except Exception as exc:
raise ValueError(f"Failed to read {plain}: {exc}") from exc
sandbox = get_current_sandbox()
if sandbox is not None:
try:
remote = resolve_sandbox_path(plain)
except ValueError as exc:
raise ValueError(
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
) from exc
try:
return bytes(await sandbox.files.read(remote, format="bytes"))
except Exception as exc:
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
raise ValueError(
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
)
async def resolve_file_ref(
ref: FileRef,
user_id: str | None,
session: ChatSession,
) -> str:
"""Resolve a :class:`FileRef` to its text content."""
raw = await read_file_bytes(ref.uri, user_id, session)
return _apply_line_range(
raw.decode("utf-8", errors="replace"), ref.start_line, ref.end_line
)
async def expand_file_refs_in_string(
text: str,
user_id: str | None,
session: "ChatSession",
*,
raise_on_error: bool = False,
) -> str:
"""Expand all ``@@agptfile:...`` tokens in *text*, returning the substituted string.
Non-reference text is passed through unchanged.
If *raise_on_error* is ``False`` (default), expansion errors are surfaced
inline as ``[file-ref error: <message>]`` — useful for display/log contexts
where partial expansion is acceptable.
If *raise_on_error* is ``True``, any resolution failure raises
:class:`FileRefExpansionError` immediately so the caller can block the
operation and surface a clean error to the model.
"""
if FILE_REF_PREFIX not in text:
return text
result: list[str] = []
last_end = 0
total_chars = 0
for m in _FILE_REF_RE.finditer(text):
result.append(text[last_end : m.start()])
start = int(m.group(2)) if m.group(2) else None
end = int(m.group(3)) if m.group(3) else None
if (start is not None and start < 1) or (end is not None and end < 1):
msg = f"line numbers must be >= 1: {m.group(0)}"
if raise_on_error:
raise FileRefExpansionError(msg)
result.append(f"[file-ref error: {msg}]")
last_end = m.end()
continue
if start is not None and end is not None and end < start:
msg = f"end line must be >= start line: {m.group(0)}"
if raise_on_error:
raise FileRefExpansionError(msg)
result.append(f"[file-ref error: {msg}]")
last_end = m.end()
continue
ref = FileRef(uri=m.group(1), start_line=start, end_line=end)
try:
content = await resolve_file_ref(ref, user_id, session)
if len(content) > _MAX_EXPAND_CHARS:
content = content[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
remaining = _MAX_TOTAL_EXPAND_CHARS - total_chars
if remaining <= 0:
content = "[file-ref budget exhausted: total expansion limit reached]"
elif len(content) > remaining:
content = content[:remaining] + "\n... [total budget exhausted]"
total_chars += len(content)
result.append(content)
except ValueError as exc:
logger.warning("file-ref expansion failed for %r: %s", m.group(0), exc)
if raise_on_error:
raise FileRefExpansionError(str(exc)) from exc
result.append(f"[file-ref error: {exc}]")
last_end = m.end()
result.append(text[last_end:])
return "".join(result)
async def expand_file_refs_in_args(
args: dict[str, Any],
user_id: str | None,
session: "ChatSession",
) -> dict[str, Any]:
"""Recursively expand ``@@agptfile:...`` references in tool call arguments.
String values are expanded in-place. Nested dicts and lists are
traversed. Non-string scalars are returned unchanged.
Raises :class:`FileRefExpansionError` if any reference fails to resolve,
so the tool is *not* executed with an error string as its input. The
caller (the MCP tool wrapper) should convert this into an MCP error
response that lets the model correct the reference before retrying.
"""
if not args:
return args
async def _expand(value: Any) -> Any:
if isinstance(value, str):
return await expand_file_refs_in_string(
value, user_id, session, raise_on_error=True
)
if isinstance(value, dict):
return {k: await _expand(v) for k, v in value.items()}
if isinstance(value, list):
return [await _expand(item) for item in value]
return value
return {k: await _expand(v) for k, v in args.items()}

View File

@@ -0,0 +1,328 @@
"""Integration tests for @@agptfile: reference expansion in tool calls.
These tests verify the end-to-end behaviour of the file reference protocol:
- Parsing @@agptfile: tokens from tool arguments
- Resolving local-filesystem paths (sdk_cwd / ephemeral)
- Expanding references inside the tool-call pipeline (_execute_tool_sync)
- The extended Read tool handler (workspace:// pass-through via session context)
No real LLM or database is required; workspace reads are stubbed where needed.
"""
from __future__ import annotations
import os
import tempfile
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot.sdk.file_ref import (
FileRef,
expand_file_refs_in_args,
expand_file_refs_in_string,
read_file_bytes,
resolve_file_ref,
)
from backend.copilot.sdk.tool_adapter import _read_file_handler
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_session(session_id: str = "integ-sess") -> MagicMock:
s = MagicMock()
s.session_id = session_id
return s
# ---------------------------------------------------------------------------
# Local-file resolution (sdk_cwd)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_resolve_file_ref_local_path():
"""resolve_file_ref reads a real local file when it's within sdk_cwd."""
with tempfile.TemporaryDirectory() as sdk_cwd:
# Write a test file inside sdk_cwd
test_file = os.path.join(sdk_cwd, "hello.txt")
with open(test_file, "w") as f:
f.write("line1\nline2\nline3\n")
session = _make_session()
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
ref = FileRef(uri=test_file, start_line=None, end_line=None)
content = await resolve_file_ref(ref, user_id="u1", session=session)
assert content == "line1\nline2\nline3\n"
@pytest.mark.asyncio
async def test_resolve_file_ref_local_path_with_line_range():
"""resolve_file_ref respects line ranges for local files."""
with tempfile.TemporaryDirectory() as sdk_cwd:
test_file = os.path.join(sdk_cwd, "multi.txt")
lines = [f"line{i}\n" for i in range(1, 11)] # line1 … line10
with open(test_file, "w") as f:
f.writelines(lines)
session = _make_session()
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
ref = FileRef(uri=test_file, start_line=3, end_line=5)
content = await resolve_file_ref(ref, user_id="u1", session=session)
assert content == "line3\nline4\nline5\n"
@pytest.mark.asyncio
async def test_resolve_file_ref_rejects_path_outside_sdk_cwd():
"""resolve_file_ref raises ValueError for paths outside sdk_cwd."""
with tempfile.TemporaryDirectory() as sdk_cwd:
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox_var:
mock_cwd_var.get.return_value = sdk_cwd
mock_sandbox_var.get.return_value = None
ref = FileRef(uri="/etc/passwd", start_line=None, end_line=None)
with pytest.raises(ValueError, match="not allowed"):
await resolve_file_ref(ref, user_id="u1", session=_make_session())
# ---------------------------------------------------------------------------
# expand_file_refs_in_string — integration with real files
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_expand_string_with_real_file():
"""expand_file_refs_in_string replaces @@agptfile: token with actual content."""
with tempfile.TemporaryDirectory() as sdk_cwd:
test_file = os.path.join(sdk_cwd, "data.txt")
with open(test_file, "w") as f:
f.write("hello world\n")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_string(
f"Content: @@agptfile:{test_file}",
user_id="u1",
session=_make_session(),
)
assert result == "Content: hello world\n"
@pytest.mark.asyncio
async def test_expand_string_missing_file_is_surfaced_inline():
"""Missing file ref yields [file-ref error: …] inline rather than raising."""
with tempfile.TemporaryDirectory() as sdk_cwd:
missing = os.path.join(sdk_cwd, "does_not_exist.txt")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_string(
f"@@agptfile:{missing}",
user_id="u1",
session=_make_session(),
)
assert "[file-ref error:" in result
assert "not found" in result.lower() or "not allowed" in result.lower()
# ---------------------------------------------------------------------------
# expand_file_refs_in_args — dict traversal with real files
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_expand_args_replaces_file_ref_in_nested_dict():
"""Nested @@agptfile: references in args are fully expanded."""
with tempfile.TemporaryDirectory() as sdk_cwd:
file_a = os.path.join(sdk_cwd, "a.txt")
file_b = os.path.join(sdk_cwd, "b.txt")
with open(file_a, "w") as f:
f.write("AAA")
with open(file_b, "w") as f:
f.write("BBB")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{
"outer": {
"content_a": f"@@agptfile:{file_a}",
"content_b": f"start @@agptfile:{file_b} end",
},
"count": 42,
},
user_id="u1",
session=_make_session(),
)
assert result["outer"]["content_a"] == "AAA"
assert result["outer"]["content_b"] == "start BBB end"
assert result["count"] == 42
# ---------------------------------------------------------------------------
# _read_file_handler — extended to accept workspace:// and local paths
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_file_handler_local_file():
"""_read_file_handler reads a local file when it's within sdk_cwd."""
with tempfile.TemporaryDirectory() as sdk_cwd:
test_file = os.path.join(sdk_cwd, "read_test.txt")
lines = [f"L{i}\n" for i in range(1, 6)]
with open(test_file, "w") as f:
f.writelines(lines)
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
"backend.copilot.context._current_project_dir"
) as mock_proj_var, patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", _make_session()),
):
mock_cwd_var.get.return_value = sdk_cwd
mock_proj_var.get.return_value = ""
result = await _read_file_handler(
{"file_path": test_file, "offset": 0, "limit": 5}
)
assert not result["isError"]
text = result["content"][0]["text"]
assert "L1" in text
assert "L5" in text
@pytest.mark.asyncio
async def test_read_file_handler_workspace_uri():
"""_read_file_handler handles workspace:// URIs via the workspace manager."""
mock_session = _make_session()
mock_manager = AsyncMock()
mock_manager.read_file_by_id.return_value = b"workspace file content\nline two\n"
with patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", mock_session),
), patch(
"backend.copilot.sdk.file_ref.get_manager",
new=AsyncMock(return_value=mock_manager),
):
result = await _read_file_handler(
{"file_path": "workspace://file-id-abc", "offset": 0, "limit": 10}
)
assert not result["isError"], result["content"][0]["text"]
text = result["content"][0]["text"]
assert "workspace file content" in text
assert "line two" in text
@pytest.mark.asyncio
async def test_read_file_handler_workspace_uri_no_session():
"""_read_file_handler returns error when workspace:// is used without session."""
with patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=(None, None),
):
result = await _read_file_handler({"file_path": "workspace://some-id"})
assert result["isError"]
assert "session" in result["content"][0]["text"].lower()
@pytest.mark.asyncio
async def test_read_file_handler_access_denied():
"""_read_file_handler rejects paths outside allowed locations."""
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox, patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", _make_session()),
):
mock_cwd.get.return_value = "/tmp/safe-dir"
mock_sandbox.get.return_value = None
result = await _read_file_handler({"file_path": "/etc/passwd"})
assert result["isError"]
assert "not allowed" in result["content"][0]["text"].lower()
# ---------------------------------------------------------------------------
# read_file_bytes — workspace:///path (virtual path) and E2B sandbox branch
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_file_bytes_workspace_virtual_path():
"""workspace:///path resolves via manager.read_file (is_path=True path)."""
session = _make_session()
mock_manager = AsyncMock()
mock_manager.read_file.return_value = b"virtual path content"
with patch(
"backend.copilot.sdk.file_ref.get_manager",
new=AsyncMock(return_value=mock_manager),
):
result = await read_file_bytes("workspace:///reports/q1.md", "user-1", session)
assert result == b"virtual path content"
mock_manager.read_file.assert_awaited_once_with("/reports/q1.md")
@pytest.mark.asyncio
async def test_read_file_bytes_e2b_sandbox_branch():
"""read_file_bytes reads from the E2B sandbox when a sandbox is active."""
session = _make_session()
mock_sandbox = AsyncMock()
mock_sandbox.files.read.return_value = bytearray(b"sandbox content")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox_var, patch(
"backend.copilot.context._current_project_dir"
) as mock_proj:
mock_cwd.get.return_value = ""
mock_sandbox_var.get.return_value = mock_sandbox
mock_proj.get.return_value = ""
result = await read_file_bytes("/home/user/script.sh", None, session)
assert result == b"sandbox content"
mock_sandbox.files.read.assert_awaited_once_with(
"/home/user/script.sh", format="bytes"
)
@pytest.mark.asyncio
async def test_read_file_bytes_e2b_path_escapes_sandbox_raises():
"""read_file_bytes raises ValueError for paths that escape the sandbox root."""
session = _make_session()
mock_sandbox = AsyncMock()
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox_var, patch(
"backend.copilot.context._current_project_dir"
) as mock_proj:
mock_cwd.get.return_value = ""
mock_sandbox_var.get.return_value = mock_sandbox
mock_proj.get.return_value = ""
with pytest.raises(ValueError, match="not allowed"):
await read_file_bytes("/etc/passwd", None, session)

View File

@@ -0,0 +1,382 @@
"""Tests for the @@agptfile: reference protocol (file_ref.py)."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot.sdk.file_ref import (
_MAX_EXPAND_CHARS,
FileRef,
FileRefExpansionError,
_apply_line_range,
expand_file_refs_in_args,
expand_file_refs_in_string,
parse_file_ref,
)
# ---------------------------------------------------------------------------
# parse_file_ref
# ---------------------------------------------------------------------------
def test_parse_file_ref_workspace_id():
ref = parse_file_ref("@@agptfile:workspace://abc123")
assert ref == FileRef(uri="workspace://abc123", start_line=None, end_line=None)
def test_parse_file_ref_workspace_id_with_mime():
ref = parse_file_ref("@@agptfile:workspace://abc123#text/plain")
assert ref is not None
assert ref.uri == "workspace://abc123#text/plain"
assert ref.start_line is None
def test_parse_file_ref_workspace_path():
ref = parse_file_ref("@@agptfile:workspace:///reports/q1.md")
assert ref is not None
assert ref.uri == "workspace:///reports/q1.md"
def test_parse_file_ref_with_line_range():
ref = parse_file_ref("@@agptfile:workspace://abc123[10-50]")
assert ref == FileRef(uri="workspace://abc123", start_line=10, end_line=50)
def test_parse_file_ref_local_path():
ref = parse_file_ref("@@agptfile:/tmp/copilot-session/output.py[1-100]")
assert ref is not None
assert ref.uri == "/tmp/copilot-session/output.py"
assert ref.start_line == 1
assert ref.end_line == 100
def test_parse_file_ref_no_match():
assert parse_file_ref("just a normal string") is None
assert parse_file_ref("workspace://abc123") is None # missing @@agptfile: prefix
assert (
parse_file_ref("@@agptfile:workspace://abc123 extra") is None
) # not full match
def test_parse_file_ref_strips_whitespace():
ref = parse_file_ref(" @@agptfile:workspace://abc123 ")
assert ref is not None
assert ref.uri == "workspace://abc123"
def test_parse_file_ref_invalid_range_zero_start():
assert parse_file_ref("@@agptfile:workspace://abc123[0-5]") is None
def test_parse_file_ref_invalid_range_end_less_than_start():
assert parse_file_ref("@@agptfile:workspace://abc123[10-5]") is None
def test_parse_file_ref_invalid_range_zero_end():
assert parse_file_ref("@@agptfile:workspace://abc123[1-0]") is None
# ---------------------------------------------------------------------------
# _apply_line_range
# ---------------------------------------------------------------------------
TEXT = "line1\nline2\nline3\nline4\nline5\n"
def test_apply_line_range_no_range():
assert _apply_line_range(TEXT, None, None) == TEXT
def test_apply_line_range_start_only():
result = _apply_line_range(TEXT, 3, None)
assert result == "line3\nline4\nline5\n"
def test_apply_line_range_full():
result = _apply_line_range(TEXT, 2, 4)
assert result == "line2\nline3\nline4\n"
def test_apply_line_range_single_line():
result = _apply_line_range(TEXT, 2, 2)
assert result == "line2\n"
def test_apply_line_range_beyond_eof():
result = _apply_line_range(TEXT, 4, 999)
assert result == "line4\nline5\n"
# ---------------------------------------------------------------------------
# expand_file_refs_in_string
# ---------------------------------------------------------------------------
def _make_session(session_id: str = "sess-1") -> MagicMock:
session = MagicMock()
session.session_id = session_id
return session
async def _resolve_always(ref: FileRef, _user_id: str | None, _session: object) -> str:
"""Stub resolver that returns the URI and range as a descriptive string."""
if ref.start_line is not None:
return f"content:{ref.uri}[{ref.start_line}-{ref.end_line}]"
return f"content:{ref.uri}"
@pytest.mark.asyncio
async def test_expand_no_refs():
result = await expand_file_refs_in_string(
"no references here", user_id="u1", session=_make_session()
)
assert result == "no references here"
@pytest.mark.asyncio
async def test_expand_single_ref():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://abc123",
user_id="u1",
session=_make_session(),
)
assert result == "content:workspace://abc123"
@pytest.mark.asyncio
async def test_expand_ref_with_range():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://abc123[10-50]",
user_id="u1",
session=_make_session(),
)
assert result == "content:workspace://abc123[10-50]"
@pytest.mark.asyncio
async def test_expand_ref_embedded_in_text():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_string(
"Here is the file: @@agptfile:workspace://abc123 — done",
user_id="u1",
session=_make_session(),
)
assert result == "Here is the file: content:workspace://abc123 — done"
@pytest.mark.asyncio
async def test_expand_multiple_refs():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://file1 and @@agptfile:workspace://file2[1-5]",
user_id="u1",
session=_make_session(),
)
assert result == "content:workspace://file1 and content:workspace://file2[1-5]"
@pytest.mark.asyncio
async def test_expand_invalid_range_zero_start_surfaces_inline():
"""expand_file_refs_in_string surfaces [file-ref error: ...] for zero-start ranges."""
result = await expand_file_refs_in_string(
"@@agptfile:workspace://abc123[0-5]",
user_id="u1",
session=_make_session(),
)
assert "[file-ref error:" in result
assert "line numbers must be >= 1" in result
@pytest.mark.asyncio
async def test_expand_invalid_range_end_less_than_start_surfaces_inline():
"""expand_file_refs_in_string surfaces [file-ref error: ...] when end < start."""
result = await expand_file_refs_in_string(
"prefix @@agptfile:workspace://abc123[10-5] suffix",
user_id="u1",
session=_make_session(),
)
assert "[file-ref error:" in result
assert "end line must be >= start line" in result
assert "prefix" in result
assert "suffix" in result
@pytest.mark.asyncio
async def test_expand_ref_error_surfaces_inline():
async def _raise(*args, **kwargs): # noqa: ARG001
raise ValueError("file not found")
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_raise),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://bad",
user_id="u1",
session=_make_session(),
)
assert "[file-ref error:" in result
assert "file not found" in result
# ---------------------------------------------------------------------------
# expand_file_refs_in_args
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_expand_args_flat():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_args(
{"content": "@@agptfile:workspace://abc123", "other": 42},
user_id="u1",
session=_make_session(),
)
assert result["content"] == "content:workspace://abc123"
assert result["other"] == 42
@pytest.mark.asyncio
async def test_expand_args_nested_dict():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_args(
{"outer": {"inner": "@@agptfile:workspace://nested"}},
user_id="u1",
session=_make_session(),
)
assert result["outer"]["inner"] == "content:workspace://nested"
@pytest.mark.asyncio
async def test_expand_args_list():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_args(
{
"items": [
"@@agptfile:workspace://a",
"plain",
"@@agptfile:workspace://b[1-3]",
]
},
user_id="u1",
session=_make_session(),
)
assert result["items"] == [
"content:workspace://a",
"plain",
"content:workspace://b[1-3]",
]
@pytest.mark.asyncio
async def test_expand_args_empty():
result = await expand_file_refs_in_args({}, user_id="u1", session=_make_session())
assert result == {}
@pytest.mark.asyncio
async def test_expand_args_no_refs():
result = await expand_file_refs_in_args(
{"key": "no refs here", "num": 1},
user_id="u1",
session=_make_session(),
)
assert result == {"key": "no refs here", "num": 1}
@pytest.mark.asyncio
async def test_expand_args_raises_on_file_ref_error():
"""expand_file_refs_in_args raises FileRefExpansionError instead of passing
the inline error string to the tool, blocking tool execution."""
async def _raise(*args, **kwargs): # noqa: ARG001
raise ValueError("path does not exist")
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_raise),
):
with pytest.raises(FileRefExpansionError) as exc_info:
await expand_file_refs_in_args(
{"prompt": "@@agptfile:/home/user/missing.txt"},
user_id="u1",
session=_make_session(),
)
assert "path does not exist" in str(exc_info.value)
# ---------------------------------------------------------------------------
# Per-file truncation and aggregate budget
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_expand_per_file_truncation():
"""Content exceeding _MAX_EXPAND_CHARS is truncated with a marker."""
oversized = "x" * (_MAX_EXPAND_CHARS + 100)
async def _resolve_oversized(ref: FileRef, _uid: str | None, _s: object) -> str:
return oversized
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_oversized),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://big-file",
user_id="u1",
session=_make_session(),
)
assert len(result) <= _MAX_EXPAND_CHARS + len("\n... [truncated]") + 10
assert "[truncated]" in result
@pytest.mark.asyncio
async def test_expand_aggregate_budget_exhausted():
"""When the aggregate budget is exhausted, later refs get the budget message."""
# Each file returns just under 300K; after ~4 files the 1M budget is used.
big_chunk = "y" * 300_000
async def _resolve_big(ref: FileRef, _uid: str | None, _s: object) -> str:
return big_chunk
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_big),
):
# 5 refs @ 300K each = 1.5M → last ref(s) should hit the aggregate limit
refs = " ".join(f"@@agptfile:workspace://f{i}" for i in range(5))
result = await expand_file_refs_in_string(
refs,
user_id="u1",
session=_make_session(),
)
assert "budget exhausted" in result

View File

@@ -0,0 +1,28 @@
## MCP Tool Guide
### Workflow
`run_mcp_tool` follows a two-step pattern:
1. **Discover** — call with only `server_url` to list available tools on the server.
2. **Execute** — call again with `server_url`, `tool_name`, and `tool_arguments` to run a tool.
### Known hosted MCP servers
Use these URLs directly without asking the user:
| Service | URL |
|---|---|
| Notion | `https://mcp.notion.com/mcp` |
| Linear | `https://mcp.linear.app/mcp` |
| Stripe | `https://mcp.stripe.com` |
| Intercom | `https://mcp.intercom.com/mcp` |
| Cloudflare | `https://mcp.cloudflare.com/mcp` |
| Atlassian / Jira | `https://mcp.atlassian.com/mcp` |
For other services, search the MCP registry at https://registry.modelcontextprotocol.io/.
### Authentication
If the server requires credentials, a `SetupRequirementsResponse` is returned with an OAuth
login prompt. Once the user completes the flow and confirms, retry the same call immediately.

View File

@@ -536,10 +536,12 @@ async def test_wait_for_stash_signaled():
result = await wait_for_stash(timeout=1.0)
assert result is True
assert _pto.get({}).get("WebSearch") == ["result data"]
pto = _pto.get()
assert pto is not None
assert pto.get("WebSearch") == ["result data"]
# Cleanup
_pto.set({}) # type: ignore[arg-type]
_pto.set({})
_stash_event.set(None)
@@ -554,7 +556,7 @@ async def test_wait_for_stash_timeout():
assert result is False
# Cleanup
_pto.set({}) # type: ignore[arg-type]
_pto.set({})
_stash_event.set(None)
@@ -573,10 +575,12 @@ async def test_wait_for_stash_already_stashed():
assert result is True
# But the stash itself is populated
assert _pto.get({}).get("Read") == ["file contents"]
pto = _pto.get()
assert pto is not None
assert pto.get("Read") == ["file contents"]
# Cleanup
_pto.set({}) # type: ignore[arg-type]
_pto.set({})
_stash_event.set(None)

View File

@@ -10,12 +10,13 @@ import re
from collections.abc import Callable
from typing import Any, cast
from backend.copilot.context import is_allowed_local_path
from .tool_adapter import (
BLOCKED_TOOLS,
DANGEROUS_PATTERNS,
MCP_TOOL_PREFIX,
WORKSPACE_SCOPED_TOOLS,
is_allowed_local_path,
stash_pending_tool_output,
)
@@ -127,7 +128,6 @@ def create_security_hooks(
sdk_cwd: str | None = None,
max_subtasks: int = 3,
on_compact: Callable[[], None] | None = None,
on_stop: Callable[[str, str], None] | None = None,
) -> dict[str, Any]:
"""Create the security hooks configuration for Claude Agent SDK.
@@ -136,15 +136,12 @@ def create_security_hooks(
- PostToolUse: Log successful tool executions
- PostToolUseFailure: Log and handle failed tool executions
- PreCompact: Log context compaction events (SDK handles compaction automatically)
- Stop: Capture transcript path for stateless resume (when *on_stop* is provided)
Args:
user_id: Current user ID for isolation validation
sdk_cwd: SDK working directory for workspace-scoped tool validation
max_subtasks: Maximum concurrent Task (sub-agent) spawns allowed per session
on_stop: Callback ``(transcript_path, sdk_session_id)`` invoked when
the SDK finishes processing — used to read the JSONL transcript
before the CLI process exits.
on_compact: Callback invoked when SDK starts compacting context.
Returns:
Hooks configuration dict for ClaudeAgentOptions
@@ -311,30 +308,6 @@ def create_security_hooks(
on_compact()
return cast(SyncHookJSONOutput, {})
# --- Stop hook: capture transcript path for stateless resume ---
async def stop_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Capture transcript path when SDK finishes processing.
The Stop hook fires while the CLI process is still alive, giving us
a reliable window to read the JSONL transcript before SIGTERM.
"""
_ = context, tool_use_id
transcript_path = cast(str, input_data.get("transcript_path", ""))
sdk_session_id = cast(str, input_data.get("session_id", ""))
if transcript_path and on_stop:
logger.info(
f"[SDK] Stop hook: transcript_path={transcript_path}, "
f"sdk_session_id={sdk_session_id[:12]}..."
)
on_stop(transcript_path, sdk_session_id)
return cast(SyncHookJSONOutput, {})
hooks: dict[str, Any] = {
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
@@ -344,9 +317,6 @@ def create_security_hooks(
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
}
if on_stop is not None:
hooks["Stop"] = [HookMatcher(matcher=None, hooks=[stop_hook])]
return hooks
except ImportError:
# Fallback for when SDK isn't available - return empty hooks

View File

@@ -9,8 +9,9 @@ import os
import pytest
from backend.copilot.context import _current_project_dir
from .security_hooks import _validate_tool_access, _validate_user_isolation
from .service import _is_tool_error_or_denial
SDK_CWD = "/tmp/copilot-abc123"
@@ -120,8 +121,6 @@ def test_read_no_cwd_denies_absolute():
def test_read_tool_results_allowed():
from .tool_adapter import _current_project_dir
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
# is_allowed_local_path requires the session's encoded cwd to be set
@@ -133,16 +132,14 @@ def test_read_tool_results_allowed():
_current_project_dir.reset(token)
def test_read_claude_projects_session_dir_allowed():
"""Files within the current session's project dir are allowed."""
from .tool_adapter import _current_project_dir
def test_read_claude_projects_settings_json_denied():
"""SDK-internal artifacts like settings.json are NOT accessible — only tool-results/ is."""
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
token = _current_project_dir.set("-tmp-copilot-abc123")
try:
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert not _is_denied(result)
assert _is_denied(result)
finally:
_current_project_dir.reset(token)
@@ -357,76 +354,3 @@ async def test_task_slot_released_on_failure(_hooks):
context={},
)
assert not _is_denied(result)
# -- _is_tool_error_or_denial ------------------------------------------------
class TestIsToolErrorOrDenial:
def test_none_content(self):
assert _is_tool_error_or_denial(None) is False
def test_empty_content(self):
assert _is_tool_error_or_denial("") is False
def test_benign_output(self):
assert _is_tool_error_or_denial("All good, no issues.") is False
def test_security_marker(self):
assert _is_tool_error_or_denial("[SECURITY] Tool access blocked") is True
def test_cannot_be_bypassed(self):
assert _is_tool_error_or_denial("This restriction cannot be bypassed.") is True
def test_not_allowed(self):
assert _is_tool_error_or_denial("Operation not allowed in sandbox") is True
def test_background_task_denial(self):
assert (
_is_tool_error_or_denial(
"Background task execution is not supported. "
"Run tasks in the foreground instead."
)
is True
)
def test_subtask_limit_denial(self):
assert (
_is_tool_error_or_denial(
"Maximum 2 concurrent sub-tasks. "
"Wait for running sub-tasks to finish, "
"or continue in the main conversation."
)
is True
)
def test_denied_marker(self):
assert (
_is_tool_error_or_denial("Access denied: insufficient privileges") is True
)
def test_blocked_marker(self):
assert _is_tool_error_or_denial("Request blocked by security policy") is True
def test_failed_marker(self):
assert _is_tool_error_or_denial("Failed to execute tool: timeout") is True
def test_mcp_iserror(self):
assert _is_tool_error_or_denial('{"isError": true, "content": []}') is True
def test_benign_error_in_value(self):
"""Content like '0 errors found' should not trigger — 'error' was removed."""
assert _is_tool_error_or_denial("0 errors found") is False
def test_benign_permission_field(self):
"""Schema descriptions mentioning 'permission' should not trigger."""
assert (
_is_tool_error_or_denial(
'{"fields": [{"name": "permission_level", "type": "int"}]}'
)
is False
)
def test_benign_not_found_in_listing(self):
"""File listing containing 'not found' in filenames should not trigger."""
assert _is_tool_error_or_denial("readme.md\nfile-not-found-handler.py") is False

View File

@@ -12,7 +12,6 @@ import subprocess
import sys
import uuid
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any, cast
import openai
@@ -21,6 +20,9 @@ from claude_agent_sdk import (
ClaudeAgentOptions,
ClaudeSDKClient,
ResultMessage,
TextBlock,
ThinkingBlock,
ToolResultBlock,
ToolUseBlock,
)
from langfuse import propagate_attributes
@@ -42,6 +44,7 @@ from ..model import (
update_session_title,
upsert_chat_session,
)
from ..prompting import get_sdk_supplement
from ..response_model import (
StreamBaseResponse,
StreamError,
@@ -57,7 +60,7 @@ from ..service import (
_generate_session_title,
_is_langfuse_configured,
)
from ..tools.e2b_sandbox import get_or_create_sandbox
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tools.workspace_files import get_manager
from ..tracking import track_user_message
@@ -74,11 +77,11 @@ from .tool_adapter import (
from .transcript import (
cleanup_cli_project_dir,
download_transcript,
read_transcript_file,
upload_transcript,
validate_transcript,
write_transcript_to_tempfile,
)
from .transcript_builder import TranscriptBuilder
logger = logging.getLogger(__name__)
config = ChatConfig()
@@ -137,19 +140,6 @@ _setup_langfuse_otel()
_background_tasks: set[asyncio.Task[Any]] = set()
@dataclass
class CapturedTranscript:
"""Info captured by the SDK Stop hook for stateless --resume."""
path: str = ""
sdk_session_id: str = ""
raw_content: str = ""
@property
def available(self) -> bool:
return bool(self.path)
_SDK_CWD_PREFIX = WORKSPACE_PREFIX
# Heartbeat interval — keep SSE alive through proxies/LBs during tool execution.
@@ -157,140 +147,6 @@ _SDK_CWD_PREFIX = WORKSPACE_PREFIX
_HEARTBEAT_INTERVAL = 10.0 # seconds
# Appended to the system prompt to inform the agent about available tools.
# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead,
# which has kernel-level network isolation (unshare --net).
_SHARED_TOOL_NOTES = """\
### Sharing files with the user
After saving a file to the persistent workspace with `write_workspace_file`,
share it with the user by embedding the `download_url` from the response in
your message as a Markdown link or image:
- **Any file** — shows as a clickable download link:
`[report.csv](workspace://file_id#text/csv)`
- **Image** — renders inline in chat:
`![chart](workspace://file_id#image/png)`
- **Video** — renders inline in chat with player controls:
`![recording](workspace://file_id#video/mp4)`
The `download_url` field in the `write_workspace_file` response is already
in the correct format — paste it directly after the `(` in the Markdown.
### Long-running tools
Long-running tools (create_agent, edit_agent, etc.) are handled
asynchronously. You will receive an immediate response; the actual result
is delivered to the user via a background stream.
### Large tool outputs
When a tool output exceeds the display limit, it is automatically saved to
the persistent workspace. The truncated output includes a
`<tool-output-truncated>` tag with the workspace path. Use
`read_workspace_file(path="...", offset=N, length=50000)` to retrieve
additional sections.
### Sub-agent tasks
- When using the Task tool, NEVER set `run_in_background` to true.
All tasks must run in the foreground.
"""
_LOCAL_TOOL_SUPPLEMENT = (
"""
## Tool notes
### Shell commands
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
for shell commands — it runs in a network-isolated sandbox.
### Working directory
- Your working directory is: `{cwd}`
- All SDK Read/Write/Edit/Glob/Grep tools AND `bash_exec` operate inside this
directory. This is the ONLY writable path — do not attempt to read or write
anywhere else on the filesystem.
- Use relative paths or absolute paths under `{cwd}` for all file operations.
### Two storage systems — CRITICAL to understand
1. **Ephemeral working directory** (`{cwd}`):
- Shared by SDK Read/Write/Edit/Glob/Grep tools AND `bash_exec`
- Files here are **lost between turns** — do NOT rely on them persisting
- Use for temporary work: running scripts, processing data, etc.
2. **Persistent workspace** (cloud storage):
- Files here **survive across turns and sessions**
- Use `write_workspace_file` to save important files (code, outputs, configs)
- Use `read_workspace_file` to retrieve previously saved files
- Use `list_workspace_files` to see what files you've saved before
- Call `list_workspace_files(include_all_sessions=True)` to see files from
all sessions
### Moving files between ephemeral and persistent storage
- **Ephemeral → Persistent**: Use `write_workspace_file` with either:
- `content` param (plain text) — for text files
- `source_path` param — to copy any file directly from the ephemeral dir
- **Persistent → Ephemeral**: Use `read_workspace_file` with `save_to_path`
param to download a workspace file to the ephemeral dir for processing
### File persistence workflow
When you create or modify important files (code, configs, outputs), you MUST:
1. Save them using `write_workspace_file` so they persist
2. At the start of a new turn, call `list_workspace_files` to see what files
are available from previous turns
"""
+ _SHARED_TOOL_NOTES
)
_E2B_TOOL_SUPPLEMENT = (
"""
## Tool notes
### Shell commands
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
for shell commands — it runs in a cloud sandbox with full internet access.
### Working directory
- Your working directory is: `/home/user` (cloud sandbox)
- All file tools (`read_file`, `write_file`, `edit_file`, `glob`, `grep`)
AND `bash_exec` operate on the **same cloud sandbox filesystem**.
- Files created by `bash_exec` are immediately visible to `read_file` and
vice-versa — they share one filesystem.
- Use relative paths (resolved from `/home/user`) or absolute paths.
### Two storage systems — CRITICAL to understand
1. **Cloud sandbox** (`/home/user`):
- Shared by all file tools AND `bash_exec` — same filesystem
- Files **persist across turns** within the current session
- Full Linux environment with internet access
- Lost when the session expires (12 h inactivity)
2. **Persistent workspace** (cloud storage):
- Files here **survive across sessions indefinitely**
- Use `write_workspace_file` to save important files permanently
- Use `read_workspace_file` to retrieve previously saved files
- Use `list_workspace_files` to see what files you've saved before
- Call `list_workspace_files(include_all_sessions=True)` to see files from
all sessions
### Moving files between sandbox and persistent storage
- **Sandbox → Persistent**: Use `write_workspace_file` with `source_path`
to copy from the sandbox to permanent storage
- **Persistent → Sandbox**: Use `read_workspace_file` with `save_to_path`
to download into the sandbox for processing
### File persistence workflow
Important files that must survive beyond this session should be saved with
`write_workspace_file`. Sandbox files persist across turns but are lost
when the session expires.
"""
+ _SHARED_TOOL_NOTES
)
STREAM_LOCK_PREFIX = "copilot:stream:lock:"
@@ -451,6 +307,50 @@ def _cleanup_sdk_tool_results(cwd: str) -> None:
pass
def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
"""Convert SDK content blocks to transcript format.
Handles TextBlock, ToolUseBlock, ToolResultBlock, and ThinkingBlock.
Unknown block types are logged and skipped.
"""
result: list[dict[str, Any]] = []
for block in blocks or []:
if isinstance(block, TextBlock):
result.append({"type": "text", "text": block.text})
elif isinstance(block, ToolUseBlock):
result.append(
{
"type": "tool_use",
"id": block.id,
"name": block.name,
"input": block.input,
}
)
elif isinstance(block, ToolResultBlock):
tool_result_entry: dict[str, Any] = {
"type": "tool_result",
"tool_use_id": block.tool_use_id,
"content": block.content,
}
if block.is_error:
tool_result_entry["is_error"] = True
result.append(tool_result_entry)
elif isinstance(block, ThinkingBlock):
result.append(
{
"type": "thinking",
"thinking": block.thinking,
"signature": block.signature,
}
)
else:
logger.warning(
f"[SDK] Unknown content block type: {type(block).__name__}. "
f"This may indicate a new SDK version with additional block types."
)
return result
async def _compress_messages(
messages: list[ChatMessage],
) -> tuple[list[ChatMessage], bool]:
@@ -556,31 +456,6 @@ def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
def _is_tool_error_or_denial(content: str | None) -> bool:
"""Check if a tool message content indicates an error or denial.
Currently unused — ``_format_conversation_context`` includes all tool
results. Kept as a utility for future selective filtering.
"""
if not content:
return False
lower = content.lower()
return any(
marker in lower
for marker in (
"[security]",
"cannot be bypassed",
"not allowed",
"not supported", # background-task denial
"maximum", # subtask-limit denial
"denied",
"blocked",
"failed to", # internal tool execution failures
'"iserror": true', # MCP protocol error flag
)
)
async def _build_query_message(
current_message: str,
session: ChatSession,
@@ -806,6 +681,11 @@ async def stream_chat_completion_sdk(
user_id=user_id, session_id=session_id, message_length=len(message)
)
# Structured log prefix: [SDK][<session>][T<turn>]
# Turn = number of user messages (1-based), computed AFTER appending the new message.
turn = sum(1 for m in session.messages if m.role == "user")
log_prefix = f"[SDK][{session_id[:12]}][T{turn}]"
session = await upsert_chat_session(session)
# Generate title for new sessions (first user message)
@@ -823,10 +703,11 @@ async def stream_chat_completion_sdk(
message_id = str(uuid.uuid4())
stream_id = str(uuid.uuid4())
stream_completed = False
ended_with_stream_error = False
e2b_sandbox = None
use_resume = False
resume_file: str | None = None
captured_transcript = CapturedTranscript()
transcript_builder = TranscriptBuilder()
sdk_cwd = ""
# Acquire stream lock to prevent concurrent streams to the same session
@@ -841,7 +722,7 @@ async def stream_chat_completion_sdk(
if lock_owner != stream_id:
# Another stream is active
logger.warning(
f"[SDK] Session {session_id} already has an active stream: {lock_owner}"
f"{log_prefix} Session already has an active stream: {lock_owner}"
)
yield StreamError(
errorText="Another stream is already active for this session. "
@@ -865,7 +746,7 @@ async def stream_chat_completion_sdk(
sdk_cwd = _make_sdk_cwd(session_id)
os.makedirs(sdk_cwd, exist_ok=True)
except (ValueError, OSError) as e:
logger.error("[SDK] [%s] Invalid SDK cwd: %s", session_id[:12], e)
logger.error("%s Invalid SDK cwd: %s", log_prefix, e)
yield StreamError(
errorText="Unable to initialize working directory.",
code="sdk_cwd_error",
@@ -878,28 +759,29 @@ async def stream_chat_completion_sdk(
async def _setup_e2b():
"""Set up E2B sandbox if configured, return sandbox or None."""
if config.use_e2b_sandbox and not config.e2b_api_key:
logger.warning(
"[E2B] [%s] E2B sandbox enabled but no API key configured "
"(CHAT_E2B_API_KEY / E2B_API_KEY) — falling back to bubblewrap",
session_id[:12],
)
return None
if config.use_e2b_sandbox and config.e2b_api_key:
try:
return await get_or_create_sandbox(
session_id,
api_key=config.e2b_api_key,
template=config.e2b_sandbox_template,
timeout=config.e2b_sandbox_timeout,
)
except Exception as e2b_err:
logger.error(
"[E2B] [%s] Setup failed: %s",
if not (e2b_api_key := config.active_e2b_api_key):
if config.use_e2b_sandbox:
logger.warning(
"[E2B] [%s] E2B sandbox enabled but no API key configured "
"(CHAT_E2B_API_KEY / E2B_API_KEY) — falling back to bubblewrap",
session_id[:12],
e2b_err,
exc_info=True,
)
return None
try:
return await get_or_create_sandbox(
session_id,
api_key=e2b_api_key,
template=config.e2b_sandbox_template,
timeout=config.e2b_sandbox_timeout,
on_timeout=config.e2b_sandbox_on_timeout,
)
except Exception as e2b_err:
logger.error(
"[E2B] [%s] Setup failed: %s",
session_id[:12],
e2b_err,
exc_info=True,
)
return None
async def _fetch_transcript():
@@ -909,12 +791,13 @@ async def stream_chat_completion_sdk(
):
return None
try:
return await download_transcript(user_id, session_id)
return await download_transcript(
user_id, session_id, log_prefix=log_prefix
)
except Exception as transcript_err:
logger.warning(
"[SDK] [%s] Transcript download failed, continuing without "
"--resume: %s",
session_id[:12],
"%s Transcript download failed, continuing without " "--resume: %s",
log_prefix,
transcript_err,
)
return None
@@ -926,21 +809,26 @@ async def stream_chat_completion_sdk(
)
use_e2b = e2b_sandbox is not None
system_prompt = base_system_prompt + (
_E2B_TOOL_SUPPLEMENT
if use_e2b
else _LOCAL_TOOL_SUPPLEMENT.format(cwd=sdk_cwd)
# Append appropriate supplement (Claude gets tool schemas automatically)
system_prompt = base_system_prompt + get_sdk_supplement(
use_e2b=use_e2b, cwd=sdk_cwd
)
# Process transcript download result
transcript_msg_count = 0
if dl:
is_valid = validate_transcript(dl.content)
dl_lines = dl.content.strip().split("\n") if dl.content else []
logger.info(
"%s Downloaded transcript: %dB, %d lines, " "msg_count=%d, valid=%s",
log_prefix,
len(dl.content),
len(dl_lines),
dl.message_count,
is_valid,
)
if is_valid:
logger.info(
f"[SDK] Transcript available for session {session_id}: "
f"{len(dl.content)}B, msg_count={dl.message_count}"
)
# Load previous FULL context into builder
transcript_builder.load_previous(dl.content, log_prefix=log_prefix)
resume_file = write_transcript_to_tempfile(
dl.content, session_id, sdk_cwd
)
@@ -948,16 +836,14 @@ async def stream_chat_completion_sdk(
use_resume = True
transcript_msg_count = dl.message_count
logger.debug(
f"[SDK] Using --resume ({len(dl.content)}B, "
f"{log_prefix} Using --resume ({len(dl.content)}B, "
f"msg_count={transcript_msg_count})"
)
else:
logger.warning(
f"[SDK] Transcript downloaded but invalid for {session_id}"
)
logger.warning(f"{log_prefix} Transcript downloaded but invalid")
elif config.claude_agent_use_resume and user_id and len(session.messages) > 1:
logger.warning(
f"[SDK] No transcript available for {session_id} "
f"{log_prefix} No transcript available "
f"({len(session.messages)} messages in session)"
)
@@ -979,25 +865,6 @@ async def stream_chat_completion_sdk(
sdk_model = _resolve_sdk_model()
# --- Transcript capture via Stop hook ---
# Read the file content immediately — the SDK may clean up
# the file before our finally block runs.
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
captured_transcript.path = transcript_path
captured_transcript.sdk_session_id = sdk_session_id
content = read_transcript_file(transcript_path)
if content:
captured_transcript.raw_content = content
logger.info(
f"[SDK] Stop hook: captured {len(content)}B from "
f"{transcript_path}"
)
else:
logger.warning(
f"[SDK] Stop hook: transcript file empty/missing at "
f"{transcript_path}"
)
# Track SDK-internal compaction (PreCompact hook → start, next msg → end)
compaction = CompactionTracker()
@@ -1005,12 +872,16 @@ async def stream_chat_completion_sdk(
user_id,
sdk_cwd=sdk_cwd,
max_subtasks=config.claude_agent_max_subtasks,
on_stop=_on_stop if config.claude_agent_use_resume else None,
on_compact=compaction.on_compact,
)
allowed = get_copilot_tool_names(use_e2b=use_e2b)
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
def _on_stderr(line: str) -> None:
sid = session_id[:12] if session_id else "?"
logger.info("[SDK] [%s] CLI stderr: %s", sid, line.rstrip())
sdk_options_kwargs: dict[str, Any] = {
"system_prompt": system_prompt,
"mcp_servers": {"copilot": mcp_server},
@@ -1019,6 +890,7 @@ async def stream_chat_completion_sdk(
"hooks": security_hooks,
"cwd": sdk_cwd,
"max_buffer_size": config.claude_agent_max_buffer_size,
"stderr": _on_stderr,
}
if sdk_model:
sdk_options_kwargs["model"] = sdk_model
@@ -1040,7 +912,10 @@ async def stream_chat_completion_sdk(
session_id=session_id,
trace_name="copilot-sdk",
tags=["sdk"],
metadata={"resume": str(use_resume)},
metadata={
"resume": str(use_resume),
"conversation_turn": str(turn),
},
)
_otel_ctx.__enter__()
@@ -1074,9 +949,9 @@ async def stream_chat_completion_sdk(
query_message = f"{query_message}\n\n{attachments.hint}"
logger.info(
"[SDK] [%s] Sending query — resume=%s, total_msgs=%d, "
"%s Sending query — resume=%s, total_msgs=%d, "
"query_len=%d, attached_files=%d, image_blocks=%d",
session_id[:12],
log_prefix,
use_resume,
len(session.messages),
len(query_message),
@@ -1105,15 +980,19 @@ async def stream_chat_completion_sdk(
await client._transport.write( # noqa: SLF001
json.dumps(user_msg) + "\n"
)
# Capture user message in transcript (multimodal)
transcript_builder.append_user(content=content_blocks)
else:
await client.query(query_message, session_id=session_id)
# Capture actual user message in transcript (not the engineered query)
# query_message may include context wrappers, but transcript needs raw input
transcript_builder.append_user(content=current_message)
assistant_response = ChatMessage(role="assistant", content="")
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False
ended_with_stream_error = False
# Use an explicit async iterator with non-cancelling heartbeats.
# CRITICAL: we must NOT cancel __anext__() mid-flight — doing so
# (via asyncio.timeout or wait_for) corrupts the SDK's internal
@@ -1150,8 +1029,8 @@ async def stream_chat_completion_sdk(
sdk_msg = done.pop().result()
except StopAsyncIteration:
logger.info(
"[SDK] [%s] Stream ended normally (StopAsyncIteration)",
session_id[:12],
"%s Stream ended normally (StopAsyncIteration)",
log_prefix,
)
break
except Exception as stream_err:
@@ -1160,8 +1039,8 @@ async def stream_chat_completion_sdk(
# so the session can still be saved and the
# frontend gets a clean finish.
logger.error(
"[SDK] [%s] Stream error from SDK: %s",
session_id[:12],
"%s Stream error from SDK: %s",
log_prefix,
stream_err,
exc_info=True,
)
@@ -1173,9 +1052,9 @@ async def stream_chat_completion_sdk(
break
logger.info(
"[SDK] [%s] Received: %s %s "
"%s Received: %s %s "
"(unresolved=%d, current=%d, resolved=%d)",
session_id[:12],
log_prefix,
type(sdk_msg).__name__,
getattr(sdk_msg, "subtype", ""),
len(adapter.current_tool_calls)
@@ -1184,6 +1063,19 @@ async def stream_chat_completion_sdk(
len(adapter.resolved_tool_calls),
)
# Log AssistantMessage API errors (e.g. invalid_request)
# so we can debug Anthropic API 400s surfaced by the CLI.
sdk_error = getattr(sdk_msg, "error", None)
if isinstance(sdk_msg, AssistantMessage) and sdk_error:
logger.error(
"[SDK] [%s] AssistantMessage has error=%s, "
"content_blocks=%d, content_preview=%s",
session_id[:12],
sdk_error,
len(sdk_msg.content),
str(sdk_msg.content)[:500],
)
# Race-condition fix: SDK hooks (PostToolUse) are
# executed asynchronously via start_soon() — the next
# message can arrive before the hook stashes output.
@@ -1210,10 +1102,10 @@ async def stream_chat_completion_sdk(
await asyncio.sleep(0)
else:
logger.warning(
"[SDK] [%s] Timed out waiting for "
"%s Timed out waiting for "
"PostToolUse hook stash "
"(%d unresolved tool calls)",
session_id[:12],
log_prefix,
len(adapter.current_tool_calls)
- len(adapter.resolved_tool_calls),
)
@@ -1221,9 +1113,9 @@ async def stream_chat_completion_sdk(
# Log ResultMessage details for debugging
if isinstance(sdk_msg, ResultMessage):
logger.info(
"[SDK] [%s] Received: ResultMessage %s "
"%s Received: ResultMessage %s "
"(unresolved=%d, current=%d, resolved=%d)",
session_id[:12],
log_prefix,
sdk_msg.subtype,
len(adapter.current_tool_calls)
- len(adapter.resolved_tool_calls),
@@ -1232,8 +1124,8 @@ async def stream_chat_completion_sdk(
)
if sdk_msg.subtype in ("error", "error_during_execution"):
logger.error(
"[SDK] [%s] SDK execution failed with error: %s",
session_id[:12],
"%s SDK execution failed with error: %s",
log_prefix,
sdk_msg.result or "(no error message provided)",
)
@@ -1258,8 +1150,8 @@ async def stream_chat_completion_sdk(
out_len = len(str(response.output))
extra = f", output_len={out_len}"
logger.info(
"[SDK] [%s] Tool event: %s, tool=%s%s",
session_id[:12],
"%s Tool event: %s, tool=%s%s",
log_prefix,
type(response).__name__,
getattr(response, "toolName", "N/A"),
extra,
@@ -1268,8 +1160,8 @@ async def stream_chat_completion_sdk(
# Log errors being sent to frontend
if isinstance(response, StreamError):
logger.error(
"[SDK] [%s] Sending error to frontend: %s (code=%s)",
session_id[:12],
"%s Sending error to frontend: %s (code=%s)",
log_prefix,
response.errorText,
response.code,
)
@@ -1314,29 +1206,44 @@ async def stream_chat_completion_sdk(
has_appended_assistant = True
elif isinstance(response, StreamToolOutputAvailable):
content = (
response.output
if isinstance(response.output, str)
else json.dumps(response.output, ensure_ascii=False)
)
session.messages.append(
ChatMessage(
role="tool",
content=(
response.output
if isinstance(response.output, str)
else str(response.output)
),
content=content,
tool_call_id=response.toolCallId,
)
)
transcript_builder.append_tool_result(
tool_use_id=response.toolCallId,
content=content,
)
has_tool_results = True
elif isinstance(response, StreamFinish):
stream_completed = True
# Append assistant entry AFTER convert_message so that
# any stashed tool results from the previous turn are
# recorded first, preserving the required API order:
# assistant(tool_use) → tool_result → assistant(text).
if isinstance(sdk_msg, AssistantMessage):
transcript_builder.append_assistant(
content_blocks=_format_sdk_content_blocks(sdk_msg.content),
model=sdk_msg.model,
)
except asyncio.CancelledError:
# Task/generator was cancelled (e.g. client disconnect,
# server shutdown). Log and let the safety-net / finally
# blocks handle cleanup.
logger.warning(
"[SDK] [%s] Streaming loop cancelled (asyncio.CancelledError)",
session_id[:12],
"%s Streaming loop cancelled (asyncio.CancelledError)",
log_prefix,
)
raise
finally:
@@ -1350,7 +1257,8 @@ async def stream_chat_completion_sdk(
except (asyncio.CancelledError, StopAsyncIteration):
# Expected: task was cancelled or exhausted during cleanup
logger.info(
"[SDK] Pending __anext__ task completed during cleanup"
"%s Pending __anext__ task completed during cleanup",
log_prefix,
)
# Safety net: if tools are still unresolved after the
@@ -1359,9 +1267,9 @@ async def stream_chat_completion_sdk(
# them now so the frontend stops showing spinners.
if adapter.has_unresolved_tool_calls:
logger.warning(
"[SDK] [%s] %d unresolved tool(s) after stream loop — "
"%s %d unresolved tool(s) after stream loop — "
"flushing as safety net",
session_id[:12],
log_prefix,
len(adapter.current_tool_calls) - len(adapter.resolved_tool_calls),
)
safety_responses: list[StreamBaseResponse] = []
@@ -1372,11 +1280,20 @@ async def stream_chat_completion_sdk(
(StreamToolInputAvailable, StreamToolOutputAvailable),
):
logger.info(
"[SDK] [%s] Safety flush: %s, tool=%s",
session_id[:12],
"%s Safety flush: %s, tool=%s",
log_prefix,
type(response).__name__,
getattr(response, "toolName", "N/A"),
)
if isinstance(response, StreamToolOutputAvailable):
transcript_builder.append_tool_result(
tool_use_id=response.toolCallId,
content=(
response.output
if isinstance(response.output, str)
else json.dumps(response.output, ensure_ascii=False)
),
)
yield response
# If the stream ended without a ResultMessage, the SDK
@@ -1386,8 +1303,8 @@ async def stream_chat_completion_sdk(
# StreamFinish is published by mark_session_completed in the processor.
if not stream_completed and not ended_with_stream_error:
logger.info(
"[SDK] [%s] Stream ended without ResultMessage (stopped by user)",
session_id[:12],
"%s Stream ended without ResultMessage (stopped by user)",
log_prefix,
)
closing_responses: list[StreamBaseResponse] = []
adapter._end_text_if_open(closing_responses)
@@ -1408,69 +1325,36 @@ async def stream_chat_completion_sdk(
) and not has_appended_assistant:
session.messages.append(assistant_response)
# --- Upload transcript for next-turn --resume ---
# After async with the SDK task group has exited, so the Stop
# hook has already fired and the CLI has been SIGTERMed. The
# CLI uses appendFileSync, so all writes are safely on disk.
if config.claude_agent_use_resume and user_id:
# With --resume the CLI appends to the resume file (most
# complete). Otherwise use the Stop hook path.
if use_resume and resume_file:
raw_transcript = read_transcript_file(resume_file)
logger.debug("[SDK] Transcript source: resume file")
elif captured_transcript.path:
raw_transcript = read_transcript_file(captured_transcript.path)
logger.debug(
"[SDK] Transcript source: stop hook (%s), read result: %s",
captured_transcript.path,
f"{len(raw_transcript)}B" if raw_transcript else "None",
)
else:
raw_transcript = None
# Transcript upload is handled exclusively in the finally block
# to avoid double-uploads (the success path used to upload the
# old resume file, then the finally block overwrote it with the
# stop hook content — which could be smaller after compaction).
if not raw_transcript:
logger.debug(
"[SDK] No usable transcript — CLI file had no "
"conversation entries (expected for first turn "
"without --resume)"
)
if raw_transcript:
# Shield the upload from generator cancellation so a
# client disconnect / page refresh doesn't lose the
# transcript. The upload must finish even if the SSE
# connection is torn down.
await asyncio.shield(
_try_upload_transcript(
user_id,
session_id,
raw_transcript,
message_count=len(session.messages),
)
)
logger.info(
"[SDK] [%s] Stream completed successfully with %d messages",
session_id[:12],
len(session.messages),
)
if ended_with_stream_error:
logger.warning(
"%s Stream ended with SDK error after %d messages",
log_prefix,
len(session.messages),
)
else:
logger.info(
"%s Stream completed successfully with %d messages",
log_prefix,
len(session.messages),
)
except BaseException as e:
# Catch BaseException to handle both Exception and CancelledError
# (CancelledError inherits from BaseException in Python 3.8+)
if isinstance(e, asyncio.CancelledError):
logger.warning("[SDK] [%s] Session cancelled", session_id[:12])
logger.warning("%s Session cancelled", log_prefix)
error_msg = "Operation cancelled"
else:
error_msg = str(e) or type(e).__name__
# SDK cleanup RuntimeError is expected during cancellation, log as warning
if isinstance(e, RuntimeError) and "cancel scope" in str(e):
logger.warning(
"[SDK] [%s] SDK cleanup error: %s", session_id[:12], error_msg
)
logger.warning("%s SDK cleanup error: %s", log_prefix, error_msg)
else:
logger.error(
f"[SDK] [%s] Error: {error_msg}", session_id[:12], exc_info=True
)
logger.error("%s Error: %s", log_prefix, error_msg, exc_info=True)
# Append error marker to session (non-invasive text parsing approach)
# The finally block will persist the session with this error marker
@@ -1481,8 +1365,8 @@ async def stream_chat_completion_sdk(
)
)
logger.debug(
"[SDK] [%s] Appended error marker, will be persisted in finally",
session_id[:12],
"%s Appended error marker, will be persisted in finally",
log_prefix,
)
# Yield StreamError for immediate feedback (only for non-cancellation errors)
@@ -1514,47 +1398,72 @@ async def stream_chat_completion_sdk(
try:
await asyncio.shield(upsert_chat_session(session))
logger.info(
"[SDK] [%s] Session persisted in finally with %d messages",
session_id[:12],
"%s Session persisted in finally with %d messages",
log_prefix,
len(session.messages),
)
except Exception as persist_err:
logger.error(
"[SDK] [%s] Failed to persist session in finally: %s",
session_id[:12],
"%s Failed to persist session in finally: %s",
log_prefix,
persist_err,
exc_info=True,
)
# --- Pause E2B sandbox to stop billing between turns ---
# Fire-and-forget: pausing is best-effort and must not block the
# response or the transcript upload. The task is anchored to
# _background_tasks to prevent garbage collection.
# Use pause_sandbox_direct to skip the Redis lookup and reconnect
# round-trip — e2b_sandbox is the live object from this turn.
if e2b_sandbox is not None:
task = asyncio.create_task(pause_sandbox_direct(e2b_sandbox, session_id))
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
# --- Upload transcript for next-turn --resume ---
# This MUST run in finally so the transcript is uploaded even when
# the streaming loop raises an exception. The CLI uses
# appendFileSync, so whatever was written before the error/SIGTERM
# is safely on disk and still useful for the next turn.
if config.claude_agent_use_resume and user_id:
# the streaming loop raises an exception.
# The transcript represents the COMPLETE active context (atomic).
if config.claude_agent_use_resume and user_id and session is not None:
try:
# Prefer content captured in the Stop hook (read before
# cleanup removes the file). Fall back to the resume
# file when the stop hook didn't fire (e.g. error before
# completion) so we don't lose the prior transcript.
raw_transcript = captured_transcript.raw_content or None
if not raw_transcript and use_resume and resume_file:
raw_transcript = read_transcript_file(resume_file)
# Build complete transcript from captured SDK messages
transcript_content = transcript_builder.to_jsonl()
if raw_transcript and session is not None:
await asyncio.shield(
_try_upload_transcript(
user_id,
session_id,
raw_transcript,
message_count=len(session.messages),
)
if not transcript_content:
logger.warning(
"%s No transcript to upload (builder empty)", log_prefix
)
elif not validate_transcript(transcript_content):
logger.warning(
"%s Transcript invalid, skipping upload (entries=%d)",
log_prefix,
transcript_builder.entry_count,
)
else:
logger.warning(f"[SDK] No transcript to upload for {session_id}")
logger.info(
"%s Uploading complete transcript (entries=%d, bytes=%d)",
log_prefix,
transcript_builder.entry_count,
len(transcript_content),
)
# Shield upload from cancellation - let it complete even if
# the finally block is interrupted. No timeout to avoid race
# conditions where backgrounded uploads overwrite newer transcripts.
await asyncio.shield(
upload_transcript(
user_id=user_id,
session_id=session_id,
content=transcript_content,
message_count=len(session.messages),
log_prefix=log_prefix,
)
)
except Exception as upload_err:
logger.error(
f"[SDK] Transcript upload failed in finally: {upload_err}",
"%s Transcript upload failed in finally: %s",
log_prefix,
upload_err,
exc_info=True,
)
@@ -1565,33 +1474,6 @@ async def stream_chat_completion_sdk(
await lock.release()
async def _try_upload_transcript(
user_id: str,
session_id: str,
raw_content: str,
message_count: int = 0,
) -> bool:
"""Strip progress entries and upload transcript (with timeout).
Returns True if the upload completed without error.
"""
try:
async with asyncio.timeout(30):
await upload_transcript(
user_id, session_id, raw_content, message_count=message_count
)
return True
except asyncio.TimeoutError:
logger.warning(f"[SDK] Transcript upload timed out for {session_id}")
return False
except Exception as e:
logger.error(
f"[SDK] Failed to upload transcript for {session_id}: {e}",
exc_info=True,
)
return False
async def _update_title_async(
session_id: str, message: str, user_id: str | None = None
) -> None:
@@ -1600,8 +1482,8 @@ async def _update_title_async(
title = await _generate_session_title(
message, user_id=user_id, session_id=session_id
)
if title:
await update_session_title(session_id, title)
if title and user_id:
await update_session_title(session_id, user_id, title, only_if_empty=True)
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
except Exception as e:
logger.warning(f"[SDK] Failed to update session title: {e}")

View File

@@ -1,9 +1,10 @@
"""Tests for SDK service helpers."""
import asyncio
import base64
import os
from dataclasses import dataclass
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -145,3 +146,145 @@ class TestPrepareFileAttachments:
assert "Read tool" not in result.hint
assert len(result.image_blocks) == 1
class TestPromptSupplement:
"""Tests for centralized prompt supplement generation."""
def test_sdk_supplement_excludes_tool_docs(self):
"""SDK mode should NOT include tool documentation (Claude gets schemas automatically)."""
from backend.copilot.prompting import get_sdk_supplement
# Test both local and E2B modes
local_supplement = get_sdk_supplement(use_e2b=False, cwd="/tmp/test")
e2b_supplement = get_sdk_supplement(use_e2b=True, cwd="")
# Should NOT have tool list section
assert "## AVAILABLE TOOLS" not in local_supplement
assert "## AVAILABLE TOOLS" not in e2b_supplement
# Should still have technical notes
assert "## Tool notes" in local_supplement
assert "## Tool notes" in e2b_supplement
def test_baseline_supplement_includes_tool_docs(self):
"""Baseline mode MUST include tool documentation (direct API needs it)."""
from backend.copilot.prompting import get_baseline_supplement
supplement = get_baseline_supplement()
# MUST have tool list section
assert "## AVAILABLE TOOLS" in supplement
# Should NOT have environment-specific notes (SDK-only)
assert "## Tool notes" not in supplement
def test_baseline_supplement_includes_key_tools(self):
"""Baseline supplement should document all essential tools."""
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.tools import TOOL_REGISTRY
docs = get_baseline_supplement()
# Core agent workflow tools (always available)
assert "`create_agent`" in docs
assert "`run_agent`" in docs
assert "`find_library_agent`" in docs
assert "`edit_agent`" in docs
# MCP integration (always available)
assert "`run_mcp_tool`" in docs
# Folder management (always available)
assert "`create_folder`" in docs
# Browser tools only if available (Playwright may not be installed in CI)
if (
TOOL_REGISTRY.get("browser_navigate")
and TOOL_REGISTRY["browser_navigate"].is_available
):
assert "`browser_navigate`" in docs
def test_baseline_supplement_includes_workflows(self):
"""Baseline supplement should include workflow guidance in tool descriptions."""
from backend.copilot.prompting import get_baseline_supplement
docs = get_baseline_supplement()
# Workflows are now in individual tool descriptions (not separate sections)
# Check that key workflow concepts appear in tool descriptions
assert "agent_json" in docs or "find_block" in docs
assert "run_mcp_tool" in docs
def test_baseline_supplement_completeness(self):
"""All available tools from TOOL_REGISTRY should appear in baseline supplement."""
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.tools import TOOL_REGISTRY
docs = get_baseline_supplement()
# Verify each available registered tool is documented
# (matches _generate_tool_documentation which filters by is_available)
for tool_name, tool in TOOL_REGISTRY.items():
if not tool.is_available:
continue
assert (
f"`{tool_name}`" in docs
), f"Tool '{tool_name}' missing from baseline supplement"
def test_pause_task_scheduled_before_transcript_upload(self):
"""Pause is scheduled as a background task before transcript upload begins.
The finally block in stream_response_sdk does:
(1) asyncio.create_task(pause_sandbox_direct(...)) — fire-and-forget
(2) await asyncio.shield(upload_transcript(...)) — awaited
Scheduling pause via create_task before awaiting upload ensures:
- Pause never blocks transcript upload (billing stops concurrently)
- On E2B timeout, pause silently fails; upload proceeds unaffected
"""
call_order: list[str] = []
async def _mock_pause(sandbox, session_id):
call_order.append("pause")
async def _mock_upload(**kwargs):
call_order.append("upload")
async def _simulate_teardown():
"""Mirror the service.py finally block teardown sequence."""
sandbox = MagicMock()
# (1) Schedule pause — mirrors lines ~1427-1429 in service.py
task = asyncio.create_task(_mock_pause(sandbox, "test-sess"))
# (2) Await transcript upload — mirrors lines ~1460-1468 in service.py
# Yielding to the event loop here lets the pause task start concurrently.
await _mock_upload(
user_id="u", session_id="test-sess", content="x", message_count=1
)
await task
asyncio.run(_simulate_teardown())
# Both must run; pause is scheduled before upload starts
assert "pause" in call_order
assert "upload" in call_order
# create_task schedules pause, then upload is awaited — pause runs
# concurrently during upload's first yield. The ordering guarantee is
# that create_task is CALLED before upload is AWAITED (see source order).
def test_baseline_supplement_no_duplicate_tools(self):
"""No tool should appear multiple times in baseline supplement."""
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.tools import TOOL_REGISTRY
docs = get_baseline_supplement()
# Count occurrences of each available tool in the entire supplement
for tool_name, tool in TOOL_REGISTRY.items():
if not tool.is_available:
continue
# Count how many times this tool appears as a bullet point
count = docs.count(f"- **`{tool_name}`**")
assert count == 1, f"Tool '{tool_name}' appears {count} times (should be 1)"

View File

@@ -9,14 +9,29 @@ import itertools
import json
import logging
import os
import re
import uuid
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any
from claude_agent_sdk import create_sdk_mcp_server, tool
from backend.copilot.context import (
_current_project_dir,
_current_sandbox,
_current_sdk_cwd,
_current_session,
_current_user_id,
_encode_cwd_for_cli,
get_execution_context,
get_sdk_cwd,
is_allowed_local_path,
)
from backend.copilot.model import ChatSession
from backend.copilot.sdk.file_ref import (
FileRefExpansionError,
expand_file_refs_in_args,
read_file_bytes,
)
from backend.copilot.tools import TOOL_REGISTRY
from backend.copilot.tools.base import BaseTool
from backend.util.truncate import truncate
@@ -28,84 +43,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Allowed base directory for the Read tool (SDK saves oversized tool results here).
# Restricted to ~/.claude/projects/ and further validated to require "tool-results"
# in the path — prevents reading settings, credentials, or other sensitive files.
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# Max MCP response size in chars — keeps tool output under the SDK's 10 MB JSON buffer.
_MCP_MAX_CHARS = 500_000
# Context variable holding the encoded project directory name for the current
# session (e.g. "-private-tmp-copilot-<uuid>"). Set by set_execution_context()
# so that path validation can scope tool-results reads to the current session.
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
def _encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does.
The CLI replaces all non-alphanumeric characters with ``-``.
"""
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
"""Check whether *path* is an allowed host-filesystem path.
Allowed:
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
- Files under ``~/.claude/projects/<encoded-cwd>/`` — the SDK's
project directory for this session (tool-results, transcripts, etc.)
Both checks are scoped to the **current session** so sessions cannot
read each other's data.
"""
if not path:
return False
if path.startswith("~"):
resolved = os.path.realpath(os.path.expanduser(path))
elif not os.path.isabs(path) and sdk_cwd:
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
else:
resolved = os.path.realpath(path)
# Allow access within the SDK working directory
if sdk_cwd:
norm_cwd = os.path.realpath(sdk_cwd)
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
return True
# Allow access within the current session's CLI project directory
# (~/.claude/projects/<encoded-cwd>/).
encoded = _current_project_dir.get("")
if encoded:
session_project = os.path.join(_SDK_PROJECTS_DIR, encoded)
if resolved == session_project or resolved.startswith(session_project + os.sep):
return True
return False
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
MCP_SERVER_NAME = "copilot"
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
# Context variables to pass user/session info to tool execution
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
_current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
)
# E2B cloud sandbox for the current turn (None when E2B is not configured).
# Passed to bash_exec so commands run on E2B instead of the local bwrap sandbox.
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
"_current_sandbox", default=None
)
# Raw SDK working directory path (e.g. /tmp/copilot-<session_id>).
# Used by workspace tools to save binary files for the CLI's built-in Read.
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
# Stash for MCP tool outputs before the SDK potentially truncates them.
# Keyed by tool_name → full output string. Consumed (popped) by the
# response adapter when it builds StreamToolOutputAvailable.
@@ -149,24 +93,6 @@ def set_execution_context(
_stash_event.set(asyncio.Event())
def get_current_sandbox() -> "AsyncSandbox | None":
"""Return the E2B sandbox for the current turn, or None."""
return _current_sandbox.get()
def get_sdk_cwd() -> str:
"""Return the SDK ephemeral working directory for the current turn."""
return _current_sdk_cwd.get()
def get_execution_context() -> tuple[str | None, ChatSession | None]:
"""Get the current execution context."""
return (
_current_user_id.get(),
_current_session.get(),
)
def pop_pending_tool_output(tool_name: str) -> str | None:
"""Pop and return the oldest stashed output for *tool_name*.
@@ -259,7 +185,11 @@ async def _execute_tool_sync(
session: ChatSession,
args: dict[str, Any],
) -> dict[str, Any]:
"""Execute a tool synchronously and return MCP-formatted response."""
"""Execute a tool synchronously and return MCP-formatted response.
Note: ``@@agptfile:`` expansion is handled upstream in the ``_truncating`` wrapper
so all registered handlers (BaseTool, E2B, Read) expand uniformly.
"""
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
result = await base_tool.execute(
user_id=user_id,
@@ -320,42 +250,50 @@ def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Read a local file with optional offset/limit.
"""Read a file with optional offset/limit.
Only allows paths that pass :func:`is_allowed_local_path` — the current
session's tool-results directory and ephemeral working directory.
Supports ``workspace://`` URIs (delegated to the workspace manager) and
local paths within the session's allowed directories (sdk_cwd + tool-results).
"""
file_path = args.get("file_path", "")
offset = args.get("offset", 0)
limit = args.get("limit", 2000)
offset = max(0, int(args.get("offset", 0)))
limit = max(1, int(args.get("limit", 2000)))
if not is_allowed_local_path(file_path):
return {
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
"isError": True,
}
def _mcp_err(text: str) -> dict[str, Any]:
return {"content": [{"type": "text", "text": text}], "isError": True}
def _mcp_ok(text: str) -> dict[str, Any]:
return {"content": [{"type": "text", "text": text}], "isError": False}
if file_path.startswith("workspace://"):
user_id, session = get_execution_context()
if session is None:
return _mcp_err("workspace:// file references require an active session")
try:
raw = await read_file_bytes(file_path, user_id, session)
except ValueError as exc:
return _mcp_err(str(exc))
lines = raw.decode("utf-8", errors="replace").splitlines(keepends=True)
selected = list(itertools.islice(lines, offset, offset + limit))
numbered = "".join(
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
)
return _mcp_ok(numbered)
if not is_allowed_local_path(file_path, get_sdk_cwd()):
return _mcp_err(f"Path not allowed: {file_path}")
resolved = os.path.realpath(os.path.expanduser(file_path))
try:
with open(resolved) as f:
selected = list(itertools.islice(f, offset, offset + limit))
content = "".join(selected)
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
# don't delete here — the SDK may read in multiple chunks.
return {
"content": [{"type": "text", "text": content}],
"isError": False,
}
return _mcp_ok("".join(selected))
except FileNotFoundError:
return {
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
"isError": True,
}
return _mcp_err(f"File not found: {file_path}")
except Exception as e:
return {
"content": [{"type": "text", "text": f"Error reading file: {e}"}],
"isError": True,
}
return _mcp_err(f"Error reading file: {e}")
_READ_TOOL_NAME = "Read"
@@ -414,9 +352,23 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
SDK's 10 MB JSON buffer, and stash the (truncated) output for the
response adapter before the SDK can apply its own head-truncation.
Also expands ``@@agptfile:`` references in args so every registered tool
(BaseTool, E2B file tools, Read) receives resolved content uniformly.
Applied once to every registered tool."""
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
user_id, session = get_execution_context()
if session is not None:
try:
args = await expand_file_refs_in_args(args, user_id, session)
except FileRefExpansionError as exc:
return _mcp_error(
f"@@agptfile: reference could not be resolved: {exc}. "
"Ensure the file exists before referencing it. "
"For sandbox paths use bash_exec to verify the file exists first; "
"for workspace files use a workspace:// URI."
)
result = await fn(args)
truncated = truncate(result, _MCP_MAX_CHARS)

View File

@@ -2,12 +2,12 @@
import pytest
from backend.copilot.context import get_sdk_cwd
from backend.util.truncate import truncate
from .tool_adapter import (
_MCP_MAX_CHARS,
_text_from_mcp_result,
get_sdk_cwd,
pop_pending_tool_output,
set_execution_context,
stash_pending_tool_output,

View File

@@ -10,13 +10,14 @@ Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
filesystem for self-hosted) — no DB column needed.
"""
import json
import logging
import os
import re
import time
from dataclasses import dataclass
from backend.util import json
logger = logging.getLogger(__name__)
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
@@ -58,41 +59,37 @@ def strip_progress_entries(content: str) -> str:
Removes entries whose ``type`` is in ``STRIPPABLE_TYPES`` and reparents
any remaining child entries so the ``parentUuid`` chain stays intact.
Typically reduces transcript size by ~30%.
Entries that are not stripped or reparented are kept as their original
raw JSON line to avoid unnecessary re-serialization that changes
whitespace or key ordering.
"""
lines = content.strip().split("\n")
entries: list[dict] = []
# Parse entries, keeping the original line alongside the parsed dict.
parsed: list[tuple[str, dict | None]] = []
for line in lines:
try:
entries.append(json.loads(line))
except json.JSONDecodeError:
# Keep unparseable lines as-is (safety)
entries.append({"_raw": line})
parsed.append((line, json.loads(line, fallback=None)))
# First pass: identify stripped UUIDs and build parent map.
stripped_uuids: set[str] = set()
uuid_to_parent: dict[str, str] = {}
kept: list[dict] = []
for entry in entries:
if "_raw" in entry:
kept.append(entry)
for _line, entry in parsed:
if not isinstance(entry, dict):
continue
uid = entry.get("uuid", "")
parent = entry.get("parentUuid", "")
entry_type = entry.get("type", "")
if uid:
uuid_to_parent[uid] = parent
if entry.get("type", "") in STRIPPABLE_TYPES and uid:
stripped_uuids.add(uid)
if entry_type in STRIPPABLE_TYPES:
if uid:
stripped_uuids.add(uid)
else:
kept.append(entry)
# Reparent: walk up chain through stripped entries to find surviving ancestor
for entry in kept:
if "_raw" in entry:
# Second pass: keep non-stripped entries, reparenting where needed.
# Preserve original line when no reparenting is required.
reparented: set[str] = set()
for _line, entry in parsed:
if not isinstance(entry, dict):
continue
parent = entry.get("parentUuid", "")
original_parent = parent
@@ -100,63 +97,32 @@ def strip_progress_entries(content: str) -> str:
parent = uuid_to_parent.get(parent, "")
if parent != original_parent:
entry["parentUuid"] = parent
uid = entry.get("uuid", "")
if uid:
reparented.add(uid)
result_lines: list[str] = []
for entry in kept:
if "_raw" in entry:
result_lines.append(entry["_raw"])
else:
for line, entry in parsed:
if not isinstance(entry, dict):
result_lines.append(line)
continue
if entry.get("type", "") in STRIPPABLE_TYPES:
continue
uid = entry.get("uuid", "")
if uid in reparented:
# Re-serialize only entries whose parentUuid was changed.
result_lines.append(json.dumps(entry, separators=(",", ":")))
else:
result_lines.append(line)
return "\n".join(result_lines) + "\n"
# ---------------------------------------------------------------------------
# Local file I/O (read from CLI's JSONL, write temp file for --resume)
# Local file I/O (write temp file for --resume)
# ---------------------------------------------------------------------------
def read_transcript_file(transcript_path: str) -> str | None:
"""Read a JSONL transcript file from disk.
Returns the raw JSONL content, or ``None`` if the file is missing, empty,
or only contains metadata (≤2 lines with no conversation messages).
"""
if not transcript_path or not os.path.isfile(transcript_path):
logger.debug(f"[Transcript] File not found: {transcript_path}")
return None
try:
with open(transcript_path) as f:
content = f.read()
if not content.strip():
logger.debug("[Transcript] File is empty: %s", transcript_path)
return None
lines = content.strip().split("\n")
# Validate that the transcript has real conversation content
# (not just metadata like queue-operation entries).
if not validate_transcript(content):
logger.debug(
"[Transcript] No conversation content (%d lines) in %s",
len(lines),
transcript_path,
)
return None
logger.info(
f"[Transcript] Read {len(lines)} lines, "
f"{len(content)} bytes from {transcript_path}"
)
return content
except (json.JSONDecodeError, OSError) as e:
logger.warning(f"[Transcript] Failed to read {transcript_path}: {e}")
return None
def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
"""Sanitize an ID for safe use in file paths.
@@ -171,14 +137,6 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
def _encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does.
The CLI replaces all non-alphanumeric characters with ``-``.
"""
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
"""Remove the CLI's project directory for a specific working directory.
@@ -188,7 +146,8 @@ def cleanup_cli_project_dir(sdk_cwd: str) -> None:
"""
import shutil
cwd_encoded = _encode_cwd_for_cli(sdk_cwd)
# Encode cwd the same way CLI does (replaces non-alphanumeric with -)
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
@@ -248,32 +207,29 @@ def write_transcript_to_tempfile(
def validate_transcript(content: str | None) -> bool:
"""Check that a transcript has actual conversation messages.
A valid transcript for resume needs at least one user message and one
assistant message (not just queue-operation / file-history-snapshot
metadata).
A valid transcript needs at least one assistant message (not just
queue-operation / file-history-snapshot metadata). We do NOT require
a ``type: "user"`` entry because with ``--resume`` the user's message
is passed as a CLI query parameter and does not appear in the
transcript file.
"""
if not content or not content.strip():
return False
lines = content.strip().split("\n")
if len(lines) < 2:
return False
has_user = False
has_assistant = False
for line in lines:
try:
entry = json.loads(line)
msg_type = entry.get("type")
if msg_type == "user":
has_user = True
elif msg_type == "assistant":
has_assistant = True
except json.JSONDecodeError:
if not line.strip():
continue
entry = json.loads(line, fallback=None)
if not isinstance(entry, dict):
return False
if entry.get("type") == "assistant":
has_assistant = True
return has_user and has_assistant
return has_assistant
# ---------------------------------------------------------------------------
@@ -328,58 +284,56 @@ async def upload_transcript(
session_id: str,
content: str,
message_count: int = 0,
log_prefix: str = "[Transcript]",
) -> None:
"""Strip progress entries and upload transcript to bucket storage.
"""Strip progress entries and upload complete transcript.
Safety: only overwrites when the new (stripped) transcript is larger than
what is already stored. Since JSONL is append-only, the latest transcript
is always the longest. This prevents a slow/stale background task from
clobbering a newer upload from a concurrent turn.
The transcript represents the FULL active context (atomic).
Each upload REPLACES the previous transcript entirely.
The executor holds a cluster lock per session, so concurrent uploads for
the same session cannot happen.
Args:
message_count: ``len(session.messages)`` at upload time — used by
the next turn to detect staleness and compress only the gap.
content: Complete JSONL transcript (from TranscriptBuilder).
message_count: ``len(session.messages)`` at upload time.
"""
from backend.util.workspace_storage import get_workspace_storage
# Strip metadata entries (progress, file-history-snapshot, etc.)
# Note: SDK-built transcripts shouldn't have these, but strip for safety
stripped = strip_progress_entries(content)
if not validate_transcript(stripped):
# Log entry types for debugging — helps identify why validation failed
entry_types: list[str] = []
for line in stripped.strip().split("\n"):
entry = json.loads(line, fallback={"type": "INVALID_JSON"})
entry_types.append(entry.get("type", "?"))
logger.warning(
f"[Transcript] Skipping upload — stripped content not valid "
f"for session {session_id}"
"%s Skipping upload — stripped content not valid "
"(types=%s, stripped_len=%d, raw_len=%d)",
log_prefix,
entry_types,
len(stripped),
len(content),
)
logger.debug("%s Raw content preview: %s", log_prefix, content[:500])
logger.debug("%s Stripped content: %s", log_prefix, stripped[:500])
return
storage = await get_workspace_storage()
wid, fid, fname = _storage_path_parts(user_id, session_id)
encoded = stripped.encode("utf-8")
new_size = len(encoded)
# Check existing transcript size to avoid overwriting newer with older
path = _build_storage_path(user_id, session_id, storage)
content_skipped = False
try:
existing = await storage.retrieve(path)
if len(existing) >= new_size:
logger.info(
f"[Transcript] Skipping content upload — existing ({len(existing)}B) "
f">= new ({new_size}B) for session {session_id}"
)
content_skipped = True
except (FileNotFoundError, Exception):
pass # No existing transcript or retrieval error — proceed with upload
await storage.store(
workspace_id=wid,
file_id=fid,
filename=fname,
content=encoded,
)
if not content_skipped:
await storage.store(
workspace_id=wid,
file_id=fid,
filename=fname,
content=encoded,
)
# Always update metadata (even when content is skipped) so message_count
# stays current. The gap-fill logic in _build_query_message relies on
# message_count to avoid re-compressing the same messages every turn.
# Update metadata so message_count stays current. The gap-fill logic
# in _build_query_message relies on it to avoid re-compressing messages.
try:
meta = {"message_count": message_count, "uploaded_at": time.time()}
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
@@ -390,18 +344,18 @@ async def upload_transcript(
content=json.dumps(meta).encode("utf-8"),
)
except Exception as e:
logger.warning(f"[Transcript] Failed to write metadata for {session_id}: {e}")
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
logger.info(
f"[Transcript] Uploaded {new_size}B "
f"(stripped from {len(content)}B, msg_count={message_count}, "
f"content_skipped={content_skipped}) "
f"for session {session_id}"
f"{log_prefix} Uploaded {len(encoded)}B "
f"(stripped from {len(content)}B, msg_count={message_count})"
)
async def download_transcript(
user_id: str, session_id: str
user_id: str,
session_id: str,
log_prefix: str = "[Transcript]",
) -> TranscriptDownload | None:
"""Download transcript and metadata from bucket storage.
@@ -417,10 +371,10 @@ async def download_transcript(
data = await storage.retrieve(path)
content = data.decode("utf-8")
except FileNotFoundError:
logger.debug(f"[Transcript] No transcript in storage for {session_id}")
logger.debug(f"{log_prefix} No transcript in storage")
return None
except Exception as e:
logger.warning(f"[Transcript] Failed to download transcript: {e}")
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
return None
# Try to load metadata (best-effort — old transcripts won't have it)
@@ -437,16 +391,13 @@ async def download_transcript(
meta_path = f"local://{mwid}/{mfid}/{mfname}"
meta_data = await storage.retrieve(meta_path)
meta = json.loads(meta_data.decode("utf-8"))
meta = json.loads(meta_data.decode("utf-8"), fallback={})
message_count = meta.get("message_count", 0)
uploaded_at = meta.get("uploaded_at", 0.0)
except (FileNotFoundError, json.JSONDecodeError, Exception):
except (FileNotFoundError, Exception):
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
logger.info(
f"[Transcript] Downloaded {len(content)}B "
f"(msg_count={message_count}) for session {session_id}"
)
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
return TranscriptDownload(
content=content,
message_count=message_count,

View File

@@ -0,0 +1,188 @@
"""Build complete JSONL transcript from SDK messages.
The transcript represents the FULL active context at any point in time.
Each upload REPLACES the previous transcript atomically.
Flow:
Turn 1: Upload [msg1, msg2]
Turn 2: Download [msg1, msg2] → Upload [msg1, msg2, msg3, msg4] (REPLACE)
Turn 3: Download [msg1, msg2, msg3, msg4] → Upload [all messages] (REPLACE)
The transcript is never incremental - always the complete atomic state.
"""
import logging
from typing import Any
from uuid import uuid4
from pydantic import BaseModel
from backend.util import json
from .transcript import STRIPPABLE_TYPES
logger = logging.getLogger(__name__)
class TranscriptEntry(BaseModel):
"""Single transcript entry (user or assistant turn)."""
type: str
uuid: str
parentUuid: str | None
message: dict[str, Any]
class TranscriptBuilder:
"""Build complete JSONL transcript from SDK messages.
This builder maintains the FULL conversation state, not incremental changes.
The output is always the complete active context.
"""
def __init__(self) -> None:
self._entries: list[TranscriptEntry] = []
self._last_uuid: str | None = None
def _last_is_assistant(self) -> bool:
return bool(self._entries) and self._entries[-1].type == "assistant"
def _last_message_id(self) -> str:
"""Return the message.id of the last entry, or '' if none."""
if self._entries:
return self._entries[-1].message.get("id", "")
return ""
def load_previous(self, content: str, log_prefix: str = "[Transcript]") -> None:
"""Load complete previous transcript.
This loads the FULL previous context. As new messages come in,
we append to this state. The final output is the complete context
(previous + new), not just the delta.
"""
if not content or not content.strip():
return
lines = content.strip().split("\n")
for line_num, line in enumerate(lines, 1):
if not line.strip():
continue
data = json.loads(line, fallback=None)
if data is None:
logger.warning(
"%s Failed to parse transcript line %d/%d",
log_prefix,
line_num,
len(lines),
)
continue
# Load all non-strippable entries (user/assistant/system/etc.)
# Skip only STRIPPABLE_TYPES to match strip_progress_entries() behavior
entry_type = data.get("type", "")
if entry_type in STRIPPABLE_TYPES:
continue
entry = TranscriptEntry(
type=data["type"],
uuid=data.get("uuid") or str(uuid4()),
parentUuid=data.get("parentUuid"),
message=data.get("message", {}),
)
self._entries.append(entry)
self._last_uuid = entry.uuid
logger.info(
"%s Loaded %d entries from previous transcript (last_uuid=%s)",
log_prefix,
len(self._entries),
self._last_uuid[:12] if self._last_uuid else None,
)
def append_user(self, content: str | list[dict], uuid: str | None = None) -> None:
"""Append a user entry."""
msg_uuid = uuid or str(uuid4())
self._entries.append(
TranscriptEntry(
type="user",
uuid=msg_uuid,
parentUuid=self._last_uuid,
message={"role": "user", "content": content},
)
)
self._last_uuid = msg_uuid
def append_tool_result(self, tool_use_id: str, content: str) -> None:
"""Append a tool result as a user entry (one per tool call)."""
self.append_user(
content=[
{"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
]
)
def append_assistant(
self,
content_blocks: list[dict],
model: str = "",
stop_reason: str | None = None,
) -> None:
"""Append an assistant entry.
Consecutive assistant entries automatically share the same message ID
so the CLI can merge them (thinking → text → tool_use) into a single
API message on ``--resume``. A new ID is assigned whenever an
assistant entry follows a non-assistant entry (user message or tool
result), because that marks the start of a new API response.
"""
message_id = (
self._last_message_id()
if self._last_is_assistant()
else f"msg_sdk_{uuid4().hex[:24]}"
)
msg_uuid = str(uuid4())
self._entries.append(
TranscriptEntry(
type="assistant",
uuid=msg_uuid,
parentUuid=self._last_uuid,
message={
"role": "assistant",
"model": model,
"id": message_id,
"type": "message",
"content": content_blocks,
"stop_reason": stop_reason,
"stop_sequence": None,
},
)
)
self._last_uuid = msg_uuid
def to_jsonl(self) -> str:
"""Export complete context as JSONL.
Consecutive assistant entries are kept separate to match the
native CLI format — the SDK merges them internally on resume.
Returns the FULL conversation state (all entries), not incremental.
This output REPLACES any previous transcript.
"""
if not self._entries:
return ""
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
return "\n".join(lines) + "\n"
@property
def entry_count(self) -> int:
"""Total number of entries in the complete context."""
return len(self._entries)
@property
def is_empty(self) -> bool:
"""Whether this builder has any entries."""
return len(self._entries) == 0

View File

@@ -1,11 +1,11 @@
"""Unit tests for JSONL transcript management utilities."""
import json
import os
from backend.util import json
from .transcript import (
STRIPPABLE_TYPES,
read_transcript_file,
strip_progress_entries,
validate_transcript,
write_transcript_to_tempfile,
@@ -38,49 +38,6 @@ PROGRESS_ENTRY = {
VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG)
# --- read_transcript_file ---
class TestReadTranscriptFile:
def test_returns_content_for_valid_file(self, tmp_path):
path = tmp_path / "session.jsonl"
path.write_text(VALID_TRANSCRIPT)
result = read_transcript_file(str(path))
assert result is not None
assert "user" in result
def test_returns_none_for_missing_file(self):
assert read_transcript_file("/nonexistent/path.jsonl") is None
def test_returns_none_for_empty_path(self):
assert read_transcript_file("") is None
def test_returns_none_for_empty_file(self, tmp_path):
path = tmp_path / "empty.jsonl"
path.write_text("")
assert read_transcript_file(str(path)) is None
def test_returns_none_for_metadata_only(self, tmp_path):
content = _make_jsonl(METADATA_LINE, FILE_HISTORY)
path = tmp_path / "meta.jsonl"
path.write_text(content)
assert read_transcript_file(str(path)) is None
def test_returns_none_for_invalid_json(self, tmp_path):
path = tmp_path / "bad.jsonl"
path.write_text("not json\n{}\n{}\n")
assert read_transcript_file(str(path)) is None
def test_no_size_limit(self, tmp_path):
"""Large files are accepted — bucket storage has no size limit."""
big_content = {"type": "user", "uuid": "u9", "data": "x" * 1_000_000}
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, big_content, ASST_MSG)
path = tmp_path / "big.jsonl"
path.write_text(content)
result = read_transcript_file(str(path))
assert result is not None
# --- write_transcript_to_tempfile ---
@@ -155,12 +112,56 @@ class TestValidateTranscript:
assert validate_transcript(content) is False
def test_assistant_only_no_user(self):
"""With --resume the user message is a CLI query param, not a transcript entry.
A transcript with only assistant entries is valid."""
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, ASST_MSG)
assert validate_transcript(content) is False
assert validate_transcript(content) is True
def test_resume_transcript_without_user_entry(self):
"""Simulates a real --resume stop hook transcript: the CLI session file
has summary + assistant entries but no user entry."""
summary = {"type": "summary", "uuid": "s1", "text": "context..."}
asst1 = {
"type": "assistant",
"uuid": "a1",
"message": {"role": "assistant", "content": "Hello!"},
}
asst2 = {
"type": "assistant",
"uuid": "a2",
"parentUuid": "a1",
"message": {"role": "assistant", "content": "Sure, let me help."},
}
content = _make_jsonl(summary, asst1, asst2)
assert validate_transcript(content) is True
def test_single_assistant_entry(self):
"""A transcript with just one assistant line is valid — the CLI may
produce short transcripts for simple responses with no tool use."""
content = json.dumps(ASST_MSG) + "\n"
assert validate_transcript(content) is True
def test_invalid_json_returns_false(self):
assert validate_transcript("not json\n{}\n{}\n") is False
def test_malformed_json_after_valid_assistant_returns_false(self):
"""Validation must scan all lines - malformed JSON anywhere should fail."""
valid_asst = json.dumps(ASST_MSG)
malformed = "not valid json"
content = valid_asst + "\n" + malformed + "\n"
assert validate_transcript(content) is False
def test_blank_lines_are_skipped(self):
"""Transcripts with blank lines should be valid if they contain assistant entries."""
content = (
json.dumps(USER_MSG)
+ "\n\n" # blank line
+ json.dumps(ASST_MSG)
+ "\n"
+ "\n" # another blank line
)
assert validate_transcript(content) is True
# --- strip_progress_entries ---
@@ -253,3 +254,31 @@ class TestStripProgressEntries:
assert "queue-operation" not in result_types
assert "user" in result_types
assert "assistant" in result_types
def test_preserves_original_line_formatting(self):
"""Non-reparented entries keep their original JSON formatting."""
# orjson produces compact JSON - test that we preserve the exact input
# when no reparenting is needed (no re-serialization)
original_line = json.dumps(USER_MSG)
content = original_line + "\n" + json.dumps(ASST_MSG) + "\n"
result = strip_progress_entries(content)
result_lines = result.strip().split("\n")
# Original line should be byte-identical (not re-serialized)
assert result_lines[0] == original_line
def test_reparented_entries_are_reserialized(self):
"""Entries whose parentUuid changes must be re-serialized."""
progress = {"type": "progress", "uuid": "p1", "parentUuid": "u1"}
asst = {
"type": "assistant",
"uuid": "a1",
"parentUuid": "p1",
"message": {"role": "assistant", "content": "done"},
}
content = _make_jsonl(USER_MSG, progress, asst)
result = strip_progress_entries(content)
lines = result.strip().split("\n")
asst_entry = json.loads(lines[-1])
assert asst_entry["parentUuid"] == "u1" # reparented

View File

@@ -18,7 +18,7 @@ from langfuse.openai import (
from backend.data.db_accessors import understanding_db
from backend.data.understanding import format_understanding_for_prompt
from backend.util.exceptions import NotFoundError
from backend.util.exceptions import NotAuthorizedError, NotFoundError
from backend.util.settings import AppEnvironment, Settings
from .config import ChatConfig
@@ -34,8 +34,9 @@ client = LangfuseAsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
langfuse = get_client()
# Default system prompt used when Langfuse is not configured
# This is a snapshot of the "CoPilot Prompt" from Langfuse (version 11)
DEFAULT_SYSTEM_PROMPT = """You are **Otto**, an AI Co-Pilot for AutoGPT and a Forward-Deployed Automation Engineer serving small business owners. Your mission is to help users automate business tasks with AI by delivering tangible value through working automations—not through documentation or lengthy explanations.
# Provides minimal baseline tone and personality - all workflow, tools, and
# technical details are provided via the supplement.
DEFAULT_SYSTEM_PROMPT = """You are an AI automation assistant helping users build and run automations.
Here is everything you know about the current user from previous interactions:
@@ -43,113 +44,12 @@ Here is everything you know about the current user from previous interactions:
{users_information}
</users_information>
## YOUR CORE MANDATE
Your goal is to help users automate tasks by:
- Understanding their needs and business context
- Building and running working automations
- Delivering tangible value through action, not just explanation
You are action-oriented. Your success is measured by:
- **Value Delivery**: Does the user think "wow, that was amazing" or "what was the point"?
- **Demonstrable Proof**: Show working automations, not descriptions of what's possible
- **Time Saved**: Focus on tangible efficiency gains
- **Quality Output**: Deliver results that meet or exceed expectations
## YOUR WORKFLOW
Adapt flexibly to the conversation context. Not every interaction requires all stages:
1. **Explore & Understand**: Learn about the user's business, tasks, and goals. Use `add_understanding` to capture important context that will improve future conversations.
2. **Assess Automation Potential**: Help the user understand whether and how AI can automate their task.
3. **Prepare for AI**: Provide brief, actionable guidance on prerequisites (data, access, etc.).
4. **Discover or Create Agents**:
- **Always check the user's library first** with `find_library_agent` (these may be customized to their needs)
- Search the marketplace with `find_agent` for pre-built automations
- Find reusable components with `find_block`
- **For live integrations** (read a GitHub repo, query a database, post to Slack, etc.) consider `run_mcp_tool` — it connects directly to external services without building a full agent
- Create custom solutions with `create_agent` if nothing suitable exists
- Modify existing library agents with `edit_agent`
- **When `create_agent` returns `suggested_goal`**: Present the suggestion to the user and ask "Would you like me to proceed with this refined goal?" If they accept, call `create_agent` again with the suggested goal.
- **When `create_agent` returns `clarifying_questions`**: After the user answers, call `create_agent` again with the original description AND the answers in the `context` parameter.
5. **Execute**: Run automations immediately, schedule them, or set up webhooks using `run_agent`. Test specific components with `run_block`.
6. **Show Results**: Display outputs using `agent_output`.
## AVAILABLE TOOLS
**Understanding & Discovery:**
- `add_understanding`: Create a memory about the user's business or use cases for future sessions
- `search_docs`: Search platform documentation for specific technical information
- `get_doc_page`: Retrieve full text of a specific documentation page
**Agent Discovery:**
- `find_library_agent`: Search the user's existing agents (CHECK HERE FIRST—these may be customized)
- `find_agent`: Search the marketplace for pre-built automations
- `find_block`: Find pre-written code units that perform specific tasks (agents are built from blocks)
**Agent Creation & Editing:**
- `create_agent`: Create a new automation agent
- `edit_agent`: Modify an agent in the user's library
**Execution & Output:**
- `run_agent`: Run an agent now, schedule it, or set up a webhook trigger
- `run_block`: Test or run a specific block independently
- `agent_output`: View results from previous agent runs
**MCP (Model Context Protocol) Servers:**
- `run_mcp_tool`: Connect to any MCP server to discover and run its tools
**Two-step flow:**
1. `run_mcp_tool(server_url)` → returns a list of available tools. Each tool has `name`, `description`, and `input_schema` (JSON Schema). Read `input_schema.properties` to understand what arguments are needed.
2. `run_mcp_tool(server_url, tool_name, tool_arguments)` → executes the tool. Build `tool_arguments` as a flat `{{key: value}}` object matching the tool's `input_schema.properties`.
**Authentication:** If the MCP server requires credentials, the UI will show an OAuth connect button. Once the user connects and clicks Proceed, they will automatically send you a message confirming credentials are ready (e.g. "I've connected the MCP server credentials. Please retry run_mcp_tool..."). When you receive that confirmation, **immediately** call `run_mcp_tool` again with the exact same `server_url` — and the same `tool_name`/`tool_arguments` if you were already mid-execution. Do not ask the user what to do next; just retry.
**Finding server URLs (fastest → slowest):**
1. **Known hosted servers** — use directly, no lookup:
- Notion: `https://mcp.notion.com/mcp`
- Linear: `https://mcp.linear.app/mcp`
- Stripe: `https://mcp.stripe.com`
- Intercom: `https://mcp.intercom.com/mcp`
- Cloudflare: `https://mcp.cloudflare.com/mcp`
- Atlassian (Jira/Confluence): `https://mcp.atlassian.com/mcp`
2. **`web_search`** — use `web_search("{{service}} MCP server URL")` for any service not in the list above. This is the fastest way to find unlisted servers.
3. **Registry API** — `web_fetch("https://registry.modelcontextprotocol.io/v0.1/servers?search={{query}}&limit=10")` to browse what's available. Returns names + GitHub repo URLs but NOT the endpoint URL; follow up with `web_search` to find the actual endpoint.
- **Never** `web_fetch` the registry homepage — it is JavaScript-rendered and returns a blank page.
**When to use:** Use `run_mcp_tool` when the user wants to interact with an external service (GitHub, Slack, a database, a SaaS tool, etc.) via its MCP integration. Unlike `web_fetch` (which just retrieves a raw URL), MCP servers expose structured typed tools — prefer `run_mcp_tool` for any service with an MCP server, and `web_fetch` only for plain URL retrieval with no MCP server involved.
**CRITICAL**: `run_mcp_tool` is **always available** in your tool list. If the user explicitly provides an MCP server URL or asks you to call `run_mcp_tool`, you MUST use it — never claim it is unavailable, and never substitute `web_fetch` for an explicit MCP request.
## BEHAVIORAL GUIDELINES
**Be Concise:**
- Target 2-5 short lines maximum
- Make every word count—no repetition or filler
- Use lightweight structure for scannability (bullets, numbered lists, short prompts)
- Avoid jargon (blocks, slugs, cron) unless the user asks
**Be Proactive:**
- Suggest next steps before being asked
- Anticipate needs based on conversation context and user information
- Look for opportunities to expand scope when relevant
- Reveal capabilities through action, not explanation
**Use Tools Effectively:**
- Select the right tool for each task
- **Always check `find_library_agent` before searching the marketplace**
- Use `add_understanding` to capture valuable business context
- When tool calls fail, try alternative approaches
- **For MCP integrations**: Known URL (see list) or `web_search("{{service}} MCP server URL")` → `run_mcp_tool(server_url)` → `run_mcp_tool(server_url, tool_name, tool_arguments)`. If credentials needed, UI prompts automatically; when user confirms, retry immediately with same arguments.
**Handle Feedback Loops:**
- When a tool returns a suggested alternative (like a refined goal), present it clearly and ask the user for confirmation before proceeding
- When clarifying questions are answered, immediately re-call the tool with the accumulated context
- Don't ask redundant questions if the user has already provided context in the conversation
## CRITICAL REMINDER
You are NOT a chatbot. You are NOT documentation. You are a partner who helps busy business owners get value quickly by showing proof through working automations. Bias toward action over explanation."""
Be concise, proactive, and action-oriented. Bias toward showing working solutions over lengthy explanations."""
# ---------------------------------------------------------------------------
@@ -298,6 +198,12 @@ async def assign_user_to_session(
session = await get_chat_session(session_id, None)
if not session:
raise NotFoundError(f"Session {session_id} not found")
if session.user_id is not None and session.user_id != user_id:
logger.warning(
f"[SECURITY] Attempt to claim session {session_id} by user {user_id}, "
f"but it already belongs to user {session.user_id}"
)
raise NotAuthorizedError(f"Not authorized to claim session {session_id}")
session.user_id = user_id
session = await upsert_chat_session(session)
return session

View File

@@ -12,6 +12,7 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
from .agent_output import AgentOutputTool
from .base import BaseTool
from .bash_exec import BashExecTool
from .continue_run_block import ContinueRunBlockTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
from .edit_agent import EditAgentTool
@@ -19,11 +20,23 @@ from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsToo
from .find_agent import FindAgentTool
from .find_block import FindBlockTool
from .find_library_agent import FindLibraryAgentTool
from .fix_agent import FixAgentGraphTool
from .get_agent_building_guide import GetAgentBuildingGuideTool
from .get_doc_page import GetDocPageTool
from .get_mcp_guide import GetMCPGuideTool
from .manage_folders import (
CreateFolderTool,
DeleteFolderTool,
ListFoldersTool,
MoveAgentsToFolderTool,
MoveFolderTool,
UpdateFolderTool,
)
from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .run_mcp_tool import RunMCPToolTool
from .search_docs import SearchDocsTool
from .validate_agent import ValidateAgentGraphTool
from .web_fetch import WebFetchTool
from .workspace_files import (
DeleteWorkspaceFileTool,
@@ -47,12 +60,22 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"find_agent": FindAgentTool(),
"find_block": FindBlockTool(),
"find_library_agent": FindLibraryAgentTool(),
# Folder management tools
"create_folder": CreateFolderTool(),
"list_folders": ListFoldersTool(),
"update_folder": UpdateFolderTool(),
"move_folder": MoveFolderTool(),
"delete_folder": DeleteFolderTool(),
"move_agents_to_folder": MoveAgentsToFolderTool(),
"run_agent": RunAgentTool(),
"run_block": RunBlockTool(),
"continue_run_block": ContinueRunBlockTool(),
"run_mcp_tool": RunMCPToolTool(),
"get_mcp_guide": GetMCPGuideTool(),
"view_agent_output": AgentOutputTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
"get_agent_building_guide": GetAgentBuildingGuideTool(),
# Web fetch for safe URL retrieval
"web_fetch": WebFetchTool(),
# Agent-browser multi-step automation (navigate, act, screenshot)
@@ -65,6 +88,9 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
# Feature request tools
"search_feature_requests": SearchFeatureRequestsTool(),
"create_feature_request": CreateFeatureRequestTool(),
# Agent generation tools (local validation/fixing)
"validate_agent_graph": ValidateAgentGraphTool(),
"fix_agent_graph": FixAgentGraphTool(),
# Workspace tools for CoPilot file operations
"list_workspace_files": ListWorkspaceFilesTool(),
"read_workspace_file": ReadWorkspaceFileTool(),

View File

@@ -151,8 +151,8 @@ async def setup_test_data(server):
unique_slug = f"test-agent-{str(uuid.uuid4())[:8]}"
store_submission = await store_db.create_store_submission(
user_id=user.id,
agent_id=created_graph.id,
agent_version=created_graph.version,
graph_id=created_graph.id,
graph_version=created_graph.version,
slug=unique_slug,
name="Test Agent",
description="A simple test agent",
@@ -161,10 +161,10 @@ async def setup_test_data(server):
image_urls=["https://example.com/image.jpg"],
)
assert store_submission.store_listing_version_id is not None
assert store_submission.listing_version_id is not None
# 4. Approve the store listing version
await store_db.review_store_submission(
store_listing_version_id=store_submission.store_listing_version_id,
store_listing_version_id=store_submission.listing_version_id,
is_approved=True,
external_comments="Approved for testing",
internal_comments="Test approval",
@@ -321,8 +321,8 @@ async def setup_llm_test_data(server):
unique_slug = f"llm-test-agent-{str(uuid.uuid4())[:8]}"
store_submission = await store_db.create_store_submission(
user_id=user.id,
agent_id=created_graph.id,
agent_version=created_graph.version,
graph_id=created_graph.id,
graph_version=created_graph.version,
slug=unique_slug,
name="LLM Test Agent",
description="An agent with LLM capabilities",
@@ -330,9 +330,9 @@ async def setup_llm_test_data(server):
categories=["testing", "ai"],
image_urls=["https://example.com/image.jpg"],
)
assert store_submission.store_listing_version_id is not None
assert store_submission.listing_version_id is not None
await store_db.review_store_submission(
store_listing_version_id=store_submission.store_listing_version_id,
store_listing_version_id=store_submission.listing_version_id,
is_approved=True,
external_comments="Approved for testing",
internal_comments="Test approval for LLM agent",
@@ -476,8 +476,8 @@ async def setup_firecrawl_test_data(server):
unique_slug = f"firecrawl-test-agent-{str(uuid.uuid4())[:8]}"
store_submission = await store_db.create_store_submission(
user_id=user.id,
agent_id=created_graph.id,
agent_version=created_graph.version,
graph_id=created_graph.id,
graph_version=created_graph.version,
slug=unique_slug,
name="Firecrawl Test Agent",
description="An agent with Firecrawl integration (no credentials)",
@@ -485,9 +485,9 @@ async def setup_firecrawl_test_data(server):
categories=["testing", "scraping"],
image_urls=["https://example.com/image.jpg"],
)
assert store_submission.store_listing_version_id is not None
assert store_submission.listing_version_id is not None
await store_db.review_store_submission(
store_listing_version_id=store_submission.store_listing_version_id,
store_listing_version_id=store_submission.listing_version_id,
is_approved=True,
external_comments="Approved for testing",
internal_comments="Test approval for Firecrawl agent",

View File

@@ -33,7 +33,7 @@ import tempfile
from typing import Any
from backend.copilot.model import ChatSession
from backend.util.request import validate_url
from backend.util.request import validate_url_host
from .base import BaseTool
from .models import (
@@ -235,7 +235,7 @@ async def _restore_browser_state(
if url:
# Validate the saved URL to prevent SSRF via stored redirect targets.
try:
await validate_url(url, trusted_origins=[])
await validate_url_host(url)
except ValueError:
logger.warning(
"[browser] State restore: blocked SSRF URL %s", url[:200]
@@ -473,7 +473,7 @@ class BrowserNavigateTool(BaseTool):
)
try:
await validate_url(url, trusted_origins=[])
await validate_url_host(url)
except ValueError as e:
return ErrorResponse(
message=str(e),

View File

@@ -68,17 +68,18 @@ def _run_result(rc: int = 0, stdout: str = "", stderr: str = "") -> tuple:
# ---------------------------------------------------------------------------
# SSRF protection via shared validate_url (backend.util.request)
# SSRF protection via shared validate_url_host (backend.util.request)
# ---------------------------------------------------------------------------
# Patch target: validate_url is imported directly into agent_browser's module scope.
_VALIDATE_URL = "backend.copilot.tools.agent_browser.validate_url"
# Patch target: validate_url_host is imported directly into agent_browser's
# module scope.
_VALIDATE_URL = "backend.copilot.tools.agent_browser.validate_url_host"
class TestSsrfViaValidateUrl:
"""Verify that browser_navigate uses validate_url for SSRF protection.
"""Verify that browser_navigate uses validate_url_host for SSRF protection.
We mock validate_url itself (not the low-level socket) so these tests
We mock validate_url_host itself (not the low-level socket) so these tests
exercise the integration point, not the internals of request.py
(which has its own thorough test suite in request_test.py).
"""
@@ -89,7 +90,7 @@ class TestSsrfViaValidateUrl:
@pytest.mark.asyncio
async def test_blocked_ip_returns_blocked_url_error(self):
"""validate_url raises ValueError → tool returns blocked_url ErrorResponse."""
"""validate_url_host raises ValueError → tool returns blocked_url ErrorResponse."""
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
mock_validate.side_effect = ValueError(
"Access to blocked IP 10.0.0.1 is not allowed."
@@ -124,8 +125,8 @@ class TestSsrfViaValidateUrl:
assert result.error == "blocked_url"
@pytest.mark.asyncio
async def test_validate_url_called_with_empty_trusted_origins(self):
"""Confirms no trusted-origins bypass is granted — all URLs are validated."""
async def test_validate_url_host_called_without_trusted_hostnames(self):
"""Confirms no trusted-hostnames bypass is granted — all URLs are validated."""
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
mock_validate.return_value = (object(), False, ["1.2.3.4"])
with patch(
@@ -143,7 +144,7 @@ class TestSsrfViaValidateUrl:
session=self.session,
url="https://example.com",
)
mock_validate.assert_called_once_with("https://example.com", trusted_origins=[])
mock_validate.assert_called_once_with("https://example.com")
# ---------------------------------------------------------------------------

View File

@@ -1,20 +1,15 @@
"""Agent generator package - Creates agents from natural language."""
from .core import (
AgentGeneratorNotConfiguredError,
AgentJsonValidationError,
AgentSummary,
DecompositionResult,
DecompositionStep,
LibraryAgentSummary,
MarketplaceAgentSummary,
customize_template,
decompose_goal,
enrich_library_agents_from_steps,
extract_search_terms_from_steps,
extract_uuids_from_text,
generate_agent,
generate_agent_patch,
get_agent_as_json,
get_all_relevant_agents_for_generation,
get_library_agent_by_graph_id,
@@ -27,25 +22,20 @@ from .core import (
search_marketplace_agents_for_generation,
)
from .errors import get_user_message_for_error
from .service import health_check as check_external_service_health
from .service import is_external_service_configured
from .validation import AgentFixer, AgentValidator
__all__ = [
"AgentGeneratorNotConfiguredError",
"AgentFixer",
"AgentValidator",
"AgentJsonValidationError",
"AgentSummary",
"DecompositionResult",
"DecompositionStep",
"LibraryAgentSummary",
"MarketplaceAgentSummary",
"check_external_service_health",
"customize_template",
"decompose_goal",
"enrich_library_agents_from_steps",
"extract_search_terms_from_steps",
"extract_uuids_from_text",
"generate_agent",
"generate_agent_patch",
"get_agent_as_json",
"get_all_relevant_agents_for_generation",
"get_library_agent_by_graph_id",
@@ -54,7 +44,6 @@ __all__ = [
"get_library_agents_for_generation",
"get_user_message_for_error",
"graph_to_json",
"is_external_service_configured",
"json_to_graph",
"save_agent_to_library",
"search_marketplace_agents_for_generation",

View File

@@ -0,0 +1,66 @@
"""Block management for agent generation.
Provides cached access to block metadata for validation and fixing.
"""
import logging
from typing import Any, Type
from backend.blocks import get_blocks as get_block_classes
from backend.blocks._base import Block
logger = logging.getLogger(__name__)
__all__ = ["get_blocks_as_dicts", "reset_block_caches"]
# ---------------------------------------------------------------------------
# Module-level caches
# ---------------------------------------------------------------------------
_blocks_cache: list[dict[str, Any]] | None = None
def reset_block_caches() -> None:
"""Reset all module-level caches (useful after updating block descriptions)."""
global _blocks_cache
_blocks_cache = None
# ---------------------------------------------------------------------------
# 1. get_blocks_as_dicts
# ---------------------------------------------------------------------------
def get_blocks_as_dicts() -> list[dict[str, Any]]:
"""Get all available blocks as dicts (cached after first call).
Each dict contains the keys returned by ``Block.get_info().model_dump()``:
id, name, description, inputSchema, outputSchema, categories,
staticOutput, costs, contributors, uiType.
Returns:
List of block info dicts.
"""
global _blocks_cache
if _blocks_cache is not None:
return _blocks_cache
block_classes: dict[str, Type[Block]] = get_block_classes() # type: ignore[assignment]
blocks: list[dict[str, Any]] = []
for block_cls in block_classes.values():
try:
instance = block_cls()
info = instance.get_info().model_dump()
# Use optimized description if available (loaded at startup)
if instance.optimized_description:
info["description"] = instance.optimized_description
blocks.append(info)
except Exception:
logger.warning(
"Failed to load block info for %s, skipping",
getattr(block_cls, "__name__", "unknown"),
exc_info=True,
)
_blocks_cache = blocks
logger.info("Cached %d block dicts", len(blocks))
return _blocks_cache

View File

@@ -10,13 +10,7 @@ from backend.data.db_accessors import graph_db, library_db, store_db
from backend.data.graph import Graph, Link, Node
from backend.util.exceptions import DatabaseError, NotFoundError
from .service import (
customize_template_external,
decompose_goal_external,
generate_agent_external,
generate_agent_patch_external,
is_external_service_configured,
)
from .helpers import UUID_RE_STR
logger = logging.getLogger(__name__)
@@ -78,38 +72,7 @@ class DecompositionResult(TypedDict, total=False):
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
def _to_dict_list(
agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None,
) -> list[dict[str, Any]] | None:
"""Convert typed agent summaries to plain dicts for external service calls."""
if agents is None:
return None
return [dict(a) for a in agents]
class AgentGeneratorNotConfiguredError(Exception):
"""Raised when the external Agent Generator service is not configured."""
pass
def _check_service_configured() -> None:
"""Check if the external Agent Generator service is configured.
Raises:
AgentGeneratorNotConfiguredError: If the service is not configured.
"""
if not is_external_service_configured():
raise AgentGeneratorNotConfiguredError(
"Agent Generator service is not configured. "
"Set AGENTGENERATOR_HOST environment variable to enable agent generation."
)
_UUID_PATTERN = re.compile(
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
re.IGNORECASE,
)
_UUID_PATTERN = re.compile(UUID_RE_STR, re.IGNORECASE)
def extract_uuids_from_text(text: str) -> list[str]:
@@ -553,69 +516,6 @@ async def enrich_library_agents_from_steps(
return all_agents
async def decompose_goal(
description: str,
context: str = "",
library_agents: Sequence[AgentSummary] | None = None,
) -> DecompositionResult | None:
"""Break down a goal into steps or return clarifying questions.
Args:
description: Natural language goal description
context: Additional context (e.g., answers to previous questions)
library_agents: User's library agents available for sub-agent composition
Returns:
DecompositionResult with either:
- {"type": "clarifying_questions", "questions": [...]}
- {"type": "instructions", "steps": [...]}
Or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for decompose_goal")
result = await decompose_goal_external(
description, context, _to_dict_list(library_agents)
)
return result # type: ignore[return-value]
async def generate_agent(
instructions: DecompositionResult | dict[str, Any],
library_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Generate agent JSON from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
Returns:
Agent JSON dict, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent")
result = await generate_agent_external(
dict(instructions), _to_dict_list(library_agents)
)
if result:
if isinstance(result, dict) and result.get("type") == "error":
return result
if "id" not in result:
result["id"] = str(uuid.uuid4())
if "version" not in result:
result["version"] = 1
if "is_active" not in result:
result["is_active"] = True
return result
class AgentJsonValidationError(Exception):
"""Raised when agent JSON is invalid or missing required fields."""
@@ -695,7 +595,10 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
async def save_agent_to_library(
agent_json: dict[str, Any], user_id: str, is_update: bool = False
agent_json: dict[str, Any],
user_id: str,
is_update: bool = False,
folder_id: str | None = None,
) -> tuple[Graph, Any]:
"""Save agent to database and user's library.
@@ -703,6 +606,7 @@ async def save_agent_to_library(
agent_json: Agent JSON dict
user_id: User ID
is_update: Whether this is an update to an existing agent
folder_id: Optional folder ID to place the agent in
Returns:
Tuple of (created Graph, LibraryAgent)
@@ -711,7 +615,7 @@ async def save_agent_to_library(
db = library_db()
if is_update:
return await db.update_graph_in_library(graph, user_id)
return await db.create_graph_in_library(graph, user_id)
return await db.create_graph_in_library(graph, user_id, folder_id=folder_id)
def graph_to_json(graph: Graph) -> dict[str, Any]:
@@ -788,70 +692,3 @@ async def get_agent_as_json(
return None
return graph_to_json(graph)
async def generate_agent_patch(
update_request: str,
current_agent: dict[str, Any],
library_agents: Sequence[AgentSummary] | None = None,
) -> dict[str, Any] | None:
"""Update an existing agent using natural language.
The external Agent Generator service handles:
- Generating the patch
- Applying the patch
- Fixing and validating the result
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
Returns:
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent_patch")
return await generate_agent_patch_external(
update_request,
current_agent,
_to_dict_list(library_agents),
)
async def customize_template(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any] | None:
"""Customize a template/marketplace agent using natural language.
This is used when users want to modify a template or marketplace agent
to fit their specific needs before adding it to their library.
The external Agent Generator service handles:
- Understanding the modification request
- Applying changes to the template
- Fixing and validating the result
Args:
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
Returns:
Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
error dict {"type": "error", ...}, or None on unexpected error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for customize_template")
return await customize_template_external(
template_agent, modification_request, context
)

View File

@@ -1,165 +0,0 @@
"""Dummy Agent Generator for testing.
Returns mock responses matching the format expected from the external service.
Enable via AGENTGENERATOR_USE_DUMMY=true in settings.
WARNING: This is for testing only. Do not use in production.
"""
import asyncio
import logging
import uuid
from typing import Any
logger = logging.getLogger(__name__)
# Dummy decomposition result (instructions type)
DUMMY_DECOMPOSITION_RESULT: dict[str, Any] = {
"type": "instructions",
"steps": [
{
"description": "Get input from user",
"action": "input",
"block_name": "AgentInputBlock",
},
{
"description": "Process the input",
"action": "process",
"block_name": "TextFormatterBlock",
},
{
"description": "Return output to user",
"action": "output",
"block_name": "AgentOutputBlock",
},
],
}
# Block IDs from backend/blocks/io.py
AGENT_INPUT_BLOCK_ID = "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4"
def _generate_dummy_agent_json() -> dict[str, Any]:
"""Generate a minimal valid agent JSON for testing."""
input_node_id = str(uuid.uuid4())
output_node_id = str(uuid.uuid4())
return {
"id": str(uuid.uuid4()),
"version": 1,
"is_active": True,
"name": "Dummy Test Agent",
"description": "A dummy agent generated for testing purposes",
"nodes": [
{
"id": input_node_id,
"block_id": AGENT_INPUT_BLOCK_ID,
"input_default": {
"name": "input",
"title": "Input",
"description": "Enter your input",
"placeholder_values": [],
},
"metadata": {"position": {"x": 0, "y": 0}},
},
{
"id": output_node_id,
"block_id": AGENT_OUTPUT_BLOCK_ID,
"input_default": {
"name": "output",
"title": "Output",
"description": "Agent output",
"format": "{output}",
},
"metadata": {"position": {"x": 400, "y": 0}},
},
],
"links": [
{
"id": str(uuid.uuid4()),
"source_id": input_node_id,
"sink_id": output_node_id,
"source_name": "result",
"sink_name": "value",
"is_static": False,
},
],
}
async def decompose_goal_dummy(
description: str,
context: str = "",
library_agents: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
"""Return dummy decomposition result."""
logger.info("Using dummy agent generator for decompose_goal")
return DUMMY_DECOMPOSITION_RESULT.copy()
async def generate_agent_dummy(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Return dummy agent synchronously (blocks for 30s, returns agent JSON).
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
"""
logger.info(
"Using dummy agent generator (sync mode): returning agent JSON after 30s"
)
await asyncio.sleep(30)
return _generate_dummy_agent_json()
async def generate_agent_patch_dummy(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Return dummy patched agent synchronously (blocks for 30s, returns patched agent JSON).
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
"""
logger.info(
"Using dummy agent generator patch (sync mode): returning patched agent after 30s"
)
await asyncio.sleep(30)
patched = current_agent.copy()
patched["description"] = (
f"{current_agent.get('description', '')} (updated: {update_request})"
)
return patched
async def customize_template_dummy(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any]:
"""Return dummy customized template (returns template with updated description)."""
logger.info("Using dummy agent generator for customize_template")
customized = template_agent.copy()
customized["description"] = (
f"{template_agent.get('description', '')} (customized: {modification_request})"
)
return customized
async def get_blocks_dummy() -> list[dict[str, Any]]:
"""Return dummy blocks list."""
logger.info("Using dummy agent generator for get_blocks")
return [
{"id": AGENT_INPUT_BLOCK_ID, "name": "AgentInputBlock"},
{"id": AGENT_OUTPUT_BLOCK_ID, "name": "AgentOutputBlock"},
]
async def health_check_dummy() -> bool:
"""Always returns healthy for dummy service."""
return True

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,913 @@
"""Unit tests for AgentFixer."""
from .fixer import (
_ADDTODICTIONARY_BLOCK_ID,
_ADDTOLIST_BLOCK_ID,
_CODE_EXECUTION_BLOCK_ID,
_DATA_SAMPLING_BLOCK_ID,
_GET_CURRENT_DATE_BLOCK_ID,
_STORE_VALUE_BLOCK_ID,
_TEXT_REPLACE_BLOCK_ID,
_UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
AGENT_EXECUTOR_BLOCK_ID,
MCP_TOOL_BLOCK_ID,
AgentFixer,
)
from .helpers import generate_uuid
def _make_agent(
nodes: list | None = None,
links: list | None = None,
agent_id: str | None = None,
) -> dict:
"""Build a minimal agent dict for testing."""
return {
"id": agent_id or generate_uuid(),
"name": "Test Agent",
"nodes": nodes or [],
"links": links or [],
}
def _make_node(
node_id: str | None = None,
block_id: str = "block-1",
input_default: dict | None = None,
position: tuple[int, int] = (0, 0),
) -> dict:
"""Build a minimal node dict for testing."""
return {
"id": node_id or generate_uuid(),
"block_id": block_id,
"input_default": input_default or {},
"metadata": {"position": {"x": position[0], "y": position[1]}},
}
def _make_link(
link_id: str | None = None,
source_id: str = "",
source_name: str = "output",
sink_id: str = "",
sink_name: str = "input",
is_static: bool = False,
) -> dict:
"""Build a minimal link dict for testing."""
return {
"id": link_id or generate_uuid(),
"source_id": source_id,
"source_name": source_name,
"sink_id": sink_id,
"sink_name": sink_name,
"is_static": is_static,
}
class TestFixAgentIds:
"""Tests for fix_agent_ids."""
def test_valid_uuids_unchanged(self):
fixer = AgentFixer()
agent_id = generate_uuid()
link_id = generate_uuid()
agent = _make_agent(agent_id=agent_id, links=[{"id": link_id}])
result = fixer.fix_agent_ids(agent)
assert result["id"] == agent_id
assert result["links"][0]["id"] == link_id
assert fixer.fixes_applied == []
def test_invalid_agent_id_replaced(self):
fixer = AgentFixer()
agent = _make_agent(agent_id="bad-id")
result = fixer.fix_agent_ids(agent)
assert result["id"] != "bad-id"
assert len(fixer.fixes_applied) == 1
assert "agent ID" in fixer.fixes_applied[0]
def test_invalid_link_id_replaced(self):
fixer = AgentFixer()
agent = _make_agent(links=[{"id": "not-a-uuid"}])
result = fixer.fix_agent_ids(agent)
assert result["links"][0]["id"] != "not-a-uuid"
assert len(fixer.fixes_applied) == 1
class TestFixDoubleCurlyBraces:
"""Tests for fix_double_curly_braces."""
def test_single_braces_converted_to_double(self):
fixer = AgentFixer()
node = _make_node(input_default={"prompt": "Hello {name}!"})
agent = _make_agent(nodes=[node])
result = fixer.fix_double_curly_braces(agent)
assert result["nodes"][0]["input_default"]["prompt"] == "Hello {{name}}!"
def test_double_braces_unchanged(self):
fixer = AgentFixer()
node = _make_node(input_default={"prompt": "Hello {{name}}!"})
agent = _make_agent(nodes=[node])
result = fixer.fix_double_curly_braces(agent)
assert result["nodes"][0]["input_default"]["prompt"] == "Hello {{name}}!"
assert fixer.fixes_applied == []
def test_non_string_prompt_skipped(self):
fixer = AgentFixer()
node = _make_node(input_default={"prompt": 42})
agent = _make_agent(nodes=[node])
result = fixer.fix_double_curly_braces(agent)
assert result["nodes"][0]["input_default"]["prompt"] == 42
def test_non_string_prompt_with_prompt_values_skipped(self):
"""Ensure non-string prompt fields don't crash re.search in the
prompt_values path."""
fixer = AgentFixer()
node_id = generate_uuid()
source_id = generate_uuid()
node = _make_node(
node_id=node_id, input_default={"prompt": None, "prompt_values": {}}
)
source_node = _make_node(node_id=source_id)
link = _make_link(
source_id=source_id,
source_name="output",
sink_id=node_id,
sink_name="prompt_values_$_name",
)
agent = _make_agent(nodes=[node, source_node], links=[link])
result = fixer.fix_double_curly_braces(agent)
# Should not crash and prompt stays None
assert result["nodes"][0]["input_default"]["prompt"] is None
class TestFixCredentials:
"""Tests for fix_credentials."""
def test_credentials_removed(self):
fixer = AgentFixer()
node = _make_node(
input_default={
"credentials": {"key": "secret"},
"url": "http://example.com",
}
)
agent = _make_agent(nodes=[node])
result = fixer.fix_credentials(agent)
assert "credentials" not in result["nodes"][0]["input_default"]
assert result["nodes"][0]["input_default"]["url"] == "http://example.com"
assert len(fixer.fixes_applied) == 1
def test_no_credentials_unchanged(self):
fixer = AgentFixer()
node = _make_node(input_default={"url": "http://example.com"})
agent = _make_agent(nodes=[node])
result = fixer.fix_credentials(agent)
assert result["nodes"][0]["input_default"]["url"] == "http://example.com"
assert fixer.fixes_applied == []
class TestFixCodeExecutionOutput:
"""Tests for fix_code_execution_output."""
def test_response_renamed_to_stdout_logs(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id=_CODE_EXECUTION_BLOCK_ID)
link = _make_link(source_id="n1", source_name="response", sink_id="n2")
agent = _make_agent(nodes=[node], links=[link])
result = fixer.fix_code_execution_output(agent)
assert result["links"][0]["source_name"] == "stdout_logs"
assert len(fixer.fixes_applied) == 1
def test_non_response_source_unchanged(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id=_CODE_EXECUTION_BLOCK_ID)
link = _make_link(source_id="n1", source_name="stdout_logs", sink_id="n2")
agent = _make_agent(nodes=[node], links=[link])
result = fixer.fix_code_execution_output(agent)
assert result["links"][0]["source_name"] == "stdout_logs"
assert fixer.fixes_applied == []
class TestFixDataSamplingSampleSize:
"""Tests for fix_data_sampling_sample_size."""
def test_sample_size_set_to_1(self):
fixer = AgentFixer()
node = _make_node(
node_id="n1",
block_id=_DATA_SAMPLING_BLOCK_ID,
input_default={"sample_size": 10},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_data_sampling_sample_size(agent)
assert result["nodes"][0]["input_default"]["sample_size"] == 1
def test_removes_links_to_sample_size(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id=_DATA_SAMPLING_BLOCK_ID)
link = _make_link(sink_id="n1", sink_name="sample_size", source_id="n2")
agent = _make_agent(nodes=[node], links=[link])
result = fixer.fix_data_sampling_sample_size(agent)
assert len(result["links"]) == 0
assert result["nodes"][0]["input_default"]["sample_size"] == 1
class TestFixTextReplaceNewParameter:
"""Tests for fix_text_replace_new_parameter."""
def test_empty_new_changed_to_space(self):
fixer = AgentFixer()
node = _make_node(
block_id=_TEXT_REPLACE_BLOCK_ID,
input_default={"new": ""},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_text_replace_new_parameter(agent)
assert result["nodes"][0]["input_default"]["new"] == " "
def test_nonempty_new_unchanged(self):
fixer = AgentFixer()
node = _make_node(
block_id=_TEXT_REPLACE_BLOCK_ID,
input_default={"new": "replacement"},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_text_replace_new_parameter(agent)
assert result["nodes"][0]["input_default"]["new"] == "replacement"
assert fixer.fixes_applied == []
class TestFixGetCurrentDateOffset:
"""Tests for fix_getcurrentdate_offset."""
def test_negative_offset_made_positive(self):
fixer = AgentFixer()
node = _make_node(
block_id=_GET_CURRENT_DATE_BLOCK_ID,
input_default={"offset": -5},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_getcurrentdate_offset(agent)
assert result["nodes"][0]["input_default"]["offset"] == 5
def test_positive_offset_unchanged(self):
fixer = AgentFixer()
node = _make_node(
block_id=_GET_CURRENT_DATE_BLOCK_ID,
input_default={"offset": 3},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_getcurrentdate_offset(agent)
assert result["nodes"][0]["input_default"]["offset"] == 3
assert fixer.fixes_applied == []
class TestFixNodeXCoordinates:
"""Tests for fix_node_x_coordinates."""
def test_close_nodes_spread_apart(self):
fixer = AgentFixer()
src_node = _make_node(node_id="src", position=(0, 0))
sink_node = _make_node(node_id="sink", position=(100, 0))
link = _make_link(source_id="src", sink_id="sink")
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
result = fixer.fix_node_x_coordinates(agent)
sink = next(n for n in result["nodes"] if n["id"] == "sink")
assert sink["metadata"]["position"]["x"] >= 800
def test_far_apart_nodes_unchanged(self):
fixer = AgentFixer()
src_node = _make_node(node_id="src", position=(0, 0))
sink_node = _make_node(node_id="sink", position=(1000, 0))
link = _make_link(source_id="src", sink_id="sink")
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
result = fixer.fix_node_x_coordinates(agent)
sink = next(n for n in result["nodes"] if n["id"] == "sink")
assert sink["metadata"]["position"]["x"] == 1000
assert fixer.fixes_applied == []
class TestFixAddToDictionaryBlocks:
"""Tests for fix_addtodictionary_blocks."""
def test_removes_create_dictionary_nodes(self):
fixer = AgentFixer()
create_dict_id = "b924ddf4-de4f-4b56-9a85-358930dcbc91"
dict_node = _make_node(node_id="dict-1", block_id=create_dict_id)
add_to_dict_node = _make_node(
node_id="add-1", block_id=_ADDTODICTIONARY_BLOCK_ID
)
link = _make_link(source_id="dict-1", sink_id="add-1")
agent = _make_agent(nodes=[dict_node, add_to_dict_node], links=[link])
result = fixer.fix_addtodictionary_blocks(agent)
node_ids = [n["id"] for n in result["nodes"]]
assert "dict-1" not in node_ids
assert "add-1" in node_ids
assert len(result["links"]) == 0
class TestFixStoreValueBeforeCondition:
"""Tests for fix_storevalue_before_condition."""
def test_inserts_storevalue_block(self):
fixer = AgentFixer()
condition_block_id = "715696a0-e1da-45c8-b209-c2fa9c3b0be6"
src_node = _make_node(node_id="src")
cond_node = _make_node(node_id="cond", block_id=condition_block_id)
link = _make_link(
source_id="src", source_name="output", sink_id="cond", sink_name="value2"
)
agent = _make_agent(nodes=[src_node, cond_node], links=[link])
result = fixer.fix_storevalue_before_condition(agent)
# Should have 3 nodes now (original 2 + new StoreValueBlock)
assert len(result["nodes"]) == 3
store_nodes = [
n for n in result["nodes"] if n["block_id"] == _STORE_VALUE_BLOCK_ID
]
assert len(store_nodes) == 1
assert store_nodes[0]["input_default"]["data"] is None
class TestFixAddToListBlocks:
"""Tests for fix_addtolist_blocks - self-reference links."""
def test_addtolist_gets_self_reference_link(self):
fixer = AgentFixer()
node = _make_node(node_id="atl-1", block_id=_ADDTOLIST_BLOCK_ID)
# Source link to AddToList (from some other node)
link = _make_link(
source_id="other",
source_name="output",
sink_id="atl-1",
sink_name="item",
)
other_node = _make_node(node_id="other")
agent = _make_agent(nodes=[other_node, node], links=[link])
result = fixer.fix_addtolist_blocks(agent)
# Should have a self-reference link: atl-1.updated_list -> atl-1.list
self_ref_links = [
lnk
for lnk in result["links"]
if lnk["source_id"] == "atl-1"
and lnk["sink_id"] == "atl-1"
and lnk["source_name"] == "updated_list"
and lnk["sink_name"] == "list"
]
assert len(self_ref_links) == 1
class TestFixLinkStaticProperties:
"""Tests for fix_link_static_properties."""
def test_sets_is_static_from_block_schema(self):
fixer = AgentFixer()
block_id = generate_uuid()
node = _make_node(node_id="n1", block_id=block_id)
link = _make_link(source_id="n1", sink_id="n2", is_static=False)
agent = _make_agent(nodes=[node], links=[link])
blocks = [{"id": block_id, "staticOutput": True}]
result = fixer.fix_link_static_properties(agent, blocks)
assert result["links"][0]["is_static"] is True
def test_unknown_block_leaves_link_unchanged(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id="unknown-block")
link = _make_link(source_id="n1", sink_id="n2", is_static=True)
agent = _make_agent(nodes=[node], links=[link])
result = fixer.fix_link_static_properties(agent, blocks=[])
# Unknown block → skipped, link stays as-is
assert result["links"][0]["is_static"] is True
class TestFixAiModelParameter:
"""Tests for fix_ai_model_parameter."""
def test_missing_model_gets_default(self):
fixer = AgentFixer()
block_id = generate_uuid()
node = _make_node(node_id="n1", block_id=block_id, input_default={})
agent = _make_agent(nodes=[node])
blocks = [
{
"id": block_id,
"categories": [{"category": "AI"}],
"inputSchema": {
"properties": {"model": {"type": "string"}},
},
}
]
result = fixer.fix_ai_model_parameter(agent, blocks)
assert result["nodes"][0]["input_default"]["model"] == "gpt-4o"
def test_valid_model_unchanged(self):
fixer = AgentFixer()
block_id = generate_uuid()
node = _make_node(
node_id="n1",
block_id=block_id,
input_default={"model": "claude-opus-4-6"},
)
agent = _make_agent(nodes=[node])
blocks = [
{
"id": block_id,
"categories": [{"category": "AI"}],
"inputSchema": {
"properties": {"model": {"type": "string"}},
},
}
]
result = fixer.fix_ai_model_parameter(agent, blocks)
assert result["nodes"][0]["input_default"]["model"] == "claude-opus-4-6"
class TestFixAgentExecutorBlocks:
"""Tests for fix_agent_executor_blocks."""
def test_fills_schemas_from_library_agent(self):
fixer = AgentFixer()
lib_agent_id = generate_uuid()
node = _make_node(
node_id="n1",
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={
"graph_id": lib_agent_id,
"graph_version": 1,
"user_id": "user-1",
},
)
agent = _make_agent(nodes=[node])
# Library agents use graph_id as the lookup key
library_agents = [
{
"graph_id": lib_agent_id,
"graph_version": 2,
"input_schema": {"field1": {"type": "string"}},
"output_schema": {"result": {"type": "string"}},
}
]
result = fixer.fix_agent_executor_blocks(agent, library_agents)
node_result = result["nodes"][0]["input_default"]
assert node_result["graph_version"] == 2
assert node_result["input_schema"] == {"field1": {"type": "string"}}
assert node_result["output_schema"] == {"result": {"type": "string"}}
class TestFixInvalidNestedSinkLinks:
"""Tests for fix_invalid_nested_sink_links."""
def test_removes_numeric_index_links(self):
fixer = AgentFixer()
block_id = generate_uuid()
node = _make_node(node_id="n1", block_id=block_id)
link = _make_link(source_id="n2", sink_id="n1", sink_name="values_#_0")
agent = _make_agent(nodes=[node], links=[link])
blocks = [
{
"id": block_id,
"inputSchema": {"properties": {"values": {"type": "array"}}},
}
]
result = fixer.fix_invalid_nested_sink_links(agent, blocks)
assert len(result["links"]) == 0
def test_valid_nested_links_kept(self):
fixer = AgentFixer()
block_id = generate_uuid()
node = _make_node(node_id="n1", block_id=block_id)
link = _make_link(source_id="n2", sink_id="n1", sink_name="values_#_name")
agent = _make_agent(nodes=[node], links=[link])
blocks = [
{
"id": block_id,
"inputSchema": {
"properties": {"values": {"type": "object"}},
},
}
]
result = fixer.fix_invalid_nested_sink_links(agent, blocks)
assert len(result["links"]) == 1
class TestApplyAllFixes:
"""Tests for apply_all_fixes orchestration."""
def test_is_sync(self):
"""apply_all_fixes should be a sync function."""
import inspect
assert not inspect.iscoroutinefunction(AgentFixer.apply_all_fixes)
def test_applies_multiple_fixes(self):
fixer = AgentFixer()
agent = _make_agent(
agent_id="bad-id",
nodes=[
_make_node(
block_id=_TEXT_REPLACE_BLOCK_ID,
input_default={"new": "", "credentials": {"key": "secret"}},
)
],
)
result = fixer.apply_all_fixes(agent)
# Agent ID should be fixed
assert result["id"] != "bad-id"
# Credentials should be removed
assert "credentials" not in result["nodes"][0]["input_default"]
# Text replace "new" should be space
assert result["nodes"][0]["input_default"]["new"] == " "
# Multiple fixes applied
assert len(fixer.fixes_applied) >= 3
def test_empty_agent_no_crash(self):
fixer = AgentFixer()
agent = _make_agent()
result = fixer.apply_all_fixes(agent)
assert "nodes" in result
assert "links" in result
def test_returns_deep_copy_behavior(self):
"""Fixer mutates in place — verify the same dict is returned."""
fixer = AgentFixer()
agent = _make_agent()
result = fixer.apply_all_fixes(agent)
assert result is agent
class TestFixMCPToolBlocks:
"""Tests for fix_mcp_tool_blocks."""
def test_adds_missing_tool_arguments(self):
fixer = AgentFixer()
node = _make_node(
node_id="n1",
block_id=MCP_TOOL_BLOCK_ID,
input_default={
"server_url": "https://mcp.example.com/sse",
"selected_tool": "search",
"tool_input_schema": {},
},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_mcp_tool_blocks(agent)
assert result["nodes"][0]["input_default"]["tool_arguments"] == {}
assert any("tool_arguments" in f for f in fixer.fixes_applied)
def test_adds_missing_tool_input_schema(self):
fixer = AgentFixer()
node = _make_node(
node_id="n1",
block_id=MCP_TOOL_BLOCK_ID,
input_default={
"server_url": "https://mcp.example.com/sse",
"selected_tool": "search",
"tool_arguments": {},
},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_mcp_tool_blocks(agent)
assert result["nodes"][0]["input_default"]["tool_input_schema"] == {}
assert any("tool_input_schema" in f for f in fixer.fixes_applied)
def test_populates_tool_arguments_from_schema(self):
fixer = AgentFixer()
node = _make_node(
node_id="n1",
block_id=MCP_TOOL_BLOCK_ID,
input_default={
"server_url": "https://mcp.example.com/sse",
"selected_tool": "search",
"tool_input_schema": {
"properties": {
"query": {"type": "string", "default": "hello"},
"limit": {"type": "integer"},
}
},
"tool_arguments": {},
},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_mcp_tool_blocks(agent)
tool_args = result["nodes"][0]["input_default"]["tool_arguments"]
assert tool_args["query"] == "hello"
assert tool_args["limit"] is None
def test_no_op_when_already_complete(self):
fixer = AgentFixer()
node = _make_node(
node_id="n1",
block_id=MCP_TOOL_BLOCK_ID,
input_default={
"server_url": "https://mcp.example.com/sse",
"selected_tool": "search",
"tool_input_schema": {},
"tool_arguments": {},
},
)
agent = _make_agent(nodes=[node])
fixer.fix_mcp_tool_blocks(agent)
assert len(fixer.fixes_applied) == 0
class TestFixDynamicBlockSinkNames:
"""Tests for fix_dynamic_block_sink_names."""
def test_mcp_tool_arguments_prefix_removed(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id=MCP_TOOL_BLOCK_ID)
link = _make_link(
source_id="src", sink_id="n1", sink_name="tool_arguments_#_query"
)
agent = _make_agent(nodes=[node], links=[link])
fixer.fix_dynamic_block_sink_names(agent)
assert agent["links"][0]["sink_name"] == "query"
assert len(fixer.fixes_applied) == 1
def test_agent_executor_inputs_prefix_removed(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id=AGENT_EXECUTOR_BLOCK_ID)
link = _make_link(source_id="src", sink_id="n1", sink_name="inputs_#_url")
agent = _make_agent(nodes=[node], links=[link])
fixer.fix_dynamic_block_sink_names(agent)
assert agent["links"][0]["sink_name"] == "url"
assert len(fixer.fixes_applied) == 1
def test_bare_sink_name_unchanged(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id=MCP_TOOL_BLOCK_ID)
link = _make_link(source_id="src", sink_id="n1", sink_name="query")
agent = _make_agent(nodes=[node], links=[link])
fixer.fix_dynamic_block_sink_names(agent)
assert agent["links"][0]["sink_name"] == "query"
assert len(fixer.fixes_applied) == 0
def test_non_dynamic_block_unchanged(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id="some-other-block-id")
link = _make_link(source_id="src", sink_id="n1", sink_name="values_#_key")
agent = _make_agent(nodes=[node], links=[link])
fixer.fix_dynamic_block_sink_names(agent)
assert agent["links"][0]["sink_name"] == "values_#_key"
assert len(fixer.fixes_applied) == 0
class TestFixDataTypeMismatch:
"""Tests for fix_data_type_mismatch."""
@staticmethod
def _make_block(
block_id: str,
name: str = "TestBlock",
input_schema: dict | None = None,
output_schema: dict | None = None,
) -> dict:
return {
"id": block_id,
"name": name,
"inputSchema": input_schema or {"properties": {}},
"outputSchema": output_schema or {"properties": {}},
}
def test_inserts_converter_for_incompatible_types(self):
fixer = AgentFixer()
src_block_id = generate_uuid()
sink_block_id = generate_uuid()
src_node = _make_node(node_id="src", block_id=src_block_id)
sink_node = _make_node(node_id="sink", block_id=sink_block_id)
link = _make_link(
source_id="src",
source_name="result",
sink_id="sink",
sink_name="count",
)
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
blocks = [
self._make_block(
src_block_id,
name="Source",
output_schema={"properties": {"result": {"type": "string"}}},
),
self._make_block(
sink_block_id,
name="Sink",
input_schema={"properties": {"count": {"type": "integer"}}},
),
]
result = fixer.fix_data_type_mismatch(agent, blocks)
# A converter node should have been inserted
converter_nodes = [
n
for n in result["nodes"]
if n["block_id"] == _UNIVERSAL_TYPE_CONVERTER_BLOCK_ID
]
assert len(converter_nodes) == 1
assert converter_nodes[0]["input_default"]["type"] == "number"
# Original link replaced by two new links through the converter
assert len(result["links"]) == 2
src_to_converter = result["links"][0]
converter_to_sink = result["links"][1]
assert src_to_converter["source_id"] == "src"
assert src_to_converter["sink_id"] == converter_nodes[0]["id"]
assert src_to_converter["sink_name"] == "value"
assert converter_to_sink["source_id"] == converter_nodes[0]["id"]
assert converter_to_sink["source_name"] == "value"
assert converter_to_sink["sink_id"] == "sink"
assert converter_to_sink["sink_name"] == "count"
assert len(fixer.fixes_applied) == 1
def test_compatible_types_unchanged(self):
fixer = AgentFixer()
block_id = generate_uuid()
src_node = _make_node(node_id="src", block_id=block_id)
sink_node = _make_node(node_id="sink", block_id=block_id)
link = _make_link(
source_id="src",
source_name="output",
sink_id="sink",
sink_name="input",
)
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
blocks = [
self._make_block(
block_id,
input_schema={"properties": {"input": {"type": "string"}}},
output_schema={"properties": {"output": {"type": "string"}}},
),
]
result = fixer.fix_data_type_mismatch(agent, blocks)
# No converter inserted, original link kept
assert len(result["nodes"]) == 2
assert len(result["links"]) == 1
assert result["links"][0] is link
assert fixer.fixes_applied == []
def test_missing_block_keeps_link(self):
"""Links referencing unknown blocks are kept unchanged."""
fixer = AgentFixer()
src_node = _make_node(node_id="src", block_id="unknown-block")
sink_node = _make_node(node_id="sink", block_id="unknown-block")
link = _make_link(source_id="src", sink_id="sink")
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
result = fixer.fix_data_type_mismatch(agent, blocks=[])
assert len(result["links"]) == 1
assert result["links"][0] is link
def test_missing_type_info_keeps_link(self):
"""Links where source/sink type is not defined are kept unchanged."""
fixer = AgentFixer()
block_id = generate_uuid()
src_node = _make_node(node_id="src", block_id=block_id)
sink_node = _make_node(node_id="sink", block_id=block_id)
link = _make_link(
source_id="src",
source_name="output",
sink_id="sink",
sink_name="input",
)
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
# Block has no properties defined for the linked fields
blocks = [self._make_block(block_id)]
result = fixer.fix_data_type_mismatch(agent, blocks)
assert len(result["links"]) == 1
assert fixer.fixes_applied == []
def test_multiple_mismatches_insert_multiple_converters(self):
"""Each incompatible link gets its own converter node."""
fixer = AgentFixer()
src_block_id = generate_uuid()
sink_block_id = generate_uuid()
src_node = _make_node(node_id="src", block_id=src_block_id)
sink1 = _make_node(node_id="sink1", block_id=sink_block_id)
sink2 = _make_node(node_id="sink2", block_id=sink_block_id)
link1 = _make_link(
source_id="src", source_name="out", sink_id="sink1", sink_name="count"
)
link2 = _make_link(
source_id="src", source_name="out", sink_id="sink2", sink_name="count"
)
agent = _make_agent(nodes=[src_node, sink1, sink2], links=[link1, link2])
blocks = [
self._make_block(
src_block_id,
output_schema={"properties": {"out": {"type": "string"}}},
),
self._make_block(
sink_block_id,
input_schema={"properties": {"count": {"type": "integer"}}},
),
]
result = fixer.fix_data_type_mismatch(agent, blocks)
converter_nodes = [
n
for n in result["nodes"]
if n["block_id"] == _UNIVERSAL_TYPE_CONVERTER_BLOCK_ID
]
assert len(converter_nodes) == 2
# Each original link becomes two links through its own converter
assert len(result["links"]) == 4
assert len(fixer.fixes_applied) == 2

View File

@@ -0,0 +1,67 @@
"""Shared helpers for agent generation."""
import re
import uuid
from typing import Any
from .blocks import get_blocks_as_dicts
__all__ = [
"AGENT_EXECUTOR_BLOCK_ID",
"AGENT_INPUT_BLOCK_ID",
"AGENT_OUTPUT_BLOCK_ID",
"AgentDict",
"MCP_TOOL_BLOCK_ID",
"UUID_REGEX",
"are_types_compatible",
"generate_uuid",
"get_blocks_as_dicts",
"get_defined_property_type",
"is_uuid",
]
# Type alias for the agent JSON structure passed through
# the validation and fixing pipeline.
AgentDict = dict[str, Any]
# Shared base pattern (unanchored, lowercase hex); used for both full-string
# validation (UUID_REGEX) and text extraction (core._UUID_PATTERN).
UUID_RE_STR = r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[a-f0-9]{4}-[a-f0-9]{12}"
UUID_REGEX = re.compile(r"^" + UUID_RE_STR + r"$")
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
MCP_TOOL_BLOCK_ID = "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
AGENT_INPUT_BLOCK_ID = "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4"
def is_uuid(value: str) -> bool:
"""Check if a string is a valid UUID."""
return isinstance(value, str) and UUID_REGEX.match(value) is not None
def generate_uuid() -> str:
"""Generate a new UUID string."""
return str(uuid.uuid4())
def get_defined_property_type(schema: dict[str, Any], name: str) -> str | None:
"""Get property type from a schema, handling nested `_#_` notation."""
if "_#_" in name:
parent, child = name.split("_#_", 1)
parent_schema = schema.get(parent, {})
if "properties" in parent_schema and isinstance(
parent_schema["properties"], dict
):
return parent_schema["properties"].get(child, {}).get("type")
return None
return schema.get(name, {}).get("type")
def are_types_compatible(src: str, sink: str) -> bool:
"""Check if two schema types are compatible."""
if {src, sink} <= {"integer", "number"}:
return True
return src == sink

View File

@@ -0,0 +1,196 @@
"""Shared fix → validate → preview/save pipeline for agent tools."""
import json
import logging
from typing import Any, cast
from backend.copilot.tools.models import (
AgentPreviewResponse,
AgentSavedResponse,
ErrorResponse,
ToolResponseBase,
)
from .blocks import get_blocks_as_dicts
from .core import get_library_agents_by_ids, save_agent_to_library
from .fixer import AgentFixer
from .validator import AgentValidator
logger = logging.getLogger(__name__)
MAX_AGENT_JSON_SIZE = 1_000_000 # 1 MB
async def fetch_library_agents(
user_id: str | None,
library_agent_ids: list[str],
) -> list[dict[str, Any]] | None:
"""Fetch library agents by IDs for AgentExecutorBlock validation.
Returns None if no IDs provided or user is not authenticated.
"""
if not user_id or not library_agent_ids:
return None
try:
agents = await get_library_agents_by_ids(
user_id=user_id,
agent_ids=library_agent_ids,
)
return cast(list[dict[str, Any]], agents)
except Exception as e:
logger.warning(f"Failed to fetch library agents by IDs: {e}")
return None
async def fix_validate_and_save(
agent_json: dict[str, Any],
*,
user_id: str | None,
session_id: str | None,
save: bool = True,
is_update: bool = False,
default_name: str = "Agent",
preview_message: str | None = None,
save_message: str | None = None,
library_agents: list[dict[str, Any]] | None = None,
folder_id: str | None = None,
) -> ToolResponseBase:
"""Shared pipeline: auto-fix → validate → preview or save.
Args:
agent_json: The agent JSON dict (must already have id/version/is_active set).
user_id: The authenticated user's ID.
session_id: The chat session ID.
save: Whether to save or just preview.
is_update: Whether this is an update to an existing agent.
default_name: Fallback name if agent_json has none.
preview_message: Custom preview message (optional).
save_message: Custom save success message (optional).
library_agents: Library agents for AgentExecutorBlock validation/fixing.
Returns:
An appropriate ToolResponseBase subclass.
"""
# Size guard
json_size = len(json.dumps(agent_json))
if json_size > MAX_AGENT_JSON_SIZE:
return ErrorResponse(
message=(
f"Agent JSON is too large ({json_size:,} bytes, "
f"max {MAX_AGENT_JSON_SIZE:,}). Reduce the number of nodes."
),
error="agent_json_too_large",
session_id=session_id,
)
blocks = get_blocks_as_dicts()
# Auto-fix
try:
fixer = AgentFixer()
agent_json = fixer.apply_all_fixes(agent_json, blocks, library_agents)
fixes = fixer.get_fixes_applied()
if fixes:
logger.info(f"Applied {len(fixes)} auto-fixes to agent JSON")
except Exception as e:
logger.warning(f"Auto-fix failed: {e}")
# Validate
try:
validator = AgentValidator()
is_valid, _ = validator.validate(agent_json, blocks, library_agents)
if not is_valid:
errors = validator.errors
return ErrorResponse(
message=(
f"The agent has {len(errors)} validation error(s):\n"
+ "\n".join(f"- {e}" for e in errors[:5])
),
error="validation_failed",
details={"errors": errors},
session_id=session_id,
)
except Exception as e:
logger.error(f"Validation failed with exception: {e}", exc_info=True)
return ErrorResponse(
message="Failed to validate the agent. Please try again.",
error="validation_exception",
details={"exception": str(e)},
session_id=session_id,
)
agent_name = agent_json.get("name", default_name)
agent_description = agent_json.get("description", "")
node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", []))
# Build a warning suffix when name/description is missing or generic
_GENERIC_NAMES = {
"agent",
"generated agent",
"customized agent",
"updated agent",
"new agent",
"my agent",
}
metadata_warnings: list[str] = []
if not agent_json.get("name") or agent_name.lower().strip() in _GENERIC_NAMES:
metadata_warnings.append("'name'")
if not agent_description:
metadata_warnings.append("'description'")
metadata_hint = ""
if metadata_warnings:
missing = " and ".join(metadata_warnings)
metadata_hint = (
f" Note: the agent is missing a meaningful {missing}. "
f"Please update the agent_json to include them."
)
if not save:
return AgentPreviewResponse(
message=(
(
preview_message
or f"Agent '{agent_name}' with {node_count} blocks is ready."
)
+ metadata_hint
),
agent_json=agent_json,
agent_name=agent_name,
description=agent_description,
node_count=node_count,
link_count=link_count,
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
error="auth_required",
session_id=session_id,
)
try:
created_graph, library_agent = await save_agent_to_library(
agent_json, user_id, is_update=is_update, folder_id=folder_id
)
return AgentSavedResponse(
message=(
(save_message or f"Agent '{created_graph.name}' has been saved!")
+ metadata_hint
),
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)
except Exception as e:
logger.error(f"Failed to save agent: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to save the agent: {str(e)}",
error="save_failed",
details={"exception": str(e)},
session_id=session_id,
)

View File

@@ -1,511 +0,0 @@
"""External Agent Generator service client.
This module provides a client for communicating with the external Agent Generator
microservice. When AGENTGENERATOR_HOST is configured, the agent generation functions
will delegate to the external service instead of using the built-in LLM-based implementation.
"""
import logging
from typing import Any
import httpx
from backend.util.settings import Settings
from .dummy import (
customize_template_dummy,
decompose_goal_dummy,
generate_agent_dummy,
generate_agent_patch_dummy,
get_blocks_dummy,
health_check_dummy,
)
logger = logging.getLogger(__name__)
_dummy_mode_warned = False
def _create_error_response(
error_message: str,
error_type: str = "unknown",
details: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Create a standardized error response dict.
Args:
error_message: Human-readable error message
error_type: Machine-readable error type
details: Optional additional error details
Returns:
Error dict with type="error" and error details
"""
response: dict[str, Any] = {
"type": "error",
"error": error_message,
"error_type": error_type,
}
if details:
response["details"] = details
return response
def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]:
"""Classify an HTTP error into error_type and message.
Args:
e: The HTTP status error
Returns:
Tuple of (error_type, error_message)
"""
status = e.response.status_code
if status == 429:
return "rate_limit", f"Agent Generator rate limited: {e}"
elif status == 503:
return "service_unavailable", f"Agent Generator unavailable: {e}"
elif status == 504 or status == 408:
return "timeout", f"Agent Generator timed out: {e}"
else:
return "http_error", f"HTTP error calling Agent Generator: {e}"
def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]:
"""Classify a request error into error_type and message.
Args:
e: The request error
Returns:
Tuple of (error_type, error_message)
"""
error_str = str(e).lower()
if "timeout" in error_str or "timed out" in error_str:
return "timeout", f"Agent Generator request timed out: {e}"
elif "connect" in error_str:
return "connection_error", f"Could not connect to Agent Generator: {e}"
else:
return "request_error", f"Request error calling Agent Generator: {e}"
_client: httpx.AsyncClient | None = None
_settings: Settings | None = None
def _get_settings() -> Settings:
"""Get or create settings singleton."""
global _settings
if _settings is None:
_settings = Settings()
return _settings
def _is_dummy_mode() -> bool:
"""Check if dummy mode is enabled for testing."""
global _dummy_mode_warned
settings = _get_settings()
is_dummy = bool(settings.config.agentgenerator_use_dummy)
if is_dummy and not _dummy_mode_warned:
logger.warning(
"Agent Generator running in DUMMY MODE - returning mock responses. "
"Do not use in production!"
)
_dummy_mode_warned = True
return is_dummy
def is_external_service_configured() -> bool:
"""Check if external Agent Generator service is configured (or dummy mode)."""
settings = _get_settings()
return bool(settings.config.agentgenerator_host) or bool(
settings.config.agentgenerator_use_dummy
)
def _get_base_url() -> str:
"""Get the base URL for the external service."""
settings = _get_settings()
host = settings.config.agentgenerator_host
port = settings.config.agentgenerator_port
return f"http://{host}:{port}"
def _get_client() -> httpx.AsyncClient:
"""Get or create the HTTP client for the external service."""
global _client
if _client is None:
settings = _get_settings()
_client = httpx.AsyncClient(
base_url=_get_base_url(),
timeout=httpx.Timeout(settings.config.agentgenerator_timeout),
)
return _client
async def decompose_goal_external(
description: str,
context: str = "",
library_agents: list[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Call the external service to decompose a goal.
Args:
description: Natural language goal description
context: Additional context (e.g., answers to previous questions)
library_agents: User's library agents available for sub-agent composition
Returns:
Dict with either:
- {"type": "clarifying_questions", "questions": [...]}
- {"type": "instructions", "steps": [...]}
- {"type": "unachievable_goal", ...}
- {"type": "vague_goal", ...}
- {"type": "error", "error": "...", "error_type": "..."} on error
Or None on unexpected error
"""
if _is_dummy_mode():
return await decompose_goal_dummy(description, context, library_agents)
client = _get_client()
if context:
description = f"{description}\n\nAdditional context from user:\n{context}"
payload: dict[str, Any] = {"description": description}
if library_agents:
payload["library_agents"] = library_agents
try:
response = await client.post("/api/decompose-description", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator decomposition failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Map the response to the expected format
response_type = data.get("type")
if response_type == "instructions":
return {"type": "instructions", "steps": data.get("steps", [])}
elif response_type == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
elif response_type == "unachievable_goal":
return {
"type": "unachievable_goal",
"reason": data.get("reason"),
"suggested_goal": data.get("suggested_goal"),
}
elif response_type == "vague_goal":
return {
"type": "vague_goal",
"suggested_goal": data.get("suggested_goal"),
}
elif response_type == "error":
# Pass through error from the service
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
else:
logger.error(
f"Unknown response type from external service: {response_type}"
)
return _create_error_response(
f"Unknown response type from Agent Generator: {response_type}",
"invalid_response",
)
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def generate_agent_external(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate an agent from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
Returns:
Agent JSON dict or error dict {"type": "error", ...} on error
"""
if _is_dummy_mode():
return await generate_agent_dummy(instructions, library_agents)
client = _get_client()
# Build request payload
payload: dict[str, Any] = {"instructions": instructions}
if library_agents:
payload["library_agents"] = library_agents
try:
response = await client.post("/api/generate-agent", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator generation failed: {error_msg} (type: {error_type})"
)
return _create_error_response(error_msg, error_type)
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def generate_agent_patch_external(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate a patch for an existing agent.
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
session_id: Session ID for async processing (enables Redis Streams callback)
Returns:
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
"""
if _is_dummy_mode():
return await generate_agent_patch_dummy(
update_request, current_agent, library_agents
)
client = _get_client()
# Build request payload
payload: dict[str, Any] = {
"update_request": update_request,
"current_agent_json": current_agent,
}
if library_agents:
payload["library_agents"] = library_agents
try:
response = await client.post("/api/update-agent", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator patch generation failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Check if it's clarifying questions
if data.get("type") == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
# Check if it's an error passed through
if data.get("type") == "error":
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
# Otherwise return the updated agent JSON
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def customize_template_external(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any] | None:
"""Call the external service to customize a template/marketplace agent.
Args:
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
operation_id: Operation ID for async processing (enables Redis Streams callback)
session_id: Session ID for async processing (enables Redis Streams callback)
Returns:
Customized agent JSON, clarifying questions dict, or error dict on error
"""
if _is_dummy_mode():
return await customize_template_dummy(
template_agent, modification_request, context
)
client = _get_client()
request = modification_request
if context:
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
payload: dict[str, Any] = {
"template_agent_json": template_agent,
"modification_request": request,
}
try:
response = await client.post("/api/template-modification", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator template customization failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Check if it's clarifying questions
if data.get("type") == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
# Check if it's an error passed through
if data.get("type") == "error":
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
# Otherwise return the customized agent JSON
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def get_blocks_external() -> list[dict[str, Any]] | None:
"""Get available blocks from the external service.
Returns:
List of block info dicts or None on error
"""
if _is_dummy_mode():
return await get_blocks_dummy()
client = _get_client()
try:
response = await client.get("/api/blocks")
response.raise_for_status()
data = response.json()
if not data.get("success"):
logger.error("External service returned error getting blocks")
return None
return data.get("blocks", [])
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error getting blocks from external service: {e}")
return None
except httpx.RequestError as e:
logger.error(f"Request error getting blocks from external service: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error getting blocks from external service: {e}")
return None
async def health_check() -> bool:
"""Check if the external service is healthy.
Returns:
True if healthy, False otherwise
"""
if not is_external_service_configured():
return False
if _is_dummy_mode():
return await health_check_dummy()
client = _get_client()
try:
response = await client.get("/health")
response.raise_for_status()
data = response.json()
return data.get("status") == "healthy" and data.get("blocks_loaded", False)
except Exception as e:
logger.warning(f"External agent generator health check failed: {e}")
return False
async def close_client() -> None:
"""Close the HTTP client."""
global _client
if _client is not None:
await _client.aclose()
_client = None

View File

@@ -0,0 +1,17 @@
"""Agent generation validation — re-exports from split modules.
This module was split into:
- helpers.py: get_blocks_as_dicts, block cache
- fixer.py: AgentFixer class
- validator.py: AgentValidator class
"""
from .fixer import AgentFixer
from .helpers import get_blocks_as_dicts
from .validator import AgentValidator
__all__ = [
"AgentFixer",
"AgentValidator",
"get_blocks_as_dicts",
]

View File

@@ -0,0 +1,939 @@
"""AgentValidator — validates agent JSON graphs for correctness."""
import json
import logging
import re
from typing import Any
from .helpers import (
AGENT_EXECUTOR_BLOCK_ID,
AGENT_INPUT_BLOCK_ID,
AGENT_OUTPUT_BLOCK_ID,
MCP_TOOL_BLOCK_ID,
AgentDict,
are_types_compatible,
get_defined_property_type,
)
logger = logging.getLogger(__name__)
class AgentValidator:
"""
A comprehensive validator for AutoGPT agents that provides detailed error
reporting for LLM-based fixes.
"""
def __init__(self):
self.errors: list[str] = []
def add_error(self, error_message: str) -> None:
"""Add an error message to the validation errors list."""
self.errors.append(error_message)
def _values_equal(self, val1: Any, val2: Any) -> bool:
"""Compare two values, handling complex types like dicts and lists."""
if type(val1) is not type(val2):
return False
if isinstance(val1, dict):
return json.dumps(val1, sort_keys=True) == json.dumps(val2, sort_keys=True)
if isinstance(val1, list):
return json.dumps(val1, sort_keys=True) == json.dumps(val2, sort_keys=True)
return val1 == val2
def validate_block_existence(
self, agent: AgentDict, blocks: list[dict[str, Any]]
) -> bool:
"""
Validate that all block IDs used in the agent actually exist in the
blocks list. Returns True if all block IDs exist, False otherwise.
"""
valid = True
# Create a set of all valid block IDs for fast lookup
valid_block_ids = {block.get("id") for block in blocks if block.get("id")}
# Check each node's block_id
for node in agent.get("nodes", []):
block_id = node.get("block_id")
node_id = node.get("id")
if not block_id:
self.add_error(
f"Node '{node_id}' is missing a 'block_id' field. "
f"Every node must reference a valid block."
)
valid = False
continue
if block_id not in valid_block_ids:
self.add_error(
f"Node '{node_id}' references block_id '{block_id}' "
f"which does not exist in the available blocks. "
f"This block may have been deprecated, removed, or "
f"the ID is incorrect. Please use a valid block from "
f"the blocks library."
)
valid = False
return valid
def validate_link_node_references(self, agent: AgentDict) -> bool:
"""
Validate that all node IDs referenced in links actually exist in the
agent's nodes. Returns True if all link references are valid, False
otherwise.
"""
valid = True
# Create a set of all valid node IDs for fast lookup
valid_node_ids = {
node.get("id") for node in agent.get("nodes", []) if node.get("id")
}
# Check each link's source_id and sink_id
for link in agent.get("links", []):
link_id = link.get("id", "Unknown")
source_id = link.get("source_id")
sink_id = link.get("sink_id")
source_name = link.get("source_name", "")
sink_name = link.get("sink_name", "")
# Check source_id
if not source_id:
self.add_error(
f"Link '{link_id}' is missing a 'source_id' field. "
f"Every link must reference a valid source node."
)
valid = False
elif source_id not in valid_node_ids:
self.add_error(
f"Link '{link_id}' references source_id '{source_id}' "
f"which does not exist in the agent's nodes. The link "
f"from '{source_name}' cannot be established because "
f"the source node is missing."
)
valid = False
# Check sink_id
if not sink_id:
self.add_error(
f"Link '{link_id}' is missing a 'sink_id' field. "
f"Every link must reference a valid sink (destination) "
f"node."
)
valid = False
elif sink_id not in valid_node_ids:
self.add_error(
f"Link '{link_id}' references sink_id '{sink_id}' "
f"which does not exist in the agent's nodes. The link "
f"to '{sink_name}' cannot be established because the "
f"destination node is missing."
)
valid = False
return valid
def validate_required_inputs(
self, agent: AgentDict, blocks: list[dict[str, Any]]
) -> bool:
"""
Validate that all required inputs are provided for each node.
Returns True if all required inputs are satisfied, False otherwise.
"""
valid = True
block_lookup = {b.get("id", ""): b for b in blocks}
for node in agent.get("nodes", []):
block_id = node.get("block_id")
block = block_lookup.get(block_id)
if not block:
continue
required_inputs = block.get("inputSchema", {}).get("required", [])
input_defaults = node.get("input_default", {})
node_id = node.get("id")
linked_inputs = set(
link.get("sink_name")
for link in agent.get("links", [])
if link.get("sink_id") == node_id and link.get("sink_name")
)
for req_input in required_inputs:
if (
req_input not in input_defaults
and req_input not in linked_inputs
and req_input != "credentials"
):
block_name = block.get("name", "Unknown Block")
self.add_error(
f"Node '{node_id}' (block '{block_name}' - "
f"{block_id}) is missing required input "
f"'{req_input}'. This input must be either "
f"provided as a default value in the node's "
f"'input_default' field or connected via a link "
f"from another node's output."
)
valid = False
return valid
def validate_data_type_compatibility(
self, agent: AgentDict, blocks: list[dict[str, Any]]
) -> bool:
"""
Validate that linked data types are compatible between source and sink.
Returns True if all data types are compatible, False otherwise.
"""
valid = True
node_lookup = {node.get("id", ""): node for node in agent.get("nodes", [])}
block_lookup = {block.get("id", ""): block for block in blocks}
for link in agent.get("links", []):
source_id = link.get("source_id")
sink_id = link.get("sink_id")
source_name = link.get("source_name")
sink_name = link.get("sink_name")
if not all(
isinstance(v, str) and v
for v in (source_id, sink_id, source_name, sink_name)
):
self.add_error(
f"Link '{link.get('id', 'Unknown')}' is missing required "
f"fields (source_id/sink_id/source_name/sink_name)."
)
valid = False
continue
source_node = node_lookup.get(source_id, "")
sink_node = node_lookup.get(sink_id, "")
if not source_node or not sink_node:
continue
source_block = block_lookup.get(source_node.get("block_id", ""))
sink_block = block_lookup.get(sink_node.get("block_id", ""))
if not source_block or not sink_block:
continue
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
source_type = get_defined_property_type(source_outputs, source_name)
sink_type = get_defined_property_type(sink_inputs, sink_name)
if (
source_type
and sink_type
and not are_types_compatible(source_type, sink_type)
):
source_block_name = source_block.get("name", "Unknown Block")
sink_block_name = sink_block.get("name", "Unknown Block")
self.add_error(
f"Data type mismatch in link '{link.get('id')}': "
f"Source '{source_block_name}' output "
f"'{link.get('source_name', '')}' outputs '{source_type}' "
f"type, but sink '{sink_block_name}' input "
f"'{link.get('sink_name', '')}' expects '{sink_type}' type. "
f"These types must match for the connection to work "
f"properly."
)
valid = False
return valid
def validate_nested_sink_links(
self, agent: AgentDict, blocks: list[dict[str, Any]]
) -> bool:
"""
Validate nested sink links (links with _#_ notation).
Returns True if all nested links are valid, False otherwise.
"""
valid = True
block_input_schemas = {
block.get("id", ""): block.get("inputSchema", {}).get("properties", {})
for block in blocks
}
block_names = {
block.get("id", ""): block.get("name", "Unknown Block") for block in blocks
}
node_lookup = {node.get("id", ""): node for node in agent.get("nodes", [])}
for link in agent.get("links", []):
sink_name = link.get("sink_name", "")
sink_id = link.get("sink_id")
if not sink_name or not sink_id:
continue
if "_#_" in sink_name:
parent, child = sink_name.split("_#_", 1)
sink_node = node_lookup.get(sink_id)
if not sink_node:
continue
block_id = sink_node.get("block_id")
input_props = block_input_schemas.get(block_id, {})
parent_schema = input_props.get(parent)
if not parent_schema:
block_name = block_names.get(block_id, "Unknown Block")
self.add_error(
f"Invalid nested sink link '{sink_name}' for "
f"node '{sink_id}' (block "
f"'{block_name}' - {block_id}): Parent property "
f"'{parent}' does not exist in the block's "
f"input schema."
)
valid = False
continue
# Check if additionalProperties is allowed either directly
# or via anyOf
allows_additional_properties = parent_schema.get(
"additionalProperties", False
)
# Check anyOf for additionalProperties
if not allows_additional_properties and "anyOf" in parent_schema:
any_of_schemas = parent_schema.get("anyOf", [])
if isinstance(any_of_schemas, list):
for schema_option in any_of_schemas:
if isinstance(schema_option, dict) and schema_option.get(
"additionalProperties"
):
allows_additional_properties = True
break
if not allows_additional_properties:
if not (
isinstance(parent_schema, dict)
and "properties" in parent_schema
and isinstance(parent_schema["properties"], dict)
and child in parent_schema["properties"]
):
block_name = block_names.get(block_id, "Unknown Block")
self.add_error(
f"Invalid nested sink link '{sink_name}' "
f"for node '{link.get('sink_id', '')}' (block "
f"'{block_name}' - {block_id}): Child "
f"property '{child}' does not exist in "
f"parent '{parent}' schema. Available "
f"properties: "
f"{list(parent_schema.get('properties', {}).keys())}"
)
valid = False
return valid
def validate_prompt_double_curly_braces_spaces(self, agent: AgentDict) -> bool:
"""
Validate that prompt parameters do not contain spaces in double curly
braces.
Checks the 'prompt' parameter in input_default of each node and reports
errors if values within double curly braces ({{...}}) contain spaces.
For example, {{user name}} should be {{user_name}}.
Args:
agent: The agent dictionary to validate
Returns:
True if all prompts are valid (no spaces in double curly braces),
False otherwise
"""
valid = True
nodes = agent.get("nodes", [])
for node in nodes:
node_id = node.get("id")
input_default = node.get("input_default", {})
# Check if 'prompt' parameter exists
if "prompt" not in input_default:
continue
prompt_text = input_default["prompt"]
# Only process if it's a string
if not isinstance(prompt_text, str):
continue
# Find all double curly brace patterns with spaces
matches = re.finditer(r"\{\{([^}]+)\}\}", prompt_text)
for match in matches:
content = match.group(1)
if " " in content:
start_pos = match.start()
snippet_start = max(0, start_pos - 30)
snippet_end = min(len(prompt_text), match.end() + 30)
snippet = prompt_text[snippet_start:snippet_end]
self.add_error(
f"Node '{node_id}' has spaces in double curly "
f"braces in prompt parameter: "
f"'{{{{{content}}}}}' should be "
f"'{{{{{content.replace(' ', '_')}}}}}'. "
f"Context: ...{snippet}..."
)
valid = False
return valid
def validate_source_output_existence(
self, agent: AgentDict, blocks: list[dict[str, Any]]
) -> bool:
"""
Validate that all source_names in links exist in the corresponding
block's output schema.
Checks that for each link, the source_name field references a valid
output property in the source block's outputSchema. Also handles nested
outputs with _#_ notation.
Args:
agent: The agent dictionary to validate
blocks: List of available blocks with their schemas
Returns:
True if all source output fields exist, False otherwise
"""
valid = True
# Create lookup dictionaries for efficiency
block_output_schemas = {
block.get("id", ""): block.get("outputSchema", {}).get("properties", {})
for block in blocks
}
block_names = {
block.get("id", ""): block.get("name", "Unknown Block") for block in blocks
}
node_lookup = {node.get("id", ""): node for node in agent.get("nodes", [])}
for link in agent.get("links", []):
source_id = link.get("source_id")
source_name = link.get("source_name", "")
link_id = link.get("id", "Unknown")
if not source_name:
self.add_error(
f"Link '{link_id}' is missing 'source_name'. "
f"Every link must specify which output field to read from."
)
valid = False
continue
source_node = node_lookup.get(source_id)
if not source_node:
# This error is already caught by
# validate_link_node_references
continue
block_id = source_node.get("block_id")
block_name = block_names.get(block_id, "Unknown Block")
# Special handling for AgentExecutorBlock - use dynamic
# output_schema from input_default
if block_id == AGENT_EXECUTOR_BLOCK_ID:
input_default = source_node.get("input_default", {})
dynamic_output_schema = input_default.get("output_schema", {})
if not isinstance(dynamic_output_schema, dict):
dynamic_output_schema = {}
output_props = dynamic_output_schema.get("properties", {})
if not isinstance(output_props, dict):
output_props = {}
else:
output_props = block_output_schemas.get(block_id, {})
# Handle nested source names (with _#_ notation)
if "_#_" in source_name:
parent, child = source_name.split("_#_", 1)
parent_schema = output_props.get(parent)
if not parent_schema:
self.add_error(
f"Invalid source output field '{source_name}' "
f"in link '{link_id}' from node '{source_id}' "
f"(block '{block_name}' - {block_id}): Parent "
f"property '{parent}' does not exist in the "
f"block's output schema."
)
valid = False
continue
# Check if additionalProperties is allowed either directly
# or via anyOf
allows_additional_properties = parent_schema.get(
"additionalProperties", False
)
if not allows_additional_properties and "anyOf" in parent_schema:
any_of_schemas = parent_schema.get("anyOf", [])
if isinstance(any_of_schemas, list):
for schema_option in any_of_schemas:
if isinstance(schema_option, dict) and schema_option.get(
"additionalProperties"
):
allows_additional_properties = True
break
# Also allow when items have
# additionalProperties (array of objects)
if (
isinstance(schema_option, dict)
and "items" in schema_option
):
items_schema = schema_option.get("items")
if isinstance(items_schema, dict) and items_schema.get(
"additionalProperties"
):
allows_additional_properties = True
break
# Only require child in properties when
# additionalProperties is not allowed
if not allows_additional_properties:
if not (
isinstance(parent_schema, dict)
and "properties" in parent_schema
and isinstance(parent_schema["properties"], dict)
and child in parent_schema["properties"]
):
available_props = (
list(parent_schema.get("properties", {}).keys())
if isinstance(parent_schema, dict)
else []
)
self.add_error(
f"Invalid nested source output field "
f"'{source_name}' in link '{link_id}' from "
f"node '{source_id}' (block "
f"'{block_name}' - {block_id}): Child "
f"property '{child}' does not exist in "
f"parent '{parent}' output schema. "
f"Available properties: {available_props}"
)
valid = False
else:
# Check simple (non-nested) source name
if source_name not in output_props:
available_outputs = list(output_props.keys())
self.add_error(
f"Invalid source output field '{source_name}' "
f"in link '{link_id}' from node '{source_id}' "
f"(block '{block_name}' - {block_id}): Output "
f"property '{source_name}' does not exist in "
f"the block's output schema. Available outputs: "
f"{available_outputs}"
)
valid = False
return valid
def validate_io_blocks(self, agent: AgentDict) -> bool:
"""
Validate that the agent has at least one AgentInputBlock and one
AgentOutputBlock. These blocks define the agent's interface.
Returns True if both are present, False otherwise.
"""
valid = True
block_ids = {node.get("block_id") for node in agent.get("nodes", [])}
if AGENT_INPUT_BLOCK_ID not in block_ids:
self.add_error(
f"Agent is missing an AgentInputBlock (block_id: "
f"'{AGENT_INPUT_BLOCK_ID}'). Every agent must have at "
f"least one AgentInputBlock to define user-facing inputs. "
f"Add a node with block_id '{AGENT_INPUT_BLOCK_ID}' and "
f"set input_default with 'name' and optionally 'title'."
)
valid = False
if AGENT_OUTPUT_BLOCK_ID not in block_ids:
self.add_error(
f"Agent is missing an AgentOutputBlock (block_id: "
f"'{AGENT_OUTPUT_BLOCK_ID}'). Every agent must have at "
f"least one AgentOutputBlock to define user-facing outputs. "
f"Add a node with block_id '{AGENT_OUTPUT_BLOCK_ID}' and "
f"set input_default with 'name', then link 'value' from "
f"another block's output."
)
valid = False
return valid
def validate_agent_executor_blocks(
self,
agent: AgentDict,
library_agents: list[dict[str, Any]] | None = None,
) -> bool:
"""
Validate AgentExecutorBlock nodes have required fields and valid
references.
Checks that AgentExecutorBlock nodes:
1. Have a valid graph_id in input_default (required)
2. If graph_id matches a known library agent, validates version
consistency
3. Sub-agent required inputs are connected via links (not hardcoded)
Note: Unknown graph_ids are not treated as errors - they could be valid
direct references to agents by their actual ID (not via library_agents).
This is consistent with fix_agent_executor_blocks() behavior.
Args:
agent: The agent dictionary to validate
library_agents: List of available library agents (for version
validation)
Returns:
True if all AgentExecutorBlock nodes are valid, False otherwise
"""
valid = True
nodes = agent.get("nodes", [])
links = agent.get("links", [])
# Create lookup for library agents
library_agent_lookup: dict[str, dict[str, Any]] = {}
if library_agents:
library_agent_lookup = {la.get("graph_id", ""): la for la in library_agents}
for node in nodes:
if node.get("block_id") != AGENT_EXECUTOR_BLOCK_ID:
continue
node_id = node.get("id")
input_default = node.get("input_default", {})
# Check for required graph_id
graph_id = input_default.get("graph_id")
if not graph_id:
self.add_error(
f"AgentExecutorBlock node '{node_id}' is missing "
f"required 'graph_id' in input_default. This field "
f"must reference the ID of the sub-agent to execute."
)
valid = False
continue
# If graph_id is not in library_agent_lookup, skip validation
if graph_id not in library_agent_lookup:
continue
# Validate version consistency for known library agents
library_agent = library_agent_lookup[graph_id]
expected_version = library_agent.get("graph_version")
current_version = input_default.get("graph_version")
if (
current_version
and expected_version
and current_version != expected_version
):
self.add_error(
f"AgentExecutorBlock node '{node_id}' has mismatched "
f"graph_version: got {current_version}, expected "
f"{expected_version} for library agent "
f"'{library_agent.get('name')}'"
)
valid = False
# Validate sub-agent inputs are properly linked (not hardcoded)
sub_agent_input_schema = library_agent.get("input_schema", {})
if not isinstance(sub_agent_input_schema, dict):
sub_agent_input_schema = {}
sub_agent_required_inputs = sub_agent_input_schema.get("required", [])
sub_agent_properties = sub_agent_input_schema.get("properties", {})
# Get all linked inputs to this node
linked_sub_agent_inputs: set[str] = set()
for link in links:
if link.get("sink_id") == node_id:
sink_name = link.get("sink_name", "")
if sink_name in sub_agent_properties:
linked_sub_agent_inputs.add(sink_name)
# Check for hardcoded inputs that should be linked
hardcoded_inputs = input_default.get("inputs", {})
input_schema = input_default.get("input_schema", {})
schema_properties = (
input_schema.get("properties", {})
if isinstance(input_schema, dict)
else {}
)
if isinstance(hardcoded_inputs, dict) and hardcoded_inputs:
for input_name, value in hardcoded_inputs.items():
if input_name not in sub_agent_properties:
continue
if value is None:
continue
# Skip if this input is already linked
if input_name in linked_sub_agent_inputs:
continue
prop_schema = schema_properties.get(input_name, {})
schema_default = (
prop_schema.get("default")
if isinstance(prop_schema, dict)
else None
)
if schema_default is not None and self._values_equal(
value, schema_default
):
continue
# This is a non-default hardcoded value without a link
self.add_error(
f"AgentExecutorBlock node '{node_id}' has "
f"hardcoded input '{input_name}'. Sub-agent "
f"inputs should be connected via links using "
f"'{input_name}' as sink_name, not hardcoded "
f"in input_default.inputs. Create a link from "
f"the appropriate source node."
)
valid = False
# Check for missing required sub-agent inputs.
# An input is satisfied if it is linked OR has an allowed
# hardcoded value (i.e. equals the schema default — the
# previous check already flags non-default hardcoded values).
hardcoded_inputs_dict = (
hardcoded_inputs if isinstance(hardcoded_inputs, dict) else {}
)
for req_input in sub_agent_required_inputs:
if req_input in linked_sub_agent_inputs:
continue
# Check if fixer populated it with a schema default value
if req_input in hardcoded_inputs_dict:
prop_schema = schema_properties.get(req_input, {})
schema_default = (
prop_schema.get("default")
if isinstance(prop_schema, dict)
else None
)
if schema_default is not None and self._values_equal(
hardcoded_inputs_dict[req_input], schema_default
):
continue
self.add_error(
f"AgentExecutorBlock node '{node_id}' is "
f"missing required sub-agent input "
f"'{req_input}'. Create a link to this node "
f"using sink_name '{req_input}' to connect "
f"the input."
)
valid = False
return valid
def validate_agent_executor_block_schemas(
self,
agent: AgentDict,
) -> bool:
"""
Validate that AgentExecutorBlock nodes have valid input_schema and
output_schema.
This validation runs regardless of library_agents availability and
ensures that the schemas are properly populated to prevent frontend
crashes.
Args:
agent: The agent dictionary to validate
Returns:
True if all AgentExecutorBlock nodes have valid schemas, False
otherwise
"""
valid = True
nodes = agent.get("nodes", [])
for node in nodes:
if node.get("block_id") != AGENT_EXECUTOR_BLOCK_ID:
continue
node_id = node.get("id")
input_default = node.get("input_default", {})
customized_name = (node.get("metadata") or {}).get(
"customized_name", "Unknown"
)
# Check input_schema
input_schema = input_default.get("input_schema")
if input_schema is None or not isinstance(input_schema, dict):
self.add_error(
f"AgentExecutorBlock node '{node_id}' "
f"({customized_name}) has missing or invalid "
f"input_schema. The input_schema must be a valid "
f"JSON Schema object with 'properties' and "
f"'required' fields."
)
valid = False
elif not input_schema.get("properties") and not input_schema.get("type"):
# Empty schema like {} is invalid
self.add_error(
f"AgentExecutorBlock node '{node_id}' "
f"({customized_name}) has empty input_schema. The "
f"input_schema must define the sub-agent's expected "
f"inputs. This usually indicates the sub-agent "
f"reference is incomplete or the library agent was "
f"not properly passed."
)
valid = False
# Check output_schema
output_schema = input_default.get("output_schema")
if output_schema is None or not isinstance(output_schema, dict):
self.add_error(
f"AgentExecutorBlock node '{node_id}' "
f"({customized_name}) has missing or invalid "
f"output_schema. The output_schema must be a valid "
f"JSON Schema object defining the sub-agent's "
f"outputs."
)
valid = False
elif not output_schema.get("properties") and not output_schema.get("type"):
# Empty schema like {} is invalid
self.add_error(
f"AgentExecutorBlock node '{node_id}' "
f"({customized_name}) has empty output_schema. "
f"The output_schema must define the sub-agent's "
f"expected outputs. This usually indicates the "
f"sub-agent reference is incomplete or the library "
f"agent was not properly passed."
)
valid = False
return valid
def validate_mcp_tool_blocks(self, agent: AgentDict) -> bool:
"""Validate that MCPToolBlock nodes have required fields.
Checks that each MCPToolBlock node has:
1. A non-empty `server_url` in input_default
2. A non-empty `selected_tool` in input_default
Returns True if all MCPToolBlock nodes are valid, False otherwise.
"""
valid = True
nodes = agent.get("nodes", [])
for node in nodes:
if node.get("block_id") != MCP_TOOL_BLOCK_ID:
continue
node_id = node.get("id", "unknown")
input_default = node.get("input_default", {})
customized_name = (node.get("metadata") or {}).get(
"customized_name", node_id
)
server_url = input_default.get("server_url")
if not server_url:
self.add_error(
f"MCPToolBlock node '{customized_name}' ({node_id}) is "
f"missing required 'server_url' in input_default. "
f"Set this to the MCP server URL "
f"(e.g. 'https://mcp.example.com/sse')."
)
valid = False
selected_tool = input_default.get("selected_tool")
if not selected_tool:
self.add_error(
f"MCPToolBlock node '{customized_name}' ({node_id}) is "
f"missing required 'selected_tool' in input_default. "
f"Set this to the name of the MCP tool to execute."
)
valid = False
return valid
def validate(
self,
agent: AgentDict,
blocks: list[dict[str, Any]],
library_agents: list[dict[str, Any]] | None = None,
) -> tuple[bool, str | None]:
"""
Comprehensive validation of an agent against available blocks.
Returns:
Tuple[bool, Optional[str]]: (is_valid, error_message)
- is_valid: True if agent passes all validations, False otherwise
- error_message: Detailed error message if validation fails, None
if successful
"""
logger.info("Validating agent...")
self.errors = []
checks = [
(
"Block existence",
self.validate_block_existence(agent, blocks),
),
(
"Link node references",
self.validate_link_node_references(agent),
),
(
"Required inputs",
self.validate_required_inputs(agent, blocks),
),
(
"Data type compatibility",
self.validate_data_type_compatibility(agent, blocks),
),
(
"Nested sink links",
self.validate_nested_sink_links(agent, blocks),
),
(
"Source output existence",
self.validate_source_output_existence(agent, blocks),
),
(
"Prompt double curly braces spaces",
self.validate_prompt_double_curly_braces_spaces(agent),
),
(
"IO blocks",
self.validate_io_blocks(agent),
),
# Always validate AgentExecutorBlock schemas to prevent
# frontend crashes
(
"AgentExecutorBlock schemas",
self.validate_agent_executor_block_schemas(agent),
),
(
"MCP tool blocks",
self.validate_mcp_tool_blocks(agent),
),
]
# Add AgentExecutorBlock detailed validation if library_agents
# provided
if library_agents:
checks.append(
(
"AgentExecutorBlock references",
self.validate_agent_executor_blocks(agent, library_agents),
)
)
all_passed = all(check[1] for check in checks)
if all_passed:
logger.info("Agent validation successful.")
return True, None
else:
error_message = "Agent validation failed with the following errors:\n\n"
for i, error in enumerate(self.errors, 1):
error_message += f"{i}. {error}\n"
logger.error(f"Agent validation failed: {error_message}")
return False, error_message

View File

@@ -0,0 +1,710 @@
"""Unit tests for AgentValidator."""
from .helpers import (
AGENT_EXECUTOR_BLOCK_ID,
AGENT_INPUT_BLOCK_ID,
AGENT_OUTPUT_BLOCK_ID,
MCP_TOOL_BLOCK_ID,
generate_uuid,
)
from .validator import AgentValidator
def _make_agent(
nodes: list | None = None,
links: list | None = None,
agent_id: str | None = None,
) -> dict:
"""Build a minimal agent dict for testing."""
return {
"id": agent_id or generate_uuid(),
"name": "Test Agent",
"nodes": nodes or [],
"links": links or [],
}
def _make_node(
node_id: str | None = None,
block_id: str = "block-1",
input_default: dict | None = None,
position: tuple[int, int] = (0, 0),
) -> dict:
return {
"id": node_id or generate_uuid(),
"block_id": block_id,
"input_default": input_default or {},
"metadata": {"position": {"x": position[0], "y": position[1]}},
}
def _make_link(
link_id: str | None = None,
source_id: str = "",
source_name: str = "output",
sink_id: str = "",
sink_name: str = "input",
) -> dict:
return {
"id": link_id or generate_uuid(),
"source_id": source_id,
"source_name": source_name,
"sink_id": sink_id,
"sink_name": sink_name,
}
def _make_block(
block_id: str = "block-1",
name: str = "TestBlock",
input_schema: dict | None = None,
output_schema: dict | None = None,
categories: list | None = None,
static_output: bool = False,
) -> dict:
return {
"id": block_id,
"name": name,
"inputSchema": input_schema or {"properties": {}, "required": []},
"outputSchema": output_schema or {"properties": {}},
"categories": categories or [],
"staticOutput": static_output,
}
# ============================================================================
# validate_block_existence
# ============================================================================
class TestValidateBlockExistence:
def test_valid_blocks_pass(self):
v = AgentValidator()
node = _make_node(block_id="b1")
block = _make_block(block_id="b1")
agent = _make_agent(nodes=[node])
assert v.validate_block_existence(agent, [block]) is True
assert v.errors == []
def test_missing_block_fails(self):
v = AgentValidator()
node = _make_node(block_id="nonexistent")
agent = _make_agent(nodes=[node])
assert v.validate_block_existence(agent, []) is False
assert len(v.errors) == 1
assert "does not exist" in v.errors[0]
def test_missing_block_id_field(self):
v = AgentValidator()
node = {"id": "n1", "input_default": {}, "metadata": {}}
agent = _make_agent(nodes=[node])
assert v.validate_block_existence(agent, []) is False
assert "missing a 'block_id'" in v.errors[0]
# ============================================================================
# validate_link_node_references
# ============================================================================
class TestValidateLinkNodeReferences:
def test_valid_references_pass(self):
v = AgentValidator()
n1 = _make_node(node_id="n1")
n2 = _make_node(node_id="n2")
link = _make_link(source_id="n1", sink_id="n2")
agent = _make_agent(nodes=[n1, n2], links=[link])
assert v.validate_link_node_references(agent) is True
assert v.errors == []
def test_invalid_source_fails(self):
v = AgentValidator()
n1 = _make_node(node_id="n1")
link = _make_link(source_id="missing", sink_id="n1")
agent = _make_agent(nodes=[n1], links=[link])
assert v.validate_link_node_references(agent) is False
assert any("source_id" in e for e in v.errors)
def test_invalid_sink_fails(self):
v = AgentValidator()
n1 = _make_node(node_id="n1")
link = _make_link(source_id="n1", sink_id="missing")
agent = _make_agent(nodes=[n1], links=[link])
assert v.validate_link_node_references(agent) is False
assert any("sink_id" in e for e in v.errors)
# ============================================================================
# validate_required_inputs
# ============================================================================
class TestValidateRequiredInputs:
def test_satisfied_by_default_passes(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {"url": {"type": "string"}},
"required": ["url"],
},
)
node = _make_node(block_id="b1", input_default={"url": "http://example.com"})
agent = _make_agent(nodes=[node])
assert v.validate_required_inputs(agent, [block]) is True
assert v.errors == []
def test_satisfied_by_link_passes(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {"url": {"type": "string"}},
"required": ["url"],
},
)
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(source_id="n2", sink_id="n1", sink_name="url")
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_required_inputs(agent, [block]) is True
def test_missing_required_input_fails(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {"url": {"type": "string"}},
"required": ["url"],
},
)
node = _make_node(block_id="b1", input_default={})
agent = _make_agent(nodes=[node])
assert v.validate_required_inputs(agent, [block]) is False
assert any("missing required input" in e for e in v.errors)
def test_credentials_always_allowed_missing(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {"credentials": {"type": "object"}},
"required": ["credentials"],
},
)
node = _make_node(block_id="b1", input_default={})
agent = _make_agent(nodes=[node])
assert v.validate_required_inputs(agent, [block]) is True
# ============================================================================
# validate_data_type_compatibility
# ============================================================================
class TestValidateDataTypeCompatibility:
def test_matching_types_pass(self):
v = AgentValidator()
src_block = _make_block(
block_id="src-b",
output_schema={"properties": {"out": {"type": "string"}}},
)
sink_block = _make_block(
block_id="sink-b",
input_schema={"properties": {"inp": {"type": "string"}}, "required": []},
)
src_node = _make_node(node_id="n1", block_id="src-b")
sink_node = _make_node(node_id="n2", block_id="sink-b")
link = _make_link(
source_id="n1", source_name="out", sink_id="n2", sink_name="inp"
)
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
assert (
v.validate_data_type_compatibility(agent, [src_block, sink_block]) is True
)
def test_int_number_compatible(self):
v = AgentValidator()
src_block = _make_block(
block_id="src-b",
output_schema={"properties": {"out": {"type": "integer"}}},
)
sink_block = _make_block(
block_id="sink-b",
input_schema={"properties": {"inp": {"type": "number"}}, "required": []},
)
src_node = _make_node(node_id="n1", block_id="src-b")
sink_node = _make_node(node_id="n2", block_id="sink-b")
link = _make_link(
source_id="n1", source_name="out", sink_id="n2", sink_name="inp"
)
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
assert (
v.validate_data_type_compatibility(agent, [src_block, sink_block]) is True
)
def test_mismatched_types_fail(self):
v = AgentValidator()
src_block = _make_block(
block_id="src-b",
output_schema={"properties": {"out": {"type": "string"}}},
)
sink_block = _make_block(
block_id="sink-b",
input_schema={"properties": {"inp": {"type": "integer"}}, "required": []},
)
src_node = _make_node(node_id="n1", block_id="src-b")
sink_node = _make_node(node_id="n2", block_id="sink-b")
link = _make_link(
source_id="n1", source_name="out", sink_id="n2", sink_name="inp"
)
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
assert (
v.validate_data_type_compatibility(agent, [src_block, sink_block]) is False
)
assert any("mismatch" in e.lower() for e in v.errors)
# ============================================================================
# validate_source_output_existence
# ============================================================================
class TestValidateSourceOutputExistence:
def test_valid_source_output_passes(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
output_schema={"properties": {"result": {"type": "string"}}},
)
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(source_id="n1", source_name="result", sink_id="n2")
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_source_output_existence(agent, [block]) is True
def test_invalid_source_output_fails(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
output_schema={"properties": {"result": {"type": "string"}}},
)
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(source_id="n1", source_name="nonexistent", sink_id="n2")
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_source_output_existence(agent, [block]) is False
assert any("does not exist" in e for e in v.errors)
# ============================================================================
# validate_prompt_double_curly_braces_spaces
# ============================================================================
class TestValidatePromptDoubleCurlyBracesSpaces:
def test_no_spaces_passes(self):
v = AgentValidator()
node = _make_node(input_default={"prompt": "Hello {{name}}!"})
agent = _make_agent(nodes=[node])
assert v.validate_prompt_double_curly_braces_spaces(agent) is True
def test_spaces_in_braces_fails(self):
v = AgentValidator()
node = _make_node(input_default={"prompt": "Hello {{user name}}!"})
agent = _make_agent(nodes=[node])
assert v.validate_prompt_double_curly_braces_spaces(agent) is False
assert any("spaces" in e for e in v.errors)
# ============================================================================
# validate_nested_sink_links
# ============================================================================
class TestValidateNestedSinkLinks:
def test_valid_nested_link_passes(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {
"config": {
"type": "object",
"properties": {"key": {"type": "string"}},
}
},
"required": [],
},
)
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(sink_id="n1", sink_name="config_#_key", source_id="n2")
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_nested_sink_links(agent, [block]) is True
def test_invalid_parent_fails(self):
v = AgentValidator()
block = _make_block(block_id="b1")
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(sink_id="n1", sink_name="nonexistent_#_key", source_id="n2")
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_nested_sink_links(agent, [block]) is False
assert any("does not exist" in e for e in v.errors)
# ============================================================================
# validate_agent_executor_block_schemas
# ============================================================================
class TestValidateAgentExecutorBlockSchemas:
def test_valid_schemas_pass(self):
v = AgentValidator()
node = _make_node(
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={
"graph_id": generate_uuid(),
"input_schema": {"properties": {"q": {"type": "string"}}},
"output_schema": {"properties": {"result": {"type": "string"}}},
},
)
agent = _make_agent(nodes=[node])
assert v.validate_agent_executor_block_schemas(agent) is True
assert v.errors == []
def test_empty_input_schema_fails(self):
v = AgentValidator()
node = _make_node(
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={
"graph_id": generate_uuid(),
"input_schema": {},
"output_schema": {"properties": {"result": {"type": "string"}}},
},
)
agent = _make_agent(nodes=[node])
assert v.validate_agent_executor_block_schemas(agent) is False
assert any("empty input_schema" in e for e in v.errors)
def test_missing_output_schema_fails(self):
v = AgentValidator()
node = _make_node(
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={
"graph_id": generate_uuid(),
"input_schema": {"properties": {"q": {"type": "string"}}},
},
)
agent = _make_agent(nodes=[node])
assert v.validate_agent_executor_block_schemas(agent) is False
assert any("output_schema" in e for e in v.errors)
# ============================================================================
# validate_agent_executor_blocks
# ============================================================================
class TestValidateAgentExecutorBlocks:
def test_missing_graph_id_fails(self):
v = AgentValidator()
node = _make_node(
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={},
)
agent = _make_agent(nodes=[node])
assert v.validate_agent_executor_blocks(agent) is False
assert any("graph_id" in e for e in v.errors)
def test_valid_graph_id_passes(self):
v = AgentValidator()
node = _make_node(
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={"graph_id": generate_uuid()},
)
agent = _make_agent(nodes=[node])
assert v.validate_agent_executor_blocks(agent) is True
def test_version_mismatch_with_library_agent(self):
v = AgentValidator()
lib_id = generate_uuid()
node = _make_node(
node_id="n1",
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={"graph_id": lib_id, "graph_version": 1},
)
agent = _make_agent(nodes=[node])
library_agents = [{"graph_id": lib_id, "graph_version": 3, "name": "Sub Agent"}]
assert v.validate_agent_executor_blocks(agent, library_agents) is False
assert any("mismatched graph_version" in e for e in v.errors)
def test_required_input_satisfied_by_schema_default_passes(self):
"""Required sub-agent inputs filled with their schema default by the fixer
should NOT be flagged as missing."""
v = AgentValidator()
lib_id = generate_uuid()
node = _make_node(
node_id="n1",
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={
"graph_id": lib_id,
"input_schema": {
"properties": {"mode": {"type": "string", "default": "fast"}}
},
"inputs": {"mode": "fast"}, # fixer populated with schema default
},
)
agent = _make_agent(nodes=[node])
library_agents = [
{
"graph_id": lib_id,
"graph_version": 1,
"name": "Sub",
"input_schema": {
"required": ["mode"],
"properties": {"mode": {"type": "string", "default": "fast"}},
},
"output_schema": {},
}
]
assert v.validate_agent_executor_blocks(agent, library_agents) is True
assert v.errors == []
def test_required_input_not_linked_and_no_default_fails(self):
"""Required sub-agent inputs without a link or schema default must fail."""
v = AgentValidator()
lib_id = generate_uuid()
node = _make_node(
node_id="n1",
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={
"graph_id": lib_id,
"input_schema": {"properties": {"query": {"type": "string"}}},
"inputs": {},
},
)
agent = _make_agent(nodes=[node])
library_agents = [
{
"graph_id": lib_id,
"graph_version": 1,
"name": "Sub",
"input_schema": {
"required": ["query"],
"properties": {"query": {"type": "string"}},
},
"output_schema": {},
}
]
assert v.validate_agent_executor_blocks(agent, library_agents) is False
assert any("missing required sub-agent input" in e for e in v.errors)
# ============================================================================
# validate_io_blocks
# ============================================================================
class TestValidateIoBlocks:
def test_missing_input_block_reports_error(self):
v = AgentValidator()
# Agent has output block but no input block
node = _make_node(block_id=AGENT_OUTPUT_BLOCK_ID)
agent = _make_agent(nodes=[node])
assert v.validate_io_blocks(agent) is False
assert len(v.errors) == 1
assert "AgentInputBlock" in v.errors[0]
def test_missing_output_block_reports_error(self):
v = AgentValidator()
# Agent has input block but no output block
node = _make_node(block_id=AGENT_INPUT_BLOCK_ID)
agent = _make_agent(nodes=[node])
assert v.validate_io_blocks(agent) is False
assert len(v.errors) == 1
assert "AgentOutputBlock" in v.errors[0]
def test_missing_both_io_blocks_reports_two_errors(self):
v = AgentValidator()
node = _make_node(block_id="some-other-block")
agent = _make_agent(nodes=[node])
assert v.validate_io_blocks(agent) is False
assert len(v.errors) == 2
def test_both_io_blocks_present_no_error(self):
v = AgentValidator()
input_node = _make_node(block_id=AGENT_INPUT_BLOCK_ID)
output_node = _make_node(block_id=AGENT_OUTPUT_BLOCK_ID)
agent = _make_agent(nodes=[input_node, output_node])
assert v.validate_io_blocks(agent) is True
assert v.errors == []
def test_empty_agent_reports_both_missing(self):
v = AgentValidator()
agent = _make_agent(nodes=[])
assert v.validate_io_blocks(agent) is False
assert len(v.errors) == 2
# ============================================================================
# validate (integration)
# ============================================================================
class TestValidate:
def test_valid_agent_passes(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {"url": {"type": "string"}},
"required": ["url"],
},
output_schema={"properties": {"result": {"type": "string"}}},
)
input_block = _make_block(
block_id=AGENT_INPUT_BLOCK_ID,
name="AgentInputBlock",
output_schema={"properties": {"result": {}}},
)
output_block = _make_block(
block_id=AGENT_OUTPUT_BLOCK_ID,
name="AgentOutputBlock",
)
input_node = _make_node(
node_id="n-in",
block_id=AGENT_INPUT_BLOCK_ID,
input_default={"name": "url"},
)
n1 = _make_node(
node_id="n1", block_id="b1", input_default={"url": "http://example.com"}
)
n2 = _make_node(
node_id="n2", block_id="b1", input_default={"url": "http://example2.com"}
)
output_node = _make_node(
node_id="n-out",
block_id=AGENT_OUTPUT_BLOCK_ID,
input_default={"name": "result"},
)
link = _make_link(
source_id="n1", source_name="result", sink_id="n2", sink_name="url"
)
agent = _make_agent(nodes=[input_node, n1, n2, output_node], links=[link])
is_valid, error_message = v.validate(agent, [block, input_block, output_block])
assert is_valid is True
assert error_message is None
def test_invalid_agent_returns_errors(self):
v = AgentValidator()
node = _make_node(block_id="nonexistent")
agent = _make_agent(nodes=[node])
is_valid, error_message = v.validate(agent, [])
assert is_valid is False
assert error_message is not None
assert "does not exist" in error_message
def test_empty_agent_fails_io_validation(self):
v = AgentValidator()
agent = _make_agent()
is_valid, error_message = v.validate(agent, [])
assert is_valid is False
assert error_message is not None
assert "AgentInputBlock" in error_message
assert "AgentOutputBlock" in error_message
class TestValidateMCPToolBlocks:
"""Tests for validate_mcp_tool_blocks."""
def test_missing_server_url_reports_error(self):
v = AgentValidator()
node = _make_node(
block_id=MCP_TOOL_BLOCK_ID,
input_default={"selected_tool": "my_tool"},
)
agent = _make_agent(nodes=[node])
result = v.validate_mcp_tool_blocks(agent)
assert result is False
assert any("server_url" in e for e in v.errors)
def test_missing_selected_tool_reports_error(self):
v = AgentValidator()
node = _make_node(
block_id=MCP_TOOL_BLOCK_ID,
input_default={"server_url": "https://mcp.example.com/sse"},
)
agent = _make_agent(nodes=[node])
result = v.validate_mcp_tool_blocks(agent)
assert result is False
assert any("selected_tool" in e for e in v.errors)
def test_valid_mcp_block_passes(self):
v = AgentValidator()
node = _make_node(
block_id=MCP_TOOL_BLOCK_ID,
input_default={
"server_url": "https://mcp.example.com/sse",
"selected_tool": "search",
"tool_input_schema": {"properties": {"query": {"type": "string"}}},
"tool_arguments": {},
},
)
agent = _make_agent(nodes=[node])
result = v.validate_mcp_tool_blocks(agent)
assert result is True
assert len(v.errors) == 0
def test_both_missing_reports_two_errors(self):
v = AgentValidator()
node = _make_node(
block_id=MCP_TOOL_BLOCK_ID,
input_default={},
)
agent = _make_agent(nodes=[node])
v.validate_mcp_tool_blocks(agent)
assert len(v.errors) == 2

View File

@@ -208,6 +208,9 @@ def _library_agent_to_info(agent: LibraryAgent) -> AgentInfo:
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
graph_version=agent.graph_version,
input_schema=agent.input_schema,
output_schema=agent.output_schema,
)

View File

@@ -21,10 +21,10 @@ from typing import Any
from e2b import AsyncSandbox
from e2b.exceptions import TimeoutException
from backend.copilot.context import E2B_WORKDIR, get_current_sandbox
from backend.copilot.model import ChatSession
from .base import BaseTool
from .e2b_sandbox import E2B_WORKDIR
from .models import BashExecResponse, ErrorResponse, ToolResponseBase
from .sandbox import get_workspace_dir, has_full_sandbox, run_sandboxed
@@ -94,9 +94,6 @@ class BashExecTool(BaseTool):
session_id=session_id,
)
# E2B path: run on remote cloud sandbox when available.
from backend.copilot.sdk.tool_adapter import get_current_sandbox
sandbox = get_current_sandbox()
if sandbox is not None:
return await self._execute_on_e2b(sandbox, command, timeout, session_id)

View File

@@ -0,0 +1,157 @@
"""Tool for continuing block execution after human review approval."""
import logging
from typing import Any
from prisma.enums import ReviewStatus
from backend.blocks import get_block
from backend.copilot.constants import (
COPILOT_NODE_PREFIX,
COPILOT_SESSION_PREFIX,
parse_node_id_from_exec_id,
)
from backend.copilot.model import ChatSession
from backend.data.db_accessors import review_db
from .base import BaseTool
from .helpers import execute_block, resolve_block_credentials
from .models import ErrorResponse, ToolResponseBase
logger = logging.getLogger(__name__)
class ContinueRunBlockTool(BaseTool):
"""Tool for continuing a block execution after human review approval."""
@property
def name(self) -> str:
return "continue_run_block"
@property
def description(self) -> str:
return (
"Continue executing a block after human review approval. "
"Use this after a run_block call returned review_required. "
"Pass the review_id from the review_required response. "
"The block will execute with the original pre-approved input data."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"review_id": {
"type": "string",
"description": (
"The review_id from a previous review_required response. "
"This resumes execution with the pre-approved input data."
),
},
},
"required": ["review_id"],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
review_id = (
kwargs.get("review_id", "").strip() if kwargs.get("review_id") else ""
)
session_id = session.session_id
if not review_id:
return ErrorResponse(
message="Please provide a review_id", session_id=session_id
)
if not user_id:
return ErrorResponse(
message="Authentication required", session_id=session_id
)
# Look up and validate the review record via adapter
reviews = await review_db().get_reviews_by_node_exec_ids([review_id], user_id)
review = reviews.get(review_id)
if not review:
return ErrorResponse(
message=(
f"Review '{review_id}' not found or already executed. "
"It may have been consumed by a previous continue_run_block call."
),
session_id=session_id,
)
# Validate the review belongs to this session
expected_graph_exec_id = f"{COPILOT_SESSION_PREFIX}{session_id}"
if review.graph_exec_id != expected_graph_exec_id:
return ErrorResponse(
message="Review does not belong to this session.",
session_id=session_id,
)
if review.status == ReviewStatus.WAITING:
return ErrorResponse(
message="Review has not been approved yet. "
"Please wait for the user to approve the review first.",
session_id=session_id,
)
if review.status == ReviewStatus.REJECTED:
return ErrorResponse(
message="Review was rejected. The block will not execute.",
session_id=session_id,
)
# Extract block_id from review_id: copilot-node-{block_id}:{random_hex}
block_id = parse_node_id_from_exec_id(review_id).removeprefix(
COPILOT_NODE_PREFIX
)
block = get_block(block_id)
if not block:
return ErrorResponse(
message=f"Block '{block_id}' not found", session_id=session_id
)
input_data: dict[str, Any] = (
review.payload if isinstance(review.payload, dict) else {}
)
logger.info(
f"Continuing block {block.name} ({block_id}) for user {user_id} "
f"with review_id={review_id}"
)
matched_creds, missing_creds = await resolve_block_credentials(
user_id, block, input_data
)
if missing_creds:
return ErrorResponse(
message=f"Block '{block.name}' requires credentials that are not configured.",
session_id=session_id,
)
result = await execute_block(
block=block,
block_id=block_id,
input_data=input_data,
user_id=user_id,
session_id=session_id,
node_exec_id=review_id,
matched_credentials=matched_creds,
)
# Delete review record after successful execution (one-time use)
if result.type != "error":
await review_db().delete_review_by_node_exec_id(review_id, user_id)
return result

View File

@@ -0,0 +1,186 @@
"""Tests for ContinueRunBlockTool."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from prisma.enums import ReviewStatus
from ._test_data import make_session
from .continue_run_block import ContinueRunBlockTool
from .models import BlockOutputResponse, ErrorResponse
_TEST_USER_ID = "test-user-continue"
def _make_review_model(
node_exec_id: str,
status: ReviewStatus = ReviewStatus.APPROVED,
payload: dict | None = None,
graph_exec_id: str = "",
):
"""Create a mock PendingHumanReviewModel."""
mock = MagicMock()
mock.node_exec_id = node_exec_id
mock.status = status
mock.payload = payload or {"text": "hello"}
mock.graph_exec_id = graph_exec_id
return mock
class TestContinueRunBlock:
@pytest.mark.asyncio(loop_scope="session")
async def test_missing_review_id_returns_error(self):
tool = ContinueRunBlockTool()
session = make_session(user_id=_TEST_USER_ID)
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
review_id="",
)
assert isinstance(response, ErrorResponse)
assert "review_id" in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_review_not_found_returns_error(self):
tool = ContinueRunBlockTool()
session = make_session(user_id=_TEST_USER_ID)
mock_db = MagicMock()
mock_db.get_reviews_by_node_exec_ids = AsyncMock(return_value={})
with patch(
"backend.copilot.tools.continue_run_block.review_db",
return_value=mock_db,
):
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
review_id="copilot-node-some-block:abc12345",
)
assert isinstance(response, ErrorResponse)
assert "not found" in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_waiting_review_returns_error(self):
tool = ContinueRunBlockTool()
session = make_session(user_id=_TEST_USER_ID)
review_id = "copilot-node-some-block:abc12345"
graph_exec_id = f"copilot-session-{session.session_id}"
review = _make_review_model(
review_id, status=ReviewStatus.WAITING, graph_exec_id=graph_exec_id
)
mock_db = MagicMock()
mock_db.get_reviews_by_node_exec_ids = AsyncMock(
return_value={review_id: review}
)
with patch(
"backend.copilot.tools.continue_run_block.review_db",
return_value=mock_db,
):
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
review_id=review_id,
)
assert isinstance(response, ErrorResponse)
assert "not been approved" in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_rejected_review_returns_error(self):
tool = ContinueRunBlockTool()
session = make_session(user_id=_TEST_USER_ID)
review_id = "copilot-node-some-block:abc12345"
graph_exec_id = f"copilot-session-{session.session_id}"
review = _make_review_model(
review_id, status=ReviewStatus.REJECTED, graph_exec_id=graph_exec_id
)
mock_db = MagicMock()
mock_db.get_reviews_by_node_exec_ids = AsyncMock(
return_value={review_id: review}
)
with patch(
"backend.copilot.tools.continue_run_block.review_db",
return_value=mock_db,
):
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
review_id=review_id,
)
assert isinstance(response, ErrorResponse)
assert "rejected" in response.message.lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_approved_review_executes_block(self):
tool = ContinueRunBlockTool()
session = make_session(user_id=_TEST_USER_ID)
review_id = "copilot-node-delete-branch-id:abc12345"
graph_exec_id = f"copilot-session-{session.session_id}"
input_data = {"repo_url": "https://github.com/test/repo", "branch": "main"}
review = _make_review_model(
review_id,
status=ReviewStatus.APPROVED,
payload=input_data,
graph_exec_id=graph_exec_id,
)
mock_block = MagicMock()
mock_block.name = "Delete Branch"
async def mock_execute(data, **kwargs):
yield "result", "Branch deleted"
mock_block.execute = mock_execute
mock_block.input_schema.get_credentials_fields_info.return_value = []
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="test-workspace-id")
)
mock_db = MagicMock()
mock_db.get_reviews_by_node_exec_ids = AsyncMock(
return_value={review_id: review}
)
mock_db.delete_review_by_node_exec_id = AsyncMock(return_value=1)
with (
patch(
"backend.copilot.tools.continue_run_block.review_db",
return_value=mock_db,
),
patch(
"backend.copilot.tools.continue_run_block.get_block",
return_value=mock_block,
),
patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
),
patch(
"backend.copilot.tools.helpers.match_credentials_to_requirements",
return_value=({}, []),
),
):
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
review_id=review_id,
)
assert isinstance(response, BlockOutputResponse)
assert response.success is True
assert response.block_name == "Delete Branch"
# Verify review was deleted (one-time use)
mock_db.delete_review_by_node_exec_id.assert_called_once_with(
review_id, _TEST_USER_ID
)

View File

@@ -1,34 +1,20 @@
"""CreateAgentTool - Creates agents from natural language descriptions."""
"""CreateAgentTool - Creates agents from pre-built JSON."""
import logging
import uuid
from typing import Any
from backend.copilot.model import ChatSession
from .agent_generator import (
AgentGeneratorNotConfiguredError,
decompose_goal,
enrich_library_agents_from_steps,
generate_agent,
get_user_message_for_error,
save_agent_to_library,
)
from .agent_generator.pipeline import fetch_library_agents, fix_validate_and_save
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
SuggestedGoalResponse,
ToolResponseBase,
)
from .models import ErrorResponse, ToolResponseBase
logger = logging.getLogger(__name__)
class CreateAgentTool(BaseTool):
"""Tool for creating agents from natural language descriptions."""
"""Tool for creating agents from pre-built JSON."""
@property
def name(self) -> str:
@@ -37,11 +23,12 @@ class CreateAgentTool(BaseTool):
@property
def description(self) -> str:
return (
"Create a new agent workflow from a natural language description. "
"First generates a preview, then saves to library if save=true. "
"\n\nIMPORTANT: Before calling this tool, search for relevant existing agents "
"Create a new agent workflow. Pass `agent_json` with the complete "
"agent graph JSON you generated using block schemas from find_block. "
"The tool validates, auto-fixes, and saves.\n\n"
"IMPORTANT: Before calling this tool, search for relevant existing agents "
"using find_library_agent that could be used as building blocks. "
"Pass their IDs in the library_agent_ids parameter so the generator can compose them."
"Pass their IDs in the library_agent_ids parameter."
)
@property
@@ -53,39 +40,39 @@ class CreateAgentTool(BaseTool):
return {
"type": "object",
"properties": {
"description": {
"type": "string",
"agent_json": {
"type": "object",
"description": (
"Natural language description of what the agent should do. "
"Be specific about inputs, outputs, and the workflow steps."
),
},
"context": {
"type": "string",
"description": (
"Additional context or answers to previous clarifying questions. "
"Include any preferences or constraints mentioned by the user."
"The agent JSON to validate and save. "
"Must contain 'nodes' and 'links' arrays, and optionally "
"'name' and 'description'."
),
},
"library_agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": (
"List of library agent IDs to use as building blocks. "
"Search for relevant agents using find_library_agent first, "
"then pass their IDs here so they can be composed into the new agent."
"List of library agent IDs to use as building blocks."
),
},
"save": {
"type": "boolean",
"description": (
"Whether to save the agent to the user's library. "
"Default is true. Set to false for preview only."
"Whether to save the agent. Default is true. "
"Set to false for preview only."
),
"default": True,
},
"folder_id": {
"type": "string",
"description": (
"Optional folder ID to save the agent into. "
"If not provided, the agent is saved at root level. "
"Use list_folders to find available folders."
),
},
},
"required": ["description"],
"required": ["agent_json"],
}
async def _execute(
@@ -94,277 +81,49 @@ class CreateAgentTool(BaseTool):
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the create_agent tool.
Flow:
1. Decompose the description into steps (may return clarifying questions)
2. Generate agent JSON (external service handles fixing and validation)
3. Preview or save based on the save parameter
"""
description = kwargs.get("description", "").strip()
context = kwargs.get("context", "")
library_agent_ids = kwargs.get("library_agent_ids", [])
save = kwargs.get("save", True)
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
session_id = session.session_id if session else None
logger.info(
f"[AGENT_CREATE_DEBUG] START - description_len={len(description)}, "
f"library_agent_ids={library_agent_ids}, save={save}, user_id={user_id}, session_id={session_id}"
if not agent_json:
return ErrorResponse(
message=(
"Please provide agent_json with the complete agent graph. "
"Use find_block to discover blocks, then generate the JSON."
),
error="missing_agent_json",
session_id=session_id,
)
save = kwargs.get("save", True)
library_agent_ids = kwargs.get("library_agent_ids", [])
folder_id: str | None = kwargs.get("folder_id")
nodes = agent_json.get("nodes", [])
if not nodes:
return ErrorResponse(
message="The agent JSON has no nodes. An agent needs at least one block.",
error="empty_agent",
session_id=session_id,
)
# Ensure top-level fields
if "id" not in agent_json:
agent_json["id"] = str(uuid.uuid4())
if "version" not in agent_json:
agent_json["version"] = 1
if "is_active" not in agent_json:
agent_json["is_active"] = True
# Fetch library agents for AgentExecutorBlock validation
library_agents = await fetch_library_agents(user_id, library_agent_ids)
return await fix_validate_and_save(
agent_json,
user_id=user_id,
session_id=session_id,
save=save,
is_update=False,
default_name="Generated Agent",
library_agents=library_agents,
folder_id=folder_id,
)
if not description:
return ErrorResponse(
message="Please provide a description of what the agent should do.",
error="Missing description parameter",
session_id=session_id,
)
# Fetch library agents by IDs if provided
library_agents = None
if user_id and library_agent_ids:
try:
from .agent_generator import get_library_agents_by_ids
library_agents = await get_library_agents_by_ids(
user_id=user_id,
agent_ids=library_agent_ids,
)
logger.debug(
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents by IDs: {e}")
try:
decomposition_result = await decompose_goal(
description, context, library_agents
)
logger.info(
f"[AGENT_CREATE_DEBUG] DECOMPOSE - type={decomposition_result.get('type') if decomposition_result else None}, "
f"session_id={session_id}"
)
except AgentGeneratorNotConfiguredError:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured, session_id={session_id}"
)
return ErrorResponse(
message=(
"Agent generation is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
if decomposition_result is None:
return ErrorResponse(
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
error="decomposition_failed",
details={"description": description[:100]},
session_id=session_id,
)
if decomposition_result.get("type") == "error":
error_msg = decomposition_result.get("error", "Unknown error")
error_type = decomposition_result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="analyze the goal",
llm_parse_message="The AI had trouble understanding this request. Please try rephrasing your goal.",
)
return ErrorResponse(
message=user_message,
error=f"decomposition_failed:{error_type}",
details={
"description": description[:100],
"service_error": error_msg,
"error_type": error_type,
},
session_id=session_id,
)
if decomposition_result.get("type") == "clarifying_questions":
questions = decomposition_result.get("questions", [])
return ClarificationNeededResponse(
message=(
"I need some more information to create this agent. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
],
session_id=session_id,
)
if decomposition_result.get("type") == "unachievable_goal":
suggested = decomposition_result.get("suggested_goal", "")
reason = decomposition_result.get("reason", "")
return SuggestedGoalResponse(
message=(
f"This goal cannot be accomplished with the available blocks. {reason}"
),
suggested_goal=suggested,
reason=reason,
original_goal=description,
goal_type="unachievable",
session_id=session_id,
)
if decomposition_result.get("type") == "vague_goal":
suggested = decomposition_result.get("suggested_goal", "")
reason = decomposition_result.get(
"reason", "The goal needs more specific details"
)
return SuggestedGoalResponse(
message="The goal is too vague to create a specific workflow.",
suggested_goal=suggested,
reason=reason,
original_goal=description,
goal_type="vague",
session_id=session_id,
)
if user_id and library_agents is not None:
try:
library_agents = await enrich_library_agents_from_steps(
user_id=user_id,
decomposition_result=decomposition_result,
existing_agents=library_agents,
include_marketplace=True,
)
logger.debug(
f"After enrichment: {len(library_agents)} total agents for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to enrich library agents from steps: {e}")
try:
agent_json = await generate_agent(
decomposition_result,
library_agents,
)
logger.info(
f"[AGENT_CREATE_DEBUG] GENERATE - "
f"success={agent_json is not None}, "
f"is_error={isinstance(agent_json, dict) and agent_json.get('type') == 'error'}, "
f"session_id={session_id}"
)
except AgentGeneratorNotConfiguredError:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured during generation, session_id={session_id}"
)
return ErrorResponse(
message=(
"Agent generation is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
if agent_json is None:
return ErrorResponse(
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
error="generation_failed",
details={"description": description[:100]},
session_id=session_id,
)
if isinstance(agent_json, dict) and agent_json.get("type") == "error":
error_msg = agent_json.get("error", "Unknown error")
error_type = agent_json.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="generate the agent",
llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.",
validation_message=(
"I wasn't able to create a valid agent for this request. "
"The generated workflow had some structural issues. "
"Please try simplifying your goal or breaking it into smaller steps."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"generation_failed:{error_type}",
details={
"description": description[:100],
"service_error": error_msg,
"error_type": error_type,
},
session_id=session_id,
)
agent_name = agent_json.get("name", "Generated Agent")
agent_description = agent_json.get("description", "")
node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", []))
logger.info(
f"[AGENT_CREATE_DEBUG] AGENT_JSON - name={agent_name}, "
f"nodes={node_count}, links={link_count}, save={save}, session_id={session_id}"
)
if not save:
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - AgentPreviewResponse, session_id={session_id}"
)
return AgentPreviewResponse(
message=(
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
f"Review it and call create_agent with save=true to save it to your library."
),
agent_json=agent_json,
agent_name=agent_name,
description=agent_description,
node_count=node_count,
link_count=link_count,
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
error="auth_required",
session_id=session_id,
)
try:
created_graph, library_agent = await save_agent_to_library(
agent_json, user_id
)
logger.info(
f"[AGENT_CREATE_DEBUG] SAVED - graph_id={created_graph.id}, "
f"library_agent_id={library_agent.id}, session_id={session_id}"
)
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - AgentSavedResponse, session_id={session_id}"
)
return AgentSavedResponse(
message=f"Agent '{created_graph.name}' has been saved to your library!",
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)
except Exception as e:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - save_failed: {str(e)}, session_id={session_id}"
)
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - ErrorResponse (save_failed), session_id={session_id}"
)
return ErrorResponse(
message=f"Failed to save the agent: {str(e)}",
error="save_failed",
details={"exception": str(e)},
session_id=session_id,
)

View File

@@ -1,19 +1,16 @@
"""Tests for CreateAgentTool response types."""
"""Tests for CreateAgentTool."""
from unittest.mock import AsyncMock, patch
from unittest.mock import MagicMock, patch
import pytest
from backend.copilot.tools.create_agent import CreateAgentTool
from backend.copilot.tools.models import (
ClarificationNeededResponse,
ErrorResponse,
SuggestedGoalResponse,
)
from backend.copilot.tools.models import AgentPreviewResponse, ErrorResponse
from ._test_data import make_session
_TEST_USER_ID = "test-user-create-agent"
_PIPELINE = "backend.copilot.tools.agent_generator.pipeline"
@pytest.fixture
@@ -26,102 +23,147 @@ def session():
return make_session(_TEST_USER_ID)
# ── Input validation tests ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_missing_description_returns_error(tool, session):
"""Missing description returns ErrorResponse."""
result = await tool._execute(user_id=_TEST_USER_ID, session=session, description="")
async def test_missing_agent_json_returns_error(tool, session):
"""Missing agent_json returns ErrorResponse."""
result = await tool._execute(user_id=_TEST_USER_ID, session=session)
assert isinstance(result, ErrorResponse)
assert result.error == "Missing description parameter"
assert result.error == "missing_agent_json"
# ── Local mode tests ────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_vague_goal_returns_suggested_goal_response(tool, session):
"""vague_goal decomposition result returns SuggestedGoalResponse, not ErrorResponse."""
vague_result = {
"type": "vague_goal",
"suggested_goal": "Monitor Twitter mentions for a specific keyword and send a daily digest email",
}
with (
patch(
"backend.copilot.tools.create_agent.decompose_goal",
new_callable=AsyncMock,
return_value=vague_result,
),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
description="monitor social media",
)
assert isinstance(result, SuggestedGoalResponse)
assert result.goal_type == "vague"
assert result.suggested_goal == vague_result["suggested_goal"]
assert result.original_goal == "monitor social media"
assert result.reason == "The goal needs more specific details"
assert not isinstance(result, ErrorResponse)
async def test_local_mode_empty_nodes_returns_error(tool, session):
"""Local mode with no nodes returns ErrorResponse."""
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json={"nodes": [], "links": []},
)
assert isinstance(result, ErrorResponse)
assert "no nodes" in result.message.lower()
@pytest.mark.asyncio
async def test_unachievable_goal_returns_suggested_goal_response(tool, session):
"""unachievable_goal decomposition result returns SuggestedGoalResponse, not ErrorResponse."""
unachievable_result = {
"type": "unachievable_goal",
"suggested_goal": "Summarize the latest news articles on a topic and send them by email",
"reason": "There are no blocks for mind-reading.",
}
with (
patch(
"backend.copilot.tools.create_agent.decompose_goal",
new_callable=AsyncMock,
return_value=unachievable_result,
),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
description="read my mind",
)
assert isinstance(result, SuggestedGoalResponse)
assert result.goal_type == "unachievable"
assert result.suggested_goal == unachievable_result["suggested_goal"]
assert result.original_goal == "read my mind"
assert result.reason == unachievable_result["reason"]
assert not isinstance(result, ErrorResponse)
@pytest.mark.asyncio
async def test_clarifying_questions_returns_clarification_needed_response(
tool, session
):
"""clarifying_questions decomposition result returns ClarificationNeededResponse."""
clarifying_result = {
"type": "clarifying_questions",
"questions": [
async def test_local_mode_preview(tool, session):
"""Local mode with save=false returns AgentPreviewResponse."""
agent_json = {
"name": "Test Agent",
"description": "A test agent",
"nodes": [
{
"question": "What platform should be monitored?",
"keyword": "platform",
"example": "Twitter, Reddit",
"id": "node-1",
"block_id": "block-1",
"input_default": {},
"metadata": {"position": {"x": 0, "y": 0}},
}
],
"links": [],
}
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
mock_fixer.get_fixes_applied.return_value = []
mock_validator = MagicMock()
mock_validator.validate.return_value = (True, None)
mock_validator.errors = []
with (
patch(
"backend.copilot.tools.create_agent.decompose_goal",
new_callable=AsyncMock,
return_value=clarifying_result,
),
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
description="monitor social media and alert me",
agent_json=agent_json,
save=False,
)
assert isinstance(result, ClarificationNeededResponse)
assert len(result.questions) == 1
assert result.questions[0].keyword == "platform"
assert isinstance(result, AgentPreviewResponse)
assert result.agent_name == "Test Agent"
assert result.node_count == 1
@pytest.mark.asyncio
async def test_local_mode_validation_failure(tool, session):
"""Local mode returns ErrorResponse when validation fails after fixing."""
agent_json = {
"nodes": [
{
"id": "node-1",
"block_id": "bad-block",
"input_default": {},
"metadata": {},
}
],
"links": [],
}
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
mock_fixer.get_fixes_applied.return_value = []
mock_validator = MagicMock()
mock_validator.validate.return_value = (False, "Block 'bad-block' not found")
mock_validator.errors = ["Block 'bad-block' not found"]
with (
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json=agent_json,
)
assert isinstance(result, ErrorResponse)
assert result.error == "validation_failed"
assert "Block 'bad-block' not found" in result.message
@pytest.mark.asyncio
async def test_local_mode_no_auth_returns_error(tool, session):
"""Local mode with save=true and no user returns ErrorResponse."""
agent_json = {
"nodes": [
{
"id": "node-1",
"block_id": "block-1",
"input_default": {},
"metadata": {},
}
],
"links": [],
}
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
mock_fixer.get_fixes_applied.return_value = []
mock_validator = MagicMock()
mock_validator.validate.return_value = (True, None)
mock_validator.errors = []
with (
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
):
result = await tool._execute(
user_id=None,
session=session,
agent_json=agent_json,
save=True,
)
assert isinstance(result, ErrorResponse)
assert "logged in" in result.message.lower()

View File

@@ -1,34 +1,20 @@
"""CustomizeAgentTool - Customizes marketplace/template agents using natural language."""
"""CustomizeAgentTool - Customizes marketplace/template agents."""
import logging
import uuid
from typing import Any
from backend.api.features.store.exceptions import AgentNotFoundError
from backend.copilot.model import ChatSession
from backend.data.db_accessors import store_db as get_store_db
from .agent_generator import (
AgentGeneratorNotConfiguredError,
customize_template,
get_user_message_for_error,
graph_to_json,
save_agent_to_library,
)
from .agent_generator.pipeline import fetch_library_agents, fix_validate_and_save
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
ToolResponseBase,
)
from .models import ErrorResponse, ToolResponseBase
logger = logging.getLogger(__name__)
class CustomizeAgentTool(BaseTool):
"""Tool for customizing marketplace/template agents using natural language."""
"""Tool for customizing marketplace/template agents."""
@property
def name(self) -> str:
@@ -37,9 +23,9 @@ class CustomizeAgentTool(BaseTool):
@property
def description(self) -> str:
return (
"Customize a marketplace or template agent using natural language. "
"Takes an existing agent from the marketplace and modifies it based on "
"the user's requirements before adding to their library."
"Customize a marketplace or template agent. Pass `agent_json` "
"with the complete customized agent JSON. The tool validates, "
"auto-fixes, and saves."
)
@property
@@ -51,37 +37,37 @@ class CustomizeAgentTool(BaseTool):
return {
"type": "object",
"properties": {
"agent_id": {
"type": "string",
"agent_json": {
"type": "object",
"description": (
"The marketplace agent ID in format 'creator/slug' "
"(e.g., 'autogpt/newsletter-writer'). "
"Get this from find_agent results."
"Complete customized agent JSON to validate and save. "
"Optionally include 'name' and 'description'."
),
},
"modifications": {
"type": "string",
"library_agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": (
"Natural language description of how to customize the agent. "
"Be specific about what changes you want to make."
),
},
"context": {
"type": "string",
"description": (
"Additional context or answers to previous clarifying questions."
"List of library agent IDs to use as building blocks."
),
},
"save": {
"type": "boolean",
"description": (
"Whether to save the customized agent to the user's library. "
"Default is true. Set to false for preview only."
"Whether to save the customized agent. Default is true."
),
"default": True,
},
"folder_id": {
"type": "string",
"description": (
"Optional folder ID to save the agent into. "
"If not provided, the agent is saved at root level. "
"Use list_folders to find available folders."
),
},
},
"required": ["agent_id", "modifications"],
"required": ["agent_json"],
}
async def _execute(
@@ -90,246 +76,46 @@ class CustomizeAgentTool(BaseTool):
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the customize_agent tool.
Flow:
1. Parse the agent ID to get creator/slug
2. Fetch the template agent from the marketplace
3. Call customize_template with the modification request
4. Preview or save based on the save parameter
"""
agent_id = kwargs.get("agent_id", "").strip()
modifications = kwargs.get("modifications", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
session_id = session.session_id if session else None
if not agent_id:
return ErrorResponse(
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
error="missing_agent_id",
session_id=session_id,
)
if not modifications:
return ErrorResponse(
message="Please describe how you want to customize this agent.",
error="missing_modifications",
session_id=session_id,
)
# Parse agent_id in format "creator/slug"
parts = [p.strip() for p in agent_id.split("/")]
if len(parts) != 2 or not parts[0] or not parts[1]:
if not agent_json:
return ErrorResponse(
message=(
f"Invalid agent ID format: '{agent_id}'. "
"Expected format is 'creator/agent-name' "
"(e.g., 'autogpt/newsletter-writer')."
"Please provide agent_json with the complete customized agent graph."
),
error="invalid_agent_id_format",
error="missing_agent_json",
session_id=session_id,
)
creator_username, agent_slug = parts
save = kwargs.get("save", True)
library_agent_ids = kwargs.get("library_agent_ids", [])
folder_id: str | None = kwargs.get("folder_id")
store_db = get_store_db()
# Fetch the marketplace agent details
try:
agent_details = await store_db.get_store_agent_details(
username=creator_username, agent_name=agent_slug
)
except AgentNotFoundError:
nodes = agent_json.get("nodes", [])
if not nodes:
return ErrorResponse(
message=(
f"Could not find marketplace agent '{agent_id}'. "
"Please check the agent ID and try again."
),
error="agent_not_found",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
return ErrorResponse(
message="Failed to fetch the marketplace agent. Please try again.",
error="fetch_error",
message="The agent JSON has no nodes.",
error="empty_agent",
session_id=session_id,
)
if not agent_details.store_listing_version_id:
return ErrorResponse(
message=(
f"The agent '{agent_id}' does not have an available version. "
"Please try a different agent."
),
error="no_version_available",
session_id=session_id,
)
# Ensure top-level fields before the fixer pipeline
if "id" not in agent_json:
agent_json["id"] = str(uuid.uuid4())
agent_json.setdefault("version", 1)
agent_json.setdefault("is_active", True)
# Get the full agent graph
try:
graph = await store_db.get_agent(agent_details.store_listing_version_id)
template_agent = graph_to_json(graph)
except Exception as e:
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
return ErrorResponse(
message="Failed to fetch the agent configuration. Please try again.",
error="graph_fetch_error",
session_id=session_id,
)
# Fetch library agents for AgentExecutorBlock validation
library_agents = await fetch_library_agents(user_id, library_agent_ids)
# Call customize_template
try:
result = await customize_template(
template_agent=template_agent,
modification_request=modifications,
context=context,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
"Agent customization is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error calling customize_template for {agent_id}: {e}")
return ErrorResponse(
message=(
"Failed to customize the agent due to a service error. "
"Please try again."
),
error="customization_service_error",
session_id=session_id,
)
if result is None:
return ErrorResponse(
message=(
"Failed to customize the agent. "
"The agent generation service may be unavailable or timed out. "
"Please try again."
),
error="customization_failed",
session_id=session_id,
)
# Handle error response
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="customize the agent",
llm_parse_message=(
"The AI had trouble customizing the agent. "
"Please try again or simplify your request."
),
validation_message=(
"The customized agent failed validation. "
"Please try rephrasing your request."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"customization_failed:{error_type}",
session_id=session_id,
)
# Handle clarifying questions
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
questions = result.get("questions") or []
if not isinstance(questions, list):
logger.error(
f"Unexpected clarifying questions format: {type(questions)}"
)
questions = []
return ClarificationNeededResponse(
message=(
"I need some more information to customize this agent. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
if isinstance(q, dict)
],
session_id=session_id,
)
# Result should be the customized agent JSON
if not isinstance(result, dict):
logger.error(f"Unexpected customize_template response type: {type(result)}")
return ErrorResponse(
message="Failed to customize the agent due to an unexpected response.",
error="unexpected_response_type",
session_id=session_id,
)
customized_agent = result
agent_name = customized_agent.get(
"name", f"Customized {agent_details.agent_name}"
return await fix_validate_and_save(
agent_json,
user_id=user_id,
session_id=session_id,
save=save,
is_update=False,
default_name="Customized Agent",
library_agents=library_agents,
folder_id=folder_id,
)
agent_description = customized_agent.get("description", "")
nodes = customized_agent.get("nodes")
links = customized_agent.get("links")
node_count = len(nodes) if isinstance(nodes, list) else 0
link_count = len(links) if isinstance(links, list) else 0
if not save:
return AgentPreviewResponse(
message=(
f"I've customized the agent '{agent_details.agent_name}'. "
f"The customized agent has {node_count} blocks. "
f"Review it and call customize_agent with save=true to save it."
),
agent_json=customized_agent,
agent_name=agent_name,
description=agent_description,
node_count=node_count,
link_count=link_count,
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
error="auth_required",
session_id=session_id,
)
# Save to user's library
try:
created_graph, library_agent = await save_agent_to_library(
customized_agent, user_id, is_update=False
)
return AgentSavedResponse(
message=(
f"Customized agent '{created_graph.name}' "
f"(based on '{agent_details.agent_name}') "
f"has been saved to your library!"
),
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error saving customized agent: {e}")
return ErrorResponse(
message="Failed to save the customized agent. Please try again.",
error="save_failed",
session_id=session_id,
)

View File

@@ -0,0 +1,172 @@
"""Tests for CustomizeAgentTool local mode."""
from unittest.mock import MagicMock, patch
import pytest
from backend.copilot.tools.customize_agent import CustomizeAgentTool
from backend.copilot.tools.models import AgentPreviewResponse, ErrorResponse
from ._test_data import make_session
_TEST_USER_ID = "test-user-customize-agent"
_PIPELINE = "backend.copilot.tools.agent_generator.pipeline"
@pytest.fixture
def tool():
return CustomizeAgentTool()
@pytest.fixture
def session():
return make_session(_TEST_USER_ID)
# ── Input validation tests ───────────────────────────────────────────────
@pytest.mark.asyncio
async def test_missing_agent_json_returns_error(tool, session):
"""Missing agent_json returns ErrorResponse."""
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
)
assert isinstance(result, ErrorResponse)
assert result.error == "missing_agent_json"
# ── Local mode tests (agent_json provided) ───────────────────────────────
@pytest.mark.asyncio
async def test_local_mode_empty_nodes_returns_error(tool, session):
"""Local mode with no nodes returns ErrorResponse."""
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json={"nodes": [], "links": []},
)
assert isinstance(result, ErrorResponse)
assert "no nodes" in result.message.lower()
@pytest.mark.asyncio
async def test_local_mode_preview(tool, session):
"""Local mode with save=false returns AgentPreviewResponse."""
agent_json = {
"name": "Customized Agent",
"description": "A customized agent",
"nodes": [
{
"id": "node-1",
"block_id": "block-1",
"input_default": {},
"metadata": {"position": {"x": 0, "y": 0}},
}
],
"links": [],
}
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
mock_fixer.get_fixes_applied.return_value = []
mock_validator = MagicMock()
mock_validator.validate.return_value = (True, None)
mock_validator.errors = []
with (
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json=agent_json,
save=False,
)
assert isinstance(result, AgentPreviewResponse)
assert result.agent_name == "Customized Agent"
assert result.node_count == 1
@pytest.mark.asyncio
async def test_local_mode_validation_failure(tool, session):
"""Local mode returns ErrorResponse when validation fails."""
agent_json = {
"nodes": [
{
"id": "node-1",
"block_id": "bad-block",
"input_default": {},
"metadata": {},
}
],
"links": [],
}
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
mock_fixer.get_fixes_applied.return_value = []
mock_validator = MagicMock()
mock_validator.validate.return_value = (False, "Block 'bad-block' not found")
mock_validator.errors = ["Block 'bad-block' not found"]
with (
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json=agent_json,
)
assert isinstance(result, ErrorResponse)
assert result.error == "validation_failed"
assert "Block 'bad-block' not found" in result.message
@pytest.mark.asyncio
async def test_local_mode_no_auth_returns_error(tool, session):
"""Local mode with save=true and no user returns ErrorResponse."""
agent_json = {
"nodes": [
{
"id": "node-1",
"block_id": "block-1",
"input_default": {},
"metadata": {},
}
],
"links": [],
}
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
mock_fixer.get_fixes_applied.return_value = []
mock_validator = MagicMock()
mock_validator.validate.return_value = (True, None)
mock_validator.errors = []
with (
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
):
result = await tool._execute(
user_id=None,
session=session,
agent_json=agent_json,
save=True,
)
assert isinstance(result, ErrorResponse)
assert "logged in" in result.message.lower()

View File

@@ -10,13 +10,34 @@ Lifecycle
---------
1. **Turn start** connect to the existing sandbox (sandbox_id in Redis) or
create a new one via ``get_or_create_sandbox()``.
``connect()`` in e2b v2 auto-resumes paused sandboxes.
2. **Execution** ``bash_exec`` and MCP file tools operate directly on the
sandbox's ``/home/user`` filesystem.
3. **Session expiry** E2B sandbox is killed by its own timeout (session_ttl).
3. **Turn end** the sandbox is paused via ``pause_sandbox()`` (fire-and-forget)
so idle time between turns costs nothing. Paused sandboxes have no compute
cost.
4. **Session delete** ``kill_sandbox()`` fully terminates the sandbox.
Cost control
------------
Sandboxes are created with a configurable ``on_timeout`` lifecycle action
(default: ``"pause"``). The explicit per-turn ``pause_sandbox()`` call is the
primary mechanism; the lifecycle setting is a safety net. Paused sandboxes are
free.
The sandbox_id is stored in Redis. The same key doubles as a creation lock:
a ``"creating"`` sentinel value is written with a short TTL while a new sandbox
is being provisioned, preventing duplicate creation under concurrent requests.
E2B project-level "paused sandbox lifetime" should be set to match
``_SANDBOX_ID_TTL`` (48 h) so orphaned paused sandboxes are auto-killed before
the Redis key expires.
"""
import asyncio
import contextlib
import logging
from typing import Any, Awaitable, Callable, Literal
from e2b import AsyncSandbox
@@ -24,147 +45,245 @@ from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
_SANDBOX_REDIS_PREFIX = "copilot:e2b:sandbox:"
E2B_WORKDIR = "/home/user"
_CREATING = "__creating__"
_CREATION_LOCK_TTL = 60
_MAX_WAIT_ATTEMPTS = 20 # 20 * 0.5s = 10s max wait
_SANDBOX_KEY_PREFIX = "copilot:e2b:sandbox:"
_CREATING_SENTINEL = "creating"
# Short TTL for the "creating" sentinel — if the process dies mid-creation the
# lock auto-expires so other callers are not blocked forever.
_CREATION_LOCK_TTL = 60 # seconds
_MAX_WAIT_ATTEMPTS = 20 # 20 × 0.5 s = 10 s max wait
# Timeout for E2B API calls (pause/kill) — short because these are control-plane
# operations; if the sandbox is unreachable, fail fast and retry on the next turn.
_E2B_API_TIMEOUT_SECONDS = 10
# Redis TTL for the sandbox key. Must be ≥ the E2B project "paused sandbox
# lifetime" setting (recommended: set both to 48 h).
_SANDBOX_ID_TTL = 48 * 3600 # 48 hours
def _sandbox_key(session_id: str) -> str:
return f"{_SANDBOX_KEY_PREFIX}{session_id}"
async def _get_stored_sandbox_id(session_id: str) -> str | None:
redis = await get_redis_async()
raw = await redis.get(_sandbox_key(session_id))
value = raw.decode() if isinstance(raw, bytes) else raw
return None if value == _CREATING_SENTINEL else value
async def _set_stored_sandbox_id(session_id: str, sandbox_id: str) -> None:
redis = await get_redis_async()
await redis.set(_sandbox_key(session_id), sandbox_id, ex=_SANDBOX_ID_TTL)
async def _clear_stored_sandbox_id(session_id: str) -> None:
redis = await get_redis_async()
await redis.delete(_sandbox_key(session_id))
async def _try_reconnect(
sandbox_id: str, api_key: str, redis_key: str, timeout: int
sandbox_id: str, session_id: str, api_key: str
) -> "AsyncSandbox | None":
"""Try to reconnect to an existing sandbox. Returns None on failure."""
try:
sandbox = await AsyncSandbox.connect(sandbox_id, api_key=api_key)
if await sandbox.is_running():
redis = await get_redis_async()
await redis.expire(redis_key, timeout)
# Refresh TTL so an active session cannot lose its sandbox_id at expiry.
await _set_stored_sandbox_id(session_id, sandbox_id)
return sandbox
except Exception as exc:
logger.warning("[E2B] Reconnect to %.12s failed: %s", sandbox_id, exc)
# Stale — clear Redis so a new sandbox can be created.
redis = await get_redis_async()
await redis.delete(redis_key)
# Stale — clear the sandbox_id from Redis so a new one can be created.
await _clear_stored_sandbox_id(session_id)
return None
async def get_or_create_sandbox(
session_id: str,
api_key: str,
timeout: int,
template: str = "base",
timeout: int = 43200,
on_timeout: Literal["kill", "pause"] = "pause",
) -> AsyncSandbox:
"""Return the existing E2B sandbox for *session_id* or create a new one.
The sandbox_id is persisted in Redis so the same sandbox is reused
across turns. Concurrent calls for the same session are serialised
via a Redis ``SET NX`` creation lock.
The sandbox key in Redis serves a dual purpose: it stores the sandbox_id
and acts as a creation lock via a ``"creating"`` sentinel value. This
removes the need for a separate lock key.
*timeout* controls how long the e2b sandbox may run continuously before
the ``on_timeout`` lifecycle rule fires (default: 3 h).
*on_timeout* controls what happens on timeout: ``"pause"`` (default, free)
or ``"kill"``.
"""
redis = await get_redis_async()
redis_key = f"{_SANDBOX_REDIS_PREFIX}{session_id}"
key = _sandbox_key(session_id)
# 1. Try reconnecting to an existing sandbox.
raw = await redis.get(redis_key)
if raw:
sandbox_id = raw if isinstance(raw, str) else raw.decode()
if sandbox_id != _CREATING:
sandbox = await _try_reconnect(sandbox_id, api_key, redis_key, timeout)
for _ in range(_MAX_WAIT_ATTEMPTS):
raw = await redis.get(key)
value = raw.decode() if isinstance(raw, bytes) else raw
if value and value != _CREATING_SENTINEL:
# Existing sandbox ID — try to reconnect (auto-resumes if paused).
sandbox = await _try_reconnect(value, session_id, api_key)
if sandbox:
logger.info(
"[E2B] Reconnected to %.12s for session %.12s",
sandbox_id,
value,
session_id,
)
return sandbox
# _try_reconnect cleared the key — loop to create a new sandbox.
continue
# 2. Claim creation lock. If another request holds it, wait for the result.
claimed = await redis.set(redis_key, _CREATING, nx=True, ex=_CREATION_LOCK_TTL)
if not claimed:
for _ in range(_MAX_WAIT_ATTEMPTS):
if value == _CREATING_SENTINEL:
# Another coroutine is creating — wait for it to finish.
await asyncio.sleep(0.5)
raw = await redis.get(redis_key)
if not raw:
break # Lock expired — fall through to retry creation
sandbox_id = raw if isinstance(raw, str) else raw.decode()
if sandbox_id != _CREATING:
sandbox = await _try_reconnect(sandbox_id, api_key, redis_key, timeout)
if sandbox:
return sandbox
break # Stale sandbox cleared — fall through to create
continue
# Try to claim creation lock again after waiting.
claimed = await redis.set(redis_key, _CREATING, nx=True, ex=_CREATION_LOCK_TTL)
if not claimed:
# Another process may have created a sandbox — try to use it.
raw = await redis.get(redis_key)
if raw:
sandbox_id = raw if isinstance(raw, str) else raw.decode()
if sandbox_id != _CREATING:
sandbox = await _try_reconnect(
sandbox_id, api_key, redis_key, timeout
)
if sandbox:
return sandbox
raise RuntimeError(
f"Could not acquire E2B creation lock for session {session_id[:12]}"
)
# 3. Create a new sandbox.
try:
sandbox = await AsyncSandbox.create(
template=template, api_key=api_key, timeout=timeout
# No sandbox and no active creation — atomically claim the creation slot.
claimed = await redis.set(
key, _CREATING_SENTINEL, nx=True, ex=_CREATION_LOCK_TTL
)
except Exception:
await redis.delete(redis_key)
raise
if not claimed:
# Race lost — another coroutine just claimed it.
await asyncio.sleep(0.1)
continue
await redis.setex(redis_key, timeout, sandbox.sandbox_id)
logger.info(
"[E2B] Created sandbox %.12s for session %.12s",
sandbox.sandbox_id,
session_id,
)
return sandbox
# We hold the slot — create the sandbox.
try:
sandbox = await AsyncSandbox.create(
template=template,
api_key=api_key,
timeout=timeout,
lifecycle={"on_timeout": on_timeout},
)
try:
await _set_stored_sandbox_id(session_id, sandbox.sandbox_id)
except Exception:
# Redis save failed — kill the sandbox to avoid leaking it.
with contextlib.suppress(Exception):
await sandbox.kill()
raise
except Exception:
# Release the creation slot so other callers can proceed.
await redis.delete(key)
raise
logger.info(
"[E2B] Created sandbox %.12s for session %.12s",
sandbox.sandbox_id,
session_id,
)
return sandbox
raise RuntimeError(f"Could not acquire E2B sandbox for session {session_id[:12]}")
async def kill_sandbox(session_id: str, api_key: str) -> bool:
"""Kill the E2B sandbox for *session_id* and clean up its Redis entry.
async def _act_on_sandbox(
session_id: str,
api_key: str,
action: str,
fn: Callable[[AsyncSandbox], Awaitable[Any]],
*,
clear_stored_id: bool = False,
) -> bool:
"""Connect to the sandbox for *session_id* and run *fn* on it.
Returns ``True`` if a sandbox was found and killed, ``False`` otherwise.
Safe to call even when no sandbox exists for the session.
Shared by ``pause_sandbox`` and ``kill_sandbox``. Returns ``True`` on
success, ``False`` when no sandbox is found or the action fails.
If *clear_stored_id* is ``True``, the sandbox_id is removed from Redis
only after the action succeeds so a failed kill can be retried.
"""
redis = await get_redis_async()
redis_key = f"{_SANDBOX_REDIS_PREFIX}{session_id}"
raw = await redis.get(redis_key)
if not raw:
sandbox_id = await _get_stored_sandbox_id(session_id)
if not sandbox_id:
return False
sandbox_id = raw if isinstance(raw, str) else raw.decode()
await redis.delete(redis_key)
if sandbox_id == _CREATING:
return False
async def _run() -> None:
await fn(await AsyncSandbox.connect(sandbox_id, api_key=api_key))
try:
async def _connect_and_kill():
sandbox = await AsyncSandbox.connect(sandbox_id, api_key=api_key)
await sandbox.kill()
await asyncio.wait_for(_connect_and_kill(), timeout=10)
await asyncio.wait_for(_run(), timeout=_E2B_API_TIMEOUT_SECONDS)
if clear_stored_id:
await _clear_stored_sandbox_id(session_id)
logger.info(
"[E2B] Killed sandbox %.12s for session %.12s",
"[E2B] %s sandbox %.12s for session %.12s",
action.capitalize(),
sandbox_id,
session_id,
)
return True
except Exception as exc:
logger.warning(
"[E2B] Failed to kill sandbox %.12s for session %.12s: %s",
"[E2B] Failed to %s sandbox %.12s for session %.12s: %s",
action,
sandbox_id,
session_id,
exc,
)
return False
async def pause_sandbox(session_id: str, api_key: str) -> bool:
"""Pause the E2B sandbox for *session_id* to stop billing between turns.
Paused sandboxes cost nothing and are resumed automatically by
``get_or_create_sandbox()`` on the next turn (via ``AsyncSandbox.connect()``).
The sandbox_id is kept in Redis so reconnection works seamlessly.
Prefer ``pause_sandbox_direct()`` when the sandbox object is already in
scope — it skips the Redis lookup and reconnect round-trip.
Returns ``True`` if the sandbox was found and paused, ``False`` otherwise.
Safe to call even when no sandbox exists for the session.
"""
return await _act_on_sandbox(session_id, api_key, "pause", lambda sb: sb.pause())
async def pause_sandbox_direct(sandbox: "AsyncSandbox", session_id: str) -> bool:
"""Pause an already-connected sandbox without a reconnect round-trip.
Use this in callers that already hold the live sandbox object (e.g. turn
teardown in ``service.py``). Saves the Redis lookup and
``AsyncSandbox.connect()`` call that ``pause_sandbox()`` would make.
Returns ``True`` on success, ``False`` on failure or timeout.
"""
try:
await asyncio.wait_for(sandbox.pause(), timeout=_E2B_API_TIMEOUT_SECONDS)
logger.info(
"[E2B] Paused sandbox %.12s for session %.12s",
sandbox.sandbox_id,
session_id,
)
return True
except Exception as exc:
logger.warning(
"[E2B] Failed to pause sandbox %.12s for session %.12s: %s",
sandbox.sandbox_id,
session_id,
exc,
)
return False
async def kill_sandbox(
session_id: str,
api_key: str,
) -> bool:
"""Kill the E2B sandbox for *session_id* and clear its Redis entry.
Returns ``True`` if a sandbox was found and killed, ``False`` otherwise.
Safe to call even when no sandbox exists for the session.
"""
return await _act_on_sandbox(
session_id,
api_key,
"kill",
lambda sb: sb.kill(),
clear_stored_id=True,
)

View File

@@ -1,6 +1,12 @@
"""Tests for e2b_sandbox: get_or_create_sandbox, _try_reconnect, kill_sandbox.
Uses mock Redis and mock AsyncSandbox — no external dependencies.
sandbox_id is stored in Redis under _SANDBOX_KEY_PREFIX + session_id.
The same key doubles as a creation lock via a "creating" sentinel value.
Tests mock:
- ``get_redis_async`` (sandbox key storage + creation lock sentinel)
- ``AsyncSandbox`` (E2B SDK)
Tests are synchronous (using asyncio.run) to avoid conflicts with the
session-scoped event loop in conftest.py.
"""
@@ -11,36 +17,50 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from .e2b_sandbox import (
_CREATING,
_SANDBOX_REDIS_PREFIX,
_CREATING_SENTINEL,
_try_reconnect,
get_or_create_sandbox,
kill_sandbox,
pause_sandbox,
pause_sandbox_direct,
)
_KEY = f"{_SANDBOX_REDIS_PREFIX}sess-123"
_SESSION_ID = "sess-123"
_API_KEY = "test-api-key"
_SANDBOX_ID = "sb-abc"
_TIMEOUT = 300
def _mock_sandbox(sandbox_id: str = "sb-abc", running: bool = True) -> MagicMock:
def _mock_sandbox(sandbox_id: str = _SANDBOX_ID, running: bool = True) -> MagicMock:
sb = MagicMock()
sb.sandbox_id = sandbox_id
sb.is_running = AsyncMock(return_value=running)
sb.pause = AsyncMock()
sb.kill = AsyncMock()
return sb
def _mock_redis(get_val: str | bytes | None = None, set_nx_result: bool = True):
def _mock_redis(
set_nx_result: bool = True,
stored_sandbox_id: str | None = None,
) -> AsyncMock:
"""Create a mock redis client.
*stored_sandbox_id* is returned by ``get()`` calls (simulates the sandbox_id
stored under the ``_SANDBOX_KEY_PREFIX`` key). ``set_nx_result`` controls
whether the creation-slot ``SET NX`` succeeds.
If *stored_sandbox_id* is None the key is absent (no sandbox, no lock).
"""
r = AsyncMock()
r.get = AsyncMock(return_value=get_val)
raw = stored_sandbox_id.encode() if stored_sandbox_id else None
r.get = AsyncMock(return_value=raw)
r.set = AsyncMock(return_value=set_nx_result)
r.setex = AsyncMock()
r.delete = AsyncMock()
r.expire = AsyncMock()
return r
def _patch_redis(redis):
def _patch_redis(redis: AsyncMock):
return patch(
"backend.copilot.tools.e2b_sandbox.get_redis_async",
new_callable=AsyncMock,
@@ -55,6 +75,7 @@ def _patch_redis(redis):
class TestTryReconnect:
def test_reconnect_success(self):
"""Returns the sandbox when it connects and is running; refreshes Redis TTL."""
sb = _mock_sandbox()
redis = _mock_redis()
with (
@@ -62,36 +83,39 @@ class TestTryReconnect:
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
result = asyncio.run(_try_reconnect(_SANDBOX_ID, _SESSION_ID, _API_KEY))
assert result is sb
redis.expire.assert_awaited_once_with(_KEY, _TIMEOUT)
redis.delete.assert_not_awaited()
# TTL must be refreshed so an active session cannot lose its key at expiry.
redis.set.assert_awaited_once()
def test_reconnect_not_running_clears_key(self):
def test_reconnect_not_running_clears_redis(self):
"""Clears sandbox_id in Redis when the sandbox is no longer running."""
sb = _mock_sandbox(running=False)
redis = _mock_redis()
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
result = asyncio.run(_try_reconnect(_SANDBOX_ID, _SESSION_ID, _API_KEY))
assert result is None
redis.delete.assert_awaited_once_with(_KEY)
redis.delete.assert_awaited_once()
def test_reconnect_exception_clears_key(self):
redis = _mock_redis()
def test_reconnect_exception_clears_redis(self):
"""Clears sandbox_id in Redis when connect raises an exception."""
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
result = asyncio.run(_try_reconnect(_SANDBOX_ID, _SESSION_ID, _API_KEY))
assert result is None
redis.delete.assert_awaited_once_with(_KEY)
redis.delete.assert_awaited_once()
# ---------------------------------------------------------------------------
@@ -103,38 +127,63 @@ class TestGetOrCreateSandbox:
def test_reconnect_existing(self):
"""When Redis has a valid sandbox_id, reconnect to it."""
sb = _mock_sandbox()
redis = _mock_redis(get_val="sb-abc")
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
)
assert result is sb
mock_cls.create.assert_not_called()
# redis.set called once to refresh TTL, not to claim a creation slot
redis.set.assert_awaited_once()
def test_create_new_when_no_key(self):
"""When Redis is empty, claim lock and create a new sandbox."""
sb = _mock_sandbox("sb-new")
redis = _mock_redis(get_val=None, set_nx_result=True)
def test_create_new_when_no_stored_id(self):
"""When Redis has no sandbox_id, claim slot and create a new sandbox."""
new_sb = _mock_sandbox("sb-new")
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.create = AsyncMock(return_value=sb)
mock_cls.create = AsyncMock(return_value=new_sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
)
assert result is sb
redis.setex.assert_awaited_once_with(_KEY, _TIMEOUT, "sb-new")
assert result is new_sb
mock_cls.create.assert_awaited_once()
# Verify lifecycle param is set
_, kwargs = mock_cls.create.call_args
assert kwargs.get("lifecycle") == {"on_timeout": "pause"}
# sandbox_id should be saved to Redis
redis.set.assert_awaited()
def test_create_failure_clears_lock(self):
"""If sandbox creation fails, the Redis lock is deleted."""
redis = _mock_redis(get_val=None, set_nx_result=True)
def test_create_with_on_timeout_kill(self):
"""on_timeout='kill' is passed through to AsyncSandbox.create."""
new_sb = _mock_sandbox("sb-new")
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.create = AsyncMock(return_value=new_sb)
asyncio.run(
get_or_create_sandbox(
_SESSION_ID, _API_KEY, timeout=_TIMEOUT, on_timeout="kill"
)
)
_, kwargs = mock_cls.create.call_args
assert kwargs.get("lifecycle") == {"on_timeout": "kill"}
def test_create_failure_releases_slot(self):
"""If sandbox creation fails, the Redis creation slot is deleted."""
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
@@ -142,17 +191,53 @@ class TestGetOrCreateSandbox:
mock_cls.create = AsyncMock(side_effect=RuntimeError("quota"))
with pytest.raises(RuntimeError, match="quota"):
asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
)
redis.delete.assert_awaited_once_with(_KEY)
redis.delete.assert_awaited_once()
def test_wait_for_lock_then_reconnect(self):
"""When another process holds the lock, wait and reconnect."""
def test_redis_save_failure_kills_sandbox_and_releases_slot(self):
"""If Redis save fails after creation, sandbox is killed and slot released."""
new_sb = _mock_sandbox("sb-new")
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
# First set() call = creation slot SET NX (returns True).
# Second set() call = sandbox_id save (raises).
call_count = 0
async def _set_side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return True # creation slot claimed
raise RuntimeError("redis error")
redis.set = AsyncMock(side_effect=_set_side_effect)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.create = AsyncMock(return_value=new_sb)
with pytest.raises(RuntimeError, match="redis error"):
asyncio.run(
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
)
# Sandbox must be killed to avoid leaking it
new_sb.kill.assert_awaited_once()
# Creation slot must always be released
redis.delete.assert_awaited_once()
def test_wait_for_creating_sentinel_then_reconnect(self):
"""When the key holds the 'creating' sentinel, wait then reconnect."""
sb = _mock_sandbox("sb-other")
redis = _mock_redis()
redis.get = AsyncMock(side_effect=[_CREATING, "sb-other"])
# First get() returns the sentinel; second returns the real ID.
redis = AsyncMock()
creating_raw = _CREATING_SENTINEL.encode()
redis.get = AsyncMock(side_effect=[creating_raw, b"sb-other"])
redis.set = AsyncMock(return_value=False)
redis.delete = AsyncMock()
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
@@ -163,16 +248,21 @@ class TestGetOrCreateSandbox:
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
)
assert result is sb
def test_stale_reconnect_clears_and_creates(self):
"""When stored sandbox is stale, clear key and create a new one."""
"""When stored sandbox is stale (not running), clear it and create a new one."""
stale_sb = _mock_sandbox("sb-stale", running=False)
new_sb = _mock_sandbox("sb-fresh")
redis = _mock_redis(get_val="sb-stale", set_nx_result=True)
# First get() returns stale id (for reconnect check), then None (after clear).
redis = AsyncMock()
redis.get = AsyncMock(side_effect=[b"sb-stale", None])
redis.set = AsyncMock(return_value=True)
redis.delete = AsyncMock()
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
@@ -180,10 +270,11 @@ class TestGetOrCreateSandbox:
mock_cls.connect = AsyncMock(return_value=stale_sb)
mock_cls.create = AsyncMock(return_value=new_sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
)
assert result is new_sb
# Redis delete called at least once to clear stale id
redis.delete.assert_awaited()
@@ -194,70 +285,48 @@ class TestGetOrCreateSandbox:
class TestKillSandbox:
def test_kill_existing_sandbox(self):
"""Kill a running sandbox and clean up Redis."""
"""Kill a running sandbox and clear its Redis entry."""
sb = _mock_sandbox()
sb.kill = AsyncMock()
redis = _mock_redis(get_val="sb-abc")
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
assert result is True
redis.delete.assert_awaited_once_with(_KEY)
sb.kill.assert_awaited_once()
# Redis key cleared after successful kill
redis.delete.assert_awaited_once()
def test_kill_no_sandbox(self):
"""No-op when no sandbox exists in Redis."""
redis = _mock_redis(get_val=None)
"""No-op when Redis has no sandbox_id."""
redis = _mock_redis(stored_sandbox_id=None)
with _patch_redis(redis):
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
assert result is False
redis.delete.assert_not_awaited()
def test_kill_creating_state(self):
"""Clears Redis key but returns False when sandbox is still being created."""
redis = _mock_redis(get_val=_CREATING)
with _patch_redis(redis):
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
def test_kill_connect_failure_keeps_redis(self):
"""Returns False and leaves Redis entry intact when connect/kill fails.
assert result is False
redis.delete.assert_awaited_once_with(_KEY)
def test_kill_connect_failure(self):
"""Returns False and cleans Redis if connect/kill fails."""
redis = _mock_redis(get_val="sb-abc")
Keeping the sandbox_id in Redis allows the kill to be retried.
"""
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
assert result is False
redis.delete.assert_awaited_once_with(_KEY)
redis.delete.assert_not_awaited()
def test_kill_with_bytes_redis_value(self):
"""Redis may return bytes — kill_sandbox should decode correctly."""
sb = _mock_sandbox()
sb.kill = AsyncMock()
redis = _mock_redis(get_val=b"sb-abc")
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
assert result is True
sb.kill.assert_awaited_once()
def test_kill_timeout_returns_false(self):
"""Returns False when E2B API calls exceed the 10s timeout."""
redis = _mock_redis(get_val="sb-abc")
def test_kill_timeout_keeps_redis(self):
"""Returns False and leaves Redis entry intact when the E2B call times out."""
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
_patch_redis(redis),
patch(
@@ -266,7 +335,146 @@ class TestKillSandbox:
side_effect=asyncio.TimeoutError,
),
):
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
assert result is False
redis.delete.assert_not_awaited()
def test_kill_creating_sentinel_returns_false(self):
"""No-op when the key holds the 'creating' sentinel (no real sandbox yet)."""
redis = _mock_redis(stored_sandbox_id=_CREATING_SENTINEL)
with _patch_redis(redis):
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
assert result is False
# ---------------------------------------------------------------------------
# pause_sandbox
# ---------------------------------------------------------------------------
class TestPauseSandbox:
def test_pause_existing_sandbox(self):
"""Pause a running sandbox; Redis sandbox_id is preserved."""
sb = _mock_sandbox()
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
assert result is True
sb.pause.assert_awaited_once()
# sandbox_id should remain in Redis (not cleared on pause)
redis.delete.assert_not_awaited()
def test_pause_no_sandbox(self):
"""No-op when Redis has no sandbox_id."""
redis = _mock_redis(stored_sandbox_id=None)
with _patch_redis(redis):
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
assert result is False
def test_pause_connect_failure(self):
"""Returns False if connect fails."""
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
assert result is False
def test_pause_creating_sentinel_returns_false(self):
"""No-op when the key holds the 'creating' sentinel (no real sandbox yet)."""
redis = _mock_redis(stored_sandbox_id=_CREATING_SENTINEL)
with _patch_redis(redis):
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
assert result is False
def test_pause_timeout_returns_false(self):
"""Returns False and preserves Redis entry when the E2B API call times out."""
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
_patch_redis(redis),
patch(
"backend.copilot.tools.e2b_sandbox.asyncio.wait_for",
new_callable=AsyncMock,
side_effect=asyncio.TimeoutError,
),
):
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
assert result is False
# sandbox_id must remain in Redis so the next turn can reconnect
redis.delete.assert_not_awaited()
def test_pause_then_reconnect_reuses_sandbox(self):
"""After pause, get_or_create_sandbox reconnects the same sandbox.
Covers the pause->reconnect cycle: connect() auto-resumes a paused
sandbox, and is_running() returns True once resume completes, so the
same sandbox_id is reused rather than a new one being created.
"""
sb = _mock_sandbox(_SANDBOX_ID)
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
# Step 1: pause the sandbox
paused = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
assert paused is True
sb.pause.assert_awaited_once()
# Step 2: reconnect on next turn -- same sandbox should be returned
result = asyncio.run(
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
)
assert result is sb
mock_cls.create.assert_not_called()
# ---------------------------------------------------------------------------
# pause_sandbox_direct
# ---------------------------------------------------------------------------
class TestPauseSandboxDirect:
def test_pause_direct_success(self):
"""Pauses the sandbox directly without a Redis lookup or reconnect."""
sb = _mock_sandbox()
result = asyncio.run(pause_sandbox_direct(sb, _SESSION_ID))
assert result is True
sb.pause.assert_awaited_once()
def test_pause_direct_failure_returns_false(self):
"""Returns False when sandbox.pause() raises."""
sb = _mock_sandbox()
sb.pause = AsyncMock(side_effect=RuntimeError("e2b error"))
result = asyncio.run(pause_sandbox_direct(sb, _SESSION_ID))
assert result is False
def test_pause_direct_timeout_returns_false(self):
"""Returns False when sandbox.pause() exceeds the 10s timeout."""
sb = _mock_sandbox()
with patch(
"backend.copilot.tools.e2b_sandbox.asyncio.wait_for",
new_callable=AsyncMock,
side_effect=asyncio.TimeoutError,
):
result = asyncio.run(pause_sandbox_direct(sb, _SESSION_ID))
assert result is False
redis.delete.assert_awaited_once_with(_KEY)

View File

@@ -1,32 +1,20 @@
"""EditAgentTool - Edits existing agents using natural language."""
"""EditAgentTool - Edits existing agents using pre-built JSON."""
import logging
from typing import Any
from backend.copilot.model import ChatSession
from .agent_generator import (
AgentGeneratorNotConfiguredError,
generate_agent_patch,
get_agent_as_json,
get_user_message_for_error,
save_agent_to_library,
)
from .agent_generator import get_agent_as_json
from .agent_generator.pipeline import fetch_library_agents, fix_validate_and_save
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
ToolResponseBase,
)
from .models import ErrorResponse, ToolResponseBase
logger = logging.getLogger(__name__)
class EditAgentTool(BaseTool):
"""Tool for editing existing agents using natural language."""
"""Tool for editing existing agents using pre-built JSON."""
@property
def name(self) -> str:
@@ -35,11 +23,12 @@ class EditAgentTool(BaseTool):
@property
def description(self) -> str:
return (
"Edit an existing agent from the user's library using natural language. "
"Generates updates to the agent while preserving unchanged parts. "
"\n\nIMPORTANT: Before calling this tool, if the changes involve adding new "
"Edit an existing agent. Pass `agent_json` with the complete "
"updated agent JSON you generated. The tool validates, auto-fixes, "
"and saves.\n\n"
"IMPORTANT: Before calling this tool, if the changes involve adding new "
"functionality, search for relevant existing agents using find_library_agent "
"that could be used as building blocks. Pass their IDs in library_agent_ids."
"that could be used as building blocks."
)
@property
@@ -58,26 +47,20 @@ class EditAgentTool(BaseTool):
"Can be a graph ID or library agent ID."
),
},
"changes": {
"type": "string",
"agent_json": {
"type": "object",
"description": (
"Natural language description of what changes to make. "
"Be specific about what to add, remove, or modify."
),
},
"context": {
"type": "string",
"description": (
"Additional context or answers to previous clarifying questions."
"Complete updated agent JSON to validate and save. "
"Must contain 'nodes' and 'links'. "
"Include 'name' and/or 'description' if they need "
"to be updated."
),
},
"library_agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": (
"List of library agent IDs to use as building blocks for the changes. "
"If adding new functionality, search for relevant agents using "
"find_library_agent first, then pass their IDs here."
"List of library agent IDs to use as building blocks for the changes."
),
},
"save": {
@@ -89,7 +72,7 @@ class EditAgentTool(BaseTool):
"default": True,
},
},
"required": ["agent_id", "changes"],
"required": ["agent_id", "agent_json"],
}
async def _execute(
@@ -98,36 +81,39 @@ class EditAgentTool(BaseTool):
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the edit_agent tool.
Flow:
1. Fetch the current agent
2. Generate updated agent (external service handles fixing and validation)
3. Preview or save based on the save parameter
"""
agent_id = kwargs.get("agent_id", "").strip()
changes = kwargs.get("changes", "").strip()
context = kwargs.get("context", "")
library_agent_ids = kwargs.get("library_agent_ids", [])
save = kwargs.get("save", True)
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
session_id = session.session_id if session else None
if not agent_id:
return ErrorResponse(
message="Please provide the agent ID to edit.",
error="Missing agent_id parameter",
error="missing_agent_id",
session_id=session_id,
)
if not changes:
if not agent_json:
return ErrorResponse(
message="Please describe what changes you want to make.",
error="Missing changes parameter",
message=(
"Please provide agent_json with the complete updated agent graph."
),
error="missing_agent_json",
session_id=session_id,
)
current_agent = await get_agent_as_json(agent_id, user_id)
save = kwargs.get("save", True)
library_agent_ids = kwargs.get("library_agent_ids", [])
nodes = agent_json.get("nodes", [])
if not nodes:
return ErrorResponse(
message="The agent JSON has no nodes.",
error="empty_agent",
session_id=session_id,
)
# Preserve original agent's ID
current_agent = await get_agent_as_json(agent_id, user_id)
if current_agent is None:
return ErrorResponse(
message=f"Could not find agent with ID '{agent_id}' in your library.",
@@ -135,142 +121,19 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Fetch library agents by IDs if provided
library_agents = None
if user_id and library_agent_ids:
try:
from .agent_generator import get_library_agents_by_ids
agent_json["id"] = current_agent.get("id", agent_id)
agent_json["version"] = current_agent.get("version", 1)
agent_json.setdefault("is_active", True)
graph_id = current_agent.get("id")
# Filter out the current agent being edited
filtered_ids = [id for id in library_agent_ids if id != graph_id]
# Fetch library agents for AgentExecutorBlock validation
library_agents = await fetch_library_agents(user_id, library_agent_ids)
library_agents = await get_library_agents_by_ids(
user_id=user_id,
agent_ids=filtered_ids,
)
logger.debug(
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents by IDs: {e}")
update_request = changes
if context:
update_request = f"{changes}\n\nAdditional context:\n{context}"
try:
result = await generate_agent_patch(
update_request,
current_agent,
library_agents,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
"Agent editing is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
if result is None:
return ErrorResponse(
message="Failed to generate changes. The agent generation service may be unavailable or timed out. Please try again.",
error="update_generation_failed",
details={"agent_id": agent_id, "changes": changes[:100]},
session_id=session_id,
)
# Check if the result is an error from the external service
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="generate the changes",
llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.",
validation_message="The generated changes failed validation. Please try rephrasing your request.",
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"update_generation_failed:{error_type}",
details={
"agent_id": agent_id,
"changes": changes[:100],
"service_error": error_msg,
"error_type": error_type,
},
session_id=session_id,
)
if result.get("type") == "clarifying_questions":
questions = result.get("questions", [])
return ClarificationNeededResponse(
message=(
"I need some more information about the changes. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
],
session_id=session_id,
)
updated_agent = result
agent_name = updated_agent.get("name", "Updated Agent")
agent_description = updated_agent.get("description", "")
node_count = len(updated_agent.get("nodes", []))
link_count = len(updated_agent.get("links", []))
if not save:
return AgentPreviewResponse(
message=(
f"I've updated the agent. "
f"The agent now has {node_count} blocks. "
f"Review it and call edit_agent with save=true to save the changes."
),
agent_json=updated_agent,
agent_name=agent_name,
description=agent_description,
node_count=node_count,
link_count=link_count,
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
error="auth_required",
session_id=session_id,
)
try:
created_graph, library_agent = await save_agent_to_library(
updated_agent, user_id, is_update=True
)
return AgentSavedResponse(
message=f"Updated agent '{created_graph.name}' has been saved to your library!",
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)
except Exception as e:
return ErrorResponse(
message=f"Failed to save the updated agent: {str(e)}",
error="save_failed",
details={"exception": str(e)},
session_id=session_id,
)
return await fix_validate_and_save(
agent_json,
user_id=user_id,
session_id=session_id,
save=save,
is_update=True,
default_name="Updated Agent",
library_agents=library_agents,
)

View File

@@ -32,7 +32,7 @@ COPILOT_EXCLUDED_BLOCK_TYPES = {
BlockType.NOTE, # Visual annotation only - no runtime behavior
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
BlockType.MCP_TOOL, # Has dedicated run_mcp_tool tool with proper discovery + auth flow
BlockType.MCP_TOOL, # Has dedicated run_mcp_tool tool with discovery + auth flow
}
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
@@ -72,6 +72,15 @@ class FindBlockTool(BaseTool):
"Use keywords like 'email', 'http', 'text', 'ai', etc."
),
},
"include_schemas": {
"type": "boolean",
"description": (
"If true, include full input_schema and output_schema "
"for each block. Use when generating agent JSON that "
"needs block schemas. Default is false."
),
"default": False,
},
},
"required": ["query"],
}
@@ -99,6 +108,7 @@ class FindBlockTool(BaseTool):
ErrorResponse: Error message
"""
query = kwargs.get("query", "").strip()
include_schemas = kwargs.get("include_schemas", False)
session_id = session.session_id
if not query:
@@ -143,15 +153,21 @@ class FindBlockTool(BaseTool):
):
continue
blocks.append(
BlockInfoSummary(
id=block_id,
name=block.name,
description=block.description or "",
categories=[c.value for c in block.categories],
)
summary = BlockInfoSummary(
id=block_id,
name=block.name,
description=block.optimized_description or block.description or "",
categories=[c.value for c in block.categories],
)
if include_schemas:
info = block.get_info()
summary.input_schema = info.inputSchema
summary.output_schema = info.outputSchema
summary.static_output = info.staticOutput
blocks.append(summary)
if len(blocks) >= _TARGET_RESULTS:
break

View File

@@ -25,6 +25,7 @@ def make_mock_block(
input_schema: dict | None = None,
output_schema: dict | None = None,
credentials_fields: dict | None = None,
static_output: bool = False,
):
"""Create a mock block for testing."""
mock = MagicMock()
@@ -33,6 +34,7 @@ def make_mock_block(
mock.description = f"{name} description"
mock.block_type = block_type
mock.disabled = disabled
mock.static_output = static_output
mock.input_schema = MagicMock()
mock.input_schema.jsonschema.return_value = input_schema or {
"properties": {},
@@ -42,6 +44,15 @@ def make_mock_block(
mock.output_schema = MagicMock()
mock.output_schema.jsonschema.return_value = output_schema or {}
mock.categories = []
mock.optimized_description = None
# Mock get_info() for include_schemas support
mock_info = MagicMock()
mock_info.inputSchema = input_schema or {"properties": {}, "required": []}
mock_info.outputSchema = output_schema or {}
mock_info.staticOutput = static_output
mock.get_info.return_value = mock_info
return mock
@@ -399,3 +410,92 @@ class TestFindBlockFiltering:
f"Average chars per block ({avg_chars}) exceeds 500. "
f"Total response: {total_chars} chars for {response.count} blocks."
)
@pytest.mark.asyncio(loop_scope="session")
async def test_include_schemas_false_omits_schemas(self):
"""Without include_schemas, schemas should be empty dicts."""
session = make_session(user_id=_TEST_USER_ID)
input_schema = {"properties": {"url": {"type": "string"}}, "required": ["url"]}
output_schema = {"properties": {"result": {"type": "string"}}}
search_results = [{"content_id": "block-1", "score": 0.9}]
block = make_mock_block(
"block-1",
"Test Block",
BlockType.STANDARD,
input_schema=input_schema,
output_schema=output_schema,
)
mock_search_db = MagicMock()
mock_search_db.unified_hybrid_search = AsyncMock(
return_value=(search_results, 1)
)
with (
patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
),
patch(
"backend.copilot.tools.find_block.get_block",
return_value=block,
),
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
query="test",
include_schemas=False,
)
assert isinstance(response, BlockListResponse)
assert response.blocks[0].input_schema == {}
assert response.blocks[0].output_schema == {}
assert response.blocks[0].static_output is False
@pytest.mark.asyncio(loop_scope="session")
async def test_include_schemas_true_populates_schemas(self):
"""With include_schemas=true, schemas should be populated from block info."""
session = make_session(user_id=_TEST_USER_ID)
input_schema = {"properties": {"url": {"type": "string"}}, "required": ["url"]}
output_schema = {"properties": {"result": {"type": "string"}}}
search_results = [{"content_id": "block-1", "score": 0.9}]
block = make_mock_block(
"block-1",
"Test Block",
BlockType.STANDARD,
input_schema=input_schema,
output_schema=output_schema,
static_output=True,
)
mock_search_db = MagicMock()
mock_search_db.unified_hybrid_search = AsyncMock(
return_value=(search_results, 1)
)
with (
patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
),
patch(
"backend.copilot.tools.find_block.get_block",
return_value=block,
),
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
query="test",
include_schemas=True,
)
assert isinstance(response, BlockListResponse)
assert response.blocks[0].input_schema == input_schema
assert response.blocks[0].output_schema == output_schema
assert response.blocks[0].static_output is True

View File

@@ -22,6 +22,9 @@ class FindLibraryAgentTool(BaseTool):
"Search for or list agents in the user's library. Use this to find "
"agents the user has already added to their library, including agents "
"they created or added from the marketplace. "
"When creating agents with sub-agent composition, use this to get "
"the agent's graph_id, graph_version, input_schema, and output_schema "
"needed for AgentExecutorBlock nodes. "
"Omit the query to list all agents."
)

Some files were not shown because too many files have changed in this diff Show More