mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
26 Commits
swiftyos/i
...
feat/analy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9afd9fa01 | ||
|
|
ddb4f6e9de | ||
|
|
f585d97928 | ||
|
|
7d39234fdd | ||
|
|
6e9d4c4333 | ||
|
|
8aad333a45 | ||
|
|
856f0d980d | ||
|
|
3c3aadd361 | ||
|
|
e87a693fdd | ||
|
|
fe265c10d4 | ||
|
|
5d00a94693 | ||
|
|
6e1605994d | ||
|
|
15e3980d65 | ||
|
|
fe9eb2564b | ||
|
|
5641cdd3ca | ||
|
|
bfb843a56e | ||
|
|
684845d946 | ||
|
|
6a6b23c2e1 | ||
|
|
d0a1d72e8a | ||
|
|
f1945d6a2f | ||
|
|
6491cb1e23 | ||
|
|
c7124a5240 | ||
|
|
5537cb2858 | ||
|
|
aef5f6d666 | ||
|
|
8063391d0a | ||
|
|
0bbb12d688 |
17
.claude/skills/backend-check/SKILL.md
Normal file
17
.claude/skills/backend-check/SKILL.md
Normal file
@@ -0,0 +1,17 @@
|
||||
---
|
||||
name: backend-check
|
||||
description: Run the full backend formatting, linting, and test suite. Ensures code quality before commits and PRs. TRIGGER when backend Python code has been modified and needs validation.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Backend Check
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Format**: `poetry run format` — runs formatting AND linting. NEVER run ruff/black/isort individually
|
||||
2. **Fix** any remaining errors manually, re-run until clean
|
||||
3. **Test**: `poetry run test` (runs DB setup + pytest). For specific files: `poetry run pytest -s -vvv <test_files>`
|
||||
4. **Snapshots** (if needed): `poetry run pytest path/to/test.py --snapshot-update` — review with `git diff`
|
||||
35
.claude/skills/code-style/SKILL.md
Normal file
35
.claude/skills/code-style/SKILL.md
Normal file
@@ -0,0 +1,35 @@
|
||||
---
|
||||
name: code-style
|
||||
description: Python code style preferences for the AutoGPT backend. Apply when writing or reviewing Python code. TRIGGER when writing new Python code, reviewing PRs, or refactoring backend code.
|
||||
user-invocable: false
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Code Style
|
||||
|
||||
## Imports
|
||||
|
||||
- **Top-level only** — no local/inner imports. Move all imports to the top of the file.
|
||||
|
||||
## Typing
|
||||
|
||||
- **No duck typing** — avoid `hasattr`, `getattr`, `isinstance` for type dispatch. Use proper typed interfaces, unions, or protocols.
|
||||
- **Pydantic models** over dataclass, namedtuple, or raw dict for structured data.
|
||||
- **No linter suppressors** — avoid `# type: ignore`, `# noqa`, `# pyright: ignore` etc. 99% of the time the right fix is fixing the type/code, not silencing the tool.
|
||||
|
||||
## Code Structure
|
||||
|
||||
- **List comprehensions** over manual loop-and-append.
|
||||
- **Early return** — guard clauses first, avoid deep nesting.
|
||||
- **Flatten inline** — prefer short, concise expressions. Reduce `if/else` chains with direct returns or ternaries when readable.
|
||||
- **Modular functions** — break complex logic into small, focused functions rather than long blocks with nested conditionals.
|
||||
|
||||
## Review Checklist
|
||||
|
||||
Before finishing, always ask:
|
||||
- Can any function be split into smaller pieces?
|
||||
- Is there unnecessary nesting that an early return would eliminate?
|
||||
- Can any loop be a comprehension?
|
||||
- Is there a simpler way to express this logic?
|
||||
16
.claude/skills/frontend-check/SKILL.md
Normal file
16
.claude/skills/frontend-check/SKILL.md
Normal file
@@ -0,0 +1,16 @@
|
||||
---
|
||||
name: frontend-check
|
||||
description: Run the full frontend formatting, linting, and type checking suite. Ensures code quality before commits and PRs. TRIGGER when frontend TypeScript/React code has been modified and needs validation.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Frontend Check
|
||||
|
||||
## Steps (in order)
|
||||
|
||||
1. **Format**: `pnpm format` — NEVER run individual formatters
|
||||
2. **Lint**: `pnpm lint` — fix errors, re-run until clean
|
||||
3. **Types**: `pnpm types` — if it keeps failing after multiple attempts, stop and ask the user
|
||||
29
.claude/skills/new-block/SKILL.md
Normal file
29
.claude/skills/new-block/SKILL.md
Normal file
@@ -0,0 +1,29 @@
|
||||
---
|
||||
name: new-block
|
||||
description: Create a new backend block following the Block SDK Guide. Guides through provider configuration, schema definition, authentication, and testing. TRIGGER when user asks to create a new block, add a new integration, or build a new node for the graph editor.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# New Block Creation
|
||||
|
||||
Read `docs/platform/block-sdk-guide.md` first for the full guide.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Provider config** (if external service): create `_config.py` with `ProviderBuilder`
|
||||
2. **Block file** in `backend/blocks/` (from `autogpt_platform/backend/`):
|
||||
- Generate a UUID once with `uuid.uuid4()`, then **hard-code that string** as `id` (IDs must be stable across imports)
|
||||
- `Input(BlockSchema)` and `Output(BlockSchema)` classes
|
||||
- `async def run` that `yield`s output fields
|
||||
3. **Files**: use `store_media_file()` with `"for_block_output"` for outputs
|
||||
4. **Test**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[MyBlock]' -xvs`
|
||||
5. **Format**: `poetry run format`
|
||||
|
||||
## Rules
|
||||
|
||||
- Analyze interfaces: do inputs/outputs connect well with other blocks in a graph?
|
||||
- Use top-level imports, avoid duck typing
|
||||
- Always use `for_block_output` for block outputs
|
||||
28
.claude/skills/openapi-regen/SKILL.md
Normal file
28
.claude/skills/openapi-regen/SKILL.md
Normal file
@@ -0,0 +1,28 @@
|
||||
---
|
||||
name: openapi-regen
|
||||
description: Regenerate the OpenAPI spec and frontend API client. Starts the backend REST server, fetches the spec, and regenerates the typed frontend hooks. TRIGGER when API routes change, new endpoints are added, or frontend API types are stale.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# OpenAPI Spec Regeneration
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Run end-to-end** in a single shell block (so `REST_PID` persists):
|
||||
```bash
|
||||
cd autogpt_platform/backend && poetry run rest &
|
||||
REST_PID=$!
|
||||
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && kill $REST_PID && exit 1; done
|
||||
cd ../frontend && pnpm generate:api:force
|
||||
kill $REST_PID
|
||||
pnpm types && pnpm lint && pnpm format
|
||||
```
|
||||
|
||||
## Rules
|
||||
|
||||
- Always use `pnpm generate:api:force` (not `pnpm generate:api`)
|
||||
- Don't manually edit files in `src/app/api/__generated__/`
|
||||
- Generated hooks follow: `use{Method}{Version}{OperationName}`
|
||||
31
.claude/skills/pr-create/SKILL.md
Normal file
31
.claude/skills/pr-create/SKILL.md
Normal file
@@ -0,0 +1,31 @@
|
||||
---
|
||||
name: pr-create
|
||||
description: Create a pull request for the current branch. TRIGGER when user asks to create a PR, open a pull request, push changes for review, or submit work for merging.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Create Pull Request
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Check for existing PR**: `gh pr view --json url -q .url 2>/dev/null` — if a PR already exists, output its URL and stop
|
||||
2. **Understand changes**: `git status`, `git diff dev...HEAD`, `git log dev..HEAD --oneline`
|
||||
3. **Read PR template**: `.github/PULL_REQUEST_TEMPLATE.md`
|
||||
4. **Draft PR title**: Use conventional commits format (see CLAUDE.md for types and scopes)
|
||||
5. **Fill out PR template** as the body — be thorough in the Changes section
|
||||
6. **Format first** (if relevant changes exist):
|
||||
- Backend: `cd autogpt_platform/backend && poetry run format`
|
||||
- Frontend: `cd autogpt_platform/frontend && pnpm format`
|
||||
- Fix any lint errors, then commit formatting changes before pushing
|
||||
7. **Push**: `git push -u origin HEAD`
|
||||
8. **Create PR**: `gh pr create --base dev`
|
||||
9. **Output** the PR URL
|
||||
|
||||
## Rules
|
||||
|
||||
- Always target `dev` branch
|
||||
- Do NOT run tests — CI will handle that
|
||||
- Use the PR template from `.github/PULL_REQUEST_TEMPLATE.md`
|
||||
51
.claude/skills/pr-review/SKILL.md
Normal file
51
.claude/skills/pr-review/SKILL.md
Normal file
@@ -0,0 +1,51 @@
|
||||
---
|
||||
name: pr-review
|
||||
description: Address all open PR review comments systematically. Fetches comments, addresses each one, reacts +1/-1, and replies when clarification is needed. Keeps iterating until all comments are addressed and CI is green. TRIGGER when user shares a PR URL, asks to address review comments, fix PR feedback, or respond to reviewer comments.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# PR Review Comment Workflow
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Find PR**: `gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT`
|
||||
2. **Fetch comments** (all three sources):
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` (top-level reviews)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` (inline review comments)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` (PR conversation comments)
|
||||
3. **Skip** comments already reacted to by PR author
|
||||
4. **For each unreacted comment**:
|
||||
- Read referenced code, make the fix (or reply if you disagree/need info)
|
||||
- **Inline review comments** (`pulls/{N}/comments`):
|
||||
- React: `gh api repos/.../pulls/comments/{ID}/reactions -f content="+1"` (or `-1`)
|
||||
- Reply: `gh api repos/.../pulls/{N}/comments/{ID}/replies -f body="..."`
|
||||
- **PR conversation comments** (`issues/{N}/comments`):
|
||||
- React: `gh api repos/.../issues/comments/{ID}/reactions -f content="+1"` (or `-1`)
|
||||
- No threaded replies — post a new issue comment if needed
|
||||
- **Top-level reviews**: no reaction API — address in code, reply via issue comment if needed
|
||||
5. **Include autogpt-reviewer bot fixes** too
|
||||
6. **Format**: `cd autogpt_platform/backend && poetry run format`, `cd autogpt_platform/frontend && pnpm format`
|
||||
7. **Commit & push**
|
||||
8. **Re-fetch comments** immediately — address any new unreacted ones before waiting on CI
|
||||
9. **Stay productive while CI runs** — don't idle. In priority order:
|
||||
- Run any pending local tests (`poetry run pytest`, e2e, etc.) and fix failures
|
||||
- Address any remaining comments
|
||||
- Only poll `gh pr checks {N}` as the last resort when there's truly nothing left to do
|
||||
10. **If CI fails** — fix, go back to step 6
|
||||
11. **Re-fetch comments again** after CI is green — address anything that appeared while CI was running
|
||||
12. **Done** only when: all comments reacted AND CI is green.
|
||||
|
||||
## CRITICAL: Do Not Stop
|
||||
|
||||
**Loop is: address → format → commit → push → re-check comments → run local tests → wait CI → re-check comments → repeat.**
|
||||
|
||||
Never idle. If CI is running and you have nothing to address, run local tests. Waiting on CI is the last resort.
|
||||
|
||||
## Rules
|
||||
|
||||
- One todo per comment
|
||||
- For inline review comments: reply on existing threads. For PR conversation comments: post a new issue comment (API doesn't support threaded replies)
|
||||
- React to every comment: +1 addressed, -1 disagreed (with explanation)
|
||||
45
.claude/skills/worktree-setup/SKILL.md
Normal file
45
.claude/skills/worktree-setup/SKILL.md
Normal file
@@ -0,0 +1,45 @@
|
||||
---
|
||||
name: worktree-setup
|
||||
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, generates Prisma client, and optionally starts the app (with port conflict resolution) or runs tests. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Worktree Setup
|
||||
|
||||
## Preferred: Use Branchlet
|
||||
|
||||
The repo has a `.branchlet.json` config — it handles env file copying, dependency installation, and Prisma generation automatically.
|
||||
|
||||
```bash
|
||||
npm install -g branchlet # install once
|
||||
branchlet create -n <name> -s <source-branch> -b <new-branch>
|
||||
branchlet list --json # list all worktrees
|
||||
```
|
||||
|
||||
## Manual Fallback
|
||||
|
||||
If branchlet isn't available:
|
||||
|
||||
1. `git worktree add ../<RepoName><N> <branch-name>`
|
||||
2. Copy `.env` files: `backend/.env`, `frontend/.env`, `autogpt_platform/.env`, `db/docker/.env`
|
||||
3. Install deps:
|
||||
- `cd autogpt_platform/backend && poetry install && poetry run prisma generate`
|
||||
- `cd autogpt_platform/frontend && pnpm install`
|
||||
|
||||
## Running the App
|
||||
|
||||
Free ports first — backend uses: 8001, 8002, 8003, 8005, 8006, 8007, 8008.
|
||||
|
||||
```bash
|
||||
for port in 8001 8002 8003 8005 8006 8007 8008; do
|
||||
lsof -ti :$port | xargs kill -9 2>/dev/null || true
|
||||
done
|
||||
cd <worktree>/autogpt_platform/backend && poetry run app
|
||||
```
|
||||
|
||||
## CoPilot Testing Gotcha
|
||||
|
||||
SDK mode spawns a Claude subprocess — **won't work inside Claude Code**. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.
|
||||
40
autogpt_platform/analytics/queries/auth_activities.sql
Normal file
40
autogpt_platform/analytics/queries/auth_activities.sql
Normal file
@@ -0,0 +1,40 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.auth_activities
|
||||
-- Looker source alias: ds49 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Tracks authentication events (login, logout, SSO, password
|
||||
-- reset, etc.) from Supabase's internal audit log.
|
||||
-- Useful for monitoring sign-in patterns and detecting anomalies.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.audit_log_entries — Supabase internal auth event log
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- created_at TIMESTAMPTZ When the auth event occurred
|
||||
-- actor_id TEXT User ID who triggered the event
|
||||
-- actor_via_sso TEXT Whether the action was via SSO ('true'/'false')
|
||||
-- action TEXT Event type (e.g. 'login', 'logout', 'token_refreshed')
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days from current date
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Daily login counts
|
||||
-- SELECT DATE_TRUNC('day', created_at) AS day, COUNT(*) AS logins
|
||||
-- FROM analytics.auth_activities
|
||||
-- WHERE action = 'login'
|
||||
-- GROUP BY 1 ORDER BY 1;
|
||||
--
|
||||
-- -- SSO vs password login breakdown
|
||||
-- SELECT actor_via_sso, COUNT(*) FROM analytics.auth_activities
|
||||
-- WHERE action = 'login' GROUP BY 1;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
created_at,
|
||||
payload->>'actor_id' AS actor_id,
|
||||
payload->>'actor_via_sso' AS actor_via_sso,
|
||||
payload->>'action' AS action
|
||||
FROM auth.audit_log_entries
|
||||
WHERE created_at >= NOW() - INTERVAL '90 days'
|
||||
105
autogpt_platform/analytics/queries/graph_execution.sql
Normal file
105
autogpt_platform/analytics/queries/graph_execution.sql
Normal file
@@ -0,0 +1,105 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.graph_execution
|
||||
-- Looker source alias: ds16 | Charts: 21
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per agent graph execution (last 90 days).
|
||||
-- Unpacks the JSONB stats column into individual numeric columns
|
||||
-- and normalises the executionStatus — runs that failed due to
|
||||
-- insufficient credits are reclassified as 'NO_CREDITS' for
|
||||
-- easier filtering. Error messages are scrubbed of IDs and URLs
|
||||
-- to allow safe grouping.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentGraphExecution — Execution records
|
||||
-- platform.AgentGraph — Agent graph metadata (for name)
|
||||
-- platform.LibraryAgent — To flag possibly-AI (safe-mode) agents
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- id TEXT Execution UUID
|
||||
-- agentGraphId TEXT Agent graph UUID
|
||||
-- agentGraphVersion INT Graph version number
|
||||
-- executionStatus TEXT COMPLETED | FAILED | NO_CREDITS | RUNNING | QUEUED | TERMINATED
|
||||
-- createdAt TIMESTAMPTZ When the execution was queued
|
||||
-- updatedAt TIMESTAMPTZ Last status update time
|
||||
-- userId TEXT Owner user UUID
|
||||
-- agentGraphName TEXT Human-readable agent name
|
||||
-- cputime DECIMAL Total CPU seconds consumed
|
||||
-- walltime DECIMAL Total wall-clock seconds
|
||||
-- node_count DECIMAL Number of nodes in the graph
|
||||
-- nodes_cputime DECIMAL CPU time across all nodes
|
||||
-- nodes_walltime DECIMAL Wall time across all nodes
|
||||
-- execution_cost DECIMAL Credit cost of this execution
|
||||
-- correctness_score FLOAT AI correctness score (if available)
|
||||
-- possibly_ai BOOLEAN True if agent has sensitive_action_safe_mode enabled
|
||||
-- groupedErrorMessage TEXT Scrubbed error string (IDs/URLs replaced with wildcards)
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Daily execution counts by status
|
||||
-- SELECT DATE_TRUNC('day', "createdAt") AS day, "executionStatus", COUNT(*)
|
||||
-- FROM analytics.graph_execution
|
||||
-- GROUP BY 1, 2 ORDER BY 1;
|
||||
--
|
||||
-- -- Average cost per execution by agent
|
||||
-- SELECT "agentGraphName", AVG("execution_cost") AS avg_cost, COUNT(*) AS runs
|
||||
-- FROM analytics.graph_execution
|
||||
-- WHERE "executionStatus" = 'COMPLETED'
|
||||
-- GROUP BY 1 ORDER BY avg_cost DESC;
|
||||
--
|
||||
-- -- Top error messages
|
||||
-- SELECT "groupedErrorMessage", COUNT(*) AS occurrences
|
||||
-- FROM analytics.graph_execution
|
||||
-- WHERE "executionStatus" = 'FAILED'
|
||||
-- GROUP BY 1 ORDER BY 2 DESC LIMIT 20;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
ge."id" AS id,
|
||||
ge."agentGraphId" AS agentGraphId,
|
||||
ge."agentGraphVersion" AS agentGraphVersion,
|
||||
CASE
|
||||
WHEN jsonb_exists(ge."stats"::jsonb, 'error')
|
||||
AND (
|
||||
(ge."stats"::jsonb->>'error') ILIKE '%insufficient balance%'
|
||||
OR (ge."stats"::jsonb->>'error') ILIKE '%you have no credits left%'
|
||||
)
|
||||
THEN 'NO_CREDITS'
|
||||
ELSE CAST(ge."executionStatus" AS TEXT)
|
||||
END AS executionStatus,
|
||||
ge."createdAt" AS createdAt,
|
||||
ge."updatedAt" AS updatedAt,
|
||||
ge."userId" AS userId,
|
||||
g."name" AS agentGraphName,
|
||||
(ge."stats"::jsonb->>'cputime')::decimal AS cputime,
|
||||
(ge."stats"::jsonb->>'walltime')::decimal AS walltime,
|
||||
(ge."stats"::jsonb->>'node_count')::decimal AS node_count,
|
||||
(ge."stats"::jsonb->>'nodes_cputime')::decimal AS nodes_cputime,
|
||||
(ge."stats"::jsonb->>'nodes_walltime')::decimal AS nodes_walltime,
|
||||
(ge."stats"::jsonb->>'cost')::decimal AS execution_cost,
|
||||
(ge."stats"::jsonb->>'correctness_score')::float AS correctness_score,
|
||||
COALESCE(la.possibly_ai, FALSE) AS possibly_ai,
|
||||
REGEXP_REPLACE(
|
||||
REGEXP_REPLACE(
|
||||
TRIM(BOTH '"' FROM ge."stats"::jsonb->>'error'),
|
||||
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
|
||||
'\1\2/...', 'gi'
|
||||
),
|
||||
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
|
||||
) AS groupedErrorMessage
|
||||
FROM platform."AgentGraphExecution" ge
|
||||
LEFT JOIN platform."AgentGraph" g
|
||||
ON ge."agentGraphId" = g."id"
|
||||
AND ge."agentGraphVersion" = g."version"
|
||||
LEFT JOIN (
|
||||
SELECT DISTINCT ON ("userId", "agentGraphId")
|
||||
"userId", "agentGraphId",
|
||||
("settings"::jsonb->>'sensitive_action_safe_mode')::boolean AS possibly_ai
|
||||
FROM platform."LibraryAgent"
|
||||
WHERE "isDeleted" = FALSE
|
||||
AND "isArchived" = FALSE
|
||||
ORDER BY "userId", "agentGraphId", "agentGraphVersion" DESC
|
||||
) la ON la."userId" = ge."userId" AND la."agentGraphId" = ge."agentGraphId"
|
||||
WHERE ge."createdAt" > CURRENT_DATE - INTERVAL '90 days'
|
||||
101
autogpt_platform/analytics/queries/node_block_execution.sql
Normal file
101
autogpt_platform/analytics/queries/node_block_execution.sql
Normal file
@@ -0,0 +1,101 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.node_block_execution
|
||||
-- Looker source alias: ds14 | Charts: 11
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per node (block) execution (last 90 days).
|
||||
-- Unpacks stats JSONB and joins to identify which block type
|
||||
-- was run. For failed nodes, joins the error output and
|
||||
-- scrubs it for safe grouping.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentNodeExecution — Node execution records
|
||||
-- platform.AgentNode — Node → block mapping
|
||||
-- platform.AgentBlock — Block name/ID
|
||||
-- platform.AgentNodeExecutionInputOutput — Error output values
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- id TEXT Node execution UUID
|
||||
-- agentGraphExecutionId TEXT Parent graph execution UUID
|
||||
-- agentNodeId TEXT Node UUID within the graph
|
||||
-- executionStatus TEXT COMPLETED | FAILED | QUEUED | RUNNING | TERMINATED
|
||||
-- addedTime TIMESTAMPTZ When the node was queued
|
||||
-- queuedTime TIMESTAMPTZ When it entered the queue
|
||||
-- startedTime TIMESTAMPTZ When execution started
|
||||
-- endedTime TIMESTAMPTZ When execution finished
|
||||
-- inputSize BIGINT Input payload size in bytes
|
||||
-- outputSize BIGINT Output payload size in bytes
|
||||
-- walltime NUMERIC Wall-clock seconds for this node
|
||||
-- cputime NUMERIC CPU seconds for this node
|
||||
-- llmRetryCount INT Number of LLM retries
|
||||
-- llmCallCount INT Number of LLM API calls made
|
||||
-- inputTokenCount BIGINT LLM input tokens consumed
|
||||
-- outputTokenCount BIGINT LLM output tokens produced
|
||||
-- blockName TEXT Human-readable block name (e.g. 'OpenAIBlock')
|
||||
-- blockId TEXT Block UUID
|
||||
-- groupedErrorMessage TEXT Scrubbed error (IDs/URLs wildcarded)
|
||||
-- errorMessage TEXT Raw error output (only set when FAILED)
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days (addedTime > CURRENT_DATE - 90 days)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Most-used blocks by execution count
|
||||
-- SELECT "blockName", COUNT(*) AS executions,
|
||||
-- COUNT(*) FILTER (WHERE "executionStatus"='FAILED') AS failures
|
||||
-- FROM analytics.node_block_execution
|
||||
-- GROUP BY 1 ORDER BY executions DESC LIMIT 20;
|
||||
--
|
||||
-- -- Average LLM token usage per block
|
||||
-- SELECT "blockName",
|
||||
-- AVG("inputTokenCount") AS avg_input_tokens,
|
||||
-- AVG("outputTokenCount") AS avg_output_tokens
|
||||
-- FROM analytics.node_block_execution
|
||||
-- WHERE "llmCallCount" > 0
|
||||
-- GROUP BY 1 ORDER BY avg_input_tokens DESC;
|
||||
--
|
||||
-- -- Top failure reasons
|
||||
-- SELECT "blockName", "groupedErrorMessage", COUNT(*) AS count
|
||||
-- FROM analytics.node_block_execution
|
||||
-- WHERE "executionStatus" = 'FAILED'
|
||||
-- GROUP BY 1, 2 ORDER BY count DESC LIMIT 20;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
ne."id" AS id,
|
||||
ne."agentGraphExecutionId" AS agentGraphExecutionId,
|
||||
ne."agentNodeId" AS agentNodeId,
|
||||
CAST(ne."executionStatus" AS TEXT) AS executionStatus,
|
||||
ne."addedTime" AS addedTime,
|
||||
ne."queuedTime" AS queuedTime,
|
||||
ne."startedTime" AS startedTime,
|
||||
ne."endedTime" AS endedTime,
|
||||
(ne."stats"::jsonb->>'input_size')::bigint AS inputSize,
|
||||
(ne."stats"::jsonb->>'output_size')::bigint AS outputSize,
|
||||
(ne."stats"::jsonb->>'walltime')::numeric AS walltime,
|
||||
(ne."stats"::jsonb->>'cputime')::numeric AS cputime,
|
||||
(ne."stats"::jsonb->>'llm_retry_count')::int AS llmRetryCount,
|
||||
(ne."stats"::jsonb->>'llm_call_count')::int AS llmCallCount,
|
||||
(ne."stats"::jsonb->>'input_token_count')::bigint AS inputTokenCount,
|
||||
(ne."stats"::jsonb->>'output_token_count')::bigint AS outputTokenCount,
|
||||
b."name" AS blockName,
|
||||
b."id" AS blockId,
|
||||
REGEXP_REPLACE(
|
||||
REGEXP_REPLACE(
|
||||
TRIM(BOTH '"' FROM eio."data"::text),
|
||||
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
|
||||
'\1\2/...', 'gi'
|
||||
),
|
||||
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
|
||||
) AS groupedErrorMessage,
|
||||
eio."data" AS errorMessage
|
||||
FROM platform."AgentNodeExecution" ne
|
||||
LEFT JOIN platform."AgentNode" nd
|
||||
ON ne."agentNodeId" = nd."id"
|
||||
LEFT JOIN platform."AgentBlock" b
|
||||
ON nd."agentBlockId" = b."id"
|
||||
LEFT JOIN platform."AgentNodeExecutionInputOutput" eio
|
||||
ON eio."referencedByOutputExecId" = ne."id"
|
||||
AND eio."name" = 'error'
|
||||
AND ne."executionStatus" = 'FAILED'
|
||||
WHERE ne."addedTime" > CURRENT_DATE - INTERVAL '90 days'
|
||||
97
autogpt_platform/analytics/queries/retention_agent.sql
Normal file
97
autogpt_platform/analytics/queries/retention_agent.sql
Normal file
@@ -0,0 +1,97 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_agent
|
||||
-- Looker source alias: ds35 | Charts: 2
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Weekly cohort retention broken down per individual agent.
|
||||
-- Cohort = week of a user's first use of THAT specific agent.
|
||||
-- Tells you which agents keep users coming back vs. one-shot
|
||||
-- use. Only includes cohorts from the last 180 days.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentGraphExecution — Execution records (user × agent × time)
|
||||
-- platform.AgentGraph — Agent names
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- agent_id TEXT Agent graph UUID
|
||||
-- agent_label TEXT 'AgentName [first8chars]'
|
||||
-- agent_label_n TEXT 'AgentName [first8chars] (n=total_users)'
|
||||
-- cohort_week_start DATE Week users first ran this agent
|
||||
-- cohort_label TEXT ISO week label
|
||||
-- cohort_label_n TEXT ISO week label with cohort size
|
||||
-- user_lifetime_week INT Weeks since first use of this agent
|
||||
-- cohort_users BIGINT Users in this cohort for this agent
|
||||
-- active_users BIGINT Users who ran the agent again in week k
|
||||
-- retention_rate FLOAT active_users / cohort_users
|
||||
-- cohort_users_w0 BIGINT cohort_users only at week 0 (safe to SUM)
|
||||
-- agent_total_users BIGINT Total users across all cohorts for this agent
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Best-retained agents at week 2
|
||||
-- SELECT agent_label, AVG(retention_rate) AS w2_retention
|
||||
-- FROM analytics.retention_agent
|
||||
-- WHERE user_lifetime_week = 2 AND cohort_users >= 10
|
||||
-- GROUP BY 1 ORDER BY w2_retention DESC LIMIT 10;
|
||||
--
|
||||
-- -- Agents with most unique users
|
||||
-- SELECT DISTINCT agent_label, agent_total_users
|
||||
-- FROM analytics.retention_agent
|
||||
-- ORDER BY agent_total_users DESC LIMIT 20;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
|
||||
events AS (
|
||||
SELECT e."userId"::text AS user_id, e."agentGraphId" AS agent_id,
|
||||
e."createdAt"::timestamptz AS created_at,
|
||||
DATE_TRUNC('week', e."createdAt")::date AS week_start
|
||||
FROM platform."AgentGraphExecution" e
|
||||
),
|
||||
first_use AS (
|
||||
SELECT user_id, agent_id, MIN(created_at) AS first_use_at,
|
||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||
FROM events GROUP BY 1,2
|
||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||
),
|
||||
activity_weeks AS (SELECT DISTINCT user_id, agent_id, week_start FROM events),
|
||||
user_week_age AS (
|
||||
SELECT aw.user_id, aw.agent_id, fu.cohort_week_start,
|
||||
((aw.week_start - DATE_TRUNC('week',fu.first_use_at)::date)/7)::int AS user_lifetime_week
|
||||
FROM activity_weeks aw JOIN first_use fu USING (user_id, agent_id)
|
||||
WHERE aw.week_start >= DATE_TRUNC('week',fu.first_use_at)::date
|
||||
),
|
||||
active_counts AS (
|
||||
SELECT agent_id, cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users
|
||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2,3
|
||||
),
|
||||
cohort_sizes AS (
|
||||
SELECT agent_id, cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_use GROUP BY 1,2
|
||||
),
|
||||
cohort_caps AS (
|
||||
SELECT cs.agent_id, cs.cohort_week_start, cs.cohort_users,
|
||||
LEAST((SELECT max_weeks FROM params),
|
||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.agent_id, cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||
),
|
||||
agent_names AS (SELECT DISTINCT ON (g."id") g."id" AS agent_id, g."name" AS agent_name FROM platform."AgentGraph" g ORDER BY g."id", g."version" DESC),
|
||||
agent_total_users AS (SELECT agent_id, SUM(cohort_users) AS agent_total_users FROM cohort_sizes GROUP BY 1)
|
||||
SELECT
|
||||
g.agent_id,
|
||||
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||']' AS agent_label,
|
||||
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||'] (n='||COALESCE(atu.agent_total_users,0)||')' AS agent_label_n,
|
||||
g.cohort_week_start,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_week, g.cohort_users,
|
||||
COALESCE(ac.active_users,0) AS active_users,
|
||||
COALESCE(ac.active_users,0)::float / NULLIF(g.cohort_users,0) AS retention_rate,
|
||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0,
|
||||
COALESCE(atu.agent_total_users,0) AS agent_total_users
|
||||
FROM grid g
|
||||
LEFT JOIN active_counts ac ON ac.agent_id=g.agent_id AND ac.cohort_week_start=g.cohort_week_start AND ac.user_lifetime_week=g.user_lifetime_week
|
||||
LEFT JOIN agent_names an ON an.agent_id=g.agent_id
|
||||
LEFT JOIN agent_total_users atu ON atu.agent_id=g.agent_id
|
||||
ORDER BY agent_label, g.cohort_week_start, g.user_lifetime_week;
|
||||
@@ -0,0 +1,81 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_execution_daily
|
||||
-- Looker source alias: ds111 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Daily cohort retention based on agent executions.
|
||||
-- Cohort anchor = day of user's FIRST ever execution.
|
||||
-- Only includes cohorts from the last 90 days, up to day 30.
|
||||
-- Great for early engagement analysis (did users run another
|
||||
-- agent the next day?).
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentGraphExecution — Execution records
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- Same pattern as retention_login_daily.
|
||||
-- cohort_day_start = day of first execution (not first login)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Day-3 execution retention
|
||||
-- SELECT cohort_label, retention_rate_bounded AS d3_retention
|
||||
-- FROM analytics.retention_execution_daily
|
||||
-- WHERE user_lifetime_day = 3 ORDER BY cohort_day_start;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days') AS cohort_start),
|
||||
events AS (
|
||||
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
|
||||
DATE_TRUNC('day', e."createdAt")::date AS day_start
|
||||
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
|
||||
),
|
||||
first_exec AS (
|
||||
SELECT user_id, MIN(created_at) AS first_exec_at,
|
||||
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
|
||||
FROM events GROUP BY 1
|
||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||
),
|
||||
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
|
||||
user_day_age AS (
|
||||
SELECT ad.user_id, fe.cohort_day_start,
|
||||
(ad.day_start - DATE_TRUNC('day',fe.first_exec_at)::date)::int AS user_lifetime_day
|
||||
FROM activity_days ad JOIN first_exec fe USING (user_id)
|
||||
WHERE ad.day_start >= DATE_TRUNC('day',fe.first_exec_at)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_day_start, cs.cohort_users,
|
||||
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_day_start,
|
||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
|
||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_day, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
|
||||
ORDER BY g.cohort_day_start, g.user_lifetime_day;
|
||||
@@ -0,0 +1,81 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_execution_weekly
|
||||
-- Looker source alias: ds92 | Charts: 2
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Weekly cohort retention based on agent executions.
|
||||
-- Cohort anchor = week of user's FIRST ever agent execution
|
||||
-- (not first login). Only includes cohorts from the last 180 days.
|
||||
-- Useful when you care about product engagement, not just visits.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentGraphExecution — Execution records
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- Same pattern as retention_login_weekly.
|
||||
-- cohort_week_start = week of first execution (not first login)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Week-2 execution retention
|
||||
-- SELECT cohort_label, retention_rate_bounded
|
||||
-- FROM analytics.retention_execution_weekly
|
||||
-- WHERE user_lifetime_week = 2 ORDER BY cohort_week_start;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
|
||||
events AS (
|
||||
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
|
||||
DATE_TRUNC('week', e."createdAt")::date AS week_start
|
||||
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
|
||||
),
|
||||
first_exec AS (
|
||||
SELECT user_id, MIN(created_at) AS first_exec_at,
|
||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||
FROM events GROUP BY 1
|
||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||
),
|
||||
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
||||
user_week_age AS (
|
||||
SELECT aw.user_id, fe.cohort_week_start,
|
||||
((aw.week_start - DATE_TRUNC('week',fe.first_exec_at)::date)/7)::int AS user_lifetime_week
|
||||
FROM activity_weeks aw JOIN first_exec fe USING (user_id)
|
||||
WHERE aw.week_start >= DATE_TRUNC('week',fe.first_exec_at)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_week_start, cs.cohort_users,
|
||||
LEAST((SELECT max_weeks FROM params),
|
||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_week_start,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_week, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
||||
ORDER BY g.cohort_week_start, g.user_lifetime_week;
|
||||
94
autogpt_platform/analytics/queries/retention_login_daily.sql
Normal file
94
autogpt_platform/analytics/queries/retention_login_daily.sql
Normal file
@@ -0,0 +1,94 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_login_daily
|
||||
-- Looker source alias: ds112 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Daily cohort retention based on login sessions.
|
||||
-- Same logic as retention_login_weekly but at day granularity,
|
||||
-- showing up to day 30 for cohorts from the last 90 days.
|
||||
-- Useful for analysing early activation (days 1-7) in detail.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.sessions — Login session records
|
||||
--
|
||||
-- OUTPUT COLUMNS (same pattern as retention_login_weekly)
|
||||
-- cohort_day_start DATE First day the cohort logged in
|
||||
-- cohort_label TEXT Date string (e.g. '2025-03-01')
|
||||
-- cohort_label_n TEXT Date + cohort size (e.g. '2025-03-01 (n=12)')
|
||||
-- user_lifetime_day INT Days since first login (0 = signup day)
|
||||
-- cohort_users BIGINT Total users in cohort
|
||||
-- active_users_bounded BIGINT Users active on exactly day k
|
||||
-- retained_users_unbounded BIGINT Users active any time on/after day k
|
||||
-- retention_rate_bounded FLOAT bounded / cohort_users
|
||||
-- retention_rate_unbounded FLOAT unbounded / cohort_users
|
||||
-- cohort_users_d0 BIGINT cohort_users only at day 0, else 0 (safe to SUM)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Day-1 retention rate (came back next day)
|
||||
-- SELECT cohort_label, retention_rate_bounded AS d1_retention
|
||||
-- FROM analytics.retention_login_daily
|
||||
-- WHERE user_lifetime_day = 1 ORDER BY cohort_day_start;
|
||||
--
|
||||
-- -- Average retention curve across all cohorts
|
||||
-- SELECT user_lifetime_day,
|
||||
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_d0), 0) AS avg_retention
|
||||
-- FROM analytics.retention_login_daily
|
||||
-- GROUP BY 1 ORDER BY 1;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days')::date AS cohort_start),
|
||||
events AS (
|
||||
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
||||
DATE_TRUNC('day', s.created_at)::date AS day_start
|
||||
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
||||
),
|
||||
first_login AS (
|
||||
SELECT user_id, MIN(created_at) AS first_login_time,
|
||||
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
|
||||
FROM events GROUP BY 1
|
||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||
),
|
||||
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
|
||||
user_day_age AS (
|
||||
SELECT ad.user_id, fl.cohort_day_start,
|
||||
(ad.day_start - DATE_TRUNC('day', fl.first_login_time)::date)::int AS user_lifetime_day
|
||||
FROM activity_days ad JOIN first_login fl USING (user_id)
|
||||
WHERE ad.day_start >= DATE_TRUNC('day', fl.first_login_time)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_day_start, cs.cohort_users,
|
||||
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_day_start,
|
||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
|
||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_day, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
|
||||
ORDER BY g.cohort_day_start, g.user_lifetime_day;
|
||||
@@ -0,0 +1,96 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_login_onboarded_weekly
|
||||
-- Looker source alias: ds101 | Charts: 2
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Weekly cohort retention from login sessions, restricted to
|
||||
-- users who "onboarded" — defined as running at least one
|
||||
-- agent within 365 days of their first login.
|
||||
-- Filters out users who signed up but never activated,
|
||||
-- giving a cleaner view of engaged-user retention.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.sessions — Login session records
|
||||
-- platform.AgentGraphExecution — Used to identify onboarders
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- Same as retention_login_weekly (cohort_week_start, user_lifetime_week,
|
||||
-- retention_rate_bounded, retention_rate_unbounded, etc.)
|
||||
-- Only difference: cohort is filtered to onboarded users only.
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Compare week-4 retention: all users vs onboarded only
|
||||
-- SELECT 'all_users' AS segment, AVG(retention_rate_bounded) AS w4_retention
|
||||
-- FROM analytics.retention_login_weekly WHERE user_lifetime_week = 4
|
||||
-- UNION ALL
|
||||
-- SELECT 'onboarded', AVG(retention_rate_bounded)
|
||||
-- FROM analytics.retention_login_onboarded_weekly WHERE user_lifetime_week = 4;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 12::int AS max_weeks, 365::int AS onboarding_window_days),
|
||||
events AS (
|
||||
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
||||
DATE_TRUNC('week', s.created_at)::date AS week_start
|
||||
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
||||
),
|
||||
first_login_all AS (
|
||||
SELECT user_id, MIN(created_at) AS first_login_time,
|
||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||
FROM events GROUP BY 1
|
||||
),
|
||||
onboarders AS (
|
||||
SELECT fl.user_id FROM first_login_all fl
|
||||
WHERE EXISTS (
|
||||
SELECT 1 FROM platform."AgentGraphExecution" e
|
||||
WHERE e."userId"::text = fl.user_id
|
||||
AND e."createdAt" >= fl.first_login_time
|
||||
AND e."createdAt" < fl.first_login_time
|
||||
+ make_interval(days => (SELECT onboarding_window_days FROM params))
|
||||
)
|
||||
),
|
||||
first_login AS (SELECT * FROM first_login_all WHERE user_id IN (SELECT user_id FROM onboarders)),
|
||||
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
||||
user_week_age AS (
|
||||
SELECT aw.user_id, fl.cohort_week_start,
|
||||
((aw.week_start - DATE_TRUNC('week',fl.first_login_time)::date)/7)::int AS user_lifetime_week
|
||||
FROM activity_weeks aw JOIN first_login fl USING (user_id)
|
||||
WHERE aw.week_start >= DATE_TRUNC('week',fl.first_login_time)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_week_start, cs.cohort_users,
|
||||
LEAST((SELECT max_weeks FROM params),
|
||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_week_start,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_week, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
||||
ORDER BY g.cohort_week_start, g.user_lifetime_week;
|
||||
103
autogpt_platform/analytics/queries/retention_login_weekly.sql
Normal file
103
autogpt_platform/analytics/queries/retention_login_weekly.sql
Normal file
@@ -0,0 +1,103 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_login_weekly
|
||||
-- Looker source alias: ds83 | Charts: 2
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Weekly cohort retention based on login sessions.
|
||||
-- Users are grouped by the ISO week of their first ever login.
|
||||
-- For each cohort × lifetime-week combination, outputs both:
|
||||
-- - bounded rate: % active in exactly that week
|
||||
-- - unbounded rate: % who were ever active on or after that week
|
||||
-- Weeks are capped to the cohort's actual age (no future data points).
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.sessions — Login session records
|
||||
--
|
||||
-- HOW TO READ THE OUTPUT
|
||||
-- cohort_week_start The Monday of the week users first logged in
|
||||
-- user_lifetime_week 0 = signup week, 1 = one week later, etc.
|
||||
-- retention_rate_bounded = active_users_bounded / cohort_users
|
||||
-- retention_rate_unbounded = retained_users_unbounded / cohort_users
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- cohort_week_start DATE First day of the cohort's signup week
|
||||
-- cohort_label TEXT ISO week label (e.g. '2025-W01')
|
||||
-- cohort_label_n TEXT ISO week label with cohort size (e.g. '2025-W01 (n=42)')
|
||||
-- user_lifetime_week INT Weeks since first login (0 = signup week)
|
||||
-- cohort_users BIGINT Total users in this cohort (denominator)
|
||||
-- active_users_bounded BIGINT Users active in exactly week k
|
||||
-- retained_users_unbounded BIGINT Users active any time on/after week k
|
||||
-- retention_rate_bounded FLOAT bounded active / cohort_users
|
||||
-- retention_rate_unbounded FLOAT unbounded retained / cohort_users
|
||||
-- cohort_users_w0 BIGINT cohort_users only at week 0, else 0 (safe to SUM in pivot tables)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Week-1 retention rate per cohort
|
||||
-- SELECT cohort_label, retention_rate_bounded AS w1_retention
|
||||
-- FROM analytics.retention_login_weekly
|
||||
-- WHERE user_lifetime_week = 1
|
||||
-- ORDER BY cohort_week_start;
|
||||
--
|
||||
-- -- Overall average retention curve (all cohorts combined)
|
||||
-- SELECT user_lifetime_week,
|
||||
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_w0), 0) AS avg_retention
|
||||
-- FROM analytics.retention_login_weekly
|
||||
-- GROUP BY 1 ORDER BY 1;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 12::int AS max_weeks),
|
||||
events AS (
|
||||
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
||||
DATE_TRUNC('week', s.created_at)::date AS week_start
|
||||
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
||||
),
|
||||
first_login AS (
|
||||
SELECT user_id, MIN(created_at) AS first_login_time,
|
||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||
FROM events GROUP BY 1
|
||||
),
|
||||
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
||||
user_week_age AS (
|
||||
SELECT aw.user_id, fl.cohort_week_start,
|
||||
((aw.week_start - DATE_TRUNC('week', fl.first_login_time)::date) / 7)::int AS user_lifetime_week
|
||||
FROM activity_weeks aw JOIN first_login fl USING (user_id)
|
||||
WHERE aw.week_start >= DATE_TRUNC('week', fl.first_login_time)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_week_start, cs.cohort_users,
|
||||
LEAST((SELECT max_weeks FROM params),
|
||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date - cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_week_start,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_week, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
||||
ORDER BY g.cohort_week_start, g.user_lifetime_week
|
||||
71
autogpt_platform/analytics/queries/user_block_spending.sql
Normal file
71
autogpt_platform/analytics/queries/user_block_spending.sql
Normal file
@@ -0,0 +1,71 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.user_block_spending
|
||||
-- Looker source alias: ds6 | Charts: 5
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per credit transaction (last 90 days).
|
||||
-- Shows how users spend credits broken down by block type,
|
||||
-- LLM provider and model. Joins node execution stats for
|
||||
-- token-level detail.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.CreditTransaction — Credit debit/credit records
|
||||
-- platform.AgentNodeExecution — Node execution stats (for token counts)
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- transactionKey TEXT Unique transaction identifier
|
||||
-- userId TEXT User who was charged
|
||||
-- amount DECIMAL Credit amount (positive = credit, negative = debit)
|
||||
-- negativeAmount DECIMAL amount * -1 (convenience for spend charts)
|
||||
-- transactionType TEXT Transaction type (e.g. 'USAGE', 'REFUND', 'TOP_UP')
|
||||
-- transactionTime TIMESTAMPTZ When the transaction was recorded
|
||||
-- blockId TEXT Block UUID that triggered the spend
|
||||
-- blockName TEXT Human-readable block name
|
||||
-- llm_provider TEXT LLM provider (e.g. 'openai', 'anthropic')
|
||||
-- llm_model TEXT Model name (e.g. 'gpt-4o', 'claude-3-5-sonnet')
|
||||
-- node_exec_id TEXT Linked node execution UUID
|
||||
-- llm_call_count INT LLM API calls made in that execution
|
||||
-- llm_retry_count INT LLM retries in that execution
|
||||
-- llm_input_token_count INT Input tokens consumed
|
||||
-- llm_output_token_count INT Output tokens produced
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Total spend per user (last 90 days)
|
||||
-- SELECT "userId", SUM("negativeAmount") AS total_spent
|
||||
-- FROM analytics.user_block_spending
|
||||
-- WHERE "transactionType" = 'USAGE'
|
||||
-- GROUP BY 1 ORDER BY total_spent DESC;
|
||||
--
|
||||
-- -- Spend by LLM provider + model
|
||||
-- SELECT "llm_provider", "llm_model",
|
||||
-- SUM("negativeAmount") AS total_cost,
|
||||
-- SUM("llm_input_token_count") AS input_tokens,
|
||||
-- SUM("llm_output_token_count") AS output_tokens
|
||||
-- FROM analytics.user_block_spending
|
||||
-- WHERE "llm_provider" IS NOT NULL
|
||||
-- GROUP BY 1, 2 ORDER BY total_cost DESC;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
c."transactionKey" AS transactionKey,
|
||||
c."userId" AS userId,
|
||||
c."amount" AS amount,
|
||||
c."amount" * -1 AS negativeAmount,
|
||||
c."type" AS transactionType,
|
||||
c."createdAt" AS transactionTime,
|
||||
c.metadata->>'block_id' AS blockId,
|
||||
c.metadata->>'block' AS blockName,
|
||||
c.metadata->'input'->'credentials'->>'provider' AS llm_provider,
|
||||
c.metadata->'input'->>'model' AS llm_model,
|
||||
c.metadata->>'node_exec_id' AS node_exec_id,
|
||||
(ne."stats"->>'llm_call_count')::int AS llm_call_count,
|
||||
(ne."stats"->>'llm_retry_count')::int AS llm_retry_count,
|
||||
(ne."stats"->>'input_token_count')::int AS llm_input_token_count,
|
||||
(ne."stats"->>'output_token_count')::int AS llm_output_token_count
|
||||
FROM platform."CreditTransaction" c
|
||||
LEFT JOIN platform."AgentNodeExecution" ne
|
||||
ON (c.metadata->>'node_exec_id') = ne."id"::text
|
||||
WHERE c."createdAt" > CURRENT_DATE - INTERVAL '90 days'
|
||||
45
autogpt_platform/analytics/queries/user_onboarding.sql
Normal file
45
autogpt_platform/analytics/queries/user_onboarding.sql
Normal file
@@ -0,0 +1,45 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.user_onboarding
|
||||
-- Looker source alias: ds68 | Charts: 3
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per user onboarding record. Contains the user's
|
||||
-- stated usage reason, selected integrations, completed
|
||||
-- onboarding steps and optional first agent selection.
|
||||
-- Full history (no date filter) since onboarding happens
|
||||
-- once per user.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.UserOnboarding — Onboarding state per user
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- id TEXT Onboarding record UUID
|
||||
-- createdAt TIMESTAMPTZ When onboarding started
|
||||
-- updatedAt TIMESTAMPTZ Last update to onboarding state
|
||||
-- usageReason TEXT Why user signed up (e.g. 'work', 'personal')
|
||||
-- integrations TEXT[] Array of integration names the user selected
|
||||
-- userId TEXT User UUID
|
||||
-- completedSteps TEXT[] Array of onboarding step enums completed
|
||||
-- selectedStoreListingVersionId TEXT First marketplace agent the user chose (if any)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Usage reason breakdown
|
||||
-- SELECT "usageReason", COUNT(*) FROM analytics.user_onboarding GROUP BY 1;
|
||||
--
|
||||
-- -- Completion rate per step
|
||||
-- SELECT step, COUNT(*) AS users_completed
|
||||
-- FROM analytics.user_onboarding
|
||||
-- CROSS JOIN LATERAL UNNEST("completedSteps") AS step
|
||||
-- GROUP BY 1 ORDER BY users_completed DESC;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
id,
|
||||
"createdAt",
|
||||
"updatedAt",
|
||||
"usageReason",
|
||||
integrations,
|
||||
"userId",
|
||||
"completedSteps",
|
||||
"selectedStoreListingVersionId"
|
||||
FROM platform."UserOnboarding"
|
||||
100
autogpt_platform/analytics/queries/user_onboarding_funnel.sql
Normal file
100
autogpt_platform/analytics/queries/user_onboarding_funnel.sql
Normal file
@@ -0,0 +1,100 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.user_onboarding_funnel
|
||||
-- Looker source alias: ds74 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Pre-aggregated onboarding funnel showing how many users
|
||||
-- completed each step and the drop-off percentage from the
|
||||
-- previous step. One row per onboarding step (all 22 steps
|
||||
-- always present, even with 0 completions — prevents sparse
|
||||
-- gaps from making LAG compare the wrong predecessors).
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.UserOnboarding — Onboarding records with completedSteps array
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- step TEXT Onboarding step enum name (e.g. 'WELCOME', 'CONGRATS')
|
||||
-- step_order INT Numeric position in the funnel (1=first, 22=last)
|
||||
-- users_completed BIGINT Distinct users who completed this step
|
||||
-- pct_from_prev NUMERIC % of users from the previous step who reached this one
|
||||
--
|
||||
-- STEP ORDER
|
||||
-- 1 WELCOME 9 MARKETPLACE_VISIT 17 SCHEDULE_AGENT
|
||||
-- 2 USAGE_REASON 10 MARKETPLACE_ADD_AGENT 18 RUN_AGENTS
|
||||
-- 3 INTEGRATIONS 11 MARKETPLACE_RUN_AGENT 19 RUN_3_DAYS
|
||||
-- 4 AGENT_CHOICE 12 BUILDER_OPEN 20 TRIGGER_WEBHOOK
|
||||
-- 5 AGENT_NEW_RUN 13 BUILDER_SAVE_AGENT 21 RUN_14_DAYS
|
||||
-- 6 AGENT_INPUT 14 BUILDER_RUN_AGENT 22 RUN_AGENTS_100
|
||||
-- 7 CONGRATS 15 VISIT_COPILOT
|
||||
-- 8 GET_RESULTS 16 RE_RUN_AGENT
|
||||
--
|
||||
-- WINDOW
|
||||
-- Users who started onboarding in the last 90 days
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Full funnel
|
||||
-- SELECT * FROM analytics.user_onboarding_funnel ORDER BY step_order;
|
||||
--
|
||||
-- -- Biggest drop-off point
|
||||
-- SELECT step, pct_from_prev FROM analytics.user_onboarding_funnel
|
||||
-- ORDER BY pct_from_prev ASC LIMIT 3;
|
||||
-- =============================================================
|
||||
|
||||
WITH all_steps AS (
|
||||
-- Complete ordered grid of all 22 steps so zero-completion steps
|
||||
-- are always present, keeping LAG comparisons correct.
|
||||
SELECT step_name, step_order
|
||||
FROM (VALUES
|
||||
('WELCOME', 1),
|
||||
('USAGE_REASON', 2),
|
||||
('INTEGRATIONS', 3),
|
||||
('AGENT_CHOICE', 4),
|
||||
('AGENT_NEW_RUN', 5),
|
||||
('AGENT_INPUT', 6),
|
||||
('CONGRATS', 7),
|
||||
('GET_RESULTS', 8),
|
||||
('MARKETPLACE_VISIT', 9),
|
||||
('MARKETPLACE_ADD_AGENT', 10),
|
||||
('MARKETPLACE_RUN_AGENT', 11),
|
||||
('BUILDER_OPEN', 12),
|
||||
('BUILDER_SAVE_AGENT', 13),
|
||||
('BUILDER_RUN_AGENT', 14),
|
||||
('VISIT_COPILOT', 15),
|
||||
('RE_RUN_AGENT', 16),
|
||||
('SCHEDULE_AGENT', 17),
|
||||
('RUN_AGENTS', 18),
|
||||
('RUN_3_DAYS', 19),
|
||||
('TRIGGER_WEBHOOK', 20),
|
||||
('RUN_14_DAYS', 21),
|
||||
('RUN_AGENTS_100', 22)
|
||||
) AS t(step_name, step_order)
|
||||
),
|
||||
raw AS (
|
||||
SELECT
|
||||
u."userId",
|
||||
step_txt::text AS step
|
||||
FROM platform."UserOnboarding" u
|
||||
CROSS JOIN LATERAL UNNEST(u."completedSteps") AS step_txt
|
||||
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
|
||||
),
|
||||
step_counts AS (
|
||||
SELECT step, COUNT(DISTINCT "userId") AS users_completed
|
||||
FROM raw GROUP BY step
|
||||
),
|
||||
funnel AS (
|
||||
SELECT
|
||||
a.step_name AS step,
|
||||
a.step_order,
|
||||
COALESCE(sc.users_completed, 0) AS users_completed,
|
||||
ROUND(
|
||||
100.0 * COALESCE(sc.users_completed, 0)
|
||||
/ NULLIF(
|
||||
LAG(COALESCE(sc.users_completed, 0)) OVER (ORDER BY a.step_order),
|
||||
0
|
||||
),
|
||||
2
|
||||
) AS pct_from_prev
|
||||
FROM all_steps a
|
||||
LEFT JOIN step_counts sc ON sc.step = a.step_name
|
||||
)
|
||||
SELECT * FROM funnel ORDER BY step_order
|
||||
@@ -0,0 +1,41 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.user_onboarding_integration
|
||||
-- Looker source alias: ds75 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Pre-aggregated count of users who selected each integration
|
||||
-- during onboarding. One row per integration type, sorted
|
||||
-- by popularity.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.UserOnboarding — integrations array column
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- integration TEXT Integration name (e.g. 'github', 'slack', 'notion')
|
||||
-- users_with_integration BIGINT Distinct users who selected this integration
|
||||
--
|
||||
-- WINDOW
|
||||
-- Users who started onboarding in the last 90 days
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Full integration popularity ranking
|
||||
-- SELECT * FROM analytics.user_onboarding_integration;
|
||||
--
|
||||
-- -- Top 5 integrations
|
||||
-- SELECT * FROM analytics.user_onboarding_integration LIMIT 5;
|
||||
-- =============================================================
|
||||
|
||||
WITH exploded AS (
|
||||
SELECT
|
||||
u."userId" AS user_id,
|
||||
UNNEST(u."integrations") AS integration
|
||||
FROM platform."UserOnboarding" u
|
||||
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
|
||||
)
|
||||
SELECT
|
||||
integration,
|
||||
COUNT(DISTINCT user_id) AS users_with_integration
|
||||
FROM exploded
|
||||
WHERE integration IS NOT NULL AND integration <> ''
|
||||
GROUP BY integration
|
||||
ORDER BY users_with_integration DESC
|
||||
145
autogpt_platform/analytics/queries/users_activities.sql
Normal file
145
autogpt_platform/analytics/queries/users_activities.sql
Normal file
@@ -0,0 +1,145 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.users_activities
|
||||
-- Looker source alias: ds56 | Charts: 5
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per user with lifetime activity summary.
|
||||
-- Joins login sessions with agent graphs, executions and
|
||||
-- node-level runs to give a full picture of how engaged
|
||||
-- each user is. Includes a convenience flag for 7-day
|
||||
-- activation (did the user return at least 7 days after
|
||||
-- their first login?).
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.sessions — Login/session records
|
||||
-- platform.AgentGraph — Graphs (agents) built by the user
|
||||
-- platform.AgentGraphExecution — Agent run history
|
||||
-- platform.AgentNodeExecution — Individual block execution history
|
||||
--
|
||||
-- PERFORMANCE NOTE
|
||||
-- Each CTE aggregates its own table independently by userId.
|
||||
-- This avoids the fan-out that occurs when driving every join
|
||||
-- from user_logins across the two largest tables
|
||||
-- (AgentGraphExecution and AgentNodeExecution).
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- user_id TEXT Supabase user UUID
|
||||
-- first_login_time TIMESTAMPTZ First ever session created_at
|
||||
-- last_login_time TIMESTAMPTZ Most recent session created_at
|
||||
-- last_visit_time TIMESTAMPTZ Max of last refresh or login
|
||||
-- last_agent_save_time TIMESTAMPTZ Last time user saved an agent graph
|
||||
-- agent_count BIGINT Number of distinct active graphs built (0 if none)
|
||||
-- first_agent_run_time TIMESTAMPTZ First ever graph execution
|
||||
-- last_agent_run_time TIMESTAMPTZ Most recent graph execution
|
||||
-- unique_agent_runs BIGINT Distinct agent graphs ever run (0 if none)
|
||||
-- agent_runs BIGINT Total graph execution count (0 if none)
|
||||
-- node_execution_count BIGINT Total node executions across all runs
|
||||
-- node_execution_failed BIGINT Node executions with FAILED status
|
||||
-- node_execution_completed BIGINT Node executions with COMPLETED status
|
||||
-- node_execution_terminated BIGINT Node executions with TERMINATED status
|
||||
-- node_execution_queued BIGINT Node executions with QUEUED status
|
||||
-- node_execution_running BIGINT Node executions with RUNNING status
|
||||
-- is_active_after_7d INT 1=returned after day 7, 0=did not, NULL=too early to tell
|
||||
-- node_execution_incomplete BIGINT Node executions with INCOMPLETE status
|
||||
-- node_execution_review BIGINT Node executions with REVIEW status
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Users who ran at least one agent and returned after 7 days
|
||||
-- SELECT COUNT(*) FROM analytics.users_activities
|
||||
-- WHERE agent_runs > 0 AND is_active_after_7d = 1;
|
||||
--
|
||||
-- -- Top 10 most active users by agent runs
|
||||
-- SELECT user_id, agent_runs, node_execution_count
|
||||
-- FROM analytics.users_activities
|
||||
-- ORDER BY agent_runs DESC LIMIT 10;
|
||||
--
|
||||
-- -- 7-day activation rate
|
||||
-- SELECT
|
||||
-- SUM(CASE WHEN is_active_after_7d = 1 THEN 1 ELSE 0 END)::float
|
||||
-- / NULLIF(COUNT(CASE WHEN is_active_after_7d IS NOT NULL THEN 1 END), 0)
|
||||
-- AS activation_rate
|
||||
-- FROM analytics.users_activities;
|
||||
-- =============================================================
|
||||
|
||||
WITH user_logins AS (
|
||||
SELECT
|
||||
user_id::text AS user_id,
|
||||
MIN(created_at) AS first_login_time,
|
||||
MAX(created_at) AS last_login_time,
|
||||
GREATEST(
|
||||
MAX(refreshed_at)::timestamptz,
|
||||
MAX(created_at)::timestamptz
|
||||
) AS last_visit_time
|
||||
FROM auth.sessions
|
||||
GROUP BY user_id
|
||||
),
|
||||
user_agents AS (
|
||||
-- Aggregate AgentGraph directly by userId (no fan-out from user_logins)
|
||||
SELECT
|
||||
"userId"::text AS user_id,
|
||||
MAX("updatedAt") AS last_agent_save_time,
|
||||
COUNT(DISTINCT "id") AS agent_count
|
||||
FROM platform."AgentGraph"
|
||||
WHERE "isActive"
|
||||
GROUP BY "userId"
|
||||
),
|
||||
user_graph_runs AS (
|
||||
-- Aggregate AgentGraphExecution directly by userId
|
||||
SELECT
|
||||
"userId"::text AS user_id,
|
||||
MIN("createdAt") AS first_agent_run_time,
|
||||
MAX("createdAt") AS last_agent_run_time,
|
||||
COUNT(DISTINCT "agentGraphId") AS unique_agent_runs,
|
||||
COUNT("id") AS agent_runs
|
||||
FROM platform."AgentGraphExecution"
|
||||
GROUP BY "userId"
|
||||
),
|
||||
user_node_runs AS (
|
||||
-- Aggregate AgentNodeExecution directly; resolve userId via a
|
||||
-- single join to AgentGraphExecution instead of fanning out from
|
||||
-- user_logins through both large tables.
|
||||
SELECT
|
||||
g."userId"::text AS user_id,
|
||||
COUNT(*) AS node_execution_count,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'FAILED') AS node_execution_failed,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'COMPLETED') AS node_execution_completed,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'TERMINATED') AS node_execution_terminated,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'QUEUED') AS node_execution_queued,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'RUNNING') AS node_execution_running,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'INCOMPLETE') AS node_execution_incomplete,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'REVIEW') AS node_execution_review
|
||||
FROM platform."AgentNodeExecution" n
|
||||
JOIN platform."AgentGraphExecution" g
|
||||
ON g."id" = n."agentGraphExecutionId"
|
||||
GROUP BY g."userId"
|
||||
)
|
||||
SELECT
|
||||
ul.user_id,
|
||||
ul.first_login_time,
|
||||
ul.last_login_time,
|
||||
ul.last_visit_time,
|
||||
ua.last_agent_save_time,
|
||||
COALESCE(ua.agent_count, 0) AS agent_count,
|
||||
gr.first_agent_run_time,
|
||||
gr.last_agent_run_time,
|
||||
COALESCE(gr.unique_agent_runs, 0) AS unique_agent_runs,
|
||||
COALESCE(gr.agent_runs, 0) AS agent_runs,
|
||||
COALESCE(nr.node_execution_count, 0) AS node_execution_count,
|
||||
COALESCE(nr.node_execution_failed, 0) AS node_execution_failed,
|
||||
COALESCE(nr.node_execution_completed, 0) AS node_execution_completed,
|
||||
COALESCE(nr.node_execution_terminated, 0) AS node_execution_terminated,
|
||||
COALESCE(nr.node_execution_queued, 0) AS node_execution_queued,
|
||||
COALESCE(nr.node_execution_running, 0) AS node_execution_running,
|
||||
CASE
|
||||
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
|
||||
AND ul.last_visit_time >= ul.first_login_time + INTERVAL '7 days' THEN 1
|
||||
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
|
||||
AND ul.last_visit_time < ul.first_login_time + INTERVAL '7 days' THEN 0
|
||||
ELSE NULL
|
||||
END AS is_active_after_7d,
|
||||
COALESCE(nr.node_execution_incomplete, 0) AS node_execution_incomplete,
|
||||
COALESCE(nr.node_execution_review, 0) AS node_execution_review
|
||||
FROM user_logins ul
|
||||
LEFT JOIN user_agents ua ON ul.user_id = ua.user_id
|
||||
LEFT JOIN user_graph_runs gr ON ul.user_id = gr.user_id
|
||||
LEFT JOIN user_node_runs nr ON ul.user_id = nr.user_id
|
||||
@@ -28,6 +28,7 @@ from backend.copilot.model import (
|
||||
update_session_title,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
from backend.copilot.tools.models import (
|
||||
AgentDetailsResponse,
|
||||
AgentOutputResponse,
|
||||
@@ -265,12 +266,12 @@ async def delete_session(
|
||||
)
|
||||
|
||||
# Best-effort cleanup of the E2B sandbox (if any).
|
||||
config = ChatConfig()
|
||||
if config.use_e2b_sandbox and config.e2b_api_key:
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
|
||||
# sandbox_id is in Redis; kill_sandbox() fetches it from there.
|
||||
e2b_cfg = ChatConfig()
|
||||
if e2b_cfg.e2b_active:
|
||||
assert e2b_cfg.e2b_api_key # guaranteed by e2b_active check
|
||||
try:
|
||||
await kill_sandbox(session_id, config.e2b_api_key)
|
||||
await kill_sandbox(session_id, e2b_cfg.e2b_api_key)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[E2B] Failed to kill sandbox for session %s", session_id[:12]
|
||||
|
||||
@@ -638,7 +638,7 @@ async def test_process_review_action_auto_approve_creates_auto_approval_records(
|
||||
|
||||
# Mock get_node_executions to return node_id mapping
|
||||
mock_get_node_executions = mocker.patch(
|
||||
"backend.data.execution.get_node_executions"
|
||||
"backend.api.features.executions.review.routes.get_node_executions"
|
||||
)
|
||||
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
||||
mock_node_exec.node_exec_id = "test_node_123"
|
||||
@@ -936,7 +936,7 @@ async def test_process_review_action_auto_approve_only_applies_to_approved_revie
|
||||
|
||||
# Mock get_node_executions to return node_id mapping
|
||||
mock_get_node_executions = mocker.patch(
|
||||
"backend.data.execution.get_node_executions"
|
||||
"backend.api.features.executions.review.routes.get_node_executions"
|
||||
)
|
||||
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
||||
mock_node_exec.node_exec_id = "node_exec_approved"
|
||||
@@ -1148,7 +1148,7 @@ async def test_process_review_action_per_review_auto_approve_granularity(
|
||||
|
||||
# Mock get_node_executions to return batch node data
|
||||
mock_get_node_executions = mocker.patch(
|
||||
"backend.data.execution.get_node_executions"
|
||||
"backend.api.features.executions.review.routes.get_node_executions"
|
||||
)
|
||||
# Create mock node executions for each review
|
||||
mock_node_execs = []
|
||||
|
||||
@@ -6,10 +6,15 @@ import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, HTTPException, Query, Security, status
|
||||
from prisma.enums import ReviewStatus
|
||||
|
||||
from backend.copilot.constants import (
|
||||
is_copilot_synthetic_id,
|
||||
parse_node_id_from_exec_id,
|
||||
)
|
||||
from backend.data.execution import (
|
||||
ExecutionContext,
|
||||
ExecutionStatus,
|
||||
get_graph_execution_meta,
|
||||
get_node_executions,
|
||||
)
|
||||
from backend.data.graph import get_graph_settings
|
||||
from backend.data.human_review import (
|
||||
@@ -36,6 +41,38 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_node_ids(
|
||||
node_exec_ids: list[str],
|
||||
graph_exec_id: str,
|
||||
is_copilot: bool,
|
||||
) -> dict[str, str]:
|
||||
"""Resolve node_exec_id -> node_id for auto-approval records.
|
||||
|
||||
CoPilot synthetic IDs encode node_id in the format "{node_id}:{random}".
|
||||
Graph executions look up node_id from NodeExecution records.
|
||||
"""
|
||||
if not node_exec_ids:
|
||||
return {}
|
||||
|
||||
if is_copilot:
|
||||
return {neid: parse_node_id_from_exec_id(neid) for neid in node_exec_ids}
|
||||
|
||||
node_execs = await get_node_executions(
|
||||
graph_exec_id=graph_exec_id, include_exec_data=False
|
||||
)
|
||||
node_exec_map = {ne.node_exec_id: ne.node_id for ne in node_execs}
|
||||
|
||||
result = {}
|
||||
for neid in node_exec_ids:
|
||||
if neid in node_exec_map:
|
||||
result[neid] = node_exec_map[neid]
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to resolve node_id for {neid}: Node execution not found."
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@router.get(
|
||||
"/pending",
|
||||
summary="Get Pending Reviews",
|
||||
@@ -110,14 +147,16 @@ async def list_pending_reviews_for_execution(
|
||||
"""
|
||||
|
||||
# Verify user owns the graph execution before returning reviews
|
||||
graph_exec = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not graph_exec:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
# (CoPilot synthetic IDs don't have graph execution records)
|
||||
if not is_copilot_synthetic_id(graph_exec_id):
|
||||
graph_exec = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not graph_exec:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
)
|
||||
|
||||
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
|
||||
|
||||
@@ -160,30 +199,26 @@ async def process_review_action(
|
||||
)
|
||||
|
||||
graph_exec_id = next(iter(graph_exec_ids))
|
||||
is_copilot = is_copilot_synthetic_id(graph_exec_id)
|
||||
|
||||
# Validate execution status before processing reviews
|
||||
graph_exec_meta = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
|
||||
if not graph_exec_meta:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
)
|
||||
|
||||
# Only allow processing reviews if execution is paused for review
|
||||
# or incomplete (partial execution with some reviews already processed)
|
||||
if graph_exec_meta.status not in (
|
||||
ExecutionStatus.REVIEW,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. "
|
||||
f"Reviews can only be processed when execution is paused (REVIEW status). "
|
||||
f"Current status: {graph_exec_meta.status}",
|
||||
# Validate execution status for graph executions (skip for CoPilot synthetic IDs)
|
||||
if not is_copilot:
|
||||
graph_exec_meta = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not graph_exec_meta:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
)
|
||||
if graph_exec_meta.status not in (
|
||||
ExecutionStatus.REVIEW,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}",
|
||||
)
|
||||
|
||||
# Build review decisions map and track which reviews requested auto-approval
|
||||
# Auto-approved reviews use original data (no modifications allowed)
|
||||
@@ -236,7 +271,7 @@ async def process_review_action(
|
||||
)
|
||||
return (node_id, False)
|
||||
|
||||
# Collect node_exec_ids that need auto-approval
|
||||
# Collect node_exec_ids that need auto-approval and resolve their node_ids
|
||||
node_exec_ids_needing_auto_approval = [
|
||||
node_exec_id
|
||||
for node_exec_id, review_result in updated_reviews.items()
|
||||
@@ -244,29 +279,16 @@ async def process_review_action(
|
||||
and auto_approve_requests.get(node_exec_id, False)
|
||||
]
|
||||
|
||||
# Batch-fetch node executions to get node_ids
|
||||
node_id_map = await _resolve_node_ids(
|
||||
node_exec_ids_needing_auto_approval, graph_exec_id, is_copilot
|
||||
)
|
||||
|
||||
# Deduplicate by node_id — one auto-approval per node
|
||||
nodes_needing_auto_approval: dict[str, Any] = {}
|
||||
if node_exec_ids_needing_auto_approval:
|
||||
from backend.data.execution import get_node_executions
|
||||
|
||||
node_execs = await get_node_executions(
|
||||
graph_exec_id=graph_exec_id, include_exec_data=False
|
||||
)
|
||||
node_exec_map = {node_exec.node_exec_id: node_exec for node_exec in node_execs}
|
||||
|
||||
for node_exec_id in node_exec_ids_needing_auto_approval:
|
||||
node_exec = node_exec_map.get(node_exec_id)
|
||||
if node_exec:
|
||||
review_result = updated_reviews[node_exec_id]
|
||||
# Use the first approved review for this node (deduplicate by node_id)
|
||||
if node_exec.node_id not in nodes_needing_auto_approval:
|
||||
nodes_needing_auto_approval[node_exec.node_id] = review_result
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to create auto-approval record for {node_exec_id}: "
|
||||
f"Node execution not found. This may indicate a race condition "
|
||||
f"or data inconsistency."
|
||||
)
|
||||
for node_exec_id in node_exec_ids_needing_auto_approval:
|
||||
node_id = node_id_map.get(node_exec_id)
|
||||
if node_id and node_id not in nodes_needing_auto_approval:
|
||||
nodes_needing_auto_approval[node_id] = updated_reviews[node_exec_id]
|
||||
|
||||
# Execute all auto-approval creations in parallel (deduplicated by node_id)
|
||||
auto_approval_results = await asyncio.gather(
|
||||
@@ -281,13 +303,11 @@ async def process_review_action(
|
||||
auto_approval_failed_count = 0
|
||||
for result in auto_approval_results:
|
||||
if isinstance(result, Exception):
|
||||
# Unexpected exception during auto-approval creation
|
||||
auto_approval_failed_count += 1
|
||||
logger.error(
|
||||
f"Unexpected exception during auto-approval creation: {result}"
|
||||
)
|
||||
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
|
||||
# Auto-approval creation failed (returned False)
|
||||
auto_approval_failed_count += 1
|
||||
|
||||
# Count results
|
||||
@@ -302,22 +322,20 @@ async def process_review_action(
|
||||
if review.status == ReviewStatus.REJECTED
|
||||
)
|
||||
|
||||
# Resume execution only if ALL pending reviews for this execution have been processed
|
||||
if updated_reviews:
|
||||
# Resume graph execution only for real graph executions (not CoPilot)
|
||||
# CoPilot sessions are resumed by the LLM retrying run_block with review_id
|
||||
if not is_copilot and updated_reviews:
|
||||
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
|
||||
|
||||
if not still_has_pending:
|
||||
# Get the graph_id from any processed review
|
||||
first_review = next(iter(updated_reviews.values()))
|
||||
|
||||
try:
|
||||
# Fetch user and settings to build complete execution context
|
||||
user = await get_user_by_id(user_id)
|
||||
settings = await get_graph_settings(
|
||||
user_id=user_id, graph_id=first_review.graph_id
|
||||
)
|
||||
|
||||
# Preserve user's timezone preference when resuming execution
|
||||
user_timezone = (
|
||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import HTTPClientError, Requests, validate_url
|
||||
from backend.util.request import HTTPClientError, Requests, validate_url_host
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -80,7 +80,7 @@ async def discover_tools(
|
||||
"""
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
await validate_url_host(request.server_url)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
@@ -167,7 +167,7 @@ async def mcp_oauth_login(
|
||||
"""
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
await validate_url_host(request.server_url)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
@@ -187,7 +187,7 @@ async def mcp_oauth_login(
|
||||
|
||||
# Validate the auth server URL from metadata to prevent SSRF.
|
||||
try:
|
||||
await validate_url(auth_server_url, trusted_origins=[])
|
||||
await validate_url_host(auth_server_url)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=400,
|
||||
@@ -234,7 +234,7 @@ async def mcp_oauth_login(
|
||||
if registration_endpoint:
|
||||
# Validate the registration endpoint to prevent SSRF via metadata.
|
||||
try:
|
||||
await validate_url(registration_endpoint, trusted_origins=[])
|
||||
await validate_url_host(registration_endpoint)
|
||||
except ValueError:
|
||||
pass # Skip registration, fall back to default client_id
|
||||
else:
|
||||
@@ -429,7 +429,7 @@ async def mcp_store_token(
|
||||
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
await validate_url_host(request.server_url)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
|
||||
@@ -32,9 +32,9 @@ async def client():
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _bypass_ssrf_validation():
|
||||
"""Bypass validate_url in all route tests (test URLs don't resolve)."""
|
||||
"""Bypass validate_url_host in all route tests (test URLs don't resolve)."""
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
"backend.api.features.mcp.routes.validate_url_host",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
yield
|
||||
@@ -521,12 +521,12 @@ class TestStoreToken:
|
||||
|
||||
|
||||
class TestSSRFValidation:
|
||||
"""Verify that validate_url is enforced on all endpoints."""
|
||||
"""Verify that validate_url_host is enforced on all endpoints."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_ssrf_blocked(self, client):
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
"backend.api.features.mcp.routes.validate_url_host",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked loopback"),
|
||||
):
|
||||
@@ -541,7 +541,7 @@ class TestSSRFValidation:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_login_ssrf_blocked(self, client):
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
"backend.api.features.mcp.routes.validate_url_host",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked private IP"),
|
||||
):
|
||||
@@ -556,7 +556,7 @@ class TestSSRFValidation:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_token_ssrf_blocked(self, client):
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
"backend.api.features.mcp.routes.validate_url_host",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked loopback"),
|
||||
):
|
||||
|
||||
@@ -418,6 +418,8 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
|
||||
|
||||
|
||||
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
_optimized_description: ClassVar[str | None] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str = "",
|
||||
@@ -470,6 +472,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.block_type = block_type
|
||||
self.webhook_config = webhook_config
|
||||
self.is_sensitive_action = is_sensitive_action
|
||||
# Read from ClassVar set by initialize_blocks()
|
||||
self.optimized_description: str | None = type(self)._optimized_description
|
||||
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
|
||||
|
||||
if self.webhook_config:
|
||||
@@ -620,6 +624,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: "ExecutionContext",
|
||||
is_graph_execution: bool = True,
|
||||
**kwargs,
|
||||
) -> tuple[bool, BlockInput]:
|
||||
"""
|
||||
@@ -648,6 +653,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
graph_version=graph_version,
|
||||
block_name=self.name,
|
||||
editable=True,
|
||||
is_graph_execution=is_graph_execution,
|
||||
)
|
||||
|
||||
if decision is None:
|
||||
|
||||
@@ -126,7 +126,7 @@ class PrintToConsoleBlock(Block):
|
||||
output_schema=PrintToConsoleBlock.Output,
|
||||
test_input={"text": "Hello, World!"},
|
||||
is_sensitive_action=True,
|
||||
disabled=True, # Disabled per Nick Tindle's request (OPEN-3000)
|
||||
disabled=True,
|
||||
test_output=[
|
||||
("output", "Hello, World!"),
|
||||
("status", "printed"),
|
||||
|
||||
@@ -142,7 +142,7 @@ class BaseE2BExecutorMixin:
|
||||
start_timestamp = ts_result.stdout.strip() if ts_result.stdout else None
|
||||
|
||||
# Execute the code
|
||||
execution = await sandbox.run_code(
|
||||
execution = await sandbox.run_code( # type: ignore[attr-defined]
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox on error
|
||||
|
||||
@@ -67,6 +67,7 @@ class HITLReviewHelper:
|
||||
graph_version: int,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
is_graph_execution: bool = True,
|
||||
) -> Optional[ReviewResult]:
|
||||
"""
|
||||
Handle a review request for a block that requires human review.
|
||||
@@ -143,10 +144,11 @@ class HITLReviewHelper:
|
||||
logger.info(
|
||||
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
|
||||
)
|
||||
await HITLReviewHelper.update_node_execution_status(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.REVIEW,
|
||||
)
|
||||
if is_graph_execution:
|
||||
await HITLReviewHelper.update_node_execution_status(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.REVIEW,
|
||||
)
|
||||
return None # Signal that execution should pause
|
||||
|
||||
# Mark review as processed if not already done
|
||||
@@ -168,6 +170,7 @@ class HITLReviewHelper:
|
||||
graph_version: int,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
is_graph_execution: bool = True,
|
||||
) -> Optional[ReviewDecision]:
|
||||
"""
|
||||
Handle a review request and return the decision in a single call.
|
||||
@@ -197,6 +200,7 @@ class HITLReviewHelper:
|
||||
graph_version=graph_version,
|
||||
block_name=block_name,
|
||||
editable=editable,
|
||||
is_graph_execution=is_graph_execution,
|
||||
)
|
||||
|
||||
if review_result is None:
|
||||
|
||||
@@ -17,7 +17,7 @@ from backend.blocks.jina._auth import (
|
||||
from backend.blocks.search import GetRequest
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.request import HTTPClientError, HTTPServerError, validate_url
|
||||
from backend.util.request import HTTPClientError, HTTPServerError, validate_url_host
|
||||
|
||||
|
||||
class SearchTheWebBlock(Block, GetRequest):
|
||||
@@ -112,7 +112,7 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
|
||||
) -> BlockOutput:
|
||||
if input_data.raw_content:
|
||||
try:
|
||||
parsed_url, _, _ = await validate_url(input_data.url, [])
|
||||
parsed_url, _, _ = await validate_url_host(input_data.url)
|
||||
url = parsed_url.geturl()
|
||||
except ValueError as e:
|
||||
yield "error", f"Invalid URL: {e}"
|
||||
|
||||
@@ -31,10 +31,14 @@ from backend.data.model import (
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.prompt import compress_context, estimate_token_count
|
||||
from backend.util.request import validate_url_host
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.text import TextFormatter
|
||||
|
||||
settings = Settings()
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||
fmt = TextFormatter(autoescape=False)
|
||||
|
||||
@@ -804,6 +808,11 @@ async def llm_call(
|
||||
if tools:
|
||||
raise ValueError("Ollama does not support tools.")
|
||||
|
||||
# Validate user-provided Ollama host to prevent SSRF etc.
|
||||
await validate_url_host(
|
||||
ollama_host, trusted_hostnames=[settings.config.ollama_host]
|
||||
)
|
||||
|
||||
client = ollama.AsyncClient(host=ollama_host)
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
@@ -825,7 +834,7 @@ async def llm_call(
|
||||
elif provider == "open_router":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
client = openai.AsyncOpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
base_url=OPENROUTER_BASE_URL,
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
from backend.util.logging import TruncatedLogger
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[Perplexity-Block]")
|
||||
@@ -136,7 +137,7 @@ class PerplexityBlock(Block):
|
||||
) -> dict[str, Any]:
|
||||
"""Call Perplexity via OpenRouter and extract annotations."""
|
||||
client = openai.AsyncOpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
base_url=OPENROUTER_BASE_URL,
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
"""Configuration management for chat system."""
|
||||
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
|
||||
|
||||
class ChatConfig(BaseSettings):
|
||||
"""Configuration for the chat system."""
|
||||
@@ -19,7 +22,7 @@ class ChatConfig(BaseSettings):
|
||||
)
|
||||
api_key: str | None = Field(default=None, description="OpenAI API key")
|
||||
base_url: str | None = Field(
|
||||
default="https://openrouter.ai/api/v1",
|
||||
default=OPENROUTER_BASE_URL,
|
||||
description="Base URL for API (e.g., for OpenRouter)",
|
||||
)
|
||||
|
||||
@@ -112,9 +115,37 @@ class ChatConfig(BaseSettings):
|
||||
description="E2B sandbox template to use for copilot sessions.",
|
||||
)
|
||||
e2b_sandbox_timeout: int = Field(
|
||||
default=43200, # 12 hours — same as session_ttl
|
||||
description="E2B sandbox keepalive timeout in seconds.",
|
||||
default=10800, # 3 hours — wall-clock timeout, not idle; explicit pause is primary
|
||||
description="E2B sandbox running-time timeout (seconds). "
|
||||
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
|
||||
"mechanism; this is the safety net.",
|
||||
)
|
||||
e2b_sandbox_on_timeout: Literal["kill", "pause"] = Field(
|
||||
default="pause",
|
||||
description="E2B lifecycle action on timeout: 'pause' (default, free) or 'kill'.",
|
||||
)
|
||||
|
||||
@property
|
||||
def e2b_active(self) -> bool:
|
||||
"""True when E2B is enabled and the API key is present.
|
||||
|
||||
Single source of truth for "should we use E2B right now?".
|
||||
Prefer this over combining ``use_e2b_sandbox`` and ``e2b_api_key``
|
||||
separately at call sites.
|
||||
"""
|
||||
return self.use_e2b_sandbox and bool(self.e2b_api_key)
|
||||
|
||||
@property
|
||||
def active_e2b_api_key(self) -> str | None:
|
||||
"""Return the E2B API key when E2B is enabled and configured, else None.
|
||||
|
||||
Combines the ``use_e2b_sandbox`` flag check and key presence into one.
|
||||
Use in callers::
|
||||
|
||||
if api_key := config.active_e2b_api_key:
|
||||
# E2B is active; api_key is narrowed to str
|
||||
"""
|
||||
return self.e2b_api_key if self.e2b_active else None
|
||||
|
||||
@field_validator("use_e2b_sandbox", mode="before")
|
||||
@classmethod
|
||||
@@ -164,7 +195,7 @@ class ChatConfig(BaseSettings):
|
||||
if not v:
|
||||
v = os.getenv("OPENAI_BASE_URL")
|
||||
if not v:
|
||||
v = "https://openrouter.ai/api/v1"
|
||||
v = OPENROUTER_BASE_URL
|
||||
return v
|
||||
|
||||
@field_validator("use_claude_agent_sdk", mode="before")
|
||||
|
||||
38
autogpt_platform/backend/backend/copilot/config_test.py
Normal file
38
autogpt_platform/backend/backend/copilot/config_test.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Unit tests for ChatConfig."""
|
||||
|
||||
import pytest
|
||||
|
||||
from .config import ChatConfig
|
||||
|
||||
# Env vars that the ChatConfig validators read — must be cleared so they don't
|
||||
# override the explicit constructor values we pass in each test.
|
||||
_E2B_ENV_VARS = (
|
||||
"CHAT_USE_E2B_SANDBOX",
|
||||
"CHAT_E2B_API_KEY",
|
||||
"E2B_API_KEY",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_e2b_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for var in _E2B_ENV_VARS:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
class TestE2BActive:
|
||||
"""Tests for the e2b_active property — single source of truth for E2B usage."""
|
||||
|
||||
def test_both_enabled_and_key_present_returns_true(self):
|
||||
"""e2b_active is True when use_e2b_sandbox=True and e2b_api_key is set."""
|
||||
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key="test-key")
|
||||
assert cfg.e2b_active is True
|
||||
|
||||
def test_enabled_but_missing_key_returns_false(self):
|
||||
"""e2b_active is False when use_e2b_sandbox=True but e2b_api_key is absent."""
|
||||
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key=None)
|
||||
assert cfg.e2b_active is False
|
||||
|
||||
def test_disabled_returns_false(self):
|
||||
"""e2b_active is False when use_e2b_sandbox=False regardless of key."""
|
||||
cfg = ChatConfig(use_e2b_sandbox=False, e2b_api_key="test-key")
|
||||
assert cfg.e2b_active is False
|
||||
@@ -6,6 +6,32 @@
|
||||
COPILOT_ERROR_PREFIX = "[__COPILOT_ERROR_f7a1__]" # Renders as ErrorCard
|
||||
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
|
||||
|
||||
# Prefix for all synthetic IDs generated by CoPilot block execution.
|
||||
# Used to distinguish CoPilot-generated records from real graph execution records
|
||||
# in PendingHumanReview and other tables.
|
||||
COPILOT_SYNTHETIC_ID_PREFIX = "copilot-"
|
||||
|
||||
# Sub-prefixes for session-scoped and node-scoped synthetic IDs.
|
||||
COPILOT_SESSION_PREFIX = f"{COPILOT_SYNTHETIC_ID_PREFIX}session-"
|
||||
COPILOT_NODE_PREFIX = f"{COPILOT_SYNTHETIC_ID_PREFIX}node-"
|
||||
|
||||
# Separator used in synthetic node_exec_id to encode node_id.
|
||||
# Format: "{node_id}:{random_hex}" — extract node_id via rsplit(":", 1)[0]
|
||||
COPILOT_NODE_EXEC_ID_SEPARATOR = ":"
|
||||
|
||||
# Compaction notice messages shown to users.
|
||||
COMPACTION_DONE_MSG = "Earlier messages were summarized to fit within context limits."
|
||||
COMPACTION_TOOL_NAME = "context_compaction"
|
||||
|
||||
|
||||
def is_copilot_synthetic_id(id_value: str) -> bool:
|
||||
"""Check if an ID is a CoPilot synthetic ID (not from a real graph execution)."""
|
||||
return id_value.startswith(COPILOT_SYNTHETIC_ID_PREFIX)
|
||||
|
||||
|
||||
def parse_node_id_from_exec_id(node_exec_id: str) -> str:
|
||||
"""Extract node_id from a synthetic node_exec_id.
|
||||
|
||||
Format: "{node_id}:{random_hex}" → returns "{node_id}".
|
||||
"""
|
||||
return node_exec_id.rsplit(COPILOT_NODE_EXEC_ID_SEPARATOR, 1)[0]
|
||||
|
||||
115
autogpt_platform/backend/backend/copilot/context.py
Normal file
115
autogpt_platform/backend/backend/copilot/context.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Shared execution context for copilot SDK tool handlers.
|
||||
|
||||
All context variables and their accessors live here so that
|
||||
``tool_adapter``, ``file_ref``, and ``e2b_file_tools`` can import them
|
||||
without creating circular dependencies.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
# Allowed base directory for the Read tool.
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
# Encoded project-directory name for the current session (e.g.
|
||||
# "-private-tmp-copilot-<uuid>"). Set by set_execution_context() so path
|
||||
# validation can scope tool-results reads to the current session.
|
||||
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
|
||||
|
||||
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
|
||||
_current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||
"current_session", default=None
|
||||
)
|
||||
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
|
||||
"_current_sandbox", default=None
|
||||
)
|
||||
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
|
||||
|
||||
|
||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does."""
|
||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
|
||||
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
sandbox: "AsyncSandbox | None" = None,
|
||||
sdk_cwd: str | None = None,
|
||||
) -> None:
|
||||
"""Set per-turn context variables used by file-resolution tool handlers."""
|
||||
_current_user_id.set(user_id)
|
||||
_current_session.set(session)
|
||||
_current_sandbox.set(sandbox)
|
||||
_current_sdk_cwd.set(sdk_cwd or "")
|
||||
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
|
||||
|
||||
|
||||
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
"""Return the current (user_id, session) pair for the active request."""
|
||||
return _current_user_id.get(), _current_session.get()
|
||||
|
||||
|
||||
def get_current_sandbox() -> "AsyncSandbox | None":
|
||||
"""Return the E2B sandbox for the current session, or None if not active."""
|
||||
return _current_sandbox.get()
|
||||
|
||||
|
||||
def get_sdk_cwd() -> str:
|
||||
"""Return the SDK working directory for the current session (empty string if unset)."""
|
||||
return _current_sdk_cwd.get()
|
||||
|
||||
|
||||
E2B_WORKDIR = "/home/user"
|
||||
|
||||
|
||||
def resolve_sandbox_path(path: str) -> str:
|
||||
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
|
||||
|
||||
Raises :class:`ValueError` if the resolved path escapes the sandbox.
|
||||
"""
|
||||
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
|
||||
normalized = os.path.normpath(candidate)
|
||||
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
|
||||
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
|
||||
return normalized
|
||||
|
||||
|
||||
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
"""Return True if *path* is within an allowed host-filesystem location.
|
||||
|
||||
Allowed:
|
||||
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/tool-results/`` (SDK tool-results)
|
||||
"""
|
||||
if not path:
|
||||
return False
|
||||
|
||||
if path.startswith("~"):
|
||||
resolved = os.path.realpath(os.path.expanduser(path))
|
||||
elif not os.path.isabs(path) and sdk_cwd:
|
||||
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
|
||||
else:
|
||||
resolved = os.path.realpath(path)
|
||||
|
||||
if sdk_cwd:
|
||||
norm_cwd = os.path.realpath(sdk_cwd)
|
||||
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
|
||||
return True
|
||||
|
||||
encoded = _current_project_dir.get("")
|
||||
if encoded:
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
if resolved == tool_results_dir or resolved.startswith(
|
||||
tool_results_dir + os.sep
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
163
autogpt_platform/backend/backend/copilot/context_test.py
Normal file
163
autogpt_platform/backend/backend/copilot/context_test.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Tests for context.py — execution context variables and path helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import (
|
||||
_SDK_PROJECTS_DIR,
|
||||
_current_project_dir,
|
||||
get_current_sandbox,
|
||||
get_execution_context,
|
||||
get_sdk_cwd,
|
||||
is_allowed_local_path,
|
||||
resolve_sandbox_path,
|
||||
set_execution_context,
|
||||
)
|
||||
|
||||
|
||||
def _make_session() -> MagicMock:
|
||||
s = MagicMock()
|
||||
s.session_id = "test-session"
|
||||
return s
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Context variable getters
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_execution_context_defaults():
|
||||
"""get_execution_context returns (None, session) when user_id is not set."""
|
||||
set_execution_context(None, _make_session())
|
||||
user_id, session = get_execution_context()
|
||||
assert user_id is None
|
||||
assert session is not None
|
||||
|
||||
|
||||
def test_set_and_get_execution_context():
|
||||
"""set_execution_context stores user_id and session."""
|
||||
mock_session = _make_session()
|
||||
set_execution_context("user-abc", mock_session)
|
||||
user_id, session = get_execution_context()
|
||||
assert user_id == "user-abc"
|
||||
assert session is mock_session
|
||||
|
||||
|
||||
def test_get_current_sandbox_none_by_default():
|
||||
"""get_current_sandbox returns None when no sandbox is set."""
|
||||
set_execution_context("u1", _make_session(), sandbox=None)
|
||||
assert get_current_sandbox() is None
|
||||
|
||||
|
||||
def test_get_current_sandbox_returns_set_value():
|
||||
"""get_current_sandbox returns the sandbox set via set_execution_context."""
|
||||
mock_sandbox = MagicMock()
|
||||
set_execution_context("u1", _make_session(), sandbox=mock_sandbox)
|
||||
assert get_current_sandbox() is mock_sandbox
|
||||
|
||||
|
||||
def test_get_sdk_cwd_empty_when_not_set():
|
||||
"""get_sdk_cwd returns empty string when sdk_cwd is not set."""
|
||||
set_execution_context("u1", _make_session(), sdk_cwd=None)
|
||||
assert get_sdk_cwd() == ""
|
||||
|
||||
|
||||
def test_get_sdk_cwd_returns_set_value():
|
||||
"""get_sdk_cwd returns the value set via set_execution_context."""
|
||||
set_execution_context("u1", _make_session(), sdk_cwd="/tmp/copilot-test")
|
||||
assert get_sdk_cwd() == "/tmp/copilot-test"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_allowed_local_path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_is_allowed_local_path_empty():
|
||||
assert not is_allowed_local_path("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_inside_sdk_cwd():
|
||||
with tempfile.TemporaryDirectory() as cwd:
|
||||
path = os.path.join(cwd, "file.txt")
|
||||
assert is_allowed_local_path(path, cwd)
|
||||
|
||||
|
||||
def test_is_allowed_local_path_sdk_cwd_itself():
|
||||
with tempfile.TemporaryDirectory() as cwd:
|
||||
assert is_allowed_local_path(cwd, cwd)
|
||||
|
||||
|
||||
def test_is_allowed_local_path_outside_sdk_cwd():
|
||||
with tempfile.TemporaryDirectory() as cwd:
|
||||
assert not is_allowed_local_path("/etc/passwd", cwd)
|
||||
|
||||
|
||||
def test_is_allowed_local_path_no_sdk_cwd_no_project_dir():
|
||||
"""Without sdk_cwd or project_dir, all paths are rejected."""
|
||||
_current_project_dir.set("")
|
||||
assert not is_allowed_local_path("/tmp/some-file.txt", sdk_cwd=None)
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_results_dir():
|
||||
"""Files under the tool-results directory for the current project are allowed."""
|
||||
encoded = "test-encoded-dir"
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
path = os.path.join(tool_results_dir, "output.txt")
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert is_allowed_local_path(path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
|
||||
"""A path adjacent to tool-results/ but not inside it is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
sibling_path = os.path.join(_SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert not is_allowed_local_path(sibling_path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_sandbox_path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_absolute_valid():
|
||||
assert (
|
||||
resolve_sandbox_path("/home/user/project/main.py")
|
||||
== "/home/user/project/main.py"
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_relative():
|
||||
assert resolve_sandbox_path("project/main.py") == "/home/user/project/main.py"
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_workdir_itself():
|
||||
assert resolve_sandbox_path("/home/user") == "/home/user"
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_normalizes_dots():
|
||||
assert resolve_sandbox_path("/home/user/a/../b") == "/home/user/b"
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_escape_raises():
|
||||
with pytest.raises(ValueError, match="/home/user"):
|
||||
resolve_sandbox_path("/home/user/../../etc/passwd")
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_absolute_outside_raises():
|
||||
with pytest.raises(ValueError, match="/home/user"):
|
||||
resolve_sandbox_path("/etc/passwd")
|
||||
138
autogpt_platform/backend/backend/copilot/optimize_blocks.py
Normal file
138
autogpt_platform/backend/backend/copilot/optimize_blocks.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Scheduler job to generate LLM-optimized block descriptions.
|
||||
|
||||
Runs periodically to rewrite block descriptions into concise, actionable
|
||||
summaries that help the copilot LLM pick the right blocks during agent
|
||||
generation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from backend.blocks import get_blocks
|
||||
from backend.util.clients import get_database_manager_client, get_openai_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a technical writer for an automation platform. "
|
||||
"Rewrite the following block description to be concise (under 50 words), "
|
||||
"informative, and actionable. Focus on what the block does and when to "
|
||||
"use it. Output ONLY the rewritten description, nothing else. "
|
||||
"Do not use markdown formatting."
|
||||
)
|
||||
|
||||
# Rate-limit delay between sequential LLM calls (seconds)
|
||||
_RATE_LIMIT_DELAY = 0.5
|
||||
# Maximum tokens for optimized description generation
|
||||
_MAX_DESCRIPTION_TOKENS = 150
|
||||
# Model for generating optimized descriptions (fast, cheap)
|
||||
_MODEL = "gpt-4o-mini"
|
||||
|
||||
|
||||
async def _optimize_descriptions(blocks: list[dict[str, str]]) -> dict[str, str]:
|
||||
"""Call the shared OpenAI client to rewrite each block description."""
|
||||
client = get_openai_client()
|
||||
if client is None:
|
||||
logger.error(
|
||||
"No OpenAI client configured, skipping block description optimization"
|
||||
)
|
||||
return {}
|
||||
|
||||
results: dict[str, str] = {}
|
||||
for block in blocks:
|
||||
block_id = block["id"]
|
||||
block_name = block["name"]
|
||||
description = block["description"]
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Block name: {block_name}\nDescription: {description}",
|
||||
},
|
||||
],
|
||||
max_tokens=_MAX_DESCRIPTION_TOKENS,
|
||||
)
|
||||
optimized = (response.choices[0].message.content or "").strip()
|
||||
if optimized:
|
||||
results[block_id] = optimized
|
||||
logger.debug("Optimized description for %s", block_name)
|
||||
else:
|
||||
logger.warning("Empty response for block %s", block_name)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to optimize description for %s", block_name, exc_info=True
|
||||
)
|
||||
|
||||
await asyncio.sleep(_RATE_LIMIT_DELAY)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def optimize_block_descriptions() -> dict[str, int]:
|
||||
"""Generate optimized descriptions for blocks that don't have one yet.
|
||||
|
||||
Uses the shared OpenAI client to rewrite block descriptions into concise
|
||||
summaries suitable for agent generation prompts.
|
||||
|
||||
Returns:
|
||||
Dict with counts: processed, success, failed, skipped.
|
||||
"""
|
||||
db_client = get_database_manager_client()
|
||||
|
||||
blocks = db_client.get_blocks_needing_optimization()
|
||||
if not blocks:
|
||||
logger.info("All blocks already have optimized descriptions")
|
||||
return {"processed": 0, "success": 0, "failed": 0, "skipped": 0}
|
||||
|
||||
logger.info("Found %d blocks needing optimized descriptions", len(blocks))
|
||||
|
||||
non_empty = [b for b in blocks if b.get("description", "").strip()]
|
||||
skipped = len(blocks) - len(non_empty)
|
||||
|
||||
new_descriptions = asyncio.run(_optimize_descriptions(non_empty))
|
||||
|
||||
stats = {
|
||||
"processed": len(non_empty),
|
||||
"success": len(new_descriptions),
|
||||
"failed": len(non_empty) - len(new_descriptions),
|
||||
"skipped": skipped,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"Block description optimization complete: "
|
||||
"%d/%d succeeded, %d failed, %d skipped",
|
||||
stats["success"],
|
||||
stats["processed"],
|
||||
stats["failed"],
|
||||
stats["skipped"],
|
||||
)
|
||||
|
||||
if new_descriptions:
|
||||
for block_id, optimized in new_descriptions.items():
|
||||
db_client.update_block_optimized_description(block_id, optimized)
|
||||
|
||||
# Update in-memory descriptions first so the cache rebuilds with fresh data.
|
||||
try:
|
||||
block_classes = get_blocks()
|
||||
for block_id, optimized in new_descriptions.items():
|
||||
if block_id in block_classes:
|
||||
block_classes[block_id]._optimized_description = optimized
|
||||
logger.info(
|
||||
"Updated %d in-memory block descriptions", len(new_descriptions)
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Could not update in-memory block descriptions", exc_info=True
|
||||
)
|
||||
|
||||
from backend.copilot.tools.agent_generator.blocks import (
|
||||
reset_block_caches, # local to avoid circular import
|
||||
)
|
||||
|
||||
reset_block_caches()
|
||||
|
||||
return stats
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Unit tests for optimize_blocks._optimize_descriptions."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.copilot.optimize_blocks import _RATE_LIMIT_DELAY, _optimize_descriptions
|
||||
|
||||
|
||||
def _make_client_response(text: str) -> MagicMock:
|
||||
"""Build a minimal mock that looks like an OpenAI ChatCompletion response."""
|
||||
choice = MagicMock()
|
||||
choice.message.content = text
|
||||
response = MagicMock()
|
||||
response.choices = [choice]
|
||||
return response
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.get_event_loop().run_until_complete(coro)
|
||||
|
||||
|
||||
class TestOptimizeDescriptions:
|
||||
"""Tests for _optimize_descriptions async function."""
|
||||
|
||||
def test_returns_empty_when_no_client(self):
|
||||
with patch(
|
||||
"backend.copilot.optimize_blocks.get_openai_client", return_value=None
|
||||
):
|
||||
result = _run(
|
||||
_optimize_descriptions([{"id": "b1", "name": "B", "description": "d"}])
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
def test_success_single_block(self):
|
||||
client = MagicMock()
|
||||
client.chat.completions.create = AsyncMock(
|
||||
return_value=_make_client_response("Short desc.")
|
||||
)
|
||||
blocks = [{"id": "b1", "name": "MyBlock", "description": "A block."}]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.optimize_blocks.asyncio.sleep", new_callable=AsyncMock
|
||||
),
|
||||
):
|
||||
result = _run(_optimize_descriptions(blocks))
|
||||
|
||||
assert result == {"b1": "Short desc."}
|
||||
client.chat.completions.create.assert_called_once()
|
||||
|
||||
def test_skips_block_on_exception(self):
|
||||
client = MagicMock()
|
||||
client.chat.completions.create = AsyncMock(side_effect=Exception("API error"))
|
||||
blocks = [{"id": "b1", "name": "MyBlock", "description": "A block."}]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.optimize_blocks.asyncio.sleep", new_callable=AsyncMock
|
||||
),
|
||||
):
|
||||
result = _run(_optimize_descriptions(blocks))
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_sleeps_between_blocks(self):
|
||||
client = MagicMock()
|
||||
client.chat.completions.create = AsyncMock(
|
||||
return_value=_make_client_response("desc")
|
||||
)
|
||||
blocks = [
|
||||
{"id": "b1", "name": "B1", "description": "d1"},
|
||||
{"id": "b2", "name": "B2", "description": "d2"},
|
||||
]
|
||||
sleep_mock = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
|
||||
),
|
||||
patch("backend.copilot.optimize_blocks.asyncio.sleep", sleep_mock),
|
||||
):
|
||||
_run(_optimize_descriptions(blocks))
|
||||
|
||||
assert sleep_mock.call_count == 2
|
||||
sleep_mock.assert_called_with(_RATE_LIMIT_DELAY)
|
||||
@@ -26,6 +26,33 @@ your message as a Markdown link or image:
|
||||
The `download_url` field in the `write_workspace_file` response is already
|
||||
in the correct format — paste it directly after the `(` in the Markdown.
|
||||
|
||||
### Passing file content to tools — @@agptfile: references
|
||||
Instead of copying large file contents into a tool argument, pass a file
|
||||
reference and the platform will load the content for you.
|
||||
|
||||
Syntax: `@@agptfile:<uri>[<start>-<end>]`
|
||||
|
||||
- `<uri>` **must** start with `workspace://` or `/` (absolute path):
|
||||
- `workspace://<file_id>` — workspace file by ID
|
||||
- `workspace:///<path>` — workspace file by virtual path
|
||||
- `/absolute/local/path` — ephemeral or sdk_cwd file
|
||||
- E2B sandbox absolute path (e.g. `/home/user/script.py`)
|
||||
- `[<start>-<end>]` is an optional 1-indexed inclusive line range.
|
||||
- URIs that do not start with `workspace://` or `/` are **not** expanded.
|
||||
|
||||
Examples:
|
||||
```
|
||||
@@agptfile:workspace://abc123
|
||||
@@agptfile:workspace://abc123[10-50]
|
||||
@@agptfile:workspace:///reports/q1.md
|
||||
@@agptfile:/tmp/copilot-<session>/output.py[1-80]
|
||||
@@agptfile:/home/user/script.py
|
||||
```
|
||||
|
||||
You can embed a reference inside any string argument, or use it as the entire
|
||||
value. Multiple references in one argument are all expanded.
|
||||
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
## Agent Generation Guide
|
||||
|
||||
You can create, edit, and customize agents directly. You ARE the brain —
|
||||
generate the agent JSON yourself using block schemas, then validate and save.
|
||||
|
||||
### Workflow for Creating/Editing Agents
|
||||
|
||||
1. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
|
||||
search for relevant blocks. This returns block IDs, names, descriptions,
|
||||
and full input/output schemas.
|
||||
2. **Find library agents**: Call `find_library_agent` to discover reusable
|
||||
agents that can be composed as sub-agents via `AgentExecutorBlock`.
|
||||
3. **Generate JSON**: Build the agent JSON using block schemas:
|
||||
- Use block IDs from step 1 as `block_id` in nodes
|
||||
- Wire outputs to inputs using links
|
||||
- Set design-time config in `input_default`
|
||||
- Use `AgentInputBlock` for values the user provides at runtime
|
||||
4. **Write to workspace**: Save the JSON to a workspace file so the user
|
||||
can review it: `write_workspace_file(filename="agent.json", content=...)`
|
||||
5. **Validate**: Call `validate_agent_graph` with the agent JSON to check
|
||||
for errors
|
||||
6. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
|
||||
or fix manually based on the error descriptions. Iterate until valid.
|
||||
7. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
|
||||
the final `agent_json`
|
||||
|
||||
### Agent JSON Structure
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "<UUID v4>", // auto-generated if omitted
|
||||
"version": 1,
|
||||
"is_active": true,
|
||||
"name": "Agent Name",
|
||||
"description": "What the agent does",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "<UUID v4>",
|
||||
"block_id": "<block UUID from find_block>",
|
||||
"input_default": {
|
||||
"field_name": "design-time value"
|
||||
},
|
||||
"metadata": {
|
||||
"position": {"x": 0, "y": 0},
|
||||
"customized_name": "Optional display name"
|
||||
}
|
||||
}
|
||||
],
|
||||
"links": [
|
||||
{
|
||||
"id": "<UUID v4>",
|
||||
"source_id": "<source node UUID>",
|
||||
"source_name": "output_field_name",
|
||||
"sink_id": "<sink node UUID>",
|
||||
"sink_name": "input_field_name",
|
||||
"is_static": false
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### REQUIRED: AgentInputBlock and AgentOutputBlock
|
||||
|
||||
Every agent MUST include at least one AgentInputBlock and one AgentOutputBlock.
|
||||
These define the agent's interface — what it accepts and what it produces.
|
||||
|
||||
**AgentInputBlock** (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`):
|
||||
- Defines a user-facing input field on the agent
|
||||
- Required `input_default` fields: `name` (str), `value` (default: null)
|
||||
- Optional: `title`, `description`, `placeholder_values` (for dropdowns)
|
||||
- Output: `result` — the user-provided value at runtime
|
||||
- Create one AgentInputBlock per distinct input the agent needs
|
||||
|
||||
**AgentOutputBlock** (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`):
|
||||
- Defines a user-facing output displayed after the agent runs
|
||||
- Required `input_default` fields: `name` (str)
|
||||
- The `value` input should be linked from another block's output
|
||||
- Optional: `title`, `description`, `format` (Jinja2 template)
|
||||
- Create one AgentOutputBlock per distinct result to show the user
|
||||
|
||||
Without these blocks, the agent has no interface and the user cannot provide
|
||||
inputs or see outputs. NEVER skip them.
|
||||
|
||||
### Key Rules
|
||||
|
||||
- **Name & description**: Include `name` and `description` in the agent JSON
|
||||
when creating a new agent, or when editing and the agent's purpose changed.
|
||||
Without these the agent gets a generic default name.
|
||||
- **Design-time vs runtime**: `input_default` = values known at build time.
|
||||
For user-provided values, create an `AgentInputBlock` node and link its
|
||||
output to the consuming block's input.
|
||||
- **Credentials**: Do NOT require credentials upfront. Users configure
|
||||
credentials later in the platform UI after the agent is saved.
|
||||
- **Node spacing**: Position nodes with at least 800 X-units between them.
|
||||
- **Nested properties**: Use `parentField_#_childField` notation in link
|
||||
sink_name/source_name to access nested object fields.
|
||||
- **is_static links**: Set `is_static: true` when the link carries a
|
||||
design-time constant (matches a field in inputSchema with a default).
|
||||
- **ConditionBlock**: Needs a `StoreValueBlock` wired to its `value2` input.
|
||||
- **Prompt templates**: Use `{{variable}}` (double curly braces) for
|
||||
literal braces in prompt strings — single `{` and `}` are for
|
||||
template variables.
|
||||
- **AgentExecutorBlock**: When composing sub-agents, set `graph_id` and
|
||||
`graph_version` in input_default, and wire inputs/outputs to match
|
||||
the sub-agent's schema.
|
||||
|
||||
### Using Sub-Agents (AgentExecutorBlock)
|
||||
|
||||
To compose agents using other agents as sub-agents:
|
||||
1. Call `find_library_agent` to find the sub-agent — the response includes
|
||||
`graph_id`, `graph_version`, `input_schema`, and `output_schema`
|
||||
2. Create an `AgentExecutorBlock` node (ID: `e189baac-8c20-45a1-94a7-55177ea42565`)
|
||||
3. Set `input_default`:
|
||||
- `graph_id`: from the library agent's `graph_id`
|
||||
- `graph_version`: from the library agent's `graph_version`
|
||||
- `input_schema`: from the library agent's `input_schema` (JSON Schema)
|
||||
- `output_schema`: from the library agent's `output_schema` (JSON Schema)
|
||||
- `user_id`: leave as `""` (filled at runtime)
|
||||
- `inputs`: `{}` (populated by links at runtime)
|
||||
4. Wire inputs: link to sink names matching the sub-agent's `input_schema`
|
||||
property names (e.g., if input_schema has a `"url"` property, use
|
||||
`"url"` as the sink_name)
|
||||
5. Wire outputs: link from source names matching the sub-agent's
|
||||
`output_schema` property names
|
||||
6. Pass `library_agent_ids` to `create_agent`/`customize_agent` with
|
||||
the library agent IDs used, so the fixer can validate schemas
|
||||
|
||||
### Using MCP Tools (MCPToolBlock)
|
||||
|
||||
To use an MCP (Model Context Protocol) tool as a node in the agent:
|
||||
1. The user must specify which MCP server URL and tool name they want
|
||||
2. Create an `MCPToolBlock` node (ID: `a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4`)
|
||||
3. Set `input_default`:
|
||||
- `server_url`: the MCP server URL (e.g. `"https://mcp.example.com/sse"`)
|
||||
- `selected_tool`: the tool name on that server
|
||||
- `tool_input_schema`: JSON Schema for the tool's inputs
|
||||
- `tool_arguments`: `{}` (populated by links or hardcoded values)
|
||||
4. The block requires MCP credentials — the user configures these in the
|
||||
platform UI after the agent is saved
|
||||
5. Wire inputs using the tool argument field name directly as the sink_name
|
||||
(e.g., `query`, NOT `tool_arguments_#_query`). The execution engine
|
||||
automatically collects top-level fields matching tool_input_schema into
|
||||
tool_arguments.
|
||||
6. Output: `result` (the tool's return value) and `error` (error message)
|
||||
|
||||
### Example: Simple AI Text Processor
|
||||
|
||||
A minimal agent with input, processing, and output:
|
||||
- Node 1: `AgentInputBlock` (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`,
|
||||
input_default: {"name": "user_text", "title": "Text to process"},
|
||||
output: "result")
|
||||
- Node 2: `AITextGeneratorBlock` (input: "prompt" linked from Node 1's "result")
|
||||
- Node 3: `AgentOutputBlock` (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`,
|
||||
input_default: {"name": "summary", "title": "Summary"},
|
||||
input: "value" linked from Node 2's output)
|
||||
@@ -8,8 +8,6 @@ SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
|
||||
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
@@ -17,36 +15,23 @@ import os
|
||||
import shlex
|
||||
from typing import Any, Callable
|
||||
|
||||
from backend.copilot.tools.e2b_sandbox import E2B_WORKDIR
|
||||
from backend.copilot.context import (
|
||||
E2B_WORKDIR,
|
||||
get_current_sandbox,
|
||||
get_sdk_cwd,
|
||||
is_allowed_local_path,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Lazy imports to break circular dependency with tool_adapter.
|
||||
|
||||
|
||||
def _get_sandbox(): # type: ignore[return]
|
||||
from .tool_adapter import get_current_sandbox # noqa: E402
|
||||
|
||||
def _get_sandbox():
|
||||
return get_current_sandbox()
|
||||
|
||||
|
||||
def _is_allowed_local(path: str) -> bool:
|
||||
from .tool_adapter import is_allowed_local_path # noqa: E402
|
||||
|
||||
return is_allowed_local_path(path)
|
||||
|
||||
|
||||
def _resolve_remote(path: str) -> str:
|
||||
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
|
||||
|
||||
Raises :class:`ValueError` if the resolved path escapes the sandbox.
|
||||
"""
|
||||
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
|
||||
normalized = os.path.normpath(candidate)
|
||||
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
|
||||
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
|
||||
return normalized
|
||||
return is_allowed_local_path(path, get_sdk_cwd())
|
||||
|
||||
|
||||
def _mcp(text: str, *, error: bool = False) -> dict[str, Any]:
|
||||
@@ -63,7 +48,7 @@ def _get_sandbox_and_path(
|
||||
if sandbox is None:
|
||||
return _mcp("No E2B sandbox available", error=True)
|
||||
try:
|
||||
remote = _resolve_remote(file_path)
|
||||
remote = resolve_sandbox_path(file_path)
|
||||
except ValueError as exc:
|
||||
return _mcp(str(exc), error=True)
|
||||
return sandbox, remote
|
||||
@@ -73,6 +58,7 @@ def _get_sandbox_and_path(
|
||||
|
||||
|
||||
async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Read lines from a sandbox file, falling back to the local host for SDK-internal paths."""
|
||||
file_path: str = args.get("file_path", "")
|
||||
offset: int = max(0, int(args.get("offset", 0)))
|
||||
limit: int = max(1, int(args.get("limit", 2000)))
|
||||
@@ -104,6 +90,7 @@ async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
|
||||
async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Write content to a sandbox file, creating parent directories as needed."""
|
||||
file_path: str = args.get("file_path", "")
|
||||
content: str = args.get("content", "")
|
||||
|
||||
@@ -127,6 +114,7 @@ async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
|
||||
async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Replace a substring in a sandbox file, with optional replace-all support."""
|
||||
file_path: str = args.get("file_path", "")
|
||||
old_string: str = args.get("old_string", "")
|
||||
new_string: str = args.get("new_string", "")
|
||||
@@ -172,6 +160,7 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
|
||||
async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Find files matching a name pattern inside the sandbox using ``find``."""
|
||||
pattern: str = args.get("pattern", "")
|
||||
path: str = args.get("path", "")
|
||||
|
||||
@@ -183,7 +172,7 @@ async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
|
||||
return _mcp("No E2B sandbox available", error=True)
|
||||
|
||||
try:
|
||||
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
|
||||
search_dir = resolve_sandbox_path(path) if path else E2B_WORKDIR
|
||||
except ValueError as exc:
|
||||
return _mcp(str(exc), error=True)
|
||||
|
||||
@@ -198,6 +187,7 @@ async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
|
||||
async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Search file contents by regex inside the sandbox using ``grep -rn``."""
|
||||
pattern: str = args.get("pattern", "")
|
||||
path: str = args.get("path", "")
|
||||
include: str = args.get("include", "")
|
||||
@@ -210,7 +200,7 @@ async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
|
||||
return _mcp("No E2B sandbox available", error=True)
|
||||
|
||||
try:
|
||||
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
|
||||
search_dir = resolve_sandbox_path(path) if path else E2B_WORKDIR
|
||||
except ValueError as exc:
|
||||
return _mcp(str(exc), error=True)
|
||||
|
||||
@@ -238,7 +228,7 @@ def _read_local(file_path: str, offset: int, limit: int) -> dict[str, Any]:
|
||||
return _mcp(f"Path not allowed: {file_path}", error=True)
|
||||
expanded = os.path.realpath(os.path.expanduser(file_path))
|
||||
try:
|
||||
with open(expanded) as fh:
|
||||
with open(expanded, encoding="utf-8", errors="replace") as fh:
|
||||
selected = list(itertools.islice(fh, offset, offset + limit))
|
||||
numbered = "".join(
|
||||
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
|
||||
|
||||
@@ -7,59 +7,60 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from .e2b_file_tools import _read_local, _resolve_remote
|
||||
from .tool_adapter import _current_project_dir
|
||||
from backend.copilot.context import _current_project_dir
|
||||
|
||||
from .e2b_file_tools import _read_local, resolve_sandbox_path
|
||||
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _resolve_remote — sandbox path normalisation & boundary enforcement
|
||||
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveRemote:
|
||||
class TestResolveSandboxPath:
|
||||
def test_relative_path_resolved(self):
|
||||
assert _resolve_remote("src/main.py") == "/home/user/src/main.py"
|
||||
assert resolve_sandbox_path("src/main.py") == "/home/user/src/main.py"
|
||||
|
||||
def test_absolute_within_sandbox(self):
|
||||
assert _resolve_remote("/home/user/file.txt") == "/home/user/file.txt"
|
||||
assert resolve_sandbox_path("/home/user/file.txt") == "/home/user/file.txt"
|
||||
|
||||
def test_workdir_itself(self):
|
||||
assert _resolve_remote("/home/user") == "/home/user"
|
||||
assert resolve_sandbox_path("/home/user") == "/home/user"
|
||||
|
||||
def test_relative_dotslash(self):
|
||||
assert _resolve_remote("./README.md") == "/home/user/README.md"
|
||||
assert resolve_sandbox_path("./README.md") == "/home/user/README.md"
|
||||
|
||||
def test_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("../../etc/passwd")
|
||||
resolve_sandbox_path("../../etc/passwd")
|
||||
|
||||
def test_absolute_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("/home/user/../../etc/passwd")
|
||||
resolve_sandbox_path("/home/user/../../etc/passwd")
|
||||
|
||||
def test_absolute_outside_sandbox_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("/etc/passwd")
|
||||
resolve_sandbox_path("/etc/passwd")
|
||||
|
||||
def test_root_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("/")
|
||||
resolve_sandbox_path("/")
|
||||
|
||||
def test_home_other_user_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("/home/other/file.txt")
|
||||
resolve_sandbox_path("/home/other/file.txt")
|
||||
|
||||
def test_deep_nested_allowed(self):
|
||||
assert _resolve_remote("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
|
||||
assert resolve_sandbox_path("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
|
||||
|
||||
def test_trailing_slash_normalised(self):
|
||||
assert _resolve_remote("src/") == "/home/user/src"
|
||||
assert resolve_sandbox_path("src/") == "/home/user/src"
|
||||
|
||||
def test_double_dots_within_sandbox_ok(self):
|
||||
"""Path that resolves back within /home/user is allowed."""
|
||||
assert _resolve_remote("a/b/../c.txt") == "/home/user/a/c.txt"
|
||||
assert resolve_sandbox_path("a/b/../c.txt") == "/home/user/a/c.txt"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
281
autogpt_platform/backend/backend/copilot/sdk/file_ref.py
Normal file
281
autogpt_platform/backend/backend/copilot/sdk/file_ref.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""File reference protocol for tool call inputs.
|
||||
|
||||
Allows the LLM to pass a file reference instead of embedding large content
|
||||
inline. The processor expands ``@@agptfile:<uri>[<start>-<end>]`` tokens in tool
|
||||
arguments before the tool is executed.
|
||||
|
||||
Protocol
|
||||
--------
|
||||
|
||||
@@agptfile:<uri>[<start>-<end>]
|
||||
|
||||
``<uri>`` (required)
|
||||
- ``workspace://<file_id>`` — workspace file by ID
|
||||
- ``workspace://<file_id>#<mime>`` — same, MIME hint is ignored for reads
|
||||
- ``workspace:///<path>`` — workspace file by virtual path
|
||||
- ``/absolute/local/path`` — ephemeral or sdk_cwd file (validated by
|
||||
:func:`~backend.copilot.sdk.tool_adapter.is_allowed_local_path`)
|
||||
- Any absolute path that resolves inside the E2B sandbox
|
||||
(``/home/user/...``) when a sandbox is active
|
||||
|
||||
``[<start>-<end>]`` (optional)
|
||||
Line range, 1-indexed inclusive. Examples: ``[1-100]``, ``[50-200]``.
|
||||
Omit to read the entire file.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@agptfile:workspace://abc123
|
||||
@@agptfile:workspace://abc123[10-50]
|
||||
@@agptfile:workspace:///reports/q1.md
|
||||
@@agptfile:/tmp/copilot-<session>/output.py[1-80]
|
||||
@@agptfile:/home/user/script.sh
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.context import (
|
||||
get_current_sandbox,
|
||||
get_sdk_cwd,
|
||||
is_allowed_local_path,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.workspace_files import get_manager
|
||||
from backend.util.file import parse_workspace_uri
|
||||
|
||||
|
||||
class FileRefExpansionError(Exception):
|
||||
"""Raised when a ``@@agptfile:`` reference in tool call args fails to resolve.
|
||||
|
||||
Separating this from inline substitution lets callers (e.g. the MCP tool
|
||||
wrapper) block tool execution and surface a helpful error to the model
|
||||
rather than passing an ``[file-ref error: …]`` string as actual input.
|
||||
"""
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FILE_REF_PREFIX = "@@agptfile:"
|
||||
|
||||
# Matches: @@agptfile:<uri>[start-end]?
|
||||
# Group 1 – URI; must start with '/' (absolute path) or 'workspace://'
|
||||
# Group 2 – start line (optional)
|
||||
# Group 3 – end line (optional)
|
||||
_FILE_REF_RE = re.compile(
|
||||
re.escape(FILE_REF_PREFIX) + r"((?:workspace://|/)[^\[\s]*)(?:\[(\d+)-(\d+)\])?"
|
||||
)
|
||||
|
||||
# Maximum characters returned for a single file reference expansion.
|
||||
_MAX_EXPAND_CHARS = 200_000
|
||||
# Maximum total characters across all @@agptfile: expansions in one string.
|
||||
_MAX_TOTAL_EXPAND_CHARS = 1_000_000
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileRef:
|
||||
uri: str
|
||||
start_line: int | None # 1-indexed, inclusive
|
||||
end_line: int | None # 1-indexed, inclusive
|
||||
|
||||
|
||||
def parse_file_ref(text: str) -> FileRef | None:
|
||||
"""Return a :class:`FileRef` if *text* is a bare file reference token.
|
||||
|
||||
A "bare token" means the entire string matches the ``@@agptfile:...`` pattern
|
||||
(after stripping whitespace). Use :func:`expand_file_refs_in_string` to
|
||||
expand references embedded in larger strings.
|
||||
"""
|
||||
m = _FILE_REF_RE.fullmatch(text.strip())
|
||||
if not m:
|
||||
return None
|
||||
start = int(m.group(2)) if m.group(2) else None
|
||||
end = int(m.group(3)) if m.group(3) else None
|
||||
if start is not None and start < 1:
|
||||
return None
|
||||
if end is not None and end < 1:
|
||||
return None
|
||||
if start is not None and end is not None and end < start:
|
||||
return None
|
||||
return FileRef(uri=m.group(1), start_line=start, end_line=end)
|
||||
|
||||
|
||||
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
|
||||
"""Slice *text* to the requested 1-indexed line range (inclusive)."""
|
||||
if start is None and end is None:
|
||||
return text
|
||||
lines = text.splitlines(keepends=True)
|
||||
s = (start - 1) if start is not None else 0
|
||||
e = end if end is not None else len(lines)
|
||||
selected = list(itertools.islice(lines, s, e))
|
||||
return "".join(selected)
|
||||
|
||||
|
||||
async def read_file_bytes(
|
||||
uri: str,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
) -> bytes:
|
||||
"""Resolve *uri* to raw bytes using workspace, local, or E2B path logic.
|
||||
|
||||
Raises :class:`ValueError` if the URI cannot be resolved.
|
||||
"""
|
||||
# Strip MIME fragment (e.g. workspace://id#mime) before dispatching.
|
||||
plain = uri.split("#")[0] if uri.startswith("workspace://") else uri
|
||||
|
||||
if plain.startswith("workspace://"):
|
||||
if not user_id:
|
||||
raise ValueError("workspace:// file references require authentication")
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
ws = parse_workspace_uri(plain)
|
||||
try:
|
||||
return await (
|
||||
manager.read_file(ws.file_ref)
|
||||
if ws.is_path
|
||||
else manager.read_file_by_id(ws.file_ref)
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"File not found: {plain}")
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
|
||||
if is_allowed_local_path(plain, get_sdk_cwd()):
|
||||
resolved = os.path.realpath(os.path.expanduser(plain))
|
||||
try:
|
||||
with open(resolved, "rb") as fh:
|
||||
return fh.read()
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"File not found: {plain}")
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
try:
|
||||
remote = resolve_sandbox_path(plain)
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
|
||||
) from exc
|
||||
try:
|
||||
return bytes(await sandbox.files.read(remote, format="bytes"))
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
|
||||
|
||||
raise ValueError(
|
||||
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
|
||||
)
|
||||
|
||||
|
||||
async def resolve_file_ref(
|
||||
ref: FileRef,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
) -> str:
|
||||
"""Resolve a :class:`FileRef` to its text content."""
|
||||
raw = await read_file_bytes(ref.uri, user_id, session)
|
||||
return _apply_line_range(
|
||||
raw.decode("utf-8", errors="replace"), ref.start_line, ref.end_line
|
||||
)
|
||||
|
||||
|
||||
async def expand_file_refs_in_string(
|
||||
text: str,
|
||||
user_id: str | None,
|
||||
session: "ChatSession",
|
||||
*,
|
||||
raise_on_error: bool = False,
|
||||
) -> str:
|
||||
"""Expand all ``@@agptfile:...`` tokens in *text*, returning the substituted string.
|
||||
|
||||
Non-reference text is passed through unchanged.
|
||||
|
||||
If *raise_on_error* is ``False`` (default), expansion errors are surfaced
|
||||
inline as ``[file-ref error: <message>]`` — useful for display/log contexts
|
||||
where partial expansion is acceptable.
|
||||
|
||||
If *raise_on_error* is ``True``, any resolution failure raises
|
||||
:class:`FileRefExpansionError` immediately so the caller can block the
|
||||
operation and surface a clean error to the model.
|
||||
"""
|
||||
if FILE_REF_PREFIX not in text:
|
||||
return text
|
||||
|
||||
result: list[str] = []
|
||||
last_end = 0
|
||||
total_chars = 0
|
||||
for m in _FILE_REF_RE.finditer(text):
|
||||
result.append(text[last_end : m.start()])
|
||||
start = int(m.group(2)) if m.group(2) else None
|
||||
end = int(m.group(3)) if m.group(3) else None
|
||||
if (start is not None and start < 1) or (end is not None and end < 1):
|
||||
msg = f"line numbers must be >= 1: {m.group(0)}"
|
||||
if raise_on_error:
|
||||
raise FileRefExpansionError(msg)
|
||||
result.append(f"[file-ref error: {msg}]")
|
||||
last_end = m.end()
|
||||
continue
|
||||
if start is not None and end is not None and end < start:
|
||||
msg = f"end line must be >= start line: {m.group(0)}"
|
||||
if raise_on_error:
|
||||
raise FileRefExpansionError(msg)
|
||||
result.append(f"[file-ref error: {msg}]")
|
||||
last_end = m.end()
|
||||
continue
|
||||
ref = FileRef(uri=m.group(1), start_line=start, end_line=end)
|
||||
try:
|
||||
content = await resolve_file_ref(ref, user_id, session)
|
||||
if len(content) > _MAX_EXPAND_CHARS:
|
||||
content = content[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
|
||||
remaining = _MAX_TOTAL_EXPAND_CHARS - total_chars
|
||||
if remaining <= 0:
|
||||
content = "[file-ref budget exhausted: total expansion limit reached]"
|
||||
elif len(content) > remaining:
|
||||
content = content[:remaining] + "\n... [total budget exhausted]"
|
||||
total_chars += len(content)
|
||||
result.append(content)
|
||||
except ValueError as exc:
|
||||
logger.warning("file-ref expansion failed for %r: %s", m.group(0), exc)
|
||||
if raise_on_error:
|
||||
raise FileRefExpansionError(str(exc)) from exc
|
||||
result.append(f"[file-ref error: {exc}]")
|
||||
last_end = m.end()
|
||||
|
||||
result.append(text[last_end:])
|
||||
return "".join(result)
|
||||
|
||||
|
||||
async def expand_file_refs_in_args(
|
||||
args: dict[str, Any],
|
||||
user_id: str | None,
|
||||
session: "ChatSession",
|
||||
) -> dict[str, Any]:
|
||||
"""Recursively expand ``@@agptfile:...`` references in tool call arguments.
|
||||
|
||||
String values are expanded in-place. Nested dicts and lists are
|
||||
traversed. Non-string scalars are returned unchanged.
|
||||
|
||||
Raises :class:`FileRefExpansionError` if any reference fails to resolve,
|
||||
so the tool is *not* executed with an error string as its input. The
|
||||
caller (the MCP tool wrapper) should convert this into an MCP error
|
||||
response that lets the model correct the reference before retrying.
|
||||
"""
|
||||
if not args:
|
||||
return args
|
||||
|
||||
async def _expand(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
return await expand_file_refs_in_string(
|
||||
value, user_id, session, raise_on_error=True
|
||||
)
|
||||
if isinstance(value, dict):
|
||||
return {k: await _expand(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [await _expand(item) for item in value]
|
||||
return value
|
||||
|
||||
return {k: await _expand(v) for k, v in args.items()}
|
||||
@@ -0,0 +1,328 @@
|
||||
"""Integration tests for @@agptfile: reference expansion in tool calls.
|
||||
|
||||
These tests verify the end-to-end behaviour of the file reference protocol:
|
||||
- Parsing @@agptfile: tokens from tool arguments
|
||||
- Resolving local-filesystem paths (sdk_cwd / ephemeral)
|
||||
- Expanding references inside the tool-call pipeline (_execute_tool_sync)
|
||||
- The extended Read tool handler (workspace:// pass-through via session context)
|
||||
|
||||
No real LLM or database is required; workspace reads are stubbed where needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.sdk.file_ref import (
|
||||
FileRef,
|
||||
expand_file_refs_in_args,
|
||||
expand_file_refs_in_string,
|
||||
read_file_bytes,
|
||||
resolve_file_ref,
|
||||
)
|
||||
from backend.copilot.sdk.tool_adapter import _read_file_handler
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_session(session_id: str = "integ-sess") -> MagicMock:
|
||||
s = MagicMock()
|
||||
s.session_id = session_id
|
||||
return s
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Local-file resolution (sdk_cwd)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_file_ref_local_path():
|
||||
"""resolve_file_ref reads a real local file when it's within sdk_cwd."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
# Write a test file inside sdk_cwd
|
||||
test_file = os.path.join(sdk_cwd, "hello.txt")
|
||||
with open(test_file, "w") as f:
|
||||
f.write("line1\nline2\nline3\n")
|
||||
|
||||
session = _make_session()
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
ref = FileRef(uri=test_file, start_line=None, end_line=None)
|
||||
content = await resolve_file_ref(ref, user_id="u1", session=session)
|
||||
|
||||
assert content == "line1\nline2\nline3\n"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_file_ref_local_path_with_line_range():
|
||||
"""resolve_file_ref respects line ranges for local files."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
test_file = os.path.join(sdk_cwd, "multi.txt")
|
||||
lines = [f"line{i}\n" for i in range(1, 11)] # line1 … line10
|
||||
with open(test_file, "w") as f:
|
||||
f.writelines(lines)
|
||||
|
||||
session = _make_session()
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
ref = FileRef(uri=test_file, start_line=3, end_line=5)
|
||||
content = await resolve_file_ref(ref, user_id="u1", session=session)
|
||||
|
||||
assert content == "line3\nline4\nline5\n"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_file_ref_rejects_path_outside_sdk_cwd():
|
||||
"""resolve_file_ref raises ValueError for paths outside sdk_cwd."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
|
||||
"backend.copilot.context._current_sandbox"
|
||||
) as mock_sandbox_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
mock_sandbox_var.get.return_value = None
|
||||
|
||||
ref = FileRef(uri="/etc/passwd", start_line=None, end_line=None)
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
await resolve_file_ref(ref, user_id="u1", session=_make_session())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# expand_file_refs_in_string — integration with real files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_string_with_real_file():
|
||||
"""expand_file_refs_in_string replaces @@agptfile: token with actual content."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
test_file = os.path.join(sdk_cwd, "data.txt")
|
||||
with open(test_file, "w") as f:
|
||||
f.write("hello world\n")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_string(
|
||||
f"Content: @@agptfile:{test_file}",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result == "Content: hello world\n"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_string_missing_file_is_surfaced_inline():
|
||||
"""Missing file ref yields [file-ref error: …] inline rather than raising."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
missing = os.path.join(sdk_cwd, "does_not_exist.txt")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_string(
|
||||
f"@@agptfile:{missing}",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert "[file-ref error:" in result
|
||||
assert "not found" in result.lower() or "not allowed" in result.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# expand_file_refs_in_args — dict traversal with real files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_args_replaces_file_ref_in_nested_dict():
|
||||
"""Nested @@agptfile: references in args are fully expanded."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
file_a = os.path.join(sdk_cwd, "a.txt")
|
||||
file_b = os.path.join(sdk_cwd, "b.txt")
|
||||
with open(file_a, "w") as f:
|
||||
f.write("AAA")
|
||||
with open(file_b, "w") as f:
|
||||
f.write("BBB")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{
|
||||
"outer": {
|
||||
"content_a": f"@@agptfile:{file_a}",
|
||||
"content_b": f"start @@agptfile:{file_b} end",
|
||||
},
|
||||
"count": 42,
|
||||
},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["outer"]["content_a"] == "AAA"
|
||||
assert result["outer"]["content_b"] == "start BBB end"
|
||||
assert result["count"] == 42
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_file_handler — extended to accept workspace:// and local paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_handler_local_file():
|
||||
"""_read_file_handler reads a local file when it's within sdk_cwd."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
test_file = os.path.join(sdk_cwd, "read_test.txt")
|
||||
lines = [f"L{i}\n" for i in range(1, 6)]
|
||||
with open(test_file, "w") as f:
|
||||
f.writelines(lines)
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
|
||||
"backend.copilot.context._current_project_dir"
|
||||
) as mock_proj_var, patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", _make_session()),
|
||||
):
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
mock_proj_var.get.return_value = ""
|
||||
|
||||
result = await _read_file_handler(
|
||||
{"file_path": test_file, "offset": 0, "limit": 5}
|
||||
)
|
||||
|
||||
assert not result["isError"]
|
||||
text = result["content"][0]["text"]
|
||||
assert "L1" in text
|
||||
assert "L5" in text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_handler_workspace_uri():
|
||||
"""_read_file_handler handles workspace:// URIs via the workspace manager."""
|
||||
mock_session = _make_session()
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.read_file_by_id.return_value = b"workspace file content\nline two\n"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", mock_session),
|
||||
), patch(
|
||||
"backend.copilot.sdk.file_ref.get_manager",
|
||||
new=AsyncMock(return_value=mock_manager),
|
||||
):
|
||||
result = await _read_file_handler(
|
||||
{"file_path": "workspace://file-id-abc", "offset": 0, "limit": 10}
|
||||
)
|
||||
|
||||
assert not result["isError"], result["content"][0]["text"]
|
||||
text = result["content"][0]["text"]
|
||||
assert "workspace file content" in text
|
||||
assert "line two" in text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_handler_workspace_uri_no_session():
|
||||
"""_read_file_handler returns error when workspace:// is used without session."""
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=(None, None),
|
||||
):
|
||||
result = await _read_file_handler({"file_path": "workspace://some-id"})
|
||||
|
||||
assert result["isError"]
|
||||
assert "session" in result["content"][0]["text"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_handler_access_denied():
|
||||
"""_read_file_handler rejects paths outside allowed locations."""
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
|
||||
"backend.copilot.context._current_sandbox"
|
||||
) as mock_sandbox, patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", _make_session()),
|
||||
):
|
||||
mock_cwd.get.return_value = "/tmp/safe-dir"
|
||||
mock_sandbox.get.return_value = None
|
||||
|
||||
result = await _read_file_handler({"file_path": "/etc/passwd"})
|
||||
|
||||
assert result["isError"]
|
||||
assert "not allowed" in result["content"][0]["text"].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# read_file_bytes — workspace:///path (virtual path) and E2B sandbox branch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_bytes_workspace_virtual_path():
|
||||
"""workspace:///path resolves via manager.read_file (is_path=True path)."""
|
||||
session = _make_session()
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.read_file.return_value = b"virtual path content"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.get_manager",
|
||||
new=AsyncMock(return_value=mock_manager),
|
||||
):
|
||||
result = await read_file_bytes("workspace:///reports/q1.md", "user-1", session)
|
||||
|
||||
assert result == b"virtual path content"
|
||||
mock_manager.read_file.assert_awaited_once_with("/reports/q1.md")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_bytes_e2b_sandbox_branch():
|
||||
"""read_file_bytes reads from the E2B sandbox when a sandbox is active."""
|
||||
session = _make_session()
|
||||
mock_sandbox = AsyncMock()
|
||||
mock_sandbox.files.read.return_value = bytearray(b"sandbox content")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
|
||||
"backend.copilot.context._current_sandbox"
|
||||
) as mock_sandbox_var, patch(
|
||||
"backend.copilot.context._current_project_dir"
|
||||
) as mock_proj:
|
||||
mock_cwd.get.return_value = ""
|
||||
mock_sandbox_var.get.return_value = mock_sandbox
|
||||
mock_proj.get.return_value = ""
|
||||
|
||||
result = await read_file_bytes("/home/user/script.sh", None, session)
|
||||
|
||||
assert result == b"sandbox content"
|
||||
mock_sandbox.files.read.assert_awaited_once_with(
|
||||
"/home/user/script.sh", format="bytes"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_bytes_e2b_path_escapes_sandbox_raises():
|
||||
"""read_file_bytes raises ValueError for paths that escape the sandbox root."""
|
||||
session = _make_session()
|
||||
mock_sandbox = AsyncMock()
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
|
||||
"backend.copilot.context._current_sandbox"
|
||||
) as mock_sandbox_var, patch(
|
||||
"backend.copilot.context._current_project_dir"
|
||||
) as mock_proj:
|
||||
mock_cwd.get.return_value = ""
|
||||
mock_sandbox_var.get.return_value = mock_sandbox
|
||||
mock_proj.get.return_value = ""
|
||||
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
await read_file_bytes("/etc/passwd", None, session)
|
||||
382
autogpt_platform/backend/backend/copilot/sdk/file_ref_test.py
Normal file
382
autogpt_platform/backend/backend/copilot/sdk/file_ref_test.py
Normal file
@@ -0,0 +1,382 @@
|
||||
"""Tests for the @@agptfile: reference protocol (file_ref.py)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.sdk.file_ref import (
|
||||
_MAX_EXPAND_CHARS,
|
||||
FileRef,
|
||||
FileRefExpansionError,
|
||||
_apply_line_range,
|
||||
expand_file_refs_in_args,
|
||||
expand_file_refs_in_string,
|
||||
parse_file_ref,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_ref
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_file_ref_workspace_id():
|
||||
ref = parse_file_ref("@@agptfile:workspace://abc123")
|
||||
assert ref == FileRef(uri="workspace://abc123", start_line=None, end_line=None)
|
||||
|
||||
|
||||
def test_parse_file_ref_workspace_id_with_mime():
|
||||
ref = parse_file_ref("@@agptfile:workspace://abc123#text/plain")
|
||||
assert ref is not None
|
||||
assert ref.uri == "workspace://abc123#text/plain"
|
||||
assert ref.start_line is None
|
||||
|
||||
|
||||
def test_parse_file_ref_workspace_path():
|
||||
ref = parse_file_ref("@@agptfile:workspace:///reports/q1.md")
|
||||
assert ref is not None
|
||||
assert ref.uri == "workspace:///reports/q1.md"
|
||||
|
||||
|
||||
def test_parse_file_ref_with_line_range():
|
||||
ref = parse_file_ref("@@agptfile:workspace://abc123[10-50]")
|
||||
assert ref == FileRef(uri="workspace://abc123", start_line=10, end_line=50)
|
||||
|
||||
|
||||
def test_parse_file_ref_local_path():
|
||||
ref = parse_file_ref("@@agptfile:/tmp/copilot-session/output.py[1-100]")
|
||||
assert ref is not None
|
||||
assert ref.uri == "/tmp/copilot-session/output.py"
|
||||
assert ref.start_line == 1
|
||||
assert ref.end_line == 100
|
||||
|
||||
|
||||
def test_parse_file_ref_no_match():
|
||||
assert parse_file_ref("just a normal string") is None
|
||||
assert parse_file_ref("workspace://abc123") is None # missing @@agptfile: prefix
|
||||
assert (
|
||||
parse_file_ref("@@agptfile:workspace://abc123 extra") is None
|
||||
) # not full match
|
||||
|
||||
|
||||
def test_parse_file_ref_strips_whitespace():
|
||||
ref = parse_file_ref(" @@agptfile:workspace://abc123 ")
|
||||
assert ref is not None
|
||||
assert ref.uri == "workspace://abc123"
|
||||
|
||||
|
||||
def test_parse_file_ref_invalid_range_zero_start():
|
||||
assert parse_file_ref("@@agptfile:workspace://abc123[0-5]") is None
|
||||
|
||||
|
||||
def test_parse_file_ref_invalid_range_end_less_than_start():
|
||||
assert parse_file_ref("@@agptfile:workspace://abc123[10-5]") is None
|
||||
|
||||
|
||||
def test_parse_file_ref_invalid_range_zero_end():
|
||||
assert parse_file_ref("@@agptfile:workspace://abc123[1-0]") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _apply_line_range
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
TEXT = "line1\nline2\nline3\nline4\nline5\n"
|
||||
|
||||
|
||||
def test_apply_line_range_no_range():
|
||||
assert _apply_line_range(TEXT, None, None) == TEXT
|
||||
|
||||
|
||||
def test_apply_line_range_start_only():
|
||||
result = _apply_line_range(TEXT, 3, None)
|
||||
assert result == "line3\nline4\nline5\n"
|
||||
|
||||
|
||||
def test_apply_line_range_full():
|
||||
result = _apply_line_range(TEXT, 2, 4)
|
||||
assert result == "line2\nline3\nline4\n"
|
||||
|
||||
|
||||
def test_apply_line_range_single_line():
|
||||
result = _apply_line_range(TEXT, 2, 2)
|
||||
assert result == "line2\n"
|
||||
|
||||
|
||||
def test_apply_line_range_beyond_eof():
|
||||
result = _apply_line_range(TEXT, 4, 999)
|
||||
assert result == "line4\nline5\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# expand_file_refs_in_string
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_session(session_id: str = "sess-1") -> MagicMock:
|
||||
session = MagicMock()
|
||||
session.session_id = session_id
|
||||
return session
|
||||
|
||||
|
||||
async def _resolve_always(ref: FileRef, _user_id: str | None, _session: object) -> str:
|
||||
"""Stub resolver that returns the URI and range as a descriptive string."""
|
||||
if ref.start_line is not None:
|
||||
return f"content:{ref.uri}[{ref.start_line}-{ref.end_line}]"
|
||||
return f"content:{ref.uri}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_no_refs():
|
||||
result = await expand_file_refs_in_string(
|
||||
"no references here", user_id="u1", session=_make_session()
|
||||
)
|
||||
assert result == "no references here"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_single_ref():
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_always),
|
||||
):
|
||||
result = await expand_file_refs_in_string(
|
||||
"@@agptfile:workspace://abc123",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result == "content:workspace://abc123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_ref_with_range():
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_always),
|
||||
):
|
||||
result = await expand_file_refs_in_string(
|
||||
"@@agptfile:workspace://abc123[10-50]",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result == "content:workspace://abc123[10-50]"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_ref_embedded_in_text():
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_always),
|
||||
):
|
||||
result = await expand_file_refs_in_string(
|
||||
"Here is the file: @@agptfile:workspace://abc123 — done",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result == "Here is the file: content:workspace://abc123 — done"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_multiple_refs():
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_always),
|
||||
):
|
||||
result = await expand_file_refs_in_string(
|
||||
"@@agptfile:workspace://file1 and @@agptfile:workspace://file2[1-5]",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result == "content:workspace://file1 and content:workspace://file2[1-5]"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_invalid_range_zero_start_surfaces_inline():
|
||||
"""expand_file_refs_in_string surfaces [file-ref error: ...] for zero-start ranges."""
|
||||
result = await expand_file_refs_in_string(
|
||||
"@@agptfile:workspace://abc123[0-5]",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert "[file-ref error:" in result
|
||||
assert "line numbers must be >= 1" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_invalid_range_end_less_than_start_surfaces_inline():
|
||||
"""expand_file_refs_in_string surfaces [file-ref error: ...] when end < start."""
|
||||
result = await expand_file_refs_in_string(
|
||||
"prefix @@agptfile:workspace://abc123[10-5] suffix",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert "[file-ref error:" in result
|
||||
assert "end line must be >= start line" in result
|
||||
assert "prefix" in result
|
||||
assert "suffix" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_ref_error_surfaces_inline():
|
||||
async def _raise(*args, **kwargs): # noqa: ARG001
|
||||
raise ValueError("file not found")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_raise),
|
||||
):
|
||||
result = await expand_file_refs_in_string(
|
||||
"@@agptfile:workspace://bad",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert "[file-ref error:" in result
|
||||
assert "file not found" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# expand_file_refs_in_args
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_args_flat():
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_always),
|
||||
):
|
||||
result = await expand_file_refs_in_args(
|
||||
{"content": "@@agptfile:workspace://abc123", "other": 42},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result["content"] == "content:workspace://abc123"
|
||||
assert result["other"] == 42
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_args_nested_dict():
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_always),
|
||||
):
|
||||
result = await expand_file_refs_in_args(
|
||||
{"outer": {"inner": "@@agptfile:workspace://nested"}},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result["outer"]["inner"] == "content:workspace://nested"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_args_list():
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_always),
|
||||
):
|
||||
result = await expand_file_refs_in_args(
|
||||
{
|
||||
"items": [
|
||||
"@@agptfile:workspace://a",
|
||||
"plain",
|
||||
"@@agptfile:workspace://b[1-3]",
|
||||
]
|
||||
},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result["items"] == [
|
||||
"content:workspace://a",
|
||||
"plain",
|
||||
"content:workspace://b[1-3]",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_args_empty():
|
||||
result = await expand_file_refs_in_args({}, user_id="u1", session=_make_session())
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_args_no_refs():
|
||||
result = await expand_file_refs_in_args(
|
||||
{"key": "no refs here", "num": 1},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result == {"key": "no refs here", "num": 1}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_args_raises_on_file_ref_error():
|
||||
"""expand_file_refs_in_args raises FileRefExpansionError instead of passing
|
||||
the inline error string to the tool, blocking tool execution."""
|
||||
|
||||
async def _raise(*args, **kwargs): # noqa: ARG001
|
||||
raise ValueError("path does not exist")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_raise),
|
||||
):
|
||||
with pytest.raises(FileRefExpansionError) as exc_info:
|
||||
await expand_file_refs_in_args(
|
||||
{"prompt": "@@agptfile:/home/user/missing.txt"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert "path does not exist" in str(exc_info.value)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-file truncation and aggregate budget
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_per_file_truncation():
|
||||
"""Content exceeding _MAX_EXPAND_CHARS is truncated with a marker."""
|
||||
oversized = "x" * (_MAX_EXPAND_CHARS + 100)
|
||||
|
||||
async def _resolve_oversized(ref: FileRef, _uid: str | None, _s: object) -> str:
|
||||
return oversized
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_oversized),
|
||||
):
|
||||
result = await expand_file_refs_in_string(
|
||||
"@@agptfile:workspace://big-file",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert len(result) <= _MAX_EXPAND_CHARS + len("\n... [truncated]") + 10
|
||||
assert "[truncated]" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_aggregate_budget_exhausted():
|
||||
"""When the aggregate budget is exhausted, later refs get the budget message."""
|
||||
# Each file returns just under 300K; after ~4 files the 1M budget is used.
|
||||
big_chunk = "y" * 300_000
|
||||
|
||||
async def _resolve_big(ref: FileRef, _uid: str | None, _s: object) -> str:
|
||||
return big_chunk
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_big),
|
||||
):
|
||||
# 5 refs @ 300K each = 1.5M → last ref(s) should hit the aggregate limit
|
||||
refs = " ".join(f"@@agptfile:workspace://f{i}" for i in range(5))
|
||||
result = await expand_file_refs_in_string(
|
||||
refs,
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert "budget exhausted" in result
|
||||
@@ -0,0 +1,28 @@
|
||||
## MCP Tool Guide
|
||||
|
||||
### Workflow
|
||||
|
||||
`run_mcp_tool` follows a two-step pattern:
|
||||
|
||||
1. **Discover** — call with only `server_url` to list available tools on the server.
|
||||
2. **Execute** — call again with `server_url`, `tool_name`, and `tool_arguments` to run a tool.
|
||||
|
||||
### Known hosted MCP servers
|
||||
|
||||
Use these URLs directly without asking the user:
|
||||
|
||||
| Service | URL |
|
||||
|---|---|
|
||||
| Notion | `https://mcp.notion.com/mcp` |
|
||||
| Linear | `https://mcp.linear.app/mcp` |
|
||||
| Stripe | `https://mcp.stripe.com` |
|
||||
| Intercom | `https://mcp.intercom.com/mcp` |
|
||||
| Cloudflare | `https://mcp.cloudflare.com/mcp` |
|
||||
| Atlassian / Jira | `https://mcp.atlassian.com/mcp` |
|
||||
|
||||
For other services, search the MCP registry at https://registry.modelcontextprotocol.io/.
|
||||
|
||||
### Authentication
|
||||
|
||||
If the server requires credentials, a `SetupRequirementsResponse` is returned with an OAuth
|
||||
login prompt. Once the user completes the flow and confirms, retry the same call immediately.
|
||||
@@ -536,10 +536,12 @@ async def test_wait_for_stash_signaled():
|
||||
result = await wait_for_stash(timeout=1.0)
|
||||
|
||||
assert result is True
|
||||
assert _pto.get({}).get("WebSearch") == ["result data"]
|
||||
pto = _pto.get()
|
||||
assert pto is not None
|
||||
assert pto.get("WebSearch") == ["result data"]
|
||||
|
||||
# Cleanup
|
||||
_pto.set({}) # type: ignore[arg-type]
|
||||
_pto.set({})
|
||||
_stash_event.set(None)
|
||||
|
||||
|
||||
@@ -554,7 +556,7 @@ async def test_wait_for_stash_timeout():
|
||||
assert result is False
|
||||
|
||||
# Cleanup
|
||||
_pto.set({}) # type: ignore[arg-type]
|
||||
_pto.set({})
|
||||
_stash_event.set(None)
|
||||
|
||||
|
||||
@@ -573,10 +575,12 @@ async def test_wait_for_stash_already_stashed():
|
||||
assert result is True
|
||||
|
||||
# But the stash itself is populated
|
||||
assert _pto.get({}).get("Read") == ["file contents"]
|
||||
pto = _pto.get()
|
||||
assert pto is not None
|
||||
assert pto.get("Read") == ["file contents"]
|
||||
|
||||
# Cleanup
|
||||
_pto.set({}) # type: ignore[arg-type]
|
||||
_pto.set({})
|
||||
_stash_event.set(None)
|
||||
|
||||
|
||||
|
||||
@@ -10,12 +10,13 @@ import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from backend.copilot.context import is_allowed_local_path
|
||||
|
||||
from .tool_adapter import (
|
||||
BLOCKED_TOOLS,
|
||||
DANGEROUS_PATTERNS,
|
||||
MCP_TOOL_PREFIX,
|
||||
WORKSPACE_SCOPED_TOOLS,
|
||||
is_allowed_local_path,
|
||||
stash_pending_tool_output,
|
||||
)
|
||||
|
||||
|
||||
@@ -9,8 +9,9 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import _current_project_dir
|
||||
|
||||
from .security_hooks import _validate_tool_access, _validate_user_isolation
|
||||
from .service import _is_tool_error_or_denial
|
||||
|
||||
SDK_CWD = "/tmp/copilot-abc123"
|
||||
|
||||
@@ -120,8 +121,6 @@ def test_read_no_cwd_denies_absolute():
|
||||
|
||||
|
||||
def test_read_tool_results_allowed():
|
||||
from .tool_adapter import _current_project_dir
|
||||
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
|
||||
# is_allowed_local_path requires the session's encoded cwd to be set
|
||||
@@ -133,16 +132,14 @@ def test_read_tool_results_allowed():
|
||||
_current_project_dir.reset(token)
|
||||
|
||||
|
||||
def test_read_claude_projects_session_dir_allowed():
|
||||
"""Files within the current session's project dir are allowed."""
|
||||
from .tool_adapter import _current_project_dir
|
||||
|
||||
def test_read_claude_projects_settings_json_denied():
|
||||
"""SDK-internal artifacts like settings.json are NOT accessible — only tool-results/ is."""
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
|
||||
token = _current_project_dir.set("-tmp-copilot-abc123")
|
||||
try:
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert not _is_denied(result)
|
||||
assert _is_denied(result)
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
|
||||
@@ -357,76 +354,3 @@ async def test_task_slot_released_on_failure(_hooks):
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
# -- _is_tool_error_or_denial ------------------------------------------------
|
||||
|
||||
|
||||
class TestIsToolErrorOrDenial:
|
||||
def test_none_content(self):
|
||||
assert _is_tool_error_or_denial(None) is False
|
||||
|
||||
def test_empty_content(self):
|
||||
assert _is_tool_error_or_denial("") is False
|
||||
|
||||
def test_benign_output(self):
|
||||
assert _is_tool_error_or_denial("All good, no issues.") is False
|
||||
|
||||
def test_security_marker(self):
|
||||
assert _is_tool_error_or_denial("[SECURITY] Tool access blocked") is True
|
||||
|
||||
def test_cannot_be_bypassed(self):
|
||||
assert _is_tool_error_or_denial("This restriction cannot be bypassed.") is True
|
||||
|
||||
def test_not_allowed(self):
|
||||
assert _is_tool_error_or_denial("Operation not allowed in sandbox") is True
|
||||
|
||||
def test_background_task_denial(self):
|
||||
assert (
|
||||
_is_tool_error_or_denial(
|
||||
"Background task execution is not supported. "
|
||||
"Run tasks in the foreground instead."
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_subtask_limit_denial(self):
|
||||
assert (
|
||||
_is_tool_error_or_denial(
|
||||
"Maximum 2 concurrent sub-tasks. "
|
||||
"Wait for running sub-tasks to finish, "
|
||||
"or continue in the main conversation."
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_denied_marker(self):
|
||||
assert (
|
||||
_is_tool_error_or_denial("Access denied: insufficient privileges") is True
|
||||
)
|
||||
|
||||
def test_blocked_marker(self):
|
||||
assert _is_tool_error_or_denial("Request blocked by security policy") is True
|
||||
|
||||
def test_failed_marker(self):
|
||||
assert _is_tool_error_or_denial("Failed to execute tool: timeout") is True
|
||||
|
||||
def test_mcp_iserror(self):
|
||||
assert _is_tool_error_or_denial('{"isError": true, "content": []}') is True
|
||||
|
||||
def test_benign_error_in_value(self):
|
||||
"""Content like '0 errors found' should not trigger — 'error' was removed."""
|
||||
assert _is_tool_error_or_denial("0 errors found") is False
|
||||
|
||||
def test_benign_permission_field(self):
|
||||
"""Schema descriptions mentioning 'permission' should not trigger."""
|
||||
assert (
|
||||
_is_tool_error_or_denial(
|
||||
'{"fields": [{"name": "permission_level", "type": "int"}]}'
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def test_benign_not_found_in_listing(self):
|
||||
"""File listing containing 'not found' in filenames should not trigger."""
|
||||
assert _is_tool_error_or_denial("readme.md\nfile-not-found-handler.py") is False
|
||||
|
||||
@@ -60,7 +60,7 @@ from ..service import (
|
||||
_generate_session_title,
|
||||
_is_langfuse_configured,
|
||||
)
|
||||
from ..tools.e2b_sandbox import get_or_create_sandbox
|
||||
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
|
||||
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||
from ..tools.workspace_files import get_manager
|
||||
from ..tracking import track_user_message
|
||||
@@ -456,31 +456,6 @@ def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
|
||||
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
|
||||
|
||||
|
||||
def _is_tool_error_or_denial(content: str | None) -> bool:
|
||||
"""Check if a tool message content indicates an error or denial.
|
||||
|
||||
Currently unused — ``_format_conversation_context`` includes all tool
|
||||
results. Kept as a utility for future selective filtering.
|
||||
"""
|
||||
if not content:
|
||||
return False
|
||||
lower = content.lower()
|
||||
return any(
|
||||
marker in lower
|
||||
for marker in (
|
||||
"[security]",
|
||||
"cannot be bypassed",
|
||||
"not allowed",
|
||||
"not supported", # background-task denial
|
||||
"maximum", # subtask-limit denial
|
||||
"denied",
|
||||
"blocked",
|
||||
"failed to", # internal tool execution failures
|
||||
'"iserror": true', # MCP protocol error flag
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _build_query_message(
|
||||
current_message: str,
|
||||
session: ChatSession,
|
||||
@@ -784,28 +759,29 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
async def _setup_e2b():
|
||||
"""Set up E2B sandbox if configured, return sandbox or None."""
|
||||
if config.use_e2b_sandbox and not config.e2b_api_key:
|
||||
logger.warning(
|
||||
"[E2B] [%s] E2B sandbox enabled but no API key configured "
|
||||
"(CHAT_E2B_API_KEY / E2B_API_KEY) — falling back to bubblewrap",
|
||||
session_id[:12],
|
||||
)
|
||||
return None
|
||||
if config.use_e2b_sandbox and config.e2b_api_key:
|
||||
try:
|
||||
return await get_or_create_sandbox(
|
||||
session_id,
|
||||
api_key=config.e2b_api_key,
|
||||
template=config.e2b_sandbox_template,
|
||||
timeout=config.e2b_sandbox_timeout,
|
||||
)
|
||||
except Exception as e2b_err:
|
||||
logger.error(
|
||||
"[E2B] [%s] Setup failed: %s",
|
||||
if not (e2b_api_key := config.active_e2b_api_key):
|
||||
if config.use_e2b_sandbox:
|
||||
logger.warning(
|
||||
"[E2B] [%s] E2B sandbox enabled but no API key configured "
|
||||
"(CHAT_E2B_API_KEY / E2B_API_KEY) — falling back to bubblewrap",
|
||||
session_id[:12],
|
||||
e2b_err,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
try:
|
||||
return await get_or_create_sandbox(
|
||||
session_id,
|
||||
api_key=e2b_api_key,
|
||||
template=config.e2b_sandbox_template,
|
||||
timeout=config.e2b_sandbox_timeout,
|
||||
on_timeout=config.e2b_sandbox_on_timeout,
|
||||
)
|
||||
except Exception as e2b_err:
|
||||
logger.error(
|
||||
"[E2B] [%s] Setup failed: %s",
|
||||
session_id[:12],
|
||||
e2b_err,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
async def _fetch_transcript():
|
||||
@@ -837,7 +813,6 @@ async def stream_chat_completion_sdk(
|
||||
system_prompt = base_system_prompt + get_sdk_supplement(
|
||||
use_e2b=use_e2b, cwd=sdk_cwd
|
||||
)
|
||||
|
||||
# Process transcript download result
|
||||
transcript_msg_count = 0
|
||||
if dl:
|
||||
@@ -902,6 +877,11 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
allowed = get_copilot_tool_names(use_e2b=use_e2b)
|
||||
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
|
||||
|
||||
def _on_stderr(line: str) -> None:
|
||||
sid = session_id[:12] if session_id else "?"
|
||||
logger.info("[SDK] [%s] CLI stderr: %s", sid, line.rstrip())
|
||||
|
||||
sdk_options_kwargs: dict[str, Any] = {
|
||||
"system_prompt": system_prompt,
|
||||
"mcp_servers": {"copilot": mcp_server},
|
||||
@@ -910,6 +890,7 @@ async def stream_chat_completion_sdk(
|
||||
"hooks": security_hooks,
|
||||
"cwd": sdk_cwd,
|
||||
"max_buffer_size": config.claude_agent_max_buffer_size,
|
||||
"stderr": _on_stderr,
|
||||
}
|
||||
if sdk_model:
|
||||
sdk_options_kwargs["model"] = sdk_model
|
||||
@@ -1082,6 +1063,19 @@ async def stream_chat_completion_sdk(
|
||||
len(adapter.resolved_tool_calls),
|
||||
)
|
||||
|
||||
# Log AssistantMessage API errors (e.g. invalid_request)
|
||||
# so we can debug Anthropic API 400s surfaced by the CLI.
|
||||
sdk_error = getattr(sdk_msg, "error", None)
|
||||
if isinstance(sdk_msg, AssistantMessage) and sdk_error:
|
||||
logger.error(
|
||||
"[SDK] [%s] AssistantMessage has error=%s, "
|
||||
"content_blocks=%d, content_preview=%s",
|
||||
session_id[:12],
|
||||
sdk_error,
|
||||
len(sdk_msg.content),
|
||||
str(sdk_msg.content)[:500],
|
||||
)
|
||||
|
||||
# Race-condition fix: SDK hooks (PostToolUse) are
|
||||
# executed asynchronously via start_soon() — the next
|
||||
# message can arrive before the hook stashes output.
|
||||
@@ -1416,6 +1410,17 @@ async def stream_chat_completion_sdk(
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# --- Pause E2B sandbox to stop billing between turns ---
|
||||
# Fire-and-forget: pausing is best-effort and must not block the
|
||||
# response or the transcript upload. The task is anchored to
|
||||
# _background_tasks to prevent garbage collection.
|
||||
# Use pause_sandbox_direct to skip the Redis lookup and reconnect
|
||||
# round-trip — e2b_sandbox is the live object from this turn.
|
||||
if e2b_sandbox is not None:
|
||||
task = asyncio.create_task(pause_sandbox_direct(e2b_sandbox, session_id))
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
# --- Upload transcript for next-turn --resume ---
|
||||
# This MUST run in finally so the transcript is uploaded even when
|
||||
# the streaming loop raises an exception.
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Tests for SDK service helpers."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -212,7 +213,7 @@ class TestPromptSupplement:
|
||||
|
||||
# Workflows are now in individual tool descriptions (not separate sections)
|
||||
# Check that key workflow concepts appear in tool descriptions
|
||||
assert "suggested_goal" in docs or "clarifying_questions" in docs
|
||||
assert "agent_json" in docs or "find_block" in docs
|
||||
assert "run_mcp_tool" in docs
|
||||
|
||||
def test_baseline_supplement_completeness(self):
|
||||
@@ -231,6 +232,48 @@ class TestPromptSupplement:
|
||||
f"`{tool_name}`" in docs
|
||||
), f"Tool '{tool_name}' missing from baseline supplement"
|
||||
|
||||
def test_pause_task_scheduled_before_transcript_upload(self):
|
||||
"""Pause is scheduled as a background task before transcript upload begins.
|
||||
|
||||
The finally block in stream_response_sdk does:
|
||||
(1) asyncio.create_task(pause_sandbox_direct(...)) — fire-and-forget
|
||||
(2) await asyncio.shield(upload_transcript(...)) — awaited
|
||||
|
||||
Scheduling pause via create_task before awaiting upload ensures:
|
||||
- Pause never blocks transcript upload (billing stops concurrently)
|
||||
- On E2B timeout, pause silently fails; upload proceeds unaffected
|
||||
"""
|
||||
call_order: list[str] = []
|
||||
|
||||
async def _mock_pause(sandbox, session_id):
|
||||
call_order.append("pause")
|
||||
|
||||
async def _mock_upload(**kwargs):
|
||||
call_order.append("upload")
|
||||
|
||||
async def _simulate_teardown():
|
||||
"""Mirror the service.py finally block teardown sequence."""
|
||||
sandbox = MagicMock()
|
||||
|
||||
# (1) Schedule pause — mirrors lines ~1427-1429 in service.py
|
||||
task = asyncio.create_task(_mock_pause(sandbox, "test-sess"))
|
||||
|
||||
# (2) Await transcript upload — mirrors lines ~1460-1468 in service.py
|
||||
# Yielding to the event loop here lets the pause task start concurrently.
|
||||
await _mock_upload(
|
||||
user_id="u", session_id="test-sess", content="x", message_count=1
|
||||
)
|
||||
await task
|
||||
|
||||
asyncio.run(_simulate_teardown())
|
||||
|
||||
# Both must run; pause is scheduled before upload starts
|
||||
assert "pause" in call_order
|
||||
assert "upload" in call_order
|
||||
# create_task schedules pause, then upload is awaited — pause runs
|
||||
# concurrently during upload's first yield. The ordering guarantee is
|
||||
# that create_task is CALLED before upload is AWAITED (see source order).
|
||||
|
||||
def test_baseline_supplement_no_duplicate_tools(self):
|
||||
"""No tool should appear multiple times in baseline supplement."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
|
||||
@@ -9,14 +9,29 @@ import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
|
||||
from backend.copilot.context import (
|
||||
_current_project_dir,
|
||||
_current_sandbox,
|
||||
_current_sdk_cwd,
|
||||
_current_session,
|
||||
_current_user_id,
|
||||
_encode_cwd_for_cli,
|
||||
get_execution_context,
|
||||
get_sdk_cwd,
|
||||
is_allowed_local_path,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.sdk.file_ref import (
|
||||
FileRefExpansionError,
|
||||
expand_file_refs_in_args,
|
||||
read_file_bytes,
|
||||
)
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
from backend.copilot.tools.base import BaseTool
|
||||
from backend.util.truncate import truncate
|
||||
@@ -28,84 +43,13 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Allowed base directory for the Read tool (SDK saves oversized tool results here).
|
||||
# Restricted to ~/.claude/projects/ and further validated to require "tool-results"
|
||||
# in the path — prevents reading settings, credentials, or other sensitive files.
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
# Max MCP response size in chars — keeps tool output under the SDK's 10 MB JSON buffer.
|
||||
_MCP_MAX_CHARS = 500_000
|
||||
|
||||
# Context variable holding the encoded project directory name for the current
|
||||
# session (e.g. "-private-tmp-copilot-<uuid>"). Set by set_execution_context()
|
||||
# so that path validation can scope tool-results reads to the current session.
|
||||
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
|
||||
|
||||
|
||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does.
|
||||
|
||||
The CLI replaces all non-alphanumeric characters with ``-``.
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
|
||||
|
||||
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
"""Check whether *path* is an allowed host-filesystem path.
|
||||
|
||||
Allowed:
|
||||
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/`` — the SDK's
|
||||
project directory for this session (tool-results, transcripts, etc.)
|
||||
|
||||
Both checks are scoped to the **current session** so sessions cannot
|
||||
read each other's data.
|
||||
"""
|
||||
if not path:
|
||||
return False
|
||||
|
||||
if path.startswith("~"):
|
||||
resolved = os.path.realpath(os.path.expanduser(path))
|
||||
elif not os.path.isabs(path) and sdk_cwd:
|
||||
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
|
||||
else:
|
||||
resolved = os.path.realpath(path)
|
||||
|
||||
# Allow access within the SDK working directory
|
||||
if sdk_cwd:
|
||||
norm_cwd = os.path.realpath(sdk_cwd)
|
||||
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
|
||||
return True
|
||||
|
||||
# Allow access within the current session's CLI project directory
|
||||
# (~/.claude/projects/<encoded-cwd>/).
|
||||
encoded = _current_project_dir.get("")
|
||||
if encoded:
|
||||
session_project = os.path.join(_SDK_PROJECTS_DIR, encoded)
|
||||
if resolved == session_project or resolved.startswith(session_project + os.sep):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
|
||||
MCP_SERVER_NAME = "copilot"
|
||||
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
|
||||
|
||||
# Context variables to pass user/session info to tool execution
|
||||
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
|
||||
_current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||
"current_session", default=None
|
||||
)
|
||||
# E2B cloud sandbox for the current turn (None when E2B is not configured).
|
||||
# Passed to bash_exec so commands run on E2B instead of the local bwrap sandbox.
|
||||
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
|
||||
"_current_sandbox", default=None
|
||||
)
|
||||
# Raw SDK working directory path (e.g. /tmp/copilot-<session_id>).
|
||||
# Used by workspace tools to save binary files for the CLI's built-in Read.
|
||||
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
|
||||
|
||||
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||
# response adapter when it builds StreamToolOutputAvailable.
|
||||
@@ -149,24 +93,6 @@ def set_execution_context(
|
||||
_stash_event.set(asyncio.Event())
|
||||
|
||||
|
||||
def get_current_sandbox() -> "AsyncSandbox | None":
|
||||
"""Return the E2B sandbox for the current turn, or None."""
|
||||
return _current_sandbox.get()
|
||||
|
||||
|
||||
def get_sdk_cwd() -> str:
|
||||
"""Return the SDK ephemeral working directory for the current turn."""
|
||||
return _current_sdk_cwd.get()
|
||||
|
||||
|
||||
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
"""Get the current execution context."""
|
||||
return (
|
||||
_current_user_id.get(),
|
||||
_current_session.get(),
|
||||
)
|
||||
|
||||
|
||||
def pop_pending_tool_output(tool_name: str) -> str | None:
|
||||
"""Pop and return the oldest stashed output for *tool_name*.
|
||||
|
||||
@@ -259,7 +185,11 @@ async def _execute_tool_sync(
|
||||
session: ChatSession,
|
||||
args: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Execute a tool synchronously and return MCP-formatted response."""
|
||||
"""Execute a tool synchronously and return MCP-formatted response.
|
||||
|
||||
Note: ``@@agptfile:`` expansion is handled upstream in the ``_truncating`` wrapper
|
||||
so all registered handlers (BaseTool, E2B, Read) expand uniformly.
|
||||
"""
|
||||
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
|
||||
result = await base_tool.execute(
|
||||
user_id=user_id,
|
||||
@@ -320,42 +250,50 @@ def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
|
||||
|
||||
|
||||
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Read a local file with optional offset/limit.
|
||||
"""Read a file with optional offset/limit.
|
||||
|
||||
Only allows paths that pass :func:`is_allowed_local_path` — the current
|
||||
session's tool-results directory and ephemeral working directory.
|
||||
Supports ``workspace://`` URIs (delegated to the workspace manager) and
|
||||
local paths within the session's allowed directories (sdk_cwd + tool-results).
|
||||
"""
|
||||
file_path = args.get("file_path", "")
|
||||
offset = args.get("offset", 0)
|
||||
limit = args.get("limit", 2000)
|
||||
offset = max(0, int(args.get("offset", 0)))
|
||||
limit = max(1, int(args.get("limit", 2000)))
|
||||
|
||||
if not is_allowed_local_path(file_path):
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
|
||||
"isError": True,
|
||||
}
|
||||
def _mcp_err(text: str) -> dict[str, Any]:
|
||||
return {"content": [{"type": "text", "text": text}], "isError": True}
|
||||
|
||||
def _mcp_ok(text: str) -> dict[str, Any]:
|
||||
return {"content": [{"type": "text", "text": text}], "isError": False}
|
||||
|
||||
if file_path.startswith("workspace://"):
|
||||
user_id, session = get_execution_context()
|
||||
if session is None:
|
||||
return _mcp_err("workspace:// file references require an active session")
|
||||
try:
|
||||
raw = await read_file_bytes(file_path, user_id, session)
|
||||
except ValueError as exc:
|
||||
return _mcp_err(str(exc))
|
||||
lines = raw.decode("utf-8", errors="replace").splitlines(keepends=True)
|
||||
selected = list(itertools.islice(lines, offset, offset + limit))
|
||||
numbered = "".join(
|
||||
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
|
||||
)
|
||||
return _mcp_ok(numbered)
|
||||
|
||||
if not is_allowed_local_path(file_path, get_sdk_cwd()):
|
||||
return _mcp_err(f"Path not allowed: {file_path}")
|
||||
|
||||
resolved = os.path.realpath(os.path.expanduser(file_path))
|
||||
try:
|
||||
with open(resolved) as f:
|
||||
selected = list(itertools.islice(f, offset, offset + limit))
|
||||
content = "".join(selected)
|
||||
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
|
||||
# don't delete here — the SDK may read in multiple chunks.
|
||||
return {
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"isError": False,
|
||||
}
|
||||
return _mcp_ok("".join(selected))
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
|
||||
"isError": True,
|
||||
}
|
||||
return _mcp_err(f"File not found: {file_path}")
|
||||
except Exception as e:
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"Error reading file: {e}"}],
|
||||
"isError": True,
|
||||
}
|
||||
return _mcp_err(f"Error reading file: {e}")
|
||||
|
||||
|
||||
_READ_TOOL_NAME = "Read"
|
||||
@@ -414,9 +352,23 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
SDK's 10 MB JSON buffer, and stash the (truncated) output for the
|
||||
response adapter before the SDK can apply its own head-truncation.
|
||||
|
||||
Also expands ``@@agptfile:`` references in args so every registered tool
|
||||
(BaseTool, E2B file tools, Read) receives resolved content uniformly.
|
||||
|
||||
Applied once to every registered tool."""
|
||||
|
||||
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
|
||||
user_id, session = get_execution_context()
|
||||
if session is not None:
|
||||
try:
|
||||
args = await expand_file_refs_in_args(args, user_id, session)
|
||||
except FileRefExpansionError as exc:
|
||||
return _mcp_error(
|
||||
f"@@agptfile: reference could not be resolved: {exc}. "
|
||||
"Ensure the file exists before referencing it. "
|
||||
"For sandbox paths use bash_exec to verify the file exists first; "
|
||||
"for workspace files use a workspace:// URI."
|
||||
)
|
||||
result = await fn(args)
|
||||
truncated = truncate(result, _MCP_MAX_CHARS)
|
||||
|
||||
|
||||
@@ -2,12 +2,12 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import get_sdk_cwd
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
from .tool_adapter import (
|
||||
_MCP_MAX_CHARS,
|
||||
_text_from_mcp_result,
|
||||
get_sdk_cwd,
|
||||
pop_pending_tool_output,
|
||||
set_execution_context,
|
||||
stash_pending_tool_output,
|
||||
|
||||
@@ -12,6 +12,7 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
from .continue_run_block import ContinueRunBlockTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .customize_agent import CustomizeAgentTool
|
||||
from .edit_agent import EditAgentTool
|
||||
@@ -19,7 +20,10 @@ from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsToo
|
||||
from .find_agent import FindAgentTool
|
||||
from .find_block import FindBlockTool
|
||||
from .find_library_agent import FindLibraryAgentTool
|
||||
from .fix_agent import FixAgentGraphTool
|
||||
from .get_agent_building_guide import GetAgentBuildingGuideTool
|
||||
from .get_doc_page import GetDocPageTool
|
||||
from .get_mcp_guide import GetMCPGuideTool
|
||||
from .manage_folders import (
|
||||
CreateFolderTool,
|
||||
DeleteFolderTool,
|
||||
@@ -32,6 +36,7 @@ from .run_agent import RunAgentTool
|
||||
from .run_block import RunBlockTool
|
||||
from .run_mcp_tool import RunMCPToolTool
|
||||
from .search_docs import SearchDocsTool
|
||||
from .validate_agent import ValidateAgentGraphTool
|
||||
from .web_fetch import WebFetchTool
|
||||
from .workspace_files import (
|
||||
DeleteWorkspaceFileTool,
|
||||
@@ -64,10 +69,13 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"move_agents_to_folder": MoveAgentsToFolderTool(),
|
||||
"run_agent": RunAgentTool(),
|
||||
"run_block": RunBlockTool(),
|
||||
"continue_run_block": ContinueRunBlockTool(),
|
||||
"run_mcp_tool": RunMCPToolTool(),
|
||||
"get_mcp_guide": GetMCPGuideTool(),
|
||||
"view_agent_output": AgentOutputTool(),
|
||||
"search_docs": SearchDocsTool(),
|
||||
"get_doc_page": GetDocPageTool(),
|
||||
"get_agent_building_guide": GetAgentBuildingGuideTool(),
|
||||
# Web fetch for safe URL retrieval
|
||||
"web_fetch": WebFetchTool(),
|
||||
# Agent-browser multi-step automation (navigate, act, screenshot)
|
||||
@@ -80,6 +88,9 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
# Feature request tools
|
||||
"search_feature_requests": SearchFeatureRequestsTool(),
|
||||
"create_feature_request": CreateFeatureRequestTool(),
|
||||
# Agent generation tools (local validation/fixing)
|
||||
"validate_agent_graph": ValidateAgentGraphTool(),
|
||||
"fix_agent_graph": FixAgentGraphTool(),
|
||||
# Workspace tools for CoPilot file operations
|
||||
"list_workspace_files": ListWorkspaceFilesTool(),
|
||||
"read_workspace_file": ReadWorkspaceFileTool(),
|
||||
|
||||
@@ -33,7 +33,7 @@ import tempfile
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.util.request import validate_url
|
||||
from backend.util.request import validate_url_host
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
@@ -235,7 +235,7 @@ async def _restore_browser_state(
|
||||
if url:
|
||||
# Validate the saved URL to prevent SSRF via stored redirect targets.
|
||||
try:
|
||||
await validate_url(url, trusted_origins=[])
|
||||
await validate_url_host(url)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"[browser] State restore: blocked SSRF URL %s", url[:200]
|
||||
@@ -473,7 +473,7 @@ class BrowserNavigateTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
await validate_url(url, trusted_origins=[])
|
||||
await validate_url_host(url)
|
||||
except ValueError as e:
|
||||
return ErrorResponse(
|
||||
message=str(e),
|
||||
|
||||
@@ -68,17 +68,18 @@ def _run_result(rc: int = 0, stdout: str = "", stderr: str = "") -> tuple:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSRF protection via shared validate_url (backend.util.request)
|
||||
# SSRF protection via shared validate_url_host (backend.util.request)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Patch target: validate_url is imported directly into agent_browser's module scope.
|
||||
_VALIDATE_URL = "backend.copilot.tools.agent_browser.validate_url"
|
||||
# Patch target: validate_url_host is imported directly into agent_browser's
|
||||
# module scope.
|
||||
_VALIDATE_URL = "backend.copilot.tools.agent_browser.validate_url_host"
|
||||
|
||||
|
||||
class TestSsrfViaValidateUrl:
|
||||
"""Verify that browser_navigate uses validate_url for SSRF protection.
|
||||
"""Verify that browser_navigate uses validate_url_host for SSRF protection.
|
||||
|
||||
We mock validate_url itself (not the low-level socket) so these tests
|
||||
We mock validate_url_host itself (not the low-level socket) so these tests
|
||||
exercise the integration point, not the internals of request.py
|
||||
(which has its own thorough test suite in request_test.py).
|
||||
"""
|
||||
@@ -89,7 +90,7 @@ class TestSsrfViaValidateUrl:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocked_ip_returns_blocked_url_error(self):
|
||||
"""validate_url raises ValueError → tool returns blocked_url ErrorResponse."""
|
||||
"""validate_url_host raises ValueError → tool returns blocked_url ErrorResponse."""
|
||||
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
|
||||
mock_validate.side_effect = ValueError(
|
||||
"Access to blocked IP 10.0.0.1 is not allowed."
|
||||
@@ -124,8 +125,8 @@ class TestSsrfViaValidateUrl:
|
||||
assert result.error == "blocked_url"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_url_called_with_empty_trusted_origins(self):
|
||||
"""Confirms no trusted-origins bypass is granted — all URLs are validated."""
|
||||
async def test_validate_url_host_called_without_trusted_hostnames(self):
|
||||
"""Confirms no trusted-hostnames bypass is granted — all URLs are validated."""
|
||||
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
|
||||
mock_validate.return_value = (object(), False, ["1.2.3.4"])
|
||||
with patch(
|
||||
@@ -143,7 +144,7 @@ class TestSsrfViaValidateUrl:
|
||||
session=self.session,
|
||||
url="https://example.com",
|
||||
)
|
||||
mock_validate.assert_called_once_with("https://example.com", trusted_origins=[])
|
||||
mock_validate.assert_called_once_with("https://example.com")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -1,20 +1,15 @@
|
||||
"""Agent generator package - Creates agents from natural language."""
|
||||
|
||||
from .core import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
AgentJsonValidationError,
|
||||
AgentSummary,
|
||||
DecompositionResult,
|
||||
DecompositionStep,
|
||||
LibraryAgentSummary,
|
||||
MarketplaceAgentSummary,
|
||||
customize_template,
|
||||
decompose_goal,
|
||||
enrich_library_agents_from_steps,
|
||||
extract_search_terms_from_steps,
|
||||
extract_uuids_from_text,
|
||||
generate_agent,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
get_all_relevant_agents_for_generation,
|
||||
get_library_agent_by_graph_id,
|
||||
@@ -27,25 +22,20 @@ from .core import (
|
||||
search_marketplace_agents_for_generation,
|
||||
)
|
||||
from .errors import get_user_message_for_error
|
||||
from .service import health_check as check_external_service_health
|
||||
from .service import is_external_service_configured
|
||||
from .validation import AgentFixer, AgentValidator
|
||||
|
||||
__all__ = [
|
||||
"AgentGeneratorNotConfiguredError",
|
||||
"AgentFixer",
|
||||
"AgentValidator",
|
||||
"AgentJsonValidationError",
|
||||
"AgentSummary",
|
||||
"DecompositionResult",
|
||||
"DecompositionStep",
|
||||
"LibraryAgentSummary",
|
||||
"MarketplaceAgentSummary",
|
||||
"check_external_service_health",
|
||||
"customize_template",
|
||||
"decompose_goal",
|
||||
"enrich_library_agents_from_steps",
|
||||
"extract_search_terms_from_steps",
|
||||
"extract_uuids_from_text",
|
||||
"generate_agent",
|
||||
"generate_agent_patch",
|
||||
"get_agent_as_json",
|
||||
"get_all_relevant_agents_for_generation",
|
||||
"get_library_agent_by_graph_id",
|
||||
@@ -54,7 +44,6 @@ __all__ = [
|
||||
"get_library_agents_for_generation",
|
||||
"get_user_message_for_error",
|
||||
"graph_to_json",
|
||||
"is_external_service_configured",
|
||||
"json_to_graph",
|
||||
"save_agent_to_library",
|
||||
"search_marketplace_agents_for_generation",
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
"""Block management for agent generation.
|
||||
|
||||
Provides cached access to block metadata for validation and fixing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Type
|
||||
|
||||
from backend.blocks import get_blocks as get_block_classes
|
||||
from backend.blocks._base import Block
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["get_blocks_as_dicts", "reset_block_caches"]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level caches
|
||||
# ---------------------------------------------------------------------------
|
||||
_blocks_cache: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
def reset_block_caches() -> None:
|
||||
"""Reset all module-level caches (useful after updating block descriptions)."""
|
||||
global _blocks_cache
|
||||
_blocks_cache = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. get_blocks_as_dicts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_blocks_as_dicts() -> list[dict[str, Any]]:
|
||||
"""Get all available blocks as dicts (cached after first call).
|
||||
|
||||
Each dict contains the keys returned by ``Block.get_info().model_dump()``:
|
||||
id, name, description, inputSchema, outputSchema, categories,
|
||||
staticOutput, costs, contributors, uiType.
|
||||
|
||||
Returns:
|
||||
List of block info dicts.
|
||||
"""
|
||||
global _blocks_cache
|
||||
if _blocks_cache is not None:
|
||||
return _blocks_cache
|
||||
|
||||
block_classes: dict[str, Type[Block]] = get_block_classes() # type: ignore[assignment]
|
||||
blocks: list[dict[str, Any]] = []
|
||||
for block_cls in block_classes.values():
|
||||
try:
|
||||
instance = block_cls()
|
||||
info = instance.get_info().model_dump()
|
||||
# Use optimized description if available (loaded at startup)
|
||||
if instance.optimized_description:
|
||||
info["description"] = instance.optimized_description
|
||||
blocks.append(info)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to load block info for %s, skipping",
|
||||
getattr(block_cls, "__name__", "unknown"),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
_blocks_cache = blocks
|
||||
logger.info("Cached %d block dicts", len(blocks))
|
||||
return _blocks_cache
|
||||
@@ -10,13 +10,7 @@ from backend.data.db_accessors import graph_db, library_db, store_db
|
||||
from backend.data.graph import Graph, Link, Node
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .service import (
|
||||
customize_template_external,
|
||||
decompose_goal_external,
|
||||
generate_agent_external,
|
||||
generate_agent_patch_external,
|
||||
is_external_service_configured,
|
||||
)
|
||||
from .helpers import UUID_RE_STR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -78,38 +72,7 @@ class DecompositionResult(TypedDict, total=False):
|
||||
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
|
||||
|
||||
|
||||
def _to_dict_list(
|
||||
agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None,
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Convert typed agent summaries to plain dicts for external service calls."""
|
||||
if agents is None:
|
||||
return None
|
||||
return [dict(a) for a in agents]
|
||||
|
||||
|
||||
class AgentGeneratorNotConfiguredError(Exception):
|
||||
"""Raised when the external Agent Generator service is not configured."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _check_service_configured() -> None:
|
||||
"""Check if the external Agent Generator service is configured.
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the service is not configured.
|
||||
"""
|
||||
if not is_external_service_configured():
|
||||
raise AgentGeneratorNotConfiguredError(
|
||||
"Agent Generator service is not configured. "
|
||||
"Set AGENTGENERATOR_HOST environment variable to enable agent generation."
|
||||
)
|
||||
|
||||
|
||||
_UUID_PATTERN = re.compile(
|
||||
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_UUID_PATTERN = re.compile(UUID_RE_STR, re.IGNORECASE)
|
||||
|
||||
|
||||
def extract_uuids_from_text(text: str) -> list[str]:
|
||||
@@ -553,69 +516,6 @@ async def enrich_library_agents_from_steps(
|
||||
return all_agents
|
||||
|
||||
|
||||
async def decompose_goal(
|
||||
description: str,
|
||||
context: str = "",
|
||||
library_agents: Sequence[AgentSummary] | None = None,
|
||||
) -> DecompositionResult | None:
|
||||
"""Break down a goal into steps or return clarifying questions.
|
||||
|
||||
Args:
|
||||
description: Natural language goal description
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
DecompositionResult with either:
|
||||
- {"type": "clarifying_questions", "questions": [...]}
|
||||
- {"type": "instructions", "steps": [...]}
|
||||
Or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for decompose_goal")
|
||||
result = await decompose_goal_external(
|
||||
description, context, _to_dict_list(library_agents)
|
||||
)
|
||||
return result # type: ignore[return-value]
|
||||
|
||||
|
||||
async def generate_agent(
|
||||
instructions: DecompositionResult | dict[str, Any],
|
||||
library_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Generate agent JSON from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
Agent JSON dict, error dict {"type": "error", ...}, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent")
|
||||
result = await generate_agent_external(
|
||||
dict(instructions), _to_dict_list(library_agents)
|
||||
)
|
||||
|
||||
if result:
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
return result
|
||||
if "id" not in result:
|
||||
result["id"] = str(uuid.uuid4())
|
||||
if "version" not in result:
|
||||
result["version"] = 1
|
||||
if "is_active" not in result:
|
||||
result["is_active"] = True
|
||||
return result
|
||||
|
||||
|
||||
class AgentJsonValidationError(Exception):
|
||||
"""Raised when agent JSON is invalid or missing required fields."""
|
||||
|
||||
@@ -792,70 +692,3 @@ async def get_agent_as_json(
|
||||
return None
|
||||
|
||||
return graph_to_json(graph)
|
||||
|
||||
|
||||
async def generate_agent_patch(
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
library_agents: Sequence[AgentSummary] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Update an existing agent using natural language.
|
||||
|
||||
The external Agent Generator service handles:
|
||||
- Generating the patch
|
||||
- Applying the patch
|
||||
- Fixing and validating the result
|
||||
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||
error dict {"type": "error", ...}, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||
return await generate_agent_patch_external(
|
||||
update_request,
|
||||
current_agent,
|
||||
_to_dict_list(library_agents),
|
||||
)
|
||||
|
||||
|
||||
async def customize_template(
|
||||
template_agent: dict[str, Any],
|
||||
modification_request: str,
|
||||
context: str = "",
|
||||
) -> dict[str, Any] | None:
|
||||
"""Customize a template/marketplace agent using natural language.
|
||||
|
||||
This is used when users want to modify a template or marketplace agent
|
||||
to fit their specific needs before adding it to their library.
|
||||
|
||||
The external Agent Generator service handles:
|
||||
- Understanding the modification request
|
||||
- Applying changes to the template
|
||||
- Fixing and validating the result
|
||||
|
||||
Args:
|
||||
template_agent: The template agent JSON to customize
|
||||
modification_request: Natural language description of customizations
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
|
||||
Returns:
|
||||
Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||
error dict {"type": "error", ...}, or None on unexpected error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for customize_template")
|
||||
return await customize_template_external(
|
||||
template_agent, modification_request, context
|
||||
)
|
||||
|
||||
@@ -1,165 +0,0 @@
|
||||
"""Dummy Agent Generator for testing.
|
||||
|
||||
Returns mock responses matching the format expected from the external service.
|
||||
Enable via AGENTGENERATOR_USE_DUMMY=true in settings.
|
||||
|
||||
WARNING: This is for testing only. Do not use in production.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Dummy decomposition result (instructions type)
|
||||
DUMMY_DECOMPOSITION_RESULT: dict[str, Any] = {
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{
|
||||
"description": "Get input from user",
|
||||
"action": "input",
|
||||
"block_name": "AgentInputBlock",
|
||||
},
|
||||
{
|
||||
"description": "Process the input",
|
||||
"action": "process",
|
||||
"block_name": "TextFormatterBlock",
|
||||
},
|
||||
{
|
||||
"description": "Return output to user",
|
||||
"action": "output",
|
||||
"block_name": "AgentOutputBlock",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Block IDs from backend/blocks/io.py
|
||||
AGENT_INPUT_BLOCK_ID = "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
|
||||
AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4"
|
||||
|
||||
|
||||
def _generate_dummy_agent_json() -> dict[str, Any]:
|
||||
"""Generate a minimal valid agent JSON for testing."""
|
||||
input_node_id = str(uuid.uuid4())
|
||||
output_node_id = str(uuid.uuid4())
|
||||
|
||||
return {
|
||||
"id": str(uuid.uuid4()),
|
||||
"version": 1,
|
||||
"is_active": True,
|
||||
"name": "Dummy Test Agent",
|
||||
"description": "A dummy agent generated for testing purposes",
|
||||
"nodes": [
|
||||
{
|
||||
"id": input_node_id,
|
||||
"block_id": AGENT_INPUT_BLOCK_ID,
|
||||
"input_default": {
|
||||
"name": "input",
|
||||
"title": "Input",
|
||||
"description": "Enter your input",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
"metadata": {"position": {"x": 0, "y": 0}},
|
||||
},
|
||||
{
|
||||
"id": output_node_id,
|
||||
"block_id": AGENT_OUTPUT_BLOCK_ID,
|
||||
"input_default": {
|
||||
"name": "output",
|
||||
"title": "Output",
|
||||
"description": "Agent output",
|
||||
"format": "{output}",
|
||||
},
|
||||
"metadata": {"position": {"x": 400, "y": 0}},
|
||||
},
|
||||
],
|
||||
"links": [
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": input_node_id,
|
||||
"sink_id": output_node_id,
|
||||
"source_name": "result",
|
||||
"sink_name": "value",
|
||||
"is_static": False,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def decompose_goal_dummy(
|
||||
description: str,
|
||||
context: str = "",
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Return dummy decomposition result."""
|
||||
logger.info("Using dummy agent generator for decompose_goal")
|
||||
return DUMMY_DECOMPOSITION_RESULT.copy()
|
||||
|
||||
|
||||
async def generate_agent_dummy(
|
||||
instructions: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Return dummy agent synchronously (blocks for 30s, returns agent JSON).
|
||||
|
||||
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
|
||||
"""
|
||||
logger.info(
|
||||
"Using dummy agent generator (sync mode): returning agent JSON after 30s"
|
||||
)
|
||||
await asyncio.sleep(30)
|
||||
return _generate_dummy_agent_json()
|
||||
|
||||
|
||||
async def generate_agent_patch_dummy(
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Return dummy patched agent synchronously (blocks for 30s, returns patched agent JSON).
|
||||
|
||||
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
|
||||
"""
|
||||
logger.info(
|
||||
"Using dummy agent generator patch (sync mode): returning patched agent after 30s"
|
||||
)
|
||||
await asyncio.sleep(30)
|
||||
patched = current_agent.copy()
|
||||
patched["description"] = (
|
||||
f"{current_agent.get('description', '')} (updated: {update_request})"
|
||||
)
|
||||
return patched
|
||||
|
||||
|
||||
async def customize_template_dummy(
|
||||
template_agent: dict[str, Any],
|
||||
modification_request: str,
|
||||
context: str = "",
|
||||
) -> dict[str, Any]:
|
||||
"""Return dummy customized template (returns template with updated description)."""
|
||||
logger.info("Using dummy agent generator for customize_template")
|
||||
customized = template_agent.copy()
|
||||
customized["description"] = (
|
||||
f"{template_agent.get('description', '')} (customized: {modification_request})"
|
||||
)
|
||||
return customized
|
||||
|
||||
|
||||
async def get_blocks_dummy() -> list[dict[str, Any]]:
|
||||
"""Return dummy blocks list."""
|
||||
logger.info("Using dummy agent generator for get_blocks")
|
||||
return [
|
||||
{"id": AGENT_INPUT_BLOCK_ID, "name": "AgentInputBlock"},
|
||||
{"id": AGENT_OUTPUT_BLOCK_ID, "name": "AgentOutputBlock"},
|
||||
]
|
||||
|
||||
|
||||
async def health_check_dummy() -> bool:
|
||||
"""Always returns healthy for dummy service."""
|
||||
return True
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,913 @@
|
||||
"""Unit tests for AgentFixer."""
|
||||
|
||||
from .fixer import (
|
||||
_ADDTODICTIONARY_BLOCK_ID,
|
||||
_ADDTOLIST_BLOCK_ID,
|
||||
_CODE_EXECUTION_BLOCK_ID,
|
||||
_DATA_SAMPLING_BLOCK_ID,
|
||||
_GET_CURRENT_DATE_BLOCK_ID,
|
||||
_STORE_VALUE_BLOCK_ID,
|
||||
_TEXT_REPLACE_BLOCK_ID,
|
||||
_UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
|
||||
AGENT_EXECUTOR_BLOCK_ID,
|
||||
MCP_TOOL_BLOCK_ID,
|
||||
AgentFixer,
|
||||
)
|
||||
from .helpers import generate_uuid
|
||||
|
||||
|
||||
def _make_agent(
|
||||
nodes: list | None = None,
|
||||
links: list | None = None,
|
||||
agent_id: str | None = None,
|
||||
) -> dict:
|
||||
"""Build a minimal agent dict for testing."""
|
||||
return {
|
||||
"id": agent_id or generate_uuid(),
|
||||
"name": "Test Agent",
|
||||
"nodes": nodes or [],
|
||||
"links": links or [],
|
||||
}
|
||||
|
||||
|
||||
def _make_node(
|
||||
node_id: str | None = None,
|
||||
block_id: str = "block-1",
|
||||
input_default: dict | None = None,
|
||||
position: tuple[int, int] = (0, 0),
|
||||
) -> dict:
|
||||
"""Build a minimal node dict for testing."""
|
||||
return {
|
||||
"id": node_id or generate_uuid(),
|
||||
"block_id": block_id,
|
||||
"input_default": input_default or {},
|
||||
"metadata": {"position": {"x": position[0], "y": position[1]}},
|
||||
}
|
||||
|
||||
|
||||
def _make_link(
|
||||
link_id: str | None = None,
|
||||
source_id: str = "",
|
||||
source_name: str = "output",
|
||||
sink_id: str = "",
|
||||
sink_name: str = "input",
|
||||
is_static: bool = False,
|
||||
) -> dict:
|
||||
"""Build a minimal link dict for testing."""
|
||||
return {
|
||||
"id": link_id or generate_uuid(),
|
||||
"source_id": source_id,
|
||||
"source_name": source_name,
|
||||
"sink_id": sink_id,
|
||||
"sink_name": sink_name,
|
||||
"is_static": is_static,
|
||||
}
|
||||
|
||||
|
||||
class TestFixAgentIds:
|
||||
"""Tests for fix_agent_ids."""
|
||||
|
||||
def test_valid_uuids_unchanged(self):
|
||||
fixer = AgentFixer()
|
||||
agent_id = generate_uuid()
|
||||
link_id = generate_uuid()
|
||||
agent = _make_agent(agent_id=agent_id, links=[{"id": link_id}])
|
||||
|
||||
result = fixer.fix_agent_ids(agent)
|
||||
|
||||
assert result["id"] == agent_id
|
||||
assert result["links"][0]["id"] == link_id
|
||||
assert fixer.fixes_applied == []
|
||||
|
||||
def test_invalid_agent_id_replaced(self):
|
||||
fixer = AgentFixer()
|
||||
agent = _make_agent(agent_id="bad-id")
|
||||
|
||||
result = fixer.fix_agent_ids(agent)
|
||||
|
||||
assert result["id"] != "bad-id"
|
||||
assert len(fixer.fixes_applied) == 1
|
||||
assert "agent ID" in fixer.fixes_applied[0]
|
||||
|
||||
def test_invalid_link_id_replaced(self):
|
||||
fixer = AgentFixer()
|
||||
agent = _make_agent(links=[{"id": "not-a-uuid"}])
|
||||
|
||||
result = fixer.fix_agent_ids(agent)
|
||||
|
||||
assert result["links"][0]["id"] != "not-a-uuid"
|
||||
assert len(fixer.fixes_applied) == 1
|
||||
|
||||
|
||||
class TestFixDoubleCurlyBraces:
|
||||
"""Tests for fix_double_curly_braces."""
|
||||
|
||||
def test_single_braces_converted_to_double(self):
|
||||
fixer = AgentFixer()
|
||||
node = _make_node(input_default={"prompt": "Hello {name}!"})
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
result = fixer.fix_double_curly_braces(agent)
|
||||
|
||||
assert result["nodes"][0]["input_default"]["prompt"] == "Hello {{name}}!"
|
||||
|
||||
def test_double_braces_unchanged(self):
|
||||
fixer = AgentFixer()
|
||||
node = _make_node(input_default={"prompt": "Hello {{name}}!"})
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
result = fixer.fix_double_curly_braces(agent)
|
||||
|
||||
assert result["nodes"][0]["input_default"]["prompt"] == "Hello {{name}}!"
|
||||
assert fixer.fixes_applied == []
|
||||
|
||||
def test_non_string_prompt_skipped(self):
|
||||
fixer = AgentFixer()
|
||||
node = _make_node(input_default={"prompt": 42})
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
result = fixer.fix_double_curly_braces(agent)
|
||||
|
||||
assert result["nodes"][0]["input_default"]["prompt"] == 42
|
||||
|
||||
def test_non_string_prompt_with_prompt_values_skipped(self):
|
||||
"""Ensure non-string prompt fields don't crash re.search in the
|
||||
prompt_values path."""
|
||||
fixer = AgentFixer()
|
||||
node_id = generate_uuid()
|
||||
source_id = generate_uuid()
|
||||
node = _make_node(
|
||||
node_id=node_id, input_default={"prompt": None, "prompt_values": {}}
|
||||
)
|
||||
source_node = _make_node(node_id=source_id)
|
||||
link = _make_link(
|
||||
source_id=source_id,
|
||||
source_name="output",
|
||||
sink_id=node_id,
|
||||
sink_name="prompt_values_$_name",
|
||||
)
|
||||
agent = _make_agent(nodes=[node, source_node], links=[link])
|
||||
|
||||
result = fixer.fix_double_curly_braces(agent)
|
||||
|
||||
# Should not crash and prompt stays None
|
||||
assert result["nodes"][0]["input_default"]["prompt"] is None
|
||||
|
||||
|
||||
class TestFixCredentials:
|
||||
"""Tests for fix_credentials."""
|
||||
|
||||
def test_credentials_removed(self):
|
||||
fixer = AgentFixer()
|
||||
node = _make_node(
|
||||
input_default={
|
||||
"credentials": {"key": "secret"},
|
||||
"url": "http://example.com",
|
||||
}
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
result = fixer.fix_credentials(agent)
|
||||
|
||||
assert "credentials" not in result["nodes"][0]["input_default"]
|
||||
assert result["nodes"][0]["input_default"]["url"] == "http://example.com"
|
||||
assert len(fixer.fixes_applied) == 1
|
||||
|
||||
def test_no_credentials_unchanged(self):
|
||||
fixer = AgentFixer()
|
||||
node = _make_node(input_default={"url": "http://example.com"})
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
result = fixer.fix_credentials(agent)
|
||||
|
||||
assert result["nodes"][0]["input_default"]["url"] == "http://example.com"
|
||||
assert fixer.fixes_applied == []
|
||||
|
||||
|
||||
class TestFixCodeExecutionOutput:
|
||||
"""Tests for fix_code_execution_output."""
|
||||
|
||||
def test_response_renamed_to_stdout_logs(self):
|
||||
fixer = AgentFixer()
|
||||
node = _make_node(node_id="n1", block_id=_CODE_EXECUTION_BLOCK_ID)
|
||||
link = _make_link(source_id="n1", source_name="response", sink_id="n2")
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
result = fixer.fix_code_execution_output(agent)
|
||||
|
||||
assert result["links"][0]["source_name"] == "stdout_logs"
|
||||
assert len(fixer.fixes_applied) == 1
|
||||
|
||||
def test_non_response_source_unchanged(self):
|
||||
fixer = AgentFixer()
|
||||
node = _make_node(node_id="n1", block_id=_CODE_EXECUTION_BLOCK_ID)
|
||||
link = _make_link(source_id="n1", source_name="stdout_logs", sink_id="n2")
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
result = fixer.fix_code_execution_output(agent)
|
||||
|
||||
assert result["links"][0]["source_name"] == "stdout_logs"
|
||||
assert fixer.fixes_applied == []
|
||||
|
||||
|
||||
class TestFixDataSamplingSampleSize:
|
||||
"""Tests for fix_data_sampling_sample_size."""
|
||||
|
||||
def test_sample_size_set_to_1(self):
|
||||
fixer = AgentFixer()
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id=_DATA_SAMPLING_BLOCK_ID,
|
||||
input_default={"sample_size": 10},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
result = fixer.fix_data_sampling_sample_size(agent)
|
||||
|
||||
assert result["nodes"][0]["input_default"]["sample_size"] == 1
|
||||
|
||||
def test_removes_links_to_sample_size(self):
|
||||
fixer = AgentFixer()
|
||||
node = _make_node(node_id="n1", block_id=_DATA_SAMPLING_BLOCK_ID)
|
||||
link = _make_link(sink_id="n1", sink_name="sample_size", source_id="n2")
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
result = fixer.fix_data_sampling_sample_size(agent)
|
||||
|
||||
assert len(result["links"]) == 0
|
||||
assert result["nodes"][0]["input_default"]["sample_size"] == 1
|
||||
|
||||
|
||||
class TestFixTextReplaceNewParameter:
|
||||
"""Tests for fix_text_replace_new_parameter."""
|
||||
|
||||
def test_empty_new_changed_to_space(self):
|
||||
fixer = AgentFixer()
|
||||
node = _make_node(
|
||||
block_id=_TEXT_REPLACE_BLOCK_ID,
|
||||
input_default={"new": ""},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
result = fixer.fix_text_replace_new_parameter(agent)
|
||||
|
||||
assert result["nodes"][0]["input_default"]["new"] == " "
|
||||
|
||||
def test_nonempty_new_unchanged(self):
|
||||
fixer = AgentFixer()
|
||||
node = _make_node(
|
||||
block_id=_TEXT_REPLACE_BLOCK_ID,
|
||||
input_default={"new": "replacement"},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
result = fixer.fix_text_replace_new_parameter(agent)
|
||||
|
||||
assert result["nodes"][0]["input_default"]["new"] == "replacement"
|
||||
assert fixer.fixes_applied == []
|
||||
|
||||
|
||||
class TestFixGetCurrentDateOffset:
|
||||
"""Tests for fix_getcurrentdate_offset."""
|
||||
|
||||
def test_negative_offset_made_positive(self):
|
||||
fixer = AgentFixer()
|
||||
node = _make_node(
|
||||
block_id=_GET_CURRENT_DATE_BLOCK_ID,
|
||||
input_default={"offset": -5},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
result = fixer.fix_getcurrentdate_offset(agent)
|
||||
|
||||
assert result["nodes"][0]["input_default"]["offset"] == 5
|
||||
|
||||
def test_positive_offset_unchanged(self):
|
||||
fixer = AgentFixer()
|
||||
node = _make_node(
|
||||
block_id=_GET_CURRENT_DATE_BLOCK_ID,
|
||||
input_default={"offset": 3},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
result = fixer.fix_getcurrentdate_offset(agent)
|
||||
|
||||
assert result["nodes"][0]["input_default"]["offset"] == 3
|
||||
assert fixer.fixes_applied == []
|
||||
|
||||
|
||||
class TestFixNodeXCoordinates:
|
||||
"""Tests for fix_node_x_coordinates."""
|
||||
|
||||
def test_close_nodes_spread_apart(self):
|
||||
fixer = AgentFixer()
|
||||
src_node = _make_node(node_id="src", position=(0, 0))
|
||||
sink_node = _make_node(node_id="sink", position=(100, 0))
|
||||
link = _make_link(source_id="src", sink_id="sink")
|
||||
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
|
||||
|
||||
result = fixer.fix_node_x_coordinates(agent)
|
||||
|
||||
sink = next(n for n in result["nodes"] if n["id"] == "sink")
|
||||
assert sink["metadata"]["position"]["x"] >= 800
|
||||
|
||||
def test_far_apart_nodes_unchanged(self):
|
||||
fixer = AgentFixer()
|
||||
src_node = _make_node(node_id="src", position=(0, 0))
|
||||
sink_node = _make_node(node_id="sink", position=(1000, 0))
|
||||
link = _make_link(source_id="src", sink_id="sink")
|
||||
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
|
||||
|
||||
result = fixer.fix_node_x_coordinates(agent)
|
||||
|
||||
sink = next(n for n in result["nodes"] if n["id"] == "sink")
|
||||
assert sink["metadata"]["position"]["x"] == 1000
|
||||
assert fixer.fixes_applied == []
|
||||
|
||||
|
||||
class TestFixAddToDictionaryBlocks:
|
||||
"""Tests for fix_addtodictionary_blocks."""
|
||||
|
||||
def test_removes_create_dictionary_nodes(self):
|
||||
fixer = AgentFixer()
|
||||
create_dict_id = "b924ddf4-de4f-4b56-9a85-358930dcbc91"
|
||||
dict_node = _make_node(node_id="dict-1", block_id=create_dict_id)
|
||||
add_to_dict_node = _make_node(
|
||||
node_id="add-1", block_id=_ADDTODICTIONARY_BLOCK_ID
|
||||
)
|
||||
link = _make_link(source_id="dict-1", sink_id="add-1")
|
||||
agent = _make_agent(nodes=[dict_node, add_to_dict_node], links=[link])
|
||||
|
||||
result = fixer.fix_addtodictionary_blocks(agent)
|
||||
|
||||
node_ids = [n["id"] for n in result["nodes"]]
|
||||
assert "dict-1" not in node_ids
|
||||
assert "add-1" in node_ids
|
||||
assert len(result["links"]) == 0
|
||||
|
||||
|
||||
class TestFixStoreValueBeforeCondition:
|
||||
"""Tests for fix_storevalue_before_condition."""
|
||||
|
||||
def test_inserts_storevalue_block(self):
|
||||
fixer = AgentFixer()
|
||||
condition_block_id = "715696a0-e1da-45c8-b209-c2fa9c3b0be6"
|
||||
src_node = _make_node(node_id="src")
|
||||
cond_node = _make_node(node_id="cond", block_id=condition_block_id)
|
||||
link = _make_link(
|
||||
source_id="src", source_name="output", sink_id="cond", sink_name="value2"
|
||||
)
|
||||
agent = _make_agent(nodes=[src_node, cond_node], links=[link])
|
||||
|
||||
result = fixer.fix_storevalue_before_condition(agent)
|
||||
|
||||
# Should have 3 nodes now (original 2 + new StoreValueBlock)
|
||||
assert len(result["nodes"]) == 3
|
||||
store_nodes = [
|
||||
n for n in result["nodes"] if n["block_id"] == _STORE_VALUE_BLOCK_ID
|
||||
]
|
||||
assert len(store_nodes) == 1
|
||||
assert store_nodes[0]["input_default"]["data"] is None
|
||||
|
||||
|
||||
class TestFixAddToListBlocks:
|
||||
"""Tests for fix_addtolist_blocks - self-reference links."""
|
||||
|
||||
def test_addtolist_gets_self_reference_link(self):
|
||||
fixer = AgentFixer()
|
||||
node = _make_node(node_id="atl-1", block_id=_ADDTOLIST_BLOCK_ID)
|
||||
# Source link to AddToList (from some other node)
|
||||
link = _make_link(
|
||||
source_id="other",
|
||||
source_name="output",
|
||||
sink_id="atl-1",
|
||||
sink_name="item",
|
||||
)
|
||||
other_node = _make_node(node_id="other")
|
||||
agent = _make_agent(nodes=[other_node, node], links=[link])
|
||||
|
||||
result = fixer.fix_addtolist_blocks(agent)
|
||||
|
||||
# Should have a self-reference link: atl-1.updated_list -> atl-1.list
|
||||
self_ref_links = [
|
||||
lnk
|
||||
for lnk in result["links"]
|
||||
if lnk["source_id"] == "atl-1"
|
||||
and lnk["sink_id"] == "atl-1"
|
||||
and lnk["source_name"] == "updated_list"
|
||||
and lnk["sink_name"] == "list"
|
||||
]
|
||||
assert len(self_ref_links) == 1
|
||||
|
||||
|
||||
class TestFixLinkStaticProperties:
|
||||
"""Tests for fix_link_static_properties."""
|
||||
|
||||
def test_sets_is_static_from_block_schema(self):
|
||||
fixer = AgentFixer()
|
||||
block_id = generate_uuid()
|
||||
node = _make_node(node_id="n1", block_id=block_id)
|
||||
link = _make_link(source_id="n1", sink_id="n2", is_static=False)
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
blocks = [{"id": block_id, "staticOutput": True}]
|
||||
|
||||
result = fixer.fix_link_static_properties(agent, blocks)
|
||||
|
||||
assert result["links"][0]["is_static"] is True
|
||||
|
||||
def test_unknown_block_leaves_link_unchanged(self):
|
||||
fixer = AgentFixer()
|
||||
node = _make_node(node_id="n1", block_id="unknown-block")
|
||||
link = _make_link(source_id="n1", sink_id="n2", is_static=True)
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
result = fixer.fix_link_static_properties(agent, blocks=[])
|
||||
|
||||
# Unknown block → skipped, link stays as-is
|
||||
assert result["links"][0]["is_static"] is True
|
||||
|
||||
|
||||
class TestFixAiModelParameter:
|
||||
"""Tests for fix_ai_model_parameter."""
|
||||
|
||||
def test_missing_model_gets_default(self):
|
||||
fixer = AgentFixer()
|
||||
block_id = generate_uuid()
|
||||
node = _make_node(node_id="n1", block_id=block_id, input_default={})
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
blocks = [
|
||||
{
|
||||
"id": block_id,
|
||||
"categories": [{"category": "AI"}],
|
||||
"inputSchema": {
|
||||
"properties": {"model": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
result = fixer.fix_ai_model_parameter(agent, blocks)
|
||||
|
||||
assert result["nodes"][0]["input_default"]["model"] == "gpt-4o"
|
||||
|
||||
def test_valid_model_unchanged(self):
|
||||
fixer = AgentFixer()
|
||||
block_id = generate_uuid()
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id=block_id,
|
||||
input_default={"model": "claude-opus-4-6"},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
blocks = [
|
||||
{
|
||||
"id": block_id,
|
||||
"categories": [{"category": "AI"}],
|
||||
"inputSchema": {
|
||||
"properties": {"model": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
result = fixer.fix_ai_model_parameter(agent, blocks)
|
||||
|
||||
assert result["nodes"][0]["input_default"]["model"] == "claude-opus-4-6"
|
||||
|
||||
|
||||
class TestFixAgentExecutorBlocks:
|
||||
"""Tests for fix_agent_executor_blocks."""
|
||||
|
||||
def test_fills_schemas_from_library_agent(self):
|
||||
fixer = AgentFixer()
|
||||
lib_agent_id = generate_uuid()
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id=AGENT_EXECUTOR_BLOCK_ID,
|
||||
input_default={
|
||||
"graph_id": lib_agent_id,
|
||||
"graph_version": 1,
|
||||
"user_id": "user-1",
|
||||
},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
# Library agents use graph_id as the lookup key
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": lib_agent_id,
|
||||
"graph_version": 2,
|
||||
"input_schema": {"field1": {"type": "string"}},
|
||||
"output_schema": {"result": {"type": "string"}},
|
||||
}
|
||||
]
|
||||
|
||||
result = fixer.fix_agent_executor_blocks(agent, library_agents)
|
||||
|
||||
node_result = result["nodes"][0]["input_default"]
|
||||
assert node_result["graph_version"] == 2
|
||||
assert node_result["input_schema"] == {"field1": {"type": "string"}}
|
||||
assert node_result["output_schema"] == {"result": {"type": "string"}}
|
||||
|
||||
|
||||
class TestFixInvalidNestedSinkLinks:
|
||||
"""Tests for fix_invalid_nested_sink_links."""
|
||||
|
||||
def test_removes_numeric_index_links(self):
|
||||
fixer = AgentFixer()
|
||||
block_id = generate_uuid()
|
||||
node = _make_node(node_id="n1", block_id=block_id)
|
||||
link = _make_link(source_id="n2", sink_id="n1", sink_name="values_#_0")
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
blocks = [
|
||||
{
|
||||
"id": block_id,
|
||||
"inputSchema": {"properties": {"values": {"type": "array"}}},
|
||||
}
|
||||
]
|
||||
|
||||
result = fixer.fix_invalid_nested_sink_links(agent, blocks)
|
||||
|
||||
assert len(result["links"]) == 0
|
||||
|
||||
def test_valid_nested_links_kept(self):
|
||||
fixer = AgentFixer()
|
||||
block_id = generate_uuid()
|
||||
node = _make_node(node_id="n1", block_id=block_id)
|
||||
link = _make_link(source_id="n2", sink_id="n1", sink_name="values_#_name")
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
blocks = [
|
||||
{
|
||||
"id": block_id,
|
||||
"inputSchema": {
|
||||
"properties": {"values": {"type": "object"}},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
result = fixer.fix_invalid_nested_sink_links(agent, blocks)
|
||||
|
||||
assert len(result["links"]) == 1
|
||||
|
||||
|
||||
class TestApplyAllFixes:
|
||||
"""Tests for apply_all_fixes orchestration."""
|
||||
|
||||
def test_is_sync(self):
|
||||
"""apply_all_fixes should be a sync function."""
|
||||
import inspect
|
||||
|
||||
assert not inspect.iscoroutinefunction(AgentFixer.apply_all_fixes)
|
||||
|
||||
def test_applies_multiple_fixes(self):
|
||||
fixer = AgentFixer()
|
||||
agent = _make_agent(
|
||||
agent_id="bad-id",
|
||||
nodes=[
|
||||
_make_node(
|
||||
block_id=_TEXT_REPLACE_BLOCK_ID,
|
||||
input_default={"new": "", "credentials": {"key": "secret"}},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
result = fixer.apply_all_fixes(agent)
|
||||
|
||||
# Agent ID should be fixed
|
||||
assert result["id"] != "bad-id"
|
||||
# Credentials should be removed
|
||||
assert "credentials" not in result["nodes"][0]["input_default"]
|
||||
# Text replace "new" should be space
|
||||
assert result["nodes"][0]["input_default"]["new"] == " "
|
||||
# Multiple fixes applied
|
||||
assert len(fixer.fixes_applied) >= 3
|
||||
|
||||
def test_empty_agent_no_crash(self):
|
||||
fixer = AgentFixer()
|
||||
agent = _make_agent()
|
||||
|
||||
result = fixer.apply_all_fixes(agent)
|
||||
|
||||
assert "nodes" in result
|
||||
assert "links" in result
|
||||
|
||||
def test_returns_deep_copy_behavior(self):
|
||||
"""Fixer mutates in place — verify the same dict is returned."""
|
||||
fixer = AgentFixer()
|
||||
agent = _make_agent()
|
||||
result = fixer.apply_all_fixes(agent)
|
||||
assert result is agent
|
||||
|
||||
|
||||
class TestFixMCPToolBlocks:
|
||||
"""Tests for fix_mcp_tool_blocks."""
|
||||
|
||||
def test_adds_missing_tool_arguments(self):
|
||||
fixer = AgentFixer()
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id=MCP_TOOL_BLOCK_ID,
|
||||
input_default={
|
||||
"server_url": "https://mcp.example.com/sse",
|
||||
"selected_tool": "search",
|
||||
"tool_input_schema": {},
|
||||
},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
result = fixer.fix_mcp_tool_blocks(agent)
|
||||
|
||||
assert result["nodes"][0]["input_default"]["tool_arguments"] == {}
|
||||
assert any("tool_arguments" in f for f in fixer.fixes_applied)
|
||||
|
||||
def test_adds_missing_tool_input_schema(self):
|
||||
fixer = AgentFixer()
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id=MCP_TOOL_BLOCK_ID,
|
||||
input_default={
|
||||
"server_url": "https://mcp.example.com/sse",
|
||||
"selected_tool": "search",
|
||||
"tool_arguments": {},
|
||||
},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
result = fixer.fix_mcp_tool_blocks(agent)
|
||||
|
||||
assert result["nodes"][0]["input_default"]["tool_input_schema"] == {}
|
||||
assert any("tool_input_schema" in f for f in fixer.fixes_applied)
|
||||
|
||||
def test_populates_tool_arguments_from_schema(self):
|
||||
fixer = AgentFixer()
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id=MCP_TOOL_BLOCK_ID,
|
||||
input_default={
|
||||
"server_url": "https://mcp.example.com/sse",
|
||||
"selected_tool": "search",
|
||||
"tool_input_schema": {
|
||||
"properties": {
|
||||
"query": {"type": "string", "default": "hello"},
|
||||
"limit": {"type": "integer"},
|
||||
}
|
||||
},
|
||||
"tool_arguments": {},
|
||||
},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
result = fixer.fix_mcp_tool_blocks(agent)
|
||||
|
||||
tool_args = result["nodes"][0]["input_default"]["tool_arguments"]
|
||||
assert tool_args["query"] == "hello"
|
||||
assert tool_args["limit"] is None
|
||||
|
||||
def test_no_op_when_already_complete(self):
|
||||
fixer = AgentFixer()
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id=MCP_TOOL_BLOCK_ID,
|
||||
input_default={
|
||||
"server_url": "https://mcp.example.com/sse",
|
||||
"selected_tool": "search",
|
||||
"tool_input_schema": {},
|
||||
"tool_arguments": {},
|
||||
},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
fixer.fix_mcp_tool_blocks(agent)
|
||||
|
||||
assert len(fixer.fixes_applied) == 0
|
||||
|
||||
|
||||
class TestFixDynamicBlockSinkNames:
|
||||
"""Tests for fix_dynamic_block_sink_names."""
|
||||
|
||||
def test_mcp_tool_arguments_prefix_removed(self):
|
||||
fixer = AgentFixer()
|
||||
node = _make_node(node_id="n1", block_id=MCP_TOOL_BLOCK_ID)
|
||||
link = _make_link(
|
||||
source_id="src", sink_id="n1", sink_name="tool_arguments_#_query"
|
||||
)
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
fixer.fix_dynamic_block_sink_names(agent)
|
||||
|
||||
assert agent["links"][0]["sink_name"] == "query"
|
||||
assert len(fixer.fixes_applied) == 1
|
||||
|
||||
def test_agent_executor_inputs_prefix_removed(self):
|
||||
fixer = AgentFixer()
|
||||
node = _make_node(node_id="n1", block_id=AGENT_EXECUTOR_BLOCK_ID)
|
||||
link = _make_link(source_id="src", sink_id="n1", sink_name="inputs_#_url")
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
fixer.fix_dynamic_block_sink_names(agent)
|
||||
|
||||
assert agent["links"][0]["sink_name"] == "url"
|
||||
assert len(fixer.fixes_applied) == 1
|
||||
|
||||
def test_bare_sink_name_unchanged(self):
|
||||
fixer = AgentFixer()
|
||||
node = _make_node(node_id="n1", block_id=MCP_TOOL_BLOCK_ID)
|
||||
link = _make_link(source_id="src", sink_id="n1", sink_name="query")
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
fixer.fix_dynamic_block_sink_names(agent)
|
||||
|
||||
assert agent["links"][0]["sink_name"] == "query"
|
||||
assert len(fixer.fixes_applied) == 0
|
||||
|
||||
def test_non_dynamic_block_unchanged(self):
|
||||
fixer = AgentFixer()
|
||||
node = _make_node(node_id="n1", block_id="some-other-block-id")
|
||||
link = _make_link(source_id="src", sink_id="n1", sink_name="values_#_key")
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
fixer.fix_dynamic_block_sink_names(agent)
|
||||
|
||||
assert agent["links"][0]["sink_name"] == "values_#_key"
|
||||
assert len(fixer.fixes_applied) == 0
|
||||
|
||||
|
||||
class TestFixDataTypeMismatch:
|
||||
"""Tests for fix_data_type_mismatch."""
|
||||
|
||||
@staticmethod
|
||||
def _make_block(
|
||||
block_id: str,
|
||||
name: str = "TestBlock",
|
||||
input_schema: dict | None = None,
|
||||
output_schema: dict | None = None,
|
||||
) -> dict:
|
||||
return {
|
||||
"id": block_id,
|
||||
"name": name,
|
||||
"inputSchema": input_schema or {"properties": {}},
|
||||
"outputSchema": output_schema or {"properties": {}},
|
||||
}
|
||||
|
||||
def test_inserts_converter_for_incompatible_types(self):
|
||||
fixer = AgentFixer()
|
||||
src_block_id = generate_uuid()
|
||||
sink_block_id = generate_uuid()
|
||||
|
||||
src_node = _make_node(node_id="src", block_id=src_block_id)
|
||||
sink_node = _make_node(node_id="sink", block_id=sink_block_id)
|
||||
link = _make_link(
|
||||
source_id="src",
|
||||
source_name="result",
|
||||
sink_id="sink",
|
||||
sink_name="count",
|
||||
)
|
||||
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
|
||||
|
||||
blocks = [
|
||||
self._make_block(
|
||||
src_block_id,
|
||||
name="Source",
|
||||
output_schema={"properties": {"result": {"type": "string"}}},
|
||||
),
|
||||
self._make_block(
|
||||
sink_block_id,
|
||||
name="Sink",
|
||||
input_schema={"properties": {"count": {"type": "integer"}}},
|
||||
),
|
||||
]
|
||||
|
||||
result = fixer.fix_data_type_mismatch(agent, blocks)
|
||||
|
||||
# A converter node should have been inserted
|
||||
converter_nodes = [
|
||||
n
|
||||
for n in result["nodes"]
|
||||
if n["block_id"] == _UNIVERSAL_TYPE_CONVERTER_BLOCK_ID
|
||||
]
|
||||
assert len(converter_nodes) == 1
|
||||
assert converter_nodes[0]["input_default"]["type"] == "number"
|
||||
|
||||
# Original link replaced by two new links through the converter
|
||||
assert len(result["links"]) == 2
|
||||
src_to_converter = result["links"][0]
|
||||
converter_to_sink = result["links"][1]
|
||||
|
||||
assert src_to_converter["source_id"] == "src"
|
||||
assert src_to_converter["sink_id"] == converter_nodes[0]["id"]
|
||||
assert src_to_converter["sink_name"] == "value"
|
||||
|
||||
assert converter_to_sink["source_id"] == converter_nodes[0]["id"]
|
||||
assert converter_to_sink["source_name"] == "value"
|
||||
assert converter_to_sink["sink_id"] == "sink"
|
||||
assert converter_to_sink["sink_name"] == "count"
|
||||
|
||||
assert len(fixer.fixes_applied) == 1
|
||||
|
||||
def test_compatible_types_unchanged(self):
|
||||
fixer = AgentFixer()
|
||||
block_id = generate_uuid()
|
||||
|
||||
src_node = _make_node(node_id="src", block_id=block_id)
|
||||
sink_node = _make_node(node_id="sink", block_id=block_id)
|
||||
link = _make_link(
|
||||
source_id="src",
|
||||
source_name="output",
|
||||
sink_id="sink",
|
||||
sink_name="input",
|
||||
)
|
||||
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
|
||||
|
||||
blocks = [
|
||||
self._make_block(
|
||||
block_id,
|
||||
input_schema={"properties": {"input": {"type": "string"}}},
|
||||
output_schema={"properties": {"output": {"type": "string"}}},
|
||||
),
|
||||
]
|
||||
|
||||
result = fixer.fix_data_type_mismatch(agent, blocks)
|
||||
|
||||
# No converter inserted, original link kept
|
||||
assert len(result["nodes"]) == 2
|
||||
assert len(result["links"]) == 1
|
||||
assert result["links"][0] is link
|
||||
assert fixer.fixes_applied == []
|
||||
|
||||
def test_missing_block_keeps_link(self):
|
||||
"""Links referencing unknown blocks are kept unchanged."""
|
||||
fixer = AgentFixer()
|
||||
src_node = _make_node(node_id="src", block_id="unknown-block")
|
||||
sink_node = _make_node(node_id="sink", block_id="unknown-block")
|
||||
link = _make_link(source_id="src", sink_id="sink")
|
||||
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
|
||||
|
||||
result = fixer.fix_data_type_mismatch(agent, blocks=[])
|
||||
|
||||
assert len(result["links"]) == 1
|
||||
assert result["links"][0] is link
|
||||
|
||||
def test_missing_type_info_keeps_link(self):
|
||||
"""Links where source/sink type is not defined are kept unchanged."""
|
||||
fixer = AgentFixer()
|
||||
block_id = generate_uuid()
|
||||
src_node = _make_node(node_id="src", block_id=block_id)
|
||||
sink_node = _make_node(node_id="sink", block_id=block_id)
|
||||
link = _make_link(
|
||||
source_id="src",
|
||||
source_name="output",
|
||||
sink_id="sink",
|
||||
sink_name="input",
|
||||
)
|
||||
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
|
||||
|
||||
# Block has no properties defined for the linked fields
|
||||
blocks = [self._make_block(block_id)]
|
||||
|
||||
result = fixer.fix_data_type_mismatch(agent, blocks)
|
||||
|
||||
assert len(result["links"]) == 1
|
||||
assert fixer.fixes_applied == []
|
||||
|
||||
def test_multiple_mismatches_insert_multiple_converters(self):
|
||||
"""Each incompatible link gets its own converter node."""
|
||||
fixer = AgentFixer()
|
||||
src_block_id = generate_uuid()
|
||||
sink_block_id = generate_uuid()
|
||||
|
||||
src_node = _make_node(node_id="src", block_id=src_block_id)
|
||||
sink1 = _make_node(node_id="sink1", block_id=sink_block_id)
|
||||
sink2 = _make_node(node_id="sink2", block_id=sink_block_id)
|
||||
link1 = _make_link(
|
||||
source_id="src", source_name="out", sink_id="sink1", sink_name="count"
|
||||
)
|
||||
link2 = _make_link(
|
||||
source_id="src", source_name="out", sink_id="sink2", sink_name="count"
|
||||
)
|
||||
agent = _make_agent(nodes=[src_node, sink1, sink2], links=[link1, link2])
|
||||
|
||||
blocks = [
|
||||
self._make_block(
|
||||
src_block_id,
|
||||
output_schema={"properties": {"out": {"type": "string"}}},
|
||||
),
|
||||
self._make_block(
|
||||
sink_block_id,
|
||||
input_schema={"properties": {"count": {"type": "integer"}}},
|
||||
),
|
||||
]
|
||||
|
||||
result = fixer.fix_data_type_mismatch(agent, blocks)
|
||||
|
||||
converter_nodes = [
|
||||
n
|
||||
for n in result["nodes"]
|
||||
if n["block_id"] == _UNIVERSAL_TYPE_CONVERTER_BLOCK_ID
|
||||
]
|
||||
assert len(converter_nodes) == 2
|
||||
# Each original link becomes two links through its own converter
|
||||
assert len(result["links"]) == 4
|
||||
assert len(fixer.fixes_applied) == 2
|
||||
@@ -0,0 +1,67 @@
|
||||
"""Shared helpers for agent generation."""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from .blocks import get_blocks_as_dicts
|
||||
|
||||
__all__ = [
|
||||
"AGENT_EXECUTOR_BLOCK_ID",
|
||||
"AGENT_INPUT_BLOCK_ID",
|
||||
"AGENT_OUTPUT_BLOCK_ID",
|
||||
"AgentDict",
|
||||
"MCP_TOOL_BLOCK_ID",
|
||||
"UUID_REGEX",
|
||||
"are_types_compatible",
|
||||
"generate_uuid",
|
||||
"get_blocks_as_dicts",
|
||||
"get_defined_property_type",
|
||||
"is_uuid",
|
||||
]
|
||||
|
||||
|
||||
# Type alias for the agent JSON structure passed through
|
||||
# the validation and fixing pipeline.
|
||||
AgentDict = dict[str, Any]
|
||||
|
||||
# Shared base pattern (unanchored, lowercase hex); used for both full-string
|
||||
# validation (UUID_REGEX) and text extraction (core._UUID_PATTERN).
|
||||
UUID_RE_STR = r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[a-f0-9]{4}-[a-f0-9]{12}"
|
||||
|
||||
UUID_REGEX = re.compile(r"^" + UUID_RE_STR + r"$")
|
||||
|
||||
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
|
||||
MCP_TOOL_BLOCK_ID = "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
|
||||
AGENT_INPUT_BLOCK_ID = "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
|
||||
AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4"
|
||||
|
||||
|
||||
def is_uuid(value: str) -> bool:
|
||||
"""Check if a string is a valid UUID."""
|
||||
return isinstance(value, str) and UUID_REGEX.match(value) is not None
|
||||
|
||||
|
||||
def generate_uuid() -> str:
|
||||
"""Generate a new UUID string."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def get_defined_property_type(schema: dict[str, Any], name: str) -> str | None:
|
||||
"""Get property type from a schema, handling nested `_#_` notation."""
|
||||
if "_#_" in name:
|
||||
parent, child = name.split("_#_", 1)
|
||||
parent_schema = schema.get(parent, {})
|
||||
if "properties" in parent_schema and isinstance(
|
||||
parent_schema["properties"], dict
|
||||
):
|
||||
return parent_schema["properties"].get(child, {}).get("type")
|
||||
return None
|
||||
return schema.get(name, {}).get("type")
|
||||
|
||||
|
||||
def are_types_compatible(src: str, sink: str) -> bool:
|
||||
"""Check if two schema types are compatible."""
|
||||
if {src, sink} <= {"integer", "number"}:
|
||||
return True
|
||||
return src == sink
|
||||
@@ -0,0 +1,196 @@
|
||||
"""Shared fix → validate → preview/save pipeline for agent tools."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from backend.copilot.tools.models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
from .blocks import get_blocks_as_dicts
|
||||
from .core import get_library_agents_by_ids, save_agent_to_library
|
||||
from .fixer import AgentFixer
|
||||
from .validator import AgentValidator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_AGENT_JSON_SIZE = 1_000_000 # 1 MB
|
||||
|
||||
|
||||
async def fetch_library_agents(
|
||||
user_id: str | None,
|
||||
library_agent_ids: list[str],
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Fetch library agents by IDs for AgentExecutorBlock validation.
|
||||
|
||||
Returns None if no IDs provided or user is not authenticated.
|
||||
"""
|
||||
if not user_id or not library_agent_ids:
|
||||
return None
|
||||
try:
|
||||
agents = await get_library_agents_by_ids(
|
||||
user_id=user_id,
|
||||
agent_ids=library_agent_ids,
|
||||
)
|
||||
return cast(list[dict[str, Any]], agents)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agents by IDs: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def fix_validate_and_save(
|
||||
agent_json: dict[str, Any],
|
||||
*,
|
||||
user_id: str | None,
|
||||
session_id: str | None,
|
||||
save: bool = True,
|
||||
is_update: bool = False,
|
||||
default_name: str = "Agent",
|
||||
preview_message: str | None = None,
|
||||
save_message: str | None = None,
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
folder_id: str | None = None,
|
||||
) -> ToolResponseBase:
|
||||
"""Shared pipeline: auto-fix → validate → preview or save.
|
||||
|
||||
Args:
|
||||
agent_json: The agent JSON dict (must already have id/version/is_active set).
|
||||
user_id: The authenticated user's ID.
|
||||
session_id: The chat session ID.
|
||||
save: Whether to save or just preview.
|
||||
is_update: Whether this is an update to an existing agent.
|
||||
default_name: Fallback name if agent_json has none.
|
||||
preview_message: Custom preview message (optional).
|
||||
save_message: Custom save success message (optional).
|
||||
library_agents: Library agents for AgentExecutorBlock validation/fixing.
|
||||
|
||||
Returns:
|
||||
An appropriate ToolResponseBase subclass.
|
||||
"""
|
||||
# Size guard
|
||||
json_size = len(json.dumps(agent_json))
|
||||
if json_size > MAX_AGENT_JSON_SIZE:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Agent JSON is too large ({json_size:,} bytes, "
|
||||
f"max {MAX_AGENT_JSON_SIZE:,}). Reduce the number of nodes."
|
||||
),
|
||||
error="agent_json_too_large",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
blocks = get_blocks_as_dicts()
|
||||
|
||||
# Auto-fix
|
||||
try:
|
||||
fixer = AgentFixer()
|
||||
agent_json = fixer.apply_all_fixes(agent_json, blocks, library_agents)
|
||||
fixes = fixer.get_fixes_applied()
|
||||
if fixes:
|
||||
logger.info(f"Applied {len(fixes)} auto-fixes to agent JSON")
|
||||
except Exception as e:
|
||||
logger.warning(f"Auto-fix failed: {e}")
|
||||
|
||||
# Validate
|
||||
try:
|
||||
validator = AgentValidator()
|
||||
is_valid, _ = validator.validate(agent_json, blocks, library_agents)
|
||||
if not is_valid:
|
||||
errors = validator.errors
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"The agent has {len(errors)} validation error(s):\n"
|
||||
+ "\n".join(f"- {e}" for e in errors[:5])
|
||||
),
|
||||
error="validation_failed",
|
||||
details={"errors": errors},
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Validation failed with exception: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to validate the agent. Please try again.",
|
||||
error="validation_exception",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agent_name = agent_json.get("name", default_name)
|
||||
agent_description = agent_json.get("description", "")
|
||||
node_count = len(agent_json.get("nodes", []))
|
||||
link_count = len(agent_json.get("links", []))
|
||||
|
||||
# Build a warning suffix when name/description is missing or generic
|
||||
_GENERIC_NAMES = {
|
||||
"agent",
|
||||
"generated agent",
|
||||
"customized agent",
|
||||
"updated agent",
|
||||
"new agent",
|
||||
"my agent",
|
||||
}
|
||||
metadata_warnings: list[str] = []
|
||||
if not agent_json.get("name") or agent_name.lower().strip() in _GENERIC_NAMES:
|
||||
metadata_warnings.append("'name'")
|
||||
if not agent_description:
|
||||
metadata_warnings.append("'description'")
|
||||
metadata_hint = ""
|
||||
if metadata_warnings:
|
||||
missing = " and ".join(metadata_warnings)
|
||||
metadata_hint = (
|
||||
f" Note: the agent is missing a meaningful {missing}. "
|
||||
f"Please update the agent_json to include them."
|
||||
)
|
||||
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
(
|
||||
preview_message
|
||||
or f"Agent '{agent_name}' with {node_count} blocks is ready."
|
||||
)
|
||||
+ metadata_hint
|
||||
),
|
||||
agent_json=agent_json,
|
||||
agent_name=agent_name,
|
||||
description=agent_description,
|
||||
node_count=node_count,
|
||||
link_count=link_count,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
error="auth_required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
agent_json, user_id, is_update=is_update, folder_id=folder_id
|
||||
)
|
||||
return AgentSavedResponse(
|
||||
message=(
|
||||
(save_message or f"Agent '{created_graph.name}' has been saved!")
|
||||
+ metadata_hint
|
||||
),
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save agent: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to save the agent: {str(e)}",
|
||||
error="save_failed",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -1,511 +0,0 @@
|
||||
"""External Agent Generator service client.
|
||||
|
||||
This module provides a client for communicating with the external Agent Generator
|
||||
microservice. When AGENTGENERATOR_HOST is configured, the agent generation functions
|
||||
will delegate to the external service instead of using the built-in LLM-based implementation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .dummy import (
|
||||
customize_template_dummy,
|
||||
decompose_goal_dummy,
|
||||
generate_agent_dummy,
|
||||
generate_agent_patch_dummy,
|
||||
get_blocks_dummy,
|
||||
health_check_dummy,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_dummy_mode_warned = False
|
||||
|
||||
|
||||
def _create_error_response(
|
||||
error_message: str,
|
||||
error_type: str = "unknown",
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a standardized error response dict.
|
||||
|
||||
Args:
|
||||
error_message: Human-readable error message
|
||||
error_type: Machine-readable error type
|
||||
details: Optional additional error details
|
||||
|
||||
Returns:
|
||||
Error dict with type="error" and error details
|
||||
"""
|
||||
response: dict[str, Any] = {
|
||||
"type": "error",
|
||||
"error": error_message,
|
||||
"error_type": error_type,
|
||||
}
|
||||
if details:
|
||||
response["details"] = details
|
||||
return response
|
||||
|
||||
|
||||
def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]:
|
||||
"""Classify an HTTP error into error_type and message.
|
||||
|
||||
Args:
|
||||
e: The HTTP status error
|
||||
|
||||
Returns:
|
||||
Tuple of (error_type, error_message)
|
||||
"""
|
||||
status = e.response.status_code
|
||||
if status == 429:
|
||||
return "rate_limit", f"Agent Generator rate limited: {e}"
|
||||
elif status == 503:
|
||||
return "service_unavailable", f"Agent Generator unavailable: {e}"
|
||||
elif status == 504 or status == 408:
|
||||
return "timeout", f"Agent Generator timed out: {e}"
|
||||
else:
|
||||
return "http_error", f"HTTP error calling Agent Generator: {e}"
|
||||
|
||||
|
||||
def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]:
|
||||
"""Classify a request error into error_type and message.
|
||||
|
||||
Args:
|
||||
e: The request error
|
||||
|
||||
Returns:
|
||||
Tuple of (error_type, error_message)
|
||||
"""
|
||||
error_str = str(e).lower()
|
||||
if "timeout" in error_str or "timed out" in error_str:
|
||||
return "timeout", f"Agent Generator request timed out: {e}"
|
||||
elif "connect" in error_str:
|
||||
return "connection_error", f"Could not connect to Agent Generator: {e}"
|
||||
else:
|
||||
return "request_error", f"Request error calling Agent Generator: {e}"
|
||||
|
||||
|
||||
_client: httpx.AsyncClient | None = None
|
||||
_settings: Settings | None = None
|
||||
|
||||
|
||||
def _get_settings() -> Settings:
|
||||
"""Get or create settings singleton."""
|
||||
global _settings
|
||||
if _settings is None:
|
||||
_settings = Settings()
|
||||
return _settings
|
||||
|
||||
|
||||
def _is_dummy_mode() -> bool:
|
||||
"""Check if dummy mode is enabled for testing."""
|
||||
global _dummy_mode_warned
|
||||
settings = _get_settings()
|
||||
is_dummy = bool(settings.config.agentgenerator_use_dummy)
|
||||
if is_dummy and not _dummy_mode_warned:
|
||||
logger.warning(
|
||||
"Agent Generator running in DUMMY MODE - returning mock responses. "
|
||||
"Do not use in production!"
|
||||
)
|
||||
_dummy_mode_warned = True
|
||||
return is_dummy
|
||||
|
||||
|
||||
def is_external_service_configured() -> bool:
|
||||
"""Check if external Agent Generator service is configured (or dummy mode)."""
|
||||
settings = _get_settings()
|
||||
return bool(settings.config.agentgenerator_host) or bool(
|
||||
settings.config.agentgenerator_use_dummy
|
||||
)
|
||||
|
||||
|
||||
def _get_base_url() -> str:
|
||||
"""Get the base URL for the external service."""
|
||||
settings = _get_settings()
|
||||
host = settings.config.agentgenerator_host
|
||||
port = settings.config.agentgenerator_port
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def _get_client() -> httpx.AsyncClient:
|
||||
"""Get or create the HTTP client for the external service."""
|
||||
global _client
|
||||
if _client is None:
|
||||
settings = _get_settings()
|
||||
_client = httpx.AsyncClient(
|
||||
base_url=_get_base_url(),
|
||||
timeout=httpx.Timeout(settings.config.agentgenerator_timeout),
|
||||
)
|
||||
return _client
|
||||
|
||||
|
||||
async def decompose_goal_external(
|
||||
description: str,
|
||||
context: str = "",
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to decompose a goal.
|
||||
|
||||
Args:
|
||||
description: Natural language goal description
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
Dict with either:
|
||||
- {"type": "clarifying_questions", "questions": [...]}
|
||||
- {"type": "instructions", "steps": [...]}
|
||||
- {"type": "unachievable_goal", ...}
|
||||
- {"type": "vague_goal", ...}
|
||||
- {"type": "error", "error": "...", "error_type": "..."} on error
|
||||
Or None on unexpected error
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await decompose_goal_dummy(description, context, library_agents)
|
||||
|
||||
client = _get_client()
|
||||
|
||||
if context:
|
||||
description = f"{description}\n\nAdditional context from user:\n{context}"
|
||||
|
||||
payload: dict[str, Any] = {"description": description}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
|
||||
try:
|
||||
response = await client.post("/api/decompose-description", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator decomposition failed: {error_msg} "
|
||||
f"(type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
# Map the response to the expected format
|
||||
response_type = data.get("type")
|
||||
if response_type == "instructions":
|
||||
return {"type": "instructions", "steps": data.get("steps", [])}
|
||||
elif response_type == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": data.get("questions", []),
|
||||
}
|
||||
elif response_type == "unachievable_goal":
|
||||
return {
|
||||
"type": "unachievable_goal",
|
||||
"reason": data.get("reason"),
|
||||
"suggested_goal": data.get("suggested_goal"),
|
||||
}
|
||||
elif response_type == "vague_goal":
|
||||
return {
|
||||
"type": "vague_goal",
|
||||
"suggested_goal": data.get("suggested_goal"),
|
||||
}
|
||||
elif response_type == "error":
|
||||
# Pass through error from the service
|
||||
return _create_error_response(
|
||||
data.get("error", "Unknown error"),
|
||||
data.get("error_type", "unknown"),
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Unknown response type from external service: {response_type}"
|
||||
)
|
||||
return _create_error_response(
|
||||
f"Unknown response type from Agent Generator: {response_type}",
|
||||
"invalid_response",
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
|
||||
async def generate_agent_external(
|
||||
instructions: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate an agent from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
Agent JSON dict or error dict {"type": "error", ...} on error
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await generate_agent_dummy(instructions, library_agents)
|
||||
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
payload: dict[str, Any] = {"instructions": instructions}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
|
||||
try:
|
||||
response = await client.post("/api/generate-agent", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator generation failed: {error_msg} (type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
return data.get("agent_json")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
|
||||
async def generate_agent_patch_external(
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate a patch for an existing agent.
|
||||
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||
session_id: Session ID for async processing (enables Redis Streams callback)
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await generate_agent_patch_dummy(
|
||||
update_request, current_agent, library_agents
|
||||
)
|
||||
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
payload: dict[str, Any] = {
|
||||
"update_request": update_request,
|
||||
"current_agent_json": current_agent,
|
||||
}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
|
||||
try:
|
||||
response = await client.post("/api/update-agent", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator patch generation failed: {error_msg} "
|
||||
f"(type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
# Check if it's clarifying questions
|
||||
if data.get("type") == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": data.get("questions", []),
|
||||
}
|
||||
|
||||
# Check if it's an error passed through
|
||||
if data.get("type") == "error":
|
||||
return _create_error_response(
|
||||
data.get("error", "Unknown error"),
|
||||
data.get("error_type", "unknown"),
|
||||
)
|
||||
|
||||
# Otherwise return the updated agent JSON
|
||||
return data.get("agent_json")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
|
||||
async def customize_template_external(
|
||||
template_agent: dict[str, Any],
|
||||
modification_request: str,
|
||||
context: str = "",
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to customize a template/marketplace agent.
|
||||
|
||||
Args:
|
||||
template_agent: The template agent JSON to customize
|
||||
modification_request: Natural language description of customizations
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||
session_id: Session ID for async processing (enables Redis Streams callback)
|
||||
|
||||
Returns:
|
||||
Customized agent JSON, clarifying questions dict, or error dict on error
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await customize_template_dummy(
|
||||
template_agent, modification_request, context
|
||||
)
|
||||
|
||||
client = _get_client()
|
||||
|
||||
request = modification_request
|
||||
if context:
|
||||
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"template_agent_json": template_agent,
|
||||
"modification_request": request,
|
||||
}
|
||||
|
||||
try:
|
||||
response = await client.post("/api/template-modification", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator template customization failed: {error_msg} "
|
||||
f"(type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
# Check if it's clarifying questions
|
||||
if data.get("type") == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": data.get("questions", []),
|
||||
}
|
||||
|
||||
# Check if it's an error passed through
|
||||
if data.get("type") == "error":
|
||||
return _create_error_response(
|
||||
data.get("error", "Unknown error"),
|
||||
data.get("error_type", "unknown"),
|
||||
)
|
||||
|
||||
# Otherwise return the customized agent JSON
|
||||
return data.get("agent_json")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
|
||||
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||
"""Get available blocks from the external service.
|
||||
|
||||
Returns:
|
||||
List of block info dicts or None on error
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await get_blocks_dummy()
|
||||
|
||||
client = _get_client()
|
||||
|
||||
try:
|
||||
response = await client.get("/api/blocks")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
logger.error("External service returned error getting blocks")
|
||||
return None
|
||||
|
||||
return data.get("blocks", [])
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error getting blocks from external service: {e}")
|
||||
return None
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error getting blocks from external service: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error getting blocks from external service: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def health_check() -> bool:
|
||||
"""Check if the external service is healthy.
|
||||
|
||||
Returns:
|
||||
True if healthy, False otherwise
|
||||
"""
|
||||
if not is_external_service_configured():
|
||||
return False
|
||||
|
||||
if _is_dummy_mode():
|
||||
return await health_check_dummy()
|
||||
|
||||
client = _get_client()
|
||||
|
||||
try:
|
||||
response = await client.get("/health")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("status") == "healthy" and data.get("blocks_loaded", False)
|
||||
except Exception as e:
|
||||
logger.warning(f"External agent generator health check failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def close_client() -> None:
|
||||
"""Close the HTTP client."""
|
||||
global _client
|
||||
if _client is not None:
|
||||
await _client.aclose()
|
||||
_client = None
|
||||
@@ -0,0 +1,17 @@
|
||||
"""Agent generation validation — re-exports from split modules.
|
||||
|
||||
This module was split into:
|
||||
- helpers.py: get_blocks_as_dicts, block cache
|
||||
- fixer.py: AgentFixer class
|
||||
- validator.py: AgentValidator class
|
||||
"""
|
||||
|
||||
from .fixer import AgentFixer
|
||||
from .helpers import get_blocks_as_dicts
|
||||
from .validator import AgentValidator
|
||||
|
||||
__all__ = [
|
||||
"AgentFixer",
|
||||
"AgentValidator",
|
||||
"get_blocks_as_dicts",
|
||||
]
|
||||
@@ -0,0 +1,939 @@
|
||||
"""AgentValidator — validates agent JSON graphs for correctness."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from .helpers import (
|
||||
AGENT_EXECUTOR_BLOCK_ID,
|
||||
AGENT_INPUT_BLOCK_ID,
|
||||
AGENT_OUTPUT_BLOCK_ID,
|
||||
MCP_TOOL_BLOCK_ID,
|
||||
AgentDict,
|
||||
are_types_compatible,
|
||||
get_defined_property_type,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentValidator:
|
||||
"""
|
||||
A comprehensive validator for AutoGPT agents that provides detailed error
|
||||
reporting for LLM-based fixes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.errors: list[str] = []
|
||||
|
||||
def add_error(self, error_message: str) -> None:
|
||||
"""Add an error message to the validation errors list."""
|
||||
self.errors.append(error_message)
|
||||
|
||||
def _values_equal(self, val1: Any, val2: Any) -> bool:
|
||||
"""Compare two values, handling complex types like dicts and lists."""
|
||||
if type(val1) is not type(val2):
|
||||
return False
|
||||
if isinstance(val1, dict):
|
||||
return json.dumps(val1, sort_keys=True) == json.dumps(val2, sort_keys=True)
|
||||
if isinstance(val1, list):
|
||||
return json.dumps(val1, sort_keys=True) == json.dumps(val2, sort_keys=True)
|
||||
return val1 == val2
|
||||
|
||||
def validate_block_existence(
|
||||
self, agent: AgentDict, blocks: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that all block IDs used in the agent actually exist in the
|
||||
blocks list. Returns True if all block IDs exist, False otherwise.
|
||||
"""
|
||||
valid = True
|
||||
|
||||
# Create a set of all valid block IDs for fast lookup
|
||||
valid_block_ids = {block.get("id") for block in blocks if block.get("id")}
|
||||
|
||||
# Check each node's block_id
|
||||
for node in agent.get("nodes", []):
|
||||
block_id = node.get("block_id")
|
||||
node_id = node.get("id")
|
||||
|
||||
if not block_id:
|
||||
self.add_error(
|
||||
f"Node '{node_id}' is missing a 'block_id' field. "
|
||||
f"Every node must reference a valid block."
|
||||
)
|
||||
valid = False
|
||||
continue
|
||||
|
||||
if block_id not in valid_block_ids:
|
||||
self.add_error(
|
||||
f"Node '{node_id}' references block_id '{block_id}' "
|
||||
f"which does not exist in the available blocks. "
|
||||
f"This block may have been deprecated, removed, or "
|
||||
f"the ID is incorrect. Please use a valid block from "
|
||||
f"the blocks library."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_link_node_references(self, agent: AgentDict) -> bool:
|
||||
"""
|
||||
Validate that all node IDs referenced in links actually exist in the
|
||||
agent's nodes. Returns True if all link references are valid, False
|
||||
otherwise.
|
||||
"""
|
||||
valid = True
|
||||
|
||||
# Create a set of all valid node IDs for fast lookup
|
||||
valid_node_ids = {
|
||||
node.get("id") for node in agent.get("nodes", []) if node.get("id")
|
||||
}
|
||||
|
||||
# Check each link's source_id and sink_id
|
||||
for link in agent.get("links", []):
|
||||
link_id = link.get("id", "Unknown")
|
||||
source_id = link.get("source_id")
|
||||
sink_id = link.get("sink_id")
|
||||
source_name = link.get("source_name", "")
|
||||
sink_name = link.get("sink_name", "")
|
||||
|
||||
# Check source_id
|
||||
if not source_id:
|
||||
self.add_error(
|
||||
f"Link '{link_id}' is missing a 'source_id' field. "
|
||||
f"Every link must reference a valid source node."
|
||||
)
|
||||
valid = False
|
||||
elif source_id not in valid_node_ids:
|
||||
self.add_error(
|
||||
f"Link '{link_id}' references source_id '{source_id}' "
|
||||
f"which does not exist in the agent's nodes. The link "
|
||||
f"from '{source_name}' cannot be established because "
|
||||
f"the source node is missing."
|
||||
)
|
||||
valid = False
|
||||
|
||||
# Check sink_id
|
||||
if not sink_id:
|
||||
self.add_error(
|
||||
f"Link '{link_id}' is missing a 'sink_id' field. "
|
||||
f"Every link must reference a valid sink (destination) "
|
||||
f"node."
|
||||
)
|
||||
valid = False
|
||||
elif sink_id not in valid_node_ids:
|
||||
self.add_error(
|
||||
f"Link '{link_id}' references sink_id '{sink_id}' "
|
||||
f"which does not exist in the agent's nodes. The link "
|
||||
f"to '{sink_name}' cannot be established because the "
|
||||
f"destination node is missing."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_required_inputs(
|
||||
self, agent: AgentDict, blocks: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that all required inputs are provided for each node.
|
||||
Returns True if all required inputs are satisfied, False otherwise.
|
||||
"""
|
||||
valid = True
|
||||
|
||||
block_lookup = {b.get("id", ""): b for b in blocks}
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
block_id = node.get("block_id")
|
||||
block = block_lookup.get(block_id)
|
||||
|
||||
if not block:
|
||||
continue
|
||||
|
||||
required_inputs = block.get("inputSchema", {}).get("required", [])
|
||||
input_defaults = node.get("input_default", {})
|
||||
node_id = node.get("id")
|
||||
|
||||
linked_inputs = set(
|
||||
link.get("sink_name")
|
||||
for link in agent.get("links", [])
|
||||
if link.get("sink_id") == node_id and link.get("sink_name")
|
||||
)
|
||||
|
||||
for req_input in required_inputs:
|
||||
if (
|
||||
req_input not in input_defaults
|
||||
and req_input not in linked_inputs
|
||||
and req_input != "credentials"
|
||||
):
|
||||
block_name = block.get("name", "Unknown Block")
|
||||
self.add_error(
|
||||
f"Node '{node_id}' (block '{block_name}' - "
|
||||
f"{block_id}) is missing required input "
|
||||
f"'{req_input}'. This input must be either "
|
||||
f"provided as a default value in the node's "
|
||||
f"'input_default' field or connected via a link "
|
||||
f"from another node's output."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_data_type_compatibility(
|
||||
self, agent: AgentDict, blocks: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that linked data types are compatible between source and sink.
|
||||
Returns True if all data types are compatible, False otherwise.
|
||||
"""
|
||||
valid = True
|
||||
node_lookup = {node.get("id", ""): node for node in agent.get("nodes", [])}
|
||||
block_lookup = {block.get("id", ""): block for block in blocks}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
source_id = link.get("source_id")
|
||||
sink_id = link.get("sink_id")
|
||||
source_name = link.get("source_name")
|
||||
sink_name = link.get("sink_name")
|
||||
|
||||
if not all(
|
||||
isinstance(v, str) and v
|
||||
for v in (source_id, sink_id, source_name, sink_name)
|
||||
):
|
||||
self.add_error(
|
||||
f"Link '{link.get('id', 'Unknown')}' is missing required "
|
||||
f"fields (source_id/sink_id/source_name/sink_name)."
|
||||
)
|
||||
valid = False
|
||||
continue
|
||||
|
||||
source_node = node_lookup.get(source_id, "")
|
||||
sink_node = node_lookup.get(sink_id, "")
|
||||
|
||||
if not source_node or not sink_node:
|
||||
continue
|
||||
|
||||
source_block = block_lookup.get(source_node.get("block_id", ""))
|
||||
sink_block = block_lookup.get(sink_node.get("block_id", ""))
|
||||
|
||||
if not source_block or not sink_block:
|
||||
continue
|
||||
|
||||
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
|
||||
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
|
||||
|
||||
source_type = get_defined_property_type(source_outputs, source_name)
|
||||
sink_type = get_defined_property_type(sink_inputs, sink_name)
|
||||
|
||||
if (
|
||||
source_type
|
||||
and sink_type
|
||||
and not are_types_compatible(source_type, sink_type)
|
||||
):
|
||||
source_block_name = source_block.get("name", "Unknown Block")
|
||||
sink_block_name = sink_block.get("name", "Unknown Block")
|
||||
self.add_error(
|
||||
f"Data type mismatch in link '{link.get('id')}': "
|
||||
f"Source '{source_block_name}' output "
|
||||
f"'{link.get('source_name', '')}' outputs '{source_type}' "
|
||||
f"type, but sink '{sink_block_name}' input "
|
||||
f"'{link.get('sink_name', '')}' expects '{sink_type}' type. "
|
||||
f"These types must match for the connection to work "
|
||||
f"properly."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_nested_sink_links(
|
||||
self, agent: AgentDict, blocks: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""
|
||||
Validate nested sink links (links with _#_ notation).
|
||||
Returns True if all nested links are valid, False otherwise.
|
||||
"""
|
||||
valid = True
|
||||
block_input_schemas = {
|
||||
block.get("id", ""): block.get("inputSchema", {}).get("properties", {})
|
||||
for block in blocks
|
||||
}
|
||||
block_names = {
|
||||
block.get("id", ""): block.get("name", "Unknown Block") for block in blocks
|
||||
}
|
||||
node_lookup = {node.get("id", ""): node for node in agent.get("nodes", [])}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
sink_name = link.get("sink_name", "")
|
||||
sink_id = link.get("sink_id")
|
||||
|
||||
if not sink_name or not sink_id:
|
||||
continue
|
||||
|
||||
if "_#_" in sink_name:
|
||||
parent, child = sink_name.split("_#_", 1)
|
||||
|
||||
sink_node = node_lookup.get(sink_id)
|
||||
if not sink_node:
|
||||
continue
|
||||
|
||||
block_id = sink_node.get("block_id")
|
||||
input_props = block_input_schemas.get(block_id, {})
|
||||
|
||||
parent_schema = input_props.get(parent)
|
||||
if not parent_schema:
|
||||
block_name = block_names.get(block_id, "Unknown Block")
|
||||
self.add_error(
|
||||
f"Invalid nested sink link '{sink_name}' for "
|
||||
f"node '{sink_id}' (block "
|
||||
f"'{block_name}' - {block_id}): Parent property "
|
||||
f"'{parent}' does not exist in the block's "
|
||||
f"input schema."
|
||||
)
|
||||
valid = False
|
||||
continue
|
||||
|
||||
# Check if additionalProperties is allowed either directly
|
||||
# or via anyOf
|
||||
allows_additional_properties = parent_schema.get(
|
||||
"additionalProperties", False
|
||||
)
|
||||
|
||||
# Check anyOf for additionalProperties
|
||||
if not allows_additional_properties and "anyOf" in parent_schema:
|
||||
any_of_schemas = parent_schema.get("anyOf", [])
|
||||
if isinstance(any_of_schemas, list):
|
||||
for schema_option in any_of_schemas:
|
||||
if isinstance(schema_option, dict) and schema_option.get(
|
||||
"additionalProperties"
|
||||
):
|
||||
allows_additional_properties = True
|
||||
break
|
||||
|
||||
if not allows_additional_properties:
|
||||
if not (
|
||||
isinstance(parent_schema, dict)
|
||||
and "properties" in parent_schema
|
||||
and isinstance(parent_schema["properties"], dict)
|
||||
and child in parent_schema["properties"]
|
||||
):
|
||||
block_name = block_names.get(block_id, "Unknown Block")
|
||||
self.add_error(
|
||||
f"Invalid nested sink link '{sink_name}' "
|
||||
f"for node '{link.get('sink_id', '')}' (block "
|
||||
f"'{block_name}' - {block_id}): Child "
|
||||
f"property '{child}' does not exist in "
|
||||
f"parent '{parent}' schema. Available "
|
||||
f"properties: "
|
||||
f"{list(parent_schema.get('properties', {}).keys())}"
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_prompt_double_curly_braces_spaces(self, agent: AgentDict) -> bool:
|
||||
"""
|
||||
Validate that prompt parameters do not contain spaces in double curly
|
||||
braces.
|
||||
|
||||
Checks the 'prompt' parameter in input_default of each node and reports
|
||||
errors if values within double curly braces ({{...}}) contain spaces.
|
||||
For example, {{user name}} should be {{user_name}}.
|
||||
|
||||
Args:
|
||||
agent: The agent dictionary to validate
|
||||
|
||||
Returns:
|
||||
True if all prompts are valid (no spaces in double curly braces),
|
||||
False otherwise
|
||||
"""
|
||||
valid = True
|
||||
nodes = agent.get("nodes", [])
|
||||
|
||||
for node in nodes:
|
||||
node_id = node.get("id")
|
||||
input_default = node.get("input_default", {})
|
||||
|
||||
# Check if 'prompt' parameter exists
|
||||
if "prompt" not in input_default:
|
||||
continue
|
||||
|
||||
prompt_text = input_default["prompt"]
|
||||
|
||||
# Only process if it's a string
|
||||
if not isinstance(prompt_text, str):
|
||||
continue
|
||||
|
||||
# Find all double curly brace patterns with spaces
|
||||
matches = re.finditer(r"\{\{([^}]+)\}\}", prompt_text)
|
||||
|
||||
for match in matches:
|
||||
content = match.group(1)
|
||||
if " " in content:
|
||||
start_pos = match.start()
|
||||
snippet_start = max(0, start_pos - 30)
|
||||
snippet_end = min(len(prompt_text), match.end() + 30)
|
||||
snippet = prompt_text[snippet_start:snippet_end]
|
||||
|
||||
self.add_error(
|
||||
f"Node '{node_id}' has spaces in double curly "
|
||||
f"braces in prompt parameter: "
|
||||
f"'{{{{{content}}}}}' should be "
|
||||
f"'{{{{{content.replace(' ', '_')}}}}}'. "
|
||||
f"Context: ...{snippet}..."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_source_output_existence(
|
||||
self, agent: AgentDict, blocks: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that all source_names in links exist in the corresponding
|
||||
block's output schema.
|
||||
|
||||
Checks that for each link, the source_name field references a valid
|
||||
output property in the source block's outputSchema. Also handles nested
|
||||
outputs with _#_ notation.
|
||||
|
||||
Args:
|
||||
agent: The agent dictionary to validate
|
||||
blocks: List of available blocks with their schemas
|
||||
|
||||
Returns:
|
||||
True if all source output fields exist, False otherwise
|
||||
"""
|
||||
valid = True
|
||||
|
||||
# Create lookup dictionaries for efficiency
|
||||
block_output_schemas = {
|
||||
block.get("id", ""): block.get("outputSchema", {}).get("properties", {})
|
||||
for block in blocks
|
||||
}
|
||||
block_names = {
|
||||
block.get("id", ""): block.get("name", "Unknown Block") for block in blocks
|
||||
}
|
||||
node_lookup = {node.get("id", ""): node for node in agent.get("nodes", [])}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
source_id = link.get("source_id")
|
||||
source_name = link.get("source_name", "")
|
||||
link_id = link.get("id", "Unknown")
|
||||
|
||||
if not source_name:
|
||||
self.add_error(
|
||||
f"Link '{link_id}' is missing 'source_name'. "
|
||||
f"Every link must specify which output field to read from."
|
||||
)
|
||||
valid = False
|
||||
continue
|
||||
|
||||
source_node = node_lookup.get(source_id)
|
||||
if not source_node:
|
||||
# This error is already caught by
|
||||
# validate_link_node_references
|
||||
continue
|
||||
|
||||
block_id = source_node.get("block_id")
|
||||
block_name = block_names.get(block_id, "Unknown Block")
|
||||
|
||||
# Special handling for AgentExecutorBlock - use dynamic
|
||||
# output_schema from input_default
|
||||
if block_id == AGENT_EXECUTOR_BLOCK_ID:
|
||||
input_default = source_node.get("input_default", {})
|
||||
dynamic_output_schema = input_default.get("output_schema", {})
|
||||
if not isinstance(dynamic_output_schema, dict):
|
||||
dynamic_output_schema = {}
|
||||
output_props = dynamic_output_schema.get("properties", {})
|
||||
if not isinstance(output_props, dict):
|
||||
output_props = {}
|
||||
else:
|
||||
output_props = block_output_schemas.get(block_id, {})
|
||||
|
||||
# Handle nested source names (with _#_ notation)
|
||||
if "_#_" in source_name:
|
||||
parent, child = source_name.split("_#_", 1)
|
||||
|
||||
parent_schema = output_props.get(parent)
|
||||
if not parent_schema:
|
||||
self.add_error(
|
||||
f"Invalid source output field '{source_name}' "
|
||||
f"in link '{link_id}' from node '{source_id}' "
|
||||
f"(block '{block_name}' - {block_id}): Parent "
|
||||
f"property '{parent}' does not exist in the "
|
||||
f"block's output schema."
|
||||
)
|
||||
valid = False
|
||||
continue
|
||||
|
||||
# Check if additionalProperties is allowed either directly
|
||||
# or via anyOf
|
||||
allows_additional_properties = parent_schema.get(
|
||||
"additionalProperties", False
|
||||
)
|
||||
if not allows_additional_properties and "anyOf" in parent_schema:
|
||||
any_of_schemas = parent_schema.get("anyOf", [])
|
||||
if isinstance(any_of_schemas, list):
|
||||
for schema_option in any_of_schemas:
|
||||
if isinstance(schema_option, dict) and schema_option.get(
|
||||
"additionalProperties"
|
||||
):
|
||||
allows_additional_properties = True
|
||||
break
|
||||
# Also allow when items have
|
||||
# additionalProperties (array of objects)
|
||||
if (
|
||||
isinstance(schema_option, dict)
|
||||
and "items" in schema_option
|
||||
):
|
||||
items_schema = schema_option.get("items")
|
||||
if isinstance(items_schema, dict) and items_schema.get(
|
||||
"additionalProperties"
|
||||
):
|
||||
allows_additional_properties = True
|
||||
break
|
||||
|
||||
# Only require child in properties when
|
||||
# additionalProperties is not allowed
|
||||
if not allows_additional_properties:
|
||||
if not (
|
||||
isinstance(parent_schema, dict)
|
||||
and "properties" in parent_schema
|
||||
and isinstance(parent_schema["properties"], dict)
|
||||
and child in parent_schema["properties"]
|
||||
):
|
||||
available_props = (
|
||||
list(parent_schema.get("properties", {}).keys())
|
||||
if isinstance(parent_schema, dict)
|
||||
else []
|
||||
)
|
||||
self.add_error(
|
||||
f"Invalid nested source output field "
|
||||
f"'{source_name}' in link '{link_id}' from "
|
||||
f"node '{source_id}' (block "
|
||||
f"'{block_name}' - {block_id}): Child "
|
||||
f"property '{child}' does not exist in "
|
||||
f"parent '{parent}' output schema. "
|
||||
f"Available properties: {available_props}"
|
||||
)
|
||||
valid = False
|
||||
else:
|
||||
# Check simple (non-nested) source name
|
||||
if source_name not in output_props:
|
||||
available_outputs = list(output_props.keys())
|
||||
self.add_error(
|
||||
f"Invalid source output field '{source_name}' "
|
||||
f"in link '{link_id}' from node '{source_id}' "
|
||||
f"(block '{block_name}' - {block_id}): Output "
|
||||
f"property '{source_name}' does not exist in "
|
||||
f"the block's output schema. Available outputs: "
|
||||
f"{available_outputs}"
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_io_blocks(self, agent: AgentDict) -> bool:
|
||||
"""
|
||||
Validate that the agent has at least one AgentInputBlock and one
|
||||
AgentOutputBlock. These blocks define the agent's interface.
|
||||
|
||||
Returns True if both are present, False otherwise.
|
||||
"""
|
||||
valid = True
|
||||
block_ids = {node.get("block_id") for node in agent.get("nodes", [])}
|
||||
|
||||
if AGENT_INPUT_BLOCK_ID not in block_ids:
|
||||
self.add_error(
|
||||
f"Agent is missing an AgentInputBlock (block_id: "
|
||||
f"'{AGENT_INPUT_BLOCK_ID}'). Every agent must have at "
|
||||
f"least one AgentInputBlock to define user-facing inputs. "
|
||||
f"Add a node with block_id '{AGENT_INPUT_BLOCK_ID}' and "
|
||||
f"set input_default with 'name' and optionally 'title'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
if AGENT_OUTPUT_BLOCK_ID not in block_ids:
|
||||
self.add_error(
|
||||
f"Agent is missing an AgentOutputBlock (block_id: "
|
||||
f"'{AGENT_OUTPUT_BLOCK_ID}'). Every agent must have at "
|
||||
f"least one AgentOutputBlock to define user-facing outputs. "
|
||||
f"Add a node with block_id '{AGENT_OUTPUT_BLOCK_ID}' and "
|
||||
f"set input_default with 'name', then link 'value' from "
|
||||
f"another block's output."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_agent_executor_blocks(
|
||||
self,
|
||||
agent: AgentDict,
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Validate AgentExecutorBlock nodes have required fields and valid
|
||||
references.
|
||||
|
||||
Checks that AgentExecutorBlock nodes:
|
||||
1. Have a valid graph_id in input_default (required)
|
||||
2. If graph_id matches a known library agent, validates version
|
||||
consistency
|
||||
3. Sub-agent required inputs are connected via links (not hardcoded)
|
||||
|
||||
Note: Unknown graph_ids are not treated as errors - they could be valid
|
||||
direct references to agents by their actual ID (not via library_agents).
|
||||
This is consistent with fix_agent_executor_blocks() behavior.
|
||||
|
||||
Args:
|
||||
agent: The agent dictionary to validate
|
||||
library_agents: List of available library agents (for version
|
||||
validation)
|
||||
|
||||
Returns:
|
||||
True if all AgentExecutorBlock nodes are valid, False otherwise
|
||||
"""
|
||||
valid = True
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
|
||||
# Create lookup for library agents
|
||||
library_agent_lookup: dict[str, dict[str, Any]] = {}
|
||||
if library_agents:
|
||||
library_agent_lookup = {la.get("graph_id", ""): la for la in library_agents}
|
||||
|
||||
for node in nodes:
|
||||
if node.get("block_id") != AGENT_EXECUTOR_BLOCK_ID:
|
||||
continue
|
||||
|
||||
node_id = node.get("id")
|
||||
input_default = node.get("input_default", {})
|
||||
|
||||
# Check for required graph_id
|
||||
graph_id = input_default.get("graph_id")
|
||||
if not graph_id:
|
||||
self.add_error(
|
||||
f"AgentExecutorBlock node '{node_id}' is missing "
|
||||
f"required 'graph_id' in input_default. This field "
|
||||
f"must reference the ID of the sub-agent to execute."
|
||||
)
|
||||
valid = False
|
||||
continue
|
||||
|
||||
# If graph_id is not in library_agent_lookup, skip validation
|
||||
if graph_id not in library_agent_lookup:
|
||||
continue
|
||||
|
||||
# Validate version consistency for known library agents
|
||||
library_agent = library_agent_lookup[graph_id]
|
||||
expected_version = library_agent.get("graph_version")
|
||||
current_version = input_default.get("graph_version")
|
||||
if (
|
||||
current_version
|
||||
and expected_version
|
||||
and current_version != expected_version
|
||||
):
|
||||
self.add_error(
|
||||
f"AgentExecutorBlock node '{node_id}' has mismatched "
|
||||
f"graph_version: got {current_version}, expected "
|
||||
f"{expected_version} for library agent "
|
||||
f"'{library_agent.get('name')}'"
|
||||
)
|
||||
valid = False
|
||||
|
||||
# Validate sub-agent inputs are properly linked (not hardcoded)
|
||||
sub_agent_input_schema = library_agent.get("input_schema", {})
|
||||
if not isinstance(sub_agent_input_schema, dict):
|
||||
sub_agent_input_schema = {}
|
||||
sub_agent_required_inputs = sub_agent_input_schema.get("required", [])
|
||||
sub_agent_properties = sub_agent_input_schema.get("properties", {})
|
||||
|
||||
# Get all linked inputs to this node
|
||||
linked_sub_agent_inputs: set[str] = set()
|
||||
for link in links:
|
||||
if link.get("sink_id") == node_id:
|
||||
sink_name = link.get("sink_name", "")
|
||||
if sink_name in sub_agent_properties:
|
||||
linked_sub_agent_inputs.add(sink_name)
|
||||
|
||||
# Check for hardcoded inputs that should be linked
|
||||
hardcoded_inputs = input_default.get("inputs", {})
|
||||
input_schema = input_default.get("input_schema", {})
|
||||
schema_properties = (
|
||||
input_schema.get("properties", {})
|
||||
if isinstance(input_schema, dict)
|
||||
else {}
|
||||
)
|
||||
if isinstance(hardcoded_inputs, dict) and hardcoded_inputs:
|
||||
for input_name, value in hardcoded_inputs.items():
|
||||
if input_name not in sub_agent_properties:
|
||||
continue
|
||||
if value is None:
|
||||
continue
|
||||
# Skip if this input is already linked
|
||||
if input_name in linked_sub_agent_inputs:
|
||||
continue
|
||||
prop_schema = schema_properties.get(input_name, {})
|
||||
schema_default = (
|
||||
prop_schema.get("default")
|
||||
if isinstance(prop_schema, dict)
|
||||
else None
|
||||
)
|
||||
if schema_default is not None and self._values_equal(
|
||||
value, schema_default
|
||||
):
|
||||
continue
|
||||
# This is a non-default hardcoded value without a link
|
||||
self.add_error(
|
||||
f"AgentExecutorBlock node '{node_id}' has "
|
||||
f"hardcoded input '{input_name}'. Sub-agent "
|
||||
f"inputs should be connected via links using "
|
||||
f"'{input_name}' as sink_name, not hardcoded "
|
||||
f"in input_default.inputs. Create a link from "
|
||||
f"the appropriate source node."
|
||||
)
|
||||
valid = False
|
||||
|
||||
# Check for missing required sub-agent inputs.
|
||||
# An input is satisfied if it is linked OR has an allowed
|
||||
# hardcoded value (i.e. equals the schema default — the
|
||||
# previous check already flags non-default hardcoded values).
|
||||
hardcoded_inputs_dict = (
|
||||
hardcoded_inputs if isinstance(hardcoded_inputs, dict) else {}
|
||||
)
|
||||
for req_input in sub_agent_required_inputs:
|
||||
if req_input in linked_sub_agent_inputs:
|
||||
continue
|
||||
# Check if fixer populated it with a schema default value
|
||||
if req_input in hardcoded_inputs_dict:
|
||||
prop_schema = schema_properties.get(req_input, {})
|
||||
schema_default = (
|
||||
prop_schema.get("default")
|
||||
if isinstance(prop_schema, dict)
|
||||
else None
|
||||
)
|
||||
if schema_default is not None and self._values_equal(
|
||||
hardcoded_inputs_dict[req_input], schema_default
|
||||
):
|
||||
continue
|
||||
self.add_error(
|
||||
f"AgentExecutorBlock node '{node_id}' is "
|
||||
f"missing required sub-agent input "
|
||||
f"'{req_input}'. Create a link to this node "
|
||||
f"using sink_name '{req_input}' to connect "
|
||||
f"the input."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_agent_executor_block_schemas(
|
||||
self,
|
||||
agent: AgentDict,
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that AgentExecutorBlock nodes have valid input_schema and
|
||||
output_schema.
|
||||
|
||||
This validation runs regardless of library_agents availability and
|
||||
ensures that the schemas are properly populated to prevent frontend
|
||||
crashes.
|
||||
|
||||
Args:
|
||||
agent: The agent dictionary to validate
|
||||
|
||||
Returns:
|
||||
True if all AgentExecutorBlock nodes have valid schemas, False
|
||||
otherwise
|
||||
"""
|
||||
valid = True
|
||||
nodes = agent.get("nodes", [])
|
||||
|
||||
for node in nodes:
|
||||
if node.get("block_id") != AGENT_EXECUTOR_BLOCK_ID:
|
||||
continue
|
||||
|
||||
node_id = node.get("id")
|
||||
input_default = node.get("input_default", {})
|
||||
customized_name = (node.get("metadata") or {}).get(
|
||||
"customized_name", "Unknown"
|
||||
)
|
||||
|
||||
# Check input_schema
|
||||
input_schema = input_default.get("input_schema")
|
||||
if input_schema is None or not isinstance(input_schema, dict):
|
||||
self.add_error(
|
||||
f"AgentExecutorBlock node '{node_id}' "
|
||||
f"({customized_name}) has missing or invalid "
|
||||
f"input_schema. The input_schema must be a valid "
|
||||
f"JSON Schema object with 'properties' and "
|
||||
f"'required' fields."
|
||||
)
|
||||
valid = False
|
||||
elif not input_schema.get("properties") and not input_schema.get("type"):
|
||||
# Empty schema like {} is invalid
|
||||
self.add_error(
|
||||
f"AgentExecutorBlock node '{node_id}' "
|
||||
f"({customized_name}) has empty input_schema. The "
|
||||
f"input_schema must define the sub-agent's expected "
|
||||
f"inputs. This usually indicates the sub-agent "
|
||||
f"reference is incomplete or the library agent was "
|
||||
f"not properly passed."
|
||||
)
|
||||
valid = False
|
||||
|
||||
# Check output_schema
|
||||
output_schema = input_default.get("output_schema")
|
||||
if output_schema is None or not isinstance(output_schema, dict):
|
||||
self.add_error(
|
||||
f"AgentExecutorBlock node '{node_id}' "
|
||||
f"({customized_name}) has missing or invalid "
|
||||
f"output_schema. The output_schema must be a valid "
|
||||
f"JSON Schema object defining the sub-agent's "
|
||||
f"outputs."
|
||||
)
|
||||
valid = False
|
||||
elif not output_schema.get("properties") and not output_schema.get("type"):
|
||||
# Empty schema like {} is invalid
|
||||
self.add_error(
|
||||
f"AgentExecutorBlock node '{node_id}' "
|
||||
f"({customized_name}) has empty output_schema. "
|
||||
f"The output_schema must define the sub-agent's "
|
||||
f"expected outputs. This usually indicates the "
|
||||
f"sub-agent reference is incomplete or the library "
|
||||
f"agent was not properly passed."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_mcp_tool_blocks(self, agent: AgentDict) -> bool:
|
||||
"""Validate that MCPToolBlock nodes have required fields.
|
||||
|
||||
Checks that each MCPToolBlock node has:
|
||||
1. A non-empty `server_url` in input_default
|
||||
2. A non-empty `selected_tool` in input_default
|
||||
|
||||
Returns True if all MCPToolBlock nodes are valid, False otherwise.
|
||||
"""
|
||||
valid = True
|
||||
nodes = agent.get("nodes", [])
|
||||
|
||||
for node in nodes:
|
||||
if node.get("block_id") != MCP_TOOL_BLOCK_ID:
|
||||
continue
|
||||
|
||||
node_id = node.get("id", "unknown")
|
||||
input_default = node.get("input_default", {})
|
||||
customized_name = (node.get("metadata") or {}).get(
|
||||
"customized_name", node_id
|
||||
)
|
||||
|
||||
server_url = input_default.get("server_url")
|
||||
if not server_url:
|
||||
self.add_error(
|
||||
f"MCPToolBlock node '{customized_name}' ({node_id}) is "
|
||||
f"missing required 'server_url' in input_default. "
|
||||
f"Set this to the MCP server URL "
|
||||
f"(e.g. 'https://mcp.example.com/sse')."
|
||||
)
|
||||
valid = False
|
||||
|
||||
selected_tool = input_default.get("selected_tool")
|
||||
if not selected_tool:
|
||||
self.add_error(
|
||||
f"MCPToolBlock node '{customized_name}' ({node_id}) is "
|
||||
f"missing required 'selected_tool' in input_default. "
|
||||
f"Set this to the name of the MCP tool to execute."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate(
|
||||
self,
|
||||
agent: AgentDict,
|
||||
blocks: list[dict[str, Any]],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Comprehensive validation of an agent against available blocks.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: (is_valid, error_message)
|
||||
- is_valid: True if agent passes all validations, False otherwise
|
||||
- error_message: Detailed error message if validation fails, None
|
||||
if successful
|
||||
"""
|
||||
logger.info("Validating agent...")
|
||||
self.errors = []
|
||||
|
||||
checks = [
|
||||
(
|
||||
"Block existence",
|
||||
self.validate_block_existence(agent, blocks),
|
||||
),
|
||||
(
|
||||
"Link node references",
|
||||
self.validate_link_node_references(agent),
|
||||
),
|
||||
(
|
||||
"Required inputs",
|
||||
self.validate_required_inputs(agent, blocks),
|
||||
),
|
||||
(
|
||||
"Data type compatibility",
|
||||
self.validate_data_type_compatibility(agent, blocks),
|
||||
),
|
||||
(
|
||||
"Nested sink links",
|
||||
self.validate_nested_sink_links(agent, blocks),
|
||||
),
|
||||
(
|
||||
"Source output existence",
|
||||
self.validate_source_output_existence(agent, blocks),
|
||||
),
|
||||
(
|
||||
"Prompt double curly braces spaces",
|
||||
self.validate_prompt_double_curly_braces_spaces(agent),
|
||||
),
|
||||
(
|
||||
"IO blocks",
|
||||
self.validate_io_blocks(agent),
|
||||
),
|
||||
# Always validate AgentExecutorBlock schemas to prevent
|
||||
# frontend crashes
|
||||
(
|
||||
"AgentExecutorBlock schemas",
|
||||
self.validate_agent_executor_block_schemas(agent),
|
||||
),
|
||||
(
|
||||
"MCP tool blocks",
|
||||
self.validate_mcp_tool_blocks(agent),
|
||||
),
|
||||
]
|
||||
|
||||
# Add AgentExecutorBlock detailed validation if library_agents
|
||||
# provided
|
||||
if library_agents:
|
||||
checks.append(
|
||||
(
|
||||
"AgentExecutorBlock references",
|
||||
self.validate_agent_executor_blocks(agent, library_agents),
|
||||
)
|
||||
)
|
||||
|
||||
all_passed = all(check[1] for check in checks)
|
||||
|
||||
if all_passed:
|
||||
logger.info("Agent validation successful.")
|
||||
return True, None
|
||||
else:
|
||||
error_message = "Agent validation failed with the following errors:\n\n"
|
||||
for i, error in enumerate(self.errors, 1):
|
||||
error_message += f"{i}. {error}\n"
|
||||
|
||||
logger.error(f"Agent validation failed: {error_message}")
|
||||
return False, error_message
|
||||
@@ -0,0 +1,710 @@
|
||||
"""Unit tests for AgentValidator."""
|
||||
|
||||
from .helpers import (
|
||||
AGENT_EXECUTOR_BLOCK_ID,
|
||||
AGENT_INPUT_BLOCK_ID,
|
||||
AGENT_OUTPUT_BLOCK_ID,
|
||||
MCP_TOOL_BLOCK_ID,
|
||||
generate_uuid,
|
||||
)
|
||||
from .validator import AgentValidator
|
||||
|
||||
|
||||
def _make_agent(
|
||||
nodes: list | None = None,
|
||||
links: list | None = None,
|
||||
agent_id: str | None = None,
|
||||
) -> dict:
|
||||
"""Build a minimal agent dict for testing."""
|
||||
return {
|
||||
"id": agent_id or generate_uuid(),
|
||||
"name": "Test Agent",
|
||||
"nodes": nodes or [],
|
||||
"links": links or [],
|
||||
}
|
||||
|
||||
|
||||
def _make_node(
|
||||
node_id: str | None = None,
|
||||
block_id: str = "block-1",
|
||||
input_default: dict | None = None,
|
||||
position: tuple[int, int] = (0, 0),
|
||||
) -> dict:
|
||||
return {
|
||||
"id": node_id or generate_uuid(),
|
||||
"block_id": block_id,
|
||||
"input_default": input_default or {},
|
||||
"metadata": {"position": {"x": position[0], "y": position[1]}},
|
||||
}
|
||||
|
||||
|
||||
def _make_link(
|
||||
link_id: str | None = None,
|
||||
source_id: str = "",
|
||||
source_name: str = "output",
|
||||
sink_id: str = "",
|
||||
sink_name: str = "input",
|
||||
) -> dict:
|
||||
return {
|
||||
"id": link_id or generate_uuid(),
|
||||
"source_id": source_id,
|
||||
"source_name": source_name,
|
||||
"sink_id": sink_id,
|
||||
"sink_name": sink_name,
|
||||
}
|
||||
|
||||
|
||||
def _make_block(
|
||||
block_id: str = "block-1",
|
||||
name: str = "TestBlock",
|
||||
input_schema: dict | None = None,
|
||||
output_schema: dict | None = None,
|
||||
categories: list | None = None,
|
||||
static_output: bool = False,
|
||||
) -> dict:
|
||||
return {
|
||||
"id": block_id,
|
||||
"name": name,
|
||||
"inputSchema": input_schema or {"properties": {}, "required": []},
|
||||
"outputSchema": output_schema or {"properties": {}},
|
||||
"categories": categories or [],
|
||||
"staticOutput": static_output,
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_block_existence
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidateBlockExistence:
|
||||
def test_valid_blocks_pass(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(block_id="b1")
|
||||
block = _make_block(block_id="b1")
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_block_existence(agent, [block]) is True
|
||||
assert v.errors == []
|
||||
|
||||
def test_missing_block_fails(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(block_id="nonexistent")
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_block_existence(agent, []) is False
|
||||
assert len(v.errors) == 1
|
||||
assert "does not exist" in v.errors[0]
|
||||
|
||||
def test_missing_block_id_field(self):
|
||||
v = AgentValidator()
|
||||
node = {"id": "n1", "input_default": {}, "metadata": {}}
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_block_existence(agent, []) is False
|
||||
assert "missing a 'block_id'" in v.errors[0]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_link_node_references
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidateLinkNodeReferences:
|
||||
def test_valid_references_pass(self):
|
||||
v = AgentValidator()
|
||||
n1 = _make_node(node_id="n1")
|
||||
n2 = _make_node(node_id="n2")
|
||||
link = _make_link(source_id="n1", sink_id="n2")
|
||||
agent = _make_agent(nodes=[n1, n2], links=[link])
|
||||
|
||||
assert v.validate_link_node_references(agent) is True
|
||||
assert v.errors == []
|
||||
|
||||
def test_invalid_source_fails(self):
|
||||
v = AgentValidator()
|
||||
n1 = _make_node(node_id="n1")
|
||||
link = _make_link(source_id="missing", sink_id="n1")
|
||||
agent = _make_agent(nodes=[n1], links=[link])
|
||||
|
||||
assert v.validate_link_node_references(agent) is False
|
||||
assert any("source_id" in e for e in v.errors)
|
||||
|
||||
def test_invalid_sink_fails(self):
|
||||
v = AgentValidator()
|
||||
n1 = _make_node(node_id="n1")
|
||||
link = _make_link(source_id="n1", sink_id="missing")
|
||||
agent = _make_agent(nodes=[n1], links=[link])
|
||||
|
||||
assert v.validate_link_node_references(agent) is False
|
||||
assert any("sink_id" in e for e in v.errors)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_required_inputs
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidateRequiredInputs:
|
||||
def test_satisfied_by_default_passes(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={
|
||||
"properties": {"url": {"type": "string"}},
|
||||
"required": ["url"],
|
||||
},
|
||||
)
|
||||
node = _make_node(block_id="b1", input_default={"url": "http://example.com"})
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_required_inputs(agent, [block]) is True
|
||||
assert v.errors == []
|
||||
|
||||
def test_satisfied_by_link_passes(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={
|
||||
"properties": {"url": {"type": "string"}},
|
||||
"required": ["url"],
|
||||
},
|
||||
)
|
||||
node = _make_node(node_id="n1", block_id="b1")
|
||||
link = _make_link(source_id="n2", sink_id="n1", sink_name="url")
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_required_inputs(agent, [block]) is True
|
||||
|
||||
def test_missing_required_input_fails(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={
|
||||
"properties": {"url": {"type": "string"}},
|
||||
"required": ["url"],
|
||||
},
|
||||
)
|
||||
node = _make_node(block_id="b1", input_default={})
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_required_inputs(agent, [block]) is False
|
||||
assert any("missing required input" in e for e in v.errors)
|
||||
|
||||
def test_credentials_always_allowed_missing(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={
|
||||
"properties": {"credentials": {"type": "object"}},
|
||||
"required": ["credentials"],
|
||||
},
|
||||
)
|
||||
node = _make_node(block_id="b1", input_default={})
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_required_inputs(agent, [block]) is True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_data_type_compatibility
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidateDataTypeCompatibility:
|
||||
def test_matching_types_pass(self):
|
||||
v = AgentValidator()
|
||||
src_block = _make_block(
|
||||
block_id="src-b",
|
||||
output_schema={"properties": {"out": {"type": "string"}}},
|
||||
)
|
||||
sink_block = _make_block(
|
||||
block_id="sink-b",
|
||||
input_schema={"properties": {"inp": {"type": "string"}}, "required": []},
|
||||
)
|
||||
src_node = _make_node(node_id="n1", block_id="src-b")
|
||||
sink_node = _make_node(node_id="n2", block_id="sink-b")
|
||||
link = _make_link(
|
||||
source_id="n1", source_name="out", sink_id="n2", sink_name="inp"
|
||||
)
|
||||
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
|
||||
|
||||
assert (
|
||||
v.validate_data_type_compatibility(agent, [src_block, sink_block]) is True
|
||||
)
|
||||
|
||||
def test_int_number_compatible(self):
|
||||
v = AgentValidator()
|
||||
src_block = _make_block(
|
||||
block_id="src-b",
|
||||
output_schema={"properties": {"out": {"type": "integer"}}},
|
||||
)
|
||||
sink_block = _make_block(
|
||||
block_id="sink-b",
|
||||
input_schema={"properties": {"inp": {"type": "number"}}, "required": []},
|
||||
)
|
||||
src_node = _make_node(node_id="n1", block_id="src-b")
|
||||
sink_node = _make_node(node_id="n2", block_id="sink-b")
|
||||
link = _make_link(
|
||||
source_id="n1", source_name="out", sink_id="n2", sink_name="inp"
|
||||
)
|
||||
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
|
||||
|
||||
assert (
|
||||
v.validate_data_type_compatibility(agent, [src_block, sink_block]) is True
|
||||
)
|
||||
|
||||
def test_mismatched_types_fail(self):
|
||||
v = AgentValidator()
|
||||
src_block = _make_block(
|
||||
block_id="src-b",
|
||||
output_schema={"properties": {"out": {"type": "string"}}},
|
||||
)
|
||||
sink_block = _make_block(
|
||||
block_id="sink-b",
|
||||
input_schema={"properties": {"inp": {"type": "integer"}}, "required": []},
|
||||
)
|
||||
src_node = _make_node(node_id="n1", block_id="src-b")
|
||||
sink_node = _make_node(node_id="n2", block_id="sink-b")
|
||||
link = _make_link(
|
||||
source_id="n1", source_name="out", sink_id="n2", sink_name="inp"
|
||||
)
|
||||
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
|
||||
|
||||
assert (
|
||||
v.validate_data_type_compatibility(agent, [src_block, sink_block]) is False
|
||||
)
|
||||
assert any("mismatch" in e.lower() for e in v.errors)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_source_output_existence
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidateSourceOutputExistence:
|
||||
def test_valid_source_output_passes(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
output_schema={"properties": {"result": {"type": "string"}}},
|
||||
)
|
||||
node = _make_node(node_id="n1", block_id="b1")
|
||||
link = _make_link(source_id="n1", source_name="result", sink_id="n2")
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_source_output_existence(agent, [block]) is True
|
||||
|
||||
def test_invalid_source_output_fails(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
output_schema={"properties": {"result": {"type": "string"}}},
|
||||
)
|
||||
node = _make_node(node_id="n1", block_id="b1")
|
||||
link = _make_link(source_id="n1", source_name="nonexistent", sink_id="n2")
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_source_output_existence(agent, [block]) is False
|
||||
assert any("does not exist" in e for e in v.errors)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_prompt_double_curly_braces_spaces
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidatePromptDoubleCurlyBracesSpaces:
|
||||
def test_no_spaces_passes(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(input_default={"prompt": "Hello {{name}}!"})
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_prompt_double_curly_braces_spaces(agent) is True
|
||||
|
||||
def test_spaces_in_braces_fails(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(input_default={"prompt": "Hello {{user name}}!"})
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_prompt_double_curly_braces_spaces(agent) is False
|
||||
assert any("spaces" in e for e in v.errors)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_nested_sink_links
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidateNestedSinkLinks:
|
||||
def test_valid_nested_link_passes(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {"key": {"type": "string"}},
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
node = _make_node(node_id="n1", block_id="b1")
|
||||
link = _make_link(sink_id="n1", sink_name="config_#_key", source_id="n2")
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_nested_sink_links(agent, [block]) is True
|
||||
|
||||
def test_invalid_parent_fails(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(block_id="b1")
|
||||
node = _make_node(node_id="n1", block_id="b1")
|
||||
link = _make_link(sink_id="n1", sink_name="nonexistent_#_key", source_id="n2")
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_nested_sink_links(agent, [block]) is False
|
||||
assert any("does not exist" in e for e in v.errors)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_agent_executor_block_schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidateAgentExecutorBlockSchemas:
|
||||
def test_valid_schemas_pass(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(
|
||||
block_id=AGENT_EXECUTOR_BLOCK_ID,
|
||||
input_default={
|
||||
"graph_id": generate_uuid(),
|
||||
"input_schema": {"properties": {"q": {"type": "string"}}},
|
||||
"output_schema": {"properties": {"result": {"type": "string"}}},
|
||||
},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_agent_executor_block_schemas(agent) is True
|
||||
assert v.errors == []
|
||||
|
||||
def test_empty_input_schema_fails(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(
|
||||
block_id=AGENT_EXECUTOR_BLOCK_ID,
|
||||
input_default={
|
||||
"graph_id": generate_uuid(),
|
||||
"input_schema": {},
|
||||
"output_schema": {"properties": {"result": {"type": "string"}}},
|
||||
},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_agent_executor_block_schemas(agent) is False
|
||||
assert any("empty input_schema" in e for e in v.errors)
|
||||
|
||||
def test_missing_output_schema_fails(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(
|
||||
block_id=AGENT_EXECUTOR_BLOCK_ID,
|
||||
input_default={
|
||||
"graph_id": generate_uuid(),
|
||||
"input_schema": {"properties": {"q": {"type": "string"}}},
|
||||
},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_agent_executor_block_schemas(agent) is False
|
||||
assert any("output_schema" in e for e in v.errors)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_agent_executor_blocks
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidateAgentExecutorBlocks:
|
||||
def test_missing_graph_id_fails(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(
|
||||
block_id=AGENT_EXECUTOR_BLOCK_ID,
|
||||
input_default={},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_agent_executor_blocks(agent) is False
|
||||
assert any("graph_id" in e for e in v.errors)
|
||||
|
||||
def test_valid_graph_id_passes(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(
|
||||
block_id=AGENT_EXECUTOR_BLOCK_ID,
|
||||
input_default={"graph_id": generate_uuid()},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_agent_executor_blocks(agent) is True
|
||||
|
||||
def test_version_mismatch_with_library_agent(self):
|
||||
v = AgentValidator()
|
||||
lib_id = generate_uuid()
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id=AGENT_EXECUTOR_BLOCK_ID,
|
||||
input_default={"graph_id": lib_id, "graph_version": 1},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
library_agents = [{"graph_id": lib_id, "graph_version": 3, "name": "Sub Agent"}]
|
||||
|
||||
assert v.validate_agent_executor_blocks(agent, library_agents) is False
|
||||
assert any("mismatched graph_version" in e for e in v.errors)
|
||||
|
||||
def test_required_input_satisfied_by_schema_default_passes(self):
|
||||
"""Required sub-agent inputs filled with their schema default by the fixer
|
||||
should NOT be flagged as missing."""
|
||||
v = AgentValidator()
|
||||
lib_id = generate_uuid()
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id=AGENT_EXECUTOR_BLOCK_ID,
|
||||
input_default={
|
||||
"graph_id": lib_id,
|
||||
"input_schema": {
|
||||
"properties": {"mode": {"type": "string", "default": "fast"}}
|
||||
},
|
||||
"inputs": {"mode": "fast"}, # fixer populated with schema default
|
||||
},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": lib_id,
|
||||
"graph_version": 1,
|
||||
"name": "Sub",
|
||||
"input_schema": {
|
||||
"required": ["mode"],
|
||||
"properties": {"mode": {"type": "string", "default": "fast"}},
|
||||
},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
assert v.validate_agent_executor_blocks(agent, library_agents) is True
|
||||
assert v.errors == []
|
||||
|
||||
def test_required_input_not_linked_and_no_default_fails(self):
|
||||
"""Required sub-agent inputs without a link or schema default must fail."""
|
||||
v = AgentValidator()
|
||||
lib_id = generate_uuid()
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id=AGENT_EXECUTOR_BLOCK_ID,
|
||||
input_default={
|
||||
"graph_id": lib_id,
|
||||
"input_schema": {"properties": {"query": {"type": "string"}}},
|
||||
"inputs": {},
|
||||
},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": lib_id,
|
||||
"graph_version": 1,
|
||||
"name": "Sub",
|
||||
"input_schema": {
|
||||
"required": ["query"],
|
||||
"properties": {"query": {"type": "string"}},
|
||||
},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
assert v.validate_agent_executor_blocks(agent, library_agents) is False
|
||||
assert any("missing required sub-agent input" in e for e in v.errors)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_io_blocks
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidateIoBlocks:
|
||||
def test_missing_input_block_reports_error(self):
|
||||
v = AgentValidator()
|
||||
# Agent has output block but no input block
|
||||
node = _make_node(block_id=AGENT_OUTPUT_BLOCK_ID)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_io_blocks(agent) is False
|
||||
assert len(v.errors) == 1
|
||||
assert "AgentInputBlock" in v.errors[0]
|
||||
|
||||
def test_missing_output_block_reports_error(self):
|
||||
v = AgentValidator()
|
||||
# Agent has input block but no output block
|
||||
node = _make_node(block_id=AGENT_INPUT_BLOCK_ID)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_io_blocks(agent) is False
|
||||
assert len(v.errors) == 1
|
||||
assert "AgentOutputBlock" in v.errors[0]
|
||||
|
||||
def test_missing_both_io_blocks_reports_two_errors(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(block_id="some-other-block")
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_io_blocks(agent) is False
|
||||
assert len(v.errors) == 2
|
||||
|
||||
def test_both_io_blocks_present_no_error(self):
|
||||
v = AgentValidator()
|
||||
input_node = _make_node(block_id=AGENT_INPUT_BLOCK_ID)
|
||||
output_node = _make_node(block_id=AGENT_OUTPUT_BLOCK_ID)
|
||||
agent = _make_agent(nodes=[input_node, output_node])
|
||||
|
||||
assert v.validate_io_blocks(agent) is True
|
||||
assert v.errors == []
|
||||
|
||||
def test_empty_agent_reports_both_missing(self):
|
||||
v = AgentValidator()
|
||||
agent = _make_agent(nodes=[])
|
||||
|
||||
assert v.validate_io_blocks(agent) is False
|
||||
assert len(v.errors) == 2
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate (integration)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidate:
|
||||
def test_valid_agent_passes(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={
|
||||
"properties": {"url": {"type": "string"}},
|
||||
"required": ["url"],
|
||||
},
|
||||
output_schema={"properties": {"result": {"type": "string"}}},
|
||||
)
|
||||
input_block = _make_block(
|
||||
block_id=AGENT_INPUT_BLOCK_ID,
|
||||
name="AgentInputBlock",
|
||||
output_schema={"properties": {"result": {}}},
|
||||
)
|
||||
output_block = _make_block(
|
||||
block_id=AGENT_OUTPUT_BLOCK_ID,
|
||||
name="AgentOutputBlock",
|
||||
)
|
||||
input_node = _make_node(
|
||||
node_id="n-in",
|
||||
block_id=AGENT_INPUT_BLOCK_ID,
|
||||
input_default={"name": "url"},
|
||||
)
|
||||
n1 = _make_node(
|
||||
node_id="n1", block_id="b1", input_default={"url": "http://example.com"}
|
||||
)
|
||||
n2 = _make_node(
|
||||
node_id="n2", block_id="b1", input_default={"url": "http://example2.com"}
|
||||
)
|
||||
output_node = _make_node(
|
||||
node_id="n-out",
|
||||
block_id=AGENT_OUTPUT_BLOCK_ID,
|
||||
input_default={"name": "result"},
|
||||
)
|
||||
link = _make_link(
|
||||
source_id="n1", source_name="result", sink_id="n2", sink_name="url"
|
||||
)
|
||||
agent = _make_agent(nodes=[input_node, n1, n2, output_node], links=[link])
|
||||
|
||||
is_valid, error_message = v.validate(agent, [block, input_block, output_block])
|
||||
|
||||
assert is_valid is True
|
||||
assert error_message is None
|
||||
|
||||
def test_invalid_agent_returns_errors(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(block_id="nonexistent")
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
is_valid, error_message = v.validate(agent, [])
|
||||
|
||||
assert is_valid is False
|
||||
assert error_message is not None
|
||||
assert "does not exist" in error_message
|
||||
|
||||
def test_empty_agent_fails_io_validation(self):
|
||||
v = AgentValidator()
|
||||
agent = _make_agent()
|
||||
|
||||
is_valid, error_message = v.validate(agent, [])
|
||||
|
||||
assert is_valid is False
|
||||
assert error_message is not None
|
||||
assert "AgentInputBlock" in error_message
|
||||
assert "AgentOutputBlock" in error_message
|
||||
|
||||
|
||||
class TestValidateMCPToolBlocks:
|
||||
"""Tests for validate_mcp_tool_blocks."""
|
||||
|
||||
def test_missing_server_url_reports_error(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(
|
||||
block_id=MCP_TOOL_BLOCK_ID,
|
||||
input_default={"selected_tool": "my_tool"},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
result = v.validate_mcp_tool_blocks(agent)
|
||||
|
||||
assert result is False
|
||||
assert any("server_url" in e for e in v.errors)
|
||||
|
||||
def test_missing_selected_tool_reports_error(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(
|
||||
block_id=MCP_TOOL_BLOCK_ID,
|
||||
input_default={"server_url": "https://mcp.example.com/sse"},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
result = v.validate_mcp_tool_blocks(agent)
|
||||
|
||||
assert result is False
|
||||
assert any("selected_tool" in e for e in v.errors)
|
||||
|
||||
def test_valid_mcp_block_passes(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(
|
||||
block_id=MCP_TOOL_BLOCK_ID,
|
||||
input_default={
|
||||
"server_url": "https://mcp.example.com/sse",
|
||||
"selected_tool": "search",
|
||||
"tool_input_schema": {"properties": {"query": {"type": "string"}}},
|
||||
"tool_arguments": {},
|
||||
},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
result = v.validate_mcp_tool_blocks(agent)
|
||||
|
||||
assert result is True
|
||||
assert len(v.errors) == 0
|
||||
|
||||
def test_both_missing_reports_two_errors(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(
|
||||
block_id=MCP_TOOL_BLOCK_ID,
|
||||
input_default={},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
v.validate_mcp_tool_blocks(agent)
|
||||
|
||||
assert len(v.errors) == 2
|
||||
@@ -208,6 +208,9 @@ def _library_agent_to_info(agent: LibraryAgent) -> AgentInfo:
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
graph_version=agent.graph_version,
|
||||
input_schema=agent.input_schema,
|
||||
output_schema=agent.output_schema,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -21,10 +21,10 @@ from typing import Any
|
||||
from e2b import AsyncSandbox
|
||||
from e2b.exceptions import TimeoutException
|
||||
|
||||
from backend.copilot.context import E2B_WORKDIR, get_current_sandbox
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .e2b_sandbox import E2B_WORKDIR
|
||||
from .models import BashExecResponse, ErrorResponse, ToolResponseBase
|
||||
from .sandbox import get_workspace_dir, has_full_sandbox, run_sandboxed
|
||||
|
||||
@@ -94,9 +94,6 @@ class BashExecTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# E2B path: run on remote cloud sandbox when available.
|
||||
from backend.copilot.sdk.tool_adapter import get_current_sandbox
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
return await self._execute_on_e2b(sandbox, command, timeout, session_id)
|
||||
|
||||
@@ -0,0 +1,157 @@
|
||||
"""Tool for continuing block execution after human review approval."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from prisma.enums import ReviewStatus
|
||||
|
||||
from backend.blocks import get_block
|
||||
from backend.copilot.constants import (
|
||||
COPILOT_NODE_PREFIX,
|
||||
COPILOT_SESSION_PREFIX,
|
||||
parse_node_id_from_exec_id,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import review_db
|
||||
|
||||
from .base import BaseTool
|
||||
from .helpers import execute_block, resolve_block_credentials
|
||||
from .models import ErrorResponse, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContinueRunBlockTool(BaseTool):
|
||||
"""Tool for continuing a block execution after human review approval."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "continue_run_block"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Continue executing a block after human review approval. "
|
||||
"Use this after a run_block call returned review_required. "
|
||||
"Pass the review_id from the review_required response. "
|
||||
"The block will execute with the original pre-approved input data."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"review_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The review_id from a previous review_required response. "
|
||||
"This resumes execution with the pre-approved input data."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["review_id"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
review_id = (
|
||||
kwargs.get("review_id", "").strip() if kwargs.get("review_id") else ""
|
||||
)
|
||||
session_id = session.session_id
|
||||
|
||||
if not review_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide a review_id", session_id=session_id
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required", session_id=session_id
|
||||
)
|
||||
|
||||
# Look up and validate the review record via adapter
|
||||
reviews = await review_db().get_reviews_by_node_exec_ids([review_id], user_id)
|
||||
review = reviews.get(review_id)
|
||||
|
||||
if not review:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Review '{review_id}' not found or already executed. "
|
||||
"It may have been consumed by a previous continue_run_block call."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Validate the review belongs to this session
|
||||
expected_graph_exec_id = f"{COPILOT_SESSION_PREFIX}{session_id}"
|
||||
if review.graph_exec_id != expected_graph_exec_id:
|
||||
return ErrorResponse(
|
||||
message="Review does not belong to this session.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if review.status == ReviewStatus.WAITING:
|
||||
return ErrorResponse(
|
||||
message="Review has not been approved yet. "
|
||||
"Please wait for the user to approve the review first.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if review.status == ReviewStatus.REJECTED:
|
||||
return ErrorResponse(
|
||||
message="Review was rejected. The block will not execute.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Extract block_id from review_id: copilot-node-{block_id}:{random_hex}
|
||||
block_id = parse_node_id_from_exec_id(review_id).removeprefix(
|
||||
COPILOT_NODE_PREFIX
|
||||
)
|
||||
block = get_block(block_id)
|
||||
if not block:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block_id}' not found", session_id=session_id
|
||||
)
|
||||
|
||||
input_data: dict[str, Any] = (
|
||||
review.payload if isinstance(review.payload, dict) else {}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Continuing block {block.name} ({block_id}) for user {user_id} "
|
||||
f"with review_id={review_id}"
|
||||
)
|
||||
|
||||
matched_creds, missing_creds = await resolve_block_credentials(
|
||||
user_id, block, input_data
|
||||
)
|
||||
if missing_creds:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block.name}' requires credentials that are not configured.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id=block_id,
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
node_exec_id=review_id,
|
||||
matched_credentials=matched_creds,
|
||||
)
|
||||
|
||||
# Delete review record after successful execution (one-time use)
|
||||
if result.type != "error":
|
||||
await review_db().delete_review_by_node_exec_id(review_id, user_id)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,186 @@
|
||||
"""Tests for ContinueRunBlockTool."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import ReviewStatus
|
||||
|
||||
from ._test_data import make_session
|
||||
from .continue_run_block import ContinueRunBlockTool
|
||||
from .models import BlockOutputResponse, ErrorResponse
|
||||
|
||||
_TEST_USER_ID = "test-user-continue"
|
||||
|
||||
|
||||
def _make_review_model(
|
||||
node_exec_id: str,
|
||||
status: ReviewStatus = ReviewStatus.APPROVED,
|
||||
payload: dict | None = None,
|
||||
graph_exec_id: str = "",
|
||||
):
|
||||
"""Create a mock PendingHumanReviewModel."""
|
||||
mock = MagicMock()
|
||||
mock.node_exec_id = node_exec_id
|
||||
mock.status = status
|
||||
mock.payload = payload or {"text": "hello"}
|
||||
mock.graph_exec_id = graph_exec_id
|
||||
return mock
|
||||
|
||||
|
||||
class TestContinueRunBlock:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_missing_review_id_returns_error(self):
|
||||
tool = ContinueRunBlockTool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
review_id="",
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "review_id" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_review_not_found_returns_error(self):
|
||||
tool = ContinueRunBlockTool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_reviews_by_node_exec_ids = AsyncMock(return_value={})
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.continue_run_block.review_db",
|
||||
return_value=mock_db,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
review_id="copilot-node-some-block:abc12345",
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "not found" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_waiting_review_returns_error(self):
|
||||
tool = ContinueRunBlockTool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
review_id = "copilot-node-some-block:abc12345"
|
||||
graph_exec_id = f"copilot-session-{session.session_id}"
|
||||
review = _make_review_model(
|
||||
review_id, status=ReviewStatus.WAITING, graph_exec_id=graph_exec_id
|
||||
)
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_reviews_by_node_exec_ids = AsyncMock(
|
||||
return_value={review_id: review}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.continue_run_block.review_db",
|
||||
return_value=mock_db,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
review_id=review_id,
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "not been approved" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_rejected_review_returns_error(self):
|
||||
tool = ContinueRunBlockTool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
review_id = "copilot-node-some-block:abc12345"
|
||||
graph_exec_id = f"copilot-session-{session.session_id}"
|
||||
review = _make_review_model(
|
||||
review_id, status=ReviewStatus.REJECTED, graph_exec_id=graph_exec_id
|
||||
)
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_reviews_by_node_exec_ids = AsyncMock(
|
||||
return_value={review_id: review}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.continue_run_block.review_db",
|
||||
return_value=mock_db,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
review_id=review_id,
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "rejected" in response.message.lower()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_approved_review_executes_block(self):
|
||||
tool = ContinueRunBlockTool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
review_id = "copilot-node-delete-branch-id:abc12345"
|
||||
graph_exec_id = f"copilot-session-{session.session_id}"
|
||||
input_data = {"repo_url": "https://github.com/test/repo", "branch": "main"}
|
||||
review = _make_review_model(
|
||||
review_id,
|
||||
status=ReviewStatus.APPROVED,
|
||||
payload=input_data,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
|
||||
mock_block = MagicMock()
|
||||
mock_block.name = "Delete Branch"
|
||||
|
||||
async def mock_execute(data, **kwargs):
|
||||
yield "result", "Branch deleted"
|
||||
|
||||
mock_block.execute = mock_execute
|
||||
mock_block.input_schema.get_credentials_fields_info.return_value = []
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="test-workspace-id")
|
||||
)
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_reviews_by_node_exec_ids = AsyncMock(
|
||||
return_value={review_id: review}
|
||||
)
|
||||
mock_db.delete_review_by_node_exec_id = AsyncMock(return_value=1)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.continue_run_block.review_db",
|
||||
return_value=mock_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.continue_run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
review_id=review_id,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert response.success is True
|
||||
assert response.block_name == "Delete Branch"
|
||||
# Verify review was deleted (one-time use)
|
||||
mock_db.delete_review_by_node_exec_id.assert_called_once_with(
|
||||
review_id, _TEST_USER_ID
|
||||
)
|
||||
@@ -1,34 +1,20 @@
|
||||
"""CreateAgentTool - Creates agents from natural language descriptions."""
|
||||
"""CreateAgentTool - Creates agents from pre-built JSON."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
decompose_goal,
|
||||
enrich_library_agents_from_steps,
|
||||
generate_agent,
|
||||
get_user_message_for_error,
|
||||
save_agent_to_library,
|
||||
)
|
||||
from .agent_generator.pipeline import fetch_library_agents, fix_validate_and_save
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
SuggestedGoalResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .models import ErrorResponse, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CreateAgentTool(BaseTool):
|
||||
"""Tool for creating agents from natural language descriptions."""
|
||||
"""Tool for creating agents from pre-built JSON."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -37,15 +23,12 @@ class CreateAgentTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Create a new agent workflow from a natural language description. "
|
||||
"First generates a preview, then saves to library if save=true. "
|
||||
"\n\nWorkflow: (1) Always check find_library_agent first for existing building blocks. "
|
||||
"(2) Call create_agent with description and library_agent_ids. "
|
||||
"(3) If response contains suggested_goal: Present to user, ask for confirmation, "
|
||||
"then call again with the suggested goal if accepted. "
|
||||
"(4) If response contains clarifying_questions: Present to user, collect answers, "
|
||||
"then call again with original description AND answers in the context parameter. "
|
||||
"\n\nThis feedback loop ensures the generated agent matches user intent."
|
||||
"Create a new agent workflow. Pass `agent_json` with the complete "
|
||||
"agent graph JSON you generated using block schemas from find_block. "
|
||||
"The tool validates, auto-fixes, and saves.\n\n"
|
||||
"IMPORTANT: Before calling this tool, search for relevant existing agents "
|
||||
"using find_library_agent that could be used as building blocks. "
|
||||
"Pass their IDs in the library_agent_ids parameter."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -57,34 +40,26 @@ class CreateAgentTool(BaseTool):
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string",
|
||||
"agent_json": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"Natural language description of what the agent should do. "
|
||||
"Be specific about inputs, outputs, and the workflow steps."
|
||||
),
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Additional context or answers to previous clarifying questions. "
|
||||
"Include any preferences or constraints mentioned by the user."
|
||||
"The agent JSON to validate and save. "
|
||||
"Must contain 'nodes' and 'links' arrays, and optionally "
|
||||
"'name' and 'description'."
|
||||
),
|
||||
},
|
||||
"library_agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"List of library agent IDs to use as building blocks. "
|
||||
"Search for relevant agents using find_library_agent first, "
|
||||
"then pass their IDs here so they can be composed into the new agent."
|
||||
"List of library agent IDs to use as building blocks."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to save the agent to the user's library. "
|
||||
"Default is true. Set to false for preview only."
|
||||
"Whether to save the agent. Default is true. "
|
||||
"Set to false for preview only."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
@@ -97,7 +72,7 @@ class CreateAgentTool(BaseTool):
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["description"],
|
||||
"required": ["agent_json"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
@@ -106,278 +81,49 @@ class CreateAgentTool(BaseTool):
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the create_agent tool.
|
||||
|
||||
Flow:
|
||||
1. Decompose the description into steps (may return clarifying questions)
|
||||
2. Generate agent JSON (external service handles fixing and validation)
|
||||
3. Preview or save based on the save parameter
|
||||
"""
|
||||
description = kwargs.get("description", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
library_agent_ids = kwargs.get("library_agent_ids", [])
|
||||
save = kwargs.get("save", True)
|
||||
folder_id = kwargs.get("folder_id")
|
||||
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] START - description_len={len(description)}, "
|
||||
f"library_agent_ids={library_agent_ids}, save={save}, user_id={user_id}, session_id={session_id}"
|
||||
if not agent_json:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Please provide agent_json with the complete agent graph. "
|
||||
"Use find_block to discover blocks, then generate the JSON."
|
||||
),
|
||||
error="missing_agent_json",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
save = kwargs.get("save", True)
|
||||
library_agent_ids = kwargs.get("library_agent_ids", [])
|
||||
folder_id: str | None = kwargs.get("folder_id")
|
||||
|
||||
nodes = agent_json.get("nodes", [])
|
||||
if not nodes:
|
||||
return ErrorResponse(
|
||||
message="The agent JSON has no nodes. An agent needs at least one block.",
|
||||
error="empty_agent",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Ensure top-level fields
|
||||
if "id" not in agent_json:
|
||||
agent_json["id"] = str(uuid.uuid4())
|
||||
if "version" not in agent_json:
|
||||
agent_json["version"] = 1
|
||||
if "is_active" not in agent_json:
|
||||
agent_json["is_active"] = True
|
||||
|
||||
# Fetch library agents for AgentExecutorBlock validation
|
||||
library_agents = await fetch_library_agents(user_id, library_agent_ids)
|
||||
|
||||
return await fix_validate_and_save(
|
||||
agent_json,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
save=save,
|
||||
is_update=False,
|
||||
default_name="Generated Agent",
|
||||
library_agents=library_agents,
|
||||
folder_id=folder_id,
|
||||
)
|
||||
|
||||
if not description:
|
||||
return ErrorResponse(
|
||||
message="Please provide a description of what the agent should do.",
|
||||
error="Missing description parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Fetch library agents by IDs if provided
|
||||
library_agents = None
|
||||
if user_id and library_agent_ids:
|
||||
try:
|
||||
from .agent_generator import get_library_agents_by_ids
|
||||
|
||||
library_agents = await get_library_agents_by_ids(
|
||||
user_id=user_id,
|
||||
agent_ids=library_agent_ids,
|
||||
)
|
||||
logger.debug(
|
||||
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agents by IDs: {e}")
|
||||
|
||||
try:
|
||||
decomposition_result = await decompose_goal(
|
||||
description, context, library_agents
|
||||
)
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] DECOMPOSE - type={decomposition_result.get('type') if decomposition_result else None}, "
|
||||
f"session_id={session_id}"
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
logger.error(
|
||||
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured, session_id={session_id}"
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Agent generation is not available. "
|
||||
"The Agent Generator service is not configured."
|
||||
),
|
||||
error="service_not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if decomposition_result is None:
|
||||
return ErrorResponse(
|
||||
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
|
||||
error="decomposition_failed",
|
||||
details={"description": description[:100]},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if decomposition_result.get("type") == "error":
|
||||
error_msg = decomposition_result.get("error", "Unknown error")
|
||||
error_type = decomposition_result.get("error_type", "unknown")
|
||||
user_message = get_user_message_for_error(
|
||||
error_type,
|
||||
operation="analyze the goal",
|
||||
llm_parse_message="The AI had trouble understanding this request. Please try rephrasing your goal.",
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
error=f"decomposition_failed:{error_type}",
|
||||
details={
|
||||
"description": description[:100],
|
||||
"service_error": error_msg,
|
||||
"error_type": error_type,
|
||||
},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if decomposition_result.get("type") == "clarifying_questions":
|
||||
questions = decomposition_result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information to create this agent. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if decomposition_result.get("type") == "unachievable_goal":
|
||||
suggested = decomposition_result.get("suggested_goal", "")
|
||||
reason = decomposition_result.get("reason", "")
|
||||
return SuggestedGoalResponse(
|
||||
message=(
|
||||
f"This goal cannot be accomplished with the available blocks. {reason}"
|
||||
),
|
||||
suggested_goal=suggested,
|
||||
reason=reason,
|
||||
original_goal=description,
|
||||
goal_type="unachievable",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if decomposition_result.get("type") == "vague_goal":
|
||||
suggested = decomposition_result.get("suggested_goal", "")
|
||||
reason = decomposition_result.get(
|
||||
"reason", "The goal needs more specific details"
|
||||
)
|
||||
return SuggestedGoalResponse(
|
||||
message="The goal is too vague to create a specific workflow.",
|
||||
suggested_goal=suggested,
|
||||
reason=reason,
|
||||
original_goal=description,
|
||||
goal_type="vague",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if user_id and library_agents is not None:
|
||||
try:
|
||||
library_agents = await enrich_library_agents_from_steps(
|
||||
user_id=user_id,
|
||||
decomposition_result=decomposition_result,
|
||||
existing_agents=library_agents,
|
||||
include_marketplace=True,
|
||||
)
|
||||
logger.debug(
|
||||
f"After enrichment: {len(library_agents)} total agents for sub-agent composition"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
||||
|
||||
try:
|
||||
agent_json = await generate_agent(
|
||||
decomposition_result,
|
||||
library_agents,
|
||||
)
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] GENERATE - "
|
||||
f"success={agent_json is not None}, "
|
||||
f"is_error={isinstance(agent_json, dict) and agent_json.get('type') == 'error'}, "
|
||||
f"session_id={session_id}"
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
logger.error(
|
||||
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured during generation, session_id={session_id}"
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Agent generation is not available. "
|
||||
"The Agent Generator service is not configured."
|
||||
),
|
||||
error="service_not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if agent_json is None:
|
||||
return ErrorResponse(
|
||||
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
|
||||
error="generation_failed",
|
||||
details={"description": description[:100]},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if isinstance(agent_json, dict) and agent_json.get("type") == "error":
|
||||
error_msg = agent_json.get("error", "Unknown error")
|
||||
error_type = agent_json.get("error_type", "unknown")
|
||||
user_message = get_user_message_for_error(
|
||||
error_type,
|
||||
operation="generate the agent",
|
||||
llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.",
|
||||
validation_message=(
|
||||
"I wasn't able to create a valid agent for this request. "
|
||||
"The generated workflow had some structural issues. "
|
||||
"Please try simplifying your goal or breaking it into smaller steps."
|
||||
),
|
||||
error_details=error_msg,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
error=f"generation_failed:{error_type}",
|
||||
details={
|
||||
"description": description[:100],
|
||||
"service_error": error_msg,
|
||||
"error_type": error_type,
|
||||
},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agent_name = agent_json.get("name", "Generated Agent")
|
||||
agent_description = agent_json.get("description", "")
|
||||
node_count = len(agent_json.get("nodes", []))
|
||||
link_count = len(agent_json.get("links", []))
|
||||
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] AGENT_JSON - name={agent_name}, "
|
||||
f"nodes={node_count}, links={link_count}, save={save}, session_id={session_id}"
|
||||
)
|
||||
|
||||
if not save:
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] RETURN - AgentPreviewResponse, session_id={session_id}"
|
||||
)
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
|
||||
f"Review it and call create_agent with save=true to save it to your library."
|
||||
),
|
||||
agent_json=agent_json,
|
||||
agent_name=agent_name,
|
||||
description=agent_description,
|
||||
node_count=node_count,
|
||||
link_count=link_count,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
error="auth_required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
agent_json, user_id, folder_id=folder_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] SAVED - graph_id={created_graph.id}, "
|
||||
f"library_agent_id={library_agent.id}, session_id={session_id}"
|
||||
)
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] RETURN - AgentSavedResponse, session_id={session_id}"
|
||||
)
|
||||
return AgentSavedResponse(
|
||||
message=f"Agent '{created_graph.name}' has been saved to your library!",
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[AGENT_CREATE_DEBUG] ERROR - save_failed: {str(e)}, session_id={session_id}"
|
||||
)
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] RETURN - ErrorResponse (save_failed), session_id={session_id}"
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to save the agent: {str(e)}",
|
||||
error="save_failed",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
"""Tests for CreateAgentTool response types."""
|
||||
"""Tests for CreateAgentTool."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.tools.create_agent import CreateAgentTool
|
||||
from backend.copilot.tools.models import (
|
||||
ClarificationNeededResponse,
|
||||
ErrorResponse,
|
||||
SuggestedGoalResponse,
|
||||
)
|
||||
from backend.copilot.tools.models import AgentPreviewResponse, ErrorResponse
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-create-agent"
|
||||
_PIPELINE = "backend.copilot.tools.agent_generator.pipeline"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -26,102 +23,147 @@ def session():
|
||||
return make_session(_TEST_USER_ID)
|
||||
|
||||
|
||||
# ── Input validation tests ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_description_returns_error(tool, session):
|
||||
"""Missing description returns ErrorResponse."""
|
||||
result = await tool._execute(user_id=_TEST_USER_ID, session=session, description="")
|
||||
async def test_missing_agent_json_returns_error(tool, session):
|
||||
"""Missing agent_json returns ErrorResponse."""
|
||||
result = await tool._execute(user_id=_TEST_USER_ID, session=session)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "Missing description parameter"
|
||||
assert result.error == "missing_agent_json"
|
||||
|
||||
|
||||
# ── Local mode tests ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vague_goal_returns_suggested_goal_response(tool, session):
|
||||
"""vague_goal decomposition result returns SuggestedGoalResponse, not ErrorResponse."""
|
||||
vague_result = {
|
||||
"type": "vague_goal",
|
||||
"suggested_goal": "Monitor Twitter mentions for a specific keyword and send a daily digest email",
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.decompose_goal",
|
||||
new_callable=AsyncMock,
|
||||
return_value=vague_result,
|
||||
),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
description="monitor social media",
|
||||
)
|
||||
|
||||
assert isinstance(result, SuggestedGoalResponse)
|
||||
assert result.goal_type == "vague"
|
||||
assert result.suggested_goal == vague_result["suggested_goal"]
|
||||
assert result.original_goal == "monitor social media"
|
||||
assert result.reason == "The goal needs more specific details"
|
||||
assert not isinstance(result, ErrorResponse)
|
||||
async def test_local_mode_empty_nodes_returns_error(tool, session):
|
||||
"""Local mode with no nodes returns ErrorResponse."""
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json={"nodes": [], "links": []},
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "no nodes" in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unachievable_goal_returns_suggested_goal_response(tool, session):
|
||||
"""unachievable_goal decomposition result returns SuggestedGoalResponse, not ErrorResponse."""
|
||||
unachievable_result = {
|
||||
"type": "unachievable_goal",
|
||||
"suggested_goal": "Summarize the latest news articles on a topic and send them by email",
|
||||
"reason": "There are no blocks for mind-reading.",
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.decompose_goal",
|
||||
new_callable=AsyncMock,
|
||||
return_value=unachievable_result,
|
||||
),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
description="read my mind",
|
||||
)
|
||||
|
||||
assert isinstance(result, SuggestedGoalResponse)
|
||||
assert result.goal_type == "unachievable"
|
||||
assert result.suggested_goal == unachievable_result["suggested_goal"]
|
||||
assert result.original_goal == "read my mind"
|
||||
assert result.reason == unachievable_result["reason"]
|
||||
assert not isinstance(result, ErrorResponse)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clarifying_questions_returns_clarification_needed_response(
|
||||
tool, session
|
||||
):
|
||||
"""clarifying_questions decomposition result returns ClarificationNeededResponse."""
|
||||
clarifying_result = {
|
||||
"type": "clarifying_questions",
|
||||
"questions": [
|
||||
async def test_local_mode_preview(tool, session):
|
||||
"""Local mode with save=false returns AgentPreviewResponse."""
|
||||
agent_json = {
|
||||
"name": "Test Agent",
|
||||
"description": "A test agent",
|
||||
"nodes": [
|
||||
{
|
||||
"question": "What platform should be monitored?",
|
||||
"keyword": "platform",
|
||||
"example": "Twitter, Reddit",
|
||||
"id": "node-1",
|
||||
"block_id": "block-1",
|
||||
"input_default": {},
|
||||
"metadata": {"position": {"x": 0, "y": 0}},
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
}
|
||||
|
||||
mock_fixer = MagicMock()
|
||||
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
|
||||
mock_fixer.get_fixes_applied.return_value = []
|
||||
|
||||
mock_validator = MagicMock()
|
||||
mock_validator.validate.return_value = (True, None)
|
||||
mock_validator.errors = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.decompose_goal",
|
||||
new_callable=AsyncMock,
|
||||
return_value=clarifying_result,
|
||||
),
|
||||
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
|
||||
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
|
||||
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
description="monitor social media and alert me",
|
||||
agent_json=agent_json,
|
||||
save=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, ClarificationNeededResponse)
|
||||
assert len(result.questions) == 1
|
||||
assert result.questions[0].keyword == "platform"
|
||||
assert isinstance(result, AgentPreviewResponse)
|
||||
assert result.agent_name == "Test Agent"
|
||||
assert result.node_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_mode_validation_failure(tool, session):
|
||||
"""Local mode returns ErrorResponse when validation fails after fixing."""
|
||||
agent_json = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"block_id": "bad-block",
|
||||
"input_default": {},
|
||||
"metadata": {},
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
}
|
||||
|
||||
mock_fixer = MagicMock()
|
||||
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
|
||||
mock_fixer.get_fixes_applied.return_value = []
|
||||
|
||||
mock_validator = MagicMock()
|
||||
mock_validator.validate.return_value = (False, "Block 'bad-block' not found")
|
||||
mock_validator.errors = ["Block 'bad-block' not found"]
|
||||
|
||||
with (
|
||||
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
|
||||
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
|
||||
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json=agent_json,
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "validation_failed"
|
||||
assert "Block 'bad-block' not found" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_mode_no_auth_returns_error(tool, session):
|
||||
"""Local mode with save=true and no user returns ErrorResponse."""
|
||||
agent_json = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"block_id": "block-1",
|
||||
"input_default": {},
|
||||
"metadata": {},
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
}
|
||||
|
||||
mock_fixer = MagicMock()
|
||||
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
|
||||
mock_fixer.get_fixes_applied.return_value = []
|
||||
|
||||
mock_validator = MagicMock()
|
||||
mock_validator.validate.return_value = (True, None)
|
||||
mock_validator.errors = []
|
||||
|
||||
with (
|
||||
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
|
||||
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
|
||||
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
agent_json=agent_json,
|
||||
save=True,
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "logged in" in result.message.lower()
|
||||
|
||||
@@ -1,34 +1,20 @@
|
||||
"""CustomizeAgentTool - Customizes marketplace/template agents using natural language."""
|
||||
"""CustomizeAgentTool - Customizes marketplace/template agents."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import store_db as get_store_db
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
customize_template,
|
||||
get_user_message_for_error,
|
||||
graph_to_json,
|
||||
save_agent_to_library,
|
||||
)
|
||||
from .agent_generator.pipeline import fetch_library_agents, fix_validate_and_save
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .models import ErrorResponse, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomizeAgentTool(BaseTool):
|
||||
"""Tool for customizing marketplace/template agents using natural language."""
|
||||
"""Tool for customizing marketplace/template agents."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -37,9 +23,9 @@ class CustomizeAgentTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Customize a marketplace or template agent using natural language. "
|
||||
"Takes an existing agent from the marketplace and modifies it based on "
|
||||
"the user's requirements before adding to their library."
|
||||
"Customize a marketplace or template agent. Pass `agent_json` "
|
||||
"with the complete customized agent JSON. The tool validates, "
|
||||
"auto-fixes, and saves."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -51,32 +37,24 @@ class CustomizeAgentTool(BaseTool):
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"agent_json": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"The marketplace agent ID in format 'creator/slug' "
|
||||
"(e.g., 'autogpt/newsletter-writer'). "
|
||||
"Get this from find_agent results."
|
||||
"Complete customized agent JSON to validate and save. "
|
||||
"Optionally include 'name' and 'description'."
|
||||
),
|
||||
},
|
||||
"modifications": {
|
||||
"type": "string",
|
||||
"library_agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"Natural language description of how to customize the agent. "
|
||||
"Be specific about what changes you want to make."
|
||||
),
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Additional context or answers to previous clarifying questions."
|
||||
"List of library agent IDs to use as building blocks."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to save the customized agent to the user's library. "
|
||||
"Default is true. Set to false for preview only."
|
||||
"Whether to save the customized agent. Default is true."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
@@ -89,7 +67,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["agent_id", "modifications"],
|
||||
"required": ["agent_json"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
@@ -98,247 +76,46 @@ class CustomizeAgentTool(BaseTool):
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the customize_agent tool.
|
||||
|
||||
Flow:
|
||||
1. Parse the agent ID to get creator/slug
|
||||
2. Fetch the template agent from the marketplace
|
||||
3. Call customize_template with the modification request
|
||||
4. Preview or save based on the save parameter
|
||||
"""
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
modifications = kwargs.get("modifications", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
folder_id = kwargs.get("folder_id")
|
||||
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
|
||||
error="missing_agent_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not modifications:
|
||||
return ErrorResponse(
|
||||
message="Please describe how you want to customize this agent.",
|
||||
error="missing_modifications",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Parse agent_id in format "creator/slug"
|
||||
parts = [p.strip() for p in agent_id.split("/")]
|
||||
if len(parts) != 2 or not parts[0] or not parts[1]:
|
||||
if not agent_json:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Invalid agent ID format: '{agent_id}'. "
|
||||
"Expected format is 'creator/agent-name' "
|
||||
"(e.g., 'autogpt/newsletter-writer')."
|
||||
"Please provide agent_json with the complete customized agent graph."
|
||||
),
|
||||
error="invalid_agent_id_format",
|
||||
error="missing_agent_json",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
creator_username, agent_slug = parts
|
||||
save = kwargs.get("save", True)
|
||||
library_agent_ids = kwargs.get("library_agent_ids", [])
|
||||
folder_id: str | None = kwargs.get("folder_id")
|
||||
|
||||
store_db = get_store_db()
|
||||
|
||||
# Fetch the marketplace agent details
|
||||
try:
|
||||
agent_details = await store_db.get_store_agent_details(
|
||||
username=creator_username, agent_name=agent_slug
|
||||
)
|
||||
except NotFoundError:
|
||||
nodes = agent_json.get("nodes", [])
|
||||
if not nodes:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Could not find marketplace agent '{agent_id}'. "
|
||||
"Please check the agent ID and try again."
|
||||
),
|
||||
error="agent_not_found",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message="Failed to fetch the marketplace agent. Please try again.",
|
||||
error="fetch_error",
|
||||
message="The agent JSON has no nodes.",
|
||||
error="empty_agent",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not agent_details.store_listing_version_id:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"The agent '{agent_id}' does not have an available version. "
|
||||
"Please try a different agent."
|
||||
),
|
||||
error="no_version_available",
|
||||
session_id=session_id,
|
||||
)
|
||||
# Ensure top-level fields before the fixer pipeline
|
||||
if "id" not in agent_json:
|
||||
agent_json["id"] = str(uuid.uuid4())
|
||||
agent_json.setdefault("version", 1)
|
||||
agent_json.setdefault("is_active", True)
|
||||
|
||||
# Get the full agent graph
|
||||
try:
|
||||
graph = await store_db.get_agent(agent_details.store_listing_version_id)
|
||||
template_agent = graph_to_json(graph)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message="Failed to fetch the agent configuration. Please try again.",
|
||||
error="graph_fetch_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
# Fetch library agents for AgentExecutorBlock validation
|
||||
library_agents = await fetch_library_agents(user_id, library_agent_ids)
|
||||
|
||||
# Call customize_template
|
||||
try:
|
||||
result = await customize_template(
|
||||
template_agent=template_agent,
|
||||
modification_request=modifications,
|
||||
context=context,
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Agent customization is not available. "
|
||||
"The Agent Generator service is not configured."
|
||||
),
|
||||
error="service_not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling customize_template for {agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Failed to customize the agent due to a service error. "
|
||||
"Please try again."
|
||||
),
|
||||
error="customization_service_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Failed to customize the agent. "
|
||||
"The agent generation service may be unavailable or timed out. "
|
||||
"Please try again."
|
||||
),
|
||||
error="customization_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Handle error response
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
error_type = result.get("error_type", "unknown")
|
||||
user_message = get_user_message_for_error(
|
||||
error_type,
|
||||
operation="customize the agent",
|
||||
llm_parse_message=(
|
||||
"The AI had trouble customizing the agent. "
|
||||
"Please try again or simplify your request."
|
||||
),
|
||||
validation_message=(
|
||||
"The customized agent failed validation. "
|
||||
"Please try rephrasing your request."
|
||||
),
|
||||
error_details=error_msg,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
error=f"customization_failed:{error_type}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Handle clarifying questions
|
||||
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
|
||||
questions = result.get("questions") or []
|
||||
if not isinstance(questions, list):
|
||||
logger.error(
|
||||
f"Unexpected clarifying questions format: {type(questions)}"
|
||||
)
|
||||
questions = []
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information to customize this agent. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
if isinstance(q, dict)
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Result should be the customized agent JSON
|
||||
if not isinstance(result, dict):
|
||||
logger.error(f"Unexpected customize_template response type: {type(result)}")
|
||||
return ErrorResponse(
|
||||
message="Failed to customize the agent due to an unexpected response.",
|
||||
error="unexpected_response_type",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
customized_agent = result
|
||||
|
||||
agent_name = customized_agent.get(
|
||||
"name", f"Customized {agent_details.agent_name}"
|
||||
return await fix_validate_and_save(
|
||||
agent_json,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
save=save,
|
||||
is_update=False,
|
||||
default_name="Customized Agent",
|
||||
library_agents=library_agents,
|
||||
folder_id=folder_id,
|
||||
)
|
||||
agent_description = customized_agent.get("description", "")
|
||||
nodes = customized_agent.get("nodes")
|
||||
links = customized_agent.get("links")
|
||||
node_count = len(nodes) if isinstance(nodes, list) else 0
|
||||
link_count = len(links) if isinstance(links, list) else 0
|
||||
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've customized the agent '{agent_details.agent_name}'. "
|
||||
f"The customized agent has {node_count} blocks. "
|
||||
f"Review it and call customize_agent with save=true to save it."
|
||||
),
|
||||
agent_json=customized_agent,
|
||||
agent_name=agent_name,
|
||||
description=agent_description,
|
||||
node_count=node_count,
|
||||
link_count=link_count,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
error="auth_required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Save to user's library
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
customized_agent, user_id, is_update=False, folder_id=folder_id
|
||||
)
|
||||
|
||||
return AgentSavedResponse(
|
||||
message=(
|
||||
f"Customized agent '{created_graph.name}' "
|
||||
f"(based on '{agent_details.agent_name}') "
|
||||
f"has been saved to your library!"
|
||||
),
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving customized agent: {e}")
|
||||
return ErrorResponse(
|
||||
message="Failed to save the customized agent. Please try again.",
|
||||
error="save_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,172 @@
|
||||
"""Tests for CustomizeAgentTool local mode."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.tools.customize_agent import CustomizeAgentTool
|
||||
from backend.copilot.tools.models import AgentPreviewResponse, ErrorResponse
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-customize-agent"
|
||||
_PIPELINE = "backend.copilot.tools.agent_generator.pipeline"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return CustomizeAgentTool()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session():
|
||||
return make_session(_TEST_USER_ID)
|
||||
|
||||
|
||||
# ── Input validation tests ───────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_agent_json_returns_error(tool, session):
|
||||
"""Missing agent_json returns ErrorResponse."""
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_agent_json"
|
||||
|
||||
|
||||
# ── Local mode tests (agent_json provided) ───────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_mode_empty_nodes_returns_error(tool, session):
|
||||
"""Local mode with no nodes returns ErrorResponse."""
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json={"nodes": [], "links": []},
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "no nodes" in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_mode_preview(tool, session):
|
||||
"""Local mode with save=false returns AgentPreviewResponse."""
|
||||
agent_json = {
|
||||
"name": "Customized Agent",
|
||||
"description": "A customized agent",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"block_id": "block-1",
|
||||
"input_default": {},
|
||||
"metadata": {"position": {"x": 0, "y": 0}},
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
}
|
||||
|
||||
mock_fixer = MagicMock()
|
||||
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
|
||||
mock_fixer.get_fixes_applied.return_value = []
|
||||
|
||||
mock_validator = MagicMock()
|
||||
mock_validator.validate.return_value = (True, None)
|
||||
mock_validator.errors = []
|
||||
|
||||
with (
|
||||
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
|
||||
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
|
||||
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json=agent_json,
|
||||
save=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, AgentPreviewResponse)
|
||||
assert result.agent_name == "Customized Agent"
|
||||
assert result.node_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_mode_validation_failure(tool, session):
|
||||
"""Local mode returns ErrorResponse when validation fails."""
|
||||
agent_json = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"block_id": "bad-block",
|
||||
"input_default": {},
|
||||
"metadata": {},
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
}
|
||||
|
||||
mock_fixer = MagicMock()
|
||||
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
|
||||
mock_fixer.get_fixes_applied.return_value = []
|
||||
|
||||
mock_validator = MagicMock()
|
||||
mock_validator.validate.return_value = (False, "Block 'bad-block' not found")
|
||||
mock_validator.errors = ["Block 'bad-block' not found"]
|
||||
|
||||
with (
|
||||
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
|
||||
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
|
||||
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json=agent_json,
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "validation_failed"
|
||||
assert "Block 'bad-block' not found" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_mode_no_auth_returns_error(tool, session):
|
||||
"""Local mode with save=true and no user returns ErrorResponse."""
|
||||
agent_json = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"block_id": "block-1",
|
||||
"input_default": {},
|
||||
"metadata": {},
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
}
|
||||
|
||||
mock_fixer = MagicMock()
|
||||
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
|
||||
mock_fixer.get_fixes_applied.return_value = []
|
||||
|
||||
mock_validator = MagicMock()
|
||||
mock_validator.validate.return_value = (True, None)
|
||||
mock_validator.errors = []
|
||||
|
||||
with (
|
||||
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
|
||||
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
|
||||
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
agent_json=agent_json,
|
||||
save=True,
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "logged in" in result.message.lower()
|
||||
@@ -10,13 +10,34 @@ Lifecycle
|
||||
---------
|
||||
1. **Turn start** – connect to the existing sandbox (sandbox_id in Redis) or
|
||||
create a new one via ``get_or_create_sandbox()``.
|
||||
``connect()`` in e2b v2 auto-resumes paused sandboxes.
|
||||
2. **Execution** – ``bash_exec`` and MCP file tools operate directly on the
|
||||
sandbox's ``/home/user`` filesystem.
|
||||
3. **Session expiry** – E2B sandbox is killed by its own timeout (session_ttl).
|
||||
3. **Turn end** – the sandbox is paused via ``pause_sandbox()`` (fire-and-forget)
|
||||
so idle time between turns costs nothing. Paused sandboxes have no compute
|
||||
cost.
|
||||
4. **Session delete** – ``kill_sandbox()`` fully terminates the sandbox.
|
||||
|
||||
Cost control
|
||||
------------
|
||||
Sandboxes are created with a configurable ``on_timeout`` lifecycle action
|
||||
(default: ``"pause"``). The explicit per-turn ``pause_sandbox()`` call is the
|
||||
primary mechanism; the lifecycle setting is a safety net. Paused sandboxes are
|
||||
free.
|
||||
|
||||
The sandbox_id is stored in Redis. The same key doubles as a creation lock:
|
||||
a ``"creating"`` sentinel value is written with a short TTL while a new sandbox
|
||||
is being provisioned, preventing duplicate creation under concurrent requests.
|
||||
|
||||
E2B project-level "paused sandbox lifetime" should be set to match
|
||||
``_SANDBOX_ID_TTL`` (48 h) so orphaned paused sandboxes are auto-killed before
|
||||
the Redis key expires.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Literal
|
||||
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
@@ -24,147 +45,245 @@ from backend.data.redis_client import get_redis_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SANDBOX_REDIS_PREFIX = "copilot:e2b:sandbox:"
|
||||
E2B_WORKDIR = "/home/user"
|
||||
_CREATING = "__creating__"
|
||||
_CREATION_LOCK_TTL = 60
|
||||
_MAX_WAIT_ATTEMPTS = 20 # 20 * 0.5s = 10s max wait
|
||||
_SANDBOX_KEY_PREFIX = "copilot:e2b:sandbox:"
|
||||
_CREATING_SENTINEL = "creating"
|
||||
|
||||
# Short TTL for the "creating" sentinel — if the process dies mid-creation the
|
||||
# lock auto-expires so other callers are not blocked forever.
|
||||
_CREATION_LOCK_TTL = 60 # seconds
|
||||
|
||||
_MAX_WAIT_ATTEMPTS = 20 # 20 × 0.5 s = 10 s max wait
|
||||
|
||||
# Timeout for E2B API calls (pause/kill) — short because these are control-plane
|
||||
# operations; if the sandbox is unreachable, fail fast and retry on the next turn.
|
||||
_E2B_API_TIMEOUT_SECONDS = 10
|
||||
|
||||
# Redis TTL for the sandbox key. Must be ≥ the E2B project "paused sandbox
|
||||
# lifetime" setting (recommended: set both to 48 h).
|
||||
_SANDBOX_ID_TTL = 48 * 3600 # 48 hours
|
||||
|
||||
|
||||
def _sandbox_key(session_id: str) -> str:
|
||||
return f"{_SANDBOX_KEY_PREFIX}{session_id}"
|
||||
|
||||
|
||||
async def _get_stored_sandbox_id(session_id: str) -> str | None:
|
||||
redis = await get_redis_async()
|
||||
raw = await redis.get(_sandbox_key(session_id))
|
||||
value = raw.decode() if isinstance(raw, bytes) else raw
|
||||
return None if value == _CREATING_SENTINEL else value
|
||||
|
||||
|
||||
async def _set_stored_sandbox_id(session_id: str, sandbox_id: str) -> None:
|
||||
redis = await get_redis_async()
|
||||
await redis.set(_sandbox_key(session_id), sandbox_id, ex=_SANDBOX_ID_TTL)
|
||||
|
||||
|
||||
async def _clear_stored_sandbox_id(session_id: str) -> None:
|
||||
redis = await get_redis_async()
|
||||
await redis.delete(_sandbox_key(session_id))
|
||||
|
||||
|
||||
async def _try_reconnect(
|
||||
sandbox_id: str, api_key: str, redis_key: str, timeout: int
|
||||
sandbox_id: str, session_id: str, api_key: str
|
||||
) -> "AsyncSandbox | None":
|
||||
"""Try to reconnect to an existing sandbox. Returns None on failure."""
|
||||
try:
|
||||
sandbox = await AsyncSandbox.connect(sandbox_id, api_key=api_key)
|
||||
if await sandbox.is_running():
|
||||
redis = await get_redis_async()
|
||||
await redis.expire(redis_key, timeout)
|
||||
# Refresh TTL so an active session cannot lose its sandbox_id at expiry.
|
||||
await _set_stored_sandbox_id(session_id, sandbox_id)
|
||||
return sandbox
|
||||
except Exception as exc:
|
||||
logger.warning("[E2B] Reconnect to %.12s failed: %s", sandbox_id, exc)
|
||||
|
||||
# Stale — clear Redis so a new sandbox can be created.
|
||||
redis = await get_redis_async()
|
||||
await redis.delete(redis_key)
|
||||
# Stale — clear the sandbox_id from Redis so a new one can be created.
|
||||
await _clear_stored_sandbox_id(session_id)
|
||||
return None
|
||||
|
||||
|
||||
async def get_or_create_sandbox(
|
||||
session_id: str,
|
||||
api_key: str,
|
||||
timeout: int,
|
||||
template: str = "base",
|
||||
timeout: int = 43200,
|
||||
on_timeout: Literal["kill", "pause"] = "pause",
|
||||
) -> AsyncSandbox:
|
||||
"""Return the existing E2B sandbox for *session_id* or create a new one.
|
||||
|
||||
The sandbox_id is persisted in Redis so the same sandbox is reused
|
||||
across turns. Concurrent calls for the same session are serialised
|
||||
via a Redis ``SET NX`` creation lock.
|
||||
The sandbox key in Redis serves a dual purpose: it stores the sandbox_id
|
||||
and acts as a creation lock via a ``"creating"`` sentinel value. This
|
||||
removes the need for a separate lock key.
|
||||
|
||||
*timeout* controls how long the e2b sandbox may run continuously before
|
||||
the ``on_timeout`` lifecycle rule fires (default: 3 h).
|
||||
*on_timeout* controls what happens on timeout: ``"pause"`` (default, free)
|
||||
or ``"kill"``.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
redis_key = f"{_SANDBOX_REDIS_PREFIX}{session_id}"
|
||||
key = _sandbox_key(session_id)
|
||||
|
||||
# 1. Try reconnecting to an existing sandbox.
|
||||
raw = await redis.get(redis_key)
|
||||
if raw:
|
||||
sandbox_id = raw if isinstance(raw, str) else raw.decode()
|
||||
if sandbox_id != _CREATING:
|
||||
sandbox = await _try_reconnect(sandbox_id, api_key, redis_key, timeout)
|
||||
for _ in range(_MAX_WAIT_ATTEMPTS):
|
||||
raw = await redis.get(key)
|
||||
value = raw.decode() if isinstance(raw, bytes) else raw
|
||||
|
||||
if value and value != _CREATING_SENTINEL:
|
||||
# Existing sandbox ID — try to reconnect (auto-resumes if paused).
|
||||
sandbox = await _try_reconnect(value, session_id, api_key)
|
||||
if sandbox:
|
||||
logger.info(
|
||||
"[E2B] Reconnected to %.12s for session %.12s",
|
||||
sandbox_id,
|
||||
value,
|
||||
session_id,
|
||||
)
|
||||
return sandbox
|
||||
# _try_reconnect cleared the key — loop to create a new sandbox.
|
||||
continue
|
||||
|
||||
# 2. Claim creation lock. If another request holds it, wait for the result.
|
||||
claimed = await redis.set(redis_key, _CREATING, nx=True, ex=_CREATION_LOCK_TTL)
|
||||
if not claimed:
|
||||
for _ in range(_MAX_WAIT_ATTEMPTS):
|
||||
if value == _CREATING_SENTINEL:
|
||||
# Another coroutine is creating — wait for it to finish.
|
||||
await asyncio.sleep(0.5)
|
||||
raw = await redis.get(redis_key)
|
||||
if not raw:
|
||||
break # Lock expired — fall through to retry creation
|
||||
sandbox_id = raw if isinstance(raw, str) else raw.decode()
|
||||
if sandbox_id != _CREATING:
|
||||
sandbox = await _try_reconnect(sandbox_id, api_key, redis_key, timeout)
|
||||
if sandbox:
|
||||
return sandbox
|
||||
break # Stale sandbox cleared — fall through to create
|
||||
continue
|
||||
|
||||
# Try to claim creation lock again after waiting.
|
||||
claimed = await redis.set(redis_key, _CREATING, nx=True, ex=_CREATION_LOCK_TTL)
|
||||
if not claimed:
|
||||
# Another process may have created a sandbox — try to use it.
|
||||
raw = await redis.get(redis_key)
|
||||
if raw:
|
||||
sandbox_id = raw if isinstance(raw, str) else raw.decode()
|
||||
if sandbox_id != _CREATING:
|
||||
sandbox = await _try_reconnect(
|
||||
sandbox_id, api_key, redis_key, timeout
|
||||
)
|
||||
if sandbox:
|
||||
return sandbox
|
||||
raise RuntimeError(
|
||||
f"Could not acquire E2B creation lock for session {session_id[:12]}"
|
||||
)
|
||||
|
||||
# 3. Create a new sandbox.
|
||||
try:
|
||||
sandbox = await AsyncSandbox.create(
|
||||
template=template, api_key=api_key, timeout=timeout
|
||||
# No sandbox and no active creation — atomically claim the creation slot.
|
||||
claimed = await redis.set(
|
||||
key, _CREATING_SENTINEL, nx=True, ex=_CREATION_LOCK_TTL
|
||||
)
|
||||
except Exception:
|
||||
await redis.delete(redis_key)
|
||||
raise
|
||||
if not claimed:
|
||||
# Race lost — another coroutine just claimed it.
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
await redis.setex(redis_key, timeout, sandbox.sandbox_id)
|
||||
logger.info(
|
||||
"[E2B] Created sandbox %.12s for session %.12s",
|
||||
sandbox.sandbox_id,
|
||||
session_id,
|
||||
)
|
||||
return sandbox
|
||||
# We hold the slot — create the sandbox.
|
||||
try:
|
||||
sandbox = await AsyncSandbox.create(
|
||||
template=template,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
lifecycle={"on_timeout": on_timeout},
|
||||
)
|
||||
try:
|
||||
await _set_stored_sandbox_id(session_id, sandbox.sandbox_id)
|
||||
except Exception:
|
||||
# Redis save failed — kill the sandbox to avoid leaking it.
|
||||
with contextlib.suppress(Exception):
|
||||
await sandbox.kill()
|
||||
raise
|
||||
except Exception:
|
||||
# Release the creation slot so other callers can proceed.
|
||||
await redis.delete(key)
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
"[E2B] Created sandbox %.12s for session %.12s",
|
||||
sandbox.sandbox_id,
|
||||
session_id,
|
||||
)
|
||||
return sandbox
|
||||
|
||||
raise RuntimeError(f"Could not acquire E2B sandbox for session {session_id[:12]}")
|
||||
|
||||
|
||||
async def kill_sandbox(session_id: str, api_key: str) -> bool:
|
||||
"""Kill the E2B sandbox for *session_id* and clean up its Redis entry.
|
||||
async def _act_on_sandbox(
|
||||
session_id: str,
|
||||
api_key: str,
|
||||
action: str,
|
||||
fn: Callable[[AsyncSandbox], Awaitable[Any]],
|
||||
*,
|
||||
clear_stored_id: bool = False,
|
||||
) -> bool:
|
||||
"""Connect to the sandbox for *session_id* and run *fn* on it.
|
||||
|
||||
Returns ``True`` if a sandbox was found and killed, ``False`` otherwise.
|
||||
Safe to call even when no sandbox exists for the session.
|
||||
Shared by ``pause_sandbox`` and ``kill_sandbox``. Returns ``True`` on
|
||||
success, ``False`` when no sandbox is found or the action fails.
|
||||
If *clear_stored_id* is ``True``, the sandbox_id is removed from Redis
|
||||
only after the action succeeds so a failed kill can be retried.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
redis_key = f"{_SANDBOX_REDIS_PREFIX}{session_id}"
|
||||
raw = await redis.get(redis_key)
|
||||
if not raw:
|
||||
sandbox_id = await _get_stored_sandbox_id(session_id)
|
||||
if not sandbox_id:
|
||||
return False
|
||||
|
||||
sandbox_id = raw if isinstance(raw, str) else raw.decode()
|
||||
await redis.delete(redis_key)
|
||||
|
||||
if sandbox_id == _CREATING:
|
||||
return False
|
||||
async def _run() -> None:
|
||||
await fn(await AsyncSandbox.connect(sandbox_id, api_key=api_key))
|
||||
|
||||
try:
|
||||
|
||||
async def _connect_and_kill():
|
||||
sandbox = await AsyncSandbox.connect(sandbox_id, api_key=api_key)
|
||||
await sandbox.kill()
|
||||
|
||||
await asyncio.wait_for(_connect_and_kill(), timeout=10)
|
||||
await asyncio.wait_for(_run(), timeout=_E2B_API_TIMEOUT_SECONDS)
|
||||
if clear_stored_id:
|
||||
await _clear_stored_sandbox_id(session_id)
|
||||
logger.info(
|
||||
"[E2B] Killed sandbox %.12s for session %.12s",
|
||||
"[E2B] %s sandbox %.12s for session %.12s",
|
||||
action.capitalize(),
|
||||
sandbox_id,
|
||||
session_id,
|
||||
)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[E2B] Failed to kill sandbox %.12s for session %.12s: %s",
|
||||
"[E2B] Failed to %s sandbox %.12s for session %.12s: %s",
|
||||
action,
|
||||
sandbox_id,
|
||||
session_id,
|
||||
exc,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def pause_sandbox(session_id: str, api_key: str) -> bool:
|
||||
"""Pause the E2B sandbox for *session_id* to stop billing between turns.
|
||||
|
||||
Paused sandboxes cost nothing and are resumed automatically by
|
||||
``get_or_create_sandbox()`` on the next turn (via ``AsyncSandbox.connect()``).
|
||||
The sandbox_id is kept in Redis so reconnection works seamlessly.
|
||||
|
||||
Prefer ``pause_sandbox_direct()`` when the sandbox object is already in
|
||||
scope — it skips the Redis lookup and reconnect round-trip.
|
||||
|
||||
Returns ``True`` if the sandbox was found and paused, ``False`` otherwise.
|
||||
Safe to call even when no sandbox exists for the session.
|
||||
"""
|
||||
return await _act_on_sandbox(session_id, api_key, "pause", lambda sb: sb.pause())
|
||||
|
||||
|
||||
async def pause_sandbox_direct(sandbox: "AsyncSandbox", session_id: str) -> bool:
|
||||
"""Pause an already-connected sandbox without a reconnect round-trip.
|
||||
|
||||
Use this in callers that already hold the live sandbox object (e.g. turn
|
||||
teardown in ``service.py``). Saves the Redis lookup and
|
||||
``AsyncSandbox.connect()`` call that ``pause_sandbox()`` would make.
|
||||
|
||||
Returns ``True`` on success, ``False`` on failure or timeout.
|
||||
"""
|
||||
try:
|
||||
await asyncio.wait_for(sandbox.pause(), timeout=_E2B_API_TIMEOUT_SECONDS)
|
||||
logger.info(
|
||||
"[E2B] Paused sandbox %.12s for session %.12s",
|
||||
sandbox.sandbox_id,
|
||||
session_id,
|
||||
)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[E2B] Failed to pause sandbox %.12s for session %.12s: %s",
|
||||
sandbox.sandbox_id,
|
||||
session_id,
|
||||
exc,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def kill_sandbox(
|
||||
session_id: str,
|
||||
api_key: str,
|
||||
) -> bool:
|
||||
"""Kill the E2B sandbox for *session_id* and clear its Redis entry.
|
||||
|
||||
Returns ``True`` if a sandbox was found and killed, ``False`` otherwise.
|
||||
Safe to call even when no sandbox exists for the session.
|
||||
"""
|
||||
return await _act_on_sandbox(
|
||||
session_id,
|
||||
api_key,
|
||||
"kill",
|
||||
lambda sb: sb.kill(),
|
||||
clear_stored_id=True,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
"""Tests for e2b_sandbox: get_or_create_sandbox, _try_reconnect, kill_sandbox.
|
||||
|
||||
Uses mock Redis and mock AsyncSandbox — no external dependencies.
|
||||
sandbox_id is stored in Redis under _SANDBOX_KEY_PREFIX + session_id.
|
||||
The same key doubles as a creation lock via a "creating" sentinel value.
|
||||
|
||||
Tests mock:
|
||||
- ``get_redis_async`` (sandbox key storage + creation lock sentinel)
|
||||
- ``AsyncSandbox`` (E2B SDK)
|
||||
|
||||
Tests are synchronous (using asyncio.run) to avoid conflicts with the
|
||||
session-scoped event loop in conftest.py.
|
||||
"""
|
||||
@@ -11,36 +17,50 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from .e2b_sandbox import (
|
||||
_CREATING,
|
||||
_SANDBOX_REDIS_PREFIX,
|
||||
_CREATING_SENTINEL,
|
||||
_try_reconnect,
|
||||
get_or_create_sandbox,
|
||||
kill_sandbox,
|
||||
pause_sandbox,
|
||||
pause_sandbox_direct,
|
||||
)
|
||||
|
||||
_KEY = f"{_SANDBOX_REDIS_PREFIX}sess-123"
|
||||
_SESSION_ID = "sess-123"
|
||||
_API_KEY = "test-api-key"
|
||||
_SANDBOX_ID = "sb-abc"
|
||||
_TIMEOUT = 300
|
||||
|
||||
|
||||
def _mock_sandbox(sandbox_id: str = "sb-abc", running: bool = True) -> MagicMock:
|
||||
def _mock_sandbox(sandbox_id: str = _SANDBOX_ID, running: bool = True) -> MagicMock:
|
||||
sb = MagicMock()
|
||||
sb.sandbox_id = sandbox_id
|
||||
sb.is_running = AsyncMock(return_value=running)
|
||||
sb.pause = AsyncMock()
|
||||
sb.kill = AsyncMock()
|
||||
return sb
|
||||
|
||||
|
||||
def _mock_redis(get_val: str | bytes | None = None, set_nx_result: bool = True):
|
||||
def _mock_redis(
|
||||
set_nx_result: bool = True,
|
||||
stored_sandbox_id: str | None = None,
|
||||
) -> AsyncMock:
|
||||
"""Create a mock redis client.
|
||||
|
||||
*stored_sandbox_id* is returned by ``get()`` calls (simulates the sandbox_id
|
||||
stored under the ``_SANDBOX_KEY_PREFIX`` key). ``set_nx_result`` controls
|
||||
whether the creation-slot ``SET NX`` succeeds.
|
||||
|
||||
If *stored_sandbox_id* is None the key is absent (no sandbox, no lock).
|
||||
"""
|
||||
r = AsyncMock()
|
||||
r.get = AsyncMock(return_value=get_val)
|
||||
raw = stored_sandbox_id.encode() if stored_sandbox_id else None
|
||||
r.get = AsyncMock(return_value=raw)
|
||||
r.set = AsyncMock(return_value=set_nx_result)
|
||||
r.setex = AsyncMock()
|
||||
r.delete = AsyncMock()
|
||||
r.expire = AsyncMock()
|
||||
return r
|
||||
|
||||
|
||||
def _patch_redis(redis):
|
||||
def _patch_redis(redis: AsyncMock):
|
||||
return patch(
|
||||
"backend.copilot.tools.e2b_sandbox.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
@@ -55,6 +75,7 @@ def _patch_redis(redis):
|
||||
|
||||
class TestTryReconnect:
|
||||
def test_reconnect_success(self):
|
||||
"""Returns the sandbox when it connects and is running; refreshes Redis TTL."""
|
||||
sb = _mock_sandbox()
|
||||
redis = _mock_redis()
|
||||
with (
|
||||
@@ -62,36 +83,39 @@ class TestTryReconnect:
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
|
||||
result = asyncio.run(_try_reconnect(_SANDBOX_ID, _SESSION_ID, _API_KEY))
|
||||
|
||||
assert result is sb
|
||||
redis.expire.assert_awaited_once_with(_KEY, _TIMEOUT)
|
||||
redis.delete.assert_not_awaited()
|
||||
# TTL must be refreshed so an active session cannot lose its key at expiry.
|
||||
redis.set.assert_awaited_once()
|
||||
|
||||
def test_reconnect_not_running_clears_key(self):
|
||||
def test_reconnect_not_running_clears_redis(self):
|
||||
"""Clears sandbox_id in Redis when the sandbox is no longer running."""
|
||||
sb = _mock_sandbox(running=False)
|
||||
redis = _mock_redis()
|
||||
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
|
||||
result = asyncio.run(_try_reconnect(_SANDBOX_ID, _SESSION_ID, _API_KEY))
|
||||
|
||||
assert result is None
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
redis.delete.assert_awaited_once()
|
||||
|
||||
def test_reconnect_exception_clears_key(self):
|
||||
redis = _mock_redis()
|
||||
def test_reconnect_exception_clears_redis(self):
|
||||
"""Clears sandbox_id in Redis when connect raises an exception."""
|
||||
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
|
||||
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
|
||||
result = asyncio.run(_try_reconnect(_SANDBOX_ID, _SESSION_ID, _API_KEY))
|
||||
|
||||
assert result is None
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
redis.delete.assert_awaited_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -103,38 +127,63 @@ class TestGetOrCreateSandbox:
|
||||
def test_reconnect_existing(self):
|
||||
"""When Redis has a valid sandbox_id, reconnect to it."""
|
||||
sb = _mock_sandbox()
|
||||
redis = _mock_redis(get_val="sb-abc")
|
||||
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
assert result is sb
|
||||
mock_cls.create.assert_not_called()
|
||||
# redis.set called once to refresh TTL, not to claim a creation slot
|
||||
redis.set.assert_awaited_once()
|
||||
|
||||
def test_create_new_when_no_key(self):
|
||||
"""When Redis is empty, claim lock and create a new sandbox."""
|
||||
sb = _mock_sandbox("sb-new")
|
||||
redis = _mock_redis(get_val=None, set_nx_result=True)
|
||||
def test_create_new_when_no_stored_id(self):
|
||||
"""When Redis has no sandbox_id, claim slot and create a new sandbox."""
|
||||
new_sb = _mock_sandbox("sb-new")
|
||||
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.create = AsyncMock(return_value=sb)
|
||||
mock_cls.create = AsyncMock(return_value=new_sb)
|
||||
result = asyncio.run(
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
assert result is sb
|
||||
redis.setex.assert_awaited_once_with(_KEY, _TIMEOUT, "sb-new")
|
||||
assert result is new_sb
|
||||
mock_cls.create.assert_awaited_once()
|
||||
# Verify lifecycle param is set
|
||||
_, kwargs = mock_cls.create.call_args
|
||||
assert kwargs.get("lifecycle") == {"on_timeout": "pause"}
|
||||
# sandbox_id should be saved to Redis
|
||||
redis.set.assert_awaited()
|
||||
|
||||
def test_create_failure_clears_lock(self):
|
||||
"""If sandbox creation fails, the Redis lock is deleted."""
|
||||
redis = _mock_redis(get_val=None, set_nx_result=True)
|
||||
def test_create_with_on_timeout_kill(self):
|
||||
"""on_timeout='kill' is passed through to AsyncSandbox.create."""
|
||||
new_sb = _mock_sandbox("sb-new")
|
||||
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.create = AsyncMock(return_value=new_sb)
|
||||
asyncio.run(
|
||||
get_or_create_sandbox(
|
||||
_SESSION_ID, _API_KEY, timeout=_TIMEOUT, on_timeout="kill"
|
||||
)
|
||||
)
|
||||
|
||||
_, kwargs = mock_cls.create.call_args
|
||||
assert kwargs.get("lifecycle") == {"on_timeout": "kill"}
|
||||
|
||||
def test_create_failure_releases_slot(self):
|
||||
"""If sandbox creation fails, the Redis creation slot is deleted."""
|
||||
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
@@ -142,17 +191,53 @@ class TestGetOrCreateSandbox:
|
||||
mock_cls.create = AsyncMock(side_effect=RuntimeError("quota"))
|
||||
with pytest.raises(RuntimeError, match="quota"):
|
||||
asyncio.run(
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
redis.delete.assert_awaited_once()
|
||||
|
||||
def test_wait_for_lock_then_reconnect(self):
|
||||
"""When another process holds the lock, wait and reconnect."""
|
||||
def test_redis_save_failure_kills_sandbox_and_releases_slot(self):
|
||||
"""If Redis save fails after creation, sandbox is killed and slot released."""
|
||||
new_sb = _mock_sandbox("sb-new")
|
||||
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
|
||||
# First set() call = creation slot SET NX (returns True).
|
||||
# Second set() call = sandbox_id save (raises).
|
||||
call_count = 0
|
||||
|
||||
async def _set_side_effect(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return True # creation slot claimed
|
||||
raise RuntimeError("redis error")
|
||||
|
||||
redis.set = AsyncMock(side_effect=_set_side_effect)
|
||||
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.create = AsyncMock(return_value=new_sb)
|
||||
with pytest.raises(RuntimeError, match="redis error"):
|
||||
asyncio.run(
|
||||
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
# Sandbox must be killed to avoid leaking it
|
||||
new_sb.kill.assert_awaited_once()
|
||||
# Creation slot must always be released
|
||||
redis.delete.assert_awaited_once()
|
||||
|
||||
def test_wait_for_creating_sentinel_then_reconnect(self):
|
||||
"""When the key holds the 'creating' sentinel, wait then reconnect."""
|
||||
sb = _mock_sandbox("sb-other")
|
||||
redis = _mock_redis()
|
||||
redis.get = AsyncMock(side_effect=[_CREATING, "sb-other"])
|
||||
# First get() returns the sentinel; second returns the real ID.
|
||||
redis = AsyncMock()
|
||||
creating_raw = _CREATING_SENTINEL.encode()
|
||||
redis.get = AsyncMock(side_effect=[creating_raw, b"sb-other"])
|
||||
redis.set = AsyncMock(return_value=False)
|
||||
redis.delete = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
@@ -163,16 +248,21 @@ class TestGetOrCreateSandbox:
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
assert result is sb
|
||||
|
||||
def test_stale_reconnect_clears_and_creates(self):
|
||||
"""When stored sandbox is stale, clear key and create a new one."""
|
||||
"""When stored sandbox is stale (not running), clear it and create a new one."""
|
||||
stale_sb = _mock_sandbox("sb-stale", running=False)
|
||||
new_sb = _mock_sandbox("sb-fresh")
|
||||
redis = _mock_redis(get_val="sb-stale", set_nx_result=True)
|
||||
# First get() returns stale id (for reconnect check), then None (after clear).
|
||||
redis = AsyncMock()
|
||||
redis.get = AsyncMock(side_effect=[b"sb-stale", None])
|
||||
redis.set = AsyncMock(return_value=True)
|
||||
redis.delete = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
@@ -180,10 +270,11 @@ class TestGetOrCreateSandbox:
|
||||
mock_cls.connect = AsyncMock(return_value=stale_sb)
|
||||
mock_cls.create = AsyncMock(return_value=new_sb)
|
||||
result = asyncio.run(
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
assert result is new_sb
|
||||
# Redis delete called at least once to clear stale id
|
||||
redis.delete.assert_awaited()
|
||||
|
||||
|
||||
@@ -194,70 +285,48 @@ class TestGetOrCreateSandbox:
|
||||
|
||||
class TestKillSandbox:
|
||||
def test_kill_existing_sandbox(self):
|
||||
"""Kill a running sandbox and clean up Redis."""
|
||||
"""Kill a running sandbox and clear its Redis entry."""
|
||||
sb = _mock_sandbox()
|
||||
sb.kill = AsyncMock()
|
||||
redis = _mock_redis(get_val="sb-abc")
|
||||
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
|
||||
|
||||
assert result is True
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
sb.kill.assert_awaited_once()
|
||||
# Redis key cleared after successful kill
|
||||
redis.delete.assert_awaited_once()
|
||||
|
||||
def test_kill_no_sandbox(self):
|
||||
"""No-op when no sandbox exists in Redis."""
|
||||
redis = _mock_redis(get_val=None)
|
||||
"""No-op when Redis has no sandbox_id."""
|
||||
redis = _mock_redis(stored_sandbox_id=None)
|
||||
with _patch_redis(redis):
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
|
||||
|
||||
assert result is False
|
||||
redis.delete.assert_not_awaited()
|
||||
|
||||
def test_kill_creating_state(self):
|
||||
"""Clears Redis key but returns False when sandbox is still being created."""
|
||||
redis = _mock_redis(get_val=_CREATING)
|
||||
with _patch_redis(redis):
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
def test_kill_connect_failure_keeps_redis(self):
|
||||
"""Returns False and leaves Redis entry intact when connect/kill fails.
|
||||
|
||||
assert result is False
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
|
||||
def test_kill_connect_failure(self):
|
||||
"""Returns False and cleans Redis if connect/kill fails."""
|
||||
redis = _mock_redis(get_val="sb-abc")
|
||||
Keeping the sandbox_id in Redis allows the kill to be retried.
|
||||
"""
|
||||
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
|
||||
|
||||
assert result is False
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
redis.delete.assert_not_awaited()
|
||||
|
||||
def test_kill_with_bytes_redis_value(self):
|
||||
"""Redis may return bytes — kill_sandbox should decode correctly."""
|
||||
sb = _mock_sandbox()
|
||||
sb.kill = AsyncMock()
|
||||
redis = _mock_redis(get_val=b"sb-abc")
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is True
|
||||
sb.kill.assert_awaited_once()
|
||||
|
||||
def test_kill_timeout_returns_false(self):
|
||||
"""Returns False when E2B API calls exceed the 10s timeout."""
|
||||
redis = _mock_redis(get_val="sb-abc")
|
||||
def test_kill_timeout_keeps_redis(self):
|
||||
"""Returns False and leaves Redis entry intact when the E2B call times out."""
|
||||
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
|
||||
with (
|
||||
_patch_redis(redis),
|
||||
patch(
|
||||
@@ -266,7 +335,146 @@ class TestKillSandbox:
|
||||
side_effect=asyncio.TimeoutError,
|
||||
),
|
||||
):
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
|
||||
|
||||
assert result is False
|
||||
redis.delete.assert_not_awaited()
|
||||
|
||||
def test_kill_creating_sentinel_returns_false(self):
|
||||
"""No-op when the key holds the 'creating' sentinel (no real sandbox yet)."""
|
||||
redis = _mock_redis(stored_sandbox_id=_CREATING_SENTINEL)
|
||||
with _patch_redis(redis):
|
||||
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# pause_sandbox
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPauseSandbox:
|
||||
def test_pause_existing_sandbox(self):
|
||||
"""Pause a running sandbox; Redis sandbox_id is preserved."""
|
||||
sb = _mock_sandbox()
|
||||
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
|
||||
|
||||
assert result is True
|
||||
sb.pause.assert_awaited_once()
|
||||
# sandbox_id should remain in Redis (not cleared on pause)
|
||||
redis.delete.assert_not_awaited()
|
||||
|
||||
def test_pause_no_sandbox(self):
|
||||
"""No-op when Redis has no sandbox_id."""
|
||||
redis = _mock_redis(stored_sandbox_id=None)
|
||||
with _patch_redis(redis):
|
||||
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_pause_connect_failure(self):
|
||||
"""Returns False if connect fails."""
|
||||
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
|
||||
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_pause_creating_sentinel_returns_false(self):
|
||||
"""No-op when the key holds the 'creating' sentinel (no real sandbox yet)."""
|
||||
redis = _mock_redis(stored_sandbox_id=_CREATING_SENTINEL)
|
||||
with _patch_redis(redis):
|
||||
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_pause_timeout_returns_false(self):
|
||||
"""Returns False and preserves Redis entry when the E2B API call times out."""
|
||||
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
|
||||
with (
|
||||
_patch_redis(redis),
|
||||
patch(
|
||||
"backend.copilot.tools.e2b_sandbox.asyncio.wait_for",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=asyncio.TimeoutError,
|
||||
),
|
||||
):
|
||||
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
|
||||
|
||||
assert result is False
|
||||
# sandbox_id must remain in Redis so the next turn can reconnect
|
||||
redis.delete.assert_not_awaited()
|
||||
|
||||
def test_pause_then_reconnect_reuses_sandbox(self):
|
||||
"""After pause, get_or_create_sandbox reconnects the same sandbox.
|
||||
|
||||
Covers the pause->reconnect cycle: connect() auto-resumes a paused
|
||||
sandbox, and is_running() returns True once resume completes, so the
|
||||
same sandbox_id is reused rather than a new one being created.
|
||||
"""
|
||||
sb = _mock_sandbox(_SANDBOX_ID)
|
||||
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
|
||||
# Step 1: pause the sandbox
|
||||
paused = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
|
||||
assert paused is True
|
||||
sb.pause.assert_awaited_once()
|
||||
|
||||
# Step 2: reconnect on next turn -- same sandbox should be returned
|
||||
result = asyncio.run(
|
||||
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
assert result is sb
|
||||
mock_cls.create.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# pause_sandbox_direct
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPauseSandboxDirect:
|
||||
def test_pause_direct_success(self):
|
||||
"""Pauses the sandbox directly without a Redis lookup or reconnect."""
|
||||
sb = _mock_sandbox()
|
||||
result = asyncio.run(pause_sandbox_direct(sb, _SESSION_ID))
|
||||
|
||||
assert result is True
|
||||
sb.pause.assert_awaited_once()
|
||||
|
||||
def test_pause_direct_failure_returns_false(self):
|
||||
"""Returns False when sandbox.pause() raises."""
|
||||
sb = _mock_sandbox()
|
||||
sb.pause = AsyncMock(side_effect=RuntimeError("e2b error"))
|
||||
result = asyncio.run(pause_sandbox_direct(sb, _SESSION_ID))
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_pause_direct_timeout_returns_false(self):
|
||||
"""Returns False when sandbox.pause() exceeds the 10s timeout."""
|
||||
sb = _mock_sandbox()
|
||||
with patch(
|
||||
"backend.copilot.tools.e2b_sandbox.asyncio.wait_for",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=asyncio.TimeoutError,
|
||||
):
|
||||
result = asyncio.run(pause_sandbox_direct(sb, _SESSION_ID))
|
||||
|
||||
assert result is False
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
|
||||
@@ -1,32 +1,20 @@
|
||||
"""EditAgentTool - Edits existing agents using natural language."""
|
||||
"""EditAgentTool - Edits existing agents using pre-built JSON."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
get_user_message_for_error,
|
||||
save_agent_to_library,
|
||||
)
|
||||
from .agent_generator import get_agent_as_json
|
||||
from .agent_generator.pipeline import fetch_library_agents, fix_validate_and_save
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .models import ErrorResponse, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EditAgentTool(BaseTool):
|
||||
"""Tool for editing existing agents using natural language."""
|
||||
"""Tool for editing existing agents using pre-built JSON."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -35,11 +23,12 @@ class EditAgentTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit an existing agent from the user's library using natural language. "
|
||||
"Generates updates to the agent while preserving unchanged parts. "
|
||||
"\n\nIMPORTANT: Before calling this tool, if the changes involve adding new "
|
||||
"Edit an existing agent. Pass `agent_json` with the complete "
|
||||
"updated agent JSON you generated. The tool validates, auto-fixes, "
|
||||
"and saves.\n\n"
|
||||
"IMPORTANT: Before calling this tool, if the changes involve adding new "
|
||||
"functionality, search for relevant existing agents using find_library_agent "
|
||||
"that could be used as building blocks. Pass their IDs in library_agent_ids."
|
||||
"that could be used as building blocks."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -58,26 +47,20 @@ class EditAgentTool(BaseTool):
|
||||
"Can be a graph ID or library agent ID."
|
||||
),
|
||||
},
|
||||
"changes": {
|
||||
"type": "string",
|
||||
"agent_json": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"Natural language description of what changes to make. "
|
||||
"Be specific about what to add, remove, or modify."
|
||||
),
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Additional context or answers to previous clarifying questions."
|
||||
"Complete updated agent JSON to validate and save. "
|
||||
"Must contain 'nodes' and 'links'. "
|
||||
"Include 'name' and/or 'description' if they need "
|
||||
"to be updated."
|
||||
),
|
||||
},
|
||||
"library_agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"List of library agent IDs to use as building blocks for the changes. "
|
||||
"If adding new functionality, search for relevant agents using "
|
||||
"find_library_agent first, then pass their IDs here."
|
||||
"List of library agent IDs to use as building blocks for the changes."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
@@ -89,7 +72,7 @@ class EditAgentTool(BaseTool):
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
"required": ["agent_id", "changes"],
|
||||
"required": ["agent_id", "agent_json"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
@@ -98,36 +81,39 @@ class EditAgentTool(BaseTool):
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the edit_agent tool.
|
||||
|
||||
Flow:
|
||||
1. Fetch the current agent
|
||||
2. Generate updated agent (external service handles fixing and validation)
|
||||
3. Preview or save based on the save parameter
|
||||
"""
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
changes = kwargs.get("changes", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
library_agent_ids = kwargs.get("library_agent_ids", [])
|
||||
save = kwargs.get("save", True)
|
||||
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the agent ID to edit.",
|
||||
error="Missing agent_id parameter",
|
||||
error="missing_agent_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not changes:
|
||||
if not agent_json:
|
||||
return ErrorResponse(
|
||||
message="Please describe what changes you want to make.",
|
||||
error="Missing changes parameter",
|
||||
message=(
|
||||
"Please provide agent_json with the complete updated agent graph."
|
||||
),
|
||||
error="missing_agent_json",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
current_agent = await get_agent_as_json(agent_id, user_id)
|
||||
save = kwargs.get("save", True)
|
||||
library_agent_ids = kwargs.get("library_agent_ids", [])
|
||||
|
||||
nodes = agent_json.get("nodes", [])
|
||||
if not nodes:
|
||||
return ErrorResponse(
|
||||
message="The agent JSON has no nodes.",
|
||||
error="empty_agent",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Preserve original agent's ID
|
||||
current_agent = await get_agent_as_json(agent_id, user_id)
|
||||
if current_agent is None:
|
||||
return ErrorResponse(
|
||||
message=f"Could not find agent with ID '{agent_id}' in your library.",
|
||||
@@ -135,142 +121,19 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Fetch library agents by IDs if provided
|
||||
library_agents = None
|
||||
if user_id and library_agent_ids:
|
||||
try:
|
||||
from .agent_generator import get_library_agents_by_ids
|
||||
agent_json["id"] = current_agent.get("id", agent_id)
|
||||
agent_json["version"] = current_agent.get("version", 1)
|
||||
agent_json.setdefault("is_active", True)
|
||||
|
||||
graph_id = current_agent.get("id")
|
||||
# Filter out the current agent being edited
|
||||
filtered_ids = [id for id in library_agent_ids if id != graph_id]
|
||||
# Fetch library agents for AgentExecutorBlock validation
|
||||
library_agents = await fetch_library_agents(user_id, library_agent_ids)
|
||||
|
||||
library_agents = await get_library_agents_by_ids(
|
||||
user_id=user_id,
|
||||
agent_ids=filtered_ids,
|
||||
)
|
||||
logger.debug(
|
||||
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agents by IDs: {e}")
|
||||
|
||||
update_request = changes
|
||||
if context:
|
||||
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
||||
|
||||
try:
|
||||
result = await generate_agent_patch(
|
||||
update_request,
|
||||
current_agent,
|
||||
library_agents,
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Agent editing is not available. "
|
||||
"The Agent Generator service is not configured."
|
||||
),
|
||||
error="service_not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return ErrorResponse(
|
||||
message="Failed to generate changes. The agent generation service may be unavailable or timed out. Please try again.",
|
||||
error="update_generation_failed",
|
||||
details={"agent_id": agent_id, "changes": changes[:100]},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if the result is an error from the external service
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
error_type = result.get("error_type", "unknown")
|
||||
user_message = get_user_message_for_error(
|
||||
error_type,
|
||||
operation="generate the changes",
|
||||
llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.",
|
||||
validation_message="The generated changes failed validation. Please try rephrasing your request.",
|
||||
error_details=error_msg,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
error=f"update_generation_failed:{error_type}",
|
||||
details={
|
||||
"agent_id": agent_id,
|
||||
"changes": changes[:100],
|
||||
"service_error": error_msg,
|
||||
"error_type": error_type,
|
||||
},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if result.get("type") == "clarifying_questions":
|
||||
questions = result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information about the changes. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
updated_agent = result
|
||||
|
||||
agent_name = updated_agent.get("name", "Updated Agent")
|
||||
agent_description = updated_agent.get("description", "")
|
||||
node_count = len(updated_agent.get("nodes", []))
|
||||
link_count = len(updated_agent.get("links", []))
|
||||
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've updated the agent. "
|
||||
f"The agent now has {node_count} blocks. "
|
||||
f"Review it and call edit_agent with save=true to save the changes."
|
||||
),
|
||||
agent_json=updated_agent,
|
||||
agent_name=agent_name,
|
||||
description=agent_description,
|
||||
node_count=node_count,
|
||||
link_count=link_count,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
error="auth_required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
updated_agent, user_id, is_update=True
|
||||
)
|
||||
|
||||
return AgentSavedResponse(
|
||||
message=f"Updated agent '{created_graph.name}' has been saved to your library!",
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to save the updated agent: {str(e)}",
|
||||
error="save_failed",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
return await fix_validate_and_save(
|
||||
agent_json,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
save=save,
|
||||
is_update=True,
|
||||
default_name="Updated Agent",
|
||||
library_agents=library_agents,
|
||||
)
|
||||
|
||||
@@ -32,7 +32,7 @@ COPILOT_EXCLUDED_BLOCK_TYPES = {
|
||||
BlockType.NOTE, # Visual annotation only - no runtime behavior
|
||||
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
|
||||
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
|
||||
BlockType.MCP_TOOL, # Has dedicated run_mcp_tool tool with proper discovery + auth flow
|
||||
BlockType.MCP_TOOL, # Has dedicated run_mcp_tool tool with discovery + auth flow
|
||||
}
|
||||
|
||||
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
|
||||
@@ -72,6 +72,15 @@ class FindBlockTool(BaseTool):
|
||||
"Use keywords like 'email', 'http', 'text', 'ai', etc."
|
||||
),
|
||||
},
|
||||
"include_schemas": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"If true, include full input_schema and output_schema "
|
||||
"for each block. Use when generating agent JSON that "
|
||||
"needs block schemas. Default is false."
|
||||
),
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
@@ -99,6 +108,7 @@ class FindBlockTool(BaseTool):
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
include_schemas = kwargs.get("include_schemas", False)
|
||||
session_id = session.session_id
|
||||
|
||||
if not query:
|
||||
@@ -143,15 +153,21 @@ class FindBlockTool(BaseTool):
|
||||
):
|
||||
continue
|
||||
|
||||
blocks.append(
|
||||
BlockInfoSummary(
|
||||
id=block_id,
|
||||
name=block.name,
|
||||
description=block.description or "",
|
||||
categories=[c.value for c in block.categories],
|
||||
)
|
||||
summary = BlockInfoSummary(
|
||||
id=block_id,
|
||||
name=block.name,
|
||||
description=block.optimized_description or block.description or "",
|
||||
categories=[c.value for c in block.categories],
|
||||
)
|
||||
|
||||
if include_schemas:
|
||||
info = block.get_info()
|
||||
summary.input_schema = info.inputSchema
|
||||
summary.output_schema = info.outputSchema
|
||||
summary.static_output = info.staticOutput
|
||||
|
||||
blocks.append(summary)
|
||||
|
||||
if len(blocks) >= _TARGET_RESULTS:
|
||||
break
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ def make_mock_block(
|
||||
input_schema: dict | None = None,
|
||||
output_schema: dict | None = None,
|
||||
credentials_fields: dict | None = None,
|
||||
static_output: bool = False,
|
||||
):
|
||||
"""Create a mock block for testing."""
|
||||
mock = MagicMock()
|
||||
@@ -33,6 +34,7 @@ def make_mock_block(
|
||||
mock.description = f"{name} description"
|
||||
mock.block_type = block_type
|
||||
mock.disabled = disabled
|
||||
mock.static_output = static_output
|
||||
mock.input_schema = MagicMock()
|
||||
mock.input_schema.jsonschema.return_value = input_schema or {
|
||||
"properties": {},
|
||||
@@ -42,6 +44,15 @@ def make_mock_block(
|
||||
mock.output_schema = MagicMock()
|
||||
mock.output_schema.jsonschema.return_value = output_schema or {}
|
||||
mock.categories = []
|
||||
mock.optimized_description = None
|
||||
|
||||
# Mock get_info() for include_schemas support
|
||||
mock_info = MagicMock()
|
||||
mock_info.inputSchema = input_schema or {"properties": {}, "required": []}
|
||||
mock_info.outputSchema = output_schema or {}
|
||||
mock_info.staticOutput = static_output
|
||||
mock.get_info.return_value = mock_info
|
||||
|
||||
return mock
|
||||
|
||||
|
||||
@@ -399,3 +410,92 @@ class TestFindBlockFiltering:
|
||||
f"Average chars per block ({avg_chars}) exceeds 500. "
|
||||
f"Total response: {total_chars} chars for {response.count} blocks."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_include_schemas_false_omits_schemas(self):
|
||||
"""Without include_schemas, schemas should be empty dicts."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
input_schema = {"properties": {"url": {"type": "string"}}, "required": ["url"]}
|
||||
output_schema = {"properties": {"result": {"type": "string"}}}
|
||||
|
||||
search_results = [{"content_id": "block-1", "score": 0.9}]
|
||||
block = make_mock_block(
|
||||
"block-1",
|
||||
"Test Block",
|
||||
BlockType.STANDARD,
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
)
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||
return_value=(search_results, 1)
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
return_value=block,
|
||||
),
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
query="test",
|
||||
include_schemas=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert response.blocks[0].input_schema == {}
|
||||
assert response.blocks[0].output_schema == {}
|
||||
assert response.blocks[0].static_output is False
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_include_schemas_true_populates_schemas(self):
|
||||
"""With include_schemas=true, schemas should be populated from block info."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
input_schema = {"properties": {"url": {"type": "string"}}, "required": ["url"]}
|
||||
output_schema = {"properties": {"result": {"type": "string"}}}
|
||||
|
||||
search_results = [{"content_id": "block-1", "score": 0.9}]
|
||||
block = make_mock_block(
|
||||
"block-1",
|
||||
"Test Block",
|
||||
BlockType.STANDARD,
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||
return_value=(search_results, 1)
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
return_value=block,
|
||||
),
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
query="test",
|
||||
include_schemas=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert response.blocks[0].input_schema == input_schema
|
||||
assert response.blocks[0].output_schema == output_schema
|
||||
assert response.blocks[0].static_output is True
|
||||
|
||||
@@ -22,6 +22,9 @@ class FindLibraryAgentTool(BaseTool):
|
||||
"Search for or list agents in the user's library. Use this to find "
|
||||
"agents the user has already added to their library, including agents "
|
||||
"they created or added from the marketplace. "
|
||||
"When creating agents with sub-agent composition, use this to get "
|
||||
"the agent's graph_id, graph_version, input_schema, and output_schema "
|
||||
"needed for AgentExecutorBlock nodes. "
|
||||
"Omit the query to list all agents."
|
||||
)
|
||||
|
||||
|
||||
134
autogpt_platform/backend/backend/copilot/tools/fix_agent.py
Normal file
134
autogpt_platform/backend/backend/copilot/tools/fix_agent.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""FixAgentGraphTool - Auto-fixes common agent JSON issues."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .agent_generator.validation import AgentFixer, AgentValidator, get_blocks_as_dicts
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, FixResultResponse, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FixAgentGraphTool(BaseTool):
|
||||
"""Tool for auto-fixing common issues in agent JSON graphs."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "fix_agent_graph"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Auto-fix common issues in an agent JSON graph. Applies fixes for:\n"
|
||||
"- Missing or invalid UUIDs on nodes and links\n"
|
||||
"- StoreValueBlock prerequisites for ConditionBlock\n"
|
||||
"- Double curly brace escaping in prompt templates\n"
|
||||
"- AddToList/AddToDictionary prerequisite blocks\n"
|
||||
"- CodeExecutionBlock output field naming\n"
|
||||
"- Missing credentials configuration\n"
|
||||
"- Node X coordinate spacing (800+ units apart)\n"
|
||||
"- AI model default parameters\n"
|
||||
"- Link static properties based on input schema\n"
|
||||
"- Type mismatches (inserts conversion blocks)\n\n"
|
||||
"Returns the fixed agent JSON plus a list of fixes applied. "
|
||||
"After fixing, the agent is re-validated. If still invalid, "
|
||||
"the remaining errors are included in the response."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_json": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"The agent JSON to fix. Must contain 'nodes' and 'links' arrays."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["agent_json"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
agent_json = kwargs.get("agent_json")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_json or not isinstance(agent_json, dict):
|
||||
return ErrorResponse(
|
||||
message="Please provide a valid agent JSON object.",
|
||||
error="Missing or invalid agent_json parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
nodes = agent_json.get("nodes", [])
|
||||
|
||||
if not nodes:
|
||||
return ErrorResponse(
|
||||
message="The agent JSON has no nodes. An agent needs at least one block.",
|
||||
error="empty_agent",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
blocks = get_blocks_as_dicts()
|
||||
fixer = AgentFixer()
|
||||
fixed_agent = fixer.apply_all_fixes(agent_json, blocks)
|
||||
fixes_applied = fixer.get_fixes_applied()
|
||||
except Exception as e:
|
||||
logger.error(f"Fixer error: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Auto-fix encountered an error: {str(e)}",
|
||||
error="fix_exception",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Re-validate after fixing
|
||||
try:
|
||||
validator = AgentValidator()
|
||||
is_valid, _ = validator.validate(fixed_agent, blocks)
|
||||
remaining_errors = validator.errors if not is_valid else []
|
||||
except Exception as e:
|
||||
logger.warning(f"Post-fix validation error: {e}", exc_info=True)
|
||||
remaining_errors = [f"Post-fix validation failed: {str(e)}"]
|
||||
is_valid = False
|
||||
|
||||
if is_valid:
|
||||
return FixResultResponse(
|
||||
message=(
|
||||
f"Applied {len(fixes_applied)} fix(es). "
|
||||
"Agent graph is now valid!"
|
||||
),
|
||||
fixed_agent_json=fixed_agent,
|
||||
fixes_applied=fixes_applied,
|
||||
fix_count=len(fixes_applied),
|
||||
valid_after_fix=True,
|
||||
remaining_errors=[],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return FixResultResponse(
|
||||
message=(
|
||||
f"Applied {len(fixes_applied)} fix(es), but "
|
||||
f"{len(remaining_errors)} issue(s) remain. "
|
||||
"Review the remaining errors and fix manually."
|
||||
),
|
||||
fixed_agent_json=fixed_agent,
|
||||
fixes_applied=fixes_applied,
|
||||
fix_count=len(fixes_applied),
|
||||
valid_after_fix=False,
|
||||
remaining_errors=remaining_errors,
|
||||
session_id=session_id,
|
||||
)
|
||||
189
autogpt_platform/backend/backend/copilot/tools/fix_agent_test.py
Normal file
189
autogpt_platform/backend/backend/copilot/tools/fix_agent_test.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""Tests for FixAgentGraphTool."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.tools.fix_agent import FixAgentGraphTool
|
||||
from backend.copilot.tools.models import ErrorResponse, FixResultResponse
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-fix-agent"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return FixAgentGraphTool()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session():
|
||||
return make_session(_TEST_USER_ID)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_agent_json_returns_error(tool, session):
|
||||
"""Missing agent_json returns ErrorResponse."""
|
||||
result = await tool._execute(user_id=_TEST_USER_ID, session=session)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error is not None
|
||||
assert "agent_json" in result.error.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_nodes_returns_error(tool, session):
|
||||
"""Agent JSON with no nodes returns ErrorResponse."""
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json={"nodes": [], "links": []},
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "no nodes" in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fix_and_validate_success(tool, session):
|
||||
"""Fixer applies fixes and validator passes -> valid_after_fix=True."""
|
||||
agent_json = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"block_id": "block-1",
|
||||
"input_default": {},
|
||||
"metadata": {"position": {"x": 0, "y": 0}},
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
}
|
||||
|
||||
fixed_agent = dict(agent_json) # Fixer returns the full agent dict
|
||||
|
||||
mock_fixer = MagicMock()
|
||||
mock_fixer.apply_all_fixes = MagicMock(return_value=fixed_agent)
|
||||
mock_fixer.get_fixes_applied.return_value = ["Fixed node UUID format"]
|
||||
|
||||
mock_validator = MagicMock()
|
||||
mock_validator.validate.return_value = (True, None)
|
||||
mock_validator.errors = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.fix_agent.get_blocks_as_dicts",
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.fix_agent.AgentFixer",
|
||||
return_value=mock_fixer,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.fix_agent.AgentValidator",
|
||||
return_value=mock_validator,
|
||||
),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json=agent_json,
|
||||
)
|
||||
|
||||
assert isinstance(result, FixResultResponse)
|
||||
assert result.valid_after_fix is True
|
||||
assert result.fix_count == 1
|
||||
assert result.fixes_applied == ["Fixed node UUID format"]
|
||||
assert result.remaining_errors == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fix_with_remaining_errors(tool, session):
|
||||
"""Fixer applies some fixes but validation still fails."""
|
||||
agent_json = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"block_id": "block-1",
|
||||
"input_default": {},
|
||||
"metadata": {},
|
||||
}
|
||||
],
|
||||
"links": [
|
||||
{
|
||||
"id": "link-1",
|
||||
"source_id": "node-1",
|
||||
"source_name": "output",
|
||||
"sink_id": "node-2",
|
||||
"sink_name": "input",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
fixed_agent = dict(agent_json)
|
||||
|
||||
mock_fixer = MagicMock()
|
||||
mock_fixer.apply_all_fixes = MagicMock(return_value=fixed_agent)
|
||||
mock_fixer.get_fixes_applied.return_value = ["Fixed UUID"]
|
||||
|
||||
mock_validator = MagicMock()
|
||||
mock_validator.validate.return_value = (
|
||||
False,
|
||||
"Link references non-existent node 'node-2'",
|
||||
)
|
||||
mock_validator.errors = ["Link references non-existent node 'node-2'"]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.fix_agent.get_blocks_as_dicts",
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.fix_agent.AgentFixer",
|
||||
return_value=mock_fixer,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.fix_agent.AgentValidator",
|
||||
return_value=mock_validator,
|
||||
),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json=agent_json,
|
||||
)
|
||||
|
||||
assert isinstance(result, FixResultResponse)
|
||||
assert result.valid_after_fix is False
|
||||
assert result.fix_count == 1
|
||||
assert len(result.remaining_errors) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fixer_exception_returns_error(tool, session):
|
||||
"""Fixer exception returns ErrorResponse."""
|
||||
agent_json = {
|
||||
"nodes": [{"id": "n1", "block_id": "b1", "input_default": {}, "metadata": {}}],
|
||||
"links": [],
|
||||
}
|
||||
|
||||
mock_fixer = MagicMock()
|
||||
mock_fixer.apply_all_fixes = MagicMock(side_effect=RuntimeError("fixer crashed"))
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.fix_agent.get_blocks_as_dicts",
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.fix_agent.AgentFixer",
|
||||
return_value=mock_fixer,
|
||||
),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json=agent_json,
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error is not None
|
||||
assert "fix_exception" in result.error
|
||||
@@ -0,0 +1,84 @@
|
||||
"""GetAgentBuildingGuideTool - Returns the complete agent building guide."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GUIDE_CACHE: str | None = None
|
||||
|
||||
|
||||
def _load_guide() -> str:
|
||||
global _GUIDE_CACHE
|
||||
if _GUIDE_CACHE is None:
|
||||
guide_path = Path(__file__).parent.parent / "sdk" / "agent_generation_guide.md"
|
||||
_GUIDE_CACHE = guide_path.read_text(encoding="utf-8")
|
||||
return _GUIDE_CACHE
|
||||
|
||||
|
||||
class AgentBuildingGuideResponse(ToolResponseBase):
|
||||
"""Response containing the agent building guide."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_BUILDER_GUIDE
|
||||
content: str
|
||||
|
||||
|
||||
class GetAgentBuildingGuideTool(BaseTool):
|
||||
"""Returns the complete guide for building agent JSON graphs.
|
||||
|
||||
Covers block IDs, link structure, AgentInputBlock, AgentOutputBlock,
|
||||
AgentExecutorBlock (sub-agent composition), and MCPToolBlock usage.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_agent_building_guide"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Returns the complete guide for building agent JSON graphs, including "
|
||||
"block IDs, link structure, AgentInputBlock, AgentOutputBlock, "
|
||||
"AgentExecutorBlock (for sub-agent composition), and MCPToolBlock usage. "
|
||||
"Call this before generating agent JSON to ensure correct structure."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id if session else None
|
||||
try:
|
||||
content = _load_guide()
|
||||
return AgentBuildingGuideResponse(
|
||||
message="Agent building guide loaded.",
|
||||
content=content,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to load agent building guide: %s", e)
|
||||
return ErrorResponse(
|
||||
message="Failed to load agent building guide.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -0,0 +1,79 @@
|
||||
"""GetMCPGuideTool - Returns the MCP tool usage guide."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GUIDE_CACHE: str | None = None
|
||||
|
||||
|
||||
def _load_guide() -> str:
|
||||
global _GUIDE_CACHE
|
||||
if _GUIDE_CACHE is None:
|
||||
guide_path = Path(__file__).parent.parent / "sdk" / "mcp_tool_guide.md"
|
||||
_GUIDE_CACHE = guide_path.read_text(encoding="utf-8")
|
||||
return _GUIDE_CACHE
|
||||
|
||||
|
||||
class MCPGuideResponse(ToolResponseBase):
|
||||
"""Response containing the MCP tool guide."""
|
||||
|
||||
type: ResponseType = ResponseType.MCP_GUIDE
|
||||
content: str
|
||||
|
||||
|
||||
class GetMCPGuideTool(BaseTool):
|
||||
"""Returns the MCP tool usage guide with known server URLs and auth details."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_mcp_guide"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Returns the MCP tool guide: known hosted server URLs (Notion, Linear, "
|
||||
"Stripe, Intercom, Cloudflare, Atlassian) and authentication workflow. "
|
||||
"Call before using run_mcp_tool if you need a server URL or auth info."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id if session else None
|
||||
try:
|
||||
content = _load_guide()
|
||||
return MCPGuideResponse(
|
||||
message="MCP guide loaded.",
|
||||
content=content,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to load MCP guide: %s", e)
|
||||
return ErrorResponse(
|
||||
message="Failed to load MCP guide.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -1,7 +1,24 @@
|
||||
"""Shared helpers for chat tools."""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
from .models import BlockOutputResponse, ErrorResponse, ToolResponseBase
|
||||
from .utils import match_credentials_to_requirements
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_inputs_from_schema(
|
||||
input_schema: dict[str, Any],
|
||||
@@ -27,3 +44,159 @@ def get_inputs_from_schema(
|
||||
for name, schema in properties.items()
|
||||
if name not in exclude
|
||||
]
|
||||
|
||||
|
||||
async def execute_block(
|
||||
*,
|
||||
block: AnyBlockSchema,
|
||||
block_id: str,
|
||||
input_data: dict[str, Any],
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
node_exec_id: str,
|
||||
matched_credentials: dict[str, CredentialsMetaInput],
|
||||
sensitive_action_safe_mode: bool = False,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute a block with full context setup, credential injection, and error handling.
|
||||
|
||||
This is the shared execution path used by both ``run_block`` (after review
|
||||
check) and ``continue_run_block`` (after approval).
|
||||
|
||||
Returns:
|
||||
BlockOutputResponse on success, ErrorResponse on failure.
|
||||
"""
|
||||
try:
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
|
||||
synthetic_graph_id = f"{COPILOT_SESSION_PREFIX}{session_id}"
|
||||
synthetic_node_id = f"{COPILOT_NODE_PREFIX}{block_id}"
|
||||
|
||||
execution_context = ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
graph_version=1,
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
workspace_id=workspace.id,
|
||||
session_id=session_id,
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
)
|
||||
|
||||
exec_kwargs: dict[str, Any] = {
|
||||
"user_id": user_id,
|
||||
"execution_context": execution_context,
|
||||
"workspace_id": workspace.id,
|
||||
"graph_exec_id": synthetic_graph_id,
|
||||
"node_exec_id": node_exec_id,
|
||||
"node_id": synthetic_node_id,
|
||||
"graph_version": 1,
|
||||
"graph_id": synthetic_graph_id,
|
||||
}
|
||||
|
||||
# Inject credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
for field_name, cred_meta in matched_credentials.items():
|
||||
if field_name not in input_data:
|
||||
input_data[field_name] = cred_meta.model_dump()
|
||||
|
||||
actual_credentials = await creds_manager.get(
|
||||
user_id, cred_meta.id, lock=False
|
||||
)
|
||||
if actual_credentials:
|
||||
exec_kwargs[field_name] = actual_credentials
|
||||
else:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to retrieve credentials for {field_name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Execute the block and collect outputs
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in block.execute(
|
||||
input_data,
|
||||
**exec_kwargs,
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
outputs=dict(outputs),
|
||||
success=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except BlockError as e:
|
||||
logger.warning(f"Block execution failed: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Block execution failed: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to execute block: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
async def resolve_block_credentials(
|
||||
user_id: str,
|
||||
block: AnyBlockSchema,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""Resolve credentials for a block by matching user's available credentials.
|
||||
|
||||
Handles discriminated credentials (e.g. provider selection based on model).
|
||||
|
||||
Returns:
|
||||
(matched_credentials, missing_credentials)
|
||||
"""
|
||||
input_data = input_data or {}
|
||||
requirements = _resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
if not requirements:
|
||||
return {}, []
|
||||
|
||||
return await match_credentials_to_requirements(user_id, requirements)
|
||||
|
||||
|
||||
def _resolve_discriminated_credentials(
|
||||
block: AnyBlockSchema,
|
||||
input_data: dict[str, Any],
|
||||
) -> dict[str, CredentialsFieldInfo]:
|
||||
"""Resolve credential requirements, applying discriminator logic where needed."""
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
if not credentials_fields_info:
|
||||
return {}
|
||||
|
||||
resolved: dict[str, CredentialsFieldInfo] = {}
|
||||
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
effective_field_info = field_info
|
||||
|
||||
if field_info.discriminator and field_info.discriminator_mapping:
|
||||
discriminator_value = input_data.get(field_info.discriminator)
|
||||
if discriminator_value is None:
|
||||
field = block.input_schema.model_fields.get(field_info.discriminator)
|
||||
if field and field.default is not PydanticUndefined:
|
||||
discriminator_value = field.default
|
||||
|
||||
if (
|
||||
discriminator_value
|
||||
and discriminator_value in field_info.discriminator_mapping
|
||||
):
|
||||
effective_field_info = field_info.discriminate(discriminator_value)
|
||||
effective_field_info.discriminator_values.add(discriminator_value)
|
||||
logger.debug(
|
||||
f"Discriminated provider for {field_name}: "
|
||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||
)
|
||||
|
||||
resolved[field_name] = effective_field_info
|
||||
|
||||
return resolved
|
||||
|
||||
@@ -12,50 +12,52 @@ from backend.data.model import CredentialsMetaInput
|
||||
class ResponseType(str, Enum):
|
||||
"""Types of tool responses."""
|
||||
|
||||
# General
|
||||
ERROR = "error"
|
||||
NO_RESULTS = "no_results"
|
||||
NEED_LOGIN = "need_login"
|
||||
|
||||
# Agent discovery & execution
|
||||
AGENTS_FOUND = "agents_found"
|
||||
AGENT_DETAILS = "agent_details"
|
||||
SETUP_REQUIREMENTS = "setup_requirements"
|
||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
||||
EXECUTION_STARTED = "execution_started"
|
||||
NEED_LOGIN = "need_login"
|
||||
ERROR = "error"
|
||||
NO_RESULTS = "no_results"
|
||||
AGENT_OUTPUT = "agent_output"
|
||||
UNDERSTANDING_UPDATED = "understanding_updated"
|
||||
AGENT_PREVIEW = "agent_preview"
|
||||
AGENT_SAVED = "agent_saved"
|
||||
CLARIFICATION_NEEDED = "clarification_needed"
|
||||
SUGGESTED_GOAL = "suggested_goal"
|
||||
|
||||
# Agent builder (create / edit / validate / fix)
|
||||
AGENT_BUILDER_GUIDE = "agent_builder_guide"
|
||||
AGENT_BUILDER_PREVIEW = "agent_builder_preview"
|
||||
AGENT_BUILDER_SAVED = "agent_builder_saved"
|
||||
AGENT_BUILDER_CLARIFICATION_NEEDED = "agent_builder_clarification_needed"
|
||||
AGENT_BUILDER_VALIDATION_RESULT = "agent_builder_validation_result"
|
||||
AGENT_BUILDER_FIX_RESULT = "agent_builder_fix_result"
|
||||
|
||||
# Block
|
||||
BLOCK_LIST = "block_list"
|
||||
BLOCK_DETAILS = "block_details"
|
||||
BLOCK_OUTPUT = "block_output"
|
||||
REVIEW_REQUIRED = "review_required"
|
||||
|
||||
# MCP
|
||||
MCP_GUIDE = "mcp_guide"
|
||||
MCP_TOOLS_DISCOVERED = "mcp_tools_discovered"
|
||||
MCP_TOOL_OUTPUT = "mcp_tool_output"
|
||||
|
||||
# Docs
|
||||
DOC_SEARCH_RESULTS = "doc_search_results"
|
||||
DOC_PAGE = "doc_page"
|
||||
# Workspace response types
|
||||
|
||||
# Workspace files
|
||||
WORKSPACE_FILE_LIST = "workspace_file_list"
|
||||
WORKSPACE_FILE_CONTENT = "workspace_file_content"
|
||||
WORKSPACE_FILE_METADATA = "workspace_file_metadata"
|
||||
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
|
||||
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
|
||||
# Long-running operation types
|
||||
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||
# Input validation
|
||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
||||
# Web fetch
|
||||
WEB_FETCH = "web_fetch"
|
||||
# Agent-browser multi-step automation (navigate, act, screenshot)
|
||||
BROWSER_NAVIGATE = "browser_navigate"
|
||||
BROWSER_ACT = "browser_act"
|
||||
BROWSER_SCREENSHOT = "browser_screenshot"
|
||||
# Code execution
|
||||
BASH_EXEC = "bash_exec"
|
||||
# Feature request types
|
||||
FEATURE_REQUEST_SEARCH = "feature_request_search"
|
||||
FEATURE_REQUEST_CREATED = "feature_request_created"
|
||||
# Goal refinement
|
||||
SUGGESTED_GOAL = "suggested_goal"
|
||||
# MCP tool types
|
||||
MCP_TOOLS_DISCOVERED = "mcp_tools_discovered"
|
||||
MCP_TOOL_OUTPUT = "mcp_tool_output"
|
||||
# Folder management types
|
||||
|
||||
# Folder management
|
||||
FOLDER_CREATED = "folder_created"
|
||||
FOLDER_LIST = "folder_list"
|
||||
FOLDER_UPDATED = "folder_updated"
|
||||
@@ -63,6 +65,21 @@ class ResponseType(str, Enum):
|
||||
FOLDER_DELETED = "folder_deleted"
|
||||
AGENTS_MOVED_TO_FOLDER = "agents_moved_to_folder"
|
||||
|
||||
# Browser automation
|
||||
BROWSER_NAVIGATE = "browser_navigate"
|
||||
BROWSER_ACT = "browser_act"
|
||||
BROWSER_SCREENSHOT = "browser_screenshot"
|
||||
|
||||
# Code execution
|
||||
BASH_EXEC = "bash_exec"
|
||||
|
||||
# Web
|
||||
WEB_FETCH = "web_fetch"
|
||||
|
||||
# Feature requests
|
||||
FEATURE_REQUEST_SEARCH = "feature_request_search"
|
||||
FEATURE_REQUEST_CREATED = "feature_request_created"
|
||||
|
||||
|
||||
# Base response model
|
||||
class ToolResponseBase(BaseModel):
|
||||
@@ -92,6 +109,15 @@ class AgentInfo(BaseModel):
|
||||
has_external_trigger: bool | None = None
|
||||
new_output: bool | None = None
|
||||
graph_id: str | None = None
|
||||
graph_version: int | None = None
|
||||
input_schema: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="JSON Schema for the agent's inputs (for AgentExecutorBlock)",
|
||||
)
|
||||
output_schema: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="JSON Schema for the agent's outputs (for AgentExecutorBlock)",
|
||||
)
|
||||
inputs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Input schema for the agent, including field names, types, and defaults",
|
||||
@@ -282,7 +308,7 @@ class ClarifyingQuestion(BaseModel):
|
||||
class AgentPreviewResponse(ToolResponseBase):
|
||||
"""Response for previewing a generated agent before saving."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_PREVIEW
|
||||
type: ResponseType = ResponseType.AGENT_BUILDER_PREVIEW
|
||||
agent_json: dict[str, Any]
|
||||
agent_name: str
|
||||
description: str
|
||||
@@ -293,7 +319,7 @@ class AgentPreviewResponse(ToolResponseBase):
|
||||
class AgentSavedResponse(ToolResponseBase):
|
||||
"""Response when an agent is saved to the library."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_SAVED
|
||||
type: ResponseType = ResponseType.AGENT_BUILDER_SAVED
|
||||
agent_id: str
|
||||
agent_name: str
|
||||
library_agent_id: str
|
||||
@@ -304,7 +330,7 @@ class AgentSavedResponse(ToolResponseBase):
|
||||
class ClarificationNeededResponse(ToolResponseBase):
|
||||
"""Response when the LLM needs more information from the user."""
|
||||
|
||||
type: ResponseType = ResponseType.CLARIFICATION_NEEDED
|
||||
type: ResponseType = ResponseType.AGENT_BUILDER_CLARIFICATION_NEEDED
|
||||
questions: list[ClarifyingQuestion] = Field(default_factory=list)
|
||||
|
||||
|
||||
@@ -381,6 +407,10 @@ class BlockInfoSummary(BaseModel):
|
||||
default_factory=dict,
|
||||
description="Full JSON schema for block outputs",
|
||||
)
|
||||
static_output: bool = Field(
|
||||
default=False,
|
||||
description="Whether the block produces output without needing input",
|
||||
)
|
||||
required_inputs: list[BlockInputFieldInfo] = Field(
|
||||
default_factory=list,
|
||||
description="List of input fields for this block",
|
||||
@@ -429,16 +459,19 @@ class BlockOutputResponse(ToolResponseBase):
|
||||
success: bool = True
|
||||
|
||||
|
||||
# Long-running operation models
|
||||
class OperationInProgressResponse(ToolResponseBase):
|
||||
"""Response when an operation is already in progress.
|
||||
class ReviewRequiredResponse(ToolResponseBase):
|
||||
"""Response when a block requires human review before execution."""
|
||||
|
||||
Returned for idempotency when the same tool_call_id is requested again
|
||||
while the background task is still running.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
||||
tool_call_id: str
|
||||
type: ResponseType = ResponseType.REVIEW_REQUIRED
|
||||
block_id: str
|
||||
block_name: str
|
||||
review_id: str = Field(description="The review ID for tracking approval status")
|
||||
graph_exec_id: str = Field(
|
||||
description="The graph execution ID for fetching review status"
|
||||
)
|
||||
input_data: dict[str, Any] = Field(
|
||||
description="The input data that requires review"
|
||||
)
|
||||
|
||||
|
||||
class WebFetchResponse(ToolResponseBase):
|
||||
@@ -548,6 +581,29 @@ class BrowserScreenshotResponse(ToolResponseBase):
|
||||
filename: str
|
||||
|
||||
|
||||
# Agent generation tool response models
|
||||
|
||||
|
||||
class ValidationResultResponse(ToolResponseBase):
|
||||
"""Response for validate_agent_graph tool."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_BUILDER_VALIDATION_RESULT
|
||||
valid: bool
|
||||
errors: list[str] = Field(default_factory=list)
|
||||
error_count: int = 0
|
||||
|
||||
|
||||
class FixResultResponse(ToolResponseBase):
|
||||
"""Response for fix_agent_graph tool."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_BUILDER_FIX_RESULT
|
||||
fixed_agent_json: dict[str, Any]
|
||||
fixes_applied: list[str] = Field(default_factory=list)
|
||||
fix_count: int = 0
|
||||
valid_after_fix: bool = False
|
||||
remaining_errors: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
# Folder management models
|
||||
|
||||
|
||||
|
||||
@@ -534,7 +534,9 @@ class RunAgentTool(BaseTool):
|
||||
return ExecutionStartedResponse(
|
||||
message=(
|
||||
f"Agent '{library_agent.name}' is awaiting human review. "
|
||||
f"Check at {library_agent_link}."
|
||||
f"The user can approve or reject inline. After approval, "
|
||||
f"the execution resumes automatically. Use view_agent_output "
|
||||
f"with execution_id='{execution.id}' to check the result."
|
||||
),
|
||||
session_id=session_id,
|
||||
execution_id=execution.id,
|
||||
|
||||
@@ -2,38 +2,34 @@
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks import BlockType, get_block
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.copilot.constants import (
|
||||
COPILOT_NODE_EXEC_ID_SEPARATOR,
|
||||
COPILOT_NODE_PREFIX,
|
||||
COPILOT_SESSION_PREFIX,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.data.db_accessors import review_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
from .base import BaseTool
|
||||
from .find_block import COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
from .helpers import get_inputs_from_schema
|
||||
from .helpers import execute_block, get_inputs_from_schema, resolve_block_credentials
|
||||
from .models import (
|
||||
BlockDetails,
|
||||
BlockDetailsResponse,
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
InputValidationErrorResponse,
|
||||
ReviewRequiredResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
from .utils import (
|
||||
build_missing_credentials_from_field_info,
|
||||
match_credentials_to_requirements,
|
||||
)
|
||||
from .utils import build_missing_credentials_from_field_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -52,7 +48,9 @@ class RunBlockTool(BaseTool):
|
||||
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
|
||||
"do NOT guess or make up block IDs. "
|
||||
"On first attempt (without input_data), returns detailed schema showing "
|
||||
"required inputs and outputs. Then call again with proper input_data to execute."
|
||||
"required inputs and outputs. Then call again with proper input_data to execute. "
|
||||
"If a block requires human review, use continue_run_block with the "
|
||||
"review_id after the user approves."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -83,7 +81,7 @@ class RunBlockTool(BaseTool):
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["block_id", "input_data"],
|
||||
"required": ["block_id", "block_name", "input_data"],
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -149,21 +147,27 @@ class RunBlockTool(BaseTool):
|
||||
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||
):
|
||||
# Provide actionable guidance for blocks with dedicated tools
|
||||
if block.block_type == BlockType.MCP_TOOL:
|
||||
hint = (
|
||||
" Use the `run_mcp_tool` tool instead — it handles "
|
||||
"MCP server discovery, authentication, and execution."
|
||||
)
|
||||
elif block.block_type == BlockType.AGENT:
|
||||
hint = " Use the `run_agent` tool instead."
|
||||
else:
|
||||
hint = " This block is designed for use within graphs only."
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' cannot be run directly in CoPilot. "
|
||||
"This block is designed for use within graphs only."
|
||||
),
|
||||
message=f"Block '{block.name}' cannot be run directly.{hint}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
(
|
||||
matched_credentials,
|
||||
missing_credentials,
|
||||
) = await self._resolve_block_credentials(user_id, block, input_data)
|
||||
) = await resolve_block_credentials(user_id, block, input_data)
|
||||
|
||||
# Get block schemas for details/validation
|
||||
try:
|
||||
@@ -272,169 +276,97 @@ class RunBlockTool(BaseTool):
|
||||
user_authenticated=True,
|
||||
)
|
||||
|
||||
try:
|
||||
# Get or create user's workspace for CoPilot file operations
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
# Generate synthetic IDs for CoPilot context.
|
||||
# Encode node_id in node_exec_id so it can be extracted later
|
||||
# (e.g. for auto-approve, where we need node_id but have no NodeExecution row).
|
||||
synthetic_graph_id = f"{COPILOT_SESSION_PREFIX}{session.session_id}"
|
||||
synthetic_node_id = f"{COPILOT_NODE_PREFIX}{block_id}"
|
||||
|
||||
# Generate synthetic IDs for CoPilot context
|
||||
# Each chat session is treated as its own agent with one continuous run
|
||||
# This means:
|
||||
# - graph_id (agent) = session (memories scoped to session when limit_to_agent=True)
|
||||
# - graph_exec_id (run) = session (memories scoped to session when limit_to_run=True)
|
||||
# - node_exec_id = unique per block execution
|
||||
synthetic_graph_id = f"copilot-session-{session.session_id}"
|
||||
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
|
||||
synthetic_node_id = f"copilot-node-{block_id}"
|
||||
synthetic_node_exec_id = (
|
||||
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
|
||||
)
|
||||
|
||||
# Create unified execution context with all required fields
|
||||
execution_context = ExecutionContext(
|
||||
# Execution identity
|
||||
user_id=user_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
graph_exec_id=synthetic_graph_exec_id,
|
||||
graph_version=1, # Versions are 1-indexed
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=synthetic_node_exec_id,
|
||||
# Workspace with session scoping
|
||||
workspace_id=workspace.id,
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
# Prepare kwargs for block execution
|
||||
# Keep individual kwargs for backwards compatibility with existing blocks
|
||||
exec_kwargs: dict[str, Any] = {
|
||||
"user_id": user_id,
|
||||
"execution_context": execution_context,
|
||||
# Legacy: individual kwargs for blocks not yet using execution_context
|
||||
"workspace_id": workspace.id,
|
||||
"graph_exec_id": synthetic_graph_exec_id,
|
||||
"node_exec_id": synthetic_node_exec_id,
|
||||
"node_id": synthetic_node_id,
|
||||
"graph_version": 1, # Versions are 1-indexed
|
||||
"graph_id": synthetic_graph_id,
|
||||
}
|
||||
|
||||
for field_name, cred_meta in matched_credentials.items():
|
||||
# Inject metadata into input_data (for validation)
|
||||
if field_name not in input_data:
|
||||
input_data[field_name] = cred_meta.model_dump()
|
||||
|
||||
# Fetch actual credentials and pass as kwargs (for execution)
|
||||
actual_credentials = await creds_manager.get(
|
||||
user_id, cred_meta.id, lock=False
|
||||
)
|
||||
if actual_credentials:
|
||||
exec_kwargs[field_name] = actual_credentials
|
||||
else:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to retrieve credentials for {field_name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Execute the block and collect outputs
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in block.execute(
|
||||
input_data,
|
||||
**exec_kwargs,
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
# Check for an existing WAITING review for this block with the same input.
|
||||
# If the LLM retries run_block with identical input, we reuse the existing
|
||||
# review instead of creating duplicates. Different inputs = new execution.
|
||||
existing_reviews = await review_db().get_pending_reviews_for_execution(
|
||||
synthetic_graph_id, user_id
|
||||
)
|
||||
existing_review = next(
|
||||
(
|
||||
r
|
||||
for r in existing_reviews
|
||||
if r.node_id == synthetic_node_id
|
||||
and r.status.value == "WAITING"
|
||||
and r.payload == input_data
|
||||
),
|
||||
None,
|
||||
)
|
||||
if existing_review:
|
||||
return ReviewRequiredResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' requires human review. "
|
||||
f"After the user approves, call continue_run_block with "
|
||||
f"review_id='{existing_review.node_exec_id}' to execute."
|
||||
),
|
||||
session_id=session_id,
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
outputs=dict(outputs),
|
||||
success=True,
|
||||
session_id=session_id,
|
||||
review_id=existing_review.node_exec_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
except BlockError as e:
|
||||
logger.warning(f"Block execution failed: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Block execution failed: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to execute block: {str(e)}",
|
||||
error=str(e),
|
||||
synthetic_node_exec_id = (
|
||||
f"{synthetic_node_id}{COPILOT_NODE_EXEC_ID_SEPARATOR}"
|
||||
f"{uuid.uuid4().hex[:8]}"
|
||||
)
|
||||
|
||||
# Check for HITL review before execution.
|
||||
# This creates the review record in the DB for CoPilot flows.
|
||||
review_context = ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
graph_version=1,
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=synthetic_node_exec_id,
|
||||
sensitive_action_safe_mode=True,
|
||||
)
|
||||
should_pause, input_data = await block.is_block_exec_need_review(
|
||||
input_data,
|
||||
user_id=user_id,
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=synthetic_node_exec_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
graph_version=1,
|
||||
execution_context=review_context,
|
||||
is_graph_execution=False,
|
||||
)
|
||||
if should_pause:
|
||||
return ReviewRequiredResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' requires human review. "
|
||||
f"After the user approves, call continue_run_block with "
|
||||
f"review_id='{synthetic_node_exec_id}' to execute."
|
||||
),
|
||||
session_id=session_id,
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
review_id=synthetic_node_exec_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
async def _resolve_block_credentials(
|
||||
self,
|
||||
user_id: str,
|
||||
block: AnyBlockSchema,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Resolve credentials for a block by matching user's available credentials.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
block: Block to resolve credentials for
|
||||
input_data: Input data for the block (used to determine provider via discriminator)
|
||||
|
||||
Returns:
|
||||
tuple of (matched_credentials, missing_credentials) - matched credentials
|
||||
are used for block execution, missing ones indicate setup requirements.
|
||||
"""
|
||||
input_data = input_data or {}
|
||||
requirements = self._resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
if not requirements:
|
||||
return {}, []
|
||||
|
||||
return await match_credentials_to_requirements(user_id, requirements)
|
||||
return await execute_block(
|
||||
block=block,
|
||||
block_id=block_id,
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
node_exec_id=synthetic_node_exec_id,
|
||||
matched_credentials=matched_credentials,
|
||||
)
|
||||
|
||||
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
|
||||
"""Extract non-credential inputs from block schema."""
|
||||
schema = block.input_schema.jsonschema()
|
||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
|
||||
|
||||
def _resolve_discriminated_credentials(
|
||||
self,
|
||||
block: AnyBlockSchema,
|
||||
input_data: dict[str, Any],
|
||||
) -> dict[str, CredentialsFieldInfo]:
|
||||
"""Resolve credential requirements, applying discriminator logic where needed."""
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
if not credentials_fields_info:
|
||||
return {}
|
||||
|
||||
resolved: dict[str, CredentialsFieldInfo] = {}
|
||||
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
effective_field_info = field_info
|
||||
|
||||
if field_info.discriminator and field_info.discriminator_mapping:
|
||||
discriminator_value = input_data.get(field_info.discriminator)
|
||||
if discriminator_value is None:
|
||||
field = block.input_schema.model_fields.get(
|
||||
field_info.discriminator
|
||||
)
|
||||
if field and field.default is not PydanticUndefined:
|
||||
discriminator_value = field.default
|
||||
|
||||
if (
|
||||
discriminator_value
|
||||
and discriminator_value in field_info.discriminator_mapping
|
||||
):
|
||||
effective_field_info = field_info.discriminate(discriminator_value)
|
||||
# For host-scoped credentials, add the discriminator value
|
||||
# (e.g., URL) so _credential_is_for_host can match it
|
||||
effective_field_info.discriminator_values.add(discriminator_value)
|
||||
logger.debug(
|
||||
f"Discriminated provider for {field_name}: "
|
||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||
)
|
||||
|
||||
resolved[field_name] = effective_field_info
|
||||
|
||||
return resolved
|
||||
|
||||
@@ -12,6 +12,7 @@ from .models import (
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
InputValidationErrorResponse,
|
||||
ReviewRequiredResponse,
|
||||
)
|
||||
from .run_block import RunBlockTool
|
||||
|
||||
@@ -27,9 +28,16 @@ def make_mock_block(
|
||||
mock.name = name
|
||||
mock.block_type = block_type
|
||||
mock.disabled = disabled
|
||||
mock.is_sensitive_action = False
|
||||
mock.input_schema = MagicMock()
|
||||
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
||||
mock.input_schema.get_credentials_fields_info.return_value = []
|
||||
mock.input_schema.get_credentials_fields_info.return_value = {}
|
||||
mock.input_schema.get_credentials_fields.return_value = {}
|
||||
|
||||
async def _no_review(input_data, **kwargs):
|
||||
return False, input_data
|
||||
|
||||
mock.is_block_exec_need_review = _no_review
|
||||
return mock
|
||||
|
||||
|
||||
@@ -46,6 +54,7 @@ def make_mock_block_with_schema(
|
||||
mock.name = name
|
||||
mock.block_type = BlockType.STANDARD
|
||||
mock.disabled = False
|
||||
mock.is_sensitive_action = False
|
||||
mock.description = f"Test block: {name}"
|
||||
|
||||
input_schema = {
|
||||
@@ -63,6 +72,12 @@ def make_mock_block_with_schema(
|
||||
mock.output_schema = MagicMock()
|
||||
mock.output_schema.jsonschema.return_value = output_schema
|
||||
|
||||
# Default: no review needed, pass through input_data unchanged
|
||||
async def _no_review(input_data, **kwargs):
|
||||
return False, input_data
|
||||
|
||||
mock.is_block_exec_need_review = _no_review
|
||||
|
||||
return mock
|
||||
|
||||
|
||||
@@ -89,7 +104,7 @@ class TestRunBlockFiltering:
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "cannot be run directly in CoPilot" in response.message
|
||||
assert "cannot be run directly" in response.message
|
||||
assert "designed for use within graphs only" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@@ -115,7 +130,7 @@ class TestRunBlockFiltering:
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "cannot be run directly in CoPilot" in response.message
|
||||
assert "cannot be run directly" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_non_excluded_block_passes_guard(self):
|
||||
@@ -126,9 +141,15 @@ class TestRunBlockFiltering:
|
||||
"standard-id", "HTTP Request", BlockType.STANDARD
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=standard_block,
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=standard_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
@@ -141,7 +162,7 @@ class TestRunBlockFiltering:
|
||||
# Should NOT be an ErrorResponse about CoPilot exclusion
|
||||
# (may be other errors like missing credentials, but not the exclusion guard)
|
||||
if isinstance(response, ErrorResponse):
|
||||
assert "cannot be run directly in CoPilot" not in response.message
|
||||
assert "cannot be run directly" not in response.message
|
||||
|
||||
|
||||
class TestRunBlockInputValidation:
|
||||
@@ -154,12 +175,7 @@ class TestRunBlockInputValidation:
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unknown_input_fields_are_rejected(self):
|
||||
"""run_block rejects unknown input fields instead of silently ignoring them.
|
||||
|
||||
Scenario: The AI Text Generator block has a field called 'model' (for LLM model
|
||||
selection), but the LLM calling the tool guesses wrong and sends 'LLM_Model'
|
||||
instead. The block should reject the request and return the valid schema.
|
||||
"""
|
||||
"""run_block rejects unknown input fields instead of silently ignoring them."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
mock_block = make_mock_block_with_schema(
|
||||
@@ -182,27 +198,31 @@ class TestRunBlockInputValidation:
|
||||
output_properties={"response": {"type": "string"}},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
|
||||
# Provide 'prompt' (correct) but 'LLM_Model' instead of 'model' (wrong key)
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="ai-text-gen-id",
|
||||
input_data={
|
||||
"prompt": "Write a haiku about coding",
|
||||
"LLM_Model": "claude-opus-4-6", # WRONG KEY - should be 'model'
|
||||
"LLM_Model": "claude-opus-4-6",
|
||||
},
|
||||
)
|
||||
|
||||
assert isinstance(response, InputValidationErrorResponse)
|
||||
assert "LLM_Model" in response.unrecognized_fields
|
||||
assert "Block was not executed" in response.message
|
||||
assert "inputs" in response.model_dump() # valid schema included
|
||||
assert "inputs" in response.model_dump()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_multiple_wrong_keys_are_all_reported(self):
|
||||
@@ -221,21 +241,26 @@ class TestRunBlockInputValidation:
|
||||
required_fields=["prompt"],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="ai-text-gen-id",
|
||||
input_data={
|
||||
"prompt": "Hello", # correct
|
||||
"llm_model": "claude-opus-4-6", # WRONG - should be 'model'
|
||||
"system_prompt": "Be helpful", # WRONG - should be 'sys_prompt'
|
||||
"retries": 5, # WRONG - should be 'retry'
|
||||
"prompt": "Hello",
|
||||
"llm_model": "claude-opus-4-6",
|
||||
"system_prompt": "Be helpful",
|
||||
"retries": 5,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -262,23 +287,26 @@ class TestRunBlockInputValidation:
|
||||
required_fields=["prompt"],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
|
||||
# 'prompt' is missing AND 'LLM_Model' is an unknown field
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="ai-text-gen-id",
|
||||
input_data={
|
||||
"LLM_Model": "claude-opus-4-6", # wrong key, and 'prompt' is missing
|
||||
"LLM_Model": "claude-opus-4-6",
|
||||
},
|
||||
)
|
||||
|
||||
# Unknown fields are caught first
|
||||
assert isinstance(response, InputValidationErrorResponse)
|
||||
assert "LLM_Model" in response.unrecognized_fields
|
||||
|
||||
@@ -313,7 +341,11 @@ class TestRunBlockInputValidation:
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.workspace_db",
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
),
|
||||
):
|
||||
@@ -325,7 +357,7 @@ class TestRunBlockInputValidation:
|
||||
block_id="ai-text-gen-id",
|
||||
input_data={
|
||||
"prompt": "Write a haiku",
|
||||
"model": "gpt-4o-mini", # correct field name
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -347,20 +379,191 @@ class TestRunBlockInputValidation:
|
||||
required_fields=["prompt"],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
|
||||
# Only provide valid optional field, missing required 'prompt'
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="ai-text-gen-id",
|
||||
input_data={
|
||||
"model": "gpt-4o-mini", # valid but optional
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockDetailsResponse)
|
||||
|
||||
|
||||
class TestRunBlockSensitiveAction:
|
||||
"""Tests for sensitive action HITL review in RunBlockTool.
|
||||
|
||||
run_block calls is_block_exec_need_review() explicitly before execution.
|
||||
When review is needed (should_pause=True), ReviewRequiredResponse is returned.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_sensitive_block_paused_returns_review_required(self):
|
||||
"""When is_block_exec_need_review returns should_pause=True, ReviewRequiredResponse is returned."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
input_data = {
|
||||
"repo_url": "https://github.com/test/repo",
|
||||
"branch": "feature-branch",
|
||||
}
|
||||
mock_block = make_mock_block_with_schema(
|
||||
block_id="delete-branch-id",
|
||||
name="Delete Branch",
|
||||
input_properties={
|
||||
"repo_url": {"type": "string"},
|
||||
"branch": {"type": "string"},
|
||||
},
|
||||
required_fields=["repo_url", "branch"],
|
||||
)
|
||||
mock_block.is_sensitive_action = True
|
||||
mock_block.is_block_exec_need_review = AsyncMock(
|
||||
return_value=(True, input_data)
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="delete-branch-id",
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
assert isinstance(response, ReviewRequiredResponse)
|
||||
assert "requires human review" in response.message
|
||||
assert "continue_run_block" in response.message
|
||||
assert response.block_name == "Delete Branch"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_sensitive_block_executes_after_approval(self):
|
||||
"""After approval (should_pause=False), sensitive blocks execute and return outputs."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
input_data = {
|
||||
"repo_url": "https://github.com/test/repo",
|
||||
"branch": "feature-branch",
|
||||
}
|
||||
mock_block = make_mock_block_with_schema(
|
||||
block_id="delete-branch-id",
|
||||
name="Delete Branch",
|
||||
input_properties={
|
||||
"repo_url": {"type": "string"},
|
||||
"branch": {"type": "string"},
|
||||
},
|
||||
required_fields=["repo_url", "branch"],
|
||||
)
|
||||
mock_block.is_sensitive_action = True
|
||||
mock_block.is_block_exec_need_review = AsyncMock(
|
||||
return_value=(False, input_data)
|
||||
)
|
||||
|
||||
async def mock_execute(input_data, **kwargs):
|
||||
yield "result", "Branch deleted successfully"
|
||||
|
||||
mock_block.execute = mock_execute
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="test-workspace-id")
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
),
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="delete-branch-id",
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert response.success is True
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_non_sensitive_block_executes_normally(self):
|
||||
"""Non-sensitive blocks skip review and execute directly."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
input_data = {"url": "https://example.com"}
|
||||
mock_block = make_mock_block_with_schema(
|
||||
block_id="http-request-id",
|
||||
name="HTTP Request",
|
||||
input_properties={
|
||||
"url": {"type": "string"},
|
||||
},
|
||||
required_fields=["url"],
|
||||
)
|
||||
mock_block.is_sensitive_action = False
|
||||
mock_block.is_block_exec_need_review = AsyncMock(
|
||||
return_value=(False, input_data)
|
||||
)
|
||||
|
||||
async def mock_execute(input_data, **kwargs):
|
||||
yield "response", {"status": 200}
|
||||
|
||||
mock_block.execute = mock_execute
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="test-workspace-id")
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
),
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="http-request-id",
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert response.success is True
|
||||
|
||||
@@ -14,7 +14,7 @@ from backend.blocks.mcp.helpers import (
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.utils import build_missing_credentials_from_field_info
|
||||
from backend.util.request import HTTPClientError, validate_url
|
||||
from backend.util.request import HTTPClientError, validate_url_host
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
@@ -53,15 +53,9 @@ class RunMCPToolTool(BaseTool):
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Connect to an MCP (Model Context Protocol) server to discover and execute its tools. "
|
||||
"Two-step workflow: (1) Call with just `server_url` to discover available tools. "
|
||||
"(2) Call again with `server_url`, `tool_name`, and `tool_arguments` to execute. "
|
||||
"Known hosted servers (use directly): Notion (https://mcp.notion.com/mcp), "
|
||||
"Linear (https://mcp.linear.app/mcp), Stripe (https://mcp.stripe.com), "
|
||||
"Intercom (https://mcp.intercom.com/mcp), Cloudflare (https://mcp.cloudflare.com/mcp), "
|
||||
"Atlassian/Jira (https://mcp.atlassian.com/mcp). "
|
||||
"For other services, search the MCP registry at https://registry.modelcontextprotocol.io/. "
|
||||
"Authentication: If the server requires credentials, user will be prompted to complete the MCP credential setup flow."
|
||||
"Once connected and user confirms, retry the same call immediately."
|
||||
"Two-step: (1) call with server_url to list available tools, "
|
||||
"(2) call again with server_url + tool_name + tool_arguments to execute. "
|
||||
"Call get_mcp_guide for known server URLs and auth details."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -150,7 +144,7 @@ class RunMCPToolTool(BaseTool):
|
||||
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges
|
||||
try:
|
||||
await validate_url(server_url, trusted_origins=[])
|
||||
await validate_url_host(server_url)
|
||||
except ValueError as e:
|
||||
msg = str(e)
|
||||
if "Unable to resolve" in msg or "No IP addresses" in msg:
|
||||
|
||||
@@ -65,9 +65,8 @@ async def test_run_block_returns_details_when_no_input_provided():
|
||||
return_value=http_block,
|
||||
):
|
||||
# Mock credentials check to return no missing credentials
|
||||
with patch.object(
|
||||
RunBlockTool,
|
||||
"_resolve_block_credentials",
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.resolve_block_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=({}, []), # (matched_credentials, missing_credentials)
|
||||
):
|
||||
@@ -123,9 +122,8 @@ async def test_run_block_returns_details_when_only_credentials_provided():
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock,
|
||||
):
|
||||
with patch.object(
|
||||
RunBlockTool,
|
||||
"_resolve_block_credentials",
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.resolve_block_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(
|
||||
{
|
||||
|
||||
@@ -100,7 +100,7 @@ async def test_ssrf_blocked_url_returns_error():
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url",
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked loopback"),
|
||||
):
|
||||
@@ -138,7 +138,7 @@ async def test_non_dict_tool_arguments_returns_error():
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url",
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
with patch(
|
||||
@@ -171,7 +171,7 @@ async def test_discover_tools_returns_discovered_response():
|
||||
mock_tools = _make_tool_list("fetch", "search")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -208,7 +208,7 @@ async def test_discover_tools_with_credentials():
|
||||
mock_tools = _make_tool_list("push_notification")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -249,7 +249,7 @@ async def test_execute_tool_returns_output_response():
|
||||
text_result = "# Example Domain\nThis domain is for examples."
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -285,7 +285,7 @@ async def test_execute_tool_parses_json_result():
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -320,7 +320,7 @@ async def test_execute_tool_image_content():
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -359,7 +359,7 @@ async def test_execute_tool_resource_content():
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -399,7 +399,7 @@ async def test_execute_tool_multi_item_content():
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -437,7 +437,7 @@ async def test_execute_tool_empty_content_returns_none():
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -470,7 +470,7 @@ async def test_execute_tool_returns_error_on_tool_failure():
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -512,7 +512,7 @@ async def test_auth_required_without_creds_returns_setup_requirements():
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -555,7 +555,7 @@ async def test_auth_error_with_existing_creds_returns_error():
|
||||
mock_creds.access_token = SecretStr("stale-token")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -589,7 +589,7 @@ async def test_mcp_client_error_returns_error_response():
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -621,7 +621,7 @@ async def test_unexpected_exception_returns_generic_error():
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
@@ -719,7 +719,7 @@ async def test_credential_lookup_normalizes_trailing_slash():
|
||||
url_with_slash = "https://mcp.example.com/mcp/"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url_host", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
|
||||
116
autogpt_platform/backend/backend/copilot/tools/validate_agent.py
Normal file
116
autogpt_platform/backend/backend/copilot/tools/validate_agent.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""ValidateAgentGraphTool - Validates agent JSON structure."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .agent_generator.validation import AgentValidator, get_blocks_as_dicts
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ToolResponseBase, ValidationResultResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ValidateAgentGraphTool(BaseTool):
|
||||
"""Tool for validating agent JSON graphs."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "validate_agent_graph"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Validate an agent JSON graph for correctness. Checks:\n"
|
||||
"- All block_ids reference real blocks\n"
|
||||
"- All links reference valid source/sink nodes and fields\n"
|
||||
"- Required input fields are wired or have defaults\n"
|
||||
"- Data types are compatible across links\n"
|
||||
"- Nested sink links use correct notation\n"
|
||||
"- Prompt templates use proper curly brace escaping\n"
|
||||
"- AgentExecutorBlock configurations are valid\n\n"
|
||||
"Call this after generating agent JSON to verify correctness. "
|
||||
"If validation fails, either fix issues manually based on the error "
|
||||
"descriptions, or call fix_agent_graph to auto-fix common problems."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_json": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"The agent JSON to validate. Must contain 'nodes' and 'links' arrays. "
|
||||
"Each node needs: id (UUID), block_id, input_default, metadata. "
|
||||
"Each link needs: id (UUID), source_id, source_name, sink_id, sink_name."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["agent_json"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
agent_json = kwargs.get("agent_json")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_json or not isinstance(agent_json, dict):
|
||||
return ErrorResponse(
|
||||
message="Please provide a valid agent JSON object.",
|
||||
error="Missing or invalid agent_json parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
nodes = agent_json.get("nodes", [])
|
||||
|
||||
if not nodes:
|
||||
return ErrorResponse(
|
||||
message="The agent JSON has no nodes. An agent needs at least one block.",
|
||||
error="empty_agent",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
blocks = get_blocks_as_dicts()
|
||||
validator = AgentValidator()
|
||||
is_valid, error_message = validator.validate(agent_json, blocks)
|
||||
except Exception as e:
|
||||
logger.error(f"Validation error: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Validation encountered an error: {str(e)}",
|
||||
error="validation_exception",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if is_valid:
|
||||
return ValidationResultResponse(
|
||||
message="Agent graph is valid! No issues found.",
|
||||
valid=True,
|
||||
errors=[],
|
||||
error_count=0,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Parse individual errors from the validator's error list
|
||||
errors = validator.errors if hasattr(validator, "errors") else []
|
||||
if not errors and error_message:
|
||||
errors = [error_message]
|
||||
|
||||
return ValidationResultResponse(
|
||||
message=f"Found {len(errors)} validation error(s). Fix them and re-validate.",
|
||||
valid=False,
|
||||
errors=errors,
|
||||
error_count=len(errors),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -0,0 +1,160 @@
|
||||
"""Tests for ValidateAgentGraphTool."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.tools.models import ErrorResponse, ValidationResultResponse
|
||||
from backend.copilot.tools.validate_agent import ValidateAgentGraphTool
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-validate-agent"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return ValidateAgentGraphTool()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session():
|
||||
return make_session(_TEST_USER_ID)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_agent_json_returns_error(tool, session):
|
||||
"""Missing agent_json returns ErrorResponse."""
|
||||
result = await tool._execute(user_id=_TEST_USER_ID, session=session)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error is not None
|
||||
assert "agent_json" in result.error.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_nodes_returns_error(tool, session):
|
||||
"""Agent JSON with no nodes returns ErrorResponse."""
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json={"nodes": [], "links": []},
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "no nodes" in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_agent_returns_success(tool, session):
|
||||
"""Valid agent returns ValidationResultResponse with valid=True."""
|
||||
agent_json = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"block_id": "block-1",
|
||||
"input_default": {},
|
||||
"metadata": {"position": {"x": 0, "y": 0}},
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
}
|
||||
|
||||
mock_validator = MagicMock()
|
||||
mock_validator.validate.return_value = (True, None)
|
||||
mock_validator.errors = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.validate_agent.get_blocks_as_dicts",
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.validate_agent.AgentValidator",
|
||||
return_value=mock_validator,
|
||||
),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json=agent_json,
|
||||
)
|
||||
|
||||
assert isinstance(result, ValidationResultResponse)
|
||||
assert result.valid is True
|
||||
assert result.error_count == 0
|
||||
assert result.errors == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_agent_returns_errors(tool, session):
|
||||
"""Invalid agent returns ValidationResultResponse with errors."""
|
||||
agent_json = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"block_id": "nonexistent-block",
|
||||
"input_default": {},
|
||||
"metadata": {},
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
}
|
||||
|
||||
mock_validator = MagicMock()
|
||||
mock_validator.validate.return_value = (False, "Validation failed")
|
||||
mock_validator.errors = [
|
||||
"Block 'nonexistent-block' not found in registry",
|
||||
"Missing required input field 'prompt'",
|
||||
]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.validate_agent.get_blocks_as_dicts",
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.validate_agent.AgentValidator",
|
||||
return_value=mock_validator,
|
||||
),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json=agent_json,
|
||||
)
|
||||
|
||||
assert isinstance(result, ValidationResultResponse)
|
||||
assert result.valid is False
|
||||
assert result.error_count == 2
|
||||
assert len(result.errors) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_exception_returns_error(tool, session):
|
||||
"""Validator exception returns ErrorResponse."""
|
||||
agent_json = {
|
||||
"nodes": [{"id": "n1", "block_id": "b1", "input_default": {}, "metadata": {}}],
|
||||
"links": [],
|
||||
}
|
||||
|
||||
mock_validator = MagicMock()
|
||||
mock_validator.validate.side_effect = RuntimeError("unexpected error")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.validate_agent.get_blocks_as_dicts",
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.validate_agent.AgentValidator",
|
||||
return_value=mock_validator,
|
||||
),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json=agent_json,
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error is not None
|
||||
assert "validation_exception" in result.error
|
||||
@@ -7,8 +7,12 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.context import (
|
||||
E2B_WORKDIR,
|
||||
get_current_sandbox,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.e2b_sandbox import E2B_WORKDIR
|
||||
from backend.copilot.tools.sandbox import make_session_path
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.settings import Config
|
||||
@@ -83,13 +87,11 @@ def _resolve_sandbox_path(
|
||||
) -> str | ErrorResponse:
|
||||
"""Normalize *path* to an absolute sandbox path under :data:`E2B_WORKDIR`.
|
||||
|
||||
Delegates to :func:`~backend.copilot.sdk.e2b_file_tools._resolve_remote`
|
||||
Delegates to :func:`~backend.copilot.sdk.e2b_file_tools.resolve_sandbox_path`
|
||||
and wraps any ``ValueError`` into an :class:`ErrorResponse`.
|
||||
"""
|
||||
from backend.copilot.sdk.e2b_file_tools import _resolve_remote
|
||||
|
||||
try:
|
||||
return _resolve_remote(path)
|
||||
return resolve_sandbox_path(path)
|
||||
except ValueError:
|
||||
return ErrorResponse(
|
||||
message=f"{param_name} must be within {E2B_WORKDIR}",
|
||||
@@ -99,7 +101,6 @@ def _resolve_sandbox_path(
|
||||
|
||||
async def _read_source_path(source_path: str, session_id: str) -> bytes | ErrorResponse:
|
||||
"""Read *source_path* from E2B sandbox or local ephemeral directory."""
|
||||
from backend.copilot.sdk.tool_adapter import get_current_sandbox
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
@@ -143,7 +144,6 @@ async def _save_to_path(
|
||||
|
||||
Returns the resolved path on success, or an ``ErrorResponse`` on failure.
|
||||
"""
|
||||
from backend.copilot.sdk.tool_adapter import get_current_sandbox
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user