mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
2 Commits
master
...
swiftyos/c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7d6338c1fc | ||
|
|
ea69536165 |
@@ -1,17 +0,0 @@
|
||||
---
|
||||
name: backend-check
|
||||
description: Run the full backend formatting, linting, and test suite. Ensures code quality before commits and PRs. TRIGGER when backend Python code has been modified and needs validation.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Backend Check
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Format**: `poetry run format` — runs formatting AND linting. NEVER run ruff/black/isort individually
|
||||
2. **Fix** any remaining errors manually, re-run until clean
|
||||
3. **Test**: `poetry run test` (runs DB setup + pytest). For specific files: `poetry run pytest -s -vvv <test_files>`
|
||||
4. **Snapshots** (if needed): `poetry run pytest path/to/test.py --snapshot-update` — review with `git diff`
|
||||
@@ -1,35 +0,0 @@
|
||||
---
|
||||
name: code-style
|
||||
description: Python code style preferences for the AutoGPT backend. Apply when writing or reviewing Python code. TRIGGER when writing new Python code, reviewing PRs, or refactoring backend code.
|
||||
user-invocable: false
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Code Style
|
||||
|
||||
## Imports
|
||||
|
||||
- **Top-level only** — no local/inner imports. Move all imports to the top of the file.
|
||||
|
||||
## Typing
|
||||
|
||||
- **No duck typing** — avoid `hasattr`, `getattr`, `isinstance` for type dispatch. Use proper typed interfaces, unions, or protocols.
|
||||
- **Pydantic models** over dataclass, namedtuple, or raw dict for structured data.
|
||||
- **No linter suppressors** — avoid `# type: ignore`, `# noqa`, `# pyright: ignore` etc. 99% of the time the right fix is fixing the type/code, not silencing the tool.
|
||||
|
||||
## Code Structure
|
||||
|
||||
- **List comprehensions** over manual loop-and-append.
|
||||
- **Early return** — guard clauses first, avoid deep nesting.
|
||||
- **Flatten inline** — prefer short, concise expressions. Reduce `if/else` chains with direct returns or ternaries when readable.
|
||||
- **Modular functions** — break complex logic into small, focused functions rather than long blocks with nested conditionals.
|
||||
|
||||
## Review Checklist
|
||||
|
||||
Before finishing, always ask:
|
||||
- Can any function be split into smaller pieces?
|
||||
- Is there unnecessary nesting that an early return would eliminate?
|
||||
- Can any loop be a comprehension?
|
||||
- Is there a simpler way to express this logic?
|
||||
@@ -1,16 +0,0 @@
|
||||
---
|
||||
name: frontend-check
|
||||
description: Run the full frontend formatting, linting, and type checking suite. Ensures code quality before commits and PRs. TRIGGER when frontend TypeScript/React code has been modified and needs validation.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Frontend Check
|
||||
|
||||
## Steps (in order)
|
||||
|
||||
1. **Format**: `pnpm format` — NEVER run individual formatters
|
||||
2. **Lint**: `pnpm lint` — fix errors, re-run until clean
|
||||
3. **Types**: `pnpm types` — if it keeps failing after multiple attempts, stop and ask the user
|
||||
@@ -1,29 +0,0 @@
|
||||
---
|
||||
name: new-block
|
||||
description: Create a new backend block following the Block SDK Guide. Guides through provider configuration, schema definition, authentication, and testing. TRIGGER when user asks to create a new block, add a new integration, or build a new node for the graph editor.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# New Block Creation
|
||||
|
||||
Read `docs/platform/block-sdk-guide.md` first for the full guide.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Provider config** (if external service): create `_config.py` with `ProviderBuilder`
|
||||
2. **Block file** in `backend/blocks/` (from `autogpt_platform/backend/`):
|
||||
- Generate a UUID once with `uuid.uuid4()`, then **hard-code that string** as `id` (IDs must be stable across imports)
|
||||
- `Input(BlockSchema)` and `Output(BlockSchema)` classes
|
||||
- `async def run` that `yield`s output fields
|
||||
3. **Files**: use `store_media_file()` with `"for_block_output"` for outputs
|
||||
4. **Test**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[MyBlock]' -xvs`
|
||||
5. **Format**: `poetry run format`
|
||||
|
||||
## Rules
|
||||
|
||||
- Analyze interfaces: do inputs/outputs connect well with other blocks in a graph?
|
||||
- Use top-level imports, avoid duck typing
|
||||
- Always use `for_block_output` for block outputs
|
||||
@@ -1,28 +0,0 @@
|
||||
---
|
||||
name: openapi-regen
|
||||
description: Regenerate the OpenAPI spec and frontend API client. Starts the backend REST server, fetches the spec, and regenerates the typed frontend hooks. TRIGGER when API routes change, new endpoints are added, or frontend API types are stale.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# OpenAPI Spec Regeneration
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Run end-to-end** in a single shell block (so `REST_PID` persists):
|
||||
```bash
|
||||
cd autogpt_platform/backend && poetry run rest &
|
||||
REST_PID=$!
|
||||
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && kill $REST_PID && exit 1; done
|
||||
cd ../frontend && pnpm generate:api:force
|
||||
kill $REST_PID
|
||||
pnpm types && pnpm lint && pnpm format
|
||||
```
|
||||
|
||||
## Rules
|
||||
|
||||
- Always use `pnpm generate:api:force` (not `pnpm generate:api`)
|
||||
- Don't manually edit files in `src/app/api/__generated__/`
|
||||
- Generated hooks follow: `use{Method}{Version}{OperationName}`
|
||||
@@ -1,31 +0,0 @@
|
||||
---
|
||||
name: pr-create
|
||||
description: Create a pull request for the current branch. TRIGGER when user asks to create a PR, open a pull request, push changes for review, or submit work for merging.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Create Pull Request
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Check for existing PR**: `gh pr view --json url -q .url 2>/dev/null` — if a PR already exists, output its URL and stop
|
||||
2. **Understand changes**: `git status`, `git diff dev...HEAD`, `git log dev..HEAD --oneline`
|
||||
3. **Read PR template**: `.github/PULL_REQUEST_TEMPLATE.md`
|
||||
4. **Draft PR title**: Use conventional commits format (see CLAUDE.md for types and scopes)
|
||||
5. **Fill out PR template** as the body — be thorough in the Changes section
|
||||
6. **Format first** (if relevant changes exist):
|
||||
- Backend: `cd autogpt_platform/backend && poetry run format`
|
||||
- Frontend: `cd autogpt_platform/frontend && pnpm format`
|
||||
- Fix any lint errors, then commit formatting changes before pushing
|
||||
7. **Push**: `git push -u origin HEAD`
|
||||
8. **Create PR**: `gh pr create --base dev`
|
||||
9. **Output** the PR URL
|
||||
|
||||
## Rules
|
||||
|
||||
- Always target `dev` branch
|
||||
- Do NOT run tests — CI will handle that
|
||||
- Use the PR template from `.github/PULL_REQUEST_TEMPLATE.md`
|
||||
@@ -1,51 +0,0 @@
|
||||
---
|
||||
name: pr-review
|
||||
description: Address all open PR review comments systematically. Fetches comments, addresses each one, reacts +1/-1, and replies when clarification is needed. Keeps iterating until all comments are addressed and CI is green. TRIGGER when user shares a PR URL, asks to address review comments, fix PR feedback, or respond to reviewer comments.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# PR Review Comment Workflow
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Find PR**: `gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT`
|
||||
2. **Fetch comments** (all three sources):
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` (top-level reviews)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` (inline review comments)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` (PR conversation comments)
|
||||
3. **Skip** comments already reacted to by PR author
|
||||
4. **For each unreacted comment**:
|
||||
- Read referenced code, make the fix (or reply if you disagree/need info)
|
||||
- **Inline review comments** (`pulls/{N}/comments`):
|
||||
- React: `gh api repos/.../pulls/comments/{ID}/reactions -f content="+1"` (or `-1`)
|
||||
- Reply: `gh api repos/.../pulls/{N}/comments/{ID}/replies -f body="..."`
|
||||
- **PR conversation comments** (`issues/{N}/comments`):
|
||||
- React: `gh api repos/.../issues/comments/{ID}/reactions -f content="+1"` (or `-1`)
|
||||
- No threaded replies — post a new issue comment if needed
|
||||
- **Top-level reviews**: no reaction API — address in code, reply via issue comment if needed
|
||||
5. **Include autogpt-reviewer bot fixes** too
|
||||
6. **Format**: `cd autogpt_platform/backend && poetry run format`, `cd autogpt_platform/frontend && pnpm format`
|
||||
7. **Commit & push**
|
||||
8. **Re-fetch comments** immediately — address any new unreacted ones before waiting on CI
|
||||
9. **Stay productive while CI runs** — don't idle. In priority order:
|
||||
- Run any pending local tests (`poetry run pytest`, e2e, etc.) and fix failures
|
||||
- Address any remaining comments
|
||||
- Only poll `gh pr checks {N}` as the last resort when there's truly nothing left to do
|
||||
10. **If CI fails** — fix, go back to step 6
|
||||
11. **Re-fetch comments again** after CI is green — address anything that appeared while CI was running
|
||||
12. **Done** only when: all comments reacted AND CI is green.
|
||||
|
||||
## CRITICAL: Do Not Stop
|
||||
|
||||
**Loop is: address → format → commit → push → re-check comments → run local tests → wait CI → re-check comments → repeat.**
|
||||
|
||||
Never idle. If CI is running and you have nothing to address, run local tests. Waiting on CI is the last resort.
|
||||
|
||||
## Rules
|
||||
|
||||
- One todo per comment
|
||||
- For inline review comments: reply on existing threads. For PR conversation comments: post a new issue comment (API doesn't support threaded replies)
|
||||
- React to every comment: +1 addressed, -1 disagreed (with explanation)
|
||||
@@ -1,45 +0,0 @@
|
||||
---
|
||||
name: worktree-setup
|
||||
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, generates Prisma client, and optionally starts the app (with port conflict resolution) or runs tests. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Worktree Setup
|
||||
|
||||
## Preferred: Use Branchlet
|
||||
|
||||
The repo has a `.branchlet.json` config — it handles env file copying, dependency installation, and Prisma generation automatically.
|
||||
|
||||
```bash
|
||||
npm install -g branchlet # install once
|
||||
branchlet create -n <name> -s <source-branch> -b <new-branch>
|
||||
branchlet list --json # list all worktrees
|
||||
```
|
||||
|
||||
## Manual Fallback
|
||||
|
||||
If branchlet isn't available:
|
||||
|
||||
1. `git worktree add ../<RepoName><N> <branch-name>`
|
||||
2. Copy `.env` files: `backend/.env`, `frontend/.env`, `autogpt_platform/.env`, `db/docker/.env`
|
||||
3. Install deps:
|
||||
- `cd autogpt_platform/backend && poetry install && poetry run prisma generate`
|
||||
- `cd autogpt_platform/frontend && pnpm install`
|
||||
|
||||
## Running the App
|
||||
|
||||
Free ports first — backend uses: 8001, 8002, 8003, 8005, 8006, 8007, 8008.
|
||||
|
||||
```bash
|
||||
for port in 8001 8002 8003 8005 8006 8007 8008; do
|
||||
lsof -ti :$port | xargs kill -9 2>/dev/null || true
|
||||
done
|
||||
cd <worktree>/autogpt_platform/backend && poetry run app
|
||||
```
|
||||
|
||||
## CoPilot Testing Gotcha
|
||||
|
||||
SDK mode spawns a Claude subprocess — **won't work inside Claude Code**. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.
|
||||
@@ -9,6 +9,21 @@ import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_id() -> str:
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_id() -> str:
|
||||
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_test_user(test_user_id: str) -> str:
|
||||
return test_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def configured_snapshot(snapshot: Snapshot) -> Snapshot:
|
||||
"""Pre-configured snapshot fixture with standard settings."""
|
||||
|
||||
@@ -28,7 +28,6 @@ from backend.copilot.model import (
|
||||
update_session_title,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
from backend.copilot.tools.models import (
|
||||
AgentDetailsResponse,
|
||||
AgentOutputResponse,
|
||||
@@ -53,7 +52,6 @@ from backend.copilot.tools.models import (
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
@@ -128,7 +126,6 @@ class SessionSummaryResponse(BaseModel):
|
||||
created_at: str
|
||||
updated_at: str
|
||||
title: str | None = None
|
||||
is_processing: bool
|
||||
|
||||
|
||||
class ListSessionsResponse(BaseModel):
|
||||
@@ -187,28 +184,6 @@ async def list_sessions(
|
||||
"""
|
||||
sessions, total_count = await get_user_sessions(user_id, limit, offset)
|
||||
|
||||
# Batch-check Redis for active stream status on each session
|
||||
processing_set: set[str] = set()
|
||||
if sessions:
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
pipe = redis.pipeline(transaction=False)
|
||||
for session in sessions:
|
||||
pipe.hget(
|
||||
f"{config.session_meta_prefix}{session.session_id}",
|
||||
"status",
|
||||
)
|
||||
statuses = await pipe.execute()
|
||||
processing_set = {
|
||||
session.session_id
|
||||
for session, st in zip(sessions, statuses)
|
||||
if st == "running"
|
||||
}
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch processing status from Redis; " "defaulting to empty"
|
||||
)
|
||||
|
||||
return ListSessionsResponse(
|
||||
sessions=[
|
||||
SessionSummaryResponse(
|
||||
@@ -216,7 +191,6 @@ async def list_sessions(
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
title=session.title,
|
||||
is_processing=session.session_id in processing_set,
|
||||
)
|
||||
for session in sessions
|
||||
],
|
||||
@@ -291,12 +265,12 @@ async def delete_session(
|
||||
)
|
||||
|
||||
# Best-effort cleanup of the E2B sandbox (if any).
|
||||
# sandbox_id is in Redis; kill_sandbox() fetches it from there.
|
||||
e2b_cfg = ChatConfig()
|
||||
if e2b_cfg.e2b_active:
|
||||
assert e2b_cfg.e2b_api_key # guaranteed by e2b_active check
|
||||
config = ChatConfig()
|
||||
if config.use_e2b_sandbox and config.e2b_api_key:
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
|
||||
try:
|
||||
await kill_sandbox(session_id, e2b_cfg.e2b_api_key)
|
||||
await kill_sandbox(session_id, config.e2b_api_key)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[E2B] Failed to kill sandbox for session %s", session_id[:12]
|
||||
|
||||
@@ -638,7 +638,7 @@ async def test_process_review_action_auto_approve_creates_auto_approval_records(
|
||||
|
||||
# Mock get_node_executions to return node_id mapping
|
||||
mock_get_node_executions = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_node_executions"
|
||||
"backend.data.execution.get_node_executions"
|
||||
)
|
||||
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
||||
mock_node_exec.node_exec_id = "test_node_123"
|
||||
@@ -936,7 +936,7 @@ async def test_process_review_action_auto_approve_only_applies_to_approved_revie
|
||||
|
||||
# Mock get_node_executions to return node_id mapping
|
||||
mock_get_node_executions = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_node_executions"
|
||||
"backend.data.execution.get_node_executions"
|
||||
)
|
||||
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
||||
mock_node_exec.node_exec_id = "node_exec_approved"
|
||||
@@ -1148,7 +1148,7 @@ async def test_process_review_action_per_review_auto_approve_granularity(
|
||||
|
||||
# Mock get_node_executions to return batch node data
|
||||
mock_get_node_executions = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_node_executions"
|
||||
"backend.data.execution.get_node_executions"
|
||||
)
|
||||
# Create mock node executions for each review
|
||||
mock_node_execs = []
|
||||
|
||||
@@ -6,15 +6,10 @@ import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, HTTPException, Query, Security, status
|
||||
from prisma.enums import ReviewStatus
|
||||
|
||||
from backend.copilot.constants import (
|
||||
is_copilot_synthetic_id,
|
||||
parse_node_id_from_exec_id,
|
||||
)
|
||||
from backend.data.execution import (
|
||||
ExecutionContext,
|
||||
ExecutionStatus,
|
||||
get_graph_execution_meta,
|
||||
get_node_executions,
|
||||
)
|
||||
from backend.data.graph import get_graph_settings
|
||||
from backend.data.human_review import (
|
||||
@@ -41,38 +36,6 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_node_ids(
|
||||
node_exec_ids: list[str],
|
||||
graph_exec_id: str,
|
||||
is_copilot: bool,
|
||||
) -> dict[str, str]:
|
||||
"""Resolve node_exec_id -> node_id for auto-approval records.
|
||||
|
||||
CoPilot synthetic IDs encode node_id in the format "{node_id}:{random}".
|
||||
Graph executions look up node_id from NodeExecution records.
|
||||
"""
|
||||
if not node_exec_ids:
|
||||
return {}
|
||||
|
||||
if is_copilot:
|
||||
return {neid: parse_node_id_from_exec_id(neid) for neid in node_exec_ids}
|
||||
|
||||
node_execs = await get_node_executions(
|
||||
graph_exec_id=graph_exec_id, include_exec_data=False
|
||||
)
|
||||
node_exec_map = {ne.node_exec_id: ne.node_id for ne in node_execs}
|
||||
|
||||
result = {}
|
||||
for neid in node_exec_ids:
|
||||
if neid in node_exec_map:
|
||||
result[neid] = node_exec_map[neid]
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to resolve node_id for {neid}: Node execution not found."
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@router.get(
|
||||
"/pending",
|
||||
summary="Get Pending Reviews",
|
||||
@@ -147,16 +110,14 @@ async def list_pending_reviews_for_execution(
|
||||
"""
|
||||
|
||||
# Verify user owns the graph execution before returning reviews
|
||||
# (CoPilot synthetic IDs don't have graph execution records)
|
||||
if not is_copilot_synthetic_id(graph_exec_id):
|
||||
graph_exec = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
graph_exec = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not graph_exec:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
)
|
||||
if not graph_exec:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
)
|
||||
|
||||
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
|
||||
|
||||
@@ -199,26 +160,30 @@ async def process_review_action(
|
||||
)
|
||||
|
||||
graph_exec_id = next(iter(graph_exec_ids))
|
||||
is_copilot = is_copilot_synthetic_id(graph_exec_id)
|
||||
|
||||
# Validate execution status for graph executions (skip for CoPilot synthetic IDs)
|
||||
if not is_copilot:
|
||||
graph_exec_meta = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
# Validate execution status before processing reviews
|
||||
graph_exec_meta = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
|
||||
if not graph_exec_meta:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
)
|
||||
|
||||
# Only allow processing reviews if execution is paused for review
|
||||
# or incomplete (partial execution with some reviews already processed)
|
||||
if graph_exec_meta.status not in (
|
||||
ExecutionStatus.REVIEW,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. "
|
||||
f"Reviews can only be processed when execution is paused (REVIEW status). "
|
||||
f"Current status: {graph_exec_meta.status}",
|
||||
)
|
||||
if not graph_exec_meta:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
)
|
||||
if graph_exec_meta.status not in (
|
||||
ExecutionStatus.REVIEW,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}",
|
||||
)
|
||||
|
||||
# Build review decisions map and track which reviews requested auto-approval
|
||||
# Auto-approved reviews use original data (no modifications allowed)
|
||||
@@ -271,7 +236,7 @@ async def process_review_action(
|
||||
)
|
||||
return (node_id, False)
|
||||
|
||||
# Collect node_exec_ids that need auto-approval and resolve their node_ids
|
||||
# Collect node_exec_ids that need auto-approval
|
||||
node_exec_ids_needing_auto_approval = [
|
||||
node_exec_id
|
||||
for node_exec_id, review_result in updated_reviews.items()
|
||||
@@ -279,16 +244,29 @@ async def process_review_action(
|
||||
and auto_approve_requests.get(node_exec_id, False)
|
||||
]
|
||||
|
||||
node_id_map = await _resolve_node_ids(
|
||||
node_exec_ids_needing_auto_approval, graph_exec_id, is_copilot
|
||||
)
|
||||
|
||||
# Deduplicate by node_id — one auto-approval per node
|
||||
# Batch-fetch node executions to get node_ids
|
||||
nodes_needing_auto_approval: dict[str, Any] = {}
|
||||
for node_exec_id in node_exec_ids_needing_auto_approval:
|
||||
node_id = node_id_map.get(node_exec_id)
|
||||
if node_id and node_id not in nodes_needing_auto_approval:
|
||||
nodes_needing_auto_approval[node_id] = updated_reviews[node_exec_id]
|
||||
if node_exec_ids_needing_auto_approval:
|
||||
from backend.data.execution import get_node_executions
|
||||
|
||||
node_execs = await get_node_executions(
|
||||
graph_exec_id=graph_exec_id, include_exec_data=False
|
||||
)
|
||||
node_exec_map = {node_exec.node_exec_id: node_exec for node_exec in node_execs}
|
||||
|
||||
for node_exec_id in node_exec_ids_needing_auto_approval:
|
||||
node_exec = node_exec_map.get(node_exec_id)
|
||||
if node_exec:
|
||||
review_result = updated_reviews[node_exec_id]
|
||||
# Use the first approved review for this node (deduplicate by node_id)
|
||||
if node_exec.node_id not in nodes_needing_auto_approval:
|
||||
nodes_needing_auto_approval[node_exec.node_id] = review_result
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to create auto-approval record for {node_exec_id}: "
|
||||
f"Node execution not found. This may indicate a race condition "
|
||||
f"or data inconsistency."
|
||||
)
|
||||
|
||||
# Execute all auto-approval creations in parallel (deduplicated by node_id)
|
||||
auto_approval_results = await asyncio.gather(
|
||||
@@ -303,11 +281,13 @@ async def process_review_action(
|
||||
auto_approval_failed_count = 0
|
||||
for result in auto_approval_results:
|
||||
if isinstance(result, Exception):
|
||||
# Unexpected exception during auto-approval creation
|
||||
auto_approval_failed_count += 1
|
||||
logger.error(
|
||||
f"Unexpected exception during auto-approval creation: {result}"
|
||||
)
|
||||
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
|
||||
# Auto-approval creation failed (returned False)
|
||||
auto_approval_failed_count += 1
|
||||
|
||||
# Count results
|
||||
@@ -322,20 +302,22 @@ async def process_review_action(
|
||||
if review.status == ReviewStatus.REJECTED
|
||||
)
|
||||
|
||||
# Resume graph execution only for real graph executions (not CoPilot)
|
||||
# CoPilot sessions are resumed by the LLM retrying run_block with review_id
|
||||
if not is_copilot and updated_reviews:
|
||||
# Resume execution only if ALL pending reviews for this execution have been processed
|
||||
if updated_reviews:
|
||||
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
|
||||
|
||||
if not still_has_pending:
|
||||
# Get the graph_id from any processed review
|
||||
first_review = next(iter(updated_reviews.values()))
|
||||
|
||||
try:
|
||||
# Fetch user and settings to build complete execution context
|
||||
user = await get_user_by_id(user_id)
|
||||
settings = await get_graph_settings(
|
||||
user_id=user_id, graph_id=first_review.graph_id
|
||||
)
|
||||
|
||||
# Preserve user's timezone preference when resuming execution
|
||||
user_timezone = (
|
||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||
)
|
||||
|
||||
@@ -165,6 +165,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
owner_user_id: str
|
||||
|
||||
image_url: str | None
|
||||
|
||||
@@ -205,9 +206,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
default_factory=list,
|
||||
description="List of recent executions with status, score, and summary",
|
||||
)
|
||||
can_access_graph: bool = pydantic.Field(
|
||||
description="Indicates whether the same user owns the corresponding graph"
|
||||
)
|
||||
can_access_graph: bool
|
||||
is_latest_version: bool
|
||||
is_favorite: bool
|
||||
folder_id: str | None = None
|
||||
@@ -325,6 +324,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id=agent.id,
|
||||
graph_id=agent.agentGraphId,
|
||||
graph_version=agent.agentGraphVersion,
|
||||
owner_user_id=agent.userId,
|
||||
image_url=agent.imageUrl,
|
||||
creator_name=creator_name,
|
||||
creator_image_url=creator_image_url,
|
||||
|
||||
@@ -42,6 +42,7 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
@@ -66,6 +67,7 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-2",
|
||||
graph_id="test-agent-2",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 2",
|
||||
description="Test Description 2",
|
||||
image_url=None,
|
||||
@@ -129,6 +131,7 @@ async def test_get_favorite_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Favorite Agent 1",
|
||||
description="Test Favorite Description 1",
|
||||
image_url=None,
|
||||
@@ -181,6 +184,7 @@ def test_add_agent_to_library_success(
|
||||
id="test-library-agent-id",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
|
||||
@@ -24,7 +24,7 @@ from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import HTTPClientError, Requests, validate_url_host
|
||||
from backend.util.request import HTTPClientError, Requests, validate_url
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -80,7 +80,7 @@ async def discover_tools(
|
||||
"""
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url_host(request.server_url)
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
@@ -167,7 +167,7 @@ async def mcp_oauth_login(
|
||||
"""
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url_host(request.server_url)
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
@@ -187,7 +187,7 @@ async def mcp_oauth_login(
|
||||
|
||||
# Validate the auth server URL from metadata to prevent SSRF.
|
||||
try:
|
||||
await validate_url_host(auth_server_url)
|
||||
await validate_url(auth_server_url, trusted_origins=[])
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=400,
|
||||
@@ -234,7 +234,7 @@ async def mcp_oauth_login(
|
||||
if registration_endpoint:
|
||||
# Validate the registration endpoint to prevent SSRF via metadata.
|
||||
try:
|
||||
await validate_url_host(registration_endpoint)
|
||||
await validate_url(registration_endpoint, trusted_origins=[])
|
||||
except ValueError:
|
||||
pass # Skip registration, fall back to default client_id
|
||||
else:
|
||||
@@ -429,7 +429,7 @@ async def mcp_store_token(
|
||||
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url_host(request.server_url)
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
|
||||
@@ -32,9 +32,9 @@ async def client():
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _bypass_ssrf_validation():
|
||||
"""Bypass validate_url_host in all route tests (test URLs don't resolve)."""
|
||||
"""Bypass validate_url in all route tests (test URLs don't resolve)."""
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url_host",
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
yield
|
||||
@@ -521,12 +521,12 @@ class TestStoreToken:
|
||||
|
||||
|
||||
class TestSSRFValidation:
|
||||
"""Verify that validate_url_host is enforced on all endpoints."""
|
||||
"""Verify that validate_url is enforced on all endpoints."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_ssrf_blocked(self, client):
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url_host",
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked loopback"),
|
||||
):
|
||||
@@ -541,7 +541,7 @@ class TestSSRFValidation:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_login_ssrf_blocked(self, client):
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url_host",
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked private IP"),
|
||||
):
|
||||
@@ -556,7 +556,7 @@ class TestSSRFValidation:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_token_ssrf_blocked(self, client):
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url_host",
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked loopback"),
|
||||
):
|
||||
|
||||
@@ -29,6 +29,7 @@ from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
from backend.api.model import (
|
||||
BusinessUnderstandingPromptsResponse,
|
||||
CreateAPIKeyRequest,
|
||||
CreateAPIKeyResponse,
|
||||
CreateGraph,
|
||||
@@ -54,6 +55,7 @@ from backend.data.credit import (
|
||||
get_user_credit_model,
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.data.db_accessors import understanding_db
|
||||
from backend.data.graph import GraphSettings
|
||||
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
@@ -158,6 +160,22 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
return user.model_dump()
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
"/auth/user/understanding/prompts",
|
||||
summary="Get business understanding prompts",
|
||||
tags=["auth"],
|
||||
dependencies=[Security(requires_user)],
|
||||
response_model=BusinessUnderstandingPromptsResponse,
|
||||
)
|
||||
async def get_business_understanding_prompts_route(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> BusinessUnderstandingPromptsResponse:
|
||||
understanding = await understanding_db().get_business_understanding(user_id)
|
||||
return BusinessUnderstandingPromptsResponse(
|
||||
prompts=understanding.prompts if understanding else []
|
||||
)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/auth/user/email",
|
||||
summary="Update user email",
|
||||
|
||||
@@ -89,6 +89,48 @@ def test_update_user_email_route(
|
||||
)
|
||||
|
||||
|
||||
def test_get_business_understanding_prompts_route(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_understanding_db = Mock()
|
||||
mock_understanding_db.get_business_understanding = AsyncMock(
|
||||
return_value=Mock(prompts=["Prompt one", "Prompt two", "Prompt three"])
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.understanding_db",
|
||||
return_value=mock_understanding_db,
|
||||
)
|
||||
|
||||
response = client.get("/auth/user/understanding/prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": ["Prompt one", "Prompt two", "Prompt three"]}
|
||||
mock_understanding_db.get_business_understanding.assert_awaited_once_with(
|
||||
test_user_id
|
||||
)
|
||||
|
||||
|
||||
def test_get_business_understanding_prompts_route_returns_empty_list(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_understanding_db = Mock()
|
||||
mock_understanding_db.get_business_understanding = AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.understanding_db",
|
||||
return_value=mock_understanding_db,
|
||||
)
|
||||
|
||||
response = client.get("/auth/user/understanding/prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": []}
|
||||
mock_understanding_db.get_business_understanding.assert_awaited_once_with(
|
||||
test_user_id
|
||||
)
|
||||
|
||||
|
||||
# Blocks endpoints tests
|
||||
def test_get_graph_blocks(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
|
||||
@@ -85,6 +85,10 @@ class UpdateTimezoneRequest(pydantic.BaseModel):
|
||||
timezone: TimeZoneName
|
||||
|
||||
|
||||
class BusinessUnderstandingPromptsResponse(pydantic.BaseModel):
|
||||
prompts: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class NotificationPayload(pydantic.BaseModel):
|
||||
type: str
|
||||
event: str
|
||||
@@ -94,8 +98,3 @@ class NotificationPayload(pydantic.BaseModel):
|
||||
|
||||
class OnboardingNotificationPayload(NotificationPayload):
|
||||
step: OnboardingStep | None
|
||||
|
||||
|
||||
class CopilotCompletionPayload(NotificationPayload):
|
||||
session_id: str
|
||||
status: Literal["completed", "failed"]
|
||||
|
||||
@@ -418,8 +418,6 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
|
||||
|
||||
|
||||
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
_optimized_description: ClassVar[str | None] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str = "",
|
||||
@@ -472,8 +470,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.block_type = block_type
|
||||
self.webhook_config = webhook_config
|
||||
self.is_sensitive_action = is_sensitive_action
|
||||
# Read from ClassVar set by initialize_blocks()
|
||||
self.optimized_description: str | None = type(self)._optimized_description
|
||||
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
|
||||
|
||||
if self.webhook_config:
|
||||
@@ -624,7 +620,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: "ExecutionContext",
|
||||
is_graph_execution: bool = True,
|
||||
**kwargs,
|
||||
) -> tuple[bool, BlockInput]:
|
||||
"""
|
||||
@@ -653,7 +648,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
graph_version=graph_version,
|
||||
block_name=self.name,
|
||||
editable=True,
|
||||
is_graph_execution=is_graph_execution,
|
||||
)
|
||||
|
||||
if decision is None:
|
||||
|
||||
@@ -126,7 +126,7 @@ class PrintToConsoleBlock(Block):
|
||||
output_schema=PrintToConsoleBlock.Output,
|
||||
test_input={"text": "Hello, World!"},
|
||||
is_sensitive_action=True,
|
||||
disabled=True,
|
||||
disabled=True, # Disabled per Nick Tindle's request (OPEN-3000)
|
||||
test_output=[
|
||||
("output", "Hello, World!"),
|
||||
("status", "printed"),
|
||||
|
||||
@@ -142,7 +142,7 @@ class BaseE2BExecutorMixin:
|
||||
start_timestamp = ts_result.stdout.strip() if ts_result.stdout else None
|
||||
|
||||
# Execute the code
|
||||
execution = await sandbox.run_code( # type: ignore[attr-defined]
|
||||
execution = await sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox on error
|
||||
|
||||
@@ -96,7 +96,6 @@ class SendEmailBlock(Block):
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Email sent successfully")],
|
||||
test_mock={"send_email": lambda *args, **kwargs: "Email sent successfully"},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
def github_repo_path(repo_url: str) -> str:
|
||||
"""Extract 'owner/repo' from a GitHub repository URL."""
|
||||
return repo_url.replace("https://github.com/", "")
|
||||
@@ -1,374 +0,0 @@
|
||||
import asyncio
|
||||
from enum import StrEnum
|
||||
from urllib.parse import quote
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
from ._utils import github_repo_path
|
||||
|
||||
|
||||
class GithubListCommitsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to list commits from",
|
||||
default="main",
|
||||
)
|
||||
per_page: int = SchemaField(
|
||||
description="Number of commits to return (max 100)",
|
||||
default=30,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
page: int = SchemaField(
|
||||
description="Page number for pagination",
|
||||
default=1,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class CommitItem(TypedDict):
|
||||
sha: str
|
||||
message: str
|
||||
author: str
|
||||
date: str
|
||||
url: str
|
||||
|
||||
commit: CommitItem = SchemaField(
|
||||
title="Commit", description="A commit with its details"
|
||||
)
|
||||
commits: list[CommitItem] = SchemaField(
|
||||
description="List of commits with their details"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing commits failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8b13f579-d8b6-4dc2-a140-f770428805de",
|
||||
description="This block lists commits on a branch in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListCommitsBlock.Input,
|
||||
output_schema=GithubListCommitsBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "main",
|
||||
"per_page": 30,
|
||||
"page": 1,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"commits",
|
||||
[
|
||||
{
|
||||
"sha": "abc123",
|
||||
"message": "Initial commit",
|
||||
"author": "octocat",
|
||||
"date": "2024-01-01T00:00:00Z",
|
||||
"url": "https://github.com/owner/repo/commit/abc123",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"commit",
|
||||
{
|
||||
"sha": "abc123",
|
||||
"message": "Initial commit",
|
||||
"author": "octocat",
|
||||
"date": "2024-01-01T00:00:00Z",
|
||||
"url": "https://github.com/owner/repo/commit/abc123",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_commits": lambda *args, **kwargs: [
|
||||
{
|
||||
"sha": "abc123",
|
||||
"message": "Initial commit",
|
||||
"author": "octocat",
|
||||
"date": "2024-01-01T00:00:00Z",
|
||||
"url": "https://github.com/owner/repo/commit/abc123",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_commits(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
branch: str,
|
||||
per_page: int,
|
||||
page: int,
|
||||
) -> list[Output.CommitItem]:
|
||||
api = get_api(credentials)
|
||||
commits_url = repo_url + "/commits"
|
||||
params = {"sha": branch, "per_page": str(per_page), "page": str(page)}
|
||||
response = await api.get(commits_url, params=params)
|
||||
data = response.json()
|
||||
repo_path = github_repo_path(repo_url)
|
||||
return [
|
||||
GithubListCommitsBlock.Output.CommitItem(
|
||||
sha=c["sha"],
|
||||
message=c["commit"]["message"],
|
||||
author=(c["commit"].get("author") or {}).get("name", "Unknown"),
|
||||
date=(c["commit"].get("author") or {}).get("date", ""),
|
||||
url=f"https://github.com/{repo_path}/commit/{c['sha']}",
|
||||
)
|
||||
for c in data
|
||||
]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
commits = await self.list_commits(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
input_data.per_page,
|
||||
input_data.page,
|
||||
)
|
||||
yield "commits", commits
|
||||
for commit in commits:
|
||||
yield "commit", commit
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class FileOperation(StrEnum):
|
||||
"""File operations for GithubMultiFileCommitBlock.
|
||||
|
||||
UPSERT creates or overwrites a file (the Git Trees API does not distinguish
|
||||
between creation and update — the blob is placed at the given path regardless
|
||||
of whether a file already exists there).
|
||||
|
||||
DELETE removes a file from the tree.
|
||||
"""
|
||||
|
||||
UPSERT = "upsert"
|
||||
DELETE = "delete"
|
||||
|
||||
|
||||
class FileOperationInput(TypedDict):
|
||||
path: str
|
||||
content: str
|
||||
operation: FileOperation
|
||||
|
||||
|
||||
class GithubMultiFileCommitBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch to commit to",
|
||||
placeholder="feature-branch",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Commit message",
|
||||
placeholder="Add new feature",
|
||||
)
|
||||
files: list[FileOperationInput] = SchemaField(
|
||||
description=(
|
||||
"List of file operations. Each item has: "
|
||||
"'path' (file path), 'content' (file content, ignored for delete), "
|
||||
"'operation' (upsert/delete)"
|
||||
),
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
sha: str = SchemaField(description="SHA of the new commit")
|
||||
url: str = SchemaField(description="URL of the new commit")
|
||||
error: str = SchemaField(description="Error message if the commit failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="389eee51-a95e-4230-9bed-92167a327802",
|
||||
description=(
|
||||
"This block creates a single commit with multiple file "
|
||||
"upsert/delete operations using the Git Trees API."
|
||||
),
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMultiFileCommitBlock.Input,
|
||||
output_schema=GithubMultiFileCommitBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "feature",
|
||||
"commit_message": "Add files",
|
||||
"files": [
|
||||
{
|
||||
"path": "src/new.py",
|
||||
"content": "print('hello')",
|
||||
"operation": "upsert",
|
||||
},
|
||||
{
|
||||
"path": "src/old.py",
|
||||
"content": "",
|
||||
"operation": "delete",
|
||||
},
|
||||
],
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("sha", "newcommitsha"),
|
||||
("url", "https://github.com/owner/repo/commit/newcommitsha"),
|
||||
],
|
||||
test_mock={
|
||||
"multi_file_commit": lambda *args, **kwargs: (
|
||||
"newcommitsha",
|
||||
"https://github.com/owner/repo/commit/newcommitsha",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def multi_file_commit(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
files: list[FileOperationInput],
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
safe_branch = quote(branch, safe="")
|
||||
|
||||
# 1. Get the latest commit SHA for the branch
|
||||
ref_url = repo_url + f"/git/refs/heads/{safe_branch}"
|
||||
response = await api.get(ref_url)
|
||||
ref_data = response.json()
|
||||
latest_commit_sha = ref_data["object"]["sha"]
|
||||
|
||||
# 2. Get the tree SHA of the latest commit
|
||||
commit_url = repo_url + f"/git/commits/{latest_commit_sha}"
|
||||
response = await api.get(commit_url)
|
||||
commit_data = response.json()
|
||||
base_tree_sha = commit_data["tree"]["sha"]
|
||||
|
||||
# 3. Build tree entries for each file operation (blobs created concurrently)
|
||||
async def _create_blob(content: str) -> str:
|
||||
blob_url = repo_url + "/git/blobs"
|
||||
blob_response = await api.post(
|
||||
blob_url,
|
||||
json={"content": content, "encoding": "utf-8"},
|
||||
)
|
||||
return blob_response.json()["sha"]
|
||||
|
||||
tree_entries: list[dict] = []
|
||||
upsert_files = []
|
||||
for file_op in files:
|
||||
path = file_op["path"]
|
||||
operation = FileOperation(file_op.get("operation", "upsert"))
|
||||
|
||||
if operation == FileOperation.DELETE:
|
||||
tree_entries.append(
|
||||
{
|
||||
"path": path,
|
||||
"mode": "100644",
|
||||
"type": "blob",
|
||||
"sha": None, # null SHA = delete
|
||||
}
|
||||
)
|
||||
else:
|
||||
upsert_files.append((path, file_op.get("content", "")))
|
||||
|
||||
# Create all blobs concurrently
|
||||
if upsert_files:
|
||||
blob_shas = await asyncio.gather(
|
||||
*[_create_blob(content) for _, content in upsert_files]
|
||||
)
|
||||
for (path, _), blob_sha in zip(upsert_files, blob_shas):
|
||||
tree_entries.append(
|
||||
{
|
||||
"path": path,
|
||||
"mode": "100644",
|
||||
"type": "blob",
|
||||
"sha": blob_sha,
|
||||
}
|
||||
)
|
||||
|
||||
# 4. Create a new tree
|
||||
tree_url = repo_url + "/git/trees"
|
||||
tree_response = await api.post(
|
||||
tree_url,
|
||||
json={"base_tree": base_tree_sha, "tree": tree_entries},
|
||||
)
|
||||
new_tree_sha = tree_response.json()["sha"]
|
||||
|
||||
# 5. Create a new commit
|
||||
new_commit_url = repo_url + "/git/commits"
|
||||
commit_response = await api.post(
|
||||
new_commit_url,
|
||||
json={
|
||||
"message": commit_message,
|
||||
"tree": new_tree_sha,
|
||||
"parents": [latest_commit_sha],
|
||||
},
|
||||
)
|
||||
new_commit_sha = commit_response.json()["sha"]
|
||||
|
||||
# 6. Update the branch reference
|
||||
try:
|
||||
await api.patch(
|
||||
ref_url,
|
||||
json={"sha": new_commit_sha},
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Commit {new_commit_sha} was created but failed to update "
|
||||
f"ref heads/{branch}: {e}. "
|
||||
f"You can recover by manually updating the branch to {new_commit_sha}."
|
||||
) from e
|
||||
|
||||
repo_path = github_repo_path(repo_url)
|
||||
commit_web_url = f"https://github.com/{repo_path}/commit/{new_commit_sha}"
|
||||
return new_commit_sha, commit_web_url
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
sha, url = await self.multi_file_commit(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
input_data.files,
|
||||
)
|
||||
yield "sha", sha
|
||||
yield "url", url
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,5 +1,4 @@
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
@@ -21,8 +20,6 @@ from ._auth import (
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
MergeMethod = Literal["merge", "squash", "rebase"]
|
||||
|
||||
|
||||
class GithubListPullRequestsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
@@ -561,109 +558,12 @@ class GithubListPRReviewersBlock(Block):
|
||||
yield "reviewer", reviewer
|
||||
|
||||
|
||||
class GithubMergePullRequestBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
pr_url: str = SchemaField(
|
||||
description="URL of the GitHub pull request",
|
||||
placeholder="https://github.com/owner/repo/pull/1",
|
||||
)
|
||||
merge_method: MergeMethod = SchemaField(
|
||||
description="Merge method to use: merge, squash, or rebase",
|
||||
default="merge",
|
||||
)
|
||||
commit_title: str = SchemaField(
|
||||
description="Title for the merge commit (optional, used for merge and squash)",
|
||||
default="",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the merge commit (optional, used for merge and squash)",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
sha: str = SchemaField(description="SHA of the merge commit")
|
||||
merged: bool = SchemaField(description="Whether the PR was merged")
|
||||
message: str = SchemaField(description="Merge status message")
|
||||
error: str = SchemaField(description="Error message if the merge failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="77456c22-33d8-4fd4-9eef-50b46a35bb48",
|
||||
description="This block merges a pull request using merge, squash, or rebase.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMergePullRequestBlock.Input,
|
||||
output_schema=GithubMergePullRequestBlock.Output,
|
||||
test_input={
|
||||
"pr_url": "https://github.com/owner/repo/pull/1",
|
||||
"merge_method": "squash",
|
||||
"commit_title": "",
|
||||
"commit_message": "",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("sha", "abc123"),
|
||||
("merged", True),
|
||||
("message", "Pull Request successfully merged"),
|
||||
],
|
||||
test_mock={
|
||||
"merge_pr": lambda *args, **kwargs: (
|
||||
"abc123",
|
||||
True,
|
||||
"Pull Request successfully merged",
|
||||
)
|
||||
},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def merge_pr(
|
||||
credentials: GithubCredentials,
|
||||
pr_url: str,
|
||||
merge_method: MergeMethod,
|
||||
commit_title: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, bool, str]:
|
||||
api = get_api(credentials)
|
||||
merge_url = prepare_pr_api_url(pr_url=pr_url, path="merge")
|
||||
data: dict[str, str] = {"merge_method": merge_method}
|
||||
if commit_title:
|
||||
data["commit_title"] = commit_title
|
||||
if commit_message:
|
||||
data["commit_message"] = commit_message
|
||||
response = await api.put(merge_url, json=data)
|
||||
result = response.json()
|
||||
return result["sha"], result["merged"], result["message"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
sha, merged, message = await self.merge_pr(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
input_data.merge_method,
|
||||
input_data.commit_title,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "sha", sha
|
||||
yield "merged", merged
|
||||
yield "message", message
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
def prepare_pr_api_url(pr_url: str, path: str) -> str:
|
||||
# Pattern to capture the base repository URL and the pull request number
|
||||
pattern = r"^(?:(https?)://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
|
||||
pattern = r"^(?:https?://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
|
||||
match = re.match(pattern, pr_url)
|
||||
if not match:
|
||||
return pr_url
|
||||
|
||||
scheme, base_url, pr_number = match.groups()
|
||||
return f"{scheme or 'https'}://{base_url}/pulls/{pr_number}/{path}"
|
||||
base_url, pr_number = match.groups()
|
||||
return f"{base_url}/pulls/{pr_number}/{path}"
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import base64
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
@@ -17,7 +19,6 @@ from ._auth import (
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
from ._utils import github_repo_path
|
||||
|
||||
|
||||
class GithubListTagsBlock(Block):
|
||||
@@ -88,7 +89,7 @@ class GithubListTagsBlock(Block):
|
||||
tags_url = repo_url + "/tags"
|
||||
response = await api.get(tags_url)
|
||||
data = response.json()
|
||||
repo_path = github_repo_path(repo_url)
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
tags: list[GithubListTagsBlock.Output.TagItem] = [
|
||||
{
|
||||
"name": tag["name"],
|
||||
@@ -114,6 +115,101 @@ class GithubListTagsBlock(Block):
|
||||
yield "tag", tag
|
||||
|
||||
|
||||
class GithubListBranchesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class BranchItem(TypedDict):
|
||||
name: str
|
||||
url: str
|
||||
|
||||
branch: BranchItem = SchemaField(
|
||||
title="Branch",
|
||||
description="Branches with their name and file tree browser URL",
|
||||
)
|
||||
branches: list[BranchItem] = SchemaField(
|
||||
description="List of branches with their name and file tree browser URL"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="74243e49-2bec-4916-8bf4-db43d44aead5",
|
||||
description="This block lists all branches for a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListBranchesBlock.Input,
|
||||
output_schema=GithubListBranchesBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"branches",
|
||||
[
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"branch",
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_branches": lambda *args, **kwargs: [
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_branches(
|
||||
credentials: GithubCredentials, repo_url: str
|
||||
) -> list[Output.BranchItem]:
|
||||
api = get_api(credentials)
|
||||
branches_url = repo_url + "/branches"
|
||||
response = await api.get(branches_url)
|
||||
data = response.json()
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
|
||||
{
|
||||
"name": branch["name"],
|
||||
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
|
||||
}
|
||||
for branch in data
|
||||
]
|
||||
return branches
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
branches = await self.list_branches(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield "branches", branches
|
||||
for branch in branches:
|
||||
yield "branch", branch
|
||||
|
||||
|
||||
class GithubListDiscussionsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
@@ -187,7 +283,7 @@ class GithubListDiscussionsBlock(Block):
|
||||
) -> list[Output.DiscussionItem]:
|
||||
api = get_api(credentials)
|
||||
# GitHub GraphQL API endpoint is different; we'll use api.post with custom URL
|
||||
repo_path = github_repo_path(repo_url)
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
owner, repo = repo_path.split("/")
|
||||
query = """
|
||||
query($owner: String!, $repo: String!, $num: Int!) {
|
||||
@@ -320,6 +416,564 @@ class GithubListReleasesBlock(Block):
|
||||
yield "release", release
|
||||
|
||||
|
||||
class GithubReadFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file in the repository",
|
||||
placeholder="path/to/file",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch to read from",
|
||||
placeholder="branch_name",
|
||||
default="master",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
text_content: str = SchemaField(
|
||||
description="Content of the file (decoded as UTF-8 text)"
|
||||
)
|
||||
raw_content: str = SchemaField(
|
||||
description="Raw base64-encoded content of the file"
|
||||
)
|
||||
size: int = SchemaField(description="The size of the file (in bytes)")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
|
||||
description="This block reads the content of a specified file from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFileBlock.Input,
|
||||
output_schema=GithubReadFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "path/to/file",
|
||||
"branch": "master",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("raw_content", "RmlsZSBjb250ZW50"),
|
||||
("text_content", "File content"),
|
||||
("size", 13),
|
||||
],
|
||||
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_file(
|
||||
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
|
||||
) -> tuple[str, int]:
|
||||
api = get_api(credentials)
|
||||
content_url = repo_url + f"/contents/{file_path}?ref={branch}"
|
||||
response = await api.get(content_url)
|
||||
data = response.json()
|
||||
|
||||
if isinstance(data, list):
|
||||
# Multiple entries of different types exist at this path
|
||||
if not (file := next((f for f in data if f["type"] == "file"), None)):
|
||||
raise TypeError("Not a file")
|
||||
data = file
|
||||
|
||||
if data["type"] != "file":
|
||||
raise TypeError("Not a file")
|
||||
|
||||
return data["content"], data["size"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
content, size = await self.read_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "raw_content", content
|
||||
yield "text_content", base64.b64decode(content).decode("utf-8")
|
||||
yield "size", size
|
||||
|
||||
|
||||
class GithubReadFolderBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
folder_path: str = SchemaField(
|
||||
description="Path to the folder in the repository",
|
||||
placeholder="path/to/folder",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to read from (defaults to master)",
|
||||
placeholder="branch_name",
|
||||
default="master",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class DirEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
|
||||
class FileEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
size: int
|
||||
|
||||
file: FileEntry = SchemaField(description="Files in the folder")
|
||||
dir: DirEntry = SchemaField(description="Directories in the folder")
|
||||
error: str = SchemaField(
|
||||
description="Error message if reading the folder failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
|
||||
description="This block reads the content of a specified folder from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFolderBlock.Input,
|
||||
output_schema=GithubReadFolderBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"folder_path": "path/to/folder",
|
||||
"branch": "master",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"file",
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
},
|
||||
),
|
||||
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
|
||||
],
|
||||
test_mock={
|
||||
"read_folder": lambda *args, **kwargs: (
|
||||
[
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
}
|
||||
],
|
||||
[{"name": "dir2", "path": "path/to/folder/dir2"}],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_folder(
|
||||
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
|
||||
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{folder_path}?ref={branch}"
|
||||
response = await api.get(contents_url)
|
||||
data = response.json()
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise TypeError("Not a folder")
|
||||
|
||||
files: list[GithubReadFolderBlock.Output.FileEntry] = [
|
||||
GithubReadFolderBlock.Output.FileEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
size=entry["size"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "file"
|
||||
]
|
||||
|
||||
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
|
||||
GithubReadFolderBlock.Output.DirEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "dir"
|
||||
]
|
||||
|
||||
return files, dirs
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
files, dirs = await self.read_folder(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.folder_path.lstrip("/"),
|
||||
input_data.branch,
|
||||
)
|
||||
for file in files:
|
||||
yield "file", file
|
||||
for dir in dirs:
|
||||
yield "dir", dir
|
||||
|
||||
|
||||
class GithubMakeBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
new_branch: str = SchemaField(
|
||||
description="Name of the new branch",
|
||||
placeholder="new_branch_name",
|
||||
)
|
||||
source_branch: str = SchemaField(
|
||||
description="Name of the source branch",
|
||||
placeholder="source_branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch creation operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
|
||||
description="This block creates a new branch from a specified source branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMakeBranchBlock.Input,
|
||||
output_schema=GithubMakeBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"new_branch": "new_branch_name",
|
||||
"source_branch": "source_branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch created successfully")],
|
||||
test_mock={
|
||||
"create_branch": lambda *args, **kwargs: "Branch created successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_branch(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
new_branch: str,
|
||||
source_branch: str,
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{source_branch}"
|
||||
response = await api.get(ref_url)
|
||||
data = response.json()
|
||||
sha = data["object"]["sha"]
|
||||
|
||||
# Create the new branch
|
||||
new_ref_url = repo_url + "/git/refs"
|
||||
data = {
|
||||
"ref": f"refs/heads/{new_branch}",
|
||||
"sha": sha,
|
||||
}
|
||||
response = await api.post(new_ref_url, json=data)
|
||||
return "Branch created successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
status = await self.create_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.new_branch,
|
||||
input_data.source_branch,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
|
||||
class GithubDeleteBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Name of the branch to delete",
|
||||
placeholder="branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch deletion operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch deletion failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
|
||||
description="This block deletes a specified branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubDeleteBranchBlock.Input,
|
||||
output_schema=GithubDeleteBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch deleted successfully")],
|
||||
test_mock={
|
||||
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_branch(
|
||||
credentials: GithubCredentials, repo_url: str, branch: str
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{branch}"
|
||||
await api.delete(ref_url)
|
||||
return "Branch deleted successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
status = await self.delete_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
|
||||
class GithubCreateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path where the file should be created",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="Content to write to the file",
|
||||
placeholder="File content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch where the file should be created",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Create new file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the created file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the file creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
|
||||
description="This block creates a new file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCreateFileBlock.Input,
|
||||
output_schema=GithubCreateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Test content",
|
||||
"branch": "main",
|
||||
"commit_message": "Create test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "abc123"),
|
||||
],
|
||||
test_mock={
|
||||
"create_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"abc123",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{file_path}"
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.create_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubUpdateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file to update",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="New content for the file",
|
||||
placeholder="Updated content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch containing the file",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Update file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the updated file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
|
||||
description="This block updates an existing file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubUpdateFileBlock.Input,
|
||||
output_schema=GithubUpdateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Updated content",
|
||||
"branch": "main",
|
||||
"commit_message": "Update test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "def456"),
|
||||
],
|
||||
test_mock={
|
||||
"update_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"def456",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{file_path}"
|
||||
params = {"ref": branch}
|
||||
response = await api.get(contents_url, params=params)
|
||||
data = response.json()
|
||||
|
||||
# Convert new content to base64
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"sha": data["sha"],
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.update_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubCreateRepositoryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
@@ -449,7 +1103,7 @@ class GithubListStargazersBlock(Block):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e96d01ec-b55e-4a99-8ce8-c8776dce850b", # Generated unique UUID
|
||||
id="a4b9c2d1-e5f6-4g7h-8i9j-0k1l2m3n4o5p", # Generated unique UUID
|
||||
description="This block lists all users who have starred a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListStargazersBlock.Input,
|
||||
@@ -518,230 +1172,3 @@ class GithubListStargazersBlock(Block):
|
||||
yield "stargazers", stargazers
|
||||
for stargazer in stargazers:
|
||||
yield "stargazer", stargazer
|
||||
|
||||
|
||||
class GithubGetRepositoryInfoBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
name: str = SchemaField(description="Repository name")
|
||||
full_name: str = SchemaField(description="Full repository name (owner/repo)")
|
||||
description: str = SchemaField(description="Repository description")
|
||||
default_branch: str = SchemaField(description="Default branch name (e.g. main)")
|
||||
private: bool = SchemaField(description="Whether the repository is private")
|
||||
html_url: str = SchemaField(description="Web URL of the repository")
|
||||
clone_url: str = SchemaField(description="Git clone URL")
|
||||
stars: int = SchemaField(description="Number of stars")
|
||||
forks: int = SchemaField(description="Number of forks")
|
||||
open_issues: int = SchemaField(description="Number of open issues")
|
||||
error: str = SchemaField(
|
||||
description="Error message if fetching repo info failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="59d4f241-968a-4040-95da-348ac5c5ce27",
|
||||
description="This block retrieves metadata about a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubGetRepositoryInfoBlock.Input,
|
||||
output_schema=GithubGetRepositoryInfoBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("name", "repo"),
|
||||
("full_name", "owner/repo"),
|
||||
("description", "A test repo"),
|
||||
("default_branch", "main"),
|
||||
("private", False),
|
||||
("html_url", "https://github.com/owner/repo"),
|
||||
("clone_url", "https://github.com/owner/repo.git"),
|
||||
("stars", 42),
|
||||
("forks", 5),
|
||||
("open_issues", 3),
|
||||
],
|
||||
test_mock={
|
||||
"get_repo_info": lambda *args, **kwargs: {
|
||||
"name": "repo",
|
||||
"full_name": "owner/repo",
|
||||
"description": "A test repo",
|
||||
"default_branch": "main",
|
||||
"private": False,
|
||||
"html_url": "https://github.com/owner/repo",
|
||||
"clone_url": "https://github.com/owner/repo.git",
|
||||
"stargazers_count": 42,
|
||||
"forks_count": 5,
|
||||
"open_issues_count": 3,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_repo_info(credentials: GithubCredentials, repo_url: str) -> dict:
|
||||
api = get_api(credentials)
|
||||
response = await api.get(repo_url)
|
||||
return response.json()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data = await self.get_repo_info(credentials, input_data.repo_url)
|
||||
yield "name", data["name"]
|
||||
yield "full_name", data["full_name"]
|
||||
yield "description", data.get("description", "") or ""
|
||||
yield "default_branch", data["default_branch"]
|
||||
yield "private", data["private"]
|
||||
yield "html_url", data["html_url"]
|
||||
yield "clone_url", data["clone_url"]
|
||||
yield "stars", data["stargazers_count"]
|
||||
yield "forks", data["forks_count"]
|
||||
yield "open_issues", data["open_issues_count"]
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubForkRepositoryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository to fork",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
organization: str = SchemaField(
|
||||
description="Organization to fork into (leave empty to fork to your account)",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the forked repository")
|
||||
clone_url: str = SchemaField(description="Git clone URL of the fork")
|
||||
full_name: str = SchemaField(description="Full name of the fork (owner/repo)")
|
||||
error: str = SchemaField(description="Error message if the fork failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a439f2f4-835f-4dae-ba7b-0205ffa70be6",
|
||||
description="This block forks a GitHub repository to your account or an organization.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubForkRepositoryBlock.Input,
|
||||
output_schema=GithubForkRepositoryBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"organization": "",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/myuser/repo"),
|
||||
("clone_url", "https://github.com/myuser/repo.git"),
|
||||
("full_name", "myuser/repo"),
|
||||
],
|
||||
test_mock={
|
||||
"fork_repo": lambda *args, **kwargs: (
|
||||
"https://github.com/myuser/repo",
|
||||
"https://github.com/myuser/repo.git",
|
||||
"myuser/repo",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def fork_repo(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
organization: str,
|
||||
) -> tuple[str, str, str]:
|
||||
api = get_api(credentials)
|
||||
forks_url = repo_url + "/forks"
|
||||
data: dict[str, str] = {}
|
||||
if organization:
|
||||
data["organization"] = organization
|
||||
response = await api.post(forks_url, json=data)
|
||||
result = response.json()
|
||||
return result["html_url"], result["clone_url"], result["full_name"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, clone_url, full_name = await self.fork_repo(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.organization,
|
||||
)
|
||||
yield "url", url
|
||||
yield "clone_url", clone_url
|
||||
yield "full_name", full_name
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubStarRepositoryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository to star",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the star operation")
|
||||
error: str = SchemaField(description="Error message if starring failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bd700764-53e3-44dd-a969-d1854088458f",
|
||||
description="This block stars a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubStarRepositoryBlock.Input,
|
||||
output_schema=GithubStarRepositoryBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Repository starred successfully")],
|
||||
test_mock={
|
||||
"star_repo": lambda *args, **kwargs: "Repository starred successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def star_repo(credentials: GithubCredentials, repo_url: str) -> str:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
repo_path = github_repo_path(repo_url)
|
||||
owner, repo = repo_path.split("/")
|
||||
await api.put(
|
||||
f"https://api.github.com/user/starred/{owner}/{repo}",
|
||||
headers={"Content-Length": "0"},
|
||||
)
|
||||
return "Repository starred successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = await self.star_repo(credentials, input_data.repo_url)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
@@ -1,452 +0,0 @@
|
||||
from urllib.parse import quote
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
from ._utils import github_repo_path
|
||||
|
||||
|
||||
class GithubListBranchesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
per_page: int = SchemaField(
|
||||
description="Number of branches to return per page (max 100)",
|
||||
default=30,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
page: int = SchemaField(
|
||||
description="Page number for pagination",
|
||||
default=1,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class BranchItem(TypedDict):
|
||||
name: str
|
||||
url: str
|
||||
|
||||
branch: BranchItem = SchemaField(
|
||||
title="Branch",
|
||||
description="Branches with their name and file tree browser URL",
|
||||
)
|
||||
branches: list[BranchItem] = SchemaField(
|
||||
description="List of branches with their name and file tree browser URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing branches failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="74243e49-2bec-4916-8bf4-db43d44aead5",
|
||||
description="This block lists all branches for a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListBranchesBlock.Input,
|
||||
output_schema=GithubListBranchesBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"per_page": 30,
|
||||
"page": 1,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"branches",
|
||||
[
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"branch",
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_branches": lambda *args, **kwargs: [
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_branches(
|
||||
credentials: GithubCredentials, repo_url: str, per_page: int, page: int
|
||||
) -> list[Output.BranchItem]:
|
||||
api = get_api(credentials)
|
||||
branches_url = repo_url + "/branches"
|
||||
response = await api.get(
|
||||
branches_url, params={"per_page": str(per_page), "page": str(page)}
|
||||
)
|
||||
data = response.json()
|
||||
repo_path = github_repo_path(repo_url)
|
||||
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
|
||||
{
|
||||
"name": branch["name"],
|
||||
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
|
||||
}
|
||||
for branch in data
|
||||
]
|
||||
return branches
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
branches = await self.list_branches(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.per_page,
|
||||
input_data.page,
|
||||
)
|
||||
yield "branches", branches
|
||||
for branch in branches:
|
||||
yield "branch", branch
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubMakeBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
new_branch: str = SchemaField(
|
||||
description="Name of the new branch",
|
||||
placeholder="new_branch_name",
|
||||
)
|
||||
source_branch: str = SchemaField(
|
||||
description="Name of the source branch",
|
||||
placeholder="source_branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch creation operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
|
||||
description="This block creates a new branch from a specified source branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMakeBranchBlock.Input,
|
||||
output_schema=GithubMakeBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"new_branch": "new_branch_name",
|
||||
"source_branch": "source_branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch created successfully")],
|
||||
test_mock={
|
||||
"create_branch": lambda *args, **kwargs: "Branch created successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_branch(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
new_branch: str,
|
||||
source_branch: str,
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{quote(source_branch, safe='')}"
|
||||
response = await api.get(ref_url)
|
||||
data = response.json()
|
||||
sha = data["object"]["sha"]
|
||||
|
||||
# Create the new branch
|
||||
new_ref_url = repo_url + "/git/refs"
|
||||
data = {
|
||||
"ref": f"refs/heads/{new_branch}",
|
||||
"sha": sha,
|
||||
}
|
||||
response = await api.post(new_ref_url, json=data)
|
||||
return "Branch created successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = await self.create_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.new_branch,
|
||||
input_data.source_branch,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubDeleteBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Name of the branch to delete",
|
||||
placeholder="branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch deletion operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch deletion failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
|
||||
description="This block deletes a specified branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubDeleteBranchBlock.Input,
|
||||
output_schema=GithubDeleteBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch deleted successfully")],
|
||||
test_mock={
|
||||
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
|
||||
},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_branch(
|
||||
credentials: GithubCredentials, repo_url: str, branch: str
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{quote(branch, safe='')}"
|
||||
await api.delete(ref_url)
|
||||
return "Branch deleted successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = await self.delete_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubCompareBranchesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
base: str = SchemaField(
|
||||
description="Base branch or commit SHA",
|
||||
placeholder="main",
|
||||
)
|
||||
head: str = SchemaField(
|
||||
description="Head branch or commit SHA to compare against base",
|
||||
placeholder="feature-branch",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class FileChange(TypedDict):
|
||||
filename: str
|
||||
status: str
|
||||
additions: int
|
||||
deletions: int
|
||||
patch: str
|
||||
|
||||
status: str = SchemaField(
|
||||
description="Comparison status: ahead, behind, diverged, or identical"
|
||||
)
|
||||
ahead_by: int = SchemaField(
|
||||
description="Number of commits head is ahead of base"
|
||||
)
|
||||
behind_by: int = SchemaField(
|
||||
description="Number of commits head is behind base"
|
||||
)
|
||||
total_commits: int = SchemaField(
|
||||
description="Total number of commits in the comparison"
|
||||
)
|
||||
diff: str = SchemaField(description="Unified diff of all file changes")
|
||||
file: FileChange = SchemaField(
|
||||
title="Changed File", description="A changed file with its diff"
|
||||
)
|
||||
files: list[FileChange] = SchemaField(
|
||||
description="List of changed files with their diffs"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if comparison failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2e4faa8c-6086-4546-ba77-172d1d560186",
|
||||
description="This block compares two branches or commits in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCompareBranchesBlock.Input,
|
||||
output_schema=GithubCompareBranchesBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"base": "main",
|
||||
"head": "feature",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("status", "ahead"),
|
||||
("ahead_by", 2),
|
||||
("behind_by", 0),
|
||||
("total_commits", 2),
|
||||
("diff", "+++ b/file.py\n+new line"),
|
||||
(
|
||||
"files",
|
||||
[
|
||||
{
|
||||
"filename": "file.py",
|
||||
"status": "modified",
|
||||
"additions": 1,
|
||||
"deletions": 0,
|
||||
"patch": "+new line",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"file",
|
||||
{
|
||||
"filename": "file.py",
|
||||
"status": "modified",
|
||||
"additions": 1,
|
||||
"deletions": 0,
|
||||
"patch": "+new line",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"compare_branches": lambda *args, **kwargs: {
|
||||
"status": "ahead",
|
||||
"ahead_by": 2,
|
||||
"behind_by": 0,
|
||||
"total_commits": 2,
|
||||
"files": [
|
||||
{
|
||||
"filename": "file.py",
|
||||
"status": "modified",
|
||||
"additions": 1,
|
||||
"deletions": 0,
|
||||
"patch": "+new line",
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def compare_branches(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
base: str,
|
||||
head: str,
|
||||
) -> dict:
|
||||
api = get_api(credentials)
|
||||
safe_base = quote(base, safe="")
|
||||
safe_head = quote(head, safe="")
|
||||
compare_url = repo_url + f"/compare/{safe_base}...{safe_head}"
|
||||
response = await api.get(compare_url)
|
||||
return response.json()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data = await self.compare_branches(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.base,
|
||||
input_data.head,
|
||||
)
|
||||
yield "status", data["status"]
|
||||
yield "ahead_by", data["ahead_by"]
|
||||
yield "behind_by", data["behind_by"]
|
||||
yield "total_commits", data["total_commits"]
|
||||
|
||||
files: list[GithubCompareBranchesBlock.Output.FileChange] = [
|
||||
GithubCompareBranchesBlock.Output.FileChange(
|
||||
filename=f["filename"],
|
||||
status=f["status"],
|
||||
additions=f["additions"],
|
||||
deletions=f["deletions"],
|
||||
patch=f.get("patch", ""),
|
||||
)
|
||||
for f in data.get("files", [])
|
||||
]
|
||||
|
||||
# Build unified diff
|
||||
diff_parts = []
|
||||
for f in data.get("files", []):
|
||||
patch = f.get("patch", "")
|
||||
if patch:
|
||||
diff_parts.append(f"+++ b/{f['filename']}\n{patch}")
|
||||
yield "diff", "\n".join(diff_parts)
|
||||
|
||||
yield "files", files
|
||||
for file in files:
|
||||
yield "file", file
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,720 +0,0 @@
|
||||
import base64
|
||||
from urllib.parse import quote
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class GithubReadFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file in the repository",
|
||||
placeholder="path/to/file",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch to read from",
|
||||
placeholder="branch_name",
|
||||
default="main",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
text_content: str = SchemaField(
|
||||
description="Content of the file (decoded as UTF-8 text)"
|
||||
)
|
||||
raw_content: str = SchemaField(
|
||||
description="Raw base64-encoded content of the file"
|
||||
)
|
||||
size: int = SchemaField(description="The size of the file (in bytes)")
|
||||
error: str = SchemaField(description="Error message if reading the file failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
|
||||
description="This block reads the content of a specified file from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFileBlock.Input,
|
||||
output_schema=GithubReadFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "path/to/file",
|
||||
"branch": "main",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("raw_content", "RmlsZSBjb250ZW50"),
|
||||
("text_content", "File content"),
|
||||
("size", 13),
|
||||
],
|
||||
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_file(
|
||||
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
|
||||
) -> tuple[str, int]:
|
||||
api = get_api(credentials)
|
||||
content_url = (
|
||||
repo_url
|
||||
+ f"/contents/{quote(file_path, safe='')}?ref={quote(branch, safe='')}"
|
||||
)
|
||||
response = await api.get(content_url)
|
||||
data = response.json()
|
||||
|
||||
if isinstance(data, list):
|
||||
# Multiple entries of different types exist at this path
|
||||
if not (file := next((f for f in data if f["type"] == "file"), None)):
|
||||
raise TypeError("Not a file")
|
||||
data = file
|
||||
|
||||
if data["type"] != "file":
|
||||
raise TypeError("Not a file")
|
||||
|
||||
return data["content"], data["size"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
content, size = await self.read_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "raw_content", content
|
||||
yield "text_content", base64.b64decode(content).decode("utf-8")
|
||||
yield "size", size
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubReadFolderBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
folder_path: str = SchemaField(
|
||||
description="Path to the folder in the repository",
|
||||
placeholder="path/to/folder",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to read from (defaults to main)",
|
||||
placeholder="branch_name",
|
||||
default="main",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class DirEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
|
||||
class FileEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
size: int
|
||||
|
||||
file: FileEntry = SchemaField(description="Files in the folder")
|
||||
dir: DirEntry = SchemaField(description="Directories in the folder")
|
||||
error: str = SchemaField(
|
||||
description="Error message if reading the folder failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
|
||||
description="This block reads the content of a specified folder from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFolderBlock.Input,
|
||||
output_schema=GithubReadFolderBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"folder_path": "path/to/folder",
|
||||
"branch": "main",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"file",
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
},
|
||||
),
|
||||
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
|
||||
],
|
||||
test_mock={
|
||||
"read_folder": lambda *args, **kwargs: (
|
||||
[
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
}
|
||||
],
|
||||
[{"name": "dir2", "path": "path/to/folder/dir2"}],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_folder(
|
||||
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
|
||||
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
|
||||
api = get_api(credentials)
|
||||
contents_url = (
|
||||
repo_url
|
||||
+ f"/contents/{quote(folder_path, safe='/')}?ref={quote(branch, safe='')}"
|
||||
)
|
||||
response = await api.get(contents_url)
|
||||
data = response.json()
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise TypeError("Not a folder")
|
||||
|
||||
files: list[GithubReadFolderBlock.Output.FileEntry] = [
|
||||
GithubReadFolderBlock.Output.FileEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
size=entry["size"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "file"
|
||||
]
|
||||
|
||||
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
|
||||
GithubReadFolderBlock.Output.DirEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "dir"
|
||||
]
|
||||
|
||||
return files, dirs
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
files, dirs = await self.read_folder(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.folder_path.lstrip("/"),
|
||||
input_data.branch,
|
||||
)
|
||||
for file in files:
|
||||
yield "file", file
|
||||
for dir in dirs:
|
||||
yield "dir", dir
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubCreateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path where the file should be created",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="Content to write to the file",
|
||||
placeholder="File content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch where the file should be created",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Create new file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the created file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the file creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
|
||||
description="This block creates a new file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCreateFileBlock.Input,
|
||||
output_schema=GithubCreateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Test content",
|
||||
"branch": "main",
|
||||
"commit_message": "Create test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "abc123"),
|
||||
],
|
||||
test_mock={
|
||||
"create_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"abc123",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.create_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubUpdateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file to update",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="New content for the file",
|
||||
placeholder="Updated content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch containing the file",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Update file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the updated file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
|
||||
description="This block updates an existing file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubUpdateFileBlock.Input,
|
||||
output_schema=GithubUpdateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Updated content",
|
||||
"branch": "main",
|
||||
"commit_message": "Update test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "def456"),
|
||||
],
|
||||
test_mock={
|
||||
"update_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"def456",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
|
||||
params = {"ref": branch}
|
||||
response = await api.get(contents_url, params=params)
|
||||
data = response.json()
|
||||
|
||||
# Convert new content to base64
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"sha": data["sha"],
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.update_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubSearchCodeBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
query: str = SchemaField(
|
||||
description="Search query (GitHub code search syntax)",
|
||||
placeholder="className language:python",
|
||||
)
|
||||
repo: str = SchemaField(
|
||||
description="Restrict search to a repository (owner/repo format, optional)",
|
||||
default="",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
per_page: int = SchemaField(
|
||||
description="Number of results to return (max 100)",
|
||||
default=30,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class SearchResult(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
repository: str
|
||||
url: str
|
||||
score: float
|
||||
|
||||
result: SearchResult = SchemaField(
|
||||
title="Result", description="A code search result"
|
||||
)
|
||||
results: list[SearchResult] = SchemaField(
|
||||
description="List of code search results"
|
||||
)
|
||||
total_count: int = SchemaField(description="Total number of matching results")
|
||||
error: str = SchemaField(description="Error message if search failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="47f94891-a2b1-4f1c-b5f2-573c043f721e",
|
||||
description="This block searches for code in GitHub repositories.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubSearchCodeBlock.Input,
|
||||
output_schema=GithubSearchCodeBlock.Output,
|
||||
test_input={
|
||||
"query": "addClass",
|
||||
"repo": "owner/repo",
|
||||
"per_page": 30,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("total_count", 1),
|
||||
(
|
||||
"results",
|
||||
[
|
||||
{
|
||||
"name": "file.py",
|
||||
"path": "src/file.py",
|
||||
"repository": "owner/repo",
|
||||
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
||||
"score": 1.0,
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"result",
|
||||
{
|
||||
"name": "file.py",
|
||||
"path": "src/file.py",
|
||||
"repository": "owner/repo",
|
||||
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
||||
"score": 1.0,
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"search_code": lambda *args, **kwargs: (
|
||||
1,
|
||||
[
|
||||
{
|
||||
"name": "file.py",
|
||||
"path": "src/file.py",
|
||||
"repository": "owner/repo",
|
||||
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
||||
"score": 1.0,
|
||||
}
|
||||
],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def search_code(
|
||||
credentials: GithubCredentials,
|
||||
query: str,
|
||||
repo: str,
|
||||
per_page: int,
|
||||
) -> tuple[int, list[Output.SearchResult]]:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
full_query = f"{query} repo:{repo}" if repo else query
|
||||
params = {"q": full_query, "per_page": str(per_page)}
|
||||
response = await api.get("https://api.github.com/search/code", params=params)
|
||||
data = response.json()
|
||||
results: list[GithubSearchCodeBlock.Output.SearchResult] = [
|
||||
GithubSearchCodeBlock.Output.SearchResult(
|
||||
name=item["name"],
|
||||
path=item["path"],
|
||||
repository=item["repository"]["full_name"],
|
||||
url=item["html_url"],
|
||||
score=item["score"],
|
||||
)
|
||||
for item in data["items"]
|
||||
]
|
||||
return data["total_count"], results
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
total_count, results = await self.search_code(
|
||||
credentials,
|
||||
input_data.query,
|
||||
input_data.repo,
|
||||
input_data.per_page,
|
||||
)
|
||||
yield "total_count", total_count
|
||||
yield "results", results
|
||||
for result in results:
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubGetRepositoryTreeBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to get the tree from",
|
||||
default="main",
|
||||
)
|
||||
recursive: bool = SchemaField(
|
||||
description="Whether to recursively list the entire tree",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class TreeEntry(TypedDict):
|
||||
path: str
|
||||
type: str
|
||||
size: int
|
||||
sha: str
|
||||
|
||||
entry: TreeEntry = SchemaField(
|
||||
title="Tree Entry", description="A file or directory in the tree"
|
||||
)
|
||||
entries: list[TreeEntry] = SchemaField(
|
||||
description="List of all files and directories in the tree"
|
||||
)
|
||||
truncated: bool = SchemaField(
|
||||
description="Whether the tree was truncated due to size"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if getting tree failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="89c5c0ec-172e-4001-a32c-bdfe4d0c9e81",
|
||||
description="This block lists the entire file tree of a GitHub repository recursively.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubGetRepositoryTreeBlock.Input,
|
||||
output_schema=GithubGetRepositoryTreeBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "main",
|
||||
"recursive": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("truncated", False),
|
||||
(
|
||||
"entries",
|
||||
[
|
||||
{
|
||||
"path": "src/main.py",
|
||||
"type": "blob",
|
||||
"size": 1234,
|
||||
"sha": "abc123",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"entry",
|
||||
{
|
||||
"path": "src/main.py",
|
||||
"type": "blob",
|
||||
"size": 1234,
|
||||
"sha": "abc123",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_tree": lambda *args, **kwargs: (
|
||||
False,
|
||||
[
|
||||
{
|
||||
"path": "src/main.py",
|
||||
"type": "blob",
|
||||
"size": 1234,
|
||||
"sha": "abc123",
|
||||
}
|
||||
],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_tree(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
branch: str,
|
||||
recursive: bool,
|
||||
) -> tuple[bool, list[Output.TreeEntry]]:
|
||||
api = get_api(credentials)
|
||||
tree_url = repo_url + f"/git/trees/{quote(branch, safe='')}"
|
||||
params = {"recursive": "1"} if recursive else {}
|
||||
response = await api.get(tree_url, params=params)
|
||||
data = response.json()
|
||||
entries: list[GithubGetRepositoryTreeBlock.Output.TreeEntry] = [
|
||||
GithubGetRepositoryTreeBlock.Output.TreeEntry(
|
||||
path=item["path"],
|
||||
type=item["type"],
|
||||
size=item.get("size", 0),
|
||||
sha=item["sha"],
|
||||
)
|
||||
for item in data["tree"]
|
||||
]
|
||||
return data.get("truncated", False), entries
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
truncated, entries = await self.get_tree(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
input_data.recursive,
|
||||
)
|
||||
yield "truncated", truncated
|
||||
yield "entries", entries
|
||||
for entry in entries:
|
||||
yield "entry", entry
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,120 +0,0 @@
|
||||
import inspect
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.github._auth import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||
from backend.blocks.github.commits import FileOperation, GithubMultiFileCommitBlock
|
||||
from backend.blocks.github.pull_requests import (
|
||||
GithubMergePullRequestBlock,
|
||||
prepare_pr_api_url,
|
||||
)
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
# ── prepare_pr_api_url tests ──
|
||||
|
||||
|
||||
class TestPreparePrApiUrl:
|
||||
def test_https_scheme_preserved(self):
|
||||
result = prepare_pr_api_url("https://github.com/owner/repo/pull/42", "merge")
|
||||
assert result == "https://github.com/owner/repo/pulls/42/merge"
|
||||
|
||||
def test_http_scheme_preserved(self):
|
||||
result = prepare_pr_api_url("http://github.com/owner/repo/pull/1", "files")
|
||||
assert result == "http://github.com/owner/repo/pulls/1/files"
|
||||
|
||||
def test_no_scheme_defaults_to_https(self):
|
||||
result = prepare_pr_api_url("github.com/owner/repo/pull/5", "merge")
|
||||
assert result == "https://github.com/owner/repo/pulls/5/merge"
|
||||
|
||||
def test_reviewers_path(self):
|
||||
result = prepare_pr_api_url(
|
||||
"https://github.com/owner/repo/pull/99", "requested_reviewers"
|
||||
)
|
||||
assert result == "https://github.com/owner/repo/pulls/99/requested_reviewers"
|
||||
|
||||
def test_invalid_url_returned_as_is(self):
|
||||
url = "https://example.com/not-a-pr"
|
||||
assert prepare_pr_api_url(url, "merge") == url
|
||||
|
||||
def test_empty_string(self):
|
||||
assert prepare_pr_api_url("", "merge") == ""
|
||||
|
||||
|
||||
# ── Error-path block tests ──
|
||||
# When a block's run() yields ("error", msg), _execute() converts it to a
|
||||
# BlockExecutionError. We call block.execute() directly (not execute_block_test,
|
||||
# which returns early on empty test_output).
|
||||
|
||||
|
||||
def _mock_block(block, mocks: dict):
|
||||
"""Apply mocks to a block's static methods, wrapping sync mocks as async."""
|
||||
for name, mock_fn in mocks.items():
|
||||
original = getattr(block, name)
|
||||
if inspect.iscoroutinefunction(original):
|
||||
|
||||
async def async_mock(*args, _fn=mock_fn, **kwargs):
|
||||
return _fn(*args, **kwargs)
|
||||
|
||||
setattr(block, name, async_mock)
|
||||
else:
|
||||
setattr(block, name, mock_fn)
|
||||
|
||||
|
||||
def _raise(exc: Exception):
|
||||
"""Helper that returns a callable which raises the given exception."""
|
||||
|
||||
def _raiser(*args, **kwargs):
|
||||
raise exc
|
||||
|
||||
return _raiser
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_pr_error_path():
|
||||
block = GithubMergePullRequestBlock()
|
||||
_mock_block(block, {"merge_pr": _raise(RuntimeError("PR not mergeable"))})
|
||||
input_data = {
|
||||
"pr_url": "https://github.com/owner/repo/pull/1",
|
||||
"merge_method": "squash",
|
||||
"commit_title": "",
|
||||
"commit_message": "",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
with pytest.raises(BlockExecutionError, match="PR not mergeable"):
|
||||
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_file_commit_error_path():
|
||||
block = GithubMultiFileCommitBlock()
|
||||
_mock_block(block, {"multi_file_commit": _raise(RuntimeError("ref update failed"))})
|
||||
input_data = {
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "feature",
|
||||
"commit_message": "test",
|
||||
"files": [{"path": "a.py", "content": "x", "operation": "upsert"}],
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
with pytest.raises(BlockExecutionError, match="ref update failed"):
|
||||
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
|
||||
# ── FileOperation enum tests ──
|
||||
|
||||
|
||||
class TestFileOperation:
|
||||
def test_upsert_value(self):
|
||||
assert FileOperation.UPSERT == "upsert"
|
||||
|
||||
def test_delete_value(self):
|
||||
assert FileOperation.DELETE == "delete"
|
||||
|
||||
def test_invalid_value_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
FileOperation("create")
|
||||
|
||||
def test_invalid_value_raises_typo(self):
|
||||
with pytest.raises(ValueError):
|
||||
FileOperation("upser")
|
||||
@@ -241,8 +241,8 @@ class GmailBase(Block, ABC):
|
||||
h.ignore_links = False
|
||||
h.ignore_images = True
|
||||
return h.handle(html_content)
|
||||
except Exception:
|
||||
# Keep extraction resilient if html2text is unavailable or fails.
|
||||
except ImportError:
|
||||
# Fallback: return raw HTML if html2text is not available
|
||||
return html_content
|
||||
|
||||
# Handle content stored as attachment
|
||||
|
||||
@@ -67,7 +67,6 @@ class HITLReviewHelper:
|
||||
graph_version: int,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
is_graph_execution: bool = True,
|
||||
) -> Optional[ReviewResult]:
|
||||
"""
|
||||
Handle a review request for a block that requires human review.
|
||||
@@ -144,11 +143,10 @@ class HITLReviewHelper:
|
||||
logger.info(
|
||||
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
|
||||
)
|
||||
if is_graph_execution:
|
||||
await HITLReviewHelper.update_node_execution_status(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.REVIEW,
|
||||
)
|
||||
await HITLReviewHelper.update_node_execution_status(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.REVIEW,
|
||||
)
|
||||
return None # Signal that execution should pause
|
||||
|
||||
# Mark review as processed if not already done
|
||||
@@ -170,7 +168,6 @@ class HITLReviewHelper:
|
||||
graph_version: int,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
is_graph_execution: bool = True,
|
||||
) -> Optional[ReviewDecision]:
|
||||
"""
|
||||
Handle a review request and return the decision in a single call.
|
||||
@@ -200,7 +197,6 @@ class HITLReviewHelper:
|
||||
graph_version=graph_version,
|
||||
block_name=block_name,
|
||||
editable=editable,
|
||||
is_graph_execution=is_graph_execution,
|
||||
)
|
||||
|
||||
if review_result is None:
|
||||
|
||||
@@ -17,7 +17,7 @@ from backend.blocks.jina._auth import (
|
||||
from backend.blocks.search import GetRequest
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.request import HTTPClientError, HTTPServerError, validate_url_host
|
||||
from backend.util.request import HTTPClientError, HTTPServerError, validate_url
|
||||
|
||||
|
||||
class SearchTheWebBlock(Block, GetRequest):
|
||||
@@ -112,7 +112,7 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
|
||||
) -> BlockOutput:
|
||||
if input_data.raw_content:
|
||||
try:
|
||||
parsed_url, _, _ = await validate_url_host(input_data.url)
|
||||
parsed_url, _, _ = await validate_url(input_data.url, [])
|
||||
url = parsed_url.geturl()
|
||||
except ValueError as e:
|
||||
yield "error", f"Invalid URL: {e}"
|
||||
|
||||
@@ -31,14 +31,10 @@ from backend.data.model import (
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.prompt import compress_context, estimate_token_count
|
||||
from backend.util.request import validate_url_host
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.text import TextFormatter
|
||||
|
||||
settings = Settings()
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||
fmt = TextFormatter(autoescape=False)
|
||||
|
||||
@@ -140,31 +136,19 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
# OpenRouter models
|
||||
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
|
||||
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
|
||||
GEMINI_2_5_PRO_PREVIEW = "google/gemini-2.5-pro-preview-03-25"
|
||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro"
|
||||
GEMINI_3_1_PRO_PREVIEW = "google/gemini-3.1-pro-preview"
|
||||
GEMINI_3_FLASH_PREVIEW = "google/gemini-3-flash-preview"
|
||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
|
||||
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
|
||||
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
|
||||
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
|
||||
GEMINI_3_1_FLASH_LITE_PREVIEW = "google/gemini-3.1-flash-lite-preview"
|
||||
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
|
||||
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
|
||||
MISTRAL_NEMO = "mistralai/mistral-nemo"
|
||||
MISTRAL_LARGE_3 = "mistralai/mistral-large-2512"
|
||||
MISTRAL_MEDIUM_3_1 = "mistralai/mistral-medium-3.1"
|
||||
MISTRAL_SMALL_3_2 = "mistralai/mistral-small-3.2-24b-instruct"
|
||||
CODESTRAL = "mistralai/codestral-2508"
|
||||
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
|
||||
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
|
||||
COHERE_COMMAND_A_03_2025 = "cohere/command-a-03-2025"
|
||||
COHERE_COMMAND_A_TRANSLATE_08_2025 = "cohere/command-a-translate-08-2025"
|
||||
COHERE_COMMAND_A_REASONING_08_2025 = "cohere/command-a-reasoning-08-2025"
|
||||
COHERE_COMMAND_A_VISION_07_2025 = "cohere/command-a-vision-07-2025"
|
||||
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
|
||||
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
|
||||
PERPLEXITY_SONAR = "perplexity/sonar"
|
||||
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
|
||||
PERPLEXITY_SONAR_REASONING_PRO = "perplexity/sonar-reasoning-pro"
|
||||
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
|
||||
@@ -172,11 +156,9 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
|
||||
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
|
||||
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
|
||||
MICROSOFT_PHI_4 = "microsoft/phi-4"
|
||||
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
|
||||
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
|
||||
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
|
||||
GROK_3 = "x-ai/grok-3"
|
||||
GROK_4 = "x-ai/grok-4"
|
||||
GROK_4_FAST = "x-ai/grok-4-fast"
|
||||
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
|
||||
@@ -354,41 +336,17 @@ MODEL_METADATA = {
|
||||
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
|
||||
),
|
||||
# https://openrouter.ai/models
|
||||
LlmModel.GEMINI_2_5_PRO_PREVIEW: ModelMetadata(
|
||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
1050000,
|
||||
8192,
|
||||
"Gemini 2.5 Pro Preview 03.25",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
2,
|
||||
),
|
||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 2.5 Pro",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
2,
|
||||
),
|
||||
LlmModel.GEMINI_3_1_PRO_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 3.1 Pro Preview",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
2,
|
||||
),
|
||||
LlmModel.GEMINI_3_FLASH_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 3 Flash Preview",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
1,
|
||||
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
|
||||
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
|
||||
),
|
||||
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
|
||||
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
|
||||
@@ -396,15 +354,6 @@ MODEL_METADATA = {
|
||||
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
|
||||
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
|
||||
),
|
||||
LlmModel.GEMINI_3_1_FLASH_LITE_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 3.1 Flash Lite Preview",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
1,
|
||||
),
|
||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
@@ -426,78 +375,12 @@ MODEL_METADATA = {
|
||||
LlmModel.MISTRAL_NEMO: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
|
||||
),
|
||||
LlmModel.MISTRAL_LARGE_3: ModelMetadata(
|
||||
"open_router",
|
||||
262144,
|
||||
None,
|
||||
"Mistral Large 3 2512",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
2,
|
||||
),
|
||||
LlmModel.MISTRAL_MEDIUM_3_1: ModelMetadata(
|
||||
"open_router",
|
||||
131072,
|
||||
None,
|
||||
"Mistral Medium 3.1",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
2,
|
||||
),
|
||||
LlmModel.MISTRAL_SMALL_3_2: ModelMetadata(
|
||||
"open_router",
|
||||
131072,
|
||||
131072,
|
||||
"Mistral Small 3.2 24B",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
1,
|
||||
),
|
||||
LlmModel.CODESTRAL: ModelMetadata(
|
||||
"open_router",
|
||||
256000,
|
||||
None,
|
||||
"Codestral 2508",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
1,
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_03_2025: ModelMetadata(
|
||||
"open_router", 256000, 8192, "Command A 03.2025", "OpenRouter", "Cohere", 2
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_TRANSLATE_08_2025: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
8192,
|
||||
"Command A Translate 08.2025",
|
||||
"OpenRouter",
|
||||
"Cohere",
|
||||
2,
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_REASONING_08_2025: ModelMetadata(
|
||||
"open_router",
|
||||
256000,
|
||||
32768,
|
||||
"Command A Reasoning 08.2025",
|
||||
"OpenRouter",
|
||||
"Cohere",
|
||||
3,
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_VISION_07_2025: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
8192,
|
||||
"Command A Vision 07.2025",
|
||||
"OpenRouter",
|
||||
"Cohere",
|
||||
2,
|
||||
),
|
||||
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
|
||||
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
|
||||
),
|
||||
@@ -510,15 +393,6 @@ MODEL_METADATA = {
|
||||
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
|
||||
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR_REASONING_PRO: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
8000,
|
||||
"Sonar Reasoning Pro",
|
||||
"OpenRouter",
|
||||
"Perplexity",
|
||||
2,
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
@@ -564,9 +438,6 @@ MODEL_METADATA = {
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
|
||||
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
|
||||
),
|
||||
LlmModel.MICROSOFT_PHI_4: ModelMetadata(
|
||||
"open_router", 16384, 16384, "Phi-4", "OpenRouter", "Microsoft", 1
|
||||
),
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
|
||||
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
|
||||
),
|
||||
@@ -576,15 +447,6 @@ MODEL_METADATA = {
|
||||
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
|
||||
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
|
||||
),
|
||||
LlmModel.GROK_3: ModelMetadata(
|
||||
"open_router",
|
||||
131072,
|
||||
131072,
|
||||
"Grok 3",
|
||||
"OpenRouter",
|
||||
"xAI",
|
||||
2,
|
||||
),
|
||||
LlmModel.GROK_4: ModelMetadata(
|
||||
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
|
||||
),
|
||||
@@ -942,11 +804,6 @@ async def llm_call(
|
||||
if tools:
|
||||
raise ValueError("Ollama does not support tools.")
|
||||
|
||||
# Validate user-provided Ollama host to prevent SSRF etc.
|
||||
await validate_url_host(
|
||||
ollama_host, trusted_hostnames=[settings.config.ollama_host]
|
||||
)
|
||||
|
||||
client = ollama.AsyncClient(host=ollama_host)
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
@@ -968,7 +825,7 @@ async def llm_call(
|
||||
elif provider == "open_router":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
client = openai.AsyncOpenAI(
|
||||
base_url=OPENROUTER_BASE_URL,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
import openai
|
||||
from pydantic import SecretStr, field_validator
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
@@ -13,7 +13,6 @@ from backend.blocks._base import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -22,7 +21,6 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
from backend.util.logging import TruncatedLogger
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[Perplexity-Block]")
|
||||
@@ -36,20 +34,6 @@ class PerplexityModel(str, Enum):
|
||||
SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||
|
||||
|
||||
def _sanitize_perplexity_model(value: Any) -> PerplexityModel:
|
||||
"""Return a valid PerplexityModel, falling back to SONAR for invalid values."""
|
||||
if isinstance(value, PerplexityModel):
|
||||
return value
|
||||
try:
|
||||
return PerplexityModel(value)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Invalid PerplexityModel '{value}', "
|
||||
f"falling back to {PerplexityModel.SONAR.value}"
|
||||
)
|
||||
return PerplexityModel.SONAR
|
||||
|
||||
|
||||
PerplexityCredentials = CredentialsMetaInput[
|
||||
Literal[ProviderName.OPEN_ROUTER], Literal["api_key"]
|
||||
]
|
||||
@@ -88,25 +72,6 @@ class PerplexityBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
credentials: PerplexityCredentials = PerplexityCredentialsField()
|
||||
|
||||
@field_validator("model", mode="before")
|
||||
@classmethod
|
||||
def fallback_invalid_model(cls, v: Any) -> PerplexityModel:
|
||||
"""Fall back to SONAR if the model value is not a valid
|
||||
PerplexityModel (e.g. an OpenAI model ID set by the agent
|
||||
generator)."""
|
||||
return _sanitize_perplexity_model(v)
|
||||
|
||||
@classmethod
|
||||
def validate_data(cls, data: BlockInput) -> str | None:
|
||||
"""Sanitize the model field before JSON schema validation so that
|
||||
invalid values are replaced with the default instead of raising a
|
||||
BlockInputError."""
|
||||
model_value = data.get("model")
|
||||
if model_value is not None:
|
||||
data["model"] = _sanitize_perplexity_model(model_value).value
|
||||
return super().validate_data(data)
|
||||
|
||||
system_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
default="",
|
||||
@@ -171,7 +136,7 @@ class PerplexityBlock(Block):
|
||||
) -> dict[str, Any]:
|
||||
"""Call Perplexity via OpenRouter and extract annotations."""
|
||||
client = openai.AsyncOpenAI(
|
||||
base_url=OPENROUTER_BASE_URL,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
|
||||
@@ -2232,7 +2232,6 @@ class DeleteRedditPostBlock(Block):
|
||||
("post_id", "abc123"),
|
||||
],
|
||||
test_mock={"delete_post": lambda creds, post_id: True},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -2291,7 +2290,6 @@ class DeleteRedditCommentBlock(Block):
|
||||
("comment_id", "xyz789"),
|
||||
],
|
||||
test_mock={"delete_comment": lambda creds, comment_id: True},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -72,7 +72,6 @@ class Slant3DCreateOrderBlock(Slant3DBlockBase):
|
||||
"_make_request": lambda *args, **kwargs: {"orderId": "314144241"},
|
||||
"_convert_to_color": lambda *args, **kwargs: "black",
|
||||
},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
async def run(
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
"""Unit tests for PerplexityBlock model fallback behavior."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.perplexity import (
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
PerplexityBlock,
|
||||
PerplexityModel,
|
||||
)
|
||||
|
||||
|
||||
def _make_input(**overrides) -> dict:
|
||||
defaults = {
|
||||
"prompt": "test query",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return defaults
|
||||
|
||||
|
||||
class TestPerplexityModelFallback:
|
||||
"""Tests for fallback_invalid_model field_validator."""
|
||||
|
||||
def test_invalid_model_falls_back_to_sonar(self):
|
||||
inp = PerplexityBlock.Input(**_make_input(model="gpt-5.2-2025-12-11"))
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
|
||||
def test_another_invalid_model_falls_back_to_sonar(self):
|
||||
inp = PerplexityBlock.Input(**_make_input(model="gpt-4o"))
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
|
||||
def test_valid_model_string_is_kept(self):
|
||||
inp = PerplexityBlock.Input(**_make_input(model="perplexity/sonar-pro"))
|
||||
assert inp.model == PerplexityModel.SONAR_PRO
|
||||
|
||||
def test_valid_enum_value_is_kept(self):
|
||||
inp = PerplexityBlock.Input(
|
||||
**_make_input(model=PerplexityModel.SONAR_DEEP_RESEARCH)
|
||||
)
|
||||
assert inp.model == PerplexityModel.SONAR_DEEP_RESEARCH
|
||||
|
||||
def test_default_model_when_omitted(self):
|
||||
inp = PerplexityBlock.Input(**_make_input())
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_value",
|
||||
[
|
||||
"perplexity/sonar",
|
||||
"perplexity/sonar-pro",
|
||||
"perplexity/sonar-deep-research",
|
||||
],
|
||||
)
|
||||
def test_all_valid_models_accepted(self, model_value: str):
|
||||
inp = PerplexityBlock.Input(**_make_input(model=model_value))
|
||||
assert inp.model.value == model_value
|
||||
|
||||
|
||||
class TestPerplexityValidateData:
|
||||
"""Tests for validate_data which runs during block execution (before
|
||||
Pydantic instantiation). Invalid models must be sanitized here so
|
||||
JSON schema validation does not reject them."""
|
||||
|
||||
def test_invalid_model_sanitized_before_schema_validation(self):
|
||||
data = _make_input(model="gpt-5.2-2025-12-11")
|
||||
error = PerplexityBlock.Input.validate_data(data)
|
||||
assert error is None
|
||||
assert data["model"] == PerplexityModel.SONAR.value
|
||||
|
||||
def test_valid_model_unchanged_by_validate_data(self):
|
||||
data = _make_input(model="perplexity/sonar-pro")
|
||||
error = PerplexityBlock.Input.validate_data(data)
|
||||
assert error is None
|
||||
assert data["model"] == "perplexity/sonar-pro"
|
||||
|
||||
def test_missing_model_uses_default(self):
|
||||
data = _make_input() # no model key
|
||||
error = PerplexityBlock.Input.validate_data(data)
|
||||
assert error is None
|
||||
inp = PerplexityBlock.Input(**data)
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
@@ -1,13 +1,10 @@
|
||||
"""Configuration management for chat system."""
|
||||
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
|
||||
|
||||
class ChatConfig(BaseSettings):
|
||||
"""Configuration for the chat system."""
|
||||
@@ -22,7 +19,7 @@ class ChatConfig(BaseSettings):
|
||||
)
|
||||
api_key: str | None = Field(default=None, description="OpenAI API key")
|
||||
base_url: str | None = Field(
|
||||
default=OPENROUTER_BASE_URL,
|
||||
default="https://openrouter.ai/api/v1",
|
||||
description="Base URL for API (e.g., for OpenRouter)",
|
||||
)
|
||||
|
||||
@@ -115,37 +112,9 @@ class ChatConfig(BaseSettings):
|
||||
description="E2B sandbox template to use for copilot sessions.",
|
||||
)
|
||||
e2b_sandbox_timeout: int = Field(
|
||||
default=10800, # 3 hours — wall-clock timeout, not idle; explicit pause is primary
|
||||
description="E2B sandbox running-time timeout (seconds). "
|
||||
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
|
||||
"mechanism; this is the safety net.",
|
||||
default=43200, # 12 hours — same as session_ttl
|
||||
description="E2B sandbox keepalive timeout in seconds.",
|
||||
)
|
||||
e2b_sandbox_on_timeout: Literal["kill", "pause"] = Field(
|
||||
default="pause",
|
||||
description="E2B lifecycle action on timeout: 'pause' (default, free) or 'kill'.",
|
||||
)
|
||||
|
||||
@property
|
||||
def e2b_active(self) -> bool:
|
||||
"""True when E2B is enabled and the API key is present.
|
||||
|
||||
Single source of truth for "should we use E2B right now?".
|
||||
Prefer this over combining ``use_e2b_sandbox`` and ``e2b_api_key``
|
||||
separately at call sites.
|
||||
"""
|
||||
return self.use_e2b_sandbox and bool(self.e2b_api_key)
|
||||
|
||||
@property
|
||||
def active_e2b_api_key(self) -> str | None:
|
||||
"""Return the E2B API key when E2B is enabled and configured, else None.
|
||||
|
||||
Combines the ``use_e2b_sandbox`` flag check and key presence into one.
|
||||
Use in callers::
|
||||
|
||||
if api_key := config.active_e2b_api_key:
|
||||
# E2B is active; api_key is narrowed to str
|
||||
"""
|
||||
return self.e2b_api_key if self.e2b_active else None
|
||||
|
||||
@field_validator("use_e2b_sandbox", mode="before")
|
||||
@classmethod
|
||||
@@ -195,7 +164,7 @@ class ChatConfig(BaseSettings):
|
||||
if not v:
|
||||
v = os.getenv("OPENAI_BASE_URL")
|
||||
if not v:
|
||||
v = OPENROUTER_BASE_URL
|
||||
v = "https://openrouter.ai/api/v1"
|
||||
return v
|
||||
|
||||
@field_validator("use_claude_agent_sdk", mode="before")
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
"""Unit tests for ChatConfig."""
|
||||
|
||||
import pytest
|
||||
|
||||
from .config import ChatConfig
|
||||
|
||||
# Env vars that the ChatConfig validators read — must be cleared so they don't
|
||||
# override the explicit constructor values we pass in each test.
|
||||
_E2B_ENV_VARS = (
|
||||
"CHAT_USE_E2B_SANDBOX",
|
||||
"CHAT_E2B_API_KEY",
|
||||
"E2B_API_KEY",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_e2b_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for var in _E2B_ENV_VARS:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
class TestE2BActive:
|
||||
"""Tests for the e2b_active property — single source of truth for E2B usage."""
|
||||
|
||||
def test_both_enabled_and_key_present_returns_true(self):
|
||||
"""e2b_active is True when use_e2b_sandbox=True and e2b_api_key is set."""
|
||||
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key="test-key")
|
||||
assert cfg.e2b_active is True
|
||||
|
||||
def test_enabled_but_missing_key_returns_false(self):
|
||||
"""e2b_active is False when use_e2b_sandbox=True but e2b_api_key is absent."""
|
||||
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key=None)
|
||||
assert cfg.e2b_active is False
|
||||
|
||||
def test_disabled_returns_false(self):
|
||||
"""e2b_active is False when use_e2b_sandbox=False regardless of key."""
|
||||
cfg = ChatConfig(use_e2b_sandbox=False, e2b_api_key="test-key")
|
||||
assert cfg.e2b_active is False
|
||||
@@ -6,32 +6,6 @@
|
||||
COPILOT_ERROR_PREFIX = "[__COPILOT_ERROR_f7a1__]" # Renders as ErrorCard
|
||||
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
|
||||
|
||||
# Prefix for all synthetic IDs generated by CoPilot block execution.
|
||||
# Used to distinguish CoPilot-generated records from real graph execution records
|
||||
# in PendingHumanReview and other tables.
|
||||
COPILOT_SYNTHETIC_ID_PREFIX = "copilot-"
|
||||
|
||||
# Sub-prefixes for session-scoped and node-scoped synthetic IDs.
|
||||
COPILOT_SESSION_PREFIX = f"{COPILOT_SYNTHETIC_ID_PREFIX}session-"
|
||||
COPILOT_NODE_PREFIX = f"{COPILOT_SYNTHETIC_ID_PREFIX}node-"
|
||||
|
||||
# Separator used in synthetic node_exec_id to encode node_id.
|
||||
# Format: "{node_id}:{random_hex}" — extract node_id via rsplit(":", 1)[0]
|
||||
COPILOT_NODE_EXEC_ID_SEPARATOR = ":"
|
||||
|
||||
# Compaction notice messages shown to users.
|
||||
COMPACTION_DONE_MSG = "Earlier messages were summarized to fit within context limits."
|
||||
COMPACTION_TOOL_NAME = "context_compaction"
|
||||
|
||||
|
||||
def is_copilot_synthetic_id(id_value: str) -> bool:
|
||||
"""Check if an ID is a CoPilot synthetic ID (not from a real graph execution)."""
|
||||
return id_value.startswith(COPILOT_SYNTHETIC_ID_PREFIX)
|
||||
|
||||
|
||||
def parse_node_id_from_exec_id(node_exec_id: str) -> str:
|
||||
"""Extract node_id from a synthetic node_exec_id.
|
||||
|
||||
Format: "{node_id}:{random_hex}" → returns "{node_id}".
|
||||
"""
|
||||
return node_exec_id.rsplit(COPILOT_NODE_EXEC_ID_SEPARATOR, 1)[0]
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
"""Shared execution context for copilot SDK tool handlers.
|
||||
|
||||
All context variables and their accessors live here so that
|
||||
``tool_adapter``, ``file_ref``, and ``e2b_file_tools`` can import them
|
||||
without creating circular dependencies.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
# Allowed base directory for the Read tool.
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
# Encoded project-directory name for the current session (e.g.
|
||||
# "-private-tmp-copilot-<uuid>"). Set by set_execution_context() so path
|
||||
# validation can scope tool-results reads to the current session.
|
||||
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
|
||||
|
||||
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
|
||||
_current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||
"current_session", default=None
|
||||
)
|
||||
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
|
||||
"_current_sandbox", default=None
|
||||
)
|
||||
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
|
||||
|
||||
|
||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does."""
|
||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
|
||||
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
sandbox: "AsyncSandbox | None" = None,
|
||||
sdk_cwd: str | None = None,
|
||||
) -> None:
|
||||
"""Set per-turn context variables used by file-resolution tool handlers."""
|
||||
_current_user_id.set(user_id)
|
||||
_current_session.set(session)
|
||||
_current_sandbox.set(sandbox)
|
||||
_current_sdk_cwd.set(sdk_cwd or "")
|
||||
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
|
||||
|
||||
|
||||
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
"""Return the current (user_id, session) pair for the active request."""
|
||||
return _current_user_id.get(), _current_session.get()
|
||||
|
||||
|
||||
def get_current_sandbox() -> "AsyncSandbox | None":
|
||||
"""Return the E2B sandbox for the current session, or None if not active."""
|
||||
return _current_sandbox.get()
|
||||
|
||||
|
||||
def get_sdk_cwd() -> str:
|
||||
"""Return the SDK working directory for the current session (empty string if unset)."""
|
||||
return _current_sdk_cwd.get()
|
||||
|
||||
|
||||
E2B_WORKDIR = "/home/user"
|
||||
|
||||
|
||||
def resolve_sandbox_path(path: str) -> str:
|
||||
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
|
||||
|
||||
Raises :class:`ValueError` if the resolved path escapes the sandbox.
|
||||
"""
|
||||
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
|
||||
normalized = os.path.normpath(candidate)
|
||||
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
|
||||
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
|
||||
return normalized
|
||||
|
||||
|
||||
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
"""Return True if *path* is within an allowed host-filesystem location.
|
||||
|
||||
Allowed:
|
||||
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/tool-results/`` (SDK tool-results)
|
||||
"""
|
||||
if not path:
|
||||
return False
|
||||
|
||||
if path.startswith("~"):
|
||||
resolved = os.path.realpath(os.path.expanduser(path))
|
||||
elif not os.path.isabs(path) and sdk_cwd:
|
||||
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
|
||||
else:
|
||||
resolved = os.path.realpath(path)
|
||||
|
||||
if sdk_cwd:
|
||||
norm_cwd = os.path.realpath(sdk_cwd)
|
||||
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
|
||||
return True
|
||||
|
||||
encoded = _current_project_dir.get("")
|
||||
if encoded:
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
if resolved == tool_results_dir or resolved.startswith(
|
||||
tool_results_dir + os.sep
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -1,163 +0,0 @@
|
||||
"""Tests for context.py — execution context variables and path helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import (
|
||||
_SDK_PROJECTS_DIR,
|
||||
_current_project_dir,
|
||||
get_current_sandbox,
|
||||
get_execution_context,
|
||||
get_sdk_cwd,
|
||||
is_allowed_local_path,
|
||||
resolve_sandbox_path,
|
||||
set_execution_context,
|
||||
)
|
||||
|
||||
|
||||
def _make_session() -> MagicMock:
|
||||
s = MagicMock()
|
||||
s.session_id = "test-session"
|
||||
return s
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Context variable getters
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_execution_context_defaults():
|
||||
"""get_execution_context returns (None, session) when user_id is not set."""
|
||||
set_execution_context(None, _make_session())
|
||||
user_id, session = get_execution_context()
|
||||
assert user_id is None
|
||||
assert session is not None
|
||||
|
||||
|
||||
def test_set_and_get_execution_context():
|
||||
"""set_execution_context stores user_id and session."""
|
||||
mock_session = _make_session()
|
||||
set_execution_context("user-abc", mock_session)
|
||||
user_id, session = get_execution_context()
|
||||
assert user_id == "user-abc"
|
||||
assert session is mock_session
|
||||
|
||||
|
||||
def test_get_current_sandbox_none_by_default():
|
||||
"""get_current_sandbox returns None when no sandbox is set."""
|
||||
set_execution_context("u1", _make_session(), sandbox=None)
|
||||
assert get_current_sandbox() is None
|
||||
|
||||
|
||||
def test_get_current_sandbox_returns_set_value():
|
||||
"""get_current_sandbox returns the sandbox set via set_execution_context."""
|
||||
mock_sandbox = MagicMock()
|
||||
set_execution_context("u1", _make_session(), sandbox=mock_sandbox)
|
||||
assert get_current_sandbox() is mock_sandbox
|
||||
|
||||
|
||||
def test_get_sdk_cwd_empty_when_not_set():
|
||||
"""get_sdk_cwd returns empty string when sdk_cwd is not set."""
|
||||
set_execution_context("u1", _make_session(), sdk_cwd=None)
|
||||
assert get_sdk_cwd() == ""
|
||||
|
||||
|
||||
def test_get_sdk_cwd_returns_set_value():
|
||||
"""get_sdk_cwd returns the value set via set_execution_context."""
|
||||
set_execution_context("u1", _make_session(), sdk_cwd="/tmp/copilot-test")
|
||||
assert get_sdk_cwd() == "/tmp/copilot-test"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_allowed_local_path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_is_allowed_local_path_empty():
|
||||
assert not is_allowed_local_path("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_inside_sdk_cwd():
|
||||
with tempfile.TemporaryDirectory() as cwd:
|
||||
path = os.path.join(cwd, "file.txt")
|
||||
assert is_allowed_local_path(path, cwd)
|
||||
|
||||
|
||||
def test_is_allowed_local_path_sdk_cwd_itself():
|
||||
with tempfile.TemporaryDirectory() as cwd:
|
||||
assert is_allowed_local_path(cwd, cwd)
|
||||
|
||||
|
||||
def test_is_allowed_local_path_outside_sdk_cwd():
|
||||
with tempfile.TemporaryDirectory() as cwd:
|
||||
assert not is_allowed_local_path("/etc/passwd", cwd)
|
||||
|
||||
|
||||
def test_is_allowed_local_path_no_sdk_cwd_no_project_dir():
|
||||
"""Without sdk_cwd or project_dir, all paths are rejected."""
|
||||
_current_project_dir.set("")
|
||||
assert not is_allowed_local_path("/tmp/some-file.txt", sdk_cwd=None)
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_results_dir():
|
||||
"""Files under the tool-results directory for the current project are allowed."""
|
||||
encoded = "test-encoded-dir"
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
path = os.path.join(tool_results_dir, "output.txt")
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert is_allowed_local_path(path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
|
||||
"""A path adjacent to tool-results/ but not inside it is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
sibling_path = os.path.join(_SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert not is_allowed_local_path(sibling_path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_sandbox_path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_absolute_valid():
|
||||
assert (
|
||||
resolve_sandbox_path("/home/user/project/main.py")
|
||||
== "/home/user/project/main.py"
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_relative():
|
||||
assert resolve_sandbox_path("project/main.py") == "/home/user/project/main.py"
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_workdir_itself():
|
||||
assert resolve_sandbox_path("/home/user") == "/home/user"
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_normalizes_dots():
|
||||
assert resolve_sandbox_path("/home/user/a/../b") == "/home/user/b"
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_escape_raises():
|
||||
with pytest.raises(ValueError, match="/home/user"):
|
||||
resolve_sandbox_path("/home/user/../../etc/passwd")
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_absolute_outside_raises():
|
||||
with pytest.raises(ValueError, match="/home/user"):
|
||||
resolve_sandbox_path("/etc/passwd")
|
||||
@@ -1,138 +0,0 @@
|
||||
"""Scheduler job to generate LLM-optimized block descriptions.
|
||||
|
||||
Runs periodically to rewrite block descriptions into concise, actionable
|
||||
summaries that help the copilot LLM pick the right blocks during agent
|
||||
generation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from backend.blocks import get_blocks
|
||||
from backend.util.clients import get_database_manager_client, get_openai_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a technical writer for an automation platform. "
|
||||
"Rewrite the following block description to be concise (under 50 words), "
|
||||
"informative, and actionable. Focus on what the block does and when to "
|
||||
"use it. Output ONLY the rewritten description, nothing else. "
|
||||
"Do not use markdown formatting."
|
||||
)
|
||||
|
||||
# Rate-limit delay between sequential LLM calls (seconds)
|
||||
_RATE_LIMIT_DELAY = 0.5
|
||||
# Maximum tokens for optimized description generation
|
||||
_MAX_DESCRIPTION_TOKENS = 150
|
||||
# Model for generating optimized descriptions (fast, cheap)
|
||||
_MODEL = "gpt-4o-mini"
|
||||
|
||||
|
||||
async def _optimize_descriptions(blocks: list[dict[str, str]]) -> dict[str, str]:
|
||||
"""Call the shared OpenAI client to rewrite each block description."""
|
||||
client = get_openai_client()
|
||||
if client is None:
|
||||
logger.error(
|
||||
"No OpenAI client configured, skipping block description optimization"
|
||||
)
|
||||
return {}
|
||||
|
||||
results: dict[str, str] = {}
|
||||
for block in blocks:
|
||||
block_id = block["id"]
|
||||
block_name = block["name"]
|
||||
description = block["description"]
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Block name: {block_name}\nDescription: {description}",
|
||||
},
|
||||
],
|
||||
max_tokens=_MAX_DESCRIPTION_TOKENS,
|
||||
)
|
||||
optimized = (response.choices[0].message.content or "").strip()
|
||||
if optimized:
|
||||
results[block_id] = optimized
|
||||
logger.debug("Optimized description for %s", block_name)
|
||||
else:
|
||||
logger.warning("Empty response for block %s", block_name)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to optimize description for %s", block_name, exc_info=True
|
||||
)
|
||||
|
||||
await asyncio.sleep(_RATE_LIMIT_DELAY)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def optimize_block_descriptions() -> dict[str, int]:
|
||||
"""Generate optimized descriptions for blocks that don't have one yet.
|
||||
|
||||
Uses the shared OpenAI client to rewrite block descriptions into concise
|
||||
summaries suitable for agent generation prompts.
|
||||
|
||||
Returns:
|
||||
Dict with counts: processed, success, failed, skipped.
|
||||
"""
|
||||
db_client = get_database_manager_client()
|
||||
|
||||
blocks = db_client.get_blocks_needing_optimization()
|
||||
if not blocks:
|
||||
logger.info("All blocks already have optimized descriptions")
|
||||
return {"processed": 0, "success": 0, "failed": 0, "skipped": 0}
|
||||
|
||||
logger.info("Found %d blocks needing optimized descriptions", len(blocks))
|
||||
|
||||
non_empty = [b for b in blocks if b.get("description", "").strip()]
|
||||
skipped = len(blocks) - len(non_empty)
|
||||
|
||||
new_descriptions = asyncio.run(_optimize_descriptions(non_empty))
|
||||
|
||||
stats = {
|
||||
"processed": len(non_empty),
|
||||
"success": len(new_descriptions),
|
||||
"failed": len(non_empty) - len(new_descriptions),
|
||||
"skipped": skipped,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"Block description optimization complete: "
|
||||
"%d/%d succeeded, %d failed, %d skipped",
|
||||
stats["success"],
|
||||
stats["processed"],
|
||||
stats["failed"],
|
||||
stats["skipped"],
|
||||
)
|
||||
|
||||
if new_descriptions:
|
||||
for block_id, optimized in new_descriptions.items():
|
||||
db_client.update_block_optimized_description(block_id, optimized)
|
||||
|
||||
# Update in-memory descriptions first so the cache rebuilds with fresh data.
|
||||
try:
|
||||
block_classes = get_blocks()
|
||||
for block_id, optimized in new_descriptions.items():
|
||||
if block_id in block_classes:
|
||||
block_classes[block_id]._optimized_description = optimized
|
||||
logger.info(
|
||||
"Updated %d in-memory block descriptions", len(new_descriptions)
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Could not update in-memory block descriptions", exc_info=True
|
||||
)
|
||||
|
||||
from backend.copilot.tools.agent_generator.blocks import (
|
||||
reset_block_caches, # local to avoid circular import
|
||||
)
|
||||
|
||||
reset_block_caches()
|
||||
|
||||
return stats
|
||||
@@ -1,91 +0,0 @@
|
||||
"""Unit tests for optimize_blocks._optimize_descriptions."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.copilot.optimize_blocks import _RATE_LIMIT_DELAY, _optimize_descriptions
|
||||
|
||||
|
||||
def _make_client_response(text: str) -> MagicMock:
|
||||
"""Build a minimal mock that looks like an OpenAI ChatCompletion response."""
|
||||
choice = MagicMock()
|
||||
choice.message.content = text
|
||||
response = MagicMock()
|
||||
response.choices = [choice]
|
||||
return response
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.get_event_loop().run_until_complete(coro)
|
||||
|
||||
|
||||
class TestOptimizeDescriptions:
|
||||
"""Tests for _optimize_descriptions async function."""
|
||||
|
||||
def test_returns_empty_when_no_client(self):
|
||||
with patch(
|
||||
"backend.copilot.optimize_blocks.get_openai_client", return_value=None
|
||||
):
|
||||
result = _run(
|
||||
_optimize_descriptions([{"id": "b1", "name": "B", "description": "d"}])
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
def test_success_single_block(self):
|
||||
client = MagicMock()
|
||||
client.chat.completions.create = AsyncMock(
|
||||
return_value=_make_client_response("Short desc.")
|
||||
)
|
||||
blocks = [{"id": "b1", "name": "MyBlock", "description": "A block."}]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.optimize_blocks.asyncio.sleep", new_callable=AsyncMock
|
||||
),
|
||||
):
|
||||
result = _run(_optimize_descriptions(blocks))
|
||||
|
||||
assert result == {"b1": "Short desc."}
|
||||
client.chat.completions.create.assert_called_once()
|
||||
|
||||
def test_skips_block_on_exception(self):
|
||||
client = MagicMock()
|
||||
client.chat.completions.create = AsyncMock(side_effect=Exception("API error"))
|
||||
blocks = [{"id": "b1", "name": "MyBlock", "description": "A block."}]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.optimize_blocks.asyncio.sleep", new_callable=AsyncMock
|
||||
),
|
||||
):
|
||||
result = _run(_optimize_descriptions(blocks))
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_sleeps_between_blocks(self):
|
||||
client = MagicMock()
|
||||
client.chat.completions.create = AsyncMock(
|
||||
return_value=_make_client_response("desc")
|
||||
)
|
||||
blocks = [
|
||||
{"id": "b1", "name": "B1", "description": "d1"},
|
||||
{"id": "b2", "name": "B2", "description": "d2"},
|
||||
]
|
||||
sleep_mock = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
|
||||
),
|
||||
patch("backend.copilot.optimize_blocks.asyncio.sleep", sleep_mock),
|
||||
):
|
||||
_run(_optimize_descriptions(blocks))
|
||||
|
||||
assert sleep_mock.call_count == 2
|
||||
sleep_mock.assert_called_with(_RATE_LIMIT_DELAY)
|
||||
@@ -26,33 +26,6 @@ your message as a Markdown link or image:
|
||||
The `download_url` field in the `write_workspace_file` response is already
|
||||
in the correct format — paste it directly after the `(` in the Markdown.
|
||||
|
||||
### Passing file content to tools — @@agptfile: references
|
||||
Instead of copying large file contents into a tool argument, pass a file
|
||||
reference and the platform will load the content for you.
|
||||
|
||||
Syntax: `@@agptfile:<uri>[<start>-<end>]`
|
||||
|
||||
- `<uri>` **must** start with `workspace://` or `/` (absolute path):
|
||||
- `workspace://<file_id>` — workspace file by ID
|
||||
- `workspace:///<path>` — workspace file by virtual path
|
||||
- `/absolute/local/path` — ephemeral or sdk_cwd file
|
||||
- E2B sandbox absolute path (e.g. `/home/user/script.py`)
|
||||
- `[<start>-<end>]` is an optional 1-indexed inclusive line range.
|
||||
- URIs that do not start with `workspace://` or `/` are **not** expanded.
|
||||
|
||||
Examples:
|
||||
```
|
||||
@@agptfile:workspace://abc123
|
||||
@@agptfile:workspace://abc123[10-50]
|
||||
@@agptfile:workspace:///reports/q1.md
|
||||
@@agptfile:/tmp/copilot-<session>/output.py[1-80]
|
||||
@@agptfile:/home/user/script.py
|
||||
```
|
||||
|
||||
You can embed a reference inside any string argument, or use it as the entire
|
||||
value. Multiple references in one argument are all expanded.
|
||||
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
## Agent Generation Guide
|
||||
|
||||
You can create, edit, and customize agents directly. You ARE the brain —
|
||||
generate the agent JSON yourself using block schemas, then validate and save.
|
||||
|
||||
### Workflow for Creating/Editing Agents
|
||||
|
||||
1. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
|
||||
search for relevant blocks. This returns block IDs, names, descriptions,
|
||||
and full input/output schemas.
|
||||
2. **Find library agents**: Call `find_library_agent` to discover reusable
|
||||
agents that can be composed as sub-agents via `AgentExecutorBlock`.
|
||||
3. **Generate JSON**: Build the agent JSON using block schemas:
|
||||
- Use block IDs from step 1 as `block_id` in nodes
|
||||
- Wire outputs to inputs using links
|
||||
- Set design-time config in `input_default`
|
||||
- Use `AgentInputBlock` for values the user provides at runtime
|
||||
4. **Write to workspace**: Save the JSON to a workspace file so the user
|
||||
can review it: `write_workspace_file(filename="agent.json", content=...)`
|
||||
5. **Validate**: Call `validate_agent_graph` with the agent JSON to check
|
||||
for errors
|
||||
6. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
|
||||
or fix manually based on the error descriptions. Iterate until valid.
|
||||
7. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
|
||||
the final `agent_json`
|
||||
|
||||
### Agent JSON Structure
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "<UUID v4>", // auto-generated if omitted
|
||||
"version": 1,
|
||||
"is_active": true,
|
||||
"name": "Agent Name",
|
||||
"description": "What the agent does",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "<UUID v4>",
|
||||
"block_id": "<block UUID from find_block>",
|
||||
"input_default": {
|
||||
"field_name": "design-time value"
|
||||
},
|
||||
"metadata": {
|
||||
"position": {"x": 0, "y": 0},
|
||||
"customized_name": "Optional display name"
|
||||
}
|
||||
}
|
||||
],
|
||||
"links": [
|
||||
{
|
||||
"id": "<UUID v4>",
|
||||
"source_id": "<source node UUID>",
|
||||
"source_name": "output_field_name",
|
||||
"sink_id": "<sink node UUID>",
|
||||
"sink_name": "input_field_name",
|
||||
"is_static": false
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### REQUIRED: AgentInputBlock and AgentOutputBlock
|
||||
|
||||
Every agent MUST include at least one AgentInputBlock and one AgentOutputBlock.
|
||||
These define the agent's interface — what it accepts and what it produces.
|
||||
|
||||
**AgentInputBlock** (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`):
|
||||
- Defines a user-facing input field on the agent
|
||||
- Required `input_default` fields: `name` (str), `value` (default: null)
|
||||
- Optional: `title`, `description`, `placeholder_values` (for dropdowns)
|
||||
- Output: `result` — the user-provided value at runtime
|
||||
- Create one AgentInputBlock per distinct input the agent needs
|
||||
|
||||
**AgentOutputBlock** (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`):
|
||||
- Defines a user-facing output displayed after the agent runs
|
||||
- Required `input_default` fields: `name` (str)
|
||||
- The `value` input should be linked from another block's output
|
||||
- Optional: `title`, `description`, `format` (Jinja2 template)
|
||||
- Create one AgentOutputBlock per distinct result to show the user
|
||||
|
||||
Without these blocks, the agent has no interface and the user cannot provide
|
||||
inputs or see outputs. NEVER skip them.
|
||||
|
||||
### Key Rules
|
||||
|
||||
- **Name & description**: Include `name` and `description` in the agent JSON
|
||||
when creating a new agent, or when editing and the agent's purpose changed.
|
||||
Without these the agent gets a generic default name.
|
||||
- **Design-time vs runtime**: `input_default` = values known at build time.
|
||||
For user-provided values, create an `AgentInputBlock` node and link its
|
||||
output to the consuming block's input.
|
||||
- **Credentials**: Do NOT require credentials upfront. Users configure
|
||||
credentials later in the platform UI after the agent is saved.
|
||||
- **Node spacing**: Position nodes with at least 800 X-units between them.
|
||||
- **Nested properties**: Use `parentField_#_childField` notation in link
|
||||
sink_name/source_name to access nested object fields.
|
||||
- **is_static links**: Set `is_static: true` when the link carries a
|
||||
design-time constant (matches a field in inputSchema with a default).
|
||||
- **ConditionBlock**: Needs a `StoreValueBlock` wired to its `value2` input.
|
||||
- **Prompt templates**: Use `{{variable}}` (double curly braces) for
|
||||
literal braces in prompt strings — single `{` and `}` are for
|
||||
template variables.
|
||||
- **AgentExecutorBlock**: When composing sub-agents, set `graph_id` and
|
||||
`graph_version` in input_default, and wire inputs/outputs to match
|
||||
the sub-agent's schema.
|
||||
|
||||
### Using Sub-Agents (AgentExecutorBlock)
|
||||
|
||||
To compose agents using other agents as sub-agents:
|
||||
1. Call `find_library_agent` to find the sub-agent — the response includes
|
||||
`graph_id`, `graph_version`, `input_schema`, and `output_schema`
|
||||
2. Create an `AgentExecutorBlock` node (ID: `e189baac-8c20-45a1-94a7-55177ea42565`)
|
||||
3. Set `input_default`:
|
||||
- `graph_id`: from the library agent's `graph_id`
|
||||
- `graph_version`: from the library agent's `graph_version`
|
||||
- `input_schema`: from the library agent's `input_schema` (JSON Schema)
|
||||
- `output_schema`: from the library agent's `output_schema` (JSON Schema)
|
||||
- `user_id`: leave as `""` (filled at runtime)
|
||||
- `inputs`: `{}` (populated by links at runtime)
|
||||
4. Wire inputs: link to sink names matching the sub-agent's `input_schema`
|
||||
property names (e.g., if input_schema has a `"url"` property, use
|
||||
`"url"` as the sink_name)
|
||||
5. Wire outputs: link from source names matching the sub-agent's
|
||||
`output_schema` property names
|
||||
6. Pass `library_agent_ids` to `create_agent`/`customize_agent` with
|
||||
the library agent IDs used, so the fixer can validate schemas
|
||||
|
||||
### Using MCP Tools (MCPToolBlock)
|
||||
|
||||
To use an MCP (Model Context Protocol) tool as a node in the agent:
|
||||
1. The user must specify which MCP server URL and tool name they want
|
||||
2. Create an `MCPToolBlock` node (ID: `a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4`)
|
||||
3. Set `input_default`:
|
||||
- `server_url`: the MCP server URL (e.g. `"https://mcp.example.com/sse"`)
|
||||
- `selected_tool`: the tool name on that server
|
||||
- `tool_input_schema`: JSON Schema for the tool's inputs
|
||||
- `tool_arguments`: `{}` (populated by links or hardcoded values)
|
||||
4. The block requires MCP credentials — the user configures these in the
|
||||
platform UI after the agent is saved
|
||||
5. Wire inputs using the tool argument field name directly as the sink_name
|
||||
(e.g., `query`, NOT `tool_arguments_#_query`). The execution engine
|
||||
automatically collects top-level fields matching tool_input_schema into
|
||||
tool_arguments.
|
||||
6. Output: `result` (the tool's return value) and `error` (error message)
|
||||
|
||||
### Example: Simple AI Text Processor
|
||||
|
||||
A minimal agent with input, processing, and output:
|
||||
- Node 1: `AgentInputBlock` (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`,
|
||||
input_default: {"name": "user_text", "title": "Text to process"},
|
||||
output: "result")
|
||||
- Node 2: `AITextGeneratorBlock` (input: "prompt" linked from Node 1's "result")
|
||||
- Node 3: `AgentOutputBlock` (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`,
|
||||
input_default: {"name": "summary", "title": "Summary"},
|
||||
input: "value" linked from Node 2's output)
|
||||
@@ -8,6 +8,8 @@ SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
|
||||
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
@@ -15,23 +17,36 @@ import os
|
||||
import shlex
|
||||
from typing import Any, Callable
|
||||
|
||||
from backend.copilot.context import (
|
||||
E2B_WORKDIR,
|
||||
get_current_sandbox,
|
||||
get_sdk_cwd,
|
||||
is_allowed_local_path,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
from backend.copilot.tools.e2b_sandbox import E2B_WORKDIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_sandbox():
|
||||
# Lazy imports to break circular dependency with tool_adapter.
|
||||
|
||||
|
||||
def _get_sandbox(): # type: ignore[return]
|
||||
from .tool_adapter import get_current_sandbox # noqa: E402
|
||||
|
||||
return get_current_sandbox()
|
||||
|
||||
|
||||
def _is_allowed_local(path: str) -> bool:
|
||||
return is_allowed_local_path(path, get_sdk_cwd())
|
||||
from .tool_adapter import is_allowed_local_path # noqa: E402
|
||||
|
||||
return is_allowed_local_path(path)
|
||||
|
||||
|
||||
def _resolve_remote(path: str) -> str:
|
||||
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
|
||||
|
||||
Raises :class:`ValueError` if the resolved path escapes the sandbox.
|
||||
"""
|
||||
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
|
||||
normalized = os.path.normpath(candidate)
|
||||
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
|
||||
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
|
||||
return normalized
|
||||
|
||||
|
||||
def _mcp(text: str, *, error: bool = False) -> dict[str, Any]:
|
||||
@@ -48,7 +63,7 @@ def _get_sandbox_and_path(
|
||||
if sandbox is None:
|
||||
return _mcp("No E2B sandbox available", error=True)
|
||||
try:
|
||||
remote = resolve_sandbox_path(file_path)
|
||||
remote = _resolve_remote(file_path)
|
||||
except ValueError as exc:
|
||||
return _mcp(str(exc), error=True)
|
||||
return sandbox, remote
|
||||
@@ -58,7 +73,6 @@ def _get_sandbox_and_path(
|
||||
|
||||
|
||||
async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Read lines from a sandbox file, falling back to the local host for SDK-internal paths."""
|
||||
file_path: str = args.get("file_path", "")
|
||||
offset: int = max(0, int(args.get("offset", 0)))
|
||||
limit: int = max(1, int(args.get("limit", 2000)))
|
||||
@@ -90,7 +104,6 @@ async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
|
||||
async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Write content to a sandbox file, creating parent directories as needed."""
|
||||
file_path: str = args.get("file_path", "")
|
||||
content: str = args.get("content", "")
|
||||
|
||||
@@ -114,7 +127,6 @@ async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
|
||||
async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Replace a substring in a sandbox file, with optional replace-all support."""
|
||||
file_path: str = args.get("file_path", "")
|
||||
old_string: str = args.get("old_string", "")
|
||||
new_string: str = args.get("new_string", "")
|
||||
@@ -160,7 +172,6 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
|
||||
async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Find files matching a name pattern inside the sandbox using ``find``."""
|
||||
pattern: str = args.get("pattern", "")
|
||||
path: str = args.get("path", "")
|
||||
|
||||
@@ -172,7 +183,7 @@ async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
|
||||
return _mcp("No E2B sandbox available", error=True)
|
||||
|
||||
try:
|
||||
search_dir = resolve_sandbox_path(path) if path else E2B_WORKDIR
|
||||
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
|
||||
except ValueError as exc:
|
||||
return _mcp(str(exc), error=True)
|
||||
|
||||
@@ -187,7 +198,6 @@ async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
|
||||
async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Search file contents by regex inside the sandbox using ``grep -rn``."""
|
||||
pattern: str = args.get("pattern", "")
|
||||
path: str = args.get("path", "")
|
||||
include: str = args.get("include", "")
|
||||
@@ -200,7 +210,7 @@ async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
|
||||
return _mcp("No E2B sandbox available", error=True)
|
||||
|
||||
try:
|
||||
search_dir = resolve_sandbox_path(path) if path else E2B_WORKDIR
|
||||
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
|
||||
except ValueError as exc:
|
||||
return _mcp(str(exc), error=True)
|
||||
|
||||
@@ -228,7 +238,7 @@ def _read_local(file_path: str, offset: int, limit: int) -> dict[str, Any]:
|
||||
return _mcp(f"Path not allowed: {file_path}", error=True)
|
||||
expanded = os.path.realpath(os.path.expanduser(file_path))
|
||||
try:
|
||||
with open(expanded, encoding="utf-8", errors="replace") as fh:
|
||||
with open(expanded) as fh:
|
||||
selected = list(itertools.islice(fh, offset, offset + limit))
|
||||
numbered = "".join(
|
||||
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
|
||||
|
||||
@@ -7,60 +7,59 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import _current_project_dir
|
||||
|
||||
from .e2b_file_tools import _read_local, resolve_sandbox_path
|
||||
from .e2b_file_tools import _read_local, _resolve_remote
|
||||
from .tool_adapter import _current_project_dir
|
||||
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
|
||||
# _resolve_remote — sandbox path normalisation & boundary enforcement
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveSandboxPath:
|
||||
class TestResolveRemote:
|
||||
def test_relative_path_resolved(self):
|
||||
assert resolve_sandbox_path("src/main.py") == "/home/user/src/main.py"
|
||||
assert _resolve_remote("src/main.py") == "/home/user/src/main.py"
|
||||
|
||||
def test_absolute_within_sandbox(self):
|
||||
assert resolve_sandbox_path("/home/user/file.txt") == "/home/user/file.txt"
|
||||
assert _resolve_remote("/home/user/file.txt") == "/home/user/file.txt"
|
||||
|
||||
def test_workdir_itself(self):
|
||||
assert resolve_sandbox_path("/home/user") == "/home/user"
|
||||
assert _resolve_remote("/home/user") == "/home/user"
|
||||
|
||||
def test_relative_dotslash(self):
|
||||
assert resolve_sandbox_path("./README.md") == "/home/user/README.md"
|
||||
assert _resolve_remote("./README.md") == "/home/user/README.md"
|
||||
|
||||
def test_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("../../etc/passwd")
|
||||
_resolve_remote("../../etc/passwd")
|
||||
|
||||
def test_absolute_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("/home/user/../../etc/passwd")
|
||||
_resolve_remote("/home/user/../../etc/passwd")
|
||||
|
||||
def test_absolute_outside_sandbox_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("/etc/passwd")
|
||||
_resolve_remote("/etc/passwd")
|
||||
|
||||
def test_root_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("/")
|
||||
_resolve_remote("/")
|
||||
|
||||
def test_home_other_user_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("/home/other/file.txt")
|
||||
_resolve_remote("/home/other/file.txt")
|
||||
|
||||
def test_deep_nested_allowed(self):
|
||||
assert resolve_sandbox_path("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
|
||||
assert _resolve_remote("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
|
||||
|
||||
def test_trailing_slash_normalised(self):
|
||||
assert resolve_sandbox_path("src/") == "/home/user/src"
|
||||
assert _resolve_remote("src/") == "/home/user/src"
|
||||
|
||||
def test_double_dots_within_sandbox_ok(self):
|
||||
"""Path that resolves back within /home/user is allowed."""
|
||||
assert resolve_sandbox_path("a/b/../c.txt") == "/home/user/a/c.txt"
|
||||
assert _resolve_remote("a/b/../c.txt") == "/home/user/a/c.txt"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -1,281 +0,0 @@
|
||||
"""File reference protocol for tool call inputs.
|
||||
|
||||
Allows the LLM to pass a file reference instead of embedding large content
|
||||
inline. The processor expands ``@@agptfile:<uri>[<start>-<end>]`` tokens in tool
|
||||
arguments before the tool is executed.
|
||||
|
||||
Protocol
|
||||
--------
|
||||
|
||||
@@agptfile:<uri>[<start>-<end>]
|
||||
|
||||
``<uri>`` (required)
|
||||
- ``workspace://<file_id>`` — workspace file by ID
|
||||
- ``workspace://<file_id>#<mime>`` — same, MIME hint is ignored for reads
|
||||
- ``workspace:///<path>`` — workspace file by virtual path
|
||||
- ``/absolute/local/path`` — ephemeral or sdk_cwd file (validated by
|
||||
:func:`~backend.copilot.sdk.tool_adapter.is_allowed_local_path`)
|
||||
- Any absolute path that resolves inside the E2B sandbox
|
||||
(``/home/user/...``) when a sandbox is active
|
||||
|
||||
``[<start>-<end>]`` (optional)
|
||||
Line range, 1-indexed inclusive. Examples: ``[1-100]``, ``[50-200]``.
|
||||
Omit to read the entire file.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@agptfile:workspace://abc123
|
||||
@@agptfile:workspace://abc123[10-50]
|
||||
@@agptfile:workspace:///reports/q1.md
|
||||
@@agptfile:/tmp/copilot-<session>/output.py[1-80]
|
||||
@@agptfile:/home/user/script.sh
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.context import (
|
||||
get_current_sandbox,
|
||||
get_sdk_cwd,
|
||||
is_allowed_local_path,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.workspace_files import get_manager
|
||||
from backend.util.file import parse_workspace_uri
|
||||
|
||||
|
||||
class FileRefExpansionError(Exception):
|
||||
"""Raised when a ``@@agptfile:`` reference in tool call args fails to resolve.
|
||||
|
||||
Separating this from inline substitution lets callers (e.g. the MCP tool
|
||||
wrapper) block tool execution and surface a helpful error to the model
|
||||
rather than passing an ``[file-ref error: …]`` string as actual input.
|
||||
"""
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FILE_REF_PREFIX = "@@agptfile:"
|
||||
|
||||
# Matches: @@agptfile:<uri>[start-end]?
|
||||
# Group 1 – URI; must start with '/' (absolute path) or 'workspace://'
|
||||
# Group 2 – start line (optional)
|
||||
# Group 3 – end line (optional)
|
||||
_FILE_REF_RE = re.compile(
|
||||
re.escape(FILE_REF_PREFIX) + r"((?:workspace://|/)[^\[\s]*)(?:\[(\d+)-(\d+)\])?"
|
||||
)
|
||||
|
||||
# Maximum characters returned for a single file reference expansion.
|
||||
_MAX_EXPAND_CHARS = 200_000
|
||||
# Maximum total characters across all @@agptfile: expansions in one string.
|
||||
_MAX_TOTAL_EXPAND_CHARS = 1_000_000
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileRef:
|
||||
uri: str
|
||||
start_line: int | None # 1-indexed, inclusive
|
||||
end_line: int | None # 1-indexed, inclusive
|
||||
|
||||
|
||||
def parse_file_ref(text: str) -> FileRef | None:
|
||||
"""Return a :class:`FileRef` if *text* is a bare file reference token.
|
||||
|
||||
A "bare token" means the entire string matches the ``@@agptfile:...`` pattern
|
||||
(after stripping whitespace). Use :func:`expand_file_refs_in_string` to
|
||||
expand references embedded in larger strings.
|
||||
"""
|
||||
m = _FILE_REF_RE.fullmatch(text.strip())
|
||||
if not m:
|
||||
return None
|
||||
start = int(m.group(2)) if m.group(2) else None
|
||||
end = int(m.group(3)) if m.group(3) else None
|
||||
if start is not None and start < 1:
|
||||
return None
|
||||
if end is not None and end < 1:
|
||||
return None
|
||||
if start is not None and end is not None and end < start:
|
||||
return None
|
||||
return FileRef(uri=m.group(1), start_line=start, end_line=end)
|
||||
|
||||
|
||||
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
|
||||
"""Slice *text* to the requested 1-indexed line range (inclusive)."""
|
||||
if start is None and end is None:
|
||||
return text
|
||||
lines = text.splitlines(keepends=True)
|
||||
s = (start - 1) if start is not None else 0
|
||||
e = end if end is not None else len(lines)
|
||||
selected = list(itertools.islice(lines, s, e))
|
||||
return "".join(selected)
|
||||
|
||||
|
||||
async def read_file_bytes(
|
||||
uri: str,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
) -> bytes:
|
||||
"""Resolve *uri* to raw bytes using workspace, local, or E2B path logic.
|
||||
|
||||
Raises :class:`ValueError` if the URI cannot be resolved.
|
||||
"""
|
||||
# Strip MIME fragment (e.g. workspace://id#mime) before dispatching.
|
||||
plain = uri.split("#")[0] if uri.startswith("workspace://") else uri
|
||||
|
||||
if plain.startswith("workspace://"):
|
||||
if not user_id:
|
||||
raise ValueError("workspace:// file references require authentication")
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
ws = parse_workspace_uri(plain)
|
||||
try:
|
||||
return await (
|
||||
manager.read_file(ws.file_ref)
|
||||
if ws.is_path
|
||||
else manager.read_file_by_id(ws.file_ref)
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"File not found: {plain}")
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
|
||||
if is_allowed_local_path(plain, get_sdk_cwd()):
|
||||
resolved = os.path.realpath(os.path.expanduser(plain))
|
||||
try:
|
||||
with open(resolved, "rb") as fh:
|
||||
return fh.read()
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"File not found: {plain}")
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
try:
|
||||
remote = resolve_sandbox_path(plain)
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
|
||||
) from exc
|
||||
try:
|
||||
return bytes(await sandbox.files.read(remote, format="bytes"))
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
|
||||
|
||||
raise ValueError(
|
||||
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
|
||||
)
|
||||
|
||||
|
||||
async def resolve_file_ref(
|
||||
ref: FileRef,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
) -> str:
|
||||
"""Resolve a :class:`FileRef` to its text content."""
|
||||
raw = await read_file_bytes(ref.uri, user_id, session)
|
||||
return _apply_line_range(
|
||||
raw.decode("utf-8", errors="replace"), ref.start_line, ref.end_line
|
||||
)
|
||||
|
||||
|
||||
async def expand_file_refs_in_string(
|
||||
text: str,
|
||||
user_id: str | None,
|
||||
session: "ChatSession",
|
||||
*,
|
||||
raise_on_error: bool = False,
|
||||
) -> str:
|
||||
"""Expand all ``@@agptfile:...`` tokens in *text*, returning the substituted string.
|
||||
|
||||
Non-reference text is passed through unchanged.
|
||||
|
||||
If *raise_on_error* is ``False`` (default), expansion errors are surfaced
|
||||
inline as ``[file-ref error: <message>]`` — useful for display/log contexts
|
||||
where partial expansion is acceptable.
|
||||
|
||||
If *raise_on_error* is ``True``, any resolution failure raises
|
||||
:class:`FileRefExpansionError` immediately so the caller can block the
|
||||
operation and surface a clean error to the model.
|
||||
"""
|
||||
if FILE_REF_PREFIX not in text:
|
||||
return text
|
||||
|
||||
result: list[str] = []
|
||||
last_end = 0
|
||||
total_chars = 0
|
||||
for m in _FILE_REF_RE.finditer(text):
|
||||
result.append(text[last_end : m.start()])
|
||||
start = int(m.group(2)) if m.group(2) else None
|
||||
end = int(m.group(3)) if m.group(3) else None
|
||||
if (start is not None and start < 1) or (end is not None and end < 1):
|
||||
msg = f"line numbers must be >= 1: {m.group(0)}"
|
||||
if raise_on_error:
|
||||
raise FileRefExpansionError(msg)
|
||||
result.append(f"[file-ref error: {msg}]")
|
||||
last_end = m.end()
|
||||
continue
|
||||
if start is not None and end is not None and end < start:
|
||||
msg = f"end line must be >= start line: {m.group(0)}"
|
||||
if raise_on_error:
|
||||
raise FileRefExpansionError(msg)
|
||||
result.append(f"[file-ref error: {msg}]")
|
||||
last_end = m.end()
|
||||
continue
|
||||
ref = FileRef(uri=m.group(1), start_line=start, end_line=end)
|
||||
try:
|
||||
content = await resolve_file_ref(ref, user_id, session)
|
||||
if len(content) > _MAX_EXPAND_CHARS:
|
||||
content = content[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
|
||||
remaining = _MAX_TOTAL_EXPAND_CHARS - total_chars
|
||||
if remaining <= 0:
|
||||
content = "[file-ref budget exhausted: total expansion limit reached]"
|
||||
elif len(content) > remaining:
|
||||
content = content[:remaining] + "\n... [total budget exhausted]"
|
||||
total_chars += len(content)
|
||||
result.append(content)
|
||||
except ValueError as exc:
|
||||
logger.warning("file-ref expansion failed for %r: %s", m.group(0), exc)
|
||||
if raise_on_error:
|
||||
raise FileRefExpansionError(str(exc)) from exc
|
||||
result.append(f"[file-ref error: {exc}]")
|
||||
last_end = m.end()
|
||||
|
||||
result.append(text[last_end:])
|
||||
return "".join(result)
|
||||
|
||||
|
||||
async def expand_file_refs_in_args(
|
||||
args: dict[str, Any],
|
||||
user_id: str | None,
|
||||
session: "ChatSession",
|
||||
) -> dict[str, Any]:
|
||||
"""Recursively expand ``@@agptfile:...`` references in tool call arguments.
|
||||
|
||||
String values are expanded in-place. Nested dicts and lists are
|
||||
traversed. Non-string scalars are returned unchanged.
|
||||
|
||||
Raises :class:`FileRefExpansionError` if any reference fails to resolve,
|
||||
so the tool is *not* executed with an error string as its input. The
|
||||
caller (the MCP tool wrapper) should convert this into an MCP error
|
||||
response that lets the model correct the reference before retrying.
|
||||
"""
|
||||
if not args:
|
||||
return args
|
||||
|
||||
async def _expand(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
return await expand_file_refs_in_string(
|
||||
value, user_id, session, raise_on_error=True
|
||||
)
|
||||
if isinstance(value, dict):
|
||||
return {k: await _expand(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [await _expand(item) for item in value]
|
||||
return value
|
||||
|
||||
return {k: await _expand(v) for k, v in args.items()}
|
||||
@@ -1,328 +0,0 @@
|
||||
"""Integration tests for @@agptfile: reference expansion in tool calls.
|
||||
|
||||
These tests verify the end-to-end behaviour of the file reference protocol:
|
||||
- Parsing @@agptfile: tokens from tool arguments
|
||||
- Resolving local-filesystem paths (sdk_cwd / ephemeral)
|
||||
- Expanding references inside the tool-call pipeline (_execute_tool_sync)
|
||||
- The extended Read tool handler (workspace:// pass-through via session context)
|
||||
|
||||
No real LLM or database is required; workspace reads are stubbed where needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.sdk.file_ref import (
|
||||
FileRef,
|
||||
expand_file_refs_in_args,
|
||||
expand_file_refs_in_string,
|
||||
read_file_bytes,
|
||||
resolve_file_ref,
|
||||
)
|
||||
from backend.copilot.sdk.tool_adapter import _read_file_handler
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_session(session_id: str = "integ-sess") -> MagicMock:
|
||||
s = MagicMock()
|
||||
s.session_id = session_id
|
||||
return s
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Local-file resolution (sdk_cwd)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_file_ref_local_path():
|
||||
"""resolve_file_ref reads a real local file when it's within sdk_cwd."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
# Write a test file inside sdk_cwd
|
||||
test_file = os.path.join(sdk_cwd, "hello.txt")
|
||||
with open(test_file, "w") as f:
|
||||
f.write("line1\nline2\nline3\n")
|
||||
|
||||
session = _make_session()
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
ref = FileRef(uri=test_file, start_line=None, end_line=None)
|
||||
content = await resolve_file_ref(ref, user_id="u1", session=session)
|
||||
|
||||
assert content == "line1\nline2\nline3\n"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_file_ref_local_path_with_line_range():
|
||||
"""resolve_file_ref respects line ranges for local files."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
test_file = os.path.join(sdk_cwd, "multi.txt")
|
||||
lines = [f"line{i}\n" for i in range(1, 11)] # line1 … line10
|
||||
with open(test_file, "w") as f:
|
||||
f.writelines(lines)
|
||||
|
||||
session = _make_session()
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
ref = FileRef(uri=test_file, start_line=3, end_line=5)
|
||||
content = await resolve_file_ref(ref, user_id="u1", session=session)
|
||||
|
||||
assert content == "line3\nline4\nline5\n"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_file_ref_rejects_path_outside_sdk_cwd():
|
||||
"""resolve_file_ref raises ValueError for paths outside sdk_cwd."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
|
||||
"backend.copilot.context._current_sandbox"
|
||||
) as mock_sandbox_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
mock_sandbox_var.get.return_value = None
|
||||
|
||||
ref = FileRef(uri="/etc/passwd", start_line=None, end_line=None)
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
await resolve_file_ref(ref, user_id="u1", session=_make_session())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# expand_file_refs_in_string — integration with real files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_string_with_real_file():
|
||||
"""expand_file_refs_in_string replaces @@agptfile: token with actual content."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
test_file = os.path.join(sdk_cwd, "data.txt")
|
||||
with open(test_file, "w") as f:
|
||||
f.write("hello world\n")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_string(
|
||||
f"Content: @@agptfile:{test_file}",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result == "Content: hello world\n"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_string_missing_file_is_surfaced_inline():
|
||||
"""Missing file ref yields [file-ref error: …] inline rather than raising."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
missing = os.path.join(sdk_cwd, "does_not_exist.txt")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_string(
|
||||
f"@@agptfile:{missing}",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert "[file-ref error:" in result
|
||||
assert "not found" in result.lower() or "not allowed" in result.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# expand_file_refs_in_args — dict traversal with real files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_args_replaces_file_ref_in_nested_dict():
|
||||
"""Nested @@agptfile: references in args are fully expanded."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
file_a = os.path.join(sdk_cwd, "a.txt")
|
||||
file_b = os.path.join(sdk_cwd, "b.txt")
|
||||
with open(file_a, "w") as f:
|
||||
f.write("AAA")
|
||||
with open(file_b, "w") as f:
|
||||
f.write("BBB")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{
|
||||
"outer": {
|
||||
"content_a": f"@@agptfile:{file_a}",
|
||||
"content_b": f"start @@agptfile:{file_b} end",
|
||||
},
|
||||
"count": 42,
|
||||
},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["outer"]["content_a"] == "AAA"
|
||||
assert result["outer"]["content_b"] == "start BBB end"
|
||||
assert result["count"] == 42
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_file_handler — extended to accept workspace:// and local paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_handler_local_file():
|
||||
"""_read_file_handler reads a local file when it's within sdk_cwd."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
test_file = os.path.join(sdk_cwd, "read_test.txt")
|
||||
lines = [f"L{i}\n" for i in range(1, 6)]
|
||||
with open(test_file, "w") as f:
|
||||
f.writelines(lines)
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
|
||||
"backend.copilot.context._current_project_dir"
|
||||
) as mock_proj_var, patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", _make_session()),
|
||||
):
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
mock_proj_var.get.return_value = ""
|
||||
|
||||
result = await _read_file_handler(
|
||||
{"file_path": test_file, "offset": 0, "limit": 5}
|
||||
)
|
||||
|
||||
assert not result["isError"]
|
||||
text = result["content"][0]["text"]
|
||||
assert "L1" in text
|
||||
assert "L5" in text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_handler_workspace_uri():
|
||||
"""_read_file_handler handles workspace:// URIs via the workspace manager."""
|
||||
mock_session = _make_session()
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.read_file_by_id.return_value = b"workspace file content\nline two\n"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", mock_session),
|
||||
), patch(
|
||||
"backend.copilot.sdk.file_ref.get_manager",
|
||||
new=AsyncMock(return_value=mock_manager),
|
||||
):
|
||||
result = await _read_file_handler(
|
||||
{"file_path": "workspace://file-id-abc", "offset": 0, "limit": 10}
|
||||
)
|
||||
|
||||
assert not result["isError"], result["content"][0]["text"]
|
||||
text = result["content"][0]["text"]
|
||||
assert "workspace file content" in text
|
||||
assert "line two" in text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_handler_workspace_uri_no_session():
|
||||
"""_read_file_handler returns error when workspace:// is used without session."""
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=(None, None),
|
||||
):
|
||||
result = await _read_file_handler({"file_path": "workspace://some-id"})
|
||||
|
||||
assert result["isError"]
|
||||
assert "session" in result["content"][0]["text"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_handler_access_denied():
|
||||
"""_read_file_handler rejects paths outside allowed locations."""
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
|
||||
"backend.copilot.context._current_sandbox"
|
||||
) as mock_sandbox, patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", _make_session()),
|
||||
):
|
||||
mock_cwd.get.return_value = "/tmp/safe-dir"
|
||||
mock_sandbox.get.return_value = None
|
||||
|
||||
result = await _read_file_handler({"file_path": "/etc/passwd"})
|
||||
|
||||
assert result["isError"]
|
||||
assert "not allowed" in result["content"][0]["text"].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# read_file_bytes — workspace:///path (virtual path) and E2B sandbox branch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_bytes_workspace_virtual_path():
|
||||
"""workspace:///path resolves via manager.read_file (is_path=True path)."""
|
||||
session = _make_session()
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.read_file.return_value = b"virtual path content"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.get_manager",
|
||||
new=AsyncMock(return_value=mock_manager),
|
||||
):
|
||||
result = await read_file_bytes("workspace:///reports/q1.md", "user-1", session)
|
||||
|
||||
assert result == b"virtual path content"
|
||||
mock_manager.read_file.assert_awaited_once_with("/reports/q1.md")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_bytes_e2b_sandbox_branch():
|
||||
"""read_file_bytes reads from the E2B sandbox when a sandbox is active."""
|
||||
session = _make_session()
|
||||
mock_sandbox = AsyncMock()
|
||||
mock_sandbox.files.read.return_value = bytearray(b"sandbox content")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
|
||||
"backend.copilot.context._current_sandbox"
|
||||
) as mock_sandbox_var, patch(
|
||||
"backend.copilot.context._current_project_dir"
|
||||
) as mock_proj:
|
||||
mock_cwd.get.return_value = ""
|
||||
mock_sandbox_var.get.return_value = mock_sandbox
|
||||
mock_proj.get.return_value = ""
|
||||
|
||||
result = await read_file_bytes("/home/user/script.sh", None, session)
|
||||
|
||||
assert result == b"sandbox content"
|
||||
mock_sandbox.files.read.assert_awaited_once_with(
|
||||
"/home/user/script.sh", format="bytes"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_bytes_e2b_path_escapes_sandbox_raises():
|
||||
"""read_file_bytes raises ValueError for paths that escape the sandbox root."""
|
||||
session = _make_session()
|
||||
mock_sandbox = AsyncMock()
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
|
||||
"backend.copilot.context._current_sandbox"
|
||||
) as mock_sandbox_var, patch(
|
||||
"backend.copilot.context._current_project_dir"
|
||||
) as mock_proj:
|
||||
mock_cwd.get.return_value = ""
|
||||
mock_sandbox_var.get.return_value = mock_sandbox
|
||||
mock_proj.get.return_value = ""
|
||||
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
await read_file_bytes("/etc/passwd", None, session)
|
||||
@@ -1,382 +0,0 @@
|
||||
"""Tests for the @@agptfile: reference protocol (file_ref.py)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.sdk.file_ref import (
|
||||
_MAX_EXPAND_CHARS,
|
||||
FileRef,
|
||||
FileRefExpansionError,
|
||||
_apply_line_range,
|
||||
expand_file_refs_in_args,
|
||||
expand_file_refs_in_string,
|
||||
parse_file_ref,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_ref
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_file_ref_workspace_id():
|
||||
ref = parse_file_ref("@@agptfile:workspace://abc123")
|
||||
assert ref == FileRef(uri="workspace://abc123", start_line=None, end_line=None)
|
||||
|
||||
|
||||
def test_parse_file_ref_workspace_id_with_mime():
|
||||
ref = parse_file_ref("@@agptfile:workspace://abc123#text/plain")
|
||||
assert ref is not None
|
||||
assert ref.uri == "workspace://abc123#text/plain"
|
||||
assert ref.start_line is None
|
||||
|
||||
|
||||
def test_parse_file_ref_workspace_path():
|
||||
ref = parse_file_ref("@@agptfile:workspace:///reports/q1.md")
|
||||
assert ref is not None
|
||||
assert ref.uri == "workspace:///reports/q1.md"
|
||||
|
||||
|
||||
def test_parse_file_ref_with_line_range():
|
||||
ref = parse_file_ref("@@agptfile:workspace://abc123[10-50]")
|
||||
assert ref == FileRef(uri="workspace://abc123", start_line=10, end_line=50)
|
||||
|
||||
|
||||
def test_parse_file_ref_local_path():
|
||||
ref = parse_file_ref("@@agptfile:/tmp/copilot-session/output.py[1-100]")
|
||||
assert ref is not None
|
||||
assert ref.uri == "/tmp/copilot-session/output.py"
|
||||
assert ref.start_line == 1
|
||||
assert ref.end_line == 100
|
||||
|
||||
|
||||
def test_parse_file_ref_no_match():
|
||||
assert parse_file_ref("just a normal string") is None
|
||||
assert parse_file_ref("workspace://abc123") is None # missing @@agptfile: prefix
|
||||
assert (
|
||||
parse_file_ref("@@agptfile:workspace://abc123 extra") is None
|
||||
) # not full match
|
||||
|
||||
|
||||
def test_parse_file_ref_strips_whitespace():
|
||||
ref = parse_file_ref(" @@agptfile:workspace://abc123 ")
|
||||
assert ref is not None
|
||||
assert ref.uri == "workspace://abc123"
|
||||
|
||||
|
||||
def test_parse_file_ref_invalid_range_zero_start():
|
||||
assert parse_file_ref("@@agptfile:workspace://abc123[0-5]") is None
|
||||
|
||||
|
||||
def test_parse_file_ref_invalid_range_end_less_than_start():
|
||||
assert parse_file_ref("@@agptfile:workspace://abc123[10-5]") is None
|
||||
|
||||
|
||||
def test_parse_file_ref_invalid_range_zero_end():
|
||||
assert parse_file_ref("@@agptfile:workspace://abc123[1-0]") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _apply_line_range
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
TEXT = "line1\nline2\nline3\nline4\nline5\n"
|
||||
|
||||
|
||||
def test_apply_line_range_no_range():
|
||||
assert _apply_line_range(TEXT, None, None) == TEXT
|
||||
|
||||
|
||||
def test_apply_line_range_start_only():
|
||||
result = _apply_line_range(TEXT, 3, None)
|
||||
assert result == "line3\nline4\nline5\n"
|
||||
|
||||
|
||||
def test_apply_line_range_full():
|
||||
result = _apply_line_range(TEXT, 2, 4)
|
||||
assert result == "line2\nline3\nline4\n"
|
||||
|
||||
|
||||
def test_apply_line_range_single_line():
|
||||
result = _apply_line_range(TEXT, 2, 2)
|
||||
assert result == "line2\n"
|
||||
|
||||
|
||||
def test_apply_line_range_beyond_eof():
|
||||
result = _apply_line_range(TEXT, 4, 999)
|
||||
assert result == "line4\nline5\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# expand_file_refs_in_string
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_session(session_id: str = "sess-1") -> MagicMock:
|
||||
session = MagicMock()
|
||||
session.session_id = session_id
|
||||
return session
|
||||
|
||||
|
||||
async def _resolve_always(ref: FileRef, _user_id: str | None, _session: object) -> str:
|
||||
"""Stub resolver that returns the URI and range as a descriptive string."""
|
||||
if ref.start_line is not None:
|
||||
return f"content:{ref.uri}[{ref.start_line}-{ref.end_line}]"
|
||||
return f"content:{ref.uri}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_no_refs():
|
||||
result = await expand_file_refs_in_string(
|
||||
"no references here", user_id="u1", session=_make_session()
|
||||
)
|
||||
assert result == "no references here"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_single_ref():
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_always),
|
||||
):
|
||||
result = await expand_file_refs_in_string(
|
||||
"@@agptfile:workspace://abc123",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result == "content:workspace://abc123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_ref_with_range():
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_always),
|
||||
):
|
||||
result = await expand_file_refs_in_string(
|
||||
"@@agptfile:workspace://abc123[10-50]",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result == "content:workspace://abc123[10-50]"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_ref_embedded_in_text():
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_always),
|
||||
):
|
||||
result = await expand_file_refs_in_string(
|
||||
"Here is the file: @@agptfile:workspace://abc123 — done",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result == "Here is the file: content:workspace://abc123 — done"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_multiple_refs():
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_always),
|
||||
):
|
||||
result = await expand_file_refs_in_string(
|
||||
"@@agptfile:workspace://file1 and @@agptfile:workspace://file2[1-5]",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result == "content:workspace://file1 and content:workspace://file2[1-5]"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_invalid_range_zero_start_surfaces_inline():
|
||||
"""expand_file_refs_in_string surfaces [file-ref error: ...] for zero-start ranges."""
|
||||
result = await expand_file_refs_in_string(
|
||||
"@@agptfile:workspace://abc123[0-5]",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert "[file-ref error:" in result
|
||||
assert "line numbers must be >= 1" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_invalid_range_end_less_than_start_surfaces_inline():
|
||||
"""expand_file_refs_in_string surfaces [file-ref error: ...] when end < start."""
|
||||
result = await expand_file_refs_in_string(
|
||||
"prefix @@agptfile:workspace://abc123[10-5] suffix",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert "[file-ref error:" in result
|
||||
assert "end line must be >= start line" in result
|
||||
assert "prefix" in result
|
||||
assert "suffix" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_ref_error_surfaces_inline():
|
||||
async def _raise(*args, **kwargs): # noqa: ARG001
|
||||
raise ValueError("file not found")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_raise),
|
||||
):
|
||||
result = await expand_file_refs_in_string(
|
||||
"@@agptfile:workspace://bad",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert "[file-ref error:" in result
|
||||
assert "file not found" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# expand_file_refs_in_args
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_args_flat():
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_always),
|
||||
):
|
||||
result = await expand_file_refs_in_args(
|
||||
{"content": "@@agptfile:workspace://abc123", "other": 42},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result["content"] == "content:workspace://abc123"
|
||||
assert result["other"] == 42
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_args_nested_dict():
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_always),
|
||||
):
|
||||
result = await expand_file_refs_in_args(
|
||||
{"outer": {"inner": "@@agptfile:workspace://nested"}},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result["outer"]["inner"] == "content:workspace://nested"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_args_list():
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_always),
|
||||
):
|
||||
result = await expand_file_refs_in_args(
|
||||
{
|
||||
"items": [
|
||||
"@@agptfile:workspace://a",
|
||||
"plain",
|
||||
"@@agptfile:workspace://b[1-3]",
|
||||
]
|
||||
},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result["items"] == [
|
||||
"content:workspace://a",
|
||||
"plain",
|
||||
"content:workspace://b[1-3]",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_args_empty():
|
||||
result = await expand_file_refs_in_args({}, user_id="u1", session=_make_session())
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_args_no_refs():
|
||||
result = await expand_file_refs_in_args(
|
||||
{"key": "no refs here", "num": 1},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result == {"key": "no refs here", "num": 1}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_args_raises_on_file_ref_error():
|
||||
"""expand_file_refs_in_args raises FileRefExpansionError instead of passing
|
||||
the inline error string to the tool, blocking tool execution."""
|
||||
|
||||
async def _raise(*args, **kwargs): # noqa: ARG001
|
||||
raise ValueError("path does not exist")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_raise),
|
||||
):
|
||||
with pytest.raises(FileRefExpansionError) as exc_info:
|
||||
await expand_file_refs_in_args(
|
||||
{"prompt": "@@agptfile:/home/user/missing.txt"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert "path does not exist" in str(exc_info.value)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-file truncation and aggregate budget
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_per_file_truncation():
|
||||
"""Content exceeding _MAX_EXPAND_CHARS is truncated with a marker."""
|
||||
oversized = "x" * (_MAX_EXPAND_CHARS + 100)
|
||||
|
||||
async def _resolve_oversized(ref: FileRef, _uid: str | None, _s: object) -> str:
|
||||
return oversized
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_oversized),
|
||||
):
|
||||
result = await expand_file_refs_in_string(
|
||||
"@@agptfile:workspace://big-file",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert len(result) <= _MAX_EXPAND_CHARS + len("\n... [truncated]") + 10
|
||||
assert "[truncated]" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_aggregate_budget_exhausted():
|
||||
"""When the aggregate budget is exhausted, later refs get the budget message."""
|
||||
# Each file returns just under 300K; after ~4 files the 1M budget is used.
|
||||
big_chunk = "y" * 300_000
|
||||
|
||||
async def _resolve_big(ref: FileRef, _uid: str | None, _s: object) -> str:
|
||||
return big_chunk
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_big),
|
||||
):
|
||||
# 5 refs @ 300K each = 1.5M → last ref(s) should hit the aggregate limit
|
||||
refs = " ".join(f"@@agptfile:workspace://f{i}" for i in range(5))
|
||||
result = await expand_file_refs_in_string(
|
||||
refs,
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert "budget exhausted" in result
|
||||
@@ -1,42 +0,0 @@
|
||||
## MCP Tool Guide
|
||||
|
||||
### Workflow
|
||||
|
||||
`run_mcp_tool` follows a two-step pattern:
|
||||
|
||||
1. **Discover** — call with only `server_url` to list available tools on the server.
|
||||
2. **Execute** — call again with `server_url`, `tool_name`, and `tool_arguments` to run a tool.
|
||||
|
||||
### Known hosted MCP servers
|
||||
|
||||
Use these URLs directly without asking the user:
|
||||
|
||||
| Service | URL |
|
||||
|---|---|
|
||||
| Notion | `https://mcp.notion.com/mcp` |
|
||||
| Linear | `https://mcp.linear.app/mcp` |
|
||||
| Stripe | `https://mcp.stripe.com` |
|
||||
| Intercom | `https://mcp.intercom.com/mcp` |
|
||||
| Cloudflare | `https://mcp.cloudflare.com/mcp` |
|
||||
| Atlassian / Jira | `https://mcp.atlassian.com/mcp` |
|
||||
|
||||
For other services, search the MCP registry at https://registry.modelcontextprotocol.io/.
|
||||
|
||||
### Authentication
|
||||
|
||||
If the server requires credentials, a `SetupRequirementsResponse` is returned with an OAuth
|
||||
login prompt. Once the user completes the flow and confirms, retry the same call immediately.
|
||||
|
||||
### Communication style
|
||||
|
||||
Avoid technical jargon like "MCP server", "OAuth", or "credentials" when talking to the user.
|
||||
Use plain, friendly language instead:
|
||||
|
||||
| Instead of… | Say… |
|
||||
|---|---|
|
||||
| "Let me connect to Sentry's MCP server and discover what tools are available." | "I can connect to Sentry and help identify important issues." |
|
||||
| "Let me connect to Sentry's MCP server now." | "Next, I'll connect to Sentry." |
|
||||
| "The MCP server at mcp.sentry.dev requires authentication. Please connect your credentials to continue." | "To continue, sign in to Sentry and approve access." |
|
||||
| "Sentry's MCP server needs OAuth authentication. You should see a prompt to connect your Sentry account…" | "You should see a prompt to sign in to Sentry. Once connected, I can help surface critical issues right away." |
|
||||
|
||||
Use **"connect to [Service]"** or **"sign in to [Service]"** — never "MCP server", "OAuth", or "credentials".
|
||||
@@ -536,12 +536,10 @@ async def test_wait_for_stash_signaled():
|
||||
result = await wait_for_stash(timeout=1.0)
|
||||
|
||||
assert result is True
|
||||
pto = _pto.get()
|
||||
assert pto is not None
|
||||
assert pto.get("WebSearch") == ["result data"]
|
||||
assert _pto.get({}).get("WebSearch") == ["result data"]
|
||||
|
||||
# Cleanup
|
||||
_pto.set({})
|
||||
_pto.set({}) # type: ignore[arg-type]
|
||||
_stash_event.set(None)
|
||||
|
||||
|
||||
@@ -556,7 +554,7 @@ async def test_wait_for_stash_timeout():
|
||||
assert result is False
|
||||
|
||||
# Cleanup
|
||||
_pto.set({})
|
||||
_pto.set({}) # type: ignore[arg-type]
|
||||
_stash_event.set(None)
|
||||
|
||||
|
||||
@@ -575,12 +573,10 @@ async def test_wait_for_stash_already_stashed():
|
||||
assert result is True
|
||||
|
||||
# But the stash itself is populated
|
||||
pto = _pto.get()
|
||||
assert pto is not None
|
||||
assert pto.get("Read") == ["file contents"]
|
||||
assert _pto.get({}).get("Read") == ["file contents"]
|
||||
|
||||
# Cleanup
|
||||
_pto.set({})
|
||||
_pto.set({}) # type: ignore[arg-type]
|
||||
_stash_event.set(None)
|
||||
|
||||
|
||||
|
||||
@@ -10,13 +10,12 @@ import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from backend.copilot.context import is_allowed_local_path
|
||||
|
||||
from .tool_adapter import (
|
||||
BLOCKED_TOOLS,
|
||||
DANGEROUS_PATTERNS,
|
||||
MCP_TOOL_PREFIX,
|
||||
WORKSPACE_SCOPED_TOOLS,
|
||||
is_allowed_local_path,
|
||||
stash_pending_tool_output,
|
||||
)
|
||||
|
||||
|
||||
@@ -9,9 +9,8 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import _current_project_dir
|
||||
|
||||
from .security_hooks import _validate_tool_access, _validate_user_isolation
|
||||
from .service import _is_tool_error_or_denial
|
||||
|
||||
SDK_CWD = "/tmp/copilot-abc123"
|
||||
|
||||
@@ -121,6 +120,8 @@ def test_read_no_cwd_denies_absolute():
|
||||
|
||||
|
||||
def test_read_tool_results_allowed():
|
||||
from .tool_adapter import _current_project_dir
|
||||
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
|
||||
# is_allowed_local_path requires the session's encoded cwd to be set
|
||||
@@ -132,14 +133,16 @@ def test_read_tool_results_allowed():
|
||||
_current_project_dir.reset(token)
|
||||
|
||||
|
||||
def test_read_claude_projects_settings_json_denied():
|
||||
"""SDK-internal artifacts like settings.json are NOT accessible — only tool-results/ is."""
|
||||
def test_read_claude_projects_session_dir_allowed():
|
||||
"""Files within the current session's project dir are allowed."""
|
||||
from .tool_adapter import _current_project_dir
|
||||
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
|
||||
token = _current_project_dir.set("-tmp-copilot-abc123")
|
||||
try:
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert _is_denied(result)
|
||||
assert not _is_denied(result)
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
|
||||
@@ -354,3 +357,76 @@ async def test_task_slot_released_on_failure(_hooks):
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
# -- _is_tool_error_or_denial ------------------------------------------------
|
||||
|
||||
|
||||
class TestIsToolErrorOrDenial:
|
||||
def test_none_content(self):
|
||||
assert _is_tool_error_or_denial(None) is False
|
||||
|
||||
def test_empty_content(self):
|
||||
assert _is_tool_error_or_denial("") is False
|
||||
|
||||
def test_benign_output(self):
|
||||
assert _is_tool_error_or_denial("All good, no issues.") is False
|
||||
|
||||
def test_security_marker(self):
|
||||
assert _is_tool_error_or_denial("[SECURITY] Tool access blocked") is True
|
||||
|
||||
def test_cannot_be_bypassed(self):
|
||||
assert _is_tool_error_or_denial("This restriction cannot be bypassed.") is True
|
||||
|
||||
def test_not_allowed(self):
|
||||
assert _is_tool_error_or_denial("Operation not allowed in sandbox") is True
|
||||
|
||||
def test_background_task_denial(self):
|
||||
assert (
|
||||
_is_tool_error_or_denial(
|
||||
"Background task execution is not supported. "
|
||||
"Run tasks in the foreground instead."
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_subtask_limit_denial(self):
|
||||
assert (
|
||||
_is_tool_error_or_denial(
|
||||
"Maximum 2 concurrent sub-tasks. "
|
||||
"Wait for running sub-tasks to finish, "
|
||||
"or continue in the main conversation."
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_denied_marker(self):
|
||||
assert (
|
||||
_is_tool_error_or_denial("Access denied: insufficient privileges") is True
|
||||
)
|
||||
|
||||
def test_blocked_marker(self):
|
||||
assert _is_tool_error_or_denial("Request blocked by security policy") is True
|
||||
|
||||
def test_failed_marker(self):
|
||||
assert _is_tool_error_or_denial("Failed to execute tool: timeout") is True
|
||||
|
||||
def test_mcp_iserror(self):
|
||||
assert _is_tool_error_or_denial('{"isError": true, "content": []}') is True
|
||||
|
||||
def test_benign_error_in_value(self):
|
||||
"""Content like '0 errors found' should not trigger — 'error' was removed."""
|
||||
assert _is_tool_error_or_denial("0 errors found") is False
|
||||
|
||||
def test_benign_permission_field(self):
|
||||
"""Schema descriptions mentioning 'permission' should not trigger."""
|
||||
assert (
|
||||
_is_tool_error_or_denial(
|
||||
'{"fields": [{"name": "permission_level", "type": "int"}]}'
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def test_benign_not_found_in_listing(self):
|
||||
"""File listing containing 'not found' in filenames should not trigger."""
|
||||
assert _is_tool_error_or_denial("readme.md\nfile-not-found-handler.py") is False
|
||||
|
||||
@@ -60,7 +60,7 @@ from ..service import (
|
||||
_generate_session_title,
|
||||
_is_langfuse_configured,
|
||||
)
|
||||
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
|
||||
from ..tools.e2b_sandbox import get_or_create_sandbox
|
||||
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||
from ..tools.workspace_files import get_manager
|
||||
from ..tracking import track_user_message
|
||||
@@ -456,6 +456,31 @@ def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
|
||||
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
|
||||
|
||||
|
||||
def _is_tool_error_or_denial(content: str | None) -> bool:
|
||||
"""Check if a tool message content indicates an error or denial.
|
||||
|
||||
Currently unused — ``_format_conversation_context`` includes all tool
|
||||
results. Kept as a utility for future selective filtering.
|
||||
"""
|
||||
if not content:
|
||||
return False
|
||||
lower = content.lower()
|
||||
return any(
|
||||
marker in lower
|
||||
for marker in (
|
||||
"[security]",
|
||||
"cannot be bypassed",
|
||||
"not allowed",
|
||||
"not supported", # background-task denial
|
||||
"maximum", # subtask-limit denial
|
||||
"denied",
|
||||
"blocked",
|
||||
"failed to", # internal tool execution failures
|
||||
'"iserror": true', # MCP protocol error flag
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _build_query_message(
|
||||
current_message: str,
|
||||
session: ChatSession,
|
||||
@@ -759,29 +784,28 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
async def _setup_e2b():
|
||||
"""Set up E2B sandbox if configured, return sandbox or None."""
|
||||
if not (e2b_api_key := config.active_e2b_api_key):
|
||||
if config.use_e2b_sandbox:
|
||||
logger.warning(
|
||||
"[E2B] [%s] E2B sandbox enabled but no API key configured "
|
||||
"(CHAT_E2B_API_KEY / E2B_API_KEY) — falling back to bubblewrap",
|
||||
session_id[:12],
|
||||
)
|
||||
return None
|
||||
try:
|
||||
return await get_or_create_sandbox(
|
||||
session_id,
|
||||
api_key=e2b_api_key,
|
||||
template=config.e2b_sandbox_template,
|
||||
timeout=config.e2b_sandbox_timeout,
|
||||
on_timeout=config.e2b_sandbox_on_timeout,
|
||||
)
|
||||
except Exception as e2b_err:
|
||||
logger.error(
|
||||
"[E2B] [%s] Setup failed: %s",
|
||||
if config.use_e2b_sandbox and not config.e2b_api_key:
|
||||
logger.warning(
|
||||
"[E2B] [%s] E2B sandbox enabled but no API key configured "
|
||||
"(CHAT_E2B_API_KEY / E2B_API_KEY) — falling back to bubblewrap",
|
||||
session_id[:12],
|
||||
e2b_err,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
if config.use_e2b_sandbox and config.e2b_api_key:
|
||||
try:
|
||||
return await get_or_create_sandbox(
|
||||
session_id,
|
||||
api_key=config.e2b_api_key,
|
||||
template=config.e2b_sandbox_template,
|
||||
timeout=config.e2b_sandbox_timeout,
|
||||
)
|
||||
except Exception as e2b_err:
|
||||
logger.error(
|
||||
"[E2B] [%s] Setup failed: %s",
|
||||
session_id[:12],
|
||||
e2b_err,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
async def _fetch_transcript():
|
||||
@@ -813,6 +837,7 @@ async def stream_chat_completion_sdk(
|
||||
system_prompt = base_system_prompt + get_sdk_supplement(
|
||||
use_e2b=use_e2b, cwd=sdk_cwd
|
||||
)
|
||||
|
||||
# Process transcript download result
|
||||
transcript_msg_count = 0
|
||||
if dl:
|
||||
@@ -877,11 +902,6 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
allowed = get_copilot_tool_names(use_e2b=use_e2b)
|
||||
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
|
||||
|
||||
def _on_stderr(line: str) -> None:
|
||||
sid = session_id[:12] if session_id else "?"
|
||||
logger.info("[SDK] [%s] CLI stderr: %s", sid, line.rstrip())
|
||||
|
||||
sdk_options_kwargs: dict[str, Any] = {
|
||||
"system_prompt": system_prompt,
|
||||
"mcp_servers": {"copilot": mcp_server},
|
||||
@@ -890,7 +910,6 @@ async def stream_chat_completion_sdk(
|
||||
"hooks": security_hooks,
|
||||
"cwd": sdk_cwd,
|
||||
"max_buffer_size": config.claude_agent_max_buffer_size,
|
||||
"stderr": _on_stderr,
|
||||
}
|
||||
if sdk_model:
|
||||
sdk_options_kwargs["model"] = sdk_model
|
||||
@@ -1063,19 +1082,6 @@ async def stream_chat_completion_sdk(
|
||||
len(adapter.resolved_tool_calls),
|
||||
)
|
||||
|
||||
# Log AssistantMessage API errors (e.g. invalid_request)
|
||||
# so we can debug Anthropic API 400s surfaced by the CLI.
|
||||
sdk_error = getattr(sdk_msg, "error", None)
|
||||
if isinstance(sdk_msg, AssistantMessage) and sdk_error:
|
||||
logger.error(
|
||||
"[SDK] [%s] AssistantMessage has error=%s, "
|
||||
"content_blocks=%d, content_preview=%s",
|
||||
session_id[:12],
|
||||
sdk_error,
|
||||
len(sdk_msg.content),
|
||||
str(sdk_msg.content)[:500],
|
||||
)
|
||||
|
||||
# Race-condition fix: SDK hooks (PostToolUse) are
|
||||
# executed asynchronously via start_soon() — the next
|
||||
# message can arrive before the hook stashes output.
|
||||
@@ -1410,17 +1416,6 @@ async def stream_chat_completion_sdk(
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# --- Pause E2B sandbox to stop billing between turns ---
|
||||
# Fire-and-forget: pausing is best-effort and must not block the
|
||||
# response or the transcript upload. The task is anchored to
|
||||
# _background_tasks to prevent garbage collection.
|
||||
# Use pause_sandbox_direct to skip the Redis lookup and reconnect
|
||||
# round-trip — e2b_sandbox is the live object from this turn.
|
||||
if e2b_sandbox is not None:
|
||||
task = asyncio.create_task(pause_sandbox_direct(e2b_sandbox, session_id))
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
# --- Upload transcript for next-turn --resume ---
|
||||
# This MUST run in finally so the transcript is uploaded even when
|
||||
# the streaming loop raises an exception.
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
"""Tests for SDK service helpers."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -213,7 +212,7 @@ class TestPromptSupplement:
|
||||
|
||||
# Workflows are now in individual tool descriptions (not separate sections)
|
||||
# Check that key workflow concepts appear in tool descriptions
|
||||
assert "agent_json" in docs or "find_block" in docs
|
||||
assert "suggested_goal" in docs or "clarifying_questions" in docs
|
||||
assert "run_mcp_tool" in docs
|
||||
|
||||
def test_baseline_supplement_completeness(self):
|
||||
@@ -232,48 +231,6 @@ class TestPromptSupplement:
|
||||
f"`{tool_name}`" in docs
|
||||
), f"Tool '{tool_name}' missing from baseline supplement"
|
||||
|
||||
def test_pause_task_scheduled_before_transcript_upload(self):
|
||||
"""Pause is scheduled as a background task before transcript upload begins.
|
||||
|
||||
The finally block in stream_response_sdk does:
|
||||
(1) asyncio.create_task(pause_sandbox_direct(...)) — fire-and-forget
|
||||
(2) await asyncio.shield(upload_transcript(...)) — awaited
|
||||
|
||||
Scheduling pause via create_task before awaiting upload ensures:
|
||||
- Pause never blocks transcript upload (billing stops concurrently)
|
||||
- On E2B timeout, pause silently fails; upload proceeds unaffected
|
||||
"""
|
||||
call_order: list[str] = []
|
||||
|
||||
async def _mock_pause(sandbox, session_id):
|
||||
call_order.append("pause")
|
||||
|
||||
async def _mock_upload(**kwargs):
|
||||
call_order.append("upload")
|
||||
|
||||
async def _simulate_teardown():
|
||||
"""Mirror the service.py finally block teardown sequence."""
|
||||
sandbox = MagicMock()
|
||||
|
||||
# (1) Schedule pause — mirrors lines ~1427-1429 in service.py
|
||||
task = asyncio.create_task(_mock_pause(sandbox, "test-sess"))
|
||||
|
||||
# (2) Await transcript upload — mirrors lines ~1460-1468 in service.py
|
||||
# Yielding to the event loop here lets the pause task start concurrently.
|
||||
await _mock_upload(
|
||||
user_id="u", session_id="test-sess", content="x", message_count=1
|
||||
)
|
||||
await task
|
||||
|
||||
asyncio.run(_simulate_teardown())
|
||||
|
||||
# Both must run; pause is scheduled before upload starts
|
||||
assert "pause" in call_order
|
||||
assert "upload" in call_order
|
||||
# create_task schedules pause, then upload is awaited — pause runs
|
||||
# concurrently during upload's first yield. The ordering guarantee is
|
||||
# that create_task is CALLED before upload is AWAITED (see source order).
|
||||
|
||||
def test_baseline_supplement_no_duplicate_tools(self):
|
||||
"""No tool should appear multiple times in baseline supplement."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
|
||||
@@ -9,29 +9,14 @@ import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
|
||||
from backend.copilot.context import (
|
||||
_current_project_dir,
|
||||
_current_sandbox,
|
||||
_current_sdk_cwd,
|
||||
_current_session,
|
||||
_current_user_id,
|
||||
_encode_cwd_for_cli,
|
||||
get_execution_context,
|
||||
get_sdk_cwd,
|
||||
is_allowed_local_path,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.sdk.file_ref import (
|
||||
FileRefExpansionError,
|
||||
expand_file_refs_in_args,
|
||||
read_file_bytes,
|
||||
)
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
from backend.copilot.tools.base import BaseTool
|
||||
from backend.util.truncate import truncate
|
||||
@@ -43,13 +28,84 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Allowed base directory for the Read tool (SDK saves oversized tool results here).
|
||||
# Restricted to ~/.claude/projects/ and further validated to require "tool-results"
|
||||
# in the path — prevents reading settings, credentials, or other sensitive files.
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
# Max MCP response size in chars — keeps tool output under the SDK's 10 MB JSON buffer.
|
||||
_MCP_MAX_CHARS = 500_000
|
||||
|
||||
# Context variable holding the encoded project directory name for the current
|
||||
# session (e.g. "-private-tmp-copilot-<uuid>"). Set by set_execution_context()
|
||||
# so that path validation can scope tool-results reads to the current session.
|
||||
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
|
||||
|
||||
|
||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does.
|
||||
|
||||
The CLI replaces all non-alphanumeric characters with ``-``.
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
|
||||
|
||||
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
"""Check whether *path* is an allowed host-filesystem path.
|
||||
|
||||
Allowed:
|
||||
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/`` — the SDK's
|
||||
project directory for this session (tool-results, transcripts, etc.)
|
||||
|
||||
Both checks are scoped to the **current session** so sessions cannot
|
||||
read each other's data.
|
||||
"""
|
||||
if not path:
|
||||
return False
|
||||
|
||||
if path.startswith("~"):
|
||||
resolved = os.path.realpath(os.path.expanduser(path))
|
||||
elif not os.path.isabs(path) and sdk_cwd:
|
||||
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
|
||||
else:
|
||||
resolved = os.path.realpath(path)
|
||||
|
||||
# Allow access within the SDK working directory
|
||||
if sdk_cwd:
|
||||
norm_cwd = os.path.realpath(sdk_cwd)
|
||||
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
|
||||
return True
|
||||
|
||||
# Allow access within the current session's CLI project directory
|
||||
# (~/.claude/projects/<encoded-cwd>/).
|
||||
encoded = _current_project_dir.get("")
|
||||
if encoded:
|
||||
session_project = os.path.join(_SDK_PROJECTS_DIR, encoded)
|
||||
if resolved == session_project or resolved.startswith(session_project + os.sep):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
|
||||
MCP_SERVER_NAME = "copilot"
|
||||
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
|
||||
|
||||
# Context variables to pass user/session info to tool execution
|
||||
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
|
||||
_current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||
"current_session", default=None
|
||||
)
|
||||
# E2B cloud sandbox for the current turn (None when E2B is not configured).
|
||||
# Passed to bash_exec so commands run on E2B instead of the local bwrap sandbox.
|
||||
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
|
||||
"_current_sandbox", default=None
|
||||
)
|
||||
# Raw SDK working directory path (e.g. /tmp/copilot-<session_id>).
|
||||
# Used by workspace tools to save binary files for the CLI's built-in Read.
|
||||
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
|
||||
|
||||
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||
# response adapter when it builds StreamToolOutputAvailable.
|
||||
@@ -93,6 +149,24 @@ def set_execution_context(
|
||||
_stash_event.set(asyncio.Event())
|
||||
|
||||
|
||||
def get_current_sandbox() -> "AsyncSandbox | None":
|
||||
"""Return the E2B sandbox for the current turn, or None."""
|
||||
return _current_sandbox.get()
|
||||
|
||||
|
||||
def get_sdk_cwd() -> str:
|
||||
"""Return the SDK ephemeral working directory for the current turn."""
|
||||
return _current_sdk_cwd.get()
|
||||
|
||||
|
||||
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
"""Get the current execution context."""
|
||||
return (
|
||||
_current_user_id.get(),
|
||||
_current_session.get(),
|
||||
)
|
||||
|
||||
|
||||
def pop_pending_tool_output(tool_name: str) -> str | None:
|
||||
"""Pop and return the oldest stashed output for *tool_name*.
|
||||
|
||||
@@ -185,11 +259,7 @@ async def _execute_tool_sync(
|
||||
session: ChatSession,
|
||||
args: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Execute a tool synchronously and return MCP-formatted response.
|
||||
|
||||
Note: ``@@agptfile:`` expansion is handled upstream in the ``_truncating`` wrapper
|
||||
so all registered handlers (BaseTool, E2B, Read) expand uniformly.
|
||||
"""
|
||||
"""Execute a tool synchronously and return MCP-formatted response."""
|
||||
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
|
||||
result = await base_tool.execute(
|
||||
user_id=user_id,
|
||||
@@ -250,50 +320,42 @@ def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
|
||||
|
||||
|
||||
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Read a file with optional offset/limit.
|
||||
"""Read a local file with optional offset/limit.
|
||||
|
||||
Supports ``workspace://`` URIs (delegated to the workspace manager) and
|
||||
local paths within the session's allowed directories (sdk_cwd + tool-results).
|
||||
Only allows paths that pass :func:`is_allowed_local_path` — the current
|
||||
session's tool-results directory and ephemeral working directory.
|
||||
"""
|
||||
file_path = args.get("file_path", "")
|
||||
offset = max(0, int(args.get("offset", 0)))
|
||||
limit = max(1, int(args.get("limit", 2000)))
|
||||
offset = args.get("offset", 0)
|
||||
limit = args.get("limit", 2000)
|
||||
|
||||
def _mcp_err(text: str) -> dict[str, Any]:
|
||||
return {"content": [{"type": "text", "text": text}], "isError": True}
|
||||
|
||||
def _mcp_ok(text: str) -> dict[str, Any]:
|
||||
return {"content": [{"type": "text", "text": text}], "isError": False}
|
||||
|
||||
if file_path.startswith("workspace://"):
|
||||
user_id, session = get_execution_context()
|
||||
if session is None:
|
||||
return _mcp_err("workspace:// file references require an active session")
|
||||
try:
|
||||
raw = await read_file_bytes(file_path, user_id, session)
|
||||
except ValueError as exc:
|
||||
return _mcp_err(str(exc))
|
||||
lines = raw.decode("utf-8", errors="replace").splitlines(keepends=True)
|
||||
selected = list(itertools.islice(lines, offset, offset + limit))
|
||||
numbered = "".join(
|
||||
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
|
||||
)
|
||||
return _mcp_ok(numbered)
|
||||
|
||||
if not is_allowed_local_path(file_path, get_sdk_cwd()):
|
||||
return _mcp_err(f"Path not allowed: {file_path}")
|
||||
if not is_allowed_local_path(file_path):
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
resolved = os.path.realpath(os.path.expanduser(file_path))
|
||||
try:
|
||||
with open(resolved) as f:
|
||||
selected = list(itertools.islice(f, offset, offset + limit))
|
||||
content = "".join(selected)
|
||||
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
|
||||
# don't delete here — the SDK may read in multiple chunks.
|
||||
return _mcp_ok("".join(selected))
|
||||
return {
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"isError": False,
|
||||
}
|
||||
except FileNotFoundError:
|
||||
return _mcp_err(f"File not found: {file_path}")
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
|
||||
"isError": True,
|
||||
}
|
||||
except Exception as e:
|
||||
return _mcp_err(f"Error reading file: {e}")
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"Error reading file: {e}"}],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
|
||||
_READ_TOOL_NAME = "Read"
|
||||
@@ -352,23 +414,9 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
SDK's 10 MB JSON buffer, and stash the (truncated) output for the
|
||||
response adapter before the SDK can apply its own head-truncation.
|
||||
|
||||
Also expands ``@@agptfile:`` references in args so every registered tool
|
||||
(BaseTool, E2B file tools, Read) receives resolved content uniformly.
|
||||
|
||||
Applied once to every registered tool."""
|
||||
|
||||
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
|
||||
user_id, session = get_execution_context()
|
||||
if session is not None:
|
||||
try:
|
||||
args = await expand_file_refs_in_args(args, user_id, session)
|
||||
except FileRefExpansionError as exc:
|
||||
return _mcp_error(
|
||||
f"@@agptfile: reference could not be resolved: {exc}. "
|
||||
"Ensure the file exists before referencing it. "
|
||||
"For sandbox paths use bash_exec to verify the file exists first; "
|
||||
"for workspace files use a workspace:// URI."
|
||||
)
|
||||
result = await fn(args)
|
||||
truncated = truncate(result, _MCP_MAX_CHARS)
|
||||
|
||||
|
||||
@@ -2,12 +2,12 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import get_sdk_cwd
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
from .tool_adapter import (
|
||||
_MCP_MAX_CHARS,
|
||||
_text_from_mcp_result,
|
||||
get_sdk_cwd,
|
||||
pop_pending_tool_output,
|
||||
set_execution_context,
|
||||
stash_pending_tool_output,
|
||||
|
||||
@@ -23,11 +23,6 @@ from typing import Any, Literal
|
||||
|
||||
import orjson
|
||||
|
||||
from backend.api.model import CopilotCompletionPayload
|
||||
from backend.data.notification_bus import (
|
||||
AsyncRedisNotificationEventBus,
|
||||
NotificationEvent,
|
||||
)
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
from .config import ChatConfig
|
||||
@@ -43,7 +38,6 @@ from .response_model import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
_notification_bus = AsyncRedisNotificationEventBus()
|
||||
|
||||
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
|
||||
_local_sessions: dict[str, asyncio.Task] = {}
|
||||
@@ -751,29 +745,6 @@ async def mark_session_completed(
|
||||
|
||||
# Clean up local session reference if exists
|
||||
_local_sessions.pop(session_id, None)
|
||||
|
||||
# Publish copilot completion notification via WebSocket
|
||||
if meta:
|
||||
parsed = _parse_session_meta(meta, session_id)
|
||||
if parsed.user_id:
|
||||
try:
|
||||
await _notification_bus.publish(
|
||||
NotificationEvent(
|
||||
user_id=parsed.user_id,
|
||||
payload=CopilotCompletionPayload(
|
||||
type="copilot_completion",
|
||||
event="session_completed",
|
||||
session_id=session_id,
|
||||
status=status,
|
||||
),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to publish copilot completion notification "
|
||||
f"for session {session_id}: {e}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
from .continue_run_block import ContinueRunBlockTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .customize_agent import CustomizeAgentTool
|
||||
from .edit_agent import EditAgentTool
|
||||
@@ -20,10 +19,7 @@ from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsToo
|
||||
from .find_agent import FindAgentTool
|
||||
from .find_block import FindBlockTool
|
||||
from .find_library_agent import FindLibraryAgentTool
|
||||
from .fix_agent import FixAgentGraphTool
|
||||
from .get_agent_building_guide import GetAgentBuildingGuideTool
|
||||
from .get_doc_page import GetDocPageTool
|
||||
from .get_mcp_guide import GetMCPGuideTool
|
||||
from .manage_folders import (
|
||||
CreateFolderTool,
|
||||
DeleteFolderTool,
|
||||
@@ -36,7 +32,6 @@ from .run_agent import RunAgentTool
|
||||
from .run_block import RunBlockTool
|
||||
from .run_mcp_tool import RunMCPToolTool
|
||||
from .search_docs import SearchDocsTool
|
||||
from .validate_agent import ValidateAgentGraphTool
|
||||
from .web_fetch import WebFetchTool
|
||||
from .workspace_files import (
|
||||
DeleteWorkspaceFileTool,
|
||||
@@ -69,13 +64,10 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"move_agents_to_folder": MoveAgentsToFolderTool(),
|
||||
"run_agent": RunAgentTool(),
|
||||
"run_block": RunBlockTool(),
|
||||
"continue_run_block": ContinueRunBlockTool(),
|
||||
"run_mcp_tool": RunMCPToolTool(),
|
||||
"get_mcp_guide": GetMCPGuideTool(),
|
||||
"view_agent_output": AgentOutputTool(),
|
||||
"search_docs": SearchDocsTool(),
|
||||
"get_doc_page": GetDocPageTool(),
|
||||
"get_agent_building_guide": GetAgentBuildingGuideTool(),
|
||||
# Web fetch for safe URL retrieval
|
||||
"web_fetch": WebFetchTool(),
|
||||
# Agent-browser multi-step automation (navigate, act, screenshot)
|
||||
@@ -88,9 +80,6 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
# Feature request tools
|
||||
"search_feature_requests": SearchFeatureRequestsTool(),
|
||||
"create_feature_request": CreateFeatureRequestTool(),
|
||||
# Agent generation tools (local validation/fixing)
|
||||
"validate_agent_graph": ValidateAgentGraphTool(),
|
||||
"fix_agent_graph": FixAgentGraphTool(),
|
||||
# Workspace tools for CoPilot file operations
|
||||
"list_workspace_files": ListWorkspaceFilesTool(),
|
||||
"read_workspace_file": ReadWorkspaceFileTool(),
|
||||
|
||||
@@ -33,7 +33,7 @@ import tempfile
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.util.request import validate_url_host
|
||||
from backend.util.request import validate_url
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
@@ -235,7 +235,7 @@ async def _restore_browser_state(
|
||||
if url:
|
||||
# Validate the saved URL to prevent SSRF via stored redirect targets.
|
||||
try:
|
||||
await validate_url_host(url)
|
||||
await validate_url(url, trusted_origins=[])
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"[browser] State restore: blocked SSRF URL %s", url[:200]
|
||||
@@ -473,7 +473,7 @@ class BrowserNavigateTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
await validate_url_host(url)
|
||||
await validate_url(url, trusted_origins=[])
|
||||
except ValueError as e:
|
||||
return ErrorResponse(
|
||||
message=str(e),
|
||||
|
||||
@@ -68,18 +68,17 @@ def _run_result(rc: int = 0, stdout: str = "", stderr: str = "") -> tuple:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSRF protection via shared validate_url_host (backend.util.request)
|
||||
# SSRF protection via shared validate_url (backend.util.request)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Patch target: validate_url_host is imported directly into agent_browser's
|
||||
# module scope.
|
||||
_VALIDATE_URL = "backend.copilot.tools.agent_browser.validate_url_host"
|
||||
# Patch target: validate_url is imported directly into agent_browser's module scope.
|
||||
_VALIDATE_URL = "backend.copilot.tools.agent_browser.validate_url"
|
||||
|
||||
|
||||
class TestSsrfViaValidateUrl:
|
||||
"""Verify that browser_navigate uses validate_url_host for SSRF protection.
|
||||
"""Verify that browser_navigate uses validate_url for SSRF protection.
|
||||
|
||||
We mock validate_url_host itself (not the low-level socket) so these tests
|
||||
We mock validate_url itself (not the low-level socket) so these tests
|
||||
exercise the integration point, not the internals of request.py
|
||||
(which has its own thorough test suite in request_test.py).
|
||||
"""
|
||||
@@ -90,7 +89,7 @@ class TestSsrfViaValidateUrl:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocked_ip_returns_blocked_url_error(self):
|
||||
"""validate_url_host raises ValueError → tool returns blocked_url ErrorResponse."""
|
||||
"""validate_url raises ValueError → tool returns blocked_url ErrorResponse."""
|
||||
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
|
||||
mock_validate.side_effect = ValueError(
|
||||
"Access to blocked IP 10.0.0.1 is not allowed."
|
||||
@@ -125,8 +124,8 @@ class TestSsrfViaValidateUrl:
|
||||
assert result.error == "blocked_url"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_url_host_called_without_trusted_hostnames(self):
|
||||
"""Confirms no trusted-hostnames bypass is granted — all URLs are validated."""
|
||||
async def test_validate_url_called_with_empty_trusted_origins(self):
|
||||
"""Confirms no trusted-origins bypass is granted — all URLs are validated."""
|
||||
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
|
||||
mock_validate.return_value = (object(), False, ["1.2.3.4"])
|
||||
with patch(
|
||||
@@ -144,7 +143,7 @@ class TestSsrfViaValidateUrl:
|
||||
session=self.session,
|
||||
url="https://example.com",
|
||||
)
|
||||
mock_validate.assert_called_once_with("https://example.com")
|
||||
mock_validate.assert_called_once_with("https://example.com", trusted_origins=[])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -1,15 +1,20 @@
|
||||
"""Agent generator package - Creates agents from natural language."""
|
||||
|
||||
from .core import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
AgentJsonValidationError,
|
||||
AgentSummary,
|
||||
DecompositionResult,
|
||||
DecompositionStep,
|
||||
LibraryAgentSummary,
|
||||
MarketplaceAgentSummary,
|
||||
customize_template,
|
||||
decompose_goal,
|
||||
enrich_library_agents_from_steps,
|
||||
extract_search_terms_from_steps,
|
||||
extract_uuids_from_text,
|
||||
generate_agent,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
get_all_relevant_agents_for_generation,
|
||||
get_library_agent_by_graph_id,
|
||||
@@ -22,20 +27,25 @@ from .core import (
|
||||
search_marketplace_agents_for_generation,
|
||||
)
|
||||
from .errors import get_user_message_for_error
|
||||
from .validation import AgentFixer, AgentValidator
|
||||
from .service import health_check as check_external_service_health
|
||||
from .service import is_external_service_configured
|
||||
|
||||
__all__ = [
|
||||
"AgentFixer",
|
||||
"AgentValidator",
|
||||
"AgentGeneratorNotConfiguredError",
|
||||
"AgentJsonValidationError",
|
||||
"AgentSummary",
|
||||
"DecompositionResult",
|
||||
"DecompositionStep",
|
||||
"LibraryAgentSummary",
|
||||
"MarketplaceAgentSummary",
|
||||
"check_external_service_health",
|
||||
"customize_template",
|
||||
"decompose_goal",
|
||||
"enrich_library_agents_from_steps",
|
||||
"extract_search_terms_from_steps",
|
||||
"extract_uuids_from_text",
|
||||
"generate_agent",
|
||||
"generate_agent_patch",
|
||||
"get_agent_as_json",
|
||||
"get_all_relevant_agents_for_generation",
|
||||
"get_library_agent_by_graph_id",
|
||||
@@ -44,6 +54,7 @@ __all__ = [
|
||||
"get_library_agents_for_generation",
|
||||
"get_user_message_for_error",
|
||||
"graph_to_json",
|
||||
"is_external_service_configured",
|
||||
"json_to_graph",
|
||||
"save_agent_to_library",
|
||||
"search_marketplace_agents_for_generation",
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
"""Block management for agent generation.
|
||||
|
||||
Provides cached access to block metadata for validation and fixing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Type
|
||||
|
||||
from backend.blocks import get_blocks as get_block_classes
|
||||
from backend.blocks._base import Block
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["get_blocks_as_dicts", "reset_block_caches"]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level caches
|
||||
# ---------------------------------------------------------------------------
|
||||
_blocks_cache: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
def reset_block_caches() -> None:
|
||||
"""Reset all module-level caches (useful after updating block descriptions)."""
|
||||
global _blocks_cache
|
||||
_blocks_cache = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. get_blocks_as_dicts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_blocks_as_dicts() -> list[dict[str, Any]]:
|
||||
"""Get all available blocks as dicts (cached after first call).
|
||||
|
||||
Each dict contains the keys returned by ``Block.get_info().model_dump()``:
|
||||
id, name, description, inputSchema, outputSchema, categories,
|
||||
staticOutput, costs, contributors, uiType.
|
||||
|
||||
Returns:
|
||||
List of block info dicts.
|
||||
"""
|
||||
global _blocks_cache
|
||||
if _blocks_cache is not None:
|
||||
return _blocks_cache
|
||||
|
||||
block_classes: dict[str, Type[Block]] = get_block_classes() # type: ignore[assignment]
|
||||
blocks: list[dict[str, Any]] = []
|
||||
for block_cls in block_classes.values():
|
||||
try:
|
||||
instance = block_cls()
|
||||
info = instance.get_info().model_dump()
|
||||
# Use optimized description if available (loaded at startup)
|
||||
if instance.optimized_description:
|
||||
info["description"] = instance.optimized_description
|
||||
blocks.append(info)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to load block info for %s, skipping",
|
||||
getattr(block_cls, "__name__", "unknown"),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
_blocks_cache = blocks
|
||||
logger.info("Cached %d block dicts", len(blocks))
|
||||
return _blocks_cache
|
||||
@@ -10,7 +10,13 @@ from backend.data.db_accessors import graph_db, library_db, store_db
|
||||
from backend.data.graph import Graph, Link, Node
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .helpers import UUID_RE_STR
|
||||
from .service import (
|
||||
customize_template_external,
|
||||
decompose_goal_external,
|
||||
generate_agent_external,
|
||||
generate_agent_patch_external,
|
||||
is_external_service_configured,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -72,7 +78,38 @@ class DecompositionResult(TypedDict, total=False):
|
||||
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
|
||||
|
||||
|
||||
_UUID_PATTERN = re.compile(UUID_RE_STR, re.IGNORECASE)
|
||||
def _to_dict_list(
|
||||
agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None,
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Convert typed agent summaries to plain dicts for external service calls."""
|
||||
if agents is None:
|
||||
return None
|
||||
return [dict(a) for a in agents]
|
||||
|
||||
|
||||
class AgentGeneratorNotConfiguredError(Exception):
|
||||
"""Raised when the external Agent Generator service is not configured."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _check_service_configured() -> None:
|
||||
"""Check if the external Agent Generator service is configured.
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the service is not configured.
|
||||
"""
|
||||
if not is_external_service_configured():
|
||||
raise AgentGeneratorNotConfiguredError(
|
||||
"Agent Generator service is not configured. "
|
||||
"Set AGENTGENERATOR_HOST environment variable to enable agent generation."
|
||||
)
|
||||
|
||||
|
||||
_UUID_PATTERN = re.compile(
|
||||
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def extract_uuids_from_text(text: str) -> list[str]:
|
||||
@@ -516,6 +553,69 @@ async def enrich_library_agents_from_steps(
|
||||
return all_agents
|
||||
|
||||
|
||||
async def decompose_goal(
|
||||
description: str,
|
||||
context: str = "",
|
||||
library_agents: Sequence[AgentSummary] | None = None,
|
||||
) -> DecompositionResult | None:
|
||||
"""Break down a goal into steps or return clarifying questions.
|
||||
|
||||
Args:
|
||||
description: Natural language goal description
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
DecompositionResult with either:
|
||||
- {"type": "clarifying_questions", "questions": [...]}
|
||||
- {"type": "instructions", "steps": [...]}
|
||||
Or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for decompose_goal")
|
||||
result = await decompose_goal_external(
|
||||
description, context, _to_dict_list(library_agents)
|
||||
)
|
||||
return result # type: ignore[return-value]
|
||||
|
||||
|
||||
async def generate_agent(
|
||||
instructions: DecompositionResult | dict[str, Any],
|
||||
library_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Generate agent JSON from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
Agent JSON dict, error dict {"type": "error", ...}, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent")
|
||||
result = await generate_agent_external(
|
||||
dict(instructions), _to_dict_list(library_agents)
|
||||
)
|
||||
|
||||
if result:
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
return result
|
||||
if "id" not in result:
|
||||
result["id"] = str(uuid.uuid4())
|
||||
if "version" not in result:
|
||||
result["version"] = 1
|
||||
if "is_active" not in result:
|
||||
result["is_active"] = True
|
||||
return result
|
||||
|
||||
|
||||
class AgentJsonValidationError(Exception):
|
||||
"""Raised when agent JSON is invalid or missing required fields."""
|
||||
|
||||
@@ -692,3 +792,70 @@ async def get_agent_as_json(
|
||||
return None
|
||||
|
||||
return graph_to_json(graph)
|
||||
|
||||
|
||||
async def generate_agent_patch(
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
library_agents: Sequence[AgentSummary] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Update an existing agent using natural language.
|
||||
|
||||
The external Agent Generator service handles:
|
||||
- Generating the patch
|
||||
- Applying the patch
|
||||
- Fixing and validating the result
|
||||
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||
error dict {"type": "error", ...}, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||
return await generate_agent_patch_external(
|
||||
update_request,
|
||||
current_agent,
|
||||
_to_dict_list(library_agents),
|
||||
)
|
||||
|
||||
|
||||
async def customize_template(
|
||||
template_agent: dict[str, Any],
|
||||
modification_request: str,
|
||||
context: str = "",
|
||||
) -> dict[str, Any] | None:
|
||||
"""Customize a template/marketplace agent using natural language.
|
||||
|
||||
This is used when users want to modify a template or marketplace agent
|
||||
to fit their specific needs before adding it to their library.
|
||||
|
||||
The external Agent Generator service handles:
|
||||
- Understanding the modification request
|
||||
- Applying changes to the template
|
||||
- Fixing and validating the result
|
||||
|
||||
Args:
|
||||
template_agent: The template agent JSON to customize
|
||||
modification_request: Natural language description of customizations
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
|
||||
Returns:
|
||||
Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||
error dict {"type": "error", ...}, or None on unexpected error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for customize_template")
|
||||
return await customize_template_external(
|
||||
template_agent, modification_request, context
|
||||
)
|
||||
|
||||
@@ -0,0 +1,165 @@
|
||||
"""Dummy Agent Generator for testing.
|
||||
|
||||
Returns mock responses matching the format expected from the external service.
|
||||
Enable via AGENTGENERATOR_USE_DUMMY=true in settings.
|
||||
|
||||
WARNING: This is for testing only. Do not use in production.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Dummy decomposition result (instructions type)
|
||||
DUMMY_DECOMPOSITION_RESULT: dict[str, Any] = {
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{
|
||||
"description": "Get input from user",
|
||||
"action": "input",
|
||||
"block_name": "AgentInputBlock",
|
||||
},
|
||||
{
|
||||
"description": "Process the input",
|
||||
"action": "process",
|
||||
"block_name": "TextFormatterBlock",
|
||||
},
|
||||
{
|
||||
"description": "Return output to user",
|
||||
"action": "output",
|
||||
"block_name": "AgentOutputBlock",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Block IDs from backend/blocks/io.py
|
||||
AGENT_INPUT_BLOCK_ID = "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
|
||||
AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4"
|
||||
|
||||
|
||||
def _generate_dummy_agent_json() -> dict[str, Any]:
|
||||
"""Generate a minimal valid agent JSON for testing."""
|
||||
input_node_id = str(uuid.uuid4())
|
||||
output_node_id = str(uuid.uuid4())
|
||||
|
||||
return {
|
||||
"id": str(uuid.uuid4()),
|
||||
"version": 1,
|
||||
"is_active": True,
|
||||
"name": "Dummy Test Agent",
|
||||
"description": "A dummy agent generated for testing purposes",
|
||||
"nodes": [
|
||||
{
|
||||
"id": input_node_id,
|
||||
"block_id": AGENT_INPUT_BLOCK_ID,
|
||||
"input_default": {
|
||||
"name": "input",
|
||||
"title": "Input",
|
||||
"description": "Enter your input",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
"metadata": {"position": {"x": 0, "y": 0}},
|
||||
},
|
||||
{
|
||||
"id": output_node_id,
|
||||
"block_id": AGENT_OUTPUT_BLOCK_ID,
|
||||
"input_default": {
|
||||
"name": "output",
|
||||
"title": "Output",
|
||||
"description": "Agent output",
|
||||
"format": "{output}",
|
||||
},
|
||||
"metadata": {"position": {"x": 400, "y": 0}},
|
||||
},
|
||||
],
|
||||
"links": [
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": input_node_id,
|
||||
"sink_id": output_node_id,
|
||||
"source_name": "result",
|
||||
"sink_name": "value",
|
||||
"is_static": False,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def decompose_goal_dummy(
|
||||
description: str,
|
||||
context: str = "",
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Return dummy decomposition result."""
|
||||
logger.info("Using dummy agent generator for decompose_goal")
|
||||
return DUMMY_DECOMPOSITION_RESULT.copy()
|
||||
|
||||
|
||||
async def generate_agent_dummy(
|
||||
instructions: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Return dummy agent synchronously (blocks for 30s, returns agent JSON).
|
||||
|
||||
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
|
||||
"""
|
||||
logger.info(
|
||||
"Using dummy agent generator (sync mode): returning agent JSON after 30s"
|
||||
)
|
||||
await asyncio.sleep(30)
|
||||
return _generate_dummy_agent_json()
|
||||
|
||||
|
||||
async def generate_agent_patch_dummy(
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Return dummy patched agent synchronously (blocks for 30s, returns patched agent JSON).
|
||||
|
||||
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
|
||||
"""
|
||||
logger.info(
|
||||
"Using dummy agent generator patch (sync mode): returning patched agent after 30s"
|
||||
)
|
||||
await asyncio.sleep(30)
|
||||
patched = current_agent.copy()
|
||||
patched["description"] = (
|
||||
f"{current_agent.get('description', '')} (updated: {update_request})"
|
||||
)
|
||||
return patched
|
||||
|
||||
|
||||
async def customize_template_dummy(
|
||||
template_agent: dict[str, Any],
|
||||
modification_request: str,
|
||||
context: str = "",
|
||||
) -> dict[str, Any]:
|
||||
"""Return dummy customized template (returns template with updated description)."""
|
||||
logger.info("Using dummy agent generator for customize_template")
|
||||
customized = template_agent.copy()
|
||||
customized["description"] = (
|
||||
f"{template_agent.get('description', '')} (customized: {modification_request})"
|
||||
)
|
||||
return customized
|
||||
|
||||
|
||||
async def get_blocks_dummy() -> list[dict[str, Any]]:
|
||||
"""Return dummy blocks list."""
|
||||
logger.info("Using dummy agent generator for get_blocks")
|
||||
return [
|
||||
{"id": AGENT_INPUT_BLOCK_ID, "name": "AgentInputBlock"},
|
||||
{"id": AGENT_OUTPUT_BLOCK_ID, "name": "AgentOutputBlock"},
|
||||
]
|
||||
|
||||
|
||||
async def health_check_dummy() -> bool:
|
||||
"""Always returns healthy for dummy service."""
|
||||
return True
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,67 +0,0 @@
|
||||
"""Shared helpers for agent generation."""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from .blocks import get_blocks_as_dicts
|
||||
|
||||
__all__ = [
|
||||
"AGENT_EXECUTOR_BLOCK_ID",
|
||||
"AGENT_INPUT_BLOCK_ID",
|
||||
"AGENT_OUTPUT_BLOCK_ID",
|
||||
"AgentDict",
|
||||
"MCP_TOOL_BLOCK_ID",
|
||||
"UUID_REGEX",
|
||||
"are_types_compatible",
|
||||
"generate_uuid",
|
||||
"get_blocks_as_dicts",
|
||||
"get_defined_property_type",
|
||||
"is_uuid",
|
||||
]
|
||||
|
||||
|
||||
# Type alias for the agent JSON structure passed through
|
||||
# the validation and fixing pipeline.
|
||||
AgentDict = dict[str, Any]
|
||||
|
||||
# Shared base pattern (unanchored, lowercase hex); used for both full-string
|
||||
# validation (UUID_REGEX) and text extraction (core._UUID_PATTERN).
|
||||
UUID_RE_STR = r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[a-f0-9]{4}-[a-f0-9]{12}"
|
||||
|
||||
UUID_REGEX = re.compile(r"^" + UUID_RE_STR + r"$")
|
||||
|
||||
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
|
||||
MCP_TOOL_BLOCK_ID = "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
|
||||
AGENT_INPUT_BLOCK_ID = "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
|
||||
AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4"
|
||||
|
||||
|
||||
def is_uuid(value: str) -> bool:
|
||||
"""Check if a string is a valid UUID."""
|
||||
return isinstance(value, str) and UUID_REGEX.match(value) is not None
|
||||
|
||||
|
||||
def generate_uuid() -> str:
|
||||
"""Generate a new UUID string."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def get_defined_property_type(schema: dict[str, Any], name: str) -> str | None:
|
||||
"""Get property type from a schema, handling nested `_#_` notation."""
|
||||
if "_#_" in name:
|
||||
parent, child = name.split("_#_", 1)
|
||||
parent_schema = schema.get(parent, {})
|
||||
if "properties" in parent_schema and isinstance(
|
||||
parent_schema["properties"], dict
|
||||
):
|
||||
return parent_schema["properties"].get(child, {}).get("type")
|
||||
return None
|
||||
return schema.get(name, {}).get("type")
|
||||
|
||||
|
||||
def are_types_compatible(src: str, sink: str) -> bool:
|
||||
"""Check if two schema types are compatible."""
|
||||
if {src, sink} <= {"integer", "number"}:
|
||||
return True
|
||||
return src == sink
|
||||
@@ -1,196 +0,0 @@
|
||||
"""Shared fix → validate → preview/save pipeline for agent tools."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from backend.copilot.tools.models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
from .blocks import get_blocks_as_dicts
|
||||
from .core import get_library_agents_by_ids, save_agent_to_library
|
||||
from .fixer import AgentFixer
|
||||
from .validator import AgentValidator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_AGENT_JSON_SIZE = 1_000_000 # 1 MB
|
||||
|
||||
|
||||
async def fetch_library_agents(
|
||||
user_id: str | None,
|
||||
library_agent_ids: list[str],
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Fetch library agents by IDs for AgentExecutorBlock validation.
|
||||
|
||||
Returns None if no IDs provided or user is not authenticated.
|
||||
"""
|
||||
if not user_id or not library_agent_ids:
|
||||
return None
|
||||
try:
|
||||
agents = await get_library_agents_by_ids(
|
||||
user_id=user_id,
|
||||
agent_ids=library_agent_ids,
|
||||
)
|
||||
return cast(list[dict[str, Any]], agents)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agents by IDs: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def fix_validate_and_save(
|
||||
agent_json: dict[str, Any],
|
||||
*,
|
||||
user_id: str | None,
|
||||
session_id: str | None,
|
||||
save: bool = True,
|
||||
is_update: bool = False,
|
||||
default_name: str = "Agent",
|
||||
preview_message: str | None = None,
|
||||
save_message: str | None = None,
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
folder_id: str | None = None,
|
||||
) -> ToolResponseBase:
|
||||
"""Shared pipeline: auto-fix → validate → preview or save.
|
||||
|
||||
Args:
|
||||
agent_json: The agent JSON dict (must already have id/version/is_active set).
|
||||
user_id: The authenticated user's ID.
|
||||
session_id: The chat session ID.
|
||||
save: Whether to save or just preview.
|
||||
is_update: Whether this is an update to an existing agent.
|
||||
default_name: Fallback name if agent_json has none.
|
||||
preview_message: Custom preview message (optional).
|
||||
save_message: Custom save success message (optional).
|
||||
library_agents: Library agents for AgentExecutorBlock validation/fixing.
|
||||
|
||||
Returns:
|
||||
An appropriate ToolResponseBase subclass.
|
||||
"""
|
||||
# Size guard
|
||||
json_size = len(json.dumps(agent_json))
|
||||
if json_size > MAX_AGENT_JSON_SIZE:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Agent JSON is too large ({json_size:,} bytes, "
|
||||
f"max {MAX_AGENT_JSON_SIZE:,}). Reduce the number of nodes."
|
||||
),
|
||||
error="agent_json_too_large",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
blocks = get_blocks_as_dicts()
|
||||
|
||||
# Auto-fix
|
||||
try:
|
||||
fixer = AgentFixer()
|
||||
agent_json = fixer.apply_all_fixes(agent_json, blocks, library_agents)
|
||||
fixes = fixer.get_fixes_applied()
|
||||
if fixes:
|
||||
logger.info(f"Applied {len(fixes)} auto-fixes to agent JSON")
|
||||
except Exception as e:
|
||||
logger.warning(f"Auto-fix failed: {e}")
|
||||
|
||||
# Validate
|
||||
try:
|
||||
validator = AgentValidator()
|
||||
is_valid, _ = validator.validate(agent_json, blocks, library_agents)
|
||||
if not is_valid:
|
||||
errors = validator.errors
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"The agent has {len(errors)} validation error(s):\n"
|
||||
+ "\n".join(f"- {e}" for e in errors[:5])
|
||||
),
|
||||
error="validation_failed",
|
||||
details={"errors": errors},
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Validation failed with exception: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to validate the agent. Please try again.",
|
||||
error="validation_exception",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agent_name = agent_json.get("name", default_name)
|
||||
agent_description = agent_json.get("description", "")
|
||||
node_count = len(agent_json.get("nodes", []))
|
||||
link_count = len(agent_json.get("links", []))
|
||||
|
||||
# Build a warning suffix when name/description is missing or generic
|
||||
_GENERIC_NAMES = {
|
||||
"agent",
|
||||
"generated agent",
|
||||
"customized agent",
|
||||
"updated agent",
|
||||
"new agent",
|
||||
"my agent",
|
||||
}
|
||||
metadata_warnings: list[str] = []
|
||||
if not agent_json.get("name") or agent_name.lower().strip() in _GENERIC_NAMES:
|
||||
metadata_warnings.append("'name'")
|
||||
if not agent_description:
|
||||
metadata_warnings.append("'description'")
|
||||
metadata_hint = ""
|
||||
if metadata_warnings:
|
||||
missing = " and ".join(metadata_warnings)
|
||||
metadata_hint = (
|
||||
f" Note: the agent is missing a meaningful {missing}. "
|
||||
f"Please update the agent_json to include them."
|
||||
)
|
||||
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
(
|
||||
preview_message
|
||||
or f"Agent '{agent_name}' with {node_count} blocks is ready."
|
||||
)
|
||||
+ metadata_hint
|
||||
),
|
||||
agent_json=agent_json,
|
||||
agent_name=agent_name,
|
||||
description=agent_description,
|
||||
node_count=node_count,
|
||||
link_count=link_count,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
error="auth_required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
agent_json, user_id, is_update=is_update, folder_id=folder_id
|
||||
)
|
||||
return AgentSavedResponse(
|
||||
message=(
|
||||
(save_message or f"Agent '{created_graph.name}' has been saved!")
|
||||
+ metadata_hint
|
||||
),
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save agent: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to save the agent: {str(e)}",
|
||||
error="save_failed",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -0,0 +1,511 @@
|
||||
"""External Agent Generator service client.
|
||||
|
||||
This module provides a client for communicating with the external Agent Generator
|
||||
microservice. When AGENTGENERATOR_HOST is configured, the agent generation functions
|
||||
will delegate to the external service instead of using the built-in LLM-based implementation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .dummy import (
|
||||
customize_template_dummy,
|
||||
decompose_goal_dummy,
|
||||
generate_agent_dummy,
|
||||
generate_agent_patch_dummy,
|
||||
get_blocks_dummy,
|
||||
health_check_dummy,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_dummy_mode_warned = False
|
||||
|
||||
|
||||
def _create_error_response(
|
||||
error_message: str,
|
||||
error_type: str = "unknown",
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a standardized error response dict.
|
||||
|
||||
Args:
|
||||
error_message: Human-readable error message
|
||||
error_type: Machine-readable error type
|
||||
details: Optional additional error details
|
||||
|
||||
Returns:
|
||||
Error dict with type="error" and error details
|
||||
"""
|
||||
response: dict[str, Any] = {
|
||||
"type": "error",
|
||||
"error": error_message,
|
||||
"error_type": error_type,
|
||||
}
|
||||
if details:
|
||||
response["details"] = details
|
||||
return response
|
||||
|
||||
|
||||
def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]:
|
||||
"""Classify an HTTP error into error_type and message.
|
||||
|
||||
Args:
|
||||
e: The HTTP status error
|
||||
|
||||
Returns:
|
||||
Tuple of (error_type, error_message)
|
||||
"""
|
||||
status = e.response.status_code
|
||||
if status == 429:
|
||||
return "rate_limit", f"Agent Generator rate limited: {e}"
|
||||
elif status == 503:
|
||||
return "service_unavailable", f"Agent Generator unavailable: {e}"
|
||||
elif status == 504 or status == 408:
|
||||
return "timeout", f"Agent Generator timed out: {e}"
|
||||
else:
|
||||
return "http_error", f"HTTP error calling Agent Generator: {e}"
|
||||
|
||||
|
||||
def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]:
|
||||
"""Classify a request error into error_type and message.
|
||||
|
||||
Args:
|
||||
e: The request error
|
||||
|
||||
Returns:
|
||||
Tuple of (error_type, error_message)
|
||||
"""
|
||||
error_str = str(e).lower()
|
||||
if "timeout" in error_str or "timed out" in error_str:
|
||||
return "timeout", f"Agent Generator request timed out: {e}"
|
||||
elif "connect" in error_str:
|
||||
return "connection_error", f"Could not connect to Agent Generator: {e}"
|
||||
else:
|
||||
return "request_error", f"Request error calling Agent Generator: {e}"
|
||||
|
||||
|
||||
_client: httpx.AsyncClient | None = None
|
||||
_settings: Settings | None = None
|
||||
|
||||
|
||||
def _get_settings() -> Settings:
|
||||
"""Get or create settings singleton."""
|
||||
global _settings
|
||||
if _settings is None:
|
||||
_settings = Settings()
|
||||
return _settings
|
||||
|
||||
|
||||
def _is_dummy_mode() -> bool:
|
||||
"""Check if dummy mode is enabled for testing."""
|
||||
global _dummy_mode_warned
|
||||
settings = _get_settings()
|
||||
is_dummy = bool(settings.config.agentgenerator_use_dummy)
|
||||
if is_dummy and not _dummy_mode_warned:
|
||||
logger.warning(
|
||||
"Agent Generator running in DUMMY MODE - returning mock responses. "
|
||||
"Do not use in production!"
|
||||
)
|
||||
_dummy_mode_warned = True
|
||||
return is_dummy
|
||||
|
||||
|
||||
def is_external_service_configured() -> bool:
|
||||
"""Check if external Agent Generator service is configured (or dummy mode)."""
|
||||
settings = _get_settings()
|
||||
return bool(settings.config.agentgenerator_host) or bool(
|
||||
settings.config.agentgenerator_use_dummy
|
||||
)
|
||||
|
||||
|
||||
def _get_base_url() -> str:
|
||||
"""Get the base URL for the external service."""
|
||||
settings = _get_settings()
|
||||
host = settings.config.agentgenerator_host
|
||||
port = settings.config.agentgenerator_port
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def _get_client() -> httpx.AsyncClient:
|
||||
"""Get or create the HTTP client for the external service."""
|
||||
global _client
|
||||
if _client is None:
|
||||
settings = _get_settings()
|
||||
_client = httpx.AsyncClient(
|
||||
base_url=_get_base_url(),
|
||||
timeout=httpx.Timeout(settings.config.agentgenerator_timeout),
|
||||
)
|
||||
return _client
|
||||
|
||||
|
||||
async def decompose_goal_external(
|
||||
description: str,
|
||||
context: str = "",
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to decompose a goal.
|
||||
|
||||
Args:
|
||||
description: Natural language goal description
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
Dict with either:
|
||||
- {"type": "clarifying_questions", "questions": [...]}
|
||||
- {"type": "instructions", "steps": [...]}
|
||||
- {"type": "unachievable_goal", ...}
|
||||
- {"type": "vague_goal", ...}
|
||||
- {"type": "error", "error": "...", "error_type": "..."} on error
|
||||
Or None on unexpected error
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await decompose_goal_dummy(description, context, library_agents)
|
||||
|
||||
client = _get_client()
|
||||
|
||||
if context:
|
||||
description = f"{description}\n\nAdditional context from user:\n{context}"
|
||||
|
||||
payload: dict[str, Any] = {"description": description}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
|
||||
try:
|
||||
response = await client.post("/api/decompose-description", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator decomposition failed: {error_msg} "
|
||||
f"(type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
# Map the response to the expected format
|
||||
response_type = data.get("type")
|
||||
if response_type == "instructions":
|
||||
return {"type": "instructions", "steps": data.get("steps", [])}
|
||||
elif response_type == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": data.get("questions", []),
|
||||
}
|
||||
elif response_type == "unachievable_goal":
|
||||
return {
|
||||
"type": "unachievable_goal",
|
||||
"reason": data.get("reason"),
|
||||
"suggested_goal": data.get("suggested_goal"),
|
||||
}
|
||||
elif response_type == "vague_goal":
|
||||
return {
|
||||
"type": "vague_goal",
|
||||
"suggested_goal": data.get("suggested_goal"),
|
||||
}
|
||||
elif response_type == "error":
|
||||
# Pass through error from the service
|
||||
return _create_error_response(
|
||||
data.get("error", "Unknown error"),
|
||||
data.get("error_type", "unknown"),
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Unknown response type from external service: {response_type}"
|
||||
)
|
||||
return _create_error_response(
|
||||
f"Unknown response type from Agent Generator: {response_type}",
|
||||
"invalid_response",
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
|
||||
async def generate_agent_external(
|
||||
instructions: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate an agent from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
Agent JSON dict or error dict {"type": "error", ...} on error
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await generate_agent_dummy(instructions, library_agents)
|
||||
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
payload: dict[str, Any] = {"instructions": instructions}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
|
||||
try:
|
||||
response = await client.post("/api/generate-agent", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator generation failed: {error_msg} (type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
return data.get("agent_json")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
|
||||
async def generate_agent_patch_external(
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate a patch for an existing agent.
|
||||
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||
session_id: Session ID for async processing (enables Redis Streams callback)
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await generate_agent_patch_dummy(
|
||||
update_request, current_agent, library_agents
|
||||
)
|
||||
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
payload: dict[str, Any] = {
|
||||
"update_request": update_request,
|
||||
"current_agent_json": current_agent,
|
||||
}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
|
||||
try:
|
||||
response = await client.post("/api/update-agent", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator patch generation failed: {error_msg} "
|
||||
f"(type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
# Check if it's clarifying questions
|
||||
if data.get("type") == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": data.get("questions", []),
|
||||
}
|
||||
|
||||
# Check if it's an error passed through
|
||||
if data.get("type") == "error":
|
||||
return _create_error_response(
|
||||
data.get("error", "Unknown error"),
|
||||
data.get("error_type", "unknown"),
|
||||
)
|
||||
|
||||
# Otherwise return the updated agent JSON
|
||||
return data.get("agent_json")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
|
||||
async def customize_template_external(
|
||||
template_agent: dict[str, Any],
|
||||
modification_request: str,
|
||||
context: str = "",
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to customize a template/marketplace agent.
|
||||
|
||||
Args:
|
||||
template_agent: The template agent JSON to customize
|
||||
modification_request: Natural language description of customizations
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||
session_id: Session ID for async processing (enables Redis Streams callback)
|
||||
|
||||
Returns:
|
||||
Customized agent JSON, clarifying questions dict, or error dict on error
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await customize_template_dummy(
|
||||
template_agent, modification_request, context
|
||||
)
|
||||
|
||||
client = _get_client()
|
||||
|
||||
request = modification_request
|
||||
if context:
|
||||
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"template_agent_json": template_agent,
|
||||
"modification_request": request,
|
||||
}
|
||||
|
||||
try:
|
||||
response = await client.post("/api/template-modification", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator template customization failed: {error_msg} "
|
||||
f"(type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
# Check if it's clarifying questions
|
||||
if data.get("type") == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": data.get("questions", []),
|
||||
}
|
||||
|
||||
# Check if it's an error passed through
|
||||
if data.get("type") == "error":
|
||||
return _create_error_response(
|
||||
data.get("error", "Unknown error"),
|
||||
data.get("error_type", "unknown"),
|
||||
)
|
||||
|
||||
# Otherwise return the customized agent JSON
|
||||
return data.get("agent_json")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
|
||||
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||
"""Get available blocks from the external service.
|
||||
|
||||
Returns:
|
||||
List of block info dicts or None on error
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await get_blocks_dummy()
|
||||
|
||||
client = _get_client()
|
||||
|
||||
try:
|
||||
response = await client.get("/api/blocks")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
logger.error("External service returned error getting blocks")
|
||||
return None
|
||||
|
||||
return data.get("blocks", [])
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error getting blocks from external service: {e}")
|
||||
return None
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error getting blocks from external service: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error getting blocks from external service: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def health_check() -> bool:
|
||||
"""Check if the external service is healthy.
|
||||
|
||||
Returns:
|
||||
True if healthy, False otherwise
|
||||
"""
|
||||
if not is_external_service_configured():
|
||||
return False
|
||||
|
||||
if _is_dummy_mode():
|
||||
return await health_check_dummy()
|
||||
|
||||
client = _get_client()
|
||||
|
||||
try:
|
||||
response = await client.get("/health")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("status") == "healthy" and data.get("blocks_loaded", False)
|
||||
except Exception as e:
|
||||
logger.warning(f"External agent generator health check failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def close_client() -> None:
|
||||
"""Close the HTTP client."""
|
||||
global _client
|
||||
if _client is not None:
|
||||
await _client.aclose()
|
||||
_client = None
|
||||
@@ -1,17 +0,0 @@
|
||||
"""Agent generation validation — re-exports from split modules.
|
||||
|
||||
This module was split into:
|
||||
- helpers.py: get_blocks_as_dicts, block cache
|
||||
- fixer.py: AgentFixer class
|
||||
- validator.py: AgentValidator class
|
||||
"""
|
||||
|
||||
from .fixer import AgentFixer
|
||||
from .helpers import get_blocks_as_dicts
|
||||
from .validator import AgentValidator
|
||||
|
||||
__all__ = [
|
||||
"AgentFixer",
|
||||
"AgentValidator",
|
||||
"get_blocks_as_dicts",
|
||||
]
|
||||
@@ -1,939 +0,0 @@
|
||||
"""AgentValidator — validates agent JSON graphs for correctness."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from .helpers import (
|
||||
AGENT_EXECUTOR_BLOCK_ID,
|
||||
AGENT_INPUT_BLOCK_ID,
|
||||
AGENT_OUTPUT_BLOCK_ID,
|
||||
MCP_TOOL_BLOCK_ID,
|
||||
AgentDict,
|
||||
are_types_compatible,
|
||||
get_defined_property_type,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentValidator:
|
||||
"""
|
||||
A comprehensive validator for AutoGPT agents that provides detailed error
|
||||
reporting for LLM-based fixes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.errors: list[str] = []
|
||||
|
||||
def add_error(self, error_message: str) -> None:
|
||||
"""Add an error message to the validation errors list."""
|
||||
self.errors.append(error_message)
|
||||
|
||||
def _values_equal(self, val1: Any, val2: Any) -> bool:
|
||||
"""Compare two values, handling complex types like dicts and lists."""
|
||||
if type(val1) is not type(val2):
|
||||
return False
|
||||
if isinstance(val1, dict):
|
||||
return json.dumps(val1, sort_keys=True) == json.dumps(val2, sort_keys=True)
|
||||
if isinstance(val1, list):
|
||||
return json.dumps(val1, sort_keys=True) == json.dumps(val2, sort_keys=True)
|
||||
return val1 == val2
|
||||
|
||||
def validate_block_existence(
|
||||
self, agent: AgentDict, blocks: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that all block IDs used in the agent actually exist in the
|
||||
blocks list. Returns True if all block IDs exist, False otherwise.
|
||||
"""
|
||||
valid = True
|
||||
|
||||
# Create a set of all valid block IDs for fast lookup
|
||||
valid_block_ids = {block.get("id") for block in blocks if block.get("id")}
|
||||
|
||||
# Check each node's block_id
|
||||
for node in agent.get("nodes", []):
|
||||
block_id = node.get("block_id")
|
||||
node_id = node.get("id")
|
||||
|
||||
if not block_id:
|
||||
self.add_error(
|
||||
f"Node '{node_id}' is missing a 'block_id' field. "
|
||||
f"Every node must reference a valid block."
|
||||
)
|
||||
valid = False
|
||||
continue
|
||||
|
||||
if block_id not in valid_block_ids:
|
||||
self.add_error(
|
||||
f"Node '{node_id}' references block_id '{block_id}' "
|
||||
f"which does not exist in the available blocks. "
|
||||
f"This block may have been deprecated, removed, or "
|
||||
f"the ID is incorrect. Please use a valid block from "
|
||||
f"the blocks library."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_link_node_references(self, agent: AgentDict) -> bool:
|
||||
"""
|
||||
Validate that all node IDs referenced in links actually exist in the
|
||||
agent's nodes. Returns True if all link references are valid, False
|
||||
otherwise.
|
||||
"""
|
||||
valid = True
|
||||
|
||||
# Create a set of all valid node IDs for fast lookup
|
||||
valid_node_ids = {
|
||||
node.get("id") for node in agent.get("nodes", []) if node.get("id")
|
||||
}
|
||||
|
||||
# Check each link's source_id and sink_id
|
||||
for link in agent.get("links", []):
|
||||
link_id = link.get("id", "Unknown")
|
||||
source_id = link.get("source_id")
|
||||
sink_id = link.get("sink_id")
|
||||
source_name = link.get("source_name", "")
|
||||
sink_name = link.get("sink_name", "")
|
||||
|
||||
# Check source_id
|
||||
if not source_id:
|
||||
self.add_error(
|
||||
f"Link '{link_id}' is missing a 'source_id' field. "
|
||||
f"Every link must reference a valid source node."
|
||||
)
|
||||
valid = False
|
||||
elif source_id not in valid_node_ids:
|
||||
self.add_error(
|
||||
f"Link '{link_id}' references source_id '{source_id}' "
|
||||
f"which does not exist in the agent's nodes. The link "
|
||||
f"from '{source_name}' cannot be established because "
|
||||
f"the source node is missing."
|
||||
)
|
||||
valid = False
|
||||
|
||||
# Check sink_id
|
||||
if not sink_id:
|
||||
self.add_error(
|
||||
f"Link '{link_id}' is missing a 'sink_id' field. "
|
||||
f"Every link must reference a valid sink (destination) "
|
||||
f"node."
|
||||
)
|
||||
valid = False
|
||||
elif sink_id not in valid_node_ids:
|
||||
self.add_error(
|
||||
f"Link '{link_id}' references sink_id '{sink_id}' "
|
||||
f"which does not exist in the agent's nodes. The link "
|
||||
f"to '{sink_name}' cannot be established because the "
|
||||
f"destination node is missing."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_required_inputs(
|
||||
self, agent: AgentDict, blocks: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that all required inputs are provided for each node.
|
||||
Returns True if all required inputs are satisfied, False otherwise.
|
||||
"""
|
||||
valid = True
|
||||
|
||||
block_lookup = {b.get("id", ""): b for b in blocks}
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
block_id = node.get("block_id")
|
||||
block = block_lookup.get(block_id)
|
||||
|
||||
if not block:
|
||||
continue
|
||||
|
||||
required_inputs = block.get("inputSchema", {}).get("required", [])
|
||||
input_defaults = node.get("input_default", {})
|
||||
node_id = node.get("id")
|
||||
|
||||
linked_inputs = set(
|
||||
link.get("sink_name")
|
||||
for link in agent.get("links", [])
|
||||
if link.get("sink_id") == node_id and link.get("sink_name")
|
||||
)
|
||||
|
||||
for req_input in required_inputs:
|
||||
if (
|
||||
req_input not in input_defaults
|
||||
and req_input not in linked_inputs
|
||||
and req_input != "credentials"
|
||||
):
|
||||
block_name = block.get("name", "Unknown Block")
|
||||
self.add_error(
|
||||
f"Node '{node_id}' (block '{block_name}' - "
|
||||
f"{block_id}) is missing required input "
|
||||
f"'{req_input}'. This input must be either "
|
||||
f"provided as a default value in the node's "
|
||||
f"'input_default' field or connected via a link "
|
||||
f"from another node's output."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_data_type_compatibility(
|
||||
self, agent: AgentDict, blocks: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that linked data types are compatible between source and sink.
|
||||
Returns True if all data types are compatible, False otherwise.
|
||||
"""
|
||||
valid = True
|
||||
node_lookup = {node.get("id", ""): node for node in agent.get("nodes", [])}
|
||||
block_lookup = {block.get("id", ""): block for block in blocks}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
source_id = link.get("source_id")
|
||||
sink_id = link.get("sink_id")
|
||||
source_name = link.get("source_name")
|
||||
sink_name = link.get("sink_name")
|
||||
|
||||
if not all(
|
||||
isinstance(v, str) and v
|
||||
for v in (source_id, sink_id, source_name, sink_name)
|
||||
):
|
||||
self.add_error(
|
||||
f"Link '{link.get('id', 'Unknown')}' is missing required "
|
||||
f"fields (source_id/sink_id/source_name/sink_name)."
|
||||
)
|
||||
valid = False
|
||||
continue
|
||||
|
||||
source_node = node_lookup.get(source_id, "")
|
||||
sink_node = node_lookup.get(sink_id, "")
|
||||
|
||||
if not source_node or not sink_node:
|
||||
continue
|
||||
|
||||
source_block = block_lookup.get(source_node.get("block_id", ""))
|
||||
sink_block = block_lookup.get(sink_node.get("block_id", ""))
|
||||
|
||||
if not source_block or not sink_block:
|
||||
continue
|
||||
|
||||
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
|
||||
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
|
||||
|
||||
source_type = get_defined_property_type(source_outputs, source_name)
|
||||
sink_type = get_defined_property_type(sink_inputs, sink_name)
|
||||
|
||||
if (
|
||||
source_type
|
||||
and sink_type
|
||||
and not are_types_compatible(source_type, sink_type)
|
||||
):
|
||||
source_block_name = source_block.get("name", "Unknown Block")
|
||||
sink_block_name = sink_block.get("name", "Unknown Block")
|
||||
self.add_error(
|
||||
f"Data type mismatch in link '{link.get('id')}': "
|
||||
f"Source '{source_block_name}' output "
|
||||
f"'{link.get('source_name', '')}' outputs '{source_type}' "
|
||||
f"type, but sink '{sink_block_name}' input "
|
||||
f"'{link.get('sink_name', '')}' expects '{sink_type}' type. "
|
||||
f"These types must match for the connection to work "
|
||||
f"properly."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_nested_sink_links(
|
||||
self, agent: AgentDict, blocks: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""
|
||||
Validate nested sink links (links with _#_ notation).
|
||||
Returns True if all nested links are valid, False otherwise.
|
||||
"""
|
||||
valid = True
|
||||
block_input_schemas = {
|
||||
block.get("id", ""): block.get("inputSchema", {}).get("properties", {})
|
||||
for block in blocks
|
||||
}
|
||||
block_names = {
|
||||
block.get("id", ""): block.get("name", "Unknown Block") for block in blocks
|
||||
}
|
||||
node_lookup = {node.get("id", ""): node for node in agent.get("nodes", [])}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
sink_name = link.get("sink_name", "")
|
||||
sink_id = link.get("sink_id")
|
||||
|
||||
if not sink_name or not sink_id:
|
||||
continue
|
||||
|
||||
if "_#_" in sink_name:
|
||||
parent, child = sink_name.split("_#_", 1)
|
||||
|
||||
sink_node = node_lookup.get(sink_id)
|
||||
if not sink_node:
|
||||
continue
|
||||
|
||||
block_id = sink_node.get("block_id")
|
||||
input_props = block_input_schemas.get(block_id, {})
|
||||
|
||||
parent_schema = input_props.get(parent)
|
||||
if not parent_schema:
|
||||
block_name = block_names.get(block_id, "Unknown Block")
|
||||
self.add_error(
|
||||
f"Invalid nested sink link '{sink_name}' for "
|
||||
f"node '{sink_id}' (block "
|
||||
f"'{block_name}' - {block_id}): Parent property "
|
||||
f"'{parent}' does not exist in the block's "
|
||||
f"input schema."
|
||||
)
|
||||
valid = False
|
||||
continue
|
||||
|
||||
# Check if additionalProperties is allowed either directly
|
||||
# or via anyOf
|
||||
allows_additional_properties = parent_schema.get(
|
||||
"additionalProperties", False
|
||||
)
|
||||
|
||||
# Check anyOf for additionalProperties
|
||||
if not allows_additional_properties and "anyOf" in parent_schema:
|
||||
any_of_schemas = parent_schema.get("anyOf", [])
|
||||
if isinstance(any_of_schemas, list):
|
||||
for schema_option in any_of_schemas:
|
||||
if isinstance(schema_option, dict) and schema_option.get(
|
||||
"additionalProperties"
|
||||
):
|
||||
allows_additional_properties = True
|
||||
break
|
||||
|
||||
if not allows_additional_properties:
|
||||
if not (
|
||||
isinstance(parent_schema, dict)
|
||||
and "properties" in parent_schema
|
||||
and isinstance(parent_schema["properties"], dict)
|
||||
and child in parent_schema["properties"]
|
||||
):
|
||||
block_name = block_names.get(block_id, "Unknown Block")
|
||||
self.add_error(
|
||||
f"Invalid nested sink link '{sink_name}' "
|
||||
f"for node '{link.get('sink_id', '')}' (block "
|
||||
f"'{block_name}' - {block_id}): Child "
|
||||
f"property '{child}' does not exist in "
|
||||
f"parent '{parent}' schema. Available "
|
||||
f"properties: "
|
||||
f"{list(parent_schema.get('properties', {}).keys())}"
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_prompt_double_curly_braces_spaces(self, agent: AgentDict) -> bool:
|
||||
"""
|
||||
Validate that prompt parameters do not contain spaces in double curly
|
||||
braces.
|
||||
|
||||
Checks the 'prompt' parameter in input_default of each node and reports
|
||||
errors if values within double curly braces ({{...}}) contain spaces.
|
||||
For example, {{user name}} should be {{user_name}}.
|
||||
|
||||
Args:
|
||||
agent: The agent dictionary to validate
|
||||
|
||||
Returns:
|
||||
True if all prompts are valid (no spaces in double curly braces),
|
||||
False otherwise
|
||||
"""
|
||||
valid = True
|
||||
nodes = agent.get("nodes", [])
|
||||
|
||||
for node in nodes:
|
||||
node_id = node.get("id")
|
||||
input_default = node.get("input_default", {})
|
||||
|
||||
# Check if 'prompt' parameter exists
|
||||
if "prompt" not in input_default:
|
||||
continue
|
||||
|
||||
prompt_text = input_default["prompt"]
|
||||
|
||||
# Only process if it's a string
|
||||
if not isinstance(prompt_text, str):
|
||||
continue
|
||||
|
||||
# Find all double curly brace patterns with spaces
|
||||
matches = re.finditer(r"\{\{([^}]+)\}\}", prompt_text)
|
||||
|
||||
for match in matches:
|
||||
content = match.group(1)
|
||||
if " " in content:
|
||||
start_pos = match.start()
|
||||
snippet_start = max(0, start_pos - 30)
|
||||
snippet_end = min(len(prompt_text), match.end() + 30)
|
||||
snippet = prompt_text[snippet_start:snippet_end]
|
||||
|
||||
self.add_error(
|
||||
f"Node '{node_id}' has spaces in double curly "
|
||||
f"braces in prompt parameter: "
|
||||
f"'{{{{{content}}}}}' should be "
|
||||
f"'{{{{{content.replace(' ', '_')}}}}}'. "
|
||||
f"Context: ...{snippet}..."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_source_output_existence(
|
||||
self, agent: AgentDict, blocks: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that all source_names in links exist in the corresponding
|
||||
block's output schema.
|
||||
|
||||
Checks that for each link, the source_name field references a valid
|
||||
output property in the source block's outputSchema. Also handles nested
|
||||
outputs with _#_ notation.
|
||||
|
||||
Args:
|
||||
agent: The agent dictionary to validate
|
||||
blocks: List of available blocks with their schemas
|
||||
|
||||
Returns:
|
||||
True if all source output fields exist, False otherwise
|
||||
"""
|
||||
valid = True
|
||||
|
||||
# Create lookup dictionaries for efficiency
|
||||
block_output_schemas = {
|
||||
block.get("id", ""): block.get("outputSchema", {}).get("properties", {})
|
||||
for block in blocks
|
||||
}
|
||||
block_names = {
|
||||
block.get("id", ""): block.get("name", "Unknown Block") for block in blocks
|
||||
}
|
||||
node_lookup = {node.get("id", ""): node for node in agent.get("nodes", [])}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
source_id = link.get("source_id")
|
||||
source_name = link.get("source_name", "")
|
||||
link_id = link.get("id", "Unknown")
|
||||
|
||||
if not source_name:
|
||||
self.add_error(
|
||||
f"Link '{link_id}' is missing 'source_name'. "
|
||||
f"Every link must specify which output field to read from."
|
||||
)
|
||||
valid = False
|
||||
continue
|
||||
|
||||
source_node = node_lookup.get(source_id)
|
||||
if not source_node:
|
||||
# This error is already caught by
|
||||
# validate_link_node_references
|
||||
continue
|
||||
|
||||
block_id = source_node.get("block_id")
|
||||
block_name = block_names.get(block_id, "Unknown Block")
|
||||
|
||||
# Special handling for AgentExecutorBlock - use dynamic
|
||||
# output_schema from input_default
|
||||
if block_id == AGENT_EXECUTOR_BLOCK_ID:
|
||||
input_default = source_node.get("input_default", {})
|
||||
dynamic_output_schema = input_default.get("output_schema", {})
|
||||
if not isinstance(dynamic_output_schema, dict):
|
||||
dynamic_output_schema = {}
|
||||
output_props = dynamic_output_schema.get("properties", {})
|
||||
if not isinstance(output_props, dict):
|
||||
output_props = {}
|
||||
else:
|
||||
output_props = block_output_schemas.get(block_id, {})
|
||||
|
||||
# Handle nested source names (with _#_ notation)
|
||||
if "_#_" in source_name:
|
||||
parent, child = source_name.split("_#_", 1)
|
||||
|
||||
parent_schema = output_props.get(parent)
|
||||
if not parent_schema:
|
||||
self.add_error(
|
||||
f"Invalid source output field '{source_name}' "
|
||||
f"in link '{link_id}' from node '{source_id}' "
|
||||
f"(block '{block_name}' - {block_id}): Parent "
|
||||
f"property '{parent}' does not exist in the "
|
||||
f"block's output schema."
|
||||
)
|
||||
valid = False
|
||||
continue
|
||||
|
||||
# Check if additionalProperties is allowed either directly
|
||||
# or via anyOf
|
||||
allows_additional_properties = parent_schema.get(
|
||||
"additionalProperties", False
|
||||
)
|
||||
if not allows_additional_properties and "anyOf" in parent_schema:
|
||||
any_of_schemas = parent_schema.get("anyOf", [])
|
||||
if isinstance(any_of_schemas, list):
|
||||
for schema_option in any_of_schemas:
|
||||
if isinstance(schema_option, dict) and schema_option.get(
|
||||
"additionalProperties"
|
||||
):
|
||||
allows_additional_properties = True
|
||||
break
|
||||
# Also allow when items have
|
||||
# additionalProperties (array of objects)
|
||||
if (
|
||||
isinstance(schema_option, dict)
|
||||
and "items" in schema_option
|
||||
):
|
||||
items_schema = schema_option.get("items")
|
||||
if isinstance(items_schema, dict) and items_schema.get(
|
||||
"additionalProperties"
|
||||
):
|
||||
allows_additional_properties = True
|
||||
break
|
||||
|
||||
# Only require child in properties when
|
||||
# additionalProperties is not allowed
|
||||
if not allows_additional_properties:
|
||||
if not (
|
||||
isinstance(parent_schema, dict)
|
||||
and "properties" in parent_schema
|
||||
and isinstance(parent_schema["properties"], dict)
|
||||
and child in parent_schema["properties"]
|
||||
):
|
||||
available_props = (
|
||||
list(parent_schema.get("properties", {}).keys())
|
||||
if isinstance(parent_schema, dict)
|
||||
else []
|
||||
)
|
||||
self.add_error(
|
||||
f"Invalid nested source output field "
|
||||
f"'{source_name}' in link '{link_id}' from "
|
||||
f"node '{source_id}' (block "
|
||||
f"'{block_name}' - {block_id}): Child "
|
||||
f"property '{child}' does not exist in "
|
||||
f"parent '{parent}' output schema. "
|
||||
f"Available properties: {available_props}"
|
||||
)
|
||||
valid = False
|
||||
else:
|
||||
# Check simple (non-nested) source name
|
||||
if source_name not in output_props:
|
||||
available_outputs = list(output_props.keys())
|
||||
self.add_error(
|
||||
f"Invalid source output field '{source_name}' "
|
||||
f"in link '{link_id}' from node '{source_id}' "
|
||||
f"(block '{block_name}' - {block_id}): Output "
|
||||
f"property '{source_name}' does not exist in "
|
||||
f"the block's output schema. Available outputs: "
|
||||
f"{available_outputs}"
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_io_blocks(self, agent: AgentDict) -> bool:
|
||||
"""
|
||||
Validate that the agent has at least one AgentInputBlock and one
|
||||
AgentOutputBlock. These blocks define the agent's interface.
|
||||
|
||||
Returns True if both are present, False otherwise.
|
||||
"""
|
||||
valid = True
|
||||
block_ids = {node.get("block_id") for node in agent.get("nodes", [])}
|
||||
|
||||
if AGENT_INPUT_BLOCK_ID not in block_ids:
|
||||
self.add_error(
|
||||
f"Agent is missing an AgentInputBlock (block_id: "
|
||||
f"'{AGENT_INPUT_BLOCK_ID}'). Every agent must have at "
|
||||
f"least one AgentInputBlock to define user-facing inputs. "
|
||||
f"Add a node with block_id '{AGENT_INPUT_BLOCK_ID}' and "
|
||||
f"set input_default with 'name' and optionally 'title'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
if AGENT_OUTPUT_BLOCK_ID not in block_ids:
|
||||
self.add_error(
|
||||
f"Agent is missing an AgentOutputBlock (block_id: "
|
||||
f"'{AGENT_OUTPUT_BLOCK_ID}'). Every agent must have at "
|
||||
f"least one AgentOutputBlock to define user-facing outputs. "
|
||||
f"Add a node with block_id '{AGENT_OUTPUT_BLOCK_ID}' and "
|
||||
f"set input_default with 'name', then link 'value' from "
|
||||
f"another block's output."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_agent_executor_blocks(
|
||||
self,
|
||||
agent: AgentDict,
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Validate AgentExecutorBlock nodes have required fields and valid
|
||||
references.
|
||||
|
||||
Checks that AgentExecutorBlock nodes:
|
||||
1. Have a valid graph_id in input_default (required)
|
||||
2. If graph_id matches a known library agent, validates version
|
||||
consistency
|
||||
3. Sub-agent required inputs are connected via links (not hardcoded)
|
||||
|
||||
Note: Unknown graph_ids are not treated as errors - they could be valid
|
||||
direct references to agents by their actual ID (not via library_agents).
|
||||
This is consistent with fix_agent_executor_blocks() behavior.
|
||||
|
||||
Args:
|
||||
agent: The agent dictionary to validate
|
||||
library_agents: List of available library agents (for version
|
||||
validation)
|
||||
|
||||
Returns:
|
||||
True if all AgentExecutorBlock nodes are valid, False otherwise
|
||||
"""
|
||||
valid = True
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
|
||||
# Create lookup for library agents
|
||||
library_agent_lookup: dict[str, dict[str, Any]] = {}
|
||||
if library_agents:
|
||||
library_agent_lookup = {la.get("graph_id", ""): la for la in library_agents}
|
||||
|
||||
for node in nodes:
|
||||
if node.get("block_id") != AGENT_EXECUTOR_BLOCK_ID:
|
||||
continue
|
||||
|
||||
node_id = node.get("id")
|
||||
input_default = node.get("input_default", {})
|
||||
|
||||
# Check for required graph_id
|
||||
graph_id = input_default.get("graph_id")
|
||||
if not graph_id:
|
||||
self.add_error(
|
||||
f"AgentExecutorBlock node '{node_id}' is missing "
|
||||
f"required 'graph_id' in input_default. This field "
|
||||
f"must reference the ID of the sub-agent to execute."
|
||||
)
|
||||
valid = False
|
||||
continue
|
||||
|
||||
# If graph_id is not in library_agent_lookup, skip validation
|
||||
if graph_id not in library_agent_lookup:
|
||||
continue
|
||||
|
||||
# Validate version consistency for known library agents
|
||||
library_agent = library_agent_lookup[graph_id]
|
||||
expected_version = library_agent.get("graph_version")
|
||||
current_version = input_default.get("graph_version")
|
||||
if (
|
||||
current_version
|
||||
and expected_version
|
||||
and current_version != expected_version
|
||||
):
|
||||
self.add_error(
|
||||
f"AgentExecutorBlock node '{node_id}' has mismatched "
|
||||
f"graph_version: got {current_version}, expected "
|
||||
f"{expected_version} for library agent "
|
||||
f"'{library_agent.get('name')}'"
|
||||
)
|
||||
valid = False
|
||||
|
||||
# Validate sub-agent inputs are properly linked (not hardcoded)
|
||||
sub_agent_input_schema = library_agent.get("input_schema", {})
|
||||
if not isinstance(sub_agent_input_schema, dict):
|
||||
sub_agent_input_schema = {}
|
||||
sub_agent_required_inputs = sub_agent_input_schema.get("required", [])
|
||||
sub_agent_properties = sub_agent_input_schema.get("properties", {})
|
||||
|
||||
# Get all linked inputs to this node
|
||||
linked_sub_agent_inputs: set[str] = set()
|
||||
for link in links:
|
||||
if link.get("sink_id") == node_id:
|
||||
sink_name = link.get("sink_name", "")
|
||||
if sink_name in sub_agent_properties:
|
||||
linked_sub_agent_inputs.add(sink_name)
|
||||
|
||||
# Check for hardcoded inputs that should be linked
|
||||
hardcoded_inputs = input_default.get("inputs", {})
|
||||
input_schema = input_default.get("input_schema", {})
|
||||
schema_properties = (
|
||||
input_schema.get("properties", {})
|
||||
if isinstance(input_schema, dict)
|
||||
else {}
|
||||
)
|
||||
if isinstance(hardcoded_inputs, dict) and hardcoded_inputs:
|
||||
for input_name, value in hardcoded_inputs.items():
|
||||
if input_name not in sub_agent_properties:
|
||||
continue
|
||||
if value is None:
|
||||
continue
|
||||
# Skip if this input is already linked
|
||||
if input_name in linked_sub_agent_inputs:
|
||||
continue
|
||||
prop_schema = schema_properties.get(input_name, {})
|
||||
schema_default = (
|
||||
prop_schema.get("default")
|
||||
if isinstance(prop_schema, dict)
|
||||
else None
|
||||
)
|
||||
if schema_default is not None and self._values_equal(
|
||||
value, schema_default
|
||||
):
|
||||
continue
|
||||
# This is a non-default hardcoded value without a link
|
||||
self.add_error(
|
||||
f"AgentExecutorBlock node '{node_id}' has "
|
||||
f"hardcoded input '{input_name}'. Sub-agent "
|
||||
f"inputs should be connected via links using "
|
||||
f"'{input_name}' as sink_name, not hardcoded "
|
||||
f"in input_default.inputs. Create a link from "
|
||||
f"the appropriate source node."
|
||||
)
|
||||
valid = False
|
||||
|
||||
# Check for missing required sub-agent inputs.
|
||||
# An input is satisfied if it is linked OR has an allowed
|
||||
# hardcoded value (i.e. equals the schema default — the
|
||||
# previous check already flags non-default hardcoded values).
|
||||
hardcoded_inputs_dict = (
|
||||
hardcoded_inputs if isinstance(hardcoded_inputs, dict) else {}
|
||||
)
|
||||
for req_input in sub_agent_required_inputs:
|
||||
if req_input in linked_sub_agent_inputs:
|
||||
continue
|
||||
# Check if fixer populated it with a schema default value
|
||||
if req_input in hardcoded_inputs_dict:
|
||||
prop_schema = schema_properties.get(req_input, {})
|
||||
schema_default = (
|
||||
prop_schema.get("default")
|
||||
if isinstance(prop_schema, dict)
|
||||
else None
|
||||
)
|
||||
if schema_default is not None and self._values_equal(
|
||||
hardcoded_inputs_dict[req_input], schema_default
|
||||
):
|
||||
continue
|
||||
self.add_error(
|
||||
f"AgentExecutorBlock node '{node_id}' is "
|
||||
f"missing required sub-agent input "
|
||||
f"'{req_input}'. Create a link to this node "
|
||||
f"using sink_name '{req_input}' to connect "
|
||||
f"the input."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_agent_executor_block_schemas(
|
||||
self,
|
||||
agent: AgentDict,
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that AgentExecutorBlock nodes have valid input_schema and
|
||||
output_schema.
|
||||
|
||||
This validation runs regardless of library_agents availability and
|
||||
ensures that the schemas are properly populated to prevent frontend
|
||||
crashes.
|
||||
|
||||
Args:
|
||||
agent: The agent dictionary to validate
|
||||
|
||||
Returns:
|
||||
True if all AgentExecutorBlock nodes have valid schemas, False
|
||||
otherwise
|
||||
"""
|
||||
valid = True
|
||||
nodes = agent.get("nodes", [])
|
||||
|
||||
for node in nodes:
|
||||
if node.get("block_id") != AGENT_EXECUTOR_BLOCK_ID:
|
||||
continue
|
||||
|
||||
node_id = node.get("id")
|
||||
input_default = node.get("input_default", {})
|
||||
customized_name = (node.get("metadata") or {}).get(
|
||||
"customized_name", "Unknown"
|
||||
)
|
||||
|
||||
# Check input_schema
|
||||
input_schema = input_default.get("input_schema")
|
||||
if input_schema is None or not isinstance(input_schema, dict):
|
||||
self.add_error(
|
||||
f"AgentExecutorBlock node '{node_id}' "
|
||||
f"({customized_name}) has missing or invalid "
|
||||
f"input_schema. The input_schema must be a valid "
|
||||
f"JSON Schema object with 'properties' and "
|
||||
f"'required' fields."
|
||||
)
|
||||
valid = False
|
||||
elif not input_schema.get("properties") and not input_schema.get("type"):
|
||||
# Empty schema like {} is invalid
|
||||
self.add_error(
|
||||
f"AgentExecutorBlock node '{node_id}' "
|
||||
f"({customized_name}) has empty input_schema. The "
|
||||
f"input_schema must define the sub-agent's expected "
|
||||
f"inputs. This usually indicates the sub-agent "
|
||||
f"reference is incomplete or the library agent was "
|
||||
f"not properly passed."
|
||||
)
|
||||
valid = False
|
||||
|
||||
# Check output_schema
|
||||
output_schema = input_default.get("output_schema")
|
||||
if output_schema is None or not isinstance(output_schema, dict):
|
||||
self.add_error(
|
||||
f"AgentExecutorBlock node '{node_id}' "
|
||||
f"({customized_name}) has missing or invalid "
|
||||
f"output_schema. The output_schema must be a valid "
|
||||
f"JSON Schema object defining the sub-agent's "
|
||||
f"outputs."
|
||||
)
|
||||
valid = False
|
||||
elif not output_schema.get("properties") and not output_schema.get("type"):
|
||||
# Empty schema like {} is invalid
|
||||
self.add_error(
|
||||
f"AgentExecutorBlock node '{node_id}' "
|
||||
f"({customized_name}) has empty output_schema. "
|
||||
f"The output_schema must define the sub-agent's "
|
||||
f"expected outputs. This usually indicates the "
|
||||
f"sub-agent reference is incomplete or the library "
|
||||
f"agent was not properly passed."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_mcp_tool_blocks(self, agent: AgentDict) -> bool:
|
||||
"""Validate that MCPToolBlock nodes have required fields.
|
||||
|
||||
Checks that each MCPToolBlock node has:
|
||||
1. A non-empty `server_url` in input_default
|
||||
2. A non-empty `selected_tool` in input_default
|
||||
|
||||
Returns True if all MCPToolBlock nodes are valid, False otherwise.
|
||||
"""
|
||||
valid = True
|
||||
nodes = agent.get("nodes", [])
|
||||
|
||||
for node in nodes:
|
||||
if node.get("block_id") != MCP_TOOL_BLOCK_ID:
|
||||
continue
|
||||
|
||||
node_id = node.get("id", "unknown")
|
||||
input_default = node.get("input_default", {})
|
||||
customized_name = (node.get("metadata") or {}).get(
|
||||
"customized_name", node_id
|
||||
)
|
||||
|
||||
server_url = input_default.get("server_url")
|
||||
if not server_url:
|
||||
self.add_error(
|
||||
f"MCPToolBlock node '{customized_name}' ({node_id}) is "
|
||||
f"missing required 'server_url' in input_default. "
|
||||
f"Set this to the MCP server URL "
|
||||
f"(e.g. 'https://mcp.example.com/sse')."
|
||||
)
|
||||
valid = False
|
||||
|
||||
selected_tool = input_default.get("selected_tool")
|
||||
if not selected_tool:
|
||||
self.add_error(
|
||||
f"MCPToolBlock node '{customized_name}' ({node_id}) is "
|
||||
f"missing required 'selected_tool' in input_default. "
|
||||
f"Set this to the name of the MCP tool to execute."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate(
|
||||
self,
|
||||
agent: AgentDict,
|
||||
blocks: list[dict[str, Any]],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Comprehensive validation of an agent against available blocks.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: (is_valid, error_message)
|
||||
- is_valid: True if agent passes all validations, False otherwise
|
||||
- error_message: Detailed error message if validation fails, None
|
||||
if successful
|
||||
"""
|
||||
logger.info("Validating agent...")
|
||||
self.errors = []
|
||||
|
||||
checks = [
|
||||
(
|
||||
"Block existence",
|
||||
self.validate_block_existence(agent, blocks),
|
||||
),
|
||||
(
|
||||
"Link node references",
|
||||
self.validate_link_node_references(agent),
|
||||
),
|
||||
(
|
||||
"Required inputs",
|
||||
self.validate_required_inputs(agent, blocks),
|
||||
),
|
||||
(
|
||||
"Data type compatibility",
|
||||
self.validate_data_type_compatibility(agent, blocks),
|
||||
),
|
||||
(
|
||||
"Nested sink links",
|
||||
self.validate_nested_sink_links(agent, blocks),
|
||||
),
|
||||
(
|
||||
"Source output existence",
|
||||
self.validate_source_output_existence(agent, blocks),
|
||||
),
|
||||
(
|
||||
"Prompt double curly braces spaces",
|
||||
self.validate_prompt_double_curly_braces_spaces(agent),
|
||||
),
|
||||
(
|
||||
"IO blocks",
|
||||
self.validate_io_blocks(agent),
|
||||
),
|
||||
# Always validate AgentExecutorBlock schemas to prevent
|
||||
# frontend crashes
|
||||
(
|
||||
"AgentExecutorBlock schemas",
|
||||
self.validate_agent_executor_block_schemas(agent),
|
||||
),
|
||||
(
|
||||
"MCP tool blocks",
|
||||
self.validate_mcp_tool_blocks(agent),
|
||||
),
|
||||
]
|
||||
|
||||
# Add AgentExecutorBlock detailed validation if library_agents
|
||||
# provided
|
||||
if library_agents:
|
||||
checks.append(
|
||||
(
|
||||
"AgentExecutorBlock references",
|
||||
self.validate_agent_executor_blocks(agent, library_agents),
|
||||
)
|
||||
)
|
||||
|
||||
all_passed = all(check[1] for check in checks)
|
||||
|
||||
if all_passed:
|
||||
logger.info("Agent validation successful.")
|
||||
return True, None
|
||||
else:
|
||||
error_message = "Agent validation failed with the following errors:\n\n"
|
||||
for i, error in enumerate(self.errors, 1):
|
||||
error_message += f"{i}. {error}\n"
|
||||
|
||||
logger.error(f"Agent validation failed: {error_message}")
|
||||
return False, error_message
|
||||
@@ -1,710 +0,0 @@
|
||||
"""Unit tests for AgentValidator."""
|
||||
|
||||
from .helpers import (
|
||||
AGENT_EXECUTOR_BLOCK_ID,
|
||||
AGENT_INPUT_BLOCK_ID,
|
||||
AGENT_OUTPUT_BLOCK_ID,
|
||||
MCP_TOOL_BLOCK_ID,
|
||||
generate_uuid,
|
||||
)
|
||||
from .validator import AgentValidator
|
||||
|
||||
|
||||
def _make_agent(
|
||||
nodes: list | None = None,
|
||||
links: list | None = None,
|
||||
agent_id: str | None = None,
|
||||
) -> dict:
|
||||
"""Build a minimal agent dict for testing."""
|
||||
return {
|
||||
"id": agent_id or generate_uuid(),
|
||||
"name": "Test Agent",
|
||||
"nodes": nodes or [],
|
||||
"links": links or [],
|
||||
}
|
||||
|
||||
|
||||
def _make_node(
|
||||
node_id: str | None = None,
|
||||
block_id: str = "block-1",
|
||||
input_default: dict | None = None,
|
||||
position: tuple[int, int] = (0, 0),
|
||||
) -> dict:
|
||||
return {
|
||||
"id": node_id or generate_uuid(),
|
||||
"block_id": block_id,
|
||||
"input_default": input_default or {},
|
||||
"metadata": {"position": {"x": position[0], "y": position[1]}},
|
||||
}
|
||||
|
||||
|
||||
def _make_link(
|
||||
link_id: str | None = None,
|
||||
source_id: str = "",
|
||||
source_name: str = "output",
|
||||
sink_id: str = "",
|
||||
sink_name: str = "input",
|
||||
) -> dict:
|
||||
return {
|
||||
"id": link_id or generate_uuid(),
|
||||
"source_id": source_id,
|
||||
"source_name": source_name,
|
||||
"sink_id": sink_id,
|
||||
"sink_name": sink_name,
|
||||
}
|
||||
|
||||
|
||||
def _make_block(
|
||||
block_id: str = "block-1",
|
||||
name: str = "TestBlock",
|
||||
input_schema: dict | None = None,
|
||||
output_schema: dict | None = None,
|
||||
categories: list | None = None,
|
||||
static_output: bool = False,
|
||||
) -> dict:
|
||||
return {
|
||||
"id": block_id,
|
||||
"name": name,
|
||||
"inputSchema": input_schema or {"properties": {}, "required": []},
|
||||
"outputSchema": output_schema or {"properties": {}},
|
||||
"categories": categories or [],
|
||||
"staticOutput": static_output,
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_block_existence
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidateBlockExistence:
|
||||
def test_valid_blocks_pass(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(block_id="b1")
|
||||
block = _make_block(block_id="b1")
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_block_existence(agent, [block]) is True
|
||||
assert v.errors == []
|
||||
|
||||
def test_missing_block_fails(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(block_id="nonexistent")
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_block_existence(agent, []) is False
|
||||
assert len(v.errors) == 1
|
||||
assert "does not exist" in v.errors[0]
|
||||
|
||||
def test_missing_block_id_field(self):
|
||||
v = AgentValidator()
|
||||
node = {"id": "n1", "input_default": {}, "metadata": {}}
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_block_existence(agent, []) is False
|
||||
assert "missing a 'block_id'" in v.errors[0]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_link_node_references
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidateLinkNodeReferences:
|
||||
def test_valid_references_pass(self):
|
||||
v = AgentValidator()
|
||||
n1 = _make_node(node_id="n1")
|
||||
n2 = _make_node(node_id="n2")
|
||||
link = _make_link(source_id="n1", sink_id="n2")
|
||||
agent = _make_agent(nodes=[n1, n2], links=[link])
|
||||
|
||||
assert v.validate_link_node_references(agent) is True
|
||||
assert v.errors == []
|
||||
|
||||
def test_invalid_source_fails(self):
|
||||
v = AgentValidator()
|
||||
n1 = _make_node(node_id="n1")
|
||||
link = _make_link(source_id="missing", sink_id="n1")
|
||||
agent = _make_agent(nodes=[n1], links=[link])
|
||||
|
||||
assert v.validate_link_node_references(agent) is False
|
||||
assert any("source_id" in e for e in v.errors)
|
||||
|
||||
def test_invalid_sink_fails(self):
|
||||
v = AgentValidator()
|
||||
n1 = _make_node(node_id="n1")
|
||||
link = _make_link(source_id="n1", sink_id="missing")
|
||||
agent = _make_agent(nodes=[n1], links=[link])
|
||||
|
||||
assert v.validate_link_node_references(agent) is False
|
||||
assert any("sink_id" in e for e in v.errors)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_required_inputs
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidateRequiredInputs:
|
||||
def test_satisfied_by_default_passes(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={
|
||||
"properties": {"url": {"type": "string"}},
|
||||
"required": ["url"],
|
||||
},
|
||||
)
|
||||
node = _make_node(block_id="b1", input_default={"url": "http://example.com"})
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_required_inputs(agent, [block]) is True
|
||||
assert v.errors == []
|
||||
|
||||
def test_satisfied_by_link_passes(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={
|
||||
"properties": {"url": {"type": "string"}},
|
||||
"required": ["url"],
|
||||
},
|
||||
)
|
||||
node = _make_node(node_id="n1", block_id="b1")
|
||||
link = _make_link(source_id="n2", sink_id="n1", sink_name="url")
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_required_inputs(agent, [block]) is True
|
||||
|
||||
def test_missing_required_input_fails(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={
|
||||
"properties": {"url": {"type": "string"}},
|
||||
"required": ["url"],
|
||||
},
|
||||
)
|
||||
node = _make_node(block_id="b1", input_default={})
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_required_inputs(agent, [block]) is False
|
||||
assert any("missing required input" in e for e in v.errors)
|
||||
|
||||
def test_credentials_always_allowed_missing(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={
|
||||
"properties": {"credentials": {"type": "object"}},
|
||||
"required": ["credentials"],
|
||||
},
|
||||
)
|
||||
node = _make_node(block_id="b1", input_default={})
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_required_inputs(agent, [block]) is True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_data_type_compatibility
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidateDataTypeCompatibility:
|
||||
def test_matching_types_pass(self):
|
||||
v = AgentValidator()
|
||||
src_block = _make_block(
|
||||
block_id="src-b",
|
||||
output_schema={"properties": {"out": {"type": "string"}}},
|
||||
)
|
||||
sink_block = _make_block(
|
||||
block_id="sink-b",
|
||||
input_schema={"properties": {"inp": {"type": "string"}}, "required": []},
|
||||
)
|
||||
src_node = _make_node(node_id="n1", block_id="src-b")
|
||||
sink_node = _make_node(node_id="n2", block_id="sink-b")
|
||||
link = _make_link(
|
||||
source_id="n1", source_name="out", sink_id="n2", sink_name="inp"
|
||||
)
|
||||
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
|
||||
|
||||
assert (
|
||||
v.validate_data_type_compatibility(agent, [src_block, sink_block]) is True
|
||||
)
|
||||
|
||||
def test_int_number_compatible(self):
|
||||
v = AgentValidator()
|
||||
src_block = _make_block(
|
||||
block_id="src-b",
|
||||
output_schema={"properties": {"out": {"type": "integer"}}},
|
||||
)
|
||||
sink_block = _make_block(
|
||||
block_id="sink-b",
|
||||
input_schema={"properties": {"inp": {"type": "number"}}, "required": []},
|
||||
)
|
||||
src_node = _make_node(node_id="n1", block_id="src-b")
|
||||
sink_node = _make_node(node_id="n2", block_id="sink-b")
|
||||
link = _make_link(
|
||||
source_id="n1", source_name="out", sink_id="n2", sink_name="inp"
|
||||
)
|
||||
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
|
||||
|
||||
assert (
|
||||
v.validate_data_type_compatibility(agent, [src_block, sink_block]) is True
|
||||
)
|
||||
|
||||
def test_mismatched_types_fail(self):
|
||||
v = AgentValidator()
|
||||
src_block = _make_block(
|
||||
block_id="src-b",
|
||||
output_schema={"properties": {"out": {"type": "string"}}},
|
||||
)
|
||||
sink_block = _make_block(
|
||||
block_id="sink-b",
|
||||
input_schema={"properties": {"inp": {"type": "integer"}}, "required": []},
|
||||
)
|
||||
src_node = _make_node(node_id="n1", block_id="src-b")
|
||||
sink_node = _make_node(node_id="n2", block_id="sink-b")
|
||||
link = _make_link(
|
||||
source_id="n1", source_name="out", sink_id="n2", sink_name="inp"
|
||||
)
|
||||
agent = _make_agent(nodes=[src_node, sink_node], links=[link])
|
||||
|
||||
assert (
|
||||
v.validate_data_type_compatibility(agent, [src_block, sink_block]) is False
|
||||
)
|
||||
assert any("mismatch" in e.lower() for e in v.errors)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_source_output_existence
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidateSourceOutputExistence:
|
||||
def test_valid_source_output_passes(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
output_schema={"properties": {"result": {"type": "string"}}},
|
||||
)
|
||||
node = _make_node(node_id="n1", block_id="b1")
|
||||
link = _make_link(source_id="n1", source_name="result", sink_id="n2")
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_source_output_existence(agent, [block]) is True
|
||||
|
||||
def test_invalid_source_output_fails(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
output_schema={"properties": {"result": {"type": "string"}}},
|
||||
)
|
||||
node = _make_node(node_id="n1", block_id="b1")
|
||||
link = _make_link(source_id="n1", source_name="nonexistent", sink_id="n2")
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_source_output_existence(agent, [block]) is False
|
||||
assert any("does not exist" in e for e in v.errors)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_prompt_double_curly_braces_spaces
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidatePromptDoubleCurlyBracesSpaces:
|
||||
def test_no_spaces_passes(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(input_default={"prompt": "Hello {{name}}!"})
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_prompt_double_curly_braces_spaces(agent) is True
|
||||
|
||||
def test_spaces_in_braces_fails(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(input_default={"prompt": "Hello {{user name}}!"})
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_prompt_double_curly_braces_spaces(agent) is False
|
||||
assert any("spaces" in e for e in v.errors)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_nested_sink_links
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidateNestedSinkLinks:
|
||||
def test_valid_nested_link_passes(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {"key": {"type": "string"}},
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
node = _make_node(node_id="n1", block_id="b1")
|
||||
link = _make_link(sink_id="n1", sink_name="config_#_key", source_id="n2")
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_nested_sink_links(agent, [block]) is True
|
||||
|
||||
def test_invalid_parent_fails(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(block_id="b1")
|
||||
node = _make_node(node_id="n1", block_id="b1")
|
||||
link = _make_link(sink_id="n1", sink_name="nonexistent_#_key", source_id="n2")
|
||||
agent = _make_agent(nodes=[node], links=[link])
|
||||
|
||||
assert v.validate_nested_sink_links(agent, [block]) is False
|
||||
assert any("does not exist" in e for e in v.errors)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_agent_executor_block_schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidateAgentExecutorBlockSchemas:
|
||||
def test_valid_schemas_pass(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(
|
||||
block_id=AGENT_EXECUTOR_BLOCK_ID,
|
||||
input_default={
|
||||
"graph_id": generate_uuid(),
|
||||
"input_schema": {"properties": {"q": {"type": "string"}}},
|
||||
"output_schema": {"properties": {"result": {"type": "string"}}},
|
||||
},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_agent_executor_block_schemas(agent) is True
|
||||
assert v.errors == []
|
||||
|
||||
def test_empty_input_schema_fails(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(
|
||||
block_id=AGENT_EXECUTOR_BLOCK_ID,
|
||||
input_default={
|
||||
"graph_id": generate_uuid(),
|
||||
"input_schema": {},
|
||||
"output_schema": {"properties": {"result": {"type": "string"}}},
|
||||
},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_agent_executor_block_schemas(agent) is False
|
||||
assert any("empty input_schema" in e for e in v.errors)
|
||||
|
||||
def test_missing_output_schema_fails(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(
|
||||
block_id=AGENT_EXECUTOR_BLOCK_ID,
|
||||
input_default={
|
||||
"graph_id": generate_uuid(),
|
||||
"input_schema": {"properties": {"q": {"type": "string"}}},
|
||||
},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_agent_executor_block_schemas(agent) is False
|
||||
assert any("output_schema" in e for e in v.errors)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_agent_executor_blocks
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidateAgentExecutorBlocks:
|
||||
def test_missing_graph_id_fails(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(
|
||||
block_id=AGENT_EXECUTOR_BLOCK_ID,
|
||||
input_default={},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_agent_executor_blocks(agent) is False
|
||||
assert any("graph_id" in e for e in v.errors)
|
||||
|
||||
def test_valid_graph_id_passes(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(
|
||||
block_id=AGENT_EXECUTOR_BLOCK_ID,
|
||||
input_default={"graph_id": generate_uuid()},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_agent_executor_blocks(agent) is True
|
||||
|
||||
def test_version_mismatch_with_library_agent(self):
|
||||
v = AgentValidator()
|
||||
lib_id = generate_uuid()
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id=AGENT_EXECUTOR_BLOCK_ID,
|
||||
input_default={"graph_id": lib_id, "graph_version": 1},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
library_agents = [{"graph_id": lib_id, "graph_version": 3, "name": "Sub Agent"}]
|
||||
|
||||
assert v.validate_agent_executor_blocks(agent, library_agents) is False
|
||||
assert any("mismatched graph_version" in e for e in v.errors)
|
||||
|
||||
def test_required_input_satisfied_by_schema_default_passes(self):
|
||||
"""Required sub-agent inputs filled with their schema default by the fixer
|
||||
should NOT be flagged as missing."""
|
||||
v = AgentValidator()
|
||||
lib_id = generate_uuid()
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id=AGENT_EXECUTOR_BLOCK_ID,
|
||||
input_default={
|
||||
"graph_id": lib_id,
|
||||
"input_schema": {
|
||||
"properties": {"mode": {"type": "string", "default": "fast"}}
|
||||
},
|
||||
"inputs": {"mode": "fast"}, # fixer populated with schema default
|
||||
},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": lib_id,
|
||||
"graph_version": 1,
|
||||
"name": "Sub",
|
||||
"input_schema": {
|
||||
"required": ["mode"],
|
||||
"properties": {"mode": {"type": "string", "default": "fast"}},
|
||||
},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
assert v.validate_agent_executor_blocks(agent, library_agents) is True
|
||||
assert v.errors == []
|
||||
|
||||
def test_required_input_not_linked_and_no_default_fails(self):
|
||||
"""Required sub-agent inputs without a link or schema default must fail."""
|
||||
v = AgentValidator()
|
||||
lib_id = generate_uuid()
|
||||
node = _make_node(
|
||||
node_id="n1",
|
||||
block_id=AGENT_EXECUTOR_BLOCK_ID,
|
||||
input_default={
|
||||
"graph_id": lib_id,
|
||||
"input_schema": {"properties": {"query": {"type": "string"}}},
|
||||
"inputs": {},
|
||||
},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": lib_id,
|
||||
"graph_version": 1,
|
||||
"name": "Sub",
|
||||
"input_schema": {
|
||||
"required": ["query"],
|
||||
"properties": {"query": {"type": "string"}},
|
||||
},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
assert v.validate_agent_executor_blocks(agent, library_agents) is False
|
||||
assert any("missing required sub-agent input" in e for e in v.errors)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate_io_blocks
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidateIoBlocks:
|
||||
def test_missing_input_block_reports_error(self):
|
||||
v = AgentValidator()
|
||||
# Agent has output block but no input block
|
||||
node = _make_node(block_id=AGENT_OUTPUT_BLOCK_ID)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_io_blocks(agent) is False
|
||||
assert len(v.errors) == 1
|
||||
assert "AgentInputBlock" in v.errors[0]
|
||||
|
||||
def test_missing_output_block_reports_error(self):
|
||||
v = AgentValidator()
|
||||
# Agent has input block but no output block
|
||||
node = _make_node(block_id=AGENT_INPUT_BLOCK_ID)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_io_blocks(agent) is False
|
||||
assert len(v.errors) == 1
|
||||
assert "AgentOutputBlock" in v.errors[0]
|
||||
|
||||
def test_missing_both_io_blocks_reports_two_errors(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(block_id="some-other-block")
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
assert v.validate_io_blocks(agent) is False
|
||||
assert len(v.errors) == 2
|
||||
|
||||
def test_both_io_blocks_present_no_error(self):
|
||||
v = AgentValidator()
|
||||
input_node = _make_node(block_id=AGENT_INPUT_BLOCK_ID)
|
||||
output_node = _make_node(block_id=AGENT_OUTPUT_BLOCK_ID)
|
||||
agent = _make_agent(nodes=[input_node, output_node])
|
||||
|
||||
assert v.validate_io_blocks(agent) is True
|
||||
assert v.errors == []
|
||||
|
||||
def test_empty_agent_reports_both_missing(self):
|
||||
v = AgentValidator()
|
||||
agent = _make_agent(nodes=[])
|
||||
|
||||
assert v.validate_io_blocks(agent) is False
|
||||
assert len(v.errors) == 2
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# validate (integration)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestValidate:
|
||||
def test_valid_agent_passes(self):
|
||||
v = AgentValidator()
|
||||
block = _make_block(
|
||||
block_id="b1",
|
||||
input_schema={
|
||||
"properties": {"url": {"type": "string"}},
|
||||
"required": ["url"],
|
||||
},
|
||||
output_schema={"properties": {"result": {"type": "string"}}},
|
||||
)
|
||||
input_block = _make_block(
|
||||
block_id=AGENT_INPUT_BLOCK_ID,
|
||||
name="AgentInputBlock",
|
||||
output_schema={"properties": {"result": {}}},
|
||||
)
|
||||
output_block = _make_block(
|
||||
block_id=AGENT_OUTPUT_BLOCK_ID,
|
||||
name="AgentOutputBlock",
|
||||
)
|
||||
input_node = _make_node(
|
||||
node_id="n-in",
|
||||
block_id=AGENT_INPUT_BLOCK_ID,
|
||||
input_default={"name": "url"},
|
||||
)
|
||||
n1 = _make_node(
|
||||
node_id="n1", block_id="b1", input_default={"url": "http://example.com"}
|
||||
)
|
||||
n2 = _make_node(
|
||||
node_id="n2", block_id="b1", input_default={"url": "http://example2.com"}
|
||||
)
|
||||
output_node = _make_node(
|
||||
node_id="n-out",
|
||||
block_id=AGENT_OUTPUT_BLOCK_ID,
|
||||
input_default={"name": "result"},
|
||||
)
|
||||
link = _make_link(
|
||||
source_id="n1", source_name="result", sink_id="n2", sink_name="url"
|
||||
)
|
||||
agent = _make_agent(nodes=[input_node, n1, n2, output_node], links=[link])
|
||||
|
||||
is_valid, error_message = v.validate(agent, [block, input_block, output_block])
|
||||
|
||||
assert is_valid is True
|
||||
assert error_message is None
|
||||
|
||||
def test_invalid_agent_returns_errors(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(block_id="nonexistent")
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
is_valid, error_message = v.validate(agent, [])
|
||||
|
||||
assert is_valid is False
|
||||
assert error_message is not None
|
||||
assert "does not exist" in error_message
|
||||
|
||||
def test_empty_agent_fails_io_validation(self):
|
||||
v = AgentValidator()
|
||||
agent = _make_agent()
|
||||
|
||||
is_valid, error_message = v.validate(agent, [])
|
||||
|
||||
assert is_valid is False
|
||||
assert error_message is not None
|
||||
assert "AgentInputBlock" in error_message
|
||||
assert "AgentOutputBlock" in error_message
|
||||
|
||||
|
||||
class TestValidateMCPToolBlocks:
|
||||
"""Tests for validate_mcp_tool_blocks."""
|
||||
|
||||
def test_missing_server_url_reports_error(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(
|
||||
block_id=MCP_TOOL_BLOCK_ID,
|
||||
input_default={"selected_tool": "my_tool"},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
result = v.validate_mcp_tool_blocks(agent)
|
||||
|
||||
assert result is False
|
||||
assert any("server_url" in e for e in v.errors)
|
||||
|
||||
def test_missing_selected_tool_reports_error(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(
|
||||
block_id=MCP_TOOL_BLOCK_ID,
|
||||
input_default={"server_url": "https://mcp.example.com/sse"},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
result = v.validate_mcp_tool_blocks(agent)
|
||||
|
||||
assert result is False
|
||||
assert any("selected_tool" in e for e in v.errors)
|
||||
|
||||
def test_valid_mcp_block_passes(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(
|
||||
block_id=MCP_TOOL_BLOCK_ID,
|
||||
input_default={
|
||||
"server_url": "https://mcp.example.com/sse",
|
||||
"selected_tool": "search",
|
||||
"tool_input_schema": {"properties": {"query": {"type": "string"}}},
|
||||
"tool_arguments": {},
|
||||
},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
result = v.validate_mcp_tool_blocks(agent)
|
||||
|
||||
assert result is True
|
||||
assert len(v.errors) == 0
|
||||
|
||||
def test_both_missing_reports_two_errors(self):
|
||||
v = AgentValidator()
|
||||
node = _make_node(
|
||||
block_id=MCP_TOOL_BLOCK_ID,
|
||||
input_default={},
|
||||
)
|
||||
agent = _make_agent(nodes=[node])
|
||||
|
||||
v.validate_mcp_tool_blocks(agent)
|
||||
|
||||
assert len(v.errors) == 2
|
||||
@@ -208,9 +208,6 @@ def _library_agent_to_info(agent: LibraryAgent) -> AgentInfo:
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
graph_version=agent.graph_version,
|
||||
input_schema=agent.input_schema,
|
||||
output_schema=agent.output_schema,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -21,10 +21,10 @@ from typing import Any
|
||||
from e2b import AsyncSandbox
|
||||
from e2b.exceptions import TimeoutException
|
||||
|
||||
from backend.copilot.context import E2B_WORKDIR, get_current_sandbox
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .e2b_sandbox import E2B_WORKDIR
|
||||
from .models import BashExecResponse, ErrorResponse, ToolResponseBase
|
||||
from .sandbox import get_workspace_dir, has_full_sandbox, run_sandboxed
|
||||
|
||||
@@ -94,6 +94,9 @@ class BashExecTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# E2B path: run on remote cloud sandbox when available.
|
||||
from backend.copilot.sdk.tool_adapter import get_current_sandbox
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
return await self._execute_on_e2b(sandbox, command, timeout, session_id)
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
"""Tool for continuing block execution after human review approval."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from prisma.enums import ReviewStatus
|
||||
|
||||
from backend.blocks import get_block
|
||||
from backend.copilot.constants import (
|
||||
COPILOT_NODE_PREFIX,
|
||||
COPILOT_SESSION_PREFIX,
|
||||
parse_node_id_from_exec_id,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import review_db
|
||||
|
||||
from .base import BaseTool
|
||||
from .helpers import execute_block, resolve_block_credentials
|
||||
from .models import ErrorResponse, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContinueRunBlockTool(BaseTool):
|
||||
"""Tool for continuing a block execution after human review approval."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "continue_run_block"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Continue executing a block after human review approval. "
|
||||
"Use this after a run_block call returned review_required. "
|
||||
"Pass the review_id from the review_required response. "
|
||||
"The block will execute with the original pre-approved input data."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"review_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The review_id from a previous review_required response. "
|
||||
"This resumes execution with the pre-approved input data."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["review_id"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
review_id = (
|
||||
kwargs.get("review_id", "").strip() if kwargs.get("review_id") else ""
|
||||
)
|
||||
session_id = session.session_id
|
||||
|
||||
if not review_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide a review_id", session_id=session_id
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required", session_id=session_id
|
||||
)
|
||||
|
||||
# Look up and validate the review record via adapter
|
||||
reviews = await review_db().get_reviews_by_node_exec_ids([review_id], user_id)
|
||||
review = reviews.get(review_id)
|
||||
|
||||
if not review:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Review '{review_id}' not found or already executed. "
|
||||
"It may have been consumed by a previous continue_run_block call."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Validate the review belongs to this session
|
||||
expected_graph_exec_id = f"{COPILOT_SESSION_PREFIX}{session_id}"
|
||||
if review.graph_exec_id != expected_graph_exec_id:
|
||||
return ErrorResponse(
|
||||
message="Review does not belong to this session.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if review.status == ReviewStatus.WAITING:
|
||||
return ErrorResponse(
|
||||
message="Review has not been approved yet. "
|
||||
"Please wait for the user to approve the review first.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if review.status == ReviewStatus.REJECTED:
|
||||
return ErrorResponse(
|
||||
message="Review was rejected. The block will not execute.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Extract block_id from review_id: copilot-node-{block_id}:{random_hex}
|
||||
block_id = parse_node_id_from_exec_id(review_id).removeprefix(
|
||||
COPILOT_NODE_PREFIX
|
||||
)
|
||||
block = get_block(block_id)
|
||||
if not block:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block_id}' not found", session_id=session_id
|
||||
)
|
||||
|
||||
input_data: dict[str, Any] = (
|
||||
review.payload if isinstance(review.payload, dict) else {}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Continuing block {block.name} ({block_id}) for user {user_id} "
|
||||
f"with review_id={review_id}"
|
||||
)
|
||||
|
||||
matched_creds, missing_creds = await resolve_block_credentials(
|
||||
user_id, block, input_data
|
||||
)
|
||||
if missing_creds:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block.name}' requires credentials that are not configured.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id=block_id,
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
node_exec_id=review_id,
|
||||
matched_credentials=matched_creds,
|
||||
)
|
||||
|
||||
# Delete review record after successful execution (one-time use)
|
||||
if result.type != "error":
|
||||
await review_db().delete_review_by_node_exec_id(review_id, user_id)
|
||||
|
||||
return result
|
||||
@@ -1,186 +0,0 @@
|
||||
"""Tests for ContinueRunBlockTool."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import ReviewStatus
|
||||
|
||||
from ._test_data import make_session
|
||||
from .continue_run_block import ContinueRunBlockTool
|
||||
from .models import BlockOutputResponse, ErrorResponse
|
||||
|
||||
_TEST_USER_ID = "test-user-continue"
|
||||
|
||||
|
||||
def _make_review_model(
|
||||
node_exec_id: str,
|
||||
status: ReviewStatus = ReviewStatus.APPROVED,
|
||||
payload: dict | None = None,
|
||||
graph_exec_id: str = "",
|
||||
):
|
||||
"""Create a mock PendingHumanReviewModel."""
|
||||
mock = MagicMock()
|
||||
mock.node_exec_id = node_exec_id
|
||||
mock.status = status
|
||||
mock.payload = payload or {"text": "hello"}
|
||||
mock.graph_exec_id = graph_exec_id
|
||||
return mock
|
||||
|
||||
|
||||
class TestContinueRunBlock:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_missing_review_id_returns_error(self):
|
||||
tool = ContinueRunBlockTool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
review_id="",
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "review_id" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_review_not_found_returns_error(self):
|
||||
tool = ContinueRunBlockTool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_reviews_by_node_exec_ids = AsyncMock(return_value={})
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.continue_run_block.review_db",
|
||||
return_value=mock_db,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
review_id="copilot-node-some-block:abc12345",
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "not found" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_waiting_review_returns_error(self):
|
||||
tool = ContinueRunBlockTool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
review_id = "copilot-node-some-block:abc12345"
|
||||
graph_exec_id = f"copilot-session-{session.session_id}"
|
||||
review = _make_review_model(
|
||||
review_id, status=ReviewStatus.WAITING, graph_exec_id=graph_exec_id
|
||||
)
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_reviews_by_node_exec_ids = AsyncMock(
|
||||
return_value={review_id: review}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.continue_run_block.review_db",
|
||||
return_value=mock_db,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
review_id=review_id,
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "not been approved" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_rejected_review_returns_error(self):
|
||||
tool = ContinueRunBlockTool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
review_id = "copilot-node-some-block:abc12345"
|
||||
graph_exec_id = f"copilot-session-{session.session_id}"
|
||||
review = _make_review_model(
|
||||
review_id, status=ReviewStatus.REJECTED, graph_exec_id=graph_exec_id
|
||||
)
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_reviews_by_node_exec_ids = AsyncMock(
|
||||
return_value={review_id: review}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.continue_run_block.review_db",
|
||||
return_value=mock_db,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
review_id=review_id,
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "rejected" in response.message.lower()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_approved_review_executes_block(self):
|
||||
tool = ContinueRunBlockTool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
review_id = "copilot-node-delete-branch-id:abc12345"
|
||||
graph_exec_id = f"copilot-session-{session.session_id}"
|
||||
input_data = {"repo_url": "https://github.com/test/repo", "branch": "main"}
|
||||
review = _make_review_model(
|
||||
review_id,
|
||||
status=ReviewStatus.APPROVED,
|
||||
payload=input_data,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
|
||||
mock_block = MagicMock()
|
||||
mock_block.name = "Delete Branch"
|
||||
|
||||
async def mock_execute(data, **kwargs):
|
||||
yield "result", "Branch deleted"
|
||||
|
||||
mock_block.execute = mock_execute
|
||||
mock_block.input_schema.get_credentials_fields_info.return_value = []
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="test-workspace-id")
|
||||
)
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_reviews_by_node_exec_ids = AsyncMock(
|
||||
return_value={review_id: review}
|
||||
)
|
||||
mock_db.delete_review_by_node_exec_id = AsyncMock(return_value=1)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.continue_run_block.review_db",
|
||||
return_value=mock_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.continue_run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
review_id=review_id,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert response.success is True
|
||||
assert response.block_name == "Delete Branch"
|
||||
# Verify review was deleted (one-time use)
|
||||
mock_db.delete_review_by_node_exec_id.assert_called_once_with(
|
||||
review_id, _TEST_USER_ID
|
||||
)
|
||||
@@ -1,20 +1,34 @@
|
||||
"""CreateAgentTool - Creates agents from pre-built JSON."""
|
||||
"""CreateAgentTool - Creates agents from natural language descriptions."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .agent_generator.pipeline import fetch_library_agents, fix_validate_and_save
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
decompose_goal,
|
||||
enrich_library_agents_from_steps,
|
||||
generate_agent,
|
||||
get_user_message_for_error,
|
||||
save_agent_to_library,
|
||||
)
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ToolResponseBase
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
SuggestedGoalResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CreateAgentTool(BaseTool):
|
||||
"""Tool for creating agents from pre-built JSON."""
|
||||
"""Tool for creating agents from natural language descriptions."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -23,12 +37,15 @@ class CreateAgentTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Create a new agent workflow. Pass `agent_json` with the complete "
|
||||
"agent graph JSON you generated using block schemas from find_block. "
|
||||
"The tool validates, auto-fixes, and saves.\n\n"
|
||||
"IMPORTANT: Before calling this tool, search for relevant existing agents "
|
||||
"using find_library_agent that could be used as building blocks. "
|
||||
"Pass their IDs in the library_agent_ids parameter."
|
||||
"Create a new agent workflow from a natural language description. "
|
||||
"First generates a preview, then saves to library if save=true. "
|
||||
"\n\nWorkflow: (1) Always check find_library_agent first for existing building blocks. "
|
||||
"(2) Call create_agent with description and library_agent_ids. "
|
||||
"(3) If response contains suggested_goal: Present to user, ask for confirmation, "
|
||||
"then call again with the suggested goal if accepted. "
|
||||
"(4) If response contains clarifying_questions: Present to user, collect answers, "
|
||||
"then call again with original description AND answers in the context parameter. "
|
||||
"\n\nThis feedback loop ensures the generated agent matches user intent."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -40,26 +57,34 @@ class CreateAgentTool(BaseTool):
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_json": {
|
||||
"type": "object",
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The agent JSON to validate and save. "
|
||||
"Must contain 'nodes' and 'links' arrays, and optionally "
|
||||
"'name' and 'description'."
|
||||
"Natural language description of what the agent should do. "
|
||||
"Be specific about inputs, outputs, and the workflow steps."
|
||||
),
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Additional context or answers to previous clarifying questions. "
|
||||
"Include any preferences or constraints mentioned by the user."
|
||||
),
|
||||
},
|
||||
"library_agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"List of library agent IDs to use as building blocks."
|
||||
"List of library agent IDs to use as building blocks. "
|
||||
"Search for relevant agents using find_library_agent first, "
|
||||
"then pass their IDs here so they can be composed into the new agent."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to save the agent. Default is true. "
|
||||
"Set to false for preview only."
|
||||
"Whether to save the agent to the user's library. "
|
||||
"Default is true. Set to false for preview only."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
@@ -72,7 +97,7 @@ class CreateAgentTool(BaseTool):
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["agent_json"],
|
||||
"required": ["description"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
@@ -81,49 +106,278 @@ class CreateAgentTool(BaseTool):
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
|
||||
"""Execute the create_agent tool.
|
||||
|
||||
Flow:
|
||||
1. Decompose the description into steps (may return clarifying questions)
|
||||
2. Generate agent JSON (external service handles fixing and validation)
|
||||
3. Preview or save based on the save parameter
|
||||
"""
|
||||
description = kwargs.get("description", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
library_agent_ids = kwargs.get("library_agent_ids", [])
|
||||
save = kwargs.get("save", True)
|
||||
folder_id = kwargs.get("folder_id")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_json:
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] START - description_len={len(description)}, "
|
||||
f"library_agent_ids={library_agent_ids}, save={save}, user_id={user_id}, session_id={session_id}"
|
||||
)
|
||||
|
||||
if not description:
|
||||
return ErrorResponse(
|
||||
message="Please provide a description of what the agent should do.",
|
||||
error="Missing description parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Fetch library agents by IDs if provided
|
||||
library_agents = None
|
||||
if user_id and library_agent_ids:
|
||||
try:
|
||||
from .agent_generator import get_library_agents_by_ids
|
||||
|
||||
library_agents = await get_library_agents_by_ids(
|
||||
user_id=user_id,
|
||||
agent_ids=library_agent_ids,
|
||||
)
|
||||
logger.debug(
|
||||
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agents by IDs: {e}")
|
||||
|
||||
try:
|
||||
decomposition_result = await decompose_goal(
|
||||
description, context, library_agents
|
||||
)
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] DECOMPOSE - type={decomposition_result.get('type') if decomposition_result else None}, "
|
||||
f"session_id={session_id}"
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
logger.error(
|
||||
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured, session_id={session_id}"
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Please provide agent_json with the complete agent graph. "
|
||||
"Use find_block to discover blocks, then generate the JSON."
|
||||
"Agent generation is not available. "
|
||||
"The Agent Generator service is not configured."
|
||||
),
|
||||
error="missing_agent_json",
|
||||
error="service_not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
save = kwargs.get("save", True)
|
||||
library_agent_ids = kwargs.get("library_agent_ids", [])
|
||||
folder_id: str | None = kwargs.get("folder_id")
|
||||
|
||||
nodes = agent_json.get("nodes", [])
|
||||
if not nodes:
|
||||
if decomposition_result is None:
|
||||
return ErrorResponse(
|
||||
message="The agent JSON has no nodes. An agent needs at least one block.",
|
||||
error="empty_agent",
|
||||
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
|
||||
error="decomposition_failed",
|
||||
details={"description": description[:100]},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Ensure top-level fields
|
||||
if "id" not in agent_json:
|
||||
agent_json["id"] = str(uuid.uuid4())
|
||||
if "version" not in agent_json:
|
||||
agent_json["version"] = 1
|
||||
if "is_active" not in agent_json:
|
||||
agent_json["is_active"] = True
|
||||
if decomposition_result.get("type") == "error":
|
||||
error_msg = decomposition_result.get("error", "Unknown error")
|
||||
error_type = decomposition_result.get("error_type", "unknown")
|
||||
user_message = get_user_message_for_error(
|
||||
error_type,
|
||||
operation="analyze the goal",
|
||||
llm_parse_message="The AI had trouble understanding this request. Please try rephrasing your goal.",
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
error=f"decomposition_failed:{error_type}",
|
||||
details={
|
||||
"description": description[:100],
|
||||
"service_error": error_msg,
|
||||
"error_type": error_type,
|
||||
},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Fetch library agents for AgentExecutorBlock validation
|
||||
library_agents = await fetch_library_agents(user_id, library_agent_ids)
|
||||
if decomposition_result.get("type") == "clarifying_questions":
|
||||
questions = decomposition_result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information to create this agent. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return await fix_validate_and_save(
|
||||
agent_json,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
save=save,
|
||||
is_update=False,
|
||||
default_name="Generated Agent",
|
||||
library_agents=library_agents,
|
||||
folder_id=folder_id,
|
||||
if decomposition_result.get("type") == "unachievable_goal":
|
||||
suggested = decomposition_result.get("suggested_goal", "")
|
||||
reason = decomposition_result.get("reason", "")
|
||||
return SuggestedGoalResponse(
|
||||
message=(
|
||||
f"This goal cannot be accomplished with the available blocks. {reason}"
|
||||
),
|
||||
suggested_goal=suggested,
|
||||
reason=reason,
|
||||
original_goal=description,
|
||||
goal_type="unachievable",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if decomposition_result.get("type") == "vague_goal":
|
||||
suggested = decomposition_result.get("suggested_goal", "")
|
||||
reason = decomposition_result.get(
|
||||
"reason", "The goal needs more specific details"
|
||||
)
|
||||
return SuggestedGoalResponse(
|
||||
message="The goal is too vague to create a specific workflow.",
|
||||
suggested_goal=suggested,
|
||||
reason=reason,
|
||||
original_goal=description,
|
||||
goal_type="vague",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if user_id and library_agents is not None:
|
||||
try:
|
||||
library_agents = await enrich_library_agents_from_steps(
|
||||
user_id=user_id,
|
||||
decomposition_result=decomposition_result,
|
||||
existing_agents=library_agents,
|
||||
include_marketplace=True,
|
||||
)
|
||||
logger.debug(
|
||||
f"After enrichment: {len(library_agents)} total agents for sub-agent composition"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
||||
|
||||
try:
|
||||
agent_json = await generate_agent(
|
||||
decomposition_result,
|
||||
library_agents,
|
||||
)
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] GENERATE - "
|
||||
f"success={agent_json is not None}, "
|
||||
f"is_error={isinstance(agent_json, dict) and agent_json.get('type') == 'error'}, "
|
||||
f"session_id={session_id}"
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
logger.error(
|
||||
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured during generation, session_id={session_id}"
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Agent generation is not available. "
|
||||
"The Agent Generator service is not configured."
|
||||
),
|
||||
error="service_not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if agent_json is None:
|
||||
return ErrorResponse(
|
||||
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
|
||||
error="generation_failed",
|
||||
details={"description": description[:100]},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if isinstance(agent_json, dict) and agent_json.get("type") == "error":
|
||||
error_msg = agent_json.get("error", "Unknown error")
|
||||
error_type = agent_json.get("error_type", "unknown")
|
||||
user_message = get_user_message_for_error(
|
||||
error_type,
|
||||
operation="generate the agent",
|
||||
llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.",
|
||||
validation_message=(
|
||||
"I wasn't able to create a valid agent for this request. "
|
||||
"The generated workflow had some structural issues. "
|
||||
"Please try simplifying your goal or breaking it into smaller steps."
|
||||
),
|
||||
error_details=error_msg,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
error=f"generation_failed:{error_type}",
|
||||
details={
|
||||
"description": description[:100],
|
||||
"service_error": error_msg,
|
||||
"error_type": error_type,
|
||||
},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agent_name = agent_json.get("name", "Generated Agent")
|
||||
agent_description = agent_json.get("description", "")
|
||||
node_count = len(agent_json.get("nodes", []))
|
||||
link_count = len(agent_json.get("links", []))
|
||||
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] AGENT_JSON - name={agent_name}, "
|
||||
f"nodes={node_count}, links={link_count}, save={save}, session_id={session_id}"
|
||||
)
|
||||
|
||||
if not save:
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] RETURN - AgentPreviewResponse, session_id={session_id}"
|
||||
)
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
|
||||
f"Review it and call create_agent with save=true to save it to your library."
|
||||
),
|
||||
agent_json=agent_json,
|
||||
agent_name=agent_name,
|
||||
description=agent_description,
|
||||
node_count=node_count,
|
||||
link_count=link_count,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
error="auth_required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
agent_json, user_id, folder_id=folder_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] SAVED - graph_id={created_graph.id}, "
|
||||
f"library_agent_id={library_agent.id}, session_id={session_id}"
|
||||
)
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] RETURN - AgentSavedResponse, session_id={session_id}"
|
||||
)
|
||||
return AgentSavedResponse(
|
||||
message=f"Agent '{created_graph.name}' has been saved to your library!",
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[AGENT_CREATE_DEBUG] ERROR - save_failed: {str(e)}, session_id={session_id}"
|
||||
)
|
||||
logger.info(
|
||||
f"[AGENT_CREATE_DEBUG] RETURN - ErrorResponse (save_failed), session_id={session_id}"
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to save the agent: {str(e)}",
|
||||
error="save_failed",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
"""Tests for CreateAgentTool."""
|
||||
"""Tests for CreateAgentTool response types."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.tools.create_agent import CreateAgentTool
|
||||
from backend.copilot.tools.models import AgentPreviewResponse, ErrorResponse
|
||||
from backend.copilot.tools.models import (
|
||||
ClarificationNeededResponse,
|
||||
ErrorResponse,
|
||||
SuggestedGoalResponse,
|
||||
)
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-create-agent"
|
||||
_PIPELINE = "backend.copilot.tools.agent_generator.pipeline"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -23,147 +26,102 @@ def session():
|
||||
return make_session(_TEST_USER_ID)
|
||||
|
||||
|
||||
# ── Input validation tests ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_agent_json_returns_error(tool, session):
|
||||
"""Missing agent_json returns ErrorResponse."""
|
||||
result = await tool._execute(user_id=_TEST_USER_ID, session=session)
|
||||
async def test_missing_description_returns_error(tool, session):
|
||||
"""Missing description returns ErrorResponse."""
|
||||
result = await tool._execute(user_id=_TEST_USER_ID, session=session, description="")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_agent_json"
|
||||
|
||||
|
||||
# ── Local mode tests ────────────────────────────────────────────────────
|
||||
assert result.error == "Missing description parameter"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_mode_empty_nodes_returns_error(tool, session):
|
||||
"""Local mode with no nodes returns ErrorResponse."""
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json={"nodes": [], "links": []},
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "no nodes" in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_mode_preview(tool, session):
|
||||
"""Local mode with save=false returns AgentPreviewResponse."""
|
||||
agent_json = {
|
||||
"name": "Test Agent",
|
||||
"description": "A test agent",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"block_id": "block-1",
|
||||
"input_default": {},
|
||||
"metadata": {"position": {"x": 0, "y": 0}},
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
async def test_vague_goal_returns_suggested_goal_response(tool, session):
|
||||
"""vague_goal decomposition result returns SuggestedGoalResponse, not ErrorResponse."""
|
||||
vague_result = {
|
||||
"type": "vague_goal",
|
||||
"suggested_goal": "Monitor Twitter mentions for a specific keyword and send a daily digest email",
|
||||
}
|
||||
|
||||
mock_fixer = MagicMock()
|
||||
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
|
||||
mock_fixer.get_fixes_applied.return_value = []
|
||||
|
||||
mock_validator = MagicMock()
|
||||
mock_validator.validate.return_value = (True, None)
|
||||
mock_validator.errors = []
|
||||
|
||||
with (
|
||||
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
|
||||
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
|
||||
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.decompose_goal",
|
||||
new_callable=AsyncMock,
|
||||
return_value=vague_result,
|
||||
),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json=agent_json,
|
||||
save=False,
|
||||
description="monitor social media",
|
||||
)
|
||||
|
||||
assert isinstance(result, AgentPreviewResponse)
|
||||
assert result.agent_name == "Test Agent"
|
||||
assert result.node_count == 1
|
||||
assert isinstance(result, SuggestedGoalResponse)
|
||||
assert result.goal_type == "vague"
|
||||
assert result.suggested_goal == vague_result["suggested_goal"]
|
||||
assert result.original_goal == "monitor social media"
|
||||
assert result.reason == "The goal needs more specific details"
|
||||
assert not isinstance(result, ErrorResponse)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_mode_validation_failure(tool, session):
|
||||
"""Local mode returns ErrorResponse when validation fails after fixing."""
|
||||
agent_json = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"block_id": "bad-block",
|
||||
"input_default": {},
|
||||
"metadata": {},
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
async def test_unachievable_goal_returns_suggested_goal_response(tool, session):
|
||||
"""unachievable_goal decomposition result returns SuggestedGoalResponse, not ErrorResponse."""
|
||||
unachievable_result = {
|
||||
"type": "unachievable_goal",
|
||||
"suggested_goal": "Summarize the latest news articles on a topic and send them by email",
|
||||
"reason": "There are no blocks for mind-reading.",
|
||||
}
|
||||
|
||||
mock_fixer = MagicMock()
|
||||
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
|
||||
mock_fixer.get_fixes_applied.return_value = []
|
||||
|
||||
mock_validator = MagicMock()
|
||||
mock_validator.validate.return_value = (False, "Block 'bad-block' not found")
|
||||
mock_validator.errors = ["Block 'bad-block' not found"]
|
||||
|
||||
with (
|
||||
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
|
||||
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
|
||||
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.decompose_goal",
|
||||
new_callable=AsyncMock,
|
||||
return_value=unachievable_result,
|
||||
),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json=agent_json,
|
||||
description="read my mind",
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "validation_failed"
|
||||
assert "Block 'bad-block' not found" in result.message
|
||||
assert isinstance(result, SuggestedGoalResponse)
|
||||
assert result.goal_type == "unachievable"
|
||||
assert result.suggested_goal == unachievable_result["suggested_goal"]
|
||||
assert result.original_goal == "read my mind"
|
||||
assert result.reason == unachievable_result["reason"]
|
||||
assert not isinstance(result, ErrorResponse)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_mode_no_auth_returns_error(tool, session):
|
||||
"""Local mode with save=true and no user returns ErrorResponse."""
|
||||
agent_json = {
|
||||
"nodes": [
|
||||
async def test_clarifying_questions_returns_clarification_needed_response(
|
||||
tool, session
|
||||
):
|
||||
"""clarifying_questions decomposition result returns ClarificationNeededResponse."""
|
||||
clarifying_result = {
|
||||
"type": "clarifying_questions",
|
||||
"questions": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"block_id": "block-1",
|
||||
"input_default": {},
|
||||
"metadata": {},
|
||||
"question": "What platform should be monitored?",
|
||||
"keyword": "platform",
|
||||
"example": "Twitter, Reddit",
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
}
|
||||
|
||||
mock_fixer = MagicMock()
|
||||
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
|
||||
mock_fixer.get_fixes_applied.return_value = []
|
||||
|
||||
mock_validator = MagicMock()
|
||||
mock_validator.validate.return_value = (True, None)
|
||||
mock_validator.errors = []
|
||||
|
||||
with (
|
||||
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
|
||||
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
|
||||
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.decompose_goal",
|
||||
new_callable=AsyncMock,
|
||||
return_value=clarifying_result,
|
||||
),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=None,
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json=agent_json,
|
||||
save=True,
|
||||
description="monitor social media and alert me",
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "logged in" in result.message.lower()
|
||||
assert isinstance(result, ClarificationNeededResponse)
|
||||
assert len(result.questions) == 1
|
||||
assert result.questions[0].keyword == "platform"
|
||||
|
||||
@@ -1,20 +1,34 @@
|
||||
"""CustomizeAgentTool - Customizes marketplace/template agents."""
|
||||
"""CustomizeAgentTool - Customizes marketplace/template agents using natural language."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import store_db as get_store_db
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .agent_generator.pipeline import fetch_library_agents, fix_validate_and_save
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
customize_template,
|
||||
get_user_message_for_error,
|
||||
graph_to_json,
|
||||
save_agent_to_library,
|
||||
)
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ToolResponseBase
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomizeAgentTool(BaseTool):
|
||||
"""Tool for customizing marketplace/template agents."""
|
||||
"""Tool for customizing marketplace/template agents using natural language."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -23,9 +37,9 @@ class CustomizeAgentTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Customize a marketplace or template agent. Pass `agent_json` "
|
||||
"with the complete customized agent JSON. The tool validates, "
|
||||
"auto-fixes, and saves."
|
||||
"Customize a marketplace or template agent using natural language. "
|
||||
"Takes an existing agent from the marketplace and modifies it based on "
|
||||
"the user's requirements before adding to their library."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -37,24 +51,32 @@ class CustomizeAgentTool(BaseTool):
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_json": {
|
||||
"type": "object",
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Complete customized agent JSON to validate and save. "
|
||||
"Optionally include 'name' and 'description'."
|
||||
"The marketplace agent ID in format 'creator/slug' "
|
||||
"(e.g., 'autogpt/newsletter-writer'). "
|
||||
"Get this from find_agent results."
|
||||
),
|
||||
},
|
||||
"library_agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"modifications": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"List of library agent IDs to use as building blocks."
|
||||
"Natural language description of how to customize the agent. "
|
||||
"Be specific about what changes you want to make."
|
||||
),
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Additional context or answers to previous clarifying questions."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to save the customized agent. Default is true."
|
||||
"Whether to save the customized agent to the user's library. "
|
||||
"Default is true. Set to false for preview only."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
@@ -67,7 +89,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["agent_json"],
|
||||
"required": ["agent_id", "modifications"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
@@ -76,46 +98,247 @@ class CustomizeAgentTool(BaseTool):
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
|
||||
"""Execute the customize_agent tool.
|
||||
|
||||
Flow:
|
||||
1. Parse the agent ID to get creator/slug
|
||||
2. Fetch the template agent from the marketplace
|
||||
3. Call customize_template with the modification request
|
||||
4. Preview or save based on the save parameter
|
||||
"""
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
modifications = kwargs.get("modifications", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
folder_id = kwargs.get("folder_id")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_json:
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
|
||||
error="missing_agent_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not modifications:
|
||||
return ErrorResponse(
|
||||
message="Please describe how you want to customize this agent.",
|
||||
error="missing_modifications",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Parse agent_id in format "creator/slug"
|
||||
parts = [p.strip() for p in agent_id.split("/")]
|
||||
if len(parts) != 2 or not parts[0] or not parts[1]:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Please provide agent_json with the complete customized agent graph."
|
||||
f"Invalid agent ID format: '{agent_id}'. "
|
||||
"Expected format is 'creator/agent-name' "
|
||||
"(e.g., 'autogpt/newsletter-writer')."
|
||||
),
|
||||
error="missing_agent_json",
|
||||
error="invalid_agent_id_format",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
save = kwargs.get("save", True)
|
||||
library_agent_ids = kwargs.get("library_agent_ids", [])
|
||||
folder_id: str | None = kwargs.get("folder_id")
|
||||
creator_username, agent_slug = parts
|
||||
|
||||
nodes = agent_json.get("nodes", [])
|
||||
if not nodes:
|
||||
store_db = get_store_db()
|
||||
|
||||
# Fetch the marketplace agent details
|
||||
try:
|
||||
agent_details = await store_db.get_store_agent_details(
|
||||
username=creator_username, agent_name=agent_slug
|
||||
)
|
||||
except NotFoundError:
|
||||
return ErrorResponse(
|
||||
message="The agent JSON has no nodes.",
|
||||
error="empty_agent",
|
||||
message=(
|
||||
f"Could not find marketplace agent '{agent_id}'. "
|
||||
"Please check the agent ID and try again."
|
||||
),
|
||||
error="agent_not_found",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message="Failed to fetch the marketplace agent. Please try again.",
|
||||
error="fetch_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Ensure top-level fields before the fixer pipeline
|
||||
if "id" not in agent_json:
|
||||
agent_json["id"] = str(uuid.uuid4())
|
||||
agent_json.setdefault("version", 1)
|
||||
agent_json.setdefault("is_active", True)
|
||||
if not agent_details.store_listing_version_id:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"The agent '{agent_id}' does not have an available version. "
|
||||
"Please try a different agent."
|
||||
),
|
||||
error="no_version_available",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Fetch library agents for AgentExecutorBlock validation
|
||||
library_agents = await fetch_library_agents(user_id, library_agent_ids)
|
||||
# Get the full agent graph
|
||||
try:
|
||||
graph = await store_db.get_agent(agent_details.store_listing_version_id)
|
||||
template_agent = graph_to_json(graph)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message="Failed to fetch the agent configuration. Please try again.",
|
||||
error="graph_fetch_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return await fix_validate_and_save(
|
||||
agent_json,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
save=save,
|
||||
is_update=False,
|
||||
default_name="Customized Agent",
|
||||
library_agents=library_agents,
|
||||
folder_id=folder_id,
|
||||
# Call customize_template
|
||||
try:
|
||||
result = await customize_template(
|
||||
template_agent=template_agent,
|
||||
modification_request=modifications,
|
||||
context=context,
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Agent customization is not available. "
|
||||
"The Agent Generator service is not configured."
|
||||
),
|
||||
error="service_not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling customize_template for {agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Failed to customize the agent due to a service error. "
|
||||
"Please try again."
|
||||
),
|
||||
error="customization_service_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Failed to customize the agent. "
|
||||
"The agent generation service may be unavailable or timed out. "
|
||||
"Please try again."
|
||||
),
|
||||
error="customization_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Handle error response
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
error_type = result.get("error_type", "unknown")
|
||||
user_message = get_user_message_for_error(
|
||||
error_type,
|
||||
operation="customize the agent",
|
||||
llm_parse_message=(
|
||||
"The AI had trouble customizing the agent. "
|
||||
"Please try again or simplify your request."
|
||||
),
|
||||
validation_message=(
|
||||
"The customized agent failed validation. "
|
||||
"Please try rephrasing your request."
|
||||
),
|
||||
error_details=error_msg,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
error=f"customization_failed:{error_type}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Handle clarifying questions
|
||||
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
|
||||
questions = result.get("questions") or []
|
||||
if not isinstance(questions, list):
|
||||
logger.error(
|
||||
f"Unexpected clarifying questions format: {type(questions)}"
|
||||
)
|
||||
questions = []
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information to customize this agent. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
if isinstance(q, dict)
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Result should be the customized agent JSON
|
||||
if not isinstance(result, dict):
|
||||
logger.error(f"Unexpected customize_template response type: {type(result)}")
|
||||
return ErrorResponse(
|
||||
message="Failed to customize the agent due to an unexpected response.",
|
||||
error="unexpected_response_type",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
customized_agent = result
|
||||
|
||||
agent_name = customized_agent.get(
|
||||
"name", f"Customized {agent_details.agent_name}"
|
||||
)
|
||||
agent_description = customized_agent.get("description", "")
|
||||
nodes = customized_agent.get("nodes")
|
||||
links = customized_agent.get("links")
|
||||
node_count = len(nodes) if isinstance(nodes, list) else 0
|
||||
link_count = len(links) if isinstance(links, list) else 0
|
||||
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've customized the agent '{agent_details.agent_name}'. "
|
||||
f"The customized agent has {node_count} blocks. "
|
||||
f"Review it and call customize_agent with save=true to save it."
|
||||
),
|
||||
agent_json=customized_agent,
|
||||
agent_name=agent_name,
|
||||
description=agent_description,
|
||||
node_count=node_count,
|
||||
link_count=link_count,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
error="auth_required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Save to user's library
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
customized_agent, user_id, is_update=False, folder_id=folder_id
|
||||
)
|
||||
|
||||
return AgentSavedResponse(
|
||||
message=(
|
||||
f"Customized agent '{created_graph.name}' "
|
||||
f"(based on '{agent_details.agent_name}') "
|
||||
f"has been saved to your library!"
|
||||
),
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving customized agent: {e}")
|
||||
return ErrorResponse(
|
||||
message="Failed to save the customized agent. Please try again.",
|
||||
error="save_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -1,172 +0,0 @@
|
||||
"""Tests for CustomizeAgentTool local mode."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.tools.customize_agent import CustomizeAgentTool
|
||||
from backend.copilot.tools.models import AgentPreviewResponse, ErrorResponse
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-customize-agent"
|
||||
_PIPELINE = "backend.copilot.tools.agent_generator.pipeline"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return CustomizeAgentTool()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session():
|
||||
return make_session(_TEST_USER_ID)
|
||||
|
||||
|
||||
# ── Input validation tests ───────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_agent_json_returns_error(tool, session):
|
||||
"""Missing agent_json returns ErrorResponse."""
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_agent_json"
|
||||
|
||||
|
||||
# ── Local mode tests (agent_json provided) ───────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_mode_empty_nodes_returns_error(tool, session):
|
||||
"""Local mode with no nodes returns ErrorResponse."""
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json={"nodes": [], "links": []},
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "no nodes" in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_mode_preview(tool, session):
|
||||
"""Local mode with save=false returns AgentPreviewResponse."""
|
||||
agent_json = {
|
||||
"name": "Customized Agent",
|
||||
"description": "A customized agent",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"block_id": "block-1",
|
||||
"input_default": {},
|
||||
"metadata": {"position": {"x": 0, "y": 0}},
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
}
|
||||
|
||||
mock_fixer = MagicMock()
|
||||
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
|
||||
mock_fixer.get_fixes_applied.return_value = []
|
||||
|
||||
mock_validator = MagicMock()
|
||||
mock_validator.validate.return_value = (True, None)
|
||||
mock_validator.errors = []
|
||||
|
||||
with (
|
||||
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
|
||||
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
|
||||
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json=agent_json,
|
||||
save=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, AgentPreviewResponse)
|
||||
assert result.agent_name == "Customized Agent"
|
||||
assert result.node_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_mode_validation_failure(tool, session):
|
||||
"""Local mode returns ErrorResponse when validation fails."""
|
||||
agent_json = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"block_id": "bad-block",
|
||||
"input_default": {},
|
||||
"metadata": {},
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
}
|
||||
|
||||
mock_fixer = MagicMock()
|
||||
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
|
||||
mock_fixer.get_fixes_applied.return_value = []
|
||||
|
||||
mock_validator = MagicMock()
|
||||
mock_validator.validate.return_value = (False, "Block 'bad-block' not found")
|
||||
mock_validator.errors = ["Block 'bad-block' not found"]
|
||||
|
||||
with (
|
||||
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
|
||||
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
|
||||
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json=agent_json,
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "validation_failed"
|
||||
assert "Block 'bad-block' not found" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_mode_no_auth_returns_error(tool, session):
|
||||
"""Local mode with save=true and no user returns ErrorResponse."""
|
||||
agent_json = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"block_id": "block-1",
|
||||
"input_default": {},
|
||||
"metadata": {},
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
}
|
||||
|
||||
mock_fixer = MagicMock()
|
||||
mock_fixer.apply_all_fixes = MagicMock(return_value=agent_json)
|
||||
mock_fixer.get_fixes_applied.return_value = []
|
||||
|
||||
mock_validator = MagicMock()
|
||||
mock_validator.validate.return_value = (True, None)
|
||||
mock_validator.errors = []
|
||||
|
||||
with (
|
||||
patch(f"{_PIPELINE}.get_blocks_as_dicts", return_value=[]),
|
||||
patch(f"{_PIPELINE}.AgentFixer", return_value=mock_fixer),
|
||||
patch(f"{_PIPELINE}.AgentValidator", return_value=mock_validator),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
agent_json=agent_json,
|
||||
save=True,
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "logged in" in result.message.lower()
|
||||
@@ -10,34 +10,13 @@ Lifecycle
|
||||
---------
|
||||
1. **Turn start** – connect to the existing sandbox (sandbox_id in Redis) or
|
||||
create a new one via ``get_or_create_sandbox()``.
|
||||
``connect()`` in e2b v2 auto-resumes paused sandboxes.
|
||||
2. **Execution** – ``bash_exec`` and MCP file tools operate directly on the
|
||||
sandbox's ``/home/user`` filesystem.
|
||||
3. **Turn end** – the sandbox is paused via ``pause_sandbox()`` (fire-and-forget)
|
||||
so idle time between turns costs nothing. Paused sandboxes have no compute
|
||||
cost.
|
||||
4. **Session delete** – ``kill_sandbox()`` fully terminates the sandbox.
|
||||
|
||||
Cost control
|
||||
------------
|
||||
Sandboxes are created with a configurable ``on_timeout`` lifecycle action
|
||||
(default: ``"pause"``). The explicit per-turn ``pause_sandbox()`` call is the
|
||||
primary mechanism; the lifecycle setting is a safety net. Paused sandboxes are
|
||||
free.
|
||||
|
||||
The sandbox_id is stored in Redis. The same key doubles as a creation lock:
|
||||
a ``"creating"`` sentinel value is written with a short TTL while a new sandbox
|
||||
is being provisioned, preventing duplicate creation under concurrent requests.
|
||||
|
||||
E2B project-level "paused sandbox lifetime" should be set to match
|
||||
``_SANDBOX_ID_TTL`` (48 h) so orphaned paused sandboxes are auto-killed before
|
||||
the Redis key expires.
|
||||
3. **Session expiry** – E2B sandbox is killed by its own timeout (session_ttl).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Literal
|
||||
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
@@ -45,245 +24,147 @@ from backend.data.redis_client import get_redis_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SANDBOX_KEY_PREFIX = "copilot:e2b:sandbox:"
|
||||
_CREATING_SENTINEL = "creating"
|
||||
|
||||
# Short TTL for the "creating" sentinel — if the process dies mid-creation the
|
||||
# lock auto-expires so other callers are not blocked forever.
|
||||
_CREATION_LOCK_TTL = 60 # seconds
|
||||
|
||||
_MAX_WAIT_ATTEMPTS = 20 # 20 × 0.5 s = 10 s max wait
|
||||
|
||||
# Timeout for E2B API calls (pause/kill) — short because these are control-plane
|
||||
# operations; if the sandbox is unreachable, fail fast and retry on the next turn.
|
||||
_E2B_API_TIMEOUT_SECONDS = 10
|
||||
|
||||
# Redis TTL for the sandbox key. Must be ≥ the E2B project "paused sandbox
|
||||
# lifetime" setting (recommended: set both to 48 h).
|
||||
_SANDBOX_ID_TTL = 48 * 3600 # 48 hours
|
||||
|
||||
|
||||
def _sandbox_key(session_id: str) -> str:
|
||||
return f"{_SANDBOX_KEY_PREFIX}{session_id}"
|
||||
|
||||
|
||||
async def _get_stored_sandbox_id(session_id: str) -> str | None:
|
||||
redis = await get_redis_async()
|
||||
raw = await redis.get(_sandbox_key(session_id))
|
||||
value = raw.decode() if isinstance(raw, bytes) else raw
|
||||
return None if value == _CREATING_SENTINEL else value
|
||||
|
||||
|
||||
async def _set_stored_sandbox_id(session_id: str, sandbox_id: str) -> None:
|
||||
redis = await get_redis_async()
|
||||
await redis.set(_sandbox_key(session_id), sandbox_id, ex=_SANDBOX_ID_TTL)
|
||||
|
||||
|
||||
async def _clear_stored_sandbox_id(session_id: str) -> None:
|
||||
redis = await get_redis_async()
|
||||
await redis.delete(_sandbox_key(session_id))
|
||||
_SANDBOX_REDIS_PREFIX = "copilot:e2b:sandbox:"
|
||||
E2B_WORKDIR = "/home/user"
|
||||
_CREATING = "__creating__"
|
||||
_CREATION_LOCK_TTL = 60
|
||||
_MAX_WAIT_ATTEMPTS = 20 # 20 * 0.5s = 10s max wait
|
||||
|
||||
|
||||
async def _try_reconnect(
|
||||
sandbox_id: str, session_id: str, api_key: str
|
||||
sandbox_id: str, api_key: str, redis_key: str, timeout: int
|
||||
) -> "AsyncSandbox | None":
|
||||
"""Try to reconnect to an existing sandbox. Returns None on failure."""
|
||||
try:
|
||||
sandbox = await AsyncSandbox.connect(sandbox_id, api_key=api_key)
|
||||
if await sandbox.is_running():
|
||||
# Refresh TTL so an active session cannot lose its sandbox_id at expiry.
|
||||
await _set_stored_sandbox_id(session_id, sandbox_id)
|
||||
redis = await get_redis_async()
|
||||
await redis.expire(redis_key, timeout)
|
||||
return sandbox
|
||||
except Exception as exc:
|
||||
logger.warning("[E2B] Reconnect to %.12s failed: %s", sandbox_id, exc)
|
||||
|
||||
# Stale — clear the sandbox_id from Redis so a new one can be created.
|
||||
await _clear_stored_sandbox_id(session_id)
|
||||
# Stale — clear Redis so a new sandbox can be created.
|
||||
redis = await get_redis_async()
|
||||
await redis.delete(redis_key)
|
||||
return None
|
||||
|
||||
|
||||
async def get_or_create_sandbox(
|
||||
session_id: str,
|
||||
api_key: str,
|
||||
timeout: int,
|
||||
template: str = "base",
|
||||
on_timeout: Literal["kill", "pause"] = "pause",
|
||||
timeout: int = 43200,
|
||||
) -> AsyncSandbox:
|
||||
"""Return the existing E2B sandbox for *session_id* or create a new one.
|
||||
|
||||
The sandbox key in Redis serves a dual purpose: it stores the sandbox_id
|
||||
and acts as a creation lock via a ``"creating"`` sentinel value. This
|
||||
removes the need for a separate lock key.
|
||||
|
||||
*timeout* controls how long the e2b sandbox may run continuously before
|
||||
the ``on_timeout`` lifecycle rule fires (default: 3 h).
|
||||
*on_timeout* controls what happens on timeout: ``"pause"`` (default, free)
|
||||
or ``"kill"``.
|
||||
The sandbox_id is persisted in Redis so the same sandbox is reused
|
||||
across turns. Concurrent calls for the same session are serialised
|
||||
via a Redis ``SET NX`` creation lock.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = _sandbox_key(session_id)
|
||||
redis_key = f"{_SANDBOX_REDIS_PREFIX}{session_id}"
|
||||
|
||||
for _ in range(_MAX_WAIT_ATTEMPTS):
|
||||
raw = await redis.get(key)
|
||||
value = raw.decode() if isinstance(raw, bytes) else raw
|
||||
|
||||
if value and value != _CREATING_SENTINEL:
|
||||
# Existing sandbox ID — try to reconnect (auto-resumes if paused).
|
||||
sandbox = await _try_reconnect(value, session_id, api_key)
|
||||
# 1. Try reconnecting to an existing sandbox.
|
||||
raw = await redis.get(redis_key)
|
||||
if raw:
|
||||
sandbox_id = raw if isinstance(raw, str) else raw.decode()
|
||||
if sandbox_id != _CREATING:
|
||||
sandbox = await _try_reconnect(sandbox_id, api_key, redis_key, timeout)
|
||||
if sandbox:
|
||||
logger.info(
|
||||
"[E2B] Reconnected to %.12s for session %.12s",
|
||||
value,
|
||||
sandbox_id,
|
||||
session_id,
|
||||
)
|
||||
return sandbox
|
||||
# _try_reconnect cleared the key — loop to create a new sandbox.
|
||||
continue
|
||||
|
||||
if value == _CREATING_SENTINEL:
|
||||
# Another coroutine is creating — wait for it to finish.
|
||||
# 2. Claim creation lock. If another request holds it, wait for the result.
|
||||
claimed = await redis.set(redis_key, _CREATING, nx=True, ex=_CREATION_LOCK_TTL)
|
||||
if not claimed:
|
||||
for _ in range(_MAX_WAIT_ATTEMPTS):
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
raw = await redis.get(redis_key)
|
||||
if not raw:
|
||||
break # Lock expired — fall through to retry creation
|
||||
sandbox_id = raw if isinstance(raw, str) else raw.decode()
|
||||
if sandbox_id != _CREATING:
|
||||
sandbox = await _try_reconnect(sandbox_id, api_key, redis_key, timeout)
|
||||
if sandbox:
|
||||
return sandbox
|
||||
break # Stale sandbox cleared — fall through to create
|
||||
|
||||
# No sandbox and no active creation — atomically claim the creation slot.
|
||||
claimed = await redis.set(
|
||||
key, _CREATING_SENTINEL, nx=True, ex=_CREATION_LOCK_TTL
|
||||
)
|
||||
# Try to claim creation lock again after waiting.
|
||||
claimed = await redis.set(redis_key, _CREATING, nx=True, ex=_CREATION_LOCK_TTL)
|
||||
if not claimed:
|
||||
# Race lost — another coroutine just claimed it.
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
# We hold the slot — create the sandbox.
|
||||
try:
|
||||
sandbox = await AsyncSandbox.create(
|
||||
template=template,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
lifecycle={"on_timeout": on_timeout},
|
||||
# Another process may have created a sandbox — try to use it.
|
||||
raw = await redis.get(redis_key)
|
||||
if raw:
|
||||
sandbox_id = raw if isinstance(raw, str) else raw.decode()
|
||||
if sandbox_id != _CREATING:
|
||||
sandbox = await _try_reconnect(
|
||||
sandbox_id, api_key, redis_key, timeout
|
||||
)
|
||||
if sandbox:
|
||||
return sandbox
|
||||
raise RuntimeError(
|
||||
f"Could not acquire E2B creation lock for session {session_id[:12]}"
|
||||
)
|
||||
try:
|
||||
await _set_stored_sandbox_id(session_id, sandbox.sandbox_id)
|
||||
except Exception:
|
||||
# Redis save failed — kill the sandbox to avoid leaking it.
|
||||
with contextlib.suppress(Exception):
|
||||
await sandbox.kill()
|
||||
raise
|
||||
except Exception:
|
||||
# Release the creation slot so other callers can proceed.
|
||||
await redis.delete(key)
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
"[E2B] Created sandbox %.12s for session %.12s",
|
||||
sandbox.sandbox_id,
|
||||
session_id,
|
||||
)
|
||||
return sandbox
|
||||
|
||||
raise RuntimeError(f"Could not acquire E2B sandbox for session {session_id[:12]}")
|
||||
|
||||
|
||||
async def _act_on_sandbox(
|
||||
session_id: str,
|
||||
api_key: str,
|
||||
action: str,
|
||||
fn: Callable[[AsyncSandbox], Awaitable[Any]],
|
||||
*,
|
||||
clear_stored_id: bool = False,
|
||||
) -> bool:
|
||||
"""Connect to the sandbox for *session_id* and run *fn* on it.
|
||||
|
||||
Shared by ``pause_sandbox`` and ``kill_sandbox``. Returns ``True`` on
|
||||
success, ``False`` when no sandbox is found or the action fails.
|
||||
If *clear_stored_id* is ``True``, the sandbox_id is removed from Redis
|
||||
only after the action succeeds so a failed kill can be retried.
|
||||
"""
|
||||
sandbox_id = await _get_stored_sandbox_id(session_id)
|
||||
if not sandbox_id:
|
||||
return False
|
||||
|
||||
async def _run() -> None:
|
||||
await fn(await AsyncSandbox.connect(sandbox_id, api_key=api_key))
|
||||
|
||||
# 3. Create a new sandbox.
|
||||
try:
|
||||
await asyncio.wait_for(_run(), timeout=_E2B_API_TIMEOUT_SECONDS)
|
||||
if clear_stored_id:
|
||||
await _clear_stored_sandbox_id(session_id)
|
||||
logger.info(
|
||||
"[E2B] %s sandbox %.12s for session %.12s",
|
||||
action.capitalize(),
|
||||
sandbox_id,
|
||||
session_id,
|
||||
sandbox = await AsyncSandbox.create(
|
||||
template=template, api_key=api_key, timeout=timeout
|
||||
)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[E2B] Failed to %s sandbox %.12s for session %.12s: %s",
|
||||
action,
|
||||
sandbox_id,
|
||||
session_id,
|
||||
exc,
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
await redis.delete(redis_key)
|
||||
raise
|
||||
|
||||
await redis.setex(redis_key, timeout, sandbox.sandbox_id)
|
||||
logger.info(
|
||||
"[E2B] Created sandbox %.12s for session %.12s",
|
||||
sandbox.sandbox_id,
|
||||
session_id,
|
||||
)
|
||||
return sandbox
|
||||
|
||||
|
||||
async def pause_sandbox(session_id: str, api_key: str) -> bool:
|
||||
"""Pause the E2B sandbox for *session_id* to stop billing between turns.
|
||||
|
||||
Paused sandboxes cost nothing and are resumed automatically by
|
||||
``get_or_create_sandbox()`` on the next turn (via ``AsyncSandbox.connect()``).
|
||||
The sandbox_id is kept in Redis so reconnection works seamlessly.
|
||||
|
||||
Prefer ``pause_sandbox_direct()`` when the sandbox object is already in
|
||||
scope — it skips the Redis lookup and reconnect round-trip.
|
||||
|
||||
Returns ``True`` if the sandbox was found and paused, ``False`` otherwise.
|
||||
Safe to call even when no sandbox exists for the session.
|
||||
"""
|
||||
return await _act_on_sandbox(session_id, api_key, "pause", lambda sb: sb.pause())
|
||||
|
||||
|
||||
async def pause_sandbox_direct(sandbox: "AsyncSandbox", session_id: str) -> bool:
|
||||
"""Pause an already-connected sandbox without a reconnect round-trip.
|
||||
|
||||
Use this in callers that already hold the live sandbox object (e.g. turn
|
||||
teardown in ``service.py``). Saves the Redis lookup and
|
||||
``AsyncSandbox.connect()`` call that ``pause_sandbox()`` would make.
|
||||
|
||||
Returns ``True`` on success, ``False`` on failure or timeout.
|
||||
"""
|
||||
try:
|
||||
await asyncio.wait_for(sandbox.pause(), timeout=_E2B_API_TIMEOUT_SECONDS)
|
||||
logger.info(
|
||||
"[E2B] Paused sandbox %.12s for session %.12s",
|
||||
sandbox.sandbox_id,
|
||||
session_id,
|
||||
)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[E2B] Failed to pause sandbox %.12s for session %.12s: %s",
|
||||
sandbox.sandbox_id,
|
||||
session_id,
|
||||
exc,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def kill_sandbox(
|
||||
session_id: str,
|
||||
api_key: str,
|
||||
) -> bool:
|
||||
"""Kill the E2B sandbox for *session_id* and clear its Redis entry.
|
||||
async def kill_sandbox(session_id: str, api_key: str) -> bool:
|
||||
"""Kill the E2B sandbox for *session_id* and clean up its Redis entry.
|
||||
|
||||
Returns ``True`` if a sandbox was found and killed, ``False`` otherwise.
|
||||
Safe to call even when no sandbox exists for the session.
|
||||
"""
|
||||
return await _act_on_sandbox(
|
||||
session_id,
|
||||
api_key,
|
||||
"kill",
|
||||
lambda sb: sb.kill(),
|
||||
clear_stored_id=True,
|
||||
)
|
||||
redis = await get_redis_async()
|
||||
redis_key = f"{_SANDBOX_REDIS_PREFIX}{session_id}"
|
||||
raw = await redis.get(redis_key)
|
||||
if not raw:
|
||||
return False
|
||||
|
||||
sandbox_id = raw if isinstance(raw, str) else raw.decode()
|
||||
await redis.delete(redis_key)
|
||||
|
||||
if sandbox_id == _CREATING:
|
||||
return False
|
||||
|
||||
try:
|
||||
|
||||
async def _connect_and_kill():
|
||||
sandbox = await AsyncSandbox.connect(sandbox_id, api_key=api_key)
|
||||
await sandbox.kill()
|
||||
|
||||
await asyncio.wait_for(_connect_and_kill(), timeout=10)
|
||||
logger.info(
|
||||
"[E2B] Killed sandbox %.12s for session %.12s",
|
||||
sandbox_id,
|
||||
session_id,
|
||||
)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[E2B] Failed to kill sandbox %.12s for session %.12s: %s",
|
||||
sandbox_id,
|
||||
session_id,
|
||||
exc,
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -1,12 +1,6 @@
|
||||
"""Tests for e2b_sandbox: get_or_create_sandbox, _try_reconnect, kill_sandbox.
|
||||
|
||||
sandbox_id is stored in Redis under _SANDBOX_KEY_PREFIX + session_id.
|
||||
The same key doubles as a creation lock via a "creating" sentinel value.
|
||||
|
||||
Tests mock:
|
||||
- ``get_redis_async`` (sandbox key storage + creation lock sentinel)
|
||||
- ``AsyncSandbox`` (E2B SDK)
|
||||
|
||||
Uses mock Redis and mock AsyncSandbox — no external dependencies.
|
||||
Tests are synchronous (using asyncio.run) to avoid conflicts with the
|
||||
session-scoped event loop in conftest.py.
|
||||
"""
|
||||
@@ -17,50 +11,36 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from .e2b_sandbox import (
|
||||
_CREATING_SENTINEL,
|
||||
_CREATING,
|
||||
_SANDBOX_REDIS_PREFIX,
|
||||
_try_reconnect,
|
||||
get_or_create_sandbox,
|
||||
kill_sandbox,
|
||||
pause_sandbox,
|
||||
pause_sandbox_direct,
|
||||
)
|
||||
|
||||
_SESSION_ID = "sess-123"
|
||||
_KEY = f"{_SANDBOX_REDIS_PREFIX}sess-123"
|
||||
_API_KEY = "test-api-key"
|
||||
_SANDBOX_ID = "sb-abc"
|
||||
_TIMEOUT = 300
|
||||
|
||||
|
||||
def _mock_sandbox(sandbox_id: str = _SANDBOX_ID, running: bool = True) -> MagicMock:
|
||||
def _mock_sandbox(sandbox_id: str = "sb-abc", running: bool = True) -> MagicMock:
|
||||
sb = MagicMock()
|
||||
sb.sandbox_id = sandbox_id
|
||||
sb.is_running = AsyncMock(return_value=running)
|
||||
sb.pause = AsyncMock()
|
||||
sb.kill = AsyncMock()
|
||||
return sb
|
||||
|
||||
|
||||
def _mock_redis(
|
||||
set_nx_result: bool = True,
|
||||
stored_sandbox_id: str | None = None,
|
||||
) -> AsyncMock:
|
||||
"""Create a mock redis client.
|
||||
|
||||
*stored_sandbox_id* is returned by ``get()`` calls (simulates the sandbox_id
|
||||
stored under the ``_SANDBOX_KEY_PREFIX`` key). ``set_nx_result`` controls
|
||||
whether the creation-slot ``SET NX`` succeeds.
|
||||
|
||||
If *stored_sandbox_id* is None the key is absent (no sandbox, no lock).
|
||||
"""
|
||||
def _mock_redis(get_val: str | bytes | None = None, set_nx_result: bool = True):
|
||||
r = AsyncMock()
|
||||
raw = stored_sandbox_id.encode() if stored_sandbox_id else None
|
||||
r.get = AsyncMock(return_value=raw)
|
||||
r.get = AsyncMock(return_value=get_val)
|
||||
r.set = AsyncMock(return_value=set_nx_result)
|
||||
r.setex = AsyncMock()
|
||||
r.delete = AsyncMock()
|
||||
r.expire = AsyncMock()
|
||||
return r
|
||||
|
||||
|
||||
def _patch_redis(redis: AsyncMock):
|
||||
def _patch_redis(redis):
|
||||
return patch(
|
||||
"backend.copilot.tools.e2b_sandbox.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
@@ -75,7 +55,6 @@ def _patch_redis(redis: AsyncMock):
|
||||
|
||||
class TestTryReconnect:
|
||||
def test_reconnect_success(self):
|
||||
"""Returns the sandbox when it connects and is running; refreshes Redis TTL."""
|
||||
sb = _mock_sandbox()
|
||||
redis = _mock_redis()
|
||||
with (
|
||||
@@ -83,39 +62,36 @@ class TestTryReconnect:
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(_try_reconnect(_SANDBOX_ID, _SESSION_ID, _API_KEY))
|
||||
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
|
||||
|
||||
assert result is sb
|
||||
redis.expire.assert_awaited_once_with(_KEY, _TIMEOUT)
|
||||
redis.delete.assert_not_awaited()
|
||||
# TTL must be refreshed so an active session cannot lose its key at expiry.
|
||||
redis.set.assert_awaited_once()
|
||||
|
||||
def test_reconnect_not_running_clears_redis(self):
|
||||
"""Clears sandbox_id in Redis when the sandbox is no longer running."""
|
||||
def test_reconnect_not_running_clears_key(self):
|
||||
sb = _mock_sandbox(running=False)
|
||||
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
|
||||
redis = _mock_redis()
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(_try_reconnect(_SANDBOX_ID, _SESSION_ID, _API_KEY))
|
||||
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
|
||||
|
||||
assert result is None
|
||||
redis.delete.assert_awaited_once()
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
|
||||
def test_reconnect_exception_clears_redis(self):
|
||||
"""Clears sandbox_id in Redis when connect raises an exception."""
|
||||
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
|
||||
def test_reconnect_exception_clears_key(self):
|
||||
redis = _mock_redis()
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
|
||||
result = asyncio.run(_try_reconnect(_SANDBOX_ID, _SESSION_ID, _API_KEY))
|
||||
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
|
||||
|
||||
assert result is None
|
||||
redis.delete.assert_awaited_once()
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -127,63 +103,38 @@ class TestGetOrCreateSandbox:
|
||||
def test_reconnect_existing(self):
|
||||
"""When Redis has a valid sandbox_id, reconnect to it."""
|
||||
sb = _mock_sandbox()
|
||||
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
|
||||
redis = _mock_redis(get_val="sb-abc")
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(
|
||||
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
assert result is sb
|
||||
mock_cls.create.assert_not_called()
|
||||
# redis.set called once to refresh TTL, not to claim a creation slot
|
||||
redis.set.assert_awaited_once()
|
||||
|
||||
def test_create_new_when_no_stored_id(self):
|
||||
"""When Redis has no sandbox_id, claim slot and create a new sandbox."""
|
||||
new_sb = _mock_sandbox("sb-new")
|
||||
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
|
||||
def test_create_new_when_no_key(self):
|
||||
"""When Redis is empty, claim lock and create a new sandbox."""
|
||||
sb = _mock_sandbox("sb-new")
|
||||
redis = _mock_redis(get_val=None, set_nx_result=True)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.create = AsyncMock(return_value=new_sb)
|
||||
mock_cls.create = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(
|
||||
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
assert result is new_sb
|
||||
mock_cls.create.assert_awaited_once()
|
||||
# Verify lifecycle param is set
|
||||
_, kwargs = mock_cls.create.call_args
|
||||
assert kwargs.get("lifecycle") == {"on_timeout": "pause"}
|
||||
# sandbox_id should be saved to Redis
|
||||
redis.set.assert_awaited()
|
||||
assert result is sb
|
||||
redis.setex.assert_awaited_once_with(_KEY, _TIMEOUT, "sb-new")
|
||||
|
||||
def test_create_with_on_timeout_kill(self):
|
||||
"""on_timeout='kill' is passed through to AsyncSandbox.create."""
|
||||
new_sb = _mock_sandbox("sb-new")
|
||||
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.create = AsyncMock(return_value=new_sb)
|
||||
asyncio.run(
|
||||
get_or_create_sandbox(
|
||||
_SESSION_ID, _API_KEY, timeout=_TIMEOUT, on_timeout="kill"
|
||||
)
|
||||
)
|
||||
|
||||
_, kwargs = mock_cls.create.call_args
|
||||
assert kwargs.get("lifecycle") == {"on_timeout": "kill"}
|
||||
|
||||
def test_create_failure_releases_slot(self):
|
||||
"""If sandbox creation fails, the Redis creation slot is deleted."""
|
||||
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
|
||||
def test_create_failure_clears_lock(self):
|
||||
"""If sandbox creation fails, the Redis lock is deleted."""
|
||||
redis = _mock_redis(get_val=None, set_nx_result=True)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
@@ -191,53 +142,17 @@ class TestGetOrCreateSandbox:
|
||||
mock_cls.create = AsyncMock(side_effect=RuntimeError("quota"))
|
||||
with pytest.raises(RuntimeError, match="quota"):
|
||||
asyncio.run(
|
||||
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
redis.delete.assert_awaited_once()
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
|
||||
def test_redis_save_failure_kills_sandbox_and_releases_slot(self):
|
||||
"""If Redis save fails after creation, sandbox is killed and slot released."""
|
||||
new_sb = _mock_sandbox("sb-new")
|
||||
redis = _mock_redis(set_nx_result=True, stored_sandbox_id=None)
|
||||
# First set() call = creation slot SET NX (returns True).
|
||||
# Second set() call = sandbox_id save (raises).
|
||||
call_count = 0
|
||||
|
||||
async def _set_side_effect(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return True # creation slot claimed
|
||||
raise RuntimeError("redis error")
|
||||
|
||||
redis.set = AsyncMock(side_effect=_set_side_effect)
|
||||
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.create = AsyncMock(return_value=new_sb)
|
||||
with pytest.raises(RuntimeError, match="redis error"):
|
||||
asyncio.run(
|
||||
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
# Sandbox must be killed to avoid leaking it
|
||||
new_sb.kill.assert_awaited_once()
|
||||
# Creation slot must always be released
|
||||
redis.delete.assert_awaited_once()
|
||||
|
||||
def test_wait_for_creating_sentinel_then_reconnect(self):
|
||||
"""When the key holds the 'creating' sentinel, wait then reconnect."""
|
||||
def test_wait_for_lock_then_reconnect(self):
|
||||
"""When another process holds the lock, wait and reconnect."""
|
||||
sb = _mock_sandbox("sb-other")
|
||||
# First get() returns the sentinel; second returns the real ID.
|
||||
redis = AsyncMock()
|
||||
creating_raw = _CREATING_SENTINEL.encode()
|
||||
redis.get = AsyncMock(side_effect=[creating_raw, b"sb-other"])
|
||||
redis = _mock_redis()
|
||||
redis.get = AsyncMock(side_effect=[_CREATING, "sb-other"])
|
||||
redis.set = AsyncMock(return_value=False)
|
||||
redis.delete = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
@@ -248,21 +163,16 @@ class TestGetOrCreateSandbox:
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(
|
||||
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
assert result is sb
|
||||
|
||||
def test_stale_reconnect_clears_and_creates(self):
|
||||
"""When stored sandbox is stale (not running), clear it and create a new one."""
|
||||
"""When stored sandbox is stale, clear key and create a new one."""
|
||||
stale_sb = _mock_sandbox("sb-stale", running=False)
|
||||
new_sb = _mock_sandbox("sb-fresh")
|
||||
# First get() returns stale id (for reconnect check), then None (after clear).
|
||||
redis = AsyncMock()
|
||||
redis.get = AsyncMock(side_effect=[b"sb-stale", None])
|
||||
redis.set = AsyncMock(return_value=True)
|
||||
redis.delete = AsyncMock()
|
||||
|
||||
redis = _mock_redis(get_val="sb-stale", set_nx_result=True)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
@@ -270,11 +180,10 @@ class TestGetOrCreateSandbox:
|
||||
mock_cls.connect = AsyncMock(return_value=stale_sb)
|
||||
mock_cls.create = AsyncMock(return_value=new_sb)
|
||||
result = asyncio.run(
|
||||
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
assert result is new_sb
|
||||
# Redis delete called at least once to clear stale id
|
||||
redis.delete.assert_awaited()
|
||||
|
||||
|
||||
@@ -285,48 +194,70 @@ class TestGetOrCreateSandbox:
|
||||
|
||||
class TestKillSandbox:
|
||||
def test_kill_existing_sandbox(self):
|
||||
"""Kill a running sandbox and clear its Redis entry."""
|
||||
"""Kill a running sandbox and clean up Redis."""
|
||||
sb = _mock_sandbox()
|
||||
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
|
||||
sb.kill = AsyncMock()
|
||||
redis = _mock_redis(get_val="sb-abc")
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is True
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
sb.kill.assert_awaited_once()
|
||||
|
||||
def test_kill_no_sandbox(self):
|
||||
"""No-op when no sandbox exists in Redis."""
|
||||
redis = _mock_redis(get_val=None)
|
||||
with _patch_redis(redis):
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is False
|
||||
redis.delete.assert_not_awaited()
|
||||
|
||||
def test_kill_creating_state(self):
|
||||
"""Clears Redis key but returns False when sandbox is still being created."""
|
||||
redis = _mock_redis(get_val=_CREATING)
|
||||
with _patch_redis(redis):
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is False
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
|
||||
def test_kill_connect_failure(self):
|
||||
"""Returns False and cleans Redis if connect/kill fails."""
|
||||
redis = _mock_redis(get_val="sb-abc")
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is False
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
|
||||
def test_kill_with_bytes_redis_value(self):
|
||||
"""Redis may return bytes — kill_sandbox should decode correctly."""
|
||||
sb = _mock_sandbox()
|
||||
sb.kill = AsyncMock()
|
||||
redis = _mock_redis(get_val=b"sb-abc")
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is True
|
||||
sb.kill.assert_awaited_once()
|
||||
# Redis key cleared after successful kill
|
||||
redis.delete.assert_awaited_once()
|
||||
|
||||
def test_kill_no_sandbox(self):
|
||||
"""No-op when Redis has no sandbox_id."""
|
||||
redis = _mock_redis(stored_sandbox_id=None)
|
||||
with _patch_redis(redis):
|
||||
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_kill_connect_failure_keeps_redis(self):
|
||||
"""Returns False and leaves Redis entry intact when connect/kill fails.
|
||||
|
||||
Keeping the sandbox_id in Redis allows the kill to be retried.
|
||||
"""
|
||||
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
|
||||
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
|
||||
|
||||
assert result is False
|
||||
redis.delete.assert_not_awaited()
|
||||
|
||||
def test_kill_timeout_keeps_redis(self):
|
||||
"""Returns False and leaves Redis entry intact when the E2B call times out."""
|
||||
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
|
||||
def test_kill_timeout_returns_false(self):
|
||||
"""Returns False when E2B API calls exceed the 10s timeout."""
|
||||
redis = _mock_redis(get_val="sb-abc")
|
||||
with (
|
||||
_patch_redis(redis),
|
||||
patch(
|
||||
@@ -335,146 +266,7 @@ class TestKillSandbox:
|
||||
side_effect=asyncio.TimeoutError,
|
||||
),
|
||||
):
|
||||
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
|
||||
|
||||
assert result is False
|
||||
redis.delete.assert_not_awaited()
|
||||
|
||||
def test_kill_creating_sentinel_returns_false(self):
|
||||
"""No-op when the key holds the 'creating' sentinel (no real sandbox yet)."""
|
||||
redis = _mock_redis(stored_sandbox_id=_CREATING_SENTINEL)
|
||||
with _patch_redis(redis):
|
||||
result = asyncio.run(kill_sandbox(_SESSION_ID, _API_KEY))
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# pause_sandbox
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPauseSandbox:
|
||||
def test_pause_existing_sandbox(self):
|
||||
"""Pause a running sandbox; Redis sandbox_id is preserved."""
|
||||
sb = _mock_sandbox()
|
||||
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
|
||||
|
||||
assert result is True
|
||||
sb.pause.assert_awaited_once()
|
||||
# sandbox_id should remain in Redis (not cleared on pause)
|
||||
redis.delete.assert_not_awaited()
|
||||
|
||||
def test_pause_no_sandbox(self):
|
||||
"""No-op when Redis has no sandbox_id."""
|
||||
redis = _mock_redis(stored_sandbox_id=None)
|
||||
with _patch_redis(redis):
|
||||
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_pause_connect_failure(self):
|
||||
"""Returns False if connect fails."""
|
||||
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
|
||||
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_pause_creating_sentinel_returns_false(self):
|
||||
"""No-op when the key holds the 'creating' sentinel (no real sandbox yet)."""
|
||||
redis = _mock_redis(stored_sandbox_id=_CREATING_SENTINEL)
|
||||
with _patch_redis(redis):
|
||||
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_pause_timeout_returns_false(self):
|
||||
"""Returns False and preserves Redis entry when the E2B API call times out."""
|
||||
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
|
||||
with (
|
||||
_patch_redis(redis),
|
||||
patch(
|
||||
"backend.copilot.tools.e2b_sandbox.asyncio.wait_for",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=asyncio.TimeoutError,
|
||||
),
|
||||
):
|
||||
result = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
|
||||
|
||||
assert result is False
|
||||
# sandbox_id must remain in Redis so the next turn can reconnect
|
||||
redis.delete.assert_not_awaited()
|
||||
|
||||
def test_pause_then_reconnect_reuses_sandbox(self):
|
||||
"""After pause, get_or_create_sandbox reconnects the same sandbox.
|
||||
|
||||
Covers the pause->reconnect cycle: connect() auto-resumes a paused
|
||||
sandbox, and is_running() returns True once resume completes, so the
|
||||
same sandbox_id is reused rather than a new one being created.
|
||||
"""
|
||||
sb = _mock_sandbox(_SANDBOX_ID)
|
||||
redis = _mock_redis(stored_sandbox_id=_SANDBOX_ID)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
|
||||
# Step 1: pause the sandbox
|
||||
paused = asyncio.run(pause_sandbox(_SESSION_ID, _API_KEY))
|
||||
assert paused is True
|
||||
sb.pause.assert_awaited_once()
|
||||
|
||||
# Step 2: reconnect on next turn -- same sandbox should be returned
|
||||
result = asyncio.run(
|
||||
get_or_create_sandbox(_SESSION_ID, _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
assert result is sb
|
||||
mock_cls.create.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# pause_sandbox_direct
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPauseSandboxDirect:
|
||||
def test_pause_direct_success(self):
|
||||
"""Pauses the sandbox directly without a Redis lookup or reconnect."""
|
||||
sb = _mock_sandbox()
|
||||
result = asyncio.run(pause_sandbox_direct(sb, _SESSION_ID))
|
||||
|
||||
assert result is True
|
||||
sb.pause.assert_awaited_once()
|
||||
|
||||
def test_pause_direct_failure_returns_false(self):
|
||||
"""Returns False when sandbox.pause() raises."""
|
||||
sb = _mock_sandbox()
|
||||
sb.pause = AsyncMock(side_effect=RuntimeError("e2b error"))
|
||||
result = asyncio.run(pause_sandbox_direct(sb, _SESSION_ID))
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_pause_direct_timeout_returns_false(self):
|
||||
"""Returns False when sandbox.pause() exceeds the 10s timeout."""
|
||||
sb = _mock_sandbox()
|
||||
with patch(
|
||||
"backend.copilot.tools.e2b_sandbox.asyncio.wait_for",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=asyncio.TimeoutError,
|
||||
):
|
||||
result = asyncio.run(pause_sandbox_direct(sb, _SESSION_ID))
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is False
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
|
||||
@@ -1,20 +1,32 @@
|
||||
"""EditAgentTool - Edits existing agents using pre-built JSON."""
|
||||
"""EditAgentTool - Edits existing agents using natural language."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .agent_generator import get_agent_as_json
|
||||
from .agent_generator.pipeline import fetch_library_agents, fix_validate_and_save
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
get_user_message_for_error,
|
||||
save_agent_to_library,
|
||||
)
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ToolResponseBase
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EditAgentTool(BaseTool):
|
||||
"""Tool for editing existing agents using pre-built JSON."""
|
||||
"""Tool for editing existing agents using natural language."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -23,12 +35,11 @@ class EditAgentTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit an existing agent. Pass `agent_json` with the complete "
|
||||
"updated agent JSON you generated. The tool validates, auto-fixes, "
|
||||
"and saves.\n\n"
|
||||
"IMPORTANT: Before calling this tool, if the changes involve adding new "
|
||||
"Edit an existing agent from the user's library using natural language. "
|
||||
"Generates updates to the agent while preserving unchanged parts. "
|
||||
"\n\nIMPORTANT: Before calling this tool, if the changes involve adding new "
|
||||
"functionality, search for relevant existing agents using find_library_agent "
|
||||
"that could be used as building blocks."
|
||||
"that could be used as building blocks. Pass their IDs in library_agent_ids."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -47,20 +58,26 @@ class EditAgentTool(BaseTool):
|
||||
"Can be a graph ID or library agent ID."
|
||||
),
|
||||
},
|
||||
"agent_json": {
|
||||
"type": "object",
|
||||
"changes": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Complete updated agent JSON to validate and save. "
|
||||
"Must contain 'nodes' and 'links'. "
|
||||
"Include 'name' and/or 'description' if they need "
|
||||
"to be updated."
|
||||
"Natural language description of what changes to make. "
|
||||
"Be specific about what to add, remove, or modify."
|
||||
),
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Additional context or answers to previous clarifying questions."
|
||||
),
|
||||
},
|
||||
"library_agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"List of library agent IDs to use as building blocks for the changes."
|
||||
"List of library agent IDs to use as building blocks for the changes. "
|
||||
"If adding new functionality, search for relevant agents using "
|
||||
"find_library_agent first, then pass their IDs here."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
@@ -72,7 +89,7 @@ class EditAgentTool(BaseTool):
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
"required": ["agent_id", "agent_json"],
|
||||
"required": ["agent_id", "changes"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
@@ -81,39 +98,36 @@ class EditAgentTool(BaseTool):
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the edit_agent tool.
|
||||
|
||||
Flow:
|
||||
1. Fetch the current agent
|
||||
2. Generate updated agent (external service handles fixing and validation)
|
||||
3. Preview or save based on the save parameter
|
||||
"""
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
agent_json: dict[str, Any] | None = kwargs.get("agent_json")
|
||||
changes = kwargs.get("changes", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
library_agent_ids = kwargs.get("library_agent_ids", [])
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the agent ID to edit.",
|
||||
error="missing_agent_id",
|
||||
error="Missing agent_id parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not agent_json:
|
||||
if not changes:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Please provide agent_json with the complete updated agent graph."
|
||||
),
|
||||
error="missing_agent_json",
|
||||
message="Please describe what changes you want to make.",
|
||||
error="Missing changes parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
save = kwargs.get("save", True)
|
||||
library_agent_ids = kwargs.get("library_agent_ids", [])
|
||||
|
||||
nodes = agent_json.get("nodes", [])
|
||||
if not nodes:
|
||||
return ErrorResponse(
|
||||
message="The agent JSON has no nodes.",
|
||||
error="empty_agent",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Preserve original agent's ID
|
||||
current_agent = await get_agent_as_json(agent_id, user_id)
|
||||
|
||||
if current_agent is None:
|
||||
return ErrorResponse(
|
||||
message=f"Could not find agent with ID '{agent_id}' in your library.",
|
||||
@@ -121,19 +135,142 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agent_json["id"] = current_agent.get("id", agent_id)
|
||||
agent_json["version"] = current_agent.get("version", 1)
|
||||
agent_json.setdefault("is_active", True)
|
||||
# Fetch library agents by IDs if provided
|
||||
library_agents = None
|
||||
if user_id and library_agent_ids:
|
||||
try:
|
||||
from .agent_generator import get_library_agents_by_ids
|
||||
|
||||
# Fetch library agents for AgentExecutorBlock validation
|
||||
library_agents = await fetch_library_agents(user_id, library_agent_ids)
|
||||
graph_id = current_agent.get("id")
|
||||
# Filter out the current agent being edited
|
||||
filtered_ids = [id for id in library_agent_ids if id != graph_id]
|
||||
|
||||
return await fix_validate_and_save(
|
||||
agent_json,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
save=save,
|
||||
is_update=True,
|
||||
default_name="Updated Agent",
|
||||
library_agents=library_agents,
|
||||
)
|
||||
library_agents = await get_library_agents_by_ids(
|
||||
user_id=user_id,
|
||||
agent_ids=filtered_ids,
|
||||
)
|
||||
logger.debug(
|
||||
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agents by IDs: {e}")
|
||||
|
||||
update_request = changes
|
||||
if context:
|
||||
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
||||
|
||||
try:
|
||||
result = await generate_agent_patch(
|
||||
update_request,
|
||||
current_agent,
|
||||
library_agents,
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Agent editing is not available. "
|
||||
"The Agent Generator service is not configured."
|
||||
),
|
||||
error="service_not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return ErrorResponse(
|
||||
message="Failed to generate changes. The agent generation service may be unavailable or timed out. Please try again.",
|
||||
error="update_generation_failed",
|
||||
details={"agent_id": agent_id, "changes": changes[:100]},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if the result is an error from the external service
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
error_type = result.get("error_type", "unknown")
|
||||
user_message = get_user_message_for_error(
|
||||
error_type,
|
||||
operation="generate the changes",
|
||||
llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.",
|
||||
validation_message="The generated changes failed validation. Please try rephrasing your request.",
|
||||
error_details=error_msg,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
error=f"update_generation_failed:{error_type}",
|
||||
details={
|
||||
"agent_id": agent_id,
|
||||
"changes": changes[:100],
|
||||
"service_error": error_msg,
|
||||
"error_type": error_type,
|
||||
},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if result.get("type") == "clarifying_questions":
|
||||
questions = result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information about the changes. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
updated_agent = result
|
||||
|
||||
agent_name = updated_agent.get("name", "Updated Agent")
|
||||
agent_description = updated_agent.get("description", "")
|
||||
node_count = len(updated_agent.get("nodes", []))
|
||||
link_count = len(updated_agent.get("links", []))
|
||||
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've updated the agent. "
|
||||
f"The agent now has {node_count} blocks. "
|
||||
f"Review it and call edit_agent with save=true to save the changes."
|
||||
),
|
||||
agent_json=updated_agent,
|
||||
agent_name=agent_name,
|
||||
description=agent_description,
|
||||
node_count=node_count,
|
||||
link_count=link_count,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
error="auth_required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
updated_agent, user_id, is_update=True
|
||||
)
|
||||
|
||||
return AgentSavedResponse(
|
||||
message=f"Updated agent '{created_graph.name}' has been saved to your library!",
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to save the updated agent: {str(e)}",
|
||||
error="save_failed",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -32,7 +32,7 @@ COPILOT_EXCLUDED_BLOCK_TYPES = {
|
||||
BlockType.NOTE, # Visual annotation only - no runtime behavior
|
||||
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
|
||||
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
|
||||
BlockType.MCP_TOOL, # Has dedicated run_mcp_tool tool with discovery + auth flow
|
||||
BlockType.MCP_TOOL, # Has dedicated run_mcp_tool tool with proper discovery + auth flow
|
||||
}
|
||||
|
||||
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
|
||||
@@ -72,15 +72,6 @@ class FindBlockTool(BaseTool):
|
||||
"Use keywords like 'email', 'http', 'text', 'ai', etc."
|
||||
),
|
||||
},
|
||||
"include_schemas": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"If true, include full input_schema and output_schema "
|
||||
"for each block. Use when generating agent JSON that "
|
||||
"needs block schemas. Default is false."
|
||||
),
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
@@ -108,7 +99,6 @@ class FindBlockTool(BaseTool):
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
include_schemas = kwargs.get("include_schemas", False)
|
||||
session_id = session.session_id
|
||||
|
||||
if not query:
|
||||
@@ -153,21 +143,15 @@ class FindBlockTool(BaseTool):
|
||||
):
|
||||
continue
|
||||
|
||||
summary = BlockInfoSummary(
|
||||
id=block_id,
|
||||
name=block.name,
|
||||
description=block.optimized_description or block.description or "",
|
||||
categories=[c.value for c in block.categories],
|
||||
blocks.append(
|
||||
BlockInfoSummary(
|
||||
id=block_id,
|
||||
name=block.name,
|
||||
description=block.description or "",
|
||||
categories=[c.value for c in block.categories],
|
||||
)
|
||||
)
|
||||
|
||||
if include_schemas:
|
||||
info = block.get_info()
|
||||
summary.input_schema = info.inputSchema
|
||||
summary.output_schema = info.outputSchema
|
||||
summary.static_output = info.staticOutput
|
||||
|
||||
blocks.append(summary)
|
||||
|
||||
if len(blocks) >= _TARGET_RESULTS:
|
||||
break
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ def make_mock_block(
|
||||
input_schema: dict | None = None,
|
||||
output_schema: dict | None = None,
|
||||
credentials_fields: dict | None = None,
|
||||
static_output: bool = False,
|
||||
):
|
||||
"""Create a mock block for testing."""
|
||||
mock = MagicMock()
|
||||
@@ -34,7 +33,6 @@ def make_mock_block(
|
||||
mock.description = f"{name} description"
|
||||
mock.block_type = block_type
|
||||
mock.disabled = disabled
|
||||
mock.static_output = static_output
|
||||
mock.input_schema = MagicMock()
|
||||
mock.input_schema.jsonschema.return_value = input_schema or {
|
||||
"properties": {},
|
||||
@@ -44,15 +42,6 @@ def make_mock_block(
|
||||
mock.output_schema = MagicMock()
|
||||
mock.output_schema.jsonschema.return_value = output_schema or {}
|
||||
mock.categories = []
|
||||
mock.optimized_description = None
|
||||
|
||||
# Mock get_info() for include_schemas support
|
||||
mock_info = MagicMock()
|
||||
mock_info.inputSchema = input_schema or {"properties": {}, "required": []}
|
||||
mock_info.outputSchema = output_schema or {}
|
||||
mock_info.staticOutput = static_output
|
||||
mock.get_info.return_value = mock_info
|
||||
|
||||
return mock
|
||||
|
||||
|
||||
@@ -410,92 +399,3 @@ class TestFindBlockFiltering:
|
||||
f"Average chars per block ({avg_chars}) exceeds 500. "
|
||||
f"Total response: {total_chars} chars for {response.count} blocks."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_include_schemas_false_omits_schemas(self):
|
||||
"""Without include_schemas, schemas should be empty dicts."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
input_schema = {"properties": {"url": {"type": "string"}}, "required": ["url"]}
|
||||
output_schema = {"properties": {"result": {"type": "string"}}}
|
||||
|
||||
search_results = [{"content_id": "block-1", "score": 0.9}]
|
||||
block = make_mock_block(
|
||||
"block-1",
|
||||
"Test Block",
|
||||
BlockType.STANDARD,
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
)
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||
return_value=(search_results, 1)
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
return_value=block,
|
||||
),
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
query="test",
|
||||
include_schemas=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert response.blocks[0].input_schema == {}
|
||||
assert response.blocks[0].output_schema == {}
|
||||
assert response.blocks[0].static_output is False
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_include_schemas_true_populates_schemas(self):
|
||||
"""With include_schemas=true, schemas should be populated from block info."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
input_schema = {"properties": {"url": {"type": "string"}}, "required": ["url"]}
|
||||
output_schema = {"properties": {"result": {"type": "string"}}}
|
||||
|
||||
search_results = [{"content_id": "block-1", "score": 0.9}]
|
||||
block = make_mock_block(
|
||||
"block-1",
|
||||
"Test Block",
|
||||
BlockType.STANDARD,
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||
return_value=(search_results, 1)
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
return_value=block,
|
||||
),
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
query="test",
|
||||
include_schemas=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert response.blocks[0].input_schema == input_schema
|
||||
assert response.blocks[0].output_schema == output_schema
|
||||
assert response.blocks[0].static_output is True
|
||||
|
||||
@@ -22,9 +22,6 @@ class FindLibraryAgentTool(BaseTool):
|
||||
"Search for or list agents in the user's library. Use this to find "
|
||||
"agents the user has already added to their library, including agents "
|
||||
"they created or added from the marketplace. "
|
||||
"When creating agents with sub-agent composition, use this to get "
|
||||
"the agent's graph_id, graph_version, input_schema, and output_schema "
|
||||
"needed for AgentExecutorBlock nodes. "
|
||||
"Omit the query to list all agents."
|
||||
)
|
||||
|
||||
|
||||
@@ -1,134 +0,0 @@
|
||||
"""FixAgentGraphTool - Auto-fixes common agent JSON issues."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .agent_generator.validation import AgentFixer, AgentValidator, get_blocks_as_dicts
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, FixResultResponse, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FixAgentGraphTool(BaseTool):
|
||||
"""Tool for auto-fixing common issues in agent JSON graphs."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "fix_agent_graph"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Auto-fix common issues in an agent JSON graph. Applies fixes for:\n"
|
||||
"- Missing or invalid UUIDs on nodes and links\n"
|
||||
"- StoreValueBlock prerequisites for ConditionBlock\n"
|
||||
"- Double curly brace escaping in prompt templates\n"
|
||||
"- AddToList/AddToDictionary prerequisite blocks\n"
|
||||
"- CodeExecutionBlock output field naming\n"
|
||||
"- Missing credentials configuration\n"
|
||||
"- Node X coordinate spacing (800+ units apart)\n"
|
||||
"- AI model default parameters\n"
|
||||
"- Link static properties based on input schema\n"
|
||||
"- Type mismatches (inserts conversion blocks)\n\n"
|
||||
"Returns the fixed agent JSON plus a list of fixes applied. "
|
||||
"After fixing, the agent is re-validated. If still invalid, "
|
||||
"the remaining errors are included in the response."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_json": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"The agent JSON to fix. Must contain 'nodes' and 'links' arrays."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["agent_json"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
agent_json = kwargs.get("agent_json")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_json or not isinstance(agent_json, dict):
|
||||
return ErrorResponse(
|
||||
message="Please provide a valid agent JSON object.",
|
||||
error="Missing or invalid agent_json parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
nodes = agent_json.get("nodes", [])
|
||||
|
||||
if not nodes:
|
||||
return ErrorResponse(
|
||||
message="The agent JSON has no nodes. An agent needs at least one block.",
|
||||
error="empty_agent",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
blocks = get_blocks_as_dicts()
|
||||
fixer = AgentFixer()
|
||||
fixed_agent = fixer.apply_all_fixes(agent_json, blocks)
|
||||
fixes_applied = fixer.get_fixes_applied()
|
||||
except Exception as e:
|
||||
logger.error(f"Fixer error: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Auto-fix encountered an error: {str(e)}",
|
||||
error="fix_exception",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Re-validate after fixing
|
||||
try:
|
||||
validator = AgentValidator()
|
||||
is_valid, _ = validator.validate(fixed_agent, blocks)
|
||||
remaining_errors = validator.errors if not is_valid else []
|
||||
except Exception as e:
|
||||
logger.warning(f"Post-fix validation error: {e}", exc_info=True)
|
||||
remaining_errors = [f"Post-fix validation failed: {str(e)}"]
|
||||
is_valid = False
|
||||
|
||||
if is_valid:
|
||||
return FixResultResponse(
|
||||
message=(
|
||||
f"Applied {len(fixes_applied)} fix(es). "
|
||||
"Agent graph is now valid!"
|
||||
),
|
||||
fixed_agent_json=fixed_agent,
|
||||
fixes_applied=fixes_applied,
|
||||
fix_count=len(fixes_applied),
|
||||
valid_after_fix=True,
|
||||
remaining_errors=[],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return FixResultResponse(
|
||||
message=(
|
||||
f"Applied {len(fixes_applied)} fix(es), but "
|
||||
f"{len(remaining_errors)} issue(s) remain. "
|
||||
"Review the remaining errors and fix manually."
|
||||
),
|
||||
fixed_agent_json=fixed_agent,
|
||||
fixes_applied=fixes_applied,
|
||||
fix_count=len(fixes_applied),
|
||||
valid_after_fix=False,
|
||||
remaining_errors=remaining_errors,
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -1,189 +0,0 @@
|
||||
"""Tests for FixAgentGraphTool."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.tools.fix_agent import FixAgentGraphTool
|
||||
from backend.copilot.tools.models import ErrorResponse, FixResultResponse
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-fix-agent"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return FixAgentGraphTool()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session():
|
||||
return make_session(_TEST_USER_ID)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_agent_json_returns_error(tool, session):
|
||||
"""Missing agent_json returns ErrorResponse."""
|
||||
result = await tool._execute(user_id=_TEST_USER_ID, session=session)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error is not None
|
||||
assert "agent_json" in result.error.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_nodes_returns_error(tool, session):
|
||||
"""Agent JSON with no nodes returns ErrorResponse."""
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json={"nodes": [], "links": []},
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "no nodes" in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fix_and_validate_success(tool, session):
|
||||
"""Fixer applies fixes and validator passes -> valid_after_fix=True."""
|
||||
agent_json = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"block_id": "block-1",
|
||||
"input_default": {},
|
||||
"metadata": {"position": {"x": 0, "y": 0}},
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
}
|
||||
|
||||
fixed_agent = dict(agent_json) # Fixer returns the full agent dict
|
||||
|
||||
mock_fixer = MagicMock()
|
||||
mock_fixer.apply_all_fixes = MagicMock(return_value=fixed_agent)
|
||||
mock_fixer.get_fixes_applied.return_value = ["Fixed node UUID format"]
|
||||
|
||||
mock_validator = MagicMock()
|
||||
mock_validator.validate.return_value = (True, None)
|
||||
mock_validator.errors = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.fix_agent.get_blocks_as_dicts",
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.fix_agent.AgentFixer",
|
||||
return_value=mock_fixer,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.fix_agent.AgentValidator",
|
||||
return_value=mock_validator,
|
||||
),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json=agent_json,
|
||||
)
|
||||
|
||||
assert isinstance(result, FixResultResponse)
|
||||
assert result.valid_after_fix is True
|
||||
assert result.fix_count == 1
|
||||
assert result.fixes_applied == ["Fixed node UUID format"]
|
||||
assert result.remaining_errors == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fix_with_remaining_errors(tool, session):
|
||||
"""Fixer applies some fixes but validation still fails."""
|
||||
agent_json = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node-1",
|
||||
"block_id": "block-1",
|
||||
"input_default": {},
|
||||
"metadata": {},
|
||||
}
|
||||
],
|
||||
"links": [
|
||||
{
|
||||
"id": "link-1",
|
||||
"source_id": "node-1",
|
||||
"source_name": "output",
|
||||
"sink_id": "node-2",
|
||||
"sink_name": "input",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
fixed_agent = dict(agent_json)
|
||||
|
||||
mock_fixer = MagicMock()
|
||||
mock_fixer.apply_all_fixes = MagicMock(return_value=fixed_agent)
|
||||
mock_fixer.get_fixes_applied.return_value = ["Fixed UUID"]
|
||||
|
||||
mock_validator = MagicMock()
|
||||
mock_validator.validate.return_value = (
|
||||
False,
|
||||
"Link references non-existent node 'node-2'",
|
||||
)
|
||||
mock_validator.errors = ["Link references non-existent node 'node-2'"]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.fix_agent.get_blocks_as_dicts",
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.fix_agent.AgentFixer",
|
||||
return_value=mock_fixer,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.fix_agent.AgentValidator",
|
||||
return_value=mock_validator,
|
||||
),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json=agent_json,
|
||||
)
|
||||
|
||||
assert isinstance(result, FixResultResponse)
|
||||
assert result.valid_after_fix is False
|
||||
assert result.fix_count == 1
|
||||
assert len(result.remaining_errors) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fixer_exception_returns_error(tool, session):
|
||||
"""Fixer exception returns ErrorResponse."""
|
||||
agent_json = {
|
||||
"nodes": [{"id": "n1", "block_id": "b1", "input_default": {}, "metadata": {}}],
|
||||
"links": [],
|
||||
}
|
||||
|
||||
mock_fixer = MagicMock()
|
||||
mock_fixer.apply_all_fixes = MagicMock(side_effect=RuntimeError("fixer crashed"))
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.fix_agent.get_blocks_as_dicts",
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.fix_agent.AgentFixer",
|
||||
return_value=mock_fixer,
|
||||
),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_json=agent_json,
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error is not None
|
||||
assert "fix_exception" in result.error
|
||||
@@ -1,84 +0,0 @@
|
||||
"""GetAgentBuildingGuideTool - Returns the complete agent building guide."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GUIDE_CACHE: str | None = None
|
||||
|
||||
|
||||
def _load_guide() -> str:
|
||||
global _GUIDE_CACHE
|
||||
if _GUIDE_CACHE is None:
|
||||
guide_path = Path(__file__).parent.parent / "sdk" / "agent_generation_guide.md"
|
||||
_GUIDE_CACHE = guide_path.read_text(encoding="utf-8")
|
||||
return _GUIDE_CACHE
|
||||
|
||||
|
||||
class AgentBuildingGuideResponse(ToolResponseBase):
|
||||
"""Response containing the agent building guide."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_BUILDER_GUIDE
|
||||
content: str
|
||||
|
||||
|
||||
class GetAgentBuildingGuideTool(BaseTool):
|
||||
"""Returns the complete guide for building agent JSON graphs.
|
||||
|
||||
Covers block IDs, link structure, AgentInputBlock, AgentOutputBlock,
|
||||
AgentExecutorBlock (sub-agent composition), and MCPToolBlock usage.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_agent_building_guide"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Returns the complete guide for building agent JSON graphs, including "
|
||||
"block IDs, link structure, AgentInputBlock, AgentOutputBlock, "
|
||||
"AgentExecutorBlock (for sub-agent composition), and MCPToolBlock usage. "
|
||||
"Call this before generating agent JSON to ensure correct structure."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id if session else None
|
||||
try:
|
||||
content = _load_guide()
|
||||
return AgentBuildingGuideResponse(
|
||||
message="Agent building guide loaded.",
|
||||
content=content,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to load agent building guide: %s", e)
|
||||
return ErrorResponse(
|
||||
message="Failed to load agent building guide.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -1,79 +0,0 @@
|
||||
"""GetMCPGuideTool - Returns the MCP tool usage guide."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GUIDE_CACHE: str | None = None
|
||||
|
||||
|
||||
def _load_guide() -> str:
|
||||
global _GUIDE_CACHE
|
||||
if _GUIDE_CACHE is None:
|
||||
guide_path = Path(__file__).parent.parent / "sdk" / "mcp_tool_guide.md"
|
||||
_GUIDE_CACHE = guide_path.read_text(encoding="utf-8")
|
||||
return _GUIDE_CACHE
|
||||
|
||||
|
||||
class MCPGuideResponse(ToolResponseBase):
|
||||
"""Response containing the MCP tool guide."""
|
||||
|
||||
type: ResponseType = ResponseType.MCP_GUIDE
|
||||
content: str
|
||||
|
||||
|
||||
class GetMCPGuideTool(BaseTool):
|
||||
"""Returns the MCP tool usage guide with known server URLs and auth details."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_mcp_guide"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Returns the MCP tool guide: known hosted server URLs (Notion, Linear, "
|
||||
"Stripe, Intercom, Cloudflare, Atlassian) and authentication workflow. "
|
||||
"Call before using run_mcp_tool if you need a server URL or auth info."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id if session else None
|
||||
try:
|
||||
content = _load_guide()
|
||||
return MCPGuideResponse(
|
||||
message="MCP guide loaded.",
|
||||
content=content,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to load MCP guide: %s", e)
|
||||
return ErrorResponse(
|
||||
message="Failed to load MCP guide.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -1,24 +1,7 @@
|
||||
"""Shared helpers for chat tools."""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
from .models import BlockOutputResponse, ErrorResponse, ToolResponseBase
|
||||
from .utils import match_credentials_to_requirements
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_inputs_from_schema(
|
||||
input_schema: dict[str, Any],
|
||||
@@ -44,159 +27,3 @@ def get_inputs_from_schema(
|
||||
for name, schema in properties.items()
|
||||
if name not in exclude
|
||||
]
|
||||
|
||||
|
||||
async def execute_block(
|
||||
*,
|
||||
block: AnyBlockSchema,
|
||||
block_id: str,
|
||||
input_data: dict[str, Any],
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
node_exec_id: str,
|
||||
matched_credentials: dict[str, CredentialsMetaInput],
|
||||
sensitive_action_safe_mode: bool = False,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute a block with full context setup, credential injection, and error handling.
|
||||
|
||||
This is the shared execution path used by both ``run_block`` (after review
|
||||
check) and ``continue_run_block`` (after approval).
|
||||
|
||||
Returns:
|
||||
BlockOutputResponse on success, ErrorResponse on failure.
|
||||
"""
|
||||
try:
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
|
||||
synthetic_graph_id = f"{COPILOT_SESSION_PREFIX}{session_id}"
|
||||
synthetic_node_id = f"{COPILOT_NODE_PREFIX}{block_id}"
|
||||
|
||||
execution_context = ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
graph_version=1,
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
workspace_id=workspace.id,
|
||||
session_id=session_id,
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
)
|
||||
|
||||
exec_kwargs: dict[str, Any] = {
|
||||
"user_id": user_id,
|
||||
"execution_context": execution_context,
|
||||
"workspace_id": workspace.id,
|
||||
"graph_exec_id": synthetic_graph_id,
|
||||
"node_exec_id": node_exec_id,
|
||||
"node_id": synthetic_node_id,
|
||||
"graph_version": 1,
|
||||
"graph_id": synthetic_graph_id,
|
||||
}
|
||||
|
||||
# Inject credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
for field_name, cred_meta in matched_credentials.items():
|
||||
if field_name not in input_data:
|
||||
input_data[field_name] = cred_meta.model_dump()
|
||||
|
||||
actual_credentials = await creds_manager.get(
|
||||
user_id, cred_meta.id, lock=False
|
||||
)
|
||||
if actual_credentials:
|
||||
exec_kwargs[field_name] = actual_credentials
|
||||
else:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to retrieve credentials for {field_name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Execute the block and collect outputs
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in block.execute(
|
||||
input_data,
|
||||
**exec_kwargs,
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
outputs=dict(outputs),
|
||||
success=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except BlockError as e:
|
||||
logger.warning(f"Block execution failed: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Block execution failed: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to execute block: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
async def resolve_block_credentials(
|
||||
user_id: str,
|
||||
block: AnyBlockSchema,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""Resolve credentials for a block by matching user's available credentials.
|
||||
|
||||
Handles discriminated credentials (e.g. provider selection based on model).
|
||||
|
||||
Returns:
|
||||
(matched_credentials, missing_credentials)
|
||||
"""
|
||||
input_data = input_data or {}
|
||||
requirements = _resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
if not requirements:
|
||||
return {}, []
|
||||
|
||||
return await match_credentials_to_requirements(user_id, requirements)
|
||||
|
||||
|
||||
def _resolve_discriminated_credentials(
|
||||
block: AnyBlockSchema,
|
||||
input_data: dict[str, Any],
|
||||
) -> dict[str, CredentialsFieldInfo]:
|
||||
"""Resolve credential requirements, applying discriminator logic where needed."""
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
if not credentials_fields_info:
|
||||
return {}
|
||||
|
||||
resolved: dict[str, CredentialsFieldInfo] = {}
|
||||
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
effective_field_info = field_info
|
||||
|
||||
if field_info.discriminator and field_info.discriminator_mapping:
|
||||
discriminator_value = input_data.get(field_info.discriminator)
|
||||
if discriminator_value is None:
|
||||
field = block.input_schema.model_fields.get(field_info.discriminator)
|
||||
if field and field.default is not PydanticUndefined:
|
||||
discriminator_value = field.default
|
||||
|
||||
if (
|
||||
discriminator_value
|
||||
and discriminator_value in field_info.discriminator_mapping
|
||||
):
|
||||
effective_field_info = field_info.discriminate(discriminator_value)
|
||||
effective_field_info.discriminator_values.add(discriminator_value)
|
||||
logger.debug(
|
||||
f"Discriminated provider for {field_name}: "
|
||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||
)
|
||||
|
||||
resolved[field_name] = effective_field_info
|
||||
|
||||
return resolved
|
||||
|
||||
@@ -12,52 +12,50 @@ from backend.data.model import CredentialsMetaInput
|
||||
class ResponseType(str, Enum):
|
||||
"""Types of tool responses."""
|
||||
|
||||
# General
|
||||
ERROR = "error"
|
||||
NO_RESULTS = "no_results"
|
||||
NEED_LOGIN = "need_login"
|
||||
|
||||
# Agent discovery & execution
|
||||
AGENTS_FOUND = "agents_found"
|
||||
AGENT_DETAILS = "agent_details"
|
||||
SETUP_REQUIREMENTS = "setup_requirements"
|
||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
||||
EXECUTION_STARTED = "execution_started"
|
||||
NEED_LOGIN = "need_login"
|
||||
ERROR = "error"
|
||||
NO_RESULTS = "no_results"
|
||||
AGENT_OUTPUT = "agent_output"
|
||||
UNDERSTANDING_UPDATED = "understanding_updated"
|
||||
SUGGESTED_GOAL = "suggested_goal"
|
||||
|
||||
# Agent builder (create / edit / validate / fix)
|
||||
AGENT_BUILDER_GUIDE = "agent_builder_guide"
|
||||
AGENT_BUILDER_PREVIEW = "agent_builder_preview"
|
||||
AGENT_BUILDER_SAVED = "agent_builder_saved"
|
||||
AGENT_BUILDER_CLARIFICATION_NEEDED = "agent_builder_clarification_needed"
|
||||
AGENT_BUILDER_VALIDATION_RESULT = "agent_builder_validation_result"
|
||||
AGENT_BUILDER_FIX_RESULT = "agent_builder_fix_result"
|
||||
|
||||
# Block
|
||||
AGENT_PREVIEW = "agent_preview"
|
||||
AGENT_SAVED = "agent_saved"
|
||||
CLARIFICATION_NEEDED = "clarification_needed"
|
||||
BLOCK_LIST = "block_list"
|
||||
BLOCK_DETAILS = "block_details"
|
||||
BLOCK_OUTPUT = "block_output"
|
||||
REVIEW_REQUIRED = "review_required"
|
||||
|
||||
# MCP
|
||||
MCP_GUIDE = "mcp_guide"
|
||||
MCP_TOOLS_DISCOVERED = "mcp_tools_discovered"
|
||||
MCP_TOOL_OUTPUT = "mcp_tool_output"
|
||||
|
||||
# Docs
|
||||
DOC_SEARCH_RESULTS = "doc_search_results"
|
||||
DOC_PAGE = "doc_page"
|
||||
|
||||
# Workspace files
|
||||
# Workspace response types
|
||||
WORKSPACE_FILE_LIST = "workspace_file_list"
|
||||
WORKSPACE_FILE_CONTENT = "workspace_file_content"
|
||||
WORKSPACE_FILE_METADATA = "workspace_file_metadata"
|
||||
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
|
||||
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
|
||||
|
||||
# Folder management
|
||||
# Long-running operation types
|
||||
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||
# Input validation
|
||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
||||
# Web fetch
|
||||
WEB_FETCH = "web_fetch"
|
||||
# Agent-browser multi-step automation (navigate, act, screenshot)
|
||||
BROWSER_NAVIGATE = "browser_navigate"
|
||||
BROWSER_ACT = "browser_act"
|
||||
BROWSER_SCREENSHOT = "browser_screenshot"
|
||||
# Code execution
|
||||
BASH_EXEC = "bash_exec"
|
||||
# Feature request types
|
||||
FEATURE_REQUEST_SEARCH = "feature_request_search"
|
||||
FEATURE_REQUEST_CREATED = "feature_request_created"
|
||||
# Goal refinement
|
||||
SUGGESTED_GOAL = "suggested_goal"
|
||||
# MCP tool types
|
||||
MCP_TOOLS_DISCOVERED = "mcp_tools_discovered"
|
||||
MCP_TOOL_OUTPUT = "mcp_tool_output"
|
||||
# Folder management types
|
||||
FOLDER_CREATED = "folder_created"
|
||||
FOLDER_LIST = "folder_list"
|
||||
FOLDER_UPDATED = "folder_updated"
|
||||
@@ -65,21 +63,6 @@ class ResponseType(str, Enum):
|
||||
FOLDER_DELETED = "folder_deleted"
|
||||
AGENTS_MOVED_TO_FOLDER = "agents_moved_to_folder"
|
||||
|
||||
# Browser automation
|
||||
BROWSER_NAVIGATE = "browser_navigate"
|
||||
BROWSER_ACT = "browser_act"
|
||||
BROWSER_SCREENSHOT = "browser_screenshot"
|
||||
|
||||
# Code execution
|
||||
BASH_EXEC = "bash_exec"
|
||||
|
||||
# Web
|
||||
WEB_FETCH = "web_fetch"
|
||||
|
||||
# Feature requests
|
||||
FEATURE_REQUEST_SEARCH = "feature_request_search"
|
||||
FEATURE_REQUEST_CREATED = "feature_request_created"
|
||||
|
||||
|
||||
# Base response model
|
||||
class ToolResponseBase(BaseModel):
|
||||
@@ -109,15 +92,6 @@ class AgentInfo(BaseModel):
|
||||
has_external_trigger: bool | None = None
|
||||
new_output: bool | None = None
|
||||
graph_id: str | None = None
|
||||
graph_version: int | None = None
|
||||
input_schema: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="JSON Schema for the agent's inputs (for AgentExecutorBlock)",
|
||||
)
|
||||
output_schema: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="JSON Schema for the agent's outputs (for AgentExecutorBlock)",
|
||||
)
|
||||
inputs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Input schema for the agent, including field names, types, and defaults",
|
||||
@@ -308,7 +282,7 @@ class ClarifyingQuestion(BaseModel):
|
||||
class AgentPreviewResponse(ToolResponseBase):
|
||||
"""Response for previewing a generated agent before saving."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_BUILDER_PREVIEW
|
||||
type: ResponseType = ResponseType.AGENT_PREVIEW
|
||||
agent_json: dict[str, Any]
|
||||
agent_name: str
|
||||
description: str
|
||||
@@ -319,7 +293,7 @@ class AgentPreviewResponse(ToolResponseBase):
|
||||
class AgentSavedResponse(ToolResponseBase):
|
||||
"""Response when an agent is saved to the library."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_BUILDER_SAVED
|
||||
type: ResponseType = ResponseType.AGENT_SAVED
|
||||
agent_id: str
|
||||
agent_name: str
|
||||
library_agent_id: str
|
||||
@@ -330,7 +304,7 @@ class AgentSavedResponse(ToolResponseBase):
|
||||
class ClarificationNeededResponse(ToolResponseBase):
|
||||
"""Response when the LLM needs more information from the user."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_BUILDER_CLARIFICATION_NEEDED
|
||||
type: ResponseType = ResponseType.CLARIFICATION_NEEDED
|
||||
questions: list[ClarifyingQuestion] = Field(default_factory=list)
|
||||
|
||||
|
||||
@@ -407,10 +381,6 @@ class BlockInfoSummary(BaseModel):
|
||||
default_factory=dict,
|
||||
description="Full JSON schema for block outputs",
|
||||
)
|
||||
static_output: bool = Field(
|
||||
default=False,
|
||||
description="Whether the block produces output without needing input",
|
||||
)
|
||||
required_inputs: list[BlockInputFieldInfo] = Field(
|
||||
default_factory=list,
|
||||
description="List of input fields for this block",
|
||||
@@ -459,19 +429,16 @@ class BlockOutputResponse(ToolResponseBase):
|
||||
success: bool = True
|
||||
|
||||
|
||||
class ReviewRequiredResponse(ToolResponseBase):
|
||||
"""Response when a block requires human review before execution."""
|
||||
# Long-running operation models
|
||||
class OperationInProgressResponse(ToolResponseBase):
|
||||
"""Response when an operation is already in progress.
|
||||
|
||||
type: ResponseType = ResponseType.REVIEW_REQUIRED
|
||||
block_id: str
|
||||
block_name: str
|
||||
review_id: str = Field(description="The review ID for tracking approval status")
|
||||
graph_exec_id: str = Field(
|
||||
description="The graph execution ID for fetching review status"
|
||||
)
|
||||
input_data: dict[str, Any] = Field(
|
||||
description="The input data that requires review"
|
||||
)
|
||||
Returned for idempotency when the same tool_call_id is requested again
|
||||
while the background task is still running.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class WebFetchResponse(ToolResponseBase):
|
||||
@@ -581,29 +548,6 @@ class BrowserScreenshotResponse(ToolResponseBase):
|
||||
filename: str
|
||||
|
||||
|
||||
# Agent generation tool response models
|
||||
|
||||
|
||||
class ValidationResultResponse(ToolResponseBase):
|
||||
"""Response for validate_agent_graph tool."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_BUILDER_VALIDATION_RESULT
|
||||
valid: bool
|
||||
errors: list[str] = Field(default_factory=list)
|
||||
error_count: int = 0
|
||||
|
||||
|
||||
class FixResultResponse(ToolResponseBase):
|
||||
"""Response for fix_agent_graph tool."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_BUILDER_FIX_RESULT
|
||||
fixed_agent_json: dict[str, Any]
|
||||
fixes_applied: list[str] = Field(default_factory=list)
|
||||
fix_count: int = 0
|
||||
valid_after_fix: bool = False
|
||||
remaining_errors: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
# Folder management models
|
||||
|
||||
|
||||
|
||||
@@ -534,9 +534,7 @@ class RunAgentTool(BaseTool):
|
||||
return ExecutionStartedResponse(
|
||||
message=(
|
||||
f"Agent '{library_agent.name}' is awaiting human review. "
|
||||
f"The user can approve or reject inline. After approval, "
|
||||
f"the execution resumes automatically. Use view_agent_output "
|
||||
f"with execution_id='{execution.id}' to check the result."
|
||||
f"Check at {library_agent_link}."
|
||||
),
|
||||
session_id=session_id,
|
||||
execution_id=execution.id,
|
||||
|
||||
@@ -2,34 +2,38 @@
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks import BlockType, get_block
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.copilot.constants import (
|
||||
COPILOT_NODE_EXEC_ID_SEPARATOR,
|
||||
COPILOT_NODE_PREFIX,
|
||||
COPILOT_SESSION_PREFIX,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import review_db
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
from .base import BaseTool
|
||||
from .find_block import COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
from .helpers import execute_block, get_inputs_from_schema, resolve_block_credentials
|
||||
from .helpers import get_inputs_from_schema
|
||||
from .models import (
|
||||
BlockDetails,
|
||||
BlockDetailsResponse,
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
InputValidationErrorResponse,
|
||||
ReviewRequiredResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
from .utils import build_missing_credentials_from_field_info
|
||||
from .utils import (
|
||||
build_missing_credentials_from_field_info,
|
||||
match_credentials_to_requirements,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -48,9 +52,7 @@ class RunBlockTool(BaseTool):
|
||||
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
|
||||
"do NOT guess or make up block IDs. "
|
||||
"On first attempt (without input_data), returns detailed schema showing "
|
||||
"required inputs and outputs. Then call again with proper input_data to execute. "
|
||||
"If a block requires human review, use continue_run_block with the "
|
||||
"review_id after the user approves."
|
||||
"required inputs and outputs. Then call again with proper input_data to execute."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -81,7 +83,7 @@ class RunBlockTool(BaseTool):
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["block_id", "block_name", "input_data"],
|
||||
"required": ["block_id", "input_data"],
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -147,27 +149,21 @@ class RunBlockTool(BaseTool):
|
||||
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||
):
|
||||
# Provide actionable guidance for blocks with dedicated tools
|
||||
if block.block_type == BlockType.MCP_TOOL:
|
||||
hint = (
|
||||
" Use the `run_mcp_tool` tool instead — it handles "
|
||||
"MCP server discovery, authentication, and execution."
|
||||
)
|
||||
elif block.block_type == BlockType.AGENT:
|
||||
hint = " Use the `run_agent` tool instead."
|
||||
else:
|
||||
hint = " This block is designed for use within graphs only."
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block.name}' cannot be run directly.{hint}",
|
||||
message=(
|
||||
f"Block '{block.name}' cannot be run directly in CoPilot. "
|
||||
"This block is designed for use within graphs only."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
(
|
||||
matched_credentials,
|
||||
missing_credentials,
|
||||
) = await resolve_block_credentials(user_id, block, input_data)
|
||||
) = await self._resolve_block_credentials(user_id, block, input_data)
|
||||
|
||||
# Get block schemas for details/validation
|
||||
try:
|
||||
@@ -276,97 +272,169 @@ class RunBlockTool(BaseTool):
|
||||
user_authenticated=True,
|
||||
)
|
||||
|
||||
# Generate synthetic IDs for CoPilot context.
|
||||
# Encode node_id in node_exec_id so it can be extracted later
|
||||
# (e.g. for auto-approve, where we need node_id but have no NodeExecution row).
|
||||
synthetic_graph_id = f"{COPILOT_SESSION_PREFIX}{session.session_id}"
|
||||
synthetic_node_id = f"{COPILOT_NODE_PREFIX}{block_id}"
|
||||
try:
|
||||
# Get or create user's workspace for CoPilot file operations
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
|
||||
# Check for an existing WAITING review for this block with the same input.
|
||||
# If the LLM retries run_block with identical input, we reuse the existing
|
||||
# review instead of creating duplicates. Different inputs = new execution.
|
||||
existing_reviews = await review_db().get_pending_reviews_for_execution(
|
||||
synthetic_graph_id, user_id
|
||||
)
|
||||
existing_review = next(
|
||||
(
|
||||
r
|
||||
for r in existing_reviews
|
||||
if r.node_id == synthetic_node_id
|
||||
and r.status.value == "WAITING"
|
||||
and r.payload == input_data
|
||||
),
|
||||
None,
|
||||
)
|
||||
if existing_review:
|
||||
return ReviewRequiredResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' requires human review. "
|
||||
f"After the user approves, call continue_run_block with "
|
||||
f"review_id='{existing_review.node_exec_id}' to execute."
|
||||
),
|
||||
session_id=session_id,
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
review_id=existing_review.node_exec_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
input_data=input_data,
|
||||
# Generate synthetic IDs for CoPilot context
|
||||
# Each chat session is treated as its own agent with one continuous run
|
||||
# This means:
|
||||
# - graph_id (agent) = session (memories scoped to session when limit_to_agent=True)
|
||||
# - graph_exec_id (run) = session (memories scoped to session when limit_to_run=True)
|
||||
# - node_exec_id = unique per block execution
|
||||
synthetic_graph_id = f"copilot-session-{session.session_id}"
|
||||
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
|
||||
synthetic_node_id = f"copilot-node-{block_id}"
|
||||
synthetic_node_exec_id = (
|
||||
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
|
||||
)
|
||||
|
||||
synthetic_node_exec_id = (
|
||||
f"{synthetic_node_id}{COPILOT_NODE_EXEC_ID_SEPARATOR}"
|
||||
f"{uuid.uuid4().hex[:8]}"
|
||||
)
|
||||
|
||||
# Check for HITL review before execution.
|
||||
# This creates the review record in the DB for CoPilot flows.
|
||||
review_context = ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
graph_version=1,
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=synthetic_node_exec_id,
|
||||
sensitive_action_safe_mode=True,
|
||||
)
|
||||
should_pause, input_data = await block.is_block_exec_need_review(
|
||||
input_data,
|
||||
user_id=user_id,
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=synthetic_node_exec_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
graph_version=1,
|
||||
execution_context=review_context,
|
||||
is_graph_execution=False,
|
||||
)
|
||||
if should_pause:
|
||||
return ReviewRequiredResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' requires human review. "
|
||||
f"After the user approves, call continue_run_block with "
|
||||
f"review_id='{synthetic_node_exec_id}' to execute."
|
||||
),
|
||||
session_id=session_id,
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
review_id=synthetic_node_exec_id,
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
input_data=input_data,
|
||||
# Create unified execution context with all required fields
|
||||
execution_context = ExecutionContext(
|
||||
# Execution identity
|
||||
user_id=user_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
graph_exec_id=synthetic_graph_exec_id,
|
||||
graph_version=1, # Versions are 1-indexed
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=synthetic_node_exec_id,
|
||||
# Workspace with session scoping
|
||||
workspace_id=workspace.id,
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
return await execute_block(
|
||||
block=block,
|
||||
block_id=block_id,
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
node_exec_id=synthetic_node_exec_id,
|
||||
matched_credentials=matched_credentials,
|
||||
)
|
||||
# Prepare kwargs for block execution
|
||||
# Keep individual kwargs for backwards compatibility with existing blocks
|
||||
exec_kwargs: dict[str, Any] = {
|
||||
"user_id": user_id,
|
||||
"execution_context": execution_context,
|
||||
# Legacy: individual kwargs for blocks not yet using execution_context
|
||||
"workspace_id": workspace.id,
|
||||
"graph_exec_id": synthetic_graph_exec_id,
|
||||
"node_exec_id": synthetic_node_exec_id,
|
||||
"node_id": synthetic_node_id,
|
||||
"graph_version": 1, # Versions are 1-indexed
|
||||
"graph_id": synthetic_graph_id,
|
||||
}
|
||||
|
||||
for field_name, cred_meta in matched_credentials.items():
|
||||
# Inject metadata into input_data (for validation)
|
||||
if field_name not in input_data:
|
||||
input_data[field_name] = cred_meta.model_dump()
|
||||
|
||||
# Fetch actual credentials and pass as kwargs (for execution)
|
||||
actual_credentials = await creds_manager.get(
|
||||
user_id, cred_meta.id, lock=False
|
||||
)
|
||||
if actual_credentials:
|
||||
exec_kwargs[field_name] = actual_credentials
|
||||
else:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to retrieve credentials for {field_name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Execute the block and collect outputs
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in block.execute(
|
||||
input_data,
|
||||
**exec_kwargs,
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
outputs=dict(outputs),
|
||||
success=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except BlockError as e:
|
||||
logger.warning(f"Block execution failed: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Block execution failed: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to execute block: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def _resolve_block_credentials(
|
||||
self,
|
||||
user_id: str,
|
||||
block: AnyBlockSchema,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Resolve credentials for a block by matching user's available credentials.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
block: Block to resolve credentials for
|
||||
input_data: Input data for the block (used to determine provider via discriminator)
|
||||
|
||||
Returns:
|
||||
tuple of (matched_credentials, missing_credentials) - matched credentials
|
||||
are used for block execution, missing ones indicate setup requirements.
|
||||
"""
|
||||
input_data = input_data or {}
|
||||
requirements = self._resolve_discriminated_credentials(block, input_data)
|
||||
|
||||
if not requirements:
|
||||
return {}, []
|
||||
|
||||
return await match_credentials_to_requirements(user_id, requirements)
|
||||
|
||||
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
|
||||
"""Extract non-credential inputs from block schema."""
|
||||
schema = block.input_schema.jsonschema()
|
||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
|
||||
|
||||
def _resolve_discriminated_credentials(
|
||||
self,
|
||||
block: AnyBlockSchema,
|
||||
input_data: dict[str, Any],
|
||||
) -> dict[str, CredentialsFieldInfo]:
|
||||
"""Resolve credential requirements, applying discriminator logic where needed."""
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
if not credentials_fields_info:
|
||||
return {}
|
||||
|
||||
resolved: dict[str, CredentialsFieldInfo] = {}
|
||||
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
effective_field_info = field_info
|
||||
|
||||
if field_info.discriminator and field_info.discriminator_mapping:
|
||||
discriminator_value = input_data.get(field_info.discriminator)
|
||||
if discriminator_value is None:
|
||||
field = block.input_schema.model_fields.get(
|
||||
field_info.discriminator
|
||||
)
|
||||
if field and field.default is not PydanticUndefined:
|
||||
discriminator_value = field.default
|
||||
|
||||
if (
|
||||
discriminator_value
|
||||
and discriminator_value in field_info.discriminator_mapping
|
||||
):
|
||||
effective_field_info = field_info.discriminate(discriminator_value)
|
||||
# For host-scoped credentials, add the discriminator value
|
||||
# (e.g., URL) so _credential_is_for_host can match it
|
||||
effective_field_info.discriminator_values.add(discriminator_value)
|
||||
logger.debug(
|
||||
f"Discriminated provider for {field_name}: "
|
||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||
)
|
||||
|
||||
resolved[field_name] = effective_field_info
|
||||
|
||||
return resolved
|
||||
|
||||
@@ -12,7 +12,6 @@ from .models import (
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
InputValidationErrorResponse,
|
||||
ReviewRequiredResponse,
|
||||
)
|
||||
from .run_block import RunBlockTool
|
||||
|
||||
@@ -28,16 +27,9 @@ def make_mock_block(
|
||||
mock.name = name
|
||||
mock.block_type = block_type
|
||||
mock.disabled = disabled
|
||||
mock.is_sensitive_action = False
|
||||
mock.input_schema = MagicMock()
|
||||
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
||||
mock.input_schema.get_credentials_fields_info.return_value = {}
|
||||
mock.input_schema.get_credentials_fields.return_value = {}
|
||||
|
||||
async def _no_review(input_data, **kwargs):
|
||||
return False, input_data
|
||||
|
||||
mock.is_block_exec_need_review = _no_review
|
||||
mock.input_schema.get_credentials_fields_info.return_value = []
|
||||
return mock
|
||||
|
||||
|
||||
@@ -54,7 +46,6 @@ def make_mock_block_with_schema(
|
||||
mock.name = name
|
||||
mock.block_type = BlockType.STANDARD
|
||||
mock.disabled = False
|
||||
mock.is_sensitive_action = False
|
||||
mock.description = f"Test block: {name}"
|
||||
|
||||
input_schema = {
|
||||
@@ -72,12 +63,6 @@ def make_mock_block_with_schema(
|
||||
mock.output_schema = MagicMock()
|
||||
mock.output_schema.jsonschema.return_value = output_schema
|
||||
|
||||
# Default: no review needed, pass through input_data unchanged
|
||||
async def _no_review(input_data, **kwargs):
|
||||
return False, input_data
|
||||
|
||||
mock.is_block_exec_need_review = _no_review
|
||||
|
||||
return mock
|
||||
|
||||
|
||||
@@ -104,7 +89,7 @@ class TestRunBlockFiltering:
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "cannot be run directly" in response.message
|
||||
assert "cannot be run directly in CoPilot" in response.message
|
||||
assert "designed for use within graphs only" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@@ -130,7 +115,7 @@ class TestRunBlockFiltering:
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "cannot be run directly" in response.message
|
||||
assert "cannot be run directly in CoPilot" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_non_excluded_block_passes_guard(self):
|
||||
@@ -141,15 +126,9 @@ class TestRunBlockFiltering:
|
||||
"standard-id", "HTTP Request", BlockType.STANDARD
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=standard_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=standard_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
@@ -162,7 +141,7 @@ class TestRunBlockFiltering:
|
||||
# Should NOT be an ErrorResponse about CoPilot exclusion
|
||||
# (may be other errors like missing credentials, but not the exclusion guard)
|
||||
if isinstance(response, ErrorResponse):
|
||||
assert "cannot be run directly" not in response.message
|
||||
assert "cannot be run directly in CoPilot" not in response.message
|
||||
|
||||
|
||||
class TestRunBlockInputValidation:
|
||||
@@ -175,7 +154,12 @@ class TestRunBlockInputValidation:
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unknown_input_fields_are_rejected(self):
|
||||
"""run_block rejects unknown input fields instead of silently ignoring them."""
|
||||
"""run_block rejects unknown input fields instead of silently ignoring them.
|
||||
|
||||
Scenario: The AI Text Generator block has a field called 'model' (for LLM model
|
||||
selection), but the LLM calling the tool guesses wrong and sends 'LLM_Model'
|
||||
instead. The block should reject the request and return the valid schema.
|
||||
"""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
mock_block = make_mock_block_with_schema(
|
||||
@@ -198,31 +182,27 @@ class TestRunBlockInputValidation:
|
||||
output_properties={"response": {"type": "string"}},
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
|
||||
# Provide 'prompt' (correct) but 'LLM_Model' instead of 'model' (wrong key)
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="ai-text-gen-id",
|
||||
input_data={
|
||||
"prompt": "Write a haiku about coding",
|
||||
"LLM_Model": "claude-opus-4-6",
|
||||
"LLM_Model": "claude-opus-4-6", # WRONG KEY - should be 'model'
|
||||
},
|
||||
)
|
||||
|
||||
assert isinstance(response, InputValidationErrorResponse)
|
||||
assert "LLM_Model" in response.unrecognized_fields
|
||||
assert "Block was not executed" in response.message
|
||||
assert "inputs" in response.model_dump()
|
||||
assert "inputs" in response.model_dump() # valid schema included
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_multiple_wrong_keys_are_all_reported(self):
|
||||
@@ -241,26 +221,21 @@ class TestRunBlockInputValidation:
|
||||
required_fields=["prompt"],
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="ai-text-gen-id",
|
||||
input_data={
|
||||
"prompt": "Hello",
|
||||
"llm_model": "claude-opus-4-6",
|
||||
"system_prompt": "Be helpful",
|
||||
"retries": 5,
|
||||
"prompt": "Hello", # correct
|
||||
"llm_model": "claude-opus-4-6", # WRONG - should be 'model'
|
||||
"system_prompt": "Be helpful", # WRONG - should be 'sys_prompt'
|
||||
"retries": 5, # WRONG - should be 'retry'
|
||||
},
|
||||
)
|
||||
|
||||
@@ -287,26 +262,23 @@ class TestRunBlockInputValidation:
|
||||
required_fields=["prompt"],
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
|
||||
# 'prompt' is missing AND 'LLM_Model' is an unknown field
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="ai-text-gen-id",
|
||||
input_data={
|
||||
"LLM_Model": "claude-opus-4-6",
|
||||
"LLM_Model": "claude-opus-4-6", # wrong key, and 'prompt' is missing
|
||||
},
|
||||
)
|
||||
|
||||
# Unknown fields are caught first
|
||||
assert isinstance(response, InputValidationErrorResponse)
|
||||
assert "LLM_Model" in response.unrecognized_fields
|
||||
|
||||
@@ -341,11 +313,7 @@ class TestRunBlockInputValidation:
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
"backend.copilot.tools.run_block.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
),
|
||||
):
|
||||
@@ -357,7 +325,7 @@ class TestRunBlockInputValidation:
|
||||
block_id="ai-text-gen-id",
|
||||
input_data={
|
||||
"prompt": "Write a haiku",
|
||||
"model": "gpt-4o-mini",
|
||||
"model": "gpt-4o-mini", # correct field name
|
||||
},
|
||||
)
|
||||
|
||||
@@ -379,191 +347,20 @@ class TestRunBlockInputValidation:
|
||||
required_fields=["prompt"],
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
|
||||
# Only provide valid optional field, missing required 'prompt'
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="ai-text-gen-id",
|
||||
input_data={
|
||||
"model": "gpt-4o-mini",
|
||||
"model": "gpt-4o-mini", # valid but optional
|
||||
},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockDetailsResponse)
|
||||
|
||||
|
||||
class TestRunBlockSensitiveAction:
|
||||
"""Tests for sensitive action HITL review in RunBlockTool.
|
||||
|
||||
run_block calls is_block_exec_need_review() explicitly before execution.
|
||||
When review is needed (should_pause=True), ReviewRequiredResponse is returned.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_sensitive_block_paused_returns_review_required(self):
|
||||
"""When is_block_exec_need_review returns should_pause=True, ReviewRequiredResponse is returned."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
input_data = {
|
||||
"repo_url": "https://github.com/test/repo",
|
||||
"branch": "feature-branch",
|
||||
}
|
||||
mock_block = make_mock_block_with_schema(
|
||||
block_id="delete-branch-id",
|
||||
name="Delete Branch",
|
||||
input_properties={
|
||||
"repo_url": {"type": "string"},
|
||||
"branch": {"type": "string"},
|
||||
},
|
||||
required_fields=["repo_url", "branch"],
|
||||
)
|
||||
mock_block.is_sensitive_action = True
|
||||
mock_block.is_block_exec_need_review = AsyncMock(
|
||||
return_value=(True, input_data)
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="delete-branch-id",
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
assert isinstance(response, ReviewRequiredResponse)
|
||||
assert "requires human review" in response.message
|
||||
assert "continue_run_block" in response.message
|
||||
assert response.block_name == "Delete Branch"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_sensitive_block_executes_after_approval(self):
|
||||
"""After approval (should_pause=False), sensitive blocks execute and return outputs."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
input_data = {
|
||||
"repo_url": "https://github.com/test/repo",
|
||||
"branch": "feature-branch",
|
||||
}
|
||||
mock_block = make_mock_block_with_schema(
|
||||
block_id="delete-branch-id",
|
||||
name="Delete Branch",
|
||||
input_properties={
|
||||
"repo_url": {"type": "string"},
|
||||
"branch": {"type": "string"},
|
||||
},
|
||||
required_fields=["repo_url", "branch"],
|
||||
)
|
||||
mock_block.is_sensitive_action = True
|
||||
mock_block.is_block_exec_need_review = AsyncMock(
|
||||
return_value=(False, input_data)
|
||||
)
|
||||
|
||||
async def mock_execute(input_data, **kwargs):
|
||||
yield "result", "Branch deleted successfully"
|
||||
|
||||
mock_block.execute = mock_execute
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="test-workspace-id")
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
),
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="delete-branch-id",
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert response.success is True
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_non_sensitive_block_executes_normally(self):
|
||||
"""Non-sensitive blocks skip review and execute directly."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
input_data = {"url": "https://example.com"}
|
||||
mock_block = make_mock_block_with_schema(
|
||||
block_id="http-request-id",
|
||||
name="HTTP Request",
|
||||
input_properties={
|
||||
"url": {"type": "string"},
|
||||
},
|
||||
required_fields=["url"],
|
||||
)
|
||||
mock_block.is_sensitive_action = False
|
||||
mock_block.is_block_exec_need_review = AsyncMock(
|
||||
return_value=(False, input_data)
|
||||
)
|
||||
|
||||
async def mock_execute(input_data, **kwargs):
|
||||
yield "response", {"status": 200}
|
||||
|
||||
mock_block.execute = mock_execute
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="test-workspace-id")
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.match_credentials_to_requirements",
|
||||
return_value=({}, []),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
),
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="http-request-id",
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert response.success is True
|
||||
|
||||
@@ -14,7 +14,7 @@ from backend.blocks.mcp.helpers import (
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.utils import build_missing_credentials_from_field_info
|
||||
from backend.util.request import HTTPClientError, validate_url_host
|
||||
from backend.util.request import HTTPClientError, validate_url
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
@@ -34,11 +34,6 @@ logger = logging.getLogger(__name__)
|
||||
_AUTH_STATUS_CODES = {401, 403}
|
||||
|
||||
|
||||
def _service_name(host: str) -> str:
|
||||
"""Strip the 'mcp.' prefix from an MCP hostname: 'mcp.sentry.dev' → 'sentry.dev'"""
|
||||
return host[4:] if host.startswith("mcp.") else host
|
||||
|
||||
|
||||
class RunMCPToolTool(BaseTool):
|
||||
"""
|
||||
Tool for discovering and executing tools on any MCP server.
|
||||
@@ -58,9 +53,15 @@ class RunMCPToolTool(BaseTool):
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Connect to an MCP (Model Context Protocol) server to discover and execute its tools. "
|
||||
"Two-step: (1) call with server_url to list available tools, "
|
||||
"(2) call again with server_url + tool_name + tool_arguments to execute. "
|
||||
"Call get_mcp_guide for known server URLs and auth details."
|
||||
"Two-step workflow: (1) Call with just `server_url` to discover available tools. "
|
||||
"(2) Call again with `server_url`, `tool_name`, and `tool_arguments` to execute. "
|
||||
"Known hosted servers (use directly): Notion (https://mcp.notion.com/mcp), "
|
||||
"Linear (https://mcp.linear.app/mcp), Stripe (https://mcp.stripe.com), "
|
||||
"Intercom (https://mcp.intercom.com/mcp), Cloudflare (https://mcp.cloudflare.com/mcp), "
|
||||
"Atlassian/Jira (https://mcp.atlassian.com/mcp). "
|
||||
"For other services, search the MCP registry at https://registry.modelcontextprotocol.io/. "
|
||||
"Authentication: If the server requires credentials, user will be prompted to complete the MCP credential setup flow."
|
||||
"Once connected and user confirms, retry the same call immediately."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -149,7 +150,7 @@ class RunMCPToolTool(BaseTool):
|
||||
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges
|
||||
try:
|
||||
await validate_url_host(server_url)
|
||||
await validate_url(server_url, trusted_origins=[])
|
||||
except ValueError as e:
|
||||
msg = str(e)
|
||||
if "Unable to resolve" in msg or "No IP addresses" in msg:
|
||||
@@ -308,8 +309,8 @@ class RunMCPToolTool(BaseTool):
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Unable to connect to {_service_name(server_host(server_url))} "
|
||||
"— no credentials configured."
|
||||
f"The MCP server at {server_host(server_url)} requires authentication, "
|
||||
"but no credential configuration was found."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -317,13 +318,15 @@ class RunMCPToolTool(BaseTool):
|
||||
missing_creds_list = list(missing_creds_dict.values())
|
||||
|
||||
host = server_host(server_url)
|
||||
service = _service_name(host)
|
||||
return SetupRequirementsResponse(
|
||||
message=(f"To continue, sign in to {service} and approve access."),
|
||||
message=(
|
||||
f"The MCP server at {host} requires authentication. "
|
||||
"Please connect your credentials to continue."
|
||||
),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=server_url,
|
||||
agent_name=service,
|
||||
agent_name=f"MCP: {host}",
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_creds_dict,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user