Compare commits

..

7 Commits

Author SHA1 Message Date
Swifty
a2c97d428e Merge branch 'swiftyos/tracing' of github.com:Significant-Gravitas/AutoGPT into swiftyos/tracing 2026-03-03 14:43:28 +01:00
Swifty
78fee94569 fix(backend): update poetry.lock content-hash
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-03 14:42:44 +01:00
Swifty
b913e8f9de Merge branch 'dev' into swiftyos/tracing 2026-03-03 14:29:34 +01:00
Swifty
7f16a10e9e review comments. 2026-03-03 14:29:06 +01:00
Swifty
57f56c0caa ensure usage is included 2026-03-03 11:39:24 +01:00
Swifty
e8b82cd268 lint 2026-03-03 11:17:17 +01:00
Swifty
4a108ad5d2 Update tracing so that it admits traces in the same format as openrouter broadcast uses 2026-03-02 16:25:45 +01:00
315 changed files with 11873 additions and 39825 deletions

View File

@@ -1,17 +0,0 @@
---
name: backend-check
description: Run the full backend formatting, linting, and test suite. Ensures code quality before commits and PRs. TRIGGER when backend Python code has been modified and needs validation.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Backend Check
## Steps
1. **Format**: `poetry run format` — runs formatting AND linting. NEVER run ruff/black/isort individually
2. **Fix** any remaining errors manually, re-run until clean
3. **Test**: `poetry run test` (runs DB setup + pytest). For specific files: `poetry run pytest -s -vvv <test_files>`
4. **Snapshots** (if needed): `poetry run pytest path/to/test.py --snapshot-update` — review with `git diff`

View File

@@ -1,35 +0,0 @@
---
name: code-style
description: Python code style preferences for the AutoGPT backend. Apply when writing or reviewing Python code. TRIGGER when writing new Python code, reviewing PRs, or refactoring backend code.
user-invocable: false
metadata:
author: autogpt-team
version: "1.0.0"
---
# Code Style
## Imports
- **Top-level only** — no local/inner imports. Move all imports to the top of the file.
## Typing
- **No duck typing** — avoid `hasattr`, `getattr`, `isinstance` for type dispatch. Use proper typed interfaces, unions, or protocols.
- **Pydantic models** over dataclass, namedtuple, or raw dict for structured data.
- **No linter suppressors** — avoid `# type: ignore`, `# noqa`, `# pyright: ignore` etc. 99% of the time the right fix is fixing the type/code, not silencing the tool.
## Code Structure
- **List comprehensions** over manual loop-and-append.
- **Early return** — guard clauses first, avoid deep nesting.
- **Flatten inline** — prefer short, concise expressions. Reduce `if/else` chains with direct returns or ternaries when readable.
- **Modular functions** — break complex logic into small, focused functions rather than long blocks with nested conditionals.
## Review Checklist
Before finishing, always ask:
- Can any function be split into smaller pieces?
- Is there unnecessary nesting that an early return would eliminate?
- Can any loop be a comprehension?
- Is there a simpler way to express this logic?

View File

@@ -1,16 +0,0 @@
---
name: frontend-check
description: Run the full frontend formatting, linting, and type checking suite. Ensures code quality before commits and PRs. TRIGGER when frontend TypeScript/React code has been modified and needs validation.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Frontend Check
## Steps (in order)
1. **Format**: `pnpm format` — NEVER run individual formatters
2. **Lint**: `pnpm lint` — fix errors, re-run until clean
3. **Types**: `pnpm types` — if it keeps failing after multiple attempts, stop and ask the user

View File

@@ -1,29 +0,0 @@
---
name: new-block
description: Create a new backend block following the Block SDK Guide. Guides through provider configuration, schema definition, authentication, and testing. TRIGGER when user asks to create a new block, add a new integration, or build a new node for the graph editor.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# New Block Creation
Read `docs/platform/block-sdk-guide.md` first for the full guide.
## Steps
1. **Provider config** (if external service): create `_config.py` with `ProviderBuilder`
2. **Block file** in `backend/blocks/` (from `autogpt_platform/backend/`):
- Generate a UUID once with `uuid.uuid4()`, then **hard-code that string** as `id` (IDs must be stable across imports)
- `Input(BlockSchema)` and `Output(BlockSchema)` classes
- `async def run` that `yield`s output fields
3. **Files**: use `store_media_file()` with `"for_block_output"` for outputs
4. **Test**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[MyBlock]' -xvs`
5. **Format**: `poetry run format`
## Rules
- Analyze interfaces: do inputs/outputs connect well with other blocks in a graph?
- Use top-level imports, avoid duck typing
- Always use `for_block_output` for block outputs

View File

@@ -1,28 +0,0 @@
---
name: openapi-regen
description: Regenerate the OpenAPI spec and frontend API client. Starts the backend REST server, fetches the spec, and regenerates the typed frontend hooks. TRIGGER when API routes change, new endpoints are added, or frontend API types are stale.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# OpenAPI Spec Regeneration
## Steps
1. **Run end-to-end** in a single shell block (so `REST_PID` persists):
```bash
cd autogpt_platform/backend && poetry run rest &
REST_PID=$!
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && kill $REST_PID && exit 1; done
cd ../frontend && pnpm generate:api:force
kill $REST_PID
pnpm types && pnpm lint && pnpm format
```
## Rules
- Always use `pnpm generate:api:force` (not `pnpm generate:api`)
- Don't manually edit files in `src/app/api/__generated__/`
- Generated hooks follow: `use{Method}{Version}{OperationName}`

View File

@@ -1,31 +0,0 @@
---
name: pr-create
description: Create a pull request for the current branch. TRIGGER when user asks to create a PR, open a pull request, push changes for review, or submit work for merging.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Create Pull Request
## Steps
1. **Check for existing PR**: `gh pr view --json url -q .url 2>/dev/null` — if a PR already exists, output its URL and stop
2. **Understand changes**: `git status`, `git diff dev...HEAD`, `git log dev..HEAD --oneline`
3. **Read PR template**: `.github/PULL_REQUEST_TEMPLATE.md`
4. **Draft PR title**: Use conventional commits format (see CLAUDE.md for types and scopes)
5. **Fill out PR template** as the body — be thorough in the Changes section
6. **Format first** (if relevant changes exist):
- Backend: `cd autogpt_platform/backend && poetry run format`
- Frontend: `cd autogpt_platform/frontend && pnpm format`
- Fix any lint errors, then commit formatting changes before pushing
7. **Push**: `git push -u origin HEAD`
8. **Create PR**: `gh pr create --base dev`
9. **Output** the PR URL
## Rules
- Always target `dev` branch
- Do NOT run tests — CI will handle that
- Use the PR template from `.github/PULL_REQUEST_TEMPLATE.md`

View File

@@ -1,51 +0,0 @@
---
name: pr-review
description: Address all open PR review comments systematically. Fetches comments, addresses each one, reacts +1/-1, and replies when clarification is needed. Keeps iterating until all comments are addressed and CI is green. TRIGGER when user shares a PR URL, asks to address review comments, fix PR feedback, or respond to reviewer comments.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# PR Review Comment Workflow
## Steps
1. **Find PR**: `gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT`
2. **Fetch comments** (all three sources):
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` (top-level reviews)
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` (inline review comments)
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` (PR conversation comments)
3. **Skip** comments already reacted to by PR author
4. **For each unreacted comment**:
- Read referenced code, make the fix (or reply if you disagree/need info)
- **Inline review comments** (`pulls/{N}/comments`):
- React: `gh api repos/.../pulls/comments/{ID}/reactions -f content="+1"` (or `-1`)
- Reply: `gh api repos/.../pulls/{N}/comments/{ID}/replies -f body="..."`
- **PR conversation comments** (`issues/{N}/comments`):
- React: `gh api repos/.../issues/comments/{ID}/reactions -f content="+1"` (or `-1`)
- No threaded replies — post a new issue comment if needed
- **Top-level reviews**: no reaction API — address in code, reply via issue comment if needed
5. **Include autogpt-reviewer bot fixes** too
6. **Format**: `cd autogpt_platform/backend && poetry run format`, `cd autogpt_platform/frontend && pnpm format`
7. **Commit & push**
8. **Re-fetch comments** immediately — address any new unreacted ones before waiting on CI
9. **Stay productive while CI runs** — don't idle. In priority order:
- Run any pending local tests (`poetry run pytest`, e2e, etc.) and fix failures
- Address any remaining comments
- Only poll `gh pr checks {N}` as the last resort when there's truly nothing left to do
10. **If CI fails** — fix, go back to step 6
11. **Re-fetch comments again** after CI is green — address anything that appeared while CI was running
12. **Done** only when: all comments reacted AND CI is green.
## CRITICAL: Do Not Stop
**Loop is: address → format → commit → push → re-check comments → run local tests → wait CI → re-check comments → repeat.**
Never idle. If CI is running and you have nothing to address, run local tests. Waiting on CI is the last resort.
## Rules
- One todo per comment
- For inline review comments: reply on existing threads. For PR conversation comments: post a new issue comment (API doesn't support threaded replies)
- React to every comment: +1 addressed, -1 disagreed (with explanation)

View File

@@ -1,45 +0,0 @@
---
name: worktree-setup
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, generates Prisma client, and optionally starts the app (with port conflict resolution) or runs tests. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Worktree Setup
## Preferred: Use Branchlet
The repo has a `.branchlet.json` config — it handles env file copying, dependency installation, and Prisma generation automatically.
```bash
npm install -g branchlet # install once
branchlet create -n <name> -s <source-branch> -b <new-branch>
branchlet list --json # list all worktrees
```
## Manual Fallback
If branchlet isn't available:
1. `git worktree add ../<RepoName><N> <branch-name>`
2. Copy `.env` files: `backend/.env`, `frontend/.env`, `autogpt_platform/.env`, `db/docker/.env`
3. Install deps:
- `cd autogpt_platform/backend && poetry install && poetry run prisma generate`
- `cd autogpt_platform/frontend && pnpm install`
## Running the App
Free ports first — backend uses: 8001, 8002, 8003, 8005, 8006, 8007, 8008.
```bash
for port in 8001 8002 8003 8005 8006 8007 8008; do
lsof -ti :$port | xargs kill -9 2>/dev/null || true
done
cd <worktree>/autogpt_platform/backend && poetry run app
```
## CoPilot Testing Gotcha
SDK mode spawns a Claude subprocess — **won't work inside Claude Code**. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.

View File

@@ -149,7 +149,7 @@ jobs:
driver-opts: network=host
- name: Set up Platform - Expose GHA cache to docker buildx CLI
uses: crazy-max/ghaction-github-runtime@v4
uses: crazy-max/ghaction-github-runtime@v3
- name: Set up Platform - Build Docker images (with cache)
working-directory: autogpt_platform

1
.nvmrc
View File

@@ -1 +0,0 @@
22

View File

@@ -1,3 +1,2 @@
*.ignore.*
*.ign.*
.application.logs
*.ign.*

View File

@@ -65,6 +65,12 @@ LANGFUSE_PUBLIC_KEY=
LANGFUSE_SECRET_KEY=
LANGFUSE_HOST=https://cloud.langfuse.com
# OTLP Tracing
# Base host for OTLP trace ingestion (for example Product Intelligence)
OTLP_TRACING_HOST=
# Bearer token for OTLP trace ingestion endpoint (optional)
OTLP_TRACING_TOKEN=
# OAuth Credentials
# For the OAuth callback URL, use <your_frontend_url>/auth/integrations/oauth_callback,
# e.g. http://localhost:3000/auth/integrations/oauth_callback

View File

@@ -95,7 +95,7 @@ ENV DEBIAN_FRONTEND=noninteractive
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
# for the bash_exec MCP tool (fallback when E2B is not configured).
# for the bash_exec MCP tool.
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.13 \
@@ -111,29 +111,13 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
# Copy poetry (build-time only, for `poetry install --only-root` to create entry points)
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
# Copy Node.js installation for Prisma and agent-browser.
# npm/npx are symlinks in the builder (-> ../lib/node_modules/npm/bin/*-cli.js);
# COPY resolves them to regular files, breaking require() paths. Recreate as
# proper symlinks so npm/npx can find their modules.
# Copy Node.js installation for Prisma
COPY --from=builder /usr/bin/node /usr/bin/node
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
COPY --from=builder /usr/bin/npm /usr/bin/npm
COPY --from=builder /usr/bin/npx /usr/bin/npx
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
# Install agent-browser (Copilot browser tool) + Chromium runtime dependencies.
# These are the runtime libraries Chromium/Playwright needs on Debian 13 (trixie).
RUN apt-get update && apt-get install -y --no-install-recommends \
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
fonts-liberation libfontconfig1 \
&& rm -rf /var/lib/apt/lists/* \
&& npm install -g agent-browser \
&& agent-browser install \
&& rm -rf /tmp/* /root/.npm
WORKDIR /app/autogpt_platform/backend
# Copy only the .venv from builder (not the entire /app directory)

View File

@@ -1,7 +1,7 @@
import logging
import urllib.parse
from collections import defaultdict
from typing import Annotated, Any, Optional, Sequence
from typing import Annotated, Any, Literal, Optional, Sequence
from fastapi import APIRouter, Body, HTTPException, Security
from prisma.enums import AgentExecutionStatus, APIKeyPermission
@@ -9,10 +9,9 @@ from pydantic import BaseModel, Field
from typing_extensions import TypedDict
import backend.api.features.store.cache as store_cache
import backend.api.features.store.db as store_db
import backend.api.features.store.model as store_model
import backend.blocks
from backend.api.external.middleware import require_auth, require_permission
from backend.api.external.middleware import require_permission
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data import user as user_db
@@ -231,13 +230,13 @@ async def get_graph_execution_results(
@v1_router.get(
path="/store/agents",
tags=["store"],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
response_model=store_model.StoreAgentsResponse,
)
async def get_store_agents(
featured: bool = False,
creator: str | None = None,
sorted_by: store_db.StoreAgentsSortOptions | None = None,
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
search_query: str | None = None,
category: str | None = None,
page: int = 1,
@@ -279,7 +278,7 @@ async def get_store_agents(
@v1_router.get(
path="/store/agents/{username}/{agent_name}",
tags=["store"],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
response_model=store_model.StoreAgentDetails,
)
async def get_store_agent(
@@ -307,13 +306,13 @@ async def get_store_agent(
@v1_router.get(
path="/store/creators",
tags=["store"],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
response_model=store_model.CreatorsResponse,
)
async def get_store_creators(
featured: bool = False,
search_query: str | None = None,
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
page: int = 1,
page_size: int = 20,
) -> store_model.CreatorsResponse:
@@ -349,7 +348,7 @@ async def get_store_creators(
@v1_router.get(
path="/store/creators/{username}",
tags=["store"],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
response_model=store_model.CreatorDetails,
)
async def get_store_creator(

View File

@@ -24,13 +24,14 @@ router = fastapi.APIRouter(
@router.get(
"/listings",
summary="Get Admin Listings History",
response_model=store_model.StoreListingsWithVersionsResponse,
)
async def get_admin_listings_with_versions(
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
search: typing.Optional[str] = None,
page: int = 1,
page_size: int = 20,
) -> store_model.StoreListingsWithVersionsAdminViewResponse:
):
"""
Get store listings with their version history for admins.
@@ -44,26 +45,36 @@ async def get_admin_listings_with_versions(
page_size: Number of items per page
Returns:
Paginated listings with their versions
StoreListingsWithVersionsResponse with listings and their versions
"""
listings = await store_db.get_admin_listings_with_versions(
status=status,
search_query=search,
page=page,
page_size=page_size,
)
return listings
try:
listings = await store_db.get_admin_listings_with_versions(
status=status,
search_query=search,
page=page,
page_size=page_size,
)
return listings
except Exception as e:
logger.exception("Error getting admin listings with versions: %s", e)
return fastapi.responses.JSONResponse(
status_code=500,
content={
"detail": "An error occurred while retrieving listings with versions"
},
)
@router.post(
"/submissions/{store_listing_version_id}/review",
summary="Review Store Submission",
response_model=store_model.StoreSubmission,
)
async def review_submission(
store_listing_version_id: str,
request: store_model.ReviewSubmissionRequest,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
) -> store_model.StoreSubmissionAdminView:
):
"""
Review a store listing submission.
@@ -73,24 +84,31 @@ async def review_submission(
user_id: Authenticated admin user performing the review
Returns:
StoreSubmissionAdminView with updated review information
StoreSubmission with updated review information
"""
already_approved = await store_db.check_submission_already_approved(
store_listing_version_id=store_listing_version_id,
)
submission = await store_db.review_store_submission(
store_listing_version_id=store_listing_version_id,
is_approved=request.is_approved,
external_comments=request.comments,
internal_comments=request.internal_comments or "",
reviewer_id=user_id,
)
try:
already_approved = await store_db.check_submission_already_approved(
store_listing_version_id=store_listing_version_id,
)
submission = await store_db.review_store_submission(
store_listing_version_id=store_listing_version_id,
is_approved=request.is_approved,
external_comments=request.comments,
internal_comments=request.internal_comments or "",
reviewer_id=user_id,
)
state_changed = already_approved != request.is_approved
# Clear caches whenever approval state changes, since store visibility can change
if state_changed:
store_cache.clear_all_caches()
return submission
state_changed = already_approved != request.is_approved
# Clear caches when the request is approved as it updates what is shown on the store
if state_changed:
store_cache.clear_all_caches()
return submission
except Exception as e:
logger.exception("Error reviewing submission: %s", e)
return fastapi.responses.JSONResponse(
status_code=500,
content={"detail": "An error occurred while reviewing the submission"},
)
@router.get(

View File

@@ -2,7 +2,6 @@
import asyncio
import logging
import re
from collections.abc import AsyncGenerator
from typing import Annotated
from uuid import uuid4
@@ -10,8 +9,7 @@ from uuid import uuid4
from autogpt_libs import auth
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from prisma.models import UserWorkspaceFile
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
@@ -25,10 +23,8 @@ from backend.copilot.model import (
delete_chat_session,
get_chat_session,
get_user_sessions,
update_session_title,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.e2b_sandbox import kill_sandbox
from backend.copilot.tools.models import (
AgentDetailsResponse,
AgentOutputResponse,
@@ -44,8 +40,6 @@ from backend.copilot.tools.models import (
ErrorResponse,
ExecutionStartedResponse,
InputValidationErrorResponse,
MCPToolOutputResponse,
MCPToolsDiscoveredResponse,
NeedLoginResponse,
NoResultsResponse,
SetupRequirementsResponse,
@@ -53,14 +47,10 @@ from backend.copilot.tools.models import (
UnderstandingUpdatedResponse,
)
from backend.copilot.tracking import track_user_message
from backend.data.workspace import get_or_create_workspace
from backend.util.exceptions import NotFoundError
config = ChatConfig()
_UUID_RE = re.compile(
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
)
logger = logging.getLogger(__name__)
@@ -89,9 +79,6 @@ class StreamChatRequest(BaseModel):
message: str
is_user_message: bool = True
context: dict[str, str] | None = None # {url: str, content: str}
file_ids: list[str] | None = Field(
default=None, max_length=20
) # Workspace file IDs attached to this message
class CreateSessionResponse(BaseModel):
@@ -143,20 +130,6 @@ class CancelSessionResponse(BaseModel):
reason: str | None = None
class UpdateSessionTitleRequest(BaseModel):
"""Request model for updating a session's title."""
title: str
@field_validator("title")
@classmethod
def title_must_not_be_blank(cls, v: str) -> str:
stripped = v.strip()
if not stripped:
raise ValueError("Title must not be blank")
return stripped
# ========== Routes ==========
@@ -265,58 +238,9 @@ async def delete_session(
detail=f"Session {session_id} not found or access denied",
)
# Best-effort cleanup of the E2B sandbox (if any).
# sandbox_id is in Redis; kill_sandbox() fetches it from there.
e2b_cfg = ChatConfig()
if e2b_cfg.e2b_active:
assert e2b_cfg.e2b_api_key # guaranteed by e2b_active check
try:
await kill_sandbox(session_id, e2b_cfg.e2b_api_key)
except Exception:
logger.warning(
"[E2B] Failed to kill sandbox for session %s", session_id[:12]
)
return Response(status_code=204)
@router.patch(
"/sessions/{session_id}/title",
summary="Update session title",
dependencies=[Security(auth.requires_user)],
status_code=200,
responses={404: {"description": "Session not found or access denied"}},
)
async def update_session_title_route(
session_id: str,
request: UpdateSessionTitleRequest,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> dict:
"""
Update the title of a chat session.
Allows the user to rename their chat session.
Args:
session_id: The session ID to update.
request: Request body containing the new title.
user_id: The authenticated user's ID.
Returns:
dict: Status of the update.
Raises:
HTTPException: 404 if session not found or not owned by user.
"""
success = await update_session_title(session_id, user_id, request.title)
if not success:
raise HTTPException(
status_code=404,
detail=f"Session {session_id} not found or access denied",
)
return {"status": "ok"}
@router.get(
"/sessions/{session_id}",
)
@@ -470,38 +394,6 @@ async def stream_chat_post(
},
)
# Enrich message with file metadata if file_ids are provided.
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
sanitized_file_ids: list[str] | None = None
if request.file_ids and user_id:
# Filter to valid UUIDs only to prevent DB abuse
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
if valid_ids:
workspace = await get_or_create_workspace(user_id)
# Batch query instead of N+1
files = await UserWorkspaceFile.prisma().find_many(
where={
"id": {"in": valid_ids},
"workspaceId": workspace.id,
"isDeleted": False,
}
)
# Only keep IDs that actually exist in the user's workspace
sanitized_file_ids = [wf.id for wf in files] or None
file_lines: list[str] = [
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
for wf in files
]
if file_lines:
files_block = (
"\n\n[Attached files]\n"
+ "\n".join(file_lines)
+ "\nUse read_workspace_file with the file_id to access file contents."
)
request.message += files_block
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
# saved yet. append_and_save_message re-fetches inside a lock to prevent
@@ -553,7 +445,6 @@ async def stream_chat_post(
turn_id=turn_id,
is_user_message=request.is_user_message,
context=request.context,
file_ids=sanitized_file_ids,
)
setup_time = (time.perf_counter() - stream_start_time) * 1000
@@ -806,6 +697,7 @@ async def resume_session_stream(
@router.patch(
"/sessions/{session_id}/assign-user",
dependencies=[Security(auth.requires_user)],
status_code=200,
)
async def session_assign_user(
session_id: str,
@@ -908,8 +800,6 @@ ToolResponseUnion = (
| BlockOutputResponse
| DocSearchResultsResponse
| DocPageResponse
| MCPToolsDiscoveredResponse
| MCPToolOutputResponse
)

View File

@@ -1,251 +0,0 @@
"""Tests for chat API routes: session title update and file attachment validation."""
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from backend.api.features.chat import routes as chat_routes
app = fastapi.FastAPI()
app.include_router(chat_routes.router)
client = fastapi.testclient.TestClient(app)
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
"""Setup auth overrides for all tests in this module"""
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def _mock_update_session_title(
mocker: pytest_mock.MockerFixture, *, success: bool = True
):
"""Mock update_session_title."""
return mocker.patch(
"backend.api.features.chat.routes.update_session_title",
new_callable=AsyncMock,
return_value=success,
)
# ─── Update title: success ─────────────────────────────────────────────
def test_update_title_success(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
mock_update = _mock_update_session_title(mocker, success=True)
response = client.patch(
"/sessions/sess-1/title",
json={"title": "My project"},
)
assert response.status_code == 200
assert response.json() == {"status": "ok"}
mock_update.assert_called_once_with("sess-1", test_user_id, "My project")
def test_update_title_trims_whitespace(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
mock_update = _mock_update_session_title(mocker, success=True)
response = client.patch(
"/sessions/sess-1/title",
json={"title": " trimmed "},
)
assert response.status_code == 200
mock_update.assert_called_once_with("sess-1", test_user_id, "trimmed")
# ─── Update title: blank / whitespace-only → 422 ──────────────────────
def test_update_title_blank_rejected(
test_user_id: str,
) -> None:
"""Whitespace-only titles must be rejected before hitting the DB."""
response = client.patch(
"/sessions/sess-1/title",
json={"title": " "},
)
assert response.status_code == 422
def test_update_title_empty_rejected(
test_user_id: str,
) -> None:
response = client.patch(
"/sessions/sess-1/title",
json={"title": ""},
)
assert response.status_code == 422
# ─── Update title: session not found or wrong user → 404 ──────────────
def test_update_title_not_found(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
_mock_update_session_title(mocker, success=False)
response = client.patch(
"/sessions/sess-1/title",
json={"title": "New name"},
)
assert response.status_code == 404
# ─── file_ids Pydantic validation ─────────────────────────────────────
def test_stream_chat_rejects_too_many_file_ids():
"""More than 20 file_ids should be rejected by Pydantic validation (422)."""
response = client.post(
"/sessions/sess-1/stream",
json={
"message": "hello",
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(21)],
},
)
assert response.status_code == 422
def _mock_stream_internals(mocker: pytest_mock.MockFixture):
"""Mock the async internals of stream_chat_post so tests can exercise
validation and enrichment logic without needing Redis/RabbitMQ."""
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = mocker.AsyncMock(return_value=None)
mocker.patch(
"backend.api.features.chat.routes.stream_registry",
mock_registry,
)
mocker.patch(
"backend.api.features.chat.routes.enqueue_copilot_turn",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
"""Exactly 20 file_ids should be accepted (not rejected by validation)."""
_mock_stream_internals(mocker)
# Patch workspace lookup as imported by the routes module
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
return_value=type("W", (), {"id": "ws-1"})(),
)
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
response = client.post(
"/sessions/sess-1/stream",
json={
"message": "hello",
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(20)],
},
)
# Should get past validation — 200 streaming response expected
assert response.status_code == 200
# ─── UUID format filtering ─────────────────────────────────────────────
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
"""Non-UUID strings in file_ids should be silently filtered out
and NOT passed to the database query."""
_mock_stream_internals(mocker)
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
return_value=type("W", (), {"id": "ws-1"})(),
)
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
valid_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
client.post(
"/sessions/sess-1/stream",
json={
"message": "hello",
"file_ids": [
valid_id,
"not-a-uuid",
"../../../etc/passwd",
"",
],
},
)
# The find_many call should only receive the one valid UUID
mock_prisma.find_many.assert_called_once()
call_kwargs = mock_prisma.find_many.call_args[1]
assert call_kwargs["where"]["id"]["in"] == [valid_id]
# ─── Cross-workspace file_ids ─────────────────────────────────────────
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
"""The batch query should scope to the user's workspace."""
_mock_stream_internals(mocker)
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
return_value=type("W", (), {"id": "my-workspace-id"})(),
)
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
fid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
client.post(
"/sessions/sess-1/stream",
json={"message": "hi", "file_ids": [fid]},
)
call_kwargs = mock_prisma.find_many.call_args[1]
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
assert call_kwargs["where"]["isDeleted"] is False

View File

@@ -22,7 +22,6 @@ from backend.data.human_review import (
)
from backend.data.model import USER_TIMEZONE_NOT_SET
from backend.data.user import get_user_by_id
from backend.data.workspace import get_or_create_workspace
from backend.executor.utils import add_graph_execution
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
@@ -322,13 +321,10 @@ async def process_review_action(
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
)
workspace = await get_or_create_workspace(user_id)
execution_context = ExecutionContext(
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
user_timezone=user_timezone,
workspace_id=workspace.id,
)
await add_graph_execution(

View File

@@ -8,6 +8,7 @@ import prisma.errors
import prisma.models
import prisma.types
import backend.api.features.store.exceptions as store_exceptions
import backend.api.features.store.image_gen as store_image_gen
import backend.api.features.store.media as store_media
import backend.data.graph as graph_db
@@ -250,7 +251,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
The requested LibraryAgent.
Raises:
NotFoundError: If the specified agent does not exist.
AgentNotFoundError: If the specified agent does not exist.
DatabaseError: If there's an error during retrieval.
"""
library_agent = await prisma.models.LibraryAgent.prisma().find_first(
@@ -397,7 +398,6 @@ async def create_library_agent(
hitl_safe_mode: bool = True,
sensitive_action_safe_mode: bool = False,
create_library_agents_for_sub_graphs: bool = True,
folder_id: str | None = None,
) -> list[library_model.LibraryAgent]:
"""
Adds an agent to the user's library (LibraryAgent table).
@@ -414,18 +414,12 @@ async def create_library_agent(
If the graph has sub-graphs, the parent graph will always be the first entry in the list.
Raises:
NotFoundError: If the specified agent does not exist.
AgentNotFoundError: If the specified agent does not exist.
DatabaseError: If there's an error during creation or if image generation fails.
"""
logger.info(
f"Creating library agent for graph #{graph.id} v{graph.version}; user:<redacted>"
)
# Authorization: FK only checks existence, not ownership.
# Verify the folder belongs to this user to prevent cross-user nesting.
if folder_id:
await get_folder(folder_id, user_id)
graph_entries = (
[graph, *graph.sub_graphs] if create_library_agents_for_sub_graphs else [graph]
)
@@ -438,6 +432,7 @@ async def create_library_agent(
isCreatedByUser=(user_id == user_id),
useGraphIsActiveVersion=True,
User={"connect": {"id": user_id}},
# Creator={"connect": {"id": user_id}},
AgentGraph={
"connect": {
"graphVersionId": {
@@ -453,11 +448,6 @@ async def create_library_agent(
sensitive_action_safe_mode=sensitive_action_safe_mode,
).model_dump()
),
**(
{"Folder": {"connect": {"id": folder_id}}}
if folder_id and graph_entry is graph
else {}
),
),
include=library_agent_include(
user_id, include_nodes=False, include_executions=False
@@ -539,7 +529,6 @@ async def update_agent_version_in_library(
async def create_graph_in_library(
graph: graph_db.Graph,
user_id: str,
folder_id: str | None = None,
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
"""Create a new graph and add it to the user's library."""
graph.version = 1
@@ -553,7 +542,6 @@ async def create_graph_in_library(
user_id=user_id,
sensitive_action_safe_mode=True,
create_library_agents_for_sub_graphs=False,
folder_id=folder_id,
)
if created_graph.is_active:
@@ -829,7 +817,7 @@ async def add_store_agent_to_library(
The newly created LibraryAgent if successfully added, the existing corresponding one if any.
Raises:
NotFoundError: If the store listing or associated agent is not found.
AgentNotFoundError: If the store listing or associated agent is not found.
DatabaseError: If there's an issue creating the LibraryAgent record.
"""
logger.debug(
@@ -844,7 +832,7 @@ async def add_store_agent_to_library(
)
if not store_listing_version or not store_listing_version.AgentGraph:
logger.warning(f"Store listing version not found: {store_listing_version_id}")
raise NotFoundError(
raise store_exceptions.AgentNotFoundError(
f"Store listing version {store_listing_version_id} not found or invalid"
)
@@ -858,7 +846,7 @@ async def add_store_agent_to_library(
include_subgraphs=False,
)
if not graph_model:
raise NotFoundError(
raise store_exceptions.AgentNotFoundError(
f"Graph #{graph.id} v{graph.version} not found or accessible"
)
@@ -1493,67 +1481,6 @@ async def bulk_move_agents_to_folder(
return [library_model.LibraryAgent.from_db(agent) for agent in agents]
def collect_tree_ids(
nodes: list[library_model.LibraryFolderTree],
visited: set[str] | None = None,
) -> list[str]:
"""Collect all folder IDs from a folder tree."""
if visited is None:
visited = set()
ids: list[str] = []
for n in nodes:
if n.id in visited:
continue
visited.add(n.id)
ids.append(n.id)
ids.extend(collect_tree_ids(n.children, visited))
return ids
async def get_folder_agent_summaries(
user_id: str, folder_id: str
) -> list[dict[str, str | None]]:
"""Get a lightweight list of agents in a folder (id, name, description)."""
all_agents: list[library_model.LibraryAgent] = []
for page in itertools.count(1):
resp = await list_library_agents(
user_id=user_id, folder_id=folder_id, page=page
)
all_agents.extend(resp.agents)
if page >= resp.pagination.total_pages:
break
return [
{"id": a.id, "name": a.name, "description": a.description} for a in all_agents
]
async def get_root_agent_summaries(
user_id: str,
) -> list[dict[str, str | None]]:
"""Get a lightweight list of root-level agents (folderId IS NULL)."""
all_agents: list[library_model.LibraryAgent] = []
for page in itertools.count(1):
resp = await list_library_agents(
user_id=user_id, include_root_only=True, page=page
)
all_agents.extend(resp.agents)
if page >= resp.pagination.total_pages:
break
return [
{"id": a.id, "name": a.name, "description": a.description} for a in all_agents
]
async def get_folder_agents_map(
user_id: str, folder_ids: list[str]
) -> dict[str, list[dict[str, str | None]]]:
"""Get agent summaries for multiple folders concurrently."""
results = await asyncio.gather(
*(get_folder_agent_summaries(user_id, fid) for fid in folder_ids)
)
return dict(zip(folder_ids, results))
##############################################
########### Presets DB Functions #############
##############################################

View File

@@ -4,6 +4,7 @@ import prisma.enums
import prisma.models
import pytest
import backend.api.features.store.exceptions
from backend.data.db import connect
from backend.data.includes import library_agent_include
@@ -217,7 +218,7 @@ async def test_add_agent_to_library_not_found(mocker):
)
# Call function and verify exception
with pytest.raises(db.NotFoundError):
with pytest.raises(backend.api.features.store.exceptions.AgentNotFoundError):
await db.add_store_agent_to_library("version123", "test-user")
# Verify mock called correctly

View File

@@ -7,24 +7,20 @@ frontend can list available tools on an MCP server before placing a block.
import logging
from typing import Annotated, Any
from urllib.parse import urlparse
import fastapi
from autogpt_libs.auth import get_user_id
from fastapi import Security
from pydantic import BaseModel, Field, SecretStr
from pydantic import BaseModel, Field
from backend.api.features.integrations.router import CredentialsMetaResponse
from backend.blocks.mcp.client import MCPClient, MCPClientError
from backend.blocks.mcp.helpers import (
auto_lookup_mcp_credential,
normalize_mcp_url,
server_host,
)
from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.request import HTTPClientError, Requests, validate_url_host
from backend.util.request import HTTPClientError, Requests
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -78,20 +74,32 @@ async def discover_tools(
If the user has a stored MCP credential for this server URL, it will be
used automatically — no need to pass an explicit auth token.
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url_host(request.server_url)
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
auth_token = request.auth_token
# Auto-use stored MCP credential when no explicit token is provided.
if not auth_token:
best_cred = await auto_lookup_mcp_credential(
user_id, normalize_mcp_url(request.server_url)
mcp_creds = await creds_manager.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
# Find the freshest credential for this server URL
best_cred: OAuth2Credentials | None = None
for cred in mcp_creds:
if (
isinstance(cred, OAuth2Credentials)
and (cred.metadata or {}).get("mcp_server_url") == request.server_url
):
if best_cred is None or (
(cred.access_token_expires_at or 0)
> (best_cred.access_token_expires_at or 0)
):
best_cred = cred
if best_cred:
# Refresh the token if expired before using it
best_cred = await creds_manager.refresh_if_needed(user_id, best_cred)
logger.info(
f"Using MCP credential {best_cred.id} for {request.server_url}, "
f"expires_at={best_cred.access_token_expires_at}"
)
auth_token = best_cred.access_token.get_secret_value()
client = MCPClient(request.server_url, auth_token=auth_token)
@@ -126,7 +134,7 @@ async def discover_tools(
],
server_name=(
init_result.get("serverInfo", {}).get("name")
or server_host(request.server_url)
or urlparse(request.server_url).hostname
or "MCP"
),
protocol_version=init_result.get("protocolVersion"),
@@ -165,16 +173,7 @@ async def mcp_oauth_login(
3. Performs Dynamic Client Registration (RFC 7591) if available
4. Returns the authorization URL for the frontend to open in a popup
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url_host(request.server_url)
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
# Normalize the URL so that credentials stored here are matched consistently
# by auto_lookup_mcp_credential (which also uses normalized URLs).
server_url = normalize_mcp_url(request.server_url)
client = MCPClient(server_url)
client = MCPClient(request.server_url)
# Step 1: Discover protected-resource metadata (RFC 9728)
protected_resource = await client.discover_auth()
@@ -183,16 +182,7 @@ async def mcp_oauth_login(
if protected_resource and protected_resource.get("authorization_servers"):
auth_server_url = protected_resource["authorization_servers"][0]
resource_url = protected_resource.get("resource", server_url)
# Validate the auth server URL from metadata to prevent SSRF.
try:
await validate_url_host(auth_server_url)
except ValueError as e:
raise fastapi.HTTPException(
status_code=400,
detail=f"Invalid authorization server URL in metadata: {e}",
)
resource_url = protected_resource.get("resource", request.server_url)
# Step 2a: Discover auth-server metadata (RFC 8414)
metadata = await client.discover_auth_server_metadata(auth_server_url)
@@ -202,7 +192,7 @@ async def mcp_oauth_login(
# Don't assume a resource_url — omitting it lets the auth server choose
# the correct audience for the token (RFC 8707 resource is optional).
resource_url = None
metadata = await client.discover_auth_server_metadata(server_url)
metadata = await client.discover_auth_server_metadata(request.server_url)
if (
not metadata
@@ -232,18 +222,12 @@ async def mcp_oauth_login(
client_id = ""
client_secret = ""
if registration_endpoint:
# Validate the registration endpoint to prevent SSRF via metadata.
try:
await validate_url_host(registration_endpoint)
except ValueError:
pass # Skip registration, fall back to default client_id
else:
reg_result = await _register_mcp_client(
registration_endpoint, redirect_uri, server_url
)
if reg_result:
client_id = reg_result.get("client_id", "")
client_secret = reg_result.get("client_secret", "")
reg_result = await _register_mcp_client(
registration_endpoint, redirect_uri, request.server_url
)
if reg_result:
client_id = reg_result.get("client_id", "")
client_secret = reg_result.get("client_secret", "")
if not client_id:
client_id = "autogpt-platform"
@@ -261,7 +245,7 @@ async def mcp_oauth_login(
"token_url": token_url,
"revoke_url": revoke_url,
"resource_url": resource_url,
"server_url": server_url,
"server_url": request.server_url,
"client_id": client_id,
"client_secret": client_secret,
},
@@ -358,7 +342,7 @@ async def mcp_oauth_callback(
credentials.metadata["mcp_token_url"] = meta["token_url"]
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
hostname = server_host(meta["server_url"])
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
credentials.title = f"MCP: {hostname}"
# Remove old MCP credentials for the same server to prevent stale token buildup.
@@ -373,9 +357,7 @@ async def mcp_oauth_callback(
):
await creds_manager.store.delete_creds_by_id(user_id, old.id)
logger.info(
"Removed old MCP credential %s for %s",
old.id,
server_host(meta["server_url"]),
f"Removed old MCP credential {old.id} for {meta['server_url']}"
)
except Exception:
logger.debug("Could not clean up old MCP credentials", exc_info=True)
@@ -393,93 +375,6 @@ async def mcp_oauth_callback(
)
# ======================== Bearer Token ======================== #
class MCPStoreTokenRequest(BaseModel):
"""Request to store a bearer token for an MCP server that doesn't support OAuth."""
server_url: str = Field(
description="MCP server URL the token authenticates against"
)
token: SecretStr = Field(
min_length=1, description="Bearer token / API key for the MCP server"
)
@router.post(
"/token",
summary="Store a bearer token for an MCP server",
)
async def mcp_store_token(
request: MCPStoreTokenRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> CredentialsMetaResponse:
"""
Store a manually provided bearer token as an MCP credential.
Used by the Copilot MCPSetupCard when the server doesn't support the MCP
OAuth discovery flow (returns 400 from /oauth/login). Subsequent
``run_mcp_tool`` calls will automatically pick up the token via
``_auto_lookup_credential``.
"""
token = request.token.get_secret_value().strip()
if not token:
raise fastapi.HTTPException(status_code=422, detail="Token must not be blank.")
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url_host(request.server_url)
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
# Normalize URL so trailing-slash variants match existing credentials.
server_url = normalize_mcp_url(request.server_url)
hostname = server_host(server_url)
# Collect IDs of old credentials to clean up after successful create.
old_cred_ids: list[str] = []
try:
old_creds = await creds_manager.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
old_cred_ids = [
old.id
for old in old_creds
if isinstance(old, OAuth2Credentials)
and normalize_mcp_url((old.metadata or {}).get("mcp_server_url", ""))
== server_url
]
except Exception:
logger.debug("Could not query old MCP token credentials", exc_info=True)
credentials = OAuth2Credentials(
provider=ProviderName.MCP.value,
title=f"MCP: {hostname}",
access_token=SecretStr(token),
scopes=[],
metadata={"mcp_server_url": server_url},
)
await creds_manager.create(user_id, credentials)
# Only delete old credentials after the new one is safely stored.
for old_id in old_cred_ids:
try:
await creds_manager.store.delete_creds_by_id(user_id, old_id)
except Exception:
logger.debug("Could not clean up old MCP token credential", exc_info=True)
return CredentialsMetaResponse(
id=credentials.id,
provider=credentials.provider,
type=credentials.type,
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
host=hostname,
)
# ======================== Helpers ======================== #
@@ -505,7 +400,5 @@ async def _register_mcp_client(
return data
return None
except Exception as e:
logger.warning(
"Dynamic client registration failed for %s: %s", server_host(server_url), e
)
logger.warning(f"Dynamic client registration failed for {server_url}: {e}")
return None

View File

@@ -11,11 +11,9 @@ import httpx
import pytest
import pytest_asyncio
from autogpt_libs.auth import get_user_id
from pydantic import SecretStr
from backend.api.features.mcp.routes import router
from backend.blocks.mcp.client import MCPClientError, MCPTool
from backend.data.model import OAuth2Credentials
from backend.util.request import HTTPClientError
app = fastapi.FastAPI()
@@ -30,16 +28,6 @@ async def client():
yield c
@pytest.fixture(autouse=True)
def _bypass_ssrf_validation():
"""Bypass validate_url_host in all route tests (test URLs don't resolve)."""
with patch(
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
):
yield
class TestDiscoverTools:
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_success(self, client):
@@ -68,12 +56,9 @@ class TestDiscoverTools:
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
),
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={
@@ -122,6 +107,10 @@ class TestDiscoverTools:
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_auto_uses_stored_credential(self, client):
"""When no explicit token is given, stored MCP credentials are used."""
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
stored_cred = OAuth2Credentials(
provider="mcp",
title="MCP: example.com",
@@ -135,12 +124,10 @@ class TestDiscoverTools:
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=stored_cred,
),
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred])
mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred)
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
@@ -162,12 +149,9 @@ class TestDiscoverTools:
async def test_discover_tools_mcp_error(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
),
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=MCPClientError("Connection refused")
@@ -185,12 +169,9 @@ class TestDiscoverTools:
async def test_discover_tools_generic_error(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
),
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
@@ -206,12 +187,9 @@ class TestDiscoverTools:
async def test_discover_tools_auth_required(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
),
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
@@ -229,12 +207,9 @@ class TestDiscoverTools:
async def test_discover_tools_forbidden(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
),
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
@@ -356,6 +331,10 @@ class TestOAuthLogin:
class TestOAuthCallback:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_success(self, client):
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
mock_creds = OAuth2Credentials(
provider="mcp",
title=None,
@@ -455,118 +434,3 @@ class TestOAuthCallback:
assert response.status_code == 400
assert "token exchange failed" in response.json()["detail"].lower()
class TestStoreToken:
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_success(self, client):
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
mock_cm.create = AsyncMock()
response = await client.post(
"/token",
json={
"server_url": "https://mcp.example.com/mcp",
"token": "my-api-key-123",
},
)
assert response.status_code == 200
data = response.json()
assert data["provider"] == "mcp"
assert data["type"] == "oauth2"
assert data["host"] == "mcp.example.com"
mock_cm.create.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_blank_rejected(self, client):
"""Blank token string (after stripping) should return 422."""
response = await client.post(
"/token",
json={
"server_url": "https://mcp.example.com/mcp",
"token": " ",
},
)
# Pydantic min_length=1 catches the whitespace-only token
assert response.status_code == 422
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_replaces_old_credential(self, client):
old_cred = OAuth2Credentials(
provider="mcp",
title="MCP: mcp.example.com",
access_token=SecretStr("old-token"),
scopes=[],
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
)
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[old_cred])
mock_cm.create = AsyncMock()
mock_cm.store.delete_creds_by_id = AsyncMock()
response = await client.post(
"/token",
json={
"server_url": "https://mcp.example.com/mcp",
"token": "new-token",
},
)
assert response.status_code == 200
mock_cm.store.delete_creds_by_id.assert_called_once_with(
"test-user-id", old_cred.id
)
class TestSSRFValidation:
"""Verify that validate_url_host is enforced on all endpoints."""
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):
response = await client.post(
"/discover-tools",
json={"server_url": "http://localhost/mcp"},
)
assert response.status_code == 400
assert "blocked loopback" in response.json()["detail"].lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
side_effect=ValueError("blocked private IP"),
):
response = await client.post(
"/oauth/login",
json={"server_url": "http://10.0.0.1/mcp"},
)
assert response.status_code == 400
assert "blocked private ip" in response.json()["detail"].lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url_host",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):
response = await client.post(
"/token",
json={
"server_url": "http://127.0.0.1/mcp",
"token": "some-token",
},
)
assert response.status_code == 400
assert "blocked loopback" in response.json()["detail"].lower()

View File

@@ -1,3 +1,5 @@
from typing import Literal
from backend.util.cache import cached
from . import db as store_db
@@ -21,7 +23,7 @@ def clear_all_caches():
async def _get_cached_store_agents(
featured: bool,
creator: str | None,
sorted_by: store_db.StoreAgentsSortOptions | None,
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None,
search_query: str | None,
category: str | None,
page: int,
@@ -55,7 +57,7 @@ async def _get_cached_agent_details(
async def _get_cached_store_creators(
featured: bool,
search_query: str | None,
sorted_by: store_db.StoreCreatorsSortOptions | None,
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None,
page: int,
page_size: int,
):
@@ -73,4 +75,4 @@ async def _get_cached_store_creators(
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
async def _get_cached_creator_details(username: str):
"""Cached helper to get creator details."""
return await store_db.get_store_creator(username=username.lower())
return await store_db.get_store_creator_details(username=username.lower())

File diff suppressed because it is too large Load Diff

View File

@@ -26,7 +26,7 @@ async def test_get_store_agents(mocker):
mock_agents = [
prisma.models.StoreAgent(
listing_id="test-id",
listing_version_id="version123",
storeListingVersionId="version123",
slug="test-agent",
agent_name="Test Agent",
agent_video=None,
@@ -40,11 +40,11 @@ async def test_get_store_agents(mocker):
runs=10,
rating=4.5,
versions=["1.0"],
graph_id="test-graph-id",
graph_versions=["1"],
agentGraphVersions=["1"],
agentGraphId="test-graph-id",
updated_at=datetime.now(),
is_available=False,
use_for_onboarding=False,
useForOnboarding=False,
)
]
@@ -68,10 +68,10 @@ async def test_get_store_agents(mocker):
@pytest.mark.asyncio(loop_scope="session")
async def test_get_store_agent_details(mocker):
# Mock data - StoreAgent view already contains the active version data
# Mock data
mock_agent = prisma.models.StoreAgent(
listing_id="test-id",
listing_version_id="version123",
storeListingVersionId="version123",
slug="test-agent",
agent_name="Test Agent",
agent_video="video.mp4",
@@ -85,38 +85,102 @@ async def test_get_store_agent_details(mocker):
runs=10,
rating=4.5,
versions=["1.0"],
graph_id="test-graph-id",
graph_versions=["1"],
agentGraphVersions=["1"],
agentGraphId="test-graph-id",
updated_at=datetime.now(),
is_available=True,
use_for_onboarding=False,
is_available=False,
useForOnboarding=False,
)
# Mock StoreAgent prisma call
# Mock active version agent (what we want to return for active version)
mock_active_agent = prisma.models.StoreAgent(
listing_id="test-id",
storeListingVersionId="active-version-id",
slug="test-agent",
agent_name="Test Agent Active",
agent_video="active_video.mp4",
agent_image=["active_image.jpg"],
featured=False,
creator_username="creator",
creator_avatar="avatar.jpg",
sub_heading="Test heading active",
description="Test description active",
categories=["test"],
runs=15,
rating=4.8,
versions=["1.0", "2.0"],
agentGraphVersions=["1", "2"],
agentGraphId="test-graph-id-active",
updated_at=datetime.now(),
is_available=True,
useForOnboarding=False,
)
# Create a mock StoreListing result
mock_store_listing = mocker.MagicMock()
mock_store_listing.activeVersionId = "active-version-id"
mock_store_listing.hasApprovedVersion = True
mock_store_listing.ActiveVersion = mocker.MagicMock()
mock_store_listing.ActiveVersion.recommendedScheduleCron = None
# Mock StoreAgent prisma call - need to handle multiple calls
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
# Set up side_effect to return different results for different calls
def mock_find_first_side_effect(*args, **kwargs):
where_clause = kwargs.get("where", {})
if "storeListingVersionId" in where_clause:
# Second call for active version
return mock_active_agent
else:
# First call for initial lookup
return mock_agent
mock_store_agent.return_value.find_first = mocker.AsyncMock(
side_effect=mock_find_first_side_effect
)
# Mock Profile prisma call
mock_profile = mocker.MagicMock()
mock_profile.userId = "user-id-123"
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
mock_profile_db.return_value.find_first = mocker.AsyncMock(
return_value=mock_profile
)
# Mock StoreListing prisma call
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(
return_value=mock_store_listing
)
# Call function
result = await db.get_store_agent_details("creator", "test-agent")
# Verify results - constructed from the StoreAgent view
# Verify results - should use active version data
assert result.slug == "test-agent"
assert result.agent_name == "Test Agent"
assert result.active_version_id == "version123"
assert result.agent_name == "Test Agent Active" # From active version
assert result.active_version_id == "active-version-id"
assert result.has_approved_version is True
assert result.store_listing_version_id == "version123"
assert result.graph_id == "test-graph-id"
assert result.runs == 10
assert result.rating == 4.5
assert (
result.store_listing_version_id == "active-version-id"
) # Should be active version ID
# Verify single StoreAgent lookup
mock_store_agent.return_value.find_first.assert_called_once_with(
# Verify mocks called correctly - now expecting 2 calls
assert mock_store_agent.return_value.find_first.call_count == 2
# Check the specific calls
calls = mock_store_agent.return_value.find_first.call_args_list
assert calls[0] == mocker.call(
where={"creator_username": "creator", "slug": "test-agent"}
)
assert calls[1] == mocker.call(where={"storeListingVersionId": "active-version-id"})
mock_store_listing_db.return_value.find_first.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_get_store_creator(mocker):
async def test_get_store_creator_details(mocker):
# Mock data
mock_creator_data = prisma.models.Creator(
name="Test Creator",
@@ -138,7 +202,7 @@ async def test_get_store_creator(mocker):
mock_creator.return_value.find_unique.return_value = mock_creator_data
# Call function
result = await db.get_store_creator("creator")
result = await db.get_store_creator_details("creator")
# Verify results
assert result.username == "creator"
@@ -154,110 +218,61 @@ async def test_get_store_creator(mocker):
@pytest.mark.asyncio(loop_scope="session")
async def test_create_store_submission(mocker):
now = datetime.now()
# Mock agent graph (with no pending submissions) and user with profile
mock_profile = prisma.models.Profile(
id="profile-id",
userId="user-id",
name="Test User",
username="testuser",
description="Test",
isFeatured=False,
links=[],
createdAt=now,
updatedAt=now,
)
mock_user = prisma.models.User(
id="user-id",
email="test@example.com",
createdAt=now,
updatedAt=now,
Profile=[mock_profile],
emailVerified=True,
metadata="{}", # type: ignore[reportArgumentType]
integrations="",
maxEmailsPerDay=1,
notifyOnAgentRun=True,
notifyOnZeroBalance=True,
notifyOnLowBalance=True,
notifyOnBlockExecutionFailed=True,
notifyOnContinuousAgentError=True,
notifyOnDailySummary=True,
notifyOnWeeklySummary=True,
notifyOnMonthlySummary=True,
notifyOnAgentApproved=True,
notifyOnAgentRejected=True,
timezone="Europe/Delft",
)
# Mock data
mock_agent = prisma.models.AgentGraph(
id="agent-id",
version=1,
userId="user-id",
createdAt=now,
createdAt=datetime.now(),
isActive=True,
StoreListingVersions=[],
User=mock_user,
)
# Mock the created StoreListingVersion (returned by create)
mock_store_listing_obj = prisma.models.StoreListing(
mock_listing = prisma.models.StoreListing(
id="listing-id",
createdAt=now,
updatedAt=now,
createdAt=datetime.now(),
updatedAt=datetime.now(),
isDeleted=False,
hasApprovedVersion=False,
slug="test-agent",
agentGraphId="agent-id",
owningUserId="user-id",
useForOnboarding=False,
)
mock_version = prisma.models.StoreListingVersion(
id="version-id",
agentGraphId="agent-id",
agentGraphVersion=1,
name="Test Agent",
description="Test description",
createdAt=now,
updatedAt=now,
subHeading="",
imageUrls=[],
categories=[],
isFeatured=False,
isDeleted=False,
version=1,
storeListingId="listing-id",
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
isAvailable=True,
submittedAt=now,
StoreListing=mock_store_listing_obj,
owningUserId="user-id",
Versions=[
prisma.models.StoreListingVersion(
id="version-id",
agentGraphId="agent-id",
agentGraphVersion=1,
name="Test Agent",
description="Test description",
createdAt=datetime.now(),
updatedAt=datetime.now(),
subHeading="Test heading",
imageUrls=["image.jpg"],
categories=["test"],
isFeatured=False,
isDeleted=False,
version=1,
storeListingId="listing-id",
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
isAvailable=True,
)
],
useForOnboarding=False,
)
# Mock prisma calls
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
mock_agent_graph.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
# Mock transaction context manager
mock_tx = mocker.MagicMock()
mocker.patch(
"backend.api.features.store.db.transaction",
return_value=mocker.AsyncMock(
__aenter__=mocker.AsyncMock(return_value=mock_tx),
__aexit__=mocker.AsyncMock(return_value=False),
),
)
mock_sl = mocker.patch("prisma.models.StoreListing.prisma")
mock_sl.return_value.find_unique = mocker.AsyncMock(return_value=None)
mock_slv = mocker.patch("prisma.models.StoreListingVersion.prisma")
mock_slv.return_value.create = mocker.AsyncMock(return_value=mock_version)
mock_store_listing = mocker.patch("prisma.models.StoreListing.prisma")
mock_store_listing.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_store_listing.return_value.create = mocker.AsyncMock(return_value=mock_listing)
# Call function
result = await db.create_store_submission(
user_id="user-id",
graph_id="agent-id",
graph_version=1,
agent_id="agent-id",
agent_version=1,
slug="test-agent",
name="Test Agent",
description="Test description",
@@ -266,11 +281,11 @@ async def test_create_store_submission(mocker):
# Verify results
assert result.name == "Test Agent"
assert result.description == "Test description"
assert result.listing_version_id == "version-id"
assert result.store_listing_version_id == "version-id"
# Verify mocks called correctly
mock_agent_graph.return_value.find_first.assert_called_once()
mock_slv.return_value.create.assert_called_once()
mock_store_listing.return_value.create.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
@@ -303,6 +318,7 @@ async def test_update_profile(mocker):
description="Test description",
links=["link1"],
avatar_url="avatar.jpg",
is_featured=False,
)
# Call function
@@ -373,7 +389,7 @@ async def test_get_store_agents_with_search_and_filters_parameterized():
creators=["creator1'; DROP TABLE Users; --", "creator2"],
category="AI'; DELETE FROM StoreAgent; --",
featured=True,
sorted_by=db.StoreAgentsSortOptions.RATING,
sorted_by="rating",
page=1,
page_size=20,
)

View File

@@ -57,6 +57,12 @@ class StoreError(ValueError):
pass
class AgentNotFoundError(NotFoundError):
"""Raised when an agent is not found"""
pass
class CreatorNotFoundError(NotFoundError):
"""Raised when a creator is not found"""

View File

@@ -568,7 +568,7 @@ async def hybrid_search(
SELECT uce."contentId" as "storeListingVersionId"
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
INNER JOIN {{schema_prefix}}"StoreAgent" sa
ON uce."contentId" = sa.listing_version_id
ON uce."contentId" = sa."storeListingVersionId"
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
AND uce."userId" IS NULL
AND uce.search @@ plainto_tsquery('english', {query_param})
@@ -582,7 +582,7 @@ async def hybrid_search(
SELECT uce."contentId", uce.embedding
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
INNER JOIN {{schema_prefix}}"StoreAgent" sa
ON uce."contentId" = sa.listing_version_id
ON uce."contentId" = sa."storeListingVersionId"
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
AND uce."userId" IS NULL
AND {where_clause}
@@ -605,7 +605,7 @@ async def hybrid_search(
sa.featured,
sa.is_available,
sa.updated_at,
sa.graph_id,
sa."agentGraphId",
-- Searchable text for BM25 reranking
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
-- Semantic score
@@ -627,9 +627,9 @@ async def hybrid_search(
sa.runs as popularity_raw
FROM candidates c
INNER JOIN {{schema_prefix}}"StoreAgent" sa
ON c."storeListingVersionId" = sa.listing_version_id
ON c."storeListingVersionId" = sa."storeListingVersionId"
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
ON sa.listing_version_id = uce."contentId"
ON sa."storeListingVersionId" = uce."contentId"
AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
),
max_vals AS (
@@ -665,7 +665,7 @@ async def hybrid_search(
featured,
is_available,
updated_at,
graph_id,
"agentGraphId",
searchable_text,
semantic_score,
lexical_score,

View File

@@ -1,14 +1,11 @@
import datetime
from typing import TYPE_CHECKING, List, Self
from typing import List
import prisma.enums
import pydantic
from backend.util.models import Pagination
if TYPE_CHECKING:
import prisma.models
class ChangelogEntry(pydantic.BaseModel):
version: str
@@ -16,9 +13,9 @@ class ChangelogEntry(pydantic.BaseModel):
date: datetime.datetime
class MyUnpublishedAgent(pydantic.BaseModel):
graph_id: str
graph_version: int
class MyAgent(pydantic.BaseModel):
agent_id: str
agent_version: int
agent_name: str
agent_image: str | None = None
description: str
@@ -26,8 +23,8 @@ class MyUnpublishedAgent(pydantic.BaseModel):
recommended_schedule_cron: str | None = None
class MyUnpublishedAgentsResponse(pydantic.BaseModel):
agents: list[MyUnpublishedAgent]
class MyAgentsResponse(pydantic.BaseModel):
agents: list[MyAgent]
pagination: Pagination
@@ -43,21 +40,6 @@ class StoreAgent(pydantic.BaseModel):
rating: float
agent_graph_id: str
@classmethod
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgent":
return cls(
slug=agent.slug,
agent_name=agent.agent_name,
agent_image=agent.agent_image[0] if agent.agent_image else "",
creator=agent.creator_username or "Needs Profile",
creator_avatar=agent.creator_avatar or "",
sub_heading=agent.sub_heading,
description=agent.description,
runs=agent.runs,
rating=agent.rating,
agent_graph_id=agent.graph_id,
)
class StoreAgentsResponse(pydantic.BaseModel):
agents: list[StoreAgent]
@@ -80,192 +62,81 @@ class StoreAgentDetails(pydantic.BaseModel):
runs: int
rating: float
versions: list[str]
graph_id: str
graph_versions: list[str]
agentGraphVersions: list[str]
agentGraphId: str
last_updated: datetime.datetime
recommended_schedule_cron: str | None = None
active_version_id: str
has_approved_version: bool
active_version_id: str | None = None
has_approved_version: bool = False
# Optional changelog data when include_changelog=True
changelog: list[ChangelogEntry] | None = None
@classmethod
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgentDetails":
return cls(
store_listing_version_id=agent.listing_version_id,
slug=agent.slug,
agent_name=agent.agent_name,
agent_video=agent.agent_video or "",
agent_output_demo=agent.agent_output_demo or "",
agent_image=agent.agent_image,
creator=agent.creator_username or "",
creator_avatar=agent.creator_avatar or "",
sub_heading=agent.sub_heading,
description=agent.description,
categories=agent.categories,
runs=agent.runs,
rating=agent.rating,
versions=agent.versions,
graph_id=agent.graph_id,
graph_versions=agent.graph_versions,
last_updated=agent.updated_at,
recommended_schedule_cron=agent.recommended_schedule_cron,
active_version_id=agent.listing_version_id,
has_approved_version=True, # StoreAgent view only has approved agents
)
class Profile(pydantic.BaseModel):
"""Marketplace user profile (only attributes that the user can update)"""
username: str
class Creator(pydantic.BaseModel):
name: str
username: str
description: str
avatar_url: str | None
links: list[str]
class ProfileDetails(Profile):
"""Marketplace user profile (including read-only fields)"""
is_featured: bool
@classmethod
def from_db(cls, profile: "prisma.models.Profile") -> "ProfileDetails":
return cls(
name=profile.name,
username=profile.username,
avatar_url=profile.avatarUrl,
description=profile.description,
links=profile.links,
is_featured=profile.isFeatured,
)
class CreatorDetails(ProfileDetails):
"""Marketplace creator profile details, including aggregated stats"""
avatar_url: str
num_agents: int
agent_runs: int
agent_rating: float
top_categories: list[str]
@classmethod
def from_db(cls, creator: "prisma.models.Creator") -> "CreatorDetails": # type: ignore[override]
return cls(
name=creator.name,
username=creator.username,
avatar_url=creator.avatar_url,
description=creator.description,
links=creator.links,
is_featured=creator.is_featured,
num_agents=creator.num_agents,
agent_runs=creator.agent_runs,
agent_rating=creator.agent_rating,
top_categories=creator.top_categories,
)
agent_runs: int
is_featured: bool
class CreatorsResponse(pydantic.BaseModel):
creators: List[CreatorDetails]
creators: List[Creator]
pagination: Pagination
class StoreSubmission(pydantic.BaseModel):
# From StoreListing:
listing_id: str
user_id: str
slug: str
class CreatorDetails(pydantic.BaseModel):
name: str
username: str
description: str
links: list[str]
avatar_url: str
agent_rating: float
agent_runs: int
top_categories: list[str]
# From StoreListingVersion:
listing_version_id: str
listing_version: int
graph_id: str
graph_version: int
class Profile(pydantic.BaseModel):
name: str
username: str
description: str
links: list[str]
avatar_url: str
is_featured: bool = False
class StoreSubmission(pydantic.BaseModel):
listing_id: str
agent_id: str
agent_version: int
name: str
sub_heading: str
slug: str
description: str
instructions: str | None
categories: list[str]
instructions: str | None = None
image_urls: list[str]
video_url: str | None
agent_output_demo_url: str | None
submitted_at: datetime.datetime | None
changes_summary: str | None
date_submitted: datetime.datetime
status: prisma.enums.SubmissionStatus
reviewed_at: datetime.datetime | None = None
runs: int
rating: float
store_listing_version_id: str | None = None
version: int | None = None # Actual version number from the database
reviewer_id: str | None = None
review_comments: str | None = None # External comments visible to creator
internal_comments: str | None = None # Private notes for admin use only
reviewed_at: datetime.datetime | None = None
changes_summary: str | None = None
# Aggregated from AgentGraphExecutions and StoreListingReviews:
run_count: int = 0
review_count: int = 0
review_avg_rating: float = 0.0
@classmethod
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
"""Construct from the StoreSubmission Prisma view."""
return cls(
listing_id=_sub.listing_id,
user_id=_sub.user_id,
slug=_sub.slug,
listing_version_id=_sub.listing_version_id,
listing_version=_sub.listing_version,
graph_id=_sub.graph_id,
graph_version=_sub.graph_version,
name=_sub.name,
sub_heading=_sub.sub_heading,
description=_sub.description,
instructions=_sub.instructions,
categories=_sub.categories,
image_urls=_sub.image_urls,
video_url=_sub.video_url,
agent_output_demo_url=_sub.agent_output_demo_url,
submitted_at=_sub.submitted_at,
changes_summary=_sub.changes_summary,
status=_sub.status,
reviewed_at=_sub.reviewed_at,
reviewer_id=_sub.reviewer_id,
review_comments=_sub.review_comments,
run_count=_sub.run_count,
review_count=_sub.review_count,
review_avg_rating=_sub.review_avg_rating,
)
@classmethod
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
"""
Construct from the StoreListingVersion Prisma model (with StoreListing included)
"""
if not (_l := _lv.StoreListing):
raise ValueError("StoreListingVersion must have included StoreListing")
return cls(
listing_id=_l.id,
user_id=_l.owningUserId,
slug=_l.slug,
listing_version_id=_lv.id,
listing_version=_lv.version,
graph_id=_lv.agentGraphId,
graph_version=_lv.agentGraphVersion,
name=_lv.name,
sub_heading=_lv.subHeading,
description=_lv.description,
instructions=_lv.instructions,
categories=_lv.categories,
image_urls=_lv.imageUrls,
video_url=_lv.videoUrl,
agent_output_demo_url=_lv.agentOutputDemoUrl,
submitted_at=_lv.submittedAt,
changes_summary=_lv.changesSummary,
status=_lv.submissionStatus,
reviewed_at=_lv.reviewedAt,
reviewer_id=_lv.reviewerId,
review_comments=_lv.reviewComments,
)
# Additional fields for editing
video_url: str | None = None
agent_output_demo_url: str | None = None
categories: list[str] = []
class StoreSubmissionsResponse(pydantic.BaseModel):
@@ -273,12 +144,33 @@ class StoreSubmissionsResponse(pydantic.BaseModel):
pagination: Pagination
class StoreListingWithVersions(pydantic.BaseModel):
"""A store listing with its version history"""
listing_id: str
slug: str
agent_id: str
agent_version: int
active_version_id: str | None = None
has_approved_version: bool = False
creator_email: str | None = None
latest_version: StoreSubmission | None = None
versions: list[StoreSubmission] = []
class StoreListingsWithVersionsResponse(pydantic.BaseModel):
"""Response model for listings with version history"""
listings: list[StoreListingWithVersions]
pagination: Pagination
class StoreSubmissionRequest(pydantic.BaseModel):
graph_id: str = pydantic.Field(
..., min_length=1, description="Graph ID cannot be empty"
agent_id: str = pydantic.Field(
..., min_length=1, description="Agent ID cannot be empty"
)
graph_version: int = pydantic.Field(
..., gt=0, description="Graph version must be greater than 0"
agent_version: int = pydantic.Field(
..., gt=0, description="Agent version must be greater than 0"
)
slug: str
name: str
@@ -306,42 +198,12 @@ class StoreSubmissionEditRequest(pydantic.BaseModel):
recommended_schedule_cron: str | None = None
class StoreSubmissionAdminView(StoreSubmission):
internal_comments: str | None # Private admin notes
@classmethod
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
return cls(
**StoreSubmission.from_db(_sub).model_dump(),
internal_comments=_sub.internal_comments,
)
@classmethod
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
return cls(
**StoreSubmission.from_listing_version(_lv).model_dump(),
internal_comments=_lv.internalComments,
)
class StoreListingWithVersionsAdminView(pydantic.BaseModel):
"""A store listing with its version history"""
listing_id: str
graph_id: str
slug: str
active_listing_version_id: str | None = None
has_approved_version: bool = False
creator_email: str | None = None
latest_version: StoreSubmissionAdminView | None = None
versions: list[StoreSubmissionAdminView] = []
class StoreListingsWithVersionsAdminViewResponse(pydantic.BaseModel):
"""Response model for listings with version history"""
listings: list[StoreListingWithVersionsAdminView]
pagination: Pagination
class ProfileDetails(pydantic.BaseModel):
name: str
username: str
description: str
links: list[str]
avatar_url: str | None = None
class StoreReview(pydantic.BaseModel):

View File

@@ -0,0 +1,203 @@
import datetime
import prisma.enums
from . import model as store_model
def test_pagination():
pagination = store_model.Pagination(
total_items=100, total_pages=5, current_page=2, page_size=20
)
assert pagination.total_items == 100
assert pagination.total_pages == 5
assert pagination.current_page == 2
assert pagination.page_size == 20
def test_store_agent():
agent = store_model.StoreAgent(
slug="test-agent",
agent_name="Test Agent",
agent_image="test.jpg",
creator="creator1",
creator_avatar="avatar.jpg",
sub_heading="Test subheading",
description="Test description",
runs=50,
rating=4.5,
agent_graph_id="test-graph-id",
)
assert agent.slug == "test-agent"
assert agent.agent_name == "Test Agent"
assert agent.runs == 50
assert agent.rating == 4.5
assert agent.agent_graph_id == "test-graph-id"
def test_store_agents_response():
response = store_model.StoreAgentsResponse(
agents=[
store_model.StoreAgent(
slug="test-agent",
agent_name="Test Agent",
agent_image="test.jpg",
creator="creator1",
creator_avatar="avatar.jpg",
sub_heading="Test subheading",
description="Test description",
runs=50,
rating=4.5,
agent_graph_id="test-graph-id",
)
],
pagination=store_model.Pagination(
total_items=1, total_pages=1, current_page=1, page_size=20
),
)
assert len(response.agents) == 1
assert response.pagination.total_items == 1
def test_store_agent_details():
details = store_model.StoreAgentDetails(
store_listing_version_id="version123",
slug="test-agent",
agent_name="Test Agent",
agent_video="video.mp4",
agent_output_demo="demo.mp4",
agent_image=["image1.jpg", "image2.jpg"],
creator="creator1",
creator_avatar="avatar.jpg",
sub_heading="Test subheading",
description="Test description",
categories=["cat1", "cat2"],
runs=50,
rating=4.5,
versions=["1.0", "2.0"],
agentGraphVersions=["1", "2"],
agentGraphId="test-graph-id",
last_updated=datetime.datetime.now(),
)
assert details.slug == "test-agent"
assert len(details.agent_image) == 2
assert len(details.categories) == 2
assert len(details.versions) == 2
def test_creator():
creator = store_model.Creator(
agent_rating=4.8,
agent_runs=1000,
name="Test Creator",
username="creator1",
description="Test description",
avatar_url="avatar.jpg",
num_agents=5,
is_featured=False,
)
assert creator.name == "Test Creator"
assert creator.num_agents == 5
def test_creators_response():
response = store_model.CreatorsResponse(
creators=[
store_model.Creator(
agent_rating=4.8,
agent_runs=1000,
name="Test Creator",
username="creator1",
description="Test description",
avatar_url="avatar.jpg",
num_agents=5,
is_featured=False,
)
],
pagination=store_model.Pagination(
total_items=1, total_pages=1, current_page=1, page_size=20
),
)
assert len(response.creators) == 1
assert response.pagination.total_items == 1
def test_creator_details():
details = store_model.CreatorDetails(
name="Test Creator",
username="creator1",
description="Test description",
links=["link1.com", "link2.com"],
avatar_url="avatar.jpg",
agent_rating=4.8,
agent_runs=1000,
top_categories=["cat1", "cat2"],
)
assert details.name == "Test Creator"
assert len(details.links) == 2
assert details.agent_rating == 4.8
assert len(details.top_categories) == 2
def test_store_submission():
submission = store_model.StoreSubmission(
listing_id="listing123",
agent_id="agent123",
agent_version=1,
sub_heading="Test subheading",
name="Test Agent",
slug="test-agent",
description="Test description",
image_urls=["image1.jpg", "image2.jpg"],
date_submitted=datetime.datetime(2023, 1, 1),
status=prisma.enums.SubmissionStatus.PENDING,
runs=50,
rating=4.5,
)
assert submission.name == "Test Agent"
assert len(submission.image_urls) == 2
assert submission.status == prisma.enums.SubmissionStatus.PENDING
def test_store_submissions_response():
response = store_model.StoreSubmissionsResponse(
submissions=[
store_model.StoreSubmission(
listing_id="listing123",
agent_id="agent123",
agent_version=1,
sub_heading="Test subheading",
name="Test Agent",
slug="test-agent",
description="Test description",
image_urls=["image1.jpg"],
date_submitted=datetime.datetime(2023, 1, 1),
status=prisma.enums.SubmissionStatus.PENDING,
runs=50,
rating=4.5,
)
],
pagination=store_model.Pagination(
total_items=1, total_pages=1, current_page=1, page_size=20
),
)
assert len(response.submissions) == 1
assert response.pagination.total_items == 1
def test_store_submission_request():
request = store_model.StoreSubmissionRequest(
agent_id="agent123",
agent_version=1,
slug="test-agent",
name="Test Agent",
sub_heading="Test subheading",
video_url="video.mp4",
image_urls=["image1.jpg", "image2.jpg"],
description="Test description",
categories=["cat1", "cat2"],
)
assert request.agent_id == "agent123"
assert request.agent_version == 1
assert len(request.image_urls) == 2
assert len(request.categories) == 2

View File

@@ -1,17 +1,16 @@
import logging
import tempfile
import typing
import urllib.parse
from typing import Literal
import autogpt_libs.auth
import fastapi
import fastapi.responses
import prisma.enums
from fastapi import Query, Security
from pydantic import BaseModel
import backend.data.graph
import backend.util.json
from backend.util.exceptions import NotFoundError
from backend.util.models import Pagination
from . import cache as store_cache
@@ -35,15 +34,22 @@ router = fastapi.APIRouter()
"/profile",
summary="Get user profile",
tags=["store", "private"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.ProfileDetails,
)
async def get_profile(
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.ProfileDetails:
"""Get the profile details for the authenticated user."""
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Get the profile details for the authenticated user.
Cached for 1 hour per user.
"""
profile = await store_db.get_user_profile(user_id)
if profile is None:
raise NotFoundError("User does not have a profile yet")
return fastapi.responses.JSONResponse(
status_code=404,
content={"detail": "Profile not found"},
)
return profile
@@ -51,17 +57,98 @@ async def get_profile(
"/profile",
summary="Update user profile",
tags=["store", "private"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.CreatorDetails,
)
async def update_or_create_profile(
profile: store_model.Profile,
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.ProfileDetails:
"""Update the store profile for the authenticated user."""
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Update the store profile for the authenticated user.
Args:
profile (Profile): The updated profile details
user_id (str): ID of the authenticated user
Returns:
CreatorDetails: The updated profile
Raises:
HTTPException: If there is an error updating the profile
"""
updated_profile = await store_db.update_profile(user_id=user_id, profile=profile)
return updated_profile
##############################################
############### Agent Endpoints ##############
##############################################
@router.get(
"/agents",
summary="List store agents",
tags=["store", "public"],
response_model=store_model.StoreAgentsResponse,
)
async def get_agents(
featured: bool = False,
creator: str | None = None,
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
search_query: str | None = None,
category: str | None = None,
page: int = 1,
page_size: int = 20,
):
"""
Get a paginated list of agents from the store with optional filtering and sorting.
Args:
featured (bool, optional): Filter to only show featured agents. Defaults to False.
creator (str | None, optional): Filter agents by creator username. Defaults to None.
sorted_by (str | None, optional): Sort agents by "runs" or "rating". Defaults to None.
search_query (str | None, optional): Search agents by name, subheading and description. Defaults to None.
category (str | None, optional): Filter agents by category. Defaults to None.
page (int, optional): Page number for pagination. Defaults to 1.
page_size (int, optional): Number of agents per page. Defaults to 20.
Returns:
StoreAgentsResponse: Paginated list of agents matching the filters
Raises:
HTTPException: If page or page_size are less than 1
Used for:
- Home Page Featured Agents
- Home Page Top Agents
- Search Results
- Agent Details - Other Agents By Creator
- Agent Details - Similar Agents
- Creator Details - Agents By Creator
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
agents = await store_cache._get_cached_store_agents(
featured=featured,
creator=creator,
sorted_by=sorted_by,
search_query=search_query,
category=category,
page=page,
page_size=page_size,
)
return agents
##############################################
############### Search Endpoints #############
##############################################
@@ -71,30 +158,60 @@ async def update_or_create_profile(
"/search",
summary="Unified search across all content types",
tags=["store", "public"],
response_model=store_model.UnifiedSearchResponse,
)
async def unified_search(
query: str,
content_types: list[prisma.enums.ContentType] | None = Query(
content_types: list[str] | None = fastapi.Query(
default=None,
description="Content types to search. If not specified, searches all.",
description="Content types to search: STORE_AGENT, BLOCK, DOCUMENTATION. If not specified, searches all.",
),
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
user_id: str | None = Security(
page: int = 1,
page_size: int = 20,
user_id: str | None = fastapi.Security(
autogpt_libs.auth.get_optional_user_id, use_cache=False
),
) -> store_model.UnifiedSearchResponse:
):
"""
Search across all content types (marketplace agents, blocks, documentation)
using hybrid search.
Search across all content types (store agents, blocks, documentation) using hybrid search.
Combines semantic (embedding-based) and lexical (text-based) search for best results.
Args:
query: The search query string
content_types: Optional list of content types to filter by (STORE_AGENT, BLOCK, DOCUMENTATION)
page: Page number for pagination (default 1)
page_size: Number of results per page (default 20)
user_id: Optional authenticated user ID (for user-scoped content in future)
Returns:
UnifiedSearchResponse: Paginated list of search results with relevance scores
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
# Convert string content types to enum
content_type_enums: list[prisma.enums.ContentType] | None = None
if content_types:
try:
content_type_enums = [prisma.enums.ContentType(ct) for ct in content_types]
except ValueError as e:
raise fastapi.HTTPException(
status_code=422,
detail=f"Invalid content type. Valid values: STORE_AGENT, BLOCK, DOCUMENTATION. Error: {e}",
)
# Perform unified hybrid search
results, total = await store_hybrid_search.unified_hybrid_search(
query=query,
content_types=content_types,
content_types=content_type_enums,
user_id=user_id,
page=page,
page_size=page_size,
@@ -128,69 +245,22 @@ async def unified_search(
)
##############################################
############### Agent Endpoints ##############
##############################################
@router.get(
"/agents",
summary="List store agents",
tags=["store", "public"],
)
async def get_agents(
featured: bool = Query(
default=False, description="Filter to only show featured agents"
),
creator: str | None = Query(
default=None, description="Filter agents by creator username"
),
category: str | None = Query(default=None, description="Filter agents by category"),
search_query: str | None = Query(
default=None, description="Literal + semantic search on names and descriptions"
),
sorted_by: store_db.StoreAgentsSortOptions | None = Query(
default=None,
description="Property to sort results by. Ignored if search_query is provided.",
),
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
) -> store_model.StoreAgentsResponse:
"""
Get a paginated list of agents from the marketplace,
with optional filtering and sorting.
Used for:
- Home Page Featured Agents
- Home Page Top Agents
- Search Results
- Agent Details - Other Agents By Creator
- Agent Details - Similar Agents
- Creator Details - Agents By Creator
"""
agents = await store_cache._get_cached_store_agents(
featured=featured,
creator=creator,
sorted_by=sorted_by,
search_query=search_query,
category=category,
page=page,
page_size=page_size,
)
return agents
@router.get(
"/agents/{username}/{agent_name}",
summary="Get specific agent",
tags=["store", "public"],
response_model=store_model.StoreAgentDetails,
)
async def get_agent_by_name(
async def get_agent(
username: str,
agent_name: str,
include_changelog: bool = Query(default=False),
) -> store_model.StoreAgentDetails:
"""Get details of a marketplace agent"""
include_changelog: bool = fastapi.Query(default=False),
):
"""
This is only used on the AgentDetails Page.
It returns the store listing agents details.
"""
username = urllib.parse.unquote(username).lower()
# URL decode the agent name since it comes from the URL path
agent_name = urllib.parse.unquote(agent_name).lower()
@@ -200,82 +270,76 @@ async def get_agent_by_name(
return agent
@router.get(
"/graph/{store_listing_version_id}",
summary="Get agent graph",
tags=["store"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
)
async def get_graph_meta_by_store_listing_version_id(
store_listing_version_id: str,
) -> backend.data.graph.GraphModelWithoutNodes:
"""
Get Agent Graph from Store Listing Version ID.
"""
graph = await store_db.get_available_graph(store_listing_version_id)
return graph
@router.get(
"/agents/{store_listing_version_id}",
summary="Get agent by version",
tags=["store"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreAgentDetails,
)
async def get_store_agent(store_listing_version_id: str):
"""
Get Store Agent Details from Store Listing Version ID.
"""
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
return agent
@router.post(
"/agents/{username}/{agent_name}/review",
summary="Create agent review",
tags=["store"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreReview,
)
async def post_user_review_for_agent(
async def create_review(
username: str,
agent_name: str,
review: store_model.StoreReviewCreate,
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.StoreReview:
"""Post a user review on a marketplace agent listing"""
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Create a review for a store agent.
Args:
username: Creator's username
agent_name: Name/slug of the agent
review: Review details including score and optional comments
user_id: ID of authenticated user creating the review
Returns:
The created review
"""
username = urllib.parse.unquote(username).lower()
agent_name = urllib.parse.unquote(agent_name).lower()
# Create the review
created_review = await store_db.create_store_review(
user_id=user_id,
store_listing_version_id=review.store_listing_version_id,
score=review.score,
comments=review.comments,
)
return created_review
@router.get(
"/listings/versions/{store_listing_version_id}",
summary="Get agent by version",
tags=["store"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def get_agent_by_listing_version(
store_listing_version_id: str,
) -> store_model.StoreAgentDetails:
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
return agent
@router.get(
"/listings/versions/{store_listing_version_id}/graph",
summary="Get agent graph",
tags=["store"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def get_graph_meta_by_store_listing_version_id(
store_listing_version_id: str,
) -> backend.data.graph.GraphModelWithoutNodes:
"""Get outline of graph belonging to a specific marketplace listing version"""
graph = await store_db.get_available_graph(store_listing_version_id)
return graph
@router.get(
"/listings/versions/{store_listing_version_id}/graph/download",
summary="Download agent file",
tags=["store", "public"],
)
async def download_agent_file(
store_listing_version_id: str,
) -> fastapi.responses.FileResponse:
"""Download agent graph file for a specific marketplace listing version"""
graph_data = await store_db.get_agent(store_listing_version_id)
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
# Sending graph as a stream (similar to marketplace v1)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as tmp_file:
tmp_file.write(backend.util.json.dumps(graph_data))
tmp_file.flush()
return fastapi.responses.FileResponse(
tmp_file.name, filename=file_name, media_type="application/json"
)
##############################################
############# Creator Endpoints #############
##############################################
@@ -285,19 +349,37 @@ async def download_agent_file(
"/creators",
summary="List store creators",
tags=["store", "public"],
response_model=store_model.CreatorsResponse,
)
async def get_creators(
featured: bool = Query(
default=False, description="Filter to only show featured creators"
),
search_query: str | None = Query(
default=None, description="Literal + semantic search on names and descriptions"
),
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
) -> store_model.CreatorsResponse:
"""List or search marketplace creators"""
featured: bool = False,
search_query: str | None = None,
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
page: int = 1,
page_size: int = 20,
):
"""
This is needed for:
- Home Page Featured Creators
- Search Results Page
---
To support this functionality we need:
- featured: bool - to limit the list to just featured agents
- search_query: str - vector search based on the creators profile description.
- sorted_by: [agent_rating, agent_runs] -
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
creators = await store_cache._get_cached_store_creators(
featured=featured,
search_query=search_query,
@@ -309,12 +391,18 @@ async def get_creators(
@router.get(
"/creators/{username}",
"/creator/{username}",
summary="Get creator details",
tags=["store", "public"],
response_model=store_model.CreatorDetails,
)
async def get_creator(username: str) -> store_model.CreatorDetails:
"""Get details on a marketplace creator"""
async def get_creator(
username: str,
):
"""
Get the details of a creator.
- Creator Details Page
"""
username = urllib.parse.unquote(username).lower()
creator = await store_cache._get_cached_creator_details(username=username)
return creator
@@ -326,17 +414,20 @@ async def get_creator(username: str) -> store_model.CreatorDetails:
@router.get(
"/my-unpublished-agents",
"/myagents",
summary="Get my agents",
tags=["store", "private"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.MyAgentsResponse,
)
async def get_my_unpublished_agents(
user_id: str = Security(autogpt_libs.auth.get_user_id),
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
) -> store_model.MyUnpublishedAgentsResponse:
"""List the authenticated user's unpublished agents"""
async def get_my_agents(
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
page: typing.Annotated[int, fastapi.Query(ge=1)] = 1,
page_size: typing.Annotated[int, fastapi.Query(ge=1)] = 20,
):
"""
Get user's own agents.
"""
agents = await store_db.get_my_agents(user_id, page=page, page_size=page_size)
return agents
@@ -345,17 +436,28 @@ async def get_my_unpublished_agents(
"/submissions/{submission_id}",
summary="Delete store submission",
tags=["store", "private"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=bool,
)
async def delete_submission(
submission_id: str,
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> bool:
"""Delete a marketplace listing submission"""
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Delete a store listing submission.
Args:
user_id (str): ID of the authenticated user
submission_id (str): ID of the submission to be deleted
Returns:
bool: True if the submission was successfully deleted, False otherwise
"""
result = await store_db.delete_store_submission(
user_id=user_id,
submission_id=submission_id,
)
return result
@@ -363,14 +465,37 @@ async def delete_submission(
"/submissions",
summary="List my submissions",
tags=["store", "private"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreSubmissionsResponse,
)
async def get_submissions(
user_id: str = Security(autogpt_libs.auth.get_user_id),
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
) -> store_model.StoreSubmissionsResponse:
"""List the authenticated user's marketplace listing submissions"""
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
page: int = 1,
page_size: int = 20,
):
"""
Get a paginated list of store submissions for the authenticated user.
Args:
user_id (str): ID of the authenticated user
page (int, optional): Page number for pagination. Defaults to 1.
page_size (int, optional): Number of submissions per page. Defaults to 20.
Returns:
StoreListingsResponse: Paginated list of store submissions
Raises:
HTTPException: If page or page_size are less than 1
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
listings = await store_db.get_store_submissions(
user_id=user_id,
page=page,
@@ -383,17 +508,30 @@ async def get_submissions(
"/submissions",
summary="Create store submission",
tags=["store", "private"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreSubmission,
)
async def create_submission(
submission_request: store_model.StoreSubmissionRequest,
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.StoreSubmission:
"""Submit a new marketplace listing for review"""
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Create a new store listing submission.
Args:
submission_request (StoreSubmissionRequest): The submission details
user_id (str): ID of the authenticated user submitting the listing
Returns:
StoreSubmission: The created store submission
Raises:
HTTPException: If there is an error creating the submission
"""
result = await store_db.create_store_submission(
user_id=user_id,
graph_id=submission_request.graph_id,
graph_version=submission_request.graph_version,
agent_id=submission_request.agent_id,
agent_version=submission_request.agent_version,
slug=submission_request.slug,
name=submission_request.name,
video_url=submission_request.video_url,
@@ -406,6 +544,7 @@ async def create_submission(
changes_summary=submission_request.changes_summary or "Initial Submission",
recommended_schedule_cron=submission_request.recommended_schedule_cron,
)
return result
@@ -413,14 +552,28 @@ async def create_submission(
"/submissions/{store_listing_version_id}",
summary="Edit store submission",
tags=["store", "private"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreSubmission,
)
async def edit_submission(
store_listing_version_id: str,
submission_request: store_model.StoreSubmissionEditRequest,
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.StoreSubmission:
"""Update a pending marketplace listing submission"""
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Edit an existing store listing submission.
Args:
store_listing_version_id (str): ID of the store listing version to edit
submission_request (StoreSubmissionRequest): The updated submission details
user_id (str): ID of the authenticated user editing the listing
Returns:
StoreSubmission: The updated store submission
Raises:
HTTPException: If there is an error editing the submission
"""
result = await store_db.edit_store_submission(
user_id=user_id,
store_listing_version_id=store_listing_version_id,
@@ -435,6 +588,7 @@ async def edit_submission(
changes_summary=submission_request.changes_summary,
recommended_schedule_cron=submission_request.recommended_schedule_cron,
)
return result
@@ -442,61 +596,115 @@ async def edit_submission(
"/submissions/media",
summary="Upload submission media",
tags=["store", "private"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
)
async def upload_submission_media(
file: fastapi.UploadFile,
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> str:
"""Upload media for a marketplace listing submission"""
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Upload media (images/videos) for a store listing submission.
Args:
file (UploadFile): The media file to upload
user_id (str): ID of the authenticated user uploading the media
Returns:
str: URL of the uploaded media file
Raises:
HTTPException: If there is an error uploading the media
"""
media_url = await store_media.upload_media(user_id=user_id, file=file)
return media_url
class ImageURLResponse(BaseModel):
image_url: str
@router.post(
"/submissions/generate_image",
summary="Generate submission image",
tags=["store", "private"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
)
async def generate_image(
graph_id: str,
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> ImageURLResponse:
agent_id: str,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
) -> fastapi.responses.Response:
"""
Generate an image for a marketplace listing submission based on the properties
of a given graph.
Generate an image for a store listing submission.
Args:
agent_id (str): ID of the agent to generate an image for
user_id (str): ID of the authenticated user
Returns:
JSONResponse: JSON containing the URL of the generated image
"""
graph = await backend.data.graph.get_graph(
graph_id=graph_id, version=None, user_id=user_id
agent = await backend.data.graph.get_graph(
graph_id=agent_id, version=None, user_id=user_id
)
if not graph:
raise NotFoundError(f"Agent graph #{graph_id} not found")
if not agent:
raise fastapi.HTTPException(
status_code=404, detail=f"Agent with ID {agent_id} not found"
)
# Use .jpeg here since we are generating JPEG images
filename = f"agent_{graph_id}.jpeg"
filename = f"agent_{agent_id}.jpeg"
existing_url = await store_media.check_media_exists(user_id, filename)
if existing_url:
logger.info(f"Using existing image for agent graph {graph_id}")
return ImageURLResponse(image_url=existing_url)
logger.info(f"Using existing image for agent {agent_id}")
return fastapi.responses.JSONResponse(content={"image_url": existing_url})
# Generate agent image as JPEG
image = await store_image_gen.generate_agent_image(agent=graph)
image = await store_image_gen.generate_agent_image(agent=agent)
# Create UploadFile with the correct filename and content_type
image_file = fastapi.UploadFile(
file=image,
filename=filename,
)
image_url = await store_media.upload_media(
user_id=user_id, file=image_file, use_file_name=True
)
return ImageURLResponse(image_url=image_url)
return fastapi.responses.JSONResponse(content={"image_url": image_url})
@router.get(
"/download/agents/{store_listing_version_id}",
summary="Download agent file",
tags=["store", "public"],
)
async def download_agent_file(
store_listing_version_id: str = fastapi.Path(
..., description="The ID of the agent to download"
),
) -> fastapi.responses.FileResponse:
"""
Download the agent file by streaming its content.
Args:
store_listing_version_id (str): The ID of the agent to download
Returns:
StreamingResponse: A streaming response containing the agent's graph data.
Raises:
HTTPException: If the agent is not found or an unexpected error occurs.
"""
graph_data = await store_db.get_agent(store_listing_version_id)
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
# Sending graph as a stream (similar to marketplace v1)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as tmp_file:
tmp_file.write(backend.util.json.dumps(graph_data))
tmp_file.flush()
return fastapi.responses.FileResponse(
tmp_file.name, filename=file_name, media_type="application/json"
)
##############################################

View File

@@ -8,8 +8,6 @@ import pytest
import pytest_mock
from pytest_snapshot.plugin import Snapshot
from backend.api.features.store.db import StoreAgentsSortOptions
from . import model as store_model
from . import routes as store_routes
@@ -198,7 +196,7 @@ def test_get_agents_sorted(
mock_db_call.assert_called_once_with(
featured=False,
creators=None,
sorted_by=StoreAgentsSortOptions.RUNS,
sorted_by="runs",
search_query=None,
category=None,
page=1,
@@ -382,11 +380,9 @@ def test_get_agent_details(
runs=100,
rating=4.5,
versions=["1.0.0", "1.1.0"],
graph_versions=["1", "2"],
graph_id="test-graph-id",
agentGraphVersions=["1", "2"],
agentGraphId="test-graph-id",
last_updated=FIXED_NOW,
active_version_id="test-version-id",
has_approved_version=True,
)
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_agent_details")
mock_db_call.return_value = mocked_value
@@ -439,17 +435,15 @@ def test_get_creators_pagination(
) -> None:
mocked_value = store_model.CreatorsResponse(
creators=[
store_model.CreatorDetails(
store_model.Creator(
name=f"Creator {i}",
username=f"creator{i}",
avatar_url=f"avatar{i}.jpg",
description=f"Creator {i} description",
links=[f"user{i}.link.com"],
is_featured=False,
avatar_url=f"avatar{i}.jpg",
num_agents=1,
agent_runs=100,
agent_rating=4.5,
top_categories=["cat1", "cat2", "cat3"],
agent_runs=100,
is_featured=False,
)
for i in range(5)
],
@@ -502,19 +496,19 @@ def test_get_creator_details(
mocked_value = store_model.CreatorDetails(
name="Test User",
username="creator1",
avatar_url="avatar.jpg",
description="Test creator description",
links=["link1.com", "link2.com"],
is_featured=True,
num_agents=5,
agent_runs=1000,
avatar_url="avatar.jpg",
agent_rating=4.8,
agent_runs=1000,
top_categories=["category1", "category2"],
)
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_creator")
mock_db_call = mocker.patch(
"backend.api.features.store.db.get_store_creator_details"
)
mock_db_call.return_value = mocked_value
response = client.get("/creators/creator1")
response = client.get("/creator/creator1")
assert response.status_code == 200
data = store_model.CreatorDetails.model_validate(response.json())
@@ -534,26 +528,19 @@ def test_get_submissions_success(
submissions=[
store_model.StoreSubmission(
listing_id="test-listing-id",
user_id="test-user-id",
slug="test-agent",
listing_version_id="test-version-id",
listing_version=1,
graph_id="test-agent-id",
graph_version=1,
name="Test Agent",
sub_heading="Test agent subheading",
description="Test agent description",
instructions="Click the button!",
categories=["test-category"],
image_urls=["test.jpg"],
video_url="test.mp4",
agent_output_demo_url="demo_video.mp4",
submitted_at=FIXED_NOW,
changes_summary="Initial Submission",
date_submitted=FIXED_NOW,
status=prisma.enums.SubmissionStatus.APPROVED,
run_count=50,
review_count=5,
review_avg_rating=4.2,
runs=50,
rating=4.2,
agent_id="test-agent-id",
agent_version=1,
sub_heading="Test agent subheading",
slug="test-agent",
video_url="test.mp4",
categories=["test-category"],
)
],
pagination=store_model.Pagination(

View File

@@ -11,7 +11,6 @@ import pytest
from backend.util.models import Pagination
from . import cache as store_cache
from .db import StoreAgentsSortOptions
from .model import StoreAgent, StoreAgentsResponse
@@ -216,7 +215,7 @@ class TestCacheDeletion:
await store_cache._get_cached_store_agents(
featured=True,
creator="testuser",
sorted_by=StoreAgentsSortOptions.RATING,
sorted_by="rating",
search_query="AI assistant",
category="productivity",
page=2,
@@ -228,7 +227,7 @@ class TestCacheDeletion:
deleted = store_cache._get_cached_store_agents.cache_delete(
featured=True,
creator="testuser",
sorted_by=StoreAgentsSortOptions.RATING,
sorted_by="rating",
search_query="AI assistant",
category="productivity",
page=2,
@@ -240,7 +239,7 @@ class TestCacheDeletion:
deleted = store_cache._get_cached_store_agents.cache_delete(
featured=True,
creator="testuser",
sorted_by=StoreAgentsSortOptions.RATING,
sorted_by="rating",
search_query="AI assistant",
category="productivity",
page=2,

View File

@@ -449,6 +449,7 @@ async def execute_graph_block(
async def upload_file(
user_id: Annotated[str, Security(get_user_id)],
file: UploadFile = File(...),
provider: str = "gcs",
expiration_hours: int = 24,
) -> UploadFileResponse:
"""
@@ -511,6 +512,7 @@ async def upload_file(
storage_path = await cloud_storage.store_file(
content=content,
filename=file_name,
provider=provider,
expiration_hours=expiration_hours,
user_id=user_id,
)

View File

@@ -515,6 +515,7 @@ async def test_upload_file_success(test_user_id: str):
result = await upload_file(
file=upload_file_mock,
user_id=test_user_id,
provider="gcs",
expiration_hours=24,
)
@@ -532,6 +533,7 @@ async def test_upload_file_success(test_user_id: str):
mock_handler.store_file.assert_called_once_with(
content=file_content,
filename="test.txt",
provider="gcs",
expiration_hours=24,
user_id=test_user_id,
)

View File

@@ -3,29 +3,15 @@ Workspace API routes for managing user file storage.
"""
import logging
import os
import re
from typing import Annotated
from urllib.parse import quote
import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
from fastapi import Query, UploadFile
from fastapi.responses import Response
from pydantic import BaseModel
from backend.data.workspace import (
WorkspaceFile,
count_workspace_files,
get_or_create_workspace,
get_workspace,
get_workspace_file,
get_workspace_total_size,
soft_delete_workspace_file,
)
from backend.util.settings import Config
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import WorkspaceManager
from backend.data.workspace import WorkspaceFile, get_workspace, get_workspace_file
from backend.util.workspace_storage import get_workspace_storage
@@ -112,25 +98,6 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
raise
class UploadFileResponse(BaseModel):
file_id: str
name: str
path: str
mime_type: str
size_bytes: int
class DeleteFileResponse(BaseModel):
deleted: bool
class StorageUsageResponse(BaseModel):
used_bytes: int
limit_bytes: int
used_percent: float
file_count: int
@router.get(
"/files/{file_id}/download",
summary="Download file by ID",
@@ -153,148 +120,3 @@ async def download_file(
raise fastapi.HTTPException(status_code=404, detail="File not found")
return await _create_file_download_response(file)
@router.delete(
"/files/{file_id}",
summary="Delete a workspace file",
)
async def delete_workspace_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
file_id: str,
) -> DeleteFileResponse:
"""
Soft-delete a workspace file and attempt to remove it from storage.
Used when a user clears a file input in the builder.
"""
workspace = await get_workspace(user_id)
if workspace is None:
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
manager = WorkspaceManager(user_id, workspace.id)
deleted = await manager.delete_file(file_id)
if not deleted:
raise fastapi.HTTPException(status_code=404, detail="File not found")
return DeleteFileResponse(deleted=True)
@router.post(
"/files/upload",
summary="Upload file to workspace",
)
async def upload_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
file: UploadFile,
session_id: str | None = Query(default=None),
) -> UploadFileResponse:
"""
Upload a file to the user's workspace.
Files are stored in session-scoped paths when session_id is provided,
so the agent's session-scoped tools can discover them automatically.
"""
config = Config()
# Sanitize filename — strip any directory components
filename = os.path.basename(file.filename or "upload") or "upload"
# Read file content with early abort on size limit
max_file_bytes = config.max_file_size_mb * 1024 * 1024
chunks: list[bytes] = []
total_size = 0
while chunk := await file.read(64 * 1024): # 64KB chunks
total_size += len(chunk)
if total_size > max_file_bytes:
raise fastapi.HTTPException(
status_code=413,
detail=f"File exceeds maximum size of {config.max_file_size_mb} MB",
)
chunks.append(chunk)
content = b"".join(chunks)
# Get or create workspace
workspace = await get_or_create_workspace(user_id)
# Pre-write storage cap check (soft check — final enforcement is post-write)
storage_limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
current_usage = await get_workspace_total_size(workspace.id)
if storage_limit_bytes and current_usage + len(content) > storage_limit_bytes:
used_percent = (current_usage / storage_limit_bytes) * 100
raise fastapi.HTTPException(
status_code=413,
detail={
"message": "Storage limit exceeded",
"used_bytes": current_usage,
"limit_bytes": storage_limit_bytes,
"used_percent": round(used_percent, 1),
},
)
# Warn at 80% usage
if (
storage_limit_bytes
and (usage_ratio := (current_usage + len(content)) / storage_limit_bytes) >= 0.8
):
logger.warning(
f"User {user_id} workspace storage at {usage_ratio * 100:.1f}% "
f"({current_usage + len(content)} / {storage_limit_bytes} bytes)"
)
# Virus scan
await scan_content_safe(content, filename=filename)
# Write file via WorkspaceManager
manager = WorkspaceManager(user_id, workspace.id, session_id)
try:
workspace_file = await manager.write_file(content, filename)
except ValueError as e:
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
# Post-write storage check — eliminates TOCTOU race on the quota.
# If a concurrent upload pushed us over the limit, undo this write.
new_total = await get_workspace_total_size(workspace.id)
if storage_limit_bytes and new_total > storage_limit_bytes:
await soft_delete_workspace_file(workspace_file.id, workspace.id)
raise fastapi.HTTPException(
status_code=413,
detail={
"message": "Storage limit exceeded (concurrent upload)",
"used_bytes": new_total,
"limit_bytes": storage_limit_bytes,
},
)
return UploadFileResponse(
file_id=workspace_file.id,
name=workspace_file.name,
path=workspace_file.path,
mime_type=workspace_file.mime_type,
size_bytes=workspace_file.size_bytes,
)
@router.get(
"/storage/usage",
summary="Get workspace storage usage",
)
async def get_storage_usage(
user_id: Annotated[str, fastapi.Security(get_user_id)],
) -> StorageUsageResponse:
"""
Get storage usage information for the user's workspace.
"""
config = Config()
workspace = await get_or_create_workspace(user_id)
used_bytes = await get_workspace_total_size(workspace.id)
file_count = await count_workspace_files(workspace.id)
limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
return StorageUsageResponse(
used_bytes=used_bytes,
limit_bytes=limit_bytes,
used_percent=round((used_bytes / limit_bytes) * 100, 1) if limit_bytes else 0,
file_count=file_count,
)

View File

@@ -1,359 +0,0 @@
"""Tests for workspace file upload and download routes."""
import io
from datetime import datetime, timezone
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from backend.api.features.workspace import routes as workspace_routes
from backend.data.workspace import WorkspaceFile
app = fastapi.FastAPI()
app.include_router(workspace_routes.router)
@app.exception_handler(ValueError)
async def _value_error_handler(
request: fastapi.Request, exc: ValueError
) -> fastapi.responses.JSONResponse:
"""Mirror the production ValueError → 400 mapping from rest_api.py."""
return fastapi.responses.JSONResponse(status_code=400, content={"detail": str(exc)})
client = fastapi.testclient.TestClient(app)
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
MOCK_WORKSPACE = type("W", (), {"id": "ws-1"})()
_NOW = datetime(2023, 1, 1, tzinfo=timezone.utc)
MOCK_FILE = WorkspaceFile(
id="file-aaa-bbb",
workspace_id="ws-1",
created_at=_NOW,
updated_at=_NOW,
name="hello.txt",
path="/session/hello.txt",
mime_type="text/plain",
size_bytes=13,
storage_path="local://hello.txt",
)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def _upload(
filename: str = "hello.txt",
content: bytes = b"Hello, world!",
content_type: str = "text/plain",
):
"""Helper to POST a file upload."""
return client.post(
"/files/upload?session_id=sess-1",
files={"file": (filename, io.BytesIO(content), content_type)},
)
# ---- Happy path ----
def test_upload_happy_path(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload()
assert response.status_code == 200
data = response.json()
assert data["file_id"] == "file-aaa-bbb"
assert data["name"] == "hello.txt"
assert data["size_bytes"] == 13
# ---- Per-file size limit ----
def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
"""Files larger than max_file_size_mb should be rejected with 413."""
cfg = mocker.patch("backend.api.features.workspace.routes.Config")
cfg.return_value.max_file_size_mb = 0 # 0 MB → any content is too big
cfg.return_value.max_workspace_storage_mb = 500
response = _upload(content=b"x" * 1024)
assert response.status_code == 413
# ---- Storage quota exceeded ----
def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
# Current usage already at limit
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=500 * 1024 * 1024,
)
response = _upload()
assert response.status_code == 413
assert "Storage limit exceeded" in response.text
# ---- Post-write quota race (B2) ----
def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
"""If a concurrent upload tips the total over the limit after write,
the file should be soft-deleted and 413 returned."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
# Pre-write check passes (under limit), but post-write check fails
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
side_effect=[0, 600 * 1024 * 1024], # first call OK, second over limit
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
mock_delete = mocker.patch(
"backend.api.features.workspace.routes.soft_delete_workspace_file",
return_value=None,
)
response = _upload()
assert response.status_code == 413
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-1")
# ---- Any extension accepted (no allowlist) ----
def test_upload_any_extension(mocker: pytest_mock.MockFixture):
"""Any file extension should be accepted — ClamAV is the security layer."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload(filename="data.xyz", content=b"arbitrary")
assert response.status_code == 200
# ---- Virus scan rejection ----
def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
"""Files flagged by ClamAV should be rejected and never written to storage."""
from backend.api.features.store.exceptions import VirusDetectedError
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
side_effect=VirusDetectedError("Eicar-Test-Signature"),
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload(filename="evil.exe", content=b"X5O!P%@AP...")
assert response.status_code == 400
assert "Virus detected" in response.text
mock_manager.write_file.assert_not_called()
# ---- No file extension ----
def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
"""Files without an extension should be accepted and stored as-is."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload(
filename="Makefile",
content=b"all:\n\techo hello",
content_type="application/octet-stream",
)
assert response.status_code == 200
mock_manager.write_file.assert_called_once()
assert mock_manager.write_file.call_args[0][1] == "Makefile"
# ---- Filename sanitization (SF5) ----
def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
"""Path-traversal filenames should be reduced to their basename."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
# Filename with traversal
_upload(filename="../../etc/passwd.txt")
# write_file should have been called with just the basename
mock_manager.write_file.assert_called_once()
call_args = mock_manager.write_file.call_args
assert call_args[0][1] == "passwd.txt"
# ---- Download ----
def test_download_file_not_found(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_file",
return_value=None,
)
response = client.get("/files/some-file-id/download")
assert response.status_code == 404
# ---- Delete ----
def test_delete_file_success(mocker: pytest_mock.MockFixture):
"""Deleting an existing file should return {"deleted": true}."""
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=MOCK_WORKSPACE,
)
mock_manager = mocker.MagicMock()
mock_manager.delete_file = mocker.AsyncMock(return_value=True)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = client.delete("/files/file-aaa-bbb")
assert response.status_code == 200
assert response.json() == {"deleted": True}
mock_manager.delete_file.assert_called_once_with("file-aaa-bbb")
def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
"""Deleting a non-existent file should return 404."""
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=MOCK_WORKSPACE,
)
mock_manager = mocker.MagicMock()
mock_manager.delete_file = mocker.AsyncMock(return_value=False)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = client.delete("/files/nonexistent-id")
assert response.status_code == 404
assert "File not found" in response.text
def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
"""Deleting when user has no workspace should return 404."""
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=None,
)
response = client.delete("/files/file-aaa-bbb")
assert response.status_code == 404
assert "Workspace not found" in response.text

View File

@@ -37,10 +37,8 @@ import backend.api.features.workspace.routes as workspace_routes
import backend.data.block
import backend.data.db
import backend.data.graph
import backend.data.llm_registry
import backend.data.user
import backend.integrations.webhooks.utils
import backend.server.v2.llm
import backend.util.service
import backend.util.settings
from backend.api.features.library.exceptions import (
@@ -57,7 +55,6 @@ from backend.util.exceptions import (
MissingConfigError,
NotAuthorizedError,
NotFoundError,
PreconditionFailed,
)
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
from backend.util.service import UnhealthyServiceError
@@ -119,30 +116,11 @@ async def lifespan_context(app: fastapi.FastAPI):
AutoRegistry.patch_integrations()
# Refresh LLM registry before initializing blocks so blocks can use registry data
# Note: Graceful fallback for now since no blocks consume registry yet (comes in PR #5)
# When block integration lands, this should fail hard or skip block initialization
try:
await backend.data.llm_registry.refresh_llm_registry()
logger.info("LLM registry refreshed successfully at startup")
except Exception as e:
logger.warning(
f"Failed to refresh LLM registry at startup: {e}. "
"Blocks will initialize with empty registry."
)
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
await backend.data.graph.fix_llm_provider_credentials()
try:
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
except Exception as e:
logger.warning(
f"Failed to migrate LLM models at startup: {e}. "
"This is expected in test environments without AgentNode table."
)
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
with launch_darkly_context():
@@ -297,7 +275,6 @@ app.add_exception_handler(RequestValidationError, validation_error_handler)
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
app.add_exception_handler(MissingConfigError, handle_internal_http_error(503))
app.add_exception_handler(ValueError, handle_internal_http_error(400))
app.add_exception_handler(PreconditionFailed, handle_internal_http_error(428))
app.add_exception_handler(Exception, handle_internal_http_error(500))
app.include_router(backend.api.features.v1.v1_router, tags=["v1"], prefix="/api")
@@ -369,11 +346,6 @@ app.include_router(
tags=["oauth"],
prefix="/api/oauth",
)
app.include_router(
backend.server.v2.llm.router,
tags=["v2", "llm"],
prefix="/api",
)
app.mount("/external-api", external_api)

View File

@@ -418,8 +418,6 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
_optimized_description: ClassVar[str | None] = None
def __init__(
self,
id: str = "",
@@ -472,8 +470,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.block_type = block_type
self.webhook_config = webhook_config
self.is_sensitive_action = is_sensitive_action
# Read from ClassVar set by initialize_blocks()
self.optimized_description: str | None = type(self)._optimized_description
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
if self.webhook_config:

View File

@@ -142,7 +142,7 @@ class BaseE2BExecutorMixin:
start_timestamp = ts_result.stdout.strip() if ts_result.stdout else None
# Execute the code
execution = await sandbox.run_code( # type: ignore[attr-defined]
execution = await sandbox.run_code(
code,
language=language.value,
on_error=lambda e: sandbox.kill(), # Kill the sandbox on error

View File

@@ -17,7 +17,7 @@ from backend.blocks.jina._auth import (
from backend.blocks.search import GetRequest
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
from backend.util.request import HTTPClientError, HTTPServerError, validate_url_host
from backend.util.request import HTTPClientError, HTTPServerError, validate_url
class SearchTheWebBlock(Block, GetRequest):
@@ -112,7 +112,7 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
) -> BlockOutput:
if input_data.raw_content:
try:
parsed_url, _, _ = await validate_url_host(input_data.url)
parsed_url, _, _ = await validate_url(input_data.url, [])
url = parsed_url.geturl()
except ValueError as e:
yield "error", f"Invalid URL: {e}"

View File

@@ -31,14 +31,10 @@ from backend.data.model import (
)
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.clients import OPENROUTER_BASE_URL
from backend.util.logging import TruncatedLogger
from backend.util.prompt import compress_context, estimate_token_count
from backend.util.request import validate_url_host
from backend.util.settings import Settings
from backend.util.text import TextFormatter
settings = Settings()
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
fmt = TextFormatter(autoescape=False)
@@ -120,7 +116,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
CLAUDE_4_6_OPUS = "claude-opus-4-6"
CLAUDE_4_6_SONNET = "claude-sonnet-4-6"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
# AI/ML API models
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
@@ -279,9 +274,6 @@ MODEL_METADATA = {
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
), # claude-opus-4-6
LlmModel.CLAUDE_4_6_SONNET: ModelMetadata(
"anthropic", 200000, 64000, "Claude Sonnet 4.6", "Anthropic", "Anthropic", 3
), # claude-sonnet-4-6
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
), # claude-opus-4-5-20251101
@@ -808,11 +800,6 @@ async def llm_call(
if tools:
raise ValueError("Ollama does not support tools.")
# Validate user-provided Ollama host to prevent SSRF etc.
await validate_url_host(
ollama_host, trusted_hostnames=[settings.config.ollama_host]
)
client = ollama.AsyncClient(host=ollama_host)
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
@@ -834,7 +821,7 @@ async def llm_call(
elif provider == "open_router":
tools_param = tools if tools else openai.NOT_GIVEN
client = openai.AsyncOpenAI(
base_url=OPENROUTER_BASE_URL,
base_url="https://openrouter.ai/api/v1",
api_key=credentials.api_key.get_secret_value(),
)

View File

@@ -6,6 +6,7 @@ and execute them. Works like AgentExecutorBlock — the user selects a tool from
dropdown and the input/output schema adapts dynamically.
"""
import json
import logging
from typing import Any, Literal
@@ -19,11 +20,6 @@ from backend.blocks._base import (
BlockType,
)
from backend.blocks.mcp.client import MCPClient, MCPClientError
from backend.blocks.mcp.helpers import (
auto_lookup_mcp_credential,
normalize_mcp_url,
parse_mcp_content,
)
from backend.data.block import BlockInput, BlockOutput
from backend.data.model import (
CredentialsField,
@@ -183,7 +179,31 @@ class MCPToolBlock(Block):
f"{error_text or 'Unknown error'}"
)
return parse_mcp_content(result.content)
# Extract text content from the result
output_parts = []
for item in result.content:
if item.get("type") == "text":
text = item.get("text", "")
# Try to parse as JSON for structured output
try:
output_parts.append(json.loads(text))
except (json.JSONDecodeError, ValueError):
output_parts.append(text)
elif item.get("type") == "image":
output_parts.append(
{
"type": "image",
"data": item.get("data"),
"mimeType": item.get("mimeType"),
}
)
elif item.get("type") == "resource":
output_parts.append(item.get("resource", {}))
# If single result, unwrap
if len(output_parts) == 1:
return output_parts[0]
return output_parts if output_parts else None
@staticmethod
async def _auto_lookup_credential(
@@ -191,10 +211,37 @@ class MCPToolBlock(Block):
) -> "OAuth2Credentials | None":
"""Auto-lookup stored MCP credential for a server URL.
Delegates to :func:`~backend.blocks.mcp.helpers.auto_lookup_mcp_credential`.
The caller should pass a normalized URL.
This is a fallback for nodes that don't have ``credentials`` explicitly
set (e.g. nodes created before the credential field was wired up).
"""
return await auto_lookup_mcp_credential(user_id, server_url)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
try:
mgr = IntegrationCredentialsManager()
mcp_creds = await mgr.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
best: OAuth2Credentials | None = None
for cred in mcp_creds:
if (
isinstance(cred, OAuth2Credentials)
and (cred.metadata or {}).get("mcp_server_url") == server_url
):
if best is None or (
(cred.access_token_expires_at or 0)
> (best.access_token_expires_at or 0)
):
best = cred
if best:
best = await mgr.refresh_if_needed(user_id, best)
logger.info(
"Auto-resolved MCP credential %s for %s", best.id, server_url
)
return best
except Exception:
logger.warning("Auto-lookup MCP credential failed", exc_info=True)
return None
async def run(
self,
@@ -231,7 +278,7 @@ class MCPToolBlock(Block):
# the stored MCP credential for this server URL.
if credentials is None:
credentials = await self._auto_lookup_credential(
user_id, normalize_mcp_url(input_data.server_url)
user_id, input_data.server_url
)
auth_token = (

View File

@@ -55,9 +55,7 @@ class MCPClient:
server_url: str,
auth_token: str | None = None,
):
from backend.blocks.mcp.helpers import normalize_mcp_url
self.server_url = normalize_mcp_url(server_url)
self.server_url = server_url.rstrip("/")
self.auth_token = auth_token
self._request_id = 0
self._session_id: str | None = None

View File

@@ -1,117 +0,0 @@
"""Shared MCP helpers used by blocks, copilot tools, and API routes."""
from __future__ import annotations
import json
import logging
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
if TYPE_CHECKING:
from backend.data.model import OAuth2Credentials
logger = logging.getLogger(__name__)
def normalize_mcp_url(url: str) -> str:
"""Normalize an MCP server URL for consistent credential matching.
Strips leading/trailing whitespace and a single trailing slash so that
``https://mcp.example.com/`` and ``https://mcp.example.com`` resolve to
the same stored credential.
"""
return url.strip().rstrip("/")
def server_host(server_url: str) -> str:
"""Extract the hostname from a server URL for display purposes.
Uses ``parsed.hostname`` (never ``netloc``) to strip any embedded
username/password before surfacing the value in UI messages.
"""
try:
parsed = urlparse(server_url)
return parsed.hostname or server_url
except Exception:
return server_url
def parse_mcp_content(content: list[dict[str, Any]]) -> Any:
"""Parse MCP tool response content into a plain Python value.
- text items: parsed as JSON when possible, kept as str otherwise
- image items: kept as ``{type, data, mimeType}`` dict for frontend rendering
- resource items: unwrapped to their resource payload dict
Single-item responses are unwrapped from the list; multiple items are
returned as a list; empty content returns ``None``.
"""
output_parts: list[Any] = []
for item in content:
item_type = item.get("type")
if item_type == "text":
text = item.get("text", "")
try:
output_parts.append(json.loads(text))
except (json.JSONDecodeError, ValueError):
output_parts.append(text)
elif item_type == "image":
output_parts.append(
{
"type": "image",
"data": item.get("data"),
"mimeType": item.get("mimeType"),
}
)
elif item_type == "resource":
output_parts.append(item.get("resource", {}))
if len(output_parts) == 1:
return output_parts[0]
return output_parts or None
async def auto_lookup_mcp_credential(
user_id: str, server_url: str
) -> OAuth2Credentials | None:
"""Look up the best stored MCP credential for *server_url*.
The caller should pass a **normalized** URL (via :func:`normalize_mcp_url`)
so the comparison with ``mcp_server_url`` in credential metadata matches.
Returns the credential with the latest ``access_token_expires_at``, refreshed
if needed, or ``None`` when no match is found.
"""
from backend.data.model import OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
try:
mgr = IntegrationCredentialsManager()
mcp_creds = await mgr.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
# Collect all matching credentials and pick the best one.
# Primary sort: latest access_token_expires_at (tokens with expiry
# are preferred over non-expiring ones). Secondary sort: last in
# iteration order, which corresponds to the most recently created
# row — this acts as a tiebreaker when multiple bearer tokens have
# no expiry (e.g. after a failed old-credential cleanup).
best: OAuth2Credentials | None = None
for cred in mcp_creds:
if (
isinstance(cred, OAuth2Credentials)
and (cred.metadata or {}).get("mcp_server_url") == server_url
):
if best is None or (
(cred.access_token_expires_at or 0)
>= (best.access_token_expires_at or 0)
):
best = cred
if best:
best = await mgr.refresh_if_needed(user_id, best)
logger.info("Auto-resolved MCP credential %s for %s", best.id, server_url)
return best
except Exception:
logger.warning("Auto-lookup MCP credential failed", exc_info=True)
return None

View File

@@ -1,98 +0,0 @@
"""Unit tests for the shared MCP helpers."""
from backend.blocks.mcp.helpers import normalize_mcp_url, parse_mcp_content, server_host
# ---------------------------------------------------------------------------
# normalize_mcp_url
# ---------------------------------------------------------------------------
def test_normalize_trailing_slash():
assert normalize_mcp_url("https://mcp.example.com/") == "https://mcp.example.com"
def test_normalize_whitespace():
assert normalize_mcp_url(" https://mcp.example.com ") == "https://mcp.example.com"
def test_normalize_both():
assert (
normalize_mcp_url(" https://mcp.example.com/ ") == "https://mcp.example.com"
)
def test_normalize_noop():
assert normalize_mcp_url("https://mcp.example.com") == "https://mcp.example.com"
def test_normalize_path_with_trailing_slash():
assert (
normalize_mcp_url("https://mcp.example.com/path/")
== "https://mcp.example.com/path"
)
# ---------------------------------------------------------------------------
# server_host
# ---------------------------------------------------------------------------
def test_server_host_standard_url():
assert server_host("https://mcp.example.com/mcp") == "mcp.example.com"
def test_server_host_strips_credentials():
"""hostname must not expose user:pass."""
assert server_host("https://user:secret@mcp.example.com/mcp") == "mcp.example.com"
def test_server_host_with_port():
"""Port should not appear in hostname (hostname strips it)."""
assert server_host("https://mcp.example.com:8080/mcp") == "mcp.example.com"
def test_server_host_fallback():
"""Falls back to the raw string for un-parseable URLs."""
assert server_host("not-a-url") == "not-a-url"
# ---------------------------------------------------------------------------
# parse_mcp_content
# ---------------------------------------------------------------------------
def test_parse_text_plain():
assert parse_mcp_content([{"type": "text", "text": "hello world"}]) == "hello world"
def test_parse_text_json():
content = [{"type": "text", "text": '{"status": "ok", "count": 42}'}]
assert parse_mcp_content(content) == {"status": "ok", "count": 42}
def test_parse_image():
content = [{"type": "image", "data": "abc123==", "mimeType": "image/png"}]
assert parse_mcp_content(content) == {
"type": "image",
"data": "abc123==",
"mimeType": "image/png",
}
def test_parse_resource():
content = [
{"type": "resource", "resource": {"uri": "file:///tmp/out.txt", "text": "hi"}}
]
assert parse_mcp_content(content) == {"uri": "file:///tmp/out.txt", "text": "hi"}
def test_parse_multi_item():
content = [
{"type": "text", "text": "first"},
{"type": "text", "text": "second"},
]
assert parse_mcp_content(content) == ["first", "second"]
def test_parse_empty():
assert parse_mcp_content([]) is None

View File

@@ -21,7 +21,6 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.clients import OPENROUTER_BASE_URL
from backend.util.logging import TruncatedLogger
logger = TruncatedLogger(logging.getLogger(__name__), "[Perplexity-Block]")
@@ -137,7 +136,7 @@ class PerplexityBlock(Block):
) -> dict[str, Any]:
"""Call Perplexity via OpenRouter and extract annotations."""
client = openai.AsyncOpenAI(
base_url=OPENROUTER_BASE_URL,
base_url="https://openrouter.ai/api/v1",
api_key=credentials.api_key.get_secret_value(),
)

View File

@@ -83,8 +83,7 @@ class StagehandRecommendedLlmModel(str, Enum):
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
# Anthropic
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929" # Keep for backwards compat
CLAUDE_4_6_SONNET = "claude-sonnet-4-6"
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
@property
def provider_name(self) -> str:
@@ -138,7 +137,7 @@ class StagehandObserveBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
@@ -228,7 +227,7 @@ class StagehandActBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
@@ -325,7 +324,7 @@ class StagehandExtractBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()

View File

@@ -1,8 +1,8 @@
import logging
from typing import Literal
from pydantic import BaseModel
from backend.api.features.store.db import StoreAgentsSortOptions
from backend.blocks._base import (
Block,
BlockCategory,
@@ -176,8 +176,8 @@ class SearchStoreAgentsBlock(Block):
category: str | None = SchemaField(
description="Filter by category", default=None
)
sort_by: StoreAgentsSortOptions = SchemaField(
description="How to sort the results", default=StoreAgentsSortOptions.RATING
sort_by: Literal["rating", "runs", "name", "updated_at"] = SchemaField(
description="How to sort the results", default="rating"
)
limit: int = SchemaField(
description="Maximum number of results to return", default=10, ge=1, le=100
@@ -278,7 +278,7 @@ class SearchStoreAgentsBlock(Block):
self,
query: str | None = None,
category: str | None = None,
sort_by: StoreAgentsSortOptions = StoreAgentsSortOptions.RATING,
sort_by: Literal["rating", "runs", "name", "updated_at"] = "rating",
limit: int = 10,
) -> SearchAgentsResponse:
"""

View File

@@ -2,7 +2,6 @@ from unittest.mock import MagicMock
import pytest
from backend.api.features.store.db import StoreAgentsSortOptions
from backend.blocks.system.library_operations import (
AddToLibraryFromStoreBlock,
LibraryAgent,
@@ -122,10 +121,7 @@ async def test_search_store_agents_block(mocker):
)
input_data = block.Input(
query="test",
category="productivity",
sort_by=StoreAgentsSortOptions.RATING, # type: ignore[reportArgumentType]
limit=10,
query="test", category="productivity", sort_by="rating", limit=10
)
outputs = {}

View File

@@ -1,3 +0,0 @@
from .service import stream_chat_completion_baseline
__all__ = ["stream_chat_completion_baseline"]

View File

@@ -1,424 +0,0 @@
"""Baseline LLM fallback — OpenAI-compatible streaming with tool calling.
Used when ``CHAT_USE_CLAUDE_AGENT_SDK=false``, e.g. as a fallback when the
Claude Agent SDK / Anthropic API is unavailable. Routes through any
OpenAI-compatible provider (OpenRouter by default) and reuses the same
shared tool registry as the SDK path.
"""
import asyncio
import logging
import uuid
from collections.abc import AsyncGenerator
from typing import Any
import orjson
from langfuse import propagate_attributes
from backend.copilot.model import (
ChatMessage,
ChatSession,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from backend.copilot.service import (
_build_system_prompt,
_generate_session_title,
client,
config,
)
from backend.copilot.tools import execute_tool, get_available_tools
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError
from backend.util.prompt import compress_context
logger = logging.getLogger(__name__)
# Set to hold background tasks to prevent garbage collection
_background_tasks: set[asyncio.Task[Any]] = set()
# Maximum number of tool-call rounds before forcing a text response.
_MAX_TOOL_ROUNDS = 30
async def _update_title_async(
session_id: str, message: str, user_id: str | None
) -> None:
"""Generate and persist a session title in the background."""
try:
title = await _generate_session_title(message, user_id, session_id)
if title and user_id:
await update_session_title(session_id, user_id, title, only_if_empty=True)
except Exception as e:
logger.warning("[Baseline] Failed to update session title: %s", e)
async def _compress_session_messages(
messages: list[ChatMessage],
) -> list[ChatMessage]:
"""Compress session messages if they exceed the model's token limit.
Uses the shared compress_context() utility which supports LLM-based
summarization of older messages while keeping recent ones intact,
with progressive truncation and middle-out deletion as fallbacks.
"""
messages_dict = []
for msg in messages:
msg_dict: dict[str, Any] = {"role": msg.role}
if msg.content:
msg_dict["content"] = msg.content
messages_dict.append(msg_dict)
try:
result = await compress_context(
messages=messages_dict,
model=config.model,
client=client,
)
except Exception as e:
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
result = await compress_context(
messages=messages_dict,
model=config.model,
client=None,
)
if result.was_compacted:
logger.info(
"[Baseline] Context compacted: %d -> %d tokens "
"(%d summarized, %d dropped)",
result.original_token_count,
result.token_count,
result.messages_summarized,
result.messages_dropped,
)
return [
ChatMessage(role=m["role"], content=m.get("content"))
for m in result.messages
]
return messages
async def stream_chat_completion_baseline(
session_id: str,
message: str | None = None,
is_user_message: bool = True,
user_id: str | None = None,
session: ChatSession | None = None,
**_kwargs: Any,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Baseline LLM with tool calling via OpenAI-compatible API.
Designed as a fallback when the Claude Agent SDK is unavailable.
Uses the same tool registry as the SDK path but routes through any
OpenAI-compatible provider (e.g. OpenRouter).
Flow: stream response -> if tool_calls, execute them -> feed results back -> repeat.
"""
if session is None:
session = await get_chat_session(session_id, user_id)
if not session:
raise NotFoundError(
f"Session {session_id} not found. Please create a new session first."
)
# Append user message
new_role = "user" if is_user_message else "assistant"
if message and (
len(session.messages) == 0
or not (
session.messages[-1].role == new_role
and session.messages[-1].content == message
)
):
session.messages.append(ChatMessage(role=new_role, content=message))
if is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(message),
)
session = await upsert_chat_session(session)
# Generate title for new sessions
if is_user_message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
if len(user_messages) == 1:
first_message = user_messages[0].content or message or ""
if first_message:
task = asyncio.create_task(
_update_title_async(session_id, first_message, user_id)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
message_id = str(uuid.uuid4())
# Build system prompt only on the first turn to avoid mid-conversation
# changes from concurrent chats updating business understanding.
is_first_turn = len(session.messages) <= 1
if is_first_turn:
base_system_prompt, _ = await _build_system_prompt(
user_id, has_conversation_history=False
)
else:
base_system_prompt, _ = await _build_system_prompt(
user_id=None, has_conversation_history=True
)
# Append tool documentation and technical notes
system_prompt = base_system_prompt + get_baseline_supplement()
# Compress context if approaching the model's token limit
messages_for_context = await _compress_session_messages(session.messages)
# Build OpenAI message list from session history
openai_messages: list[dict[str, Any]] = [
{"role": "system", "content": system_prompt}
]
for msg in messages_for_context:
if msg.role in ("user", "assistant") and msg.content:
openai_messages.append({"role": msg.role, "content": msg.content})
tools = get_available_tools()
yield StreamStart(messageId=message_id, sessionId=session_id)
# Propagate user/session context to Langfuse so all LLM calls within
# this request are grouped under a single trace with proper attribution.
_trace_ctx: Any = None
try:
_trace_ctx = propagate_attributes(
user_id=user_id,
session_id=session_id,
trace_name="copilot-baseline",
tags=["baseline"],
)
_trace_ctx.__enter__()
except Exception:
logger.warning("[Baseline] Langfuse trace context setup failed")
assistant_text = ""
text_block_id = str(uuid.uuid4())
text_started = False
step_open = False
try:
for _round in range(_MAX_TOOL_ROUNDS):
# Open a new step for each LLM round
yield StreamStartStep()
step_open = True
# Stream a response from the model
create_kwargs: dict[str, Any] = dict(
model=config.model,
messages=openai_messages,
stream=True,
)
if tools:
create_kwargs["tools"] = tools
response = await client.chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
# Accumulate streamed response (text + tool calls)
round_text = ""
tool_calls_by_index: dict[int, dict[str, str]] = {}
async for chunk in response:
delta = chunk.choices[0].delta if chunk.choices else None
if not delta:
continue
# Text content
if delta.content:
if not text_started:
yield StreamTextStart(id=text_block_id)
text_started = True
round_text += delta.content
yield StreamTextDelta(id=text_block_id, delta=delta.content)
# Tool call fragments (streamed incrementally)
if delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index
if idx not in tool_calls_by_index:
tool_calls_by_index[idx] = {
"id": "",
"name": "",
"arguments": "",
}
entry = tool_calls_by_index[idx]
if tc.id:
entry["id"] = tc.id
if tc.function and tc.function.name:
entry["name"] = tc.function.name
if tc.function and tc.function.arguments:
entry["arguments"] += tc.function.arguments
# Close text block if we had one this round
if text_started:
yield StreamTextEnd(id=text_block_id)
text_started = False
text_block_id = str(uuid.uuid4())
# Accumulate text for session persistence
assistant_text += round_text
# No tool calls -> model is done
if not tool_calls_by_index:
yield StreamFinishStep()
step_open = False
break
# Close step before tool execution
yield StreamFinishStep()
step_open = False
# Append the assistant message with tool_calls to context.
assistant_msg: dict[str, Any] = {"role": "assistant"}
if round_text:
assistant_msg["content"] = round_text
assistant_msg["tool_calls"] = [
{
"id": tc["id"],
"type": "function",
"function": {
"name": tc["name"],
"arguments": tc["arguments"] or "{}",
},
}
for tc in tool_calls_by_index.values()
]
openai_messages.append(assistant_msg)
# Execute each tool call and stream events
for tc in tool_calls_by_index.values():
tool_call_id = tc["id"]
tool_name = tc["name"]
raw_args = tc["arguments"] or "{}"
try:
tool_args = orjson.loads(raw_args)
except orjson.JSONDecodeError as parse_err:
parse_error = (
f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
)
logger.warning("[Baseline] %s", parse_error)
yield StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=parse_error,
success=False,
)
openai_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": parse_error,
}
)
continue
yield StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
yield StreamToolInputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
input=tool_args,
)
# Execute via shared tool registry
try:
result: StreamToolOutputAvailable = await execute_tool(
tool_name=tool_name,
parameters=tool_args,
user_id=user_id,
session=session,
tool_call_id=tool_call_id,
)
yield result
tool_output = (
result.output
if isinstance(result.output, str)
else str(result.output)
)
except Exception as e:
error_output = f"Tool execution error: {e}"
logger.error(
"[Baseline] Tool %s failed: %s",
tool_name,
error_output,
exc_info=True,
)
yield StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=error_output,
success=False,
)
tool_output = error_output
# Append tool result to context for next round
openai_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": tool_output,
}
)
else:
# for-loop exhausted without break -> tool-round limit hit
limit_msg = (
f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds "
"without a final response."
)
logger.error("[Baseline] %s", limit_msg)
yield StreamError(
errorText=limit_msg,
code="baseline_tool_round_limit",
)
except Exception as e:
error_msg = str(e) or type(e).__name__
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
# Close any open text/step before emitting error
if text_started:
yield StreamTextEnd(id=text_block_id)
if step_open:
yield StreamFinishStep()
yield StreamError(errorText=error_msg, code="baseline_error")
# Still persist whatever we got
finally:
# Close Langfuse trace context
if _trace_ctx is not None:
try:
_trace_ctx.__exit__(None, None, None)
except Exception:
logger.warning("[Baseline] Langfuse trace context teardown failed")
# Persist assistant response
if assistant_text:
session.messages.append(
ChatMessage(role="assistant", content=assistant_text)
)
try:
await upsert_chat_session(session)
except Exception as persist_err:
logger.error("[Baseline] Failed to persist session: %s", persist_err)
yield StreamFinish()

View File

@@ -1,99 +0,0 @@
import logging
from os import getenv
import pytest
from backend.copilot.baseline import stream_chat_completion_baseline
from backend.copilot.model import (
create_chat_session,
get_chat_session,
upsert_chat_session,
)
from backend.copilot.response_model import (
StreamError,
StreamFinish,
StreamStart,
StreamTextDelta,
)
logger = logging.getLogger(__name__)
@pytest.mark.asyncio(loop_scope="session")
async def test_baseline_multi_turn(setup_test_user, test_user_id):
"""Test that the baseline LLM path streams responses and maintains history.
Turn 1: Send a message with a unique keyword.
Turn 2: Ask the model to recall the keyword — proving conversation history
is correctly passed to the single-call LLM.
"""
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
# --- Turn 1: send a message with a unique keyword ---
keyword = "QUASAR99"
turn1_msg = (
f"Please remember this special keyword: {keyword}. "
"Just confirm you've noted it, keep your response brief."
)
turn1_text = ""
turn1_errors: list[str] = []
got_start = False
got_finish = False
async for chunk in stream_chat_completion_baseline(
session.session_id,
turn1_msg,
user_id=test_user_id,
):
if isinstance(chunk, StreamStart):
got_start = True
elif isinstance(chunk, StreamTextDelta):
turn1_text += chunk.delta
elif isinstance(chunk, StreamError):
turn1_errors.append(chunk.errorText)
elif isinstance(chunk, StreamFinish):
got_finish = True
assert got_start, "Turn 1 did not yield StreamStart"
assert got_finish, "Turn 1 did not yield StreamFinish"
assert not turn1_errors, f"Turn 1 errors: {turn1_errors}"
assert turn1_text, "Turn 1 produced no text"
logger.info(f"Turn 1 response: {turn1_text[:100]}")
# Reload session for turn 2
session = await get_chat_session(session.session_id, test_user_id)
assert session, "Session not found after turn 1"
# Verify messages were persisted (user + assistant)
assert (
len(session.messages) >= 2
), f"Expected at least 2 messages after turn 1, got {len(session.messages)}"
# --- Turn 2: ask model to recall the keyword ---
turn2_msg = "What was the special keyword I asked you to remember?"
turn2_text = ""
turn2_errors: list[str] = []
async for chunk in stream_chat_completion_baseline(
session.session_id,
turn2_msg,
user_id=test_user_id,
session=session,
):
if isinstance(chunk, StreamTextDelta):
turn2_text += chunk.delta
elif isinstance(chunk, StreamError):
turn2_errors.append(chunk.errorText)
assert not turn2_errors, f"Turn 2 errors: {turn2_errors}"
assert turn2_text, "Turn 2 produced no text"
assert keyword in turn2_text, (
f"Model did not recall keyword '{keyword}' in turn 2. "
f"Response: {turn2_text[:200]}"
)
logger.info(f"Turn 2 recalled keyword successfully: {turn2_text[:100]}")

View File

@@ -1,13 +1,10 @@
"""Configuration management for chat system."""
import os
from typing import Literal
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings
from backend.util.clients import OPENROUTER_BASE_URL
class ChatConfig(BaseSettings):
"""Configuration for the chat system."""
@@ -22,13 +19,18 @@ class ChatConfig(BaseSettings):
)
api_key: str | None = Field(default=None, description="OpenAI API key")
base_url: str | None = Field(
default=OPENROUTER_BASE_URL,
default="https://openrouter.ai/api/v1",
description="Base URL for API (e.g., for OpenRouter)",
)
# Session TTL Configuration - 12 hours
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
# Streaming Configuration
max_retries: int = Field(
default=3,
description="Max retries for fallback path (SDK handles retries internally)",
)
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
max_agent_schedules: int = Field(
default=30, description="Maximum number of agent schedules"
@@ -65,15 +67,11 @@ class ChatConfig(BaseSettings):
default="CoPilot Prompt",
description="Name of the prompt in Langfuse to fetch",
)
langfuse_prompt_cache_ttl: int = Field(
default=300,
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
)
# Claude Agent SDK Configuration
use_claude_agent_sdk: bool = Field(
default=True,
description="Use Claude Agent SDK (True) or OpenAI-compatible LLM baseline (False)",
description="Use Claude Agent SDK for chat completions",
)
claude_agent_model: str | None = Field(
default=None,
@@ -94,81 +92,18 @@ class ChatConfig(BaseSettings):
description="Use --resume for multi-turn conversations instead of "
"history compression. Falls back to compression when unavailable.",
)
use_claude_code_subscription: bool = Field(
default=False,
description="For personal/dev use: use Claude Code CLI subscription auth instead of API keys. Requires `claude login` on the host. Only works with SDK mode.",
)
# E2B Sandbox Configuration
use_e2b_sandbox: bool = Field(
# Extended thinking configuration for Claude models
thinking_enabled: bool = Field(
default=True,
description="Use E2B cloud sandboxes for persistent bash/python execution. "
"When enabled, bash_exec routes commands to E2B and SDK file tools "
"operate directly on the sandbox via E2B's filesystem API.",
description="Enable adaptive thinking for Claude models via OpenRouter",
)
e2b_api_key: str | None = Field(
default=None,
description="E2B API key. Falls back to E2B_API_KEY environment variable.",
)
e2b_sandbox_template: str = Field(
default="base",
description="E2B sandbox template to use for copilot sessions.",
)
e2b_sandbox_timeout: int = Field(
default=10800, # 3 hours — wall-clock timeout, not idle; explicit pause is primary
description="E2B sandbox running-time timeout (seconds). "
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
"mechanism; this is the safety net.",
)
e2b_sandbox_on_timeout: Literal["kill", "pause"] = Field(
default="pause",
description="E2B lifecycle action on timeout: 'pause' (default, free) or 'kill'.",
)
@property
def e2b_active(self) -> bool:
"""True when E2B is enabled and the API key is present.
Single source of truth for "should we use E2B right now?".
Prefer this over combining ``use_e2b_sandbox`` and ``e2b_api_key``
separately at call sites.
"""
return self.use_e2b_sandbox and bool(self.e2b_api_key)
@property
def active_e2b_api_key(self) -> str | None:
"""Return the E2B API key when E2B is enabled and configured, else None.
Combines the ``use_e2b_sandbox`` flag check and key presence into one.
Use in callers::
if api_key := config.active_e2b_api_key:
# E2B is active; api_key is narrowed to str
"""
return self.e2b_api_key if self.e2b_active else None
@field_validator("use_e2b_sandbox", mode="before")
@classmethod
def get_use_e2b_sandbox(cls, v):
"""Get use_e2b_sandbox from environment if not provided."""
env_val = os.getenv("CHAT_USE_E2B_SANDBOX", "").lower()
if env_val:
return env_val in ("true", "1", "yes", "on")
return True if v is None else v
@field_validator("e2b_api_key", mode="before")
@classmethod
def get_e2b_api_key(cls, v):
"""Get E2B API key from environment if not provided."""
if not v:
v = os.getenv("CHAT_E2B_API_KEY") or os.getenv("E2B_API_KEY")
return v
@field_validator("api_key", mode="before")
@classmethod
def get_api_key(cls, v):
"""Get API key from environment if not provided."""
if not v:
if v is None:
# Try to get from environment variables
# First check for CHAT_API_KEY (Pydantic prefix)
v = os.getenv("CHAT_API_KEY")
@@ -178,16 +113,13 @@ class ChatConfig(BaseSettings):
if not v:
# Fall back to OPENAI_API_KEY
v = os.getenv("OPENAI_API_KEY")
# Note: ANTHROPIC_API_KEY is intentionally NOT included here.
# The SDK CLI picks it up from the env directly. Including it
# would pair it with the OpenRouter base_url, causing auth failures.
return v
@field_validator("base_url", mode="before")
@classmethod
def get_base_url(cls, v):
"""Get base URL from environment if not provided."""
if not v:
if v is None:
# Check for OpenRouter or custom base URL
v = os.getenv("CHAT_BASE_URL")
if not v:
@@ -195,7 +127,7 @@ class ChatConfig(BaseSettings):
if not v:
v = os.getenv("OPENAI_BASE_URL")
if not v:
v = OPENROUTER_BASE_URL
v = "https://openrouter.ai/api/v1"
return v
@field_validator("use_claude_agent_sdk", mode="before")
@@ -209,15 +141,6 @@ class ChatConfig(BaseSettings):
# Default to True (SDK enabled by default)
return True if v is None else v
@field_validator("use_claude_code_subscription", mode="before")
@classmethod
def get_use_claude_code_subscription(cls, v):
"""Get use_claude_code_subscription from environment if not provided."""
env_val = os.getenv("CHAT_USE_CLAUDE_CODE_SUBSCRIPTION", "").lower()
if env_val:
return env_val in ("true", "1", "yes", "on")
return False if v is None else v
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",

View File

@@ -1,38 +0,0 @@
"""Unit tests for ChatConfig."""
import pytest
from .config import ChatConfig
# Env vars that the ChatConfig validators read — must be cleared so they don't
# override the explicit constructor values we pass in each test.
_E2B_ENV_VARS = (
"CHAT_USE_E2B_SANDBOX",
"CHAT_E2B_API_KEY",
"E2B_API_KEY",
)
@pytest.fixture(autouse=True)
def _clean_e2b_env(monkeypatch: pytest.MonkeyPatch) -> None:
for var in _E2B_ENV_VARS:
monkeypatch.delenv(var, raising=False)
class TestE2BActive:
"""Tests for the e2b_active property — single source of truth for E2B usage."""
def test_both_enabled_and_key_present_returns_true(self):
"""e2b_active is True when use_e2b_sandbox=True and e2b_api_key is set."""
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key="test-key")
assert cfg.e2b_active is True
def test_enabled_but_missing_key_returns_false(self):
"""e2b_active is False when use_e2b_sandbox=True but e2b_api_key is absent."""
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key=None)
assert cfg.e2b_active is False
def test_disabled_returns_false(self):
"""e2b_active is False when use_e2b_sandbox=False regardless of key."""
cfg = ChatConfig(use_e2b_sandbox=False, e2b_api_key="test-key")
assert cfg.e2b_active is False

View File

@@ -1,11 +0,0 @@
"""Shared constants for the CoPilot module."""
# Special message prefixes for text-based markers (parsed by frontend).
# The hex suffix makes accidental LLM generation of these strings virtually
# impossible, avoiding false-positive marker detection in normal conversation.
COPILOT_ERROR_PREFIX = "[__COPILOT_ERROR_f7a1__]" # Renders as ErrorCard
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
# Compaction notice messages shown to users.
COMPACTION_DONE_MSG = "Earlier messages were summarized to fit within context limits."
COMPACTION_TOOL_NAME = "context_compaction"

View File

@@ -1,115 +0,0 @@
"""Shared execution context for copilot SDK tool handlers.
All context variables and their accessors live here so that
``tool_adapter``, ``file_ref``, and ``e2b_file_tools`` can import them
without creating circular dependencies.
"""
import os
import re
from contextvars import ContextVar
from typing import TYPE_CHECKING
from backend.copilot.model import ChatSession
if TYPE_CHECKING:
from e2b import AsyncSandbox
# Allowed base directory for the Read tool.
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# Encoded project-directory name for the current session (e.g.
# "-private-tmp-copilot-<uuid>"). Set by set_execution_context() so path
# validation can scope tool-results reads to the current session.
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
_current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
)
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
"_current_sandbox", default=None
)
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
def _encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does."""
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
def set_execution_context(
user_id: str | None,
session: ChatSession,
sandbox: "AsyncSandbox | None" = None,
sdk_cwd: str | None = None,
) -> None:
"""Set per-turn context variables used by file-resolution tool handlers."""
_current_user_id.set(user_id)
_current_session.set(session)
_current_sandbox.set(sandbox)
_current_sdk_cwd.set(sdk_cwd or "")
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
def get_execution_context() -> tuple[str | None, ChatSession | None]:
"""Return the current (user_id, session) pair for the active request."""
return _current_user_id.get(), _current_session.get()
def get_current_sandbox() -> "AsyncSandbox | None":
"""Return the E2B sandbox for the current session, or None if not active."""
return _current_sandbox.get()
def get_sdk_cwd() -> str:
"""Return the SDK working directory for the current session (empty string if unset)."""
return _current_sdk_cwd.get()
E2B_WORKDIR = "/home/user"
def resolve_sandbox_path(path: str) -> str:
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
Raises :class:`ValueError` if the resolved path escapes the sandbox.
"""
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
normalized = os.path.normpath(candidate)
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
return normalized
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
"""Return True if *path* is within an allowed host-filesystem location.
Allowed:
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
- Files under ``~/.claude/projects/<encoded-cwd>/tool-results/`` (SDK tool-results)
"""
if not path:
return False
if path.startswith("~"):
resolved = os.path.realpath(os.path.expanduser(path))
elif not os.path.isabs(path) and sdk_cwd:
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
else:
resolved = os.path.realpath(path)
if sdk_cwd:
norm_cwd = os.path.realpath(sdk_cwd)
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
return True
encoded = _current_project_dir.get("")
if encoded:
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
if resolved == tool_results_dir or resolved.startswith(
tool_results_dir + os.sep
):
return True
return False

View File

@@ -1,163 +0,0 @@
"""Tests for context.py — execution context variables and path helpers."""
from __future__ import annotations
import os
import tempfile
from unittest.mock import MagicMock
import pytest
from backend.copilot.context import (
_SDK_PROJECTS_DIR,
_current_project_dir,
get_current_sandbox,
get_execution_context,
get_sdk_cwd,
is_allowed_local_path,
resolve_sandbox_path,
set_execution_context,
)
def _make_session() -> MagicMock:
s = MagicMock()
s.session_id = "test-session"
return s
# ---------------------------------------------------------------------------
# Context variable getters
# ---------------------------------------------------------------------------
def test_get_execution_context_defaults():
"""get_execution_context returns (None, session) when user_id is not set."""
set_execution_context(None, _make_session())
user_id, session = get_execution_context()
assert user_id is None
assert session is not None
def test_set_and_get_execution_context():
"""set_execution_context stores user_id and session."""
mock_session = _make_session()
set_execution_context("user-abc", mock_session)
user_id, session = get_execution_context()
assert user_id == "user-abc"
assert session is mock_session
def test_get_current_sandbox_none_by_default():
"""get_current_sandbox returns None when no sandbox is set."""
set_execution_context("u1", _make_session(), sandbox=None)
assert get_current_sandbox() is None
def test_get_current_sandbox_returns_set_value():
"""get_current_sandbox returns the sandbox set via set_execution_context."""
mock_sandbox = MagicMock()
set_execution_context("u1", _make_session(), sandbox=mock_sandbox)
assert get_current_sandbox() is mock_sandbox
def test_get_sdk_cwd_empty_when_not_set():
"""get_sdk_cwd returns empty string when sdk_cwd is not set."""
set_execution_context("u1", _make_session(), sdk_cwd=None)
assert get_sdk_cwd() == ""
def test_get_sdk_cwd_returns_set_value():
"""get_sdk_cwd returns the value set via set_execution_context."""
set_execution_context("u1", _make_session(), sdk_cwd="/tmp/copilot-test")
assert get_sdk_cwd() == "/tmp/copilot-test"
# ---------------------------------------------------------------------------
# is_allowed_local_path
# ---------------------------------------------------------------------------
def test_is_allowed_local_path_empty():
assert not is_allowed_local_path("")
def test_is_allowed_local_path_inside_sdk_cwd():
with tempfile.TemporaryDirectory() as cwd:
path = os.path.join(cwd, "file.txt")
assert is_allowed_local_path(path, cwd)
def test_is_allowed_local_path_sdk_cwd_itself():
with tempfile.TemporaryDirectory() as cwd:
assert is_allowed_local_path(cwd, cwd)
def test_is_allowed_local_path_outside_sdk_cwd():
with tempfile.TemporaryDirectory() as cwd:
assert not is_allowed_local_path("/etc/passwd", cwd)
def test_is_allowed_local_path_no_sdk_cwd_no_project_dir():
"""Without sdk_cwd or project_dir, all paths are rejected."""
_current_project_dir.set("")
assert not is_allowed_local_path("/tmp/some-file.txt", sdk_cwd=None)
def test_is_allowed_local_path_tool_results_dir():
"""Files under the tool-results directory for the current project are allowed."""
encoded = "test-encoded-dir"
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
path = os.path.join(tool_results_dir, "output.txt")
_current_project_dir.set(encoded)
try:
assert is_allowed_local_path(path, sdk_cwd=None)
finally:
_current_project_dir.set("")
def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
"""A path adjacent to tool-results/ but not inside it is rejected."""
encoded = "test-encoded-dir"
sibling_path = os.path.join(_SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
_current_project_dir.set(encoded)
try:
assert not is_allowed_local_path(sibling_path, sdk_cwd=None)
finally:
_current_project_dir.set("")
# ---------------------------------------------------------------------------
# resolve_sandbox_path
# ---------------------------------------------------------------------------
def test_resolve_sandbox_path_absolute_valid():
assert (
resolve_sandbox_path("/home/user/project/main.py")
== "/home/user/project/main.py"
)
def test_resolve_sandbox_path_relative():
assert resolve_sandbox_path("project/main.py") == "/home/user/project/main.py"
def test_resolve_sandbox_path_workdir_itself():
assert resolve_sandbox_path("/home/user") == "/home/user"
def test_resolve_sandbox_path_normalizes_dots():
assert resolve_sandbox_path("/home/user/a/../b") == "/home/user/b"
def test_resolve_sandbox_path_escape_raises():
with pytest.raises(ValueError, match="/home/user"):
resolve_sandbox_path("/home/user/../../etc/passwd")
def test_resolve_sandbox_path_absolute_outside_raises():
with pytest.raises(ValueError, match="/home/user"):
resolve_sandbox_path("/etc/passwd")

View File

@@ -16,7 +16,7 @@ from prisma.types import (
)
from backend.data import db
from backend.util.json import SafeJson, sanitize_string
from backend.util.json import SafeJson
from .model import ChatMessage, ChatSession, ChatSessionInfo
@@ -81,35 +81,6 @@ async def update_chat_session(
return ChatSession.from_db(session) if session else None
async def update_chat_session_title(
session_id: str,
user_id: str,
title: str,
*,
only_if_empty: bool = False,
) -> bool:
"""Update the title of a chat session, scoped to the owning user.
Always filters by (session_id, user_id) so callers cannot mutate another
user's session even when they know the session_id.
Args:
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
guard so auto-generated titles never overwrite a user-set title.
Returns True if a row was updated, False otherwise (session not found,
wrong user, or — when only_if_empty — title was already set).
"""
where: ChatSessionWhereInput = {"id": session_id, "userId": user_id}
if only_if_empty:
where["title"] = None
result = await PrismaChatSession.prisma().update_many(
where=where,
data={"title": title, "updatedAt": datetime.now(UTC)},
)
return result > 0
async def add_chat_message(
session_id: str,
role: str,
@@ -130,16 +101,15 @@ async def add_chat_message(
"sequence": sequence,
}
# Add optional string fields — sanitize to strip PostgreSQL-incompatible
# control characters (null bytes etc.) that may appear in tool outputs.
# Add optional string fields
if content is not None:
data["content"] = sanitize_string(content)
data["content"] = content
if name is not None:
data["name"] = name
if tool_call_id is not None:
data["toolCallId"] = tool_call_id
if refusal is not None:
data["refusal"] = sanitize_string(refusal)
data["refusal"] = refusal
# Add optional JSON fields only when they have values
if tool_calls is not None:
@@ -200,16 +170,15 @@ async def add_chat_messages_batch(
"createdAt": now,
}
# Add optional string fields — sanitize to strip
# PostgreSQL-incompatible control characters.
# Add optional string fields
if msg.get("content") is not None:
data["content"] = sanitize_string(msg["content"])
data["content"] = msg["content"]
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = sanitize_string(msg["refusal"])
data["refusal"] = msg["refusal"]
# Add optional JSON fields only when they have values
if msg.get("tool_calls") is not None:
@@ -343,7 +312,7 @@ async def update_tool_message_content(
"toolCallId": tool_call_id,
},
data={
"content": sanitize_string(new_content),
"content": new_content,
},
)
if result == 0:

View File

@@ -6,13 +6,11 @@ in a thread-local context, following the graph executor pattern.
import asyncio
import logging
import os
import subprocess
import threading
import time
from backend.copilot import service as copilot_service
from backend.copilot import stream_registry
from backend.copilot.baseline import stream_chat_completion_baseline
from backend.copilot.config import ChatConfig
from backend.copilot.response_model import StreamFinish
from backend.copilot.sdk import service as sdk_service
@@ -110,41 +108,8 @@ class CoPilotProcessor:
)
self.execution_thread.start()
# Skip the SDK's per-request CLI version check — the bundled CLI is
# already version-matched to the SDK package.
os.environ.setdefault("CLAUDE_AGENT_SDK_SKIP_VERSION_CHECK", "1")
# Pre-warm the bundled CLI binary so the OS page-caches the ~185 MB
# executable. First spawn pays ~1.2 s; subsequent spawns ~0.65 s.
self._prewarm_cli()
logger.info(f"[CoPilotExecutor] Worker {self.tid} started")
def _prewarm_cli(self) -> None:
"""Run the bundled CLI binary once to warm OS page caches."""
try:
from claude_agent_sdk._internal.transport.subprocess_cli import (
SubprocessCLITransport,
)
cli_path = SubprocessCLITransport._find_bundled_cli(None) # type: ignore[arg-type]
if cli_path:
result = subprocess.run(
[cli_path, "-v"],
capture_output=True,
timeout=10,
)
if result.returncode == 0:
logger.info(f"[CoPilotExecutor] CLI pre-warm done: {cli_path}")
else:
logger.warning(
"[CoPilotExecutor] CLI pre-warm failed (rc=%d): %s",
result.returncode, # type: ignore[reportCallIssue]
cli_path,
)
except Exception as e:
logger.debug(f"[CoPilotExecutor] CLI pre-warm skipped: {e}")
def cleanup(self):
"""Clean up event-loop-bound resources before the loop is destroyed.
@@ -154,12 +119,12 @@ class CoPilotProcessor:
"""
from backend.util.workspace_storage import shutdown_workspace_storage
coro = shutdown_workspace_storage()
try:
future = asyncio.run_coroutine_threadsafe(coro, self.execution_loop)
future = asyncio.run_coroutine_threadsafe(
shutdown_workspace_storage(), self.execution_loop
)
future.result(timeout=5)
except Exception as e:
coro.close() # Prevent "coroutine was never awaited" warning
error_msg = str(e) or type(e).__name__
logger.warning(
f"[CoPilotExecutor] Worker {self.tid} cleanup error: {error_msg}"
@@ -229,7 +194,7 @@ class CoPilotProcessor:
):
"""Async execution logic for a CoPilot turn.
Calls the chat completion service (SDK or baseline) and publishes
Calls the stream_chat_completion service function and publishes
results to the stream registry.
Args:
@@ -243,10 +208,9 @@ class CoPilotProcessor:
error_msg = None
try:
# Choose service based on LaunchDarkly flag.
# Claude Code subscription forces SDK mode (CLI subprocess auth).
# Choose service based on LaunchDarkly flag
config = ChatConfig()
use_sdk = config.use_claude_code_subscription or await is_feature_enabled(
use_sdk = await is_feature_enabled(
Flag.COPILOT_SDK,
entry.user_id or "anonymous",
default=config.use_claude_agent_sdk,
@@ -254,9 +218,9 @@ class CoPilotProcessor:
stream_fn = (
sdk_service.stream_chat_completion_sdk
if use_sdk
else stream_chat_completion_baseline
else copilot_service.stream_chat_completion
)
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
log.info(f"Using {'SDK' if use_sdk else 'standard'} service")
# Stream chat completion and publish chunks to Redis.
async for chunk in stream_fn(
@@ -265,7 +229,6 @@ class CoPilotProcessor:
is_user_message=entry.is_user_message,
user_id=entry.user_id,
context=entry.context,
file_ids=entry.file_ids,
):
if cancel.is_set():
log.info("Cancel requested, breaking stream")

View File

@@ -153,9 +153,6 @@ class CoPilotExecutionEntry(BaseModel):
context: dict[str, str] | None = None
"""Optional context for the message (e.g., {url: str, content: str})"""
file_ids: list[str] | None = None
"""Workspace file IDs attached to the user's message"""
class CancelCoPilotEvent(BaseModel):
"""Event to cancel a CoPilot operation."""
@@ -174,7 +171,6 @@ async def enqueue_copilot_turn(
turn_id: str,
is_user_message: bool = True,
context: dict[str, str] | None = None,
file_ids: list[str] | None = None,
) -> None:
"""Enqueue a CoPilot task for processing by the executor service.
@@ -185,7 +181,6 @@ async def enqueue_copilot_turn(
turn_id: Per-turn UUID for Redis stream isolation
is_user_message: Whether the message is from the user (vs system/assistant)
context: Optional context for the message (e.g., {url: str, content: str})
file_ids: Optional workspace file IDs attached to the user's message
"""
from backend.util.clients import get_async_copilot_queue
@@ -196,7 +191,6 @@ async def enqueue_copilot_turn(
message=message,
is_user_message=is_user_message,
context=context,
file_ids=file_ids,
)
queue_client = await get_async_copilot_queue()

View File

@@ -469,16 +469,8 @@ async def upsert_chat_session(
)
db_error = e
# Save to cache (best-effort, even if DB failed).
# Title updates (update_session_title) run *outside* this lock because
# they only touch the title field, not messages. So a concurrent rename
# or auto-title may have written a newer title to Redis while this
# upsert was in progress. Always prefer the cached title to avoid
# overwriting it with the stale in-memory copy.
# Save to cache (best-effort, even if DB failed)
try:
existing_cached = await _get_session_from_cache(session.session_id)
if existing_cached and existing_cached.title:
session = session.model_copy(update={"title": existing_cached.title})
await cache_chat_session(session)
except Exception as e:
# If DB succeeded but cache failed, raise cache error
@@ -680,47 +672,27 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
async with _session_locks_mutex:
_session_locks.pop(session_id, None)
# Shut down any local browser daemon for this session (best-effort).
# Inline import required: all tool modules import ChatSession from this
# module, so any top-level import from tools.* would create a cycle.
try:
from .tools.agent_browser import close_browser_session
await close_browser_session(session_id, user_id=user_id)
except Exception as e:
logger.debug(f"Browser cleanup for session {session_id}: {e}")
return True
async def update_session_title(
session_id: str,
user_id: str,
title: str,
*,
only_if_empty: bool = False,
) -> bool:
"""Update the title of a chat session, scoped to the owning user.
async def update_session_title(session_id: str, title: str) -> bool:
"""Update only the title of a chat session.
Lightweight operation that doesn't touch messages, avoiding race conditions
with concurrent message updates.
This is a lightweight operation that doesn't touch messages, avoiding
race conditions with concurrent message updates. Use this for background
title generation instead of upsert_chat_session.
Args:
session_id: The session ID to update.
user_id: Owning user — the DB query filters on this.
title: The new title to set.
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
so auto-generated titles never overwrite a user-set title.
Returns:
True if updated successfully, False otherwise (not found, wrong user,
or — when only_if_empty — title was already set).
True if updated successfully, False otherwise.
"""
try:
updated = await chat_db().update_chat_session_title(
session_id, user_id, title, only_if_empty=only_if_empty
)
if not updated:
result = await chat_db().update_chat_session(session_id=session_id, title=title)
if result is None:
logger.warning(f"Session {session_id} not found for title update")
return False
# Update title in cache if it exists (instead of invalidating).
@@ -732,8 +704,9 @@ async def update_session_title(
cached.title = title
await cache_chat_session(cached)
except Exception as e:
# Not critical - title will be correct on next full cache refresh
logger.warning(
f"Cache title update failed for session {session_id} (non-critical): {e}"
f"Failed to update title in cache for session {session_id}: {e}"
)
return True

View File

@@ -1,138 +0,0 @@
"""Scheduler job to generate LLM-optimized block descriptions.
Runs periodically to rewrite block descriptions into concise, actionable
summaries that help the copilot LLM pick the right blocks during agent
generation.
"""
import asyncio
import logging
from backend.blocks import get_blocks
from backend.util.clients import get_database_manager_client, get_openai_client
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = (
"You are a technical writer for an automation platform. "
"Rewrite the following block description to be concise (under 50 words), "
"informative, and actionable. Focus on what the block does and when to "
"use it. Output ONLY the rewritten description, nothing else. "
"Do not use markdown formatting."
)
# Rate-limit delay between sequential LLM calls (seconds)
_RATE_LIMIT_DELAY = 0.5
# Maximum tokens for optimized description generation
_MAX_DESCRIPTION_TOKENS = 150
# Model for generating optimized descriptions (fast, cheap)
_MODEL = "gpt-4o-mini"
async def _optimize_descriptions(blocks: list[dict[str, str]]) -> dict[str, str]:
"""Call the shared OpenAI client to rewrite each block description."""
client = get_openai_client()
if client is None:
logger.error(
"No OpenAI client configured, skipping block description optimization"
)
return {}
results: dict[str, str] = {}
for block in blocks:
block_id = block["id"]
block_name = block["name"]
description = block["description"]
try:
response = await client.chat.completions.create(
model=_MODEL,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": f"Block name: {block_name}\nDescription: {description}",
},
],
max_tokens=_MAX_DESCRIPTION_TOKENS,
)
optimized = (response.choices[0].message.content or "").strip()
if optimized:
results[block_id] = optimized
logger.debug("Optimized description for %s", block_name)
else:
logger.warning("Empty response for block %s", block_name)
except Exception:
logger.warning(
"Failed to optimize description for %s", block_name, exc_info=True
)
await asyncio.sleep(_RATE_LIMIT_DELAY)
return results
def optimize_block_descriptions() -> dict[str, int]:
"""Generate optimized descriptions for blocks that don't have one yet.
Uses the shared OpenAI client to rewrite block descriptions into concise
summaries suitable for agent generation prompts.
Returns:
Dict with counts: processed, success, failed, skipped.
"""
db_client = get_database_manager_client()
blocks = db_client.get_blocks_needing_optimization()
if not blocks:
logger.info("All blocks already have optimized descriptions")
return {"processed": 0, "success": 0, "failed": 0, "skipped": 0}
logger.info("Found %d blocks needing optimized descriptions", len(blocks))
non_empty = [b for b in blocks if b.get("description", "").strip()]
skipped = len(blocks) - len(non_empty)
new_descriptions = asyncio.run(_optimize_descriptions(non_empty))
stats = {
"processed": len(non_empty),
"success": len(new_descriptions),
"failed": len(non_empty) - len(new_descriptions),
"skipped": skipped,
}
logger.info(
"Block description optimization complete: "
"%d/%d succeeded, %d failed, %d skipped",
stats["success"],
stats["processed"],
stats["failed"],
stats["skipped"],
)
if new_descriptions:
for block_id, optimized in new_descriptions.items():
db_client.update_block_optimized_description(block_id, optimized)
# Update in-memory descriptions first so the cache rebuilds with fresh data.
try:
block_classes = get_blocks()
for block_id, optimized in new_descriptions.items():
if block_id in block_classes:
block_classes[block_id]._optimized_description = optimized
logger.info(
"Updated %d in-memory block descriptions", len(new_descriptions)
)
except Exception:
logger.warning(
"Could not update in-memory block descriptions", exc_info=True
)
from backend.copilot.tools.agent_generator.blocks import (
reset_block_caches, # local to avoid circular import
)
reset_block_caches()
return stats

View File

@@ -1,91 +0,0 @@
"""Unit tests for optimize_blocks._optimize_descriptions."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from backend.copilot.optimize_blocks import _RATE_LIMIT_DELAY, _optimize_descriptions
def _make_client_response(text: str) -> MagicMock:
"""Build a minimal mock that looks like an OpenAI ChatCompletion response."""
choice = MagicMock()
choice.message.content = text
response = MagicMock()
response.choices = [choice]
return response
def _run(coro):
return asyncio.get_event_loop().run_until_complete(coro)
class TestOptimizeDescriptions:
"""Tests for _optimize_descriptions async function."""
def test_returns_empty_when_no_client(self):
with patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=None
):
result = _run(
_optimize_descriptions([{"id": "b1", "name": "B", "description": "d"}])
)
assert result == {}
def test_success_single_block(self):
client = MagicMock()
client.chat.completions.create = AsyncMock(
return_value=_make_client_response("Short desc.")
)
blocks = [{"id": "b1", "name": "MyBlock", "description": "A block."}]
with (
patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
),
patch(
"backend.copilot.optimize_blocks.asyncio.sleep", new_callable=AsyncMock
),
):
result = _run(_optimize_descriptions(blocks))
assert result == {"b1": "Short desc."}
client.chat.completions.create.assert_called_once()
def test_skips_block_on_exception(self):
client = MagicMock()
client.chat.completions.create = AsyncMock(side_effect=Exception("API error"))
blocks = [{"id": "b1", "name": "MyBlock", "description": "A block."}]
with (
patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
),
patch(
"backend.copilot.optimize_blocks.asyncio.sleep", new_callable=AsyncMock
),
):
result = _run(_optimize_descriptions(blocks))
assert result == {}
def test_sleeps_between_blocks(self):
client = MagicMock()
client.chat.completions.create = AsyncMock(
return_value=_make_client_response("desc")
)
blocks = [
{"id": "b1", "name": "B1", "description": "d1"},
{"id": "b2", "name": "B2", "description": "d2"},
]
sleep_mock = AsyncMock()
with (
patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
),
patch("backend.copilot.optimize_blocks.asyncio.sleep", sleep_mock),
):
_run(_optimize_descriptions(blocks))
assert sleep_mock.call_count == 2
sleep_mock.assert_called_with(_RATE_LIMIT_DELAY)

View File

@@ -0,0 +1,422 @@
"""Lightweight OTLP JSON trace exporter for CoPilot LLM calls.
Sends spans to a remote OTLP-compatible endpoint (e.g. Product Intelligence)
in the ExportTraceServiceRequest JSON format. Payload construction and the
HTTP POST run in background asyncio tasks so streaming latency is unaffected.
Configuration (via backend.util.settings.Secrets):
OTLP_TRACING_HOST base URL of the trace ingestion service
(e.g. "https://traces.example.com")
OTLP_TRACING_TOKEN optional Bearer token for authentication
"""
import asyncio
import json
import logging
import time
import uuid
from dataclasses import dataclass, field
from typing import Any
import httpx
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
_settings = Settings()
# Resolve the endpoint once at import time.
_TRACING_HOST = (_settings.secrets.otlp_tracing_host or "").rstrip("/")
_TRACING_TOKEN = _settings.secrets.otlp_tracing_token.get_secret_value()
_TRACING_ENABLED = bool(_TRACING_HOST)
# Shared async client — created lazily on first use.
_client: httpx.AsyncClient | None = None
def _get_client() -> httpx.AsyncClient:
global _client
if _client is None:
headers: dict[str, str] = {"Content-Type": "application/json"}
if _TRACING_TOKEN:
headers["Authorization"] = f"Bearer {_TRACING_TOKEN}"
_client = httpx.AsyncClient(headers=headers, timeout=10.0)
return _client
def _nano(ts: float) -> str:
"""Convert a ``time.time()`` float to nanosecond string for OTLP."""
return str(int(ts * 1_000_000_000))
def _kv(key: str, value: Any) -> dict | None:
"""Build an OTLP KeyValue entry, returning None for missing values."""
if value is None:
return None
if isinstance(value, str):
return {"key": key, "value": {"stringValue": value}}
if isinstance(value, bool):
return {"key": key, "value": {"stringValue": str(value).lower()}}
if isinstance(value, int):
return {"key": key, "value": {"intValue": str(value)}}
if isinstance(value, float):
return {"key": key, "value": {"doubleValue": value}}
# Fallback: serialise as string
return {"key": key, "value": {"stringValue": str(value)}}
def _build_completion_text(
assistant_content: str | None,
tool_calls: list[dict[str, Any]] | None,
) -> str | None:
"""Build completion text that includes tool calls in the format
the Product Intelligence system can parse: ``tool_name{json_args}``.
"""
parts: list[str] = []
if tool_calls:
for tc in tool_calls:
fn = tc.get("function", {})
name = fn.get("name", "")
args = fn.get("arguments", "{}")
if name:
parts.append(f"{name}{args}")
if assistant_content:
parts.append(assistant_content)
return "\n".join(parts) if parts else None
def _model_provider_slug(model: str) -> str:
text = (model or "").strip().lower()
if not text:
return "unknown"
return text.split("/", 1)[0]
def _model_provider_name(slug: str) -> str:
known = {
"openai": "OpenAI",
"anthropic": "Anthropic",
"google": "Google",
"meta": "Meta",
"mistral": "Mistral",
"deepseek": "DeepSeek",
"x-ai": "xAI",
"xai": "xAI",
"qwen": "Qwen",
"nvidia": "NVIDIA",
"cohere": "Cohere",
}
return known.get(slug, slug)
@dataclass
class TraceContext:
"""Accumulates trace data during LLM streaming for OTLP emission.
Used by both SDK and non-SDK CoPilot paths to collect usage metrics,
tool calls, and timing information in a consistent structure.
"""
model: str = ""
user_id: str | None = None
session_id: str | None = None
start_time: float = 0.0
# Accumulated during streaming
text_parts: list[str] = field(default_factory=list)
tool_calls: list[dict[str, Any]] = field(default_factory=list)
usage: dict[str, Any] = field(default_factory=dict)
cost_usd: float | None = None
def emit(
self,
*,
finish_reason: str | None = None,
messages: list[dict[str, Any]] | None = None,
) -> None:
"""Build and emit the trace as a fire-and-forget background task."""
fr = finish_reason or ("tool_calls" if self.tool_calls else "stop")
emit_trace(
model=self.model,
messages=messages or [],
assistant_content="".join(self.text_parts) or None,
finish_reason=fr,
prompt_tokens=(self.usage.get("prompt") or self.usage.get("input_tokens")),
completion_tokens=(
self.usage.get("completion") or self.usage.get("output_tokens")
),
total_tokens=self.usage.get("total"),
total_cost_usd=self.cost_usd,
cache_creation_input_tokens=self.usage.get("cache_creation_input_tokens"),
cache_read_input_tokens=(
self.usage.get("cached") or self.usage.get("cache_read_input_tokens")
),
reasoning_tokens=self.usage.get("reasoning"),
user_id=self.user_id,
session_id=self.session_id,
tool_calls=self.tool_calls or None,
start_time=self.start_time,
end_time=time.time(),
)
def _build_otlp_payload(
*,
trace_id: str,
model: str,
messages: list[dict[str, Any]],
assistant_content: str | None = None,
finish_reason: str = "stop",
prompt_tokens: int | None = None,
completion_tokens: int | None = None,
total_tokens: int | None = None,
total_cost_usd: float | None = None,
cache_creation_input_tokens: int | None = None,
cache_read_input_tokens: int | None = None,
reasoning_tokens: int | None = None,
user_id: str | None = None,
session_id: str | None = None,
tool_calls: list[dict[str, Any]] | None = None,
start_time: float | None = None,
end_time: float | None = None,
) -> dict:
"""Build an ``ExportTraceServiceRequest`` JSON payload."""
provider_slug = _model_provider_slug(model)
provider_name = _model_provider_name(provider_slug)
prompt_payload: str | None = None
if messages:
prompt_payload = json.dumps({"messages": messages}, default=str)
completion_payload: str | None = None
completion_text = _build_completion_text(assistant_content, tool_calls)
if completion_text is not None:
completion_obj: dict[str, Any] = {
"completion": completion_text,
"reasoning": None,
"rawRequest": {
"model": model,
"stream": True,
"stream_options": {"include_usage": True},
"tool_choice": "auto",
"user": user_id,
"posthogDistinctId": user_id,
"session_id": session_id,
},
}
completion_payload = json.dumps(completion_obj, default=str)
attrs: list[dict] = []
for kv in [
_kv("trace.name", "OpenRouter Request"),
_kv("span.type", "generation"),
_kv("span.level", "DEFAULT"),
_kv("gen_ai.operation.name", "chat"),
_kv("gen_ai.system", provider_slug),
_kv("gen_ai.provider.name", provider_slug),
_kv("gen_ai.request.model", model),
_kv("gen_ai.response.model", model),
_kv("gen_ai.response.finish_reason", finish_reason),
_kv("gen_ai.response.finish_reasons", json.dumps([finish_reason])),
_kv("gen_ai.usage.input_tokens", prompt_tokens),
_kv("gen_ai.usage.output_tokens", completion_tokens),
_kv("gen_ai.usage.total_tokens", total_tokens),
_kv("gen_ai.usage.input_tokens.cached", cache_read_input_tokens),
_kv(
"gen_ai.usage.input_tokens.cache_creation",
cache_creation_input_tokens,
),
_kv("gen_ai.usage.output_tokens.reasoning", reasoning_tokens),
_kv("user.id", user_id),
_kv("session.id", session_id),
_kv("trace.metadata.openrouter.source", "openrouter"),
_kv("trace.metadata.openrouter.user_id", user_id),
_kv("gen_ai.usage.total_cost", total_cost_usd),
_kv("trace.metadata.openrouter.provider_name", provider_name),
_kv("trace.metadata.openrouter.provider_slug", provider_slug),
_kv("trace.metadata.openrouter.finish_reason", finish_reason),
]:
if kv is not None:
attrs.append(kv)
if prompt_payload is not None:
attrs.append({"key": "trace.input", "value": {"stringValue": prompt_payload}})
attrs.append({"key": "span.input", "value": {"stringValue": prompt_payload}})
attrs.append({"key": "gen_ai.prompt", "value": {"stringValue": prompt_payload}})
if completion_payload is not None:
attrs.append(
{
"key": "trace.output",
"value": {"stringValue": completion_payload},
}
)
attrs.append(
{
"key": "span.output",
"value": {"stringValue": completion_payload},
}
)
attrs.append(
{
"key": "gen_ai.completion",
"value": {"stringValue": completion_payload},
}
)
span = {
"traceId": trace_id,
"startTimeUnixNano": _nano(start_time or time.time()),
"endTimeUnixNano": _nano(end_time or time.time()),
"attributes": attrs,
}
return {
"resourceSpans": [
{
"resource": {
"attributes": [
{
"key": "service.name",
"value": {"stringValue": "openrouter"},
},
{
"key": "openrouter.trace.id",
"value": {
"stringValue": (
f"gen-{int(end_time or time.time())}"
f"-{trace_id[:20]}"
)
},
},
]
},
"scopeSpans": [{"spans": [span]}],
}
]
}
async def _send_trace(payload: dict) -> None:
"""POST the OTLP payload to the configured tracing host."""
url = f"{_TRACING_HOST}/v1/traces"
try:
client = _get_client()
resp = await client.post(url, json=payload)
if resp.status_code >= 400:
logger.debug(
"[OTLP] Trace POST returned %d: %s",
resp.status_code,
resp.text[:200],
)
else:
logger.debug("[OTLP] Trace sent successfully (%d)", resp.status_code)
except Exception as e:
logger.warning("[OTLP] Failed to send trace: %s", e)
# Background task set with backpressure cap.
_bg_tasks: set[asyncio.Task[Any]] = set()
_MAX_BG_TASKS = 64
async def _build_and_send_trace(
*,
model: str,
messages: list[dict[str, Any]],
assistant_content: str | None,
finish_reason: str,
prompt_tokens: int | None,
completion_tokens: int | None,
total_tokens: int | None,
total_cost_usd: float | None,
cache_creation_input_tokens: int | None,
cache_read_input_tokens: int | None,
reasoning_tokens: int | None,
user_id: str | None,
session_id: str | None,
tool_calls: list[dict[str, Any]] | None,
start_time: float | None,
end_time: float | None,
) -> None:
"""Build the OTLP payload and send it — runs entirely in a background task."""
trace_id = uuid.uuid4().hex
payload = _build_otlp_payload(
trace_id=trace_id,
model=model,
messages=messages,
assistant_content=assistant_content,
finish_reason=finish_reason,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
total_cost_usd=total_cost_usd,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
reasoning_tokens=reasoning_tokens,
user_id=user_id,
session_id=session_id,
tool_calls=tool_calls,
start_time=start_time,
end_time=end_time,
)
await _send_trace(payload)
def emit_trace(
*,
model: str,
messages: list[dict[str, Any]],
assistant_content: str | None = None,
finish_reason: str = "stop",
prompt_tokens: int | None = None,
completion_tokens: int | None = None,
total_tokens: int | None = None,
total_cost_usd: float | None = None,
cache_creation_input_tokens: int | None = None,
cache_read_input_tokens: int | None = None,
reasoning_tokens: int | None = None,
user_id: str | None = None,
session_id: str | None = None,
tool_calls: list[dict[str, Any]] | None = None,
start_time: float | None = None,
end_time: float | None = None,
) -> None:
"""Fire-and-forget: build and send an OTLP trace span.
Safe to call from async context — both payload serialization and the
HTTP POST run in a background task so they never block the event loop.
"""
if not _TRACING_ENABLED:
return
if len(_bg_tasks) >= _MAX_BG_TASKS:
logger.warning(
"[OTLP] Backpressure: dropping trace (%d tasks queued)",
len(_bg_tasks),
)
return
task = asyncio.create_task(
_build_and_send_trace(
model=model,
messages=messages,
assistant_content=assistant_content,
finish_reason=finish_reason,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
total_cost_usd=total_cost_usd,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
reasoning_tokens=reasoning_tokens,
user_id=user_id,
session_id=session_id,
tool_calls=tool_calls,
start_time=start_time,
end_time=end_time,
)
)
_bg_tasks.add(task)
task.add_done_callback(_bg_tasks.discard)

View File

@@ -0,0 +1,269 @@
"""Tests for parallel tool call execution in CoPilot.
These tests mock _yield_tool_call to avoid importing the full copilot stack
which requires Prisma, DB connections, etc.
"""
import asyncio
import time
from typing import Any, cast
import pytest
@pytest.mark.asyncio
async def test_parallel_tool_calls_run_concurrently():
"""Multiple tool calls should complete in ~max(delays), not sum(delays)."""
from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from backend.copilot.service import _execute_tool_calls_parallel
n_tools = 3
delay_per_tool = 0.2
tool_calls = [
{
"id": f"call_{i}",
"type": "function",
"function": {"name": f"tool_{i}", "arguments": "{}"},
}
for i in range(n_tools)
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
original_yield = None
async def fake_yield(tc_list, idx, sess):
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"],
toolName=tc_list[idx]["function"]["name"],
input={},
)
await asyncio.sleep(delay_per_tool)
yield StreamToolOutputAvailable(
toolCallId=tc_list[idx]["id"],
toolName=tc_list[idx]["function"]["name"],
output="{}",
)
import backend.copilot.service as svc
original_yield = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
start = time.monotonic()
events = []
async for event in _execute_tool_calls_parallel(
tool_calls, cast(Any, FakeSession())
):
events.append(event)
elapsed = time.monotonic() - start
finally:
svc._yield_tool_call = original_yield
assert len(events) == n_tools * 2
# Parallel: should take ~delay, not ~n*delay
assert elapsed < delay_per_tool * (
n_tools - 0.5
), f"Took {elapsed:.2f}s, expected parallel (~{delay_per_tool}s)"
@pytest.mark.asyncio
async def test_single_tool_call_works():
"""Single tool call should work identically."""
from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from backend.copilot.service import _execute_tool_calls_parallel
tool_calls = [
{
"id": "call_0",
"type": "function",
"function": {"name": "t", "arguments": "{}"},
}
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
async def fake_yield(tc_list, idx, sess):
yield StreamToolInputAvailable(toolCallId="call_0", toolName="t", input={})
yield StreamToolOutputAvailable(toolCallId="call_0", toolName="t", output="{}")
import backend.copilot.service as svc
orig = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
events = [
e
async for e in _execute_tool_calls_parallel(
tool_calls, cast(Any, FakeSession())
)
]
finally:
svc._yield_tool_call = orig
assert len(events) == 2
@pytest.mark.asyncio
async def test_retryable_error_propagates():
"""Retryable errors should be raised after all tools finish."""
from backend.copilot.response_model import StreamToolOutputAvailable
from backend.copilot.service import _execute_tool_calls_parallel
tool_calls = [
{
"id": f"call_{i}",
"type": "function",
"function": {"name": f"t_{i}", "arguments": "{}"},
}
for i in range(2)
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
async def fake_yield(tc_list, idx, sess):
if idx == 1:
raise KeyError("bad")
from backend.copilot.response_model import StreamToolInputAvailable
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"], toolName="t_0", input={}
)
await asyncio.sleep(0.05)
yield StreamToolOutputAvailable(
toolCallId=tc_list[idx]["id"], toolName="t_0", output="{}"
)
import backend.copilot.service as svc
orig = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
events = []
with pytest.raises(KeyError):
async for event in _execute_tool_calls_parallel(
tool_calls, cast(Any, FakeSession())
):
events.append(event)
# First tool's events should still be yielded
assert any(isinstance(e, StreamToolOutputAvailable) for e in events)
finally:
svc._yield_tool_call = orig
@pytest.mark.asyncio
async def test_session_shared_across_parallel_tools():
"""All parallel tools should receive the same session instance."""
from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from backend.copilot.service import _execute_tool_calls_parallel
tool_calls = [
{
"id": f"call_{i}",
"type": "function",
"function": {"name": f"t_{i}", "arguments": "{}"},
}
for i in range(3)
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
observed_sessions = []
async def fake_yield(tc_list, idx, sess):
observed_sessions.append(sess)
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
)
yield StreamToolOutputAvailable(
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", output="{}"
)
import backend.copilot.service as svc
orig = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
async for _ in _execute_tool_calls_parallel(
tool_calls, cast(Any, FakeSession())
):
pass
finally:
svc._yield_tool_call = orig
assert len(observed_sessions) == 3
assert observed_sessions[0] is observed_sessions[1] is observed_sessions[2]
@pytest.mark.asyncio
async def test_cancellation_cleans_up():
"""Generator close should cancel in-flight tasks."""
from backend.copilot.response_model import StreamToolInputAvailable
from backend.copilot.service import _execute_tool_calls_parallel
tool_calls = [
{
"id": f"call_{i}",
"type": "function",
"function": {"name": f"t_{i}", "arguments": "{}"},
}
for i in range(2)
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
started = asyncio.Event()
async def fake_yield(tc_list, idx, sess):
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
)
started.set()
await asyncio.sleep(10) # simulate long-running
import backend.copilot.service as svc
orig = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
gen = _execute_tool_calls_parallel(tool_calls, cast(Any, FakeSession()))
await gen.__anext__() # get first event
await started.wait()
await gen.aclose() # close generator
finally:
svc._yield_tool_call = orig
# If we get here without hanging, cleanup worked

View File

@@ -1,218 +0,0 @@
"""Centralized prompt building logic for CoPilot.
This module contains all prompt construction functions and constants,
handling the distinction between:
- SDK mode vs Baseline mode (tool documentation needs)
- Local mode vs E2B mode (storage/filesystem differences)
"""
from backend.copilot.tools import TOOL_REGISTRY
# Shared technical notes that apply to both SDK and baseline modes
_SHARED_TOOL_NOTES = """\
### Sharing files with the user
After saving a file to the persistent workspace with `write_workspace_file`,
share it with the user by embedding the `download_url` from the response in
your message as a Markdown link or image:
- **Any file** — shows as a clickable download link:
`[report.csv](workspace://file_id#text/csv)`
- **Image** — renders inline in chat:
`![chart](workspace://file_id#image/png)`
- **Video** — renders inline in chat with player controls:
`![recording](workspace://file_id#video/mp4)`
The `download_url` field in the `write_workspace_file` response is already
in the correct format — paste it directly after the `(` in the Markdown.
### Passing file content to tools — @@agptfile: references
Instead of copying large file contents into a tool argument, pass a file
reference and the platform will load the content for you.
Syntax: `@@agptfile:<uri>[<start>-<end>]`
- `<uri>` **must** start with `workspace://` or `/` (absolute path):
- `workspace://<file_id>` — workspace file by ID
- `workspace:///<path>` — workspace file by virtual path
- `/absolute/local/path` — ephemeral or sdk_cwd file
- E2B sandbox absolute path (e.g. `/home/user/script.py`)
- `[<start>-<end>]` is an optional 1-indexed inclusive line range.
- URIs that do not start with `workspace://` or `/` are **not** expanded.
Examples:
```
@@agptfile:workspace://abc123
@@agptfile:workspace://abc123[10-50]
@@agptfile:workspace:///reports/q1.md
@@agptfile:/tmp/copilot-<session>/output.py[1-80]
@@agptfile:/home/user/script.py
```
You can embed a reference inside any string argument, or use it as the entire
value. Multiple references in one argument are all expanded.
### Sub-agent tasks
- When using the Task tool, NEVER set `run_in_background` to true.
All tasks must run in the foreground.
"""
# Environment-specific supplement templates
def _build_storage_supplement(
working_dir: str,
sandbox_type: str,
storage_system_1_name: str,
storage_system_1_characteristics: list[str],
storage_system_1_persistence: list[str],
file_move_name_1_to_2: str,
file_move_name_2_to_1: str,
) -> str:
"""Build storage/filesystem supplement for a specific environment.
Template function handles all formatting (bullets, indentation, markdown).
Callers provide clean data as lists of strings.
Args:
working_dir: Working directory path
sandbox_type: Description of bash_exec sandbox
storage_system_1_name: Name of primary storage (ephemeral or cloud)
storage_system_1_characteristics: List of characteristic descriptions
storage_system_1_persistence: List of persistence behavior descriptions
file_move_name_1_to_2: Direction label for primary→persistent
file_move_name_2_to_1: Direction label for persistent→primary
"""
# Format lists as bullet points with proper indentation
characteristics = "\n".join(f" - {c}" for c in storage_system_1_characteristics)
persistence = "\n".join(f" - {p}" for p in storage_system_1_persistence)
return f"""
## Tool notes
### Shell commands
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
for shell commands — it runs {sandbox_type}.
### Working directory
- Your working directory is: `{working_dir}`
- All SDK file tools AND `bash_exec` operate on the same filesystem
- Use relative paths or absolute paths under `{working_dir}` for all file operations
### Two storage systems — CRITICAL to understand
1. **{storage_system_1_name}** (`{working_dir}`):
{characteristics}
{persistence}
2. **Persistent workspace** (cloud storage):
- Files here **survive across sessions indefinitely**
### Moving files between storages
- **{file_move_name_1_to_2}**: Copy to persistent workspace
- **{file_move_name_2_to_1}**: Download for processing
### File persistence
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
{_SHARED_TOOL_NOTES}"""
# Pre-built supplements for common environments
def _get_local_storage_supplement(cwd: str) -> str:
"""Local ephemeral storage (files lost between turns)."""
return _build_storage_supplement(
working_dir=cwd,
sandbox_type="in a network-isolated sandbox",
storage_system_1_name="Ephemeral working directory",
storage_system_1_characteristics=[
"Shared by SDK Read/Write/Edit/Glob/Grep tools AND `bash_exec`",
],
storage_system_1_persistence=[
"Files here are **lost between turns** — do NOT rely on them persisting",
"Use for temporary work: running scripts, processing data, etc.",
],
file_move_name_1_to_2="Ephemeral → Persistent",
file_move_name_2_to_1="Persistent → Ephemeral",
)
def _get_cloud_sandbox_supplement() -> str:
"""Cloud persistent sandbox (files survive across turns in session)."""
return _build_storage_supplement(
working_dir="/home/user",
sandbox_type="in a cloud sandbox with full internet access",
storage_system_1_name="Cloud sandbox",
storage_system_1_characteristics=[
"Shared by all file tools AND `bash_exec` — same filesystem",
"Full Linux environment with internet access",
],
storage_system_1_persistence=[
"Files **persist across turns** within the current session",
"Lost when the session expires (12 h inactivity)",
],
file_move_name_1_to_2="Sandbox → Persistent",
file_move_name_2_to_1="Persistent → Sandbox",
)
def _generate_tool_documentation() -> str:
"""Auto-generate tool documentation from TOOL_REGISTRY.
NOTE: This is ONLY used in baseline mode (direct OpenAI API).
SDK mode doesn't need it since Claude gets tool schemas automatically.
This generates a complete list of available tools with their descriptions,
ensuring the documentation stays in sync with the actual tool implementations.
All workflow guidance is now embedded in individual tool descriptions.
Only documents tools that are available in the current environment
(checked via tool.is_available property).
"""
docs = "\n## AVAILABLE TOOLS\n\n"
# Sort tools alphabetically for consistent output
# Filter by is_available to match get_available_tools() behavior
for name in sorted(TOOL_REGISTRY.keys()):
tool = TOOL_REGISTRY[name]
if not tool.is_available:
continue
schema = tool.as_openai_tool()
desc = schema["function"].get("description", "No description available")
# Format as bullet list with tool name in code style
docs += f"- **`{name}`**: {desc}\n"
return docs
def get_sdk_supplement(use_e2b: bool, cwd: str = "") -> str:
"""Get the supplement for SDK mode (Claude Agent SDK).
SDK mode does NOT include tool documentation because Claude automatically
receives tool schemas from the SDK. Only includes technical notes about
storage systems and execution environment.
Args:
use_e2b: Whether E2B cloud sandbox is being used
cwd: Current working directory (only used in local_storage mode)
Returns:
The supplement string to append to the system prompt
"""
if use_e2b:
return _get_cloud_sandbox_supplement()
return _get_local_storage_supplement(cwd)
def get_baseline_supplement() -> str:
"""Get the supplement for baseline mode (direct OpenAI API).
Baseline mode INCLUDES auto-generated tool documentation because the
direct API doesn't automatically provide tool schemas to Claude.
Also includes shared technical notes (but NOT SDK-specific environment details).
Returns:
The supplement string to append to the system prompt
"""
tool_docs = _generate_tool_documentation()
return tool_docs + _SHARED_TOOL_NOTES

View File

@@ -13,7 +13,6 @@ from typing import Any
from pydantic import BaseModel, Field
from backend.util.json import dumps as json_dumps
from backend.util.truncate import truncate
logger = logging.getLogger(__name__)
@@ -151,9 +150,6 @@ class StreamToolInputAvailable(StreamBaseResponse):
)
_MAX_TOOL_OUTPUT_SIZE = 100_000 # ~100 KB; truncate to avoid bloating SSE/DB
class StreamToolOutputAvailable(StreamBaseResponse):
"""Tool execution result."""
@@ -168,10 +164,6 @@ class StreamToolOutputAvailable(StreamBaseResponse):
default=True, description="Whether the tool execution succeeded"
)
def model_post_init(self, __context: Any) -> None:
"""Truncate oversized outputs after construction."""
self.output = truncate(self.output, _MAX_TOOL_OUTPUT_SIZE)
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-spec fields."""
data = {

View File

@@ -1,155 +0,0 @@
## Agent Generation Guide
You can create, edit, and customize agents directly. You ARE the brain —
generate the agent JSON yourself using block schemas, then validate and save.
### Workflow for Creating/Editing Agents
1. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
search for relevant blocks. This returns block IDs, names, descriptions,
and full input/output schemas.
2. **Find library agents**: Call `find_library_agent` to discover reusable
agents that can be composed as sub-agents via `AgentExecutorBlock`.
3. **Generate JSON**: Build the agent JSON using block schemas:
- Use block IDs from step 1 as `block_id` in nodes
- Wire outputs to inputs using links
- Set design-time config in `input_default`
- Use `AgentInputBlock` for values the user provides at runtime
4. **Write to workspace**: Save the JSON to a workspace file so the user
can review it: `write_workspace_file(filename="agent.json", content=...)`
5. **Validate**: Call `validate_agent_graph` with the agent JSON to check
for errors
6. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
or fix manually based on the error descriptions. Iterate until valid.
7. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
the final `agent_json`
### Agent JSON Structure
```json
{
"id": "<UUID v4>", // auto-generated if omitted
"version": 1,
"is_active": true,
"name": "Agent Name",
"description": "What the agent does",
"nodes": [
{
"id": "<UUID v4>",
"block_id": "<block UUID from find_block>",
"input_default": {
"field_name": "design-time value"
},
"metadata": {
"position": {"x": 0, "y": 0},
"customized_name": "Optional display name"
}
}
],
"links": [
{
"id": "<UUID v4>",
"source_id": "<source node UUID>",
"source_name": "output_field_name",
"sink_id": "<sink node UUID>",
"sink_name": "input_field_name",
"is_static": false
}
]
}
```
### REQUIRED: AgentInputBlock and AgentOutputBlock
Every agent MUST include at least one AgentInputBlock and one AgentOutputBlock.
These define the agent's interface — what it accepts and what it produces.
**AgentInputBlock** (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`):
- Defines a user-facing input field on the agent
- Required `input_default` fields: `name` (str), `value` (default: null)
- Optional: `title`, `description`, `placeholder_values` (for dropdowns)
- Output: `result` — the user-provided value at runtime
- Create one AgentInputBlock per distinct input the agent needs
**AgentOutputBlock** (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`):
- Defines a user-facing output displayed after the agent runs
- Required `input_default` fields: `name` (str)
- The `value` input should be linked from another block's output
- Optional: `title`, `description`, `format` (Jinja2 template)
- Create one AgentOutputBlock per distinct result to show the user
Without these blocks, the agent has no interface and the user cannot provide
inputs or see outputs. NEVER skip them.
### Key Rules
- **Name & description**: Include `name` and `description` in the agent JSON
when creating a new agent, or when editing and the agent's purpose changed.
Without these the agent gets a generic default name.
- **Design-time vs runtime**: `input_default` = values known at build time.
For user-provided values, create an `AgentInputBlock` node and link its
output to the consuming block's input.
- **Credentials**: Do NOT require credentials upfront. Users configure
credentials later in the platform UI after the agent is saved.
- **Node spacing**: Position nodes with at least 800 X-units between them.
- **Nested properties**: Use `parentField_#_childField` notation in link
sink_name/source_name to access nested object fields.
- **is_static links**: Set `is_static: true` when the link carries a
design-time constant (matches a field in inputSchema with a default).
- **ConditionBlock**: Needs a `StoreValueBlock` wired to its `value2` input.
- **Prompt templates**: Use `{{variable}}` (double curly braces) for
literal braces in prompt strings — single `{` and `}` are for
template variables.
- **AgentExecutorBlock**: When composing sub-agents, set `graph_id` and
`graph_version` in input_default, and wire inputs/outputs to match
the sub-agent's schema.
### Using Sub-Agents (AgentExecutorBlock)
To compose agents using other agents as sub-agents:
1. Call `find_library_agent` to find the sub-agent — the response includes
`graph_id`, `graph_version`, `input_schema`, and `output_schema`
2. Create an `AgentExecutorBlock` node (ID: `e189baac-8c20-45a1-94a7-55177ea42565`)
3. Set `input_default`:
- `graph_id`: from the library agent's `graph_id`
- `graph_version`: from the library agent's `graph_version`
- `input_schema`: from the library agent's `input_schema` (JSON Schema)
- `output_schema`: from the library agent's `output_schema` (JSON Schema)
- `user_id`: leave as `""` (filled at runtime)
- `inputs`: `{}` (populated by links at runtime)
4. Wire inputs: link to sink names matching the sub-agent's `input_schema`
property names (e.g., if input_schema has a `"url"` property, use
`"url"` as the sink_name)
5. Wire outputs: link from source names matching the sub-agent's
`output_schema` property names
6. Pass `library_agent_ids` to `create_agent`/`customize_agent` with
the library agent IDs used, so the fixer can validate schemas
### Using MCP Tools (MCPToolBlock)
To use an MCP (Model Context Protocol) tool as a node in the agent:
1. The user must specify which MCP server URL and tool name they want
2. Create an `MCPToolBlock` node (ID: `a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4`)
3. Set `input_default`:
- `server_url`: the MCP server URL (e.g. `"https://mcp.example.com/sse"`)
- `selected_tool`: the tool name on that server
- `tool_input_schema`: JSON Schema for the tool's inputs
- `tool_arguments`: `{}` (populated by links or hardcoded values)
4. The block requires MCP credentials — the user configures these in the
platform UI after the agent is saved
5. Wire inputs using the tool argument field name directly as the sink_name
(e.g., `query`, NOT `tool_arguments_#_query`). The execution engine
automatically collects top-level fields matching tool_input_schema into
tool_arguments.
6. Output: `result` (the tool's return value) and `error` (error message)
### Example: Simple AI Text Processor
A minimal agent with input, processing, and output:
- Node 1: `AgentInputBlock` (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`,
input_default: {"name": "user_text", "title": "Text to process"},
output: "result")
- Node 2: `AITextGeneratorBlock` (input: "prompt" linked from Node 1's "result")
- Node 3: `AgentOutputBlock` (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`,
input_default: {"name": "summary", "title": "Summary"},
input: "value" linked from Node 2's output)

View File

@@ -1,239 +0,0 @@
"""Compaction tracking for SDK-based chat sessions.
Encapsulates the state machine and event emission for context compaction,
both pre-query (history compressed before SDK query) and SDK-internal
(PreCompact hook fires mid-stream).
All compaction-related helpers live here: event builders, message filtering,
persistence, and the ``CompactionTracker`` state machine.
"""
import asyncio
import logging
import uuid
from collections.abc import Callable
from ..constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
from ..model import ChatMessage, ChatSession
from ..response_model import (
StreamBaseResponse,
StreamFinishStep,
StreamStartStep,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Event builders (private — use CompactionTracker or compaction_events)
# ---------------------------------------------------------------------------
def _start_events(tool_call_id: str) -> list[StreamBaseResponse]:
"""Build the opening events for a compaction tool call."""
return [
StreamStartStep(),
StreamToolInputStart(toolCallId=tool_call_id, toolName=COMPACTION_TOOL_NAME),
StreamToolInputAvailable(
toolCallId=tool_call_id, toolName=COMPACTION_TOOL_NAME, input={}
),
]
def _end_events(tool_call_id: str, message: str) -> list[StreamBaseResponse]:
"""Build the closing events for a compaction tool call."""
return [
StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=COMPACTION_TOOL_NAME,
output=message,
),
StreamFinishStep(),
]
def _new_tool_call_id() -> str:
return f"compaction-{uuid.uuid4().hex[:12]}"
# ---------------------------------------------------------------------------
# Public event builder
# ---------------------------------------------------------------------------
def emit_compaction(session: ChatSession) -> list[StreamBaseResponse]:
"""Create, persist, and return a self-contained compaction tool call.
Convenience for callers that don't use ``CompactionTracker`` (e.g. the
legacy non-SDK streaming path in ``service.py``).
"""
tc_id = _new_tool_call_id()
evts = compaction_events(COMPACTION_DONE_MSG, tool_call_id=tc_id)
_persist(session, tc_id, COMPACTION_DONE_MSG)
return evts
def compaction_events(
message: str, tool_call_id: str | None = None
) -> list[StreamBaseResponse]:
"""Emit a self-contained compaction tool call (already completed).
When *tool_call_id* is provided it is reused (e.g. for persistence that
must match an already-streamed start event). Otherwise a new ID is
generated.
"""
tc_id = tool_call_id or _new_tool_call_id()
return _start_events(tc_id) + _end_events(tc_id, message)
# ---------------------------------------------------------------------------
# Message filtering
# ---------------------------------------------------------------------------
def filter_compaction_messages(
messages: list[ChatMessage],
) -> list[ChatMessage]:
"""Remove synthetic compaction tool-call messages (UI-only artifacts).
Strips assistant messages whose only tool calls are compaction calls,
and their corresponding tool-result messages.
"""
compaction_ids: set[str] = set()
filtered: list[ChatMessage] = []
for msg in messages:
if msg.role == "assistant" and msg.tool_calls:
for tc in msg.tool_calls:
if tc.get("function", {}).get("name") == COMPACTION_TOOL_NAME:
compaction_ids.add(tc.get("id", ""))
real_calls = [
tc
for tc in msg.tool_calls
if tc.get("function", {}).get("name") != COMPACTION_TOOL_NAME
]
if not real_calls and not msg.content:
continue
if msg.role == "tool" and msg.tool_call_id in compaction_ids:
continue
filtered.append(msg)
return filtered
# ---------------------------------------------------------------------------
# Persistence
# ---------------------------------------------------------------------------
def _persist(session: ChatSession, tool_call_id: str, message: str) -> None:
"""Append compaction tool-call + result to session messages.
Compaction events are synthetic so they bypass the normal adapter
accumulation. This explicitly records them so they survive a page refresh.
"""
session.messages.append(
ChatMessage(
role="assistant",
content="",
tool_calls=[
{
"id": tool_call_id,
"type": "function",
"function": {
"name": COMPACTION_TOOL_NAME,
"arguments": "{}",
},
}
],
)
)
session.messages.append(
ChatMessage(role="tool", content=message, tool_call_id=tool_call_id)
)
# ---------------------------------------------------------------------------
# CompactionTracker — state machine for streaming sessions
# ---------------------------------------------------------------------------
class CompactionTracker:
"""Tracks compaction state and yields UI events.
Two compaction paths:
1. **Pre-query** — history compressed before the SDK query starts.
Call :meth:`emit_pre_query` to yield a self-contained tool call.
2. **SDK-internal** — ``PreCompact`` hook fires mid-stream.
Call :meth:`emit_start_if_ready` on heartbeat ticks and
:meth:`emit_end_if_ready` when a message arrives.
"""
def __init__(self) -> None:
self._compact_start = asyncio.Event()
self._start_emitted = False
self._done = False
self._tool_call_id = ""
@property
def on_compact(self) -> Callable[[], None]:
"""Callback for the PreCompact hook."""
return self._compact_start.set
# ------------------------------------------------------------------
# Pre-query compaction
# ------------------------------------------------------------------
def emit_pre_query(self, session: ChatSession) -> list[StreamBaseResponse]:
"""Emit + persist a self-contained compaction tool call."""
self._done = True
return emit_compaction(session)
# ------------------------------------------------------------------
# SDK-internal compaction
# ------------------------------------------------------------------
def reset_for_query(self) -> None:
"""Reset per-query state before a new SDK query."""
self._done = False
self._start_emitted = False
self._tool_call_id = ""
def emit_start_if_ready(self) -> list[StreamBaseResponse]:
"""If the PreCompact hook fired, emit start events (spinning tool)."""
if self._compact_start.is_set() and not self._start_emitted and not self._done:
self._compact_start.clear()
self._start_emitted = True
self._tool_call_id = _new_tool_call_id()
return _start_events(self._tool_call_id)
return []
async def emit_end_if_ready(self, session: ChatSession) -> list[StreamBaseResponse]:
"""If compaction is in progress, emit end events and persist."""
# Yield so pending hook tasks can set compact_start
await asyncio.sleep(0)
if self._done:
return []
if not self._start_emitted and not self._compact_start.is_set():
return []
if self._start_emitted:
# Close the open spinner
done_events = _end_events(self._tool_call_id, COMPACTION_DONE_MSG)
persist_id = self._tool_call_id
else:
# PreCompact fired but start never emitted — self-contained
persist_id = _new_tool_call_id()
done_events = compaction_events(
COMPACTION_DONE_MSG, tool_call_id=persist_id
)
self._compact_start.clear()
self._start_emitted = False
self._done = True
_persist(session, persist_id, COMPACTION_DONE_MSG)
return done_events

View File

@@ -1,291 +0,0 @@
"""Tests for sdk/compaction.py — event builders, filtering, persistence, and
CompactionTracker state machine."""
import pytest
from backend.copilot.constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.response_model import (
StreamFinishStep,
StreamStartStep,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from backend.copilot.sdk.compaction import (
CompactionTracker,
compaction_events,
emit_compaction,
filter_compaction_messages,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_session() -> ChatSession:
return ChatSession.new(user_id="test-user")
# ---------------------------------------------------------------------------
# compaction_events
# ---------------------------------------------------------------------------
class TestCompactionEvents:
def test_returns_start_and_end_events(self):
evts = compaction_events("done")
assert len(evts) == 5
assert isinstance(evts[0], StreamStartStep)
assert isinstance(evts[1], StreamToolInputStart)
assert isinstance(evts[2], StreamToolInputAvailable)
assert isinstance(evts[3], StreamToolOutputAvailable)
assert isinstance(evts[4], StreamFinishStep)
def test_uses_provided_tool_call_id(self):
evts = compaction_events("msg", tool_call_id="my-id")
tool_start = evts[1]
assert isinstance(tool_start, StreamToolInputStart)
assert tool_start.toolCallId == "my-id"
def test_generates_id_when_not_provided(self):
evts = compaction_events("msg")
tool_start = evts[1]
assert isinstance(tool_start, StreamToolInputStart)
assert tool_start.toolCallId.startswith("compaction-")
def test_tool_name_is_context_compaction(self):
evts = compaction_events("msg")
tool_start = evts[1]
assert isinstance(tool_start, StreamToolInputStart)
assert tool_start.toolName == COMPACTION_TOOL_NAME
# ---------------------------------------------------------------------------
# emit_compaction
# ---------------------------------------------------------------------------
class TestEmitCompaction:
def test_persists_to_session(self):
session = _make_session()
assert len(session.messages) == 0
evts = emit_compaction(session)
assert len(evts) == 5
# Should have appended 2 messages (assistant tool call + tool result)
assert len(session.messages) == 2
assert session.messages[0].role == "assistant"
assert session.messages[0].tool_calls is not None
assert (
session.messages[0].tool_calls[0]["function"]["name"]
== COMPACTION_TOOL_NAME
)
assert session.messages[1].role == "tool"
assert session.messages[1].content == COMPACTION_DONE_MSG
# ---------------------------------------------------------------------------
# filter_compaction_messages
# ---------------------------------------------------------------------------
class TestFilterCompactionMessages:
def test_removes_compaction_tool_calls(self):
msgs = [
ChatMessage(role="user", content="hello"),
ChatMessage(
role="assistant",
content="",
tool_calls=[
{
"id": "comp-1",
"type": "function",
"function": {"name": COMPACTION_TOOL_NAME, "arguments": "{}"},
}
],
),
ChatMessage(
role="tool", content=COMPACTION_DONE_MSG, tool_call_id="comp-1"
),
ChatMessage(role="assistant", content="world"),
]
filtered = filter_compaction_messages(msgs)
assert len(filtered) == 2
assert filtered[0].content == "hello"
assert filtered[1].content == "world"
def test_keeps_non_compaction_tool_calls(self):
msgs = [
ChatMessage(
role="assistant",
content="",
tool_calls=[
{
"id": "real-1",
"type": "function",
"function": {"name": "search", "arguments": "{}"},
}
],
),
ChatMessage(role="tool", content="result", tool_call_id="real-1"),
]
filtered = filter_compaction_messages(msgs)
assert len(filtered) == 2
def test_keeps_assistant_with_content_and_compaction_call(self):
"""If assistant message has both content and a compaction tool call,
the message is kept (has real content)."""
msgs = [
ChatMessage(
role="assistant",
content="I have content",
tool_calls=[
{
"id": "comp-1",
"type": "function",
"function": {"name": COMPACTION_TOOL_NAME, "arguments": "{}"},
}
],
),
]
filtered = filter_compaction_messages(msgs)
assert len(filtered) == 1
def test_empty_list(self):
assert filter_compaction_messages([]) == []
# ---------------------------------------------------------------------------
# CompactionTracker
# ---------------------------------------------------------------------------
class TestCompactionTracker:
def test_on_compact_sets_event(self):
tracker = CompactionTracker()
tracker.on_compact()
assert tracker._compact_start.is_set()
def test_emit_start_if_ready_no_event(self):
tracker = CompactionTracker()
assert tracker.emit_start_if_ready() == []
def test_emit_start_if_ready_with_event(self):
tracker = CompactionTracker()
tracker.on_compact()
evts = tracker.emit_start_if_ready()
assert len(evts) == 3
assert isinstance(evts[0], StreamStartStep)
assert isinstance(evts[1], StreamToolInputStart)
assert isinstance(evts[2], StreamToolInputAvailable)
def test_emit_start_only_once(self):
tracker = CompactionTracker()
tracker.on_compact()
evts1 = tracker.emit_start_if_ready()
assert len(evts1) == 3
# Second call should return empty
evts2 = tracker.emit_start_if_ready()
assert evts2 == []
@pytest.mark.asyncio
async def test_emit_end_after_start(self):
tracker = CompactionTracker()
session = _make_session()
tracker.on_compact()
tracker.emit_start_if_ready()
evts = await tracker.emit_end_if_ready(session)
assert len(evts) == 2
assert isinstance(evts[0], StreamToolOutputAvailable)
assert isinstance(evts[1], StreamFinishStep)
# Should persist
assert len(session.messages) == 2
@pytest.mark.asyncio
async def test_emit_end_without_start_self_contained(self):
"""If PreCompact fired but start was never emitted, emit_end
produces a self-contained compaction event."""
tracker = CompactionTracker()
session = _make_session()
tracker.on_compact()
# Don't call emit_start_if_ready
evts = await tracker.emit_end_if_ready(session)
assert len(evts) == 5 # Full self-contained event
assert isinstance(evts[0], StreamStartStep)
assert len(session.messages) == 2
@pytest.mark.asyncio
async def test_emit_end_no_op_when_done(self):
tracker = CompactionTracker()
session = _make_session()
tracker.on_compact()
tracker.emit_start_if_ready()
await tracker.emit_end_if_ready(session)
# Second call should be no-op
evts = await tracker.emit_end_if_ready(session)
assert evts == []
@pytest.mark.asyncio
async def test_emit_end_no_op_when_nothing_happened(self):
tracker = CompactionTracker()
session = _make_session()
evts = await tracker.emit_end_if_ready(session)
assert evts == []
def test_emit_pre_query(self):
tracker = CompactionTracker()
session = _make_session()
evts = tracker.emit_pre_query(session)
assert len(evts) == 5
assert len(session.messages) == 2
assert tracker._done is True
def test_reset_for_query(self):
tracker = CompactionTracker()
tracker._done = True
tracker._start_emitted = True
tracker._tool_call_id = "old"
tracker.reset_for_query()
assert tracker._done is False
assert tracker._start_emitted is False
assert tracker._tool_call_id == ""
@pytest.mark.asyncio
async def test_pre_query_blocks_sdk_compaction(self):
"""After pre-query compaction, SDK compaction events are suppressed."""
tracker = CompactionTracker()
session = _make_session()
tracker.emit_pre_query(session)
tracker.on_compact()
evts = tracker.emit_start_if_ready()
assert evts == [] # _done blocks it
@pytest.mark.asyncio
async def test_reset_allows_new_compaction(self):
"""After reset_for_query, compaction can fire again."""
tracker = CompactionTracker()
session = _make_session()
tracker.emit_pre_query(session)
tracker.reset_for_query()
tracker.on_compact()
evts = tracker.emit_start_if_ready()
assert len(evts) == 3 # Start events emitted
@pytest.mark.asyncio
async def test_tool_call_id_consistency(self):
"""Start and end events use the same tool_call_id."""
tracker = CompactionTracker()
session = _make_session()
tracker.on_compact()
start_evts = tracker.emit_start_if_ready()
end_evts = await tracker.emit_end_if_ready(session)
start_evt = start_evts[1]
end_evt = end_evts[0]
assert isinstance(start_evt, StreamToolInputStart)
assert isinstance(end_evt, StreamToolOutputAvailable)
assert start_evt.toolCallId == end_evt.toolCallId
# Persisted ID should also match
tool_calls = session.messages[0].tool_calls
assert tool_calls is not None
assert tool_calls[0]["id"] == start_evt.toolCallId

View File

@@ -10,7 +10,6 @@ import asyncio
import logging
import uuid
from collections.abc import AsyncGenerator
from typing import Any
from ..model import ChatSession
from ..response_model import StreamBaseResponse, StreamStart, StreamTextDelta
@@ -27,7 +26,6 @@ async def stream_chat_completion_dummy(
retry_count: int = 0,
session: ChatSession | None = None,
context: dict[str, str] | None = None,
**_kwargs: Any,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Stream dummy chat completion for testing.

View File

@@ -1,352 +0,0 @@
"""MCP file-tool handlers that route to the E2B cloud sandbox.
When E2B is active, these tools replace the SDK built-in Read/Write/Edit/
Glob/Grep so that all file operations share the same ``/home/user``
filesystem as ``bash_exec``.
SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
"""
import itertools
import json
import logging
import os
import shlex
from typing import Any, Callable
from backend.copilot.context import (
E2B_WORKDIR,
get_current_sandbox,
get_sdk_cwd,
is_allowed_local_path,
resolve_sandbox_path,
)
logger = logging.getLogger(__name__)
def _get_sandbox():
return get_current_sandbox()
def _is_allowed_local(path: str) -> bool:
return is_allowed_local_path(path, get_sdk_cwd())
def _mcp(text: str, *, error: bool = False) -> dict[str, Any]:
if error:
text = json.dumps({"error": text, "type": "error"})
return {"content": [{"type": "text", "text": text}], "isError": error}
def _get_sandbox_and_path(
file_path: str,
) -> tuple[Any, str] | dict[str, Any]:
"""Common preamble: get sandbox + resolve path, or return MCP error."""
sandbox = _get_sandbox()
if sandbox is None:
return _mcp("No E2B sandbox available", error=True)
try:
remote = resolve_sandbox_path(file_path)
except ValueError as exc:
return _mcp(str(exc), error=True)
return sandbox, remote
# Tool handlers
async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
"""Read lines from a sandbox file, falling back to the local host for SDK-internal paths."""
file_path: str = args.get("file_path", "")
offset: int = max(0, int(args.get("offset", 0)))
limit: int = max(1, int(args.get("limit", 2000)))
if not file_path:
return _mcp("file_path is required", error=True)
# SDK-internal paths (tool-results, ephemeral working dir) stay on the host.
if _is_allowed_local(file_path):
return _read_local(file_path, offset, limit)
result = _get_sandbox_and_path(file_path)
if isinstance(result, dict):
return result
sandbox, remote = result
try:
raw: bytes = await sandbox.files.read(remote, format="bytes")
content = raw.decode("utf-8", errors="replace")
except Exception as exc:
return _mcp(f"Failed to read {remote}: {exc}", error=True)
lines = content.splitlines(keepends=True)
selected = list(itertools.islice(lines, offset, offset + limit))
numbered = "".join(
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
)
return _mcp(numbered)
async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
"""Write content to a sandbox file, creating parent directories as needed."""
file_path: str = args.get("file_path", "")
content: str = args.get("content", "")
if not file_path:
return _mcp("file_path is required", error=True)
result = _get_sandbox_and_path(file_path)
if isinstance(result, dict):
return result
sandbox, remote = result
try:
parent = os.path.dirname(remote)
if parent and parent != E2B_WORKDIR:
await sandbox.files.make_dir(parent)
await sandbox.files.write(remote, content)
except Exception as exc:
return _mcp(f"Failed to write {remote}: {exc}", error=True)
return _mcp(f"Successfully wrote to {remote}")
async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
"""Replace a substring in a sandbox file, with optional replace-all support."""
file_path: str = args.get("file_path", "")
old_string: str = args.get("old_string", "")
new_string: str = args.get("new_string", "")
replace_all: bool = args.get("replace_all", False)
if not file_path:
return _mcp("file_path is required", error=True)
if not old_string:
return _mcp("old_string is required", error=True)
result = _get_sandbox_and_path(file_path)
if isinstance(result, dict):
return result
sandbox, remote = result
try:
raw: bytes = await sandbox.files.read(remote, format="bytes")
content = raw.decode("utf-8", errors="replace")
except Exception as exc:
return _mcp(f"Failed to read {remote}: {exc}", error=True)
count = content.count(old_string)
if count == 0:
return _mcp(f"old_string not found in {file_path}", error=True)
if count > 1 and not replace_all:
return _mcp(
f"old_string appears {count} times in {file_path}. "
"Use replace_all=true or provide a more unique string.",
error=True,
)
updated = (
content.replace(old_string, new_string)
if replace_all
else content.replace(old_string, new_string, 1)
)
try:
await sandbox.files.write(remote, updated)
except Exception as exc:
return _mcp(f"Failed to write {remote}: {exc}", error=True)
return _mcp(f"Edited {remote} ({count} replacement{'s' if count > 1 else ''})")
async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
"""Find files matching a name pattern inside the sandbox using ``find``."""
pattern: str = args.get("pattern", "")
path: str = args.get("path", "")
if not pattern:
return _mcp("pattern is required", error=True)
sandbox = _get_sandbox()
if sandbox is None:
return _mcp("No E2B sandbox available", error=True)
try:
search_dir = resolve_sandbox_path(path) if path else E2B_WORKDIR
except ValueError as exc:
return _mcp(str(exc), error=True)
cmd = f"find {shlex.quote(search_dir)} -name {shlex.quote(pattern)} -type f 2>/dev/null | head -500"
try:
result = await sandbox.commands.run(cmd, cwd=E2B_WORKDIR, timeout=10)
except Exception as exc:
return _mcp(f"Glob failed: {exc}", error=True)
files = [line for line in (result.stdout or "").strip().splitlines() if line]
return _mcp(json.dumps(files, indent=2))
async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
"""Search file contents by regex inside the sandbox using ``grep -rn``."""
pattern: str = args.get("pattern", "")
path: str = args.get("path", "")
include: str = args.get("include", "")
if not pattern:
return _mcp("pattern is required", error=True)
sandbox = _get_sandbox()
if sandbox is None:
return _mcp("No E2B sandbox available", error=True)
try:
search_dir = resolve_sandbox_path(path) if path else E2B_WORKDIR
except ValueError as exc:
return _mcp(str(exc), error=True)
parts = ["grep", "-rn", "--color=never"]
if include:
parts.extend(["--include", include])
parts.extend([pattern, search_dir])
cmd = " ".join(shlex.quote(p) for p in parts) + " 2>/dev/null | head -200"
try:
result = await sandbox.commands.run(cmd, cwd=E2B_WORKDIR, timeout=15)
except Exception as exc:
return _mcp(f"Grep failed: {exc}", error=True)
output = (result.stdout or "").strip()
return _mcp(output if output else "No matches found.")
# Local read (for SDK-internal paths)
def _read_local(file_path: str, offset: int, limit: int) -> dict[str, Any]:
"""Read from the host filesystem (defence-in-depth path check)."""
if not _is_allowed_local(file_path):
return _mcp(f"Path not allowed: {file_path}", error=True)
expanded = os.path.realpath(os.path.expanduser(file_path))
try:
with open(expanded, encoding="utf-8", errors="replace") as fh:
selected = list(itertools.islice(fh, offset, offset + limit))
numbered = "".join(
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
)
return _mcp(numbered)
except FileNotFoundError:
return _mcp(f"File not found: {file_path}", error=True)
except Exception as exc:
return _mcp(f"Error reading {file_path}: {exc}", error=True)
# Tool descriptors (name, description, schema, handler)
E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
(
"read_file",
"Read a file from the cloud sandbox (/home/user). "
"Use offset and limit for large files.",
{
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path (relative to /home/user, or absolute).",
},
"offset": {
"type": "integer",
"description": "Line to start reading from (0-indexed). Default: 0.",
},
"limit": {
"type": "integer",
"description": "Number of lines to read. Default: 2000.",
},
},
"required": ["file_path"],
},
_handle_read_file,
),
(
"write_file",
"Write or create a file in the cloud sandbox (/home/user). "
"Parent directories are created automatically. "
"To copy a workspace file into the sandbox, use "
"read_workspace_file with save_to_path instead.",
{
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path (relative to /home/user, or absolute).",
},
"content": {"type": "string", "description": "Content to write."},
},
"required": ["file_path", "content"],
},
_handle_write_file,
),
(
"edit_file",
"Targeted text replacement in a sandbox file. "
"old_string must appear in the file and is replaced with new_string.",
{
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path (relative to /home/user, or absolute).",
},
"old_string": {"type": "string", "description": "Text to find."},
"new_string": {"type": "string", "description": "Replacement text."},
"replace_all": {
"type": "boolean",
"description": "Replace all occurrences (default: false).",
},
},
"required": ["file_path", "old_string", "new_string"],
},
_handle_edit_file,
),
(
"glob",
"Search for files by name pattern in the cloud sandbox.",
{
"type": "object",
"properties": {
"pattern": {
"type": "string",
"description": "Glob pattern (e.g. *.py).",
},
"path": {
"type": "string",
"description": "Directory to search. Default: /home/user.",
},
},
"required": ["pattern"],
},
_handle_glob,
),
(
"grep",
"Search file contents by regex in the cloud sandbox.",
{
"type": "object",
"properties": {
"pattern": {"type": "string", "description": "Regex pattern."},
"path": {
"type": "string",
"description": "File or directory. Default: /home/user.",
},
"include": {
"type": "string",
"description": "Glob to filter files (e.g. *.py).",
},
},
"required": ["pattern"],
},
_handle_grep,
),
]
E2B_FILE_TOOL_NAMES: list[str] = [name for name, *_ in E2B_FILE_TOOLS]

View File

@@ -1,154 +0,0 @@
"""Tests for E2B file-tool path validation and local read safety.
Pure unit tests with no external dependencies (no E2B, no sandbox).
"""
import os
import pytest
from backend.copilot.context import _current_project_dir
from .e2b_file_tools import _read_local, resolve_sandbox_path
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# ---------------------------------------------------------------------------
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
# ---------------------------------------------------------------------------
class TestResolveSandboxPath:
def test_relative_path_resolved(self):
assert resolve_sandbox_path("src/main.py") == "/home/user/src/main.py"
def test_absolute_within_sandbox(self):
assert resolve_sandbox_path("/home/user/file.txt") == "/home/user/file.txt"
def test_workdir_itself(self):
assert resolve_sandbox_path("/home/user") == "/home/user"
def test_relative_dotslash(self):
assert resolve_sandbox_path("./README.md") == "/home/user/README.md"
def test_traversal_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
resolve_sandbox_path("../../etc/passwd")
def test_absolute_traversal_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
resolve_sandbox_path("/home/user/../../etc/passwd")
def test_absolute_outside_sandbox_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
resolve_sandbox_path("/etc/passwd")
def test_root_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
resolve_sandbox_path("/")
def test_home_other_user_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
resolve_sandbox_path("/home/other/file.txt")
def test_deep_nested_allowed(self):
assert resolve_sandbox_path("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
def test_trailing_slash_normalised(self):
assert resolve_sandbox_path("src/") == "/home/user/src"
def test_double_dots_within_sandbox_ok(self):
"""Path that resolves back within /home/user is allowed."""
assert resolve_sandbox_path("a/b/../c.txt") == "/home/user/a/c.txt"
# ---------------------------------------------------------------------------
# _read_local — host filesystem reads with allowlist enforcement
#
# In E2B mode, _read_local only allows tool-results paths (via
# is_allowed_local_path without sdk_cwd). Regular files live on the
# sandbox, not the host.
# ---------------------------------------------------------------------------
class TestReadLocal:
def _make_tool_results_file(self, encoded: str, filename: str, content: str) -> str:
"""Create a tool-results file and return its path."""
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
os.makedirs(tool_results_dir, exist_ok=True)
filepath = os.path.join(tool_results_dir, filename)
with open(filepath, "w") as f:
f.write(content)
return filepath
def test_read_tool_results_file(self):
"""Reading a tool-results file should succeed."""
encoded = "-tmp-copilot-e2b-test-read"
filepath = self._make_tool_results_file(
encoded, "result.txt", "line 1\nline 2\nline 3\n"
)
token = _current_project_dir.set(encoded)
try:
result = _read_local(filepath, offset=0, limit=2000)
assert result["isError"] is False
assert "line 1" in result["content"][0]["text"]
assert "line 2" in result["content"][0]["text"]
finally:
_current_project_dir.reset(token)
os.unlink(filepath)
def test_read_disallowed_path_blocked(self):
"""Reading /etc/passwd should be blocked by the allowlist."""
result = _read_local("/etc/passwd", offset=0, limit=10)
assert result["isError"] is True
assert "not allowed" in result["content"][0]["text"].lower()
def test_read_nonexistent_tool_results(self):
"""A tool-results path that doesn't exist returns FileNotFoundError."""
encoded = "-tmp-copilot-e2b-test-nofile"
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
os.makedirs(tool_results_dir, exist_ok=True)
filepath = os.path.join(tool_results_dir, "nonexistent.txt")
token = _current_project_dir.set(encoded)
try:
result = _read_local(filepath, offset=0, limit=10)
assert result["isError"] is True
assert "not found" in result["content"][0]["text"].lower()
finally:
_current_project_dir.reset(token)
os.rmdir(tool_results_dir)
def test_read_traversal_path_blocked(self):
"""A traversal attempt that escapes allowed directories is blocked."""
result = _read_local("/tmp/copilot-abc/../../etc/shadow", offset=0, limit=10)
assert result["isError"] is True
assert "not allowed" in result["content"][0]["text"].lower()
def test_read_arbitrary_host_path_blocked(self):
"""Arbitrary host paths are blocked even if they exist."""
result = _read_local("/proc/self/environ", offset=0, limit=10)
assert result["isError"] is True
def test_read_with_offset_and_limit(self):
"""Offset and limit should control which lines are returned."""
encoded = "-tmp-copilot-e2b-test-offset"
content = "".join(f"line {i}\n" for i in range(10))
filepath = self._make_tool_results_file(encoded, "lines.txt", content)
token = _current_project_dir.set(encoded)
try:
result = _read_local(filepath, offset=3, limit=2)
assert result["isError"] is False
text = result["content"][0]["text"]
assert "line 3" in text
assert "line 4" in text
assert "line 2" not in text
assert "line 5" not in text
finally:
_current_project_dir.reset(token)
os.unlink(filepath)
def test_read_without_project_dir_blocks_all(self):
"""Without _current_project_dir set, all paths are blocked."""
result = _read_local("/tmp/anything.txt", offset=0, limit=10)
assert result["isError"] is True

View File

@@ -1,281 +0,0 @@
"""File reference protocol for tool call inputs.
Allows the LLM to pass a file reference instead of embedding large content
inline. The processor expands ``@@agptfile:<uri>[<start>-<end>]`` tokens in tool
arguments before the tool is executed.
Protocol
--------
@@agptfile:<uri>[<start>-<end>]
``<uri>`` (required)
- ``workspace://<file_id>`` — workspace file by ID
- ``workspace://<file_id>#<mime>`` — same, MIME hint is ignored for reads
- ``workspace:///<path>`` — workspace file by virtual path
- ``/absolute/local/path`` — ephemeral or sdk_cwd file (validated by
:func:`~backend.copilot.sdk.tool_adapter.is_allowed_local_path`)
- Any absolute path that resolves inside the E2B sandbox
(``/home/user/...``) when a sandbox is active
``[<start>-<end>]`` (optional)
Line range, 1-indexed inclusive. Examples: ``[1-100]``, ``[50-200]``.
Omit to read the entire file.
Examples
--------
@@agptfile:workspace://abc123
@@agptfile:workspace://abc123[10-50]
@@agptfile:workspace:///reports/q1.md
@@agptfile:/tmp/copilot-<session>/output.py[1-80]
@@agptfile:/home/user/script.sh
"""
import itertools
import logging
import os
import re
from dataclasses import dataclass
from typing import Any
from backend.copilot.context import (
get_current_sandbox,
get_sdk_cwd,
is_allowed_local_path,
resolve_sandbox_path,
)
from backend.copilot.model import ChatSession
from backend.copilot.tools.workspace_files import get_manager
from backend.util.file import parse_workspace_uri
class FileRefExpansionError(Exception):
"""Raised when a ``@@agptfile:`` reference in tool call args fails to resolve.
Separating this from inline substitution lets callers (e.g. the MCP tool
wrapper) block tool execution and surface a helpful error to the model
rather than passing an ``[file-ref error: …]`` string as actual input.
"""
logger = logging.getLogger(__name__)
FILE_REF_PREFIX = "@@agptfile:"
# Matches: @@agptfile:<uri>[start-end]?
# Group 1 URI; must start with '/' (absolute path) or 'workspace://'
# Group 2 start line (optional)
# Group 3 end line (optional)
_FILE_REF_RE = re.compile(
re.escape(FILE_REF_PREFIX) + r"((?:workspace://|/)[^\[\s]*)(?:\[(\d+)-(\d+)\])?"
)
# Maximum characters returned for a single file reference expansion.
_MAX_EXPAND_CHARS = 200_000
# Maximum total characters across all @@agptfile: expansions in one string.
_MAX_TOTAL_EXPAND_CHARS = 1_000_000
@dataclass
class FileRef:
uri: str
start_line: int | None # 1-indexed, inclusive
end_line: int | None # 1-indexed, inclusive
def parse_file_ref(text: str) -> FileRef | None:
"""Return a :class:`FileRef` if *text* is a bare file reference token.
A "bare token" means the entire string matches the ``@@agptfile:...`` pattern
(after stripping whitespace). Use :func:`expand_file_refs_in_string` to
expand references embedded in larger strings.
"""
m = _FILE_REF_RE.fullmatch(text.strip())
if not m:
return None
start = int(m.group(2)) if m.group(2) else None
end = int(m.group(3)) if m.group(3) else None
if start is not None and start < 1:
return None
if end is not None and end < 1:
return None
if start is not None and end is not None and end < start:
return None
return FileRef(uri=m.group(1), start_line=start, end_line=end)
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
"""Slice *text* to the requested 1-indexed line range (inclusive)."""
if start is None and end is None:
return text
lines = text.splitlines(keepends=True)
s = (start - 1) if start is not None else 0
e = end if end is not None else len(lines)
selected = list(itertools.islice(lines, s, e))
return "".join(selected)
async def read_file_bytes(
uri: str,
user_id: str | None,
session: ChatSession,
) -> bytes:
"""Resolve *uri* to raw bytes using workspace, local, or E2B path logic.
Raises :class:`ValueError` if the URI cannot be resolved.
"""
# Strip MIME fragment (e.g. workspace://id#mime) before dispatching.
plain = uri.split("#")[0] if uri.startswith("workspace://") else uri
if plain.startswith("workspace://"):
if not user_id:
raise ValueError("workspace:// file references require authentication")
manager = await get_manager(user_id, session.session_id)
ws = parse_workspace_uri(plain)
try:
return await (
manager.read_file(ws.file_ref)
if ws.is_path
else manager.read_file_by_id(ws.file_ref)
)
except FileNotFoundError:
raise ValueError(f"File not found: {plain}")
except Exception as exc:
raise ValueError(f"Failed to read {plain}: {exc}") from exc
if is_allowed_local_path(plain, get_sdk_cwd()):
resolved = os.path.realpath(os.path.expanduser(plain))
try:
with open(resolved, "rb") as fh:
return fh.read()
except FileNotFoundError:
raise ValueError(f"File not found: {plain}")
except Exception as exc:
raise ValueError(f"Failed to read {plain}: {exc}") from exc
sandbox = get_current_sandbox()
if sandbox is not None:
try:
remote = resolve_sandbox_path(plain)
except ValueError as exc:
raise ValueError(
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
) from exc
try:
return bytes(await sandbox.files.read(remote, format="bytes"))
except Exception as exc:
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
raise ValueError(
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
)
async def resolve_file_ref(
ref: FileRef,
user_id: str | None,
session: ChatSession,
) -> str:
"""Resolve a :class:`FileRef` to its text content."""
raw = await read_file_bytes(ref.uri, user_id, session)
return _apply_line_range(
raw.decode("utf-8", errors="replace"), ref.start_line, ref.end_line
)
async def expand_file_refs_in_string(
text: str,
user_id: str | None,
session: "ChatSession",
*,
raise_on_error: bool = False,
) -> str:
"""Expand all ``@@agptfile:...`` tokens in *text*, returning the substituted string.
Non-reference text is passed through unchanged.
If *raise_on_error* is ``False`` (default), expansion errors are surfaced
inline as ``[file-ref error: <message>]`` — useful for display/log contexts
where partial expansion is acceptable.
If *raise_on_error* is ``True``, any resolution failure raises
:class:`FileRefExpansionError` immediately so the caller can block the
operation and surface a clean error to the model.
"""
if FILE_REF_PREFIX not in text:
return text
result: list[str] = []
last_end = 0
total_chars = 0
for m in _FILE_REF_RE.finditer(text):
result.append(text[last_end : m.start()])
start = int(m.group(2)) if m.group(2) else None
end = int(m.group(3)) if m.group(3) else None
if (start is not None and start < 1) or (end is not None and end < 1):
msg = f"line numbers must be >= 1: {m.group(0)}"
if raise_on_error:
raise FileRefExpansionError(msg)
result.append(f"[file-ref error: {msg}]")
last_end = m.end()
continue
if start is not None and end is not None and end < start:
msg = f"end line must be >= start line: {m.group(0)}"
if raise_on_error:
raise FileRefExpansionError(msg)
result.append(f"[file-ref error: {msg}]")
last_end = m.end()
continue
ref = FileRef(uri=m.group(1), start_line=start, end_line=end)
try:
content = await resolve_file_ref(ref, user_id, session)
if len(content) > _MAX_EXPAND_CHARS:
content = content[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
remaining = _MAX_TOTAL_EXPAND_CHARS - total_chars
if remaining <= 0:
content = "[file-ref budget exhausted: total expansion limit reached]"
elif len(content) > remaining:
content = content[:remaining] + "\n... [total budget exhausted]"
total_chars += len(content)
result.append(content)
except ValueError as exc:
logger.warning("file-ref expansion failed for %r: %s", m.group(0), exc)
if raise_on_error:
raise FileRefExpansionError(str(exc)) from exc
result.append(f"[file-ref error: {exc}]")
last_end = m.end()
result.append(text[last_end:])
return "".join(result)
async def expand_file_refs_in_args(
args: dict[str, Any],
user_id: str | None,
session: "ChatSession",
) -> dict[str, Any]:
"""Recursively expand ``@@agptfile:...`` references in tool call arguments.
String values are expanded in-place. Nested dicts and lists are
traversed. Non-string scalars are returned unchanged.
Raises :class:`FileRefExpansionError` if any reference fails to resolve,
so the tool is *not* executed with an error string as its input. The
caller (the MCP tool wrapper) should convert this into an MCP error
response that lets the model correct the reference before retrying.
"""
if not args:
return args
async def _expand(value: Any) -> Any:
if isinstance(value, str):
return await expand_file_refs_in_string(
value, user_id, session, raise_on_error=True
)
if isinstance(value, dict):
return {k: await _expand(v) for k, v in value.items()}
if isinstance(value, list):
return [await _expand(item) for item in value]
return value
return {k: await _expand(v) for k, v in args.items()}

View File

@@ -1,328 +0,0 @@
"""Integration tests for @@agptfile: reference expansion in tool calls.
These tests verify the end-to-end behaviour of the file reference protocol:
- Parsing @@agptfile: tokens from tool arguments
- Resolving local-filesystem paths (sdk_cwd / ephemeral)
- Expanding references inside the tool-call pipeline (_execute_tool_sync)
- The extended Read tool handler (workspace:// pass-through via session context)
No real LLM or database is required; workspace reads are stubbed where needed.
"""
from __future__ import annotations
import os
import tempfile
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot.sdk.file_ref import (
FileRef,
expand_file_refs_in_args,
expand_file_refs_in_string,
read_file_bytes,
resolve_file_ref,
)
from backend.copilot.sdk.tool_adapter import _read_file_handler
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_session(session_id: str = "integ-sess") -> MagicMock:
s = MagicMock()
s.session_id = session_id
return s
# ---------------------------------------------------------------------------
# Local-file resolution (sdk_cwd)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_resolve_file_ref_local_path():
"""resolve_file_ref reads a real local file when it's within sdk_cwd."""
with tempfile.TemporaryDirectory() as sdk_cwd:
# Write a test file inside sdk_cwd
test_file = os.path.join(sdk_cwd, "hello.txt")
with open(test_file, "w") as f:
f.write("line1\nline2\nline3\n")
session = _make_session()
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
ref = FileRef(uri=test_file, start_line=None, end_line=None)
content = await resolve_file_ref(ref, user_id="u1", session=session)
assert content == "line1\nline2\nline3\n"
@pytest.mark.asyncio
async def test_resolve_file_ref_local_path_with_line_range():
"""resolve_file_ref respects line ranges for local files."""
with tempfile.TemporaryDirectory() as sdk_cwd:
test_file = os.path.join(sdk_cwd, "multi.txt")
lines = [f"line{i}\n" for i in range(1, 11)] # line1 … line10
with open(test_file, "w") as f:
f.writelines(lines)
session = _make_session()
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
ref = FileRef(uri=test_file, start_line=3, end_line=5)
content = await resolve_file_ref(ref, user_id="u1", session=session)
assert content == "line3\nline4\nline5\n"
@pytest.mark.asyncio
async def test_resolve_file_ref_rejects_path_outside_sdk_cwd():
"""resolve_file_ref raises ValueError for paths outside sdk_cwd."""
with tempfile.TemporaryDirectory() as sdk_cwd:
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox_var:
mock_cwd_var.get.return_value = sdk_cwd
mock_sandbox_var.get.return_value = None
ref = FileRef(uri="/etc/passwd", start_line=None, end_line=None)
with pytest.raises(ValueError, match="not allowed"):
await resolve_file_ref(ref, user_id="u1", session=_make_session())
# ---------------------------------------------------------------------------
# expand_file_refs_in_string — integration with real files
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_expand_string_with_real_file():
"""expand_file_refs_in_string replaces @@agptfile: token with actual content."""
with tempfile.TemporaryDirectory() as sdk_cwd:
test_file = os.path.join(sdk_cwd, "data.txt")
with open(test_file, "w") as f:
f.write("hello world\n")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_string(
f"Content: @@agptfile:{test_file}",
user_id="u1",
session=_make_session(),
)
assert result == "Content: hello world\n"
@pytest.mark.asyncio
async def test_expand_string_missing_file_is_surfaced_inline():
"""Missing file ref yields [file-ref error: …] inline rather than raising."""
with tempfile.TemporaryDirectory() as sdk_cwd:
missing = os.path.join(sdk_cwd, "does_not_exist.txt")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_string(
f"@@agptfile:{missing}",
user_id="u1",
session=_make_session(),
)
assert "[file-ref error:" in result
assert "not found" in result.lower() or "not allowed" in result.lower()
# ---------------------------------------------------------------------------
# expand_file_refs_in_args — dict traversal with real files
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_expand_args_replaces_file_ref_in_nested_dict():
"""Nested @@agptfile: references in args are fully expanded."""
with tempfile.TemporaryDirectory() as sdk_cwd:
file_a = os.path.join(sdk_cwd, "a.txt")
file_b = os.path.join(sdk_cwd, "b.txt")
with open(file_a, "w") as f:
f.write("AAA")
with open(file_b, "w") as f:
f.write("BBB")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{
"outer": {
"content_a": f"@@agptfile:{file_a}",
"content_b": f"start @@agptfile:{file_b} end",
},
"count": 42,
},
user_id="u1",
session=_make_session(),
)
assert result["outer"]["content_a"] == "AAA"
assert result["outer"]["content_b"] == "start BBB end"
assert result["count"] == 42
# ---------------------------------------------------------------------------
# _read_file_handler — extended to accept workspace:// and local paths
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_file_handler_local_file():
"""_read_file_handler reads a local file when it's within sdk_cwd."""
with tempfile.TemporaryDirectory() as sdk_cwd:
test_file = os.path.join(sdk_cwd, "read_test.txt")
lines = [f"L{i}\n" for i in range(1, 6)]
with open(test_file, "w") as f:
f.writelines(lines)
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
"backend.copilot.context._current_project_dir"
) as mock_proj_var, patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", _make_session()),
):
mock_cwd_var.get.return_value = sdk_cwd
mock_proj_var.get.return_value = ""
result = await _read_file_handler(
{"file_path": test_file, "offset": 0, "limit": 5}
)
assert not result["isError"]
text = result["content"][0]["text"]
assert "L1" in text
assert "L5" in text
@pytest.mark.asyncio
async def test_read_file_handler_workspace_uri():
"""_read_file_handler handles workspace:// URIs via the workspace manager."""
mock_session = _make_session()
mock_manager = AsyncMock()
mock_manager.read_file_by_id.return_value = b"workspace file content\nline two\n"
with patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", mock_session),
), patch(
"backend.copilot.sdk.file_ref.get_manager",
new=AsyncMock(return_value=mock_manager),
):
result = await _read_file_handler(
{"file_path": "workspace://file-id-abc", "offset": 0, "limit": 10}
)
assert not result["isError"], result["content"][0]["text"]
text = result["content"][0]["text"]
assert "workspace file content" in text
assert "line two" in text
@pytest.mark.asyncio
async def test_read_file_handler_workspace_uri_no_session():
"""_read_file_handler returns error when workspace:// is used without session."""
with patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=(None, None),
):
result = await _read_file_handler({"file_path": "workspace://some-id"})
assert result["isError"]
assert "session" in result["content"][0]["text"].lower()
@pytest.mark.asyncio
async def test_read_file_handler_access_denied():
"""_read_file_handler rejects paths outside allowed locations."""
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox, patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", _make_session()),
):
mock_cwd.get.return_value = "/tmp/safe-dir"
mock_sandbox.get.return_value = None
result = await _read_file_handler({"file_path": "/etc/passwd"})
assert result["isError"]
assert "not allowed" in result["content"][0]["text"].lower()
# ---------------------------------------------------------------------------
# read_file_bytes — workspace:///path (virtual path) and E2B sandbox branch
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_file_bytes_workspace_virtual_path():
"""workspace:///path resolves via manager.read_file (is_path=True path)."""
session = _make_session()
mock_manager = AsyncMock()
mock_manager.read_file.return_value = b"virtual path content"
with patch(
"backend.copilot.sdk.file_ref.get_manager",
new=AsyncMock(return_value=mock_manager),
):
result = await read_file_bytes("workspace:///reports/q1.md", "user-1", session)
assert result == b"virtual path content"
mock_manager.read_file.assert_awaited_once_with("/reports/q1.md")
@pytest.mark.asyncio
async def test_read_file_bytes_e2b_sandbox_branch():
"""read_file_bytes reads from the E2B sandbox when a sandbox is active."""
session = _make_session()
mock_sandbox = AsyncMock()
mock_sandbox.files.read.return_value = bytearray(b"sandbox content")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox_var, patch(
"backend.copilot.context._current_project_dir"
) as mock_proj:
mock_cwd.get.return_value = ""
mock_sandbox_var.get.return_value = mock_sandbox
mock_proj.get.return_value = ""
result = await read_file_bytes("/home/user/script.sh", None, session)
assert result == b"sandbox content"
mock_sandbox.files.read.assert_awaited_once_with(
"/home/user/script.sh", format="bytes"
)
@pytest.mark.asyncio
async def test_read_file_bytes_e2b_path_escapes_sandbox_raises():
"""read_file_bytes raises ValueError for paths that escape the sandbox root."""
session = _make_session()
mock_sandbox = AsyncMock()
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox_var, patch(
"backend.copilot.context._current_project_dir"
) as mock_proj:
mock_cwd.get.return_value = ""
mock_sandbox_var.get.return_value = mock_sandbox
mock_proj.get.return_value = ""
with pytest.raises(ValueError, match="not allowed"):
await read_file_bytes("/etc/passwd", None, session)

View File

@@ -1,382 +0,0 @@
"""Tests for the @@agptfile: reference protocol (file_ref.py)."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot.sdk.file_ref import (
_MAX_EXPAND_CHARS,
FileRef,
FileRefExpansionError,
_apply_line_range,
expand_file_refs_in_args,
expand_file_refs_in_string,
parse_file_ref,
)
# ---------------------------------------------------------------------------
# parse_file_ref
# ---------------------------------------------------------------------------
def test_parse_file_ref_workspace_id():
ref = parse_file_ref("@@agptfile:workspace://abc123")
assert ref == FileRef(uri="workspace://abc123", start_line=None, end_line=None)
def test_parse_file_ref_workspace_id_with_mime():
ref = parse_file_ref("@@agptfile:workspace://abc123#text/plain")
assert ref is not None
assert ref.uri == "workspace://abc123#text/plain"
assert ref.start_line is None
def test_parse_file_ref_workspace_path():
ref = parse_file_ref("@@agptfile:workspace:///reports/q1.md")
assert ref is not None
assert ref.uri == "workspace:///reports/q1.md"
def test_parse_file_ref_with_line_range():
ref = parse_file_ref("@@agptfile:workspace://abc123[10-50]")
assert ref == FileRef(uri="workspace://abc123", start_line=10, end_line=50)
def test_parse_file_ref_local_path():
ref = parse_file_ref("@@agptfile:/tmp/copilot-session/output.py[1-100]")
assert ref is not None
assert ref.uri == "/tmp/copilot-session/output.py"
assert ref.start_line == 1
assert ref.end_line == 100
def test_parse_file_ref_no_match():
assert parse_file_ref("just a normal string") is None
assert parse_file_ref("workspace://abc123") is None # missing @@agptfile: prefix
assert (
parse_file_ref("@@agptfile:workspace://abc123 extra") is None
) # not full match
def test_parse_file_ref_strips_whitespace():
ref = parse_file_ref(" @@agptfile:workspace://abc123 ")
assert ref is not None
assert ref.uri == "workspace://abc123"
def test_parse_file_ref_invalid_range_zero_start():
assert parse_file_ref("@@agptfile:workspace://abc123[0-5]") is None
def test_parse_file_ref_invalid_range_end_less_than_start():
assert parse_file_ref("@@agptfile:workspace://abc123[10-5]") is None
def test_parse_file_ref_invalid_range_zero_end():
assert parse_file_ref("@@agptfile:workspace://abc123[1-0]") is None
# ---------------------------------------------------------------------------
# _apply_line_range
# ---------------------------------------------------------------------------
TEXT = "line1\nline2\nline3\nline4\nline5\n"
def test_apply_line_range_no_range():
assert _apply_line_range(TEXT, None, None) == TEXT
def test_apply_line_range_start_only():
result = _apply_line_range(TEXT, 3, None)
assert result == "line3\nline4\nline5\n"
def test_apply_line_range_full():
result = _apply_line_range(TEXT, 2, 4)
assert result == "line2\nline3\nline4\n"
def test_apply_line_range_single_line():
result = _apply_line_range(TEXT, 2, 2)
assert result == "line2\n"
def test_apply_line_range_beyond_eof():
result = _apply_line_range(TEXT, 4, 999)
assert result == "line4\nline5\n"
# ---------------------------------------------------------------------------
# expand_file_refs_in_string
# ---------------------------------------------------------------------------
def _make_session(session_id: str = "sess-1") -> MagicMock:
session = MagicMock()
session.session_id = session_id
return session
async def _resolve_always(ref: FileRef, _user_id: str | None, _session: object) -> str:
"""Stub resolver that returns the URI and range as a descriptive string."""
if ref.start_line is not None:
return f"content:{ref.uri}[{ref.start_line}-{ref.end_line}]"
return f"content:{ref.uri}"
@pytest.mark.asyncio
async def test_expand_no_refs():
result = await expand_file_refs_in_string(
"no references here", user_id="u1", session=_make_session()
)
assert result == "no references here"
@pytest.mark.asyncio
async def test_expand_single_ref():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://abc123",
user_id="u1",
session=_make_session(),
)
assert result == "content:workspace://abc123"
@pytest.mark.asyncio
async def test_expand_ref_with_range():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://abc123[10-50]",
user_id="u1",
session=_make_session(),
)
assert result == "content:workspace://abc123[10-50]"
@pytest.mark.asyncio
async def test_expand_ref_embedded_in_text():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_string(
"Here is the file: @@agptfile:workspace://abc123 — done",
user_id="u1",
session=_make_session(),
)
assert result == "Here is the file: content:workspace://abc123 — done"
@pytest.mark.asyncio
async def test_expand_multiple_refs():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://file1 and @@agptfile:workspace://file2[1-5]",
user_id="u1",
session=_make_session(),
)
assert result == "content:workspace://file1 and content:workspace://file2[1-5]"
@pytest.mark.asyncio
async def test_expand_invalid_range_zero_start_surfaces_inline():
"""expand_file_refs_in_string surfaces [file-ref error: ...] for zero-start ranges."""
result = await expand_file_refs_in_string(
"@@agptfile:workspace://abc123[0-5]",
user_id="u1",
session=_make_session(),
)
assert "[file-ref error:" in result
assert "line numbers must be >= 1" in result
@pytest.mark.asyncio
async def test_expand_invalid_range_end_less_than_start_surfaces_inline():
"""expand_file_refs_in_string surfaces [file-ref error: ...] when end < start."""
result = await expand_file_refs_in_string(
"prefix @@agptfile:workspace://abc123[10-5] suffix",
user_id="u1",
session=_make_session(),
)
assert "[file-ref error:" in result
assert "end line must be >= start line" in result
assert "prefix" in result
assert "suffix" in result
@pytest.mark.asyncio
async def test_expand_ref_error_surfaces_inline():
async def _raise(*args, **kwargs): # noqa: ARG001
raise ValueError("file not found")
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_raise),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://bad",
user_id="u1",
session=_make_session(),
)
assert "[file-ref error:" in result
assert "file not found" in result
# ---------------------------------------------------------------------------
# expand_file_refs_in_args
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_expand_args_flat():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_args(
{"content": "@@agptfile:workspace://abc123", "other": 42},
user_id="u1",
session=_make_session(),
)
assert result["content"] == "content:workspace://abc123"
assert result["other"] == 42
@pytest.mark.asyncio
async def test_expand_args_nested_dict():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_args(
{"outer": {"inner": "@@agptfile:workspace://nested"}},
user_id="u1",
session=_make_session(),
)
assert result["outer"]["inner"] == "content:workspace://nested"
@pytest.mark.asyncio
async def test_expand_args_list():
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_always),
):
result = await expand_file_refs_in_args(
{
"items": [
"@@agptfile:workspace://a",
"plain",
"@@agptfile:workspace://b[1-3]",
]
},
user_id="u1",
session=_make_session(),
)
assert result["items"] == [
"content:workspace://a",
"plain",
"content:workspace://b[1-3]",
]
@pytest.mark.asyncio
async def test_expand_args_empty():
result = await expand_file_refs_in_args({}, user_id="u1", session=_make_session())
assert result == {}
@pytest.mark.asyncio
async def test_expand_args_no_refs():
result = await expand_file_refs_in_args(
{"key": "no refs here", "num": 1},
user_id="u1",
session=_make_session(),
)
assert result == {"key": "no refs here", "num": 1}
@pytest.mark.asyncio
async def test_expand_args_raises_on_file_ref_error():
"""expand_file_refs_in_args raises FileRefExpansionError instead of passing
the inline error string to the tool, blocking tool execution."""
async def _raise(*args, **kwargs): # noqa: ARG001
raise ValueError("path does not exist")
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_raise),
):
with pytest.raises(FileRefExpansionError) as exc_info:
await expand_file_refs_in_args(
{"prompt": "@@agptfile:/home/user/missing.txt"},
user_id="u1",
session=_make_session(),
)
assert "path does not exist" in str(exc_info.value)
# ---------------------------------------------------------------------------
# Per-file truncation and aggregate budget
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_expand_per_file_truncation():
"""Content exceeding _MAX_EXPAND_CHARS is truncated with a marker."""
oversized = "x" * (_MAX_EXPAND_CHARS + 100)
async def _resolve_oversized(ref: FileRef, _uid: str | None, _s: object) -> str:
return oversized
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_oversized),
):
result = await expand_file_refs_in_string(
"@@agptfile:workspace://big-file",
user_id="u1",
session=_make_session(),
)
assert len(result) <= _MAX_EXPAND_CHARS + len("\n... [truncated]") + 10
assert "[truncated]" in result
@pytest.mark.asyncio
async def test_expand_aggregate_budget_exhausted():
"""When the aggregate budget is exhausted, later refs get the budget message."""
# Each file returns just under 300K; after ~4 files the 1M budget is used.
big_chunk = "y" * 300_000
async def _resolve_big(ref: FileRef, _uid: str | None, _s: object) -> str:
return big_chunk
with patch(
"backend.copilot.sdk.file_ref.resolve_file_ref",
new=AsyncMock(side_effect=_resolve_big),
):
# 5 refs @ 300K each = 1.5M → last ref(s) should hit the aggregate limit
refs = " ".join(f"@@agptfile:workspace://f{i}" for i in range(5))
result = await expand_file_refs_in_string(
refs,
user_id="u1",
session=_make_session(),
)
assert "budget exhausted" in result

View File

@@ -1,28 +0,0 @@
## MCP Tool Guide
### Workflow
`run_mcp_tool` follows a two-step pattern:
1. **Discover** — call with only `server_url` to list available tools on the server.
2. **Execute** — call again with `server_url`, `tool_name`, and `tool_arguments` to run a tool.
### Known hosted MCP servers
Use these URLs directly without asking the user:
| Service | URL |
|---|---|
| Notion | `https://mcp.notion.com/mcp` |
| Linear | `https://mcp.linear.app/mcp` |
| Stripe | `https://mcp.stripe.com` |
| Intercom | `https://mcp.intercom.com/mcp` |
| Cloudflare | `https://mcp.cloudflare.com/mcp` |
| Atlassian / Jira | `https://mcp.atlassian.com/mcp` |
For other services, search the MCP registry at https://registry.modelcontextprotocol.io/.
### Authentication
If the server requires credentials, a `SetupRequirementsResponse` is returned with an OAuth
login prompt. Once the user completes the flow and confirms, retry the same call immediately.

View File

@@ -1,172 +0,0 @@
"""Tests for OTEL tracing setup in the SDK copilot path."""
import os
from unittest.mock import MagicMock, patch
class TestSetupLangfuseOtel:
"""Tests for _setup_langfuse_otel()."""
def test_noop_when_langfuse_not_configured(self):
"""No env vars should be set when Langfuse credentials are missing."""
with patch(
"backend.copilot.sdk.service._is_langfuse_configured", return_value=False
):
from backend.copilot.sdk.service import _setup_langfuse_otel
# Clear any previously set env vars
env_keys = [
"LANGSMITH_OTEL_ENABLED",
"LANGSMITH_OTEL_ONLY",
"LANGSMITH_TRACING",
"OTEL_EXPORTER_OTLP_ENDPOINT",
"OTEL_EXPORTER_OTLP_HEADERS",
]
saved = {k: os.environ.pop(k, None) for k in env_keys}
try:
_setup_langfuse_otel()
for key in env_keys:
assert key not in os.environ, f"{key} should not be set"
finally:
for k, v in saved.items():
if v is not None:
os.environ[k] = v
def test_sets_env_vars_when_langfuse_configured(self):
"""OTEL env vars should be set when Langfuse credentials exist."""
mock_settings = MagicMock()
mock_settings.secrets.langfuse_public_key = "pk-test-123"
mock_settings.secrets.langfuse_secret_key = "sk-test-456"
mock_settings.secrets.langfuse_host = "https://langfuse.example.com"
mock_settings.secrets.langfuse_tracing_environment = "test"
with (
patch(
"backend.copilot.sdk.service._is_langfuse_configured",
return_value=True,
),
patch("backend.copilot.sdk.service.Settings", return_value=mock_settings),
patch(
"backend.copilot.sdk.service.configure_claude_agent_sdk",
return_value=True,
) as mock_configure,
):
from backend.copilot.sdk.service import _setup_langfuse_otel
# Clear env vars so setdefault works
env_keys = [
"LANGSMITH_OTEL_ENABLED",
"LANGSMITH_OTEL_ONLY",
"LANGSMITH_TRACING",
"OTEL_EXPORTER_OTLP_ENDPOINT",
"OTEL_EXPORTER_OTLP_HEADERS",
"OTEL_RESOURCE_ATTRIBUTES",
]
saved = {k: os.environ.pop(k, None) for k in env_keys}
try:
_setup_langfuse_otel()
assert os.environ["LANGSMITH_OTEL_ENABLED"] == "true"
assert os.environ["LANGSMITH_OTEL_ONLY"] == "true"
assert os.environ["LANGSMITH_TRACING"] == "true"
assert (
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"]
== "https://langfuse.example.com/api/public/otel"
)
assert "Authorization=Basic" in os.environ["OTEL_EXPORTER_OTLP_HEADERS"]
assert (
os.environ["OTEL_RESOURCE_ATTRIBUTES"]
== "langfuse.environment=test"
)
mock_configure.assert_called_once_with(tags=["sdk"])
finally:
for k, v in saved.items():
if v is not None:
os.environ[k] = v
elif k in os.environ:
del os.environ[k]
def test_existing_env_vars_not_overwritten(self):
"""Explicit env-var overrides should not be clobbered."""
mock_settings = MagicMock()
mock_settings.secrets.langfuse_public_key = "pk-test"
mock_settings.secrets.langfuse_secret_key = "sk-test"
mock_settings.secrets.langfuse_host = "https://langfuse.example.com"
with (
patch(
"backend.copilot.sdk.service._is_langfuse_configured",
return_value=True,
),
patch("backend.copilot.sdk.service.Settings", return_value=mock_settings),
patch(
"backend.copilot.sdk.service.configure_claude_agent_sdk",
return_value=True,
),
):
from backend.copilot.sdk.service import _setup_langfuse_otel
saved = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
try:
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = "https://custom.endpoint/v1"
_setup_langfuse_otel()
assert (
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"]
== "https://custom.endpoint/v1"
)
finally:
if saved is not None:
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = saved
elif "OTEL_EXPORTER_OTLP_ENDPOINT" in os.environ:
del os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"]
def test_graceful_failure_on_exception(self):
"""Setup should not raise even if internal code fails."""
with (
patch(
"backend.copilot.sdk.service._is_langfuse_configured",
return_value=True,
),
patch(
"backend.copilot.sdk.service.Settings",
side_effect=RuntimeError("settings unavailable"),
),
):
from backend.copilot.sdk.service import _setup_langfuse_otel
# Should not raise — just logs and returns
_setup_langfuse_otel()
class TestPropagateAttributesImport:
"""Verify langfuse.propagate_attributes is available."""
def test_propagate_attributes_is_importable(self):
from langfuse import propagate_attributes
assert callable(propagate_attributes)
def test_propagate_attributes_returns_context_manager(self):
from langfuse import propagate_attributes
ctx = propagate_attributes(user_id="u1", session_id="s1", tags=["test"])
assert hasattr(ctx, "__enter__")
assert hasattr(ctx, "__exit__")
class TestReceiveResponseCompat:
"""Verify ClaudeSDKClient.receive_response() exists (langsmith patches it)."""
def test_receive_response_exists(self):
from claude_agent_sdk import ClaudeSDKClient
assert hasattr(ClaudeSDKClient, "receive_response")
def test_receive_response_is_async_generator(self):
import inspect
from claude_agent_sdk import ClaudeSDKClient
method = getattr(ClaudeSDKClient, "receive_response")
assert inspect.isfunction(method) or inspect.ismethod(method)

View File

@@ -118,7 +118,7 @@ async def test_build_query_resume_up_to_date():
ChatMessage(role="user", content="what's new?"),
]
)
result, was_compacted = await _build_query_message(
result = await _build_query_message(
"what's new?",
session,
use_resume=True,
@@ -127,7 +127,6 @@ async def test_build_query_resume_up_to_date():
)
# transcript_msg_count == msg_count - 1, so no gap
assert result == "what's new?"
assert was_compacted is False
@pytest.mark.asyncio
@@ -142,7 +141,7 @@ async def test_build_query_resume_stale_transcript():
ChatMessage(role="user", content="turn 3"),
]
)
result, was_compacted = await _build_query_message(
result = await _build_query_message(
"turn 3",
session,
use_resume=True,
@@ -153,7 +152,6 @@ async def test_build_query_resume_stale_transcript():
assert "turn 2" in result
assert "reply 2" in result
assert "Now, the user says:\nturn 3" in result
assert was_compacted is False # gap context does not compact
@pytest.mark.asyncio
@@ -166,7 +164,7 @@ async def test_build_query_resume_zero_msg_count():
ChatMessage(role="user", content="new msg"),
]
)
result, was_compacted = await _build_query_message(
result = await _build_query_message(
"new msg",
session,
use_resume=True,
@@ -174,14 +172,13 @@ async def test_build_query_resume_zero_msg_count():
session_id="test-session",
)
assert result == "new msg"
assert was_compacted is False
@pytest.mark.asyncio
async def test_build_query_no_resume_single_message():
"""Without --resume and only 1 message, return raw message."""
session = _make_session([ChatMessage(role="user", content="first")])
result, was_compacted = await _build_query_message(
result = await _build_query_message(
"first",
session,
use_resume=False,
@@ -189,7 +186,6 @@ async def test_build_query_no_resume_single_message():
session_id="test-session",
)
assert result == "first"
assert was_compacted is False
@pytest.mark.asyncio
@@ -203,16 +199,16 @@ async def test_build_query_no_resume_multi_message(monkeypatch):
]
)
# Mock _compress_messages to return the messages as-is
async def _mock_compress(msgs):
return msgs, False
# Mock _compress_conversation_history to return the messages as-is
async def _mock_compress(sess):
return sess.messages[:-1]
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages",
"backend.copilot.sdk.service._compress_conversation_history",
_mock_compress,
)
result, was_compacted = await _build_query_message(
result = await _build_query_message(
"new question",
session,
use_resume=False,
@@ -223,33 +219,3 @@ async def test_build_query_no_resume_multi_message(monkeypatch):
assert "older question" in result
assert "older answer" in result
assert "Now, the user says:\nnew question" in result
assert was_compacted is False # mock returns False
@pytest.mark.asyncio
async def test_build_query_no_resume_multi_message_compacted(monkeypatch):
"""When compression actually compacts, was_compacted should be True."""
session = _make_session(
[
ChatMessage(role="user", content="old"),
ChatMessage(role="assistant", content="reply"),
ChatMessage(role="user", content="new"),
]
)
async def _mock_compress(msgs):
return msgs, True # Simulate actual compaction
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages",
_mock_compress,
)
result, was_compacted = await _build_query_message(
"new",
session,
use_resume=False,
transcript_msg_count=0,
session_id="test-session",
)
assert was_compacted is True

View File

@@ -536,12 +536,10 @@ async def test_wait_for_stash_signaled():
result = await wait_for_stash(timeout=1.0)
assert result is True
pto = _pto.get()
assert pto is not None
assert pto.get("WebSearch") == ["result data"]
assert _pto.get({}).get("WebSearch") == ["result data"]
# Cleanup
_pto.set({})
_pto.set({}) # type: ignore[arg-type]
_stash_event.set(None)
@@ -556,7 +554,7 @@ async def test_wait_for_stash_timeout():
assert result is False
# Cleanup
_pto.set({})
_pto.set({}) # type: ignore[arg-type]
_stash_event.set(None)
@@ -575,12 +573,10 @@ async def test_wait_for_stash_already_stashed():
assert result is True
# But the stash itself is populated
pto = _pto.get()
assert pto is not None
assert pto.get("Read") == ["file contents"]
assert _pto.get({}).get("Read") == ["file contents"]
# Cleanup
_pto.set({})
_pto.set({}) # type: ignore[arg-type]
_stash_event.set(None)

View File

@@ -6,12 +6,11 @@ ensuring multi-user isolation and preventing unauthorized operations.
import json
import logging
import os
import re
from collections.abc import Callable
from typing import Any, cast
from backend.copilot.context import is_allowed_local_path
from .tool_adapter import (
BLOCKED_TOOLS,
DANGEROUS_PATTERNS,
@@ -39,20 +38,40 @@ def _validate_workspace_path(
) -> dict[str, Any]:
"""Validate that a workspace-scoped tool only accesses allowed paths.
Delegates to :func:`is_allowed_local_path` which permits:
Allowed directories:
- The SDK working directory (``/tmp/copilot-<session>/``)
- The current session's tool-results directory
(``~/.claude/projects/<encoded-cwd>/tool-results/``)
- The SDK tool-results directory (``~/.claude/projects/…/tool-results/``)
"""
path = tool_input.get("file_path") or tool_input.get("path") or ""
if not path:
# Glob/Grep without a path default to cwd which is already sandboxed
return {}
if is_allowed_local_path(path, sdk_cwd):
# Resolve relative paths against sdk_cwd (the SDK sets cwd so the LLM
# naturally uses relative paths like "test.txt" instead of absolute ones).
# Tilde paths (~/) are home-dir references, not relative — expand first.
if path.startswith("~"):
resolved = os.path.realpath(os.path.expanduser(path))
elif not os.path.isabs(path) and sdk_cwd:
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
else:
resolved = os.path.realpath(path)
# Allow access within the SDK working directory
if sdk_cwd:
norm_cwd = os.path.realpath(sdk_cwd)
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
return {}
# Allow access to ~/.claude/projects/*/tool-results/ (big tool results)
claude_dir = os.path.realpath(os.path.expanduser("~/.claude/projects"))
tool_results_seg = os.sep + "tool-results" + os.sep
if resolved.startswith(claude_dir + os.sep) and tool_results_seg in resolved:
return {}
logger.warning(f"Blocked {tool_name} outside workspace: {path}")
logger.warning(
f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})"
)
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
return _deny(
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
@@ -127,7 +146,7 @@ def create_security_hooks(
user_id: str | None,
sdk_cwd: str | None = None,
max_subtasks: int = 3,
on_compact: Callable[[], None] | None = None,
on_stop: Callable[[str, str], None] | None = None,
) -> dict[str, Any]:
"""Create the security hooks configuration for Claude Agent SDK.
@@ -136,12 +155,15 @@ def create_security_hooks(
- PostToolUse: Log successful tool executions
- PostToolUseFailure: Log and handle failed tool executions
- PreCompact: Log context compaction events (SDK handles compaction automatically)
- Stop: Capture transcript path for stateless resume (when *on_stop* is provided)
Args:
user_id: Current user ID for isolation validation
sdk_cwd: SDK working directory for workspace-scoped tool validation
max_subtasks: Maximum concurrent Task (sub-agent) spawns allowed per session
on_compact: Callback invoked when SDK starts compacting context.
on_stop: Callback ``(transcript_path, sdk_session_id)`` invoked when
the SDK finishes processing — used to read the JSONL transcript
before the CLI process exits.
Returns:
Hooks configuration dict for ClaudeAgentOptions
@@ -304,8 +326,30 @@ def create_security_hooks(
logger.info(
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
)
if on_compact is not None:
on_compact()
return cast(SyncHookJSONOutput, {})
# --- Stop hook: capture transcript path for stateless resume ---
async def stop_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Capture transcript path when SDK finishes processing.
The Stop hook fires while the CLI process is still alive, giving us
a reliable window to read the JSONL transcript before SIGTERM.
"""
_ = context, tool_use_id
transcript_path = cast(str, input_data.get("transcript_path", ""))
sdk_session_id = cast(str, input_data.get("session_id", ""))
if transcript_path and on_stop:
logger.info(
f"[SDK] Stop hook: transcript_path={transcript_path}, "
f"sdk_session_id={sdk_session_id[:12]}..."
)
on_stop(transcript_path, sdk_session_id)
return cast(SyncHookJSONOutput, {})
hooks: dict[str, Any] = {
@@ -317,6 +361,9 @@ def create_security_hooks(
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
}
if on_stop is not None:
hooks["Stop"] = [HookMatcher(matcher=None, hooks=[stop_hook])]
return hooks
except ImportError:
# Fallback for when SDK isn't available - return empty hooks

View File

@@ -9,9 +9,8 @@ import os
import pytest
from backend.copilot.context import _current_project_dir
from .security_hooks import _validate_tool_access, _validate_user_isolation
from .service import _is_tool_error_or_denial
SDK_CWD = "/tmp/copilot-abc123"
@@ -123,25 +122,15 @@ def test_read_no_cwd_denies_absolute():
def test_read_tool_results_allowed():
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
# is_allowed_local_path requires the session's encoded cwd to be set
token = _current_project_dir.set("-tmp-copilot-abc123")
try:
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert result == {}
finally:
_current_project_dir.reset(token)
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert result == {}
def test_read_claude_projects_settings_json_denied():
"""SDK-internal artifacts like settings.json are NOT accessible — only tool-results/ is."""
def test_read_claude_projects_without_tool_results_denied():
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
token = _current_project_dir.set("-tmp-copilot-abc123")
try:
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert _is_denied(result)
finally:
_current_project_dir.reset(token)
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert _is_denied(result)
# -- Built-in Bash is blocked (use bash_exec MCP tool instead) ---------------
@@ -354,3 +343,76 @@ async def test_task_slot_released_on_failure(_hooks):
context={},
)
assert not _is_denied(result)
# -- _is_tool_error_or_denial ------------------------------------------------
class TestIsToolErrorOrDenial:
def test_none_content(self):
assert _is_tool_error_or_denial(None) is False
def test_empty_content(self):
assert _is_tool_error_or_denial("") is False
def test_benign_output(self):
assert _is_tool_error_or_denial("All good, no issues.") is False
def test_security_marker(self):
assert _is_tool_error_or_denial("[SECURITY] Tool access blocked") is True
def test_cannot_be_bypassed(self):
assert _is_tool_error_or_denial("This restriction cannot be bypassed.") is True
def test_not_allowed(self):
assert _is_tool_error_or_denial("Operation not allowed in sandbox") is True
def test_background_task_denial(self):
assert (
_is_tool_error_or_denial(
"Background task execution is not supported. "
"Run tasks in the foreground instead."
)
is True
)
def test_subtask_limit_denial(self):
assert (
_is_tool_error_or_denial(
"Maximum 2 concurrent sub-tasks. "
"Wait for running sub-tasks to finish, "
"or continue in the main conversation."
)
is True
)
def test_denied_marker(self):
assert (
_is_tool_error_or_denial("Access denied: insufficient privileges") is True
)
def test_blocked_marker(self):
assert _is_tool_error_or_denial("Request blocked by security policy") is True
def test_failed_marker(self):
assert _is_tool_error_or_denial("Failed to execute tool: timeout") is True
def test_mcp_iserror(self):
assert _is_tool_error_or_denial('{"isError": true, "content": []}') is True
def test_benign_error_in_value(self):
"""Content like '0 errors found' should not trigger — 'error' was removed."""
assert _is_tool_error_or_denial("0 errors found") is False
def test_benign_permission_field(self):
"""Schema descriptions mentioning 'permission' should not trigger."""
assert (
_is_tool_error_or_denial(
'{"fields": [{"name": "permission_level", "type": "int"}]}'
)
is False
)
def test_benign_not_found_in_listing(self):
"""File listing containing 'not found' in filenames should not trigger."""
assert _is_tool_error_or_denial("readme.md\nfile-not-found-handler.py") is False

File diff suppressed because it is too large Load Diff

View File

@@ -1,290 +0,0 @@
"""Tests for SDK service helpers."""
import asyncio
import base64
import os
from dataclasses import dataclass
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from .service import _prepare_file_attachments
@dataclass
class _FakeFileInfo:
id: str
name: str
path: str
mime_type: str
size_bytes: int
_PATCH_TARGET = "backend.copilot.sdk.service.get_manager"
class TestPrepareFileAttachments:
@pytest.mark.asyncio
async def test_empty_list_returns_empty(self, tmp_path):
result = await _prepare_file_attachments([], "u", "s", str(tmp_path))
assert result.hint == ""
assert result.image_blocks == []
@pytest.mark.asyncio
async def test_image_embedded_as_vision_block(self, tmp_path):
"""JPEG images should become vision content blocks, not files on disk."""
raw = b"\xff\xd8\xff\xe0fake-jpeg"
info = _FakeFileInfo(
id="abc",
name="photo.jpg",
path="/photo.jpg",
mime_type="image/jpeg",
size_bytes=len(raw),
)
mgr = AsyncMock()
mgr.get_file_info.return_value = info
mgr.read_file_by_id.return_value = raw
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
result = await _prepare_file_attachments(
["abc"], "user1", "sess1", str(tmp_path)
)
assert "1 file" in result.hint
assert "photo.jpg" in result.hint
assert "embedded as image" in result.hint
assert len(result.image_blocks) == 1
block = result.image_blocks[0]
assert block["type"] == "image"
assert block["source"]["media_type"] == "image/jpeg"
assert block["source"]["data"] == base64.b64encode(raw).decode("ascii")
# Image should NOT be written to disk (embedded instead)
assert not os.path.exists(os.path.join(tmp_path, "photo.jpg"))
@pytest.mark.asyncio
async def test_pdf_saved_to_disk(self, tmp_path):
"""PDFs should be saved to disk for Read tool access, not embedded."""
info = _FakeFileInfo("f1", "doc.pdf", "/doc.pdf", "application/pdf", 50)
mgr = AsyncMock()
mgr.get_file_info.return_value = info
mgr.read_file_by_id.return_value = b"%PDF-1.4 fake"
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
result = await _prepare_file_attachments(["f1"], "u", "s", str(tmp_path))
assert result.image_blocks == []
saved = tmp_path / "doc.pdf"
assert saved.exists()
assert saved.read_bytes() == b"%PDF-1.4 fake"
assert str(saved) in result.hint
@pytest.mark.asyncio
async def test_mixed_images_and_files(self, tmp_path):
"""Images become blocks, non-images go to disk."""
infos = {
"id1": _FakeFileInfo("id1", "a.png", "/a.png", "image/png", 4),
"id2": _FakeFileInfo("id2", "b.pdf", "/b.pdf", "application/pdf", 4),
"id3": _FakeFileInfo("id3", "c.txt", "/c.txt", "text/plain", 4),
}
mgr = AsyncMock()
mgr.get_file_info.side_effect = lambda fid: infos[fid]
mgr.read_file_by_id.return_value = b"data"
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
result = await _prepare_file_attachments(
["id1", "id2", "id3"], "u", "s", str(tmp_path)
)
assert "3 files" in result.hint
assert "a.png" in result.hint
assert "b.pdf" in result.hint
assert "c.txt" in result.hint
# Only the image should be a vision block
assert len(result.image_blocks) == 1
assert result.image_blocks[0]["source"]["media_type"] == "image/png"
# Non-image files should be on disk
assert (tmp_path / "b.pdf").exists()
assert (tmp_path / "c.txt").exists()
# Read tool hint should appear (has non-image files)
assert "Read tool" in result.hint
@pytest.mark.asyncio
async def test_singular_noun(self, tmp_path):
info = _FakeFileInfo("x", "only.txt", "/only.txt", "text/plain", 2)
mgr = AsyncMock()
mgr.get_file_info.return_value = info
mgr.read_file_by_id.return_value = b"hi"
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
result = await _prepare_file_attachments(["x"], "u", "s", str(tmp_path))
assert "1 file." in result.hint
@pytest.mark.asyncio
async def test_missing_file_skipped(self, tmp_path):
mgr = AsyncMock()
mgr.get_file_info.return_value = None
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
result = await _prepare_file_attachments(
["missing-id"], "u", "s", str(tmp_path)
)
assert result.hint == ""
assert result.image_blocks == []
@pytest.mark.asyncio
async def test_image_only_no_read_hint(self, tmp_path):
"""When all files are images, no Read tool hint should appear."""
info = _FakeFileInfo("i1", "cat.png", "/cat.png", "image/png", 4)
mgr = AsyncMock()
mgr.get_file_info.return_value = info
mgr.read_file_by_id.return_value = b"data"
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
result = await _prepare_file_attachments(["i1"], "u", "s", str(tmp_path))
assert "Read tool" not in result.hint
assert len(result.image_blocks) == 1
class TestPromptSupplement:
"""Tests for centralized prompt supplement generation."""
def test_sdk_supplement_excludes_tool_docs(self):
"""SDK mode should NOT include tool documentation (Claude gets schemas automatically)."""
from backend.copilot.prompting import get_sdk_supplement
# Test both local and E2B modes
local_supplement = get_sdk_supplement(use_e2b=False, cwd="/tmp/test")
e2b_supplement = get_sdk_supplement(use_e2b=True, cwd="")
# Should NOT have tool list section
assert "## AVAILABLE TOOLS" not in local_supplement
assert "## AVAILABLE TOOLS" not in e2b_supplement
# Should still have technical notes
assert "## Tool notes" in local_supplement
assert "## Tool notes" in e2b_supplement
def test_baseline_supplement_includes_tool_docs(self):
"""Baseline mode MUST include tool documentation (direct API needs it)."""
from backend.copilot.prompting import get_baseline_supplement
supplement = get_baseline_supplement()
# MUST have tool list section
assert "## AVAILABLE TOOLS" in supplement
# Should NOT have environment-specific notes (SDK-only)
assert "## Tool notes" not in supplement
def test_baseline_supplement_includes_key_tools(self):
"""Baseline supplement should document all essential tools."""
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.tools import TOOL_REGISTRY
docs = get_baseline_supplement()
# Core agent workflow tools (always available)
assert "`create_agent`" in docs
assert "`run_agent`" in docs
assert "`find_library_agent`" in docs
assert "`edit_agent`" in docs
# MCP integration (always available)
assert "`run_mcp_tool`" in docs
# Folder management (always available)
assert "`create_folder`" in docs
# Browser tools only if available (Playwright may not be installed in CI)
if (
TOOL_REGISTRY.get("browser_navigate")
and TOOL_REGISTRY["browser_navigate"].is_available
):
assert "`browser_navigate`" in docs
def test_baseline_supplement_includes_workflows(self):
"""Baseline supplement should include workflow guidance in tool descriptions."""
from backend.copilot.prompting import get_baseline_supplement
docs = get_baseline_supplement()
# Workflows are now in individual tool descriptions (not separate sections)
# Check that key workflow concepts appear in tool descriptions
assert "agent_json" in docs or "find_block" in docs
assert "run_mcp_tool" in docs
def test_baseline_supplement_completeness(self):
"""All available tools from TOOL_REGISTRY should appear in baseline supplement."""
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.tools import TOOL_REGISTRY
docs = get_baseline_supplement()
# Verify each available registered tool is documented
# (matches _generate_tool_documentation which filters by is_available)
for tool_name, tool in TOOL_REGISTRY.items():
if not tool.is_available:
continue
assert (
f"`{tool_name}`" in docs
), f"Tool '{tool_name}' missing from baseline supplement"
def test_pause_task_scheduled_before_transcript_upload(self):
"""Pause is scheduled as a background task before transcript upload begins.
The finally block in stream_response_sdk does:
(1) asyncio.create_task(pause_sandbox_direct(...)) — fire-and-forget
(2) await asyncio.shield(upload_transcript(...)) — awaited
Scheduling pause via create_task before awaiting upload ensures:
- Pause never blocks transcript upload (billing stops concurrently)
- On E2B timeout, pause silently fails; upload proceeds unaffected
"""
call_order: list[str] = []
async def _mock_pause(sandbox, session_id):
call_order.append("pause")
async def _mock_upload(**kwargs):
call_order.append("upload")
async def _simulate_teardown():
"""Mirror the service.py finally block teardown sequence."""
sandbox = MagicMock()
# (1) Schedule pause — mirrors lines ~1427-1429 in service.py
task = asyncio.create_task(_mock_pause(sandbox, "test-sess"))
# (2) Await transcript upload — mirrors lines ~1460-1468 in service.py
# Yielding to the event loop here lets the pause task start concurrently.
await _mock_upload(
user_id="u", session_id="test-sess", content="x", message_count=1
)
await task
asyncio.run(_simulate_teardown())
# Both must run; pause is scheduled before upload starts
assert "pause" in call_order
assert "upload" in call_order
# create_task schedules pause, then upload is awaited — pause runs
# concurrently during upload's first yield. The ordering guarantee is
# that create_task is CALLED before upload is AWAITED (see source order).
def test_baseline_supplement_no_duplicate_tools(self):
"""No tool should appear multiple times in baseline supplement."""
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.tools import TOOL_REGISTRY
docs = get_baseline_supplement()
# Count occurrences of each available tool in the entire supplement
for tool_name, tool in TOOL_REGISTRY.items():
if not tool.is_available:
continue
# Count how many times this tool appears as a bullet point
count = docs.count(f"- **`{tool_name}`**")
assert count == 1, f"Tool '{tool_name}' appears {count} times (should be 1)"

View File

@@ -11,45 +11,28 @@ import logging
import os
import uuid
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any
from typing import Any
from claude_agent_sdk import create_sdk_mcp_server, tool
from backend.copilot.context import (
_current_project_dir,
_current_sandbox,
_current_sdk_cwd,
_current_session,
_current_user_id,
_encode_cwd_for_cli,
get_execution_context,
get_sdk_cwd,
is_allowed_local_path,
)
from backend.copilot.model import ChatSession
from backend.copilot.sdk.file_ref import (
FileRefExpansionError,
expand_file_refs_in_args,
read_file_bytes,
)
from backend.copilot.tools import TOOL_REGISTRY
from backend.copilot.tools.base import BaseTool
from backend.util.truncate import truncate
from .e2b_file_tools import E2B_FILE_TOOL_NAMES, E2B_FILE_TOOLS
if TYPE_CHECKING:
from e2b import AsyncSandbox
logger = logging.getLogger(__name__)
# Max MCP response size in chars — keeps tool output under the SDK's 10 MB JSON buffer.
_MCP_MAX_CHARS = 500_000
# Allowed base directory for the Read tool (SDK saves oversized tool results here).
# Restricted to ~/.claude/projects/ and further validated to require "tool-results"
# in the path — prevents reading settings, credentials, or other sensitive files.
_SDK_PROJECTS_DIR = os.path.expanduser("~/.claude/projects/")
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
MCP_SERVER_NAME = "copilot"
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
# Context variables to pass user/session info to tool execution
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
_current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
)
# Stash for MCP tool outputs before the SDK potentially truncates them.
# Keyed by tool_name → full output string. Consumed (popped) by the
# response adapter when it builds StreamToolOutputAvailable.
@@ -70,29 +53,30 @@ _stash_event: ContextVar[asyncio.Event | None] = ContextVar(
def set_execution_context(
user_id: str | None,
session: ChatSession,
sandbox: "AsyncSandbox | None" = None,
sdk_cwd: str | None = None,
) -> None:
"""Set the execution context for tool calls.
This must be called before streaming begins to ensure tools have access
to user_id, session, and (optionally) an E2B sandbox for bash execution.
to user_id and session information.
Args:
user_id: Current user's ID.
session: Current chat session.
sandbox: Optional E2B sandbox; when set, bash_exec routes commands there.
sdk_cwd: SDK working directory; used to scope tool-results reads.
"""
_current_user_id.set(user_id)
_current_session.set(session)
_current_sandbox.set(sandbox)
_current_sdk_cwd.set(sdk_cwd or "")
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
_pending_tool_outputs.set({})
_stash_event.set(asyncio.Event())
def get_execution_context() -> tuple[str | None, ChatSession | None]:
"""Get the current execution context."""
return (
_current_user_id.get(),
_current_session.get(),
)
def pop_pending_tool_output(tool_name: str) -> str | None:
"""Pop and return the oldest stashed output for *tool_name*.
@@ -185,11 +169,7 @@ async def _execute_tool_sync(
session: ChatSession,
args: dict[str, Any],
) -> dict[str, Any]:
"""Execute a tool synchronously and return MCP-formatted response.
Note: ``@@agptfile:`` expansion is handled upstream in the ``_truncating`` wrapper
so all registered handlers (BaseTool, E2B, Read) expand uniformly.
"""
"""Execute a tool synchronously and return MCP-formatted response."""
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
result = await base_tool.execute(
user_id=user_id,
@@ -202,12 +182,66 @@ async def _execute_tool_sync(
result.output if isinstance(result.output, str) else json.dumps(result.output)
)
# Stash the full output before the SDK potentially truncates it.
pending = _pending_tool_outputs.get(None)
if pending is not None:
pending.setdefault(base_tool.name, []).append(text)
content_blocks: list[dict[str, str]] = [{"type": "text", "text": text}]
# If the tool result contains inline image data, add an MCP image block
# so Claude can "see" the image (e.g. read_workspace_file on a small PNG).
image_block = _extract_image_block(text)
if image_block:
content_blocks.append(image_block)
return {
"content": [{"type": "text", "text": text}],
"content": content_blocks,
"isError": not result.success,
}
# MIME types that Claude can process as image content blocks.
_SUPPORTED_IMAGE_TYPES = frozenset(
{"image/png", "image/jpeg", "image/gif", "image/webp"}
)
def _extract_image_block(text: str) -> dict[str, str] | None:
"""Extract an MCP image content block from a tool result JSON string.
Detects workspace file responses with ``content_base64`` and an image
MIME type, returning an MCP-format image block that allows Claude to
"see" the image. Returns ``None`` if the result is not an inline image.
"""
try:
data = json.loads(text)
except (json.JSONDecodeError, TypeError):
return None
if not isinstance(data, dict):
return None
mime_type = data.get("mime_type", "")
base64_content = data.get("content_base64", "")
# Only inline small images — large ones would exceed Claude's limits.
# 32 KB raw ≈ ~43 KB base64.
_MAX_IMAGE_BASE64_BYTES = 43_000
if (
mime_type in _SUPPORTED_IMAGE_TYPES
and base64_content
and len(base64_content) <= _MAX_IMAGE_BASE64_BYTES
):
return {
"type": "image",
"data": base64_content,
"mimeType": mime_type,
}
return None
def _mcp_error(message: str) -> dict[str, Any]:
return {
"content": [
@@ -250,50 +284,39 @@ def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Read a file with optional offset/limit.
"""Read a file with optional offset/limit. Restricted to SDK working directory.
Supports ``workspace://`` URIs (delegated to the workspace manager) and
local paths within the session's allowed directories (sdk_cwd + tool-results).
After reading, the file is deleted to prevent accumulation in long-running pods.
"""
file_path = args.get("file_path", "")
offset = max(0, int(args.get("offset", 0)))
limit = max(1, int(args.get("limit", 2000)))
offset = args.get("offset", 0)
limit = args.get("limit", 2000)
def _mcp_err(text: str) -> dict[str, Any]:
return {"content": [{"type": "text", "text": text}], "isError": True}
# Security: only allow reads under ~/.claude/projects/**/tool-results/
real_path = os.path.realpath(file_path)
if not real_path.startswith(_SDK_PROJECTS_DIR) or "tool-results" not in real_path:
return {
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
"isError": True,
}
def _mcp_ok(text: str) -> dict[str, Any]:
return {"content": [{"type": "text", "text": text}], "isError": False}
if file_path.startswith("workspace://"):
user_id, session = get_execution_context()
if session is None:
return _mcp_err("workspace:// file references require an active session")
try:
raw = await read_file_bytes(file_path, user_id, session)
except ValueError as exc:
return _mcp_err(str(exc))
lines = raw.decode("utf-8", errors="replace").splitlines(keepends=True)
selected = list(itertools.islice(lines, offset, offset + limit))
numbered = "".join(
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
)
return _mcp_ok(numbered)
if not is_allowed_local_path(file_path, get_sdk_cwd()):
return _mcp_err(f"Path not allowed: {file_path}")
resolved = os.path.realpath(os.path.expanduser(file_path))
try:
with open(resolved) as f:
with open(real_path) as f:
selected = list(itertools.islice(f, offset, offset + limit))
content = "".join(selected)
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
# don't delete here — the SDK may read in multiple chunks.
return _mcp_ok("".join(selected))
return {"content": [{"type": "text", "text": content}], "isError": False}
except FileNotFoundError:
return _mcp_err(f"File not found: {file_path}")
return {
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
"isError": True,
}
except Exception as e:
return _mcp_err(f"Error reading file: {e}")
return {
"content": [{"type": "text", "text": f"Error reading file: {e}"}],
"isError": True,
}
_READ_TOOL_NAME = "Read"
@@ -321,100 +344,50 @@ _READ_TOOL_SCHEMA = {
}
# ---------------------------------------------------------------------------
# MCP result helpers
# ---------------------------------------------------------------------------
def _text_from_mcp_result(result: dict[str, Any]) -> str:
"""Extract concatenated text from an MCP response's content blocks."""
content = result.get("content", [])
if not isinstance(content, list):
return ""
return "".join(
b.get("text", "")
for b in content
if isinstance(b, dict) and b.get("type") == "text"
)
def create_copilot_mcp_server(*, use_e2b: bool = False):
# Create the MCP server configuration
def create_copilot_mcp_server():
"""Create an in-process MCP server configuration for CoPilot tools.
When *use_e2b* is True, five additional MCP file tools are registered
that route directly to the E2B sandbox filesystem, and the caller should
disable the corresponding SDK built-in tools via
:func:`get_sdk_disallowed_tools`.
This can be passed to ClaudeAgentOptions.mcp_servers.
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
package being available. This function returns the configuration that
can be used with the SDK.
"""
try:
from claude_agent_sdk import create_sdk_mcp_server, tool
def _truncating(fn, tool_name: str):
"""Wrap a tool handler so its response is truncated to stay under the
SDK's 10 MB JSON buffer, and stash the (truncated) output for the
response adapter before the SDK can apply its own head-truncation.
# Create decorated tool functions
sdk_tools = []
Also expands ``@@agptfile:`` references in args so every registered tool
(BaseTool, E2B file tools, Read) receives resolved content uniformly.
Applied once to every registered tool."""
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
user_id, session = get_execution_context()
if session is not None:
try:
args = await expand_file_refs_in_args(args, user_id, session)
except FileRefExpansionError as exc:
return _mcp_error(
f"@@agptfile: reference could not be resolved: {exc}. "
"Ensure the file exists before referencing it. "
"For sandbox paths use bash_exec to verify the file exists first; "
"for workspace files use a workspace:// URI."
)
result = await fn(args)
truncated = truncate(result, _MCP_MAX_CHARS)
# Stash the text so the response adapter can forward our
# middle-out truncated version to the frontend instead of the
# SDK's head-truncated version (for outputs >~100 KB the SDK
# persists to tool-results/ with a 2 KB head-only preview).
if not truncated.get("isError"):
text = _text_from_mcp_result(truncated)
if text:
stash_pending_tool_output(tool_name, text)
return truncated
return wrapper
sdk_tools = []
for tool_name, base_tool in TOOL_REGISTRY.items():
handler = create_tool_handler(base_tool)
decorated = tool(
tool_name,
base_tool.description,
_build_input_schema(base_tool),
)(_truncating(handler, tool_name))
sdk_tools.append(decorated)
# E2B file tools replace SDK built-in Read/Write/Edit/Glob/Grep.
if use_e2b:
for name, desc, schema, handler in E2B_FILE_TOOLS:
decorated = tool(name, desc, schema)(_truncating(handler, name))
for tool_name, base_tool in TOOL_REGISTRY.items():
handler = create_tool_handler(base_tool)
decorated = tool(
tool_name,
base_tool.description,
_build_input_schema(base_tool),
)(handler)
sdk_tools.append(decorated)
# Read tool for SDK-truncated tool results (always needed).
read_tool = tool(
_READ_TOOL_NAME,
_READ_TOOL_DESCRIPTION,
_READ_TOOL_SCHEMA,
)(_truncating(_read_file_handler, _READ_TOOL_NAME))
sdk_tools.append(read_tool)
# Add the Read tool so the SDK can read back oversized tool results
read_tool = tool(
_READ_TOOL_NAME,
_READ_TOOL_DESCRIPTION,
_READ_TOOL_SCHEMA,
)(_read_file_handler)
sdk_tools.append(read_tool)
return create_sdk_mcp_server(
name=MCP_SERVER_NAME,
version="1.0.0",
tools=sdk_tools,
)
server = create_sdk_mcp_server(
name=MCP_SERVER_NAME,
version="1.0.0",
tools=sdk_tools,
)
return server
except ImportError:
# Let ImportError propagate so service.py handles the fallback
raise
# SDK built-in tools allowed within the workspace directory.
@@ -424,11 +397,16 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
# Task allows spawning sub-agents (rate-limited by security hooks).
# WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk.
# TodoWrite manages the task checklist shown in the UI — no security concern.
# In E2B mode, all five are disabled — MCP equivalents provide direct sandbox
# access. read_file also handles local tool-results and ephemeral reads.
_SDK_BUILTIN_FILE_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep"]
_SDK_BUILTIN_ALWAYS = ["Task", "WebSearch", "TodoWrite"]
_SDK_BUILTIN_TOOLS = [*_SDK_BUILTIN_FILE_TOOLS, *_SDK_BUILTIN_ALWAYS]
_SDK_BUILTIN_TOOLS = [
"Read",
"Write",
"Edit",
"Glob",
"Grep",
"Task",
"WebSearch",
"TodoWrite",
]
# SDK built-in tools that must be explicitly blocked.
# Bash: dangerous — agent uses mcp__copilot__bash_exec with kernel-level
@@ -436,10 +414,15 @@ _SDK_BUILTIN_TOOLS = [*_SDK_BUILTIN_FILE_TOOLS, *_SDK_BUILTIN_ALWAYS]
# WebFetch: SSRF risk — can reach internal network (localhost, 10.x, etc.).
# Agent uses the SSRF-protected mcp__copilot__web_fetch tool instead.
# AskUserQuestion: interactive CLI tool — no terminal in copilot context.
# ToolSearch: SDK bug — ToolSearch results are passed as raw match objects
# instead of text content blocks, causing Anthropic API 400 errors.
# All copilot tools are already explicitly listed in allowed_tools,
# so dynamic tool search is unnecessary.
SDK_DISALLOWED_TOOLS = [
"Bash",
"WebFetch",
"AskUserQuestion",
"ToolSearch",
]
# Tools that are blocked entirely in security hooks (defence-in-depth).
@@ -475,37 +458,11 @@ DANGEROUS_PATTERNS = [
r"subprocess",
]
# Static tool name list for the non-E2B case (backward compatibility).
# List of tool names for allowed_tools configuration
# Include MCP tools, the MCP Read tool for oversized results,
# and SDK built-in file tools for workspace operations.
COPILOT_TOOL_NAMES = [
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
*_SDK_BUILTIN_TOOLS,
]
def get_copilot_tool_names(*, use_e2b: bool = False) -> list[str]:
"""Build the ``allowed_tools`` list for :class:`ClaudeAgentOptions`.
When *use_e2b* is True the SDK built-in file tools are replaced by MCP
equivalents that route to the E2B sandbox.
"""
if not use_e2b:
return list(COPILOT_TOOL_NAMES)
return [
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
*[f"{MCP_TOOL_PREFIX}{name}" for name in E2B_FILE_TOOL_NAMES],
*_SDK_BUILTIN_ALWAYS,
]
def get_sdk_disallowed_tools(*, use_e2b: bool = False) -> list[str]:
"""Build the ``disallowed_tools`` list for :class:`ClaudeAgentOptions`.
When *use_e2b* is True the SDK built-in file tools are also disabled
because MCP equivalents provide direct sandbox access.
"""
if not use_e2b:
return list(SDK_DISALLOWED_TOOLS)
return [*SDK_DISALLOWED_TOOLS, *_SDK_BUILTIN_FILE_TOOLS]

View File

@@ -1,170 +0,0 @@
"""Tests for tool_adapter helpers: truncation, stash, context vars."""
import pytest
from backend.copilot.context import get_sdk_cwd
from backend.util.truncate import truncate
from .tool_adapter import (
_MCP_MAX_CHARS,
_text_from_mcp_result,
pop_pending_tool_output,
set_execution_context,
stash_pending_tool_output,
)
# ---------------------------------------------------------------------------
# _text_from_mcp_result
# ---------------------------------------------------------------------------
class TestTextFromMcpResult:
def test_single_text_block(self):
result = {"content": [{"type": "text", "text": "hello"}]}
assert _text_from_mcp_result(result) == "hello"
def test_multiple_text_blocks_concatenated(self):
result = {
"content": [
{"type": "text", "text": "one"},
{"type": "text", "text": "two"},
]
}
assert _text_from_mcp_result(result) == "onetwo"
def test_non_text_blocks_ignored(self):
result = {
"content": [
{"type": "image", "data": "..."},
{"type": "text", "text": "only this"},
]
}
assert _text_from_mcp_result(result) == "only this"
def test_empty_content_list(self):
assert _text_from_mcp_result({"content": []}) == ""
def test_missing_content_key(self):
assert _text_from_mcp_result({}) == ""
def test_non_list_content(self):
assert _text_from_mcp_result({"content": "raw string"}) == ""
def test_missing_text_field(self):
result = {"content": [{"type": "text"}]}
assert _text_from_mcp_result(result) == ""
# ---------------------------------------------------------------------------
# get_sdk_cwd
# ---------------------------------------------------------------------------
class TestGetSdkCwd:
def test_returns_empty_string_by_default(self):
set_execution_context(
user_id="test",
session=None, # type: ignore[arg-type]
sandbox=None,
)
assert get_sdk_cwd() == ""
def test_returns_set_value(self):
set_execution_context(
user_id="test",
session=None, # type: ignore[arg-type]
sandbox=None,
sdk_cwd="/tmp/copilot-test-123",
)
assert get_sdk_cwd() == "/tmp/copilot-test-123"
# ---------------------------------------------------------------------------
# stash / pop round-trip (the mechanism _truncating relies on)
# ---------------------------------------------------------------------------
class TestToolOutputStash:
@pytest.fixture(autouse=True)
def _init_context(self):
"""Initialise the context vars that stash_pending_tool_output needs."""
set_execution_context(
user_id="test",
session=None, # type: ignore[arg-type]
sandbox=None,
sdk_cwd="/tmp/test",
)
def test_stash_and_pop(self):
stash_pending_tool_output("my_tool", "output1")
assert pop_pending_tool_output("my_tool") == "output1"
def test_pop_empty_returns_none(self):
assert pop_pending_tool_output("nonexistent") is None
def test_fifo_order(self):
stash_pending_tool_output("t", "first")
stash_pending_tool_output("t", "second")
assert pop_pending_tool_output("t") == "first"
assert pop_pending_tool_output("t") == "second"
assert pop_pending_tool_output("t") is None
def test_dict_serialised_to_json(self):
stash_pending_tool_output("t", {"key": "value"})
assert pop_pending_tool_output("t") == '{"key": "value"}'
def test_separate_tool_names(self):
stash_pending_tool_output("a", "alpha")
stash_pending_tool_output("b", "beta")
assert pop_pending_tool_output("b") == "beta"
assert pop_pending_tool_output("a") == "alpha"
# ---------------------------------------------------------------------------
# _truncating wrapper (integration via create_copilot_mcp_server)
# ---------------------------------------------------------------------------
class TestTruncationAndStashIntegration:
"""Test truncation + stash behavior that _truncating relies on."""
@pytest.fixture(autouse=True)
def _init_context(self):
set_execution_context(
user_id="test",
session=None, # type: ignore[arg-type]
sandbox=None,
sdk_cwd="/tmp/test",
)
def test_small_output_stashed(self):
"""Non-error output is stashed for the response adapter."""
result = {
"content": [{"type": "text", "text": "small output"}],
"isError": False,
}
truncated = truncate(result, _MCP_MAX_CHARS)
text = _text_from_mcp_result(truncated)
assert text == "small output"
stash_pending_tool_output("test_tool", text)
assert pop_pending_tool_output("test_tool") == "small output"
def test_error_result_not_stashed(self):
"""Error results should not be stashed."""
result = {
"content": [{"type": "text", "text": "error msg"}],
"isError": True,
}
# _truncating only stashes when not result.get("isError")
if not result.get("isError"):
stash_pending_tool_output("err_tool", "should not happen")
assert pop_pending_tool_output("err_tool") is None
def test_large_output_truncated(self):
"""Output exceeding _MCP_MAX_CHARS is truncated before stashing."""
big_text = "x" * (_MCP_MAX_CHARS + 100_000)
result = {"content": [{"type": "text", "text": big_text}]}
truncated = truncate(result, _MCP_MAX_CHARS)
text = _text_from_mcp_result(truncated)
assert len(text) < len(big_text)
assert len(str(truncated)) <= _MCP_MAX_CHARS

View File

@@ -10,14 +10,13 @@ Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
filesystem for self-hosted) — no DB column needed.
"""
import json
import logging
import os
import re
import time
from dataclasses import dataclass
from backend.util import json
logger = logging.getLogger(__name__)
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
@@ -59,37 +58,41 @@ def strip_progress_entries(content: str) -> str:
Removes entries whose ``type`` is in ``STRIPPABLE_TYPES`` and reparents
any remaining child entries so the ``parentUuid`` chain stays intact.
Typically reduces transcript size by ~30%.
Entries that are not stripped or reparented are kept as their original
raw JSON line to avoid unnecessary re-serialization that changes
whitespace or key ordering.
"""
lines = content.strip().split("\n")
# Parse entries, keeping the original line alongside the parsed dict.
parsed: list[tuple[str, dict | None]] = []
entries: list[dict] = []
for line in lines:
parsed.append((line, json.loads(line, fallback=None)))
try:
entries.append(json.loads(line))
except json.JSONDecodeError:
# Keep unparseable lines as-is (safety)
entries.append({"_raw": line})
# First pass: identify stripped UUIDs and build parent map.
stripped_uuids: set[str] = set()
uuid_to_parent: dict[str, str] = {}
kept: list[dict] = []
for _line, entry in parsed:
if not isinstance(entry, dict):
for entry in entries:
if "_raw" in entry:
kept.append(entry)
continue
uid = entry.get("uuid", "")
parent = entry.get("parentUuid", "")
entry_type = entry.get("type", "")
if uid:
uuid_to_parent[uid] = parent
if entry.get("type", "") in STRIPPABLE_TYPES and uid:
stripped_uuids.add(uid)
# Second pass: keep non-stripped entries, reparenting where needed.
# Preserve original line when no reparenting is required.
reparented: set[str] = set()
for _line, entry in parsed:
if not isinstance(entry, dict):
if entry_type in STRIPPABLE_TYPES:
if uid:
stripped_uuids.add(uid)
else:
kept.append(entry)
# Reparent: walk up chain through stripped entries to find surviving ancestor
for entry in kept:
if "_raw" in entry:
continue
parent = entry.get("parentUuid", "")
original_parent = parent
@@ -97,32 +100,63 @@ def strip_progress_entries(content: str) -> str:
parent = uuid_to_parent.get(parent, "")
if parent != original_parent:
entry["parentUuid"] = parent
uid = entry.get("uuid", "")
if uid:
reparented.add(uid)
result_lines: list[str] = []
for line, entry in parsed:
if not isinstance(entry, dict):
result_lines.append(line)
continue
if entry.get("type", "") in STRIPPABLE_TYPES:
continue
uid = entry.get("uuid", "")
if uid in reparented:
# Re-serialize only entries whose parentUuid was changed.
result_lines.append(json.dumps(entry, separators=(",", ":")))
for entry in kept:
if "_raw" in entry:
result_lines.append(entry["_raw"])
else:
result_lines.append(line)
result_lines.append(json.dumps(entry, separators=(",", ":")))
return "\n".join(result_lines) + "\n"
# ---------------------------------------------------------------------------
# Local file I/O (write temp file for --resume)
# Local file I/O (read from CLI's JSONL, write temp file for --resume)
# ---------------------------------------------------------------------------
def read_transcript_file(transcript_path: str) -> str | None:
"""Read a JSONL transcript file from disk.
Returns the raw JSONL content, or ``None`` if the file is missing, empty,
or only contains metadata (≤2 lines with no conversation messages).
"""
if not transcript_path or not os.path.isfile(transcript_path):
logger.debug(f"[Transcript] File not found: {transcript_path}")
return None
try:
with open(transcript_path) as f:
content = f.read()
if not content.strip():
logger.debug("[Transcript] File is empty: %s", transcript_path)
return None
lines = content.strip().split("\n")
# Validate that the transcript has real conversation content
# (not just metadata like queue-operation entries).
if not validate_transcript(content):
logger.debug(
"[Transcript] No conversation content (%d lines) in %s",
len(lines),
transcript_path,
)
return None
logger.info(
f"[Transcript] Read {len(lines)} lines, "
f"{len(content)} bytes from {transcript_path}"
)
return content
except (json.JSONDecodeError, OSError) as e:
logger.warning(f"[Transcript] Failed to read {transcript_path}: {e}")
return None
def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
"""Sanitize an ID for safe use in file paths.
@@ -137,6 +171,14 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
def _encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does.
The CLI replaces all non-alphanumeric characters with ``-``.
"""
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
"""Remove the CLI's project directory for a specific working directory.
@@ -146,8 +188,7 @@ def cleanup_cli_project_dir(sdk_cwd: str) -> None:
"""
import shutil
# Encode cwd the same way CLI does (replaces non-alphanumeric with -)
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
cwd_encoded = _encode_cwd_for_cli(sdk_cwd)
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
@@ -207,29 +248,32 @@ def write_transcript_to_tempfile(
def validate_transcript(content: str | None) -> bool:
"""Check that a transcript has actual conversation messages.
A valid transcript needs at least one assistant message (not just
queue-operation / file-history-snapshot metadata). We do NOT require
a ``type: "user"`` entry because with ``--resume`` the user's message
is passed as a CLI query parameter and does not appear in the
transcript file.
A valid transcript for resume needs at least one user message and one
assistant message (not just queue-operation / file-history-snapshot
metadata).
"""
if not content or not content.strip():
return False
lines = content.strip().split("\n")
if len(lines) < 2:
return False
has_user = False
has_assistant = False
for line in lines:
if not line.strip():
continue
entry = json.loads(line, fallback=None)
if not isinstance(entry, dict):
try:
entry = json.loads(line)
msg_type = entry.get("type")
if msg_type == "user":
has_user = True
elif msg_type == "assistant":
has_assistant = True
except json.JSONDecodeError:
return False
if entry.get("type") == "assistant":
has_assistant = True
return has_assistant
return has_user and has_assistant
# ---------------------------------------------------------------------------
@@ -284,46 +328,45 @@ async def upload_transcript(
session_id: str,
content: str,
message_count: int = 0,
log_prefix: str = "[Transcript]",
) -> None:
"""Strip progress entries and upload complete transcript.
"""Strip progress entries and upload transcript to bucket storage.
The transcript represents the FULL active context (atomic).
Each upload REPLACES the previous transcript entirely.
The executor holds a cluster lock per session, so concurrent uploads for
the same session cannot happen.
Safety: only overwrites when the new (stripped) transcript is larger than
what is already stored. Since JSONL is append-only, the latest transcript
is always the longest. This prevents a slow/stale background task from
clobbering a newer upload from a concurrent turn.
Args:
content: Complete JSONL transcript (from TranscriptBuilder).
message_count: ``len(session.messages)`` at upload time.
message_count: ``len(session.messages)`` at upload time — used by
the next turn to detect staleness and compress only the gap.
"""
from backend.util.workspace_storage import get_workspace_storage
# Strip metadata entries (progress, file-history-snapshot, etc.)
# Note: SDK-built transcripts shouldn't have these, but strip for safety
stripped = strip_progress_entries(content)
if not validate_transcript(stripped):
# Log entry types for debugging — helps identify why validation failed
entry_types: list[str] = []
for line in stripped.strip().split("\n"):
entry = json.loads(line, fallback={"type": "INVALID_JSON"})
entry_types.append(entry.get("type", "?"))
logger.warning(
"%s Skipping upload — stripped content not valid "
"(types=%s, stripped_len=%d, raw_len=%d)",
log_prefix,
entry_types,
len(stripped),
len(content),
f"[Transcript] Skipping upload — stripped content not valid "
f"for session {session_id}"
)
logger.debug("%s Raw content preview: %s", log_prefix, content[:500])
logger.debug("%s Stripped content: %s", log_prefix, stripped[:500])
return
storage = await get_workspace_storage()
wid, fid, fname = _storage_path_parts(user_id, session_id)
encoded = stripped.encode("utf-8")
new_size = len(encoded)
# Check existing transcript size to avoid overwriting newer with older
path = _build_storage_path(user_id, session_id, storage)
try:
existing = await storage.retrieve(path)
if len(existing) >= new_size:
logger.info(
f"[Transcript] Skipping upload — existing ({len(existing)}B) "
f">= new ({new_size}B) for session {session_id}"
)
return
except (FileNotFoundError, Exception):
pass # No existing transcript or retrieval error — proceed with upload
await storage.store(
workspace_id=wid,
@@ -332,8 +375,11 @@ async def upload_transcript(
content=encoded,
)
# Update metadata so message_count stays current. The gap-fill logic
# in _build_query_message relies on it to avoid re-compressing messages.
# Store metadata alongside the transcript so the next turn can detect
# staleness and only compress the gap instead of the full history.
# Wrapped in try/except so a metadata write failure doesn't orphan
# the already-uploaded transcript — the next turn will just fall back
# to full gap fill (msg_count=0).
try:
meta = {"message_count": message_count, "uploaded_at": time.time()}
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
@@ -344,18 +390,17 @@ async def upload_transcript(
content=json.dumps(meta).encode("utf-8"),
)
except Exception as e:
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
logger.warning(f"[Transcript] Failed to write metadata for {session_id}: {e}")
logger.info(
f"{log_prefix} Uploaded {len(encoded)}B "
f"(stripped from {len(content)}B, msg_count={message_count})"
f"[Transcript] Uploaded {new_size}B "
f"(stripped from {len(content)}B, msg_count={message_count}) "
f"for session {session_id}"
)
async def download_transcript(
user_id: str,
session_id: str,
log_prefix: str = "[Transcript]",
user_id: str, session_id: str
) -> TranscriptDownload | None:
"""Download transcript and metadata from bucket storage.
@@ -371,10 +416,10 @@ async def download_transcript(
data = await storage.retrieve(path)
content = data.decode("utf-8")
except FileNotFoundError:
logger.debug(f"{log_prefix} No transcript in storage")
logger.debug(f"[Transcript] No transcript in storage for {session_id}")
return None
except Exception as e:
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
logger.warning(f"[Transcript] Failed to download transcript: {e}")
return None
# Try to load metadata (best-effort — old transcripts won't have it)
@@ -391,13 +436,16 @@ async def download_transcript(
meta_path = f"local://{mwid}/{mfid}/{mfname}"
meta_data = await storage.retrieve(meta_path)
meta = json.loads(meta_data.decode("utf-8"), fallback={})
meta = json.loads(meta_data.decode("utf-8"))
message_count = meta.get("message_count", 0)
uploaded_at = meta.get("uploaded_at", 0.0)
except (FileNotFoundError, Exception):
except (FileNotFoundError, json.JSONDecodeError, Exception):
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
logger.info(
f"[Transcript] Downloaded {len(content)}B "
f"(msg_count={message_count}) for session {session_id}"
)
return TranscriptDownload(
content=content,
message_count=message_count,

View File

@@ -1,188 +0,0 @@
"""Build complete JSONL transcript from SDK messages.
The transcript represents the FULL active context at any point in time.
Each upload REPLACES the previous transcript atomically.
Flow:
Turn 1: Upload [msg1, msg2]
Turn 2: Download [msg1, msg2] → Upload [msg1, msg2, msg3, msg4] (REPLACE)
Turn 3: Download [msg1, msg2, msg3, msg4] → Upload [all messages] (REPLACE)
The transcript is never incremental - always the complete atomic state.
"""
import logging
from typing import Any
from uuid import uuid4
from pydantic import BaseModel
from backend.util import json
from .transcript import STRIPPABLE_TYPES
logger = logging.getLogger(__name__)
class TranscriptEntry(BaseModel):
"""Single transcript entry (user or assistant turn)."""
type: str
uuid: str
parentUuid: str | None
message: dict[str, Any]
class TranscriptBuilder:
"""Build complete JSONL transcript from SDK messages.
This builder maintains the FULL conversation state, not incremental changes.
The output is always the complete active context.
"""
def __init__(self) -> None:
self._entries: list[TranscriptEntry] = []
self._last_uuid: str | None = None
def _last_is_assistant(self) -> bool:
return bool(self._entries) and self._entries[-1].type == "assistant"
def _last_message_id(self) -> str:
"""Return the message.id of the last entry, or '' if none."""
if self._entries:
return self._entries[-1].message.get("id", "")
return ""
def load_previous(self, content: str, log_prefix: str = "[Transcript]") -> None:
"""Load complete previous transcript.
This loads the FULL previous context. As new messages come in,
we append to this state. The final output is the complete context
(previous + new), not just the delta.
"""
if not content or not content.strip():
return
lines = content.strip().split("\n")
for line_num, line in enumerate(lines, 1):
if not line.strip():
continue
data = json.loads(line, fallback=None)
if data is None:
logger.warning(
"%s Failed to parse transcript line %d/%d",
log_prefix,
line_num,
len(lines),
)
continue
# Load all non-strippable entries (user/assistant/system/etc.)
# Skip only STRIPPABLE_TYPES to match strip_progress_entries() behavior
entry_type = data.get("type", "")
if entry_type in STRIPPABLE_TYPES:
continue
entry = TranscriptEntry(
type=data["type"],
uuid=data.get("uuid") or str(uuid4()),
parentUuid=data.get("parentUuid"),
message=data.get("message", {}),
)
self._entries.append(entry)
self._last_uuid = entry.uuid
logger.info(
"%s Loaded %d entries from previous transcript (last_uuid=%s)",
log_prefix,
len(self._entries),
self._last_uuid[:12] if self._last_uuid else None,
)
def append_user(self, content: str | list[dict], uuid: str | None = None) -> None:
"""Append a user entry."""
msg_uuid = uuid or str(uuid4())
self._entries.append(
TranscriptEntry(
type="user",
uuid=msg_uuid,
parentUuid=self._last_uuid,
message={"role": "user", "content": content},
)
)
self._last_uuid = msg_uuid
def append_tool_result(self, tool_use_id: str, content: str) -> None:
"""Append a tool result as a user entry (one per tool call)."""
self.append_user(
content=[
{"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
]
)
def append_assistant(
self,
content_blocks: list[dict],
model: str = "",
stop_reason: str | None = None,
) -> None:
"""Append an assistant entry.
Consecutive assistant entries automatically share the same message ID
so the CLI can merge them (thinking → text → tool_use) into a single
API message on ``--resume``. A new ID is assigned whenever an
assistant entry follows a non-assistant entry (user message or tool
result), because that marks the start of a new API response.
"""
message_id = (
self._last_message_id()
if self._last_is_assistant()
else f"msg_sdk_{uuid4().hex[:24]}"
)
msg_uuid = str(uuid4())
self._entries.append(
TranscriptEntry(
type="assistant",
uuid=msg_uuid,
parentUuid=self._last_uuid,
message={
"role": "assistant",
"model": model,
"id": message_id,
"type": "message",
"content": content_blocks,
"stop_reason": stop_reason,
"stop_sequence": None,
},
)
)
self._last_uuid = msg_uuid
def to_jsonl(self) -> str:
"""Export complete context as JSONL.
Consecutive assistant entries are kept separate to match the
native CLI format — the SDK merges them internally on resume.
Returns the FULL conversation state (all entries), not incremental.
This output REPLACES any previous transcript.
"""
if not self._entries:
return ""
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
return "\n".join(lines) + "\n"
@property
def entry_count(self) -> int:
"""Total number of entries in the complete context."""
return len(self._entries)
@property
def is_empty(self) -> bool:
"""Whether this builder has any entries."""
return len(self._entries) == 0

View File

@@ -1,11 +1,11 @@
"""Unit tests for JSONL transcript management utilities."""
import json
import os
from backend.util import json
from .transcript import (
STRIPPABLE_TYPES,
read_transcript_file,
strip_progress_entries,
validate_transcript,
write_transcript_to_tempfile,
@@ -38,6 +38,49 @@ PROGRESS_ENTRY = {
VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG)
# --- read_transcript_file ---
class TestReadTranscriptFile:
def test_returns_content_for_valid_file(self, tmp_path):
path = tmp_path / "session.jsonl"
path.write_text(VALID_TRANSCRIPT)
result = read_transcript_file(str(path))
assert result is not None
assert "user" in result
def test_returns_none_for_missing_file(self):
assert read_transcript_file("/nonexistent/path.jsonl") is None
def test_returns_none_for_empty_path(self):
assert read_transcript_file("") is None
def test_returns_none_for_empty_file(self, tmp_path):
path = tmp_path / "empty.jsonl"
path.write_text("")
assert read_transcript_file(str(path)) is None
def test_returns_none_for_metadata_only(self, tmp_path):
content = _make_jsonl(METADATA_LINE, FILE_HISTORY)
path = tmp_path / "meta.jsonl"
path.write_text(content)
assert read_transcript_file(str(path)) is None
def test_returns_none_for_invalid_json(self, tmp_path):
path = tmp_path / "bad.jsonl"
path.write_text("not json\n{}\n{}\n")
assert read_transcript_file(str(path)) is None
def test_no_size_limit(self, tmp_path):
"""Large files are accepted — bucket storage has no size limit."""
big_content = {"type": "user", "uuid": "u9", "data": "x" * 1_000_000}
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, big_content, ASST_MSG)
path = tmp_path / "big.jsonl"
path.write_text(content)
result = read_transcript_file(str(path))
assert result is not None
# --- write_transcript_to_tempfile ---
@@ -112,56 +155,12 @@ class TestValidateTranscript:
assert validate_transcript(content) is False
def test_assistant_only_no_user(self):
"""With --resume the user message is a CLI query param, not a transcript entry.
A transcript with only assistant entries is valid."""
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, ASST_MSG)
assert validate_transcript(content) is True
def test_resume_transcript_without_user_entry(self):
"""Simulates a real --resume stop hook transcript: the CLI session file
has summary + assistant entries but no user entry."""
summary = {"type": "summary", "uuid": "s1", "text": "context..."}
asst1 = {
"type": "assistant",
"uuid": "a1",
"message": {"role": "assistant", "content": "Hello!"},
}
asst2 = {
"type": "assistant",
"uuid": "a2",
"parentUuid": "a1",
"message": {"role": "assistant", "content": "Sure, let me help."},
}
content = _make_jsonl(summary, asst1, asst2)
assert validate_transcript(content) is True
def test_single_assistant_entry(self):
"""A transcript with just one assistant line is valid — the CLI may
produce short transcripts for simple responses with no tool use."""
content = json.dumps(ASST_MSG) + "\n"
assert validate_transcript(content) is True
assert validate_transcript(content) is False
def test_invalid_json_returns_false(self):
assert validate_transcript("not json\n{}\n{}\n") is False
def test_malformed_json_after_valid_assistant_returns_false(self):
"""Validation must scan all lines - malformed JSON anywhere should fail."""
valid_asst = json.dumps(ASST_MSG)
malformed = "not valid json"
content = valid_asst + "\n" + malformed + "\n"
assert validate_transcript(content) is False
def test_blank_lines_are_skipped(self):
"""Transcripts with blank lines should be valid if they contain assistant entries."""
content = (
json.dumps(USER_MSG)
+ "\n\n" # blank line
+ json.dumps(ASST_MSG)
+ "\n"
+ "\n" # another blank line
)
assert validate_transcript(content) is True
# --- strip_progress_entries ---
@@ -254,31 +253,3 @@ class TestStripProgressEntries:
assert "queue-operation" not in result_types
assert "user" in result_types
assert "assistant" in result_types
def test_preserves_original_line_formatting(self):
"""Non-reparented entries keep their original JSON formatting."""
# orjson produces compact JSON - test that we preserve the exact input
# when no reparenting is needed (no re-serialization)
original_line = json.dumps(USER_MSG)
content = original_line + "\n" + json.dumps(ASST_MSG) + "\n"
result = strip_progress_entries(content)
result_lines = result.strip().split("\n")
# Original line should be byte-identical (not re-serialized)
assert result_lines[0] == original_line
def test_reparented_entries_are_reserialized(self):
"""Entries whose parentUuid changes must be re-serialized."""
progress = {"type": "progress", "uuid": "p1", "parentUuid": "u1"}
asst = {
"type": "assistant",
"uuid": "a1",
"parentUuid": "p1",
"message": {"role": "assistant", "content": "done"},
}
content = _make_jsonl(USER_MSG, progress, asst)
result = strip_progress_entries(content)
lines = result.strip().split("\n")
asst_entry = json.loads(lines[-1])
assert asst_entry["parentUuid"] == "u1" # reparented

File diff suppressed because it is too large Load Diff

View File

@@ -4,14 +4,75 @@ from os import getenv
import pytest
from . import service as chat_service
from .model import create_chat_session, get_chat_session, upsert_chat_session
from .response_model import StreamError, StreamTextDelta
from .response_model import StreamError, StreamTextDelta, StreamToolOutputAvailable
from .sdk import service as sdk_service
from .sdk.transcript import download_transcript
logger = logging.getLogger(__name__)
@pytest.mark.asyncio(loop_scope="session")
async def test_stream_chat_completion(setup_test_user, test_user_id):
"""
Test the stream_chat_completion function.
"""
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
has_errors = False
assistant_message = ""
async for chunk in chat_service.stream_chat_completion(
session.session_id, "Hello, how are you?", user_id=session.user_id
):
logger.info(chunk)
if isinstance(chunk, StreamError):
has_errors = True
if isinstance(chunk, StreamTextDelta):
assistant_message += chunk.delta
# StreamFinish is published by mark_session_completed (processor layer),
# not by the service. The generator completing means the stream ended.
assert not has_errors, "Error occurred while streaming chat completion"
assert assistant_message, "Assistant message is empty"
@pytest.mark.asyncio(loop_scope="session")
async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user_id):
"""
Test the stream_chat_completion function.
"""
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
has_errors = False
had_tool_calls = False
async for chunk in chat_service.stream_chat_completion(
session.session_id,
"Please find me an agent that can help me with my business. Use the query 'moneny printing agent'",
user_id=session.user_id,
):
logger.info(chunk)
if isinstance(chunk, StreamError):
has_errors = True
if isinstance(chunk, StreamToolOutputAvailable):
had_tool_calls = True
assert not has_errors, "Error occurred while streaming chat completion"
assert had_tool_calls, "Tool calls did not occur"
session = await get_chat_session(session.session_id)
assert session, "Session not found"
assert session.usage, "Usage is empty"
@pytest.mark.asyncio(loop_scope="session")
async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
"""Test that the SDK --resume path captures and uses transcripts across turns.

View File

@@ -733,10 +733,7 @@ async def mark_session_completed(
# This is the SINGLE place that publishes StreamFinish — services and
# the processor must NOT publish it themselves.
try:
await publish_chunk(
turn_id,
StreamFinish(),
)
await publish_chunk(turn_id, StreamFinish())
except Exception as e:
logger.error(
f"Failed to publish StreamFinish for session {session_id}: {e}. "

View File

@@ -1,14 +1,12 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
from openai.types.chat import ChatCompletionToolParam
from backend.copilot.model import ChatSession
from backend.copilot.tracking import track_tool_called
from .add_understanding import AddUnderstandingTool
from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreenshotTool
from .agent_output import AgentOutputTool
from .base import BaseTool
from .bash_exec import BashExecTool
@@ -19,23 +17,10 @@ from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsToo
from .find_agent import FindAgentTool
from .find_block import FindBlockTool
from .find_library_agent import FindLibraryAgentTool
from .fix_agent import FixAgentGraphTool
from .get_agent_building_guide import GetAgentBuildingGuideTool
from .get_doc_page import GetDocPageTool
from .get_mcp_guide import GetMCPGuideTool
from .manage_folders import (
CreateFolderTool,
DeleteFolderTool,
ListFoldersTool,
MoveAgentsToFolderTool,
MoveFolderTool,
UpdateFolderTool,
)
from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .run_mcp_tool import RunMCPToolTool
from .search_docs import SearchDocsTool
from .validate_agent import ValidateAgentGraphTool
from .web_fetch import WebFetchTool
from .workspace_files import (
DeleteWorkspaceFileTool,
@@ -45,7 +30,6 @@ from .workspace_files import (
)
if TYPE_CHECKING:
from backend.copilot.model import ChatSession
from backend.copilot.response_model import StreamToolOutputAvailable
logger = logging.getLogger(__name__)
@@ -59,36 +43,19 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"find_agent": FindAgentTool(),
"find_block": FindBlockTool(),
"find_library_agent": FindLibraryAgentTool(),
# Folder management tools
"create_folder": CreateFolderTool(),
"list_folders": ListFoldersTool(),
"update_folder": UpdateFolderTool(),
"move_folder": MoveFolderTool(),
"delete_folder": DeleteFolderTool(),
"move_agents_to_folder": MoveAgentsToFolderTool(),
"run_agent": RunAgentTool(),
"run_block": RunBlockTool(),
"run_mcp_tool": RunMCPToolTool(),
"get_mcp_guide": GetMCPGuideTool(),
"view_agent_output": AgentOutputTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
"get_agent_building_guide": GetAgentBuildingGuideTool(),
# Web fetch for safe URL retrieval
"web_fetch": WebFetchTool(),
# Agent-browser multi-step automation (navigate, act, screenshot)
"browser_navigate": BrowserNavigateTool(),
"browser_act": BrowserActTool(),
"browser_screenshot": BrowserScreenshotTool(),
# Sandboxed code execution (bubblewrap)
"bash_exec": BashExecTool(),
# Persistent workspace tools (cloud storage, survives across sessions)
# Feature request tools
"search_feature_requests": SearchFeatureRequestsTool(),
"create_feature_request": CreateFeatureRequestTool(),
# Agent generation tools (local validation/fixing)
"validate_agent_graph": ValidateAgentGraphTool(),
"fix_agent_graph": FixAgentGraphTool(),
# Workspace tools for CoPilot file operations
"list_workspace_files": ListWorkspaceFilesTool(),
"read_workspace_file": ReadWorkspaceFileTool(),
@@ -100,17 +67,10 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
find_agent_tool = TOOL_REGISTRY["find_agent"]
run_agent_tool = TOOL_REGISTRY["run_agent"]
def get_available_tools() -> list[ChatCompletionToolParam]:
"""Return OpenAI tool schemas for tools available in the current environment.
Called per-request so that env-var or binary availability is evaluated
fresh each time (e.g. browser_* tools are excluded when agent-browser
CLI is not installed).
"""
return [
tool.as_openai_tool() for tool in TOOL_REGISTRY.values() if tool.is_available
]
# Generated from registry for OpenAI API
tools: list[ChatCompletionToolParam] = [
tool.as_openai_tool() for tool in TOOL_REGISTRY.values()
]
def get_tool(tool_name: str) -> BaseTool | None:

View File

@@ -151,8 +151,8 @@ async def setup_test_data(server):
unique_slug = f"test-agent-{str(uuid.uuid4())[:8]}"
store_submission = await store_db.create_store_submission(
user_id=user.id,
graph_id=created_graph.id,
graph_version=created_graph.version,
agent_id=created_graph.id,
agent_version=created_graph.version,
slug=unique_slug,
name="Test Agent",
description="A simple test agent",
@@ -161,10 +161,10 @@ async def setup_test_data(server):
image_urls=["https://example.com/image.jpg"],
)
assert store_submission.listing_version_id is not None
assert store_submission.store_listing_version_id is not None
# 4. Approve the store listing version
await store_db.review_store_submission(
store_listing_version_id=store_submission.listing_version_id,
store_listing_version_id=store_submission.store_listing_version_id,
is_approved=True,
external_comments="Approved for testing",
internal_comments="Test approval",
@@ -321,8 +321,8 @@ async def setup_llm_test_data(server):
unique_slug = f"llm-test-agent-{str(uuid.uuid4())[:8]}"
store_submission = await store_db.create_store_submission(
user_id=user.id,
graph_id=created_graph.id,
graph_version=created_graph.version,
agent_id=created_graph.id,
agent_version=created_graph.version,
slug=unique_slug,
name="LLM Test Agent",
description="An agent with LLM capabilities",
@@ -330,9 +330,9 @@ async def setup_llm_test_data(server):
categories=["testing", "ai"],
image_urls=["https://example.com/image.jpg"],
)
assert store_submission.listing_version_id is not None
assert store_submission.store_listing_version_id is not None
await store_db.review_store_submission(
store_listing_version_id=store_submission.listing_version_id,
store_listing_version_id=store_submission.store_listing_version_id,
is_approved=True,
external_comments="Approved for testing",
internal_comments="Test approval for LLM agent",
@@ -476,8 +476,8 @@ async def setup_firecrawl_test_data(server):
unique_slug = f"firecrawl-test-agent-{str(uuid.uuid4())[:8]}"
store_submission = await store_db.create_store_submission(
user_id=user.id,
graph_id=created_graph.id,
graph_version=created_graph.version,
agent_id=created_graph.id,
agent_version=created_graph.version,
slug=unique_slug,
name="Firecrawl Test Agent",
description="An agent with Firecrawl integration (no credentials)",
@@ -485,9 +485,9 @@ async def setup_firecrawl_test_data(server):
categories=["testing", "scraping"],
image_urls=["https://example.com/image.jpg"],
)
assert store_submission.listing_version_id is not None
assert store_submission.store_listing_version_id is not None
await store_db.review_store_submission(
store_listing_version_id=store_submission.listing_version_id,
store_listing_version_id=store_submission.store_listing_version_id,
is_approved=True,
external_comments="Approved for testing",
internal_comments="Test approval for Firecrawl agent",

View File

@@ -1,876 +0,0 @@
"""Agent-browser tools — multi-step browser automation for the Copilot.
Uses the agent-browser CLI (https://github.com/vercel-labs/agent-browser)
which runs a local Chromium instance managed by a persistent daemon.
- Runs locally — no cloud account required
- Full interaction support: click, fill, scroll, login flows, multi-step
- Session persistence via --session-name: cookies/auth carry across tool calls
within the same Copilot session, enabling login → navigate → extract workflows
- Screenshot with --annotate overlays @ref labels, saved to workspace for user
- The Claude Agent SDK's multi-turn loop handles orchestration — each tool call
is one browser action; the LLM chains them naturally
SSRF protection:
Uses the shared validate_url() from backend.util.request, which is the same
guard used by HTTP blocks and web_fetch. It resolves ALL DNS answers (not just
the first), blocks RFC 1918, loopback, link-local, 0.0.0.0/8, multicast,
and all relevant IPv6 ranges, and applies IDNA encoding to prevent Unicode
domain attacks.
Requires:
npm install -g agent-browser
agent-browser install (downloads Chromium, one-time per machine)
"""
import asyncio
import base64
import json
import logging
import os
import shutil
import tempfile
from typing import Any
from backend.copilot.model import ChatSession
from backend.util.request import validate_url_host
from .base import BaseTool
from .models import (
BrowserActResponse,
BrowserNavigateResponse,
BrowserScreenshotResponse,
ErrorResponse,
ToolResponseBase,
)
from .workspace_files import get_manager
logger = logging.getLogger(__name__)
# Per-command timeout (seconds). Navigation + networkidle wait can be slow.
_CMD_TIMEOUT = 45
# Accessibility tree can be very large; cap it to keep LLM context manageable.
_MAX_SNAPSHOT_CHARS = 20_000
# ---------------------------------------------------------------------------
# Subprocess helper
# ---------------------------------------------------------------------------
async def _run(
session_name: str,
*args: str,
timeout: int = _CMD_TIMEOUT,
) -> tuple[int, str, str]:
"""Run agent-browser for the given session and return (rc, stdout, stderr).
Uses both:
--session <name> → isolated Chromium context (no shared history/cookies
with other Copilot sessions — prevents cross-session
browser state leakage)
--session-name <name> → persist cookies/localStorage across tool calls within
the same session (enables login → navigate flows)
"""
cmd = [
"agent-browser",
"--session",
session_name,
"--session-name",
session_name,
*args,
]
proc = None
try:
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout)
return proc.returncode or 0, stdout.decode(), stderr.decode()
except asyncio.TimeoutError:
# Kill the orphaned subprocess so it does not linger in the process table.
if proc is not None and proc.returncode is None:
proc.kill()
try:
await proc.communicate()
except Exception:
pass # Best-effort reap; ignore errors during cleanup.
return 1, "", f"Command timed out after {timeout}s."
except FileNotFoundError:
return (
1,
"",
"agent-browser is not installed (run: npm install -g agent-browser && agent-browser install).",
)
async def _snapshot(session_name: str) -> str:
"""Return the current page's interactive accessibility tree, truncated."""
rc, stdout, stderr = await _run(session_name, "snapshot", "-i", "-c")
if rc != 0:
return f"[snapshot failed: {stderr[:300]}]"
text = stdout.strip()
if len(text) > _MAX_SNAPSHOT_CHARS:
suffix = "\n\n[Snapshot truncated — use browser_act to navigate further]"
keep = max(0, _MAX_SNAPSHOT_CHARS - len(suffix))
text = text[:keep] + suffix
return text
# ---------------------------------------------------------------------------
# Stateless session helpers — persist / restore browser state across pods
# ---------------------------------------------------------------------------
# Module-level cache of sessions known to be alive on this pod.
# Avoids the subprocess probe on every tool call within the same pod.
_alive_sessions: set[str] = set()
# Per-session locks to prevent concurrent _ensure_session calls from
# triggering duplicate _restore_browser_state for the same session.
# Protected by _session_locks_mutex to ensure setdefault/pop are not
# interleaved across await boundaries.
_session_locks: dict[str, asyncio.Lock] = {}
_session_locks_mutex = asyncio.Lock()
# Workspace filename for persisted browser state (auto-scoped to session).
# Dot-prefixed so it is hidden from user workspace listings.
_STATE_FILENAME = "._browser_state.json"
# Maximum concurrent subprocesses during cookie/storage restore.
_RESTORE_CONCURRENCY = 10
# Maximum cookies to restore per session. Pathological sites can accumulate
# thousands of cookies; restoring them all would be slow and is rarely useful.
_MAX_RESTORE_COOKIES = 100
# Background tasks for fire-and-forget state persistence.
# Prevents GC from collecting tasks before they complete.
_background_tasks: set[asyncio.Task] = set()
def _fire_and_forget_save(
session_name: str, user_id: str, session: ChatSession
) -> None:
"""Schedule state persistence as a background task (non-blocking).
State save is already best-effort (errors are swallowed), so running it
in the background avoids adding latency to tool responses.
"""
task = asyncio.create_task(_save_browser_state(session_name, user_id, session))
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
async def _has_local_session(session_name: str) -> bool:
"""Check if the local agent-browser daemon for this session is running."""
rc, _, _ = await _run(session_name, "get", "url", timeout=5)
return rc == 0
async def _save_browser_state(
session_name: str, user_id: str, session: ChatSession
) -> None:
"""Persist browser state (cookies, localStorage, URL) to workspace.
Best-effort: errors are logged but never propagate to the tool response.
"""
try:
# Gather state in parallel
(rc_url, url_out, _), (rc_ck, ck_out, _), (rc_ls, ls_out, _) = (
await asyncio.gather(
_run(session_name, "get", "url", timeout=10),
_run(session_name, "cookies", "get", "--json", timeout=10),
_run(session_name, "storage", "local", "--json", timeout=10),
)
)
state = {
"url": url_out.strip() if rc_url == 0 else "",
"cookies": (json.loads(ck_out) if rc_ck == 0 and ck_out.strip() else []),
"local_storage": (
json.loads(ls_out) if rc_ls == 0 and ls_out.strip() else {}
),
}
manager = await get_manager(user_id, session.session_id)
await manager.write_file(
content=json.dumps(state).encode("utf-8"),
filename=_STATE_FILENAME,
mime_type="application/json",
overwrite=True,
)
except Exception:
logger.warning(
"[browser] Failed to save browser state for session %s",
session_name,
exc_info=True,
)
async def _restore_browser_state(
session_name: str, user_id: str, session: ChatSession
) -> bool:
"""Restore browser state from workspace storage into a fresh daemon.
Best-effort: errors are logged but never propagate to the tool response.
Returns True on success (or no state to restore), False on failure.
"""
try:
manager = await get_manager(user_id, session.session_id)
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
if file_info is None:
return True # No saved state — first call or never saved
state_bytes = await manager.read_file(_STATE_FILENAME)
state = json.loads(state_bytes.decode("utf-8"))
url = state.get("url", "")
cookies = state.get("cookies", [])
local_storage = state.get("local_storage", {})
# Navigate first — starts daemon + sets the correct origin for cookies
if url:
# Validate the saved URL to prevent SSRF via stored redirect targets.
try:
await validate_url_host(url)
except ValueError:
logger.warning(
"[browser] State restore: blocked SSRF URL %s", url[:200]
)
return False
rc, _, stderr = await _run(session_name, "open", url)
if rc != 0:
logger.warning(
"[browser] State restore: failed to open %s: %s",
url,
stderr[:200],
)
return False
await _run(session_name, "wait", "--load", "load", timeout=15)
# Restore cookies and localStorage in parallel via asyncio.gather.
# Semaphore caps concurrent subprocess spawns so we don't overwhelm the
# system when a session has hundreds of cookies.
sem = asyncio.Semaphore(_RESTORE_CONCURRENCY)
# Guard against pathological sites with thousands of cookies.
if len(cookies) > _MAX_RESTORE_COOKIES:
logger.debug(
"[browser] State restore: capping cookies from %d to %d",
len(cookies),
_MAX_RESTORE_COOKIES,
)
cookies = cookies[:_MAX_RESTORE_COOKIES]
async def _set_cookie(c: dict[str, Any]) -> None:
name = c.get("name", "")
value = c.get("value", "")
domain = c.get("domain", "")
path = c.get("path", "/")
if not (name and domain):
return
async with sem:
rc, _, stderr = await _run(
session_name,
"cookies",
"set",
name,
value,
"--domain",
domain,
"--path",
path,
timeout=5,
)
if rc != 0:
logger.debug(
"[browser] State restore: cookie set failed for %s: %s",
name,
stderr[:100],
)
async def _set_storage(key: str, val: object) -> None:
async with sem:
rc, _, stderr = await _run(
session_name,
"storage",
"local",
"set",
key,
str(val),
timeout=5,
)
if rc != 0:
logger.debug(
"[browser] State restore: localStorage set failed for %s: %s",
key,
stderr[:100],
)
await asyncio.gather(
*[_set_cookie(c) for c in cookies],
*[_set_storage(k, v) for k, v in local_storage.items()],
)
return True
except Exception:
logger.warning(
"[browser] Failed to restore browser state for session %s",
session_name,
exc_info=True,
)
return False
async def _ensure_session(
session_name: str, user_id: str, session: ChatSession
) -> None:
"""Ensure the local browser daemon has state. Restore from cloud if needed."""
if session_name in _alive_sessions:
return
async with _session_locks_mutex:
lock = _session_locks.setdefault(session_name, asyncio.Lock())
async with lock:
# Double-check after acquiring lock — another coroutine may have restored.
if session_name in _alive_sessions:
return
if await _has_local_session(session_name):
_alive_sessions.add(session_name)
return
if await _restore_browser_state(session_name, user_id, session):
_alive_sessions.add(session_name)
async def close_browser_session(session_name: str, user_id: str | None = None) -> None:
"""Shut down the local agent-browser daemon and clean up stored state.
Deletes ``._browser_state.json`` from workspace storage so cookies and
other credentials do not linger after the session is deleted.
Best-effort: errors are logged but never raised.
"""
_alive_sessions.discard(session_name)
async with _session_locks_mutex:
_session_locks.pop(session_name, None)
# Delete persisted browser state (cookies, localStorage) from workspace.
if user_id:
try:
manager = await get_manager(user_id, session_name)
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
if file_info is not None:
await manager.delete_file(file_info.id)
except Exception:
logger.debug(
"[browser] Failed to delete state file for session %s",
session_name,
exc_info=True,
)
try:
rc, _, stderr = await _run(session_name, "close", timeout=10)
if rc != 0:
logger.debug(
"[browser] close failed for session %s: %s",
session_name,
stderr[:200],
)
except Exception:
logger.debug(
"[browser] Exception closing browser session %s",
session_name,
exc_info=True,
)
# ---------------------------------------------------------------------------
# Tool: browser_navigate
# ---------------------------------------------------------------------------
class BrowserNavigateTool(BaseTool):
"""Navigate to a URL and return the page's interactive elements.
The browser session persists across tool calls within this Copilot session
(keyed to session_id), so cookies and auth state carry over. This enables
full login flows: navigate to login page → browser_act to fill credentials
→ browser_act to submit → browser_navigate to the target page.
"""
@property
def name(self) -> str:
return "browser_navigate"
@property
def description(self) -> str:
return (
"Navigate to a URL using a real browser. Returns an accessibility "
"tree snapshot listing the page's interactive elements with @ref IDs "
"(e.g. @e3) that can be used with browser_act. "
"Session persists — cookies and login state carry over between calls. "
"Use this (with browser_act) for multi-step interaction: login flows, "
"form filling, button clicks, or anything requiring page interaction. "
"For plain static pages, prefer web_fetch — no browser overhead. "
"For authenticated pages: navigate to the login page first, use browser_act "
"to fill credentials and submit, then navigate to the target page. "
"Note: for slow SPAs, the returned snapshot may reflect a partially-loaded "
"state. If elements seem missing, use browser_act with action='wait' and a "
"CSS selector or millisecond delay, then take a browser_screenshot to verify."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The HTTP/HTTPS URL to navigate to.",
},
"wait_for": {
"type": "string",
"enum": ["networkidle", "load", "domcontentloaded"],
"default": "networkidle",
"description": "When to consider navigation complete. Use 'networkidle' for SPAs (default).",
},
},
"required": ["url"],
}
@property
def requires_auth(self) -> bool:
return True
@property
def is_available(self) -> bool:
return shutil.which("agent-browser") is not None
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
) -> ToolResponseBase:
"""Navigate to *url*, wait for the page to settle, and return a snapshot.
The snapshot is an accessibility-tree listing of interactive elements.
Note: for slow SPAs that never fully idle, the snapshot may reflect a
partially-loaded state (the wait is best-effort).
"""
url: str = (kwargs.get("url") or "").strip()
wait_for: str = kwargs.get("wait_for") or "networkidle"
session_name = session.session_id
if not url:
return ErrorResponse(
message="Please provide a URL to navigate to.",
error="missing_url",
session_id=session_name,
)
try:
await validate_url_host(url)
except ValueError as e:
return ErrorResponse(
message=str(e),
error="blocked_url",
session_id=session_name,
)
# Restore browser state from cloud if this is a different pod
if user_id:
await _ensure_session(session_name, user_id, session)
# Navigate
rc, _, stderr = await _run(session_name, "open", url)
if rc != 0:
logger.warning(
"[browser_navigate] open failed for %s: %s", url, stderr[:300]
)
return ErrorResponse(
message="Failed to navigate to URL.",
error="navigation_failed",
session_id=session_name,
)
# Wait for page to settle (best-effort: some SPAs never reach networkidle)
wait_rc, _, wait_err = await _run(session_name, "wait", "--load", wait_for)
if wait_rc != 0:
logger.warning(
"[browser_navigate] wait(%s) failed: %s", wait_for, wait_err[:300]
)
# Get current title and URL in parallel
(_, title_out, _), (_, url_out, _) = await asyncio.gather(
_run(session_name, "get", "title"),
_run(session_name, "get", "url"),
)
snapshot = await _snapshot(session_name)
result = BrowserNavigateResponse(
message=f"Navigated to {url}",
url=url_out.strip() or url,
title=title_out.strip(),
snapshot=snapshot,
session_id=session_name,
)
# Persist browser state to cloud for cross-pod continuity
if user_id:
_fire_and_forget_save(session_name, user_id, session)
return result
# ---------------------------------------------------------------------------
# Tool: browser_act
# ---------------------------------------------------------------------------
_NO_TARGET_ACTIONS = frozenset({"back", "forward", "reload"})
_SCROLL_ACTIONS = frozenset({"scroll"})
_TARGET_ONLY_ACTIONS = frozenset({"click", "dblclick", "hover", "check", "uncheck"})
_TARGET_VALUE_ACTIONS = frozenset({"fill", "type", "select"})
# wait <selector|ms>: waits for a DOM element or a fixed delay (e.g. "1000" for 1 s)
_WAIT_ACTIONS = frozenset({"wait"})
class BrowserActTool(BaseTool):
"""Perform an action on the current browser page and return the updated snapshot.
Use @ref IDs from the snapshot returned by browser_navigate (e.g. '@e3').
The LLM orchestrates multi-step flows by chaining browser_navigate and
browser_act calls across turns of the Claude Agent SDK conversation.
"""
@property
def name(self) -> str:
return "browser_act"
@property
def description(self) -> str:
return (
"Interact with the current browser page. Use @ref IDs from the "
"snapshot (e.g. '@e3') to target elements. Returns an updated snapshot. "
"Supported actions: click, dblclick, fill, type, scroll, hover, press, "
"check, uncheck, select, wait, back, forward, reload. "
"fill clears the field before typing; type appends without clearing. "
"wait accepts a CSS selector (waits for element) or milliseconds string (e.g. '1000'). "
"Example login flow: fill @e1 with email → fill @e2 with password → "
"click @e3 (submit) → browser_navigate to the target page."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": [
"click",
"dblclick",
"fill",
"type",
"scroll",
"hover",
"press",
"check",
"uncheck",
"select",
"wait",
"back",
"forward",
"reload",
],
"description": "The action to perform.",
},
"target": {
"type": "string",
"description": (
"Element to target. Use @ref from snapshot (e.g. '@e3'), "
"a CSS selector, or a text description. "
"Required for: click, dblclick, fill, type, hover, check, uncheck, select. "
"For wait: a CSS selector to wait for, or milliseconds as a string (e.g. '1000')."
),
},
"value": {
"type": "string",
"description": (
"For fill/type: the text to enter. "
"For press: key name (e.g. 'Enter', 'Tab', 'Control+a'). "
"For select: the option value to select."
),
},
"direction": {
"type": "string",
"enum": ["up", "down", "left", "right"],
"default": "down",
"description": "For scroll: direction to scroll.",
},
},
"required": ["action"],
}
@property
def requires_auth(self) -> bool:
return True
@property
def is_available(self) -> bool:
return shutil.which("agent-browser") is not None
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
) -> ToolResponseBase:
"""Perform a browser action and return an updated page snapshot.
Validates the *action*/*target*/*value* combination, delegates to
``agent-browser``, waits for the page to settle, and returns the
accessibility-tree snapshot so the LLM can plan the next step.
"""
action: str = (kwargs.get("action") or "").strip()
target: str = (kwargs.get("target") or "").strip()
value: str = (kwargs.get("value") or "").strip()
direction: str = (kwargs.get("direction") or "down").strip()
session_name = session.session_id
if not action:
return ErrorResponse(
message="Please specify an action.",
error="missing_action",
session_id=session_name,
)
# Build the agent-browser command args
if action in _NO_TARGET_ACTIONS:
cmd_args = [action]
elif action in _SCROLL_ACTIONS:
cmd_args = ["scroll", direction]
elif action == "press":
if not value:
return ErrorResponse(
message="'press' requires a 'value' (key name, e.g. 'Enter').",
error="missing_value",
session_id=session_name,
)
cmd_args = ["press", value]
elif action in _TARGET_ONLY_ACTIONS:
if not target:
return ErrorResponse(
message=f"'{action}' requires a 'target' element.",
error="missing_target",
session_id=session_name,
)
cmd_args = [action, target]
elif action in _TARGET_VALUE_ACTIONS:
if not target or not value:
return ErrorResponse(
message=f"'{action}' requires both 'target' and 'value'.",
error="missing_params",
session_id=session_name,
)
cmd_args = [action, target, value]
elif action in _WAIT_ACTIONS:
if not target:
return ErrorResponse(
message=(
"'wait' requires a 'target': a CSS selector to wait for, "
"or milliseconds as a string (e.g. '1000')."
),
error="missing_target",
session_id=session_name,
)
cmd_args = ["wait", target]
else:
return ErrorResponse(
message=f"Unsupported action: {action}",
error="invalid_action",
session_id=session_name,
)
# Restore browser state from cloud if this is a different pod
if user_id:
await _ensure_session(session_name, user_id, session)
rc, _, stderr = await _run(session_name, *cmd_args)
if rc != 0:
logger.warning("[browser_act] %s failed: %s", action, stderr[:300])
return ErrorResponse(
message=f"Action '{action}' failed.",
error="action_failed",
session_id=session_name,
)
# Allow the page to settle after interaction (best-effort: SPAs may not idle)
settle_rc, _, settle_err = await _run(
session_name, "wait", "--load", "networkidle"
)
if settle_rc != 0:
logger.warning(
"[browser_act] post-action wait failed: %s", settle_err[:300]
)
snapshot = await _snapshot(session_name)
_, url_out, _ = await _run(session_name, "get", "url")
result = BrowserActResponse(
message=f"Performed '{action}'" + (f" on '{target}'" if target else ""),
action=action,
current_url=url_out.strip(),
snapshot=snapshot,
session_id=session_name,
)
# Persist browser state to cloud for cross-pod continuity
if user_id:
_fire_and_forget_save(session_name, user_id, session)
return result
# ---------------------------------------------------------------------------
# Tool: browser_screenshot
# ---------------------------------------------------------------------------
class BrowserScreenshotTool(BaseTool):
"""Capture a screenshot of the current browser page and save it to the workspace."""
@property
def name(self) -> str:
return "browser_screenshot"
@property
def description(self) -> str:
return (
"Take a screenshot of the current browser page and save it to the workspace. "
"IMPORTANT: After calling this tool, immediately call read_workspace_file "
"with the returned file_id to display the image inline to the user — "
"the screenshot is not visible until you do this. "
"With annotate=true (default), @ref labels are overlaid on interactive "
"elements, making it easy to see which @ref ID maps to which element on screen."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"annotate": {
"type": "boolean",
"default": True,
"description": "Overlay @ref labels on interactive elements (default: true).",
},
"filename": {
"type": "string",
"default": "screenshot.png",
"description": "Filename to save in the workspace.",
},
},
}
@property
def requires_auth(self) -> bool:
return True
@property
def is_available(self) -> bool:
return shutil.which("agent-browser") is not None
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
) -> ToolResponseBase:
"""Capture a PNG screenshot and upload it to the workspace.
Handles string-to-bool coercion for *annotate* (OpenAI function-call
payloads sometimes deliver ``"true"``/``"false"`` as strings).
Returns a :class:`BrowserScreenshotResponse` with the workspace
``file_id`` the LLM should pass to ``read_workspace_file``.
"""
raw_annotate = kwargs.get("annotate", True)
if isinstance(raw_annotate, str):
annotate = raw_annotate.strip().lower() in {"1", "true", "yes", "on"}
else:
annotate = bool(raw_annotate)
filename: str = (kwargs.get("filename") or "screenshot.png").strip()
session_name = session.session_id
# Restore browser state from cloud if this is a different pod
if user_id:
await _ensure_session(session_name, user_id, session)
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".png")
os.close(tmp_fd)
try:
cmd_args = ["screenshot"]
if annotate:
cmd_args.append("--annotate")
cmd_args.append(tmp_path)
rc, _, stderr = await _run(session_name, *cmd_args)
if rc != 0:
logger.warning("[browser_screenshot] failed: %s", stderr[:300])
return ErrorResponse(
message="Failed to take screenshot.",
error="screenshot_failed",
session_id=session_name,
)
with open(tmp_path, "rb") as f:
png_bytes = f.read()
finally:
try:
os.unlink(tmp_path)
except OSError:
pass # Best-effort temp file cleanup; not critical if it fails.
# Upload to workspace so the user can view it
png_b64 = base64.b64encode(png_bytes).decode()
# Import here to avoid circular deps — workspace_files imports from .models
from .workspace_files import WorkspaceWriteResponse, WriteWorkspaceFileTool
write_resp = await WriteWorkspaceFileTool()._execute(
user_id=user_id,
session=session,
filename=filename,
content_base64=png_b64,
)
if not isinstance(write_resp, WorkspaceWriteResponse):
return ErrorResponse(
message="Screenshot taken but failed to save to workspace.",
error="workspace_write_failed",
session_id=session_name,
)
result = BrowserScreenshotResponse(
message=f"Screenshot saved to workspace as '{filename}'. Use read_workspace_file with file_id='{write_resp.file_id}' to retrieve it.",
file_id=write_resp.file_id,
filename=filename,
session_id=session_name,
)
# Persist browser state to cloud for cross-pod continuity
if user_id:
_fire_and_forget_save(session_name, user_id, session)
return result

View File

@@ -1,15 +1,20 @@
"""Agent generator package - Creates agents from natural language."""
from .core import (
AgentGeneratorNotConfiguredError,
AgentJsonValidationError,
AgentSummary,
DecompositionResult,
DecompositionStep,
LibraryAgentSummary,
MarketplaceAgentSummary,
customize_template,
decompose_goal,
enrich_library_agents_from_steps,
extract_search_terms_from_steps,
extract_uuids_from_text,
generate_agent,
generate_agent_patch,
get_agent_as_json,
get_all_relevant_agents_for_generation,
get_library_agent_by_graph_id,
@@ -22,20 +27,25 @@ from .core import (
search_marketplace_agents_for_generation,
)
from .errors import get_user_message_for_error
from .validation import AgentFixer, AgentValidator
from .service import health_check as check_external_service_health
from .service import is_external_service_configured
__all__ = [
"AgentFixer",
"AgentValidator",
"AgentGeneratorNotConfiguredError",
"AgentJsonValidationError",
"AgentSummary",
"DecompositionResult",
"DecompositionStep",
"LibraryAgentSummary",
"MarketplaceAgentSummary",
"check_external_service_health",
"customize_template",
"decompose_goal",
"enrich_library_agents_from_steps",
"extract_search_terms_from_steps",
"extract_uuids_from_text",
"generate_agent",
"generate_agent_patch",
"get_agent_as_json",
"get_all_relevant_agents_for_generation",
"get_library_agent_by_graph_id",
@@ -44,6 +54,7 @@ __all__ = [
"get_library_agents_for_generation",
"get_user_message_for_error",
"graph_to_json",
"is_external_service_configured",
"json_to_graph",
"save_agent_to_library",
"search_marketplace_agents_for_generation",

View File

@@ -1,66 +0,0 @@
"""Block management for agent generation.
Provides cached access to block metadata for validation and fixing.
"""
import logging
from typing import Any, Type
from backend.blocks import get_blocks as get_block_classes
from backend.blocks._base import Block
logger = logging.getLogger(__name__)
__all__ = ["get_blocks_as_dicts", "reset_block_caches"]
# ---------------------------------------------------------------------------
# Module-level caches
# ---------------------------------------------------------------------------
_blocks_cache: list[dict[str, Any]] | None = None
def reset_block_caches() -> None:
"""Reset all module-level caches (useful after updating block descriptions)."""
global _blocks_cache
_blocks_cache = None
# ---------------------------------------------------------------------------
# 1. get_blocks_as_dicts
# ---------------------------------------------------------------------------
def get_blocks_as_dicts() -> list[dict[str, Any]]:
"""Get all available blocks as dicts (cached after first call).
Each dict contains the keys returned by ``Block.get_info().model_dump()``:
id, name, description, inputSchema, outputSchema, categories,
staticOutput, costs, contributors, uiType.
Returns:
List of block info dicts.
"""
global _blocks_cache
if _blocks_cache is not None:
return _blocks_cache
block_classes: dict[str, Type[Block]] = get_block_classes() # type: ignore[assignment]
blocks: list[dict[str, Any]] = []
for block_cls in block_classes.values():
try:
instance = block_cls()
info = instance.get_info().model_dump()
# Use optimized description if available (loaded at startup)
if instance.optimized_description:
info["description"] = instance.optimized_description
blocks.append(info)
except Exception:
logger.warning(
"Failed to load block info for %s, skipping",
getattr(block_cls, "__name__", "unknown"),
exc_info=True,
)
_blocks_cache = blocks
logger.info("Cached %d block dicts", len(blocks))
return _blocks_cache

View File

@@ -10,7 +10,13 @@ from backend.data.db_accessors import graph_db, library_db, store_db
from backend.data.graph import Graph, Link, Node
from backend.util.exceptions import DatabaseError, NotFoundError
from .helpers import UUID_RE_STR
from .service import (
customize_template_external,
decompose_goal_external,
generate_agent_external,
generate_agent_patch_external,
is_external_service_configured,
)
logger = logging.getLogger(__name__)
@@ -72,7 +78,38 @@ class DecompositionResult(TypedDict, total=False):
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
_UUID_PATTERN = re.compile(UUID_RE_STR, re.IGNORECASE)
def _to_dict_list(
agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None,
) -> list[dict[str, Any]] | None:
"""Convert typed agent summaries to plain dicts for external service calls."""
if agents is None:
return None
return [dict(a) for a in agents]
class AgentGeneratorNotConfiguredError(Exception):
"""Raised when the external Agent Generator service is not configured."""
pass
def _check_service_configured() -> None:
"""Check if the external Agent Generator service is configured.
Raises:
AgentGeneratorNotConfiguredError: If the service is not configured.
"""
if not is_external_service_configured():
raise AgentGeneratorNotConfiguredError(
"Agent Generator service is not configured. "
"Set AGENTGENERATOR_HOST environment variable to enable agent generation."
)
_UUID_PATTERN = re.compile(
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
re.IGNORECASE,
)
def extract_uuids_from_text(text: str) -> list[str]:
@@ -516,6 +553,69 @@ async def enrich_library_agents_from_steps(
return all_agents
async def decompose_goal(
description: str,
context: str = "",
library_agents: Sequence[AgentSummary] | None = None,
) -> DecompositionResult | None:
"""Break down a goal into steps or return clarifying questions.
Args:
description: Natural language goal description
context: Additional context (e.g., answers to previous questions)
library_agents: User's library agents available for sub-agent composition
Returns:
DecompositionResult with either:
- {"type": "clarifying_questions", "questions": [...]}
- {"type": "instructions", "steps": [...]}
Or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for decompose_goal")
result = await decompose_goal_external(
description, context, _to_dict_list(library_agents)
)
return result # type: ignore[return-value]
async def generate_agent(
instructions: DecompositionResult | dict[str, Any],
library_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Generate agent JSON from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
Returns:
Agent JSON dict, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent")
result = await generate_agent_external(
dict(instructions), _to_dict_list(library_agents)
)
if result:
if isinstance(result, dict) and result.get("type") == "error":
return result
if "id" not in result:
result["id"] = str(uuid.uuid4())
if "version" not in result:
result["version"] = 1
if "is_active" not in result:
result["is_active"] = True
return result
class AgentJsonValidationError(Exception):
"""Raised when agent JSON is invalid or missing required fields."""
@@ -595,10 +695,7 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
async def save_agent_to_library(
agent_json: dict[str, Any],
user_id: str,
is_update: bool = False,
folder_id: str | None = None,
agent_json: dict[str, Any], user_id: str, is_update: bool = False
) -> tuple[Graph, Any]:
"""Save agent to database and user's library.
@@ -606,7 +703,6 @@ async def save_agent_to_library(
agent_json: Agent JSON dict
user_id: User ID
is_update: Whether this is an update to an existing agent
folder_id: Optional folder ID to place the agent in
Returns:
Tuple of (created Graph, LibraryAgent)
@@ -615,7 +711,7 @@ async def save_agent_to_library(
db = library_db()
if is_update:
return await db.update_graph_in_library(graph, user_id)
return await db.create_graph_in_library(graph, user_id, folder_id=folder_id)
return await db.create_graph_in_library(graph, user_id)
def graph_to_json(graph: Graph) -> dict[str, Any]:
@@ -692,3 +788,70 @@ async def get_agent_as_json(
return None
return graph_to_json(graph)
async def generate_agent_patch(
update_request: str,
current_agent: dict[str, Any],
library_agents: Sequence[AgentSummary] | None = None,
) -> dict[str, Any] | None:
"""Update an existing agent using natural language.
The external Agent Generator service handles:
- Generating the patch
- Applying the patch
- Fixing and validating the result
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
Returns:
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent_patch")
return await generate_agent_patch_external(
update_request,
current_agent,
_to_dict_list(library_agents),
)
async def customize_template(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any] | None:
"""Customize a template/marketplace agent using natural language.
This is used when users want to modify a template or marketplace agent
to fit their specific needs before adding it to their library.
The external Agent Generator service handles:
- Understanding the modification request
- Applying changes to the template
- Fixing and validating the result
Args:
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
Returns:
Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
error dict {"type": "error", ...}, or None on unexpected error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for customize_template")
return await customize_template_external(
template_agent, modification_request, context
)

View File

@@ -0,0 +1,165 @@
"""Dummy Agent Generator for testing.
Returns mock responses matching the format expected from the external service.
Enable via AGENTGENERATOR_USE_DUMMY=true in settings.
WARNING: This is for testing only. Do not use in production.
"""
import asyncio
import logging
import uuid
from typing import Any
logger = logging.getLogger(__name__)
# Dummy decomposition result (instructions type)
DUMMY_DECOMPOSITION_RESULT: dict[str, Any] = {
"type": "instructions",
"steps": [
{
"description": "Get input from user",
"action": "input",
"block_name": "AgentInputBlock",
},
{
"description": "Process the input",
"action": "process",
"block_name": "TextFormatterBlock",
},
{
"description": "Return output to user",
"action": "output",
"block_name": "AgentOutputBlock",
},
],
}
# Block IDs from backend/blocks/io.py
AGENT_INPUT_BLOCK_ID = "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4"
def _generate_dummy_agent_json() -> dict[str, Any]:
"""Generate a minimal valid agent JSON for testing."""
input_node_id = str(uuid.uuid4())
output_node_id = str(uuid.uuid4())
return {
"id": str(uuid.uuid4()),
"version": 1,
"is_active": True,
"name": "Dummy Test Agent",
"description": "A dummy agent generated for testing purposes",
"nodes": [
{
"id": input_node_id,
"block_id": AGENT_INPUT_BLOCK_ID,
"input_default": {
"name": "input",
"title": "Input",
"description": "Enter your input",
"placeholder_values": [],
},
"metadata": {"position": {"x": 0, "y": 0}},
},
{
"id": output_node_id,
"block_id": AGENT_OUTPUT_BLOCK_ID,
"input_default": {
"name": "output",
"title": "Output",
"description": "Agent output",
"format": "{output}",
},
"metadata": {"position": {"x": 400, "y": 0}},
},
],
"links": [
{
"id": str(uuid.uuid4()),
"source_id": input_node_id,
"sink_id": output_node_id,
"source_name": "result",
"sink_name": "value",
"is_static": False,
},
],
}
async def decompose_goal_dummy(
description: str,
context: str = "",
library_agents: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
"""Return dummy decomposition result."""
logger.info("Using dummy agent generator for decompose_goal")
return DUMMY_DECOMPOSITION_RESULT.copy()
async def generate_agent_dummy(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Return dummy agent synchronously (blocks for 30s, returns agent JSON).
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
"""
logger.info(
"Using dummy agent generator (sync mode): returning agent JSON after 30s"
)
await asyncio.sleep(30)
return _generate_dummy_agent_json()
async def generate_agent_patch_dummy(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Return dummy patched agent synchronously (blocks for 30s, returns patched agent JSON).
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
"""
logger.info(
"Using dummy agent generator patch (sync mode): returning patched agent after 30s"
)
await asyncio.sleep(30)
patched = current_agent.copy()
patched["description"] = (
f"{current_agent.get('description', '')} (updated: {update_request})"
)
return patched
async def customize_template_dummy(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any]:
"""Return dummy customized template (returns template with updated description)."""
logger.info("Using dummy agent generator for customize_template")
customized = template_agent.copy()
customized["description"] = (
f"{template_agent.get('description', '')} (customized: {modification_request})"
)
return customized
async def get_blocks_dummy() -> list[dict[str, Any]]:
"""Return dummy blocks list."""
logger.info("Using dummy agent generator for get_blocks")
return [
{"id": AGENT_INPUT_BLOCK_ID, "name": "AgentInputBlock"},
{"id": AGENT_OUTPUT_BLOCK_ID, "name": "AgentOutputBlock"},
]
async def health_check_dummy() -> bool:
"""Always returns healthy for dummy service."""
return True

Some files were not shown because too many files have changed in this diff Show More