Merge branch 'dev' into ntindle/open-3018-google-drive-file-inputs-are-non-chainable-on-new-builder

This commit is contained in:
Nicholas Tindle
2026-03-03 21:47:09 -06:00
committed by GitHub
52 changed files with 5963 additions and 342 deletions

1
.nvmrc Normal file
View File

@@ -0,0 +1 @@
22

View File

@@ -1,2 +1,3 @@
*.ignore.*
*.ign.*
*.ign.*
.application.logs

View File

@@ -95,7 +95,7 @@ ENV DEBIAN_FRONTEND=noninteractive
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
# for the bash_exec MCP tool.
# for the bash_exec MCP tool (fallback when E2B is not configured).
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.13 \

View File

@@ -2,6 +2,7 @@
import asyncio
import logging
import re
from collections.abc import AsyncGenerator
from typing import Annotated
from uuid import uuid4
@@ -9,7 +10,8 @@ from uuid import uuid4
from autogpt_libs import auth
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from prisma.models import UserWorkspaceFile
from pydantic import BaseModel, Field
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
@@ -47,10 +49,14 @@ from backend.copilot.tools.models import (
UnderstandingUpdatedResponse,
)
from backend.copilot.tracking import track_user_message
from backend.data.workspace import get_or_create_workspace
from backend.util.exceptions import NotFoundError
config = ChatConfig()
_UUID_RE = re.compile(
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
)
logger = logging.getLogger(__name__)
@@ -79,6 +85,9 @@ class StreamChatRequest(BaseModel):
message: str
is_user_message: bool = True
context: dict[str, str] | None = None # {url: str, content: str}
file_ids: list[str] | None = Field(
default=None, max_length=20
) # Workspace file IDs attached to this message
class CreateSessionResponse(BaseModel):
@@ -238,6 +247,18 @@ async def delete_session(
detail=f"Session {session_id} not found or access denied",
)
# Best-effort cleanup of the E2B sandbox (if any).
config = ChatConfig()
if config.use_e2b_sandbox and config.e2b_api_key:
from backend.copilot.tools.e2b_sandbox import kill_sandbox
try:
await kill_sandbox(session_id, config.e2b_api_key)
except Exception:
logger.warning(
"[E2B] Failed to kill sandbox for session %s", session_id[:12]
)
return Response(status_code=204)
@@ -394,6 +415,38 @@ async def stream_chat_post(
},
)
# Enrich message with file metadata if file_ids are provided.
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
sanitized_file_ids: list[str] | None = None
if request.file_ids and user_id:
# Filter to valid UUIDs only to prevent DB abuse
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
if valid_ids:
workspace = await get_or_create_workspace(user_id)
# Batch query instead of N+1
files = await UserWorkspaceFile.prisma().find_many(
where={
"id": {"in": valid_ids},
"workspaceId": workspace.id,
"isDeleted": False,
}
)
# Only keep IDs that actually exist in the user's workspace
sanitized_file_ids = [wf.id for wf in files] or None
file_lines: list[str] = [
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
for wf in files
]
if file_lines:
files_block = (
"\n\n[Attached files]\n"
+ "\n".join(file_lines)
+ "\nUse read_workspace_file with the file_id to access file contents."
)
request.message += files_block
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
# saved yet. append_and_save_message re-fetches inside a lock to prevent
@@ -445,6 +498,7 @@ async def stream_chat_post(
turn_id=turn_id,
is_user_message=request.is_user_message,
context=request.context,
file_ids=sanitized_file_ids,
)
setup_time = (time.perf_counter() - stream_start_time) * 1000

View File

@@ -0,0 +1,160 @@
"""Tests for chat route file_ids validation and enrichment."""
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from backend.api.features.chat import routes as chat_routes
app = fastapi.FastAPI()
app.include_router(chat_routes.router)
client = fastapi.testclient.TestClient(app)
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
# ---- file_ids Pydantic validation (B1) ----
def test_stream_chat_rejects_too_many_file_ids():
"""More than 20 file_ids should be rejected by Pydantic validation (422)."""
response = client.post(
"/sessions/sess-1/stream",
json={
"message": "hello",
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(21)],
},
)
assert response.status_code == 422
def _mock_stream_internals(mocker: pytest_mock.MockFixture):
"""Mock the async internals of stream_chat_post so tests can exercise
validation and enrichment logic without needing Redis/RabbitMQ."""
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = mocker.AsyncMock(return_value=None)
mocker.patch(
"backend.api.features.chat.routes.stream_registry",
mock_registry,
)
mocker.patch(
"backend.api.features.chat.routes.enqueue_copilot_turn",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
"""Exactly 20 file_ids should be accepted (not rejected by validation)."""
_mock_stream_internals(mocker)
# Patch workspace lookup as imported by the routes module
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
return_value=type("W", (), {"id": "ws-1"})(),
)
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
response = client.post(
"/sessions/sess-1/stream",
json={
"message": "hello",
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(20)],
},
)
# Should get past validation — 200 streaming response expected
assert response.status_code == 200
# ---- UUID format filtering ----
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
"""Non-UUID strings in file_ids should be silently filtered out
and NOT passed to the database query."""
_mock_stream_internals(mocker)
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
return_value=type("W", (), {"id": "ws-1"})(),
)
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
valid_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
client.post(
"/sessions/sess-1/stream",
json={
"message": "hello",
"file_ids": [
valid_id,
"not-a-uuid",
"../../../etc/passwd",
"",
],
},
)
# The find_many call should only receive the one valid UUID
mock_prisma.find_many.assert_called_once()
call_kwargs = mock_prisma.find_many.call_args[1]
assert call_kwargs["where"]["id"]["in"] == [valid_id]
# ---- Cross-workspace file_ids ----
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
"""The batch query should scope to the user's workspace."""
_mock_stream_internals(mocker)
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
return_value=type("W", (), {"id": "my-workspace-id"})(),
)
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
fid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
client.post(
"/sessions/sess-1/stream",
json={"message": "hi", "file_ids": [fid]},
)
call_kwargs = mock_prisma.find_many.call_args[1]
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
assert call_kwargs["where"]["isDeleted"] is False

View File

@@ -3,15 +3,29 @@ Workspace API routes for managing user file storage.
"""
import logging
import os
import re
from typing import Annotated
from urllib.parse import quote
import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
from fastapi import Query, UploadFile
from fastapi.responses import Response
from pydantic import BaseModel
from backend.data.workspace import WorkspaceFile, get_workspace, get_workspace_file
from backend.data.workspace import (
WorkspaceFile,
count_workspace_files,
get_or_create_workspace,
get_workspace,
get_workspace_file,
get_workspace_total_size,
soft_delete_workspace_file,
)
from backend.util.settings import Config
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import WorkspaceManager
from backend.util.workspace_storage import get_workspace_storage
@@ -98,6 +112,21 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
raise
class UploadFileResponse(BaseModel):
file_id: str
name: str
path: str
mime_type: str
size_bytes: int
class StorageUsageResponse(BaseModel):
used_bytes: int
limit_bytes: int
used_percent: float
file_count: int
@router.get(
"/files/{file_id}/download",
summary="Download file by ID",
@@ -120,3 +149,120 @@ async def download_file(
raise fastapi.HTTPException(status_code=404, detail="File not found")
return await _create_file_download_response(file)
@router.post(
"/files/upload",
summary="Upload file to workspace",
)
async def upload_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
file: UploadFile,
session_id: str | None = Query(default=None),
) -> UploadFileResponse:
"""
Upload a file to the user's workspace.
Files are stored in session-scoped paths when session_id is provided,
so the agent's session-scoped tools can discover them automatically.
"""
config = Config()
# Sanitize filename — strip any directory components
filename = os.path.basename(file.filename or "upload") or "upload"
# Read file content with early abort on size limit
max_file_bytes = config.max_file_size_mb * 1024 * 1024
chunks: list[bytes] = []
total_size = 0
while chunk := await file.read(64 * 1024): # 64KB chunks
total_size += len(chunk)
if total_size > max_file_bytes:
raise fastapi.HTTPException(
status_code=413,
detail=f"File exceeds maximum size of {config.max_file_size_mb} MB",
)
chunks.append(chunk)
content = b"".join(chunks)
# Get or create workspace
workspace = await get_or_create_workspace(user_id)
# Pre-write storage cap check (soft check — final enforcement is post-write)
storage_limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
current_usage = await get_workspace_total_size(workspace.id)
if storage_limit_bytes and current_usage + len(content) > storage_limit_bytes:
used_percent = (current_usage / storage_limit_bytes) * 100
raise fastapi.HTTPException(
status_code=413,
detail={
"message": "Storage limit exceeded",
"used_bytes": current_usage,
"limit_bytes": storage_limit_bytes,
"used_percent": round(used_percent, 1),
},
)
# Warn at 80% usage
if (
storage_limit_bytes
and (usage_ratio := (current_usage + len(content)) / storage_limit_bytes) >= 0.8
):
logger.warning(
f"User {user_id} workspace storage at {usage_ratio * 100:.1f}% "
f"({current_usage + len(content)} / {storage_limit_bytes} bytes)"
)
# Virus scan
await scan_content_safe(content, filename=filename)
# Write file via WorkspaceManager
manager = WorkspaceManager(user_id, workspace.id, session_id)
workspace_file = await manager.write_file(content, filename)
# Post-write storage check — eliminates TOCTOU race on the quota.
# If a concurrent upload pushed us over the limit, undo this write.
new_total = await get_workspace_total_size(workspace.id)
if storage_limit_bytes and new_total > storage_limit_bytes:
await soft_delete_workspace_file(workspace_file.id, workspace.id)
raise fastapi.HTTPException(
status_code=413,
detail={
"message": "Storage limit exceeded (concurrent upload)",
"used_bytes": new_total,
"limit_bytes": storage_limit_bytes,
},
)
return UploadFileResponse(
file_id=workspace_file.id,
name=workspace_file.name,
path=workspace_file.path,
mime_type=workspace_file.mime_type,
size_bytes=workspace_file.size_bytes,
)
@router.get(
"/storage/usage",
summary="Get workspace storage usage",
)
async def get_storage_usage(
user_id: Annotated[str, fastapi.Security(get_user_id)],
) -> StorageUsageResponse:
"""
Get storage usage information for the user's workspace.
"""
config = Config()
workspace = await get_or_create_workspace(user_id)
used_bytes = await get_workspace_total_size(workspace.id)
file_count = await count_workspace_files(workspace.id)
limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
return StorageUsageResponse(
used_bytes=used_bytes,
limit_bytes=limit_bytes,
used_percent=round((used_bytes / limit_bytes) * 100, 1) if limit_bytes else 0,
file_count=file_count,
)

View File

@@ -0,0 +1,307 @@
"""Tests for workspace file upload and download routes."""
import io
from datetime import datetime, timezone
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from backend.api.features.workspace import routes as workspace_routes
from backend.data.workspace import WorkspaceFile
app = fastapi.FastAPI()
app.include_router(workspace_routes.router)
@app.exception_handler(ValueError)
async def _value_error_handler(
request: fastapi.Request, exc: ValueError
) -> fastapi.responses.JSONResponse:
"""Mirror the production ValueError → 400 mapping from rest_api.py."""
return fastapi.responses.JSONResponse(status_code=400, content={"detail": str(exc)})
client = fastapi.testclient.TestClient(app)
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
MOCK_WORKSPACE = type("W", (), {"id": "ws-1"})()
_NOW = datetime(2023, 1, 1, tzinfo=timezone.utc)
MOCK_FILE = WorkspaceFile(
id="file-aaa-bbb",
workspace_id="ws-1",
created_at=_NOW,
updated_at=_NOW,
name="hello.txt",
path="/session/hello.txt",
mime_type="text/plain",
size_bytes=13,
storage_path="local://hello.txt",
)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def _upload(
filename: str = "hello.txt",
content: bytes = b"Hello, world!",
content_type: str = "text/plain",
):
"""Helper to POST a file upload."""
return client.post(
"/files/upload?session_id=sess-1",
files={"file": (filename, io.BytesIO(content), content_type)},
)
# ---- Happy path ----
def test_upload_happy_path(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload()
assert response.status_code == 200
data = response.json()
assert data["file_id"] == "file-aaa-bbb"
assert data["name"] == "hello.txt"
assert data["size_bytes"] == 13
# ---- Per-file size limit ----
def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
"""Files larger than max_file_size_mb should be rejected with 413."""
cfg = mocker.patch("backend.api.features.workspace.routes.Config")
cfg.return_value.max_file_size_mb = 0 # 0 MB → any content is too big
cfg.return_value.max_workspace_storage_mb = 500
response = _upload(content=b"x" * 1024)
assert response.status_code == 413
# ---- Storage quota exceeded ----
def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
# Current usage already at limit
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=500 * 1024 * 1024,
)
response = _upload()
assert response.status_code == 413
assert "Storage limit exceeded" in response.text
# ---- Post-write quota race (B2) ----
def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
"""If a concurrent upload tips the total over the limit after write,
the file should be soft-deleted and 413 returned."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
# Pre-write check passes (under limit), but post-write check fails
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
side_effect=[0, 600 * 1024 * 1024], # first call OK, second over limit
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
mock_delete = mocker.patch(
"backend.api.features.workspace.routes.soft_delete_workspace_file",
return_value=None,
)
response = _upload()
assert response.status_code == 413
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-1")
# ---- Any extension accepted (no allowlist) ----
def test_upload_any_extension(mocker: pytest_mock.MockFixture):
"""Any file extension should be accepted — ClamAV is the security layer."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload(filename="data.xyz", content=b"arbitrary")
assert response.status_code == 200
# ---- Virus scan rejection ----
def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
"""Files flagged by ClamAV should be rejected and never written to storage."""
from backend.api.features.store.exceptions import VirusDetectedError
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
side_effect=VirusDetectedError("Eicar-Test-Signature"),
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload(filename="evil.exe", content=b"X5O!P%@AP...")
assert response.status_code == 400
assert "Virus detected" in response.text
mock_manager.write_file.assert_not_called()
# ---- No file extension ----
def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
"""Files without an extension should be accepted and stored as-is."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload(
filename="Makefile",
content=b"all:\n\techo hello",
content_type="application/octet-stream",
)
assert response.status_code == 200
mock_manager.write_file.assert_called_once()
assert mock_manager.write_file.call_args[0][1] == "Makefile"
# ---- Filename sanitization (SF5) ----
def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
"""Path-traversal filenames should be reduced to their basename."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
# Filename with traversal
_upload(filename="../../etc/passwd.txt")
# write_file should have been called with just the basename
mock_manager.write_file.assert_called_once()
call_args = mock_manager.write_file.call_args
assert call_args[0][1] == "passwd.txt"
# ---- Download ----
def test_download_file_not_found(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_file",
return_value=None,
)
response = client.get("/files/some-file-id/download")
assert response.status_code == 404

View File

@@ -93,12 +93,49 @@ class ChatConfig(BaseSettings):
"history compression. Falls back to compression when unavailable.",
)
# E2B Sandbox Configuration
use_e2b_sandbox: bool = Field(
default=True,
description="Use E2B cloud sandboxes for persistent bash/python execution. "
"When enabled, bash_exec routes commands to E2B and SDK file tools "
"operate directly on the sandbox via E2B's filesystem API.",
)
e2b_api_key: str | None = Field(
default=None,
description="E2B API key. Falls back to E2B_API_KEY environment variable.",
)
e2b_sandbox_template: str = Field(
default="base",
description="E2B sandbox template to use for copilot sessions.",
)
e2b_sandbox_timeout: int = Field(
default=43200, # 12 hours — same as session_ttl
description="E2B sandbox keepalive timeout in seconds.",
)
# Extended thinking configuration for Claude models
thinking_enabled: bool = Field(
default=True,
description="Enable adaptive thinking for Claude models via OpenRouter",
)
@field_validator("use_e2b_sandbox", mode="before")
@classmethod
def get_use_e2b_sandbox(cls, v):
"""Get use_e2b_sandbox from environment if not provided."""
env_val = os.getenv("CHAT_USE_E2B_SANDBOX", "").lower()
if env_val:
return env_val in ("true", "1", "yes", "on")
return True if v is None else v
@field_validator("e2b_api_key", mode="before")
@classmethod
def get_e2b_api_key(cls, v):
"""Get E2B API key from environment if not provided."""
if v is None:
v = os.getenv("CHAT_E2B_API_KEY") or os.getenv("E2B_API_KEY")
return v
@field_validator("api_key", mode="before")
@classmethod
def get_api_key(cls, v):

View File

@@ -16,7 +16,7 @@ from prisma.types import (
)
from backend.data import db
from backend.util.json import SafeJson
from backend.util.json import SafeJson, sanitize_string
from .model import ChatMessage, ChatSession, ChatSessionInfo
@@ -101,15 +101,16 @@ async def add_chat_message(
"sequence": sequence,
}
# Add optional string fields
# Add optional string fields — sanitize to strip PostgreSQL-incompatible
# control characters (null bytes etc.) that may appear in tool outputs.
if content is not None:
data["content"] = content
data["content"] = sanitize_string(content)
if name is not None:
data["name"] = name
if tool_call_id is not None:
data["toolCallId"] = tool_call_id
if refusal is not None:
data["refusal"] = refusal
data["refusal"] = sanitize_string(refusal)
# Add optional JSON fields only when they have values
if tool_calls is not None:
@@ -170,15 +171,16 @@ async def add_chat_messages_batch(
"createdAt": now,
}
# Add optional string fields
# Add optional string fields — sanitize to strip
# PostgreSQL-incompatible control characters.
if msg.get("content") is not None:
data["content"] = msg["content"]
data["content"] = sanitize_string(msg["content"])
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = msg["refusal"]
data["refusal"] = sanitize_string(msg["refusal"])
# Add optional JSON fields only when they have values
if msg.get("tool_calls") is not None:
@@ -312,7 +314,7 @@ async def update_tool_message_content(
"toolCallId": tool_call_id,
},
data={
"content": new_content,
"content": sanitize_string(new_content),
},
)
if result == 0:

View File

@@ -119,12 +119,12 @@ class CoPilotProcessor:
"""
from backend.util.workspace_storage import shutdown_workspace_storage
coro = shutdown_workspace_storage()
try:
future = asyncio.run_coroutine_threadsafe(
shutdown_workspace_storage(), self.execution_loop
)
future = asyncio.run_coroutine_threadsafe(coro, self.execution_loop)
future.result(timeout=5)
except Exception as e:
coro.close() # Prevent "coroutine was never awaited" warning
error_msg = str(e) or type(e).__name__
logger.warning(
f"[CoPilotExecutor] Worker {self.tid} cleanup error: {error_msg}"

View File

@@ -153,6 +153,9 @@ class CoPilotExecutionEntry(BaseModel):
context: dict[str, str] | None = None
"""Optional context for the message (e.g., {url: str, content: str})"""
file_ids: list[str] | None = None
"""Workspace file IDs attached to the user's message"""
class CancelCoPilotEvent(BaseModel):
"""Event to cancel a CoPilot operation."""
@@ -171,6 +174,7 @@ async def enqueue_copilot_turn(
turn_id: str,
is_user_message: bool = True,
context: dict[str, str] | None = None,
file_ids: list[str] | None = None,
) -> None:
"""Enqueue a CoPilot task for processing by the executor service.
@@ -181,6 +185,7 @@ async def enqueue_copilot_turn(
turn_id: Per-turn UUID for Redis stream isolation
is_user_message: Whether the message is from the user (vs system/assistant)
context: Optional context for the message (e.g., {url: str, content: str})
file_ids: Optional workspace file IDs attached to the user's message
"""
from backend.util.clients import get_async_copilot_queue
@@ -191,6 +196,7 @@ async def enqueue_copilot_turn(
message=message,
is_user_message=is_user_message,
context=context,
file_ids=file_ids,
)
queue_client = await get_async_copilot_queue()

View File

@@ -672,6 +672,16 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
async with _session_locks_mutex:
_session_locks.pop(session_id, None)
# Shut down any local browser daemon for this session (best-effort).
# Inline import required: all tool modules import ChatSession from this
# module, so any top-level import from tools.* would create a cycle.
try:
from .tools.agent_browser import close_browser_session
await close_browser_session(session_id, user_id=user_id)
except Exception as e:
logger.debug(f"Browser cleanup for session {session_id}: {e}")
return True

View File

@@ -13,6 +13,7 @@ from typing import Any
from pydantic import BaseModel, Field
from backend.util.json import dumps as json_dumps
from backend.util.truncate import truncate
logger = logging.getLogger(__name__)
@@ -150,6 +151,9 @@ class StreamToolInputAvailable(StreamBaseResponse):
)
_MAX_TOOL_OUTPUT_SIZE = 100_000 # ~100 KB; truncate to avoid bloating SSE/DB
class StreamToolOutputAvailable(StreamBaseResponse):
"""Tool execution result."""
@@ -164,6 +168,10 @@ class StreamToolOutputAvailable(StreamBaseResponse):
default=True, description="Whether the tool execution succeeded"
)
def model_post_init(self, __context: Any) -> None:
"""Truncate oversized outputs after construction."""
self.output = truncate(self.output, _MAX_TOOL_OUTPUT_SIZE)
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-spec fields."""
data = {

View File

@@ -0,0 +1,360 @@
"""MCP file-tool handlers that route to the E2B cloud sandbox.
When E2B is active, these tools replace the SDK built-in Read/Write/Edit/
Glob/Grep so that all file operations share the same ``/home/user``
filesystem as ``bash_exec``.
SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
"""
from __future__ import annotations
import itertools
import json
import logging
import os
import shlex
from typing import Any, Callable
from backend.copilot.tools.e2b_sandbox import E2B_WORKDIR
logger = logging.getLogger(__name__)
# Lazy imports to break circular dependency with tool_adapter.
def _get_sandbox(): # type: ignore[return]
from .tool_adapter import get_current_sandbox # noqa: E402
return get_current_sandbox()
def _is_allowed_local(path: str) -> bool:
from .tool_adapter import is_allowed_local_path # noqa: E402
return is_allowed_local_path(path)
def _resolve_remote(path: str) -> str:
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
Raises :class:`ValueError` if the resolved path escapes the sandbox.
"""
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
normalized = os.path.normpath(candidate)
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
return normalized
def _mcp(text: str, *, error: bool = False) -> dict[str, Any]:
if error:
text = json.dumps({"error": text, "type": "error"})
return {"content": [{"type": "text", "text": text}], "isError": error}
def _get_sandbox_and_path(
file_path: str,
) -> tuple[Any, str] | dict[str, Any]:
"""Common preamble: get sandbox + resolve path, or return MCP error."""
sandbox = _get_sandbox()
if sandbox is None:
return _mcp("No E2B sandbox available", error=True)
try:
remote = _resolve_remote(file_path)
except ValueError as exc:
return _mcp(str(exc), error=True)
return sandbox, remote
# Tool handlers
async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
file_path: str = args.get("file_path", "")
offset: int = max(0, int(args.get("offset", 0)))
limit: int = max(1, int(args.get("limit", 2000)))
if not file_path:
return _mcp("file_path is required", error=True)
# SDK-internal paths (tool-results, ephemeral working dir) stay on the host.
if _is_allowed_local(file_path):
return _read_local(file_path, offset, limit)
result = _get_sandbox_and_path(file_path)
if isinstance(result, dict):
return result
sandbox, remote = result
try:
raw: bytes = await sandbox.files.read(remote, format="bytes")
content = raw.decode("utf-8", errors="replace")
except Exception as exc:
return _mcp(f"Failed to read {remote}: {exc}", error=True)
lines = content.splitlines(keepends=True)
selected = list(itertools.islice(lines, offset, offset + limit))
numbered = "".join(
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
)
return _mcp(numbered)
async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
file_path: str = args.get("file_path", "")
content: str = args.get("content", "")
if not file_path:
return _mcp("file_path is required", error=True)
result = _get_sandbox_and_path(file_path)
if isinstance(result, dict):
return result
sandbox, remote = result
try:
parent = os.path.dirname(remote)
if parent and parent != E2B_WORKDIR:
await sandbox.files.make_dir(parent)
await sandbox.files.write(remote, content)
except Exception as exc:
return _mcp(f"Failed to write {remote}: {exc}", error=True)
return _mcp(f"Successfully wrote to {remote}")
async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
file_path: str = args.get("file_path", "")
old_string: str = args.get("old_string", "")
new_string: str = args.get("new_string", "")
replace_all: bool = args.get("replace_all", False)
if not file_path:
return _mcp("file_path is required", error=True)
if not old_string:
return _mcp("old_string is required", error=True)
result = _get_sandbox_and_path(file_path)
if isinstance(result, dict):
return result
sandbox, remote = result
try:
raw: bytes = await sandbox.files.read(remote, format="bytes")
content = raw.decode("utf-8", errors="replace")
except Exception as exc:
return _mcp(f"Failed to read {remote}: {exc}", error=True)
count = content.count(old_string)
if count == 0:
return _mcp(f"old_string not found in {file_path}", error=True)
if count > 1 and not replace_all:
return _mcp(
f"old_string appears {count} times in {file_path}. "
"Use replace_all=true or provide a more unique string.",
error=True,
)
updated = (
content.replace(old_string, new_string)
if replace_all
else content.replace(old_string, new_string, 1)
)
try:
await sandbox.files.write(remote, updated)
except Exception as exc:
return _mcp(f"Failed to write {remote}: {exc}", error=True)
return _mcp(f"Edited {remote} ({count} replacement{'s' if count > 1 else ''})")
async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
pattern: str = args.get("pattern", "")
path: str = args.get("path", "")
if not pattern:
return _mcp("pattern is required", error=True)
sandbox = _get_sandbox()
if sandbox is None:
return _mcp("No E2B sandbox available", error=True)
try:
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
except ValueError as exc:
return _mcp(str(exc), error=True)
cmd = f"find {shlex.quote(search_dir)} -name {shlex.quote(pattern)} -type f 2>/dev/null | head -500"
try:
result = await sandbox.commands.run(cmd, cwd=E2B_WORKDIR, timeout=10)
except Exception as exc:
return _mcp(f"Glob failed: {exc}", error=True)
files = [line for line in (result.stdout or "").strip().splitlines() if line]
return _mcp(json.dumps(files, indent=2))
async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
pattern: str = args.get("pattern", "")
path: str = args.get("path", "")
include: str = args.get("include", "")
if not pattern:
return _mcp("pattern is required", error=True)
sandbox = _get_sandbox()
if sandbox is None:
return _mcp("No E2B sandbox available", error=True)
try:
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
except ValueError as exc:
return _mcp(str(exc), error=True)
parts = ["grep", "-rn", "--color=never"]
if include:
parts.extend(["--include", include])
parts.extend([pattern, search_dir])
cmd = " ".join(shlex.quote(p) for p in parts) + " 2>/dev/null | head -200"
try:
result = await sandbox.commands.run(cmd, cwd=E2B_WORKDIR, timeout=15)
except Exception as exc:
return _mcp(f"Grep failed: {exc}", error=True)
output = (result.stdout or "").strip()
return _mcp(output if output else "No matches found.")
# Local read (for SDK-internal paths)
def _read_local(file_path: str, offset: int, limit: int) -> dict[str, Any]:
"""Read from the host filesystem (defence-in-depth path check)."""
if not _is_allowed_local(file_path):
return _mcp(f"Path not allowed: {file_path}", error=True)
expanded = os.path.realpath(os.path.expanduser(file_path))
try:
with open(expanded) as fh:
selected = list(itertools.islice(fh, offset, offset + limit))
numbered = "".join(
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
)
return _mcp(numbered)
except FileNotFoundError:
return _mcp(f"File not found: {file_path}", error=True)
except Exception as exc:
return _mcp(f"Error reading {file_path}: {exc}", error=True)
# Tool descriptors (name, description, schema, handler)
E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
(
"read_file",
"Read a file from the cloud sandbox (/home/user). "
"Use offset and limit for large files.",
{
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path (relative to /home/user, or absolute).",
},
"offset": {
"type": "integer",
"description": "Line to start reading from (0-indexed). Default: 0.",
},
"limit": {
"type": "integer",
"description": "Number of lines to read. Default: 2000.",
},
},
"required": ["file_path"],
},
_handle_read_file,
),
(
"write_file",
"Write or create a file in the cloud sandbox (/home/user). "
"Parent directories are created automatically.",
{
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path (relative to /home/user, or absolute).",
},
"content": {"type": "string", "description": "Content to write."},
},
"required": ["file_path", "content"],
},
_handle_write_file,
),
(
"edit_file",
"Targeted text replacement in a sandbox file. "
"old_string must appear in the file and is replaced with new_string.",
{
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path (relative to /home/user, or absolute).",
},
"old_string": {"type": "string", "description": "Text to find."},
"new_string": {"type": "string", "description": "Replacement text."},
"replace_all": {
"type": "boolean",
"description": "Replace all occurrences (default: false).",
},
},
"required": ["file_path", "old_string", "new_string"],
},
_handle_edit_file,
),
(
"glob",
"Search for files by name pattern in the cloud sandbox.",
{
"type": "object",
"properties": {
"pattern": {
"type": "string",
"description": "Glob pattern (e.g. *.py).",
},
"path": {
"type": "string",
"description": "Directory to search. Default: /home/user.",
},
},
"required": ["pattern"],
},
_handle_glob,
),
(
"grep",
"Search file contents by regex in the cloud sandbox.",
{
"type": "object",
"properties": {
"pattern": {"type": "string", "description": "Regex pattern."},
"path": {
"type": "string",
"description": "File or directory. Default: /home/user.",
},
"include": {
"type": "string",
"description": "Glob to filter files (e.g. *.py).",
},
},
"required": ["pattern"],
},
_handle_grep,
),
]
E2B_FILE_TOOL_NAMES: list[str] = [name for name, *_ in E2B_FILE_TOOLS]

View File

@@ -0,0 +1,153 @@
"""Tests for E2B file-tool path validation and local read safety.
Pure unit tests with no external dependencies (no E2B, no sandbox).
"""
import os
import pytest
from .e2b_file_tools import _read_local, _resolve_remote
from .tool_adapter import _current_project_dir
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# ---------------------------------------------------------------------------
# _resolve_remote — sandbox path normalisation & boundary enforcement
# ---------------------------------------------------------------------------
class TestResolveRemote:
def test_relative_path_resolved(self):
assert _resolve_remote("src/main.py") == "/home/user/src/main.py"
def test_absolute_within_sandbox(self):
assert _resolve_remote("/home/user/file.txt") == "/home/user/file.txt"
def test_workdir_itself(self):
assert _resolve_remote("/home/user") == "/home/user"
def test_relative_dotslash(self):
assert _resolve_remote("./README.md") == "/home/user/README.md"
def test_traversal_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("../../etc/passwd")
def test_absolute_traversal_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/home/user/../../etc/passwd")
def test_absolute_outside_sandbox_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/etc/passwd")
def test_root_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/")
def test_home_other_user_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/home/other/file.txt")
def test_deep_nested_allowed(self):
assert _resolve_remote("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
def test_trailing_slash_normalised(self):
assert _resolve_remote("src/") == "/home/user/src"
def test_double_dots_within_sandbox_ok(self):
"""Path that resolves back within /home/user is allowed."""
assert _resolve_remote("a/b/../c.txt") == "/home/user/a/c.txt"
# ---------------------------------------------------------------------------
# _read_local — host filesystem reads with allowlist enforcement
#
# In E2B mode, _read_local only allows tool-results paths (via
# is_allowed_local_path without sdk_cwd). Regular files live on the
# sandbox, not the host.
# ---------------------------------------------------------------------------
class TestReadLocal:
def _make_tool_results_file(self, encoded: str, filename: str, content: str) -> str:
"""Create a tool-results file and return its path."""
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
os.makedirs(tool_results_dir, exist_ok=True)
filepath = os.path.join(tool_results_dir, filename)
with open(filepath, "w") as f:
f.write(content)
return filepath
def test_read_tool_results_file(self):
"""Reading a tool-results file should succeed."""
encoded = "-tmp-copilot-e2b-test-read"
filepath = self._make_tool_results_file(
encoded, "result.txt", "line 1\nline 2\nline 3\n"
)
token = _current_project_dir.set(encoded)
try:
result = _read_local(filepath, offset=0, limit=2000)
assert result["isError"] is False
assert "line 1" in result["content"][0]["text"]
assert "line 2" in result["content"][0]["text"]
finally:
_current_project_dir.reset(token)
os.unlink(filepath)
def test_read_disallowed_path_blocked(self):
"""Reading /etc/passwd should be blocked by the allowlist."""
result = _read_local("/etc/passwd", offset=0, limit=10)
assert result["isError"] is True
assert "not allowed" in result["content"][0]["text"].lower()
def test_read_nonexistent_tool_results(self):
"""A tool-results path that doesn't exist returns FileNotFoundError."""
encoded = "-tmp-copilot-e2b-test-nofile"
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
os.makedirs(tool_results_dir, exist_ok=True)
filepath = os.path.join(tool_results_dir, "nonexistent.txt")
token = _current_project_dir.set(encoded)
try:
result = _read_local(filepath, offset=0, limit=10)
assert result["isError"] is True
assert "not found" in result["content"][0]["text"].lower()
finally:
_current_project_dir.reset(token)
os.rmdir(tool_results_dir)
def test_read_traversal_path_blocked(self):
"""A traversal attempt that escapes allowed directories is blocked."""
result = _read_local("/tmp/copilot-abc/../../etc/shadow", offset=0, limit=10)
assert result["isError"] is True
assert "not allowed" in result["content"][0]["text"].lower()
def test_read_arbitrary_host_path_blocked(self):
"""Arbitrary host paths are blocked even if they exist."""
result = _read_local("/proc/self/environ", offset=0, limit=10)
assert result["isError"] is True
def test_read_with_offset_and_limit(self):
"""Offset and limit should control which lines are returned."""
encoded = "-tmp-copilot-e2b-test-offset"
content = "".join(f"line {i}\n" for i in range(10))
filepath = self._make_tool_results_file(encoded, "lines.txt", content)
token = _current_project_dir.set(encoded)
try:
result = _read_local(filepath, offset=3, limit=2)
assert result["isError"] is False
text = result["content"][0]["text"]
assert "line 3" in text
assert "line 4" in text
assert "line 2" not in text
assert "line 5" not in text
finally:
_current_project_dir.reset(token)
os.unlink(filepath)
def test_read_without_project_dir_blocks_all(self):
"""Without _current_project_dir set, all paths are blocked."""
result = _read_local("/tmp/anything.txt", offset=0, limit=10)
assert result["isError"] is True

View File

@@ -6,7 +6,6 @@ ensuring multi-user isolation and preventing unauthorized operations.
import json
import logging
import os
import re
from collections.abc import Callable
from typing import Any, cast
@@ -16,6 +15,7 @@ from .tool_adapter import (
DANGEROUS_PATTERNS,
MCP_TOOL_PREFIX,
WORKSPACE_SCOPED_TOOLS,
is_allowed_local_path,
stash_pending_tool_output,
)
@@ -38,40 +38,20 @@ def _validate_workspace_path(
) -> dict[str, Any]:
"""Validate that a workspace-scoped tool only accesses allowed paths.
Allowed directories:
Delegates to :func:`is_allowed_local_path` which permits:
- The SDK working directory (``/tmp/copilot-<session>/``)
- The SDK tool-results directory (``~/.claude/projects/…/tool-results/``)
- The current session's tool-results directory
(``~/.claude/projects/<encoded-cwd>/tool-results/``)
"""
path = tool_input.get("file_path") or tool_input.get("path") or ""
if not path:
# Glob/Grep without a path default to cwd which is already sandboxed
return {}
# Resolve relative paths against sdk_cwd (the SDK sets cwd so the LLM
# naturally uses relative paths like "test.txt" instead of absolute ones).
# Tilde paths (~/) are home-dir references, not relative — expand first.
if path.startswith("~"):
resolved = os.path.realpath(os.path.expanduser(path))
elif not os.path.isabs(path) and sdk_cwd:
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
else:
resolved = os.path.realpath(path)
# Allow access within the SDK working directory
if sdk_cwd:
norm_cwd = os.path.realpath(sdk_cwd)
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
return {}
# Allow access to ~/.claude/projects/*/tool-results/ (big tool results)
claude_dir = os.path.realpath(os.path.expanduser("~/.claude/projects"))
tool_results_seg = os.sep + "tool-results" + os.sep
if resolved.startswith(claude_dir + os.sep) and tool_results_seg in resolved:
if is_allowed_local_path(path, sdk_cwd):
return {}
logger.warning(
f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})"
)
logger.warning(f"Blocked {tool_name} outside workspace: {path}")
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
return _deny(
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "

View File

@@ -120,17 +120,31 @@ def test_read_no_cwd_denies_absolute():
def test_read_tool_results_allowed():
from .tool_adapter import _current_project_dir
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert result == {}
# is_allowed_local_path requires the session's encoded cwd to be set
token = _current_project_dir.set("-tmp-copilot-abc123")
try:
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert result == {}
finally:
_current_project_dir.reset(token)
def test_read_claude_projects_without_tool_results_denied():
def test_read_claude_projects_session_dir_allowed():
"""Files within the current session's project dir are allowed."""
from .tool_adapter import _current_project_dir
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert _is_denied(result)
token = _current_project_dir.set("-tmp-copilot-abc123")
try:
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert not _is_denied(result)
finally:
_current_project_dir.reset(token)
# -- Built-in Bash is blocked (use bash_exec MCP tool instead) ---------------

View File

@@ -5,18 +5,28 @@ import base64
import json
import logging
import os
import shutil
import sys
import uuid
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any, cast
import openai
from claude_agent_sdk import (
AssistantMessage,
ClaudeAgentOptions,
ClaudeSDKClient,
ResultMessage,
ToolUseBlock,
)
from langfuse import propagate_attributes
from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
from backend.data.redis_client import get_redis_async
from backend.executor.cluster_lock import AsyncClusterLock
from backend.util.exceptions import NotFoundError
from backend.util.prompt import compress_context
from backend.util.settings import Settings
from ..config import ChatConfig
@@ -42,14 +52,15 @@ from ..service import (
_generate_session_title,
_is_langfuse_configured,
)
from ..tools.e2b_sandbox import get_or_create_sandbox
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tracking import track_user_message
from .response_adapter import SDKResponseAdapter
from .security_hooks import create_security_hooks
from .tool_adapter import (
COPILOT_TOOL_NAMES,
SDK_DISALLOWED_TOOLS,
create_copilot_mcp_server,
get_copilot_tool_names,
get_sdk_disallowed_tools,
set_execution_context,
wait_for_stash,
)
@@ -148,9 +159,36 @@ _HEARTBEAT_INTERVAL = 3.0 # seconds
# Appended to the system prompt to inform the agent about available tools.
# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead,
# which has kernel-level network isolation (unshare --net).
def _build_sdk_tool_supplement(cwd: str) -> str:
"""Build the SDK tool supplement with the actual working directory injected."""
return f"""
_SHARED_TOOL_NOTES = """\
### Sharing files with the user
After saving a file to the persistent workspace with `write_workspace_file`,
share it with the user by embedding the `download_url` from the response in
your message as a Markdown link or image:
- **Any file** — shows as a clickable download link:
`[report.csv](workspace://file_id#text/csv)`
- **Image** — renders inline in chat:
`![chart](workspace://file_id#image/png)`
- **Video** — renders inline in chat with player controls:
`![recording](workspace://file_id#video/mp4)`
The `download_url` field in the `write_workspace_file` response is already
in the correct format — paste it directly after the `(` in the Markdown.
### Long-running tools
Long-running tools (create_agent, edit_agent, etc.) are handled
asynchronously. You will receive an immediate response; the actual result
is delivered to the user via a background stream.
### Sub-agent tasks
- When using the Task tool, NEVER set `run_in_background` to true.
All tasks must run in the foreground.
"""
_LOCAL_TOOL_SUPPLEMENT = (
"""
## Tool notes
@@ -192,31 +230,57 @@ When you create or modify important files (code, configs, outputs), you MUST:
1. Save them using `write_workspace_file` so they persist
2. At the start of a new turn, call `list_workspace_files` to see what files
are available from previous turns
### Sharing files with the user
After saving a file to the persistent workspace with `write_workspace_file`,
share it with the user by embedding the `download_url` from the response in
your message as a Markdown link or image:
- **Any file** — shows as a clickable download link:
`[report.csv](workspace://file_id#text/csv)`
- **Image** — renders inline in chat:
`![chart](workspace://file_id#image/png)`
- **Video** — renders inline in chat with player controls:
`![recording](workspace://file_id#video/mp4)`
The `download_url` field in the `write_workspace_file` response is already
in the correct format — paste it directly after the `(` in the Markdown.
### Long-running tools
Long-running tools (create_agent, edit_agent, etc.) are handled
asynchronously. You will receive an immediate response; the actual result
is delivered to the user via a background stream.
### Sub-agent tasks
- When using the Task tool, NEVER set `run_in_background` to true.
All tasks must run in the foreground.
"""
+ _SHARED_TOOL_NOTES
)
_E2B_TOOL_SUPPLEMENT = (
"""
## Tool notes
### Shell commands
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
for shell commands — it runs in a cloud sandbox with full internet access.
### Working directory
- Your working directory is: `/home/user` (cloud sandbox)
- All file tools (`read_file`, `write_file`, `edit_file`, `glob`, `grep`)
AND `bash_exec` operate on the **same cloud sandbox filesystem**.
- Files created by `bash_exec` are immediately visible to `read_file` and
vice-versa — they share one filesystem.
- Use relative paths (resolved from `/home/user`) or absolute paths.
### Two storage systems — CRITICAL to understand
1. **Cloud sandbox** (`/home/user`):
- Shared by all file tools AND `bash_exec` — same filesystem
- Files **persist across turns** within the current session
- Full Linux environment with internet access
- Lost when the session expires (12 h inactivity)
2. **Persistent workspace** (cloud storage):
- Files here **survive across sessions indefinitely**
- Use `write_workspace_file` to save important files permanently
- Use `read_workspace_file` to retrieve previously saved files
- Use `list_workspace_files` to see what files you've saved before
- Call `list_workspace_files(include_all_sessions=True)` to see files from
all sessions
### Moving files between sandbox and persistent storage
- **Sandbox → Persistent**: Use `write_workspace_file` with `source_path`
to copy from the sandbox to permanent storage
- **Persistent → Sandbox**: Use `read_workspace_file` with `save_to_path`
to download into the sandbox for processing
### File persistence workflow
Important files that must survive beyond this session should be saved with
`write_workspace_file`. Sandbox files persist across turns but are lost
when the session expires.
"""
+ _SHARED_TOOL_NOTES
)
STREAM_LOCK_PREFIX = "copilot:stream:lock:"
@@ -291,8 +355,6 @@ def _cleanup_sdk_tool_results(cwd: str) -> None:
Security: *cwd* MUST be created by ``_make_sdk_cwd()`` which sanitizes
the session_id.
"""
import shutil
normalized = os.path.normpath(cwd)
if not normalized.startswith(_SDK_CWD_PREFIX):
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
@@ -324,8 +386,6 @@ async def _compress_conversation_history(
if len(messages) < 2:
return messages
from backend.util.prompt import compress_context
# Convert ChatMessages to dicts for compress_context
messages_dict = []
for msg in messages:
@@ -339,8 +399,6 @@ async def _compress_conversation_history(
messages_dict.append(msg_dict)
try:
import openai
async with openai.AsyncOpenAI(
api_key=config.api_key, base_url=config.base_url, timeout=30.0
) as client:
@@ -550,6 +608,7 @@ async def stream_chat_completion_sdk(
message_id = str(uuid.uuid4())
stream_id = str(uuid.uuid4())
stream_completed = False
e2b_sandbox = None
use_resume = False
resume_file: str | None = None
captured_transcript = CapturedTranscript()
@@ -579,7 +638,7 @@ async def stream_chat_completion_sdk(
# OTEL context manager — initialized inside the try and cleaned up in finally.
_otel_ctx: Any = None
# Make sure there is no more code between the lock acquitition and try-block.
# Make sure there is no more code between the lock acquisition and try-block.
try:
# Build system prompt (reuses non-SDK path with Langfuse support).
# Pre-compute the cwd here so the exact working directory path can be
@@ -597,17 +656,47 @@ async def stream_chat_completion_sdk(
code="sdk_cwd_error",
)
return
# Set up E2B sandbox for persistent cloud execution when configured.
# When active, MCP file tools route directly to the sandbox filesystem
# so bash_exec and file tools share the same /home/user directory.
if config.use_e2b_sandbox and not config.e2b_api_key:
logger.warning(
"[E2B] [%s] E2B sandbox enabled but no API key configured "
"(CHAT_E2B_API_KEY / E2B_API_KEY) — falling back to bubblewrap",
session_id[:12],
)
if config.use_e2b_sandbox and config.e2b_api_key:
try:
e2b_sandbox = await get_or_create_sandbox(
session_id,
api_key=config.e2b_api_key,
template=config.e2b_sandbox_template,
timeout=config.e2b_sandbox_timeout,
)
except Exception as e2b_err:
logger.error(
"[E2B] [%s] Setup failed: %s",
session_id[:12],
e2b_err,
exc_info=True,
)
e2b_sandbox = None
use_e2b = e2b_sandbox is not None
system_prompt, _ = await _build_system_prompt(
user_id, has_conversation_history=has_history
)
system_prompt += _build_sdk_tool_supplement(sdk_cwd)
system_prompt += (
_E2B_TOOL_SUPPLEMENT
if use_e2b
else _LOCAL_TOOL_SUPPLEMENT.format(cwd=sdk_cwd)
)
yield StreamStart(messageId=message_id, sessionId=session_id)
set_execution_context(user_id, session)
set_execution_context(user_id, session, sandbox=e2b_sandbox, sdk_cwd=sdk_cwd)
try:
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
# Fail fast when no API credentials are available at all
sdk_env = _build_sdk_env()
if not sdk_env and not os.environ.get("ANTHROPIC_API_KEY"):
@@ -617,7 +706,7 @@ async def stream_chat_completion_sdk(
"or ANTHROPIC_API_KEY for direct Anthropic access."
)
mcp_server = create_copilot_mcp_server()
mcp_server = create_copilot_mcp_server(use_e2b=use_e2b)
sdk_model = _resolve_sdk_model()
@@ -678,11 +767,13 @@ async def stream_chat_completion_sdk(
f"({len(session.messages)} messages in session)"
)
allowed = get_copilot_tool_names(use_e2b=use_e2b)
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
sdk_options_kwargs: dict[str, Any] = {
"system_prompt": system_prompt,
"mcp_servers": {"copilot": mcp_server},
"allowed_tools": COPILOT_TOOL_NAMES,
"disallowed_tools": SDK_DISALLOWED_TOOLS,
"allowed_tools": allowed,
"disallowed_tools": disallowed,
"hooks": security_hooks,
"cwd": sdk_cwd,
"max_buffer_size": config.claude_agent_max_buffer_size,
@@ -827,12 +918,6 @@ async def stream_chat_completion_sdk(
# AssistantMessages (each containing only
# ToolUseBlocks), we must NOT wait/flush — the prior
# tools are still executing concurrently.
from claude_agent_sdk import (
AssistantMessage,
ResultMessage,
ToolUseBlock,
)
is_parallel_continuation = isinstance(
sdk_msg, AssistantMessage
) and all(isinstance(b, ToolUseBlock) for b in sdk_msg.content)

View File

@@ -9,20 +9,84 @@ import itertools
import json
import logging
import os
import re
import uuid
from contextvars import ContextVar
from typing import Any
from typing import TYPE_CHECKING, Any
from claude_agent_sdk import create_sdk_mcp_server, tool
from backend.copilot.model import ChatSession
from backend.copilot.tools import TOOL_REGISTRY
from backend.copilot.tools.base import BaseTool
from backend.util.truncate import truncate
from .e2b_file_tools import E2B_FILE_TOOL_NAMES, E2B_FILE_TOOLS
if TYPE_CHECKING:
from e2b import AsyncSandbox
logger = logging.getLogger(__name__)
# Allowed base directory for the Read tool (SDK saves oversized tool results here).
# Restricted to ~/.claude/projects/ and further validated to require "tool-results"
# in the path — prevents reading settings, credentials, or other sensitive files.
_SDK_PROJECTS_DIR = os.path.expanduser("~/.claude/projects/")
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# Max MCP response size in chars — keeps tool output under the SDK's 10 MB JSON buffer.
_MCP_MAX_CHARS = 500_000
# Context variable holding the encoded project directory name for the current
# session (e.g. "-private-tmp-copilot-<uuid>"). Set by set_execution_context()
# so that path validation can scope tool-results reads to the current session.
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
def _encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does.
The CLI replaces all non-alphanumeric characters with ``-``.
"""
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
"""Check whether *path* is an allowed host-filesystem path.
Allowed:
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
- Files under ``~/.claude/projects/<encoded-cwd>/`` — the SDK's
project directory for this session (tool-results, transcripts, etc.)
Both checks are scoped to the **current session** so sessions cannot
read each other's data.
"""
if not path:
return False
if path.startswith("~"):
resolved = os.path.realpath(os.path.expanduser(path))
elif not os.path.isabs(path) and sdk_cwd:
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
else:
resolved = os.path.realpath(path)
# Allow access within the SDK working directory
if sdk_cwd:
norm_cwd = os.path.realpath(sdk_cwd)
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
return True
# Allow access within the current session's CLI project directory
# (~/.claude/projects/<encoded-cwd>/).
encoded = _current_project_dir.get("")
if encoded:
session_project = os.path.join(_SDK_PROJECTS_DIR, encoded)
if resolved == session_project or resolved.startswith(session_project + os.sep):
return True
return False
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
MCP_SERVER_NAME = "copilot"
@@ -33,6 +97,12 @@ _current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default
_current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
)
# E2B cloud sandbox for the current turn (None when E2B is not configured).
# Passed to bash_exec so commands run on E2B instead of the local bwrap sandbox.
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
"_current_sandbox", default=None
)
# Stash for MCP tool outputs before the SDK potentially truncates them.
# Keyed by tool_name → full output string. Consumed (popped) by the
# response adapter when it builds StreamToolOutputAvailable.
@@ -53,22 +123,33 @@ _stash_event: ContextVar[asyncio.Event | None] = ContextVar(
def set_execution_context(
user_id: str | None,
session: ChatSession,
sandbox: "AsyncSandbox | None" = None,
sdk_cwd: str | None = None,
) -> None:
"""Set the execution context for tool calls.
This must be called before streaming begins to ensure tools have access
to user_id and session information.
to user_id, session, and (optionally) an E2B sandbox for bash execution.
Args:
user_id: Current user's ID.
session: Current chat session.
sandbox: Optional E2B sandbox; when set, bash_exec routes commands there.
sdk_cwd: SDK working directory; used to scope tool-results reads.
"""
_current_user_id.set(user_id)
_current_session.set(session)
_current_sandbox.set(sandbox)
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
_pending_tool_outputs.set({})
_stash_event.set(asyncio.Event())
def get_current_sandbox() -> "AsyncSandbox | None":
"""Return the E2B sandbox for the current turn, or None."""
return _current_sandbox.get()
def get_execution_context() -> tuple[str | None, ChatSession | None]:
"""Get the current execution context."""
return (
@@ -182,11 +263,6 @@ async def _execute_tool_sync(
result.output if isinstance(result.output, str) else json.dumps(result.output)
)
# Stash the full output before the SDK potentially truncates it.
pending = _pending_tool_outputs.get(None)
if pending is not None:
pending.setdefault(base_tool.name, []).append(text)
content_blocks: list[dict[str, str]] = [{"type": "text", "text": text}]
# If the tool result contains inline image data, add an MCP image block
@@ -284,29 +360,32 @@ def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Read a file with optional offset/limit. Restricted to SDK working directory.
"""Read a local file with optional offset/limit.
After reading, the file is deleted to prevent accumulation in long-running pods.
Only allows paths that pass :func:`is_allowed_local_path` — the current
session's tool-results directory and ephemeral working directory.
"""
file_path = args.get("file_path", "")
offset = args.get("offset", 0)
limit = args.get("limit", 2000)
# Security: only allow reads under ~/.claude/projects/**/tool-results/
real_path = os.path.realpath(file_path)
if not real_path.startswith(_SDK_PROJECTS_DIR) or "tool-results" not in real_path:
if not is_allowed_local_path(file_path):
return {
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
"isError": True,
}
resolved = os.path.realpath(os.path.expanduser(file_path))
try:
with open(real_path) as f:
with open(resolved) as f:
selected = list(itertools.islice(f, offset, offset + limit))
content = "".join(selected)
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
# don't delete here — the SDK may read in multiple chunks.
return {"content": [{"type": "text", "text": content}], "isError": False}
return {
"content": [{"type": "text", "text": content}],
"isError": False,
}
except FileNotFoundError:
return {
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
@@ -345,49 +424,82 @@ _READ_TOOL_SCHEMA = {
# Create the MCP server configuration
def create_copilot_mcp_server():
def _text_from_mcp_result(result: dict[str, Any]) -> str:
"""Extract concatenated text from an MCP response's content blocks."""
content = result.get("content", [])
if isinstance(content, list):
parts = [
b.get("text", "")
for b in content
if isinstance(b, dict) and b.get("type") == "text"
]
return "".join(parts)
return ""
def create_copilot_mcp_server(*, use_e2b: bool = False):
"""Create an in-process MCP server configuration for CoPilot tools.
This can be passed to ClaudeAgentOptions.mcp_servers.
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
package being available. This function returns the configuration that
can be used with the SDK.
When *use_e2b* is True, five additional MCP file tools are registered
that route directly to the E2B sandbox filesystem, and the caller should
disable the corresponding SDK built-in tools via
:func:`get_sdk_disallowed_tools`.
"""
try:
from claude_agent_sdk import create_sdk_mcp_server, tool
# Create decorated tool functions
sdk_tools = []
def _truncating(fn, tool_name: str):
"""Wrap a tool handler so its response is truncated to stay under the
SDK's 10 MB JSON buffer, and stash the (truncated) output for the
response adapter before the SDK can apply its own head-truncation.
for tool_name, base_tool in TOOL_REGISTRY.items():
handler = create_tool_handler(base_tool)
decorated = tool(
tool_name,
base_tool.description,
_build_input_schema(base_tool),
)(handler)
Applied once to every registered tool."""
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
result = await fn(args)
truncated = truncate(result, _MCP_MAX_CHARS)
# Stash the text so the response adapter can forward our
# middle-out truncated version to the frontend instead of the
# SDK's head-truncated version (for outputs >~100 KB the SDK
# persists to tool-results/ with a 2 KB head-only preview).
if not truncated.get("isError"):
text = _text_from_mcp_result(truncated)
if text:
stash_pending_tool_output(tool_name, text)
return truncated
return wrapper
sdk_tools = []
for tool_name, base_tool in TOOL_REGISTRY.items():
handler = create_tool_handler(base_tool)
decorated = tool(
tool_name,
base_tool.description,
_build_input_schema(base_tool),
)(_truncating(handler, tool_name))
sdk_tools.append(decorated)
# E2B file tools replace SDK built-in Read/Write/Edit/Glob/Grep.
if use_e2b:
for name, desc, schema, handler in E2B_FILE_TOOLS:
decorated = tool(name, desc, schema)(_truncating(handler, name))
sdk_tools.append(decorated)
# Add the Read tool so the SDK can read back oversized tool results
read_tool = tool(
_READ_TOOL_NAME,
_READ_TOOL_DESCRIPTION,
_READ_TOOL_SCHEMA,
)(_read_file_handler)
sdk_tools.append(read_tool)
# Read tool for SDK-truncated tool results (always needed).
read_tool = tool(
_READ_TOOL_NAME,
_READ_TOOL_DESCRIPTION,
_READ_TOOL_SCHEMA,
)(_truncating(_read_file_handler, _READ_TOOL_NAME))
sdk_tools.append(read_tool)
server = create_sdk_mcp_server(
name=MCP_SERVER_NAME,
version="1.0.0",
tools=sdk_tools,
)
return server
except ImportError:
# Let ImportError propagate so service.py handles the fallback
raise
return create_sdk_mcp_server(
name=MCP_SERVER_NAME,
version="1.0.0",
tools=sdk_tools,
)
# SDK built-in tools allowed within the workspace directory.
@@ -397,16 +509,11 @@ def create_copilot_mcp_server():
# Task allows spawning sub-agents (rate-limited by security hooks).
# WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk.
# TodoWrite manages the task checklist shown in the UI — no security concern.
_SDK_BUILTIN_TOOLS = [
"Read",
"Write",
"Edit",
"Glob",
"Grep",
"Task",
"WebSearch",
"TodoWrite",
]
# In E2B mode, all five are disabled — MCP equivalents provide direct sandbox
# access. read_file also handles local tool-results and ephemeral reads.
_SDK_BUILTIN_FILE_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep"]
_SDK_BUILTIN_ALWAYS = ["Task", "WebSearch", "TodoWrite"]
_SDK_BUILTIN_TOOLS = [*_SDK_BUILTIN_FILE_TOOLS, *_SDK_BUILTIN_ALWAYS]
# SDK built-in tools that must be explicitly blocked.
# Bash: dangerous — agent uses mcp__copilot__bash_exec with kernel-level
@@ -453,11 +560,37 @@ DANGEROUS_PATTERNS = [
r"subprocess",
]
# List of tool names for allowed_tools configuration
# Include MCP tools, the MCP Read tool for oversized results,
# and SDK built-in file tools for workspace operations.
# Static tool name list for the non-E2B case (backward compatibility).
COPILOT_TOOL_NAMES = [
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
*_SDK_BUILTIN_TOOLS,
]
def get_copilot_tool_names(*, use_e2b: bool = False) -> list[str]:
"""Build the ``allowed_tools`` list for :class:`ClaudeAgentOptions`.
When *use_e2b* is True the SDK built-in file tools are replaced by MCP
equivalents that route to the E2B sandbox.
"""
if not use_e2b:
return list(COPILOT_TOOL_NAMES)
return [
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
*[f"{MCP_TOOL_PREFIX}{name}" for name in E2B_FILE_TOOL_NAMES],
*_SDK_BUILTIN_ALWAYS,
]
def get_sdk_disallowed_tools(*, use_e2b: bool = False) -> list[str]:
"""Build the ``disallowed_tools`` list for :class:`ClaudeAgentOptions`.
When *use_e2b* is True the SDK built-in file tools are also disabled
because MCP equivalents provide direct sandbox access.
"""
if not use_e2b:
return list(SDK_DISALLOWED_TOOLS)
return [*SDK_DISALLOWED_TOOLS, *_SDK_BUILTIN_FILE_TOOLS]

View File

@@ -0,0 +1,145 @@
"""Tests for tool_adapter helpers: _text_from_mcp_result, truncation stash."""
import pytest
from backend.util.truncate import truncate
from .tool_adapter import (
_MCP_MAX_CHARS,
_text_from_mcp_result,
pop_pending_tool_output,
set_execution_context,
stash_pending_tool_output,
)
# ---------------------------------------------------------------------------
# _text_from_mcp_result
# ---------------------------------------------------------------------------
class TestTextFromMcpResult:
def test_single_text_block(self):
result = {"content": [{"type": "text", "text": "hello"}]}
assert _text_from_mcp_result(result) == "hello"
def test_multiple_text_blocks_concatenated(self):
result = {
"content": [
{"type": "text", "text": "one"},
{"type": "text", "text": "two"},
]
}
assert _text_from_mcp_result(result) == "onetwo"
def test_non_text_blocks_ignored(self):
result = {
"content": [
{"type": "image", "data": "..."},
{"type": "text", "text": "only this"},
]
}
assert _text_from_mcp_result(result) == "only this"
def test_empty_content_list(self):
assert _text_from_mcp_result({"content": []}) == ""
def test_missing_content_key(self):
assert _text_from_mcp_result({}) == ""
def test_non_list_content(self):
assert _text_from_mcp_result({"content": "raw string"}) == ""
def test_missing_text_field(self):
result = {"content": [{"type": "text"}]}
assert _text_from_mcp_result(result) == ""
# ---------------------------------------------------------------------------
# stash / pop round-trip (the mechanism _truncating relies on)
# ---------------------------------------------------------------------------
class TestToolOutputStash:
@pytest.fixture(autouse=True)
def _init_context(self):
"""Initialise the context vars that stash_pending_tool_output needs."""
set_execution_context(
user_id="test",
session=None, # type: ignore[arg-type]
sandbox=None,
sdk_cwd="/tmp/test",
)
def test_stash_and_pop(self):
stash_pending_tool_output("my_tool", "output1")
assert pop_pending_tool_output("my_tool") == "output1"
def test_pop_empty_returns_none(self):
assert pop_pending_tool_output("nonexistent") is None
def test_fifo_order(self):
stash_pending_tool_output("t", "first")
stash_pending_tool_output("t", "second")
assert pop_pending_tool_output("t") == "first"
assert pop_pending_tool_output("t") == "second"
assert pop_pending_tool_output("t") is None
def test_dict_serialised_to_json(self):
stash_pending_tool_output("t", {"key": "value"})
assert pop_pending_tool_output("t") == '{"key": "value"}'
def test_separate_tool_names(self):
stash_pending_tool_output("a", "alpha")
stash_pending_tool_output("b", "beta")
assert pop_pending_tool_output("b") == "beta"
assert pop_pending_tool_output("a") == "alpha"
# ---------------------------------------------------------------------------
# _truncating wrapper (integration via create_copilot_mcp_server)
# ---------------------------------------------------------------------------
class TestTruncationAndStashIntegration:
"""Test truncation + stash behavior that _truncating relies on."""
@pytest.fixture(autouse=True)
def _init_context(self):
set_execution_context(
user_id="test",
session=None, # type: ignore[arg-type]
sandbox=None,
sdk_cwd="/tmp/test",
)
def test_small_output_stashed(self):
"""Non-error output is stashed for the response adapter."""
result = {
"content": [{"type": "text", "text": "small output"}],
"isError": False,
}
truncated = truncate(result, _MCP_MAX_CHARS)
text = _text_from_mcp_result(truncated)
assert text == "small output"
stash_pending_tool_output("test_tool", text)
assert pop_pending_tool_output("test_tool") == "small output"
def test_error_result_not_stashed(self):
"""Error results should not be stashed."""
result = {
"content": [{"type": "text", "text": "error msg"}],
"isError": True,
}
# _truncating only stashes when not result.get("isError")
if not result.get("isError"):
stash_pending_tool_output("err_tool", "should not happen")
assert pop_pending_tool_output("err_tool") is None
def test_large_output_truncated(self):
"""Output exceeding _MCP_MAX_CHARS is truncated before stashing."""
big_text = "x" * (_MCP_MAX_CHARS + 100_000)
result = {"content": [{"type": "text", "text": big_text}]}
truncated = truncate(result, _MCP_MAX_CHARS)
text = _text_from_mcp_result(truncated)
assert len(text) < len(big_text)
assert len(str(truncated)) <= _MCP_MAX_CHARS

View File

@@ -59,7 +59,7 @@ from .response_model import (
StreamToolOutputAvailable,
StreamUsage,
)
from .tools import execute_tool, tools
from .tools import execute_tool, get_available_tools
from .tools.models import ErrorResponse
from .tracking import track_user_message
@@ -514,7 +514,7 @@ async def stream_chat_completion(
)
async for chunk in _stream_chat_chunks(
session=session,
tools=tools,
tools=get_available_tools(),
system_prompt=system_prompt,
text_block_id=text_block_id,
):

View File

@@ -1,12 +1,14 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
from openai.types.chat import ChatCompletionToolParam
from backend.copilot.model import ChatSession
from backend.copilot.tracking import track_tool_called
from .add_understanding import AddUnderstandingTool
from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreenshotTool
from .agent_output import AgentOutputTool
from .base import BaseTool
from .bash_exec import BashExecTool
@@ -30,6 +32,7 @@ from .workspace_files import (
)
if TYPE_CHECKING:
from backend.copilot.model import ChatSession
from backend.copilot.response_model import StreamToolOutputAvailable
logger = logging.getLogger(__name__)
@@ -50,6 +53,10 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"get_doc_page": GetDocPageTool(),
# Web fetch for safe URL retrieval
"web_fetch": WebFetchTool(),
# Agent-browser multi-step automation (navigate, act, screenshot)
"browser_navigate": BrowserNavigateTool(),
"browser_act": BrowserActTool(),
"browser_screenshot": BrowserScreenshotTool(),
# Sandboxed code execution (bubblewrap)
"bash_exec": BashExecTool(),
# Persistent workspace tools (cloud storage, survives across sessions)
@@ -67,10 +74,17 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
find_agent_tool = TOOL_REGISTRY["find_agent"]
run_agent_tool = TOOL_REGISTRY["run_agent"]
# Generated from registry for OpenAI API
tools: list[ChatCompletionToolParam] = [
tool.as_openai_tool() for tool in TOOL_REGISTRY.values()
]
def get_available_tools() -> list[ChatCompletionToolParam]:
"""Return OpenAI tool schemas for tools available in the current environment.
Called per-request so that env-var or binary availability is evaluated
fresh each time (e.g. browser_* tools are excluded when agent-browser
CLI is not installed).
"""
return [
tool.as_openai_tool() for tool in TOOL_REGISTRY.values() if tool.is_available
]
def get_tool(tool_name: str) -> BaseTool | None:

View File

@@ -0,0 +1,876 @@
"""Agent-browser tools — multi-step browser automation for the Copilot.
Uses the agent-browser CLI (https://github.com/vercel-labs/agent-browser)
which runs a local Chromium instance managed by a persistent daemon.
- Runs locally — no cloud account required
- Full interaction support: click, fill, scroll, login flows, multi-step
- Session persistence via --session-name: cookies/auth carry across tool calls
within the same Copilot session, enabling login → navigate → extract workflows
- Screenshot with --annotate overlays @ref labels, saved to workspace for user
- The Claude Agent SDK's multi-turn loop handles orchestration — each tool call
is one browser action; the LLM chains them naturally
SSRF protection:
Uses the shared validate_url() from backend.util.request, which is the same
guard used by HTTP blocks and web_fetch. It resolves ALL DNS answers (not just
the first), blocks RFC 1918, loopback, link-local, 0.0.0.0/8, multicast,
and all relevant IPv6 ranges, and applies IDNA encoding to prevent Unicode
domain attacks.
Requires:
npm install -g agent-browser
agent-browser install (downloads Chromium, one-time per machine)
"""
import asyncio
import base64
import json
import logging
import os
import shutil
import tempfile
from typing import Any
from backend.copilot.model import ChatSession
from backend.util.request import validate_url
from .base import BaseTool
from .models import (
BrowserActResponse,
BrowserNavigateResponse,
BrowserScreenshotResponse,
ErrorResponse,
ToolResponseBase,
)
from .workspace_files import get_manager
logger = logging.getLogger(__name__)
# Per-command timeout (seconds). Navigation + networkidle wait can be slow.
_CMD_TIMEOUT = 45
# Accessibility tree can be very large; cap it to keep LLM context manageable.
_MAX_SNAPSHOT_CHARS = 20_000
# ---------------------------------------------------------------------------
# Subprocess helper
# ---------------------------------------------------------------------------
async def _run(
session_name: str,
*args: str,
timeout: int = _CMD_TIMEOUT,
) -> tuple[int, str, str]:
"""Run agent-browser for the given session and return (rc, stdout, stderr).
Uses both:
--session <name> → isolated Chromium context (no shared history/cookies
with other Copilot sessions — prevents cross-session
browser state leakage)
--session-name <name> → persist cookies/localStorage across tool calls within
the same session (enables login → navigate flows)
"""
cmd = [
"agent-browser",
"--session",
session_name,
"--session-name",
session_name,
*args,
]
proc = None
try:
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout)
return proc.returncode or 0, stdout.decode(), stderr.decode()
except asyncio.TimeoutError:
# Kill the orphaned subprocess so it does not linger in the process table.
if proc is not None and proc.returncode is None:
proc.kill()
try:
await proc.communicate()
except Exception:
pass # Best-effort reap; ignore errors during cleanup.
return 1, "", f"Command timed out after {timeout}s."
except FileNotFoundError:
return (
1,
"",
"agent-browser is not installed (run: npm install -g agent-browser && agent-browser install).",
)
async def _snapshot(session_name: str) -> str:
"""Return the current page's interactive accessibility tree, truncated."""
rc, stdout, stderr = await _run(session_name, "snapshot", "-i", "-c")
if rc != 0:
return f"[snapshot failed: {stderr[:300]}]"
text = stdout.strip()
if len(text) > _MAX_SNAPSHOT_CHARS:
suffix = "\n\n[Snapshot truncated — use browser_act to navigate further]"
keep = max(0, _MAX_SNAPSHOT_CHARS - len(suffix))
text = text[:keep] + suffix
return text
# ---------------------------------------------------------------------------
# Stateless session helpers — persist / restore browser state across pods
# ---------------------------------------------------------------------------
# Module-level cache of sessions known to be alive on this pod.
# Avoids the subprocess probe on every tool call within the same pod.
_alive_sessions: set[str] = set()
# Per-session locks to prevent concurrent _ensure_session calls from
# triggering duplicate _restore_browser_state for the same session.
# Protected by _session_locks_mutex to ensure setdefault/pop are not
# interleaved across await boundaries.
_session_locks: dict[str, asyncio.Lock] = {}
_session_locks_mutex = asyncio.Lock()
# Workspace filename for persisted browser state (auto-scoped to session).
# Dot-prefixed so it is hidden from user workspace listings.
_STATE_FILENAME = "._browser_state.json"
# Maximum concurrent subprocesses during cookie/storage restore.
_RESTORE_CONCURRENCY = 10
# Maximum cookies to restore per session. Pathological sites can accumulate
# thousands of cookies; restoring them all would be slow and is rarely useful.
_MAX_RESTORE_COOKIES = 100
# Background tasks for fire-and-forget state persistence.
# Prevents GC from collecting tasks before they complete.
_background_tasks: set[asyncio.Task] = set()
def _fire_and_forget_save(
session_name: str, user_id: str, session: ChatSession
) -> None:
"""Schedule state persistence as a background task (non-blocking).
State save is already best-effort (errors are swallowed), so running it
in the background avoids adding latency to tool responses.
"""
task = asyncio.create_task(_save_browser_state(session_name, user_id, session))
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
async def _has_local_session(session_name: str) -> bool:
"""Check if the local agent-browser daemon for this session is running."""
rc, _, _ = await _run(session_name, "get", "url", timeout=5)
return rc == 0
async def _save_browser_state(
session_name: str, user_id: str, session: ChatSession
) -> None:
"""Persist browser state (cookies, localStorage, URL) to workspace.
Best-effort: errors are logged but never propagate to the tool response.
"""
try:
# Gather state in parallel
(rc_url, url_out, _), (rc_ck, ck_out, _), (rc_ls, ls_out, _) = (
await asyncio.gather(
_run(session_name, "get", "url", timeout=10),
_run(session_name, "cookies", "get", "--json", timeout=10),
_run(session_name, "storage", "local", "--json", timeout=10),
)
)
state = {
"url": url_out.strip() if rc_url == 0 else "",
"cookies": (json.loads(ck_out) if rc_ck == 0 and ck_out.strip() else []),
"local_storage": (
json.loads(ls_out) if rc_ls == 0 and ls_out.strip() else {}
),
}
manager = await get_manager(user_id, session.session_id)
await manager.write_file(
content=json.dumps(state).encode("utf-8"),
filename=_STATE_FILENAME,
mime_type="application/json",
overwrite=True,
)
except Exception:
logger.warning(
"[browser] Failed to save browser state for session %s",
session_name,
exc_info=True,
)
async def _restore_browser_state(
session_name: str, user_id: str, session: ChatSession
) -> bool:
"""Restore browser state from workspace storage into a fresh daemon.
Best-effort: errors are logged but never propagate to the tool response.
Returns True on success (or no state to restore), False on failure.
"""
try:
manager = await get_manager(user_id, session.session_id)
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
if file_info is None:
return True # No saved state — first call or never saved
state_bytes = await manager.read_file(_STATE_FILENAME)
state = json.loads(state_bytes.decode("utf-8"))
url = state.get("url", "")
cookies = state.get("cookies", [])
local_storage = state.get("local_storage", {})
# Navigate first — starts daemon + sets the correct origin for cookies
if url:
# Validate the saved URL to prevent SSRF via stored redirect targets.
try:
await validate_url(url, trusted_origins=[])
except ValueError:
logger.warning(
"[browser] State restore: blocked SSRF URL %s", url[:200]
)
return False
rc, _, stderr = await _run(session_name, "open", url)
if rc != 0:
logger.warning(
"[browser] State restore: failed to open %s: %s",
url,
stderr[:200],
)
return False
await _run(session_name, "wait", "--load", "load", timeout=15)
# Restore cookies and localStorage in parallel via asyncio.gather.
# Semaphore caps concurrent subprocess spawns so we don't overwhelm the
# system when a session has hundreds of cookies.
sem = asyncio.Semaphore(_RESTORE_CONCURRENCY)
# Guard against pathological sites with thousands of cookies.
if len(cookies) > _MAX_RESTORE_COOKIES:
logger.debug(
"[browser] State restore: capping cookies from %d to %d",
len(cookies),
_MAX_RESTORE_COOKIES,
)
cookies = cookies[:_MAX_RESTORE_COOKIES]
async def _set_cookie(c: dict[str, Any]) -> None:
name = c.get("name", "")
value = c.get("value", "")
domain = c.get("domain", "")
path = c.get("path", "/")
if not (name and domain):
return
async with sem:
rc, _, stderr = await _run(
session_name,
"cookies",
"set",
name,
value,
"--domain",
domain,
"--path",
path,
timeout=5,
)
if rc != 0:
logger.debug(
"[browser] State restore: cookie set failed for %s: %s",
name,
stderr[:100],
)
async def _set_storage(key: str, val: object) -> None:
async with sem:
rc, _, stderr = await _run(
session_name,
"storage",
"local",
"set",
key,
str(val),
timeout=5,
)
if rc != 0:
logger.debug(
"[browser] State restore: localStorage set failed for %s: %s",
key,
stderr[:100],
)
await asyncio.gather(
*[_set_cookie(c) for c in cookies],
*[_set_storage(k, v) for k, v in local_storage.items()],
)
return True
except Exception:
logger.warning(
"[browser] Failed to restore browser state for session %s",
session_name,
exc_info=True,
)
return False
async def _ensure_session(
session_name: str, user_id: str, session: ChatSession
) -> None:
"""Ensure the local browser daemon has state. Restore from cloud if needed."""
if session_name in _alive_sessions:
return
async with _session_locks_mutex:
lock = _session_locks.setdefault(session_name, asyncio.Lock())
async with lock:
# Double-check after acquiring lock — another coroutine may have restored.
if session_name in _alive_sessions:
return
if await _has_local_session(session_name):
_alive_sessions.add(session_name)
return
if await _restore_browser_state(session_name, user_id, session):
_alive_sessions.add(session_name)
async def close_browser_session(session_name: str, user_id: str | None = None) -> None:
"""Shut down the local agent-browser daemon and clean up stored state.
Deletes ``._browser_state.json`` from workspace storage so cookies and
other credentials do not linger after the session is deleted.
Best-effort: errors are logged but never raised.
"""
_alive_sessions.discard(session_name)
async with _session_locks_mutex:
_session_locks.pop(session_name, None)
# Delete persisted browser state (cookies, localStorage) from workspace.
if user_id:
try:
manager = await get_manager(user_id, session_name)
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
if file_info is not None:
await manager.delete_file(file_info.id)
except Exception:
logger.debug(
"[browser] Failed to delete state file for session %s",
session_name,
exc_info=True,
)
try:
rc, _, stderr = await _run(session_name, "close", timeout=10)
if rc != 0:
logger.debug(
"[browser] close failed for session %s: %s",
session_name,
stderr[:200],
)
except Exception:
logger.debug(
"[browser] Exception closing browser session %s",
session_name,
exc_info=True,
)
# ---------------------------------------------------------------------------
# Tool: browser_navigate
# ---------------------------------------------------------------------------
class BrowserNavigateTool(BaseTool):
"""Navigate to a URL and return the page's interactive elements.
The browser session persists across tool calls within this Copilot session
(keyed to session_id), so cookies and auth state carry over. This enables
full login flows: navigate to login page → browser_act to fill credentials
→ browser_act to submit → browser_navigate to the target page.
"""
@property
def name(self) -> str:
return "browser_navigate"
@property
def description(self) -> str:
return (
"Navigate to a URL using a real browser. Returns an accessibility "
"tree snapshot listing the page's interactive elements with @ref IDs "
"(e.g. @e3) that can be used with browser_act. "
"Session persists — cookies and login state carry over between calls. "
"Use this (with browser_act) for multi-step interaction: login flows, "
"form filling, button clicks, or anything requiring page interaction. "
"For plain static pages, prefer web_fetch — no browser overhead. "
"For authenticated pages: navigate to the login page first, use browser_act "
"to fill credentials and submit, then navigate to the target page. "
"Note: for slow SPAs, the returned snapshot may reflect a partially-loaded "
"state. If elements seem missing, use browser_act with action='wait' and a "
"CSS selector or millisecond delay, then take a browser_screenshot to verify."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The HTTP/HTTPS URL to navigate to.",
},
"wait_for": {
"type": "string",
"enum": ["networkidle", "load", "domcontentloaded"],
"default": "networkidle",
"description": "When to consider navigation complete. Use 'networkidle' for SPAs (default).",
},
},
"required": ["url"],
}
@property
def requires_auth(self) -> bool:
return True
@property
def is_available(self) -> bool:
return shutil.which("agent-browser") is not None
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
) -> ToolResponseBase:
"""Navigate to *url*, wait for the page to settle, and return a snapshot.
The snapshot is an accessibility-tree listing of interactive elements.
Note: for slow SPAs that never fully idle, the snapshot may reflect a
partially-loaded state (the wait is best-effort).
"""
url: str = (kwargs.get("url") or "").strip()
wait_for: str = kwargs.get("wait_for") or "networkidle"
session_name = session.session_id
if not url:
return ErrorResponse(
message="Please provide a URL to navigate to.",
error="missing_url",
session_id=session_name,
)
try:
await validate_url(url, trusted_origins=[])
except ValueError as e:
return ErrorResponse(
message=str(e),
error="blocked_url",
session_id=session_name,
)
# Restore browser state from cloud if this is a different pod
if user_id:
await _ensure_session(session_name, user_id, session)
# Navigate
rc, _, stderr = await _run(session_name, "open", url)
if rc != 0:
logger.warning(
"[browser_navigate] open failed for %s: %s", url, stderr[:300]
)
return ErrorResponse(
message="Failed to navigate to URL.",
error="navigation_failed",
session_id=session_name,
)
# Wait for page to settle (best-effort: some SPAs never reach networkidle)
wait_rc, _, wait_err = await _run(session_name, "wait", "--load", wait_for)
if wait_rc != 0:
logger.warning(
"[browser_navigate] wait(%s) failed: %s", wait_for, wait_err[:300]
)
# Get current title and URL in parallel
(_, title_out, _), (_, url_out, _) = await asyncio.gather(
_run(session_name, "get", "title"),
_run(session_name, "get", "url"),
)
snapshot = await _snapshot(session_name)
result = BrowserNavigateResponse(
message=f"Navigated to {url}",
url=url_out.strip() or url,
title=title_out.strip(),
snapshot=snapshot,
session_id=session_name,
)
# Persist browser state to cloud for cross-pod continuity
if user_id:
_fire_and_forget_save(session_name, user_id, session)
return result
# ---------------------------------------------------------------------------
# Tool: browser_act
# ---------------------------------------------------------------------------
_NO_TARGET_ACTIONS = frozenset({"back", "forward", "reload"})
_SCROLL_ACTIONS = frozenset({"scroll"})
_TARGET_ONLY_ACTIONS = frozenset({"click", "dblclick", "hover", "check", "uncheck"})
_TARGET_VALUE_ACTIONS = frozenset({"fill", "type", "select"})
# wait <selector|ms>: waits for a DOM element or a fixed delay (e.g. "1000" for 1 s)
_WAIT_ACTIONS = frozenset({"wait"})
class BrowserActTool(BaseTool):
"""Perform an action on the current browser page and return the updated snapshot.
Use @ref IDs from the snapshot returned by browser_navigate (e.g. '@e3').
The LLM orchestrates multi-step flows by chaining browser_navigate and
browser_act calls across turns of the Claude Agent SDK conversation.
"""
@property
def name(self) -> str:
return "browser_act"
@property
def description(self) -> str:
return (
"Interact with the current browser page. Use @ref IDs from the "
"snapshot (e.g. '@e3') to target elements. Returns an updated snapshot. "
"Supported actions: click, dblclick, fill, type, scroll, hover, press, "
"check, uncheck, select, wait, back, forward, reload. "
"fill clears the field before typing; type appends without clearing. "
"wait accepts a CSS selector (waits for element) or milliseconds string (e.g. '1000'). "
"Example login flow: fill @e1 with email → fill @e2 with password → "
"click @e3 (submit) → browser_navigate to the target page."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": [
"click",
"dblclick",
"fill",
"type",
"scroll",
"hover",
"press",
"check",
"uncheck",
"select",
"wait",
"back",
"forward",
"reload",
],
"description": "The action to perform.",
},
"target": {
"type": "string",
"description": (
"Element to target. Use @ref from snapshot (e.g. '@e3'), "
"a CSS selector, or a text description. "
"Required for: click, dblclick, fill, type, hover, check, uncheck, select. "
"For wait: a CSS selector to wait for, or milliseconds as a string (e.g. '1000')."
),
},
"value": {
"type": "string",
"description": (
"For fill/type: the text to enter. "
"For press: key name (e.g. 'Enter', 'Tab', 'Control+a'). "
"For select: the option value to select."
),
},
"direction": {
"type": "string",
"enum": ["up", "down", "left", "right"],
"default": "down",
"description": "For scroll: direction to scroll.",
},
},
"required": ["action"],
}
@property
def requires_auth(self) -> bool:
return True
@property
def is_available(self) -> bool:
return shutil.which("agent-browser") is not None
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
) -> ToolResponseBase:
"""Perform a browser action and return an updated page snapshot.
Validates the *action*/*target*/*value* combination, delegates to
``agent-browser``, waits for the page to settle, and returns the
accessibility-tree snapshot so the LLM can plan the next step.
"""
action: str = (kwargs.get("action") or "").strip()
target: str = (kwargs.get("target") or "").strip()
value: str = (kwargs.get("value") or "").strip()
direction: str = (kwargs.get("direction") or "down").strip()
session_name = session.session_id
if not action:
return ErrorResponse(
message="Please specify an action.",
error="missing_action",
session_id=session_name,
)
# Build the agent-browser command args
if action in _NO_TARGET_ACTIONS:
cmd_args = [action]
elif action in _SCROLL_ACTIONS:
cmd_args = ["scroll", direction]
elif action == "press":
if not value:
return ErrorResponse(
message="'press' requires a 'value' (key name, e.g. 'Enter').",
error="missing_value",
session_id=session_name,
)
cmd_args = ["press", value]
elif action in _TARGET_ONLY_ACTIONS:
if not target:
return ErrorResponse(
message=f"'{action}' requires a 'target' element.",
error="missing_target",
session_id=session_name,
)
cmd_args = [action, target]
elif action in _TARGET_VALUE_ACTIONS:
if not target or not value:
return ErrorResponse(
message=f"'{action}' requires both 'target' and 'value'.",
error="missing_params",
session_id=session_name,
)
cmd_args = [action, target, value]
elif action in _WAIT_ACTIONS:
if not target:
return ErrorResponse(
message=(
"'wait' requires a 'target': a CSS selector to wait for, "
"or milliseconds as a string (e.g. '1000')."
),
error="missing_target",
session_id=session_name,
)
cmd_args = ["wait", target]
else:
return ErrorResponse(
message=f"Unsupported action: {action}",
error="invalid_action",
session_id=session_name,
)
# Restore browser state from cloud if this is a different pod
if user_id:
await _ensure_session(session_name, user_id, session)
rc, _, stderr = await _run(session_name, *cmd_args)
if rc != 0:
logger.warning("[browser_act] %s failed: %s", action, stderr[:300])
return ErrorResponse(
message=f"Action '{action}' failed.",
error="action_failed",
session_id=session_name,
)
# Allow the page to settle after interaction (best-effort: SPAs may not idle)
settle_rc, _, settle_err = await _run(
session_name, "wait", "--load", "networkidle"
)
if settle_rc != 0:
logger.warning(
"[browser_act] post-action wait failed: %s", settle_err[:300]
)
snapshot = await _snapshot(session_name)
_, url_out, _ = await _run(session_name, "get", "url")
result = BrowserActResponse(
message=f"Performed '{action}'" + (f" on '{target}'" if target else ""),
action=action,
current_url=url_out.strip(),
snapshot=snapshot,
session_id=session_name,
)
# Persist browser state to cloud for cross-pod continuity
if user_id:
_fire_and_forget_save(session_name, user_id, session)
return result
# ---------------------------------------------------------------------------
# Tool: browser_screenshot
# ---------------------------------------------------------------------------
class BrowserScreenshotTool(BaseTool):
"""Capture a screenshot of the current browser page and save it to the workspace."""
@property
def name(self) -> str:
return "browser_screenshot"
@property
def description(self) -> str:
return (
"Take a screenshot of the current browser page and save it to the workspace. "
"IMPORTANT: After calling this tool, immediately call read_workspace_file "
"with the returned file_id to display the image inline to the user — "
"the screenshot is not visible until you do this. "
"With annotate=true (default), @ref labels are overlaid on interactive "
"elements, making it easy to see which @ref ID maps to which element on screen."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"annotate": {
"type": "boolean",
"default": True,
"description": "Overlay @ref labels on interactive elements (default: true).",
},
"filename": {
"type": "string",
"default": "screenshot.png",
"description": "Filename to save in the workspace.",
},
},
}
@property
def requires_auth(self) -> bool:
return True
@property
def is_available(self) -> bool:
return shutil.which("agent-browser") is not None
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
) -> ToolResponseBase:
"""Capture a PNG screenshot and upload it to the workspace.
Handles string-to-bool coercion for *annotate* (OpenAI function-call
payloads sometimes deliver ``"true"``/``"false"`` as strings).
Returns a :class:`BrowserScreenshotResponse` with the workspace
``file_id`` the LLM should pass to ``read_workspace_file``.
"""
raw_annotate = kwargs.get("annotate", True)
if isinstance(raw_annotate, str):
annotate = raw_annotate.strip().lower() in {"1", "true", "yes", "on"}
else:
annotate = bool(raw_annotate)
filename: str = (kwargs.get("filename") or "screenshot.png").strip()
session_name = session.session_id
# Restore browser state from cloud if this is a different pod
if user_id:
await _ensure_session(session_name, user_id, session)
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".png")
os.close(tmp_fd)
try:
cmd_args = ["screenshot"]
if annotate:
cmd_args.append("--annotate")
cmd_args.append(tmp_path)
rc, _, stderr = await _run(session_name, *cmd_args)
if rc != 0:
logger.warning("[browser_screenshot] failed: %s", stderr[:300])
return ErrorResponse(
message="Failed to take screenshot.",
error="screenshot_failed",
session_id=session_name,
)
with open(tmp_path, "rb") as f:
png_bytes = f.read()
finally:
try:
os.unlink(tmp_path)
except OSError:
pass # Best-effort temp file cleanup; not critical if it fails.
# Upload to workspace so the user can view it
png_b64 = base64.b64encode(png_bytes).decode()
# Import here to avoid circular deps — workspace_files imports from .models
from .workspace_files import WorkspaceWriteResponse, WriteWorkspaceFileTool
write_resp = await WriteWorkspaceFileTool()._execute(
user_id=user_id,
session=session,
filename=filename,
content_base64=png_b64,
)
if not isinstance(write_resp, WorkspaceWriteResponse):
return ErrorResponse(
message="Screenshot taken but failed to save to workspace.",
error="workspace_write_failed",
session_id=session_name,
)
result = BrowserScreenshotResponse(
message=f"Screenshot saved to workspace as '{filename}'. Use read_workspace_file with file_id='{write_resp.file_id}' to retrieve it.",
file_id=write_resp.file_id,
filename=filename,
session_id=session_name,
)
# Persist browser state to cloud for cross-pod continuity
if user_id:
_fire_and_forget_save(session_name, user_id, session)
return result

File diff suppressed because it is too large Load Diff

View File

@@ -36,6 +36,16 @@ class BaseTool:
"""Whether this tool requires authentication."""
return False
@property
def is_available(self) -> bool:
"""Whether this tool is available in the current environment.
Override to check required env vars, binaries, or other dependencies.
Unavailable tools are excluded from the LLM tool list so the model is
never offered an option that will immediately fail.
"""
return True
def as_openai_tool(self) -> ChatCompletionToolParam:
"""Convert to OpenAI tool format."""
return ChatCompletionToolParam(

View File

@@ -1,19 +1,30 @@
"""Bash execution tool — run shell commands in a bubblewrap sandbox.
"""Bash execution tool — run shell commands on E2B or in a bubblewrap sandbox.
Full Bash scripting is allowed (loops, conditionals, pipes, functions, etc.).
Safety comes from OS-level isolation (bubblewrap): only system dirs visible
read-only, writable workspace only, clean env, no network.
When an E2B sandbox is available in the current execution context the command
runs directly on the remote E2B cloud environment. This means:
Requires bubblewrap (``bwrap``) — the tool is disabled when bwrap is not
available (e.g. macOS development).
- **Persistent filesystem**: files survive across turns via HTTP-based sync
with the sandbox's ``/home/user`` directory (E2B files API), shared with
SDK Read/Write/Edit tools.
- **Full internet access**: E2B sandboxes have unrestricted outbound network.
- **Execution isolation**: E2B provides a fresh, containerised Linux environment.
When E2B is *not* configured the tool falls back to **bubblewrap** (bwrap):
OS-level isolation with a whitelist-only filesystem, no network, and resource
limits. Requires bubblewrap to be installed (Linux only).
"""
import logging
import shlex
from typing import Any
from e2b import AsyncSandbox
from e2b.exceptions import TimeoutException
from backend.copilot.model import ChatSession
from .base import BaseTool
from .e2b_sandbox import E2B_WORKDIR
from .models import BashExecResponse, ErrorResponse, ToolResponseBase
from .sandbox import get_workspace_dir, has_full_sandbox, run_sandboxed
@@ -21,7 +32,7 @@ logger = logging.getLogger(__name__)
class BashExecTool(BaseTool):
"""Execute Bash commands in a bubblewrap sandbox."""
"""Execute Bash commands on E2B or in a bubblewrap sandbox."""
@property
def name(self) -> str:
@@ -29,28 +40,16 @@ class BashExecTool(BaseTool):
@property
def description(self) -> str:
if not has_full_sandbox():
return (
"Bash execution is DISABLED — bubblewrap sandbox is not "
"available on this platform. Do not call this tool."
)
return (
"Execute a Bash command or script in a bubblewrap sandbox. "
"Execute a Bash command or script. "
"Full Bash scripting is supported (loops, conditionals, pipes, "
"functions, etc.). "
"The sandbox shares the same working directory as the SDK Read/Write "
"tools — files created by either are accessible to both. "
"SECURITY: Only system directories (/usr, /bin, /lib, /etc) are "
"visible read-only, the per-session workspace is the only writable "
"path, environment variables are wiped (no secrets), all network "
"access is blocked at the kernel level, and resource limits are "
"enforced (max 64 processes, 512MB memory, 50MB file size). "
"Application code, configs, and other directories are NOT accessible. "
"To fetch web content, use the web_fetch tool instead. "
"The working directory is shared with the SDK Read/Write/Edit/Glob/Grep "
"tools — files created by either are immediately visible to both. "
"Execution is killed after the timeout (default 30s, max 120s). "
"Returns stdout and stderr. "
"Useful for file manipulation, data processing with Unix tools "
"(grep, awk, sed, jq, etc.), and running shell scripts."
"Useful for file manipulation, data processing, running scripts, "
"and installing packages."
)
@property
@@ -85,15 +84,8 @@ class BashExecTool(BaseTool):
) -> ToolResponseBase:
session_id = session.session_id if session else None
if not has_full_sandbox():
return ErrorResponse(
message="bash_exec requires bubblewrap sandbox (Linux only).",
error="sandbox_unavailable",
session_id=session_id,
)
command: str = (kwargs.get("command") or "").strip()
timeout: int = kwargs.get("timeout", 30)
timeout: int = int(kwargs.get("timeout", 30))
if not command:
return ErrorResponse(
@@ -102,6 +94,21 @@ class BashExecTool(BaseTool):
session_id=session_id,
)
# E2B path: run on remote cloud sandbox when available.
from backend.copilot.sdk.tool_adapter import get_current_sandbox
sandbox = get_current_sandbox()
if sandbox is not None:
return await self._execute_on_e2b(sandbox, command, timeout, session_id)
# Bubblewrap fallback: local isolated execution.
if not has_full_sandbox():
return ErrorResponse(
message="bash_exec requires bubblewrap sandbox (Linux only).",
error="sandbox_unavailable",
session_id=session_id,
)
workspace = get_workspace_dir(session_id or "default")
stdout, stderr, exit_code, timed_out = await run_sandboxed(
@@ -122,3 +129,43 @@ class BashExecTool(BaseTool):
timed_out=timed_out,
session_id=session_id,
)
async def _execute_on_e2b(
self,
sandbox: AsyncSandbox,
command: str,
timeout: int,
session_id: str | None,
) -> ToolResponseBase:
"""Execute *command* on the E2B sandbox via commands.run()."""
try:
result = await sandbox.commands.run(
f"bash -c {shlex.quote(command)}",
cwd=E2B_WORKDIR,
timeout=timeout,
envs={"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"},
)
return BashExecResponse(
message=f"Command executed on E2B (exit {result.exit_code})",
stdout=result.stdout or "",
stderr=result.stderr or "",
exit_code=result.exit_code,
timed_out=False,
session_id=session_id,
)
except Exception as exc:
if isinstance(exc, TimeoutException):
return BashExecResponse(
message="Execution timed out",
stdout="",
stderr=f"Timed out after {timeout}s",
exit_code=-1,
timed_out=True,
session_id=session_id,
)
logger.error("[E2B] bash_exec failed: %s", exc, exc_info=True)
return ErrorResponse(
message=f"E2B execution failed: {exc}",
error="e2b_execution_error",
session_id=session_id,
)

View File

@@ -0,0 +1,170 @@
"""E2B sandbox lifecycle for CoPilot: persistent cloud execution.
Each session gets a long-lived E2B cloud sandbox. ``bash_exec`` runs commands
directly on the sandbox via ``sandbox.commands.run()``. SDK file tools
(read_file/write_file/edit_file/glob/grep) route to the sandbox's
``/home/user`` directory via E2B's HTTP-based filesystem API — all tools
share a single coherent filesystem with no local sync required.
Lifecycle
---------
1. **Turn start** connect to the existing sandbox (sandbox_id in Redis) or
create a new one via ``get_or_create_sandbox()``.
2. **Execution** ``bash_exec`` and MCP file tools operate directly on the
sandbox's ``/home/user`` filesystem.
3. **Session expiry** E2B sandbox is killed by its own timeout (session_ttl).
"""
import asyncio
import logging
from e2b import AsyncSandbox
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
_SANDBOX_REDIS_PREFIX = "copilot:e2b:sandbox:"
E2B_WORKDIR = "/home/user"
_CREATING = "__creating__"
_CREATION_LOCK_TTL = 60
_MAX_WAIT_ATTEMPTS = 20 # 20 * 0.5s = 10s max wait
async def _try_reconnect(
sandbox_id: str, api_key: str, redis_key: str, timeout: int
) -> "AsyncSandbox | None":
"""Try to reconnect to an existing sandbox. Returns None on failure."""
try:
sandbox = await AsyncSandbox.connect(sandbox_id, api_key=api_key)
if await sandbox.is_running():
redis = await get_redis_async()
await redis.expire(redis_key, timeout)
return sandbox
except Exception as exc:
logger.warning("[E2B] Reconnect to %.12s failed: %s", sandbox_id, exc)
# Stale — clear Redis so a new sandbox can be created.
redis = await get_redis_async()
await redis.delete(redis_key)
return None
async def get_or_create_sandbox(
session_id: str,
api_key: str,
template: str = "base",
timeout: int = 43200,
) -> AsyncSandbox:
"""Return the existing E2B sandbox for *session_id* or create a new one.
The sandbox_id is persisted in Redis so the same sandbox is reused
across turns. Concurrent calls for the same session are serialised
via a Redis ``SET NX`` creation lock.
"""
redis = await get_redis_async()
redis_key = f"{_SANDBOX_REDIS_PREFIX}{session_id}"
# 1. Try reconnecting to an existing sandbox.
raw = await redis.get(redis_key)
if raw:
sandbox_id = raw if isinstance(raw, str) else raw.decode()
if sandbox_id != _CREATING:
sandbox = await _try_reconnect(sandbox_id, api_key, redis_key, timeout)
if sandbox:
logger.info(
"[E2B] Reconnected to %.12s for session %.12s",
sandbox_id,
session_id,
)
return sandbox
# 2. Claim creation lock. If another request holds it, wait for the result.
claimed = await redis.set(redis_key, _CREATING, nx=True, ex=_CREATION_LOCK_TTL)
if not claimed:
for _ in range(_MAX_WAIT_ATTEMPTS):
await asyncio.sleep(0.5)
raw = await redis.get(redis_key)
if not raw:
break # Lock expired — fall through to retry creation
sandbox_id = raw if isinstance(raw, str) else raw.decode()
if sandbox_id != _CREATING:
sandbox = await _try_reconnect(sandbox_id, api_key, redis_key, timeout)
if sandbox:
return sandbox
break # Stale sandbox cleared — fall through to create
# Try to claim creation lock again after waiting.
claimed = await redis.set(redis_key, _CREATING, nx=True, ex=_CREATION_LOCK_TTL)
if not claimed:
# Another process may have created a sandbox — try to use it.
raw = await redis.get(redis_key)
if raw:
sandbox_id = raw if isinstance(raw, str) else raw.decode()
if sandbox_id != _CREATING:
sandbox = await _try_reconnect(
sandbox_id, api_key, redis_key, timeout
)
if sandbox:
return sandbox
raise RuntimeError(
f"Could not acquire E2B creation lock for session {session_id[:12]}"
)
# 3. Create a new sandbox.
try:
sandbox = await AsyncSandbox.create(
template=template, api_key=api_key, timeout=timeout
)
except Exception:
await redis.delete(redis_key)
raise
await redis.setex(redis_key, timeout, sandbox.sandbox_id)
logger.info(
"[E2B] Created sandbox %.12s for session %.12s",
sandbox.sandbox_id,
session_id,
)
return sandbox
async def kill_sandbox(session_id: str, api_key: str) -> bool:
"""Kill the E2B sandbox for *session_id* and clean up its Redis entry.
Returns ``True`` if a sandbox was found and killed, ``False`` otherwise.
Safe to call even when no sandbox exists for the session.
"""
redis = await get_redis_async()
redis_key = f"{_SANDBOX_REDIS_PREFIX}{session_id}"
raw = await redis.get(redis_key)
if not raw:
return False
sandbox_id = raw if isinstance(raw, str) else raw.decode()
await redis.delete(redis_key)
if sandbox_id == _CREATING:
return False
try:
async def _connect_and_kill():
sandbox = await AsyncSandbox.connect(sandbox_id, api_key=api_key)
await sandbox.kill()
await asyncio.wait_for(_connect_and_kill(), timeout=10)
logger.info(
"[E2B] Killed sandbox %.12s for session %.12s",
sandbox_id,
session_id,
)
return True
except Exception as exc:
logger.warning(
"[E2B] Failed to kill sandbox %.12s for session %.12s: %s",
sandbox_id,
session_id,
exc,
)
return False

View File

@@ -0,0 +1,272 @@
"""Tests for e2b_sandbox: get_or_create_sandbox, _try_reconnect, kill_sandbox.
Uses mock Redis and mock AsyncSandbox — no external dependencies.
Tests are synchronous (using asyncio.run) to avoid conflicts with the
session-scoped event loop in conftest.py.
"""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from .e2b_sandbox import (
_CREATING,
_SANDBOX_REDIS_PREFIX,
_try_reconnect,
get_or_create_sandbox,
kill_sandbox,
)
_KEY = f"{_SANDBOX_REDIS_PREFIX}sess-123"
_API_KEY = "test-api-key"
_TIMEOUT = 300
def _mock_sandbox(sandbox_id: str = "sb-abc", running: bool = True) -> MagicMock:
sb = MagicMock()
sb.sandbox_id = sandbox_id
sb.is_running = AsyncMock(return_value=running)
return sb
def _mock_redis(get_val: str | bytes | None = None, set_nx_result: bool = True):
r = AsyncMock()
r.get = AsyncMock(return_value=get_val)
r.set = AsyncMock(return_value=set_nx_result)
r.setex = AsyncMock()
r.delete = AsyncMock()
r.expire = AsyncMock()
return r
def _patch_redis(redis):
return patch(
"backend.copilot.tools.e2b_sandbox.get_redis_async",
new_callable=AsyncMock,
return_value=redis,
)
# ---------------------------------------------------------------------------
# _try_reconnect
# ---------------------------------------------------------------------------
class TestTryReconnect:
def test_reconnect_success(self):
sb = _mock_sandbox()
redis = _mock_redis()
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
assert result is sb
redis.expire.assert_awaited_once_with(_KEY, _TIMEOUT)
redis.delete.assert_not_awaited()
def test_reconnect_not_running_clears_key(self):
sb = _mock_sandbox(running=False)
redis = _mock_redis()
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
assert result is None
redis.delete.assert_awaited_once_with(_KEY)
def test_reconnect_exception_clears_key(self):
redis = _mock_redis()
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
assert result is None
redis.delete.assert_awaited_once_with(_KEY)
# ---------------------------------------------------------------------------
# get_or_create_sandbox
# ---------------------------------------------------------------------------
class TestGetOrCreateSandbox:
def test_reconnect_existing(self):
"""When Redis has a valid sandbox_id, reconnect to it."""
sb = _mock_sandbox()
redis = _mock_redis(get_val="sb-abc")
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
)
assert result is sb
mock_cls.create.assert_not_called()
def test_create_new_when_no_key(self):
"""When Redis is empty, claim lock and create a new sandbox."""
sb = _mock_sandbox("sb-new")
redis = _mock_redis(get_val=None, set_nx_result=True)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.create = AsyncMock(return_value=sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
)
assert result is sb
redis.setex.assert_awaited_once_with(_KEY, _TIMEOUT, "sb-new")
def test_create_failure_clears_lock(self):
"""If sandbox creation fails, the Redis lock is deleted."""
redis = _mock_redis(get_val=None, set_nx_result=True)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.create = AsyncMock(side_effect=RuntimeError("quota"))
with pytest.raises(RuntimeError, match="quota"):
asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
)
redis.delete.assert_awaited_once_with(_KEY)
def test_wait_for_lock_then_reconnect(self):
"""When another process holds the lock, wait and reconnect."""
sb = _mock_sandbox("sb-other")
redis = _mock_redis()
redis.get = AsyncMock(side_effect=[_CREATING, "sb-other"])
redis.set = AsyncMock(return_value=False)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
patch(
"backend.copilot.tools.e2b_sandbox.asyncio.sleep",
new_callable=AsyncMock,
),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
)
assert result is sb
def test_stale_reconnect_clears_and_creates(self):
"""When stored sandbox is stale, clear key and create a new one."""
stale_sb = _mock_sandbox("sb-stale", running=False)
new_sb = _mock_sandbox("sb-fresh")
redis = _mock_redis(get_val="sb-stale", set_nx_result=True)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=stale_sb)
mock_cls.create = AsyncMock(return_value=new_sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
)
assert result is new_sb
redis.delete.assert_awaited()
# ---------------------------------------------------------------------------
# kill_sandbox
# ---------------------------------------------------------------------------
class TestKillSandbox:
def test_kill_existing_sandbox(self):
"""Kill a running sandbox and clean up Redis."""
sb = _mock_sandbox()
sb.kill = AsyncMock()
redis = _mock_redis(get_val="sb-abc")
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
assert result is True
redis.delete.assert_awaited_once_with(_KEY)
sb.kill.assert_awaited_once()
def test_kill_no_sandbox(self):
"""No-op when no sandbox exists in Redis."""
redis = _mock_redis(get_val=None)
with _patch_redis(redis):
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
assert result is False
redis.delete.assert_not_awaited()
def test_kill_creating_state(self):
"""Clears Redis key but returns False when sandbox is still being created."""
redis = _mock_redis(get_val=_CREATING)
with _patch_redis(redis):
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
assert result is False
redis.delete.assert_awaited_once_with(_KEY)
def test_kill_connect_failure(self):
"""Returns False and cleans Redis if connect/kill fails."""
redis = _mock_redis(get_val="sb-abc")
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
assert result is False
redis.delete.assert_awaited_once_with(_KEY)
def test_kill_with_bytes_redis_value(self):
"""Redis may return bytes — kill_sandbox should decode correctly."""
sb = _mock_sandbox()
sb.kill = AsyncMock()
redis = _mock_redis(get_val=b"sb-abc")
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
assert result is True
sb.kill.assert_awaited_once()
def test_kill_timeout_returns_false(self):
"""Returns False when E2B API calls exceed the 10s timeout."""
redis = _mock_redis(get_val="sb-abc")
with (
_patch_redis(redis),
patch(
"backend.copilot.tools.e2b_sandbox.asyncio.wait_for",
new_callable=AsyncMock,
side_effect=asyncio.TimeoutError,
),
):
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
assert result is False
redis.delete.assert_awaited_once_with(_KEY)

View File

@@ -41,6 +41,10 @@ class ResponseType(str, Enum):
INPUT_VALIDATION_ERROR = "input_validation_error"
# Web fetch
WEB_FETCH = "web_fetch"
# Agent-browser multi-step automation (navigate, act, screenshot)
BROWSER_NAVIGATE = "browser_navigate"
BROWSER_ACT = "browser_act"
BROWSER_SCREENSHOT = "browser_screenshot"
# Code execution
BASH_EXEC = "bash_exec"
# Feature request types
@@ -476,3 +480,32 @@ class FeatureRequestCreatedResponse(ToolResponseBase):
issue_url: str
is_new_issue: bool # False if added to existing
customer_name: str
# Agent-browser multi-step automation models
class BrowserNavigateResponse(ToolResponseBase):
"""Response for browser_navigate tool."""
type: ResponseType = ResponseType.BROWSER_NAVIGATE
url: str
title: str
snapshot: str # Interactive accessibility tree with @ref IDs
class BrowserActResponse(ToolResponseBase):
"""Response for browser_act tool."""
type: ResponseType = ResponseType.BROWSER_ACT
action: str
current_url: str = ""
snapshot: str # Updated accessibility tree after the action
class BrowserScreenshotResponse(ToolResponseBase):
"""Response for browser_screenshot tool."""
type: ResponseType = ResponseType.BROWSER_SCREENSHOT
file_id: str # Workspace file ID — use read_workspace_file to retrieve
filename: str

View File

@@ -8,6 +8,7 @@ from typing import Any, Optional
from pydantic import BaseModel
from backend.copilot.model import ChatSession
from backend.copilot.tools.e2b_sandbox import E2B_WORKDIR
from backend.copilot.tools.sandbox import make_session_path
from backend.data.db_accessors import workspace_db
from backend.util.settings import Config
@@ -20,7 +21,7 @@ from .models import ErrorResponse, ResponseType, ToolResponseBase
logger = logging.getLogger(__name__)
def _resolve_write_content(
async def _resolve_write_content(
content_text: str | None,
content_b64: str | None,
source_path: str | None,
@@ -30,6 +31,9 @@ def _resolve_write_content(
Returns the raw bytes on success, or an ``ErrorResponse`` on validation
failure (wrong number of sources, invalid path, file not found, etc.).
When an E2B sandbox is active, ``source_path`` reads from the sandbox
filesystem instead of the local ephemeral directory.
"""
# Normalise empty strings to None so counting and dispatch stay in sync.
if content_text is not None and content_text == "":
@@ -54,24 +58,7 @@ def _resolve_write_content(
)
if source_path is not None:
validated = _validate_ephemeral_path(
source_path, param_name="source_path", session_id=session_id
)
if isinstance(validated, ErrorResponse):
return validated
try:
with open(validated, "rb") as f:
return f.read()
except FileNotFoundError:
return ErrorResponse(
message=f"Source file not found: {source_path}",
session_id=session_id,
)
except Exception as e:
return ErrorResponse(
message=f"Failed to read source file: {e}",
session_id=session_id,
)
return await _read_source_path(source_path, session_id)
if content_b64 is not None:
try:
@@ -91,6 +78,106 @@ def _resolve_write_content(
return content_text.encode("utf-8")
def _resolve_sandbox_path(
path: str, session_id: str | None, param_name: str
) -> str | ErrorResponse:
"""Normalize *path* to an absolute sandbox path under :data:`E2B_WORKDIR`.
Delegates to :func:`~backend.copilot.sdk.e2b_file_tools._resolve_remote`
and wraps any ``ValueError`` into an :class:`ErrorResponse`.
"""
from backend.copilot.sdk.e2b_file_tools import _resolve_remote
try:
return _resolve_remote(path)
except ValueError:
return ErrorResponse(
message=f"{param_name} must be within {E2B_WORKDIR}",
session_id=session_id,
)
async def _read_source_path(source_path: str, session_id: str) -> bytes | ErrorResponse:
"""Read *source_path* from E2B sandbox or local ephemeral directory."""
from backend.copilot.sdk.tool_adapter import get_current_sandbox
sandbox = get_current_sandbox()
if sandbox is not None:
remote = _resolve_sandbox_path(source_path, session_id, "source_path")
if isinstance(remote, ErrorResponse):
return remote
try:
data = await sandbox.files.read(remote, format="bytes")
return bytes(data)
except Exception as exc:
return ErrorResponse(
message=f"Source file not found on sandbox: {source_path} ({exc})",
session_id=session_id,
)
# Local fallback: validate path stays within ephemeral directory.
validated = _validate_ephemeral_path(
source_path, param_name="source_path", session_id=session_id
)
if isinstance(validated, ErrorResponse):
return validated
try:
with open(validated, "rb") as f:
return f.read()
except FileNotFoundError:
return ErrorResponse(
message=f"Source file not found: {source_path}",
session_id=session_id,
)
except Exception as e:
return ErrorResponse(
message=f"Failed to read source file: {e}",
session_id=session_id,
)
async def _save_to_path(
path: str, content: bytes, session_id: str
) -> str | ErrorResponse:
"""Write *content* to *path* on E2B sandbox or local ephemeral directory.
Returns the resolved path on success, or an ``ErrorResponse`` on failure.
"""
from backend.copilot.sdk.tool_adapter import get_current_sandbox
sandbox = get_current_sandbox()
if sandbox is not None:
remote = _resolve_sandbox_path(path, session_id, "save_to_path")
if isinstance(remote, ErrorResponse):
return remote
try:
await sandbox.files.write(remote, content)
except Exception as exc:
return ErrorResponse(
message=f"Failed to write to sandbox: {path} ({exc})",
session_id=session_id,
)
return remote
validated = _validate_ephemeral_path(
path, param_name="save_to_path", session_id=session_id
)
if isinstance(validated, ErrorResponse):
return validated
try:
dir_path = os.path.dirname(validated)
if dir_path:
os.makedirs(dir_path, exist_ok=True)
with open(validated, "wb") as f:
f.write(content)
except Exception as exc:
return ErrorResponse(
message=f"Failed to write to local path: {path} ({exc})",
session_id=session_id,
)
return validated
def _validate_ephemeral_path(
path: str, *, param_name: str, session_id: str
) -> ErrorResponse | str:
@@ -131,7 +218,7 @@ def _is_text_mime(mime_type: str) -> bool:
return any(mime_type.startswith(t) for t in _TEXT_MIME_PREFIXES)
async def _get_manager(user_id: str, session_id: str) -> WorkspaceManager:
async def get_manager(user_id: str, session_id: str) -> WorkspaceManager:
"""Create a session-scoped WorkspaceManager."""
workspace = await workspace_db().get_or_create_workspace(user_id)
return WorkspaceManager(user_id, workspace.id, session_id)
@@ -299,7 +386,7 @@ class ListWorkspaceFilesTool(BaseTool):
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
try:
manager = await _get_manager(user_id, session_id)
manager = await get_manager(user_id, session_id)
files = await manager.list_files(
path=path_prefix, limit=limit, include_all_sessions=include_all_sessions
)
@@ -429,17 +516,8 @@ class ReadWorkspaceFileTool(BaseTool):
message="Please provide either file_id or path", session_id=session_id
)
# Validate and resolve save_to_path (use sanitized real path).
if save_to_path:
validated_save = _validate_ephemeral_path(
save_to_path, param_name="save_to_path", session_id=session_id
)
if isinstance(validated_save, ErrorResponse):
return validated_save
save_to_path = validated_save
try:
manager = await _get_manager(user_id, session_id)
manager = await get_manager(user_id, session_id)
resolved = await _resolve_file(manager, file_id, path, session_id)
if isinstance(resolved, ErrorResponse):
return resolved
@@ -449,11 +527,10 @@ class ReadWorkspaceFileTool(BaseTool):
cached_content: bytes | None = None
if save_to_path:
cached_content = await manager.read_file_by_id(target_file_id)
dir_path = os.path.dirname(save_to_path)
if dir_path:
os.makedirs(dir_path, exist_ok=True)
with open(save_to_path, "wb") as f:
f.write(cached_content)
result = await _save_to_path(save_to_path, cached_content, session_id)
if isinstance(result, ErrorResponse):
return result
save_to_path = result
is_small = file_info.size_bytes <= self.MAX_INLINE_SIZE_BYTES
is_text = _is_text_mime(file_info.mime_type)
@@ -629,7 +706,7 @@ class WriteWorkspaceFileTool(BaseTool):
content_text: str | None = kwargs.get("content")
content_b64: str | None = kwargs.get("content_base64")
resolved = _resolve_write_content(
resolved = await _resolve_write_content(
content_text,
content_b64,
source_path_arg,
@@ -648,7 +725,7 @@ class WriteWorkspaceFileTool(BaseTool):
try:
await scan_content_safe(content, filename=filename)
manager = await _get_manager(user_id, session_id)
manager = await get_manager(user_id, session_id)
rec = await manager.write_file(
content=content,
filename=filename,
@@ -775,7 +852,7 @@ class DeleteWorkspaceFileTool(BaseTool):
)
try:
manager = await _get_manager(user_id, session_id)
manager = await get_manager(user_id, session_id)
resolved = await _resolve_file(manager, file_id, path, session_id)
if isinstance(resolved, ErrorResponse):
return resolved

View File

@@ -102,67 +102,68 @@ class TestValidateEphemeralPath:
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
class TestResolveWriteContent:
def test_no_sources_returns_error(self):
async def test_no_sources_returns_error(self):
from backend.copilot.tools.models import ErrorResponse
result = _resolve_write_content(None, None, None, "s1")
result = await _resolve_write_content(None, None, None, "s1")
assert isinstance(result, ErrorResponse)
def test_multiple_sources_returns_error(self):
async def test_multiple_sources_returns_error(self):
from backend.copilot.tools.models import ErrorResponse
result = _resolve_write_content("text", "b64data", None, "s1")
result = await _resolve_write_content("text", "b64data", None, "s1")
assert isinstance(result, ErrorResponse)
def test_plain_text_content(self):
result = _resolve_write_content("hello world", None, None, "s1")
async def test_plain_text_content(self):
result = await _resolve_write_content("hello world", None, None, "s1")
assert result == b"hello world"
def test_base64_content(self):
async def test_base64_content(self):
raw = b"binary data"
b64 = base64.b64encode(raw).decode()
result = _resolve_write_content(None, b64, None, "s1")
result = await _resolve_write_content(None, b64, None, "s1")
assert result == raw
def test_invalid_base64_returns_error(self):
async def test_invalid_base64_returns_error(self):
from backend.copilot.tools.models import ErrorResponse
result = _resolve_write_content(None, "not-valid-b64!!!", None, "s1")
result = await _resolve_write_content(None, "not-valid-b64!!!", None, "s1")
assert isinstance(result, ErrorResponse)
assert "base64" in result.message.lower()
def test_source_path(self, ephemeral_dir):
async def test_source_path(self, ephemeral_dir):
target = ephemeral_dir / "input.txt"
target.write_bytes(b"file content")
result = _resolve_write_content(None, None, str(target), "s1")
result = await _resolve_write_content(None, None, str(target), "s1")
assert result == b"file content"
def test_source_path_not_found(self, ephemeral_dir):
async def test_source_path_not_found(self, ephemeral_dir):
from backend.copilot.tools.models import ErrorResponse
missing = str(ephemeral_dir / "nope.txt")
result = _resolve_write_content(None, None, missing, "s1")
result = await _resolve_write_content(None, None, missing, "s1")
assert isinstance(result, ErrorResponse)
def test_source_path_outside_ephemeral(self, ephemeral_dir, tmp_path):
async def test_source_path_outside_ephemeral(self, ephemeral_dir, tmp_path):
from backend.copilot.tools.models import ErrorResponse
outside = tmp_path / "outside.txt"
outside.write_text("nope")
result = _resolve_write_content(None, None, str(outside), "s1")
result = await _resolve_write_content(None, None, str(outside), "s1")
assert isinstance(result, ErrorResponse)
def test_empty_string_sources_treated_as_none(self):
async def test_empty_string_sources_treated_as_none(self):
from backend.copilot.tools.models import ErrorResponse
# All empty strings → same as no sources
result = _resolve_write_content("", "", "", "s1")
result = await _resolve_write_content("", "", "", "s1")
assert isinstance(result, ErrorResponse)
def test_empty_string_source_path_with_text(self):
async def test_empty_string_source_path_with_text(self):
# source_path="" should be normalised to None, so only content counts
result = _resolve_write_content("hello", "", "", "s1")
result = await _resolve_write_content("hello", "", "", "s1")
assert result == b"hello"

View File

@@ -327,11 +327,16 @@ async def get_workspace_total_size(workspace_id: str) -> int:
"""
Get the total size of all files in a workspace.
Queries Prisma directly (skipping Pydantic model conversion) and only
fetches the ``sizeBytes`` column to minimise data transfer.
Args:
workspace_id: The workspace ID
Returns:
Total size in bytes
"""
files = await list_workspace_files(workspace_id)
return sum(file.size_bytes for file in files)
files = await UserWorkspaceFile.prisma().find_many(
where={"workspaceId": workspace_id, "isDeleted": False},
)
return sum(f.sizeBytes for f in files)

View File

@@ -105,8 +105,13 @@ def validate_with_jsonschema(
return str(e)
def _sanitize_string(value: str) -> str:
"""Remove PostgreSQL-incompatible control characters from string."""
def sanitize_string(value: str) -> str:
"""Remove PostgreSQL-incompatible control characters from string.
Strips \\x00-\\x08, \\x0B-\\x0C, \\x0E-\\x1F, \\x7F while keeping tab,
newline, and carriage return. Use this before inserting free-form text
into PostgreSQL text/varchar columns.
"""
return POSTGRES_CONTROL_CHARS.sub("", value)
@@ -116,7 +121,7 @@ def sanitize_json(data: Any) -> Any:
# 1. First convert to basic JSON-serializable types (handles Pydantic models)
# 2. Then sanitize strings in the result
basic_result = to_dict(data)
return to_dict(basic_result, custom_encoder={str: _sanitize_string})
return to_dict(basic_result, custom_encoder={str: sanitize_string})
except Exception as e:
# Log the failure and fall back to string representation
logger.error(
@@ -129,7 +134,7 @@ def sanitize_json(data: Any) -> Any:
)
# Ultimate fallback: convert to string representation and sanitize
return _sanitize_string(str(data))
return sanitize_string(str(data))
class SafeJson(Json):

View File

@@ -413,6 +413,13 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="Maximum file size in MB for workspace files (1-1024 MB)",
)
max_workspace_storage_mb: int = Field(
default=500,
ge=1,
le=10240,
description="Maximum total workspace storage per user in MB.",
)
# AutoMod configuration
automod_enabled: bool = Field(
default=False,

View File

@@ -79,6 +79,10 @@ def truncate(value: Any, size_limit: int) -> Any:
largest str_limit and list_limit that fit.
"""
# Fast path: plain strings don't need the binary search machinery.
if isinstance(value, str):
return _truncate_string_middle(value, size_limit)
def measure(val):
try:
return len(str(val))
@@ -86,7 +90,7 @@ def truncate(value: Any, size_limit: int) -> Any:
return sys.getsizeof(val)
# Reasonable bounds for string and list limits
STR_MIN, STR_MAX = 8, 2**16
STR_MIN, STR_MAX = min(8, size_limit), size_limit
LIST_MIN, LIST_MAX = 1, 2**12
# Binary search for the largest str_limit and list_limit that fit

View File

@@ -162,7 +162,7 @@ services:
context: ../
dockerfile: autogpt_platform/backend/Dockerfile
target: server
command: ["python", "-m", "backend.copilot.executor"]
command: ["python", "-u", "-m", "backend.copilot.executor"]
develop:
watch:
- path: ./
@@ -182,6 +182,7 @@ services:
<<: *backend-env-files
environment:
<<: *backend-env
PYTHONUNBUFFERED: "1"
ports:
- "8008:8008"
networks:

View File

@@ -7,7 +7,9 @@ import {
DropdownMenuTrigger,
} from "@/components/molecules/DropdownMenu/DropdownMenu";
import { SidebarProvider } from "@/components/ui/sidebar";
import { DotsThree } from "@phosphor-icons/react";
import { cn } from "@/lib/utils";
import { DotsThree, UploadSimple } from "@phosphor-icons/react";
import { useCallback, useRef, useState } from "react";
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
import { DeleteChatDialog } from "./components/DeleteChatDialog/DeleteChatDialog";
@@ -17,6 +19,49 @@ import { ScaleLoader } from "./components/ScaleLoader/ScaleLoader";
import { useCopilotPage } from "./useCopilotPage";
export function CopilotPage() {
const [isDragging, setIsDragging] = useState(false);
const [droppedFiles, setDroppedFiles] = useState<File[]>([]);
const dragCounter = useRef(0);
const handleDroppedFilesConsumed = useCallback(() => {
setDroppedFiles([]);
}, []);
function handleDragEnter(e: React.DragEvent) {
e.preventDefault();
e.stopPropagation();
dragCounter.current += 1;
if (e.dataTransfer.types.includes("Files")) {
setIsDragging(true);
}
}
function handleDragOver(e: React.DragEvent) {
e.preventDefault();
e.stopPropagation();
}
function handleDragLeave(e: React.DragEvent) {
e.preventDefault();
e.stopPropagation();
dragCounter.current -= 1;
if (dragCounter.current === 0) {
setIsDragging(false);
}
}
function handleDrop(e: React.DragEvent) {
e.preventDefault();
e.stopPropagation();
dragCounter.current = 0;
setIsDragging(false);
const files = Array.from(e.dataTransfer.files);
if (files.length > 0) {
setDroppedFiles(files);
}
}
const {
sessionId,
messages,
@@ -29,6 +74,7 @@ export function CopilotPage() {
isLoadingSession,
isSessionError,
isCreatingSession,
isUploadingFiles,
isUserLoading,
isLoggedIn,
// Mobile drawer
@@ -63,8 +109,26 @@ export function CopilotPage() {
className="h-[calc(100vh-72px)] min-h-0"
>
{!isMobile && <ChatSidebar />}
<div className="relative flex h-full w-full flex-col overflow-hidden bg-[#f8f8f9] px-0">
<div
className="relative flex h-full w-full flex-col overflow-hidden bg-[#f8f8f9] px-0"
onDragEnter={handleDragEnter}
onDragOver={handleDragOver}
onDragLeave={handleDragLeave}
onDrop={handleDrop}
>
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
{/* Drop overlay */}
<div
className={cn(
"pointer-events-none absolute inset-0 z-50 flex flex-col items-center justify-center gap-3 rounded-lg border-2 border-dashed border-violet-400 bg-violet-500/10 transition-opacity duration-150",
isDragging ? "opacity-100" : "opacity-0",
)}
>
<UploadSimple className="h-10 w-10 text-violet-500" weight="bold" />
<span className="text-lg font-medium text-violet-600">
Drop files here
</span>
</div>
<div className="flex-1 overflow-hidden">
<ChatContainer
messages={messages}
@@ -78,6 +142,9 @@ export function CopilotPage() {
onCreateSession={createSession}
onSend={onSend}
onStop={stop}
isUploadingFiles={isUploadingFiles}
droppedFiles={droppedFiles}
onDroppedFilesConsumed={handleDroppedFilesConsumed}
headerSlot={
isMobile && sessionId ? (
<div className="flex justify-end">

View File

@@ -18,9 +18,14 @@ export interface ChatContainerProps {
/** True when backend has an active stream but we haven't reconnected yet. */
isReconnecting?: boolean;
onCreateSession: () => void | Promise<string>;
onSend: (message: string) => void | Promise<void>;
onSend: (message: string, files?: File[]) => void | Promise<void>;
onStop: () => void;
isUploadingFiles?: boolean;
headerSlot?: ReactNode;
/** Files dropped onto the chat window. */
droppedFiles?: File[];
/** Called after droppedFiles have been consumed by ChatInput. */
onDroppedFilesConsumed?: () => void;
}
export const ChatContainer = ({
messages,
@@ -34,7 +39,10 @@ export const ChatContainer = ({
onCreateSession,
onSend,
onStop,
isUploadingFiles,
headerSlot,
droppedFiles,
onDroppedFilesConsumed,
}: ChatContainerProps) => {
const isBusy =
status === "streaming" ||
@@ -69,8 +77,11 @@ export const ChatContainer = ({
onSend={onSend}
disabled={isBusy}
isStreaming={isBusy}
isUploadingFiles={isUploadingFiles}
onStop={onStop}
placeholder="What else can I help with?"
droppedFiles={droppedFiles}
onDroppedFilesConsumed={onDroppedFilesConsumed}
/>
</motion.div>
</div>
@@ -80,6 +91,9 @@ export const ChatContainer = ({
isCreatingSession={isCreatingSession}
onCreateSession={onCreateSession}
onSend={onSend}
isUploadingFiles={isUploadingFiles}
droppedFiles={droppedFiles}
onDroppedFilesConsumed={onDroppedFilesConsumed}
/>
)}
</div>

View File

@@ -9,38 +9,66 @@ import {
import { InputGroup } from "@/components/ui/input-group";
import { cn } from "@/lib/utils";
import { CircleNotchIcon, MicrophoneIcon } from "@phosphor-icons/react";
import { ChangeEvent } from "react";
import { ChangeEvent, useEffect, useState } from "react";
import { AttachmentMenu } from "./components/AttachmentMenu";
import { FileChips } from "./components/FileChips";
import { RecordingIndicator } from "./components/RecordingIndicator";
import { useChatInput } from "./useChatInput";
import { useVoiceRecording } from "./useVoiceRecording";
export interface Props {
onSend: (message: string) => void | Promise<void>;
onSend: (message: string, files?: File[]) => void | Promise<void>;
disabled?: boolean;
isStreaming?: boolean;
isUploadingFiles?: boolean;
onStop?: () => void;
placeholder?: string;
className?: string;
inputId?: string;
/** Files dropped onto the chat window by the parent. */
droppedFiles?: File[];
/** Called after droppedFiles have been merged into internal state. */
onDroppedFilesConsumed?: () => void;
}
export function ChatInput({
onSend,
disabled = false,
isStreaming = false,
isUploadingFiles = false,
onStop,
placeholder = "Type your message...",
className,
inputId = "chat-input",
droppedFiles,
onDroppedFilesConsumed,
}: Props) {
const [files, setFiles] = useState<File[]>([]);
// Merge files dropped onto the chat window into internal state.
useEffect(() => {
if (droppedFiles && droppedFiles.length > 0) {
setFiles((prev) => [...prev, ...droppedFiles]);
onDroppedFilesConsumed?.();
}
}, [droppedFiles, onDroppedFilesConsumed]);
const hasFiles = files.length > 0;
const isBusy = disabled || isStreaming || isUploadingFiles;
const {
value,
setValue,
handleSubmit,
handleChange: baseHandleChange,
} = useChatInput({
onSend,
disabled: disabled || isStreaming,
onSend: async (message: string) => {
await onSend(message, hasFiles ? files : undefined);
// Only clear files after successful send (onSend throws on failure)
setFiles([]);
},
disabled: isBusy,
canSendEmpty: hasFiles,
inputId,
});
@@ -55,7 +83,7 @@ export function ChatInput({
audioStream,
} = useVoiceRecording({
setValue,
disabled: disabled || isStreaming,
disabled: isBusy,
isStreaming,
value,
inputId,
@@ -67,7 +95,18 @@ export function ChatInput({
}
const canSend =
!disabled && !!value.trim() && !isRecording && !isTranscribing;
!disabled &&
(!!value.trim() || hasFiles) &&
!isRecording &&
!isTranscribing;
function handleFilesSelected(newFiles: File[]) {
setFiles((prev) => [...prev, ...newFiles]);
}
function handleRemoveFile(index: number) {
setFiles((prev) => prev.filter((_, i) => i !== index));
}
return (
<form onSubmit={handleSubmit} className={cn("relative flex-1", className)}>
@@ -78,6 +117,11 @@ export function ChatInput({
"border-red-400 ring-1 ring-red-400 has-[[data-slot=input-group-control]:focus-visible]:border-red-400 has-[[data-slot=input-group-control]:focus-visible]:ring-red-400",
)}
>
<FileChips
files={files}
onRemove={handleRemoveFile}
isUploading={isUploadingFiles}
/>
<PromptInputBody className="relative block w-full">
<PromptInputTextarea
id={inputId}
@@ -104,6 +148,10 @@ export function ChatInput({
<PromptInputFooter>
<PromptInputTools>
<AttachmentMenu
onFilesSelected={handleFilesSelected}
disabled={isBusy}
/>
{showMicButton && (
<PromptInputButton
aria-label={isRecording ? "Stop recording" : "Start recording"}

View File

@@ -0,0 +1,55 @@
"use client";
import { Button } from "@/components/atoms/Button/Button";
import { cn } from "@/lib/utils";
import { Plus as PlusIcon } from "@phosphor-icons/react";
import { useRef } from "react";
interface Props {
onFilesSelected: (files: File[]) => void;
disabled?: boolean;
}
export function AttachmentMenu({ onFilesSelected, disabled }: Props) {
const fileInputRef = useRef<HTMLInputElement>(null);
function handleClick() {
fileInputRef.current?.click();
}
function handleFileChange(e: React.ChangeEvent<HTMLInputElement>) {
const files = Array.from(e.target.files ?? []);
if (files.length > 0) {
onFilesSelected(files);
}
// Reset so the same file can be re-selected
e.target.value = "";
}
return (
<>
<input
ref={fileInputRef}
type="file"
multiple
className="hidden"
onChange={handleFileChange}
tabIndex={-1}
/>
<Button
type="button"
variant="icon"
size="icon"
aria-label="Attach file"
disabled={disabled}
onClick={handleClick}
className={cn(
"border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700",
disabled && "opacity-40",
)}
>
<PlusIcon className="h-4 w-4" weight="bold" />
</Button>
</>
);
}

View File

@@ -0,0 +1,45 @@
"use client";
import { cn } from "@/lib/utils";
import {
CircleNotch as CircleNotchIcon,
X as XIcon,
} from "@phosphor-icons/react";
interface Props {
files: File[];
onRemove: (index: number) => void;
isUploading?: boolean;
}
export function FileChips({ files, onRemove, isUploading }: Props) {
if (files.length === 0) return null;
return (
<div className="flex w-full flex-wrap gap-2 px-3 pb-2 pt-1">
{files.map((file, index) => (
<span
key={`${file.name}-${file.size}-${index}`}
className={cn(
"inline-flex items-center gap-1 rounded-full bg-zinc-100 px-3 py-1 text-sm text-zinc-700",
isUploading && "opacity-70",
)}
>
<span className="max-w-[160px] truncate">{file.name}</span>
{isUploading ? (
<CircleNotchIcon className="ml-0.5 h-3 w-3 animate-spin text-zinc-400" />
) : (
<button
type="button"
aria-label={`Remove ${file.name}`}
onClick={() => onRemove(index)}
className="ml-0.5 rounded-full p-0.5 text-zinc-400 transition-colors hover:bg-zinc-200 hover:text-zinc-600"
>
<XIcon className="h-3 w-3" weight="bold" />
</button>
)}
</span>
))}
</div>
);
}

View File

@@ -3,12 +3,15 @@ import { ChangeEvent, FormEvent, useEffect, useState } from "react";
interface Args {
onSend: (message: string) => void;
disabled?: boolean;
/** Allow sending when text is empty (e.g. when files are attached). */
canSendEmpty?: boolean;
inputId?: string;
}
export function useChatInput({
onSend,
disabled = false,
canSendEmpty = false,
inputId = "chat-input",
}: Args) {
const [value, setValue] = useState("");
@@ -32,7 +35,7 @@ export function useChatInput({
);
async function handleSend() {
if (disabled || isSending || !value.trim()) return;
if (disabled || isSending || (!value.trim() && !canSendEmpty)) return;
setIsSending(true);
try {

View File

@@ -5,7 +5,8 @@ import {
} from "@/components/ai-elements/conversation";
import { Message, MessageContent } from "@/components/ai-elements/message";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { UIDataTypes, UIMessage, UITools } from "ai";
import { FileUIPart, UIDataTypes, UIMessage, UITools } from "ai";
import { MessageAttachments } from "./components/MessageAttachments";
import { MessagePartRenderer } from "./components/MessagePartRenderer";
import { ThinkingIndicator } from "./components/ThinkingIndicator";
@@ -72,6 +73,10 @@ export function ChatMessagesContainer({
messageIndex === messages.length - 1 &&
message.role === "assistant";
const fileParts = message.parts.filter(
(p): p is FileUIPart => p.type === "file",
);
return (
<Message from={message.role} key={message.id}>
<MessageContent
@@ -93,6 +98,12 @@ export function ChatMessagesContainer({
<ThinkingIndicator active={showThinking} />
)}
</MessageContent>
{fileParts.length > 0 && (
<MessageAttachments
files={fileParts}
isUser={message.role === "user"}
/>
)}
</Message>
);
})}

View File

@@ -0,0 +1,84 @@
import {
FileText as FileTextIcon,
DownloadSimple as DownloadIcon,
} from "@phosphor-icons/react";
import type { FileUIPart } from "ai";
import {
ContentCard,
ContentCardHeader,
ContentCardTitle,
ContentCardSubtitle,
} from "../../ToolAccordion/AccordionContent";
interface Props {
files: FileUIPart[];
isUser?: boolean;
}
export function MessageAttachments({ files, isUser }: Props) {
if (files.length === 0) return null;
return (
<div className="mt-2 flex flex-col gap-2">
{files.map((file, i) =>
isUser ? (
<div
key={`${file.filename}-${i}`}
className="min-w-0 rounded-lg border border-purple-300 bg-purple-100 p-3"
>
<div className="flex items-start justify-between gap-2">
<div className="flex min-w-0 items-center gap-2">
<FileTextIcon className="h-5 w-5 shrink-0 text-neutral-400" />
<div className="min-w-0">
<p className="truncate text-sm font-medium text-zinc-800">
{file.filename || "file"}
</p>
<p className="mt-0.5 truncate font-mono text-xs text-zinc-800">
{file.mediaType || "file"}
</p>
</div>
</div>
{file.url && (
<a
href={file.url}
download
aria-label="Download file"
className="shrink-0 text-purple-400 hover:text-purple-600"
>
<DownloadIcon className="h-5 w-5" />
</a>
)}
</div>
</div>
) : (
<ContentCard key={`${file.filename}-${i}`}>
<ContentCardHeader
action={
file.url ? (
<a
href={file.url}
download
aria-label="Download file"
className="shrink-0 text-neutral-400 hover:text-neutral-600"
>
<DownloadIcon className="h-5 w-5" />
</a>
) : undefined
}
>
<div className="flex items-center gap-2">
<FileTextIcon className="h-5 w-5 shrink-0 text-neutral-400" />
<div className="min-w-0">
<ContentCardTitle>{file.filename || "file"}</ContentCardTitle>
<ContentCardSubtitle>
{file.mediaType || "file"}
</ContentCardSubtitle>
</div>
</div>
</ContentCardHeader>
</ContentCard>
),
)}
</div>
);
}

View File

@@ -17,13 +17,19 @@ interface Props {
inputLayoutId: string;
isCreatingSession: boolean;
onCreateSession: () => void | Promise<string>;
onSend: (message: string) => void | Promise<void>;
onSend: (message: string, files?: File[]) => void | Promise<void>;
isUploadingFiles?: boolean;
droppedFiles?: File[];
onDroppedFilesConsumed?: () => void;
}
export function EmptySession({
inputLayoutId,
isCreatingSession,
onSend,
isUploadingFiles,
droppedFiles,
onDroppedFilesConsumed,
}: Props) {
const { user } = useSupabase();
const greetingName = getGreetingName(user);
@@ -51,12 +57,12 @@ export function EmptySession({
return (
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-0 py-5 md:px-6 md:py-10">
<motion.div
className="w-full max-w-3xl text-center"
className="w-full max-w-[52rem] text-center"
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.3 }}
>
<div className="mx-auto max-w-3xl">
<div className="mx-auto max-w-[52rem]">
<Text variant="h3" className="mb-1 !text-[1.375rem] text-zinc-700">
Hey, <span className="text-violet-600">{greetingName}</span>
</Text>
@@ -74,8 +80,11 @@ export function EmptySession({
inputId="chat-input-empty"
onSend={onSend}
disabled={isCreatingSession}
isUploadingFiles={isUploadingFiles}
placeholder={inputPlaceholder}
className="w-full"
droppedFiles={droppedFiles}
onDroppedFilesConsumed={onDroppedFilesConsumed}
/>
</motion.div>
</div>

View File

@@ -1,4 +1,5 @@
import type { UIMessage, UIDataTypes, UITools } from "ai";
import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
import type { FileUIPart, UIMessage, UIDataTypes, UITools } from "ai";
interface SessionChatMessage {
role: string;
@@ -38,6 +39,48 @@ function coerceSessionChatMessages(
.filter((m): m is SessionChatMessage => m !== null);
}
/**
* Parse the `[Attached files]` block appended by the backend and return
* the cleaned text plus reconstructed FileUIPart objects.
*
* Backend format:
* ```
* \n\n[Attached files]
* - name.jpg (image/jpeg, 191.0 KB), file_id=<uuid>
* Use read_workspace_file with the file_id to access file contents.
* ```
*/
const ATTACHED_FILES_RE =
/\n?\n?\[Attached files\]\n([\s\S]*?)Use read_workspace_file with the file_id to access file contents\./;
const FILE_LINE_RE = /^- (.+) \(([^,]+),\s*[\d.]+ KB\), file_id=([0-9a-f-]+)$/;
function extractFileParts(content: string): {
cleanText: string;
fileParts: FileUIPart[];
} {
const match = content.match(ATTACHED_FILES_RE);
if (!match) return { cleanText: content, fileParts: [] };
const cleanText = content.replace(match[0], "").trim();
const lines = match[1].trim().split("\n");
const fileParts: FileUIPart[] = [];
for (const line of lines) {
const m = line.trim().match(FILE_LINE_RE);
if (!m) continue;
const [, filename, mimeType, fileId] = m;
const apiPath = getGetWorkspaceDownloadFileByIdUrl(fileId);
fileParts.push({
type: "file",
filename,
mediaType: mimeType,
url: `/api/proxy${apiPath}`,
});
}
return { cleanText, fileParts };
}
function safeJsonParse(value: string): unknown {
try {
return JSON.parse(value) as unknown;
@@ -79,7 +122,17 @@ export function convertChatSessionMessagesToUiMessages(
const parts: UIMessage<unknown, UIDataTypes, UITools>["parts"] = [];
if (typeof msg.content === "string" && msg.content.trim()) {
parts.push({ type: "text", text: msg.content, state: "done" });
if (msg.role === "user") {
const { cleanText, fileParts } = extractFileParts(msg.content);
if (cleanText) {
parts.push({ type: "text", text: cleanText, state: "done" });
}
for (const fp of fileParts) {
parts.push(fp);
}
} else {
parts.push({ type: "text", text: msg.content, state: "done" });
}
}
if (msg.role === "assistant" && Array.isArray(msg.tool_calls)) {

View File

@@ -12,6 +12,7 @@ import {
GlobeIcon,
ListChecksIcon,
MagnifyingGlassIcon,
MonitorIcon,
PencilSimpleIcon,
TerminalIcon,
TrashIcon,
@@ -48,6 +49,7 @@ function formatToolName(name: string): string {
type ToolCategory =
| "bash"
| "web"
| "browser"
| "file-read"
| "file-write"
| "file-delete"
@@ -65,19 +67,28 @@ function getToolCategory(toolName: string): ToolCategory {
case "WebSearch":
case "WebFetch":
return "web";
case "browser_navigate":
case "browser_act":
case "browser_screenshot":
return "browser";
case "read_workspace_file":
case "read_file":
case "Read":
return "file-read";
case "write_workspace_file":
case "write_file":
case "Write":
return "file-write";
case "delete_workspace_file":
return "file-delete";
case "list_workspace_files":
case "glob":
case "Glob":
return "file-list";
case "grep":
case "Grep":
return "search";
case "edit_file":
case "Edit":
return "edit";
case "TodoWrite":
@@ -115,6 +126,8 @@ function ToolIcon({
return <TerminalIcon size={14} weight="regular" className={iconClass} />;
case "web":
return <GlobeIcon size={14} weight="regular" className={iconClass} />;
case "browser":
return <MonitorIcon size={14} weight="regular" className={iconClass} />;
case "file-read":
return <FileIcon size={14} weight="regular" className={iconClass} />;
case "file-write":
@@ -150,6 +163,8 @@ function AccordionIcon({ category }: { category: ToolCategory }) {
return <TerminalIcon size={32} weight="light" />;
case "web":
return <GlobeIcon size={32} weight="light" />;
case "browser":
return <MonitorIcon size={32} weight="light" />;
case "file-read":
case "file-write":
return <FileIcon size={32} weight="light" />;
@@ -184,13 +199,25 @@ function getInputSummary(toolName: string, input: unknown): string | null {
return typeof inp.url === "string" ? inp.url : null;
case "WebSearch":
return typeof inp.query === "string" ? inp.query : null;
case "browser_navigate":
return typeof inp.url === "string" ? inp.url : null;
case "browser_act":
return typeof inp.action === "string"
? inp.target
? `${inp.action} ${inp.target}`
: (inp.action as string)
: null;
case "browser_screenshot":
return null;
case "read_workspace_file":
case "read_file":
case "Read":
return (
(typeof inp.file_path === "string" ? inp.file_path : null) ??
(typeof inp.path === "string" ? inp.path : null)
);
case "write_workspace_file":
case "write_file":
case "Write":
return (
(typeof inp.file_path === "string" ? inp.file_path : null) ??
@@ -198,10 +225,13 @@ function getInputSummary(toolName: string, input: unknown): string | null {
);
case "delete_workspace_file":
return typeof inp.file_path === "string" ? inp.file_path : null;
case "glob":
case "Glob":
return typeof inp.pattern === "string" ? inp.pattern : null;
case "grep":
case "Grep":
return typeof inp.pattern === "string" ? inp.pattern : null;
case "edit_file":
case "Edit":
return typeof inp.file_path === "string" ? inp.file_path : null;
case "TodoWrite": {
@@ -249,6 +279,11 @@ function getAnimationText(part: ToolUIPart, category: ToolCategory): string {
return shortSummary
? `Fetching ${shortSummary}`
: "Fetching web content…";
case "browser":
if (toolName === "browser_screenshot") return "Taking screenshot…";
return shortSummary
? `Browsing ${shortSummary}`
: "Interacting with browser…";
case "file-read":
return shortSummary ? `Reading ${shortSummary}` : "Reading file…";
case "file-write":
@@ -287,6 +322,11 @@ function getAnimationText(part: ToolUIPart, category: ToolCategory): string {
return shortSummary
? `Fetched ${shortSummary}`
: "Fetched web content";
case "browser":
if (toolName === "browser_screenshot") return "Screenshot captured";
return shortSummary
? `Browsed ${shortSummary}`
: "Browser action completed";
case "file-read":
return shortSummary ? `Read ${shortSummary}` : "File read completed";
case "file-write":
@@ -313,6 +353,8 @@ function getAnimationText(part: ToolUIPart, category: ToolCategory): string {
return "Command failed";
case "web":
return toolName === "WebSearch" ? "Search failed" : "Fetch failed";
case "browser":
return "Browser action failed";
default:
return `${formatToolName(toolName)} failed`;
}
@@ -418,16 +460,22 @@ function getBashAccordionData(
description: truncate(command, 80),
content: (
<div className="space-y-2">
{command && (
<div>
<p className="mb-1 text-xs font-medium text-slate-500">command</p>
<ContentCodeBlock>{command}</ContentCodeBlock>
</div>
)}
{stdout && (
<div>
<p className="mb-1 text-xs font-medium text-slate-500">stdout</p>
<ContentCodeBlock>{truncate(stdout, 2000)}</ContentCodeBlock>
<ContentCodeBlock>{stdout}</ContentCodeBlock>
</div>
)}
{stderr && (
<div>
<p className="mb-1 text-xs font-medium text-slate-500">stderr</p>
<ContentCodeBlock>{truncate(stderr, 1000)}</ContentCodeBlock>
<ContentCodeBlock>{stderr}</ContentCodeBlock>
</div>
)}
{!stdout && !stderr && message && (
@@ -475,18 +523,55 @@ function getWebAccordionData(
: "Search results",
description: truncate(url, 80),
content: content ? (
<ContentCodeBlock>{truncate(content, 2000)}</ContentCodeBlock>
<ContentCodeBlock>{content}</ContentCodeBlock>
) : message ? (
<ContentMessage>{message}</ContentMessage>
) : Object.keys(output).length > 0 ? (
<ContentCodeBlock>
{truncate(JSON.stringify(output, null, 2), 2000)}
</ContentCodeBlock>
<ContentCodeBlock>{JSON.stringify(output, null, 2)}</ContentCodeBlock>
) : null,
};
}
function getBrowserAccordionData(
toolName: string,
input: unknown,
output: Record<string, unknown>,
): AccordionData {
const message = getStringField(output, "message");
const snapshot = getStringField(output, "snapshot");
// Screenshot tool: show the file_id so the user knows it was saved
if (toolName === "browser_screenshot") {
const fileId = getStringField(output, "file_id");
const filename = getStringField(output, "filename");
return {
title: filename ? `Screenshot: ${filename}` : "Screenshot captured",
description: fileId ? `file_id: ${fileId}` : undefined,
content: message ? <ContentMessage>{message}</ContentMessage> : null,
};
}
// Navigate / act tools: show snapshot if available
const title =
toolName === "browser_navigate"
? (getStringField(output, "title") ?? "Page loaded")
: (message ?? "Action completed");
const url = getStringField(output, "url", "current_url");
return {
title,
description: url ? truncate(url, 80) : undefined,
content: snapshot ? (
<ContentCodeBlock>{truncate(snapshot, 3000)}</ContentCodeBlock>
) : message ? (
<ContentMessage>{message}</ContentMessage>
) : null,
};
}
function getFileAccordionData(
category: ToolCategory,
input: unknown,
output: Record<string, unknown>,
): AccordionData {
@@ -529,6 +614,20 @@ function getFileAccordionData(
displayContent = extractMcpText(output);
}
// For edit: show old/new diff; for write: show written content if output is just a status
const oldString =
category === "edit"
? getStringField(inp as Record<string, unknown>, "old_string")
: null;
const newString =
category === "edit"
? getStringField(inp as Record<string, unknown>, "new_string")
: null;
const writtenContent =
category === "file-write"
? getStringField(inp as Record<string, unknown>, "content")
: null;
// For Glob/list results, try to show file list
// Files can be either strings (from Glob) or objects (from list_workspace_files)
const files = Array.isArray(output.files) ? output.files : null;
@@ -562,18 +661,33 @@ function getFileAccordionData(
fileListText = fileLines.join("\n");
}
const isWriteOrEdit = category === "file-write" || category === "edit";
return {
title: message ?? "File output",
title:
message ??
(isWriteOrEdit ? `Wrote ${truncate(filePath, 60)}` : "File output"),
description: truncate(filePath, 80),
content: (
<div className="space-y-2">
{displayContent && (
<ContentCodeBlock>{truncate(displayContent, 2000)}</ContentCodeBlock>
)}
{fileListText && (
<ContentCodeBlock>{truncate(fileListText, 2000)}</ContentCodeBlock>
)}
{!displayContent && !fileListText && message && (
{oldString && newString != null ? (
<>
<div>
<p className="mb-1 text-xs font-medium text-red-400">removed</p>
<ContentCodeBlock>{oldString}</ContentCodeBlock>
</div>
<div>
<p className="mb-1 text-xs font-medium text-green-400">added</p>
<ContentCodeBlock>{newString}</ContentCodeBlock>
</div>
</>
) : writtenContent ? (
<ContentCodeBlock>{writtenContent}</ContentCodeBlock>
) : displayContent ? (
<ContentCodeBlock>{displayContent}</ContentCodeBlock>
) : null}
{fileListText && <ContentCodeBlock>{fileListText}</ContentCodeBlock>}
{!displayContent && !fileListText && !writtenContent && message && (
<ContentMessage>{message}</ContentMessage>
)}
</div>
@@ -675,14 +789,13 @@ function getDefaultAccordionData(
return {
title: "Output",
description: message ?? undefined,
content: (
<ContentCodeBlock>{truncate(displayContent, 2000)}</ContentCodeBlock>
),
content: <ContentCodeBlock>{displayContent}</ContentCodeBlock>,
};
}
function getAccordionData(
category: ToolCategory,
toolName: string,
input: unknown,
output: Record<string, unknown>,
): AccordionData {
@@ -691,13 +804,15 @@ function getAccordionData(
return getBashAccordionData(input, output);
case "web":
return getWebAccordionData(input, output);
case "browser":
return getBrowserAccordionData(toolName, input, output);
case "file-read":
case "file-write":
case "file-delete":
case "file-list":
case "search":
case "edit":
return getFileAccordionData(input, output);
return getFileAccordionData(category, input, output);
case "todo":
return getTodoAccordionData(input);
default:
@@ -733,25 +848,23 @@ export function GenericTool({ part }: Props) {
const showAccordion = hasOutput || hasError || hasTodoInput;
const accordionData = showAccordion
? getAccordionData(category, part.input, output ?? {})
? getAccordionData(category, toolName, part.input, output ?? {})
: null;
return (
<div className="py-2">
{/* Only show loading text when NOT showing accordion */}
{!showAccordion && (
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon
category={category}
isStreaming={isStreaming}
isError={isError}
/>
<MorphingTextAnimation
text={text}
className={isError ? "text-red-500" : undefined}
/>
</div>
)}
{/* Status line: always visible so the user sees what tool ran */}
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon
category={category}
isStreaming={isStreaming}
isError={isError}
/>
<MorphingTextAnimation
text={text}
className={isError ? "text-red-500" : undefined}
/>
</div>
{showAccordion && accordionData ? (
<ToolAccordion

View File

@@ -5,15 +5,25 @@ import {
} from "@/app/api/__generated__/endpoints/chat/chat";
import { toast } from "@/components/molecules/Toast/use-toast";
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
import { getWebSocketToken } from "@/lib/supabase/actions";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { environment } from "@/services/environment";
import { useQueryClient } from "@tanstack/react-query";
import { useEffect, useState } from "react";
import type { FileUIPart } from "ai";
import { useEffect, useRef, useState } from "react";
import { useCopilotUIStore } from "./store";
import { useChatSession } from "./useChatSession";
import { useCopilotStream } from "./useCopilotStream";
interface UploadedFile {
file_id: string;
name: string;
mime_type: string;
}
export function useCopilotPage() {
const { isUserLoading, isLoggedIn } = useSupabase();
const [isUploadingFiles, setIsUploadingFiles] = useState(false);
const [pendingMessage, setPendingMessage] = useState<string | null>(null);
const queryClient = useQueryClient();
@@ -77,26 +87,164 @@ export function useCopilotPage() {
const isMobile =
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
const pendingFilesRef = useRef<File[]>([]);
// --- Send pending message after session creation ---
useEffect(() => {
if (!sessionId || !pendingMessage) return;
if (!sessionId || pendingMessage === null) return;
const msg = pendingMessage;
const files = pendingFilesRef.current;
setPendingMessage(null);
sendMessage({ text: msg });
pendingFilesRef.current = [];
if (files.length > 0) {
setIsUploadingFiles(true);
void uploadFiles(files, sessionId)
.then((uploaded) => {
if (uploaded.length === 0) {
toast({
title: "File upload failed",
description: "Could not upload any files. Please try again.",
variant: "destructive",
});
return;
}
const fileParts = buildFileParts(uploaded);
sendMessage({
text: msg,
files: fileParts.length > 0 ? fileParts : undefined,
});
})
.finally(() => setIsUploadingFiles(false));
} else {
sendMessage({ text: msg });
}
}, [sessionId, pendingMessage, sendMessage]);
async function onSend(message: string) {
async function uploadFiles(
files: File[],
sid: string,
): Promise<UploadedFile[]> {
// Upload directly to the Python backend, bypassing the Next.js serverless
// proxy. Vercel's 4.5 MB function payload limit would reject larger files
// when routed through /api/workspace/files/upload.
const { token, error: tokenError } = await getWebSocketToken();
if (tokenError || !token) {
toast({
title: "Authentication error",
description: "Please sign in again.",
variant: "destructive",
});
return [];
}
const backendBase = environment.getAGPTServerBaseUrl();
const results = await Promise.allSettled(
files.map(async (file) => {
const formData = new FormData();
formData.append("file", file);
const url = new URL("/api/workspace/files/upload", backendBase);
url.searchParams.set("session_id", sid);
const res = await fetch(url.toString(), {
method: "POST",
headers: { Authorization: `Bearer ${token}` },
body: formData,
});
if (!res.ok) {
const err = await res.text();
console.error("File upload failed:", err);
toast({
title: "File upload failed",
description: file.name,
variant: "destructive",
});
throw new Error(err);
}
const data = await res.json();
if (!data.file_id) throw new Error("No file_id returned");
return {
file_id: data.file_id,
name: data.name || file.name,
mime_type: data.mime_type || "application/octet-stream",
} as UploadedFile;
}),
);
return results
.filter(
(r): r is PromiseFulfilledResult<UploadedFile> =>
r.status === "fulfilled",
)
.map((r) => r.value);
}
function buildFileParts(uploaded: UploadedFile[]): FileUIPart[] {
return uploaded.map((f) => ({
type: "file" as const,
mediaType: f.mime_type,
filename: f.name,
url: `/api/proxy/api/workspace/files/${f.file_id}/download`,
}));
}
async function onSend(message: string, files?: File[]) {
const trimmed = message.trim();
if (!trimmed) return;
if (!trimmed && (!files || files.length === 0)) return;
// Client-side file limits
if (files && files.length > 0) {
const MAX_FILES = 10;
const MAX_FILE_SIZE_BYTES = 100 * 1024 * 1024; // 100 MB
if (files.length > MAX_FILES) {
toast({
title: "Too many files",
description: `You can attach up to ${MAX_FILES} files at once.`,
variant: "destructive",
});
return;
}
const oversized = files.filter((f) => f.size > MAX_FILE_SIZE_BYTES);
if (oversized.length > 0) {
toast({
title: "File too large",
description: `${oversized[0].name} exceeds the 100 MB limit.`,
variant: "destructive",
});
return;
}
}
isUserStoppingRef.current = false;
if (sessionId) {
sendMessage({ text: trimmed });
if (files && files.length > 0) {
setIsUploadingFiles(true);
try {
const uploaded = await uploadFiles(files, sessionId);
if (uploaded.length === 0) {
// All uploads failed — abort send so chips revert to editable
throw new Error("All file uploads failed");
}
const fileParts = buildFileParts(uploaded);
sendMessage({
text: trimmed || "",
files: fileParts.length > 0 ? fileParts : undefined,
});
} finally {
setIsUploadingFiles(false);
}
} else {
sendMessage({ text: trimmed });
}
return;
}
setPendingMessage(trimmed);
setPendingMessage(trimmed || "");
if (files && files.length > 0) {
pendingFilesRef.current = files;
}
await createSession();
}
@@ -161,6 +309,7 @@ export function useCopilotPage() {
isLoadingSession,
isSessionError,
isCreatingSession,
isUploadingFiles,
isUserLoading,
isLoggedIn,
createSession,

View File

@@ -8,7 +8,7 @@ import { environment } from "@/services/environment";
import { useChat } from "@ai-sdk/react";
import { useQueryClient } from "@tanstack/react-query";
import { DefaultChatTransport } from "ai";
import type { UIMessage } from "ai";
import type { FileUIPart, UIMessage } from "ai";
import { useEffect, useMemo, useRef, useState } from "react";
import { deduplicateMessages, resolveInProgressTools } from "./helpers";
@@ -51,6 +51,15 @@ export function useCopilotStream({
api: `${environment.getAGPTServerBaseUrl()}/api/chat/sessions/${sessionId}/stream`,
prepareSendMessagesRequest: async ({ messages }) => {
const last = messages[messages.length - 1];
// Extract file_ids from FileUIPart entries on the message
const fileIds = last.parts
?.filter((p): p is FileUIPart => p.type === "file")
.map((p) => {
// URL is like /api/proxy/api/workspace/files/{id}/download
const match = p.url.match(/\/workspace\/files\/([^/]+)\//);
return match?.[1];
})
.filter(Boolean) as string[] | undefined;
return {
body: {
message: (
@@ -59,6 +68,7 @@ export function useCopilotStream({
).join(""),
is_user_message: last.role === "user",
context: null,
file_ids: fileIds && fileIds.length > 0 ? fileIds : null,
},
headers: await getAuthHeaders(),
};

View File

@@ -27,9 +27,9 @@ export async function POST(
try {
const body = await request.json();
const { message, is_user_message, context } = body;
const { message, is_user_message, context, file_ids } = body;
if (!message) {
if (message === undefined) {
return new Response(
JSON.stringify({ error: "Missing message parameter" }),
{ status: 400, headers: { "Content-Type": "application/json" } },
@@ -62,6 +62,7 @@ export async function POST(
message,
is_user_message: is_user_message ?? true,
context: context || null,
file_ids: file_ids || null,
}),
signal: debugSignal(),
});

View File

@@ -2039,7 +2039,9 @@
"description": "Successful Response",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/UploadFileResponse" }
"schema": {
"$ref": "#/components/schemas/backend__api__model__UploadFileResponse"
}
}
}
},
@@ -6497,6 +6499,59 @@
}
}
},
"/api/workspace/files/upload": {
"post": {
"tags": ["workspace"],
"summary": "Upload file to workspace",
"description": "Upload a file to the user's workspace.\n\nFiles are stored in session-scoped paths when session_id is provided,\nso the agent's session-scoped tools can discover them automatically.",
"operationId": "postWorkspaceUpload file to workspace",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "session_id",
"in": "query",
"required": false,
"schema": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Session Id"
}
}
],
"requestBody": {
"required": true,
"content": {
"multipart/form-data": {
"schema": {
"$ref": "#/components/schemas/Body_postWorkspaceUpload_file_to_workspace"
}
}
}
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/backend__api__features__workspace__routes__UploadFileResponse"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/workspace/files/{file_id}/download": {
"get": {
"tags": ["workspace"],
@@ -6531,6 +6586,30 @@
}
}
},
"/api/workspace/storage/usage": {
"get": {
"tags": ["workspace"],
"summary": "Get workspace storage usage",
"description": "Get storage usage information for the user's workspace.",
"operationId": "getWorkspaceGet workspace storage usage",
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/StorageUsageResponse"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
}
},
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/health": {
"get": {
"tags": ["health"],
@@ -7768,6 +7847,14 @@
"required": ["file"],
"title": "Body_postV2Upload submission media"
},
"Body_postWorkspaceUpload_file_to_workspace": {
"properties": {
"file": { "type": "string", "format": "binary", "title": "File" }
},
"type": "object",
"required": ["file"],
"title": "Body_postWorkspaceUpload file to workspace"
},
"BulkMoveAgentsRequest": {
"properties": {
"agent_ids": {
@@ -11167,6 +11254,9 @@
"operation_in_progress",
"input_validation_error",
"web_fetch",
"browser_navigate",
"browser_act",
"browser_screenshot",
"bash_exec",
"feature_request_search",
"feature_request_created",
@@ -11592,6 +11682,17 @@
"type": "object",
"title": "Stats"
},
"StorageUsageResponse": {
"properties": {
"used_bytes": { "type": "integer", "title": "Used Bytes" },
"limit_bytes": { "type": "integer", "title": "Limit Bytes" },
"used_percent": { "type": "number", "title": "Used Percent" },
"file_count": { "type": "integer", "title": "File Count" }
},
"type": "object",
"required": ["used_bytes", "limit_bytes", "used_percent", "file_count"],
"title": "StorageUsageResponse"
},
"StoreAgent": {
"properties": {
"slug": { "type": "string", "title": "Slug" },
@@ -12039,6 +12140,17 @@
{ "type": "null" }
],
"title": "Context"
},
"file_ids": {
"anyOf": [
{
"items": { "type": "string" },
"type": "array",
"maxItems": 20
},
{ "type": "null" }
],
"title": "File Ids"
}
},
"type": "object",
@@ -13620,24 +13732,6 @@
"required": ["timezone"],
"title": "UpdateTimezoneRequest"
},
"UploadFileResponse": {
"properties": {
"file_uri": { "type": "string", "title": "File Uri" },
"file_name": { "type": "string", "title": "File Name" },
"size": { "type": "integer", "title": "Size" },
"content_type": { "type": "string", "title": "Content Type" },
"expires_in_hours": { "type": "integer", "title": "Expires In Hours" }
},
"type": "object",
"required": [
"file_uri",
"file_name",
"size",
"content_type",
"expires_in_hours"
],
"title": "UploadFileResponse"
},
"UserHistoryResponse": {
"properties": {
"history": {
@@ -13966,6 +14060,36 @@
"url"
],
"title": "Webhook"
},
"backend__api__features__workspace__routes__UploadFileResponse": {
"properties": {
"file_id": { "type": "string", "title": "File Id" },
"name": { "type": "string", "title": "Name" },
"path": { "type": "string", "title": "Path" },
"mime_type": { "type": "string", "title": "Mime Type" },
"size_bytes": { "type": "integer", "title": "Size Bytes" }
},
"type": "object",
"required": ["file_id", "name", "path", "mime_type", "size_bytes"],
"title": "UploadFileResponse"
},
"backend__api__model__UploadFileResponse": {
"properties": {
"file_uri": { "type": "string", "title": "File Uri" },
"file_name": { "type": "string", "title": "File Name" },
"size": { "type": "integer", "title": "Size" },
"content_type": { "type": "string", "title": "Content Type" },
"expires_in_hours": { "type": "integer", "title": "Expires In Hours" }
},
"type": "object",
"required": [
"file_uri",
"file_name",
"size",
"content_type",
"expires_in_hours"
],
"title": "UploadFileResponse"
}
},
"securitySchemes": {

View File

@@ -0,0 +1,48 @@
import { environment } from "@/services/environment";
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
import { NextRequest, NextResponse } from "next/server";
export async function POST(request: NextRequest) {
try {
const formData = await request.formData();
const sessionId = request.nextUrl.searchParams.get("session_id");
const token = await getServerAuthToken();
const backendUrl = environment.getAGPTServerBaseUrl();
const uploadUrl = new URL("/api/workspace/files/upload", backendUrl);
if (sessionId) {
uploadUrl.searchParams.set("session_id", sessionId);
}
const headers: Record<string, string> = {};
if (token && token !== "no-token-found") {
headers["Authorization"] = `Bearer ${token}`;
}
const response = await fetch(uploadUrl.toString(), {
method: "POST",
headers,
body: formData,
});
if (!response.ok) {
const errorText = await response.text();
return new NextResponse(errorText, {
status: response.status,
});
}
const data = await response.json();
return NextResponse.json(data);
} catch (error) {
console.error("File upload proxy error:", error);
return NextResponse.json(
{
error: "Failed to upload file",
detail: error instanceof Error ? error.message : String(error),
},
{ status: 500 },
);
}
}