Compare commits

...

26 Commits

Author SHA1 Message Date
Otto (AGPT)
e9afd9fa01 fix: cast OnboardingStep enum to text in funnel view
The completedSteps column is a platform."OnboardingStep" enum array.
UNNEST produces enum values that can't be compared directly to text
from the VALUES clause. Adding ::text cast fixes the type mismatch.
2026-03-13 11:55:37 +00:00
Zamil Majdy
ddb4f6e9de fix(analytics): address second batch of PR review comments
- user_onboarding_funnel: build complete 22-step grid with VALUES CTE
  so zero-completion steps are always present, fixing LAG comparisons
  against wrong predecessors; update docs to reflect all 22 steps
- users_activities: use COUNT(DISTINCT "id") for agent_count to avoid
  counting multiple version rows per graph; add COALESCE(..., 0) for
  agent_count, unique_agent_runs, agent_runs; update docs column list
  to include node_execution_incomplete and node_execution_review
- generate_views: update Step 3 comment to clarify NOLOGIN role needs
  WITH LOGIN PASSWORD not just WITH PASSWORD; add fail-fast validation
  for unknown --only view names with helpful error message
2026-03-12 00:47:55 +07:00
Zamil Majdy
f585d97928 fix(analytics): move new status columns to end of users_activities SELECT
CREATE OR REPLACE VIEW requires existing columns to stay in position.
Moving node_execution_incomplete and node_execution_review after
is_active_after_7d so the replacement doesn't shift existing columns.
2026-03-12 00:01:40 +07:00
Zamil Majdy
7d39234fdd fix(analytics): address PR review comments
- user_block_spending: use ->> instead of -> for JSONB field extraction
  before casting to int (avoids runtime cast errors)
- generate_views: create analytics_readonly as NOLOGIN to avoid a
  usable role with a known default password
- generate_views: percent-encode DB credentials in the URI builder so
  passwords with reserved chars (@, :, /) connect correctly
- graph_execution: remove WHERE filter on sensitive_action_safe_mode
  before DISTINCT ON so the latest LibraryAgent version always wins
  (fixes possibly_ai being sticky once any version had the flag set)
- retention_agent: use DISTINCT ON ordered by version DESC instead of
  MAX(name) so renamed agents resolve to their latest name
- retention_login_daily: add 90-day cohort_start filter to first_login
  CTE so the view matches its documented window
- user_onboarding_funnel: map the 8 missing OnboardingStep enum values
  (VISIT_COPILOT, RE_RUN_AGENT, SCHEDULE_AGENT, RUN_AGENTS, RUN_3_DAYS,
  TRIGGER_WEBHOOK, RUN_14_DAYS, RUN_AGENTS_100) to step_order 15-22
- users_activities: use updatedAt instead of createdAt for
  last_agent_save_time; add node_execution_incomplete and
  node_execution_review status columns
2026-03-11 23:48:42 +07:00
Zamil Majdy
6e9d4c4333 perf(analytics): fix fan-out in users_activities view
The original CTEs drove all joins from user_logins, causing a
O(users × executions × node_executions) fan-out that made the view
too heavy for Supabase to serve. Rewrote each CTE to aggregate its
own source table directly by userId, then LEFT JOIN the aggregates
in the final SELECT.
2026-03-11 23:39:14 +07:00
Zamil Majdy
8aad333a45 refactor(analytics): move generate_views.py to backend, add poetry run analytics-setup/analytics-views scripts 2026-03-11 16:23:29 +07:00
Zamil Majdy
856f0d980d fix(analytics): restrict analytics_readonly to analytics schema only via security_invoker=false views 2026-03-11 16:16:03 +07:00
Zamil Majdy
3c3aadd361 docs(analytics): add step-by-step quick start to generate_views.py docstring 2026-03-11 16:12:22 +07:00
Zamil Majdy
e87a693fdd feat(analytics): auto-load DB creds from backend/.env as fallback 2026-03-11 16:10:31 +07:00
Zamil Majdy
fe265c10d4 refactor(analytics): generate setup.sql via --setup flag, gitignore it 2026-03-11 16:01:52 +07:00
Zamil Majdy
5d00a94693 chore(analytics): remove auto-generated files, gitignore views.sql 2026-03-11 16:00:48 +07:00
Zamil Majdy
6e1605994d feat(analytics): add documented SQL views with generation script
Introduces an analytics/ layer that wraps production Postgres data in
safe, read-only views exposed under the analytics schema.

- 14 documented query files in queries/ (one per Looker data source)
  covering auth activities, user activity, execution metrics, onboarding
  funnel, and cohort retention (login + execution, weekly + daily)
- setup.sql — one-time schema creation and role/grant setup for the
  analytics_readonly role (auth, platform, analytics schemas)
- generate_views.py — reads queries/*.sql and applies
  CREATE OR REPLACE VIEW analytics.<name> to the database;
  supports --dry-run, --only, and --db-url flags
- views.sql — pre-generated combined reference output
- README.md — full setup, deployment, and integration guide

Looker, PostHog Data Warehouse, and Supabase MCP (for Otto) all
connect to the same analytics.* views instead of raw tables.
2026-03-11 15:36:27 +07:00
Zamil Majdy
15e3980d65 fix(frontend): buffer workspace file downloads to prevent truncation (#12349)
## Summary
- Workspace file downloads (images, CSVs, etc.) were silently truncated
(~10 KB lost from the end) when served through the Next.js proxy
- Root cause: `new NextResponse(response.body)` passes a
`ReadableStream` directly, which Next.js / Vercel silently truncates for
larger files
- Fix: fully buffer with `response.arrayBuffer()` before forwarding, and
set `Content-Length` from the actual buffer size
- Keeps the auth proxy intact — no signed URLs (which would be public
and expire, breaking chat history)

## Root cause verification
Confirmed locally on session `080f27f9-0379-4085-a67a-ee34cc40cd62`:
- Backend `write_workspace_file` logs **978,831 bytes** written
- Direct backend download (`curl
localhost:8006/api/workspace/files/.../download`): **978,831 bytes** 
- Browser download through Next.js proxy: **truncated** 

## Why not signed URLs?
- Signed URLs are effectively public — anyone with the link can download
the file (privacy concern)
- Signed URLs expire, but chat history persists — reopening a
conversation later would show broken downloads
- Buffering is fine: workspace files are capped at 100 MB, Vercel
function memory is 1 GB

## Related
- Discord thread: `#Truncated File Bug` channel
- Related PR #12319 (signed URL approach) — this fix is simpler and
preserves auth

## Test plan
- [ ] Download a workspace file (CSV, PNG, any type) through the copilot
UI
- [ ] Verify downloaded file size matches the original
- [ ] Verify PNGs open correctly and CSVs have all rows

cc @Swiftyos @uberdot @AdarshRawat1
2026-03-10 18:23:51 +00:00
Zamil Majdy
fe9eb2564b feat(copilot): HITL review for sensitive block execution (#12356)
## Summary
- Integrates existing Human-In-The-Loop (HITL) review infrastructure
into CoPilot's direct block execution (`run_block`) for blocks marked
with `is_sensitive_action=True`
- Removes the `PendingHumanReview → AgentGraphExecution` FK constraint
to support synthetic CoPilot session IDs (migration included)
- Adds `ReviewRequiredResponse` model + frontend `ReviewRequiredCard`
component to surface review status in the chat UI
- Auto-approval works within a CoPilot session: once a block is
approved, subsequent executions of the same block in the same session
are auto-approved (using `copilot-session-{session_id}` as
`graph_exec_id` and `copilot-node-{block_id}` as `node_id`)

## Test plan
- [x] All 11 `run_block_test.py` tests pass (3 new sensitive action
tests)
- [ ] Manual: Execute a block with `is_sensitive_action=True` in CoPilot
→ verify ReviewRequiredResponse is returned and rendered
- [ ] Manual: Approve in review panel → re-execute the same block →
verify auto-approval kicks in
- [ ] Manual: Verify non-sensitive blocks still execute without review
2026-03-10 18:20:11 +00:00
Otto
5641cdd3ca fix(backend): update test patches for validate_url → validate_url_host rename (#12358)
bfb843a renamed `validate_url` to `validate_url_host` in
`agent_browser`, `run_mcp_tool`, and MCP routes, but the corresponding
test files still patched the old name, causing `AttributeError` in CI.

Updates all mock patch targets and assertions across 3 test files:
- `agent_browser_test.py`
- `test_run_mcp_tool.py`  
- `mcp/test_routes.py`

---
Co-authored-by: Zamil Majdy (@majdyz) <zamil.majdy@agpt.co>
Co-authored-by: Reinier van der Leer (@Pwuts) <pwuts@agpt.co>
2026-03-10 17:22:11 +00:00
Otto
bfb843a56e Merge commit from fork
* Fix SSRF via user-controlled ollama_host field

Validate ollama_host against BLOCKED_IP_NETWORKS before passing to
ollama.AsyncClient(). The server-configured default (env: OLLAMA_HOST)
is allowed without validation; user-supplied values that differ are
checked for private/internal IP resolution.

Fixes GHSA-6jx2-4h7q-3fx3

* Generalize validate_ollama_host to validate_host; fix description line length

* Rename to validate_untrusted_host with whitelist parameter

* Apply PR suggestion: include whitelist in error message; run formatting

* Move whitelist check after URL normalization; match on netloc

* revert unrelated formatting changes

* Dedup validate_url and validate_untrusted_host; normalize whitelist

* Move _resolve_and_check_blocked after calling functions

* dedup and clean up

* make trusted_hostnames truly optional

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2026-03-10 15:51:58 +01:00
Abhimanyu Yadav
684845d946 fix(frontend/builder): handle discriminated unions and improve node layout (#12354)
## Summary
- **Discriminated union support (oneOf)**: Added a new `OneOfField`
component that properly
renders Pydantic discriminated unions. Hides the unusable parent object
handle, auto-populates
the discriminator value, shows a dropdown with variant titles (e.g.,
"Username" / "UserId"), and
filters out the internal discriminator field from the form.
Non-discriminated `oneOf` schemas
  fall back to existing `AnyOfField` behavior.
- **Collapsible object outputs**: Object-type outputs with nested keys
(e.g.,
`PersonLookupResponse.Url`, `PersonLookupResponse.profile`) are now
collapsed by default behind a
caret toggle. Nested keys show short names instead of the full
`Parent.Key` prefix.
- **Node layout cleanup**: Removed excessive bottom margin (`mb-6`) from
`FormRenderer`, hide the
Advanced toggle when no advanced fields exist, and add rounded bottom
corners on OUTPUT-type
  blocks.

<img width="440" height="427" alt="Screenshot 2026-03-10 at 11 31 55 AM"
src="https://github.com/user-attachments/assets/06cc5414-4e02-4371-bdeb-1695e7cb2c97"
/>
<img width="371" height="320" alt="Screenshot 2026-03-10 at 11 36 52 AM"
src="https://github.com/user-attachments/assets/1a55f87a-c602-4f4d-b91b-6e49f810e5d5"
/>

  ## Test plan
- [x] Add a Twitter Get User block — verify "Identifier" shows a
dropdown (Username/UserId) with
no unusable parent handle, discriminator field is hidden, and the block
can run without staying
  INCOMPLETE
- [x] Add any block with object outputs (e.g., PersonLookupResponse) —
verify nested keys are
  collapsed by default and expand on click with short labels
- [x] Verify blocks without advanced fields don't show the Advanced
toggle
- [x] Verify existing `anyOf` schemas (optional types, 3+ variant
unions) still render correctly
  - [x] Check OUTPUT-type blocks have rounded bottom corners

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
Co-authored-by: eureka928 <meobius123@gmail.com>
2026-03-10 14:13:32 +00:00
Bently
6a6b23c2e1 fix(frontend): Remove unused Otto Server Action causing 107K+ errors (#12336)
## Summary

Fixes [OPEN-3025](https://linear.app/autogpt/issue/OPEN-3025) —
**107,571+ Server Action errors** in production

Removes the orphaned `askOtto` Server Action that was left behind after
the Otto chat widget removal in PR #12082.

## Problem

Next.js Server Actions that are never imported are excluded from the
server manifest. Old client bundles still reference the action ID,
causing "not found" errors.

**Sentry impact:**
- **BUILDER-3BN:** 107,571 events
- **BUILDER-729:** 285 events  
- **BUILDER-3QH:** 1,611 events
- **36+ users affected**

## Root Cause

1. **Mar 2025:** Otto widget added to `/build` page with `askOtto`
Server Action
2. **Feb 2026:** Otto widget removed (PR #12082), but `actions.ts` left
behind
3. **Result:** Dead code → not in manifest → errors

## Evidence

```bash
# Zero imports across frontend:
grep -r "askOtto" src/ --exclude="actions.ts"
# → No results

# Server manifest missing the action:
cat .next/server/server-reference-manifest.json
# → Only includes login/supabase actions, NOT build/actions
```

## Changes

-  Delete
`autogpt_platform/frontend/src/app/(platform)/build/actions.ts`

## Testing

1. Verify no imports of `askOtto` in codebase 
2. Check Sentry for error drop after deploy
3. Monitor for new "Server Action not found" errors

## Checklist

- [x] Dead code confirmed (zero imports)
- [x] Sentry issues documented
- [x] Clear commit message with context
2026-03-10 09:03:38 +00:00
Dream
d0a1d72e8a fix(frontend/builder): batch undo history for cascading operations (#12344)
## Summary

Fixes undo in the Builder not working correctly when deleting nodes.
When a node is deleted, React Flow fires `onNodesChange` (node removal)
and `onEdgesChange` (cascading edge cleanup) as separate callbacks —
each independently pushing to the undo history stack. This creates
intermediate states that break undo:

- Single undo restores a partial state (e.g. edges pointing to a deleted
node)
- Multiple undos required to fully restore the graph
- Redo also produces inconsistent states

Resolves #10999

### Changes 🏗️

- **`historyStore.ts`** — Added microtask-based batching to
`pushState()`. Multiple calls within the same synchronous execution
(same event loop tick) are coalesced into a single history entry,
keeping only the first pre-change snapshot. Uses `queueMicrotask` so all
cascading store updates from a single user action settle before the
history entry is committed.
- Reset `pendingState` in `initializeHistory()` and `clear()` to prevent
stale batched state from leaking across graph loads or navigation.

**Side benefit:** Copy/paste operations that add multiple nodes and
edges now also produce a single history entry instead of one per
node/edge.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Place 3 blocks A, B, C and connect A→B→C
  - [x] Delete block C (removes node + cascading edge B→C)
  - [x] Delete connection A→B
  - [x] Undo — connection A→B restored (single undo, not multiple)
  - [x] Undo — block C and connection B→C restored
  - [x] Redo — block C removed again with its connections
- [x] Copy/paste multiple connected blocks — single undo reverts entire
paste

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
Co-authored-by: Abhimanyu Yadav <122007096+Abhi1992002@users.noreply.github.com>
2026-03-10 04:55:07 +00:00
Zamil Majdy
f1945d6a2f feat(platform/copilot): @@agptfile: file-ref protocol for tool call inputs + block input toggle (#12332)
## Summary

- **Problem**: When the LLM calls a tool with large file content, it
must rewrite all content token-by-token. This is wasteful since the
files are already accessible on disk.
- **Solution**: Introduces an \`@@agptfile:\` reference protocol. The
LLM passes a file path reference; the processor loads and substitutes
the content before executing the tool.

### Protocol

\`\`\`
@@agptfile:<uri>[<start>-<end>]
\`\`\`

**Supported URI types:**
| URI | Source |
|-----|--------|
| \`workspace://<file_id>\` | Persistent workspace file by ID |
| \`workspace:///<path>\` | Workspace file by virtual path |
| \`/absolute/path\` | Absolute host or sandbox path |

**Line range** is optional; omitting it reads the whole file.

### Backend changes

- Rename \`@file:\` → \`@@agptfile:\` prefix for uniqueness; extract
\`FILE_REF_PREFIX\` constant
- Extract shared execution-context ContextVars into
\`backend/copilot/context.py\` — eliminates duplicate ContextVar objects
that caused \`e2b_file_tools.py\` to always see empty context
- \`tool_adapter.py\` imports ContextVars from \`context.py\` (single
source of truth)
- \`expand_file_refs_in_string\` raises \`FileRefExpansionError\` on
failure (instead of inline error strings), blocking tool execution and
returning a clear error hint to the model
- Tighten URI regex: only expand refs starting with \`workspace://\` or
\`/\`
- Aggregate budget: 1 MB total expansion cap across all refs in one
string
- Per-file cap: 200 KB per individual ref
- Fix \`_read_file_handler\` to pass \`get_sdk_cwd()\` to
\`is_allowed_local_path\` — ephemeral working directory files were
incorrectly blocked
- Fix \`_is_allowed_local\` in \`e2b_file_tools.py\` to pass
\`get_sdk_cwd()\`
- Restrict local path allow-list to \`tool-results/\` subdirectory only
(was entire session project dir)
- Add \`raise_on_error\` param + remove two-pass \`_FILE_REF_ERROR_RE\`
detection
- Update system prompt docs and tool_adapter error messages

### Frontend changes

- \`BlockInputCard\`: hidden by default with Show/Hide toggle + \`mb-2\`
spacing

## Test plan

- [ ] \`poetry run pytest backend/copilot/ -x
--ignore=backend/copilot/sdk/file_ref_integration_test.py\` passes
- [ ] \`@@agptfile:workspace:///<path>[1-50]\` expands correctly in tool
calls
- [ ] Invalid line ranges produce \`[file-ref error: ...]\` inline
messages
- [ ] Files outside \`sdk_cwd\` / \`tool-results/\` are rejected
- [ ] Block input card shows hidden by default with toggle
2026-03-09 18:39:13 +00:00
Zamil Majdy
6491cb1e23 feat(copilot): local agent generation with validation, fixing, MCP & sub-agent support (#12238)
## Summary

Port the agent generation pipeline from the external AgentGenerator
service into local copilot tools, making the Claude Agent SDK itself
handle validation, fixing, and block recommendation — no separate inner
LLM calls needed.

Key capabilities:
- **Local agent generation**: Create, edit, and customize agents
entirely within the SDK session
- **Graph validation**: 9 validation checks (block existence, link
references, type compatibility, IO blocks, etc.)
- **Graph fixing**: 17+ auto-fix methods (ID repair, link rewiring, type
conversion, credential stripping, dynamic block sink names, etc.)
- **MCP tool blocks**: Guide and fixer support for MCPToolBlock nodes
with proper dynamic input schema handling
- **Sub-agent composition**: AgentExecutorBlock support with library
agent schema enrichment
- **Embedding fallback**: Falls back to OpenRouter for embeddings when
`openai_internal_api_key` is unavailable
- **Actionable error messages**: Excluded block types (MCP, Agent)
return specific hints redirecting to the correct tool

### New Tools
- `validate_agent_graph` — run 9 validation checks on agent JSON
- `fix_agent_graph` — apply 17+ auto-fixes to agent JSON
- `get_blocks_for_goal` — recommend blocks for a given goal (with
optimized descriptions)

### Refactored Tools
- `create_agent`, `edit_agent`, `customize_agent` — accept `agent_json`
for local generation with shared fix→validate→save pipeline
- `find_block` — added `include_schemas` parameter, excludes MCP/Agent
blocks with actionable hints
- `run_block` — actionable error messages for excluded block types
- `find_library_agent` — enriched with `graph_version`, `input_schema`,
`output_schema` for sub-agent composition

### Architecture
- Split 2,558-line `validation.py` into `fixer.py`, `validator.py`,
`helpers.py`, `pipeline.py`
- Extracted shared `fix_validate_and_save()` pipeline (was duplicated
across 3 tools)
- Shared `OPENROUTER_BASE_URL` constant across codebase
- Comprehensive test coverage: 78+ unit tests for fixer/validator, 8
run_block tests, 17 SDK compat tests

## Test plan
- [x] `poetry run format` passes
- [x] `poetry run pytest -s -vvv backend/copilot/` — all tests pass
- [x] CI green on all Python versions (3.11, 3.12, 3.13)
- [x] Manual E2E: copilot generates agents with correct IO blocks,
links, and node structure
- [x] Manual E2E: MCP tool blocks use bare field names for dynamic
inputs
- [x] Manual E2E: sub-agent composition with AgentExecutorBlock
2026-03-09 16:10:22 +00:00
nKOxxx
c7124a5240 Add documentation for Google Gemini integration (#12283)
## Summary
Adding comprehensive documentation for Google Gemini integration with
AutoGPT.

## Changes
- Added setup instructions for Gemini API
- Documented configuration options
- Added examples and best practices

## Related Issues
N/A - Documentation improvement

## Testing
- Verified documentation accuracy
- Tested all code examples

## Checklist
- [x] Code follows project style
- [x] Documentation updated
- [x] Tests pass (if applicable)
2026-03-09 15:13:28 +00:00
Zamil Majdy
5537cb2858 dx: add shared Claude Code skills as auto-triggered guidelines (#12297)
## Summary
- Add 8 Claude Code skills under \`.claude/skills/\` that act as
**auto-triggered guidelines** — the LLM invokes them automatically based
on context, no manual \`/command\` needed
- Skills: \`pr-review\`, \`pr-create\`, \`new-block\`,
\`openapi-regen\`, \`backend-check\`, \`frontend-check\`,
\`worktree-setup\`, \`code-style\`
- Each skill has an explicit TRIGGER condition so the LLM knows when to
apply it without being asked

## Changes

### Skills (all auto-triggered by context)
| Skill | Trigger |
|-------|---------|
| \`pr-review\` | User shares a PR URL or asks to address review
comments |
| \`pr-create\` | User asks to create a PR, push changes for review, or
submit work |
| \`new-block\` | User asks to create a new block or add a new
integration |
| \`openapi-regen\` | API routes change, new endpoints added, or
frontend types are stale |
| \`backend-check\` | Backend Python code has been modified |
| \`frontend-check\` | Frontend TypeScript/React code has been modified
|
| \`worktree-setup\` | User asks to work on a branch in isolation or set
up a worktree |
| \`code-style\` | Writing or reviewing Python code |

## Test plan
- [ ] Verify skills appear automatically in Claude Code when context
matches (no \`/command\` needed)
- [ ] Modify frontend code — confirm \`frontend-check\` fires
automatically
- [ ] Ask Claude to "create a PR" — confirm \`pr-create\` fires without
\`/pr-create\`
- [ ] Share a PR URL — confirm \`pr-review\` fires automatically

---------

Co-authored-by: Krzysztof Czerwinski <kpczerwinski@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 15:10:38 +00:00
Zamil Majdy
aef5f6d666 feat(copilot): E2B sandbox auto-pause between turns to eliminate idle billing (#12330)
## Summary

### Before
- E2B sandboxes ran continuously between CoPilot turns, billing for idle
time
- Sandbox timeout caused **termination** (kill), losing all session
state
- No explicit cleanup when sessions were deleted — sandboxes leaked
- Single timeout concept with no separation between pause and kill
semantics

### After
- **Per-turn pause**: `pause_sandbox()` is called in the `finally` block
after every CoPilot turn, stopping billing instantly between turns
(paused sandboxes cost \$0 compute)
- **Auto-pause safety net**: Sandboxes are created with
`lifecycle={"on_timeout": "pause"}` (`pause_timeout` = 4h default) so
they auto-pause rather than terminate if the explicit pause is missed
- **Auto-reconnect**: `AsyncSandbox.connect()` in e2b SDK v2
auto-resumes paused sandboxes transparently — no extra code needed
- **Session delete cleanup**: `kill_sandbox()` is now called in
`delete_chat_session()` to explicitly terminate sandboxes and free
resources
- **Two distinct timeouts**: `pause_timeout` (4h, e2b auto-pause) vs
`redis_ttl` (12h, session key lifetime)

### Key Changes

| File | Change |
|------|--------|
| `pyproject.toml` | Bump `e2b-code-interpreter` `1.x` → `2.x` |
| `e2b_sandbox.py` | Add `pause_sandbox()`, `kill_sandbox()`,
`_act_on_sandbox()` helper; `lifecycle={"on_timeout": "pause"}`;
separate `pause_timeout` / `redis_ttl` params |
| `sdk/service.py` | Call `pause_sandbox()` in `finally` block
**before** transcript upload; use walrus operator for type-safe
`e2b_api_key` narrowing |
| `model.py` | Call `kill_sandbox()` in `delete_chat_session()`; inline
import to avoid circular dependency |
| `config.py` | Add `e2b_active` property; rename `e2b_sandbox_timeout`
default to 4h |
| `e2b_sandbox_test.py` | Add `test_pause_then_reconnect_reuses_sandbox`
test; update all `sandbox_timeout` → `pause_timeout` |

### Verified E2E
- Used real `E2B_API_KEY` from k8s dev cluster to manually verify:
sandbox created → paused → `is_running() == False` → reconnected via
`connect()` → state preserved → killed

## Test plan
- [x] `poetry run pytest backend/copilot/tools/e2b_sandbox_test.py` —
all 19 tests pass
- [x] CI: test (3.11, 3.12, 3.13), types — all green
- [x] E2E verified with real E2B credentials
2026-03-09 14:55:10 +00:00
Ubbe
8063391d0a feat(frontend/copilot): pin interactive tool cards outside reasoning collapse (#12346)
## Summary

<img width="400" height="227" alt="Screenshot 2026-03-09 at 22 43 10"
src="https://github.com/user-attachments/assets/0116e260-860d-4466-9763-e02de2766e50"
/>

<img width="600" height="618" alt="Screenshot 2026-03-09 at 22 43 14"
src="https://github.com/user-attachments/assets/beaa6aca-afa8-483f-ac06-439bf162c951"
/>

- When the copilot stream finishes, tool calls that require user
interaction (credentials, inputs, clarification) are now **pinned**
outside the "Show reasoning" collapse instead of being hidden
- Added `isInteractiveToolPart()` helper that checks tool output's
`type` field against a set of interactive response types
- Modified `splitReasoningAndResponse()` to extract interactive tools
from reasoning into the visible response section
- Added styleguide section with 3 demos: `setup_requirements`,
`agent_details`, and `agent_saved` pinning scenarios

### Interactive response types kept visible:
`setup_requirements`, `agent_details`, `block_details`, `need_login`,
`input_validation_error`, `clarification_needed`, `suggested_goal`,
`agent_preview`, `agent_saved`

Error responses remain in reasoning (LLM explains them in final text).

Closes SECRT-2088

## Test plan
- [ ] Verify copilot stream with interactive tool (e.g. run_agent
requiring credentials) keeps the tool card visible after stream ends
- [ ] Verify non-interactive tools (find_block, bash_exec) still
collapse into "Show reasoning"
- [ ] Verify styleguide page at `/copilot/styleguide` renders the new
"Reasoning Collapse: Interactive Tool Pinning" section correctly
- [ ] Verify `pnpm types`, `pnpm lint`, `pnpm format` all pass

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 23:12:14 +08:00
Otto
0bbb12d688 fix(frontend/copilot): hide New Chat button on Autopilot homepage (#12321)
Requested by @0ubbe

The **New Chat** button was visible on the Autopilot homepage where
clicking it has no effect (since `sessionId` is already `null`). This
hides the button when no chat session is active, so it only appears when
the user is viewing a conversation and wants to start a new one.

**Changes:**
- `ChatSidebar.tsx` — hide button in both collapsed and expanded sidebar
states when `sessionId` is null
- `MobileDrawer.tsx` — same fix for mobile drawer

---
Co-authored-by: Ubbe <ubbe@users.noreply.github.com>
2026-03-09 22:41:11 +08:00
147 changed files with 13079 additions and 3842 deletions

View File

@@ -0,0 +1,17 @@
---
name: backend-check
description: Run the full backend formatting, linting, and test suite. Ensures code quality before commits and PRs. TRIGGER when backend Python code has been modified and needs validation.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Backend Check
## Steps
1. **Format**: `poetry run format` — runs formatting AND linting. NEVER run ruff/black/isort individually
2. **Fix** any remaining errors manually, re-run until clean
3. **Test**: `poetry run test` (runs DB setup + pytest). For specific files: `poetry run pytest -s -vvv <test_files>`
4. **Snapshots** (if needed): `poetry run pytest path/to/test.py --snapshot-update` — review with `git diff`

View File

@@ -0,0 +1,35 @@
---
name: code-style
description: Python code style preferences for the AutoGPT backend. Apply when writing or reviewing Python code. TRIGGER when writing new Python code, reviewing PRs, or refactoring backend code.
user-invocable: false
metadata:
author: autogpt-team
version: "1.0.0"
---
# Code Style
## Imports
- **Top-level only** — no local/inner imports. Move all imports to the top of the file.
## Typing
- **No duck typing** — avoid `hasattr`, `getattr`, `isinstance` for type dispatch. Use proper typed interfaces, unions, or protocols.
- **Pydantic models** over dataclass, namedtuple, or raw dict for structured data.
- **No linter suppressors** — avoid `# type: ignore`, `# noqa`, `# pyright: ignore` etc. 99% of the time the right fix is fixing the type/code, not silencing the tool.
## Code Structure
- **List comprehensions** over manual loop-and-append.
- **Early return** — guard clauses first, avoid deep nesting.
- **Flatten inline** — prefer short, concise expressions. Reduce `if/else` chains with direct returns or ternaries when readable.
- **Modular functions** — break complex logic into small, focused functions rather than long blocks with nested conditionals.
## Review Checklist
Before finishing, always ask:
- Can any function be split into smaller pieces?
- Is there unnecessary nesting that an early return would eliminate?
- Can any loop be a comprehension?
- Is there a simpler way to express this logic?

View File

@@ -0,0 +1,16 @@
---
name: frontend-check
description: Run the full frontend formatting, linting, and type checking suite. Ensures code quality before commits and PRs. TRIGGER when frontend TypeScript/React code has been modified and needs validation.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Frontend Check
## Steps (in order)
1. **Format**: `pnpm format` — NEVER run individual formatters
2. **Lint**: `pnpm lint` — fix errors, re-run until clean
3. **Types**: `pnpm types` — if it keeps failing after multiple attempts, stop and ask the user

View File

@@ -0,0 +1,29 @@
---
name: new-block
description: Create a new backend block following the Block SDK Guide. Guides through provider configuration, schema definition, authentication, and testing. TRIGGER when user asks to create a new block, add a new integration, or build a new node for the graph editor.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# New Block Creation
Read `docs/platform/block-sdk-guide.md` first for the full guide.
## Steps
1. **Provider config** (if external service): create `_config.py` with `ProviderBuilder`
2. **Block file** in `backend/blocks/` (from `autogpt_platform/backend/`):
- Generate a UUID once with `uuid.uuid4()`, then **hard-code that string** as `id` (IDs must be stable across imports)
- `Input(BlockSchema)` and `Output(BlockSchema)` classes
- `async def run` that `yield`s output fields
3. **Files**: use `store_media_file()` with `"for_block_output"` for outputs
4. **Test**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[MyBlock]' -xvs`
5. **Format**: `poetry run format`
## Rules
- Analyze interfaces: do inputs/outputs connect well with other blocks in a graph?
- Use top-level imports, avoid duck typing
- Always use `for_block_output` for block outputs

View File

@@ -0,0 +1,28 @@
---
name: openapi-regen
description: Regenerate the OpenAPI spec and frontend API client. Starts the backend REST server, fetches the spec, and regenerates the typed frontend hooks. TRIGGER when API routes change, new endpoints are added, or frontend API types are stale.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# OpenAPI Spec Regeneration
## Steps
1. **Run end-to-end** in a single shell block (so `REST_PID` persists):
```bash
cd autogpt_platform/backend && poetry run rest &
REST_PID=$!
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && kill $REST_PID && exit 1; done
cd ../frontend && pnpm generate:api:force
kill $REST_PID
pnpm types && pnpm lint && pnpm format
```
## Rules
- Always use `pnpm generate:api:force` (not `pnpm generate:api`)
- Don't manually edit files in `src/app/api/__generated__/`
- Generated hooks follow: `use{Method}{Version}{OperationName}`

View File

@@ -0,0 +1,31 @@
---
name: pr-create
description: Create a pull request for the current branch. TRIGGER when user asks to create a PR, open a pull request, push changes for review, or submit work for merging.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Create Pull Request
## Steps
1. **Check for existing PR**: `gh pr view --json url -q .url 2>/dev/null` — if a PR already exists, output its URL and stop
2. **Understand changes**: `git status`, `git diff dev...HEAD`, `git log dev..HEAD --oneline`
3. **Read PR template**: `.github/PULL_REQUEST_TEMPLATE.md`
4. **Draft PR title**: Use conventional commits format (see CLAUDE.md for types and scopes)
5. **Fill out PR template** as the body — be thorough in the Changes section
6. **Format first** (if relevant changes exist):
- Backend: `cd autogpt_platform/backend && poetry run format`
- Frontend: `cd autogpt_platform/frontend && pnpm format`
- Fix any lint errors, then commit formatting changes before pushing
7. **Push**: `git push -u origin HEAD`
8. **Create PR**: `gh pr create --base dev`
9. **Output** the PR URL
## Rules
- Always target `dev` branch
- Do NOT run tests — CI will handle that
- Use the PR template from `.github/PULL_REQUEST_TEMPLATE.md`

View File

@@ -0,0 +1,51 @@
---
name: pr-review
description: Address all open PR review comments systematically. Fetches comments, addresses each one, reacts +1/-1, and replies when clarification is needed. Keeps iterating until all comments are addressed and CI is green. TRIGGER when user shares a PR URL, asks to address review comments, fix PR feedback, or respond to reviewer comments.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# PR Review Comment Workflow
## Steps
1. **Find PR**: `gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT`
2. **Fetch comments** (all three sources):
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` (top-level reviews)
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` (inline review comments)
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` (PR conversation comments)
3. **Skip** comments already reacted to by PR author
4. **For each unreacted comment**:
- Read referenced code, make the fix (or reply if you disagree/need info)
- **Inline review comments** (`pulls/{N}/comments`):
- React: `gh api repos/.../pulls/comments/{ID}/reactions -f content="+1"` (or `-1`)
- Reply: `gh api repos/.../pulls/{N}/comments/{ID}/replies -f body="..."`
- **PR conversation comments** (`issues/{N}/comments`):
- React: `gh api repos/.../issues/comments/{ID}/reactions -f content="+1"` (or `-1`)
- No threaded replies — post a new issue comment if needed
- **Top-level reviews**: no reaction API — address in code, reply via issue comment if needed
5. **Include autogpt-reviewer bot fixes** too
6. **Format**: `cd autogpt_platform/backend && poetry run format`, `cd autogpt_platform/frontend && pnpm format`
7. **Commit & push**
8. **Re-fetch comments** immediately — address any new unreacted ones before waiting on CI
9. **Stay productive while CI runs** — don't idle. In priority order:
- Run any pending local tests (`poetry run pytest`, e2e, etc.) and fix failures
- Address any remaining comments
- Only poll `gh pr checks {N}` as the last resort when there's truly nothing left to do
10. **If CI fails** — fix, go back to step 6
11. **Re-fetch comments again** after CI is green — address anything that appeared while CI was running
12. **Done** only when: all comments reacted AND CI is green.
## CRITICAL: Do Not Stop
**Loop is: address → format → commit → push → re-check comments → run local tests → wait CI → re-check comments → repeat.**
Never idle. If CI is running and you have nothing to address, run local tests. Waiting on CI is the last resort.
## Rules
- One todo per comment
- For inline review comments: reply on existing threads. For PR conversation comments: post a new issue comment (API doesn't support threaded replies)
- React to every comment: +1 addressed, -1 disagreed (with explanation)

View File

@@ -0,0 +1,45 @@
---
name: worktree-setup
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, generates Prisma client, and optionally starts the app (with port conflict resolution) or runs tests. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Worktree Setup
## Preferred: Use Branchlet
The repo has a `.branchlet.json` config — it handles env file copying, dependency installation, and Prisma generation automatically.
```bash
npm install -g branchlet # install once
branchlet create -n <name> -s <source-branch> -b <new-branch>
branchlet list --json # list all worktrees
```
## Manual Fallback
If branchlet isn't available:
1. `git worktree add ../<RepoName><N> <branch-name>`
2. Copy `.env` files: `backend/.env`, `frontend/.env`, `autogpt_platform/.env`, `db/docker/.env`
3. Install deps:
- `cd autogpt_platform/backend && poetry install && poetry run prisma generate`
- `cd autogpt_platform/frontend && pnpm install`
## Running the App
Free ports first — backend uses: 8001, 8002, 8003, 8005, 8006, 8007, 8008.
```bash
for port in 8001 8002 8003 8005 8006 8007 8008; do
lsof -ti :$port | xargs kill -9 2>/dev/null || true
done
cd <worktree>/autogpt_platform/backend && poetry run app
```
## CoPilot Testing Gotcha
SDK mode spawns a Claude subprocess — **won't work inside Claude Code**. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.

View File

@@ -0,0 +1,40 @@
-- =============================================================
-- View: analytics.auth_activities
-- Looker source alias: ds49 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Tracks authentication events (login, logout, SSO, password
-- reset, etc.) from Supabase's internal audit log.
-- Useful for monitoring sign-in patterns and detecting anomalies.
--
-- SOURCE TABLES
-- auth.audit_log_entries — Supabase internal auth event log
--
-- OUTPUT COLUMNS
-- created_at TIMESTAMPTZ When the auth event occurred
-- actor_id TEXT User ID who triggered the event
-- actor_via_sso TEXT Whether the action was via SSO ('true'/'false')
-- action TEXT Event type (e.g. 'login', 'logout', 'token_refreshed')
--
-- WINDOW
-- Rolling 90 days from current date
--
-- EXAMPLE QUERIES
-- -- Daily login counts
-- SELECT DATE_TRUNC('day', created_at) AS day, COUNT(*) AS logins
-- FROM analytics.auth_activities
-- WHERE action = 'login'
-- GROUP BY 1 ORDER BY 1;
--
-- -- SSO vs password login breakdown
-- SELECT actor_via_sso, COUNT(*) FROM analytics.auth_activities
-- WHERE action = 'login' GROUP BY 1;
-- =============================================================
SELECT
created_at,
payload->>'actor_id' AS actor_id,
payload->>'actor_via_sso' AS actor_via_sso,
payload->>'action' AS action
FROM auth.audit_log_entries
WHERE created_at >= NOW() - INTERVAL '90 days'

View File

@@ -0,0 +1,105 @@
-- =============================================================
-- View: analytics.graph_execution
-- Looker source alias: ds16 | Charts: 21
-- =============================================================
-- DESCRIPTION
-- One row per agent graph execution (last 90 days).
-- Unpacks the JSONB stats column into individual numeric columns
-- and normalises the executionStatus — runs that failed due to
-- insufficient credits are reclassified as 'NO_CREDITS' for
-- easier filtering. Error messages are scrubbed of IDs and URLs
-- to allow safe grouping.
--
-- SOURCE TABLES
-- platform.AgentGraphExecution — Execution records
-- platform.AgentGraph — Agent graph metadata (for name)
-- platform.LibraryAgent — To flag possibly-AI (safe-mode) agents
--
-- OUTPUT COLUMNS
-- id TEXT Execution UUID
-- agentGraphId TEXT Agent graph UUID
-- agentGraphVersion INT Graph version number
-- executionStatus TEXT COMPLETED | FAILED | NO_CREDITS | RUNNING | QUEUED | TERMINATED
-- createdAt TIMESTAMPTZ When the execution was queued
-- updatedAt TIMESTAMPTZ Last status update time
-- userId TEXT Owner user UUID
-- agentGraphName TEXT Human-readable agent name
-- cputime DECIMAL Total CPU seconds consumed
-- walltime DECIMAL Total wall-clock seconds
-- node_count DECIMAL Number of nodes in the graph
-- nodes_cputime DECIMAL CPU time across all nodes
-- nodes_walltime DECIMAL Wall time across all nodes
-- execution_cost DECIMAL Credit cost of this execution
-- correctness_score FLOAT AI correctness score (if available)
-- possibly_ai BOOLEAN True if agent has sensitive_action_safe_mode enabled
-- groupedErrorMessage TEXT Scrubbed error string (IDs/URLs replaced with wildcards)
--
-- WINDOW
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
--
-- EXAMPLE QUERIES
-- -- Daily execution counts by status
-- SELECT DATE_TRUNC('day', "createdAt") AS day, "executionStatus", COUNT(*)
-- FROM analytics.graph_execution
-- GROUP BY 1, 2 ORDER BY 1;
--
-- -- Average cost per execution by agent
-- SELECT "agentGraphName", AVG("execution_cost") AS avg_cost, COUNT(*) AS runs
-- FROM analytics.graph_execution
-- WHERE "executionStatus" = 'COMPLETED'
-- GROUP BY 1 ORDER BY avg_cost DESC;
--
-- -- Top error messages
-- SELECT "groupedErrorMessage", COUNT(*) AS occurrences
-- FROM analytics.graph_execution
-- WHERE "executionStatus" = 'FAILED'
-- GROUP BY 1 ORDER BY 2 DESC LIMIT 20;
-- =============================================================
SELECT
ge."id" AS id,
ge."agentGraphId" AS agentGraphId,
ge."agentGraphVersion" AS agentGraphVersion,
CASE
WHEN jsonb_exists(ge."stats"::jsonb, 'error')
AND (
(ge."stats"::jsonb->>'error') ILIKE '%insufficient balance%'
OR (ge."stats"::jsonb->>'error') ILIKE '%you have no credits left%'
)
THEN 'NO_CREDITS'
ELSE CAST(ge."executionStatus" AS TEXT)
END AS executionStatus,
ge."createdAt" AS createdAt,
ge."updatedAt" AS updatedAt,
ge."userId" AS userId,
g."name" AS agentGraphName,
(ge."stats"::jsonb->>'cputime')::decimal AS cputime,
(ge."stats"::jsonb->>'walltime')::decimal AS walltime,
(ge."stats"::jsonb->>'node_count')::decimal AS node_count,
(ge."stats"::jsonb->>'nodes_cputime')::decimal AS nodes_cputime,
(ge."stats"::jsonb->>'nodes_walltime')::decimal AS nodes_walltime,
(ge."stats"::jsonb->>'cost')::decimal AS execution_cost,
(ge."stats"::jsonb->>'correctness_score')::float AS correctness_score,
COALESCE(la.possibly_ai, FALSE) AS possibly_ai,
REGEXP_REPLACE(
REGEXP_REPLACE(
TRIM(BOTH '"' FROM ge."stats"::jsonb->>'error'),
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
'\1\2/...', 'gi'
),
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
) AS groupedErrorMessage
FROM platform."AgentGraphExecution" ge
LEFT JOIN platform."AgentGraph" g
ON ge."agentGraphId" = g."id"
AND ge."agentGraphVersion" = g."version"
LEFT JOIN (
SELECT DISTINCT ON ("userId", "agentGraphId")
"userId", "agentGraphId",
("settings"::jsonb->>'sensitive_action_safe_mode')::boolean AS possibly_ai
FROM platform."LibraryAgent"
WHERE "isDeleted" = FALSE
AND "isArchived" = FALSE
ORDER BY "userId", "agentGraphId", "agentGraphVersion" DESC
) la ON la."userId" = ge."userId" AND la."agentGraphId" = ge."agentGraphId"
WHERE ge."createdAt" > CURRENT_DATE - INTERVAL '90 days'

View File

@@ -0,0 +1,101 @@
-- =============================================================
-- View: analytics.node_block_execution
-- Looker source alias: ds14 | Charts: 11
-- =============================================================
-- DESCRIPTION
-- One row per node (block) execution (last 90 days).
-- Unpacks stats JSONB and joins to identify which block type
-- was run. For failed nodes, joins the error output and
-- scrubs it for safe grouping.
--
-- SOURCE TABLES
-- platform.AgentNodeExecution — Node execution records
-- platform.AgentNode — Node → block mapping
-- platform.AgentBlock — Block name/ID
-- platform.AgentNodeExecutionInputOutput — Error output values
--
-- OUTPUT COLUMNS
-- id TEXT Node execution UUID
-- agentGraphExecutionId TEXT Parent graph execution UUID
-- agentNodeId TEXT Node UUID within the graph
-- executionStatus TEXT COMPLETED | FAILED | QUEUED | RUNNING | TERMINATED
-- addedTime TIMESTAMPTZ When the node was queued
-- queuedTime TIMESTAMPTZ When it entered the queue
-- startedTime TIMESTAMPTZ When execution started
-- endedTime TIMESTAMPTZ When execution finished
-- inputSize BIGINT Input payload size in bytes
-- outputSize BIGINT Output payload size in bytes
-- walltime NUMERIC Wall-clock seconds for this node
-- cputime NUMERIC CPU seconds for this node
-- llmRetryCount INT Number of LLM retries
-- llmCallCount INT Number of LLM API calls made
-- inputTokenCount BIGINT LLM input tokens consumed
-- outputTokenCount BIGINT LLM output tokens produced
-- blockName TEXT Human-readable block name (e.g. 'OpenAIBlock')
-- blockId TEXT Block UUID
-- groupedErrorMessage TEXT Scrubbed error (IDs/URLs wildcarded)
-- errorMessage TEXT Raw error output (only set when FAILED)
--
-- WINDOW
-- Rolling 90 days (addedTime > CURRENT_DATE - 90 days)
--
-- EXAMPLE QUERIES
-- -- Most-used blocks by execution count
-- SELECT "blockName", COUNT(*) AS executions,
-- COUNT(*) FILTER (WHERE "executionStatus"='FAILED') AS failures
-- FROM analytics.node_block_execution
-- GROUP BY 1 ORDER BY executions DESC LIMIT 20;
--
-- -- Average LLM token usage per block
-- SELECT "blockName",
-- AVG("inputTokenCount") AS avg_input_tokens,
-- AVG("outputTokenCount") AS avg_output_tokens
-- FROM analytics.node_block_execution
-- WHERE "llmCallCount" > 0
-- GROUP BY 1 ORDER BY avg_input_tokens DESC;
--
-- -- Top failure reasons
-- SELECT "blockName", "groupedErrorMessage", COUNT(*) AS count
-- FROM analytics.node_block_execution
-- WHERE "executionStatus" = 'FAILED'
-- GROUP BY 1, 2 ORDER BY count DESC LIMIT 20;
-- =============================================================
SELECT
ne."id" AS id,
ne."agentGraphExecutionId" AS agentGraphExecutionId,
ne."agentNodeId" AS agentNodeId,
CAST(ne."executionStatus" AS TEXT) AS executionStatus,
ne."addedTime" AS addedTime,
ne."queuedTime" AS queuedTime,
ne."startedTime" AS startedTime,
ne."endedTime" AS endedTime,
(ne."stats"::jsonb->>'input_size')::bigint AS inputSize,
(ne."stats"::jsonb->>'output_size')::bigint AS outputSize,
(ne."stats"::jsonb->>'walltime')::numeric AS walltime,
(ne."stats"::jsonb->>'cputime')::numeric AS cputime,
(ne."stats"::jsonb->>'llm_retry_count')::int AS llmRetryCount,
(ne."stats"::jsonb->>'llm_call_count')::int AS llmCallCount,
(ne."stats"::jsonb->>'input_token_count')::bigint AS inputTokenCount,
(ne."stats"::jsonb->>'output_token_count')::bigint AS outputTokenCount,
b."name" AS blockName,
b."id" AS blockId,
REGEXP_REPLACE(
REGEXP_REPLACE(
TRIM(BOTH '"' FROM eio."data"::text),
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
'\1\2/...', 'gi'
),
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
) AS groupedErrorMessage,
eio."data" AS errorMessage
FROM platform."AgentNodeExecution" ne
LEFT JOIN platform."AgentNode" nd
ON ne."agentNodeId" = nd."id"
LEFT JOIN platform."AgentBlock" b
ON nd."agentBlockId" = b."id"
LEFT JOIN platform."AgentNodeExecutionInputOutput" eio
ON eio."referencedByOutputExecId" = ne."id"
AND eio."name" = 'error'
AND ne."executionStatus" = 'FAILED'
WHERE ne."addedTime" > CURRENT_DATE - INTERVAL '90 days'

View File

@@ -0,0 +1,97 @@
-- =============================================================
-- View: analytics.retention_agent
-- Looker source alias: ds35 | Charts: 2
-- =============================================================
-- DESCRIPTION
-- Weekly cohort retention broken down per individual agent.
-- Cohort = week of a user's first use of THAT specific agent.
-- Tells you which agents keep users coming back vs. one-shot
-- use. Only includes cohorts from the last 180 days.
--
-- SOURCE TABLES
-- platform.AgentGraphExecution — Execution records (user × agent × time)
-- platform.AgentGraph — Agent names
--
-- OUTPUT COLUMNS
-- agent_id TEXT Agent graph UUID
-- agent_label TEXT 'AgentName [first8chars]'
-- agent_label_n TEXT 'AgentName [first8chars] (n=total_users)'
-- cohort_week_start DATE Week users first ran this agent
-- cohort_label TEXT ISO week label
-- cohort_label_n TEXT ISO week label with cohort size
-- user_lifetime_week INT Weeks since first use of this agent
-- cohort_users BIGINT Users in this cohort for this agent
-- active_users BIGINT Users who ran the agent again in week k
-- retention_rate FLOAT active_users / cohort_users
-- cohort_users_w0 BIGINT cohort_users only at week 0 (safe to SUM)
-- agent_total_users BIGINT Total users across all cohorts for this agent
--
-- EXAMPLE QUERIES
-- -- Best-retained agents at week 2
-- SELECT agent_label, AVG(retention_rate) AS w2_retention
-- FROM analytics.retention_agent
-- WHERE user_lifetime_week = 2 AND cohort_users >= 10
-- GROUP BY 1 ORDER BY w2_retention DESC LIMIT 10;
--
-- -- Agents with most unique users
-- SELECT DISTINCT agent_label, agent_total_users
-- FROM analytics.retention_agent
-- ORDER BY agent_total_users DESC LIMIT 20;
-- =============================================================
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
events AS (
SELECT e."userId"::text AS user_id, e."agentGraphId" AS agent_id,
e."createdAt"::timestamptz AS created_at,
DATE_TRUNC('week', e."createdAt")::date AS week_start
FROM platform."AgentGraphExecution" e
),
first_use AS (
SELECT user_id, agent_id, MIN(created_at) AS first_use_at,
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
FROM events GROUP BY 1,2
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
),
activity_weeks AS (SELECT DISTINCT user_id, agent_id, week_start FROM events),
user_week_age AS (
SELECT aw.user_id, aw.agent_id, fu.cohort_week_start,
((aw.week_start - DATE_TRUNC('week',fu.first_use_at)::date)/7)::int AS user_lifetime_week
FROM activity_weeks aw JOIN first_use fu USING (user_id, agent_id)
WHERE aw.week_start >= DATE_TRUNC('week',fu.first_use_at)::date
),
active_counts AS (
SELECT agent_id, cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2,3
),
cohort_sizes AS (
SELECT agent_id, cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_use GROUP BY 1,2
),
cohort_caps AS (
SELECT cs.agent_id, cs.cohort_week_start, cs.cohort_users,
LEAST((SELECT max_weeks FROM params),
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
FROM cohort_sizes cs
),
grid AS (
SELECT cc.agent_id, cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
),
agent_names AS (SELECT DISTINCT ON (g."id") g."id" AS agent_id, g."name" AS agent_name FROM platform."AgentGraph" g ORDER BY g."id", g."version" DESC),
agent_total_users AS (SELECT agent_id, SUM(cohort_users) AS agent_total_users FROM cohort_sizes GROUP BY 1)
SELECT
g.agent_id,
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||']' AS agent_label,
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||'] (n='||COALESCE(atu.agent_total_users,0)||')' AS agent_label_n,
g.cohort_week_start,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_week, g.cohort_users,
COALESCE(ac.active_users,0) AS active_users,
COALESCE(ac.active_users,0)::float / NULLIF(g.cohort_users,0) AS retention_rate,
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0,
COALESCE(atu.agent_total_users,0) AS agent_total_users
FROM grid g
LEFT JOIN active_counts ac ON ac.agent_id=g.agent_id AND ac.cohort_week_start=g.cohort_week_start AND ac.user_lifetime_week=g.user_lifetime_week
LEFT JOIN agent_names an ON an.agent_id=g.agent_id
LEFT JOIN agent_total_users atu ON atu.agent_id=g.agent_id
ORDER BY agent_label, g.cohort_week_start, g.user_lifetime_week;

View File

@@ -0,0 +1,81 @@
-- =============================================================
-- View: analytics.retention_execution_daily
-- Looker source alias: ds111 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Daily cohort retention based on agent executions.
-- Cohort anchor = day of user's FIRST ever execution.
-- Only includes cohorts from the last 90 days, up to day 30.
-- Great for early engagement analysis (did users run another
-- agent the next day?).
--
-- SOURCE TABLES
-- platform.AgentGraphExecution — Execution records
--
-- OUTPUT COLUMNS
-- Same pattern as retention_login_daily.
-- cohort_day_start = day of first execution (not first login)
--
-- EXAMPLE QUERIES
-- -- Day-3 execution retention
-- SELECT cohort_label, retention_rate_bounded AS d3_retention
-- FROM analytics.retention_execution_daily
-- WHERE user_lifetime_day = 3 ORDER BY cohort_day_start;
-- =============================================================
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days') AS cohort_start),
events AS (
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
DATE_TRUNC('day', e."createdAt")::date AS day_start
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
),
first_exec AS (
SELECT user_id, MIN(created_at) AS first_exec_at,
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
FROM events GROUP BY 1
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
),
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
user_day_age AS (
SELECT ad.user_id, fe.cohort_day_start,
(ad.day_start - DATE_TRUNC('day',fe.first_exec_at)::date)::int AS user_lifetime_day
FROM activity_days ad JOIN first_exec fe USING (user_id)
WHERE ad.day_start >= DATE_TRUNC('day',fe.first_exec_at)::date
),
bounded_counts AS (
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_day_start, cs.cohort_users,
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
)
SELECT
g.cohort_day_start,
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_day, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
ORDER BY g.cohort_day_start, g.user_lifetime_day;

View File

@@ -0,0 +1,81 @@
-- =============================================================
-- View: analytics.retention_execution_weekly
-- Looker source alias: ds92 | Charts: 2
-- =============================================================
-- DESCRIPTION
-- Weekly cohort retention based on agent executions.
-- Cohort anchor = week of user's FIRST ever agent execution
-- (not first login). Only includes cohorts from the last 180 days.
-- Useful when you care about product engagement, not just visits.
--
-- SOURCE TABLES
-- platform.AgentGraphExecution — Execution records
--
-- OUTPUT COLUMNS
-- Same pattern as retention_login_weekly.
-- cohort_week_start = week of first execution (not first login)
--
-- EXAMPLE QUERIES
-- -- Week-2 execution retention
-- SELECT cohort_label, retention_rate_bounded
-- FROM analytics.retention_execution_weekly
-- WHERE user_lifetime_week = 2 ORDER BY cohort_week_start;
-- =============================================================
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
events AS (
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
DATE_TRUNC('week', e."createdAt")::date AS week_start
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
),
first_exec AS (
SELECT user_id, MIN(created_at) AS first_exec_at,
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
FROM events GROUP BY 1
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
),
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
user_week_age AS (
SELECT aw.user_id, fe.cohort_week_start,
((aw.week_start - DATE_TRUNC('week',fe.first_exec_at)::date)/7)::int AS user_lifetime_week
FROM activity_weeks aw JOIN first_exec fe USING (user_id)
WHERE aw.week_start >= DATE_TRUNC('week',fe.first_exec_at)::date
),
bounded_counts AS (
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_week_start, cs.cohort_users,
LEAST((SELECT max_weeks FROM params),
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
)
SELECT
g.cohort_week_start,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_week, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
ORDER BY g.cohort_week_start, g.user_lifetime_week;

View File

@@ -0,0 +1,94 @@
-- =============================================================
-- View: analytics.retention_login_daily
-- Looker source alias: ds112 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Daily cohort retention based on login sessions.
-- Same logic as retention_login_weekly but at day granularity,
-- showing up to day 30 for cohorts from the last 90 days.
-- Useful for analysing early activation (days 1-7) in detail.
--
-- SOURCE TABLES
-- auth.sessions — Login session records
--
-- OUTPUT COLUMNS (same pattern as retention_login_weekly)
-- cohort_day_start DATE First day the cohort logged in
-- cohort_label TEXT Date string (e.g. '2025-03-01')
-- cohort_label_n TEXT Date + cohort size (e.g. '2025-03-01 (n=12)')
-- user_lifetime_day INT Days since first login (0 = signup day)
-- cohort_users BIGINT Total users in cohort
-- active_users_bounded BIGINT Users active on exactly day k
-- retained_users_unbounded BIGINT Users active any time on/after day k
-- retention_rate_bounded FLOAT bounded / cohort_users
-- retention_rate_unbounded FLOAT unbounded / cohort_users
-- cohort_users_d0 BIGINT cohort_users only at day 0, else 0 (safe to SUM)
--
-- EXAMPLE QUERIES
-- -- Day-1 retention rate (came back next day)
-- SELECT cohort_label, retention_rate_bounded AS d1_retention
-- FROM analytics.retention_login_daily
-- WHERE user_lifetime_day = 1 ORDER BY cohort_day_start;
--
-- -- Average retention curve across all cohorts
-- SELECT user_lifetime_day,
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_d0), 0) AS avg_retention
-- FROM analytics.retention_login_daily
-- GROUP BY 1 ORDER BY 1;
-- =============================================================
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days')::date AS cohort_start),
events AS (
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
DATE_TRUNC('day', s.created_at)::date AS day_start
FROM auth.sessions s WHERE s.user_id IS NOT NULL
),
first_login AS (
SELECT user_id, MIN(created_at) AS first_login_time,
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
FROM events GROUP BY 1
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
),
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
user_day_age AS (
SELECT ad.user_id, fl.cohort_day_start,
(ad.day_start - DATE_TRUNC('day', fl.first_login_time)::date)::int AS user_lifetime_day
FROM activity_days ad JOIN first_login fl USING (user_id)
WHERE ad.day_start >= DATE_TRUNC('day', fl.first_login_time)::date
),
bounded_counts AS (
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_day_start, cs.cohort_users,
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
)
SELECT
g.cohort_day_start,
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_day, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
ORDER BY g.cohort_day_start, g.user_lifetime_day;

View File

@@ -0,0 +1,96 @@
-- =============================================================
-- View: analytics.retention_login_onboarded_weekly
-- Looker source alias: ds101 | Charts: 2
-- =============================================================
-- DESCRIPTION
-- Weekly cohort retention from login sessions, restricted to
-- users who "onboarded" — defined as running at least one
-- agent within 365 days of their first login.
-- Filters out users who signed up but never activated,
-- giving a cleaner view of engaged-user retention.
--
-- SOURCE TABLES
-- auth.sessions — Login session records
-- platform.AgentGraphExecution — Used to identify onboarders
--
-- OUTPUT COLUMNS
-- Same as retention_login_weekly (cohort_week_start, user_lifetime_week,
-- retention_rate_bounded, retention_rate_unbounded, etc.)
-- Only difference: cohort is filtered to onboarded users only.
--
-- EXAMPLE QUERIES
-- -- Compare week-4 retention: all users vs onboarded only
-- SELECT 'all_users' AS segment, AVG(retention_rate_bounded) AS w4_retention
-- FROM analytics.retention_login_weekly WHERE user_lifetime_week = 4
-- UNION ALL
-- SELECT 'onboarded', AVG(retention_rate_bounded)
-- FROM analytics.retention_login_onboarded_weekly WHERE user_lifetime_week = 4;
-- =============================================================
WITH params AS (SELECT 12::int AS max_weeks, 365::int AS onboarding_window_days),
events AS (
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
DATE_TRUNC('week', s.created_at)::date AS week_start
FROM auth.sessions s WHERE s.user_id IS NOT NULL
),
first_login_all AS (
SELECT user_id, MIN(created_at) AS first_login_time,
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
FROM events GROUP BY 1
),
onboarders AS (
SELECT fl.user_id FROM first_login_all fl
WHERE EXISTS (
SELECT 1 FROM platform."AgentGraphExecution" e
WHERE e."userId"::text = fl.user_id
AND e."createdAt" >= fl.first_login_time
AND e."createdAt" < fl.first_login_time
+ make_interval(days => (SELECT onboarding_window_days FROM params))
)
),
first_login AS (SELECT * FROM first_login_all WHERE user_id IN (SELECT user_id FROM onboarders)),
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
user_week_age AS (
SELECT aw.user_id, fl.cohort_week_start,
((aw.week_start - DATE_TRUNC('week',fl.first_login_time)::date)/7)::int AS user_lifetime_week
FROM activity_weeks aw JOIN first_login fl USING (user_id)
WHERE aw.week_start >= DATE_TRUNC('week',fl.first_login_time)::date
),
bounded_counts AS (
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_week_start, cs.cohort_users,
LEAST((SELECT max_weeks FROM params),
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
)
SELECT
g.cohort_week_start,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_week, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
ORDER BY g.cohort_week_start, g.user_lifetime_week;

View File

@@ -0,0 +1,103 @@
-- =============================================================
-- View: analytics.retention_login_weekly
-- Looker source alias: ds83 | Charts: 2
-- =============================================================
-- DESCRIPTION
-- Weekly cohort retention based on login sessions.
-- Users are grouped by the ISO week of their first ever login.
-- For each cohort × lifetime-week combination, outputs both:
-- - bounded rate: % active in exactly that week
-- - unbounded rate: % who were ever active on or after that week
-- Weeks are capped to the cohort's actual age (no future data points).
--
-- SOURCE TABLES
-- auth.sessions — Login session records
--
-- HOW TO READ THE OUTPUT
-- cohort_week_start The Monday of the week users first logged in
-- user_lifetime_week 0 = signup week, 1 = one week later, etc.
-- retention_rate_bounded = active_users_bounded / cohort_users
-- retention_rate_unbounded = retained_users_unbounded / cohort_users
--
-- OUTPUT COLUMNS
-- cohort_week_start DATE First day of the cohort's signup week
-- cohort_label TEXT ISO week label (e.g. '2025-W01')
-- cohort_label_n TEXT ISO week label with cohort size (e.g. '2025-W01 (n=42)')
-- user_lifetime_week INT Weeks since first login (0 = signup week)
-- cohort_users BIGINT Total users in this cohort (denominator)
-- active_users_bounded BIGINT Users active in exactly week k
-- retained_users_unbounded BIGINT Users active any time on/after week k
-- retention_rate_bounded FLOAT bounded active / cohort_users
-- retention_rate_unbounded FLOAT unbounded retained / cohort_users
-- cohort_users_w0 BIGINT cohort_users only at week 0, else 0 (safe to SUM in pivot tables)
--
-- EXAMPLE QUERIES
-- -- Week-1 retention rate per cohort
-- SELECT cohort_label, retention_rate_bounded AS w1_retention
-- FROM analytics.retention_login_weekly
-- WHERE user_lifetime_week = 1
-- ORDER BY cohort_week_start;
--
-- -- Overall average retention curve (all cohorts combined)
-- SELECT user_lifetime_week,
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_w0), 0) AS avg_retention
-- FROM analytics.retention_login_weekly
-- GROUP BY 1 ORDER BY 1;
-- =============================================================
WITH params AS (SELECT 12::int AS max_weeks),
events AS (
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
DATE_TRUNC('week', s.created_at)::date AS week_start
FROM auth.sessions s WHERE s.user_id IS NOT NULL
),
first_login AS (
SELECT user_id, MIN(created_at) AS first_login_time,
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
FROM events GROUP BY 1
),
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
user_week_age AS (
SELECT aw.user_id, fl.cohort_week_start,
((aw.week_start - DATE_TRUNC('week', fl.first_login_time)::date) / 7)::int AS user_lifetime_week
FROM activity_weeks aw JOIN first_login fl USING (user_id)
WHERE aw.week_start >= DATE_TRUNC('week', fl.first_login_time)::date
),
bounded_counts AS (
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_week_start, cs.cohort_users,
LEAST((SELECT max_weeks FROM params),
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date - cs.cohort_week_start)/7)::int)) AS cap_weeks
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
)
SELECT
g.cohort_week_start,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_week, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
ORDER BY g.cohort_week_start, g.user_lifetime_week

View File

@@ -0,0 +1,71 @@
-- =============================================================
-- View: analytics.user_block_spending
-- Looker source alias: ds6 | Charts: 5
-- =============================================================
-- DESCRIPTION
-- One row per credit transaction (last 90 days).
-- Shows how users spend credits broken down by block type,
-- LLM provider and model. Joins node execution stats for
-- token-level detail.
--
-- SOURCE TABLES
-- platform.CreditTransaction — Credit debit/credit records
-- platform.AgentNodeExecution — Node execution stats (for token counts)
--
-- OUTPUT COLUMNS
-- transactionKey TEXT Unique transaction identifier
-- userId TEXT User who was charged
-- amount DECIMAL Credit amount (positive = credit, negative = debit)
-- negativeAmount DECIMAL amount * -1 (convenience for spend charts)
-- transactionType TEXT Transaction type (e.g. 'USAGE', 'REFUND', 'TOP_UP')
-- transactionTime TIMESTAMPTZ When the transaction was recorded
-- blockId TEXT Block UUID that triggered the spend
-- blockName TEXT Human-readable block name
-- llm_provider TEXT LLM provider (e.g. 'openai', 'anthropic')
-- llm_model TEXT Model name (e.g. 'gpt-4o', 'claude-3-5-sonnet')
-- node_exec_id TEXT Linked node execution UUID
-- llm_call_count INT LLM API calls made in that execution
-- llm_retry_count INT LLM retries in that execution
-- llm_input_token_count INT Input tokens consumed
-- llm_output_token_count INT Output tokens produced
--
-- WINDOW
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
--
-- EXAMPLE QUERIES
-- -- Total spend per user (last 90 days)
-- SELECT "userId", SUM("negativeAmount") AS total_spent
-- FROM analytics.user_block_spending
-- WHERE "transactionType" = 'USAGE'
-- GROUP BY 1 ORDER BY total_spent DESC;
--
-- -- Spend by LLM provider + model
-- SELECT "llm_provider", "llm_model",
-- SUM("negativeAmount") AS total_cost,
-- SUM("llm_input_token_count") AS input_tokens,
-- SUM("llm_output_token_count") AS output_tokens
-- FROM analytics.user_block_spending
-- WHERE "llm_provider" IS NOT NULL
-- GROUP BY 1, 2 ORDER BY total_cost DESC;
-- =============================================================
SELECT
c."transactionKey" AS transactionKey,
c."userId" AS userId,
c."amount" AS amount,
c."amount" * -1 AS negativeAmount,
c."type" AS transactionType,
c."createdAt" AS transactionTime,
c.metadata->>'block_id' AS blockId,
c.metadata->>'block' AS blockName,
c.metadata->'input'->'credentials'->>'provider' AS llm_provider,
c.metadata->'input'->>'model' AS llm_model,
c.metadata->>'node_exec_id' AS node_exec_id,
(ne."stats"->>'llm_call_count')::int AS llm_call_count,
(ne."stats"->>'llm_retry_count')::int AS llm_retry_count,
(ne."stats"->>'input_token_count')::int AS llm_input_token_count,
(ne."stats"->>'output_token_count')::int AS llm_output_token_count
FROM platform."CreditTransaction" c
LEFT JOIN platform."AgentNodeExecution" ne
ON (c.metadata->>'node_exec_id') = ne."id"::text
WHERE c."createdAt" > CURRENT_DATE - INTERVAL '90 days'

View File

@@ -0,0 +1,45 @@
-- =============================================================
-- View: analytics.user_onboarding
-- Looker source alias: ds68 | Charts: 3
-- =============================================================
-- DESCRIPTION
-- One row per user onboarding record. Contains the user's
-- stated usage reason, selected integrations, completed
-- onboarding steps and optional first agent selection.
-- Full history (no date filter) since onboarding happens
-- once per user.
--
-- SOURCE TABLES
-- platform.UserOnboarding — Onboarding state per user
--
-- OUTPUT COLUMNS
-- id TEXT Onboarding record UUID
-- createdAt TIMESTAMPTZ When onboarding started
-- updatedAt TIMESTAMPTZ Last update to onboarding state
-- usageReason TEXT Why user signed up (e.g. 'work', 'personal')
-- integrations TEXT[] Array of integration names the user selected
-- userId TEXT User UUID
-- completedSteps TEXT[] Array of onboarding step enums completed
-- selectedStoreListingVersionId TEXT First marketplace agent the user chose (if any)
--
-- EXAMPLE QUERIES
-- -- Usage reason breakdown
-- SELECT "usageReason", COUNT(*) FROM analytics.user_onboarding GROUP BY 1;
--
-- -- Completion rate per step
-- SELECT step, COUNT(*) AS users_completed
-- FROM analytics.user_onboarding
-- CROSS JOIN LATERAL UNNEST("completedSteps") AS step
-- GROUP BY 1 ORDER BY users_completed DESC;
-- =============================================================
SELECT
id,
"createdAt",
"updatedAt",
"usageReason",
integrations,
"userId",
"completedSteps",
"selectedStoreListingVersionId"
FROM platform."UserOnboarding"

View File

@@ -0,0 +1,100 @@
-- =============================================================
-- View: analytics.user_onboarding_funnel
-- Looker source alias: ds74 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Pre-aggregated onboarding funnel showing how many users
-- completed each step and the drop-off percentage from the
-- previous step. One row per onboarding step (all 22 steps
-- always present, even with 0 completions — prevents sparse
-- gaps from making LAG compare the wrong predecessors).
--
-- SOURCE TABLES
-- platform.UserOnboarding — Onboarding records with completedSteps array
--
-- OUTPUT COLUMNS
-- step TEXT Onboarding step enum name (e.g. 'WELCOME', 'CONGRATS')
-- step_order INT Numeric position in the funnel (1=first, 22=last)
-- users_completed BIGINT Distinct users who completed this step
-- pct_from_prev NUMERIC % of users from the previous step who reached this one
--
-- STEP ORDER
-- 1 WELCOME 9 MARKETPLACE_VISIT 17 SCHEDULE_AGENT
-- 2 USAGE_REASON 10 MARKETPLACE_ADD_AGENT 18 RUN_AGENTS
-- 3 INTEGRATIONS 11 MARKETPLACE_RUN_AGENT 19 RUN_3_DAYS
-- 4 AGENT_CHOICE 12 BUILDER_OPEN 20 TRIGGER_WEBHOOK
-- 5 AGENT_NEW_RUN 13 BUILDER_SAVE_AGENT 21 RUN_14_DAYS
-- 6 AGENT_INPUT 14 BUILDER_RUN_AGENT 22 RUN_AGENTS_100
-- 7 CONGRATS 15 VISIT_COPILOT
-- 8 GET_RESULTS 16 RE_RUN_AGENT
--
-- WINDOW
-- Users who started onboarding in the last 90 days
--
-- EXAMPLE QUERIES
-- -- Full funnel
-- SELECT * FROM analytics.user_onboarding_funnel ORDER BY step_order;
--
-- -- Biggest drop-off point
-- SELECT step, pct_from_prev FROM analytics.user_onboarding_funnel
-- ORDER BY pct_from_prev ASC LIMIT 3;
-- =============================================================
WITH all_steps AS (
-- Complete ordered grid of all 22 steps so zero-completion steps
-- are always present, keeping LAG comparisons correct.
SELECT step_name, step_order
FROM (VALUES
('WELCOME', 1),
('USAGE_REASON', 2),
('INTEGRATIONS', 3),
('AGENT_CHOICE', 4),
('AGENT_NEW_RUN', 5),
('AGENT_INPUT', 6),
('CONGRATS', 7),
('GET_RESULTS', 8),
('MARKETPLACE_VISIT', 9),
('MARKETPLACE_ADD_AGENT', 10),
('MARKETPLACE_RUN_AGENT', 11),
('BUILDER_OPEN', 12),
('BUILDER_SAVE_AGENT', 13),
('BUILDER_RUN_AGENT', 14),
('VISIT_COPILOT', 15),
('RE_RUN_AGENT', 16),
('SCHEDULE_AGENT', 17),
('RUN_AGENTS', 18),
('RUN_3_DAYS', 19),
('TRIGGER_WEBHOOK', 20),
('RUN_14_DAYS', 21),
('RUN_AGENTS_100', 22)
) AS t(step_name, step_order)
),
raw AS (
SELECT
u."userId",
step_txt::text AS step
FROM platform."UserOnboarding" u
CROSS JOIN LATERAL UNNEST(u."completedSteps") AS step_txt
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
),
step_counts AS (
SELECT step, COUNT(DISTINCT "userId") AS users_completed
FROM raw GROUP BY step
),
funnel AS (
SELECT
a.step_name AS step,
a.step_order,
COALESCE(sc.users_completed, 0) AS users_completed,
ROUND(
100.0 * COALESCE(sc.users_completed, 0)
/ NULLIF(
LAG(COALESCE(sc.users_completed, 0)) OVER (ORDER BY a.step_order),
0
),
2
) AS pct_from_prev
FROM all_steps a
LEFT JOIN step_counts sc ON sc.step = a.step_name
)
SELECT * FROM funnel ORDER BY step_order

View File

@@ -0,0 +1,41 @@
-- =============================================================
-- View: analytics.user_onboarding_integration
-- Looker source alias: ds75 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Pre-aggregated count of users who selected each integration
-- during onboarding. One row per integration type, sorted
-- by popularity.
--
-- SOURCE TABLES
-- platform.UserOnboarding — integrations array column
--
-- OUTPUT COLUMNS
-- integration TEXT Integration name (e.g. 'github', 'slack', 'notion')
-- users_with_integration BIGINT Distinct users who selected this integration
--
-- WINDOW
-- Users who started onboarding in the last 90 days
--
-- EXAMPLE QUERIES
-- -- Full integration popularity ranking
-- SELECT * FROM analytics.user_onboarding_integration;
--
-- -- Top 5 integrations
-- SELECT * FROM analytics.user_onboarding_integration LIMIT 5;
-- =============================================================
WITH exploded AS (
SELECT
u."userId" AS user_id,
UNNEST(u."integrations") AS integration
FROM platform."UserOnboarding" u
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
)
SELECT
integration,
COUNT(DISTINCT user_id) AS users_with_integration
FROM exploded
WHERE integration IS NOT NULL AND integration <> ''
GROUP BY integration
ORDER BY users_with_integration DESC

View File

@@ -0,0 +1,145 @@
-- =============================================================
-- View: analytics.users_activities
-- Looker source alias: ds56 | Charts: 5
-- =============================================================
-- DESCRIPTION
-- One row per user with lifetime activity summary.
-- Joins login sessions with agent graphs, executions and
-- node-level runs to give a full picture of how engaged
-- each user is. Includes a convenience flag for 7-day
-- activation (did the user return at least 7 days after
-- their first login?).
--
-- SOURCE TABLES
-- auth.sessions — Login/session records
-- platform.AgentGraph — Graphs (agents) built by the user
-- platform.AgentGraphExecution — Agent run history
-- platform.AgentNodeExecution — Individual block execution history
--
-- PERFORMANCE NOTE
-- Each CTE aggregates its own table independently by userId.
-- This avoids the fan-out that occurs when driving every join
-- from user_logins across the two largest tables
-- (AgentGraphExecution and AgentNodeExecution).
--
-- OUTPUT COLUMNS
-- user_id TEXT Supabase user UUID
-- first_login_time TIMESTAMPTZ First ever session created_at
-- last_login_time TIMESTAMPTZ Most recent session created_at
-- last_visit_time TIMESTAMPTZ Max of last refresh or login
-- last_agent_save_time TIMESTAMPTZ Last time user saved an agent graph
-- agent_count BIGINT Number of distinct active graphs built (0 if none)
-- first_agent_run_time TIMESTAMPTZ First ever graph execution
-- last_agent_run_time TIMESTAMPTZ Most recent graph execution
-- unique_agent_runs BIGINT Distinct agent graphs ever run (0 if none)
-- agent_runs BIGINT Total graph execution count (0 if none)
-- node_execution_count BIGINT Total node executions across all runs
-- node_execution_failed BIGINT Node executions with FAILED status
-- node_execution_completed BIGINT Node executions with COMPLETED status
-- node_execution_terminated BIGINT Node executions with TERMINATED status
-- node_execution_queued BIGINT Node executions with QUEUED status
-- node_execution_running BIGINT Node executions with RUNNING status
-- is_active_after_7d INT 1=returned after day 7, 0=did not, NULL=too early to tell
-- node_execution_incomplete BIGINT Node executions with INCOMPLETE status
-- node_execution_review BIGINT Node executions with REVIEW status
--
-- EXAMPLE QUERIES
-- -- Users who ran at least one agent and returned after 7 days
-- SELECT COUNT(*) FROM analytics.users_activities
-- WHERE agent_runs > 0 AND is_active_after_7d = 1;
--
-- -- Top 10 most active users by agent runs
-- SELECT user_id, agent_runs, node_execution_count
-- FROM analytics.users_activities
-- ORDER BY agent_runs DESC LIMIT 10;
--
-- -- 7-day activation rate
-- SELECT
-- SUM(CASE WHEN is_active_after_7d = 1 THEN 1 ELSE 0 END)::float
-- / NULLIF(COUNT(CASE WHEN is_active_after_7d IS NOT NULL THEN 1 END), 0)
-- AS activation_rate
-- FROM analytics.users_activities;
-- =============================================================
WITH user_logins AS (
SELECT
user_id::text AS user_id,
MIN(created_at) AS first_login_time,
MAX(created_at) AS last_login_time,
GREATEST(
MAX(refreshed_at)::timestamptz,
MAX(created_at)::timestamptz
) AS last_visit_time
FROM auth.sessions
GROUP BY user_id
),
user_agents AS (
-- Aggregate AgentGraph directly by userId (no fan-out from user_logins)
SELECT
"userId"::text AS user_id,
MAX("updatedAt") AS last_agent_save_time,
COUNT(DISTINCT "id") AS agent_count
FROM platform."AgentGraph"
WHERE "isActive"
GROUP BY "userId"
),
user_graph_runs AS (
-- Aggregate AgentGraphExecution directly by userId
SELECT
"userId"::text AS user_id,
MIN("createdAt") AS first_agent_run_time,
MAX("createdAt") AS last_agent_run_time,
COUNT(DISTINCT "agentGraphId") AS unique_agent_runs,
COUNT("id") AS agent_runs
FROM platform."AgentGraphExecution"
GROUP BY "userId"
),
user_node_runs AS (
-- Aggregate AgentNodeExecution directly; resolve userId via a
-- single join to AgentGraphExecution instead of fanning out from
-- user_logins through both large tables.
SELECT
g."userId"::text AS user_id,
COUNT(*) AS node_execution_count,
COUNT(*) FILTER (WHERE n."executionStatus" = 'FAILED') AS node_execution_failed,
COUNT(*) FILTER (WHERE n."executionStatus" = 'COMPLETED') AS node_execution_completed,
COUNT(*) FILTER (WHERE n."executionStatus" = 'TERMINATED') AS node_execution_terminated,
COUNT(*) FILTER (WHERE n."executionStatus" = 'QUEUED') AS node_execution_queued,
COUNT(*) FILTER (WHERE n."executionStatus" = 'RUNNING') AS node_execution_running,
COUNT(*) FILTER (WHERE n."executionStatus" = 'INCOMPLETE') AS node_execution_incomplete,
COUNT(*) FILTER (WHERE n."executionStatus" = 'REVIEW') AS node_execution_review
FROM platform."AgentNodeExecution" n
JOIN platform."AgentGraphExecution" g
ON g."id" = n."agentGraphExecutionId"
GROUP BY g."userId"
)
SELECT
ul.user_id,
ul.first_login_time,
ul.last_login_time,
ul.last_visit_time,
ua.last_agent_save_time,
COALESCE(ua.agent_count, 0) AS agent_count,
gr.first_agent_run_time,
gr.last_agent_run_time,
COALESCE(gr.unique_agent_runs, 0) AS unique_agent_runs,
COALESCE(gr.agent_runs, 0) AS agent_runs,
COALESCE(nr.node_execution_count, 0) AS node_execution_count,
COALESCE(nr.node_execution_failed, 0) AS node_execution_failed,
COALESCE(nr.node_execution_completed, 0) AS node_execution_completed,
COALESCE(nr.node_execution_terminated, 0) AS node_execution_terminated,
COALESCE(nr.node_execution_queued, 0) AS node_execution_queued,
COALESCE(nr.node_execution_running, 0) AS node_execution_running,
CASE
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
AND ul.last_visit_time >= ul.first_login_time + INTERVAL '7 days' THEN 1
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
AND ul.last_visit_time < ul.first_login_time + INTERVAL '7 days' THEN 0
ELSE NULL
END AS is_active_after_7d,
COALESCE(nr.node_execution_incomplete, 0) AS node_execution_incomplete,
COALESCE(nr.node_execution_review, 0) AS node_execution_review
FROM user_logins ul
LEFT JOIN user_agents ua ON ul.user_id = ua.user_id
LEFT JOIN user_graph_runs gr ON ul.user_id = gr.user_id
LEFT JOIN user_node_runs nr ON ul.user_id = nr.user_id

View File

@@ -28,6 +28,7 @@ from backend.copilot.model import (
update_session_title,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.e2b_sandbox import kill_sandbox
from backend.copilot.tools.models import (
AgentDetailsResponse,
AgentOutputResponse,
@@ -265,12 +266,12 @@ async def delete_session(
)
# Best-effort cleanup of the E2B sandbox (if any).
config = ChatConfig()
if config.use_e2b_sandbox and config.e2b_api_key:
from backend.copilot.tools.e2b_sandbox import kill_sandbox
# sandbox_id is in Redis; kill_sandbox() fetches it from there.
e2b_cfg = ChatConfig()
if e2b_cfg.e2b_active:
assert e2b_cfg.e2b_api_key # guaranteed by e2b_active check
try:
await kill_sandbox(session_id, config.e2b_api_key)
await kill_sandbox(session_id, e2b_cfg.e2b_api_key)
except Exception:
logger.warning(
"[E2B] Failed to kill sandbox for session %s", session_id[:12]

View File

@@ -638,7 +638,7 @@ async def test_process_review_action_auto_approve_creates_auto_approval_records(
# Mock get_node_executions to return node_id mapping
mock_get_node_executions = mocker.patch(
"backend.data.execution.get_node_executions"
"backend.api.features.executions.review.routes.get_node_executions"
)
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
mock_node_exec.node_exec_id = "test_node_123"
@@ -936,7 +936,7 @@ async def test_process_review_action_auto_approve_only_applies_to_approved_revie
# Mock get_node_executions to return node_id mapping
mock_get_node_executions = mocker.patch(
"backend.data.execution.get_node_executions"
"backend.api.features.executions.review.routes.get_node_executions"
)
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
mock_node_exec.node_exec_id = "node_exec_approved"
@@ -1148,7 +1148,7 @@ async def test_process_review_action_per_review_auto_approve_granularity(
# Mock get_node_executions to return batch node data
mock_get_node_executions = mocker.patch(
"backend.data.execution.get_node_executions"
"backend.api.features.executions.review.routes.get_node_executions"
)
# Create mock node executions for each review
mock_node_execs = []

View File

@@ -6,10 +6,15 @@ import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, HTTPException, Query, Security, status
from prisma.enums import ReviewStatus
from backend.copilot.constants import (
is_copilot_synthetic_id,
parse_node_id_from_exec_id,
)
from backend.data.execution import (
ExecutionContext,
ExecutionStatus,
get_graph_execution_meta,
get_node_executions,
)
from backend.data.graph import get_graph_settings
from backend.data.human_review import (
@@ -36,6 +41,38 @@ router = APIRouter(
)
async def _resolve_node_ids(
node_exec_ids: list[str],
graph_exec_id: str,
is_copilot: bool,
) -> dict[str, str]:
"""Resolve node_exec_id -> node_id for auto-approval records.
CoPilot synthetic IDs encode node_id in the format "{node_id}:{random}".
Graph executions look up node_id from NodeExecution records.
"""
if not node_exec_ids:
return {}
if is_copilot:
return {neid: parse_node_id_from_exec_id(neid) for neid in node_exec_ids}
node_execs = await get_node_executions(
graph_exec_id=graph_exec_id, include_exec_data=False
)
node_exec_map = {ne.node_exec_id: ne.node_id for ne in node_execs}
result = {}
for neid in node_exec_ids:
if neid in node_exec_map:
result[neid] = node_exec_map[neid]
else:
logger.error(
f"Failed to resolve node_id for {neid}: Node execution not found."
)
return result
@router.get(
"/pending",
summary="Get Pending Reviews",
@@ -110,14 +147,16 @@ async def list_pending_reviews_for_execution(
"""
# Verify user owns the graph execution before returning reviews
graph_exec = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
# (CoPilot synthetic IDs don't have graph execution records)
if not is_copilot_synthetic_id(graph_exec_id):
graph_exec = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
@@ -160,30 +199,26 @@ async def process_review_action(
)
graph_exec_id = next(iter(graph_exec_ids))
is_copilot = is_copilot_synthetic_id(graph_exec_id)
# Validate execution status before processing reviews
graph_exec_meta = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec_meta:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
# Only allow processing reviews if execution is paused for review
# or incomplete (partial execution with some reviews already processed)
if graph_exec_meta.status not in (
ExecutionStatus.REVIEW,
ExecutionStatus.INCOMPLETE,
):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. "
f"Reviews can only be processed when execution is paused (REVIEW status). "
f"Current status: {graph_exec_meta.status}",
# Validate execution status for graph executions (skip for CoPilot synthetic IDs)
if not is_copilot:
graph_exec_meta = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec_meta:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
if graph_exec_meta.status not in (
ExecutionStatus.REVIEW,
ExecutionStatus.INCOMPLETE,
):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}",
)
# Build review decisions map and track which reviews requested auto-approval
# Auto-approved reviews use original data (no modifications allowed)
@@ -236,7 +271,7 @@ async def process_review_action(
)
return (node_id, False)
# Collect node_exec_ids that need auto-approval
# Collect node_exec_ids that need auto-approval and resolve their node_ids
node_exec_ids_needing_auto_approval = [
node_exec_id
for node_exec_id, review_result in updated_reviews.items()
@@ -244,29 +279,16 @@ async def process_review_action(
and auto_approve_requests.get(node_exec_id, False)
]
# Batch-fetch node executions to get node_ids
node_id_map = await _resolve_node_ids(
node_exec_ids_needing_auto_approval, graph_exec_id, is_copilot
)
# Deduplicate by node_id — one auto-approval per node
nodes_needing_auto_approval: dict[str, Any] = {}
if node_exec_ids_needing_auto_approval:
from backend.data.execution import get_node_executions
node_execs = await get_node_executions(
graph_exec_id=graph_exec_id, include_exec_data=False
)
node_exec_map = {node_exec.node_exec_id: node_exec for node_exec in node_execs}
for node_exec_id in node_exec_ids_needing_auto_approval:
node_exec = node_exec_map.get(node_exec_id)
if node_exec:
review_result = updated_reviews[node_exec_id]
# Use the first approved review for this node (deduplicate by node_id)
if node_exec.node_id not in nodes_needing_auto_approval:
nodes_needing_auto_approval[node_exec.node_id] = review_result
else:
logger.error(
f"Failed to create auto-approval record for {node_exec_id}: "
f"Node execution not found. This may indicate a race condition "
f"or data inconsistency."
)
for node_exec_id in node_exec_ids_needing_auto_approval:
node_id = node_id_map.get(node_exec_id)
if node_id and node_id not in nodes_needing_auto_approval:
nodes_needing_auto_approval[node_id] = updated_reviews[node_exec_id]
# Execute all auto-approval creations in parallel (deduplicated by node_id)
auto_approval_results = await asyncio.gather(
@@ -281,13 +303,11 @@ async def process_review_action(
auto_approval_failed_count = 0
for result in auto_approval_results:
if isinstance(result, Exception):
# Unexpected exception during auto-approval creation
auto_approval_failed_count += 1
logger.error(
f"Unexpected exception during auto-approval creation: {result}"
)
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
# Auto-approval creation failed (returned False)
auto_approval_failed_count += 1
# Count results
@@ -302,22 +322,20 @@ async def process_review_action(
if review.status == ReviewStatus.REJECTED
)
# Resume execution only if ALL pending reviews for this execution have been processed
if updated_reviews:
# Resume graph execution only for real graph executions (not CoPilot)
# CoPilot sessions are resumed by the LLM retrying run_block with review_id
if not is_copilot and updated_reviews:
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
if not still_has_pending:
# Get the graph_id from any processed review
first_review = next(iter(updated_reviews.values()))
try:
# Fetch user and settings to build complete execution context
user = await get_user_by_id(user_id)
settings = await get_graph_settings(
user_id=user_id, graph_id=first_review.graph_id
)
# Preserve user's timezone preference when resuming execution
user_timezone = (
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
)

View File

@@ -24,7 +24,7 @@ from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.request import HTTPClientError, Requests, validate_url
from backend.util.request import HTTPClientError, Requests, validate_url_host
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -80,7 +80,7 @@ async def discover_tools(
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
await validate_url_host(request.server_url)
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
@@ -167,7 +167,7 @@ async def mcp_oauth_login(
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
await validate_url_host(request.server_url)
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
@@ -187,7 +187,7 @@ async def mcp_oauth_login(
# Validate the auth server URL from metadata to prevent SSRF.
try:
await validate_url(auth_server_url, trusted_origins=[])
await validate_url_host(auth_server_url)
except ValueError as e:
raise fastapi.HTTPException(
status_code=400,
@@ -234,7 +234,7 @@ async def mcp_oauth_login(
if registration_endpoint:
# Validate the registration endpoint to prevent SSRF via metadata.
try:
await validate_url(registration_endpoint, trusted_origins=[])
await validate_url_host(registration_endpoint)
except ValueError:
pass # Skip registration, fall back to default client_id
else:
@@ -429,7 +429,7 @@ async def mcp_store_token(
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
await validate_url_host(request.server_url)
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")

View File

@@ -32,9 +32,9 @@ async def client():
@pytest.fixture(autouse=True)
def _bypass_ssrf_validation():
"""Bypass validate_url in all route tests (test URLs don't resolve)."""
"""Bypass validate_url_host in all route tests (test URLs don't resolve)."""
with patch(
"backend.api.features.mcp.routes.validate_url",
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
):
yield
@@ -521,12 +521,12 @@ class TestStoreToken:
class TestSSRFValidation:
"""Verify that validate_url is enforced on all endpoints."""
"""Verify that validate_url_host is enforced on all endpoints."""
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):
@@ -541,7 +541,7 @@ class TestSSRFValidation:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
side_effect=ValueError("blocked private IP"),
):
@@ -556,7 +556,7 @@ class TestSSRFValidation:
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):

View File

@@ -418,6 +418,8 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
_optimized_description: ClassVar[str | None] = None
def __init__(
self,
id: str = "",
@@ -470,6 +472,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.block_type = block_type
self.webhook_config = webhook_config
self.is_sensitive_action = is_sensitive_action
# Read from ClassVar set by initialize_blocks()
self.optimized_description: str | None = type(self)._optimized_description
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
if self.webhook_config:
@@ -620,6 +624,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
graph_id: str,
graph_version: int,
execution_context: "ExecutionContext",
is_graph_execution: bool = True,
**kwargs,
) -> tuple[bool, BlockInput]:
"""
@@ -648,6 +653,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
graph_version=graph_version,
block_name=self.name,
editable=True,
is_graph_execution=is_graph_execution,
)
if decision is None:

View File

@@ -126,7 +126,7 @@ class PrintToConsoleBlock(Block):
output_schema=PrintToConsoleBlock.Output,
test_input={"text": "Hello, World!"},
is_sensitive_action=True,
disabled=True, # Disabled per Nick Tindle's request (OPEN-3000)
disabled=True,
test_output=[
("output", "Hello, World!"),
("status", "printed"),

View File

@@ -142,7 +142,7 @@ class BaseE2BExecutorMixin:
start_timestamp = ts_result.stdout.strip() if ts_result.stdout else None
# Execute the code
execution = await sandbox.run_code(
execution = await sandbox.run_code( # type: ignore[attr-defined]
code,
language=language.value,
on_error=lambda e: sandbox.kill(), # Kill the sandbox on error

View File

@@ -67,6 +67,7 @@ class HITLReviewHelper:
graph_version: int,
block_name: str = "Block",
editable: bool = False,
is_graph_execution: bool = True,
) -> Optional[ReviewResult]:
"""
Handle a review request for a block that requires human review.
@@ -143,10 +144,11 @@ class HITLReviewHelper:
logger.info(
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
)
await HITLReviewHelper.update_node_execution_status(
exec_id=node_exec_id,
status=ExecutionStatus.REVIEW,
)
if is_graph_execution:
await HITLReviewHelper.update_node_execution_status(
exec_id=node_exec_id,
status=ExecutionStatus.REVIEW,
)
return None # Signal that execution should pause
# Mark review as processed if not already done
@@ -168,6 +170,7 @@ class HITLReviewHelper:
graph_version: int,
block_name: str = "Block",
editable: bool = False,
is_graph_execution: bool = True,
) -> Optional[ReviewDecision]:
"""
Handle a review request and return the decision in a single call.
@@ -197,6 +200,7 @@ class HITLReviewHelper:
graph_version=graph_version,
block_name=block_name,
editable=editable,
is_graph_execution=is_graph_execution,
)
if review_result is None:

View File

@@ -17,7 +17,7 @@ from backend.blocks.jina._auth import (
from backend.blocks.search import GetRequest
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
from backend.util.request import HTTPClientError, HTTPServerError, validate_url
from backend.util.request import HTTPClientError, HTTPServerError, validate_url_host
class SearchTheWebBlock(Block, GetRequest):
@@ -112,7 +112,7 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
) -> BlockOutput:
if input_data.raw_content:
try:
parsed_url, _, _ = await validate_url(input_data.url, [])
parsed_url, _, _ = await validate_url_host(input_data.url)
url = parsed_url.geturl()
except ValueError as e:
yield "error", f"Invalid URL: {e}"

View File

@@ -31,10 +31,14 @@ from backend.data.model import (
)
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.clients import OPENROUTER_BASE_URL
from backend.util.logging import TruncatedLogger
from backend.util.prompt import compress_context, estimate_token_count
from backend.util.request import validate_url_host
from backend.util.settings import Settings
from backend.util.text import TextFormatter
settings = Settings()
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
fmt = TextFormatter(autoescape=False)
@@ -804,6 +808,11 @@ async def llm_call(
if tools:
raise ValueError("Ollama does not support tools.")
# Validate user-provided Ollama host to prevent SSRF etc.
await validate_url_host(
ollama_host, trusted_hostnames=[settings.config.ollama_host]
)
client = ollama.AsyncClient(host=ollama_host)
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
@@ -825,7 +834,7 @@ async def llm_call(
elif provider == "open_router":
tools_param = tools if tools else openai.NOT_GIVEN
client = openai.AsyncOpenAI(
base_url="https://openrouter.ai/api/v1",
base_url=OPENROUTER_BASE_URL,
api_key=credentials.api_key.get_secret_value(),
)

View File

@@ -21,6 +21,7 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.clients import OPENROUTER_BASE_URL
from backend.util.logging import TruncatedLogger
logger = TruncatedLogger(logging.getLogger(__name__), "[Perplexity-Block]")
@@ -136,7 +137,7 @@ class PerplexityBlock(Block):
) -> dict[str, Any]:
"""Call Perplexity via OpenRouter and extract annotations."""
client = openai.AsyncOpenAI(
base_url="https://openrouter.ai/api/v1",
base_url=OPENROUTER_BASE_URL,
api_key=credentials.api_key.get_secret_value(),
)

View File

@@ -1,10 +1,13 @@
"""Configuration management for chat system."""
import os
from typing import Literal
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings
from backend.util.clients import OPENROUTER_BASE_URL
class ChatConfig(BaseSettings):
"""Configuration for the chat system."""
@@ -19,7 +22,7 @@ class ChatConfig(BaseSettings):
)
api_key: str | None = Field(default=None, description="OpenAI API key")
base_url: str | None = Field(
default="https://openrouter.ai/api/v1",
default=OPENROUTER_BASE_URL,
description="Base URL for API (e.g., for OpenRouter)",
)
@@ -112,9 +115,37 @@ class ChatConfig(BaseSettings):
description="E2B sandbox template to use for copilot sessions.",
)
e2b_sandbox_timeout: int = Field(
default=43200, # 12 hours — same as session_ttl
description="E2B sandbox keepalive timeout in seconds.",
default=10800, # 3 hours — wall-clock timeout, not idle; explicit pause is primary
description="E2B sandbox running-time timeout (seconds). "
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
"mechanism; this is the safety net.",
)
e2b_sandbox_on_timeout: Literal["kill", "pause"] = Field(
default="pause",
description="E2B lifecycle action on timeout: 'pause' (default, free) or 'kill'.",
)
@property
def e2b_active(self) -> bool:
"""True when E2B is enabled and the API key is present.
Single source of truth for "should we use E2B right now?".
Prefer this over combining ``use_e2b_sandbox`` and ``e2b_api_key``
separately at call sites.
"""
return self.use_e2b_sandbox and bool(self.e2b_api_key)
@property
def active_e2b_api_key(self) -> str | None:
"""Return the E2B API key when E2B is enabled and configured, else None.
Combines the ``use_e2b_sandbox`` flag check and key presence into one.
Use in callers::
if api_key := config.active_e2b_api_key:
# E2B is active; api_key is narrowed to str
"""
return self.e2b_api_key if self.e2b_active else None
@field_validator("use_e2b_sandbox", mode="before")
@classmethod
@@ -164,7 +195,7 @@ class ChatConfig(BaseSettings):
if not v:
v = os.getenv("OPENAI_BASE_URL")
if not v:
v = "https://openrouter.ai/api/v1"
v = OPENROUTER_BASE_URL
return v
@field_validator("use_claude_agent_sdk", mode="before")

View File

@@ -0,0 +1,38 @@
"""Unit tests for ChatConfig."""
import pytest
from .config import ChatConfig
# Env vars that the ChatConfig validators read — must be cleared so they don't
# override the explicit constructor values we pass in each test.
_E2B_ENV_VARS = (
"CHAT_USE_E2B_SANDBOX",
"CHAT_E2B_API_KEY",
"E2B_API_KEY",
)
@pytest.fixture(autouse=True)
def _clean_e2b_env(monkeypatch: pytest.MonkeyPatch) -> None:
for var in _E2B_ENV_VARS:
monkeypatch.delenv(var, raising=False)
class TestE2BActive:
"""Tests for the e2b_active property — single source of truth for E2B usage."""
def test_both_enabled_and_key_present_returns_true(self):
"""e2b_active is True when use_e2b_sandbox=True and e2b_api_key is set."""
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key="test-key")
assert cfg.e2b_active is True
def test_enabled_but_missing_key_returns_false(self):
"""e2b_active is False when use_e2b_sandbox=True but e2b_api_key is absent."""
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key=None)
assert cfg.e2b_active is False
def test_disabled_returns_false(self):
"""e2b_active is False when use_e2b_sandbox=False regardless of key."""
cfg = ChatConfig(use_e2b_sandbox=False, e2b_api_key="test-key")
assert cfg.e2b_active is False

View File

@@ -6,6 +6,32 @@
COPILOT_ERROR_PREFIX = "[__COPILOT_ERROR_f7a1__]" # Renders as ErrorCard
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
# Prefix for all synthetic IDs generated by CoPilot block execution.
# Used to distinguish CoPilot-generated records from real graph execution records
# in PendingHumanReview and other tables.
COPILOT_SYNTHETIC_ID_PREFIX = "copilot-"
# Sub-prefixes for session-scoped and node-scoped synthetic IDs.
COPILOT_SESSION_PREFIX = f"{COPILOT_SYNTHETIC_ID_PREFIX}session-"
COPILOT_NODE_PREFIX = f"{COPILOT_SYNTHETIC_ID_PREFIX}node-"
# Separator used in synthetic node_exec_id to encode node_id.
# Format: "{node_id}:{random_hex}" — extract node_id via rsplit(":", 1)[0]
COPILOT_NODE_EXEC_ID_SEPARATOR = ":"
# Compaction notice messages shown to users.
COMPACTION_DONE_MSG = "Earlier messages were summarized to fit within context limits."
COMPACTION_TOOL_NAME = "context_compaction"
def is_copilot_synthetic_id(id_value: str) -> bool:
"""Check if an ID is a CoPilot synthetic ID (not from a real graph execution)."""
return id_value.startswith(COPILOT_SYNTHETIC_ID_PREFIX)
def parse_node_id_from_exec_id(node_exec_id: str) -> str:
"""Extract node_id from a synthetic node_exec_id.
Format: "{node_id}:{random_hex}" → returns "{node_id}".
"""
return node_exec_id.rsplit(COPILOT_NODE_EXEC_ID_SEPARATOR, 1)[0]

View File

@@ -0,0 +1,115 @@
"""Shared execution context for copilot SDK tool handlers.
All context variables and their accessors live here so that
``tool_adapter``, ``file_ref``, and ``e2b_file_tools`` can import them
without creating circular dependencies.
"""
import os
import re
from contextvars import ContextVar
from typing import TYPE_CHECKING
from backend.copilot.model import ChatSession
if TYPE_CHECKING:
from e2b import AsyncSandbox
# Allowed base directory for the Read tool.
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# Encoded project-directory name for the current session (e.g.
# "-private-tmp-copilot-<uuid>"). Set by set_execution_context() so path
# validation can scope tool-results reads to the current session.
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
_current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
)
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
"_current_sandbox", default=None
)
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
def _encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does."""
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
def set_execution_context(
user_id: str | None,
session: ChatSession,
sandbox: "AsyncSandbox | None" = None,
sdk_cwd: str | None = None,
) -> None:
"""Set per-turn context variables used by file-resolution tool handlers."""
_current_user_id.set(user_id)
_current_session.set(session)
_current_sandbox.set(sandbox)
_current_sdk_cwd.set(sdk_cwd or "")
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
def get_execution_context() -> tuple[str | None, ChatSession | None]:
"""Return the current (user_id, session) pair for the active request."""
return _current_user_id.get(), _current_session.get()
def get_current_sandbox() -> "AsyncSandbox | None":
"""Return the E2B sandbox for the current session, or None if not active."""
return _current_sandbox.get()
def get_sdk_cwd() -> str:
"""Return the SDK working directory for the current session (empty string if unset)."""
return _current_sdk_cwd.get()
E2B_WORKDIR = "/home/user"
def resolve_sandbox_path(path: str) -> str:
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
Raises :class:`ValueError` if the resolved path escapes the sandbox.
"""
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
normalized = os.path.normpath(candidate)
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
return normalized
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
"""Return True if *path* is within an allowed host-filesystem location.
Allowed:
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
- Files under ``~/.claude/projects/<encoded-cwd>/tool-results/`` (SDK tool-results)
"""
if not path:
return False
if path.startswith("~"):
resolved = os.path.realpath(os.path.expanduser(path))
elif not os.path.isabs(path) and sdk_cwd:
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
else:
resolved = os.path.realpath(path)
if sdk_cwd:
norm_cwd = os.path.realpath(sdk_cwd)
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
return True
encoded = _current_project_dir.get("")
if encoded:
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
if resolved == tool_results_dir or resolved.startswith(
tool_results_dir + os.sep
):
return True
return False

View File

@@ -0,0 +1,163 @@
"""Tests for context.py — execution context variables and path helpers."""
from __future__ import annotations
import os
import tempfile
from unittest.mock import MagicMock
import pytest
from backend.copilot.context import (
_SDK_PROJECTS_DIR,
_current_project_dir,
get_current_sandbox,
get_execution_context,
get_sdk_cwd,
is_allowed_local_path,
resolve_sandbox_path,
set_execution_context,
)
def _make_session() -> MagicMock:
s = MagicMock()
s.session_id = "test-session"
return s
# ---------------------------------------------------------------------------
# Context variable getters
# ---------------------------------------------------------------------------
def test_get_execution_context_defaults():
"""get_execution_context returns (None, session) when user_id is not set."""
set_execution_context(None, _make_session())
user_id, session = get_execution_context()
assert user_id is None
assert session is not None
def test_set_and_get_execution_context():
"""set_execution_context stores user_id and session."""
mock_session = _make_session()
set_execution_context("user-abc", mock_session)
user_id, session = get_execution_context()
assert user_id == "user-abc"
assert session is mock_session
def test_get_current_sandbox_none_by_default():
"""get_current_sandbox returns None when no sandbox is set."""
set_execution_context("u1", _make_session(), sandbox=None)
assert get_current_sandbox() is None
def test_get_current_sandbox_returns_set_value():
"""get_current_sandbox returns the sandbox set via set_execution_context."""
mock_sandbox = MagicMock()
set_execution_context("u1", _make_session(), sandbox=mock_sandbox)
assert get_current_sandbox() is mock_sandbox
def test_get_sdk_cwd_empty_when_not_set():
"""get_sdk_cwd returns empty string when sdk_cwd is not set."""
set_execution_context("u1", _make_session(), sdk_cwd=None)
assert get_sdk_cwd() == ""
def test_get_sdk_cwd_returns_set_value():
"""get_sdk_cwd returns the value set via set_execution_context."""
set_execution_context("u1", _make_session(), sdk_cwd="/tmp/copilot-test")
assert get_sdk_cwd() == "/tmp/copilot-test"
# ---------------------------------------------------------------------------
# is_allowed_local_path
# ---------------------------------------------------------------------------
def test_is_allowed_local_path_empty():
assert not is_allowed_local_path("")
def test_is_allowed_local_path_inside_sdk_cwd():
with tempfile.TemporaryDirectory() as cwd:
path = os.path.join(cwd, "file.txt")
assert is_allowed_local_path(path, cwd)
def test_is_allowed_local_path_sdk_cwd_itself():
with tempfile.TemporaryDirectory() as cwd:
assert is_allowed_local_path(cwd, cwd)
def test_is_allowed_local_path_outside_sdk_cwd():
with tempfile.TemporaryDirectory() as cwd:
assert not is_allowed_local_path("/etc/passwd", cwd)
def test_is_allowed_local_path_no_sdk_cwd_no_project_dir():
"""Without sdk_cwd or project_dir, all paths are rejected."""
_current_project_dir.set("")
assert not is_allowed_local_path("/tmp/some-file.txt", sdk_cwd=None)
def test_is_allowed_local_path_tool_results_dir():
"""Files under the tool-results directory for the current project are allowed."""
encoded = "test-encoded-dir"
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
path = os.path.join(tool_results_dir, "output.txt")
_current_project_dir.set(encoded)
try:
assert is_allowed_local_path(path, sdk_cwd=None)
finally:
_current_project_dir.set("")
def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
"""A path adjacent to tool-results/ but not inside it is rejected."""
encoded = "test-encoded-dir"
sibling_path = os.path.join(_SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
_current_project_dir.set(encoded)
try:
assert not is_allowed_local_path(sibling_path, sdk_cwd=None)
finally:
_current_project_dir.set("")
# ---------------------------------------------------------------------------
# resolve_sandbox_path
# ---------------------------------------------------------------------------
def test_resolve_sandbox_path_absolute_valid():
assert (
resolve_sandbox_path("/home/user/project/main.py")
== "/home/user/project/main.py"
)
def test_resolve_sandbox_path_relative():
assert resolve_sandbox_path("project/main.py") == "/home/user/project/main.py"
def test_resolve_sandbox_path_workdir_itself():
assert resolve_sandbox_path("/home/user") == "/home/user"
def test_resolve_sandbox_path_normalizes_dots():
assert resolve_sandbox_path("/home/user/a/../b") == "/home/user/b"
def test_resolve_sandbox_path_escape_raises():
with pytest.raises(ValueError, match="/home/user"):
resolve_sandbox_path("/home/user/../../etc/passwd")
def test_resolve_sandbox_path_absolute_outside_raises():
with pytest.raises(ValueError, match="/home/user"):
resolve_sandbox_path("/etc/passwd")

View File

@@ -0,0 +1,138 @@
"""Scheduler job to generate LLM-optimized block descriptions.
Runs periodically to rewrite block descriptions into concise, actionable
summaries that help the copilot LLM pick the right blocks during agent
generation.
"""
import asyncio
import logging
from backend.blocks import get_blocks
from backend.util.clients import get_database_manager_client, get_openai_client
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = (
"You are a technical writer for an automation platform. "
"Rewrite the following block description to be concise (under 50 words), "
"informative, and actionable. Focus on what the block does and when to "
"use it. Output ONLY the rewritten description, nothing else. "
"Do not use markdown formatting."
)
# Rate-limit delay between sequential LLM calls (seconds)
_RATE_LIMIT_DELAY = 0.5
# Maximum tokens for optimized description generation
_MAX_DESCRIPTION_TOKENS = 150
# Model for generating optimized descriptions (fast, cheap)
_MODEL = "gpt-4o-mini"
async def _optimize_descriptions(blocks: list[dict[str, str]]) -> dict[str, str]:
"""Call the shared OpenAI client to rewrite each block description."""
client = get_openai_client()
if client is None:
logger.error(
"No OpenAI client configured, skipping block description optimization"
)
return {}
results: dict[str, str] = {}
for block in blocks:
block_id = block["id"]
block_name = block["name"]
description = block["description"]
try:
response = await client.chat.completions.create(
model=_MODEL,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": f"Block name: {block_name}\nDescription: {description}",
},
],
max_tokens=_MAX_DESCRIPTION_TOKENS,
)
optimized = (response.choices[0].message.content or "").strip()
if optimized:
results[block_id] = optimized
logger.debug("Optimized description for %s", block_name)
else:
logger.warning("Empty response for block %s", block_name)
except Exception:
logger.warning(
"Failed to optimize description for %s", block_name, exc_info=True
)
await asyncio.sleep(_RATE_LIMIT_DELAY)
return results
def optimize_block_descriptions() -> dict[str, int]:
"""Generate optimized descriptions for blocks that don't have one yet.
Uses the shared OpenAI client to rewrite block descriptions into concise
summaries suitable for agent generation prompts.
Returns:
Dict with counts: processed, success, failed, skipped.
"""
db_client = get_database_manager_client()
blocks = db_client.get_blocks_needing_optimization()
if not blocks:
logger.info("All blocks already have optimized descriptions")
return {"processed": 0, "success": 0, "failed": 0, "skipped": 0}
logger.info("Found %d blocks needing optimized descriptions", len(blocks))
non_empty = [b for b in blocks if b.get("description", "").strip()]
skipped = len(blocks) - len(non_empty)
new_descriptions = asyncio.run(_optimize_descriptions(non_empty))
stats = {
"processed": len(non_empty),
"success": len(new_descriptions),
"failed": len(non_empty) - len(new_descriptions),
"skipped": skipped,
}
logger.info(
"Block description optimization complete: "
"%d/%d succeeded, %d failed, %d skipped",
stats["success"],
stats["processed"],
stats["failed"],
stats["skipped"],
)
if new_descriptions:
for block_id, optimized in new_descriptions.items():
db_client.update_block_optimized_description(block_id, optimized)
# Update in-memory descriptions first so the cache rebuilds with fresh data.
try:
block_classes = get_blocks()
for block_id, optimized in new_descriptions.items():
if block_id in block_classes:
block_classes[block_id]._optimized_description = optimized
logger.info(
"Updated %d in-memory block descriptions", len(new_descriptions)
)
except Exception:
logger.warning(
"Could not update in-memory block descriptions", exc_info=True
)
from backend.copilot.tools.agent_generator.blocks import (
reset_block_caches, # local to avoid circular import
)
reset_block_caches()
return stats

View File

@@ -0,0 +1,91 @@
"""Unit tests for optimize_blocks._optimize_descriptions."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from backend.copilot.optimize_blocks import _RATE_LIMIT_DELAY, _optimize_descriptions
def _make_client_response(text: str) -> MagicMock:
"""Build a minimal mock that looks like an OpenAI ChatCompletion response."""
choice = MagicMock()
choice.message.content = text
response = MagicMock()
response.choices = [choice]
return response
def _run(coro):
return asyncio.get_event_loop().run_until_complete(coro)
class TestOptimizeDescriptions:
"""Tests for _optimize_descriptions async function."""
def test_returns_empty_when_no_client(self):
with patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=None
):
result = _run(
_optimize_descriptions([{"id": "b1", "name": "B", "description": "d"}])
)
assert result == {}
def test_success_single_block(self):
client = MagicMock()
client.chat.completions.create = AsyncMock(
return_value=_make_client_response("Short desc.")
)
blocks = [{"id": "b1", "name": "MyBlock", "description": "A block."}]
with (
patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
),
patch(
"backend.copilot.optimize_blocks.asyncio.sleep", new_callable=AsyncMock
),
):
result = _run(_optimize_descriptions(blocks))
assert result == {"b1": "Short desc."}
client.chat.completions.create.assert_called_once()
def test_skips_block_on_exception(self):
client = MagicMock()
client.chat.completions.create = AsyncMock(side_effect=Exception("API error"))
blocks = [{"id": "b1", "name": "MyBlock", "description": "A block."}]
with (
patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
),
patch(
"backend.copilot.optimize_blocks.asyncio.sleep", new_callable=AsyncMock
),
):
result = _run(_optimize_descriptions(blocks))
assert result == {}
def test_sleeps_between_blocks(self):
client = MagicMock()
client.chat.completions.create = AsyncMock(
return_value=_make_client_response("desc")
)
blocks = [
{"id": "b1", "name": "B1", "description": "d1"},
{"id": "b2", "name": "B2", "description": "d2"},
]
sleep_mock = AsyncMock()
with (
patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
),
patch("backend.copilot.optimize_blocks.asyncio.sleep", sleep_mock),
):
_run(_optimize_descriptions(blocks))
assert sleep_mock.call_count == 2
sleep_mock.assert_called_with(_RATE_LIMIT_DELAY)

View File

@@ -26,6 +26,33 @@ your message as a Markdown link or image:
The `download_url` field in the `write_workspace_file` response is already
in the correct format — paste it directly after the `(` in the Markdown.
### Passing file content to tools — @@agptfile: references
Instead of copying large file contents into a tool argument, pass a file
reference and the platform will load the content for you.
Syntax: `@@agptfile:<uri>[<start>-<end>]`
- `<uri>` **must** start with `workspace://` or `/` (absolute path):
- `workspace://<file_id>` — workspace file by ID
- `workspace:///<path>` — workspace file by virtual path
- `/absolute/local/path` — ephemeral or sdk_cwd file
- E2B sandbox absolute path (e.g. `/home/user/script.py`)
- `[<start>-<end>]` is an optional 1-indexed inclusive line range.
- URIs that do not start with `workspace://` or `/` are **not** expanded.
Examples:
```
@@agptfile:workspace://abc123
@@agptfile:workspace://abc123[10-50]
@@agptfile:workspace:///reports/q1.md
@@agptfile:/tmp/copilot-<session>/output.py[1-80]
@@agptfile:/home/user/script.py
```
You can embed a reference inside any string argument, or use it as the entire
value. Multiple references in one argument are all expanded.
### Sub-agent tasks
- When using the Task tool, NEVER set `run_in_background` to true.
All tasks must run in the foreground.

View File

@@ -0,0 +1,155 @@
## Agent Generation Guide
You can create, edit, and customize agents directly. You ARE the brain —
generate the agent JSON yourself using block schemas, then validate and save.
### Workflow for Creating/Editing Agents
1. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
search for relevant blocks. This returns block IDs, names, descriptions,
and full input/output schemas.
2. **Find library agents**: Call `find_library_agent` to discover reusable
agents that can be composed as sub-agents via `AgentExecutorBlock`.
3. **Generate JSON**: Build the agent JSON using block schemas:
- Use block IDs from step 1 as `block_id` in nodes
- Wire outputs to inputs using links
- Set design-time config in `input_default`
- Use `AgentInputBlock` for values the user provides at runtime
4. **Write to workspace**: Save the JSON to a workspace file so the user
can review it: `write_workspace_file(filename="agent.json", content=...)`
5. **Validate**: Call `validate_agent_graph` with the agent JSON to check
for errors
6. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
or fix manually based on the error descriptions. Iterate until valid.
7. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
the final `agent_json`
### Agent JSON Structure
```json
{
"id": "<UUID v4>", // auto-generated if omitted
"version": 1,
"is_active": true,
"name": "Agent Name",
"description": "What the agent does",
"nodes": [
{
"id": "<UUID v4>",
"block_id": "<block UUID from find_block>",
"input_default": {
"field_name": "design-time value"
},
"metadata": {
"position": {"x": 0, "y": 0},
"customized_name": "Optional display name"
}
}
],
"links": [
{
"id": "<UUID v4>",
"source_id": "<source node UUID>",
"source_name": "output_field_name",
"sink_id": "<sink node UUID>",
"sink_name": "input_field_name",
"is_static": false
}
]
}
```
### REQUIRED: AgentInputBlock and AgentOutputBlock
Every agent MUST include at least one AgentInputBlock and one AgentOutputBlock.
These define the agent's interface — what it accepts and what it produces.
**AgentInputBlock** (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`):
- Defines a user-facing input field on the agent
- Required `input_default` fields: `name` (str), `value` (default: null)
- Optional: `title`, `description`, `placeholder_values` (for dropdowns)
- Output: `result` — the user-provided value at runtime
- Create one AgentInputBlock per distinct input the agent needs
**AgentOutputBlock** (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`):
- Defines a user-facing output displayed after the agent runs
- Required `input_default` fields: `name` (str)
- The `value` input should be linked from another block's output
- Optional: `title`, `description`, `format` (Jinja2 template)
- Create one AgentOutputBlock per distinct result to show the user
Without these blocks, the agent has no interface and the user cannot provide
inputs or see outputs. NEVER skip them.
### Key Rules
- **Name & description**: Include `name` and `description` in the agent JSON
when creating a new agent, or when editing and the agent's purpose changed.
Without these the agent gets a generic default name.
- **Design-time vs runtime**: `input_default` = values known at build time.
For user-provided values, create an `AgentInputBlock` node and link its
output to the consuming block's input.
- **Credentials**: Do NOT require credentials upfront. Users configure
credentials later in the platform UI after the agent is saved.
- **Node spacing**: Position nodes with at least 800 X-units between them.
- **Nested properties**: Use `parentField_#_childField` notation in link
sink_name/source_name to access nested object fields.
- **is_static links**: Set `is_static: true` when the link carries a
design-time constant (matches a field in inputSchema with a default).
- **ConditionBlock**: Needs a `StoreValueBlock` wired to its `value2` input.
- **Prompt templates**: Use `{{variable}}` (double curly braces) for
literal braces in prompt strings — single `{` and `}` are for
template variables.
- **AgentExecutorBlock**: When composing sub-agents, set `graph_id` and
`graph_version` in input_default, and wire inputs/outputs to match
the sub-agent's schema.
### Using Sub-Agents (AgentExecutorBlock)
To compose agents using other agents as sub-agents:
1. Call `find_library_agent` to find the sub-agent — the response includes
`graph_id`, `graph_version`, `input_schema`, and `output_schema`
2. Create an `AgentExecutorBlock` node (ID: `e189baac-8c20-45a1-94a7-55177ea42565`)
3. Set `input_default`:
- `graph_id`: from the library agent's `graph_id`
- `graph_version`: from the library agent's `graph_version`
- `input_schema`: from the library agent's `input_schema` (JSON Schema)
- `output_schema`: from the library agent's `output_schema` (JSON Schema)
- `user_id`: leave as `""` (filled at runtime)
- `inputs`: `{}` (populated by links at runtime)
4. Wire inputs: link to sink names matching the sub-agent's `input_schema`
property names (e.g., if input_schema has a `"url"` property, use
`"url"` as the sink_name)
5. Wire outputs: link from source names matching the sub-agent's
`output_schema` property names
6. Pass `library_agent_ids` to `create_agent`/`customize_agent` with
the library agent IDs used, so the fixer can validate schemas
### Using MCP Tools (MCPToolBlock)
To use an MCP (Model Context Protocol) tool as a node in the agent:
1. The user must specify which MCP server URL and tool name they want
2. Create an `MCPToolBlock` node (ID: `a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4`)
3. Set `input_default`:
- `server_url`: the MCP server URL (e.g. `"https://mcp.example.com/sse"`)
- `selected_tool`: the tool name on that server
- `tool_input_schema`: JSON Schema for the tool's inputs
- `tool_arguments`: `{}` (populated by links or hardcoded values)
4. The block requires MCP credentials — the user configures these in the
platform UI after the agent is saved
5. Wire inputs using the tool argument field name directly as the sink_name
(e.g., `query`, NOT `tool_arguments_#_query`). The execution engine
automatically collects top-level fields matching tool_input_schema into
tool_arguments.
6. Output: `result` (the tool's return value) and `error` (error message)
### Example: Simple AI Text Processor
A minimal agent with input, processing, and output:
- Node 1: `AgentInputBlock` (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`,
input_default: {"name": "user_text", "title": "Text to process"},
output: "result")
- Node 2: `AITextGeneratorBlock` (input: "prompt" linked from Node 1's "result")
- Node 3: `AgentOutputBlock` (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`,
input_default: {"name": "summary", "title": "Summary"},
input: "value" linked from Node 2's output)

View File

@@ -8,8 +8,6 @@ SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
"""
from __future__ import annotations
import itertools
import json
import logging
@@ -17,36 +15,23 @@ import os
import shlex
from typing import Any, Callable
from backend.copilot.tools.e2b_sandbox import E2B_WORKDIR
from backend.copilot.context import (
E2B_WORKDIR,
get_current_sandbox,
get_sdk_cwd,
is_allowed_local_path,
resolve_sandbox_path,
)
logger = logging.getLogger(__name__)
# Lazy imports to break circular dependency with tool_adapter.
def _get_sandbox(): # type: ignore[return]
from .tool_adapter import get_current_sandbox # noqa: E402
def _get_sandbox():
return get_current_sandbox()
def _is_allowed_local(path: str) -> bool:
from .tool_adapter import is_allowed_local_path # noqa: E402
return is_allowed_local_path(path)
def _resolve_remote(path: str) -> str:
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
Raises :class:`ValueError` if the resolved path escapes the sandbox.
"""
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
normalized = os.path.normpath(candidate)
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
return normalized
return is_allowed_local_path(path, get_sdk_cwd())
def _mcp(text: str, *, error: bool = False) -> dict[str, Any]:
@@ -63,7 +48,7 @@ def _get_sandbox_and_path(
if sandbox is None:
return _mcp("No E2B sandbox available", error=True)
try:
remote = _resolve_remote(file_path)
remote = resolve_sandbox_path(file_path)
except ValueError as exc:
return _mcp(str(exc), error=True)
return sandbox, remote
@@ -73,6 +58,7 @@ def _get_sandbox_and_path(
async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
"""Read lines from a sandbox file, falling back to the local host for SDK-internal paths."""
file_path: str = args.get("file_path", "")
offset: int = max(0, int(args.get("offset", 0)))
limit: int = max(1, int(args.get("limit", 2000)))
@@ -104,6 +90,7 @@ async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
"""Write content to a sandbox file, creating parent directories as needed."""
file_path: str = args.get("file_path", "")
content: str = args.get("content", "")
@@ -127,6 +114,7 @@ async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
"""Replace a substring in a sandbox file, with optional replace-all support."""
file_path: str = args.get("file_path", "")
old_string: str = args.get("old_string", "")
new_string: str = args.get("new_string", "")
@@ -172,6 +160,7 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
"""Find files matching a name pattern inside the sandbox using ``find``."""
pattern: str = args.get("pattern", "")
path: str = args.get("path", "")
@@ -183,7 +172,7 @@ async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
return _mcp("No E2B sandbox available", error=True)
try:
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
search_dir = resolve_sandbox_path(path) if path else E2B_WORKDIR
except ValueError as exc:
return _mcp(str(exc), error=True)
@@ -198,6 +187,7 @@ async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
"""Search file contents by regex inside the sandbox using ``grep -rn``."""
pattern: str = args.get("pattern", "")
path: str = args.get("path", "")
include: str = args.get("include", "")
@@ -210,7 +200,7 @@ async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
return _mcp("No E2B sandbox available", error=True)
try:
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
search_dir = resolve_sandbox_path(path) if path else E2B_WORKDIR
except ValueError as exc:
return _mcp(str(exc), error=True)
@@ -238,7 +228,7 @@ def _read_local(file_path: str, offset: int, limit: int) -> dict[str, Any]:
return _mcp(f"Path not allowed: {file_path}", error=True)
expanded = os.path.realpath(os.path.expanduser(file_path))
try:
with open(expanded) as fh:
with open(expanded, encoding="utf-8", errors="replace") as fh:
selected = list(itertools.islice(fh, offset, offset + limit))
numbered = "".join(
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)

View File

@@ -7,59 +7,60 @@ import os
import pytest
from .e2b_file_tools import _read_local, _resolve_remote
from .tool_adapter import _current_project_dir
from backend.copilot.context import _current_project_dir
from .e2b_file_tools import _read_local, resolve_sandbox_path
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# ---------------------------------------------------------------------------
# _resolve_remote — sandbox path normalisation & boundary enforcement
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
# ---------------------------------------------------------------------------
class TestResolveRemote:
class TestResolveSandboxPath:
def test_relative_path_resolved(self):
assert _resolve_remote("src/main.py") == "/home/user/src/main.py"
assert resolve_sandbox_path("src/main.py") == "/home/user/src/main.py"
def test_absolute_within_sandbox(self):
assert _resolve_remote("/home/user/file.txt") == "/home/user/file.txt"
assert resolve_sandbox_path("/home/user/file.txt") == "/home/user/file.txt"
def test_workdir_itself(self):
assert _resolve_remote("/home/user") == "/home/user"
assert resolve_sandbox_path("/home/user") == "/home/user"
def test_relative_dotslash(self):
assert _resolve_remote("./README.md") == "/home/user/README.md"
assert resolve_sandbox_path("./README.md") == "/home/user/README.md"
def test_traversal_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("../../etc/passwd")
resolve_sandbox_path("../../etc/passwd")
def test_absolute_traversal_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/home/user/../../etc/passwd")
resolve_sandbox_path("/home/user/../../etc/passwd")
def test_absolute_outside_sandbox_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/etc/passwd")
resolve_sandbox_path("/etc/passwd")
def test_root_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/")
resolve_sandbox_path("/")
def test_home_other_user_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/home/other/file.txt")
resolve_sandbox_path("/home/other/file.txt")
def test_deep_nested_allowed(self):
assert _resolve_remote("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
assert resolve_sandbox_path("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
def test_trailing_slash_normalised(self):
assert _resolve_remote("src/") == "/home/user/src"
assert resolve_sandbox_path("src/") == "/home/user/src"
def test_double_dots_within_sandbox_ok(self):
"""Path that resolves back within /home/user is allowed."""
assert _resolve_remote("a/b/../c.txt") == "/home/user/a/c.txt"
assert resolve_sandbox_path("a/b/../c.txt") == "/home/user/a/c.txt"
# ---------------------------------------------------------------------------

View File

@@ -0,0 +1,281 @@
"""File reference protocol for tool call inputs.
Allows the LLM to pass a file reference instead of embedding large content
inline. The processor expands ``@@agptfile:<uri>[<start>-<end>]`` tokens in tool
arguments before the tool is executed.
Protocol
--------
@@agptfile:<uri>[<start>-<end>]
``<uri>`` (required)
- ``workspace://<file_id>`` — workspace file by ID
- ``workspace://<file_id>#<mime>`` — same, MIME hint is ignored for reads
- ``workspace:///<path>`` — workspace file by virtual path
- ``/absolute/local/path`` — ephemeral or sdk_cwd file (validated by
:func:`~backend.copilot.sdk.tool_adapter.is_allowed_local_path`)
- Any absolute path that resolves inside the E2B sandbox
(``/home/user/...``) when a sandbox is active
``[<start>-<end>]`` (optional)
Line range, 1-indexed inclusive. Examples: ``[1-100]``, ``[50-200]``.
Omit to read the entire file.
Examples
--------
@@agptfile:workspace://abc123
@@agptfile:workspace://abc123[10-50]
@@agptfile:workspace:///reports/q1.md
@@agptfile:/tmp/copilot-<session>/output.py[1-80]
@@agptfile:/home/user/script.sh
"""
import itertools
import logging
import os
import re
from dataclasses import dataclass
from typing import Any
from backend.copilot.context import (
get_current_sandbox,
get_sdk_cwd,
is_allowed_local_path,
resolve_sandbox_path,
)
from backend.copilot.model import ChatSession
from backend.copilot.tools.workspace_files import get_manager
from backend.util.file import parse_workspace_uri
class FileRefExpansionError(Exception):
"""Raised when a ``@@agptfile:`` reference in tool call args fails to resolve.
Separating this from inline substitution lets callers (e.g. the MCP tool
wrapper) block tool execution and surface a helpful error to the model
rather than passing an ``[file-ref error: …]`` string as actual input.
"""
logger = logging.getLogger(__name__)
FILE_REF_PREFIX = "@@agptfile:"
# Matches: @@agptfile:<uri>[start-end]?
# Group 1 URI; must start with '/' (absolute path) or 'workspace://'
# Group 2 start line (optional)
# Group 3 end line (optional)
_FILE_REF_RE = re.compile(
re.escape(FILE_REF_PREFIX) + r"((?:workspace://|/)[^\[\s]*)(?:\[(\d+)-(\d+)\])?"
)
# Maximum characters returned for a single file reference expansion.
_MAX_EXPAND_CHARS = 200_000
# Maximum total characters across all @@agptfile: expansions in one string.
_MAX_TOTAL_EXPAND_CHARS = 1_000_000
@dataclass
class FileRef:
uri: str
start_line: int | None # 1-indexed, inclusive
end_line: int | None # 1-indexed, inclusive
def parse_file_ref(text: str) -> FileRef | None:
"""Return a :class:`FileRef` if *text* is a bare file reference token.
A "bare token" means the entire string matches the ``@@agptfile:...`` pattern
(after stripping whitespace). Use :func:`expand_file_refs_in_string` to
expand references embedded in larger strings.
"""
m = _FILE_REF_RE.fullmatch(text.strip())
if not m:
return None
start = int(m.group(2)) if m.group(2) else None
end = int(m.group(3)) if m.group(3) else None
if start is not None and start < 1:
return None
if end is not None and end < 1:
return None
if start is not None and end is not None and end < start:
return None
return FileRef(uri=m.group(1), start_line=start, end_line=end)
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
"""Slice *text* to the requested 1-indexed line range (inclusive)."""
if start is None and end is None:
return text
lines = text.splitlines(keepends=True)
s = (start - 1) if start is not None else 0
e = end if end is not None else len(lines)
selected = list(itertools.islice(lines, s, e))
return "".join(selected)
async def read_file_bytes(
uri: str,
user_id: str | None,
session: ChatSession,
) -> bytes:
"""Resolve *uri* to raw bytes using workspace, local, or E2B path logic.
Raises :class:`ValueError` if the URI cannot be resolved.
"""
# Strip MIME fragment (e.g. workspace://id#mime) before dispatching.
plain = uri.split("#")[0] if uri.startswith("workspace://") else uri
if plain.startswith("workspace://"):
if not user_id:
raise ValueError("workspace:// file references require authentication")
manager = await get_manager(user_id, session.session_id)
ws = parse_workspace_uri(plain)
try:
return await (
manager.read_file(ws.file_ref)
if ws.is_path
else manager.read_file_by_id(ws.file_ref)
)
except FileNotFoundError:
raise ValueError(f"File not found: {plain}")
except Exception as exc:
raise ValueError(f"Failed to read {plain}: {exc}") from exc
if is_allowed_local_path(plain, get_sdk_cwd()):
resolved = os.path.realpath(os.path.expanduser(plain))
try:
with open(resolved, "rb") as fh:
return fh.read()
except FileNotFoundError:
raise ValueError(f"File not found: {plain}")
except Exception as exc:
raise ValueError(f"Failed to read {plain}: {exc}") from exc
sandbox = get_current_sandbox()
if sandbox is not None:
try:
remote = resolve_sandbox_path(plain)
except ValueError as exc:
raise ValueError(
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
) from exc
try:
return bytes(await sandbox.files.read(remote, format="bytes"))
except Exception as exc:
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
raise ValueError(
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
)
async def resolve_file_ref(
ref: FileRef,
user_id: str | None,
session: ChatSession,
) -> str:
"""Resolve a :class:`FileRef` to its text content."""
raw = await read_file_bytes(ref.uri, user_id, session)
return _apply_line_range(
raw.decode("utf-8", errors="replace"), ref.start_line, ref.end_line
)
async def expand_file_refs_in_string(
text: str,
user_id: str | None,
session: "ChatSession",
*,
raise_on_error: bool = False,
) -> str:
"""Expand all ``@@agptfile:...`` tokens in *text*, returning the substituted string.
Non-reference text is passed through unchanged.
If *raise_on_error* is ``False`` (default), expansion errors are surfaced
inline as ``[file-ref error: <message>]`` — useful for display/log contexts
where partial expansion is acceptable.
If *raise_on_error* is ``True``, any resolution failure raises
:class:`FileRefExpansionError` immediately so the caller can block the
operation and surface a clean error to the model.
"""
if FILE_REF_PREFIX not in text:
return text
result: list[str] = []
last_end = 0
total_chars = 0
for m in _FILE_REF_RE.finditer(text):
result.append(text[last_end : m.start()])
start = int(m.group(2)) if m.group(2) else None
end = int(m.group(3)) if m.group(3) else None
if (start is not None and start < 1) or (end is not None and end < 1):
msg = f"line numbers must be >= 1: {m.group(0)}"
if raise_on_error:
raise FileRefExpansionError(msg)
result.append(f"[file-ref error: {msg}]")
last_end = m.end()
continue
if start is not None and end is not None and end < start:
msg = f"end line must be >= start line: {m.group(0)}"
if raise_on_error:
raise FileRefExpansionError(msg)
result.append(f"[file-ref error: {msg}]")
last_end = m.end()
continue
ref = FileRef(uri=m.group(1), start_line=start, end_line=end)
try:
content = await resolve_file_ref(ref, user_id, session)
if len(content) > _MAX_EXPAND_CHARS:
content = content[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
remaining = _MAX_TOTAL_EXPAND_CHARS - total_chars
if remaining <= 0:
content = "[file-ref budget exhausted: total expansion limit reached]"
elif len(content) > remaining:
content = content[:remaining] + "\n... [total budget exhausted]"
total_chars += len(content)
result.append(content)
except ValueError as exc:
logger.warning("file-ref expansion failed for %r: %s", m.group(0), exc)
if raise_on_error:
raise FileRefExpansionError(str(exc)) from exc
result.append(f"[file-ref error: {exc}]")
last_end = m.end()
result.append(text[last_end:])
return "".join(result)
async def expand_file_refs_in_args(
args: dict[str, Any],
user_id: str | None,
session: "ChatSession",
) -> dict[str, Any]:
"""Recursively expand ``@@agptfile:...`` references in tool call arguments.
String values are expanded in-place. Nested dicts and lists are
traversed. Non-string scalars are returned unchanged.
Raises :class:`FileRefExpansionError` if any reference fails to resolve,
so the tool is *not* executed with an error string as its input. The
caller (the MCP tool wrapper) should convert this into an MCP error
response that lets the model correct the reference before retrying.
"""
if not args:
return args
async def _expand(value: Any) -> Any:
if isinstance(value, str):
return await expand_file_refs_in_string(
value, user_id, session, raise_on_error=True
)
if isinstance(value, dict):
return {k: await _expand(v) for k, v in value.items()}
if isinstance(value, list):
return [await _expand(item) for item in value]
return value
return {k: await _expand(v) for k, v in args.items()}

View File

@@ -0,0 +1,328 @@
"""Integration tests for @@agptfile: reference expansion in tool calls.
These tests verify the end-to-end behaviour of the file reference protocol:
- Parsing @@agptfile: tokens from tool arguments
- Resolving local-filesystem paths (sdk_cwd / ephemeral)
- Expanding references inside the tool-call pipeline (_execute_tool_sync)
- The extended Read tool handler (workspace:// pass-through via session context)
No real LLM or database is required; workspace reads are stubbed where needed.
"""
from __future__ import annotations
import os
import tempfile
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot.sdk.file_ref import (
FileRef,
expand_file_refs_in_args,
expand_file_refs_in_string,
read_file_bytes,
resolve_file_ref,
)
from backend.copilot.sdk.tool_adapter import _read_file_handler
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_session(session_id: str = "integ-sess") -> MagicMock:
s = MagicMock()
s.session_id = session_id
return s
# ---------------------------------------------------------------------------
# Local-file resolution (sdk_cwd)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_resolve_file_ref_local_path():
"""resolve_file_ref reads a real local file when it's within sdk_cwd."""
with tempfile.TemporaryDirectory() as sdk_cwd:
# Write a test file inside sdk_cwd
test_file = os.path.join(sdk_cwd, "hello.txt")
with open(test_file, "w") as f:
f.write("line1\nline2\nline3\n")
session = _make_session()
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
ref = FileRef(uri=test_file, start_line=None, end_line=None)
content = await resolve_file_ref(ref, user_id="u1", session=session)
assert content == "line1\nline2\nline3\n"
@pytest.mark.asyncio
async def test_resolve_file_ref_local_path_with_line_range():
"""resolve_file_ref respects line ranges for local files."""
with tempfile.TemporaryDirectory() as sdk_cwd:
test_file = os.path.join(sdk_cwd, "multi.txt")
lines = [f"line{i}\n" for i in range(1, 11)] # line1 … line10
with open(test_file, "w") as f:
f.writelines(lines)
session = _make_session()
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
ref = FileRef(uri=test_file, start_line=3, end_line=5)
content = await resolve_file_ref(ref, user_id="u1", session=session)
assert content == "line3\nline4\nline5\n"
@pytest.mark.asyncio
async def test_resolve_file_ref_rejects_path_outside_sdk_cwd():
"""resolve_file_ref raises ValueError for paths outside sdk_cwd."""
with tempfile.TemporaryDirectory() as sdk_cwd:
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox_var:
mock_cwd_var.get.return_value = sdk_cwd
mock_sandbox_var.get.return_value = None
ref = FileRef(uri="/etc/passwd", start_line=None, end_line=None)
with pytest.raises(ValueError, match="not allowed"):
await resolve_file_ref(ref, user_id="u1", session=_make_session())
# ---------------------------------------------------------------------------
# expand_file_refs_in_string — integration with real files
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_expand_string_with_real_file():
"""expand_file_refs_in_string replaces @@agptfile: token with actual content."""
with tempfile.TemporaryDirectory() as sdk_cwd:
test_file = os.path.join(sdk_cwd, "data.txt")
with open(test_file, "w") as f:
f.write("hello world\n")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_string(
f"Content: @@agptfile:{test_file}",
user_id="u1",
session=_make_session(),
)
assert result == "Content: hello world\n"
@pytest.mark.asyncio
async def test_expand_string_missing_file_is_surfaced_inline():
"""Missing file ref yields [file-ref error: …] inline rather than raising."""
with tempfile.TemporaryDirectory() as sdk_cwd:
missing = os.path.join(sdk_cwd, "does_not_exist.txt")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_string(
f"@@agptfile:{missing}",
user_id="u1",
session=_make_session(),
)
assert "[file-ref error:" in result
assert "not found" in result.lower() or "not allowed" in result.lower()
# ---------------------------------------------------------------------------
# expand_file_refs_in_args — dict traversal with real files
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_expand_args_replaces_file_ref_in_nested_dict():
"""Nested @@agptfile: references in args are fully expanded."""
with tempfile.TemporaryDirectory() as sdk_cwd:
file_a = os.path.join(sdk_cwd, "a.txt")
file_b = os.path.join(sdk_cwd, "b.txt")
with open(file_a, "w") as f:
f.write("AAA")
with open(file_b, "w") as f:
f.write("BBB")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{
"outer": {
"content_a": f"@@agptfile:{file_a}",
"content_b": f"start @@agptfile:{file_b} end",
},
"count": 42,
},
user_id="u1",
session=_make_session(),
)
assert result["outer"]["content_a"] == "AAA"
assert result["outer"]["content_b"] == "start BBB end"
assert result["count"] == 42
# ---------------------------------------------------------------------------
# _read_file_handler — extended to accept workspace:// and local paths
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_file_handler_local_file():
"""_read_file_handler reads a local file when it's within sdk_cwd."""
with tempfile.TemporaryDirectory() as sdk_cwd:
test_file = os.path.join(sdk_cwd, "read_test.txt")
lines = [f"L{i}\n" for i in range(1, 6)]
with open(test_file, "w") as f:
f.writelines(lines)
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
"backend.copilot.context._current_project_dir"
) as mock_proj_var, patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", _make_session()),
):
mock_cwd_var.get.return_value = sdk_cwd
mock_proj_var.get.return_value = ""
result = await _read_file_handler(
{"file_path": test_file, "offset": 0, "limit": 5}
)
assert not result["isError"]
text = result["content"][0]["text"]
assert "L1" in text
assert "L5" in text
@pytest.mark.asyncio
async def test_read_file_handler_workspace_uri():
"""_read_file_handler handles workspace:// URIs via the workspace manager."""
mock_session = _make_session()
mock_manager = AsyncMock()
mock_manager.read_file_by_id.return_value = b"workspace file content\nline two\n"
with patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", mock_session),
), patch(
"backend.copilot.sdk.file_ref.get_manager",
new=AsyncMock(return_value=mock_manager),
):
result = await _read_file_handler(
{"file_path": "workspace://file-id-abc", "offset": 0, "limit": 10}
)
assert not result["isError"], result["content"][0]["text"]
text = result["content"][0]["text"]
assert "workspace file content" in text
assert "line two" in text
@pytest.mark.asyncio
async def test_read_file_handler_workspace_uri_no_session():
"""_read_file_handler returns error when workspace:// is used without session."""
with patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=(None, None),
):
result = await _read_file_handler({"file_path": "workspace://some-id"})
assert result["isError"]
assert "session" in result["content"][0]["text"].lower()
@pytest.mark.asyncio
async def test_read_file_handler_access_denied():
"""_read_file_handler rejects paths outside allowed locations."""
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox, patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", _make_session()),
):
mock_cwd.get.return_value = "/tmp/safe-dir"
mock_sandbox.get.return_value = None
result = await _read_file_handler({"file_path": "/etc/passwd"})
assert result["isError"]
assert "not allowed" in result["content"][0]["text"].lower()
# ---------------------------------------------------------------------------
# read_file_bytes — workspace:///path (virtual path) and E2B sandbox branch
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_file_bytes_workspace_virtual_path():
"""workspace:///path resolves via manager.read_file (is_path=True path)."""
session = _make_session()
mock_manager = AsyncMock()
mock_manager.read_file.return_value = b"virtual path content"
with patch(
"backend.copilot.sdk.file_ref.get_manager",
new=AsyncMock(return_value=mock_manager),
):
result = await read_file_bytes("workspace:///reports/q1.md", "user-1", session)
assert result == b"virtual path content"
mock_manager.read_file.assert_awaited_once_with("/reports/q1.md")
@pytest.mark.asyncio
async def test_read_file_bytes_e2b_sandbox_branch():
"""read_file_bytes reads from the E2B sandbox when a sandbox is active."""
session = _make_session()
mock_sandbox = AsyncMock()
mock_sandbox.files.read.return_value = bytearray(b"sandbox content")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox_var, patch(
"backend.copilot.context._current_project_dir"
) as mock_proj:
mock_cwd.get.return_value = ""
mock_sandbox_var.get.return_value = mock_sandbox
mock_proj.get.return_value = ""
result = await read_file_bytes("/home/user/script.sh", None, session)
assert result == b"sandbox content"
mock_sandbox.files.read.assert_awaited_once_with(
"/home/user/script.sh", format="bytes"
)
@pytest.mark.asyncio
async def test_read_file_bytes_e2b_path_escapes_sandbox_raises():
"""read_file_bytes raises ValueError for paths that escape the sandbox root."""
session = _make_session()
mock_sandbox = AsyncMock()
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox_var, patch(
"backend.copilot.context._current_project_dir"
) as mock_proj:
mock_cwd.get.return_value = ""
mock_sandbox_var.get.return_value = mock_sandbox
mock_proj.get.return_value = ""
with pytest.raises(ValueError, match="not allowed"):
await read_file_bytes("/etc/passwd", None, session)

View File

@@ -0,0 +1,382 @@
"""Tests for the @@agptfile: reference protocol (file_ref.py)."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot.sdk.file_ref import (
_MAX_EXPAND_CHARS,
FileRef,
FileRefExpansionError,
_apply_line_range,
expand_file_refs_in_args,
expand_file_refs_in_string,
parse_file_ref,
)
# ---------------------------------------------------------------------------
# parse_file_ref
# ---------------------------------------------------------------------------
def test_parse_file_ref_workspace_id():
ref = parse_file_ref("@@agptfile:workspace://abc123")
assert ref == FileRef(uri="workspace://abc123", start_line=None, end_line=None)
def test_parse_file_ref_workspace_id_with_mime():
ref = parse_file_ref("@@agptfile:workspace://abc123#text/plain")
assert ref is not None
assert ref.uri == "workspace://abc123#text/plain"
assert ref.start_line is None
def test_parse_file_ref_workspace_path():
ref = parse_file_ref("@@agptfile:workspace:///reports/q1.md")
assert ref is not None
assert ref.uri == "workspace:///reports/q1.md"
def test_parse_file_ref_with_line_range():
ref = parse_file_ref("@@agptfile:workspace://abc123[10-50]")
assert ref == FileRef(uri="workspace://abc123", start_line=10, end_line=50)
def test_parse_file_ref_local_path():
ref = parse_file_ref("@@agptfile:/tmp/copilot-session/output.py[1-100]")
assert ref is not None
assert ref.uri == "/tmp/copilot-session/output.py"
assert ref.start_line == 1
assert ref.end_line == 100
def test_parse_file_ref_no_match():
assert parse_file_ref("just a normal string") is None
assert parse_file_ref("workspace://abc123") is None # missing @@agptfile: prefix
assert (
parse_file_ref("@@agptfile:workspace://abc123 extra") is None
) # not full match
def test_parse_file_ref_strips_whitespace():
ref = parse_file_ref(" @@agptfile:workspace://abc123 ")
assert ref is not None
assert ref.uri == "workspace://abc123"
def test_parse_file_ref_invalid_range_zero_start():
assert parse_file_ref("@@agptfile:workspace://abc123[0-5]") is None
def test_parse_file_ref_invalid_range_end_less_than_start():
assert parse_file_ref("@@agptfile:workspace://abc123[10-5]") is None
def test_parse_file_ref_invalid_range_zero_end():
assert parse_file_ref("@@agptfile:workspace://abc123[1-0]") is None
# ---------------------------------------------------------------------------
# _apply_line_range
# ---------------------------------------------------------------------------
TEXT = "line1\nline2\nline3\nline4\nline5\n"
def test_apply_line_range_no_range():
assert _apply_line_range(TEXT, None, None) == TEXT
def test_apply_line_range_start_only():
result = _apply_line_range(TEXT, 3, None)
assert result == "line3\nline4\nline5\n"
def test_apply_line_range_full():
result = _apply_line_range(TEXT, 2, 4)
assert result == "line2\nline3\nline4\n"
def test_apply_line_range_single_line():
result = _apply_line_range(TEXT, 2, 2)
assert result == "line2\n"
def test_apply_line_range_beyond_eof():
result = _apply_line_range(TEXT, 4, 999)
assert result == "line4\nline5\n"
# ---------------------------------------------------------------------------
# expand_file_refs_in_string
# ---------------------------------------------------------------------------
def _make_session(session_id: str = "sess-1") -> MagicMock:
session = MagicMock()
session.session_id = session_id
return session
async def _resolve_always(ref: FileRef, _user_id: str | None, _session: object) -> str:
"""Stub resolver that returns the URI and range as a descriptive string."""
if ref.start_line is not None:
return f"content:{ref.uri}[{ref.start_line}-{ref.end_line}]"
return f"content:{ref.uri}"
@pytest.mark.asyncio
async def test_expand_no_refs():
result = await expand_file_refs_in_string(
"no references here", user_id="u1", session=_make_session()
)
assert result == "no references here"
@pytest.mark.asyncio
async def test_expand_single_ref():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://abc123",
user_id="u1",
session=_make_session(),
)
assert result == "content:workspace://abc123"
@pytest.mark.asyncio
async def test_expand_ref_with_range():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://abc123[10-50]",
user_id="u1",
session=_make_session(),
)
assert result == "content:workspace://abc123[10-50]"
@pytest.mark.asyncio
async def test_expand_ref_embedded_in_text():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_string(
"Here is the file: @@agptfile:workspace://abc123 — done",
user_id="u1",
session=_make_session(),
)
assert result == "Here is the file: content:workspace://abc123 — done"
@pytest.mark.asyncio
async def test_expand_multiple_refs():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://file1 and @@agptfile:workspace://file2[1-5]",
user_id="u1",
session=_make_session(),
)
assert result == "content:workspace://file1 and content:workspace://file2[1-5]"
@pytest.mark.asyncio
async def test_expand_invalid_range_zero_start_surfaces_inline():
"""expand_file_refs_in_string surfaces [file-ref error: ...] for zero-start ranges."""
result = await expand_file_refs_in_string(
"@@agptfile:workspace://abc123[0-5]",
user_id="u1",
session=_make_session(),
)
assert "[file-ref error:" in result
assert "line numbers must be >= 1" in result
@pytest.mark.asyncio
async def test_expand_invalid_range_end_less_than_start_surfaces_inline():
"""expand_file_refs_in_string surfaces [file-ref error: ...] when end < start."""
result = await expand_file_refs_in_string(
"prefix @@agptfile:workspace://abc123[10-5] suffix",
user_id="u1",
session=_make_session(),
)
assert "[file-ref error:" in result
assert "end line must be >= start line" in result
assert "prefix" in result
assert "suffix" in result
@pytest.mark.asyncio
async def test_expand_ref_error_surfaces_inline():
async def _raise(*args, **kwargs): # noqa: ARG001
raise ValueError("file not found")
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_raise),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://bad",
user_id="u1",
session=_make_session(),
)
assert "[file-ref error:" in result
assert "file not found" in result
# ---------------------------------------------------------------------------
# expand_file_refs_in_args
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_expand_args_flat():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_args(
{"content": "@@agptfile:workspace://abc123", "other": 42},
user_id="u1",
session=_make_session(),
)
assert result["content"] == "content:workspace://abc123"
assert result["other"] == 42
@pytest.mark.asyncio
async def test_expand_args_nested_dict():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_args(
{"outer": {"inner": "@@agptfile:workspace://nested"}},
user_id="u1",
session=_make_session(),
)
assert result["outer"]["inner"] == "content:workspace://nested"
@pytest.mark.asyncio
async def test_expand_args_list():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_args(
{
"items": [
"@@agptfile:workspace://a",
"plain",
"@@agptfile:workspace://b[1-3]",
]
},
user_id="u1",
session=_make_session(),
)
assert result["items"] == [
"content:workspace://a",
"plain",
"content:workspace://b[1-3]",
]
@pytest.mark.asyncio
async def test_expand_args_empty():
result = await expand_file_refs_in_args({}, user_id="u1", session=_make_session())
assert result == {}
@pytest.mark.asyncio
async def test_expand_args_no_refs():
result = await expand_file_refs_in_args(
{"key": "no refs here", "num": 1},
user_id="u1",
session=_make_session(),
)
assert result == {"key": "no refs here", "num": 1}
@pytest.mark.asyncio
async def test_expand_args_raises_on_file_ref_error():
"""expand_file_refs_in_args raises FileRefExpansionError instead of passing
the inline error string to the tool, blocking tool execution."""
async def _raise(*args, **kwargs): # noqa: ARG001
raise ValueError("path does not exist")
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_raise),
):
with pytest.raises(FileRefExpansionError) as exc_info:
await expand_file_refs_in_args(
{"prompt": "@@agptfile:/home/user/missing.txt"},
user_id="u1",
session=_make_session(),
)
assert "path does not exist" in str(exc_info.value)
# ---------------------------------------------------------------------------
# Per-file truncation and aggregate budget
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_expand_per_file_truncation():
"""Content exceeding _MAX_EXPAND_CHARS is truncated with a marker."""
oversized = "x" * (_MAX_EXPAND_CHARS + 100)
async def _resolve_oversized(ref: FileRef, _uid: str | None, _s: object) -> str:
return oversized
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_oversized),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://big-file",
user_id="u1",
session=_make_session(),
)
assert len(result) <= _MAX_EXPAND_CHARS + len("\n... [truncated]") + 10
assert "[truncated]" in result
@pytest.mark.asyncio
async def test_expand_aggregate_budget_exhausted():
"""When the aggregate budget is exhausted, later refs get the budget message."""
# Each file returns just under 300K; after ~4 files the 1M budget is used.
big_chunk = "y" * 300_000
async def _resolve_big(ref: FileRef, _uid: str | None, _s: object) -> str:
return big_chunk
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_big),
):
# 5 refs @ 300K each = 1.5M → last ref(s) should hit the aggregate limit
refs = " ".join(f"@@agptfile:workspace://f{i}" for i in range(5))
result = await expand_file_refs_in_string(
refs,
user_id="u1",
session=_make_session(),
)
assert "budget exhausted" in result

View File

@@ -0,0 +1,28 @@
## MCP Tool Guide
### Workflow
`run_mcp_tool` follows a two-step pattern:
1. **Discover** — call with only `server_url` to list available tools on the server.
2. **Execute** — call again with `server_url`, `tool_name`, and `tool_arguments` to run a tool.
### Known hosted MCP servers
Use these URLs directly without asking the user:
| Service | URL |
|---|---|
| Notion | `https://mcp.notion.com/mcp` |
| Linear | `https://mcp.linear.app/mcp` |
| Stripe | `https://mcp.stripe.com` |
| Intercom | `https://mcp.intercom.com/mcp` |
| Cloudflare | `https://mcp.cloudflare.com/mcp` |
| Atlassian / Jira | `https://mcp.atlassian.com/mcp` |
For other services, search the MCP registry at https://registry.modelcontextprotocol.io/.
### Authentication
If the server requires credentials, a `SetupRequirementsResponse` is returned with an OAuth
login prompt. Once the user completes the flow and confirms, retry the same call immediately.

View File

@@ -536,10 +536,12 @@ async def test_wait_for_stash_signaled():
result = await wait_for_stash(timeout=1.0)
assert result is True
assert _pto.get({}).get("WebSearch") == ["result data"]
pto = _pto.get()
assert pto is not None
assert pto.get("WebSearch") == ["result data"]
# Cleanup
_pto.set({}) # type: ignore[arg-type]
_pto.set({})
_stash_event.set(None)
@@ -554,7 +556,7 @@ async def test_wait_for_stash_timeout():
assert result is False
# Cleanup
_pto.set({}) # type: ignore[arg-type]
_pto.set({})
_stash_event.set(None)
@@ -573,10 +575,12 @@ async def test_wait_for_stash_already_stashed():
assert result is True
# But the stash itself is populated
assert _pto.get({}).get("Read") == ["file contents"]
pto = _pto.get()
assert pto is not None
assert pto.get("Read") == ["file contents"]
# Cleanup
_pto.set({}) # type: ignore[arg-type]
_pto.set({})
_stash_event.set(None)

View File

@@ -10,12 +10,13 @@ import re
from collections.abc import Callable
from typing import Any, cast
from backend.copilot.context import is_allowed_local_path
from .tool_adapter import (
BLOCKED_TOOLS,
DANGEROUS_PATTERNS,
MCP_TOOL_PREFIX,
WORKSPACE_SCOPED_TOOLS,
is_allowed_local_path,
stash_pending_tool_output,
)

View File

@@ -9,8 +9,9 @@ import os
import pytest
from backend.copilot.context import _current_project_dir
from .security_hooks import _validate_tool_access, _validate_user_isolation
from .service import _is_tool_error_or_denial
SDK_CWD = "/tmp/copilot-abc123"
@@ -120,8 +121,6 @@ def test_read_no_cwd_denies_absolute():
def test_read_tool_results_allowed():
from .tool_adapter import _current_project_dir
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
# is_allowed_local_path requires the session's encoded cwd to be set
@@ -133,16 +132,14 @@ def test_read_tool_results_allowed():
_current_project_dir.reset(token)
def test_read_claude_projects_session_dir_allowed():
"""Files within the current session's project dir are allowed."""
from .tool_adapter import _current_project_dir
def test_read_claude_projects_settings_json_denied():
"""SDK-internal artifacts like settings.json are NOT accessible — only tool-results/ is."""
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
token = _current_project_dir.set("-tmp-copilot-abc123")
try:
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert not _is_denied(result)
assert _is_denied(result)
finally:
_current_project_dir.reset(token)
@@ -357,76 +354,3 @@ async def test_task_slot_released_on_failure(_hooks):
context={},
)
assert not _is_denied(result)
# -- _is_tool_error_or_denial ------------------------------------------------
class TestIsToolErrorOrDenial:
def test_none_content(self):
assert _is_tool_error_or_denial(None) is False
def test_empty_content(self):
assert _is_tool_error_or_denial("") is False
def test_benign_output(self):
assert _is_tool_error_or_denial("All good, no issues.") is False
def test_security_marker(self):
assert _is_tool_error_or_denial("[SECURITY] Tool access blocked") is True
def test_cannot_be_bypassed(self):
assert _is_tool_error_or_denial("This restriction cannot be bypassed.") is True
def test_not_allowed(self):
assert _is_tool_error_or_denial("Operation not allowed in sandbox") is True
def test_background_task_denial(self):
assert (
_is_tool_error_or_denial(
"Background task execution is not supported. "
"Run tasks in the foreground instead."
)
is True
)
def test_subtask_limit_denial(self):
assert (
_is_tool_error_or_denial(
"Maximum 2 concurrent sub-tasks. "
"Wait for running sub-tasks to finish, "
"or continue in the main conversation."
)
is True
)
def test_denied_marker(self):
assert (
_is_tool_error_or_denial("Access denied: insufficient privileges") is True
)
def test_blocked_marker(self):
assert _is_tool_error_or_denial("Request blocked by security policy") is True
def test_failed_marker(self):
assert _is_tool_error_or_denial("Failed to execute tool: timeout") is True
def test_mcp_iserror(self):
assert _is_tool_error_or_denial('{"isError": true, "content": []}') is True
def test_benign_error_in_value(self):
"""Content like '0 errors found' should not trigger — 'error' was removed."""
assert _is_tool_error_or_denial("0 errors found") is False
def test_benign_permission_field(self):
"""Schema descriptions mentioning 'permission' should not trigger."""
assert (
_is_tool_error_or_denial(
'{"fields": [{"name": "permission_level", "type": "int"}]}'
)
is False
)
def test_benign_not_found_in_listing(self):
"""File listing containing 'not found' in filenames should not trigger."""
assert _is_tool_error_or_denial("readme.md\nfile-not-found-handler.py") is False

View File

@@ -60,7 +60,7 @@ from ..service import (
_generate_session_title,
_is_langfuse_configured,
)
from ..tools.e2b_sandbox import get_or_create_sandbox
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tools.workspace_files import get_manager
from ..tracking import track_user_message
@@ -456,31 +456,6 @@ def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
def _is_tool_error_or_denial(content: str | None) -> bool:
"""Check if a tool message content indicates an error or denial.
Currently unused — ``_format_conversation_context`` includes all tool
results. Kept as a utility for future selective filtering.
"""
if not content:
return False
lower = content.lower()
return any(
marker in lower
for marker in (
"[security]",
"cannot be bypassed",
"not allowed",
"not supported", # background-task denial
"maximum", # subtask-limit denial
"denied",
"blocked",
"failed to", # internal tool execution failures
'"iserror": true', # MCP protocol error flag
)
)
async def _build_query_message(
current_message: str,
session: ChatSession,
@@ -784,28 +759,29 @@ async def stream_chat_completion_sdk(
async def _setup_e2b():
"""Set up E2B sandbox if configured, return sandbox or None."""
if config.use_e2b_sandbox and not config.e2b_api_key:
logger.warning(
"[E2B] [%s] E2B sandbox enabled but no API key configured "
"(CHAT_E2B_API_KEY / E2B_API_KEY) — falling back to bubblewrap",
session_id[:12],
)
return None
if config.use_e2b_sandbox and config.e2b_api_key:
try:
return await get_or_create_sandbox(
session_id,
api_key=config.e2b_api_key,
template=config.e2b_sandbox_template,
timeout=config.e2b_sandbox_timeout,
)
except Exception as e2b_err:
logger.error(
"[E2B] [%s] Setup failed: %s",
if not (e2b_api_key := config.active_e2b_api_key):
if config.use_e2b_sandbox:
logger.warning(
"[E2B] [%s] E2B sandbox enabled but no API key configured "
"(CHAT_E2B_API_KEY / E2B_API_KEY) — falling back to bubblewrap",
session_id[:12],
e2b_err,
exc_info=True,
)
return None
try:
return await get_or_create_sandbox(
session_id,
api_key=e2b_api_key,
template=config.e2b_sandbox_template,
timeout=config.e2b_sandbox_timeout,
on_timeout=config.e2b_sandbox_on_timeout,
)
except Exception as e2b_err:
logger.error(
"[E2B] [%s] Setup failed: %s",
session_id[:12],
e2b_err,
exc_info=True,
)
return None
async def _fetch_transcript():
@@ -837,7 +813,6 @@ async def stream_chat_completion_sdk(
system_prompt = base_system_prompt + get_sdk_supplement(
use_e2b=use_e2b, cwd=sdk_cwd
)
# Process transcript download result
transcript_msg_count = 0
if dl:
@@ -902,6 +877,11 @@ async def stream_chat_completion_sdk(
allowed = get_copilot_tool_names(use_e2b=use_e2b)
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
def _on_stderr(line: str) -> None:
sid = session_id[:12] if session_id else "?"
logger.info("[SDK] [%s] CLI stderr: %s", sid, line.rstrip())
sdk_options_kwargs: dict[str, Any] = {
"system_prompt": system_prompt,
"mcp_servers": {"copilot": mcp_server},
@@ -910,6 +890,7 @@ async def stream_chat_completion_sdk(
"hooks": security_hooks,
"cwd": sdk_cwd,
"max_buffer_size": config.claude_agent_max_buffer_size,
"stderr": _on_stderr,
}
if sdk_model:
sdk_options_kwargs["model"] = sdk_model
@@ -1082,6 +1063,19 @@ async def stream_chat_completion_sdk(
len(adapter.resolved_tool_calls),
)
# Log AssistantMessage API errors (e.g. invalid_request)
# so we can debug Anthropic API 400s surfaced by the CLI.
sdk_error = getattr(sdk_msg, "error", None)
if isinstance(sdk_msg, AssistantMessage) and sdk_error:
logger.error(
"[SDK] [%s] AssistantMessage has error=%s, "
"content_blocks=%d, content_preview=%s",
session_id[:12],
sdk_error,
len(sdk_msg.content),
str(sdk_msg.content)[:500],
)
# Race-condition fix: SDK hooks (PostToolUse) are
# executed asynchronously via start_soon() — the next
# message can arrive before the hook stashes output.
@@ -1416,6 +1410,17 @@ async def stream_chat_completion_sdk(
exc_info=True,
)
# --- Pause E2B sandbox to stop billing between turns ---
# Fire-and-forget: pausing is best-effort and must not block the
# response or the transcript upload. The task is anchored to
# _background_tasks to prevent garbage collection.
# Use pause_sandbox_direct to skip the Redis lookup and reconnect
# round-trip — e2b_sandbox is the live object from this turn.
if e2b_sandbox is not None:
task = asyncio.create_task(pause_sandbox_direct(e2b_sandbox, session_id))
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
# --- Upload transcript for next-turn --resume ---
# This MUST run in finally so the transcript is uploaded even when
# the streaming loop raises an exception.

View File

@@ -1,9 +1,10 @@
"""Tests for SDK service helpers."""
import asyncio
import base64
import os
from dataclasses import dataclass
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -212,7 +213,7 @@ class TestPromptSupplement:
# Workflows are now in individual tool descriptions (not separate sections)
# Check that key workflow concepts appear in tool descriptions
assert "suggested_goal" in docs or "clarifying_questions" in docs
assert "agent_json" in docs or "find_block" in docs
assert "run_mcp_tool" in docs
def test_baseline_supplement_completeness(self):
@@ -231,6 +232,48 @@ class TestPromptSupplement:
f"`{tool_name}`" in docs
), f"Tool '{tool_name}' missing from baseline supplement"
def test_pause_task_scheduled_before_transcript_upload(self):
"""Pause is scheduled as a background task before transcript upload begins.
The finally block in stream_response_sdk does:
(1) asyncio.create_task(pause_sandbox_direct(...)) — fire-and-forget
(2) await asyncio.shield(upload_transcript(...)) — awaited
Scheduling pause via create_task before awaiting upload ensures:
- Pause never blocks transcript upload (billing stops concurrently)
- On E2B timeout, pause silently fails; upload proceeds unaffected
"""
call_order: list[str] = []
async def _mock_pause(sandbox, session_id):
call_order.append("pause")
async def _mock_upload(**kwargs):
call_order.append("upload")
async def _simulate_teardown():
"""Mirror the service.py finally block teardown sequence."""
sandbox = MagicMock()
# (1) Schedule pause — mirrors lines ~1427-1429 in service.py
task = asyncio.create_task(_mock_pause(sandbox, "test-sess"))
# (2) Await transcript upload — mirrors lines ~1460-1468 in service.py
# Yielding to the event loop here lets the pause task start concurrently.
await _mock_upload(
user_id="u", session_id="test-sess", content="x", message_count=1
)
await task
asyncio.run(_simulate_teardown())
# Both must run; pause is scheduled before upload starts
assert "pause" in call_order
assert "upload" in call_order
# create_task schedules pause, then upload is awaited — pause runs
# concurrently during upload's first yield. The ordering guarantee is
# that create_task is CALLED before upload is AWAITED (see source order).
def test_baseline_supplement_no_duplicate_tools(self):
"""No tool should appear multiple times in baseline supplement."""
from backend.copilot.prompting import get_baseline_supplement

View File

@@ -9,14 +9,29 @@ import itertools
import json
import logging
import os
import re
import uuid
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any
from claude_agent_sdk import create_sdk_mcp_server, tool
from backend.copilot.context import (
_current_project_dir,
_current_sandbox,
_current_sdk_cwd,
_current_session,
_current_user_id,
_encode_cwd_for_cli,
get_execution_context,
get_sdk_cwd,
is_allowed_local_path,
)
from backend.copilot.model import ChatSession
from backend.copilot.sdk.file_ref import (
FileRefExpansionError,
expand_file_refs_in_args,
read_file_bytes,
)
from backend.copilot.tools import TOOL_REGISTRY
from backend.copilot.tools.base import BaseTool
from backend.util.truncate import truncate
@@ -28,84 +43,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Allowed base directory for the Read tool (SDK saves oversized tool results here).
# Restricted to ~/.claude/projects/ and further validated to require "tool-results"
# in the path — prevents reading settings, credentials, or other sensitive files.
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# Max MCP response size in chars — keeps tool output under the SDK's 10 MB JSON buffer.
_MCP_MAX_CHARS = 500_000
# Context variable holding the encoded project directory name for the current
# session (e.g. "-private-tmp-copilot-<uuid>"). Set by set_execution_context()
# so that path validation can scope tool-results reads to the current session.
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
def _encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does.
The CLI replaces all non-alphanumeric characters with ``-``.
"""
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
"""Check whether *path* is an allowed host-filesystem path.
Allowed:
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
- Files under ``~/.claude/projects/<encoded-cwd>/`` — the SDK's
project directory for this session (tool-results, transcripts, etc.)
Both checks are scoped to the **current session** so sessions cannot
read each other's data.
"""
if not path:
return False
if path.startswith("~"):
resolved = os.path.realpath(os.path.expanduser(path))
elif not os.path.isabs(path) and sdk_cwd:
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
else:
resolved = os.path.realpath(path)
# Allow access within the SDK working directory
if sdk_cwd:
norm_cwd = os.path.realpath(sdk_cwd)
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
return True
# Allow access within the current session's CLI project directory
# (~/.claude/projects/<encoded-cwd>/).
encoded = _current_project_dir.get("")
if encoded:
session_project = os.path.join(_SDK_PROJECTS_DIR, encoded)
if resolved == session_project or resolved.startswith(session_project + os.sep):
return True
return False
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
MCP_SERVER_NAME = "copilot"
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
# Context variables to pass user/session info to tool execution
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
_current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
)
# E2B cloud sandbox for the current turn (None when E2B is not configured).
# Passed to bash_exec so commands run on E2B instead of the local bwrap sandbox.
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
"_current_sandbox", default=None
)
# Raw SDK working directory path (e.g. /tmp/copilot-<session_id>).
# Used by workspace tools to save binary files for the CLI's built-in Read.
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
# Stash for MCP tool outputs before the SDK potentially truncates them.
# Keyed by tool_name → full output string. Consumed (popped) by the
# response adapter when it builds StreamToolOutputAvailable.
@@ -149,24 +93,6 @@ def set_execution_context(
_stash_event.set(asyncio.Event())
def get_current_sandbox() -> "AsyncSandbox | None":
"""Return the E2B sandbox for the current turn, or None."""
return _current_sandbox.get()
def get_sdk_cwd() -> str:
"""Return the SDK ephemeral working directory for the current turn."""
return _current_sdk_cwd.get()
def get_execution_context() -> tuple[str | None, ChatSession | None]:
"""Get the current execution context."""
return (
_current_user_id.get(),
_current_session.get(),
)
def pop_pending_tool_output(tool_name: str) -> str | None:
"""Pop and return the oldest stashed output for *tool_name*.
@@ -259,7 +185,11 @@ async def _execute_tool_sync(
session: ChatSession,
args: dict[str, Any],
) -> dict[str, Any]:
"""Execute a tool synchronously and return MCP-formatted response."""
"""Execute a tool synchronously and return MCP-formatted response.
Note: ``@@agptfile:`` expansion is handled upstream in the ``_truncating`` wrapper
so all registered handlers (BaseTool, E2B, Read) expand uniformly.
"""
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
result = await base_tool.execute(
user_id=user_id,
@@ -320,42 +250,50 @@ def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Read a local file with optional offset/limit.
"""Read a file with optional offset/limit.
Only allows paths that pass :func:`is_allowed_local_path` — the current
session's tool-results directory and ephemeral working directory.
Supports ``workspace://`` URIs (delegated to the workspace manager) and
local paths within the session's allowed directories (sdk_cwd + tool-results).
"""
file_path = args.get("file_path", "")
offset = args.get("offset", 0)
limit = args.get("limit", 2000)
offset = max(0, int(args.get("offset", 0)))
limit = max(1, int(args.get("limit", 2000)))
if not is_allowed_local_path(file_path):
return {
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
"isError": True,
}
def _mcp_err(text: str) -> dict[str, Any]:
return {"content": [{"type": "text", "text": text}], "isError": True}
def _mcp_ok(text: str) -> dict[str, Any]:
return {"content": [{"type": "text", "text": text}], "isError": False}
if file_path.startswith("workspace://"):
user_id, session = get_execution_context()
if session is None:
return _mcp_err("workspace:// file references require an active session")
try:
raw = await read_file_bytes(file_path, user_id, session)
except ValueError as exc:
return _mcp_err(str(exc))
lines = raw.decode("utf-8", errors="replace").splitlines(keepends=True)
selected = list(itertools.islice(lines, offset, offset + limit))
numbered = "".join(
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
)
return _mcp_ok(numbered)
if not is_allowed_local_path(file_path, get_sdk_cwd()):
return _mcp_err(f"Path not allowed: {file_path}")
resolved = os.path.realpath(os.path.expanduser(file_path))
try:
with open(resolved) as f:
selected = list(itertools.islice(f, offset, offset + limit))
content = "".join(selected)
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
# don't delete here — the SDK may read in multiple chunks.
return {
"content": [{"type": "text", "text": content}],
"isError": False,
}
return _mcp_ok("".join(selected))
except FileNotFoundError:
return {
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
"isError": True,
}
return _mcp_err(f"File not found: {file_path}")
except Exception as e:
return {
"content": [{"type": "text", "text": f"Error reading file: {e}"}],
"isError": True,
}
return _mcp_err(f"Error reading file: {e}")
_READ_TOOL_NAME = "Read"
@@ -414,9 +352,23 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
SDK's 10 MB JSON buffer, and stash the (truncated) output for the
response adapter before the SDK can apply its own head-truncation.
Also expands ``@@agptfile:`` references in args so every registered tool
(BaseTool, E2B file tools, Read) receives resolved content uniformly.
Applied once to every registered tool."""
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
user_id, session = get_execution_context()
if session is not None:
try:
args = await expand_file_refs_in_args(args, user_id, session)
except FileRefExpansionError as exc:
return _mcp_error(
f"@@agptfile: reference could not be resolved: {exc}. "
"Ensure the file exists before referencing it. "
"For sandbox paths use bash_exec to verify the file exists first; "
"for workspace files use a workspace:// URI."
)
result = await fn(args)
truncated = truncate(result, _MCP_MAX_CHARS)

View File

@@ -2,12 +2,12 @@
import pytest
from backend.copilot.context import get_sdk_cwd
from backend.util.truncate import truncate
from .tool_adapter import (
_MCP_MAX_CHARS,
_text_from_mcp_result,
get_sdk_cwd,
pop_pending_tool_output,
set_execution_context,
stash_pending_tool_output,

View File

@@ -12,6 +12,7 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
from .agent_output import AgentOutputTool
from .base import BaseTool
from .bash_exec import BashExecTool
from .continue_run_block import ContinueRunBlockTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
from .edit_agent import EditAgentTool
@@ -19,7 +20,10 @@ from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsToo
from .find_agent import FindAgentTool
from .find_block import FindBlockTool
from .find_library_agent import FindLibraryAgentTool
from .fix_agent import FixAgentGraphTool
from .get_agent_building_guide import GetAgentBuildingGuideTool
from .get_doc_page import GetDocPageTool
from .get_mcp_guide import GetMCPGuideTool
from .manage_folders import (
CreateFolderTool,
DeleteFolderTool,
@@ -32,6 +36,7 @@ from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .run_mcp_tool import RunMCPToolTool
from .search_docs import SearchDocsTool
from .validate_agent import ValidateAgentGraphTool
from .web_fetch import WebFetchTool
from .workspace_files import (
DeleteWorkspaceFileTool,
@@ -64,10 +69,13 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"move_agents_to_folder": MoveAgentsToFolderTool(),
"run_agent": RunAgentTool(),
"run_block": RunBlockTool(),
"continue_run_block": ContinueRunBlockTool(),
"run_mcp_tool": RunMCPToolTool(),
"get_mcp_guide": GetMCPGuideTool(),
"view_agent_output": AgentOutputTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
"get_agent_building_guide": GetAgentBuildingGuideTool(),
# Web fetch for safe URL retrieval
"web_fetch": WebFetchTool(),
# Agent-browser multi-step automation (navigate, act, screenshot)
@@ -80,6 +88,9 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
# Feature request tools
"search_feature_requests": SearchFeatureRequestsTool(),
"create_feature_request": CreateFeatureRequestTool(),
# Agent generation tools (local validation/fixing)
"validate_agent_graph": ValidateAgentGraphTool(),
"fix_agent_graph": FixAgentGraphTool(),
# Workspace tools for CoPilot file operations
"list_workspace_files": ListWorkspaceFilesTool(),
"read_workspace_file": ReadWorkspaceFileTool(),

View File

@@ -33,7 +33,7 @@ import tempfile
from typing import Any
from backend.copilot.model import ChatSession
from backend.util.request import validate_url
from backend.util.request import validate_url_host
from .base import BaseTool
from .models import (
@@ -235,7 +235,7 @@ async def _restore_browser_state(
if url:
# Validate the saved URL to prevent SSRF via stored redirect targets.
try:
await validate_url(url, trusted_origins=[])
await validate_url_host(url)
except ValueError:
logger.warning(
"[browser] State restore: blocked SSRF URL %s", url[:200]
@@ -473,7 +473,7 @@ class BrowserNavigateTool(BaseTool):
)
try:
await validate_url(url, trusted_origins=[])
await validate_url_host(url)
except ValueError as e:
return ErrorResponse(
message=str(e),

View File

@@ -68,17 +68,18 @@ def _run_result(rc: int = 0, stdout: str = "", stderr: str = "") -> tuple:
# ---------------------------------------------------------------------------
# SSRF protection via shared validate_url (backend.util.request)
# SSRF protection via shared validate_url_host (backend.util.request)
# ---------------------------------------------------------------------------
# Patch target: validate_url is imported directly into agent_browser's module scope.
_VALIDATE_URL = "backend.copilot.tools.agent_browser.validate_url"
# Patch target: validate_url_host is imported directly into agent_browser's
# module scope.
_VALIDATE_URL = "backend.copilot.tools.agent_browser.validate_url_host"
class TestSsrfViaValidateUrl:
"""Verify that browser_navigate uses validate_url for SSRF protection.
"""Verify that browser_navigate uses validate_url_host for SSRF protection.
We mock validate_url itself (not the low-level socket) so these tests
We mock validate_url_host itself (not the low-level socket) so these tests
exercise the integration point, not the internals of request.py
(which has its own thorough test suite in request_test.py).
"""
@@ -89,7 +90,7 @@ class TestSsrfViaValidateUrl:
@pytest.mark.asyncio
async def test_blocked_ip_returns_blocked_url_error(self):
"""validate_url raises ValueError → tool returns blocked_url ErrorResponse."""
"""validate_url_host raises ValueError → tool returns blocked_url ErrorResponse."""
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
mock_validate.side_effect = ValueError(
"Access to blocked IP 10.0.0.1 is not allowed."
@@ -124,8 +125,8 @@ class TestSsrfViaValidateUrl:
assert result.error == "blocked_url"
@pytest.mark.asyncio
async def test_validate_url_called_with_empty_trusted_origins(self):
"""Confirms no trusted-origins bypass is granted — all URLs are validated."""
async def test_validate_url_host_called_without_trusted_hostnames(self):
"""Confirms no trusted-hostnames bypass is granted — all URLs are validated."""
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
mock_validate.return_value = (object(), False, ["1.2.3.4"])
with patch(
@@ -143,7 +144,7 @@ class TestSsrfViaValidateUrl:
session=self.session,
url="https://example.com",
)
mock_validate.assert_called_once_with("https://example.com", trusted_origins=[])
mock_validate.assert_called_once_with("https://example.com")
# ---------------------------------------------------------------------------

View File

@@ -1,20 +1,15 @@
"""Agent generator package - Creates agents from natural language."""
from .core import (
AgentGeneratorNotConfiguredError,
AgentJsonValidationError,
AgentSummary,
DecompositionResult,
DecompositionStep,
LibraryAgentSummary,
MarketplaceAgentSummary,
customize_template,
decompose_goal,
enrich_library_agents_from_steps,
extract_search_terms_from_steps,
extract_uuids_from_text,
generate_agent,
generate_agent_patch,
get_agent_as_json,
get_all_relevant_agents_for_generation,
get_library_agent_by_graph_id,
@@ -27,25 +22,20 @@ from .core import (
search_marketplace_agents_for_generation,
)
from .errors import get_user_message_for_error
from .service import health_check as check_external_service_health
from .service import is_external_service_configured
from .validation import AgentFixer, AgentValidator
__all__ = [
"AgentGeneratorNotConfiguredError",
"AgentFixer",
"AgentValidator",
"AgentJsonValidationError",
"AgentSummary",
"DecompositionResult",
"DecompositionStep",
"LibraryAgentSummary",
"MarketplaceAgentSummary",
"check_external_service_health",
"customize_template",
"decompose_goal",
"enrich_library_agents_from_steps",
"extract_search_terms_from_steps",
"extract_uuids_from_text",
"generate_agent",
"generate_agent_patch",
"get_agent_as_json",
"get_all_relevant_agents_for_generation",
"get_library_agent_by_graph_id",
@@ -54,7 +44,6 @@ __all__ = [
"get_library_agents_for_generation",
"get_user_message_for_error",
"graph_to_json",
"is_external_service_configured",
"json_to_graph",
"save_agent_to_library",
"search_marketplace_agents_for_generation",

View File

@@ -0,0 +1,66 @@
"""Block management for agent generation.
Provides cached access to block metadata for validation and fixing.
"""
import logging
from typing import Any, Type
from backend.blocks import get_blocks as get_block_classes
from backend.blocks._base import Block
logger = logging.getLogger(__name__)
__all__ = ["get_blocks_as_dicts", "reset_block_caches"]
# ---------------------------------------------------------------------------
# Module-level caches
# ---------------------------------------------------------------------------
_blocks_cache: list[dict[str, Any]] | None = None
def reset_block_caches() -> None:
"""Reset all module-level caches (useful after updating block descriptions)."""
global _blocks_cache
_blocks_cache = None
# ---------------------------------------------------------------------------
# 1. get_blocks_as_dicts
# ---------------------------------------------------------------------------
def get_blocks_as_dicts() -> list[dict[str, Any]]:
"""Get all available blocks as dicts (cached after first call).
Each dict contains the keys returned by ``Block.get_info().model_dump()``:
id, name, description, inputSchema, outputSchema, categories,
staticOutput, costs, contributors, uiType.
Returns:
List of block info dicts.
"""
global _blocks_cache
if _blocks_cache is not None:
return _blocks_cache
block_classes: dict[str, Type[Block]] = get_block_classes() # type: ignore[assignment]
blocks: list[dict[str, Any]] = []
for block_cls in block_classes.values():
try:
instance = block_cls()
info = instance.get_info().model_dump()
# Use optimized description if available (loaded at startup)
if instance.optimized_description:
info["description"] = instance.optimized_description
blocks.append(info)
except Exception:
logger.warning(
"Failed to load block info for %s, skipping",
getattr(block_cls, "__name__", "unknown"),
exc_info=True,
)
_blocks_cache = blocks
logger.info("Cached %d block dicts", len(blocks))
return _blocks_cache

View File

@@ -10,13 +10,7 @@ from backend.data.db_accessors import graph_db, library_db, store_db
from backend.data.graph import Graph, Link, Node
from backend.util.exceptions import DatabaseError, NotFoundError
from .service import (
customize_template_external,
decompose_goal_external,
generate_agent_external,
generate_agent_patch_external,
is_external_service_configured,
)
from .helpers import UUID_RE_STR
logger = logging.getLogger(__name__)
@@ -78,38 +72,7 @@ class DecompositionResult(TypedDict, total=False):
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
def _to_dict_list(
agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None,
) -> list[dict[str, Any]] | None:
"""Convert typed agent summaries to plain dicts for external service calls."""
if agents is None:
return None
return [dict(a) for a in agents]
class AgentGeneratorNotConfiguredError(Exception):
"""Raised when the external Agent Generator service is not configured."""
pass
def _check_service_configured() -> None:
"""Check if the external Agent Generator service is configured.
Raises:
AgentGeneratorNotConfiguredError: If the service is not configured.
"""
if not is_external_service_configured():
raise AgentGeneratorNotConfiguredError(
"Agent Generator service is not configured. "
"Set AGENTGENERATOR_HOST environment variable to enable agent generation."
)
_UUID_PATTERN = re.compile(
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
re.IGNORECASE,
)
_UUID_PATTERN = re.compile(UUID_RE_STR, re.IGNORECASE)
def extract_uuids_from_text(text: str) -> list[str]:
@@ -553,69 +516,6 @@ async def enrich_library_agents_from_steps(
return all_agents
async def decompose_goal(
description: str,
context: str = "",
library_agents: Sequence[AgentSummary] | None = None,
) -> DecompositionResult | None:
"""Break down a goal into steps or return clarifying questions.
Args:
description: Natural language goal description
context: Additional context (e.g., answers to previous questions)
library_agents: User's library agents available for sub-agent composition
Returns:
DecompositionResult with either:
- {"type": "clarifying_questions", "questions": [...]}
- {"type": "instructions", "steps": [...]}
Or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for decompose_goal")
result = await decompose_goal_external(
description, context, _to_dict_list(library_agents)
)
return result # type: ignore[return-value]
async def generate_agent(
instructions: DecompositionResult | dict[str, Any],
library_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Generate agent JSON from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
Returns:
Agent JSON dict, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent")
result = await generate_agent_external(
dict(instructions), _to_dict_list(library_agents)
)
if result:
if isinstance(result, dict) and result.get("type") == "error":
return result
if "id" not in result:
result["id"] = str(uuid.uuid4())
if "version" not in result:
result["version"] = 1
if "is_active" not in result:
result["is_active"] = True
return result
class AgentJsonValidationError(Exception):
"""Raised when agent JSON is invalid or missing required fields."""
@@ -792,70 +692,3 @@ async def get_agent_as_json(
return None
return graph_to_json(graph)
async def generate_agent_patch(
update_request: str,
current_agent: dict[str, Any],
library_agents: Sequence[AgentSummary] | None = None,
) -> dict[str, Any] | None:
"""Update an existing agent using natural language.
The external Agent Generator service handles:
- Generating the patch
- Applying the patch
- Fixing and validating the result
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
Returns:
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent_patch")
return await generate_agent_patch_external(
update_request,
current_agent,
_to_dict_list(library_agents),
)
async def customize_template(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any] | None:
"""Customize a template/marketplace agent using natural language.
This is used when users want to modify a template or marketplace agent
to fit their specific needs before adding it to their library.
The external Agent Generator service handles:
- Understanding the modification request
- Applying changes to the template
- Fixing and validating the result
Args:
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
Returns:
Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
error dict {"type": "error", ...}, or None on unexpected error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for customize_template")
return await customize_template_external(
template_agent, modification_request, context
)

View File

@@ -1,165 +0,0 @@
"""Dummy Agent Generator for testing.
Returns mock responses matching the format expected from the external service.
Enable via AGENTGENERATOR_USE_DUMMY=true in settings.
WARNING: This is for testing only. Do not use in production.
"""
import asyncio
import logging
import uuid
from typing import Any
logger = logging.getLogger(__name__)
# Dummy decomposition result (instructions type)
DUMMY_DECOMPOSITION_RESULT: dict[str, Any] = {
"type": "instructions",
"steps": [
{
"description": "Get input from user",
"action": "input",
"block_name": "AgentInputBlock",
},
{
"description": "Process the input",
"action": "process",
"block_name": "TextFormatterBlock",
},
{
"description": "Return output to user",
"action": "output",
"block_name": "AgentOutputBlock",
},
],
}
# Block IDs from backend/blocks/io.py
AGENT_INPUT_BLOCK_ID = "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4"
def _generate_dummy_agent_json() -> dict[str, Any]:
"""Generate a minimal valid agent JSON for testing."""
input_node_id = str(uuid.uuid4())
output_node_id = str(uuid.uuid4())
return {
"id": str(uuid.uuid4()),
"version": 1,
"is_active": True,
"name": "Dummy Test Agent",
"description": "A dummy agent generated for testing purposes",
"nodes": [
{
"id": input_node_id,
"block_id": AGENT_INPUT_BLOCK_ID,
"input_default": {
"name": "input",
"title": "Input",
"description": "Enter your input",
"placeholder_values": [],
},
"metadata": {"position": {"x": 0, "y": 0}},
},
{
"id": output_node_id,
"block_id": AGENT_OUTPUT_BLOCK_ID,
"input_default": {
"name": "output",
"title": "Output",
"description": "Agent output",
"format": "{output}",
},
"metadata": {"position": {"x": 400, "y": 0}},
},
],
"links": [
{
"id": str(uuid.uuid4()),
"source_id": input_node_id,
"sink_id": output_node_id,
"source_name": "result",
"sink_name": "value",
"is_static": False,
},
],
}
async def decompose_goal_dummy(
description: str,
context: str = "",
library_agents: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
"""Return dummy decomposition result."""
logger.info("Using dummy agent generator for decompose_goal")
return DUMMY_DECOMPOSITION_RESULT.copy()
async def generate_agent_dummy(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Return dummy agent synchronously (blocks for 30s, returns agent JSON).
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
"""
logger.info(
"Using dummy agent generator (sync mode): returning agent JSON after 30s"
)
await asyncio.sleep(30)
return _generate_dummy_agent_json()
async def generate_agent_patch_dummy(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Return dummy patched agent synchronously (blocks for 30s, returns patched agent JSON).
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
"""
logger.info(
"Using dummy agent generator patch (sync mode): returning patched agent after 30s"
)
await asyncio.sleep(30)
patched = current_agent.copy()
patched["description"] = (
f"{current_agent.get('description', '')} (updated: {update_request})"
)
return patched
async def customize_template_dummy(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any]:
"""Return dummy customized template (returns template with updated description)."""
logger.info("Using dummy agent generator for customize_template")
customized = template_agent.copy()
customized["description"] = (
f"{template_agent.get('description', '')} (customized: {modification_request})"
)
return customized
async def get_blocks_dummy() -> list[dict[str, Any]]:
"""Return dummy blocks list."""
logger.info("Using dummy agent generator for get_blocks")
return [
{"id": AGENT_INPUT_BLOCK_ID, "name": "AgentInputBlock"},
{"id": AGENT_OUTPUT_BLOCK_ID, "name": "AgentOutputBlock"},
]
async def health_check_dummy() -> bool:
"""Always returns healthy for dummy service."""
return True

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,913 @@
"""Unit tests for AgentFixer."""
from .fixer import (
_ADDTODICTIONARY_BLOCK_ID,
_ADDTOLIST_BLOCK_ID,
_CODE_EXECUTION_BLOCK_ID,
_DATA_SAMPLING_BLOCK_ID,
_GET_CURRENT_DATE_BLOCK_ID,
_STORE_VALUE_BLOCK_ID,
_TEXT_REPLACE_BLOCK_ID,
_UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
AGENT_EXECUTOR_BLOCK_ID,
MCP_TOOL_BLOCK_ID,
AgentFixer,
)
from .helpers import generate_uuid
def _make_agent(
nodes: list | None = None,
links: list | None = None,
agent_id: str | None = None,
) -> dict:
"""Build a minimal agent dict for testing."""
return {
"id": agent_id or generate_uuid(),
"name": "Test Agent",
"nodes": nodes or [],
"links": links or [],
}
def _make_node(
node_id: str | None = None,
block_id: str = "block-1",
input_default: dict | None = None,
position: tuple[int, int] = (0, 0),
) -> dict:
"""Build a minimal node dict for testing."""
return {
"id": node_id or generate_uuid(),
"block_id": block_id,
"input_default": input_default or {},
"metadata": {"position": {"x": position[0], "y": position[1]}},
}
def _make_link(
link_id: str | None = None,
source_id: str = "",
source_name: str = "output",
sink_id: str = "",
sink_name: str = "input",
is_static: bool = False,
) -> dict:
"""Build a minimal link dict for testing."""
return {
"id": link_id or generate_uuid(),
"source_id": source_id,
"source_name": source_name,
"sink_id": sink_id,
"sink_name": sink_name,
"is_static": is_static,
}
class TestFixAgentIds:
"""Tests for fix_agent_ids."""
def test_valid_uuids_unchanged(self):
fixer = AgentFixer()
agent_id = generate_uuid()
link_id = generate_uuid()
agent = _make_agent(agent_id=agent_id, links=[{"id": link_id}])
result = fixer.fix_agent_ids(agent)
assert result["id"] == agent_id
assert result["links"][0]["id"] == link_id
assert fixer.fixes_applied == []
def test_invalid_agent_id_replaced(self):
fixer = AgentFixer()
agent = _make_agent(agent_id="bad-id")
result = fixer.fix_agent_ids(agent)
assert result["id"] != "bad-id"
assert len(fixer.fixes_applied) == 1
assert "agent ID" in fixer.fixes_applied[0]
def test_invalid_link_id_replaced(self):
fixer = AgentFixer()
agent = _make_agent(links=[{"id": "not-a-uuid"}])
result = fixer.fix_agent_ids(agent)
assert result["links"][0]["id"] != "not-a-uuid"
assert len(fixer.fixes_applied) == 1
class TestFixDoubleCurlyBraces:
"""Tests for fix_double_curly_braces."""
def test_single_braces_converted_to_double(self):
fixer = AgentFixer()
node = _make_node(input_default={"prompt": "Hello {name}!"})
agent = _make_agent(nodes=[node])
result = fixer.fix_double_curly_braces(agent)
assert result["nodes"][0]["input_default"]["prompt"] == "Hello {{name}}!"
def test_double_braces_unchanged(self):
fixer = AgentFixer()
node = _make_node(input_default={"prompt": "Hello {{name}}!"})
agent = _make_agent(nodes=[node])
result = fixer.fix_double_curly_braces(agent)
assert result["nodes"][0]["input_default"]["prompt"] == "Hello {{name}}!"
assert fixer.fixes_applied == []
def test_non_string_prompt_skipped(self):
fixer = AgentFixer()
node = _make_node(input_default={"prompt": 42})
agent = _make_agent(nodes=[node])
result = fixer.fix_double_curly_braces(agent)
assert result["nodes"][0]["input_default"]["prompt"] == 42
def test_non_string_prompt_with_prompt_values_skipped(self):
"""Ensure non-string prompt fields don't crash re.search in the
prompt_values path."""
fixer = AgentFixer()
node_id = generate_uuid()
source_id = generate_uuid()
node = _make_node(
node_id=node_id, input_default={"prompt": None, "prompt_values": {}}
)
source_node = _make_node(node_id=source_id)
link = _make_link(
source_id=source_id,
source_name="output",
sink_id=node_id,
sink_name="prompt_values_$_name",
)
agent = _make_agent(nodes=[node, source_node], links=[link])
result = fixer.fix_double_curly_braces(agent)
# Should not crash and prompt stays None
assert result["nodes"][0]["input_default"]["prompt"] is None
class TestFixCredentials:
"""Tests for fix_credentials."""
def test_credentials_removed(self):
fixer = AgentFixer()
node = _make_node(
input_default={
"credentials": {"key": "secret"},
"url": "http://example.com",
}
)
agent = _make_agent(nodes=[node])
result = fixer.fix_credentials(agent)
assert "credentials" not in result["nodes"][0]["input_default"]
assert result["nodes"][0]["input_default"]["url"] == "http://example.com"
assert len(fixer.fixes_applied) == 1
def test_no_credentials_unchanged(self):
fixer = AgentFixer()
node = _make_node(input_default={"url": "http://example.com"})
agent = _make_agent(nodes=[node])
result = fixer.fix_credentials(agent)
assert result["nodes"][0]["input_default"]["url"] == "http://example.com"
assert fixer.fixes_applied == []
class TestFixCodeExecutionOutput:
"""Tests for fix_code_execution_output."""
def test_response_renamed_to_stdout_logs(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id=_CODE_EXECUTION_BLOCK_ID)
link = _make_link(source_id="n1", source_name="response", sink_id="n2")
agent = _make_agent(nodes=[node], links=[link])
result = fixer.fix_code_execution_output(agent)
assert result["links"][0]["source_name"] == "stdout_logs"
assert len(fixer.fixes_applied) == 1
def test_non_response_source_unchanged(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id=_CODE_EXECUTION_BLOCK_ID)
link = _make_link(source_id="n1", source_name="stdout_logs", sink_id="n2")
agent = _make_agent(nodes=[node], links=[link])
result = fixer.fix_code_execution_output(agent)
assert result["links"][0]["source_name"] == "stdout_logs"
assert fixer.fixes_applied == []
class TestFixDataSamplingSampleSize:
"""Tests for fix_data_sampling_sample_size."""
def test_sample_size_set_to_1(self):
fixer = AgentFixer()
node = _make_node(
node_id="n1",
block_id=_DATA_SAMPLING_BLOCK_ID,
input_default={"sample_size": 10},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_data_sampling_sample_size(agent)
assert result["nodes"][0]["input_default"]["sample_size"] == 1
def test_removes_links_to_sample_size(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id=_DATA_SAMPLING_BLOCK_ID)
link = _make_link(sink_id="n1", sink_name="sample_size", source_id="n2")
agent = _make_agent(nodes=[node], links=[link])
result = fixer.fix_data_sampling_sample_size(agent)
assert len(result["links"]) == 0
assert result["nodes"][0]["input_default"]["sample_size"] == 1
class TestFixTextReplaceNewParameter:
"""Tests for fix_text_replace_new_parameter."""
def test_empty_new_changed_to_space(self):
fixer = AgentFixer()
node = _make_node(
block_id=_TEXT_REPLACE_BLOCK_ID,
input_default={"new": ""},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_text_replace_new_parameter(agent)
assert result["nodes"][0]["input_default"]["new"] == " "
def test_nonempty_new_unchanged(self):
fixer = AgentFixer()
node = _make_node(
block_id=_TEXT_REPLACE_BLOCK_ID,
input_default={"new": "replacement"},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_text_replace_new_parameter(agent)
assert result["nodes"][0]["input_default"]["new"] == "replacement"
assert fixer.fixes_applied == []
class TestFixGetCurrentDateOffset:
"""Tests for fix_getcurrentdate_offset."""
def test_negative_offset_made_positive(self):
fixer = AgentFixer()
node = _make_node(
block_id=_GET_CURRENT_DATE_BLOCK_ID,
input_default={"offset": -5},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_getcurrentdate_offset(agent)
assert result["nodes"][0]["input_default"]["offset"] == 5
def test_positive_offset_unchanged(self):
fixer = AgentFixer()
node = _make_node(
block_id=_GET_CURRENT_DATE_BLOCK_ID,
input_default={"offset": 3},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_getcurrentdate_offset(agent)
assert result["nodes"][0]["input_default"]["offset"] == 3
assert fixer.fixes_applied == []
class TestFixNodeXCoordinates:
"""Tests for fix_node_x_coordinates."""
def test_close_nodes_spread_apart(self):
fixer = AgentFixer()
src_node = _make_node(node_id="src", position=(0, 0))
sink_node = _make_node(node_id="sink", position=(100, 0))
link = _make_link(source_id="src", sink_id="sink")
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
result = fixer.fix_node_x_coordinates(agent)
sink = next(n for n in result["nodes"] if n["id"] == "sink")
assert sink["metadata"]["position"]["x"] >= 800
def test_far_apart_nodes_unchanged(self):
fixer = AgentFixer()
src_node = _make_node(node_id="src", position=(0, 0))
sink_node = _make_node(node_id="sink", position=(1000, 0))
link = _make_link(source_id="src", sink_id="sink")
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
result = fixer.fix_node_x_coordinates(agent)
sink = next(n for n in result["nodes"] if n["id"] == "sink")
assert sink["metadata"]["position"]["x"] == 1000
assert fixer.fixes_applied == []
class TestFixAddToDictionaryBlocks:
"""Tests for fix_addtodictionary_blocks."""
def test_removes_create_dictionary_nodes(self):
fixer = AgentFixer()
create_dict_id = "b924ddf4-de4f-4b56-9a85-358930dcbc91"
dict_node = _make_node(node_id="dict-1", block_id=create_dict_id)
add_to_dict_node = _make_node(
node_id="add-1", block_id=_ADDTODICTIONARY_BLOCK_ID
)
link = _make_link(source_id="dict-1", sink_id="add-1")
agent = _make_agent(nodes=[dict_node, add_to_dict_node], links=[link])
result = fixer.fix_addtodictionary_blocks(agent)
node_ids = [n["id"] for n in result["nodes"]]
assert "dict-1" not in node_ids
assert "add-1" in node_ids
assert len(result["links"]) == 0
class TestFixStoreValueBeforeCondition:
"""Tests for fix_storevalue_before_condition."""
def test_inserts_storevalue_block(self):
fixer = AgentFixer()
condition_block_id = "715696a0-e1da-45c8-b209-c2fa9c3b0be6"
src_node = _make_node(node_id="src")
cond_node = _make_node(node_id="cond", block_id=condition_block_id)
link = _make_link(
source_id="src", source_name="output", sink_id="cond", sink_name="value2"
)
agent = _make_agent(nodes=[src_node, cond_node], links=[link])
result = fixer.fix_storevalue_before_condition(agent)
# Should have 3 nodes now (original 2 + new StoreValueBlock)
assert len(result["nodes"]) == 3
store_nodes = [
n for n in result["nodes"] if n["block_id"] == _STORE_VALUE_BLOCK_ID
]
assert len(store_nodes) == 1
assert store_nodes[0]["input_default"]["data"] is None
class TestFixAddToListBlocks:
"""Tests for fix_addtolist_blocks - self-reference links."""
def test_addtolist_gets_self_reference_link(self):
fixer = AgentFixer()
node = _make_node(node_id="atl-1", block_id=_ADDTOLIST_BLOCK_ID)
# Source link to AddToList (from some other node)
link = _make_link(
source_id="other",
source_name="output",
sink_id="atl-1",
sink_name="item",
)
other_node = _make_node(node_id="other")
agent = _make_agent(nodes=[other_node, node], links=[link])
result = fixer.fix_addtolist_blocks(agent)
# Should have a self-reference link: atl-1.updated_list -> atl-1.list
self_ref_links = [
lnk
for lnk in result["links"]
if lnk["source_id"] == "atl-1"
and lnk["sink_id"] == "atl-1"
and lnk["source_name"] == "updated_list"
and lnk["sink_name"] == "list"
]
assert len(self_ref_links) == 1
class TestFixLinkStaticProperties:
"""Tests for fix_link_static_properties."""
def test_sets_is_static_from_block_schema(self):
fixer = AgentFixer()
block_id = generate_uuid()
node = _make_node(node_id="n1", block_id=block_id)
link = _make_link(source_id="n1", sink_id="n2", is_static=False)
agent = _make_agent(nodes=[node], links=[link])
blocks = [{"id": block_id, "staticOutput": True}]
result = fixer.fix_link_static_properties(agent, blocks)
assert result["links"][0]["is_static"] is True
def test_unknown_block_leaves_link_unchanged(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id="unknown-block")
link = _make_link(source_id="n1", sink_id="n2", is_static=True)
agent = _make_agent(nodes=[node], links=[link])
result = fixer.fix_link_static_properties(agent, blocks=[])
# Unknown block → skipped, link stays as-is
assert result["links"][0]["is_static"] is True
class TestFixAiModelParameter:
"""Tests for fix_ai_model_parameter."""
def test_missing_model_gets_default(self):
fixer = AgentFixer()
block_id = generate_uuid()
node = _make_node(node_id="n1", block_id=block_id, input_default={})
agent = _make_agent(nodes=[node])
blocks = [
{
"id": block_id,
"categories": [{"category": "AI"}],
"inputSchema": {
"properties": {"model": {"type": "string"}},
},
}
]
result = fixer.fix_ai_model_parameter(agent, blocks)
assert result["nodes"][0]["input_default"]["model"] == "gpt-4o"
def test_valid_model_unchanged(self):
fixer = AgentFixer()
block_id = generate_uuid()
node = _make_node(
node_id="n1",
block_id=block_id,
input_default={"model": "claude-opus-4-6"},
)
agent = _make_agent(nodes=[node])
blocks = [
{
"id": block_id,
"categories": [{"category": "AI"}],
"inputSchema": {
"properties": {"model": {"type": "string"}},
},
}
]
result = fixer.fix_ai_model_parameter(agent, blocks)
assert result["nodes"][0]["input_default"]["model"] == "claude-opus-4-6"
class TestFixAgentExecutorBlocks:
"""Tests for fix_agent_executor_blocks."""
def test_fills_schemas_from_library_agent(self):
fixer = AgentFixer()
lib_agent_id = generate_uuid()
node = _make_node(
node_id="n1",
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={
"graph_id": lib_agent_id,
"graph_version": 1,
"user_id": "user-1",
},
)
agent = _make_agent(nodes=[node])
# Library agents use graph_id as the lookup key
library_agents = [
{
"graph_id": lib_agent_id,
"graph_version": 2,
"input_schema": {"field1": {"type": "string"}},
"output_schema": {"result": {"type": "string"}},
}
]
result = fixer.fix_agent_executor_blocks(agent, library_agents)
node_result = result["nodes"][0]["input_default"]
assert node_result["graph_version"] == 2
assert node_result["input_schema"] == {"field1": {"type": "string"}}
assert node_result["output_schema"] == {"result": {"type": "string"}}
class TestFixInvalidNestedSinkLinks:
"""Tests for fix_invalid_nested_sink_links."""
def test_removes_numeric_index_links(self):
fixer = AgentFixer()
block_id = generate_uuid()
node = _make_node(node_id="n1", block_id=block_id)
link = _make_link(source_id="n2", sink_id="n1", sink_name="values_#_0")
agent = _make_agent(nodes=[node], links=[link])
blocks = [
{
"id": block_id,
"inputSchema": {"properties": {"values": {"type": "array"}}},
}
]
result = fixer.fix_invalid_nested_sink_links(agent, blocks)
assert len(result["links"]) == 0
def test_valid_nested_links_kept(self):
fixer = AgentFixer()
block_id = generate_uuid()
node = _make_node(node_id="n1", block_id=block_id)
link = _make_link(source_id="n2", sink_id="n1", sink_name="values_#_name")
agent = _make_agent(nodes=[node], links=[link])
blocks = [
{
"id": block_id,
"inputSchema": {
"properties": {"values": {"type": "object"}},
},
}
]
result = fixer.fix_invalid_nested_sink_links(agent, blocks)
assert len(result["links"]) == 1
class TestApplyAllFixes:
"""Tests for apply_all_fixes orchestration."""
def test_is_sync(self):
"""apply_all_fixes should be a sync function."""
import inspect
assert not inspect.iscoroutinefunction(AgentFixer.apply_all_fixes)
def test_applies_multiple_fixes(self):
fixer = AgentFixer()
agent = _make_agent(
agent_id="bad-id",
nodes=[
_make_node(
block_id=_TEXT_REPLACE_BLOCK_ID,
input_default={"new": "", "credentials": {"key": "secret"}},
)
],
)
result = fixer.apply_all_fixes(agent)
# Agent ID should be fixed
assert result["id"] != "bad-id"
# Credentials should be removed
assert "credentials" not in result["nodes"][0]["input_default"]
# Text replace "new" should be space
assert result["nodes"][0]["input_default"]["new"] == " "
# Multiple fixes applied
assert len(fixer.fixes_applied) >= 3
def test_empty_agent_no_crash(self):
fixer = AgentFixer()
agent = _make_agent()
result = fixer.apply_all_fixes(agent)
assert "nodes" in result
assert "links" in result
def test_returns_deep_copy_behavior(self):
"""Fixer mutates in place — verify the same dict is returned."""
fixer = AgentFixer()
agent = _make_agent()
result = fixer.apply_all_fixes(agent)
assert result is agent
class TestFixMCPToolBlocks:
"""Tests for fix_mcp_tool_blocks."""
def test_adds_missing_tool_arguments(self):
fixer = AgentFixer()
node = _make_node(
node_id="n1",
block_id=MCP_TOOL_BLOCK_ID,
input_default={
"server_url": "https://mcp.example.com/sse",
"selected_tool": "search",
"tool_input_schema": {},
},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_mcp_tool_blocks(agent)
assert result["nodes"][0]["input_default"]["tool_arguments"] == {}
assert any("tool_arguments" in f for f in fixer.fixes_applied)
def test_adds_missing_tool_input_schema(self):
fixer = AgentFixer()
node = _make_node(
node_id="n1",
block_id=MCP_TOOL_BLOCK_ID,
input_default={
"server_url": "https://mcp.example.com/sse",
"selected_tool": "search",
"tool_arguments": {},
},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_mcp_tool_blocks(agent)
assert result["nodes"][0]["input_default"]["tool_input_schema"] == {}
assert any("tool_input_schema" in f for f in fixer.fixes_applied)
def test_populates_tool_arguments_from_schema(self):
fixer = AgentFixer()
node = _make_node(
node_id="n1",
block_id=MCP_TOOL_BLOCK_ID,
input_default={
"server_url": "https://mcp.example.com/sse",
"selected_tool": "search",
"tool_input_schema": {
"properties": {
"query": {"type": "string", "default": "hello"},
"limit": {"type": "integer"},
}
},
"tool_arguments": {},
},
)
agent = _make_agent(nodes=[node])
result = fixer.fix_mcp_tool_blocks(agent)
tool_args = result["nodes"][0]["input_default"]["tool_arguments"]
assert tool_args["query"] == "hello"
assert tool_args["limit"] is None
def test_no_op_when_already_complete(self):
fixer = AgentFixer()
node = _make_node(
node_id="n1",
block_id=MCP_TOOL_BLOCK_ID,
input_default={
"server_url": "https://mcp.example.com/sse",
"selected_tool": "search",
"tool_input_schema": {},
"tool_arguments": {},
},
)
agent = _make_agent(nodes=[node])
fixer.fix_mcp_tool_blocks(agent)
assert len(fixer.fixes_applied) == 0
class TestFixDynamicBlockSinkNames:
"""Tests for fix_dynamic_block_sink_names."""
def test_mcp_tool_arguments_prefix_removed(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id=MCP_TOOL_BLOCK_ID)
link = _make_link(
source_id="src", sink_id="n1", sink_name="tool_arguments_#_query"
)
agent = _make_agent(nodes=[node], links=[link])
fixer.fix_dynamic_block_sink_names(agent)
assert agent["links"][0]["sink_name"] == "query"
assert len(fixer.fixes_applied) == 1
def test_agent_executor_inputs_prefix_removed(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id=AGENT_EXECUTOR_BLOCK_ID)
link = _make_link(source_id="src", sink_id="n1", sink_name="inputs_#_url")
agent = _make_agent(nodes=[node], links=[link])
fixer.fix_dynamic_block_sink_names(agent)
assert agent["links"][0]["sink_name"] == "url"
assert len(fixer.fixes_applied) == 1
def test_bare_sink_name_unchanged(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id=MCP_TOOL_BLOCK_ID)
link = _make_link(source_id="src", sink_id="n1", sink_name="query")
agent = _make_agent(nodes=[node], links=[link])
fixer.fix_dynamic_block_sink_names(agent)
assert agent["links"][0]["sink_name"] == "query"
assert len(fixer.fixes_applied) == 0
def test_non_dynamic_block_unchanged(self):
fixer = AgentFixer()
node = _make_node(node_id="n1", block_id="some-other-block-id")
link = _make_link(source_id="src", sink_id="n1", sink_name="values_#_key")
agent = _make_agent(nodes=[node], links=[link])
fixer.fix_dynamic_block_sink_names(agent)
assert agent["links"][0]["sink_name"] == "values_#_key"
assert len(fixer.fixes_applied) == 0
class TestFixDataTypeMismatch:
"""Tests for fix_data_type_mismatch."""
@staticmethod
def _make_block(
block_id: str,
name: str = "TestBlock",
input_schema: dict | None = None,
output_schema: dict | None = None,
) -> dict:
return {
"id": block_id,
"name": name,
"inputSchema": input_schema or {"properties": {}},
"outputSchema": output_schema or {"properties": {}},
}
def test_inserts_converter_for_incompatible_types(self):
fixer = AgentFixer()
src_block_id = generate_uuid()
sink_block_id = generate_uuid()
src_node = _make_node(node_id="src", block_id=src_block_id)
sink_node = _make_node(node_id="sink", block_id=sink_block_id)
link = _make_link(
source_id="src",
source_name="result",
sink_id="sink",
sink_name="count",
)
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
blocks = [
self._make_block(
src_block_id,
name="Source",
output_schema={"properties": {"result": {"type": "string"}}},
),
self._make_block(
sink_block_id,
name="Sink",
input_schema={"properties": {"count": {"type": "integer"}}},
),
]
result = fixer.fix_data_type_mismatch(agent, blocks)
# A converter node should have been inserted
converter_nodes = [
n
for n in result["nodes"]
if n["block_id"] == _UNIVERSAL_TYPE_CONVERTER_BLOCK_ID
]
assert len(converter_nodes) == 1
assert converter_nodes[0]["input_default"]["type"] == "number"
# Original link replaced by two new links through the converter
assert len(result["links"]) == 2
src_to_converter = result["links"][0]
converter_to_sink = result["links"][1]
assert src_to_converter["source_id"] == "src"
assert src_to_converter["sink_id"] == converter_nodes[0]["id"]
assert src_to_converter["sink_name"] == "value"
assert converter_to_sink["source_id"] == converter_nodes[0]["id"]
assert converter_to_sink["source_name"] == "value"
assert converter_to_sink["sink_id"] == "sink"
assert converter_to_sink["sink_name"] == "count"
assert len(fixer.fixes_applied) == 1
def test_compatible_types_unchanged(self):
fixer = AgentFixer()
block_id = generate_uuid()
src_node = _make_node(node_id="src", block_id=block_id)
sink_node = _make_node(node_id="sink", block_id=block_id)
link = _make_link(
source_id="src",
source_name="output",
sink_id="sink",
sink_name="input",
)
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
blocks = [
self._make_block(
block_id,
input_schema={"properties": {"input": {"type": "string"}}},
output_schema={"properties": {"output": {"type": "string"}}},
),
]
result = fixer.fix_data_type_mismatch(agent, blocks)
# No converter inserted, original link kept
assert len(result["nodes"]) == 2
assert len(result["links"]) == 1
assert result["links"][0] is link
assert fixer.fixes_applied == []
def test_missing_block_keeps_link(self):
"""Links referencing unknown blocks are kept unchanged."""
fixer = AgentFixer()
src_node = _make_node(node_id="src", block_id="unknown-block")
sink_node = _make_node(node_id="sink", block_id="unknown-block")
link = _make_link(source_id="src", sink_id="sink")
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
result = fixer.fix_data_type_mismatch(agent, blocks=[])
assert len(result["links"]) == 1
assert result["links"][0] is link
def test_missing_type_info_keeps_link(self):
"""Links where source/sink type is not defined are kept unchanged."""
fixer = AgentFixer()
block_id = generate_uuid()
src_node = _make_node(node_id="src", block_id=block_id)
sink_node = _make_node(node_id="sink", block_id=block_id)
link = _make_link(
source_id="src",
source_name="output",
sink_id="sink",
sink_name="input",
)
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
# Block has no properties defined for the linked fields
blocks = [self._make_block(block_id)]
result = fixer.fix_data_type_mismatch(agent, blocks)
assert len(result["links"]) == 1
assert fixer.fixes_applied == []
def test_multiple_mismatches_insert_multiple_converters(self):
"""Each incompatible link gets its own converter node."""
fixer = AgentFixer()
src_block_id = generate_uuid()
sink_block_id = generate_uuid()
src_node = _make_node(node_id="src", block_id=src_block_id)
sink1 = _make_node(node_id="sink1", block_id=sink_block_id)
sink2 = _make_node(node_id="sink2", block_id=sink_block_id)
link1 = _make_link(
source_id="src", source_name="out", sink_id="sink1", sink_name="count"
)
link2 = _make_link(
source_id="src", source_name="out", sink_id="sink2", sink_name="count"
)
agent = _make_agent(nodes=[src_node, sink1, sink2], links=[link1, link2])
blocks = [
self._make_block(
src_block_id,
output_schema={"properties": {"out": {"type": "string"}}},
),
self._make_block(
sink_block_id,
input_schema={"properties": {"count": {"type": "integer"}}},
),
]
result = fixer.fix_data_type_mismatch(agent, blocks)
converter_nodes = [
n
for n in result["nodes"]
if n["block_id"] == _UNIVERSAL_TYPE_CONVERTER_BLOCK_ID
]
assert len(converter_nodes) == 2
# Each original link becomes two links through its own converter
assert len(result["links"]) == 4
assert len(fixer.fixes_applied) == 2

View File

@@ -0,0 +1,67 @@
"""Shared helpers for agent generation."""
import re
import uuid
from typing import Any
from .blocks import get_blocks_as_dicts
__all__ = [
"AGENT_EXECUTOR_BLOCK_ID",
"AGENT_INPUT_BLOCK_ID",
"AGENT_OUTPUT_BLOCK_ID",
"AgentDict",
"MCP_TOOL_BLOCK_ID",
"UUID_REGEX",
"are_types_compatible",
"generate_uuid",
"get_blocks_as_dicts",
"get_defined_property_type",
"is_uuid",
]
# Type alias for the agent JSON structure passed through
# the validation and fixing pipeline.
AgentDict = dict[str, Any]
# Shared base pattern (unanchored, lowercase hex); used for both full-string
# validation (UUID_REGEX) and text extraction (core._UUID_PATTERN).
UUID_RE_STR = r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[a-f0-9]{4}-[a-f0-9]{12}"
UUID_REGEX = re.compile(r"^" + UUID_RE_STR + r"$")
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
MCP_TOOL_BLOCK_ID = "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
AGENT_INPUT_BLOCK_ID = "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4"
def is_uuid(value: str) -> bool:
"""Check if a string is a valid UUID."""
return isinstance(value, str) and UUID_REGEX.match(value) is not None
def generate_uuid() -> str:
"""Generate a new UUID string."""
return str(uuid.uuid4())
def get_defined_property_type(schema: dict[str, Any], name: str) -> str | None:
"""Get property type from a schema, handling nested `_#_` notation."""
if "_#_" in name:
parent, child = name.split("_#_", 1)
parent_schema = schema.get(parent, {})
if "properties" in parent_schema and isinstance(
parent_schema["properties"], dict
):
return parent_schema["properties"].get(child, {}).get("type")
return None
return schema.get(name, {}).get("type")
def are_types_compatible(src: str, sink: str) -> bool:
"""Check if two schema types are compatible."""
if {src, sink} <= {"integer", "number"}:
return True
return src == sink

View File

@@ -0,0 +1,196 @@
"""Shared fix → validate → preview/save pipeline for agent tools."""
import json
import logging
from typing import Any, cast
from backend.copilot.tools.models import (
AgentPreviewResponse,
AgentSavedResponse,
ErrorResponse,
ToolResponseBase,
)
from .blocks import get_blocks_as_dicts
from .core import get_library_agents_by_ids, save_agent_to_library
from .fixer import AgentFixer
from .validator import AgentValidator
logger = logging.getLogger(__name__)
MAX_AGENT_JSON_SIZE = 1_000_000 # 1 MB
async def fetch_library_agents(
user_id: str | None,
library_agent_ids: list[str],
) -> list[dict[str, Any]] | None:
"""Fetch library agents by IDs for AgentExecutorBlock validation.
Returns None if no IDs provided or user is not authenticated.
"""
if not user_id or not library_agent_ids:
return None
try:
agents = await get_library_agents_by_ids(
user_id=user_id,
agent_ids=library_agent_ids,
)
return cast(list[dict[str, Any]], agents)
except Exception as e:
logger.warning(f"Failed to fetch library agents by IDs: {e}")
return None
async def fix_validate_and_save(
agent_json: dict[str, Any],
*,
user_id: str | None,
session_id: str | None,
save: bool = True,
is_update: bool = False,
default_name: str = "Agent",
preview_message: str | None = None,
save_message: str | None = None,
library_agents: list[dict[str, Any]] | None = None,
folder_id: str | None = None,
) -> ToolResponseBase:
"""Shared pipeline: auto-fix → validate → preview or save.
Args:
agent_json: The agent JSON dict (must already have id/version/is_active set).
user_id: The authenticated user's ID.
session_id: The chat session ID.
save: Whether to save or just preview.
is_update: Whether this is an update to an existing agent.
default_name: Fallback name if agent_json has none.
preview_message: Custom preview message (optional).
save_message: Custom save success message (optional).
library_agents: Library agents for AgentExecutorBlock validation/fixing.
Returns:
An appropriate ToolResponseBase subclass.
"""
# Size guard
json_size = len(json.dumps(agent_json))
if json_size > MAX_AGENT_JSON_SIZE:
return ErrorResponse(
message=(
f"Agent JSON is too large ({json_size:,} bytes, "
f"max {MAX_AGENT_JSON_SIZE:,}). Reduce the number of nodes."
),
error="agent_json_too_large",
session_id=session_id,
)
blocks = get_blocks_as_dicts()
# Auto-fix
try:
fixer = AgentFixer()
agent_json = fixer.apply_all_fixes(agent_json, blocks, library_agents)
fixes = fixer.get_fixes_applied()
if fixes:
logger.info(f"Applied {len(fixes)} auto-fixes to agent JSON")
except Exception as e:
logger.warning(f"Auto-fix failed: {e}")
# Validate
try:
validator = AgentValidator()
is_valid, _ = validator.validate(agent_json, blocks, library_agents)
if not is_valid:
errors = validator.errors
return ErrorResponse(
message=(
f"The agent has {len(errors)} validation error(s):\n"
+ "\n".join(f"- {e}" for e in errors[:5])
),
error="validation_failed",
details={"errors": errors},
session_id=session_id,
)
except Exception as e:
logger.error(f"Validation failed with exception: {e}", exc_info=True)
return ErrorResponse(
message="Failed to validate the agent. Please try again.",
error="validation_exception",
details={"exception": str(e)},
session_id=session_id,
)
agent_name = agent_json.get("name", default_name)
agent_description = agent_json.get("description", "")
node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", []))
# Build a warning suffix when name/description is missing or generic
_GENERIC_NAMES = {
"agent",
"generated agent",
"customized agent",
"updated agent",
"new agent",
"my agent",
}
metadata_warnings: list[str] = []
if not agent_json.get("name") or agent_name.lower().strip() in _GENERIC_NAMES:
metadata_warnings.append("'name'")
if not agent_description:
metadata_warnings.append("'description'")
metadata_hint = ""
if metadata_warnings:
missing = " and ".join(metadata_warnings)
metadata_hint = (
f" Note: the agent is missing a meaningful {missing}. "
f"Please update the agent_json to include them."
)
if not save:
return AgentPreviewResponse(
message=(
(
preview_message
or f"Agent '{agent_name}' with {node_count} blocks is ready."
)
+ metadata_hint
),
agent_json=agent_json,
agent_name=agent_name,
description=agent_description,
node_count=node_count,
link_count=link_count,
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
error="auth_required",
session_id=session_id,
)
try:
created_graph, library_agent = await save_agent_to_library(
agent_json, user_id, is_update=is_update, folder_id=folder_id
)
return AgentSavedResponse(
message=(
(save_message or f"Agent '{created_graph.name}' has been saved!")
+ metadata_hint
),
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)
except Exception as e:
logger.error(f"Failed to save agent: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to save the agent: {str(e)}",
error="save_failed",
details={"exception": str(e)},
session_id=session_id,
)

View File

@@ -1,511 +0,0 @@
"""External Agent Generator service client.
This module provides a client for communicating with the external Agent Generator
microservice. When AGENTGENERATOR_HOST is configured, the agent generation functions
will delegate to the external service instead of using the built-in LLM-based implementation.
"""
import logging
from typing import Any
import httpx
from backend.util.settings import Settings
from .dummy import (
customize_template_dummy,
decompose_goal_dummy,
generate_agent_dummy,
generate_agent_patch_dummy,
get_blocks_dummy,
health_check_dummy,
)
logger = logging.getLogger(__name__)
_dummy_mode_warned = False
def _create_error_response(
error_message: str,
error_type: str = "unknown",
details: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Create a standardized error response dict.
Args:
error_message: Human-readable error message
error_type: Machine-readable error type
details: Optional additional error details
Returns:
Error dict with type="error" and error details
"""
response: dict[str, Any] = {
"type": "error",
"error": error_message,
"error_type": error_type,
}
if details:
response["details"] = details
return response
def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]:
"""Classify an HTTP error into error_type and message.
Args:
e: The HTTP status error
Returns:
Tuple of (error_type, error_message)
"""
status = e.response.status_code
if status == 429:
return "rate_limit", f"Agent Generator rate limited: {e}"
elif status == 503:
return "service_unavailable", f"Agent Generator unavailable: {e}"
elif status == 504 or status == 408:
return "timeout", f"Agent Generator timed out: {e}"
else:
return "http_error", f"HTTP error calling Agent Generator: {e}"
def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]:
"""Classify a request error into error_type and message.
Args:
e: The request error
Returns:
Tuple of (error_type, error_message)
"""
error_str = str(e).lower()
if "timeout" in error_str or "timed out" in error_str:
return "timeout", f"Agent Generator request timed out: {e}"
elif "connect" in error_str:
return "connection_error", f"Could not connect to Agent Generator: {e}"
else:
return "request_error", f"Request error calling Agent Generator: {e}"
_client: httpx.AsyncClient | None = None
_settings: Settings | None = None
def _get_settings() -> Settings:
"""Get or create settings singleton."""
global _settings
if _settings is None:
_settings = Settings()
return _settings
def _is_dummy_mode() -> bool:
"""Check if dummy mode is enabled for testing."""
global _dummy_mode_warned
settings = _get_settings()
is_dummy = bool(settings.config.agentgenerator_use_dummy)
if is_dummy and not _dummy_mode_warned:
logger.warning(
"Agent Generator running in DUMMY MODE - returning mock responses. "
"Do not use in production!"
)
_dummy_mode_warned = True
return is_dummy
def is_external_service_configured() -> bool:
"""Check if external Agent Generator service is configured (or dummy mode)."""
settings = _get_settings()
return bool(settings.config.agentgenerator_host) or bool(
settings.config.agentgenerator_use_dummy
)
def _get_base_url() -> str:
"""Get the base URL for the external service."""
settings = _get_settings()
host = settings.config.agentgenerator_host
port = settings.config.agentgenerator_port
return f"http://{host}:{port}"
def _get_client() -> httpx.AsyncClient:
"""Get or create the HTTP client for the external service."""
global _client
if _client is None:
settings = _get_settings()
_client = httpx.AsyncClient(
base_url=_get_base_url(),
timeout=httpx.Timeout(settings.config.agentgenerator_timeout),
)
return _client
async def decompose_goal_external(
description: str,
context: str = "",
library_agents: list[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Call the external service to decompose a goal.
Args:
description: Natural language goal description
context: Additional context (e.g., answers to previous questions)
library_agents: User's library agents available for sub-agent composition
Returns:
Dict with either:
- {"type": "clarifying_questions", "questions": [...]}
- {"type": "instructions", "steps": [...]}
- {"type": "unachievable_goal", ...}
- {"type": "vague_goal", ...}
- {"type": "error", "error": "...", "error_type": "..."} on error
Or None on unexpected error
"""
if _is_dummy_mode():
return await decompose_goal_dummy(description, context, library_agents)
client = _get_client()
if context:
description = f"{description}\n\nAdditional context from user:\n{context}"
payload: dict[str, Any] = {"description": description}
if library_agents:
payload["library_agents"] = library_agents
try:
response = await client.post("/api/decompose-description", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator decomposition failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Map the response to the expected format
response_type = data.get("type")
if response_type == "instructions":
return {"type": "instructions", "steps": data.get("steps", [])}
elif response_type == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
elif response_type == "unachievable_goal":
return {
"type": "unachievable_goal",
"reason": data.get("reason"),
"suggested_goal": data.get("suggested_goal"),
}
elif response_type == "vague_goal":
return {
"type": "vague_goal",
"suggested_goal": data.get("suggested_goal"),
}
elif response_type == "error":
# Pass through error from the service
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
else:
logger.error(
f"Unknown response type from external service: {response_type}"
)
return _create_error_response(
f"Unknown response type from Agent Generator: {response_type}",
"invalid_response",
)
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def generate_agent_external(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate an agent from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
Returns:
Agent JSON dict or error dict {"type": "error", ...} on error
"""
if _is_dummy_mode():
return await generate_agent_dummy(instructions, library_agents)
client = _get_client()
# Build request payload
payload: dict[str, Any] = {"instructions": instructions}
if library_agents:
payload["library_agents"] = library_agents
try:
response = await client.post("/api/generate-agent", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator generation failed: {error_msg} (type: {error_type})"
)
return _create_error_response(error_msg, error_type)
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def generate_agent_patch_external(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate a patch for an existing agent.
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
session_id: Session ID for async processing (enables Redis Streams callback)
Returns:
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
"""
if _is_dummy_mode():
return await generate_agent_patch_dummy(
update_request, current_agent, library_agents
)
client = _get_client()
# Build request payload
payload: dict[str, Any] = {
"update_request": update_request,
"current_agent_json": current_agent,
}
if library_agents:
payload["library_agents"] = library_agents
try:
response = await client.post("/api/update-agent", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator patch generation failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Check if it's clarifying questions
if data.get("type") == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
# Check if it's an error passed through
if data.get("type") == "error":
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
# Otherwise return the updated agent JSON
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def customize_template_external(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any] | None:
"""Call the external service to customize a template/marketplace agent.
Args:
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
operation_id: Operation ID for async processing (enables Redis Streams callback)
session_id: Session ID for async processing (enables Redis Streams callback)
Returns:
Customized agent JSON, clarifying questions dict, or error dict on error
"""
if _is_dummy_mode():
return await customize_template_dummy(
template_agent, modification_request, context
)
client = _get_client()
request = modification_request
if context:
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
payload: dict[str, Any] = {
"template_agent_json": template_agent,
"modification_request": request,
}
try:
response = await client.post("/api/template-modification", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator template customization failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Check if it's clarifying questions
if data.get("type") == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
# Check if it's an error passed through
if data.get("type") == "error":
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
# Otherwise return the customized agent JSON
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def get_blocks_external() -> list[dict[str, Any]] | None:
"""Get available blocks from the external service.
Returns:
List of block info dicts or None on error
"""
if _is_dummy_mode():
return await get_blocks_dummy()
client = _get_client()
try:
response = await client.get("/api/blocks")
response.raise_for_status()
data = response.json()
if not data.get("success"):
logger.error("External service returned error getting blocks")
return None
return data.get("blocks", [])
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error getting blocks from external service: {e}")
return None
except httpx.RequestError as e:
logger.error(f"Request error getting blocks from external service: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error getting blocks from external service: {e}")
return None
async def health_check() -> bool:
"""Check if the external service is healthy.
Returns:
True if healthy, False otherwise
"""
if not is_external_service_configured():
return False
if _is_dummy_mode():
return await health_check_dummy()
client = _get_client()
try:
response = await client.get("/health")
response.raise_for_status()
data = response.json()
return data.get("status") == "healthy" and data.get("blocks_loaded", False)
except Exception as e:
logger.warning(f"External agent generator health check failed: {e}")
return False
async def close_client() -> None:
"""Close the HTTP client."""
global _client
if _client is not None:
await _client.aclose()
_client = None

View File

@@ -0,0 +1,17 @@
"""Agent generation validation — re-exports from split modules.
This module was split into:
- helpers.py: get_blocks_as_dicts, block cache
- fixer.py: AgentFixer class
- validator.py: AgentValidator class
"""
from .fixer import AgentFixer
from .helpers import get_blocks_as_dicts
from .validator import AgentValidator
__all__ = [
"AgentFixer",
"AgentValidator",
"get_blocks_as_dicts",
]

View File

@@ -0,0 +1,939 @@
"""AgentValidator — validates agent JSON graphs for correctness."""
import json
import logging
import re
from typing import Any
from .helpers import (
AGENT_EXECUTOR_BLOCK_ID,
AGENT_INPUT_BLOCK_ID,
AGENT_OUTPUT_BLOCK_ID,
MCP_TOOL_BLOCK_ID,
AgentDict,
are_types_compatible,
get_defined_property_type,
)
logger = logging.getLogger(__name__)
class AgentValidator:
"""
A comprehensive validator for AutoGPT agents that provides detailed error
reporting for LLM-based fixes.
"""
def __init__(self):
self.errors: list[str] = []
def add_error(self, error_message: str) -> None:
"""Add an error message to the validation errors list."""
self.errors.append(error_message)
def _values_equal(self, val1: Any, val2: Any) -> bool:
"""Compare two values, handling complex types like dicts and lists."""
if type(val1) is not type(val2):
return False
if isinstance(val1, dict):
return json.dumps(val1, sort_keys=True) == json.dumps(val2, sort_keys=True)
if isinstance(val1, list):
return json.dumps(val1, sort_keys=True) == json.dumps(val2, sort_keys=True)
return val1 == val2
def validate_block_existence(
self, agent: AgentDict, blocks: list[dict[str, Any]]
) -> bool:
"""
Validate that all block IDs used in the agent actually exist in the
blocks list. Returns True if all block IDs exist, False otherwise.
"""
valid = True
# Create a set of all valid block IDs for fast lookup
valid_block_ids = {block.get("id") for block in blocks if block.get("id")}
# Check each node's block_id
for node in agent.get("nodes", []):
block_id = node.get("block_id")
node_id = node.get("id")
if not block_id:
self.add_error(
f"Node '{node_id}' is missing a 'block_id' field. "
f"Every node must reference a valid block."
)
valid = False
continue
if block_id not in valid_block_ids:
self.add_error(
f"Node '{node_id}' references block_id '{block_id}' "
f"which does not exist in the available blocks. "
f"This block may have been deprecated, removed, or "
f"the ID is incorrect. Please use a valid block from "
f"the blocks library."
)
valid = False
return valid
def validate_link_node_references(self, agent: AgentDict) -> bool:
"""
Validate that all node IDs referenced in links actually exist in the
agent's nodes. Returns True if all link references are valid, False
otherwise.
"""
valid = True
# Create a set of all valid node IDs for fast lookup
valid_node_ids = {
node.get("id") for node in agent.get("nodes", []) if node.get("id")
}
# Check each link's source_id and sink_id
for link in agent.get("links", []):
link_id = link.get("id", "Unknown")
source_id = link.get("source_id")
sink_id = link.get("sink_id")
source_name = link.get("source_name", "")
sink_name = link.get("sink_name", "")
# Check source_id
if not source_id:
self.add_error(
f"Link '{link_id}' is missing a 'source_id' field. "
f"Every link must reference a valid source node."
)
valid = False
elif source_id not in valid_node_ids:
self.add_error(
f"Link '{link_id}' references source_id '{source_id}' "
f"which does not exist in the agent's nodes. The link "
f"from '{source_name}' cannot be established because "
f"the source node is missing."
)
valid = False
# Check sink_id
if not sink_id:
self.add_error(
f"Link '{link_id}' is missing a 'sink_id' field. "
f"Every link must reference a valid sink (destination) "
f"node."
)
valid = False
elif sink_id not in valid_node_ids:
self.add_error(
f"Link '{link_id}' references sink_id '{sink_id}' "
f"which does not exist in the agent's nodes. The link "
f"to '{sink_name}' cannot be established because the "
f"destination node is missing."
)
valid = False
return valid
def validate_required_inputs(
self, agent: AgentDict, blocks: list[dict[str, Any]]
) -> bool:
"""
Validate that all required inputs are provided for each node.
Returns True if all required inputs are satisfied, False otherwise.
"""
valid = True
block_lookup = {b.get("id", ""): b for b in blocks}
for node in agent.get("nodes", []):
block_id = node.get("block_id")
block = block_lookup.get(block_id)
if not block:
continue
required_inputs = block.get("inputSchema", {}).get("required", [])
input_defaults = node.get("input_default", {})
node_id = node.get("id")
linked_inputs = set(
link.get("sink_name")
for link in agent.get("links", [])
if link.get("sink_id") == node_id and link.get("sink_name")
)
for req_input in required_inputs:
if (
req_input not in input_defaults
and req_input not in linked_inputs
and req_input != "credentials"
):
block_name = block.get("name", "Unknown Block")
self.add_error(
f"Node '{node_id}' (block '{block_name}' - "
f"{block_id}) is missing required input "
f"'{req_input}'. This input must be either "
f"provided as a default value in the node's "
f"'input_default' field or connected via a link "
f"from another node's output."
)
valid = False
return valid
def validate_data_type_compatibility(
self, agent: AgentDict, blocks: list[dict[str, Any]]
) -> bool:
"""
Validate that linked data types are compatible between source and sink.
Returns True if all data types are compatible, False otherwise.
"""
valid = True
node_lookup = {node.get("id", ""): node for node in agent.get("nodes", [])}
block_lookup = {block.get("id", ""): block for block in blocks}
for link in agent.get("links", []):
source_id = link.get("source_id")
sink_id = link.get("sink_id")
source_name = link.get("source_name")
sink_name = link.get("sink_name")
if not all(
isinstance(v, str) and v
for v in (source_id, sink_id, source_name, sink_name)
):
self.add_error(
f"Link '{link.get('id', 'Unknown')}' is missing required "
f"fields (source_id/sink_id/source_name/sink_name)."
)
valid = False
continue
source_node = node_lookup.get(source_id, "")
sink_node = node_lookup.get(sink_id, "")
if not source_node or not sink_node:
continue
source_block = block_lookup.get(source_node.get("block_id", ""))
sink_block = block_lookup.get(sink_node.get("block_id", ""))
if not source_block or not sink_block:
continue
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
source_type = get_defined_property_type(source_outputs, source_name)
sink_type = get_defined_property_type(sink_inputs, sink_name)
if (
source_type
and sink_type
and not are_types_compatible(source_type, sink_type)
):
source_block_name = source_block.get("name", "Unknown Block")
sink_block_name = sink_block.get("name", "Unknown Block")
self.add_error(
f"Data type mismatch in link '{link.get('id')}': "
f"Source '{source_block_name}' output "
f"'{link.get('source_name', '')}' outputs '{source_type}' "
f"type, but sink '{sink_block_name}' input "
f"'{link.get('sink_name', '')}' expects '{sink_type}' type. "
f"These types must match for the connection to work "
f"properly."
)
valid = False
return valid
def validate_nested_sink_links(
self, agent: AgentDict, blocks: list[dict[str, Any]]
) -> bool:
"""
Validate nested sink links (links with _#_ notation).
Returns True if all nested links are valid, False otherwise.
"""
valid = True
block_input_schemas = {
block.get("id", ""): block.get("inputSchema", {}).get("properties", {})
for block in blocks
}
block_names = {
block.get("id", ""): block.get("name", "Unknown Block") for block in blocks
}
node_lookup = {node.get("id", ""): node for node in agent.get("nodes", [])}
for link in agent.get("links", []):
sink_name = link.get("sink_name", "")
sink_id = link.get("sink_id")
if not sink_name or not sink_id:
continue
if "_#_" in sink_name:
parent, child = sink_name.split("_#_", 1)
sink_node = node_lookup.get(sink_id)
if not sink_node:
continue
block_id = sink_node.get("block_id")
input_props = block_input_schemas.get(block_id, {})
parent_schema = input_props.get(parent)
if not parent_schema:
block_name = block_names.get(block_id, "Unknown Block")
self.add_error(
f"Invalid nested sink link '{sink_name}' for "
f"node '{sink_id}' (block "
f"'{block_name}' - {block_id}): Parent property "
f"'{parent}' does not exist in the block's "
f"input schema."
)
valid = False
continue
# Check if additionalProperties is allowed either directly
# or via anyOf
allows_additional_properties = parent_schema.get(
"additionalProperties", False
)
# Check anyOf for additionalProperties
if not allows_additional_properties and "anyOf" in parent_schema:
any_of_schemas = parent_schema.get("anyOf", [])
if isinstance(any_of_schemas, list):
for schema_option in any_of_schemas:
if isinstance(schema_option, dict) and schema_option.get(
"additionalProperties"
):
allows_additional_properties = True
break
if not allows_additional_properties:
if not (
isinstance(parent_schema, dict)
and "properties" in parent_schema
and isinstance(parent_schema["properties"], dict)
and child in parent_schema["properties"]
):
block_name = block_names.get(block_id, "Unknown Block")
self.add_error(
f"Invalid nested sink link '{sink_name}' "
f"for node '{link.get('sink_id', '')}' (block "
f"'{block_name}' - {block_id}): Child "
f"property '{child}' does not exist in "
f"parent '{parent}' schema. Available "
f"properties: "
f"{list(parent_schema.get('properties', {}).keys())}"
)
valid = False
return valid
def validate_prompt_double_curly_braces_spaces(self, agent: AgentDict) -> bool:
"""
Validate that prompt parameters do not contain spaces in double curly
braces.
Checks the 'prompt' parameter in input_default of each node and reports
errors if values within double curly braces ({{...}}) contain spaces.
For example, {{user name}} should be {{user_name}}.
Args:
agent: The agent dictionary to validate
Returns:
True if all prompts are valid (no spaces in double curly braces),
False otherwise
"""
valid = True
nodes = agent.get("nodes", [])
for node in nodes:
node_id = node.get("id")
input_default = node.get("input_default", {})
# Check if 'prompt' parameter exists
if "prompt" not in input_default:
continue
prompt_text = input_default["prompt"]
# Only process if it's a string
if not isinstance(prompt_text, str):
continue
# Find all double curly brace patterns with spaces
matches = re.finditer(r"\{\{([^}]+)\}\}", prompt_text)
for match in matches:
content = match.group(1)
if " " in content:
start_pos = match.start()
snippet_start = max(0, start_pos - 30)
snippet_end = min(len(prompt_text), match.end() + 30)
snippet = prompt_text[snippet_start:snippet_end]
self.add_error(
f"Node '{node_id}' has spaces in double curly "
f"braces in prompt parameter: "
f"'{{{{{content}}}}}' should be "
f"'{{{{{content.replace(' ', '_')}}}}}'. "
f"Context: ...{snippet}..."
)
valid = False
return valid
def validate_source_output_existence(
self, agent: AgentDict, blocks: list[dict[str, Any]]
) -> bool:
"""
Validate that all source_names in links exist in the corresponding
block's output schema.
Checks that for each link, the source_name field references a valid
output property in the source block's outputSchema. Also handles nested
outputs with _#_ notation.
Args:
agent: The agent dictionary to validate
blocks: List of available blocks with their schemas
Returns:
True if all source output fields exist, False otherwise
"""
valid = True
# Create lookup dictionaries for efficiency
block_output_schemas = {
block.get("id", ""): block.get("outputSchema", {}).get("properties", {})
for block in blocks
}
block_names = {
block.get("id", ""): block.get("name", "Unknown Block") for block in blocks
}
node_lookup = {node.get("id", ""): node for node in agent.get("nodes", [])}
for link in agent.get("links", []):
source_id = link.get("source_id")
source_name = link.get("source_name", "")
link_id = link.get("id", "Unknown")
if not source_name:
self.add_error(
f"Link '{link_id}' is missing 'source_name'. "
f"Every link must specify which output field to read from."
)
valid = False
continue
source_node = node_lookup.get(source_id)
if not source_node:
# This error is already caught by
# validate_link_node_references
continue
block_id = source_node.get("block_id")
block_name = block_names.get(block_id, "Unknown Block")
# Special handling for AgentExecutorBlock - use dynamic
# output_schema from input_default
if block_id == AGENT_EXECUTOR_BLOCK_ID:
input_default = source_node.get("input_default", {})
dynamic_output_schema = input_default.get("output_schema", {})
if not isinstance(dynamic_output_schema, dict):
dynamic_output_schema = {}
output_props = dynamic_output_schema.get("properties", {})
if not isinstance(output_props, dict):
output_props = {}
else:
output_props = block_output_schemas.get(block_id, {})
# Handle nested source names (with _#_ notation)
if "_#_" in source_name:
parent, child = source_name.split("_#_", 1)
parent_schema = output_props.get(parent)
if not parent_schema:
self.add_error(
f"Invalid source output field '{source_name}' "
f"in link '{link_id}' from node '{source_id}' "
f"(block '{block_name}' - {block_id}): Parent "
f"property '{parent}' does not exist in the "
f"block's output schema."
)
valid = False
continue
# Check if additionalProperties is allowed either directly
# or via anyOf
allows_additional_properties = parent_schema.get(
"additionalProperties", False
)
if not allows_additional_properties and "anyOf" in parent_schema:
any_of_schemas = parent_schema.get("anyOf", [])
if isinstance(any_of_schemas, list):
for schema_option in any_of_schemas:
if isinstance(schema_option, dict) and schema_option.get(
"additionalProperties"
):
allows_additional_properties = True
break
# Also allow when items have
# additionalProperties (array of objects)
if (
isinstance(schema_option, dict)
and "items" in schema_option
):
items_schema = schema_option.get("items")
if isinstance(items_schema, dict) and items_schema.get(
"additionalProperties"
):
allows_additional_properties = True
break
# Only require child in properties when
# additionalProperties is not allowed
if not allows_additional_properties:
if not (
isinstance(parent_schema, dict)
and "properties" in parent_schema
and isinstance(parent_schema["properties"], dict)
and child in parent_schema["properties"]
):
available_props = (
list(parent_schema.get("properties", {}).keys())
if isinstance(parent_schema, dict)
else []
)
self.add_error(
f"Invalid nested source output field "
f"'{source_name}' in link '{link_id}' from "
f"node '{source_id}' (block "
f"'{block_name}' - {block_id}): Child "
f"property '{child}' does not exist in "
f"parent '{parent}' output schema. "
f"Available properties: {available_props}"
)
valid = False
else:
# Check simple (non-nested) source name
if source_name not in output_props:
available_outputs = list(output_props.keys())
self.add_error(
f"Invalid source output field '{source_name}' "
f"in link '{link_id}' from node '{source_id}' "
f"(block '{block_name}' - {block_id}): Output "
f"property '{source_name}' does not exist in "
f"the block's output schema. Available outputs: "
f"{available_outputs}"
)
valid = False
return valid
def validate_io_blocks(self, agent: AgentDict) -> bool:
"""
Validate that the agent has at least one AgentInputBlock and one
AgentOutputBlock. These blocks define the agent's interface.
Returns True if both are present, False otherwise.
"""
valid = True
block_ids = {node.get("block_id") for node in agent.get("nodes", [])}
if AGENT_INPUT_BLOCK_ID not in block_ids:
self.add_error(
f"Agent is missing an AgentInputBlock (block_id: "
f"'{AGENT_INPUT_BLOCK_ID}'). Every agent must have at "
f"least one AgentInputBlock to define user-facing inputs. "
f"Add a node with block_id '{AGENT_INPUT_BLOCK_ID}' and "
f"set input_default with 'name' and optionally 'title'."
)
valid = False
if AGENT_OUTPUT_BLOCK_ID not in block_ids:
self.add_error(
f"Agent is missing an AgentOutputBlock (block_id: "
f"'{AGENT_OUTPUT_BLOCK_ID}'). Every agent must have at "
f"least one AgentOutputBlock to define user-facing outputs. "
f"Add a node with block_id '{AGENT_OUTPUT_BLOCK_ID}' and "
f"set input_default with 'name', then link 'value' from "
f"another block's output."
)
valid = False
return valid
def validate_agent_executor_blocks(
self,
agent: AgentDict,
library_agents: list[dict[str, Any]] | None = None,
) -> bool:
"""
Validate AgentExecutorBlock nodes have required fields and valid
references.
Checks that AgentExecutorBlock nodes:
1. Have a valid graph_id in input_default (required)
2. If graph_id matches a known library agent, validates version
consistency
3. Sub-agent required inputs are connected via links (not hardcoded)
Note: Unknown graph_ids are not treated as errors - they could be valid
direct references to agents by their actual ID (not via library_agents).
This is consistent with fix_agent_executor_blocks() behavior.
Args:
agent: The agent dictionary to validate
library_agents: List of available library agents (for version
validation)
Returns:
True if all AgentExecutorBlock nodes are valid, False otherwise
"""
valid = True
nodes = agent.get("nodes", [])
links = agent.get("links", [])
# Create lookup for library agents
library_agent_lookup: dict[str, dict[str, Any]] = {}
if library_agents:
library_agent_lookup = {la.get("graph_id", ""): la for la in library_agents}
for node in nodes:
if node.get("block_id") != AGENT_EXECUTOR_BLOCK_ID:
continue
node_id = node.get("id")
input_default = node.get("input_default", {})
# Check for required graph_id
graph_id = input_default.get("graph_id")
if not graph_id:
self.add_error(
f"AgentExecutorBlock node '{node_id}' is missing "
f"required 'graph_id' in input_default. This field "
f"must reference the ID of the sub-agent to execute."
)
valid = False
continue
# If graph_id is not in library_agent_lookup, skip validation
if graph_id not in library_agent_lookup:
continue
# Validate version consistency for known library agents
library_agent = library_agent_lookup[graph_id]
expected_version = library_agent.get("graph_version")
current_version = input_default.get("graph_version")
if (
current_version
and expected_version
and current_version != expected_version
):
self.add_error(
f"AgentExecutorBlock node '{node_id}' has mismatched "
f"graph_version: got {current_version}, expected "
f"{expected_version} for library agent "
f"'{library_agent.get('name')}'"
)
valid = False
# Validate sub-agent inputs are properly linked (not hardcoded)
sub_agent_input_schema = library_agent.get("input_schema", {})
if not isinstance(sub_agent_input_schema, dict):
sub_agent_input_schema = {}
sub_agent_required_inputs = sub_agent_input_schema.get("required", [])
sub_agent_properties = sub_agent_input_schema.get("properties", {})
# Get all linked inputs to this node
linked_sub_agent_inputs: set[str] = set()
for link in links:
if link.get("sink_id") == node_id:
sink_name = link.get("sink_name", "")
if sink_name in sub_agent_properties:
linked_sub_agent_inputs.add(sink_name)
# Check for hardcoded inputs that should be linked
hardcoded_inputs = input_default.get("inputs", {})
input_schema = input_default.get("input_schema", {})
schema_properties = (
input_schema.get("properties", {})
if isinstance(input_schema, dict)
else {}
)
if isinstance(hardcoded_inputs, dict) and hardcoded_inputs:
for input_name, value in hardcoded_inputs.items():
if input_name not in sub_agent_properties:
continue
if value is None:
continue
# Skip if this input is already linked
if input_name in linked_sub_agent_inputs:
continue
prop_schema = schema_properties.get(input_name, {})
schema_default = (
prop_schema.get("default")
if isinstance(prop_schema, dict)
else None
)
if schema_default is not None and self._values_equal(
value, schema_default
):
continue
# This is a non-default hardcoded value without a link
self.add_error(
f"AgentExecutorBlock node '{node_id}' has "
f"hardcoded input '{input_name}'. Sub-agent "
f"inputs should be connected via links using "
f"'{input_name}' as sink_name, not hardcoded "
f"in input_default.inputs. Create a link from "
f"the appropriate source node."
)
valid = False
# Check for missing required sub-agent inputs.
# An input is satisfied if it is linked OR has an allowed
# hardcoded value (i.e. equals the schema default — the
# previous check already flags non-default hardcoded values).
hardcoded_inputs_dict = (
hardcoded_inputs if isinstance(hardcoded_inputs, dict) else {}
)
for req_input in sub_agent_required_inputs:
if req_input in linked_sub_agent_inputs:
continue
# Check if fixer populated it with a schema default value
if req_input in hardcoded_inputs_dict:
prop_schema = schema_properties.get(req_input, {})
schema_default = (
prop_schema.get("default")
if isinstance(prop_schema, dict)
else None
)
if schema_default is not None and self._values_equal(
hardcoded_inputs_dict[req_input], schema_default
):
continue
self.add_error(
f"AgentExecutorBlock node '{node_id}' is "
f"missing required sub-agent input "
f"'{req_input}'. Create a link to this node "
f"using sink_name '{req_input}' to connect "
f"the input."
)
valid = False
return valid
def validate_agent_executor_block_schemas(
self,
agent: AgentDict,
) -> bool:
"""
Validate that AgentExecutorBlock nodes have valid input_schema and
output_schema.
This validation runs regardless of library_agents availability and
ensures that the schemas are properly populated to prevent frontend
crashes.
Args:
agent: The agent dictionary to validate
Returns:
True if all AgentExecutorBlock nodes have valid schemas, False
otherwise
"""
valid = True
nodes = agent.get("nodes", [])
for node in nodes:
if node.get("block_id") != AGENT_EXECUTOR_BLOCK_ID:
continue
node_id = node.get("id")
input_default = node.get("input_default", {})
customized_name = (node.get("metadata") or {}).get(
"customized_name", "Unknown"
)
# Check input_schema
input_schema = input_default.get("input_schema")
if input_schema is None or not isinstance(input_schema, dict):
self.add_error(
f"AgentExecutorBlock node '{node_id}' "
f"({customized_name}) has missing or invalid "
f"input_schema. The input_schema must be a valid "
f"JSON Schema object with 'properties' and "
f"'required' fields."
)
valid = False
elif not input_schema.get("properties") and not input_schema.get("type"):
# Empty schema like {} is invalid
self.add_error(
f"AgentExecutorBlock node '{node_id}' "
f"({customized_name}) has empty input_schema. The "
f"input_schema must define the sub-agent's expected "
f"inputs. This usually indicates the sub-agent "
f"reference is incomplete or the library agent was "
f"not properly passed."
)
valid = False
# Check output_schema
output_schema = input_default.get("output_schema")
if output_schema is None or not isinstance(output_schema, dict):
self.add_error(
f"AgentExecutorBlock node '{node_id}' "
f"({customized_name}) has missing or invalid "
f"output_schema. The output_schema must be a valid "
f"JSON Schema object defining the sub-agent's "
f"outputs."
)
valid = False
elif not output_schema.get("properties") and not output_schema.get("type"):
# Empty schema like {} is invalid
self.add_error(
f"AgentExecutorBlock node '{node_id}' "
f"({customized_name}) has empty output_schema. "
f"The output_schema must define the sub-agent's "
f"expected outputs. This usually indicates the "
f"sub-agent reference is incomplete or the library "
f"agent was not properly passed."
)
valid = False
return valid
def validate_mcp_tool_blocks(self, agent: AgentDict) -> bool:
"""Validate that MCPToolBlock nodes have required fields.
Checks that each MCPToolBlock node has:
1. A non-empty `server_url` in input_default
2. A non-empty `selected_tool` in input_default
Returns True if all MCPToolBlock nodes are valid, False otherwise.
"""
valid = True
nodes = agent.get("nodes", [])
for node in nodes:
if node.get("block_id") != MCP_TOOL_BLOCK_ID:
continue
node_id = node.get("id", "unknown")
input_default = node.get("input_default", {})
customized_name = (node.get("metadata") or {}).get(
"customized_name", node_id
)
server_url = input_default.get("server_url")
if not server_url:
self.add_error(
f"MCPToolBlock node '{customized_name}' ({node_id}) is "
f"missing required 'server_url' in input_default. "
f"Set this to the MCP server URL "
f"(e.g. 'https://mcp.example.com/sse')."
)
valid = False
selected_tool = input_default.get("selected_tool")
if not selected_tool:
self.add_error(
f"MCPToolBlock node '{customized_name}' ({node_id}) is "
f"missing required 'selected_tool' in input_default. "
f"Set this to the name of the MCP tool to execute."
)
valid = False
return valid
def validate(
self,
agent: AgentDict,
blocks: list[dict[str, Any]],
library_agents: list[dict[str, Any]] | None = None,
) -> tuple[bool, str | None]:
"""
Comprehensive validation of an agent against available blocks.
Returns:
Tuple[bool, Optional[str]]: (is_valid, error_message)
- is_valid: True if agent passes all validations, False otherwise
- error_message: Detailed error message if validation fails, None
if successful
"""
logger.info("Validating agent...")
self.errors = []
checks = [
(
"Block existence",
self.validate_block_existence(agent, blocks),
),
(
"Link node references",
self.validate_link_node_references(agent),
),
(
"Required inputs",
self.validate_required_inputs(agent, blocks),
),
(
"Data type compatibility",
self.validate_data_type_compatibility(agent, blocks),
),
(
"Nested sink links",
self.validate_nested_sink_links(agent, blocks),
),
(
"Source output existence",
self.validate_source_output_existence(agent, blocks),
),
(
"Prompt double curly braces spaces",
self.validate_prompt_double_curly_braces_spaces(agent),
),
(
"IO blocks",
self.validate_io_blocks(agent),
),
# Always validate AgentExecutorBlock schemas to prevent
# frontend crashes
(
"AgentExecutorBlock schemas",
self.validate_agent_executor_block_schemas(agent),
),
(
"MCP tool blocks",
self.validate_mcp_tool_blocks(agent),
),
]
# Add AgentExecutorBlock detailed validation if library_agents
# provided
if library_agents:
checks.append(
(
"AgentExecutorBlock references",
self.validate_agent_executor_blocks(agent, library_agents),
)
)
all_passed = all(check[1] for check in checks)
if all_passed:
logger.info("Agent validation successful.")
return True, None
else:
error_message = "Agent validation failed with the following errors:\n\n"
for i, error in enumerate(self.errors, 1):
error_message += f"{i}. {error}\n"
logger.error(f"Agent validation failed: {error_message}")
return False, error_message

View File

@@ -0,0 +1,710 @@
"""Unit tests for AgentValidator."""
from .helpers import (
AGENT_EXECUTOR_BLOCK_ID,
AGENT_INPUT_BLOCK_ID,
AGENT_OUTPUT_BLOCK_ID,
MCP_TOOL_BLOCK_ID,
generate_uuid,
)
from .validator import AgentValidator
def _make_agent(
nodes: list | None = None,
links: list | None = None,
agent_id: str | None = None,
) -> dict:
"""Build a minimal agent dict for testing."""
return {
"id": agent_id or generate_uuid(),
"name": "Test Agent",
"nodes": nodes or [],
"links": links or [],
}
def _make_node(
node_id: str | None = None,
block_id: str = "block-1",
input_default: dict | None = None,
position: tuple[int, int] = (0, 0),
) -> dict:
return {
"id": node_id or generate_uuid(),
"block_id": block_id,
"input_default": input_default or {},
"metadata": {"position": {"x": position[0], "y": position[1]}},
}
def _make_link(
link_id: str | None = None,
source_id: str = "",
source_name: str = "output",
sink_id: str = "",
sink_name: str = "input",
) -> dict:
return {
"id": link_id or generate_uuid(),
"source_id": source_id,
"source_name": source_name,
"sink_id": sink_id,
"sink_name": sink_name,
}
def _make_block(
block_id: str = "block-1",
name: str = "TestBlock",
input_schema: dict | None = None,
output_schema: dict | None = None,
categories: list | None = None,
static_output: bool = False,
) -> dict:
return {
"id": block_id,
"name": name,
"inputSchema": input_schema or {"properties": {}, "required": []},
"outputSchema": output_schema or {"properties": {}},
"categories": categories or [],
"staticOutput": static_output,
}
# ============================================================================
# validate_block_existence
# ============================================================================
class TestValidateBlockExistence:
def test_valid_blocks_pass(self):
v = AgentValidator()
node = _make_node(block_id="b1")
block = _make_block(block_id="b1")
agent = _make_agent(nodes=[node])
assert v.validate_block_existence(agent, [block]) is True
assert v.errors == []
def test_missing_block_fails(self):
v = AgentValidator()
node = _make_node(block_id="nonexistent")
agent = _make_agent(nodes=[node])
assert v.validate_block_existence(agent, []) is False
assert len(v.errors) == 1
assert "does not exist" in v.errors[0]
def test_missing_block_id_field(self):
v = AgentValidator()
node = {"id": "n1", "input_default": {}, "metadata": {}}
agent = _make_agent(nodes=[node])
assert v.validate_block_existence(agent, []) is False
assert "missing a 'block_id'" in v.errors[0]
# ============================================================================
# validate_link_node_references
# ============================================================================
class TestValidateLinkNodeReferences:
def test_valid_references_pass(self):
v = AgentValidator()
n1 = _make_node(node_id="n1")
n2 = _make_node(node_id="n2")
link = _make_link(source_id="n1", sink_id="n2")
agent = _make_agent(nodes=[n1, n2], links=[link])
assert v.validate_link_node_references(agent) is True
assert v.errors == []
def test_invalid_source_fails(self):
v = AgentValidator()
n1 = _make_node(node_id="n1")
link = _make_link(source_id="missing", sink_id="n1")
agent = _make_agent(nodes=[n1], links=[link])
assert v.validate_link_node_references(agent) is False
assert any("source_id" in e for e in v.errors)
def test_invalid_sink_fails(self):
v = AgentValidator()
n1 = _make_node(node_id="n1")
link = _make_link(source_id="n1", sink_id="missing")
agent = _make_agent(nodes=[n1], links=[link])
assert v.validate_link_node_references(agent) is False
assert any("sink_id" in e for e in v.errors)
# ============================================================================
# validate_required_inputs
# ============================================================================
class TestValidateRequiredInputs:
def test_satisfied_by_default_passes(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {"url": {"type": "string"}},
"required": ["url"],
},
)
node = _make_node(block_id="b1", input_default={"url": "http://example.com"})
agent = _make_agent(nodes=[node])
assert v.validate_required_inputs(agent, [block]) is True
assert v.errors == []
def test_satisfied_by_link_passes(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {"url": {"type": "string"}},
"required": ["url"],
},
)
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(source_id="n2", sink_id="n1", sink_name="url")
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_required_inputs(agent, [block]) is True
def test_missing_required_input_fails(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {"url": {"type": "string"}},
"required": ["url"],
},
)
node = _make_node(block_id="b1", input_default={})
agent = _make_agent(nodes=[node])
assert v.validate_required_inputs(agent, [block]) is False
assert any("missing required input" in e for e in v.errors)
def test_credentials_always_allowed_missing(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {"credentials": {"type": "object"}},
"required": ["credentials"],
},
)
node = _make_node(block_id="b1", input_default={})
agent = _make_agent(nodes=[node])
assert v.validate_required_inputs(agent, [block]) is True
# ============================================================================
# validate_data_type_compatibility
# ============================================================================
class TestValidateDataTypeCompatibility:
def test_matching_types_pass(self):
v = AgentValidator()
src_block = _make_block(
block_id="src-b",
output_schema={"properties": {"out": {"type": "string"}}},
)
sink_block = _make_block(
block_id="sink-b",
input_schema={"properties": {"inp": {"type": "string"}}, "required": []},
)
src_node = _make_node(node_id="n1", block_id="src-b")
sink_node = _make_node(node_id="n2", block_id="sink-b")
link = _make_link(
source_id="n1", source_name="out", sink_id="n2", sink_name="inp"
)
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
assert (
v.validate_data_type_compatibility(agent, [src_block, sink_block]) is True
)
def test_int_number_compatible(self):
v = AgentValidator()
src_block = _make_block(
block_id="src-b",
output_schema={"properties": {"out": {"type": "integer"}}},
)
sink_block = _make_block(
block_id="sink-b",
input_schema={"properties": {"inp": {"type": "number"}}, "required": []},
)
src_node = _make_node(node_id="n1", block_id="src-b")
sink_node = _make_node(node_id="n2", block_id="sink-b")
link = _make_link(
source_id="n1", source_name="out", sink_id="n2", sink_name="inp"
)
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
assert (
v.validate_data_type_compatibility(agent, [src_block, sink_block]) is True
)
def test_mismatched_types_fail(self):
v = AgentValidator()
src_block = _make_block(
block_id="src-b",
output_schema={"properties": {"out": {"type": "string"}}},
)
sink_block = _make_block(
block_id="sink-b",
input_schema={"properties": {"inp": {"type": "integer"}}, "required": []},
)
src_node = _make_node(node_id="n1", block_id="src-b")
sink_node = _make_node(node_id="n2", block_id="sink-b")
link = _make_link(
source_id="n1", source_name="out", sink_id="n2", sink_name="inp"
)
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
assert (
v.validate_data_type_compatibility(agent, [src_block, sink_block]) is False
)
assert any("mismatch" in e.lower() for e in v.errors)
# ============================================================================
# validate_source_output_existence
# ============================================================================
class TestValidateSourceOutputExistence:
def test_valid_source_output_passes(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
output_schema={"properties": {"result": {"type": "string"}}},
)
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(source_id="n1", source_name="result", sink_id="n2")
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_source_output_existence(agent, [block]) is True
def test_invalid_source_output_fails(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
output_schema={"properties": {"result": {"type": "string"}}},
)
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(source_id="n1", source_name="nonexistent", sink_id="n2")
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_source_output_existence(agent, [block]) is False
assert any("does not exist" in e for e in v.errors)
# ============================================================================
# validate_prompt_double_curly_braces_spaces
# ============================================================================
class TestValidatePromptDoubleCurlyBracesSpaces:
def test_no_spaces_passes(self):
v = AgentValidator()
node = _make_node(input_default={"prompt": "Hello {{name}}!"})
agent = _make_agent(nodes=[node])
assert v.validate_prompt_double_curly_braces_spaces(agent) is True
def test_spaces_in_braces_fails(self):
v = AgentValidator()
node = _make_node(input_default={"prompt": "Hello {{user name}}!"})
agent = _make_agent(nodes=[node])
assert v.validate_prompt_double_curly_braces_spaces(agent) is False
assert any("spaces" in e for e in v.errors)
# ============================================================================
# validate_nested_sink_links
# ============================================================================
class TestValidateNestedSinkLinks:
def test_valid_nested_link_passes(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {
"config": {
"type": "object",
"properties": {"key": {"type": "string"}},
}
},
"required": [],
},
)
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(sink_id="n1", sink_name="config_#_key", source_id="n2")
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_nested_sink_links(agent, [block]) is True
def test_invalid_parent_fails(self):
v = AgentValidator()
block = _make_block(block_id="b1")
node = _make_node(node_id="n1", block_id="b1")
link = _make_link(sink_id="n1", sink_name="nonexistent_#_key", source_id="n2")
agent = _make_agent(nodes=[node], links=[link])
assert v.validate_nested_sink_links(agent, [block]) is False
assert any("does not exist" in e for e in v.errors)
# ============================================================================
# validate_agent_executor_block_schemas
# ============================================================================
class TestValidateAgentExecutorBlockSchemas:
def test_valid_schemas_pass(self):
v = AgentValidator()
node = _make_node(
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={
"graph_id": generate_uuid(),
"input_schema": {"properties": {"q": {"type": "string"}}},
"output_schema": {"properties": {"result": {"type": "string"}}},
},
)
agent = _make_agent(nodes=[node])
assert v.validate_agent_executor_block_schemas(agent) is True
assert v.errors == []
def test_empty_input_schema_fails(self):
v = AgentValidator()
node = _make_node(
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={
"graph_id": generate_uuid(),
"input_schema": {},
"output_schema": {"properties": {"result": {"type": "string"}}},
},
)
agent = _make_agent(nodes=[node])
assert v.validate_agent_executor_block_schemas(agent) is False
assert any("empty input_schema" in e for e in v.errors)
def test_missing_output_schema_fails(self):
v = AgentValidator()
node = _make_node(
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={
"graph_id": generate_uuid(),
"input_schema": {"properties": {"q": {"type": "string"}}},
},
)
agent = _make_agent(nodes=[node])
assert v.validate_agent_executor_block_schemas(agent) is False
assert any("output_schema" in e for e in v.errors)
# ============================================================================
# validate_agent_executor_blocks
# ============================================================================
class TestValidateAgentExecutorBlocks:
def test_missing_graph_id_fails(self):
v = AgentValidator()
node = _make_node(
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={},
)
agent = _make_agent(nodes=[node])
assert v.validate_agent_executor_blocks(agent) is False
assert any("graph_id" in e for e in v.errors)
def test_valid_graph_id_passes(self):
v = AgentValidator()
node = _make_node(
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={"graph_id": generate_uuid()},
)
agent = _make_agent(nodes=[node])
assert v.validate_agent_executor_blocks(agent) is True
def test_version_mismatch_with_library_agent(self):
v = AgentValidator()
lib_id = generate_uuid()
node = _make_node(
node_id="n1",
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={"graph_id": lib_id, "graph_version": 1},
)
agent = _make_agent(nodes=[node])
library_agents = [{"graph_id": lib_id, "graph_version": 3, "name": "Sub Agent"}]
assert v.validate_agent_executor_blocks(agent, library_agents) is False
assert any("mismatched graph_version" in e for e in v.errors)
def test_required_input_satisfied_by_schema_default_passes(self):
"""Required sub-agent inputs filled with their schema default by the fixer
should NOT be flagged as missing."""
v = AgentValidator()
lib_id = generate_uuid()
node = _make_node(
node_id="n1",
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={
"graph_id": lib_id,
"input_schema": {
"properties": {"mode": {"type": "string", "default": "fast"}}
},
"inputs": {"mode": "fast"}, # fixer populated with schema default
},
)
agent = _make_agent(nodes=[node])
library_agents = [
{
"graph_id": lib_id,
"graph_version": 1,
"name": "Sub",
"input_schema": {
"required": ["mode"],
"properties": {"mode": {"type": "string", "default": "fast"}},
},
"output_schema": {},
}
]
assert v.validate_agent_executor_blocks(agent, library_agents) is True
assert v.errors == []
def test_required_input_not_linked_and_no_default_fails(self):
"""Required sub-agent inputs without a link or schema default must fail."""
v = AgentValidator()
lib_id = generate_uuid()
node = _make_node(
node_id="n1",
block_id=AGENT_EXECUTOR_BLOCK_ID,
input_default={
"graph_id": lib_id,
"input_schema": {"properties": {"query": {"type": "string"}}},
"inputs": {},
},
)
agent = _make_agent(nodes=[node])
library_agents = [
{
"graph_id": lib_id,
"graph_version": 1,
"name": "Sub",
"input_schema": {
"required": ["query"],
"properties": {"query": {"type": "string"}},
},
"output_schema": {},
}
]
assert v.validate_agent_executor_blocks(agent, library_agents) is False
assert any("missing required sub-agent input" in e for e in v.errors)
# ============================================================================
# validate_io_blocks
# ============================================================================
class TestValidateIoBlocks:
def test_missing_input_block_reports_error(self):
v = AgentValidator()
# Agent has output block but no input block
node = _make_node(block_id=AGENT_OUTPUT_BLOCK_ID)
agent = _make_agent(nodes=[node])
assert v.validate_io_blocks(agent) is False
assert len(v.errors) == 1
assert "AgentInputBlock" in v.errors[0]
def test_missing_output_block_reports_error(self):
v = AgentValidator()
# Agent has input block but no output block
node = _make_node(block_id=AGENT_INPUT_BLOCK_ID)
agent = _make_agent(nodes=[node])
assert v.validate_io_blocks(agent) is False
assert len(v.errors) == 1
assert "AgentOutputBlock" in v.errors[0]
def test_missing_both_io_blocks_reports_two_errors(self):
v = AgentValidator()
node = _make_node(block_id="some-other-block")
agent = _make_agent(nodes=[node])
assert v.validate_io_blocks(agent) is False
assert len(v.errors) == 2
def test_both_io_blocks_present_no_error(self):
v = AgentValidator()
input_node = _make_node(block_id=AGENT_INPUT_BLOCK_ID)
output_node = _make_node(block_id=AGENT_OUTPUT_BLOCK_ID)
agent = _make_agent(nodes=[input_node, output_node])
assert v.validate_io_blocks(agent) is True
assert v.errors == []
def test_empty_agent_reports_both_missing(self):
v = AgentValidator()
agent = _make_agent(nodes=[])
assert v.validate_io_blocks(agent) is False
assert len(v.errors) == 2
# ============================================================================
# validate (integration)
# ============================================================================
class TestValidate:
def test_valid_agent_passes(self):
v = AgentValidator()
block = _make_block(
block_id="b1",
input_schema={
"properties": {"url": {"type": "string"}},
"required": ["url"],
},
output_schema={"properties": {"result": {"type": "string"}}},
)
input_block = _make_block(
block_id=AGENT_INPUT_BLOCK_ID,
name="AgentInputBlock",
output_schema={"properties": {"result": {}}},
)
output_block = _make_block(
block_id=AGENT_OUTPUT_BLOCK_ID,
name="AgentOutputBlock",
)
input_node = _make_node(
node_id="n-in",
block_id=AGENT_INPUT_BLOCK_ID,
input_default={"name": "url"},
)
n1 = _make_node(
node_id="n1", block_id="b1", input_default={"url": "http://example.com"}
)
n2 = _make_node(
node_id="n2", block_id="b1", input_default={"url": "http://example2.com"}
)
output_node = _make_node(
node_id="n-out",
block_id=AGENT_OUTPUT_BLOCK_ID,
input_default={"name": "result"},
)
link = _make_link(
source_id="n1", source_name="result", sink_id="n2", sink_name="url"
)
agent = _make_agent(nodes=[input_node, n1, n2, output_node], links=[link])
is_valid, error_message = v.validate(agent, [block, input_block, output_block])
assert is_valid is True
assert error_message is None
def test_invalid_agent_returns_errors(self):
v = AgentValidator()
node = _make_node(block_id="nonexistent")
agent = _make_agent(nodes=[node])
is_valid, error_message = v.validate(agent, [])
assert is_valid is False
assert error_message is not None
assert "does not exist" in error_message
def test_empty_agent_fails_io_validation(self):
v = AgentValidator()
agent = _make_agent()
is_valid, error_message = v.validate(agent, [])
assert is_valid is False
assert error_message is not None
assert "AgentInputBlock" in error_message
assert "AgentOutputBlock" in error_message
class TestValidateMCPToolBlocks:
"""Tests for validate_mcp_tool_blocks."""
def test_missing_server_url_reports_error(self):
v = AgentValidator()
node = _make_node(
block_id=MCP_TOOL_BLOCK_ID,
input_default={"selected_tool": "my_tool"},
)
agent = _make_agent(nodes=[node])
result = v.validate_mcp_tool_blocks(agent)
assert result is False
assert any("server_url" in e for e in v.errors)
def test_missing_selected_tool_reports_error(self):
v = AgentValidator()
node = _make_node(
block_id=MCP_TOOL_BLOCK_ID,
input_default={"server_url": "https://mcp.example.com/sse"},
)
agent = _make_agent(nodes=[node])
result = v.validate_mcp_tool_blocks(agent)
assert result is False
assert any("selected_tool" in e for e in v.errors)
def test_valid_mcp_block_passes(self):
v = AgentValidator()
node = _make_node(
block_id=MCP_TOOL_BLOCK_ID,
input_default={
"server_url": "https://mcp.example.com/sse",
"selected_tool": "search",
"tool_input_schema": {"properties": {"query": {"type": "string"}}},
"tool_arguments": {},
},
)
agent = _make_agent(nodes=[node])
result = v.validate_mcp_tool_blocks(agent)
assert result is True
assert len(v.errors) == 0
def test_both_missing_reports_two_errors(self):
v = AgentValidator()
node = _make_node(
block_id=MCP_TOOL_BLOCK_ID,
input_default={},
)
agent = _make_agent(nodes=[node])
v.validate_mcp_tool_blocks(agent)
assert len(v.errors) == 2

View File

@@ -208,6 +208,9 @@ def _library_agent_to_info(agent: LibraryAgent) -> AgentInfo:
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
graph_version=agent.graph_version,
input_schema=agent.input_schema,
output_schema=agent.output_schema,
)

View File

@@ -21,10 +21,10 @@ from typing import Any
from e2b import AsyncSandbox
from e2b.exceptions import TimeoutException
from backend.copilot.context import E2B_WORKDIR, get_current_sandbox
from backend.copilot.model import ChatSession
from .base import BaseTool
from .e2b_sandbox import E2B_WORKDIR
from .models import BashExecResponse, ErrorResponse, ToolResponseBase
from .sandbox import get_workspace_dir, has_full_sandbox, run_sandboxed
@@ -94,9 +94,6 @@ class BashExecTool(BaseTool):
session_id=session_id,
)
# E2B path: run on remote cloud sandbox when available.
from backend.copilot.sdk.tool_adapter import get_current_sandbox
sandbox = get_current_sandbox()
if sandbox is not None:
return await self._execute_on_e2b(sandbox, command, timeout, session_id)

View File

@@ -0,0 +1,157 @@
"""Tool for continuing block execution after human review approval."""
import logging
from typing import Any
from prisma.enums import ReviewStatus
from backend.blocks import get_block
from backend.copilot.constants import (
COPILOT_NODE_PREFIX,
COPILOT_SESSION_PREFIX,
parse_node_id_from_exec_id,
)
from backend.copilot.model import ChatSession
from backend.data.db_accessors import review_db
from .base import BaseTool
from .helpers import execute_block, resolve_block_credentials
from .models import ErrorResponse, ToolResponseBase
logger = logging.getLogger(__name__)
class ContinueRunBlockTool(BaseTool):
"""Tool for continuing a block execution after human review approval."""
@property
def name(self) -> str:
return "continue_run_block"
@property
def description(self) -> str:
return (
"Continue executing a block after human review approval. "
"Use this after a run_block call returned review_required. "
"Pass the review_id from the review_required response. "
"The block will execute with the original pre-approved input data."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"review_id": {
"type": "string",
"description": (
"The review_id from a previous review_required response. "
"This resumes execution with the pre-approved input data."
),
},
},
"required": ["review_id"],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
review_id = (
kwargs.get("review_id", "").strip() if kwargs.get("review_id") else ""
)
session_id = session.session_id
if not review_id:
return ErrorResponse(
message="Please provide a review_id", session_id=session_id
)
if not user_id:
return ErrorResponse(
message="Authentication required", session_id=session_id
)
# Look up and validate the review record via adapter
reviews = await review_db().get_reviews_by_node_exec_ids([review_id], user_id)
review = reviews.get(review_id)
if not review:
return ErrorResponse(
message=(
f"Review '{review_id}' not found or already executed. "
"It may have been consumed by a previous continue_run_block call."
),
session_id=session_id,
)
# Validate the review belongs to this session
expected_graph_exec_id = f"{COPILOT_SESSION_PREFIX}{session_id}"
if review.graph_exec_id != expected_graph_exec_id:
return ErrorResponse(
message="Review does not belong to this session.",
session_id=session_id,
)
if review.status == ReviewStatus.WAITING:
return ErrorResponse(
message="Review has not been approved yet. "
"Please wait for the user to approve the review first.",
session_id=session_id,
)
if review.status == ReviewStatus.REJECTED:
return ErrorResponse(
message="Review was rejected. The block will not execute.",
session_id=session_id,
)
# Extract block_id from review_id: copilot-node-{block_id}:{random_hex}
block_id = parse_node_id_from_exec_id(review_id).removeprefix(
COPILOT_NODE_PREFIX
)
block = get_block(block_id)
if not block:
return ErrorResponse(
message=f"Block '{block_id}' not found", session_id=session_id
)
input_data: dict[str, Any] = (
review.payload if isinstance(review.payload, dict) else {}
)
logger.info(
f"Continuing block {block.name} ({block_id}) for user {user_id} "
f"with review_id={review_id}"
)
matched_creds, missing_creds = await resolve_block_credentials(
user_id, block, input_data
)
if missing_creds:
return ErrorResponse(
message=f"Block '{block.name}' requires credentials that are not configured.",
session_id=session_id,
)
result = await execute_block(
block=block,
block_id=block_id,
input_data=input_data,
user_id=user_id,
session_id=session_id,
node_exec_id=review_id,
matched_credentials=matched_creds,
)
# Delete review record after successful execution (one-time use)
if result.type != "error":
await review_db().delete_review_by_node_exec_id(review_id, user_id)
return result

View File

@@ -0,0 +1,186 @@
"""Tests for ContinueRunBlockTool."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from prisma.enums import ReviewStatus
from ._test_data import make_session
from .continue_run_block import ContinueRunBlockTool
from .models import BlockOutputResponse, ErrorResponse
_TEST_USER_ID = "test-user-continue"
def _make_review_model(
node_exec_id: str,
status: ReviewStatus = ReviewStatus.APPROVED,
payload: dict | None = None,
graph_exec_id: str = "",
):
"""Create a mock PendingHumanReviewModel."""
mock = MagicMock()
mock.node_exec_id = node_exec_id
mock.status = status
mock.payload = payload or {"text": "hello"}
mock.graph_exec_id = graph_exec_id
return mock
class TestContinueRunBlock:
@pytest.mark.asyncio(loop_scope="session")
async def test_missing_review_id_returns_error(self):
tool = ContinueRunBlockTool()
session = make_session(user_id=_TEST_USER_ID)
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
review_id="",
)
assert isinstance(response, ErrorResponse)
assert "review_id" in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_review_not_found_returns_error(self):
tool = ContinueRunBlockTool()
session = make_session(user_id=_TEST_USER_ID)
mock_db = MagicMock()
mock_db.get_reviews_by_node_exec_ids = AsyncMock(return_value={})
with patch(
"backend.copilot.tools.continue_run_block.review_db",
return_value=mock_db,
):
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
review_id="copilot-node-some-block:abc12345",
)
assert isinstance(response, ErrorResponse)
assert "not found" in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_waiting_review_returns_error(self):
tool = ContinueRunBlockTool()
session = make_session(user_id=_TEST_USER_ID)
review_id = "copilot-node-some-block:abc12345"
graph_exec_id = f"copilot-session-{session.session_id}"
review = _make_review_model(
review_id, status=ReviewStatus.WAITING, graph_exec_id=graph_exec_id
)
mock_db = MagicMock()
mock_db.get_reviews_by_node_exec_ids = AsyncMock(
return_value={review_id: review}
)
with patch(
"backend.copilot.tools.continue_run_block.review_db",
return_value=mock_db,
):
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
review_id=review_id,
)
assert isinstance(response, ErrorResponse)
assert "not been approved" in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_rejected_review_returns_error(self):
tool = ContinueRunBlockTool()
session = make_session(user_id=_TEST_USER_ID)
review_id = "copilot-node-some-block:abc12345"
graph_exec_id = f"copilot-session-{session.session_id}"
review = _make_review_model(
review_id, status=ReviewStatus.REJECTED, graph_exec_id=graph_exec_id
)
mock_db = MagicMock()
mock_db.get_reviews_by_node_exec_ids = AsyncMock(
return_value={review_id: review}
)
with patch(
"backend.copilot.tools.continue_run_block.review_db",
return_value=mock_db,
):
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
review_id=review_id,
)
assert isinstance(response, ErrorResponse)
assert "rejected" in response.message.lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_approved_review_executes_block(self):
tool = ContinueRunBlockTool()
session = make_session(user_id=_TEST_USER_ID)
review_id = "copilot-node-delete-branch-id:abc12345"
graph_exec_id = f"copilot-session-{session.session_id}"
input_data = {"repo_url": "https://github.com/test/repo", "branch": "main"}
review = _make_review_model(
review_id,
status=ReviewStatus.APPROVED,
payload=input_data,
graph_exec_id=graph_exec_id,
)
mock_block = MagicMock()
mock_block.name = "Delete Branch"
async def mock_execute(data, **kwargs):
yield "result", "Branch deleted"
mock_block.execute = mock_execute
mock_block.input_schema.get_credentials_fields_info.return_value = []
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="test-workspace-id")
)
mock_db = MagicMock()
mock_db.get_reviews_by_node_exec_ids = AsyncMock(
return_value={review_id: review}
)
mock_db.delete_review_by_node_exec_id = AsyncMock(return_value=1)
with (
patch(
"backend.copilot.tools.continue_run_block.review_db",
return_value=mock_db,
),
patch(
"backend.copilot.tools.continue_run_block.get_block",
return_value=mock_block,
),
patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
),
patch(
"backend.copilot.tools.helpers.match_credentials_to_requirements",
return_value=({}, []),
),
):
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
review_id=review_id,
)
assert isinstance(response, BlockOutputResponse)
assert response.success is True
assert response.block_name == "Delete Branch"
# Verify review was deleted (one-time use)
mock_db.delete_review_by_node_exec_id.assert_called_once_with(
review_id, _TEST_USER_ID
)

View File

@@ -1,34 +1,20 @@
"""CreateAgentTool - Creates agents from natural language descriptions."""
"""CreateAgentTool - Creates agents from pre-built JSON."""
import logging
import uuid
from typing import Any
from backend.copilot.model import ChatSession
from .agent_generator import (
AgentGeneratorNotConfiguredError,
decompose_goal,
enrich_library_agents_from_steps,
generate_agent,
get_user_message_for_error,
save_agent_to_library,
)
from .agent_generator.pipeline import fetch_library_agents, fix_validate_and_save
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
SuggestedGoalResponse,
ToolResponseBase,
)
from .models import ErrorResponse, ToolResponseBase
logger = logging.getLogger(__name__)
class CreateAgentTool(BaseTool):
"""Tool for creating agents from natural language descriptions."""
"""Tool for creating agents from pre-built JSON."""
@property
def name(self) -> str:
@@ -37,15 +23,12 @@ class CreateAgentTool(BaseTool):
@property
def description(self) -> str:
return (
"Create a new agent workflow from a natural language description. "
"First generates a preview, then saves to library if save=true. "
"\n\nWorkflow: (1) Always check find_library_agent first for existing building blocks. "
"(2) Call create_agent with description and library_agent_ids. "
"(3) If response contains suggested_goal: Present to user, ask for confirmation, "
"then call again with the suggested goal if accepted. "
"(4) If response contains clarifying_questions: Present to user, collect answers, "
"then call again with original description AND answers in the context parameter. "
"\n\nThis feedback loop ensures the generated agent matches user intent."
"Create a new agent workflow. Pass `agent_json` with the complete "
"agent graph JSON you generated using block schemas from find_block. "
"The tool validates, auto-fixes, and saves.\n\n"
"IMPORTANT: Before calling this tool, search for relevant existing agents "
"using find_library_agent that could be used as building blocks. "
"Pass their IDs in the library_agent_ids parameter."
)
@property
@@ -57,34 +40,26 @@ class CreateAgentTool(BaseTool):
return {
"type": "object",
"properties": {
"description": {
"type": "string",
"agent_json": {
"type": "object",
"description": (
"Natural language description of what the agent should do. "
"Be specific about inputs, outputs, and the workflow steps."
),
},
"context": {
"type": "string",
"description": (
"Additional context or answers to previous clarifying questions. "
"Include any preferences or constraints mentioned by the user."
"The agent JSON to validate and save. "
"Must contain 'nodes' and 'links' arrays, and optionally "
"'name' and 'description'."
),
},
"library_agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": (
"List of library agent IDs to use as building blocks. "
"Search for relevant agents using find_library_agent first, "
"then pass their IDs here so they can be composed into the new agent."
"List of library agent IDs to use as building blocks."
),
},
"save": {
"type": "boolean",
"description": (
"Whether to save the agent to the user's library. "
"Default is true. Set to false for preview only."
"Whether to save the agent. Default is true. "
"Set to false for preview only."
),
"default": True,
},
@@ -97,7 +72,7 @@ class CreateAgentTool(BaseTool):
),
},
},
"required": ["description"],
"required": ["agent_json"],
}
async def _execute(
@@ -106,278 +81,49 @@ class CreateAgentTool(BaseTool):
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the create_agent tool.
Flow:
1. Decompose the description into steps (may return clarifying questions)
2. Generate agent JSON (external service handles fixing and validation)
3. Preview or save based on the save parameter
"""
description = kwargs.get("description", "").strip()
context = kwargs.get("context", "")
library_agent_ids = kwargs.get("library_agent_ids", [])
save = kwargs.get("save", True)
folder_id = kwargs.get("folder_id")
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
session_id = session.session_id if session else None
logger.info(
f"[AGENT_CREATE_DEBUG] START - description_len={len(description)}, "
f"library_agent_ids={library_agent_ids}, save={save}, user_id={user_id}, session_id={session_id}"
if not agent_json:
return ErrorResponse(
message=(
"Please provide agent_json with the complete agent graph. "
"Use find_block to discover blocks, then generate the JSON."
),
error="missing_agent_json",
session_id=session_id,
)
save = kwargs.get("save", True)
library_agent_ids = kwargs.get("library_agent_ids", [])
folder_id: str | None = kwargs.get("folder_id")
nodes = agent_json.get("nodes", [])
if not nodes:
return ErrorResponse(
message="The agent JSON has no nodes. An agent needs at least one block.",
error="empty_agent",
session_id=session_id,
)
# Ensure top-level fields
if "id" not in agent_json:
agent_json["id"] = str(uuid.uuid4())
if "version" not in agent_json:
agent_json["version"] = 1
if "is_active" not in agent_json:
agent_json["is_active"] = True
# Fetch library agents for AgentExecutorBlock validation
library_agents = await fetch_library_agents(user_id, library_agent_ids)
return await fix_validate_and_save(
agent_json,
user_id=user_id,
session_id=session_id,
save=save,
is_update=False,
default_name="Generated Agent",
library_agents=library_agents,
folder_id=folder_id,
)
if not description:
return ErrorResponse(
message="Please provide a description of what the agent should do.",
error="Missing description parameter",
session_id=session_id,
)
# Fetch library agents by IDs if provided
library_agents = None
if user_id and library_agent_ids:
try:
from .agent_generator import get_library_agents_by_ids
library_agents = await get_library_agents_by_ids(
user_id=user_id,
agent_ids=library_agent_ids,
)
logger.debug(
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents by IDs: {e}")
try:
decomposition_result = await decompose_goal(
description, context, library_agents
)
logger.info(
f"[AGENT_CREATE_DEBUG] DECOMPOSE - type={decomposition_result.get('type') if decomposition_result else None}, "
f"session_id={session_id}"
)
except AgentGeneratorNotConfiguredError:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured, session_id={session_id}"
)
return ErrorResponse(
message=(
"Agent generation is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
if decomposition_result is None:
return ErrorResponse(
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
error="decomposition_failed",
details={"description": description[:100]},
session_id=session_id,
)
if decomposition_result.get("type") == "error":
error_msg = decomposition_result.get("error", "Unknown error")
error_type = decomposition_result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="analyze the goal",
llm_parse_message="The AI had trouble understanding this request. Please try rephrasing your goal.",
)
return ErrorResponse(
message=user_message,
error=f"decomposition_failed:{error_type}",
details={
"description": description[:100],
"service_error": error_msg,
"error_type": error_type,
},
session_id=session_id,
)
if decomposition_result.get("type") == "clarifying_questions":
questions = decomposition_result.get("questions", [])
return ClarificationNeededResponse(
message=(
"I need some more information to create this agent. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
],
session_id=session_id,
)
if decomposition_result.get("type") == "unachievable_goal":
suggested = decomposition_result.get("suggested_goal", "")
reason = decomposition_result.get("reason", "")
return SuggestedGoalResponse(
message=(
f"This goal cannot be accomplished with the available blocks. {reason}"
),
suggested_goal=suggested,
reason=reason,
original_goal=description,
goal_type="unachievable",
session_id=session_id,
)
if decomposition_result.get("type") == "vague_goal":
suggested = decomposition_result.get("suggested_goal", "")
reason = decomposition_result.get(
"reason", "The goal needs more specific details"
)
return SuggestedGoalResponse(
message="The goal is too vague to create a specific workflow.",
suggested_goal=suggested,
reason=reason,
original_goal=description,
goal_type="vague",
session_id=session_id,
)
if user_id and library_agents is not None:
try:
library_agents = await enrich_library_agents_from_steps(
user_id=user_id,
decomposition_result=decomposition_result,
existing_agents=library_agents,
include_marketplace=True,
)
logger.debug(
f"After enrichment: {len(library_agents)} total agents for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to enrich library agents from steps: {e}")
try:
agent_json = await generate_agent(
decomposition_result,
library_agents,
)
logger.info(
f"[AGENT_CREATE_DEBUG] GENERATE - "
f"success={agent_json is not None}, "
f"is_error={isinstance(agent_json, dict) and agent_json.get('type') == 'error'}, "
f"session_id={session_id}"
)
except AgentGeneratorNotConfiguredError:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured during generation, session_id={session_id}"
)
return ErrorResponse(
message=(
"Agent generation is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
if agent_json is None:
return ErrorResponse(
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
error="generation_failed",
details={"description": description[:100]},
session_id=session_id,
)
if isinstance(agent_json, dict) and agent_json.get("type") == "error":
error_msg = agent_json.get("error", "Unknown error")
error_type = agent_json.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="generate the agent",
llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.",
validation_message=(
"I wasn't able to create a valid agent for this request. "
"The generated workflow had some structural issues. "
"Please try simplifying your goal or breaking it into smaller steps."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"generation_failed:{error_type}",
details={
"description": description[:100],
"service_error": error_msg,
"error_type": error_type,
},
session_id=session_id,
)
agent_name = agent_json.get("name", "Generated Agent")
agent_description = agent_json.get("description", "")
node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", []))
logger.info(
f"[AGENT_CREATE_DEBUG] AGENT_JSON - name={agent_name}, "
f"nodes={node_count}, links={link_count}, save={save}, session_id={session_id}"
)
if not save:
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - AgentPreviewResponse, session_id={session_id}"
)
return AgentPreviewResponse(
message=(
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
f"Review it and call create_agent with save=true to save it to your library."
),
agent_json=agent_json,
agent_name=agent_name,
description=agent_description,
node_count=node_count,
link_count=link_count,
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
error="auth_required",
session_id=session_id,
)
try:
created_graph, library_agent = await save_agent_to_library(
agent_json, user_id, folder_id=folder_id
)
logger.info(
f"[AGENT_CREATE_DEBUG] SAVED - graph_id={created_graph.id}, "
f"library_agent_id={library_agent.id}, session_id={session_id}"
)
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - AgentSavedResponse, session_id={session_id}"
)
return AgentSavedResponse(
message=f"Agent '{created_graph.name}' has been saved to your library!",
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)
except Exception as e:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - save_failed: {str(e)}, session_id={session_id}"
)
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - ErrorResponse (save_failed), session_id={session_id}"
)
return ErrorResponse(
message=f"Failed to save the agent: {str(e)}",
error="save_failed",
details={"exception": str(e)},
session_id=session_id,
)

View File

@@ -1,19 +1,16 @@
"""Tests for CreateAgentTool response types."""
"""Tests for CreateAgentTool."""
from unittest.mock import AsyncMock, patch
from unittest.mock import MagicMock, patch
import pytest
from backend.copilot.tools.create_agent import CreateAgentTool
from backend.copilot.tools.models import (
ClarificationNeededResponse,
ErrorResponse,
SuggestedGoalResponse,
)
from backend.copilot.tools.models import AgentPreviewResponse, ErrorResponse
from ._test_data import make_session
_TEST_USER_ID = "test-user-create-agent"
_PIPELINE = "backend.copilot.tools.agent_generator.pipeline"
@pytest.fixture
@@ -26,102 +23,147 @@ def session():
return make_session(_TEST_USER_ID)
# ── Input validation tests ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_missing_description_returns_error(tool, session):
"""Missing description returns ErrorResponse."""
result = await tool._execute(user_id=_TEST_USER_ID, session=session, description="")
async def test_missing_agent_json_returns_error(tool, session):
"""Missing agent_json returns ErrorResponse."""
result = await tool._execute(user_id=_TEST_USER_ID, session=session)
assert isinstance(result, ErrorResponse)
assert result.error == "Missing description parameter"
assert result.error == "missing_agent_json"
# ── Local mode tests ────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_vague_goal_returns_suggested_goal_response(tool, session):
"""vague_goal decomposition result returns SuggestedGoalResponse, not ErrorResponse."""
vague_result = {
"type": "vague_goal",
"suggested_goal": "Monitor Twitter mentions for a specific keyword and send a daily digest email",
}
with (
patch(
"backend.copilot.tools.create_agent.decompose_goal",
new_callable=AsyncMock,
return_value=vague_result,
),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
description="monitor social media",
)
assert isinstance(result, SuggestedGoalResponse)
assert result.goal_type == "vague"
assert result.suggested_goal == vague_result["suggested_goal"]
assert result.original_goal == "monitor social media"
assert result.reason == "The goal needs more specific details"
assert not isinstance(result, ErrorResponse)
async def test_local_mode_empty_nodes_returns_error(tool, session):
"""Local mode with no nodes returns ErrorResponse."""
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json={"nodes": [], "links": []},
)
assert isinstance(result, ErrorResponse)
assert "no nodes" in result.message.lower()
@pytest.mark.asyncio
async def test_unachievable_goal_returns_suggested_goal_response(tool, session):
"""unachievable_goal decomposition result returns SuggestedGoalResponse, not ErrorResponse."""
unachievable_result = {
"type": "unachievable_goal",
"suggested_goal": "Summarize the latest news articles on a topic and send them by email",
"reason": "There are no blocks for mind-reading.",
}
with (
patch(
"backend.copilot.tools.create_agent.decompose_goal",
new_callable=AsyncMock,
return_value=unachievable_result,
),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
description="read my mind",
)
assert isinstance(result, SuggestedGoalResponse)
assert result.goal_type == "unachievable"
assert result.suggested_goal == unachievable_result["suggested_goal"]
assert result.original_goal == "read my mind"
assert result.reason == unachievable_result["reason"]
assert not isinstance(result, ErrorResponse)
@pytest.mark.asyncio
async def test_clarifying_questions_returns_clarification_needed_response(
tool, session
):
"""clarifying_questions decomposition result returns ClarificationNeededResponse."""
clarifying_result = {
"type": "clarifying_questions",
"questions": [
async def test_local_mode_preview(tool, session):
"""Local mode with save=false returns AgentPreviewResponse."""
agent_json = {
"name": "Test Agent",
"description": "A test agent",
"nodes": [
{
"question": "What platform should be monitored?",
"keyword": "platform",
"example": "Twitter, Reddit",
"id": "node-1",
"block_id": "block-1",
"input_default": {},
"metadata": {"position": {"x": 0, "y": 0}},
}
],
"links": [],
}
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
mock_fixer.get_fixes_applied.return_value = []
mock_validator = MagicMock()
mock_validator.validate.return_value = (True, None)
mock_validator.errors = []
with (
patch(
"backend.copilot.tools.create_agent.decompose_goal",
new_callable=AsyncMock,
return_value=clarifying_result,
),
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
description="monitor social media and alert me",
agent_json=agent_json,
save=False,
)
assert isinstance(result, ClarificationNeededResponse)
assert len(result.questions) == 1
assert result.questions[0].keyword == "platform"
assert isinstance(result, AgentPreviewResponse)
assert result.agent_name == "Test Agent"
assert result.node_count == 1
@pytest.mark.asyncio
async def test_local_mode_validation_failure(tool, session):
"""Local mode returns ErrorResponse when validation fails after fixing."""
agent_json = {
"nodes": [
{
"id": "node-1",
"block_id": "bad-block",
"input_default": {},
"metadata": {},
}
],
"links": [],
}
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
mock_fixer.get_fixes_applied.return_value = []
mock_validator = MagicMock()
mock_validator.validate.return_value = (False, "Block 'bad-block' not found")
mock_validator.errors = ["Block 'bad-block' not found"]
with (
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json=agent_json,
)
assert isinstance(result, ErrorResponse)
assert result.error == "validation_failed"
assert "Block 'bad-block' not found" in result.message
@pytest.mark.asyncio
async def test_local_mode_no_auth_returns_error(tool, session):
"""Local mode with save=true and no user returns ErrorResponse."""
agent_json = {
"nodes": [
{
"id": "node-1",
"block_id": "block-1",
"input_default": {},
"metadata": {},
}
],
"links": [],
}
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
mock_fixer.get_fixes_applied.return_value = []
mock_validator = MagicMock()
mock_validator.validate.return_value = (True, None)
mock_validator.errors = []
with (
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
):
result = await tool._execute(
user_id=None,
session=session,
agent_json=agent_json,
save=True,
)
assert isinstance(result, ErrorResponse)
assert "logged in" in result.message.lower()

View File

@@ -1,34 +1,20 @@
"""CustomizeAgentTool - Customizes marketplace/template agents using natural language."""
"""CustomizeAgentTool - Customizes marketplace/template agents."""
import logging
import uuid
from typing import Any
from backend.copilot.model import ChatSession
from backend.data.db_accessors import store_db as get_store_db
from backend.util.exceptions import NotFoundError
from .agent_generator import (
AgentGeneratorNotConfiguredError,
customize_template,
get_user_message_for_error,
graph_to_json,
save_agent_to_library,
)
from .agent_generator.pipeline import fetch_library_agents, fix_validate_and_save
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
ToolResponseBase,
)
from .models import ErrorResponse, ToolResponseBase
logger = logging.getLogger(__name__)
class CustomizeAgentTool(BaseTool):
"""Tool for customizing marketplace/template agents using natural language."""
"""Tool for customizing marketplace/template agents."""
@property
def name(self) -> str:
@@ -37,9 +23,9 @@ class CustomizeAgentTool(BaseTool):
@property
def description(self) -> str:
return (
"Customize a marketplace or template agent using natural language. "
"Takes an existing agent from the marketplace and modifies it based on "
"the user's requirements before adding to their library."
"Customize a marketplace or template agent. Pass `agent_json` "
"with the complete customized agent JSON. The tool validates, "
"auto-fixes, and saves."
)
@property
@@ -51,32 +37,24 @@ class CustomizeAgentTool(BaseTool):
return {
"type": "object",
"properties": {
"agent_id": {
"type": "string",
"agent_json": {
"type": "object",
"description": (
"The marketplace agent ID in format 'creator/slug' "
"(e.g., 'autogpt/newsletter-writer'). "
"Get this from find_agent results."
"Complete customized agent JSON to validate and save. "
"Optionally include 'name' and 'description'."
),
},
"modifications": {
"type": "string",
"library_agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": (
"Natural language description of how to customize the agent. "
"Be specific about what changes you want to make."
),
},
"context": {
"type": "string",
"description": (
"Additional context or answers to previous clarifying questions."
"List of library agent IDs to use as building blocks."
),
},
"save": {
"type": "boolean",
"description": (
"Whether to save the customized agent to the user's library. "
"Default is true. Set to false for preview only."
"Whether to save the customized agent. Default is true."
),
"default": True,
},
@@ -89,7 +67,7 @@ class CustomizeAgentTool(BaseTool):
),
},
},
"required": ["agent_id", "modifications"],
"required": ["agent_json"],
}
async def _execute(
@@ -98,247 +76,46 @@ class CustomizeAgentTool(BaseTool):
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the customize_agent tool.
Flow:
1. Parse the agent ID to get creator/slug
2. Fetch the template agent from the marketplace
3. Call customize_template with the modification request
4. Preview or save based on the save parameter
"""
agent_id = kwargs.get("agent_id", "").strip()
modifications = kwargs.get("modifications", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
folder_id = kwargs.get("folder_id")
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
session_id = session.session_id if session else None
if not agent_id:
return ErrorResponse(
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
error="missing_agent_id",
session_id=session_id,
)
if not modifications:
return ErrorResponse(
message="Please describe how you want to customize this agent.",
error="missing_modifications",
session_id=session_id,
)
# Parse agent_id in format "creator/slug"
parts = [p.strip() for p in agent_id.split("/")]
if len(parts) != 2 or not parts[0] or not parts[1]:
if not agent_json:
return ErrorResponse(
message=(
f"Invalid agent ID format: '{agent_id}'. "
"Expected format is 'creator/agent-name' "
"(e.g., 'autogpt/newsletter-writer')."
"Please provide agent_json with the complete customized agent graph."
),
error="invalid_agent_id_format",
error="missing_agent_json",
session_id=session_id,
)
creator_username, agent_slug = parts
save = kwargs.get("save", True)
library_agent_ids = kwargs.get("library_agent_ids", [])
folder_id: str | None = kwargs.get("folder_id")
store_db = get_store_db()
# Fetch the marketplace agent details
try:
agent_details = await store_db.get_store_agent_details(
username=creator_username, agent_name=agent_slug
)
except NotFoundError:
nodes = agent_json.get("nodes", [])
if not nodes:
return ErrorResponse(
message=(
f"Could not find marketplace agent '{agent_id}'. "
"Please check the agent ID and try again."
),
error="agent_not_found",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
return ErrorResponse(
message="Failed to fetch the marketplace agent. Please try again.",
error="fetch_error",
message="The agent JSON has no nodes.",
error="empty_agent",
session_id=session_id,
)
if not agent_details.store_listing_version_id:
return ErrorResponse(
message=(
f"The agent '{agent_id}' does not have an available version. "
"Please try a different agent."
),
error="no_version_available",
session_id=session_id,
)
# Ensure top-level fields before the fixer pipeline
if "id" not in agent_json:
agent_json["id"] = str(uuid.uuid4())
agent_json.setdefault("version", 1)
agent_json.setdefault("is_active", True)
# Get the full agent graph
try:
graph = await store_db.get_agent(agent_details.store_listing_version_id)
template_agent = graph_to_json(graph)
except Exception as e:
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
return ErrorResponse(
message="Failed to fetch the agent configuration. Please try again.",
error="graph_fetch_error",
session_id=session_id,
)
# Fetch library agents for AgentExecutorBlock validation
library_agents = await fetch_library_agents(user_id, library_agent_ids)
# Call customize_template
try:
result = await customize_template(
template_agent=template_agent,
modification_request=modifications,
context=context,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
"Agent customization is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error calling customize_template for {agent_id}: {e}")
return ErrorResponse(
message=(
"Failed to customize the agent due to a service error. "
"Please try again."
),
error="customization_service_error",
session_id=session_id,
)
if result is None:
return ErrorResponse(
message=(
"Failed to customize the agent. "
"The agent generation service may be unavailable or timed out. "
"Please try again."
),
error="customization_failed",
session_id=session_id,
)
# Handle error response
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="customize the agent",
llm_parse_message=(
"The AI had trouble customizing the agent. "
"Please try again or simplify your request."
),
validation_message=(
"The customized agent failed validation. "
"Please try rephrasing your request."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"customization_failed:{error_type}",
session_id=session_id,
)
# Handle clarifying questions
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
questions = result.get("questions") or []
if not isinstance(questions, list):
logger.error(
f"Unexpected clarifying questions format: {type(questions)}"
)
questions = []
return ClarificationNeededResponse(
message=(
"I need some more information to customize this agent. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
if isinstance(q, dict)
],
session_id=session_id,
)
# Result should be the customized agent JSON
if not isinstance(result, dict):
logger.error(f"Unexpected customize_template response type: {type(result)}")
return ErrorResponse(
message="Failed to customize the agent due to an unexpected response.",
error="unexpected_response_type",
session_id=session_id,
)
customized_agent = result
agent_name = customized_agent.get(
"name", f"Customized {agent_details.agent_name}"
return await fix_validate_and_save(
agent_json,
user_id=user_id,
session_id=session_id,
save=save,
is_update=False,
default_name="Customized Agent",
library_agents=library_agents,
folder_id=folder_id,
)
agent_description = customized_agent.get("description", "")
nodes = customized_agent.get("nodes")
links = customized_agent.get("links")
node_count = len(nodes) if isinstance(nodes, list) else 0
link_count = len(links) if isinstance(links, list) else 0
if not save:
return AgentPreviewResponse(
message=(
f"I've customized the agent '{agent_details.agent_name}'. "
f"The customized agent has {node_count} blocks. "
f"Review it and call customize_agent with save=true to save it."
),
agent_json=customized_agent,
agent_name=agent_name,
description=agent_description,
node_count=node_count,
link_count=link_count,
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
error="auth_required",
session_id=session_id,
)
# Save to user's library
try:
created_graph, library_agent = await save_agent_to_library(
customized_agent, user_id, is_update=False, folder_id=folder_id
)
return AgentSavedResponse(
message=(
f"Customized agent '{created_graph.name}' "
f"(based on '{agent_details.agent_name}') "
f"has been saved to your library!"
),
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error saving customized agent: {e}")
return ErrorResponse(
message="Failed to save the customized agent. Please try again.",
error="save_failed",
session_id=session_id,
)

View File

@@ -0,0 +1,172 @@
"""Tests for CustomizeAgentTool local mode."""
from unittest.mock import MagicMock, patch
import pytest
from backend.copilot.tools.customize_agent import CustomizeAgentTool
from backend.copilot.tools.models import AgentPreviewResponse, ErrorResponse
from ._test_data import make_session
_TEST_USER_ID = "test-user-customize-agent"
_PIPELINE = "backend.copilot.tools.agent_generator.pipeline"
@pytest.fixture
def tool():
return CustomizeAgentTool()
@pytest.fixture
def session():
return make_session(_TEST_USER_ID)
# ── Input validation tests ───────────────────────────────────────────────
@pytest.mark.asyncio
async def test_missing_agent_json_returns_error(tool, session):
"""Missing agent_json returns ErrorResponse."""
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
)
assert isinstance(result, ErrorResponse)
assert result.error == "missing_agent_json"
# ── Local mode tests (agent_json provided) ───────────────────────────────
@pytest.mark.asyncio
async def test_local_mode_empty_nodes_returns_error(tool, session):
"""Local mode with no nodes returns ErrorResponse."""
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json={"nodes": [], "links": []},
)
assert isinstance(result, ErrorResponse)
assert "no nodes" in result.message.lower()
@pytest.mark.asyncio
async def test_local_mode_preview(tool, session):
"""Local mode with save=false returns AgentPreviewResponse."""
agent_json = {
"name": "Customized Agent",
"description": "A customized agent",
"nodes": [
{
"id": "node-1",
"block_id": "block-1",
"input_default": {},
"metadata": {"position": {"x": 0, "y": 0}},
}
],
"links": [],
}
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
mock_fixer.get_fixes_applied.return_value = []
mock_validator = MagicMock()
mock_validator.validate.return_value = (True, None)
mock_validator.errors = []
with (
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json=agent_json,
save=False,
)
assert isinstance(result, AgentPreviewResponse)
assert result.agent_name == "Customized Agent"
assert result.node_count == 1
@pytest.mark.asyncio
async def test_local_mode_validation_failure(tool, session):
"""Local mode returns ErrorResponse when validation fails."""
agent_json = {
"nodes": [
{
"id": "node-1",
"block_id": "bad-block",
"input_default": {},
"metadata": {},
}
],
"links": [],
}
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
mock_fixer.get_fixes_applied.return_value = []
mock_validator = MagicMock()
mock_validator.validate.return_value = (False, "Block 'bad-block' not found")
mock_validator.errors = ["Block 'bad-block' not found"]
with (
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json=agent_json,
)
assert isinstance(result, ErrorResponse)
assert result.error == "validation_failed"
assert "Block 'bad-block' not found" in result.message
@pytest.mark.asyncio
async def test_local_mode_no_auth_returns_error(tool, session):
"""Local mode with save=true and no user returns ErrorResponse."""
agent_json = {
"nodes": [
{
"id": "node-1",
"block_id": "block-1",
"input_default": {},
"metadata": {},
}
],
"links": [],
}
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
mock_fixer.get_fixes_applied.return_value = []
mock_validator = MagicMock()
mock_validator.validate.return_value = (True, None)
mock_validator.errors = []
with (
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
):
result = await tool._execute(
user_id=None,
session=session,
agent_json=agent_json,
save=True,
)
assert isinstance(result, ErrorResponse)
assert "logged in" in result.message.lower()

View File

@@ -10,13 +10,34 @@ Lifecycle
---------
1. **Turn start** connect to the existing sandbox (sandbox_id in Redis) or
create a new one via ``get_or_create_sandbox()``.
``connect()`` in e2b v2 auto-resumes paused sandboxes.
2. **Execution** ``bash_exec`` and MCP file tools operate directly on the
sandbox's ``/home/user`` filesystem.
3. **Session expiry** E2B sandbox is killed by its own timeout (session_ttl).
3. **Turn end** the sandbox is paused via ``pause_sandbox()`` (fire-and-forget)
so idle time between turns costs nothing. Paused sandboxes have no compute
cost.
4. **Session delete** ``kill_sandbox()`` fully terminates the sandbox.
Cost control
------------
Sandboxes are created with a configurable ``on_timeout`` lifecycle action
(default: ``"pause"``). The explicit per-turn ``pause_sandbox()`` call is the
primary mechanism; the lifecycle setting is a safety net. Paused sandboxes are
free.
The sandbox_id is stored in Redis. The same key doubles as a creation lock:
a ``"creating"`` sentinel value is written with a short TTL while a new sandbox
is being provisioned, preventing duplicate creation under concurrent requests.
E2B project-level "paused sandbox lifetime" should be set to match
``_SANDBOX_ID_TTL`` (48 h) so orphaned paused sandboxes are auto-killed before
the Redis key expires.
"""
import asyncio
import contextlib
import logging
from typing import Any, Awaitable, Callable, Literal
from e2b import AsyncSandbox
@@ -24,147 +45,245 @@ from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
_SANDBOX_REDIS_PREFIX = "copilot:e2b:sandbox:"
E2B_WORKDIR = "/home/user"
_CREATING = "__creating__"
_CREATION_LOCK_TTL = 60
_MAX_WAIT_ATTEMPTS = 20 # 20 * 0.5s = 10s max wait
_SANDBOX_KEY_PREFIX = "copilot:e2b:sandbox:"
_CREATING_SENTINEL = "creating"
# Short TTL for the "creating" sentinel — if the process dies mid-creation the
# lock auto-expires so other callers are not blocked forever.
_CREATION_LOCK_TTL = 60 # seconds
_MAX_WAIT_ATTEMPTS = 20 # 20 × 0.5 s = 10 s max wait
# Timeout for E2B API calls (pause/kill) — short because these are control-plane
# operations; if the sandbox is unreachable, fail fast and retry on the next turn.
_E2B_API_TIMEOUT_SECONDS = 10
# Redis TTL for the sandbox key. Must be ≥ the E2B project "paused sandbox
# lifetime" setting (recommended: set both to 48 h).
_SANDBOX_ID_TTL = 48 * 3600 # 48 hours
def _sandbox_key(session_id: str) -> str:
return f"{_SANDBOX_KEY_PREFIX}{session_id}"
async def _get_stored_sandbox_id(session_id: str) -> str | None:
redis = await get_redis_async()
raw = await redis.get(_sandbox_key(session_id))
value = raw.decode() if isinstance(raw, bytes) else raw
return None if value == _CREATING_SENTINEL else value
async def _set_stored_sandbox_id(session_id: str, sandbox_id: str) -> None:
redis = await get_redis_async()
await redis.set(_sandbox_key(session_id), sandbox_id, ex=_SANDBOX_ID_TTL)
async def _clear_stored_sandbox_id(session_id: str) -> None:
redis = await get_redis_async()
await redis.delete(_sandbox_key(session_id))
async def _try_reconnect(
sandbox_id: str, api_key: str, redis_key: str, timeout: int
sandbox_id: str, session_id: str, api_key: str
) -> "AsyncSandbox | None":
"""Try to reconnect to an existing sandbox. Returns None on failure."""
try:
sandbox = await AsyncSandbox.connect(sandbox_id, api_key=api_key)
if await sandbox.is_running():
redis = await get_redis_async()
await redis.expire(redis_key, timeout)
# Refresh TTL so an active session cannot lose its sandbox_id at expiry.
await _set_stored_sandbox_id(session_id, sandbox_id)
return sandbox
except Exception as exc:
logger.warning("[E2B] Reconnect to %.12s failed: %s", sandbox_id, exc)
# Stale — clear Redis so a new sandbox can be created.
redis = await get_redis_async()
await redis.delete(redis_key)
# Stale — clear the sandbox_id from Redis so a new one can be created.
await _clear_stored_sandbox_id(session_id)
return None
async def get_or_create_sandbox(
session_id: str,
api_key: str,
timeout: int,
template: str = "base",
timeout: int = 43200,
on_timeout: Literal["kill", "pause"] = "pause",
) -> AsyncSandbox:
"""Return the existing E2B sandbox for *session_id* or create a new one.
The sandbox_id is persisted in Redis so the same sandbox is reused
across turns. Concurrent calls for the same session are serialised
via a Redis ``SET NX`` creation lock.
The sandbox key in Redis serves a dual purpose: it stores the sandbox_id
and acts as a creation lock via a ``"creating"`` sentinel value. This
removes the need for a separate lock key.
*timeout* controls how long the e2b sandbox may run continuously before
the ``on_timeout`` lifecycle rule fires (default: 3 h).
*on_timeout* controls what happens on timeout: ``"pause"`` (default, free)
or ``"kill"``.
"""
redis = await get_redis_async()
redis_key = f"{_SANDBOX_REDIS_PREFIX}{session_id}"
key = _sandbox_key(session_id)
# 1. Try reconnecting to an existing sandbox.
raw = await redis.get(redis_key)
if raw:
sandbox_id = raw if isinstance(raw, str) else raw.decode()
if sandbox_id != _CREATING:
sandbox = await _try_reconnect(sandbox_id, api_key, redis_key, timeout)
for _ in range(_MAX_WAIT_ATTEMPTS):
raw = await redis.get(key)
value = raw.decode() if isinstance(raw, bytes) else raw
if value and value != _CREATING_SENTINEL:
# Existing sandbox ID — try to reconnect (auto-resumes if paused).
sandbox = await _try_reconnect(value, session_id, api_key)
if sandbox:
logger.info(
"[E2B] Reconnected to %.12s for session %.12s",
sandbox_id,
value,
session_id,
)
return sandbox
# _try_reconnect cleared the key — loop to create a new sandbox.
continue
# 2. Claim creation lock. If another request holds it, wait for the result.
claimed = await redis.set(redis_key, _CREATING, nx=True, ex=_CREATION_LOCK_TTL)
if not claimed:
for _ in range(_MAX_WAIT_ATTEMPTS):
if value == _CREATING_SENTINEL:
# Another coroutine is creating — wait for it to finish.
await asyncio.sleep(0.5)
raw = await redis.get(redis_key)
if not raw:
break # Lock expired — fall through to retry creation
sandbox_id = raw if isinstance(raw, str) else raw.decode()
if sandbox_id != _CREATING:
sandbox = await _try_reconnect(sandbox_id, api_key, redis_key, timeout)
if sandbox:
return sandbox
break # Stale sandbox cleared — fall through to create
continue
# Try to claim creation lock again after waiting.
claimed = await redis.set(redis_key, _CREATING, nx=True, ex=_CREATION_LOCK_TTL)
if not claimed:
# Another process may have created a sandbox — try to use it.
raw = await redis.get(redis_key)
if raw:
sandbox_id = raw if isinstance(raw, str) else raw.decode()
if sandbox_id != _CREATING:
sandbox = await _try_reconnect(
sandbox_id, api_key, redis_key, timeout
)
if sandbox:
return sandbox
raise RuntimeError(
f"Could not acquire E2B creation lock for session {session_id[:12]}"
)
# 3. Create a new sandbox.
try:
sandbox = await AsyncSandbox.create(
template=template, api_key=api_key, timeout=timeout
# No sandbox and no active creation — atomically claim the creation slot.
claimed = await redis.set(
key, _CREATING_SENTINEL, nx=True, ex=_CREATION_LOCK_TTL
)
except Exception:
await redis.delete(redis_key)
raise
if not claimed:
# Race lost — another coroutine just claimed it.
await asyncio.sleep(0.1)
continue
await redis.setex(redis_key, timeout, sandbox.sandbox_id)
logger.info(
"[E2B] Created sandbox %.12s for session %.12s",
sandbox.sandbox_id,
session_id,
)
return sandbox
# We hold the slot — create the sandbox.
try:
sandbox = await AsyncSandbox.create(
template=template,
api_key=api_key,
timeout=timeout,
lifecycle={"on_timeout": on_timeout},
)
try:
await _set_stored_sandbox_id(session_id, sandbox.sandbox_id)
except Exception:
# Redis save failed — kill the sandbox to avoid leaking it.
with contextlib.suppress(Exception):
await sandbox.kill()
raise
except Exception:
# Release the creation slot so other callers can proceed.
await redis.delete(key)
raise
logger.info(
"[E2B] Created sandbox %.12s for session %.12s",
sandbox.sandbox_id,
session_id,
)
return sandbox
raise RuntimeError(f"Could not acquire E2B sandbox for session {session_id[:12]}")
async def kill_sandbox(session_id: str, api_key: str) -> bool:
"""Kill the E2B sandbox for *session_id* and clean up its Redis entry.
async def _act_on_sandbox(
session_id: str,
api_key: str,
action: str,
fn: Callable[[AsyncSandbox], Awaitable[Any]],
*,
clear_stored_id: bool = False,
) -> bool:
"""Connect to the sandbox for *session_id* and run *fn* on it.
Returns ``True`` if a sandbox was found and killed, ``False`` otherwise.
Safe to call even when no sandbox exists for the session.
Shared by ``pause_sandbox`` and ``kill_sandbox``. Returns ``True`` on
success, ``False`` when no sandbox is found or the action fails.
If *clear_stored_id* is ``True``, the sandbox_id is removed from Redis
only after the action succeeds so a failed kill can be retried.
"""
redis = await get_redis_async()
redis_key = f"{_SANDBOX_REDIS_PREFIX}{session_id}"
raw = await redis.get(redis_key)
if not raw:
sandbox_id = await _get_stored_sandbox_id(session_id)
if not sandbox_id:
return False
sandbox_id = raw if isinstance(raw, str) else raw.decode()
await redis.delete(redis_key)
if sandbox_id == _CREATING:
return False
async def _run() -> None:
await fn(await AsyncSandbox.connect(sandbox_id, api_key=api_key))
try:
async def _connect_and_kill():
sandbox = await AsyncSandbox.connect(sandbox_id, api_key=api_key)
await sandbox.kill()
await asyncio.wait_for(_connect_and_kill(), timeout=10)
await asyncio.wait_for(_run(), timeout=_E2B_API_TIMEOUT_SECONDS)
if clear_stored_id:
await _clear_stored_sandbox_id(session_id)
logger.info(
"[E2B] Killed sandbox %.12s for session %.12s",
"[E2B] %s sandbox %.12s for session %.12s",
action.capitalize(),
sandbox_id,
session_id,
)
return True
except Exception as exc:
logger.warning(
"[E2B] Failed to kill sandbox %.12s for session %.12s: %s",
"[E2B] Failed to %s sandbox %.12s for session %.12s: %s",
action,
sandbox_id,
session_id,
exc,
)
return False
async def pause_sandbox(session_id: str, api_key: str) -> bool:
"""Pause the E2B sandbox for *session_id* to stop billing between turns.
Paused sandboxes cost nothing and are resumed automatically by
``get_or_create_sandbox()`` on the next turn (via ``AsyncSandbox.connect()``).
The sandbox_id is kept in Redis so reconnection works seamlessly.
Prefer ``pause_sandbox_direct()`` when the sandbox object is already in
scope — it skips the Redis lookup and reconnect round-trip.
Returns ``True`` if the sandbox was found and paused, ``False`` otherwise.
Safe to call even when no sandbox exists for the session.
"""
return await _act_on_sandbox(session_id, api_key, "pause", lambda sb: sb.pause())
async def pause_sandbox_direct(sandbox: "AsyncSandbox", session_id: str) -> bool:
"""Pause an already-connected sandbox without a reconnect round-trip.
Use this in callers that already hold the live sandbox object (e.g. turn
teardown in ``service.py``). Saves the Redis lookup and
``AsyncSandbox.connect()`` call that ``pause_sandbox()`` would make.
Returns ``True`` on success, ``False`` on failure or timeout.
"""
try:
await asyncio.wait_for(sandbox.pause(), timeout=_E2B_API_TIMEOUT_SECONDS)
logger.info(
"[E2B] Paused sandbox %.12s for session %.12s",
sandbox.sandbox_id,
session_id,
)
return True
except Exception as exc:
logger.warning(
"[E2B] Failed to pause sandbox %.12s for session %.12s: %s",
sandbox.sandbox_id,
session_id,
exc,
)
return False
async def kill_sandbox(
session_id: str,
api_key: str,
) -> bool:
"""Kill the E2B sandbox for *session_id* and clear its Redis entry.
Returns ``True`` if a sandbox was found and killed, ``False`` otherwise.
Safe to call even when no sandbox exists for the session.
"""
return await _act_on_sandbox(
session_id,
api_key,
"kill",
lambda sb: sb.kill(),
clear_stored_id=True,
)

View File

@@ -1,6 +1,12 @@
"""Tests for e2b_sandbox: get_or_create_sandbox, _try_reconnect, kill_sandbox.
Uses mock Redis and mock AsyncSandbox — no external dependencies.
sandbox_id is stored in Redis under _SANDBOX_KEY_PREFIX + session_id.
The same key doubles as a creation lock via a "creating" sentinel value.
Tests mock:
- ``get_redis_async`` (sandbox key storage + creation lock sentinel)
- ``AsyncSandbox`` (E2B SDK)
Tests are synchronous (using asyncio.run) to avoid conflicts with the
session-scoped event loop in conftest.py.
"""
@@ -11,36 +17,50 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from .e2b_sandbox import (
_CREATING,
_SANDBOX_REDIS_PREFIX,
_CREATING_SENTINEL,
_try_reconnect,
get_or_create_sandbox,
kill_sandbox,
pause_sandbox,
pause_sandbox_direct,
)
_KEY = f"{_SANDBOX_REDIS_PREFIX}sess-123"
_SESSION_ID = "sess-123"
_API_KEY = "test-api-key"
_SANDBOX_ID = "sb-abc"
_TIMEOUT = 300
def _mock_sandbox(sandbox_id: str = "sb-abc", running: bool = True) -> MagicMock:
def _mock_sandbox(sandbox_id: str = _SANDBOX_ID, running: bool = True) -> MagicMock:
sb = MagicMock()
sb.sandbox_id = sandbox_id
sb.is_running = AsyncMock(return_value=running)
sb.pause = AsyncMock()
sb.kill = AsyncMock()
return sb
def _mock_redis(get_val: str | bytes | None = None, set_nx_result: bool = True):
def _mock_redis(
set_nx_result: bool = True,
stored_sandbox_id: str | None = None,
) -> AsyncMock:
"""Create a mock redis client.
*stored_sandbox_id* is returned by ``get()`` calls (simulates the sandbox_id
stored under the ``_SANDBOX_KEY_PREFIX`` key). ``set_nx_result`` controls
whether the creation-slot ``SET NX`` succeeds.
If *stored_sandbox_id* is None the key is absent (no sandbox, no lock).
"""
r = AsyncMock()
r.get = AsyncMock(return_value=get_val)
raw = stored_sandbox_id.encode() if stored_sandbox_id else None
r.get = AsyncMock(return_value=raw)
r.set = AsyncMock(return_value=set_nx_result)
r.setex = AsyncMock()
r.delete = AsyncMock()
r.expire = AsyncMock()
return r
def _patch_redis(redis):
def _patch_redis(redis: AsyncMock):
return patch(
"backend.copilot.tools.e2b_sandbox.get_redis_async",
new_callable=AsyncMock,
@@ -55,6 +75,7 @@ def _patch_redis(redis):
class TestTryReconnect:
def test_reconnect_success(self):
"""Returns the sandbox when it connects and is running; refreshes Redis TTL."""
sb = _mock_sandbox()
redis = _mock_redis()
with (
@@ -62,36 +83,39 @@ class TestTryReconnect:
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
result = asyncio.run(_try_reconnect(_SANDBOX_ID, _SESSION_ID, _API_KEY))
assert result is sb
redis.expire.assert_awaited_once_with(_KEY, _TIMEOUT)
redis.delete.assert_not_awaited()
# TTL must be refreshed so an active session cannot lose its key at expiry.
redis.set.assert_awaited_once()
def test_reconnect_not_running_clears_key(self):
def test_reconnect_not_running_clears_redis(self):
"""Clears sandbox_id in Redis when the sandbox is no longer running."""
sb = _mock_sandbox(running=False)
redis = _mock_redis()
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
result = asyncio.run(_try_reconnect(_SANDBOX_ID, _SESSION_ID, _API_KEY))
assert result is None
redis.delete.assert_awaited_once_with(_KEY)
redis.delete.assert_awaited_once()
def test_reconnect_exception_clears_key(self):
redis = _mock_redis()
def test_reconnect_exception_clears_redis(self):
"""Clears sandbox_id in Redis when connect raises an exception."""
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
result = asyncio.run(_try_reconnect(_SANDBOX_ID, _SESSION_ID, _API_KEY))
assert result is None
redis.delete.assert_awaited_once_with(_KEY)
redis.delete.assert_awaited_once()
# ---------------------------------------------------------------------------
@@ -103,38 +127,63 @@ class TestGetOrCreateSandbox:
def test_reconnect_existing(self):
"""When Redis has a valid sandbox_id, reconnect to it."""
sb = _mock_sandbox()
redis = _mock_redis(get_val="sb-abc")
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
)
assert result is sb
mock_cls.create.assert_not_called()
# redis.set called once to refresh TTL, not to claim a creation slot
redis.set.assert_awaited_once()
def test_create_new_when_no_key(self):
"""When Redis is empty, claim lock and create a new sandbox."""
sb = _mock_sandbox("sb-new")
redis = _mock_redis(get_val=None, set_nx_result=True)
def test_create_new_when_no_stored_id(self):
"""When Redis has no sandbox_id, claim slot and create a new sandbox."""
new_sb = _mock_sandbox("sb-new")
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.create = AsyncMock(return_value=sb)
mock_cls.create = AsyncMock(return_value=new_sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
)
assert result is sb
redis.setex.assert_awaited_once_with(_KEY, _TIMEOUT, "sb-new")
assert result is new_sb
mock_cls.create.assert_awaited_once()
# Verify lifecycle param is set
_, kwargs = mock_cls.create.call_args
assert kwargs.get("lifecycle") == {"on_timeout": "pause"}
# sandbox_id should be saved to Redis
redis.set.assert_awaited()
def test_create_failure_clears_lock(self):
"""If sandbox creation fails, the Redis lock is deleted."""
redis = _mock_redis(get_val=None, set_nx_result=True)
def test_create_with_on_timeout_kill(self):
"""on_timeout='kill' is passed through to AsyncSandbox.create."""
new_sb = _mock_sandbox("sb-new")
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.create = AsyncMock(return_value=new_sb)
asyncio.run(
get_or_create_sandbox(
_SESSION_ID, _API_KEY, timeout=_TIMEOUT, on_timeout="kill"
)
)
_, kwargs = mock_cls.create.call_args
assert kwargs.get("lifecycle") == {"on_timeout": "kill"}
def test_create_failure_releases_slot(self):
"""If sandbox creation fails, the Redis creation slot is deleted."""
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
@@ -142,17 +191,53 @@ class TestGetOrCreateSandbox:
mock_cls.create = AsyncMock(side_effect=RuntimeError("quota"))
with pytest.raises(RuntimeError, match="quota"):
asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
)
redis.delete.assert_awaited_once_with(_KEY)
redis.delete.assert_awaited_once()
def test_wait_for_lock_then_reconnect(self):
"""When another process holds the lock, wait and reconnect."""
def test_redis_save_failure_kills_sandbox_and_releases_slot(self):
"""If Redis save fails after creation, sandbox is killed and slot released."""
new_sb = _mock_sandbox("sb-new")
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
# First set() call = creation slot SET NX (returns True).
# Second set() call = sandbox_id save (raises).
call_count = 0
async def _set_side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return True # creation slot claimed
raise RuntimeError("redis error")
redis.set = AsyncMock(side_effect=_set_side_effect)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.create = AsyncMock(return_value=new_sb)
with pytest.raises(RuntimeError, match="redis error"):
asyncio.run(
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
)
# Sandbox must be killed to avoid leaking it
new_sb.kill.assert_awaited_once()
# Creation slot must always be released
redis.delete.assert_awaited_once()
def test_wait_for_creating_sentinel_then_reconnect(self):
"""When the key holds the 'creating' sentinel, wait then reconnect."""
sb = _mock_sandbox("sb-other")
redis = _mock_redis()
redis.get = AsyncMock(side_effect=[_CREATING, "sb-other"])
# First get() returns the sentinel; second returns the real ID.
redis = AsyncMock()
creating_raw = _CREATING_SENTINEL.encode()
redis.get = AsyncMock(side_effect=[creating_raw, b"sb-other"])
redis.set = AsyncMock(return_value=False)
redis.delete = AsyncMock()
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
@@ -163,16 +248,21 @@ class TestGetOrCreateSandbox:
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
)
assert result is sb
def test_stale_reconnect_clears_and_creates(self):
"""When stored sandbox is stale, clear key and create a new one."""
"""When stored sandbox is stale (not running), clear it and create a new one."""
stale_sb = _mock_sandbox("sb-stale", running=False)
new_sb = _mock_sandbox("sb-fresh")
redis = _mock_redis(get_val="sb-stale", set_nx_result=True)
# First get() returns stale id (for reconnect check), then None (after clear).
redis = AsyncMock()
redis.get = AsyncMock(side_effect=[b"sb-stale", None])
redis.set = AsyncMock(return_value=True)
redis.delete = AsyncMock()
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
@@ -180,10 +270,11 @@ class TestGetOrCreateSandbox:
mock_cls.connect = AsyncMock(return_value=stale_sb)
mock_cls.create = AsyncMock(return_value=new_sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
)
assert result is new_sb
# Redis delete called at least once to clear stale id
redis.delete.assert_awaited()
@@ -194,70 +285,48 @@ class TestGetOrCreateSandbox:
class TestKillSandbox:
def test_kill_existing_sandbox(self):
"""Kill a running sandbox and clean up Redis."""
"""Kill a running sandbox and clear its Redis entry."""
sb = _mock_sandbox()
sb.kill = AsyncMock()
redis = _mock_redis(get_val="sb-abc")
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
assert result is True
redis.delete.assert_awaited_once_with(_KEY)
sb.kill.assert_awaited_once()
# Redis key cleared after successful kill
redis.delete.assert_awaited_once()
def test_kill_no_sandbox(self):
"""No-op when no sandbox exists in Redis."""
redis = _mock_redis(get_val=None)
"""No-op when Redis has no sandbox_id."""
redis = _mock_redis(stored_sandbox_id=None)
with _patch_redis(redis):
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
assert result is False
redis.delete.assert_not_awaited()
def test_kill_creating_state(self):
"""Clears Redis key but returns False when sandbox is still being created."""
redis = _mock_redis(get_val=_CREATING)
with _patch_redis(redis):
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
def test_kill_connect_failure_keeps_redis(self):
"""Returns False and leaves Redis entry intact when connect/kill fails.
assert result is False
redis.delete.assert_awaited_once_with(_KEY)
def test_kill_connect_failure(self):
"""Returns False and cleans Redis if connect/kill fails."""
redis = _mock_redis(get_val="sb-abc")
Keeping the sandbox_id in Redis allows the kill to be retried.
"""
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
assert result is False
redis.delete.assert_awaited_once_with(_KEY)
redis.delete.assert_not_awaited()
def test_kill_with_bytes_redis_value(self):
"""Redis may return bytes — kill_sandbox should decode correctly."""
sb = _mock_sandbox()
sb.kill = AsyncMock()
redis = _mock_redis(get_val=b"sb-abc")
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
assert result is True
sb.kill.assert_awaited_once()
def test_kill_timeout_returns_false(self):
"""Returns False when E2B API calls exceed the 10s timeout."""
redis = _mock_redis(get_val="sb-abc")
def test_kill_timeout_keeps_redis(self):
"""Returns False and leaves Redis entry intact when the E2B call times out."""
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
_patch_redis(redis),
patch(
@@ -266,7 +335,146 @@ class TestKillSandbox:
side_effect=asyncio.TimeoutError,
),
):
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
assert result is False
redis.delete.assert_not_awaited()
def test_kill_creating_sentinel_returns_false(self):
"""No-op when the key holds the 'creating' sentinel (no real sandbox yet)."""
redis = _mock_redis(stored_sandbox_id=_CREATING_SENTINEL)
with _patch_redis(redis):
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
assert result is False
# ---------------------------------------------------------------------------
# pause_sandbox
# ---------------------------------------------------------------------------
class TestPauseSandbox:
def test_pause_existing_sandbox(self):
"""Pause a running sandbox; Redis sandbox_id is preserved."""
sb = _mock_sandbox()
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
assert result is True
sb.pause.assert_awaited_once()
# sandbox_id should remain in Redis (not cleared on pause)
redis.delete.assert_not_awaited()
def test_pause_no_sandbox(self):
"""No-op when Redis has no sandbox_id."""
redis = _mock_redis(stored_sandbox_id=None)
with _patch_redis(redis):
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
assert result is False
def test_pause_connect_failure(self):
"""Returns False if connect fails."""
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
assert result is False
def test_pause_creating_sentinel_returns_false(self):
"""No-op when the key holds the 'creating' sentinel (no real sandbox yet)."""
redis = _mock_redis(stored_sandbox_id=_CREATING_SENTINEL)
with _patch_redis(redis):
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
assert result is False
def test_pause_timeout_returns_false(self):
"""Returns False and preserves Redis entry when the E2B API call times out."""
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
_patch_redis(redis),
patch(
"backend.copilot.tools.e2b_sandbox.asyncio.wait_for",
new_callable=AsyncMock,
side_effect=asyncio.TimeoutError,
),
):
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
assert result is False
# sandbox_id must remain in Redis so the next turn can reconnect
redis.delete.assert_not_awaited()
def test_pause_then_reconnect_reuses_sandbox(self):
"""After pause, get_or_create_sandbox reconnects the same sandbox.
Covers the pause->reconnect cycle: connect() auto-resumes a paused
sandbox, and is_running() returns True once resume completes, so the
same sandbox_id is reused rather than a new one being created.
"""
sb = _mock_sandbox(_SANDBOX_ID)
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
# Step 1: pause the sandbox
paused = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
assert paused is True
sb.pause.assert_awaited_once()
# Step 2: reconnect on next turn -- same sandbox should be returned
result = asyncio.run(
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
)
assert result is sb
mock_cls.create.assert_not_called()
# ---------------------------------------------------------------------------
# pause_sandbox_direct
# ---------------------------------------------------------------------------
class TestPauseSandboxDirect:
def test_pause_direct_success(self):
"""Pauses the sandbox directly without a Redis lookup or reconnect."""
sb = _mock_sandbox()
result = asyncio.run(pause_sandbox_direct(sb, _SESSION_ID))
assert result is True
sb.pause.assert_awaited_once()
def test_pause_direct_failure_returns_false(self):
"""Returns False when sandbox.pause() raises."""
sb = _mock_sandbox()
sb.pause = AsyncMock(side_effect=RuntimeError("e2b error"))
result = asyncio.run(pause_sandbox_direct(sb, _SESSION_ID))
assert result is False
def test_pause_direct_timeout_returns_false(self):
"""Returns False when sandbox.pause() exceeds the 10s timeout."""
sb = _mock_sandbox()
with patch(
"backend.copilot.tools.e2b_sandbox.asyncio.wait_for",
new_callable=AsyncMock,
side_effect=asyncio.TimeoutError,
):
result = asyncio.run(pause_sandbox_direct(sb, _SESSION_ID))
assert result is False
redis.delete.assert_awaited_once_with(_KEY)

View File

@@ -1,32 +1,20 @@
"""EditAgentTool - Edits existing agents using natural language."""
"""EditAgentTool - Edits existing agents using pre-built JSON."""
import logging
from typing import Any
from backend.copilot.model import ChatSession
from .agent_generator import (
AgentGeneratorNotConfiguredError,
generate_agent_patch,
get_agent_as_json,
get_user_message_for_error,
save_agent_to_library,
)
from .agent_generator import get_agent_as_json
from .agent_generator.pipeline import fetch_library_agents, fix_validate_and_save
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
ToolResponseBase,
)
from .models import ErrorResponse, ToolResponseBase
logger = logging.getLogger(__name__)
class EditAgentTool(BaseTool):
"""Tool for editing existing agents using natural language."""
"""Tool for editing existing agents using pre-built JSON."""
@property
def name(self) -> str:
@@ -35,11 +23,12 @@ class EditAgentTool(BaseTool):
@property
def description(self) -> str:
return (
"Edit an existing agent from the user's library using natural language. "
"Generates updates to the agent while preserving unchanged parts. "
"\n\nIMPORTANT: Before calling this tool, if the changes involve adding new "
"Edit an existing agent. Pass `agent_json` with the complete "
"updated agent JSON you generated. The tool validates, auto-fixes, "
"and saves.\n\n"
"IMPORTANT: Before calling this tool, if the changes involve adding new "
"functionality, search for relevant existing agents using find_library_agent "
"that could be used as building blocks. Pass their IDs in library_agent_ids."
"that could be used as building blocks."
)
@property
@@ -58,26 +47,20 @@ class EditAgentTool(BaseTool):
"Can be a graph ID or library agent ID."
),
},
"changes": {
"type": "string",
"agent_json": {
"type": "object",
"description": (
"Natural language description of what changes to make. "
"Be specific about what to add, remove, or modify."
),
},
"context": {
"type": "string",
"description": (
"Additional context or answers to previous clarifying questions."
"Complete updated agent JSON to validate and save. "
"Must contain 'nodes' and 'links'. "
"Include 'name' and/or 'description' if they need "
"to be updated."
),
},
"library_agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": (
"List of library agent IDs to use as building blocks for the changes. "
"If adding new functionality, search for relevant agents using "
"find_library_agent first, then pass their IDs here."
"List of library agent IDs to use as building blocks for the changes."
),
},
"save": {
@@ -89,7 +72,7 @@ class EditAgentTool(BaseTool):
"default": True,
},
},
"required": ["agent_id", "changes"],
"required": ["agent_id", "agent_json"],
}
async def _execute(
@@ -98,36 +81,39 @@ class EditAgentTool(BaseTool):
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the edit_agent tool.
Flow:
1. Fetch the current agent
2. Generate updated agent (external service handles fixing and validation)
3. Preview or save based on the save parameter
"""
agent_id = kwargs.get("agent_id", "").strip()
changes = kwargs.get("changes", "").strip()
context = kwargs.get("context", "")
library_agent_ids = kwargs.get("library_agent_ids", [])
save = kwargs.get("save", True)
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
session_id = session.session_id if session else None
if not agent_id:
return ErrorResponse(
message="Please provide the agent ID to edit.",
error="Missing agent_id parameter",
error="missing_agent_id",
session_id=session_id,
)
if not changes:
if not agent_json:
return ErrorResponse(
message="Please describe what changes you want to make.",
error="Missing changes parameter",
message=(
"Please provide agent_json with the complete updated agent graph."
),
error="missing_agent_json",
session_id=session_id,
)
current_agent = await get_agent_as_json(agent_id, user_id)
save = kwargs.get("save", True)
library_agent_ids = kwargs.get("library_agent_ids", [])
nodes = agent_json.get("nodes", [])
if not nodes:
return ErrorResponse(
message="The agent JSON has no nodes.",
error="empty_agent",
session_id=session_id,
)
# Preserve original agent's ID
current_agent = await get_agent_as_json(agent_id, user_id)
if current_agent is None:
return ErrorResponse(
message=f"Could not find agent with ID '{agent_id}' in your library.",
@@ -135,142 +121,19 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Fetch library agents by IDs if provided
library_agents = None
if user_id and library_agent_ids:
try:
from .agent_generator import get_library_agents_by_ids
agent_json["id"] = current_agent.get("id", agent_id)
agent_json["version"] = current_agent.get("version", 1)
agent_json.setdefault("is_active", True)
graph_id = current_agent.get("id")
# Filter out the current agent being edited
filtered_ids = [id for id in library_agent_ids if id != graph_id]
# Fetch library agents for AgentExecutorBlock validation
library_agents = await fetch_library_agents(user_id, library_agent_ids)
library_agents = await get_library_agents_by_ids(
user_id=user_id,
agent_ids=filtered_ids,
)
logger.debug(
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents by IDs: {e}")
update_request = changes
if context:
update_request = f"{changes}\n\nAdditional context:\n{context}"
try:
result = await generate_agent_patch(
update_request,
current_agent,
library_agents,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
"Agent editing is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
if result is None:
return ErrorResponse(
message="Failed to generate changes. The agent generation service may be unavailable or timed out. Please try again.",
error="update_generation_failed",
details={"agent_id": agent_id, "changes": changes[:100]},
session_id=session_id,
)
# Check if the result is an error from the external service
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="generate the changes",
llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.",
validation_message="The generated changes failed validation. Please try rephrasing your request.",
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"update_generation_failed:{error_type}",
details={
"agent_id": agent_id,
"changes": changes[:100],
"service_error": error_msg,
"error_type": error_type,
},
session_id=session_id,
)
if result.get("type") == "clarifying_questions":
questions = result.get("questions", [])
return ClarificationNeededResponse(
message=(
"I need some more information about the changes. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
],
session_id=session_id,
)
updated_agent = result
agent_name = updated_agent.get("name", "Updated Agent")
agent_description = updated_agent.get("description", "")
node_count = len(updated_agent.get("nodes", []))
link_count = len(updated_agent.get("links", []))
if not save:
return AgentPreviewResponse(
message=(
f"I've updated the agent. "
f"The agent now has {node_count} blocks. "
f"Review it and call edit_agent with save=true to save the changes."
),
agent_json=updated_agent,
agent_name=agent_name,
description=agent_description,
node_count=node_count,
link_count=link_count,
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
error="auth_required",
session_id=session_id,
)
try:
created_graph, library_agent = await save_agent_to_library(
updated_agent, user_id, is_update=True
)
return AgentSavedResponse(
message=f"Updated agent '{created_graph.name}' has been saved to your library!",
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)
except Exception as e:
return ErrorResponse(
message=f"Failed to save the updated agent: {str(e)}",
error="save_failed",
details={"exception": str(e)},
session_id=session_id,
)
return await fix_validate_and_save(
agent_json,
user_id=user_id,
session_id=session_id,
save=save,
is_update=True,
default_name="Updated Agent",
library_agents=library_agents,
)

View File

@@ -32,7 +32,7 @@ COPILOT_EXCLUDED_BLOCK_TYPES = {
BlockType.NOTE, # Visual annotation only - no runtime behavior
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
BlockType.MCP_TOOL, # Has dedicated run_mcp_tool tool with proper discovery + auth flow
BlockType.MCP_TOOL, # Has dedicated run_mcp_tool tool with discovery + auth flow
}
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
@@ -72,6 +72,15 @@ class FindBlockTool(BaseTool):
"Use keywords like 'email', 'http', 'text', 'ai', etc."
),
},
"include_schemas": {
"type": "boolean",
"description": (
"If true, include full input_schema and output_schema "
"for each block. Use when generating agent JSON that "
"needs block schemas. Default is false."
),
"default": False,
},
},
"required": ["query"],
}
@@ -99,6 +108,7 @@ class FindBlockTool(BaseTool):
ErrorResponse: Error message
"""
query = kwargs.get("query", "").strip()
include_schemas = kwargs.get("include_schemas", False)
session_id = session.session_id
if not query:
@@ -143,15 +153,21 @@ class FindBlockTool(BaseTool):
):
continue
blocks.append(
BlockInfoSummary(
id=block_id,
name=block.name,
description=block.description or "",
categories=[c.value for c in block.categories],
)
summary = BlockInfoSummary(
id=block_id,
name=block.name,
description=block.optimized_description or block.description or "",
categories=[c.value for c in block.categories],
)
if include_schemas:
info = block.get_info()
summary.input_schema = info.inputSchema
summary.output_schema = info.outputSchema
summary.static_output = info.staticOutput
blocks.append(summary)
if len(blocks) >= _TARGET_RESULTS:
break

View File

@@ -25,6 +25,7 @@ def make_mock_block(
input_schema: dict | None = None,
output_schema: dict | None = None,
credentials_fields: dict | None = None,
static_output: bool = False,
):
"""Create a mock block for testing."""
mock = MagicMock()
@@ -33,6 +34,7 @@ def make_mock_block(
mock.description = f"{name} description"
mock.block_type = block_type
mock.disabled = disabled
mock.static_output = static_output
mock.input_schema = MagicMock()
mock.input_schema.jsonschema.return_value = input_schema or {
"properties": {},
@@ -42,6 +44,15 @@ def make_mock_block(
mock.output_schema = MagicMock()
mock.output_schema.jsonschema.return_value = output_schema or {}
mock.categories = []
mock.optimized_description = None
# Mock get_info() for include_schemas support
mock_info = MagicMock()
mock_info.inputSchema = input_schema or {"properties": {}, "required": []}
mock_info.outputSchema = output_schema or {}
mock_info.staticOutput = static_output
mock.get_info.return_value = mock_info
return mock
@@ -399,3 +410,92 @@ class TestFindBlockFiltering:
f"Average chars per block ({avg_chars}) exceeds 500. "
f"Total response: {total_chars} chars for {response.count} blocks."
)
@pytest.mark.asyncio(loop_scope="session")
async def test_include_schemas_false_omits_schemas(self):
"""Without include_schemas, schemas should be empty dicts."""
session = make_session(user_id=_TEST_USER_ID)
input_schema = {"properties": {"url": {"type": "string"}}, "required": ["url"]}
output_schema = {"properties": {"result": {"type": "string"}}}
search_results = [{"content_id": "block-1", "score": 0.9}]
block = make_mock_block(
"block-1",
"Test Block",
BlockType.STANDARD,
input_schema=input_schema,
output_schema=output_schema,
)
mock_search_db = MagicMock()
mock_search_db.unified_hybrid_search = AsyncMock(
return_value=(search_results, 1)
)
with (
patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
),
patch(
"backend.copilot.tools.find_block.get_block",
return_value=block,
),
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
query="test",
include_schemas=False,
)
assert isinstance(response, BlockListResponse)
assert response.blocks[0].input_schema == {}
assert response.blocks[0].output_schema == {}
assert response.blocks[0].static_output is False
@pytest.mark.asyncio(loop_scope="session")
async def test_include_schemas_true_populates_schemas(self):
"""With include_schemas=true, schemas should be populated from block info."""
session = make_session(user_id=_TEST_USER_ID)
input_schema = {"properties": {"url": {"type": "string"}}, "required": ["url"]}
output_schema = {"properties": {"result": {"type": "string"}}}
search_results = [{"content_id": "block-1", "score": 0.9}]
block = make_mock_block(
"block-1",
"Test Block",
BlockType.STANDARD,
input_schema=input_schema,
output_schema=output_schema,
static_output=True,
)
mock_search_db = MagicMock()
mock_search_db.unified_hybrid_search = AsyncMock(
return_value=(search_results, 1)
)
with (
patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
),
patch(
"backend.copilot.tools.find_block.get_block",
return_value=block,
),
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
query="test",
include_schemas=True,
)
assert isinstance(response, BlockListResponse)
assert response.blocks[0].input_schema == input_schema
assert response.blocks[0].output_schema == output_schema
assert response.blocks[0].static_output is True

View File

@@ -22,6 +22,9 @@ class FindLibraryAgentTool(BaseTool):
"Search for or list agents in the user's library. Use this to find "
"agents the user has already added to their library, including agents "
"they created or added from the marketplace. "
"When creating agents with sub-agent composition, use this to get "
"the agent's graph_id, graph_version, input_schema, and output_schema "
"needed for AgentExecutorBlock nodes. "
"Omit the query to list all agents."
)

View File

@@ -0,0 +1,134 @@
"""FixAgentGraphTool - Auto-fixes common agent JSON issues."""
import logging
from typing import Any
from backend.copilot.model import ChatSession
from .agent_generator.validation import AgentFixer, AgentValidator, get_blocks_as_dicts
from .base import BaseTool
from .models import ErrorResponse, FixResultResponse, ToolResponseBase
logger = logging.getLogger(__name__)
class FixAgentGraphTool(BaseTool):
"""Tool for auto-fixing common issues in agent JSON graphs."""
@property
def name(self) -> str:
return "fix_agent_graph"
@property
def description(self) -> str:
return (
"Auto-fix common issues in an agent JSON graph. Applies fixes for:\n"
"- Missing or invalid UUIDs on nodes and links\n"
"- StoreValueBlock prerequisites for ConditionBlock\n"
"- Double curly brace escaping in prompt templates\n"
"- AddToList/AddToDictionary prerequisite blocks\n"
"- CodeExecutionBlock output field naming\n"
"- Missing credentials configuration\n"
"- Node X coordinate spacing (800+ units apart)\n"
"- AI model default parameters\n"
"- Link static properties based on input schema\n"
"- Type mismatches (inserts conversion blocks)\n\n"
"Returns the fixed agent JSON plus a list of fixes applied. "
"After fixing, the agent is re-validated. If still invalid, "
"the remaining errors are included in the response."
)
@property
def requires_auth(self) -> bool:
return False
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"agent_json": {
"type": "object",
"description": (
"The agent JSON to fix. Must contain 'nodes' and 'links' arrays."
),
},
},
"required": ["agent_json"],
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
agent_json = kwargs.get("agent_json")
session_id = session.session_id if session else None
if not agent_json or not isinstance(agent_json, dict):
return ErrorResponse(
message="Please provide a valid agent JSON object.",
error="Missing or invalid agent_json parameter",
session_id=session_id,
)
nodes = agent_json.get("nodes", [])
if not nodes:
return ErrorResponse(
message="The agent JSON has no nodes. An agent needs at least one block.",
error="empty_agent",
session_id=session_id,
)
try:
blocks = get_blocks_as_dicts()
fixer = AgentFixer()
fixed_agent = fixer.apply_all_fixes(agent_json, blocks)
fixes_applied = fixer.get_fixes_applied()
except Exception as e:
logger.error(f"Fixer error: {e}", exc_info=True)
return ErrorResponse(
message=f"Auto-fix encountered an error: {str(e)}",
error="fix_exception",
session_id=session_id,
)
# Re-validate after fixing
try:
validator = AgentValidator()
is_valid, _ = validator.validate(fixed_agent, blocks)
remaining_errors = validator.errors if not is_valid else []
except Exception as e:
logger.warning(f"Post-fix validation error: {e}", exc_info=True)
remaining_errors = [f"Post-fix validation failed: {str(e)}"]
is_valid = False
if is_valid:
return FixResultResponse(
message=(
f"Applied {len(fixes_applied)} fix(es). "
"Agent graph is now valid!"
),
fixed_agent_json=fixed_agent,
fixes_applied=fixes_applied,
fix_count=len(fixes_applied),
valid_after_fix=True,
remaining_errors=[],
session_id=session_id,
)
return FixResultResponse(
message=(
f"Applied {len(fixes_applied)} fix(es), but "
f"{len(remaining_errors)} issue(s) remain. "
"Review the remaining errors and fix manually."
),
fixed_agent_json=fixed_agent,
fixes_applied=fixes_applied,
fix_count=len(fixes_applied),
valid_after_fix=False,
remaining_errors=remaining_errors,
session_id=session_id,
)

View File

@@ -0,0 +1,189 @@
"""Tests for FixAgentGraphTool."""
from unittest.mock import MagicMock, patch
import pytest
from backend.copilot.tools.fix_agent import FixAgentGraphTool
from backend.copilot.tools.models import ErrorResponse, FixResultResponse
from ._test_data import make_session
_TEST_USER_ID = "test-user-fix-agent"
@pytest.fixture
def tool():
return FixAgentGraphTool()
@pytest.fixture
def session():
return make_session(_TEST_USER_ID)
@pytest.mark.asyncio
async def test_missing_agent_json_returns_error(tool, session):
"""Missing agent_json returns ErrorResponse."""
result = await tool._execute(user_id=_TEST_USER_ID, session=session)
assert isinstance(result, ErrorResponse)
assert result.error is not None
assert "agent_json" in result.error.lower()
@pytest.mark.asyncio
async def test_empty_nodes_returns_error(tool, session):
"""Agent JSON with no nodes returns ErrorResponse."""
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json={"nodes": [], "links": []},
)
assert isinstance(result, ErrorResponse)
assert "no nodes" in result.message.lower()
@pytest.mark.asyncio
async def test_fix_and_validate_success(tool, session):
"""Fixer applies fixes and validator passes -> valid_after_fix=True."""
agent_json = {
"nodes": [
{
"id": "node-1",
"block_id": "block-1",
"input_default": {},
"metadata": {"position": {"x": 0, "y": 0}},
}
],
"links": [],
}
fixed_agent = dict(agent_json) # Fixer returns the full agent dict
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(return_value=fixed_agent)
mock_fixer.get_fixes_applied.return_value = ["Fixed node UUID format"]
mock_validator = MagicMock()
mock_validator.validate.return_value = (True, None)
mock_validator.errors = []
with (
patch(
"backend.copilot.tools.fix_agent.get_blocks_as_dicts",
return_value=[],
),
patch(
"backend.copilot.tools.fix_agent.AgentFixer",
return_value=mock_fixer,
),
patch(
"backend.copilot.tools.fix_agent.AgentValidator",
return_value=mock_validator,
),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json=agent_json,
)
assert isinstance(result, FixResultResponse)
assert result.valid_after_fix is True
assert result.fix_count == 1
assert result.fixes_applied == ["Fixed node UUID format"]
assert result.remaining_errors == []
@pytest.mark.asyncio
async def test_fix_with_remaining_errors(tool, session):
"""Fixer applies some fixes but validation still fails."""
agent_json = {
"nodes": [
{
"id": "node-1",
"block_id": "block-1",
"input_default": {},
"metadata": {},
}
],
"links": [
{
"id": "link-1",
"source_id": "node-1",
"source_name": "output",
"sink_id": "node-2",
"sink_name": "input",
}
],
}
fixed_agent = dict(agent_json)
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(return_value=fixed_agent)
mock_fixer.get_fixes_applied.return_value = ["Fixed UUID"]
mock_validator = MagicMock()
mock_validator.validate.return_value = (
False,
"Link references non-existent node 'node-2'",
)
mock_validator.errors = ["Link references non-existent node 'node-2'"]
with (
patch(
"backend.copilot.tools.fix_agent.get_blocks_as_dicts",
return_value=[],
),
patch(
"backend.copilot.tools.fix_agent.AgentFixer",
return_value=mock_fixer,
),
patch(
"backend.copilot.tools.fix_agent.AgentValidator",
return_value=mock_validator,
),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json=agent_json,
)
assert isinstance(result, FixResultResponse)
assert result.valid_after_fix is False
assert result.fix_count == 1
assert len(result.remaining_errors) == 1
@pytest.mark.asyncio
async def test_fixer_exception_returns_error(tool, session):
"""Fixer exception returns ErrorResponse."""
agent_json = {
"nodes": [{"id": "n1", "block_id": "b1", "input_default": {}, "metadata": {}}],
"links": [],
}
mock_fixer = MagicMock()
mock_fixer.apply_all_fixes = MagicMock(side_effect=RuntimeError("fixer crashed"))
with (
patch(
"backend.copilot.tools.fix_agent.get_blocks_as_dicts",
return_value=[],
),
patch(
"backend.copilot.tools.fix_agent.AgentFixer",
return_value=mock_fixer,
),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json=agent_json,
)
assert isinstance(result, ErrorResponse)
assert result.error is not None
assert "fix_exception" in result.error

View File

@@ -0,0 +1,84 @@
"""GetAgentBuildingGuideTool - Returns the complete agent building guide."""
import logging
from pathlib import Path
from typing import Any
from backend.copilot.model import ChatSession
from .base import BaseTool
from .models import ErrorResponse, ResponseType, ToolResponseBase
logger = logging.getLogger(__name__)
_GUIDE_CACHE: str | None = None
def _load_guide() -> str:
global _GUIDE_CACHE
if _GUIDE_CACHE is None:
guide_path = Path(__file__).parent.parent / "sdk" / "agent_generation_guide.md"
_GUIDE_CACHE = guide_path.read_text(encoding="utf-8")
return _GUIDE_CACHE
class AgentBuildingGuideResponse(ToolResponseBase):
"""Response containing the agent building guide."""
type: ResponseType = ResponseType.AGENT_BUILDER_GUIDE
content: str
class GetAgentBuildingGuideTool(BaseTool):
"""Returns the complete guide for building agent JSON graphs.
Covers block IDs, link structure, AgentInputBlock, AgentOutputBlock,
AgentExecutorBlock (sub-agent composition), and MCPToolBlock usage.
"""
@property
def name(self) -> str:
return "get_agent_building_guide"
@property
def description(self) -> str:
return (
"Returns the complete guide for building agent JSON graphs, including "
"block IDs, link structure, AgentInputBlock, AgentOutputBlock, "
"AgentExecutorBlock (for sub-agent composition), and MCPToolBlock usage. "
"Call this before generating agent JSON to ensure correct structure."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {},
"required": [],
}
@property
def requires_auth(self) -> bool:
return False
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id if session else None
try:
content = _load_guide()
return AgentBuildingGuideResponse(
message="Agent building guide loaded.",
content=content,
session_id=session_id,
)
except Exception as e:
logger.error("Failed to load agent building guide: %s", e)
return ErrorResponse(
message="Failed to load agent building guide.",
error=str(e),
session_id=session_id,
)

View File

@@ -0,0 +1,79 @@
"""GetMCPGuideTool - Returns the MCP tool usage guide."""
import logging
from pathlib import Path
from typing import Any
from backend.copilot.model import ChatSession
from .base import BaseTool
from .models import ErrorResponse, ResponseType, ToolResponseBase
logger = logging.getLogger(__name__)
_GUIDE_CACHE: str | None = None
def _load_guide() -> str:
global _GUIDE_CACHE
if _GUIDE_CACHE is None:
guide_path = Path(__file__).parent.parent / "sdk" / "mcp_tool_guide.md"
_GUIDE_CACHE = guide_path.read_text(encoding="utf-8")
return _GUIDE_CACHE
class MCPGuideResponse(ToolResponseBase):
"""Response containing the MCP tool guide."""
type: ResponseType = ResponseType.MCP_GUIDE
content: str
class GetMCPGuideTool(BaseTool):
"""Returns the MCP tool usage guide with known server URLs and auth details."""
@property
def name(self) -> str:
return "get_mcp_guide"
@property
def description(self) -> str:
return (
"Returns the MCP tool guide: known hosted server URLs (Notion, Linear, "
"Stripe, Intercom, Cloudflare, Atlassian) and authentication workflow. "
"Call before using run_mcp_tool if you need a server URL or auth info."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {},
"required": [],
}
@property
def requires_auth(self) -> bool:
return False
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id if session else None
try:
content = _load_guide()
return MCPGuideResponse(
message="MCP guide loaded.",
content=content,
session_id=session_id,
)
except Exception as e:
logger.error("Failed to load MCP guide: %s", e)
return ErrorResponse(
message="Failed to load MCP guide.",
error=str(e),
session_id=session_id,
)

View File

@@ -1,7 +1,24 @@
"""Shared helpers for chat tools."""
import logging
from collections import defaultdict
from typing import Any
from pydantic_core import PydanticUndefined
from backend.blocks._base import AnyBlockSchema
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
from backend.data.db_accessors import workspace_db
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
from .models import BlockOutputResponse, ErrorResponse, ToolResponseBase
from .utils import match_credentials_to_requirements
logger = logging.getLogger(__name__)
def get_inputs_from_schema(
input_schema: dict[str, Any],
@@ -27,3 +44,159 @@ def get_inputs_from_schema(
for name, schema in properties.items()
if name not in exclude
]
async def execute_block(
*,
block: AnyBlockSchema,
block_id: str,
input_data: dict[str, Any],
user_id: str,
session_id: str,
node_exec_id: str,
matched_credentials: dict[str, CredentialsMetaInput],
sensitive_action_safe_mode: bool = False,
) -> ToolResponseBase:
"""Execute a block with full context setup, credential injection, and error handling.
This is the shared execution path used by both ``run_block`` (after review
check) and ``continue_run_block`` (after approval).
Returns:
BlockOutputResponse on success, ErrorResponse on failure.
"""
try:
workspace = await workspace_db().get_or_create_workspace(user_id)
synthetic_graph_id = f"{COPILOT_SESSION_PREFIX}{session_id}"
synthetic_node_id = f"{COPILOT_NODE_PREFIX}{block_id}"
execution_context = ExecutionContext(
user_id=user_id,
graph_id=synthetic_graph_id,
graph_exec_id=synthetic_graph_id,
graph_version=1,
node_id=synthetic_node_id,
node_exec_id=node_exec_id,
workspace_id=workspace.id,
session_id=session_id,
sensitive_action_safe_mode=sensitive_action_safe_mode,
)
exec_kwargs: dict[str, Any] = {
"user_id": user_id,
"execution_context": execution_context,
"workspace_id": workspace.id,
"graph_exec_id": synthetic_graph_id,
"node_exec_id": node_exec_id,
"node_id": synthetic_node_id,
"graph_version": 1,
"graph_id": synthetic_graph_id,
}
# Inject credentials
creds_manager = IntegrationCredentialsManager()
for field_name, cred_meta in matched_credentials.items():
if field_name not in input_data:
input_data[field_name] = cred_meta.model_dump()
actual_credentials = await creds_manager.get(
user_id, cred_meta.id, lock=False
)
if actual_credentials:
exec_kwargs[field_name] = actual_credentials
else:
return ErrorResponse(
message=f"Failed to retrieve credentials for {field_name}",
session_id=session_id,
)
# Execute the block and collect outputs
outputs: dict[str, list[Any]] = defaultdict(list)
async for output_name, output_data in block.execute(
input_data,
**exec_kwargs,
):
outputs[output_name].append(output_data)
return BlockOutputResponse(
message=f"Block '{block.name}' executed successfully",
block_id=block_id,
block_name=block.name,
outputs=dict(outputs),
success=True,
session_id=session_id,
)
except BlockError as e:
logger.warning(f"Block execution failed: {e}")
return ErrorResponse(
message=f"Block execution failed: {e}",
error=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to execute block: {str(e)}",
error=str(e),
session_id=session_id,
)
async def resolve_block_credentials(
user_id: str,
block: AnyBlockSchema,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""Resolve credentials for a block by matching user's available credentials.
Handles discriminated credentials (e.g. provider selection based on model).
Returns:
(matched_credentials, missing_credentials)
"""
input_data = input_data or {}
requirements = _resolve_discriminated_credentials(block, input_data)
if not requirements:
return {}, []
return await match_credentials_to_requirements(user_id, requirements)
def _resolve_discriminated_credentials(
block: AnyBlockSchema,
input_data: dict[str, Any],
) -> dict[str, CredentialsFieldInfo]:
"""Resolve credential requirements, applying discriminator logic where needed."""
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return {}
resolved: dict[str, CredentialsFieldInfo] = {}
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(field_info.discriminator)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
effective_field_info.discriminator_values.add(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
resolved[field_name] = effective_field_info
return resolved

View File

@@ -12,50 +12,52 @@ from backend.data.model import CredentialsMetaInput
class ResponseType(str, Enum):
"""Types of tool responses."""
# General
ERROR = "error"
NO_RESULTS = "no_results"
NEED_LOGIN = "need_login"
# Agent discovery & execution
AGENTS_FOUND = "agents_found"
AGENT_DETAILS = "agent_details"
SETUP_REQUIREMENTS = "setup_requirements"
INPUT_VALIDATION_ERROR = "input_validation_error"
EXECUTION_STARTED = "execution_started"
NEED_LOGIN = "need_login"
ERROR = "error"
NO_RESULTS = "no_results"
AGENT_OUTPUT = "agent_output"
UNDERSTANDING_UPDATED = "understanding_updated"
AGENT_PREVIEW = "agent_preview"
AGENT_SAVED = "agent_saved"
CLARIFICATION_NEEDED = "clarification_needed"
SUGGESTED_GOAL = "suggested_goal"
# Agent builder (create / edit / validate / fix)
AGENT_BUILDER_GUIDE = "agent_builder_guide"
AGENT_BUILDER_PREVIEW = "agent_builder_preview"
AGENT_BUILDER_SAVED = "agent_builder_saved"
AGENT_BUILDER_CLARIFICATION_NEEDED = "agent_builder_clarification_needed"
AGENT_BUILDER_VALIDATION_RESULT = "agent_builder_validation_result"
AGENT_BUILDER_FIX_RESULT = "agent_builder_fix_result"
# Block
BLOCK_LIST = "block_list"
BLOCK_DETAILS = "block_details"
BLOCK_OUTPUT = "block_output"
REVIEW_REQUIRED = "review_required"
# MCP
MCP_GUIDE = "mcp_guide"
MCP_TOOLS_DISCOVERED = "mcp_tools_discovered"
MCP_TOOL_OUTPUT = "mcp_tool_output"
# Docs
DOC_SEARCH_RESULTS = "doc_search_results"
DOC_PAGE = "doc_page"
# Workspace response types
# Workspace files
WORKSPACE_FILE_LIST = "workspace_file_list"
WORKSPACE_FILE_CONTENT = "workspace_file_content"
WORKSPACE_FILE_METADATA = "workspace_file_metadata"
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
# Long-running operation types
OPERATION_IN_PROGRESS = "operation_in_progress"
# Input validation
INPUT_VALIDATION_ERROR = "input_validation_error"
# Web fetch
WEB_FETCH = "web_fetch"
# Agent-browser multi-step automation (navigate, act, screenshot)
BROWSER_NAVIGATE = "browser_navigate"
BROWSER_ACT = "browser_act"
BROWSER_SCREENSHOT = "browser_screenshot"
# Code execution
BASH_EXEC = "bash_exec"
# Feature request types
FEATURE_REQUEST_SEARCH = "feature_request_search"
FEATURE_REQUEST_CREATED = "feature_request_created"
# Goal refinement
SUGGESTED_GOAL = "suggested_goal"
# MCP tool types
MCP_TOOLS_DISCOVERED = "mcp_tools_discovered"
MCP_TOOL_OUTPUT = "mcp_tool_output"
# Folder management types
# Folder management
FOLDER_CREATED = "folder_created"
FOLDER_LIST = "folder_list"
FOLDER_UPDATED = "folder_updated"
@@ -63,6 +65,21 @@ class ResponseType(str, Enum):
FOLDER_DELETED = "folder_deleted"
AGENTS_MOVED_TO_FOLDER = "agents_moved_to_folder"
# Browser automation
BROWSER_NAVIGATE = "browser_navigate"
BROWSER_ACT = "browser_act"
BROWSER_SCREENSHOT = "browser_screenshot"
# Code execution
BASH_EXEC = "bash_exec"
# Web
WEB_FETCH = "web_fetch"
# Feature requests
FEATURE_REQUEST_SEARCH = "feature_request_search"
FEATURE_REQUEST_CREATED = "feature_request_created"
# Base response model
class ToolResponseBase(BaseModel):
@@ -92,6 +109,15 @@ class AgentInfo(BaseModel):
has_external_trigger: bool | None = None
new_output: bool | None = None
graph_id: str | None = None
graph_version: int | None = None
input_schema: dict[str, Any] | None = Field(
default=None,
description="JSON Schema for the agent's inputs (for AgentExecutorBlock)",
)
output_schema: dict[str, Any] | None = Field(
default=None,
description="JSON Schema for the agent's outputs (for AgentExecutorBlock)",
)
inputs: dict[str, Any] | None = Field(
default=None,
description="Input schema for the agent, including field names, types, and defaults",
@@ -282,7 +308,7 @@ class ClarifyingQuestion(BaseModel):
class AgentPreviewResponse(ToolResponseBase):
"""Response for previewing a generated agent before saving."""
type: ResponseType = ResponseType.AGENT_PREVIEW
type: ResponseType = ResponseType.AGENT_BUILDER_PREVIEW
agent_json: dict[str, Any]
agent_name: str
description: str
@@ -293,7 +319,7 @@ class AgentPreviewResponse(ToolResponseBase):
class AgentSavedResponse(ToolResponseBase):
"""Response when an agent is saved to the library."""
type: ResponseType = ResponseType.AGENT_SAVED
type: ResponseType = ResponseType.AGENT_BUILDER_SAVED
agent_id: str
agent_name: str
library_agent_id: str
@@ -304,7 +330,7 @@ class AgentSavedResponse(ToolResponseBase):
class ClarificationNeededResponse(ToolResponseBase):
"""Response when the LLM needs more information from the user."""
type: ResponseType = ResponseType.CLARIFICATION_NEEDED
type: ResponseType = ResponseType.AGENT_BUILDER_CLARIFICATION_NEEDED
questions: list[ClarifyingQuestion] = Field(default_factory=list)
@@ -381,6 +407,10 @@ class BlockInfoSummary(BaseModel):
default_factory=dict,
description="Full JSON schema for block outputs",
)
static_output: bool = Field(
default=False,
description="Whether the block produces output without needing input",
)
required_inputs: list[BlockInputFieldInfo] = Field(
default_factory=list,
description="List of input fields for this block",
@@ -429,16 +459,19 @@ class BlockOutputResponse(ToolResponseBase):
success: bool = True
# Long-running operation models
class OperationInProgressResponse(ToolResponseBase):
"""Response when an operation is already in progress.
class ReviewRequiredResponse(ToolResponseBase):
"""Response when a block requires human review before execution."""
Returned for idempotency when the same tool_call_id is requested again
while the background task is still running.
"""
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
tool_call_id: str
type: ResponseType = ResponseType.REVIEW_REQUIRED
block_id: str
block_name: str
review_id: str = Field(description="The review ID for tracking approval status")
graph_exec_id: str = Field(
description="The graph execution ID for fetching review status"
)
input_data: dict[str, Any] = Field(
description="The input data that requires review"
)
class WebFetchResponse(ToolResponseBase):
@@ -548,6 +581,29 @@ class BrowserScreenshotResponse(ToolResponseBase):
filename: str
# Agent generation tool response models
class ValidationResultResponse(ToolResponseBase):
"""Response for validate_agent_graph tool."""
type: ResponseType = ResponseType.AGENT_BUILDER_VALIDATION_RESULT
valid: bool
errors: list[str] = Field(default_factory=list)
error_count: int = 0
class FixResultResponse(ToolResponseBase):
"""Response for fix_agent_graph tool."""
type: ResponseType = ResponseType.AGENT_BUILDER_FIX_RESULT
fixed_agent_json: dict[str, Any]
fixes_applied: list[str] = Field(default_factory=list)
fix_count: int = 0
valid_after_fix: bool = False
remaining_errors: list[str] = Field(default_factory=list)
# Folder management models

View File

@@ -534,7 +534,9 @@ class RunAgentTool(BaseTool):
return ExecutionStartedResponse(
message=(
f"Agent '{library_agent.name}' is awaiting human review. "
f"Check at {library_agent_link}."
f"The user can approve or reject inline. After approval, "
f"the execution resumes automatically. Use view_agent_output "
f"with execution_id='{execution.id}' to check the result."
),
session_id=session_id,
execution_id=execution.id,

View File

@@ -2,38 +2,34 @@
import logging
import uuid
from collections import defaultdict
from typing import Any
from pydantic_core import PydanticUndefined
from backend.blocks import get_block
from backend.blocks import BlockType, get_block
from backend.blocks._base import AnyBlockSchema
from backend.copilot.constants import (
COPILOT_NODE_EXEC_ID_SEPARATOR,
COPILOT_NODE_PREFIX,
COPILOT_SESSION_PREFIX,
)
from backend.copilot.model import ChatSession
from backend.data.db_accessors import workspace_db
from backend.data.db_accessors import review_db
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
from .base import BaseTool
from .find_block import COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_TYPES
from .helpers import get_inputs_from_schema
from .helpers import execute_block, get_inputs_from_schema, resolve_block_credentials
from .models import (
BlockDetails,
BlockDetailsResponse,
BlockOutputResponse,
ErrorResponse,
InputValidationErrorResponse,
ReviewRequiredResponse,
SetupInfo,
SetupRequirementsResponse,
ToolResponseBase,
UserReadiness,
)
from .utils import (
build_missing_credentials_from_field_info,
match_credentials_to_requirements,
)
from .utils import build_missing_credentials_from_field_info
logger = logging.getLogger(__name__)
@@ -52,7 +48,9 @@ class RunBlockTool(BaseTool):
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
"do NOT guess or make up block IDs. "
"On first attempt (without input_data), returns detailed schema showing "
"required inputs and outputs. Then call again with proper input_data to execute."
"required inputs and outputs. Then call again with proper input_data to execute. "
"If a block requires human review, use continue_run_block with the "
"review_id after the user approves."
)
@property
@@ -83,7 +81,7 @@ class RunBlockTool(BaseTool):
),
},
},
"required": ["block_id", "input_data"],
"required": ["block_id", "block_name", "input_data"],
}
@property
@@ -149,21 +147,27 @@ class RunBlockTool(BaseTool):
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
):
# Provide actionable guidance for blocks with dedicated tools
if block.block_type == BlockType.MCP_TOOL:
hint = (
" Use the `run_mcp_tool` tool instead — it handles "
"MCP server discovery, authentication, and execution."
)
elif block.block_type == BlockType.AGENT:
hint = " Use the `run_agent` tool instead."
else:
hint = " This block is designed for use within graphs only."
return ErrorResponse(
message=(
f"Block '{block.name}' cannot be run directly in CoPilot. "
"This block is designed for use within graphs only."
),
message=f"Block '{block.name}' cannot be run directly.{hint}",
session_id=session_id,
)
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
creds_manager = IntegrationCredentialsManager()
(
matched_credentials,
missing_credentials,
) = await self._resolve_block_credentials(user_id, block, input_data)
) = await resolve_block_credentials(user_id, block, input_data)
# Get block schemas for details/validation
try:
@@ -272,169 +276,97 @@ class RunBlockTool(BaseTool):
user_authenticated=True,
)
try:
# Get or create user's workspace for CoPilot file operations
workspace = await workspace_db().get_or_create_workspace(user_id)
# Generate synthetic IDs for CoPilot context.
# Encode node_id in node_exec_id so it can be extracted later
# (e.g. for auto-approve, where we need node_id but have no NodeExecution row).
synthetic_graph_id = f"{COPILOT_SESSION_PREFIX}{session.session_id}"
synthetic_node_id = f"{COPILOT_NODE_PREFIX}{block_id}"
# Generate synthetic IDs for CoPilot context
# Each chat session is treated as its own agent with one continuous run
# This means:
# - graph_id (agent) = session (memories scoped to session when limit_to_agent=True)
# - graph_exec_id (run) = session (memories scoped to session when limit_to_run=True)
# - node_exec_id = unique per block execution
synthetic_graph_id = f"copilot-session-{session.session_id}"
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
synthetic_node_id = f"copilot-node-{block_id}"
synthetic_node_exec_id = (
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
)
# Create unified execution context with all required fields
execution_context = ExecutionContext(
# Execution identity
user_id=user_id,
graph_id=synthetic_graph_id,
graph_exec_id=synthetic_graph_exec_id,
graph_version=1, # Versions are 1-indexed
node_id=synthetic_node_id,
node_exec_id=synthetic_node_exec_id,
# Workspace with session scoping
workspace_id=workspace.id,
session_id=session.session_id,
)
# Prepare kwargs for block execution
# Keep individual kwargs for backwards compatibility with existing blocks
exec_kwargs: dict[str, Any] = {
"user_id": user_id,
"execution_context": execution_context,
# Legacy: individual kwargs for blocks not yet using execution_context
"workspace_id": workspace.id,
"graph_exec_id": synthetic_graph_exec_id,
"node_exec_id": synthetic_node_exec_id,
"node_id": synthetic_node_id,
"graph_version": 1, # Versions are 1-indexed
"graph_id": synthetic_graph_id,
}
for field_name, cred_meta in matched_credentials.items():
# Inject metadata into input_data (for validation)
if field_name not in input_data:
input_data[field_name] = cred_meta.model_dump()
# Fetch actual credentials and pass as kwargs (for execution)
actual_credentials = await creds_manager.get(
user_id, cred_meta.id, lock=False
)
if actual_credentials:
exec_kwargs[field_name] = actual_credentials
else:
return ErrorResponse(
message=f"Failed to retrieve credentials for {field_name}",
session_id=session_id,
)
# Execute the block and collect outputs
outputs: dict[str, list[Any]] = defaultdict(list)
async for output_name, output_data in block.execute(
input_data,
**exec_kwargs,
):
outputs[output_name].append(output_data)
return BlockOutputResponse(
message=f"Block '{block.name}' executed successfully",
# Check for an existing WAITING review for this block with the same input.
# If the LLM retries run_block with identical input, we reuse the existing
# review instead of creating duplicates. Different inputs = new execution.
existing_reviews = await review_db().get_pending_reviews_for_execution(
synthetic_graph_id, user_id
)
existing_review = next(
(
r
for r in existing_reviews
if r.node_id == synthetic_node_id
and r.status.value == "WAITING"
and r.payload == input_data
),
None,
)
if existing_review:
return ReviewRequiredResponse(
message=(
f"Block '{block.name}' requires human review. "
f"After the user approves, call continue_run_block with "
f"review_id='{existing_review.node_exec_id}' to execute."
),
session_id=session_id,
block_id=block_id,
block_name=block.name,
outputs=dict(outputs),
success=True,
session_id=session_id,
review_id=existing_review.node_exec_id,
graph_exec_id=synthetic_graph_id,
input_data=input_data,
)
except BlockError as e:
logger.warning(f"Block execution failed: {e}")
return ErrorResponse(
message=f"Block execution failed: {e}",
error=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to execute block: {str(e)}",
error=str(e),
synthetic_node_exec_id = (
f"{synthetic_node_id}{COPILOT_NODE_EXEC_ID_SEPARATOR}"
f"{uuid.uuid4().hex[:8]}"
)
# Check for HITL review before execution.
# This creates the review record in the DB for CoPilot flows.
review_context = ExecutionContext(
user_id=user_id,
graph_id=synthetic_graph_id,
graph_exec_id=synthetic_graph_id,
graph_version=1,
node_id=synthetic_node_id,
node_exec_id=synthetic_node_exec_id,
sensitive_action_safe_mode=True,
)
should_pause, input_data = await block.is_block_exec_need_review(
input_data,
user_id=user_id,
node_id=synthetic_node_id,
node_exec_id=synthetic_node_exec_id,
graph_exec_id=synthetic_graph_id,
graph_id=synthetic_graph_id,
graph_version=1,
execution_context=review_context,
is_graph_execution=False,
)
if should_pause:
return ReviewRequiredResponse(
message=(
f"Block '{block.name}' requires human review. "
f"After the user approves, call continue_run_block with "
f"review_id='{synthetic_node_exec_id}' to execute."
),
session_id=session_id,
block_id=block_id,
block_name=block.name,
review_id=synthetic_node_exec_id,
graph_exec_id=synthetic_graph_id,
input_data=input_data,
)
async def _resolve_block_credentials(
self,
user_id: str,
block: AnyBlockSchema,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Resolve credentials for a block by matching user's available credentials.
Args:
user_id: User ID
block: Block to resolve credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple of (matched_credentials, missing_credentials) - matched credentials
are used for block execution, missing ones indicate setup requirements.
"""
input_data = input_data or {}
requirements = self._resolve_discriminated_credentials(block, input_data)
if not requirements:
return {}, []
return await match_credentials_to_requirements(user_id, requirements)
return await execute_block(
block=block,
block_id=block_id,
input_data=input_data,
user_id=user_id,
session_id=session_id,
node_exec_id=synthetic_node_exec_id,
matched_credentials=matched_credentials,
)
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
"""Extract non-credential inputs from block schema."""
schema = block.input_schema.jsonschema()
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
def _resolve_discriminated_credentials(
self,
block: AnyBlockSchema,
input_data: dict[str, Any],
) -> dict[str, CredentialsFieldInfo]:
"""Resolve credential requirements, applying discriminator logic where needed."""
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return {}
resolved: dict[str, CredentialsFieldInfo] = {}
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(
field_info.discriminator
)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
# For host-scoped credentials, add the discriminator value
# (e.g., URL) so _credential_is_for_host can match it
effective_field_info.discriminator_values.add(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
resolved[field_name] = effective_field_info
return resolved

View File

@@ -12,6 +12,7 @@ from .models import (
BlockOutputResponse,
ErrorResponse,
InputValidationErrorResponse,
ReviewRequiredResponse,
)
from .run_block import RunBlockTool
@@ -27,9 +28,16 @@ def make_mock_block(
mock.name = name
mock.block_type = block_type
mock.disabled = disabled
mock.is_sensitive_action = False
mock.input_schema = MagicMock()
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
mock.input_schema.get_credentials_fields_info.return_value = []
mock.input_schema.get_credentials_fields_info.return_value = {}
mock.input_schema.get_credentials_fields.return_value = {}
async def _no_review(input_data, **kwargs):
return False, input_data
mock.is_block_exec_need_review = _no_review
return mock
@@ -46,6 +54,7 @@ def make_mock_block_with_schema(
mock.name = name
mock.block_type = BlockType.STANDARD
mock.disabled = False
mock.is_sensitive_action = False
mock.description = f"Test block: {name}"
input_schema = {
@@ -63,6 +72,12 @@ def make_mock_block_with_schema(
mock.output_schema = MagicMock()
mock.output_schema.jsonschema.return_value = output_schema
# Default: no review needed, pass through input_data unchanged
async def _no_review(input_data, **kwargs):
return False, input_data
mock.is_block_exec_need_review = _no_review
return mock
@@ -89,7 +104,7 @@ class TestRunBlockFiltering:
)
assert isinstance(response, ErrorResponse)
assert "cannot be run directly in CoPilot" in response.message
assert "cannot be run directly" in response.message
assert "designed for use within graphs only" in response.message
@pytest.mark.asyncio(loop_scope="session")
@@ -115,7 +130,7 @@ class TestRunBlockFiltering:
)
assert isinstance(response, ErrorResponse)
assert "cannot be run directly in CoPilot" in response.message
assert "cannot be run directly" in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_non_excluded_block_passes_guard(self):
@@ -126,9 +141,15 @@ class TestRunBlockFiltering:
"standard-id", "HTTP Request", BlockType.STANDARD
)
with patch(
"backend.copilot.tools.run_block.get_block",
return_value=standard_block,
with (
patch(
"backend.copilot.tools.run_block.get_block",
return_value=standard_block,
),
patch(
"backend.copilot.tools.helpers.match_credentials_to_requirements",
return_value=({}, []),
),
):
tool = RunBlockTool()
response = await tool._execute(
@@ -141,7 +162,7 @@ class TestRunBlockFiltering:
# Should NOT be an ErrorResponse about CoPilot exclusion
# (may be other errors like missing credentials, but not the exclusion guard)
if isinstance(response, ErrorResponse):
assert "cannot be run directly in CoPilot" not in response.message
assert "cannot be run directly" not in response.message
class TestRunBlockInputValidation:
@@ -154,12 +175,7 @@ class TestRunBlockInputValidation:
@pytest.mark.asyncio(loop_scope="session")
async def test_unknown_input_fields_are_rejected(self):
"""run_block rejects unknown input fields instead of silently ignoring them.
Scenario: The AI Text Generator block has a field called 'model' (for LLM model
selection), but the LLM calling the tool guesses wrong and sends 'LLM_Model'
instead. The block should reject the request and return the valid schema.
"""
"""run_block rejects unknown input fields instead of silently ignoring them."""
session = make_session(user_id=_TEST_USER_ID)
mock_block = make_mock_block_with_schema(
@@ -182,27 +198,31 @@ class TestRunBlockInputValidation:
output_properties={"response": {"type": "string"}},
)
with patch(
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
with (
patch(
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
),
patch(
"backend.copilot.tools.helpers.match_credentials_to_requirements",
return_value=({}, []),
),
):
tool = RunBlockTool()
# Provide 'prompt' (correct) but 'LLM_Model' instead of 'model' (wrong key)
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="ai-text-gen-id",
input_data={
"prompt": "Write a haiku about coding",
"LLM_Model": "claude-opus-4-6", # WRONG KEY - should be 'model'
"LLM_Model": "claude-opus-4-6",
},
)
assert isinstance(response, InputValidationErrorResponse)
assert "LLM_Model" in response.unrecognized_fields
assert "Block was not executed" in response.message
assert "inputs" in response.model_dump() # valid schema included
assert "inputs" in response.model_dump()
@pytest.mark.asyncio(loop_scope="session")
async def test_multiple_wrong_keys_are_all_reported(self):
@@ -221,21 +241,26 @@ class TestRunBlockInputValidation:
required_fields=["prompt"],
)
with patch(
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
with (
patch(
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
),
patch(
"backend.copilot.tools.helpers.match_credentials_to_requirements",
return_value=({}, []),
),
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="ai-text-gen-id",
input_data={
"prompt": "Hello", # correct
"llm_model": "claude-opus-4-6", # WRONG - should be 'model'
"system_prompt": "Be helpful", # WRONG - should be 'sys_prompt'
"retries": 5, # WRONG - should be 'retry'
"prompt": "Hello",
"llm_model": "claude-opus-4-6",
"system_prompt": "Be helpful",
"retries": 5,
},
)
@@ -262,23 +287,26 @@ class TestRunBlockInputValidation:
required_fields=["prompt"],
)
with patch(
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
with (
patch(
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
),
patch(
"backend.copilot.tools.helpers.match_credentials_to_requirements",
return_value=({}, []),
),
):
tool = RunBlockTool()
# 'prompt' is missing AND 'LLM_Model' is an unknown field
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="ai-text-gen-id",
input_data={
"LLM_Model": "claude-opus-4-6", # wrong key, and 'prompt' is missing
"LLM_Model": "claude-opus-4-6",
},
)
# Unknown fields are caught first
assert isinstance(response, InputValidationErrorResponse)
assert "LLM_Model" in response.unrecognized_fields
@@ -313,7 +341,11 @@ class TestRunBlockInputValidation:
return_value=mock_block,
),
patch(
"backend.copilot.tools.run_block.workspace_db",
"backend.copilot.tools.helpers.match_credentials_to_requirements",
return_value=({}, []),
),
patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
),
):
@@ -325,7 +357,7 @@ class TestRunBlockInputValidation:
block_id="ai-text-gen-id",
input_data={
"prompt": "Write a haiku",
"model": "gpt-4o-mini", # correct field name
"model": "gpt-4o-mini",
},
)
@@ -347,20 +379,191 @@ class TestRunBlockInputValidation:
required_fields=["prompt"],
)
with patch(
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
with (
patch(
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
),
patch(
"backend.copilot.tools.helpers.match_credentials_to_requirements",
return_value=({}, []),
),
):
tool = RunBlockTool()
# Only provide valid optional field, missing required 'prompt'
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="ai-text-gen-id",
input_data={
"model": "gpt-4o-mini", # valid but optional
"model": "gpt-4o-mini",
},
)
assert isinstance(response, BlockDetailsResponse)
class TestRunBlockSensitiveAction:
"""Tests for sensitive action HITL review in RunBlockTool.
run_block calls is_block_exec_need_review() explicitly before execution.
When review is needed (should_pause=True), ReviewRequiredResponse is returned.
"""
@pytest.mark.asyncio(loop_scope="session")
async def test_sensitive_block_paused_returns_review_required(self):
"""When is_block_exec_need_review returns should_pause=True, ReviewRequiredResponse is returned."""
session = make_session(user_id=_TEST_USER_ID)
input_data = {
"repo_url": "https://github.com/test/repo",
"branch": "feature-branch",
}
mock_block = make_mock_block_with_schema(
block_id="delete-branch-id",
name="Delete Branch",
input_properties={
"repo_url": {"type": "string"},
"branch": {"type": "string"},
},
required_fields=["repo_url", "branch"],
)
mock_block.is_sensitive_action = True
mock_block.is_block_exec_need_review = AsyncMock(
return_value=(True, input_data)
)
with (
patch(
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
),
patch(
"backend.copilot.tools.helpers.match_credentials_to_requirements",
return_value=({}, []),
),
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="delete-branch-id",
input_data=input_data,
)
assert isinstance(response, ReviewRequiredResponse)
assert "requires human review" in response.message
assert "continue_run_block" in response.message
assert response.block_name == "Delete Branch"
@pytest.mark.asyncio(loop_scope="session")
async def test_sensitive_block_executes_after_approval(self):
"""After approval (should_pause=False), sensitive blocks execute and return outputs."""
session = make_session(user_id=_TEST_USER_ID)
input_data = {
"repo_url": "https://github.com/test/repo",
"branch": "feature-branch",
}
mock_block = make_mock_block_with_schema(
block_id="delete-branch-id",
name="Delete Branch",
input_properties={
"repo_url": {"type": "string"},
"branch": {"type": "string"},
},
required_fields=["repo_url", "branch"],
)
mock_block.is_sensitive_action = True
mock_block.is_block_exec_need_review = AsyncMock(
return_value=(False, input_data)
)
async def mock_execute(input_data, **kwargs):
yield "result", "Branch deleted successfully"
mock_block.execute = mock_execute
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="test-workspace-id")
)
with (
patch(
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
),
patch(
"backend.copilot.tools.helpers.match_credentials_to_requirements",
return_value=({}, []),
),
patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
),
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="delete-branch-id",
input_data=input_data,
)
assert isinstance(response, BlockOutputResponse)
assert response.success is True
@pytest.mark.asyncio(loop_scope="session")
async def test_non_sensitive_block_executes_normally(self):
"""Non-sensitive blocks skip review and execute directly."""
session = make_session(user_id=_TEST_USER_ID)
input_data = {"url": "https://example.com"}
mock_block = make_mock_block_with_schema(
block_id="http-request-id",
name="HTTP Request",
input_properties={
"url": {"type": "string"},
},
required_fields=["url"],
)
mock_block.is_sensitive_action = False
mock_block.is_block_exec_need_review = AsyncMock(
return_value=(False, input_data)
)
async def mock_execute(input_data, **kwargs):
yield "response", {"status": 200}
mock_block.execute = mock_execute
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="test-workspace-id")
)
with (
patch(
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
),
patch(
"backend.copilot.tools.helpers.match_credentials_to_requirements",
return_value=({}, []),
),
patch(
"backend.copilot.tools.helpers.workspace_db",
return_value=mock_workspace_db,
),
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="http-request-id",
input_data=input_data,
)
assert isinstance(response, BlockOutputResponse)
assert response.success is True

View File

@@ -14,7 +14,7 @@ from backend.blocks.mcp.helpers import (
)
from backend.copilot.model import ChatSession
from backend.copilot.tools.utils import build_missing_credentials_from_field_info
from backend.util.request import HTTPClientError, validate_url
from backend.util.request import HTTPClientError, validate_url_host
from .base import BaseTool
from .models import (
@@ -53,15 +53,9 @@ class RunMCPToolTool(BaseTool):
def description(self) -> str:
return (
"Connect to an MCP (Model Context Protocol) server to discover and execute its tools. "
"Two-step workflow: (1) Call with just `server_url` to discover available tools. "
"(2) Call again with `server_url`, `tool_name`, and `tool_arguments` to execute. "
"Known hosted servers (use directly): Notion (https://mcp.notion.com/mcp), "
"Linear (https://mcp.linear.app/mcp), Stripe (https://mcp.stripe.com), "
"Intercom (https://mcp.intercom.com/mcp), Cloudflare (https://mcp.cloudflare.com/mcp), "
"Atlassian/Jira (https://mcp.atlassian.com/mcp). "
"For other services, search the MCP registry at https://registry.modelcontextprotocol.io/. "
"Authentication: If the server requires credentials, user will be prompted to complete the MCP credential setup flow."
"Once connected and user confirms, retry the same call immediately."
"Two-step: (1) call with server_url to list available tools, "
"(2) call again with server_url + tool_name + tool_arguments to execute. "
"Call get_mcp_guide for known server URLs and auth details."
)
@property
@@ -150,7 +144,7 @@ class RunMCPToolTool(BaseTool):
# Validate URL to prevent SSRF — blocks loopback and private IP ranges
try:
await validate_url(server_url, trusted_origins=[])
await validate_url_host(server_url)
except ValueError as e:
msg = str(e)
if "Unable to resolve" in msg or "No IP addresses" in msg:

View File

@@ -65,9 +65,8 @@ async def test_run_block_returns_details_when_no_input_provided():
return_value=http_block,
):
# Mock credentials check to return no missing credentials
with patch.object(
RunBlockTool,
"_resolve_block_credentials",
with patch(
"backend.copilot.tools.run_block.resolve_block_credentials",
new_callable=AsyncMock,
return_value=({}, []), # (matched_credentials, missing_credentials)
):
@@ -123,9 +122,8 @@ async def test_run_block_returns_details_when_only_credentials_provided():
"backend.copilot.tools.run_block.get_block",
return_value=mock,
):
with patch.object(
RunBlockTool,
"_resolve_block_credentials",
with patch(
"backend.copilot.tools.run_block.resolve_block_credentials",
new_callable=AsyncMock,
return_value=(
{

View File

@@ -100,7 +100,7 @@ async def test_ssrf_blocked_url_returns_error():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url",
"backend.copilot.tools.run_mcp_tool.validate_url_host",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):
@@ -138,7 +138,7 @@ async def test_non_dict_tool_arguments_returns_error():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url",
"backend.copilot.tools.run_mcp_tool.validate_url_host",
new_callable=AsyncMock,
):
with patch(
@@ -171,7 +171,7 @@ async def test_discover_tools_returns_discovered_response():
mock_tools = _make_tool_list("fetch", "search")
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -208,7 +208,7 @@ async def test_discover_tools_with_credentials():
mock_tools = _make_tool_list("push_notification")
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -249,7 +249,7 @@ async def test_execute_tool_returns_output_response():
text_result = "# Example Domain\nThis domain is for examples."
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -285,7 +285,7 @@ async def test_execute_tool_parses_json_result():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -320,7 +320,7 @@ async def test_execute_tool_image_content():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -359,7 +359,7 @@ async def test_execute_tool_resource_content():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -399,7 +399,7 @@ async def test_execute_tool_multi_item_content():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -437,7 +437,7 @@ async def test_execute_tool_empty_content_returns_none():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -470,7 +470,7 @@ async def test_execute_tool_returns_error_on_tool_failure():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -512,7 +512,7 @@ async def test_auth_required_without_creds_returns_setup_requirements():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -555,7 +555,7 @@ async def test_auth_error_with_existing_creds_returns_error():
mock_creds.access_token = SecretStr("stale-token")
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -589,7 +589,7 @@ async def test_mcp_client_error_returns_error_response():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -621,7 +621,7 @@ async def test_unexpected_exception_returns_generic_error():
session = make_session(_USER_ID)
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
@@ -719,7 +719,7 @@ async def test_credential_lookup_normalizes_trailing_slash():
url_with_slash = "https://mcp.example.com/mcp/"
with patch(
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
):
with patch(
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",

View File

@@ -0,0 +1,116 @@
"""ValidateAgentGraphTool - Validates agent JSON structure."""
import logging
from typing import Any
from backend.copilot.model import ChatSession
from .agent_generator.validation import AgentValidator, get_blocks_as_dicts
from .base import BaseTool
from .models import ErrorResponse, ToolResponseBase, ValidationResultResponse
logger = logging.getLogger(__name__)
class ValidateAgentGraphTool(BaseTool):
"""Tool for validating agent JSON graphs."""
@property
def name(self) -> str:
return "validate_agent_graph"
@property
def description(self) -> str:
return (
"Validate an agent JSON graph for correctness. Checks:\n"
"- All block_ids reference real blocks\n"
"- All links reference valid source/sink nodes and fields\n"
"- Required input fields are wired or have defaults\n"
"- Data types are compatible across links\n"
"- Nested sink links use correct notation\n"
"- Prompt templates use proper curly brace escaping\n"
"- AgentExecutorBlock configurations are valid\n\n"
"Call this after generating agent JSON to verify correctness. "
"If validation fails, either fix issues manually based on the error "
"descriptions, or call fix_agent_graph to auto-fix common problems."
)
@property
def requires_auth(self) -> bool:
return False
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"agent_json": {
"type": "object",
"description": (
"The agent JSON to validate. Must contain 'nodes' and 'links' arrays. "
"Each node needs: id (UUID), block_id, input_default, metadata. "
"Each link needs: id (UUID), source_id, source_name, sink_id, sink_name."
),
},
},
"required": ["agent_json"],
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
agent_json = kwargs.get("agent_json")
session_id = session.session_id if session else None
if not agent_json or not isinstance(agent_json, dict):
return ErrorResponse(
message="Please provide a valid agent JSON object.",
error="Missing or invalid agent_json parameter",
session_id=session_id,
)
nodes = agent_json.get("nodes", [])
if not nodes:
return ErrorResponse(
message="The agent JSON has no nodes. An agent needs at least one block.",
error="empty_agent",
session_id=session_id,
)
try:
blocks = get_blocks_as_dicts()
validator = AgentValidator()
is_valid, error_message = validator.validate(agent_json, blocks)
except Exception as e:
logger.error(f"Validation error: {e}", exc_info=True)
return ErrorResponse(
message=f"Validation encountered an error: {str(e)}",
error="validation_exception",
session_id=session_id,
)
if is_valid:
return ValidationResultResponse(
message="Agent graph is valid! No issues found.",
valid=True,
errors=[],
error_count=0,
session_id=session_id,
)
# Parse individual errors from the validator's error list
errors = validator.errors if hasattr(validator, "errors") else []
if not errors and error_message:
errors = [error_message]
return ValidationResultResponse(
message=f"Found {len(errors)} validation error(s). Fix them and re-validate.",
valid=False,
errors=errors,
error_count=len(errors),
session_id=session_id,
)

View File

@@ -0,0 +1,160 @@
"""Tests for ValidateAgentGraphTool."""
from unittest.mock import MagicMock, patch
import pytest
from backend.copilot.tools.models import ErrorResponse, ValidationResultResponse
from backend.copilot.tools.validate_agent import ValidateAgentGraphTool
from ._test_data import make_session
_TEST_USER_ID = "test-user-validate-agent"
@pytest.fixture
def tool():
return ValidateAgentGraphTool()
@pytest.fixture
def session():
return make_session(_TEST_USER_ID)
@pytest.mark.asyncio
async def test_missing_agent_json_returns_error(tool, session):
"""Missing agent_json returns ErrorResponse."""
result = await tool._execute(user_id=_TEST_USER_ID, session=session)
assert isinstance(result, ErrorResponse)
assert result.error is not None
assert "agent_json" in result.error.lower()
@pytest.mark.asyncio
async def test_empty_nodes_returns_error(tool, session):
"""Agent JSON with no nodes returns ErrorResponse."""
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json={"nodes": [], "links": []},
)
assert isinstance(result, ErrorResponse)
assert "no nodes" in result.message.lower()
@pytest.mark.asyncio
async def test_valid_agent_returns_success(tool, session):
"""Valid agent returns ValidationResultResponse with valid=True."""
agent_json = {
"nodes": [
{
"id": "node-1",
"block_id": "block-1",
"input_default": {},
"metadata": {"position": {"x": 0, "y": 0}},
}
],
"links": [],
}
mock_validator = MagicMock()
mock_validator.validate.return_value = (True, None)
mock_validator.errors = []
with (
patch(
"backend.copilot.tools.validate_agent.get_blocks_as_dicts",
return_value=[],
),
patch(
"backend.copilot.tools.validate_agent.AgentValidator",
return_value=mock_validator,
),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json=agent_json,
)
assert isinstance(result, ValidationResultResponse)
assert result.valid is True
assert result.error_count == 0
assert result.errors == []
@pytest.mark.asyncio
async def test_invalid_agent_returns_errors(tool, session):
"""Invalid agent returns ValidationResultResponse with errors."""
agent_json = {
"nodes": [
{
"id": "node-1",
"block_id": "nonexistent-block",
"input_default": {},
"metadata": {},
}
],
"links": [],
}
mock_validator = MagicMock()
mock_validator.validate.return_value = (False, "Validation failed")
mock_validator.errors = [
"Block 'nonexistent-block' not found in registry",
"Missing required input field 'prompt'",
]
with (
patch(
"backend.copilot.tools.validate_agent.get_blocks_as_dicts",
return_value=[],
),
patch(
"backend.copilot.tools.validate_agent.AgentValidator",
return_value=mock_validator,
),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json=agent_json,
)
assert isinstance(result, ValidationResultResponse)
assert result.valid is False
assert result.error_count == 2
assert len(result.errors) == 2
@pytest.mark.asyncio
async def test_validation_exception_returns_error(tool, session):
"""Validator exception returns ErrorResponse."""
agent_json = {
"nodes": [{"id": "n1", "block_id": "b1", "input_default": {}, "metadata": {}}],
"links": [],
}
mock_validator = MagicMock()
mock_validator.validate.side_effect = RuntimeError("unexpected error")
with (
patch(
"backend.copilot.tools.validate_agent.get_blocks_as_dicts",
return_value=[],
),
patch(
"backend.copilot.tools.validate_agent.AgentValidator",
return_value=mock_validator,
),
):
result = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
agent_json=agent_json,
)
assert isinstance(result, ErrorResponse)
assert result.error is not None
assert "validation_exception" in result.error

View File

@@ -7,8 +7,12 @@ from typing import Any, Optional
from pydantic import BaseModel
from backend.copilot.context import (
E2B_WORKDIR,
get_current_sandbox,
resolve_sandbox_path,
)
from backend.copilot.model import ChatSession
from backend.copilot.tools.e2b_sandbox import E2B_WORKDIR
from backend.copilot.tools.sandbox import make_session_path
from backend.data.db_accessors import workspace_db
from backend.util.settings import Config
@@ -83,13 +87,11 @@ def _resolve_sandbox_path(
) -> str | ErrorResponse:
"""Normalize *path* to an absolute sandbox path under :data:`E2B_WORKDIR`.
Delegates to :func:`~backend.copilot.sdk.e2b_file_tools._resolve_remote`
Delegates to :func:`~backend.copilot.sdk.e2b_file_tools.resolve_sandbox_path`
and wraps any ``ValueError`` into an :class:`ErrorResponse`.
"""
from backend.copilot.sdk.e2b_file_tools import _resolve_remote
try:
return _resolve_remote(path)
return resolve_sandbox_path(path)
except ValueError:
return ErrorResponse(
message=f"{param_name} must be within {E2B_WORKDIR}",
@@ -99,7 +101,6 @@ def _resolve_sandbox_path(
async def _read_source_path(source_path: str, session_id: str) -> bytes | ErrorResponse:
"""Read *source_path* from E2B sandbox or local ephemeral directory."""
from backend.copilot.sdk.tool_adapter import get_current_sandbox
sandbox = get_current_sandbox()
if sandbox is not None:
@@ -143,7 +144,6 @@ async def _save_to_path(
Returns the resolved path on success, or an ``ErrorResponse`` on failure.
"""
from backend.copilot.sdk.tool_adapter import get_current_sandbox
sandbox = get_current_sandbox()
if sandbox is not None:

Some files were not shown because too many files have changed in this diff Show More