mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Merge branch 'dev' into ntindle/open-3018-google-drive-file-inputs-are-non-chainable-on-new-builder
This commit is contained in:
3
autogpt_platform/.gitignore
vendored
3
autogpt_platform/.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
*.ignore.*
|
||||
*.ign.*
|
||||
*.ign.*
|
||||
.application.logs
|
||||
|
||||
@@ -95,7 +95,7 @@ ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
|
||||
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
|
||||
# for the bash_exec MCP tool.
|
||||
# for the bash_exec MCP tool (fallback when E2B is not configured).
|
||||
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.13 \
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
from uuid import uuid4
|
||||
@@ -9,7 +10,8 @@ from uuid import uuid4
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from prisma.models import UserWorkspaceFile
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
@@ -47,10 +49,14 @@ from backend.copilot.tools.models import (
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -79,6 +85,9 @@ class StreamChatRequest(BaseModel):
|
||||
message: str
|
||||
is_user_message: bool = True
|
||||
context: dict[str, str] | None = None # {url: str, content: str}
|
||||
file_ids: list[str] | None = Field(
|
||||
default=None, max_length=20
|
||||
) # Workspace file IDs attached to this message
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
@@ -238,6 +247,18 @@ async def delete_session(
|
||||
detail=f"Session {session_id} not found or access denied",
|
||||
)
|
||||
|
||||
# Best-effort cleanup of the E2B sandbox (if any).
|
||||
config = ChatConfig()
|
||||
if config.use_e2b_sandbox and config.e2b_api_key:
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
|
||||
try:
|
||||
await kill_sandbox(session_id, config.e2b_api_key)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[E2B] Failed to kill sandbox for session %s", session_id[:12]
|
||||
)
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@@ -394,6 +415,38 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
|
||||
# Enrich message with file metadata if file_ids are provided.
|
||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||
sanitized_file_ids: list[str] | None = None
|
||||
if request.file_ids and user_id:
|
||||
# Filter to valid UUIDs only to prevent DB abuse
|
||||
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
|
||||
|
||||
if valid_ids:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Batch query instead of N+1
|
||||
files = await UserWorkspaceFile.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": valid_ids},
|
||||
"workspaceId": workspace.id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
)
|
||||
# Only keep IDs that actually exist in the user's workspace
|
||||
sanitized_file_ids = [wf.id for wf in files] or None
|
||||
file_lines: list[str] = [
|
||||
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
|
||||
for wf in files
|
||||
]
|
||||
if file_lines:
|
||||
files_block = (
|
||||
"\n\n[Attached files]\n"
|
||||
+ "\n".join(file_lines)
|
||||
+ "\nUse read_workspace_file with the file_id to access file contents."
|
||||
)
|
||||
request.message += files_block
|
||||
|
||||
# Atomically append user message to session BEFORE creating task to avoid
|
||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||
# saved yet. append_and_save_message re-fetches inside a lock to prevent
|
||||
@@ -445,6 +498,7 @@ async def stream_chat_post(
|
||||
turn_id=turn_id,
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
file_ids=sanitized_file_ids,
|
||||
)
|
||||
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
"""Tests for chat route file_ids validation and enrichment."""
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.chat import routes as chat_routes
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(chat_routes.router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ---- file_ids Pydantic validation (B1) ----
|
||||
|
||||
|
||||
def test_stream_chat_rejects_too_many_file_ids():
|
||||
"""More than 20 file_ids should be rejected by Pydantic validation (422)."""
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={
|
||||
"message": "hello",
|
||||
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(21)],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def _mock_stream_internals(mocker: pytest_mock.MockFixture):
|
||||
"""Mock the async internals of stream_chat_post so tests can exercise
|
||||
validation and enrichment logic without needing Redis/RabbitMQ."""
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=None,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.create_session = mocker.AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
|
||||
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
|
||||
"""Exactly 20 file_ids should be accepted (not rejected by validation)."""
|
||||
_mock_stream_internals(mocker)
|
||||
# Patch workspace lookup as imported by the routes module
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={
|
||||
"message": "hello",
|
||||
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(20)],
|
||||
},
|
||||
)
|
||||
# Should get past validation — 200 streaming response expected
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ---- UUID format filtering ----
|
||||
|
||||
|
||||
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
"""Non-UUID strings in file_ids should be silently filtered out
|
||||
and NOT passed to the database query."""
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
valid_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={
|
||||
"message": "hello",
|
||||
"file_ids": [
|
||||
valid_id,
|
||||
"not-a-uuid",
|
||||
"../../../etc/passwd",
|
||||
"",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
# The find_many call should only receive the one valid UUID
|
||||
mock_prisma.find_many.assert_called_once()
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["id"]["in"] == [valid_id]
|
||||
|
||||
|
||||
# ---- Cross-workspace file_ids ----
|
||||
|
||||
|
||||
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
"""The batch query should scope to the user's workspace."""
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "my-workspace-id"})(),
|
||||
)
|
||||
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
fid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hi", "file_ids": [fid]},
|
||||
)
|
||||
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
|
||||
assert call_kwargs["where"]["isDeleted"] is False
|
||||
@@ -3,15 +3,29 @@ Workspace API routes for managing user file storage.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Annotated
|
||||
from urllib.parse import quote
|
||||
|
||||
import fastapi
|
||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||
from fastapi import Query, UploadFile
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.workspace import WorkspaceFile, get_workspace, get_workspace_file
|
||||
from backend.data.workspace import (
|
||||
WorkspaceFile,
|
||||
count_workspace_files,
|
||||
get_or_create_workspace,
|
||||
get_workspace,
|
||||
get_workspace_file,
|
||||
get_workspace_total_size,
|
||||
soft_delete_workspace_file,
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
|
||||
@@ -98,6 +112,21 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
||||
raise
|
||||
|
||||
|
||||
class UploadFileResponse(BaseModel):
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
|
||||
|
||||
class StorageUsageResponse(BaseModel):
|
||||
used_bytes: int
|
||||
limit_bytes: int
|
||||
used_percent: float
|
||||
file_count: int
|
||||
|
||||
|
||||
@router.get(
|
||||
"/files/{file_id}/download",
|
||||
summary="Download file by ID",
|
||||
@@ -120,3 +149,120 @@ async def download_file(
|
||||
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
return await _create_file_download_response(file)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/files/upload",
|
||||
summary="Upload file to workspace",
|
||||
)
|
||||
async def upload_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
file: UploadFile,
|
||||
session_id: str | None = Query(default=None),
|
||||
) -> UploadFileResponse:
|
||||
"""
|
||||
Upload a file to the user's workspace.
|
||||
|
||||
Files are stored in session-scoped paths when session_id is provided,
|
||||
so the agent's session-scoped tools can discover them automatically.
|
||||
"""
|
||||
config = Config()
|
||||
|
||||
# Sanitize filename — strip any directory components
|
||||
filename = os.path.basename(file.filename or "upload") or "upload"
|
||||
|
||||
# Read file content with early abort on size limit
|
||||
max_file_bytes = config.max_file_size_mb * 1024 * 1024
|
||||
chunks: list[bytes] = []
|
||||
total_size = 0
|
||||
while chunk := await file.read(64 * 1024): # 64KB chunks
|
||||
total_size += len(chunk)
|
||||
if total_size > max_file_bytes:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File exceeds maximum size of {config.max_file_size_mb} MB",
|
||||
)
|
||||
chunks.append(chunk)
|
||||
content = b"".join(chunks)
|
||||
|
||||
# Get or create workspace
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
# Pre-write storage cap check (soft check — final enforcement is post-write)
|
||||
storage_limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
|
||||
current_usage = await get_workspace_total_size(workspace.id)
|
||||
if storage_limit_bytes and current_usage + len(content) > storage_limit_bytes:
|
||||
used_percent = (current_usage / storage_limit_bytes) * 100
|
||||
raise fastapi.HTTPException(
|
||||
status_code=413,
|
||||
detail={
|
||||
"message": "Storage limit exceeded",
|
||||
"used_bytes": current_usage,
|
||||
"limit_bytes": storage_limit_bytes,
|
||||
"used_percent": round(used_percent, 1),
|
||||
},
|
||||
)
|
||||
|
||||
# Warn at 80% usage
|
||||
if (
|
||||
storage_limit_bytes
|
||||
and (usage_ratio := (current_usage + len(content)) / storage_limit_bytes) >= 0.8
|
||||
):
|
||||
logger.warning(
|
||||
f"User {user_id} workspace storage at {usage_ratio * 100:.1f}% "
|
||||
f"({current_usage + len(content)} / {storage_limit_bytes} bytes)"
|
||||
)
|
||||
|
||||
# Virus scan
|
||||
await scan_content_safe(content, filename=filename)
|
||||
|
||||
# Write file via WorkspaceManager
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
workspace_file = await manager.write_file(content, filename)
|
||||
|
||||
# Post-write storage check — eliminates TOCTOU race on the quota.
|
||||
# If a concurrent upload pushed us over the limit, undo this write.
|
||||
new_total = await get_workspace_total_size(workspace.id)
|
||||
if storage_limit_bytes and new_total > storage_limit_bytes:
|
||||
await soft_delete_workspace_file(workspace_file.id, workspace.id)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=413,
|
||||
detail={
|
||||
"message": "Storage limit exceeded (concurrent upload)",
|
||||
"used_bytes": new_total,
|
||||
"limit_bytes": storage_limit_bytes,
|
||||
},
|
||||
)
|
||||
|
||||
return UploadFileResponse(
|
||||
file_id=workspace_file.id,
|
||||
name=workspace_file.name,
|
||||
path=workspace_file.path,
|
||||
mime_type=workspace_file.mime_type,
|
||||
size_bytes=workspace_file.size_bytes,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/storage/usage",
|
||||
summary="Get workspace storage usage",
|
||||
)
|
||||
async def get_storage_usage(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
) -> StorageUsageResponse:
|
||||
"""
|
||||
Get storage usage information for the user's workspace.
|
||||
"""
|
||||
config = Config()
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
used_bytes = await get_workspace_total_size(workspace.id)
|
||||
file_count = await count_workspace_files(workspace.id)
|
||||
limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
|
||||
|
||||
return StorageUsageResponse(
|
||||
used_bytes=used_bytes,
|
||||
limit_bytes=limit_bytes,
|
||||
used_percent=round((used_bytes / limit_bytes) * 100, 1) if limit_bytes else 0,
|
||||
file_count=file_count,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,307 @@
|
||||
"""Tests for workspace file upload and download routes."""
|
||||
|
||||
import io
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.workspace import routes as workspace_routes
|
||||
from backend.data.workspace import WorkspaceFile
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(workspace_routes.router)
|
||||
|
||||
|
||||
@app.exception_handler(ValueError)
|
||||
async def _value_error_handler(
|
||||
request: fastapi.Request, exc: ValueError
|
||||
) -> fastapi.responses.JSONResponse:
|
||||
"""Mirror the production ValueError → 400 mapping from rest_api.py."""
|
||||
return fastapi.responses.JSONResponse(status_code=400, content={"detail": str(exc)})
|
||||
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
MOCK_WORKSPACE = type("W", (), {"id": "ws-1"})()
|
||||
|
||||
_NOW = datetime(2023, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
MOCK_FILE = WorkspaceFile(
|
||||
id="file-aaa-bbb",
|
||||
workspace_id="ws-1",
|
||||
created_at=_NOW,
|
||||
updated_at=_NOW,
|
||||
name="hello.txt",
|
||||
path="/session/hello.txt",
|
||||
mime_type="text/plain",
|
||||
size_bytes=13,
|
||||
storage_path="local://hello.txt",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _upload(
|
||||
filename: str = "hello.txt",
|
||||
content: bytes = b"Hello, world!",
|
||||
content_type: str = "text/plain",
|
||||
):
|
||||
"""Helper to POST a file upload."""
|
||||
return client.post(
|
||||
"/files/upload?session_id=sess-1",
|
||||
files={"file": (filename, io.BytesIO(content), content_type)},
|
||||
)
|
||||
|
||||
|
||||
# ---- Happy path ----
|
||||
|
||||
|
||||
def test_upload_happy_path(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["file_id"] == "file-aaa-bbb"
|
||||
assert data["name"] == "hello.txt"
|
||||
assert data["size_bytes"] == 13
|
||||
|
||||
|
||||
# ---- Per-file size limit ----
|
||||
|
||||
|
||||
def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
|
||||
"""Files larger than max_file_size_mb should be rejected with 413."""
|
||||
cfg = mocker.patch("backend.api.features.workspace.routes.Config")
|
||||
cfg.return_value.max_file_size_mb = 0 # 0 MB → any content is too big
|
||||
cfg.return_value.max_workspace_storage_mb = 500
|
||||
|
||||
response = _upload(content=b"x" * 1024)
|
||||
assert response.status_code == 413
|
||||
|
||||
|
||||
# ---- Storage quota exceeded ----
|
||||
|
||||
|
||||
def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
# Current usage already at limit
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=500 * 1024 * 1024,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 413
|
||||
assert "Storage limit exceeded" in response.text
|
||||
|
||||
|
||||
# ---- Post-write quota race (B2) ----
|
||||
|
||||
|
||||
def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
|
||||
"""If a concurrent upload tips the total over the limit after write,
|
||||
the file should be soft-deleted and 413 returned."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
# Pre-write check passes (under limit), but post-write check fails
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
side_effect=[0, 600 * 1024 * 1024], # first call OK, second over limit
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
mock_delete = mocker.patch(
|
||||
"backend.api.features.workspace.routes.soft_delete_workspace_file",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 413
|
||||
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-1")
|
||||
|
||||
|
||||
# ---- Any extension accepted (no allowlist) ----
|
||||
|
||||
|
||||
def test_upload_any_extension(mocker: pytest_mock.MockFixture):
|
||||
"""Any file extension should be accepted — ClamAV is the security layer."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload(filename="data.xyz", content=b"arbitrary")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ---- Virus scan rejection ----
|
||||
|
||||
|
||||
def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
|
||||
"""Files flagged by ClamAV should be rejected and never written to storage."""
|
||||
from backend.api.features.store.exceptions import VirusDetectedError
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
side_effect=VirusDetectedError("Eicar-Test-Signature"),
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload(filename="evil.exe", content=b"X5O!P%@AP...")
|
||||
assert response.status_code == 400
|
||||
assert "Virus detected" in response.text
|
||||
mock_manager.write_file.assert_not_called()
|
||||
|
||||
|
||||
# ---- No file extension ----
|
||||
|
||||
|
||||
def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
|
||||
"""Files without an extension should be accepted and stored as-is."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload(
|
||||
filename="Makefile",
|
||||
content=b"all:\n\techo hello",
|
||||
content_type="application/octet-stream",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_manager.write_file.assert_called_once()
|
||||
assert mock_manager.write_file.call_args[0][1] == "Makefile"
|
||||
|
||||
|
||||
# ---- Filename sanitization (SF5) ----
|
||||
|
||||
|
||||
def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
|
||||
"""Path-traversal filenames should be reduced to their basename."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
# Filename with traversal
|
||||
_upload(filename="../../etc/passwd.txt")
|
||||
|
||||
# write_file should have been called with just the basename
|
||||
mock_manager.write_file.assert_called_once()
|
||||
call_args = mock_manager.write_file.call_args
|
||||
assert call_args[0][1] == "passwd.txt"
|
||||
|
||||
|
||||
# ---- Download ----
|
||||
|
||||
|
||||
def test_download_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_file",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.get("/files/some-file-id/download")
|
||||
assert response.status_code == 404
|
||||
@@ -93,12 +93,49 @@ class ChatConfig(BaseSettings):
|
||||
"history compression. Falls back to compression when unavailable.",
|
||||
)
|
||||
|
||||
# E2B Sandbox Configuration
|
||||
use_e2b_sandbox: bool = Field(
|
||||
default=True,
|
||||
description="Use E2B cloud sandboxes for persistent bash/python execution. "
|
||||
"When enabled, bash_exec routes commands to E2B and SDK file tools "
|
||||
"operate directly on the sandbox via E2B's filesystem API.",
|
||||
)
|
||||
e2b_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="E2B API key. Falls back to E2B_API_KEY environment variable.",
|
||||
)
|
||||
e2b_sandbox_template: str = Field(
|
||||
default="base",
|
||||
description="E2B sandbox template to use for copilot sessions.",
|
||||
)
|
||||
e2b_sandbox_timeout: int = Field(
|
||||
default=43200, # 12 hours — same as session_ttl
|
||||
description="E2B sandbox keepalive timeout in seconds.",
|
||||
)
|
||||
|
||||
# Extended thinking configuration for Claude models
|
||||
thinking_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable adaptive thinking for Claude models via OpenRouter",
|
||||
)
|
||||
|
||||
@field_validator("use_e2b_sandbox", mode="before")
|
||||
@classmethod
|
||||
def get_use_e2b_sandbox(cls, v):
|
||||
"""Get use_e2b_sandbox from environment if not provided."""
|
||||
env_val = os.getenv("CHAT_USE_E2B_SANDBOX", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
return True if v is None else v
|
||||
|
||||
@field_validator("e2b_api_key", mode="before")
|
||||
@classmethod
|
||||
def get_e2b_api_key(cls, v):
|
||||
"""Get E2B API key from environment if not provided."""
|
||||
if v is None:
|
||||
v = os.getenv("CHAT_E2B_API_KEY") or os.getenv("E2B_API_KEY")
|
||||
return v
|
||||
|
||||
@field_validator("api_key", mode="before")
|
||||
@classmethod
|
||||
def get_api_key(cls, v):
|
||||
|
||||
@@ -16,7 +16,7 @@ from prisma.types import (
|
||||
)
|
||||
|
||||
from backend.data import db
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.json import SafeJson, sanitize_string
|
||||
|
||||
from .model import ChatMessage, ChatSession, ChatSessionInfo
|
||||
|
||||
@@ -101,15 +101,16 @@ async def add_chat_message(
|
||||
"sequence": sequence,
|
||||
}
|
||||
|
||||
# Add optional string fields
|
||||
# Add optional string fields — sanitize to strip PostgreSQL-incompatible
|
||||
# control characters (null bytes etc.) that may appear in tool outputs.
|
||||
if content is not None:
|
||||
data["content"] = content
|
||||
data["content"] = sanitize_string(content)
|
||||
if name is not None:
|
||||
data["name"] = name
|
||||
if tool_call_id is not None:
|
||||
data["toolCallId"] = tool_call_id
|
||||
if refusal is not None:
|
||||
data["refusal"] = refusal
|
||||
data["refusal"] = sanitize_string(refusal)
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if tool_calls is not None:
|
||||
@@ -170,15 +171,16 @@ async def add_chat_messages_batch(
|
||||
"createdAt": now,
|
||||
}
|
||||
|
||||
# Add optional string fields
|
||||
# Add optional string fields — sanitize to strip
|
||||
# PostgreSQL-incompatible control characters.
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
data["content"] = sanitize_string(msg["content"])
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
data["refusal"] = sanitize_string(msg["refusal"])
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
@@ -312,7 +314,7 @@ async def update_tool_message_content(
|
||||
"toolCallId": tool_call_id,
|
||||
},
|
||||
data={
|
||||
"content": new_content,
|
||||
"content": sanitize_string(new_content),
|
||||
},
|
||||
)
|
||||
if result == 0:
|
||||
|
||||
@@ -119,12 +119,12 @@ class CoPilotProcessor:
|
||||
"""
|
||||
from backend.util.workspace_storage import shutdown_workspace_storage
|
||||
|
||||
coro = shutdown_workspace_storage()
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
shutdown_workspace_storage(), self.execution_loop
|
||||
)
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self.execution_loop)
|
||||
future.result(timeout=5)
|
||||
except Exception as e:
|
||||
coro.close() # Prevent "coroutine was never awaited" warning
|
||||
error_msg = str(e) or type(e).__name__
|
||||
logger.warning(
|
||||
f"[CoPilotExecutor] Worker {self.tid} cleanup error: {error_msg}"
|
||||
|
||||
@@ -153,6 +153,9 @@ class CoPilotExecutionEntry(BaseModel):
|
||||
context: dict[str, str] | None = None
|
||||
"""Optional context for the message (e.g., {url: str, content: str})"""
|
||||
|
||||
file_ids: list[str] | None = None
|
||||
"""Workspace file IDs attached to the user's message"""
|
||||
|
||||
|
||||
class CancelCoPilotEvent(BaseModel):
|
||||
"""Event to cancel a CoPilot operation."""
|
||||
@@ -171,6 +174,7 @@ async def enqueue_copilot_turn(
|
||||
turn_id: str,
|
||||
is_user_message: bool = True,
|
||||
context: dict[str, str] | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Enqueue a CoPilot task for processing by the executor service.
|
||||
|
||||
@@ -181,6 +185,7 @@ async def enqueue_copilot_turn(
|
||||
turn_id: Per-turn UUID for Redis stream isolation
|
||||
is_user_message: Whether the message is from the user (vs system/assistant)
|
||||
context: Optional context for the message (e.g., {url: str, content: str})
|
||||
file_ids: Optional workspace file IDs attached to the user's message
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
@@ -191,6 +196,7 @@ async def enqueue_copilot_turn(
|
||||
message=message,
|
||||
is_user_message=is_user_message,
|
||||
context=context,
|
||||
file_ids=file_ids,
|
||||
)
|
||||
|
||||
queue_client = await get_async_copilot_queue()
|
||||
|
||||
@@ -672,6 +672,16 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
async with _session_locks_mutex:
|
||||
_session_locks.pop(session_id, None)
|
||||
|
||||
# Shut down any local browser daemon for this session (best-effort).
|
||||
# Inline import required: all tool modules import ChatSession from this
|
||||
# module, so any top-level import from tools.* would create a cycle.
|
||||
try:
|
||||
from .tools.agent_browser import close_browser_session
|
||||
|
||||
await close_browser_session(session_id, user_id=user_id)
|
||||
except Exception as e:
|
||||
logger.debug(f"Browser cleanup for session {session_id}: {e}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from typing import Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.util.json import dumps as json_dumps
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -150,6 +151,9 @@ class StreamToolInputAvailable(StreamBaseResponse):
|
||||
)
|
||||
|
||||
|
||||
_MAX_TOOL_OUTPUT_SIZE = 100_000 # ~100 KB; truncate to avoid bloating SSE/DB
|
||||
|
||||
|
||||
class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
"""Tool execution result."""
|
||||
|
||||
@@ -164,6 +168,10 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
default=True, description="Whether the tool execution succeeded"
|
||||
)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Truncate oversized outputs after construction."""
|
||||
self.output = truncate(self.output, _MAX_TOOL_OUTPUT_SIZE)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format, excluding non-spec fields."""
|
||||
data = {
|
||||
|
||||
360
autogpt_platform/backend/backend/copilot/sdk/e2b_file_tools.py
Normal file
360
autogpt_platform/backend/backend/copilot/sdk/e2b_file_tools.py
Normal file
@@ -0,0 +1,360 @@
|
||||
"""MCP file-tool handlers that route to the E2B cloud sandbox.
|
||||
|
||||
When E2B is active, these tools replace the SDK built-in Read/Write/Edit/
|
||||
Glob/Grep so that all file operations share the same ``/home/user``
|
||||
filesystem as ``bash_exec``.
|
||||
|
||||
SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
|
||||
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
from typing import Any, Callable
|
||||
|
||||
from backend.copilot.tools.e2b_sandbox import E2B_WORKDIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Lazy imports to break circular dependency with tool_adapter.
|
||||
|
||||
|
||||
def _get_sandbox(): # type: ignore[return]
|
||||
from .tool_adapter import get_current_sandbox # noqa: E402
|
||||
|
||||
return get_current_sandbox()
|
||||
|
||||
|
||||
def _is_allowed_local(path: str) -> bool:
|
||||
from .tool_adapter import is_allowed_local_path # noqa: E402
|
||||
|
||||
return is_allowed_local_path(path)
|
||||
|
||||
|
||||
def _resolve_remote(path: str) -> str:
|
||||
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
|
||||
|
||||
Raises :class:`ValueError` if the resolved path escapes the sandbox.
|
||||
"""
|
||||
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
|
||||
normalized = os.path.normpath(candidate)
|
||||
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
|
||||
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
|
||||
return normalized
|
||||
|
||||
|
||||
def _mcp(text: str, *, error: bool = False) -> dict[str, Any]:
|
||||
if error:
|
||||
text = json.dumps({"error": text, "type": "error"})
|
||||
return {"content": [{"type": "text", "text": text}], "isError": error}
|
||||
|
||||
|
||||
def _get_sandbox_and_path(
|
||||
file_path: str,
|
||||
) -> tuple[Any, str] | dict[str, Any]:
|
||||
"""Common preamble: get sandbox + resolve path, or return MCP error."""
|
||||
sandbox = _get_sandbox()
|
||||
if sandbox is None:
|
||||
return _mcp("No E2B sandbox available", error=True)
|
||||
try:
|
||||
remote = _resolve_remote(file_path)
|
||||
except ValueError as exc:
|
||||
return _mcp(str(exc), error=True)
|
||||
return sandbox, remote
|
||||
|
||||
|
||||
# Tool handlers
|
||||
|
||||
|
||||
async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
file_path: str = args.get("file_path", "")
|
||||
offset: int = max(0, int(args.get("offset", 0)))
|
||||
limit: int = max(1, int(args.get("limit", 2000)))
|
||||
|
||||
if not file_path:
|
||||
return _mcp("file_path is required", error=True)
|
||||
|
||||
# SDK-internal paths (tool-results, ephemeral working dir) stay on the host.
|
||||
if _is_allowed_local(file_path):
|
||||
return _read_local(file_path, offset, limit)
|
||||
|
||||
result = _get_sandbox_and_path(file_path)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
sandbox, remote = result
|
||||
|
||||
try:
|
||||
raw: bytes = await sandbox.files.read(remote, format="bytes")
|
||||
content = raw.decode("utf-8", errors="replace")
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to read {remote}: {exc}", error=True)
|
||||
|
||||
lines = content.splitlines(keepends=True)
|
||||
selected = list(itertools.islice(lines, offset, offset + limit))
|
||||
numbered = "".join(
|
||||
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
|
||||
)
|
||||
return _mcp(numbered)
|
||||
|
||||
|
||||
async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
file_path: str = args.get("file_path", "")
|
||||
content: str = args.get("content", "")
|
||||
|
||||
if not file_path:
|
||||
return _mcp("file_path is required", error=True)
|
||||
|
||||
result = _get_sandbox_and_path(file_path)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
sandbox, remote = result
|
||||
|
||||
try:
|
||||
parent = os.path.dirname(remote)
|
||||
if parent and parent != E2B_WORKDIR:
|
||||
await sandbox.files.make_dir(parent)
|
||||
await sandbox.files.write(remote, content)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to write {remote}: {exc}", error=True)
|
||||
|
||||
return _mcp(f"Successfully wrote to {remote}")
|
||||
|
||||
|
||||
async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
file_path: str = args.get("file_path", "")
|
||||
old_string: str = args.get("old_string", "")
|
||||
new_string: str = args.get("new_string", "")
|
||||
replace_all: bool = args.get("replace_all", False)
|
||||
|
||||
if not file_path:
|
||||
return _mcp("file_path is required", error=True)
|
||||
if not old_string:
|
||||
return _mcp("old_string is required", error=True)
|
||||
|
||||
result = _get_sandbox_and_path(file_path)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
sandbox, remote = result
|
||||
|
||||
try:
|
||||
raw: bytes = await sandbox.files.read(remote, format="bytes")
|
||||
content = raw.decode("utf-8", errors="replace")
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to read {remote}: {exc}", error=True)
|
||||
|
||||
count = content.count(old_string)
|
||||
if count == 0:
|
||||
return _mcp(f"old_string not found in {file_path}", error=True)
|
||||
if count > 1 and not replace_all:
|
||||
return _mcp(
|
||||
f"old_string appears {count} times in {file_path}. "
|
||||
"Use replace_all=true or provide a more unique string.",
|
||||
error=True,
|
||||
)
|
||||
|
||||
updated = (
|
||||
content.replace(old_string, new_string)
|
||||
if replace_all
|
||||
else content.replace(old_string, new_string, 1)
|
||||
)
|
||||
try:
|
||||
await sandbox.files.write(remote, updated)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to write {remote}: {exc}", error=True)
|
||||
|
||||
return _mcp(f"Edited {remote} ({count} replacement{'s' if count > 1 else ''})")
|
||||
|
||||
|
||||
async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
|
||||
pattern: str = args.get("pattern", "")
|
||||
path: str = args.get("path", "")
|
||||
|
||||
if not pattern:
|
||||
return _mcp("pattern is required", error=True)
|
||||
|
||||
sandbox = _get_sandbox()
|
||||
if sandbox is None:
|
||||
return _mcp("No E2B sandbox available", error=True)
|
||||
|
||||
try:
|
||||
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
|
||||
except ValueError as exc:
|
||||
return _mcp(str(exc), error=True)
|
||||
|
||||
cmd = f"find {shlex.quote(search_dir)} -name {shlex.quote(pattern)} -type f 2>/dev/null | head -500"
|
||||
try:
|
||||
result = await sandbox.commands.run(cmd, cwd=E2B_WORKDIR, timeout=10)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Glob failed: {exc}", error=True)
|
||||
|
||||
files = [line for line in (result.stdout or "").strip().splitlines() if line]
|
||||
return _mcp(json.dumps(files, indent=2))
|
||||
|
||||
|
||||
async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
|
||||
pattern: str = args.get("pattern", "")
|
||||
path: str = args.get("path", "")
|
||||
include: str = args.get("include", "")
|
||||
|
||||
if not pattern:
|
||||
return _mcp("pattern is required", error=True)
|
||||
|
||||
sandbox = _get_sandbox()
|
||||
if sandbox is None:
|
||||
return _mcp("No E2B sandbox available", error=True)
|
||||
|
||||
try:
|
||||
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
|
||||
except ValueError as exc:
|
||||
return _mcp(str(exc), error=True)
|
||||
|
||||
parts = ["grep", "-rn", "--color=never"]
|
||||
if include:
|
||||
parts.extend(["--include", include])
|
||||
parts.extend([pattern, search_dir])
|
||||
cmd = " ".join(shlex.quote(p) for p in parts) + " 2>/dev/null | head -200"
|
||||
|
||||
try:
|
||||
result = await sandbox.commands.run(cmd, cwd=E2B_WORKDIR, timeout=15)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Grep failed: {exc}", error=True)
|
||||
|
||||
output = (result.stdout or "").strip()
|
||||
return _mcp(output if output else "No matches found.")
|
||||
|
||||
|
||||
# Local read (for SDK-internal paths)
|
||||
|
||||
|
||||
def _read_local(file_path: str, offset: int, limit: int) -> dict[str, Any]:
|
||||
"""Read from the host filesystem (defence-in-depth path check)."""
|
||||
if not _is_allowed_local(file_path):
|
||||
return _mcp(f"Path not allowed: {file_path}", error=True)
|
||||
expanded = os.path.realpath(os.path.expanduser(file_path))
|
||||
try:
|
||||
with open(expanded) as fh:
|
||||
selected = list(itertools.islice(fh, offset, offset + limit))
|
||||
numbered = "".join(
|
||||
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
|
||||
)
|
||||
return _mcp(numbered)
|
||||
except FileNotFoundError:
|
||||
return _mcp(f"File not found: {file_path}", error=True)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Error reading {file_path}: {exc}", error=True)
|
||||
|
||||
|
||||
# Tool descriptors (name, description, schema, handler)
|
||||
|
||||
E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
|
||||
(
|
||||
"read_file",
|
||||
"Read a file from the cloud sandbox (/home/user). "
|
||||
"Use offset and limit for large files.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path (relative to /home/user, or absolute).",
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line to start reading from (0-indexed). Default: 0.",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Number of lines to read. Default: 2000.",
|
||||
},
|
||||
},
|
||||
"required": ["file_path"],
|
||||
},
|
||||
_handle_read_file,
|
||||
),
|
||||
(
|
||||
"write_file",
|
||||
"Write or create a file in the cloud sandbox (/home/user). "
|
||||
"Parent directories are created automatically.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path (relative to /home/user, or absolute).",
|
||||
},
|
||||
"content": {"type": "string", "description": "Content to write."},
|
||||
},
|
||||
"required": ["file_path", "content"],
|
||||
},
|
||||
_handle_write_file,
|
||||
),
|
||||
(
|
||||
"edit_file",
|
||||
"Targeted text replacement in a sandbox file. "
|
||||
"old_string must appear in the file and is replaced with new_string.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path (relative to /home/user, or absolute).",
|
||||
},
|
||||
"old_string": {"type": "string", "description": "Text to find."},
|
||||
"new_string": {"type": "string", "description": "Replacement text."},
|
||||
"replace_all": {
|
||||
"type": "boolean",
|
||||
"description": "Replace all occurrences (default: false).",
|
||||
},
|
||||
},
|
||||
"required": ["file_path", "old_string", "new_string"],
|
||||
},
|
||||
_handle_edit_file,
|
||||
),
|
||||
(
|
||||
"glob",
|
||||
"Search for files by name pattern in the cloud sandbox.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Glob pattern (e.g. *.py).",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory to search. Default: /home/user.",
|
||||
},
|
||||
},
|
||||
"required": ["pattern"],
|
||||
},
|
||||
_handle_glob,
|
||||
),
|
||||
(
|
||||
"grep",
|
||||
"Search file contents by regex in the cloud sandbox.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {"type": "string", "description": "Regex pattern."},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File or directory. Default: /home/user.",
|
||||
},
|
||||
"include": {
|
||||
"type": "string",
|
||||
"description": "Glob to filter files (e.g. *.py).",
|
||||
},
|
||||
},
|
||||
"required": ["pattern"],
|
||||
},
|
||||
_handle_grep,
|
||||
),
|
||||
]
|
||||
|
||||
E2B_FILE_TOOL_NAMES: list[str] = [name for name, *_ in E2B_FILE_TOOLS]
|
||||
@@ -0,0 +1,153 @@
|
||||
"""Tests for E2B file-tool path validation and local read safety.
|
||||
|
||||
Pure unit tests with no external dependencies (no E2B, no sandbox).
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from .e2b_file_tools import _read_local, _resolve_remote
|
||||
from .tool_adapter import _current_project_dir
|
||||
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _resolve_remote — sandbox path normalisation & boundary enforcement
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveRemote:
|
||||
def test_relative_path_resolved(self):
|
||||
assert _resolve_remote("src/main.py") == "/home/user/src/main.py"
|
||||
|
||||
def test_absolute_within_sandbox(self):
|
||||
assert _resolve_remote("/home/user/file.txt") == "/home/user/file.txt"
|
||||
|
||||
def test_workdir_itself(self):
|
||||
assert _resolve_remote("/home/user") == "/home/user"
|
||||
|
||||
def test_relative_dotslash(self):
|
||||
assert _resolve_remote("./README.md") == "/home/user/README.md"
|
||||
|
||||
def test_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("../../etc/passwd")
|
||||
|
||||
def test_absolute_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("/home/user/../../etc/passwd")
|
||||
|
||||
def test_absolute_outside_sandbox_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("/etc/passwd")
|
||||
|
||||
def test_root_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("/")
|
||||
|
||||
def test_home_other_user_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("/home/other/file.txt")
|
||||
|
||||
def test_deep_nested_allowed(self):
|
||||
assert _resolve_remote("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
|
||||
|
||||
def test_trailing_slash_normalised(self):
|
||||
assert _resolve_remote("src/") == "/home/user/src"
|
||||
|
||||
def test_double_dots_within_sandbox_ok(self):
|
||||
"""Path that resolves back within /home/user is allowed."""
|
||||
assert _resolve_remote("a/b/../c.txt") == "/home/user/a/c.txt"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_local — host filesystem reads with allowlist enforcement
|
||||
#
|
||||
# In E2B mode, _read_local only allows tool-results paths (via
|
||||
# is_allowed_local_path without sdk_cwd). Regular files live on the
|
||||
# sandbox, not the host.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReadLocal:
|
||||
def _make_tool_results_file(self, encoded: str, filename: str, content: str) -> str:
|
||||
"""Create a tool-results file and return its path."""
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
os.makedirs(tool_results_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_results_dir, filename)
|
||||
with open(filepath, "w") as f:
|
||||
f.write(content)
|
||||
return filepath
|
||||
|
||||
def test_read_tool_results_file(self):
|
||||
"""Reading a tool-results file should succeed."""
|
||||
encoded = "-tmp-copilot-e2b-test-read"
|
||||
filepath = self._make_tool_results_file(
|
||||
encoded, "result.txt", "line 1\nline 2\nline 3\n"
|
||||
)
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local(filepath, offset=0, limit=2000)
|
||||
assert result["isError"] is False
|
||||
assert "line 1" in result["content"][0]["text"]
|
||||
assert "line 2" in result["content"][0]["text"]
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
os.unlink(filepath)
|
||||
|
||||
def test_read_disallowed_path_blocked(self):
|
||||
"""Reading /etc/passwd should be blocked by the allowlist."""
|
||||
result = _read_local("/etc/passwd", offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
assert "not allowed" in result["content"][0]["text"].lower()
|
||||
|
||||
def test_read_nonexistent_tool_results(self):
|
||||
"""A tool-results path that doesn't exist returns FileNotFoundError."""
|
||||
encoded = "-tmp-copilot-e2b-test-nofile"
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
os.makedirs(tool_results_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_results_dir, "nonexistent.txt")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local(filepath, offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
assert "not found" in result["content"][0]["text"].lower()
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
os.rmdir(tool_results_dir)
|
||||
|
||||
def test_read_traversal_path_blocked(self):
|
||||
"""A traversal attempt that escapes allowed directories is blocked."""
|
||||
result = _read_local("/tmp/copilot-abc/../../etc/shadow", offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
assert "not allowed" in result["content"][0]["text"].lower()
|
||||
|
||||
def test_read_arbitrary_host_path_blocked(self):
|
||||
"""Arbitrary host paths are blocked even if they exist."""
|
||||
result = _read_local("/proc/self/environ", offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
|
||||
def test_read_with_offset_and_limit(self):
|
||||
"""Offset and limit should control which lines are returned."""
|
||||
encoded = "-tmp-copilot-e2b-test-offset"
|
||||
content = "".join(f"line {i}\n" for i in range(10))
|
||||
filepath = self._make_tool_results_file(encoded, "lines.txt", content)
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local(filepath, offset=3, limit=2)
|
||||
assert result["isError"] is False
|
||||
text = result["content"][0]["text"]
|
||||
assert "line 3" in text
|
||||
assert "line 4" in text
|
||||
assert "line 2" not in text
|
||||
assert "line 5" not in text
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
os.unlink(filepath)
|
||||
|
||||
def test_read_without_project_dir_blocks_all(self):
|
||||
"""Without _current_project_dir set, all paths are blocked."""
|
||||
result = _read_local("/tmp/anything.txt", offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
@@ -6,7 +6,6 @@ ensuring multi-user isolation and preventing unauthorized operations.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
@@ -16,6 +15,7 @@ from .tool_adapter import (
|
||||
DANGEROUS_PATTERNS,
|
||||
MCP_TOOL_PREFIX,
|
||||
WORKSPACE_SCOPED_TOOLS,
|
||||
is_allowed_local_path,
|
||||
stash_pending_tool_output,
|
||||
)
|
||||
|
||||
@@ -38,40 +38,20 @@ def _validate_workspace_path(
|
||||
) -> dict[str, Any]:
|
||||
"""Validate that a workspace-scoped tool only accesses allowed paths.
|
||||
|
||||
Allowed directories:
|
||||
Delegates to :func:`is_allowed_local_path` which permits:
|
||||
- The SDK working directory (``/tmp/copilot-<session>/``)
|
||||
- The SDK tool-results directory (``~/.claude/projects/…/tool-results/``)
|
||||
- The current session's tool-results directory
|
||||
(``~/.claude/projects/<encoded-cwd>/tool-results/``)
|
||||
"""
|
||||
path = tool_input.get("file_path") or tool_input.get("path") or ""
|
||||
if not path:
|
||||
# Glob/Grep without a path default to cwd which is already sandboxed
|
||||
return {}
|
||||
|
||||
# Resolve relative paths against sdk_cwd (the SDK sets cwd so the LLM
|
||||
# naturally uses relative paths like "test.txt" instead of absolute ones).
|
||||
# Tilde paths (~/) are home-dir references, not relative — expand first.
|
||||
if path.startswith("~"):
|
||||
resolved = os.path.realpath(os.path.expanduser(path))
|
||||
elif not os.path.isabs(path) and sdk_cwd:
|
||||
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
|
||||
else:
|
||||
resolved = os.path.realpath(path)
|
||||
|
||||
# Allow access within the SDK working directory
|
||||
if sdk_cwd:
|
||||
norm_cwd = os.path.realpath(sdk_cwd)
|
||||
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
|
||||
return {}
|
||||
|
||||
# Allow access to ~/.claude/projects/*/tool-results/ (big tool results)
|
||||
claude_dir = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
tool_results_seg = os.sep + "tool-results" + os.sep
|
||||
if resolved.startswith(claude_dir + os.sep) and tool_results_seg in resolved:
|
||||
if is_allowed_local_path(path, sdk_cwd):
|
||||
return {}
|
||||
|
||||
logger.warning(
|
||||
f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})"
|
||||
)
|
||||
logger.warning(f"Blocked {tool_name} outside workspace: {path}")
|
||||
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
|
||||
return _deny(
|
||||
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
|
||||
|
||||
@@ -120,17 +120,31 @@ def test_read_no_cwd_denies_absolute():
|
||||
|
||||
|
||||
def test_read_tool_results_allowed():
|
||||
from .tool_adapter import _current_project_dir
|
||||
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
# is_allowed_local_path requires the session's encoded cwd to be set
|
||||
token = _current_project_dir.set("-tmp-copilot-abc123")
|
||||
try:
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
|
||||
|
||||
def test_read_claude_projects_without_tool_results_denied():
|
||||
def test_read_claude_projects_session_dir_allowed():
|
||||
"""Files within the current session's project dir are allowed."""
|
||||
from .tool_adapter import _current_project_dir
|
||||
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert _is_denied(result)
|
||||
token = _current_project_dir.set("-tmp-copilot-abc123")
|
||||
try:
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert not _is_denied(result)
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
|
||||
|
||||
# -- Built-in Bash is blocked (use bash_exec MCP tool instead) ---------------
|
||||
|
||||
@@ -5,18 +5,28 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
import openai
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
ClaudeAgentOptions,
|
||||
ClaudeSDKClient,
|
||||
ResultMessage,
|
||||
ToolUseBlock,
|
||||
)
|
||||
from langfuse import propagate_attributes
|
||||
from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.executor.cluster_lock import AsyncClusterLock
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.prompt import compress_context
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from ..config import ChatConfig
|
||||
@@ -42,14 +52,15 @@ from ..service import (
|
||||
_generate_session_title,
|
||||
_is_langfuse_configured,
|
||||
)
|
||||
from ..tools.e2b_sandbox import get_or_create_sandbox
|
||||
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||
from ..tracking import track_user_message
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .security_hooks import create_security_hooks
|
||||
from .tool_adapter import (
|
||||
COPILOT_TOOL_NAMES,
|
||||
SDK_DISALLOWED_TOOLS,
|
||||
create_copilot_mcp_server,
|
||||
get_copilot_tool_names,
|
||||
get_sdk_disallowed_tools,
|
||||
set_execution_context,
|
||||
wait_for_stash,
|
||||
)
|
||||
@@ -148,9 +159,36 @@ _HEARTBEAT_INTERVAL = 3.0 # seconds
|
||||
# Appended to the system prompt to inform the agent about available tools.
|
||||
# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead,
|
||||
# which has kernel-level network isolation (unshare --net).
|
||||
def _build_sdk_tool_supplement(cwd: str) -> str:
|
||||
"""Build the SDK tool supplement with the actual working directory injected."""
|
||||
return f"""
|
||||
_SHARED_TOOL_NOTES = """\
|
||||
|
||||
### Sharing files with the user
|
||||
After saving a file to the persistent workspace with `write_workspace_file`,
|
||||
share it with the user by embedding the `download_url` from the response in
|
||||
your message as a Markdown link or image:
|
||||
|
||||
- **Any file** — shows as a clickable download link:
|
||||
`[report.csv](workspace://file_id#text/csv)`
|
||||
- **Image** — renders inline in chat:
|
||||
``
|
||||
- **Video** — renders inline in chat with player controls:
|
||||
``
|
||||
|
||||
The `download_url` field in the `write_workspace_file` response is already
|
||||
in the correct format — paste it directly after the `(` in the Markdown.
|
||||
|
||||
### Long-running tools
|
||||
Long-running tools (create_agent, edit_agent, etc.) are handled
|
||||
asynchronously. You will receive an immediate response; the actual result
|
||||
is delivered to the user via a background stream.
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
"""
|
||||
|
||||
|
||||
_LOCAL_TOOL_SUPPLEMENT = (
|
||||
"""
|
||||
|
||||
## Tool notes
|
||||
|
||||
@@ -192,31 +230,57 @@ When you create or modify important files (code, configs, outputs), you MUST:
|
||||
1. Save them using `write_workspace_file` so they persist
|
||||
2. At the start of a new turn, call `list_workspace_files` to see what files
|
||||
are available from previous turns
|
||||
|
||||
### Sharing files with the user
|
||||
After saving a file to the persistent workspace with `write_workspace_file`,
|
||||
share it with the user by embedding the `download_url` from the response in
|
||||
your message as a Markdown link or image:
|
||||
|
||||
- **Any file** — shows as a clickable download link:
|
||||
`[report.csv](workspace://file_id#text/csv)`
|
||||
- **Image** — renders inline in chat:
|
||||
``
|
||||
- **Video** — renders inline in chat with player controls:
|
||||
``
|
||||
|
||||
The `download_url` field in the `write_workspace_file` response is already
|
||||
in the correct format — paste it directly after the `(` in the Markdown.
|
||||
|
||||
### Long-running tools
|
||||
Long-running tools (create_agent, edit_agent, etc.) are handled
|
||||
asynchronously. You will receive an immediate response; the actual result
|
||||
is delivered to the user via a background stream.
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
"""
|
||||
+ _SHARED_TOOL_NOTES
|
||||
)
|
||||
|
||||
|
||||
_E2B_TOOL_SUPPLEMENT = (
|
||||
"""
|
||||
|
||||
## Tool notes
|
||||
|
||||
### Shell commands
|
||||
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
|
||||
for shell commands — it runs in a cloud sandbox with full internet access.
|
||||
|
||||
### Working directory
|
||||
- Your working directory is: `/home/user` (cloud sandbox)
|
||||
- All file tools (`read_file`, `write_file`, `edit_file`, `glob`, `grep`)
|
||||
AND `bash_exec` operate on the **same cloud sandbox filesystem**.
|
||||
- Files created by `bash_exec` are immediately visible to `read_file` and
|
||||
vice-versa — they share one filesystem.
|
||||
- Use relative paths (resolved from `/home/user`) or absolute paths.
|
||||
|
||||
### Two storage systems — CRITICAL to understand
|
||||
|
||||
1. **Cloud sandbox** (`/home/user`):
|
||||
- Shared by all file tools AND `bash_exec` — same filesystem
|
||||
- Files **persist across turns** within the current session
|
||||
- Full Linux environment with internet access
|
||||
- Lost when the session expires (12 h inactivity)
|
||||
|
||||
2. **Persistent workspace** (cloud storage):
|
||||
- Files here **survive across sessions indefinitely**
|
||||
- Use `write_workspace_file` to save important files permanently
|
||||
- Use `read_workspace_file` to retrieve previously saved files
|
||||
- Use `list_workspace_files` to see what files you've saved before
|
||||
- Call `list_workspace_files(include_all_sessions=True)` to see files from
|
||||
all sessions
|
||||
|
||||
### Moving files between sandbox and persistent storage
|
||||
- **Sandbox → Persistent**: Use `write_workspace_file` with `source_path`
|
||||
to copy from the sandbox to permanent storage
|
||||
- **Persistent → Sandbox**: Use `read_workspace_file` with `save_to_path`
|
||||
to download into the sandbox for processing
|
||||
|
||||
### File persistence workflow
|
||||
Important files that must survive beyond this session should be saved with
|
||||
`write_workspace_file`. Sandbox files persist across turns but are lost
|
||||
when the session expires.
|
||||
"""
|
||||
+ _SHARED_TOOL_NOTES
|
||||
)
|
||||
|
||||
|
||||
STREAM_LOCK_PREFIX = "copilot:stream:lock:"
|
||||
@@ -291,8 +355,6 @@ def _cleanup_sdk_tool_results(cwd: str) -> None:
|
||||
Security: *cwd* MUST be created by ``_make_sdk_cwd()`` which sanitizes
|
||||
the session_id.
|
||||
"""
|
||||
import shutil
|
||||
|
||||
normalized = os.path.normpath(cwd)
|
||||
if not normalized.startswith(_SDK_CWD_PREFIX):
|
||||
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
|
||||
@@ -324,8 +386,6 @@ async def _compress_conversation_history(
|
||||
if len(messages) < 2:
|
||||
return messages
|
||||
|
||||
from backend.util.prompt import compress_context
|
||||
|
||||
# Convert ChatMessages to dicts for compress_context
|
||||
messages_dict = []
|
||||
for msg in messages:
|
||||
@@ -339,8 +399,6 @@ async def _compress_conversation_history(
|
||||
messages_dict.append(msg_dict)
|
||||
|
||||
try:
|
||||
import openai
|
||||
|
||||
async with openai.AsyncOpenAI(
|
||||
api_key=config.api_key, base_url=config.base_url, timeout=30.0
|
||||
) as client:
|
||||
@@ -550,6 +608,7 @@ async def stream_chat_completion_sdk(
|
||||
message_id = str(uuid.uuid4())
|
||||
stream_id = str(uuid.uuid4())
|
||||
stream_completed = False
|
||||
e2b_sandbox = None
|
||||
use_resume = False
|
||||
resume_file: str | None = None
|
||||
captured_transcript = CapturedTranscript()
|
||||
@@ -579,7 +638,7 @@ async def stream_chat_completion_sdk(
|
||||
# OTEL context manager — initialized inside the try and cleaned up in finally.
|
||||
_otel_ctx: Any = None
|
||||
|
||||
# Make sure there is no more code between the lock acquitition and try-block.
|
||||
# Make sure there is no more code between the lock acquisition and try-block.
|
||||
try:
|
||||
# Build system prompt (reuses non-SDK path with Langfuse support).
|
||||
# Pre-compute the cwd here so the exact working directory path can be
|
||||
@@ -597,17 +656,47 @@ async def stream_chat_completion_sdk(
|
||||
code="sdk_cwd_error",
|
||||
)
|
||||
return
|
||||
# Set up E2B sandbox for persistent cloud execution when configured.
|
||||
# When active, MCP file tools route directly to the sandbox filesystem
|
||||
# so bash_exec and file tools share the same /home/user directory.
|
||||
if config.use_e2b_sandbox and not config.e2b_api_key:
|
||||
logger.warning(
|
||||
"[E2B] [%s] E2B sandbox enabled but no API key configured "
|
||||
"(CHAT_E2B_API_KEY / E2B_API_KEY) — falling back to bubblewrap",
|
||||
session_id[:12],
|
||||
)
|
||||
if config.use_e2b_sandbox and config.e2b_api_key:
|
||||
try:
|
||||
e2b_sandbox = await get_or_create_sandbox(
|
||||
session_id,
|
||||
api_key=config.e2b_api_key,
|
||||
template=config.e2b_sandbox_template,
|
||||
timeout=config.e2b_sandbox_timeout,
|
||||
)
|
||||
except Exception as e2b_err:
|
||||
logger.error(
|
||||
"[E2B] [%s] Setup failed: %s",
|
||||
session_id[:12],
|
||||
e2b_err,
|
||||
exc_info=True,
|
||||
)
|
||||
e2b_sandbox = None
|
||||
|
||||
use_e2b = e2b_sandbox is not None
|
||||
|
||||
system_prompt, _ = await _build_system_prompt(
|
||||
user_id, has_conversation_history=has_history
|
||||
)
|
||||
system_prompt += _build_sdk_tool_supplement(sdk_cwd)
|
||||
system_prompt += (
|
||||
_E2B_TOOL_SUPPLEMENT
|
||||
if use_e2b
|
||||
else _LOCAL_TOOL_SUPPLEMENT.format(cwd=sdk_cwd)
|
||||
)
|
||||
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
|
||||
set_execution_context(user_id, session)
|
||||
set_execution_context(user_id, session, sandbox=e2b_sandbox, sdk_cwd=sdk_cwd)
|
||||
try:
|
||||
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
|
||||
|
||||
# Fail fast when no API credentials are available at all
|
||||
sdk_env = _build_sdk_env()
|
||||
if not sdk_env and not os.environ.get("ANTHROPIC_API_KEY"):
|
||||
@@ -617,7 +706,7 @@ async def stream_chat_completion_sdk(
|
||||
"or ANTHROPIC_API_KEY for direct Anthropic access."
|
||||
)
|
||||
|
||||
mcp_server = create_copilot_mcp_server()
|
||||
mcp_server = create_copilot_mcp_server(use_e2b=use_e2b)
|
||||
|
||||
sdk_model = _resolve_sdk_model()
|
||||
|
||||
@@ -678,11 +767,13 @@ async def stream_chat_completion_sdk(
|
||||
f"({len(session.messages)} messages in session)"
|
||||
)
|
||||
|
||||
allowed = get_copilot_tool_names(use_e2b=use_e2b)
|
||||
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
|
||||
sdk_options_kwargs: dict[str, Any] = {
|
||||
"system_prompt": system_prompt,
|
||||
"mcp_servers": {"copilot": mcp_server},
|
||||
"allowed_tools": COPILOT_TOOL_NAMES,
|
||||
"disallowed_tools": SDK_DISALLOWED_TOOLS,
|
||||
"allowed_tools": allowed,
|
||||
"disallowed_tools": disallowed,
|
||||
"hooks": security_hooks,
|
||||
"cwd": sdk_cwd,
|
||||
"max_buffer_size": config.claude_agent_max_buffer_size,
|
||||
@@ -827,12 +918,6 @@ async def stream_chat_completion_sdk(
|
||||
# AssistantMessages (each containing only
|
||||
# ToolUseBlocks), we must NOT wait/flush — the prior
|
||||
# tools are still executing concurrently.
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
ResultMessage,
|
||||
ToolUseBlock,
|
||||
)
|
||||
|
||||
is_parallel_continuation = isinstance(
|
||||
sdk_msg, AssistantMessage
|
||||
) and all(isinstance(b, ToolUseBlock) for b in sdk_msg.content)
|
||||
|
||||
@@ -9,20 +9,84 @@ import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
from backend.copilot.tools.base import BaseTool
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
from .e2b_file_tools import E2B_FILE_TOOL_NAMES, E2B_FILE_TOOLS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Allowed base directory for the Read tool (SDK saves oversized tool results here).
|
||||
# Restricted to ~/.claude/projects/ and further validated to require "tool-results"
|
||||
# in the path — prevents reading settings, credentials, or other sensitive files.
|
||||
_SDK_PROJECTS_DIR = os.path.expanduser("~/.claude/projects/")
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
# Max MCP response size in chars — keeps tool output under the SDK's 10 MB JSON buffer.
|
||||
_MCP_MAX_CHARS = 500_000
|
||||
|
||||
# Context variable holding the encoded project directory name for the current
|
||||
# session (e.g. "-private-tmp-copilot-<uuid>"). Set by set_execution_context()
|
||||
# so that path validation can scope tool-results reads to the current session.
|
||||
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
|
||||
|
||||
|
||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does.
|
||||
|
||||
The CLI replaces all non-alphanumeric characters with ``-``.
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
|
||||
|
||||
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
"""Check whether *path* is an allowed host-filesystem path.
|
||||
|
||||
Allowed:
|
||||
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/`` — the SDK's
|
||||
project directory for this session (tool-results, transcripts, etc.)
|
||||
|
||||
Both checks are scoped to the **current session** so sessions cannot
|
||||
read each other's data.
|
||||
"""
|
||||
if not path:
|
||||
return False
|
||||
|
||||
if path.startswith("~"):
|
||||
resolved = os.path.realpath(os.path.expanduser(path))
|
||||
elif not os.path.isabs(path) and sdk_cwd:
|
||||
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
|
||||
else:
|
||||
resolved = os.path.realpath(path)
|
||||
|
||||
# Allow access within the SDK working directory
|
||||
if sdk_cwd:
|
||||
norm_cwd = os.path.realpath(sdk_cwd)
|
||||
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
|
||||
return True
|
||||
|
||||
# Allow access within the current session's CLI project directory
|
||||
# (~/.claude/projects/<encoded-cwd>/).
|
||||
encoded = _current_project_dir.get("")
|
||||
if encoded:
|
||||
session_project = os.path.join(_SDK_PROJECTS_DIR, encoded)
|
||||
if resolved == session_project or resolved.startswith(session_project + os.sep):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
|
||||
MCP_SERVER_NAME = "copilot"
|
||||
@@ -33,6 +97,12 @@ _current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default
|
||||
_current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||
"current_session", default=None
|
||||
)
|
||||
# E2B cloud sandbox for the current turn (None when E2B is not configured).
|
||||
# Passed to bash_exec so commands run on E2B instead of the local bwrap sandbox.
|
||||
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
|
||||
"_current_sandbox", default=None
|
||||
)
|
||||
|
||||
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||
# response adapter when it builds StreamToolOutputAvailable.
|
||||
@@ -53,22 +123,33 @@ _stash_event: ContextVar[asyncio.Event | None] = ContextVar(
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
sandbox: "AsyncSandbox | None" = None,
|
||||
sdk_cwd: str | None = None,
|
||||
) -> None:
|
||||
"""Set the execution context for tool calls.
|
||||
|
||||
This must be called before streaming begins to ensure tools have access
|
||||
to user_id and session information.
|
||||
to user_id, session, and (optionally) an E2B sandbox for bash execution.
|
||||
|
||||
Args:
|
||||
user_id: Current user's ID.
|
||||
session: Current chat session.
|
||||
sandbox: Optional E2B sandbox; when set, bash_exec routes commands there.
|
||||
sdk_cwd: SDK working directory; used to scope tool-results reads.
|
||||
"""
|
||||
_current_user_id.set(user_id)
|
||||
_current_session.set(session)
|
||||
_current_sandbox.set(sandbox)
|
||||
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
|
||||
_pending_tool_outputs.set({})
|
||||
_stash_event.set(asyncio.Event())
|
||||
|
||||
|
||||
def get_current_sandbox() -> "AsyncSandbox | None":
|
||||
"""Return the E2B sandbox for the current turn, or None."""
|
||||
return _current_sandbox.get()
|
||||
|
||||
|
||||
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
"""Get the current execution context."""
|
||||
return (
|
||||
@@ -182,11 +263,6 @@ async def _execute_tool_sync(
|
||||
result.output if isinstance(result.output, str) else json.dumps(result.output)
|
||||
)
|
||||
|
||||
# Stash the full output before the SDK potentially truncates it.
|
||||
pending = _pending_tool_outputs.get(None)
|
||||
if pending is not None:
|
||||
pending.setdefault(base_tool.name, []).append(text)
|
||||
|
||||
content_blocks: list[dict[str, str]] = [{"type": "text", "text": text}]
|
||||
|
||||
# If the tool result contains inline image data, add an MCP image block
|
||||
@@ -284,29 +360,32 @@ def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
|
||||
|
||||
|
||||
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Read a file with optional offset/limit. Restricted to SDK working directory.
|
||||
"""Read a local file with optional offset/limit.
|
||||
|
||||
After reading, the file is deleted to prevent accumulation in long-running pods.
|
||||
Only allows paths that pass :func:`is_allowed_local_path` — the current
|
||||
session's tool-results directory and ephemeral working directory.
|
||||
"""
|
||||
file_path = args.get("file_path", "")
|
||||
offset = args.get("offset", 0)
|
||||
limit = args.get("limit", 2000)
|
||||
|
||||
# Security: only allow reads under ~/.claude/projects/**/tool-results/
|
||||
real_path = os.path.realpath(file_path)
|
||||
if not real_path.startswith(_SDK_PROJECTS_DIR) or "tool-results" not in real_path:
|
||||
if not is_allowed_local_path(file_path):
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
resolved = os.path.realpath(os.path.expanduser(file_path))
|
||||
try:
|
||||
with open(real_path) as f:
|
||||
with open(resolved) as f:
|
||||
selected = list(itertools.islice(f, offset, offset + limit))
|
||||
content = "".join(selected)
|
||||
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
|
||||
# don't delete here — the SDK may read in multiple chunks.
|
||||
return {"content": [{"type": "text", "text": content}], "isError": False}
|
||||
return {
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"isError": False,
|
||||
}
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
|
||||
@@ -345,49 +424,82 @@ _READ_TOOL_SCHEMA = {
|
||||
|
||||
|
||||
# Create the MCP server configuration
|
||||
def create_copilot_mcp_server():
|
||||
def _text_from_mcp_result(result: dict[str, Any]) -> str:
|
||||
"""Extract concatenated text from an MCP response's content blocks."""
|
||||
content = result.get("content", [])
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
b.get("text", "")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
]
|
||||
return "".join(parts)
|
||||
return ""
|
||||
|
||||
|
||||
def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
"""Create an in-process MCP server configuration for CoPilot tools.
|
||||
|
||||
This can be passed to ClaudeAgentOptions.mcp_servers.
|
||||
|
||||
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
|
||||
package being available. This function returns the configuration that
|
||||
can be used with the SDK.
|
||||
When *use_e2b* is True, five additional MCP file tools are registered
|
||||
that route directly to the E2B sandbox filesystem, and the caller should
|
||||
disable the corresponding SDK built-in tools via
|
||||
:func:`get_sdk_disallowed_tools`.
|
||||
"""
|
||||
try:
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
|
||||
# Create decorated tool functions
|
||||
sdk_tools = []
|
||||
def _truncating(fn, tool_name: str):
|
||||
"""Wrap a tool handler so its response is truncated to stay under the
|
||||
SDK's 10 MB JSON buffer, and stash the (truncated) output for the
|
||||
response adapter before the SDK can apply its own head-truncation.
|
||||
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
handler = create_tool_handler(base_tool)
|
||||
decorated = tool(
|
||||
tool_name,
|
||||
base_tool.description,
|
||||
_build_input_schema(base_tool),
|
||||
)(handler)
|
||||
Applied once to every registered tool."""
|
||||
|
||||
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
|
||||
result = await fn(args)
|
||||
truncated = truncate(result, _MCP_MAX_CHARS)
|
||||
|
||||
# Stash the text so the response adapter can forward our
|
||||
# middle-out truncated version to the frontend instead of the
|
||||
# SDK's head-truncated version (for outputs >~100 KB the SDK
|
||||
# persists to tool-results/ with a 2 KB head-only preview).
|
||||
if not truncated.get("isError"):
|
||||
text = _text_from_mcp_result(truncated)
|
||||
if text:
|
||||
stash_pending_tool_output(tool_name, text)
|
||||
|
||||
return truncated
|
||||
|
||||
return wrapper
|
||||
|
||||
sdk_tools = []
|
||||
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
handler = create_tool_handler(base_tool)
|
||||
decorated = tool(
|
||||
tool_name,
|
||||
base_tool.description,
|
||||
_build_input_schema(base_tool),
|
||||
)(_truncating(handler, tool_name))
|
||||
sdk_tools.append(decorated)
|
||||
|
||||
# E2B file tools replace SDK built-in Read/Write/Edit/Glob/Grep.
|
||||
if use_e2b:
|
||||
for name, desc, schema, handler in E2B_FILE_TOOLS:
|
||||
decorated = tool(name, desc, schema)(_truncating(handler, name))
|
||||
sdk_tools.append(decorated)
|
||||
|
||||
# Add the Read tool so the SDK can read back oversized tool results
|
||||
read_tool = tool(
|
||||
_READ_TOOL_NAME,
|
||||
_READ_TOOL_DESCRIPTION,
|
||||
_READ_TOOL_SCHEMA,
|
||||
)(_read_file_handler)
|
||||
sdk_tools.append(read_tool)
|
||||
# Read tool for SDK-truncated tool results (always needed).
|
||||
read_tool = tool(
|
||||
_READ_TOOL_NAME,
|
||||
_READ_TOOL_DESCRIPTION,
|
||||
_READ_TOOL_SCHEMA,
|
||||
)(_truncating(_read_file_handler, _READ_TOOL_NAME))
|
||||
sdk_tools.append(read_tool)
|
||||
|
||||
server = create_sdk_mcp_server(
|
||||
name=MCP_SERVER_NAME,
|
||||
version="1.0.0",
|
||||
tools=sdk_tools,
|
||||
)
|
||||
|
||||
return server
|
||||
|
||||
except ImportError:
|
||||
# Let ImportError propagate so service.py handles the fallback
|
||||
raise
|
||||
return create_sdk_mcp_server(
|
||||
name=MCP_SERVER_NAME,
|
||||
version="1.0.0",
|
||||
tools=sdk_tools,
|
||||
)
|
||||
|
||||
|
||||
# SDK built-in tools allowed within the workspace directory.
|
||||
@@ -397,16 +509,11 @@ def create_copilot_mcp_server():
|
||||
# Task allows spawning sub-agents (rate-limited by security hooks).
|
||||
# WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk.
|
||||
# TodoWrite manages the task checklist shown in the UI — no security concern.
|
||||
_SDK_BUILTIN_TOOLS = [
|
||||
"Read",
|
||||
"Write",
|
||||
"Edit",
|
||||
"Glob",
|
||||
"Grep",
|
||||
"Task",
|
||||
"WebSearch",
|
||||
"TodoWrite",
|
||||
]
|
||||
# In E2B mode, all five are disabled — MCP equivalents provide direct sandbox
|
||||
# access. read_file also handles local tool-results and ephemeral reads.
|
||||
_SDK_BUILTIN_FILE_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep"]
|
||||
_SDK_BUILTIN_ALWAYS = ["Task", "WebSearch", "TodoWrite"]
|
||||
_SDK_BUILTIN_TOOLS = [*_SDK_BUILTIN_FILE_TOOLS, *_SDK_BUILTIN_ALWAYS]
|
||||
|
||||
# SDK built-in tools that must be explicitly blocked.
|
||||
# Bash: dangerous — agent uses mcp__copilot__bash_exec with kernel-level
|
||||
@@ -453,11 +560,37 @@ DANGEROUS_PATTERNS = [
|
||||
r"subprocess",
|
||||
]
|
||||
|
||||
# List of tool names for allowed_tools configuration
|
||||
# Include MCP tools, the MCP Read tool for oversized results,
|
||||
# and SDK built-in file tools for workspace operations.
|
||||
# Static tool name list for the non-E2B case (backward compatibility).
|
||||
COPILOT_TOOL_NAMES = [
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
|
||||
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
|
||||
*_SDK_BUILTIN_TOOLS,
|
||||
]
|
||||
|
||||
|
||||
def get_copilot_tool_names(*, use_e2b: bool = False) -> list[str]:
|
||||
"""Build the ``allowed_tools`` list for :class:`ClaudeAgentOptions`.
|
||||
|
||||
When *use_e2b* is True the SDK built-in file tools are replaced by MCP
|
||||
equivalents that route to the E2B sandbox.
|
||||
"""
|
||||
if not use_e2b:
|
||||
return list(COPILOT_TOOL_NAMES)
|
||||
|
||||
return [
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
|
||||
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in E2B_FILE_TOOL_NAMES],
|
||||
*_SDK_BUILTIN_ALWAYS,
|
||||
]
|
||||
|
||||
|
||||
def get_sdk_disallowed_tools(*, use_e2b: bool = False) -> list[str]:
|
||||
"""Build the ``disallowed_tools`` list for :class:`ClaudeAgentOptions`.
|
||||
|
||||
When *use_e2b* is True the SDK built-in file tools are also disabled
|
||||
because MCP equivalents provide direct sandbox access.
|
||||
"""
|
||||
if not use_e2b:
|
||||
return list(SDK_DISALLOWED_TOOLS)
|
||||
return [*SDK_DISALLOWED_TOOLS, *_SDK_BUILTIN_FILE_TOOLS]
|
||||
|
||||
@@ -0,0 +1,145 @@
|
||||
"""Tests for tool_adapter helpers: _text_from_mcp_result, truncation stash."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
from .tool_adapter import (
|
||||
_MCP_MAX_CHARS,
|
||||
_text_from_mcp_result,
|
||||
pop_pending_tool_output,
|
||||
set_execution_context,
|
||||
stash_pending_tool_output,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _text_from_mcp_result
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTextFromMcpResult:
|
||||
def test_single_text_block(self):
|
||||
result = {"content": [{"type": "text", "text": "hello"}]}
|
||||
assert _text_from_mcp_result(result) == "hello"
|
||||
|
||||
def test_multiple_text_blocks_concatenated(self):
|
||||
result = {
|
||||
"content": [
|
||||
{"type": "text", "text": "one"},
|
||||
{"type": "text", "text": "two"},
|
||||
]
|
||||
}
|
||||
assert _text_from_mcp_result(result) == "onetwo"
|
||||
|
||||
def test_non_text_blocks_ignored(self):
|
||||
result = {
|
||||
"content": [
|
||||
{"type": "image", "data": "..."},
|
||||
{"type": "text", "text": "only this"},
|
||||
]
|
||||
}
|
||||
assert _text_from_mcp_result(result) == "only this"
|
||||
|
||||
def test_empty_content_list(self):
|
||||
assert _text_from_mcp_result({"content": []}) == ""
|
||||
|
||||
def test_missing_content_key(self):
|
||||
assert _text_from_mcp_result({}) == ""
|
||||
|
||||
def test_non_list_content(self):
|
||||
assert _text_from_mcp_result({"content": "raw string"}) == ""
|
||||
|
||||
def test_missing_text_field(self):
|
||||
result = {"content": [{"type": "text"}]}
|
||||
assert _text_from_mcp_result(result) == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# stash / pop round-trip (the mechanism _truncating relies on)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolOutputStash:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init_context(self):
|
||||
"""Initialise the context vars that stash_pending_tool_output needs."""
|
||||
set_execution_context(
|
||||
user_id="test",
|
||||
session=None, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
sdk_cwd="/tmp/test",
|
||||
)
|
||||
|
||||
def test_stash_and_pop(self):
|
||||
stash_pending_tool_output("my_tool", "output1")
|
||||
assert pop_pending_tool_output("my_tool") == "output1"
|
||||
|
||||
def test_pop_empty_returns_none(self):
|
||||
assert pop_pending_tool_output("nonexistent") is None
|
||||
|
||||
def test_fifo_order(self):
|
||||
stash_pending_tool_output("t", "first")
|
||||
stash_pending_tool_output("t", "second")
|
||||
assert pop_pending_tool_output("t") == "first"
|
||||
assert pop_pending_tool_output("t") == "second"
|
||||
assert pop_pending_tool_output("t") is None
|
||||
|
||||
def test_dict_serialised_to_json(self):
|
||||
stash_pending_tool_output("t", {"key": "value"})
|
||||
assert pop_pending_tool_output("t") == '{"key": "value"}'
|
||||
|
||||
def test_separate_tool_names(self):
|
||||
stash_pending_tool_output("a", "alpha")
|
||||
stash_pending_tool_output("b", "beta")
|
||||
assert pop_pending_tool_output("b") == "beta"
|
||||
assert pop_pending_tool_output("a") == "alpha"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _truncating wrapper (integration via create_copilot_mcp_server)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTruncationAndStashIntegration:
|
||||
"""Test truncation + stash behavior that _truncating relies on."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init_context(self):
|
||||
set_execution_context(
|
||||
user_id="test",
|
||||
session=None, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
sdk_cwd="/tmp/test",
|
||||
)
|
||||
|
||||
def test_small_output_stashed(self):
|
||||
"""Non-error output is stashed for the response adapter."""
|
||||
result = {
|
||||
"content": [{"type": "text", "text": "small output"}],
|
||||
"isError": False,
|
||||
}
|
||||
truncated = truncate(result, _MCP_MAX_CHARS)
|
||||
text = _text_from_mcp_result(truncated)
|
||||
assert text == "small output"
|
||||
stash_pending_tool_output("test_tool", text)
|
||||
assert pop_pending_tool_output("test_tool") == "small output"
|
||||
|
||||
def test_error_result_not_stashed(self):
|
||||
"""Error results should not be stashed."""
|
||||
result = {
|
||||
"content": [{"type": "text", "text": "error msg"}],
|
||||
"isError": True,
|
||||
}
|
||||
# _truncating only stashes when not result.get("isError")
|
||||
if not result.get("isError"):
|
||||
stash_pending_tool_output("err_tool", "should not happen")
|
||||
assert pop_pending_tool_output("err_tool") is None
|
||||
|
||||
def test_large_output_truncated(self):
|
||||
"""Output exceeding _MCP_MAX_CHARS is truncated before stashing."""
|
||||
big_text = "x" * (_MCP_MAX_CHARS + 100_000)
|
||||
result = {"content": [{"type": "text", "text": big_text}]}
|
||||
truncated = truncate(result, _MCP_MAX_CHARS)
|
||||
text = _text_from_mcp_result(truncated)
|
||||
assert len(text) < len(big_text)
|
||||
assert len(str(truncated)) <= _MCP_MAX_CHARS
|
||||
@@ -59,7 +59,7 @@ from .response_model import (
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from .tools import execute_tool, tools
|
||||
from .tools import execute_tool, get_available_tools
|
||||
from .tools.models import ErrorResponse
|
||||
from .tracking import track_user_message
|
||||
|
||||
@@ -514,7 +514,7 @@ async def stream_chat_completion(
|
||||
)
|
||||
async for chunk in _stream_chat_chunks(
|
||||
session=session,
|
||||
tools=tools,
|
||||
tools=get_available_tools(),
|
||||
system_prompt=system_prompt,
|
||||
text_block_id=text_block_id,
|
||||
):
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tracking import track_tool_called
|
||||
|
||||
from .add_understanding import AddUnderstandingTool
|
||||
from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreenshotTool
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
@@ -30,6 +32,7 @@ from .workspace_files import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -50,6 +53,10 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"get_doc_page": GetDocPageTool(),
|
||||
# Web fetch for safe URL retrieval
|
||||
"web_fetch": WebFetchTool(),
|
||||
# Agent-browser multi-step automation (navigate, act, screenshot)
|
||||
"browser_navigate": BrowserNavigateTool(),
|
||||
"browser_act": BrowserActTool(),
|
||||
"browser_screenshot": BrowserScreenshotTool(),
|
||||
# Sandboxed code execution (bubblewrap)
|
||||
"bash_exec": BashExecTool(),
|
||||
# Persistent workspace tools (cloud storage, survives across sessions)
|
||||
@@ -67,10 +74,17 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
find_agent_tool = TOOL_REGISTRY["find_agent"]
|
||||
run_agent_tool = TOOL_REGISTRY["run_agent"]
|
||||
|
||||
# Generated from registry for OpenAI API
|
||||
tools: list[ChatCompletionToolParam] = [
|
||||
tool.as_openai_tool() for tool in TOOL_REGISTRY.values()
|
||||
]
|
||||
|
||||
def get_available_tools() -> list[ChatCompletionToolParam]:
|
||||
"""Return OpenAI tool schemas for tools available in the current environment.
|
||||
|
||||
Called per-request so that env-var or binary availability is evaluated
|
||||
fresh each time (e.g. browser_* tools are excluded when agent-browser
|
||||
CLI is not installed).
|
||||
"""
|
||||
return [
|
||||
tool.as_openai_tool() for tool in TOOL_REGISTRY.values() if tool.is_available
|
||||
]
|
||||
|
||||
|
||||
def get_tool(tool_name: str) -> BaseTool | None:
|
||||
|
||||
876
autogpt_platform/backend/backend/copilot/tools/agent_browser.py
Normal file
876
autogpt_platform/backend/backend/copilot/tools/agent_browser.py
Normal file
@@ -0,0 +1,876 @@
|
||||
"""Agent-browser tools — multi-step browser automation for the Copilot.
|
||||
|
||||
Uses the agent-browser CLI (https://github.com/vercel-labs/agent-browser)
|
||||
which runs a local Chromium instance managed by a persistent daemon.
|
||||
|
||||
- Runs locally — no cloud account required
|
||||
- Full interaction support: click, fill, scroll, login flows, multi-step
|
||||
- Session persistence via --session-name: cookies/auth carry across tool calls
|
||||
within the same Copilot session, enabling login → navigate → extract workflows
|
||||
- Screenshot with --annotate overlays @ref labels, saved to workspace for user
|
||||
- The Claude Agent SDK's multi-turn loop handles orchestration — each tool call
|
||||
is one browser action; the LLM chains them naturally
|
||||
|
||||
SSRF protection:
|
||||
Uses the shared validate_url() from backend.util.request, which is the same
|
||||
guard used by HTTP blocks and web_fetch. It resolves ALL DNS answers (not just
|
||||
the first), blocks RFC 1918, loopback, link-local, 0.0.0.0/8, multicast,
|
||||
and all relevant IPv6 ranges, and applies IDNA encoding to prevent Unicode
|
||||
domain attacks.
|
||||
|
||||
Requires:
|
||||
npm install -g agent-browser
|
||||
agent-browser install (downloads Chromium, one-time per machine)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.util.request import validate_url
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
BrowserActResponse,
|
||||
BrowserNavigateResponse,
|
||||
BrowserScreenshotResponse,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .workspace_files import get_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Per-command timeout (seconds). Navigation + networkidle wait can be slow.
|
||||
_CMD_TIMEOUT = 45
|
||||
# Accessibility tree can be very large; cap it to keep LLM context manageable.
|
||||
_MAX_SNAPSHOT_CHARS = 20_000
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subprocess helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _run(
|
||||
session_name: str,
|
||||
*args: str,
|
||||
timeout: int = _CMD_TIMEOUT,
|
||||
) -> tuple[int, str, str]:
|
||||
"""Run agent-browser for the given session and return (rc, stdout, stderr).
|
||||
|
||||
Uses both:
|
||||
--session <name> → isolated Chromium context (no shared history/cookies
|
||||
with other Copilot sessions — prevents cross-session
|
||||
browser state leakage)
|
||||
--session-name <name> → persist cookies/localStorage across tool calls within
|
||||
the same session (enables login → navigate flows)
|
||||
"""
|
||||
cmd = [
|
||||
"agent-browser",
|
||||
"--session",
|
||||
session_name,
|
||||
"--session-name",
|
||||
session_name,
|
||||
*args,
|
||||
]
|
||||
proc = None
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout)
|
||||
return proc.returncode or 0, stdout.decode(), stderr.decode()
|
||||
except asyncio.TimeoutError:
|
||||
# Kill the orphaned subprocess so it does not linger in the process table.
|
||||
if proc is not None and proc.returncode is None:
|
||||
proc.kill()
|
||||
try:
|
||||
await proc.communicate()
|
||||
except Exception:
|
||||
pass # Best-effort reap; ignore errors during cleanup.
|
||||
return 1, "", f"Command timed out after {timeout}s."
|
||||
except FileNotFoundError:
|
||||
return (
|
||||
1,
|
||||
"",
|
||||
"agent-browser is not installed (run: npm install -g agent-browser && agent-browser install).",
|
||||
)
|
||||
|
||||
|
||||
async def _snapshot(session_name: str) -> str:
|
||||
"""Return the current page's interactive accessibility tree, truncated."""
|
||||
rc, stdout, stderr = await _run(session_name, "snapshot", "-i", "-c")
|
||||
if rc != 0:
|
||||
return f"[snapshot failed: {stderr[:300]}]"
|
||||
text = stdout.strip()
|
||||
if len(text) > _MAX_SNAPSHOT_CHARS:
|
||||
suffix = "\n\n[Snapshot truncated — use browser_act to navigate further]"
|
||||
keep = max(0, _MAX_SNAPSHOT_CHARS - len(suffix))
|
||||
text = text[:keep] + suffix
|
||||
return text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stateless session helpers — persist / restore browser state across pods
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Module-level cache of sessions known to be alive on this pod.
|
||||
# Avoids the subprocess probe on every tool call within the same pod.
|
||||
_alive_sessions: set[str] = set()
|
||||
|
||||
# Per-session locks to prevent concurrent _ensure_session calls from
|
||||
# triggering duplicate _restore_browser_state for the same session.
|
||||
# Protected by _session_locks_mutex to ensure setdefault/pop are not
|
||||
# interleaved across await boundaries.
|
||||
_session_locks: dict[str, asyncio.Lock] = {}
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
# Workspace filename for persisted browser state (auto-scoped to session).
|
||||
# Dot-prefixed so it is hidden from user workspace listings.
|
||||
_STATE_FILENAME = "._browser_state.json"
|
||||
|
||||
# Maximum concurrent subprocesses during cookie/storage restore.
|
||||
_RESTORE_CONCURRENCY = 10
|
||||
|
||||
# Maximum cookies to restore per session. Pathological sites can accumulate
|
||||
# thousands of cookies; restoring them all would be slow and is rarely useful.
|
||||
_MAX_RESTORE_COOKIES = 100
|
||||
|
||||
# Background tasks for fire-and-forget state persistence.
|
||||
# Prevents GC from collecting tasks before they complete.
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
def _fire_and_forget_save(
|
||||
session_name: str, user_id: str, session: ChatSession
|
||||
) -> None:
|
||||
"""Schedule state persistence as a background task (non-blocking).
|
||||
|
||||
State save is already best-effort (errors are swallowed), so running it
|
||||
in the background avoids adding latency to tool responses.
|
||||
"""
|
||||
task = asyncio.create_task(_save_browser_state(session_name, user_id, session))
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
|
||||
async def _has_local_session(session_name: str) -> bool:
|
||||
"""Check if the local agent-browser daemon for this session is running."""
|
||||
rc, _, _ = await _run(session_name, "get", "url", timeout=5)
|
||||
return rc == 0
|
||||
|
||||
|
||||
async def _save_browser_state(
|
||||
session_name: str, user_id: str, session: ChatSession
|
||||
) -> None:
|
||||
"""Persist browser state (cookies, localStorage, URL) to workspace.
|
||||
|
||||
Best-effort: errors are logged but never propagate to the tool response.
|
||||
"""
|
||||
try:
|
||||
# Gather state in parallel
|
||||
(rc_url, url_out, _), (rc_ck, ck_out, _), (rc_ls, ls_out, _) = (
|
||||
await asyncio.gather(
|
||||
_run(session_name, "get", "url", timeout=10),
|
||||
_run(session_name, "cookies", "get", "--json", timeout=10),
|
||||
_run(session_name, "storage", "local", "--json", timeout=10),
|
||||
)
|
||||
)
|
||||
|
||||
state = {
|
||||
"url": url_out.strip() if rc_url == 0 else "",
|
||||
"cookies": (json.loads(ck_out) if rc_ck == 0 and ck_out.strip() else []),
|
||||
"local_storage": (
|
||||
json.loads(ls_out) if rc_ls == 0 and ls_out.strip() else {}
|
||||
),
|
||||
}
|
||||
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
await manager.write_file(
|
||||
content=json.dumps(state).encode("utf-8"),
|
||||
filename=_STATE_FILENAME,
|
||||
mime_type="application/json",
|
||||
overwrite=True,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[browser] Failed to save browser state for session %s",
|
||||
session_name,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def _restore_browser_state(
|
||||
session_name: str, user_id: str, session: ChatSession
|
||||
) -> bool:
|
||||
"""Restore browser state from workspace storage into a fresh daemon.
|
||||
|
||||
Best-effort: errors are logged but never propagate to the tool response.
|
||||
Returns True on success (or no state to restore), False on failure.
|
||||
"""
|
||||
try:
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
|
||||
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
|
||||
if file_info is None:
|
||||
return True # No saved state — first call or never saved
|
||||
|
||||
state_bytes = await manager.read_file(_STATE_FILENAME)
|
||||
state = json.loads(state_bytes.decode("utf-8"))
|
||||
|
||||
url = state.get("url", "")
|
||||
cookies = state.get("cookies", [])
|
||||
local_storage = state.get("local_storage", {})
|
||||
|
||||
# Navigate first — starts daemon + sets the correct origin for cookies
|
||||
if url:
|
||||
# Validate the saved URL to prevent SSRF via stored redirect targets.
|
||||
try:
|
||||
await validate_url(url, trusted_origins=[])
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"[browser] State restore: blocked SSRF URL %s", url[:200]
|
||||
)
|
||||
return False
|
||||
|
||||
rc, _, stderr = await _run(session_name, "open", url)
|
||||
if rc != 0:
|
||||
logger.warning(
|
||||
"[browser] State restore: failed to open %s: %s",
|
||||
url,
|
||||
stderr[:200],
|
||||
)
|
||||
return False
|
||||
await _run(session_name, "wait", "--load", "load", timeout=15)
|
||||
|
||||
# Restore cookies and localStorage in parallel via asyncio.gather.
|
||||
# Semaphore caps concurrent subprocess spawns so we don't overwhelm the
|
||||
# system when a session has hundreds of cookies.
|
||||
sem = asyncio.Semaphore(_RESTORE_CONCURRENCY)
|
||||
|
||||
# Guard against pathological sites with thousands of cookies.
|
||||
if len(cookies) > _MAX_RESTORE_COOKIES:
|
||||
logger.debug(
|
||||
"[browser] State restore: capping cookies from %d to %d",
|
||||
len(cookies),
|
||||
_MAX_RESTORE_COOKIES,
|
||||
)
|
||||
cookies = cookies[:_MAX_RESTORE_COOKIES]
|
||||
|
||||
async def _set_cookie(c: dict[str, Any]) -> None:
|
||||
name = c.get("name", "")
|
||||
value = c.get("value", "")
|
||||
domain = c.get("domain", "")
|
||||
path = c.get("path", "/")
|
||||
if not (name and domain):
|
||||
return
|
||||
async with sem:
|
||||
rc, _, stderr = await _run(
|
||||
session_name,
|
||||
"cookies",
|
||||
"set",
|
||||
name,
|
||||
value,
|
||||
"--domain",
|
||||
domain,
|
||||
"--path",
|
||||
path,
|
||||
timeout=5,
|
||||
)
|
||||
if rc != 0:
|
||||
logger.debug(
|
||||
"[browser] State restore: cookie set failed for %s: %s",
|
||||
name,
|
||||
stderr[:100],
|
||||
)
|
||||
|
||||
async def _set_storage(key: str, val: object) -> None:
|
||||
async with sem:
|
||||
rc, _, stderr = await _run(
|
||||
session_name,
|
||||
"storage",
|
||||
"local",
|
||||
"set",
|
||||
key,
|
||||
str(val),
|
||||
timeout=5,
|
||||
)
|
||||
if rc != 0:
|
||||
logger.debug(
|
||||
"[browser] State restore: localStorage set failed for %s: %s",
|
||||
key,
|
||||
stderr[:100],
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
*[_set_cookie(c) for c in cookies],
|
||||
*[_set_storage(k, v) for k, v in local_storage.items()],
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[browser] Failed to restore browser state for session %s",
|
||||
session_name,
|
||||
exc_info=True,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def _ensure_session(
|
||||
session_name: str, user_id: str, session: ChatSession
|
||||
) -> None:
|
||||
"""Ensure the local browser daemon has state. Restore from cloud if needed."""
|
||||
if session_name in _alive_sessions:
|
||||
return
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.setdefault(session_name, asyncio.Lock())
|
||||
async with lock:
|
||||
# Double-check after acquiring lock — another coroutine may have restored.
|
||||
if session_name in _alive_sessions:
|
||||
return
|
||||
if await _has_local_session(session_name):
|
||||
_alive_sessions.add(session_name)
|
||||
return
|
||||
if await _restore_browser_state(session_name, user_id, session):
|
||||
_alive_sessions.add(session_name)
|
||||
|
||||
|
||||
async def close_browser_session(session_name: str, user_id: str | None = None) -> None:
|
||||
"""Shut down the local agent-browser daemon and clean up stored state.
|
||||
|
||||
Deletes ``._browser_state.json`` from workspace storage so cookies and
|
||||
other credentials do not linger after the session is deleted.
|
||||
|
||||
Best-effort: errors are logged but never raised.
|
||||
"""
|
||||
_alive_sessions.discard(session_name)
|
||||
async with _session_locks_mutex:
|
||||
_session_locks.pop(session_name, None)
|
||||
|
||||
# Delete persisted browser state (cookies, localStorage) from workspace.
|
||||
if user_id:
|
||||
try:
|
||||
manager = await get_manager(user_id, session_name)
|
||||
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
|
||||
if file_info is not None:
|
||||
await manager.delete_file(file_info.id)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"[browser] Failed to delete state file for session %s",
|
||||
session_name,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
rc, _, stderr = await _run(session_name, "close", timeout=10)
|
||||
if rc != 0:
|
||||
logger.debug(
|
||||
"[browser] close failed for session %s: %s",
|
||||
session_name,
|
||||
stderr[:200],
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"[browser] Exception closing browser session %s",
|
||||
session_name,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool: browser_navigate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BrowserNavigateTool(BaseTool):
|
||||
"""Navigate to a URL and return the page's interactive elements.
|
||||
|
||||
The browser session persists across tool calls within this Copilot session
|
||||
(keyed to session_id), so cookies and auth state carry over. This enables
|
||||
full login flows: navigate to login page → browser_act to fill credentials
|
||||
→ browser_act to submit → browser_navigate to the target page.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "browser_navigate"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Navigate to a URL using a real browser. Returns an accessibility "
|
||||
"tree snapshot listing the page's interactive elements with @ref IDs "
|
||||
"(e.g. @e3) that can be used with browser_act. "
|
||||
"Session persists — cookies and login state carry over between calls. "
|
||||
"Use this (with browser_act) for multi-step interaction: login flows, "
|
||||
"form filling, button clicks, or anything requiring page interaction. "
|
||||
"For plain static pages, prefer web_fetch — no browser overhead. "
|
||||
"For authenticated pages: navigate to the login page first, use browser_act "
|
||||
"to fill credentials and submit, then navigate to the target page. "
|
||||
"Note: for slow SPAs, the returned snapshot may reflect a partially-loaded "
|
||||
"state. If elements seem missing, use browser_act with action='wait' and a "
|
||||
"CSS selector or millisecond delay, then take a browser_screenshot to verify."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The HTTP/HTTPS URL to navigate to.",
|
||||
},
|
||||
"wait_for": {
|
||||
"type": "string",
|
||||
"enum": ["networkidle", "load", "domcontentloaded"],
|
||||
"default": "networkidle",
|
||||
"description": "When to consider navigation complete. Use 'networkidle' for SPAs (default).",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
return shutil.which("agent-browser") is not None
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Navigate to *url*, wait for the page to settle, and return a snapshot.
|
||||
|
||||
The snapshot is an accessibility-tree listing of interactive elements.
|
||||
Note: for slow SPAs that never fully idle, the snapshot may reflect a
|
||||
partially-loaded state (the wait is best-effort).
|
||||
"""
|
||||
url: str = (kwargs.get("url") or "").strip()
|
||||
wait_for: str = kwargs.get("wait_for") or "networkidle"
|
||||
session_name = session.session_id
|
||||
|
||||
if not url:
|
||||
return ErrorResponse(
|
||||
message="Please provide a URL to navigate to.",
|
||||
error="missing_url",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
try:
|
||||
await validate_url(url, trusted_origins=[])
|
||||
except ValueError as e:
|
||||
return ErrorResponse(
|
||||
message=str(e),
|
||||
error="blocked_url",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Restore browser state from cloud if this is a different pod
|
||||
if user_id:
|
||||
await _ensure_session(session_name, user_id, session)
|
||||
|
||||
# Navigate
|
||||
rc, _, stderr = await _run(session_name, "open", url)
|
||||
if rc != 0:
|
||||
logger.warning(
|
||||
"[browser_navigate] open failed for %s: %s", url, stderr[:300]
|
||||
)
|
||||
return ErrorResponse(
|
||||
message="Failed to navigate to URL.",
|
||||
error="navigation_failed",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Wait for page to settle (best-effort: some SPAs never reach networkidle)
|
||||
wait_rc, _, wait_err = await _run(session_name, "wait", "--load", wait_for)
|
||||
if wait_rc != 0:
|
||||
logger.warning(
|
||||
"[browser_navigate] wait(%s) failed: %s", wait_for, wait_err[:300]
|
||||
)
|
||||
|
||||
# Get current title and URL in parallel
|
||||
(_, title_out, _), (_, url_out, _) = await asyncio.gather(
|
||||
_run(session_name, "get", "title"),
|
||||
_run(session_name, "get", "url"),
|
||||
)
|
||||
|
||||
snapshot = await _snapshot(session_name)
|
||||
|
||||
result = BrowserNavigateResponse(
|
||||
message=f"Navigated to {url}",
|
||||
url=url_out.strip() or url,
|
||||
title=title_out.strip(),
|
||||
snapshot=snapshot,
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Persist browser state to cloud for cross-pod continuity
|
||||
if user_id:
|
||||
_fire_and_forget_save(session_name, user_id, session)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool: browser_act
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_NO_TARGET_ACTIONS = frozenset({"back", "forward", "reload"})
|
||||
_SCROLL_ACTIONS = frozenset({"scroll"})
|
||||
_TARGET_ONLY_ACTIONS = frozenset({"click", "dblclick", "hover", "check", "uncheck"})
|
||||
_TARGET_VALUE_ACTIONS = frozenset({"fill", "type", "select"})
|
||||
# wait <selector|ms>: waits for a DOM element or a fixed delay (e.g. "1000" for 1 s)
|
||||
_WAIT_ACTIONS = frozenset({"wait"})
|
||||
|
||||
|
||||
class BrowserActTool(BaseTool):
|
||||
"""Perform an action on the current browser page and return the updated snapshot.
|
||||
|
||||
Use @ref IDs from the snapshot returned by browser_navigate (e.g. '@e3').
|
||||
The LLM orchestrates multi-step flows by chaining browser_navigate and
|
||||
browser_act calls across turns of the Claude Agent SDK conversation.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "browser_act"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Interact with the current browser page. Use @ref IDs from the "
|
||||
"snapshot (e.g. '@e3') to target elements. Returns an updated snapshot. "
|
||||
"Supported actions: click, dblclick, fill, type, scroll, hover, press, "
|
||||
"check, uncheck, select, wait, back, forward, reload. "
|
||||
"fill clears the field before typing; type appends without clearing. "
|
||||
"wait accepts a CSS selector (waits for element) or milliseconds string (e.g. '1000'). "
|
||||
"Example login flow: fill @e1 with email → fill @e2 with password → "
|
||||
"click @e3 (submit) → browser_navigate to the target page."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"click",
|
||||
"dblclick",
|
||||
"fill",
|
||||
"type",
|
||||
"scroll",
|
||||
"hover",
|
||||
"press",
|
||||
"check",
|
||||
"uncheck",
|
||||
"select",
|
||||
"wait",
|
||||
"back",
|
||||
"forward",
|
||||
"reload",
|
||||
],
|
||||
"description": "The action to perform.",
|
||||
},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Element to target. Use @ref from snapshot (e.g. '@e3'), "
|
||||
"a CSS selector, or a text description. "
|
||||
"Required for: click, dblclick, fill, type, hover, check, uncheck, select. "
|
||||
"For wait: a CSS selector to wait for, or milliseconds as a string (e.g. '1000')."
|
||||
),
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"For fill/type: the text to enter. "
|
||||
"For press: key name (e.g. 'Enter', 'Tab', 'Control+a'). "
|
||||
"For select: the option value to select."
|
||||
),
|
||||
},
|
||||
"direction": {
|
||||
"type": "string",
|
||||
"enum": ["up", "down", "left", "right"],
|
||||
"default": "down",
|
||||
"description": "For scroll: direction to scroll.",
|
||||
},
|
||||
},
|
||||
"required": ["action"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
return shutil.which("agent-browser") is not None
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Perform a browser action and return an updated page snapshot.
|
||||
|
||||
Validates the *action*/*target*/*value* combination, delegates to
|
||||
``agent-browser``, waits for the page to settle, and returns the
|
||||
accessibility-tree snapshot so the LLM can plan the next step.
|
||||
"""
|
||||
action: str = (kwargs.get("action") or "").strip()
|
||||
target: str = (kwargs.get("target") or "").strip()
|
||||
value: str = (kwargs.get("value") or "").strip()
|
||||
direction: str = (kwargs.get("direction") or "down").strip()
|
||||
session_name = session.session_id
|
||||
|
||||
if not action:
|
||||
return ErrorResponse(
|
||||
message="Please specify an action.",
|
||||
error="missing_action",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Build the agent-browser command args
|
||||
if action in _NO_TARGET_ACTIONS:
|
||||
cmd_args = [action]
|
||||
|
||||
elif action in _SCROLL_ACTIONS:
|
||||
cmd_args = ["scroll", direction]
|
||||
|
||||
elif action == "press":
|
||||
if not value:
|
||||
return ErrorResponse(
|
||||
message="'press' requires a 'value' (key name, e.g. 'Enter').",
|
||||
error="missing_value",
|
||||
session_id=session_name,
|
||||
)
|
||||
cmd_args = ["press", value]
|
||||
|
||||
elif action in _TARGET_ONLY_ACTIONS:
|
||||
if not target:
|
||||
return ErrorResponse(
|
||||
message=f"'{action}' requires a 'target' element.",
|
||||
error="missing_target",
|
||||
session_id=session_name,
|
||||
)
|
||||
cmd_args = [action, target]
|
||||
|
||||
elif action in _TARGET_VALUE_ACTIONS:
|
||||
if not target or not value:
|
||||
return ErrorResponse(
|
||||
message=f"'{action}' requires both 'target' and 'value'.",
|
||||
error="missing_params",
|
||||
session_id=session_name,
|
||||
)
|
||||
cmd_args = [action, target, value]
|
||||
|
||||
elif action in _WAIT_ACTIONS:
|
||||
if not target:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"'wait' requires a 'target': a CSS selector to wait for, "
|
||||
"or milliseconds as a string (e.g. '1000')."
|
||||
),
|
||||
error="missing_target",
|
||||
session_id=session_name,
|
||||
)
|
||||
cmd_args = ["wait", target]
|
||||
|
||||
else:
|
||||
return ErrorResponse(
|
||||
message=f"Unsupported action: {action}",
|
||||
error="invalid_action",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Restore browser state from cloud if this is a different pod
|
||||
if user_id:
|
||||
await _ensure_session(session_name, user_id, session)
|
||||
|
||||
rc, _, stderr = await _run(session_name, *cmd_args)
|
||||
if rc != 0:
|
||||
logger.warning("[browser_act] %s failed: %s", action, stderr[:300])
|
||||
return ErrorResponse(
|
||||
message=f"Action '{action}' failed.",
|
||||
error="action_failed",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Allow the page to settle after interaction (best-effort: SPAs may not idle)
|
||||
settle_rc, _, settle_err = await _run(
|
||||
session_name, "wait", "--load", "networkidle"
|
||||
)
|
||||
if settle_rc != 0:
|
||||
logger.warning(
|
||||
"[browser_act] post-action wait failed: %s", settle_err[:300]
|
||||
)
|
||||
|
||||
snapshot = await _snapshot(session_name)
|
||||
_, url_out, _ = await _run(session_name, "get", "url")
|
||||
|
||||
result = BrowserActResponse(
|
||||
message=f"Performed '{action}'" + (f" on '{target}'" if target else ""),
|
||||
action=action,
|
||||
current_url=url_out.strip(),
|
||||
snapshot=snapshot,
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Persist browser state to cloud for cross-pod continuity
|
||||
if user_id:
|
||||
_fire_and_forget_save(session_name, user_id, session)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool: browser_screenshot
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BrowserScreenshotTool(BaseTool):
|
||||
"""Capture a screenshot of the current browser page and save it to the workspace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "browser_screenshot"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Take a screenshot of the current browser page and save it to the workspace. "
|
||||
"IMPORTANT: After calling this tool, immediately call read_workspace_file "
|
||||
"with the returned file_id to display the image inline to the user — "
|
||||
"the screenshot is not visible until you do this. "
|
||||
"With annotate=true (default), @ref labels are overlaid on interactive "
|
||||
"elements, making it easy to see which @ref ID maps to which element on screen."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"annotate": {
|
||||
"type": "boolean",
|
||||
"default": True,
|
||||
"description": "Overlay @ref labels on interactive elements (default: true).",
|
||||
},
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"default": "screenshot.png",
|
||||
"description": "Filename to save in the workspace.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
return shutil.which("agent-browser") is not None
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Capture a PNG screenshot and upload it to the workspace.
|
||||
|
||||
Handles string-to-bool coercion for *annotate* (OpenAI function-call
|
||||
payloads sometimes deliver ``"true"``/``"false"`` as strings).
|
||||
Returns a :class:`BrowserScreenshotResponse` with the workspace
|
||||
``file_id`` the LLM should pass to ``read_workspace_file``.
|
||||
"""
|
||||
raw_annotate = kwargs.get("annotate", True)
|
||||
if isinstance(raw_annotate, str):
|
||||
annotate = raw_annotate.strip().lower() in {"1", "true", "yes", "on"}
|
||||
else:
|
||||
annotate = bool(raw_annotate)
|
||||
filename: str = (kwargs.get("filename") or "screenshot.png").strip()
|
||||
session_name = session.session_id
|
||||
|
||||
# Restore browser state from cloud if this is a different pod
|
||||
if user_id:
|
||||
await _ensure_session(session_name, user_id, session)
|
||||
|
||||
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".png")
|
||||
os.close(tmp_fd)
|
||||
try:
|
||||
cmd_args = ["screenshot"]
|
||||
if annotate:
|
||||
cmd_args.append("--annotate")
|
||||
cmd_args.append(tmp_path)
|
||||
|
||||
rc, _, stderr = await _run(session_name, *cmd_args)
|
||||
if rc != 0:
|
||||
logger.warning("[browser_screenshot] failed: %s", stderr[:300])
|
||||
return ErrorResponse(
|
||||
message="Failed to take screenshot.",
|
||||
error="screenshot_failed",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
with open(tmp_path, "rb") as f:
|
||||
png_bytes = f.read()
|
||||
|
||||
finally:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass # Best-effort temp file cleanup; not critical if it fails.
|
||||
|
||||
# Upload to workspace so the user can view it
|
||||
png_b64 = base64.b64encode(png_bytes).decode()
|
||||
|
||||
# Import here to avoid circular deps — workspace_files imports from .models
|
||||
from .workspace_files import WorkspaceWriteResponse, WriteWorkspaceFileTool
|
||||
|
||||
write_resp = await WriteWorkspaceFileTool()._execute(
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
filename=filename,
|
||||
content_base64=png_b64,
|
||||
)
|
||||
|
||||
if not isinstance(write_resp, WorkspaceWriteResponse):
|
||||
return ErrorResponse(
|
||||
message="Screenshot taken but failed to save to workspace.",
|
||||
error="workspace_write_failed",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
result = BrowserScreenshotResponse(
|
||||
message=f"Screenshot saved to workspace as '{filename}'. Use read_workspace_file with file_id='{write_resp.file_id}' to retrieve it.",
|
||||
file_id=write_resp.file_id,
|
||||
filename=filename,
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Persist browser state to cloud for cross-pod continuity
|
||||
if user_id:
|
||||
_fire_and_forget_save(session_name, user_id, session)
|
||||
|
||||
return result
|
||||
1663
autogpt_platform/backend/backend/copilot/tools/agent_browser_test.py
Normal file
1663
autogpt_platform/backend/backend/copilot/tools/agent_browser_test.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -36,6 +36,16 @@ class BaseTool:
|
||||
"""Whether this tool requires authentication."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
"""Whether this tool is available in the current environment.
|
||||
|
||||
Override to check required env vars, binaries, or other dependencies.
|
||||
Unavailable tools are excluded from the LLM tool list so the model is
|
||||
never offered an option that will immediately fail.
|
||||
"""
|
||||
return True
|
||||
|
||||
def as_openai_tool(self) -> ChatCompletionToolParam:
|
||||
"""Convert to OpenAI tool format."""
|
||||
return ChatCompletionToolParam(
|
||||
|
||||
@@ -1,19 +1,30 @@
|
||||
"""Bash execution tool — run shell commands in a bubblewrap sandbox.
|
||||
"""Bash execution tool — run shell commands on E2B or in a bubblewrap sandbox.
|
||||
|
||||
Full Bash scripting is allowed (loops, conditionals, pipes, functions, etc.).
|
||||
Safety comes from OS-level isolation (bubblewrap): only system dirs visible
|
||||
read-only, writable workspace only, clean env, no network.
|
||||
When an E2B sandbox is available in the current execution context the command
|
||||
runs directly on the remote E2B cloud environment. This means:
|
||||
|
||||
Requires bubblewrap (``bwrap``) — the tool is disabled when bwrap is not
|
||||
available (e.g. macOS development).
|
||||
- **Persistent filesystem**: files survive across turns via HTTP-based sync
|
||||
with the sandbox's ``/home/user`` directory (E2B files API), shared with
|
||||
SDK Read/Write/Edit tools.
|
||||
- **Full internet access**: E2B sandboxes have unrestricted outbound network.
|
||||
- **Execution isolation**: E2B provides a fresh, containerised Linux environment.
|
||||
|
||||
When E2B is *not* configured the tool falls back to **bubblewrap** (bwrap):
|
||||
OS-level isolation with a whitelist-only filesystem, no network, and resource
|
||||
limits. Requires bubblewrap to be installed (Linux only).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import shlex
|
||||
from typing import Any
|
||||
|
||||
from e2b import AsyncSandbox
|
||||
from e2b.exceptions import TimeoutException
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .e2b_sandbox import E2B_WORKDIR
|
||||
from .models import BashExecResponse, ErrorResponse, ToolResponseBase
|
||||
from .sandbox import get_workspace_dir, has_full_sandbox, run_sandboxed
|
||||
|
||||
@@ -21,7 +32,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BashExecTool(BaseTool):
|
||||
"""Execute Bash commands in a bubblewrap sandbox."""
|
||||
"""Execute Bash commands on E2B or in a bubblewrap sandbox."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -29,28 +40,16 @@ class BashExecTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
if not has_full_sandbox():
|
||||
return (
|
||||
"Bash execution is DISABLED — bubblewrap sandbox is not "
|
||||
"available on this platform. Do not call this tool."
|
||||
)
|
||||
return (
|
||||
"Execute a Bash command or script in a bubblewrap sandbox. "
|
||||
"Execute a Bash command or script. "
|
||||
"Full Bash scripting is supported (loops, conditionals, pipes, "
|
||||
"functions, etc.). "
|
||||
"The sandbox shares the same working directory as the SDK Read/Write "
|
||||
"tools — files created by either are accessible to both. "
|
||||
"SECURITY: Only system directories (/usr, /bin, /lib, /etc) are "
|
||||
"visible read-only, the per-session workspace is the only writable "
|
||||
"path, environment variables are wiped (no secrets), all network "
|
||||
"access is blocked at the kernel level, and resource limits are "
|
||||
"enforced (max 64 processes, 512MB memory, 50MB file size). "
|
||||
"Application code, configs, and other directories are NOT accessible. "
|
||||
"To fetch web content, use the web_fetch tool instead. "
|
||||
"The working directory is shared with the SDK Read/Write/Edit/Glob/Grep "
|
||||
"tools — files created by either are immediately visible to both. "
|
||||
"Execution is killed after the timeout (default 30s, max 120s). "
|
||||
"Returns stdout and stderr. "
|
||||
"Useful for file manipulation, data processing with Unix tools "
|
||||
"(grep, awk, sed, jq, etc.), and running shell scripts."
|
||||
"Useful for file manipulation, data processing, running scripts, "
|
||||
"and installing packages."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -85,15 +84,8 @@ class BashExecTool(BaseTool):
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not has_full_sandbox():
|
||||
return ErrorResponse(
|
||||
message="bash_exec requires bubblewrap sandbox (Linux only).",
|
||||
error="sandbox_unavailable",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
command: str = (kwargs.get("command") or "").strip()
|
||||
timeout: int = kwargs.get("timeout", 30)
|
||||
timeout: int = int(kwargs.get("timeout", 30))
|
||||
|
||||
if not command:
|
||||
return ErrorResponse(
|
||||
@@ -102,6 +94,21 @@ class BashExecTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# E2B path: run on remote cloud sandbox when available.
|
||||
from backend.copilot.sdk.tool_adapter import get_current_sandbox
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
return await self._execute_on_e2b(sandbox, command, timeout, session_id)
|
||||
|
||||
# Bubblewrap fallback: local isolated execution.
|
||||
if not has_full_sandbox():
|
||||
return ErrorResponse(
|
||||
message="bash_exec requires bubblewrap sandbox (Linux only).",
|
||||
error="sandbox_unavailable",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
workspace = get_workspace_dir(session_id or "default")
|
||||
|
||||
stdout, stderr, exit_code, timed_out = await run_sandboxed(
|
||||
@@ -122,3 +129,43 @@ class BashExecTool(BaseTool):
|
||||
timed_out=timed_out,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def _execute_on_e2b(
|
||||
self,
|
||||
sandbox: AsyncSandbox,
|
||||
command: str,
|
||||
timeout: int,
|
||||
session_id: str | None,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute *command* on the E2B sandbox via commands.run()."""
|
||||
try:
|
||||
result = await sandbox.commands.run(
|
||||
f"bash -c {shlex.quote(command)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
timeout=timeout,
|
||||
envs={"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"},
|
||||
)
|
||||
return BashExecResponse(
|
||||
message=f"Command executed on E2B (exit {result.exit_code})",
|
||||
stdout=result.stdout or "",
|
||||
stderr=result.stderr or "",
|
||||
exit_code=result.exit_code,
|
||||
timed_out=False,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
if isinstance(exc, TimeoutException):
|
||||
return BashExecResponse(
|
||||
message="Execution timed out",
|
||||
stdout="",
|
||||
stderr=f"Timed out after {timeout}s",
|
||||
exit_code=-1,
|
||||
timed_out=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
logger.error("[E2B] bash_exec failed: %s", exc, exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"E2B execution failed: {exc}",
|
||||
error="e2b_execution_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
170
autogpt_platform/backend/backend/copilot/tools/e2b_sandbox.py
Normal file
170
autogpt_platform/backend/backend/copilot/tools/e2b_sandbox.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""E2B sandbox lifecycle for CoPilot: persistent cloud execution.
|
||||
|
||||
Each session gets a long-lived E2B cloud sandbox. ``bash_exec`` runs commands
|
||||
directly on the sandbox via ``sandbox.commands.run()``. SDK file tools
|
||||
(read_file/write_file/edit_file/glob/grep) route to the sandbox's
|
||||
``/home/user`` directory via E2B's HTTP-based filesystem API — all tools
|
||||
share a single coherent filesystem with no local sync required.
|
||||
|
||||
Lifecycle
|
||||
---------
|
||||
1. **Turn start** – connect to the existing sandbox (sandbox_id in Redis) or
|
||||
create a new one via ``get_or_create_sandbox()``.
|
||||
2. **Execution** – ``bash_exec`` and MCP file tools operate directly on the
|
||||
sandbox's ``/home/user`` filesystem.
|
||||
3. **Session expiry** – E2B sandbox is killed by its own timeout (session_ttl).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SANDBOX_REDIS_PREFIX = "copilot:e2b:sandbox:"
|
||||
E2B_WORKDIR = "/home/user"
|
||||
_CREATING = "__creating__"
|
||||
_CREATION_LOCK_TTL = 60
|
||||
_MAX_WAIT_ATTEMPTS = 20 # 20 * 0.5s = 10s max wait
|
||||
|
||||
|
||||
async def _try_reconnect(
|
||||
sandbox_id: str, api_key: str, redis_key: str, timeout: int
|
||||
) -> "AsyncSandbox | None":
|
||||
"""Try to reconnect to an existing sandbox. Returns None on failure."""
|
||||
try:
|
||||
sandbox = await AsyncSandbox.connect(sandbox_id, api_key=api_key)
|
||||
if await sandbox.is_running():
|
||||
redis = await get_redis_async()
|
||||
await redis.expire(redis_key, timeout)
|
||||
return sandbox
|
||||
except Exception as exc:
|
||||
logger.warning("[E2B] Reconnect to %.12s failed: %s", sandbox_id, exc)
|
||||
|
||||
# Stale — clear Redis so a new sandbox can be created.
|
||||
redis = await get_redis_async()
|
||||
await redis.delete(redis_key)
|
||||
return None
|
||||
|
||||
|
||||
async def get_or_create_sandbox(
|
||||
session_id: str,
|
||||
api_key: str,
|
||||
template: str = "base",
|
||||
timeout: int = 43200,
|
||||
) -> AsyncSandbox:
|
||||
"""Return the existing E2B sandbox for *session_id* or create a new one.
|
||||
|
||||
The sandbox_id is persisted in Redis so the same sandbox is reused
|
||||
across turns. Concurrent calls for the same session are serialised
|
||||
via a Redis ``SET NX`` creation lock.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
redis_key = f"{_SANDBOX_REDIS_PREFIX}{session_id}"
|
||||
|
||||
# 1. Try reconnecting to an existing sandbox.
|
||||
raw = await redis.get(redis_key)
|
||||
if raw:
|
||||
sandbox_id = raw if isinstance(raw, str) else raw.decode()
|
||||
if sandbox_id != _CREATING:
|
||||
sandbox = await _try_reconnect(sandbox_id, api_key, redis_key, timeout)
|
||||
if sandbox:
|
||||
logger.info(
|
||||
"[E2B] Reconnected to %.12s for session %.12s",
|
||||
sandbox_id,
|
||||
session_id,
|
||||
)
|
||||
return sandbox
|
||||
|
||||
# 2. Claim creation lock. If another request holds it, wait for the result.
|
||||
claimed = await redis.set(redis_key, _CREATING, nx=True, ex=_CREATION_LOCK_TTL)
|
||||
if not claimed:
|
||||
for _ in range(_MAX_WAIT_ATTEMPTS):
|
||||
await asyncio.sleep(0.5)
|
||||
raw = await redis.get(redis_key)
|
||||
if not raw:
|
||||
break # Lock expired — fall through to retry creation
|
||||
sandbox_id = raw if isinstance(raw, str) else raw.decode()
|
||||
if sandbox_id != _CREATING:
|
||||
sandbox = await _try_reconnect(sandbox_id, api_key, redis_key, timeout)
|
||||
if sandbox:
|
||||
return sandbox
|
||||
break # Stale sandbox cleared — fall through to create
|
||||
|
||||
# Try to claim creation lock again after waiting.
|
||||
claimed = await redis.set(redis_key, _CREATING, nx=True, ex=_CREATION_LOCK_TTL)
|
||||
if not claimed:
|
||||
# Another process may have created a sandbox — try to use it.
|
||||
raw = await redis.get(redis_key)
|
||||
if raw:
|
||||
sandbox_id = raw if isinstance(raw, str) else raw.decode()
|
||||
if sandbox_id != _CREATING:
|
||||
sandbox = await _try_reconnect(
|
||||
sandbox_id, api_key, redis_key, timeout
|
||||
)
|
||||
if sandbox:
|
||||
return sandbox
|
||||
raise RuntimeError(
|
||||
f"Could not acquire E2B creation lock for session {session_id[:12]}"
|
||||
)
|
||||
|
||||
# 3. Create a new sandbox.
|
||||
try:
|
||||
sandbox = await AsyncSandbox.create(
|
||||
template=template, api_key=api_key, timeout=timeout
|
||||
)
|
||||
except Exception:
|
||||
await redis.delete(redis_key)
|
||||
raise
|
||||
|
||||
await redis.setex(redis_key, timeout, sandbox.sandbox_id)
|
||||
logger.info(
|
||||
"[E2B] Created sandbox %.12s for session %.12s",
|
||||
sandbox.sandbox_id,
|
||||
session_id,
|
||||
)
|
||||
return sandbox
|
||||
|
||||
|
||||
async def kill_sandbox(session_id: str, api_key: str) -> bool:
|
||||
"""Kill the E2B sandbox for *session_id* and clean up its Redis entry.
|
||||
|
||||
Returns ``True`` if a sandbox was found and killed, ``False`` otherwise.
|
||||
Safe to call even when no sandbox exists for the session.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
redis_key = f"{_SANDBOX_REDIS_PREFIX}{session_id}"
|
||||
raw = await redis.get(redis_key)
|
||||
if not raw:
|
||||
return False
|
||||
|
||||
sandbox_id = raw if isinstance(raw, str) else raw.decode()
|
||||
await redis.delete(redis_key)
|
||||
|
||||
if sandbox_id == _CREATING:
|
||||
return False
|
||||
|
||||
try:
|
||||
|
||||
async def _connect_and_kill():
|
||||
sandbox = await AsyncSandbox.connect(sandbox_id, api_key=api_key)
|
||||
await sandbox.kill()
|
||||
|
||||
await asyncio.wait_for(_connect_and_kill(), timeout=10)
|
||||
logger.info(
|
||||
"[E2B] Killed sandbox %.12s for session %.12s",
|
||||
sandbox_id,
|
||||
session_id,
|
||||
)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[E2B] Failed to kill sandbox %.12s for session %.12s: %s",
|
||||
sandbox_id,
|
||||
session_id,
|
||||
exc,
|
||||
)
|
||||
return False
|
||||
@@ -0,0 +1,272 @@
|
||||
"""Tests for e2b_sandbox: get_or_create_sandbox, _try_reconnect, kill_sandbox.
|
||||
|
||||
Uses mock Redis and mock AsyncSandbox — no external dependencies.
|
||||
Tests are synchronous (using asyncio.run) to avoid conflicts with the
|
||||
session-scoped event loop in conftest.py.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .e2b_sandbox import (
|
||||
_CREATING,
|
||||
_SANDBOX_REDIS_PREFIX,
|
||||
_try_reconnect,
|
||||
get_or_create_sandbox,
|
||||
kill_sandbox,
|
||||
)
|
||||
|
||||
_KEY = f"{_SANDBOX_REDIS_PREFIX}sess-123"
|
||||
_API_KEY = "test-api-key"
|
||||
_TIMEOUT = 300
|
||||
|
||||
|
||||
def _mock_sandbox(sandbox_id: str = "sb-abc", running: bool = True) -> MagicMock:
|
||||
sb = MagicMock()
|
||||
sb.sandbox_id = sandbox_id
|
||||
sb.is_running = AsyncMock(return_value=running)
|
||||
return sb
|
||||
|
||||
|
||||
def _mock_redis(get_val: str | bytes | None = None, set_nx_result: bool = True):
|
||||
r = AsyncMock()
|
||||
r.get = AsyncMock(return_value=get_val)
|
||||
r.set = AsyncMock(return_value=set_nx_result)
|
||||
r.setex = AsyncMock()
|
||||
r.delete = AsyncMock()
|
||||
r.expire = AsyncMock()
|
||||
return r
|
||||
|
||||
|
||||
def _patch_redis(redis):
|
||||
return patch(
|
||||
"backend.copilot.tools.e2b_sandbox.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=redis,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _try_reconnect
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTryReconnect:
|
||||
def test_reconnect_success(self):
|
||||
sb = _mock_sandbox()
|
||||
redis = _mock_redis()
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
|
||||
|
||||
assert result is sb
|
||||
redis.expire.assert_awaited_once_with(_KEY, _TIMEOUT)
|
||||
redis.delete.assert_not_awaited()
|
||||
|
||||
def test_reconnect_not_running_clears_key(self):
|
||||
sb = _mock_sandbox(running=False)
|
||||
redis = _mock_redis()
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
|
||||
|
||||
assert result is None
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
|
||||
def test_reconnect_exception_clears_key(self):
|
||||
redis = _mock_redis()
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
|
||||
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
|
||||
|
||||
assert result is None
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_or_create_sandbox
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetOrCreateSandbox:
|
||||
def test_reconnect_existing(self):
|
||||
"""When Redis has a valid sandbox_id, reconnect to it."""
|
||||
sb = _mock_sandbox()
|
||||
redis = _mock_redis(get_val="sb-abc")
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
assert result is sb
|
||||
mock_cls.create.assert_not_called()
|
||||
|
||||
def test_create_new_when_no_key(self):
|
||||
"""When Redis is empty, claim lock and create a new sandbox."""
|
||||
sb = _mock_sandbox("sb-new")
|
||||
redis = _mock_redis(get_val=None, set_nx_result=True)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.create = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
assert result is sb
|
||||
redis.setex.assert_awaited_once_with(_KEY, _TIMEOUT, "sb-new")
|
||||
|
||||
def test_create_failure_clears_lock(self):
|
||||
"""If sandbox creation fails, the Redis lock is deleted."""
|
||||
redis = _mock_redis(get_val=None, set_nx_result=True)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.create = AsyncMock(side_effect=RuntimeError("quota"))
|
||||
with pytest.raises(RuntimeError, match="quota"):
|
||||
asyncio.run(
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
|
||||
def test_wait_for_lock_then_reconnect(self):
|
||||
"""When another process holds the lock, wait and reconnect."""
|
||||
sb = _mock_sandbox("sb-other")
|
||||
redis = _mock_redis()
|
||||
redis.get = AsyncMock(side_effect=[_CREATING, "sb-other"])
|
||||
redis.set = AsyncMock(return_value=False)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
patch(
|
||||
"backend.copilot.tools.e2b_sandbox.asyncio.sleep",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
assert result is sb
|
||||
|
||||
def test_stale_reconnect_clears_and_creates(self):
|
||||
"""When stored sandbox is stale, clear key and create a new one."""
|
||||
stale_sb = _mock_sandbox("sb-stale", running=False)
|
||||
new_sb = _mock_sandbox("sb-fresh")
|
||||
redis = _mock_redis(get_val="sb-stale", set_nx_result=True)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=stale_sb)
|
||||
mock_cls.create = AsyncMock(return_value=new_sb)
|
||||
result = asyncio.run(
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
assert result is new_sb
|
||||
redis.delete.assert_awaited()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# kill_sandbox
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestKillSandbox:
|
||||
def test_kill_existing_sandbox(self):
|
||||
"""Kill a running sandbox and clean up Redis."""
|
||||
sb = _mock_sandbox()
|
||||
sb.kill = AsyncMock()
|
||||
redis = _mock_redis(get_val="sb-abc")
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is True
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
sb.kill.assert_awaited_once()
|
||||
|
||||
def test_kill_no_sandbox(self):
|
||||
"""No-op when no sandbox exists in Redis."""
|
||||
redis = _mock_redis(get_val=None)
|
||||
with _patch_redis(redis):
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is False
|
||||
redis.delete.assert_not_awaited()
|
||||
|
||||
def test_kill_creating_state(self):
|
||||
"""Clears Redis key but returns False when sandbox is still being created."""
|
||||
redis = _mock_redis(get_val=_CREATING)
|
||||
with _patch_redis(redis):
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is False
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
|
||||
def test_kill_connect_failure(self):
|
||||
"""Returns False and cleans Redis if connect/kill fails."""
|
||||
redis = _mock_redis(get_val="sb-abc")
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is False
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
|
||||
def test_kill_with_bytes_redis_value(self):
|
||||
"""Redis may return bytes — kill_sandbox should decode correctly."""
|
||||
sb = _mock_sandbox()
|
||||
sb.kill = AsyncMock()
|
||||
redis = _mock_redis(get_val=b"sb-abc")
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is True
|
||||
sb.kill.assert_awaited_once()
|
||||
|
||||
def test_kill_timeout_returns_false(self):
|
||||
"""Returns False when E2B API calls exceed the 10s timeout."""
|
||||
redis = _mock_redis(get_val="sb-abc")
|
||||
with (
|
||||
_patch_redis(redis),
|
||||
patch(
|
||||
"backend.copilot.tools.e2b_sandbox.asyncio.wait_for",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=asyncio.TimeoutError,
|
||||
),
|
||||
):
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is False
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
@@ -41,6 +41,10 @@ class ResponseType(str, Enum):
|
||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
||||
# Web fetch
|
||||
WEB_FETCH = "web_fetch"
|
||||
# Agent-browser multi-step automation (navigate, act, screenshot)
|
||||
BROWSER_NAVIGATE = "browser_navigate"
|
||||
BROWSER_ACT = "browser_act"
|
||||
BROWSER_SCREENSHOT = "browser_screenshot"
|
||||
# Code execution
|
||||
BASH_EXEC = "bash_exec"
|
||||
# Feature request types
|
||||
@@ -476,3 +480,32 @@ class FeatureRequestCreatedResponse(ToolResponseBase):
|
||||
issue_url: str
|
||||
is_new_issue: bool # False if added to existing
|
||||
customer_name: str
|
||||
|
||||
|
||||
# Agent-browser multi-step automation models
|
||||
|
||||
|
||||
class BrowserNavigateResponse(ToolResponseBase):
|
||||
"""Response for browser_navigate tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BROWSER_NAVIGATE
|
||||
url: str
|
||||
title: str
|
||||
snapshot: str # Interactive accessibility tree with @ref IDs
|
||||
|
||||
|
||||
class BrowserActResponse(ToolResponseBase):
|
||||
"""Response for browser_act tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BROWSER_ACT
|
||||
action: str
|
||||
current_url: str = ""
|
||||
snapshot: str # Updated accessibility tree after the action
|
||||
|
||||
|
||||
class BrowserScreenshotResponse(ToolResponseBase):
|
||||
"""Response for browser_screenshot tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BROWSER_SCREENSHOT
|
||||
file_id: str # Workspace file ID — use read_workspace_file to retrieve
|
||||
filename: str
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.e2b_sandbox import E2B_WORKDIR
|
||||
from backend.copilot.tools.sandbox import make_session_path
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.settings import Config
|
||||
@@ -20,7 +21,7 @@ from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_write_content(
|
||||
async def _resolve_write_content(
|
||||
content_text: str | None,
|
||||
content_b64: str | None,
|
||||
source_path: str | None,
|
||||
@@ -30,6 +31,9 @@ def _resolve_write_content(
|
||||
|
||||
Returns the raw bytes on success, or an ``ErrorResponse`` on validation
|
||||
failure (wrong number of sources, invalid path, file not found, etc.).
|
||||
|
||||
When an E2B sandbox is active, ``source_path`` reads from the sandbox
|
||||
filesystem instead of the local ephemeral directory.
|
||||
"""
|
||||
# Normalise empty strings to None so counting and dispatch stay in sync.
|
||||
if content_text is not None and content_text == "":
|
||||
@@ -54,24 +58,7 @@ def _resolve_write_content(
|
||||
)
|
||||
|
||||
if source_path is not None:
|
||||
validated = _validate_ephemeral_path(
|
||||
source_path, param_name="source_path", session_id=session_id
|
||||
)
|
||||
if isinstance(validated, ErrorResponse):
|
||||
return validated
|
||||
try:
|
||||
with open(validated, "rb") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
return ErrorResponse(
|
||||
message=f"Source file not found: {source_path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read source file: {e}",
|
||||
session_id=session_id,
|
||||
)
|
||||
return await _read_source_path(source_path, session_id)
|
||||
|
||||
if content_b64 is not None:
|
||||
try:
|
||||
@@ -91,6 +78,106 @@ def _resolve_write_content(
|
||||
return content_text.encode("utf-8")
|
||||
|
||||
|
||||
def _resolve_sandbox_path(
|
||||
path: str, session_id: str | None, param_name: str
|
||||
) -> str | ErrorResponse:
|
||||
"""Normalize *path* to an absolute sandbox path under :data:`E2B_WORKDIR`.
|
||||
|
||||
Delegates to :func:`~backend.copilot.sdk.e2b_file_tools._resolve_remote`
|
||||
and wraps any ``ValueError`` into an :class:`ErrorResponse`.
|
||||
"""
|
||||
from backend.copilot.sdk.e2b_file_tools import _resolve_remote
|
||||
|
||||
try:
|
||||
return _resolve_remote(path)
|
||||
except ValueError:
|
||||
return ErrorResponse(
|
||||
message=f"{param_name} must be within {E2B_WORKDIR}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
async def _read_source_path(source_path: str, session_id: str) -> bytes | ErrorResponse:
|
||||
"""Read *source_path* from E2B sandbox or local ephemeral directory."""
|
||||
from backend.copilot.sdk.tool_adapter import get_current_sandbox
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
remote = _resolve_sandbox_path(source_path, session_id, "source_path")
|
||||
if isinstance(remote, ErrorResponse):
|
||||
return remote
|
||||
try:
|
||||
data = await sandbox.files.read(remote, format="bytes")
|
||||
return bytes(data)
|
||||
except Exception as exc:
|
||||
return ErrorResponse(
|
||||
message=f"Source file not found on sandbox: {source_path} ({exc})",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Local fallback: validate path stays within ephemeral directory.
|
||||
validated = _validate_ephemeral_path(
|
||||
source_path, param_name="source_path", session_id=session_id
|
||||
)
|
||||
if isinstance(validated, ErrorResponse):
|
||||
return validated
|
||||
try:
|
||||
with open(validated, "rb") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
return ErrorResponse(
|
||||
message=f"Source file not found: {source_path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read source file: {e}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
async def _save_to_path(
|
||||
path: str, content: bytes, session_id: str
|
||||
) -> str | ErrorResponse:
|
||||
"""Write *content* to *path* on E2B sandbox or local ephemeral directory.
|
||||
|
||||
Returns the resolved path on success, or an ``ErrorResponse`` on failure.
|
||||
"""
|
||||
from backend.copilot.sdk.tool_adapter import get_current_sandbox
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
remote = _resolve_sandbox_path(path, session_id, "save_to_path")
|
||||
if isinstance(remote, ErrorResponse):
|
||||
return remote
|
||||
try:
|
||||
await sandbox.files.write(remote, content)
|
||||
except Exception as exc:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to write to sandbox: {path} ({exc})",
|
||||
session_id=session_id,
|
||||
)
|
||||
return remote
|
||||
|
||||
validated = _validate_ephemeral_path(
|
||||
path, param_name="save_to_path", session_id=session_id
|
||||
)
|
||||
if isinstance(validated, ErrorResponse):
|
||||
return validated
|
||||
try:
|
||||
dir_path = os.path.dirname(validated)
|
||||
if dir_path:
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
with open(validated, "wb") as f:
|
||||
f.write(content)
|
||||
except Exception as exc:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to write to local path: {path} ({exc})",
|
||||
session_id=session_id,
|
||||
)
|
||||
return validated
|
||||
|
||||
|
||||
def _validate_ephemeral_path(
|
||||
path: str, *, param_name: str, session_id: str
|
||||
) -> ErrorResponse | str:
|
||||
@@ -131,7 +218,7 @@ def _is_text_mime(mime_type: str) -> bool:
|
||||
return any(mime_type.startswith(t) for t in _TEXT_MIME_PREFIXES)
|
||||
|
||||
|
||||
async def _get_manager(user_id: str, session_id: str) -> WorkspaceManager:
|
||||
async def get_manager(user_id: str, session_id: str) -> WorkspaceManager:
|
||||
"""Create a session-scoped WorkspaceManager."""
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
return WorkspaceManager(user_id, workspace.id, session_id)
|
||||
@@ -299,7 +386,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||
|
||||
try:
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
files = await manager.list_files(
|
||||
path=path_prefix, limit=limit, include_all_sessions=include_all_sessions
|
||||
)
|
||||
@@ -429,17 +516,8 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
message="Please provide either file_id or path", session_id=session_id
|
||||
)
|
||||
|
||||
# Validate and resolve save_to_path (use sanitized real path).
|
||||
if save_to_path:
|
||||
validated_save = _validate_ephemeral_path(
|
||||
save_to_path, param_name="save_to_path", session_id=session_id
|
||||
)
|
||||
if isinstance(validated_save, ErrorResponse):
|
||||
return validated_save
|
||||
save_to_path = validated_save
|
||||
|
||||
try:
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
@@ -449,11 +527,10 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
cached_content: bytes | None = None
|
||||
if save_to_path:
|
||||
cached_content = await manager.read_file_by_id(target_file_id)
|
||||
dir_path = os.path.dirname(save_to_path)
|
||||
if dir_path:
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
with open(save_to_path, "wb") as f:
|
||||
f.write(cached_content)
|
||||
result = await _save_to_path(save_to_path, cached_content, session_id)
|
||||
if isinstance(result, ErrorResponse):
|
||||
return result
|
||||
save_to_path = result
|
||||
|
||||
is_small = file_info.size_bytes <= self.MAX_INLINE_SIZE_BYTES
|
||||
is_text = _is_text_mime(file_info.mime_type)
|
||||
@@ -629,7 +706,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
content_text: str | None = kwargs.get("content")
|
||||
content_b64: str | None = kwargs.get("content_base64")
|
||||
|
||||
resolved = _resolve_write_content(
|
||||
resolved = await _resolve_write_content(
|
||||
content_text,
|
||||
content_b64,
|
||||
source_path_arg,
|
||||
@@ -648,7 +725,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
|
||||
try:
|
||||
await scan_content_safe(content, filename=filename)
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
rec = await manager.write_file(
|
||||
content=content,
|
||||
filename=filename,
|
||||
@@ -775,7 +852,7 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
|
||||
@@ -102,67 +102,68 @@ class TestValidateEphemeralPath:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
class TestResolveWriteContent:
|
||||
def test_no_sources_returns_error(self):
|
||||
async def test_no_sources_returns_error(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
result = _resolve_write_content(None, None, None, "s1")
|
||||
result = await _resolve_write_content(None, None, None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_multiple_sources_returns_error(self):
|
||||
async def test_multiple_sources_returns_error(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
result = _resolve_write_content("text", "b64data", None, "s1")
|
||||
result = await _resolve_write_content("text", "b64data", None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_plain_text_content(self):
|
||||
result = _resolve_write_content("hello world", None, None, "s1")
|
||||
async def test_plain_text_content(self):
|
||||
result = await _resolve_write_content("hello world", None, None, "s1")
|
||||
assert result == b"hello world"
|
||||
|
||||
def test_base64_content(self):
|
||||
async def test_base64_content(self):
|
||||
raw = b"binary data"
|
||||
b64 = base64.b64encode(raw).decode()
|
||||
result = _resolve_write_content(None, b64, None, "s1")
|
||||
result = await _resolve_write_content(None, b64, None, "s1")
|
||||
assert result == raw
|
||||
|
||||
def test_invalid_base64_returns_error(self):
|
||||
async def test_invalid_base64_returns_error(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
result = _resolve_write_content(None, "not-valid-b64!!!", None, "s1")
|
||||
result = await _resolve_write_content(None, "not-valid-b64!!!", None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "base64" in result.message.lower()
|
||||
|
||||
def test_source_path(self, ephemeral_dir):
|
||||
async def test_source_path(self, ephemeral_dir):
|
||||
target = ephemeral_dir / "input.txt"
|
||||
target.write_bytes(b"file content")
|
||||
result = _resolve_write_content(None, None, str(target), "s1")
|
||||
result = await _resolve_write_content(None, None, str(target), "s1")
|
||||
assert result == b"file content"
|
||||
|
||||
def test_source_path_not_found(self, ephemeral_dir):
|
||||
async def test_source_path_not_found(self, ephemeral_dir):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
missing = str(ephemeral_dir / "nope.txt")
|
||||
result = _resolve_write_content(None, None, missing, "s1")
|
||||
result = await _resolve_write_content(None, None, missing, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_source_path_outside_ephemeral(self, ephemeral_dir, tmp_path):
|
||||
async def test_source_path_outside_ephemeral(self, ephemeral_dir, tmp_path):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
outside = tmp_path / "outside.txt"
|
||||
outside.write_text("nope")
|
||||
result = _resolve_write_content(None, None, str(outside), "s1")
|
||||
result = await _resolve_write_content(None, None, str(outside), "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_empty_string_sources_treated_as_none(self):
|
||||
async def test_empty_string_sources_treated_as_none(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
# All empty strings → same as no sources
|
||||
result = _resolve_write_content("", "", "", "s1")
|
||||
result = await _resolve_write_content("", "", "", "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_empty_string_source_path_with_text(self):
|
||||
async def test_empty_string_source_path_with_text(self):
|
||||
# source_path="" should be normalised to None, so only content counts
|
||||
result = _resolve_write_content("hello", "", "", "s1")
|
||||
result = await _resolve_write_content("hello", "", "", "s1")
|
||||
assert result == b"hello"
|
||||
|
||||
|
||||
|
||||
@@ -327,11 +327,16 @@ async def get_workspace_total_size(workspace_id: str) -> int:
|
||||
"""
|
||||
Get the total size of all files in a workspace.
|
||||
|
||||
Queries Prisma directly (skipping Pydantic model conversion) and only
|
||||
fetches the ``sizeBytes`` column to minimise data transfer.
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace ID
|
||||
|
||||
Returns:
|
||||
Total size in bytes
|
||||
"""
|
||||
files = await list_workspace_files(workspace_id)
|
||||
return sum(file.size_bytes for file in files)
|
||||
files = await UserWorkspaceFile.prisma().find_many(
|
||||
where={"workspaceId": workspace_id, "isDeleted": False},
|
||||
)
|
||||
return sum(f.sizeBytes for f in files)
|
||||
|
||||
@@ -105,8 +105,13 @@ def validate_with_jsonschema(
|
||||
return str(e)
|
||||
|
||||
|
||||
def _sanitize_string(value: str) -> str:
|
||||
"""Remove PostgreSQL-incompatible control characters from string."""
|
||||
def sanitize_string(value: str) -> str:
|
||||
"""Remove PostgreSQL-incompatible control characters from string.
|
||||
|
||||
Strips \\x00-\\x08, \\x0B-\\x0C, \\x0E-\\x1F, \\x7F while keeping tab,
|
||||
newline, and carriage return. Use this before inserting free-form text
|
||||
into PostgreSQL text/varchar columns.
|
||||
"""
|
||||
return POSTGRES_CONTROL_CHARS.sub("", value)
|
||||
|
||||
|
||||
@@ -116,7 +121,7 @@ def sanitize_json(data: Any) -> Any:
|
||||
# 1. First convert to basic JSON-serializable types (handles Pydantic models)
|
||||
# 2. Then sanitize strings in the result
|
||||
basic_result = to_dict(data)
|
||||
return to_dict(basic_result, custom_encoder={str: _sanitize_string})
|
||||
return to_dict(basic_result, custom_encoder={str: sanitize_string})
|
||||
except Exception as e:
|
||||
# Log the failure and fall back to string representation
|
||||
logger.error(
|
||||
@@ -129,7 +134,7 @@ def sanitize_json(data: Any) -> Any:
|
||||
)
|
||||
|
||||
# Ultimate fallback: convert to string representation and sanitize
|
||||
return _sanitize_string(str(data))
|
||||
return sanitize_string(str(data))
|
||||
|
||||
|
||||
class SafeJson(Json):
|
||||
|
||||
@@ -413,6 +413,13 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="Maximum file size in MB for workspace files (1-1024 MB)",
|
||||
)
|
||||
|
||||
max_workspace_storage_mb: int = Field(
|
||||
default=500,
|
||||
ge=1,
|
||||
le=10240,
|
||||
description="Maximum total workspace storage per user in MB.",
|
||||
)
|
||||
|
||||
# AutoMod configuration
|
||||
automod_enabled: bool = Field(
|
||||
default=False,
|
||||
|
||||
@@ -79,6 +79,10 @@ def truncate(value: Any, size_limit: int) -> Any:
|
||||
largest str_limit and list_limit that fit.
|
||||
"""
|
||||
|
||||
# Fast path: plain strings don't need the binary search machinery.
|
||||
if isinstance(value, str):
|
||||
return _truncate_string_middle(value, size_limit)
|
||||
|
||||
def measure(val):
|
||||
try:
|
||||
return len(str(val))
|
||||
@@ -86,7 +90,7 @@ def truncate(value: Any, size_limit: int) -> Any:
|
||||
return sys.getsizeof(val)
|
||||
|
||||
# Reasonable bounds for string and list limits
|
||||
STR_MIN, STR_MAX = 8, 2**16
|
||||
STR_MIN, STR_MAX = min(8, size_limit), size_limit
|
||||
LIST_MIN, LIST_MAX = 1, 2**12
|
||||
|
||||
# Binary search for the largest str_limit and list_limit that fit
|
||||
|
||||
@@ -162,7 +162,7 @@ services:
|
||||
context: ../
|
||||
dockerfile: autogpt_platform/backend/Dockerfile
|
||||
target: server
|
||||
command: ["python", "-m", "backend.copilot.executor"]
|
||||
command: ["python", "-u", "-m", "backend.copilot.executor"]
|
||||
develop:
|
||||
watch:
|
||||
- path: ./
|
||||
@@ -182,6 +182,7 @@ services:
|
||||
<<: *backend-env-files
|
||||
environment:
|
||||
<<: *backend-env
|
||||
PYTHONUNBUFFERED: "1"
|
||||
ports:
|
||||
- "8008:8008"
|
||||
networks:
|
||||
|
||||
@@ -7,7 +7,9 @@ import {
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/molecules/DropdownMenu/DropdownMenu";
|
||||
import { SidebarProvider } from "@/components/ui/sidebar";
|
||||
import { DotsThree } from "@phosphor-icons/react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { DotsThree, UploadSimple } from "@phosphor-icons/react";
|
||||
import { useCallback, useRef, useState } from "react";
|
||||
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
||||
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
|
||||
import { DeleteChatDialog } from "./components/DeleteChatDialog/DeleteChatDialog";
|
||||
@@ -17,6 +19,49 @@ import { ScaleLoader } from "./components/ScaleLoader/ScaleLoader";
|
||||
import { useCopilotPage } from "./useCopilotPage";
|
||||
|
||||
export function CopilotPage() {
|
||||
const [isDragging, setIsDragging] = useState(false);
|
||||
const [droppedFiles, setDroppedFiles] = useState<File[]>([]);
|
||||
const dragCounter = useRef(0);
|
||||
|
||||
const handleDroppedFilesConsumed = useCallback(() => {
|
||||
setDroppedFiles([]);
|
||||
}, []);
|
||||
|
||||
function handleDragEnter(e: React.DragEvent) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
dragCounter.current += 1;
|
||||
if (e.dataTransfer.types.includes("Files")) {
|
||||
setIsDragging(true);
|
||||
}
|
||||
}
|
||||
|
||||
function handleDragOver(e: React.DragEvent) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
}
|
||||
|
||||
function handleDragLeave(e: React.DragEvent) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
dragCounter.current -= 1;
|
||||
if (dragCounter.current === 0) {
|
||||
setIsDragging(false);
|
||||
}
|
||||
}
|
||||
|
||||
function handleDrop(e: React.DragEvent) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
dragCounter.current = 0;
|
||||
setIsDragging(false);
|
||||
|
||||
const files = Array.from(e.dataTransfer.files);
|
||||
if (files.length > 0) {
|
||||
setDroppedFiles(files);
|
||||
}
|
||||
}
|
||||
|
||||
const {
|
||||
sessionId,
|
||||
messages,
|
||||
@@ -29,6 +74,7 @@ export function CopilotPage() {
|
||||
isLoadingSession,
|
||||
isSessionError,
|
||||
isCreatingSession,
|
||||
isUploadingFiles,
|
||||
isUserLoading,
|
||||
isLoggedIn,
|
||||
// Mobile drawer
|
||||
@@ -63,8 +109,26 @@ export function CopilotPage() {
|
||||
className="h-[calc(100vh-72px)] min-h-0"
|
||||
>
|
||||
{!isMobile && <ChatSidebar />}
|
||||
<div className="relative flex h-full w-full flex-col overflow-hidden bg-[#f8f8f9] px-0">
|
||||
<div
|
||||
className="relative flex h-full w-full flex-col overflow-hidden bg-[#f8f8f9] px-0"
|
||||
onDragEnter={handleDragEnter}
|
||||
onDragOver={handleDragOver}
|
||||
onDragLeave={handleDragLeave}
|
||||
onDrop={handleDrop}
|
||||
>
|
||||
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
|
||||
{/* Drop overlay */}
|
||||
<div
|
||||
className={cn(
|
||||
"pointer-events-none absolute inset-0 z-50 flex flex-col items-center justify-center gap-3 rounded-lg border-2 border-dashed border-violet-400 bg-violet-500/10 transition-opacity duration-150",
|
||||
isDragging ? "opacity-100" : "opacity-0",
|
||||
)}
|
||||
>
|
||||
<UploadSimple className="h-10 w-10 text-violet-500" weight="bold" />
|
||||
<span className="text-lg font-medium text-violet-600">
|
||||
Drop files here
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex-1 overflow-hidden">
|
||||
<ChatContainer
|
||||
messages={messages}
|
||||
@@ -78,6 +142,9 @@ export function CopilotPage() {
|
||||
onCreateSession={createSession}
|
||||
onSend={onSend}
|
||||
onStop={stop}
|
||||
isUploadingFiles={isUploadingFiles}
|
||||
droppedFiles={droppedFiles}
|
||||
onDroppedFilesConsumed={handleDroppedFilesConsumed}
|
||||
headerSlot={
|
||||
isMobile && sessionId ? (
|
||||
<div className="flex justify-end">
|
||||
|
||||
@@ -18,9 +18,14 @@ export interface ChatContainerProps {
|
||||
/** True when backend has an active stream but we haven't reconnected yet. */
|
||||
isReconnecting?: boolean;
|
||||
onCreateSession: () => void | Promise<string>;
|
||||
onSend: (message: string) => void | Promise<void>;
|
||||
onSend: (message: string, files?: File[]) => void | Promise<void>;
|
||||
onStop: () => void;
|
||||
isUploadingFiles?: boolean;
|
||||
headerSlot?: ReactNode;
|
||||
/** Files dropped onto the chat window. */
|
||||
droppedFiles?: File[];
|
||||
/** Called after droppedFiles have been consumed by ChatInput. */
|
||||
onDroppedFilesConsumed?: () => void;
|
||||
}
|
||||
export const ChatContainer = ({
|
||||
messages,
|
||||
@@ -34,7 +39,10 @@ export const ChatContainer = ({
|
||||
onCreateSession,
|
||||
onSend,
|
||||
onStop,
|
||||
isUploadingFiles,
|
||||
headerSlot,
|
||||
droppedFiles,
|
||||
onDroppedFilesConsumed,
|
||||
}: ChatContainerProps) => {
|
||||
const isBusy =
|
||||
status === "streaming" ||
|
||||
@@ -69,8 +77,11 @@ export const ChatContainer = ({
|
||||
onSend={onSend}
|
||||
disabled={isBusy}
|
||||
isStreaming={isBusy}
|
||||
isUploadingFiles={isUploadingFiles}
|
||||
onStop={onStop}
|
||||
placeholder="What else can I help with?"
|
||||
droppedFiles={droppedFiles}
|
||||
onDroppedFilesConsumed={onDroppedFilesConsumed}
|
||||
/>
|
||||
</motion.div>
|
||||
</div>
|
||||
@@ -80,6 +91,9 @@ export const ChatContainer = ({
|
||||
isCreatingSession={isCreatingSession}
|
||||
onCreateSession={onCreateSession}
|
||||
onSend={onSend}
|
||||
isUploadingFiles={isUploadingFiles}
|
||||
droppedFiles={droppedFiles}
|
||||
onDroppedFilesConsumed={onDroppedFilesConsumed}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -9,38 +9,66 @@ import {
|
||||
import { InputGroup } from "@/components/ui/input-group";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { CircleNotchIcon, MicrophoneIcon } from "@phosphor-icons/react";
|
||||
import { ChangeEvent } from "react";
|
||||
import { ChangeEvent, useEffect, useState } from "react";
|
||||
import { AttachmentMenu } from "./components/AttachmentMenu";
|
||||
import { FileChips } from "./components/FileChips";
|
||||
import { RecordingIndicator } from "./components/RecordingIndicator";
|
||||
import { useChatInput } from "./useChatInput";
|
||||
import { useVoiceRecording } from "./useVoiceRecording";
|
||||
|
||||
export interface Props {
|
||||
onSend: (message: string) => void | Promise<void>;
|
||||
onSend: (message: string, files?: File[]) => void | Promise<void>;
|
||||
disabled?: boolean;
|
||||
isStreaming?: boolean;
|
||||
isUploadingFiles?: boolean;
|
||||
onStop?: () => void;
|
||||
placeholder?: string;
|
||||
className?: string;
|
||||
inputId?: string;
|
||||
/** Files dropped onto the chat window by the parent. */
|
||||
droppedFiles?: File[];
|
||||
/** Called after droppedFiles have been merged into internal state. */
|
||||
onDroppedFilesConsumed?: () => void;
|
||||
}
|
||||
|
||||
export function ChatInput({
|
||||
onSend,
|
||||
disabled = false,
|
||||
isStreaming = false,
|
||||
isUploadingFiles = false,
|
||||
onStop,
|
||||
placeholder = "Type your message...",
|
||||
className,
|
||||
inputId = "chat-input",
|
||||
droppedFiles,
|
||||
onDroppedFilesConsumed,
|
||||
}: Props) {
|
||||
const [files, setFiles] = useState<File[]>([]);
|
||||
|
||||
// Merge files dropped onto the chat window into internal state.
|
||||
useEffect(() => {
|
||||
if (droppedFiles && droppedFiles.length > 0) {
|
||||
setFiles((prev) => [...prev, ...droppedFiles]);
|
||||
onDroppedFilesConsumed?.();
|
||||
}
|
||||
}, [droppedFiles, onDroppedFilesConsumed]);
|
||||
|
||||
const hasFiles = files.length > 0;
|
||||
const isBusy = disabled || isStreaming || isUploadingFiles;
|
||||
|
||||
const {
|
||||
value,
|
||||
setValue,
|
||||
handleSubmit,
|
||||
handleChange: baseHandleChange,
|
||||
} = useChatInput({
|
||||
onSend,
|
||||
disabled: disabled || isStreaming,
|
||||
onSend: async (message: string) => {
|
||||
await onSend(message, hasFiles ? files : undefined);
|
||||
// Only clear files after successful send (onSend throws on failure)
|
||||
setFiles([]);
|
||||
},
|
||||
disabled: isBusy,
|
||||
canSendEmpty: hasFiles,
|
||||
inputId,
|
||||
});
|
||||
|
||||
@@ -55,7 +83,7 @@ export function ChatInput({
|
||||
audioStream,
|
||||
} = useVoiceRecording({
|
||||
setValue,
|
||||
disabled: disabled || isStreaming,
|
||||
disabled: isBusy,
|
||||
isStreaming,
|
||||
value,
|
||||
inputId,
|
||||
@@ -67,7 +95,18 @@ export function ChatInput({
|
||||
}
|
||||
|
||||
const canSend =
|
||||
!disabled && !!value.trim() && !isRecording && !isTranscribing;
|
||||
!disabled &&
|
||||
(!!value.trim() || hasFiles) &&
|
||||
!isRecording &&
|
||||
!isTranscribing;
|
||||
|
||||
function handleFilesSelected(newFiles: File[]) {
|
||||
setFiles((prev) => [...prev, ...newFiles]);
|
||||
}
|
||||
|
||||
function handleRemoveFile(index: number) {
|
||||
setFiles((prev) => prev.filter((_, i) => i !== index));
|
||||
}
|
||||
|
||||
return (
|
||||
<form onSubmit={handleSubmit} className={cn("relative flex-1", className)}>
|
||||
@@ -78,6 +117,11 @@ export function ChatInput({
|
||||
"border-red-400 ring-1 ring-red-400 has-[[data-slot=input-group-control]:focus-visible]:border-red-400 has-[[data-slot=input-group-control]:focus-visible]:ring-red-400",
|
||||
)}
|
||||
>
|
||||
<FileChips
|
||||
files={files}
|
||||
onRemove={handleRemoveFile}
|
||||
isUploading={isUploadingFiles}
|
||||
/>
|
||||
<PromptInputBody className="relative block w-full">
|
||||
<PromptInputTextarea
|
||||
id={inputId}
|
||||
@@ -104,6 +148,10 @@ export function ChatInput({
|
||||
|
||||
<PromptInputFooter>
|
||||
<PromptInputTools>
|
||||
<AttachmentMenu
|
||||
onFilesSelected={handleFilesSelected}
|
||||
disabled={isBusy}
|
||||
/>
|
||||
{showMicButton && (
|
||||
<PromptInputButton
|
||||
aria-label={isRecording ? "Stop recording" : "Start recording"}
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Plus as PlusIcon } from "@phosphor-icons/react";
|
||||
import { useRef } from "react";
|
||||
|
||||
interface Props {
|
||||
onFilesSelected: (files: File[]) => void;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
export function AttachmentMenu({ onFilesSelected, disabled }: Props) {
|
||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
function handleClick() {
|
||||
fileInputRef.current?.click();
|
||||
}
|
||||
|
||||
function handleFileChange(e: React.ChangeEvent<HTMLInputElement>) {
|
||||
const files = Array.from(e.target.files ?? []);
|
||||
if (files.length > 0) {
|
||||
onFilesSelected(files);
|
||||
}
|
||||
// Reset so the same file can be re-selected
|
||||
e.target.value = "";
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
multiple
|
||||
className="hidden"
|
||||
onChange={handleFileChange}
|
||||
tabIndex={-1}
|
||||
/>
|
||||
<Button
|
||||
type="button"
|
||||
variant="icon"
|
||||
size="icon"
|
||||
aria-label="Attach file"
|
||||
disabled={disabled}
|
||||
onClick={handleClick}
|
||||
className={cn(
|
||||
"border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700",
|
||||
disabled && "opacity-40",
|
||||
)}
|
||||
>
|
||||
<PlusIcon className="h-4 w-4" weight="bold" />
|
||||
</Button>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
CircleNotch as CircleNotchIcon,
|
||||
X as XIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
|
||||
interface Props {
|
||||
files: File[];
|
||||
onRemove: (index: number) => void;
|
||||
isUploading?: boolean;
|
||||
}
|
||||
|
||||
export function FileChips({ files, onRemove, isUploading }: Props) {
|
||||
if (files.length === 0) return null;
|
||||
|
||||
return (
|
||||
<div className="flex w-full flex-wrap gap-2 px-3 pb-2 pt-1">
|
||||
{files.map((file, index) => (
|
||||
<span
|
||||
key={`${file.name}-${file.size}-${index}`}
|
||||
className={cn(
|
||||
"inline-flex items-center gap-1 rounded-full bg-zinc-100 px-3 py-1 text-sm text-zinc-700",
|
||||
isUploading && "opacity-70",
|
||||
)}
|
||||
>
|
||||
<span className="max-w-[160px] truncate">{file.name}</span>
|
||||
{isUploading ? (
|
||||
<CircleNotchIcon className="ml-0.5 h-3 w-3 animate-spin text-zinc-400" />
|
||||
) : (
|
||||
<button
|
||||
type="button"
|
||||
aria-label={`Remove ${file.name}`}
|
||||
onClick={() => onRemove(index)}
|
||||
className="ml-0.5 rounded-full p-0.5 text-zinc-400 transition-colors hover:bg-zinc-200 hover:text-zinc-600"
|
||||
>
|
||||
<XIcon className="h-3 w-3" weight="bold" />
|
||||
</button>
|
||||
)}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -3,12 +3,15 @@ import { ChangeEvent, FormEvent, useEffect, useState } from "react";
|
||||
interface Args {
|
||||
onSend: (message: string) => void;
|
||||
disabled?: boolean;
|
||||
/** Allow sending when text is empty (e.g. when files are attached). */
|
||||
canSendEmpty?: boolean;
|
||||
inputId?: string;
|
||||
}
|
||||
|
||||
export function useChatInput({
|
||||
onSend,
|
||||
disabled = false,
|
||||
canSendEmpty = false,
|
||||
inputId = "chat-input",
|
||||
}: Args) {
|
||||
const [value, setValue] = useState("");
|
||||
@@ -32,7 +35,7 @@ export function useChatInput({
|
||||
);
|
||||
|
||||
async function handleSend() {
|
||||
if (disabled || isSending || !value.trim()) return;
|
||||
if (disabled || isSending || (!value.trim() && !canSendEmpty)) return;
|
||||
|
||||
setIsSending(true);
|
||||
try {
|
||||
|
||||
@@ -5,7 +5,8 @@ import {
|
||||
} from "@/components/ai-elements/conversation";
|
||||
import { Message, MessageContent } from "@/components/ai-elements/message";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { FileUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { MessageAttachments } from "./components/MessageAttachments";
|
||||
import { MessagePartRenderer } from "./components/MessagePartRenderer";
|
||||
import { ThinkingIndicator } from "./components/ThinkingIndicator";
|
||||
|
||||
@@ -72,6 +73,10 @@ export function ChatMessagesContainer({
|
||||
messageIndex === messages.length - 1 &&
|
||||
message.role === "assistant";
|
||||
|
||||
const fileParts = message.parts.filter(
|
||||
(p): p is FileUIPart => p.type === "file",
|
||||
);
|
||||
|
||||
return (
|
||||
<Message from={message.role} key={message.id}>
|
||||
<MessageContent
|
||||
@@ -93,6 +98,12 @@ export function ChatMessagesContainer({
|
||||
<ThinkingIndicator active={showThinking} />
|
||||
)}
|
||||
</MessageContent>
|
||||
{fileParts.length > 0 && (
|
||||
<MessageAttachments
|
||||
files={fileParts}
|
||||
isUser={message.role === "user"}
|
||||
/>
|
||||
)}
|
||||
</Message>
|
||||
);
|
||||
})}
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
import {
|
||||
FileText as FileTextIcon,
|
||||
DownloadSimple as DownloadIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import type { FileUIPart } from "ai";
|
||||
import {
|
||||
ContentCard,
|
||||
ContentCardHeader,
|
||||
ContentCardTitle,
|
||||
ContentCardSubtitle,
|
||||
} from "../../ToolAccordion/AccordionContent";
|
||||
|
||||
interface Props {
|
||||
files: FileUIPart[];
|
||||
isUser?: boolean;
|
||||
}
|
||||
|
||||
export function MessageAttachments({ files, isUser }: Props) {
|
||||
if (files.length === 0) return null;
|
||||
|
||||
return (
|
||||
<div className="mt-2 flex flex-col gap-2">
|
||||
{files.map((file, i) =>
|
||||
isUser ? (
|
||||
<div
|
||||
key={`${file.filename}-${i}`}
|
||||
className="min-w-0 rounded-lg border border-purple-300 bg-purple-100 p-3"
|
||||
>
|
||||
<div className="flex items-start justify-between gap-2">
|
||||
<div className="flex min-w-0 items-center gap-2">
|
||||
<FileTextIcon className="h-5 w-5 shrink-0 text-neutral-400" />
|
||||
<div className="min-w-0">
|
||||
<p className="truncate text-sm font-medium text-zinc-800">
|
||||
{file.filename || "file"}
|
||||
</p>
|
||||
<p className="mt-0.5 truncate font-mono text-xs text-zinc-800">
|
||||
{file.mediaType || "file"}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
{file.url && (
|
||||
<a
|
||||
href={file.url}
|
||||
download
|
||||
aria-label="Download file"
|
||||
className="shrink-0 text-purple-400 hover:text-purple-600"
|
||||
>
|
||||
<DownloadIcon className="h-5 w-5" />
|
||||
</a>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<ContentCard key={`${file.filename}-${i}`}>
|
||||
<ContentCardHeader
|
||||
action={
|
||||
file.url ? (
|
||||
<a
|
||||
href={file.url}
|
||||
download
|
||||
aria-label="Download file"
|
||||
className="shrink-0 text-neutral-400 hover:text-neutral-600"
|
||||
>
|
||||
<DownloadIcon className="h-5 w-5" />
|
||||
</a>
|
||||
) : undefined
|
||||
}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<FileTextIcon className="h-5 w-5 shrink-0 text-neutral-400" />
|
||||
<div className="min-w-0">
|
||||
<ContentCardTitle>{file.filename || "file"}</ContentCardTitle>
|
||||
<ContentCardSubtitle>
|
||||
{file.mediaType || "file"}
|
||||
</ContentCardSubtitle>
|
||||
</div>
|
||||
</div>
|
||||
</ContentCardHeader>
|
||||
</ContentCard>
|
||||
),
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -17,13 +17,19 @@ interface Props {
|
||||
inputLayoutId: string;
|
||||
isCreatingSession: boolean;
|
||||
onCreateSession: () => void | Promise<string>;
|
||||
onSend: (message: string) => void | Promise<void>;
|
||||
onSend: (message: string, files?: File[]) => void | Promise<void>;
|
||||
isUploadingFiles?: boolean;
|
||||
droppedFiles?: File[];
|
||||
onDroppedFilesConsumed?: () => void;
|
||||
}
|
||||
|
||||
export function EmptySession({
|
||||
inputLayoutId,
|
||||
isCreatingSession,
|
||||
onSend,
|
||||
isUploadingFiles,
|
||||
droppedFiles,
|
||||
onDroppedFilesConsumed,
|
||||
}: Props) {
|
||||
const { user } = useSupabase();
|
||||
const greetingName = getGreetingName(user);
|
||||
@@ -51,12 +57,12 @@ export function EmptySession({
|
||||
return (
|
||||
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-0 py-5 md:px-6 md:py-10">
|
||||
<motion.div
|
||||
className="w-full max-w-3xl text-center"
|
||||
className="w-full max-w-[52rem] text-center"
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
transition={{ duration: 0.3 }}
|
||||
>
|
||||
<div className="mx-auto max-w-3xl">
|
||||
<div className="mx-auto max-w-[52rem]">
|
||||
<Text variant="h3" className="mb-1 !text-[1.375rem] text-zinc-700">
|
||||
Hey, <span className="text-violet-600">{greetingName}</span>
|
||||
</Text>
|
||||
@@ -74,8 +80,11 @@ export function EmptySession({
|
||||
inputId="chat-input-empty"
|
||||
onSend={onSend}
|
||||
disabled={isCreatingSession}
|
||||
isUploadingFiles={isUploadingFiles}
|
||||
placeholder={inputPlaceholder}
|
||||
className="w-full"
|
||||
droppedFiles={droppedFiles}
|
||||
onDroppedFilesConsumed={onDroppedFilesConsumed}
|
||||
/>
|
||||
</motion.div>
|
||||
</div>
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import type { UIMessage, UIDataTypes, UITools } from "ai";
|
||||
import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
|
||||
import type { FileUIPart, UIMessage, UIDataTypes, UITools } from "ai";
|
||||
|
||||
interface SessionChatMessage {
|
||||
role: string;
|
||||
@@ -38,6 +39,48 @@ function coerceSessionChatMessages(
|
||||
.filter((m): m is SessionChatMessage => m !== null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse the `[Attached files]` block appended by the backend and return
|
||||
* the cleaned text plus reconstructed FileUIPart objects.
|
||||
*
|
||||
* Backend format:
|
||||
* ```
|
||||
* \n\n[Attached files]
|
||||
* - name.jpg (image/jpeg, 191.0 KB), file_id=<uuid>
|
||||
* Use read_workspace_file with the file_id to access file contents.
|
||||
* ```
|
||||
*/
|
||||
const ATTACHED_FILES_RE =
|
||||
/\n?\n?\[Attached files\]\n([\s\S]*?)Use read_workspace_file with the file_id to access file contents\./;
|
||||
const FILE_LINE_RE = /^- (.+) \(([^,]+),\s*[\d.]+ KB\), file_id=([0-9a-f-]+)$/;
|
||||
|
||||
function extractFileParts(content: string): {
|
||||
cleanText: string;
|
||||
fileParts: FileUIPart[];
|
||||
} {
|
||||
const match = content.match(ATTACHED_FILES_RE);
|
||||
if (!match) return { cleanText: content, fileParts: [] };
|
||||
|
||||
const cleanText = content.replace(match[0], "").trim();
|
||||
const lines = match[1].trim().split("\n");
|
||||
const fileParts: FileUIPart[] = [];
|
||||
|
||||
for (const line of lines) {
|
||||
const m = line.trim().match(FILE_LINE_RE);
|
||||
if (!m) continue;
|
||||
const [, filename, mimeType, fileId] = m;
|
||||
const apiPath = getGetWorkspaceDownloadFileByIdUrl(fileId);
|
||||
fileParts.push({
|
||||
type: "file",
|
||||
filename,
|
||||
mediaType: mimeType,
|
||||
url: `/api/proxy${apiPath}`,
|
||||
});
|
||||
}
|
||||
|
||||
return { cleanText, fileParts };
|
||||
}
|
||||
|
||||
function safeJsonParse(value: string): unknown {
|
||||
try {
|
||||
return JSON.parse(value) as unknown;
|
||||
@@ -79,7 +122,17 @@ export function convertChatSessionMessagesToUiMessages(
|
||||
const parts: UIMessage<unknown, UIDataTypes, UITools>["parts"] = [];
|
||||
|
||||
if (typeof msg.content === "string" && msg.content.trim()) {
|
||||
parts.push({ type: "text", text: msg.content, state: "done" });
|
||||
if (msg.role === "user") {
|
||||
const { cleanText, fileParts } = extractFileParts(msg.content);
|
||||
if (cleanText) {
|
||||
parts.push({ type: "text", text: cleanText, state: "done" });
|
||||
}
|
||||
for (const fp of fileParts) {
|
||||
parts.push(fp);
|
||||
}
|
||||
} else {
|
||||
parts.push({ type: "text", text: msg.content, state: "done" });
|
||||
}
|
||||
}
|
||||
|
||||
if (msg.role === "assistant" && Array.isArray(msg.tool_calls)) {
|
||||
|
||||
@@ -12,6 +12,7 @@ import {
|
||||
GlobeIcon,
|
||||
ListChecksIcon,
|
||||
MagnifyingGlassIcon,
|
||||
MonitorIcon,
|
||||
PencilSimpleIcon,
|
||||
TerminalIcon,
|
||||
TrashIcon,
|
||||
@@ -48,6 +49,7 @@ function formatToolName(name: string): string {
|
||||
type ToolCategory =
|
||||
| "bash"
|
||||
| "web"
|
||||
| "browser"
|
||||
| "file-read"
|
||||
| "file-write"
|
||||
| "file-delete"
|
||||
@@ -65,19 +67,28 @@ function getToolCategory(toolName: string): ToolCategory {
|
||||
case "WebSearch":
|
||||
case "WebFetch":
|
||||
return "web";
|
||||
case "browser_navigate":
|
||||
case "browser_act":
|
||||
case "browser_screenshot":
|
||||
return "browser";
|
||||
case "read_workspace_file":
|
||||
case "read_file":
|
||||
case "Read":
|
||||
return "file-read";
|
||||
case "write_workspace_file":
|
||||
case "write_file":
|
||||
case "Write":
|
||||
return "file-write";
|
||||
case "delete_workspace_file":
|
||||
return "file-delete";
|
||||
case "list_workspace_files":
|
||||
case "glob":
|
||||
case "Glob":
|
||||
return "file-list";
|
||||
case "grep":
|
||||
case "Grep":
|
||||
return "search";
|
||||
case "edit_file":
|
||||
case "Edit":
|
||||
return "edit";
|
||||
case "TodoWrite":
|
||||
@@ -115,6 +126,8 @@ function ToolIcon({
|
||||
return <TerminalIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "web":
|
||||
return <GlobeIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "browser":
|
||||
return <MonitorIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "file-read":
|
||||
return <FileIcon size={14} weight="regular" className={iconClass} />;
|
||||
case "file-write":
|
||||
@@ -150,6 +163,8 @@ function AccordionIcon({ category }: { category: ToolCategory }) {
|
||||
return <TerminalIcon size={32} weight="light" />;
|
||||
case "web":
|
||||
return <GlobeIcon size={32} weight="light" />;
|
||||
case "browser":
|
||||
return <MonitorIcon size={32} weight="light" />;
|
||||
case "file-read":
|
||||
case "file-write":
|
||||
return <FileIcon size={32} weight="light" />;
|
||||
@@ -184,13 +199,25 @@ function getInputSummary(toolName: string, input: unknown): string | null {
|
||||
return typeof inp.url === "string" ? inp.url : null;
|
||||
case "WebSearch":
|
||||
return typeof inp.query === "string" ? inp.query : null;
|
||||
case "browser_navigate":
|
||||
return typeof inp.url === "string" ? inp.url : null;
|
||||
case "browser_act":
|
||||
return typeof inp.action === "string"
|
||||
? inp.target
|
||||
? `${inp.action} ${inp.target}`
|
||||
: (inp.action as string)
|
||||
: null;
|
||||
case "browser_screenshot":
|
||||
return null;
|
||||
case "read_workspace_file":
|
||||
case "read_file":
|
||||
case "Read":
|
||||
return (
|
||||
(typeof inp.file_path === "string" ? inp.file_path : null) ??
|
||||
(typeof inp.path === "string" ? inp.path : null)
|
||||
);
|
||||
case "write_workspace_file":
|
||||
case "write_file":
|
||||
case "Write":
|
||||
return (
|
||||
(typeof inp.file_path === "string" ? inp.file_path : null) ??
|
||||
@@ -198,10 +225,13 @@ function getInputSummary(toolName: string, input: unknown): string | null {
|
||||
);
|
||||
case "delete_workspace_file":
|
||||
return typeof inp.file_path === "string" ? inp.file_path : null;
|
||||
case "glob":
|
||||
case "Glob":
|
||||
return typeof inp.pattern === "string" ? inp.pattern : null;
|
||||
case "grep":
|
||||
case "Grep":
|
||||
return typeof inp.pattern === "string" ? inp.pattern : null;
|
||||
case "edit_file":
|
||||
case "Edit":
|
||||
return typeof inp.file_path === "string" ? inp.file_path : null;
|
||||
case "TodoWrite": {
|
||||
@@ -249,6 +279,11 @@ function getAnimationText(part: ToolUIPart, category: ToolCategory): string {
|
||||
return shortSummary
|
||||
? `Fetching ${shortSummary}`
|
||||
: "Fetching web content…";
|
||||
case "browser":
|
||||
if (toolName === "browser_screenshot") return "Taking screenshot…";
|
||||
return shortSummary
|
||||
? `Browsing ${shortSummary}`
|
||||
: "Interacting with browser…";
|
||||
case "file-read":
|
||||
return shortSummary ? `Reading ${shortSummary}` : "Reading file…";
|
||||
case "file-write":
|
||||
@@ -287,6 +322,11 @@ function getAnimationText(part: ToolUIPart, category: ToolCategory): string {
|
||||
return shortSummary
|
||||
? `Fetched ${shortSummary}`
|
||||
: "Fetched web content";
|
||||
case "browser":
|
||||
if (toolName === "browser_screenshot") return "Screenshot captured";
|
||||
return shortSummary
|
||||
? `Browsed ${shortSummary}`
|
||||
: "Browser action completed";
|
||||
case "file-read":
|
||||
return shortSummary ? `Read ${shortSummary}` : "File read completed";
|
||||
case "file-write":
|
||||
@@ -313,6 +353,8 @@ function getAnimationText(part: ToolUIPart, category: ToolCategory): string {
|
||||
return "Command failed";
|
||||
case "web":
|
||||
return toolName === "WebSearch" ? "Search failed" : "Fetch failed";
|
||||
case "browser":
|
||||
return "Browser action failed";
|
||||
default:
|
||||
return `${formatToolName(toolName)} failed`;
|
||||
}
|
||||
@@ -418,16 +460,22 @@ function getBashAccordionData(
|
||||
description: truncate(command, 80),
|
||||
content: (
|
||||
<div className="space-y-2">
|
||||
{command && (
|
||||
<div>
|
||||
<p className="mb-1 text-xs font-medium text-slate-500">command</p>
|
||||
<ContentCodeBlock>{command}</ContentCodeBlock>
|
||||
</div>
|
||||
)}
|
||||
{stdout && (
|
||||
<div>
|
||||
<p className="mb-1 text-xs font-medium text-slate-500">stdout</p>
|
||||
<ContentCodeBlock>{truncate(stdout, 2000)}</ContentCodeBlock>
|
||||
<ContentCodeBlock>{stdout}</ContentCodeBlock>
|
||||
</div>
|
||||
)}
|
||||
{stderr && (
|
||||
<div>
|
||||
<p className="mb-1 text-xs font-medium text-slate-500">stderr</p>
|
||||
<ContentCodeBlock>{truncate(stderr, 1000)}</ContentCodeBlock>
|
||||
<ContentCodeBlock>{stderr}</ContentCodeBlock>
|
||||
</div>
|
||||
)}
|
||||
{!stdout && !stderr && message && (
|
||||
@@ -475,18 +523,55 @@ function getWebAccordionData(
|
||||
: "Search results",
|
||||
description: truncate(url, 80),
|
||||
content: content ? (
|
||||
<ContentCodeBlock>{truncate(content, 2000)}</ContentCodeBlock>
|
||||
<ContentCodeBlock>{content}</ContentCodeBlock>
|
||||
) : message ? (
|
||||
<ContentMessage>{message}</ContentMessage>
|
||||
) : Object.keys(output).length > 0 ? (
|
||||
<ContentCodeBlock>
|
||||
{truncate(JSON.stringify(output, null, 2), 2000)}
|
||||
</ContentCodeBlock>
|
||||
<ContentCodeBlock>{JSON.stringify(output, null, 2)}</ContentCodeBlock>
|
||||
) : null,
|
||||
};
|
||||
}
|
||||
|
||||
function getBrowserAccordionData(
|
||||
toolName: string,
|
||||
input: unknown,
|
||||
output: Record<string, unknown>,
|
||||
): AccordionData {
|
||||
const message = getStringField(output, "message");
|
||||
const snapshot = getStringField(output, "snapshot");
|
||||
|
||||
// Screenshot tool: show the file_id so the user knows it was saved
|
||||
if (toolName === "browser_screenshot") {
|
||||
const fileId = getStringField(output, "file_id");
|
||||
const filename = getStringField(output, "filename");
|
||||
return {
|
||||
title: filename ? `Screenshot: ${filename}` : "Screenshot captured",
|
||||
description: fileId ? `file_id: ${fileId}` : undefined,
|
||||
content: message ? <ContentMessage>{message}</ContentMessage> : null,
|
||||
};
|
||||
}
|
||||
|
||||
// Navigate / act tools: show snapshot if available
|
||||
const title =
|
||||
toolName === "browser_navigate"
|
||||
? (getStringField(output, "title") ?? "Page loaded")
|
||||
: (message ?? "Action completed");
|
||||
|
||||
const url = getStringField(output, "url", "current_url");
|
||||
|
||||
return {
|
||||
title,
|
||||
description: url ? truncate(url, 80) : undefined,
|
||||
content: snapshot ? (
|
||||
<ContentCodeBlock>{truncate(snapshot, 3000)}</ContentCodeBlock>
|
||||
) : message ? (
|
||||
<ContentMessage>{message}</ContentMessage>
|
||||
) : null,
|
||||
};
|
||||
}
|
||||
|
||||
function getFileAccordionData(
|
||||
category: ToolCategory,
|
||||
input: unknown,
|
||||
output: Record<string, unknown>,
|
||||
): AccordionData {
|
||||
@@ -529,6 +614,20 @@ function getFileAccordionData(
|
||||
displayContent = extractMcpText(output);
|
||||
}
|
||||
|
||||
// For edit: show old/new diff; for write: show written content if output is just a status
|
||||
const oldString =
|
||||
category === "edit"
|
||||
? getStringField(inp as Record<string, unknown>, "old_string")
|
||||
: null;
|
||||
const newString =
|
||||
category === "edit"
|
||||
? getStringField(inp as Record<string, unknown>, "new_string")
|
||||
: null;
|
||||
const writtenContent =
|
||||
category === "file-write"
|
||||
? getStringField(inp as Record<string, unknown>, "content")
|
||||
: null;
|
||||
|
||||
// For Glob/list results, try to show file list
|
||||
// Files can be either strings (from Glob) or objects (from list_workspace_files)
|
||||
const files = Array.isArray(output.files) ? output.files : null;
|
||||
@@ -562,18 +661,33 @@ function getFileAccordionData(
|
||||
fileListText = fileLines.join("\n");
|
||||
}
|
||||
|
||||
const isWriteOrEdit = category === "file-write" || category === "edit";
|
||||
|
||||
return {
|
||||
title: message ?? "File output",
|
||||
title:
|
||||
message ??
|
||||
(isWriteOrEdit ? `Wrote ${truncate(filePath, 60)}` : "File output"),
|
||||
description: truncate(filePath, 80),
|
||||
content: (
|
||||
<div className="space-y-2">
|
||||
{displayContent && (
|
||||
<ContentCodeBlock>{truncate(displayContent, 2000)}</ContentCodeBlock>
|
||||
)}
|
||||
{fileListText && (
|
||||
<ContentCodeBlock>{truncate(fileListText, 2000)}</ContentCodeBlock>
|
||||
)}
|
||||
{!displayContent && !fileListText && message && (
|
||||
{oldString && newString != null ? (
|
||||
<>
|
||||
<div>
|
||||
<p className="mb-1 text-xs font-medium text-red-400">removed</p>
|
||||
<ContentCodeBlock>{oldString}</ContentCodeBlock>
|
||||
</div>
|
||||
<div>
|
||||
<p className="mb-1 text-xs font-medium text-green-400">added</p>
|
||||
<ContentCodeBlock>{newString}</ContentCodeBlock>
|
||||
</div>
|
||||
</>
|
||||
) : writtenContent ? (
|
||||
<ContentCodeBlock>{writtenContent}</ContentCodeBlock>
|
||||
) : displayContent ? (
|
||||
<ContentCodeBlock>{displayContent}</ContentCodeBlock>
|
||||
) : null}
|
||||
{fileListText && <ContentCodeBlock>{fileListText}</ContentCodeBlock>}
|
||||
{!displayContent && !fileListText && !writtenContent && message && (
|
||||
<ContentMessage>{message}</ContentMessage>
|
||||
)}
|
||||
</div>
|
||||
@@ -675,14 +789,13 @@ function getDefaultAccordionData(
|
||||
return {
|
||||
title: "Output",
|
||||
description: message ?? undefined,
|
||||
content: (
|
||||
<ContentCodeBlock>{truncate(displayContent, 2000)}</ContentCodeBlock>
|
||||
),
|
||||
content: <ContentCodeBlock>{displayContent}</ContentCodeBlock>,
|
||||
};
|
||||
}
|
||||
|
||||
function getAccordionData(
|
||||
category: ToolCategory,
|
||||
toolName: string,
|
||||
input: unknown,
|
||||
output: Record<string, unknown>,
|
||||
): AccordionData {
|
||||
@@ -691,13 +804,15 @@ function getAccordionData(
|
||||
return getBashAccordionData(input, output);
|
||||
case "web":
|
||||
return getWebAccordionData(input, output);
|
||||
case "browser":
|
||||
return getBrowserAccordionData(toolName, input, output);
|
||||
case "file-read":
|
||||
case "file-write":
|
||||
case "file-delete":
|
||||
case "file-list":
|
||||
case "search":
|
||||
case "edit":
|
||||
return getFileAccordionData(input, output);
|
||||
return getFileAccordionData(category, input, output);
|
||||
case "todo":
|
||||
return getTodoAccordionData(input);
|
||||
default:
|
||||
@@ -733,25 +848,23 @@ export function GenericTool({ part }: Props) {
|
||||
|
||||
const showAccordion = hasOutput || hasError || hasTodoInput;
|
||||
const accordionData = showAccordion
|
||||
? getAccordionData(category, part.input, output ?? {})
|
||||
? getAccordionData(category, toolName, part.input, output ?? {})
|
||||
: null;
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
{/* Only show loading text when NOT showing accordion */}
|
||||
{!showAccordion && (
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<ToolIcon
|
||||
category={category}
|
||||
isStreaming={isStreaming}
|
||||
isError={isError}
|
||||
/>
|
||||
<MorphingTextAnimation
|
||||
text={text}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{/* Status line: always visible so the user sees what tool ran */}
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<ToolIcon
|
||||
category={category}
|
||||
isStreaming={isStreaming}
|
||||
isError={isError}
|
||||
/>
|
||||
<MorphingTextAnimation
|
||||
text={text}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{showAccordion && accordionData ? (
|
||||
<ToolAccordion
|
||||
|
||||
@@ -5,15 +5,25 @@ import {
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||
import { getWebSocketToken } from "@/lib/supabase/actions";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { environment } from "@/services/environment";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useEffect, useState } from "react";
|
||||
import type { FileUIPart } from "ai";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useCopilotUIStore } from "./store";
|
||||
import { useChatSession } from "./useChatSession";
|
||||
import { useCopilotStream } from "./useCopilotStream";
|
||||
|
||||
interface UploadedFile {
|
||||
file_id: string;
|
||||
name: string;
|
||||
mime_type: string;
|
||||
}
|
||||
|
||||
export function useCopilotPage() {
|
||||
const { isUserLoading, isLoggedIn } = useSupabase();
|
||||
const [isUploadingFiles, setIsUploadingFiles] = useState(false);
|
||||
const [pendingMessage, setPendingMessage] = useState<string | null>(null);
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
@@ -77,26 +87,164 @@ export function useCopilotPage() {
|
||||
const isMobile =
|
||||
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
||||
|
||||
const pendingFilesRef = useRef<File[]>([]);
|
||||
|
||||
// --- Send pending message after session creation ---
|
||||
useEffect(() => {
|
||||
if (!sessionId || !pendingMessage) return;
|
||||
if (!sessionId || pendingMessage === null) return;
|
||||
const msg = pendingMessage;
|
||||
const files = pendingFilesRef.current;
|
||||
setPendingMessage(null);
|
||||
sendMessage({ text: msg });
|
||||
pendingFilesRef.current = [];
|
||||
|
||||
if (files.length > 0) {
|
||||
setIsUploadingFiles(true);
|
||||
void uploadFiles(files, sessionId)
|
||||
.then((uploaded) => {
|
||||
if (uploaded.length === 0) {
|
||||
toast({
|
||||
title: "File upload failed",
|
||||
description: "Could not upload any files. Please try again.",
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
const fileParts = buildFileParts(uploaded);
|
||||
sendMessage({
|
||||
text: msg,
|
||||
files: fileParts.length > 0 ? fileParts : undefined,
|
||||
});
|
||||
})
|
||||
.finally(() => setIsUploadingFiles(false));
|
||||
} else {
|
||||
sendMessage({ text: msg });
|
||||
}
|
||||
}, [sessionId, pendingMessage, sendMessage]);
|
||||
|
||||
async function onSend(message: string) {
|
||||
async function uploadFiles(
|
||||
files: File[],
|
||||
sid: string,
|
||||
): Promise<UploadedFile[]> {
|
||||
// Upload directly to the Python backend, bypassing the Next.js serverless
|
||||
// proxy. Vercel's 4.5 MB function payload limit would reject larger files
|
||||
// when routed through /api/workspace/files/upload.
|
||||
const { token, error: tokenError } = await getWebSocketToken();
|
||||
if (tokenError || !token) {
|
||||
toast({
|
||||
title: "Authentication error",
|
||||
description: "Please sign in again.",
|
||||
variant: "destructive",
|
||||
});
|
||||
return [];
|
||||
}
|
||||
|
||||
const backendBase = environment.getAGPTServerBaseUrl();
|
||||
|
||||
const results = await Promise.allSettled(
|
||||
files.map(async (file) => {
|
||||
const formData = new FormData();
|
||||
formData.append("file", file);
|
||||
const url = new URL("/api/workspace/files/upload", backendBase);
|
||||
url.searchParams.set("session_id", sid);
|
||||
const res = await fetch(url.toString(), {
|
||||
method: "POST",
|
||||
headers: { Authorization: `Bearer ${token}` },
|
||||
body: formData,
|
||||
});
|
||||
if (!res.ok) {
|
||||
const err = await res.text();
|
||||
console.error("File upload failed:", err);
|
||||
toast({
|
||||
title: "File upload failed",
|
||||
description: file.name,
|
||||
variant: "destructive",
|
||||
});
|
||||
throw new Error(err);
|
||||
}
|
||||
const data = await res.json();
|
||||
if (!data.file_id) throw new Error("No file_id returned");
|
||||
return {
|
||||
file_id: data.file_id,
|
||||
name: data.name || file.name,
|
||||
mime_type: data.mime_type || "application/octet-stream",
|
||||
} as UploadedFile;
|
||||
}),
|
||||
);
|
||||
return results
|
||||
.filter(
|
||||
(r): r is PromiseFulfilledResult<UploadedFile> =>
|
||||
r.status === "fulfilled",
|
||||
)
|
||||
.map((r) => r.value);
|
||||
}
|
||||
|
||||
function buildFileParts(uploaded: UploadedFile[]): FileUIPart[] {
|
||||
return uploaded.map((f) => ({
|
||||
type: "file" as const,
|
||||
mediaType: f.mime_type,
|
||||
filename: f.name,
|
||||
url: `/api/proxy/api/workspace/files/${f.file_id}/download`,
|
||||
}));
|
||||
}
|
||||
|
||||
async function onSend(message: string, files?: File[]) {
|
||||
const trimmed = message.trim();
|
||||
if (!trimmed) return;
|
||||
if (!trimmed && (!files || files.length === 0)) return;
|
||||
|
||||
// Client-side file limits
|
||||
if (files && files.length > 0) {
|
||||
const MAX_FILES = 10;
|
||||
const MAX_FILE_SIZE_BYTES = 100 * 1024 * 1024; // 100 MB
|
||||
|
||||
if (files.length > MAX_FILES) {
|
||||
toast({
|
||||
title: "Too many files",
|
||||
description: `You can attach up to ${MAX_FILES} files at once.`,
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const oversized = files.filter((f) => f.size > MAX_FILE_SIZE_BYTES);
|
||||
if (oversized.length > 0) {
|
||||
toast({
|
||||
title: "File too large",
|
||||
description: `${oversized[0].name} exceeds the 100 MB limit.`,
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
isUserStoppingRef.current = false;
|
||||
|
||||
if (sessionId) {
|
||||
sendMessage({ text: trimmed });
|
||||
if (files && files.length > 0) {
|
||||
setIsUploadingFiles(true);
|
||||
try {
|
||||
const uploaded = await uploadFiles(files, sessionId);
|
||||
if (uploaded.length === 0) {
|
||||
// All uploads failed — abort send so chips revert to editable
|
||||
throw new Error("All file uploads failed");
|
||||
}
|
||||
const fileParts = buildFileParts(uploaded);
|
||||
sendMessage({
|
||||
text: trimmed || "",
|
||||
files: fileParts.length > 0 ? fileParts : undefined,
|
||||
});
|
||||
} finally {
|
||||
setIsUploadingFiles(false);
|
||||
}
|
||||
} else {
|
||||
sendMessage({ text: trimmed });
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
setPendingMessage(trimmed);
|
||||
setPendingMessage(trimmed || "");
|
||||
if (files && files.length > 0) {
|
||||
pendingFilesRef.current = files;
|
||||
}
|
||||
await createSession();
|
||||
}
|
||||
|
||||
@@ -161,6 +309,7 @@ export function useCopilotPage() {
|
||||
isLoadingSession,
|
||||
isSessionError,
|
||||
isCreatingSession,
|
||||
isUploadingFiles,
|
||||
isUserLoading,
|
||||
isLoggedIn,
|
||||
createSession,
|
||||
|
||||
@@ -8,7 +8,7 @@ import { environment } from "@/services/environment";
|
||||
import { useChat } from "@ai-sdk/react";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { DefaultChatTransport } from "ai";
|
||||
import type { UIMessage } from "ai";
|
||||
import type { FileUIPart, UIMessage } from "ai";
|
||||
import { useEffect, useMemo, useRef, useState } from "react";
|
||||
import { deduplicateMessages, resolveInProgressTools } from "./helpers";
|
||||
|
||||
@@ -51,6 +51,15 @@ export function useCopilotStream({
|
||||
api: `${environment.getAGPTServerBaseUrl()}/api/chat/sessions/${sessionId}/stream`,
|
||||
prepareSendMessagesRequest: async ({ messages }) => {
|
||||
const last = messages[messages.length - 1];
|
||||
// Extract file_ids from FileUIPart entries on the message
|
||||
const fileIds = last.parts
|
||||
?.filter((p): p is FileUIPart => p.type === "file")
|
||||
.map((p) => {
|
||||
// URL is like /api/proxy/api/workspace/files/{id}/download
|
||||
const match = p.url.match(/\/workspace\/files\/([^/]+)\//);
|
||||
return match?.[1];
|
||||
})
|
||||
.filter(Boolean) as string[] | undefined;
|
||||
return {
|
||||
body: {
|
||||
message: (
|
||||
@@ -59,6 +68,7 @@ export function useCopilotStream({
|
||||
).join(""),
|
||||
is_user_message: last.role === "user",
|
||||
context: null,
|
||||
file_ids: fileIds && fileIds.length > 0 ? fileIds : null,
|
||||
},
|
||||
headers: await getAuthHeaders(),
|
||||
};
|
||||
|
||||
@@ -27,9 +27,9 @@ export async function POST(
|
||||
|
||||
try {
|
||||
const body = await request.json();
|
||||
const { message, is_user_message, context } = body;
|
||||
const { message, is_user_message, context, file_ids } = body;
|
||||
|
||||
if (!message) {
|
||||
if (message === undefined) {
|
||||
return new Response(
|
||||
JSON.stringify({ error: "Missing message parameter" }),
|
||||
{ status: 400, headers: { "Content-Type": "application/json" } },
|
||||
@@ -62,6 +62,7 @@ export async function POST(
|
||||
message,
|
||||
is_user_message: is_user_message ?? true,
|
||||
context: context || null,
|
||||
file_ids: file_ids || null,
|
||||
}),
|
||||
signal: debugSignal(),
|
||||
});
|
||||
|
||||
@@ -2039,7 +2039,9 @@
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/UploadFileResponse" }
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/backend__api__model__UploadFileResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -6497,6 +6499,59 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/workspace/files/upload": {
|
||||
"post": {
|
||||
"tags": ["workspace"],
|
||||
"summary": "Upload file to workspace",
|
||||
"description": "Upload a file to the user's workspace.\n\nFiles are stored in session-scoped paths when session_id is provided,\nso the agent's session-scoped tools can discover them automatically.",
|
||||
"operationId": "postWorkspaceUpload file to workspace",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "session_id",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Session Id"
|
||||
}
|
||||
}
|
||||
],
|
||||
"requestBody": {
|
||||
"required": true,
|
||||
"content": {
|
||||
"multipart/form-data": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Body_postWorkspaceUpload_file_to_workspace"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/backend__api__features__workspace__routes__UploadFileResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/workspace/files/{file_id}/download": {
|
||||
"get": {
|
||||
"tags": ["workspace"],
|
||||
@@ -6531,6 +6586,30 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/workspace/storage/usage": {
|
||||
"get": {
|
||||
"tags": ["workspace"],
|
||||
"summary": "Get workspace storage usage",
|
||||
"description": "Get storage usage information for the user's workspace.",
|
||||
"operationId": "getWorkspaceGet workspace storage usage",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/StorageUsageResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
}
|
||||
},
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/health": {
|
||||
"get": {
|
||||
"tags": ["health"],
|
||||
@@ -7768,6 +7847,14 @@
|
||||
"required": ["file"],
|
||||
"title": "Body_postV2Upload submission media"
|
||||
},
|
||||
"Body_postWorkspaceUpload_file_to_workspace": {
|
||||
"properties": {
|
||||
"file": { "type": "string", "format": "binary", "title": "File" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["file"],
|
||||
"title": "Body_postWorkspaceUpload file to workspace"
|
||||
},
|
||||
"BulkMoveAgentsRequest": {
|
||||
"properties": {
|
||||
"agent_ids": {
|
||||
@@ -11167,6 +11254,9 @@
|
||||
"operation_in_progress",
|
||||
"input_validation_error",
|
||||
"web_fetch",
|
||||
"browser_navigate",
|
||||
"browser_act",
|
||||
"browser_screenshot",
|
||||
"bash_exec",
|
||||
"feature_request_search",
|
||||
"feature_request_created",
|
||||
@@ -11592,6 +11682,17 @@
|
||||
"type": "object",
|
||||
"title": "Stats"
|
||||
},
|
||||
"StorageUsageResponse": {
|
||||
"properties": {
|
||||
"used_bytes": { "type": "integer", "title": "Used Bytes" },
|
||||
"limit_bytes": { "type": "integer", "title": "Limit Bytes" },
|
||||
"used_percent": { "type": "number", "title": "Used Percent" },
|
||||
"file_count": { "type": "integer", "title": "File Count" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["used_bytes", "limit_bytes", "used_percent", "file_count"],
|
||||
"title": "StorageUsageResponse"
|
||||
},
|
||||
"StoreAgent": {
|
||||
"properties": {
|
||||
"slug": { "type": "string", "title": "Slug" },
|
||||
@@ -12039,6 +12140,17 @@
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Context"
|
||||
},
|
||||
"file_ids": {
|
||||
"anyOf": [
|
||||
{
|
||||
"items": { "type": "string" },
|
||||
"type": "array",
|
||||
"maxItems": 20
|
||||
},
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "File Ids"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
@@ -13620,24 +13732,6 @@
|
||||
"required": ["timezone"],
|
||||
"title": "UpdateTimezoneRequest"
|
||||
},
|
||||
"UploadFileResponse": {
|
||||
"properties": {
|
||||
"file_uri": { "type": "string", "title": "File Uri" },
|
||||
"file_name": { "type": "string", "title": "File Name" },
|
||||
"size": { "type": "integer", "title": "Size" },
|
||||
"content_type": { "type": "string", "title": "Content Type" },
|
||||
"expires_in_hours": { "type": "integer", "title": "Expires In Hours" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"file_uri",
|
||||
"file_name",
|
||||
"size",
|
||||
"content_type",
|
||||
"expires_in_hours"
|
||||
],
|
||||
"title": "UploadFileResponse"
|
||||
},
|
||||
"UserHistoryResponse": {
|
||||
"properties": {
|
||||
"history": {
|
||||
@@ -13966,6 +14060,36 @@
|
||||
"url"
|
||||
],
|
||||
"title": "Webhook"
|
||||
},
|
||||
"backend__api__features__workspace__routes__UploadFileResponse": {
|
||||
"properties": {
|
||||
"file_id": { "type": "string", "title": "File Id" },
|
||||
"name": { "type": "string", "title": "Name" },
|
||||
"path": { "type": "string", "title": "Path" },
|
||||
"mime_type": { "type": "string", "title": "Mime Type" },
|
||||
"size_bytes": { "type": "integer", "title": "Size Bytes" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["file_id", "name", "path", "mime_type", "size_bytes"],
|
||||
"title": "UploadFileResponse"
|
||||
},
|
||||
"backend__api__model__UploadFileResponse": {
|
||||
"properties": {
|
||||
"file_uri": { "type": "string", "title": "File Uri" },
|
||||
"file_name": { "type": "string", "title": "File Name" },
|
||||
"size": { "type": "integer", "title": "Size" },
|
||||
"content_type": { "type": "string", "title": "Content Type" },
|
||||
"expires_in_hours": { "type": "integer", "title": "Expires In Hours" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"file_uri",
|
||||
"file_name",
|
||||
"size",
|
||||
"content_type",
|
||||
"expires_in_hours"
|
||||
],
|
||||
"title": "UploadFileResponse"
|
||||
}
|
||||
},
|
||||
"securitySchemes": {
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
import { environment } from "@/services/environment";
|
||||
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
|
||||
import { NextRequest, NextResponse } from "next/server";
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
const formData = await request.formData();
|
||||
const sessionId = request.nextUrl.searchParams.get("session_id");
|
||||
|
||||
const token = await getServerAuthToken();
|
||||
const backendUrl = environment.getAGPTServerBaseUrl();
|
||||
|
||||
const uploadUrl = new URL("/api/workspace/files/upload", backendUrl);
|
||||
if (sessionId) {
|
||||
uploadUrl.searchParams.set("session_id", sessionId);
|
||||
}
|
||||
|
||||
const headers: Record<string, string> = {};
|
||||
if (token && token !== "no-token-found") {
|
||||
headers["Authorization"] = `Bearer ${token}`;
|
||||
}
|
||||
|
||||
const response = await fetch(uploadUrl.toString(), {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: formData,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
return new NextResponse(errorText, {
|
||||
status: response.status,
|
||||
});
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
return NextResponse.json(data);
|
||||
} catch (error) {
|
||||
console.error("File upload proxy error:", error);
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: "Failed to upload file",
|
||||
detail: error instanceof Error ? error.message : String(error),
|
||||
},
|
||||
{ status: 500 },
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user