Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into feat/copilot-kimi-k2-fast-model

This commit is contained in:
Zamil Majdy
2026-04-22 00:03:03 +07:00
64 changed files with 6415 additions and 237 deletions

View File

@@ -1,3 +1,6 @@
*.ignore.*
*.ign.*
.application.logs
# Claude Code local settings only — the rest of .claude/ is shared (skills etc.)
.claude/settings.local.json

View File

@@ -179,6 +179,9 @@ MEM0_API_KEY=
OPENWEATHERMAP_API_KEY=
GOOGLE_MAPS_API_KEY=
# Platform Bot Linking
PLATFORM_LINK_BASE_URL=http://localhost:3000/link
# Communication Services
DISCORD_BOT_TOKEN=
MEDIUM_API_KEY=

View File

@@ -0,0 +1 @@
"""Platform bot linking — user-facing REST routes."""

View File

@@ -0,0 +1,158 @@
"""User-facing platform_linking REST routes (JWT auth)."""
import logging
from typing import Annotated
from autogpt_libs import auth
from fastapi import APIRouter, HTTPException, Path, Security
from backend.data.db_accessors import platform_linking_db
from backend.platform_linking.models import (
ConfirmLinkResponse,
ConfirmUserLinkResponse,
DeleteLinkResponse,
LinkTokenInfoResponse,
PlatformLinkInfo,
PlatformUserLinkInfo,
)
from backend.util.exceptions import (
LinkAlreadyExistsError,
LinkFlowMismatchError,
LinkTokenExpiredError,
NotAuthorizedError,
NotFoundError,
)
logger = logging.getLogger(__name__)
router = APIRouter()
TokenPath = Annotated[
str,
Path(max_length=64, pattern=r"^[A-Za-z0-9_-]+$"),
]
def _translate(exc: Exception) -> HTTPException:
if isinstance(exc, NotFoundError):
return HTTPException(status_code=404, detail=str(exc))
if isinstance(exc, NotAuthorizedError):
return HTTPException(status_code=403, detail=str(exc))
if isinstance(exc, LinkAlreadyExistsError):
return HTTPException(status_code=409, detail=str(exc))
if isinstance(exc, LinkTokenExpiredError):
return HTTPException(status_code=410, detail=str(exc))
if isinstance(exc, LinkFlowMismatchError):
return HTTPException(status_code=400, detail=str(exc))
return HTTPException(status_code=500, detail="Internal error.")
@router.get(
"/tokens/{token}/info",
response_model=LinkTokenInfoResponse,
dependencies=[Security(auth.requires_user)],
summary="Get display info for a link token",
)
async def get_link_token_info_route(token: TokenPath) -> LinkTokenInfoResponse:
try:
return await platform_linking_db().get_link_token_info(token)
except (NotFoundError, LinkTokenExpiredError) as exc:
raise _translate(exc) from exc
@router.post(
"/tokens/{token}/confirm",
response_model=ConfirmLinkResponse,
dependencies=[Security(auth.requires_user)],
summary="Confirm a SERVER link token (user must be authenticated)",
)
async def confirm_link_token(
token: TokenPath,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> ConfirmLinkResponse:
try:
return await platform_linking_db().confirm_server_link(token, user_id)
except (
NotFoundError,
LinkFlowMismatchError,
LinkTokenExpiredError,
LinkAlreadyExistsError,
) as exc:
raise _translate(exc) from exc
@router.post(
"/user-tokens/{token}/confirm",
response_model=ConfirmUserLinkResponse,
dependencies=[Security(auth.requires_user)],
summary="Confirm a USER link token (user must be authenticated)",
)
async def confirm_user_link_token(
token: TokenPath,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> ConfirmUserLinkResponse:
try:
return await platform_linking_db().confirm_user_link(token, user_id)
except (
NotFoundError,
LinkFlowMismatchError,
LinkTokenExpiredError,
LinkAlreadyExistsError,
) as exc:
raise _translate(exc) from exc
@router.get(
"/links",
response_model=list[PlatformLinkInfo],
dependencies=[Security(auth.requires_user)],
summary="List all platform servers linked to the authenticated user",
)
async def list_my_links(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> list[PlatformLinkInfo]:
return await platform_linking_db().list_server_links(user_id)
@router.get(
"/user-links",
response_model=list[PlatformUserLinkInfo],
dependencies=[Security(auth.requires_user)],
summary="List all DM links for the authenticated user",
)
async def list_my_user_links(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> list[PlatformUserLinkInfo]:
return await platform_linking_db().list_user_links(user_id)
@router.delete(
"/links/{link_id}",
response_model=DeleteLinkResponse,
dependencies=[Security(auth.requires_user)],
summary="Unlink a platform server",
)
async def delete_link(
link_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> DeleteLinkResponse:
try:
return await platform_linking_db().delete_server_link(link_id, user_id)
except (NotFoundError, NotAuthorizedError) as exc:
raise _translate(exc) from exc
@router.delete(
"/user-links/{link_id}",
response_model=DeleteLinkResponse,
dependencies=[Security(auth.requires_user)],
summary="Unlink a DM / user link",
)
async def delete_user_link_route(
link_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> DeleteLinkResponse:
try:
return await platform_linking_db().delete_user_link(link_id, user_id)
except (NotFoundError, NotAuthorizedError) as exc:
raise _translate(exc) from exc

View File

@@ -0,0 +1,264 @@
"""Route tests: domain exceptions → HTTPException status codes."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException
from backend.util.exceptions import (
LinkAlreadyExistsError,
LinkFlowMismatchError,
LinkTokenExpiredError,
NotAuthorizedError,
NotFoundError,
)
def _db_mock(**method_configs):
"""Return a mock of the accessor's return value with the given AsyncMocks."""
db = MagicMock()
for name, mock in method_configs.items():
setattr(db, name, mock)
return db
class TestTokenInfoRouteTranslation:
@pytest.mark.asyncio
async def test_not_found_maps_to_404(self):
from backend.api.features.platform_linking.routes import (
get_link_token_info_route,
)
db = _db_mock(
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as exc:
await get_link_token_info_route(token="abc")
assert exc.value.status_code == 404
@pytest.mark.asyncio
async def test_expired_maps_to_410(self):
from backend.api.features.platform_linking.routes import (
get_link_token_info_route,
)
db = _db_mock(
get_link_token_info=AsyncMock(side_effect=LinkTokenExpiredError("expired"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as exc:
await get_link_token_info_route(token="abc")
assert exc.value.status_code == 410
class TestConfirmLinkRouteTranslation:
@pytest.mark.asyncio
@pytest.mark.parametrize(
"exc,expected_status",
[
(NotFoundError("missing"), 404),
(LinkFlowMismatchError("wrong flow"), 400),
(LinkTokenExpiredError("expired"), 410),
(LinkAlreadyExistsError("already"), 409),
],
)
async def test_translation(self, exc: Exception, expected_status: int):
from backend.api.features.platform_linking.routes import confirm_link_token
db = _db_mock(confirm_server_link=AsyncMock(side_effect=exc))
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as ctx:
await confirm_link_token(token="abc", user_id="u1")
assert ctx.value.status_code == expected_status
class TestConfirmUserLinkRouteTranslation:
@pytest.mark.asyncio
@pytest.mark.parametrize(
"exc,expected_status",
[
(NotFoundError("missing"), 404),
(LinkFlowMismatchError("wrong flow"), 400),
(LinkTokenExpiredError("expired"), 410),
(LinkAlreadyExistsError("already"), 409),
],
)
async def test_translation(self, exc: Exception, expected_status: int):
from backend.api.features.platform_linking.routes import confirm_user_link_token
db = _db_mock(confirm_user_link=AsyncMock(side_effect=exc))
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as ctx:
await confirm_user_link_token(token="abc", user_id="u1")
assert ctx.value.status_code == expected_status
class TestDeleteLinkRouteTranslation:
@pytest.mark.asyncio
async def test_not_found_maps_to_404(self):
from backend.api.features.platform_linking.routes import delete_link
db = _db_mock(
delete_server_link=AsyncMock(side_effect=NotFoundError("missing"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as exc:
await delete_link(link_id="x", user_id="u1")
assert exc.value.status_code == 404
@pytest.mark.asyncio
async def test_not_owned_maps_to_403(self):
from backend.api.features.platform_linking.routes import delete_link
db = _db_mock(
delete_server_link=AsyncMock(side_effect=NotAuthorizedError("nope"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as exc:
await delete_link(link_id="x", user_id="u1")
assert exc.value.status_code == 403
class TestDeleteUserLinkRouteTranslation:
@pytest.mark.asyncio
async def test_not_found_maps_to_404(self):
from backend.api.features.platform_linking.routes import delete_user_link_route
db = _db_mock(delete_user_link=AsyncMock(side_effect=NotFoundError("missing")))
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as exc:
await delete_user_link_route(link_id="x", user_id="u1")
assert exc.value.status_code == 404
@pytest.mark.asyncio
async def test_not_owned_maps_to_403(self):
from backend.api.features.platform_linking.routes import delete_user_link_route
db = _db_mock(
delete_user_link=AsyncMock(side_effect=NotAuthorizedError("nope"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
with pytest.raises(HTTPException) as exc:
await delete_user_link_route(link_id="x", user_id="u1")
assert exc.value.status_code == 403
# ── Adversarial: malformed token path params ──────────────────────────
class TestAdversarialTokenPath:
# TokenPath enforces `^[A-Za-z0-9_-]+$` + max_length=64.
@pytest.fixture
def client(self):
import fastapi
from autogpt_libs.auth import get_user_id, requires_user
from fastapi.testclient import TestClient
import backend.api.features.platform_linking.routes as routes_mod
app = fastapi.FastAPI()
app.dependency_overrides[requires_user] = lambda: None
app.dependency_overrides[get_user_id] = lambda: "caller-user"
app.include_router(routes_mod.router, prefix="/api/platform-linking")
return TestClient(app)
def test_rejects_token_with_special_chars(self, client):
response = client.get("/api/platform-linking/tokens/bad%24token/info")
assert response.status_code == 422
def test_rejects_token_with_path_traversal(self, client):
for probe in ("..%2F..", "foo..bar", "foo%2Fbar"):
response = client.get(f"/api/platform-linking/tokens/{probe}/info")
assert response.status_code in (
404,
422,
), f"path-traversal probe {probe!r} returned {response.status_code}"
def test_rejects_token_too_long(self, client):
long_token = "a" * 65
response = client.get(f"/api/platform-linking/tokens/{long_token}/info")
assert response.status_code == 422
def test_accepts_token_at_max_length(self, client):
token = "a" * 64
db = _db_mock(
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
response = client.get(f"/api/platform-linking/tokens/{token}/info")
assert response.status_code == 404
def test_accepts_urlsafe_b64_token_shape(self, client):
db = _db_mock(
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
response = client.get("/api/platform-linking/tokens/abc-_XYZ123-_abc/info")
assert response.status_code == 404
def test_confirm_rejects_malformed_token(self, client):
response = client.post("/api/platform-linking/tokens/bad%24token/confirm")
assert response.status_code == 422
class TestAdversarialDeleteLinkId:
"""DELETE link_id has no regex — ensure weird values are handled via
NotFoundError (no crash, no cross-user leak)."""
@pytest.fixture
def client(self):
import fastapi
from autogpt_libs.auth import get_user_id, requires_user
from fastapi.testclient import TestClient
import backend.api.features.platform_linking.routes as routes_mod
app = fastapi.FastAPI()
app.dependency_overrides[requires_user] = lambda: None
app.dependency_overrides[get_user_id] = lambda: "caller-user"
app.include_router(routes_mod.router, prefix="/api/platform-linking")
return TestClient(app)
def test_weird_link_id_returns_404(self, client):
db = _db_mock(
delete_server_link=AsyncMock(side_effect=NotFoundError("missing"))
)
with patch(
"backend.api.features.platform_linking.routes.platform_linking_db",
return_value=db,
):
for link_id in ("'; DROP TABLE links;--", "../../etc/passwd", ""):
response = client.delete(f"/api/platform-linking/links/{link_id}")
assert response.status_code in (404, 405)

View File

@@ -30,6 +30,7 @@ from pydantic import BaseModel, Field
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
from typing_extensions import Optional, TypedDict
from backend.api.features.workspace.routes import create_file_download_response
from backend.api.model import (
CreateAPIKeyRequest,
CreateAPIKeyResponse,
@@ -96,6 +97,7 @@ from backend.data.user import (
update_user_notification_preference,
update_user_timezone,
)
from backend.data.workspace import get_workspace_file_by_id
from backend.executor import scheduler
from backend.executor import utils as execution_utils
from backend.integrations.webhooks.graph_lifecycle_hooks import (
@@ -1703,6 +1705,10 @@ async def enable_execution_sharing(
# Generate a unique share token
share_token = str(uuid.uuid4())
# Remove stale allowlist records before updating the token — prevents a
# window where old records + new token could coexist.
await execution_db.delete_shared_execution_files(execution_id=graph_exec_id)
# Update the execution with share info
await execution_db.update_graph_execution_share_status(
execution_id=graph_exec_id,
@@ -1712,6 +1718,14 @@ async def enable_execution_sharing(
shared_at=datetime.now(timezone.utc),
)
# Create allowlist of workspace files referenced in outputs
await execution_db.create_shared_execution_files(
execution_id=graph_exec_id,
share_token=share_token,
user_id=user_id,
outputs=execution.outputs,
)
# Return the share URL
frontend_url = settings.config.frontend_base_url or "http://localhost:3000"
share_url = f"{frontend_url}/share/{share_token}"
@@ -1737,6 +1751,9 @@ async def disable_execution_sharing(
if not execution:
raise HTTPException(status_code=404, detail="Execution not found")
# Remove shared file allowlist records
await execution_db.delete_shared_execution_files(execution_id=graph_exec_id)
# Remove share info
await execution_db.update_graph_execution_share_status(
execution_id=graph_exec_id,
@@ -1762,6 +1779,43 @@ async def get_shared_execution(
return execution
@v1_router.get(
"/public/shared/{share_token}/files/{file_id}/download",
summary="Download a file from a shared execution",
operation_id="download_shared_file",
tags=["graphs"],
)
async def download_shared_file(
share_token: Annotated[
str,
Path(pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"),
],
file_id: Annotated[
str,
Path(pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"),
],
) -> Response:
"""Download a workspace file from a shared execution (no auth required).
Validates that the file was explicitly exposed when sharing was enabled.
Returns a uniform 404 for all failure modes to prevent enumeration attacks.
"""
# Single-query validation against the allowlist
execution_id = await execution_db.get_shared_execution_file(
share_token=share_token, file_id=file_id
)
if not execution_id:
raise HTTPException(status_code=404, detail="Not found")
# Look up the actual file (no workspace scoping needed — the allowlist
# already validated that this file belongs to the shared execution)
file = await get_workspace_file_by_id(file_id)
if not file:
raise HTTPException(status_code=404, detail="Not found")
return await create_file_download_response(file, inline=True)
########################################################
##################### Schedules ########################
########################################################

View File

@@ -0,0 +1,157 @@
"""Tests for the public shared file download endpoint."""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from starlette.responses import Response
from backend.api.features.v1 import v1_router
from backend.data.workspace import WorkspaceFile
app = FastAPI()
app.include_router(v1_router, prefix="/api")
VALID_TOKEN = "550e8400-e29b-41d4-a716-446655440000"
VALID_FILE_ID = "6ba7b810-9dad-11d1-80b4-00c04fd430c8"
def _make_workspace_file(**overrides) -> WorkspaceFile:
defaults = {
"id": VALID_FILE_ID,
"workspace_id": "ws-001",
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
"updated_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
"name": "image.png",
"path": "/image.png",
"storage_path": "local://uploads/image.png",
"mime_type": "image/png",
"size_bytes": 4,
"checksum": None,
"is_deleted": False,
"deleted_at": None,
"metadata": {},
}
defaults.update(overrides)
return WorkspaceFile(**defaults)
def _mock_download_response(**kwargs):
"""Return an AsyncMock that resolves to a Response with inline disposition."""
async def _handler(file, *, inline=False):
return Response(
content=b"\x89PNG",
media_type="image/png",
headers={
"Content-Disposition": (
'inline; filename="image.png"'
if inline
else 'attachment; filename="image.png"'
),
"Content-Length": "4",
},
)
return _handler
class TestDownloadSharedFile:
"""Tests for GET /api/public/shared/{token}/files/{id}/download."""
@pytest.fixture(autouse=True)
def _client(self):
self.client = TestClient(app, raise_server_exceptions=False)
def test_valid_token_and_file_returns_inline_content(self):
with (
patch(
"backend.api.features.v1.execution_db.get_shared_execution_file",
new_callable=AsyncMock,
return_value="exec-123",
),
patch(
"backend.api.features.v1.get_workspace_file_by_id",
new_callable=AsyncMock,
return_value=_make_workspace_file(),
),
patch(
"backend.api.features.v1.create_file_download_response",
side_effect=_mock_download_response(),
),
):
response = self.client.get(
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
)
assert response.status_code == 200
assert response.content == b"\x89PNG"
assert "inline" in response.headers["Content-Disposition"]
def test_invalid_token_format_returns_422(self):
response = self.client.get(
f"/api/public/shared/not-a-uuid/files/{VALID_FILE_ID}/download"
)
assert response.status_code == 422
def test_token_not_in_allowlist_returns_404(self):
with patch(
"backend.api.features.v1.execution_db.get_shared_execution_file",
new_callable=AsyncMock,
return_value=None,
):
response = self.client.get(
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
)
assert response.status_code == 404
def test_file_missing_from_workspace_returns_404(self):
with (
patch(
"backend.api.features.v1.execution_db.get_shared_execution_file",
new_callable=AsyncMock,
return_value="exec-123",
),
patch(
"backend.api.features.v1.get_workspace_file_by_id",
new_callable=AsyncMock,
return_value=None,
),
):
response = self.client.get(
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
)
assert response.status_code == 404
def test_uniform_404_prevents_enumeration(self):
"""Both failure modes produce identical 404 — no information leak."""
with patch(
"backend.api.features.v1.execution_db.get_shared_execution_file",
new_callable=AsyncMock,
return_value=None,
):
resp_no_allow = self.client.get(
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
)
with (
patch(
"backend.api.features.v1.execution_db.get_shared_execution_file",
new_callable=AsyncMock,
return_value="exec-123",
),
patch(
"backend.api.features.v1.get_workspace_file_by_id",
new_callable=AsyncMock,
return_value=None,
),
):
resp_no_file = self.client.get(
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
)
assert resp_no_allow.status_code == 404
assert resp_no_file.status_code == 404
assert resp_no_allow.json() == resp_no_file.json()

View File

@@ -29,7 +29,9 @@ from backend.util.workspace import WorkspaceManager
from backend.util.workspace_storage import get_workspace_storage
def _sanitize_filename_for_header(filename: str) -> str:
def _sanitize_filename_for_header(
filename: str, disposition: str = "attachment"
) -> str:
"""
Sanitize filename for Content-Disposition header to prevent header injection.
@@ -44,11 +46,11 @@ def _sanitize_filename_for_header(filename: str) -> str:
# Check if filename has non-ASCII characters
try:
sanitized.encode("ascii")
return f'attachment; filename="{sanitized}"'
return f'{disposition}; filename="{sanitized}"'
except UnicodeEncodeError:
# Use RFC5987 encoding for UTF-8 filenames
encoded = quote(sanitized, safe="")
return f"attachment; filename*=UTF-8''{encoded}"
return f"{disposition}; filename*=UTF-8''{encoded}"
logger = logging.getLogger(__name__)
@@ -58,19 +60,26 @@ router = fastapi.APIRouter(
)
def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response:
def _create_streaming_response(
content: bytes, file: WorkspaceFile, *, inline: bool = False
) -> Response:
"""Create a streaming response for file content."""
disposition = _sanitize_filename_for_header(
file.name, disposition="inline" if inline else "attachment"
)
return Response(
content=content,
media_type=file.mime_type,
headers={
"Content-Disposition": _sanitize_filename_for_header(file.name),
"Content-Disposition": disposition,
"Content-Length": str(len(content)),
},
)
async def _create_file_download_response(file: WorkspaceFile) -> Response:
async def create_file_download_response(
file: WorkspaceFile, *, inline: bool = False
) -> Response:
"""
Create a download response for a workspace file.
@@ -82,7 +91,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
# For local storage, stream the file directly
if file.storage_path.startswith("local://"):
content = await storage.retrieve(file.storage_path)
return _create_streaming_response(content, file)
return _create_streaming_response(content, file, inline=inline)
# For GCS, try to redirect to signed URL, fall back to streaming
try:
@@ -90,7 +99,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
# If we got back an API path (fallback), stream directly instead
if url.startswith("/api/"):
content = await storage.retrieve(file.storage_path)
return _create_streaming_response(content, file)
return _create_streaming_response(content, file, inline=inline)
return fastapi.responses.RedirectResponse(url=url, status_code=302)
except Exception as e:
# Log the signed URL failure with context
@@ -102,7 +111,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
# Fall back to streaming directly from GCS
try:
content = await storage.retrieve(file.storage_path)
return _create_streaming_response(content, file)
return _create_streaming_response(content, file, inline=inline)
except Exception as fallback_error:
logger.error(
f"Fallback streaming also failed for file {file.id} "
@@ -169,7 +178,7 @@ async def download_file(
if file is None:
raise fastapi.HTTPException(status_code=404, detail="File not found")
return await _create_file_download_response(file)
return await create_file_download_response(file)
@router.delete(

View File

@@ -600,3 +600,221 @@ def test_list_files_offset_is_echoed_back(mock_manager_cls, mock_get_workspace):
mock_instance.list_files.assert_called_once_with(
limit=11, offset=50, include_all_sessions=True
)
# -- _sanitize_filename_for_header tests --
class TestSanitizeFilenameForHeader:
def test_simple_ascii_attachment(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
assert _sanitize_filename_for_header("report.pdf") == (
'attachment; filename="report.pdf"'
)
def test_inline_disposition(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
assert _sanitize_filename_for_header("image.png", disposition="inline") == (
'inline; filename="image.png"'
)
def test_strips_cr_lf_null(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
result = _sanitize_filename_for_header("a\rb\nc\x00d.txt")
assert "\r" not in result
assert "\n" not in result
assert "\x00" not in result
assert 'filename="abcd.txt"' in result
def test_escapes_quotes(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
result = _sanitize_filename_for_header('file"name.txt')
assert 'filename="file\\"name.txt"' in result
def test_header_injection_blocked(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
result = _sanitize_filename_for_header("evil.txt\r\nX-Injected: true")
# CR/LF stripped — the remaining text is safely inside the quoted value
assert "\r" not in result
assert "\n" not in result
assert result == 'attachment; filename="evil.txtX-Injected: true"'
def test_unicode_uses_rfc5987(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
result = _sanitize_filename_for_header("日本語.pdf")
assert "filename*=UTF-8''" in result
assert "attachment" in result
def test_unicode_inline(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
result = _sanitize_filename_for_header("图片.png", disposition="inline")
assert result.startswith("inline; filename*=UTF-8''")
def test_empty_filename(self):
from backend.api.features.workspace.routes import _sanitize_filename_for_header
result = _sanitize_filename_for_header("")
assert result == 'attachment; filename=""'
# -- _create_streaming_response tests --
class TestCreateStreamingResponse:
def test_attachment_disposition_by_default(self):
from backend.api.features.workspace.routes import _create_streaming_response
file = _make_file(name="data.bin", mime_type="application/octet-stream")
response = _create_streaming_response(b"binary-data", file)
assert (
response.headers["Content-Disposition"] == 'attachment; filename="data.bin"'
)
assert response.headers["Content-Type"] == "application/octet-stream"
assert response.headers["Content-Length"] == "11"
assert response.body == b"binary-data"
def test_inline_disposition(self):
from backend.api.features.workspace.routes import _create_streaming_response
file = _make_file(name="photo.png", mime_type="image/png")
response = _create_streaming_response(b"\x89PNG", file, inline=True)
assert response.headers["Content-Disposition"] == 'inline; filename="photo.png"'
assert response.headers["Content-Type"] == "image/png"
def test_inline_sanitizes_filename(self):
from backend.api.features.workspace.routes import _create_streaming_response
file = _make_file(name='evil"\r\n.txt', mime_type="text/plain")
response = _create_streaming_response(b"data", file, inline=True)
assert "\r" not in response.headers["Content-Disposition"]
assert "\n" not in response.headers["Content-Disposition"]
assert "inline" in response.headers["Content-Disposition"]
def test_content_length_matches_body(self):
from backend.api.features.workspace.routes import _create_streaming_response
content = b"x" * 1000
file = _make_file(name="big.bin", mime_type="application/octet-stream")
response = _create_streaming_response(content, file)
assert response.headers["Content-Length"] == "1000"
# -- create_file_download_response tests --
class TestCreateFileDownloadResponse:
@pytest.mark.asyncio
async def test_local_storage_returns_streaming_response(self, mocker):
from backend.api.features.workspace.routes import create_file_download_response
mock_storage = AsyncMock()
mock_storage.retrieve.return_value = b"file contents"
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage",
return_value=mock_storage,
)
file = _make_file(
storage_path="local://uploads/test.txt",
mime_type="text/plain",
)
response = await create_file_download_response(file)
assert response.status_code == 200
assert response.body == b"file contents"
assert "attachment" in response.headers["Content-Disposition"]
@pytest.mark.asyncio
async def test_local_storage_inline(self, mocker):
from backend.api.features.workspace.routes import create_file_download_response
mock_storage = AsyncMock()
mock_storage.retrieve.return_value = b"\x89PNG"
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage",
return_value=mock_storage,
)
file = _make_file(
storage_path="local://uploads/photo.png",
mime_type="image/png",
name="photo.png",
)
response = await create_file_download_response(file, inline=True)
assert "inline" in response.headers["Content-Disposition"]
@pytest.mark.asyncio
async def test_gcs_redirect(self, mocker):
from backend.api.features.workspace.routes import create_file_download_response
mock_storage = AsyncMock()
mock_storage.get_download_url.return_value = (
"https://storage.googleapis.com/signed-url"
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage",
return_value=mock_storage,
)
file = _make_file(storage_path="gcs://bucket/file.pdf")
response = await create_file_download_response(file)
assert response.status_code == 302
assert (
response.headers["location"] == "https://storage.googleapis.com/signed-url"
)
@pytest.mark.asyncio
async def test_gcs_api_fallback_streams_directly(self, mocker):
from backend.api.features.workspace.routes import create_file_download_response
mock_storage = AsyncMock()
mock_storage.get_download_url.return_value = "/api/fallback"
mock_storage.retrieve.return_value = b"fallback content"
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage",
return_value=mock_storage,
)
file = _make_file(storage_path="gcs://bucket/file.txt")
response = await create_file_download_response(file)
assert response.status_code == 200
assert response.body == b"fallback content"
@pytest.mark.asyncio
async def test_gcs_signed_url_failure_falls_back_to_streaming(self, mocker):
from backend.api.features.workspace.routes import create_file_download_response
mock_storage = AsyncMock()
mock_storage.get_download_url.side_effect = RuntimeError("GCS error")
mock_storage.retrieve.return_value = b"streamed"
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage",
return_value=mock_storage,
)
file = _make_file(storage_path="gcs://bucket/file.txt")
response = await create_file_download_response(file)
assert response.status_code == 200
assert response.body == b"streamed"
@pytest.mark.asyncio
async def test_gcs_total_failure_raises(self, mocker):
from backend.api.features.workspace.routes import create_file_download_response
mock_storage = AsyncMock()
mock_storage.get_download_url.side_effect = RuntimeError("GCS error")
mock_storage.retrieve.side_effect = RuntimeError("Also failed")
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_storage",
return_value=mock_storage,
)
file = _make_file(storage_path="gcs://bucket/file.txt")
with pytest.raises(RuntimeError, match="Also failed"):
await create_file_download_response(file)

View File

@@ -32,6 +32,7 @@ import backend.api.features.library.routes
import backend.api.features.mcp.routes as mcp_routes
import backend.api.features.oauth
import backend.api.features.otto.routes
import backend.api.features.platform_linking.routes
import backend.api.features.postmark.postmark
import backend.api.features.store.model
import backend.api.features.store.routes
@@ -378,6 +379,11 @@ app.include_router(
tags=["oauth"],
prefix="/api/oauth",
)
app.include_router(
backend.api.features.platform_linking.routes.router,
tags=["platform-linking"],
prefix="/api/platform-linking",
)
app.mount("/external-api", external_api)

View File

@@ -42,11 +42,13 @@ def main(**kwargs):
from backend.data.db_manager import DatabaseManager
from backend.executor import ExecutionManager, Scheduler
from backend.notifications import NotificationManager
from backend.platform_linking.manager import PlatformLinkingManager
run_processes(
DatabaseManager().set_log_level("warning"),
Scheduler(),
NotificationManager(),
PlatformLinkingManager(),
WebsocketServer(),
AgentServer(),
ExecutionManager(),

View File

@@ -155,3 +155,16 @@ def platform_cost_db():
platform_cost_db = get_database_manager_async_client()
return platform_cost_db
def platform_linking_db():
if db.is_connected():
from backend.platform_linking import db as _platform_linking_db
platform_linking_db = _platform_linking_db
else:
from backend.util.clients import get_database_manager_async_client
platform_linking_db = get_database_manager_async_client()
return platform_linking_db

View File

@@ -120,6 +120,7 @@ from backend.data.workspace import (
list_workspace_files,
soft_delete_workspace_file,
)
from backend.platform_linking import db as platform_linking_db
from backend.util.service import (
AppService,
AppServiceClient,
@@ -338,6 +339,22 @@ class DatabaseManager(AppService):
# ============ Platform Cost Tracking ============ #
log_platform_cost = _(log_platform_cost)
# ============ Platform Linking ============ #
find_server_link_owner = _(platform_linking_db.find_server_link_owner)
find_user_link_owner = _(platform_linking_db.find_user_link_owner)
resolve_server_link = _(platform_linking_db.resolve_server_link)
resolve_user_link = _(platform_linking_db.resolve_user_link)
create_server_link_token = _(platform_linking_db.create_server_link_token)
create_user_link_token = _(platform_linking_db.create_user_link_token)
get_link_token_status = _(platform_linking_db.get_link_token_status)
get_link_token_info = _(platform_linking_db.get_link_token_info)
confirm_server_link = _(platform_linking_db.confirm_server_link)
confirm_user_link = _(platform_linking_db.confirm_user_link)
list_server_links = _(platform_linking_db.list_server_links)
list_user_links = _(platform_linking_db.list_user_links)
delete_server_link = _(platform_linking_db.delete_server_link)
delete_user_link = _(platform_linking_db.delete_user_link)
# ============ CoPilot Chat Sessions ============ #
get_chat_session = _(chat_db.get_chat_session)
create_chat_session = _(chat_db.create_chat_session)
@@ -540,6 +557,22 @@ class DatabaseManagerAsyncClient(AppServiceClient):
# ============ Platform Cost Tracking ============ #
log_platform_cost = d.log_platform_cost
# ============ Platform Linking ============ #
find_server_link_owner = d.find_server_link_owner
find_user_link_owner = d.find_user_link_owner
resolve_server_link = d.resolve_server_link
resolve_user_link = d.resolve_user_link
create_server_link_token = d.create_server_link_token
create_user_link_token = d.create_user_link_token
get_link_token_status = d.get_link_token_status
get_link_token_info = d.get_link_token_info
confirm_server_link = d.confirm_server_link
confirm_user_link = d.confirm_user_link
list_server_links = d.list_server_links
list_user_links = d.list_user_links
delete_server_link = d.delete_server_link
delete_user_link = d.delete_user_link
# ============ CoPilot Chat Sessions ============ #
get_chat_session = d.get_chat_session
create_chat_session = d.create_chat_session

View File

@@ -19,11 +19,15 @@ from typing import (
from prisma import Json
from prisma.enums import AgentExecutionStatus
from prisma.errors import ForeignKeyViolationError, UniqueViolationError
from prisma.models import (
AgentGraphExecution,
AgentNodeExecution,
AgentNodeExecutionInputOutput,
AgentNodeExecutionKeyValueData,
SharedExecutionFile,
UserWorkspace,
UserWorkspaceFile,
)
from prisma.types import (
AgentGraphExecutionOrderByInput,
@@ -1602,6 +1606,121 @@ async def get_graph_execution_by_share_token(
)
def _extract_workspace_file_ids(outputs: CompletedBlockOutput) -> set[str]:
"""Extract workspace file IDs from execution outputs.
Scans all output values for workspace:// URI strings and extracts
the file IDs. Only matches values that are plain strings starting
with workspace://, not substrings within larger text.
"""
file_ids: set[str] = set()
def _scan(value: Any) -> None:
if isinstance(value, str) and value.startswith("workspace://"):
raw = value.removeprefix("workspace://")
file_ref = raw.split("#", 1)[0] if "#" in raw else raw
if file_ref and not file_ref.startswith("/"):
file_ids.add(file_ref)
elif isinstance(value, list):
for item in value:
_scan(item)
elif isinstance(value, dict):
for v in value.values():
_scan(v)
for output_values in outputs.values():
if isinstance(output_values, list):
for val in output_values:
_scan(val)
else:
_scan(output_values)
return file_ids
async def create_shared_execution_files(
execution_id: str,
share_token: str,
user_id: str,
outputs: CompletedBlockOutput,
) -> int:
"""Scan execution outputs for workspace files and create allowlist records.
Only files belonging to the user's workspace are allowlisted — prevents
cross-workspace file exposure via crafted outputs.
Returns the number of records created.
"""
file_ids = _extract_workspace_file_ids(outputs)
if not file_ids:
return 0
# Validate file IDs belong to the user's workspace
workspace = await UserWorkspace.prisma().find_unique(where={"userId": user_id})
if not workspace:
return 0
owned_files = await UserWorkspaceFile.prisma().find_many(
where={
"id": {"in": list(file_ids)},
"workspaceId": workspace.id,
"isDeleted": False,
}
)
owned_ids = {f.id for f in owned_files}
created = 0
for file_id in owned_ids:
try:
await SharedExecutionFile.prisma().create(
data={
"executionId": execution_id,
"fileId": file_id,
"shareToken": share_token,
}
)
created += 1
except UniqueViolationError:
logger.debug(
f"Skipping shared file record for {file_id}: " f"record already exists"
)
except ForeignKeyViolationError:
logger.debug(
f"Skipping shared file record for {file_id}: " f"file does not exist"
)
return created
async def delete_shared_execution_files(execution_id: str) -> int:
"""Delete all shared file records for an execution.
Returns the number of records deleted.
"""
result = await SharedExecutionFile.prisma().delete_many(
where={"executionId": execution_id}
)
return result
async def get_shared_execution_file(
share_token: str,
file_id: str,
) -> str | None:
"""Look up a file ID in the shared execution file allowlist.
Returns the execution ID if the file is in the allowlist, None otherwise.
Uses a single query and returns a uniform None for all failure modes
to prevent timing-based enumeration attacks.
"""
record = await SharedExecutionFile.prisma().find_first(
where={
"shareToken": share_token,
"fileId": file_id,
}
)
return record.executionId if record else None
async def get_frequently_executed_graphs(
days_back: int = 30,
min_executions: int = 10,

View File

@@ -0,0 +1,72 @@
"""Tests for SharedExecutionFile workspace URI extraction logic."""
from backend.data.execution import _extract_workspace_file_ids
class TestExtractWorkspaceFileIds:
def test_extracts_simple_workspace_uri(self):
outputs = {"image": ["workspace://abc123"]}
assert _extract_workspace_file_ids(outputs) == {"abc123"}
def test_extracts_workspace_uri_with_mime_fragment(self):
outputs = {"image": ["workspace://abc123#image/png"]}
assert _extract_workspace_file_ids(outputs) == {"abc123"}
def test_extracts_multiple_files_from_multiple_outputs(self):
outputs = {
"images": ["workspace://file1#image/png", "workspace://file2#image/jpeg"],
"video": ["workspace://file3#video/mp4"],
}
assert _extract_workspace_file_ids(outputs) == {"file1", "file2", "file3"}
def test_ignores_non_workspace_strings(self):
outputs = {
"text": ["hello world"],
"url": ["https://example.com/image.png"],
"data": ["data:image/png;base64,abc"],
}
assert _extract_workspace_file_ids(outputs) == set()
def test_ignores_path_references(self):
"""workspace:///path/to/file is a path reference, not a file ID."""
outputs = {"file": ["workspace:///path/to/file.txt"]}
assert _extract_workspace_file_ids(outputs) == set()
def test_handles_nested_dicts_in_output_values(self):
outputs = {
"result": [{"url": "workspace://nested-file#image/png", "label": "test"}]
}
assert _extract_workspace_file_ids(outputs) == {"nested-file"}
def test_handles_nested_lists_in_output_values(self):
outputs = {"result": [["workspace://inner-file"]]}
assert _extract_workspace_file_ids(outputs) == {"inner-file"}
def test_handles_empty_outputs(self):
assert _extract_workspace_file_ids({}) == set()
def test_handles_non_string_values(self):
outputs = {"count": [42], "flag": [True], "empty": [None]}
assert _extract_workspace_file_ids(outputs) == set()
def test_deduplicates_repeated_file_ids(self):
outputs = {
"a": ["workspace://same-file#image/png"],
"b": ["workspace://same-file#image/jpeg"],
}
assert _extract_workspace_file_ids(outputs) == {"same-file"}
def test_does_not_match_workspace_substring_in_text(self):
"""Plain text that contains workspace:// as a substring should NOT be extracted
because the value itself must start with workspace://."""
outputs = {"text": ["check out workspace://fake-id for details"]}
# The string starts with "check out", not "workspace://", so no match
assert _extract_workspace_file_ids(outputs) == set()
def test_mixed_workspace_and_non_workspace_outputs(self):
outputs = {
"image": ["workspace://real-file#image/png"],
"text": ["just some text"],
"url": ["https://example.com"],
}
assert _extract_workspace_file_ids(outputs) == {"real-file"}

View File

@@ -204,6 +204,22 @@ async def get_workspace_file(
return WorkspaceFile.from_db(file) if file else None
async def get_workspace_file_by_id(
file_id: str,
) -> Optional[WorkspaceFile]:
"""
Get a workspace file by ID without workspace scoping.
Only use this when access has already been validated through another
mechanism (e.g. SharedExecutionFile allowlist). For user-facing
endpoints, use get_workspace_file() which enforces workspace scoping.
"""
file = await UserWorkspaceFile.prisma().find_first(
where={"id": file_id, "isDeleted": False}
)
return WorkspaceFile.from_db(file) if file else None
async def get_workspace_file_by_path(
workspace_id: str,
path: str,

View File

@@ -0,0 +1 @@
"""Platform bot linking: helpers, chat orchestration, and AppService."""

View File

@@ -0,0 +1,112 @@
"""Chat-turn orchestration for the platform bot bridge."""
import logging
from uuid import uuid4
from backend.copilot import stream_registry
from backend.copilot.executor.utils import enqueue_copilot_turn
from backend.copilot.model import (
ChatMessage,
append_and_save_message,
create_chat_session,
get_chat_session,
)
from backend.data.db_accessors import platform_linking_db
from backend.util.exceptions import DuplicateChatMessageError, NotFoundError
from .models import BotChatRequest, ChatTurnHandle
logger = logging.getLogger(__name__)
CHAT_TOOL_CALL_ID = "chat_stream"
CHAT_TOOL_NAME = "chat"
async def resolve_chat_owner(request: BotChatRequest) -> str:
"""Return the AutoGPT user ID that owns the platform conversation.
Server context → server owner. DM context → the DM-linked user.
"""
platform = request.platform.value
db = platform_linking_db()
if request.platform_server_id:
owner = await db.find_server_link_owner(platform, request.platform_server_id)
if owner is None:
raise NotFoundError("This server is not linked to an AutoGPT account.")
return owner
owner = await db.find_user_link_owner(platform, request.platform_user_id)
if owner is None:
raise NotFoundError("Your DMs are not linked to an AutoGPT account.")
return owner
async def start_chat_turn(request: BotChatRequest) -> ChatTurnHandle:
"""Prepare a copilot turn; caller subscribes via the returned handle.
``subscribe_from="0-0"`` on the handle means a late subscriber replays
the full stream (Redis Streams, not pub/sub).
"""
owner_user_id = await resolve_chat_owner(request)
session_id = request.session_id
if session_id:
session = await get_chat_session(session_id, owner_user_id)
if not session:
raise NotFoundError("Session not found.")
else:
session = await create_chat_session(owner_user_id, dry_run=False)
session_id = session.session_id
# Persist the user message before enqueueing, mirroring the REST chat
# endpoint — otherwise the executor runs against empty history.
is_duplicate = (
await append_and_save_message(
session_id, ChatMessage(role="user", content=request.message)
)
) is None
if is_duplicate:
# Matches REST chat behaviour: skip create_session + enqueue so we
# don't create an orphan stream with no producer. Caller subscribes
# to the in-flight turn via its own retry logic, or drops.
logger.info(
"Duplicate bot message for session %s (platform %s, user ...%s)",
session_id,
request.platform.value,
owner_user_id[-8:],
)
raise DuplicateChatMessageError("Message already in flight.")
turn_id = str(uuid4())
await stream_registry.create_session(
session_id=session_id,
user_id=owner_user_id,
tool_call_id=CHAT_TOOL_CALL_ID,
tool_name=CHAT_TOOL_NAME,
turn_id=turn_id,
)
await enqueue_copilot_turn(
session_id=session_id,
user_id=owner_user_id,
message=request.message,
turn_id=turn_id,
is_user_message=True,
)
logger.info(
"Bot chat turn started: %s (server %s, session %s, turn %s, owner ...%s)",
request.platform.value,
request.platform_server_id or "DM",
session_id,
turn_id,
owner_user_id[-8:],
)
return ChatTurnHandle(
session_id=session_id,
turn_id=turn_id,
user_id=owner_user_id,
)

View File

@@ -0,0 +1,125 @@
"""Tests for chat-turn orchestration — esp. the duplicate-message guard."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.util.exceptions import DuplicateChatMessageError, NotFoundError
from .chat import start_chat_turn
from .models import BotChatRequest, Platform
def _request(**overrides) -> BotChatRequest:
defaults = dict(
platform=Platform.DISCORD,
platform_user_id="pu1",
message="hello",
)
defaults.update(overrides)
return BotChatRequest(**defaults)
class TestStartChatTurn:
@pytest.mark.asyncio
async def test_no_user_link_raises_not_found(self):
db_mock = MagicMock()
db_mock.find_user_link_owner = AsyncMock(return_value=None)
with patch(
"backend.platform_linking.chat.platform_linking_db",
return_value=db_mock,
):
with pytest.raises(NotFoundError):
await start_chat_turn(_request())
@pytest.mark.asyncio
async def test_duplicate_message_raises_and_skips_stream_create(self):
# append_and_save_message returns None → duplicate.
# Verify we raise and do NOT create a stream session.
db_mock = MagicMock()
db_mock.find_user_link_owner = AsyncMock(return_value="owner-1")
session = MagicMock(session_id="sess-existing")
with (
patch(
"backend.platform_linking.chat.platform_linking_db",
return_value=db_mock,
),
patch(
"backend.platform_linking.chat.create_chat_session",
new=AsyncMock(return_value=session),
),
patch(
"backend.platform_linking.chat.append_and_save_message",
new=AsyncMock(return_value=None),
),
patch(
"backend.platform_linking.chat.stream_registry"
) as mock_stream_registry,
patch(
"backend.platform_linking.chat.enqueue_copilot_turn",
new=AsyncMock(),
) as mock_enqueue,
):
mock_stream_registry.create_session = AsyncMock()
with pytest.raises(DuplicateChatMessageError):
await start_chat_turn(_request())
mock_stream_registry.create_session.assert_not_awaited()
mock_enqueue.assert_not_awaited()
@pytest.mark.asyncio
async def test_happy_path_creates_stream_and_enqueues(self):
db_mock = MagicMock()
db_mock.find_user_link_owner = AsyncMock(return_value="owner-1")
session = MagicMock(session_id="sess-new")
with (
patch(
"backend.platform_linking.chat.platform_linking_db",
return_value=db_mock,
),
patch(
"backend.platform_linking.chat.create_chat_session",
new=AsyncMock(return_value=session),
),
patch(
"backend.platform_linking.chat.append_and_save_message",
new=AsyncMock(return_value=MagicMock()),
),
patch(
"backend.platform_linking.chat.stream_registry"
) as mock_stream_registry,
patch(
"backend.platform_linking.chat.enqueue_copilot_turn",
new=AsyncMock(),
) as mock_enqueue,
):
mock_stream_registry.create_session = AsyncMock()
handle = await start_chat_turn(_request())
assert handle.session_id == "sess-new"
assert handle.user_id == "owner-1"
assert handle.turn_id
assert handle.subscribe_from == "0-0"
mock_stream_registry.create_session.assert_awaited_once()
mock_enqueue.assert_awaited_once()
@pytest.mark.asyncio
async def test_existing_session_id_wrong_user_raises_not_found(self):
db_mock = MagicMock()
db_mock.find_user_link_owner = AsyncMock(return_value="owner-1")
with (
patch(
"backend.platform_linking.chat.platform_linking_db",
return_value=db_mock,
),
patch(
"backend.platform_linking.chat.get_chat_session",
new=AsyncMock(return_value=None),
),
):
with pytest.raises(NotFoundError):
await start_chat_turn(_request(session_id="someone-elses"))

View File

@@ -0,0 +1,428 @@
"""Platform link DB operations.
Directly accessed by the ``AgentServer`` / ``DatabaseManager`` pods (which
hold the Prisma connection). Other services go through
``backend.data.db_accessors.platform_linking_db`` so calls are transparently
routed via ``DatabaseManagerAsyncClient`` when no local Prisma is available.
"""
import logging
import secrets
from datetime import datetime, timedelta, timezone
from prisma.errors import UniqueViolationError
from prisma.models import PlatformLink, PlatformLinkToken, PlatformUserLink
from backend.data.db import transaction
from backend.util.exceptions import (
LinkAlreadyExistsError,
LinkFlowMismatchError,
LinkTokenExpiredError,
NotAuthorizedError,
NotFoundError,
)
from backend.util.settings import Settings
from .models import (
ConfirmLinkResponse,
ConfirmUserLinkResponse,
CreateLinkTokenRequest,
CreateUserLinkTokenRequest,
DeleteLinkResponse,
LinkTokenInfoResponse,
LinkTokenResponse,
LinkTokenStatusResponse,
LinkType,
PlatformLinkInfo,
PlatformUserLinkInfo,
ResolveResponse,
)
logger = logging.getLogger(__name__)
LINK_TOKEN_EXPIRY_MINUTES = 30
def _link_base_url() -> str:
return Settings().config.platform_link_base_url
# ── Owner lookups ─────────────────────────────────────────────────────
# These return the owning AutoGPT user_id (or None). Using scalars instead
# of Prisma models keeps everything RPC-safe — Prisma objects are rejected
# by AppService's result validator.
async def find_server_link_owner(platform: str, platform_server_id: str) -> str | None:
link = await PlatformLink.prisma().find_first(
where={"platform": platform, "platformServerId": platform_server_id}
)
return link.userId if link else None
async def find_user_link_owner(platform: str, platform_user_id: str) -> str | None:
link = await PlatformUserLink.prisma().find_unique(
where={
"platform_platformUserId": {
"platform": platform,
"platformUserId": platform_user_id,
}
}
)
return link.userId if link else None
async def resolve_server_link(
platform: str, platform_server_id: str
) -> ResolveResponse:
owner = await find_server_link_owner(platform, platform_server_id)
return ResolveResponse(linked=owner is not None)
async def resolve_user_link(platform: str, platform_user_id: str) -> ResolveResponse:
owner = await find_user_link_owner(platform, platform_user_id)
return ResolveResponse(linked=owner is not None)
# ── Token creation ────────────────────────────────────────────────────
async def create_server_link_token(
request: CreateLinkTokenRequest,
) -> LinkTokenResponse:
platform = request.platform.value
if await find_server_link_owner(platform, request.platform_server_id):
raise LinkAlreadyExistsError(
"This server is already linked to an AutoGPT account."
)
token = secrets.token_urlsafe(32)
expires_at = datetime.now(timezone.utc) + timedelta(
minutes=LINK_TOKEN_EXPIRY_MINUTES
)
# Atomic: invalidate pending tokens + create the new one, so two racing
# create calls can't leave two valid tokens for the same target.
async with transaction() as tx:
await PlatformLinkToken.prisma(tx).update_many(
where={
"platform": platform,
"linkType": LinkType.SERVER.value,
"platformServerId": request.platform_server_id,
"usedAt": None,
},
data={"usedAt": datetime.now(timezone.utc)},
)
await PlatformLinkToken.prisma(tx).create(
data={
"token": token,
"platform": platform,
"linkType": LinkType.SERVER.value,
"platformServerId": request.platform_server_id,
"platformUserId": request.platform_user_id,
"platformUsername": request.platform_username,
"serverName": request.server_name,
"channelId": request.channel_id,
"expiresAt": expires_at,
}
)
logger.info(
"Created SERVER link token for %s server %s (expires %s)",
platform,
request.platform_server_id,
expires_at.isoformat(),
)
return LinkTokenResponse(
token=token,
expires_at=expires_at,
link_url=f"{_link_base_url()}/{token}?platform={platform}",
)
async def create_user_link_token(
request: CreateUserLinkTokenRequest,
) -> LinkTokenResponse:
platform = request.platform.value
if await find_user_link_owner(platform, request.platform_user_id):
raise LinkAlreadyExistsError(
"Your DMs with the bot are already linked to an AutoGPT account."
)
token = secrets.token_urlsafe(32)
expires_at = datetime.now(timezone.utc) + timedelta(
minutes=LINK_TOKEN_EXPIRY_MINUTES
)
async with transaction() as tx:
await PlatformLinkToken.prisma(tx).update_many(
where={
"platform": platform,
"linkType": LinkType.USER.value,
"platformUserId": request.platform_user_id,
"usedAt": None,
},
data={"usedAt": datetime.now(timezone.utc)},
)
await PlatformLinkToken.prisma(tx).create(
data={
"token": token,
"platform": platform,
"linkType": LinkType.USER.value,
"platformUserId": request.platform_user_id,
"platformUsername": request.platform_username,
"expiresAt": expires_at,
}
)
logger.info(
"Created USER link token for %s (expires %s)", platform, expires_at.isoformat()
)
return LinkTokenResponse(
token=token,
expires_at=expires_at,
link_url=f"{_link_base_url()}/{token}?platform={platform}",
)
# ── Token status / info ───────────────────────────────────────────────
async def get_link_token_status(token: str) -> LinkTokenStatusResponse:
link_token = await PlatformLinkToken.prisma().find_unique(where={"token": token})
if not link_token:
raise NotFoundError("Token not found.")
if link_token.usedAt is not None:
# A superseded token (invalidated by create_*_token) has usedAt set
# without a backing link row — report expired, not linked.
if link_token.linkType == LinkType.USER.value:
owner = await find_user_link_owner(
link_token.platform, link_token.platformUserId
)
else:
owner = (
await find_server_link_owner(
link_token.platform, link_token.platformServerId
)
if link_token.platformServerId
else None
)
return LinkTokenStatusResponse(status="linked" if owner else "expired")
if link_token.expiresAt.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
return LinkTokenStatusResponse(status="expired")
return LinkTokenStatusResponse(status="pending")
async def get_link_token_info(token: str) -> LinkTokenInfoResponse:
link_token = await PlatformLinkToken.prisma().find_unique(where={"token": token})
if not link_token or link_token.usedAt is not None:
raise NotFoundError("Token not found.")
if link_token.expiresAt.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
raise LinkTokenExpiredError("Token expired.")
return LinkTokenInfoResponse(
platform=link_token.platform,
link_type=LinkType(link_token.linkType),
server_name=link_token.serverName,
)
# ── Confirmation (user-facing, JWT-authed) ────────────────────────────
async def confirm_server_link(token: str, user_id: str) -> ConfirmLinkResponse:
link_token = await PlatformLinkToken.prisma().find_unique(where={"token": token})
if not link_token:
raise NotFoundError("Token not found.")
if link_token.linkType != LinkType.SERVER.value:
raise LinkFlowMismatchError("This link is for a different linking flow.")
if link_token.usedAt is not None:
raise LinkTokenExpiredError("This link has already been used.")
if link_token.expiresAt.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
raise LinkTokenExpiredError("This link has expired.")
if not link_token.platformServerId:
raise LinkFlowMismatchError("Server token missing server ID.")
owner = await find_server_link_owner(
link_token.platform, link_token.platformServerId
)
if owner:
detail = (
"This server is already linked to your account."
if owner == user_id
else "This server is already linked to another AutoGPT account."
)
raise LinkAlreadyExistsError(detail)
# Atomic consume + create so a failed create doesn't burn the token.
now = datetime.now(timezone.utc)
try:
async with transaction() as tx:
updated = await PlatformLinkToken.prisma(tx).update_many(
where={"token": token, "usedAt": None, "expiresAt": {"gt": now}},
data={"usedAt": now},
)
if updated == 0:
raise LinkTokenExpiredError("This link has already been used.")
await PlatformLink.prisma(tx).create(
data={
"userId": user_id,
"platform": link_token.platform,
"platformServerId": link_token.platformServerId,
"ownerPlatformUserId": link_token.platformUserId,
"serverName": link_token.serverName,
}
)
except UniqueViolationError as exc:
raise LinkAlreadyExistsError(
"This server was just linked by another request."
) from exc
logger.info(
"Linked %s server %s to user ...%s",
link_token.platform,
link_token.platformServerId,
user_id[-8:],
)
return ConfirmLinkResponse(
success=True,
platform=link_token.platform,
platform_server_id=link_token.platformServerId,
server_name=link_token.serverName,
)
async def confirm_user_link(token: str, user_id: str) -> ConfirmUserLinkResponse:
link_token = await PlatformLinkToken.prisma().find_unique(where={"token": token})
if not link_token:
raise NotFoundError("Token not found.")
if link_token.linkType != LinkType.USER.value:
raise LinkFlowMismatchError("This link is for a different linking flow.")
if link_token.usedAt is not None:
raise LinkTokenExpiredError("This link has already been used.")
if link_token.expiresAt.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
raise LinkTokenExpiredError("This link has expired.")
owner = await find_user_link_owner(link_token.platform, link_token.platformUserId)
if owner:
detail = (
"Your DMs are already linked to your account."
if owner == user_id
else "This platform user is already linked to another AutoGPT account."
)
raise LinkAlreadyExistsError(detail)
now = datetime.now(timezone.utc)
try:
async with transaction() as tx:
updated = await PlatformLinkToken.prisma(tx).update_many(
where={"token": token, "usedAt": None, "expiresAt": {"gt": now}},
data={"usedAt": now},
)
if updated == 0:
raise LinkTokenExpiredError("This link has already been used.")
await PlatformUserLink.prisma(tx).create(
data={
"userId": user_id,
"platform": link_token.platform,
"platformUserId": link_token.platformUserId,
"platformUsername": link_token.platformUsername,
}
)
except UniqueViolationError as exc:
raise LinkAlreadyExistsError(
"Your DMs were just linked by another request."
) from exc
logger.info(
"Linked %s DMs to AutoGPT user ...%s", link_token.platform, user_id[-8:]
)
return ConfirmUserLinkResponse(
success=True,
platform=link_token.platform,
platform_user_id=link_token.platformUserId,
)
# ── Listing ───────────────────────────────────────────────────────────
async def list_server_links(user_id: str) -> list[PlatformLinkInfo]:
links = await PlatformLink.prisma().find_many(
where={"userId": user_id},
order={"linkedAt": "desc"},
)
return [
PlatformLinkInfo(
id=link.id,
platform=link.platform,
platform_server_id=link.platformServerId,
owner_platform_user_id=link.ownerPlatformUserId,
server_name=link.serverName,
linked_at=link.linkedAt,
)
for link in links
]
async def list_user_links(user_id: str) -> list[PlatformUserLinkInfo]:
links = await PlatformUserLink.prisma().find_many(
where={"userId": user_id},
order={"linkedAt": "desc"},
)
return [
PlatformUserLinkInfo(
id=link.id,
platform=link.platform,
platform_user_id=link.platformUserId,
platform_username=link.platformUsername,
linked_at=link.linkedAt,
)
for link in links
]
# ── Deletion ──────────────────────────────────────────────────────────
async def delete_server_link(link_id: str, user_id: str) -> DeleteLinkResponse:
link = await PlatformLink.prisma().find_unique(where={"id": link_id})
if not link:
raise NotFoundError("Link not found.")
if link.userId != user_id:
raise NotAuthorizedError("Not your link.")
await PlatformLink.prisma().delete(where={"id": link_id})
logger.info(
"Unlinked %s server %s from user ...%s",
link.platform,
link.platformServerId,
user_id[-8:],
)
return DeleteLinkResponse(success=True)
async def delete_user_link(link_id: str, user_id: str) -> DeleteLinkResponse:
link = await PlatformUserLink.prisma().find_unique(where={"id": link_id})
if not link:
raise NotFoundError("Link not found.")
if link.userId != user_id:
raise NotAuthorizedError("Not your link.")
await PlatformUserLink.prisma().delete(where={"id": link_id})
logger.info("Unlinked %s DMs from AutoGPT user ...%s", link.platform, user_id[-8:])
return DeleteLinkResponse(success=True)

View File

@@ -0,0 +1,481 @@
"""Unit tests for platform_linking DB operations."""
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.util.exceptions import (
LinkAlreadyExistsError,
LinkFlowMismatchError,
LinkTokenExpiredError,
NotAuthorizedError,
NotFoundError,
)
from .db import (
confirm_server_link,
confirm_user_link,
create_server_link_token,
create_user_link_token,
delete_server_link,
delete_user_link,
get_link_token_info,
get_link_token_status,
resolve_server_link,
resolve_user_link,
)
from .models import (
CreateLinkTokenRequest,
CreateUserLinkTokenRequest,
LinkType,
Platform,
)
@asynccontextmanager
async def _fake_transaction():
# Avoids Prisma's tx binding asyncio primitives to the wrong loop in tests.
yield MagicMock()
# ── Resolve ──────────────────────────────────────────────────────────
class TestResolve:
@pytest.mark.asyncio
async def test_server_linked(self):
with patch("backend.platform_linking.db.PlatformLink") as mock_link:
mock_link.prisma.return_value.find_first = AsyncMock(
return_value=MagicMock(userId="u-123")
)
result = await resolve_server_link("DISCORD", "g1")
assert result.linked is True
@pytest.mark.asyncio
async def test_server_unlinked(self):
with patch("backend.platform_linking.db.PlatformLink") as mock_link:
mock_link.prisma.return_value.find_first = AsyncMock(return_value=None)
result = await resolve_server_link("DISCORD", "g1")
assert result.linked is False
@pytest.mark.asyncio
async def test_user_linked(self):
with patch("backend.platform_linking.db.PlatformUserLink") as mock_user_link:
mock_user_link.prisma.return_value.find_unique = AsyncMock(
return_value=MagicMock(userId="u-xyz")
)
result = await resolve_user_link("DISCORD", "pu1")
assert result.linked is True
@pytest.mark.asyncio
async def test_user_unlinked(self):
with patch("backend.platform_linking.db.PlatformUserLink") as mock_user_link:
mock_user_link.prisma.return_value.find_unique = AsyncMock(
return_value=None
)
result = await resolve_user_link("DISCORD", "pu1")
assert result.linked is False
# ── Token creation ───────────────────────────────────────────────────
class TestCreateServerLinkToken:
@pytest.mark.asyncio
async def test_creates_token_for_unlinked_server(self):
with (
patch("backend.platform_linking.db.PlatformLink") as mock_link,
patch(
"backend.platform_linking.db.transaction",
new=_fake_transaction,
),
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token_model,
):
mock_link.prisma.return_value.find_first = AsyncMock(return_value=None)
mock_token_model.prisma.return_value.update_many = AsyncMock(return_value=0)
mock_token_model.prisma.return_value.create = AsyncMock(
return_value=MagicMock()
)
result = await create_server_link_token(
CreateLinkTokenRequest(
platform=Platform.DISCORD,
platform_server_id="g1",
platform_user_id="u1",
server_name="Test",
),
)
assert result.token
assert result.token in result.link_url
assert "?platform=DISCORD" in result.link_url
@pytest.mark.asyncio
async def test_rejects_when_already_linked(self):
with patch("backend.platform_linking.db.PlatformLink") as mock_link:
mock_link.prisma.return_value.find_first = AsyncMock(
return_value=MagicMock(userId="u-owner")
)
with pytest.raises(LinkAlreadyExistsError):
await create_server_link_token(
CreateLinkTokenRequest(
platform=Platform.DISCORD,
platform_server_id="g1",
platform_user_id="u1",
),
)
class TestCreateUserLinkToken:
@pytest.mark.asyncio
async def test_creates_token_for_unlinked_user(self):
with (
patch("backend.platform_linking.db.PlatformUserLink") as mock_user_link,
patch(
"backend.platform_linking.db.transaction",
new=_fake_transaction,
),
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token_model,
):
mock_user_link.prisma.return_value.find_unique = AsyncMock(
return_value=None
)
mock_token_model.prisma.return_value.update_many = AsyncMock(return_value=0)
mock_token_model.prisma.return_value.create = AsyncMock(
return_value=MagicMock()
)
result = await create_user_link_token(
CreateUserLinkTokenRequest(
platform=Platform.DISCORD,
platform_user_id="pu1",
platform_username="Bently",
),
)
assert result.token
assert result.token in result.link_url
@pytest.mark.asyncio
async def test_rejects_when_already_linked(self):
with patch("backend.platform_linking.db.PlatformUserLink") as mock_user_link:
mock_user_link.prisma.return_value.find_unique = AsyncMock(
return_value=MagicMock(userId="u-owner")
)
with pytest.raises(LinkAlreadyExistsError):
await create_user_link_token(
CreateUserLinkTokenRequest(
platform=Platform.DISCORD,
platform_user_id="pu1",
),
)
# ── Token status / info ───────────────────────────────────────────────
class TestGetLinkTokenStatus:
@pytest.mark.asyncio
async def test_not_found(self):
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
mock_model.prisma.return_value.find_unique = AsyncMock(return_value=None)
with pytest.raises(NotFoundError):
await get_link_token_status("abc")
@pytest.mark.asyncio
async def test_pending(self):
future = datetime.now(timezone.utc) + timedelta(minutes=10)
fake_token = MagicMock(usedAt=None, expiresAt=future)
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
mock_model.prisma.return_value.find_unique = AsyncMock(
return_value=fake_token
)
result = await get_link_token_status("abc")
assert result.status == "pending"
@pytest.mark.asyncio
async def test_expired_by_time(self):
past = datetime.now(timezone.utc) - timedelta(minutes=10)
fake_token = MagicMock(usedAt=None, expiresAt=past)
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
mock_model.prisma.return_value.find_unique = AsyncMock(
return_value=fake_token
)
result = await get_link_token_status("abc")
assert result.status == "expired"
@pytest.mark.asyncio
async def test_used_with_user_link_reports_linked(self):
fake_token = MagicMock(
usedAt=datetime.now(timezone.utc),
linkType=LinkType.USER.value,
platform="DISCORD",
platformUserId="pu1",
)
with (
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token,
patch("backend.platform_linking.db.PlatformUserLink") as mock_user_link,
):
mock_token.prisma.return_value.find_unique = AsyncMock(
return_value=fake_token
)
mock_user_link.prisma.return_value.find_unique = AsyncMock(
return_value=MagicMock(userId="u-owner")
)
result = await get_link_token_status("abc")
assert result.status == "linked"
@pytest.mark.asyncio
async def test_used_without_link_reports_expired(self):
# Superseded token: usedAt set, but no backing link row.
fake_token = MagicMock(
usedAt=datetime.now(timezone.utc),
linkType=LinkType.SERVER.value,
platform="DISCORD",
platformServerId="g1",
)
with (
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token,
patch("backend.platform_linking.db.PlatformLink") as mock_link,
):
mock_token.prisma.return_value.find_unique = AsyncMock(
return_value=fake_token
)
mock_link.prisma.return_value.find_first = AsyncMock(return_value=None)
result = await get_link_token_status("abc")
assert result.status == "expired"
class TestGetLinkTokenInfo:
@pytest.mark.asyncio
async def test_not_found(self):
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
mock_model.prisma.return_value.find_unique = AsyncMock(return_value=None)
with pytest.raises(NotFoundError):
await get_link_token_info("abc")
@pytest.mark.asyncio
async def test_used_returns_not_found(self):
fake_token = MagicMock(usedAt=datetime.now(timezone.utc))
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
mock_model.prisma.return_value.find_unique = AsyncMock(
return_value=fake_token
)
with pytest.raises(NotFoundError):
await get_link_token_info("abc")
@pytest.mark.asyncio
async def test_expired_raises_expired(self):
past = datetime.now(timezone.utc) - timedelta(minutes=5)
fake_token = MagicMock(usedAt=None, expiresAt=past)
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
mock_model.prisma.return_value.find_unique = AsyncMock(
return_value=fake_token
)
with pytest.raises(LinkTokenExpiredError):
await get_link_token_info("abc")
@pytest.mark.asyncio
async def test_success_returns_display_info(self):
future = datetime.now(timezone.utc) + timedelta(minutes=10)
fake_token = MagicMock(
usedAt=None,
expiresAt=future,
platform="DISCORD",
linkType=LinkType.SERVER.value,
serverName="My Server",
)
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
mock_model.prisma.return_value.find_unique = AsyncMock(
return_value=fake_token
)
result = await get_link_token_info("abc")
assert result.platform == "DISCORD"
assert result.link_type == LinkType.SERVER
assert result.server_name == "My Server"
# ── Confirmation ─────────────────────────────────────────────────────
class TestConfirmServerLink:
@pytest.mark.asyncio
async def test_not_found(self):
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
mock_model.prisma.return_value.find_unique = AsyncMock(return_value=None)
with pytest.raises(NotFoundError):
await confirm_server_link("abc", "u1")
@pytest.mark.asyncio
async def test_wrong_link_type_rejected(self):
fake_token = MagicMock(linkType=LinkType.USER.value)
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
mock_model.prisma.return_value.find_unique = AsyncMock(
return_value=fake_token
)
with pytest.raises(LinkFlowMismatchError):
await confirm_server_link("abc", "u1")
@pytest.mark.asyncio
async def test_already_used(self):
fake_token = MagicMock(
linkType=LinkType.SERVER.value, usedAt=datetime.now(timezone.utc)
)
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
mock_model.prisma.return_value.find_unique = AsyncMock(
return_value=fake_token
)
with pytest.raises(LinkTokenExpiredError):
await confirm_server_link("abc", "u1")
@pytest.mark.asyncio
async def test_expired_by_time(self):
fake_token = MagicMock(
linkType=LinkType.SERVER.value,
usedAt=None,
expiresAt=datetime.now(timezone.utc) - timedelta(minutes=5),
)
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
mock_model.prisma.return_value.find_unique = AsyncMock(
return_value=fake_token
)
with pytest.raises(LinkTokenExpiredError):
await confirm_server_link("abc", "u1")
@pytest.mark.asyncio
async def test_already_linked_to_same_user(self):
fake_token = MagicMock(
linkType=LinkType.SERVER.value,
usedAt=None,
expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10),
platform="DISCORD",
platformServerId="g1",
)
with (
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token,
patch("backend.platform_linking.db.PlatformLink") as mock_link,
):
mock_token.prisma.return_value.find_unique = AsyncMock(
return_value=fake_token
)
mock_link.prisma.return_value.find_first = AsyncMock(
return_value=MagicMock(userId="u1")
)
with pytest.raises(LinkAlreadyExistsError) as exc_info:
await confirm_server_link("abc", "u1")
assert "your account" in str(exc_info.value)
@pytest.mark.asyncio
async def test_already_linked_to_other_user(self):
fake_token = MagicMock(
linkType=LinkType.SERVER.value,
usedAt=None,
expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10),
platform="DISCORD",
platformServerId="g1",
)
with (
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token,
patch("backend.platform_linking.db.PlatformLink") as mock_link,
):
mock_token.prisma.return_value.find_unique = AsyncMock(
return_value=fake_token
)
mock_link.prisma.return_value.find_first = AsyncMock(
return_value=MagicMock(userId="other-user")
)
with pytest.raises(LinkAlreadyExistsError) as exc_info:
await confirm_server_link("abc", "u1")
assert "another" in str(exc_info.value)
class TestConfirmUserLink:
@pytest.mark.asyncio
async def test_not_found(self):
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
mock_model.prisma.return_value.find_unique = AsyncMock(return_value=None)
with pytest.raises(NotFoundError):
await confirm_user_link("abc", "u1")
@pytest.mark.asyncio
async def test_wrong_link_type_rejected(self):
fake_token = MagicMock(linkType=LinkType.SERVER.value)
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
mock_model.prisma.return_value.find_unique = AsyncMock(
return_value=fake_token
)
with pytest.raises(LinkFlowMismatchError):
await confirm_user_link("abc", "u1")
@pytest.mark.asyncio
async def test_expired_by_time(self):
fake_token = MagicMock(
linkType=LinkType.USER.value,
usedAt=None,
expiresAt=datetime.now(timezone.utc) - timedelta(minutes=5),
)
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
mock_model.prisma.return_value.find_unique = AsyncMock(
return_value=fake_token
)
with pytest.raises(LinkTokenExpiredError):
await confirm_user_link("abc", "u1")
@pytest.mark.asyncio
async def test_already_linked_to_other_user(self):
fake_token = MagicMock(
linkType=LinkType.USER.value,
usedAt=None,
expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10),
platform="DISCORD",
platformUserId="pu1",
)
with (
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token,
patch("backend.platform_linking.db.PlatformUserLink") as mock_user_link,
):
mock_token.prisma.return_value.find_unique = AsyncMock(
return_value=fake_token
)
mock_user_link.prisma.return_value.find_unique = AsyncMock(
return_value=MagicMock(userId="other-user")
)
with pytest.raises(LinkAlreadyExistsError):
await confirm_user_link("abc", "u1")
# ── Delete (authz checks) ────────────────────────────────────────────
class TestDeleteLinks:
@pytest.mark.asyncio
async def test_delete_server_link_not_found(self):
with patch("backend.platform_linking.db.PlatformLink") as mock_model:
mock_model.prisma.return_value.find_unique = AsyncMock(return_value=None)
with pytest.raises(NotFoundError):
await delete_server_link("x", "u1")
@pytest.mark.asyncio
async def test_delete_server_link_not_owned(self):
link = MagicMock(userId="owner-A", platform="DISCORD", platformServerId="g1")
with patch("backend.platform_linking.db.PlatformLink") as mock_model:
mock_model.prisma.return_value.find_unique = AsyncMock(return_value=link)
with pytest.raises(NotAuthorizedError):
await delete_server_link("x", "u-other")
@pytest.mark.asyncio
async def test_delete_user_link_not_found(self):
with patch("backend.platform_linking.db.PlatformUserLink") as mock_model:
mock_model.prisma.return_value.find_unique = AsyncMock(return_value=None)
with pytest.raises(NotFoundError):
await delete_user_link("x", "u1")
@pytest.mark.asyncio
async def test_delete_user_link_not_owned(self):
link = MagicMock(userId="owner-A", platform="DISCORD")
with patch("backend.platform_linking.db.PlatformUserLink") as mock_model:
mock_model.prisma.return_value.find_unique = AsyncMock(return_value=link)
with pytest.raises(NotAuthorizedError):
await delete_user_link("x", "u-other")

View File

@@ -0,0 +1,82 @@
"""AppService exposing bot-facing platform_linking ops over internal RPC."""
import logging
from backend.data.db_accessors import platform_linking_db
from backend.util.service import AppService, AppServiceClient, endpoint_to_async, expose
from backend.util.settings import Settings
from .chat import start_chat_turn
from .models import (
BotChatRequest,
ChatTurnHandle,
CreateLinkTokenRequest,
CreateUserLinkTokenRequest,
LinkTokenResponse,
LinkTokenStatusResponse,
Platform,
ResolveResponse,
)
logger = logging.getLogger(__name__)
class PlatformLinkingManager(AppService):
@classmethod
def get_port(cls) -> int:
return Settings().config.platform_linking_service_port
@expose
async def resolve_server_link(
self, platform: Platform, platform_server_id: str
) -> ResolveResponse:
return await platform_linking_db().resolve_server_link(
platform.value, platform_server_id
)
@expose
async def resolve_user_link(
self, platform: Platform, platform_user_id: str
) -> ResolveResponse:
return await platform_linking_db().resolve_user_link(
platform.value, platform_user_id
)
@expose
async def create_server_link_token(
self, request: CreateLinkTokenRequest
) -> LinkTokenResponse:
return await platform_linking_db().create_server_link_token(request)
@expose
async def create_user_link_token(
self, request: CreateUserLinkTokenRequest
) -> LinkTokenResponse:
return await platform_linking_db().create_user_link_token(request)
@expose
async def get_link_token_status(self, token: str) -> LinkTokenStatusResponse:
return await platform_linking_db().get_link_token_status(token)
@expose
async def start_chat_turn(self, request: BotChatRequest) -> ChatTurnHandle:
return await start_chat_turn(request)
class PlatformLinkingManagerClient(AppServiceClient):
@classmethod
def get_service_type(cls):
return PlatformLinkingManager
resolve_server_link = endpoint_to_async(PlatformLinkingManager.resolve_server_link)
resolve_user_link = endpoint_to_async(PlatformLinkingManager.resolve_user_link)
create_server_link_token = endpoint_to_async(
PlatformLinkingManager.create_server_link_token
)
create_user_link_token = endpoint_to_async(
PlatformLinkingManager.create_user_link_token
)
get_link_token_status = endpoint_to_async(
PlatformLinkingManager.get_link_token_status
)
start_chat_turn = endpoint_to_async(PlatformLinkingManager.start_chat_turn)

View File

@@ -0,0 +1,346 @@
"""Tests for PlatformLinkingManager RPC wiring and confirm-token races."""
import asyncio
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.util.exceptions import LinkTokenExpiredError
from .db import confirm_server_link, confirm_user_link
from .manager import PlatformLinkingManager, PlatformLinkingManagerClient
from .models import (
BotChatRequest,
CreateLinkTokenRequest,
CreateUserLinkTokenRequest,
LinkType,
Platform,
ResolveResponse,
)
@asynccontextmanager
async def _fake_transaction():
yield MagicMock()
class TestManagerWiring:
def test_get_port(self):
assert PlatformLinkingManager.get_port() == 8009
def test_client_exposes_expected_rpc_surface(self):
service_type = PlatformLinkingManagerClient.get_service_type()
assert service_type is PlatformLinkingManager
expected = {
"resolve_server_link",
"resolve_user_link",
"create_server_link_token",
"create_user_link_token",
"get_link_token_status",
"start_chat_turn",
}
for name in expected:
assert hasattr(
PlatformLinkingManagerClient, name
), f"Client missing RPC stub: {name}"
for name in (
"confirm_server_link",
"confirm_user_link",
"list_server_links",
"list_user_links",
"delete_server_link",
"delete_user_link",
):
assert not hasattr(
PlatformLinkingManagerClient, name
), f"User-facing method leaked to bot client: {name}"
@pytest.mark.asyncio
async def test_resolve_server_link_delegates_to_accessor(self):
manager = PlatformLinkingManager()
db_mock = MagicMock()
db_mock.resolve_server_link = AsyncMock(
return_value=ResolveResponse(linked=True)
)
with patch(
"backend.platform_linking.manager.platform_linking_db",
return_value=db_mock,
):
result = await manager.resolve_server_link(Platform.DISCORD, "g1")
db_mock.resolve_server_link.assert_awaited_once_with("DISCORD", "g1")
assert result.linked is True
@pytest.mark.asyncio
async def test_resolve_user_link_delegates_to_accessor(self):
manager = PlatformLinkingManager()
db_mock = MagicMock()
db_mock.resolve_user_link = AsyncMock(
return_value=ResolveResponse(linked=False)
)
with patch(
"backend.platform_linking.manager.platform_linking_db",
return_value=db_mock,
):
result = await manager.resolve_user_link(Platform.DISCORD, "pu1")
db_mock.resolve_user_link.assert_awaited_once_with("DISCORD", "pu1")
assert result.linked is False
@pytest.mark.asyncio
async def test_create_server_link_token_delegates(self):
manager = PlatformLinkingManager()
req = CreateLinkTokenRequest(
platform=Platform.DISCORD,
platform_server_id="g1",
platform_user_id="u1",
)
fake_response = MagicMock()
db_mock = MagicMock()
db_mock.create_server_link_token = AsyncMock(return_value=fake_response)
with patch(
"backend.platform_linking.manager.platform_linking_db",
return_value=db_mock,
):
result = await manager.create_server_link_token(req)
db_mock.create_server_link_token.assert_awaited_once_with(req)
assert result is fake_response
@pytest.mark.asyncio
async def test_create_user_link_token_delegates(self):
manager = PlatformLinkingManager()
req = CreateUserLinkTokenRequest(
platform=Platform.DISCORD, platform_user_id="pu1"
)
fake_response = MagicMock()
db_mock = MagicMock()
db_mock.create_user_link_token = AsyncMock(return_value=fake_response)
with patch(
"backend.platform_linking.manager.platform_linking_db",
return_value=db_mock,
):
result = await manager.create_user_link_token(req)
db_mock.create_user_link_token.assert_awaited_once_with(req)
assert result is fake_response
@pytest.mark.asyncio
async def test_get_link_token_status_delegates(self):
manager = PlatformLinkingManager()
fake_response = MagicMock()
db_mock = MagicMock()
db_mock.get_link_token_status = AsyncMock(return_value=fake_response)
with patch(
"backend.platform_linking.manager.platform_linking_db",
return_value=db_mock,
):
result = await manager.get_link_token_status("tok")
db_mock.get_link_token_status.assert_awaited_once_with("tok")
assert result is fake_response
@pytest.mark.asyncio
async def test_start_chat_turn_delegates(self):
manager = PlatformLinkingManager()
req = BotChatRequest(
platform=Platform.DISCORD,
platform_user_id="pu1",
message="hi",
)
fake_response = MagicMock()
with patch(
"backend.platform_linking.manager.start_chat_turn",
new=AsyncMock(return_value=fake_response),
) as stub:
result = await manager.start_chat_turn(req)
stub.assert_awaited_once_with(req)
assert result is fake_response
class TestAdversarialConfirmRace:
"""Concurrent confirm of one token: exactly one winner via ``update_many``
guarded on ``usedAt = None``."""
@pytest.mark.asyncio
async def test_second_confirm_loses(self):
# update_many returns 0 → caller lost the race
fake_token = MagicMock(
linkType=LinkType.SERVER.value,
usedAt=None,
expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10),
platform="DISCORD",
platformServerId="g1",
)
with (
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token,
patch("backend.platform_linking.db.PlatformLink") as mock_link,
patch("backend.platform_linking.db.transaction", new=_fake_transaction),
):
mock_token.prisma.return_value.find_unique = AsyncMock(
return_value=fake_token
)
mock_link.prisma.return_value.find_first = AsyncMock(return_value=None)
mock_token.prisma.return_value.update_many = AsyncMock(return_value=0)
with pytest.raises(LinkTokenExpiredError):
await confirm_server_link("abc", "user-late")
@pytest.mark.asyncio
async def test_second_confirm_wins_when_update_many_returns_one(self):
fake_token = MagicMock(
linkType=LinkType.SERVER.value,
usedAt=None,
expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10),
platform="DISCORD",
platformServerId="g1",
platformUserId="pu1",
serverName="S1",
)
with (
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token,
patch("backend.platform_linking.db.PlatformLink") as mock_link,
patch("backend.platform_linking.db.transaction", new=_fake_transaction),
):
mock_token.prisma.return_value.find_unique = AsyncMock(
return_value=fake_token
)
mock_link.prisma.return_value.find_first = AsyncMock(return_value=None)
mock_token.prisma.return_value.update_many = AsyncMock(return_value=1)
mock_link.prisma.return_value.create = AsyncMock(return_value=MagicMock())
result = await confirm_server_link("abc", "user-winner")
assert result.success is True
assert result.platform_server_id == "g1"
@pytest.mark.asyncio
async def test_gather_confirm_same_user_one_winner(self):
fake_token = MagicMock(
linkType=LinkType.SERVER.value,
usedAt=None,
expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10),
platform="DISCORD",
platformServerId="g1",
platformUserId="pu1",
serverName="S1",
)
update_results = [1, 0]
async def flaky_update_many(*args, **kwargs):
return update_results.pop(0)
with (
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token,
patch("backend.platform_linking.db.PlatformLink") as mock_link,
patch("backend.platform_linking.db.transaction", new=_fake_transaction),
):
mock_token.prisma.return_value.find_unique = AsyncMock(
return_value=fake_token
)
mock_link.prisma.return_value.find_first = AsyncMock(return_value=None)
mock_token.prisma.return_value.update_many = flaky_update_many
mock_link.prisma.return_value.create = AsyncMock(return_value=MagicMock())
results = await asyncio.gather(
confirm_server_link("abc", "u1"),
confirm_server_link("abc", "u1"),
return_exceptions=True,
)
successes = [r for r in results if not isinstance(r, Exception)]
losses = [r for r in results if isinstance(r, LinkTokenExpiredError)]
assert len(successes) == 1
assert len(losses) == 1
@pytest.mark.asyncio
async def test_gather_confirm_different_users_one_winner_no_hijack(self):
# Different users racing the same token: still exactly one winner,
# and the other gets a clean LinkTokenExpiredError (no partial state).
fake_token = MagicMock(
linkType=LinkType.SERVER.value,
usedAt=None,
expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10),
platform="DISCORD",
platformServerId="g1",
platformUserId="pu1",
serverName="S1",
)
update_results = [1, 0]
async def flaky_update_many(*args, **kwargs):
return update_results.pop(0)
created_link_user_ids: list[str] = []
async def record_create(*, data):
created_link_user_ids.append(data["userId"])
return MagicMock()
with (
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token,
patch("backend.platform_linking.db.PlatformLink") as mock_link,
patch("backend.platform_linking.db.transaction", new=_fake_transaction),
):
mock_token.prisma.return_value.find_unique = AsyncMock(
return_value=fake_token
)
mock_link.prisma.return_value.find_first = AsyncMock(return_value=None)
mock_token.prisma.return_value.update_many = flaky_update_many
mock_link.prisma.return_value.create = record_create
results = await asyncio.gather(
confirm_server_link("abc", "user-a"),
confirm_server_link("abc", "user-b"),
return_exceptions=True,
)
successes = [r for r in results if not isinstance(r, Exception)]
losses = [r for r in results if isinstance(r, LinkTokenExpiredError)]
assert len(successes) == 1
assert len(losses) == 1
assert len(created_link_user_ids) == 1
assert created_link_user_ids[0] in ("user-a", "user-b")
@pytest.mark.asyncio
async def test_gather_confirm_user_link_one_winner(self):
fake_token = MagicMock(
linkType=LinkType.USER.value,
usedAt=None,
expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10),
platform="DISCORD",
platformUserId="pu1",
platformUsername="pu_name",
)
update_results = [1, 0]
async def flaky_update_many(*args, **kwargs):
return update_results.pop(0)
with (
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token,
patch("backend.platform_linking.db.PlatformUserLink") as mock_user_link,
patch("backend.platform_linking.db.transaction", new=_fake_transaction),
):
mock_token.prisma.return_value.find_unique = AsyncMock(
return_value=fake_token
)
mock_user_link.prisma.return_value.find_unique = AsyncMock(
return_value=None
)
mock_token.prisma.return_value.update_many = flaky_update_many
mock_user_link.prisma.return_value.create = AsyncMock(
return_value=MagicMock()
)
results = await asyncio.gather(
confirm_user_link("abc", "user-a"),
confirm_user_link("abc", "user-b"),
return_exceptions=True,
)
successes = [r for r in results if not isinstance(r, Exception)]
losses = [r for r in results if isinstance(r, LinkTokenExpiredError)]
assert len(successes) == 1
assert len(losses) == 1

View File

@@ -0,0 +1,182 @@
"""Pydantic models for platform_linking requests and responses."""
from datetime import datetime
from enum import Enum
from typing import Literal
from pydantic import BaseModel, Field
class Platform(str, Enum):
"""Mirror of the Prisma PlatformType enum."""
DISCORD = "DISCORD"
TELEGRAM = "TELEGRAM"
SLACK = "SLACK"
TEAMS = "TEAMS"
WHATSAPP = "WHATSAPP"
GITHUB = "GITHUB"
LINEAR = "LINEAR"
class LinkType(str, Enum):
SERVER = "SERVER"
USER = "USER"
# ── Request Models ─────────────────────────────────────────────────────
class CreateLinkTokenRequest(BaseModel):
platform: Platform = Field(description="Platform name")
platform_server_id: str = Field(
description="Server/guild/group ID on the platform",
min_length=1,
max_length=255,
)
platform_user_id: str = Field(
description="Platform user ID of the person claiming ownership",
min_length=1,
max_length=255,
)
platform_username: str | None = Field(
default=None,
description="Display name of the person claiming ownership",
max_length=255,
)
server_name: str | None = Field(
default=None,
description="Display name of the server/group",
max_length=255,
)
channel_id: str | None = Field(
default=None,
description="Channel ID so the bot can send a confirmation message",
max_length=255,
)
class CreateUserLinkTokenRequest(BaseModel):
platform: Platform
platform_user_id: str = Field(
description="Platform user ID of the person linking their DMs",
min_length=1,
max_length=255,
)
platform_username: str | None = Field(
default=None,
description="Their display name (best-effort for audit)",
max_length=255,
)
class ResolveServerRequest(BaseModel):
platform: Platform
platform_server_id: str = Field(
description="Server/guild/group ID to look up",
min_length=1,
max_length=255,
)
class ResolveUserRequest(BaseModel):
platform: Platform
platform_user_id: str = Field(
description="Platform user ID to look up",
min_length=1,
max_length=255,
)
class BotChatRequest(BaseModel):
"""Bot message request. If ``platform_server_id`` is set, the turn is
billed to that server's owner; otherwise billed to ``platform_user_id``
(DM context)."""
platform: Platform
platform_server_id: str | None = Field(
default=None,
description="Server/guild/group ID — null for DM context",
min_length=1,
max_length=255,
)
platform_user_id: str = Field(
description="Platform user ID of the person who sent the message",
min_length=1,
max_length=255,
)
message: str = Field(
description="The user's message", min_length=1, max_length=32000
)
session_id: str | None = Field(
default=None,
description="Existing CoPilot session ID. If omitted, a new session is created.",
)
# ── Response Models ────────────────────────────────────────────────────
class LinkTokenResponse(BaseModel):
token: str
expires_at: datetime
link_url: str
class LinkTokenStatusResponse(BaseModel):
status: Literal["pending", "linked", "expired"]
class LinkTokenInfoResponse(BaseModel):
platform: str
link_type: LinkType
server_name: str | None = None
class ResolveResponse(BaseModel):
linked: bool
class PlatformLinkInfo(BaseModel):
id: str
platform: str
platform_server_id: str
owner_platform_user_id: str
server_name: str | None
linked_at: datetime
class PlatformUserLinkInfo(BaseModel):
id: str
platform: str
platform_user_id: str
platform_username: str | None
linked_at: datetime
class ConfirmLinkResponse(BaseModel):
success: bool
link_type: LinkType = LinkType.SERVER
platform: str
platform_server_id: str
server_name: str | None
class ConfirmUserLinkResponse(BaseModel):
success: bool
link_type: LinkType = LinkType.USER
platform: str
platform_user_id: str
class DeleteLinkResponse(BaseModel):
success: bool
class ChatTurnHandle(BaseModel):
"""Subscribe keys for a pending copilot turn."""
session_id: str
turn_id: str
user_id: str
subscribe_from: str = "0-0"

View File

@@ -0,0 +1,178 @@
"""Schema validation tests for platform_linking Pydantic models."""
import pytest
from pydantic import ValidationError
from .models import (
BotChatRequest,
ConfirmLinkResponse,
CreateLinkTokenRequest,
DeleteLinkResponse,
LinkTokenStatusResponse,
Platform,
ResolveResponse,
ResolveServerRequest,
)
class TestPlatformEnum:
def test_all_platforms_exist(self):
assert Platform.DISCORD.value == "DISCORD"
assert Platform.TELEGRAM.value == "TELEGRAM"
assert Platform.SLACK.value == "SLACK"
assert Platform.TEAMS.value == "TEAMS"
assert Platform.WHATSAPP.value == "WHATSAPP"
assert Platform.GITHUB.value == "GITHUB"
assert Platform.LINEAR.value == "LINEAR"
class TestCreateLinkTokenRequest:
def test_valid_request(self):
req = CreateLinkTokenRequest(
platform=Platform.DISCORD,
platform_server_id="1126875755960336515",
platform_user_id="353922987235213313",
platform_username="Bently",
server_name="My Discord Server",
)
assert req.platform == Platform.DISCORD
assert req.platform_server_id == "1126875755960336515"
assert req.platform_user_id == "353922987235213313"
assert req.server_name == "My Discord Server"
def test_minimal_request(self):
req = CreateLinkTokenRequest(
platform=Platform.TELEGRAM,
platform_server_id="-100123456789",
platform_user_id="987654321",
)
assert req.server_name is None
assert req.platform_username is None
def test_empty_server_id_rejected(self):
with pytest.raises(ValidationError):
CreateLinkTokenRequest(
platform=Platform.DISCORD,
platform_server_id="",
platform_user_id="123",
)
def test_too_long_server_id_rejected(self):
with pytest.raises(ValidationError):
CreateLinkTokenRequest(
platform=Platform.DISCORD,
platform_server_id="x" * 256,
platform_user_id="123",
)
def test_invalid_platform_rejected(self):
with pytest.raises(ValidationError):
CreateLinkTokenRequest.model_validate(
{
"platform": "INVALID",
"platform_server_id": "123",
"platform_user_id": "456",
}
)
class TestResolveServerRequest:
def test_valid_request(self):
req = ResolveServerRequest(
platform=Platform.DISCORD,
platform_server_id="1126875755960336515",
)
assert req.platform == Platform.DISCORD
assert req.platform_server_id == "1126875755960336515"
def test_empty_server_id_rejected(self):
with pytest.raises(ValidationError):
ResolveServerRequest(
platform=Platform.SLACK,
platform_server_id="",
)
class TestBotChatRequest:
def test_server_context(self):
req = BotChatRequest(
platform=Platform.DISCORD,
platform_server_id="1126875755960336515",
platform_user_id="353922987235213313",
message="Hello CoPilot!",
)
assert req.platform == Platform.DISCORD
assert req.platform_server_id == "1126875755960336515"
assert req.session_id is None
def test_dm_context_omits_server_id(self):
req = BotChatRequest(
platform=Platform.DISCORD,
platform_user_id="353922987235213313",
message="Hello in DMs!",
)
assert req.platform_server_id is None
def test_with_session_id(self):
req = BotChatRequest(
platform=Platform.DISCORD,
platform_server_id="guild_123",
platform_user_id="user_456",
message="follow up",
session_id="session-uuid-here",
)
assert req.session_id == "session-uuid-here"
def test_empty_message_rejected(self):
with pytest.raises(ValidationError):
BotChatRequest(
platform=Platform.DISCORD,
platform_server_id="guild_123",
platform_user_id="user_456",
message="",
)
def test_empty_string_server_id_rejected(self):
with pytest.raises(ValidationError):
BotChatRequest(
platform=Platform.DISCORD,
platform_server_id="",
platform_user_id="user_456",
message="hi",
)
class TestResponseModels:
def test_link_token_status_pending(self):
resp = LinkTokenStatusResponse(status="pending")
assert resp.status == "pending"
def test_link_token_status_linked(self):
resp = LinkTokenStatusResponse(status="linked")
assert resp.status == "linked"
def test_link_token_status_expired(self):
resp = LinkTokenStatusResponse(status="expired")
assert resp.status == "expired"
def test_resolve_linked(self):
resp = ResolveResponse(linked=True)
assert resp.linked is True
def test_resolve_not_linked(self):
resp = ResolveResponse(linked=False)
assert resp.linked is False
def test_confirm_link_response(self):
resp = ConfirmLinkResponse(
success=True,
platform="DISCORD",
platform_server_id="1126875755960336515",
server_name="My Server",
)
assert resp.success is True
assert resp.server_name == "My Server"
def test_delete_link_response(self):
resp = DeleteLinkResponse(success=True)
assert resp.success is True

View File

@@ -0,0 +1,15 @@
from backend.app import run_processes
from backend.platform_linking.manager import PlatformLinkingManager
def main():
"""
Run the AutoGPT-server Platform Linking Manager service.
"""
run_processes(
PlatformLinkingManager(),
)
if __name__ == "__main__":
main()

View File

@@ -27,6 +27,7 @@ if TYPE_CHECKING:
from backend.executor.scheduler import SchedulerClient
from backend.integrations.credentials_store import IntegrationCredentialsStore
from backend.notifications.notifications import NotificationManagerClient
from backend.platform_linking.manager import PlatformLinkingManagerClient
@thread_cached
@@ -67,6 +68,15 @@ def get_notification_manager_client() -> "NotificationManagerClient":
return get_service_client(NotificationManagerClient)
@thread_cached
def get_platform_linking_manager_client() -> "PlatformLinkingManagerClient":
"""Get a thread-cached PlatformLinkingManagerClient."""
from backend.platform_linking.manager import PlatformLinkingManagerClient
from backend.util.service import get_service_client
return get_service_client(PlatformLinkingManagerClient)
# ============ Execution Event Bus Helpers ============ #

View File

@@ -155,3 +155,19 @@ class RedisError(Exception):
"""Raised when there is an error interacting with Redis"""
pass
class LinkAlreadyExistsError(ValueError):
"""A platform_linking target (server or user) is already linked."""
class LinkTokenExpiredError(ValueError):
"""A platform_linking token has expired or been consumed."""
class LinkFlowMismatchError(ValueError):
"""A platform_linking token was used for the wrong flow (server vs user)."""
class DuplicateChatMessageError(ValueError):
"""The same user message is already in flight for this chat session."""

View File

@@ -252,6 +252,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="The port for notification service daemon to run on",
)
platform_linking_service_port: int = Field(
default=8009,
description="The port for the platform_linking manager daemon to run on",
)
otto_api_url: str = Field(
default="",
description="The URL for the Otto API service",
@@ -269,6 +274,13 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
"This value is then used to generate redirect URLs for OAuth flows.",
)
platform_link_base_url: str = Field(
default="https://platform.agpt.co/link",
description="Base URL the bot service prepends to one-time linking "
"tokens when it posts them to users ({base}/{token}?platform=...). "
"Should point at the frontend /link page.",
)
media_gcs_bucket_name: str = Field(
default="",
description="The name of the Google Cloud Storage bucket for media files",

View File

@@ -0,0 +1,55 @@
-- CreateEnum
CREATE TYPE "PlatformType" AS ENUM ('DISCORD', 'TELEGRAM', 'SLACK', 'TEAMS', 'WHATSAPP', 'GITHUB', 'LINEAR');
-- CreateTable
-- PlatformLink maps a platform server (Discord guild, Telegram group, etc.) to an AutoGPT
-- owner account. The first user to authenticate becomes the owner — all usage from that
-- server is billed to their account. Each user within the server gets their own CoPilot
-- session, all visible in the owner's AutoGPT account.
CREATE TABLE "PlatformLink" (
"id" TEXT NOT NULL,
"userId" TEXT NOT NULL,
"platform" "PlatformType" NOT NULL,
"platformServerId" TEXT NOT NULL,
"ownerPlatformUserId" TEXT NOT NULL,
"serverName" TEXT,
"linkedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT "PlatformLink_pkey" PRIMARY KEY ("id")
);
-- CreateTable
-- PlatformLinkToken is a one-time token for the server linking flow.
CREATE TABLE "PlatformLinkToken" (
"id" TEXT NOT NULL,
"token" TEXT NOT NULL,
"platform" "PlatformType" NOT NULL,
"platformServerId" TEXT NOT NULL,
"platformUserId" TEXT NOT NULL,
"platformUsername" TEXT,
"serverName" TEXT,
"channelId" TEXT,
"expiresAt" TIMESTAMP(3) NOT NULL,
"usedAt" TIMESTAMP(3),
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT "PlatformLinkToken_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "PlatformLink_platform_platformServerId_key" ON "PlatformLink"("platform", "platformServerId");
-- CreateIndex
CREATE INDEX "PlatformLink_userId_idx" ON "PlatformLink"("userId");
-- CreateIndex
CREATE UNIQUE INDEX "PlatformLinkToken_token_key" ON "PlatformLinkToken"("token");
-- CreateIndex
CREATE INDEX "PlatformLinkToken_platform_platformServerId_idx" ON "PlatformLinkToken"("platform", "platformServerId");
-- CreateIndex
CREATE INDEX "PlatformLinkToken_expiresAt_idx" ON "PlatformLinkToken"("expiresAt");
-- AddForeignKey
ALTER TABLE "PlatformLink" ADD CONSTRAINT "PlatformLink_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -0,0 +1,37 @@
-- CreateEnum
-- Server links (group chats / guilds) and user links (personal DMs) are
-- fully independent — a user who owns a linked server still has to link
-- their DMs separately.
CREATE TYPE "PlatformLinkType" AS ENUM ('SERVER', 'USER');
-- CreateTable
-- PlatformUserLink maps an individual platform user identity to an AutoGPT
-- account for 1:1 DMs with the bot. Independent from PlatformLink.
CREATE TABLE "PlatformUserLink" (
"id" TEXT NOT NULL,
"userId" TEXT NOT NULL,
"platform" "PlatformType" NOT NULL,
"platformUserId" TEXT NOT NULL,
"platformUsername" TEXT,
"linkedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT "PlatformUserLink_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "PlatformUserLink_platform_platformUserId_key" ON "PlatformUserLink"("platform", "platformUserId");
-- CreateIndex
CREATE INDEX "PlatformUserLink_userId_idx" ON "PlatformUserLink"("userId");
-- AddForeignKey
ALTER TABLE "PlatformUserLink" ADD CONSTRAINT "PlatformUserLink_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AlterTable: PlatformLinkToken now supports SERVER or USER tokens.
-- Existing rows are all SERVER (default matches the column default).
ALTER TABLE "PlatformLinkToken"
ADD COLUMN "linkType" "PlatformLinkType" NOT NULL DEFAULT 'SERVER',
ALTER COLUMN "platformServerId" DROP NOT NULL;
-- CreateIndex
CREATE INDEX "PlatformLinkToken_platform_platformUserId_idx" ON "PlatformLinkToken"("platform", "platformUserId");

View File

@@ -0,0 +1,25 @@
-- CreateTable
CREATE TABLE "SharedExecutionFile" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"executionId" TEXT NOT NULL,
"fileId" TEXT NOT NULL,
"shareToken" TEXT NOT NULL,
CONSTRAINT "SharedExecutionFile_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "SharedExecutionFile_shareToken_fileId_key" ON "SharedExecutionFile"("shareToken", "fileId");
-- CreateIndex
CREATE INDEX "SharedExecutionFile_shareToken_idx" ON "SharedExecutionFile"("shareToken");
-- CreateIndex
CREATE INDEX "SharedExecutionFile_executionId_idx" ON "SharedExecutionFile"("executionId");
-- AddForeignKey
ALTER TABLE "SharedExecutionFile" ADD CONSTRAINT "SharedExecutionFile_executionId_fkey" FOREIGN KEY ("executionId") REFERENCES "AgentGraphExecution"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "SharedExecutionFile" ADD CONSTRAINT "SharedExecutionFile_fileId_fkey" FOREIGN KEY ("fileId") REFERENCES "UserWorkspaceFile"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -129,6 +129,7 @@ db = "backend.db:main"
ws = "backend.ws:main"
scheduler = "backend.scheduler:main"
notification = "backend.notification:main"
platform-linking-manager = "backend.platform_linking_manager:main"
executor = "backend.exec:main"
analytics-setup = "scripts.generate_views:main_setup"
analytics-views = "scripts.generate_views:main_views"

View File

@@ -81,6 +81,10 @@ model User {
OAuthAuthorizationCodes OAuthAuthorizationCode[]
OAuthAccessTokens OAuthAccessToken[]
OAuthRefreshTokens OAuthRefreshToken[]
// Platform bot linking
PlatformLinks PlatformLink[]
PlatformUserLinks PlatformUserLink[]
}
enum SubscriptionTier {
@@ -200,10 +204,32 @@ model UserWorkspaceFile {
metadata Json @default("{}")
SharedExecutionFiles SharedExecutionFile[]
@@unique([workspaceId, path])
@@index([workspaceId, isDeleted])
}
// Tracks which workspace files are exposed via a shared execution.
// Created when sharing is enabled, deleted when sharing is disabled.
// The public file download endpoint validates against this table.
model SharedExecutionFile {
id String @id @default(uuid())
createdAt DateTime @default(now())
executionId String
Execution AgentGraphExecution @relation(fields: [executionId], references: [id], onDelete: Cascade)
fileId String
File UserWorkspaceFile @relation(fields: [fileId], references: [id], onDelete: Cascade)
shareToken String
@@unique([shareToken, fileId])
@@index([shareToken])
@@index([executionId])
}
model BuilderSearchHistory {
id String @id @default(uuid())
createdAt DateTime @default(now())
@@ -585,9 +611,10 @@ model AgentGraphExecution {
ChildExecutions AgentGraphExecution[] @relation("ParentChildExecution")
// Sharing fields
isShared Boolean @default(false)
shareToken String? @unique
sharedAt DateTime?
isShared Boolean @default(false)
shareToken String? @unique
sharedAt DateTime?
SharedExecutionFiles SharedExecutionFile[]
@@index([agentGraphId, agentGraphVersion])
@@index([userId, isDeleted, createdAt])
@@ -1366,3 +1393,84 @@ model OAuthRefreshToken {
@@index([userId, applicationId])
@@index([expiresAt]) // For cleanup
}
// ── Platform Bot Linking ──────────────────────────────────────────────
// Links external chat platform identities (Discord, Telegram, Slack, etc.)
// to AutoGPT user accounts, enabling the multi-platform CoPilot bot.
enum PlatformType {
DISCORD
TELEGRAM
SLACK
TEAMS
WHATSAPP
GITHUB
LINEAR
}
// Whether a linking token claims a server (group chat / guild) or a personal
// 1:1 user link (DMs). Server and user links are completely independent —
// linking a server does not grant DM access and vice versa.
enum PlatformLinkType {
SERVER
USER
}
// Maps a platform server (Discord guild, Telegram group, Slack workspace, etc.)
// to an AutoGPT user account. The user who first authenticates becomes the
// "owner" — all usage from that server is attributed to their account.
model PlatformLink {
id String @id @default(uuid())
userId String // AutoGPT user ID of the owner
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
platform PlatformType
platformServerId String // Server/guild/group ID on that platform
ownerPlatformUserId String // Platform user ID of the person who set it up
serverName String? // Display name of the server (best-effort, may go stale)
linkedAt DateTime @default(now())
@@unique([platform, platformServerId])
@@index([userId])
}
// Maps a platform user identity (a single Discord / Telegram / Slack user) to
// an AutoGPT account for 1:1 DM conversations with the bot. Independent from
// PlatformLink — a user who owns a linked server must still link their DMs
// separately.
model PlatformUserLink {
id String @id @default(uuid())
userId String // AutoGPT user ID
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
platform PlatformType
platformUserId String // Individual's user ID on the platform
platformUsername String? // Display name at link time (best-effort)
linkedAt DateTime @default(now())
@@unique([platform, platformUserId])
@@index([userId])
}
// One-time tokens for either the server linking flow or the DM (user) linking
// flow. linkType determines which target is populated — SERVER tokens carry
// platformServerId + serverName + ownerPlatformUserId, USER tokens carry
// platformUserId only.
model PlatformLinkToken {
id String @id @default(uuid())
token String @unique
platform PlatformType
linkType PlatformLinkType @default(SERVER)
// SERVER token fields (null for USER tokens)
platformServerId String? // Server/guild/group ID being linked
serverName String? // Server display name
channelId String? // Channel to send confirmation back to
// Always set — platform user ID of the person who will claim ownership
platformUserId String
platformUsername String? // Their display name
expiresAt DateTime
usedAt DateTime?
createdAt DateTime @default(now())
@@index([platform, platformServerId])
@@index([platform, platformUserId])
@@index([expiresAt])
}

View File

@@ -92,11 +92,12 @@
"geist": "1.5.1",
"highlight.js": "11.11.1",
"jaro-winkler": "0.2.8",
"jszip": "3.10.1",
"katex": "0.16.25",
"launchdarkly-react-client-sdk": "3.9.0",
"lodash": "4.17.21",
"lucide-react": "0.552.0",
"next": "15.4.10",
"next": "15.4.11",
"next-themes": "0.4.6",
"nuqs": "2.7.2",
"posthog-js": "1.334.1",

View File

@@ -26,7 +26,7 @@ importers:
version: 5.2.2(react-hook-form@7.66.0(react@18.3.1))
'@next/third-parties':
specifier: 15.4.6
version: 15.4.6(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
version: 15.4.6(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
'@phosphor-icons/react':
specifier: 2.1.10
version: 2.1.10(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
@@ -107,7 +107,7 @@ importers:
version: 6.1.2(@rjsf/utils@6.1.2(react@18.3.1))
'@sentry/nextjs':
specifier: 10.27.0
version: 10.27.0(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)(webpack@5.104.1(esbuild@0.25.12))
version: 10.27.0(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)(webpack@5.104.1(esbuild@0.25.12))
'@streamdown/cjk':
specifier: 1.0.1
version: 1.0.1(@types/mdast@4.0.4)(micromark-util-types@2.0.2)(micromark@4.0.2)(react@18.3.1)(unified@11.0.5)
@@ -134,10 +134,10 @@ importers:
version: 0.2.4
'@vercel/analytics':
specifier: 1.5.0
version: 1.5.0(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
version: 1.5.0(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
'@vercel/speed-insights':
specifier: 1.2.0
version: 1.2.0(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
version: 1.2.0(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
'@xyflow/react':
specifier: 12.9.2
version: 12.9.2(@types/react@18.3.17)(immer@11.1.3)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
@@ -185,13 +185,16 @@ importers:
version: 12.23.24(@emotion/is-prop-valid@1.2.2)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
geist:
specifier: 1.5.1
version: 1.5.1(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))
version: 1.5.1(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))
highlight.js:
specifier: 11.11.1
version: 11.11.1
jaro-winkler:
specifier: 0.2.8
version: 0.2.8
jszip:
specifier: 3.10.1
version: 3.10.1
katex:
specifier: 0.16.25
version: 0.16.25
@@ -205,14 +208,14 @@ importers:
specifier: 0.552.0
version: 0.552.0(react@18.3.1)
next:
specifier: 15.4.10
version: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
specifier: 15.4.11
version: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
next-themes:
specifier: 0.4.6
version: 0.4.6(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
nuqs:
specifier: 2.7.2
version: 2.7.2(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
version: 2.7.2(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
posthog-js:
specifier: 1.334.1
version: 1.334.1
@@ -330,7 +333,7 @@ importers:
version: 9.1.5(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2)(vite@7.3.1(@types/node@24.10.0)(jiti@2.6.1)(terser@5.44.1)(yaml@2.8.2)))
'@storybook/nextjs':
specifier: 9.1.5
version: 9.1.5(esbuild@0.25.12)(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2)(vite@7.3.1(@types/node@24.10.0)(jiti@2.6.1)(terser@5.44.1)(yaml@2.8.2)))(type-fest@4.41.0)(typescript@5.9.3)(webpack-hot-middleware@2.26.1)(webpack@5.104.1(esbuild@0.25.12))
version: 9.1.5(esbuild@0.25.12)(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2)(vite@7.3.1(@types/node@24.10.0)(jiti@2.6.1)(terser@5.44.1)(yaml@2.8.2)))(type-fest@4.41.0)(typescript@5.9.3)(webpack-hot-middleware@2.26.1)(webpack@5.104.1(esbuild@0.25.12))
'@tanstack/eslint-plugin-query':
specifier: 5.91.2
version: 5.91.2(eslint@8.57.1)(typescript@5.9.3)
@@ -1844,8 +1847,8 @@ packages:
'@neoconfetti/react@1.0.0':
resolution: {integrity: sha512-klcSooChXXOzIm+SE5IISIAn3bYzYfPjbX7D7HoqZL84oAfgREeSg5vSIaSFH+DaGzzvImTyWe1OyrJ67vik4A==}
'@next/env@15.4.10':
resolution: {integrity: sha512-knhmoJ0Vv7VRf6pZEPSnciUG1S4bIhWx+qTYBW/AjxEtlzsiNORPk8sFDCEvqLfmKuey56UB9FL1UdHEV3uBrg==}
'@next/env@15.4.11':
resolution: {integrity: sha512-mIYp/091eYfPFezKX7ZPTWqrmSXq+ih6+LcUyKvLmeLQGhlPtot33kuEOd4U+xAA7sFfj21+OtCpIZx0g5SpvQ==}
'@next/eslint-plugin-next@15.5.7':
resolution: {integrity: sha512-DtRU2N7BkGr8r+pExfuWHwMEPX5SD57FeA6pxdgCHODo+b/UgIgjE+rgWKtJAbEbGhVZ2jtHn4g3wNhWFoNBQQ==}
@@ -5919,6 +5922,9 @@ packages:
engines: {node: '>=16.x'}
hasBin: true
immediate@3.0.6:
resolution: {integrity: sha512-XXOFtyqDjNDAQxVfYxuF7g9Il/IbWmmlQg2MYKOH8ExIT1qg6xc4zyS3HaEEATgs1btfzxq15ciUiY7gjSXRGQ==}
immer@10.2.0:
resolution: {integrity: sha512-d/+XTN3zfODyjr89gM3mPq1WNX2B8pYsu7eORitdwyA2sBubnTl3laYlBk4sXY5FUa5qTZGBDPJICVbvqzjlbw==}
@@ -6279,6 +6285,9 @@ packages:
resolution: {integrity: sha512-ZZow9HBI5O6EPgSJLUb8n2NKgmVWTwCvHGwFuJlMjvLFqlGG6pjirPhtdsseaLZjSibD8eegzmYpUZwoIlj2cQ==}
engines: {node: '>=4.0'}
jszip@3.10.1:
resolution: {integrity: sha512-xXDvecyTpGLrqFrvkrUSoxxfJI5AH7U8zxxtVclpsUtMCq4JQ290LY8AW5c7Ggnr/Y/oK+bQMbqK2qmtk3pN4g==}
junit-report-builder@5.1.1:
resolution: {integrity: sha512-ZNOIIGMzqCGcHQEA2Q4rIQQ3Df6gSIfne+X9Rly9Bc2y55KxAZu8iGv+n2pP0bLf0XAOctJZgeloC54hWzCahQ==}
engines: {node: '>=16'}
@@ -6348,6 +6357,9 @@ packages:
resolution: {integrity: sha512-+bT2uH4E5LGE7h/n3evcS/sQlJXCpIp6ym8OWJ5eV6+67Dsql/LaaT7qJBAt2rzfoa/5QBGBhxDix1dMt2kQKQ==}
engines: {node: '>= 0.8.0'}
lie@3.3.0:
resolution: {integrity: sha512-UaiMJzeWRlEujzAuw5LokY1L5ecNQYZKfmyZ9L7wDHb/p5etKaxXhohBcrw0EYby+G/NA52vRSN4N39dxHAIwQ==}
lilconfig@3.1.3:
resolution: {integrity: sha512-/vlFKAoH5Cgt3Ie+JLhRbwOsCQePABiU3tJ1egGvyQ+33R/vcwM2Zl2QR/LzjsBeItPt3oSVXapn+m4nQDvpzw==}
engines: {node: '>=14'}
@@ -6839,8 +6851,8 @@ packages:
react: ^16.8 || ^17 || ^18 || ^19 || ^19.0.0-rc
react-dom: ^16.8 || ^17 || ^18 || ^19 || ^19.0.0-rc
next@15.4.10:
resolution: {integrity: sha512-itVlc79QjpKMFMRhP+kbGKaSG/gZM6RCvwhEbwmCNF06CdDiNaoHcbeg0PqkEa2GOcn8KJ0nnc7+yL7EjoYLHQ==}
next@15.4.11:
resolution: {integrity: sha512-IJRyXal45mIsshZI5XJne/intjusslUP1F+FHVBIyMGEqbYtIq1Irdx5vdWBBg58smviPDycmDeV6txsfkv1RQ==}
engines: {node: ^18.18.0 || ^19.8.0 || >= 20.0.0}
hasBin: true
peerDependencies:
@@ -10423,7 +10435,7 @@ snapshots:
'@neoconfetti/react@1.0.0': {}
'@next/env@15.4.10': {}
'@next/env@15.4.11': {}
'@next/eslint-plugin-next@15.5.7':
dependencies:
@@ -10453,9 +10465,9 @@ snapshots:
'@next/swc-win32-x64-msvc@15.4.8':
optional: true
'@next/third-parties@15.4.6(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)':
'@next/third-parties@15.4.6(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)':
dependencies:
next: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
next: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
react: 18.3.1
third-party-capital: 1.0.20
@@ -11770,7 +11782,7 @@ snapshots:
'@sentry/core@10.27.0': {}
'@sentry/nextjs@10.27.0(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)(webpack@5.104.1(esbuild@0.25.12))':
'@sentry/nextjs@10.27.0(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)(webpack@5.104.1(esbuild@0.25.12))':
dependencies:
'@opentelemetry/api': 1.9.0
'@opentelemetry/semantic-conventions': 1.38.0
@@ -11783,7 +11795,7 @@ snapshots:
'@sentry/react': 10.27.0(react@18.3.1)
'@sentry/vercel-edge': 10.27.0
'@sentry/webpack-plugin': 4.6.1(webpack@5.104.1(esbuild@0.25.12))
next: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
next: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
resolve: 1.22.8
rollup: 4.55.1
stacktrace-parser: 0.1.11
@@ -12162,7 +12174,7 @@ snapshots:
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
'@storybook/nextjs@9.1.5(esbuild@0.25.12)(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2)(vite@7.3.1(@types/node@24.10.0)(jiti@2.6.1)(terser@5.44.1)(yaml@2.8.2)))(type-fest@4.41.0)(typescript@5.9.3)(webpack-hot-middleware@2.26.1)(webpack@5.104.1(esbuild@0.25.12))':
'@storybook/nextjs@9.1.5(esbuild@0.25.12)(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2)(vite@7.3.1(@types/node@24.10.0)(jiti@2.6.1)(terser@5.44.1)(yaml@2.8.2)))(type-fest@4.41.0)(typescript@5.9.3)(webpack-hot-middleware@2.26.1)(webpack@5.104.1(esbuild@0.25.12))':
dependencies:
'@babel/core': 7.28.5
'@babel/plugin-syntax-bigint': 7.8.3(@babel/core@7.28.5)
@@ -12186,7 +12198,7 @@ snapshots:
css-loader: 6.11.0(webpack@5.104.1(esbuild@0.25.12))
image-size: 2.0.2
loader-utils: 3.3.1
next: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
next: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
node-polyfill-webpack-plugin: 2.0.1(webpack@5.104.1(esbuild@0.25.12))
postcss: 8.5.6
postcss-loader: 8.2.0(postcss@8.5.6)(typescript@5.9.3)(webpack@5.104.1(esbuild@0.25.12))
@@ -12872,16 +12884,16 @@ snapshots:
'@unrs/resolver-binding-win32-x64-msvc@1.11.1':
optional: true
'@vercel/analytics@1.5.0(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)':
'@vercel/analytics@1.5.0(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)':
optionalDependencies:
next: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
next: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
react: 18.3.1
'@vercel/oidc@3.1.0': {}
'@vercel/speed-insights@1.2.0(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)':
'@vercel/speed-insights@1.2.0(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)':
optionalDependencies:
next: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
next: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
react: 18.3.1
'@vitejs/plugin-react@5.1.2(vite@7.3.1(@types/node@24.10.0)(jiti@2.6.1)(terser@5.44.1)(yaml@2.8.2))':
@@ -14449,8 +14461,8 @@ snapshots:
'@typescript-eslint/parser': 8.52.0(eslint@8.57.1)(typescript@5.9.3)
eslint: 8.57.1
eslint-import-resolver-node: 0.3.9
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1)
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1)
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1)
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1)
eslint-plugin-jsx-a11y: 6.10.2(eslint@8.57.1)
eslint-plugin-react: 7.37.5(eslint@8.57.1)
eslint-plugin-react-hooks: 5.2.0(eslint@8.57.1)
@@ -14469,7 +14481,7 @@ snapshots:
transitivePeerDependencies:
- supports-color
eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1):
eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1):
dependencies:
'@nolyfill/is-core-module': 1.0.39
debug: 4.4.3
@@ -14480,22 +14492,22 @@ snapshots:
tinyglobby: 0.2.15
unrs-resolver: 1.11.1
optionalDependencies:
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1)
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1)
transitivePeerDependencies:
- supports-color
eslint-module-utils@2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1):
eslint-module-utils@2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1):
dependencies:
debug: 3.2.7
optionalDependencies:
'@typescript-eslint/parser': 8.52.0(eslint@8.57.1)(typescript@5.9.3)
eslint: 8.57.1
eslint-import-resolver-node: 0.3.9
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1)
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1)
transitivePeerDependencies:
- supports-color
eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1):
eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1):
dependencies:
'@rtsao/scc': 1.1.0
array-includes: 3.1.9
@@ -14506,7 +14518,7 @@ snapshots:
doctrine: 2.1.0
eslint: 8.57.1
eslint-import-resolver-node: 0.3.9
eslint-module-utils: 2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1)
eslint-module-utils: 2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1)
hasown: 2.0.2
is-core-module: 2.16.1
is-glob: 4.0.3
@@ -14877,9 +14889,9 @@ snapshots:
functions-have-names@1.2.3: {}
geist@1.5.1(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)):
geist@1.5.1(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)):
dependencies:
next: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
next: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
generator-function@2.0.1: {}
@@ -15290,6 +15302,8 @@ snapshots:
image-size@2.0.2: {}
immediate@3.0.6: {}
immer@10.2.0: {}
immer@11.1.3: {}
@@ -15646,6 +15660,13 @@ snapshots:
object.assign: 4.1.7
object.values: 1.2.1
jszip@3.10.1:
dependencies:
lie: 3.3.0
pako: 1.0.11
readable-stream: 2.3.8
setimmediate: 1.0.5
junit-report-builder@5.1.1:
dependencies:
lodash: 4.17.21
@@ -15739,6 +15760,10 @@ snapshots:
prelude-ls: 1.2.1
type-check: 0.4.0
lie@3.3.0:
dependencies:
immediate: 3.0.6
lilconfig@3.1.3: {}
lines-and-columns@1.2.4: {}
@@ -16465,9 +16490,9 @@ snapshots:
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
dependencies:
'@next/env': 15.4.10
'@next/env': 15.4.11
'@swc/helpers': 0.5.15
caniuse-lite: 1.0.30001762
postcss: 8.4.31
@@ -16569,12 +16594,12 @@ snapshots:
dependencies:
boolbase: 1.0.0
nuqs@2.7.2(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1):
nuqs@2.7.2(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1):
dependencies:
'@standard-schema/spec': 1.0.0
react: 18.3.1
optionalDependencies:
next: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
next: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
oas-kit-common@1.0.8:
dependencies:

View File

@@ -119,7 +119,7 @@ export default function SharePage() {
<CardTitle>Output</CardTitle>
</CardHeader>
<CardContent>
<RunOutputs outputs={executionData.outputs} />
<RunOutputs outputs={executionData.outputs} shareToken={token} />
</CardContent>
</Card>

View File

@@ -1,4 +1,6 @@
import type { Metadata } from "next";
import Image from "next/image";
import Link from "next/link";
export const metadata: Metadata = {
title: "Shared Agent Run - AutoGPT",
@@ -13,6 +15,27 @@ export default function ShareLayout({
}) {
return (
<div className="min-h-screen bg-background">
<header className="border-b border-border bg-background">
<div className="container mx-auto flex justify-center px-4 py-4">
<Link href="/" className="inline-block">
<Image
src="/autogpt-logo-dark-bg.png"
alt="AutoGPT"
width={120}
height={54}
className="hidden h-8 w-auto dark:block"
/>
<Image
src="/autogpt-logo-light-bg.png"
alt="AutoGPT"
width={120}
height={54}
className="block h-8 w-auto dark:hidden"
priority
/>
</Link>
</div>
</header>
<div className="container mx-auto px-4 py-8">{children}</div>
</div>
);

View File

@@ -6,9 +6,11 @@ import { Suspense, useState } from "react";
import { Skeleton } from "@/components/ui/skeleton";
import type { ArtifactRef } from "../../../store";
import type { ArtifactClassification } from "../helpers";
import { ArtifactErrorBoundary } from "./ArtifactErrorBoundary";
import { ArtifactReactPreview } from "./ArtifactReactPreview";
import { ArtifactSkeleton } from "./ArtifactSkeleton";
import {
FRAGMENT_LINK_INTERCEPTOR_SCRIPT,
TAILWIND_CDN_URL,
wrapWithHeadInjection,
} from "@/lib/iframe-sandbox-csp";
@@ -53,20 +55,35 @@ function ArtifactContentLoader({
return (
<div ref={scrollRef} className="flex-1 overflow-y-auto">
<ArtifactRenderer
artifact={artifact}
content={content}
pdfUrl={pdfUrl}
isSourceView={isSourceView}
classification={classification}
/>
<ArtifactErrorBoundary
artifactID={artifact.id}
artifactTitle={artifact.title}
artifactType={classification.type}
>
<ArtifactRenderer
artifact={artifact}
content={content}
pdfUrl={pdfUrl}
isSourceView={isSourceView}
classification={classification}
/>
</ArtifactErrorBoundary>
</div>
);
}
function withCacheBust(src: string, nonce: number): string {
if (nonce === 0) return src;
const sep = src.includes("?") ? "&" : "?";
return `${src}${sep}_retry=${nonce}`;
}
function ArtifactImage({ src, alt }: { src: string; alt: string }) {
const [loaded, setLoaded] = useState(false);
const [error, setError] = useState(false);
// Incremented on every Try Again so the URL changes and the browser
// can't reuse a negative-cached response (SECRT-2221).
const [retryNonce, setRetryNonce] = useState(0);
if (error) {
return (
@@ -80,6 +97,7 @@ function ArtifactImage({ src, alt }: { src: string; alt: string }) {
onClick={() => {
setError(false);
setLoaded(false);
setRetryNonce((n) => n + 1);
}}
className="rounded-md border border-zinc-200 bg-white px-3 py-1.5 text-xs font-medium text-zinc-700 shadow-sm transition-colors hover:bg-zinc-50 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-violet-400"
>
@@ -96,7 +114,7 @@ function ArtifactImage({ src, alt }: { src: string; alt: string }) {
)}
{/* eslint-disable-next-line @next/next/no-img-element */}
<img
src={src}
src={withCacheBust(src, retryNonce)}
alt={alt}
className={`max-h-full max-w-full object-contain transition-opacity ${loaded ? "opacity-100" : "opacity-0"}`}
onLoad={() => setLoaded(true)}
@@ -109,6 +127,7 @@ function ArtifactImage({ src, alt }: { src: string; alt: string }) {
function ArtifactVideo({ src }: { src: string }) {
const [loaded, setLoaded] = useState(false);
const [error, setError] = useState(false);
const [retryNonce, setRetryNonce] = useState(0);
if (error) {
return (
@@ -122,6 +141,7 @@ function ArtifactVideo({ src }: { src: string }) {
onClick={() => {
setError(false);
setLoaded(false);
setRetryNonce((n) => n + 1);
}}
className="rounded-md border border-zinc-200 bg-white px-3 py-1.5 text-xs font-medium text-zinc-700 shadow-sm transition-colors hover:bg-zinc-50 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-violet-400"
>
@@ -137,7 +157,7 @@ function ArtifactVideo({ src }: { src: string }) {
<Skeleton className="absolute inset-4 h-[calc(100%-2rem)] w-[calc(100%-2rem)] rounded-md" />
)}
<video
src={src}
src={withCacheBust(src, retryNonce)}
controls
preload="metadata"
className={`max-h-full max-w-full rounded-md transition-opacity ${loaded ? "opacity-100" : "opacity-0"}`}
@@ -200,7 +220,10 @@ function ArtifactRenderer({
if (classification.type === "html") {
// Inject Tailwind CDN — no CSP (see iframe-sandbox-csp.ts for why)
const tailwindScript = `<script src="${TAILWIND_CDN_URL}"></script>`;
const wrapped = wrapWithHeadInjection(content, tailwindScript);
const wrapped = wrapWithHeadInjection(
content,
tailwindScript + FRAGMENT_LINK_INTERCEPTOR_SCRIPT,
);
return (
<iframe
sandbox="allow-scripts"

View File

@@ -27,6 +27,10 @@ export function ArtifactDragHandle({
minWidthRef.current = minWidth;
maxWidthPercentRef.current = maxWidthPercent;
// Track the captured pointer id so pointerup can release it even after
// React re-renders.
const pointerIdRef = useRef<number | null>(null);
// Attach document listeners only while dragging, and always tear them down
// on unmount — otherwise closing the panel mid-drag leaves listeners bound
// to a handler that calls setState on the unmounted component.
@@ -57,7 +61,7 @@ export function ArtifactDragHandle({
};
}, [isDragging]);
function handlePointerDown(e: React.PointerEvent) {
function handlePointerDown(e: React.PointerEvent<HTMLDivElement>) {
e.preventDefault();
startXRef.current = e.clientX;
@@ -67,9 +71,31 @@ export function ArtifactDragHandle({
) as HTMLElement | null;
startWidthRef.current = panel?.offsetWidth ?? DEFAULT_PANEL_WIDTH;
// Capture the pointer so pointermove/pointerup still reach us when the
// cursor drifts over sandboxed artifact iframes. Without this, the iframe
// eats the events and the drag gets stuck (SECRT-2256).
try {
e.currentTarget.setPointerCapture(e.pointerId);
pointerIdRef.current = e.pointerId;
} catch {
// Non-supporting environments (older test DOMs) — safe to ignore.
}
setIsDragging(true);
}
function handlePointerUp(e: React.PointerEvent<HTMLDivElement>) {
if (pointerIdRef.current != null) {
try {
e.currentTarget.releasePointerCapture(pointerIdRef.current);
} catch {
// Capture may already be released.
}
pointerIdRef.current = null;
}
setIsDragging(false);
}
return (
// 12px transparent hit target with the visible 1px line centered inside
// (WCAG-compliant, matches ~8-12px conventions of other resizable panels).
@@ -81,6 +107,9 @@ export function ArtifactDragHandle({
"group absolute -left-1.5 top-0 z-10 flex h-full w-3 cursor-col-resize items-stretch justify-center",
)}
onPointerDown={handlePointerDown}
onPointerUp={handlePointerUp}
onPointerCancel={handlePointerUp}
style={{ touchAction: "none" }}
>
<div
className={cn(

View File

@@ -0,0 +1,100 @@
"use client";
import * as Sentry from "@sentry/nextjs";
import { Component, type ErrorInfo, type ReactNode } from "react";
interface Props {
children: ReactNode;
artifactID: string;
artifactTitle: string;
artifactType: string;
}
interface State {
error: Error | null;
}
export class ArtifactErrorBoundary extends Component<Props, State> {
state: State = { error: null };
static getDerivedStateFromError(error: Error): State {
return { error };
}
componentDidCatch(error: Error, errorInfo: ErrorInfo) {
Sentry.captureException(error, {
contexts: {
react: { componentStack: errorInfo.componentStack },
},
tags: { errorBoundary: "true", context: "copilot-artifact" },
extra: {
artifactID: this.props.artifactID,
artifactTitle: this.props.artifactTitle,
artifactType: this.props.artifactType,
},
});
}
componentDidUpdate(prevProps: Props) {
if (
this.state.error &&
(prevProps.artifactID !== this.props.artifactID ||
prevProps.artifactTitle !== this.props.artifactTitle ||
prevProps.artifactType !== this.props.artifactType)
) {
this.setState({ error: null });
}
}
handleCopy = () => {
const { error } = this.state;
if (!error) return;
const details = [
`Artifact: ${this.props.artifactTitle}`,
`ID: ${this.props.artifactID}`,
`Type: ${this.props.artifactType}`,
`Error: ${error.message}`,
error.stack ? `Stack:\n${error.stack}` : "",
]
.filter(Boolean)
.join("\n");
navigator.clipboard?.writeText(details).catch(() => {});
};
render() {
const { error } = this.state;
if (!error) return this.props.children;
const message = error.message || "Unknown rendering error";
return (
<div
role="alert"
className="flex h-full flex-col items-center justify-center gap-3 p-8 text-center"
>
<p className="text-sm font-medium text-zinc-700">
This artifact couldn&apos;t be rendered
</p>
<p className="max-w-md break-words text-xs text-zinc-500">
Something in{" "}
<span className="font-mono">{this.props.artifactTitle}</span> threw an
error while rendering. The chat and sidebar are still working.
</p>
<pre className="max-h-32 max-w-md overflow-auto whitespace-pre-wrap break-words rounded-md bg-zinc-100 px-3 py-2 text-left text-xs text-zinc-700">
{message}
</pre>
<button
type="button"
onClick={this.handleCopy}
className="rounded-md border border-zinc-200 bg-white px-3 py-1.5 text-xs font-medium text-zinc-700 shadow-sm transition-colors hover:bg-zinc-50 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-violet-400"
>
Copy error details
</button>
<p className="max-w-md text-xs text-zinc-400">
Paste this into the chat so the agent can regenerate a working
version.
</p>
</div>
);
}
}

View File

@@ -199,6 +199,82 @@ describe("ArtifactContent", () => {
});
});
// SECRT-2221 integration: the classification-level fix (hi-res PNGs stop
// being size-gated) only matters if the end-to-end rendering pipeline
// actually reaches the <img> path. Pass in the real classifyArtifact
// result for a 25 MB .png and assert the panel renders an img element
// rather than routing to the download-only surface.
it("renders a 25 MB PNG through the <img> path, not download-only (SECRT-2221)", () => {
const artifact = makeArtifact({
id: "hires-png-001",
title: "poster.png",
mimeType: "image/png",
sourceUrl: "/api/proxy/api/workspace/files/hires-png-001/download",
sizeBytes: 25 * 1024 * 1024,
});
const classification = classifyArtifact(
artifact.mimeType,
artifact.title,
artifact.sizeBytes,
);
expect(classification.type).toBe("image");
expect(classification.openable).toBe(true);
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const img = container.querySelector("img");
expect(img).toBeTruthy();
expect(img?.getAttribute("src")).toBe(artifact.sourceUrl);
});
// SECRT-2221: image retry appends a cache-busting query so the browser
// can't reuse a previously-failed response. Without this, a transient
// 5xx that gets negative-cached keeps showing "Failed to load image" no
// matter how many times the user clicks Try again.
it("image retry appends a cache-busting query so the browser re-fetches (SECRT-2221)", async () => {
const artifact = makeArtifact({
id: "img-cachebust",
title: "hires.png",
mimeType: "image/png",
sourceUrl: "/api/proxy/api/workspace/files/img-cachebust/download",
});
const classification = makeClassification({ type: "image" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const firstImg = container.querySelector("img");
const firstSrc = firstImg?.getAttribute("src");
expect(firstSrc).toBe(artifact.sourceUrl);
fireEvent.error(firstImg!);
await waitFor(() => {
expect(screen.queryByText("Failed to load image")).toBeTruthy();
});
fireEvent.click(screen.getByRole("button", { name: /try again/i }));
await waitFor(() => {
const nextImg = container.querySelector("img");
const nextSrc = nextImg?.getAttribute("src") ?? "";
expect(nextSrc).not.toBe(firstSrc);
expect(nextSrc.startsWith(artifact.sourceUrl)).toBe(true);
// Assert the specific cache-bust contract, not just that the URL
// changed — guards against accidental rewrites that drop the key.
expect(nextSrc).toContain("_retry=");
});
});
// ── Video ─────────────────────────────────────────────────────────
it("renders video artifact with video tag and controls", () => {
@@ -379,6 +455,117 @@ describe("ArtifactContent", () => {
expect(retryButtons.length).toBeGreaterThan(0);
});
// SECRT-2224: "try again doesn't do anything". The retry itself works — the
// user's complaint is that there's no visible feedback when the same error
// returns (e.g. a 404 for a deleted file). Clicking Try Again must flip the
// UI into the loading skeleton immediately so the user can tell their click
// registered, instead of the error UI re-flashing in place.
it("clicking Try Again shows the loading skeleton before the next fetch settles (SECRT-2224)", async () => {
let resolveSecond: (value: unknown) => void = () => {};
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
if (callCount === 1) {
return Promise.resolve({
ok: false,
status: 404,
text: () => Promise.resolve("Not found"),
});
}
return new Promise((resolve) => {
resolveSecond = resolve;
});
}),
);
const artifact = makeArtifact({
id: "retry-skeleton-001",
title: "flaky.html",
mimeType: "text/html",
});
const classification = makeClassification({ type: "html" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
await screen.findByText("Failed to load content");
fireEvent.click(screen.getByRole("button", { name: /try again/i }));
// Before the second fetch resolves, the error must be gone and a skeleton
// visible (animate-pulse is the Skeleton component's signature class).
await waitFor(() => {
expect(screen.queryByText("Failed to load content")).toBeNull();
expect(container.querySelector('[class*="animate-pulse"]')).toBeTruthy();
});
// Let the second fetch complete and wait for the recovered render so
// pending React updates can't leak into the next test.
resolveSecond({
ok: true,
text: () => Promise.resolve("<html><body>ok</body></html>"),
});
await screen.findByTitle("flaky.html");
});
// SECRT-2224 end-to-end: Try Again actually recovers when the next fetch
// succeeds. Covers the full click → re-fetch → iframe-render loop.
it("clicking Try Again re-fetches and renders recovered HTML content (SECRT-2224)", async () => {
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
if (callCount === 1) {
return Promise.resolve({
ok: false,
status: 404,
text: () => Promise.resolve("Not found"),
});
}
return Promise.resolve({
ok: true,
text: () =>
Promise.resolve(
"<html><body><h1 id='ok'>recovered</h1></body></html>",
),
});
}),
);
const artifact = makeArtifact({
id: "retry-recover-001",
title: "flaky.html",
mimeType: "text/html",
});
const classification = makeClassification({ type: "html" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
await screen.findByText("Failed to load content");
fireEvent.click(screen.getByRole("button", { name: /try again/i }));
await waitFor(() => {
const iframe = container.querySelector("iframe");
expect(iframe).toBeTruthy();
expect(iframe?.getAttribute("srcdoc")).toContain("recovered");
});
expect(screen.queryByText("Failed to load content")).toBeNull();
expect(callCount).toBeGreaterThanOrEqual(2);
});
// ── HTML ──────────────────────────────────────────────────────────
it("renders HTML content in sandboxed iframe", async () => {
@@ -412,6 +599,41 @@ describe("ArtifactContent", () => {
expect(iframe?.getAttribute("sandbox")).toBe("allow-scripts");
});
it("injects the fragment-link interceptor into HTML artifact iframes (regression)", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () =>
Promise.resolve(
'<html><head></head><body><a href="#x">x</a><div id="x">x</div></body></html>',
),
}),
);
const { container } = render(
<ArtifactContent
artifact={makeArtifact({
id: "html-frag",
title: "page.html",
mimeType: "text/html",
})}
isSourceView={false}
classification={makeClassification({ type: "html" })}
/>,
);
await screen.findByTitle("page.html");
const srcdoc = container.querySelector("iframe")?.getAttribute("srcdoc");
expect(srcdoc).toBeTruthy();
// Markers unique to FRAGMENT_LINK_INTERCEPTOR_SCRIPT — if any of these
// disappear, the interceptor is no longer being injected and fragment
// links will navigate the parent URL again.
expect(srcdoc).toContain("__fragmentLinkInterceptor");
expect(srcdoc).toContain('a[href^="#"]');
expect(srcdoc).toContain("scrollIntoView");
});
// ── Source view ───────────────────────────────────────────────────
it("renders source view as pre tag", async () => {
@@ -923,6 +1145,239 @@ describe("ArtifactContent", () => {
},
);
// ── Error boundary ────────────────────────────────────────────────
it("shows a visible error instead of crashing when the renderer throws", async () => {
const consoleErr = vi.spyOn(console, "error").mockImplementation(() => {});
const originalImpl = vi
.mocked(ArtifactReactPreview)
.getMockImplementation();
vi.mocked(ArtifactReactPreview).mockImplementation(() => {
throw new Error("boom in renderer");
});
try {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve("source"),
}),
);
const artifact = makeArtifact({
id: "crash-001",
title: "broken.tsx",
mimeType: "text/tsx",
});
const classification = makeClassification({ type: "react" });
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
expect(
await screen.findByText(/This artifact couldn't be rendered/i),
).toBeTruthy();
expect(screen.getByText(/boom in renderer/)).toBeTruthy();
expect(
screen.getByRole("button", { name: /copy error details/i }),
).toBeTruthy();
} finally {
if (originalImpl) {
vi.mocked(ArtifactReactPreview).mockImplementation(originalImpl);
}
consoleErr.mockRestore();
}
});
it("copies artifact title, type, and error to the clipboard", async () => {
const consoleErr = vi.spyOn(console, "error").mockImplementation(() => {});
const writeText = vi.fn().mockResolvedValue(undefined);
Object.defineProperty(navigator, "clipboard", {
value: { writeText },
writable: true,
configurable: true,
});
const originalImpl = vi
.mocked(ArtifactReactPreview)
.getMockImplementation();
vi.mocked(ArtifactReactPreview).mockImplementation(() => {
throw new Error("jsx parse failed at line 42");
});
try {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve("source"),
}),
);
render(
<ArtifactContent
artifact={makeArtifact({
id: "crash-002",
title: "report.tsx",
mimeType: "text/tsx",
})}
isSourceView={false}
classification={makeClassification({ type: "react" })}
/>,
);
fireEvent.click(
await screen.findByRole("button", { name: /copy error details/i }),
);
await waitFor(() => {
expect(writeText).toHaveBeenCalled();
});
const payload = writeText.mock.calls[0]![0] as string;
expect(payload).toContain("report.tsx");
expect(payload).toContain("crash-002");
expect(payload).toContain("react");
expect(payload).toContain("jsx parse failed at line 42");
} finally {
if (originalImpl) {
vi.mocked(ArtifactReactPreview).mockImplementation(originalImpl);
}
consoleErr.mockRestore();
}
});
// Regression: two different artifacts can share the same title+type (e.g.
// two "App.tsx" files from different sessions). The boundary must reset
// when artifact.id changes, not only on title/type changes, otherwise
// opening a second artifact after a crash stays stuck on the first's error.
it("resets the error fallback when the artifact id changes (same title/type)", async () => {
const consoleErr = vi.spyOn(console, "error").mockImplementation(() => {});
const originalImpl = vi
.mocked(ArtifactReactPreview)
.getMockImplementation();
// First render: throws.
vi.mocked(ArtifactReactPreview).mockImplementation(() => {
throw new Error("first render boom");
});
try {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve("source"),
}),
);
const classification = makeClassification({ type: "react" });
const { rerender } = render(
<ArtifactContent
artifact={makeArtifact({
id: "id-one",
title: "App.tsx",
mimeType: "text/tsx",
})}
isSourceView={false}
classification={classification}
/>,
);
await screen.findByText(/This artifact couldn't be rendered/i);
// Swap in a working renderer and a different artifact id (same title/type).
if (originalImpl) {
vi.mocked(ArtifactReactPreview).mockImplementation(originalImpl);
}
rerender(
<ArtifactContent
artifact={makeArtifact({
id: "id-two",
title: "App.tsx",
mimeType: "text/tsx",
})}
isSourceView={false}
classification={classification}
/>,
);
await waitFor(() => {
expect(
screen.queryByText(/This artifact couldn't be rendered/i),
).toBeNull();
expect(screen.getByTestId("react-preview")).toBeTruthy();
});
} finally {
if (originalImpl) {
vi.mocked(ArtifactReactPreview).mockImplementation(originalImpl);
}
consoleErr.mockRestore();
}
});
it("renders the user-reported plotly HTML artifact into a sandboxed iframe", async () => {
const html = `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>AutoGPT Beta Launch Interactive Report</title>
<script src="https://cdn.plot.ly/plotly-2.27.0.min.js"></script>
<style>
:root { --bg: #f8f9fa; --primary: #6c5ce7; }
* { margin: 0; padding: 0; box-sizing: border-box; }
body { font-family: 'Segoe UI', system-ui, sans-serif; }
</style>
</head>
<body>
<header><h1>\u{1F4CA} AutoGPT Beta Launch Interactive Report</h1></header>
<div class="chart-container" id="globalActivationChart"></div>
<script>
function showTab(tabId, groupId) {
const group = document.getElementById(groupId);
group.querySelectorAll('.tab-content').forEach(t => t.classList.remove('active'));
document.getElementById(tabId).classList.add('active');
}
Plotly.newPlot('globalActivationChart', [{ type: 'pie', values: [1, 2] }], {});
</script>
</body>
</html>`;
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve(html),
}),
);
const artifact = makeArtifact({
id: "html-big-report",
title: "report.html",
mimeType: "text/html",
});
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={makeClassification({ type: "html" })}
/>,
);
await screen.findByTitle("report.html");
const iframe = container.querySelector("iframe");
expect(iframe).toBeTruthy();
expect(iframe?.getAttribute("sandbox")).toBe("allow-scripts");
expect(screen.queryByText(/couldn't be rendered/i)).toBeNull();
});
it("falls back to pre tag when no renderer matches", async () => {
const { globalRegistry } = await import(
"@/components/contextual/OutputRenderers"

View File

@@ -0,0 +1,181 @@
import { cleanup, fireEvent, render } from "@testing-library/react";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { ArtifactDragHandle } from "../ArtifactDragHandle";
function renderHandle(onWidthChange = vi.fn(), panelWidth = 600) {
const utils = render(
<div
data-artifact-panel
style={{
width: `${panelWidth}px`,
height: "400px",
position: "relative",
}}
>
<ArtifactDragHandle onWidthChange={onWidthChange} />
</div>,
);
const panel = utils.container.querySelector(
"[data-artifact-panel]",
) as HTMLElement;
// happy-dom doesn't compute layout; stub offsetWidth so the handle reads
// the intended starting width.
Object.defineProperty(panel, "offsetWidth", {
value: panelWidth,
configurable: true,
});
const handle = utils.container.querySelector(
'[role="separator"]',
) as HTMLElement;
return { handle, onWidthChange, ...utils };
}
// jsdom/happy-dom don't implement pointer capture by default — spy on the
// prototype so vi.restoreAllMocks() can tear the spies down. We also seed
// no-op base implementations where the prototype lacks them so vi.spyOn has
// something to wrap. Both the seeded properties and window.innerWidth are
// manual mutations that vi.restoreAllMocks() won't undo, so capture their
// original descriptors and restore them in `restoreGlobals`.
function installPointerCaptureStub() {
const proto = HTMLElement.prototype as unknown as {
setPointerCapture?: (id: number) => void;
releasePointerCapture?: (id: number) => void;
};
const originalSetPointerCapture = Object.getOwnPropertyDescriptor(
proto,
"setPointerCapture",
);
const originalReleasePointerCapture = Object.getOwnPropertyDescriptor(
proto,
"releasePointerCapture",
);
if (!proto.setPointerCapture) proto.setPointerCapture = () => {};
if (!proto.releasePointerCapture) proto.releasePointerCapture = () => {};
const setPointerCapture = vi
.spyOn(HTMLElement.prototype, "setPointerCapture")
.mockImplementation(() => {});
const releasePointerCapture = vi
.spyOn(HTMLElement.prototype, "releasePointerCapture")
.mockImplementation(() => {});
function restoreGlobals() {
if (originalSetPointerCapture) {
Object.defineProperty(
proto,
"setPointerCapture",
originalSetPointerCapture,
);
} else {
delete proto.setPointerCapture;
}
if (originalReleasePointerCapture) {
Object.defineProperty(
proto,
"releasePointerCapture",
originalReleasePointerCapture,
);
} else {
delete proto.releasePointerCapture;
}
}
return { setPointerCapture, releasePointerCapture, restoreGlobals };
}
describe("ArtifactDragHandle", () => {
let spies: ReturnType<typeof installPointerCaptureStub>;
const originalInnerWidth = Object.getOwnPropertyDescriptor(
window,
"innerWidth",
);
beforeEach(() => {
spies = installPointerCaptureStub();
Object.defineProperty(window, "innerWidth", {
value: 1200,
writable: true,
configurable: true,
});
});
afterEach(() => {
cleanup();
vi.restoreAllMocks();
spies.restoreGlobals();
if (originalInnerWidth) {
Object.defineProperty(window, "innerWidth", originalInnerWidth);
}
});
// SECRT-2256: when the cursor drifts over a sandboxed iframe mid-drag, the
// iframe eats pointermove/pointerup and the drag gets stuck. setPointerCapture
// routes all subsequent pointer events to the handle regardless of what's
// under the cursor, which fixes both "can't drag right" and "drag doesn't
// stop on release".
it("captures the pointer on pointerdown so drags survive the cursor drifting over iframes (SECRT-2256)", () => {
const { handle } = renderHandle();
fireEvent.pointerDown(handle, { clientX: 500, pointerId: 7 });
expect(spies.setPointerCapture).toHaveBeenCalledWith(7);
});
it("releases the pointer capture when the drag ends", () => {
const { handle } = renderHandle();
fireEvent.pointerDown(handle, { clientX: 500, pointerId: 7 });
fireEvent.pointerUp(handle, { clientX: 400, pointerId: 7 });
expect(spies.releasePointerCapture).toHaveBeenCalledWith(7);
});
it("calls onWidthChange with the expanded width when dragging leftwards", () => {
const onWidthChange = vi.fn();
const { handle } = renderHandle(onWidthChange);
fireEvent.pointerDown(handle, { clientX: 800, pointerId: 1 });
fireEvent.pointerMove(document, { clientX: 700, pointerId: 1 });
// startWidth is 600 (container), delta = 800 - 700 = 100 → newWidth 700
expect(onWidthChange).toHaveBeenCalledWith(700);
});
it("calls onWidthChange with the shrunk width when dragging rightwards", () => {
const onWidthChange = vi.fn();
const { handle } = renderHandle(onWidthChange);
fireEvent.pointerDown(handle, { clientX: 800, pointerId: 1 });
fireEvent.pointerMove(document, { clientX: 900, pointerId: 1 });
// delta = -100 → newWidth 500
expect(onWidthChange).toHaveBeenCalledWith(500);
});
it("clamps to minWidth and maxWidth", () => {
const onWidthChange = vi.fn();
const { handle } = renderHandle(onWidthChange);
fireEvent.pointerDown(handle, { clientX: 800, pointerId: 1 });
// Drag way left → want huge width, should clamp at 85% of 1200 = 1020
fireEvent.pointerMove(document, { clientX: -5000, pointerId: 1 });
expect(onWidthChange).toHaveBeenLastCalledWith(1020);
// Drag way right → want tiny width, should clamp at minWidth 320
fireEvent.pointerMove(document, { clientX: 5000, pointerId: 1 });
expect(onWidthChange).toHaveBeenLastCalledWith(320);
});
it("stops dragging on pointerup so subsequent cursor moves don't resize", () => {
const onWidthChange = vi.fn();
const { handle } = renderHandle(onWidthChange);
fireEvent.pointerDown(handle, { clientX: 800, pointerId: 1 });
fireEvent.pointerUp(handle, { clientX: 800, pointerId: 1 });
onWidthChange.mockClear();
fireEvent.pointerMove(document, { clientX: 500, pointerId: 1 });
expect(onWidthChange).not.toHaveBeenCalled();
});
});

View File

@@ -3,7 +3,7 @@ import {
buildReactArtifactSrcDoc,
collectPreviewStyles,
escapeHtml,
} from "./reactArtifactPreview";
} from "../reactArtifactPreview";
describe("escapeHtml", () => {
it("escapes &, <, >, \", '", () => {
@@ -116,4 +116,11 @@ describe("buildReactArtifactSrcDoc", () => {
expect(doc).toContain("/^[A-Z]/.test(name)");
expect(doc).toContain("wrapWithProviders");
});
it("injects the fragment-link interceptor so #anchor clicks stay inside the iframe (regression)", () => {
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
expect(doc).toContain("__fragmentLinkInterceptor");
expect(doc).toContain('a[href^="#"]');
expect(doc).toContain("scrollIntoView");
});
});

View File

@@ -19,7 +19,10 @@
* React is loaded from unpkg with pinned version and SRI integrity hashes.
*/
import { TAILWIND_CDN_URL } from "@/lib/iframe-sandbox-csp";
import {
FRAGMENT_LINK_INTERCEPTOR_SCRIPT,
TAILWIND_CDN_URL,
} from "@/lib/iframe-sandbox-csp";
export { transpileReactArtifactSource } from "./transpileReactArtifact";
@@ -95,6 +98,7 @@ export function buildReactArtifactSrcDoc(
}
</style>
<script src="${TAILWIND_CDN_URL}"></script>
${FRAGMENT_LINK_INTERCEPTOR_SCRIPT}
<script crossorigin="anonymous" src="https://unpkg.com/react@18.3.1/umd/react.production.min.js" integrity="sha384-DGyLxAyjq0f9SPpVevD6IgztCFlnMF6oW/XQGmfe+IsZ8TqEiDrcHkMLKI6fiB/Z"></script><!-- pragma: allowlist secret -->
<script crossorigin="anonymous" src="https://unpkg.com/react-dom@18.3.1/umd/react-dom.production.min.js" integrity="sha384-gTGxhz21lVGYNMcdJOyq01Edg0jhn/c22nsx0kyqP0TxaV5WVdsSH1fSDUf5YJj1"></script><!-- pragma: allowlist secret -->
</head>

View File

@@ -141,6 +141,11 @@ export function useArtifactContent(
function retry() {
// Drop any cached failure/content for this id so we actually re-fetch.
contentCache.delete(artifact.id);
// Flip into loading + clear error synchronously with the click so the
// user always sees the skeleton (rather than the error UI re-flashing
// instantly for same-error retries). See SECRT-2224.
setIsLoading(true);
setError(null);
setRetryNonce((n) => n + 1);
}

View File

@@ -45,12 +45,33 @@ describe("classifyArtifact", () => {
expect(classifyArtifact("text/markdown", "x").type).toBe("markdown");
});
it("gates files > 10MB to download-only", () => {
it("gates text/code files > 10MB to download-only", () => {
const c = classifyArtifact("text/plain", "big.txt", 20 * 1024 * 1024);
expect(c.openable).toBe(false);
expect(c.type).toBe("download-only");
});
// SECRT-2221: large images (hi-res PNGs, etc.) were getting force-classified
// as download-only by the generic >10MB gate, so clicking them started a
// download instead of previewing — and the preview was "broken" in the
// sense that it never appeared. Images, videos, and PDFs are decoded
// natively by the browser and don't run through our JS render pipeline,
// so the size gate shouldn't apply to them.
it("does NOT size-gate large images, videos, or PDFs (SECRT-2221)", () => {
expect(
classifyArtifact("image/png", "hires.png", 25 * 1024 * 1024).type,
).toBe("image");
expect(
classifyArtifact("image/jpeg", "huge.jpg", 50 * 1024 * 1024).type,
).toBe("image");
expect(
classifyArtifact("video/mp4", "long.mp4", 500 * 1024 * 1024).type,
).toBe("video");
expect(
classifyArtifact("application/pdf", "book.pdf", 80 * 1024 * 1024).type,
).toBe("pdf");
});
it("treats binary/octet-stream MIME as download-only", () => {
expect(classifyArtifact("application/zip", "a.zip").openable).toBe(false);
expect(classifyArtifact("application/octet-stream", "x").openable).toBe(

View File

@@ -257,14 +257,34 @@ function getExtension(filename?: string): string {
return filename.slice(lastDot).toLowerCase();
}
// Types the browser renders natively — we don't run their bytes through our
// React/JS pipeline, so the size gate doesn't need to apply.
const NATIVELY_RENDERED = new Set<ArtifactClassification["type"]>([
"image",
"video",
"pdf",
]);
export function classifyArtifact(
mimeType: string | null,
filename?: string,
sizeBytes?: number,
): ArtifactClassification {
// Size gate: >10MB is download-only regardless of type.
if (sizeBytes && sizeBytes > TEN_MB) return KIND["download-only"];
const kind = classifyByTypeOnly(mimeType, filename);
// Size gate: >10MB is download-only, but only for content we actually
// render in JS. Images, videos, and PDFs are handled natively by the
// browser — gating them produced "broken previews" for hi-res files
// (SECRT-2221).
if (sizeBytes && sizeBytes > TEN_MB && !NATIVELY_RENDERED.has(kind.type)) {
return KIND["download-only"];
}
return kind;
}
function classifyByTypeOnly(
mimeType: string | null,
filename?: string,
): ArtifactClassification {
const basename = getBasename(filename);
const exactKind = EXACT_FILENAME_KIND[basename];
if (exactKind) return KIND[exactKind];

View File

@@ -1,5 +1,5 @@
import { act, renderHook } from "@testing-library/react";
import { beforeEach, describe, expect, it } from "vitest";
import { act, cleanup, renderHook } from "@testing-library/react";
import { afterEach, beforeEach, describe, expect, it } from "vitest";
import { useCopilotUIStore } from "../../store";
import { useAutoOpenArtifacts } from "./useAutoOpenArtifacts";
@@ -31,6 +31,11 @@ function resetStore() {
describe("useAutoOpenArtifacts", () => {
beforeEach(resetStore);
// Testing Library auto-cleanup isn't registered in our Vitest setup, so
// mounted `renderHook` instances (and their unmount cleanups) would leak
// between tests — here the unmount effect in useAutoOpenArtifacts would
// fire after the next test had already run and corrupt its assertions.
afterEach(cleanup);
it("does not auto-open on initial render", () => {
renderHook(() => useAutoOpenArtifacts({ sessionId: "s1" }));
@@ -88,4 +93,53 @@ describe("useAutoOpenArtifacts", () => {
expect(s.activeArtifact?.id).toBe("c");
expect(s.history).toEqual([]);
});
// SECRT-2254: "had agent panel open then went to profile then went to home
// and agent panel was still open". Nav-away unmounts the copilot page; if
// the panel state persists in the store, coming back re-renders it open.
it("closes the panel on unmount so nav-away → nav-back doesn't resurrect it (SECRT-2254)", () => {
useCopilotUIStore.getState().openArtifact(makeArtifact(A_ID, "a.txt"));
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(true);
const { unmount } = renderHook(() =>
useAutoOpenArtifacts({ sessionId: "s1" }),
);
act(() => {
unmount();
});
const s = useCopilotUIStore.getState().artifactPanel;
expect(s.isOpen).toBe(false);
expect(s.activeArtifact).toBeNull();
expect(s.history).toEqual([]);
});
// SECRT-2220: "keep closed by default" — a fresh mount (e.g. user returns to
// /copilot) must start with a closed panel even if the store somehow carries
// stale state from a prior life.
it("does not re-open a panel whose store state is stale on fresh mount (SECRT-2220)", () => {
// Simulate the store being left in an open state by a previous page life.
useCopilotUIStore.setState({
artifactPanel: {
isOpen: true,
isMinimized: false,
isMaximized: false,
width: 600,
activeArtifact: makeArtifact(A_ID, "stale.txt"),
history: [],
},
});
const { unmount } = renderHook(() =>
useAutoOpenArtifacts({ sessionId: "s1" }),
);
act(() => {
unmount();
});
// Next mount of the page should see a clean store.
renderHook(() => useAutoOpenArtifacts({ sessionId: "s1" }));
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
});
});

View File

@@ -26,4 +26,13 @@ export function useAutoOpenArtifacts({
resetArtifactPanel();
}
}, [sessionId, resetArtifactPanel]);
// Reset on unmount so navigating away from /copilot (to /profile, /home,
// etc.) can't leave the panel open in the Zustand store, which would then
// render the panel re-open when the user returns. See SECRT-2254/2220.
useEffect(() => {
return () => {
resetArtifactPanel();
};
}, [resetArtifactPanel]);
}

View File

@@ -1,74 +1,2 @@
import { OutputRenderer, OutputMetadata } from "../types";
export interface DownloadItem {
value: any;
metadata?: OutputMetadata;
renderer: OutputRenderer;
}
export async function downloadOutputs(items: DownloadItem[]) {
const concatenableTexts: string[] = [];
const nonConcatenableDownloads: Array<{ blob: Blob; filename: string }> = [];
for (const item of items) {
if (item.renderer.isConcatenable(item.value, item.metadata)) {
const copyContent = item.renderer.getCopyContent(
item.value,
item.metadata,
);
if (copyContent) {
// Extract text from CopyContent
let text: string;
if (typeof copyContent.data === "string") {
text = copyContent.data;
} else if (copyContent.fallbackText) {
text = copyContent.fallbackText;
} else {
continue;
}
concatenableTexts.push(text);
}
} else {
const downloadContent = item.renderer.getDownloadContent(
item.value,
item.metadata,
);
if (downloadContent) {
if (typeof downloadContent.data === "string") {
if (downloadContent.data.startsWith("http")) {
const link = document.createElement("a");
link.href = downloadContent.data;
link.download = downloadContent.filename;
link.click();
}
} else {
nonConcatenableDownloads.push({
blob: downloadContent.data as Blob,
filename: downloadContent.filename,
});
}
}
}
}
if (concatenableTexts.length > 0) {
const combinedText = concatenableTexts.join("\n\n---\n\n");
const blob = new Blob([combinedText], { type: "text/plain" });
downloadBlob(blob, "combined_output.txt");
}
for (const download of nonConcatenableDownloads) {
downloadBlob(download.blob, download.filename);
}
}
function downloadBlob(blob: Blob, filename: string) {
const url = URL.createObjectURL(blob);
const link = document.createElement("a");
link.href = url;
link.download = filename;
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
URL.revokeObjectURL(url);
}
export { downloadOutputs } from "@/lib/utils/download-outputs";
export type { DownloadItem } from "@/lib/utils/download-outputs";

View File

@@ -15,9 +15,10 @@ type OutputsRecord = Record<string, Array<unknown>>;
interface RunOutputsProps {
outputs: OutputsRecord;
shareToken?: string;
}
export function RunOutputs({ outputs }: RunOutputsProps) {
export function RunOutputs({ outputs, shareToken }: RunOutputsProps) {
const items = useMemo(() => {
const list: Array<{
key: string;
@@ -30,6 +31,7 @@ export function RunOutputs({ outputs }: RunOutputsProps) {
Object.entries(outputs || {}).forEach(([key, values]) => {
(values || []).forEach((value, index) => {
const metadata: OutputMetadata = {};
if (shareToken) metadata.shareToken = shareToken;
if (
typeof value === "object" &&
value !== null &&
@@ -76,7 +78,7 @@ export function RunOutputs({ outputs }: RunOutputsProps) {
});
return list;
}, [outputs]);
}, [outputs, shareToken]);
if (!items.length) {
return <div className="text-neutral-600">No output from this run.</div>;

View File

@@ -6931,6 +6931,262 @@
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/platform-linking/links": {
"get": {
"tags": ["platform-linking"],
"summary": "List all platform servers linked to the authenticated user",
"operationId": "getPlatform-linkingList all platform servers linked to the authenticated user",
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"items": { "$ref": "#/components/schemas/PlatformLinkInfo" },
"type": "array",
"title": "Response Getplatform-Linkinglist All Platform Servers Linked To The Authenticated User"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
}
},
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/platform-linking/links/{link_id}": {
"delete": {
"tags": ["platform-linking"],
"summary": "Unlink a platform server",
"operationId": "deletePlatform-linkingUnlink a platform server",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "link_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Link Id" }
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/DeleteLinkResponse" }
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/platform-linking/tokens/{token}/confirm": {
"post": {
"tags": ["platform-linking"],
"summary": "Confirm a SERVER link token (user must be authenticated)",
"operationId": "postPlatform-linkingConfirm a server link token (user must be authenticated)",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "token",
"in": "path",
"required": true,
"schema": {
"type": "string",
"maxLength": 64,
"pattern": "^[A-Za-z0-9_-]+$",
"title": "Token"
}
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/ConfirmLinkResponse" }
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/platform-linking/tokens/{token}/info": {
"get": {
"tags": ["platform-linking"],
"summary": "Get display info for a link token",
"operationId": "getPlatform-linkingGet display info for a link token",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "token",
"in": "path",
"required": true,
"schema": {
"type": "string",
"maxLength": 64,
"pattern": "^[A-Za-z0-9_-]+$",
"title": "Token"
}
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/LinkTokenInfoResponse"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/platform-linking/user-links": {
"get": {
"tags": ["platform-linking"],
"summary": "List all DM links for the authenticated user",
"operationId": "getPlatform-linkingList all dm links for the authenticated user",
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"items": {
"$ref": "#/components/schemas/PlatformUserLinkInfo"
},
"type": "array",
"title": "Response Getplatform-Linkinglist All Dm Links For The Authenticated User"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
}
},
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/platform-linking/user-links/{link_id}": {
"delete": {
"tags": ["platform-linking"],
"summary": "Unlink a DM / user link",
"operationId": "deletePlatform-linkingUnlink a dm / user link",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "link_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Link Id" }
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/DeleteLinkResponse" }
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/platform-linking/user-tokens/{token}/confirm": {
"post": {
"tags": ["platform-linking"],
"summary": "Confirm a USER link token (user must be authenticated)",
"operationId": "postPlatform-linkingConfirm a user link token (user must be authenticated)",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "token",
"in": "path",
"required": true,
"schema": {
"type": "string",
"maxLength": 64,
"pattern": "^[A-Za-z0-9_-]+$",
"title": "Token"
}
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ConfirmUserLinkResponse"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/public/shared/{share_token}": {
"get": {
"tags": ["v1"],
@@ -6971,6 +7227,50 @@
}
}
},
"/api/public/shared/{share_token}/files/{file_id}/download": {
"get": {
"tags": ["v1", "graphs"],
"summary": "Download a file from a shared execution",
"description": "Download a workspace file from a shared execution (no auth required).\n\nValidates that the file was explicitly exposed when sharing was enabled.\nReturns a uniform 404 for all failure modes to prevent enumeration attacks.",
"operationId": "download_shared_file",
"parameters": [
{
"name": "share_token",
"in": "path",
"required": true,
"schema": {
"type": "string",
"pattern": "^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
"title": "Share Token"
}
},
{
"name": "file_id",
"in": "path",
"required": true,
"schema": {
"type": "string",
"pattern": "^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
"title": "File Id"
}
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": { "application/json": { "schema": {} } }
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/review/action": {
"post": {
"tags": ["v2", "executions", "review", "v2", "executions", "review"],
@@ -10029,6 +10329,46 @@
"title": "CoPilotUsagePublic",
"description": "Current usage status for a user — public (client-safe) shape."
},
"ConfirmLinkResponse": {
"properties": {
"success": { "type": "boolean", "title": "Success" },
"link_type": {
"$ref": "#/components/schemas/LinkType",
"default": "SERVER"
},
"platform": { "type": "string", "title": "Platform" },
"platform_server_id": {
"type": "string",
"title": "Platform Server Id"
},
"server_name": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Server Name"
}
},
"type": "object",
"required": [
"success",
"platform",
"platform_server_id",
"server_name"
],
"title": "ConfirmLinkResponse"
},
"ConfirmUserLinkResponse": {
"properties": {
"success": { "type": "boolean", "title": "Success" },
"link_type": {
"$ref": "#/components/schemas/LinkType",
"default": "USER"
},
"platform": { "type": "string", "title": "Platform" },
"platform_user_id": { "type": "string", "title": "Platform User Id" }
},
"type": "object",
"required": ["success", "platform", "platform_user_id"],
"title": "ConfirmUserLinkResponse"
},
"ContentType": {
"type": "string",
"enum": [
@@ -10396,6 +10736,12 @@
"required": ["version_counts"],
"title": "DeleteGraphResponse"
},
"DeleteLinkResponse": {
"properties": { "success": { "type": "boolean", "title": "Success" } },
"type": "object",
"required": ["success"],
"title": "DeleteLinkResponse"
},
"DiscoverToolsRequest": {
"properties": {
"server_url": {
@@ -12347,6 +12693,24 @@
"required": ["source_id", "sink_id", "source_name", "sink_name"],
"title": "Link"
},
"LinkTokenInfoResponse": {
"properties": {
"platform": { "type": "string", "title": "Platform" },
"link_type": { "$ref": "#/components/schemas/LinkType" },
"server_name": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Server Name"
}
},
"type": "object",
"required": ["platform", "link_type"],
"title": "LinkTokenInfoResponse"
},
"LinkType": {
"type": "string",
"enum": ["SERVER", "USER"],
"title": "LinkType"
},
"ListFilesResponse": {
"properties": {
"files": {
@@ -13491,6 +13855,64 @@
"required": ["logs", "pagination"],
"title": "PlatformCostLogsResponse"
},
"PlatformLinkInfo": {
"properties": {
"id": { "type": "string", "title": "Id" },
"platform": { "type": "string", "title": "Platform" },
"platform_server_id": {
"type": "string",
"title": "Platform Server Id"
},
"owner_platform_user_id": {
"type": "string",
"title": "Owner Platform User Id"
},
"server_name": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Server Name"
},
"linked_at": {
"type": "string",
"format": "date-time",
"title": "Linked At"
}
},
"type": "object",
"required": [
"id",
"platform",
"platform_server_id",
"owner_platform_user_id",
"server_name",
"linked_at"
],
"title": "PlatformLinkInfo"
},
"PlatformUserLinkInfo": {
"properties": {
"id": { "type": "string", "title": "Id" },
"platform": { "type": "string", "title": "Platform" },
"platform_user_id": { "type": "string", "title": "Platform User Id" },
"platform_username": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Platform Username"
},
"linked_at": {
"type": "string",
"format": "date-time",
"title": "Linked At"
}
},
"type": "object",
"required": [
"id",
"platform",
"platform_user_id",
"platform_username",
"linked_at"
],
"title": "PlatformUserLinkInfo"
},
"PostmarkBounceEnum": {
"type": "integer",
"enum": [

View File

@@ -9,13 +9,16 @@ import {
} from "./route.helpers";
describe("isWorkspaceDownloadRequest", () => {
it("matches api/workspace/files/{id}/download pattern", () => {
const VALID_UUID = "550e8400-e29b-41d4-a716-446655440000";
const VALID_UUID_2 = "6ba7b810-9dad-11d1-80b4-00c04fd430c8";
it("matches api/workspace/files/{uuid}/download pattern", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
"abc-123",
VALID_UUID,
"download",
]),
).toBe(true);
@@ -30,7 +33,7 @@ describe("isWorkspaceDownloadRequest", () => {
"api",
"workspace",
"files",
"id",
VALID_UUID,
"download",
"extra",
]),
@@ -43,7 +46,7 @@ describe("isWorkspaceDownloadRequest", () => {
"v1",
"workspace",
"files",
"id",
VALID_UUID,
"download",
]),
).toBe(false);
@@ -55,11 +58,622 @@ describe("isWorkspaceDownloadRequest", () => {
"api",
"workspace",
"files",
"id",
VALID_UUID,
"metadata",
]),
).toBe(false);
});
it("matches api/public/shared/{uuid}/files/{uuid}/download pattern", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"public",
"shared",
VALID_UUID,
"files",
VALID_UUID_2,
"download",
]),
).toBe(true);
});
it("rejects public shared paths not ending with download", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"public",
"shared",
VALID_UUID,
"files",
VALID_UUID_2,
"metadata",
]),
).toBe(false);
});
it("rejects non-UUID file ID in workspace path", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
"not-a-uuid",
"download",
]),
).toBe(false);
});
it("rejects non-UUID token in public share path", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"public",
"shared",
"not-a-uuid",
"files",
VALID_UUID,
"download",
]),
).toBe(false);
});
it("rejects non-UUID file ID in public share path", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"public",
"shared",
VALID_UUID,
"files",
"not-a-uuid",
"download",
]),
).toBe(false);
});
it("accepts uppercase hex in UUIDs", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
"550E8400-E29B-41D4-A716-446655440000",
"download",
]),
).toBe(true);
});
describe("adversarial inputs", () => {
it("rejects empty path", () => {
expect(isWorkspaceDownloadRequest([])).toBe(false);
});
it("rejects single-segment path", () => {
expect(isWorkspaceDownloadRequest(["download"])).toBe(false);
});
it("rejects path traversal in file ID segment", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
"../../etc/passwd",
"download",
]),
).toBe(false);
});
it("rejects path traversal in token segment", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"public",
"shared",
"../../etc/passwd",
"files",
VALID_UUID,
"download",
]),
).toBe(false);
});
it("rejects path traversal replacing fixed segments", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"..",
"files",
VALID_UUID,
"download",
]),
).toBe(false);
expect(
isWorkspaceDownloadRequest([
"..",
"workspace",
"files",
VALID_UUID,
"download",
]),
).toBe(false);
});
it("rejects swapped workspace/public segments to confuse routing", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"public",
"files",
VALID_UUID,
"download",
]),
).toBe(false);
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"shared",
VALID_UUID,
"files",
VALID_UUID_2,
"download",
]),
).toBe(false);
});
it("rejects case variations on fixed segments", () => {
expect(
isWorkspaceDownloadRequest([
"API",
"workspace",
"files",
VALID_UUID,
"download",
]),
).toBe(false);
expect(
isWorkspaceDownloadRequest([
"api",
"Workspace",
"files",
VALID_UUID,
"download",
]),
).toBe(false);
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
VALID_UUID,
"DOWNLOAD",
]),
).toBe(false);
expect(
isWorkspaceDownloadRequest([
"api",
"PUBLIC",
"shared",
VALID_UUID,
"files",
VALID_UUID_2,
"download",
]),
).toBe(false);
expect(
isWorkspaceDownloadRequest([
"api",
"public",
"SHARED",
VALID_UUID,
"files",
VALID_UUID_2,
"download",
]),
).toBe(false);
});
it("rejects empty string in fixed segments", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"",
"files",
VALID_UUID,
"download",
]),
).toBe(false);
});
it("rejects empty token in public share path", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"public",
"shared",
"",
"files",
VALID_UUID,
"download",
]),
).toBe(false);
});
it("rejects empty file ID in public share path", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"public",
"shared",
VALID_UUID,
"files",
"",
"download",
]),
).toBe(false);
});
it("rejects empty file ID in workspace path", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
"",
"download",
]),
).toBe(false);
});
it("rejects UUID with null bytes injected", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
VALID_UUID + "\x00.jpg",
"download",
]),
).toBe(false);
});
it("rejects UUID with trailing garbage", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
VALID_UUID + "-extra",
"download",
]),
).toBe(false);
});
it("rejects UUID with leading garbage", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
"prefix-" + VALID_UUID,
"download",
]),
).toBe(false);
});
it("rejects truncated UUIDs", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
"550e8400-e29b-41d4",
"download",
]),
).toBe(false);
});
it("rejects UUID-length strings with wrong format", () => {
// Right length (36 chars) but missing hyphens
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
"550e8400e29b41d4a716446655440000xxxx",
"download",
]),
).toBe(false);
// Hyphens in wrong positions
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
"550e-8400e29b-41d4a716-44665544-0000",
"download",
]),
).toBe(false);
});
it("rejects UUID with non-hex characters", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
"550e8400-e29b-41d4-a716-44665544000g",
"download",
]),
).toBe(false);
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
"550e8400-e29b-41d4-a716-44665544000!",
"download",
]),
).toBe(false);
});
it("rejects SQL injection via ID segment", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
"'; DROP TABLE files;--",
"download",
]),
).toBe(false);
});
it("rejects padded segments with whitespace", () => {
expect(
isWorkspaceDownloadRequest([
"api",
" workspace",
"files",
VALID_UUID,
"download",
]),
).toBe(false);
expect(
isWorkspaceDownloadRequest([
"api",
"workspace ",
"files",
VALID_UUID,
"download",
]),
).toBe(false);
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
" " + VALID_UUID,
"download",
]),
).toBe(false);
});
it("rejects extra trailing segments after download", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
VALID_UUID,
"download",
"",
]),
).toBe(false);
expect(
isWorkspaceDownloadRequest([
"api",
"public",
"shared",
VALID_UUID,
"files",
VALID_UUID_2,
"download",
"extra",
]),
).toBe(false);
});
it("rejects extra leading segments before api", () => {
expect(
isWorkspaceDownloadRequest([
"prefix",
"api",
"workspace",
"files",
VALID_UUID,
"download",
]),
).toBe(false);
expect(
isWorkspaceDownloadRequest([
"",
"api",
"public",
"shared",
VALID_UUID,
"files",
VALID_UUID_2,
"download",
]),
).toBe(false);
});
it("rejects URL-encoded segment lookalikes", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"workspace%2Ffiles",
VALID_UUID,
"download",
]),
).toBe(false);
expect(
isWorkspaceDownloadRequest([
"api",
"public%2Fshared",
VALID_UUID,
"files",
VALID_UUID_2,
"download",
]),
).toBe(false);
});
it("rejects unicode homoglyph substitutions in fixed segments", () => {
// Cyrillic 'а' (U+0430) instead of Latin 'a'
expect(
isWorkspaceDownloadRequest([
"\u0430pi",
"workspace",
"files",
VALID_UUID,
"download",
]),
).toBe(false);
// Fullwidth '' (U+FF41)
expect(
isWorkspaceDownloadRequest([
"\uff41pi",
"workspace",
"files",
VALID_UUID,
"download",
]),
).toBe(false);
});
it("rejects hybrid path mixing workspace and public patterns", () => {
// 5-segment but with public prefix
expect(
isWorkspaceDownloadRequest([
"api",
"public",
"shared",
VALID_UUID,
"download",
]),
).toBe(false);
// 7-segment but with workspace prefix
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
VALID_UUID,
"files",
VALID_UUID_2,
"download",
]),
).toBe(false);
});
it("rejects download appearing in non-terminal position", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"download",
"files",
VALID_UUID,
]),
).toBe(false);
expect(
isWorkspaceDownloadRequest([
"api",
"public",
"shared",
"download",
"files",
VALID_UUID,
"extra",
]),
).toBe(false);
});
it("rejects prototype pollution segment names as IDs", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
"__proto__",
"download",
]),
).toBe(false);
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
"constructor",
"download",
]),
).toBe(false);
});
it("rejects very long path segments (DoS vector)", () => {
const longId = "a".repeat(10000);
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
longId,
"download",
]),
).toBe(false);
});
it("rejects UUID with embedded path separators", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
"550e8400/e29b-41d4-a716-446655440000",
"download",
]),
).toBe(false);
});
it("rejects UUID-shaped strings with unicode hyphens", () => {
// EN DASH (U+2013) instead of HYPHEN-MINUS
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
"550e8400\u2013e29b\u201341d4\u2013a716\u2013446655440000",
"download",
]),
).toBe(false);
});
it("rejects SSRF-style payloads in ID position", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
"http://169.254.169.254",
"download",
]),
).toBe(false);
});
});
});
describe("isRedirectStatus", () => {

View File

@@ -1,11 +1,34 @@
const UUID_RE =
/^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i;
export function isWorkspaceDownloadRequest(path: string[]): boolean {
return (
path.length == 5 &&
// api/workspace/files/{id}/download
if (
path.length === 5 &&
path[0] === "api" &&
path[1] === "workspace" &&
path[2] === "files" &&
path[path.length - 1] === "download"
);
UUID_RE.test(path[3]) &&
path[4] === "download"
) {
return true;
}
// api/public/shared/{token}/files/{id}/download
if (
path.length === 7 &&
path[0] === "api" &&
path[1] === "public" &&
path[2] === "shared" &&
UUID_RE.test(path[3]) &&
path[4] === "files" &&
UUID_RE.test(path[5]) &&
path[6] === "download"
) {
return true;
}
return false;
}
export function isRedirectStatus(status: number): boolean {

View File

@@ -0,0 +1,54 @@
import { cleanup, render } from "@testing-library/react";
import { afterEach, describe, expect, it } from "vitest";
import { htmlRenderer } from "./HTMLRenderer";
describe("HTMLRenderer", () => {
afterEach(() => {
cleanup();
});
it("renders text/html content in a sandboxed iframe", () => {
const { container } = render(
<>
{htmlRenderer.render("<h1>Hi</h1>", {
mimeType: "text/html",
filename: "page.html",
})}
</>,
);
const iframe = container.querySelector("iframe");
expect(iframe).toBeTruthy();
expect(iframe?.getAttribute("sandbox")).toBe("allow-scripts");
});
it("injects the fragment-link interceptor into the srcDoc (regression)", () => {
const { container } = render(
<>
{htmlRenderer.render(
'<html><head></head><body><a href="#x">x</a><div id="x">x</div></body></html>',
{ mimeType: "text/html", filename: "page.html" },
)}
</>,
);
const srcdoc = container.querySelector("iframe")?.getAttribute("srcdoc");
expect(srcdoc).toBeTruthy();
expect(srcdoc).toContain("__fragmentLinkInterceptor");
expect(srcdoc).toContain('a[href^="#"]');
expect(srcdoc).toContain("scrollIntoView");
});
it("canRender recognises text/html mime type and .html/.htm filenames", () => {
expect(
htmlRenderer.canRender("<h1>Hi</h1>", { mimeType: "text/html" }),
).toBe(true);
expect(
htmlRenderer.canRender("<h1>Hi</h1>", { filename: "report.html" }),
).toBe(true);
expect(
htmlRenderer.canRender("<h1>Hi</h1>", { filename: "report.htm" }),
).toBe(true);
expect(
htmlRenderer.canRender("<h1>Hi</h1>", { mimeType: "text/plain" }),
).toBe(false);
});
});

View File

@@ -1,5 +1,6 @@
import React from "react";
import {
FRAGMENT_LINK_INTERCEPTOR_SCRIPT,
TAILWIND_CDN_URL,
wrapWithHeadInjection,
} from "@/lib/iframe-sandbox-csp";
@@ -13,7 +14,10 @@ import {
function HTMLPreview({ value }: { value: string }) {
// Inject Tailwind CDN — no CSP (see iframe-sandbox-csp.ts for why)
const tailwindScript = `<script src="${TAILWIND_CDN_URL}"></script>`;
const srcDoc = wrapWithHeadInjection(value, tailwindScript);
const srcDoc = wrapWithHeadInjection(
value,
tailwindScript + FRAGMENT_LINK_INTERCEPTOR_SCRIPT,
);
return (
<iframe
sandbox="allow-scripts"

View File

@@ -5,6 +5,7 @@ import {
isWorkspaceURI,
buildWorkspaceURI,
} from "@/lib/workspace-uri";
import { workspaceFileRenderer } from "./WorkspaceFileRenderer";
describe("parseWorkspaceURI", () => {
it("parses a full workspace URI with mime type", () => {
@@ -113,3 +114,26 @@ describe("buildWorkspaceURI", () => {
expect(parsed).toEqual({ fileID: "file-abc", mimeType: "text/plain" });
});
});
describe("workspaceFileRenderer.getDownloadContent", () => {
it("returns auth-proxied URL without share token", () => {
const result = workspaceFileRenderer.getDownloadContent(
"workspace://file-123#image/png",
);
expect(result).not.toBeNull();
expect(result!.data).toBe(
"/api/proxy/api/workspace/files/file-123/download",
);
});
it("returns public share URL when share token is in metadata", () => {
const result = workspaceFileRenderer.getDownloadContent(
"workspace://file-123#image/png",
{ shareToken: "abc-token-123" },
);
expect(result).not.toBeNull();
expect(result!.data).toBe(
"/api/proxy/api/public/shared/abc-token-123/files/file-123/download",
);
});
});

View File

@@ -37,7 +37,10 @@ const audioMimeTypes = [
"audio/flac",
];
function buildDownloadURL(fileID: string): string {
function buildDownloadURL(fileID: string, shareToken?: string): string {
if (shareToken) {
return `/api/proxy/api/public/shared/${shareToken}/files/${fileID}/download`;
}
return `/api/proxy/api/workspace/files/${fileID}/download`;
}
@@ -124,7 +127,7 @@ function renderWorkspaceFile(
const uri = parseWorkspaceURI(String(value));
if (!uri) return null;
const downloadURL = buildDownloadURL(uri.fileID);
const downloadURL = buildDownloadURL(uri.fileID, metadata?.shareToken);
const mimeType = uri.mimeType || metadata?.mimeType || null;
if (mimeType && imageMimeTypes.includes(mimeType)) {
@@ -174,7 +177,7 @@ function getCopyContentWorkspaceFile(
const uri = parseWorkspaceURI(String(value));
if (!uri) return null;
const downloadURL = buildDownloadURL(uri.fileID);
const downloadURL = buildDownloadURL(uri.fileID, metadata?.shareToken);
const mimeType =
uri.mimeType || metadata?.mimeType || "application/octet-stream";
@@ -205,7 +208,7 @@ function getDownloadContentWorkspaceFile(
const filename = metadata?.filename || `file.${ext}`;
return {
data: buildDownloadURL(uri.fileID),
data: buildDownloadURL(uri.fileID, metadata?.shareToken),
filename,
mimeType,
};

View File

@@ -1,74 +1,2 @@
import { OutputRenderer, OutputMetadata } from "../types";
export interface DownloadItem {
value: any;
metadata?: OutputMetadata;
renderer: OutputRenderer;
}
export async function downloadOutputs(items: DownloadItem[]) {
const concatenableTexts: string[] = [];
const nonConcatenableDownloads: Array<{ blob: Blob; filename: string }> = [];
for (const item of items) {
if (item.renderer.isConcatenable(item.value, item.metadata)) {
const copyContent = item.renderer.getCopyContent(
item.value,
item.metadata,
);
if (copyContent) {
// Extract text from CopyContent
let text: string;
if (typeof copyContent.data === "string") {
text = copyContent.data;
} else if (copyContent.fallbackText) {
text = copyContent.fallbackText;
} else {
continue;
}
concatenableTexts.push(text);
}
} else {
const downloadContent = item.renderer.getDownloadContent(
item.value,
item.metadata,
);
if (downloadContent) {
if (typeof downloadContent.data === "string") {
if (downloadContent.data.startsWith("http")) {
const link = document.createElement("a");
link.href = downloadContent.data;
link.download = downloadContent.filename;
link.click();
}
} else {
nonConcatenableDownloads.push({
blob: downloadContent.data as Blob,
filename: downloadContent.filename,
});
}
}
}
}
if (concatenableTexts.length > 0) {
const combinedText = concatenableTexts.join("\n\n---\n\n");
const blob = new Blob([combinedText], { type: "text/plain" });
downloadBlob(blob, "combined_output.txt");
}
for (const download of nonConcatenableDownloads) {
downloadBlob(download.blob, download.filename);
}
}
function downloadBlob(blob: Blob, filename: string) {
const url = URL.createObjectURL(blob);
const link = document.createElement("a");
link.href = url;
link.download = filename;
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
URL.revokeObjectURL(url);
}
export { downloadOutputs } from "@/lib/utils/download-outputs";
export type { DownloadItem } from "@/lib/utils/download-outputs";

View File

@@ -1,5 +1,9 @@
import { describe, expect, it } from "vitest";
import { TAILWIND_CDN_URL, wrapWithHeadInjection } from "../iframe-sandbox-csp";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import {
FRAGMENT_LINK_INTERCEPTOR_SCRIPT,
TAILWIND_CDN_URL,
wrapWithHeadInjection,
} from "../iframe-sandbox-csp";
describe("wrapWithHeadInjection", () => {
const injection = '<script src="https://example.com/lib.js"></script>';
@@ -45,6 +49,142 @@ describe("TAILWIND_CDN_URL", () => {
});
});
describe("FRAGMENT_LINK_INTERCEPTOR_SCRIPT", () => {
// Evaluate the script body (without <script> tags) against the current
// document. Because sandboxed srcdoc iframes run their scripts in isolation
// anyway, the behavior we care about is just "this code, when executed in
// a document, intercepts #anchor clicks and calls scrollIntoView".
//
// Parse the exported <script> via the DOM rather than regex — CodeQL flags
// regex-based HTML stripping, and the test already runs in a DOM env.
function installInterceptor() {
const template = document.createElement("template");
template.innerHTML = FRAGMENT_LINK_INTERCEPTOR_SCRIPT;
const script = template.content.querySelector("script");
if (!script) throw new Error("Interceptor script tag not found");
new Function(script.textContent ?? "")();
}
let cleanup: (() => void) | null = null;
beforeEach(() => {
document.body.innerHTML = "";
});
afterEach(() => {
if (cleanup) cleanup();
cleanup = null;
document.body.innerHTML = "";
const doc = document as Document & {
__fragmentLinkInterceptor?: EventListener;
};
if (doc.__fragmentLinkInterceptor) {
document.removeEventListener("click", doc.__fragmentLinkInterceptor);
delete doc.__fragmentLinkInterceptor;
}
});
it("exports a <script> tag wrapping the interceptor", () => {
expect(FRAGMENT_LINK_INTERCEPTOR_SCRIPT.startsWith("<script>")).toBe(true);
expect(FRAGMENT_LINK_INTERCEPTOR_SCRIPT.endsWith("</script>")).toBe(true);
expect(FRAGMENT_LINK_INTERCEPTOR_SCRIPT).toContain("addEventListener");
expect(FRAGMENT_LINK_INTERCEPTOR_SCRIPT).toContain("scrollIntoView");
expect(FRAGMENT_LINK_INTERCEPTOR_SCRIPT).toContain('a[href^="#"]');
});
// Install the interceptor first, then a tail listener that records
// defaultPrevented. Listeners fire in registration order, so the tail
// sees the post-interceptor state.
function installWithObserver() {
installInterceptor();
const observed = { defaulted: false };
const listener = (e: Event) => {
observed.defaulted = e.defaultPrevented;
};
document.addEventListener("click", listener);
cleanup = () => document.removeEventListener("click", listener);
return observed;
}
it("intercepts fragment-link clicks, calls preventDefault, and scrolls the target into view", () => {
document.body.innerHTML = `
<nav><a id="nav-link" href="#activation">Activation</a></nav>
<section id="activation">Target</section>
`;
const scrollSpy = vi.fn();
document.getElementById("activation")!.scrollIntoView = scrollSpy;
const observed = installWithObserver();
document.getElementById("nav-link")!.click();
expect(scrollSpy).toHaveBeenCalledTimes(1);
expect(observed.defaulted).toBe(true);
});
it("does not intercept bare '#' links (no target id)", () => {
document.body.innerHTML = `<a id="top" href="#">Back to top</a>`;
const observed = installWithObserver();
document.getElementById("top")!.click();
expect(observed.defaulted).toBe(false);
});
it("does not intercept links with no matching target in the document", () => {
document.body.innerHTML = `<a id="dangle" href="#missing">Nowhere</a>`;
const observed = installWithObserver();
document.getElementById("dangle")!.click();
expect(observed.defaulted).toBe(false);
});
it("does not intercept non-fragment links", () => {
document.body.innerHTML = `<a id="ext" href="https://example.com/x">Ext</a>`;
installInterceptor();
const observed = { defaulted: false };
const listener = (e: Event) => {
observed.defaulted = e.defaultPrevented;
e.preventDefault();
};
document.addEventListener("click", listener);
cleanup = () => document.removeEventListener("click", listener);
document.getElementById("ext")!.click();
expect(observed.defaulted).toBe(false);
});
it("scrolls to target when click originates from a nested child of the anchor", () => {
document.body.innerHTML = `
<a id="outer" href="#costs"><span id="inner">💰 Costs</span></a>
<section id="costs">Target</section>
`;
const scrollSpy = vi.fn();
document.getElementById("costs")!.scrollIntoView = scrollSpy;
installInterceptor();
document.getElementById("inner")!.click();
expect(scrollSpy).toHaveBeenCalledTimes(1);
});
it("handles percent-encoded ids", () => {
document.body.innerHTML = `
<a id="enc" href="#top%20costs">Jump</a>
<section id="top costs">Target</section>
`;
const scrollSpy = vi.fn();
document.getElementById("top costs")!.scrollIntoView = scrollSpy;
installInterceptor();
document.getElementById("enc")!.click();
expect(scrollSpy).toHaveBeenCalledTimes(1);
});
});
describe("no CSP is exported", () => {
it("does not export ARTIFACT_IFRAME_CSP", async () => {
const mod = await import("../iframe-sandbox-csp");

View File

@@ -32,6 +32,38 @@
// changes (SRI is not possible because the JIT runtime is generated on demand).
export const TAILWIND_CDN_URL = "https://cdn.tailwindcss.com/3.4.16";
// Sandboxed srcdoc iframes without `allow-same-origin` resolve `href="#id"` links
// against the parent's URL as base. The default click then either navigates the
// iframe to `<parent-url>#id` (reloading our app inside the iframe) or updates
// the parent window's hash — both of which break the artifact preview.
//
// This script stays inside the iframe document and handles in-page anchor
// navigation locally by scrolling to the element with the matching id.
export const FRAGMENT_LINK_INTERCEPTOR_SCRIPT = `<script>
(function() {
if (document.__fragmentLinkInterceptor) return;
function handler(e) {
var t = e.target;
if (!t || typeof t.closest !== 'function') return;
var a = t.closest('a[href^="#"]');
if (!a) return;
var href = a.getAttribute('href');
if (!href || href === '#') return;
var id;
try { id = decodeURIComponent(href.slice(1)); } catch (_) { id = href.slice(1); }
if (!id) return;
var target = document.getElementById(id);
if (!target) return;
e.preventDefault();
if (typeof target.scrollIntoView === 'function') {
target.scrollIntoView({ behavior: 'smooth', block: 'start' });
}
}
document.__fragmentLinkInterceptor = handler;
document.addEventListener('click', handler);
})();
</script>`;
/**
* Inject content into the <head> of an HTML document string.
* If the content has no <head> tag, wraps it in a full document skeleton.

View File

@@ -0,0 +1,423 @@
import { describe, expect, it, vi, beforeEach } from "vitest";
import {
sanitizeFilename,
getUniqueFilename,
downloadOutputs,
} from "../download-outputs";
import type { DownloadItem } from "../download-outputs";
describe("sanitizeFilename", () => {
it("strips forward slashes", () => {
expect(sanitizeFilename("path/to/file.txt")).toBe("path_to_file.txt");
});
it("strips backslashes", () => {
expect(sanitizeFilename("path\\to\\file.txt")).toBe("path_to_file.txt");
});
it("replaces parent directory traversal", () => {
const result = sanitizeFilename("../../etc/passwd");
expect(result).not.toContain("/");
expect(result).not.toContain("\\");
expect(result).not.toContain("..");
expect(result).not.toMatch(/^\./);
});
it("strips leading dots", () => {
expect(sanitizeFilename(".gitignore")).toBe("gitignore");
expect(sanitizeFilename("..hidden")).toBe("hidden");
expect(sanitizeFilename("...triple")).toBe("triple");
});
it("returns 'file' for empty results", () => {
expect(sanitizeFilename("")).toBe("file");
expect(sanitizeFilename("...")).toBe("file");
expect(sanitizeFilename(".")).toBe("file");
});
it("leaves safe filenames unchanged", () => {
expect(sanitizeFilename("report.pdf")).toBe("report.pdf");
expect(sanitizeFilename("image_001.png")).toBe("image_001.png");
});
});
describe("getUniqueFilename", () => {
it("returns the filename when not already used", () => {
const used = new Set<string>();
expect(getUniqueFilename("file.txt", used)).toBe("file.txt");
expect(used.has("file.txt")).toBe(true);
});
it("appends a counter on collision", () => {
const used = new Set<string>(["file.txt"]);
expect(getUniqueFilename("file.txt", used)).toBe("file_1.txt");
expect(used.has("file_1.txt")).toBe(true);
});
it("increments counter until unique", () => {
const used = new Set<string>(["file.txt", "file_1.txt", "file_2.txt"]);
expect(getUniqueFilename("file.txt", used)).toBe("file_3.txt");
});
it("handles filenames without extensions", () => {
const used = new Set<string>(["README"]);
expect(getUniqueFilename("README", used)).toBe("README_1");
});
it("sanitizes the filename before deduplication", () => {
const used = new Set<string>();
expect(getUniqueFilename("../evil.txt", used)).toBe("_evil.txt");
});
it("handles dotfiles by stripping leading dots first", () => {
const used = new Set<string>();
expect(getUniqueFilename(".gitignore", used)).toBe("gitignore");
});
});
const mockZipFile = vi.fn();
const mockGenerateAsync = vi.fn();
let mockZipFiles: Record<string, { async: () => Promise<Blob> }> = {};
vi.mock("jszip", () => ({
default: class MockJSZip {
files = mockZipFiles;
file = (...args: unknown[]) => {
if (typeof args[0] === "string" && args[1] !== undefined) {
const content = args[1];
mockZipFiles[args[0] as string] = {
async: () =>
Promise.resolve(
content instanceof Blob ? content : new Blob([String(content)]),
),
};
}
mockZipFile(...args);
};
generateAsync = mockGenerateAsync;
},
}));
function makeRenderer(overrides: {
isConcatenable?: boolean;
copyData?: string;
downloadData?: Blob | string;
downloadFilename?: string;
}) {
return {
value: "test",
metadata: undefined,
renderer: {
name: "test",
priority: 1,
canRender: () => true,
render: () => null,
isConcatenable: () => overrides.isConcatenable ?? false,
getCopyContent: () =>
overrides.copyData
? { mimeType: "text/plain", data: overrides.copyData }
: null,
getDownloadContent: () =>
overrides.downloadData
? {
data: overrides.downloadData,
filename: overrides.downloadFilename ?? "file.bin",
mimeType: "application/octet-stream",
}
: null,
},
} satisfies DownloadItem;
}
describe("downloadOutputs", () => {
beforeEach(() => {
vi.clearAllMocks();
mockZipFiles = {};
mockGenerateAsync.mockResolvedValue(new Blob(["zip-content"]));
vi.stubGlobal(
"URL",
Object.assign(URL, {
createObjectURL: vi.fn(() => "blob:mock-url"),
revokeObjectURL: vi.fn(),
}),
);
});
it("creates a zip with concatenable text outputs", async () => {
const items = [
makeRenderer({ isConcatenable: true, copyData: "Hello" }),
makeRenderer({ isConcatenable: true, copyData: "World" }),
];
await downloadOutputs(items);
expect(mockZipFile).toHaveBeenCalledWith(
"combined_output.txt",
"Hello\n\n---\n\nWorld",
);
// Single file in zip → downloaded directly, no zip generation
expect(mockGenerateAsync).not.toHaveBeenCalled();
});
it("includes direct blob data in the zip", async () => {
const blob = new Blob(["binary data"]);
const items = [
makeRenderer({ downloadData: blob, downloadFilename: "image.png" }),
];
await downloadOutputs(items);
expect(mockZipFile).toHaveBeenCalledWith("image.png", blob);
});
it("skips blobs exceeding size limit", async () => {
const consoleSpy = vi.spyOn(console, "warn").mockImplementation(() => {});
const bigBlob = new Blob(["x".repeat(100)]);
Object.defineProperty(bigBlob, "size", { value: 60 * 1024 * 1024 });
const items = [
makeRenderer({ downloadData: bigBlob, downloadFilename: "huge.bin" }),
];
await downloadOutputs(items);
expect(consoleSpy).toHaveBeenCalledWith(
expect.stringContaining("blob too large"),
);
expect(mockZipFile).not.toHaveBeenCalledWith("huge.bin", expect.anything());
consoleSpy.mockRestore();
});
it("fetches http URLs and adds to zip", async () => {
const mockBlob = new Blob(["fetched"]);
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
headers: new Headers({ "content-length": "7" }),
blob: () => Promise.resolve(mockBlob),
}),
);
const items = [
makeRenderer({
downloadData: "https://example.com/file.pdf",
downloadFilename: "report.pdf",
}),
];
await downloadOutputs(items);
expect(fetch).toHaveBeenCalledWith("https://example.com/file.pdf", {
mode: "cors",
});
expect(mockZipFile).toHaveBeenCalledWith("report.pdf", mockBlob);
});
it("handles fetch failures gracefully and records unfetchable URLs", async () => {
const consoleSpy = vi.spyOn(console, "warn").mockImplementation(() => {});
vi.stubGlobal("fetch", vi.fn().mockRejectedValue(new Error("CORS error")));
const items = [
makeRenderer({ isConcatenable: true, copyData: "some text" }),
makeRenderer({
downloadData: "https://cors-blocked.com/file.bin",
downloadFilename: "blocked.bin",
}),
];
await downloadOutputs(items);
expect(mockZipFile).toHaveBeenCalledWith(
"unfetched_files.txt",
expect.stringContaining("cors-blocked.com"),
);
consoleSpy.mockRestore();
});
it("handles malformed data URLs with try-catch", async () => {
const consoleSpy = vi.spyOn(console, "warn").mockImplementation(() => {});
vi.stubGlobal(
"fetch",
vi.fn().mockRejectedValue(new Error("Invalid data URL")),
);
const items = [
makeRenderer({
downloadData: "data:invalid",
downloadFilename: "broken.bin",
}),
];
await downloadOutputs(items);
expect(consoleSpy).toHaveBeenCalledWith(
expect.stringContaining("malformed or unsupported format"),
);
consoleSpy.mockRestore();
});
it("skips unsupported URL formats with a warning", async () => {
const consoleSpy = vi.spyOn(console, "warn").mockImplementation(() => {});
const items = [
makeRenderer({
downloadData: "ftp://server/file.dat",
downloadFilename: "file.dat",
}),
];
await downloadOutputs(items);
expect(consoleSpy).toHaveBeenCalledWith(
expect.stringContaining("unsupported URL format"),
);
consoleSpy.mockRestore();
});
it("does nothing when items array is empty", async () => {
await downloadOutputs([]);
expect(mockZipFile).not.toHaveBeenCalled();
expect(mockGenerateAsync).not.toHaveBeenCalled();
});
it("fetches relative URLs (workspace files) and adds to zip", async () => {
const mockBlob = new Blob(["image-data"]);
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
headers: new Headers({ "content-length": "10" }),
blob: () => Promise.resolve(mockBlob),
}),
);
const items = [
makeRenderer({
downloadData: "/api/proxy/api/workspace/files/abc-123/download",
downloadFilename: "photo.png",
}),
];
await downloadOutputs(items);
expect(fetch).toHaveBeenCalledWith(
"/api/proxy/api/workspace/files/abc-123/download",
{ mode: "cors" },
);
expect(mockZipFile).toHaveBeenCalledWith("photo.png", mockBlob);
});
it("includes workspace images that renderers return as relative URLs", async () => {
const mockBlob = new Blob(["img"]);
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
headers: new Headers({ "content-length": "3" }),
blob: () => Promise.resolve(mockBlob),
}),
);
const items = [
makeRenderer({
downloadData: "/api/proxy/api/workspace/files/file-1/download",
downloadFilename: "image1.png",
}),
makeRenderer({
downloadData: "/api/proxy/api/workspace/files/file-2/download",
downloadFilename: "image2.jpg",
}),
];
await downloadOutputs(items);
expect(mockZipFile).toHaveBeenCalledWith("image1.png", mockBlob);
expect(mockZipFile).toHaveBeenCalledWith("image2.jpg", mockBlob);
});
it("fetches public share endpoint URLs for workspace files", async () => {
const mockBlob = new Blob(["shared-image-data"]);
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
headers: new Headers({ "content-length": "17" }),
blob: () => Promise.resolve(mockBlob),
}),
);
const items = [
makeRenderer({
downloadData:
"/api/proxy/api/public/shared/abc-token/files/file-123/download",
downloadFilename: "shared-image.png",
}),
];
await downloadOutputs(items);
expect(fetch).toHaveBeenCalledWith(
"/api/proxy/api/public/shared/abc-token/files/file-123/download",
{ mode: "cors" },
);
expect(mockZipFile).toHaveBeenCalledWith("shared-image.png", mockBlob);
});
it("rejects files over content-length before buffering", async () => {
const consoleSpy = vi.spyOn(console, "warn").mockImplementation(() => {});
const blobFn = vi.fn();
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
headers: new Headers({
"content-length": String(60 * 1024 * 1024),
}),
blob: blobFn,
}),
);
const items = [
makeRenderer({
downloadData: "https://example.com/huge.zip",
downloadFilename: "huge.zip",
}),
];
await downloadOutputs(items);
expect(blobFn).not.toHaveBeenCalled();
expect(consoleSpy).toHaveBeenCalledWith(
expect.stringContaining("file too large"),
);
consoleSpy.mockRestore();
});
it("downloads single file directly without zip wrapping", async () => {
const blob = new Blob(["single file"]);
const items = [
makeRenderer({ downloadData: blob, downloadFilename: "photo.png" }),
];
await downloadOutputs(items);
expect(mockZipFile).toHaveBeenCalledWith("photo.png", blob);
expect(mockGenerateAsync).not.toHaveBeenCalled();
expect(URL.createObjectURL).toHaveBeenCalled();
});
it("uses zip when multiple files are present", async () => {
const blob1 = new Blob(["file1"]);
const blob2 = new Blob(["file2"]);
const items = [
makeRenderer({ downloadData: blob1, downloadFilename: "a.png" }),
makeRenderer({ downloadData: blob2, downloadFilename: "b.png" }),
];
await downloadOutputs(items);
expect(mockGenerateAsync).toHaveBeenCalledWith({ type: "blob" });
});
});

View File

@@ -0,0 +1,282 @@
import type {
OutputRenderer,
OutputMetadata,
} from "@/components/contextual/OutputRenderers/types";
export interface DownloadItem {
value: unknown;
metadata?: OutputMetadata;
renderer: OutputRenderer;
}
/** Maximum individual file size for zip inclusion (50 MB) */
const MAX_FILE_SIZE_BYTES = 50 * 1024 * 1024;
/** Maximum total zip content size before generation (200 MB) */
const MAX_TOTAL_SIZE_BYTES = 200 * 1024 * 1024;
/** Maximum concurrent file fetches */
const FETCH_CONCURRENCY = 5;
async function fetchFileAsBlob(url: string): Promise<Blob | null> {
try {
const response = await fetch(url, { mode: "cors" });
if (!response.ok) {
console.error(`Failed to fetch ${url}: ${response.status}`);
return null;
}
const contentLength = Number(response.headers.get("content-length") ?? "0");
if (contentLength > MAX_FILE_SIZE_BYTES) {
console.warn(
`Skipping ${url}: file too large (${(contentLength / 1024 / 1024).toFixed(1)} MB, limit ${MAX_FILE_SIZE_BYTES / 1024 / 1024} MB)`,
);
return null;
}
const blob = await response.blob();
if (blob.size > MAX_FILE_SIZE_BYTES) {
console.warn(
`Skipping ${url}: file too large (${(blob.size / 1024 / 1024).toFixed(1)} MB, limit ${MAX_FILE_SIZE_BYTES / 1024 / 1024} MB)`,
);
return null;
}
return blob;
} catch (_error) {
console.warn(
`Could not fetch ${url} (likely CORS). Adding as link reference.`,
);
return null;
}
}
/** Strip path traversal components and unsafe characters from a filename. */
export function sanitizeFilename(filename: string): string {
const sanitized = filename
.replace(/[/\\]/g, "_")
.replace(/^\.+/, "")
.replace(/\.\./g, "_");
return sanitized || "file";
}
export function getUniqueFilename(
filename: string,
usedNames: Set<string>,
): string {
const safe = sanitizeFilename(filename);
if (!usedNames.has(safe)) {
usedNames.add(safe);
return safe;
}
const dotIndex = safe.lastIndexOf(".");
const baseName = dotIndex > 0 ? safe.slice(0, dotIndex) : safe;
const extension = dotIndex > 0 ? safe.slice(dotIndex) : "";
let counter = 1;
let newName = `${baseName}_${counter}${extension}`;
while (usedNames.has(newName)) {
counter++;
newName = `${baseName}_${counter}${extension}`;
}
usedNames.add(newName);
return newName;
}
async function fetchInParallel<T>(
tasks: (() => Promise<T>)[],
concurrency: number,
): Promise<T[]> {
const results: T[] = [];
let index = 0;
async function worker() {
while (index < tasks.length) {
const i = index++;
results[i] = await tasks[i]();
}
}
await Promise.all(
Array.from({ length: Math.min(concurrency, tasks.length) }, () => worker()),
);
return results;
}
type FetchResult = {
blob: Blob | null;
filename: string;
sourceUrl: string | null;
};
export async function downloadOutputs(items: DownloadItem[]) {
if (items.length === 0) return;
const { default: JSZip } = await import("jszip");
const zip = new JSZip();
const usedFilenames = new Set<string>();
let hasFiles = false;
let totalSize = 0;
const concatenableTexts: string[] = [];
const unfetchableUrls: string[] = [];
const fileItems: Array<{
downloadContent: { data: unknown; filename: string };
}> = [];
for (const item of items) {
if (item.renderer.isConcatenable(item.value, item.metadata)) {
const copyContent = item.renderer.getCopyContent(
item.value,
item.metadata,
);
if (copyContent) {
let text: string;
if (typeof copyContent.data === "string") {
text = copyContent.data;
} else if (copyContent.fallbackText) {
text = copyContent.fallbackText;
} else {
continue;
}
concatenableTexts.push(text);
}
} else {
const downloadContent = item.renderer.getDownloadContent(
item.value,
item.metadata,
);
if (downloadContent) {
fileItems.push({ downloadContent });
}
}
}
const fetchTasks = fileItems.map(
({ downloadContent }) =>
async (): Promise<FetchResult> => {
let blob: Blob | null = null;
const filename = downloadContent.filename;
let sourceUrl: string | null = null;
if (typeof downloadContent.data === "string") {
if (
downloadContent.data.startsWith("http://") ||
downloadContent.data.startsWith("https://") ||
downloadContent.data.startsWith("/")
) {
sourceUrl = downloadContent.data;
blob = await fetchFileAsBlob(downloadContent.data);
} else if (downloadContent.data.startsWith("data:")) {
try {
const dataBlob = await fetch(downloadContent.data).then((r) =>
r.blob(),
);
if (dataBlob.size <= MAX_FILE_SIZE_BYTES) {
blob = dataBlob;
} else {
console.warn(
`Skipping data URL: too large (${(dataBlob.size / 1024 / 1024).toFixed(1)} MB)`,
);
}
} catch (_error) {
console.warn(
`Failed to process data URL for ${filename}: malformed or unsupported format`,
);
}
} else {
console.warn(
`Skipping unsupported URL format: ${downloadContent.data.slice(0, 50)}...`,
);
}
} else {
const rawBlob = downloadContent.data as Blob;
if (rawBlob.size <= MAX_FILE_SIZE_BYTES) {
blob = rawBlob;
} else {
console.warn(
`Skipping ${filename}: blob too large (${(rawBlob.size / 1024 / 1024).toFixed(1)} MB)`,
);
}
}
return { blob, filename, sourceUrl };
},
);
const results = await fetchInParallel(fetchTasks, FETCH_CONCURRENCY);
for (const { blob, filename, sourceUrl } of results) {
if (blob) {
if (totalSize + blob.size > MAX_TOTAL_SIZE_BYTES) {
console.warn(
`Skipping ${filename}: would exceed total zip size limit (${MAX_TOTAL_SIZE_BYTES / 1024 / 1024} MB)`,
);
if (sourceUrl) unfetchableUrls.push(sourceUrl);
continue;
}
const uniqueFilename = getUniqueFilename(filename, usedFilenames);
zip.file(uniqueFilename, blob);
totalSize += blob.size;
hasFiles = true;
} else if (sourceUrl) {
unfetchableUrls.push(sourceUrl);
}
}
if (concatenableTexts.length > 0) {
const combinedText = concatenableTexts.join("\n\n---\n\n");
const textSize = new Blob([combinedText]).size;
if (totalSize + textSize <= MAX_TOTAL_SIZE_BYTES) {
const filename = getUniqueFilename("combined_output.txt", usedFilenames);
zip.file(filename, combinedText);
totalSize += textSize;
hasFiles = true;
}
}
if (unfetchableUrls.length > 0) {
const linksContent = unfetchableUrls
.map((url, i) => `${i + 1}. ${url}`)
.join("\n");
const manifest = `The following files could not be included in the zip (CORS restriction or size limit).\nYou can download them directly from these URLs:\n\n${linksContent}\n`;
const manifestSize = new Blob([manifest]).size;
if (totalSize + manifestSize <= MAX_TOTAL_SIZE_BYTES) {
const manifestFilename = getUniqueFilename(
"unfetched_files.txt",
usedFilenames,
);
zip.file(manifestFilename, manifest);
totalSize += manifestSize;
hasFiles = true;
}
}
if (!hasFiles) return;
// Single-file shortcut: download directly instead of wrapping in a zip
if (
zip.files &&
Object.keys(zip.files).length === 1 &&
unfetchableUrls.length === 0
) {
const onlyFilename = Object.keys(zip.files)[0];
const entry = zip.files[onlyFilename];
const content = await entry.async("blob");
downloadBlob(content, onlyFilename);
return;
}
const zipBlob = await zip.generateAsync({ type: "blob" });
downloadBlob(zipBlob, "outputs.zip");
}
function downloadBlob(blob: Blob, filename: string) {
const url = URL.createObjectURL(blob);
const link = document.createElement("a");
link.href = url;
link.download = filename;
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
URL.revokeObjectURL(url);
}