mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
7 Commits
hotfix/tra
...
swiftyos/t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a2c97d428e | ||
|
|
78fee94569 | ||
|
|
b913e8f9de | ||
|
|
7f16a10e9e | ||
|
|
57f56c0caa | ||
|
|
e8b82cd268 | ||
|
|
4a108ad5d2 |
2
.github/workflows/platform-frontend-ci.yml
vendored
2
.github/workflows/platform-frontend-ci.yml
vendored
@@ -149,7 +149,7 @@ jobs:
|
||||
driver-opts: network=host
|
||||
|
||||
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
||||
uses: crazy-max/ghaction-github-runtime@v4
|
||||
uses: crazy-max/ghaction-github-runtime@v3
|
||||
|
||||
- name: Set up Platform - Build Docker images (with cache)
|
||||
working-directory: autogpt_platform
|
||||
|
||||
3
autogpt_platform/.gitignore
vendored
3
autogpt_platform/.gitignore
vendored
@@ -1,3 +1,2 @@
|
||||
*.ignore.*
|
||||
*.ign.*
|
||||
.application.logs
|
||||
*.ign.*
|
||||
@@ -65,6 +65,12 @@ LANGFUSE_PUBLIC_KEY=
|
||||
LANGFUSE_SECRET_KEY=
|
||||
LANGFUSE_HOST=https://cloud.langfuse.com
|
||||
|
||||
# OTLP Tracing
|
||||
# Base host for OTLP trace ingestion (for example Product Intelligence)
|
||||
OTLP_TRACING_HOST=
|
||||
# Bearer token for OTLP trace ingestion endpoint (optional)
|
||||
OTLP_TRACING_TOKEN=
|
||||
|
||||
# OAuth Credentials
|
||||
# For the OAuth callback URL, use <your_frontend_url>/auth/integrations/oauth_callback,
|
||||
# e.g. http://localhost:3000/auth/integrations/oauth_callback
|
||||
|
||||
@@ -95,7 +95,7 @@ ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
|
||||
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
|
||||
# for the bash_exec MCP tool (fallback when E2B is not configured).
|
||||
# for the bash_exec MCP tool.
|
||||
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.13 \
|
||||
@@ -111,29 +111,13 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
# Copy poetry (build-time only, for `poetry install --only-root` to create entry points)
|
||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||
# Copy Node.js installation for Prisma and agent-browser.
|
||||
# npm/npx are symlinks in the builder (-> ../lib/node_modules/npm/bin/*-cli.js);
|
||||
# COPY resolves them to regular files, breaking require() paths. Recreate as
|
||||
# proper symlinks so npm/npx can find their modules.
|
||||
# Copy Node.js installation for Prisma
|
||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||
RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
|
||||
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
|
||||
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
# Install agent-browser (Copilot browser tool) + Chromium runtime dependencies.
|
||||
# These are the runtime libraries Chromium/Playwright needs on Debian 13 (trixie).
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
|
||||
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
|
||||
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
|
||||
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
|
||||
fonts-liberation libfontconfig1 \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& npm install -g agent-browser \
|
||||
&& agent-browser install \
|
||||
&& rm -rf /tmp/* /root/.npm
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
# Copy only the .venv from builder (not the entire /app directory)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
from uuid import uuid4
|
||||
@@ -10,8 +9,7 @@ from uuid import uuid4
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from prisma.models import UserWorkspaceFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
@@ -25,7 +23,6 @@ from backend.copilot.model import (
|
||||
delete_chat_session,
|
||||
get_chat_session,
|
||||
get_user_sessions,
|
||||
update_session_title,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.models import (
|
||||
@@ -43,8 +40,6 @@ from backend.copilot.tools.models import (
|
||||
ErrorResponse,
|
||||
ExecutionStartedResponse,
|
||||
InputValidationErrorResponse,
|
||||
MCPToolOutputResponse,
|
||||
MCPToolsDiscoveredResponse,
|
||||
NeedLoginResponse,
|
||||
NoResultsResponse,
|
||||
SetupRequirementsResponse,
|
||||
@@ -52,14 +47,10 @@ from backend.copilot.tools.models import (
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -88,9 +79,6 @@ class StreamChatRequest(BaseModel):
|
||||
message: str
|
||||
is_user_message: bool = True
|
||||
context: dict[str, str] | None = None # {url: str, content: str}
|
||||
file_ids: list[str] | None = Field(
|
||||
default=None, max_length=20
|
||||
) # Workspace file IDs attached to this message
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
@@ -142,20 +130,6 @@ class CancelSessionResponse(BaseModel):
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class UpdateSessionTitleRequest(BaseModel):
|
||||
"""Request model for updating a session's title."""
|
||||
|
||||
title: str
|
||||
|
||||
@field_validator("title")
|
||||
@classmethod
|
||||
def title_must_not_be_blank(cls, v: str) -> str:
|
||||
stripped = v.strip()
|
||||
if not stripped:
|
||||
raise ValueError("Title must not be blank")
|
||||
return stripped
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
|
||||
|
||||
@@ -264,58 +238,9 @@ async def delete_session(
|
||||
detail=f"Session {session_id} not found or access denied",
|
||||
)
|
||||
|
||||
# Best-effort cleanup of the E2B sandbox (if any).
|
||||
config = ChatConfig()
|
||||
if config.use_e2b_sandbox and config.e2b_api_key:
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
|
||||
try:
|
||||
await kill_sandbox(session_id, config.e2b_api_key)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[E2B] Failed to kill sandbox for session %s", session_id[:12]
|
||||
)
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/sessions/{session_id}/title",
|
||||
summary="Update session title",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
status_code=200,
|
||||
responses={404: {"description": "Session not found or access denied"}},
|
||||
)
|
||||
async def update_session_title_route(
|
||||
session_id: str,
|
||||
request: UpdateSessionTitleRequest,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> dict:
|
||||
"""
|
||||
Update the title of a chat session.
|
||||
|
||||
Allows the user to rename their chat session.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to update.
|
||||
request: Request body containing the new title.
|
||||
user_id: The authenticated user's ID.
|
||||
|
||||
Returns:
|
||||
dict: Status of the update.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if session not found or not owned by user.
|
||||
"""
|
||||
success = await update_session_title(session_id, user_id, request.title)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Session {session_id} not found or access denied",
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}",
|
||||
)
|
||||
@@ -469,38 +394,6 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
|
||||
# Enrich message with file metadata if file_ids are provided.
|
||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||
sanitized_file_ids: list[str] | None = None
|
||||
if request.file_ids and user_id:
|
||||
# Filter to valid UUIDs only to prevent DB abuse
|
||||
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
|
||||
|
||||
if valid_ids:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Batch query instead of N+1
|
||||
files = await UserWorkspaceFile.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": valid_ids},
|
||||
"workspaceId": workspace.id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
)
|
||||
# Only keep IDs that actually exist in the user's workspace
|
||||
sanitized_file_ids = [wf.id for wf in files] or None
|
||||
file_lines: list[str] = [
|
||||
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
|
||||
for wf in files
|
||||
]
|
||||
if file_lines:
|
||||
files_block = (
|
||||
"\n\n[Attached files]\n"
|
||||
+ "\n".join(file_lines)
|
||||
+ "\nUse read_workspace_file with the file_id to access file contents."
|
||||
)
|
||||
request.message += files_block
|
||||
|
||||
# Atomically append user message to session BEFORE creating task to avoid
|
||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||
# saved yet. append_and_save_message re-fetches inside a lock to prevent
|
||||
@@ -552,7 +445,6 @@ async def stream_chat_post(
|
||||
turn_id=turn_id,
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
file_ids=sanitized_file_ids,
|
||||
)
|
||||
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
@@ -908,8 +800,6 @@ ToolResponseUnion = (
|
||||
| BlockOutputResponse
|
||||
| DocSearchResultsResponse
|
||||
| DocPageResponse
|
||||
| MCPToolsDiscoveredResponse
|
||||
| MCPToolOutputResponse
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,251 +0,0 @@
|
||||
"""Tests for chat API routes: session title update and file attachment validation."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.chat import routes as chat_routes
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(chat_routes.router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _mock_update_session_title(
|
||||
mocker: pytest_mock.MockerFixture, *, success: bool = True
|
||||
):
|
||||
"""Mock update_session_title."""
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.update_session_title",
|
||||
new_callable=AsyncMock,
|
||||
return_value=success,
|
||||
)
|
||||
|
||||
|
||||
# ─── Update title: success ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_update_title_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_update = _mock_update_session_title(mocker, success=True)
|
||||
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": "My project"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
mock_update.assert_called_once_with("sess-1", test_user_id, "My project")
|
||||
|
||||
|
||||
def test_update_title_trims_whitespace(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_update = _mock_update_session_title(mocker, success=True)
|
||||
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": " trimmed "},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_update.assert_called_once_with("sess-1", test_user_id, "trimmed")
|
||||
|
||||
|
||||
# ─── Update title: blank / whitespace-only → 422 ──────────────────────
|
||||
|
||||
|
||||
def test_update_title_blank_rejected(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Whitespace-only titles must be rejected before hitting the DB."""
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": " "},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_update_title_empty_rejected(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": ""},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
# ─── Update title: session not found or wrong user → 404 ──────────────
|
||||
|
||||
|
||||
def test_update_title_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
_mock_update_session_title(mocker, success=False)
|
||||
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": "New name"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ─── file_ids Pydantic validation ─────────────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_rejects_too_many_file_ids():
|
||||
"""More than 20 file_ids should be rejected by Pydantic validation (422)."""
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={
|
||||
"message": "hello",
|
||||
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(21)],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def _mock_stream_internals(mocker: pytest_mock.MockFixture):
|
||||
"""Mock the async internals of stream_chat_post so tests can exercise
|
||||
validation and enrichment logic without needing Redis/RabbitMQ."""
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=None,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.create_session = mocker.AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
|
||||
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
|
||||
"""Exactly 20 file_ids should be accepted (not rejected by validation)."""
|
||||
_mock_stream_internals(mocker)
|
||||
# Patch workspace lookup as imported by the routes module
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={
|
||||
"message": "hello",
|
||||
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(20)],
|
||||
},
|
||||
)
|
||||
# Should get past validation — 200 streaming response expected
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ─── UUID format filtering ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
"""Non-UUID strings in file_ids should be silently filtered out
|
||||
and NOT passed to the database query."""
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
valid_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={
|
||||
"message": "hello",
|
||||
"file_ids": [
|
||||
valid_id,
|
||||
"not-a-uuid",
|
||||
"../../../etc/passwd",
|
||||
"",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
# The find_many call should only receive the one valid UUID
|
||||
mock_prisma.find_many.assert_called_once()
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["id"]["in"] == [valid_id]
|
||||
|
||||
|
||||
# ─── Cross-workspace file_ids ─────────────────────────────────────────
|
||||
|
||||
|
||||
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
"""The batch query should scope to the user's workspace."""
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "my-workspace-id"})(),
|
||||
)
|
||||
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
fid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hi", "file_ids": [fid]},
|
||||
)
|
||||
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
|
||||
assert call_kwargs["where"]["isDeleted"] is False
|
||||
@@ -22,7 +22,6 @@ from backend.data.human_review import (
|
||||
)
|
||||
from backend.data.model import USER_TIMEZONE_NOT_SET
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.executor.utils import add_graph_execution
|
||||
|
||||
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
|
||||
@@ -322,13 +321,10 @@ async def process_review_action(
|
||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||
)
|
||||
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
execution_context = ExecutionContext(
|
||||
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
||||
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
||||
user_timezone=user_timezone,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
await add_graph_execution(
|
||||
|
||||
@@ -7,24 +7,20 @@ frontend can list available tools on an MCP server before placing a block.
|
||||
|
||||
import logging
|
||||
from typing import Annotated, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import fastapi
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from fastapi import Security
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.api.features.integrations.router import CredentialsMetaResponse
|
||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
||||
from backend.blocks.mcp.helpers import (
|
||||
auto_lookup_mcp_credential,
|
||||
normalize_mcp_url,
|
||||
server_host,
|
||||
)
|
||||
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import HTTPClientError, Requests, validate_url
|
||||
from backend.util.request import HTTPClientError, Requests
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -78,20 +74,32 @@ async def discover_tools(
|
||||
If the user has a stored MCP credential for this server URL, it will be
|
||||
used automatically — no need to pass an explicit auth token.
|
||||
"""
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
auth_token = request.auth_token
|
||||
|
||||
# Auto-use stored MCP credential when no explicit token is provided.
|
||||
if not auth_token:
|
||||
best_cred = await auto_lookup_mcp_credential(
|
||||
user_id, normalize_mcp_url(request.server_url)
|
||||
mcp_creds = await creds_manager.store.get_creds_by_provider(
|
||||
user_id, ProviderName.MCP.value
|
||||
)
|
||||
# Find the freshest credential for this server URL
|
||||
best_cred: OAuth2Credentials | None = None
|
||||
for cred in mcp_creds:
|
||||
if (
|
||||
isinstance(cred, OAuth2Credentials)
|
||||
and (cred.metadata or {}).get("mcp_server_url") == request.server_url
|
||||
):
|
||||
if best_cred is None or (
|
||||
(cred.access_token_expires_at or 0)
|
||||
> (best_cred.access_token_expires_at or 0)
|
||||
):
|
||||
best_cred = cred
|
||||
if best_cred:
|
||||
# Refresh the token if expired before using it
|
||||
best_cred = await creds_manager.refresh_if_needed(user_id, best_cred)
|
||||
logger.info(
|
||||
f"Using MCP credential {best_cred.id} for {request.server_url}, "
|
||||
f"expires_at={best_cred.access_token_expires_at}"
|
||||
)
|
||||
auth_token = best_cred.access_token.get_secret_value()
|
||||
|
||||
client = MCPClient(request.server_url, auth_token=auth_token)
|
||||
@@ -126,7 +134,7 @@ async def discover_tools(
|
||||
],
|
||||
server_name=(
|
||||
init_result.get("serverInfo", {}).get("name")
|
||||
or server_host(request.server_url)
|
||||
or urlparse(request.server_url).hostname
|
||||
or "MCP"
|
||||
),
|
||||
protocol_version=init_result.get("protocolVersion"),
|
||||
@@ -165,16 +173,7 @@ async def mcp_oauth_login(
|
||||
3. Performs Dynamic Client Registration (RFC 7591) if available
|
||||
4. Returns the authorization URL for the frontend to open in a popup
|
||||
"""
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
# Normalize the URL so that credentials stored here are matched consistently
|
||||
# by auto_lookup_mcp_credential (which also uses normalized URLs).
|
||||
server_url = normalize_mcp_url(request.server_url)
|
||||
client = MCPClient(server_url)
|
||||
client = MCPClient(request.server_url)
|
||||
|
||||
# Step 1: Discover protected-resource metadata (RFC 9728)
|
||||
protected_resource = await client.discover_auth()
|
||||
@@ -183,16 +182,7 @@ async def mcp_oauth_login(
|
||||
|
||||
if protected_resource and protected_resource.get("authorization_servers"):
|
||||
auth_server_url = protected_resource["authorization_servers"][0]
|
||||
resource_url = protected_resource.get("resource", server_url)
|
||||
|
||||
# Validate the auth server URL from metadata to prevent SSRF.
|
||||
try:
|
||||
await validate_url(auth_server_url, trusted_origins=[])
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid authorization server URL in metadata: {e}",
|
||||
)
|
||||
resource_url = protected_resource.get("resource", request.server_url)
|
||||
|
||||
# Step 2a: Discover auth-server metadata (RFC 8414)
|
||||
metadata = await client.discover_auth_server_metadata(auth_server_url)
|
||||
@@ -202,7 +192,7 @@ async def mcp_oauth_login(
|
||||
# Don't assume a resource_url — omitting it lets the auth server choose
|
||||
# the correct audience for the token (RFC 8707 resource is optional).
|
||||
resource_url = None
|
||||
metadata = await client.discover_auth_server_metadata(server_url)
|
||||
metadata = await client.discover_auth_server_metadata(request.server_url)
|
||||
|
||||
if (
|
||||
not metadata
|
||||
@@ -232,18 +222,12 @@ async def mcp_oauth_login(
|
||||
client_id = ""
|
||||
client_secret = ""
|
||||
if registration_endpoint:
|
||||
# Validate the registration endpoint to prevent SSRF via metadata.
|
||||
try:
|
||||
await validate_url(registration_endpoint, trusted_origins=[])
|
||||
except ValueError:
|
||||
pass # Skip registration, fall back to default client_id
|
||||
else:
|
||||
reg_result = await _register_mcp_client(
|
||||
registration_endpoint, redirect_uri, server_url
|
||||
)
|
||||
if reg_result:
|
||||
client_id = reg_result.get("client_id", "")
|
||||
client_secret = reg_result.get("client_secret", "")
|
||||
reg_result = await _register_mcp_client(
|
||||
registration_endpoint, redirect_uri, request.server_url
|
||||
)
|
||||
if reg_result:
|
||||
client_id = reg_result.get("client_id", "")
|
||||
client_secret = reg_result.get("client_secret", "")
|
||||
|
||||
if not client_id:
|
||||
client_id = "autogpt-platform"
|
||||
@@ -261,7 +245,7 @@ async def mcp_oauth_login(
|
||||
"token_url": token_url,
|
||||
"revoke_url": revoke_url,
|
||||
"resource_url": resource_url,
|
||||
"server_url": server_url,
|
||||
"server_url": request.server_url,
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
},
|
||||
@@ -358,7 +342,7 @@ async def mcp_oauth_callback(
|
||||
credentials.metadata["mcp_token_url"] = meta["token_url"]
|
||||
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
|
||||
|
||||
hostname = server_host(meta["server_url"])
|
||||
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
|
||||
credentials.title = f"MCP: {hostname}"
|
||||
|
||||
# Remove old MCP credentials for the same server to prevent stale token buildup.
|
||||
@@ -373,9 +357,7 @@ async def mcp_oauth_callback(
|
||||
):
|
||||
await creds_manager.store.delete_creds_by_id(user_id, old.id)
|
||||
logger.info(
|
||||
"Removed old MCP credential %s for %s",
|
||||
old.id,
|
||||
server_host(meta["server_url"]),
|
||||
f"Removed old MCP credential {old.id} for {meta['server_url']}"
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Could not clean up old MCP credentials", exc_info=True)
|
||||
@@ -393,93 +375,6 @@ async def mcp_oauth_callback(
|
||||
)
|
||||
|
||||
|
||||
# ======================== Bearer Token ======================== #
|
||||
|
||||
|
||||
class MCPStoreTokenRequest(BaseModel):
|
||||
"""Request to store a bearer token for an MCP server that doesn't support OAuth."""
|
||||
|
||||
server_url: str = Field(
|
||||
description="MCP server URL the token authenticates against"
|
||||
)
|
||||
token: SecretStr = Field(
|
||||
min_length=1, description="Bearer token / API key for the MCP server"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/token",
|
||||
summary="Store a bearer token for an MCP server",
|
||||
)
|
||||
async def mcp_store_token(
|
||||
request: MCPStoreTokenRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> CredentialsMetaResponse:
|
||||
"""
|
||||
Store a manually provided bearer token as an MCP credential.
|
||||
|
||||
Used by the Copilot MCPSetupCard when the server doesn't support the MCP
|
||||
OAuth discovery flow (returns 400 from /oauth/login). Subsequent
|
||||
``run_mcp_tool`` calls will automatically pick up the token via
|
||||
``_auto_lookup_credential``.
|
||||
"""
|
||||
token = request.token.get_secret_value().strip()
|
||||
if not token:
|
||||
raise fastapi.HTTPException(status_code=422, detail="Token must not be blank.")
|
||||
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
# Normalize URL so trailing-slash variants match existing credentials.
|
||||
server_url = normalize_mcp_url(request.server_url)
|
||||
hostname = server_host(server_url)
|
||||
|
||||
# Collect IDs of old credentials to clean up after successful create.
|
||||
old_cred_ids: list[str] = []
|
||||
try:
|
||||
old_creds = await creds_manager.store.get_creds_by_provider(
|
||||
user_id, ProviderName.MCP.value
|
||||
)
|
||||
old_cred_ids = [
|
||||
old.id
|
||||
for old in old_creds
|
||||
if isinstance(old, OAuth2Credentials)
|
||||
and normalize_mcp_url((old.metadata or {}).get("mcp_server_url", ""))
|
||||
== server_url
|
||||
]
|
||||
except Exception:
|
||||
logger.debug("Could not query old MCP token credentials", exc_info=True)
|
||||
|
||||
credentials = OAuth2Credentials(
|
||||
provider=ProviderName.MCP.value,
|
||||
title=f"MCP: {hostname}",
|
||||
access_token=SecretStr(token),
|
||||
scopes=[],
|
||||
metadata={"mcp_server_url": server_url},
|
||||
)
|
||||
await creds_manager.create(user_id, credentials)
|
||||
|
||||
# Only delete old credentials after the new one is safely stored.
|
||||
for old_id in old_cred_ids:
|
||||
try:
|
||||
await creds_manager.store.delete_creds_by_id(user_id, old_id)
|
||||
except Exception:
|
||||
logger.debug("Could not clean up old MCP token credential", exc_info=True)
|
||||
|
||||
return CredentialsMetaResponse(
|
||||
id=credentials.id,
|
||||
provider=credentials.provider,
|
||||
type=credentials.type,
|
||||
title=credentials.title,
|
||||
scopes=credentials.scopes,
|
||||
username=credentials.username,
|
||||
host=hostname,
|
||||
)
|
||||
|
||||
|
||||
# ======================== Helpers ======================== #
|
||||
|
||||
|
||||
@@ -505,7 +400,5 @@ async def _register_mcp_client(
|
||||
return data
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Dynamic client registration failed for %s: %s", server_host(server_url), e
|
||||
)
|
||||
logger.warning(f"Dynamic client registration failed for {server_url}: {e}")
|
||||
return None
|
||||
|
||||
@@ -11,11 +11,9 @@ import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.api.features.mcp.routes import router
|
||||
from backend.blocks.mcp.client import MCPClientError, MCPTool
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.util.request import HTTPClientError
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
@@ -30,16 +28,6 @@ async def client():
|
||||
yield c
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _bypass_ssrf_validation():
|
||||
"""Bypass validate_url in all route tests (test URLs don't resolve)."""
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestDiscoverTools:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_success(self, client):
|
||||
@@ -68,12 +56,9 @@ class TestDiscoverTools:
|
||||
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch(
|
||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
return_value={
|
||||
@@ -122,6 +107,10 @@ class TestDiscoverTools:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_auto_uses_stored_credential(self, client):
|
||||
"""When no explicit token is given, stored MCP credentials are used."""
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
|
||||
stored_cred = OAuth2Credentials(
|
||||
provider="mcp",
|
||||
title="MCP: example.com",
|
||||
@@ -135,12 +124,10 @@ class TestDiscoverTools:
|
||||
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch(
|
||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=stored_cred,
|
||||
),
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred])
|
||||
mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred)
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
|
||||
@@ -162,12 +149,9 @@ class TestDiscoverTools:
|
||||
async def test_discover_tools_mcp_error(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch(
|
||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
side_effect=MCPClientError("Connection refused")
|
||||
@@ -185,12 +169,9 @@ class TestDiscoverTools:
|
||||
async def test_discover_tools_generic_error(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch(
|
||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
|
||||
|
||||
@@ -206,12 +187,9 @@ class TestDiscoverTools:
|
||||
async def test_discover_tools_auth_required(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch(
|
||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
|
||||
@@ -229,12 +207,9 @@ class TestDiscoverTools:
|
||||
async def test_discover_tools_forbidden(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch(
|
||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
|
||||
@@ -356,6 +331,10 @@ class TestOAuthLogin:
|
||||
class TestOAuthCallback:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_callback_success(self, client):
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
|
||||
mock_creds = OAuth2Credentials(
|
||||
provider="mcp",
|
||||
title=None,
|
||||
@@ -455,118 +434,3 @@ class TestOAuthCallback:
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "token exchange failed" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
class TestStoreToken:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_token_success(self, client):
|
||||
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
mock_cm.create = AsyncMock()
|
||||
|
||||
response = await client.post(
|
||||
"/token",
|
||||
json={
|
||||
"server_url": "https://mcp.example.com/mcp",
|
||||
"token": "my-api-key-123",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["provider"] == "mcp"
|
||||
assert data["type"] == "oauth2"
|
||||
assert data["host"] == "mcp.example.com"
|
||||
mock_cm.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_token_blank_rejected(self, client):
|
||||
"""Blank token string (after stripping) should return 422."""
|
||||
response = await client.post(
|
||||
"/token",
|
||||
json={
|
||||
"server_url": "https://mcp.example.com/mcp",
|
||||
"token": " ",
|
||||
},
|
||||
)
|
||||
# Pydantic min_length=1 catches the whitespace-only token
|
||||
assert response.status_code == 422
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_token_replaces_old_credential(self, client):
|
||||
old_cred = OAuth2Credentials(
|
||||
provider="mcp",
|
||||
title="MCP: mcp.example.com",
|
||||
access_token=SecretStr("old-token"),
|
||||
scopes=[],
|
||||
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
|
||||
)
|
||||
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[old_cred])
|
||||
mock_cm.create = AsyncMock()
|
||||
mock_cm.store.delete_creds_by_id = AsyncMock()
|
||||
|
||||
response = await client.post(
|
||||
"/token",
|
||||
json={
|
||||
"server_url": "https://mcp.example.com/mcp",
|
||||
"token": "new-token",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_cm.store.delete_creds_by_id.assert_called_once_with(
|
||||
"test-user-id", old_cred.id
|
||||
)
|
||||
|
||||
|
||||
class TestSSRFValidation:
|
||||
"""Verify that validate_url is enforced on all endpoints."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_ssrf_blocked(self, client):
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked loopback"),
|
||||
):
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={"server_url": "http://localhost/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "blocked loopback" in response.json()["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_login_ssrf_blocked(self, client):
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked private IP"),
|
||||
):
|
||||
response = await client.post(
|
||||
"/oauth/login",
|
||||
json={"server_url": "http://10.0.0.1/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "blocked private ip" in response.json()["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_token_ssrf_blocked(self, client):
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked loopback"),
|
||||
):
|
||||
response = await client.post(
|
||||
"/token",
|
||||
json={
|
||||
"server_url": "http://127.0.0.1/mcp",
|
||||
"token": "some-token",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "blocked loopback" in response.json()["detail"].lower()
|
||||
|
||||
@@ -3,29 +3,15 @@ Workspace API routes for managing user file storage.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Annotated
|
||||
from urllib.parse import quote
|
||||
|
||||
import fastapi
|
||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||
from fastapi import Query, UploadFile
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.workspace import (
|
||||
WorkspaceFile,
|
||||
count_workspace_files,
|
||||
get_or_create_workspace,
|
||||
get_workspace,
|
||||
get_workspace_file,
|
||||
get_workspace_total_size,
|
||||
soft_delete_workspace_file,
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
from backend.data.workspace import WorkspaceFile, get_workspace, get_workspace_file
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
|
||||
@@ -112,25 +98,6 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
||||
raise
|
||||
|
||||
|
||||
class UploadFileResponse(BaseModel):
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
|
||||
|
||||
class DeleteFileResponse(BaseModel):
|
||||
deleted: bool
|
||||
|
||||
|
||||
class StorageUsageResponse(BaseModel):
|
||||
used_bytes: int
|
||||
limit_bytes: int
|
||||
used_percent: float
|
||||
file_count: int
|
||||
|
||||
|
||||
@router.get(
|
||||
"/files/{file_id}/download",
|
||||
summary="Download file by ID",
|
||||
@@ -153,148 +120,3 @@ async def download_file(
|
||||
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
return await _create_file_download_response(file)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/files/{file_id}",
|
||||
summary="Delete a workspace file",
|
||||
)
|
||||
async def delete_workspace_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
file_id: str,
|
||||
) -> DeleteFileResponse:
|
||||
"""
|
||||
Soft-delete a workspace file and attempt to remove it from storage.
|
||||
|
||||
Used when a user clears a file input in the builder.
|
||||
"""
|
||||
workspace = await get_workspace(user_id)
|
||||
if workspace is None:
|
||||
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
|
||||
|
||||
manager = WorkspaceManager(user_id, workspace.id)
|
||||
deleted = await manager.delete_file(file_id)
|
||||
if not deleted:
|
||||
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
return DeleteFileResponse(deleted=True)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/files/upload",
|
||||
summary="Upload file to workspace",
|
||||
)
|
||||
async def upload_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
file: UploadFile,
|
||||
session_id: str | None = Query(default=None),
|
||||
) -> UploadFileResponse:
|
||||
"""
|
||||
Upload a file to the user's workspace.
|
||||
|
||||
Files are stored in session-scoped paths when session_id is provided,
|
||||
so the agent's session-scoped tools can discover them automatically.
|
||||
"""
|
||||
config = Config()
|
||||
|
||||
# Sanitize filename — strip any directory components
|
||||
filename = os.path.basename(file.filename or "upload") or "upload"
|
||||
|
||||
# Read file content with early abort on size limit
|
||||
max_file_bytes = config.max_file_size_mb * 1024 * 1024
|
||||
chunks: list[bytes] = []
|
||||
total_size = 0
|
||||
while chunk := await file.read(64 * 1024): # 64KB chunks
|
||||
total_size += len(chunk)
|
||||
if total_size > max_file_bytes:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File exceeds maximum size of {config.max_file_size_mb} MB",
|
||||
)
|
||||
chunks.append(chunk)
|
||||
content = b"".join(chunks)
|
||||
|
||||
# Get or create workspace
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
# Pre-write storage cap check (soft check — final enforcement is post-write)
|
||||
storage_limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
|
||||
current_usage = await get_workspace_total_size(workspace.id)
|
||||
if storage_limit_bytes and current_usage + len(content) > storage_limit_bytes:
|
||||
used_percent = (current_usage / storage_limit_bytes) * 100
|
||||
raise fastapi.HTTPException(
|
||||
status_code=413,
|
||||
detail={
|
||||
"message": "Storage limit exceeded",
|
||||
"used_bytes": current_usage,
|
||||
"limit_bytes": storage_limit_bytes,
|
||||
"used_percent": round(used_percent, 1),
|
||||
},
|
||||
)
|
||||
|
||||
# Warn at 80% usage
|
||||
if (
|
||||
storage_limit_bytes
|
||||
and (usage_ratio := (current_usage + len(content)) / storage_limit_bytes) >= 0.8
|
||||
):
|
||||
logger.warning(
|
||||
f"User {user_id} workspace storage at {usage_ratio * 100:.1f}% "
|
||||
f"({current_usage + len(content)} / {storage_limit_bytes} bytes)"
|
||||
)
|
||||
|
||||
# Virus scan
|
||||
await scan_content_safe(content, filename=filename)
|
||||
|
||||
# Write file via WorkspaceManager
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
try:
|
||||
workspace_file = await manager.write_file(content, filename)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
|
||||
|
||||
# Post-write storage check — eliminates TOCTOU race on the quota.
|
||||
# If a concurrent upload pushed us over the limit, undo this write.
|
||||
new_total = await get_workspace_total_size(workspace.id)
|
||||
if storage_limit_bytes and new_total > storage_limit_bytes:
|
||||
await soft_delete_workspace_file(workspace_file.id, workspace.id)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=413,
|
||||
detail={
|
||||
"message": "Storage limit exceeded (concurrent upload)",
|
||||
"used_bytes": new_total,
|
||||
"limit_bytes": storage_limit_bytes,
|
||||
},
|
||||
)
|
||||
|
||||
return UploadFileResponse(
|
||||
file_id=workspace_file.id,
|
||||
name=workspace_file.name,
|
||||
path=workspace_file.path,
|
||||
mime_type=workspace_file.mime_type,
|
||||
size_bytes=workspace_file.size_bytes,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/storage/usage",
|
||||
summary="Get workspace storage usage",
|
||||
)
|
||||
async def get_storage_usage(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
) -> StorageUsageResponse:
|
||||
"""
|
||||
Get storage usage information for the user's workspace.
|
||||
"""
|
||||
config = Config()
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
used_bytes = await get_workspace_total_size(workspace.id)
|
||||
file_count = await count_workspace_files(workspace.id)
|
||||
limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
|
||||
|
||||
return StorageUsageResponse(
|
||||
used_bytes=used_bytes,
|
||||
limit_bytes=limit_bytes,
|
||||
used_percent=round((used_bytes / limit_bytes) * 100, 1) if limit_bytes else 0,
|
||||
file_count=file_count,
|
||||
)
|
||||
|
||||
@@ -1,359 +0,0 @@
|
||||
"""Tests for workspace file upload and download routes."""
|
||||
|
||||
import io
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.workspace import routes as workspace_routes
|
||||
from backend.data.workspace import WorkspaceFile
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(workspace_routes.router)
|
||||
|
||||
|
||||
@app.exception_handler(ValueError)
|
||||
async def _value_error_handler(
|
||||
request: fastapi.Request, exc: ValueError
|
||||
) -> fastapi.responses.JSONResponse:
|
||||
"""Mirror the production ValueError → 400 mapping from rest_api.py."""
|
||||
return fastapi.responses.JSONResponse(status_code=400, content={"detail": str(exc)})
|
||||
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
MOCK_WORKSPACE = type("W", (), {"id": "ws-1"})()
|
||||
|
||||
_NOW = datetime(2023, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
MOCK_FILE = WorkspaceFile(
|
||||
id="file-aaa-bbb",
|
||||
workspace_id="ws-1",
|
||||
created_at=_NOW,
|
||||
updated_at=_NOW,
|
||||
name="hello.txt",
|
||||
path="/session/hello.txt",
|
||||
mime_type="text/plain",
|
||||
size_bytes=13,
|
||||
storage_path="local://hello.txt",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _upload(
|
||||
filename: str = "hello.txt",
|
||||
content: bytes = b"Hello, world!",
|
||||
content_type: str = "text/plain",
|
||||
):
|
||||
"""Helper to POST a file upload."""
|
||||
return client.post(
|
||||
"/files/upload?session_id=sess-1",
|
||||
files={"file": (filename, io.BytesIO(content), content_type)},
|
||||
)
|
||||
|
||||
|
||||
# ---- Happy path ----
|
||||
|
||||
|
||||
def test_upload_happy_path(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["file_id"] == "file-aaa-bbb"
|
||||
assert data["name"] == "hello.txt"
|
||||
assert data["size_bytes"] == 13
|
||||
|
||||
|
||||
# ---- Per-file size limit ----
|
||||
|
||||
|
||||
def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
|
||||
"""Files larger than max_file_size_mb should be rejected with 413."""
|
||||
cfg = mocker.patch("backend.api.features.workspace.routes.Config")
|
||||
cfg.return_value.max_file_size_mb = 0 # 0 MB → any content is too big
|
||||
cfg.return_value.max_workspace_storage_mb = 500
|
||||
|
||||
response = _upload(content=b"x" * 1024)
|
||||
assert response.status_code == 413
|
||||
|
||||
|
||||
# ---- Storage quota exceeded ----
|
||||
|
||||
|
||||
def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
# Current usage already at limit
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=500 * 1024 * 1024,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 413
|
||||
assert "Storage limit exceeded" in response.text
|
||||
|
||||
|
||||
# ---- Post-write quota race (B2) ----
|
||||
|
||||
|
||||
def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
|
||||
"""If a concurrent upload tips the total over the limit after write,
|
||||
the file should be soft-deleted and 413 returned."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
# Pre-write check passes (under limit), but post-write check fails
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
side_effect=[0, 600 * 1024 * 1024], # first call OK, second over limit
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
mock_delete = mocker.patch(
|
||||
"backend.api.features.workspace.routes.soft_delete_workspace_file",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 413
|
||||
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-1")
|
||||
|
||||
|
||||
# ---- Any extension accepted (no allowlist) ----
|
||||
|
||||
|
||||
def test_upload_any_extension(mocker: pytest_mock.MockFixture):
|
||||
"""Any file extension should be accepted — ClamAV is the security layer."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload(filename="data.xyz", content=b"arbitrary")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ---- Virus scan rejection ----
|
||||
|
||||
|
||||
def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
|
||||
"""Files flagged by ClamAV should be rejected and never written to storage."""
|
||||
from backend.api.features.store.exceptions import VirusDetectedError
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
side_effect=VirusDetectedError("Eicar-Test-Signature"),
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload(filename="evil.exe", content=b"X5O!P%@AP...")
|
||||
assert response.status_code == 400
|
||||
assert "Virus detected" in response.text
|
||||
mock_manager.write_file.assert_not_called()
|
||||
|
||||
|
||||
# ---- No file extension ----
|
||||
|
||||
|
||||
def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
|
||||
"""Files without an extension should be accepted and stored as-is."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload(
|
||||
filename="Makefile",
|
||||
content=b"all:\n\techo hello",
|
||||
content_type="application/octet-stream",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_manager.write_file.assert_called_once()
|
||||
assert mock_manager.write_file.call_args[0][1] == "Makefile"
|
||||
|
||||
|
||||
# ---- Filename sanitization (SF5) ----
|
||||
|
||||
|
||||
def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
|
||||
"""Path-traversal filenames should be reduced to their basename."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
# Filename with traversal
|
||||
_upload(filename="../../etc/passwd.txt")
|
||||
|
||||
# write_file should have been called with just the basename
|
||||
mock_manager.write_file.assert_called_once()
|
||||
call_args = mock_manager.write_file.call_args
|
||||
assert call_args[0][1] == "passwd.txt"
|
||||
|
||||
|
||||
# ---- Download ----
|
||||
|
||||
|
||||
def test_download_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_file",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.get("/files/some-file-id/download")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ---- Delete ----
|
||||
|
||||
|
||||
def test_delete_file_success(mocker: pytest_mock.MockFixture):
|
||||
"""Deleting an existing file should return {"deleted": true}."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.delete_file = mocker.AsyncMock(return_value=True)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = client.delete("/files/file-aaa-bbb")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"deleted": True}
|
||||
mock_manager.delete_file.assert_called_once_with("file-aaa-bbb")
|
||||
|
||||
|
||||
def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
"""Deleting a non-existent file should return 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.delete_file = mocker.AsyncMock(return_value=False)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = client.delete("/files/nonexistent-id")
|
||||
assert response.status_code == 404
|
||||
assert "File not found" in response.text
|
||||
|
||||
|
||||
def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
|
||||
"""Deleting when user has no workspace should return 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.delete("/files/file-aaa-bbb")
|
||||
assert response.status_code == 404
|
||||
assert "Workspace not found" in response.text
|
||||
@@ -116,7 +116,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||
CLAUDE_4_6_OPUS = "claude-opus-4-6"
|
||||
CLAUDE_4_6_SONNET = "claude-sonnet-4-6"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
# AI/ML API models
|
||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||
@@ -275,9 +274,6 @@ MODEL_METADATA = {
|
||||
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
|
||||
), # claude-opus-4-6
|
||||
LlmModel.CLAUDE_4_6_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Sonnet 4.6", "Anthropic", "Anthropic", 3
|
||||
), # claude-sonnet-4-6
|
||||
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
||||
), # claude-opus-4-5-20251101
|
||||
|
||||
@@ -6,6 +6,7 @@ and execute them. Works like AgentExecutorBlock — the user selects a tool from
|
||||
dropdown and the input/output schema adapts dynamically.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
@@ -19,11 +20,6 @@ from backend.blocks._base import (
|
||||
BlockType,
|
||||
)
|
||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
||||
from backend.blocks.mcp.helpers import (
|
||||
auto_lookup_mcp_credential,
|
||||
normalize_mcp_url,
|
||||
parse_mcp_content,
|
||||
)
|
||||
from backend.data.block import BlockInput, BlockOutput
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
@@ -183,7 +179,31 @@ class MCPToolBlock(Block):
|
||||
f"{error_text or 'Unknown error'}"
|
||||
)
|
||||
|
||||
return parse_mcp_content(result.content)
|
||||
# Extract text content from the result
|
||||
output_parts = []
|
||||
for item in result.content:
|
||||
if item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
# Try to parse as JSON for structured output
|
||||
try:
|
||||
output_parts.append(json.loads(text))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
output_parts.append(text)
|
||||
elif item.get("type") == "image":
|
||||
output_parts.append(
|
||||
{
|
||||
"type": "image",
|
||||
"data": item.get("data"),
|
||||
"mimeType": item.get("mimeType"),
|
||||
}
|
||||
)
|
||||
elif item.get("type") == "resource":
|
||||
output_parts.append(item.get("resource", {}))
|
||||
|
||||
# If single result, unwrap
|
||||
if len(output_parts) == 1:
|
||||
return output_parts[0]
|
||||
return output_parts if output_parts else None
|
||||
|
||||
@staticmethod
|
||||
async def _auto_lookup_credential(
|
||||
@@ -191,10 +211,37 @@ class MCPToolBlock(Block):
|
||||
) -> "OAuth2Credentials | None":
|
||||
"""Auto-lookup stored MCP credential for a server URL.
|
||||
|
||||
Delegates to :func:`~backend.blocks.mcp.helpers.auto_lookup_mcp_credential`.
|
||||
The caller should pass a normalized URL.
|
||||
This is a fallback for nodes that don't have ``credentials`` explicitly
|
||||
set (e.g. nodes created before the credential field was wired up).
|
||||
"""
|
||||
return await auto_lookup_mcp_credential(user_id, server_url)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
try:
|
||||
mgr = IntegrationCredentialsManager()
|
||||
mcp_creds = await mgr.store.get_creds_by_provider(
|
||||
user_id, ProviderName.MCP.value
|
||||
)
|
||||
best: OAuth2Credentials | None = None
|
||||
for cred in mcp_creds:
|
||||
if (
|
||||
isinstance(cred, OAuth2Credentials)
|
||||
and (cred.metadata or {}).get("mcp_server_url") == server_url
|
||||
):
|
||||
if best is None or (
|
||||
(cred.access_token_expires_at or 0)
|
||||
> (best.access_token_expires_at or 0)
|
||||
):
|
||||
best = cred
|
||||
if best:
|
||||
best = await mgr.refresh_if_needed(user_id, best)
|
||||
logger.info(
|
||||
"Auto-resolved MCP credential %s for %s", best.id, server_url
|
||||
)
|
||||
return best
|
||||
except Exception:
|
||||
logger.warning("Auto-lookup MCP credential failed", exc_info=True)
|
||||
return None
|
||||
|
||||
async def run(
|
||||
self,
|
||||
@@ -231,7 +278,7 @@ class MCPToolBlock(Block):
|
||||
# the stored MCP credential for this server URL.
|
||||
if credentials is None:
|
||||
credentials = await self._auto_lookup_credential(
|
||||
user_id, normalize_mcp_url(input_data.server_url)
|
||||
user_id, input_data.server_url
|
||||
)
|
||||
|
||||
auth_token = (
|
||||
|
||||
@@ -55,9 +55,7 @@ class MCPClient:
|
||||
server_url: str,
|
||||
auth_token: str | None = None,
|
||||
):
|
||||
from backend.blocks.mcp.helpers import normalize_mcp_url
|
||||
|
||||
self.server_url = normalize_mcp_url(server_url)
|
||||
self.server_url = server_url.rstrip("/")
|
||||
self.auth_token = auth_token
|
||||
self._request_id = 0
|
||||
self._session_id: str | None = None
|
||||
|
||||
@@ -1,117 +0,0 @@
|
||||
"""Shared MCP helpers used by blocks, copilot tools, and API routes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.model import OAuth2Credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def normalize_mcp_url(url: str) -> str:
|
||||
"""Normalize an MCP server URL for consistent credential matching.
|
||||
|
||||
Strips leading/trailing whitespace and a single trailing slash so that
|
||||
``https://mcp.example.com/`` and ``https://mcp.example.com`` resolve to
|
||||
the same stored credential.
|
||||
"""
|
||||
return url.strip().rstrip("/")
|
||||
|
||||
|
||||
def server_host(server_url: str) -> str:
|
||||
"""Extract the hostname from a server URL for display purposes.
|
||||
|
||||
Uses ``parsed.hostname`` (never ``netloc``) to strip any embedded
|
||||
username/password before surfacing the value in UI messages.
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(server_url)
|
||||
return parsed.hostname or server_url
|
||||
except Exception:
|
||||
return server_url
|
||||
|
||||
|
||||
def parse_mcp_content(content: list[dict[str, Any]]) -> Any:
|
||||
"""Parse MCP tool response content into a plain Python value.
|
||||
|
||||
- text items: parsed as JSON when possible, kept as str otherwise
|
||||
- image items: kept as ``{type, data, mimeType}`` dict for frontend rendering
|
||||
- resource items: unwrapped to their resource payload dict
|
||||
|
||||
Single-item responses are unwrapped from the list; multiple items are
|
||||
returned as a list; empty content returns ``None``.
|
||||
"""
|
||||
output_parts: list[Any] = []
|
||||
for item in content:
|
||||
item_type = item.get("type")
|
||||
if item_type == "text":
|
||||
text = item.get("text", "")
|
||||
try:
|
||||
output_parts.append(json.loads(text))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
output_parts.append(text)
|
||||
elif item_type == "image":
|
||||
output_parts.append(
|
||||
{
|
||||
"type": "image",
|
||||
"data": item.get("data"),
|
||||
"mimeType": item.get("mimeType"),
|
||||
}
|
||||
)
|
||||
elif item_type == "resource":
|
||||
output_parts.append(item.get("resource", {}))
|
||||
|
||||
if len(output_parts) == 1:
|
||||
return output_parts[0]
|
||||
return output_parts or None
|
||||
|
||||
|
||||
async def auto_lookup_mcp_credential(
|
||||
user_id: str, server_url: str
|
||||
) -> OAuth2Credentials | None:
|
||||
"""Look up the best stored MCP credential for *server_url*.
|
||||
|
||||
The caller should pass a **normalized** URL (via :func:`normalize_mcp_url`)
|
||||
so the comparison with ``mcp_server_url`` in credential metadata matches.
|
||||
|
||||
Returns the credential with the latest ``access_token_expires_at``, refreshed
|
||||
if needed, or ``None`` when no match is found.
|
||||
"""
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
try:
|
||||
mgr = IntegrationCredentialsManager()
|
||||
mcp_creds = await mgr.store.get_creds_by_provider(
|
||||
user_id, ProviderName.MCP.value
|
||||
)
|
||||
# Collect all matching credentials and pick the best one.
|
||||
# Primary sort: latest access_token_expires_at (tokens with expiry
|
||||
# are preferred over non-expiring ones). Secondary sort: last in
|
||||
# iteration order, which corresponds to the most recently created
|
||||
# row — this acts as a tiebreaker when multiple bearer tokens have
|
||||
# no expiry (e.g. after a failed old-credential cleanup).
|
||||
best: OAuth2Credentials | None = None
|
||||
for cred in mcp_creds:
|
||||
if (
|
||||
isinstance(cred, OAuth2Credentials)
|
||||
and (cred.metadata or {}).get("mcp_server_url") == server_url
|
||||
):
|
||||
if best is None or (
|
||||
(cred.access_token_expires_at or 0)
|
||||
>= (best.access_token_expires_at or 0)
|
||||
):
|
||||
best = cred
|
||||
if best:
|
||||
best = await mgr.refresh_if_needed(user_id, best)
|
||||
logger.info("Auto-resolved MCP credential %s for %s", best.id, server_url)
|
||||
return best
|
||||
except Exception:
|
||||
logger.warning("Auto-lookup MCP credential failed", exc_info=True)
|
||||
return None
|
||||
@@ -1,98 +0,0 @@
|
||||
"""Unit tests for the shared MCP helpers."""
|
||||
|
||||
from backend.blocks.mcp.helpers import normalize_mcp_url, parse_mcp_content, server_host
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# normalize_mcp_url
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_normalize_trailing_slash():
|
||||
assert normalize_mcp_url("https://mcp.example.com/") == "https://mcp.example.com"
|
||||
|
||||
|
||||
def test_normalize_whitespace():
|
||||
assert normalize_mcp_url(" https://mcp.example.com ") == "https://mcp.example.com"
|
||||
|
||||
|
||||
def test_normalize_both():
|
||||
assert (
|
||||
normalize_mcp_url(" https://mcp.example.com/ ") == "https://mcp.example.com"
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_noop():
|
||||
assert normalize_mcp_url("https://mcp.example.com") == "https://mcp.example.com"
|
||||
|
||||
|
||||
def test_normalize_path_with_trailing_slash():
|
||||
assert (
|
||||
normalize_mcp_url("https://mcp.example.com/path/")
|
||||
== "https://mcp.example.com/path"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# server_host
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_server_host_standard_url():
|
||||
assert server_host("https://mcp.example.com/mcp") == "mcp.example.com"
|
||||
|
||||
|
||||
def test_server_host_strips_credentials():
|
||||
"""hostname must not expose user:pass."""
|
||||
assert server_host("https://user:secret@mcp.example.com/mcp") == "mcp.example.com"
|
||||
|
||||
|
||||
def test_server_host_with_port():
|
||||
"""Port should not appear in hostname (hostname strips it)."""
|
||||
assert server_host("https://mcp.example.com:8080/mcp") == "mcp.example.com"
|
||||
|
||||
|
||||
def test_server_host_fallback():
|
||||
"""Falls back to the raw string for un-parseable URLs."""
|
||||
assert server_host("not-a-url") == "not-a-url"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_mcp_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_text_plain():
|
||||
assert parse_mcp_content([{"type": "text", "text": "hello world"}]) == "hello world"
|
||||
|
||||
|
||||
def test_parse_text_json():
|
||||
content = [{"type": "text", "text": '{"status": "ok", "count": 42}'}]
|
||||
assert parse_mcp_content(content) == {"status": "ok", "count": 42}
|
||||
|
||||
|
||||
def test_parse_image():
|
||||
content = [{"type": "image", "data": "abc123==", "mimeType": "image/png"}]
|
||||
assert parse_mcp_content(content) == {
|
||||
"type": "image",
|
||||
"data": "abc123==",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
|
||||
|
||||
def test_parse_resource():
|
||||
content = [
|
||||
{"type": "resource", "resource": {"uri": "file:///tmp/out.txt", "text": "hi"}}
|
||||
]
|
||||
assert parse_mcp_content(content) == {"uri": "file:///tmp/out.txt", "text": "hi"}
|
||||
|
||||
|
||||
def test_parse_multi_item():
|
||||
content = [
|
||||
{"type": "text", "text": "first"},
|
||||
{"type": "text", "text": "second"},
|
||||
]
|
||||
assert parse_mcp_content(content) == ["first", "second"]
|
||||
|
||||
|
||||
def test_parse_empty():
|
||||
assert parse_mcp_content([]) is None
|
||||
@@ -83,8 +83,7 @@ class StagehandRecommendedLlmModel(str, Enum):
|
||||
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
||||
|
||||
# Anthropic
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929" # Keep for backwards compat
|
||||
CLAUDE_4_6_SONNET = "claude-sonnet-4-6"
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
@@ -138,7 +137,7 @@ class StagehandObserveBlock(Block):
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
@@ -228,7 +227,7 @@ class StagehandActBlock(Block):
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
@@ -325,7 +324,7 @@ class StagehandExtractBlock(Block):
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from .service import stream_chat_completion_baseline
|
||||
|
||||
__all__ = ["stream_chat_completion_baseline"]
|
||||
@@ -1,420 +0,0 @@
|
||||
"""Baseline LLM fallback — OpenAI-compatible streaming with tool calling.
|
||||
|
||||
Used when ``CHAT_USE_CLAUDE_AGENT_SDK=false``, e.g. as a fallback when the
|
||||
Claude Agent SDK / Anthropic API is unavailable. Routes through any
|
||||
OpenAI-compatible provider (OpenRouter by default) and reuses the same
|
||||
shared tool registry as the SDK path.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from langfuse import propagate_attributes
|
||||
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import (
|
||||
_build_system_prompt,
|
||||
_generate_session_title,
|
||||
client,
|
||||
config,
|
||||
)
|
||||
from backend.copilot.tools import execute_tool, get_available_tools
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.prompt import compress_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Set to hold background tasks to prevent garbage collection
|
||||
_background_tasks: set[asyncio.Task[Any]] = set()
|
||||
|
||||
# Maximum number of tool-call rounds before forcing a text response.
|
||||
_MAX_TOOL_ROUNDS = 30
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
session_id: str, message: str, user_id: str | None
|
||||
) -> None:
|
||||
"""Generate and persist a session title in the background."""
|
||||
try:
|
||||
title = await _generate_session_title(message, user_id, session_id)
|
||||
if title and user_id:
|
||||
await update_session_title(session_id, user_id, title, only_if_empty=True)
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Failed to update session title: %s", e)
|
||||
|
||||
|
||||
async def _compress_session_messages(
|
||||
messages: list[ChatMessage],
|
||||
) -> list[ChatMessage]:
|
||||
"""Compress session messages if they exceed the model's token limit.
|
||||
|
||||
Uses the shared compress_context() utility which supports LLM-based
|
||||
summarization of older messages while keeping recent ones intact,
|
||||
with progressive truncation and middle-out deletion as fallbacks.
|
||||
"""
|
||||
messages_dict = []
|
||||
for msg in messages:
|
||||
msg_dict: dict[str, Any] = {"role": msg.role}
|
||||
if msg.content:
|
||||
msg_dict["content"] = msg.content
|
||||
messages_dict.append(msg_dict)
|
||||
|
||||
try:
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
client=None,
|
||||
)
|
||||
|
||||
if result.was_compacted:
|
||||
logger.info(
|
||||
"[Baseline] Context compacted: %d -> %d tokens "
|
||||
"(%d summarized, %d dropped)",
|
||||
result.original_token_count,
|
||||
result.token_count,
|
||||
result.messages_summarized,
|
||||
result.messages_dropped,
|
||||
)
|
||||
return [
|
||||
ChatMessage(role=m["role"], content=m.get("content"))
|
||||
for m in result.messages
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
async def stream_chat_completion_baseline(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
is_user_message: bool = True,
|
||||
user_id: str | None = None,
|
||||
session: ChatSession | None = None,
|
||||
**_kwargs: Any,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Baseline LLM with tool calling via OpenAI-compatible API.
|
||||
|
||||
Designed as a fallback when the Claude Agent SDK is unavailable.
|
||||
Uses the same tool registry as the SDK path but routes through any
|
||||
OpenAI-compatible provider (e.g. OpenRouter).
|
||||
|
||||
Flow: stream response -> if tool_calls, execute them -> feed results back -> repeat.
|
||||
"""
|
||||
if session is None:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(
|
||||
f"Session {session_id} not found. Please create a new session first."
|
||||
)
|
||||
|
||||
# Append user message
|
||||
new_role = "user" if is_user_message else "assistant"
|
||||
if message and (
|
||||
len(session.messages) == 0
|
||||
or not (
|
||||
session.messages[-1].role == new_role
|
||||
and session.messages[-1].content == message
|
||||
)
|
||||
):
|
||||
session.messages.append(ChatMessage(role=new_role, content=message))
|
||||
if is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(message),
|
||||
)
|
||||
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Generate title for new sessions
|
||||
if is_user_message and not session.title:
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
if len(user_messages) == 1:
|
||||
first_message = user_messages[0].content or message or ""
|
||||
if first_message:
|
||||
task = asyncio.create_task(
|
||||
_update_title_async(session_id, first_message, user_id)
|
||||
)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
# Build system prompt only on the first turn to avoid mid-conversation
|
||||
# changes from concurrent chats updating business understanding.
|
||||
is_first_turn = len(session.messages) <= 1
|
||||
if is_first_turn:
|
||||
system_prompt, _ = await _build_system_prompt(
|
||||
user_id, has_conversation_history=False
|
||||
)
|
||||
else:
|
||||
system_prompt, _ = await _build_system_prompt(
|
||||
user_id=None, has_conversation_history=True
|
||||
)
|
||||
|
||||
# Compress context if approaching the model's token limit
|
||||
messages_for_context = await _compress_session_messages(session.messages)
|
||||
|
||||
# Build OpenAI message list from session history
|
||||
openai_messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": system_prompt}
|
||||
]
|
||||
for msg in messages_for_context:
|
||||
if msg.role in ("user", "assistant") and msg.content:
|
||||
openai_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
tools = get_available_tools()
|
||||
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
|
||||
# Propagate user/session context to Langfuse so all LLM calls within
|
||||
# this request are grouped under a single trace with proper attribution.
|
||||
_trace_ctx: Any = None
|
||||
try:
|
||||
_trace_ctx = propagate_attributes(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
trace_name="copilot-baseline",
|
||||
tags=["baseline"],
|
||||
)
|
||||
_trace_ctx.__enter__()
|
||||
except Exception:
|
||||
logger.warning("[Baseline] Langfuse trace context setup failed")
|
||||
|
||||
assistant_text = ""
|
||||
text_block_id = str(uuid.uuid4())
|
||||
text_started = False
|
||||
step_open = False
|
||||
try:
|
||||
for _round in range(_MAX_TOOL_ROUNDS):
|
||||
# Open a new step for each LLM round
|
||||
yield StreamStartStep()
|
||||
step_open = True
|
||||
|
||||
# Stream a response from the model
|
||||
create_kwargs: dict[str, Any] = dict(
|
||||
model=config.model,
|
||||
messages=openai_messages,
|
||||
stream=True,
|
||||
)
|
||||
if tools:
|
||||
create_kwargs["tools"] = tools
|
||||
response = await client.chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
|
||||
|
||||
# Accumulate streamed response (text + tool calls)
|
||||
round_text = ""
|
||||
tool_calls_by_index: dict[int, dict[str, str]] = {}
|
||||
|
||||
async for chunk in response:
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
if not delta:
|
||||
continue
|
||||
|
||||
# Text content
|
||||
if delta.content:
|
||||
if not text_started:
|
||||
yield StreamTextStart(id=text_block_id)
|
||||
text_started = True
|
||||
round_text += delta.content
|
||||
yield StreamTextDelta(id=text_block_id, delta=delta.content)
|
||||
|
||||
# Tool call fragments (streamed incrementally)
|
||||
if delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index
|
||||
if idx not in tool_calls_by_index:
|
||||
tool_calls_by_index[idx] = {
|
||||
"id": "",
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
entry = tool_calls_by_index[idx]
|
||||
if tc.id:
|
||||
entry["id"] = tc.id
|
||||
if tc.function and tc.function.name:
|
||||
entry["name"] = tc.function.name
|
||||
if tc.function and tc.function.arguments:
|
||||
entry["arguments"] += tc.function.arguments
|
||||
|
||||
# Close text block if we had one this round
|
||||
if text_started:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
text_started = False
|
||||
text_block_id = str(uuid.uuid4())
|
||||
|
||||
# Accumulate text for session persistence
|
||||
assistant_text += round_text
|
||||
|
||||
# No tool calls -> model is done
|
||||
if not tool_calls_by_index:
|
||||
yield StreamFinishStep()
|
||||
step_open = False
|
||||
break
|
||||
|
||||
# Close step before tool execution
|
||||
yield StreamFinishStep()
|
||||
step_open = False
|
||||
|
||||
# Append the assistant message with tool_calls to context.
|
||||
assistant_msg: dict[str, Any] = {"role": "assistant"}
|
||||
if round_text:
|
||||
assistant_msg["content"] = round_text
|
||||
assistant_msg["tool_calls"] = [
|
||||
{
|
||||
"id": tc["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc["name"],
|
||||
"arguments": tc["arguments"] or "{}",
|
||||
},
|
||||
}
|
||||
for tc in tool_calls_by_index.values()
|
||||
]
|
||||
openai_messages.append(assistant_msg)
|
||||
|
||||
# Execute each tool call and stream events
|
||||
for tc in tool_calls_by_index.values():
|
||||
tool_call_id = tc["id"]
|
||||
tool_name = tc["name"]
|
||||
raw_args = tc["arguments"] or "{}"
|
||||
try:
|
||||
tool_args = orjson.loads(raw_args)
|
||||
except orjson.JSONDecodeError as parse_err:
|
||||
parse_error = (
|
||||
f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
|
||||
)
|
||||
logger.warning("[Baseline] %s", parse_error)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=parse_error,
|
||||
success=False,
|
||||
)
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": parse_error,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
yield StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
input=tool_args,
|
||||
)
|
||||
|
||||
# Execute via shared tool registry
|
||||
try:
|
||||
result: StreamToolOutputAvailable = await execute_tool(
|
||||
tool_name=tool_name,
|
||||
parameters=tool_args,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
yield result
|
||||
tool_output = (
|
||||
result.output
|
||||
if isinstance(result.output, str)
|
||||
else str(result.output)
|
||||
)
|
||||
except Exception as e:
|
||||
error_output = f"Tool execution error: {e}"
|
||||
logger.error(
|
||||
"[Baseline] Tool %s failed: %s",
|
||||
tool_name,
|
||||
error_output,
|
||||
exc_info=True,
|
||||
)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=error_output,
|
||||
success=False,
|
||||
)
|
||||
tool_output = error_output
|
||||
|
||||
# Append tool result to context for next round
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": tool_output,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# for-loop exhausted without break -> tool-round limit hit
|
||||
limit_msg = (
|
||||
f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds "
|
||||
"without a final response."
|
||||
)
|
||||
logger.error("[Baseline] %s", limit_msg)
|
||||
yield StreamError(
|
||||
errorText=limit_msg,
|
||||
code="baseline_tool_round_limit",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e) or type(e).__name__
|
||||
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
|
||||
# Close any open text/step before emitting error
|
||||
if text_started:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
if step_open:
|
||||
yield StreamFinishStep()
|
||||
yield StreamError(errorText=error_msg, code="baseline_error")
|
||||
# Still persist whatever we got
|
||||
finally:
|
||||
# Close Langfuse trace context
|
||||
if _trace_ctx is not None:
|
||||
try:
|
||||
_trace_ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("[Baseline] Langfuse trace context teardown failed")
|
||||
|
||||
# Persist assistant response
|
||||
if assistant_text:
|
||||
session.messages.append(
|
||||
ChatMessage(role="assistant", content=assistant_text)
|
||||
)
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
except Exception as persist_err:
|
||||
logger.error("[Baseline] Failed to persist session: %s", persist_err)
|
||||
|
||||
yield StreamFinish()
|
||||
@@ -1,99 +0,0 @@
|
||||
import logging
|
||||
from os import getenv
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.baseline import stream_chat_completion_baseline
|
||||
from backend.copilot.model import (
|
||||
create_chat_session,
|
||||
get_chat_session,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamStart,
|
||||
StreamTextDelta,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_baseline_multi_turn(setup_test_user, test_user_id):
|
||||
"""Test that the baseline LLM path streams responses and maintains history.
|
||||
|
||||
Turn 1: Send a message with a unique keyword.
|
||||
Turn 2: Ask the model to recall the keyword — proving conversation history
|
||||
is correctly passed to the single-call LLM.
|
||||
"""
|
||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# --- Turn 1: send a message with a unique keyword ---
|
||||
keyword = "QUASAR99"
|
||||
turn1_msg = (
|
||||
f"Please remember this special keyword: {keyword}. "
|
||||
"Just confirm you've noted it, keep your response brief."
|
||||
)
|
||||
turn1_text = ""
|
||||
turn1_errors: list[str] = []
|
||||
got_start = False
|
||||
got_finish = False
|
||||
|
||||
async for chunk in stream_chat_completion_baseline(
|
||||
session.session_id,
|
||||
turn1_msg,
|
||||
user_id=test_user_id,
|
||||
):
|
||||
if isinstance(chunk, StreamStart):
|
||||
got_start = True
|
||||
elif isinstance(chunk, StreamTextDelta):
|
||||
turn1_text += chunk.delta
|
||||
elif isinstance(chunk, StreamError):
|
||||
turn1_errors.append(chunk.errorText)
|
||||
elif isinstance(chunk, StreamFinish):
|
||||
got_finish = True
|
||||
|
||||
assert got_start, "Turn 1 did not yield StreamStart"
|
||||
assert got_finish, "Turn 1 did not yield StreamFinish"
|
||||
assert not turn1_errors, f"Turn 1 errors: {turn1_errors}"
|
||||
assert turn1_text, "Turn 1 produced no text"
|
||||
logger.info(f"Turn 1 response: {turn1_text[:100]}")
|
||||
|
||||
# Reload session for turn 2
|
||||
session = await get_chat_session(session.session_id, test_user_id)
|
||||
assert session, "Session not found after turn 1"
|
||||
|
||||
# Verify messages were persisted (user + assistant)
|
||||
assert (
|
||||
len(session.messages) >= 2
|
||||
), f"Expected at least 2 messages after turn 1, got {len(session.messages)}"
|
||||
|
||||
# --- Turn 2: ask model to recall the keyword ---
|
||||
turn2_msg = "What was the special keyword I asked you to remember?"
|
||||
turn2_text = ""
|
||||
turn2_errors: list[str] = []
|
||||
|
||||
async for chunk in stream_chat_completion_baseline(
|
||||
session.session_id,
|
||||
turn2_msg,
|
||||
user_id=test_user_id,
|
||||
session=session,
|
||||
):
|
||||
if isinstance(chunk, StreamTextDelta):
|
||||
turn2_text += chunk.delta
|
||||
elif isinstance(chunk, StreamError):
|
||||
turn2_errors.append(chunk.errorText)
|
||||
|
||||
assert not turn2_errors, f"Turn 2 errors: {turn2_errors}"
|
||||
assert turn2_text, "Turn 2 produced no text"
|
||||
assert keyword in turn2_text, (
|
||||
f"Model did not recall keyword '{keyword}' in turn 2. "
|
||||
f"Response: {turn2_text[:200]}"
|
||||
)
|
||||
logger.info(f"Turn 2 recalled keyword successfully: {turn2_text[:100]}")
|
||||
@@ -26,6 +26,11 @@ class ChatConfig(BaseSettings):
|
||||
# Session TTL Configuration - 12 hours
|
||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||
|
||||
# Streaming Configuration
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
description="Max retries for fallback path (SDK handles retries internally)",
|
||||
)
|
||||
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||
max_agent_schedules: int = Field(
|
||||
default=30, description="Maximum number of agent schedules"
|
||||
@@ -62,15 +67,11 @@ class ChatConfig(BaseSettings):
|
||||
default="CoPilot Prompt",
|
||||
description="Name of the prompt in Langfuse to fetch",
|
||||
)
|
||||
langfuse_prompt_cache_ttl: int = Field(
|
||||
default=300,
|
||||
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
|
||||
)
|
||||
|
||||
# Claude Agent SDK Configuration
|
||||
use_claude_agent_sdk: bool = Field(
|
||||
default=True,
|
||||
description="Use Claude Agent SDK (True) or OpenAI-compatible LLM baseline (False)",
|
||||
description="Use Claude Agent SDK for chat completions",
|
||||
)
|
||||
claude_agent_model: str | None = Field(
|
||||
default=None,
|
||||
@@ -91,53 +92,18 @@ class ChatConfig(BaseSettings):
|
||||
description="Use --resume for multi-turn conversations instead of "
|
||||
"history compression. Falls back to compression when unavailable.",
|
||||
)
|
||||
use_claude_code_subscription: bool = Field(
|
||||
default=False,
|
||||
description="For personal/dev use: use Claude Code CLI subscription auth instead of API keys. Requires `claude login` on the host. Only works with SDK mode.",
|
||||
)
|
||||
|
||||
# E2B Sandbox Configuration
|
||||
use_e2b_sandbox: bool = Field(
|
||||
# Extended thinking configuration for Claude models
|
||||
thinking_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Use E2B cloud sandboxes for persistent bash/python execution. "
|
||||
"When enabled, bash_exec routes commands to E2B and SDK file tools "
|
||||
"operate directly on the sandbox via E2B's filesystem API.",
|
||||
description="Enable adaptive thinking for Claude models via OpenRouter",
|
||||
)
|
||||
e2b_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="E2B API key. Falls back to E2B_API_KEY environment variable.",
|
||||
)
|
||||
e2b_sandbox_template: str = Field(
|
||||
default="base",
|
||||
description="E2B sandbox template to use for copilot sessions.",
|
||||
)
|
||||
e2b_sandbox_timeout: int = Field(
|
||||
default=43200, # 12 hours — same as session_ttl
|
||||
description="E2B sandbox keepalive timeout in seconds.",
|
||||
)
|
||||
|
||||
@field_validator("use_e2b_sandbox", mode="before")
|
||||
@classmethod
|
||||
def get_use_e2b_sandbox(cls, v):
|
||||
"""Get use_e2b_sandbox from environment if not provided."""
|
||||
env_val = os.getenv("CHAT_USE_E2B_SANDBOX", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
return True if v is None else v
|
||||
|
||||
@field_validator("e2b_api_key", mode="before")
|
||||
@classmethod
|
||||
def get_e2b_api_key(cls, v):
|
||||
"""Get E2B API key from environment if not provided."""
|
||||
if not v:
|
||||
v = os.getenv("CHAT_E2B_API_KEY") or os.getenv("E2B_API_KEY")
|
||||
return v
|
||||
|
||||
@field_validator("api_key", mode="before")
|
||||
@classmethod
|
||||
def get_api_key(cls, v):
|
||||
"""Get API key from environment if not provided."""
|
||||
if not v:
|
||||
if v is None:
|
||||
# Try to get from environment variables
|
||||
# First check for CHAT_API_KEY (Pydantic prefix)
|
||||
v = os.getenv("CHAT_API_KEY")
|
||||
@@ -147,16 +113,13 @@ class ChatConfig(BaseSettings):
|
||||
if not v:
|
||||
# Fall back to OPENAI_API_KEY
|
||||
v = os.getenv("OPENAI_API_KEY")
|
||||
# Note: ANTHROPIC_API_KEY is intentionally NOT included here.
|
||||
# The SDK CLI picks it up from the env directly. Including it
|
||||
# would pair it with the OpenRouter base_url, causing auth failures.
|
||||
return v
|
||||
|
||||
@field_validator("base_url", mode="before")
|
||||
@classmethod
|
||||
def get_base_url(cls, v):
|
||||
"""Get base URL from environment if not provided."""
|
||||
if not v:
|
||||
if v is None:
|
||||
# Check for OpenRouter or custom base URL
|
||||
v = os.getenv("CHAT_BASE_URL")
|
||||
if not v:
|
||||
@@ -178,15 +141,6 @@ class ChatConfig(BaseSettings):
|
||||
# Default to True (SDK enabled by default)
|
||||
return True if v is None else v
|
||||
|
||||
@field_validator("use_claude_code_subscription", mode="before")
|
||||
@classmethod
|
||||
def get_use_claude_code_subscription(cls, v):
|
||||
"""Get use_claude_code_subscription from environment if not provided."""
|
||||
env_val = os.getenv("CHAT_USE_CLAUDE_CODE_SUBSCRIPTION", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
return False if v is None else v
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
"""Shared constants for the CoPilot module."""
|
||||
|
||||
# Special message prefixes for text-based markers (parsed by frontend).
|
||||
# The hex suffix makes accidental LLM generation of these strings virtually
|
||||
# impossible, avoiding false-positive marker detection in normal conversation.
|
||||
COPILOT_ERROR_PREFIX = "[__COPILOT_ERROR_f7a1__]" # Renders as ErrorCard
|
||||
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
|
||||
|
||||
# Compaction notice messages shown to users.
|
||||
COMPACTION_DONE_MSG = "Earlier messages were summarized to fit within context limits."
|
||||
COMPACTION_TOOL_NAME = "context_compaction"
|
||||
@@ -16,7 +16,7 @@ from prisma.types import (
|
||||
)
|
||||
|
||||
from backend.data import db
|
||||
from backend.util.json import SafeJson, sanitize_string
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
from .model import ChatMessage, ChatSession, ChatSessionInfo
|
||||
|
||||
@@ -81,35 +81,6 @@ async def update_chat_session(
|
||||
return ChatSession.from_db(session) if session else None
|
||||
|
||||
|
||||
async def update_chat_session_title(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
title: str,
|
||||
*,
|
||||
only_if_empty: bool = False,
|
||||
) -> bool:
|
||||
"""Update the title of a chat session, scoped to the owning user.
|
||||
|
||||
Always filters by (session_id, user_id) so callers cannot mutate another
|
||||
user's session even when they know the session_id.
|
||||
|
||||
Args:
|
||||
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
|
||||
guard so auto-generated titles never overwrite a user-set title.
|
||||
|
||||
Returns True if a row was updated, False otherwise (session not found,
|
||||
wrong user, or — when only_if_empty — title was already set).
|
||||
"""
|
||||
where: ChatSessionWhereInput = {"id": session_id, "userId": user_id}
|
||||
if only_if_empty:
|
||||
where["title"] = None
|
||||
result = await PrismaChatSession.prisma().update_many(
|
||||
where=where,
|
||||
data={"title": title, "updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
return result > 0
|
||||
|
||||
|
||||
async def add_chat_message(
|
||||
session_id: str,
|
||||
role: str,
|
||||
@@ -130,16 +101,15 @@ async def add_chat_message(
|
||||
"sequence": sequence,
|
||||
}
|
||||
|
||||
# Add optional string fields — sanitize to strip PostgreSQL-incompatible
|
||||
# control characters (null bytes etc.) that may appear in tool outputs.
|
||||
# Add optional string fields
|
||||
if content is not None:
|
||||
data["content"] = sanitize_string(content)
|
||||
data["content"] = content
|
||||
if name is not None:
|
||||
data["name"] = name
|
||||
if tool_call_id is not None:
|
||||
data["toolCallId"] = tool_call_id
|
||||
if refusal is not None:
|
||||
data["refusal"] = sanitize_string(refusal)
|
||||
data["refusal"] = refusal
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if tool_calls is not None:
|
||||
@@ -200,16 +170,15 @@ async def add_chat_messages_batch(
|
||||
"createdAt": now,
|
||||
}
|
||||
|
||||
# Add optional string fields — sanitize to strip
|
||||
# PostgreSQL-incompatible control characters.
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = sanitize_string(msg["content"])
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = sanitize_string(msg["refusal"])
|
||||
data["refusal"] = msg["refusal"]
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
@@ -343,7 +312,7 @@ async def update_tool_message_content(
|
||||
"toolCallId": tool_call_id,
|
||||
},
|
||||
data={
|
||||
"content": sanitize_string(new_content),
|
||||
"content": new_content,
|
||||
},
|
||||
)
|
||||
if result == 0:
|
||||
|
||||
@@ -6,13 +6,11 @@ in a thread-local context, following the graph executor pattern.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
|
||||
from backend.copilot import service as copilot_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.baseline import stream_chat_completion_baseline
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.response_model import StreamFinish
|
||||
from backend.copilot.sdk import service as sdk_service
|
||||
@@ -110,41 +108,8 @@ class CoPilotProcessor:
|
||||
)
|
||||
self.execution_thread.start()
|
||||
|
||||
# Skip the SDK's per-request CLI version check — the bundled CLI is
|
||||
# already version-matched to the SDK package.
|
||||
os.environ.setdefault("CLAUDE_AGENT_SDK_SKIP_VERSION_CHECK", "1")
|
||||
|
||||
# Pre-warm the bundled CLI binary so the OS page-caches the ~185 MB
|
||||
# executable. First spawn pays ~1.2 s; subsequent spawns ~0.65 s.
|
||||
self._prewarm_cli()
|
||||
|
||||
logger.info(f"[CoPilotExecutor] Worker {self.tid} started")
|
||||
|
||||
def _prewarm_cli(self) -> None:
|
||||
"""Run the bundled CLI binary once to warm OS page caches."""
|
||||
try:
|
||||
from claude_agent_sdk._internal.transport.subprocess_cli import (
|
||||
SubprocessCLITransport,
|
||||
)
|
||||
|
||||
cli_path = SubprocessCLITransport._find_bundled_cli(None) # type: ignore[arg-type]
|
||||
if cli_path:
|
||||
result = subprocess.run(
|
||||
[cli_path, "-v"],
|
||||
capture_output=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
logger.info(f"[CoPilotExecutor] CLI pre-warm done: {cli_path}")
|
||||
else:
|
||||
logger.warning(
|
||||
"[CoPilotExecutor] CLI pre-warm failed (rc=%d): %s",
|
||||
result.returncode, # type: ignore[reportCallIssue]
|
||||
cli_path,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"[CoPilotExecutor] CLI pre-warm skipped: {e}")
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up event-loop-bound resources before the loop is destroyed.
|
||||
|
||||
@@ -154,12 +119,12 @@ class CoPilotProcessor:
|
||||
"""
|
||||
from backend.util.workspace_storage import shutdown_workspace_storage
|
||||
|
||||
coro = shutdown_workspace_storage()
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self.execution_loop)
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
shutdown_workspace_storage(), self.execution_loop
|
||||
)
|
||||
future.result(timeout=5)
|
||||
except Exception as e:
|
||||
coro.close() # Prevent "coroutine was never awaited" warning
|
||||
error_msg = str(e) or type(e).__name__
|
||||
logger.warning(
|
||||
f"[CoPilotExecutor] Worker {self.tid} cleanup error: {error_msg}"
|
||||
@@ -229,7 +194,7 @@ class CoPilotProcessor:
|
||||
):
|
||||
"""Async execution logic for a CoPilot turn.
|
||||
|
||||
Calls the chat completion service (SDK or baseline) and publishes
|
||||
Calls the stream_chat_completion service function and publishes
|
||||
results to the stream registry.
|
||||
|
||||
Args:
|
||||
@@ -243,10 +208,9 @@ class CoPilotProcessor:
|
||||
error_msg = None
|
||||
|
||||
try:
|
||||
# Choose service based on LaunchDarkly flag.
|
||||
# Claude Code subscription forces SDK mode (CLI subprocess auth).
|
||||
# Choose service based on LaunchDarkly flag
|
||||
config = ChatConfig()
|
||||
use_sdk = config.use_claude_code_subscription or await is_feature_enabled(
|
||||
use_sdk = await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
@@ -254,9 +218,9 @@ class CoPilotProcessor:
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else stream_chat_completion_baseline
|
||||
else copilot_service.stream_chat_completion
|
||||
)
|
||||
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
|
||||
log.info(f"Using {'SDK' if use_sdk else 'standard'} service")
|
||||
|
||||
# Stream chat completion and publish chunks to Redis.
|
||||
async for chunk in stream_fn(
|
||||
@@ -265,7 +229,6 @@ class CoPilotProcessor:
|
||||
is_user_message=entry.is_user_message,
|
||||
user_id=entry.user_id,
|
||||
context=entry.context,
|
||||
file_ids=entry.file_ids,
|
||||
):
|
||||
if cancel.is_set():
|
||||
log.info("Cancel requested, breaking stream")
|
||||
|
||||
@@ -153,9 +153,6 @@ class CoPilotExecutionEntry(BaseModel):
|
||||
context: dict[str, str] | None = None
|
||||
"""Optional context for the message (e.g., {url: str, content: str})"""
|
||||
|
||||
file_ids: list[str] | None = None
|
||||
"""Workspace file IDs attached to the user's message"""
|
||||
|
||||
|
||||
class CancelCoPilotEvent(BaseModel):
|
||||
"""Event to cancel a CoPilot operation."""
|
||||
@@ -174,7 +171,6 @@ async def enqueue_copilot_turn(
|
||||
turn_id: str,
|
||||
is_user_message: bool = True,
|
||||
context: dict[str, str] | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Enqueue a CoPilot task for processing by the executor service.
|
||||
|
||||
@@ -185,7 +181,6 @@ async def enqueue_copilot_turn(
|
||||
turn_id: Per-turn UUID for Redis stream isolation
|
||||
is_user_message: Whether the message is from the user (vs system/assistant)
|
||||
context: Optional context for the message (e.g., {url: str, content: str})
|
||||
file_ids: Optional workspace file IDs attached to the user's message
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
@@ -196,7 +191,6 @@ async def enqueue_copilot_turn(
|
||||
message=message,
|
||||
is_user_message=is_user_message,
|
||||
context=context,
|
||||
file_ids=file_ids,
|
||||
)
|
||||
|
||||
queue_client = await get_async_copilot_queue()
|
||||
|
||||
@@ -469,16 +469,8 @@ async def upsert_chat_session(
|
||||
)
|
||||
db_error = e
|
||||
|
||||
# Save to cache (best-effort, even if DB failed).
|
||||
# Title updates (update_session_title) run *outside* this lock because
|
||||
# they only touch the title field, not messages. So a concurrent rename
|
||||
# or auto-title may have written a newer title to Redis while this
|
||||
# upsert was in progress. Always prefer the cached title to avoid
|
||||
# overwriting it with the stale in-memory copy.
|
||||
# Save to cache (best-effort, even if DB failed)
|
||||
try:
|
||||
existing_cached = await _get_session_from_cache(session.session_id)
|
||||
if existing_cached and existing_cached.title:
|
||||
session = session.model_copy(update={"title": existing_cached.title})
|
||||
await cache_chat_session(session)
|
||||
except Exception as e:
|
||||
# If DB succeeded but cache failed, raise cache error
|
||||
@@ -680,47 +672,27 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
async with _session_locks_mutex:
|
||||
_session_locks.pop(session_id, None)
|
||||
|
||||
# Shut down any local browser daemon for this session (best-effort).
|
||||
# Inline import required: all tool modules import ChatSession from this
|
||||
# module, so any top-level import from tools.* would create a cycle.
|
||||
try:
|
||||
from .tools.agent_browser import close_browser_session
|
||||
|
||||
await close_browser_session(session_id, user_id=user_id)
|
||||
except Exception as e:
|
||||
logger.debug(f"Browser cleanup for session {session_id}: {e}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def update_session_title(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
title: str,
|
||||
*,
|
||||
only_if_empty: bool = False,
|
||||
) -> bool:
|
||||
"""Update the title of a chat session, scoped to the owning user.
|
||||
async def update_session_title(session_id: str, title: str) -> bool:
|
||||
"""Update only the title of a chat session.
|
||||
|
||||
Lightweight operation that doesn't touch messages, avoiding race conditions
|
||||
with concurrent message updates.
|
||||
This is a lightweight operation that doesn't touch messages, avoiding
|
||||
race conditions with concurrent message updates. Use this for background
|
||||
title generation instead of upsert_chat_session.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to update.
|
||||
user_id: Owning user — the DB query filters on this.
|
||||
title: The new title to set.
|
||||
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
|
||||
so auto-generated titles never overwrite a user-set title.
|
||||
|
||||
Returns:
|
||||
True if updated successfully, False otherwise (not found, wrong user,
|
||||
or — when only_if_empty — title was already set).
|
||||
True if updated successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
updated = await chat_db().update_chat_session_title(
|
||||
session_id, user_id, title, only_if_empty=only_if_empty
|
||||
)
|
||||
if not updated:
|
||||
result = await chat_db().update_chat_session(session_id=session_id, title=title)
|
||||
if result is None:
|
||||
logger.warning(f"Session {session_id} not found for title update")
|
||||
return False
|
||||
|
||||
# Update title in cache if it exists (instead of invalidating).
|
||||
@@ -732,8 +704,9 @@ async def update_session_title(
|
||||
cached.title = title
|
||||
await cache_chat_session(cached)
|
||||
except Exception as e:
|
||||
# Not critical - title will be correct on next full cache refresh
|
||||
logger.warning(
|
||||
f"Cache title update failed for session {session_id} (non-critical): {e}"
|
||||
f"Failed to update title in cache for session {session_id}: {e}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
422
autogpt_platform/backend/backend/copilot/otlp_trace.py
Normal file
422
autogpt_platform/backend/backend/copilot/otlp_trace.py
Normal file
@@ -0,0 +1,422 @@
|
||||
"""Lightweight OTLP JSON trace exporter for CoPilot LLM calls.
|
||||
|
||||
Sends spans to a remote OTLP-compatible endpoint (e.g. Product Intelligence)
|
||||
in the ExportTraceServiceRequest JSON format. Payload construction and the
|
||||
HTTP POST run in background asyncio tasks so streaming latency is unaffected.
|
||||
|
||||
Configuration (via backend.util.settings.Secrets):
|
||||
OTLP_TRACING_HOST – base URL of the trace ingestion service
|
||||
(e.g. "https://traces.example.com")
|
||||
OTLP_TRACING_TOKEN – optional Bearer token for authentication
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_settings = Settings()
|
||||
|
||||
# Resolve the endpoint once at import time.
|
||||
_TRACING_HOST = (_settings.secrets.otlp_tracing_host or "").rstrip("/")
|
||||
_TRACING_TOKEN = _settings.secrets.otlp_tracing_token.get_secret_value()
|
||||
_TRACING_ENABLED = bool(_TRACING_HOST)
|
||||
|
||||
# Shared async client — created lazily on first use.
|
||||
_client: httpx.AsyncClient | None = None
|
||||
|
||||
|
||||
def _get_client() -> httpx.AsyncClient:
|
||||
global _client
|
||||
if _client is None:
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
if _TRACING_TOKEN:
|
||||
headers["Authorization"] = f"Bearer {_TRACING_TOKEN}"
|
||||
_client = httpx.AsyncClient(headers=headers, timeout=10.0)
|
||||
return _client
|
||||
|
||||
|
||||
def _nano(ts: float) -> str:
|
||||
"""Convert a ``time.time()`` float to nanosecond string for OTLP."""
|
||||
return str(int(ts * 1_000_000_000))
|
||||
|
||||
|
||||
def _kv(key: str, value: Any) -> dict | None:
|
||||
"""Build an OTLP KeyValue entry, returning None for missing values."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return {"key": key, "value": {"stringValue": value}}
|
||||
if isinstance(value, bool):
|
||||
return {"key": key, "value": {"stringValue": str(value).lower()}}
|
||||
if isinstance(value, int):
|
||||
return {"key": key, "value": {"intValue": str(value)}}
|
||||
if isinstance(value, float):
|
||||
return {"key": key, "value": {"doubleValue": value}}
|
||||
# Fallback: serialise as string
|
||||
return {"key": key, "value": {"stringValue": str(value)}}
|
||||
|
||||
|
||||
def _build_completion_text(
|
||||
assistant_content: str | None,
|
||||
tool_calls: list[dict[str, Any]] | None,
|
||||
) -> str | None:
|
||||
"""Build completion text that includes tool calls in the format
|
||||
the Product Intelligence system can parse: ``tool_name{json_args}``.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
if tool_calls:
|
||||
for tc in tool_calls:
|
||||
fn = tc.get("function", {})
|
||||
name = fn.get("name", "")
|
||||
args = fn.get("arguments", "{}")
|
||||
if name:
|
||||
parts.append(f"{name}{args}")
|
||||
if assistant_content:
|
||||
parts.append(assistant_content)
|
||||
return "\n".join(parts) if parts else None
|
||||
|
||||
|
||||
def _model_provider_slug(model: str) -> str:
|
||||
text = (model or "").strip().lower()
|
||||
if not text:
|
||||
return "unknown"
|
||||
return text.split("/", 1)[0]
|
||||
|
||||
|
||||
def _model_provider_name(slug: str) -> str:
|
||||
known = {
|
||||
"openai": "OpenAI",
|
||||
"anthropic": "Anthropic",
|
||||
"google": "Google",
|
||||
"meta": "Meta",
|
||||
"mistral": "Mistral",
|
||||
"deepseek": "DeepSeek",
|
||||
"x-ai": "xAI",
|
||||
"xai": "xAI",
|
||||
"qwen": "Qwen",
|
||||
"nvidia": "NVIDIA",
|
||||
"cohere": "Cohere",
|
||||
}
|
||||
return known.get(slug, slug)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraceContext:
|
||||
"""Accumulates trace data during LLM streaming for OTLP emission.
|
||||
|
||||
Used by both SDK and non-SDK CoPilot paths to collect usage metrics,
|
||||
tool calls, and timing information in a consistent structure.
|
||||
"""
|
||||
|
||||
model: str = ""
|
||||
user_id: str | None = None
|
||||
session_id: str | None = None
|
||||
start_time: float = 0.0
|
||||
|
||||
# Accumulated during streaming
|
||||
text_parts: list[str] = field(default_factory=list)
|
||||
tool_calls: list[dict[str, Any]] = field(default_factory=list)
|
||||
usage: dict[str, Any] = field(default_factory=dict)
|
||||
cost_usd: float | None = None
|
||||
|
||||
def emit(
|
||||
self,
|
||||
*,
|
||||
finish_reason: str | None = None,
|
||||
messages: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
"""Build and emit the trace as a fire-and-forget background task."""
|
||||
fr = finish_reason or ("tool_calls" if self.tool_calls else "stop")
|
||||
emit_trace(
|
||||
model=self.model,
|
||||
messages=messages or [],
|
||||
assistant_content="".join(self.text_parts) or None,
|
||||
finish_reason=fr,
|
||||
prompt_tokens=(self.usage.get("prompt") or self.usage.get("input_tokens")),
|
||||
completion_tokens=(
|
||||
self.usage.get("completion") or self.usage.get("output_tokens")
|
||||
),
|
||||
total_tokens=self.usage.get("total"),
|
||||
total_cost_usd=self.cost_usd,
|
||||
cache_creation_input_tokens=self.usage.get("cache_creation_input_tokens"),
|
||||
cache_read_input_tokens=(
|
||||
self.usage.get("cached") or self.usage.get("cache_read_input_tokens")
|
||||
),
|
||||
reasoning_tokens=self.usage.get("reasoning"),
|
||||
user_id=self.user_id,
|
||||
session_id=self.session_id,
|
||||
tool_calls=self.tool_calls or None,
|
||||
start_time=self.start_time,
|
||||
end_time=time.time(),
|
||||
)
|
||||
|
||||
|
||||
def _build_otlp_payload(
|
||||
*,
|
||||
trace_id: str,
|
||||
model: str,
|
||||
messages: list[dict[str, Any]],
|
||||
assistant_content: str | None = None,
|
||||
finish_reason: str = "stop",
|
||||
prompt_tokens: int | None = None,
|
||||
completion_tokens: int | None = None,
|
||||
total_tokens: int | None = None,
|
||||
total_cost_usd: float | None = None,
|
||||
cache_creation_input_tokens: int | None = None,
|
||||
cache_read_input_tokens: int | None = None,
|
||||
reasoning_tokens: int | None = None,
|
||||
user_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
start_time: float | None = None,
|
||||
end_time: float | None = None,
|
||||
) -> dict:
|
||||
"""Build an ``ExportTraceServiceRequest`` JSON payload."""
|
||||
provider_slug = _model_provider_slug(model)
|
||||
provider_name = _model_provider_name(provider_slug)
|
||||
|
||||
prompt_payload: str | None = None
|
||||
if messages:
|
||||
prompt_payload = json.dumps({"messages": messages}, default=str)
|
||||
|
||||
completion_payload: str | None = None
|
||||
completion_text = _build_completion_text(assistant_content, tool_calls)
|
||||
if completion_text is not None:
|
||||
completion_obj: dict[str, Any] = {
|
||||
"completion": completion_text,
|
||||
"reasoning": None,
|
||||
"rawRequest": {
|
||||
"model": model,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
"tool_choice": "auto",
|
||||
"user": user_id,
|
||||
"posthogDistinctId": user_id,
|
||||
"session_id": session_id,
|
||||
},
|
||||
}
|
||||
completion_payload = json.dumps(completion_obj, default=str)
|
||||
|
||||
attrs: list[dict] = []
|
||||
for kv in [
|
||||
_kv("trace.name", "OpenRouter Request"),
|
||||
_kv("span.type", "generation"),
|
||||
_kv("span.level", "DEFAULT"),
|
||||
_kv("gen_ai.operation.name", "chat"),
|
||||
_kv("gen_ai.system", provider_slug),
|
||||
_kv("gen_ai.provider.name", provider_slug),
|
||||
_kv("gen_ai.request.model", model),
|
||||
_kv("gen_ai.response.model", model),
|
||||
_kv("gen_ai.response.finish_reason", finish_reason),
|
||||
_kv("gen_ai.response.finish_reasons", json.dumps([finish_reason])),
|
||||
_kv("gen_ai.usage.input_tokens", prompt_tokens),
|
||||
_kv("gen_ai.usage.output_tokens", completion_tokens),
|
||||
_kv("gen_ai.usage.total_tokens", total_tokens),
|
||||
_kv("gen_ai.usage.input_tokens.cached", cache_read_input_tokens),
|
||||
_kv(
|
||||
"gen_ai.usage.input_tokens.cache_creation",
|
||||
cache_creation_input_tokens,
|
||||
),
|
||||
_kv("gen_ai.usage.output_tokens.reasoning", reasoning_tokens),
|
||||
_kv("user.id", user_id),
|
||||
_kv("session.id", session_id),
|
||||
_kv("trace.metadata.openrouter.source", "openrouter"),
|
||||
_kv("trace.metadata.openrouter.user_id", user_id),
|
||||
_kv("gen_ai.usage.total_cost", total_cost_usd),
|
||||
_kv("trace.metadata.openrouter.provider_name", provider_name),
|
||||
_kv("trace.metadata.openrouter.provider_slug", provider_slug),
|
||||
_kv("trace.metadata.openrouter.finish_reason", finish_reason),
|
||||
]:
|
||||
if kv is not None:
|
||||
attrs.append(kv)
|
||||
|
||||
if prompt_payload is not None:
|
||||
attrs.append({"key": "trace.input", "value": {"stringValue": prompt_payload}})
|
||||
attrs.append({"key": "span.input", "value": {"stringValue": prompt_payload}})
|
||||
attrs.append({"key": "gen_ai.prompt", "value": {"stringValue": prompt_payload}})
|
||||
|
||||
if completion_payload is not None:
|
||||
attrs.append(
|
||||
{
|
||||
"key": "trace.output",
|
||||
"value": {"stringValue": completion_payload},
|
||||
}
|
||||
)
|
||||
attrs.append(
|
||||
{
|
||||
"key": "span.output",
|
||||
"value": {"stringValue": completion_payload},
|
||||
}
|
||||
)
|
||||
attrs.append(
|
||||
{
|
||||
"key": "gen_ai.completion",
|
||||
"value": {"stringValue": completion_payload},
|
||||
}
|
||||
)
|
||||
|
||||
span = {
|
||||
"traceId": trace_id,
|
||||
"startTimeUnixNano": _nano(start_time or time.time()),
|
||||
"endTimeUnixNano": _nano(end_time or time.time()),
|
||||
"attributes": attrs,
|
||||
}
|
||||
|
||||
return {
|
||||
"resourceSpans": [
|
||||
{
|
||||
"resource": {
|
||||
"attributes": [
|
||||
{
|
||||
"key": "service.name",
|
||||
"value": {"stringValue": "openrouter"},
|
||||
},
|
||||
{
|
||||
"key": "openrouter.trace.id",
|
||||
"value": {
|
||||
"stringValue": (
|
||||
f"gen-{int(end_time or time.time())}"
|
||||
f"-{trace_id[:20]}"
|
||||
)
|
||||
},
|
||||
},
|
||||
]
|
||||
},
|
||||
"scopeSpans": [{"spans": [span]}],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
async def _send_trace(payload: dict) -> None:
|
||||
"""POST the OTLP payload to the configured tracing host."""
|
||||
url = f"{_TRACING_HOST}/v1/traces"
|
||||
try:
|
||||
client = _get_client()
|
||||
resp = await client.post(url, json=payload)
|
||||
if resp.status_code >= 400:
|
||||
logger.debug(
|
||||
"[OTLP] Trace POST returned %d: %s",
|
||||
resp.status_code,
|
||||
resp.text[:200],
|
||||
)
|
||||
else:
|
||||
logger.debug("[OTLP] Trace sent successfully (%d)", resp.status_code)
|
||||
except Exception as e:
|
||||
logger.warning("[OTLP] Failed to send trace: %s", e)
|
||||
|
||||
|
||||
# Background task set with backpressure cap.
|
||||
_bg_tasks: set[asyncio.Task[Any]] = set()
|
||||
_MAX_BG_TASKS = 64
|
||||
|
||||
|
||||
async def _build_and_send_trace(
|
||||
*,
|
||||
model: str,
|
||||
messages: list[dict[str, Any]],
|
||||
assistant_content: str | None,
|
||||
finish_reason: str,
|
||||
prompt_tokens: int | None,
|
||||
completion_tokens: int | None,
|
||||
total_tokens: int | None,
|
||||
total_cost_usd: float | None,
|
||||
cache_creation_input_tokens: int | None,
|
||||
cache_read_input_tokens: int | None,
|
||||
reasoning_tokens: int | None,
|
||||
user_id: str | None,
|
||||
session_id: str | None,
|
||||
tool_calls: list[dict[str, Any]] | None,
|
||||
start_time: float | None,
|
||||
end_time: float | None,
|
||||
) -> None:
|
||||
"""Build the OTLP payload and send it — runs entirely in a background task."""
|
||||
trace_id = uuid.uuid4().hex
|
||||
payload = _build_otlp_payload(
|
||||
trace_id=trace_id,
|
||||
model=model,
|
||||
messages=messages,
|
||||
assistant_content=assistant_content,
|
||||
finish_reason=finish_reason,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
total_cost_usd=total_cost_usd,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
tool_calls=tool_calls,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
await _send_trace(payload)
|
||||
|
||||
|
||||
def emit_trace(
|
||||
*,
|
||||
model: str,
|
||||
messages: list[dict[str, Any]],
|
||||
assistant_content: str | None = None,
|
||||
finish_reason: str = "stop",
|
||||
prompt_tokens: int | None = None,
|
||||
completion_tokens: int | None = None,
|
||||
total_tokens: int | None = None,
|
||||
total_cost_usd: float | None = None,
|
||||
cache_creation_input_tokens: int | None = None,
|
||||
cache_read_input_tokens: int | None = None,
|
||||
reasoning_tokens: int | None = None,
|
||||
user_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
start_time: float | None = None,
|
||||
end_time: float | None = None,
|
||||
) -> None:
|
||||
"""Fire-and-forget: build and send an OTLP trace span.
|
||||
|
||||
Safe to call from async context — both payload serialization and the
|
||||
HTTP POST run in a background task so they never block the event loop.
|
||||
"""
|
||||
if not _TRACING_ENABLED:
|
||||
return
|
||||
|
||||
if len(_bg_tasks) >= _MAX_BG_TASKS:
|
||||
logger.warning(
|
||||
"[OTLP] Backpressure: dropping trace (%d tasks queued)",
|
||||
len(_bg_tasks),
|
||||
)
|
||||
return
|
||||
|
||||
task = asyncio.create_task(
|
||||
_build_and_send_trace(
|
||||
model=model,
|
||||
messages=messages,
|
||||
assistant_content=assistant_content,
|
||||
finish_reason=finish_reason,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
total_cost_usd=total_cost_usd,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
tool_calls=tool_calls,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
)
|
||||
_bg_tasks.add(task)
|
||||
task.add_done_callback(_bg_tasks.discard)
|
||||
@@ -0,0 +1,269 @@
|
||||
"""Tests for parallel tool call execution in CoPilot.
|
||||
|
||||
These tests mock _yield_tool_call to avoid importing the full copilot stack
|
||||
which requires Prisma, DB connections, etc.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_tool_calls_run_concurrently():
|
||||
"""Multiple tool calls should complete in ~max(delays), not sum(delays)."""
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
n_tools = 3
|
||||
delay_per_tool = 0.2
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"tool_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(n_tools)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
original_yield = None
|
||||
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"],
|
||||
toolName=tc_list[idx]["function"]["name"],
|
||||
input={},
|
||||
)
|
||||
await asyncio.sleep(delay_per_tool)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tc_list[idx]["id"],
|
||||
toolName=tc_list[idx]["function"]["name"],
|
||||
output="{}",
|
||||
)
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
original_yield = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
start = time.monotonic()
|
||||
events = []
|
||||
async for event in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
):
|
||||
events.append(event)
|
||||
elapsed = time.monotonic() - start
|
||||
finally:
|
||||
svc._yield_tool_call = original_yield
|
||||
|
||||
assert len(events) == n_tools * 2
|
||||
# Parallel: should take ~delay, not ~n*delay
|
||||
assert elapsed < delay_per_tool * (
|
||||
n_tools - 0.5
|
||||
), f"Took {elapsed:.2f}s, expected parallel (~{delay_per_tool}s)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_tool_call_works():
|
||||
"""Single tool call should work identically."""
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": "call_0",
|
||||
"type": "function",
|
||||
"function": {"name": "t", "arguments": "{}"},
|
||||
}
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
yield StreamToolInputAvailable(toolCallId="call_0", toolName="t", input={})
|
||||
yield StreamToolOutputAvailable(toolCallId="call_0", toolName="t", output="{}")
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
events = [
|
||||
e
|
||||
async for e in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
)
|
||||
]
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
|
||||
assert len(events) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retryable_error_propagates():
|
||||
"""Retryable errors should be raised after all tools finish."""
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"t_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
if idx == 1:
|
||||
raise KeyError("bad")
|
||||
from backend.copilot.response_model import StreamToolInputAvailable
|
||||
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName="t_0", input={}
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName="t_0", output="{}"
|
||||
)
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
events = []
|
||||
with pytest.raises(KeyError):
|
||||
async for event in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
):
|
||||
events.append(event)
|
||||
# First tool's events should still be yielded
|
||||
assert any(isinstance(e, StreamToolOutputAvailable) for e in events)
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_shared_across_parallel_tools():
|
||||
"""All parallel tools should receive the same session instance."""
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"t_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
observed_sessions = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
observed_sessions.append(sess)
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
|
||||
)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", output="{}"
|
||||
)
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
async for _ in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
):
|
||||
pass
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
|
||||
assert len(observed_sessions) == 3
|
||||
assert observed_sessions[0] is observed_sessions[1] is observed_sessions[2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancellation_cleans_up():
|
||||
"""Generator close should cancel in-flight tasks."""
|
||||
from backend.copilot.response_model import StreamToolInputAvailable
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"t_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
|
||||
)
|
||||
started.set()
|
||||
await asyncio.sleep(10) # simulate long-running
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
gen = _execute_tool_calls_parallel(tool_calls, cast(Any, FakeSession()))
|
||||
await gen.__anext__() # get first event
|
||||
await started.wait()
|
||||
await gen.aclose() # close generator
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
# If we get here without hanging, cleanup worked
|
||||
@@ -13,7 +13,6 @@ from typing import Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.util.json import dumps as json_dumps
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -151,9 +150,6 @@ class StreamToolInputAvailable(StreamBaseResponse):
|
||||
)
|
||||
|
||||
|
||||
_MAX_TOOL_OUTPUT_SIZE = 100_000 # ~100 KB; truncate to avoid bloating SSE/DB
|
||||
|
||||
|
||||
class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
"""Tool execution result."""
|
||||
|
||||
@@ -168,10 +164,6 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
default=True, description="Whether the tool execution succeeded"
|
||||
)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Truncate oversized outputs after construction."""
|
||||
self.output = truncate(self.output, _MAX_TOOL_OUTPUT_SIZE)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format, excluding non-spec fields."""
|
||||
data = {
|
||||
|
||||
@@ -1,239 +0,0 @@
|
||||
"""Compaction tracking for SDK-based chat sessions.
|
||||
|
||||
Encapsulates the state machine and event emission for context compaction,
|
||||
both pre-query (history compressed before SDK query) and SDK-internal
|
||||
(PreCompact hook fires mid-stream).
|
||||
|
||||
All compaction-related helpers live here: event builders, message filtering,
|
||||
persistence, and the ``CompactionTracker`` state machine.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
|
||||
from ..constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
|
||||
from ..model import ChatMessage, ChatSession
|
||||
from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamFinishStep,
|
||||
StreamStartStep,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Event builders (private — use CompactionTracker or compaction_events)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _start_events(tool_call_id: str) -> list[StreamBaseResponse]:
|
||||
"""Build the opening events for a compaction tool call."""
|
||||
return [
|
||||
StreamStartStep(),
|
||||
StreamToolInputStart(toolCallId=tool_call_id, toolName=COMPACTION_TOOL_NAME),
|
||||
StreamToolInputAvailable(
|
||||
toolCallId=tool_call_id, toolName=COMPACTION_TOOL_NAME, input={}
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _end_events(tool_call_id: str, message: str) -> list[StreamBaseResponse]:
|
||||
"""Build the closing events for a compaction tool call."""
|
||||
return [
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=COMPACTION_TOOL_NAME,
|
||||
output=message,
|
||||
),
|
||||
StreamFinishStep(),
|
||||
]
|
||||
|
||||
|
||||
def _new_tool_call_id() -> str:
|
||||
return f"compaction-{uuid.uuid4().hex[:12]}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public event builder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def emit_compaction(session: ChatSession) -> list[StreamBaseResponse]:
|
||||
"""Create, persist, and return a self-contained compaction tool call.
|
||||
|
||||
Convenience for callers that don't use ``CompactionTracker`` (e.g. the
|
||||
legacy non-SDK streaming path in ``service.py``).
|
||||
"""
|
||||
tc_id = _new_tool_call_id()
|
||||
evts = compaction_events(COMPACTION_DONE_MSG, tool_call_id=tc_id)
|
||||
_persist(session, tc_id, COMPACTION_DONE_MSG)
|
||||
return evts
|
||||
|
||||
|
||||
def compaction_events(
|
||||
message: str, tool_call_id: str | None = None
|
||||
) -> list[StreamBaseResponse]:
|
||||
"""Emit a self-contained compaction tool call (already completed).
|
||||
|
||||
When *tool_call_id* is provided it is reused (e.g. for persistence that
|
||||
must match an already-streamed start event). Otherwise a new ID is
|
||||
generated.
|
||||
"""
|
||||
tc_id = tool_call_id or _new_tool_call_id()
|
||||
return _start_events(tc_id) + _end_events(tc_id, message)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message filtering
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def filter_compaction_messages(
|
||||
messages: list[ChatMessage],
|
||||
) -> list[ChatMessage]:
|
||||
"""Remove synthetic compaction tool-call messages (UI-only artifacts).
|
||||
|
||||
Strips assistant messages whose only tool calls are compaction calls,
|
||||
and their corresponding tool-result messages.
|
||||
"""
|
||||
compaction_ids: set[str] = set()
|
||||
filtered: list[ChatMessage] = []
|
||||
for msg in messages:
|
||||
if msg.role == "assistant" and msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
if tc.get("function", {}).get("name") == COMPACTION_TOOL_NAME:
|
||||
compaction_ids.add(tc.get("id", ""))
|
||||
real_calls = [
|
||||
tc
|
||||
for tc in msg.tool_calls
|
||||
if tc.get("function", {}).get("name") != COMPACTION_TOOL_NAME
|
||||
]
|
||||
if not real_calls and not msg.content:
|
||||
continue
|
||||
if msg.role == "tool" and msg.tool_call_id in compaction_ids:
|
||||
continue
|
||||
filtered.append(msg)
|
||||
return filtered
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _persist(session: ChatSession, tool_call_id: str, message: str) -> None:
|
||||
"""Append compaction tool-call + result to session messages.
|
||||
|
||||
Compaction events are synthetic so they bypass the normal adapter
|
||||
accumulation. This explicitly records them so they survive a page refresh.
|
||||
"""
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": tool_call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": COMPACTION_TOOL_NAME,
|
||||
"arguments": "{}",
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
)
|
||||
session.messages.append(
|
||||
ChatMessage(role="tool", content=message, tool_call_id=tool_call_id)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CompactionTracker — state machine for streaming sessions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CompactionTracker:
|
||||
"""Tracks compaction state and yields UI events.
|
||||
|
||||
Two compaction paths:
|
||||
|
||||
1. **Pre-query** — history compressed before the SDK query starts.
|
||||
Call :meth:`emit_pre_query` to yield a self-contained tool call.
|
||||
|
||||
2. **SDK-internal** — ``PreCompact`` hook fires mid-stream.
|
||||
Call :meth:`emit_start_if_ready` on heartbeat ticks and
|
||||
:meth:`emit_end_if_ready` when a message arrives.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._compact_start = asyncio.Event()
|
||||
self._start_emitted = False
|
||||
self._done = False
|
||||
self._tool_call_id = ""
|
||||
|
||||
@property
|
||||
def on_compact(self) -> Callable[[], None]:
|
||||
"""Callback for the PreCompact hook."""
|
||||
return self._compact_start.set
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Pre-query compaction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def emit_pre_query(self, session: ChatSession) -> list[StreamBaseResponse]:
|
||||
"""Emit + persist a self-contained compaction tool call."""
|
||||
self._done = True
|
||||
return emit_compaction(session)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SDK-internal compaction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def reset_for_query(self) -> None:
|
||||
"""Reset per-query state before a new SDK query."""
|
||||
self._done = False
|
||||
self._start_emitted = False
|
||||
self._tool_call_id = ""
|
||||
|
||||
def emit_start_if_ready(self) -> list[StreamBaseResponse]:
|
||||
"""If the PreCompact hook fired, emit start events (spinning tool)."""
|
||||
if self._compact_start.is_set() and not self._start_emitted and not self._done:
|
||||
self._compact_start.clear()
|
||||
self._start_emitted = True
|
||||
self._tool_call_id = _new_tool_call_id()
|
||||
return _start_events(self._tool_call_id)
|
||||
return []
|
||||
|
||||
async def emit_end_if_ready(self, session: ChatSession) -> list[StreamBaseResponse]:
|
||||
"""If compaction is in progress, emit end events and persist."""
|
||||
# Yield so pending hook tasks can set compact_start
|
||||
await asyncio.sleep(0)
|
||||
|
||||
if self._done:
|
||||
return []
|
||||
if not self._start_emitted and not self._compact_start.is_set():
|
||||
return []
|
||||
|
||||
if self._start_emitted:
|
||||
# Close the open spinner
|
||||
done_events = _end_events(self._tool_call_id, COMPACTION_DONE_MSG)
|
||||
persist_id = self._tool_call_id
|
||||
else:
|
||||
# PreCompact fired but start never emitted — self-contained
|
||||
persist_id = _new_tool_call_id()
|
||||
done_events = compaction_events(
|
||||
COMPACTION_DONE_MSG, tool_call_id=persist_id
|
||||
)
|
||||
|
||||
self._compact_start.clear()
|
||||
self._start_emitted = False
|
||||
self._done = True
|
||||
_persist(session, persist_id, COMPACTION_DONE_MSG)
|
||||
return done_events
|
||||
@@ -1,291 +0,0 @@
|
||||
"""Tests for sdk/compaction.py — event builders, filtering, persistence, and
|
||||
CompactionTracker state machine."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.response_model import (
|
||||
StreamFinishStep,
|
||||
StreamStartStep,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.sdk.compaction import (
|
||||
CompactionTracker,
|
||||
compaction_events,
|
||||
emit_compaction,
|
||||
filter_compaction_messages,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_session() -> ChatSession:
|
||||
return ChatSession.new(user_id="test-user")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compaction_events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactionEvents:
|
||||
def test_returns_start_and_end_events(self):
|
||||
evts = compaction_events("done")
|
||||
assert len(evts) == 5
|
||||
assert isinstance(evts[0], StreamStartStep)
|
||||
assert isinstance(evts[1], StreamToolInputStart)
|
||||
assert isinstance(evts[2], StreamToolInputAvailable)
|
||||
assert isinstance(evts[3], StreamToolOutputAvailable)
|
||||
assert isinstance(evts[4], StreamFinishStep)
|
||||
|
||||
def test_uses_provided_tool_call_id(self):
|
||||
evts = compaction_events("msg", tool_call_id="my-id")
|
||||
tool_start = evts[1]
|
||||
assert isinstance(tool_start, StreamToolInputStart)
|
||||
assert tool_start.toolCallId == "my-id"
|
||||
|
||||
def test_generates_id_when_not_provided(self):
|
||||
evts = compaction_events("msg")
|
||||
tool_start = evts[1]
|
||||
assert isinstance(tool_start, StreamToolInputStart)
|
||||
assert tool_start.toolCallId.startswith("compaction-")
|
||||
|
||||
def test_tool_name_is_context_compaction(self):
|
||||
evts = compaction_events("msg")
|
||||
tool_start = evts[1]
|
||||
assert isinstance(tool_start, StreamToolInputStart)
|
||||
assert tool_start.toolName == COMPACTION_TOOL_NAME
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# emit_compaction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitCompaction:
|
||||
def test_persists_to_session(self):
|
||||
session = _make_session()
|
||||
assert len(session.messages) == 0
|
||||
evts = emit_compaction(session)
|
||||
assert len(evts) == 5
|
||||
# Should have appended 2 messages (assistant tool call + tool result)
|
||||
assert len(session.messages) == 2
|
||||
assert session.messages[0].role == "assistant"
|
||||
assert session.messages[0].tool_calls is not None
|
||||
assert (
|
||||
session.messages[0].tool_calls[0]["function"]["name"]
|
||||
== COMPACTION_TOOL_NAME
|
||||
)
|
||||
assert session.messages[1].role == "tool"
|
||||
assert session.messages[1].content == COMPACTION_DONE_MSG
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# filter_compaction_messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFilterCompactionMessages:
|
||||
def test_removes_compaction_tool_calls(self):
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="hello"),
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "comp-1",
|
||||
"type": "function",
|
||||
"function": {"name": COMPACTION_TOOL_NAME, "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
),
|
||||
ChatMessage(
|
||||
role="tool", content=COMPACTION_DONE_MSG, tool_call_id="comp-1"
|
||||
),
|
||||
ChatMessage(role="assistant", content="world"),
|
||||
]
|
||||
filtered = filter_compaction_messages(msgs)
|
||||
assert len(filtered) == 2
|
||||
assert filtered[0].content == "hello"
|
||||
assert filtered[1].content == "world"
|
||||
|
||||
def test_keeps_non_compaction_tool_calls(self):
|
||||
msgs = [
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "real-1",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
),
|
||||
ChatMessage(role="tool", content="result", tool_call_id="real-1"),
|
||||
]
|
||||
filtered = filter_compaction_messages(msgs)
|
||||
assert len(filtered) == 2
|
||||
|
||||
def test_keeps_assistant_with_content_and_compaction_call(self):
|
||||
"""If assistant message has both content and a compaction tool call,
|
||||
the message is kept (has real content)."""
|
||||
msgs = [
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="I have content",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "comp-1",
|
||||
"type": "function",
|
||||
"function": {"name": COMPACTION_TOOL_NAME, "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
filtered = filter_compaction_messages(msgs)
|
||||
assert len(filtered) == 1
|
||||
|
||||
def test_empty_list(self):
|
||||
assert filter_compaction_messages([]) == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CompactionTracker
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactionTracker:
|
||||
def test_on_compact_sets_event(self):
|
||||
tracker = CompactionTracker()
|
||||
tracker.on_compact()
|
||||
assert tracker._compact_start.is_set()
|
||||
|
||||
def test_emit_start_if_ready_no_event(self):
|
||||
tracker = CompactionTracker()
|
||||
assert tracker.emit_start_if_ready() == []
|
||||
|
||||
def test_emit_start_if_ready_with_event(self):
|
||||
tracker = CompactionTracker()
|
||||
tracker.on_compact()
|
||||
evts = tracker.emit_start_if_ready()
|
||||
assert len(evts) == 3
|
||||
assert isinstance(evts[0], StreamStartStep)
|
||||
assert isinstance(evts[1], StreamToolInputStart)
|
||||
assert isinstance(evts[2], StreamToolInputAvailable)
|
||||
|
||||
def test_emit_start_only_once(self):
|
||||
tracker = CompactionTracker()
|
||||
tracker.on_compact()
|
||||
evts1 = tracker.emit_start_if_ready()
|
||||
assert len(evts1) == 3
|
||||
# Second call should return empty
|
||||
evts2 = tracker.emit_start_if_ready()
|
||||
assert evts2 == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_after_start(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
tracker.emit_start_if_ready()
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert len(evts) == 2
|
||||
assert isinstance(evts[0], StreamToolOutputAvailable)
|
||||
assert isinstance(evts[1], StreamFinishStep)
|
||||
# Should persist
|
||||
assert len(session.messages) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_without_start_self_contained(self):
|
||||
"""If PreCompact fired but start was never emitted, emit_end
|
||||
produces a self-contained compaction event."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
# Don't call emit_start_if_ready
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert len(evts) == 5 # Full self-contained event
|
||||
assert isinstance(evts[0], StreamStartStep)
|
||||
assert len(session.messages) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_no_op_when_done(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
tracker.emit_start_if_ready()
|
||||
await tracker.emit_end_if_ready(session)
|
||||
# Second call should be no-op
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert evts == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_no_op_when_nothing_happened(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert evts == []
|
||||
|
||||
def test_emit_pre_query(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
evts = tracker.emit_pre_query(session)
|
||||
assert len(evts) == 5
|
||||
assert len(session.messages) == 2
|
||||
assert tracker._done is True
|
||||
|
||||
def test_reset_for_query(self):
|
||||
tracker = CompactionTracker()
|
||||
tracker._done = True
|
||||
tracker._start_emitted = True
|
||||
tracker._tool_call_id = "old"
|
||||
tracker.reset_for_query()
|
||||
assert tracker._done is False
|
||||
assert tracker._start_emitted is False
|
||||
assert tracker._tool_call_id == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_query_blocks_sdk_compaction(self):
|
||||
"""After pre-query compaction, SDK compaction events are suppressed."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.emit_pre_query(session)
|
||||
tracker.on_compact()
|
||||
evts = tracker.emit_start_if_ready()
|
||||
assert evts == [] # _done blocks it
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_allows_new_compaction(self):
|
||||
"""After reset_for_query, compaction can fire again."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.emit_pre_query(session)
|
||||
tracker.reset_for_query()
|
||||
tracker.on_compact()
|
||||
evts = tracker.emit_start_if_ready()
|
||||
assert len(evts) == 3 # Start events emitted
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_id_consistency(self):
|
||||
"""Start and end events use the same tool_call_id."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
start_evts = tracker.emit_start_if_ready()
|
||||
end_evts = await tracker.emit_end_if_ready(session)
|
||||
start_evt = start_evts[1]
|
||||
end_evt = end_evts[0]
|
||||
assert isinstance(start_evt, StreamToolInputStart)
|
||||
assert isinstance(end_evt, StreamToolOutputAvailable)
|
||||
assert start_evt.toolCallId == end_evt.toolCallId
|
||||
# Persisted ID should also match
|
||||
tool_calls = session.messages[0].tool_calls
|
||||
assert tool_calls is not None
|
||||
assert tool_calls[0]["id"] == start_evt.toolCallId
|
||||
@@ -10,7 +10,6 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from ..model import ChatSession
|
||||
from ..response_model import StreamBaseResponse, StreamStart, StreamTextDelta
|
||||
@@ -27,7 +26,6 @@ async def stream_chat_completion_dummy(
|
||||
retry_count: int = 0,
|
||||
session: ChatSession | None = None,
|
||||
context: dict[str, str] | None = None,
|
||||
**_kwargs: Any,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Stream dummy chat completion for testing.
|
||||
|
||||
|
||||
@@ -1,362 +0,0 @@
|
||||
"""MCP file-tool handlers that route to the E2B cloud sandbox.
|
||||
|
||||
When E2B is active, these tools replace the SDK built-in Read/Write/Edit/
|
||||
Glob/Grep so that all file operations share the same ``/home/user``
|
||||
filesystem as ``bash_exec``.
|
||||
|
||||
SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
|
||||
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
from typing import Any, Callable
|
||||
|
||||
from backend.copilot.tools.e2b_sandbox import E2B_WORKDIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Lazy imports to break circular dependency with tool_adapter.
|
||||
|
||||
|
||||
def _get_sandbox(): # type: ignore[return]
|
||||
from .tool_adapter import get_current_sandbox # noqa: E402
|
||||
|
||||
return get_current_sandbox()
|
||||
|
||||
|
||||
def _is_allowed_local(path: str) -> bool:
|
||||
from .tool_adapter import is_allowed_local_path # noqa: E402
|
||||
|
||||
return is_allowed_local_path(path)
|
||||
|
||||
|
||||
def _resolve_remote(path: str) -> str:
|
||||
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
|
||||
|
||||
Raises :class:`ValueError` if the resolved path escapes the sandbox.
|
||||
"""
|
||||
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
|
||||
normalized = os.path.normpath(candidate)
|
||||
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
|
||||
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
|
||||
return normalized
|
||||
|
||||
|
||||
def _mcp(text: str, *, error: bool = False) -> dict[str, Any]:
|
||||
if error:
|
||||
text = json.dumps({"error": text, "type": "error"})
|
||||
return {"content": [{"type": "text", "text": text}], "isError": error}
|
||||
|
||||
|
||||
def _get_sandbox_and_path(
|
||||
file_path: str,
|
||||
) -> tuple[Any, str] | dict[str, Any]:
|
||||
"""Common preamble: get sandbox + resolve path, or return MCP error."""
|
||||
sandbox = _get_sandbox()
|
||||
if sandbox is None:
|
||||
return _mcp("No E2B sandbox available", error=True)
|
||||
try:
|
||||
remote = _resolve_remote(file_path)
|
||||
except ValueError as exc:
|
||||
return _mcp(str(exc), error=True)
|
||||
return sandbox, remote
|
||||
|
||||
|
||||
# Tool handlers
|
||||
|
||||
|
||||
async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
file_path: str = args.get("file_path", "")
|
||||
offset: int = max(0, int(args.get("offset", 0)))
|
||||
limit: int = max(1, int(args.get("limit", 2000)))
|
||||
|
||||
if not file_path:
|
||||
return _mcp("file_path is required", error=True)
|
||||
|
||||
# SDK-internal paths (tool-results, ephemeral working dir) stay on the host.
|
||||
if _is_allowed_local(file_path):
|
||||
return _read_local(file_path, offset, limit)
|
||||
|
||||
result = _get_sandbox_and_path(file_path)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
sandbox, remote = result
|
||||
|
||||
try:
|
||||
raw: bytes = await sandbox.files.read(remote, format="bytes")
|
||||
content = raw.decode("utf-8", errors="replace")
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to read {remote}: {exc}", error=True)
|
||||
|
||||
lines = content.splitlines(keepends=True)
|
||||
selected = list(itertools.islice(lines, offset, offset + limit))
|
||||
numbered = "".join(
|
||||
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
|
||||
)
|
||||
return _mcp(numbered)
|
||||
|
||||
|
||||
async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
file_path: str = args.get("file_path", "")
|
||||
content: str = args.get("content", "")
|
||||
|
||||
if not file_path:
|
||||
return _mcp("file_path is required", error=True)
|
||||
|
||||
result = _get_sandbox_and_path(file_path)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
sandbox, remote = result
|
||||
|
||||
try:
|
||||
parent = os.path.dirname(remote)
|
||||
if parent and parent != E2B_WORKDIR:
|
||||
await sandbox.files.make_dir(parent)
|
||||
await sandbox.files.write(remote, content)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to write {remote}: {exc}", error=True)
|
||||
|
||||
return _mcp(f"Successfully wrote to {remote}")
|
||||
|
||||
|
||||
async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
file_path: str = args.get("file_path", "")
|
||||
old_string: str = args.get("old_string", "")
|
||||
new_string: str = args.get("new_string", "")
|
||||
replace_all: bool = args.get("replace_all", False)
|
||||
|
||||
if not file_path:
|
||||
return _mcp("file_path is required", error=True)
|
||||
if not old_string:
|
||||
return _mcp("old_string is required", error=True)
|
||||
|
||||
result = _get_sandbox_and_path(file_path)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
sandbox, remote = result
|
||||
|
||||
try:
|
||||
raw: bytes = await sandbox.files.read(remote, format="bytes")
|
||||
content = raw.decode("utf-8", errors="replace")
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to read {remote}: {exc}", error=True)
|
||||
|
||||
count = content.count(old_string)
|
||||
if count == 0:
|
||||
return _mcp(f"old_string not found in {file_path}", error=True)
|
||||
if count > 1 and not replace_all:
|
||||
return _mcp(
|
||||
f"old_string appears {count} times in {file_path}. "
|
||||
"Use replace_all=true or provide a more unique string.",
|
||||
error=True,
|
||||
)
|
||||
|
||||
updated = (
|
||||
content.replace(old_string, new_string)
|
||||
if replace_all
|
||||
else content.replace(old_string, new_string, 1)
|
||||
)
|
||||
try:
|
||||
await sandbox.files.write(remote, updated)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to write {remote}: {exc}", error=True)
|
||||
|
||||
return _mcp(f"Edited {remote} ({count} replacement{'s' if count > 1 else ''})")
|
||||
|
||||
|
||||
async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
|
||||
pattern: str = args.get("pattern", "")
|
||||
path: str = args.get("path", "")
|
||||
|
||||
if not pattern:
|
||||
return _mcp("pattern is required", error=True)
|
||||
|
||||
sandbox = _get_sandbox()
|
||||
if sandbox is None:
|
||||
return _mcp("No E2B sandbox available", error=True)
|
||||
|
||||
try:
|
||||
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
|
||||
except ValueError as exc:
|
||||
return _mcp(str(exc), error=True)
|
||||
|
||||
cmd = f"find {shlex.quote(search_dir)} -name {shlex.quote(pattern)} -type f 2>/dev/null | head -500"
|
||||
try:
|
||||
result = await sandbox.commands.run(cmd, cwd=E2B_WORKDIR, timeout=10)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Glob failed: {exc}", error=True)
|
||||
|
||||
files = [line for line in (result.stdout or "").strip().splitlines() if line]
|
||||
return _mcp(json.dumps(files, indent=2))
|
||||
|
||||
|
||||
async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
|
||||
pattern: str = args.get("pattern", "")
|
||||
path: str = args.get("path", "")
|
||||
include: str = args.get("include", "")
|
||||
|
||||
if not pattern:
|
||||
return _mcp("pattern is required", error=True)
|
||||
|
||||
sandbox = _get_sandbox()
|
||||
if sandbox is None:
|
||||
return _mcp("No E2B sandbox available", error=True)
|
||||
|
||||
try:
|
||||
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
|
||||
except ValueError as exc:
|
||||
return _mcp(str(exc), error=True)
|
||||
|
||||
parts = ["grep", "-rn", "--color=never"]
|
||||
if include:
|
||||
parts.extend(["--include", include])
|
||||
parts.extend([pattern, search_dir])
|
||||
cmd = " ".join(shlex.quote(p) for p in parts) + " 2>/dev/null | head -200"
|
||||
|
||||
try:
|
||||
result = await sandbox.commands.run(cmd, cwd=E2B_WORKDIR, timeout=15)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Grep failed: {exc}", error=True)
|
||||
|
||||
output = (result.stdout or "").strip()
|
||||
return _mcp(output if output else "No matches found.")
|
||||
|
||||
|
||||
# Local read (for SDK-internal paths)
|
||||
|
||||
|
||||
def _read_local(file_path: str, offset: int, limit: int) -> dict[str, Any]:
|
||||
"""Read from the host filesystem (defence-in-depth path check)."""
|
||||
if not _is_allowed_local(file_path):
|
||||
return _mcp(f"Path not allowed: {file_path}", error=True)
|
||||
expanded = os.path.realpath(os.path.expanduser(file_path))
|
||||
try:
|
||||
with open(expanded) as fh:
|
||||
selected = list(itertools.islice(fh, offset, offset + limit))
|
||||
numbered = "".join(
|
||||
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
|
||||
)
|
||||
return _mcp(numbered)
|
||||
except FileNotFoundError:
|
||||
return _mcp(f"File not found: {file_path}", error=True)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Error reading {file_path}: {exc}", error=True)
|
||||
|
||||
|
||||
# Tool descriptors (name, description, schema, handler)
|
||||
|
||||
E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
|
||||
(
|
||||
"read_file",
|
||||
"Read a file from the cloud sandbox (/home/user). "
|
||||
"Use offset and limit for large files.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path (relative to /home/user, or absolute).",
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line to start reading from (0-indexed). Default: 0.",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Number of lines to read. Default: 2000.",
|
||||
},
|
||||
},
|
||||
"required": ["file_path"],
|
||||
},
|
||||
_handle_read_file,
|
||||
),
|
||||
(
|
||||
"write_file",
|
||||
"Write or create a file in the cloud sandbox (/home/user). "
|
||||
"Parent directories are created automatically. "
|
||||
"To copy a workspace file into the sandbox, use "
|
||||
"read_workspace_file with save_to_path instead.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path (relative to /home/user, or absolute).",
|
||||
},
|
||||
"content": {"type": "string", "description": "Content to write."},
|
||||
},
|
||||
"required": ["file_path", "content"],
|
||||
},
|
||||
_handle_write_file,
|
||||
),
|
||||
(
|
||||
"edit_file",
|
||||
"Targeted text replacement in a sandbox file. "
|
||||
"old_string must appear in the file and is replaced with new_string.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path (relative to /home/user, or absolute).",
|
||||
},
|
||||
"old_string": {"type": "string", "description": "Text to find."},
|
||||
"new_string": {"type": "string", "description": "Replacement text."},
|
||||
"replace_all": {
|
||||
"type": "boolean",
|
||||
"description": "Replace all occurrences (default: false).",
|
||||
},
|
||||
},
|
||||
"required": ["file_path", "old_string", "new_string"],
|
||||
},
|
||||
_handle_edit_file,
|
||||
),
|
||||
(
|
||||
"glob",
|
||||
"Search for files by name pattern in the cloud sandbox.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Glob pattern (e.g. *.py).",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory to search. Default: /home/user.",
|
||||
},
|
||||
},
|
||||
"required": ["pattern"],
|
||||
},
|
||||
_handle_glob,
|
||||
),
|
||||
(
|
||||
"grep",
|
||||
"Search file contents by regex in the cloud sandbox.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {"type": "string", "description": "Regex pattern."},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File or directory. Default: /home/user.",
|
||||
},
|
||||
"include": {
|
||||
"type": "string",
|
||||
"description": "Glob to filter files (e.g. *.py).",
|
||||
},
|
||||
},
|
||||
"required": ["pattern"],
|
||||
},
|
||||
_handle_grep,
|
||||
),
|
||||
]
|
||||
|
||||
E2B_FILE_TOOL_NAMES: list[str] = [name for name, *_ in E2B_FILE_TOOLS]
|
||||
@@ -1,153 +0,0 @@
|
||||
"""Tests for E2B file-tool path validation and local read safety.
|
||||
|
||||
Pure unit tests with no external dependencies (no E2B, no sandbox).
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from .e2b_file_tools import _read_local, _resolve_remote
|
||||
from .tool_adapter import _current_project_dir
|
||||
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _resolve_remote — sandbox path normalisation & boundary enforcement
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveRemote:
|
||||
def test_relative_path_resolved(self):
|
||||
assert _resolve_remote("src/main.py") == "/home/user/src/main.py"
|
||||
|
||||
def test_absolute_within_sandbox(self):
|
||||
assert _resolve_remote("/home/user/file.txt") == "/home/user/file.txt"
|
||||
|
||||
def test_workdir_itself(self):
|
||||
assert _resolve_remote("/home/user") == "/home/user"
|
||||
|
||||
def test_relative_dotslash(self):
|
||||
assert _resolve_remote("./README.md") == "/home/user/README.md"
|
||||
|
||||
def test_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("../../etc/passwd")
|
||||
|
||||
def test_absolute_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("/home/user/../../etc/passwd")
|
||||
|
||||
def test_absolute_outside_sandbox_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("/etc/passwd")
|
||||
|
||||
def test_root_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("/")
|
||||
|
||||
def test_home_other_user_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("/home/other/file.txt")
|
||||
|
||||
def test_deep_nested_allowed(self):
|
||||
assert _resolve_remote("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
|
||||
|
||||
def test_trailing_slash_normalised(self):
|
||||
assert _resolve_remote("src/") == "/home/user/src"
|
||||
|
||||
def test_double_dots_within_sandbox_ok(self):
|
||||
"""Path that resolves back within /home/user is allowed."""
|
||||
assert _resolve_remote("a/b/../c.txt") == "/home/user/a/c.txt"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_local — host filesystem reads with allowlist enforcement
|
||||
#
|
||||
# In E2B mode, _read_local only allows tool-results paths (via
|
||||
# is_allowed_local_path without sdk_cwd). Regular files live on the
|
||||
# sandbox, not the host.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReadLocal:
|
||||
def _make_tool_results_file(self, encoded: str, filename: str, content: str) -> str:
|
||||
"""Create a tool-results file and return its path."""
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
os.makedirs(tool_results_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_results_dir, filename)
|
||||
with open(filepath, "w") as f:
|
||||
f.write(content)
|
||||
return filepath
|
||||
|
||||
def test_read_tool_results_file(self):
|
||||
"""Reading a tool-results file should succeed."""
|
||||
encoded = "-tmp-copilot-e2b-test-read"
|
||||
filepath = self._make_tool_results_file(
|
||||
encoded, "result.txt", "line 1\nline 2\nline 3\n"
|
||||
)
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local(filepath, offset=0, limit=2000)
|
||||
assert result["isError"] is False
|
||||
assert "line 1" in result["content"][0]["text"]
|
||||
assert "line 2" in result["content"][0]["text"]
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
os.unlink(filepath)
|
||||
|
||||
def test_read_disallowed_path_blocked(self):
|
||||
"""Reading /etc/passwd should be blocked by the allowlist."""
|
||||
result = _read_local("/etc/passwd", offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
assert "not allowed" in result["content"][0]["text"].lower()
|
||||
|
||||
def test_read_nonexistent_tool_results(self):
|
||||
"""A tool-results path that doesn't exist returns FileNotFoundError."""
|
||||
encoded = "-tmp-copilot-e2b-test-nofile"
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
os.makedirs(tool_results_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_results_dir, "nonexistent.txt")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local(filepath, offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
assert "not found" in result["content"][0]["text"].lower()
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
os.rmdir(tool_results_dir)
|
||||
|
||||
def test_read_traversal_path_blocked(self):
|
||||
"""A traversal attempt that escapes allowed directories is blocked."""
|
||||
result = _read_local("/tmp/copilot-abc/../../etc/shadow", offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
assert "not allowed" in result["content"][0]["text"].lower()
|
||||
|
||||
def test_read_arbitrary_host_path_blocked(self):
|
||||
"""Arbitrary host paths are blocked even if they exist."""
|
||||
result = _read_local("/proc/self/environ", offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
|
||||
def test_read_with_offset_and_limit(self):
|
||||
"""Offset and limit should control which lines are returned."""
|
||||
encoded = "-tmp-copilot-e2b-test-offset"
|
||||
content = "".join(f"line {i}\n" for i in range(10))
|
||||
filepath = self._make_tool_results_file(encoded, "lines.txt", content)
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local(filepath, offset=3, limit=2)
|
||||
assert result["isError"] is False
|
||||
text = result["content"][0]["text"]
|
||||
assert "line 3" in text
|
||||
assert "line 4" in text
|
||||
assert "line 2" not in text
|
||||
assert "line 5" not in text
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
os.unlink(filepath)
|
||||
|
||||
def test_read_without_project_dir_blocks_all(self):
|
||||
"""Without _current_project_dir set, all paths are blocked."""
|
||||
result = _read_local("/tmp/anything.txt", offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
@@ -1,172 +0,0 @@
|
||||
"""Tests for OTEL tracing setup in the SDK copilot path."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
class TestSetupLangfuseOtel:
|
||||
"""Tests for _setup_langfuse_otel()."""
|
||||
|
||||
def test_noop_when_langfuse_not_configured(self):
|
||||
"""No env vars should be set when Langfuse credentials are missing."""
|
||||
with patch(
|
||||
"backend.copilot.sdk.service._is_langfuse_configured", return_value=False
|
||||
):
|
||||
from backend.copilot.sdk.service import _setup_langfuse_otel
|
||||
|
||||
# Clear any previously set env vars
|
||||
env_keys = [
|
||||
"LANGSMITH_OTEL_ENABLED",
|
||||
"LANGSMITH_OTEL_ONLY",
|
||||
"LANGSMITH_TRACING",
|
||||
"OTEL_EXPORTER_OTLP_ENDPOINT",
|
||||
"OTEL_EXPORTER_OTLP_HEADERS",
|
||||
]
|
||||
saved = {k: os.environ.pop(k, None) for k in env_keys}
|
||||
try:
|
||||
_setup_langfuse_otel()
|
||||
for key in env_keys:
|
||||
assert key not in os.environ, f"{key} should not be set"
|
||||
finally:
|
||||
for k, v in saved.items():
|
||||
if v is not None:
|
||||
os.environ[k] = v
|
||||
|
||||
def test_sets_env_vars_when_langfuse_configured(self):
|
||||
"""OTEL env vars should be set when Langfuse credentials exist."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.secrets.langfuse_public_key = "pk-test-123"
|
||||
mock_settings.secrets.langfuse_secret_key = "sk-test-456"
|
||||
mock_settings.secrets.langfuse_host = "https://langfuse.example.com"
|
||||
mock_settings.secrets.langfuse_tracing_environment = "test"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service._is_langfuse_configured",
|
||||
return_value=True,
|
||||
),
|
||||
patch("backend.copilot.sdk.service.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.configure_claude_agent_sdk",
|
||||
return_value=True,
|
||||
) as mock_configure,
|
||||
):
|
||||
from backend.copilot.sdk.service import _setup_langfuse_otel
|
||||
|
||||
# Clear env vars so setdefault works
|
||||
env_keys = [
|
||||
"LANGSMITH_OTEL_ENABLED",
|
||||
"LANGSMITH_OTEL_ONLY",
|
||||
"LANGSMITH_TRACING",
|
||||
"OTEL_EXPORTER_OTLP_ENDPOINT",
|
||||
"OTEL_EXPORTER_OTLP_HEADERS",
|
||||
"OTEL_RESOURCE_ATTRIBUTES",
|
||||
]
|
||||
saved = {k: os.environ.pop(k, None) for k in env_keys}
|
||||
try:
|
||||
_setup_langfuse_otel()
|
||||
|
||||
assert os.environ["LANGSMITH_OTEL_ENABLED"] == "true"
|
||||
assert os.environ["LANGSMITH_OTEL_ONLY"] == "true"
|
||||
assert os.environ["LANGSMITH_TRACING"] == "true"
|
||||
assert (
|
||||
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"]
|
||||
== "https://langfuse.example.com/api/public/otel"
|
||||
)
|
||||
assert "Authorization=Basic" in os.environ["OTEL_EXPORTER_OTLP_HEADERS"]
|
||||
assert (
|
||||
os.environ["OTEL_RESOURCE_ATTRIBUTES"]
|
||||
== "langfuse.environment=test"
|
||||
)
|
||||
|
||||
mock_configure.assert_called_once_with(tags=["sdk"])
|
||||
finally:
|
||||
for k, v in saved.items():
|
||||
if v is not None:
|
||||
os.environ[k] = v
|
||||
elif k in os.environ:
|
||||
del os.environ[k]
|
||||
|
||||
def test_existing_env_vars_not_overwritten(self):
|
||||
"""Explicit env-var overrides should not be clobbered."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.secrets.langfuse_public_key = "pk-test"
|
||||
mock_settings.secrets.langfuse_secret_key = "sk-test"
|
||||
mock_settings.secrets.langfuse_host = "https://langfuse.example.com"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service._is_langfuse_configured",
|
||||
return_value=True,
|
||||
),
|
||||
patch("backend.copilot.sdk.service.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.configure_claude_agent_sdk",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
from backend.copilot.sdk.service import _setup_langfuse_otel
|
||||
|
||||
saved = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
|
||||
try:
|
||||
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = "https://custom.endpoint/v1"
|
||||
_setup_langfuse_otel()
|
||||
assert (
|
||||
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"]
|
||||
== "https://custom.endpoint/v1"
|
||||
)
|
||||
finally:
|
||||
if saved is not None:
|
||||
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = saved
|
||||
elif "OTEL_EXPORTER_OTLP_ENDPOINT" in os.environ:
|
||||
del os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"]
|
||||
|
||||
def test_graceful_failure_on_exception(self):
|
||||
"""Setup should not raise even if internal code fails."""
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service._is_langfuse_configured",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.Settings",
|
||||
side_effect=RuntimeError("settings unavailable"),
|
||||
),
|
||||
):
|
||||
from backend.copilot.sdk.service import _setup_langfuse_otel
|
||||
|
||||
# Should not raise — just logs and returns
|
||||
_setup_langfuse_otel()
|
||||
|
||||
|
||||
class TestPropagateAttributesImport:
|
||||
"""Verify langfuse.propagate_attributes is available."""
|
||||
|
||||
def test_propagate_attributes_is_importable(self):
|
||||
from langfuse import propagate_attributes
|
||||
|
||||
assert callable(propagate_attributes)
|
||||
|
||||
def test_propagate_attributes_returns_context_manager(self):
|
||||
from langfuse import propagate_attributes
|
||||
|
||||
ctx = propagate_attributes(user_id="u1", session_id="s1", tags=["test"])
|
||||
assert hasattr(ctx, "__enter__")
|
||||
assert hasattr(ctx, "__exit__")
|
||||
|
||||
|
||||
class TestReceiveResponseCompat:
|
||||
"""Verify ClaudeSDKClient.receive_response() exists (langsmith patches it)."""
|
||||
|
||||
def test_receive_response_exists(self):
|
||||
from claude_agent_sdk import ClaudeSDKClient
|
||||
|
||||
assert hasattr(ClaudeSDKClient, "receive_response")
|
||||
|
||||
def test_receive_response_is_async_generator(self):
|
||||
import inspect
|
||||
|
||||
from claude_agent_sdk import ClaudeSDKClient
|
||||
|
||||
method = getattr(ClaudeSDKClient, "receive_response")
|
||||
assert inspect.isfunction(method) or inspect.ismethod(method)
|
||||
@@ -118,7 +118,7 @@ async def test_build_query_resume_up_to_date():
|
||||
ChatMessage(role="user", content="what's new?"),
|
||||
]
|
||||
)
|
||||
result, was_compacted = await _build_query_message(
|
||||
result = await _build_query_message(
|
||||
"what's new?",
|
||||
session,
|
||||
use_resume=True,
|
||||
@@ -127,7 +127,6 @@ async def test_build_query_resume_up_to_date():
|
||||
)
|
||||
# transcript_msg_count == msg_count - 1, so no gap
|
||||
assert result == "what's new?"
|
||||
assert was_compacted is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -142,7 +141,7 @@ async def test_build_query_resume_stale_transcript():
|
||||
ChatMessage(role="user", content="turn 3"),
|
||||
]
|
||||
)
|
||||
result, was_compacted = await _build_query_message(
|
||||
result = await _build_query_message(
|
||||
"turn 3",
|
||||
session,
|
||||
use_resume=True,
|
||||
@@ -153,7 +152,6 @@ async def test_build_query_resume_stale_transcript():
|
||||
assert "turn 2" in result
|
||||
assert "reply 2" in result
|
||||
assert "Now, the user says:\nturn 3" in result
|
||||
assert was_compacted is False # gap context does not compact
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -166,7 +164,7 @@ async def test_build_query_resume_zero_msg_count():
|
||||
ChatMessage(role="user", content="new msg"),
|
||||
]
|
||||
)
|
||||
result, was_compacted = await _build_query_message(
|
||||
result = await _build_query_message(
|
||||
"new msg",
|
||||
session,
|
||||
use_resume=True,
|
||||
@@ -174,14 +172,13 @@ async def test_build_query_resume_zero_msg_count():
|
||||
session_id="test-session",
|
||||
)
|
||||
assert result == "new msg"
|
||||
assert was_compacted is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_no_resume_single_message():
|
||||
"""Without --resume and only 1 message, return raw message."""
|
||||
session = _make_session([ChatMessage(role="user", content="first")])
|
||||
result, was_compacted = await _build_query_message(
|
||||
result = await _build_query_message(
|
||||
"first",
|
||||
session,
|
||||
use_resume=False,
|
||||
@@ -189,7 +186,6 @@ async def test_build_query_no_resume_single_message():
|
||||
session_id="test-session",
|
||||
)
|
||||
assert result == "first"
|
||||
assert was_compacted is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -203,16 +199,16 @@ async def test_build_query_no_resume_multi_message(monkeypatch):
|
||||
]
|
||||
)
|
||||
|
||||
# Mock _compress_messages to return the messages as-is
|
||||
async def _mock_compress(msgs):
|
||||
return msgs, False
|
||||
# Mock _compress_conversation_history to return the messages as-is
|
||||
async def _mock_compress(sess):
|
||||
return sess.messages[:-1]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages",
|
||||
"backend.copilot.sdk.service._compress_conversation_history",
|
||||
_mock_compress,
|
||||
)
|
||||
|
||||
result, was_compacted = await _build_query_message(
|
||||
result = await _build_query_message(
|
||||
"new question",
|
||||
session,
|
||||
use_resume=False,
|
||||
@@ -223,33 +219,3 @@ async def test_build_query_no_resume_multi_message(monkeypatch):
|
||||
assert "older question" in result
|
||||
assert "older answer" in result
|
||||
assert "Now, the user says:\nnew question" in result
|
||||
assert was_compacted is False # mock returns False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_no_resume_multi_message_compacted(monkeypatch):
|
||||
"""When compression actually compacts, was_compacted should be True."""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="old"),
|
||||
ChatMessage(role="assistant", content="reply"),
|
||||
ChatMessage(role="user", content="new"),
|
||||
]
|
||||
)
|
||||
|
||||
async def _mock_compress(msgs):
|
||||
return msgs, True # Simulate actual compaction
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages",
|
||||
_mock_compress,
|
||||
)
|
||||
|
||||
result, was_compacted = await _build_query_message(
|
||||
"new",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=0,
|
||||
session_id="test-session",
|
||||
)
|
||||
assert was_compacted is True
|
||||
|
||||
@@ -6,6 +6,7 @@ ensuring multi-user isolation and preventing unauthorized operations.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
@@ -15,7 +16,6 @@ from .tool_adapter import (
|
||||
DANGEROUS_PATTERNS,
|
||||
MCP_TOOL_PREFIX,
|
||||
WORKSPACE_SCOPED_TOOLS,
|
||||
is_allowed_local_path,
|
||||
stash_pending_tool_output,
|
||||
)
|
||||
|
||||
@@ -38,20 +38,40 @@ def _validate_workspace_path(
|
||||
) -> dict[str, Any]:
|
||||
"""Validate that a workspace-scoped tool only accesses allowed paths.
|
||||
|
||||
Delegates to :func:`is_allowed_local_path` which permits:
|
||||
Allowed directories:
|
||||
- The SDK working directory (``/tmp/copilot-<session>/``)
|
||||
- The current session's tool-results directory
|
||||
(``~/.claude/projects/<encoded-cwd>/tool-results/``)
|
||||
- The SDK tool-results directory (``~/.claude/projects/…/tool-results/``)
|
||||
"""
|
||||
path = tool_input.get("file_path") or tool_input.get("path") or ""
|
||||
if not path:
|
||||
# Glob/Grep without a path default to cwd which is already sandboxed
|
||||
return {}
|
||||
|
||||
if is_allowed_local_path(path, sdk_cwd):
|
||||
# Resolve relative paths against sdk_cwd (the SDK sets cwd so the LLM
|
||||
# naturally uses relative paths like "test.txt" instead of absolute ones).
|
||||
# Tilde paths (~/) are home-dir references, not relative — expand first.
|
||||
if path.startswith("~"):
|
||||
resolved = os.path.realpath(os.path.expanduser(path))
|
||||
elif not os.path.isabs(path) and sdk_cwd:
|
||||
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
|
||||
else:
|
||||
resolved = os.path.realpath(path)
|
||||
|
||||
# Allow access within the SDK working directory
|
||||
if sdk_cwd:
|
||||
norm_cwd = os.path.realpath(sdk_cwd)
|
||||
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
|
||||
return {}
|
||||
|
||||
# Allow access to ~/.claude/projects/*/tool-results/ (big tool results)
|
||||
claude_dir = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
tool_results_seg = os.sep + "tool-results" + os.sep
|
||||
if resolved.startswith(claude_dir + os.sep) and tool_results_seg in resolved:
|
||||
return {}
|
||||
|
||||
logger.warning(f"Blocked {tool_name} outside workspace: {path}")
|
||||
logger.warning(
|
||||
f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})"
|
||||
)
|
||||
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
|
||||
return _deny(
|
||||
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
|
||||
@@ -126,7 +146,7 @@ def create_security_hooks(
|
||||
user_id: str | None,
|
||||
sdk_cwd: str | None = None,
|
||||
max_subtasks: int = 3,
|
||||
on_compact: Callable[[], None] | None = None,
|
||||
on_stop: Callable[[str, str], None] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create the security hooks configuration for Claude Agent SDK.
|
||||
|
||||
@@ -135,12 +155,15 @@ def create_security_hooks(
|
||||
- PostToolUse: Log successful tool executions
|
||||
- PostToolUseFailure: Log and handle failed tool executions
|
||||
- PreCompact: Log context compaction events (SDK handles compaction automatically)
|
||||
- Stop: Capture transcript path for stateless resume (when *on_stop* is provided)
|
||||
|
||||
Args:
|
||||
user_id: Current user ID for isolation validation
|
||||
sdk_cwd: SDK working directory for workspace-scoped tool validation
|
||||
max_subtasks: Maximum concurrent Task (sub-agent) spawns allowed per session
|
||||
on_compact: Callback invoked when SDK starts compacting context.
|
||||
on_stop: Callback ``(transcript_path, sdk_session_id)`` invoked when
|
||||
the SDK finishes processing — used to read the JSONL transcript
|
||||
before the CLI process exits.
|
||||
|
||||
Returns:
|
||||
Hooks configuration dict for ClaudeAgentOptions
|
||||
@@ -303,8 +326,30 @@ def create_security_hooks(
|
||||
logger.info(
|
||||
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
|
||||
)
|
||||
if on_compact is not None:
|
||||
on_compact()
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
# --- Stop hook: capture transcript path for stateless resume ---
|
||||
async def stop_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Capture transcript path when SDK finishes processing.
|
||||
|
||||
The Stop hook fires while the CLI process is still alive, giving us
|
||||
a reliable window to read the JSONL transcript before SIGTERM.
|
||||
"""
|
||||
_ = context, tool_use_id
|
||||
transcript_path = cast(str, input_data.get("transcript_path", ""))
|
||||
sdk_session_id = cast(str, input_data.get("session_id", ""))
|
||||
|
||||
if transcript_path and on_stop:
|
||||
logger.info(
|
||||
f"[SDK] Stop hook: transcript_path={transcript_path}, "
|
||||
f"sdk_session_id={sdk_session_id[:12]}..."
|
||||
)
|
||||
on_stop(transcript_path, sdk_session_id)
|
||||
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
hooks: dict[str, Any] = {
|
||||
@@ -316,6 +361,9 @@ def create_security_hooks(
|
||||
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
|
||||
}
|
||||
|
||||
if on_stop is not None:
|
||||
hooks["Stop"] = [HookMatcher(matcher=None, hooks=[stop_hook])]
|
||||
|
||||
return hooks
|
||||
except ImportError:
|
||||
# Fallback for when SDK isn't available - return empty hooks
|
||||
|
||||
@@ -120,31 +120,17 @@ def test_read_no_cwd_denies_absolute():
|
||||
|
||||
|
||||
def test_read_tool_results_allowed():
|
||||
from .tool_adapter import _current_project_dir
|
||||
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
|
||||
# is_allowed_local_path requires the session's encoded cwd to be set
|
||||
token = _current_project_dir.set("-tmp-copilot-abc123")
|
||||
try:
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_read_claude_projects_session_dir_allowed():
|
||||
"""Files within the current session's project dir are allowed."""
|
||||
from .tool_adapter import _current_project_dir
|
||||
|
||||
def test_read_claude_projects_without_tool_results_denied():
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
|
||||
token = _current_project_dir.set("-tmp-copilot-abc123")
|
||||
try:
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert not _is_denied(result)
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
# -- Built-in Bash is blocked (use bash_exec MCP tool instead) ---------------
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,147 +0,0 @@
|
||||
"""Tests for SDK service helpers."""
|
||||
|
||||
import base64
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .service import _prepare_file_attachments
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeFileInfo:
|
||||
id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
|
||||
|
||||
_PATCH_TARGET = "backend.copilot.sdk.service.get_manager"
|
||||
|
||||
|
||||
class TestPrepareFileAttachments:
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_list_returns_empty(self, tmp_path):
|
||||
result = await _prepare_file_attachments([], "u", "s", str(tmp_path))
|
||||
assert result.hint == ""
|
||||
assert result.image_blocks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_embedded_as_vision_block(self, tmp_path):
|
||||
"""JPEG images should become vision content blocks, not files on disk."""
|
||||
raw = b"\xff\xd8\xff\xe0fake-jpeg"
|
||||
info = _FakeFileInfo(
|
||||
id="abc",
|
||||
name="photo.jpg",
|
||||
path="/photo.jpg",
|
||||
mime_type="image/jpeg",
|
||||
size_bytes=len(raw),
|
||||
)
|
||||
mgr = AsyncMock()
|
||||
mgr.get_file_info.return_value = info
|
||||
mgr.read_file_by_id.return_value = raw
|
||||
|
||||
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
|
||||
result = await _prepare_file_attachments(
|
||||
["abc"], "user1", "sess1", str(tmp_path)
|
||||
)
|
||||
|
||||
assert "1 file" in result.hint
|
||||
assert "photo.jpg" in result.hint
|
||||
assert "embedded as image" in result.hint
|
||||
assert len(result.image_blocks) == 1
|
||||
block = result.image_blocks[0]
|
||||
assert block["type"] == "image"
|
||||
assert block["source"]["media_type"] == "image/jpeg"
|
||||
assert block["source"]["data"] == base64.b64encode(raw).decode("ascii")
|
||||
# Image should NOT be written to disk (embedded instead)
|
||||
assert not os.path.exists(os.path.join(tmp_path, "photo.jpg"))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pdf_saved_to_disk(self, tmp_path):
|
||||
"""PDFs should be saved to disk for Read tool access, not embedded."""
|
||||
info = _FakeFileInfo("f1", "doc.pdf", "/doc.pdf", "application/pdf", 50)
|
||||
mgr = AsyncMock()
|
||||
mgr.get_file_info.return_value = info
|
||||
mgr.read_file_by_id.return_value = b"%PDF-1.4 fake"
|
||||
|
||||
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
|
||||
result = await _prepare_file_attachments(["f1"], "u", "s", str(tmp_path))
|
||||
|
||||
assert result.image_blocks == []
|
||||
saved = tmp_path / "doc.pdf"
|
||||
assert saved.exists()
|
||||
assert saved.read_bytes() == b"%PDF-1.4 fake"
|
||||
assert str(saved) in result.hint
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_images_and_files(self, tmp_path):
|
||||
"""Images become blocks, non-images go to disk."""
|
||||
infos = {
|
||||
"id1": _FakeFileInfo("id1", "a.png", "/a.png", "image/png", 4),
|
||||
"id2": _FakeFileInfo("id2", "b.pdf", "/b.pdf", "application/pdf", 4),
|
||||
"id3": _FakeFileInfo("id3", "c.txt", "/c.txt", "text/plain", 4),
|
||||
}
|
||||
mgr = AsyncMock()
|
||||
mgr.get_file_info.side_effect = lambda fid: infos[fid]
|
||||
mgr.read_file_by_id.return_value = b"data"
|
||||
|
||||
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
|
||||
result = await _prepare_file_attachments(
|
||||
["id1", "id2", "id3"], "u", "s", str(tmp_path)
|
||||
)
|
||||
|
||||
assert "3 files" in result.hint
|
||||
assert "a.png" in result.hint
|
||||
assert "b.pdf" in result.hint
|
||||
assert "c.txt" in result.hint
|
||||
# Only the image should be a vision block
|
||||
assert len(result.image_blocks) == 1
|
||||
assert result.image_blocks[0]["source"]["media_type"] == "image/png"
|
||||
# Non-image files should be on disk
|
||||
assert (tmp_path / "b.pdf").exists()
|
||||
assert (tmp_path / "c.txt").exists()
|
||||
# Read tool hint should appear (has non-image files)
|
||||
assert "Read tool" in result.hint
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_singular_noun(self, tmp_path):
|
||||
info = _FakeFileInfo("x", "only.txt", "/only.txt", "text/plain", 2)
|
||||
mgr = AsyncMock()
|
||||
mgr.get_file_info.return_value = info
|
||||
mgr.read_file_by_id.return_value = b"hi"
|
||||
|
||||
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
|
||||
result = await _prepare_file_attachments(["x"], "u", "s", str(tmp_path))
|
||||
|
||||
assert "1 file." in result.hint
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_file_skipped(self, tmp_path):
|
||||
mgr = AsyncMock()
|
||||
mgr.get_file_info.return_value = None
|
||||
|
||||
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
|
||||
result = await _prepare_file_attachments(
|
||||
["missing-id"], "u", "s", str(tmp_path)
|
||||
)
|
||||
|
||||
assert result.hint == ""
|
||||
assert result.image_blocks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_only_no_read_hint(self, tmp_path):
|
||||
"""When all files are images, no Read tool hint should appear."""
|
||||
info = _FakeFileInfo("i1", "cat.png", "/cat.png", "image/png", 4)
|
||||
mgr = AsyncMock()
|
||||
mgr.get_file_info.return_value = info
|
||||
mgr.read_file_by_id.return_value = b"data"
|
||||
|
||||
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
|
||||
result = await _prepare_file_attachments(["i1"], "u", "s", str(tmp_path))
|
||||
|
||||
assert "Read tool" not in result.hint
|
||||
assert len(result.image_blocks) == 1
|
||||
@@ -9,84 +9,20 @@ import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
from backend.copilot.tools.base import BaseTool
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
from .e2b_file_tools import E2B_FILE_TOOL_NAMES, E2B_FILE_TOOLS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Allowed base directory for the Read tool (SDK saves oversized tool results here).
|
||||
# Restricted to ~/.claude/projects/ and further validated to require "tool-results"
|
||||
# in the path — prevents reading settings, credentials, or other sensitive files.
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
# Max MCP response size in chars — keeps tool output under the SDK's 10 MB JSON buffer.
|
||||
_MCP_MAX_CHARS = 500_000
|
||||
|
||||
# Context variable holding the encoded project directory name for the current
|
||||
# session (e.g. "-private-tmp-copilot-<uuid>"). Set by set_execution_context()
|
||||
# so that path validation can scope tool-results reads to the current session.
|
||||
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
|
||||
|
||||
|
||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does.
|
||||
|
||||
The CLI replaces all non-alphanumeric characters with ``-``.
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
|
||||
|
||||
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
"""Check whether *path* is an allowed host-filesystem path.
|
||||
|
||||
Allowed:
|
||||
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/`` — the SDK's
|
||||
project directory for this session (tool-results, transcripts, etc.)
|
||||
|
||||
Both checks are scoped to the **current session** so sessions cannot
|
||||
read each other's data.
|
||||
"""
|
||||
if not path:
|
||||
return False
|
||||
|
||||
if path.startswith("~"):
|
||||
resolved = os.path.realpath(os.path.expanduser(path))
|
||||
elif not os.path.isabs(path) and sdk_cwd:
|
||||
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
|
||||
else:
|
||||
resolved = os.path.realpath(path)
|
||||
|
||||
# Allow access within the SDK working directory
|
||||
if sdk_cwd:
|
||||
norm_cwd = os.path.realpath(sdk_cwd)
|
||||
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
|
||||
return True
|
||||
|
||||
# Allow access within the current session's CLI project directory
|
||||
# (~/.claude/projects/<encoded-cwd>/).
|
||||
encoded = _current_project_dir.get("")
|
||||
if encoded:
|
||||
session_project = os.path.join(_SDK_PROJECTS_DIR, encoded)
|
||||
if resolved == session_project or resolved.startswith(session_project + os.sep):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
_SDK_PROJECTS_DIR = os.path.expanduser("~/.claude/projects/")
|
||||
|
||||
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
|
||||
MCP_SERVER_NAME = "copilot"
|
||||
@@ -97,15 +33,6 @@ _current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default
|
||||
_current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||
"current_session", default=None
|
||||
)
|
||||
# E2B cloud sandbox for the current turn (None when E2B is not configured).
|
||||
# Passed to bash_exec so commands run on E2B instead of the local bwrap sandbox.
|
||||
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
|
||||
"_current_sandbox", default=None
|
||||
)
|
||||
# Raw SDK working directory path (e.g. /tmp/copilot-<session_id>).
|
||||
# Used by workspace tools to save binary files for the CLI's built-in Read.
|
||||
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
|
||||
|
||||
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||
# response adapter when it builds StreamToolOutputAvailable.
|
||||
@@ -126,39 +53,22 @@ _stash_event: ContextVar[asyncio.Event | None] = ContextVar(
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
sandbox: "AsyncSandbox | None" = None,
|
||||
sdk_cwd: str | None = None,
|
||||
) -> None:
|
||||
"""Set the execution context for tool calls.
|
||||
|
||||
This must be called before streaming begins to ensure tools have access
|
||||
to user_id, session, and (optionally) an E2B sandbox for bash execution.
|
||||
to user_id and session information.
|
||||
|
||||
Args:
|
||||
user_id: Current user's ID.
|
||||
session: Current chat session.
|
||||
sandbox: Optional E2B sandbox; when set, bash_exec routes commands there.
|
||||
sdk_cwd: SDK working directory; used to scope tool-results reads.
|
||||
"""
|
||||
_current_user_id.set(user_id)
|
||||
_current_session.set(session)
|
||||
_current_sandbox.set(sandbox)
|
||||
_current_sdk_cwd.set(sdk_cwd or "")
|
||||
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
|
||||
_pending_tool_outputs.set({})
|
||||
_stash_event.set(asyncio.Event())
|
||||
|
||||
|
||||
def get_current_sandbox() -> "AsyncSandbox | None":
|
||||
"""Return the E2B sandbox for the current turn, or None."""
|
||||
return _current_sandbox.get()
|
||||
|
||||
|
||||
def get_sdk_cwd() -> str:
|
||||
"""Return the SDK ephemeral working directory for the current turn."""
|
||||
return _current_sdk_cwd.get()
|
||||
|
||||
|
||||
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
"""Get the current execution context."""
|
||||
return (
|
||||
@@ -272,12 +182,66 @@ async def _execute_tool_sync(
|
||||
result.output if isinstance(result.output, str) else json.dumps(result.output)
|
||||
)
|
||||
|
||||
# Stash the full output before the SDK potentially truncates it.
|
||||
pending = _pending_tool_outputs.get(None)
|
||||
if pending is not None:
|
||||
pending.setdefault(base_tool.name, []).append(text)
|
||||
|
||||
content_blocks: list[dict[str, str]] = [{"type": "text", "text": text}]
|
||||
|
||||
# If the tool result contains inline image data, add an MCP image block
|
||||
# so Claude can "see" the image (e.g. read_workspace_file on a small PNG).
|
||||
image_block = _extract_image_block(text)
|
||||
if image_block:
|
||||
content_blocks.append(image_block)
|
||||
|
||||
return {
|
||||
"content": [{"type": "text", "text": text}],
|
||||
"content": content_blocks,
|
||||
"isError": not result.success,
|
||||
}
|
||||
|
||||
|
||||
# MIME types that Claude can process as image content blocks.
|
||||
_SUPPORTED_IMAGE_TYPES = frozenset(
|
||||
{"image/png", "image/jpeg", "image/gif", "image/webp"}
|
||||
)
|
||||
|
||||
|
||||
def _extract_image_block(text: str) -> dict[str, str] | None:
|
||||
"""Extract an MCP image content block from a tool result JSON string.
|
||||
|
||||
Detects workspace file responses with ``content_base64`` and an image
|
||||
MIME type, returning an MCP-format image block that allows Claude to
|
||||
"see" the image. Returns ``None`` if the result is not an inline image.
|
||||
"""
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
|
||||
mime_type = data.get("mime_type", "")
|
||||
base64_content = data.get("content_base64", "")
|
||||
|
||||
# Only inline small images — large ones would exceed Claude's limits.
|
||||
# 32 KB raw ≈ ~43 KB base64.
|
||||
_MAX_IMAGE_BASE64_BYTES = 43_000
|
||||
if (
|
||||
mime_type in _SUPPORTED_IMAGE_TYPES
|
||||
and base64_content
|
||||
and len(base64_content) <= _MAX_IMAGE_BASE64_BYTES
|
||||
):
|
||||
return {
|
||||
"type": "image",
|
||||
"data": base64_content,
|
||||
"mimeType": mime_type,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _mcp_error(message: str) -> dict[str, Any]:
|
||||
return {
|
||||
"content": [
|
||||
@@ -320,32 +284,29 @@ def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
|
||||
|
||||
|
||||
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Read a local file with optional offset/limit.
|
||||
"""Read a file with optional offset/limit. Restricted to SDK working directory.
|
||||
|
||||
Only allows paths that pass :func:`is_allowed_local_path` — the current
|
||||
session's tool-results directory and ephemeral working directory.
|
||||
After reading, the file is deleted to prevent accumulation in long-running pods.
|
||||
"""
|
||||
file_path = args.get("file_path", "")
|
||||
offset = args.get("offset", 0)
|
||||
limit = args.get("limit", 2000)
|
||||
|
||||
if not is_allowed_local_path(file_path):
|
||||
# Security: only allow reads under ~/.claude/projects/**/tool-results/
|
||||
real_path = os.path.realpath(file_path)
|
||||
if not real_path.startswith(_SDK_PROJECTS_DIR) or "tool-results" not in real_path:
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
resolved = os.path.realpath(os.path.expanduser(file_path))
|
||||
try:
|
||||
with open(resolved) as f:
|
||||
with open(real_path) as f:
|
||||
selected = list(itertools.islice(f, offset, offset + limit))
|
||||
content = "".join(selected)
|
||||
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
|
||||
# don't delete here — the SDK may read in multiple chunks.
|
||||
return {
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"isError": False,
|
||||
}
|
||||
return {"content": [{"type": "text", "text": content}], "isError": False}
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
|
||||
@@ -383,86 +344,50 @@ _READ_TOOL_SCHEMA = {
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCP result helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _text_from_mcp_result(result: dict[str, Any]) -> str:
|
||||
"""Extract concatenated text from an MCP response's content blocks."""
|
||||
content = result.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
return ""
|
||||
return "".join(
|
||||
b.get("text", "")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
)
|
||||
|
||||
|
||||
def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
# Create the MCP server configuration
|
||||
def create_copilot_mcp_server():
|
||||
"""Create an in-process MCP server configuration for CoPilot tools.
|
||||
|
||||
When *use_e2b* is True, five additional MCP file tools are registered
|
||||
that route directly to the E2B sandbox filesystem, and the caller should
|
||||
disable the corresponding SDK built-in tools via
|
||||
:func:`get_sdk_disallowed_tools`.
|
||||
This can be passed to ClaudeAgentOptions.mcp_servers.
|
||||
|
||||
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
|
||||
package being available. This function returns the configuration that
|
||||
can be used with the SDK.
|
||||
"""
|
||||
try:
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
|
||||
def _truncating(fn, tool_name: str):
|
||||
"""Wrap a tool handler so its response is truncated to stay under the
|
||||
SDK's 10 MB JSON buffer, and stash the (truncated) output for the
|
||||
response adapter before the SDK can apply its own head-truncation.
|
||||
# Create decorated tool functions
|
||||
sdk_tools = []
|
||||
|
||||
Applied once to every registered tool."""
|
||||
|
||||
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
|
||||
result = await fn(args)
|
||||
truncated = truncate(result, _MCP_MAX_CHARS)
|
||||
|
||||
# Stash the text so the response adapter can forward our
|
||||
# middle-out truncated version to the frontend instead of the
|
||||
# SDK's head-truncated version (for outputs >~100 KB the SDK
|
||||
# persists to tool-results/ with a 2 KB head-only preview).
|
||||
if not truncated.get("isError"):
|
||||
text = _text_from_mcp_result(truncated)
|
||||
if text:
|
||||
stash_pending_tool_output(tool_name, text)
|
||||
|
||||
return truncated
|
||||
|
||||
return wrapper
|
||||
|
||||
sdk_tools = []
|
||||
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
handler = create_tool_handler(base_tool)
|
||||
decorated = tool(
|
||||
tool_name,
|
||||
base_tool.description,
|
||||
_build_input_schema(base_tool),
|
||||
)(_truncating(handler, tool_name))
|
||||
sdk_tools.append(decorated)
|
||||
|
||||
# E2B file tools replace SDK built-in Read/Write/Edit/Glob/Grep.
|
||||
if use_e2b:
|
||||
for name, desc, schema, handler in E2B_FILE_TOOLS:
|
||||
decorated = tool(name, desc, schema)(_truncating(handler, name))
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
handler = create_tool_handler(base_tool)
|
||||
decorated = tool(
|
||||
tool_name,
|
||||
base_tool.description,
|
||||
_build_input_schema(base_tool),
|
||||
)(handler)
|
||||
sdk_tools.append(decorated)
|
||||
|
||||
# Read tool for SDK-truncated tool results (always needed).
|
||||
read_tool = tool(
|
||||
_READ_TOOL_NAME,
|
||||
_READ_TOOL_DESCRIPTION,
|
||||
_READ_TOOL_SCHEMA,
|
||||
)(_truncating(_read_file_handler, _READ_TOOL_NAME))
|
||||
sdk_tools.append(read_tool)
|
||||
# Add the Read tool so the SDK can read back oversized tool results
|
||||
read_tool = tool(
|
||||
_READ_TOOL_NAME,
|
||||
_READ_TOOL_DESCRIPTION,
|
||||
_READ_TOOL_SCHEMA,
|
||||
)(_read_file_handler)
|
||||
sdk_tools.append(read_tool)
|
||||
|
||||
return create_sdk_mcp_server(
|
||||
name=MCP_SERVER_NAME,
|
||||
version="1.0.0",
|
||||
tools=sdk_tools,
|
||||
)
|
||||
server = create_sdk_mcp_server(
|
||||
name=MCP_SERVER_NAME,
|
||||
version="1.0.0",
|
||||
tools=sdk_tools,
|
||||
)
|
||||
|
||||
return server
|
||||
|
||||
except ImportError:
|
||||
# Let ImportError propagate so service.py handles the fallback
|
||||
raise
|
||||
|
||||
|
||||
# SDK built-in tools allowed within the workspace directory.
|
||||
@@ -472,11 +397,16 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
# Task allows spawning sub-agents (rate-limited by security hooks).
|
||||
# WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk.
|
||||
# TodoWrite manages the task checklist shown in the UI — no security concern.
|
||||
# In E2B mode, all five are disabled — MCP equivalents provide direct sandbox
|
||||
# access. read_file also handles local tool-results and ephemeral reads.
|
||||
_SDK_BUILTIN_FILE_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep"]
|
||||
_SDK_BUILTIN_ALWAYS = ["Task", "WebSearch", "TodoWrite"]
|
||||
_SDK_BUILTIN_TOOLS = [*_SDK_BUILTIN_FILE_TOOLS, *_SDK_BUILTIN_ALWAYS]
|
||||
_SDK_BUILTIN_TOOLS = [
|
||||
"Read",
|
||||
"Write",
|
||||
"Edit",
|
||||
"Glob",
|
||||
"Grep",
|
||||
"Task",
|
||||
"WebSearch",
|
||||
"TodoWrite",
|
||||
]
|
||||
|
||||
# SDK built-in tools that must be explicitly blocked.
|
||||
# Bash: dangerous — agent uses mcp__copilot__bash_exec with kernel-level
|
||||
@@ -484,10 +414,15 @@ _SDK_BUILTIN_TOOLS = [*_SDK_BUILTIN_FILE_TOOLS, *_SDK_BUILTIN_ALWAYS]
|
||||
# WebFetch: SSRF risk — can reach internal network (localhost, 10.x, etc.).
|
||||
# Agent uses the SSRF-protected mcp__copilot__web_fetch tool instead.
|
||||
# AskUserQuestion: interactive CLI tool — no terminal in copilot context.
|
||||
# ToolSearch: SDK bug — ToolSearch results are passed as raw match objects
|
||||
# instead of text content blocks, causing Anthropic API 400 errors.
|
||||
# All copilot tools are already explicitly listed in allowed_tools,
|
||||
# so dynamic tool search is unnecessary.
|
||||
SDK_DISALLOWED_TOOLS = [
|
||||
"Bash",
|
||||
"WebFetch",
|
||||
"AskUserQuestion",
|
||||
"ToolSearch",
|
||||
]
|
||||
|
||||
# Tools that are blocked entirely in security hooks (defence-in-depth).
|
||||
@@ -523,37 +458,11 @@ DANGEROUS_PATTERNS = [
|
||||
r"subprocess",
|
||||
]
|
||||
|
||||
# Static tool name list for the non-E2B case (backward compatibility).
|
||||
# List of tool names for allowed_tools configuration
|
||||
# Include MCP tools, the MCP Read tool for oversized results,
|
||||
# and SDK built-in file tools for workspace operations.
|
||||
COPILOT_TOOL_NAMES = [
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
|
||||
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
|
||||
*_SDK_BUILTIN_TOOLS,
|
||||
]
|
||||
|
||||
|
||||
def get_copilot_tool_names(*, use_e2b: bool = False) -> list[str]:
|
||||
"""Build the ``allowed_tools`` list for :class:`ClaudeAgentOptions`.
|
||||
|
||||
When *use_e2b* is True the SDK built-in file tools are replaced by MCP
|
||||
equivalents that route to the E2B sandbox.
|
||||
"""
|
||||
if not use_e2b:
|
||||
return list(COPILOT_TOOL_NAMES)
|
||||
|
||||
return [
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
|
||||
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in E2B_FILE_TOOL_NAMES],
|
||||
*_SDK_BUILTIN_ALWAYS,
|
||||
]
|
||||
|
||||
|
||||
def get_sdk_disallowed_tools(*, use_e2b: bool = False) -> list[str]:
|
||||
"""Build the ``disallowed_tools`` list for :class:`ClaudeAgentOptions`.
|
||||
|
||||
When *use_e2b* is True the SDK built-in file tools are also disabled
|
||||
because MCP equivalents provide direct sandbox access.
|
||||
"""
|
||||
if not use_e2b:
|
||||
return list(SDK_DISALLOWED_TOOLS)
|
||||
return [*SDK_DISALLOWED_TOOLS, *_SDK_BUILTIN_FILE_TOOLS]
|
||||
|
||||
@@ -1,170 +0,0 @@
|
||||
"""Tests for tool_adapter helpers: truncation, stash, context vars."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
from .tool_adapter import (
|
||||
_MCP_MAX_CHARS,
|
||||
_text_from_mcp_result,
|
||||
get_sdk_cwd,
|
||||
pop_pending_tool_output,
|
||||
set_execution_context,
|
||||
stash_pending_tool_output,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _text_from_mcp_result
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTextFromMcpResult:
|
||||
def test_single_text_block(self):
|
||||
result = {"content": [{"type": "text", "text": "hello"}]}
|
||||
assert _text_from_mcp_result(result) == "hello"
|
||||
|
||||
def test_multiple_text_blocks_concatenated(self):
|
||||
result = {
|
||||
"content": [
|
||||
{"type": "text", "text": "one"},
|
||||
{"type": "text", "text": "two"},
|
||||
]
|
||||
}
|
||||
assert _text_from_mcp_result(result) == "onetwo"
|
||||
|
||||
def test_non_text_blocks_ignored(self):
|
||||
result = {
|
||||
"content": [
|
||||
{"type": "image", "data": "..."},
|
||||
{"type": "text", "text": "only this"},
|
||||
]
|
||||
}
|
||||
assert _text_from_mcp_result(result) == "only this"
|
||||
|
||||
def test_empty_content_list(self):
|
||||
assert _text_from_mcp_result({"content": []}) == ""
|
||||
|
||||
def test_missing_content_key(self):
|
||||
assert _text_from_mcp_result({}) == ""
|
||||
|
||||
def test_non_list_content(self):
|
||||
assert _text_from_mcp_result({"content": "raw string"}) == ""
|
||||
|
||||
def test_missing_text_field(self):
|
||||
result = {"content": [{"type": "text"}]}
|
||||
assert _text_from_mcp_result(result) == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_sdk_cwd
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetSdkCwd:
|
||||
def test_returns_empty_string_by_default(self):
|
||||
set_execution_context(
|
||||
user_id="test",
|
||||
session=None, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
)
|
||||
assert get_sdk_cwd() == ""
|
||||
|
||||
def test_returns_set_value(self):
|
||||
set_execution_context(
|
||||
user_id="test",
|
||||
session=None, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
sdk_cwd="/tmp/copilot-test-123",
|
||||
)
|
||||
assert get_sdk_cwd() == "/tmp/copilot-test-123"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# stash / pop round-trip (the mechanism _truncating relies on)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolOutputStash:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init_context(self):
|
||||
"""Initialise the context vars that stash_pending_tool_output needs."""
|
||||
set_execution_context(
|
||||
user_id="test",
|
||||
session=None, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
sdk_cwd="/tmp/test",
|
||||
)
|
||||
|
||||
def test_stash_and_pop(self):
|
||||
stash_pending_tool_output("my_tool", "output1")
|
||||
assert pop_pending_tool_output("my_tool") == "output1"
|
||||
|
||||
def test_pop_empty_returns_none(self):
|
||||
assert pop_pending_tool_output("nonexistent") is None
|
||||
|
||||
def test_fifo_order(self):
|
||||
stash_pending_tool_output("t", "first")
|
||||
stash_pending_tool_output("t", "second")
|
||||
assert pop_pending_tool_output("t") == "first"
|
||||
assert pop_pending_tool_output("t") == "second"
|
||||
assert pop_pending_tool_output("t") is None
|
||||
|
||||
def test_dict_serialised_to_json(self):
|
||||
stash_pending_tool_output("t", {"key": "value"})
|
||||
assert pop_pending_tool_output("t") == '{"key": "value"}'
|
||||
|
||||
def test_separate_tool_names(self):
|
||||
stash_pending_tool_output("a", "alpha")
|
||||
stash_pending_tool_output("b", "beta")
|
||||
assert pop_pending_tool_output("b") == "beta"
|
||||
assert pop_pending_tool_output("a") == "alpha"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _truncating wrapper (integration via create_copilot_mcp_server)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTruncationAndStashIntegration:
|
||||
"""Test truncation + stash behavior that _truncating relies on."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init_context(self):
|
||||
set_execution_context(
|
||||
user_id="test",
|
||||
session=None, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
sdk_cwd="/tmp/test",
|
||||
)
|
||||
|
||||
def test_small_output_stashed(self):
|
||||
"""Non-error output is stashed for the response adapter."""
|
||||
result = {
|
||||
"content": [{"type": "text", "text": "small output"}],
|
||||
"isError": False,
|
||||
}
|
||||
truncated = truncate(result, _MCP_MAX_CHARS)
|
||||
text = _text_from_mcp_result(truncated)
|
||||
assert text == "small output"
|
||||
stash_pending_tool_output("test_tool", text)
|
||||
assert pop_pending_tool_output("test_tool") == "small output"
|
||||
|
||||
def test_error_result_not_stashed(self):
|
||||
"""Error results should not be stashed."""
|
||||
result = {
|
||||
"content": [{"type": "text", "text": "error msg"}],
|
||||
"isError": True,
|
||||
}
|
||||
# _truncating only stashes when not result.get("isError")
|
||||
if not result.get("isError"):
|
||||
stash_pending_tool_output("err_tool", "should not happen")
|
||||
assert pop_pending_tool_output("err_tool") is None
|
||||
|
||||
def test_large_output_truncated(self):
|
||||
"""Output exceeding _MCP_MAX_CHARS is truncated before stashing."""
|
||||
big_text = "x" * (_MCP_MAX_CHARS + 100_000)
|
||||
result = {"content": [{"type": "text", "text": big_text}]}
|
||||
truncated = truncate(result, _MCP_MAX_CHARS)
|
||||
text = _text_from_mcp_result(truncated)
|
||||
assert len(text) < len(big_text)
|
||||
assert len(str(truncated)) <= _MCP_MAX_CHARS
|
||||
@@ -58,40 +58,41 @@ def strip_progress_entries(content: str) -> str:
|
||||
Removes entries whose ``type`` is in ``STRIPPABLE_TYPES`` and reparents
|
||||
any remaining child entries so the ``parentUuid`` chain stays intact.
|
||||
Typically reduces transcript size by ~30%.
|
||||
|
||||
Entries that are not stripped or reparented are kept as their original
|
||||
raw JSON line to avoid unnecessary re-serialization that changes
|
||||
whitespace or key ordering.
|
||||
"""
|
||||
lines = content.strip().split("\n")
|
||||
|
||||
# Parse entries, keeping the original line alongside the parsed dict.
|
||||
parsed: list[tuple[str, dict | None]] = []
|
||||
entries: list[dict] = []
|
||||
for line in lines:
|
||||
try:
|
||||
parsed.append((line, json.loads(line)))
|
||||
entries.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
parsed.append((line, None))
|
||||
# Keep unparseable lines as-is (safety)
|
||||
entries.append({"_raw": line})
|
||||
|
||||
# First pass: identify stripped UUIDs and build parent map.
|
||||
stripped_uuids: set[str] = set()
|
||||
uuid_to_parent: dict[str, str] = {}
|
||||
kept: list[dict] = []
|
||||
|
||||
for _line, entry in parsed:
|
||||
if entry is None:
|
||||
for entry in entries:
|
||||
if "_raw" in entry:
|
||||
kept.append(entry)
|
||||
continue
|
||||
uid = entry.get("uuid", "")
|
||||
parent = entry.get("parentUuid", "")
|
||||
entry_type = entry.get("type", "")
|
||||
|
||||
if uid:
|
||||
uuid_to_parent[uid] = parent
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES and uid:
|
||||
stripped_uuids.add(uid)
|
||||
|
||||
# Second pass: keep non-stripped entries, reparenting where needed.
|
||||
# Preserve original line when no reparenting is required.
|
||||
reparented: set[str] = set()
|
||||
for _line, entry in parsed:
|
||||
if entry is None:
|
||||
if entry_type in STRIPPABLE_TYPES:
|
||||
if uid:
|
||||
stripped_uuids.add(uid)
|
||||
else:
|
||||
kept.append(entry)
|
||||
|
||||
# Reparent: walk up chain through stripped entries to find surviving ancestor
|
||||
for entry in kept:
|
||||
if "_raw" in entry:
|
||||
continue
|
||||
parent = entry.get("parentUuid", "")
|
||||
original_parent = parent
|
||||
@@ -99,32 +100,63 @@ def strip_progress_entries(content: str) -> str:
|
||||
parent = uuid_to_parent.get(parent, "")
|
||||
if parent != original_parent:
|
||||
entry["parentUuid"] = parent
|
||||
uid = entry.get("uuid", "")
|
||||
if uid:
|
||||
reparented.add(uid)
|
||||
|
||||
result_lines: list[str] = []
|
||||
for line, entry in parsed:
|
||||
if entry is None:
|
||||
result_lines.append(line)
|
||||
continue
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES:
|
||||
continue
|
||||
uid = entry.get("uuid", "")
|
||||
if uid in reparented:
|
||||
# Re-serialize only entries whose parentUuid was changed.
|
||||
result_lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
for entry in kept:
|
||||
if "_raw" in entry:
|
||||
result_lines.append(entry["_raw"])
|
||||
else:
|
||||
result_lines.append(line)
|
||||
result_lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
|
||||
return "\n".join(result_lines) + "\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Local file I/O (write temp file for --resume)
|
||||
# Local file I/O (read from CLI's JSONL, write temp file for --resume)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def read_transcript_file(transcript_path: str) -> str | None:
|
||||
"""Read a JSONL transcript file from disk.
|
||||
|
||||
Returns the raw JSONL content, or ``None`` if the file is missing, empty,
|
||||
or only contains metadata (≤2 lines with no conversation messages).
|
||||
"""
|
||||
if not transcript_path or not os.path.isfile(transcript_path):
|
||||
logger.debug(f"[Transcript] File not found: {transcript_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(transcript_path) as f:
|
||||
content = f.read()
|
||||
|
||||
if not content.strip():
|
||||
logger.debug("[Transcript] File is empty: %s", transcript_path)
|
||||
return None
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
|
||||
# Validate that the transcript has real conversation content
|
||||
# (not just metadata like queue-operation entries).
|
||||
if not validate_transcript(content):
|
||||
logger.debug(
|
||||
"[Transcript] No conversation content (%d lines) in %s",
|
||||
len(lines),
|
||||
transcript_path,
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"[Transcript] Read {len(lines)} lines, "
|
||||
f"{len(content)} bytes from {transcript_path}"
|
||||
)
|
||||
return content
|
||||
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.warning(f"[Transcript] Failed to read {transcript_path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
||||
"""Sanitize an ID for safe use in file paths.
|
||||
|
||||
@@ -139,6 +171,14 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
||||
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
||||
|
||||
|
||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does.
|
||||
|
||||
The CLI replaces all non-alphanumeric characters with ``-``.
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
|
||||
|
||||
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||
"""Remove the CLI's project directory for a specific working directory.
|
||||
|
||||
@@ -148,8 +188,7 @@ def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||
"""
|
||||
import shutil
|
||||
|
||||
# Encode cwd the same way CLI does (replaces non-alphanumeric with -)
|
||||
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||
cwd_encoded = _encode_cwd_for_cli(sdk_cwd)
|
||||
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
|
||||
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
|
||||
@@ -209,30 +248,32 @@ def write_transcript_to_tempfile(
|
||||
def validate_transcript(content: str | None) -> bool:
|
||||
"""Check that a transcript has actual conversation messages.
|
||||
|
||||
A valid transcript needs at least one assistant message (not just
|
||||
queue-operation / file-history-snapshot metadata). We do NOT require
|
||||
a ``type: "user"`` entry because with ``--resume`` the user's message
|
||||
is passed as a CLI query parameter and does not appear in the
|
||||
transcript file.
|
||||
A valid transcript for resume needs at least one user message and one
|
||||
assistant message (not just queue-operation / file-history-snapshot
|
||||
metadata).
|
||||
"""
|
||||
if not content or not content.strip():
|
||||
return False
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
if len(lines) < 2:
|
||||
return False
|
||||
|
||||
has_user = False
|
||||
has_assistant = False
|
||||
|
||||
for line in lines:
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
if entry.get("type") == "assistant":
|
||||
msg_type = entry.get("type")
|
||||
if msg_type == "user":
|
||||
has_user = True
|
||||
elif msg_type == "assistant":
|
||||
has_assistant = True
|
||||
except json.JSONDecodeError:
|
||||
return False
|
||||
|
||||
return has_assistant
|
||||
return has_user and has_assistant
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -287,48 +328,45 @@ async def upload_transcript(
|
||||
session_id: str,
|
||||
content: str,
|
||||
message_count: int = 0,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> None:
|
||||
"""Strip progress entries and upload complete transcript.
|
||||
"""Strip progress entries and upload transcript to bucket storage.
|
||||
|
||||
The transcript represents the FULL active context (atomic).
|
||||
Each upload REPLACES the previous transcript entirely.
|
||||
|
||||
The executor holds a cluster lock per session, so concurrent uploads for
|
||||
the same session cannot happen.
|
||||
Safety: only overwrites when the new (stripped) transcript is larger than
|
||||
what is already stored. Since JSONL is append-only, the latest transcript
|
||||
is always the longest. This prevents a slow/stale background task from
|
||||
clobbering a newer upload from a concurrent turn.
|
||||
|
||||
Args:
|
||||
content: Complete JSONL transcript (from TranscriptBuilder).
|
||||
message_count: ``len(session.messages)`` at upload time.
|
||||
message_count: ``len(session.messages)`` at upload time — used by
|
||||
the next turn to detect staleness and compress only the gap.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
# Strip metadata entries (progress, file-history-snapshot, etc.)
|
||||
# Note: SDK-built transcripts shouldn't have these, but strip for safety
|
||||
stripped = strip_progress_entries(content)
|
||||
if not validate_transcript(stripped):
|
||||
# Log entry types for debugging — helps identify why validation failed
|
||||
entry_types: list[str] = []
|
||||
for line in stripped.strip().split("\n"):
|
||||
try:
|
||||
entry_types.append(json.loads(line).get("type", "?"))
|
||||
except json.JSONDecodeError:
|
||||
entry_types.append("INVALID_JSON")
|
||||
logger.warning(
|
||||
"%s Skipping upload — stripped content not valid "
|
||||
"(types=%s, stripped_len=%d, raw_len=%d)",
|
||||
log_prefix,
|
||||
entry_types,
|
||||
len(stripped),
|
||||
len(content),
|
||||
f"[Transcript] Skipping upload — stripped content not valid "
|
||||
f"for session {session_id}"
|
||||
)
|
||||
logger.debug("%s Raw content preview: %s", log_prefix, content[:500])
|
||||
logger.debug("%s Stripped content: %s", log_prefix, stripped[:500])
|
||||
return
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
wid, fid, fname = _storage_path_parts(user_id, session_id)
|
||||
encoded = stripped.encode("utf-8")
|
||||
new_size = len(encoded)
|
||||
|
||||
# Check existing transcript size to avoid overwriting newer with older
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
try:
|
||||
existing = await storage.retrieve(path)
|
||||
if len(existing) >= new_size:
|
||||
logger.info(
|
||||
f"[Transcript] Skipping upload — existing ({len(existing)}B) "
|
||||
f">= new ({new_size}B) for session {session_id}"
|
||||
)
|
||||
return
|
||||
except (FileNotFoundError, Exception):
|
||||
pass # No existing transcript or retrieval error — proceed with upload
|
||||
|
||||
await storage.store(
|
||||
workspace_id=wid,
|
||||
@@ -337,8 +375,11 @@ async def upload_transcript(
|
||||
content=encoded,
|
||||
)
|
||||
|
||||
# Update metadata so message_count stays current. The gap-fill logic
|
||||
# in _build_query_message relies on it to avoid re-compressing messages.
|
||||
# Store metadata alongside the transcript so the next turn can detect
|
||||
# staleness and only compress the gap instead of the full history.
|
||||
# Wrapped in try/except so a metadata write failure doesn't orphan
|
||||
# the already-uploaded transcript — the next turn will just fall back
|
||||
# to full gap fill (msg_count=0).
|
||||
try:
|
||||
meta = {"message_count": message_count, "uploaded_at": time.time()}
|
||||
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
|
||||
@@ -349,18 +390,17 @@ async def upload_transcript(
|
||||
content=json.dumps(meta).encode("utf-8"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
|
||||
logger.warning(f"[Transcript] Failed to write metadata for {session_id}: {e}")
|
||||
|
||||
logger.info(
|
||||
f"{log_prefix} Uploaded {len(encoded)}B "
|
||||
f"(stripped from {len(content)}B, msg_count={message_count})"
|
||||
f"[Transcript] Uploaded {new_size}B "
|
||||
f"(stripped from {len(content)}B, msg_count={message_count}) "
|
||||
f"for session {session_id}"
|
||||
)
|
||||
|
||||
|
||||
async def download_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
user_id: str, session_id: str
|
||||
) -> TranscriptDownload | None:
|
||||
"""Download transcript and metadata from bucket storage.
|
||||
|
||||
@@ -376,10 +416,10 @@ async def download_transcript(
|
||||
data = await storage.retrieve(path)
|
||||
content = data.decode("utf-8")
|
||||
except FileNotFoundError:
|
||||
logger.debug(f"{log_prefix} No transcript in storage")
|
||||
logger.debug(f"[Transcript] No transcript in storage for {session_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
|
||||
logger.warning(f"[Transcript] Failed to download transcript: {e}")
|
||||
return None
|
||||
|
||||
# Try to load metadata (best-effort — old transcripts won't have it)
|
||||
@@ -402,7 +442,10 @@ async def download_transcript(
|
||||
except (FileNotFoundError, json.JSONDecodeError, Exception):
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
|
||||
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
|
||||
logger.info(
|
||||
f"[Transcript] Downloaded {len(content)}B "
|
||||
f"(msg_count={message_count}) for session {session_id}"
|
||||
)
|
||||
return TranscriptDownload(
|
||||
content=content,
|
||||
message_count=message_count,
|
||||
|
||||
@@ -1,140 +0,0 @@
|
||||
"""Build complete JSONL transcript from SDK messages.
|
||||
|
||||
The transcript represents the FULL active context at any point in time.
|
||||
Each upload REPLACES the previous transcript atomically.
|
||||
|
||||
Flow:
|
||||
Turn 1: Upload [msg1, msg2]
|
||||
Turn 2: Download [msg1, msg2] → Upload [msg1, msg2, msg3, msg4] (REPLACE)
|
||||
Turn 3: Download [msg1, msg2, msg3, msg4] → Upload [all messages] (REPLACE)
|
||||
|
||||
The transcript is never incremental - always the complete atomic state.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TranscriptEntry(BaseModel):
|
||||
"""Single transcript entry (user or assistant turn)."""
|
||||
|
||||
type: str
|
||||
uuid: str
|
||||
parentUuid: str | None
|
||||
message: dict[str, Any]
|
||||
|
||||
|
||||
class TranscriptBuilder:
|
||||
"""Build complete JSONL transcript from SDK messages.
|
||||
|
||||
This builder maintains the FULL conversation state, not incremental changes.
|
||||
The output is always the complete active context.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._entries: list[TranscriptEntry] = []
|
||||
self._last_uuid: str | None = None
|
||||
|
||||
def load_previous(self, content: str) -> None:
|
||||
"""Load complete previous transcript.
|
||||
|
||||
This loads the FULL previous context. As new messages come in,
|
||||
we append to this state. The final output is the complete context
|
||||
(previous + new), not just the delta.
|
||||
"""
|
||||
if not content or not content.strip():
|
||||
return
|
||||
|
||||
for line in content.strip().split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Failed to parse transcript line: %s", line[:100])
|
||||
continue
|
||||
|
||||
# Only load conversation messages (user/assistant)
|
||||
# Skip metadata entries
|
||||
if data.get("type") not in ("user", "assistant"):
|
||||
continue
|
||||
|
||||
entry = TranscriptEntry(
|
||||
type=data["type"],
|
||||
uuid=data.get("uuid") or str(uuid4()),
|
||||
parentUuid=data.get("parentUuid"),
|
||||
message=data.get("message", {}),
|
||||
)
|
||||
self._entries.append(entry)
|
||||
self._last_uuid = entry.uuid
|
||||
|
||||
logger.info(
|
||||
"Loaded %d entries from previous transcript (last_uuid=%s)",
|
||||
len(self._entries),
|
||||
self._last_uuid[:12] if self._last_uuid else None,
|
||||
)
|
||||
|
||||
def add_user_message(
|
||||
self, content: str | list[dict], uuid: str | None = None
|
||||
) -> None:
|
||||
"""Add user message to the complete context."""
|
||||
msg_uuid = uuid or str(uuid4())
|
||||
|
||||
self._entries.append(
|
||||
TranscriptEntry(
|
||||
type="user",
|
||||
uuid=msg_uuid,
|
||||
parentUuid=self._last_uuid,
|
||||
message={"role": "user", "content": content},
|
||||
)
|
||||
)
|
||||
self._last_uuid = msg_uuid
|
||||
|
||||
def add_assistant_message(
|
||||
self, content_blocks: list[dict], model: str = ""
|
||||
) -> None:
|
||||
"""Add assistant message to the complete context."""
|
||||
msg_uuid = str(uuid4())
|
||||
|
||||
self._entries.append(
|
||||
TranscriptEntry(
|
||||
type="assistant",
|
||||
uuid=msg_uuid,
|
||||
parentUuid=self._last_uuid,
|
||||
message={
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": content_blocks,
|
||||
},
|
||||
)
|
||||
)
|
||||
self._last_uuid = msg_uuid
|
||||
|
||||
def to_jsonl(self) -> str:
|
||||
"""Export complete context as JSONL.
|
||||
|
||||
Returns the FULL conversation state (all entries), not incremental.
|
||||
This output REPLACES any previous transcript.
|
||||
"""
|
||||
if not self._entries:
|
||||
return ""
|
||||
|
||||
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
@property
|
||||
def entry_count(self) -> int:
|
||||
"""Total number of entries in the complete context."""
|
||||
return len(self._entries)
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Whether this builder has any entries."""
|
||||
return len(self._entries) == 0
|
||||
@@ -5,6 +5,7 @@ import os
|
||||
|
||||
from .transcript import (
|
||||
STRIPPABLE_TYPES,
|
||||
read_transcript_file,
|
||||
strip_progress_entries,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
@@ -37,6 +38,49 @@ PROGRESS_ENTRY = {
|
||||
VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG)
|
||||
|
||||
|
||||
# --- read_transcript_file ---
|
||||
|
||||
|
||||
class TestReadTranscriptFile:
|
||||
def test_returns_content_for_valid_file(self, tmp_path):
|
||||
path = tmp_path / "session.jsonl"
|
||||
path.write_text(VALID_TRANSCRIPT)
|
||||
result = read_transcript_file(str(path))
|
||||
assert result is not None
|
||||
assert "user" in result
|
||||
|
||||
def test_returns_none_for_missing_file(self):
|
||||
assert read_transcript_file("/nonexistent/path.jsonl") is None
|
||||
|
||||
def test_returns_none_for_empty_path(self):
|
||||
assert read_transcript_file("") is None
|
||||
|
||||
def test_returns_none_for_empty_file(self, tmp_path):
|
||||
path = tmp_path / "empty.jsonl"
|
||||
path.write_text("")
|
||||
assert read_transcript_file(str(path)) is None
|
||||
|
||||
def test_returns_none_for_metadata_only(self, tmp_path):
|
||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY)
|
||||
path = tmp_path / "meta.jsonl"
|
||||
path.write_text(content)
|
||||
assert read_transcript_file(str(path)) is None
|
||||
|
||||
def test_returns_none_for_invalid_json(self, tmp_path):
|
||||
path = tmp_path / "bad.jsonl"
|
||||
path.write_text("not json\n{}\n{}\n")
|
||||
assert read_transcript_file(str(path)) is None
|
||||
|
||||
def test_no_size_limit(self, tmp_path):
|
||||
"""Large files are accepted — bucket storage has no size limit."""
|
||||
big_content = {"type": "user", "uuid": "u9", "data": "x" * 1_000_000}
|
||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, big_content, ASST_MSG)
|
||||
path = tmp_path / "big.jsonl"
|
||||
path.write_text(content)
|
||||
result = read_transcript_file(str(path))
|
||||
assert result is not None
|
||||
|
||||
|
||||
# --- write_transcript_to_tempfile ---
|
||||
|
||||
|
||||
@@ -111,56 +155,12 @@ class TestValidateTranscript:
|
||||
assert validate_transcript(content) is False
|
||||
|
||||
def test_assistant_only_no_user(self):
|
||||
"""With --resume the user message is a CLI query param, not a transcript entry.
|
||||
A transcript with only assistant entries is valid."""
|
||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, ASST_MSG)
|
||||
assert validate_transcript(content) is True
|
||||
|
||||
def test_resume_transcript_without_user_entry(self):
|
||||
"""Simulates a real --resume stop hook transcript: the CLI session file
|
||||
has summary + assistant entries but no user entry."""
|
||||
summary = {"type": "summary", "uuid": "s1", "text": "context..."}
|
||||
asst1 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"message": {"role": "assistant", "content": "Hello!"},
|
||||
}
|
||||
asst2 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a2",
|
||||
"parentUuid": "a1",
|
||||
"message": {"role": "assistant", "content": "Sure, let me help."},
|
||||
}
|
||||
content = _make_jsonl(summary, asst1, asst2)
|
||||
assert validate_transcript(content) is True
|
||||
|
||||
def test_single_assistant_entry(self):
|
||||
"""A transcript with just one assistant line is valid — the CLI may
|
||||
produce short transcripts for simple responses with no tool use."""
|
||||
content = json.dumps(ASST_MSG) + "\n"
|
||||
assert validate_transcript(content) is True
|
||||
assert validate_transcript(content) is False
|
||||
|
||||
def test_invalid_json_returns_false(self):
|
||||
assert validate_transcript("not json\n{}\n{}\n") is False
|
||||
|
||||
def test_malformed_json_after_valid_assistant_returns_false(self):
|
||||
"""Validation must scan all lines - malformed JSON anywhere should fail."""
|
||||
valid_asst = json.dumps(ASST_MSG)
|
||||
malformed = "not valid json"
|
||||
content = valid_asst + "\n" + malformed + "\n"
|
||||
assert validate_transcript(content) is False
|
||||
|
||||
def test_blank_lines_are_skipped(self):
|
||||
"""Transcripts with blank lines should be valid if they contain assistant entries."""
|
||||
content = (
|
||||
json.dumps(USER_MSG)
|
||||
+ "\n\n" # blank line
|
||||
+ json.dumps(ASST_MSG)
|
||||
+ "\n"
|
||||
+ "\n" # another blank line
|
||||
)
|
||||
assert validate_transcript(content) is True
|
||||
|
||||
|
||||
# --- strip_progress_entries ---
|
||||
|
||||
@@ -253,32 +253,3 @@ class TestStripProgressEntries:
|
||||
assert "queue-operation" not in result_types
|
||||
assert "user" in result_types
|
||||
assert "assistant" in result_types
|
||||
|
||||
def test_preserves_original_line_formatting(self):
|
||||
"""Non-reparented entries keep their original JSON formatting."""
|
||||
# Use pretty-printed JSON with spaces (as the CLI produces)
|
||||
original_line = json.dumps(USER_MSG) # default formatting with spaces
|
||||
compact_line = json.dumps(USER_MSG, separators=(",", ":"))
|
||||
assert original_line != compact_line # precondition
|
||||
|
||||
content = original_line + "\n" + json.dumps(ASST_MSG) + "\n"
|
||||
result = strip_progress_entries(content)
|
||||
result_lines = result.strip().split("\n")
|
||||
|
||||
# Original line should be byte-identical (not re-serialized)
|
||||
assert result_lines[0] == original_line
|
||||
|
||||
def test_reparented_entries_are_reserialized(self):
|
||||
"""Entries whose parentUuid changes must be re-serialized."""
|
||||
progress = {"type": "progress", "uuid": "p1", "parentUuid": "u1"}
|
||||
asst = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "p1",
|
||||
"message": {"role": "assistant", "content": "done"},
|
||||
}
|
||||
content = _make_jsonl(USER_MSG, progress, asst)
|
||||
result = strip_progress_entries(content)
|
||||
lines = result.strip().split("\n")
|
||||
asst_entry = json.loads(lines[-1])
|
||||
assert asst_entry["parentUuid"] == "u1" # reparented
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,14 +4,75 @@ from os import getenv
|
||||
|
||||
import pytest
|
||||
|
||||
from . import service as chat_service
|
||||
from .model import create_chat_session, get_chat_session, upsert_chat_session
|
||||
from .response_model import StreamError, StreamTextDelta
|
||||
from .response_model import StreamError, StreamTextDelta, StreamToolOutputAvailable
|
||||
from .sdk import service as sdk_service
|
||||
from .sdk.transcript import download_transcript
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion(setup_test_user, test_user_id):
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
|
||||
has_errors = False
|
||||
assistant_message = ""
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session.session_id, "Hello, how are you?", user_id=session.user_id
|
||||
):
|
||||
logger.info(chunk)
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
if isinstance(chunk, StreamTextDelta):
|
||||
assistant_message += chunk.delta
|
||||
|
||||
# StreamFinish is published by mark_session_completed (processor layer),
|
||||
# not by the service. The generator completing means the stream ended.
|
||||
assert not has_errors, "Error occurred while streaming chat completion"
|
||||
assert assistant_message, "Assistant message is empty"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user_id):
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
has_errors = False
|
||||
had_tool_calls = False
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session.session_id,
|
||||
"Please find me an agent that can help me with my business. Use the query 'moneny printing agent'",
|
||||
user_id=session.user_id,
|
||||
):
|
||||
logger.info(chunk)
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
if isinstance(chunk, StreamToolOutputAvailable):
|
||||
had_tool_calls = True
|
||||
|
||||
assert not has_errors, "Error occurred while streaming chat completion"
|
||||
assert had_tool_calls, "Tool calls did not occur"
|
||||
session = await get_chat_session(session.session_id)
|
||||
assert session, "Session not found"
|
||||
assert session.usage, "Usage is empty"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
||||
"""Test that the SDK --resume path captures and uses transcripts across turns.
|
||||
|
||||
@@ -733,10 +733,7 @@ async def mark_session_completed(
|
||||
# This is the SINGLE place that publishes StreamFinish — services and
|
||||
# the processor must NOT publish it themselves.
|
||||
try:
|
||||
await publish_chunk(
|
||||
turn_id,
|
||||
StreamFinish(),
|
||||
)
|
||||
await publish_chunk(turn_id, StreamFinish())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to publish StreamFinish for session {session_id}: {e}. "
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tracking import track_tool_called
|
||||
|
||||
from .add_understanding import AddUnderstandingTool
|
||||
from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreenshotTool
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
@@ -22,7 +20,6 @@ from .find_library_agent import FindLibraryAgentTool
|
||||
from .get_doc_page import GetDocPageTool
|
||||
from .run_agent import RunAgentTool
|
||||
from .run_block import RunBlockTool
|
||||
from .run_mcp_tool import RunMCPToolTool
|
||||
from .search_docs import SearchDocsTool
|
||||
from .web_fetch import WebFetchTool
|
||||
from .workspace_files import (
|
||||
@@ -33,7 +30,6 @@ from .workspace_files import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -49,16 +45,11 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"find_library_agent": FindLibraryAgentTool(),
|
||||
"run_agent": RunAgentTool(),
|
||||
"run_block": RunBlockTool(),
|
||||
"run_mcp_tool": RunMCPToolTool(),
|
||||
"view_agent_output": AgentOutputTool(),
|
||||
"search_docs": SearchDocsTool(),
|
||||
"get_doc_page": GetDocPageTool(),
|
||||
# Web fetch for safe URL retrieval
|
||||
"web_fetch": WebFetchTool(),
|
||||
# Agent-browser multi-step automation (navigate, act, screenshot)
|
||||
"browser_navigate": BrowserNavigateTool(),
|
||||
"browser_act": BrowserActTool(),
|
||||
"browser_screenshot": BrowserScreenshotTool(),
|
||||
# Sandboxed code execution (bubblewrap)
|
||||
"bash_exec": BashExecTool(),
|
||||
# Persistent workspace tools (cloud storage, survives across sessions)
|
||||
@@ -76,17 +67,10 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
find_agent_tool = TOOL_REGISTRY["find_agent"]
|
||||
run_agent_tool = TOOL_REGISTRY["run_agent"]
|
||||
|
||||
|
||||
def get_available_tools() -> list[ChatCompletionToolParam]:
|
||||
"""Return OpenAI tool schemas for tools available in the current environment.
|
||||
|
||||
Called per-request so that env-var or binary availability is evaluated
|
||||
fresh each time (e.g. browser_* tools are excluded when agent-browser
|
||||
CLI is not installed).
|
||||
"""
|
||||
return [
|
||||
tool.as_openai_tool() for tool in TOOL_REGISTRY.values() if tool.is_available
|
||||
]
|
||||
# Generated from registry for OpenAI API
|
||||
tools: list[ChatCompletionToolParam] = [
|
||||
tool.as_openai_tool() for tool in TOOL_REGISTRY.values()
|
||||
]
|
||||
|
||||
|
||||
def get_tool(tool_name: str) -> BaseTool | None:
|
||||
|
||||
@@ -1,876 +0,0 @@
|
||||
"""Agent-browser tools — multi-step browser automation for the Copilot.
|
||||
|
||||
Uses the agent-browser CLI (https://github.com/vercel-labs/agent-browser)
|
||||
which runs a local Chromium instance managed by a persistent daemon.
|
||||
|
||||
- Runs locally — no cloud account required
|
||||
- Full interaction support: click, fill, scroll, login flows, multi-step
|
||||
- Session persistence via --session-name: cookies/auth carry across tool calls
|
||||
within the same Copilot session, enabling login → navigate → extract workflows
|
||||
- Screenshot with --annotate overlays @ref labels, saved to workspace for user
|
||||
- The Claude Agent SDK's multi-turn loop handles orchestration — each tool call
|
||||
is one browser action; the LLM chains them naturally
|
||||
|
||||
SSRF protection:
|
||||
Uses the shared validate_url() from backend.util.request, which is the same
|
||||
guard used by HTTP blocks and web_fetch. It resolves ALL DNS answers (not just
|
||||
the first), blocks RFC 1918, loopback, link-local, 0.0.0.0/8, multicast,
|
||||
and all relevant IPv6 ranges, and applies IDNA encoding to prevent Unicode
|
||||
domain attacks.
|
||||
|
||||
Requires:
|
||||
npm install -g agent-browser
|
||||
agent-browser install (downloads Chromium, one-time per machine)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.util.request import validate_url
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
BrowserActResponse,
|
||||
BrowserNavigateResponse,
|
||||
BrowserScreenshotResponse,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .workspace_files import get_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Per-command timeout (seconds). Navigation + networkidle wait can be slow.
|
||||
_CMD_TIMEOUT = 45
|
||||
# Accessibility tree can be very large; cap it to keep LLM context manageable.
|
||||
_MAX_SNAPSHOT_CHARS = 20_000
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subprocess helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _run(
|
||||
session_name: str,
|
||||
*args: str,
|
||||
timeout: int = _CMD_TIMEOUT,
|
||||
) -> tuple[int, str, str]:
|
||||
"""Run agent-browser for the given session and return (rc, stdout, stderr).
|
||||
|
||||
Uses both:
|
||||
--session <name> → isolated Chromium context (no shared history/cookies
|
||||
with other Copilot sessions — prevents cross-session
|
||||
browser state leakage)
|
||||
--session-name <name> → persist cookies/localStorage across tool calls within
|
||||
the same session (enables login → navigate flows)
|
||||
"""
|
||||
cmd = [
|
||||
"agent-browser",
|
||||
"--session",
|
||||
session_name,
|
||||
"--session-name",
|
||||
session_name,
|
||||
*args,
|
||||
]
|
||||
proc = None
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout)
|
||||
return proc.returncode or 0, stdout.decode(), stderr.decode()
|
||||
except asyncio.TimeoutError:
|
||||
# Kill the orphaned subprocess so it does not linger in the process table.
|
||||
if proc is not None and proc.returncode is None:
|
||||
proc.kill()
|
||||
try:
|
||||
await proc.communicate()
|
||||
except Exception:
|
||||
pass # Best-effort reap; ignore errors during cleanup.
|
||||
return 1, "", f"Command timed out after {timeout}s."
|
||||
except FileNotFoundError:
|
||||
return (
|
||||
1,
|
||||
"",
|
||||
"agent-browser is not installed (run: npm install -g agent-browser && agent-browser install).",
|
||||
)
|
||||
|
||||
|
||||
async def _snapshot(session_name: str) -> str:
|
||||
"""Return the current page's interactive accessibility tree, truncated."""
|
||||
rc, stdout, stderr = await _run(session_name, "snapshot", "-i", "-c")
|
||||
if rc != 0:
|
||||
return f"[snapshot failed: {stderr[:300]}]"
|
||||
text = stdout.strip()
|
||||
if len(text) > _MAX_SNAPSHOT_CHARS:
|
||||
suffix = "\n\n[Snapshot truncated — use browser_act to navigate further]"
|
||||
keep = max(0, _MAX_SNAPSHOT_CHARS - len(suffix))
|
||||
text = text[:keep] + suffix
|
||||
return text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stateless session helpers — persist / restore browser state across pods
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Module-level cache of sessions known to be alive on this pod.
|
||||
# Avoids the subprocess probe on every tool call within the same pod.
|
||||
_alive_sessions: set[str] = set()
|
||||
|
||||
# Per-session locks to prevent concurrent _ensure_session calls from
|
||||
# triggering duplicate _restore_browser_state for the same session.
|
||||
# Protected by _session_locks_mutex to ensure setdefault/pop are not
|
||||
# interleaved across await boundaries.
|
||||
_session_locks: dict[str, asyncio.Lock] = {}
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
# Workspace filename for persisted browser state (auto-scoped to session).
|
||||
# Dot-prefixed so it is hidden from user workspace listings.
|
||||
_STATE_FILENAME = "._browser_state.json"
|
||||
|
||||
# Maximum concurrent subprocesses during cookie/storage restore.
|
||||
_RESTORE_CONCURRENCY = 10
|
||||
|
||||
# Maximum cookies to restore per session. Pathological sites can accumulate
|
||||
# thousands of cookies; restoring them all would be slow and is rarely useful.
|
||||
_MAX_RESTORE_COOKIES = 100
|
||||
|
||||
# Background tasks for fire-and-forget state persistence.
|
||||
# Prevents GC from collecting tasks before they complete.
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
def _fire_and_forget_save(
|
||||
session_name: str, user_id: str, session: ChatSession
|
||||
) -> None:
|
||||
"""Schedule state persistence as a background task (non-blocking).
|
||||
|
||||
State save is already best-effort (errors are swallowed), so running it
|
||||
in the background avoids adding latency to tool responses.
|
||||
"""
|
||||
task = asyncio.create_task(_save_browser_state(session_name, user_id, session))
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
|
||||
async def _has_local_session(session_name: str) -> bool:
|
||||
"""Check if the local agent-browser daemon for this session is running."""
|
||||
rc, _, _ = await _run(session_name, "get", "url", timeout=5)
|
||||
return rc == 0
|
||||
|
||||
|
||||
async def _save_browser_state(
|
||||
session_name: str, user_id: str, session: ChatSession
|
||||
) -> None:
|
||||
"""Persist browser state (cookies, localStorage, URL) to workspace.
|
||||
|
||||
Best-effort: errors are logged but never propagate to the tool response.
|
||||
"""
|
||||
try:
|
||||
# Gather state in parallel
|
||||
(rc_url, url_out, _), (rc_ck, ck_out, _), (rc_ls, ls_out, _) = (
|
||||
await asyncio.gather(
|
||||
_run(session_name, "get", "url", timeout=10),
|
||||
_run(session_name, "cookies", "get", "--json", timeout=10),
|
||||
_run(session_name, "storage", "local", "--json", timeout=10),
|
||||
)
|
||||
)
|
||||
|
||||
state = {
|
||||
"url": url_out.strip() if rc_url == 0 else "",
|
||||
"cookies": (json.loads(ck_out) if rc_ck == 0 and ck_out.strip() else []),
|
||||
"local_storage": (
|
||||
json.loads(ls_out) if rc_ls == 0 and ls_out.strip() else {}
|
||||
),
|
||||
}
|
||||
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
await manager.write_file(
|
||||
content=json.dumps(state).encode("utf-8"),
|
||||
filename=_STATE_FILENAME,
|
||||
mime_type="application/json",
|
||||
overwrite=True,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[browser] Failed to save browser state for session %s",
|
||||
session_name,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def _restore_browser_state(
|
||||
session_name: str, user_id: str, session: ChatSession
|
||||
) -> bool:
|
||||
"""Restore browser state from workspace storage into a fresh daemon.
|
||||
|
||||
Best-effort: errors are logged but never propagate to the tool response.
|
||||
Returns True on success (or no state to restore), False on failure.
|
||||
"""
|
||||
try:
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
|
||||
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
|
||||
if file_info is None:
|
||||
return True # No saved state — first call or never saved
|
||||
|
||||
state_bytes = await manager.read_file(_STATE_FILENAME)
|
||||
state = json.loads(state_bytes.decode("utf-8"))
|
||||
|
||||
url = state.get("url", "")
|
||||
cookies = state.get("cookies", [])
|
||||
local_storage = state.get("local_storage", {})
|
||||
|
||||
# Navigate first — starts daemon + sets the correct origin for cookies
|
||||
if url:
|
||||
# Validate the saved URL to prevent SSRF via stored redirect targets.
|
||||
try:
|
||||
await validate_url(url, trusted_origins=[])
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"[browser] State restore: blocked SSRF URL %s", url[:200]
|
||||
)
|
||||
return False
|
||||
|
||||
rc, _, stderr = await _run(session_name, "open", url)
|
||||
if rc != 0:
|
||||
logger.warning(
|
||||
"[browser] State restore: failed to open %s: %s",
|
||||
url,
|
||||
stderr[:200],
|
||||
)
|
||||
return False
|
||||
await _run(session_name, "wait", "--load", "load", timeout=15)
|
||||
|
||||
# Restore cookies and localStorage in parallel via asyncio.gather.
|
||||
# Semaphore caps concurrent subprocess spawns so we don't overwhelm the
|
||||
# system when a session has hundreds of cookies.
|
||||
sem = asyncio.Semaphore(_RESTORE_CONCURRENCY)
|
||||
|
||||
# Guard against pathological sites with thousands of cookies.
|
||||
if len(cookies) > _MAX_RESTORE_COOKIES:
|
||||
logger.debug(
|
||||
"[browser] State restore: capping cookies from %d to %d",
|
||||
len(cookies),
|
||||
_MAX_RESTORE_COOKIES,
|
||||
)
|
||||
cookies = cookies[:_MAX_RESTORE_COOKIES]
|
||||
|
||||
async def _set_cookie(c: dict[str, Any]) -> None:
|
||||
name = c.get("name", "")
|
||||
value = c.get("value", "")
|
||||
domain = c.get("domain", "")
|
||||
path = c.get("path", "/")
|
||||
if not (name and domain):
|
||||
return
|
||||
async with sem:
|
||||
rc, _, stderr = await _run(
|
||||
session_name,
|
||||
"cookies",
|
||||
"set",
|
||||
name,
|
||||
value,
|
||||
"--domain",
|
||||
domain,
|
||||
"--path",
|
||||
path,
|
||||
timeout=5,
|
||||
)
|
||||
if rc != 0:
|
||||
logger.debug(
|
||||
"[browser] State restore: cookie set failed for %s: %s",
|
||||
name,
|
||||
stderr[:100],
|
||||
)
|
||||
|
||||
async def _set_storage(key: str, val: object) -> None:
|
||||
async with sem:
|
||||
rc, _, stderr = await _run(
|
||||
session_name,
|
||||
"storage",
|
||||
"local",
|
||||
"set",
|
||||
key,
|
||||
str(val),
|
||||
timeout=5,
|
||||
)
|
||||
if rc != 0:
|
||||
logger.debug(
|
||||
"[browser] State restore: localStorage set failed for %s: %s",
|
||||
key,
|
||||
stderr[:100],
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
*[_set_cookie(c) for c in cookies],
|
||||
*[_set_storage(k, v) for k, v in local_storage.items()],
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[browser] Failed to restore browser state for session %s",
|
||||
session_name,
|
||||
exc_info=True,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def _ensure_session(
|
||||
session_name: str, user_id: str, session: ChatSession
|
||||
) -> None:
|
||||
"""Ensure the local browser daemon has state. Restore from cloud if needed."""
|
||||
if session_name in _alive_sessions:
|
||||
return
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.setdefault(session_name, asyncio.Lock())
|
||||
async with lock:
|
||||
# Double-check after acquiring lock — another coroutine may have restored.
|
||||
if session_name in _alive_sessions:
|
||||
return
|
||||
if await _has_local_session(session_name):
|
||||
_alive_sessions.add(session_name)
|
||||
return
|
||||
if await _restore_browser_state(session_name, user_id, session):
|
||||
_alive_sessions.add(session_name)
|
||||
|
||||
|
||||
async def close_browser_session(session_name: str, user_id: str | None = None) -> None:
|
||||
"""Shut down the local agent-browser daemon and clean up stored state.
|
||||
|
||||
Deletes ``._browser_state.json`` from workspace storage so cookies and
|
||||
other credentials do not linger after the session is deleted.
|
||||
|
||||
Best-effort: errors are logged but never raised.
|
||||
"""
|
||||
_alive_sessions.discard(session_name)
|
||||
async with _session_locks_mutex:
|
||||
_session_locks.pop(session_name, None)
|
||||
|
||||
# Delete persisted browser state (cookies, localStorage) from workspace.
|
||||
if user_id:
|
||||
try:
|
||||
manager = await get_manager(user_id, session_name)
|
||||
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
|
||||
if file_info is not None:
|
||||
await manager.delete_file(file_info.id)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"[browser] Failed to delete state file for session %s",
|
||||
session_name,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
rc, _, stderr = await _run(session_name, "close", timeout=10)
|
||||
if rc != 0:
|
||||
logger.debug(
|
||||
"[browser] close failed for session %s: %s",
|
||||
session_name,
|
||||
stderr[:200],
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"[browser] Exception closing browser session %s",
|
||||
session_name,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool: browser_navigate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BrowserNavigateTool(BaseTool):
|
||||
"""Navigate to a URL and return the page's interactive elements.
|
||||
|
||||
The browser session persists across tool calls within this Copilot session
|
||||
(keyed to session_id), so cookies and auth state carry over. This enables
|
||||
full login flows: navigate to login page → browser_act to fill credentials
|
||||
→ browser_act to submit → browser_navigate to the target page.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "browser_navigate"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Navigate to a URL using a real browser. Returns an accessibility "
|
||||
"tree snapshot listing the page's interactive elements with @ref IDs "
|
||||
"(e.g. @e3) that can be used with browser_act. "
|
||||
"Session persists — cookies and login state carry over between calls. "
|
||||
"Use this (with browser_act) for multi-step interaction: login flows, "
|
||||
"form filling, button clicks, or anything requiring page interaction. "
|
||||
"For plain static pages, prefer web_fetch — no browser overhead. "
|
||||
"For authenticated pages: navigate to the login page first, use browser_act "
|
||||
"to fill credentials and submit, then navigate to the target page. "
|
||||
"Note: for slow SPAs, the returned snapshot may reflect a partially-loaded "
|
||||
"state. If elements seem missing, use browser_act with action='wait' and a "
|
||||
"CSS selector or millisecond delay, then take a browser_screenshot to verify."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The HTTP/HTTPS URL to navigate to.",
|
||||
},
|
||||
"wait_for": {
|
||||
"type": "string",
|
||||
"enum": ["networkidle", "load", "domcontentloaded"],
|
||||
"default": "networkidle",
|
||||
"description": "When to consider navigation complete. Use 'networkidle' for SPAs (default).",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
return shutil.which("agent-browser") is not None
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Navigate to *url*, wait for the page to settle, and return a snapshot.
|
||||
|
||||
The snapshot is an accessibility-tree listing of interactive elements.
|
||||
Note: for slow SPAs that never fully idle, the snapshot may reflect a
|
||||
partially-loaded state (the wait is best-effort).
|
||||
"""
|
||||
url: str = (kwargs.get("url") or "").strip()
|
||||
wait_for: str = kwargs.get("wait_for") or "networkidle"
|
||||
session_name = session.session_id
|
||||
|
||||
if not url:
|
||||
return ErrorResponse(
|
||||
message="Please provide a URL to navigate to.",
|
||||
error="missing_url",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
try:
|
||||
await validate_url(url, trusted_origins=[])
|
||||
except ValueError as e:
|
||||
return ErrorResponse(
|
||||
message=str(e),
|
||||
error="blocked_url",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Restore browser state from cloud if this is a different pod
|
||||
if user_id:
|
||||
await _ensure_session(session_name, user_id, session)
|
||||
|
||||
# Navigate
|
||||
rc, _, stderr = await _run(session_name, "open", url)
|
||||
if rc != 0:
|
||||
logger.warning(
|
||||
"[browser_navigate] open failed for %s: %s", url, stderr[:300]
|
||||
)
|
||||
return ErrorResponse(
|
||||
message="Failed to navigate to URL.",
|
||||
error="navigation_failed",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Wait for page to settle (best-effort: some SPAs never reach networkidle)
|
||||
wait_rc, _, wait_err = await _run(session_name, "wait", "--load", wait_for)
|
||||
if wait_rc != 0:
|
||||
logger.warning(
|
||||
"[browser_navigate] wait(%s) failed: %s", wait_for, wait_err[:300]
|
||||
)
|
||||
|
||||
# Get current title and URL in parallel
|
||||
(_, title_out, _), (_, url_out, _) = await asyncio.gather(
|
||||
_run(session_name, "get", "title"),
|
||||
_run(session_name, "get", "url"),
|
||||
)
|
||||
|
||||
snapshot = await _snapshot(session_name)
|
||||
|
||||
result = BrowserNavigateResponse(
|
||||
message=f"Navigated to {url}",
|
||||
url=url_out.strip() or url,
|
||||
title=title_out.strip(),
|
||||
snapshot=snapshot,
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Persist browser state to cloud for cross-pod continuity
|
||||
if user_id:
|
||||
_fire_and_forget_save(session_name, user_id, session)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool: browser_act
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_NO_TARGET_ACTIONS = frozenset({"back", "forward", "reload"})
|
||||
_SCROLL_ACTIONS = frozenset({"scroll"})
|
||||
_TARGET_ONLY_ACTIONS = frozenset({"click", "dblclick", "hover", "check", "uncheck"})
|
||||
_TARGET_VALUE_ACTIONS = frozenset({"fill", "type", "select"})
|
||||
# wait <selector|ms>: waits for a DOM element or a fixed delay (e.g. "1000" for 1 s)
|
||||
_WAIT_ACTIONS = frozenset({"wait"})
|
||||
|
||||
|
||||
class BrowserActTool(BaseTool):
|
||||
"""Perform an action on the current browser page and return the updated snapshot.
|
||||
|
||||
Use @ref IDs from the snapshot returned by browser_navigate (e.g. '@e3').
|
||||
The LLM orchestrates multi-step flows by chaining browser_navigate and
|
||||
browser_act calls across turns of the Claude Agent SDK conversation.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "browser_act"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Interact with the current browser page. Use @ref IDs from the "
|
||||
"snapshot (e.g. '@e3') to target elements. Returns an updated snapshot. "
|
||||
"Supported actions: click, dblclick, fill, type, scroll, hover, press, "
|
||||
"check, uncheck, select, wait, back, forward, reload. "
|
||||
"fill clears the field before typing; type appends without clearing. "
|
||||
"wait accepts a CSS selector (waits for element) or milliseconds string (e.g. '1000'). "
|
||||
"Example login flow: fill @e1 with email → fill @e2 with password → "
|
||||
"click @e3 (submit) → browser_navigate to the target page."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"click",
|
||||
"dblclick",
|
||||
"fill",
|
||||
"type",
|
||||
"scroll",
|
||||
"hover",
|
||||
"press",
|
||||
"check",
|
||||
"uncheck",
|
||||
"select",
|
||||
"wait",
|
||||
"back",
|
||||
"forward",
|
||||
"reload",
|
||||
],
|
||||
"description": "The action to perform.",
|
||||
},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Element to target. Use @ref from snapshot (e.g. '@e3'), "
|
||||
"a CSS selector, or a text description. "
|
||||
"Required for: click, dblclick, fill, type, hover, check, uncheck, select. "
|
||||
"For wait: a CSS selector to wait for, or milliseconds as a string (e.g. '1000')."
|
||||
),
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"For fill/type: the text to enter. "
|
||||
"For press: key name (e.g. 'Enter', 'Tab', 'Control+a'). "
|
||||
"For select: the option value to select."
|
||||
),
|
||||
},
|
||||
"direction": {
|
||||
"type": "string",
|
||||
"enum": ["up", "down", "left", "right"],
|
||||
"default": "down",
|
||||
"description": "For scroll: direction to scroll.",
|
||||
},
|
||||
},
|
||||
"required": ["action"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
return shutil.which("agent-browser") is not None
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Perform a browser action and return an updated page snapshot.
|
||||
|
||||
Validates the *action*/*target*/*value* combination, delegates to
|
||||
``agent-browser``, waits for the page to settle, and returns the
|
||||
accessibility-tree snapshot so the LLM can plan the next step.
|
||||
"""
|
||||
action: str = (kwargs.get("action") or "").strip()
|
||||
target: str = (kwargs.get("target") or "").strip()
|
||||
value: str = (kwargs.get("value") or "").strip()
|
||||
direction: str = (kwargs.get("direction") or "down").strip()
|
||||
session_name = session.session_id
|
||||
|
||||
if not action:
|
||||
return ErrorResponse(
|
||||
message="Please specify an action.",
|
||||
error="missing_action",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Build the agent-browser command args
|
||||
if action in _NO_TARGET_ACTIONS:
|
||||
cmd_args = [action]
|
||||
|
||||
elif action in _SCROLL_ACTIONS:
|
||||
cmd_args = ["scroll", direction]
|
||||
|
||||
elif action == "press":
|
||||
if not value:
|
||||
return ErrorResponse(
|
||||
message="'press' requires a 'value' (key name, e.g. 'Enter').",
|
||||
error="missing_value",
|
||||
session_id=session_name,
|
||||
)
|
||||
cmd_args = ["press", value]
|
||||
|
||||
elif action in _TARGET_ONLY_ACTIONS:
|
||||
if not target:
|
||||
return ErrorResponse(
|
||||
message=f"'{action}' requires a 'target' element.",
|
||||
error="missing_target",
|
||||
session_id=session_name,
|
||||
)
|
||||
cmd_args = [action, target]
|
||||
|
||||
elif action in _TARGET_VALUE_ACTIONS:
|
||||
if not target or not value:
|
||||
return ErrorResponse(
|
||||
message=f"'{action}' requires both 'target' and 'value'.",
|
||||
error="missing_params",
|
||||
session_id=session_name,
|
||||
)
|
||||
cmd_args = [action, target, value]
|
||||
|
||||
elif action in _WAIT_ACTIONS:
|
||||
if not target:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"'wait' requires a 'target': a CSS selector to wait for, "
|
||||
"or milliseconds as a string (e.g. '1000')."
|
||||
),
|
||||
error="missing_target",
|
||||
session_id=session_name,
|
||||
)
|
||||
cmd_args = ["wait", target]
|
||||
|
||||
else:
|
||||
return ErrorResponse(
|
||||
message=f"Unsupported action: {action}",
|
||||
error="invalid_action",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Restore browser state from cloud if this is a different pod
|
||||
if user_id:
|
||||
await _ensure_session(session_name, user_id, session)
|
||||
|
||||
rc, _, stderr = await _run(session_name, *cmd_args)
|
||||
if rc != 0:
|
||||
logger.warning("[browser_act] %s failed: %s", action, stderr[:300])
|
||||
return ErrorResponse(
|
||||
message=f"Action '{action}' failed.",
|
||||
error="action_failed",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Allow the page to settle after interaction (best-effort: SPAs may not idle)
|
||||
settle_rc, _, settle_err = await _run(
|
||||
session_name, "wait", "--load", "networkidle"
|
||||
)
|
||||
if settle_rc != 0:
|
||||
logger.warning(
|
||||
"[browser_act] post-action wait failed: %s", settle_err[:300]
|
||||
)
|
||||
|
||||
snapshot = await _snapshot(session_name)
|
||||
_, url_out, _ = await _run(session_name, "get", "url")
|
||||
|
||||
result = BrowserActResponse(
|
||||
message=f"Performed '{action}'" + (f" on '{target}'" if target else ""),
|
||||
action=action,
|
||||
current_url=url_out.strip(),
|
||||
snapshot=snapshot,
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Persist browser state to cloud for cross-pod continuity
|
||||
if user_id:
|
||||
_fire_and_forget_save(session_name, user_id, session)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool: browser_screenshot
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BrowserScreenshotTool(BaseTool):
|
||||
"""Capture a screenshot of the current browser page and save it to the workspace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "browser_screenshot"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Take a screenshot of the current browser page and save it to the workspace. "
|
||||
"IMPORTANT: After calling this tool, immediately call read_workspace_file "
|
||||
"with the returned file_id to display the image inline to the user — "
|
||||
"the screenshot is not visible until you do this. "
|
||||
"With annotate=true (default), @ref labels are overlaid on interactive "
|
||||
"elements, making it easy to see which @ref ID maps to which element on screen."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"annotate": {
|
||||
"type": "boolean",
|
||||
"default": True,
|
||||
"description": "Overlay @ref labels on interactive elements (default: true).",
|
||||
},
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"default": "screenshot.png",
|
||||
"description": "Filename to save in the workspace.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
return shutil.which("agent-browser") is not None
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Capture a PNG screenshot and upload it to the workspace.
|
||||
|
||||
Handles string-to-bool coercion for *annotate* (OpenAI function-call
|
||||
payloads sometimes deliver ``"true"``/``"false"`` as strings).
|
||||
Returns a :class:`BrowserScreenshotResponse` with the workspace
|
||||
``file_id`` the LLM should pass to ``read_workspace_file``.
|
||||
"""
|
||||
raw_annotate = kwargs.get("annotate", True)
|
||||
if isinstance(raw_annotate, str):
|
||||
annotate = raw_annotate.strip().lower() in {"1", "true", "yes", "on"}
|
||||
else:
|
||||
annotate = bool(raw_annotate)
|
||||
filename: str = (kwargs.get("filename") or "screenshot.png").strip()
|
||||
session_name = session.session_id
|
||||
|
||||
# Restore browser state from cloud if this is a different pod
|
||||
if user_id:
|
||||
await _ensure_session(session_name, user_id, session)
|
||||
|
||||
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".png")
|
||||
os.close(tmp_fd)
|
||||
try:
|
||||
cmd_args = ["screenshot"]
|
||||
if annotate:
|
||||
cmd_args.append("--annotate")
|
||||
cmd_args.append(tmp_path)
|
||||
|
||||
rc, _, stderr = await _run(session_name, *cmd_args)
|
||||
if rc != 0:
|
||||
logger.warning("[browser_screenshot] failed: %s", stderr[:300])
|
||||
return ErrorResponse(
|
||||
message="Failed to take screenshot.",
|
||||
error="screenshot_failed",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
with open(tmp_path, "rb") as f:
|
||||
png_bytes = f.read()
|
||||
|
||||
finally:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass # Best-effort temp file cleanup; not critical if it fails.
|
||||
|
||||
# Upload to workspace so the user can view it
|
||||
png_b64 = base64.b64encode(png_bytes).decode()
|
||||
|
||||
# Import here to avoid circular deps — workspace_files imports from .models
|
||||
from .workspace_files import WorkspaceWriteResponse, WriteWorkspaceFileTool
|
||||
|
||||
write_resp = await WriteWorkspaceFileTool()._execute(
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
filename=filename,
|
||||
content_base64=png_b64,
|
||||
)
|
||||
|
||||
if not isinstance(write_resp, WorkspaceWriteResponse):
|
||||
return ErrorResponse(
|
||||
message="Screenshot taken but failed to save to workspace.",
|
||||
error="workspace_write_failed",
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
result = BrowserScreenshotResponse(
|
||||
message=f"Screenshot saved to workspace as '{filename}'. Use read_workspace_file with file_id='{write_resp.file_id}' to retrieve it.",
|
||||
file_id=write_resp.file_id,
|
||||
filename=filename,
|
||||
session_id=session_name,
|
||||
)
|
||||
|
||||
# Persist browser state to cloud for cross-pod continuity
|
||||
if user_id:
|
||||
_fire_and_forget_save(session_name, user_id, session)
|
||||
|
||||
return result
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,5 @@
|
||||
"""Base classes and shared utilities for chat tools."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@@ -8,98 +7,11 @@ from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.truncate import truncate
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Persist full tool output to workspace when it exceeds this threshold.
|
||||
# Must be below _MAX_TOOL_OUTPUT_SIZE (100K) in response_model.py so we
|
||||
# capture the data before model_post_init middle-out truncation discards it.
|
||||
_LARGE_OUTPUT_THRESHOLD = 80_000
|
||||
|
||||
# Character budget for the middle-out preview. The total preview + wrapper
|
||||
# must stay below BOTH:
|
||||
# - _MAX_TOOL_OUTPUT_SIZE (100K) in response_model.py (our own truncation)
|
||||
# - Claude SDK's ~100 KB tool-result spill-to-disk threshold
|
||||
# to avoid double truncation/spilling. 95K + ~300 wrapper = ~95.3K, under both.
|
||||
_PREVIEW_CHARS = 95_000
|
||||
|
||||
|
||||
# Fields whose values are binary/base64 data — truncating them produces
|
||||
# garbage, so we replace them with a human-readable size summary instead.
|
||||
_BINARY_FIELD_NAMES = {"content_base64"}
|
||||
|
||||
|
||||
def _summarize_binary_fields(raw_json: str) -> str:
|
||||
"""Replace known binary fields with a size summary so truncate() doesn't
|
||||
produce garbled base64 in the middle-out preview."""
|
||||
try:
|
||||
data = json.loads(raw_json)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return raw_json
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return raw_json
|
||||
|
||||
changed = False
|
||||
for key in _BINARY_FIELD_NAMES:
|
||||
if key in data and isinstance(data[key], str) and len(data[key]) > 1_000:
|
||||
byte_size = len(data[key]) * 3 // 4 # approximate decoded size
|
||||
data[key] = f"<binary, ~{byte_size:,} bytes>"
|
||||
changed = True
|
||||
|
||||
return json.dumps(data, ensure_ascii=False) if changed else raw_json
|
||||
|
||||
|
||||
async def _persist_and_summarize(
|
||||
raw_output: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
) -> str:
|
||||
"""Persist full output to workspace and return a middle-out preview with retrieval instructions.
|
||||
|
||||
On failure, returns the original ``raw_output`` unchanged so that the
|
||||
existing ``model_post_init`` middle-out truncation handles it as before.
|
||||
"""
|
||||
file_path = f"tool-outputs/{tool_call_id}.json"
|
||||
try:
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
await manager.write_file(
|
||||
content=raw_output.encode("utf-8"),
|
||||
filename=f"{tool_call_id}.json",
|
||||
path=file_path,
|
||||
mime_type="application/json",
|
||||
overwrite=True,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist large tool output for %s",
|
||||
tool_call_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return raw_output # fall back to normal truncation
|
||||
|
||||
total = len(raw_output)
|
||||
preview = truncate(_summarize_binary_fields(raw_output), _PREVIEW_CHARS)
|
||||
retrieval = (
|
||||
f"\nFull output ({total:,} chars) saved to workspace. "
|
||||
f"Use read_workspace_file("
|
||||
f'path="{file_path}", offset=<char_offset>, length=50000) '
|
||||
f"to read any section."
|
||||
)
|
||||
return (
|
||||
f'<tool-output-truncated total_chars={total} path="{file_path}">\n'
|
||||
f"{preview}\n"
|
||||
f"{retrieval}\n"
|
||||
f"</tool-output-truncated>"
|
||||
)
|
||||
|
||||
|
||||
class BaseTool:
|
||||
"""Base class for all chat tools."""
|
||||
@@ -124,16 +36,6 @@ class BaseTool:
|
||||
"""Whether this tool requires authentication."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
"""Whether this tool is available in the current environment.
|
||||
|
||||
Override to check required env vars, binaries, or other dependencies.
|
||||
Unavailable tools are excluded from the LLM tool list so the model is
|
||||
never offered an option that will immediately fail.
|
||||
"""
|
||||
return True
|
||||
|
||||
def as_openai_tool(self) -> ChatCompletionToolParam:
|
||||
"""Convert to OpenAI tool format."""
|
||||
return ChatCompletionToolParam(
|
||||
@@ -155,7 +57,7 @@ class BaseTool:
|
||||
"""Execute the tool with authentication check.
|
||||
|
||||
Args:
|
||||
user_id: User ID (None for anonymous users)
|
||||
user_id: User ID (may be anonymous like "anon_123")
|
||||
session_id: Chat session ID
|
||||
**kwargs: Tool-specific parameters
|
||||
|
||||
@@ -179,21 +81,10 @@ class BaseTool:
|
||||
|
||||
try:
|
||||
result = await self._execute(user_id, session, **kwargs)
|
||||
raw_output = result.model_dump_json()
|
||||
|
||||
if (
|
||||
len(raw_output) > _LARGE_OUTPUT_THRESHOLD
|
||||
and user_id
|
||||
and session.session_id
|
||||
):
|
||||
raw_output = await _persist_and_summarize(
|
||||
raw_output, user_id, session.session_id, tool_call_id
|
||||
)
|
||||
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=self.name,
|
||||
output=raw_output,
|
||||
output=result.model_dump_json(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in {self.name}: {e}", exc_info=True)
|
||||
|
||||
@@ -1,194 +0,0 @@
|
||||
"""Tests for BaseTool large-output persistence in execute()."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.tools.base import (
|
||||
_LARGE_OUTPUT_THRESHOLD,
|
||||
BaseTool,
|
||||
_persist_and_summarize,
|
||||
_summarize_binary_fields,
|
||||
)
|
||||
from backend.copilot.tools.models import ResponseType, ToolResponseBase
|
||||
|
||||
|
||||
class _HugeOutputTool(BaseTool):
|
||||
"""Fake tool that returns an arbitrarily large output."""
|
||||
|
||||
def __init__(self, output_size: int) -> None:
|
||||
self._output_size = output_size
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "huge_output_tool"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Returns a huge output"
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict:
|
||||
return {"type": "object", "properties": {}}
|
||||
|
||||
async def _execute(self, user_id, session, **kwargs) -> ToolResponseBase:
|
||||
return ToolResponseBase(
|
||||
type=ResponseType.ERROR,
|
||||
message="x" * self._output_size,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _persist_and_summarize
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPersistAndSummarize:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_middle_out_preview_with_retrieval_instructions(self):
|
||||
raw = "A" * 200_000
|
||||
|
||||
mock_workspace = MagicMock()
|
||||
mock_workspace.id = "ws-1"
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_or_create_workspace = AsyncMock(return_value=mock_workspace)
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("backend.copilot.tools.base.workspace_db", return_value=mock_db),
|
||||
patch(
|
||||
"backend.copilot.tools.base.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
),
|
||||
):
|
||||
result = await _persist_and_summarize(raw, "user-1", "session-1", "tc-123")
|
||||
|
||||
assert "<tool-output-truncated" in result
|
||||
assert "</tool-output-truncated>" in result
|
||||
assert "total_chars=200000" in result
|
||||
assert 'path="tool-outputs/tc-123.json"' in result
|
||||
assert "read_workspace_file" in result
|
||||
# Middle-out sentinel from truncate()
|
||||
assert "omitted" in result
|
||||
# Total result is much shorter than the raw output
|
||||
assert len(result) < len(raw)
|
||||
|
||||
# Verify write_file was called with full content
|
||||
mock_manager.write_file.assert_awaited_once()
|
||||
call_kwargs = mock_manager.write_file.call_args
|
||||
assert call_kwargs.kwargs["content"] == raw.encode("utf-8")
|
||||
assert call_kwargs.kwargs["path"] == "tool-outputs/tc-123.json"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_on_workspace_error(self):
|
||||
"""If workspace write fails, return raw output for normal truncation."""
|
||||
raw = "B" * 200_000
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_or_create_workspace = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
with patch("backend.copilot.tools.base.workspace_db", return_value=mock_db):
|
||||
result = await _persist_and_summarize(raw, "user-1", "session-1", "tc-fail")
|
||||
|
||||
assert result == raw # unchanged — fallback to normal truncation
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BaseTool.execute — integration with persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBaseToolExecuteLargeOutput:
|
||||
@pytest.mark.asyncio
|
||||
async def test_small_output_not_persisted(self):
|
||||
"""Outputs under the threshold go through without persistence."""
|
||||
tool = _HugeOutputTool(output_size=100)
|
||||
session = MagicMock()
|
||||
session.session_id = "s-1"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.base._persist_and_summarize",
|
||||
new_callable=AsyncMock,
|
||||
) as persist_mock:
|
||||
result = await tool.execute("user-1", session, "tc-small")
|
||||
persist_mock.assert_not_awaited()
|
||||
assert "<tool-output-truncated" not in str(result.output)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_output_persisted(self):
|
||||
"""Outputs over the threshold trigger persistence + preview."""
|
||||
tool = _HugeOutputTool(output_size=_LARGE_OUTPUT_THRESHOLD + 10_000)
|
||||
session = MagicMock()
|
||||
session.session_id = "s-1"
|
||||
|
||||
mock_workspace = MagicMock()
|
||||
mock_workspace.id = "ws-1"
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_or_create_workspace = AsyncMock(return_value=mock_workspace)
|
||||
mock_manager = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("backend.copilot.tools.base.workspace_db", return_value=mock_db),
|
||||
patch(
|
||||
"backend.copilot.tools.base.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
),
|
||||
):
|
||||
result = await tool.execute("user-1", session, "tc-big")
|
||||
|
||||
assert "<tool-output-truncated" in str(result.output)
|
||||
assert "read_workspace_file" in str(result.output)
|
||||
mock_manager.write_file.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_persistence_without_user_id(self):
|
||||
"""Anonymous users skip persistence (no workspace)."""
|
||||
tool = _HugeOutputTool(output_size=_LARGE_OUTPUT_THRESHOLD + 10_000)
|
||||
session = MagicMock()
|
||||
session.session_id = "s-1"
|
||||
|
||||
# user_id=None → should not attempt persistence
|
||||
with patch(
|
||||
"backend.copilot.tools.base._persist_and_summarize",
|
||||
new_callable=AsyncMock,
|
||||
) as persist_mock:
|
||||
result = await tool.execute(None, session, "tc-anon")
|
||||
persist_mock.assert_not_awaited()
|
||||
# Output is set but not wrapped in <tool-output-truncated> tags
|
||||
# (it will be middle-out truncated by model_post_init instead)
|
||||
assert "<tool-output-truncated" not in str(result.output)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _summarize_binary_fields
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSummarizeBinaryFields:
|
||||
def test_replaces_large_content_base64(self):
|
||||
import json
|
||||
|
||||
data = {"content_base64": "A" * 10_000, "name": "file.png"}
|
||||
result = json.loads(_summarize_binary_fields(json.dumps(data)))
|
||||
assert result["name"] == "file.png"
|
||||
assert "<binary" in result["content_base64"]
|
||||
assert "bytes>" in result["content_base64"]
|
||||
|
||||
def test_preserves_small_content_base64(self):
|
||||
import json
|
||||
|
||||
data = {"content_base64": "AQID", "name": "tiny.bin"}
|
||||
result_str = _summarize_binary_fields(json.dumps(data))
|
||||
result = json.loads(result_str)
|
||||
assert result["content_base64"] == "AQID" # unchanged
|
||||
|
||||
def test_non_json_passthrough(self):
|
||||
raw = "not json at all"
|
||||
assert _summarize_binary_fields(raw) == raw
|
||||
|
||||
def test_no_binary_fields_unchanged(self):
|
||||
import json
|
||||
|
||||
data = {"message": "hello", "type": "info"}
|
||||
raw = json.dumps(data)
|
||||
assert _summarize_binary_fields(raw) == raw
|
||||
@@ -1,30 +1,19 @@
|
||||
"""Bash execution tool — run shell commands on E2B or in a bubblewrap sandbox.
|
||||
"""Bash execution tool — run shell commands in a bubblewrap sandbox.
|
||||
|
||||
When an E2B sandbox is available in the current execution context the command
|
||||
runs directly on the remote E2B cloud environment. This means:
|
||||
Full Bash scripting is allowed (loops, conditionals, pipes, functions, etc.).
|
||||
Safety comes from OS-level isolation (bubblewrap): only system dirs visible
|
||||
read-only, writable workspace only, clean env, no network.
|
||||
|
||||
- **Persistent filesystem**: files survive across turns via HTTP-based sync
|
||||
with the sandbox's ``/home/user`` directory (E2B files API), shared with
|
||||
SDK Read/Write/Edit tools.
|
||||
- **Full internet access**: E2B sandboxes have unrestricted outbound network.
|
||||
- **Execution isolation**: E2B provides a fresh, containerised Linux environment.
|
||||
|
||||
When E2B is *not* configured the tool falls back to **bubblewrap** (bwrap):
|
||||
OS-level isolation with a whitelist-only filesystem, no network, and resource
|
||||
limits. Requires bubblewrap to be installed (Linux only).
|
||||
Requires bubblewrap (``bwrap``) — the tool is disabled when bwrap is not
|
||||
available (e.g. macOS development).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import shlex
|
||||
from typing import Any
|
||||
|
||||
from e2b import AsyncSandbox
|
||||
from e2b.exceptions import TimeoutException
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .e2b_sandbox import E2B_WORKDIR
|
||||
from .models import BashExecResponse, ErrorResponse, ToolResponseBase
|
||||
from .sandbox import get_workspace_dir, has_full_sandbox, run_sandboxed
|
||||
|
||||
@@ -32,7 +21,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BashExecTool(BaseTool):
|
||||
"""Execute Bash commands on E2B or in a bubblewrap sandbox."""
|
||||
"""Execute Bash commands in a bubblewrap sandbox."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -40,16 +29,28 @@ class BashExecTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
if not has_full_sandbox():
|
||||
return (
|
||||
"Bash execution is DISABLED — bubblewrap sandbox is not "
|
||||
"available on this platform. Do not call this tool."
|
||||
)
|
||||
return (
|
||||
"Execute a Bash command or script. "
|
||||
"Execute a Bash command or script in a bubblewrap sandbox. "
|
||||
"Full Bash scripting is supported (loops, conditionals, pipes, "
|
||||
"functions, etc.). "
|
||||
"The working directory is shared with the SDK Read/Write/Edit/Glob/Grep "
|
||||
"tools — files created by either are immediately visible to both. "
|
||||
"The sandbox shares the same working directory as the SDK Read/Write "
|
||||
"tools — files created by either are accessible to both. "
|
||||
"SECURITY: Only system directories (/usr, /bin, /lib, /etc) are "
|
||||
"visible read-only, the per-session workspace is the only writable "
|
||||
"path, environment variables are wiped (no secrets), all network "
|
||||
"access is blocked at the kernel level, and resource limits are "
|
||||
"enforced (max 64 processes, 512MB memory, 50MB file size). "
|
||||
"Application code, configs, and other directories are NOT accessible. "
|
||||
"To fetch web content, use the web_fetch tool instead. "
|
||||
"Execution is killed after the timeout (default 30s, max 120s). "
|
||||
"Returns stdout and stderr. "
|
||||
"Useful for file manipulation, data processing, running scripts, "
|
||||
"and installing packages."
|
||||
"Useful for file manipulation, data processing with Unix tools "
|
||||
"(grep, awk, sed, jq, etc.), and running shell scripts."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -84,8 +85,15 @@ class BashExecTool(BaseTool):
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not has_full_sandbox():
|
||||
return ErrorResponse(
|
||||
message="bash_exec requires bubblewrap sandbox (Linux only).",
|
||||
error="sandbox_unavailable",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
command: str = (kwargs.get("command") or "").strip()
|
||||
timeout: int = int(kwargs.get("timeout", 30))
|
||||
timeout: int = kwargs.get("timeout", 30)
|
||||
|
||||
if not command:
|
||||
return ErrorResponse(
|
||||
@@ -94,21 +102,6 @@ class BashExecTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# E2B path: run on remote cloud sandbox when available.
|
||||
from backend.copilot.sdk.tool_adapter import get_current_sandbox
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
return await self._execute_on_e2b(sandbox, command, timeout, session_id)
|
||||
|
||||
# Bubblewrap fallback: local isolated execution.
|
||||
if not has_full_sandbox():
|
||||
return ErrorResponse(
|
||||
message="bash_exec requires bubblewrap sandbox (Linux only).",
|
||||
error="sandbox_unavailable",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
workspace = get_workspace_dir(session_id or "default")
|
||||
|
||||
stdout, stderr, exit_code, timed_out = await run_sandboxed(
|
||||
@@ -129,43 +122,3 @@ class BashExecTool(BaseTool):
|
||||
timed_out=timed_out,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def _execute_on_e2b(
|
||||
self,
|
||||
sandbox: AsyncSandbox,
|
||||
command: str,
|
||||
timeout: int,
|
||||
session_id: str | None,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute *command* on the E2B sandbox via commands.run()."""
|
||||
try:
|
||||
result = await sandbox.commands.run(
|
||||
f"bash -c {shlex.quote(command)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
timeout=timeout,
|
||||
envs={"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"},
|
||||
)
|
||||
return BashExecResponse(
|
||||
message=f"Command executed on E2B (exit {result.exit_code})",
|
||||
stdout=result.stdout or "",
|
||||
stderr=result.stderr or "",
|
||||
exit_code=result.exit_code,
|
||||
timed_out=False,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
if isinstance(exc, TimeoutException):
|
||||
return BashExecResponse(
|
||||
message="Execution timed out",
|
||||
stdout="",
|
||||
stderr=f"Timed out after {timeout}s",
|
||||
exit_code=-1,
|
||||
timed_out=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
logger.error("[E2B] bash_exec failed: %s", exc, exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"E2B execution failed: {exc}",
|
||||
error="e2b_execution_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -1,170 +0,0 @@
|
||||
"""E2B sandbox lifecycle for CoPilot: persistent cloud execution.
|
||||
|
||||
Each session gets a long-lived E2B cloud sandbox. ``bash_exec`` runs commands
|
||||
directly on the sandbox via ``sandbox.commands.run()``. SDK file tools
|
||||
(read_file/write_file/edit_file/glob/grep) route to the sandbox's
|
||||
``/home/user`` directory via E2B's HTTP-based filesystem API — all tools
|
||||
share a single coherent filesystem with no local sync required.
|
||||
|
||||
Lifecycle
|
||||
---------
|
||||
1. **Turn start** – connect to the existing sandbox (sandbox_id in Redis) or
|
||||
create a new one via ``get_or_create_sandbox()``.
|
||||
2. **Execution** – ``bash_exec`` and MCP file tools operate directly on the
|
||||
sandbox's ``/home/user`` filesystem.
|
||||
3. **Session expiry** – E2B sandbox is killed by its own timeout (session_ttl).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SANDBOX_REDIS_PREFIX = "copilot:e2b:sandbox:"
|
||||
E2B_WORKDIR = "/home/user"
|
||||
_CREATING = "__creating__"
|
||||
_CREATION_LOCK_TTL = 60
|
||||
_MAX_WAIT_ATTEMPTS = 20 # 20 * 0.5s = 10s max wait
|
||||
|
||||
|
||||
async def _try_reconnect(
|
||||
sandbox_id: str, api_key: str, redis_key: str, timeout: int
|
||||
) -> "AsyncSandbox | None":
|
||||
"""Try to reconnect to an existing sandbox. Returns None on failure."""
|
||||
try:
|
||||
sandbox = await AsyncSandbox.connect(sandbox_id, api_key=api_key)
|
||||
if await sandbox.is_running():
|
||||
redis = await get_redis_async()
|
||||
await redis.expire(redis_key, timeout)
|
||||
return sandbox
|
||||
except Exception as exc:
|
||||
logger.warning("[E2B] Reconnect to %.12s failed: %s", sandbox_id, exc)
|
||||
|
||||
# Stale — clear Redis so a new sandbox can be created.
|
||||
redis = await get_redis_async()
|
||||
await redis.delete(redis_key)
|
||||
return None
|
||||
|
||||
|
||||
async def get_or_create_sandbox(
|
||||
session_id: str,
|
||||
api_key: str,
|
||||
template: str = "base",
|
||||
timeout: int = 43200,
|
||||
) -> AsyncSandbox:
|
||||
"""Return the existing E2B sandbox for *session_id* or create a new one.
|
||||
|
||||
The sandbox_id is persisted in Redis so the same sandbox is reused
|
||||
across turns. Concurrent calls for the same session are serialised
|
||||
via a Redis ``SET NX`` creation lock.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
redis_key = f"{_SANDBOX_REDIS_PREFIX}{session_id}"
|
||||
|
||||
# 1. Try reconnecting to an existing sandbox.
|
||||
raw = await redis.get(redis_key)
|
||||
if raw:
|
||||
sandbox_id = raw if isinstance(raw, str) else raw.decode()
|
||||
if sandbox_id != _CREATING:
|
||||
sandbox = await _try_reconnect(sandbox_id, api_key, redis_key, timeout)
|
||||
if sandbox:
|
||||
logger.info(
|
||||
"[E2B] Reconnected to %.12s for session %.12s",
|
||||
sandbox_id,
|
||||
session_id,
|
||||
)
|
||||
return sandbox
|
||||
|
||||
# 2. Claim creation lock. If another request holds it, wait for the result.
|
||||
claimed = await redis.set(redis_key, _CREATING, nx=True, ex=_CREATION_LOCK_TTL)
|
||||
if not claimed:
|
||||
for _ in range(_MAX_WAIT_ATTEMPTS):
|
||||
await asyncio.sleep(0.5)
|
||||
raw = await redis.get(redis_key)
|
||||
if not raw:
|
||||
break # Lock expired — fall through to retry creation
|
||||
sandbox_id = raw if isinstance(raw, str) else raw.decode()
|
||||
if sandbox_id != _CREATING:
|
||||
sandbox = await _try_reconnect(sandbox_id, api_key, redis_key, timeout)
|
||||
if sandbox:
|
||||
return sandbox
|
||||
break # Stale sandbox cleared — fall through to create
|
||||
|
||||
# Try to claim creation lock again after waiting.
|
||||
claimed = await redis.set(redis_key, _CREATING, nx=True, ex=_CREATION_LOCK_TTL)
|
||||
if not claimed:
|
||||
# Another process may have created a sandbox — try to use it.
|
||||
raw = await redis.get(redis_key)
|
||||
if raw:
|
||||
sandbox_id = raw if isinstance(raw, str) else raw.decode()
|
||||
if sandbox_id != _CREATING:
|
||||
sandbox = await _try_reconnect(
|
||||
sandbox_id, api_key, redis_key, timeout
|
||||
)
|
||||
if sandbox:
|
||||
return sandbox
|
||||
raise RuntimeError(
|
||||
f"Could not acquire E2B creation lock for session {session_id[:12]}"
|
||||
)
|
||||
|
||||
# 3. Create a new sandbox.
|
||||
try:
|
||||
sandbox = await AsyncSandbox.create(
|
||||
template=template, api_key=api_key, timeout=timeout
|
||||
)
|
||||
except Exception:
|
||||
await redis.delete(redis_key)
|
||||
raise
|
||||
|
||||
await redis.setex(redis_key, timeout, sandbox.sandbox_id)
|
||||
logger.info(
|
||||
"[E2B] Created sandbox %.12s for session %.12s",
|
||||
sandbox.sandbox_id,
|
||||
session_id,
|
||||
)
|
||||
return sandbox
|
||||
|
||||
|
||||
async def kill_sandbox(session_id: str, api_key: str) -> bool:
|
||||
"""Kill the E2B sandbox for *session_id* and clean up its Redis entry.
|
||||
|
||||
Returns ``True`` if a sandbox was found and killed, ``False`` otherwise.
|
||||
Safe to call even when no sandbox exists for the session.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
redis_key = f"{_SANDBOX_REDIS_PREFIX}{session_id}"
|
||||
raw = await redis.get(redis_key)
|
||||
if not raw:
|
||||
return False
|
||||
|
||||
sandbox_id = raw if isinstance(raw, str) else raw.decode()
|
||||
await redis.delete(redis_key)
|
||||
|
||||
if sandbox_id == _CREATING:
|
||||
return False
|
||||
|
||||
try:
|
||||
|
||||
async def _connect_and_kill():
|
||||
sandbox = await AsyncSandbox.connect(sandbox_id, api_key=api_key)
|
||||
await sandbox.kill()
|
||||
|
||||
await asyncio.wait_for(_connect_and_kill(), timeout=10)
|
||||
logger.info(
|
||||
"[E2B] Killed sandbox %.12s for session %.12s",
|
||||
sandbox_id,
|
||||
session_id,
|
||||
)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[E2B] Failed to kill sandbox %.12s for session %.12s: %s",
|
||||
sandbox_id,
|
||||
session_id,
|
||||
exc,
|
||||
)
|
||||
return False
|
||||
@@ -1,272 +0,0 @@
|
||||
"""Tests for e2b_sandbox: get_or_create_sandbox, _try_reconnect, kill_sandbox.
|
||||
|
||||
Uses mock Redis and mock AsyncSandbox — no external dependencies.
|
||||
Tests are synchronous (using asyncio.run) to avoid conflicts with the
|
||||
session-scoped event loop in conftest.py.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .e2b_sandbox import (
|
||||
_CREATING,
|
||||
_SANDBOX_REDIS_PREFIX,
|
||||
_try_reconnect,
|
||||
get_or_create_sandbox,
|
||||
kill_sandbox,
|
||||
)
|
||||
|
||||
_KEY = f"{_SANDBOX_REDIS_PREFIX}sess-123"
|
||||
_API_KEY = "test-api-key"
|
||||
_TIMEOUT = 300
|
||||
|
||||
|
||||
def _mock_sandbox(sandbox_id: str = "sb-abc", running: bool = True) -> MagicMock:
|
||||
sb = MagicMock()
|
||||
sb.sandbox_id = sandbox_id
|
||||
sb.is_running = AsyncMock(return_value=running)
|
||||
return sb
|
||||
|
||||
|
||||
def _mock_redis(get_val: str | bytes | None = None, set_nx_result: bool = True):
|
||||
r = AsyncMock()
|
||||
r.get = AsyncMock(return_value=get_val)
|
||||
r.set = AsyncMock(return_value=set_nx_result)
|
||||
r.setex = AsyncMock()
|
||||
r.delete = AsyncMock()
|
||||
r.expire = AsyncMock()
|
||||
return r
|
||||
|
||||
|
||||
def _patch_redis(redis):
|
||||
return patch(
|
||||
"backend.copilot.tools.e2b_sandbox.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=redis,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _try_reconnect
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTryReconnect:
|
||||
def test_reconnect_success(self):
|
||||
sb = _mock_sandbox()
|
||||
redis = _mock_redis()
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
|
||||
|
||||
assert result is sb
|
||||
redis.expire.assert_awaited_once_with(_KEY, _TIMEOUT)
|
||||
redis.delete.assert_not_awaited()
|
||||
|
||||
def test_reconnect_not_running_clears_key(self):
|
||||
sb = _mock_sandbox(running=False)
|
||||
redis = _mock_redis()
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
|
||||
|
||||
assert result is None
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
|
||||
def test_reconnect_exception_clears_key(self):
|
||||
redis = _mock_redis()
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
|
||||
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
|
||||
|
||||
assert result is None
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_or_create_sandbox
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetOrCreateSandbox:
|
||||
def test_reconnect_existing(self):
|
||||
"""When Redis has a valid sandbox_id, reconnect to it."""
|
||||
sb = _mock_sandbox()
|
||||
redis = _mock_redis(get_val="sb-abc")
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
assert result is sb
|
||||
mock_cls.create.assert_not_called()
|
||||
|
||||
def test_create_new_when_no_key(self):
|
||||
"""When Redis is empty, claim lock and create a new sandbox."""
|
||||
sb = _mock_sandbox("sb-new")
|
||||
redis = _mock_redis(get_val=None, set_nx_result=True)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.create = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
assert result is sb
|
||||
redis.setex.assert_awaited_once_with(_KEY, _TIMEOUT, "sb-new")
|
||||
|
||||
def test_create_failure_clears_lock(self):
|
||||
"""If sandbox creation fails, the Redis lock is deleted."""
|
||||
redis = _mock_redis(get_val=None, set_nx_result=True)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.create = AsyncMock(side_effect=RuntimeError("quota"))
|
||||
with pytest.raises(RuntimeError, match="quota"):
|
||||
asyncio.run(
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
|
||||
def test_wait_for_lock_then_reconnect(self):
|
||||
"""When another process holds the lock, wait and reconnect."""
|
||||
sb = _mock_sandbox("sb-other")
|
||||
redis = _mock_redis()
|
||||
redis.get = AsyncMock(side_effect=[_CREATING, "sb-other"])
|
||||
redis.set = AsyncMock(return_value=False)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
patch(
|
||||
"backend.copilot.tools.e2b_sandbox.asyncio.sleep",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
assert result is sb
|
||||
|
||||
def test_stale_reconnect_clears_and_creates(self):
|
||||
"""When stored sandbox is stale, clear key and create a new one."""
|
||||
stale_sb = _mock_sandbox("sb-stale", running=False)
|
||||
new_sb = _mock_sandbox("sb-fresh")
|
||||
redis = _mock_redis(get_val="sb-stale", set_nx_result=True)
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=stale_sb)
|
||||
mock_cls.create = AsyncMock(return_value=new_sb)
|
||||
result = asyncio.run(
|
||||
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
|
||||
)
|
||||
|
||||
assert result is new_sb
|
||||
redis.delete.assert_awaited()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# kill_sandbox
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestKillSandbox:
|
||||
def test_kill_existing_sandbox(self):
|
||||
"""Kill a running sandbox and clean up Redis."""
|
||||
sb = _mock_sandbox()
|
||||
sb.kill = AsyncMock()
|
||||
redis = _mock_redis(get_val="sb-abc")
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is True
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
sb.kill.assert_awaited_once()
|
||||
|
||||
def test_kill_no_sandbox(self):
|
||||
"""No-op when no sandbox exists in Redis."""
|
||||
redis = _mock_redis(get_val=None)
|
||||
with _patch_redis(redis):
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is False
|
||||
redis.delete.assert_not_awaited()
|
||||
|
||||
def test_kill_creating_state(self):
|
||||
"""Clears Redis key but returns False when sandbox is still being created."""
|
||||
redis = _mock_redis(get_val=_CREATING)
|
||||
with _patch_redis(redis):
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is False
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
|
||||
def test_kill_connect_failure(self):
|
||||
"""Returns False and cleans Redis if connect/kill fails."""
|
||||
redis = _mock_redis(get_val="sb-abc")
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is False
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
|
||||
def test_kill_with_bytes_redis_value(self):
|
||||
"""Redis may return bytes — kill_sandbox should decode correctly."""
|
||||
sb = _mock_sandbox()
|
||||
sb.kill = AsyncMock()
|
||||
redis = _mock_redis(get_val=b"sb-abc")
|
||||
with (
|
||||
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
|
||||
_patch_redis(redis),
|
||||
):
|
||||
mock_cls.connect = AsyncMock(return_value=sb)
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is True
|
||||
sb.kill.assert_awaited_once()
|
||||
|
||||
def test_kill_timeout_returns_false(self):
|
||||
"""Returns False when E2B API calls exceed the 10s timeout."""
|
||||
redis = _mock_redis(get_val="sb-abc")
|
||||
with (
|
||||
_patch_redis(redis),
|
||||
patch(
|
||||
"backend.copilot.tools.e2b_sandbox.asyncio.wait_for",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=asyncio.TimeoutError,
|
||||
),
|
||||
):
|
||||
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
|
||||
|
||||
assert result is False
|
||||
redis.delete.assert_awaited_once_with(_KEY)
|
||||
@@ -32,7 +32,6 @@ COPILOT_EXCLUDED_BLOCK_TYPES = {
|
||||
BlockType.NOTE, # Visual annotation only - no runtime behavior
|
||||
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
|
||||
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
|
||||
BlockType.MCP_TOOL, # Has dedicated run_mcp_tool tool with proper discovery + auth flow
|
||||
}
|
||||
|
||||
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
|
||||
|
||||
@@ -41,10 +41,6 @@ class ResponseType(str, Enum):
|
||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
||||
# Web fetch
|
||||
WEB_FETCH = "web_fetch"
|
||||
# Agent-browser multi-step automation (navigate, act, screenshot)
|
||||
BROWSER_NAVIGATE = "browser_navigate"
|
||||
BROWSER_ACT = "browser_act"
|
||||
BROWSER_SCREENSHOT = "browser_screenshot"
|
||||
# Code execution
|
||||
BASH_EXEC = "bash_exec"
|
||||
# Feature request types
|
||||
@@ -52,9 +48,6 @@ class ResponseType(str, Enum):
|
||||
FEATURE_REQUEST_CREATED = "feature_request_created"
|
||||
# Goal refinement
|
||||
SUGGESTED_GOAL = "suggested_goal"
|
||||
# MCP tool types
|
||||
MCP_TOOLS_DISCOVERED = "mcp_tools_discovered"
|
||||
MCP_TOOL_OUTPUT = "mcp_tool_output"
|
||||
|
||||
|
||||
# Base response model
|
||||
@@ -483,59 +476,3 @@ class FeatureRequestCreatedResponse(ToolResponseBase):
|
||||
issue_url: str
|
||||
is_new_issue: bool # False if added to existing
|
||||
customer_name: str
|
||||
|
||||
|
||||
# MCP tool models
|
||||
class MCPToolInfo(BaseModel):
|
||||
"""Information about a single MCP tool discovered from a server."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
input_schema: dict[str, Any]
|
||||
|
||||
|
||||
class MCPToolsDiscoveredResponse(ToolResponseBase):
|
||||
"""Response when MCP tools are discovered from a server (agent-internal)."""
|
||||
|
||||
type: ResponseType = ResponseType.MCP_TOOLS_DISCOVERED
|
||||
server_url: str
|
||||
tools: list[MCPToolInfo]
|
||||
|
||||
|
||||
class MCPToolOutputResponse(ToolResponseBase):
|
||||
"""Response after executing an MCP tool."""
|
||||
|
||||
type: ResponseType = ResponseType.MCP_TOOL_OUTPUT
|
||||
server_url: str
|
||||
tool_name: str
|
||||
result: Any = None
|
||||
success: bool = True
|
||||
|
||||
|
||||
# Agent-browser multi-step automation models
|
||||
|
||||
|
||||
class BrowserNavigateResponse(ToolResponseBase):
|
||||
"""Response for browser_navigate tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BROWSER_NAVIGATE
|
||||
url: str
|
||||
title: str
|
||||
snapshot: str # Interactive accessibility tree with @ref IDs
|
||||
|
||||
|
||||
class BrowserActResponse(ToolResponseBase):
|
||||
"""Response for browser_act tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BROWSER_ACT
|
||||
action: str
|
||||
current_url: str = ""
|
||||
snapshot: str # Updated accessibility tree after the action
|
||||
|
||||
|
||||
class BrowserScreenshotResponse(ToolResponseBase):
|
||||
"""Response for browser_screenshot tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BROWSER_SCREENSHOT
|
||||
file_id: str # Workspace file ID — use read_workspace_file to retrieve
|
||||
filename: str
|
||||
|
||||
@@ -1,339 +0,0 @@
|
||||
"""Tool for discovering and executing MCP (Model Context Protocol) server tools."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from backend.blocks.mcp.block import MCPToolBlock
|
||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
||||
from backend.blocks.mcp.helpers import (
|
||||
auto_lookup_mcp_credential,
|
||||
normalize_mcp_url,
|
||||
parse_mcp_content,
|
||||
server_host,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.utils import build_missing_credentials_from_field_info
|
||||
from backend.util.request import HTTPClientError, validate_url
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
ErrorResponse,
|
||||
MCPToolInfo,
|
||||
MCPToolOutputResponse,
|
||||
MCPToolsDiscoveredResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# HTTP status codes that indicate authentication is required
|
||||
_AUTH_STATUS_CODES = {401, 403}
|
||||
|
||||
|
||||
class RunMCPToolTool(BaseTool):
|
||||
"""
|
||||
Tool for discovering and executing tools on any MCP server.
|
||||
|
||||
Stage 1 — discovery: call with just server_url to get available tools.
|
||||
Stage 2 — execution: call with server_url + tool_name + tool_arguments.
|
||||
If the server requires OAuth credentials that the user hasn't connected yet,
|
||||
a SetupRequirementsResponse is returned so the frontend can render the
|
||||
same OAuth login UI as the graph builder.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "run_mcp_tool"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Connect to an MCP (Model Context Protocol) server to discover and execute its tools. "
|
||||
"Call with just `server_url` to see available tools. "
|
||||
"Then call again with `server_url`, `tool_name`, and `tool_arguments` to execute. "
|
||||
"If the server requires authentication, the user will be prompted to connect it. "
|
||||
"Find MCP servers at https://registry.modelcontextprotocol.io/ — hundreds of integrations "
|
||||
"including GitHub, Postgres, Slack, filesystem, and more."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"server_url": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"URL of the MCP server (Streamable HTTP endpoint), "
|
||||
"e.g. https://mcp.example.com/mcp"
|
||||
),
|
||||
},
|
||||
"tool_name": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Name of the MCP tool to execute. "
|
||||
"Omit on first call to discover available tools."
|
||||
),
|
||||
},
|
||||
"tool_arguments": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"Arguments to pass to the selected tool. "
|
||||
"Must match the tool's input schema returned during discovery."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["server_url"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
server_url: str = (kwargs.get("server_url") or "").strip()
|
||||
tool_name: str = (kwargs.get("tool_name") or "").strip()
|
||||
raw_tool_arguments = kwargs.get("tool_arguments")
|
||||
tool_arguments: dict[str, Any] = (
|
||||
raw_tool_arguments if isinstance(raw_tool_arguments, dict) else {}
|
||||
)
|
||||
session_id = session.session_id
|
||||
|
||||
if raw_tool_arguments is not None and not isinstance(raw_tool_arguments, dict):
|
||||
return ErrorResponse(
|
||||
message="tool_arguments must be a JSON object.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not server_url:
|
||||
return ErrorResponse(
|
||||
message="Please provide a server_url for the MCP server.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
_parsed = urlparse(server_url)
|
||||
if _parsed.username or _parsed.password:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Do not include credentials in server_url. "
|
||||
"Use the MCP credential setup flow instead."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
if _parsed.query or _parsed.fragment:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Do not include query parameters or fragments in server_url. "
|
||||
"Use the MCP credential setup flow instead."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges
|
||||
try:
|
||||
await validate_url(server_url, trusted_origins=[])
|
||||
except ValueError as e:
|
||||
msg = str(e)
|
||||
if "Unable to resolve" in msg or "No IP addresses" in msg:
|
||||
user_msg = (
|
||||
f"Hostname not found: {server_host(server_url)}. "
|
||||
"Please check the URL — the domain may not exist."
|
||||
)
|
||||
else:
|
||||
user_msg = f"Blocked server URL: {msg}"
|
||||
return ErrorResponse(message=user_msg, session_id=session_id)
|
||||
|
||||
# Fast DB lookup — no network call.
|
||||
# Normalize for matching because stored credentials use normalized URLs.
|
||||
creds = await auto_lookup_mcp_credential(user_id, normalize_mcp_url(server_url))
|
||||
auth_token = creds.access_token.get_secret_value() if creds else None
|
||||
|
||||
client = MCPClient(server_url, auth_token=auth_token)
|
||||
|
||||
try:
|
||||
await client.initialize()
|
||||
|
||||
if not tool_name:
|
||||
# Stage 1: Discover available tools
|
||||
return await self._discover_tools(client, server_url, session_id)
|
||||
else:
|
||||
# Stage 2: Execute the selected tool
|
||||
return await self._execute_tool(
|
||||
client, server_url, tool_name, tool_arguments, session_id
|
||||
)
|
||||
|
||||
except HTTPClientError as e:
|
||||
if e.status_code in _AUTH_STATUS_CODES and not creds:
|
||||
# Server requires auth and user has no stored credentials
|
||||
return self._build_setup_requirements(server_url, session_id)
|
||||
logger.warning("MCP HTTP error for %s: %s", server_host(server_url), e)
|
||||
return ErrorResponse(
|
||||
message=f"MCP server returned HTTP {e.status_code}: {e}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except MCPClientError as e:
|
||||
logger.warning("MCP client error for %s: %s", server_host(server_url), e)
|
||||
return ErrorResponse(
|
||||
message=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.error(
|
||||
"Unexpected error calling MCP server %s",
|
||||
server_host(server_url),
|
||||
exc_info=True,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message="An unexpected error occurred connecting to the MCP server. Please try again.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def _discover_tools(
|
||||
self,
|
||||
client: MCPClient,
|
||||
server_url: str,
|
||||
session_id: str,
|
||||
) -> MCPToolsDiscoveredResponse:
|
||||
"""List available tools from an already-initialized MCPClient.
|
||||
|
||||
Called when the agent invokes run_mcp_tool with only server_url (no
|
||||
tool_name). Returns MCPToolsDiscoveredResponse so the agent can
|
||||
inspect tool schemas and choose one to execute in a follow-up call.
|
||||
"""
|
||||
tools = await client.list_tools()
|
||||
tool_infos = [
|
||||
MCPToolInfo(
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
input_schema=t.input_schema,
|
||||
)
|
||||
for t in tools
|
||||
]
|
||||
host = server_host(server_url)
|
||||
return MCPToolsDiscoveredResponse(
|
||||
message=(
|
||||
f"Discovered {len(tool_infos)} tool(s) on {host}. "
|
||||
"Call run_mcp_tool again with tool_name and tool_arguments to execute one."
|
||||
),
|
||||
server_url=server_url,
|
||||
tools=tool_infos,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def _execute_tool(
|
||||
self,
|
||||
client: MCPClient,
|
||||
server_url: str,
|
||||
tool_name: str,
|
||||
tool_arguments: dict[str, Any],
|
||||
session_id: str,
|
||||
) -> MCPToolOutputResponse | ErrorResponse:
|
||||
"""Execute a specific tool on an already-initialized MCPClient.
|
||||
|
||||
Parses the MCP content response into a plain Python value:
|
||||
- text items: parsed as JSON when possible, kept as str otherwise
|
||||
- image items: kept as {type, data, mimeType} dict for frontend rendering
|
||||
- resource items: unwrapped to their resource payload dict
|
||||
Single-item responses are unwrapped from the list; multiple items are
|
||||
returned as a list; empty content returns None.
|
||||
"""
|
||||
result = await client.call_tool(tool_name, tool_arguments)
|
||||
|
||||
if result.is_error:
|
||||
error_text = " ".join(
|
||||
item.get("text", "")
|
||||
for item in result.content
|
||||
if item.get("type") == "text"
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=f"MCP tool '{tool_name}' returned an error: {error_text or 'Unknown error'}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
result_value = parse_mcp_content(result.content)
|
||||
|
||||
return MCPToolOutputResponse(
|
||||
message=f"MCP tool '{tool_name}' executed successfully.",
|
||||
server_url=server_url,
|
||||
tool_name=tool_name,
|
||||
result=result_value,
|
||||
success=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def _build_setup_requirements(
|
||||
self,
|
||||
server_url: str,
|
||||
session_id: str,
|
||||
) -> SetupRequirementsResponse | ErrorResponse:
|
||||
"""Build a SetupRequirementsResponse for a missing MCP server credential."""
|
||||
mcp_block = MCPToolBlock()
|
||||
credentials_fields_info = mcp_block.input_schema.get_credentials_fields_info()
|
||||
|
||||
# Apply the server_url discriminator value so the frontend's CredentialsGroupedView
|
||||
# can match the credential to the correct OAuth provider/server.
|
||||
for field_info in credentials_fields_info.values():
|
||||
if field_info.discriminator == "server_url":
|
||||
field_info.discriminator_values.add(server_url)
|
||||
|
||||
missing_creds_dict = build_missing_credentials_from_field_info(
|
||||
credentials_fields_info, matched_keys=set()
|
||||
)
|
||||
|
||||
if not missing_creds_dict:
|
||||
logger.error(
|
||||
"No credential requirements found for MCP server %s — "
|
||||
"MCPToolBlock may not have credentials configured",
|
||||
server_host(server_url),
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"The MCP server at {server_host(server_url)} requires authentication, "
|
||||
"but no credential configuration was found."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
missing_creds_list = list(missing_creds_dict.values())
|
||||
|
||||
host = server_host(server_url)
|
||||
return SetupRequirementsResponse(
|
||||
message=(
|
||||
f"The MCP server at {host} requires authentication. "
|
||||
"Please connect your credentials to continue."
|
||||
),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=server_url,
|
||||
agent_name=f"MCP: {host}",
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_creds_dict,
|
||||
ready_to_run=False,
|
||||
),
|
||||
requirements={
|
||||
"credentials": missing_creds_list,
|
||||
"inputs": [],
|
||||
"execution_modes": ["immediate"],
|
||||
},
|
||||
),
|
||||
graph_id=None,
|
||||
graph_version=None,
|
||||
)
|
||||
@@ -1,759 +0,0 @@
|
||||
"""Unit tests for the run_mcp_tool copilot tool."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.mcp.helpers import server_host
|
||||
|
||||
from ._test_data import make_session
|
||||
from .models import (
|
||||
ErrorResponse,
|
||||
MCPToolOutputResponse,
|
||||
MCPToolsDiscoveredResponse,
|
||||
SetupRequirementsResponse,
|
||||
)
|
||||
from .run_mcp_tool import RunMCPToolTool
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_USER_ID = "test-user-run-mcp-tool"
|
||||
_SERVER_URL = "https://remote.mcpservers.org/fetch/mcp"
|
||||
|
||||
|
||||
def _make_tool_list(*names: str):
|
||||
"""Build a list of mock MCPClientTool objects."""
|
||||
tools = []
|
||||
for name in names:
|
||||
t = MagicMock()
|
||||
t.name = name
|
||||
t.description = f"Description for {name}"
|
||||
t.input_schema = {"type": "object", "properties": {}, "required": []}
|
||||
tools.append(t)
|
||||
return tools
|
||||
|
||||
|
||||
def _make_call_result(content: list[dict], is_error: bool = False) -> MagicMock:
|
||||
result = MagicMock()
|
||||
result.is_error = is_error
|
||||
result.content = content
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# server_host helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_server_host_plain_url():
|
||||
assert server_host("https://mcp.example.com/mcp") == "mcp.example.com"
|
||||
|
||||
|
||||
def test_server_host_strips_credentials():
|
||||
"""netloc would expose user:pass — hostname must not."""
|
||||
assert server_host("https://user:secret@mcp.example.com/mcp") == "mcp.example.com"
|
||||
|
||||
|
||||
def test_server_host_with_port():
|
||||
"""Port should not appear in the returned hostname (hostname strips it)."""
|
||||
assert server_host("https://mcp.example.com:8080/mcp") == "mcp.example.com"
|
||||
|
||||
|
||||
def test_server_host_invalid_url():
|
||||
"""Falls back to the raw string for un-parseable URLs."""
|
||||
result = server_host("not-a-url")
|
||||
assert result == "not-a-url"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_missing_server_url_returns_error():
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
response = await tool._execute(user_id=_USER_ID, session=session)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "server_url" in response.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_missing_user_id_returns_error():
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
response = await tool._execute(
|
||||
user_id=None, session=session, server_url=_SERVER_URL
|
||||
)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "authentication" in response.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_ssrf_blocked_url_returns_error():
|
||||
"""Private/loopback URLs must be rejected before any network call."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked loopback"),
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID, session=session, server_url="http://localhost/mcp"
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert (
|
||||
"blocked" in response.message.lower() or "invalid" in response.message.lower()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_credential_bearing_url_returns_error():
|
||||
"""URLs with embedded user:pass@ must be rejected before any network call."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url="https://user:secret@mcp.example.com/mcp",
|
||||
)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert (
|
||||
"credential" in response.message.lower()
|
||||
or "do not include" in response.message.lower()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_non_dict_tool_arguments_returns_error():
|
||||
"""tool_arguments must be a JSON object — strings/arrays are rejected early."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
tool_name="fetch",
|
||||
tool_arguments=["this", "is", "a", "list"], # wrong type
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "json object" in response.message.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stage 1 — Discovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_returns_discovered_response():
|
||||
"""Calling with only server_url triggers discovery and returns tool list."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
mock_tools = _make_tool_list("fetch", "search")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tools)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
)
|
||||
|
||||
assert isinstance(response, MCPToolsDiscoveredResponse)
|
||||
assert len(response.tools) == 2
|
||||
assert response.tools[0].name == "fetch"
|
||||
assert response.tools[1].name == "search"
|
||||
assert response.server_url == _SERVER_URL
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_with_credentials():
|
||||
"""Stored credentials are passed as Bearer token to MCPClient."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
mock_creds = MagicMock()
|
||||
mock_creds.access_token = SecretStr("test-token-abc")
|
||||
mock_tools = _make_tool_list("push_notification")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_creds,
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tools)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
) as MockMCPClient:
|
||||
MockMCPClient.return_value = mock_client
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
)
|
||||
# Verify MCPClient was created with the resolved auth token
|
||||
MockMCPClient.assert_called_once_with(
|
||||
_SERVER_URL, auth_token="test-token-abc"
|
||||
)
|
||||
|
||||
assert isinstance(response, MCPToolsDiscoveredResponse)
|
||||
assert len(response.tools) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stage 2 — Execution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_execute_tool_returns_output_response():
|
||||
"""Calling with tool_name executes the tool and returns MCPToolOutputResponse."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
text_result = "# Example Domain\nThis domain is for examples."
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
mock_result = _make_call_result([{"type": "text", "text": text_result}])
|
||||
mock_client = AsyncMock()
|
||||
mock_client.call_tool = AsyncMock(return_value=mock_result)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
tool_name="fetch",
|
||||
tool_arguments={"url": "https://example.com"},
|
||||
)
|
||||
|
||||
assert isinstance(response, MCPToolOutputResponse)
|
||||
assert response.tool_name == "fetch"
|
||||
assert response.server_url == _SERVER_URL
|
||||
assert response.success is True
|
||||
assert text_result in response.result
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_execute_tool_parses_json_result():
|
||||
"""JSON text content items are parsed into Python objects."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
mock_result = _make_call_result(
|
||||
[{"type": "text", "text": '{"status": "ok", "count": 42}'}]
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.call_tool = AsyncMock(return_value=mock_result)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
tool_name="status",
|
||||
tool_arguments={},
|
||||
)
|
||||
|
||||
assert isinstance(response, MCPToolOutputResponse)
|
||||
assert response.result == {"status": "ok", "count": 42}
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_execute_tool_image_content():
|
||||
"""Image content items are returned as {type, data, mimeType} dicts."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
mock_result = _make_call_result(
|
||||
[{"type": "image", "data": "abc123==", "mimeType": "image/png"}]
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.call_tool = AsyncMock(return_value=mock_result)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
tool_name="screenshot",
|
||||
tool_arguments={},
|
||||
)
|
||||
|
||||
assert isinstance(response, MCPToolOutputResponse)
|
||||
assert response.result == {
|
||||
"type": "image",
|
||||
"data": "abc123==",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_execute_tool_resource_content():
|
||||
"""Resource content items are unwrapped to their resource payload."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
mock_result = _make_call_result(
|
||||
[
|
||||
{
|
||||
"type": "resource",
|
||||
"resource": {"uri": "file:///tmp/out.txt", "text": "hello"},
|
||||
}
|
||||
]
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.call_tool = AsyncMock(return_value=mock_result)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
tool_name="read_file",
|
||||
tool_arguments={},
|
||||
)
|
||||
|
||||
assert isinstance(response, MCPToolOutputResponse)
|
||||
assert response.result == {"uri": "file:///tmp/out.txt", "text": "hello"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_execute_tool_multi_item_content():
|
||||
"""Multiple content items are returned as a list."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
mock_result = _make_call_result(
|
||||
[
|
||||
{"type": "text", "text": "part one"},
|
||||
{"type": "text", "text": "part two"},
|
||||
]
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.call_tool = AsyncMock(return_value=mock_result)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
tool_name="multi",
|
||||
tool_arguments={},
|
||||
)
|
||||
|
||||
assert isinstance(response, MCPToolOutputResponse)
|
||||
assert response.result == ["part one", "part two"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_execute_tool_empty_content_returns_none():
|
||||
"""Empty content list results in result=None."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
mock_result = _make_call_result([])
|
||||
mock_client = AsyncMock()
|
||||
mock_client.call_tool = AsyncMock(return_value=mock_result)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
tool_name="ping",
|
||||
tool_arguments={},
|
||||
)
|
||||
|
||||
assert isinstance(response, MCPToolOutputResponse)
|
||||
assert response.result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_execute_tool_returns_error_on_tool_failure():
|
||||
"""When the MCP tool returns is_error=True, an ErrorResponse is returned."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
mock_result = _make_call_result(
|
||||
[{"type": "text", "text": "Tool not found"}], is_error=True
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.call_tool = AsyncMock(return_value=mock_result)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
tool_name="nonexistent",
|
||||
tool_arguments={},
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "nonexistent" in response.message
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth / credential flow
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_auth_required_without_creds_returns_setup_requirements():
|
||||
"""HTTP 401 from MCP with no stored creds → SetupRequirementsResponse."""
|
||||
from backend.util.request import HTTPClientError
|
||||
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None, # No stored credentials
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.initialize = AsyncMock(
|
||||
side_effect=HTTPClientError("Unauthorized", status_code=401)
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
with patch.object(
|
||||
RunMCPToolTool,
|
||||
"_build_setup_requirements",
|
||||
return_value=MagicMock(spec=SetupRequirementsResponse),
|
||||
) as mock_build:
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
)
|
||||
mock_build.assert_called_once()
|
||||
|
||||
# Should have returned what _build_setup_requirements returned
|
||||
assert response is mock_build.return_value
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_auth_error_with_existing_creds_returns_error():
|
||||
"""HTTP 403 when creds ARE present → generic ErrorResponse (not setup card)."""
|
||||
from backend.util.request import HTTPClientError
|
||||
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
mock_creds = MagicMock()
|
||||
mock_creds.access_token = SecretStr("stale-token")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_creds,
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.initialize = AsyncMock(
|
||||
side_effect=HTTPClientError("Forbidden", status_code=403)
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "403" in response.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_mcp_client_error_returns_error_response():
|
||||
"""MCPClientError (protocol-level) maps to a clean ErrorResponse."""
|
||||
from backend.blocks.mcp.client import MCPClientError
|
||||
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.initialize = AsyncMock(
|
||||
side_effect=MCPClientError("JSON-RPC protocol error")
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "JSON-RPC" in response.message or "protocol" in response.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unexpected_exception_returns_generic_error():
|
||||
"""Unhandled exceptions inside the MCP call don't leak traceback text to the user."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
# An unexpected error inside initialize (inside the try block)
|
||||
mock_client.initialize = AsyncMock(
|
||||
side_effect=ValueError("Unexpected internal error")
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=_SERVER_URL,
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
# Must not leak the raw exception message
|
||||
assert "Unexpected internal error" not in response.message
|
||||
assert (
|
||||
"unexpected" in response.message.lower() or "error" in response.message.lower()
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool metadata
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_tool_name():
|
||||
assert RunMCPToolTool().name == "run_mcp_tool"
|
||||
|
||||
|
||||
def test_tool_requires_auth():
|
||||
assert RunMCPToolTool().requires_auth is True
|
||||
|
||||
|
||||
def test_tool_parameters_schema():
|
||||
params = RunMCPToolTool().parameters
|
||||
assert params["type"] == "object"
|
||||
assert "server_url" in params["properties"]
|
||||
assert "tool_name" in params["properties"]
|
||||
assert "tool_arguments" in params["properties"]
|
||||
assert params["required"] == ["server_url"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Query/fragment rejection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_query_in_url_returns_error():
|
||||
"""server_url with query parameters must be rejected."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url="https://mcp.example.com/mcp?key=val",
|
||||
)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "query" in response.message.lower() or "fragment" in response.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_fragment_in_url_returns_error():
|
||||
"""server_url with a fragment must be rejected."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
response = await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url="https://mcp.example.com/mcp#section",
|
||||
)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "query" in response.message.lower() or "fragment" in response.message.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Credential lookup normalization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_credential_lookup_normalizes_trailing_slash():
|
||||
"""Credential lookup must normalize the URL (strip trailing slash)."""
|
||||
tool = RunMCPToolTool()
|
||||
session = make_session(_USER_ID)
|
||||
url_with_slash = "https://mcp.example.com/mcp/"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.validate_url", new_callable=AsyncMock
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
) as mock_lookup:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=[])
|
||||
with patch(
|
||||
"backend.copilot.tools.run_mcp_tool.MCPClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await tool._execute(
|
||||
user_id=_USER_ID,
|
||||
session=session,
|
||||
server_url=url_with_slash,
|
||||
)
|
||||
# Credential lookup should use the normalized URL (no trailing slash)
|
||||
mock_lookup.assert_called_once_with(_USER_ID, "https://mcp.example.com/mcp")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_setup_requirements
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_build_setup_requirements_returns_setup_response():
|
||||
"""_build_setup_requirements should return a SetupRequirementsResponse."""
|
||||
tool = RunMCPToolTool()
|
||||
result = tool._build_setup_requirements(
|
||||
server_url=_SERVER_URL,
|
||||
session_id="test-session",
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert result.setup_info.agent_id == _SERVER_URL
|
||||
assert "authentication" in result.message.lower()
|
||||
@@ -30,10 +30,6 @@ _TEXT_CONTENT_TYPES = {
|
||||
"application/xhtml+xml",
|
||||
"application/rss+xml",
|
||||
"application/atom+xml",
|
||||
# RFC 7807 — JSON problem details; used by many REST APIs for error responses
|
||||
"application/problem+json",
|
||||
"application/problem+xml",
|
||||
"application/ld+json",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ from typing import Any, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.e2b_sandbox import E2B_WORKDIR
|
||||
from backend.copilot.tools.sandbox import make_session_path
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.settings import Config
|
||||
@@ -21,7 +20,7 @@ from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _resolve_write_content(
|
||||
def _resolve_write_content(
|
||||
content_text: str | None,
|
||||
content_b64: str | None,
|
||||
source_path: str | None,
|
||||
@@ -31,9 +30,6 @@ async def _resolve_write_content(
|
||||
|
||||
Returns the raw bytes on success, or an ``ErrorResponse`` on validation
|
||||
failure (wrong number of sources, invalid path, file not found, etc.).
|
||||
|
||||
When an E2B sandbox is active, ``source_path`` reads from the sandbox
|
||||
filesystem instead of the local ephemeral directory.
|
||||
"""
|
||||
# Normalise empty strings to None so counting and dispatch stay in sync.
|
||||
if content_text is not None and content_text == "":
|
||||
@@ -58,7 +54,24 @@ async def _resolve_write_content(
|
||||
)
|
||||
|
||||
if source_path is not None:
|
||||
return await _read_source_path(source_path, session_id)
|
||||
validated = _validate_ephemeral_path(
|
||||
source_path, param_name="source_path", session_id=session_id
|
||||
)
|
||||
if isinstance(validated, ErrorResponse):
|
||||
return validated
|
||||
try:
|
||||
with open(validated, "rb") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
return ErrorResponse(
|
||||
message=f"Source file not found: {source_path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read source file: {e}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if content_b64 is not None:
|
||||
try:
|
||||
@@ -78,106 +91,6 @@ async def _resolve_write_content(
|
||||
return content_text.encode("utf-8")
|
||||
|
||||
|
||||
def _resolve_sandbox_path(
|
||||
path: str, session_id: str | None, param_name: str
|
||||
) -> str | ErrorResponse:
|
||||
"""Normalize *path* to an absolute sandbox path under :data:`E2B_WORKDIR`.
|
||||
|
||||
Delegates to :func:`~backend.copilot.sdk.e2b_file_tools._resolve_remote`
|
||||
and wraps any ``ValueError`` into an :class:`ErrorResponse`.
|
||||
"""
|
||||
from backend.copilot.sdk.e2b_file_tools import _resolve_remote
|
||||
|
||||
try:
|
||||
return _resolve_remote(path)
|
||||
except ValueError:
|
||||
return ErrorResponse(
|
||||
message=f"{param_name} must be within {E2B_WORKDIR}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
async def _read_source_path(source_path: str, session_id: str) -> bytes | ErrorResponse:
|
||||
"""Read *source_path* from E2B sandbox or local ephemeral directory."""
|
||||
from backend.copilot.sdk.tool_adapter import get_current_sandbox
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
remote = _resolve_sandbox_path(source_path, session_id, "source_path")
|
||||
if isinstance(remote, ErrorResponse):
|
||||
return remote
|
||||
try:
|
||||
data = await sandbox.files.read(remote, format="bytes")
|
||||
return bytes(data)
|
||||
except Exception as exc:
|
||||
return ErrorResponse(
|
||||
message=f"Source file not found on sandbox: {source_path} ({exc})",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Local fallback: validate path stays within ephemeral directory.
|
||||
validated = _validate_ephemeral_path(
|
||||
source_path, param_name="source_path", session_id=session_id
|
||||
)
|
||||
if isinstance(validated, ErrorResponse):
|
||||
return validated
|
||||
try:
|
||||
with open(validated, "rb") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
return ErrorResponse(
|
||||
message=f"Source file not found: {source_path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read source file: {e}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
async def _save_to_path(
|
||||
path: str, content: bytes, session_id: str
|
||||
) -> str | ErrorResponse:
|
||||
"""Write *content* to *path* on E2B sandbox or local ephemeral directory.
|
||||
|
||||
Returns the resolved path on success, or an ``ErrorResponse`` on failure.
|
||||
"""
|
||||
from backend.copilot.sdk.tool_adapter import get_current_sandbox
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
remote = _resolve_sandbox_path(path, session_id, "save_to_path")
|
||||
if isinstance(remote, ErrorResponse):
|
||||
return remote
|
||||
try:
|
||||
await sandbox.files.write(remote, content)
|
||||
except Exception as exc:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to write to sandbox: {path} ({exc})",
|
||||
session_id=session_id,
|
||||
)
|
||||
return remote
|
||||
|
||||
validated = _validate_ephemeral_path(
|
||||
path, param_name="save_to_path", session_id=session_id
|
||||
)
|
||||
if isinstance(validated, ErrorResponse):
|
||||
return validated
|
||||
try:
|
||||
dir_path = os.path.dirname(validated)
|
||||
if dir_path:
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
with open(validated, "wb") as f:
|
||||
f.write(content)
|
||||
except Exception as exc:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to write to local path: {path} ({exc})",
|
||||
session_id=session_id,
|
||||
)
|
||||
return validated
|
||||
|
||||
|
||||
def _validate_ephemeral_path(
|
||||
path: str, *, param_name: str, session_id: str
|
||||
) -> ErrorResponse | str:
|
||||
@@ -218,7 +131,7 @@ def _is_text_mime(mime_type: str) -> bool:
|
||||
return any(mime_type.startswith(t) for t in _TEXT_MIME_PREFIXES)
|
||||
|
||||
|
||||
async def get_manager(user_id: str, session_id: str) -> WorkspaceManager:
|
||||
async def _get_manager(user_id: str, session_id: str) -> WorkspaceManager:
|
||||
"""Create a session-scoped WorkspaceManager."""
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
return WorkspaceManager(user_id, workspace.id, session_id)
|
||||
@@ -386,7 +299,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||
|
||||
try:
|
||||
manager = await get_manager(user_id, session_id)
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
files = await manager.list_files(
|
||||
path=path_prefix, limit=limit, include_all_sessions=include_all_sessions
|
||||
)
|
||||
@@ -432,7 +345,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
class ReadWorkspaceFileTool(BaseTool):
|
||||
"""Tool for reading file content from workspace."""
|
||||
|
||||
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB for text/image files
|
||||
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
|
||||
PREVIEW_SIZE = 500
|
||||
|
||||
@property
|
||||
@@ -448,10 +361,8 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"For small text files, returns content directly. "
|
||||
"For large or binary files, returns metadata and a download URL. "
|
||||
"Use 'save_to_path' to copy the file to the working directory "
|
||||
"(sandbox or ephemeral) for processing with bash_exec or file tools. "
|
||||
"Use 'offset' and 'length' for paginated reads of large files "
|
||||
"(e.g., persisted tool outputs). "
|
||||
"Optionally use 'save_to_path' to copy the file to the ephemeral "
|
||||
"working directory for processing with bash_exec or SDK tools. "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
@@ -475,10 +386,9 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
"save_to_path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"If provided, save the file to this path in the working "
|
||||
"directory (cloud sandbox when E2B is active, or "
|
||||
"ephemeral dir otherwise) so it can be processed with "
|
||||
"bash_exec or file tools. "
|
||||
"If provided, save the file to this path in the ephemeral "
|
||||
"working directory (e.g., '/tmp/copilot-.../data.csv') "
|
||||
"so it can be processed with bash_exec or SDK tools. "
|
||||
"The file content is still returned in the response."
|
||||
),
|
||||
},
|
||||
@@ -489,20 +399,6 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
"Default is false (auto-selects based on file size/type)."
|
||||
),
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Character offset to start reading from (0-based). "
|
||||
"Use with 'length' for paginated reads of large files."
|
||||
),
|
||||
},
|
||||
"length": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Maximum number of characters to return. "
|
||||
"Defaults to full file. Use with 'offset' for paginated reads."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [], # At least one must be provided
|
||||
}
|
||||
@@ -527,16 +423,23 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
save_to_path: Optional[str] = kwargs.get("save_to_path")
|
||||
force_download_url: bool = kwargs.get("force_download_url", False)
|
||||
char_offset: int = max(0, kwargs.get("offset", 0))
|
||||
char_length: Optional[int] = kwargs.get("length")
|
||||
|
||||
if not file_id and not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide either file_id or path", session_id=session_id
|
||||
)
|
||||
|
||||
# Validate and resolve save_to_path (use sanitized real path).
|
||||
if save_to_path:
|
||||
validated_save = _validate_ephemeral_path(
|
||||
save_to_path, param_name="save_to_path", session_id=session_id
|
||||
)
|
||||
if isinstance(validated_save, ErrorResponse):
|
||||
return validated_save
|
||||
save_to_path = validated_save
|
||||
|
||||
try:
|
||||
manager = await get_manager(user_id, session_id)
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
@@ -546,38 +449,11 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
cached_content: bytes | None = None
|
||||
if save_to_path:
|
||||
cached_content = await manager.read_file_by_id(target_file_id)
|
||||
result = await _save_to_path(save_to_path, cached_content, session_id)
|
||||
if isinstance(result, ErrorResponse):
|
||||
return result
|
||||
save_to_path = result
|
||||
|
||||
# Ranged read: return a character slice directly.
|
||||
if char_offset > 0 or char_length is not None:
|
||||
raw = cached_content or await manager.read_file_by_id(target_file_id)
|
||||
text = raw.decode("utf-8", errors="replace")
|
||||
total_chars = len(text)
|
||||
end = (
|
||||
char_offset + char_length
|
||||
if char_length is not None
|
||||
else total_chars
|
||||
)
|
||||
slice_text = text[char_offset:end]
|
||||
return WorkspaceFileContentResponse(
|
||||
file_id=file_info.id,
|
||||
name=file_info.name,
|
||||
path=file_info.path,
|
||||
mime_type="text/plain",
|
||||
content_base64=base64.b64encode(slice_text.encode("utf-8")).decode(
|
||||
"utf-8"
|
||||
),
|
||||
message=(
|
||||
f"Read chars {char_offset}–"
|
||||
f"{char_offset + len(slice_text)} "
|
||||
f"of {total_chars:,} total "
|
||||
f"from {file_info.name}"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
dir_path = os.path.dirname(save_to_path)
|
||||
if dir_path:
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
with open(save_to_path, "wb") as f:
|
||||
f.write(cached_content)
|
||||
|
||||
is_small = file_info.size_bytes <= self.MAX_INLINE_SIZE_BYTES
|
||||
is_text = _is_text_mime(file_info.mime_type)
|
||||
@@ -753,7 +629,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
content_text: str | None = kwargs.get("content")
|
||||
content_b64: str | None = kwargs.get("content_base64")
|
||||
|
||||
resolved = await _resolve_write_content(
|
||||
resolved = _resolve_write_content(
|
||||
content_text,
|
||||
content_b64,
|
||||
source_path_arg,
|
||||
@@ -772,7 +648,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
|
||||
try:
|
||||
await scan_content_safe(content, filename=filename)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
rec = await manager.write_file(
|
||||
content=content,
|
||||
filename=filename,
|
||||
@@ -899,7 +775,7 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
manager = await get_manager(user_id, session_id)
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
|
||||
@@ -102,68 +102,67 @@ class TestValidateEphemeralPath:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
class TestResolveWriteContent:
|
||||
async def test_no_sources_returns_error(self):
|
||||
def test_no_sources_returns_error(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
result = await _resolve_write_content(None, None, None, "s1")
|
||||
result = _resolve_write_content(None, None, None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
async def test_multiple_sources_returns_error(self):
|
||||
def test_multiple_sources_returns_error(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
result = await _resolve_write_content("text", "b64data", None, "s1")
|
||||
result = _resolve_write_content("text", "b64data", None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
async def test_plain_text_content(self):
|
||||
result = await _resolve_write_content("hello world", None, None, "s1")
|
||||
def test_plain_text_content(self):
|
||||
result = _resolve_write_content("hello world", None, None, "s1")
|
||||
assert result == b"hello world"
|
||||
|
||||
async def test_base64_content(self):
|
||||
def test_base64_content(self):
|
||||
raw = b"binary data"
|
||||
b64 = base64.b64encode(raw).decode()
|
||||
result = await _resolve_write_content(None, b64, None, "s1")
|
||||
result = _resolve_write_content(None, b64, None, "s1")
|
||||
assert result == raw
|
||||
|
||||
async def test_invalid_base64_returns_error(self):
|
||||
def test_invalid_base64_returns_error(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
result = await _resolve_write_content(None, "not-valid-b64!!!", None, "s1")
|
||||
result = _resolve_write_content(None, "not-valid-b64!!!", None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "base64" in result.message.lower()
|
||||
|
||||
async def test_source_path(self, ephemeral_dir):
|
||||
def test_source_path(self, ephemeral_dir):
|
||||
target = ephemeral_dir / "input.txt"
|
||||
target.write_bytes(b"file content")
|
||||
result = await _resolve_write_content(None, None, str(target), "s1")
|
||||
result = _resolve_write_content(None, None, str(target), "s1")
|
||||
assert result == b"file content"
|
||||
|
||||
async def test_source_path_not_found(self, ephemeral_dir):
|
||||
def test_source_path_not_found(self, ephemeral_dir):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
missing = str(ephemeral_dir / "nope.txt")
|
||||
result = await _resolve_write_content(None, None, missing, "s1")
|
||||
result = _resolve_write_content(None, None, missing, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
async def test_source_path_outside_ephemeral(self, ephemeral_dir, tmp_path):
|
||||
def test_source_path_outside_ephemeral(self, ephemeral_dir, tmp_path):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
outside = tmp_path / "outside.txt"
|
||||
outside.write_text("nope")
|
||||
result = await _resolve_write_content(None, None, str(outside), "s1")
|
||||
result = _resolve_write_content(None, None, str(outside), "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
async def test_empty_string_sources_treated_as_none(self):
|
||||
def test_empty_string_sources_treated_as_none(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
# All empty strings → same as no sources
|
||||
result = await _resolve_write_content("", "", "", "s1")
|
||||
result = _resolve_write_content("", "", "", "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
async def test_empty_string_source_path_with_text(self):
|
||||
def test_empty_string_source_path_with_text(self):
|
||||
# source_path="" should be normalised to None, so only content counts
|
||||
result = await _resolve_write_content("hello", "", "", "s1")
|
||||
result = _resolve_write_content("hello", "", "", "s1")
|
||||
assert result == b"hello"
|
||||
|
||||
|
||||
@@ -236,65 +235,6 @@ async def test_workspace_file_round_trip(setup_test_data):
|
||||
assert not any(f.file_id == file_id for f in list_resp2.files)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ranged reads (offset / length)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_read_workspace_file_with_offset_and_length(setup_test_data):
|
||||
"""Read a slice of a text file using offset and length."""
|
||||
user = setup_test_data["user"]
|
||||
session = make_session(user.id)
|
||||
|
||||
# Write a known-content file
|
||||
content = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" * 100 # 2600 chars
|
||||
write_tool = WriteWorkspaceFileTool()
|
||||
write_resp = await write_tool._execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
filename="ranged_test.txt",
|
||||
content=content,
|
||||
)
|
||||
assert isinstance(write_resp, WorkspaceWriteResponse), write_resp.message
|
||||
file_id = write_resp.file_id
|
||||
|
||||
from backend.copilot.tools.workspace_files import WorkspaceFileContentResponse
|
||||
|
||||
read_tool = ReadWorkspaceFileTool()
|
||||
|
||||
# Read with offset=100, length=50
|
||||
resp = await read_tool._execute(
|
||||
user_id=user.id, session=session, file_id=file_id, offset=100, length=50
|
||||
)
|
||||
assert isinstance(resp, WorkspaceFileContentResponse), resp.message
|
||||
decoded = base64.b64decode(resp.content_base64).decode()
|
||||
assert decoded == content[100:150]
|
||||
assert "100" in resp.message
|
||||
assert "2,600" in resp.message # total chars (comma-formatted)
|
||||
|
||||
# Read with offset only (no length) — returns from offset to end
|
||||
resp2 = await read_tool._execute(
|
||||
user_id=user.id, session=session, file_id=file_id, offset=2500
|
||||
)
|
||||
assert isinstance(resp2, WorkspaceFileContentResponse)
|
||||
decoded2 = base64.b64decode(resp2.content_base64).decode()
|
||||
assert decoded2 == content[2500:]
|
||||
assert len(decoded2) == 100
|
||||
|
||||
# Read with offset beyond file length — returns empty string
|
||||
resp3 = await read_tool._execute(
|
||||
user_id=user.id, session=session, file_id=file_id, offset=9999, length=10
|
||||
)
|
||||
assert isinstance(resp3, WorkspaceFileContentResponse)
|
||||
decoded3 = base64.b64decode(resp3.content_base64).decode()
|
||||
assert decoded3 == ""
|
||||
|
||||
# Cleanup
|
||||
delete_tool = DeleteWorkspaceFileTool()
|
||||
await delete_tool._execute(user_id=user.id, session=session, file_id=file_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_write_workspace_file_source_path(setup_test_data):
|
||||
"""E2E: write a file from ephemeral source_path to workspace."""
|
||||
|
||||
@@ -81,7 +81,6 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.CLAUDE_4_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_SONNET: 5,
|
||||
LlmModel.CLAUDE_4_6_OPUS: 14,
|
||||
LlmModel.CLAUDE_4_6_SONNET: 9,
|
||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||
LlmModel.CLAUDE_4_5_OPUS: 14,
|
||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||
|
||||
@@ -305,7 +305,6 @@ class DatabaseManager(AppService):
|
||||
delete_chat_session = _(chat_db.delete_chat_session)
|
||||
get_next_sequence = _(chat_db.get_next_sequence)
|
||||
update_tool_message_content = _(chat_db.update_tool_message_content)
|
||||
update_chat_session_title = _(chat_db.update_chat_session_title)
|
||||
|
||||
|
||||
class DatabaseManagerClient(AppServiceClient):
|
||||
@@ -476,4 +475,3 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
delete_chat_session = d.delete_chat_session
|
||||
get_next_sequence = d.get_next_sequence
|
||||
update_tool_message_content = d.update_tool_message_content
|
||||
update_chat_session_title = d.update_chat_session_title
|
||||
|
||||
@@ -184,17 +184,17 @@ async def find_webhook_by_credentials_and_props(
|
||||
credentials_id: str,
|
||||
webhook_type: str,
|
||||
resource: str,
|
||||
events: list[str] | None = None,
|
||||
events: Optional[list[str]],
|
||||
) -> Webhook | None:
|
||||
where: IntegrationWebhookWhereInput = {
|
||||
"userId": user_id,
|
||||
"credentialsId": credentials_id,
|
||||
"webhookType": webhook_type,
|
||||
"resource": resource,
|
||||
}
|
||||
if events is not None:
|
||||
where["events"] = {"has_every": events}
|
||||
webhook = await IntegrationWebhook.prisma().find_first(where=where)
|
||||
webhook = await IntegrationWebhook.prisma().find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"credentialsId": credentials_id,
|
||||
"webhookType": webhook_type,
|
||||
"resource": resource,
|
||||
**({"events": {"has_every": events}} if events else {}),
|
||||
},
|
||||
)
|
||||
return Webhook.from_db(webhook) if webhook else None
|
||||
|
||||
|
||||
|
||||
@@ -327,16 +327,11 @@ async def get_workspace_total_size(workspace_id: str) -> int:
|
||||
"""
|
||||
Get the total size of all files in a workspace.
|
||||
|
||||
Queries Prisma directly (skipping Pydantic model conversion) and only
|
||||
fetches the ``sizeBytes`` column to minimise data transfer.
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace ID
|
||||
|
||||
Returns:
|
||||
Total size in bytes
|
||||
"""
|
||||
files = await UserWorkspaceFile.prisma().find_many(
|
||||
where={"workspaceId": workspace_id, "isDeleted": False},
|
||||
)
|
||||
return sum(f.sizeBytes for f in files)
|
||||
files = await list_workspace_files(workspace_id)
|
||||
return sum(file.size_bytes for file in files)
|
||||
|
||||
@@ -32,7 +32,6 @@ from backend.data.execution import (
|
||||
from backend.data.graph import GraphModel, Node
|
||||
from backend.data.model import USER_TIMEZONE_NOT_SET, CredentialsMetaInput, GraphInput
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
get_async_execution_queue,
|
||||
@@ -892,7 +891,6 @@ async def add_graph_execution(
|
||||
if execution_context is None:
|
||||
user = await udb.get_user_by_id(user_id)
|
||||
settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id)
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
execution_context = ExecutionContext(
|
||||
# Execution identity
|
||||
@@ -909,8 +907,6 @@ async def add_graph_execution(
|
||||
),
|
||||
# Execution hierarchy
|
||||
root_execution_id=graph_exec.id,
|
||||
# Workspace (enables workspace:// file resolution in blocks)
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -368,12 +368,6 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
mock_workspace = mocker.MagicMock()
|
||||
mock_workspace.id = "test-workspace-id"
|
||||
mocker.patch(
|
||||
"backend.executor.utils.get_or_create_workspace",
|
||||
new=mocker.AsyncMock(return_value=mock_workspace),
|
||||
)
|
||||
|
||||
# Setup mock returns
|
||||
# The function returns (graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip)
|
||||
@@ -649,12 +643,6 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
mock_workspace = mocker.MagicMock()
|
||||
mock_workspace.id = "test-workspace-id"
|
||||
mocker.patch(
|
||||
"backend.executor.utils.get_or_create_workspace",
|
||||
new=mocker.AsyncMock(return_value=mock_workspace),
|
||||
)
|
||||
|
||||
# Setup returns - include nodes_to_skip in the tuple
|
||||
mock_validate.return_value = (
|
||||
@@ -693,10 +681,6 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
||||
assert "nodes_to_skip" in captured_kwargs
|
||||
assert captured_kwargs["nodes_to_skip"] == nodes_to_skip
|
||||
|
||||
# Verify workspace_id is set in the execution context
|
||||
assert "execution_context" in captured_kwargs
|
||||
assert captured_kwargs["execution_context"].workspace_id == "test-workspace-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_graph_execution_in_review_status_cancels_pending_reviews(
|
||||
|
||||
@@ -76,6 +76,7 @@ class TelegramWebhooksManager(BaseWebhooksManager):
|
||||
credentials_id=credentials.id,
|
||||
webhook_type=webhook_type,
|
||||
resource=resource,
|
||||
events=None, # Ignore events for this lookup
|
||||
):
|
||||
# Re-register with Telegram using the same URL but new allowed_updates
|
||||
ingress_url = webhook_ingress_url(self.PROVIDER_NAME, existing.id)
|
||||
@@ -142,6 +143,10 @@ class TelegramWebhooksManager(BaseWebhooksManager):
|
||||
elif "video" in message:
|
||||
event_type = "message.video"
|
||||
else:
|
||||
logger.warning(
|
||||
"Unknown Telegram webhook payload type; "
|
||||
f"message.keys() = {message.keys()}"
|
||||
)
|
||||
event_type = "message.other"
|
||||
elif "edited_message" in payload:
|
||||
event_type = "message.edited_message"
|
||||
|
||||
@@ -105,13 +105,8 @@ def validate_with_jsonschema(
|
||||
return str(e)
|
||||
|
||||
|
||||
def sanitize_string(value: str) -> str:
|
||||
"""Remove PostgreSQL-incompatible control characters from string.
|
||||
|
||||
Strips \\x00-\\x08, \\x0B-\\x0C, \\x0E-\\x1F, \\x7F while keeping tab,
|
||||
newline, and carriage return. Use this before inserting free-form text
|
||||
into PostgreSQL text/varchar columns.
|
||||
"""
|
||||
def _sanitize_string(value: str) -> str:
|
||||
"""Remove PostgreSQL-incompatible control characters from string."""
|
||||
return POSTGRES_CONTROL_CHARS.sub("", value)
|
||||
|
||||
|
||||
@@ -121,7 +116,7 @@ def sanitize_json(data: Any) -> Any:
|
||||
# 1. First convert to basic JSON-serializable types (handles Pydantic models)
|
||||
# 2. Then sanitize strings in the result
|
||||
basic_result = to_dict(data)
|
||||
return to_dict(basic_result, custom_encoder={str: sanitize_string})
|
||||
return to_dict(basic_result, custom_encoder={str: _sanitize_string})
|
||||
except Exception as e:
|
||||
# Log the failure and fall back to string representation
|
||||
logger.error(
|
||||
@@ -134,7 +129,7 @@ def sanitize_json(data: Any) -> Any:
|
||||
)
|
||||
|
||||
# Ultimate fallback: convert to string representation and sanitize
|
||||
return sanitize_string(str(data))
|
||||
return _sanitize_string(str(data))
|
||||
|
||||
|
||||
class SafeJson(Json):
|
||||
|
||||
@@ -219,11 +219,8 @@ def parse_url(url: str) -> URL:
|
||||
"""Canonicalizes and parses a URL string."""
|
||||
url = url.strip("/ ").replace("\\", "/")
|
||||
|
||||
# Ensure scheme is present for proper parsing.
|
||||
# Avoid regex to sidestep CodeQL py/polynomial-redos on user-controlled
|
||||
# input. We only need to detect "scheme://"; urlparse() and the
|
||||
# ALLOWED_SCHEMES check downstream handle full validation.
|
||||
if "://" not in url:
|
||||
# Ensure scheme is present for proper parsing
|
||||
if not re.match(r"[a-z0-9+.\-]+://", url):
|
||||
url = f"http://{url}"
|
||||
|
||||
return urlparse(url)
|
||||
|
||||
@@ -4,7 +4,14 @@ import re
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Generic, List, Set, Tuple, Type, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr, ValidationInfo, field_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
SecretStr,
|
||||
ValidationInfo,
|
||||
field_validator,
|
||||
)
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
JsonConfigSettingsSource,
|
||||
@@ -413,13 +420,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="Maximum file size in MB for workspace files (1-1024 MB)",
|
||||
)
|
||||
|
||||
max_workspace_storage_mb: int = Field(
|
||||
default=500,
|
||||
ge=1,
|
||||
le=10240,
|
||||
description="Maximum total workspace storage per user in MB.",
|
||||
)
|
||||
|
||||
# AutoMod configuration
|
||||
automod_enabled: bool = Field(
|
||||
default=False,
|
||||
@@ -728,12 +728,17 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
langfuse_public_key: str = Field(default="", description="Langfuse public key")
|
||||
langfuse_secret_key: str = Field(default="", description="Langfuse secret key")
|
||||
langfuse_host: str = Field(
|
||||
default="https://cloud.langfuse.com", description="Langfuse host URL"
|
||||
default="https://cloud.langfuse.com",
|
||||
description="Langfuse host URL for fetching prompts",
|
||||
)
|
||||
langfuse_tracing_environment: str = Field(
|
||||
default="local", description="Tracing environment tag (local/dev/production)"
|
||||
otlp_tracing_host: str = Field(
|
||||
default="",
|
||||
description="OTLP trace ingestion host URL (for example, Product Intelligence).",
|
||||
)
|
||||
otlp_tracing_token: SecretStr = Field(
|
||||
default=SecretStr(""),
|
||||
description="Bearer token for OTLP trace ingestion endpoint.",
|
||||
)
|
||||
|
||||
# PostHog analytics
|
||||
posthog_api_key: str = Field(default="", description="PostHog API key")
|
||||
posthog_host: str = Field(
|
||||
|
||||
@@ -79,10 +79,6 @@ def truncate(value: Any, size_limit: int) -> Any:
|
||||
largest str_limit and list_limit that fit.
|
||||
"""
|
||||
|
||||
# Fast path: plain strings don't need the binary search machinery.
|
||||
if isinstance(value, str):
|
||||
return _truncate_string_middle(value, size_limit)
|
||||
|
||||
def measure(val):
|
||||
try:
|
||||
return len(str(val))
|
||||
@@ -90,7 +86,7 @@ def truncate(value: Any, size_limit: int) -> Any:
|
||||
return sys.getsizeof(val)
|
||||
|
||||
# Reasonable bounds for string and list limits
|
||||
STR_MIN, STR_MAX = min(8, size_limit), size_limit
|
||||
STR_MIN, STR_MAX = 8, 2**16
|
||||
LIST_MIN, LIST_MAX = 1, 2**12
|
||||
|
||||
# Binary search for the largest str_limit and list_limit that fit
|
||||
|
||||
@@ -228,43 +228,52 @@ class WorkspaceManager:
|
||||
|
||||
# Create database record - handle race condition where another request
|
||||
# created a file at the same path between our check and create
|
||||
async def _persist_db_record(
|
||||
retries: int = 2 if overwrite else 0,
|
||||
) -> WorkspaceFile:
|
||||
"""Create DB record, retrying on conflict if overwrite=True.
|
||||
|
||||
Cleans up the orphaned storage file on any failure.
|
||||
"""
|
||||
try:
|
||||
return await db.create_workspace_file(
|
||||
workspace_id=self.workspace_id,
|
||||
file_id=file_id,
|
||||
name=filename,
|
||||
path=path,
|
||||
storage_path=storage_path,
|
||||
mime_type=mime_type,
|
||||
size_bytes=len(content),
|
||||
checksum=checksum,
|
||||
)
|
||||
except UniqueViolationError:
|
||||
if retries > 0:
|
||||
# Delete conflicting file and retry
|
||||
existing = await db.get_workspace_file_by_path(
|
||||
self.workspace_id, path
|
||||
)
|
||||
if existing:
|
||||
await self.delete_file(existing.id)
|
||||
return await _persist_db_record(retries=retries - 1)
|
||||
if overwrite:
|
||||
raise ValueError(
|
||||
f"Unable to overwrite file at path: {path} "
|
||||
f"(concurrent write conflict)"
|
||||
) from None
|
||||
raise ValueError(f"File already exists at path: {path}")
|
||||
|
||||
try:
|
||||
file = await _persist_db_record()
|
||||
file = await db.create_workspace_file(
|
||||
workspace_id=self.workspace_id,
|
||||
file_id=file_id,
|
||||
name=filename,
|
||||
path=path,
|
||||
storage_path=storage_path,
|
||||
mime_type=mime_type,
|
||||
size_bytes=len(content),
|
||||
checksum=checksum,
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# Race condition: another request created a file at this path
|
||||
if overwrite:
|
||||
# Re-fetch and delete the conflicting file, then retry
|
||||
existing = await db.get_workspace_file_by_path(self.workspace_id, path)
|
||||
if existing:
|
||||
await self.delete_file(existing.id)
|
||||
# Retry the create - if this also fails, clean up storage file
|
||||
try:
|
||||
file = await db.create_workspace_file(
|
||||
workspace_id=self.workspace_id,
|
||||
file_id=file_id,
|
||||
name=filename,
|
||||
path=path,
|
||||
storage_path=storage_path,
|
||||
mime_type=mime_type,
|
||||
size_bytes=len(content),
|
||||
checksum=checksum,
|
||||
)
|
||||
except Exception:
|
||||
# Clean up orphaned storage file on retry failure
|
||||
try:
|
||||
await storage.delete(storage_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up orphaned storage file: {e}")
|
||||
raise
|
||||
else:
|
||||
# Clean up the orphaned storage file before raising
|
||||
try:
|
||||
await storage.delete(storage_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up orphaned storage file: {e}")
|
||||
raise ValueError(f"File already exists at path: {path}")
|
||||
except Exception:
|
||||
# Any other database error (connection, validation, etc.) - clean up storage
|
||||
try:
|
||||
await storage.delete(storage_path)
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,158 +0,0 @@
|
||||
"""
|
||||
Tests for WorkspaceManager.write_file UniqueViolationError handling.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.errors import UniqueViolationError
|
||||
|
||||
from backend.data.workspace import WorkspaceFile
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
_NOW = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def _make_workspace_file(
|
||||
id: str = "existing-file-id",
|
||||
workspace_id: str = "ws-123",
|
||||
name: str = "test.txt",
|
||||
path: str = "/test.txt",
|
||||
storage_path: str = "ws-123/existing-uuid/test.txt",
|
||||
mime_type: str = "text/plain",
|
||||
size_bytes: int = 5,
|
||||
checksum: str = "abc123",
|
||||
) -> WorkspaceFile:
|
||||
"""Create a mock WorkspaceFile with sensible defaults."""
|
||||
return WorkspaceFile(
|
||||
id=id,
|
||||
workspace_id=workspace_id,
|
||||
name=name,
|
||||
path=path,
|
||||
storage_path=storage_path,
|
||||
mime_type=mime_type,
|
||||
size_bytes=size_bytes,
|
||||
checksum=checksum,
|
||||
metadata={},
|
||||
created_at=_NOW,
|
||||
updated_at=_NOW,
|
||||
)
|
||||
|
||||
|
||||
def _unique_violation() -> UniqueViolationError:
|
||||
"""Create a UniqueViolationError for testing."""
|
||||
data = {
|
||||
"user_facing_error": {
|
||||
"message": "Unique constraint failed on the fields: (`path`)",
|
||||
}
|
||||
}
|
||||
return UniqueViolationError(data)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager():
|
||||
return WorkspaceManager(user_id="user-123", workspace_id="ws-123")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage():
|
||||
storage = AsyncMock()
|
||||
storage.store.return_value = "ws-123/some-uuid/test.txt"
|
||||
storage.delete = AsyncMock()
|
||||
return storage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db():
|
||||
"""Create a mock workspace_db() return value."""
|
||||
db = MagicMock()
|
||||
db.create_workspace_file = AsyncMock()
|
||||
db.get_workspace_file_by_path = AsyncMock()
|
||||
db.get_workspace_file = AsyncMock()
|
||||
return db
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_no_overwrite_unique_violation_raises_and_cleans_up(
|
||||
manager, mock_storage, mock_db
|
||||
):
|
||||
"""overwrite=False + UniqueViolationError → ValueError + storage cleanup."""
|
||||
mock_db.get_workspace_file_by_path.return_value = None
|
||||
mock_db.create_workspace_file.side_effect = _unique_violation()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
),
|
||||
patch("backend.util.workspace.workspace_db", return_value=mock_db),
|
||||
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
|
||||
):
|
||||
with pytest.raises(ValueError, match="File already exists"):
|
||||
await manager.write_file(
|
||||
filename="test.txt", content=b"hello", overwrite=False
|
||||
)
|
||||
|
||||
mock_storage.delete.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_overwrite_conflict_then_retry_succeeds(
|
||||
manager, mock_storage, mock_db
|
||||
):
|
||||
"""overwrite=True + conflict → delete existing → retry succeeds."""
|
||||
created_file = _make_workspace_file()
|
||||
existing_file = _make_workspace_file(id="old-id")
|
||||
|
||||
mock_db.get_workspace_file_by_path.return_value = existing_file
|
||||
mock_db.create_workspace_file.side_effect = [_unique_violation(), created_file]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
),
|
||||
patch("backend.util.workspace.workspace_db", return_value=mock_db),
|
||||
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
|
||||
patch.object(manager, "delete_file", new_callable=AsyncMock) as mock_delete,
|
||||
):
|
||||
result = await manager.write_file(
|
||||
filename="test.txt", content=b"hello", overwrite=True
|
||||
)
|
||||
|
||||
assert result == created_file
|
||||
mock_delete.assert_called_once_with("old-id")
|
||||
mock_storage.delete.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_overwrite_exhausted_retries_raises_and_cleans_up(
|
||||
manager, mock_storage, mock_db
|
||||
):
|
||||
"""overwrite=True + all retries exhausted → ValueError + cleanup."""
|
||||
existing_file = _make_workspace_file(id="old-id")
|
||||
|
||||
mock_db.get_workspace_file_by_path.return_value = existing_file
|
||||
# Initial + 2 retries = 3 UniqueViolationErrors
|
||||
mock_db.create_workspace_file.side_effect = [
|
||||
_unique_violation(),
|
||||
_unique_violation(),
|
||||
_unique_violation(),
|
||||
]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.util.workspace.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
),
|
||||
patch("backend.util.workspace.workspace_db", return_value=mock_db),
|
||||
patch("backend.util.workspace.scan_content_safe", new_callable=AsyncMock),
|
||||
patch.object(manager, "delete_file", new_callable=AsyncMock),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Unable to overwrite.*concurrent write"):
|
||||
await manager.write_file(
|
||||
filename="test.txt", content=b"hello", overwrite=True
|
||||
)
|
||||
|
||||
mock_storage.delete.assert_called_once()
|
||||
229
autogpt_platform/backend/poetry.lock
generated
229
autogpt_platform/backend/poetry.lock
generated
@@ -899,17 +899,17 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "claude-agent-sdk"
|
||||
version = "0.1.45"
|
||||
version = "0.1.39"
|
||||
description = "Python SDK for Claude Code"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "claude_agent_sdk-0.1.45-py3-none-macosx_11_0_arm64.whl", hash = "sha256:26a5cc60c3a394f5b814f6b2f67650819cbcd38c405bbdc11582b3e097b3a770"},
|
||||
{file = "claude_agent_sdk-0.1.45-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:decc741b53e0b2c10a64fd84c15acca1102077d9f99941c54905172cd95160c9"},
|
||||
{file = "claude_agent_sdk-0.1.45-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:7d48dcf4178c704e4ccbf3f1f4ebf20b3de3f03d0592086c1f3abd16b8ca441e"},
|
||||
{file = "claude_agent_sdk-0.1.45-py3-none-win_amd64.whl", hash = "sha256:d1cf34995109c513d8daabcae7208edc260b553b53462a9ac06a7c40e240a288"},
|
||||
{file = "claude_agent_sdk-0.1.45.tar.gz", hash = "sha256:97c1e981431b5af1e08c34731906ab8d4a58fe0774a04df0ea9587dcabc85151"},
|
||||
{file = "claude_agent_sdk-0.1.39-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6ed6a79781f545b761b9fe467bc5ae213a103c9d3f0fe7a9dad3c01790ed58fa"},
|
||||
{file = "claude_agent_sdk-0.1.39-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:0c03b5a3772eaec42e29ea39240c7d24b760358082f2e36336db9e71dde3dda4"},
|
||||
{file = "claude_agent_sdk-0.1.39-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:d2665c9e87b6ffece590bcdd6eb9def47cde4809b0d2f66e0a61a719189be7c9"},
|
||||
{file = "claude_agent_sdk-0.1.39-py3-none-win_amd64.whl", hash = "sha256:d03324daf7076be79d2dd05944559aabf4cc11c98d3a574b992a442a7c7a26d6"},
|
||||
{file = "claude_agent_sdk-0.1.39.tar.gz", hash = "sha256:dcf0ebd5a638c9a7d9f3af7640932a9212b2705b7056e4f08bd3968a865b4268"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -3230,39 +3230,6 @@ pydantic = ">=1.10.7,<3.0"
|
||||
requests = ">=2,<3"
|
||||
wrapt = ">=1.14,<2.0"
|
||||
|
||||
[[package]]
|
||||
name = "langsmith"
|
||||
version = "0.7.7"
|
||||
description = "Client library to connect to the LangSmith Observability and Evaluation Platform."
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "langsmith-0.7.7-py3-none-any.whl", hash = "sha256:ef3d0aff77917bf3776368e90f387df5ffd7cb7cff11ece0ec4fd227e433b5de"},
|
||||
{file = "langsmith-0.7.7.tar.gz", hash = "sha256:2294d3c4a5a8205ef38880c1c412d85322e6055858ae999ef6641c815995d437"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = ">=0.23.0,<1"
|
||||
orjson = {version = ">=3.9.14", markers = "platform_python_implementation != \"PyPy\""}
|
||||
packaging = ">=23.2"
|
||||
pydantic = ">=2,<3"
|
||||
requests = ">=2.0.0"
|
||||
requests-toolbelt = ">=1.0.0"
|
||||
uuid-utils = ">=0.12.0,<1.0"
|
||||
xxhash = ">=3.0.0"
|
||||
zstandard = ">=0.23.0"
|
||||
|
||||
[package.extras]
|
||||
claude-agent-sdk = ["claude-agent-sdk (>=0.1.0) ; python_version >= \"3.10\""]
|
||||
google-adk = ["google-adk (>=1.0.0)", "wrapt (>=1.16.0)"]
|
||||
langsmith-pyo3 = ["langsmith-pyo3 (>=0.1.0rc2)"]
|
||||
openai-agents = ["openai-agents (>=0.0.3)"]
|
||||
otel = ["opentelemetry-api (>=1.30.0)", "opentelemetry-exporter-otlp-proto-http (>=1.30.0)", "opentelemetry-sdk (>=1.30.0)"]
|
||||
pytest = ["pytest (>=7.0.0)", "rich (>=13.9.4)", "vcrpy (>=7.0.0)"]
|
||||
sandbox = ["websockets (>=15.0)"]
|
||||
vcr = ["vcrpy (>=7.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "launchdarkly-eventsource"
|
||||
version = "1.5.1"
|
||||
@@ -7780,38 +7747,6 @@ h2 = ["h2 (>=4,<5)"]
|
||||
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
|
||||
zstd = ["backports-zstd (>=1.0.0) ; python_version < \"3.14\""]
|
||||
|
||||
[[package]]
|
||||
name = "uuid-utils"
|
||||
version = "0.14.1"
|
||||
description = "Fast, drop-in replacement for Python's uuid module, powered by Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "uuid_utils-0.14.1-cp39-abi3-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:93a3b5dc798a54a1feb693f2d1cb4cf08258c32ff05ae4929b5f0a2ca624a4f0"},
|
||||
{file = "uuid_utils-0.14.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:ccd65a4b8e83af23eae5e56d88034b2fe7264f465d3e830845f10d1591b81741"},
|
||||
{file = "uuid_utils-0.14.1-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b56b0cacd81583834820588378e432b0696186683b813058b707aedc1e16c4b1"},
|
||||
{file = "uuid_utils-0.14.1-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bb3cf14de789097320a3c56bfdfdd51b1225d11d67298afbedee7e84e3837c96"},
|
||||
{file = "uuid_utils-0.14.1-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60e0854a90d67f4b0cc6e54773deb8be618f4c9bad98d3326f081423b5d14fae"},
|
||||
{file = "uuid_utils-0.14.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce6743ba194de3910b5feb1a62590cd2587e33a73ab6af8a01b642ceb5055862"},
|
||||
{file = "uuid_utils-0.14.1-cp39-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:043fb58fde6cf1620a6c066382f04f87a8e74feb0f95a585e4ed46f5d44af57b"},
|
||||
{file = "uuid_utils-0.14.1-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:c915d53f22945e55fe0d3d3b0b87fd965a57f5fd15666fd92d6593a73b1dd297"},
|
||||
{file = "uuid_utils-0.14.1-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:0972488e3f9b449e83f006ead5a0e0a33ad4a13e4462e865b7c286ab7d7566a3"},
|
||||
{file = "uuid_utils-0.14.1-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:1c238812ae0c8ffe77d8d447a32c6dfd058ea4631246b08b5a71df586ff08531"},
|
||||
{file = "uuid_utils-0.14.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:bec8f8ef627af86abf8298e7ec50926627e29b34fa907fcfbedb45aaa72bca43"},
|
||||
{file = "uuid_utils-0.14.1-cp39-abi3-win32.whl", hash = "sha256:b54d6aa6252d96bac1fdbc80d26ba71bad9f220b2724d692ad2f2310c22ef523"},
|
||||
{file = "uuid_utils-0.14.1-cp39-abi3-win_amd64.whl", hash = "sha256:fc27638c2ce267a0ce3e06828aff786f91367f093c80625ee21dad0208e0f5ba"},
|
||||
{file = "uuid_utils-0.14.1-cp39-abi3-win_arm64.whl", hash = "sha256:b04cb49b42afbc4ff8dbc60cf054930afc479d6f4dd7f1ec3bbe5dbfdde06b7a"},
|
||||
{file = "uuid_utils-0.14.1-pp311-pypy311_pp73-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:b197cd5424cf89fb019ca7f53641d05bfe34b1879614bed111c9c313b5574cd8"},
|
||||
{file = "uuid_utils-0.14.1-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:12c65020ba6cb6abe1d57fcbfc2d0ea0506c67049ee031714057f5caf0f9bc9c"},
|
||||
{file = "uuid_utils-0.14.1-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b5d2ad28063d422ccc2c28d46471d47b61a58de885d35113a8f18cb547e25bf"},
|
||||
{file = "uuid_utils-0.14.1-pp311-pypy311_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:da2234387b45fde40b0fedfee64a0ba591caeea9c48c7698ab6e2d85c7991533"},
|
||||
{file = "uuid_utils-0.14.1-pp311-pypy311_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:50fffc2827348c1e48972eed3d1c698959e63f9d030aa5dd82ba451113158a62"},
|
||||
{file = "uuid_utils-0.14.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1dbe718765f70f5b7f9b7f66b6a937802941b1cc56bcf642ce0274169741e01"},
|
||||
{file = "uuid_utils-0.14.1-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:258186964039a8e36db10810c1ece879d229b01331e09e9030bc5dcabe231bd2"},
|
||||
{file = "uuid_utils-0.14.1.tar.gz", hash = "sha256:9bfc95f64af80ccf129c604fb6b8ca66c6f256451e32bc4570f760e4309c9b69"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "uvicorn"
|
||||
version = "0.40.0"
|
||||
@@ -8357,156 +8292,6 @@ cffi = ">=1.16.0"
|
||||
[package.extras]
|
||||
test = ["pytest"]
|
||||
|
||||
[[package]]
|
||||
name = "xxhash"
|
||||
version = "3.6.0"
|
||||
description = "Python binding for xxHash"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "xxhash-3.6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:87ff03d7e35c61435976554477a7f4cd1704c3596a89a8300d5ce7fc83874a71"},
|
||||
{file = "xxhash-3.6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f572dfd3d0e2eb1a57511831cf6341242f5a9f8298a45862d085f5b93394a27d"},
|
||||
{file = "xxhash-3.6.0-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:89952ea539566b9fed2bbd94e589672794b4286f342254fad28b149f9615fef8"},
|
||||
{file = "xxhash-3.6.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:48e6f2ffb07a50b52465a1032c3cf1f4a5683f944acaca8a134a2f23674c2058"},
|
||||
{file = "xxhash-3.6.0-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:b5b848ad6c16d308c3ac7ad4ba6bede80ed5df2ba8ed382f8932df63158dd4b2"},
|
||||
{file = "xxhash-3.6.0-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a034590a727b44dd8ac5914236a7b8504144447a9682586c3327e935f33ec8cc"},
|
||||
{file = "xxhash-3.6.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8a8f1972e75ebdd161d7896743122834fe87378160c20e97f8b09166213bf8cc"},
|
||||
{file = "xxhash-3.6.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ee34327b187f002a596d7b167ebc59a1b729e963ce645964bbc050d2f1b73d07"},
|
||||
{file = "xxhash-3.6.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:339f518c3c7a850dd033ab416ea25a692759dc7478a71131fe8869010d2b75e4"},
|
||||
{file = "xxhash-3.6.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:bf48889c9630542d4709192578aebbd836177c9f7a4a2778a7d6340107c65f06"},
|
||||
{file = "xxhash-3.6.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:5576b002a56207f640636056b4160a378fe36a58db73ae5c27a7ec8db35f71d4"},
|
||||
{file = "xxhash-3.6.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:af1f3278bd02814d6dedc5dec397993b549d6f16c19379721e5a1d31e132c49b"},
|
||||
{file = "xxhash-3.6.0-cp310-cp310-win32.whl", hash = "sha256:aed058764db109dc9052720da65fafe84873b05eb8b07e5e653597951af57c3b"},
|
||||
{file = "xxhash-3.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:e82da5670f2d0d98950317f82a0e4a0197150ff19a6df2ba40399c2a3b9ae5fb"},
|
||||
{file = "xxhash-3.6.0-cp310-cp310-win_arm64.whl", hash = "sha256:4a082ffff8c6ac07707fb6b671caf7c6e020c75226c561830b73d862060f281d"},
|
||||
{file = "xxhash-3.6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b47bbd8cf2d72797f3c2772eaaac0ded3d3af26481a26d7d7d41dc2d3c46b04a"},
|
||||
{file = "xxhash-3.6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2b6821e94346f96db75abaa6e255706fb06ebd530899ed76d32cd99f20dc52fa"},
|
||||
{file = "xxhash-3.6.0-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:d0a9751f71a1a65ce3584e9cae4467651c7e70c9d31017fa57574583a4540248"},
|
||||
{file = "xxhash-3.6.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8b29ee68625ab37b04c0b40c3fafdf24d2f75ccd778333cfb698f65f6c463f62"},
|
||||
{file = "xxhash-3.6.0-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:6812c25fe0d6c36a46ccb002f40f27ac903bf18af9f6dd8f9669cb4d176ab18f"},
|
||||
{file = "xxhash-3.6.0-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:4ccbff013972390b51a18ef1255ef5ac125c92dc9143b2d1909f59abc765540e"},
|
||||
{file = "xxhash-3.6.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:297b7fbf86c82c550e12e8fb71968b3f033d27b874276ba3624ea868c11165a8"},
|
||||
{file = "xxhash-3.6.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:dea26ae1eb293db089798d3973a5fc928a18fdd97cc8801226fae705b02b14b0"},
|
||||
{file = "xxhash-3.6.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:7a0b169aafb98f4284f73635a8e93f0735f9cbde17bd5ec332480484241aaa77"},
|
||||
{file = "xxhash-3.6.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:08d45aef063a4531b785cd72de4887766d01dc8f362a515693df349fdb825e0c"},
|
||||
{file = "xxhash-3.6.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:929142361a48ee07f09121fe9e96a84950e8d4df3bb298ca5d88061969f34d7b"},
|
||||
{file = "xxhash-3.6.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:51312c768403d8540487dbbfb557454cfc55589bbde6424456951f7fcd4facb3"},
|
||||
{file = "xxhash-3.6.0-cp311-cp311-win32.whl", hash = "sha256:d1927a69feddc24c987b337ce81ac15c4720955b667fe9b588e02254b80446fd"},
|
||||
{file = "xxhash-3.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:26734cdc2d4ffe449b41d186bbeac416f704a482ed835d375a5c0cb02bc63fef"},
|
||||
{file = "xxhash-3.6.0-cp311-cp311-win_arm64.whl", hash = "sha256:d72f67ef8bf36e05f5b6c65e8524f265bd61071471cd4cf1d36743ebeeeb06b7"},
|
||||
{file = "xxhash-3.6.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:01362c4331775398e7bb34e3ab403bc9ee9f7c497bc7dee6272114055277dd3c"},
|
||||
{file = "xxhash-3.6.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b7b2df81a23f8cb99656378e72501b2cb41b1827c0f5a86f87d6b06b69f9f204"},
|
||||
{file = "xxhash-3.6.0-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:dc94790144e66b14f67b10ac8ed75b39ca47536bf8800eb7c24b50271ea0c490"},
|
||||
{file = "xxhash-3.6.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:93f107c673bccf0d592cdba077dedaf52fe7f42dcd7676eba1f6d6f0c3efffd2"},
|
||||
{file = "xxhash-3.6.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2aa5ee3444c25b69813663c9f8067dcfaa2e126dc55e8dddf40f4d1c25d7effa"},
|
||||
{file = "xxhash-3.6.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:f7f99123f0e1194fa59cc69ad46dbae2e07becec5df50a0509a808f90a0f03f0"},
|
||||
{file = "xxhash-3.6.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:49e03e6fe2cac4a1bc64952dd250cf0dbc5ef4ebb7b8d96bce82e2de163c82a2"},
|
||||
{file = "xxhash-3.6.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:bd17fede52a17a4f9a7bc4472a5867cb0b160deeb431795c0e4abe158bc784e9"},
|
||||
{file = "xxhash-3.6.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:6fb5f5476bef678f69db04f2bd1efbed3030d2aba305b0fc1773645f187d6a4e"},
|
||||
{file = "xxhash-3.6.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:843b52f6d88071f87eba1631b684fcb4b2068cd2180a0224122fe4ef011a9374"},
|
||||
{file = "xxhash-3.6.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:7d14a6cfaf03b1b6f5f9790f76880601ccc7896aff7ab9cd8978a939c1eb7e0d"},
|
||||
{file = "xxhash-3.6.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:418daf3db71e1413cfe211c2f9a528456936645c17f46b5204705581a45390ae"},
|
||||
{file = "xxhash-3.6.0-cp312-cp312-win32.whl", hash = "sha256:50fc255f39428a27299c20e280d6193d8b63b8ef8028995323bf834a026b4fbb"},
|
||||
{file = "xxhash-3.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:c0f2ab8c715630565ab8991b536ecded9416d615538be8ecddce43ccf26cbc7c"},
|
||||
{file = "xxhash-3.6.0-cp312-cp312-win_arm64.whl", hash = "sha256:eae5c13f3bc455a3bbb68bdc513912dc7356de7e2280363ea235f71f54064829"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:599e64ba7f67472481ceb6ee80fa3bd828fd61ba59fb11475572cc5ee52b89ec"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7d8b8aaa30fca4f16f0c84a5c8d7ddee0e25250ec2796c973775373257dde8f1"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:d597acf8506d6e7101a4a44a5e428977a51c0fadbbfd3c39650cca9253f6e5a6"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:858dc935963a33bc33490128edc1c12b0c14d9c7ebaa4e387a7869ecc4f3e263"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:ba284920194615cb8edf73bf52236ce2e1664ccd4a38fdb543506413529cc546"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:4b54219177f6c6674d5378bd862c6aedf64725f70dd29c472eaae154df1a2e89"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:42c36dd7dbad2f5238950c377fcbf6811b1cdb1c444fab447960030cea60504d"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f22927652cba98c44639ffdc7aaf35828dccf679b10b31c4ad72a5b530a18eb7"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b45fad44d9c5c119e9c6fbf2e1c656a46dc68e280275007bbfd3d572b21426db"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:6f2580ffab1a8b68ef2b901cde7e55fa8da5e4be0977c68f78fc80f3c143de42"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:40c391dd3cd041ebc3ffe6f2c862f402e306eb571422e0aa918d8070ba31da11"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f205badabde7aafd1a31e8ca2a3e5a763107a71c397c4481d6a804eb5063d8bd"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313-win32.whl", hash = "sha256:2577b276e060b73b73a53042ea5bd5203d3e6347ce0d09f98500f418a9fcf799"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313-win_amd64.whl", hash = "sha256:757320d45d2fbcce8f30c42a6b2f47862967aea7bf458b9625b4bbe7ee390392"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313-win_arm64.whl", hash = "sha256:457b8f85dec5825eed7b69c11ae86834a018b8e3df5e77783c999663da2f96d6"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:a42e633d75cdad6d625434e3468126c73f13f7584545a9cf34e883aa1710e702"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:568a6d743219e717b07b4e03b0a828ce593833e498c3b64752e0f5df6bfe84db"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:bec91b562d8012dae276af8025a55811b875baace6af510412a5e58e3121bc54"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:78e7f2f4c521c30ad5e786fdd6bae89d47a32672a80195467b5de0480aa97b1f"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:3ed0df1b11a79856df5ffcab572cbd6b9627034c1c748c5566fa79df9048a7c5"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:0e4edbfc7d420925b0dd5e792478ed393d6e75ff8fc219a6546fb446b6a417b1"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fba27a198363a7ef87f8c0f6b171ec36b674fe9053742c58dd7e3201c1ab30ee"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:794fe9145fe60191c6532fa95063765529770edcdd67b3d537793e8004cabbfd"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:6105ef7e62b5ac73a837778efc331a591d8442f8ef5c7e102376506cb4ae2729"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:f01375c0e55395b814a679b3eea205db7919ac2af213f4a6682e01220e5fe292"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:d706dca2d24d834a4661619dcacf51a75c16d65985718d6a7d73c1eeeb903ddf"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:5f059d9faeacd49c0215d66f4056e1326c80503f51a1532ca336a385edadd033"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313t-win32.whl", hash = "sha256:1244460adc3a9be84731d72b8e80625788e5815b68da3da8b83f78115a40a7ec"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313t-win_amd64.whl", hash = "sha256:b1e420ef35c503869c4064f4a2f2b08ad6431ab7b229a05cce39d74268bca6b8"},
|
||||
{file = "xxhash-3.6.0-cp313-cp313t-win_arm64.whl", hash = "sha256:ec44b73a4220623235f67a996c862049f375df3b1052d9899f40a6382c32d746"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:a40a3d35b204b7cc7643cbcf8c9976d818cb47befcfac8bbefec8038ac363f3e"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:a54844be970d3fc22630b32d515e79a90d0a3ddb2644d8d7402e3c4c8da61405"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:016e9190af8f0a4e3741343777710e3d5717427f175adfdc3e72508f59e2a7f3"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4f6f72232f849eb9d0141e2ebe2677ece15adfd0fa599bc058aad83c714bb2c6"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:63275a8aba7865e44b1813d2177e0f5ea7eadad3dd063a21f7cf9afdc7054063"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:3cd01fa2aa00d8b017c97eb46b9a794fbdca53fc14f845f5a328c71254b0abb7"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0226aa89035b62b6a86d3c68df4d7c1f47a342b8683da2b60cedcddb46c4d95b"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c6e193e9f56e4ca4923c61238cdaced324f0feac782544eb4c6d55ad5cc99ddd"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:9176dcaddf4ca963d4deb93866d739a343c01c969231dbe21680e13a5d1a5bf0"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:c1ce4009c97a752e682b897aa99aef84191077a9433eb237774689f14f8ec152"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:8cb2f4f679b01513b7adbb9b1b2f0f9cdc31b70007eaf9d59d0878809f385b11"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:653a91d7c2ab54a92c19ccf43508b6a555440b9be1bc8be553376778be7f20b5"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314-win32.whl", hash = "sha256:a756fe893389483ee8c394d06b5ab765d96e68fbbfe6fde7aa17e11f5720559f"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314-win_amd64.whl", hash = "sha256:39be8e4e142550ef69629c9cd71b88c90e9a5db703fecbcf265546d9536ca4ad"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314-win_arm64.whl", hash = "sha256:25915e6000338999236f1eb68a02a32c3275ac338628a7eaa5a269c401995679"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:c5294f596a9017ca5a3e3f8884c00b91ab2ad2933cf288f4923c3fd4346cf3d4"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:1cf9dcc4ab9cff01dfbba78544297a3a01dafd60f3bde4e2bfd016cf7e4ddc67"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:01262da8798422d0685f7cef03b2bd3f4f46511b02830861df548d7def4402ad"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:51a73fb7cb3a3ead9f7a8b583ffd9b8038e277cdb8cb87cf890e88b3456afa0b"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:b9c6df83594f7df8f7f708ce5ebeacfc69f72c9fbaaababf6cf4758eaada0c9b"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:627f0af069b0ea56f312fd5189001c24578868643203bca1abbc2c52d3a6f3ca"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:aa912c62f842dfd013c5f21a642c9c10cd9f4c4e943e0af83618b4a404d9091a"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:b465afd7909db30168ab62afe40b2fcf79eedc0b89a6c0ab3123515dc0df8b99"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:a881851cf38b0a70e7c4d3ce81fc7afd86fbc2a024f4cfb2a97cf49ce04b75d3"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:9b3222c686a919a0f3253cfc12bb118b8b103506612253b5baeaac10d8027cf6"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:c5aa639bc113e9286137cec8fadc20e9cd732b2cc385c0b7fa673b84fc1f2a93"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5c1343d49ac102799905e115aee590183c3921d475356cb24b4de29a4bc56518"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314t-win32.whl", hash = "sha256:5851f033c3030dd95c086b4a36a2683c2ff4a799b23af60977188b057e467119"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314t-win_amd64.whl", hash = "sha256:0444e7967dac37569052d2409b00a8860c2135cff05502df4da80267d384849f"},
|
||||
{file = "xxhash-3.6.0-cp314-cp314t-win_arm64.whl", hash = "sha256:bb79b1e63f6fd84ec778a4b1916dfe0a7c3fdb986c06addd5db3a0d413819d95"},
|
||||
{file = "xxhash-3.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7dac94fad14a3d1c92affb661021e1d5cbcf3876be5f5b4d90730775ccb7ac41"},
|
||||
{file = "xxhash-3.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6965e0e90f1f0e6cb78da568c13d4a348eeb7f40acfd6d43690a666a459458b8"},
|
||||
{file = "xxhash-3.6.0-cp38-cp38-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:2ab89a6b80f22214b43d98693c30da66af910c04f9858dd39c8e570749593d7e"},
|
||||
{file = "xxhash-3.6.0-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4903530e866b7a9c1eadfd3fa2fbe1b97d3aed4739a80abf506eb9318561c850"},
|
||||
{file = "xxhash-3.6.0-cp38-cp38-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4da8168ae52c01ac64c511d6f4a709479da8b7a4a1d7621ed51652f93747dffa"},
|
||||
{file = "xxhash-3.6.0-cp38-cp38-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:97460eec202017f719e839a0d3551fbc0b2fcc9c6c6ffaa5af85bbd5de432788"},
|
||||
{file = "xxhash-3.6.0-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:45aae0c9df92e7fa46fbb738737324a563c727990755ec1965a6a339ea10a1df"},
|
||||
{file = "xxhash-3.6.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:0d50101e57aad86f4344ca9b32d091a2135a9d0a4396f19133426c88025b09f1"},
|
||||
{file = "xxhash-3.6.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:9085e798c163ce310d91f8aa6b325dda3c2944c93c6ce1edb314030d4167cc65"},
|
||||
{file = "xxhash-3.6.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:a87f271a33fad0e5bf3be282be55d78df3a45ae457950deb5241998790326f87"},
|
||||
{file = "xxhash-3.6.0-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:9e040d3e762f84500961791fa3709ffa4784d4dcd7690afc655c095e02fff05f"},
|
||||
{file = "xxhash-3.6.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:b0359391c3dad6de872fefb0cf5b69d55b0655c55ee78b1bb7a568979b2ce96b"},
|
||||
{file = "xxhash-3.6.0-cp38-cp38-win32.whl", hash = "sha256:e4ff728a2894e7f436b9e94c667b0f426b9c74b71f900cf37d5468c6b5da0536"},
|
||||
{file = "xxhash-3.6.0-cp38-cp38-win_amd64.whl", hash = "sha256:01be0c5b500c5362871fc9cfdf58c69b3e5c4f531a82229ddb9eb1eb14138004"},
|
||||
{file = "xxhash-3.6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cc604dc06027dbeb8281aeac5899c35fcfe7c77b25212833709f0bff4ce74d2a"},
|
||||
{file = "xxhash-3.6.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:277175a73900ad43a8caeb8b99b9604f21fe8d7c842f2f9061a364a7e220ddb7"},
|
||||
{file = "xxhash-3.6.0-cp39-cp39-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:cfbc5b91397c8c2972fdac13fb3e4ed2f7f8ccac85cd2c644887557780a9b6e2"},
|
||||
{file = "xxhash-3.6.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2762bfff264c4e73c0e507274b40634ff465e025f0eaf050897e88ec8367575d"},
|
||||
{file = "xxhash-3.6.0-cp39-cp39-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2f171a900d59d51511209f7476933c34a0c2c711078d3c80e74e0fe4f38680ec"},
|
||||
{file = "xxhash-3.6.0-cp39-cp39-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:780b90c313348f030b811efc37b0fa1431163cb8db8064cf88a7936b6ce5f222"},
|
||||
{file = "xxhash-3.6.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:18b242455eccdfcd1fa4134c431a30737d2b4f045770f8fe84356b3469d4b919"},
|
||||
{file = "xxhash-3.6.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:a75ffc1bd5def584129774c158e108e5d768e10b75813f2b32650bb041066ed6"},
|
||||
{file = "xxhash-3.6.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:1fc1ed882d1e8df932a66e2999429ba6cc4d5172914c904ab193381fba825360"},
|
||||
{file = "xxhash-3.6.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:44e342e8cc11b4e79dae5c57f2fb6360c3c20cc57d32049af8f567f5b4bcb5f4"},
|
||||
{file = "xxhash-3.6.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:c2f9ccd5c4be370939a2e17602fbc49995299203da72a3429db013d44d590e86"},
|
||||
{file = "xxhash-3.6.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:02ea4cb627c76f48cd9fb37cf7ab22bd51e57e1b519807234b473faebe526796"},
|
||||
{file = "xxhash-3.6.0-cp39-cp39-win32.whl", hash = "sha256:6551880383f0e6971dc23e512c9ccc986147ce7bfa1cd2e4b520b876c53e9f3d"},
|
||||
{file = "xxhash-3.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:7c35c4cdc65f2a29f34425c446f2f5cdcd0e3c34158931e1cc927ece925ab802"},
|
||||
{file = "xxhash-3.6.0-cp39-cp39-win_arm64.whl", hash = "sha256:ffc578717a347baf25be8397cb10d2528802d24f94cfc005c0e44fef44b5cdd6"},
|
||||
{file = "xxhash-3.6.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:0f7b7e2ec26c1666ad5fc9dbfa426a6a3367ceaf79db5dd76264659d509d73b0"},
|
||||
{file = "xxhash-3.6.0-pp311-pypy311_pp73-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:5dc1e14d14fa0f5789ec29a7062004b5933964bb9b02aae6622b8f530dc40296"},
|
||||
{file = "xxhash-3.6.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:881b47fc47e051b37d94d13e7455131054b56749b91b508b0907eb07900d1c13"},
|
||||
{file = "xxhash-3.6.0-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c6dc31591899f5e5666f04cc2e529e69b4072827085c1ef15294d91a004bc1bd"},
|
||||
{file = "xxhash-3.6.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:15e0dac10eb9309508bfc41f7f9deaa7755c69e35af835db9cb10751adebc35d"},
|
||||
{file = "xxhash-3.6.0.tar.gz", hash = "sha256:f0162a78b13a0d7617b2845b90c763339d1f1d82bb04a4b07f4ab535cc5e05d6"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yarl"
|
||||
version = "1.22.0"
|
||||
@@ -8840,4 +8625,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "7189c9725ca42dfe6672632fe801c61248d87d3dd1259747b0ed9579b19fe088"
|
||||
content-hash = "3869bc3fb8ea50e7101daffce13edbe563c8af568cb751adfa31fb9bb5c8318a"
|
||||
|
||||
@@ -16,7 +16,7 @@ anthropic = "^0.79.0"
|
||||
apscheduler = "^3.11.1"
|
||||
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||
bleach = { extras = ["css"], version = "^6.2.0" }
|
||||
claude-agent-sdk = "0.1.45" # see copilot/sdk/sdk_compat_test.py for capability checks
|
||||
claude-agent-sdk = "^0.1.39" # see copilot/sdk/sdk_compat_test.py for capability checks
|
||||
click = "^8.2.0"
|
||||
cryptography = "^46.0"
|
||||
discord-py = "^2.5.2"
|
||||
@@ -90,7 +90,6 @@ stagehand = "^0.5.1"
|
||||
gravitas-md2gdocs = "^0.1.0"
|
||||
posthog = "^7.6.0"
|
||||
fpdf2 = "^2.8.6"
|
||||
langsmith = "^0.7.7"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
aiohappyeyeballs = "^2.6.1"
|
||||
|
||||
@@ -162,7 +162,7 @@ services:
|
||||
context: ../
|
||||
dockerfile: autogpt_platform/backend/Dockerfile
|
||||
target: server
|
||||
command: ["python", "-u", "-m", "backend.copilot.executor"]
|
||||
command: ["python", "-m", "backend.copilot.executor"]
|
||||
develop:
|
||||
watch:
|
||||
- path: ./
|
||||
@@ -182,7 +182,6 @@ services:
|
||||
<<: *backend-env-files
|
||||
environment:
|
||||
<<: *backend-env
|
||||
PYTHONUNBUFFERED: "1"
|
||||
ports:
|
||||
- "8008:8008"
|
||||
networks:
|
||||
|
||||
@@ -34,7 +34,6 @@ export const ContentRenderer: React.FC<{
|
||||
if (
|
||||
renderer?.name === "ImageRenderer" ||
|
||||
renderer?.name === "VideoRenderer" ||
|
||||
renderer?.name === "WorkspaceFileRenderer" ||
|
||||
!shortContent
|
||||
) {
|
||||
return (
|
||||
|
||||
@@ -7,9 +7,7 @@ import {
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/molecules/DropdownMenu/DropdownMenu";
|
||||
import { SidebarProvider } from "@/components/ui/sidebar";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { DotsThree, UploadSimple } from "@phosphor-icons/react";
|
||||
import { useCallback, useRef, useState } from "react";
|
||||
import { DotsThree } from "@phosphor-icons/react";
|
||||
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
||||
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
|
||||
import { DeleteChatDialog } from "./components/DeleteChatDialog/DeleteChatDialog";
|
||||
@@ -19,49 +17,6 @@ import { ScaleLoader } from "./components/ScaleLoader/ScaleLoader";
|
||||
import { useCopilotPage } from "./useCopilotPage";
|
||||
|
||||
export function CopilotPage() {
|
||||
const [isDragging, setIsDragging] = useState(false);
|
||||
const [droppedFiles, setDroppedFiles] = useState<File[]>([]);
|
||||
const dragCounter = useRef(0);
|
||||
|
||||
const handleDroppedFilesConsumed = useCallback(() => {
|
||||
setDroppedFiles([]);
|
||||
}, []);
|
||||
|
||||
function handleDragEnter(e: React.DragEvent) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
dragCounter.current += 1;
|
||||
if (e.dataTransfer.types.includes("Files")) {
|
||||
setIsDragging(true);
|
||||
}
|
||||
}
|
||||
|
||||
function handleDragOver(e: React.DragEvent) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
}
|
||||
|
||||
function handleDragLeave(e: React.DragEvent) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
dragCounter.current -= 1;
|
||||
if (dragCounter.current === 0) {
|
||||
setIsDragging(false);
|
||||
}
|
||||
}
|
||||
|
||||
function handleDrop(e: React.DragEvent) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
dragCounter.current = 0;
|
||||
setIsDragging(false);
|
||||
|
||||
const files = Array.from(e.dataTransfer.files);
|
||||
if (files.length > 0) {
|
||||
setDroppedFiles(files);
|
||||
}
|
||||
}
|
||||
|
||||
const {
|
||||
sessionId,
|
||||
messages,
|
||||
@@ -74,7 +29,6 @@ export function CopilotPage() {
|
||||
isLoadingSession,
|
||||
isSessionError,
|
||||
isCreatingSession,
|
||||
isUploadingFiles,
|
||||
isUserLoading,
|
||||
isLoggedIn,
|
||||
// Mobile drawer
|
||||
@@ -109,26 +63,8 @@ export function CopilotPage() {
|
||||
className="h-[calc(100vh-72px)] min-h-0"
|
||||
>
|
||||
{!isMobile && <ChatSidebar />}
|
||||
<div
|
||||
className="relative flex h-full w-full flex-col overflow-hidden bg-[#f8f8f9] px-0"
|
||||
onDragEnter={handleDragEnter}
|
||||
onDragOver={handleDragOver}
|
||||
onDragLeave={handleDragLeave}
|
||||
onDrop={handleDrop}
|
||||
>
|
||||
<div className="relative flex h-full w-full flex-col overflow-hidden bg-[#f8f8f9] px-0">
|
||||
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
|
||||
{/* Drop overlay */}
|
||||
<div
|
||||
className={cn(
|
||||
"pointer-events-none absolute inset-0 z-50 flex flex-col items-center justify-center gap-3 rounded-lg border-2 border-dashed border-violet-400 bg-violet-500/10 transition-opacity duration-150",
|
||||
isDragging ? "opacity-100" : "opacity-0",
|
||||
)}
|
||||
>
|
||||
<UploadSimple className="h-10 w-10 text-violet-500" weight="bold" />
|
||||
<span className="text-lg font-medium text-violet-600">
|
||||
Drop files here
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex-1 overflow-hidden">
|
||||
<ChatContainer
|
||||
messages={messages}
|
||||
@@ -142,9 +78,6 @@ export function CopilotPage() {
|
||||
onCreateSession={createSession}
|
||||
onSend={onSend}
|
||||
onStop={stop}
|
||||
isUploadingFiles={isUploadingFiles}
|
||||
droppedFiles={droppedFiles}
|
||||
onDroppedFilesConsumed={handleDroppedFilesConsumed}
|
||||
headerSlot={
|
||||
isMobile && sessionId ? (
|
||||
<div className="flex justify-end">
|
||||
|
||||
@@ -18,14 +18,9 @@ export interface ChatContainerProps {
|
||||
/** True when backend has an active stream but we haven't reconnected yet. */
|
||||
isReconnecting?: boolean;
|
||||
onCreateSession: () => void | Promise<string>;
|
||||
onSend: (message: string, files?: File[]) => void | Promise<void>;
|
||||
onSend: (message: string) => void | Promise<void>;
|
||||
onStop: () => void;
|
||||
isUploadingFiles?: boolean;
|
||||
headerSlot?: ReactNode;
|
||||
/** Files dropped onto the chat window. */
|
||||
droppedFiles?: File[];
|
||||
/** Called after droppedFiles have been consumed by ChatInput. */
|
||||
onDroppedFilesConsumed?: () => void;
|
||||
}
|
||||
export const ChatContainer = ({
|
||||
messages,
|
||||
@@ -39,10 +34,7 @@ export const ChatContainer = ({
|
||||
onCreateSession,
|
||||
onSend,
|
||||
onStop,
|
||||
isUploadingFiles,
|
||||
headerSlot,
|
||||
droppedFiles,
|
||||
onDroppedFilesConsumed,
|
||||
}: ChatContainerProps) => {
|
||||
const isBusy =
|
||||
status === "streaming" ||
|
||||
@@ -64,7 +56,6 @@ export const ChatContainer = ({
|
||||
error={error}
|
||||
isLoading={isLoadingSession}
|
||||
headerSlot={headerSlot}
|
||||
sessionID={sessionId}
|
||||
/>
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
@@ -78,11 +69,8 @@ export const ChatContainer = ({
|
||||
onSend={onSend}
|
||||
disabled={isBusy}
|
||||
isStreaming={isBusy}
|
||||
isUploadingFiles={isUploadingFiles}
|
||||
onStop={onStop}
|
||||
placeholder="What else can I help with?"
|
||||
droppedFiles={droppedFiles}
|
||||
onDroppedFilesConsumed={onDroppedFilesConsumed}
|
||||
/>
|
||||
</motion.div>
|
||||
</div>
|
||||
@@ -92,9 +80,6 @@ export const ChatContainer = ({
|
||||
isCreatingSession={isCreatingSession}
|
||||
onCreateSession={onCreateSession}
|
||||
onSend={onSend}
|
||||
isUploadingFiles={isUploadingFiles}
|
||||
droppedFiles={droppedFiles}
|
||||
onDroppedFilesConsumed={onDroppedFilesConsumed}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import {
|
||||
PromptInputBody,
|
||||
PromptInputButton,
|
||||
PromptInputFooter,
|
||||
PromptInputSubmit,
|
||||
PromptInputTextarea,
|
||||
@@ -7,67 +8,39 @@ import {
|
||||
} from "@/components/ai-elements/prompt-input";
|
||||
import { InputGroup } from "@/components/ui/input-group";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ChangeEvent, useEffect, useState } from "react";
|
||||
import { AttachmentMenu } from "./components/AttachmentMenu";
|
||||
import { FileChips } from "./components/FileChips";
|
||||
import { RecordingButton } from "./components/RecordingButton";
|
||||
import { CircleNotchIcon, MicrophoneIcon } from "@phosphor-icons/react";
|
||||
import { ChangeEvent } from "react";
|
||||
import { RecordingIndicator } from "./components/RecordingIndicator";
|
||||
import { useChatInput } from "./useChatInput";
|
||||
import { useVoiceRecording } from "./useVoiceRecording";
|
||||
|
||||
export interface Props {
|
||||
onSend: (message: string, files?: File[]) => void | Promise<void>;
|
||||
onSend: (message: string) => void | Promise<void>;
|
||||
disabled?: boolean;
|
||||
isStreaming?: boolean;
|
||||
isUploadingFiles?: boolean;
|
||||
onStop?: () => void;
|
||||
placeholder?: string;
|
||||
className?: string;
|
||||
inputId?: string;
|
||||
/** Files dropped onto the chat window by the parent. */
|
||||
droppedFiles?: File[];
|
||||
/** Called after droppedFiles have been merged into internal state. */
|
||||
onDroppedFilesConsumed?: () => void;
|
||||
}
|
||||
|
||||
export function ChatInput({
|
||||
onSend,
|
||||
disabled = false,
|
||||
isStreaming = false,
|
||||
isUploadingFiles = false,
|
||||
onStop,
|
||||
placeholder = "Type your message...",
|
||||
className,
|
||||
inputId = "chat-input",
|
||||
droppedFiles,
|
||||
onDroppedFilesConsumed,
|
||||
}: Props) {
|
||||
const [files, setFiles] = useState<File[]>([]);
|
||||
|
||||
// Merge files dropped onto the chat window into internal state.
|
||||
useEffect(() => {
|
||||
if (droppedFiles && droppedFiles.length > 0) {
|
||||
setFiles((prev) => [...prev, ...droppedFiles]);
|
||||
onDroppedFilesConsumed?.();
|
||||
}
|
||||
}, [droppedFiles, onDroppedFilesConsumed]);
|
||||
|
||||
const hasFiles = files.length > 0;
|
||||
const isBusy = disabled || isStreaming || isUploadingFiles;
|
||||
|
||||
const {
|
||||
value,
|
||||
setValue,
|
||||
handleSubmit,
|
||||
handleChange: baseHandleChange,
|
||||
} = useChatInput({
|
||||
onSend: async (message: string) => {
|
||||
await onSend(message, hasFiles ? files : undefined);
|
||||
// Only clear files after successful send (onSend throws on failure)
|
||||
setFiles([]);
|
||||
},
|
||||
disabled: isBusy,
|
||||
canSendEmpty: hasFiles,
|
||||
onSend,
|
||||
disabled: disabled || isStreaming,
|
||||
inputId,
|
||||
});
|
||||
|
||||
@@ -82,7 +55,7 @@ export function ChatInput({
|
||||
audioStream,
|
||||
} = useVoiceRecording({
|
||||
setValue,
|
||||
disabled: isBusy,
|
||||
disabled: disabled || isStreaming,
|
||||
isStreaming,
|
||||
value,
|
||||
inputId,
|
||||
@@ -94,18 +67,7 @@ export function ChatInput({
|
||||
}
|
||||
|
||||
const canSend =
|
||||
!disabled &&
|
||||
(!!value.trim() || hasFiles) &&
|
||||
!isRecording &&
|
||||
!isTranscribing;
|
||||
|
||||
function handleFilesSelected(newFiles: File[]) {
|
||||
setFiles((prev) => [...prev, ...newFiles]);
|
||||
}
|
||||
|
||||
function handleRemoveFile(index: number) {
|
||||
setFiles((prev) => prev.filter((_, i) => i !== index));
|
||||
}
|
||||
!disabled && !!value.trim() && !isRecording && !isTranscribing;
|
||||
|
||||
return (
|
||||
<form onSubmit={handleSubmit} className={cn("relative flex-1", className)}>
|
||||
@@ -116,11 +78,6 @@ export function ChatInput({
|
||||
"border-red-400 ring-1 ring-red-400 has-[[data-slot=input-group-control]:focus-visible]:border-red-400 has-[[data-slot=input-group-control]:focus-visible]:ring-red-400",
|
||||
)}
|
||||
>
|
||||
<FileChips
|
||||
files={files}
|
||||
onRemove={handleRemoveFile}
|
||||
isUploading={isUploadingFiles}
|
||||
/>
|
||||
<PromptInputBody className="relative block w-full">
|
||||
<PromptInputTextarea
|
||||
id={inputId}
|
||||
@@ -147,18 +104,25 @@ export function ChatInput({
|
||||
|
||||
<PromptInputFooter>
|
||||
<PromptInputTools>
|
||||
<AttachmentMenu
|
||||
onFilesSelected={handleFilesSelected}
|
||||
disabled={isBusy}
|
||||
/>
|
||||
{showMicButton && (
|
||||
<RecordingButton
|
||||
isRecording={isRecording}
|
||||
isTranscribing={isTranscribing}
|
||||
isStreaming={isStreaming}
|
||||
disabled={disabled || isTranscribing || isStreaming}
|
||||
<PromptInputButton
|
||||
aria-label={isRecording ? "Stop recording" : "Start recording"}
|
||||
onClick={toggleRecording}
|
||||
/>
|
||||
disabled={disabled || isTranscribing || isStreaming}
|
||||
className={cn(
|
||||
"size-[2.625rem] rounded-[96px] border border-zinc-300 bg-transparent text-black hover:border-zinc-600 hover:bg-zinc-100",
|
||||
isRecording &&
|
||||
"animate-pulse border-red-500 bg-red-500 text-white hover:border-red-600 hover:bg-red-600",
|
||||
isTranscribing && "bg-zinc-100 text-zinc-400",
|
||||
isStreaming && "opacity-40",
|
||||
)}
|
||||
>
|
||||
{isTranscribing ? (
|
||||
<CircleNotchIcon className="h-4 w-4 animate-spin" />
|
||||
) : (
|
||||
<MicrophoneIcon className="h-4 w-4" weight="bold" />
|
||||
)}
|
||||
</PromptInputButton>
|
||||
)}
|
||||
</PromptInputTools>
|
||||
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Plus as PlusIcon } from "@phosphor-icons/react";
|
||||
import { useRef } from "react";
|
||||
|
||||
interface Props {
|
||||
onFilesSelected: (files: File[]) => void;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
export function AttachmentMenu({ onFilesSelected, disabled }: Props) {
|
||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
function handleClick() {
|
||||
fileInputRef.current?.click();
|
||||
}
|
||||
|
||||
function handleFileChange(e: React.ChangeEvent<HTMLInputElement>) {
|
||||
const files = Array.from(e.target.files ?? []);
|
||||
if (files.length > 0) {
|
||||
onFilesSelected(files);
|
||||
}
|
||||
// Reset so the same file can be re-selected
|
||||
e.target.value = "";
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
multiple
|
||||
className="hidden"
|
||||
onChange={handleFileChange}
|
||||
tabIndex={-1}
|
||||
/>
|
||||
<Button
|
||||
type="button"
|
||||
variant="icon"
|
||||
size="icon"
|
||||
aria-label="Attach file"
|
||||
disabled={disabled}
|
||||
onClick={handleClick}
|
||||
className={cn(
|
||||
"border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700",
|
||||
disabled && "opacity-40",
|
||||
)}
|
||||
>
|
||||
<PlusIcon className="h-4 w-4" weight="bold" />
|
||||
</Button>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
CircleNotch as CircleNotchIcon,
|
||||
X as XIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
|
||||
interface Props {
|
||||
files: File[];
|
||||
onRemove: (index: number) => void;
|
||||
isUploading?: boolean;
|
||||
}
|
||||
|
||||
export function FileChips({ files, onRemove, isUploading }: Props) {
|
||||
if (files.length === 0) return null;
|
||||
|
||||
return (
|
||||
<div className="flex w-full flex-wrap gap-2 px-3 pb-2 pt-2">
|
||||
{files.map((file, index) => (
|
||||
<span
|
||||
key={`${file.name}-${file.size}-${index}`}
|
||||
className={cn(
|
||||
"inline-flex items-center gap-1 rounded-full bg-zinc-100 px-3 py-1 text-sm text-zinc-700",
|
||||
isUploading && "opacity-70",
|
||||
)}
|
||||
>
|
||||
<span className="max-w-[160px] truncate">{file.name}</span>
|
||||
{isUploading ? (
|
||||
<CircleNotchIcon className="ml-0.5 h-3 w-3 animate-spin text-zinc-400" />
|
||||
) : (
|
||||
<button
|
||||
type="button"
|
||||
aria-label={`Remove ${file.name}`}
|
||||
onClick={() => onRemove(index)}
|
||||
className="ml-0.5 rounded-full p-0.5 text-zinc-400 transition-colors hover:bg-zinc-200 hover:text-zinc-600"
|
||||
>
|
||||
<XIcon className="h-3 w-3" weight="bold" />
|
||||
</button>
|
||||
)}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { CircleNotchIcon, MicrophoneIcon } from "@phosphor-icons/react";
|
||||
|
||||
interface Props {
|
||||
isRecording: boolean;
|
||||
isTranscribing: boolean;
|
||||
isStreaming: boolean;
|
||||
disabled: boolean;
|
||||
onClick: () => void;
|
||||
}
|
||||
|
||||
export function RecordingButton({
|
||||
isRecording,
|
||||
isTranscribing,
|
||||
isStreaming,
|
||||
disabled,
|
||||
onClick,
|
||||
}: Props) {
|
||||
return (
|
||||
<Button
|
||||
type="button"
|
||||
variant="icon"
|
||||
size="icon"
|
||||
aria-label={isRecording ? "Stop recording" : "Start recording"}
|
||||
disabled={disabled}
|
||||
onClick={onClick}
|
||||
className={cn(
|
||||
"border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700",
|
||||
disabled && "opacity-40",
|
||||
isRecording &&
|
||||
"animate-pulse border-red-500 bg-red-500 text-white hover:border-red-600 hover:bg-red-600",
|
||||
isTranscribing && "bg-zinc-100 text-zinc-400",
|
||||
isStreaming && "opacity-40",
|
||||
)}
|
||||
>
|
||||
{isTranscribing ? (
|
||||
<CircleNotchIcon className="h-4 w-4 animate-spin" weight="bold" />
|
||||
) : (
|
||||
<MicrophoneIcon className="h-4 w-4" weight="bold" />
|
||||
)}
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
@@ -3,15 +3,12 @@ import { ChangeEvent, FormEvent, useEffect, useState } from "react";
|
||||
interface Args {
|
||||
onSend: (message: string) => void;
|
||||
disabled?: boolean;
|
||||
/** Allow sending when text is empty (e.g. when files are attached). */
|
||||
canSendEmpty?: boolean;
|
||||
inputId?: string;
|
||||
}
|
||||
|
||||
export function useChatInput({
|
||||
onSend,
|
||||
disabled = false,
|
||||
canSendEmpty = false,
|
||||
inputId = "chat-input",
|
||||
}: Args) {
|
||||
const [value, setValue] = useState("");
|
||||
@@ -35,7 +32,7 @@ export function useChatInput({
|
||||
);
|
||||
|
||||
async function handleSend() {
|
||||
if (disabled || isSending || (!value.trim() && !canSendEmpty)) return;
|
||||
if (disabled || isSending || !value.trim()) return;
|
||||
|
||||
setIsSending(true);
|
||||
try {
|
||||
|
||||
@@ -5,12 +5,7 @@ import {
|
||||
} from "@/components/ai-elements/conversation";
|
||||
import { Message, MessageContent } from "@/components/ai-elements/message";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { FileUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { TOOL_PART_PREFIX } from "../JobStatsBar/constants";
|
||||
import { TurnStatsBar } from "../JobStatsBar/TurnStatsBar";
|
||||
import { parseSpecialMarkers } from "./helpers";
|
||||
import { AssistantMessageActions } from "./components/AssistantMessageActions";
|
||||
import { MessageAttachments } from "./components/MessageAttachments";
|
||||
import { UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { MessagePartRenderer } from "./components/MessagePartRenderer";
|
||||
import { ThinkingIndicator } from "./components/ThinkingIndicator";
|
||||
|
||||
@@ -20,24 +15,6 @@ interface Props {
|
||||
error: Error | undefined;
|
||||
isLoading: boolean;
|
||||
headerSlot?: React.ReactNode;
|
||||
sessionID?: string | null;
|
||||
}
|
||||
|
||||
/** Collect all messages belonging to a turn: the user message + every
|
||||
* assistant message up to (but not including) the next user message. */
|
||||
function getTurnMessages(
|
||||
messages: UIMessage<unknown, UIDataTypes, UITools>[],
|
||||
lastAssistantIndex: number,
|
||||
): UIMessage<unknown, UIDataTypes, UITools>[] {
|
||||
const userIndex = messages.findLastIndex(
|
||||
(m, i) => i < lastAssistantIndex && m.role === "user",
|
||||
);
|
||||
const nextUserIndex = messages.findIndex(
|
||||
(m, i) => i > lastAssistantIndex && m.role === "user",
|
||||
);
|
||||
const start = userIndex >= 0 ? userIndex : lastAssistantIndex;
|
||||
const end = nextUserIndex >= 0 ? nextUserIndex : messages.length;
|
||||
return messages.slice(start, end);
|
||||
}
|
||||
|
||||
export function ChatMessagesContainer({
|
||||
@@ -46,10 +23,12 @@ export function ChatMessagesContainer({
|
||||
error,
|
||||
isLoading,
|
||||
headerSlot,
|
||||
sessionID,
|
||||
}: Props) {
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
|
||||
// Determine if something is visibly "in-flight" in the last assistant message:
|
||||
// - Text is actively streaming (last part is non-empty text)
|
||||
// - A tool call is pending (state is input-streaming or input-available)
|
||||
const hasInflight = (() => {
|
||||
if (lastMessage?.role !== "assistant") return false;
|
||||
const parts = lastMessage.parts;
|
||||
@@ -57,11 +36,13 @@ export function ChatMessagesContainer({
|
||||
|
||||
const lastPart = parts[parts.length - 1];
|
||||
|
||||
// Text is actively being written
|
||||
if (lastPart.type === "text" && lastPart.text.trim().length > 0)
|
||||
return true;
|
||||
|
||||
// A tool call is still pending (no output yet)
|
||||
if (
|
||||
lastPart.type.startsWith(TOOL_PART_PREFIX) &&
|
||||
lastPart.type.startsWith("tool-") &&
|
||||
"state" in lastPart &&
|
||||
(lastPart.state === "input-streaming" ||
|
||||
lastPart.state === "input-available")
|
||||
@@ -91,34 +72,6 @@ export function ChatMessagesContainer({
|
||||
messageIndex === messages.length - 1 &&
|
||||
message.role === "assistant";
|
||||
|
||||
const isCurrentlyStreaming =
|
||||
isLastAssistant &&
|
||||
(status === "streaming" || status === "submitted");
|
||||
|
||||
const isAssistant = message.role === "assistant";
|
||||
|
||||
const nextMessage = messages[messageIndex + 1];
|
||||
const isLastInTurn =
|
||||
isAssistant &&
|
||||
messageIndex <= messages.length - 1 &&
|
||||
(!nextMessage || nextMessage.role === "user");
|
||||
const textParts = message.parts.filter(
|
||||
(p): p is Extract<typeof p, { type: "text" }> => p.type === "text",
|
||||
);
|
||||
const lastTextPart = textParts[textParts.length - 1];
|
||||
const hasErrorMarker =
|
||||
lastTextPart !== undefined &&
|
||||
parseSpecialMarkers(lastTextPart.text).markerType === "error";
|
||||
const showActions =
|
||||
isLastInTurn &&
|
||||
!isCurrentlyStreaming &&
|
||||
textParts.length > 0 &&
|
||||
!hasErrorMarker;
|
||||
|
||||
const fileParts = message.parts.filter(
|
||||
(p): p is FileUIPart => p.type === "file",
|
||||
);
|
||||
|
||||
return (
|
||||
<Message from={message.role} key={message.id}>
|
||||
<MessageContent
|
||||
@@ -136,27 +89,10 @@ export function ChatMessagesContainer({
|
||||
partIndex={i}
|
||||
/>
|
||||
))}
|
||||
{isLastInTurn && !isCurrentlyStreaming && (
|
||||
<TurnStatsBar
|
||||
turnMessages={getTurnMessages(messages, messageIndex)}
|
||||
/>
|
||||
)}
|
||||
{isLastAssistant && showThinking && (
|
||||
<ThinkingIndicator active={showThinking} />
|
||||
)}
|
||||
</MessageContent>
|
||||
{fileParts.length > 0 && (
|
||||
<MessageAttachments
|
||||
files={fileParts}
|
||||
isUser={message.role === "user"}
|
||||
/>
|
||||
)}
|
||||
{showActions && (
|
||||
<AssistantMessageActions
|
||||
message={message}
|
||||
sessionID={sessionID ?? null}
|
||||
/>
|
||||
)}
|
||||
</Message>
|
||||
);
|
||||
})}
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
MessageAction,
|
||||
MessageActions,
|
||||
} from "@/components/ai-elements/message";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { CopySimple, ThumbsDown, ThumbsUp } from "@phosphor-icons/react";
|
||||
import { UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { useMessageFeedback } from "../useMessageFeedback";
|
||||
import { FeedbackModal } from "./FeedbackModal";
|
||||
import { TTSButton } from "./TTSButton";
|
||||
|
||||
interface Props {
|
||||
message: UIMessage<unknown, UIDataTypes, UITools>;
|
||||
sessionID: string | null;
|
||||
}
|
||||
|
||||
function extractTextFromParts(
|
||||
parts: UIMessage<unknown, UIDataTypes, UITools>["parts"],
|
||||
): string {
|
||||
return parts
|
||||
.filter((p) => p.type === "text")
|
||||
.map((p) => (p as { type: "text"; text: string }).text)
|
||||
.join("\n")
|
||||
.trim();
|
||||
}
|
||||
|
||||
export function AssistantMessageActions({ message, sessionID }: Props) {
|
||||
const {
|
||||
feedback,
|
||||
showFeedbackModal,
|
||||
handleCopy,
|
||||
handleUpvote,
|
||||
handleDownvoteClick,
|
||||
handleDownvoteSubmit,
|
||||
handleDownvoteCancel,
|
||||
} = useMessageFeedback({ sessionID, messageID: message.id });
|
||||
|
||||
const text = extractTextFromParts(message.parts);
|
||||
|
||||
return (
|
||||
<>
|
||||
<MessageActions className="mt-1">
|
||||
<MessageAction
|
||||
tooltip="Copy"
|
||||
onClick={() => handleCopy(text)}
|
||||
variant="ghost"
|
||||
size="icon-sm"
|
||||
>
|
||||
<CopySimple size={16} weight="regular" />
|
||||
</MessageAction>
|
||||
|
||||
<MessageAction
|
||||
tooltip="Good response"
|
||||
onClick={handleUpvote}
|
||||
variant="ghost"
|
||||
size="icon-sm"
|
||||
disabled={feedback === "downvote"}
|
||||
className={cn(
|
||||
feedback === "upvote" && "text-green-300 hover:text-green-300",
|
||||
feedback === "downvote" && "!opacity-20",
|
||||
)}
|
||||
>
|
||||
<ThumbsUp
|
||||
size={16}
|
||||
weight={feedback === "upvote" ? "fill" : "regular"}
|
||||
/>
|
||||
</MessageAction>
|
||||
|
||||
<MessageAction
|
||||
tooltip="Bad response"
|
||||
onClick={handleDownvoteClick}
|
||||
variant="ghost"
|
||||
size="icon-sm"
|
||||
disabled={feedback === "upvote"}
|
||||
className={cn(
|
||||
feedback === "downvote" && "text-red-300 hover:text-red-300",
|
||||
feedback === "upvote" && "!opacity-20",
|
||||
)}
|
||||
>
|
||||
<ThumbsDown
|
||||
size={16}
|
||||
weight={feedback === "downvote" ? "fill" : "regular"}
|
||||
/>
|
||||
</MessageAction>
|
||||
|
||||
<TTSButton text={text} />
|
||||
</MessageActions>
|
||||
|
||||
{showFeedbackModal && (
|
||||
<FeedbackModal
|
||||
isOpen={showFeedbackModal}
|
||||
onSubmit={handleDownvoteSubmit}
|
||||
onCancel={handleDownvoteCancel}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { MessageAction } from "@/components/ai-elements/message";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { Check, Copy } from "@phosphor-icons/react";
|
||||
import { useState } from "react";
|
||||
|
||||
interface Props {
|
||||
text: string;
|
||||
}
|
||||
|
||||
export function CopyButton({ text }: Props) {
|
||||
const [copied, setCopied] = useState(false);
|
||||
|
||||
if (!text.trim()) return null;
|
||||
|
||||
async function handleCopy() {
|
||||
try {
|
||||
await navigator.clipboard.writeText(text);
|
||||
setCopied(true);
|
||||
setTimeout(() => setCopied(false), 2000);
|
||||
} catch {
|
||||
toast({
|
||||
title: "Failed to copy",
|
||||
description:
|
||||
"Your browser may not support clipboard access, or something went wrong.",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<MessageAction
|
||||
tooltip={copied ? "Copied!" : "Copy to clipboard"}
|
||||
onClick={handleCopy}
|
||||
>
|
||||
{copied ? <Check size={16} /> : <Copy size={16} />}
|
||||
</MessageAction>
|
||||
);
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { useState } from "react";
|
||||
|
||||
interface Props {
|
||||
isOpen: boolean;
|
||||
onSubmit: (comment: string) => void;
|
||||
onCancel: () => void;
|
||||
}
|
||||
|
||||
export function FeedbackModal({ isOpen, onSubmit, onCancel }: Props) {
|
||||
const [comment, setComment] = useState("");
|
||||
|
||||
function handleSubmit() {
|
||||
onSubmit(comment);
|
||||
setComment("");
|
||||
}
|
||||
|
||||
function handleClose() {
|
||||
onCancel();
|
||||
setComment("");
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="What could have been better?"
|
||||
controlled={{
|
||||
isOpen,
|
||||
set: (open) => {
|
||||
if (!open) handleClose();
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<div className="mx-auto w-[95%] space-y-4">
|
||||
<p className="text-sm text-slate-600">
|
||||
Your feedback helps us improve. Share details below.
|
||||
</p>
|
||||
<Textarea
|
||||
placeholder="Tell us what went wrong or could be improved..."
|
||||
value={comment}
|
||||
onChange={(e) => setComment(e.target.value)}
|
||||
rows={4}
|
||||
maxLength={2000}
|
||||
className="resize-none"
|
||||
/>
|
||||
<div className="flex items-center justify-between">
|
||||
<p className="text-xs text-slate-400">{comment.length}/2000</p>
|
||||
<div className="flex gap-2">
|
||||
<Button variant="outline" size="sm" onClick={handleClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button size="sm" onClick={handleSubmit}>
|
||||
Submit feedback
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,136 +0,0 @@
|
||||
import {
|
||||
FileText as FileTextIcon,
|
||||
DownloadSimple as DownloadIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import type { FileUIPart } from "ai";
|
||||
import {
|
||||
globalRegistry,
|
||||
OutputItem,
|
||||
} from "@/components/contextual/OutputRenderers";
|
||||
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||
import {
|
||||
ContentCard,
|
||||
ContentCardHeader,
|
||||
ContentCardTitle,
|
||||
ContentCardSubtitle,
|
||||
} from "../../ToolAccordion/AccordionContent";
|
||||
|
||||
interface Props {
|
||||
files: FileUIPart[];
|
||||
isUser?: boolean;
|
||||
}
|
||||
|
||||
function renderFileContent(file: FileUIPart): React.ReactNode | null {
|
||||
if (!file.url) return null;
|
||||
const metadata: OutputMetadata = {
|
||||
mimeType: file.mediaType,
|
||||
filename: file.filename,
|
||||
type: file.mediaType?.startsWith("image/")
|
||||
? "image"
|
||||
: file.mediaType?.startsWith("video/")
|
||||
? "video"
|
||||
: undefined,
|
||||
};
|
||||
const renderer = globalRegistry.getRenderer(file.url, metadata);
|
||||
if (!renderer) return null;
|
||||
return (
|
||||
<OutputItem value={file.url} metadata={metadata} renderer={renderer} />
|
||||
);
|
||||
}
|
||||
|
||||
export function MessageAttachments({ files, isUser }: Props) {
|
||||
if (files.length === 0) return null;
|
||||
|
||||
return (
|
||||
<div className="mt-2 flex flex-col gap-2">
|
||||
{files.map((file, i) => {
|
||||
const rendered = renderFileContent(file);
|
||||
return rendered ? (
|
||||
<div
|
||||
key={`${file.filename}-${i}`}
|
||||
className={`inline-block rounded-lg border p-1.5 ${
|
||||
isUser
|
||||
? "border-purple-300 bg-purple-50"
|
||||
: "border-neutral-200 bg-neutral-50"
|
||||
}`}
|
||||
>
|
||||
{rendered}
|
||||
<div
|
||||
className={`mt-1 flex items-center gap-1 px-0.5 text-xs ${
|
||||
isUser ? "text-zinc-600" : "text-neutral-500"
|
||||
}`}
|
||||
>
|
||||
<span className="truncate">{file.filename || "file"}</span>
|
||||
{file.url && (
|
||||
<a
|
||||
href={file.url}
|
||||
download
|
||||
aria-label="Download file"
|
||||
className="ml-auto shrink-0 opacity-50 hover:opacity-100"
|
||||
>
|
||||
<DownloadIcon className="h-3.5 w-3.5" />
|
||||
</a>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
) : isUser ? (
|
||||
<div
|
||||
key={`${file.filename}-${i}`}
|
||||
className="min-w-0 rounded-lg border border-purple-300 bg-purple-100 p-3"
|
||||
>
|
||||
<div className="flex items-start justify-between gap-2">
|
||||
<div className="flex min-w-0 items-center gap-2">
|
||||
<FileTextIcon className="h-5 w-5 shrink-0 text-neutral-400" />
|
||||
<div className="min-w-0">
|
||||
<p className="truncate text-sm font-medium text-zinc-800">
|
||||
{file.filename || "file"}
|
||||
</p>
|
||||
<p className="mt-0.5 truncate font-mono text-xs text-zinc-800">
|
||||
{file.mediaType || "file"}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
{file.url && (
|
||||
<a
|
||||
href={file.url}
|
||||
download
|
||||
aria-label="Download file"
|
||||
className="shrink-0 text-purple-400 hover:text-purple-600"
|
||||
>
|
||||
<DownloadIcon className="h-5 w-5" />
|
||||
</a>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<ContentCard key={`${file.filename}-${i}`}>
|
||||
<ContentCardHeader
|
||||
action={
|
||||
file.url ? (
|
||||
<a
|
||||
href={file.url}
|
||||
download
|
||||
aria-label="Download file"
|
||||
className="shrink-0 text-neutral-400 hover:text-neutral-600"
|
||||
>
|
||||
<DownloadIcon className="h-5 w-5" />
|
||||
</a>
|
||||
) : undefined
|
||||
}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<FileTextIcon className="h-5 w-5 shrink-0 text-neutral-400" />
|
||||
<div className="min-w-0">
|
||||
<ContentCardTitle>{file.filename || "file"}</ContentCardTitle>
|
||||
<ContentCardSubtitle>
|
||||
{file.mediaType || "file"}
|
||||
</ContentCardSubtitle>
|
||||
</div>
|
||||
</div>
|
||||
</ContentCardHeader>
|
||||
</ContentCard>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -14,7 +14,6 @@ import { FindBlocksTool } from "../../../tools/FindBlocks/FindBlocks";
|
||||
import { GenericTool } from "../../../tools/GenericTool/GenericTool";
|
||||
import { RunAgentTool } from "../../../tools/RunAgent/RunAgent";
|
||||
import { RunBlockTool } from "../../../tools/RunBlock/RunBlock";
|
||||
import { RunMCPToolComponent } from "../../../tools/RunMCPTool/RunMCPTool";
|
||||
import { SearchDocsTool } from "../../../tools/SearchDocs/SearchDocs";
|
||||
import { ViewAgentOutputTool } from "../../../tools/ViewAgentOutput/ViewAgentOutput";
|
||||
import { parseSpecialMarkers, resolveWorkspaceUrls } from "../helpers";
|
||||
@@ -130,8 +129,6 @@ export function MessagePartRenderer({ part, messageID, partIndex }: Props) {
|
||||
return <SearchDocsTool key={key} part={part as ToolUIPart} />;
|
||||
case "tool-run_block":
|
||||
return <RunBlockTool key={key} part={part as ToolUIPart} />;
|
||||
case "tool-run_mcp_tool":
|
||||
return <RunMCPToolComponent key={key} part={part as ToolUIPart} />;
|
||||
case "tool-run_agent":
|
||||
case "tool-schedule_agent":
|
||||
return <RunAgentTool key={key} part={part as ToolUIPart} />;
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { MessageAction } from "@/components/ai-elements/message";
|
||||
import { SpeakerHigh, Stop } from "@phosphor-icons/react";
|
||||
import { useTextToSpeech } from "@/components/contextual/Chat/components/ChatMessage/useTextToSpeech";
|
||||
import { useMemo } from "react";
|
||||
|
||||
// Unicode emoji pattern (covers most emoji ranges including modifiers and ZWJ sequences)
|
||||
const EMOJI_RE =
|
||||
/[\u{1F600}-\u{1F64F}\u{1F300}-\u{1F5FF}\u{1F680}-\u{1F6FF}\u{1F1E0}-\u{1F1FF}\u{2600}-\u{26FF}\u{2700}-\u{27BF}\u{FE00}-\u{FE0F}\u{1F900}-\u{1F9FF}\u{1FA00}-\u{1FA6F}\u{1FA70}-\u{1FAFF}\u{200D}\u{20E3}\u{E0020}-\u{E007F}]+/gu;
|
||||
|
||||
function stripMarkdownForSpeech(md: string): string {
|
||||
return (
|
||||
md
|
||||
// Code blocks (``` ... ```)
|
||||
.replace(/```[\s\S]*?```/g, "")
|
||||
// Inline code
|
||||
.replace(/`([^`]*)`/g, "$1")
|
||||
// Images 
|
||||
.replace(/!\[([^\]]*)\]\([^)]*\)/g, "$1")
|
||||
// Links [text](url)
|
||||
.replace(/\[([^\]]*)\]\([^)]*\)/g, "$1")
|
||||
// Bold/italic (***text***, **text**, *text*, ___text___, __text__, _text_)
|
||||
.replace(/\*{1,3}([^*]+)\*{1,3}/g, "$1")
|
||||
.replace(/_{1,3}([^_]+)_{1,3}/g, "$1")
|
||||
// Strikethrough
|
||||
.replace(/~~([^~]+)~~/g, "$1")
|
||||
// Headings (# ... ######)
|
||||
.replace(/^#{1,6}\s+/gm, "")
|
||||
// Horizontal rules
|
||||
.replace(/^[-*_]{3,}\s*$/gm, "")
|
||||
// Blockquotes
|
||||
.replace(/^>\s?/gm, "")
|
||||
// Unordered list markers
|
||||
.replace(/^[\s]*[-*+]\s+/gm, "")
|
||||
// Ordered list markers
|
||||
.replace(/^[\s]*\d+\.\s+/gm, "")
|
||||
// HTML tags
|
||||
.replace(/<[^>]+>/g, "")
|
||||
// Emoji
|
||||
.replace(EMOJI_RE, "")
|
||||
// Collapse multiple blank lines
|
||||
.replace(/\n{3,}/g, "\n\n")
|
||||
.trim()
|
||||
);
|
||||
}
|
||||
|
||||
interface Props {
|
||||
text: string;
|
||||
}
|
||||
|
||||
export function TTSButton({ text }: Props) {
|
||||
const cleanText = useMemo(() => stripMarkdownForSpeech(text), [text]);
|
||||
const { status, isSupported, toggle } = useTextToSpeech(cleanText);
|
||||
|
||||
if (!isSupported || !cleanText) return null;
|
||||
|
||||
const isPlaying = status === "playing";
|
||||
|
||||
return (
|
||||
<MessageAction
|
||||
tooltip={isPlaying ? "Stop reading" : "Read aloud"}
|
||||
onClick={toggle}
|
||||
>
|
||||
{isPlaying ? <Stop size={16} /> : <SpeakerHigh size={16} />}
|
||||
</MessageAction>
|
||||
);
|
||||
}
|
||||
@@ -1,120 +0,0 @@
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { getWebSocketToken } from "@/lib/supabase/actions";
|
||||
import { environment } from "@/services/environment";
|
||||
import { useState } from "react";
|
||||
|
||||
interface Args {
|
||||
sessionID: string | null;
|
||||
messageID: string;
|
||||
}
|
||||
|
||||
async function submitFeedbackToBackend(args: {
|
||||
sessionID: string;
|
||||
messageID: string;
|
||||
scoreName: string;
|
||||
scoreValue: number;
|
||||
comment?: string;
|
||||
}) {
|
||||
try {
|
||||
const { token } = await getWebSocketToken();
|
||||
if (!token) return;
|
||||
|
||||
await fetch(
|
||||
`${environment.getAGPTServerBaseUrl()}/api/chat/sessions/${args.sessionID}/feedback`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
message_id: args.messageID,
|
||||
score_name: args.scoreName,
|
||||
score_value: args.scoreValue,
|
||||
comment: args.comment,
|
||||
}),
|
||||
},
|
||||
);
|
||||
} catch {
|
||||
// Feedback submission is best-effort; silently ignore failures
|
||||
}
|
||||
}
|
||||
|
||||
export function useMessageFeedback({ sessionID, messageID }: Args) {
|
||||
const [feedback, setFeedback] = useState<"upvote" | "downvote" | null>(null);
|
||||
const [showFeedbackModal, setShowFeedbackModal] = useState(false);
|
||||
|
||||
async function handleCopy(text: string) {
|
||||
try {
|
||||
await navigator.clipboard.writeText(text);
|
||||
toast({ title: "Copied!", variant: "success", duration: 2000 });
|
||||
} catch {
|
||||
toast({
|
||||
title: "Failed to copy",
|
||||
variant: "destructive",
|
||||
duration: 2000,
|
||||
});
|
||||
return;
|
||||
}
|
||||
if (sessionID) {
|
||||
submitFeedbackToBackend({
|
||||
sessionID,
|
||||
messageID,
|
||||
scoreName: "copy",
|
||||
scoreValue: 1,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function handleUpvote() {
|
||||
if (feedback) return;
|
||||
setFeedback("upvote");
|
||||
toast({
|
||||
title: "Thank you for your feedback!",
|
||||
variant: "success",
|
||||
duration: 3000,
|
||||
});
|
||||
if (sessionID) {
|
||||
submitFeedbackToBackend({
|
||||
sessionID,
|
||||
messageID,
|
||||
scoreName: "user-feedback",
|
||||
scoreValue: 1,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function handleDownvoteClick() {
|
||||
if (feedback) return;
|
||||
setFeedback("downvote");
|
||||
setShowFeedbackModal(true);
|
||||
}
|
||||
|
||||
function handleDownvoteSubmit(comment: string) {
|
||||
setShowFeedbackModal(false);
|
||||
if (sessionID) {
|
||||
submitFeedbackToBackend({
|
||||
sessionID,
|
||||
messageID,
|
||||
scoreName: "user-feedback",
|
||||
scoreValue: 0,
|
||||
comment: comment || undefined,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function handleDownvoteCancel() {
|
||||
setShowFeedbackModal(false);
|
||||
setFeedback(null);
|
||||
}
|
||||
|
||||
return {
|
||||
feedback,
|
||||
showFeedbackModal,
|
||||
handleCopy,
|
||||
handleUpvote,
|
||||
handleDownvoteClick,
|
||||
handleDownvoteSubmit,
|
||||
handleDownvoteCancel,
|
||||
};
|
||||
}
|
||||
@@ -3,7 +3,6 @@ import {
|
||||
getGetV2ListSessionsQueryKey,
|
||||
useDeleteV2DeleteSession,
|
||||
useGetV2ListSessions,
|
||||
usePatchV2UpdateSessionTitle,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
@@ -18,6 +17,7 @@ import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import {
|
||||
Sidebar,
|
||||
SidebarContent,
|
||||
SidebarFooter,
|
||||
SidebarHeader,
|
||||
SidebarTrigger,
|
||||
useSidebar,
|
||||
@@ -25,9 +25,8 @@ import {
|
||||
import { cn } from "@/lib/utils";
|
||||
import { DotsThree, PlusCircleIcon, PlusIcon } from "@phosphor-icons/react";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import { motion } from "framer-motion";
|
||||
import { parseAsString, useQueryState } from "nuqs";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
import { DeleteChatDialog } from "../DeleteChatDialog/DeleteChatDialog";
|
||||
|
||||
@@ -66,39 +65,6 @@ export function ChatSidebar() {
|
||||
},
|
||||
});
|
||||
|
||||
const [editingSessionId, setEditingSessionId] = useState<string | null>(null);
|
||||
const [editingTitle, setEditingTitle] = useState("");
|
||||
const renameInputRef = useRef<HTMLInputElement>(null);
|
||||
const renameCancelledRef = useRef(false);
|
||||
|
||||
const { mutate: renameSession } = usePatchV2UpdateSessionTitle({
|
||||
mutation: {
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
});
|
||||
setEditingSessionId(null);
|
||||
},
|
||||
onError: (error) => {
|
||||
toast({
|
||||
title: "Failed to rename chat",
|
||||
description:
|
||||
error instanceof Error ? error.message : "An error occurred",
|
||||
variant: "destructive",
|
||||
});
|
||||
setEditingSessionId(null);
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Auto-focus the rename input when editing starts
|
||||
useEffect(() => {
|
||||
if (editingSessionId && renameInputRef.current) {
|
||||
renameInputRef.current.focus();
|
||||
renameInputRef.current.select();
|
||||
}
|
||||
}, [editingSessionId]);
|
||||
|
||||
const sessions =
|
||||
sessionsResponse?.status === 200 ? sessionsResponse.data.sessions : [];
|
||||
|
||||
@@ -110,26 +76,6 @@ export function ChatSidebar() {
|
||||
setSessionId(id);
|
||||
}
|
||||
|
||||
function handleRenameClick(
|
||||
e: React.MouseEvent,
|
||||
id: string,
|
||||
title: string | null | undefined,
|
||||
) {
|
||||
e.stopPropagation();
|
||||
renameCancelledRef.current = false;
|
||||
setEditingSessionId(id);
|
||||
setEditingTitle(title || "");
|
||||
}
|
||||
|
||||
function handleRenameSubmit(id: string) {
|
||||
const trimmed = editingTitle.trim();
|
||||
if (trimmed) {
|
||||
renameSession({ sessionId: id, data: { title: trimmed } });
|
||||
} else {
|
||||
setEditingSessionId(null);
|
||||
}
|
||||
}
|
||||
|
||||
function handleDeleteClick(
|
||||
e: React.MouseEvent,
|
||||
id: string,
|
||||
@@ -214,42 +160,29 @@ export function ChatSidebar() {
|
||||
</motion.div>
|
||||
</SidebarHeader>
|
||||
)}
|
||||
{!isCollapsed && (
|
||||
<SidebarHeader className="shrink-0 px-4 pb-4 pt-4 shadow-[0_4px_6px_-1px_rgba(0,0,0,0.05)]">
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
transition={{ duration: 0.2, delay: 0.1 }}
|
||||
className="flex flex-col gap-3 px-3"
|
||||
>
|
||||
<div className="flex items-center justify-between">
|
||||
<Text variant="h3" size="body-medium">
|
||||
Your chats
|
||||
</Text>
|
||||
<div className="relative left-6">
|
||||
<SidebarTrigger />
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={handleNewChat}
|
||||
className="w-full"
|
||||
leftIcon={<PlusIcon className="h-4 w-4" weight="bold" />}
|
||||
>
|
||||
New Chat
|
||||
</Button>
|
||||
</motion.div>
|
||||
</SidebarHeader>
|
||||
)}
|
||||
|
||||
<SidebarContent className="gap-4 overflow-y-auto px-4 py-4 [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
||||
{!isCollapsed && (
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
transition={{ duration: 0.2, delay: 0.1 }}
|
||||
className="flex items-center justify-between px-3"
|
||||
>
|
||||
<Text variant="h3" size="body-medium">
|
||||
Your chats
|
||||
</Text>
|
||||
<div className="relative left-6">
|
||||
<SidebarTrigger />
|
||||
</div>
|
||||
</motion.div>
|
||||
)}
|
||||
|
||||
{!isCollapsed && (
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
transition={{ duration: 0.2, delay: 0.15 }}
|
||||
className="flex flex-col gap-1"
|
||||
className="mt-4 flex flex-col gap-1"
|
||||
>
|
||||
{isLoadingSessions ? (
|
||||
<div className="flex min-h-[30rem] items-center justify-center py-4">
|
||||
@@ -270,105 +203,76 @@ export function ChatSidebar() {
|
||||
: "hover:bg-zinc-50",
|
||||
)}
|
||||
>
|
||||
{editingSessionId === session.id ? (
|
||||
<div className="px-3 py-2.5">
|
||||
<input
|
||||
ref={renameInputRef}
|
||||
type="text"
|
||||
aria-label="Rename chat"
|
||||
value={editingTitle}
|
||||
onChange={(e) => setEditingTitle(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") {
|
||||
e.currentTarget.blur();
|
||||
} else if (e.key === "Escape") {
|
||||
renameCancelledRef.current = true;
|
||||
setEditingSessionId(null);
|
||||
}
|
||||
}}
|
||||
onBlur={() => {
|
||||
if (renameCancelledRef.current) {
|
||||
renameCancelledRef.current = false;
|
||||
return;
|
||||
}
|
||||
handleRenameSubmit(session.id);
|
||||
}}
|
||||
className="w-full rounded border border-zinc-300 bg-white px-2 py-1 text-sm text-zinc-800 outline-none focus:border-purple-500 focus:ring-1 focus:ring-purple-500"
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<button
|
||||
onClick={() => handleSelectSession(session.id)}
|
||||
className="w-full px-3 py-2.5 pr-10 text-left"
|
||||
>
|
||||
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
|
||||
<div className="min-w-0 max-w-full">
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"truncate font-normal",
|
||||
session.id === sessionId
|
||||
? "text-zinc-600"
|
||||
: "text-zinc-800",
|
||||
)}
|
||||
>
|
||||
<AnimatePresence mode="wait" initial={false}>
|
||||
<motion.span
|
||||
key={session.title || "untitled"}
|
||||
initial={{ opacity: 0, y: 4 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
exit={{ opacity: 0, y: -4 }}
|
||||
transition={{ duration: 0.2 }}
|
||||
className="block truncate"
|
||||
>
|
||||
{session.title || "Untitled chat"}
|
||||
</motion.span>
|
||||
</AnimatePresence>
|
||||
</Text>
|
||||
</div>
|
||||
<Text variant="small" className="text-neutral-400">
|
||||
{formatDate(session.updated_at)}
|
||||
<button
|
||||
onClick={() => handleSelectSession(session.id)}
|
||||
className="w-full px-3 py-2.5 pr-10 text-left"
|
||||
>
|
||||
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
|
||||
<div className="min-w-0 max-w-full">
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"truncate font-normal",
|
||||
session.id === sessionId
|
||||
? "text-zinc-600"
|
||||
: "text-zinc-800",
|
||||
)}
|
||||
>
|
||||
{session.title || `Untitled chat`}
|
||||
</Text>
|
||||
</div>
|
||||
</button>
|
||||
)}
|
||||
{editingSessionId !== session.id && (
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<button
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
className="absolute right-2 top-1/2 -translate-y-1/2 rounded-full p-1.5 text-zinc-600 transition-all hover:bg-neutral-100"
|
||||
aria-label="More actions"
|
||||
>
|
||||
<DotsThree className="h-4 w-4" />
|
||||
</button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem
|
||||
onClick={(e) =>
|
||||
handleRenameClick(e, session.id, session.title)
|
||||
}
|
||||
>
|
||||
Rename
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem
|
||||
onClick={(e) =>
|
||||
handleDeleteClick(e, session.id, session.title)
|
||||
}
|
||||
disabled={isDeleting}
|
||||
className="text-red-600 focus:bg-red-50 focus:text-red-600"
|
||||
>
|
||||
Delete chat
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
)}
|
||||
<Text variant="small" className="text-neutral-400">
|
||||
{formatDate(session.updated_at)}
|
||||
</Text>
|
||||
</div>
|
||||
</button>
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<button
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
className="absolute right-2 top-1/2 -translate-y-1/2 rounded-full p-1.5 text-zinc-600 transition-all hover:bg-neutral-100"
|
||||
aria-label="More actions"
|
||||
>
|
||||
<DotsThree className="h-4 w-4" />
|
||||
</button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem
|
||||
onClick={(e) =>
|
||||
handleDeleteClick(e, session.id, session.title)
|
||||
}
|
||||
disabled={isDeleting}
|
||||
className="text-red-600 focus:bg-red-50 focus:text-red-600"
|
||||
>
|
||||
Delete chat
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
))
|
||||
)}
|
||||
</motion.div>
|
||||
)}
|
||||
</SidebarContent>
|
||||
{!isCollapsed && sessionId && (
|
||||
<SidebarFooter className="shrink-0 bg-zinc-50 p-3 pb-1 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
transition={{ duration: 0.2, delay: 0.2 }}
|
||||
>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={handleNewChat}
|
||||
className="w-full"
|
||||
leftIcon={<PlusIcon className="h-4 w-4" weight="bold" />}
|
||||
>
|
||||
New Chat
|
||||
</Button>
|
||||
</motion.div>
|
||||
</SidebarFooter>
|
||||
)}
|
||||
</Sidebar>
|
||||
|
||||
<DeleteChatDialog
|
||||
|
||||
@@ -29,6 +29,7 @@ export function DeleteChatDialog({
|
||||
}
|
||||
},
|
||||
}}
|
||||
onClose={isDeleting ? undefined : onCancel}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<Text variant="body">
|
||||
|
||||
@@ -17,19 +17,13 @@ interface Props {
|
||||
inputLayoutId: string;
|
||||
isCreatingSession: boolean;
|
||||
onCreateSession: () => void | Promise<string>;
|
||||
onSend: (message: string, files?: File[]) => void | Promise<void>;
|
||||
isUploadingFiles?: boolean;
|
||||
droppedFiles?: File[];
|
||||
onDroppedFilesConsumed?: () => void;
|
||||
onSend: (message: string) => void | Promise<void>;
|
||||
}
|
||||
|
||||
export function EmptySession({
|
||||
inputLayoutId,
|
||||
isCreatingSession,
|
||||
onSend,
|
||||
isUploadingFiles,
|
||||
droppedFiles,
|
||||
onDroppedFilesConsumed,
|
||||
}: Props) {
|
||||
const { user } = useSupabase();
|
||||
const greetingName = getGreetingName(user);
|
||||
@@ -57,12 +51,12 @@ export function EmptySession({
|
||||
return (
|
||||
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-0 py-5 md:px-6 md:py-10">
|
||||
<motion.div
|
||||
className="w-full max-w-[52rem] text-center"
|
||||
className="w-full max-w-3xl text-center"
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
transition={{ duration: 0.3 }}
|
||||
>
|
||||
<div className="mx-auto max-w-[52rem]">
|
||||
<div className="mx-auto max-w-3xl">
|
||||
<Text variant="h3" className="mb-1 !text-[1.375rem] text-zinc-700">
|
||||
Hey, <span className="text-violet-600">{greetingName}</span>
|
||||
</Text>
|
||||
@@ -80,11 +74,8 @@ export function EmptySession({
|
||||
inputId="chat-input-empty"
|
||||
onSend={onSend}
|
||||
disabled={isCreatingSession}
|
||||
isUploadingFiles={isUploadingFiles}
|
||||
placeholder={inputPlaceholder}
|
||||
className="w-full"
|
||||
droppedFiles={droppedFiles}
|
||||
onDroppedFilesConsumed={onDroppedFilesConsumed}
|
||||
/>
|
||||
</motion.div>
|
||||
</div>
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
import type { UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { getWorkDoneCounters } from "./useWorkDoneCounters";
|
||||
|
||||
interface Props {
|
||||
turnMessages: UIMessage<unknown, UIDataTypes, UITools>[];
|
||||
}
|
||||
|
||||
export function TurnStatsBar({ turnMessages }: Props) {
|
||||
const { counters } = getWorkDoneCounters(turnMessages);
|
||||
|
||||
if (counters.length === 0) return null;
|
||||
|
||||
return (
|
||||
<div className="mt-2 flex items-center gap-1.5">
|
||||
{counters.map(function renderCounter(counter, index) {
|
||||
return (
|
||||
<span key={counter.category} className="flex items-center gap-1">
|
||||
{index > 0 && (
|
||||
<span className="text-xs text-neutral-300">·</span>
|
||||
)}
|
||||
<span className="text-[11px] tabular-nums text-neutral-500">
|
||||
{counter.count} {counter.label}
|
||||
</span>
|
||||
</span>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user