mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
51 Commits
feat/brows
...
swiftyos/i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
af7e8d19fe | ||
|
|
eadc68f2a5 | ||
|
|
eca7b5e793 | ||
|
|
c304a4937a | ||
|
|
8cfabcf4fd | ||
|
|
7bf407b66c | ||
|
|
7ead4c040f | ||
|
|
0f813f1bf9 | ||
|
|
aa08063939 | ||
|
|
bde6a4c0df | ||
|
|
d56452898a | ||
|
|
7507240177 | ||
|
|
d7c3f5b8fc | ||
|
|
3e108a813a | ||
|
|
08c49a78f8 | ||
|
|
5d56548e6b | ||
|
|
6ecf55d214 | ||
|
|
7c8c7bf395 | ||
|
|
0b9e0665dd | ||
|
|
be18436e8f | ||
|
|
f6f268a1f0 | ||
|
|
ea0333c1fc | ||
|
|
21c705af6e | ||
|
|
a576be9db2 | ||
|
|
5e90585f10 | ||
|
|
3e22a0e786 | ||
|
|
6abe39b33a | ||
|
|
476cf1c601 | ||
|
|
25022f2d1e | ||
|
|
ce1675cfc7 | ||
|
|
3d0ede9f34 | ||
|
|
5474f7c495 | ||
|
|
f1b771b7ee | ||
|
|
aa7a2f0a48 | ||
|
|
3722d05b9b | ||
|
|
592830ce9b | ||
|
|
6cc680f71c | ||
|
|
b342bfa3ba | ||
|
|
0215332386 | ||
|
|
160d6eddfb | ||
|
|
a5db9c05d0 | ||
|
|
b74d41d50c | ||
|
|
a897f9e124 | ||
|
|
7fd26d3554 | ||
|
|
b504cf9854 | ||
|
|
29da8db48e | ||
|
|
757ec1f064 | ||
|
|
9442c648a4 | ||
|
|
1c51dd18aa | ||
|
|
6f4f80871d | ||
|
|
e8cca6cd9a |
2
.github/workflows/platform-frontend-ci.yml
vendored
2
.github/workflows/platform-frontend-ci.yml
vendored
@@ -149,7 +149,7 @@ jobs:
|
||||
driver-opts: network=host
|
||||
|
||||
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
||||
uses: crazy-max/ghaction-github-runtime@v3
|
||||
uses: crazy-max/ghaction-github-runtime@v4
|
||||
|
||||
- name: Set up Platform - Build Docker images (with cache)
|
||||
working-directory: autogpt_platform
|
||||
|
||||
3
autogpt_platform/.gitignore
vendored
3
autogpt_platform/.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
*.ignore.*
|
||||
*.ign.*
|
||||
*.ign.*
|
||||
.application.logs
|
||||
|
||||
@@ -95,7 +95,7 @@ ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
|
||||
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
|
||||
# for the bash_exec MCP tool.
|
||||
# for the bash_exec MCP tool (fallback when E2B is not configured).
|
||||
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.13 \
|
||||
@@ -111,13 +111,29 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
# Copy poetry (build-time only, for `poetry install --only-root` to create entry points)
|
||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||
# Copy Node.js installation for Prisma
|
||||
# Copy Node.js installation for Prisma and agent-browser.
|
||||
# npm/npx are symlinks in the builder (-> ../lib/node_modules/npm/bin/*-cli.js);
|
||||
# COPY resolves them to regular files, breaking require() paths. Recreate as
|
||||
# proper symlinks so npm/npx can find their modules.
|
||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
||||
RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
|
||||
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
# Install agent-browser (Copilot browser tool) + Chromium runtime dependencies.
|
||||
# These are the runtime libraries Chromium/Playwright needs on Debian 13 (trixie).
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
|
||||
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
|
||||
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
|
||||
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
|
||||
fonts-liberation libfontconfig1 \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& npm install -g agent-browser \
|
||||
&& agent-browser install \
|
||||
&& rm -rf /tmp/* /root/.npm
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
# Copy only the .venv from builder (not the entire /app directory)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import urllib.parse
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, Any, Literal, Optional, Sequence
|
||||
from typing import Annotated, Any, Optional, Sequence
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Security
|
||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||
@@ -9,9 +9,10 @@ from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import backend.api.features.store.cache as store_cache
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
import backend.blocks
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.external.middleware import require_auth, require_permission
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data import user as user_db
|
||||
@@ -230,13 +231,13 @@ async def get_graph_execution_results(
|
||||
@v1_router.get(
|
||||
path="/store/agents",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
response_model=store_model.StoreAgentsResponse,
|
||||
)
|
||||
async def get_store_agents(
|
||||
featured: bool = False,
|
||||
creator: str | None = None,
|
||||
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
|
||||
sorted_by: store_db.StoreAgentsSortOptions | None = None,
|
||||
search_query: str | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
@@ -278,7 +279,7 @@ async def get_store_agents(
|
||||
@v1_router.get(
|
||||
path="/store/agents/{username}/{agent_name}",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
response_model=store_model.StoreAgentDetails,
|
||||
)
|
||||
async def get_store_agent(
|
||||
@@ -306,13 +307,13 @@ async def get_store_agent(
|
||||
@v1_router.get(
|
||||
path="/store/creators",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
response_model=store_model.CreatorsResponse,
|
||||
)
|
||||
async def get_store_creators(
|
||||
featured: bool = False,
|
||||
search_query: str | None = None,
|
||||
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
|
||||
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> store_model.CreatorsResponse:
|
||||
@@ -348,7 +349,7 @@ async def get_store_creators(
|
||||
@v1_router.get(
|
||||
path="/store/creators/{username}",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
response_model=store_model.CreatorDetails,
|
||||
)
|
||||
async def get_store_creator(
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import prisma.enums
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
from backend.data.model import UserTransaction
|
||||
from backend.util.models import Pagination
|
||||
@@ -14,3 +18,42 @@ class UserHistoryResponse(BaseModel):
|
||||
class AddUserCreditsResponse(BaseModel):
|
||||
new_balance: int
|
||||
transaction_key: str
|
||||
|
||||
|
||||
class CreateInvitedUserRequest(BaseModel):
|
||||
email: EmailStr
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class InvitedUserResponse(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
status: prisma.enums.InvitedUserStatus
|
||||
auth_user_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
tally_understanding: Optional[dict[str, Any]] = None
|
||||
tally_status: prisma.enums.TallyComputationStatus
|
||||
tally_computed_at: Optional[datetime] = None
|
||||
tally_error: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class InvitedUsersResponse(BaseModel):
|
||||
invited_users: list[InvitedUserResponse]
|
||||
|
||||
|
||||
class BulkInvitedUserRowResponse(BaseModel):
|
||||
row_number: int
|
||||
email: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
status: Literal["CREATED", "SKIPPED", "ERROR"]
|
||||
message: str
|
||||
invited_user: Optional[InvitedUserResponse] = None
|
||||
|
||||
|
||||
class BulkInvitedUsersResponse(BaseModel):
|
||||
created_count: int
|
||||
skipped_count: int
|
||||
error_count: int
|
||||
results: list[BulkInvitedUserRowResponse]
|
||||
|
||||
@@ -24,14 +24,13 @@ router = fastapi.APIRouter(
|
||||
@router.get(
|
||||
"/listings",
|
||||
summary="Get Admin Listings History",
|
||||
response_model=store_model.StoreListingsWithVersionsResponse,
|
||||
)
|
||||
async def get_admin_listings_with_versions(
|
||||
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
|
||||
search: typing.Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
) -> store_model.StoreListingsWithVersionsAdminViewResponse:
|
||||
"""
|
||||
Get store listings with their version history for admins.
|
||||
|
||||
@@ -45,36 +44,26 @@ async def get_admin_listings_with_versions(
|
||||
page_size: Number of items per page
|
||||
|
||||
Returns:
|
||||
StoreListingsWithVersionsResponse with listings and their versions
|
||||
Paginated listings with their versions
|
||||
"""
|
||||
try:
|
||||
listings = await store_db.get_admin_listings_with_versions(
|
||||
status=status,
|
||||
search_query=search,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return listings
|
||||
except Exception as e:
|
||||
logger.exception("Error getting admin listings with versions: %s", e)
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "An error occurred while retrieving listings with versions"
|
||||
},
|
||||
)
|
||||
listings = await store_db.get_admin_listings_with_versions(
|
||||
status=status,
|
||||
search_query=search,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return listings
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/{store_listing_version_id}/review",
|
||||
summary="Review Store Submission",
|
||||
response_model=store_model.StoreSubmission,
|
||||
)
|
||||
async def review_submission(
|
||||
store_listing_version_id: str,
|
||||
request: store_model.ReviewSubmissionRequest,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
) -> store_model.StoreSubmissionAdminView:
|
||||
"""
|
||||
Review a store listing submission.
|
||||
|
||||
@@ -84,31 +73,24 @@ async def review_submission(
|
||||
user_id: Authenticated admin user performing the review
|
||||
|
||||
Returns:
|
||||
StoreSubmission with updated review information
|
||||
StoreSubmissionAdminView with updated review information
|
||||
"""
|
||||
try:
|
||||
already_approved = await store_db.check_submission_already_approved(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
submission = await store_db.review_store_submission(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
is_approved=request.is_approved,
|
||||
external_comments=request.comments,
|
||||
internal_comments=request.internal_comments or "",
|
||||
reviewer_id=user_id,
|
||||
)
|
||||
already_approved = await store_db.check_submission_already_approved(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
submission = await store_db.review_store_submission(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
is_approved=request.is_approved,
|
||||
external_comments=request.comments,
|
||||
internal_comments=request.internal_comments or "",
|
||||
reviewer_id=user_id,
|
||||
)
|
||||
|
||||
state_changed = already_approved != request.is_approved
|
||||
# Clear caches when the request is approved as it updates what is shown on the store
|
||||
if state_changed:
|
||||
store_cache.clear_all_caches()
|
||||
return submission
|
||||
except Exception as e:
|
||||
logger.exception("Error reviewing submission: %s", e)
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while reviewing the submission"},
|
||||
)
|
||||
state_changed = already_approved != request.is_approved
|
||||
# Clear caches whenever approval state changes, since store visibility can change
|
||||
if state_changed:
|
||||
store_cache.clear_all_caches()
|
||||
return submission
|
||||
|
||||
|
||||
@router.get(
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
import logging
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, File, Security, UploadFile
|
||||
|
||||
from backend.data.invited_user import (
|
||||
BulkInvitedUsersResult,
|
||||
InvitedUserRecord,
|
||||
bulk_create_invited_users_from_file,
|
||||
create_invited_user,
|
||||
list_invited_users,
|
||||
retry_invited_user_tally,
|
||||
revoke_invited_user,
|
||||
)
|
||||
|
||||
from .model import (
|
||||
BulkInvitedUserRowResponse,
|
||||
BulkInvitedUsersResponse,
|
||||
CreateInvitedUserRequest,
|
||||
InvitedUserResponse,
|
||||
InvitedUsersResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["users", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
def _to_response(invited_user: InvitedUserRecord) -> InvitedUserResponse:
|
||||
return InvitedUserResponse(**invited_user.model_dump())
|
||||
|
||||
|
||||
def _to_bulk_response(result: BulkInvitedUsersResult) -> BulkInvitedUsersResponse:
|
||||
return BulkInvitedUsersResponse(
|
||||
created_count=result.created_count,
|
||||
skipped_count=result.skipped_count,
|
||||
error_count=result.error_count,
|
||||
results=[
|
||||
BulkInvitedUserRowResponse(
|
||||
row_number=row.row_number,
|
||||
email=row.email,
|
||||
name=row.name,
|
||||
status=row.status,
|
||||
message=row.message,
|
||||
invited_user=(
|
||||
_to_response(row.invited_user)
|
||||
if row.invited_user is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
for row in result.results
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/invited-users",
|
||||
response_model=InvitedUsersResponse,
|
||||
summary="List Invited Users",
|
||||
)
|
||||
async def get_invited_users(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUsersResponse:
|
||||
logger.info("Admin user %s requested invited users", admin_user_id)
|
||||
invited_users = await list_invited_users()
|
||||
return InvitedUsersResponse(
|
||||
invited_users=[_to_response(invited_user) for invited_user in invited_users]
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Create Invited User",
|
||||
)
|
||||
async def create_invited_user_route(
|
||||
request: CreateInvitedUserRequest,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s created invited user for %s",
|
||||
admin_user_id,
|
||||
request.email,
|
||||
)
|
||||
invited_user = await create_invited_user(request.email, request.name)
|
||||
return _to_response(invited_user)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/bulk",
|
||||
response_model=BulkInvitedUsersResponse,
|
||||
summary="Bulk Create Invited Users",
|
||||
operation_id="postV2BulkCreateInvitedUsers",
|
||||
)
|
||||
async def bulk_create_invited_users_route(
|
||||
file: UploadFile = File(...),
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> BulkInvitedUsersResponse:
|
||||
logger.info(
|
||||
"Admin user %s bulk invited users from %s",
|
||||
admin_user_id,
|
||||
file.filename or "<unnamed>",
|
||||
)
|
||||
content = await file.read()
|
||||
result = await bulk_create_invited_users_from_file(file.filename, content)
|
||||
return _to_bulk_response(result)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/{invited_user_id}/revoke",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Revoke Invited User",
|
||||
)
|
||||
async def revoke_invited_user_route(
|
||||
invited_user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info("Admin user %s revoked invited user %s", admin_user_id, invited_user_id)
|
||||
invited_user = await revoke_invited_user(invited_user_id)
|
||||
return _to_response(invited_user)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/{invited_user_id}/retry-tally",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Retry Invited User Tally",
|
||||
)
|
||||
async def retry_invited_user_tally_route(
|
||||
invited_user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s retried Tally seed for invited user %s",
|
||||
admin_user_id,
|
||||
invited_user_id,
|
||||
)
|
||||
invited_user = await retry_invited_user_tally(invited_user_id)
|
||||
return _to_response(invited_user)
|
||||
@@ -0,0 +1,165 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import prisma.enums
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.data.invited_user import (
|
||||
BulkInvitedUserRowResult,
|
||||
BulkInvitedUsersResult,
|
||||
InvitedUserRecord,
|
||||
)
|
||||
|
||||
from .user_admin_routes import router as user_admin_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(user_admin_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _sample_invited_user() -> InvitedUserRecord:
|
||||
now = datetime.now(timezone.utc)
|
||||
return InvitedUserRecord(
|
||||
id="invite-1",
|
||||
email="invited@example.com",
|
||||
status=prisma.enums.InvitedUserStatus.INVITED,
|
||||
auth_user_id=None,
|
||||
name="Invited User",
|
||||
tally_understanding=None,
|
||||
tally_status=prisma.enums.TallyComputationStatus.PENDING,
|
||||
tally_computed_at=None,
|
||||
tally_error=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
|
||||
return BulkInvitedUsersResult(
|
||||
created_count=1,
|
||||
skipped_count=1,
|
||||
error_count=0,
|
||||
results=[
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=1,
|
||||
email="invited@example.com",
|
||||
name=None,
|
||||
status="CREATED",
|
||||
message="Invite created",
|
||||
invited_user=_sample_invited_user(),
|
||||
),
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=2,
|
||||
email="duplicate@example.com",
|
||||
name=None,
|
||||
status="SKIPPED",
|
||||
message="An invited user with this email already exists",
|
||||
invited_user=None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_get_invited_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.list_invited_users",
|
||||
AsyncMock(return_value=[_sample_invited_user()]),
|
||||
)
|
||||
|
||||
response = client.get("/admin/invited-users")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["invited_users"]) == 1
|
||||
assert data["invited_users"][0]["email"] == "invited@example.com"
|
||||
assert data["invited_users"][0]["status"] == "INVITED"
|
||||
|
||||
|
||||
def test_create_invited_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.create_invited_user",
|
||||
AsyncMock(return_value=_sample_invited_user()),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/invited-users",
|
||||
json={"email": "invited@example.com", "name": "Invited User"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == "invited@example.com"
|
||||
assert data["name"] == "Invited User"
|
||||
|
||||
|
||||
def test_bulk_create_invited_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.bulk_create_invited_users_from_file",
|
||||
AsyncMock(return_value=_sample_bulk_invited_users_result()),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/invited-users/bulk",
|
||||
files={
|
||||
"file": ("invites.txt", b"invited@example.com\nduplicate@example.com\n")
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["created_count"] == 1
|
||||
assert data["skipped_count"] == 1
|
||||
assert data["results"][0]["status"] == "CREATED"
|
||||
assert data["results"][1]["status"] == "SKIPPED"
|
||||
|
||||
|
||||
def test_revoke_invited_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
revoked = _sample_invited_user().model_copy(
|
||||
update={"status": prisma.enums.InvitedUserStatus.REVOKED}
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.revoke_invited_user",
|
||||
AsyncMock(return_value=revoked),
|
||||
)
|
||||
|
||||
response = client.post("/admin/invited-users/invite-1/revoke")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "REVOKED"
|
||||
|
||||
|
||||
def test_retry_invited_user_tally(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
retried = _sample_invited_user().model_copy(
|
||||
update={"tally_status": prisma.enums.TallyComputationStatus.RUNNING}
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.retry_invited_user_tally",
|
||||
AsyncMock(return_value=retried),
|
||||
)
|
||||
|
||||
response = client.post("/admin/invited-users/invite-1/retry-tally")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["tally_status"] == "RUNNING"
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
from uuid import uuid4
|
||||
@@ -9,7 +10,8 @@ from uuid import uuid4
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from prisma.models import UserWorkspaceFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
@@ -23,6 +25,7 @@ from backend.copilot.model import (
|
||||
delete_chat_session,
|
||||
get_chat_session,
|
||||
get_user_sessions,
|
||||
update_session_title,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.models import (
|
||||
@@ -40,6 +43,8 @@ from backend.copilot.tools.models import (
|
||||
ErrorResponse,
|
||||
ExecutionStartedResponse,
|
||||
InputValidationErrorResponse,
|
||||
MCPToolOutputResponse,
|
||||
MCPToolsDiscoveredResponse,
|
||||
NeedLoginResponse,
|
||||
NoResultsResponse,
|
||||
SetupRequirementsResponse,
|
||||
@@ -47,10 +52,14 @@ from backend.copilot.tools.models import (
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -79,6 +88,9 @@ class StreamChatRequest(BaseModel):
|
||||
message: str
|
||||
is_user_message: bool = True
|
||||
context: dict[str, str] | None = None # {url: str, content: str}
|
||||
file_ids: list[str] | None = Field(
|
||||
default=None, max_length=20
|
||||
) # Workspace file IDs attached to this message
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
@@ -130,6 +142,20 @@ class CancelSessionResponse(BaseModel):
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class UpdateSessionTitleRequest(BaseModel):
|
||||
"""Request model for updating a session's title."""
|
||||
|
||||
title: str
|
||||
|
||||
@field_validator("title")
|
||||
@classmethod
|
||||
def title_must_not_be_blank(cls, v: str) -> str:
|
||||
stripped = v.strip()
|
||||
if not stripped:
|
||||
raise ValueError("Title must not be blank")
|
||||
return stripped
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
|
||||
|
||||
@@ -238,9 +264,58 @@ async def delete_session(
|
||||
detail=f"Session {session_id} not found or access denied",
|
||||
)
|
||||
|
||||
# Best-effort cleanup of the E2B sandbox (if any).
|
||||
config = ChatConfig()
|
||||
if config.use_e2b_sandbox and config.e2b_api_key:
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
|
||||
try:
|
||||
await kill_sandbox(session_id, config.e2b_api_key)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[E2B] Failed to kill sandbox for session %s", session_id[:12]
|
||||
)
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/sessions/{session_id}/title",
|
||||
summary="Update session title",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
status_code=200,
|
||||
responses={404: {"description": "Session not found or access denied"}},
|
||||
)
|
||||
async def update_session_title_route(
|
||||
session_id: str,
|
||||
request: UpdateSessionTitleRequest,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> dict:
|
||||
"""
|
||||
Update the title of a chat session.
|
||||
|
||||
Allows the user to rename their chat session.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to update.
|
||||
request: Request body containing the new title.
|
||||
user_id: The authenticated user's ID.
|
||||
|
||||
Returns:
|
||||
dict: Status of the update.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if session not found or not owned by user.
|
||||
"""
|
||||
success = await update_session_title(session_id, user_id, request.title)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Session {session_id} not found or access denied",
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}",
|
||||
)
|
||||
@@ -394,6 +469,38 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
|
||||
# Enrich message with file metadata if file_ids are provided.
|
||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||
sanitized_file_ids: list[str] | None = None
|
||||
if request.file_ids and user_id:
|
||||
# Filter to valid UUIDs only to prevent DB abuse
|
||||
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
|
||||
|
||||
if valid_ids:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Batch query instead of N+1
|
||||
files = await UserWorkspaceFile.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": valid_ids},
|
||||
"workspaceId": workspace.id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
)
|
||||
# Only keep IDs that actually exist in the user's workspace
|
||||
sanitized_file_ids = [wf.id for wf in files] or None
|
||||
file_lines: list[str] = [
|
||||
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
|
||||
for wf in files
|
||||
]
|
||||
if file_lines:
|
||||
files_block = (
|
||||
"\n\n[Attached files]\n"
|
||||
+ "\n".join(file_lines)
|
||||
+ "\nUse read_workspace_file with the file_id to access file contents."
|
||||
)
|
||||
request.message += files_block
|
||||
|
||||
# Atomically append user message to session BEFORE creating task to avoid
|
||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||
# saved yet. append_and_save_message re-fetches inside a lock to prevent
|
||||
@@ -445,6 +552,7 @@ async def stream_chat_post(
|
||||
turn_id=turn_id,
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
file_ids=sanitized_file_ids,
|
||||
)
|
||||
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
@@ -487,7 +595,7 @@ async def stream_chat_post(
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
|
||||
chunks_yielded += 1
|
||||
|
||||
if not first_chunk_yielded:
|
||||
@@ -640,7 +748,7 @@ async def resume_session_stream(
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Resume stream chunk",
|
||||
@@ -697,7 +805,6 @@ async def resume_session_stream(
|
||||
@router.patch(
|
||||
"/sessions/{session_id}/assign-user",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
status_code=200,
|
||||
)
|
||||
async def session_assign_user(
|
||||
session_id: str,
|
||||
@@ -800,6 +907,8 @@ ToolResponseUnion = (
|
||||
| BlockOutputResponse
|
||||
| DocSearchResultsResponse
|
||||
| DocPageResponse
|
||||
| MCPToolsDiscoveredResponse
|
||||
| MCPToolOutputResponse
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,251 @@
|
||||
"""Tests for chat API routes: session title update and file attachment validation."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.chat import routes as chat_routes
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(chat_routes.router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _mock_update_session_title(
|
||||
mocker: pytest_mock.MockerFixture, *, success: bool = True
|
||||
):
|
||||
"""Mock update_session_title."""
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.update_session_title",
|
||||
new_callable=AsyncMock,
|
||||
return_value=success,
|
||||
)
|
||||
|
||||
|
||||
# ─── Update title: success ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_update_title_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_update = _mock_update_session_title(mocker, success=True)
|
||||
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": "My project"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
mock_update.assert_called_once_with("sess-1", test_user_id, "My project")
|
||||
|
||||
|
||||
def test_update_title_trims_whitespace(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_update = _mock_update_session_title(mocker, success=True)
|
||||
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": " trimmed "},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_update.assert_called_once_with("sess-1", test_user_id, "trimmed")
|
||||
|
||||
|
||||
# ─── Update title: blank / whitespace-only → 422 ──────────────────────
|
||||
|
||||
|
||||
def test_update_title_blank_rejected(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Whitespace-only titles must be rejected before hitting the DB."""
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": " "},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_update_title_empty_rejected(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": ""},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
# ─── Update title: session not found or wrong user → 404 ──────────────
|
||||
|
||||
|
||||
def test_update_title_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
_mock_update_session_title(mocker, success=False)
|
||||
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": "New name"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ─── file_ids Pydantic validation ─────────────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_rejects_too_many_file_ids():
|
||||
"""More than 20 file_ids should be rejected by Pydantic validation (422)."""
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={
|
||||
"message": "hello",
|
||||
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(21)],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def _mock_stream_internals(mocker: pytest_mock.MockFixture):
|
||||
"""Mock the async internals of stream_chat_post so tests can exercise
|
||||
validation and enrichment logic without needing Redis/RabbitMQ."""
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=None,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.create_session = mocker.AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
|
||||
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
|
||||
"""Exactly 20 file_ids should be accepted (not rejected by validation)."""
|
||||
_mock_stream_internals(mocker)
|
||||
# Patch workspace lookup as imported by the routes module
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={
|
||||
"message": "hello",
|
||||
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(20)],
|
||||
},
|
||||
)
|
||||
# Should get past validation — 200 streaming response expected
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ─── UUID format filtering ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
"""Non-UUID strings in file_ids should be silently filtered out
|
||||
and NOT passed to the database query."""
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
valid_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={
|
||||
"message": "hello",
|
||||
"file_ids": [
|
||||
valid_id,
|
||||
"not-a-uuid",
|
||||
"../../../etc/passwd",
|
||||
"",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
# The find_many call should only receive the one valid UUID
|
||||
mock_prisma.find_many.assert_called_once()
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["id"]["in"] == [valid_id]
|
||||
|
||||
|
||||
# ─── Cross-workspace file_ids ─────────────────────────────────────────
|
||||
|
||||
|
||||
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
"""The batch query should scope to the user's workspace."""
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "my-workspace-id"})(),
|
||||
)
|
||||
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
fid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hi", "file_ids": [fid]},
|
||||
)
|
||||
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
|
||||
assert call_kwargs["where"]["isDeleted"] is False
|
||||
@@ -22,6 +22,7 @@ from backend.data.human_review import (
|
||||
)
|
||||
from backend.data.model import USER_TIMEZONE_NOT_SET
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.executor.utils import add_graph_execution
|
||||
|
||||
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
|
||||
@@ -321,10 +322,13 @@ async def process_review_action(
|
||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||
)
|
||||
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
execution_context = ExecutionContext(
|
||||
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
||||
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
||||
user_timezone=user_timezone,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
await add_graph_execution(
|
||||
|
||||
@@ -8,7 +8,6 @@ import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
import backend.api.features.store.exceptions as store_exceptions
|
||||
import backend.api.features.store.image_gen as store_image_gen
|
||||
import backend.api.features.store.media as store_media
|
||||
import backend.data.graph as graph_db
|
||||
@@ -251,7 +250,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
|
||||
The requested LibraryAgent.
|
||||
|
||||
Raises:
|
||||
AgentNotFoundError: If the specified agent does not exist.
|
||||
NotFoundError: If the specified agent does not exist.
|
||||
DatabaseError: If there's an error during retrieval.
|
||||
"""
|
||||
library_agent = await prisma.models.LibraryAgent.prisma().find_first(
|
||||
@@ -398,6 +397,7 @@ async def create_library_agent(
|
||||
hitl_safe_mode: bool = True,
|
||||
sensitive_action_safe_mode: bool = False,
|
||||
create_library_agents_for_sub_graphs: bool = True,
|
||||
folder_id: str | None = None,
|
||||
) -> list[library_model.LibraryAgent]:
|
||||
"""
|
||||
Adds an agent to the user's library (LibraryAgent table).
|
||||
@@ -414,12 +414,18 @@ async def create_library_agent(
|
||||
If the graph has sub-graphs, the parent graph will always be the first entry in the list.
|
||||
|
||||
Raises:
|
||||
AgentNotFoundError: If the specified agent does not exist.
|
||||
NotFoundError: If the specified agent does not exist.
|
||||
DatabaseError: If there's an error during creation or if image generation fails.
|
||||
"""
|
||||
logger.info(
|
||||
f"Creating library agent for graph #{graph.id} v{graph.version}; user:<redacted>"
|
||||
)
|
||||
|
||||
# Authorization: FK only checks existence, not ownership.
|
||||
# Verify the folder belongs to this user to prevent cross-user nesting.
|
||||
if folder_id:
|
||||
await get_folder(folder_id, user_id)
|
||||
|
||||
graph_entries = (
|
||||
[graph, *graph.sub_graphs] if create_library_agents_for_sub_graphs else [graph]
|
||||
)
|
||||
@@ -432,7 +438,6 @@ async def create_library_agent(
|
||||
isCreatedByUser=(user_id == user_id),
|
||||
useGraphIsActiveVersion=True,
|
||||
User={"connect": {"id": user_id}},
|
||||
# Creator={"connect": {"id": user_id}},
|
||||
AgentGraph={
|
||||
"connect": {
|
||||
"graphVersionId": {
|
||||
@@ -448,6 +453,11 @@ async def create_library_agent(
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
),
|
||||
**(
|
||||
{"Folder": {"connect": {"id": folder_id}}}
|
||||
if folder_id and graph_entry is graph
|
||||
else {}
|
||||
),
|
||||
),
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
@@ -529,6 +539,7 @@ async def update_agent_version_in_library(
|
||||
async def create_graph_in_library(
|
||||
graph: graph_db.Graph,
|
||||
user_id: str,
|
||||
folder_id: str | None = None,
|
||||
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
||||
"""Create a new graph and add it to the user's library."""
|
||||
graph.version = 1
|
||||
@@ -542,6 +553,7 @@ async def create_graph_in_library(
|
||||
user_id=user_id,
|
||||
sensitive_action_safe_mode=True,
|
||||
create_library_agents_for_sub_graphs=False,
|
||||
folder_id=folder_id,
|
||||
)
|
||||
|
||||
if created_graph.is_active:
|
||||
@@ -817,7 +829,7 @@ async def add_store_agent_to_library(
|
||||
The newly created LibraryAgent if successfully added, the existing corresponding one if any.
|
||||
|
||||
Raises:
|
||||
AgentNotFoundError: If the store listing or associated agent is not found.
|
||||
NotFoundError: If the store listing or associated agent is not found.
|
||||
DatabaseError: If there's an issue creating the LibraryAgent record.
|
||||
"""
|
||||
logger.debug(
|
||||
@@ -832,7 +844,7 @@ async def add_store_agent_to_library(
|
||||
)
|
||||
if not store_listing_version or not store_listing_version.AgentGraph:
|
||||
logger.warning(f"Store listing version not found: {store_listing_version_id}")
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
raise NotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found or invalid"
|
||||
)
|
||||
|
||||
@@ -846,7 +858,7 @@ async def add_store_agent_to_library(
|
||||
include_subgraphs=False,
|
||||
)
|
||||
if not graph_model:
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
raise NotFoundError(
|
||||
f"Graph #{graph.id} v{graph.version} not found or accessible"
|
||||
)
|
||||
|
||||
@@ -1481,6 +1493,67 @@ async def bulk_move_agents_to_folder(
|
||||
return [library_model.LibraryAgent.from_db(agent) for agent in agents]
|
||||
|
||||
|
||||
def collect_tree_ids(
|
||||
nodes: list[library_model.LibraryFolderTree],
|
||||
visited: set[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Collect all folder IDs from a folder tree."""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
ids: list[str] = []
|
||||
for n in nodes:
|
||||
if n.id in visited:
|
||||
continue
|
||||
visited.add(n.id)
|
||||
ids.append(n.id)
|
||||
ids.extend(collect_tree_ids(n.children, visited))
|
||||
return ids
|
||||
|
||||
|
||||
async def get_folder_agent_summaries(
|
||||
user_id: str, folder_id: str
|
||||
) -> list[dict[str, str | None]]:
|
||||
"""Get a lightweight list of agents in a folder (id, name, description)."""
|
||||
all_agents: list[library_model.LibraryAgent] = []
|
||||
for page in itertools.count(1):
|
||||
resp = await list_library_agents(
|
||||
user_id=user_id, folder_id=folder_id, page=page
|
||||
)
|
||||
all_agents.extend(resp.agents)
|
||||
if page >= resp.pagination.total_pages:
|
||||
break
|
||||
return [
|
||||
{"id": a.id, "name": a.name, "description": a.description} for a in all_agents
|
||||
]
|
||||
|
||||
|
||||
async def get_root_agent_summaries(
|
||||
user_id: str,
|
||||
) -> list[dict[str, str | None]]:
|
||||
"""Get a lightweight list of root-level agents (folderId IS NULL)."""
|
||||
all_agents: list[library_model.LibraryAgent] = []
|
||||
for page in itertools.count(1):
|
||||
resp = await list_library_agents(
|
||||
user_id=user_id, include_root_only=True, page=page
|
||||
)
|
||||
all_agents.extend(resp.agents)
|
||||
if page >= resp.pagination.total_pages:
|
||||
break
|
||||
return [
|
||||
{"id": a.id, "name": a.name, "description": a.description} for a in all_agents
|
||||
]
|
||||
|
||||
|
||||
async def get_folder_agents_map(
|
||||
user_id: str, folder_ids: list[str]
|
||||
) -> dict[str, list[dict[str, str | None]]]:
|
||||
"""Get agent summaries for multiple folders concurrently."""
|
||||
results = await asyncio.gather(
|
||||
*(get_folder_agent_summaries(user_id, fid) for fid in folder_ids)
|
||||
)
|
||||
return dict(zip(folder_ids, results))
|
||||
|
||||
|
||||
##############################################
|
||||
########### Presets DB Functions #############
|
||||
##############################################
|
||||
|
||||
@@ -4,7 +4,6 @@ import prisma.enums
|
||||
import prisma.models
|
||||
import pytest
|
||||
|
||||
import backend.api.features.store.exceptions
|
||||
from backend.data.db import connect
|
||||
from backend.data.includes import library_agent_include
|
||||
|
||||
@@ -218,7 +217,7 @@ async def test_add_agent_to_library_not_found(mocker):
|
||||
)
|
||||
|
||||
# Call function and verify exception
|
||||
with pytest.raises(backend.api.features.store.exceptions.AgentNotFoundError):
|
||||
with pytest.raises(db.NotFoundError):
|
||||
await db.add_store_agent_to_library("version123", "test-user")
|
||||
|
||||
# Verify mock called correctly
|
||||
|
||||
@@ -7,20 +7,24 @@ frontend can list available tools on an MCP server before placing a block.
|
||||
|
||||
import logging
|
||||
from typing import Annotated, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import fastapi
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from fastapi import Security
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from backend.api.features.integrations.router import CredentialsMetaResponse
|
||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
||||
from backend.blocks.mcp.helpers import (
|
||||
auto_lookup_mcp_credential,
|
||||
normalize_mcp_url,
|
||||
server_host,
|
||||
)
|
||||
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import HTTPClientError, Requests
|
||||
from backend.util.request import HTTPClientError, Requests, validate_url
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -74,32 +78,20 @@ async def discover_tools(
|
||||
If the user has a stored MCP credential for this server URL, it will be
|
||||
used automatically — no need to pass an explicit auth token.
|
||||
"""
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
auth_token = request.auth_token
|
||||
|
||||
# Auto-use stored MCP credential when no explicit token is provided.
|
||||
if not auth_token:
|
||||
mcp_creds = await creds_manager.store.get_creds_by_provider(
|
||||
user_id, ProviderName.MCP.value
|
||||
best_cred = await auto_lookup_mcp_credential(
|
||||
user_id, normalize_mcp_url(request.server_url)
|
||||
)
|
||||
# Find the freshest credential for this server URL
|
||||
best_cred: OAuth2Credentials | None = None
|
||||
for cred in mcp_creds:
|
||||
if (
|
||||
isinstance(cred, OAuth2Credentials)
|
||||
and (cred.metadata or {}).get("mcp_server_url") == request.server_url
|
||||
):
|
||||
if best_cred is None or (
|
||||
(cred.access_token_expires_at or 0)
|
||||
> (best_cred.access_token_expires_at or 0)
|
||||
):
|
||||
best_cred = cred
|
||||
if best_cred:
|
||||
# Refresh the token if expired before using it
|
||||
best_cred = await creds_manager.refresh_if_needed(user_id, best_cred)
|
||||
logger.info(
|
||||
f"Using MCP credential {best_cred.id} for {request.server_url}, "
|
||||
f"expires_at={best_cred.access_token_expires_at}"
|
||||
)
|
||||
auth_token = best_cred.access_token.get_secret_value()
|
||||
|
||||
client = MCPClient(request.server_url, auth_token=auth_token)
|
||||
@@ -134,7 +126,7 @@ async def discover_tools(
|
||||
],
|
||||
server_name=(
|
||||
init_result.get("serverInfo", {}).get("name")
|
||||
or urlparse(request.server_url).hostname
|
||||
or server_host(request.server_url)
|
||||
or "MCP"
|
||||
),
|
||||
protocol_version=init_result.get("protocolVersion"),
|
||||
@@ -173,7 +165,16 @@ async def mcp_oauth_login(
|
||||
3. Performs Dynamic Client Registration (RFC 7591) if available
|
||||
4. Returns the authorization URL for the frontend to open in a popup
|
||||
"""
|
||||
client = MCPClient(request.server_url)
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
# Normalize the URL so that credentials stored here are matched consistently
|
||||
# by auto_lookup_mcp_credential (which also uses normalized URLs).
|
||||
server_url = normalize_mcp_url(request.server_url)
|
||||
client = MCPClient(server_url)
|
||||
|
||||
# Step 1: Discover protected-resource metadata (RFC 9728)
|
||||
protected_resource = await client.discover_auth()
|
||||
@@ -182,7 +183,16 @@ async def mcp_oauth_login(
|
||||
|
||||
if protected_resource and protected_resource.get("authorization_servers"):
|
||||
auth_server_url = protected_resource["authorization_servers"][0]
|
||||
resource_url = protected_resource.get("resource", request.server_url)
|
||||
resource_url = protected_resource.get("resource", server_url)
|
||||
|
||||
# Validate the auth server URL from metadata to prevent SSRF.
|
||||
try:
|
||||
await validate_url(auth_server_url, trusted_origins=[])
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid authorization server URL in metadata: {e}",
|
||||
)
|
||||
|
||||
# Step 2a: Discover auth-server metadata (RFC 8414)
|
||||
metadata = await client.discover_auth_server_metadata(auth_server_url)
|
||||
@@ -192,7 +202,7 @@ async def mcp_oauth_login(
|
||||
# Don't assume a resource_url — omitting it lets the auth server choose
|
||||
# the correct audience for the token (RFC 8707 resource is optional).
|
||||
resource_url = None
|
||||
metadata = await client.discover_auth_server_metadata(request.server_url)
|
||||
metadata = await client.discover_auth_server_metadata(server_url)
|
||||
|
||||
if (
|
||||
not metadata
|
||||
@@ -222,12 +232,18 @@ async def mcp_oauth_login(
|
||||
client_id = ""
|
||||
client_secret = ""
|
||||
if registration_endpoint:
|
||||
reg_result = await _register_mcp_client(
|
||||
registration_endpoint, redirect_uri, request.server_url
|
||||
)
|
||||
if reg_result:
|
||||
client_id = reg_result.get("client_id", "")
|
||||
client_secret = reg_result.get("client_secret", "")
|
||||
# Validate the registration endpoint to prevent SSRF via metadata.
|
||||
try:
|
||||
await validate_url(registration_endpoint, trusted_origins=[])
|
||||
except ValueError:
|
||||
pass # Skip registration, fall back to default client_id
|
||||
else:
|
||||
reg_result = await _register_mcp_client(
|
||||
registration_endpoint, redirect_uri, server_url
|
||||
)
|
||||
if reg_result:
|
||||
client_id = reg_result.get("client_id", "")
|
||||
client_secret = reg_result.get("client_secret", "")
|
||||
|
||||
if not client_id:
|
||||
client_id = "autogpt-platform"
|
||||
@@ -245,7 +261,7 @@ async def mcp_oauth_login(
|
||||
"token_url": token_url,
|
||||
"revoke_url": revoke_url,
|
||||
"resource_url": resource_url,
|
||||
"server_url": request.server_url,
|
||||
"server_url": server_url,
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
},
|
||||
@@ -342,7 +358,7 @@ async def mcp_oauth_callback(
|
||||
credentials.metadata["mcp_token_url"] = meta["token_url"]
|
||||
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
|
||||
|
||||
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
|
||||
hostname = server_host(meta["server_url"])
|
||||
credentials.title = f"MCP: {hostname}"
|
||||
|
||||
# Remove old MCP credentials for the same server to prevent stale token buildup.
|
||||
@@ -357,7 +373,9 @@ async def mcp_oauth_callback(
|
||||
):
|
||||
await creds_manager.store.delete_creds_by_id(user_id, old.id)
|
||||
logger.info(
|
||||
f"Removed old MCP credential {old.id} for {meta['server_url']}"
|
||||
"Removed old MCP credential %s for %s",
|
||||
old.id,
|
||||
server_host(meta["server_url"]),
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Could not clean up old MCP credentials", exc_info=True)
|
||||
@@ -375,6 +393,93 @@ async def mcp_oauth_callback(
|
||||
)
|
||||
|
||||
|
||||
# ======================== Bearer Token ======================== #
|
||||
|
||||
|
||||
class MCPStoreTokenRequest(BaseModel):
|
||||
"""Request to store a bearer token for an MCP server that doesn't support OAuth."""
|
||||
|
||||
server_url: str = Field(
|
||||
description="MCP server URL the token authenticates against"
|
||||
)
|
||||
token: SecretStr = Field(
|
||||
min_length=1, description="Bearer token / API key for the MCP server"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/token",
|
||||
summary="Store a bearer token for an MCP server",
|
||||
)
|
||||
async def mcp_store_token(
|
||||
request: MCPStoreTokenRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> CredentialsMetaResponse:
|
||||
"""
|
||||
Store a manually provided bearer token as an MCP credential.
|
||||
|
||||
Used by the Copilot MCPSetupCard when the server doesn't support the MCP
|
||||
OAuth discovery flow (returns 400 from /oauth/login). Subsequent
|
||||
``run_mcp_tool`` calls will automatically pick up the token via
|
||||
``_auto_lookup_credential``.
|
||||
"""
|
||||
token = request.token.get_secret_value().strip()
|
||||
if not token:
|
||||
raise fastapi.HTTPException(status_code=422, detail="Token must not be blank.")
|
||||
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
# Normalize URL so trailing-slash variants match existing credentials.
|
||||
server_url = normalize_mcp_url(request.server_url)
|
||||
hostname = server_host(server_url)
|
||||
|
||||
# Collect IDs of old credentials to clean up after successful create.
|
||||
old_cred_ids: list[str] = []
|
||||
try:
|
||||
old_creds = await creds_manager.store.get_creds_by_provider(
|
||||
user_id, ProviderName.MCP.value
|
||||
)
|
||||
old_cred_ids = [
|
||||
old.id
|
||||
for old in old_creds
|
||||
if isinstance(old, OAuth2Credentials)
|
||||
and normalize_mcp_url((old.metadata or {}).get("mcp_server_url", ""))
|
||||
== server_url
|
||||
]
|
||||
except Exception:
|
||||
logger.debug("Could not query old MCP token credentials", exc_info=True)
|
||||
|
||||
credentials = OAuth2Credentials(
|
||||
provider=ProviderName.MCP.value,
|
||||
title=f"MCP: {hostname}",
|
||||
access_token=SecretStr(token),
|
||||
scopes=[],
|
||||
metadata={"mcp_server_url": server_url},
|
||||
)
|
||||
await creds_manager.create(user_id, credentials)
|
||||
|
||||
# Only delete old credentials after the new one is safely stored.
|
||||
for old_id in old_cred_ids:
|
||||
try:
|
||||
await creds_manager.store.delete_creds_by_id(user_id, old_id)
|
||||
except Exception:
|
||||
logger.debug("Could not clean up old MCP token credential", exc_info=True)
|
||||
|
||||
return CredentialsMetaResponse(
|
||||
id=credentials.id,
|
||||
provider=credentials.provider,
|
||||
type=credentials.type,
|
||||
title=credentials.title,
|
||||
scopes=credentials.scopes,
|
||||
username=credentials.username,
|
||||
host=hostname,
|
||||
)
|
||||
|
||||
|
||||
# ======================== Helpers ======================== #
|
||||
|
||||
|
||||
@@ -400,5 +505,7 @@ async def _register_mcp_client(
|
||||
return data
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Dynamic client registration failed for {server_url}: {e}")
|
||||
logger.warning(
|
||||
"Dynamic client registration failed for %s: %s", server_host(server_url), e
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -11,9 +11,11 @@ import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.api.features.mcp.routes import router
|
||||
from backend.blocks.mcp.client import MCPClientError, MCPTool
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.util.request import HTTPClientError
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
@@ -28,6 +30,16 @@ async def client():
|
||||
yield c
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _bypass_ssrf_validation():
|
||||
"""Bypass validate_url in all route tests (test URLs don't resolve)."""
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestDiscoverTools:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_success(self, client):
|
||||
@@ -56,9 +68,12 @@ class TestDiscoverTools:
|
||||
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
patch(
|
||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
return_value={
|
||||
@@ -107,10 +122,6 @@ class TestDiscoverTools:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_auto_uses_stored_credential(self, client):
|
||||
"""When no explicit token is given, stored MCP credentials are used."""
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
|
||||
stored_cred = OAuth2Credentials(
|
||||
provider="mcp",
|
||||
title="MCP: example.com",
|
||||
@@ -124,10 +135,12 @@ class TestDiscoverTools:
|
||||
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
patch(
|
||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=stored_cred,
|
||||
),
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred])
|
||||
mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred)
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
|
||||
@@ -149,9 +162,12 @@ class TestDiscoverTools:
|
||||
async def test_discover_tools_mcp_error(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
patch(
|
||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
side_effect=MCPClientError("Connection refused")
|
||||
@@ -169,9 +185,12 @@ class TestDiscoverTools:
|
||||
async def test_discover_tools_generic_error(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
patch(
|
||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
|
||||
|
||||
@@ -187,9 +206,12 @@ class TestDiscoverTools:
|
||||
async def test_discover_tools_auth_required(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
patch(
|
||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
|
||||
@@ -207,9 +229,12 @@ class TestDiscoverTools:
|
||||
async def test_discover_tools_forbidden(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
patch(
|
||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
|
||||
@@ -331,10 +356,6 @@ class TestOAuthLogin:
|
||||
class TestOAuthCallback:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_callback_success(self, client):
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
|
||||
mock_creds = OAuth2Credentials(
|
||||
provider="mcp",
|
||||
title=None,
|
||||
@@ -434,3 +455,118 @@ class TestOAuthCallback:
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "token exchange failed" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
class TestStoreToken:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_token_success(self, client):
|
||||
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
mock_cm.create = AsyncMock()
|
||||
|
||||
response = await client.post(
|
||||
"/token",
|
||||
json={
|
||||
"server_url": "https://mcp.example.com/mcp",
|
||||
"token": "my-api-key-123",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["provider"] == "mcp"
|
||||
assert data["type"] == "oauth2"
|
||||
assert data["host"] == "mcp.example.com"
|
||||
mock_cm.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_token_blank_rejected(self, client):
|
||||
"""Blank token string (after stripping) should return 422."""
|
||||
response = await client.post(
|
||||
"/token",
|
||||
json={
|
||||
"server_url": "https://mcp.example.com/mcp",
|
||||
"token": " ",
|
||||
},
|
||||
)
|
||||
# Pydantic min_length=1 catches the whitespace-only token
|
||||
assert response.status_code == 422
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_token_replaces_old_credential(self, client):
|
||||
old_cred = OAuth2Credentials(
|
||||
provider="mcp",
|
||||
title="MCP: mcp.example.com",
|
||||
access_token=SecretStr("old-token"),
|
||||
scopes=[],
|
||||
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
|
||||
)
|
||||
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[old_cred])
|
||||
mock_cm.create = AsyncMock()
|
||||
mock_cm.store.delete_creds_by_id = AsyncMock()
|
||||
|
||||
response = await client.post(
|
||||
"/token",
|
||||
json={
|
||||
"server_url": "https://mcp.example.com/mcp",
|
||||
"token": "new-token",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_cm.store.delete_creds_by_id.assert_called_once_with(
|
||||
"test-user-id", old_cred.id
|
||||
)
|
||||
|
||||
|
||||
class TestSSRFValidation:
|
||||
"""Verify that validate_url is enforced on all endpoints."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_ssrf_blocked(self, client):
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked loopback"),
|
||||
):
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={"server_url": "http://localhost/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "blocked loopback" in response.json()["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_login_ssrf_blocked(self, client):
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked private IP"),
|
||||
):
|
||||
response = await client.post(
|
||||
"/oauth/login",
|
||||
json={"server_url": "http://10.0.0.1/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "blocked private ip" in response.json()["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_token_ssrf_blocked(self, client):
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked loopback"),
|
||||
):
|
||||
response = await client.post(
|
||||
"/token",
|
||||
json={
|
||||
"server_url": "http://127.0.0.1/mcp",
|
||||
"token": "some-token",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "blocked loopback" in response.json()["detail"].lower()
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Literal
|
||||
|
||||
from backend.util.cache import cached
|
||||
|
||||
from . import db as store_db
|
||||
@@ -23,7 +21,7 @@ def clear_all_caches():
|
||||
async def _get_cached_store_agents(
|
||||
featured: bool,
|
||||
creator: str | None,
|
||||
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None,
|
||||
sorted_by: store_db.StoreAgentsSortOptions | None,
|
||||
search_query: str | None,
|
||||
category: str | None,
|
||||
page: int,
|
||||
@@ -57,7 +55,7 @@ async def _get_cached_agent_details(
|
||||
async def _get_cached_store_creators(
|
||||
featured: bool,
|
||||
search_query: str | None,
|
||||
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None,
|
||||
sorted_by: store_db.StoreCreatorsSortOptions | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
@@ -75,4 +73,4 @@ async def _get_cached_store_creators(
|
||||
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
|
||||
async def _get_cached_creator_details(username: str):
|
||||
"""Cached helper to get creator details."""
|
||||
return await store_db.get_store_creator_details(username=username.lower())
|
||||
return await store_db.get_store_creator(username=username.lower())
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -26,7 +26,7 @@ async def test_get_store_agents(mocker):
|
||||
mock_agents = [
|
||||
prisma.models.StoreAgent(
|
||||
listing_id="test-id",
|
||||
storeListingVersionId="version123",
|
||||
listing_version_id="version123",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video=None,
|
||||
@@ -40,11 +40,11 @@ async def test_get_store_agents(mocker):
|
||||
runs=10,
|
||||
rating=4.5,
|
||||
versions=["1.0"],
|
||||
agentGraphVersions=["1"],
|
||||
agentGraphId="test-graph-id",
|
||||
graph_id="test-graph-id",
|
||||
graph_versions=["1"],
|
||||
updated_at=datetime.now(),
|
||||
is_available=False,
|
||||
useForOnboarding=False,
|
||||
use_for_onboarding=False,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -68,10 +68,10 @@ async def test_get_store_agents(mocker):
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agent_details(mocker):
|
||||
# Mock data
|
||||
# Mock data - StoreAgent view already contains the active version data
|
||||
mock_agent = prisma.models.StoreAgent(
|
||||
listing_id="test-id",
|
||||
storeListingVersionId="version123",
|
||||
listing_version_id="version123",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video="video.mp4",
|
||||
@@ -85,102 +85,38 @@ async def test_get_store_agent_details(mocker):
|
||||
runs=10,
|
||||
rating=4.5,
|
||||
versions=["1.0"],
|
||||
agentGraphVersions=["1"],
|
||||
agentGraphId="test-graph-id",
|
||||
updated_at=datetime.now(),
|
||||
is_available=False,
|
||||
useForOnboarding=False,
|
||||
)
|
||||
|
||||
# Mock active version agent (what we want to return for active version)
|
||||
mock_active_agent = prisma.models.StoreAgent(
|
||||
listing_id="test-id",
|
||||
storeListingVersionId="active-version-id",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent Active",
|
||||
agent_video="active_video.mp4",
|
||||
agent_image=["active_image.jpg"],
|
||||
featured=False,
|
||||
creator_username="creator",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test heading active",
|
||||
description="Test description active",
|
||||
categories=["test"],
|
||||
runs=15,
|
||||
rating=4.8,
|
||||
versions=["1.0", "2.0"],
|
||||
agentGraphVersions=["1", "2"],
|
||||
agentGraphId="test-graph-id-active",
|
||||
graph_id="test-graph-id",
|
||||
graph_versions=["1"],
|
||||
updated_at=datetime.now(),
|
||||
is_available=True,
|
||||
useForOnboarding=False,
|
||||
use_for_onboarding=False,
|
||||
)
|
||||
|
||||
# Create a mock StoreListing result
|
||||
mock_store_listing = mocker.MagicMock()
|
||||
mock_store_listing.activeVersionId = "active-version-id"
|
||||
mock_store_listing.hasApprovedVersion = True
|
||||
mock_store_listing.ActiveVersion = mocker.MagicMock()
|
||||
mock_store_listing.ActiveVersion.recommendedScheduleCron = None
|
||||
|
||||
# Mock StoreAgent prisma call - need to handle multiple calls
|
||||
# Mock StoreAgent prisma call
|
||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||
|
||||
# Set up side_effect to return different results for different calls
|
||||
def mock_find_first_side_effect(*args, **kwargs):
|
||||
where_clause = kwargs.get("where", {})
|
||||
if "storeListingVersionId" in where_clause:
|
||||
# Second call for active version
|
||||
return mock_active_agent
|
||||
else:
|
||||
# First call for initial lookup
|
||||
return mock_agent
|
||||
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(
|
||||
side_effect=mock_find_first_side_effect
|
||||
)
|
||||
|
||||
# Mock Profile prisma call
|
||||
mock_profile = mocker.MagicMock()
|
||||
mock_profile.userId = "user-id-123"
|
||||
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
||||
mock_profile_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_profile
|
||||
)
|
||||
|
||||
# Mock StoreListing prisma call
|
||||
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_store_listing
|
||||
)
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
# Call function
|
||||
result = await db.get_store_agent_details("creator", "test-agent")
|
||||
|
||||
# Verify results - should use active version data
|
||||
# Verify results - constructed from the StoreAgent view
|
||||
assert result.slug == "test-agent"
|
||||
assert result.agent_name == "Test Agent Active" # From active version
|
||||
assert result.active_version_id == "active-version-id"
|
||||
assert result.agent_name == "Test Agent"
|
||||
assert result.active_version_id == "version123"
|
||||
assert result.has_approved_version is True
|
||||
assert (
|
||||
result.store_listing_version_id == "active-version-id"
|
||||
) # Should be active version ID
|
||||
assert result.store_listing_version_id == "version123"
|
||||
assert result.graph_id == "test-graph-id"
|
||||
assert result.runs == 10
|
||||
assert result.rating == 4.5
|
||||
|
||||
# Verify mocks called correctly - now expecting 2 calls
|
||||
assert mock_store_agent.return_value.find_first.call_count == 2
|
||||
|
||||
# Check the specific calls
|
||||
calls = mock_store_agent.return_value.find_first.call_args_list
|
||||
assert calls[0] == mocker.call(
|
||||
# Verify single StoreAgent lookup
|
||||
mock_store_agent.return_value.find_first.assert_called_once_with(
|
||||
where={"creator_username": "creator", "slug": "test-agent"}
|
||||
)
|
||||
assert calls[1] == mocker.call(where={"storeListingVersionId": "active-version-id"})
|
||||
|
||||
mock_store_listing_db.return_value.find_first.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_creator_details(mocker):
|
||||
async def test_get_store_creator(mocker):
|
||||
# Mock data
|
||||
mock_creator_data = prisma.models.Creator(
|
||||
name="Test Creator",
|
||||
@@ -202,7 +138,7 @@ async def test_get_store_creator_details(mocker):
|
||||
mock_creator.return_value.find_unique.return_value = mock_creator_data
|
||||
|
||||
# Call function
|
||||
result = await db.get_store_creator_details("creator")
|
||||
result = await db.get_store_creator("creator")
|
||||
|
||||
# Verify results
|
||||
assert result.username == "creator"
|
||||
@@ -218,61 +154,110 @@ async def test_get_store_creator_details(mocker):
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_create_store_submission(mocker):
|
||||
# Mock data
|
||||
now = datetime.now()
|
||||
|
||||
# Mock agent graph (with no pending submissions) and user with profile
|
||||
mock_profile = prisma.models.Profile(
|
||||
id="profile-id",
|
||||
userId="user-id",
|
||||
name="Test User",
|
||||
username="testuser",
|
||||
description="Test",
|
||||
isFeatured=False,
|
||||
links=[],
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
)
|
||||
mock_user = prisma.models.User(
|
||||
id="user-id",
|
||||
email="test@example.com",
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
Profile=[mock_profile],
|
||||
emailVerified=True,
|
||||
metadata="{}", # type: ignore[reportArgumentType]
|
||||
integrations="",
|
||||
maxEmailsPerDay=1,
|
||||
notifyOnAgentRun=True,
|
||||
notifyOnZeroBalance=True,
|
||||
notifyOnLowBalance=True,
|
||||
notifyOnBlockExecutionFailed=True,
|
||||
notifyOnContinuousAgentError=True,
|
||||
notifyOnDailySummary=True,
|
||||
notifyOnWeeklySummary=True,
|
||||
notifyOnMonthlySummary=True,
|
||||
notifyOnAgentApproved=True,
|
||||
notifyOnAgentRejected=True,
|
||||
timezone="Europe/Delft",
|
||||
)
|
||||
mock_agent = prisma.models.AgentGraph(
|
||||
id="agent-id",
|
||||
version=1,
|
||||
userId="user-id",
|
||||
createdAt=datetime.now(),
|
||||
createdAt=now,
|
||||
isActive=True,
|
||||
StoreListingVersions=[],
|
||||
User=mock_user,
|
||||
)
|
||||
|
||||
mock_listing = prisma.models.StoreListing(
|
||||
# Mock the created StoreListingVersion (returned by create)
|
||||
mock_store_listing_obj = prisma.models.StoreListing(
|
||||
id="listing-id",
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
isDeleted=False,
|
||||
hasApprovedVersion=False,
|
||||
slug="test-agent",
|
||||
agentGraphId="agent-id",
|
||||
agentGraphVersion=1,
|
||||
owningUserId="user-id",
|
||||
Versions=[
|
||||
prisma.models.StoreListingVersion(
|
||||
id="version-id",
|
||||
agentGraphId="agent-id",
|
||||
agentGraphVersion=1,
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
subHeading="Test heading",
|
||||
imageUrls=["image.jpg"],
|
||||
categories=["test"],
|
||||
isFeatured=False,
|
||||
isDeleted=False,
|
||||
version=1,
|
||||
storeListingId="listing-id",
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
isAvailable=True,
|
||||
)
|
||||
],
|
||||
useForOnboarding=False,
|
||||
)
|
||||
mock_version = prisma.models.StoreListingVersion(
|
||||
id="version-id",
|
||||
agentGraphId="agent-id",
|
||||
agentGraphVersion=1,
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
subHeading="",
|
||||
imageUrls=[],
|
||||
categories=[],
|
||||
isFeatured=False,
|
||||
isDeleted=False,
|
||||
version=1,
|
||||
storeListingId="listing-id",
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
isAvailable=True,
|
||||
submittedAt=now,
|
||||
StoreListing=mock_store_listing_obj,
|
||||
)
|
||||
|
||||
# Mock prisma calls
|
||||
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
|
||||
mock_agent_graph.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
mock_store_listing = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_store_listing.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_store_listing.return_value.create = mocker.AsyncMock(return_value=mock_listing)
|
||||
# Mock transaction context manager
|
||||
mock_tx = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"backend.api.features.store.db.transaction",
|
||||
return_value=mocker.AsyncMock(
|
||||
__aenter__=mocker.AsyncMock(return_value=mock_tx),
|
||||
__aexit__=mocker.AsyncMock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
mock_sl = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_sl.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
||||
|
||||
mock_slv = mocker.patch("prisma.models.StoreListingVersion.prisma")
|
||||
mock_slv.return_value.create = mocker.AsyncMock(return_value=mock_version)
|
||||
|
||||
# Call function
|
||||
result = await db.create_store_submission(
|
||||
user_id="user-id",
|
||||
agent_id="agent-id",
|
||||
agent_version=1,
|
||||
graph_id="agent-id",
|
||||
graph_version=1,
|
||||
slug="test-agent",
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
@@ -281,11 +266,11 @@ async def test_create_store_submission(mocker):
|
||||
# Verify results
|
||||
assert result.name == "Test Agent"
|
||||
assert result.description == "Test description"
|
||||
assert result.store_listing_version_id == "version-id"
|
||||
assert result.listing_version_id == "version-id"
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_agent_graph.return_value.find_first.assert_called_once()
|
||||
mock_store_listing.return_value.create.assert_called_once()
|
||||
mock_slv.return_value.create.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@@ -318,7 +303,6 @@ async def test_update_profile(mocker):
|
||||
description="Test description",
|
||||
links=["link1"],
|
||||
avatar_url="avatar.jpg",
|
||||
is_featured=False,
|
||||
)
|
||||
|
||||
# Call function
|
||||
@@ -389,7 +373,7 @@ async def test_get_store_agents_with_search_and_filters_parameterized():
|
||||
creators=["creator1'; DROP TABLE Users; --", "creator2"],
|
||||
category="AI'; DELETE FROM StoreAgent; --",
|
||||
featured=True,
|
||||
sorted_by="rating",
|
||||
sorted_by=db.StoreAgentsSortOptions.RATING,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
@@ -57,12 +57,6 @@ class StoreError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class AgentNotFoundError(NotFoundError):
|
||||
"""Raised when an agent is not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CreatorNotFoundError(NotFoundError):
|
||||
"""Raised when a creator is not found"""
|
||||
|
||||
|
||||
@@ -568,7 +568,7 @@ async def hybrid_search(
|
||||
SELECT uce."contentId" as "storeListingVersionId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON uce."contentId" = sa."storeListingVersionId"
|
||||
ON uce."contentId" = sa.listing_version_id
|
||||
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
AND uce."userId" IS NULL
|
||||
AND uce.search @@ plainto_tsquery('english', {query_param})
|
||||
@@ -582,7 +582,7 @@ async def hybrid_search(
|
||||
SELECT uce."contentId", uce.embedding
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON uce."contentId" = sa."storeListingVersionId"
|
||||
ON uce."contentId" = sa.listing_version_id
|
||||
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
AND uce."userId" IS NULL
|
||||
AND {where_clause}
|
||||
@@ -605,7 +605,7 @@ async def hybrid_search(
|
||||
sa.featured,
|
||||
sa.is_available,
|
||||
sa.updated_at,
|
||||
sa."agentGraphId",
|
||||
sa.graph_id,
|
||||
-- Searchable text for BM25 reranking
|
||||
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
|
||||
-- Semantic score
|
||||
@@ -627,9 +627,9 @@ async def hybrid_search(
|
||||
sa.runs as popularity_raw
|
||||
FROM candidates c
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON c."storeListingVersionId" = sa."storeListingVersionId"
|
||||
ON c."storeListingVersionId" = sa.listing_version_id
|
||||
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
ON sa."storeListingVersionId" = uce."contentId"
|
||||
ON sa.listing_version_id = uce."contentId"
|
||||
AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
),
|
||||
max_vals AS (
|
||||
@@ -665,7 +665,7 @@ async def hybrid_search(
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
"agentGraphId",
|
||||
graph_id,
|
||||
searchable_text,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import datetime
|
||||
from typing import List
|
||||
from typing import TYPE_CHECKING, List, Self
|
||||
|
||||
import prisma.enums
|
||||
import pydantic
|
||||
|
||||
from backend.util.models import Pagination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import prisma.models
|
||||
|
||||
|
||||
class ChangelogEntry(pydantic.BaseModel):
|
||||
version: str
|
||||
@@ -13,9 +16,9 @@ class ChangelogEntry(pydantic.BaseModel):
|
||||
date: datetime.datetime
|
||||
|
||||
|
||||
class MyAgent(pydantic.BaseModel):
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
class MyUnpublishedAgent(pydantic.BaseModel):
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
agent_name: str
|
||||
agent_image: str | None = None
|
||||
description: str
|
||||
@@ -23,8 +26,8 @@ class MyAgent(pydantic.BaseModel):
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
|
||||
class MyAgentsResponse(pydantic.BaseModel):
|
||||
agents: list[MyAgent]
|
||||
class MyUnpublishedAgentsResponse(pydantic.BaseModel):
|
||||
agents: list[MyUnpublishedAgent]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
@@ -40,6 +43,21 @@ class StoreAgent(pydantic.BaseModel):
|
||||
rating: float
|
||||
agent_graph_id: str
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgent":
|
||||
return cls(
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_image=agent.agent_image[0] if agent.agent_image else "",
|
||||
creator=agent.creator_username or "Needs Profile",
|
||||
creator_avatar=agent.creator_avatar or "",
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
agent_graph_id=agent.graph_id,
|
||||
)
|
||||
|
||||
|
||||
class StoreAgentsResponse(pydantic.BaseModel):
|
||||
agents: list[StoreAgent]
|
||||
@@ -62,81 +80,192 @@ class StoreAgentDetails(pydantic.BaseModel):
|
||||
runs: int
|
||||
rating: float
|
||||
versions: list[str]
|
||||
agentGraphVersions: list[str]
|
||||
agentGraphId: str
|
||||
graph_id: str
|
||||
graph_versions: list[str]
|
||||
last_updated: datetime.datetime
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
active_version_id: str | None = None
|
||||
has_approved_version: bool = False
|
||||
active_version_id: str
|
||||
has_approved_version: bool
|
||||
|
||||
# Optional changelog data when include_changelog=True
|
||||
changelog: list[ChangelogEntry] | None = None
|
||||
|
||||
|
||||
class Creator(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
avatar_url: str
|
||||
num_agents: int
|
||||
agent_rating: float
|
||||
agent_runs: int
|
||||
is_featured: bool
|
||||
|
||||
|
||||
class CreatorsResponse(pydantic.BaseModel):
|
||||
creators: List[Creator]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class CreatorDetails(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
links: list[str]
|
||||
avatar_url: str
|
||||
agent_rating: float
|
||||
agent_runs: int
|
||||
top_categories: list[str]
|
||||
@classmethod
|
||||
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgentDetails":
|
||||
return cls(
|
||||
store_listing_version_id=agent.listing_version_id,
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_video=agent.agent_video or "",
|
||||
agent_output_demo=agent.agent_output_demo or "",
|
||||
agent_image=agent.agent_image,
|
||||
creator=agent.creator_username or "",
|
||||
creator_avatar=agent.creator_avatar or "",
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
categories=agent.categories,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
versions=agent.versions,
|
||||
graph_id=agent.graph_id,
|
||||
graph_versions=agent.graph_versions,
|
||||
last_updated=agent.updated_at,
|
||||
recommended_schedule_cron=agent.recommended_schedule_cron,
|
||||
active_version_id=agent.listing_version_id,
|
||||
has_approved_version=True, # StoreAgent view only has approved agents
|
||||
)
|
||||
|
||||
|
||||
class Profile(pydantic.BaseModel):
|
||||
name: str
|
||||
"""Marketplace user profile (only attributes that the user can update)"""
|
||||
|
||||
username: str
|
||||
name: str
|
||||
description: str
|
||||
avatar_url: str | None
|
||||
links: list[str]
|
||||
avatar_url: str
|
||||
is_featured: bool = False
|
||||
|
||||
|
||||
class ProfileDetails(Profile):
|
||||
"""Marketplace user profile (including read-only fields)"""
|
||||
|
||||
is_featured: bool
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, profile: "prisma.models.Profile") -> "ProfileDetails":
|
||||
return cls(
|
||||
name=profile.name,
|
||||
username=profile.username,
|
||||
avatar_url=profile.avatarUrl,
|
||||
description=profile.description,
|
||||
links=profile.links,
|
||||
is_featured=profile.isFeatured,
|
||||
)
|
||||
|
||||
|
||||
class CreatorDetails(ProfileDetails):
|
||||
"""Marketplace creator profile details, including aggregated stats"""
|
||||
|
||||
num_agents: int
|
||||
agent_runs: int
|
||||
agent_rating: float
|
||||
top_categories: list[str]
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, creator: "prisma.models.Creator") -> "CreatorDetails": # type: ignore[override]
|
||||
return cls(
|
||||
name=creator.name,
|
||||
username=creator.username,
|
||||
avatar_url=creator.avatar_url,
|
||||
description=creator.description,
|
||||
links=creator.links,
|
||||
is_featured=creator.is_featured,
|
||||
num_agents=creator.num_agents,
|
||||
agent_runs=creator.agent_runs,
|
||||
agent_rating=creator.agent_rating,
|
||||
top_categories=creator.top_categories,
|
||||
)
|
||||
|
||||
|
||||
class CreatorsResponse(pydantic.BaseModel):
|
||||
creators: List[CreatorDetails]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreSubmission(pydantic.BaseModel):
|
||||
# From StoreListing:
|
||||
listing_id: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
user_id: str
|
||||
slug: str
|
||||
|
||||
# From StoreListingVersion:
|
||||
listing_version_id: str
|
||||
listing_version: int
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
name: str
|
||||
sub_heading: str
|
||||
slug: str
|
||||
description: str
|
||||
instructions: str | None = None
|
||||
instructions: str | None
|
||||
categories: list[str]
|
||||
image_urls: list[str]
|
||||
date_submitted: datetime.datetime
|
||||
status: prisma.enums.SubmissionStatus
|
||||
runs: int
|
||||
rating: float
|
||||
store_listing_version_id: str | None = None
|
||||
version: int | None = None # Actual version number from the database
|
||||
video_url: str | None
|
||||
agent_output_demo_url: str | None
|
||||
|
||||
submitted_at: datetime.datetime | None
|
||||
changes_summary: str | None
|
||||
status: prisma.enums.SubmissionStatus
|
||||
reviewed_at: datetime.datetime | None = None
|
||||
reviewer_id: str | None = None
|
||||
review_comments: str | None = None # External comments visible to creator
|
||||
internal_comments: str | None = None # Private notes for admin use only
|
||||
reviewed_at: datetime.datetime | None = None
|
||||
changes_summary: str | None = None
|
||||
|
||||
# Additional fields for editing
|
||||
video_url: str | None = None
|
||||
agent_output_demo_url: str | None = None
|
||||
categories: list[str] = []
|
||||
# Aggregated from AgentGraphExecutions and StoreListingReviews:
|
||||
run_count: int = 0
|
||||
review_count: int = 0
|
||||
review_avg_rating: float = 0.0
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
|
||||
"""Construct from the StoreSubmission Prisma view."""
|
||||
return cls(
|
||||
listing_id=_sub.listing_id,
|
||||
user_id=_sub.user_id,
|
||||
slug=_sub.slug,
|
||||
listing_version_id=_sub.listing_version_id,
|
||||
listing_version=_sub.listing_version,
|
||||
graph_id=_sub.graph_id,
|
||||
graph_version=_sub.graph_version,
|
||||
name=_sub.name,
|
||||
sub_heading=_sub.sub_heading,
|
||||
description=_sub.description,
|
||||
instructions=_sub.instructions,
|
||||
categories=_sub.categories,
|
||||
image_urls=_sub.image_urls,
|
||||
video_url=_sub.video_url,
|
||||
agent_output_demo_url=_sub.agent_output_demo_url,
|
||||
submitted_at=_sub.submitted_at,
|
||||
changes_summary=_sub.changes_summary,
|
||||
status=_sub.status,
|
||||
reviewed_at=_sub.reviewed_at,
|
||||
reviewer_id=_sub.reviewer_id,
|
||||
review_comments=_sub.review_comments,
|
||||
run_count=_sub.run_count,
|
||||
review_count=_sub.review_count,
|
||||
review_avg_rating=_sub.review_avg_rating,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
|
||||
"""
|
||||
Construct from the StoreListingVersion Prisma model (with StoreListing included)
|
||||
"""
|
||||
if not (_l := _lv.StoreListing):
|
||||
raise ValueError("StoreListingVersion must have included StoreListing")
|
||||
|
||||
return cls(
|
||||
listing_id=_l.id,
|
||||
user_id=_l.owningUserId,
|
||||
slug=_l.slug,
|
||||
listing_version_id=_lv.id,
|
||||
listing_version=_lv.version,
|
||||
graph_id=_lv.agentGraphId,
|
||||
graph_version=_lv.agentGraphVersion,
|
||||
name=_lv.name,
|
||||
sub_heading=_lv.subHeading,
|
||||
description=_lv.description,
|
||||
instructions=_lv.instructions,
|
||||
categories=_lv.categories,
|
||||
image_urls=_lv.imageUrls,
|
||||
video_url=_lv.videoUrl,
|
||||
agent_output_demo_url=_lv.agentOutputDemoUrl,
|
||||
submitted_at=_lv.submittedAt,
|
||||
changes_summary=_lv.changesSummary,
|
||||
status=_lv.submissionStatus,
|
||||
reviewed_at=_lv.reviewedAt,
|
||||
reviewer_id=_lv.reviewerId,
|
||||
review_comments=_lv.reviewComments,
|
||||
)
|
||||
|
||||
|
||||
class StoreSubmissionsResponse(pydantic.BaseModel):
|
||||
@@ -144,33 +273,12 @@ class StoreSubmissionsResponse(pydantic.BaseModel):
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreListingWithVersions(pydantic.BaseModel):
|
||||
"""A store listing with its version history"""
|
||||
|
||||
listing_id: str
|
||||
slug: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
active_version_id: str | None = None
|
||||
has_approved_version: bool = False
|
||||
creator_email: str | None = None
|
||||
latest_version: StoreSubmission | None = None
|
||||
versions: list[StoreSubmission] = []
|
||||
|
||||
|
||||
class StoreListingsWithVersionsResponse(pydantic.BaseModel):
|
||||
"""Response model for listings with version history"""
|
||||
|
||||
listings: list[StoreListingWithVersions]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreSubmissionRequest(pydantic.BaseModel):
|
||||
agent_id: str = pydantic.Field(
|
||||
..., min_length=1, description="Agent ID cannot be empty"
|
||||
graph_id: str = pydantic.Field(
|
||||
..., min_length=1, description="Graph ID cannot be empty"
|
||||
)
|
||||
agent_version: int = pydantic.Field(
|
||||
..., gt=0, description="Agent version must be greater than 0"
|
||||
graph_version: int = pydantic.Field(
|
||||
..., gt=0, description="Graph version must be greater than 0"
|
||||
)
|
||||
slug: str
|
||||
name: str
|
||||
@@ -198,12 +306,42 @@ class StoreSubmissionEditRequest(pydantic.BaseModel):
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
|
||||
class ProfileDetails(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
links: list[str]
|
||||
avatar_url: str | None = None
|
||||
class StoreSubmissionAdminView(StoreSubmission):
|
||||
internal_comments: str | None # Private admin notes
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
|
||||
return cls(
|
||||
**StoreSubmission.from_db(_sub).model_dump(),
|
||||
internal_comments=_sub.internal_comments,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
|
||||
return cls(
|
||||
**StoreSubmission.from_listing_version(_lv).model_dump(),
|
||||
internal_comments=_lv.internalComments,
|
||||
)
|
||||
|
||||
|
||||
class StoreListingWithVersionsAdminView(pydantic.BaseModel):
|
||||
"""A store listing with its version history"""
|
||||
|
||||
listing_id: str
|
||||
graph_id: str
|
||||
slug: str
|
||||
active_listing_version_id: str | None = None
|
||||
has_approved_version: bool = False
|
||||
creator_email: str | None = None
|
||||
latest_version: StoreSubmissionAdminView | None = None
|
||||
versions: list[StoreSubmissionAdminView] = []
|
||||
|
||||
|
||||
class StoreListingsWithVersionsAdminViewResponse(pydantic.BaseModel):
|
||||
"""Response model for listings with version history"""
|
||||
|
||||
listings: list[StoreListingWithVersionsAdminView]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreReview(pydantic.BaseModel):
|
||||
|
||||
@@ -1,203 +0,0 @@
|
||||
import datetime
|
||||
|
||||
import prisma.enums
|
||||
|
||||
from . import model as store_model
|
||||
|
||||
|
||||
def test_pagination():
|
||||
pagination = store_model.Pagination(
|
||||
total_items=100, total_pages=5, current_page=2, page_size=20
|
||||
)
|
||||
assert pagination.total_items == 100
|
||||
assert pagination.total_pages == 5
|
||||
assert pagination.current_page == 2
|
||||
assert pagination.page_size == 20
|
||||
|
||||
|
||||
def test_store_agent():
|
||||
agent = store_model.StoreAgent(
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_image="test.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
agent_graph_id="test-graph-id",
|
||||
)
|
||||
assert agent.slug == "test-agent"
|
||||
assert agent.agent_name == "Test Agent"
|
||||
assert agent.runs == 50
|
||||
assert agent.rating == 4.5
|
||||
assert agent.agent_graph_id == "test-graph-id"
|
||||
|
||||
|
||||
def test_store_agents_response():
|
||||
response = store_model.StoreAgentsResponse(
|
||||
agents=[
|
||||
store_model.StoreAgent(
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_image="test.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
agent_graph_id="test-graph-id",
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
assert len(response.agents) == 1
|
||||
assert response.pagination.total_items == 1
|
||||
|
||||
|
||||
def test_store_agent_details():
|
||||
details = store_model.StoreAgentDetails(
|
||||
store_listing_version_id="version123",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video="video.mp4",
|
||||
agent_output_demo="demo.mp4",
|
||||
agent_image=["image1.jpg", "image2.jpg"],
|
||||
creator="creator1",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
categories=["cat1", "cat2"],
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
versions=["1.0", "2.0"],
|
||||
agentGraphVersions=["1", "2"],
|
||||
agentGraphId="test-graph-id",
|
||||
last_updated=datetime.datetime.now(),
|
||||
)
|
||||
assert details.slug == "test-agent"
|
||||
assert len(details.agent_image) == 2
|
||||
assert len(details.categories) == 2
|
||||
assert len(details.versions) == 2
|
||||
|
||||
|
||||
def test_creator():
|
||||
creator = store_model.Creator(
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
name="Test Creator",
|
||||
username="creator1",
|
||||
description="Test description",
|
||||
avatar_url="avatar.jpg",
|
||||
num_agents=5,
|
||||
is_featured=False,
|
||||
)
|
||||
assert creator.name == "Test Creator"
|
||||
assert creator.num_agents == 5
|
||||
|
||||
|
||||
def test_creators_response():
|
||||
response = store_model.CreatorsResponse(
|
||||
creators=[
|
||||
store_model.Creator(
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
name="Test Creator",
|
||||
username="creator1",
|
||||
description="Test description",
|
||||
avatar_url="avatar.jpg",
|
||||
num_agents=5,
|
||||
is_featured=False,
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
assert len(response.creators) == 1
|
||||
assert response.pagination.total_items == 1
|
||||
|
||||
|
||||
def test_creator_details():
|
||||
details = store_model.CreatorDetails(
|
||||
name="Test Creator",
|
||||
username="creator1",
|
||||
description="Test description",
|
||||
links=["link1.com", "link2.com"],
|
||||
avatar_url="avatar.jpg",
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
top_categories=["cat1", "cat2"],
|
||||
)
|
||||
assert details.name == "Test Creator"
|
||||
assert len(details.links) == 2
|
||||
assert details.agent_rating == 4.8
|
||||
assert len(details.top_categories) == 2
|
||||
|
||||
|
||||
def test_store_submission():
|
||||
submission = store_model.StoreSubmission(
|
||||
listing_id="listing123",
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
description="Test description",
|
||||
image_urls=["image1.jpg", "image2.jpg"],
|
||||
date_submitted=datetime.datetime(2023, 1, 1),
|
||||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
)
|
||||
assert submission.name == "Test Agent"
|
||||
assert len(submission.image_urls) == 2
|
||||
assert submission.status == prisma.enums.SubmissionStatus.PENDING
|
||||
|
||||
|
||||
def test_store_submissions_response():
|
||||
response = store_model.StoreSubmissionsResponse(
|
||||
submissions=[
|
||||
store_model.StoreSubmission(
|
||||
listing_id="listing123",
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
description="Test description",
|
||||
image_urls=["image1.jpg"],
|
||||
date_submitted=datetime.datetime(2023, 1, 1),
|
||||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
assert len(response.submissions) == 1
|
||||
assert response.pagination.total_items == 1
|
||||
|
||||
|
||||
def test_store_submission_request():
|
||||
request = store_model.StoreSubmissionRequest(
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
slug="test-agent",
|
||||
name="Test Agent",
|
||||
sub_heading="Test subheading",
|
||||
video_url="video.mp4",
|
||||
image_urls=["image1.jpg", "image2.jpg"],
|
||||
description="Test description",
|
||||
categories=["cat1", "cat2"],
|
||||
)
|
||||
assert request.agent_id == "agent123"
|
||||
assert request.agent_version == 1
|
||||
assert len(request.image_urls) == 2
|
||||
assert len(request.categories) == 2
|
||||
@@ -1,16 +1,17 @@
|
||||
import logging
|
||||
import tempfile
|
||||
import typing
|
||||
import urllib.parse
|
||||
from typing import Literal
|
||||
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
import prisma.enums
|
||||
from fastapi import Query, Security
|
||||
from pydantic import BaseModel
|
||||
|
||||
import backend.data.graph
|
||||
import backend.util.json
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from . import cache as store_cache
|
||||
@@ -34,22 +35,15 @@ router = fastapi.APIRouter()
|
||||
"/profile",
|
||||
summary="Get user profile",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.ProfileDetails,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_profile(
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Get the profile details for the authenticated user.
|
||||
Cached for 1 hour per user.
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.ProfileDetails:
|
||||
"""Get the profile details for the authenticated user."""
|
||||
profile = await store_db.get_user_profile(user_id)
|
||||
if profile is None:
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=404,
|
||||
content={"detail": "Profile not found"},
|
||||
)
|
||||
raise NotFoundError("User does not have a profile yet")
|
||||
return profile
|
||||
|
||||
|
||||
@@ -57,98 +51,17 @@ async def get_profile(
|
||||
"/profile",
|
||||
summary="Update user profile",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.CreatorDetails,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def update_or_create_profile(
|
||||
profile: store_model.Profile,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Update the store profile for the authenticated user.
|
||||
|
||||
Args:
|
||||
profile (Profile): The updated profile details
|
||||
user_id (str): ID of the authenticated user
|
||||
|
||||
Returns:
|
||||
CreatorDetails: The updated profile
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error updating the profile
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.ProfileDetails:
|
||||
"""Update the store profile for the authenticated user."""
|
||||
updated_profile = await store_db.update_profile(user_id=user_id, profile=profile)
|
||||
return updated_profile
|
||||
|
||||
|
||||
##############################################
|
||||
############### Agent Endpoints ##############
|
||||
##############################################
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents",
|
||||
summary="List store agents",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.StoreAgentsResponse,
|
||||
)
|
||||
async def get_agents(
|
||||
featured: bool = False,
|
||||
creator: str | None = None,
|
||||
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
|
||||
search_query: str | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""
|
||||
Get a paginated list of agents from the store with optional filtering and sorting.
|
||||
|
||||
Args:
|
||||
featured (bool, optional): Filter to only show featured agents. Defaults to False.
|
||||
creator (str | None, optional): Filter agents by creator username. Defaults to None.
|
||||
sorted_by (str | None, optional): Sort agents by "runs" or "rating". Defaults to None.
|
||||
search_query (str | None, optional): Search agents by name, subheading and description. Defaults to None.
|
||||
category (str | None, optional): Filter agents by category. Defaults to None.
|
||||
page (int, optional): Page number for pagination. Defaults to 1.
|
||||
page_size (int, optional): Number of agents per page. Defaults to 20.
|
||||
|
||||
Returns:
|
||||
StoreAgentsResponse: Paginated list of agents matching the filters
|
||||
|
||||
Raises:
|
||||
HTTPException: If page or page_size are less than 1
|
||||
|
||||
Used for:
|
||||
- Home Page Featured Agents
|
||||
- Home Page Top Agents
|
||||
- Search Results
|
||||
- Agent Details - Other Agents By Creator
|
||||
- Agent Details - Similar Agents
|
||||
- Creator Details - Agents By Creator
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
agents = await store_cache._get_cached_store_agents(
|
||||
featured=featured,
|
||||
creator=creator,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return agents
|
||||
|
||||
|
||||
##############################################
|
||||
############### Search Endpoints #############
|
||||
##############################################
|
||||
@@ -158,60 +71,30 @@ async def get_agents(
|
||||
"/search",
|
||||
summary="Unified search across all content types",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.UnifiedSearchResponse,
|
||||
)
|
||||
async def unified_search(
|
||||
query: str,
|
||||
content_types: list[str] | None = fastapi.Query(
|
||||
content_types: list[prisma.enums.ContentType] | None = Query(
|
||||
default=None,
|
||||
description="Content types to search: STORE_AGENT, BLOCK, DOCUMENTATION. If not specified, searches all.",
|
||||
description="Content types to search. If not specified, searches all.",
|
||||
),
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
user_id: str | None = fastapi.Security(
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
user_id: str | None = Security(
|
||||
autogpt_libs.auth.get_optional_user_id, use_cache=False
|
||||
),
|
||||
):
|
||||
) -> store_model.UnifiedSearchResponse:
|
||||
"""
|
||||
Search across all content types (store agents, blocks, documentation) using hybrid search.
|
||||
Search across all content types (marketplace agents, blocks, documentation)
|
||||
using hybrid search.
|
||||
|
||||
Combines semantic (embedding-based) and lexical (text-based) search for best results.
|
||||
|
||||
Args:
|
||||
query: The search query string
|
||||
content_types: Optional list of content types to filter by (STORE_AGENT, BLOCK, DOCUMENTATION)
|
||||
page: Page number for pagination (default 1)
|
||||
page_size: Number of results per page (default 20)
|
||||
user_id: Optional authenticated user ID (for user-scoped content in future)
|
||||
|
||||
Returns:
|
||||
UnifiedSearchResponse: Paginated list of search results with relevance scores
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
# Convert string content types to enum
|
||||
content_type_enums: list[prisma.enums.ContentType] | None = None
|
||||
if content_types:
|
||||
try:
|
||||
content_type_enums = [prisma.enums.ContentType(ct) for ct in content_types]
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Invalid content type. Valid values: STORE_AGENT, BLOCK, DOCUMENTATION. Error: {e}",
|
||||
)
|
||||
|
||||
# Perform unified hybrid search
|
||||
results, total = await store_hybrid_search.unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=content_type_enums,
|
||||
content_types=content_types,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
@@ -245,22 +128,69 @@ async def unified_search(
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
############### Agent Endpoints ##############
|
||||
##############################################
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents",
|
||||
summary="List store agents",
|
||||
tags=["store", "public"],
|
||||
)
|
||||
async def get_agents(
|
||||
featured: bool = Query(
|
||||
default=False, description="Filter to only show featured agents"
|
||||
),
|
||||
creator: str | None = Query(
|
||||
default=None, description="Filter agents by creator username"
|
||||
),
|
||||
category: str | None = Query(default=None, description="Filter agents by category"),
|
||||
search_query: str | None = Query(
|
||||
default=None, description="Literal + semantic search on names and descriptions"
|
||||
),
|
||||
sorted_by: store_db.StoreAgentsSortOptions | None = Query(
|
||||
default=None,
|
||||
description="Property to sort results by. Ignored if search_query is provided.",
|
||||
),
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
) -> store_model.StoreAgentsResponse:
|
||||
"""
|
||||
Get a paginated list of agents from the marketplace,
|
||||
with optional filtering and sorting.
|
||||
|
||||
Used for:
|
||||
- Home Page Featured Agents
|
||||
- Home Page Top Agents
|
||||
- Search Results
|
||||
- Agent Details - Other Agents By Creator
|
||||
- Agent Details - Similar Agents
|
||||
- Creator Details - Agents By Creator
|
||||
"""
|
||||
agents = await store_cache._get_cached_store_agents(
|
||||
featured=featured,
|
||||
creator=creator,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return agents
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents/{username}/{agent_name}",
|
||||
summary="Get specific agent",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.StoreAgentDetails,
|
||||
)
|
||||
async def get_agent(
|
||||
async def get_agent_by_name(
|
||||
username: str,
|
||||
agent_name: str,
|
||||
include_changelog: bool = fastapi.Query(default=False),
|
||||
):
|
||||
"""
|
||||
This is only used on the AgentDetails Page.
|
||||
|
||||
It returns the store listing agents details.
|
||||
"""
|
||||
include_changelog: bool = Query(default=False),
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""Get details of a marketplace agent"""
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
# URL decode the agent name since it comes from the URL path
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
@@ -270,76 +200,82 @@ async def get_agent(
|
||||
return agent
|
||||
|
||||
|
||||
@router.get(
|
||||
"/graph/{store_listing_version_id}",
|
||||
summary="Get agent graph",
|
||||
tags=["store"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_graph_meta_by_store_listing_version_id(
|
||||
store_listing_version_id: str,
|
||||
) -> backend.data.graph.GraphModelWithoutNodes:
|
||||
"""
|
||||
Get Agent Graph from Store Listing Version ID.
|
||||
"""
|
||||
graph = await store_db.get_available_graph(store_listing_version_id)
|
||||
return graph
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents/{store_listing_version_id}",
|
||||
summary="Get agent by version",
|
||||
tags=["store"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreAgentDetails,
|
||||
)
|
||||
async def get_store_agent(store_listing_version_id: str):
|
||||
"""
|
||||
Get Store Agent Details from Store Listing Version ID.
|
||||
"""
|
||||
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@router.post(
|
||||
"/agents/{username}/{agent_name}/review",
|
||||
summary="Create agent review",
|
||||
tags=["store"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreReview,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def create_review(
|
||||
async def post_user_review_for_agent(
|
||||
username: str,
|
||||
agent_name: str,
|
||||
review: store_model.StoreReviewCreate,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Create a review for a store agent.
|
||||
|
||||
Args:
|
||||
username: Creator's username
|
||||
agent_name: Name/slug of the agent
|
||||
review: Review details including score and optional comments
|
||||
user_id: ID of authenticated user creating the review
|
||||
|
||||
Returns:
|
||||
The created review
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.StoreReview:
|
||||
"""Post a user review on a marketplace agent listing"""
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
# Create the review
|
||||
|
||||
created_review = await store_db.create_store_review(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=review.store_listing_version_id,
|
||||
score=review.score,
|
||||
comments=review.comments,
|
||||
)
|
||||
|
||||
return created_review
|
||||
|
||||
|
||||
@router.get(
|
||||
"/listings/versions/{store_listing_version_id}",
|
||||
summary="Get agent by version",
|
||||
tags=["store"],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_agent_by_listing_version(
|
||||
store_listing_version_id: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
|
||||
return agent
|
||||
|
||||
|
||||
@router.get(
|
||||
"/listings/versions/{store_listing_version_id}/graph",
|
||||
summary="Get agent graph",
|
||||
tags=["store"],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_graph_meta_by_store_listing_version_id(
|
||||
store_listing_version_id: str,
|
||||
) -> backend.data.graph.GraphModelWithoutNodes:
|
||||
"""Get outline of graph belonging to a specific marketplace listing version"""
|
||||
graph = await store_db.get_available_graph(store_listing_version_id)
|
||||
return graph
|
||||
|
||||
|
||||
@router.get(
|
||||
"/listings/versions/{store_listing_version_id}/graph/download",
|
||||
summary="Download agent file",
|
||||
tags=["store", "public"],
|
||||
)
|
||||
async def download_agent_file(
|
||||
store_listing_version_id: str,
|
||||
) -> fastapi.responses.FileResponse:
|
||||
"""Download agent graph file for a specific marketplace listing version"""
|
||||
graph_data = await store_db.get_agent(store_listing_version_id)
|
||||
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
||||
|
||||
# Sending graph as a stream (similar to marketplace v1)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(backend.util.json.dumps(graph_data))
|
||||
tmp_file.flush()
|
||||
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
############# Creator Endpoints #############
|
||||
##############################################
|
||||
@@ -349,37 +285,19 @@ async def create_review(
|
||||
"/creators",
|
||||
summary="List store creators",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.CreatorsResponse,
|
||||
)
|
||||
async def get_creators(
|
||||
featured: bool = False,
|
||||
search_query: str | None = None,
|
||||
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""
|
||||
This is needed for:
|
||||
- Home Page Featured Creators
|
||||
- Search Results Page
|
||||
|
||||
---
|
||||
|
||||
To support this functionality we need:
|
||||
- featured: bool - to limit the list to just featured agents
|
||||
- search_query: str - vector search based on the creators profile description.
|
||||
- sorted_by: [agent_rating, agent_runs] -
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
featured: bool = Query(
|
||||
default=False, description="Filter to only show featured creators"
|
||||
),
|
||||
search_query: str | None = Query(
|
||||
default=None, description="Literal + semantic search on names and descriptions"
|
||||
),
|
||||
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
) -> store_model.CreatorsResponse:
|
||||
"""List or search marketplace creators"""
|
||||
creators = await store_cache._get_cached_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
@@ -391,18 +309,12 @@ async def get_creators(
|
||||
|
||||
|
||||
@router.get(
|
||||
"/creator/{username}",
|
||||
"/creators/{username}",
|
||||
summary="Get creator details",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.CreatorDetails,
|
||||
)
|
||||
async def get_creator(
|
||||
username: str,
|
||||
):
|
||||
"""
|
||||
Get the details of a creator.
|
||||
- Creator Details Page
|
||||
"""
|
||||
async def get_creator(username: str) -> store_model.CreatorDetails:
|
||||
"""Get details on a marketplace creator"""
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
creator = await store_cache._get_cached_creator_details(username=username)
|
||||
return creator
|
||||
@@ -414,20 +326,17 @@ async def get_creator(
|
||||
|
||||
|
||||
@router.get(
|
||||
"/myagents",
|
||||
"/my-unpublished-agents",
|
||||
summary="Get my agents",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.MyAgentsResponse,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_my_agents(
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
page: typing.Annotated[int, fastapi.Query(ge=1)] = 1,
|
||||
page_size: typing.Annotated[int, fastapi.Query(ge=1)] = 20,
|
||||
):
|
||||
"""
|
||||
Get user's own agents.
|
||||
"""
|
||||
async def get_my_unpublished_agents(
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
) -> store_model.MyUnpublishedAgentsResponse:
|
||||
"""List the authenticated user's unpublished agents"""
|
||||
agents = await store_db.get_my_agents(user_id, page=page, page_size=page_size)
|
||||
return agents
|
||||
|
||||
@@ -436,28 +345,17 @@ async def get_my_agents(
|
||||
"/submissions/{submission_id}",
|
||||
summary="Delete store submission",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=bool,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def delete_submission(
|
||||
submission_id: str,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Delete a store listing submission.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the authenticated user
|
||||
submission_id (str): ID of the submission to be deleted
|
||||
|
||||
Returns:
|
||||
bool: True if the submission was successfully deleted, False otherwise
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> bool:
|
||||
"""Delete a marketplace listing submission"""
|
||||
result = await store_db.delete_store_submission(
|
||||
user_id=user_id,
|
||||
submission_id=submission_id,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -465,37 +363,14 @@ async def delete_submission(
|
||||
"/submissions",
|
||||
summary="List my submissions",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreSubmissionsResponse,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_submissions(
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""
|
||||
Get a paginated list of store submissions for the authenticated user.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the authenticated user
|
||||
page (int, optional): Page number for pagination. Defaults to 1.
|
||||
page_size (int, optional): Number of submissions per page. Defaults to 20.
|
||||
|
||||
Returns:
|
||||
StoreListingsResponse: Paginated list of store submissions
|
||||
|
||||
Raises:
|
||||
HTTPException: If page or page_size are less than 1
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
) -> store_model.StoreSubmissionsResponse:
|
||||
"""List the authenticated user's marketplace listing submissions"""
|
||||
listings = await store_db.get_store_submissions(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
@@ -508,30 +383,17 @@ async def get_submissions(
|
||||
"/submissions",
|
||||
summary="Create store submission",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreSubmission,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def create_submission(
|
||||
submission_request: store_model.StoreSubmissionRequest,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Create a new store listing submission.
|
||||
|
||||
Args:
|
||||
submission_request (StoreSubmissionRequest): The submission details
|
||||
user_id (str): ID of the authenticated user submitting the listing
|
||||
|
||||
Returns:
|
||||
StoreSubmission: The created store submission
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error creating the submission
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.StoreSubmission:
|
||||
"""Submit a new marketplace listing for review"""
|
||||
result = await store_db.create_store_submission(
|
||||
user_id=user_id,
|
||||
agent_id=submission_request.agent_id,
|
||||
agent_version=submission_request.agent_version,
|
||||
graph_id=submission_request.graph_id,
|
||||
graph_version=submission_request.graph_version,
|
||||
slug=submission_request.slug,
|
||||
name=submission_request.name,
|
||||
video_url=submission_request.video_url,
|
||||
@@ -544,7 +406,6 @@ async def create_submission(
|
||||
changes_summary=submission_request.changes_summary or "Initial Submission",
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -552,28 +413,14 @@ async def create_submission(
|
||||
"/submissions/{store_listing_version_id}",
|
||||
summary="Edit store submission",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreSubmission,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def edit_submission(
|
||||
store_listing_version_id: str,
|
||||
submission_request: store_model.StoreSubmissionEditRequest,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Edit an existing store listing submission.
|
||||
|
||||
Args:
|
||||
store_listing_version_id (str): ID of the store listing version to edit
|
||||
submission_request (StoreSubmissionRequest): The updated submission details
|
||||
user_id (str): ID of the authenticated user editing the listing
|
||||
|
||||
Returns:
|
||||
StoreSubmission: The updated store submission
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error editing the submission
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.StoreSubmission:
|
||||
"""Update a pending marketplace listing submission"""
|
||||
result = await store_db.edit_store_submission(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
@@ -588,7 +435,6 @@ async def edit_submission(
|
||||
changes_summary=submission_request.changes_summary,
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -596,115 +442,61 @@ async def edit_submission(
|
||||
"/submissions/media",
|
||||
summary="Upload submission media",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def upload_submission_media(
|
||||
file: fastapi.UploadFile,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Upload media (images/videos) for a store listing submission.
|
||||
|
||||
Args:
|
||||
file (UploadFile): The media file to upload
|
||||
user_id (str): ID of the authenticated user uploading the media
|
||||
|
||||
Returns:
|
||||
str: URL of the uploaded media file
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error uploading the media
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> str:
|
||||
"""Upload media for a marketplace listing submission"""
|
||||
media_url = await store_media.upload_media(user_id=user_id, file=file)
|
||||
return media_url
|
||||
|
||||
|
||||
class ImageURLResponse(BaseModel):
|
||||
image_url: str
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/generate_image",
|
||||
summary="Generate submission image",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def generate_image(
|
||||
agent_id: str,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
) -> fastapi.responses.Response:
|
||||
graph_id: str,
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> ImageURLResponse:
|
||||
"""
|
||||
Generate an image for a store listing submission.
|
||||
|
||||
Args:
|
||||
agent_id (str): ID of the agent to generate an image for
|
||||
user_id (str): ID of the authenticated user
|
||||
|
||||
Returns:
|
||||
JSONResponse: JSON containing the URL of the generated image
|
||||
Generate an image for a marketplace listing submission based on the properties
|
||||
of a given graph.
|
||||
"""
|
||||
agent = await backend.data.graph.get_graph(
|
||||
graph_id=agent_id, version=None, user_id=user_id
|
||||
graph = await backend.data.graph.get_graph(
|
||||
graph_id=graph_id, version=None, user_id=user_id
|
||||
)
|
||||
|
||||
if not agent:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(f"Agent graph #{graph_id} not found")
|
||||
# Use .jpeg here since we are generating JPEG images
|
||||
filename = f"agent_{agent_id}.jpeg"
|
||||
filename = f"agent_{graph_id}.jpeg"
|
||||
|
||||
existing_url = await store_media.check_media_exists(user_id, filename)
|
||||
if existing_url:
|
||||
logger.info(f"Using existing image for agent {agent_id}")
|
||||
return fastapi.responses.JSONResponse(content={"image_url": existing_url})
|
||||
logger.info(f"Using existing image for agent graph {graph_id}")
|
||||
return ImageURLResponse(image_url=existing_url)
|
||||
# Generate agent image as JPEG
|
||||
image = await store_image_gen.generate_agent_image(agent=agent)
|
||||
image = await store_image_gen.generate_agent_image(agent=graph)
|
||||
|
||||
# Create UploadFile with the correct filename and content_type
|
||||
image_file = fastapi.UploadFile(
|
||||
file=image,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
image_url = await store_media.upload_media(
|
||||
user_id=user_id, file=image_file, use_file_name=True
|
||||
)
|
||||
|
||||
return fastapi.responses.JSONResponse(content={"image_url": image_url})
|
||||
|
||||
|
||||
@router.get(
|
||||
"/download/agents/{store_listing_version_id}",
|
||||
summary="Download agent file",
|
||||
tags=["store", "public"],
|
||||
)
|
||||
async def download_agent_file(
|
||||
store_listing_version_id: str = fastapi.Path(
|
||||
..., description="The ID of the agent to download"
|
||||
),
|
||||
) -> fastapi.responses.FileResponse:
|
||||
"""
|
||||
Download the agent file by streaming its content.
|
||||
|
||||
Args:
|
||||
store_listing_version_id (str): The ID of the agent to download
|
||||
|
||||
Returns:
|
||||
StreamingResponse: A streaming response containing the agent's graph data.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the agent is not found or an unexpected error occurs.
|
||||
"""
|
||||
graph_data = await store_db.get_agent(store_listing_version_id)
|
||||
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
||||
|
||||
# Sending graph as a stream (similar to marketplace v1)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(backend.util.json.dumps(graph_data))
|
||||
tmp_file.flush()
|
||||
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
return ImageURLResponse(image_url=image_url)
|
||||
|
||||
|
||||
##############################################
|
||||
|
||||
@@ -8,6 +8,8 @@ import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.api.features.store.db import StoreAgentsSortOptions
|
||||
|
||||
from . import model as store_model
|
||||
from . import routes as store_routes
|
||||
|
||||
@@ -196,7 +198,7 @@ def test_get_agents_sorted(
|
||||
mock_db_call.assert_called_once_with(
|
||||
featured=False,
|
||||
creators=None,
|
||||
sorted_by="runs",
|
||||
sorted_by=StoreAgentsSortOptions.RUNS,
|
||||
search_query=None,
|
||||
category=None,
|
||||
page=1,
|
||||
@@ -380,9 +382,11 @@ def test_get_agent_details(
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
versions=["1.0.0", "1.1.0"],
|
||||
agentGraphVersions=["1", "2"],
|
||||
agentGraphId="test-graph-id",
|
||||
graph_versions=["1", "2"],
|
||||
graph_id="test-graph-id",
|
||||
last_updated=FIXED_NOW,
|
||||
active_version_id="test-version-id",
|
||||
has_approved_version=True,
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_agent_details")
|
||||
mock_db_call.return_value = mocked_value
|
||||
@@ -435,15 +439,17 @@ def test_get_creators_pagination(
|
||||
) -> None:
|
||||
mocked_value = store_model.CreatorsResponse(
|
||||
creators=[
|
||||
store_model.Creator(
|
||||
store_model.CreatorDetails(
|
||||
name=f"Creator {i}",
|
||||
username=f"creator{i}",
|
||||
description=f"Creator {i} description",
|
||||
avatar_url=f"avatar{i}.jpg",
|
||||
num_agents=1,
|
||||
agent_rating=4.5,
|
||||
agent_runs=100,
|
||||
description=f"Creator {i} description",
|
||||
links=[f"user{i}.link.com"],
|
||||
is_featured=False,
|
||||
num_agents=1,
|
||||
agent_runs=100,
|
||||
agent_rating=4.5,
|
||||
top_categories=["cat1", "cat2", "cat3"],
|
||||
)
|
||||
for i in range(5)
|
||||
],
|
||||
@@ -496,19 +502,19 @@ def test_get_creator_details(
|
||||
mocked_value = store_model.CreatorDetails(
|
||||
name="Test User",
|
||||
username="creator1",
|
||||
avatar_url="avatar.jpg",
|
||||
description="Test creator description",
|
||||
links=["link1.com", "link2.com"],
|
||||
avatar_url="avatar.jpg",
|
||||
agent_rating=4.8,
|
||||
is_featured=True,
|
||||
num_agents=5,
|
||||
agent_runs=1000,
|
||||
agent_rating=4.8,
|
||||
top_categories=["category1", "category2"],
|
||||
)
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.api.features.store.db.get_store_creator_details"
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_creator")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/creator/creator1")
|
||||
response = client.get("/creators/creator1")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = store_model.CreatorDetails.model_validate(response.json())
|
||||
@@ -528,19 +534,26 @@ def test_get_submissions_success(
|
||||
submissions=[
|
||||
store_model.StoreSubmission(
|
||||
listing_id="test-listing-id",
|
||||
name="Test Agent",
|
||||
description="Test agent description",
|
||||
image_urls=["test.jpg"],
|
||||
date_submitted=FIXED_NOW,
|
||||
status=prisma.enums.SubmissionStatus.APPROVED,
|
||||
runs=50,
|
||||
rating=4.2,
|
||||
agent_id="test-agent-id",
|
||||
agent_version=1,
|
||||
sub_heading="Test agent subheading",
|
||||
user_id="test-user-id",
|
||||
slug="test-agent",
|
||||
video_url="test.mp4",
|
||||
listing_version_id="test-version-id",
|
||||
listing_version=1,
|
||||
graph_id="test-agent-id",
|
||||
graph_version=1,
|
||||
name="Test Agent",
|
||||
sub_heading="Test agent subheading",
|
||||
description="Test agent description",
|
||||
instructions="Click the button!",
|
||||
categories=["test-category"],
|
||||
image_urls=["test.jpg"],
|
||||
video_url="test.mp4",
|
||||
agent_output_demo_url="demo_video.mp4",
|
||||
submitted_at=FIXED_NOW,
|
||||
changes_summary="Initial Submission",
|
||||
status=prisma.enums.SubmissionStatus.APPROVED,
|
||||
run_count=50,
|
||||
review_count=5,
|
||||
review_avg_rating=4.2,
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
|
||||
@@ -11,6 +11,7 @@ import pytest
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from . import cache as store_cache
|
||||
from .db import StoreAgentsSortOptions
|
||||
from .model import StoreAgent, StoreAgentsResponse
|
||||
|
||||
|
||||
@@ -215,7 +216,7 @@ class TestCacheDeletion:
|
||||
await store_cache._get_cached_store_agents(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by="rating",
|
||||
sorted_by=StoreAgentsSortOptions.RATING,
|
||||
search_query="AI assistant",
|
||||
category="productivity",
|
||||
page=2,
|
||||
@@ -227,7 +228,7 @@ class TestCacheDeletion:
|
||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by="rating",
|
||||
sorted_by=StoreAgentsSortOptions.RATING,
|
||||
search_query="AI assistant",
|
||||
category="productivity",
|
||||
page=2,
|
||||
@@ -239,7 +240,7 @@ class TestCacheDeletion:
|
||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by="rating",
|
||||
sorted_by=StoreAgentsSortOptions.RATING,
|
||||
search_query="AI assistant",
|
||||
category="productivity",
|
||||
page=2,
|
||||
|
||||
@@ -55,6 +55,7 @@ from backend.data.credit import (
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.data.graph import GraphSettings
|
||||
from backend.data.invited_user import get_or_activate_user
|
||||
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.data.onboarding import (
|
||||
@@ -70,7 +71,6 @@ from backend.data.onboarding import (
|
||||
update_user_onboarding,
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_or_create_user,
|
||||
get_user_by_id,
|
||||
get_user_notification_preference,
|
||||
update_user_email,
|
||||
@@ -136,12 +136,10 @@ _tally_background_tasks: set[asyncio.Task] = set()
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
user = await get_or_create_user(user_data)
|
||||
user = await get_or_activate_user(user_data)
|
||||
|
||||
# Fire-and-forget: populate business understanding from Tally form.
|
||||
# We use created_at proximity instead of an is_new flag because
|
||||
# get_or_create_user is cached — a separate is_new return value would be
|
||||
# unreliable on repeated calls within the cache TTL.
|
||||
# Fire-and-forget: backfill Tally understanding when invite pre-seeding did
|
||||
# not produce a stored result before first activation.
|
||||
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
|
||||
if age_seconds < 30:
|
||||
try:
|
||||
@@ -165,8 +163,11 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_user_email_route(
|
||||
user_id: Annotated[str, Security(get_user_id)], email: str = Body(...)
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
email: str = Body(...),
|
||||
user_data: dict = Security(get_jwt_payload),
|
||||
) -> dict[str, str]:
|
||||
await get_or_activate_user(user_data)
|
||||
await update_user_email(user_id, email)
|
||||
|
||||
return {"email": email}
|
||||
@@ -182,7 +183,7 @@ async def get_user_timezone_route(
|
||||
user_data: dict = Security(get_jwt_payload),
|
||||
) -> TimezoneResponse:
|
||||
"""Get user timezone setting."""
|
||||
user = await get_or_create_user(user_data)
|
||||
user = await get_or_activate_user(user_data)
|
||||
return TimezoneResponse(timezone=user.timezone)
|
||||
|
||||
|
||||
@@ -193,9 +194,12 @@ async def get_user_timezone_route(
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_user_timezone_route(
|
||||
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
request: UpdateTimezoneRequest,
|
||||
user_data: dict = Security(get_jwt_payload),
|
||||
) -> TimezoneResponse:
|
||||
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
|
||||
await get_or_activate_user(user_data)
|
||||
user = await update_user_timezone(user_id, str(request.timezone))
|
||||
return TimezoneResponse(timezone=user.timezone)
|
||||
|
||||
@@ -208,7 +212,9 @@ async def update_user_timezone_route(
|
||||
)
|
||||
async def get_preferences(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
user_data: dict = Security(get_jwt_payload),
|
||||
) -> NotificationPreference:
|
||||
await get_or_activate_user(user_data)
|
||||
preferences = await get_user_notification_preference(user_id)
|
||||
return preferences
|
||||
|
||||
@@ -222,7 +228,9 @@ async def get_preferences(
|
||||
async def update_preferences(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
preferences: NotificationPreferenceDTO = Body(...),
|
||||
user_data: dict = Security(get_jwt_payload),
|
||||
) -> NotificationPreference:
|
||||
await get_or_activate_user(user_data)
|
||||
output = await update_user_notification_preference(user_id, preferences)
|
||||
return output
|
||||
|
||||
@@ -449,7 +457,6 @@ async def execute_graph_block(
|
||||
async def upload_file(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
file: UploadFile = File(...),
|
||||
provider: str = "gcs",
|
||||
expiration_hours: int = 24,
|
||||
) -> UploadFileResponse:
|
||||
"""
|
||||
@@ -512,7 +519,6 @@ async def upload_file(
|
||||
storage_path = await cloud_storage.store_file(
|
||||
content=content,
|
||||
filename=file_name,
|
||||
provider=provider,
|
||||
expiration_hours=expiration_hours,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@@ -51,7 +51,7 @@ def test_get_or_create_user_route(
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_or_create_user",
|
||||
"backend.api.features.v1.get_or_activate_user",
|
||||
return_value=mock_user,
|
||||
)
|
||||
|
||||
@@ -515,7 +515,6 @@ async def test_upload_file_success(test_user_id: str):
|
||||
result = await upload_file(
|
||||
file=upload_file_mock,
|
||||
user_id=test_user_id,
|
||||
provider="gcs",
|
||||
expiration_hours=24,
|
||||
)
|
||||
|
||||
@@ -533,7 +532,6 @@ async def test_upload_file_success(test_user_id: str):
|
||||
mock_handler.store_file.assert_called_once_with(
|
||||
content=file_content,
|
||||
filename="test.txt",
|
||||
provider="gcs",
|
||||
expiration_hours=24,
|
||||
user_id=test_user_id,
|
||||
)
|
||||
|
||||
@@ -3,15 +3,29 @@ Workspace API routes for managing user file storage.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Annotated
|
||||
from urllib.parse import quote
|
||||
|
||||
import fastapi
|
||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||
from fastapi import Query, UploadFile
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.workspace import WorkspaceFile, get_workspace, get_workspace_file
|
||||
from backend.data.workspace import (
|
||||
WorkspaceFile,
|
||||
count_workspace_files,
|
||||
get_or_create_workspace,
|
||||
get_workspace,
|
||||
get_workspace_file,
|
||||
get_workspace_total_size,
|
||||
soft_delete_workspace_file,
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
|
||||
@@ -98,6 +112,25 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
||||
raise
|
||||
|
||||
|
||||
class UploadFileResponse(BaseModel):
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
|
||||
|
||||
class DeleteFileResponse(BaseModel):
|
||||
deleted: bool
|
||||
|
||||
|
||||
class StorageUsageResponse(BaseModel):
|
||||
used_bytes: int
|
||||
limit_bytes: int
|
||||
used_percent: float
|
||||
file_count: int
|
||||
|
||||
|
||||
@router.get(
|
||||
"/files/{file_id}/download",
|
||||
summary="Download file by ID",
|
||||
@@ -120,3 +153,148 @@ async def download_file(
|
||||
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
return await _create_file_download_response(file)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/files/{file_id}",
|
||||
summary="Delete a workspace file",
|
||||
)
|
||||
async def delete_workspace_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
file_id: str,
|
||||
) -> DeleteFileResponse:
|
||||
"""
|
||||
Soft-delete a workspace file and attempt to remove it from storage.
|
||||
|
||||
Used when a user clears a file input in the builder.
|
||||
"""
|
||||
workspace = await get_workspace(user_id)
|
||||
if workspace is None:
|
||||
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
|
||||
|
||||
manager = WorkspaceManager(user_id, workspace.id)
|
||||
deleted = await manager.delete_file(file_id)
|
||||
if not deleted:
|
||||
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
return DeleteFileResponse(deleted=True)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/files/upload",
|
||||
summary="Upload file to workspace",
|
||||
)
|
||||
async def upload_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
file: UploadFile,
|
||||
session_id: str | None = Query(default=None),
|
||||
) -> UploadFileResponse:
|
||||
"""
|
||||
Upload a file to the user's workspace.
|
||||
|
||||
Files are stored in session-scoped paths when session_id is provided,
|
||||
so the agent's session-scoped tools can discover them automatically.
|
||||
"""
|
||||
config = Config()
|
||||
|
||||
# Sanitize filename — strip any directory components
|
||||
filename = os.path.basename(file.filename or "upload") or "upload"
|
||||
|
||||
# Read file content with early abort on size limit
|
||||
max_file_bytes = config.max_file_size_mb * 1024 * 1024
|
||||
chunks: list[bytes] = []
|
||||
total_size = 0
|
||||
while chunk := await file.read(64 * 1024): # 64KB chunks
|
||||
total_size += len(chunk)
|
||||
if total_size > max_file_bytes:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File exceeds maximum size of {config.max_file_size_mb} MB",
|
||||
)
|
||||
chunks.append(chunk)
|
||||
content = b"".join(chunks)
|
||||
|
||||
# Get or create workspace
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
# Pre-write storage cap check (soft check — final enforcement is post-write)
|
||||
storage_limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
|
||||
current_usage = await get_workspace_total_size(workspace.id)
|
||||
if storage_limit_bytes and current_usage + len(content) > storage_limit_bytes:
|
||||
used_percent = (current_usage / storage_limit_bytes) * 100
|
||||
raise fastapi.HTTPException(
|
||||
status_code=413,
|
||||
detail={
|
||||
"message": "Storage limit exceeded",
|
||||
"used_bytes": current_usage,
|
||||
"limit_bytes": storage_limit_bytes,
|
||||
"used_percent": round(used_percent, 1),
|
||||
},
|
||||
)
|
||||
|
||||
# Warn at 80% usage
|
||||
if (
|
||||
storage_limit_bytes
|
||||
and (usage_ratio := (current_usage + len(content)) / storage_limit_bytes) >= 0.8
|
||||
):
|
||||
logger.warning(
|
||||
f"User {user_id} workspace storage at {usage_ratio * 100:.1f}% "
|
||||
f"({current_usage + len(content)} / {storage_limit_bytes} bytes)"
|
||||
)
|
||||
|
||||
# Virus scan
|
||||
await scan_content_safe(content, filename=filename)
|
||||
|
||||
# Write file via WorkspaceManager
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
try:
|
||||
workspace_file = await manager.write_file(content, filename)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
|
||||
|
||||
# Post-write storage check — eliminates TOCTOU race on the quota.
|
||||
# If a concurrent upload pushed us over the limit, undo this write.
|
||||
new_total = await get_workspace_total_size(workspace.id)
|
||||
if storage_limit_bytes and new_total > storage_limit_bytes:
|
||||
await soft_delete_workspace_file(workspace_file.id, workspace.id)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=413,
|
||||
detail={
|
||||
"message": "Storage limit exceeded (concurrent upload)",
|
||||
"used_bytes": new_total,
|
||||
"limit_bytes": storage_limit_bytes,
|
||||
},
|
||||
)
|
||||
|
||||
return UploadFileResponse(
|
||||
file_id=workspace_file.id,
|
||||
name=workspace_file.name,
|
||||
path=workspace_file.path,
|
||||
mime_type=workspace_file.mime_type,
|
||||
size_bytes=workspace_file.size_bytes,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/storage/usage",
|
||||
summary="Get workspace storage usage",
|
||||
)
|
||||
async def get_storage_usage(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
) -> StorageUsageResponse:
|
||||
"""
|
||||
Get storage usage information for the user's workspace.
|
||||
"""
|
||||
config = Config()
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
used_bytes = await get_workspace_total_size(workspace.id)
|
||||
file_count = await count_workspace_files(workspace.id)
|
||||
limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
|
||||
|
||||
return StorageUsageResponse(
|
||||
used_bytes=used_bytes,
|
||||
limit_bytes=limit_bytes,
|
||||
used_percent=round((used_bytes / limit_bytes) * 100, 1) if limit_bytes else 0,
|
||||
file_count=file_count,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,359 @@
|
||||
"""Tests for workspace file upload and download routes."""
|
||||
|
||||
import io
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.workspace import routes as workspace_routes
|
||||
from backend.data.workspace import WorkspaceFile
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(workspace_routes.router)
|
||||
|
||||
|
||||
@app.exception_handler(ValueError)
|
||||
async def _value_error_handler(
|
||||
request: fastapi.Request, exc: ValueError
|
||||
) -> fastapi.responses.JSONResponse:
|
||||
"""Mirror the production ValueError → 400 mapping from rest_api.py."""
|
||||
return fastapi.responses.JSONResponse(status_code=400, content={"detail": str(exc)})
|
||||
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
MOCK_WORKSPACE = type("W", (), {"id": "ws-1"})()
|
||||
|
||||
_NOW = datetime(2023, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
MOCK_FILE = WorkspaceFile(
|
||||
id="file-aaa-bbb",
|
||||
workspace_id="ws-1",
|
||||
created_at=_NOW,
|
||||
updated_at=_NOW,
|
||||
name="hello.txt",
|
||||
path="/session/hello.txt",
|
||||
mime_type="text/plain",
|
||||
size_bytes=13,
|
||||
storage_path="local://hello.txt",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _upload(
|
||||
filename: str = "hello.txt",
|
||||
content: bytes = b"Hello, world!",
|
||||
content_type: str = "text/plain",
|
||||
):
|
||||
"""Helper to POST a file upload."""
|
||||
return client.post(
|
||||
"/files/upload?session_id=sess-1",
|
||||
files={"file": (filename, io.BytesIO(content), content_type)},
|
||||
)
|
||||
|
||||
|
||||
# ---- Happy path ----
|
||||
|
||||
|
||||
def test_upload_happy_path(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["file_id"] == "file-aaa-bbb"
|
||||
assert data["name"] == "hello.txt"
|
||||
assert data["size_bytes"] == 13
|
||||
|
||||
|
||||
# ---- Per-file size limit ----
|
||||
|
||||
|
||||
def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
|
||||
"""Files larger than max_file_size_mb should be rejected with 413."""
|
||||
cfg = mocker.patch("backend.api.features.workspace.routes.Config")
|
||||
cfg.return_value.max_file_size_mb = 0 # 0 MB → any content is too big
|
||||
cfg.return_value.max_workspace_storage_mb = 500
|
||||
|
||||
response = _upload(content=b"x" * 1024)
|
||||
assert response.status_code == 413
|
||||
|
||||
|
||||
# ---- Storage quota exceeded ----
|
||||
|
||||
|
||||
def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
# Current usage already at limit
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=500 * 1024 * 1024,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 413
|
||||
assert "Storage limit exceeded" in response.text
|
||||
|
||||
|
||||
# ---- Post-write quota race (B2) ----
|
||||
|
||||
|
||||
def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
|
||||
"""If a concurrent upload tips the total over the limit after write,
|
||||
the file should be soft-deleted and 413 returned."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
# Pre-write check passes (under limit), but post-write check fails
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
side_effect=[0, 600 * 1024 * 1024], # first call OK, second over limit
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
mock_delete = mocker.patch(
|
||||
"backend.api.features.workspace.routes.soft_delete_workspace_file",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 413
|
||||
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-1")
|
||||
|
||||
|
||||
# ---- Any extension accepted (no allowlist) ----
|
||||
|
||||
|
||||
def test_upload_any_extension(mocker: pytest_mock.MockFixture):
|
||||
"""Any file extension should be accepted — ClamAV is the security layer."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload(filename="data.xyz", content=b"arbitrary")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ---- Virus scan rejection ----
|
||||
|
||||
|
||||
def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
|
||||
"""Files flagged by ClamAV should be rejected and never written to storage."""
|
||||
from backend.api.features.store.exceptions import VirusDetectedError
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
side_effect=VirusDetectedError("Eicar-Test-Signature"),
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload(filename="evil.exe", content=b"X5O!P%@AP...")
|
||||
assert response.status_code == 400
|
||||
assert "Virus detected" in response.text
|
||||
mock_manager.write_file.assert_not_called()
|
||||
|
||||
|
||||
# ---- No file extension ----
|
||||
|
||||
|
||||
def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
|
||||
"""Files without an extension should be accepted and stored as-is."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload(
|
||||
filename="Makefile",
|
||||
content=b"all:\n\techo hello",
|
||||
content_type="application/octet-stream",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_manager.write_file.assert_called_once()
|
||||
assert mock_manager.write_file.call_args[0][1] == "Makefile"
|
||||
|
||||
|
||||
# ---- Filename sanitization (SF5) ----
|
||||
|
||||
|
||||
def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
|
||||
"""Path-traversal filenames should be reduced to their basename."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
# Filename with traversal
|
||||
_upload(filename="../../etc/passwd.txt")
|
||||
|
||||
# write_file should have been called with just the basename
|
||||
mock_manager.write_file.assert_called_once()
|
||||
call_args = mock_manager.write_file.call_args
|
||||
assert call_args[0][1] == "passwd.txt"
|
||||
|
||||
|
||||
# ---- Download ----
|
||||
|
||||
|
||||
def test_download_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_file",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.get("/files/some-file-id/download")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ---- Delete ----
|
||||
|
||||
|
||||
def test_delete_file_success(mocker: pytest_mock.MockFixture):
|
||||
"""Deleting an existing file should return {"deleted": true}."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.delete_file = mocker.AsyncMock(return_value=True)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = client.delete("/files/file-aaa-bbb")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"deleted": True}
|
||||
mock_manager.delete_file.assert_called_once_with("file-aaa-bbb")
|
||||
|
||||
|
||||
def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
"""Deleting a non-existent file should return 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.delete_file = mocker.AsyncMock(return_value=False)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = client.delete("/files/nonexistent-id")
|
||||
assert response.status_code == 404
|
||||
assert "File not found" in response.text
|
||||
|
||||
|
||||
def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
|
||||
"""Deleting when user has no workspace should return 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.delete("/files/file-aaa-bbb")
|
||||
assert response.status_code == 404
|
||||
assert "Workspace not found" in response.text
|
||||
@@ -19,6 +19,7 @@ from prisma.errors import PrismaError
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.admin.user_admin_routes
|
||||
import backend.api.features.builder
|
||||
import backend.api.features.builder.routes
|
||||
import backend.api.features.chat.routes as chat_routes
|
||||
@@ -55,6 +56,7 @@ from backend.util.exceptions import (
|
||||
MissingConfigError,
|
||||
NotAuthorizedError,
|
||||
NotFoundError,
|
||||
PreconditionFailed,
|
||||
)
|
||||
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
||||
from backend.util.service import UnhealthyServiceError
|
||||
@@ -275,6 +277,7 @@ app.add_exception_handler(RequestValidationError, validation_error_handler)
|
||||
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
|
||||
app.add_exception_handler(MissingConfigError, handle_internal_http_error(503))
|
||||
app.add_exception_handler(ValueError, handle_internal_http_error(400))
|
||||
app.add_exception_handler(PreconditionFailed, handle_internal_http_error(428))
|
||||
app.add_exception_handler(Exception, handle_internal_http_error(500))
|
||||
|
||||
app.include_router(backend.api.features.v1.v1_router, tags=["v1"], prefix="/api")
|
||||
@@ -309,6 +312,11 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/executions",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.user_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/users",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.executions.review.routes.router,
|
||||
tags=["v2", "executions", "review"],
|
||||
|
||||
@@ -116,6 +116,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||
CLAUDE_4_6_OPUS = "claude-opus-4-6"
|
||||
CLAUDE_4_6_SONNET = "claude-sonnet-4-6"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
# AI/ML API models
|
||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||
@@ -274,6 +275,9 @@ MODEL_METADATA = {
|
||||
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
|
||||
), # claude-opus-4-6
|
||||
LlmModel.CLAUDE_4_6_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Sonnet 4.6", "Anthropic", "Anthropic", 3
|
||||
), # claude-sonnet-4-6
|
||||
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
||||
), # claude-opus-4-5-20251101
|
||||
|
||||
@@ -6,7 +6,6 @@ and execute them. Works like AgentExecutorBlock — the user selects a tool from
|
||||
dropdown and the input/output schema adapts dynamically.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
@@ -20,6 +19,11 @@ from backend.blocks._base import (
|
||||
BlockType,
|
||||
)
|
||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
||||
from backend.blocks.mcp.helpers import (
|
||||
auto_lookup_mcp_credential,
|
||||
normalize_mcp_url,
|
||||
parse_mcp_content,
|
||||
)
|
||||
from backend.data.block import BlockInput, BlockOutput
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
@@ -179,31 +183,7 @@ class MCPToolBlock(Block):
|
||||
f"{error_text or 'Unknown error'}"
|
||||
)
|
||||
|
||||
# Extract text content from the result
|
||||
output_parts = []
|
||||
for item in result.content:
|
||||
if item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
# Try to parse as JSON for structured output
|
||||
try:
|
||||
output_parts.append(json.loads(text))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
output_parts.append(text)
|
||||
elif item.get("type") == "image":
|
||||
output_parts.append(
|
||||
{
|
||||
"type": "image",
|
||||
"data": item.get("data"),
|
||||
"mimeType": item.get("mimeType"),
|
||||
}
|
||||
)
|
||||
elif item.get("type") == "resource":
|
||||
output_parts.append(item.get("resource", {}))
|
||||
|
||||
# If single result, unwrap
|
||||
if len(output_parts) == 1:
|
||||
return output_parts[0]
|
||||
return output_parts if output_parts else None
|
||||
return parse_mcp_content(result.content)
|
||||
|
||||
@staticmethod
|
||||
async def _auto_lookup_credential(
|
||||
@@ -211,37 +191,10 @@ class MCPToolBlock(Block):
|
||||
) -> "OAuth2Credentials | None":
|
||||
"""Auto-lookup stored MCP credential for a server URL.
|
||||
|
||||
This is a fallback for nodes that don't have ``credentials`` explicitly
|
||||
set (e.g. nodes created before the credential field was wired up).
|
||||
Delegates to :func:`~backend.blocks.mcp.helpers.auto_lookup_mcp_credential`.
|
||||
The caller should pass a normalized URL.
|
||||
"""
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
try:
|
||||
mgr = IntegrationCredentialsManager()
|
||||
mcp_creds = await mgr.store.get_creds_by_provider(
|
||||
user_id, ProviderName.MCP.value
|
||||
)
|
||||
best: OAuth2Credentials | None = None
|
||||
for cred in mcp_creds:
|
||||
if (
|
||||
isinstance(cred, OAuth2Credentials)
|
||||
and (cred.metadata or {}).get("mcp_server_url") == server_url
|
||||
):
|
||||
if best is None or (
|
||||
(cred.access_token_expires_at or 0)
|
||||
> (best.access_token_expires_at or 0)
|
||||
):
|
||||
best = cred
|
||||
if best:
|
||||
best = await mgr.refresh_if_needed(user_id, best)
|
||||
logger.info(
|
||||
"Auto-resolved MCP credential %s for %s", best.id, server_url
|
||||
)
|
||||
return best
|
||||
except Exception:
|
||||
logger.warning("Auto-lookup MCP credential failed", exc_info=True)
|
||||
return None
|
||||
return await auto_lookup_mcp_credential(user_id, server_url)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
@@ -278,7 +231,7 @@ class MCPToolBlock(Block):
|
||||
# the stored MCP credential for this server URL.
|
||||
if credentials is None:
|
||||
credentials = await self._auto_lookup_credential(
|
||||
user_id, input_data.server_url
|
||||
user_id, normalize_mcp_url(input_data.server_url)
|
||||
)
|
||||
|
||||
auth_token = (
|
||||
|
||||
@@ -55,7 +55,9 @@ class MCPClient:
|
||||
server_url: str,
|
||||
auth_token: str | None = None,
|
||||
):
|
||||
self.server_url = server_url.rstrip("/")
|
||||
from backend.blocks.mcp.helpers import normalize_mcp_url
|
||||
|
||||
self.server_url = normalize_mcp_url(server_url)
|
||||
self.auth_token = auth_token
|
||||
self._request_id = 0
|
||||
self._session_id: str | None = None
|
||||
|
||||
117
autogpt_platform/backend/backend/blocks/mcp/helpers.py
Normal file
117
autogpt_platform/backend/backend/blocks/mcp/helpers.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Shared MCP helpers used by blocks, copilot tools, and API routes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.model import OAuth2Credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def normalize_mcp_url(url: str) -> str:
|
||||
"""Normalize an MCP server URL for consistent credential matching.
|
||||
|
||||
Strips leading/trailing whitespace and a single trailing slash so that
|
||||
``https://mcp.example.com/`` and ``https://mcp.example.com`` resolve to
|
||||
the same stored credential.
|
||||
"""
|
||||
return url.strip().rstrip("/")
|
||||
|
||||
|
||||
def server_host(server_url: str) -> str:
|
||||
"""Extract the hostname from a server URL for display purposes.
|
||||
|
||||
Uses ``parsed.hostname`` (never ``netloc``) to strip any embedded
|
||||
username/password before surfacing the value in UI messages.
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(server_url)
|
||||
return parsed.hostname or server_url
|
||||
except Exception:
|
||||
return server_url
|
||||
|
||||
|
||||
def parse_mcp_content(content: list[dict[str, Any]]) -> Any:
|
||||
"""Parse MCP tool response content into a plain Python value.
|
||||
|
||||
- text items: parsed as JSON when possible, kept as str otherwise
|
||||
- image items: kept as ``{type, data, mimeType}`` dict for frontend rendering
|
||||
- resource items: unwrapped to their resource payload dict
|
||||
|
||||
Single-item responses are unwrapped from the list; multiple items are
|
||||
returned as a list; empty content returns ``None``.
|
||||
"""
|
||||
output_parts: list[Any] = []
|
||||
for item in content:
|
||||
item_type = item.get("type")
|
||||
if item_type == "text":
|
||||
text = item.get("text", "")
|
||||
try:
|
||||
output_parts.append(json.loads(text))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
output_parts.append(text)
|
||||
elif item_type == "image":
|
||||
output_parts.append(
|
||||
{
|
||||
"type": "image",
|
||||
"data": item.get("data"),
|
||||
"mimeType": item.get("mimeType"),
|
||||
}
|
||||
)
|
||||
elif item_type == "resource":
|
||||
output_parts.append(item.get("resource", {}))
|
||||
|
||||
if len(output_parts) == 1:
|
||||
return output_parts[0]
|
||||
return output_parts or None
|
||||
|
||||
|
||||
async def auto_lookup_mcp_credential(
|
||||
user_id: str, server_url: str
|
||||
) -> OAuth2Credentials | None:
|
||||
"""Look up the best stored MCP credential for *server_url*.
|
||||
|
||||
The caller should pass a **normalized** URL (via :func:`normalize_mcp_url`)
|
||||
so the comparison with ``mcp_server_url`` in credential metadata matches.
|
||||
|
||||
Returns the credential with the latest ``access_token_expires_at``, refreshed
|
||||
if needed, or ``None`` when no match is found.
|
||||
"""
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
try:
|
||||
mgr = IntegrationCredentialsManager()
|
||||
mcp_creds = await mgr.store.get_creds_by_provider(
|
||||
user_id, ProviderName.MCP.value
|
||||
)
|
||||
# Collect all matching credentials and pick the best one.
|
||||
# Primary sort: latest access_token_expires_at (tokens with expiry
|
||||
# are preferred over non-expiring ones). Secondary sort: last in
|
||||
# iteration order, which corresponds to the most recently created
|
||||
# row — this acts as a tiebreaker when multiple bearer tokens have
|
||||
# no expiry (e.g. after a failed old-credential cleanup).
|
||||
best: OAuth2Credentials | None = None
|
||||
for cred in mcp_creds:
|
||||
if (
|
||||
isinstance(cred, OAuth2Credentials)
|
||||
and (cred.metadata or {}).get("mcp_server_url") == server_url
|
||||
):
|
||||
if best is None or (
|
||||
(cred.access_token_expires_at or 0)
|
||||
>= (best.access_token_expires_at or 0)
|
||||
):
|
||||
best = cred
|
||||
if best:
|
||||
best = await mgr.refresh_if_needed(user_id, best)
|
||||
logger.info("Auto-resolved MCP credential %s for %s", best.id, server_url)
|
||||
return best
|
||||
except Exception:
|
||||
logger.warning("Auto-lookup MCP credential failed", exc_info=True)
|
||||
return None
|
||||
98
autogpt_platform/backend/backend/blocks/mcp/test_helpers.py
Normal file
98
autogpt_platform/backend/backend/blocks/mcp/test_helpers.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Unit tests for the shared MCP helpers."""
|
||||
|
||||
from backend.blocks.mcp.helpers import normalize_mcp_url, parse_mcp_content, server_host
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# normalize_mcp_url
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_normalize_trailing_slash():
|
||||
assert normalize_mcp_url("https://mcp.example.com/") == "https://mcp.example.com"
|
||||
|
||||
|
||||
def test_normalize_whitespace():
|
||||
assert normalize_mcp_url(" https://mcp.example.com ") == "https://mcp.example.com"
|
||||
|
||||
|
||||
def test_normalize_both():
|
||||
assert (
|
||||
normalize_mcp_url(" https://mcp.example.com/ ") == "https://mcp.example.com"
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_noop():
|
||||
assert normalize_mcp_url("https://mcp.example.com") == "https://mcp.example.com"
|
||||
|
||||
|
||||
def test_normalize_path_with_trailing_slash():
|
||||
assert (
|
||||
normalize_mcp_url("https://mcp.example.com/path/")
|
||||
== "https://mcp.example.com/path"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# server_host
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_server_host_standard_url():
|
||||
assert server_host("https://mcp.example.com/mcp") == "mcp.example.com"
|
||||
|
||||
|
||||
def test_server_host_strips_credentials():
|
||||
"""hostname must not expose user:pass."""
|
||||
assert server_host("https://user:secret@mcp.example.com/mcp") == "mcp.example.com"
|
||||
|
||||
|
||||
def test_server_host_with_port():
|
||||
"""Port should not appear in hostname (hostname strips it)."""
|
||||
assert server_host("https://mcp.example.com:8080/mcp") == "mcp.example.com"
|
||||
|
||||
|
||||
def test_server_host_fallback():
|
||||
"""Falls back to the raw string for un-parseable URLs."""
|
||||
assert server_host("not-a-url") == "not-a-url"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_mcp_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_text_plain():
|
||||
assert parse_mcp_content([{"type": "text", "text": "hello world"}]) == "hello world"
|
||||
|
||||
|
||||
def test_parse_text_json():
|
||||
content = [{"type": "text", "text": '{"status": "ok", "count": 42}'}]
|
||||
assert parse_mcp_content(content) == {"status": "ok", "count": 42}
|
||||
|
||||
|
||||
def test_parse_image():
|
||||
content = [{"type": "image", "data": "abc123==", "mimeType": "image/png"}]
|
||||
assert parse_mcp_content(content) == {
|
||||
"type": "image",
|
||||
"data": "abc123==",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
|
||||
|
||||
def test_parse_resource():
|
||||
content = [
|
||||
{"type": "resource", "resource": {"uri": "file:///tmp/out.txt", "text": "hi"}}
|
||||
]
|
||||
assert parse_mcp_content(content) == {"uri": "file:///tmp/out.txt", "text": "hi"}
|
||||
|
||||
|
||||
def test_parse_multi_item():
|
||||
content = [
|
||||
{"type": "text", "text": "first"},
|
||||
{"type": "text", "text": "second"},
|
||||
]
|
||||
assert parse_mcp_content(content) == ["first", "second"]
|
||||
|
||||
|
||||
def test_parse_empty():
|
||||
assert parse_mcp_content([]) is None
|
||||
@@ -83,7 +83,8 @@ class StagehandRecommendedLlmModel(str, Enum):
|
||||
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
||||
|
||||
# Anthropic
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929" # Keep for backwards compat
|
||||
CLAUDE_4_6_SONNET = "claude-sonnet-4-6"
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
@@ -137,7 +138,7 @@ class StagehandObserveBlock(Block):
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
@@ -227,7 +228,7 @@ class StagehandActBlock(Block):
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
@@ -324,7 +325,7 @@ class StagehandExtractBlock(Block):
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.store.db import StoreAgentsSortOptions
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -176,8 +176,8 @@ class SearchStoreAgentsBlock(Block):
|
||||
category: str | None = SchemaField(
|
||||
description="Filter by category", default=None
|
||||
)
|
||||
sort_by: Literal["rating", "runs", "name", "updated_at"] = SchemaField(
|
||||
description="How to sort the results", default="rating"
|
||||
sort_by: StoreAgentsSortOptions = SchemaField(
|
||||
description="How to sort the results", default=StoreAgentsSortOptions.RATING
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of results to return", default=10, ge=1, le=100
|
||||
@@ -278,7 +278,7 @@ class SearchStoreAgentsBlock(Block):
|
||||
self,
|
||||
query: str | None = None,
|
||||
category: str | None = None,
|
||||
sort_by: Literal["rating", "runs", "name", "updated_at"] = "rating",
|
||||
sort_by: StoreAgentsSortOptions = StoreAgentsSortOptions.RATING,
|
||||
limit: int = 10,
|
||||
) -> SearchAgentsResponse:
|
||||
"""
|
||||
|
||||
@@ -2,6 +2,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.store.db import StoreAgentsSortOptions
|
||||
from backend.blocks.system.library_operations import (
|
||||
AddToLibraryFromStoreBlock,
|
||||
LibraryAgent,
|
||||
@@ -121,7 +122,10 @@ async def test_search_store_agents_block(mocker):
|
||||
)
|
||||
|
||||
input_data = block.Input(
|
||||
query="test", category="productivity", sort_by="rating", limit=10
|
||||
query="test",
|
||||
category="productivity",
|
||||
sort_by=StoreAgentsSortOptions.RATING, # type: ignore[reportArgumentType]
|
||||
limit=10,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .service import stream_chat_completion_baseline
|
||||
|
||||
__all__ = ["stream_chat_completion_baseline"]
|
||||
424
autogpt_platform/backend/backend/copilot/baseline/service.py
Normal file
424
autogpt_platform/backend/backend/copilot/baseline/service.py
Normal file
@@ -0,0 +1,424 @@
|
||||
"""Baseline LLM fallback — OpenAI-compatible streaming with tool calling.
|
||||
|
||||
Used when ``CHAT_USE_CLAUDE_AGENT_SDK=false``, e.g. as a fallback when the
|
||||
Claude Agent SDK / Anthropic API is unavailable. Routes through any
|
||||
OpenAI-compatible provider (OpenRouter by default) and reuses the same
|
||||
shared tool registry as the SDK path.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from langfuse import propagate_attributes
|
||||
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import (
|
||||
_build_system_prompt,
|
||||
_generate_session_title,
|
||||
client,
|
||||
config,
|
||||
)
|
||||
from backend.copilot.tools import execute_tool, get_available_tools
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.prompt import compress_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Set to hold background tasks to prevent garbage collection
|
||||
_background_tasks: set[asyncio.Task[Any]] = set()
|
||||
|
||||
# Maximum number of tool-call rounds before forcing a text response.
|
||||
_MAX_TOOL_ROUNDS = 30
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
session_id: str, message: str, user_id: str | None
|
||||
) -> None:
|
||||
"""Generate and persist a session title in the background."""
|
||||
try:
|
||||
title = await _generate_session_title(message, user_id, session_id)
|
||||
if title and user_id:
|
||||
await update_session_title(session_id, user_id, title, only_if_empty=True)
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Failed to update session title: %s", e)
|
||||
|
||||
|
||||
async def _compress_session_messages(
|
||||
messages: list[ChatMessage],
|
||||
) -> list[ChatMessage]:
|
||||
"""Compress session messages if they exceed the model's token limit.
|
||||
|
||||
Uses the shared compress_context() utility which supports LLM-based
|
||||
summarization of older messages while keeping recent ones intact,
|
||||
with progressive truncation and middle-out deletion as fallbacks.
|
||||
"""
|
||||
messages_dict = []
|
||||
for msg in messages:
|
||||
msg_dict: dict[str, Any] = {"role": msg.role}
|
||||
if msg.content:
|
||||
msg_dict["content"] = msg.content
|
||||
messages_dict.append(msg_dict)
|
||||
|
||||
try:
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
client=None,
|
||||
)
|
||||
|
||||
if result.was_compacted:
|
||||
logger.info(
|
||||
"[Baseline] Context compacted: %d -> %d tokens "
|
||||
"(%d summarized, %d dropped)",
|
||||
result.original_token_count,
|
||||
result.token_count,
|
||||
result.messages_summarized,
|
||||
result.messages_dropped,
|
||||
)
|
||||
return [
|
||||
ChatMessage(role=m["role"], content=m.get("content"))
|
||||
for m in result.messages
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
async def stream_chat_completion_baseline(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
is_user_message: bool = True,
|
||||
user_id: str | None = None,
|
||||
session: ChatSession | None = None,
|
||||
**_kwargs: Any,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Baseline LLM with tool calling via OpenAI-compatible API.
|
||||
|
||||
Designed as a fallback when the Claude Agent SDK is unavailable.
|
||||
Uses the same tool registry as the SDK path but routes through any
|
||||
OpenAI-compatible provider (e.g. OpenRouter).
|
||||
|
||||
Flow: stream response -> if tool_calls, execute them -> feed results back -> repeat.
|
||||
"""
|
||||
if session is None:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(
|
||||
f"Session {session_id} not found. Please create a new session first."
|
||||
)
|
||||
|
||||
# Append user message
|
||||
new_role = "user" if is_user_message else "assistant"
|
||||
if message and (
|
||||
len(session.messages) == 0
|
||||
or not (
|
||||
session.messages[-1].role == new_role
|
||||
and session.messages[-1].content == message
|
||||
)
|
||||
):
|
||||
session.messages.append(ChatMessage(role=new_role, content=message))
|
||||
if is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(message),
|
||||
)
|
||||
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Generate title for new sessions
|
||||
if is_user_message and not session.title:
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
if len(user_messages) == 1:
|
||||
first_message = user_messages[0].content or message or ""
|
||||
if first_message:
|
||||
task = asyncio.create_task(
|
||||
_update_title_async(session_id, first_message, user_id)
|
||||
)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
# Build system prompt only on the first turn to avoid mid-conversation
|
||||
# changes from concurrent chats updating business understanding.
|
||||
is_first_turn = len(session.messages) <= 1
|
||||
if is_first_turn:
|
||||
base_system_prompt, _ = await _build_system_prompt(
|
||||
user_id, has_conversation_history=False
|
||||
)
|
||||
else:
|
||||
base_system_prompt, _ = await _build_system_prompt(
|
||||
user_id=None, has_conversation_history=True
|
||||
)
|
||||
|
||||
# Append tool documentation and technical notes
|
||||
system_prompt = base_system_prompt + get_baseline_supplement()
|
||||
|
||||
# Compress context if approaching the model's token limit
|
||||
messages_for_context = await _compress_session_messages(session.messages)
|
||||
|
||||
# Build OpenAI message list from session history
|
||||
openai_messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": system_prompt}
|
||||
]
|
||||
for msg in messages_for_context:
|
||||
if msg.role in ("user", "assistant") and msg.content:
|
||||
openai_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
tools = get_available_tools()
|
||||
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
|
||||
# Propagate user/session context to Langfuse so all LLM calls within
|
||||
# this request are grouped under a single trace with proper attribution.
|
||||
_trace_ctx: Any = None
|
||||
try:
|
||||
_trace_ctx = propagate_attributes(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
trace_name="copilot-baseline",
|
||||
tags=["baseline"],
|
||||
)
|
||||
_trace_ctx.__enter__()
|
||||
except Exception:
|
||||
logger.warning("[Baseline] Langfuse trace context setup failed")
|
||||
|
||||
assistant_text = ""
|
||||
text_block_id = str(uuid.uuid4())
|
||||
text_started = False
|
||||
step_open = False
|
||||
try:
|
||||
for _round in range(_MAX_TOOL_ROUNDS):
|
||||
# Open a new step for each LLM round
|
||||
yield StreamStartStep()
|
||||
step_open = True
|
||||
|
||||
# Stream a response from the model
|
||||
create_kwargs: dict[str, Any] = dict(
|
||||
model=config.model,
|
||||
messages=openai_messages,
|
||||
stream=True,
|
||||
)
|
||||
if tools:
|
||||
create_kwargs["tools"] = tools
|
||||
response = await client.chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
|
||||
|
||||
# Accumulate streamed response (text + tool calls)
|
||||
round_text = ""
|
||||
tool_calls_by_index: dict[int, dict[str, str]] = {}
|
||||
|
||||
async for chunk in response:
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
if not delta:
|
||||
continue
|
||||
|
||||
# Text content
|
||||
if delta.content:
|
||||
if not text_started:
|
||||
yield StreamTextStart(id=text_block_id)
|
||||
text_started = True
|
||||
round_text += delta.content
|
||||
yield StreamTextDelta(id=text_block_id, delta=delta.content)
|
||||
|
||||
# Tool call fragments (streamed incrementally)
|
||||
if delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index
|
||||
if idx not in tool_calls_by_index:
|
||||
tool_calls_by_index[idx] = {
|
||||
"id": "",
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
entry = tool_calls_by_index[idx]
|
||||
if tc.id:
|
||||
entry["id"] = tc.id
|
||||
if tc.function and tc.function.name:
|
||||
entry["name"] = tc.function.name
|
||||
if tc.function and tc.function.arguments:
|
||||
entry["arguments"] += tc.function.arguments
|
||||
|
||||
# Close text block if we had one this round
|
||||
if text_started:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
text_started = False
|
||||
text_block_id = str(uuid.uuid4())
|
||||
|
||||
# Accumulate text for session persistence
|
||||
assistant_text += round_text
|
||||
|
||||
# No tool calls -> model is done
|
||||
if not tool_calls_by_index:
|
||||
yield StreamFinishStep()
|
||||
step_open = False
|
||||
break
|
||||
|
||||
# Close step before tool execution
|
||||
yield StreamFinishStep()
|
||||
step_open = False
|
||||
|
||||
# Append the assistant message with tool_calls to context.
|
||||
assistant_msg: dict[str, Any] = {"role": "assistant"}
|
||||
if round_text:
|
||||
assistant_msg["content"] = round_text
|
||||
assistant_msg["tool_calls"] = [
|
||||
{
|
||||
"id": tc["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc["name"],
|
||||
"arguments": tc["arguments"] or "{}",
|
||||
},
|
||||
}
|
||||
for tc in tool_calls_by_index.values()
|
||||
]
|
||||
openai_messages.append(assistant_msg)
|
||||
|
||||
# Execute each tool call and stream events
|
||||
for tc in tool_calls_by_index.values():
|
||||
tool_call_id = tc["id"]
|
||||
tool_name = tc["name"]
|
||||
raw_args = tc["arguments"] or "{}"
|
||||
try:
|
||||
tool_args = orjson.loads(raw_args)
|
||||
except orjson.JSONDecodeError as parse_err:
|
||||
parse_error = (
|
||||
f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
|
||||
)
|
||||
logger.warning("[Baseline] %s", parse_error)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=parse_error,
|
||||
success=False,
|
||||
)
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": parse_error,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
yield StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
input=tool_args,
|
||||
)
|
||||
|
||||
# Execute via shared tool registry
|
||||
try:
|
||||
result: StreamToolOutputAvailable = await execute_tool(
|
||||
tool_name=tool_name,
|
||||
parameters=tool_args,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
yield result
|
||||
tool_output = (
|
||||
result.output
|
||||
if isinstance(result.output, str)
|
||||
else str(result.output)
|
||||
)
|
||||
except Exception as e:
|
||||
error_output = f"Tool execution error: {e}"
|
||||
logger.error(
|
||||
"[Baseline] Tool %s failed: %s",
|
||||
tool_name,
|
||||
error_output,
|
||||
exc_info=True,
|
||||
)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=error_output,
|
||||
success=False,
|
||||
)
|
||||
tool_output = error_output
|
||||
|
||||
# Append tool result to context for next round
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": tool_output,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# for-loop exhausted without break -> tool-round limit hit
|
||||
limit_msg = (
|
||||
f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds "
|
||||
"without a final response."
|
||||
)
|
||||
logger.error("[Baseline] %s", limit_msg)
|
||||
yield StreamError(
|
||||
errorText=limit_msg,
|
||||
code="baseline_tool_round_limit",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e) or type(e).__name__
|
||||
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
|
||||
# Close any open text/step before emitting error
|
||||
if text_started:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
if step_open:
|
||||
yield StreamFinishStep()
|
||||
yield StreamError(errorText=error_msg, code="baseline_error")
|
||||
# Still persist whatever we got
|
||||
finally:
|
||||
# Close Langfuse trace context
|
||||
if _trace_ctx is not None:
|
||||
try:
|
||||
_trace_ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("[Baseline] Langfuse trace context teardown failed")
|
||||
|
||||
# Persist assistant response
|
||||
if assistant_text:
|
||||
session.messages.append(
|
||||
ChatMessage(role="assistant", content=assistant_text)
|
||||
)
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
except Exception as persist_err:
|
||||
logger.error("[Baseline] Failed to persist session: %s", persist_err)
|
||||
|
||||
yield StreamFinish()
|
||||
@@ -0,0 +1,99 @@
|
||||
import logging
|
||||
from os import getenv
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.baseline import stream_chat_completion_baseline
|
||||
from backend.copilot.model import (
|
||||
create_chat_session,
|
||||
get_chat_session,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamStart,
|
||||
StreamTextDelta,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_baseline_multi_turn(setup_test_user, test_user_id):
|
||||
"""Test that the baseline LLM path streams responses and maintains history.
|
||||
|
||||
Turn 1: Send a message with a unique keyword.
|
||||
Turn 2: Ask the model to recall the keyword — proving conversation history
|
||||
is correctly passed to the single-call LLM.
|
||||
"""
|
||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# --- Turn 1: send a message with a unique keyword ---
|
||||
keyword = "QUASAR99"
|
||||
turn1_msg = (
|
||||
f"Please remember this special keyword: {keyword}. "
|
||||
"Just confirm you've noted it, keep your response brief."
|
||||
)
|
||||
turn1_text = ""
|
||||
turn1_errors: list[str] = []
|
||||
got_start = False
|
||||
got_finish = False
|
||||
|
||||
async for chunk in stream_chat_completion_baseline(
|
||||
session.session_id,
|
||||
turn1_msg,
|
||||
user_id=test_user_id,
|
||||
):
|
||||
if isinstance(chunk, StreamStart):
|
||||
got_start = True
|
||||
elif isinstance(chunk, StreamTextDelta):
|
||||
turn1_text += chunk.delta
|
||||
elif isinstance(chunk, StreamError):
|
||||
turn1_errors.append(chunk.errorText)
|
||||
elif isinstance(chunk, StreamFinish):
|
||||
got_finish = True
|
||||
|
||||
assert got_start, "Turn 1 did not yield StreamStart"
|
||||
assert got_finish, "Turn 1 did not yield StreamFinish"
|
||||
assert not turn1_errors, f"Turn 1 errors: {turn1_errors}"
|
||||
assert turn1_text, "Turn 1 produced no text"
|
||||
logger.info(f"Turn 1 response: {turn1_text[:100]}")
|
||||
|
||||
# Reload session for turn 2
|
||||
session = await get_chat_session(session.session_id, test_user_id)
|
||||
assert session, "Session not found after turn 1"
|
||||
|
||||
# Verify messages were persisted (user + assistant)
|
||||
assert (
|
||||
len(session.messages) >= 2
|
||||
), f"Expected at least 2 messages after turn 1, got {len(session.messages)}"
|
||||
|
||||
# --- Turn 2: ask model to recall the keyword ---
|
||||
turn2_msg = "What was the special keyword I asked you to remember?"
|
||||
turn2_text = ""
|
||||
turn2_errors: list[str] = []
|
||||
|
||||
async for chunk in stream_chat_completion_baseline(
|
||||
session.session_id,
|
||||
turn2_msg,
|
||||
user_id=test_user_id,
|
||||
session=session,
|
||||
):
|
||||
if isinstance(chunk, StreamTextDelta):
|
||||
turn2_text += chunk.delta
|
||||
elif isinstance(chunk, StreamError):
|
||||
turn2_errors.append(chunk.errorText)
|
||||
|
||||
assert not turn2_errors, f"Turn 2 errors: {turn2_errors}"
|
||||
assert turn2_text, "Turn 2 produced no text"
|
||||
assert keyword in turn2_text, (
|
||||
f"Model did not recall keyword '{keyword}' in turn 2. "
|
||||
f"Response: {turn2_text[:200]}"
|
||||
)
|
||||
logger.info(f"Turn 2 recalled keyword successfully: {turn2_text[:100]}")
|
||||
@@ -26,11 +26,6 @@ class ChatConfig(BaseSettings):
|
||||
# Session TTL Configuration - 12 hours
|
||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||
|
||||
# Streaming Configuration
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
description="Max retries for fallback path (SDK handles retries internally)",
|
||||
)
|
||||
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||
max_agent_schedules: int = Field(
|
||||
default=30, description="Maximum number of agent schedules"
|
||||
@@ -67,11 +62,15 @@ class ChatConfig(BaseSettings):
|
||||
default="CoPilot Prompt",
|
||||
description="Name of the prompt in Langfuse to fetch",
|
||||
)
|
||||
langfuse_prompt_cache_ttl: int = Field(
|
||||
default=300,
|
||||
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
|
||||
)
|
||||
|
||||
# Claude Agent SDK Configuration
|
||||
use_claude_agent_sdk: bool = Field(
|
||||
default=True,
|
||||
description="Use Claude Agent SDK for chat completions",
|
||||
description="Use Claude Agent SDK (True) or OpenAI-compatible LLM baseline (False)",
|
||||
)
|
||||
claude_agent_model: str | None = Field(
|
||||
default=None,
|
||||
@@ -92,18 +91,53 @@ class ChatConfig(BaseSettings):
|
||||
description="Use --resume for multi-turn conversations instead of "
|
||||
"history compression. Falls back to compression when unavailable.",
|
||||
)
|
||||
|
||||
# Extended thinking configuration for Claude models
|
||||
thinking_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable adaptive thinking for Claude models via OpenRouter",
|
||||
use_claude_code_subscription: bool = Field(
|
||||
default=False,
|
||||
description="For personal/dev use: use Claude Code CLI subscription auth instead of API keys. Requires `claude login` on the host. Only works with SDK mode.",
|
||||
)
|
||||
|
||||
# E2B Sandbox Configuration
|
||||
use_e2b_sandbox: bool = Field(
|
||||
default=True,
|
||||
description="Use E2B cloud sandboxes for persistent bash/python execution. "
|
||||
"When enabled, bash_exec routes commands to E2B and SDK file tools "
|
||||
"operate directly on the sandbox via E2B's filesystem API.",
|
||||
)
|
||||
e2b_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="E2B API key. Falls back to E2B_API_KEY environment variable.",
|
||||
)
|
||||
e2b_sandbox_template: str = Field(
|
||||
default="base",
|
||||
description="E2B sandbox template to use for copilot sessions.",
|
||||
)
|
||||
e2b_sandbox_timeout: int = Field(
|
||||
default=43200, # 12 hours — same as session_ttl
|
||||
description="E2B sandbox keepalive timeout in seconds.",
|
||||
)
|
||||
|
||||
@field_validator("use_e2b_sandbox", mode="before")
|
||||
@classmethod
|
||||
def get_use_e2b_sandbox(cls, v):
|
||||
"""Get use_e2b_sandbox from environment if not provided."""
|
||||
env_val = os.getenv("CHAT_USE_E2B_SANDBOX", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
return True if v is None else v
|
||||
|
||||
@field_validator("e2b_api_key", mode="before")
|
||||
@classmethod
|
||||
def get_e2b_api_key(cls, v):
|
||||
"""Get E2B API key from environment if not provided."""
|
||||
if not v:
|
||||
v = os.getenv("CHAT_E2B_API_KEY") or os.getenv("E2B_API_KEY")
|
||||
return v
|
||||
|
||||
@field_validator("api_key", mode="before")
|
||||
@classmethod
|
||||
def get_api_key(cls, v):
|
||||
"""Get API key from environment if not provided."""
|
||||
if v is None:
|
||||
if not v:
|
||||
# Try to get from environment variables
|
||||
# First check for CHAT_API_KEY (Pydantic prefix)
|
||||
v = os.getenv("CHAT_API_KEY")
|
||||
@@ -113,13 +147,16 @@ class ChatConfig(BaseSettings):
|
||||
if not v:
|
||||
# Fall back to OPENAI_API_KEY
|
||||
v = os.getenv("OPENAI_API_KEY")
|
||||
# Note: ANTHROPIC_API_KEY is intentionally NOT included here.
|
||||
# The SDK CLI picks it up from the env directly. Including it
|
||||
# would pair it with the OpenRouter base_url, causing auth failures.
|
||||
return v
|
||||
|
||||
@field_validator("base_url", mode="before")
|
||||
@classmethod
|
||||
def get_base_url(cls, v):
|
||||
"""Get base URL from environment if not provided."""
|
||||
if v is None:
|
||||
if not v:
|
||||
# Check for OpenRouter or custom base URL
|
||||
v = os.getenv("CHAT_BASE_URL")
|
||||
if not v:
|
||||
@@ -141,6 +178,15 @@ class ChatConfig(BaseSettings):
|
||||
# Default to True (SDK enabled by default)
|
||||
return True if v is None else v
|
||||
|
||||
@field_validator("use_claude_code_subscription", mode="before")
|
||||
@classmethod
|
||||
def get_use_claude_code_subscription(cls, v):
|
||||
"""Get use_claude_code_subscription from environment if not provided."""
|
||||
env_val = os.getenv("CHAT_USE_CLAUDE_CODE_SUBSCRIPTION", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
return False if v is None else v
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
|
||||
11
autogpt_platform/backend/backend/copilot/constants.py
Normal file
11
autogpt_platform/backend/backend/copilot/constants.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Shared constants for the CoPilot module."""
|
||||
|
||||
# Special message prefixes for text-based markers (parsed by frontend).
|
||||
# The hex suffix makes accidental LLM generation of these strings virtually
|
||||
# impossible, avoiding false-positive marker detection in normal conversation.
|
||||
COPILOT_ERROR_PREFIX = "[__COPILOT_ERROR_f7a1__]" # Renders as ErrorCard
|
||||
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
|
||||
|
||||
# Compaction notice messages shown to users.
|
||||
COMPACTION_DONE_MSG = "Earlier messages were summarized to fit within context limits."
|
||||
COMPACTION_TOOL_NAME = "context_compaction"
|
||||
@@ -16,7 +16,7 @@ from prisma.types import (
|
||||
)
|
||||
|
||||
from backend.data import db
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.json import SafeJson, sanitize_string
|
||||
|
||||
from .model import ChatMessage, ChatSession, ChatSessionInfo
|
||||
|
||||
@@ -81,6 +81,35 @@ async def update_chat_session(
|
||||
return ChatSession.from_db(session) if session else None
|
||||
|
||||
|
||||
async def update_chat_session_title(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
title: str,
|
||||
*,
|
||||
only_if_empty: bool = False,
|
||||
) -> bool:
|
||||
"""Update the title of a chat session, scoped to the owning user.
|
||||
|
||||
Always filters by (session_id, user_id) so callers cannot mutate another
|
||||
user's session even when they know the session_id.
|
||||
|
||||
Args:
|
||||
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
|
||||
guard so auto-generated titles never overwrite a user-set title.
|
||||
|
||||
Returns True if a row was updated, False otherwise (session not found,
|
||||
wrong user, or — when only_if_empty — title was already set).
|
||||
"""
|
||||
where: ChatSessionWhereInput = {"id": session_id, "userId": user_id}
|
||||
if only_if_empty:
|
||||
where["title"] = None
|
||||
result = await PrismaChatSession.prisma().update_many(
|
||||
where=where,
|
||||
data={"title": title, "updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
return result > 0
|
||||
|
||||
|
||||
async def add_chat_message(
|
||||
session_id: str,
|
||||
role: str,
|
||||
@@ -101,15 +130,16 @@ async def add_chat_message(
|
||||
"sequence": sequence,
|
||||
}
|
||||
|
||||
# Add optional string fields
|
||||
# Add optional string fields — sanitize to strip PostgreSQL-incompatible
|
||||
# control characters (null bytes etc.) that may appear in tool outputs.
|
||||
if content is not None:
|
||||
data["content"] = content
|
||||
data["content"] = sanitize_string(content)
|
||||
if name is not None:
|
||||
data["name"] = name
|
||||
if tool_call_id is not None:
|
||||
data["toolCallId"] = tool_call_id
|
||||
if refusal is not None:
|
||||
data["refusal"] = refusal
|
||||
data["refusal"] = sanitize_string(refusal)
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if tool_calls is not None:
|
||||
@@ -170,15 +200,16 @@ async def add_chat_messages_batch(
|
||||
"createdAt": now,
|
||||
}
|
||||
|
||||
# Add optional string fields
|
||||
# Add optional string fields — sanitize to strip
|
||||
# PostgreSQL-incompatible control characters.
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
data["content"] = sanitize_string(msg["content"])
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
data["refusal"] = sanitize_string(msg["refusal"])
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
@@ -312,7 +343,7 @@ async def update_tool_message_content(
|
||||
"toolCallId": tool_call_id,
|
||||
},
|
||||
data={
|
||||
"content": new_content,
|
||||
"content": sanitize_string(new_content),
|
||||
},
|
||||
)
|
||||
if result == 0:
|
||||
|
||||
@@ -6,11 +6,13 @@ in a thread-local context, following the graph executor pattern.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
|
||||
from backend.copilot import service as copilot_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.baseline import stream_chat_completion_baseline
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.response_model import StreamFinish
|
||||
from backend.copilot.sdk import service as sdk_service
|
||||
@@ -108,8 +110,41 @@ class CoPilotProcessor:
|
||||
)
|
||||
self.execution_thread.start()
|
||||
|
||||
# Skip the SDK's per-request CLI version check — the bundled CLI is
|
||||
# already version-matched to the SDK package.
|
||||
os.environ.setdefault("CLAUDE_AGENT_SDK_SKIP_VERSION_CHECK", "1")
|
||||
|
||||
# Pre-warm the bundled CLI binary so the OS page-caches the ~185 MB
|
||||
# executable. First spawn pays ~1.2 s; subsequent spawns ~0.65 s.
|
||||
self._prewarm_cli()
|
||||
|
||||
logger.info(f"[CoPilotExecutor] Worker {self.tid} started")
|
||||
|
||||
def _prewarm_cli(self) -> None:
|
||||
"""Run the bundled CLI binary once to warm OS page caches."""
|
||||
try:
|
||||
from claude_agent_sdk._internal.transport.subprocess_cli import (
|
||||
SubprocessCLITransport,
|
||||
)
|
||||
|
||||
cli_path = SubprocessCLITransport._find_bundled_cli(None) # type: ignore[arg-type]
|
||||
if cli_path:
|
||||
result = subprocess.run(
|
||||
[cli_path, "-v"],
|
||||
capture_output=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
logger.info(f"[CoPilotExecutor] CLI pre-warm done: {cli_path}")
|
||||
else:
|
||||
logger.warning(
|
||||
"[CoPilotExecutor] CLI pre-warm failed (rc=%d): %s",
|
||||
result.returncode, # type: ignore[reportCallIssue]
|
||||
cli_path,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"[CoPilotExecutor] CLI pre-warm skipped: {e}")
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up event-loop-bound resources before the loop is destroyed.
|
||||
|
||||
@@ -119,12 +154,12 @@ class CoPilotProcessor:
|
||||
"""
|
||||
from backend.util.workspace_storage import shutdown_workspace_storage
|
||||
|
||||
coro = shutdown_workspace_storage()
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
shutdown_workspace_storage(), self.execution_loop
|
||||
)
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self.execution_loop)
|
||||
future.result(timeout=5)
|
||||
except Exception as e:
|
||||
coro.close() # Prevent "coroutine was never awaited" warning
|
||||
error_msg = str(e) or type(e).__name__
|
||||
logger.warning(
|
||||
f"[CoPilotExecutor] Worker {self.tid} cleanup error: {error_msg}"
|
||||
@@ -194,7 +229,7 @@ class CoPilotProcessor:
|
||||
):
|
||||
"""Async execution logic for a CoPilot turn.
|
||||
|
||||
Calls the stream_chat_completion service function and publishes
|
||||
Calls the chat completion service (SDK or baseline) and publishes
|
||||
results to the stream registry.
|
||||
|
||||
Args:
|
||||
@@ -208,9 +243,10 @@ class CoPilotProcessor:
|
||||
error_msg = None
|
||||
|
||||
try:
|
||||
# Choose service based on LaunchDarkly flag
|
||||
# Choose service based on LaunchDarkly flag.
|
||||
# Claude Code subscription forces SDK mode (CLI subprocess auth).
|
||||
config = ChatConfig()
|
||||
use_sdk = await is_feature_enabled(
|
||||
use_sdk = config.use_claude_code_subscription or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
@@ -218,9 +254,9 @@ class CoPilotProcessor:
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else copilot_service.stream_chat_completion
|
||||
else stream_chat_completion_baseline
|
||||
)
|
||||
log.info(f"Using {'SDK' if use_sdk else 'standard'} service")
|
||||
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
|
||||
|
||||
# Stream chat completion and publish chunks to Redis.
|
||||
async for chunk in stream_fn(
|
||||
@@ -229,6 +265,7 @@ class CoPilotProcessor:
|
||||
is_user_message=entry.is_user_message,
|
||||
user_id=entry.user_id,
|
||||
context=entry.context,
|
||||
file_ids=entry.file_ids,
|
||||
):
|
||||
if cancel.is_set():
|
||||
log.info("Cancel requested, breaking stream")
|
||||
|
||||
@@ -153,6 +153,9 @@ class CoPilotExecutionEntry(BaseModel):
|
||||
context: dict[str, str] | None = None
|
||||
"""Optional context for the message (e.g., {url: str, content: str})"""
|
||||
|
||||
file_ids: list[str] | None = None
|
||||
"""Workspace file IDs attached to the user's message"""
|
||||
|
||||
|
||||
class CancelCoPilotEvent(BaseModel):
|
||||
"""Event to cancel a CoPilot operation."""
|
||||
@@ -171,6 +174,7 @@ async def enqueue_copilot_turn(
|
||||
turn_id: str,
|
||||
is_user_message: bool = True,
|
||||
context: dict[str, str] | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Enqueue a CoPilot task for processing by the executor service.
|
||||
|
||||
@@ -181,6 +185,7 @@ async def enqueue_copilot_turn(
|
||||
turn_id: Per-turn UUID for Redis stream isolation
|
||||
is_user_message: Whether the message is from the user (vs system/assistant)
|
||||
context: Optional context for the message (e.g., {url: str, content: str})
|
||||
file_ids: Optional workspace file IDs attached to the user's message
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
@@ -191,6 +196,7 @@ async def enqueue_copilot_turn(
|
||||
message=message,
|
||||
is_user_message=is_user_message,
|
||||
context=context,
|
||||
file_ids=file_ids,
|
||||
)
|
||||
|
||||
queue_client = await get_async_copilot_queue()
|
||||
|
||||
@@ -469,8 +469,16 @@ async def upsert_chat_session(
|
||||
)
|
||||
db_error = e
|
||||
|
||||
# Save to cache (best-effort, even if DB failed)
|
||||
# Save to cache (best-effort, even if DB failed).
|
||||
# Title updates (update_session_title) run *outside* this lock because
|
||||
# they only touch the title field, not messages. So a concurrent rename
|
||||
# or auto-title may have written a newer title to Redis while this
|
||||
# upsert was in progress. Always prefer the cached title to avoid
|
||||
# overwriting it with the stale in-memory copy.
|
||||
try:
|
||||
existing_cached = await _get_session_from_cache(session.session_id)
|
||||
if existing_cached and existing_cached.title:
|
||||
session = session.model_copy(update={"title": existing_cached.title})
|
||||
await cache_chat_session(session)
|
||||
except Exception as e:
|
||||
# If DB succeeded but cache failed, raise cache error
|
||||
@@ -672,27 +680,47 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
async with _session_locks_mutex:
|
||||
_session_locks.pop(session_id, None)
|
||||
|
||||
# Shut down any local browser daemon for this session (best-effort).
|
||||
# Inline import required: all tool modules import ChatSession from this
|
||||
# module, so any top-level import from tools.* would create a cycle.
|
||||
try:
|
||||
from .tools.agent_browser import close_browser_session
|
||||
|
||||
await close_browser_session(session_id, user_id=user_id)
|
||||
except Exception as e:
|
||||
logger.debug(f"Browser cleanup for session {session_id}: {e}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def update_session_title(session_id: str, title: str) -> bool:
|
||||
"""Update only the title of a chat session.
|
||||
async def update_session_title(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
title: str,
|
||||
*,
|
||||
only_if_empty: bool = False,
|
||||
) -> bool:
|
||||
"""Update the title of a chat session, scoped to the owning user.
|
||||
|
||||
This is a lightweight operation that doesn't touch messages, avoiding
|
||||
race conditions with concurrent message updates. Use this for background
|
||||
title generation instead of upsert_chat_session.
|
||||
Lightweight operation that doesn't touch messages, avoiding race conditions
|
||||
with concurrent message updates.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to update.
|
||||
user_id: Owning user — the DB query filters on this.
|
||||
title: The new title to set.
|
||||
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
|
||||
so auto-generated titles never overwrite a user-set title.
|
||||
|
||||
Returns:
|
||||
True if updated successfully, False otherwise.
|
||||
True if updated successfully, False otherwise (not found, wrong user,
|
||||
or — when only_if_empty — title was already set).
|
||||
"""
|
||||
try:
|
||||
result = await chat_db().update_chat_session(session_id=session_id, title=title)
|
||||
if result is None:
|
||||
logger.warning(f"Session {session_id} not found for title update")
|
||||
updated = await chat_db().update_chat_session_title(
|
||||
session_id, user_id, title, only_if_empty=only_if_empty
|
||||
)
|
||||
if not updated:
|
||||
return False
|
||||
|
||||
# Update title in cache if it exists (instead of invalidating).
|
||||
@@ -704,9 +732,8 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
||||
cached.title = title
|
||||
await cache_chat_session(cached)
|
||||
except Exception as e:
|
||||
# Not critical - title will be correct on next full cache refresh
|
||||
logger.warning(
|
||||
f"Failed to update title in cache for session {session_id}: {e}"
|
||||
f"Cache title update failed for session {session_id} (non-critical): {e}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@@ -1,269 +0,0 @@
|
||||
"""Tests for parallel tool call execution in CoPilot.
|
||||
|
||||
These tests mock _yield_tool_call to avoid importing the full copilot stack
|
||||
which requires Prisma, DB connections, etc.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_tool_calls_run_concurrently():
|
||||
"""Multiple tool calls should complete in ~max(delays), not sum(delays)."""
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
n_tools = 3
|
||||
delay_per_tool = 0.2
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"tool_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(n_tools)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
original_yield = None
|
||||
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"],
|
||||
toolName=tc_list[idx]["function"]["name"],
|
||||
input={},
|
||||
)
|
||||
await asyncio.sleep(delay_per_tool)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tc_list[idx]["id"],
|
||||
toolName=tc_list[idx]["function"]["name"],
|
||||
output="{}",
|
||||
)
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
original_yield = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
start = time.monotonic()
|
||||
events = []
|
||||
async for event in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
):
|
||||
events.append(event)
|
||||
elapsed = time.monotonic() - start
|
||||
finally:
|
||||
svc._yield_tool_call = original_yield
|
||||
|
||||
assert len(events) == n_tools * 2
|
||||
# Parallel: should take ~delay, not ~n*delay
|
||||
assert elapsed < delay_per_tool * (
|
||||
n_tools - 0.5
|
||||
), f"Took {elapsed:.2f}s, expected parallel (~{delay_per_tool}s)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_tool_call_works():
|
||||
"""Single tool call should work identically."""
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": "call_0",
|
||||
"type": "function",
|
||||
"function": {"name": "t", "arguments": "{}"},
|
||||
}
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
yield StreamToolInputAvailable(toolCallId="call_0", toolName="t", input={})
|
||||
yield StreamToolOutputAvailable(toolCallId="call_0", toolName="t", output="{}")
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
events = [
|
||||
e
|
||||
async for e in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
)
|
||||
]
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
|
||||
assert len(events) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retryable_error_propagates():
|
||||
"""Retryable errors should be raised after all tools finish."""
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"t_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
if idx == 1:
|
||||
raise KeyError("bad")
|
||||
from backend.copilot.response_model import StreamToolInputAvailable
|
||||
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName="t_0", input={}
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName="t_0", output="{}"
|
||||
)
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
events = []
|
||||
with pytest.raises(KeyError):
|
||||
async for event in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
):
|
||||
events.append(event)
|
||||
# First tool's events should still be yielded
|
||||
assert any(isinstance(e, StreamToolOutputAvailable) for e in events)
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_shared_across_parallel_tools():
|
||||
"""All parallel tools should receive the same session instance."""
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"t_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
observed_sessions = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
observed_sessions.append(sess)
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
|
||||
)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", output="{}"
|
||||
)
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
async for _ in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
):
|
||||
pass
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
|
||||
assert len(observed_sessions) == 3
|
||||
assert observed_sessions[0] is observed_sessions[1] is observed_sessions[2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancellation_cleans_up():
|
||||
"""Generator close should cancel in-flight tasks."""
|
||||
from backend.copilot.response_model import StreamToolInputAvailable
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"t_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
|
||||
)
|
||||
started.set()
|
||||
await asyncio.sleep(10) # simulate long-running
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
gen = _execute_tool_calls_parallel(tool_calls, cast(Any, FakeSession()))
|
||||
await gen.__anext__() # get first event
|
||||
await started.wait()
|
||||
await gen.aclose() # close generator
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
# If we get here without hanging, cleanup worked
|
||||
191
autogpt_platform/backend/backend/copilot/prompting.py
Normal file
191
autogpt_platform/backend/backend/copilot/prompting.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Centralized prompt building logic for CoPilot.
|
||||
|
||||
This module contains all prompt construction functions and constants,
|
||||
handling the distinction between:
|
||||
- SDK mode vs Baseline mode (tool documentation needs)
|
||||
- Local mode vs E2B mode (storage/filesystem differences)
|
||||
"""
|
||||
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
# Shared technical notes that apply to both SDK and baseline modes
|
||||
_SHARED_TOOL_NOTES = """\
|
||||
|
||||
### Sharing files with the user
|
||||
After saving a file to the persistent workspace with `write_workspace_file`,
|
||||
share it with the user by embedding the `download_url` from the response in
|
||||
your message as a Markdown link or image:
|
||||
|
||||
- **Any file** — shows as a clickable download link:
|
||||
`[report.csv](workspace://file_id#text/csv)`
|
||||
- **Image** — renders inline in chat:
|
||||
``
|
||||
- **Video** — renders inline in chat with player controls:
|
||||
``
|
||||
|
||||
The `download_url` field in the `write_workspace_file` response is already
|
||||
in the correct format — paste it directly after the `(` in the Markdown.
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
"""
|
||||
|
||||
|
||||
# Environment-specific supplement templates
|
||||
def _build_storage_supplement(
|
||||
working_dir: str,
|
||||
sandbox_type: str,
|
||||
storage_system_1_name: str,
|
||||
storage_system_1_characteristics: list[str],
|
||||
storage_system_1_persistence: list[str],
|
||||
file_move_name_1_to_2: str,
|
||||
file_move_name_2_to_1: str,
|
||||
) -> str:
|
||||
"""Build storage/filesystem supplement for a specific environment.
|
||||
|
||||
Template function handles all formatting (bullets, indentation, markdown).
|
||||
Callers provide clean data as lists of strings.
|
||||
|
||||
Args:
|
||||
working_dir: Working directory path
|
||||
sandbox_type: Description of bash_exec sandbox
|
||||
storage_system_1_name: Name of primary storage (ephemeral or cloud)
|
||||
storage_system_1_characteristics: List of characteristic descriptions
|
||||
storage_system_1_persistence: List of persistence behavior descriptions
|
||||
file_move_name_1_to_2: Direction label for primary→persistent
|
||||
file_move_name_2_to_1: Direction label for persistent→primary
|
||||
"""
|
||||
# Format lists as bullet points with proper indentation
|
||||
characteristics = "\n".join(f" - {c}" for c in storage_system_1_characteristics)
|
||||
persistence = "\n".join(f" - {p}" for p in storage_system_1_persistence)
|
||||
|
||||
return f"""
|
||||
|
||||
## Tool notes
|
||||
|
||||
### Shell commands
|
||||
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
|
||||
for shell commands — it runs {sandbox_type}.
|
||||
|
||||
### Working directory
|
||||
- Your working directory is: `{working_dir}`
|
||||
- All SDK file tools AND `bash_exec` operate on the same filesystem
|
||||
- Use relative paths or absolute paths under `{working_dir}` for all file operations
|
||||
|
||||
### Two storage systems — CRITICAL to understand
|
||||
|
||||
1. **{storage_system_1_name}** (`{working_dir}`):
|
||||
{characteristics}
|
||||
{persistence}
|
||||
|
||||
2. **Persistent workspace** (cloud storage):
|
||||
- Files here **survive across sessions indefinitely**
|
||||
|
||||
### Moving files between storages
|
||||
- **{file_move_name_1_to_2}**: Copy to persistent workspace
|
||||
- **{file_move_name_2_to_1}**: Download for processing
|
||||
|
||||
### File persistence
|
||||
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
|
||||
{_SHARED_TOOL_NOTES}"""
|
||||
|
||||
|
||||
# Pre-built supplements for common environments
|
||||
def _get_local_storage_supplement(cwd: str) -> str:
|
||||
"""Local ephemeral storage (files lost between turns)."""
|
||||
return _build_storage_supplement(
|
||||
working_dir=cwd,
|
||||
sandbox_type="in a network-isolated sandbox",
|
||||
storage_system_1_name="Ephemeral working directory",
|
||||
storage_system_1_characteristics=[
|
||||
"Shared by SDK Read/Write/Edit/Glob/Grep tools AND `bash_exec`",
|
||||
],
|
||||
storage_system_1_persistence=[
|
||||
"Files here are **lost between turns** — do NOT rely on them persisting",
|
||||
"Use for temporary work: running scripts, processing data, etc.",
|
||||
],
|
||||
file_move_name_1_to_2="Ephemeral → Persistent",
|
||||
file_move_name_2_to_1="Persistent → Ephemeral",
|
||||
)
|
||||
|
||||
|
||||
def _get_cloud_sandbox_supplement() -> str:
|
||||
"""Cloud persistent sandbox (files survive across turns in session)."""
|
||||
return _build_storage_supplement(
|
||||
working_dir="/home/user",
|
||||
sandbox_type="in a cloud sandbox with full internet access",
|
||||
storage_system_1_name="Cloud sandbox",
|
||||
storage_system_1_characteristics=[
|
||||
"Shared by all file tools AND `bash_exec` — same filesystem",
|
||||
"Full Linux environment with internet access",
|
||||
],
|
||||
storage_system_1_persistence=[
|
||||
"Files **persist across turns** within the current session",
|
||||
"Lost when the session expires (12 h inactivity)",
|
||||
],
|
||||
file_move_name_1_to_2="Sandbox → Persistent",
|
||||
file_move_name_2_to_1="Persistent → Sandbox",
|
||||
)
|
||||
|
||||
|
||||
def _generate_tool_documentation() -> str:
|
||||
"""Auto-generate tool documentation from TOOL_REGISTRY.
|
||||
|
||||
NOTE: This is ONLY used in baseline mode (direct OpenAI API).
|
||||
SDK mode doesn't need it since Claude gets tool schemas automatically.
|
||||
|
||||
This generates a complete list of available tools with their descriptions,
|
||||
ensuring the documentation stays in sync with the actual tool implementations.
|
||||
All workflow guidance is now embedded in individual tool descriptions.
|
||||
|
||||
Only documents tools that are available in the current environment
|
||||
(checked via tool.is_available property).
|
||||
"""
|
||||
docs = "\n## AVAILABLE TOOLS\n\n"
|
||||
|
||||
# Sort tools alphabetically for consistent output
|
||||
# Filter by is_available to match get_available_tools() behavior
|
||||
for name in sorted(TOOL_REGISTRY.keys()):
|
||||
tool = TOOL_REGISTRY[name]
|
||||
if not tool.is_available:
|
||||
continue
|
||||
schema = tool.as_openai_tool()
|
||||
desc = schema["function"].get("description", "No description available")
|
||||
# Format as bullet list with tool name in code style
|
||||
docs += f"- **`{name}`**: {desc}\n"
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def get_sdk_supplement(use_e2b: bool, cwd: str = "") -> str:
|
||||
"""Get the supplement for SDK mode (Claude Agent SDK).
|
||||
|
||||
SDK mode does NOT include tool documentation because Claude automatically
|
||||
receives tool schemas from the SDK. Only includes technical notes about
|
||||
storage systems and execution environment.
|
||||
|
||||
Args:
|
||||
use_e2b: Whether E2B cloud sandbox is being used
|
||||
cwd: Current working directory (only used in local_storage mode)
|
||||
|
||||
Returns:
|
||||
The supplement string to append to the system prompt
|
||||
"""
|
||||
if use_e2b:
|
||||
return _get_cloud_sandbox_supplement()
|
||||
return _get_local_storage_supplement(cwd)
|
||||
|
||||
|
||||
def get_baseline_supplement() -> str:
|
||||
"""Get the supplement for baseline mode (direct OpenAI API).
|
||||
|
||||
Baseline mode INCLUDES auto-generated tool documentation because the
|
||||
direct API doesn't automatically provide tool schemas to Claude.
|
||||
Also includes shared technical notes (but NOT SDK-specific environment details).
|
||||
|
||||
Returns:
|
||||
The supplement string to append to the system prompt
|
||||
"""
|
||||
tool_docs = _generate_tool_documentation()
|
||||
return tool_docs + _SHARED_TOOL_NOTES
|
||||
@@ -13,6 +13,7 @@ from typing import Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.util.json import dumps as json_dumps
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -150,6 +151,9 @@ class StreamToolInputAvailable(StreamBaseResponse):
|
||||
)
|
||||
|
||||
|
||||
_MAX_TOOL_OUTPUT_SIZE = 100_000 # ~100 KB; truncate to avoid bloating SSE/DB
|
||||
|
||||
|
||||
class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
"""Tool execution result."""
|
||||
|
||||
@@ -164,6 +168,10 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
default=True, description="Whether the tool execution succeeded"
|
||||
)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Truncate oversized outputs after construction."""
|
||||
self.output = truncate(self.output, _MAX_TOOL_OUTPUT_SIZE)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format, excluding non-spec fields."""
|
||||
data = {
|
||||
|
||||
239
autogpt_platform/backend/backend/copilot/sdk/compaction.py
Normal file
239
autogpt_platform/backend/backend/copilot/sdk/compaction.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""Compaction tracking for SDK-based chat sessions.
|
||||
|
||||
Encapsulates the state machine and event emission for context compaction,
|
||||
both pre-query (history compressed before SDK query) and SDK-internal
|
||||
(PreCompact hook fires mid-stream).
|
||||
|
||||
All compaction-related helpers live here: event builders, message filtering,
|
||||
persistence, and the ``CompactionTracker`` state machine.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
|
||||
from ..constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
|
||||
from ..model import ChatMessage, ChatSession
|
||||
from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamFinishStep,
|
||||
StreamStartStep,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Event builders (private — use CompactionTracker or compaction_events)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _start_events(tool_call_id: str) -> list[StreamBaseResponse]:
|
||||
"""Build the opening events for a compaction tool call."""
|
||||
return [
|
||||
StreamStartStep(),
|
||||
StreamToolInputStart(toolCallId=tool_call_id, toolName=COMPACTION_TOOL_NAME),
|
||||
StreamToolInputAvailable(
|
||||
toolCallId=tool_call_id, toolName=COMPACTION_TOOL_NAME, input={}
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _end_events(tool_call_id: str, message: str) -> list[StreamBaseResponse]:
|
||||
"""Build the closing events for a compaction tool call."""
|
||||
return [
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=COMPACTION_TOOL_NAME,
|
||||
output=message,
|
||||
),
|
||||
StreamFinishStep(),
|
||||
]
|
||||
|
||||
|
||||
def _new_tool_call_id() -> str:
|
||||
return f"compaction-{uuid.uuid4().hex[:12]}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public event builder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def emit_compaction(session: ChatSession) -> list[StreamBaseResponse]:
|
||||
"""Create, persist, and return a self-contained compaction tool call.
|
||||
|
||||
Convenience for callers that don't use ``CompactionTracker`` (e.g. the
|
||||
legacy non-SDK streaming path in ``service.py``).
|
||||
"""
|
||||
tc_id = _new_tool_call_id()
|
||||
evts = compaction_events(COMPACTION_DONE_MSG, tool_call_id=tc_id)
|
||||
_persist(session, tc_id, COMPACTION_DONE_MSG)
|
||||
return evts
|
||||
|
||||
|
||||
def compaction_events(
|
||||
message: str, tool_call_id: str | None = None
|
||||
) -> list[StreamBaseResponse]:
|
||||
"""Emit a self-contained compaction tool call (already completed).
|
||||
|
||||
When *tool_call_id* is provided it is reused (e.g. for persistence that
|
||||
must match an already-streamed start event). Otherwise a new ID is
|
||||
generated.
|
||||
"""
|
||||
tc_id = tool_call_id or _new_tool_call_id()
|
||||
return _start_events(tc_id) + _end_events(tc_id, message)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message filtering
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def filter_compaction_messages(
|
||||
messages: list[ChatMessage],
|
||||
) -> list[ChatMessage]:
|
||||
"""Remove synthetic compaction tool-call messages (UI-only artifacts).
|
||||
|
||||
Strips assistant messages whose only tool calls are compaction calls,
|
||||
and their corresponding tool-result messages.
|
||||
"""
|
||||
compaction_ids: set[str] = set()
|
||||
filtered: list[ChatMessage] = []
|
||||
for msg in messages:
|
||||
if msg.role == "assistant" and msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
if tc.get("function", {}).get("name") == COMPACTION_TOOL_NAME:
|
||||
compaction_ids.add(tc.get("id", ""))
|
||||
real_calls = [
|
||||
tc
|
||||
for tc in msg.tool_calls
|
||||
if tc.get("function", {}).get("name") != COMPACTION_TOOL_NAME
|
||||
]
|
||||
if not real_calls and not msg.content:
|
||||
continue
|
||||
if msg.role == "tool" and msg.tool_call_id in compaction_ids:
|
||||
continue
|
||||
filtered.append(msg)
|
||||
return filtered
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _persist(session: ChatSession, tool_call_id: str, message: str) -> None:
|
||||
"""Append compaction tool-call + result to session messages.
|
||||
|
||||
Compaction events are synthetic so they bypass the normal adapter
|
||||
accumulation. This explicitly records them so they survive a page refresh.
|
||||
"""
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": tool_call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": COMPACTION_TOOL_NAME,
|
||||
"arguments": "{}",
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
)
|
||||
session.messages.append(
|
||||
ChatMessage(role="tool", content=message, tool_call_id=tool_call_id)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CompactionTracker — state machine for streaming sessions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CompactionTracker:
|
||||
"""Tracks compaction state and yields UI events.
|
||||
|
||||
Two compaction paths:
|
||||
|
||||
1. **Pre-query** — history compressed before the SDK query starts.
|
||||
Call :meth:`emit_pre_query` to yield a self-contained tool call.
|
||||
|
||||
2. **SDK-internal** — ``PreCompact`` hook fires mid-stream.
|
||||
Call :meth:`emit_start_if_ready` on heartbeat ticks and
|
||||
:meth:`emit_end_if_ready` when a message arrives.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._compact_start = asyncio.Event()
|
||||
self._start_emitted = False
|
||||
self._done = False
|
||||
self._tool_call_id = ""
|
||||
|
||||
@property
|
||||
def on_compact(self) -> Callable[[], None]:
|
||||
"""Callback for the PreCompact hook."""
|
||||
return self._compact_start.set
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Pre-query compaction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def emit_pre_query(self, session: ChatSession) -> list[StreamBaseResponse]:
|
||||
"""Emit + persist a self-contained compaction tool call."""
|
||||
self._done = True
|
||||
return emit_compaction(session)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SDK-internal compaction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def reset_for_query(self) -> None:
|
||||
"""Reset per-query state before a new SDK query."""
|
||||
self._done = False
|
||||
self._start_emitted = False
|
||||
self._tool_call_id = ""
|
||||
|
||||
def emit_start_if_ready(self) -> list[StreamBaseResponse]:
|
||||
"""If the PreCompact hook fired, emit start events (spinning tool)."""
|
||||
if self._compact_start.is_set() and not self._start_emitted and not self._done:
|
||||
self._compact_start.clear()
|
||||
self._start_emitted = True
|
||||
self._tool_call_id = _new_tool_call_id()
|
||||
return _start_events(self._tool_call_id)
|
||||
return []
|
||||
|
||||
async def emit_end_if_ready(self, session: ChatSession) -> list[StreamBaseResponse]:
|
||||
"""If compaction is in progress, emit end events and persist."""
|
||||
# Yield so pending hook tasks can set compact_start
|
||||
await asyncio.sleep(0)
|
||||
|
||||
if self._done:
|
||||
return []
|
||||
if not self._start_emitted and not self._compact_start.is_set():
|
||||
return []
|
||||
|
||||
if self._start_emitted:
|
||||
# Close the open spinner
|
||||
done_events = _end_events(self._tool_call_id, COMPACTION_DONE_MSG)
|
||||
persist_id = self._tool_call_id
|
||||
else:
|
||||
# PreCompact fired but start never emitted — self-contained
|
||||
persist_id = _new_tool_call_id()
|
||||
done_events = compaction_events(
|
||||
COMPACTION_DONE_MSG, tool_call_id=persist_id
|
||||
)
|
||||
|
||||
self._compact_start.clear()
|
||||
self._start_emitted = False
|
||||
self._done = True
|
||||
_persist(session, persist_id, COMPACTION_DONE_MSG)
|
||||
return done_events
|
||||
291
autogpt_platform/backend/backend/copilot/sdk/compaction_test.py
Normal file
291
autogpt_platform/backend/backend/copilot/sdk/compaction_test.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""Tests for sdk/compaction.py — event builders, filtering, persistence, and
|
||||
CompactionTracker state machine."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.response_model import (
|
||||
StreamFinishStep,
|
||||
StreamStartStep,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.sdk.compaction import (
|
||||
CompactionTracker,
|
||||
compaction_events,
|
||||
emit_compaction,
|
||||
filter_compaction_messages,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_session() -> ChatSession:
|
||||
return ChatSession.new(user_id="test-user")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compaction_events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactionEvents:
|
||||
def test_returns_start_and_end_events(self):
|
||||
evts = compaction_events("done")
|
||||
assert len(evts) == 5
|
||||
assert isinstance(evts[0], StreamStartStep)
|
||||
assert isinstance(evts[1], StreamToolInputStart)
|
||||
assert isinstance(evts[2], StreamToolInputAvailable)
|
||||
assert isinstance(evts[3], StreamToolOutputAvailable)
|
||||
assert isinstance(evts[4], StreamFinishStep)
|
||||
|
||||
def test_uses_provided_tool_call_id(self):
|
||||
evts = compaction_events("msg", tool_call_id="my-id")
|
||||
tool_start = evts[1]
|
||||
assert isinstance(tool_start, StreamToolInputStart)
|
||||
assert tool_start.toolCallId == "my-id"
|
||||
|
||||
def test_generates_id_when_not_provided(self):
|
||||
evts = compaction_events("msg")
|
||||
tool_start = evts[1]
|
||||
assert isinstance(tool_start, StreamToolInputStart)
|
||||
assert tool_start.toolCallId.startswith("compaction-")
|
||||
|
||||
def test_tool_name_is_context_compaction(self):
|
||||
evts = compaction_events("msg")
|
||||
tool_start = evts[1]
|
||||
assert isinstance(tool_start, StreamToolInputStart)
|
||||
assert tool_start.toolName == COMPACTION_TOOL_NAME
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# emit_compaction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitCompaction:
|
||||
def test_persists_to_session(self):
|
||||
session = _make_session()
|
||||
assert len(session.messages) == 0
|
||||
evts = emit_compaction(session)
|
||||
assert len(evts) == 5
|
||||
# Should have appended 2 messages (assistant tool call + tool result)
|
||||
assert len(session.messages) == 2
|
||||
assert session.messages[0].role == "assistant"
|
||||
assert session.messages[0].tool_calls is not None
|
||||
assert (
|
||||
session.messages[0].tool_calls[0]["function"]["name"]
|
||||
== COMPACTION_TOOL_NAME
|
||||
)
|
||||
assert session.messages[1].role == "tool"
|
||||
assert session.messages[1].content == COMPACTION_DONE_MSG
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# filter_compaction_messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFilterCompactionMessages:
|
||||
def test_removes_compaction_tool_calls(self):
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="hello"),
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "comp-1",
|
||||
"type": "function",
|
||||
"function": {"name": COMPACTION_TOOL_NAME, "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
),
|
||||
ChatMessage(
|
||||
role="tool", content=COMPACTION_DONE_MSG, tool_call_id="comp-1"
|
||||
),
|
||||
ChatMessage(role="assistant", content="world"),
|
||||
]
|
||||
filtered = filter_compaction_messages(msgs)
|
||||
assert len(filtered) == 2
|
||||
assert filtered[0].content == "hello"
|
||||
assert filtered[1].content == "world"
|
||||
|
||||
def test_keeps_non_compaction_tool_calls(self):
|
||||
msgs = [
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "real-1",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
),
|
||||
ChatMessage(role="tool", content="result", tool_call_id="real-1"),
|
||||
]
|
||||
filtered = filter_compaction_messages(msgs)
|
||||
assert len(filtered) == 2
|
||||
|
||||
def test_keeps_assistant_with_content_and_compaction_call(self):
|
||||
"""If assistant message has both content and a compaction tool call,
|
||||
the message is kept (has real content)."""
|
||||
msgs = [
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="I have content",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "comp-1",
|
||||
"type": "function",
|
||||
"function": {"name": COMPACTION_TOOL_NAME, "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
filtered = filter_compaction_messages(msgs)
|
||||
assert len(filtered) == 1
|
||||
|
||||
def test_empty_list(self):
|
||||
assert filter_compaction_messages([]) == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CompactionTracker
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactionTracker:
|
||||
def test_on_compact_sets_event(self):
|
||||
tracker = CompactionTracker()
|
||||
tracker.on_compact()
|
||||
assert tracker._compact_start.is_set()
|
||||
|
||||
def test_emit_start_if_ready_no_event(self):
|
||||
tracker = CompactionTracker()
|
||||
assert tracker.emit_start_if_ready() == []
|
||||
|
||||
def test_emit_start_if_ready_with_event(self):
|
||||
tracker = CompactionTracker()
|
||||
tracker.on_compact()
|
||||
evts = tracker.emit_start_if_ready()
|
||||
assert len(evts) == 3
|
||||
assert isinstance(evts[0], StreamStartStep)
|
||||
assert isinstance(evts[1], StreamToolInputStart)
|
||||
assert isinstance(evts[2], StreamToolInputAvailable)
|
||||
|
||||
def test_emit_start_only_once(self):
|
||||
tracker = CompactionTracker()
|
||||
tracker.on_compact()
|
||||
evts1 = tracker.emit_start_if_ready()
|
||||
assert len(evts1) == 3
|
||||
# Second call should return empty
|
||||
evts2 = tracker.emit_start_if_ready()
|
||||
assert evts2 == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_after_start(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
tracker.emit_start_if_ready()
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert len(evts) == 2
|
||||
assert isinstance(evts[0], StreamToolOutputAvailable)
|
||||
assert isinstance(evts[1], StreamFinishStep)
|
||||
# Should persist
|
||||
assert len(session.messages) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_without_start_self_contained(self):
|
||||
"""If PreCompact fired but start was never emitted, emit_end
|
||||
produces a self-contained compaction event."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
# Don't call emit_start_if_ready
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert len(evts) == 5 # Full self-contained event
|
||||
assert isinstance(evts[0], StreamStartStep)
|
||||
assert len(session.messages) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_no_op_when_done(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
tracker.emit_start_if_ready()
|
||||
await tracker.emit_end_if_ready(session)
|
||||
# Second call should be no-op
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert evts == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_no_op_when_nothing_happened(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert evts == []
|
||||
|
||||
def test_emit_pre_query(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
evts = tracker.emit_pre_query(session)
|
||||
assert len(evts) == 5
|
||||
assert len(session.messages) == 2
|
||||
assert tracker._done is True
|
||||
|
||||
def test_reset_for_query(self):
|
||||
tracker = CompactionTracker()
|
||||
tracker._done = True
|
||||
tracker._start_emitted = True
|
||||
tracker._tool_call_id = "old"
|
||||
tracker.reset_for_query()
|
||||
assert tracker._done is False
|
||||
assert tracker._start_emitted is False
|
||||
assert tracker._tool_call_id == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_query_blocks_sdk_compaction(self):
|
||||
"""After pre-query compaction, SDK compaction events are suppressed."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.emit_pre_query(session)
|
||||
tracker.on_compact()
|
||||
evts = tracker.emit_start_if_ready()
|
||||
assert evts == [] # _done blocks it
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_allows_new_compaction(self):
|
||||
"""After reset_for_query, compaction can fire again."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.emit_pre_query(session)
|
||||
tracker.reset_for_query()
|
||||
tracker.on_compact()
|
||||
evts = tracker.emit_start_if_ready()
|
||||
assert len(evts) == 3 # Start events emitted
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_id_consistency(self):
|
||||
"""Start and end events use the same tool_call_id."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
start_evts = tracker.emit_start_if_ready()
|
||||
end_evts = await tracker.emit_end_if_ready(session)
|
||||
start_evt = start_evts[1]
|
||||
end_evt = end_evts[0]
|
||||
assert isinstance(start_evt, StreamToolInputStart)
|
||||
assert isinstance(end_evt, StreamToolOutputAvailable)
|
||||
assert start_evt.toolCallId == end_evt.toolCallId
|
||||
# Persisted ID should also match
|
||||
tool_calls = session.messages[0].tool_calls
|
||||
assert tool_calls is not None
|
||||
assert tool_calls[0]["id"] == start_evt.toolCallId
|
||||
@@ -10,6 +10,7 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from ..model import ChatSession
|
||||
from ..response_model import StreamBaseResponse, StreamStart, StreamTextDelta
|
||||
@@ -26,6 +27,7 @@ async def stream_chat_completion_dummy(
|
||||
retry_count: int = 0,
|
||||
session: ChatSession | None = None,
|
||||
context: dict[str, str] | None = None,
|
||||
**_kwargs: Any,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Stream dummy chat completion for testing.
|
||||
|
||||
|
||||
362
autogpt_platform/backend/backend/copilot/sdk/e2b_file_tools.py
Normal file
362
autogpt_platform/backend/backend/copilot/sdk/e2b_file_tools.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""MCP file-tool handlers that route to the E2B cloud sandbox.
|
||||
|
||||
When E2B is active, these tools replace the SDK built-in Read/Write/Edit/
|
||||
Glob/Grep so that all file operations share the same ``/home/user``
|
||||
filesystem as ``bash_exec``.
|
||||
|
||||
SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
|
||||
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
from typing import Any, Callable
|
||||
|
||||
from backend.copilot.tools.e2b_sandbox import E2B_WORKDIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Lazy imports to break circular dependency with tool_adapter.
|
||||
|
||||
|
||||
def _get_sandbox(): # type: ignore[return]
|
||||
from .tool_adapter import get_current_sandbox # noqa: E402
|
||||
|
||||
return get_current_sandbox()
|
||||
|
||||
|
||||
def _is_allowed_local(path: str) -> bool:
|
||||
from .tool_adapter import is_allowed_local_path # noqa: E402
|
||||
|
||||
return is_allowed_local_path(path)
|
||||
|
||||
|
||||
def _resolve_remote(path: str) -> str:
|
||||
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
|
||||
|
||||
Raises :class:`ValueError` if the resolved path escapes the sandbox.
|
||||
"""
|
||||
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
|
||||
normalized = os.path.normpath(candidate)
|
||||
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
|
||||
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
|
||||
return normalized
|
||||
|
||||
|
||||
def _mcp(text: str, *, error: bool = False) -> dict[str, Any]:
|
||||
if error:
|
||||
text = json.dumps({"error": text, "type": "error"})
|
||||
return {"content": [{"type": "text", "text": text}], "isError": error}
|
||||
|
||||
|
||||
def _get_sandbox_and_path(
|
||||
file_path: str,
|
||||
) -> tuple[Any, str] | dict[str, Any]:
|
||||
"""Common preamble: get sandbox + resolve path, or return MCP error."""
|
||||
sandbox = _get_sandbox()
|
||||
if sandbox is None:
|
||||
return _mcp("No E2B sandbox available", error=True)
|
||||
try:
|
||||
remote = _resolve_remote(file_path)
|
||||
except ValueError as exc:
|
||||
return _mcp(str(exc), error=True)
|
||||
return sandbox, remote
|
||||
|
||||
|
||||
# Tool handlers
|
||||
|
||||
|
||||
async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
file_path: str = args.get("file_path", "")
|
||||
offset: int = max(0, int(args.get("offset", 0)))
|
||||
limit: int = max(1, int(args.get("limit", 2000)))
|
||||
|
||||
if not file_path:
|
||||
return _mcp("file_path is required", error=True)
|
||||
|
||||
# SDK-internal paths (tool-results, ephemeral working dir) stay on the host.
|
||||
if _is_allowed_local(file_path):
|
||||
return _read_local(file_path, offset, limit)
|
||||
|
||||
result = _get_sandbox_and_path(file_path)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
sandbox, remote = result
|
||||
|
||||
try:
|
||||
raw: bytes = await sandbox.files.read(remote, format="bytes")
|
||||
content = raw.decode("utf-8", errors="replace")
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to read {remote}: {exc}", error=True)
|
||||
|
||||
lines = content.splitlines(keepends=True)
|
||||
selected = list(itertools.islice(lines, offset, offset + limit))
|
||||
numbered = "".join(
|
||||
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
|
||||
)
|
||||
return _mcp(numbered)
|
||||
|
||||
|
||||
async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
file_path: str = args.get("file_path", "")
|
||||
content: str = args.get("content", "")
|
||||
|
||||
if not file_path:
|
||||
return _mcp("file_path is required", error=True)
|
||||
|
||||
result = _get_sandbox_and_path(file_path)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
sandbox, remote = result
|
||||
|
||||
try:
|
||||
parent = os.path.dirname(remote)
|
||||
if parent and parent != E2B_WORKDIR:
|
||||
await sandbox.files.make_dir(parent)
|
||||
await sandbox.files.write(remote, content)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to write {remote}: {exc}", error=True)
|
||||
|
||||
return _mcp(f"Successfully wrote to {remote}")
|
||||
|
||||
|
||||
async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
file_path: str = args.get("file_path", "")
|
||||
old_string: str = args.get("old_string", "")
|
||||
new_string: str = args.get("new_string", "")
|
||||
replace_all: bool = args.get("replace_all", False)
|
||||
|
||||
if not file_path:
|
||||
return _mcp("file_path is required", error=True)
|
||||
if not old_string:
|
||||
return _mcp("old_string is required", error=True)
|
||||
|
||||
result = _get_sandbox_and_path(file_path)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
sandbox, remote = result
|
||||
|
||||
try:
|
||||
raw: bytes = await sandbox.files.read(remote, format="bytes")
|
||||
content = raw.decode("utf-8", errors="replace")
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to read {remote}: {exc}", error=True)
|
||||
|
||||
count = content.count(old_string)
|
||||
if count == 0:
|
||||
return _mcp(f"old_string not found in {file_path}", error=True)
|
||||
if count > 1 and not replace_all:
|
||||
return _mcp(
|
||||
f"old_string appears {count} times in {file_path}. "
|
||||
"Use replace_all=true or provide a more unique string.",
|
||||
error=True,
|
||||
)
|
||||
|
||||
updated = (
|
||||
content.replace(old_string, new_string)
|
||||
if replace_all
|
||||
else content.replace(old_string, new_string, 1)
|
||||
)
|
||||
try:
|
||||
await sandbox.files.write(remote, updated)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to write {remote}: {exc}", error=True)
|
||||
|
||||
return _mcp(f"Edited {remote} ({count} replacement{'s' if count > 1 else ''})")
|
||||
|
||||
|
||||
async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
|
||||
pattern: str = args.get("pattern", "")
|
||||
path: str = args.get("path", "")
|
||||
|
||||
if not pattern:
|
||||
return _mcp("pattern is required", error=True)
|
||||
|
||||
sandbox = _get_sandbox()
|
||||
if sandbox is None:
|
||||
return _mcp("No E2B sandbox available", error=True)
|
||||
|
||||
try:
|
||||
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
|
||||
except ValueError as exc:
|
||||
return _mcp(str(exc), error=True)
|
||||
|
||||
cmd = f"find {shlex.quote(search_dir)} -name {shlex.quote(pattern)} -type f 2>/dev/null | head -500"
|
||||
try:
|
||||
result = await sandbox.commands.run(cmd, cwd=E2B_WORKDIR, timeout=10)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Glob failed: {exc}", error=True)
|
||||
|
||||
files = [line for line in (result.stdout or "").strip().splitlines() if line]
|
||||
return _mcp(json.dumps(files, indent=2))
|
||||
|
||||
|
||||
async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
|
||||
pattern: str = args.get("pattern", "")
|
||||
path: str = args.get("path", "")
|
||||
include: str = args.get("include", "")
|
||||
|
||||
if not pattern:
|
||||
return _mcp("pattern is required", error=True)
|
||||
|
||||
sandbox = _get_sandbox()
|
||||
if sandbox is None:
|
||||
return _mcp("No E2B sandbox available", error=True)
|
||||
|
||||
try:
|
||||
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
|
||||
except ValueError as exc:
|
||||
return _mcp(str(exc), error=True)
|
||||
|
||||
parts = ["grep", "-rn", "--color=never"]
|
||||
if include:
|
||||
parts.extend(["--include", include])
|
||||
parts.extend([pattern, search_dir])
|
||||
cmd = " ".join(shlex.quote(p) for p in parts) + " 2>/dev/null | head -200"
|
||||
|
||||
try:
|
||||
result = await sandbox.commands.run(cmd, cwd=E2B_WORKDIR, timeout=15)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Grep failed: {exc}", error=True)
|
||||
|
||||
output = (result.stdout or "").strip()
|
||||
return _mcp(output if output else "No matches found.")
|
||||
|
||||
|
||||
# Local read (for SDK-internal paths)
|
||||
|
||||
|
||||
def _read_local(file_path: str, offset: int, limit: int) -> dict[str, Any]:
|
||||
"""Read from the host filesystem (defence-in-depth path check)."""
|
||||
if not _is_allowed_local(file_path):
|
||||
return _mcp(f"Path not allowed: {file_path}", error=True)
|
||||
expanded = os.path.realpath(os.path.expanduser(file_path))
|
||||
try:
|
||||
with open(expanded) as fh:
|
||||
selected = list(itertools.islice(fh, offset, offset + limit))
|
||||
numbered = "".join(
|
||||
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
|
||||
)
|
||||
return _mcp(numbered)
|
||||
except FileNotFoundError:
|
||||
return _mcp(f"File not found: {file_path}", error=True)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Error reading {file_path}: {exc}", error=True)
|
||||
|
||||
|
||||
# Tool descriptors (name, description, schema, handler)
|
||||
|
||||
E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
|
||||
(
|
||||
"read_file",
|
||||
"Read a file from the cloud sandbox (/home/user). "
|
||||
"Use offset and limit for large files.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path (relative to /home/user, or absolute).",
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line to start reading from (0-indexed). Default: 0.",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Number of lines to read. Default: 2000.",
|
||||
},
|
||||
},
|
||||
"required": ["file_path"],
|
||||
},
|
||||
_handle_read_file,
|
||||
),
|
||||
(
|
||||
"write_file",
|
||||
"Write or create a file in the cloud sandbox (/home/user). "
|
||||
"Parent directories are created automatically. "
|
||||
"To copy a workspace file into the sandbox, use "
|
||||
"read_workspace_file with save_to_path instead.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path (relative to /home/user, or absolute).",
|
||||
},
|
||||
"content": {"type": "string", "description": "Content to write."},
|
||||
},
|
||||
"required": ["file_path", "content"],
|
||||
},
|
||||
_handle_write_file,
|
||||
),
|
||||
(
|
||||
"edit_file",
|
||||
"Targeted text replacement in a sandbox file. "
|
||||
"old_string must appear in the file and is replaced with new_string.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path (relative to /home/user, or absolute).",
|
||||
},
|
||||
"old_string": {"type": "string", "description": "Text to find."},
|
||||
"new_string": {"type": "string", "description": "Replacement text."},
|
||||
"replace_all": {
|
||||
"type": "boolean",
|
||||
"description": "Replace all occurrences (default: false).",
|
||||
},
|
||||
},
|
||||
"required": ["file_path", "old_string", "new_string"],
|
||||
},
|
||||
_handle_edit_file,
|
||||
),
|
||||
(
|
||||
"glob",
|
||||
"Search for files by name pattern in the cloud sandbox.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Glob pattern (e.g. *.py).",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory to search. Default: /home/user.",
|
||||
},
|
||||
},
|
||||
"required": ["pattern"],
|
||||
},
|
||||
_handle_glob,
|
||||
),
|
||||
(
|
||||
"grep",
|
||||
"Search file contents by regex in the cloud sandbox.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {"type": "string", "description": "Regex pattern."},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File or directory. Default: /home/user.",
|
||||
},
|
||||
"include": {
|
||||
"type": "string",
|
||||
"description": "Glob to filter files (e.g. *.py).",
|
||||
},
|
||||
},
|
||||
"required": ["pattern"],
|
||||
},
|
||||
_handle_grep,
|
||||
),
|
||||
]
|
||||
|
||||
E2B_FILE_TOOL_NAMES: list[str] = [name for name, *_ in E2B_FILE_TOOLS]
|
||||
@@ -0,0 +1,153 @@
|
||||
"""Tests for E2B file-tool path validation and local read safety.
|
||||
|
||||
Pure unit tests with no external dependencies (no E2B, no sandbox).
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from .e2b_file_tools import _read_local, _resolve_remote
|
||||
from .tool_adapter import _current_project_dir
|
||||
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _resolve_remote — sandbox path normalisation & boundary enforcement
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveRemote:
|
||||
def test_relative_path_resolved(self):
|
||||
assert _resolve_remote("src/main.py") == "/home/user/src/main.py"
|
||||
|
||||
def test_absolute_within_sandbox(self):
|
||||
assert _resolve_remote("/home/user/file.txt") == "/home/user/file.txt"
|
||||
|
||||
def test_workdir_itself(self):
|
||||
assert _resolve_remote("/home/user") == "/home/user"
|
||||
|
||||
def test_relative_dotslash(self):
|
||||
assert _resolve_remote("./README.md") == "/home/user/README.md"
|
||||
|
||||
def test_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("../../etc/passwd")
|
||||
|
||||
def test_absolute_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("/home/user/../../etc/passwd")
|
||||
|
||||
def test_absolute_outside_sandbox_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("/etc/passwd")
|
||||
|
||||
def test_root_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("/")
|
||||
|
||||
def test_home_other_user_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("/home/other/file.txt")
|
||||
|
||||
def test_deep_nested_allowed(self):
|
||||
assert _resolve_remote("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
|
||||
|
||||
def test_trailing_slash_normalised(self):
|
||||
assert _resolve_remote("src/") == "/home/user/src"
|
||||
|
||||
def test_double_dots_within_sandbox_ok(self):
|
||||
"""Path that resolves back within /home/user is allowed."""
|
||||
assert _resolve_remote("a/b/../c.txt") == "/home/user/a/c.txt"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_local — host filesystem reads with allowlist enforcement
|
||||
#
|
||||
# In E2B mode, _read_local only allows tool-results paths (via
|
||||
# is_allowed_local_path without sdk_cwd). Regular files live on the
|
||||
# sandbox, not the host.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReadLocal:
|
||||
def _make_tool_results_file(self, encoded: str, filename: str, content: str) -> str:
|
||||
"""Create a tool-results file and return its path."""
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
os.makedirs(tool_results_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_results_dir, filename)
|
||||
with open(filepath, "w") as f:
|
||||
f.write(content)
|
||||
return filepath
|
||||
|
||||
def test_read_tool_results_file(self):
|
||||
"""Reading a tool-results file should succeed."""
|
||||
encoded = "-tmp-copilot-e2b-test-read"
|
||||
filepath = self._make_tool_results_file(
|
||||
encoded, "result.txt", "line 1\nline 2\nline 3\n"
|
||||
)
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local(filepath, offset=0, limit=2000)
|
||||
assert result["isError"] is False
|
||||
assert "line 1" in result["content"][0]["text"]
|
||||
assert "line 2" in result["content"][0]["text"]
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
os.unlink(filepath)
|
||||
|
||||
def test_read_disallowed_path_blocked(self):
|
||||
"""Reading /etc/passwd should be blocked by the allowlist."""
|
||||
result = _read_local("/etc/passwd", offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
assert "not allowed" in result["content"][0]["text"].lower()
|
||||
|
||||
def test_read_nonexistent_tool_results(self):
|
||||
"""A tool-results path that doesn't exist returns FileNotFoundError."""
|
||||
encoded = "-tmp-copilot-e2b-test-nofile"
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
os.makedirs(tool_results_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_results_dir, "nonexistent.txt")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local(filepath, offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
assert "not found" in result["content"][0]["text"].lower()
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
os.rmdir(tool_results_dir)
|
||||
|
||||
def test_read_traversal_path_blocked(self):
|
||||
"""A traversal attempt that escapes allowed directories is blocked."""
|
||||
result = _read_local("/tmp/copilot-abc/../../etc/shadow", offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
assert "not allowed" in result["content"][0]["text"].lower()
|
||||
|
||||
def test_read_arbitrary_host_path_blocked(self):
|
||||
"""Arbitrary host paths are blocked even if they exist."""
|
||||
result = _read_local("/proc/self/environ", offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
|
||||
def test_read_with_offset_and_limit(self):
|
||||
"""Offset and limit should control which lines are returned."""
|
||||
encoded = "-tmp-copilot-e2b-test-offset"
|
||||
content = "".join(f"line {i}\n" for i in range(10))
|
||||
filepath = self._make_tool_results_file(encoded, "lines.txt", content)
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local(filepath, offset=3, limit=2)
|
||||
assert result["isError"] is False
|
||||
text = result["content"][0]["text"]
|
||||
assert "line 3" in text
|
||||
assert "line 4" in text
|
||||
assert "line 2" not in text
|
||||
assert "line 5" not in text
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
os.unlink(filepath)
|
||||
|
||||
def test_read_without_project_dir_blocks_all(self):
|
||||
"""Without _current_project_dir set, all paths are blocked."""
|
||||
result = _read_local("/tmp/anything.txt", offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
172
autogpt_platform/backend/backend/copilot/sdk/otel_setup_test.py
Normal file
172
autogpt_platform/backend/backend/copilot/sdk/otel_setup_test.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""Tests for OTEL tracing setup in the SDK copilot path."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
class TestSetupLangfuseOtel:
|
||||
"""Tests for _setup_langfuse_otel()."""
|
||||
|
||||
def test_noop_when_langfuse_not_configured(self):
|
||||
"""No env vars should be set when Langfuse credentials are missing."""
|
||||
with patch(
|
||||
"backend.copilot.sdk.service._is_langfuse_configured", return_value=False
|
||||
):
|
||||
from backend.copilot.sdk.service import _setup_langfuse_otel
|
||||
|
||||
# Clear any previously set env vars
|
||||
env_keys = [
|
||||
"LANGSMITH_OTEL_ENABLED",
|
||||
"LANGSMITH_OTEL_ONLY",
|
||||
"LANGSMITH_TRACING",
|
||||
"OTEL_EXPORTER_OTLP_ENDPOINT",
|
||||
"OTEL_EXPORTER_OTLP_HEADERS",
|
||||
]
|
||||
saved = {k: os.environ.pop(k, None) for k in env_keys}
|
||||
try:
|
||||
_setup_langfuse_otel()
|
||||
for key in env_keys:
|
||||
assert key not in os.environ, f"{key} should not be set"
|
||||
finally:
|
||||
for k, v in saved.items():
|
||||
if v is not None:
|
||||
os.environ[k] = v
|
||||
|
||||
def test_sets_env_vars_when_langfuse_configured(self):
|
||||
"""OTEL env vars should be set when Langfuse credentials exist."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.secrets.langfuse_public_key = "pk-test-123"
|
||||
mock_settings.secrets.langfuse_secret_key = "sk-test-456"
|
||||
mock_settings.secrets.langfuse_host = "https://langfuse.example.com"
|
||||
mock_settings.secrets.langfuse_tracing_environment = "test"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service._is_langfuse_configured",
|
||||
return_value=True,
|
||||
),
|
||||
patch("backend.copilot.sdk.service.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.configure_claude_agent_sdk",
|
||||
return_value=True,
|
||||
) as mock_configure,
|
||||
):
|
||||
from backend.copilot.sdk.service import _setup_langfuse_otel
|
||||
|
||||
# Clear env vars so setdefault works
|
||||
env_keys = [
|
||||
"LANGSMITH_OTEL_ENABLED",
|
||||
"LANGSMITH_OTEL_ONLY",
|
||||
"LANGSMITH_TRACING",
|
||||
"OTEL_EXPORTER_OTLP_ENDPOINT",
|
||||
"OTEL_EXPORTER_OTLP_HEADERS",
|
||||
"OTEL_RESOURCE_ATTRIBUTES",
|
||||
]
|
||||
saved = {k: os.environ.pop(k, None) for k in env_keys}
|
||||
try:
|
||||
_setup_langfuse_otel()
|
||||
|
||||
assert os.environ["LANGSMITH_OTEL_ENABLED"] == "true"
|
||||
assert os.environ["LANGSMITH_OTEL_ONLY"] == "true"
|
||||
assert os.environ["LANGSMITH_TRACING"] == "true"
|
||||
assert (
|
||||
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"]
|
||||
== "https://langfuse.example.com/api/public/otel"
|
||||
)
|
||||
assert "Authorization=Basic" in os.environ["OTEL_EXPORTER_OTLP_HEADERS"]
|
||||
assert (
|
||||
os.environ["OTEL_RESOURCE_ATTRIBUTES"]
|
||||
== "langfuse.environment=test"
|
||||
)
|
||||
|
||||
mock_configure.assert_called_once_with(tags=["sdk"])
|
||||
finally:
|
||||
for k, v in saved.items():
|
||||
if v is not None:
|
||||
os.environ[k] = v
|
||||
elif k in os.environ:
|
||||
del os.environ[k]
|
||||
|
||||
def test_existing_env_vars_not_overwritten(self):
|
||||
"""Explicit env-var overrides should not be clobbered."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.secrets.langfuse_public_key = "pk-test"
|
||||
mock_settings.secrets.langfuse_secret_key = "sk-test"
|
||||
mock_settings.secrets.langfuse_host = "https://langfuse.example.com"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service._is_langfuse_configured",
|
||||
return_value=True,
|
||||
),
|
||||
patch("backend.copilot.sdk.service.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.configure_claude_agent_sdk",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
from backend.copilot.sdk.service import _setup_langfuse_otel
|
||||
|
||||
saved = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
|
||||
try:
|
||||
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = "https://custom.endpoint/v1"
|
||||
_setup_langfuse_otel()
|
||||
assert (
|
||||
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"]
|
||||
== "https://custom.endpoint/v1"
|
||||
)
|
||||
finally:
|
||||
if saved is not None:
|
||||
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = saved
|
||||
elif "OTEL_EXPORTER_OTLP_ENDPOINT" in os.environ:
|
||||
del os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"]
|
||||
|
||||
def test_graceful_failure_on_exception(self):
|
||||
"""Setup should not raise even if internal code fails."""
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service._is_langfuse_configured",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.Settings",
|
||||
side_effect=RuntimeError("settings unavailable"),
|
||||
),
|
||||
):
|
||||
from backend.copilot.sdk.service import _setup_langfuse_otel
|
||||
|
||||
# Should not raise — just logs and returns
|
||||
_setup_langfuse_otel()
|
||||
|
||||
|
||||
class TestPropagateAttributesImport:
|
||||
"""Verify langfuse.propagate_attributes is available."""
|
||||
|
||||
def test_propagate_attributes_is_importable(self):
|
||||
from langfuse import propagate_attributes
|
||||
|
||||
assert callable(propagate_attributes)
|
||||
|
||||
def test_propagate_attributes_returns_context_manager(self):
|
||||
from langfuse import propagate_attributes
|
||||
|
||||
ctx = propagate_attributes(user_id="u1", session_id="s1", tags=["test"])
|
||||
assert hasattr(ctx, "__enter__")
|
||||
assert hasattr(ctx, "__exit__")
|
||||
|
||||
|
||||
class TestReceiveResponseCompat:
|
||||
"""Verify ClaudeSDKClient.receive_response() exists (langsmith patches it)."""
|
||||
|
||||
def test_receive_response_exists(self):
|
||||
from claude_agent_sdk import ClaudeSDKClient
|
||||
|
||||
assert hasattr(ClaudeSDKClient, "receive_response")
|
||||
|
||||
def test_receive_response_is_async_generator(self):
|
||||
import inspect
|
||||
|
||||
from claude_agent_sdk import ClaudeSDKClient
|
||||
|
||||
method = getattr(ClaudeSDKClient, "receive_response")
|
||||
assert inspect.isfunction(method) or inspect.ismethod(method)
|
||||
@@ -118,7 +118,7 @@ async def test_build_query_resume_up_to_date():
|
||||
ChatMessage(role="user", content="what's new?"),
|
||||
]
|
||||
)
|
||||
result = await _build_query_message(
|
||||
result, was_compacted = await _build_query_message(
|
||||
"what's new?",
|
||||
session,
|
||||
use_resume=True,
|
||||
@@ -127,6 +127,7 @@ async def test_build_query_resume_up_to_date():
|
||||
)
|
||||
# transcript_msg_count == msg_count - 1, so no gap
|
||||
assert result == "what's new?"
|
||||
assert was_compacted is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -141,7 +142,7 @@ async def test_build_query_resume_stale_transcript():
|
||||
ChatMessage(role="user", content="turn 3"),
|
||||
]
|
||||
)
|
||||
result = await _build_query_message(
|
||||
result, was_compacted = await _build_query_message(
|
||||
"turn 3",
|
||||
session,
|
||||
use_resume=True,
|
||||
@@ -152,6 +153,7 @@ async def test_build_query_resume_stale_transcript():
|
||||
assert "turn 2" in result
|
||||
assert "reply 2" in result
|
||||
assert "Now, the user says:\nturn 3" in result
|
||||
assert was_compacted is False # gap context does not compact
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -164,7 +166,7 @@ async def test_build_query_resume_zero_msg_count():
|
||||
ChatMessage(role="user", content="new msg"),
|
||||
]
|
||||
)
|
||||
result = await _build_query_message(
|
||||
result, was_compacted = await _build_query_message(
|
||||
"new msg",
|
||||
session,
|
||||
use_resume=True,
|
||||
@@ -172,13 +174,14 @@ async def test_build_query_resume_zero_msg_count():
|
||||
session_id="test-session",
|
||||
)
|
||||
assert result == "new msg"
|
||||
assert was_compacted is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_no_resume_single_message():
|
||||
"""Without --resume and only 1 message, return raw message."""
|
||||
session = _make_session([ChatMessage(role="user", content="first")])
|
||||
result = await _build_query_message(
|
||||
result, was_compacted = await _build_query_message(
|
||||
"first",
|
||||
session,
|
||||
use_resume=False,
|
||||
@@ -186,6 +189,7 @@ async def test_build_query_no_resume_single_message():
|
||||
session_id="test-session",
|
||||
)
|
||||
assert result == "first"
|
||||
assert was_compacted is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -199,16 +203,16 @@ async def test_build_query_no_resume_multi_message(monkeypatch):
|
||||
]
|
||||
)
|
||||
|
||||
# Mock _compress_conversation_history to return the messages as-is
|
||||
async def _mock_compress(sess):
|
||||
return sess.messages[:-1]
|
||||
# Mock _compress_messages to return the messages as-is
|
||||
async def _mock_compress(msgs):
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_conversation_history",
|
||||
"backend.copilot.sdk.service._compress_messages",
|
||||
_mock_compress,
|
||||
)
|
||||
|
||||
result = await _build_query_message(
|
||||
result, was_compacted = await _build_query_message(
|
||||
"new question",
|
||||
session,
|
||||
use_resume=False,
|
||||
@@ -219,3 +223,33 @@ async def test_build_query_no_resume_multi_message(monkeypatch):
|
||||
assert "older question" in result
|
||||
assert "older answer" in result
|
||||
assert "Now, the user says:\nnew question" in result
|
||||
assert was_compacted is False # mock returns False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_no_resume_multi_message_compacted(monkeypatch):
|
||||
"""When compression actually compacts, was_compacted should be True."""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="old"),
|
||||
ChatMessage(role="assistant", content="reply"),
|
||||
ChatMessage(role="user", content="new"),
|
||||
]
|
||||
)
|
||||
|
||||
async def _mock_compress(msgs):
|
||||
return msgs, True # Simulate actual compaction
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages",
|
||||
_mock_compress,
|
||||
)
|
||||
|
||||
result, was_compacted = await _build_query_message(
|
||||
"new",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=0,
|
||||
session_id="test-session",
|
||||
)
|
||||
assert was_compacted is True
|
||||
|
||||
@@ -6,7 +6,6 @@ ensuring multi-user isolation and preventing unauthorized operations.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
@@ -16,6 +15,7 @@ from .tool_adapter import (
|
||||
DANGEROUS_PATTERNS,
|
||||
MCP_TOOL_PREFIX,
|
||||
WORKSPACE_SCOPED_TOOLS,
|
||||
is_allowed_local_path,
|
||||
stash_pending_tool_output,
|
||||
)
|
||||
|
||||
@@ -38,40 +38,20 @@ def _validate_workspace_path(
|
||||
) -> dict[str, Any]:
|
||||
"""Validate that a workspace-scoped tool only accesses allowed paths.
|
||||
|
||||
Allowed directories:
|
||||
Delegates to :func:`is_allowed_local_path` which permits:
|
||||
- The SDK working directory (``/tmp/copilot-<session>/``)
|
||||
- The SDK tool-results directory (``~/.claude/projects/…/tool-results/``)
|
||||
- The current session's tool-results directory
|
||||
(``~/.claude/projects/<encoded-cwd>/tool-results/``)
|
||||
"""
|
||||
path = tool_input.get("file_path") or tool_input.get("path") or ""
|
||||
if not path:
|
||||
# Glob/Grep without a path default to cwd which is already sandboxed
|
||||
return {}
|
||||
|
||||
# Resolve relative paths against sdk_cwd (the SDK sets cwd so the LLM
|
||||
# naturally uses relative paths like "test.txt" instead of absolute ones).
|
||||
# Tilde paths (~/) are home-dir references, not relative — expand first.
|
||||
if path.startswith("~"):
|
||||
resolved = os.path.realpath(os.path.expanduser(path))
|
||||
elif not os.path.isabs(path) and sdk_cwd:
|
||||
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
|
||||
else:
|
||||
resolved = os.path.realpath(path)
|
||||
|
||||
# Allow access within the SDK working directory
|
||||
if sdk_cwd:
|
||||
norm_cwd = os.path.realpath(sdk_cwd)
|
||||
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
|
||||
return {}
|
||||
|
||||
# Allow access to ~/.claude/projects/*/tool-results/ (big tool results)
|
||||
claude_dir = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
tool_results_seg = os.sep + "tool-results" + os.sep
|
||||
if resolved.startswith(claude_dir + os.sep) and tool_results_seg in resolved:
|
||||
if is_allowed_local_path(path, sdk_cwd):
|
||||
return {}
|
||||
|
||||
logger.warning(
|
||||
f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})"
|
||||
)
|
||||
logger.warning(f"Blocked {tool_name} outside workspace: {path}")
|
||||
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
|
||||
return _deny(
|
||||
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
|
||||
@@ -146,7 +126,7 @@ def create_security_hooks(
|
||||
user_id: str | None,
|
||||
sdk_cwd: str | None = None,
|
||||
max_subtasks: int = 3,
|
||||
on_stop: Callable[[str, str], None] | None = None,
|
||||
on_compact: Callable[[], None] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create the security hooks configuration for Claude Agent SDK.
|
||||
|
||||
@@ -155,15 +135,12 @@ def create_security_hooks(
|
||||
- PostToolUse: Log successful tool executions
|
||||
- PostToolUseFailure: Log and handle failed tool executions
|
||||
- PreCompact: Log context compaction events (SDK handles compaction automatically)
|
||||
- Stop: Capture transcript path for stateless resume (when *on_stop* is provided)
|
||||
|
||||
Args:
|
||||
user_id: Current user ID for isolation validation
|
||||
sdk_cwd: SDK working directory for workspace-scoped tool validation
|
||||
max_subtasks: Maximum concurrent Task (sub-agent) spawns allowed per session
|
||||
on_stop: Callback ``(transcript_path, sdk_session_id)`` invoked when
|
||||
the SDK finishes processing — used to read the JSONL transcript
|
||||
before the CLI process exits.
|
||||
on_compact: Callback invoked when SDK starts compacting context.
|
||||
|
||||
Returns:
|
||||
Hooks configuration dict for ClaudeAgentOptions
|
||||
@@ -326,30 +303,8 @@ def create_security_hooks(
|
||||
logger.info(
|
||||
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
|
||||
)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
# --- Stop hook: capture transcript path for stateless resume ---
|
||||
async def stop_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Capture transcript path when SDK finishes processing.
|
||||
|
||||
The Stop hook fires while the CLI process is still alive, giving us
|
||||
a reliable window to read the JSONL transcript before SIGTERM.
|
||||
"""
|
||||
_ = context, tool_use_id
|
||||
transcript_path = cast(str, input_data.get("transcript_path", ""))
|
||||
sdk_session_id = cast(str, input_data.get("session_id", ""))
|
||||
|
||||
if transcript_path and on_stop:
|
||||
logger.info(
|
||||
f"[SDK] Stop hook: transcript_path={transcript_path}, "
|
||||
f"sdk_session_id={sdk_session_id[:12]}..."
|
||||
)
|
||||
on_stop(transcript_path, sdk_session_id)
|
||||
|
||||
if on_compact is not None:
|
||||
on_compact()
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
hooks: dict[str, Any] = {
|
||||
@@ -361,9 +316,6 @@ def create_security_hooks(
|
||||
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
|
||||
}
|
||||
|
||||
if on_stop is not None:
|
||||
hooks["Stop"] = [HookMatcher(matcher=None, hooks=[stop_hook])]
|
||||
|
||||
return hooks
|
||||
except ImportError:
|
||||
# Fallback for when SDK isn't available - return empty hooks
|
||||
|
||||
@@ -120,17 +120,31 @@ def test_read_no_cwd_denies_absolute():
|
||||
|
||||
|
||||
def test_read_tool_results_allowed():
|
||||
from .tool_adapter import _current_project_dir
|
||||
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
# is_allowed_local_path requires the session's encoded cwd to be set
|
||||
token = _current_project_dir.set("-tmp-copilot-abc123")
|
||||
try:
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
|
||||
|
||||
def test_read_claude_projects_without_tool_results_denied():
|
||||
def test_read_claude_projects_session_dir_allowed():
|
||||
"""Files within the current session's project dir are allowed."""
|
||||
from .tool_adapter import _current_project_dir
|
||||
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert _is_denied(result)
|
||||
token = _current_project_dir.set("-tmp-copilot-abc123")
|
||||
try:
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert not _is_denied(result)
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
|
||||
|
||||
# -- Built-in Bash is blocked (use bash_exec MCP tool instead) ---------------
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
247
autogpt_platform/backend/backend/copilot/sdk/service_test.py
Normal file
247
autogpt_platform/backend/backend/copilot/sdk/service_test.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""Tests for SDK service helpers."""
|
||||
|
||||
import base64
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .service import _prepare_file_attachments
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeFileInfo:
|
||||
id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
|
||||
|
||||
_PATCH_TARGET = "backend.copilot.sdk.service.get_manager"
|
||||
|
||||
|
||||
class TestPrepareFileAttachments:
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_list_returns_empty(self, tmp_path):
|
||||
result = await _prepare_file_attachments([], "u", "s", str(tmp_path))
|
||||
assert result.hint == ""
|
||||
assert result.image_blocks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_embedded_as_vision_block(self, tmp_path):
|
||||
"""JPEG images should become vision content blocks, not files on disk."""
|
||||
raw = b"\xff\xd8\xff\xe0fake-jpeg"
|
||||
info = _FakeFileInfo(
|
||||
id="abc",
|
||||
name="photo.jpg",
|
||||
path="/photo.jpg",
|
||||
mime_type="image/jpeg",
|
||||
size_bytes=len(raw),
|
||||
)
|
||||
mgr = AsyncMock()
|
||||
mgr.get_file_info.return_value = info
|
||||
mgr.read_file_by_id.return_value = raw
|
||||
|
||||
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
|
||||
result = await _prepare_file_attachments(
|
||||
["abc"], "user1", "sess1", str(tmp_path)
|
||||
)
|
||||
|
||||
assert "1 file" in result.hint
|
||||
assert "photo.jpg" in result.hint
|
||||
assert "embedded as image" in result.hint
|
||||
assert len(result.image_blocks) == 1
|
||||
block = result.image_blocks[0]
|
||||
assert block["type"] == "image"
|
||||
assert block["source"]["media_type"] == "image/jpeg"
|
||||
assert block["source"]["data"] == base64.b64encode(raw).decode("ascii")
|
||||
# Image should NOT be written to disk (embedded instead)
|
||||
assert not os.path.exists(os.path.join(tmp_path, "photo.jpg"))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pdf_saved_to_disk(self, tmp_path):
|
||||
"""PDFs should be saved to disk for Read tool access, not embedded."""
|
||||
info = _FakeFileInfo("f1", "doc.pdf", "/doc.pdf", "application/pdf", 50)
|
||||
mgr = AsyncMock()
|
||||
mgr.get_file_info.return_value = info
|
||||
mgr.read_file_by_id.return_value = b"%PDF-1.4 fake"
|
||||
|
||||
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
|
||||
result = await _prepare_file_attachments(["f1"], "u", "s", str(tmp_path))
|
||||
|
||||
assert result.image_blocks == []
|
||||
saved = tmp_path / "doc.pdf"
|
||||
assert saved.exists()
|
||||
assert saved.read_bytes() == b"%PDF-1.4 fake"
|
||||
assert str(saved) in result.hint
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_images_and_files(self, tmp_path):
|
||||
"""Images become blocks, non-images go to disk."""
|
||||
infos = {
|
||||
"id1": _FakeFileInfo("id1", "a.png", "/a.png", "image/png", 4),
|
||||
"id2": _FakeFileInfo("id2", "b.pdf", "/b.pdf", "application/pdf", 4),
|
||||
"id3": _FakeFileInfo("id3", "c.txt", "/c.txt", "text/plain", 4),
|
||||
}
|
||||
mgr = AsyncMock()
|
||||
mgr.get_file_info.side_effect = lambda fid: infos[fid]
|
||||
mgr.read_file_by_id.return_value = b"data"
|
||||
|
||||
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
|
||||
result = await _prepare_file_attachments(
|
||||
["id1", "id2", "id3"], "u", "s", str(tmp_path)
|
||||
)
|
||||
|
||||
assert "3 files" in result.hint
|
||||
assert "a.png" in result.hint
|
||||
assert "b.pdf" in result.hint
|
||||
assert "c.txt" in result.hint
|
||||
# Only the image should be a vision block
|
||||
assert len(result.image_blocks) == 1
|
||||
assert result.image_blocks[0]["source"]["media_type"] == "image/png"
|
||||
# Non-image files should be on disk
|
||||
assert (tmp_path / "b.pdf").exists()
|
||||
assert (tmp_path / "c.txt").exists()
|
||||
# Read tool hint should appear (has non-image files)
|
||||
assert "Read tool" in result.hint
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_singular_noun(self, tmp_path):
|
||||
info = _FakeFileInfo("x", "only.txt", "/only.txt", "text/plain", 2)
|
||||
mgr = AsyncMock()
|
||||
mgr.get_file_info.return_value = info
|
||||
mgr.read_file_by_id.return_value = b"hi"
|
||||
|
||||
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
|
||||
result = await _prepare_file_attachments(["x"], "u", "s", str(tmp_path))
|
||||
|
||||
assert "1 file." in result.hint
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_file_skipped(self, tmp_path):
|
||||
mgr = AsyncMock()
|
||||
mgr.get_file_info.return_value = None
|
||||
|
||||
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
|
||||
result = await _prepare_file_attachments(
|
||||
["missing-id"], "u", "s", str(tmp_path)
|
||||
)
|
||||
|
||||
assert result.hint == ""
|
||||
assert result.image_blocks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_only_no_read_hint(self, tmp_path):
|
||||
"""When all files are images, no Read tool hint should appear."""
|
||||
info = _FakeFileInfo("i1", "cat.png", "/cat.png", "image/png", 4)
|
||||
mgr = AsyncMock()
|
||||
mgr.get_file_info.return_value = info
|
||||
mgr.read_file_by_id.return_value = b"data"
|
||||
|
||||
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
|
||||
result = await _prepare_file_attachments(["i1"], "u", "s", str(tmp_path))
|
||||
|
||||
assert "Read tool" not in result.hint
|
||||
assert len(result.image_blocks) == 1
|
||||
|
||||
|
||||
class TestPromptSupplement:
|
||||
"""Tests for centralized prompt supplement generation."""
|
||||
|
||||
def test_sdk_supplement_excludes_tool_docs(self):
|
||||
"""SDK mode should NOT include tool documentation (Claude gets schemas automatically)."""
|
||||
from backend.copilot.prompting import get_sdk_supplement
|
||||
|
||||
# Test both local and E2B modes
|
||||
local_supplement = get_sdk_supplement(use_e2b=False, cwd="/tmp/test")
|
||||
e2b_supplement = get_sdk_supplement(use_e2b=True, cwd="")
|
||||
|
||||
# Should NOT have tool list section
|
||||
assert "## AVAILABLE TOOLS" not in local_supplement
|
||||
assert "## AVAILABLE TOOLS" not in e2b_supplement
|
||||
|
||||
# Should still have technical notes
|
||||
assert "## Tool notes" in local_supplement
|
||||
assert "## Tool notes" in e2b_supplement
|
||||
|
||||
def test_baseline_supplement_includes_tool_docs(self):
|
||||
"""Baseline mode MUST include tool documentation (direct API needs it)."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
|
||||
supplement = get_baseline_supplement()
|
||||
|
||||
# MUST have tool list section
|
||||
assert "## AVAILABLE TOOLS" in supplement
|
||||
|
||||
# Should NOT have environment-specific notes (SDK-only)
|
||||
assert "## Tool notes" not in supplement
|
||||
|
||||
def test_baseline_supplement_includes_key_tools(self):
|
||||
"""Baseline supplement should document all essential tools."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
docs = get_baseline_supplement()
|
||||
|
||||
# Core agent workflow tools (always available)
|
||||
assert "`create_agent`" in docs
|
||||
assert "`run_agent`" in docs
|
||||
assert "`find_library_agent`" in docs
|
||||
assert "`edit_agent`" in docs
|
||||
|
||||
# MCP integration (always available)
|
||||
assert "`run_mcp_tool`" in docs
|
||||
|
||||
# Folder management (always available)
|
||||
assert "`create_folder`" in docs
|
||||
|
||||
# Browser tools only if available (Playwright may not be installed in CI)
|
||||
if (
|
||||
TOOL_REGISTRY.get("browser_navigate")
|
||||
and TOOL_REGISTRY["browser_navigate"].is_available
|
||||
):
|
||||
assert "`browser_navigate`" in docs
|
||||
|
||||
def test_baseline_supplement_includes_workflows(self):
|
||||
"""Baseline supplement should include workflow guidance in tool descriptions."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
|
||||
docs = get_baseline_supplement()
|
||||
|
||||
# Workflows are now in individual tool descriptions (not separate sections)
|
||||
# Check that key workflow concepts appear in tool descriptions
|
||||
assert "suggested_goal" in docs or "clarifying_questions" in docs
|
||||
assert "run_mcp_tool" in docs
|
||||
|
||||
def test_baseline_supplement_completeness(self):
|
||||
"""All available tools from TOOL_REGISTRY should appear in baseline supplement."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
docs = get_baseline_supplement()
|
||||
|
||||
# Verify each available registered tool is documented
|
||||
# (matches _generate_tool_documentation which filters by is_available)
|
||||
for tool_name, tool in TOOL_REGISTRY.items():
|
||||
if not tool.is_available:
|
||||
continue
|
||||
assert (
|
||||
f"`{tool_name}`" in docs
|
||||
), f"Tool '{tool_name}' missing from baseline supplement"
|
||||
|
||||
def test_baseline_supplement_no_duplicate_tools(self):
|
||||
"""No tool should appear multiple times in baseline supplement."""
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
docs = get_baseline_supplement()
|
||||
|
||||
# Count occurrences of each available tool in the entire supplement
|
||||
for tool_name, tool in TOOL_REGISTRY.items():
|
||||
if not tool.is_available:
|
||||
continue
|
||||
# Count how many times this tool appears as a bullet point
|
||||
count = docs.count(f"- **`{tool_name}`**")
|
||||
assert count == 1, f"Tool '{tool_name}' appears {count} times (should be 1)"
|
||||
@@ -9,20 +9,84 @@ import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
from backend.copilot.tools.base import BaseTool
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
from .e2b_file_tools import E2B_FILE_TOOL_NAMES, E2B_FILE_TOOLS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Allowed base directory for the Read tool (SDK saves oversized tool results here).
|
||||
# Restricted to ~/.claude/projects/ and further validated to require "tool-results"
|
||||
# in the path — prevents reading settings, credentials, or other sensitive files.
|
||||
_SDK_PROJECTS_DIR = os.path.expanduser("~/.claude/projects/")
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
# Max MCP response size in chars — keeps tool output under the SDK's 10 MB JSON buffer.
|
||||
_MCP_MAX_CHARS = 500_000
|
||||
|
||||
# Context variable holding the encoded project directory name for the current
|
||||
# session (e.g. "-private-tmp-copilot-<uuid>"). Set by set_execution_context()
|
||||
# so that path validation can scope tool-results reads to the current session.
|
||||
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
|
||||
|
||||
|
||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does.
|
||||
|
||||
The CLI replaces all non-alphanumeric characters with ``-``.
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
|
||||
|
||||
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
"""Check whether *path* is an allowed host-filesystem path.
|
||||
|
||||
Allowed:
|
||||
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/`` — the SDK's
|
||||
project directory for this session (tool-results, transcripts, etc.)
|
||||
|
||||
Both checks are scoped to the **current session** so sessions cannot
|
||||
read each other's data.
|
||||
"""
|
||||
if not path:
|
||||
return False
|
||||
|
||||
if path.startswith("~"):
|
||||
resolved = os.path.realpath(os.path.expanduser(path))
|
||||
elif not os.path.isabs(path) and sdk_cwd:
|
||||
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
|
||||
else:
|
||||
resolved = os.path.realpath(path)
|
||||
|
||||
# Allow access within the SDK working directory
|
||||
if sdk_cwd:
|
||||
norm_cwd = os.path.realpath(sdk_cwd)
|
||||
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
|
||||
return True
|
||||
|
||||
# Allow access within the current session's CLI project directory
|
||||
# (~/.claude/projects/<encoded-cwd>/).
|
||||
encoded = _current_project_dir.get("")
|
||||
if encoded:
|
||||
session_project = os.path.join(_SDK_PROJECTS_DIR, encoded)
|
||||
if resolved == session_project or resolved.startswith(session_project + os.sep):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
|
||||
MCP_SERVER_NAME = "copilot"
|
||||
@@ -33,6 +97,15 @@ _current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default
|
||||
_current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||
"current_session", default=None
|
||||
)
|
||||
# E2B cloud sandbox for the current turn (None when E2B is not configured).
|
||||
# Passed to bash_exec so commands run on E2B instead of the local bwrap sandbox.
|
||||
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
|
||||
"_current_sandbox", default=None
|
||||
)
|
||||
# Raw SDK working directory path (e.g. /tmp/copilot-<session_id>).
|
||||
# Used by workspace tools to save binary files for the CLI's built-in Read.
|
||||
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
|
||||
|
||||
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||
# response adapter when it builds StreamToolOutputAvailable.
|
||||
@@ -53,22 +126,39 @@ _stash_event: ContextVar[asyncio.Event | None] = ContextVar(
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
sandbox: "AsyncSandbox | None" = None,
|
||||
sdk_cwd: str | None = None,
|
||||
) -> None:
|
||||
"""Set the execution context for tool calls.
|
||||
|
||||
This must be called before streaming begins to ensure tools have access
|
||||
to user_id and session information.
|
||||
to user_id, session, and (optionally) an E2B sandbox for bash execution.
|
||||
|
||||
Args:
|
||||
user_id: Current user's ID.
|
||||
session: Current chat session.
|
||||
sandbox: Optional E2B sandbox; when set, bash_exec routes commands there.
|
||||
sdk_cwd: SDK working directory; used to scope tool-results reads.
|
||||
"""
|
||||
_current_user_id.set(user_id)
|
||||
_current_session.set(session)
|
||||
_current_sandbox.set(sandbox)
|
||||
_current_sdk_cwd.set(sdk_cwd or "")
|
||||
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
|
||||
_pending_tool_outputs.set({})
|
||||
_stash_event.set(asyncio.Event())
|
||||
|
||||
|
||||
def get_current_sandbox() -> "AsyncSandbox | None":
|
||||
"""Return the E2B sandbox for the current turn, or None."""
|
||||
return _current_sandbox.get()
|
||||
|
||||
|
||||
def get_sdk_cwd() -> str:
|
||||
"""Return the SDK ephemeral working directory for the current turn."""
|
||||
return _current_sdk_cwd.get()
|
||||
|
||||
|
||||
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
"""Get the current execution context."""
|
||||
return (
|
||||
@@ -182,66 +272,12 @@ async def _execute_tool_sync(
|
||||
result.output if isinstance(result.output, str) else json.dumps(result.output)
|
||||
)
|
||||
|
||||
# Stash the full output before the SDK potentially truncates it.
|
||||
pending = _pending_tool_outputs.get(None)
|
||||
if pending is not None:
|
||||
pending.setdefault(base_tool.name, []).append(text)
|
||||
|
||||
content_blocks: list[dict[str, str]] = [{"type": "text", "text": text}]
|
||||
|
||||
# If the tool result contains inline image data, add an MCP image block
|
||||
# so Claude can "see" the image (e.g. read_workspace_file on a small PNG).
|
||||
image_block = _extract_image_block(text)
|
||||
if image_block:
|
||||
content_blocks.append(image_block)
|
||||
|
||||
return {
|
||||
"content": content_blocks,
|
||||
"content": [{"type": "text", "text": text}],
|
||||
"isError": not result.success,
|
||||
}
|
||||
|
||||
|
||||
# MIME types that Claude can process as image content blocks.
|
||||
_SUPPORTED_IMAGE_TYPES = frozenset(
|
||||
{"image/png", "image/jpeg", "image/gif", "image/webp"}
|
||||
)
|
||||
|
||||
|
||||
def _extract_image_block(text: str) -> dict[str, str] | None:
|
||||
"""Extract an MCP image content block from a tool result JSON string.
|
||||
|
||||
Detects workspace file responses with ``content_base64`` and an image
|
||||
MIME type, returning an MCP-format image block that allows Claude to
|
||||
"see" the image. Returns ``None`` if the result is not an inline image.
|
||||
"""
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
|
||||
mime_type = data.get("mime_type", "")
|
||||
base64_content = data.get("content_base64", "")
|
||||
|
||||
# Only inline small images — large ones would exceed Claude's limits.
|
||||
# 32 KB raw ≈ ~43 KB base64.
|
||||
_MAX_IMAGE_BASE64_BYTES = 43_000
|
||||
if (
|
||||
mime_type in _SUPPORTED_IMAGE_TYPES
|
||||
and base64_content
|
||||
and len(base64_content) <= _MAX_IMAGE_BASE64_BYTES
|
||||
):
|
||||
return {
|
||||
"type": "image",
|
||||
"data": base64_content,
|
||||
"mimeType": mime_type,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _mcp_error(message: str) -> dict[str, Any]:
|
||||
return {
|
||||
"content": [
|
||||
@@ -284,29 +320,32 @@ def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
|
||||
|
||||
|
||||
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Read a file with optional offset/limit. Restricted to SDK working directory.
|
||||
"""Read a local file with optional offset/limit.
|
||||
|
||||
After reading, the file is deleted to prevent accumulation in long-running pods.
|
||||
Only allows paths that pass :func:`is_allowed_local_path` — the current
|
||||
session's tool-results directory and ephemeral working directory.
|
||||
"""
|
||||
file_path = args.get("file_path", "")
|
||||
offset = args.get("offset", 0)
|
||||
limit = args.get("limit", 2000)
|
||||
|
||||
# Security: only allow reads under ~/.claude/projects/**/tool-results/
|
||||
real_path = os.path.realpath(file_path)
|
||||
if not real_path.startswith(_SDK_PROJECTS_DIR) or "tool-results" not in real_path:
|
||||
if not is_allowed_local_path(file_path):
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
resolved = os.path.realpath(os.path.expanduser(file_path))
|
||||
try:
|
||||
with open(real_path) as f:
|
||||
with open(resolved) as f:
|
||||
selected = list(itertools.islice(f, offset, offset + limit))
|
||||
content = "".join(selected)
|
||||
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
|
||||
# don't delete here — the SDK may read in multiple chunks.
|
||||
return {"content": [{"type": "text", "text": content}], "isError": False}
|
||||
return {
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"isError": False,
|
||||
}
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
|
||||
@@ -344,50 +383,86 @@ _READ_TOOL_SCHEMA = {
|
||||
}
|
||||
|
||||
|
||||
# Create the MCP server configuration
|
||||
def create_copilot_mcp_server():
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCP result helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _text_from_mcp_result(result: dict[str, Any]) -> str:
|
||||
"""Extract concatenated text from an MCP response's content blocks."""
|
||||
content = result.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
return ""
|
||||
return "".join(
|
||||
b.get("text", "")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
)
|
||||
|
||||
|
||||
def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
"""Create an in-process MCP server configuration for CoPilot tools.
|
||||
|
||||
This can be passed to ClaudeAgentOptions.mcp_servers.
|
||||
|
||||
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
|
||||
package being available. This function returns the configuration that
|
||||
can be used with the SDK.
|
||||
When *use_e2b* is True, five additional MCP file tools are registered
|
||||
that route directly to the E2B sandbox filesystem, and the caller should
|
||||
disable the corresponding SDK built-in tools via
|
||||
:func:`get_sdk_disallowed_tools`.
|
||||
"""
|
||||
try:
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
|
||||
# Create decorated tool functions
|
||||
sdk_tools = []
|
||||
def _truncating(fn, tool_name: str):
|
||||
"""Wrap a tool handler so its response is truncated to stay under the
|
||||
SDK's 10 MB JSON buffer, and stash the (truncated) output for the
|
||||
response adapter before the SDK can apply its own head-truncation.
|
||||
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
handler = create_tool_handler(base_tool)
|
||||
decorated = tool(
|
||||
tool_name,
|
||||
base_tool.description,
|
||||
_build_input_schema(base_tool),
|
||||
)(handler)
|
||||
Applied once to every registered tool."""
|
||||
|
||||
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
|
||||
result = await fn(args)
|
||||
truncated = truncate(result, _MCP_MAX_CHARS)
|
||||
|
||||
# Stash the text so the response adapter can forward our
|
||||
# middle-out truncated version to the frontend instead of the
|
||||
# SDK's head-truncated version (for outputs >~100 KB the SDK
|
||||
# persists to tool-results/ with a 2 KB head-only preview).
|
||||
if not truncated.get("isError"):
|
||||
text = _text_from_mcp_result(truncated)
|
||||
if text:
|
||||
stash_pending_tool_output(tool_name, text)
|
||||
|
||||
return truncated
|
||||
|
||||
return wrapper
|
||||
|
||||
sdk_tools = []
|
||||
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
handler = create_tool_handler(base_tool)
|
||||
decorated = tool(
|
||||
tool_name,
|
||||
base_tool.description,
|
||||
_build_input_schema(base_tool),
|
||||
)(_truncating(handler, tool_name))
|
||||
sdk_tools.append(decorated)
|
||||
|
||||
# E2B file tools replace SDK built-in Read/Write/Edit/Glob/Grep.
|
||||
if use_e2b:
|
||||
for name, desc, schema, handler in E2B_FILE_TOOLS:
|
||||
decorated = tool(name, desc, schema)(_truncating(handler, name))
|
||||
sdk_tools.append(decorated)
|
||||
|
||||
# Add the Read tool so the SDK can read back oversized tool results
|
||||
read_tool = tool(
|
||||
_READ_TOOL_NAME,
|
||||
_READ_TOOL_DESCRIPTION,
|
||||
_READ_TOOL_SCHEMA,
|
||||
)(_read_file_handler)
|
||||
sdk_tools.append(read_tool)
|
||||
# Read tool for SDK-truncated tool results (always needed).
|
||||
read_tool = tool(
|
||||
_READ_TOOL_NAME,
|
||||
_READ_TOOL_DESCRIPTION,
|
||||
_READ_TOOL_SCHEMA,
|
||||
)(_truncating(_read_file_handler, _READ_TOOL_NAME))
|
||||
sdk_tools.append(read_tool)
|
||||
|
||||
server = create_sdk_mcp_server(
|
||||
name=MCP_SERVER_NAME,
|
||||
version="1.0.0",
|
||||
tools=sdk_tools,
|
||||
)
|
||||
|
||||
return server
|
||||
|
||||
except ImportError:
|
||||
# Let ImportError propagate so service.py handles the fallback
|
||||
raise
|
||||
return create_sdk_mcp_server(
|
||||
name=MCP_SERVER_NAME,
|
||||
version="1.0.0",
|
||||
tools=sdk_tools,
|
||||
)
|
||||
|
||||
|
||||
# SDK built-in tools allowed within the workspace directory.
|
||||
@@ -397,16 +472,11 @@ def create_copilot_mcp_server():
|
||||
# Task allows spawning sub-agents (rate-limited by security hooks).
|
||||
# WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk.
|
||||
# TodoWrite manages the task checklist shown in the UI — no security concern.
|
||||
_SDK_BUILTIN_TOOLS = [
|
||||
"Read",
|
||||
"Write",
|
||||
"Edit",
|
||||
"Glob",
|
||||
"Grep",
|
||||
"Task",
|
||||
"WebSearch",
|
||||
"TodoWrite",
|
||||
]
|
||||
# In E2B mode, all five are disabled — MCP equivalents provide direct sandbox
|
||||
# access. read_file also handles local tool-results and ephemeral reads.
|
||||
_SDK_BUILTIN_FILE_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep"]
|
||||
_SDK_BUILTIN_ALWAYS = ["Task", "WebSearch", "TodoWrite"]
|
||||
_SDK_BUILTIN_TOOLS = [*_SDK_BUILTIN_FILE_TOOLS, *_SDK_BUILTIN_ALWAYS]
|
||||
|
||||
# SDK built-in tools that must be explicitly blocked.
|
||||
# Bash: dangerous — agent uses mcp__copilot__bash_exec with kernel-level
|
||||
@@ -453,11 +523,37 @@ DANGEROUS_PATTERNS = [
|
||||
r"subprocess",
|
||||
]
|
||||
|
||||
# List of tool names for allowed_tools configuration
|
||||
# Include MCP tools, the MCP Read tool for oversized results,
|
||||
# and SDK built-in file tools for workspace operations.
|
||||
# Static tool name list for the non-E2B case (backward compatibility).
|
||||
COPILOT_TOOL_NAMES = [
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
|
||||
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
|
||||
*_SDK_BUILTIN_TOOLS,
|
||||
]
|
||||
|
||||
|
||||
def get_copilot_tool_names(*, use_e2b: bool = False) -> list[str]:
|
||||
"""Build the ``allowed_tools`` list for :class:`ClaudeAgentOptions`.
|
||||
|
||||
When *use_e2b* is True the SDK built-in file tools are replaced by MCP
|
||||
equivalents that route to the E2B sandbox.
|
||||
"""
|
||||
if not use_e2b:
|
||||
return list(COPILOT_TOOL_NAMES)
|
||||
|
||||
return [
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
|
||||
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in E2B_FILE_TOOL_NAMES],
|
||||
*_SDK_BUILTIN_ALWAYS,
|
||||
]
|
||||
|
||||
|
||||
def get_sdk_disallowed_tools(*, use_e2b: bool = False) -> list[str]:
|
||||
"""Build the ``disallowed_tools`` list for :class:`ClaudeAgentOptions`.
|
||||
|
||||
When *use_e2b* is True the SDK built-in file tools are also disabled
|
||||
because MCP equivalents provide direct sandbox access.
|
||||
"""
|
||||
if not use_e2b:
|
||||
return list(SDK_DISALLOWED_TOOLS)
|
||||
return [*SDK_DISALLOWED_TOOLS, *_SDK_BUILTIN_FILE_TOOLS]
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
"""Tests for tool_adapter helpers: truncation, stash, context vars."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
from .tool_adapter import (
|
||||
_MCP_MAX_CHARS,
|
||||
_text_from_mcp_result,
|
||||
get_sdk_cwd,
|
||||
pop_pending_tool_output,
|
||||
set_execution_context,
|
||||
stash_pending_tool_output,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _text_from_mcp_result
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTextFromMcpResult:
|
||||
def test_single_text_block(self):
|
||||
result = {"content": [{"type": "text", "text": "hello"}]}
|
||||
assert _text_from_mcp_result(result) == "hello"
|
||||
|
||||
def test_multiple_text_blocks_concatenated(self):
|
||||
result = {
|
||||
"content": [
|
||||
{"type": "text", "text": "one"},
|
||||
{"type": "text", "text": "two"},
|
||||
]
|
||||
}
|
||||
assert _text_from_mcp_result(result) == "onetwo"
|
||||
|
||||
def test_non_text_blocks_ignored(self):
|
||||
result = {
|
||||
"content": [
|
||||
{"type": "image", "data": "..."},
|
||||
{"type": "text", "text": "only this"},
|
||||
]
|
||||
}
|
||||
assert _text_from_mcp_result(result) == "only this"
|
||||
|
||||
def test_empty_content_list(self):
|
||||
assert _text_from_mcp_result({"content": []}) == ""
|
||||
|
||||
def test_missing_content_key(self):
|
||||
assert _text_from_mcp_result({}) == ""
|
||||
|
||||
def test_non_list_content(self):
|
||||
assert _text_from_mcp_result({"content": "raw string"}) == ""
|
||||
|
||||
def test_missing_text_field(self):
|
||||
result = {"content": [{"type": "text"}]}
|
||||
assert _text_from_mcp_result(result) == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_sdk_cwd
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetSdkCwd:
|
||||
def test_returns_empty_string_by_default(self):
|
||||
set_execution_context(
|
||||
user_id="test",
|
||||
session=None, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
)
|
||||
assert get_sdk_cwd() == ""
|
||||
|
||||
def test_returns_set_value(self):
|
||||
set_execution_context(
|
||||
user_id="test",
|
||||
session=None, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
sdk_cwd="/tmp/copilot-test-123",
|
||||
)
|
||||
assert get_sdk_cwd() == "/tmp/copilot-test-123"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# stash / pop round-trip (the mechanism _truncating relies on)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolOutputStash:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init_context(self):
|
||||
"""Initialise the context vars that stash_pending_tool_output needs."""
|
||||
set_execution_context(
|
||||
user_id="test",
|
||||
session=None, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
sdk_cwd="/tmp/test",
|
||||
)
|
||||
|
||||
def test_stash_and_pop(self):
|
||||
stash_pending_tool_output("my_tool", "output1")
|
||||
assert pop_pending_tool_output("my_tool") == "output1"
|
||||
|
||||
def test_pop_empty_returns_none(self):
|
||||
assert pop_pending_tool_output("nonexistent") is None
|
||||
|
||||
def test_fifo_order(self):
|
||||
stash_pending_tool_output("t", "first")
|
||||
stash_pending_tool_output("t", "second")
|
||||
assert pop_pending_tool_output("t") == "first"
|
||||
assert pop_pending_tool_output("t") == "second"
|
||||
assert pop_pending_tool_output("t") is None
|
||||
|
||||
def test_dict_serialised_to_json(self):
|
||||
stash_pending_tool_output("t", {"key": "value"})
|
||||
assert pop_pending_tool_output("t") == '{"key": "value"}'
|
||||
|
||||
def test_separate_tool_names(self):
|
||||
stash_pending_tool_output("a", "alpha")
|
||||
stash_pending_tool_output("b", "beta")
|
||||
assert pop_pending_tool_output("b") == "beta"
|
||||
assert pop_pending_tool_output("a") == "alpha"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _truncating wrapper (integration via create_copilot_mcp_server)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTruncationAndStashIntegration:
|
||||
"""Test truncation + stash behavior that _truncating relies on."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init_context(self):
|
||||
set_execution_context(
|
||||
user_id="test",
|
||||
session=None, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
sdk_cwd="/tmp/test",
|
||||
)
|
||||
|
||||
def test_small_output_stashed(self):
|
||||
"""Non-error output is stashed for the response adapter."""
|
||||
result = {
|
||||
"content": [{"type": "text", "text": "small output"}],
|
||||
"isError": False,
|
||||
}
|
||||
truncated = truncate(result, _MCP_MAX_CHARS)
|
||||
text = _text_from_mcp_result(truncated)
|
||||
assert text == "small output"
|
||||
stash_pending_tool_output("test_tool", text)
|
||||
assert pop_pending_tool_output("test_tool") == "small output"
|
||||
|
||||
def test_error_result_not_stashed(self):
|
||||
"""Error results should not be stashed."""
|
||||
result = {
|
||||
"content": [{"type": "text", "text": "error msg"}],
|
||||
"isError": True,
|
||||
}
|
||||
# _truncating only stashes when not result.get("isError")
|
||||
if not result.get("isError"):
|
||||
stash_pending_tool_output("err_tool", "should not happen")
|
||||
assert pop_pending_tool_output("err_tool") is None
|
||||
|
||||
def test_large_output_truncated(self):
|
||||
"""Output exceeding _MCP_MAX_CHARS is truncated before stashing."""
|
||||
big_text = "x" * (_MCP_MAX_CHARS + 100_000)
|
||||
result = {"content": [{"type": "text", "text": big_text}]}
|
||||
truncated = truncate(result, _MCP_MAX_CHARS)
|
||||
text = _text_from_mcp_result(truncated)
|
||||
assert len(text) < len(big_text)
|
||||
assert len(str(truncated)) <= _MCP_MAX_CHARS
|
||||
@@ -10,13 +10,14 @@ Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
|
||||
filesystem for self-hosted) — no DB column needed.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from backend.util import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
|
||||
@@ -58,41 +59,37 @@ def strip_progress_entries(content: str) -> str:
|
||||
Removes entries whose ``type`` is in ``STRIPPABLE_TYPES`` and reparents
|
||||
any remaining child entries so the ``parentUuid`` chain stays intact.
|
||||
Typically reduces transcript size by ~30%.
|
||||
|
||||
Entries that are not stripped or reparented are kept as their original
|
||||
raw JSON line to avoid unnecessary re-serialization that changes
|
||||
whitespace or key ordering.
|
||||
"""
|
||||
lines = content.strip().split("\n")
|
||||
|
||||
entries: list[dict] = []
|
||||
# Parse entries, keeping the original line alongside the parsed dict.
|
||||
parsed: list[tuple[str, dict | None]] = []
|
||||
for line in lines:
|
||||
try:
|
||||
entries.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
# Keep unparseable lines as-is (safety)
|
||||
entries.append({"_raw": line})
|
||||
parsed.append((line, json.loads(line, fallback=None)))
|
||||
|
||||
# First pass: identify stripped UUIDs and build parent map.
|
||||
stripped_uuids: set[str] = set()
|
||||
uuid_to_parent: dict[str, str] = {}
|
||||
kept: list[dict] = []
|
||||
|
||||
for entry in entries:
|
||||
if "_raw" in entry:
|
||||
kept.append(entry)
|
||||
for _line, entry in parsed:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
uid = entry.get("uuid", "")
|
||||
parent = entry.get("parentUuid", "")
|
||||
entry_type = entry.get("type", "")
|
||||
|
||||
if uid:
|
||||
uuid_to_parent[uid] = parent
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES and uid:
|
||||
stripped_uuids.add(uid)
|
||||
|
||||
if entry_type in STRIPPABLE_TYPES:
|
||||
if uid:
|
||||
stripped_uuids.add(uid)
|
||||
else:
|
||||
kept.append(entry)
|
||||
|
||||
# Reparent: walk up chain through stripped entries to find surviving ancestor
|
||||
for entry in kept:
|
||||
if "_raw" in entry:
|
||||
# Second pass: keep non-stripped entries, reparenting where needed.
|
||||
# Preserve original line when no reparenting is required.
|
||||
reparented: set[str] = set()
|
||||
for _line, entry in parsed:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
parent = entry.get("parentUuid", "")
|
||||
original_parent = parent
|
||||
@@ -100,63 +97,32 @@ def strip_progress_entries(content: str) -> str:
|
||||
parent = uuid_to_parent.get(parent, "")
|
||||
if parent != original_parent:
|
||||
entry["parentUuid"] = parent
|
||||
uid = entry.get("uuid", "")
|
||||
if uid:
|
||||
reparented.add(uid)
|
||||
|
||||
result_lines: list[str] = []
|
||||
for entry in kept:
|
||||
if "_raw" in entry:
|
||||
result_lines.append(entry["_raw"])
|
||||
else:
|
||||
for line, entry in parsed:
|
||||
if not isinstance(entry, dict):
|
||||
result_lines.append(line)
|
||||
continue
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES:
|
||||
continue
|
||||
uid = entry.get("uuid", "")
|
||||
if uid in reparented:
|
||||
# Re-serialize only entries whose parentUuid was changed.
|
||||
result_lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
else:
|
||||
result_lines.append(line)
|
||||
|
||||
return "\n".join(result_lines) + "\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Local file I/O (read from CLI's JSONL, write temp file for --resume)
|
||||
# Local file I/O (write temp file for --resume)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def read_transcript_file(transcript_path: str) -> str | None:
|
||||
"""Read a JSONL transcript file from disk.
|
||||
|
||||
Returns the raw JSONL content, or ``None`` if the file is missing, empty,
|
||||
or only contains metadata (≤2 lines with no conversation messages).
|
||||
"""
|
||||
if not transcript_path or not os.path.isfile(transcript_path):
|
||||
logger.debug(f"[Transcript] File not found: {transcript_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(transcript_path) as f:
|
||||
content = f.read()
|
||||
|
||||
if not content.strip():
|
||||
logger.debug("[Transcript] File is empty: %s", transcript_path)
|
||||
return None
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
|
||||
# Validate that the transcript has real conversation content
|
||||
# (not just metadata like queue-operation entries).
|
||||
if not validate_transcript(content):
|
||||
logger.debug(
|
||||
"[Transcript] No conversation content (%d lines) in %s",
|
||||
len(lines),
|
||||
transcript_path,
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"[Transcript] Read {len(lines)} lines, "
|
||||
f"{len(content)} bytes from {transcript_path}"
|
||||
)
|
||||
return content
|
||||
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.warning(f"[Transcript] Failed to read {transcript_path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
||||
"""Sanitize an ID for safe use in file paths.
|
||||
|
||||
@@ -171,14 +137,6 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
||||
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
||||
|
||||
|
||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does.
|
||||
|
||||
The CLI replaces all non-alphanumeric characters with ``-``.
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
|
||||
|
||||
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||
"""Remove the CLI's project directory for a specific working directory.
|
||||
|
||||
@@ -188,7 +146,8 @@ def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||
"""
|
||||
import shutil
|
||||
|
||||
cwd_encoded = _encode_cwd_for_cli(sdk_cwd)
|
||||
# Encode cwd the same way CLI does (replaces non-alphanumeric with -)
|
||||
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
|
||||
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
|
||||
@@ -248,32 +207,29 @@ def write_transcript_to_tempfile(
|
||||
def validate_transcript(content: str | None) -> bool:
|
||||
"""Check that a transcript has actual conversation messages.
|
||||
|
||||
A valid transcript for resume needs at least one user message and one
|
||||
assistant message (not just queue-operation / file-history-snapshot
|
||||
metadata).
|
||||
A valid transcript needs at least one assistant message (not just
|
||||
queue-operation / file-history-snapshot metadata). We do NOT require
|
||||
a ``type: "user"`` entry because with ``--resume`` the user's message
|
||||
is passed as a CLI query parameter and does not appear in the
|
||||
transcript file.
|
||||
"""
|
||||
if not content or not content.strip():
|
||||
return False
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
if len(lines) < 2:
|
||||
return False
|
||||
|
||||
has_user = False
|
||||
has_assistant = False
|
||||
|
||||
for line in lines:
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
msg_type = entry.get("type")
|
||||
if msg_type == "user":
|
||||
has_user = True
|
||||
elif msg_type == "assistant":
|
||||
has_assistant = True
|
||||
except json.JSONDecodeError:
|
||||
if not line.strip():
|
||||
continue
|
||||
entry = json.loads(line, fallback=None)
|
||||
if not isinstance(entry, dict):
|
||||
return False
|
||||
if entry.get("type") == "assistant":
|
||||
has_assistant = True
|
||||
|
||||
return has_user and has_assistant
|
||||
return has_assistant
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -328,45 +284,46 @@ async def upload_transcript(
|
||||
session_id: str,
|
||||
content: str,
|
||||
message_count: int = 0,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> None:
|
||||
"""Strip progress entries and upload transcript to bucket storage.
|
||||
"""Strip progress entries and upload complete transcript.
|
||||
|
||||
Safety: only overwrites when the new (stripped) transcript is larger than
|
||||
what is already stored. Since JSONL is append-only, the latest transcript
|
||||
is always the longest. This prevents a slow/stale background task from
|
||||
clobbering a newer upload from a concurrent turn.
|
||||
The transcript represents the FULL active context (atomic).
|
||||
Each upload REPLACES the previous transcript entirely.
|
||||
|
||||
The executor holds a cluster lock per session, so concurrent uploads for
|
||||
the same session cannot happen.
|
||||
|
||||
Args:
|
||||
message_count: ``len(session.messages)`` at upload time — used by
|
||||
the next turn to detect staleness and compress only the gap.
|
||||
content: Complete JSONL transcript (from TranscriptBuilder).
|
||||
message_count: ``len(session.messages)`` at upload time.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
# Strip metadata entries (progress, file-history-snapshot, etc.)
|
||||
# Note: SDK-built transcripts shouldn't have these, but strip for safety
|
||||
stripped = strip_progress_entries(content)
|
||||
if not validate_transcript(stripped):
|
||||
# Log entry types for debugging — helps identify why validation failed
|
||||
entry_types: list[str] = []
|
||||
for line in stripped.strip().split("\n"):
|
||||
entry = json.loads(line, fallback={"type": "INVALID_JSON"})
|
||||
entry_types.append(entry.get("type", "?"))
|
||||
logger.warning(
|
||||
f"[Transcript] Skipping upload — stripped content not valid "
|
||||
f"for session {session_id}"
|
||||
"%s Skipping upload — stripped content not valid "
|
||||
"(types=%s, stripped_len=%d, raw_len=%d)",
|
||||
log_prefix,
|
||||
entry_types,
|
||||
len(stripped),
|
||||
len(content),
|
||||
)
|
||||
logger.debug("%s Raw content preview: %s", log_prefix, content[:500])
|
||||
logger.debug("%s Stripped content: %s", log_prefix, stripped[:500])
|
||||
return
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
wid, fid, fname = _storage_path_parts(user_id, session_id)
|
||||
encoded = stripped.encode("utf-8")
|
||||
new_size = len(encoded)
|
||||
|
||||
# Check existing transcript size to avoid overwriting newer with older
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
try:
|
||||
existing = await storage.retrieve(path)
|
||||
if len(existing) >= new_size:
|
||||
logger.info(
|
||||
f"[Transcript] Skipping upload — existing ({len(existing)}B) "
|
||||
f">= new ({new_size}B) for session {session_id}"
|
||||
)
|
||||
return
|
||||
except (FileNotFoundError, Exception):
|
||||
pass # No existing transcript or retrieval error — proceed with upload
|
||||
|
||||
await storage.store(
|
||||
workspace_id=wid,
|
||||
@@ -375,11 +332,8 @@ async def upload_transcript(
|
||||
content=encoded,
|
||||
)
|
||||
|
||||
# Store metadata alongside the transcript so the next turn can detect
|
||||
# staleness and only compress the gap instead of the full history.
|
||||
# Wrapped in try/except so a metadata write failure doesn't orphan
|
||||
# the already-uploaded transcript — the next turn will just fall back
|
||||
# to full gap fill (msg_count=0).
|
||||
# Update metadata so message_count stays current. The gap-fill logic
|
||||
# in _build_query_message relies on it to avoid re-compressing messages.
|
||||
try:
|
||||
meta = {"message_count": message_count, "uploaded_at": time.time()}
|
||||
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
|
||||
@@ -390,17 +344,18 @@ async def upload_transcript(
|
||||
content=json.dumps(meta).encode("utf-8"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transcript] Failed to write metadata for {session_id}: {e}")
|
||||
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
|
||||
|
||||
logger.info(
|
||||
f"[Transcript] Uploaded {new_size}B "
|
||||
f"(stripped from {len(content)}B, msg_count={message_count}) "
|
||||
f"for session {session_id}"
|
||||
f"{log_prefix} Uploaded {len(encoded)}B "
|
||||
f"(stripped from {len(content)}B, msg_count={message_count})"
|
||||
)
|
||||
|
||||
|
||||
async def download_transcript(
|
||||
user_id: str, session_id: str
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> TranscriptDownload | None:
|
||||
"""Download transcript and metadata from bucket storage.
|
||||
|
||||
@@ -416,10 +371,10 @@ async def download_transcript(
|
||||
data = await storage.retrieve(path)
|
||||
content = data.decode("utf-8")
|
||||
except FileNotFoundError:
|
||||
logger.debug(f"[Transcript] No transcript in storage for {session_id}")
|
||||
logger.debug(f"{log_prefix} No transcript in storage")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transcript] Failed to download transcript: {e}")
|
||||
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
|
||||
return None
|
||||
|
||||
# Try to load metadata (best-effort — old transcripts won't have it)
|
||||
@@ -436,16 +391,13 @@ async def download_transcript(
|
||||
meta_path = f"local://{mwid}/{mfid}/{mfname}"
|
||||
|
||||
meta_data = await storage.retrieve(meta_path)
|
||||
meta = json.loads(meta_data.decode("utf-8"))
|
||||
meta = json.loads(meta_data.decode("utf-8"), fallback={})
|
||||
message_count = meta.get("message_count", 0)
|
||||
uploaded_at = meta.get("uploaded_at", 0.0)
|
||||
except (FileNotFoundError, json.JSONDecodeError, Exception):
|
||||
except (FileNotFoundError, Exception):
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
|
||||
logger.info(
|
||||
f"[Transcript] Downloaded {len(content)}B "
|
||||
f"(msg_count={message_count}) for session {session_id}"
|
||||
)
|
||||
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
|
||||
return TranscriptDownload(
|
||||
content=content,
|
||||
message_count=message_count,
|
||||
|
||||
@@ -0,0 +1,188 @@
|
||||
"""Build complete JSONL transcript from SDK messages.
|
||||
|
||||
The transcript represents the FULL active context at any point in time.
|
||||
Each upload REPLACES the previous transcript atomically.
|
||||
|
||||
Flow:
|
||||
Turn 1: Upload [msg1, msg2]
|
||||
Turn 2: Download [msg1, msg2] → Upload [msg1, msg2, msg3, msg4] (REPLACE)
|
||||
Turn 3: Download [msg1, msg2, msg3, msg4] → Upload [all messages] (REPLACE)
|
||||
|
||||
The transcript is never incremental - always the complete atomic state.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .transcript import STRIPPABLE_TYPES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TranscriptEntry(BaseModel):
|
||||
"""Single transcript entry (user or assistant turn)."""
|
||||
|
||||
type: str
|
||||
uuid: str
|
||||
parentUuid: str | None
|
||||
message: dict[str, Any]
|
||||
|
||||
|
||||
class TranscriptBuilder:
|
||||
"""Build complete JSONL transcript from SDK messages.
|
||||
|
||||
This builder maintains the FULL conversation state, not incremental changes.
|
||||
The output is always the complete active context.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._entries: list[TranscriptEntry] = []
|
||||
self._last_uuid: str | None = None
|
||||
|
||||
def _last_is_assistant(self) -> bool:
|
||||
return bool(self._entries) and self._entries[-1].type == "assistant"
|
||||
|
||||
def _last_message_id(self) -> str:
|
||||
"""Return the message.id of the last entry, or '' if none."""
|
||||
if self._entries:
|
||||
return self._entries[-1].message.get("id", "")
|
||||
return ""
|
||||
|
||||
def load_previous(self, content: str, log_prefix: str = "[Transcript]") -> None:
|
||||
"""Load complete previous transcript.
|
||||
|
||||
This loads the FULL previous context. As new messages come in,
|
||||
we append to this state. The final output is the complete context
|
||||
(previous + new), not just the delta.
|
||||
"""
|
||||
if not content or not content.strip():
|
||||
return
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
data = json.loads(line, fallback=None)
|
||||
if data is None:
|
||||
logger.warning(
|
||||
"%s Failed to parse transcript line %d/%d",
|
||||
log_prefix,
|
||||
line_num,
|
||||
len(lines),
|
||||
)
|
||||
continue
|
||||
|
||||
# Load all non-strippable entries (user/assistant/system/etc.)
|
||||
# Skip only STRIPPABLE_TYPES to match strip_progress_entries() behavior
|
||||
entry_type = data.get("type", "")
|
||||
if entry_type in STRIPPABLE_TYPES:
|
||||
continue
|
||||
|
||||
entry = TranscriptEntry(
|
||||
type=data["type"],
|
||||
uuid=data.get("uuid") or str(uuid4()),
|
||||
parentUuid=data.get("parentUuid"),
|
||||
message=data.get("message", {}),
|
||||
)
|
||||
self._entries.append(entry)
|
||||
self._last_uuid = entry.uuid
|
||||
|
||||
logger.info(
|
||||
"%s Loaded %d entries from previous transcript (last_uuid=%s)",
|
||||
log_prefix,
|
||||
len(self._entries),
|
||||
self._last_uuid[:12] if self._last_uuid else None,
|
||||
)
|
||||
|
||||
def append_user(self, content: str | list[dict], uuid: str | None = None) -> None:
|
||||
"""Append a user entry."""
|
||||
msg_uuid = uuid or str(uuid4())
|
||||
|
||||
self._entries.append(
|
||||
TranscriptEntry(
|
||||
type="user",
|
||||
uuid=msg_uuid,
|
||||
parentUuid=self._last_uuid,
|
||||
message={"role": "user", "content": content},
|
||||
)
|
||||
)
|
||||
self._last_uuid = msg_uuid
|
||||
|
||||
def append_tool_result(self, tool_use_id: str, content: str) -> None:
|
||||
"""Append a tool result as a user entry (one per tool call)."""
|
||||
self.append_user(
|
||||
content=[
|
||||
{"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
|
||||
]
|
||||
)
|
||||
|
||||
def append_assistant(
|
||||
self,
|
||||
content_blocks: list[dict],
|
||||
model: str = "",
|
||||
stop_reason: str | None = None,
|
||||
) -> None:
|
||||
"""Append an assistant entry.
|
||||
|
||||
Consecutive assistant entries automatically share the same message ID
|
||||
so the CLI can merge them (thinking → text → tool_use) into a single
|
||||
API message on ``--resume``. A new ID is assigned whenever an
|
||||
assistant entry follows a non-assistant entry (user message or tool
|
||||
result), because that marks the start of a new API response.
|
||||
"""
|
||||
message_id = (
|
||||
self._last_message_id()
|
||||
if self._last_is_assistant()
|
||||
else f"msg_sdk_{uuid4().hex[:24]}"
|
||||
)
|
||||
|
||||
msg_uuid = str(uuid4())
|
||||
|
||||
self._entries.append(
|
||||
TranscriptEntry(
|
||||
type="assistant",
|
||||
uuid=msg_uuid,
|
||||
parentUuid=self._last_uuid,
|
||||
message={
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"id": message_id,
|
||||
"type": "message",
|
||||
"content": content_blocks,
|
||||
"stop_reason": stop_reason,
|
||||
"stop_sequence": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
self._last_uuid = msg_uuid
|
||||
|
||||
def to_jsonl(self) -> str:
|
||||
"""Export complete context as JSONL.
|
||||
|
||||
Consecutive assistant entries are kept separate to match the
|
||||
native CLI format — the SDK merges them internally on resume.
|
||||
|
||||
Returns the FULL conversation state (all entries), not incremental.
|
||||
This output REPLACES any previous transcript.
|
||||
"""
|
||||
if not self._entries:
|
||||
return ""
|
||||
|
||||
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
@property
|
||||
def entry_count(self) -> int:
|
||||
"""Total number of entries in the complete context."""
|
||||
return len(self._entries)
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Whether this builder has any entries."""
|
||||
return len(self._entries) == 0
|
||||
@@ -1,11 +1,11 @@
|
||||
"""Unit tests for JSONL transcript management utilities."""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .transcript import (
|
||||
STRIPPABLE_TYPES,
|
||||
read_transcript_file,
|
||||
strip_progress_entries,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
@@ -38,49 +38,6 @@ PROGRESS_ENTRY = {
|
||||
VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG)
|
||||
|
||||
|
||||
# --- read_transcript_file ---
|
||||
|
||||
|
||||
class TestReadTranscriptFile:
|
||||
def test_returns_content_for_valid_file(self, tmp_path):
|
||||
path = tmp_path / "session.jsonl"
|
||||
path.write_text(VALID_TRANSCRIPT)
|
||||
result = read_transcript_file(str(path))
|
||||
assert result is not None
|
||||
assert "user" in result
|
||||
|
||||
def test_returns_none_for_missing_file(self):
|
||||
assert read_transcript_file("/nonexistent/path.jsonl") is None
|
||||
|
||||
def test_returns_none_for_empty_path(self):
|
||||
assert read_transcript_file("") is None
|
||||
|
||||
def test_returns_none_for_empty_file(self, tmp_path):
|
||||
path = tmp_path / "empty.jsonl"
|
||||
path.write_text("")
|
||||
assert read_transcript_file(str(path)) is None
|
||||
|
||||
def test_returns_none_for_metadata_only(self, tmp_path):
|
||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY)
|
||||
path = tmp_path / "meta.jsonl"
|
||||
path.write_text(content)
|
||||
assert read_transcript_file(str(path)) is None
|
||||
|
||||
def test_returns_none_for_invalid_json(self, tmp_path):
|
||||
path = tmp_path / "bad.jsonl"
|
||||
path.write_text("not json\n{}\n{}\n")
|
||||
assert read_transcript_file(str(path)) is None
|
||||
|
||||
def test_no_size_limit(self, tmp_path):
|
||||
"""Large files are accepted — bucket storage has no size limit."""
|
||||
big_content = {"type": "user", "uuid": "u9", "data": "x" * 1_000_000}
|
||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, big_content, ASST_MSG)
|
||||
path = tmp_path / "big.jsonl"
|
||||
path.write_text(content)
|
||||
result = read_transcript_file(str(path))
|
||||
assert result is not None
|
||||
|
||||
|
||||
# --- write_transcript_to_tempfile ---
|
||||
|
||||
|
||||
@@ -155,12 +112,56 @@ class TestValidateTranscript:
|
||||
assert validate_transcript(content) is False
|
||||
|
||||
def test_assistant_only_no_user(self):
|
||||
"""With --resume the user message is a CLI query param, not a transcript entry.
|
||||
A transcript with only assistant entries is valid."""
|
||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, ASST_MSG)
|
||||
assert validate_transcript(content) is False
|
||||
assert validate_transcript(content) is True
|
||||
|
||||
def test_resume_transcript_without_user_entry(self):
|
||||
"""Simulates a real --resume stop hook transcript: the CLI session file
|
||||
has summary + assistant entries but no user entry."""
|
||||
summary = {"type": "summary", "uuid": "s1", "text": "context..."}
|
||||
asst1 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {"role": "assistant", "content": "Hello!"},
|
||||
}
|
||||
asst2 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a2",
|
||||
"parentUuid": "a1",
|
||||
"message": {"role": "assistant", "content": "Sure, let me help."},
|
||||
}
|
||||
content = _make_jsonl(summary, asst1, asst2)
|
||||
assert validate_transcript(content) is True
|
||||
|
||||
def test_single_assistant_entry(self):
|
||||
"""A transcript with just one assistant line is valid — the CLI may
|
||||
produce short transcripts for simple responses with no tool use."""
|
||||
content = json.dumps(ASST_MSG) + "\n"
|
||||
assert validate_transcript(content) is True
|
||||
|
||||
def test_invalid_json_returns_false(self):
|
||||
assert validate_transcript("not json\n{}\n{}\n") is False
|
||||
|
||||
def test_malformed_json_after_valid_assistant_returns_false(self):
|
||||
"""Validation must scan all lines - malformed JSON anywhere should fail."""
|
||||
valid_asst = json.dumps(ASST_MSG)
|
||||
malformed = "not valid json"
|
||||
content = valid_asst + "\n" + malformed + "\n"
|
||||
assert validate_transcript(content) is False
|
||||
|
||||
def test_blank_lines_are_skipped(self):
|
||||
"""Transcripts with blank lines should be valid if they contain assistant entries."""
|
||||
content = (
|
||||
json.dumps(USER_MSG)
|
||||
+ "\n\n" # blank line
|
||||
+ json.dumps(ASST_MSG)
|
||||
+ "\n"
|
||||
+ "\n" # another blank line
|
||||
)
|
||||
assert validate_transcript(content) is True
|
||||
|
||||
|
||||
# --- strip_progress_entries ---
|
||||
|
||||
@@ -253,3 +254,31 @@ class TestStripProgressEntries:
|
||||
assert "queue-operation" not in result_types
|
||||
assert "user" in result_types
|
||||
assert "assistant" in result_types
|
||||
|
||||
def test_preserves_original_line_formatting(self):
|
||||
"""Non-reparented entries keep their original JSON formatting."""
|
||||
# orjson produces compact JSON - test that we preserve the exact input
|
||||
# when no reparenting is needed (no re-serialization)
|
||||
original_line = json.dumps(USER_MSG)
|
||||
|
||||
content = original_line + "\n" + json.dumps(ASST_MSG) + "\n"
|
||||
result = strip_progress_entries(content)
|
||||
result_lines = result.strip().split("\n")
|
||||
|
||||
# Original line should be byte-identical (not re-serialized)
|
||||
assert result_lines[0] == original_line
|
||||
|
||||
def test_reparented_entries_are_reserialized(self):
|
||||
"""Entries whose parentUuid changes must be re-serialized."""
|
||||
progress = {"type": "progress", "uuid": "p1", "parentUuid": "u1"}
|
||||
asst = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "p1",
|
||||
"message": {"role": "assistant", "content": "done"},
|
||||
}
|
||||
content = _make_jsonl(USER_MSG, progress, asst)
|
||||
result = strip_progress_entries(content)
|
||||
lines = result.strip().split("\n")
|
||||
asst_entry = json.loads(lines[-1])
|
||||
assert asst_entry["parentUuid"] == "u1" # reparented
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,75 +4,14 @@ from os import getenv
|
||||
|
||||
import pytest
|
||||
|
||||
from . import service as chat_service
|
||||
from .model import create_chat_session, get_chat_session, upsert_chat_session
|
||||
from .response_model import StreamError, StreamTextDelta, StreamToolOutputAvailable
|
||||
from .response_model import StreamError, StreamTextDelta
|
||||
from .sdk import service as sdk_service
|
||||
from .sdk.transcript import download_transcript
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion(setup_test_user, test_user_id):
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
|
||||
has_errors = False
|
||||
assistant_message = ""
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session.session_id, "Hello, how are you?", user_id=session.user_id
|
||||
):
|
||||
logger.info(chunk)
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
if isinstance(chunk, StreamTextDelta):
|
||||
assistant_message += chunk.delta
|
||||
|
||||
# StreamFinish is published by mark_session_completed (processor layer),
|
||||
# not by the service. The generator completing means the stream ended.
|
||||
assert not has_errors, "Error occurred while streaming chat completion"
|
||||
assert assistant_message, "Assistant message is empty"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user_id):
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
has_errors = False
|
||||
had_tool_calls = False
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session.session_id,
|
||||
"Please find me an agent that can help me with my business. Use the query 'moneny printing agent'",
|
||||
user_id=session.user_id,
|
||||
):
|
||||
logger.info(chunk)
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
if isinstance(chunk, StreamToolOutputAvailable):
|
||||
had_tool_calls = True
|
||||
|
||||
assert not has_errors, "Error occurred while streaming chat completion"
|
||||
assert had_tool_calls, "Tool calls did not occur"
|
||||
session = await get_chat_session(session.session_id)
|
||||
assert session, "Session not found"
|
||||
assert session.usage, "Usage is empty"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
||||
"""Test that the SDK --resume path captures and uses transcripts across turns.
|
||||
|
||||
@@ -733,7 +733,10 @@ async def mark_session_completed(
|
||||
# This is the SINGLE place that publishes StreamFinish — services and
|
||||
# the processor must NOT publish it themselves.
|
||||
try:
|
||||
await publish_chunk(turn_id, StreamFinish())
|
||||
await publish_chunk(
|
||||
turn_id,
|
||||
StreamFinish(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to publish StreamFinish for session {session_id}: {e}. "
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tracking import track_tool_called
|
||||
|
||||
from .add_understanding import AddUnderstandingTool
|
||||
from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreenshotTool
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
from .browse_web import BrowseWebTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .customize_agent import CustomizeAgentTool
|
||||
from .edit_agent import EditAgentTool
|
||||
@@ -19,8 +20,17 @@ from .find_agent import FindAgentTool
|
||||
from .find_block import FindBlockTool
|
||||
from .find_library_agent import FindLibraryAgentTool
|
||||
from .get_doc_page import GetDocPageTool
|
||||
from .manage_folders import (
|
||||
CreateFolderTool,
|
||||
DeleteFolderTool,
|
||||
ListFoldersTool,
|
||||
MoveAgentsToFolderTool,
|
||||
MoveFolderTool,
|
||||
UpdateFolderTool,
|
||||
)
|
||||
from .run_agent import RunAgentTool
|
||||
from .run_block import RunBlockTool
|
||||
from .run_mcp_tool import RunMCPToolTool
|
||||
from .search_docs import SearchDocsTool
|
||||
from .web_fetch import WebFetchTool
|
||||
from .workspace_files import (
|
||||
@@ -31,6 +41,7 @@ from .workspace_files import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -44,15 +55,25 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"find_agent": FindAgentTool(),
|
||||
"find_block": FindBlockTool(),
|
||||
"find_library_agent": FindLibraryAgentTool(),
|
||||
# Folder management tools
|
||||
"create_folder": CreateFolderTool(),
|
||||
"list_folders": ListFoldersTool(),
|
||||
"update_folder": UpdateFolderTool(),
|
||||
"move_folder": MoveFolderTool(),
|
||||
"delete_folder": DeleteFolderTool(),
|
||||
"move_agents_to_folder": MoveAgentsToFolderTool(),
|
||||
"run_agent": RunAgentTool(),
|
||||
"run_block": RunBlockTool(),
|
||||
"run_mcp_tool": RunMCPToolTool(),
|
||||
"view_agent_output": AgentOutputTool(),
|
||||
"search_docs": SearchDocsTool(),
|
||||
"get_doc_page": GetDocPageTool(),
|
||||
# Web fetch for safe URL retrieval
|
||||
"web_fetch": WebFetchTool(),
|
||||
# Browser-based browsing for JS-rendered pages (Stagehand + Browserbase)
|
||||
"browse_web": BrowseWebTool(),
|
||||
# Agent-browser multi-step automation (navigate, act, screenshot)
|
||||
"browser_navigate": BrowserNavigateTool(),
|
||||
"browser_act": BrowserActTool(),
|
||||
"browser_screenshot": BrowserScreenshotTool(),
|
||||
# Sandboxed code execution (bubblewrap)
|
||||
"bash_exec": BashExecTool(),
|
||||
# Persistent workspace tools (cloud storage, survives across sessions)
|
||||
@@ -70,10 +91,17 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
find_agent_tool = TOOL_REGISTRY["find_agent"]
|
||||
run_agent_tool = TOOL_REGISTRY["run_agent"]
|
||||
|
||||
# Generated from registry for OpenAI API
|
||||
tools: list[ChatCompletionToolParam] = [
|
||||
tool.as_openai_tool() for tool in TOOL_REGISTRY.values()
|
||||
]
|
||||
|
||||
def get_available_tools() -> list[ChatCompletionToolParam]:
|
||||
"""Return OpenAI tool schemas for tools available in the current environment.
|
||||
|
||||
Called per-request so that env-var or binary availability is evaluated
|
||||
fresh each time (e.g. browser_* tools are excluded when agent-browser
|
||||
CLI is not installed).
|
||||
"""
|
||||
return [
|
||||
tool.as_openai_tool() for tool in TOOL_REGISTRY.values() if tool.is_available
|
||||
]
|
||||
|
||||
|
||||
def get_tool(tool_name: str) -> BaseTool | None:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from os import getenv
|
||||
@@ -12,12 +13,34 @@ from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data import db as db_module
|
||||
from backend.data.db import prisma
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.data.model import APIKeyCredentials
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _ensure_db_connected() -> None:
|
||||
"""Ensure the Prisma connection is alive on the current event loop.
|
||||
|
||||
On Python 3.11, the httpx transport inside Prisma can reference a stale
|
||||
(closed) event loop when session-scoped async fixtures are evaluated long
|
||||
after the initial ``server`` fixture connected Prisma. A cheap health-check
|
||||
followed by a reconnect fixes this without affecting other fixtures.
|
||||
"""
|
||||
try:
|
||||
await prisma.query_raw("SELECT 1")
|
||||
except Exception:
|
||||
_logger.info("Prisma connection stale – reconnecting")
|
||||
try:
|
||||
await db_module.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
await db_module.connect()
|
||||
|
||||
|
||||
def make_session(user_id: str):
|
||||
return ChatSession(
|
||||
@@ -43,6 +66,8 @@ async def setup_test_data(server):
|
||||
|
||||
Depends on ``server`` to ensure Prisma is connected.
|
||||
"""
|
||||
await _ensure_db_connected()
|
||||
|
||||
# 1. Create a test user
|
||||
user_data = {
|
||||
"sub": f"test-user-{uuid.uuid4()}",
|
||||
@@ -126,8 +151,8 @@ async def setup_test_data(server):
|
||||
unique_slug = f"test-agent-{str(uuid.uuid4())[:8]}"
|
||||
store_submission = await store_db.create_store_submission(
|
||||
user_id=user.id,
|
||||
agent_id=created_graph.id,
|
||||
agent_version=created_graph.version,
|
||||
graph_id=created_graph.id,
|
||||
graph_version=created_graph.version,
|
||||
slug=unique_slug,
|
||||
name="Test Agent",
|
||||
description="A simple test agent",
|
||||
@@ -136,10 +161,10 @@ async def setup_test_data(server):
|
||||
image_urls=["https://example.com/image.jpg"],
|
||||
)
|
||||
|
||||
assert store_submission.store_listing_version_id is not None
|
||||
assert store_submission.listing_version_id is not None
|
||||
# 4. Approve the store listing version
|
||||
await store_db.review_store_submission(
|
||||
store_listing_version_id=store_submission.store_listing_version_id,
|
||||
store_listing_version_id=store_submission.listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Approved for testing",
|
||||
internal_comments="Test approval",
|
||||
@@ -164,6 +189,8 @@ async def setup_llm_test_data(server):
|
||||
|
||||
Depends on ``server`` to ensure Prisma is connected.
|
||||
"""
|
||||
await _ensure_db_connected()
|
||||
|
||||
key = getenv("OPENAI_API_KEY")
|
||||
if not key:
|
||||
return pytest.skip("OPENAI_API_KEY is not set")
|
||||
@@ -294,8 +321,8 @@ async def setup_llm_test_data(server):
|
||||
unique_slug = f"llm-test-agent-{str(uuid.uuid4())[:8]}"
|
||||
store_submission = await store_db.create_store_submission(
|
||||
user_id=user.id,
|
||||
agent_id=created_graph.id,
|
||||
agent_version=created_graph.version,
|
||||
graph_id=created_graph.id,
|
||||
graph_version=created_graph.version,
|
||||
slug=unique_slug,
|
||||
name="LLM Test Agent",
|
||||
description="An agent with LLM capabilities",
|
||||
@@ -303,9 +330,9 @@ async def setup_llm_test_data(server):
|
||||
categories=["testing", "ai"],
|
||||
image_urls=["https://example.com/image.jpg"],
|
||||
)
|
||||
assert store_submission.store_listing_version_id is not None
|
||||
assert store_submission.listing_version_id is not None
|
||||
await store_db.review_store_submission(
|
||||
store_listing_version_id=store_submission.store_listing_version_id,
|
||||
store_listing_version_id=store_submission.listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Approved for testing",
|
||||
internal_comments="Test approval for LLM agent",
|
||||
@@ -330,6 +357,8 @@ async def setup_firecrawl_test_data(server):
|
||||
|
||||
Depends on ``server`` to ensure Prisma is connected.
|
||||
"""
|
||||
await _ensure_db_connected()
|
||||
|
||||
# 1. Create a test user
|
||||
user_data = {
|
||||
"sub": f"test-user-{uuid.uuid4()}",
|
||||
@@ -447,8 +476,8 @@ async def setup_firecrawl_test_data(server):
|
||||
unique_slug = f"firecrawl-test-agent-{str(uuid.uuid4())[:8]}"
|
||||
store_submission = await store_db.create_store_submission(
|
||||
user_id=user.id,
|
||||
agent_id=created_graph.id,
|
||||
agent_version=created_graph.version,
|
||||
graph_id=created_graph.id,
|
||||
graph_version=created_graph.version,
|
||||
slug=unique_slug,
|
||||
name="Firecrawl Test Agent",
|
||||
description="An agent with Firecrawl integration (no credentials)",
|
||||
@@ -456,9 +485,9 @@ async def setup_firecrawl_test_data(server):
|
||||
categories=["testing", "scraping"],
|
||||
image_urls=["https://example.com/image.jpg"],
|
||||
)
|
||||
assert store_submission.store_listing_version_id is not None
|
||||
assert store_submission.listing_version_id is not None
|
||||
await store_db.review_store_submission(
|
||||
store_listing_version_id=store_submission.store_listing_version_id,
|
||||
store_listing_version_id=store_submission.listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Approved for testing",
|
||||
internal_comments="Test approval for Firecrawl agent",
|
||||
|
||||
876
autogpt_platform/backend/backend/copilot/tools/agent_browser.py
Normal file
876
autogpt_platform/backend/backend/copilot/tools/agent_browser.py
Normal file
@@ -0,0 +1,876 @@
|
||||
"""Agent-browser tools — multi-step browser automation for the Copilot.
|
||||
|
||||
Uses the agent-browser CLI (https://github.com/vercel-labs/agent-browser)
|
||||
which runs a local Chromium instance managed by a persistent daemon.
|
||||
|
||||
- Runs locally — no cloud account required
|
||||
- Full interaction support: click, fill, scroll, login flows, multi-step
|
||||
- Session persistence via --session-name: cookies/auth carry across tool calls
|
||||
within the same Copilot session, enabling login → navigate → extract workflows
|
||||
- Screenshot with --annotate overlays @ref labels, saved to workspace for user
|
||||
- The Claude Agent SDK's multi-turn loop handles orchestration — each tool call
|
||||
is one browser action; the LLM chains them naturally
|
||||
|
||||
SSRF protection:
|
||||
Uses the shared validate_url() from backend.util.request, which is the same
|
||||
guard used by HTTP blocks and web_fetch. It resolves ALL DNS answers (not just
|
||||
the first), blocks RFC 1918, loopback, link-local, 0.0.0.0/8, multicast,
|
||||
and all relevant IPv6 ranges, and applies IDNA encoding to prevent Unicode
|
||||
domain attacks.
|
||||
|
||||
Requires:
|
||||
npm install -g agent-browser
|
||||
agent-browser install (downloads Chromium, one-time per machine)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.util.request import validate_url
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
BrowserActResponse,
|
||||
BrowserNavigateResponse,
|
||||
BrowserScreenshotResponse,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .workspace_files import get_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Per-command timeout (seconds). Navigation + networkidle wait can be slow.
|
||||
_CMD_TIMEOUT = 45
|
||||
# Accessibility tree can be very large; cap it to keep LLM context manageable.
|
||||
_MAX_SNAPSHOT_CHARS = 20_000
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subprocess helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _run(
|
||||
session_name: str,
|
||||
*args: str,
|
||||
timeout: int = _CMD_TIMEOUT,
|
||||
) -> tuple[int, str, str]:
|
||||
"""Run agent-browser for the given session and return (rc, stdout, stderr).
|
||||
|
||||
Uses both:
|
||||
--session <name> → isolated Chromium context (no shared history/cookies
|
||||
with other Copilot sessions — prevents cross-session
|
||||
browser state leakage)
|
||||
--session-name <name> → persist cookies/localStorage across tool calls within
|
||||
the same session (enables login → navigate flows)
|
||||
"""
|
||||
cmd = [
|
||||
"agent-browser",
|
||||
"--session",
|
||||
session_name,
|
||||
"--session-name",
|
||||
session_name,
|
||||
*args,
|
||||
]
|
||||
proc = None
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout)
|
||||
return proc.returncode or 0, stdout.decode(), stderr.decode()
|
||||
except asyncio.TimeoutError:
|
||||
# Kill the orphaned subprocess so it does not linger in the process table.
|
||||
if proc is not None and proc.returncode is None:
|
||||
proc.kill()
|
||||
try:
|
||||
await proc.communicate()
|
||||
except Exception:
|
||||
pass # Best-effort reap; ignore errors during cleanup.
|
||||
return 1, "", f"Command timed out after {timeout}s."
|
||||
except FileNotFoundError:
|
||||
return (
|
||||
1,
|
||||
"",
|
||||
"agent-browser is not installed (run: npm install -g agent-browser && agent-browser install).",
|
||||
)
|
||||
|
||||
|
||||
async def _snapshot(session_name: str) -> str:
|
||||
"""Return the current page's interactive accessibility tree, truncated."""
|
||||
rc, stdout, stderr = await _run(session_name, "snapshot", "-i", "-c")
|
||||
if rc != 0:
|
||||
return f"[snapshot failed: {stderr[:300]}]"
|
||||
text = stdout.strip()
|
||||
if len(text) > _MAX_SNAPSHOT_CHARS:
|
||||
suffix = "\n\n[Snapshot truncated — use browser_act to navigate further]"
|
||||
keep = max(0, _MAX_SNAPSHOT_CHARS - len(suffix))
|
||||
text = text[:keep] + suffix
|
||||
return text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stateless session helpers — persist / restore browser state across pods
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Module-level cache of sessions known to be alive on this pod.
|
||||
# Avoids the subprocess probe on every tool call within the same pod.
|
||||
_alive_sessions: set[str] = set()
|
||||
|
||||
# Per-session locks to prevent concurrent _ensure_session calls from
|
||||
# triggering duplicate _restore_browser_state for the same session.
|
||||
# Protected by _session_locks_mutex to ensure setdefault/pop are not
|
||||
# interleaved across await boundaries.
|
||||
_session_locks: dict[str, asyncio.Lock] = {}
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
# Workspace filename for persisted browser state (auto-scoped to session).
|
||||
# Dot-prefixed so it is hidden from user workspace listings.
|
||||
_STATE_FILENAME = "._browser_state.json"
|
||||
|
||||
# Maximum concurrent subprocesses during cookie/storage restore.
|
||||
_RESTORE_CONCURRENCY = 10
|
||||
|
||||
# Maximum cookies to restore per session. Pathological sites can accumulate
|
||||
# thousands of cookies; restoring them all would be slow and is rarely useful.
|
||||
_MAX_RESTORE_COOKIES = 100
|
||||
|
||||
# Background tasks for fire-and-forget state persistence.
|
||||
# Prevents GC from collecting tasks before they complete.
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
def _fire_and_forget_save(
|
||||
session_name: str, user_id: str, session: ChatSession
|
||||
) -> None:
|
||||
"""Schedule state persistence as a background task (non-blocking).
|
||||
|
||||
State save is already best-effort (errors are swallowed), so running it
|
||||
in the background avoids adding latency to tool responses.
|
||||
"""
|
||||
task = asyncio.create_task(_save_browser_state(session_name, user_id, session))
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
|
||||
async def _has_local_session(session_name: str) -> bool:
|
||||
"""Check if the local agent-browser daemon for this session is running."""
|
||||
rc, _, _ = await _run(session_name, "get", "url", timeout=5)
|
||||
return rc == 0
|
||||
|
||||
|
||||
async def _save_browser_state(
|
||||
session_name: str, user_id: str, session: ChatSession
|
||||
) -> None:
|
||||
"""Persist browser state (cookies, localStorage, URL) to workspace.
|
||||
|
||||
Best-effort: errors are logged but never propagate to the tool response.
|
||||
"""
|
||||
try:
|
||||
# Gather state in parallel
|
||||
(rc_url, url_out, _), (rc_ck, ck_out, _), (rc_ls, ls_out, _) = (
|
||||
await asyncio.gather(
|
||||
_run(session_name, "get", "url", timeout=10),
|
||||
_run(session_name, "cookies", "get", "--json", timeout=10),
|
||||
_run(session_name, "storage", "local", "--json", timeout=10),
|
||||
)
|
||||
)
|
||||
|
||||
state = {
|
||||
"url": url_out.strip() if rc_url == 0 else "",
|
||||
"cookies": (json.loads(ck_out) if rc_ck == 0 and ck_out.strip() else []),
|
||||
"local_storage": (
|
||||
json.loads(ls_out) if rc_ls == 0 and ls_out.strip() else {}
|
||||
),
|
||||
}
|
||||
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
await manager.write_file(
|
||||
content=json.dumps(state).encode("utf-8"),
|
||||
filename=_STATE_FILENAME,
|
||||
mime_type="application/json",
|
||||
overwrite=True,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[browser] Failed to save browser state for session %s",
|
||||
session_name,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def _restore_browser_state(
|
||||
session_name: str, user_id: str, session: ChatSession
|
||||
) -> bool:
|
||||
"""Restore browser state from workspace storage into a fresh daemon.
|
||||
|
||||
Best-effort: errors are logged but never propagate to the tool response.
|
||||
Returns True on success (or no state to restore), False on failure.
|
||||
"""
|
||||
try:
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
|
||||
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
|
||||
if file_info is None:
|
||||
return True # No saved state — first call or never saved
|
||||
|
||||
state_bytes = await manager.read_file(_STATE_FILENAME)
|
||||
state = json.loads(state_bytes.decode("utf-8"))
|
||||
|
||||
url = state.get("url", "")
|
||||
cookies = state.get("cookies", [])
|
||||
local_storage = state.get("local_storage", {})
|
||||
|
||||
# Navigate first — starts daemon + sets the correct origin for cookies
|
||||
if url:
|
||||
# Validate the saved URL to prevent SSRF via stored redirect targets.
|
||||
try:
|
||||
await validate_url(url, trusted_origins=[])
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"[browser] State restore: blocked SSRF URL %s", url[:200]
|
||||
)
|
||||
return False
|
||||
|
||||
rc, _, stderr = await _run(session_name, "open", url)
|
||||
if rc != 0:
|
||||
logger.warning(
|
||||
"[browser] State restore: failed to open %s: %s",
|
||||
url,
|
||||
stderr[:200],
|
||||
)
|
||||
return False
|
||||
await _run(session_name, "wait", "--load", "load", timeout=15)
|
||||
|
||||
# Restore cookies and localStorage in parallel via asyncio.gather.
|
||||
# Semaphore caps concurrent subprocess spawns so we don't overwhelm the
|
||||
# system when a session has hundreds of cookies.
|
||||
sem = asyncio.Semaphore(_RESTORE_CONCURRENCY)
|
||||
|
||||
# Guard against pathological sites with thousands of cookies.
|
||||
if len(cookies) > _MAX_RESTORE_COOKIES:
|
||||
logger.debug(
|
||||
"[browser] State restore: capping cookies from %d to %d",
|
||||
len(cookies),
|
||||
_MAX_RESTORE_COOKIES,
|
||||
)
|
||||
cookies = cookies[:_MAX_RESTORE_COOKIES]
|
||||
|
||||
async def _set_cookie(c: dict[str, Any]) -> None:
|
||||
name = c.get("name", "")
|
||||
value = c.get("value", "")
|
||||
domain = c.get("domain", "")
|
||||
path = c.get("path", "/")
|
||||
if not (name and domain):
|
||||
return
|
||||
async with sem:
|
||||
rc, _, stderr = await _run(
|
||||
session_name,
|
||||
"cookies",
|
||||
"set",
|
||||
name,
|
||||
value,
|
||||
"--domain",
|
||||
domain,
|
||||
"--path",
|
||||
path,
|
||||
timeout=5,
|
||||
)
|
||||
if rc != 0:
|
||||
logger.debug(
|
||||
"[browser] State restore: cookie set failed for %s: %s",
|
||||
name,
|
||||
stderr[:100],
|
||||
)
|
||||
|
||||
async def _set_storage(key: str, val: object) -> None:
|
||||
async with sem:
|
||||
rc, _, stderr = await _run(
|
||||
session_name,
|
||||
"storage",
|
||||
"local",
|
||||
"set",
|
||||
key,
|
||||
str(val),
|
||||
timeout=5,
|
||||
)
|
||||
if rc != 0:
|
||||
logger.debug(
|
||||
"[browser] State restore: localStorage set failed for %s: %s",
|
||||
key,
|
||||
stderr[:100],
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
*[_set_cookie(c) for c in cookies],
|
||||
*[_set_storage(k, v) for k, v in local_storage.items()],
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[browser] Failed to restore browser state for session %s",
|
||||
session_name,
|
||||
exc_info=True,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def _ensure_session(
|
||||
session_name: str, user_id: str, session: ChatSession
|
||||
) -> None:
|
||||
"""Ensure the local browser daemon has state. Restore from cloud if needed."""
|
||||
if session_name in _alive_sessions:
|
||||
return
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.setdefault(session_name, asyncio.Lock())
|
||||
async with lock:
|
||||
# Double-check after acquiring lock — another coroutine may have restored.
|
||||
if session_name in _alive_sessions:
|
||||
return
|
||||
if await _has_local_session(session_name):
|
||||
_alive_sessions.add(session_name)
|
||||
return
|
||||
if await _restore_browser_state(session_name, user_id, session):
|
||||
_alive_sessions.add(session_name)
|
||||
|
||||
|
||||
async def close_browser_session(session_name: str, user_id: str | None = None) -> None:
|
||||
"""Shut down the local agent-browser daemon and clean up stored state.
|
||||
|
||||
Deletes ``._browser_state.json`` from workspace storage so cookies and
|
||||
other credentials do not linger after the session is deleted.
|
||||
|
||||
Best-effort: errors are logged but never raised.
|
||||
"""
|
||||
_alive_sessions.discard(session_name)
|
||||
async with _session_locks_mutex:
|
||||
_session_locks.pop(session_name, None)
|
||||
|
||||
# Delete persisted browser state (cookies, localStorage) from workspace.
|
||||
if user_id:
|
||||
try:
|
||||
manager = await get_manager(user_id, session_name)
|
||||
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
|
||||
if file_info is not None:
|
||||
await manager.delete_file(file_info.id)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"[browser] Failed to delete state file for session %s",
|
||||
session_name,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
rc, _, stderr = await _run(session_name, "close", timeout=10)
|
||||
if rc != 0:
|
||||
logger.debug(
|
||||
"[browser] close failed for session %s: %s",
|
||||
session_name,
|
||||
stderr[:200],
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"[browser] Exception closing browser session %s",
|
||||
session_name,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool: browser_navigate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BrowserNavigateTool(BaseTool):
|
||||
"""Navigate to a URL and return the page's interactive elements.
|
||||
|
||||
The browser session persists across tool calls within this Copilot session
|
||||
(keyed to session_id), so cookies and auth state carry over. This enables
|
||||
full login flows: navigate to login page → browser_act to fill credentials
|
||||
→ browser_act to submit → browser_navigate to the target page.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "browser_navigate"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Navigate to a URL using a real browser. Returns an accessibility "
|
||||
"tree snapshot listing the page's interactive elements with @ref IDs "
|
||||
"(e.g. @e3) that can be used with browser_act. "
|
||||
"Session persists — cookies and login state carry over between calls. "
|
||||
"Use this (with browser_act) for multi-step interaction: login flows, "
|
||||
"form filling, button clicks, or anything requiring page interaction. "
|
||||
"For plain static pages, prefer web_fetch — no browser overhead. "
|
||||
"For authenticated pages: navigate to the login page first, use browser_act "
|
||||
"to fill credentials and submit, then navigate to the target page. "
|
||||
"Note: for slow SPAs, the returned snapshot may reflect a partially-loaded "
|
||||
"state. If elements seem missing, use browser_act with action='wait' and a "
|
||||
"CSS selector or millisecond delay, then take a browser_screenshot to verify."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The HTTP/HTTPS URL to navigate to.",
|
||||
},
|
||||
"wait_for": {
|
||||
"type": "string",
|
||||
"enum": ["networkidle", "load", "domcontentloaded"],
|
||||
"default": "networkidle",
|
||||
"description": "When to consider navigation complete. Use 'networkidle' for SPAs (default).",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
return shutil.which("agent-browser") is not None
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Navigate to *url*, wait for the page to settle, and return a snapshot.
|
||||
|
||||
The snapshot is an accessibility-tree listing of interactive elements.
|
||||
Note: for slow SPAs that never fully idle, the snapshot may reflect a
|
||||
partially-loaded state (the wait is best-effort).
|
||||
"""
|
||||
url: str = (kwargs.get("url") or "").strip()
|
||||
wait_for: str = kwargs.get("wait_for") or "networkidle"
|
||||
session_name = session.session_id
|
||||
|
||||
if not url:
|
||||
return ErrorResponse(
|
||||
message="Please provide a URL to navigate to.",
|
||||
error="missing_url",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
try:
|
||||
await validate_url(url, trusted_origins=[])
|
||||
except ValueError as e:
|
||||
return ErrorResponse(
|
||||
message=str(e),
|
||||
error="blocked_url",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Restore browser state from cloud if this is a different pod
|
||||
if user_id:
|
||||
await _ensure_session(session_name, user_id, session)
|
||||
|
||||
# Navigate
|
||||
rc, _, stderr = await _run(session_name, "open", url)
|
||||
if rc != 0:
|
||||
logger.warning(
|
||||
"[browser_navigate] open failed for %s: %s", url, stderr[:300]
|
||||
)
|
||||
return ErrorResponse(
|
||||
message="Failed to navigate to URL.",
|
||||
error="navigation_failed",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Wait for page to settle (best-effort: some SPAs never reach networkidle)
|
||||
wait_rc, _, wait_err = await _run(session_name, "wait", "--load", wait_for)
|
||||
if wait_rc != 0:
|
||||
logger.warning(
|
||||
"[browser_navigate] wait(%s) failed: %s", wait_for, wait_err[:300]
|
||||
)
|
||||
|
||||
# Get current title and URL in parallel
|
||||
(_, title_out, _), (_, url_out, _) = await asyncio.gather(
|
||||
_run(session_name, "get", "title"),
|
||||
_run(session_name, "get", "url"),
|
||||
)
|
||||
|
||||
snapshot = await _snapshot(session_name)
|
||||
|
||||
result = BrowserNavigateResponse(
|
||||
message=f"Navigated to {url}",
|
||||
url=url_out.strip() or url,
|
||||
title=title_out.strip(),
|
||||
snapshot=snapshot,
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Persist browser state to cloud for cross-pod continuity
|
||||
if user_id:
|
||||
_fire_and_forget_save(session_name, user_id, session)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool: browser_act
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_NO_TARGET_ACTIONS = frozenset({"back", "forward", "reload"})
|
||||
_SCROLL_ACTIONS = frozenset({"scroll"})
|
||||
_TARGET_ONLY_ACTIONS = frozenset({"click", "dblclick", "hover", "check", "uncheck"})
|
||||
_TARGET_VALUE_ACTIONS = frozenset({"fill", "type", "select"})
|
||||
# wait <selector|ms>: waits for a DOM element or a fixed delay (e.g. "1000" for 1 s)
|
||||
_WAIT_ACTIONS = frozenset({"wait"})
|
||||
|
||||
|
||||
class BrowserActTool(BaseTool):
|
||||
"""Perform an action on the current browser page and return the updated snapshot.
|
||||
|
||||
Use @ref IDs from the snapshot returned by browser_navigate (e.g. '@e3').
|
||||
The LLM orchestrates multi-step flows by chaining browser_navigate and
|
||||
browser_act calls across turns of the Claude Agent SDK conversation.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "browser_act"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Interact with the current browser page. Use @ref IDs from the "
|
||||
"snapshot (e.g. '@e3') to target elements. Returns an updated snapshot. "
|
||||
"Supported actions: click, dblclick, fill, type, scroll, hover, press, "
|
||||
"check, uncheck, select, wait, back, forward, reload. "
|
||||
"fill clears the field before typing; type appends without clearing. "
|
||||
"wait accepts a CSS selector (waits for element) or milliseconds string (e.g. '1000'). "
|
||||
"Example login flow: fill @e1 with email → fill @e2 with password → "
|
||||
"click @e3 (submit) → browser_navigate to the target page."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"click",
|
||||
"dblclick",
|
||||
"fill",
|
||||
"type",
|
||||
"scroll",
|
||||
"hover",
|
||||
"press",
|
||||
"check",
|
||||
"uncheck",
|
||||
"select",
|
||||
"wait",
|
||||
"back",
|
||||
"forward",
|
||||
"reload",
|
||||
],
|
||||
"description": "The action to perform.",
|
||||
},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Element to target. Use @ref from snapshot (e.g. '@e3'), "
|
||||
"a CSS selector, or a text description. "
|
||||
"Required for: click, dblclick, fill, type, hover, check, uncheck, select. "
|
||||
"For wait: a CSS selector to wait for, or milliseconds as a string (e.g. '1000')."
|
||||
),
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"For fill/type: the text to enter. "
|
||||
"For press: key name (e.g. 'Enter', 'Tab', 'Control+a'). "
|
||||
"For select: the option value to select."
|
||||
),
|
||||
},
|
||||
"direction": {
|
||||
"type": "string",
|
||||
"enum": ["up", "down", "left", "right"],
|
||||
"default": "down",
|
||||
"description": "For scroll: direction to scroll.",
|
||||
},
|
||||
},
|
||||
"required": ["action"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
return shutil.which("agent-browser") is not None
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Perform a browser action and return an updated page snapshot.
|
||||
|
||||
Validates the *action*/*target*/*value* combination, delegates to
|
||||
``agent-browser``, waits for the page to settle, and returns the
|
||||
accessibility-tree snapshot so the LLM can plan the next step.
|
||||
"""
|
||||
action: str = (kwargs.get("action") or "").strip()
|
||||
target: str = (kwargs.get("target") or "").strip()
|
||||
value: str = (kwargs.get("value") or "").strip()
|
||||
direction: str = (kwargs.get("direction") or "down").strip()
|
||||
session_name = session.session_id
|
||||
|
||||
if not action:
|
||||
return ErrorResponse(
|
||||
message="Please specify an action.",
|
||||
error="missing_action",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Build the agent-browser command args
|
||||
if action in _NO_TARGET_ACTIONS:
|
||||
cmd_args = [action]
|
||||
|
||||
elif action in _SCROLL_ACTIONS:
|
||||
cmd_args = ["scroll", direction]
|
||||
|
||||
elif action == "press":
|
||||
if not value:
|
||||
return ErrorResponse(
|
||||
message="'press' requires a 'value' (key name, e.g. 'Enter').",
|
||||
error="missing_value",
|
||||
session_id=session_name,
|
||||
)
|
||||
cmd_args = ["press", value]
|
||||
|
||||
elif action in _TARGET_ONLY_ACTIONS:
|
||||
if not target:
|
||||
return ErrorResponse(
|
||||
message=f"'{action}' requires a 'target' element.",
|
||||
error="missing_target",
|
||||
session_id=session_name,
|
||||
)
|
||||
cmd_args = [action, target]
|
||||
|
||||
elif action in _TARGET_VALUE_ACTIONS:
|
||||
if not target or not value:
|
||||
return ErrorResponse(
|
||||
message=f"'{action}' requires both 'target' and 'value'.",
|
||||
error="missing_params",
|
||||
session_id=session_name,
|
||||
)
|
||||
cmd_args = [action, target, value]
|
||||
|
||||
elif action in _WAIT_ACTIONS:
|
||||
if not target:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"'wait' requires a 'target': a CSS selector to wait for, "
|
||||
"or milliseconds as a string (e.g. '1000')."
|
||||
),
|
||||
error="missing_target",
|
||||
session_id=session_name,
|
||||
)
|
||||
cmd_args = ["wait", target]
|
||||
|
||||
else:
|
||||
return ErrorResponse(
|
||||
message=f"Unsupported action: {action}",
|
||||
error="invalid_action",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Restore browser state from cloud if this is a different pod
|
||||
if user_id:
|
||||
await _ensure_session(session_name, user_id, session)
|
||||
|
||||
rc, _, stderr = await _run(session_name, *cmd_args)
|
||||
if rc != 0:
|
||||
logger.warning("[browser_act] %s failed: %s", action, stderr[:300])
|
||||
return ErrorResponse(
|
||||
message=f"Action '{action}' failed.",
|
||||
error="action_failed",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Allow the page to settle after interaction (best-effort: SPAs may not idle)
|
||||
settle_rc, _, settle_err = await _run(
|
||||
session_name, "wait", "--load", "networkidle"
|
||||
)
|
||||
if settle_rc != 0:
|
||||
logger.warning(
|
||||
"[browser_act] post-action wait failed: %s", settle_err[:300]
|
||||
)
|
||||
|
||||
snapshot = await _snapshot(session_name)
|
||||
_, url_out, _ = await _run(session_name, "get", "url")
|
||||
|
||||
result = BrowserActResponse(
|
||||
message=f"Performed '{action}'" + (f" on '{target}'" if target else ""),
|
||||
action=action,
|
||||
current_url=url_out.strip(),
|
||||
snapshot=snapshot,
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Persist browser state to cloud for cross-pod continuity
|
||||
if user_id:
|
||||
_fire_and_forget_save(session_name, user_id, session)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool: browser_screenshot
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BrowserScreenshotTool(BaseTool):
|
||||
"""Capture a screenshot of the current browser page and save it to the workspace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "browser_screenshot"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Take a screenshot of the current browser page and save it to the workspace. "
|
||||
"IMPORTANT: After calling this tool, immediately call read_workspace_file "
|
||||
"with the returned file_id to display the image inline to the user — "
|
||||
"the screenshot is not visible until you do this. "
|
||||
"With annotate=true (default), @ref labels are overlaid on interactive "
|
||||
"elements, making it easy to see which @ref ID maps to which element on screen."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"annotate": {
|
||||
"type": "boolean",
|
||||
"default": True,
|
||||
"description": "Overlay @ref labels on interactive elements (default: true).",
|
||||
},
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"default": "screenshot.png",
|
||||
"description": "Filename to save in the workspace.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
return shutil.which("agent-browser") is not None
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Capture a PNG screenshot and upload it to the workspace.
|
||||
|
||||
Handles string-to-bool coercion for *annotate* (OpenAI function-call
|
||||
payloads sometimes deliver ``"true"``/``"false"`` as strings).
|
||||
Returns a :class:`BrowserScreenshotResponse` with the workspace
|
||||
``file_id`` the LLM should pass to ``read_workspace_file``.
|
||||
"""
|
||||
raw_annotate = kwargs.get("annotate", True)
|
||||
if isinstance(raw_annotate, str):
|
||||
annotate = raw_annotate.strip().lower() in {"1", "true", "yes", "on"}
|
||||
else:
|
||||
annotate = bool(raw_annotate)
|
||||
filename: str = (kwargs.get("filename") or "screenshot.png").strip()
|
||||
session_name = session.session_id
|
||||
|
||||
# Restore browser state from cloud if this is a different pod
|
||||
if user_id:
|
||||
await _ensure_session(session_name, user_id, session)
|
||||
|
||||
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".png")
|
||||
os.close(tmp_fd)
|
||||
try:
|
||||
cmd_args = ["screenshot"]
|
||||
if annotate:
|
||||
cmd_args.append("--annotate")
|
||||
cmd_args.append(tmp_path)
|
||||
|
||||
rc, _, stderr = await _run(session_name, *cmd_args)
|
||||
if rc != 0:
|
||||
logger.warning("[browser_screenshot] failed: %s", stderr[:300])
|
||||
return ErrorResponse(
|
||||
message="Failed to take screenshot.",
|
||||
error="screenshot_failed",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
with open(tmp_path, "rb") as f:
|
||||
png_bytes = f.read()
|
||||
|
||||
finally:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass # Best-effort temp file cleanup; not critical if it fails.
|
||||
|
||||
# Upload to workspace so the user can view it
|
||||
png_b64 = base64.b64encode(png_bytes).decode()
|
||||
|
||||
# Import here to avoid circular deps — workspace_files imports from .models
|
||||
from .workspace_files import WorkspaceWriteResponse, WriteWorkspaceFileTool
|
||||
|
||||
write_resp = await WriteWorkspaceFileTool()._execute(
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
filename=filename,
|
||||
content_base64=png_b64,
|
||||
)
|
||||
|
||||
if not isinstance(write_resp, WorkspaceWriteResponse):
|
||||
return ErrorResponse(
|
||||
message="Screenshot taken but failed to save to workspace.",
|
||||
error="workspace_write_failed",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
result = BrowserScreenshotResponse(
|
||||
message=f"Screenshot saved to workspace as '{filename}'. Use read_workspace_file with file_id='{write_resp.file_id}' to retrieve it.",
|
||||
file_id=write_resp.file_id,
|
||||
filename=filename,
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Persist browser state to cloud for cross-pod continuity
|
||||
if user_id:
|
||||
_fire_and_forget_save(session_name, user_id, session)
|
||||
|
||||
return result
|
||||
1663
autogpt_platform/backend/backend/copilot/tools/agent_browser_test.py
Normal file
1663
autogpt_platform/backend/backend/copilot/tools/agent_browser_test.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -695,7 +695,10 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||
|
||||
|
||||
async def save_agent_to_library(
|
||||
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||
agent_json: dict[str, Any],
|
||||
user_id: str,
|
||||
is_update: bool = False,
|
||||
folder_id: str | None = None,
|
||||
) -> tuple[Graph, Any]:
|
||||
"""Save agent to database and user's library.
|
||||
|
||||
@@ -703,6 +706,7 @@ async def save_agent_to_library(
|
||||
agent_json: Agent JSON dict
|
||||
user_id: User ID
|
||||
is_update: Whether this is an update to an existing agent
|
||||
folder_id: Optional folder ID to place the agent in
|
||||
|
||||
Returns:
|
||||
Tuple of (created Graph, LibraryAgent)
|
||||
@@ -711,7 +715,7 @@ async def save_agent_to_library(
|
||||
db = library_db()
|
||||
if is_update:
|
||||
return await db.update_graph_in_library(graph, user_id)
|
||||
return await db.create_graph_in_library(graph, user_id)
|
||||
return await db.create_graph_in_library(graph, user_id, folder_id=folder_id)
|
||||
|
||||
|
||||
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Base classes and shared utilities for chat tools."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@@ -7,11 +8,98 @@ from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.truncate import truncate
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Persist full tool output to workspace when it exceeds this threshold.
|
||||
# Must be below _MAX_TOOL_OUTPUT_SIZE (100K) in response_model.py so we
|
||||
# capture the data before model_post_init middle-out truncation discards it.
|
||||
_LARGE_OUTPUT_THRESHOLD = 80_000
|
||||
|
||||
# Character budget for the middle-out preview. The total preview + wrapper
|
||||
# must stay below BOTH:
|
||||
# - _MAX_TOOL_OUTPUT_SIZE (100K) in response_model.py (our own truncation)
|
||||
# - Claude SDK's ~100 KB tool-result spill-to-disk threshold
|
||||
# to avoid double truncation/spilling. 95K + ~300 wrapper = ~95.3K, under both.
|
||||
_PREVIEW_CHARS = 95_000
|
||||
|
||||
|
||||
# Fields whose values are binary/base64 data — truncating them produces
|
||||
# garbage, so we replace them with a human-readable size summary instead.
|
||||
_BINARY_FIELD_NAMES = {"content_base64"}
|
||||
|
||||
|
||||
def _summarize_binary_fields(raw_json: str) -> str:
|
||||
"""Replace known binary fields with a size summary so truncate() doesn't
|
||||
produce garbled base64 in the middle-out preview."""
|
||||
try:
|
||||
data = json.loads(raw_json)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return raw_json
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return raw_json
|
||||
|
||||
changed = False
|
||||
for key in _BINARY_FIELD_NAMES:
|
||||
if key in data and isinstance(data[key], str) and len(data[key]) > 1_000:
|
||||
byte_size = len(data[key]) * 3 // 4 # approximate decoded size
|
||||
data[key] = f"<binary, ~{byte_size:,} bytes>"
|
||||
changed = True
|
||||
|
||||
return json.dumps(data, ensure_ascii=False) if changed else raw_json
|
||||
|
||||
|
||||
async def _persist_and_summarize(
|
||||
raw_output: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
) -> str:
|
||||
"""Persist full output to workspace and return a middle-out preview with retrieval instructions.
|
||||
|
||||
On failure, returns the original ``raw_output`` unchanged so that the
|
||||
existing ``model_post_init`` middle-out truncation handles it as before.
|
||||
"""
|
||||
file_path = f"tool-outputs/{tool_call_id}.json"
|
||||
try:
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
await manager.write_file(
|
||||
content=raw_output.encode("utf-8"),
|
||||
filename=f"{tool_call_id}.json",
|
||||
path=file_path,
|
||||
mime_type="application/json",
|
||||
overwrite=True,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist large tool output for %s",
|
||||
tool_call_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return raw_output # fall back to normal truncation
|
||||
|
||||
total = len(raw_output)
|
||||
preview = truncate(_summarize_binary_fields(raw_output), _PREVIEW_CHARS)
|
||||
retrieval = (
|
||||
f"\nFull output ({total:,} chars) saved to workspace. "
|
||||
f"Use read_workspace_file("
|
||||
f'path="{file_path}", offset=<char_offset>, length=50000) '
|
||||
f"to read any section."
|
||||
)
|
||||
return (
|
||||
f'<tool-output-truncated total_chars={total} path="{file_path}">\n'
|
||||
f"{preview}\n"
|
||||
f"{retrieval}\n"
|
||||
f"</tool-output-truncated>"
|
||||
)
|
||||
|
||||
|
||||
class BaseTool:
|
||||
"""Base class for all chat tools."""
|
||||
@@ -36,6 +124,16 @@ class BaseTool:
|
||||
"""Whether this tool requires authentication."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
"""Whether this tool is available in the current environment.
|
||||
|
||||
Override to check required env vars, binaries, or other dependencies.
|
||||
Unavailable tools are excluded from the LLM tool list so the model is
|
||||
never offered an option that will immediately fail.
|
||||
"""
|
||||
return True
|
||||
|
||||
def as_openai_tool(self) -> ChatCompletionToolParam:
|
||||
"""Convert to OpenAI tool format."""
|
||||
return ChatCompletionToolParam(
|
||||
@@ -57,7 +155,7 @@ class BaseTool:
|
||||
"""Execute the tool with authentication check.
|
||||
|
||||
Args:
|
||||
user_id: User ID (may be anonymous like "anon_123")
|
||||
user_id: User ID (None for anonymous users)
|
||||
session_id: Chat session ID
|
||||
**kwargs: Tool-specific parameters
|
||||
|
||||
@@ -81,10 +179,21 @@ class BaseTool:
|
||||
|
||||
try:
|
||||
result = await self._execute(user_id, session, **kwargs)
|
||||
raw_output = result.model_dump_json()
|
||||
|
||||
if (
|
||||
len(raw_output) > _LARGE_OUTPUT_THRESHOLD
|
||||
and user_id
|
||||
and session.session_id
|
||||
):
|
||||
raw_output = await _persist_and_summarize(
|
||||
raw_output, user_id, session.session_id, tool_call_id
|
||||
)
|
||||
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=self.name,
|
||||
output=result.model_dump_json(),
|
||||
output=raw_output,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in {self.name}: {e}", exc_info=True)
|
||||
|
||||
194
autogpt_platform/backend/backend/copilot/tools/base_test.py
Normal file
194
autogpt_platform/backend/backend/copilot/tools/base_test.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Tests for BaseTool large-output persistence in execute()."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.tools.base import (
|
||||
_LARGE_OUTPUT_THRESHOLD,
|
||||
BaseTool,
|
||||
_persist_and_summarize,
|
||||
_summarize_binary_fields,
|
||||
)
|
||||
from backend.copilot.tools.models import ResponseType, ToolResponseBase
|
||||
|
||||
|
||||
class _HugeOutputTool(BaseTool):
|
||||
"""Fake tool that returns an arbitrarily large output."""
|
||||
|
||||
def __init__(self, output_size: int) -> None:
|
||||
self._output_size = output_size
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "huge_output_tool"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Returns a huge output"
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict:
|
||||
return {"type": "object", "properties": {}}
|
||||
|
||||
async def _execute(self, user_id, session, **kwargs) -> ToolResponseBase:
|
||||
return ToolResponseBase(
|
||||
type=ResponseType.ERROR,
|
||||
message="x" * self._output_size,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _persist_and_summarize
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPersistAndSummarize:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_middle_out_preview_with_retrieval_instructions(self):
|
||||
raw = "A" * 200_000
|
||||
|
||||
mock_workspace = MagicMock()
|
||||
mock_workspace.id = "ws-1"
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_or_create_workspace = AsyncMock(return_value=mock_workspace)
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("backend.copilot.tools.base.workspace_db", return_value=mock_db),
|
||||
patch(
|
||||
"backend.copilot.tools.base.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
),
|
||||
):
|
||||
result = await _persist_and_summarize(raw, "user-1", "session-1", "tc-123")
|
||||
|
||||
assert "<tool-output-truncated" in result
|
||||
assert "</tool-output-truncated>" in result
|
||||
assert "total_chars=200000" in result
|
||||
assert 'path="tool-outputs/tc-123.json"' in result
|
||||
assert "read_workspace_file" in result
|
||||
# Middle-out sentinel from truncate()
|
||||
assert "omitted" in result
|
||||
# Total result is much shorter than the raw output
|
||||
assert len(result) < len(raw)
|
||||
|
||||
# Verify write_file was called with full content
|
||||
mock_manager.write_file.assert_awaited_once()
|
||||
call_kwargs = mock_manager.write_file.call_args
|
||||
assert call_kwargs.kwargs["content"] == raw.encode("utf-8")
|
||||
assert call_kwargs.kwargs["path"] == "tool-outputs/tc-123.json"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_on_workspace_error(self):
|
||||
"""If workspace write fails, return raw output for normal truncation."""
|
||||
raw = "B" * 200_000
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_or_create_workspace = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
with patch("backend.copilot.tools.base.workspace_db", return_value=mock_db):
|
||||
result = await _persist_and_summarize(raw, "user-1", "session-1", "tc-fail")
|
||||
|
||||
assert result == raw # unchanged — fallback to normal truncation
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BaseTool.execute — integration with persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBaseToolExecuteLargeOutput:
|
||||
@pytest.mark.asyncio
|
||||
async def test_small_output_not_persisted(self):
|
||||
"""Outputs under the threshold go through without persistence."""
|
||||
tool = _HugeOutputTool(output_size=100)
|
||||
session = MagicMock()
|
||||
session.session_id = "s-1"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.base._persist_and_summarize",
|
||||
new_callable=AsyncMock,
|
||||
) as persist_mock:
|
||||
result = await tool.execute("user-1", session, "tc-small")
|
||||
persist_mock.assert_not_awaited()
|
||||
assert "<tool-output-truncated" not in str(result.output)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_output_persisted(self):
|
||||
"""Outputs over the threshold trigger persistence + preview."""
|
||||
tool = _HugeOutputTool(output_size=_LARGE_OUTPUT_THRESHOLD + 10_000)
|
||||
session = MagicMock()
|
||||
session.session_id = "s-1"
|
||||
|
||||
mock_workspace = MagicMock()
|
||||
mock_workspace.id = "ws-1"
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_or_create_workspace = AsyncMock(return_value=mock_workspace)
|
||||
mock_manager = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("backend.copilot.tools.base.workspace_db", return_value=mock_db),
|
||||
patch(
|
||||
"backend.copilot.tools.base.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
),
|
||||
):
|
||||
result = await tool.execute("user-1", session, "tc-big")
|
||||
|
||||
assert "<tool-output-truncated" in str(result.output)
|
||||
assert "read_workspace_file" in str(result.output)
|
||||
mock_manager.write_file.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_persistence_without_user_id(self):
|
||||
"""Anonymous users skip persistence (no workspace)."""
|
||||
tool = _HugeOutputTool(output_size=_LARGE_OUTPUT_THRESHOLD + 10_000)
|
||||
session = MagicMock()
|
||||
session.session_id = "s-1"
|
||||
|
||||
# user_id=None → should not attempt persistence
|
||||
with patch(
|
||||
"backend.copilot.tools.base._persist_and_summarize",
|
||||
new_callable=AsyncMock,
|
||||
) as persist_mock:
|
||||
result = await tool.execute(None, session, "tc-anon")
|
||||
persist_mock.assert_not_awaited()
|
||||
# Output is set but not wrapped in <tool-output-truncated> tags
|
||||
# (it will be middle-out truncated by model_post_init instead)
|
||||
assert "<tool-output-truncated" not in str(result.output)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _summarize_binary_fields
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSummarizeBinaryFields:
|
||||
def test_replaces_large_content_base64(self):
|
||||
import json
|
||||
|
||||
data = {"content_base64": "A" * 10_000, "name": "file.png"}
|
||||
result = json.loads(_summarize_binary_fields(json.dumps(data)))
|
||||
assert result["name"] == "file.png"
|
||||
assert "<binary" in result["content_base64"]
|
||||
assert "bytes>" in result["content_base64"]
|
||||
|
||||
def test_preserves_small_content_base64(self):
|
||||
import json
|
||||
|
||||
data = {"content_base64": "AQID", "name": "tiny.bin"}
|
||||
result_str = _summarize_binary_fields(json.dumps(data))
|
||||
result = json.loads(result_str)
|
||||
assert result["content_base64"] == "AQID" # unchanged
|
||||
|
||||
def test_non_json_passthrough(self):
|
||||
raw = "not json at all"
|
||||
assert _summarize_binary_fields(raw) == raw
|
||||
|
||||
def test_no_binary_fields_unchanged(self):
|
||||
import json
|
||||
|
||||
data = {"message": "hello", "type": "info"}
|
||||
raw = json.dumps(data)
|
||||
assert _summarize_binary_fields(raw) == raw
|
||||
@@ -1,19 +1,30 @@
|
||||
"""Bash execution tool — run shell commands in a bubblewrap sandbox.
|
||||
"""Bash execution tool — run shell commands on E2B or in a bubblewrap sandbox.
|
||||
|
||||
Full Bash scripting is allowed (loops, conditionals, pipes, functions, etc.).
|
||||
Safety comes from OS-level isolation (bubblewrap): only system dirs visible
|
||||
read-only, writable workspace only, clean env, no network.
|
||||
When an E2B sandbox is available in the current execution context the command
|
||||
runs directly on the remote E2B cloud environment. This means:
|
||||
|
||||
Requires bubblewrap (``bwrap``) — the tool is disabled when bwrap is not
|
||||
available (e.g. macOS development).
|
||||
- **Persistent filesystem**: files survive across turns via HTTP-based sync
|
||||
with the sandbox's ``/home/user`` directory (E2B files API), shared with
|
||||
SDK Read/Write/Edit tools.
|
||||
- **Full internet access**: E2B sandboxes have unrestricted outbound network.
|
||||
- **Execution isolation**: E2B provides a fresh, containerised Linux environment.
|
||||
|
||||
When E2B is *not* configured the tool falls back to **bubblewrap** (bwrap):
|
||||
OS-level isolation with a whitelist-only filesystem, no network, and resource
|
||||
limits. Requires bubblewrap to be installed (Linux only).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import shlex
|
||||
from typing import Any
|
||||
|
||||
from e2b import AsyncSandbox
|
||||
from e2b.exceptions import TimeoutException
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .e2b_sandbox import E2B_WORKDIR
|
||||
from .models import BashExecResponse, ErrorResponse, ToolResponseBase
|
||||
from .sandbox import get_workspace_dir, has_full_sandbox, run_sandboxed
|
||||
|
||||
@@ -21,7 +32,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BashExecTool(BaseTool):
|
||||
"""Execute Bash commands in a bubblewrap sandbox."""
|
||||
"""Execute Bash commands on E2B or in a bubblewrap sandbox."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -29,28 +40,16 @@ class BashExecTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
if not has_full_sandbox():
|
||||
return (
|
||||
"Bash execution is DISABLED — bubblewrap sandbox is not "
|
||||
"available on this platform. Do not call this tool."
|
||||
)
|
||||
return (
|
||||
"Execute a Bash command or script in a bubblewrap sandbox. "
|
||||
"Execute a Bash command or script. "
|
||||
"Full Bash scripting is supported (loops, conditionals, pipes, "
|
||||
"functions, etc.). "
|
||||
"The sandbox shares the same working directory as the SDK Read/Write "
|
||||
"tools — files created by either are accessible to both. "
|
||||
"SECURITY: Only system directories (/usr, /bin, /lib, /etc) are "
|
||||
"visible read-only, the per-session workspace is the only writable "
|
||||
"path, environment variables are wiped (no secrets), all network "
|
||||
"access is blocked at the kernel level, and resource limits are "
|
||||
"enforced (max 64 processes, 512MB memory, 50MB file size). "
|
||||
"Application code, configs, and other directories are NOT accessible. "
|
||||
"To fetch web content, use the web_fetch tool instead. "
|
||||
"The working directory is shared with the SDK Read/Write/Edit/Glob/Grep "
|
||||
"tools — files created by either are immediately visible to both. "
|
||||
"Execution is killed after the timeout (default 30s, max 120s). "
|
||||
"Returns stdout and stderr. "
|
||||
"Useful for file manipulation, data processing with Unix tools "
|
||||
"(grep, awk, sed, jq, etc.), and running shell scripts."
|
||||
"Useful for file manipulation, data processing, running scripts, "
|
||||
"and installing packages."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -85,15 +84,8 @@ class BashExecTool(BaseTool):
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not has_full_sandbox():
|
||||
return ErrorResponse(
|
||||
message="bash_exec requires bubblewrap sandbox (Linux only).",
|
||||
error="sandbox_unavailable",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
command: str = (kwargs.get("command") or "").strip()
|
||||
timeout: int = kwargs.get("timeout", 30)
|
||||
timeout: int = int(kwargs.get("timeout", 30))
|
||||
|
||||
if not command:
|
||||
return ErrorResponse(
|
||||
@@ -102,6 +94,21 @@ class BashExecTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# E2B path: run on remote cloud sandbox when available.
|
||||
from backend.copilot.sdk.tool_adapter import get_current_sandbox
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
return await self._execute_on_e2b(sandbox, command, timeout, session_id)
|
||||
|
||||
# Bubblewrap fallback: local isolated execution.
|
||||
if not has_full_sandbox():
|
||||
return ErrorResponse(
|
||||
message="bash_exec requires bubblewrap sandbox (Linux only).",
|
||||
error="sandbox_unavailable",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
workspace = get_workspace_dir(session_id or "default")
|
||||
|
||||
stdout, stderr, exit_code, timed_out = await run_sandboxed(
|
||||
@@ -122,3 +129,43 @@ class BashExecTool(BaseTool):
|
||||
timed_out=timed_out,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def _execute_on_e2b(
|
||||
self,
|
||||
sandbox: AsyncSandbox,
|
||||
command: str,
|
||||
timeout: int,
|
||||
session_id: str | None,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute *command* on the E2B sandbox via commands.run()."""
|
||||
try:
|
||||
result = await sandbox.commands.run(
|
||||
f"bash -c {shlex.quote(command)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
timeout=timeout,
|
||||
envs={"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"},
|
||||
)
|
||||
return BashExecResponse(
|
||||
message=f"Command executed on E2B (exit {result.exit_code})",
|
||||
stdout=result.stdout or "",
|
||||
stderr=result.stderr or "",
|
||||
exit_code=result.exit_code,
|
||||
timed_out=False,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
if isinstance(exc, TimeoutException):
|
||||
return BashExecResponse(
|
||||
message="Execution timed out",
|
||||
stdout="",
|
||||
stderr=f"Timed out after {timeout}s",
|
||||
exit_code=-1,
|
||||
timed_out=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
logger.error("[E2B] bash_exec failed: %s", exc, exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"E2B execution failed: {exc}",
|
||||
error="e2b_execution_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -1,227 +0,0 @@
|
||||
"""Web browsing tool — navigate real browser sessions to extract page content.
|
||||
|
||||
Uses Stagehand + Browserbase for cloud-based browser execution. Handles
|
||||
JS-rendered pages, SPAs, and dynamic content that web_fetch cannot reach.
|
||||
|
||||
Requires environment variables:
|
||||
STAGEHAND_API_KEY — Browserbase API key
|
||||
STAGEHAND_PROJECT_ID — Browserbase project ID
|
||||
ANTHROPIC_API_KEY — LLM key used by Stagehand for extraction
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import BrowseWebResponse, ErrorResponse, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Stagehand uses the LLM internally for natural-language extraction/actions.
|
||||
_STAGEHAND_MODEL = "anthropic/claude-sonnet-4-5-20250929"
|
||||
# Hard cap on extracted content returned to the LLM context.
|
||||
_MAX_CONTENT_CHARS = 50_000
|
||||
# Explicit timeouts for Stagehand browser operations (milliseconds).
|
||||
_GOTO_TIMEOUT_MS = 30_000 # page navigation
|
||||
_EXTRACT_TIMEOUT_MS = 60_000 # LLM extraction
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread-safety patch for Stagehand signal handlers (applied lazily, once).
|
||||
#
|
||||
# Stagehand calls signal.signal() during __init__, which raises ValueError
|
||||
# when called from a non-main thread (e.g. the CoPilot executor thread pool).
|
||||
# We patch _register_signal_handlers to be a no-op outside the main thread.
|
||||
# The patch is applied exactly once per process via double-checked locking.
|
||||
# ---------------------------------------------------------------------------
|
||||
_stagehand_patched = False
|
||||
_patch_lock = threading.Lock()
|
||||
|
||||
|
||||
def _patch_stagehand_once() -> None:
|
||||
"""Monkey-patch Stagehand signal handler registration to be thread-safe.
|
||||
|
||||
Must be called after ``import stagehand.main`` has succeeded.
|
||||
Safe to call from multiple threads — applies the patch at most once.
|
||||
"""
|
||||
global _stagehand_patched
|
||||
if _stagehand_patched:
|
||||
return
|
||||
with _patch_lock:
|
||||
if _stagehand_patched:
|
||||
return
|
||||
import stagehand.main # noqa: PLC0415
|
||||
|
||||
_original = stagehand.main.Stagehand._register_signal_handlers
|
||||
|
||||
def _safe_register(self: Any) -> None:
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
_original(self)
|
||||
|
||||
stagehand.main.Stagehand._register_signal_handlers = _safe_register
|
||||
_stagehand_patched = True
|
||||
|
||||
|
||||
class BrowseWebTool(BaseTool):
|
||||
"""Navigate a URL with a real browser and extract its content.
|
||||
|
||||
Use this instead of ``web_fetch`` when the page requires JavaScript
|
||||
to render (SPAs, dashboards, paywalled content with JS checks, etc.).
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "browse_web"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Navigate to a URL using a real browser and extract content. "
|
||||
"Handles JavaScript-rendered pages and dynamic content that "
|
||||
"web_fetch cannot reach. "
|
||||
"Specify exactly what to extract via the `instruction` parameter."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The HTTP/HTTPS URL to navigate to.",
|
||||
},
|
||||
"instruction": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"What to extract from the page. Be specific — e.g. "
|
||||
"'Extract all pricing plans with features and prices', "
|
||||
"'Get the main article text and author', "
|
||||
"'List all navigation links'. "
|
||||
"Defaults to extracting the main page content."
|
||||
),
|
||||
"default": "Extract the main content of this page.",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None, # noqa: ARG002
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Navigate to a URL with a real browser and return extracted content."""
|
||||
url: str = (kwargs.get("url") or "").strip()
|
||||
instruction: str = (
|
||||
kwargs.get("instruction") or "Extract the main content of this page."
|
||||
)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not url:
|
||||
return ErrorResponse(
|
||||
message="Please provide a URL to browse.",
|
||||
error="missing_url",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not url.startswith(("http://", "https://")):
|
||||
return ErrorResponse(
|
||||
message="Only HTTP/HTTPS URLs are supported.",
|
||||
error="invalid_url",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
api_key = os.environ.get("STAGEHAND_API_KEY")
|
||||
project_id = os.environ.get("STAGEHAND_PROJECT_ID")
|
||||
model_api_key = os.environ.get("ANTHROPIC_API_KEY")
|
||||
|
||||
if not api_key or not project_id:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Web browsing is not configured on this platform. "
|
||||
"STAGEHAND_API_KEY and STAGEHAND_PROJECT_ID are required."
|
||||
),
|
||||
error="not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not model_api_key:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Web browsing is not configured: ANTHROPIC_API_KEY is required "
|
||||
"for Stagehand's extraction model."
|
||||
),
|
||||
error="not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Lazy import — Stagehand is an optional heavy dependency.
|
||||
# Importing here scopes any ImportError to this tool only, so other
|
||||
# tools continue to register and work normally if Stagehand is absent.
|
||||
try:
|
||||
from stagehand import Stagehand # noqa: PLC0415
|
||||
except ImportError:
|
||||
return ErrorResponse(
|
||||
message="Web browsing is not available: Stagehand is not installed.",
|
||||
error="not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Apply the signal handler patch now that we know stagehand is present.
|
||||
_patch_stagehand_once()
|
||||
|
||||
client: Any | None = None
|
||||
try:
|
||||
client = Stagehand(
|
||||
api_key=api_key,
|
||||
project_id=project_id,
|
||||
model_name=_STAGEHAND_MODEL,
|
||||
model_api_key=model_api_key,
|
||||
)
|
||||
await client.init()
|
||||
|
||||
page = client.page
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
await page.goto(url, timeoutMs=_GOTO_TIMEOUT_MS)
|
||||
result = await page.extract(instruction, timeoutMs=_EXTRACT_TIMEOUT_MS)
|
||||
|
||||
# Extract the text content from the Pydantic result model.
|
||||
raw = result.model_dump().get("extraction", "")
|
||||
content = str(raw) if raw else ""
|
||||
|
||||
truncated = len(content) > _MAX_CONTENT_CHARS
|
||||
if truncated:
|
||||
suffix = "\n\n[Content truncated]"
|
||||
keep = max(0, _MAX_CONTENT_CHARS - len(suffix))
|
||||
content = content[:keep] + suffix
|
||||
|
||||
return BrowseWebResponse(
|
||||
message=f"Browsed {url}",
|
||||
url=url,
|
||||
content=content,
|
||||
truncated=truncated,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[browse_web] Failed for %s", url)
|
||||
return ErrorResponse(
|
||||
message="Failed to browse URL.",
|
||||
error="browse_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
finally:
|
||||
if client is not None:
|
||||
try:
|
||||
await client.close()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -1,486 +0,0 @@
|
||||
"""Unit tests for BrowseWebTool.
|
||||
|
||||
All tests run without a running server / database. External dependencies
|
||||
(Stagehand, Browserbase) are mocked via sys.modules injection so the suite
|
||||
stays fast and deterministic.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import backend.copilot.tools.browse_web as _browse_web_mod
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.browse_web import (
|
||||
_MAX_CONTENT_CHARS,
|
||||
BrowseWebTool,
|
||||
_patch_stagehand_once,
|
||||
)
|
||||
from backend.copilot.tools.models import BrowseWebResponse, ErrorResponse, ResponseType
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_session(user_id: str = "test-user") -> ChatSession:
|
||||
return ChatSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
messages=[],
|
||||
usage=[],
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
successful_agent_runs={},
|
||||
successful_agent_schedules={},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_stagehand_patch():
|
||||
"""Reset the process-level _stagehand_patched flag before every test."""
|
||||
_browse_web_mod._stagehand_patched = False
|
||||
yield
|
||||
_browse_web_mod._stagehand_patched = False
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def env_vars(monkeypatch):
|
||||
"""Inject the three env vars required by BrowseWebTool."""
|
||||
monkeypatch.setenv("STAGEHAND_API_KEY", "test-api-key")
|
||||
monkeypatch.setenv("STAGEHAND_PROJECT_ID", "test-project-id")
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-anthropic-key")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def stagehand_mocks(monkeypatch):
|
||||
"""Inject mock stagehand + stagehand.main into sys.modules.
|
||||
|
||||
Returns a dict with the mock objects so individual tests can
|
||||
assert on calls or inject side-effects.
|
||||
"""
|
||||
# --- mock page ---
|
||||
mock_result = MagicMock()
|
||||
mock_result.model_dump.return_value = {"extraction": "Page content here"}
|
||||
|
||||
mock_page = AsyncMock()
|
||||
mock_page.goto = AsyncMock(return_value=None)
|
||||
mock_page.extract = AsyncMock(return_value=mock_result)
|
||||
|
||||
# --- mock client ---
|
||||
mock_client = AsyncMock()
|
||||
mock_client.page = mock_page
|
||||
mock_client.init = AsyncMock(return_value=None)
|
||||
mock_client.close = AsyncMock(return_value=None)
|
||||
|
||||
MockStagehand = MagicMock(return_value=mock_client)
|
||||
|
||||
# --- stagehand top-level module ---
|
||||
mock_stagehand = MagicMock()
|
||||
mock_stagehand.Stagehand = MockStagehand
|
||||
|
||||
# --- stagehand.main (needed by _patch_stagehand_once) ---
|
||||
mock_main = MagicMock()
|
||||
mock_main.Stagehand = MagicMock()
|
||||
mock_main.Stagehand._register_signal_handlers = MagicMock()
|
||||
|
||||
monkeypatch.setitem(sys.modules, "stagehand", mock_stagehand)
|
||||
monkeypatch.setitem(sys.modules, "stagehand.main", mock_main)
|
||||
|
||||
return {
|
||||
"client": mock_client,
|
||||
"page": mock_page,
|
||||
"result": mock_result,
|
||||
"MockStagehand": MockStagehand,
|
||||
"mock_main": mock_main,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. Tool metadata
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBrowseWebToolMetadata:
|
||||
def test_name(self):
|
||||
assert BrowseWebTool().name == "browse_web"
|
||||
|
||||
def test_requires_auth(self):
|
||||
assert BrowseWebTool().requires_auth is True
|
||||
|
||||
def test_url_is_required_parameter(self):
|
||||
params = BrowseWebTool().parameters
|
||||
assert "url" in params["properties"]
|
||||
assert "url" in params["required"]
|
||||
|
||||
def test_instruction_is_optional(self):
|
||||
params = BrowseWebTool().parameters
|
||||
assert "instruction" in params["properties"]
|
||||
assert "instruction" not in params.get("required", [])
|
||||
|
||||
def test_registered_in_tool_registry(self):
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
assert "browse_web" in TOOL_REGISTRY
|
||||
assert isinstance(TOOL_REGISTRY["browse_web"], BrowseWebTool)
|
||||
|
||||
def test_response_type_enum_value(self):
|
||||
assert ResponseType.BROWSE_WEB == "browse_web"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Input validation (no external deps)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInputValidation:
|
||||
async def test_missing_url_returns_error(self):
|
||||
result = await BrowseWebTool()._execute(user_id="u1", session=make_session())
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "url" in result.message.lower()
|
||||
|
||||
async def test_empty_url_returns_error(self):
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url=""
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
async def test_ftp_url_rejected(self):
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="ftp://example.com/file"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "http" in result.message.lower()
|
||||
|
||||
async def test_file_url_rejected(self):
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="file:///etc/passwd"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
async def test_javascript_url_rejected(self):
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="javascript:alert(1)"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Environment variable checks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnvVarChecks:
|
||||
async def test_missing_api_key(self, monkeypatch):
|
||||
monkeypatch.delenv("STAGEHAND_API_KEY", raising=False)
|
||||
monkeypatch.setenv("STAGEHAND_PROJECT_ID", "proj")
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "key")
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "not_configured"
|
||||
|
||||
async def test_missing_project_id(self, monkeypatch):
|
||||
monkeypatch.setenv("STAGEHAND_API_KEY", "key")
|
||||
monkeypatch.delenv("STAGEHAND_PROJECT_ID", raising=False)
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "key")
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "not_configured"
|
||||
|
||||
async def test_missing_anthropic_key(self, monkeypatch):
|
||||
monkeypatch.setenv("STAGEHAND_API_KEY", "key")
|
||||
monkeypatch.setenv("STAGEHAND_PROJECT_ID", "proj")
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "not_configured"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Stagehand absent (ImportError path)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStagehandAbsent:
|
||||
async def test_returns_not_configured_error(self, env_vars, monkeypatch):
|
||||
"""Blocking the stagehand import must return a graceful ErrorResponse."""
|
||||
# sys.modules entry set to None → Python raises ImportError on import
|
||||
monkeypatch.setitem(sys.modules, "stagehand", None)
|
||||
monkeypatch.setitem(sys.modules, "stagehand.main", None)
|
||||
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "not_configured"
|
||||
assert "not available" in result.message or "not installed" in result.message
|
||||
|
||||
async def test_other_tools_unaffected_when_stagehand_absent(
|
||||
self, env_vars, monkeypatch
|
||||
):
|
||||
"""Registry import must not raise even when stagehand is blocked."""
|
||||
monkeypatch.setitem(sys.modules, "stagehand", None)
|
||||
# This import already happened at module load; just verify the registry exists
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
assert "browse_web" in TOOL_REGISTRY
|
||||
assert "web_fetch" in TOOL_REGISTRY # unrelated tool still present
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Successful browse
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSuccessfulBrowse:
|
||||
async def test_returns_browse_web_response(self, env_vars, stagehand_mocks):
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, BrowseWebResponse)
|
||||
assert result.url == "https://example.com"
|
||||
assert result.content == "Page content here"
|
||||
assert result.truncated is False
|
||||
|
||||
async def test_http_url_accepted(self, env_vars, stagehand_mocks):
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="http://example.com"
|
||||
)
|
||||
assert isinstance(result, BrowseWebResponse)
|
||||
|
||||
async def test_session_id_propagated(self, env_vars, stagehand_mocks):
|
||||
session = make_session()
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=session, url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, BrowseWebResponse)
|
||||
assert result.session_id == session.session_id
|
||||
|
||||
async def test_custom_instruction_forwarded_to_extract(
|
||||
self, env_vars, stagehand_mocks
|
||||
):
|
||||
await BrowseWebTool()._execute(
|
||||
user_id="u1",
|
||||
session=make_session(),
|
||||
url="https://example.com",
|
||||
instruction="Extract all pricing plans",
|
||||
)
|
||||
stagehand_mocks["page"].extract.assert_awaited_once()
|
||||
first_arg = stagehand_mocks["page"].extract.call_args[0][0]
|
||||
assert first_arg == "Extract all pricing plans"
|
||||
|
||||
async def test_default_instruction_used_when_omitted(
|
||||
self, env_vars, stagehand_mocks
|
||||
):
|
||||
await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
first_arg = stagehand_mocks["page"].extract.call_args[0][0]
|
||||
assert "main content" in first_arg.lower()
|
||||
|
||||
async def test_explicit_timeouts_passed_to_stagehand(
|
||||
self, env_vars, stagehand_mocks
|
||||
):
|
||||
from backend.copilot.tools.browse_web import (
|
||||
_EXTRACT_TIMEOUT_MS,
|
||||
_GOTO_TIMEOUT_MS,
|
||||
)
|
||||
|
||||
await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
goto_kwargs = stagehand_mocks["page"].goto.call_args[1]
|
||||
extract_kwargs = stagehand_mocks["page"].extract.call_args[1]
|
||||
assert goto_kwargs.get("timeoutMs") == _GOTO_TIMEOUT_MS
|
||||
assert extract_kwargs.get("timeoutMs") == _EXTRACT_TIMEOUT_MS
|
||||
|
||||
async def test_client_closed_after_success(self, env_vars, stagehand_mocks):
|
||||
await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
stagehand_mocks["client"].close.assert_awaited_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. Truncation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTruncation:
|
||||
async def test_short_content_not_truncated(self, env_vars, stagehand_mocks):
|
||||
stagehand_mocks["result"].model_dump.return_value = {"extraction": "short"}
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, BrowseWebResponse)
|
||||
assert result.truncated is False
|
||||
assert result.content == "short"
|
||||
|
||||
async def test_oversized_content_is_truncated(self, env_vars, stagehand_mocks):
|
||||
big = "a" * (_MAX_CONTENT_CHARS + 1000)
|
||||
stagehand_mocks["result"].model_dump.return_value = {"extraction": big}
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, BrowseWebResponse)
|
||||
assert result.truncated is True
|
||||
assert result.content.endswith("[Content truncated]")
|
||||
|
||||
async def test_truncated_content_never_exceeds_cap(self, env_vars, stagehand_mocks):
|
||||
"""The final string must be ≤ _MAX_CONTENT_CHARS regardless of input size."""
|
||||
big = "b" * (_MAX_CONTENT_CHARS * 3)
|
||||
stagehand_mocks["result"].model_dump.return_value = {"extraction": big}
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, BrowseWebResponse)
|
||||
assert len(result.content) == _MAX_CONTENT_CHARS
|
||||
|
||||
async def test_content_exactly_at_limit_not_truncated(
|
||||
self, env_vars, stagehand_mocks
|
||||
):
|
||||
exact = "c" * _MAX_CONTENT_CHARS
|
||||
stagehand_mocks["result"].model_dump.return_value = {"extraction": exact}
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, BrowseWebResponse)
|
||||
assert result.truncated is False
|
||||
assert len(result.content) == _MAX_CONTENT_CHARS
|
||||
|
||||
async def test_empty_extraction_returns_empty_content(
|
||||
self, env_vars, stagehand_mocks
|
||||
):
|
||||
stagehand_mocks["result"].model_dump.return_value = {"extraction": ""}
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, BrowseWebResponse)
|
||||
assert result.content == ""
|
||||
assert result.truncated is False
|
||||
|
||||
async def test_none_extraction_returns_empty_content(
|
||||
self, env_vars, stagehand_mocks
|
||||
):
|
||||
stagehand_mocks["result"].model_dump.return_value = {"extraction": None}
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, BrowseWebResponse)
|
||||
assert result.content == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. Error handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
async def test_stagehand_init_exception_returns_generic_error(
|
||||
self, env_vars, stagehand_mocks
|
||||
):
|
||||
stagehand_mocks["client"].init.side_effect = RuntimeError("Connection refused")
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "browse_failed"
|
||||
|
||||
async def test_raw_exception_text_not_leaked_to_user(
|
||||
self, env_vars, stagehand_mocks
|
||||
):
|
||||
"""Internal error details must not appear in the user-facing message."""
|
||||
stagehand_mocks["client"].init.side_effect = RuntimeError("SECRET_TOKEN_abc123")
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "SECRET_TOKEN_abc123" not in result.message
|
||||
assert result.message == "Failed to browse URL."
|
||||
|
||||
async def test_goto_timeout_returns_error(self, env_vars, stagehand_mocks):
|
||||
stagehand_mocks["page"].goto.side_effect = TimeoutError("Navigation timed out")
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "browse_failed"
|
||||
|
||||
async def test_client_closed_after_exception(self, env_vars, stagehand_mocks):
|
||||
stagehand_mocks["page"].goto.side_effect = RuntimeError("boom")
|
||||
await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
stagehand_mocks["client"].close.assert_awaited_once()
|
||||
|
||||
async def test_close_failure_does_not_propagate(self, env_vars, stagehand_mocks):
|
||||
"""If close() itself raises, the tool must still return ErrorResponse."""
|
||||
stagehand_mocks["client"].init.side_effect = RuntimeError("init failed")
|
||||
stagehand_mocks["client"].close.side_effect = RuntimeError("close also failed")
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. Thread-safety of _patch_stagehand_once
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPatchStagehandOnce:
|
||||
def test_idempotent_double_call(self, stagehand_mocks):
|
||||
"""_stagehand_patched transitions False→True exactly once."""
|
||||
assert _browse_web_mod._stagehand_patched is False
|
||||
_patch_stagehand_once()
|
||||
assert _browse_web_mod._stagehand_patched is True
|
||||
_patch_stagehand_once() # second call — still True, not re-patched
|
||||
assert _browse_web_mod._stagehand_patched is True
|
||||
|
||||
def test_safe_register_is_noop_in_worker_thread(self, stagehand_mocks):
|
||||
"""The patched handler must silently do nothing when called from a worker."""
|
||||
_patch_stagehand_once()
|
||||
mock_main = sys.modules["stagehand.main"]
|
||||
safe_register = mock_main.Stagehand._register_signal_handlers
|
||||
|
||||
errors: list[Exception] = []
|
||||
|
||||
def run():
|
||||
try:
|
||||
safe_register(MagicMock())
|
||||
except Exception as exc:
|
||||
errors.append(exc)
|
||||
|
||||
t = threading.Thread(target=run)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
assert errors == [], f"Worker thread raised: {errors}"
|
||||
|
||||
def test_patched_flag_set_after_execution(self, env_vars, stagehand_mocks):
|
||||
"""After a successful browse, _stagehand_patched must be True."""
|
||||
|
||||
async def _run():
|
||||
return await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(_run())
|
||||
assert _browse_web_mod._stagehand_patched is True
|
||||
@@ -39,9 +39,13 @@ class CreateAgentTool(BaseTool):
|
||||
return (
|
||||
"Create a new agent workflow from a natural language description. "
|
||||
"First generates a preview, then saves to library if save=true. "
|
||||
"\n\nIMPORTANT: Before calling this tool, search for relevant existing agents "
|
||||
"using find_library_agent that could be used as building blocks. "
|
||||
"Pass their IDs in the library_agent_ids parameter so the generator can compose them."
|
||||
"\n\nWorkflow: (1) Always check find_library_agent first for existing building blocks. "
|
||||
"(2) Call create_agent with description and library_agent_ids. "
|
||||
"(3) If response contains suggested_goal: Present to user, ask for confirmation, "
|
||||
"then call again with the suggested goal if accepted. "
|
||||
"(4) If response contains clarifying_questions: Present to user, collect answers, "
|
||||
"then call again with original description AND answers in the context parameter. "
|
||||
"\n\nThis feedback loop ensures the generated agent matches user intent."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -84,6 +88,14 @@ class CreateAgentTool(BaseTool):
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional folder ID to save the agent into. "
|
||||
"If not provided, the agent is saved at root level. "
|
||||
"Use list_folders to find available folders."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["description"],
|
||||
}
|
||||
@@ -105,6 +117,7 @@ class CreateAgentTool(BaseTool):
|
||||
context = kwargs.get("context", "")
|
||||
library_agent_ids = kwargs.get("library_agent_ids", [])
|
||||
save = kwargs.get("save", True)
|
||||
folder_id = kwargs.get("folder_id")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
logger.info(
|
||||
@@ -336,7 +349,7 @@ class CreateAgentTool(BaseTool):
|
||||
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
agent_json, user_id
|
||||
agent_json, user_id, folder_id=folder_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.store.exceptions import AgentNotFoundError
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import store_db as get_store_db
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
@@ -80,6 +80,14 @@ class CustomizeAgentTool(BaseTool):
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional folder ID to save the agent into. "
|
||||
"If not provided, the agent is saved at root level. "
|
||||
"Use list_folders to find available folders."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["agent_id", "modifications"],
|
||||
}
|
||||
@@ -102,6 +110,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
modifications = kwargs.get("modifications", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
folder_id = kwargs.get("folder_id")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_id:
|
||||
@@ -140,7 +149,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
agent_details = await store_db.get_store_agent_details(
|
||||
username=creator_username, agent_name=agent_slug
|
||||
)
|
||||
except AgentNotFoundError:
|
||||
except NotFoundError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Could not find marketplace agent '{agent_id}'. "
|
||||
@@ -310,7 +319,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
# Save to user's library
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
customized_agent, user_id, is_update=False
|
||||
customized_agent, user_id, is_update=False, folder_id=folder_id
|
||||
)
|
||||
|
||||
return AgentSavedResponse(
|
||||
|
||||
170
autogpt_platform/backend/backend/copilot/tools/e2b_sandbox.py
Normal file
170
autogpt_platform/backend/backend/copilot/tools/e2b_sandbox.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""E2B sandbox lifecycle for CoPilot: persistent cloud execution.
|
||||
|
||||
Each session gets a long-lived E2B cloud sandbox. ``bash_exec`` runs commands
|
||||
directly on the sandbox via ``sandbox.commands.run()``. SDK file tools
|
||||
(read_file/write_file/edit_file/glob/grep) route to the sandbox's
|
||||
``/home/user`` directory via E2B's HTTP-based filesystem API — all tools
|
||||
share a single coherent filesystem with no local sync required.
|
||||
|
||||
Lifecycle
|
||||
---------
|
||||
1. **Turn start** – connect to the existing sandbox (sandbox_id in Redis) or
|
||||
create a new one via ``get_or_create_sandbox()``.
|
||||
2. **Execution** – ``bash_exec`` and MCP file tools operate directly on the
|
||||
sandbox's ``/home/user`` filesystem.
|
||||
3. **Session expiry** – E2B sandbox is killed by its own timeout (session_ttl).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SANDBOX_REDIS_PREFIX = "copilot:e2b:sandbox:"
|
||||
E2B_WORKDIR = "/home/user"
|
||||
_CREATING = "__creating__"
|
||||
_CREATION_LOCK_TTL = 60
|
||||
_MAX_WAIT_ATTEMPTS = 20 # 20 * 0.5s = 10s max wait
|
||||
|
||||
|
||||
async def _try_reconnect(
|
||||
sandbox_id: str, api_key: str, redis_key: str, timeout: int
|
||||
) -> "AsyncSandbox | None":
|
||||
"""Try to reconnect to an existing sandbox. Returns None on failure."""
|
||||
try:
|
||||
sandbox = await AsyncSandbox.connect(sandbox_id, api_key=api_key)
|
||||
if await sandbox.is_running():
|
||||
redis = await get_redis_async()
|
||||
await redis.expire(redis_key, timeout)
|
||||
return sandbox
|
||||
except Exception as exc:
|
||||
logger.warning("[E2B] Reconnect to %.12s failed: %s", sandbox_id, exc)
|
||||
|
||||
# Stale — clear Redis so a new sandbox can be created.
|
||||
redis = await get_redis_async()
|
||||
await redis.delete(redis_key)
|
||||
return None
|
||||
|
||||
|
||||
async def get_or_create_sandbox(
|
||||
session_id: str,
|
||||
api_key: str,
|
||||
template: str = "base",
|
||||
timeout: int = 43200,
|
||||
) -> AsyncSandbox:
|
||||
"""Return the existing E2B sandbox for *session_id* or create a new one.
|
||||
|
||||
The sandbox_id is persisted in Redis so the same sandbox is reused
|
||||
across turns. Concurrent calls for the same session are serialised
|
||||
via a Redis ``SET NX`` creation lock.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
redis_key = f"{_SANDBOX_REDIS_PREFIX}{session_id}"
|
||||
|
||||
# 1. Try reconnecting to an existing sandbox.
|
||||
raw = await redis.get(redis_key)
|
||||
if raw:
|
||||
sandbox_id = raw if isinstance(raw, str) else raw.decode()
|
||||
if sandbox_id != _CREATING:
|
||||
sandbox = await _try_reconnect(sandbox_id, api_key, redis_key, timeout)
|
||||
if sandbox:
|
||||
logger.info(
|
||||
"[E2B] Reconnected to %.12s for session %.12s",
|
||||
sandbox_id,
|
||||
session_id,
|
||||
)
|
||||
return sandbox
|
||||
|
||||
# 2. Claim creation lock. If another request holds it, wait for the result.
|
||||
claimed = await redis.set(redis_key, _CREATING, nx=True, ex=_CREATION_LOCK_TTL)
|
||||
if not claimed:
|
||||
for _ in range(_MAX_WAIT_ATTEMPTS):
|
||||
await asyncio.sleep(0.5)
|
||||
raw = await redis.get(redis_key)
|
||||
if not raw:
|
||||
break # Lock expired — fall through to retry creation
|
||||
sandbox_id = raw if isinstance(raw, str) else raw.decode()
|
||||
if sandbox_id != _CREATING:
|
||||
sandbox = await _try_reconnect(sandbox_id, api_key, redis_key, timeout)
|
||||
if sandbox:
|
||||
return sandbox
|
||||
break # Stale sandbox cleared — fall through to create
|
||||
|
||||
# Try to claim creation lock again after waiting.
|
||||
claimed = await redis.set(redis_key, _CREATING, nx=True, ex=_CREATION_LOCK_TTL)
|
||||
if not claimed:
|
||||
# Another process may have created a sandbox — try to use it.
|
||||
raw = await redis.get(redis_key)
|
||||
if raw:
|
||||
sandbox_id = raw if isinstance(raw, str) else raw.decode()
|
||||
if sandbox_id != _CREATING:
|
||||
sandbox = await _try_reconnect(
|
||||
sandbox_id, api_key, redis_key, timeout
|
||||
)
|
||||
if sandbox:
|
||||
return sandbox
|
||||
raise RuntimeError(
|
||||
f"Could not acquire E2B creation lock for session {session_id[:12]}"
|
||||
)
|
||||
|
||||
# 3. Create a new sandbox.
|
||||
try:
|
||||
sandbox = await AsyncSandbox.create(
|
||||
template=template, api_key=api_key, timeout=timeout
|
||||
)
|
||||
except Exception:
|
||||
await redis.delete(redis_key)
|
||||
raise
|
||||
|
||||
await redis.setex(redis_key, timeout, sandbox.sandbox_id)
|
||||
logger.info(
|
||||
"[E2B] Created sandbox %.12s for session %.12s",
|
||||
sandbox.sandbox_id,
|
||||
session_id,
|
||||
)
|
||||
return sandbox
|
||||
|
||||
|
||||
async def kill_sandbox(session_id: str, api_key: str) -> bool:
|
||||
"""Kill the E2B sandbox for *session_id* and clean up its Redis entry.
|
||||
|
||||
Returns ``True`` if a sandbox was found and killed, ``False`` otherwise.
|
||||
Safe to call even when no sandbox exists for the session.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
redis_key = f"{_SANDBOX_REDIS_PREFIX}{session_id}"
|
||||
raw = await redis.get(redis_key)
|
||||
if not raw:
|
||||
return False
|
||||
|
||||
sandbox_id = raw if isinstance(raw, str) else raw.decode()
|
||||
await redis.delete(redis_key)
|
||||
|
||||
if sandbox_id == _CREATING:
|
||||
return False
|
||||
|
||||
try:
|
||||
|
||||
async def _connect_and_kill():
|
||||
sandbox = await AsyncSandbox.connect(sandbox_id, api_key=api_key)
|
||||
await sandbox.kill()
|
||||
|
||||
await asyncio.wait_for(_connect_and_kill(), timeout=10)
|
||||
logger.info(
|
||||
"[E2B] Killed sandbox %.12s for session %.12s",
|
||||
sandbox_id,
|
||||
session_id,
|
||||
)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[E2B] Failed to kill sandbox %.12s for session %.12s: %s",
|
||||
sandbox_id,
|
||||
session_id,
|
||||
exc,
|
||||
)
|
||||
return False
|
||||
@@ -0,0 +1,272 @@
|
||||
"""Tests for e2b_sandbox: get_or_create_sandbox, _try_reconnect, kill_sandbox.
|
||||
|
||||
Uses mock Redis and mock AsyncSandbox — no external dependencies.
|
||||
Tests are synchronous (using asyncio.run) to avoid conflicts with the
|
||||
session-scoped event loop in conftest.py.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .e2b_sandbox import (
|
||||
_CREATING,
|
||||
_SANDBOX_REDIS_PREFIX,
|
||||
_try_reconnect,
|
||||
get_or_create_sandbox,
|
||||
kill_sandbox,
|
||||
)
|
||||
|
||||
_KEY = f"{_SANDBOX_REDIS_PREFIX}sess-123"
|
||||
_API_KEY = "test-api-key"
|
||||
_TIMEOUT = 300
|
||||
|
||||
|
||||
def _mock_sandbox(sandbox_id: str = "sb-abc", running: bool = True) -> MagicMock:
|
||||
sb = MagicMock()
|
||||
sb.sandbox_id = sandbox_id
|
||||
sb.is_running = AsyncMock(return_value=running)
|
||||
return sb
|
||||
|
||||
|
||||
def _mock_redis(get_val: str | bytes | None = None, set_nx_result: bool = True):
|
||||
r = AsyncMock()
|
||||
r.get = AsyncMock(return_value=get_val)
|
||||
r.set = AsyncMock(return_value=set_nx_result)
|
||||
r.setex = AsyncMock()
|
||||
r.delete = AsyncMock()
|
||||
r.expire = AsyncMock()
|
||||
return r
|
||||
|
||||
|
||||
def _patch_redis(redis):
|
||||
return patch(
|
||||
"backend.copilot.tools.e2b_sandbox.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=redis,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _try_reconnect
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTryReconnect:
|
||||
def test_reconnect_success(self):
|
||||
sb = _mock_sandbox()
|
||||
redis = _mock_redis()
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
|
||||
|
||||
assert result is sb
|
||||
redis.expire.assert_awaited_once_with(_KEY, _TIMEOUT)
|
||||
redis.delete.assert_not_awaited()
|
||||
|
||||
def test_reconnect_not_running_clears_key(self):
|
||||
sb = _mock_sandbox(running=False)
|
||||
redis = _mock_redis()
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
|
||||
|
||||
assert result is None
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
|
||||
def test_reconnect_exception_clears_key(self):
|
||||
redis = _mock_redis()
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
|
||||
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
|
||||
|
||||
assert result is None
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_or_create_sandbox
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetOrCreateSandbox:
|
||||
def test_reconnect_existing(self):
|
||||
"""When Redis has a valid sandbox_id, reconnect to it."""
|
||||
sb = _mock_sandbox()
|
||||
redis = _mock_redis(get_val="sb-abc")
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
assert result is sb
|
||||
mock_cls.create.assert_not_called()
|
||||
|
||||
def test_create_new_when_no_key(self):
|
||||
"""When Redis is empty, claim lock and create a new sandbox."""
|
||||
sb = _mock_sandbox("sb-new")
|
||||
redis = _mock_redis(get_val=None, set_nx_result=True)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.create = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
assert result is sb
|
||||
redis.setex.assert_awaited_once_with(_KEY, _TIMEOUT, "sb-new")
|
||||
|
||||
def test_create_failure_clears_lock(self):
|
||||
"""If sandbox creation fails, the Redis lock is deleted."""
|
||||
redis = _mock_redis(get_val=None, set_nx_result=True)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.create = AsyncMock(side_effect=RuntimeError("quota"))
|
||||
with pytest.raises(RuntimeError, match="quota"):
|
||||
asyncio.run(
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
|
||||
def test_wait_for_lock_then_reconnect(self):
|
||||
"""When another process holds the lock, wait and reconnect."""
|
||||
sb = _mock_sandbox("sb-other")
|
||||
redis = _mock_redis()
|
||||
redis.get = AsyncMock(side_effect=[_CREATING, "sb-other"])
|
||||
redis.set = AsyncMock(return_value=False)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
patch(
|
||||
"backend.copilot.tools.e2b_sandbox.asyncio.sleep",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
assert result is sb
|
||||
|
||||
def test_stale_reconnect_clears_and_creates(self):
|
||||
"""When stored sandbox is stale, clear key and create a new one."""
|
||||
stale_sb = _mock_sandbox("sb-stale", running=False)
|
||||
new_sb = _mock_sandbox("sb-fresh")
|
||||
redis = _mock_redis(get_val="sb-stale", set_nx_result=True)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=stale_sb)
|
||||
mock_cls.create = AsyncMock(return_value=new_sb)
|
||||
result = asyncio.run(
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
assert result is new_sb
|
||||
redis.delete.assert_awaited()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# kill_sandbox
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestKillSandbox:
|
||||
def test_kill_existing_sandbox(self):
|
||||
"""Kill a running sandbox and clean up Redis."""
|
||||
sb = _mock_sandbox()
|
||||
sb.kill = AsyncMock()
|
||||
redis = _mock_redis(get_val="sb-abc")
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is True
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
sb.kill.assert_awaited_once()
|
||||
|
||||
def test_kill_no_sandbox(self):
|
||||
"""No-op when no sandbox exists in Redis."""
|
||||
redis = _mock_redis(get_val=None)
|
||||
with _patch_redis(redis):
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is False
|
||||
redis.delete.assert_not_awaited()
|
||||
|
||||
def test_kill_creating_state(self):
|
||||
"""Clears Redis key but returns False when sandbox is still being created."""
|
||||
redis = _mock_redis(get_val=_CREATING)
|
||||
with _patch_redis(redis):
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is False
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
|
||||
def test_kill_connect_failure(self):
|
||||
"""Returns False and cleans Redis if connect/kill fails."""
|
||||
redis = _mock_redis(get_val="sb-abc")
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is False
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
|
||||
def test_kill_with_bytes_redis_value(self):
|
||||
"""Redis may return bytes — kill_sandbox should decode correctly."""
|
||||
sb = _mock_sandbox()
|
||||
sb.kill = AsyncMock()
|
||||
redis = _mock_redis(get_val=b"sb-abc")
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is True
|
||||
sb.kill.assert_awaited_once()
|
||||
|
||||
def test_kill_timeout_returns_false(self):
|
||||
"""Returns False when E2B API calls exceed the 10s timeout."""
|
||||
redis = _mock_redis(get_val="sb-abc")
|
||||
with (
|
||||
_patch_redis(redis),
|
||||
patch(
|
||||
"backend.copilot.tools.e2b_sandbox.asyncio.wait_for",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=asyncio.TimeoutError,
|
||||
),
|
||||
):
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is False
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
@@ -32,6 +32,7 @@ COPILOT_EXCLUDED_BLOCK_TYPES = {
|
||||
BlockType.NOTE, # Visual annotation only - no runtime behavior
|
||||
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
|
||||
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
|
||||
BlockType.MCP_TOOL, # Has dedicated run_mcp_tool tool with proper discovery + auth flow
|
||||
}
|
||||
|
||||
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
|
||||
|
||||
573
autogpt_platform/backend/backend/copilot/tools/manage_folders.py
Normal file
573
autogpt_platform/backend/backend/copilot/tools/manage_folders.py
Normal file
@@ -0,0 +1,573 @@
|
||||
"""Folder management tools for the copilot."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.library import model as library_model
|
||||
from backend.api.features.library.db import collect_tree_ids
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import library_db
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentsMovedToFolderResponse,
|
||||
ErrorResponse,
|
||||
FolderAgentSummary,
|
||||
FolderCreatedResponse,
|
||||
FolderDeletedResponse,
|
||||
FolderInfo,
|
||||
FolderListResponse,
|
||||
FolderMovedResponse,
|
||||
FolderTreeInfo,
|
||||
FolderUpdatedResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
|
||||
def _folder_to_info(
|
||||
folder: library_model.LibraryFolder,
|
||||
agents: list[FolderAgentSummary] | None = None,
|
||||
) -> FolderInfo:
|
||||
"""Convert a LibraryFolder DB model to a FolderInfo response model."""
|
||||
return FolderInfo(
|
||||
id=folder.id,
|
||||
name=folder.name,
|
||||
parent_id=folder.parent_id,
|
||||
icon=folder.icon,
|
||||
color=folder.color,
|
||||
agent_count=folder.agent_count,
|
||||
subfolder_count=folder.subfolder_count,
|
||||
agents=agents,
|
||||
)
|
||||
|
||||
|
||||
def _tree_to_info(
|
||||
tree: library_model.LibraryFolderTree,
|
||||
agents_map: dict[str, list[FolderAgentSummary]] | None = None,
|
||||
) -> FolderTreeInfo:
|
||||
"""Recursively convert a LibraryFolderTree to a FolderTreeInfo response."""
|
||||
return FolderTreeInfo(
|
||||
id=tree.id,
|
||||
name=tree.name,
|
||||
parent_id=tree.parent_id,
|
||||
icon=tree.icon,
|
||||
color=tree.color,
|
||||
agent_count=tree.agent_count,
|
||||
subfolder_count=tree.subfolder_count,
|
||||
children=[_tree_to_info(child, agents_map) for child in tree.children],
|
||||
agents=agents_map.get(tree.id) if agents_map else None,
|
||||
)
|
||||
|
||||
|
||||
def _to_agent_summaries(
|
||||
raw: list[dict[str, str | None]],
|
||||
) -> list[FolderAgentSummary]:
|
||||
"""Convert raw agent dicts to typed FolderAgentSummary models."""
|
||||
return [
|
||||
FolderAgentSummary(
|
||||
id=a["id"] or "",
|
||||
name=a["name"] or "",
|
||||
description=a["description"] or "",
|
||||
)
|
||||
for a in raw
|
||||
]
|
||||
|
||||
|
||||
def _to_agent_summaries_map(
|
||||
raw: dict[str, list[dict[str, str | None]]],
|
||||
) -> dict[str, list[FolderAgentSummary]]:
|
||||
"""Convert a folder-id-keyed dict of raw agents to typed summaries."""
|
||||
return {fid: _to_agent_summaries(agents) for fid, agents in raw.items()}
|
||||
|
||||
|
||||
class CreateFolderTool(BaseTool):
|
||||
"""Tool for creating a library folder."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "create_folder"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Create a new folder in the user's library to organize agents. "
|
||||
"Optionally nest it inside an existing folder using parent_id."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Name for the new folder (max 100 chars).",
|
||||
},
|
||||
"parent_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"ID of the parent folder to nest inside. "
|
||||
"Omit to create at root level."
|
||||
),
|
||||
},
|
||||
"icon": {
|
||||
"type": "string",
|
||||
"description": "Optional icon identifier for the folder.",
|
||||
},
|
||||
"color": {
|
||||
"type": "string",
|
||||
"description": "Optional hex color code (#RRGGBB).",
|
||||
},
|
||||
},
|
||||
"required": ["name"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Create a folder with the given name and optional parent/icon/color."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
name = (kwargs.get("name") or "").strip()
|
||||
parent_id = kwargs.get("parent_id")
|
||||
icon = kwargs.get("icon")
|
||||
color = kwargs.get("color")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not name:
|
||||
return ErrorResponse(
|
||||
message="Please provide a folder name.",
|
||||
error="missing_name",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
folder = await library_db().create_folder(
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
parent_id=parent_id,
|
||||
icon=icon,
|
||||
color=color,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to create folder: {e}",
|
||||
error="create_folder_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return FolderCreatedResponse(
|
||||
message=f"Folder '{folder.name}' created successfully!",
|
||||
folder=_folder_to_info(folder),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class ListFoldersTool(BaseTool):
|
||||
"""Tool for listing library folders."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "list_folders"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"List the user's library folders. "
|
||||
"Omit parent_id to get the full folder tree. "
|
||||
"Provide parent_id to list only direct children of that folder. "
|
||||
"Set include_agents=true to also return the agents inside each folder "
|
||||
"and root-level agents not in any folder. Always set include_agents=true "
|
||||
"when the user asks about agents, wants to see what's in their folders, "
|
||||
"or mentions agents alongside folders."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"parent_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"List children of this folder. "
|
||||
"Omit to get the full folder tree."
|
||||
),
|
||||
},
|
||||
"include_agents": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to include the list of agents inside each folder. "
|
||||
"Defaults to false."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""List folders as a flat list (by parent) or full tree."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
parent_id = kwargs.get("parent_id")
|
||||
include_agents = kwargs.get("include_agents", False)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
try:
|
||||
if parent_id:
|
||||
folders = await library_db().list_folders(
|
||||
user_id=user_id, parent_id=parent_id
|
||||
)
|
||||
raw_map = (
|
||||
await library_db().get_folder_agents_map(
|
||||
user_id, [f.id for f in folders]
|
||||
)
|
||||
if include_agents
|
||||
else None
|
||||
)
|
||||
agents_map = _to_agent_summaries_map(raw_map) if raw_map else None
|
||||
return FolderListResponse(
|
||||
message=f"Found {len(folders)} folder(s).",
|
||||
folders=[
|
||||
_folder_to_info(f, agents_map.get(f.id) if agents_map else None)
|
||||
for f in folders
|
||||
],
|
||||
count=len(folders),
|
||||
session_id=session_id,
|
||||
)
|
||||
else:
|
||||
tree = await library_db().get_folder_tree(user_id=user_id)
|
||||
all_ids = collect_tree_ids(tree)
|
||||
agents_map = None
|
||||
root_agents = None
|
||||
if include_agents:
|
||||
raw_map = await library_db().get_folder_agents_map(user_id, all_ids)
|
||||
agents_map = _to_agent_summaries_map(raw_map)
|
||||
root_agents = _to_agent_summaries(
|
||||
await library_db().get_root_agent_summaries(user_id)
|
||||
)
|
||||
return FolderListResponse(
|
||||
message=f"Found {len(all_ids)} folder(s) in your library.",
|
||||
tree=[_tree_to_info(t, agents_map) for t in tree],
|
||||
root_agents=root_agents,
|
||||
count=len(all_ids),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to list folders: {e}",
|
||||
error="list_folders_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class UpdateFolderTool(BaseTool):
|
||||
"""Tool for updating a folder's properties."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "update_folder"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Update a folder's name, icon, or color."
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": "ID of the folder to update.",
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "New name for the folder.",
|
||||
},
|
||||
"icon": {
|
||||
"type": "string",
|
||||
"description": "New icon identifier.",
|
||||
},
|
||||
"color": {
|
||||
"type": "string",
|
||||
"description": "New hex color code (#RRGGBB).",
|
||||
},
|
||||
},
|
||||
"required": ["folder_id"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Update a folder's name, icon, or color."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
folder_id = (kwargs.get("folder_id") or "").strip()
|
||||
name = kwargs.get("name")
|
||||
icon = kwargs.get("icon")
|
||||
color = kwargs.get("color")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not folder_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide a folder_id.",
|
||||
error="missing_folder_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
folder = await library_db().update_folder(
|
||||
folder_id=folder_id,
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
icon=icon,
|
||||
color=color,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to update folder: {e}",
|
||||
error="update_folder_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return FolderUpdatedResponse(
|
||||
message=f"Folder updated to '{folder.name}'.",
|
||||
folder=_folder_to_info(folder),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class MoveFolderTool(BaseTool):
|
||||
"""Tool for moving a folder to a new parent."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "move_folder"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Move a folder to a different parent folder. "
|
||||
"Set target_parent_id to null to move to root level."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": "ID of the folder to move.",
|
||||
},
|
||||
"target_parent_id": {
|
||||
"type": ["string", "null"],
|
||||
"description": (
|
||||
"ID of the new parent folder. "
|
||||
"Use null to move to root level."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["folder_id"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Move a folder to a new parent or to root level."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
folder_id = (kwargs.get("folder_id") or "").strip()
|
||||
target_parent_id = kwargs.get("target_parent_id")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not folder_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide a folder_id.",
|
||||
error="missing_folder_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
folder = await library_db().move_folder(
|
||||
folder_id=folder_id,
|
||||
user_id=user_id,
|
||||
target_parent_id=target_parent_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to move folder: {e}",
|
||||
error="move_folder_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
dest = "a subfolder" if target_parent_id else "root level"
|
||||
return FolderMovedResponse(
|
||||
message=f"Folder '{folder.name}' moved to {dest}.",
|
||||
folder=_folder_to_info(folder),
|
||||
target_parent_id=target_parent_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class DeleteFolderTool(BaseTool):
|
||||
"""Tool for deleting a folder."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "delete_folder"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Delete a folder from the user's library. "
|
||||
"Agents inside the folder are moved to root level (not deleted)."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": "ID of the folder to delete.",
|
||||
},
|
||||
},
|
||||
"required": ["folder_id"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Soft-delete a folder; agents inside are moved to root level."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
folder_id = (kwargs.get("folder_id") or "").strip()
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not folder_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide a folder_id.",
|
||||
error="missing_folder_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
await library_db().delete_folder(
|
||||
folder_id=folder_id,
|
||||
user_id=user_id,
|
||||
soft_delete=True,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to delete folder: {e}",
|
||||
error="delete_folder_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return FolderDeletedResponse(
|
||||
message="Folder deleted. Any agents inside were moved to root level.",
|
||||
folder_id=folder_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class MoveAgentsToFolderTool(BaseTool):
|
||||
"""Tool for moving agents into a folder."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "move_agents_to_folder"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Move one or more agents to a folder. "
|
||||
"Set folder_id to null to move agents to root level."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List of library agent IDs to move.",
|
||||
},
|
||||
"folder_id": {
|
||||
"type": ["string", "null"],
|
||||
"description": (
|
||||
"Target folder ID. Use null to move to root level."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["agent_ids"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Move one or more agents to a folder or to root level."""
|
||||
assert user_id is not None # guaranteed by requires_auth
|
||||
agent_ids = kwargs.get("agent_ids", [])
|
||||
folder_id = kwargs.get("folder_id")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_ids:
|
||||
return ErrorResponse(
|
||||
message="Please provide at least one agent ID.",
|
||||
error="missing_agent_ids",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
moved = await library_db().bulk_move_agents_to_folder(
|
||||
agent_ids=agent_ids,
|
||||
folder_id=folder_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to move agents: {e}",
|
||||
error="move_agents_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
moved_ids = [a.id for a in moved]
|
||||
agent_names = [a.name for a in moved]
|
||||
dest = "the folder" if folder_id else "root level"
|
||||
names_str = (
|
||||
", ".join(agent_names) if agent_names else f"{len(agent_ids)} agent(s)"
|
||||
)
|
||||
return AgentsMovedToFolderResponse(
|
||||
message=f"Moved {names_str} to {dest}.",
|
||||
agent_ids=moved_ids,
|
||||
agent_names=agent_names,
|
||||
folder_id=folder_id,
|
||||
count=len(moved),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -0,0 +1,455 @@
|
||||
"""Tests for folder management copilot tools."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.library import model as library_model
|
||||
from backend.copilot.tools.manage_folders import (
|
||||
CreateFolderTool,
|
||||
DeleteFolderTool,
|
||||
ListFoldersTool,
|
||||
MoveAgentsToFolderTool,
|
||||
MoveFolderTool,
|
||||
UpdateFolderTool,
|
||||
)
|
||||
from backend.copilot.tools.models import (
|
||||
AgentsMovedToFolderResponse,
|
||||
ErrorResponse,
|
||||
FolderCreatedResponse,
|
||||
FolderDeletedResponse,
|
||||
FolderListResponse,
|
||||
FolderMovedResponse,
|
||||
FolderUpdatedResponse,
|
||||
)
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-folders"
|
||||
_NOW = datetime.now(UTC)
|
||||
|
||||
|
||||
def _make_folder(
|
||||
id: str = "folder-1",
|
||||
name: str = "My Folder",
|
||||
parent_id: str | None = None,
|
||||
icon: str | None = None,
|
||||
color: str | None = None,
|
||||
agent_count: int = 0,
|
||||
subfolder_count: int = 0,
|
||||
) -> library_model.LibraryFolder:
|
||||
return library_model.LibraryFolder(
|
||||
id=id,
|
||||
user_id=_TEST_USER_ID,
|
||||
name=name,
|
||||
icon=icon,
|
||||
color=color,
|
||||
parent_id=parent_id,
|
||||
created_at=_NOW,
|
||||
updated_at=_NOW,
|
||||
agent_count=agent_count,
|
||||
subfolder_count=subfolder_count,
|
||||
)
|
||||
|
||||
|
||||
def _make_tree(
|
||||
id: str = "folder-1",
|
||||
name: str = "Root",
|
||||
children: list[library_model.LibraryFolderTree] | None = None,
|
||||
) -> library_model.LibraryFolderTree:
|
||||
return library_model.LibraryFolderTree(
|
||||
id=id,
|
||||
user_id=_TEST_USER_ID,
|
||||
name=name,
|
||||
created_at=_NOW,
|
||||
updated_at=_NOW,
|
||||
children=children or [],
|
||||
)
|
||||
|
||||
|
||||
def _make_library_agent(id: str = "agent-1", name: str = "Test Agent"):
|
||||
agent = MagicMock()
|
||||
agent.id = id
|
||||
agent.name = name
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session():
|
||||
return make_session(_TEST_USER_ID)
|
||||
|
||||
|
||||
# ── CreateFolderTool ──
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_tool():
|
||||
return CreateFolderTool()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_folder_missing_name(create_tool, session):
|
||||
result = await create_tool._execute(user_id=_TEST_USER_ID, session=session, name="")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_name"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_folder_none_name(create_tool, session):
|
||||
result = await create_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, name=None
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_name"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_folder_success(create_tool, session):
|
||||
folder = _make_folder(name="New Folder")
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.create_folder = AsyncMock(return_value=folder)
|
||||
result = await create_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, name="New Folder"
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderCreatedResponse)
|
||||
assert result.folder.name == "New Folder"
|
||||
assert "New Folder" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_folder_db_error(create_tool, session):
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.create_folder = AsyncMock(
|
||||
side_effect=Exception("db down")
|
||||
)
|
||||
result = await create_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, name="Folder"
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "create_folder_failed"
|
||||
|
||||
|
||||
# ── ListFoldersTool ──
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def list_tool():
|
||||
return ListFoldersTool()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_folders_by_parent(list_tool, session):
|
||||
folders = [_make_folder(id="f1", name="A"), _make_folder(id="f2", name="B")]
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.list_folders = AsyncMock(return_value=folders)
|
||||
result = await list_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, parent_id="parent-1"
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderListResponse)
|
||||
assert result.count == 2
|
||||
assert len(result.folders) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_folders_tree(list_tool, session):
|
||||
tree = [
|
||||
_make_tree(id="r1", name="Root", children=[_make_tree(id="c1", name="Child")])
|
||||
]
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.get_folder_tree = AsyncMock(return_value=tree)
|
||||
result = await list_tool._execute(user_id=_TEST_USER_ID, session=session)
|
||||
|
||||
assert isinstance(result, FolderListResponse)
|
||||
assert result.count == 2 # root + child
|
||||
assert result.tree is not None
|
||||
assert len(result.tree) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_folders_tree_with_agents_includes_root(list_tool, session):
|
||||
tree = [_make_tree(id="r1", name="Root")]
|
||||
raw_map = {"r1": [{"id": "a1", "name": "Foldered", "description": "In folder"}]}
|
||||
root_raw = [{"id": "a2", "name": "Loose Agent", "description": "At root"}]
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.get_folder_tree = AsyncMock(return_value=tree)
|
||||
mock_lib.return_value.get_folder_agents_map = AsyncMock(return_value=raw_map)
|
||||
mock_lib.return_value.get_root_agent_summaries = AsyncMock(
|
||||
return_value=root_raw
|
||||
)
|
||||
result = await list_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, include_agents=True
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderListResponse)
|
||||
assert result.root_agents is not None
|
||||
assert len(result.root_agents) == 1
|
||||
assert result.root_agents[0].name == "Loose Agent"
|
||||
assert result.tree is not None
|
||||
assert result.tree[0].agents is not None
|
||||
assert result.tree[0].agents[0].name == "Foldered"
|
||||
mock_lib.return_value.get_root_agent_summaries.assert_awaited_once_with(
|
||||
_TEST_USER_ID
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_folders_tree_without_agents_no_root(list_tool, session):
|
||||
tree = [_make_tree(id="r1", name="Root")]
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.get_folder_tree = AsyncMock(return_value=tree)
|
||||
result = await list_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, include_agents=False
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderListResponse)
|
||||
assert result.root_agents is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_folders_db_error(list_tool, session):
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.get_folder_tree = AsyncMock(
|
||||
side_effect=Exception("timeout")
|
||||
)
|
||||
result = await list_tool._execute(user_id=_TEST_USER_ID, session=session)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "list_folders_failed"
|
||||
|
||||
|
||||
# ── UpdateFolderTool ──
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def update_tool():
|
||||
return UpdateFolderTool()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_folder_missing_id(update_tool, session):
|
||||
result = await update_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id=""
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_folder_id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_folder_none_id(update_tool, session):
|
||||
result = await update_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id=None
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_folder_id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_folder_success(update_tool, session):
|
||||
folder = _make_folder(name="Renamed")
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.update_folder = AsyncMock(return_value=folder)
|
||||
result = await update_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id="folder-1", name="Renamed"
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderUpdatedResponse)
|
||||
assert result.folder.name == "Renamed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_folder_db_error(update_tool, session):
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.update_folder = AsyncMock(
|
||||
side_effect=Exception("not found")
|
||||
)
|
||||
result = await update_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id="folder-1", name="X"
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "update_folder_failed"
|
||||
|
||||
|
||||
# ── MoveFolderTool ──
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def move_tool():
|
||||
return MoveFolderTool()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_folder_missing_id(move_tool, session):
|
||||
result = await move_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id=""
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_folder_id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_folder_to_parent(move_tool, session):
|
||||
folder = _make_folder(name="Moved")
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.move_folder = AsyncMock(return_value=folder)
|
||||
result = await move_tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
folder_id="folder-1",
|
||||
target_parent_id="parent-1",
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderMovedResponse)
|
||||
assert "subfolder" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_folder_to_root(move_tool, session):
|
||||
folder = _make_folder(name="Moved")
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.move_folder = AsyncMock(return_value=folder)
|
||||
result = await move_tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
folder_id="folder-1",
|
||||
target_parent_id=None,
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderMovedResponse)
|
||||
assert "root level" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_folder_db_error(move_tool, session):
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.move_folder = AsyncMock(side_effect=Exception("circular"))
|
||||
result = await move_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id="folder-1"
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "move_folder_failed"
|
||||
|
||||
|
||||
# ── DeleteFolderTool ──
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def delete_tool():
|
||||
return DeleteFolderTool()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_folder_missing_id(delete_tool, session):
|
||||
result = await delete_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id=""
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_folder_id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_folder_success(delete_tool, session):
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.delete_folder = AsyncMock(return_value=None)
|
||||
result = await delete_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id="folder-1"
|
||||
)
|
||||
|
||||
assert isinstance(result, FolderDeletedResponse)
|
||||
assert result.folder_id == "folder-1"
|
||||
assert "root level" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_folder_db_error(delete_tool, session):
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.delete_folder = AsyncMock(
|
||||
side_effect=Exception("permission denied")
|
||||
)
|
||||
result = await delete_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, folder_id="folder-1"
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "delete_folder_failed"
|
||||
|
||||
|
||||
# ── MoveAgentsToFolderTool ──
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def move_agents_tool():
|
||||
return MoveAgentsToFolderTool()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_agents_missing_ids(move_agents_tool, session):
|
||||
result = await move_agents_tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, agent_ids=[]
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_agent_ids"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_agents_success(move_agents_tool, session):
|
||||
agents = [
|
||||
_make_library_agent(id="a1", name="Agent Alpha"),
|
||||
_make_library_agent(id="a2", name="Agent Beta"),
|
||||
]
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.bulk_move_agents_to_folder = AsyncMock(
|
||||
return_value=agents
|
||||
)
|
||||
result = await move_agents_tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_ids=["a1", "a2"],
|
||||
folder_id="folder-1",
|
||||
)
|
||||
|
||||
assert isinstance(result, AgentsMovedToFolderResponse)
|
||||
assert result.count == 2
|
||||
assert result.agent_names == ["Agent Alpha", "Agent Beta"]
|
||||
assert "Agent Alpha" in result.message
|
||||
assert "Agent Beta" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_agents_to_root(move_agents_tool, session):
|
||||
agents = [_make_library_agent(id="a1", name="Agent One")]
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.bulk_move_agents_to_folder = AsyncMock(
|
||||
return_value=agents
|
||||
)
|
||||
result = await move_agents_tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_ids=["a1"],
|
||||
folder_id=None,
|
||||
)
|
||||
|
||||
assert isinstance(result, AgentsMovedToFolderResponse)
|
||||
assert "root level" in result.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_agents_db_error(move_agents_tool, session):
|
||||
with patch("backend.copilot.tools.manage_folders.library_db") as mock_lib:
|
||||
mock_lib.return_value.bulk_move_agents_to_folder = AsyncMock(
|
||||
side_effect=Exception("folder not found")
|
||||
)
|
||||
result = await move_agents_tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
agent_ids=["a1"],
|
||||
folder_id="bad-folder",
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "move_agents_failed"
|
||||
@@ -41,8 +41,10 @@ class ResponseType(str, Enum):
|
||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
||||
# Web fetch
|
||||
WEB_FETCH = "web_fetch"
|
||||
# Browser-based web browsing (JS-rendered pages)
|
||||
BROWSE_WEB = "browse_web"
|
||||
# Agent-browser multi-step automation (navigate, act, screenshot)
|
||||
BROWSER_NAVIGATE = "browser_navigate"
|
||||
BROWSER_ACT = "browser_act"
|
||||
BROWSER_SCREENSHOT = "browser_screenshot"
|
||||
# Code execution
|
||||
BASH_EXEC = "bash_exec"
|
||||
# Feature request types
|
||||
@@ -50,6 +52,16 @@ class ResponseType(str, Enum):
|
||||
FEATURE_REQUEST_CREATED = "feature_request_created"
|
||||
# Goal refinement
|
||||
SUGGESTED_GOAL = "suggested_goal"
|
||||
# MCP tool types
|
||||
MCP_TOOLS_DISCOVERED = "mcp_tools_discovered"
|
||||
MCP_TOOL_OUTPUT = "mcp_tool_output"
|
||||
# Folder management types
|
||||
FOLDER_CREATED = "folder_created"
|
||||
FOLDER_LIST = "folder_list"
|
||||
FOLDER_UPDATED = "folder_updated"
|
||||
FOLDER_MOVED = "folder_moved"
|
||||
FOLDER_DELETED = "folder_deleted"
|
||||
AGENTS_MOVED_TO_FOLDER = "agents_moved_to_folder"
|
||||
|
||||
|
||||
# Base response model
|
||||
@@ -440,15 +452,6 @@ class WebFetchResponse(ToolResponseBase):
|
||||
truncated: bool = False
|
||||
|
||||
|
||||
class BrowseWebResponse(ToolResponseBase):
|
||||
"""Response for browse_web tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BROWSE_WEB
|
||||
url: str
|
||||
content: str
|
||||
truncated: bool = False
|
||||
|
||||
|
||||
class BashExecResponse(ToolResponseBase):
|
||||
"""Response for bash_exec tool."""
|
||||
|
||||
@@ -487,3 +490,138 @@ class FeatureRequestCreatedResponse(ToolResponseBase):
|
||||
issue_url: str
|
||||
is_new_issue: bool # False if added to existing
|
||||
customer_name: str
|
||||
|
||||
|
||||
# MCP tool models
|
||||
class MCPToolInfo(BaseModel):
|
||||
"""Information about a single MCP tool discovered from a server."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
input_schema: dict[str, Any]
|
||||
|
||||
|
||||
class MCPToolsDiscoveredResponse(ToolResponseBase):
|
||||
"""Response when MCP tools are discovered from a server (agent-internal)."""
|
||||
|
||||
type: ResponseType = ResponseType.MCP_TOOLS_DISCOVERED
|
||||
server_url: str
|
||||
tools: list[MCPToolInfo]
|
||||
|
||||
|
||||
class MCPToolOutputResponse(ToolResponseBase):
|
||||
"""Response after executing an MCP tool."""
|
||||
|
||||
type: ResponseType = ResponseType.MCP_TOOL_OUTPUT
|
||||
server_url: str
|
||||
tool_name: str
|
||||
result: Any = None
|
||||
success: bool = True
|
||||
|
||||
|
||||
# Agent-browser multi-step automation models
|
||||
|
||||
|
||||
class BrowserNavigateResponse(ToolResponseBase):
|
||||
"""Response for browser_navigate tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BROWSER_NAVIGATE
|
||||
url: str
|
||||
title: str
|
||||
snapshot: str # Interactive accessibility tree with @ref IDs
|
||||
|
||||
|
||||
class BrowserActResponse(ToolResponseBase):
|
||||
"""Response for browser_act tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BROWSER_ACT
|
||||
action: str
|
||||
current_url: str = ""
|
||||
snapshot: str # Updated accessibility tree after the action
|
||||
|
||||
|
||||
class BrowserScreenshotResponse(ToolResponseBase):
|
||||
"""Response for browser_screenshot tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BROWSER_SCREENSHOT
|
||||
file_id: str # Workspace file ID — use read_workspace_file to retrieve
|
||||
filename: str
|
||||
|
||||
|
||||
# Folder management models
|
||||
|
||||
|
||||
class FolderAgentSummary(BaseModel):
|
||||
"""Lightweight agent info for folder listings."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str = ""
|
||||
|
||||
|
||||
class FolderInfo(BaseModel):
|
||||
"""Information about a folder."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
parent_id: str | None = None
|
||||
icon: str | None = None
|
||||
color: str | None = None
|
||||
agent_count: int = 0
|
||||
subfolder_count: int = 0
|
||||
agents: list[FolderAgentSummary] | None = None
|
||||
|
||||
|
||||
class FolderTreeInfo(FolderInfo):
|
||||
"""Folder with nested children for tree display."""
|
||||
|
||||
children: list["FolderTreeInfo"] = []
|
||||
|
||||
|
||||
class FolderCreatedResponse(ToolResponseBase):
|
||||
"""Response when a folder is created."""
|
||||
|
||||
type: ResponseType = ResponseType.FOLDER_CREATED
|
||||
folder: FolderInfo
|
||||
|
||||
|
||||
class FolderListResponse(ToolResponseBase):
|
||||
"""Response for listing folders."""
|
||||
|
||||
type: ResponseType = ResponseType.FOLDER_LIST
|
||||
folders: list[FolderInfo] = Field(default_factory=list)
|
||||
tree: list[FolderTreeInfo] | None = None
|
||||
root_agents: list[FolderAgentSummary] | None = None
|
||||
count: int = 0
|
||||
|
||||
|
||||
class FolderUpdatedResponse(ToolResponseBase):
|
||||
"""Response when a folder is updated."""
|
||||
|
||||
type: ResponseType = ResponseType.FOLDER_UPDATED
|
||||
folder: FolderInfo
|
||||
|
||||
|
||||
class FolderMovedResponse(ToolResponseBase):
|
||||
"""Response when a folder is moved."""
|
||||
|
||||
type: ResponseType = ResponseType.FOLDER_MOVED
|
||||
folder: FolderInfo
|
||||
target_parent_id: str | None = None
|
||||
|
||||
|
||||
class FolderDeletedResponse(ToolResponseBase):
|
||||
"""Response when a folder is deleted."""
|
||||
|
||||
type: ResponseType = ResponseType.FOLDER_DELETED
|
||||
folder_id: str
|
||||
|
||||
|
||||
class AgentsMovedToFolderResponse(ToolResponseBase):
|
||||
"""Response when agents are moved to a folder."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENTS_MOVED_TO_FOLDER
|
||||
agent_ids: list[str]
|
||||
agent_names: list[str] = []
|
||||
folder_id: str | None = None
|
||||
count: int = 0
|
||||
|
||||
343
autogpt_platform/backend/backend/copilot/tools/run_mcp_tool.py
Normal file
343
autogpt_platform/backend/backend/copilot/tools/run_mcp_tool.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""Tool for discovering and executing MCP (Model Context Protocol) server tools."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from backend.blocks.mcp.block import MCPToolBlock
|
||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
||||
from backend.blocks.mcp.helpers import (
|
||||
auto_lookup_mcp_credential,
|
||||
normalize_mcp_url,
|
||||
parse_mcp_content,
|
||||
server_host,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.utils import build_missing_credentials_from_field_info
|
||||
from backend.util.request import HTTPClientError, validate_url
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
ErrorResponse,
|
||||
MCPToolInfo,
|
||||
MCPToolOutputResponse,
|
||||
MCPToolsDiscoveredResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# HTTP status codes that indicate authentication is required
|
||||
_AUTH_STATUS_CODES = {401, 403}
|
||||
|
||||
|
||||
class RunMCPToolTool(BaseTool):
|
||||
"""
|
||||
Tool for discovering and executing tools on any MCP server.
|
||||
|
||||
Stage 1 — discovery: call with just server_url to get available tools.
|
||||
Stage 2 — execution: call with server_url + tool_name + tool_arguments.
|
||||
If the server requires OAuth credentials that the user hasn't connected yet,
|
||||
a SetupRequirementsResponse is returned so the frontend can render the
|
||||
same OAuth login UI as the graph builder.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "run_mcp_tool"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Connect to an MCP (Model Context Protocol) server to discover and execute its tools. "
|
||||
"Two-step workflow: (1) Call with just `server_url` to discover available tools. "
|
||||
"(2) Call again with `server_url`, `tool_name`, and `tool_arguments` to execute. "
|
||||
"Known hosted servers (use directly): Notion (https://mcp.notion.com/mcp), "
|
||||
"Linear (https://mcp.linear.app/mcp), Stripe (https://mcp.stripe.com), "
|
||||
"Intercom (https://mcp.intercom.com/mcp), Cloudflare (https://mcp.cloudflare.com/mcp), "
|
||||
"Atlassian/Jira (https://mcp.atlassian.com/mcp). "
|
||||
"For other services, search the MCP registry at https://registry.modelcontextprotocol.io/. "
|
||||
"Authentication: If the server requires credentials, user will be prompted to complete the MCP credential setup flow."
|
||||
"Once connected and user confirms, retry the same call immediately."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"server_url": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"URL of the MCP server (Streamable HTTP endpoint), "
|
||||
"e.g. https://mcp.example.com/mcp"
|
||||
),
|
||||
},
|
||||
"tool_name": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Name of the MCP tool to execute. "
|
||||
"Omit on first call to discover available tools."
|
||||
),
|
||||
},
|
||||
"tool_arguments": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"Arguments to pass to the selected tool. "
|
||||
"Must match the tool's input schema returned during discovery."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["server_url"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
server_url: str = (kwargs.get("server_url") or "").strip()
|
||||
tool_name: str = (kwargs.get("tool_name") or "").strip()
|
||||
raw_tool_arguments = kwargs.get("tool_arguments")
|
||||
tool_arguments: dict[str, Any] = (
|
||||
raw_tool_arguments if isinstance(raw_tool_arguments, dict) else {}
|
||||
)
|
||||
session_id = session.session_id
|
||||
|
||||
if raw_tool_arguments is not None and not isinstance(raw_tool_arguments, dict):
|
||||
return ErrorResponse(
|
||||
message="tool_arguments must be a JSON object.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not server_url:
|
||||
return ErrorResponse(
|
||||
message="Please provide a server_url for the MCP server.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
_parsed = urlparse(server_url)
|
||||
if _parsed.username or _parsed.password:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Do not include credentials in server_url. "
|
||||
"Use the MCP credential setup flow instead."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
if _parsed.query or _parsed.fragment:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Do not include query parameters or fragments in server_url. "
|
||||
"Use the MCP credential setup flow instead."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges
|
||||
try:
|
||||
await validate_url(server_url, trusted_origins=[])
|
||||
except ValueError as e:
|
||||
msg = str(e)
|
||||
if "Unable to resolve" in msg or "No IP addresses" in msg:
|
||||
user_msg = (
|
||||
f"Hostname not found: {server_host(server_url)}. "
|
||||
"Please check the URL — the domain may not exist."
|
||||
)
|
||||
else:
|
||||
user_msg = f"Blocked server URL: {msg}"
|
||||
return ErrorResponse(message=user_msg, session_id=session_id)
|
||||
|
||||
# Fast DB lookup — no network call.
|
||||
# Normalize for matching because stored credentials use normalized URLs.
|
||||
creds = await auto_lookup_mcp_credential(user_id, normalize_mcp_url(server_url))
|
||||
auth_token = creds.access_token.get_secret_value() if creds else None
|
||||
|
||||
client = MCPClient(server_url, auth_token=auth_token)
|
||||
|
||||
try:
|
||||
await client.initialize()
|
||||
|
||||
if not tool_name:
|
||||
# Stage 1: Discover available tools
|
||||
return await self._discover_tools(client, server_url, session_id)
|
||||
else:
|
||||
# Stage 2: Execute the selected tool
|
||||
return await self._execute_tool(
|
||||
client, server_url, tool_name, tool_arguments, session_id
|
||||
)
|
||||
|
||||
except HTTPClientError as e:
|
||||
if e.status_code in _AUTH_STATUS_CODES and not creds:
|
||||
# Server requires auth and user has no stored credentials
|
||||
return self._build_setup_requirements(server_url, session_id)
|
||||
logger.warning("MCP HTTP error for %s: %s", server_host(server_url), e)
|
||||
return ErrorResponse(
|
||||
message=f"MCP server returned HTTP {e.status_code}: {e}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except MCPClientError as e:
|
||||
logger.warning("MCP client error for %s: %s", server_host(server_url), e)
|
||||
return ErrorResponse(
|
||||
message=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.error(
|
||||
"Unexpected error calling MCP server %s",
|
||||
server_host(server_url),
|
||||
exc_info=True,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message="An unexpected error occurred connecting to the MCP server. Please try again.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def _discover_tools(
|
||||
self,
|
||||
client: MCPClient,
|
||||
server_url: str,
|
||||
session_id: str,
|
||||
) -> MCPToolsDiscoveredResponse:
|
||||
"""List available tools from an already-initialized MCPClient.
|
||||
|
||||
Called when the agent invokes run_mcp_tool with only server_url (no
|
||||
tool_name). Returns MCPToolsDiscoveredResponse so the agent can
|
||||
inspect tool schemas and choose one to execute in a follow-up call.
|
||||
"""
|
||||
tools = await client.list_tools()
|
||||
tool_infos = [
|
||||
MCPToolInfo(
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
input_schema=t.input_schema,
|
||||
)
|
||||
for t in tools
|
||||
]
|
||||
host = server_host(server_url)
|
||||
return MCPToolsDiscoveredResponse(
|
||||
message=(
|
||||
f"Discovered {len(tool_infos)} tool(s) on {host}. "
|
||||
"Call run_mcp_tool again with tool_name and tool_arguments to execute one."
|
||||
),
|
||||
server_url=server_url,
|
||||
tools=tool_infos,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def _execute_tool(
|
||||
self,
|
||||
client: MCPClient,
|
||||
server_url: str,
|
||||
tool_name: str,
|
||||
tool_arguments: dict[str, Any],
|
||||
session_id: str,
|
||||
) -> MCPToolOutputResponse | ErrorResponse:
|
||||
"""Execute a specific tool on an already-initialized MCPClient.
|
||||
|
||||
Parses the MCP content response into a plain Python value:
|
||||
- text items: parsed as JSON when possible, kept as str otherwise
|
||||
- image items: kept as {type, data, mimeType} dict for frontend rendering
|
||||
- resource items: unwrapped to their resource payload dict
|
||||
Single-item responses are unwrapped from the list; multiple items are
|
||||
returned as a list; empty content returns None.
|
||||
"""
|
||||
result = await client.call_tool(tool_name, tool_arguments)
|
||||
|
||||
if result.is_error:
|
||||
error_text = " ".join(
|
||||
item.get("text", "")
|
||||
for item in result.content
|
||||
if item.get("type") == "text"
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=f"MCP tool '{tool_name}' returned an error: {error_text or 'Unknown error'}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
result_value = parse_mcp_content(result.content)
|
||||
|
||||
return MCPToolOutputResponse(
|
||||
message=f"MCP tool '{tool_name}' executed successfully.",
|
||||
server_url=server_url,
|
||||
tool_name=tool_name,
|
||||
result=result_value,
|
||||
success=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def _build_setup_requirements(
|
||||
self,
|
||||
server_url: str,
|
||||
session_id: str,
|
||||
) -> SetupRequirementsResponse | ErrorResponse:
|
||||
"""Build a SetupRequirementsResponse for a missing MCP server credential."""
|
||||
mcp_block = MCPToolBlock()
|
||||
credentials_fields_info = mcp_block.input_schema.get_credentials_fields_info()
|
||||
|
||||
# Apply the server_url discriminator value so the frontend's CredentialsGroupedView
|
||||
# can match the credential to the correct OAuth provider/server.
|
||||
for field_info in credentials_fields_info.values():
|
||||
if field_info.discriminator == "server_url":
|
||||
field_info.discriminator_values.add(server_url)
|
||||
|
||||
missing_creds_dict = build_missing_credentials_from_field_info(
|
||||
credentials_fields_info, matched_keys=set()
|
||||
)
|
||||
|
||||
if not missing_creds_dict:
|
||||
logger.error(
|
||||
"No credential requirements found for MCP server %s — "
|
||||
"MCPToolBlock may not have credentials configured",
|
||||
server_host(server_url),
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"The MCP server at {server_host(server_url)} requires authentication, "
|
||||
"but no credential configuration was found."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
missing_creds_list = list(missing_creds_dict.values())
|
||||
|
||||
host = server_host(server_url)
|
||||
return SetupRequirementsResponse(
|
||||
message=(
|
||||
f"The MCP server at {host} requires authentication. "
|
||||
"Please connect your credentials to continue."
|
||||
),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=server_url,
|
||||
agent_name=f"MCP: {host}",
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_creds_dict,
|
||||
ready_to_run=False,
|
||||
),
|
||||
requirements={
|
||||
"credentials": missing_creds_list,
|
||||
"inputs": [],
|
||||
"execution_modes": ["immediate"],
|
||||
},
|
||||
),
|
||||
graph_id=None,
|
||||
graph_version=None,
|
||||
)
|
||||
@@ -0,0 +1,759 @@
|
||||
"""Unit tests for the run_mcp_tool copilot tool."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.mcp.helpers import server_host
|
||||
|
||||
from ._test_data import make_session
|
||||
from .models import (
|
||||
ErrorResponse,
|
||||
MCPToolOutputResponse,
|
||||
MCPToolsDiscoveredResponse,
|
||||
SetupRequirementsResponse,
|
||||
)
|
||||
from .run_mcp_tool import RunMCPToolTool
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_USER_ID = "test-user-run-mcp-tool"
|
||||
_SERVER_URL = "https://remote.mcpservers.org/fetch/mcp"
|
||||
|
||||
|
||||
def _make_tool_list(*names: str):
|
||||
"""Build a list of mock MCPClientTool objects."""
|
||||
tools = []
|
||||
for name in names:
|
||||
t = MagicMock()
|
||||
t.name = name
|
||||
t.description = f"Description for {name}"
|
||||
t.input_schema = {"type": "object", "properties": {}, "required": []}
|
||||
tools.append(t)
|
||||
return tools
|
||||
|
||||
|
||||
def _make_call_result(content: list[dict], is_error: bool = False) -> MagicMock:
|
||||
result = MagicMock()
|
||||
result.is_error = is_error
|
||||
result.content = content
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# server_host helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_server_host_plain_url():
|
||||
assert server_host("https://mcp.example.com/mcp") == "mcp.example.com"
|
||||
|
||||
|
||||
def test_server_host_strips_credentials():
|
||||
"""netloc would expose user:pass — hostname must not."""
|
||||
assert server_host("https://user:secret@mcp.example.com/mcp") == "mcp.example.com"
|
||||
|
||||
|
||||
def test_server_host_with_port():
|
||||
"""Port should not appear in the returned hostname (hostname strips it)."""
|
||||
assert server_host("https://mcp.example.com:8080/mcp") == "mcp.example.com"
|
||||
|
||||
|
||||
def test_server_host_invalid_url():
|
||||
"""Falls back to the raw string for un-parseable URLs."""
|
||||
result = server_host("not-a-url")
|
||||
assert result == "not-a-url"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_missing_server_url_returns_error():
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
response = await tool._execute(user_id=_USER_ID, session=session)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "server_url" in response.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_missing_user_id_returns_error():
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
response = await tool._execute(
|
||||
user_id=None, session=session, server_url=_SERVER_URL
|
||||
)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "authentication" in response.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_ssrf_blocked_url_returns_error():
|
||||
"""Private/loopback URLs must be rejected before any network call."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked loopback"),
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID, session=session, server_url="http://localhost/mcp"
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert (
|
||||
"blocked" in response.message.lower() or "invalid" in response.message.lower()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_credential_bearing_url_returns_error():
|
||||
"""URLs with embedded user:pass@ must be rejected before any network call."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url="https://user:secret@mcp.example.com/mcp",
|
||||
)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert (
|
||||
"credential" in response.message.lower()
|
||||
or "do not include" in response.message.lower()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_non_dict_tool_arguments_returns_error():
|
||||
"""tool_arguments must be a JSON object — strings/arrays are rejected early."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
tool_name="fetch",
|
||||
tool_arguments=["this", "is", "a", "list"], # wrong type
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "json object" in response.message.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stage 1 — Discovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_returns_discovered_response():
|
||||
"""Calling with only server_url triggers discovery and returns tool list."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
mock_tools = _make_tool_list("fetch", "search")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tools)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
)
|
||||
|
||||
assert isinstance(response, MCPToolsDiscoveredResponse)
|
||||
assert len(response.tools) == 2
|
||||
assert response.tools[0].name == "fetch"
|
||||
assert response.tools[1].name == "search"
|
||||
assert response.server_url == _SERVER_URL
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_with_credentials():
|
||||
"""Stored credentials are passed as Bearer token to MCPClient."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
mock_creds = MagicMock()
|
||||
mock_creds.access_token = SecretStr("test-token-abc")
|
||||
mock_tools = _make_tool_list("push_notification")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_creds,
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tools)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
) as MockMCPClient:
|
||||
MockMCPClient.return_value = mock_client
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
)
|
||||
# Verify MCPClient was created with the resolved auth token
|
||||
MockMCPClient.assert_called_once_with(
|
||||
_SERVER_URL, auth_token="test-token-abc"
|
||||
)
|
||||
|
||||
assert isinstance(response, MCPToolsDiscoveredResponse)
|
||||
assert len(response.tools) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stage 2 — Execution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_execute_tool_returns_output_response():
|
||||
"""Calling with tool_name executes the tool and returns MCPToolOutputResponse."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
text_result = "# Example Domain\nThis domain is for examples."
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
mock_result = _make_call_result([{"type": "text", "text": text_result}])
|
||||
mock_client = AsyncMock()
|
||||
mock_client.call_tool = AsyncMock(return_value=mock_result)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
tool_name="fetch",
|
||||
tool_arguments={"url": "https://example.com"},
|
||||
)
|
||||
|
||||
assert isinstance(response, MCPToolOutputResponse)
|
||||
assert response.tool_name == "fetch"
|
||||
assert response.server_url == _SERVER_URL
|
||||
assert response.success is True
|
||||
assert text_result in response.result
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_execute_tool_parses_json_result():
|
||||
"""JSON text content items are parsed into Python objects."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
mock_result = _make_call_result(
|
||||
[{"type": "text", "text": '{"status": "ok", "count": 42}'}]
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.call_tool = AsyncMock(return_value=mock_result)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
tool_name="status",
|
||||
tool_arguments={},
|
||||
)
|
||||
|
||||
assert isinstance(response, MCPToolOutputResponse)
|
||||
assert response.result == {"status": "ok", "count": 42}
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_execute_tool_image_content():
|
||||
"""Image content items are returned as {type, data, mimeType} dicts."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
mock_result = _make_call_result(
|
||||
[{"type": "image", "data": "abc123==", "mimeType": "image/png"}]
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.call_tool = AsyncMock(return_value=mock_result)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
tool_name="screenshot",
|
||||
tool_arguments={},
|
||||
)
|
||||
|
||||
assert isinstance(response, MCPToolOutputResponse)
|
||||
assert response.result == {
|
||||
"type": "image",
|
||||
"data": "abc123==",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_execute_tool_resource_content():
|
||||
"""Resource content items are unwrapped to their resource payload."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
mock_result = _make_call_result(
|
||||
[
|
||||
{
|
||||
"type": "resource",
|
||||
"resource": {"uri": "file:///tmp/out.txt", "text": "hello"},
|
||||
}
|
||||
]
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.call_tool = AsyncMock(return_value=mock_result)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
tool_name="read_file",
|
||||
tool_arguments={},
|
||||
)
|
||||
|
||||
assert isinstance(response, MCPToolOutputResponse)
|
||||
assert response.result == {"uri": "file:///tmp/out.txt", "text": "hello"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_execute_tool_multi_item_content():
|
||||
"""Multiple content items are returned as a list."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
mock_result = _make_call_result(
|
||||
[
|
||||
{"type": "text", "text": "part one"},
|
||||
{"type": "text", "text": "part two"},
|
||||
]
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.call_tool = AsyncMock(return_value=mock_result)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
tool_name="multi",
|
||||
tool_arguments={},
|
||||
)
|
||||
|
||||
assert isinstance(response, MCPToolOutputResponse)
|
||||
assert response.result == ["part one", "part two"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_execute_tool_empty_content_returns_none():
|
||||
"""Empty content list results in result=None."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
mock_result = _make_call_result([])
|
||||
mock_client = AsyncMock()
|
||||
mock_client.call_tool = AsyncMock(return_value=mock_result)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
tool_name="ping",
|
||||
tool_arguments={},
|
||||
)
|
||||
|
||||
assert isinstance(response, MCPToolOutputResponse)
|
||||
assert response.result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_execute_tool_returns_error_on_tool_failure():
|
||||
"""When the MCP tool returns is_error=True, an ErrorResponse is returned."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
mock_result = _make_call_result(
|
||||
[{"type": "text", "text": "Tool not found"}], is_error=True
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.call_tool = AsyncMock(return_value=mock_result)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
tool_name="nonexistent",
|
||||
tool_arguments={},
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "nonexistent" in response.message
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth / credential flow
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_auth_required_without_creds_returns_setup_requirements():
|
||||
"""HTTP 401 from MCP with no stored creds → SetupRequirementsResponse."""
|
||||
from backend.util.request import HTTPClientError
|
||||
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None, # No stored credentials
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.initialize = AsyncMock(
|
||||
side_effect=HTTPClientError("Unauthorized", status_code=401)
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
with patch.object(
|
||||
RunMCPToolTool,
|
||||
"_build_setup_requirements",
|
||||
return_value=MagicMock(spec=SetupRequirementsResponse),
|
||||
) as mock_build:
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
)
|
||||
mock_build.assert_called_once()
|
||||
|
||||
# Should have returned what _build_setup_requirements returned
|
||||
assert response is mock_build.return_value
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_auth_error_with_existing_creds_returns_error():
|
||||
"""HTTP 403 when creds ARE present → generic ErrorResponse (not setup card)."""
|
||||
from backend.util.request import HTTPClientError
|
||||
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
mock_creds = MagicMock()
|
||||
mock_creds.access_token = SecretStr("stale-token")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_creds,
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.initialize = AsyncMock(
|
||||
side_effect=HTTPClientError("Forbidden", status_code=403)
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "403" in response.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_mcp_client_error_returns_error_response():
|
||||
"""MCPClientError (protocol-level) maps to a clean ErrorResponse."""
|
||||
from backend.blocks.mcp.client import MCPClientError
|
||||
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.initialize = AsyncMock(
|
||||
side_effect=MCPClientError("JSON-RPC protocol error")
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "JSON-RPC" in response.message or "protocol" in response.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unexpected_exception_returns_generic_error():
|
||||
"""Unhandled exceptions inside the MCP call don't leak traceback text to the user."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
# An unexpected error inside initialize (inside the try block)
|
||||
mock_client.initialize = AsyncMock(
|
||||
side_effect=ValueError("Unexpected internal error")
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
# Must not leak the raw exception message
|
||||
assert "Unexpected internal error" not in response.message
|
||||
assert (
|
||||
"unexpected" in response.message.lower() or "error" in response.message.lower()
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool metadata
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_tool_name():
|
||||
assert RunMCPToolTool().name == "run_mcp_tool"
|
||||
|
||||
|
||||
def test_tool_requires_auth():
|
||||
assert RunMCPToolTool().requires_auth is True
|
||||
|
||||
|
||||
def test_tool_parameters_schema():
|
||||
params = RunMCPToolTool().parameters
|
||||
assert params["type"] == "object"
|
||||
assert "server_url" in params["properties"]
|
||||
assert "tool_name" in params["properties"]
|
||||
assert "tool_arguments" in params["properties"]
|
||||
assert params["required"] == ["server_url"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Query/fragment rejection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_query_in_url_returns_error():
|
||||
"""server_url with query parameters must be rejected."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url="https://mcp.example.com/mcp?key=val",
|
||||
)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "query" in response.message.lower() or "fragment" in response.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_fragment_in_url_returns_error():
|
||||
"""server_url with a fragment must be rejected."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url="https://mcp.example.com/mcp#section",
|
||||
)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "query" in response.message.lower() or "fragment" in response.message.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Credential lookup normalization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_credential_lookup_normalizes_trailing_slash():
|
||||
"""Credential lookup must normalize the URL (strip trailing slash)."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
url_with_slash = "https://mcp.example.com/mcp/"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
) as mock_lookup:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=[])
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=url_with_slash,
|
||||
)
|
||||
# Credential lookup should use the normalized URL (no trailing slash)
|
||||
mock_lookup.assert_called_once_with(_USER_ID, "https://mcp.example.com/mcp")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_setup_requirements
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_build_setup_requirements_returns_setup_response():
|
||||
"""_build_setup_requirements should return a SetupRequirementsResponse."""
|
||||
tool = RunMCPToolTool()
|
||||
result = tool._build_setup_requirements(
|
||||
server_url=_SERVER_URL,
|
||||
session_id="test-session",
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert result.setup_info.agent_id == _SERVER_URL
|
||||
assert "authentication" in result.message.lower()
|
||||
@@ -30,6 +30,10 @@ _TEXT_CONTENT_TYPES = {
|
||||
"application/xhtml+xml",
|
||||
"application/rss+xml",
|
||||
"application/atom+xml",
|
||||
# RFC 7807 — JSON problem details; used by many REST APIs for error responses
|
||||
"application/problem+json",
|
||||
"application/problem+xml",
|
||||
"application/ld+json",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.e2b_sandbox import E2B_WORKDIR
|
||||
from backend.copilot.tools.sandbox import make_session_path
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.settings import Config
|
||||
@@ -20,7 +21,7 @@ from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_write_content(
|
||||
async def _resolve_write_content(
|
||||
content_text: str | None,
|
||||
content_b64: str | None,
|
||||
source_path: str | None,
|
||||
@@ -30,6 +31,9 @@ def _resolve_write_content(
|
||||
|
||||
Returns the raw bytes on success, or an ``ErrorResponse`` on validation
|
||||
failure (wrong number of sources, invalid path, file not found, etc.).
|
||||
|
||||
When an E2B sandbox is active, ``source_path`` reads from the sandbox
|
||||
filesystem instead of the local ephemeral directory.
|
||||
"""
|
||||
# Normalise empty strings to None so counting and dispatch stay in sync.
|
||||
if content_text is not None and content_text == "":
|
||||
@@ -54,24 +58,7 @@ def _resolve_write_content(
|
||||
)
|
||||
|
||||
if source_path is not None:
|
||||
validated = _validate_ephemeral_path(
|
||||
source_path, param_name="source_path", session_id=session_id
|
||||
)
|
||||
if isinstance(validated, ErrorResponse):
|
||||
return validated
|
||||
try:
|
||||
with open(validated, "rb") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
return ErrorResponse(
|
||||
message=f"Source file not found: {source_path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read source file: {e}",
|
||||
session_id=session_id,
|
||||
)
|
||||
return await _read_source_path(source_path, session_id)
|
||||
|
||||
if content_b64 is not None:
|
||||
try:
|
||||
@@ -91,6 +78,106 @@ def _resolve_write_content(
|
||||
return content_text.encode("utf-8")
|
||||
|
||||
|
||||
def _resolve_sandbox_path(
|
||||
path: str, session_id: str | None, param_name: str
|
||||
) -> str | ErrorResponse:
|
||||
"""Normalize *path* to an absolute sandbox path under :data:`E2B_WORKDIR`.
|
||||
|
||||
Delegates to :func:`~backend.copilot.sdk.e2b_file_tools._resolve_remote`
|
||||
and wraps any ``ValueError`` into an :class:`ErrorResponse`.
|
||||
"""
|
||||
from backend.copilot.sdk.e2b_file_tools import _resolve_remote
|
||||
|
||||
try:
|
||||
return _resolve_remote(path)
|
||||
except ValueError:
|
||||
return ErrorResponse(
|
||||
message=f"{param_name} must be within {E2B_WORKDIR}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
async def _read_source_path(source_path: str, session_id: str) -> bytes | ErrorResponse:
|
||||
"""Read *source_path* from E2B sandbox or local ephemeral directory."""
|
||||
from backend.copilot.sdk.tool_adapter import get_current_sandbox
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
remote = _resolve_sandbox_path(source_path, session_id, "source_path")
|
||||
if isinstance(remote, ErrorResponse):
|
||||
return remote
|
||||
try:
|
||||
data = await sandbox.files.read(remote, format="bytes")
|
||||
return bytes(data)
|
||||
except Exception as exc:
|
||||
return ErrorResponse(
|
||||
message=f"Source file not found on sandbox: {source_path} ({exc})",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Local fallback: validate path stays within ephemeral directory.
|
||||
validated = _validate_ephemeral_path(
|
||||
source_path, param_name="source_path", session_id=session_id
|
||||
)
|
||||
if isinstance(validated, ErrorResponse):
|
||||
return validated
|
||||
try:
|
||||
with open(validated, "rb") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
return ErrorResponse(
|
||||
message=f"Source file not found: {source_path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read source file: {e}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
async def _save_to_path(
|
||||
path: str, content: bytes, session_id: str
|
||||
) -> str | ErrorResponse:
|
||||
"""Write *content* to *path* on E2B sandbox or local ephemeral directory.
|
||||
|
||||
Returns the resolved path on success, or an ``ErrorResponse`` on failure.
|
||||
"""
|
||||
from backend.copilot.sdk.tool_adapter import get_current_sandbox
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
remote = _resolve_sandbox_path(path, session_id, "save_to_path")
|
||||
if isinstance(remote, ErrorResponse):
|
||||
return remote
|
||||
try:
|
||||
await sandbox.files.write(remote, content)
|
||||
except Exception as exc:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to write to sandbox: {path} ({exc})",
|
||||
session_id=session_id,
|
||||
)
|
||||
return remote
|
||||
|
||||
validated = _validate_ephemeral_path(
|
||||
path, param_name="save_to_path", session_id=session_id
|
||||
)
|
||||
if isinstance(validated, ErrorResponse):
|
||||
return validated
|
||||
try:
|
||||
dir_path = os.path.dirname(validated)
|
||||
if dir_path:
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
with open(validated, "wb") as f:
|
||||
f.write(content)
|
||||
except Exception as exc:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to write to local path: {path} ({exc})",
|
||||
session_id=session_id,
|
||||
)
|
||||
return validated
|
||||
|
||||
|
||||
def _validate_ephemeral_path(
|
||||
path: str, *, param_name: str, session_id: str
|
||||
) -> ErrorResponse | str:
|
||||
@@ -131,7 +218,7 @@ def _is_text_mime(mime_type: str) -> bool:
|
||||
return any(mime_type.startswith(t) for t in _TEXT_MIME_PREFIXES)
|
||||
|
||||
|
||||
async def _get_manager(user_id: str, session_id: str) -> WorkspaceManager:
|
||||
async def get_manager(user_id: str, session_id: str) -> WorkspaceManager:
|
||||
"""Create a session-scoped WorkspaceManager."""
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
return WorkspaceManager(user_id, workspace.id, session_id)
|
||||
@@ -299,7 +386,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||
|
||||
try:
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
files = await manager.list_files(
|
||||
path=path_prefix, limit=limit, include_all_sessions=include_all_sessions
|
||||
)
|
||||
@@ -345,7 +432,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
class ReadWorkspaceFileTool(BaseTool):
|
||||
"""Tool for reading file content from workspace."""
|
||||
|
||||
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
|
||||
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB for text/image files
|
||||
PREVIEW_SIZE = 500
|
||||
|
||||
@property
|
||||
@@ -361,8 +448,10 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"For small text files, returns content directly. "
|
||||
"For large or binary files, returns metadata and a download URL. "
|
||||
"Optionally use 'save_to_path' to copy the file to the ephemeral "
|
||||
"working directory for processing with bash_exec or SDK tools. "
|
||||
"Use 'save_to_path' to copy the file to the working directory "
|
||||
"(sandbox or ephemeral) for processing with bash_exec or file tools. "
|
||||
"Use 'offset' and 'length' for paginated reads of large files "
|
||||
"(e.g., persisted tool outputs). "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
@@ -386,9 +475,10 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
"save_to_path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"If provided, save the file to this path in the ephemeral "
|
||||
"working directory (e.g., '/tmp/copilot-.../data.csv') "
|
||||
"so it can be processed with bash_exec or SDK tools. "
|
||||
"If provided, save the file to this path in the working "
|
||||
"directory (cloud sandbox when E2B is active, or "
|
||||
"ephemeral dir otherwise) so it can be processed with "
|
||||
"bash_exec or file tools. "
|
||||
"The file content is still returned in the response."
|
||||
),
|
||||
},
|
||||
@@ -399,6 +489,20 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
"Default is false (auto-selects based on file size/type)."
|
||||
),
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Character offset to start reading from (0-based). "
|
||||
"Use with 'length' for paginated reads of large files."
|
||||
),
|
||||
},
|
||||
"length": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Maximum number of characters to return. "
|
||||
"Defaults to full file. Use with 'offset' for paginated reads."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [], # At least one must be provided
|
||||
}
|
||||
@@ -423,23 +527,16 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
save_to_path: Optional[str] = kwargs.get("save_to_path")
|
||||
force_download_url: bool = kwargs.get("force_download_url", False)
|
||||
char_offset: int = max(0, kwargs.get("offset", 0))
|
||||
char_length: Optional[int] = kwargs.get("length")
|
||||
|
||||
if not file_id and not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide either file_id or path", session_id=session_id
|
||||
)
|
||||
|
||||
# Validate and resolve save_to_path (use sanitized real path).
|
||||
if save_to_path:
|
||||
validated_save = _validate_ephemeral_path(
|
||||
save_to_path, param_name="save_to_path", session_id=session_id
|
||||
)
|
||||
if isinstance(validated_save, ErrorResponse):
|
||||
return validated_save
|
||||
save_to_path = validated_save
|
||||
|
||||
try:
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
@@ -449,11 +546,38 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
cached_content: bytes | None = None
|
||||
if save_to_path:
|
||||
cached_content = await manager.read_file_by_id(target_file_id)
|
||||
dir_path = os.path.dirname(save_to_path)
|
||||
if dir_path:
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
with open(save_to_path, "wb") as f:
|
||||
f.write(cached_content)
|
||||
result = await _save_to_path(save_to_path, cached_content, session_id)
|
||||
if isinstance(result, ErrorResponse):
|
||||
return result
|
||||
save_to_path = result
|
||||
|
||||
# Ranged read: return a character slice directly.
|
||||
if char_offset > 0 or char_length is not None:
|
||||
raw = cached_content or await manager.read_file_by_id(target_file_id)
|
||||
text = raw.decode("utf-8", errors="replace")
|
||||
total_chars = len(text)
|
||||
end = (
|
||||
char_offset + char_length
|
||||
if char_length is not None
|
||||
else total_chars
|
||||
)
|
||||
slice_text = text[char_offset:end]
|
||||
return WorkspaceFileContentResponse(
|
||||
file_id=file_info.id,
|
||||
name=file_info.name,
|
||||
path=file_info.path,
|
||||
mime_type="text/plain",
|
||||
content_base64=base64.b64encode(slice_text.encode("utf-8")).decode(
|
||||
"utf-8"
|
||||
),
|
||||
message=(
|
||||
f"Read chars {char_offset}–"
|
||||
f"{char_offset + len(slice_text)} "
|
||||
f"of {total_chars:,} total "
|
||||
f"from {file_info.name}"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
is_small = file_info.size_bytes <= self.MAX_INLINE_SIZE_BYTES
|
||||
is_text = _is_text_mime(file_info.mime_type)
|
||||
@@ -629,7 +753,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
content_text: str | None = kwargs.get("content")
|
||||
content_b64: str | None = kwargs.get("content_base64")
|
||||
|
||||
resolved = _resolve_write_content(
|
||||
resolved = await _resolve_write_content(
|
||||
content_text,
|
||||
content_b64,
|
||||
source_path_arg,
|
||||
@@ -648,7 +772,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
|
||||
try:
|
||||
await scan_content_safe(content, filename=filename)
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
rec = await manager.write_file(
|
||||
content=content,
|
||||
filename=filename,
|
||||
@@ -775,7 +899,7 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
|
||||
@@ -102,67 +102,68 @@ class TestValidateEphemeralPath:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
class TestResolveWriteContent:
|
||||
def test_no_sources_returns_error(self):
|
||||
async def test_no_sources_returns_error(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
result = _resolve_write_content(None, None, None, "s1")
|
||||
result = await _resolve_write_content(None, None, None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_multiple_sources_returns_error(self):
|
||||
async def test_multiple_sources_returns_error(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
result = _resolve_write_content("text", "b64data", None, "s1")
|
||||
result = await _resolve_write_content("text", "b64data", None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_plain_text_content(self):
|
||||
result = _resolve_write_content("hello world", None, None, "s1")
|
||||
async def test_plain_text_content(self):
|
||||
result = await _resolve_write_content("hello world", None, None, "s1")
|
||||
assert result == b"hello world"
|
||||
|
||||
def test_base64_content(self):
|
||||
async def test_base64_content(self):
|
||||
raw = b"binary data"
|
||||
b64 = base64.b64encode(raw).decode()
|
||||
result = _resolve_write_content(None, b64, None, "s1")
|
||||
result = await _resolve_write_content(None, b64, None, "s1")
|
||||
assert result == raw
|
||||
|
||||
def test_invalid_base64_returns_error(self):
|
||||
async def test_invalid_base64_returns_error(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
result = _resolve_write_content(None, "not-valid-b64!!!", None, "s1")
|
||||
result = await _resolve_write_content(None, "not-valid-b64!!!", None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "base64" in result.message.lower()
|
||||
|
||||
def test_source_path(self, ephemeral_dir):
|
||||
async def test_source_path(self, ephemeral_dir):
|
||||
target = ephemeral_dir / "input.txt"
|
||||
target.write_bytes(b"file content")
|
||||
result = _resolve_write_content(None, None, str(target), "s1")
|
||||
result = await _resolve_write_content(None, None, str(target), "s1")
|
||||
assert result == b"file content"
|
||||
|
||||
def test_source_path_not_found(self, ephemeral_dir):
|
||||
async def test_source_path_not_found(self, ephemeral_dir):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
missing = str(ephemeral_dir / "nope.txt")
|
||||
result = _resolve_write_content(None, None, missing, "s1")
|
||||
result = await _resolve_write_content(None, None, missing, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_source_path_outside_ephemeral(self, ephemeral_dir, tmp_path):
|
||||
async def test_source_path_outside_ephemeral(self, ephemeral_dir, tmp_path):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
outside = tmp_path / "outside.txt"
|
||||
outside.write_text("nope")
|
||||
result = _resolve_write_content(None, None, str(outside), "s1")
|
||||
result = await _resolve_write_content(None, None, str(outside), "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_empty_string_sources_treated_as_none(self):
|
||||
async def test_empty_string_sources_treated_as_none(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
# All empty strings → same as no sources
|
||||
result = _resolve_write_content("", "", "", "s1")
|
||||
result = await _resolve_write_content("", "", "", "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_empty_string_source_path_with_text(self):
|
||||
async def test_empty_string_source_path_with_text(self):
|
||||
# source_path="" should be normalised to None, so only content counts
|
||||
result = _resolve_write_content("hello", "", "", "s1")
|
||||
result = await _resolve_write_content("hello", "", "", "s1")
|
||||
assert result == b"hello"
|
||||
|
||||
|
||||
@@ -235,6 +236,65 @@ async def test_workspace_file_round_trip(setup_test_data):
|
||||
assert not any(f.file_id == file_id for f in list_resp2.files)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ranged reads (offset / length)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_read_workspace_file_with_offset_and_length(setup_test_data):
|
||||
"""Read a slice of a text file using offset and length."""
|
||||
user = setup_test_data["user"]
|
||||
session = make_session(user.id)
|
||||
|
||||
# Write a known-content file
|
||||
content = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" * 100 # 2600 chars
|
||||
write_tool = WriteWorkspaceFileTool()
|
||||
write_resp = await write_tool._execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
filename="ranged_test.txt",
|
||||
content=content,
|
||||
)
|
||||
assert isinstance(write_resp, WorkspaceWriteResponse), write_resp.message
|
||||
file_id = write_resp.file_id
|
||||
|
||||
from backend.copilot.tools.workspace_files import WorkspaceFileContentResponse
|
||||
|
||||
read_tool = ReadWorkspaceFileTool()
|
||||
|
||||
# Read with offset=100, length=50
|
||||
resp = await read_tool._execute(
|
||||
user_id=user.id, session=session, file_id=file_id, offset=100, length=50
|
||||
)
|
||||
assert isinstance(resp, WorkspaceFileContentResponse), resp.message
|
||||
decoded = base64.b64decode(resp.content_base64).decode()
|
||||
assert decoded == content[100:150]
|
||||
assert "100" in resp.message
|
||||
assert "2,600" in resp.message # total chars (comma-formatted)
|
||||
|
||||
# Read with offset only (no length) — returns from offset to end
|
||||
resp2 = await read_tool._execute(
|
||||
user_id=user.id, session=session, file_id=file_id, offset=2500
|
||||
)
|
||||
assert isinstance(resp2, WorkspaceFileContentResponse)
|
||||
decoded2 = base64.b64decode(resp2.content_base64).decode()
|
||||
assert decoded2 == content[2500:]
|
||||
assert len(decoded2) == 100
|
||||
|
||||
# Read with offset beyond file length — returns empty string
|
||||
resp3 = await read_tool._execute(
|
||||
user_id=user.id, session=session, file_id=file_id, offset=9999, length=10
|
||||
)
|
||||
assert isinstance(resp3, WorkspaceFileContentResponse)
|
||||
decoded3 = base64.b64decode(resp3.content_base64).decode()
|
||||
assert decoded3 == ""
|
||||
|
||||
# Cleanup
|
||||
delete_tool = DeleteWorkspaceFileTool()
|
||||
await delete_tool._execute(user_id=user.id, session=session, file_id=file_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_write_workspace_file_source_path(setup_test_data):
|
||||
"""E2E: write a file from ephemeral source_path to workspace."""
|
||||
|
||||
@@ -81,6 +81,7 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.CLAUDE_4_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_SONNET: 5,
|
||||
LlmModel.CLAUDE_4_6_OPUS: 14,
|
||||
LlmModel.CLAUDE_4_6_SONNET: 9,
|
||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||
LlmModel.CLAUDE_4_5_OPUS: 14,
|
||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||
|
||||
@@ -178,9 +178,13 @@ async def test_block_credit_reset(server: SpinTestServer):
|
||||
assert month2_balance == 1100 # Balance persists, no reset
|
||||
|
||||
# Now test the refill behavior when balance is low
|
||||
# Set balance below refill threshold
|
||||
# Set balance below refill threshold and backdate updatedAt to month2 so
|
||||
# the month3 refill check sees a different (month2 → month3) transition.
|
||||
# Without the explicit updatedAt, Prisma sets it to real-world NOW which
|
||||
# may share the same calendar month as the mocked month3, suppressing refill.
|
||||
await UserBalance.prisma().update(
|
||||
where={"userId": DEFAULT_USER_ID}, data={"balance": 400}
|
||||
where={"userId": DEFAULT_USER_ID},
|
||||
data={"balance": 400, "updatedAt": month2},
|
||||
)
|
||||
|
||||
# Create a month 2 transaction to update the last transaction time
|
||||
|
||||
@@ -4,11 +4,20 @@ from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cas
|
||||
|
||||
from backend.api.features.library.db import (
|
||||
add_store_agent_to_library,
|
||||
bulk_move_agents_to_folder,
|
||||
create_folder,
|
||||
create_graph_in_library,
|
||||
create_library_agent,
|
||||
delete_folder,
|
||||
get_folder_agents_map,
|
||||
get_folder_tree,
|
||||
get_library_agent,
|
||||
get_library_agent_by_graph_id,
|
||||
get_root_agent_summaries,
|
||||
list_folders,
|
||||
list_library_agents,
|
||||
move_folder,
|
||||
update_folder,
|
||||
update_graph_in_library,
|
||||
)
|
||||
from backend.api.features.store.db import (
|
||||
@@ -260,6 +269,16 @@ class DatabaseManager(AppService):
|
||||
update_graph_in_library = _(update_graph_in_library)
|
||||
validate_graph_execution_permissions = _(validate_graph_execution_permissions)
|
||||
|
||||
create_folder = _(create_folder)
|
||||
list_folders = _(list_folders)
|
||||
get_folder_tree = _(get_folder_tree)
|
||||
update_folder = _(update_folder)
|
||||
move_folder = _(move_folder)
|
||||
delete_folder = _(delete_folder)
|
||||
bulk_move_agents_to_folder = _(bulk_move_agents_to_folder)
|
||||
get_folder_agents_map = _(get_folder_agents_map)
|
||||
get_root_agent_summaries = _(get_root_agent_summaries)
|
||||
|
||||
# ============ Onboarding ============ #
|
||||
increment_onboarding_runs = _(increment_onboarding_runs)
|
||||
|
||||
@@ -305,6 +324,7 @@ class DatabaseManager(AppService):
|
||||
delete_chat_session = _(chat_db.delete_chat_session)
|
||||
get_next_sequence = _(chat_db.get_next_sequence)
|
||||
update_tool_message_content = _(chat_db.update_tool_message_content)
|
||||
update_chat_session_title = _(chat_db.update_chat_session_title)
|
||||
|
||||
|
||||
class DatabaseManagerClient(AppServiceClient):
|
||||
@@ -433,6 +453,17 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
update_graph_in_library = d.update_graph_in_library
|
||||
validate_graph_execution_permissions = d.validate_graph_execution_permissions
|
||||
|
||||
# ============ Library Folders ============ #
|
||||
create_folder = d.create_folder
|
||||
list_folders = d.list_folders
|
||||
get_folder_tree = d.get_folder_tree
|
||||
update_folder = d.update_folder
|
||||
move_folder = d.move_folder
|
||||
delete_folder = d.delete_folder
|
||||
bulk_move_agents_to_folder = d.bulk_move_agents_to_folder
|
||||
get_folder_agents_map = d.get_folder_agents_map
|
||||
get_root_agent_summaries = d.get_root_agent_summaries
|
||||
|
||||
# ============ Onboarding ============ #
|
||||
increment_onboarding_runs = d.increment_onboarding_runs
|
||||
|
||||
@@ -475,3 +506,4 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
delete_chat_session = d.delete_chat_session
|
||||
get_next_sequence = d.get_next_sequence
|
||||
update_tool_message_content = d.update_tool_message_content
|
||||
update_chat_session_title = d.update_chat_session_title
|
||||
|
||||
@@ -344,7 +344,7 @@ class GraphExecution(GraphExecutionMeta):
|
||||
),
|
||||
**{
|
||||
# input from webhook-triggered block
|
||||
"payload": exec.input_data["payload"]
|
||||
"payload": exec.input_data.get("payload")
|
||||
for exec in complete_node_executions
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
|
||||
@@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import fastapi.exceptions
|
||||
import prisma
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
@@ -250,8 +251,8 @@ async def test_clean_graph(server: SpinTestServer):
|
||||
"_test_id": "node_with_secrets",
|
||||
"input": "normal_value",
|
||||
"control_test_input": "should be preserved",
|
||||
"api_key": "secret_api_key_123", # Should be filtered
|
||||
"password": "secret_password_456", # Should be filtered
|
||||
"api_key": "secret_api_key_123", # Should be filtered # pragma: allowlist secret # noqa
|
||||
"password": "secret_password_456", # Should be filtered # pragma: allowlist secret # noqa
|
||||
"token": "secret_token_789", # Should be filtered
|
||||
"credentials": { # Should be filtered
|
||||
"id": "fake-github-credentials-id",
|
||||
@@ -354,9 +355,24 @@ async def test_access_store_listing_graph(server: SpinTestServer):
|
||||
create_graph, DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
# Ensure the default user has a Profile (required for store submissions)
|
||||
existing_profile = await prisma.models.Profile.prisma().find_first(
|
||||
where={"userId": DEFAULT_USER_ID}
|
||||
)
|
||||
if not existing_profile:
|
||||
await prisma.models.Profile.prisma().create(
|
||||
data=prisma.types.ProfileCreateInput(
|
||||
userId=DEFAULT_USER_ID,
|
||||
name="Default User",
|
||||
username=f"default-user-{DEFAULT_USER_ID[:8]}",
|
||||
description="Default test user profile",
|
||||
links=[],
|
||||
)
|
||||
)
|
||||
|
||||
store_submission_request = store.StoreSubmissionRequest(
|
||||
agent_id=created_graph.id,
|
||||
agent_version=created_graph.version,
|
||||
graph_id=created_graph.id,
|
||||
graph_version=created_graph.version,
|
||||
slug=created_graph.id,
|
||||
name="Test name",
|
||||
sub_heading="Test sub heading",
|
||||
@@ -385,8 +401,8 @@ async def test_access_store_listing_graph(server: SpinTestServer):
|
||||
assert False, "Failed to create store listing"
|
||||
|
||||
slv_id = (
|
||||
store_listing.store_listing_version_id
|
||||
if store_listing.store_listing_version_id is not None
|
||||
store_listing.listing_version_id
|
||||
if store_listing.listing_version_id is not None
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
@@ -184,17 +184,17 @@ async def find_webhook_by_credentials_and_props(
|
||||
credentials_id: str,
|
||||
webhook_type: str,
|
||||
resource: str,
|
||||
events: Optional[list[str]],
|
||||
events: list[str] | None = None,
|
||||
) -> Webhook | None:
|
||||
webhook = await IntegrationWebhook.prisma().find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"credentialsId": credentials_id,
|
||||
"webhookType": webhook_type,
|
||||
"resource": resource,
|
||||
**({"events": {"has_every": events}} if events else {}),
|
||||
},
|
||||
)
|
||||
where: IntegrationWebhookWhereInput = {
|
||||
"userId": user_id,
|
||||
"credentialsId": credentials_id,
|
||||
"webhookType": webhook_type,
|
||||
"resource": resource,
|
||||
}
|
||||
if events is not None:
|
||||
where["events"] = {"has_every": events}
|
||||
webhook = await IntegrationWebhook.prisma().find_first(where=where)
|
||||
return Webhook.from_db(webhook) if webhook else None
|
||||
|
||||
|
||||
|
||||
601
autogpt_platform/backend/backend/data/invited_user.py
Normal file
601
autogpt_platform/backend/backend/data/invited_user.py
Normal file
@@ -0,0 +1,601 @@
|
||||
import asyncio
|
||||
import csv
|
||||
import io
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
from pydantic import BaseModel, EmailStr, TypeAdapter, ValidationError
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.data.model import User
|
||||
from backend.data.tally import get_business_understanding_input_from_tally
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
merge_business_understanding_data,
|
||||
)
|
||||
from backend.util.exceptions import (
|
||||
NotAuthorizedError,
|
||||
NotFoundError,
|
||||
PreconditionFailed,
|
||||
)
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_tally_seed_tasks: set[asyncio.Task] = set()
|
||||
_email_adapter = TypeAdapter(EmailStr)
|
||||
|
||||
MAX_BULK_INVITE_FILE_BYTES = 1024 * 1024
|
||||
MAX_BULK_INVITE_ROWS = 500
|
||||
|
||||
|
||||
class InvitedUserRecord(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
status: prisma.enums.InvitedUserStatus
|
||||
auth_user_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
tally_understanding: Optional[dict[str, Any]] = None
|
||||
tally_status: prisma.enums.TallyComputationStatus
|
||||
tally_computed_at: Optional[datetime] = None
|
||||
tally_error: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, invited_user: "prisma.models.InvitedUser") -> "InvitedUserRecord":
|
||||
payload = (
|
||||
invited_user.tallyUnderstanding
|
||||
if isinstance(invited_user.tallyUnderstanding, dict)
|
||||
else None
|
||||
)
|
||||
return cls(
|
||||
id=invited_user.id,
|
||||
email=invited_user.email,
|
||||
status=invited_user.status,
|
||||
auth_user_id=invited_user.authUserId,
|
||||
name=invited_user.name,
|
||||
tally_understanding=payload,
|
||||
tally_status=invited_user.tallyStatus,
|
||||
tally_computed_at=invited_user.tallyComputedAt,
|
||||
tally_error=invited_user.tallyError,
|
||||
created_at=invited_user.createdAt,
|
||||
updated_at=invited_user.updatedAt,
|
||||
)
|
||||
|
||||
|
||||
class BulkInvitedUserRowResult(BaseModel):
|
||||
row_number: int
|
||||
email: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
status: Literal["CREATED", "SKIPPED", "ERROR"]
|
||||
message: str
|
||||
invited_user: Optional[InvitedUserRecord] = None
|
||||
|
||||
|
||||
class BulkInvitedUsersResult(BaseModel):
|
||||
created_count: int
|
||||
skipped_count: int
|
||||
error_count: int
|
||||
results: list[BulkInvitedUserRowResult]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ParsedInviteRow:
|
||||
row_number: int
|
||||
email: str
|
||||
name: Optional[str]
|
||||
|
||||
|
||||
def normalize_email(email: str) -> str:
|
||||
return email.strip().lower()
|
||||
|
||||
|
||||
def _normalize_name(name: Optional[str]) -> Optional[str]:
|
||||
if name is None:
|
||||
return None
|
||||
normalized = name.strip()
|
||||
return normalized or None
|
||||
|
||||
|
||||
def _default_profile_name(email: str, preferred_name: Optional[str]) -> str:
|
||||
if preferred_name:
|
||||
return preferred_name
|
||||
local_part = email.split("@", 1)[0].strip()
|
||||
return local_part or "user"
|
||||
|
||||
|
||||
def _sanitize_username_base(email: str) -> str:
|
||||
local_part = email.split("@", 1)[0].lower()
|
||||
sanitized = re.sub(r"[^a-z0-9-]", "", local_part)
|
||||
sanitized = sanitized.strip("-")
|
||||
return sanitized[:40] or "user"
|
||||
|
||||
|
||||
async def _generate_unique_profile_username(email: str, tx) -> str:
|
||||
base = _sanitize_username_base(email)
|
||||
|
||||
for attempt in range(10):
|
||||
candidate = base if attempt == 0 else f"{base}-{uuid4().hex[:6]}"
|
||||
existing = await prisma.models.Profile.prisma(tx).find_unique(
|
||||
where={"username": candidate}
|
||||
)
|
||||
if existing is None:
|
||||
return candidate
|
||||
|
||||
raise RuntimeError(f"Unable to generate unique username for {email}")
|
||||
|
||||
|
||||
async def _ensure_default_profile(
|
||||
user_id: str,
|
||||
email: str,
|
||||
preferred_name: Optional[str],
|
||||
tx,
|
||||
) -> None:
|
||||
existing_profile = await prisma.models.Profile.prisma(tx).find_unique(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
if existing_profile is not None:
|
||||
return
|
||||
|
||||
username = await _generate_unique_profile_username(email, tx)
|
||||
await prisma.models.Profile.prisma(tx).create(
|
||||
data=prisma.types.ProfileCreateInput(
|
||||
userId=user_id,
|
||||
name=_default_profile_name(email, preferred_name),
|
||||
username=username,
|
||||
description="I'm new here",
|
||||
links=[],
|
||||
avatarUrl="",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _ensure_default_onboarding(user_id: str, tx) -> None:
|
||||
await prisma.models.UserOnboarding.prisma(tx).upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": prisma.types.UserOnboardingCreateInput(userId=user_id),
|
||||
"update": {},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _apply_tally_understanding(
|
||||
user_id: str,
|
||||
invited_user: "prisma.models.InvitedUser",
|
||||
tx,
|
||||
) -> None:
|
||||
if not isinstance(invited_user.tallyUnderstanding, dict):
|
||||
return
|
||||
|
||||
input_data = BusinessUnderstandingInput.model_validate(
|
||||
invited_user.tallyUnderstanding
|
||||
)
|
||||
payload = merge_business_understanding_data({}, input_data)
|
||||
await prisma.models.CoPilotUnderstanding.prisma(tx).upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "data": SafeJson(payload)},
|
||||
"update": {"data": SafeJson(payload)},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def list_invited_users() -> list[InvitedUserRecord]:
|
||||
invited_users = await prisma.models.InvitedUser.prisma().find_many(
|
||||
order={"createdAt": "desc"}
|
||||
)
|
||||
return [InvitedUserRecord.from_db(invited_user) for invited_user in invited_users]
|
||||
|
||||
|
||||
async def create_invited_user(
|
||||
email: str, name: Optional[str] = None
|
||||
) -> InvitedUserRecord:
|
||||
normalized_email = normalize_email(email)
|
||||
normalized_name = _normalize_name(name)
|
||||
|
||||
existing_user = await prisma.models.User.prisma().find_unique(
|
||||
where={"email": normalized_email}
|
||||
)
|
||||
if existing_user is not None:
|
||||
raise PreconditionFailed("An active user with this email already exists")
|
||||
|
||||
existing_invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"email": normalized_email}
|
||||
)
|
||||
if existing_invited_user is not None:
|
||||
raise PreconditionFailed("An invited user with this email already exists")
|
||||
|
||||
invited_user = await prisma.models.InvitedUser.prisma().create(
|
||||
data={
|
||||
"email": normalized_email,
|
||||
"name": normalized_name,
|
||||
"status": prisma.enums.InvitedUserStatus.INVITED,
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
|
||||
}
|
||||
)
|
||||
schedule_invited_user_tally_precompute(invited_user.id)
|
||||
return InvitedUserRecord.from_db(invited_user)
|
||||
|
||||
|
||||
async def revoke_invited_user(invited_user_id: str) -> InvitedUserRecord:
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"id": invited_user_id}
|
||||
)
|
||||
if invited_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.CLAIMED:
|
||||
raise PreconditionFailed("Claimed invited users cannot be revoked")
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
|
||||
return InvitedUserRecord.from_db(invited_user)
|
||||
|
||||
revoked_user = await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={"status": prisma.enums.InvitedUserStatus.REVOKED},
|
||||
)
|
||||
if revoked_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
return InvitedUserRecord.from_db(revoked_user)
|
||||
|
||||
|
||||
async def retry_invited_user_tally(invited_user_id: str) -> InvitedUserRecord:
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"id": invited_user_id}
|
||||
)
|
||||
if invited_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
|
||||
raise PreconditionFailed("Revoked invited users cannot retry Tally seeding")
|
||||
|
||||
refreshed_user = await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyUnderstanding": None,
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.PENDING,
|
||||
"tallyComputedAt": None,
|
||||
"tallyError": None,
|
||||
},
|
||||
)
|
||||
if refreshed_user is None:
|
||||
raise NotFoundError(f"Invited user {invited_user_id} not found")
|
||||
schedule_invited_user_tally_precompute(invited_user_id)
|
||||
return InvitedUserRecord.from_db(refreshed_user)
|
||||
|
||||
|
||||
def _decode_bulk_invite_file(content: bytes) -> str:
|
||||
if len(content) > MAX_BULK_INVITE_FILE_BYTES:
|
||||
raise ValueError("Invite file exceeds the maximum size of 1 MB")
|
||||
|
||||
try:
|
||||
return content.decode("utf-8-sig")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise ValueError("Invite file must be UTF-8 encoded") from exc
|
||||
|
||||
|
||||
def _parse_bulk_invite_csv(text: str) -> list[_ParsedInviteRow]:
|
||||
indexed_rows: list[tuple[int, list[str]]] = []
|
||||
|
||||
for row_number, row in enumerate(csv.reader(io.StringIO(text)), start=1):
|
||||
normalized_row = [cell.strip() for cell in row]
|
||||
if any(normalized_row):
|
||||
indexed_rows.append((row_number, normalized_row))
|
||||
|
||||
if not indexed_rows:
|
||||
return []
|
||||
|
||||
header = [cell.lower() for cell in indexed_rows[0][1]]
|
||||
has_header = "email" in header
|
||||
email_index = header.index("email") if has_header else 0
|
||||
name_index = header.index("name") if has_header and "name" in header else 1
|
||||
data_rows = indexed_rows[1:] if has_header else indexed_rows
|
||||
|
||||
parsed_rows: list[_ParsedInviteRow] = []
|
||||
for row_number, row in data_rows:
|
||||
email = row[email_index].strip() if len(row) > email_index else ""
|
||||
name = row[name_index].strip() if len(row) > name_index else ""
|
||||
parsed_rows.append(
|
||||
_ParsedInviteRow(
|
||||
row_number=row_number,
|
||||
email=email,
|
||||
name=name or None,
|
||||
)
|
||||
)
|
||||
|
||||
return parsed_rows
|
||||
|
||||
|
||||
def _parse_bulk_invite_text(text: str) -> list[_ParsedInviteRow]:
|
||||
parsed_rows: list[_ParsedInviteRow] = []
|
||||
|
||||
for row_number, raw_line in enumerate(text.splitlines(), start=1):
|
||||
line = raw_line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
parsed_rows.append(
|
||||
_ParsedInviteRow(
|
||||
row_number=row_number,
|
||||
email=line,
|
||||
name=None,
|
||||
)
|
||||
)
|
||||
|
||||
return parsed_rows
|
||||
|
||||
|
||||
def _parse_bulk_invite_file(
|
||||
filename: Optional[str],
|
||||
content: bytes,
|
||||
) -> list[_ParsedInviteRow]:
|
||||
text = _decode_bulk_invite_file(content)
|
||||
file_name = filename.lower() if filename else ""
|
||||
parsed_rows = (
|
||||
_parse_bulk_invite_csv(text)
|
||||
if file_name.endswith(".csv")
|
||||
else _parse_bulk_invite_text(text)
|
||||
)
|
||||
|
||||
if not parsed_rows:
|
||||
raise ValueError("Invite file did not contain any emails")
|
||||
|
||||
if len(parsed_rows) > MAX_BULK_INVITE_ROWS:
|
||||
raise ValueError(
|
||||
f"Invite file contains too many rows. Maximum supported rows: {MAX_BULK_INVITE_ROWS}"
|
||||
)
|
||||
|
||||
return parsed_rows
|
||||
|
||||
|
||||
async def bulk_create_invited_users_from_file(
|
||||
filename: Optional[str],
|
||||
content: bytes,
|
||||
) -> BulkInvitedUsersResult:
|
||||
parsed_rows = _parse_bulk_invite_file(filename, content)
|
||||
|
||||
created_count = 0
|
||||
skipped_count = 0
|
||||
error_count = 0
|
||||
results: list[BulkInvitedUserRowResult] = []
|
||||
seen_emails: set[str] = set()
|
||||
|
||||
for row in parsed_rows:
|
||||
row_name = _normalize_name(row.name)
|
||||
|
||||
try:
|
||||
validated_email = _email_adapter.validate_python(row.email)
|
||||
except ValidationError:
|
||||
error_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=row.email or None,
|
||||
name=row_name,
|
||||
status="ERROR",
|
||||
message="Invalid email address",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
normalized_email = normalize_email(str(validated_email))
|
||||
if normalized_email in seen_emails:
|
||||
skipped_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="SKIPPED",
|
||||
message="Duplicate email in upload file",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
seen_emails.add(normalized_email)
|
||||
|
||||
try:
|
||||
invited_user = await create_invited_user(normalized_email, row_name)
|
||||
except PreconditionFailed as exc:
|
||||
skipped_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="SKIPPED",
|
||||
message=str(exc),
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to create bulk invite for %s from row %s",
|
||||
normalized_email,
|
||||
row.row_number,
|
||||
)
|
||||
error_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="ERROR",
|
||||
message="Unexpected error creating invite",
|
||||
)
|
||||
)
|
||||
else:
|
||||
created_count += 1
|
||||
results.append(
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=row.row_number,
|
||||
email=normalized_email,
|
||||
name=row_name,
|
||||
status="CREATED",
|
||||
message="Invite created",
|
||||
invited_user=invited_user,
|
||||
)
|
||||
)
|
||||
|
||||
return BulkInvitedUsersResult(
|
||||
created_count=created_count,
|
||||
skipped_count=skipped_count,
|
||||
error_count=error_count,
|
||||
results=results,
|
||||
)
|
||||
|
||||
|
||||
async def _compute_invited_user_tally_seed(invited_user_id: str) -> None:
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"id": invited_user_id}
|
||||
)
|
||||
if invited_user is None:
|
||||
return
|
||||
|
||||
if invited_user.status == prisma.enums.InvitedUserStatus.REVOKED:
|
||||
return
|
||||
|
||||
await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.RUNNING,
|
||||
"tallyError": None,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
input_data = await get_business_understanding_input_from_tally(
|
||||
invited_user.email,
|
||||
require_api_key=True,
|
||||
)
|
||||
payload = (
|
||||
SafeJson(input_data.model_dump(exclude_none=True))
|
||||
if input_data is not None
|
||||
else None
|
||||
)
|
||||
await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyUnderstanding": payload,
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.READY,
|
||||
"tallyComputedAt": datetime.now(timezone.utc),
|
||||
"tallyError": None,
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"Failed to compute Tally understanding for invited user %s",
|
||||
invited_user_id,
|
||||
)
|
||||
await prisma.models.InvitedUser.prisma().update(
|
||||
where={"id": invited_user_id},
|
||||
data={
|
||||
"tallyStatus": prisma.enums.TallyComputationStatus.FAILED,
|
||||
"tallyError": str(exc),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def schedule_invited_user_tally_precompute(invited_user_id: str) -> None:
|
||||
task = asyncio.create_task(_compute_invited_user_tally_seed(invited_user_id))
|
||||
_tally_seed_tasks.add(task)
|
||||
task.add_done_callback(_tally_seed_tasks.discard)
|
||||
|
||||
|
||||
async def get_or_activate_user(user_data: dict) -> User:
|
||||
auth_user_id = user_data.get("sub")
|
||||
if not auth_user_id:
|
||||
raise NotAuthorizedError("User ID not found in token")
|
||||
|
||||
auth_email = user_data.get("email")
|
||||
if not auth_email:
|
||||
raise NotAuthorizedError("Email not found in token")
|
||||
|
||||
normalized_email = normalize_email(auth_email)
|
||||
user_metadata = user_data.get("user_metadata")
|
||||
metadata_name = (
|
||||
user_metadata.get("name") if isinstance(user_metadata, dict) else None
|
||||
)
|
||||
|
||||
existing_user = await prisma.models.User.prisma().find_unique(
|
||||
where={"id": auth_user_id}
|
||||
)
|
||||
if existing_user is not None:
|
||||
return User.from_db(existing_user)
|
||||
|
||||
invited_user = await prisma.models.InvitedUser.prisma().find_unique(
|
||||
where={"email": normalized_email}
|
||||
)
|
||||
if invited_user is None:
|
||||
raise NotAuthorizedError("Your email is not allowed to access the platform")
|
||||
|
||||
if invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
|
||||
raise NotAuthorizedError("Your invitation is no longer active")
|
||||
|
||||
async with transaction() as tx:
|
||||
current_user = await prisma.models.User.prisma(tx).find_unique(
|
||||
where={"id": auth_user_id}
|
||||
)
|
||||
if current_user is not None:
|
||||
return User.from_db(current_user)
|
||||
|
||||
current_invited_user = await prisma.models.InvitedUser.prisma(tx).find_unique(
|
||||
where={"email": normalized_email}
|
||||
)
|
||||
if current_invited_user is None:
|
||||
raise NotAuthorizedError("Your email is not allowed to access the platform")
|
||||
|
||||
if current_invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
|
||||
raise NotAuthorizedError("Your invitation is no longer active")
|
||||
|
||||
if current_invited_user.authUserId not in (None, auth_user_id):
|
||||
raise NotAuthorizedError("Your invitation has already been claimed")
|
||||
|
||||
preferred_name = current_invited_user.name or _normalize_name(metadata_name)
|
||||
await prisma.models.User.prisma(tx).create(
|
||||
data=prisma.types.UserCreateInput(
|
||||
id=auth_user_id,
|
||||
email=normalized_email,
|
||||
name=preferred_name,
|
||||
)
|
||||
)
|
||||
|
||||
await prisma.models.InvitedUser.prisma(tx).update(
|
||||
where={"id": current_invited_user.id},
|
||||
data={
|
||||
"status": prisma.enums.InvitedUserStatus.CLAIMED,
|
||||
"authUserId": auth_user_id,
|
||||
},
|
||||
)
|
||||
|
||||
await _ensure_default_profile(
|
||||
auth_user_id,
|
||||
normalized_email,
|
||||
preferred_name,
|
||||
tx,
|
||||
)
|
||||
await _ensure_default_onboarding(auth_user_id, tx)
|
||||
await _apply_tally_understanding(auth_user_id, current_invited_user, tx)
|
||||
|
||||
from backend.data.user import get_user_by_email, get_user_by_id
|
||||
|
||||
get_user_by_id.cache_delete(auth_user_id)
|
||||
get_user_by_email.cache_delete(normalized_email)
|
||||
|
||||
activated_user = await prisma.models.User.prisma().find_unique(
|
||||
where={"id": auth_user_id}
|
||||
)
|
||||
if activated_user is None:
|
||||
raise RuntimeError(
|
||||
f"Activated user {auth_user_id} was not found after creation"
|
||||
)
|
||||
|
||||
return User.from_db(activated_user)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user