mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into feat/copilot-kimi-k2-fast-model
This commit is contained in:
3
autogpt_platform/.gitignore
vendored
3
autogpt_platform/.gitignore
vendored
@@ -1,3 +1,6 @@
|
||||
*.ignore.*
|
||||
*.ign.*
|
||||
.application.logs
|
||||
|
||||
# Claude Code local settings only — the rest of .claude/ is shared (skills etc.)
|
||||
.claude/settings.local.json
|
||||
|
||||
@@ -179,6 +179,9 @@ MEM0_API_KEY=
|
||||
OPENWEATHERMAP_API_KEY=
|
||||
GOOGLE_MAPS_API_KEY=
|
||||
|
||||
# Platform Bot Linking
|
||||
PLATFORM_LINK_BASE_URL=http://localhost:3000/link
|
||||
|
||||
# Communication Services
|
||||
DISCORD_BOT_TOKEN=
|
||||
MEDIUM_API_KEY=
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Platform bot linking — user-facing REST routes."""
|
||||
@@ -0,0 +1,158 @@
|
||||
"""User-facing platform_linking REST routes (JWT auth)."""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, HTTPException, Path, Security
|
||||
|
||||
from backend.data.db_accessors import platform_linking_db
|
||||
from backend.platform_linking.models import (
|
||||
ConfirmLinkResponse,
|
||||
ConfirmUserLinkResponse,
|
||||
DeleteLinkResponse,
|
||||
LinkTokenInfoResponse,
|
||||
PlatformLinkInfo,
|
||||
PlatformUserLinkInfo,
|
||||
)
|
||||
from backend.util.exceptions import (
|
||||
LinkAlreadyExistsError,
|
||||
LinkFlowMismatchError,
|
||||
LinkTokenExpiredError,
|
||||
NotAuthorizedError,
|
||||
NotFoundError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
TokenPath = Annotated[
|
||||
str,
|
||||
Path(max_length=64, pattern=r"^[A-Za-z0-9_-]+$"),
|
||||
]
|
||||
|
||||
|
||||
def _translate(exc: Exception) -> HTTPException:
|
||||
if isinstance(exc, NotFoundError):
|
||||
return HTTPException(status_code=404, detail=str(exc))
|
||||
if isinstance(exc, NotAuthorizedError):
|
||||
return HTTPException(status_code=403, detail=str(exc))
|
||||
if isinstance(exc, LinkAlreadyExistsError):
|
||||
return HTTPException(status_code=409, detail=str(exc))
|
||||
if isinstance(exc, LinkTokenExpiredError):
|
||||
return HTTPException(status_code=410, detail=str(exc))
|
||||
if isinstance(exc, LinkFlowMismatchError):
|
||||
return HTTPException(status_code=400, detail=str(exc))
|
||||
return HTTPException(status_code=500, detail="Internal error.")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tokens/{token}/info",
|
||||
response_model=LinkTokenInfoResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Get display info for a link token",
|
||||
)
|
||||
async def get_link_token_info_route(token: TokenPath) -> LinkTokenInfoResponse:
|
||||
try:
|
||||
return await platform_linking_db().get_link_token_info(token)
|
||||
except (NotFoundError, LinkTokenExpiredError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tokens/{token}/confirm",
|
||||
response_model=ConfirmLinkResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Confirm a SERVER link token (user must be authenticated)",
|
||||
)
|
||||
async def confirm_link_token(
|
||||
token: TokenPath,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> ConfirmLinkResponse:
|
||||
try:
|
||||
return await platform_linking_db().confirm_server_link(token, user_id)
|
||||
except (
|
||||
NotFoundError,
|
||||
LinkFlowMismatchError,
|
||||
LinkTokenExpiredError,
|
||||
LinkAlreadyExistsError,
|
||||
) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/user-tokens/{token}/confirm",
|
||||
response_model=ConfirmUserLinkResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Confirm a USER link token (user must be authenticated)",
|
||||
)
|
||||
async def confirm_user_link_token(
|
||||
token: TokenPath,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> ConfirmUserLinkResponse:
|
||||
try:
|
||||
return await platform_linking_db().confirm_user_link(token, user_id)
|
||||
except (
|
||||
NotFoundError,
|
||||
LinkFlowMismatchError,
|
||||
LinkTokenExpiredError,
|
||||
LinkAlreadyExistsError,
|
||||
) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/links",
|
||||
response_model=list[PlatformLinkInfo],
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="List all platform servers linked to the authenticated user",
|
||||
)
|
||||
async def list_my_links(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> list[PlatformLinkInfo]:
|
||||
return await platform_linking_db().list_server_links(user_id)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/user-links",
|
||||
response_model=list[PlatformUserLinkInfo],
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="List all DM links for the authenticated user",
|
||||
)
|
||||
async def list_my_user_links(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> list[PlatformUserLinkInfo]:
|
||||
return await platform_linking_db().list_user_links(user_id)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/links/{link_id}",
|
||||
response_model=DeleteLinkResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Unlink a platform server",
|
||||
)
|
||||
async def delete_link(
|
||||
link_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> DeleteLinkResponse:
|
||||
try:
|
||||
return await platform_linking_db().delete_server_link(link_id, user_id)
|
||||
except (NotFoundError, NotAuthorizedError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/user-links/{link_id}",
|
||||
response_model=DeleteLinkResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Unlink a DM / user link",
|
||||
)
|
||||
async def delete_user_link_route(
|
||||
link_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> DeleteLinkResponse:
|
||||
try:
|
||||
return await platform_linking_db().delete_user_link(link_id, user_id)
|
||||
except (NotFoundError, NotAuthorizedError) as exc:
|
||||
raise _translate(exc) from exc
|
||||
@@ -0,0 +1,264 @@
|
||||
"""Route tests: domain exceptions → HTTPException status codes."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from backend.util.exceptions import (
|
||||
LinkAlreadyExistsError,
|
||||
LinkFlowMismatchError,
|
||||
LinkTokenExpiredError,
|
||||
NotAuthorizedError,
|
||||
NotFoundError,
|
||||
)
|
||||
|
||||
|
||||
def _db_mock(**method_configs):
|
||||
"""Return a mock of the accessor's return value with the given AsyncMocks."""
|
||||
db = MagicMock()
|
||||
for name, mock in method_configs.items():
|
||||
setattr(db, name, mock)
|
||||
return db
|
||||
|
||||
|
||||
class TestTokenInfoRouteTranslation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found_maps_to_404(self):
|
||||
from backend.api.features.platform_linking.routes import (
|
||||
get_link_token_info_route,
|
||||
)
|
||||
|
||||
db = _db_mock(
|
||||
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await get_link_token_info_route(token="abc")
|
||||
assert exc.value.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_maps_to_410(self):
|
||||
from backend.api.features.platform_linking.routes import (
|
||||
get_link_token_info_route,
|
||||
)
|
||||
|
||||
db = _db_mock(
|
||||
get_link_token_info=AsyncMock(side_effect=LinkTokenExpiredError("expired"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await get_link_token_info_route(token="abc")
|
||||
assert exc.value.status_code == 410
|
||||
|
||||
|
||||
class TestConfirmLinkRouteTranslation:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"exc,expected_status",
|
||||
[
|
||||
(NotFoundError("missing"), 404),
|
||||
(LinkFlowMismatchError("wrong flow"), 400),
|
||||
(LinkTokenExpiredError("expired"), 410),
|
||||
(LinkAlreadyExistsError("already"), 409),
|
||||
],
|
||||
)
|
||||
async def test_translation(self, exc: Exception, expected_status: int):
|
||||
from backend.api.features.platform_linking.routes import confirm_link_token
|
||||
|
||||
db = _db_mock(confirm_server_link=AsyncMock(side_effect=exc))
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as ctx:
|
||||
await confirm_link_token(token="abc", user_id="u1")
|
||||
assert ctx.value.status_code == expected_status
|
||||
|
||||
|
||||
class TestConfirmUserLinkRouteTranslation:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"exc,expected_status",
|
||||
[
|
||||
(NotFoundError("missing"), 404),
|
||||
(LinkFlowMismatchError("wrong flow"), 400),
|
||||
(LinkTokenExpiredError("expired"), 410),
|
||||
(LinkAlreadyExistsError("already"), 409),
|
||||
],
|
||||
)
|
||||
async def test_translation(self, exc: Exception, expected_status: int):
|
||||
from backend.api.features.platform_linking.routes import confirm_user_link_token
|
||||
|
||||
db = _db_mock(confirm_user_link=AsyncMock(side_effect=exc))
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as ctx:
|
||||
await confirm_user_link_token(token="abc", user_id="u1")
|
||||
assert ctx.value.status_code == expected_status
|
||||
|
||||
|
||||
class TestDeleteLinkRouteTranslation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found_maps_to_404(self):
|
||||
from backend.api.features.platform_linking.routes import delete_link
|
||||
|
||||
db = _db_mock(
|
||||
delete_server_link=AsyncMock(side_effect=NotFoundError("missing"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await delete_link(link_id="x", user_id="u1")
|
||||
assert exc.value.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_owned_maps_to_403(self):
|
||||
from backend.api.features.platform_linking.routes import delete_link
|
||||
|
||||
db = _db_mock(
|
||||
delete_server_link=AsyncMock(side_effect=NotAuthorizedError("nope"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await delete_link(link_id="x", user_id="u1")
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
class TestDeleteUserLinkRouteTranslation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found_maps_to_404(self):
|
||||
from backend.api.features.platform_linking.routes import delete_user_link_route
|
||||
|
||||
db = _db_mock(delete_user_link=AsyncMock(side_effect=NotFoundError("missing")))
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await delete_user_link_route(link_id="x", user_id="u1")
|
||||
assert exc.value.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_owned_maps_to_403(self):
|
||||
from backend.api.features.platform_linking.routes import delete_user_link_route
|
||||
|
||||
db = _db_mock(
|
||||
delete_user_link=AsyncMock(side_effect=NotAuthorizedError("nope"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await delete_user_link_route(link_id="x", user_id="u1")
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
# ── Adversarial: malformed token path params ──────────────────────────
|
||||
|
||||
|
||||
class TestAdversarialTokenPath:
|
||||
# TokenPath enforces `^[A-Za-z0-9_-]+$` + max_length=64.
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
import fastapi
|
||||
from autogpt_libs.auth import get_user_id, requires_user
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import backend.api.features.platform_linking.routes as routes_mod
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.dependency_overrides[requires_user] = lambda: None
|
||||
app.dependency_overrides[get_user_id] = lambda: "caller-user"
|
||||
app.include_router(routes_mod.router, prefix="/api/platform-linking")
|
||||
return TestClient(app)
|
||||
|
||||
def test_rejects_token_with_special_chars(self, client):
|
||||
response = client.get("/api/platform-linking/tokens/bad%24token/info")
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_rejects_token_with_path_traversal(self, client):
|
||||
for probe in ("..%2F..", "foo..bar", "foo%2Fbar"):
|
||||
response = client.get(f"/api/platform-linking/tokens/{probe}/info")
|
||||
assert response.status_code in (
|
||||
404,
|
||||
422,
|
||||
), f"path-traversal probe {probe!r} returned {response.status_code}"
|
||||
|
||||
def test_rejects_token_too_long(self, client):
|
||||
long_token = "a" * 65
|
||||
response = client.get(f"/api/platform-linking/tokens/{long_token}/info")
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_accepts_token_at_max_length(self, client):
|
||||
token = "a" * 64
|
||||
db = _db_mock(
|
||||
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
response = client.get(f"/api/platform-linking/tokens/{token}/info")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_accepts_urlsafe_b64_token_shape(self, client):
|
||||
db = _db_mock(
|
||||
get_link_token_info=AsyncMock(side_effect=NotFoundError("missing"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
response = client.get("/api/platform-linking/tokens/abc-_XYZ123-_abc/info")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_confirm_rejects_malformed_token(self, client):
|
||||
response = client.post("/api/platform-linking/tokens/bad%24token/confirm")
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
class TestAdversarialDeleteLinkId:
|
||||
"""DELETE link_id has no regex — ensure weird values are handled via
|
||||
NotFoundError (no crash, no cross-user leak)."""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
import fastapi
|
||||
from autogpt_libs.auth import get_user_id, requires_user
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import backend.api.features.platform_linking.routes as routes_mod
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.dependency_overrides[requires_user] = lambda: None
|
||||
app.dependency_overrides[get_user_id] = lambda: "caller-user"
|
||||
app.include_router(routes_mod.router, prefix="/api/platform-linking")
|
||||
return TestClient(app)
|
||||
|
||||
def test_weird_link_id_returns_404(self, client):
|
||||
db = _db_mock(
|
||||
delete_server_link=AsyncMock(side_effect=NotFoundError("missing"))
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.platform_linking.routes.platform_linking_db",
|
||||
return_value=db,
|
||||
):
|
||||
for link_id in ("'; DROP TABLE links;--", "../../etc/passwd", ""):
|
||||
response = client.delete(f"/api/platform-linking/links/{link_id}")
|
||||
assert response.status_code in (404, 405)
|
||||
@@ -30,6 +30,7 @@ from pydantic import BaseModel, Field
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
from backend.api.model import (
|
||||
CreateAPIKeyRequest,
|
||||
CreateAPIKeyResponse,
|
||||
@@ -96,6 +97,7 @@ from backend.data.user import (
|
||||
update_user_notification_preference,
|
||||
update_user_timezone,
|
||||
)
|
||||
from backend.data.workspace import get_workspace_file_by_id
|
||||
from backend.executor import scheduler
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||
@@ -1703,6 +1705,10 @@ async def enable_execution_sharing(
|
||||
# Generate a unique share token
|
||||
share_token = str(uuid.uuid4())
|
||||
|
||||
# Remove stale allowlist records before updating the token — prevents a
|
||||
# window where old records + new token could coexist.
|
||||
await execution_db.delete_shared_execution_files(execution_id=graph_exec_id)
|
||||
|
||||
# Update the execution with share info
|
||||
await execution_db.update_graph_execution_share_status(
|
||||
execution_id=graph_exec_id,
|
||||
@@ -1712,6 +1718,14 @@ async def enable_execution_sharing(
|
||||
shared_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Create allowlist of workspace files referenced in outputs
|
||||
await execution_db.create_shared_execution_files(
|
||||
execution_id=graph_exec_id,
|
||||
share_token=share_token,
|
||||
user_id=user_id,
|
||||
outputs=execution.outputs,
|
||||
)
|
||||
|
||||
# Return the share URL
|
||||
frontend_url = settings.config.frontend_base_url or "http://localhost:3000"
|
||||
share_url = f"{frontend_url}/share/{share_token}"
|
||||
@@ -1737,6 +1751,9 @@ async def disable_execution_sharing(
|
||||
if not execution:
|
||||
raise HTTPException(status_code=404, detail="Execution not found")
|
||||
|
||||
# Remove shared file allowlist records
|
||||
await execution_db.delete_shared_execution_files(execution_id=graph_exec_id)
|
||||
|
||||
# Remove share info
|
||||
await execution_db.update_graph_execution_share_status(
|
||||
execution_id=graph_exec_id,
|
||||
@@ -1762,6 +1779,43 @@ async def get_shared_execution(
|
||||
return execution
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
"/public/shared/{share_token}/files/{file_id}/download",
|
||||
summary="Download a file from a shared execution",
|
||||
operation_id="download_shared_file",
|
||||
tags=["graphs"],
|
||||
)
|
||||
async def download_shared_file(
|
||||
share_token: Annotated[
|
||||
str,
|
||||
Path(pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"),
|
||||
],
|
||||
file_id: Annotated[
|
||||
str,
|
||||
Path(pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"),
|
||||
],
|
||||
) -> Response:
|
||||
"""Download a workspace file from a shared execution (no auth required).
|
||||
|
||||
Validates that the file was explicitly exposed when sharing was enabled.
|
||||
Returns a uniform 404 for all failure modes to prevent enumeration attacks.
|
||||
"""
|
||||
# Single-query validation against the allowlist
|
||||
execution_id = await execution_db.get_shared_execution_file(
|
||||
share_token=share_token, file_id=file_id
|
||||
)
|
||||
if not execution_id:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
|
||||
# Look up the actual file (no workspace scoping needed — the allowlist
|
||||
# already validated that this file belongs to the shared execution)
|
||||
file = await get_workspace_file_by_id(file_id)
|
||||
if not file:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
|
||||
return await create_file_download_response(file, inline=True)
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Schedules ########################
|
||||
########################################################
|
||||
|
||||
157
autogpt_platform/backend/backend/api/features/v1_share_test.py
Normal file
157
autogpt_platform/backend/backend/api/features/v1_share_test.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""Tests for the public shared file download endpoint."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.responses import Response
|
||||
|
||||
from backend.api.features.v1 import v1_router
|
||||
from backend.data.workspace import WorkspaceFile
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(v1_router, prefix="/api")
|
||||
|
||||
VALID_TOKEN = "550e8400-e29b-41d4-a716-446655440000"
|
||||
VALID_FILE_ID = "6ba7b810-9dad-11d1-80b4-00c04fd430c8"
|
||||
|
||||
|
||||
def _make_workspace_file(**overrides) -> WorkspaceFile:
|
||||
defaults = {
|
||||
"id": VALID_FILE_ID,
|
||||
"workspace_id": "ws-001",
|
||||
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"updated_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"name": "image.png",
|
||||
"path": "/image.png",
|
||||
"storage_path": "local://uploads/image.png",
|
||||
"mime_type": "image/png",
|
||||
"size_bytes": 4,
|
||||
"checksum": None,
|
||||
"is_deleted": False,
|
||||
"deleted_at": None,
|
||||
"metadata": {},
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return WorkspaceFile(**defaults)
|
||||
|
||||
|
||||
def _mock_download_response(**kwargs):
|
||||
"""Return an AsyncMock that resolves to a Response with inline disposition."""
|
||||
|
||||
async def _handler(file, *, inline=False):
|
||||
return Response(
|
||||
content=b"\x89PNG",
|
||||
media_type="image/png",
|
||||
headers={
|
||||
"Content-Disposition": (
|
||||
'inline; filename="image.png"'
|
||||
if inline
|
||||
else 'attachment; filename="image.png"'
|
||||
),
|
||||
"Content-Length": "4",
|
||||
},
|
||||
)
|
||||
|
||||
return _handler
|
||||
|
||||
|
||||
class TestDownloadSharedFile:
|
||||
"""Tests for GET /api/public/shared/{token}/files/{id}/download."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _client(self):
|
||||
self.client = TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
def test_valid_token_and_file_returns_inline_content(self):
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||
new_callable=AsyncMock,
|
||||
return_value="exec-123",
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.v1.get_workspace_file_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_make_workspace_file(),
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.v1.create_file_download_response",
|
||||
side_effect=_mock_download_response(),
|
||||
),
|
||||
):
|
||||
response = self.client.get(
|
||||
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"\x89PNG"
|
||||
assert "inline" in response.headers["Content-Disposition"]
|
||||
|
||||
def test_invalid_token_format_returns_422(self):
|
||||
response = self.client.get(
|
||||
f"/api/public/shared/not-a-uuid/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_token_not_in_allowlist_returns_404(self):
|
||||
with patch(
|
||||
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
response = self.client.get(
|
||||
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_file_missing_from_workspace_returns_404(self):
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||
new_callable=AsyncMock,
|
||||
return_value="exec-123",
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.v1.get_workspace_file_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response = self.client.get(
|
||||
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_uniform_404_prevents_enumeration(self):
|
||||
"""Both failure modes produce identical 404 — no information leak."""
|
||||
with patch(
|
||||
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
resp_no_allow = self.client.get(
|
||||
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.v1.execution_db.get_shared_execution_file",
|
||||
new_callable=AsyncMock,
|
||||
return_value="exec-123",
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.v1.get_workspace_file_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
resp_no_file = self.client.get(
|
||||
f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download"
|
||||
)
|
||||
|
||||
assert resp_no_allow.status_code == 404
|
||||
assert resp_no_file.status_code == 404
|
||||
assert resp_no_allow.json() == resp_no_file.json()
|
||||
@@ -29,7 +29,9 @@ from backend.util.workspace import WorkspaceManager
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
|
||||
def _sanitize_filename_for_header(filename: str) -> str:
|
||||
def _sanitize_filename_for_header(
|
||||
filename: str, disposition: str = "attachment"
|
||||
) -> str:
|
||||
"""
|
||||
Sanitize filename for Content-Disposition header to prevent header injection.
|
||||
|
||||
@@ -44,11 +46,11 @@ def _sanitize_filename_for_header(filename: str) -> str:
|
||||
# Check if filename has non-ASCII characters
|
||||
try:
|
||||
sanitized.encode("ascii")
|
||||
return f'attachment; filename="{sanitized}"'
|
||||
return f'{disposition}; filename="{sanitized}"'
|
||||
except UnicodeEncodeError:
|
||||
# Use RFC5987 encoding for UTF-8 filenames
|
||||
encoded = quote(sanitized, safe="")
|
||||
return f"attachment; filename*=UTF-8''{encoded}"
|
||||
return f"{disposition}; filename*=UTF-8''{encoded}"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -58,19 +60,26 @@ router = fastapi.APIRouter(
|
||||
)
|
||||
|
||||
|
||||
def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response:
|
||||
def _create_streaming_response(
|
||||
content: bytes, file: WorkspaceFile, *, inline: bool = False
|
||||
) -> Response:
|
||||
"""Create a streaming response for file content."""
|
||||
disposition = _sanitize_filename_for_header(
|
||||
file.name, disposition="inline" if inline else "attachment"
|
||||
)
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=file.mime_type,
|
||||
headers={
|
||||
"Content-Disposition": _sanitize_filename_for_header(file.name),
|
||||
"Content-Disposition": disposition,
|
||||
"Content-Length": str(len(content)),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
||||
async def create_file_download_response(
|
||||
file: WorkspaceFile, *, inline: bool = False
|
||||
) -> Response:
|
||||
"""
|
||||
Create a download response for a workspace file.
|
||||
|
||||
@@ -82,7 +91,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
||||
# For local storage, stream the file directly
|
||||
if file.storage_path.startswith("local://"):
|
||||
content = await storage.retrieve(file.storage_path)
|
||||
return _create_streaming_response(content, file)
|
||||
return _create_streaming_response(content, file, inline=inline)
|
||||
|
||||
# For GCS, try to redirect to signed URL, fall back to streaming
|
||||
try:
|
||||
@@ -90,7 +99,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
||||
# If we got back an API path (fallback), stream directly instead
|
||||
if url.startswith("/api/"):
|
||||
content = await storage.retrieve(file.storage_path)
|
||||
return _create_streaming_response(content, file)
|
||||
return _create_streaming_response(content, file, inline=inline)
|
||||
return fastapi.responses.RedirectResponse(url=url, status_code=302)
|
||||
except Exception as e:
|
||||
# Log the signed URL failure with context
|
||||
@@ -102,7 +111,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
||||
# Fall back to streaming directly from GCS
|
||||
try:
|
||||
content = await storage.retrieve(file.storage_path)
|
||||
return _create_streaming_response(content, file)
|
||||
return _create_streaming_response(content, file, inline=inline)
|
||||
except Exception as fallback_error:
|
||||
logger.error(
|
||||
f"Fallback streaming also failed for file {file.id} "
|
||||
@@ -169,7 +178,7 @@ async def download_file(
|
||||
if file is None:
|
||||
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
return await _create_file_download_response(file)
|
||||
return await create_file_download_response(file)
|
||||
|
||||
|
||||
@router.delete(
|
||||
|
||||
@@ -600,3 +600,221 @@ def test_list_files_offset_is_echoed_back(mock_manager_cls, mock_get_workspace):
|
||||
mock_instance.list_files.assert_called_once_with(
|
||||
limit=11, offset=50, include_all_sessions=True
|
||||
)
|
||||
|
||||
|
||||
# -- _sanitize_filename_for_header tests --
|
||||
|
||||
|
||||
class TestSanitizeFilenameForHeader:
|
||||
def test_simple_ascii_attachment(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
assert _sanitize_filename_for_header("report.pdf") == (
|
||||
'attachment; filename="report.pdf"'
|
||||
)
|
||||
|
||||
def test_inline_disposition(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
assert _sanitize_filename_for_header("image.png", disposition="inline") == (
|
||||
'inline; filename="image.png"'
|
||||
)
|
||||
|
||||
def test_strips_cr_lf_null(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header("a\rb\nc\x00d.txt")
|
||||
assert "\r" not in result
|
||||
assert "\n" not in result
|
||||
assert "\x00" not in result
|
||||
assert 'filename="abcd.txt"' in result
|
||||
|
||||
def test_escapes_quotes(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header('file"name.txt')
|
||||
assert 'filename="file\\"name.txt"' in result
|
||||
|
||||
def test_header_injection_blocked(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header("evil.txt\r\nX-Injected: true")
|
||||
# CR/LF stripped — the remaining text is safely inside the quoted value
|
||||
assert "\r" not in result
|
||||
assert "\n" not in result
|
||||
assert result == 'attachment; filename="evil.txtX-Injected: true"'
|
||||
|
||||
def test_unicode_uses_rfc5987(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header("日本語.pdf")
|
||||
assert "filename*=UTF-8''" in result
|
||||
assert "attachment" in result
|
||||
|
||||
def test_unicode_inline(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header("图片.png", disposition="inline")
|
||||
assert result.startswith("inline; filename*=UTF-8''")
|
||||
|
||||
def test_empty_filename(self):
|
||||
from backend.api.features.workspace.routes import _sanitize_filename_for_header
|
||||
|
||||
result = _sanitize_filename_for_header("")
|
||||
assert result == 'attachment; filename=""'
|
||||
|
||||
|
||||
# -- _create_streaming_response tests --
|
||||
|
||||
|
||||
class TestCreateStreamingResponse:
|
||||
def test_attachment_disposition_by_default(self):
|
||||
from backend.api.features.workspace.routes import _create_streaming_response
|
||||
|
||||
file = _make_file(name="data.bin", mime_type="application/octet-stream")
|
||||
response = _create_streaming_response(b"binary-data", file)
|
||||
assert (
|
||||
response.headers["Content-Disposition"] == 'attachment; filename="data.bin"'
|
||||
)
|
||||
assert response.headers["Content-Type"] == "application/octet-stream"
|
||||
assert response.headers["Content-Length"] == "11"
|
||||
assert response.body == b"binary-data"
|
||||
|
||||
def test_inline_disposition(self):
|
||||
from backend.api.features.workspace.routes import _create_streaming_response
|
||||
|
||||
file = _make_file(name="photo.png", mime_type="image/png")
|
||||
response = _create_streaming_response(b"\x89PNG", file, inline=True)
|
||||
assert response.headers["Content-Disposition"] == 'inline; filename="photo.png"'
|
||||
assert response.headers["Content-Type"] == "image/png"
|
||||
|
||||
def test_inline_sanitizes_filename(self):
|
||||
from backend.api.features.workspace.routes import _create_streaming_response
|
||||
|
||||
file = _make_file(name='evil"\r\n.txt', mime_type="text/plain")
|
||||
response = _create_streaming_response(b"data", file, inline=True)
|
||||
assert "\r" not in response.headers["Content-Disposition"]
|
||||
assert "\n" not in response.headers["Content-Disposition"]
|
||||
assert "inline" in response.headers["Content-Disposition"]
|
||||
|
||||
def test_content_length_matches_body(self):
|
||||
from backend.api.features.workspace.routes import _create_streaming_response
|
||||
|
||||
content = b"x" * 1000
|
||||
file = _make_file(name="big.bin", mime_type="application/octet-stream")
|
||||
response = _create_streaming_response(content, file)
|
||||
assert response.headers["Content-Length"] == "1000"
|
||||
|
||||
|
||||
# -- create_file_download_response tests --
|
||||
|
||||
|
||||
class TestCreateFileDownloadResponse:
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_storage_returns_streaming_response(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.retrieve.return_value = b"file contents"
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(
|
||||
storage_path="local://uploads/test.txt",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
response = await create_file_download_response(file)
|
||||
assert response.status_code == 200
|
||||
assert response.body == b"file contents"
|
||||
assert "attachment" in response.headers["Content-Disposition"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_storage_inline(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.retrieve.return_value = b"\x89PNG"
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(
|
||||
storage_path="local://uploads/photo.png",
|
||||
mime_type="image/png",
|
||||
name="photo.png",
|
||||
)
|
||||
response = await create_file_download_response(file, inline=True)
|
||||
assert "inline" in response.headers["Content-Disposition"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gcs_redirect(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.get_download_url.return_value = (
|
||||
"https://storage.googleapis.com/signed-url"
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(storage_path="gcs://bucket/file.pdf")
|
||||
response = await create_file_download_response(file)
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
response.headers["location"] == "https://storage.googleapis.com/signed-url"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gcs_api_fallback_streams_directly(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.get_download_url.return_value = "/api/fallback"
|
||||
mock_storage.retrieve.return_value = b"fallback content"
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(storage_path="gcs://bucket/file.txt")
|
||||
response = await create_file_download_response(file)
|
||||
assert response.status_code == 200
|
||||
assert response.body == b"fallback content"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gcs_signed_url_failure_falls_back_to_streaming(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.get_download_url.side_effect = RuntimeError("GCS error")
|
||||
mock_storage.retrieve.return_value = b"streamed"
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(storage_path="gcs://bucket/file.txt")
|
||||
response = await create_file_download_response(file)
|
||||
assert response.status_code == 200
|
||||
assert response.body == b"streamed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gcs_total_failure_raises(self, mocker):
|
||||
from backend.api.features.workspace.routes import create_file_download_response
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.get_download_url.side_effect = RuntimeError("GCS error")
|
||||
mock_storage.retrieve.side_effect = RuntimeError("Also failed")
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_storage",
|
||||
return_value=mock_storage,
|
||||
)
|
||||
|
||||
file = _make_file(storage_path="gcs://bucket/file.txt")
|
||||
with pytest.raises(RuntimeError, match="Also failed"):
|
||||
await create_file_download_response(file)
|
||||
|
||||
@@ -32,6 +32,7 @@ import backend.api.features.library.routes
|
||||
import backend.api.features.mcp.routes as mcp_routes
|
||||
import backend.api.features.oauth
|
||||
import backend.api.features.otto.routes
|
||||
import backend.api.features.platform_linking.routes
|
||||
import backend.api.features.postmark.postmark
|
||||
import backend.api.features.store.model
|
||||
import backend.api.features.store.routes
|
||||
@@ -378,6 +379,11 @@ app.include_router(
|
||||
tags=["oauth"],
|
||||
prefix="/api/oauth",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.platform_linking.routes.router,
|
||||
tags=["platform-linking"],
|
||||
prefix="/api/platform-linking",
|
||||
)
|
||||
|
||||
app.mount("/external-api", external_api)
|
||||
|
||||
|
||||
@@ -42,11 +42,13 @@ def main(**kwargs):
|
||||
from backend.data.db_manager import DatabaseManager
|
||||
from backend.executor import ExecutionManager, Scheduler
|
||||
from backend.notifications import NotificationManager
|
||||
from backend.platform_linking.manager import PlatformLinkingManager
|
||||
|
||||
run_processes(
|
||||
DatabaseManager().set_log_level("warning"),
|
||||
Scheduler(),
|
||||
NotificationManager(),
|
||||
PlatformLinkingManager(),
|
||||
WebsocketServer(),
|
||||
AgentServer(),
|
||||
ExecutionManager(),
|
||||
|
||||
@@ -155,3 +155,16 @@ def platform_cost_db():
|
||||
platform_cost_db = get_database_manager_async_client()
|
||||
|
||||
return platform_cost_db
|
||||
|
||||
|
||||
def platform_linking_db():
|
||||
if db.is_connected():
|
||||
from backend.platform_linking import db as _platform_linking_db
|
||||
|
||||
platform_linking_db = _platform_linking_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
platform_linking_db = get_database_manager_async_client()
|
||||
|
||||
return platform_linking_db
|
||||
|
||||
@@ -120,6 +120,7 @@ from backend.data.workspace import (
|
||||
list_workspace_files,
|
||||
soft_delete_workspace_file,
|
||||
)
|
||||
from backend.platform_linking import db as platform_linking_db
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
@@ -338,6 +339,22 @@ class DatabaseManager(AppService):
|
||||
# ============ Platform Cost Tracking ============ #
|
||||
log_platform_cost = _(log_platform_cost)
|
||||
|
||||
# ============ Platform Linking ============ #
|
||||
find_server_link_owner = _(platform_linking_db.find_server_link_owner)
|
||||
find_user_link_owner = _(platform_linking_db.find_user_link_owner)
|
||||
resolve_server_link = _(platform_linking_db.resolve_server_link)
|
||||
resolve_user_link = _(platform_linking_db.resolve_user_link)
|
||||
create_server_link_token = _(platform_linking_db.create_server_link_token)
|
||||
create_user_link_token = _(platform_linking_db.create_user_link_token)
|
||||
get_link_token_status = _(platform_linking_db.get_link_token_status)
|
||||
get_link_token_info = _(platform_linking_db.get_link_token_info)
|
||||
confirm_server_link = _(platform_linking_db.confirm_server_link)
|
||||
confirm_user_link = _(platform_linking_db.confirm_user_link)
|
||||
list_server_links = _(platform_linking_db.list_server_links)
|
||||
list_user_links = _(platform_linking_db.list_user_links)
|
||||
delete_server_link = _(platform_linking_db.delete_server_link)
|
||||
delete_user_link = _(platform_linking_db.delete_user_link)
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
get_chat_session = _(chat_db.get_chat_session)
|
||||
create_chat_session = _(chat_db.create_chat_session)
|
||||
@@ -540,6 +557,22 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
# ============ Platform Cost Tracking ============ #
|
||||
log_platform_cost = d.log_platform_cost
|
||||
|
||||
# ============ Platform Linking ============ #
|
||||
find_server_link_owner = d.find_server_link_owner
|
||||
find_user_link_owner = d.find_user_link_owner
|
||||
resolve_server_link = d.resolve_server_link
|
||||
resolve_user_link = d.resolve_user_link
|
||||
create_server_link_token = d.create_server_link_token
|
||||
create_user_link_token = d.create_user_link_token
|
||||
get_link_token_status = d.get_link_token_status
|
||||
get_link_token_info = d.get_link_token_info
|
||||
confirm_server_link = d.confirm_server_link
|
||||
confirm_user_link = d.confirm_user_link
|
||||
list_server_links = d.list_server_links
|
||||
list_user_links = d.list_user_links
|
||||
delete_server_link = d.delete_server_link
|
||||
delete_user_link = d.delete_user_link
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
get_chat_session = d.get_chat_session
|
||||
create_chat_session = d.create_chat_session
|
||||
|
||||
@@ -19,11 +19,15 @@ from typing import (
|
||||
|
||||
from prisma import Json
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
from prisma.errors import ForeignKeyViolationError, UniqueViolationError
|
||||
from prisma.models import (
|
||||
AgentGraphExecution,
|
||||
AgentNodeExecution,
|
||||
AgentNodeExecutionInputOutput,
|
||||
AgentNodeExecutionKeyValueData,
|
||||
SharedExecutionFile,
|
||||
UserWorkspace,
|
||||
UserWorkspaceFile,
|
||||
)
|
||||
from prisma.types import (
|
||||
AgentGraphExecutionOrderByInput,
|
||||
@@ -1602,6 +1606,121 @@ async def get_graph_execution_by_share_token(
|
||||
)
|
||||
|
||||
|
||||
def _extract_workspace_file_ids(outputs: CompletedBlockOutput) -> set[str]:
|
||||
"""Extract workspace file IDs from execution outputs.
|
||||
|
||||
Scans all output values for workspace:// URI strings and extracts
|
||||
the file IDs. Only matches values that are plain strings starting
|
||||
with workspace://, not substrings within larger text.
|
||||
"""
|
||||
file_ids: set[str] = set()
|
||||
|
||||
def _scan(value: Any) -> None:
|
||||
if isinstance(value, str) and value.startswith("workspace://"):
|
||||
raw = value.removeprefix("workspace://")
|
||||
file_ref = raw.split("#", 1)[0] if "#" in raw else raw
|
||||
if file_ref and not file_ref.startswith("/"):
|
||||
file_ids.add(file_ref)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
_scan(item)
|
||||
elif isinstance(value, dict):
|
||||
for v in value.values():
|
||||
_scan(v)
|
||||
|
||||
for output_values in outputs.values():
|
||||
if isinstance(output_values, list):
|
||||
for val in output_values:
|
||||
_scan(val)
|
||||
else:
|
||||
_scan(output_values)
|
||||
|
||||
return file_ids
|
||||
|
||||
|
||||
async def create_shared_execution_files(
|
||||
execution_id: str,
|
||||
share_token: str,
|
||||
user_id: str,
|
||||
outputs: CompletedBlockOutput,
|
||||
) -> int:
|
||||
"""Scan execution outputs for workspace files and create allowlist records.
|
||||
|
||||
Only files belonging to the user's workspace are allowlisted — prevents
|
||||
cross-workspace file exposure via crafted outputs.
|
||||
|
||||
Returns the number of records created.
|
||||
"""
|
||||
file_ids = _extract_workspace_file_ids(outputs)
|
||||
if not file_ids:
|
||||
return 0
|
||||
|
||||
# Validate file IDs belong to the user's workspace
|
||||
workspace = await UserWorkspace.prisma().find_unique(where={"userId": user_id})
|
||||
if not workspace:
|
||||
return 0
|
||||
|
||||
owned_files = await UserWorkspaceFile.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": list(file_ids)},
|
||||
"workspaceId": workspace.id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
)
|
||||
owned_ids = {f.id for f in owned_files}
|
||||
|
||||
created = 0
|
||||
for file_id in owned_ids:
|
||||
try:
|
||||
await SharedExecutionFile.prisma().create(
|
||||
data={
|
||||
"executionId": execution_id,
|
||||
"fileId": file_id,
|
||||
"shareToken": share_token,
|
||||
}
|
||||
)
|
||||
created += 1
|
||||
except UniqueViolationError:
|
||||
logger.debug(
|
||||
f"Skipping shared file record for {file_id}: " f"record already exists"
|
||||
)
|
||||
except ForeignKeyViolationError:
|
||||
logger.debug(
|
||||
f"Skipping shared file record for {file_id}: " f"file does not exist"
|
||||
)
|
||||
return created
|
||||
|
||||
|
||||
async def delete_shared_execution_files(execution_id: str) -> int:
|
||||
"""Delete all shared file records for an execution.
|
||||
|
||||
Returns the number of records deleted.
|
||||
"""
|
||||
result = await SharedExecutionFile.prisma().delete_many(
|
||||
where={"executionId": execution_id}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def get_shared_execution_file(
|
||||
share_token: str,
|
||||
file_id: str,
|
||||
) -> str | None:
|
||||
"""Look up a file ID in the shared execution file allowlist.
|
||||
|
||||
Returns the execution ID if the file is in the allowlist, None otherwise.
|
||||
Uses a single query and returns a uniform None for all failure modes
|
||||
to prevent timing-based enumeration attacks.
|
||||
"""
|
||||
record = await SharedExecutionFile.prisma().find_first(
|
||||
where={
|
||||
"shareToken": share_token,
|
||||
"fileId": file_id,
|
||||
}
|
||||
)
|
||||
return record.executionId if record else None
|
||||
|
||||
|
||||
async def get_frequently_executed_graphs(
|
||||
days_back: int = 30,
|
||||
min_executions: int = 10,
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
"""Tests for SharedExecutionFile workspace URI extraction logic."""
|
||||
|
||||
from backend.data.execution import _extract_workspace_file_ids
|
||||
|
||||
|
||||
class TestExtractWorkspaceFileIds:
|
||||
def test_extracts_simple_workspace_uri(self):
|
||||
outputs = {"image": ["workspace://abc123"]}
|
||||
assert _extract_workspace_file_ids(outputs) == {"abc123"}
|
||||
|
||||
def test_extracts_workspace_uri_with_mime_fragment(self):
|
||||
outputs = {"image": ["workspace://abc123#image/png"]}
|
||||
assert _extract_workspace_file_ids(outputs) == {"abc123"}
|
||||
|
||||
def test_extracts_multiple_files_from_multiple_outputs(self):
|
||||
outputs = {
|
||||
"images": ["workspace://file1#image/png", "workspace://file2#image/jpeg"],
|
||||
"video": ["workspace://file3#video/mp4"],
|
||||
}
|
||||
assert _extract_workspace_file_ids(outputs) == {"file1", "file2", "file3"}
|
||||
|
||||
def test_ignores_non_workspace_strings(self):
|
||||
outputs = {
|
||||
"text": ["hello world"],
|
||||
"url": ["https://example.com/image.png"],
|
||||
"data": ["data:image/png;base64,abc"],
|
||||
}
|
||||
assert _extract_workspace_file_ids(outputs) == set()
|
||||
|
||||
def test_ignores_path_references(self):
|
||||
"""workspace:///path/to/file is a path reference, not a file ID."""
|
||||
outputs = {"file": ["workspace:///path/to/file.txt"]}
|
||||
assert _extract_workspace_file_ids(outputs) == set()
|
||||
|
||||
def test_handles_nested_dicts_in_output_values(self):
|
||||
outputs = {
|
||||
"result": [{"url": "workspace://nested-file#image/png", "label": "test"}]
|
||||
}
|
||||
assert _extract_workspace_file_ids(outputs) == {"nested-file"}
|
||||
|
||||
def test_handles_nested_lists_in_output_values(self):
|
||||
outputs = {"result": [["workspace://inner-file"]]}
|
||||
assert _extract_workspace_file_ids(outputs) == {"inner-file"}
|
||||
|
||||
def test_handles_empty_outputs(self):
|
||||
assert _extract_workspace_file_ids({}) == set()
|
||||
|
||||
def test_handles_non_string_values(self):
|
||||
outputs = {"count": [42], "flag": [True], "empty": [None]}
|
||||
assert _extract_workspace_file_ids(outputs) == set()
|
||||
|
||||
def test_deduplicates_repeated_file_ids(self):
|
||||
outputs = {
|
||||
"a": ["workspace://same-file#image/png"],
|
||||
"b": ["workspace://same-file#image/jpeg"],
|
||||
}
|
||||
assert _extract_workspace_file_ids(outputs) == {"same-file"}
|
||||
|
||||
def test_does_not_match_workspace_substring_in_text(self):
|
||||
"""Plain text that contains workspace:// as a substring should NOT be extracted
|
||||
because the value itself must start with workspace://."""
|
||||
outputs = {"text": ["check out workspace://fake-id for details"]}
|
||||
# The string starts with "check out", not "workspace://", so no match
|
||||
assert _extract_workspace_file_ids(outputs) == set()
|
||||
|
||||
def test_mixed_workspace_and_non_workspace_outputs(self):
|
||||
outputs = {
|
||||
"image": ["workspace://real-file#image/png"],
|
||||
"text": ["just some text"],
|
||||
"url": ["https://example.com"],
|
||||
}
|
||||
assert _extract_workspace_file_ids(outputs) == {"real-file"}
|
||||
@@ -204,6 +204,22 @@ async def get_workspace_file(
|
||||
return WorkspaceFile.from_db(file) if file else None
|
||||
|
||||
|
||||
async def get_workspace_file_by_id(
|
||||
file_id: str,
|
||||
) -> Optional[WorkspaceFile]:
|
||||
"""
|
||||
Get a workspace file by ID without workspace scoping.
|
||||
|
||||
Only use this when access has already been validated through another
|
||||
mechanism (e.g. SharedExecutionFile allowlist). For user-facing
|
||||
endpoints, use get_workspace_file() which enforces workspace scoping.
|
||||
"""
|
||||
file = await UserWorkspaceFile.prisma().find_first(
|
||||
where={"id": file_id, "isDeleted": False}
|
||||
)
|
||||
return WorkspaceFile.from_db(file) if file else None
|
||||
|
||||
|
||||
async def get_workspace_file_by_path(
|
||||
workspace_id: str,
|
||||
path: str,
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Platform bot linking: helpers, chat orchestration, and AppService."""
|
||||
112
autogpt_platform/backend/backend/platform_linking/chat.py
Normal file
112
autogpt_platform/backend/backend/platform_linking/chat.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""Chat-turn orchestration for the platform bot bridge."""
|
||||
|
||||
import logging
|
||||
from uuid import uuid4
|
||||
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.executor.utils import enqueue_copilot_turn
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
append_and_save_message,
|
||||
create_chat_session,
|
||||
get_chat_session,
|
||||
)
|
||||
from backend.data.db_accessors import platform_linking_db
|
||||
from backend.util.exceptions import DuplicateChatMessageError, NotFoundError
|
||||
|
||||
from .models import BotChatRequest, ChatTurnHandle
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CHAT_TOOL_CALL_ID = "chat_stream"
|
||||
CHAT_TOOL_NAME = "chat"
|
||||
|
||||
|
||||
async def resolve_chat_owner(request: BotChatRequest) -> str:
|
||||
"""Return the AutoGPT user ID that owns the platform conversation.
|
||||
|
||||
Server context → server owner. DM context → the DM-linked user.
|
||||
"""
|
||||
platform = request.platform.value
|
||||
db = platform_linking_db()
|
||||
|
||||
if request.platform_server_id:
|
||||
owner = await db.find_server_link_owner(platform, request.platform_server_id)
|
||||
if owner is None:
|
||||
raise NotFoundError("This server is not linked to an AutoGPT account.")
|
||||
return owner
|
||||
|
||||
owner = await db.find_user_link_owner(platform, request.platform_user_id)
|
||||
if owner is None:
|
||||
raise NotFoundError("Your DMs are not linked to an AutoGPT account.")
|
||||
return owner
|
||||
|
||||
|
||||
async def start_chat_turn(request: BotChatRequest) -> ChatTurnHandle:
|
||||
"""Prepare a copilot turn; caller subscribes via the returned handle.
|
||||
|
||||
``subscribe_from="0-0"`` on the handle means a late subscriber replays
|
||||
the full stream (Redis Streams, not pub/sub).
|
||||
"""
|
||||
owner_user_id = await resolve_chat_owner(request)
|
||||
|
||||
session_id = request.session_id
|
||||
if session_id:
|
||||
session = await get_chat_session(session_id, owner_user_id)
|
||||
if not session:
|
||||
raise NotFoundError("Session not found.")
|
||||
else:
|
||||
session = await create_chat_session(owner_user_id, dry_run=False)
|
||||
session_id = session.session_id
|
||||
|
||||
# Persist the user message before enqueueing, mirroring the REST chat
|
||||
# endpoint — otherwise the executor runs against empty history.
|
||||
is_duplicate = (
|
||||
await append_and_save_message(
|
||||
session_id, ChatMessage(role="user", content=request.message)
|
||||
)
|
||||
) is None
|
||||
if is_duplicate:
|
||||
# Matches REST chat behaviour: skip create_session + enqueue so we
|
||||
# don't create an orphan stream with no producer. Caller subscribes
|
||||
# to the in-flight turn via its own retry logic, or drops.
|
||||
logger.info(
|
||||
"Duplicate bot message for session %s (platform %s, user ...%s)",
|
||||
session_id,
|
||||
request.platform.value,
|
||||
owner_user_id[-8:],
|
||||
)
|
||||
raise DuplicateChatMessageError("Message already in flight.")
|
||||
|
||||
turn_id = str(uuid4())
|
||||
|
||||
await stream_registry.create_session(
|
||||
session_id=session_id,
|
||||
user_id=owner_user_id,
|
||||
tool_call_id=CHAT_TOOL_CALL_ID,
|
||||
tool_name=CHAT_TOOL_NAME,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
|
||||
await enqueue_copilot_turn(
|
||||
session_id=session_id,
|
||||
user_id=owner_user_id,
|
||||
message=request.message,
|
||||
turn_id=turn_id,
|
||||
is_user_message=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Bot chat turn started: %s (server %s, session %s, turn %s, owner ...%s)",
|
||||
request.platform.value,
|
||||
request.platform_server_id or "DM",
|
||||
session_id,
|
||||
turn_id,
|
||||
owner_user_id[-8:],
|
||||
)
|
||||
|
||||
return ChatTurnHandle(
|
||||
session_id=session_id,
|
||||
turn_id=turn_id,
|
||||
user_id=owner_user_id,
|
||||
)
|
||||
125
autogpt_platform/backend/backend/platform_linking/chat_test.py
Normal file
125
autogpt_platform/backend/backend/platform_linking/chat_test.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Tests for chat-turn orchestration — esp. the duplicate-message guard."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.exceptions import DuplicateChatMessageError, NotFoundError
|
||||
|
||||
from .chat import start_chat_turn
|
||||
from .models import BotChatRequest, Platform
|
||||
|
||||
|
||||
def _request(**overrides) -> BotChatRequest:
|
||||
defaults = dict(
|
||||
platform=Platform.DISCORD,
|
||||
platform_user_id="pu1",
|
||||
message="hello",
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return BotChatRequest(**defaults)
|
||||
|
||||
|
||||
class TestStartChatTurn:
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_user_link_raises_not_found(self):
|
||||
db_mock = MagicMock()
|
||||
db_mock.find_user_link_owner = AsyncMock(return_value=None)
|
||||
with patch(
|
||||
"backend.platform_linking.chat.platform_linking_db",
|
||||
return_value=db_mock,
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
await start_chat_turn(_request())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_message_raises_and_skips_stream_create(self):
|
||||
# append_and_save_message returns None → duplicate.
|
||||
# Verify we raise and do NOT create a stream session.
|
||||
db_mock = MagicMock()
|
||||
db_mock.find_user_link_owner = AsyncMock(return_value="owner-1")
|
||||
session = MagicMock(session_id="sess-existing")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.platform_linking.chat.platform_linking_db",
|
||||
return_value=db_mock,
|
||||
),
|
||||
patch(
|
||||
"backend.platform_linking.chat.create_chat_session",
|
||||
new=AsyncMock(return_value=session),
|
||||
),
|
||||
patch(
|
||||
"backend.platform_linking.chat.append_and_save_message",
|
||||
new=AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
"backend.platform_linking.chat.stream_registry"
|
||||
) as mock_stream_registry,
|
||||
patch(
|
||||
"backend.platform_linking.chat.enqueue_copilot_turn",
|
||||
new=AsyncMock(),
|
||||
) as mock_enqueue,
|
||||
):
|
||||
mock_stream_registry.create_session = AsyncMock()
|
||||
|
||||
with pytest.raises(DuplicateChatMessageError):
|
||||
await start_chat_turn(_request())
|
||||
|
||||
mock_stream_registry.create_session.assert_not_awaited()
|
||||
mock_enqueue.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_happy_path_creates_stream_and_enqueues(self):
|
||||
db_mock = MagicMock()
|
||||
db_mock.find_user_link_owner = AsyncMock(return_value="owner-1")
|
||||
session = MagicMock(session_id="sess-new")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.platform_linking.chat.platform_linking_db",
|
||||
return_value=db_mock,
|
||||
),
|
||||
patch(
|
||||
"backend.platform_linking.chat.create_chat_session",
|
||||
new=AsyncMock(return_value=session),
|
||||
),
|
||||
patch(
|
||||
"backend.platform_linking.chat.append_and_save_message",
|
||||
new=AsyncMock(return_value=MagicMock()),
|
||||
),
|
||||
patch(
|
||||
"backend.platform_linking.chat.stream_registry"
|
||||
) as mock_stream_registry,
|
||||
patch(
|
||||
"backend.platform_linking.chat.enqueue_copilot_turn",
|
||||
new=AsyncMock(),
|
||||
) as mock_enqueue,
|
||||
):
|
||||
mock_stream_registry.create_session = AsyncMock()
|
||||
handle = await start_chat_turn(_request())
|
||||
|
||||
assert handle.session_id == "sess-new"
|
||||
assert handle.user_id == "owner-1"
|
||||
assert handle.turn_id
|
||||
assert handle.subscribe_from == "0-0"
|
||||
mock_stream_registry.create_session.assert_awaited_once()
|
||||
mock_enqueue.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_existing_session_id_wrong_user_raises_not_found(self):
|
||||
db_mock = MagicMock()
|
||||
db_mock.find_user_link_owner = AsyncMock(return_value="owner-1")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.platform_linking.chat.platform_linking_db",
|
||||
return_value=db_mock,
|
||||
),
|
||||
patch(
|
||||
"backend.platform_linking.chat.get_chat_session",
|
||||
new=AsyncMock(return_value=None),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
await start_chat_turn(_request(session_id="someone-elses"))
|
||||
428
autogpt_platform/backend/backend/platform_linking/db.py
Normal file
428
autogpt_platform/backend/backend/platform_linking/db.py
Normal file
@@ -0,0 +1,428 @@
|
||||
"""Platform link DB operations.
|
||||
|
||||
Directly accessed by the ``AgentServer`` / ``DatabaseManager`` pods (which
|
||||
hold the Prisma connection). Other services go through
|
||||
``backend.data.db_accessors.platform_linking_db`` so calls are transparently
|
||||
routed via ``DatabaseManagerAsyncClient`` when no local Prisma is available.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import PlatformLink, PlatformLinkToken, PlatformUserLink
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.util.exceptions import (
|
||||
LinkAlreadyExistsError,
|
||||
LinkFlowMismatchError,
|
||||
LinkTokenExpiredError,
|
||||
NotAuthorizedError,
|
||||
NotFoundError,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .models import (
|
||||
ConfirmLinkResponse,
|
||||
ConfirmUserLinkResponse,
|
||||
CreateLinkTokenRequest,
|
||||
CreateUserLinkTokenRequest,
|
||||
DeleteLinkResponse,
|
||||
LinkTokenInfoResponse,
|
||||
LinkTokenResponse,
|
||||
LinkTokenStatusResponse,
|
||||
LinkType,
|
||||
PlatformLinkInfo,
|
||||
PlatformUserLinkInfo,
|
||||
ResolveResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LINK_TOKEN_EXPIRY_MINUTES = 30
|
||||
|
||||
|
||||
def _link_base_url() -> str:
|
||||
return Settings().config.platform_link_base_url
|
||||
|
||||
|
||||
# ── Owner lookups ─────────────────────────────────────────────────────
|
||||
# These return the owning AutoGPT user_id (or None). Using scalars instead
|
||||
# of Prisma models keeps everything RPC-safe — Prisma objects are rejected
|
||||
# by AppService's result validator.
|
||||
|
||||
|
||||
async def find_server_link_owner(platform: str, platform_server_id: str) -> str | None:
|
||||
link = await PlatformLink.prisma().find_first(
|
||||
where={"platform": platform, "platformServerId": platform_server_id}
|
||||
)
|
||||
return link.userId if link else None
|
||||
|
||||
|
||||
async def find_user_link_owner(platform: str, platform_user_id: str) -> str | None:
|
||||
link = await PlatformUserLink.prisma().find_unique(
|
||||
where={
|
||||
"platform_platformUserId": {
|
||||
"platform": platform,
|
||||
"platformUserId": platform_user_id,
|
||||
}
|
||||
}
|
||||
)
|
||||
return link.userId if link else None
|
||||
|
||||
|
||||
async def resolve_server_link(
|
||||
platform: str, platform_server_id: str
|
||||
) -> ResolveResponse:
|
||||
owner = await find_server_link_owner(platform, platform_server_id)
|
||||
return ResolveResponse(linked=owner is not None)
|
||||
|
||||
|
||||
async def resolve_user_link(platform: str, platform_user_id: str) -> ResolveResponse:
|
||||
owner = await find_user_link_owner(platform, platform_user_id)
|
||||
return ResolveResponse(linked=owner is not None)
|
||||
|
||||
|
||||
# ── Token creation ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def create_server_link_token(
|
||||
request: CreateLinkTokenRequest,
|
||||
) -> LinkTokenResponse:
|
||||
platform = request.platform.value
|
||||
|
||||
if await find_server_link_owner(platform, request.platform_server_id):
|
||||
raise LinkAlreadyExistsError(
|
||||
"This server is already linked to an AutoGPT account."
|
||||
)
|
||||
|
||||
token = secrets.token_urlsafe(32)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=LINK_TOKEN_EXPIRY_MINUTES
|
||||
)
|
||||
|
||||
# Atomic: invalidate pending tokens + create the new one, so two racing
|
||||
# create calls can't leave two valid tokens for the same target.
|
||||
async with transaction() as tx:
|
||||
await PlatformLinkToken.prisma(tx).update_many(
|
||||
where={
|
||||
"platform": platform,
|
||||
"linkType": LinkType.SERVER.value,
|
||||
"platformServerId": request.platform_server_id,
|
||||
"usedAt": None,
|
||||
},
|
||||
data={"usedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
await PlatformLinkToken.prisma(tx).create(
|
||||
data={
|
||||
"token": token,
|
||||
"platform": platform,
|
||||
"linkType": LinkType.SERVER.value,
|
||||
"platformServerId": request.platform_server_id,
|
||||
"platformUserId": request.platform_user_id,
|
||||
"platformUsername": request.platform_username,
|
||||
"serverName": request.server_name,
|
||||
"channelId": request.channel_id,
|
||||
"expiresAt": expires_at,
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Created SERVER link token for %s server %s (expires %s)",
|
||||
platform,
|
||||
request.platform_server_id,
|
||||
expires_at.isoformat(),
|
||||
)
|
||||
|
||||
return LinkTokenResponse(
|
||||
token=token,
|
||||
expires_at=expires_at,
|
||||
link_url=f"{_link_base_url()}/{token}?platform={platform}",
|
||||
)
|
||||
|
||||
|
||||
async def create_user_link_token(
|
||||
request: CreateUserLinkTokenRequest,
|
||||
) -> LinkTokenResponse:
|
||||
platform = request.platform.value
|
||||
|
||||
if await find_user_link_owner(platform, request.platform_user_id):
|
||||
raise LinkAlreadyExistsError(
|
||||
"Your DMs with the bot are already linked to an AutoGPT account."
|
||||
)
|
||||
|
||||
token = secrets.token_urlsafe(32)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=LINK_TOKEN_EXPIRY_MINUTES
|
||||
)
|
||||
|
||||
async with transaction() as tx:
|
||||
await PlatformLinkToken.prisma(tx).update_many(
|
||||
where={
|
||||
"platform": platform,
|
||||
"linkType": LinkType.USER.value,
|
||||
"platformUserId": request.platform_user_id,
|
||||
"usedAt": None,
|
||||
},
|
||||
data={"usedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
await PlatformLinkToken.prisma(tx).create(
|
||||
data={
|
||||
"token": token,
|
||||
"platform": platform,
|
||||
"linkType": LinkType.USER.value,
|
||||
"platformUserId": request.platform_user_id,
|
||||
"platformUsername": request.platform_username,
|
||||
"expiresAt": expires_at,
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Created USER link token for %s (expires %s)", platform, expires_at.isoformat()
|
||||
)
|
||||
|
||||
return LinkTokenResponse(
|
||||
token=token,
|
||||
expires_at=expires_at,
|
||||
link_url=f"{_link_base_url()}/{token}?platform={platform}",
|
||||
)
|
||||
|
||||
|
||||
# ── Token status / info ───────────────────────────────────────────────
|
||||
|
||||
|
||||
async def get_link_token_status(token: str) -> LinkTokenStatusResponse:
|
||||
link_token = await PlatformLinkToken.prisma().find_unique(where={"token": token})
|
||||
|
||||
if not link_token:
|
||||
raise NotFoundError("Token not found.")
|
||||
|
||||
if link_token.usedAt is not None:
|
||||
# A superseded token (invalidated by create_*_token) has usedAt set
|
||||
# without a backing link row — report expired, not linked.
|
||||
if link_token.linkType == LinkType.USER.value:
|
||||
owner = await find_user_link_owner(
|
||||
link_token.platform, link_token.platformUserId
|
||||
)
|
||||
else:
|
||||
owner = (
|
||||
await find_server_link_owner(
|
||||
link_token.platform, link_token.platformServerId
|
||||
)
|
||||
if link_token.platformServerId
|
||||
else None
|
||||
)
|
||||
return LinkTokenStatusResponse(status="linked" if owner else "expired")
|
||||
|
||||
if link_token.expiresAt.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
|
||||
return LinkTokenStatusResponse(status="expired")
|
||||
|
||||
return LinkTokenStatusResponse(status="pending")
|
||||
|
||||
|
||||
async def get_link_token_info(token: str) -> LinkTokenInfoResponse:
|
||||
link_token = await PlatformLinkToken.prisma().find_unique(where={"token": token})
|
||||
|
||||
if not link_token or link_token.usedAt is not None:
|
||||
raise NotFoundError("Token not found.")
|
||||
|
||||
if link_token.expiresAt.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
|
||||
raise LinkTokenExpiredError("Token expired.")
|
||||
|
||||
return LinkTokenInfoResponse(
|
||||
platform=link_token.platform,
|
||||
link_type=LinkType(link_token.linkType),
|
||||
server_name=link_token.serverName,
|
||||
)
|
||||
|
||||
|
||||
# ── Confirmation (user-facing, JWT-authed) ────────────────────────────
|
||||
|
||||
|
||||
async def confirm_server_link(token: str, user_id: str) -> ConfirmLinkResponse:
|
||||
link_token = await PlatformLinkToken.prisma().find_unique(where={"token": token})
|
||||
|
||||
if not link_token:
|
||||
raise NotFoundError("Token not found.")
|
||||
if link_token.linkType != LinkType.SERVER.value:
|
||||
raise LinkFlowMismatchError("This link is for a different linking flow.")
|
||||
if link_token.usedAt is not None:
|
||||
raise LinkTokenExpiredError("This link has already been used.")
|
||||
if link_token.expiresAt.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
|
||||
raise LinkTokenExpiredError("This link has expired.")
|
||||
if not link_token.platformServerId:
|
||||
raise LinkFlowMismatchError("Server token missing server ID.")
|
||||
|
||||
owner = await find_server_link_owner(
|
||||
link_token.platform, link_token.platformServerId
|
||||
)
|
||||
if owner:
|
||||
detail = (
|
||||
"This server is already linked to your account."
|
||||
if owner == user_id
|
||||
else "This server is already linked to another AutoGPT account."
|
||||
)
|
||||
raise LinkAlreadyExistsError(detail)
|
||||
|
||||
# Atomic consume + create so a failed create doesn't burn the token.
|
||||
now = datetime.now(timezone.utc)
|
||||
try:
|
||||
async with transaction() as tx:
|
||||
updated = await PlatformLinkToken.prisma(tx).update_many(
|
||||
where={"token": token, "usedAt": None, "expiresAt": {"gt": now}},
|
||||
data={"usedAt": now},
|
||||
)
|
||||
if updated == 0:
|
||||
raise LinkTokenExpiredError("This link has already been used.")
|
||||
await PlatformLink.prisma(tx).create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"platform": link_token.platform,
|
||||
"platformServerId": link_token.platformServerId,
|
||||
"ownerPlatformUserId": link_token.platformUserId,
|
||||
"serverName": link_token.serverName,
|
||||
}
|
||||
)
|
||||
except UniqueViolationError as exc:
|
||||
raise LinkAlreadyExistsError(
|
||||
"This server was just linked by another request."
|
||||
) from exc
|
||||
|
||||
logger.info(
|
||||
"Linked %s server %s to user ...%s",
|
||||
link_token.platform,
|
||||
link_token.platformServerId,
|
||||
user_id[-8:],
|
||||
)
|
||||
|
||||
return ConfirmLinkResponse(
|
||||
success=True,
|
||||
platform=link_token.platform,
|
||||
platform_server_id=link_token.platformServerId,
|
||||
server_name=link_token.serverName,
|
||||
)
|
||||
|
||||
|
||||
async def confirm_user_link(token: str, user_id: str) -> ConfirmUserLinkResponse:
|
||||
link_token = await PlatformLinkToken.prisma().find_unique(where={"token": token})
|
||||
|
||||
if not link_token:
|
||||
raise NotFoundError("Token not found.")
|
||||
if link_token.linkType != LinkType.USER.value:
|
||||
raise LinkFlowMismatchError("This link is for a different linking flow.")
|
||||
if link_token.usedAt is not None:
|
||||
raise LinkTokenExpiredError("This link has already been used.")
|
||||
if link_token.expiresAt.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
|
||||
raise LinkTokenExpiredError("This link has expired.")
|
||||
|
||||
owner = await find_user_link_owner(link_token.platform, link_token.platformUserId)
|
||||
if owner:
|
||||
detail = (
|
||||
"Your DMs are already linked to your account."
|
||||
if owner == user_id
|
||||
else "This platform user is already linked to another AutoGPT account."
|
||||
)
|
||||
raise LinkAlreadyExistsError(detail)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
try:
|
||||
async with transaction() as tx:
|
||||
updated = await PlatformLinkToken.prisma(tx).update_many(
|
||||
where={"token": token, "usedAt": None, "expiresAt": {"gt": now}},
|
||||
data={"usedAt": now},
|
||||
)
|
||||
if updated == 0:
|
||||
raise LinkTokenExpiredError("This link has already been used.")
|
||||
await PlatformUserLink.prisma(tx).create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"platform": link_token.platform,
|
||||
"platformUserId": link_token.platformUserId,
|
||||
"platformUsername": link_token.platformUsername,
|
||||
}
|
||||
)
|
||||
except UniqueViolationError as exc:
|
||||
raise LinkAlreadyExistsError(
|
||||
"Your DMs were just linked by another request."
|
||||
) from exc
|
||||
|
||||
logger.info(
|
||||
"Linked %s DMs to AutoGPT user ...%s", link_token.platform, user_id[-8:]
|
||||
)
|
||||
|
||||
return ConfirmUserLinkResponse(
|
||||
success=True,
|
||||
platform=link_token.platform,
|
||||
platform_user_id=link_token.platformUserId,
|
||||
)
|
||||
|
||||
|
||||
# ── Listing ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def list_server_links(user_id: str) -> list[PlatformLinkInfo]:
|
||||
links = await PlatformLink.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order={"linkedAt": "desc"},
|
||||
)
|
||||
return [
|
||||
PlatformLinkInfo(
|
||||
id=link.id,
|
||||
platform=link.platform,
|
||||
platform_server_id=link.platformServerId,
|
||||
owner_platform_user_id=link.ownerPlatformUserId,
|
||||
server_name=link.serverName,
|
||||
linked_at=link.linkedAt,
|
||||
)
|
||||
for link in links
|
||||
]
|
||||
|
||||
|
||||
async def list_user_links(user_id: str) -> list[PlatformUserLinkInfo]:
|
||||
links = await PlatformUserLink.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order={"linkedAt": "desc"},
|
||||
)
|
||||
return [
|
||||
PlatformUserLinkInfo(
|
||||
id=link.id,
|
||||
platform=link.platform,
|
||||
platform_user_id=link.platformUserId,
|
||||
platform_username=link.platformUsername,
|
||||
linked_at=link.linkedAt,
|
||||
)
|
||||
for link in links
|
||||
]
|
||||
|
||||
|
||||
# ── Deletion ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def delete_server_link(link_id: str, user_id: str) -> DeleteLinkResponse:
|
||||
link = await PlatformLink.prisma().find_unique(where={"id": link_id})
|
||||
if not link:
|
||||
raise NotFoundError("Link not found.")
|
||||
if link.userId != user_id:
|
||||
raise NotAuthorizedError("Not your link.")
|
||||
|
||||
await PlatformLink.prisma().delete(where={"id": link_id})
|
||||
logger.info(
|
||||
"Unlinked %s server %s from user ...%s",
|
||||
link.platform,
|
||||
link.platformServerId,
|
||||
user_id[-8:],
|
||||
)
|
||||
return DeleteLinkResponse(success=True)
|
||||
|
||||
|
||||
async def delete_user_link(link_id: str, user_id: str) -> DeleteLinkResponse:
|
||||
link = await PlatformUserLink.prisma().find_unique(where={"id": link_id})
|
||||
if not link:
|
||||
raise NotFoundError("Link not found.")
|
||||
if link.userId != user_id:
|
||||
raise NotAuthorizedError("Not your link.")
|
||||
|
||||
await PlatformUserLink.prisma().delete(where={"id": link_id})
|
||||
logger.info("Unlinked %s DMs from AutoGPT user ...%s", link.platform, user_id[-8:])
|
||||
return DeleteLinkResponse(success=True)
|
||||
481
autogpt_platform/backend/backend/platform_linking/db_test.py
Normal file
481
autogpt_platform/backend/backend/platform_linking/db_test.py
Normal file
@@ -0,0 +1,481 @@
|
||||
"""Unit tests for platform_linking DB operations."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.exceptions import (
|
||||
LinkAlreadyExistsError,
|
||||
LinkFlowMismatchError,
|
||||
LinkTokenExpiredError,
|
||||
NotAuthorizedError,
|
||||
NotFoundError,
|
||||
)
|
||||
|
||||
from .db import (
|
||||
confirm_server_link,
|
||||
confirm_user_link,
|
||||
create_server_link_token,
|
||||
create_user_link_token,
|
||||
delete_server_link,
|
||||
delete_user_link,
|
||||
get_link_token_info,
|
||||
get_link_token_status,
|
||||
resolve_server_link,
|
||||
resolve_user_link,
|
||||
)
|
||||
from .models import (
|
||||
CreateLinkTokenRequest,
|
||||
CreateUserLinkTokenRequest,
|
||||
LinkType,
|
||||
Platform,
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _fake_transaction():
|
||||
# Avoids Prisma's tx binding asyncio primitives to the wrong loop in tests.
|
||||
yield MagicMock()
|
||||
|
||||
|
||||
# ── Resolve ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestResolve:
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_linked(self):
|
||||
with patch("backend.platform_linking.db.PlatformLink") as mock_link:
|
||||
mock_link.prisma.return_value.find_first = AsyncMock(
|
||||
return_value=MagicMock(userId="u-123")
|
||||
)
|
||||
result = await resolve_server_link("DISCORD", "g1")
|
||||
assert result.linked is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_unlinked(self):
|
||||
with patch("backend.platform_linking.db.PlatformLink") as mock_link:
|
||||
mock_link.prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
result = await resolve_server_link("DISCORD", "g1")
|
||||
assert result.linked is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_linked(self):
|
||||
with patch("backend.platform_linking.db.PlatformUserLink") as mock_user_link:
|
||||
mock_user_link.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=MagicMock(userId="u-xyz")
|
||||
)
|
||||
result = await resolve_user_link("DISCORD", "pu1")
|
||||
assert result.linked is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_unlinked(self):
|
||||
with patch("backend.platform_linking.db.PlatformUserLink") as mock_user_link:
|
||||
mock_user_link.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=None
|
||||
)
|
||||
result = await resolve_user_link("DISCORD", "pu1")
|
||||
assert result.linked is False
|
||||
|
||||
|
||||
# ── Token creation ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCreateServerLinkToken:
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_token_for_unlinked_server(self):
|
||||
with (
|
||||
patch("backend.platform_linking.db.PlatformLink") as mock_link,
|
||||
patch(
|
||||
"backend.platform_linking.db.transaction",
|
||||
new=_fake_transaction,
|
||||
),
|
||||
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token_model,
|
||||
):
|
||||
mock_link.prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_token_model.prisma.return_value.update_many = AsyncMock(return_value=0)
|
||||
mock_token_model.prisma.return_value.create = AsyncMock(
|
||||
return_value=MagicMock()
|
||||
)
|
||||
|
||||
result = await create_server_link_token(
|
||||
CreateLinkTokenRequest(
|
||||
platform=Platform.DISCORD,
|
||||
platform_server_id="g1",
|
||||
platform_user_id="u1",
|
||||
server_name="Test",
|
||||
),
|
||||
)
|
||||
|
||||
assert result.token
|
||||
assert result.token in result.link_url
|
||||
assert "?platform=DISCORD" in result.link_url
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_when_already_linked(self):
|
||||
with patch("backend.platform_linking.db.PlatformLink") as mock_link:
|
||||
mock_link.prisma.return_value.find_first = AsyncMock(
|
||||
return_value=MagicMock(userId="u-owner")
|
||||
)
|
||||
with pytest.raises(LinkAlreadyExistsError):
|
||||
await create_server_link_token(
|
||||
CreateLinkTokenRequest(
|
||||
platform=Platform.DISCORD,
|
||||
platform_server_id="g1",
|
||||
platform_user_id="u1",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestCreateUserLinkToken:
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_token_for_unlinked_user(self):
|
||||
with (
|
||||
patch("backend.platform_linking.db.PlatformUserLink") as mock_user_link,
|
||||
patch(
|
||||
"backend.platform_linking.db.transaction",
|
||||
new=_fake_transaction,
|
||||
),
|
||||
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token_model,
|
||||
):
|
||||
mock_user_link.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=None
|
||||
)
|
||||
mock_token_model.prisma.return_value.update_many = AsyncMock(return_value=0)
|
||||
mock_token_model.prisma.return_value.create = AsyncMock(
|
||||
return_value=MagicMock()
|
||||
)
|
||||
|
||||
result = await create_user_link_token(
|
||||
CreateUserLinkTokenRequest(
|
||||
platform=Platform.DISCORD,
|
||||
platform_user_id="pu1",
|
||||
platform_username="Bently",
|
||||
),
|
||||
)
|
||||
|
||||
assert result.token
|
||||
assert result.token in result.link_url
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_when_already_linked(self):
|
||||
with patch("backend.platform_linking.db.PlatformUserLink") as mock_user_link:
|
||||
mock_user_link.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=MagicMock(userId="u-owner")
|
||||
)
|
||||
with pytest.raises(LinkAlreadyExistsError):
|
||||
await create_user_link_token(
|
||||
CreateUserLinkTokenRequest(
|
||||
platform=Platform.DISCORD,
|
||||
platform_user_id="pu1",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ── Token status / info ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetLinkTokenStatus:
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found(self):
|
||||
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
|
||||
mock_model.prisma.return_value.find_unique = AsyncMock(return_value=None)
|
||||
with pytest.raises(NotFoundError):
|
||||
await get_link_token_status("abc")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pending(self):
|
||||
future = datetime.now(timezone.utc) + timedelta(minutes=10)
|
||||
fake_token = MagicMock(usedAt=None, expiresAt=future)
|
||||
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
|
||||
mock_model.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=fake_token
|
||||
)
|
||||
result = await get_link_token_status("abc")
|
||||
assert result.status == "pending"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_by_time(self):
|
||||
past = datetime.now(timezone.utc) - timedelta(minutes=10)
|
||||
fake_token = MagicMock(usedAt=None, expiresAt=past)
|
||||
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
|
||||
mock_model.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=fake_token
|
||||
)
|
||||
result = await get_link_token_status("abc")
|
||||
assert result.status == "expired"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_used_with_user_link_reports_linked(self):
|
||||
fake_token = MagicMock(
|
||||
usedAt=datetime.now(timezone.utc),
|
||||
linkType=LinkType.USER.value,
|
||||
platform="DISCORD",
|
||||
platformUserId="pu1",
|
||||
)
|
||||
with (
|
||||
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token,
|
||||
patch("backend.platform_linking.db.PlatformUserLink") as mock_user_link,
|
||||
):
|
||||
mock_token.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=fake_token
|
||||
)
|
||||
mock_user_link.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=MagicMock(userId="u-owner")
|
||||
)
|
||||
result = await get_link_token_status("abc")
|
||||
assert result.status == "linked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_used_without_link_reports_expired(self):
|
||||
# Superseded token: usedAt set, but no backing link row.
|
||||
fake_token = MagicMock(
|
||||
usedAt=datetime.now(timezone.utc),
|
||||
linkType=LinkType.SERVER.value,
|
||||
platform="DISCORD",
|
||||
platformServerId="g1",
|
||||
)
|
||||
with (
|
||||
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token,
|
||||
patch("backend.platform_linking.db.PlatformLink") as mock_link,
|
||||
):
|
||||
mock_token.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=fake_token
|
||||
)
|
||||
mock_link.prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
result = await get_link_token_status("abc")
|
||||
assert result.status == "expired"
|
||||
|
||||
|
||||
class TestGetLinkTokenInfo:
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found(self):
|
||||
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
|
||||
mock_model.prisma.return_value.find_unique = AsyncMock(return_value=None)
|
||||
with pytest.raises(NotFoundError):
|
||||
await get_link_token_info("abc")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_used_returns_not_found(self):
|
||||
fake_token = MagicMock(usedAt=datetime.now(timezone.utc))
|
||||
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
|
||||
mock_model.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=fake_token
|
||||
)
|
||||
with pytest.raises(NotFoundError):
|
||||
await get_link_token_info("abc")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_raises_expired(self):
|
||||
past = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
fake_token = MagicMock(usedAt=None, expiresAt=past)
|
||||
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
|
||||
mock_model.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=fake_token
|
||||
)
|
||||
with pytest.raises(LinkTokenExpiredError):
|
||||
await get_link_token_info("abc")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_returns_display_info(self):
|
||||
future = datetime.now(timezone.utc) + timedelta(minutes=10)
|
||||
fake_token = MagicMock(
|
||||
usedAt=None,
|
||||
expiresAt=future,
|
||||
platform="DISCORD",
|
||||
linkType=LinkType.SERVER.value,
|
||||
serverName="My Server",
|
||||
)
|
||||
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
|
||||
mock_model.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=fake_token
|
||||
)
|
||||
result = await get_link_token_info("abc")
|
||||
assert result.platform == "DISCORD"
|
||||
assert result.link_type == LinkType.SERVER
|
||||
assert result.server_name == "My Server"
|
||||
|
||||
|
||||
# ── Confirmation ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestConfirmServerLink:
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found(self):
|
||||
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
|
||||
mock_model.prisma.return_value.find_unique = AsyncMock(return_value=None)
|
||||
with pytest.raises(NotFoundError):
|
||||
await confirm_server_link("abc", "u1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wrong_link_type_rejected(self):
|
||||
fake_token = MagicMock(linkType=LinkType.USER.value)
|
||||
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
|
||||
mock_model.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=fake_token
|
||||
)
|
||||
with pytest.raises(LinkFlowMismatchError):
|
||||
await confirm_server_link("abc", "u1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_used(self):
|
||||
fake_token = MagicMock(
|
||||
linkType=LinkType.SERVER.value, usedAt=datetime.now(timezone.utc)
|
||||
)
|
||||
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
|
||||
mock_model.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=fake_token
|
||||
)
|
||||
with pytest.raises(LinkTokenExpiredError):
|
||||
await confirm_server_link("abc", "u1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_by_time(self):
|
||||
fake_token = MagicMock(
|
||||
linkType=LinkType.SERVER.value,
|
||||
usedAt=None,
|
||||
expiresAt=datetime.now(timezone.utc) - timedelta(minutes=5),
|
||||
)
|
||||
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
|
||||
mock_model.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=fake_token
|
||||
)
|
||||
with pytest.raises(LinkTokenExpiredError):
|
||||
await confirm_server_link("abc", "u1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_linked_to_same_user(self):
|
||||
fake_token = MagicMock(
|
||||
linkType=LinkType.SERVER.value,
|
||||
usedAt=None,
|
||||
expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10),
|
||||
platform="DISCORD",
|
||||
platformServerId="g1",
|
||||
)
|
||||
with (
|
||||
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token,
|
||||
patch("backend.platform_linking.db.PlatformLink") as mock_link,
|
||||
):
|
||||
mock_token.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=fake_token
|
||||
)
|
||||
mock_link.prisma.return_value.find_first = AsyncMock(
|
||||
return_value=MagicMock(userId="u1")
|
||||
)
|
||||
with pytest.raises(LinkAlreadyExistsError) as exc_info:
|
||||
await confirm_server_link("abc", "u1")
|
||||
assert "your account" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_linked_to_other_user(self):
|
||||
fake_token = MagicMock(
|
||||
linkType=LinkType.SERVER.value,
|
||||
usedAt=None,
|
||||
expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10),
|
||||
platform="DISCORD",
|
||||
platformServerId="g1",
|
||||
)
|
||||
with (
|
||||
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token,
|
||||
patch("backend.platform_linking.db.PlatformLink") as mock_link,
|
||||
):
|
||||
mock_token.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=fake_token
|
||||
)
|
||||
mock_link.prisma.return_value.find_first = AsyncMock(
|
||||
return_value=MagicMock(userId="other-user")
|
||||
)
|
||||
with pytest.raises(LinkAlreadyExistsError) as exc_info:
|
||||
await confirm_server_link("abc", "u1")
|
||||
assert "another" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestConfirmUserLink:
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found(self):
|
||||
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
|
||||
mock_model.prisma.return_value.find_unique = AsyncMock(return_value=None)
|
||||
with pytest.raises(NotFoundError):
|
||||
await confirm_user_link("abc", "u1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wrong_link_type_rejected(self):
|
||||
fake_token = MagicMock(linkType=LinkType.SERVER.value)
|
||||
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
|
||||
mock_model.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=fake_token
|
||||
)
|
||||
with pytest.raises(LinkFlowMismatchError):
|
||||
await confirm_user_link("abc", "u1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_by_time(self):
|
||||
fake_token = MagicMock(
|
||||
linkType=LinkType.USER.value,
|
||||
usedAt=None,
|
||||
expiresAt=datetime.now(timezone.utc) - timedelta(minutes=5),
|
||||
)
|
||||
with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model:
|
||||
mock_model.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=fake_token
|
||||
)
|
||||
with pytest.raises(LinkTokenExpiredError):
|
||||
await confirm_user_link("abc", "u1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_linked_to_other_user(self):
|
||||
fake_token = MagicMock(
|
||||
linkType=LinkType.USER.value,
|
||||
usedAt=None,
|
||||
expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10),
|
||||
platform="DISCORD",
|
||||
platformUserId="pu1",
|
||||
)
|
||||
with (
|
||||
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token,
|
||||
patch("backend.platform_linking.db.PlatformUserLink") as mock_user_link,
|
||||
):
|
||||
mock_token.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=fake_token
|
||||
)
|
||||
mock_user_link.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=MagicMock(userId="other-user")
|
||||
)
|
||||
with pytest.raises(LinkAlreadyExistsError):
|
||||
await confirm_user_link("abc", "u1")
|
||||
|
||||
|
||||
# ── Delete (authz checks) ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDeleteLinks:
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_server_link_not_found(self):
|
||||
with patch("backend.platform_linking.db.PlatformLink") as mock_model:
|
||||
mock_model.prisma.return_value.find_unique = AsyncMock(return_value=None)
|
||||
with pytest.raises(NotFoundError):
|
||||
await delete_server_link("x", "u1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_server_link_not_owned(self):
|
||||
link = MagicMock(userId="owner-A", platform="DISCORD", platformServerId="g1")
|
||||
with patch("backend.platform_linking.db.PlatformLink") as mock_model:
|
||||
mock_model.prisma.return_value.find_unique = AsyncMock(return_value=link)
|
||||
with pytest.raises(NotAuthorizedError):
|
||||
await delete_server_link("x", "u-other")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_user_link_not_found(self):
|
||||
with patch("backend.platform_linking.db.PlatformUserLink") as mock_model:
|
||||
mock_model.prisma.return_value.find_unique = AsyncMock(return_value=None)
|
||||
with pytest.raises(NotFoundError):
|
||||
await delete_user_link("x", "u1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_user_link_not_owned(self):
|
||||
link = MagicMock(userId="owner-A", platform="DISCORD")
|
||||
with patch("backend.platform_linking.db.PlatformUserLink") as mock_model:
|
||||
mock_model.prisma.return_value.find_unique = AsyncMock(return_value=link)
|
||||
with pytest.raises(NotAuthorizedError):
|
||||
await delete_user_link("x", "u-other")
|
||||
82
autogpt_platform/backend/backend/platform_linking/manager.py
Normal file
82
autogpt_platform/backend/backend/platform_linking/manager.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""AppService exposing bot-facing platform_linking ops over internal RPC."""
|
||||
|
||||
import logging
|
||||
|
||||
from backend.data.db_accessors import platform_linking_db
|
||||
from backend.util.service import AppService, AppServiceClient, endpoint_to_async, expose
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .chat import start_chat_turn
|
||||
from .models import (
|
||||
BotChatRequest,
|
||||
ChatTurnHandle,
|
||||
CreateLinkTokenRequest,
|
||||
CreateUserLinkTokenRequest,
|
||||
LinkTokenResponse,
|
||||
LinkTokenStatusResponse,
|
||||
Platform,
|
||||
ResolveResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PlatformLinkingManager(AppService):
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return Settings().config.platform_linking_service_port
|
||||
|
||||
@expose
|
||||
async def resolve_server_link(
|
||||
self, platform: Platform, platform_server_id: str
|
||||
) -> ResolveResponse:
|
||||
return await platform_linking_db().resolve_server_link(
|
||||
platform.value, platform_server_id
|
||||
)
|
||||
|
||||
@expose
|
||||
async def resolve_user_link(
|
||||
self, platform: Platform, platform_user_id: str
|
||||
) -> ResolveResponse:
|
||||
return await platform_linking_db().resolve_user_link(
|
||||
platform.value, platform_user_id
|
||||
)
|
||||
|
||||
@expose
|
||||
async def create_server_link_token(
|
||||
self, request: CreateLinkTokenRequest
|
||||
) -> LinkTokenResponse:
|
||||
return await platform_linking_db().create_server_link_token(request)
|
||||
|
||||
@expose
|
||||
async def create_user_link_token(
|
||||
self, request: CreateUserLinkTokenRequest
|
||||
) -> LinkTokenResponse:
|
||||
return await platform_linking_db().create_user_link_token(request)
|
||||
|
||||
@expose
|
||||
async def get_link_token_status(self, token: str) -> LinkTokenStatusResponse:
|
||||
return await platform_linking_db().get_link_token_status(token)
|
||||
|
||||
@expose
|
||||
async def start_chat_turn(self, request: BotChatRequest) -> ChatTurnHandle:
|
||||
return await start_chat_turn(request)
|
||||
|
||||
|
||||
class PlatformLinkingManagerClient(AppServiceClient):
|
||||
@classmethod
|
||||
def get_service_type(cls):
|
||||
return PlatformLinkingManager
|
||||
|
||||
resolve_server_link = endpoint_to_async(PlatformLinkingManager.resolve_server_link)
|
||||
resolve_user_link = endpoint_to_async(PlatformLinkingManager.resolve_user_link)
|
||||
create_server_link_token = endpoint_to_async(
|
||||
PlatformLinkingManager.create_server_link_token
|
||||
)
|
||||
create_user_link_token = endpoint_to_async(
|
||||
PlatformLinkingManager.create_user_link_token
|
||||
)
|
||||
get_link_token_status = endpoint_to_async(
|
||||
PlatformLinkingManager.get_link_token_status
|
||||
)
|
||||
start_chat_turn = endpoint_to_async(PlatformLinkingManager.start_chat_turn)
|
||||
@@ -0,0 +1,346 @@
|
||||
"""Tests for PlatformLinkingManager RPC wiring and confirm-token races."""
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.exceptions import LinkTokenExpiredError
|
||||
|
||||
from .db import confirm_server_link, confirm_user_link
|
||||
from .manager import PlatformLinkingManager, PlatformLinkingManagerClient
|
||||
from .models import (
|
||||
BotChatRequest,
|
||||
CreateLinkTokenRequest,
|
||||
CreateUserLinkTokenRequest,
|
||||
LinkType,
|
||||
Platform,
|
||||
ResolveResponse,
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _fake_transaction():
|
||||
yield MagicMock()
|
||||
|
||||
|
||||
class TestManagerWiring:
|
||||
def test_get_port(self):
|
||||
assert PlatformLinkingManager.get_port() == 8009
|
||||
|
||||
def test_client_exposes_expected_rpc_surface(self):
|
||||
service_type = PlatformLinkingManagerClient.get_service_type()
|
||||
assert service_type is PlatformLinkingManager
|
||||
|
||||
expected = {
|
||||
"resolve_server_link",
|
||||
"resolve_user_link",
|
||||
"create_server_link_token",
|
||||
"create_user_link_token",
|
||||
"get_link_token_status",
|
||||
"start_chat_turn",
|
||||
}
|
||||
for name in expected:
|
||||
assert hasattr(
|
||||
PlatformLinkingManagerClient, name
|
||||
), f"Client missing RPC stub: {name}"
|
||||
|
||||
for name in (
|
||||
"confirm_server_link",
|
||||
"confirm_user_link",
|
||||
"list_server_links",
|
||||
"list_user_links",
|
||||
"delete_server_link",
|
||||
"delete_user_link",
|
||||
):
|
||||
assert not hasattr(
|
||||
PlatformLinkingManagerClient, name
|
||||
), f"User-facing method leaked to bot client: {name}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_server_link_delegates_to_accessor(self):
|
||||
manager = PlatformLinkingManager()
|
||||
db_mock = MagicMock()
|
||||
db_mock.resolve_server_link = AsyncMock(
|
||||
return_value=ResolveResponse(linked=True)
|
||||
)
|
||||
with patch(
|
||||
"backend.platform_linking.manager.platform_linking_db",
|
||||
return_value=db_mock,
|
||||
):
|
||||
result = await manager.resolve_server_link(Platform.DISCORD, "g1")
|
||||
db_mock.resolve_server_link.assert_awaited_once_with("DISCORD", "g1")
|
||||
assert result.linked is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_user_link_delegates_to_accessor(self):
|
||||
manager = PlatformLinkingManager()
|
||||
db_mock = MagicMock()
|
||||
db_mock.resolve_user_link = AsyncMock(
|
||||
return_value=ResolveResponse(linked=False)
|
||||
)
|
||||
with patch(
|
||||
"backend.platform_linking.manager.platform_linking_db",
|
||||
return_value=db_mock,
|
||||
):
|
||||
result = await manager.resolve_user_link(Platform.DISCORD, "pu1")
|
||||
db_mock.resolve_user_link.assert_awaited_once_with("DISCORD", "pu1")
|
||||
assert result.linked is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_server_link_token_delegates(self):
|
||||
manager = PlatformLinkingManager()
|
||||
req = CreateLinkTokenRequest(
|
||||
platform=Platform.DISCORD,
|
||||
platform_server_id="g1",
|
||||
platform_user_id="u1",
|
||||
)
|
||||
fake_response = MagicMock()
|
||||
db_mock = MagicMock()
|
||||
db_mock.create_server_link_token = AsyncMock(return_value=fake_response)
|
||||
with patch(
|
||||
"backend.platform_linking.manager.platform_linking_db",
|
||||
return_value=db_mock,
|
||||
):
|
||||
result = await manager.create_server_link_token(req)
|
||||
db_mock.create_server_link_token.assert_awaited_once_with(req)
|
||||
assert result is fake_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_link_token_delegates(self):
|
||||
manager = PlatformLinkingManager()
|
||||
req = CreateUserLinkTokenRequest(
|
||||
platform=Platform.DISCORD, platform_user_id="pu1"
|
||||
)
|
||||
fake_response = MagicMock()
|
||||
db_mock = MagicMock()
|
||||
db_mock.create_user_link_token = AsyncMock(return_value=fake_response)
|
||||
with patch(
|
||||
"backend.platform_linking.manager.platform_linking_db",
|
||||
return_value=db_mock,
|
||||
):
|
||||
result = await manager.create_user_link_token(req)
|
||||
db_mock.create_user_link_token.assert_awaited_once_with(req)
|
||||
assert result is fake_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_link_token_status_delegates(self):
|
||||
manager = PlatformLinkingManager()
|
||||
fake_response = MagicMock()
|
||||
db_mock = MagicMock()
|
||||
db_mock.get_link_token_status = AsyncMock(return_value=fake_response)
|
||||
with patch(
|
||||
"backend.platform_linking.manager.platform_linking_db",
|
||||
return_value=db_mock,
|
||||
):
|
||||
result = await manager.get_link_token_status("tok")
|
||||
db_mock.get_link_token_status.assert_awaited_once_with("tok")
|
||||
assert result is fake_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_chat_turn_delegates(self):
|
||||
manager = PlatformLinkingManager()
|
||||
req = BotChatRequest(
|
||||
platform=Platform.DISCORD,
|
||||
platform_user_id="pu1",
|
||||
message="hi",
|
||||
)
|
||||
fake_response = MagicMock()
|
||||
with patch(
|
||||
"backend.platform_linking.manager.start_chat_turn",
|
||||
new=AsyncMock(return_value=fake_response),
|
||||
) as stub:
|
||||
result = await manager.start_chat_turn(req)
|
||||
stub.assert_awaited_once_with(req)
|
||||
assert result is fake_response
|
||||
|
||||
|
||||
class TestAdversarialConfirmRace:
|
||||
"""Concurrent confirm of one token: exactly one winner via ``update_many``
|
||||
guarded on ``usedAt = None``."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_second_confirm_loses(self):
|
||||
# update_many returns 0 → caller lost the race
|
||||
fake_token = MagicMock(
|
||||
linkType=LinkType.SERVER.value,
|
||||
usedAt=None,
|
||||
expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10),
|
||||
platform="DISCORD",
|
||||
platformServerId="g1",
|
||||
)
|
||||
|
||||
with (
|
||||
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token,
|
||||
patch("backend.platform_linking.db.PlatformLink") as mock_link,
|
||||
patch("backend.platform_linking.db.transaction", new=_fake_transaction),
|
||||
):
|
||||
mock_token.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=fake_token
|
||||
)
|
||||
mock_link.prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_token.prisma.return_value.update_many = AsyncMock(return_value=0)
|
||||
|
||||
with pytest.raises(LinkTokenExpiredError):
|
||||
await confirm_server_link("abc", "user-late")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_second_confirm_wins_when_update_many_returns_one(self):
|
||||
fake_token = MagicMock(
|
||||
linkType=LinkType.SERVER.value,
|
||||
usedAt=None,
|
||||
expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10),
|
||||
platform="DISCORD",
|
||||
platformServerId="g1",
|
||||
platformUserId="pu1",
|
||||
serverName="S1",
|
||||
)
|
||||
|
||||
with (
|
||||
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token,
|
||||
patch("backend.platform_linking.db.PlatformLink") as mock_link,
|
||||
patch("backend.platform_linking.db.transaction", new=_fake_transaction),
|
||||
):
|
||||
mock_token.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=fake_token
|
||||
)
|
||||
mock_link.prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_token.prisma.return_value.update_many = AsyncMock(return_value=1)
|
||||
mock_link.prisma.return_value.create = AsyncMock(return_value=MagicMock())
|
||||
|
||||
result = await confirm_server_link("abc", "user-winner")
|
||||
|
||||
assert result.success is True
|
||||
assert result.platform_server_id == "g1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gather_confirm_same_user_one_winner(self):
|
||||
fake_token = MagicMock(
|
||||
linkType=LinkType.SERVER.value,
|
||||
usedAt=None,
|
||||
expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10),
|
||||
platform="DISCORD",
|
||||
platformServerId="g1",
|
||||
platformUserId="pu1",
|
||||
serverName="S1",
|
||||
)
|
||||
update_results = [1, 0]
|
||||
|
||||
async def flaky_update_many(*args, **kwargs):
|
||||
return update_results.pop(0)
|
||||
|
||||
with (
|
||||
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token,
|
||||
patch("backend.platform_linking.db.PlatformLink") as mock_link,
|
||||
patch("backend.platform_linking.db.transaction", new=_fake_transaction),
|
||||
):
|
||||
mock_token.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=fake_token
|
||||
)
|
||||
mock_link.prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_token.prisma.return_value.update_many = flaky_update_many
|
||||
mock_link.prisma.return_value.create = AsyncMock(return_value=MagicMock())
|
||||
|
||||
results = await asyncio.gather(
|
||||
confirm_server_link("abc", "u1"),
|
||||
confirm_server_link("abc", "u1"),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
successes = [r for r in results if not isinstance(r, Exception)]
|
||||
losses = [r for r in results if isinstance(r, LinkTokenExpiredError)]
|
||||
assert len(successes) == 1
|
||||
assert len(losses) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gather_confirm_different_users_one_winner_no_hijack(self):
|
||||
# Different users racing the same token: still exactly one winner,
|
||||
# and the other gets a clean LinkTokenExpiredError (no partial state).
|
||||
fake_token = MagicMock(
|
||||
linkType=LinkType.SERVER.value,
|
||||
usedAt=None,
|
||||
expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10),
|
||||
platform="DISCORD",
|
||||
platformServerId="g1",
|
||||
platformUserId="pu1",
|
||||
serverName="S1",
|
||||
)
|
||||
update_results = [1, 0]
|
||||
|
||||
async def flaky_update_many(*args, **kwargs):
|
||||
return update_results.pop(0)
|
||||
|
||||
created_link_user_ids: list[str] = []
|
||||
|
||||
async def record_create(*, data):
|
||||
created_link_user_ids.append(data["userId"])
|
||||
return MagicMock()
|
||||
|
||||
with (
|
||||
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token,
|
||||
patch("backend.platform_linking.db.PlatformLink") as mock_link,
|
||||
patch("backend.platform_linking.db.transaction", new=_fake_transaction),
|
||||
):
|
||||
mock_token.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=fake_token
|
||||
)
|
||||
mock_link.prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_token.prisma.return_value.update_many = flaky_update_many
|
||||
mock_link.prisma.return_value.create = record_create
|
||||
|
||||
results = await asyncio.gather(
|
||||
confirm_server_link("abc", "user-a"),
|
||||
confirm_server_link("abc", "user-b"),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
successes = [r for r in results if not isinstance(r, Exception)]
|
||||
losses = [r for r in results if isinstance(r, LinkTokenExpiredError)]
|
||||
assert len(successes) == 1
|
||||
assert len(losses) == 1
|
||||
assert len(created_link_user_ids) == 1
|
||||
assert created_link_user_ids[0] in ("user-a", "user-b")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gather_confirm_user_link_one_winner(self):
|
||||
fake_token = MagicMock(
|
||||
linkType=LinkType.USER.value,
|
||||
usedAt=None,
|
||||
expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10),
|
||||
platform="DISCORD",
|
||||
platformUserId="pu1",
|
||||
platformUsername="pu_name",
|
||||
)
|
||||
update_results = [1, 0]
|
||||
|
||||
async def flaky_update_many(*args, **kwargs):
|
||||
return update_results.pop(0)
|
||||
|
||||
with (
|
||||
patch("backend.platform_linking.db.PlatformLinkToken") as mock_token,
|
||||
patch("backend.platform_linking.db.PlatformUserLink") as mock_user_link,
|
||||
patch("backend.platform_linking.db.transaction", new=_fake_transaction),
|
||||
):
|
||||
mock_token.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=fake_token
|
||||
)
|
||||
mock_user_link.prisma.return_value.find_unique = AsyncMock(
|
||||
return_value=None
|
||||
)
|
||||
mock_token.prisma.return_value.update_many = flaky_update_many
|
||||
mock_user_link.prisma.return_value.create = AsyncMock(
|
||||
return_value=MagicMock()
|
||||
)
|
||||
|
||||
results = await asyncio.gather(
|
||||
confirm_user_link("abc", "user-a"),
|
||||
confirm_user_link("abc", "user-b"),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
successes = [r for r in results if not isinstance(r, Exception)]
|
||||
losses = [r for r in results if isinstance(r, LinkTokenExpiredError)]
|
||||
assert len(successes) == 1
|
||||
assert len(losses) == 1
|
||||
182
autogpt_platform/backend/backend/platform_linking/models.py
Normal file
182
autogpt_platform/backend/backend/platform_linking/models.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""Pydantic models for platform_linking requests and responses."""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Platform(str, Enum):
|
||||
"""Mirror of the Prisma PlatformType enum."""
|
||||
|
||||
DISCORD = "DISCORD"
|
||||
TELEGRAM = "TELEGRAM"
|
||||
SLACK = "SLACK"
|
||||
TEAMS = "TEAMS"
|
||||
WHATSAPP = "WHATSAPP"
|
||||
GITHUB = "GITHUB"
|
||||
LINEAR = "LINEAR"
|
||||
|
||||
|
||||
class LinkType(str, Enum):
|
||||
SERVER = "SERVER"
|
||||
USER = "USER"
|
||||
|
||||
|
||||
# ── Request Models ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class CreateLinkTokenRequest(BaseModel):
|
||||
platform: Platform = Field(description="Platform name")
|
||||
platform_server_id: str = Field(
|
||||
description="Server/guild/group ID on the platform",
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
)
|
||||
platform_user_id: str = Field(
|
||||
description="Platform user ID of the person claiming ownership",
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
)
|
||||
platform_username: str | None = Field(
|
||||
default=None,
|
||||
description="Display name of the person claiming ownership",
|
||||
max_length=255,
|
||||
)
|
||||
server_name: str | None = Field(
|
||||
default=None,
|
||||
description="Display name of the server/group",
|
||||
max_length=255,
|
||||
)
|
||||
channel_id: str | None = Field(
|
||||
default=None,
|
||||
description="Channel ID so the bot can send a confirmation message",
|
||||
max_length=255,
|
||||
)
|
||||
|
||||
|
||||
class CreateUserLinkTokenRequest(BaseModel):
|
||||
platform: Platform
|
||||
platform_user_id: str = Field(
|
||||
description="Platform user ID of the person linking their DMs",
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
)
|
||||
platform_username: str | None = Field(
|
||||
default=None,
|
||||
description="Their display name (best-effort for audit)",
|
||||
max_length=255,
|
||||
)
|
||||
|
||||
|
||||
class ResolveServerRequest(BaseModel):
|
||||
platform: Platform
|
||||
platform_server_id: str = Field(
|
||||
description="Server/guild/group ID to look up",
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
)
|
||||
|
||||
|
||||
class ResolveUserRequest(BaseModel):
|
||||
platform: Platform
|
||||
platform_user_id: str = Field(
|
||||
description="Platform user ID to look up",
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
)
|
||||
|
||||
|
||||
class BotChatRequest(BaseModel):
|
||||
"""Bot message request. If ``platform_server_id`` is set, the turn is
|
||||
billed to that server's owner; otherwise billed to ``platform_user_id``
|
||||
(DM context)."""
|
||||
|
||||
platform: Platform
|
||||
platform_server_id: str | None = Field(
|
||||
default=None,
|
||||
description="Server/guild/group ID — null for DM context",
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
)
|
||||
platform_user_id: str = Field(
|
||||
description="Platform user ID of the person who sent the message",
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
)
|
||||
message: str = Field(
|
||||
description="The user's message", min_length=1, max_length=32000
|
||||
)
|
||||
session_id: str | None = Field(
|
||||
default=None,
|
||||
description="Existing CoPilot session ID. If omitted, a new session is created.",
|
||||
)
|
||||
|
||||
|
||||
# ── Response Models ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class LinkTokenResponse(BaseModel):
|
||||
token: str
|
||||
expires_at: datetime
|
||||
link_url: str
|
||||
|
||||
|
||||
class LinkTokenStatusResponse(BaseModel):
|
||||
status: Literal["pending", "linked", "expired"]
|
||||
|
||||
|
||||
class LinkTokenInfoResponse(BaseModel):
|
||||
platform: str
|
||||
link_type: LinkType
|
||||
server_name: str | None = None
|
||||
|
||||
|
||||
class ResolveResponse(BaseModel):
|
||||
linked: bool
|
||||
|
||||
|
||||
class PlatformLinkInfo(BaseModel):
|
||||
id: str
|
||||
platform: str
|
||||
platform_server_id: str
|
||||
owner_platform_user_id: str
|
||||
server_name: str | None
|
||||
linked_at: datetime
|
||||
|
||||
|
||||
class PlatformUserLinkInfo(BaseModel):
|
||||
id: str
|
||||
platform: str
|
||||
platform_user_id: str
|
||||
platform_username: str | None
|
||||
linked_at: datetime
|
||||
|
||||
|
||||
class ConfirmLinkResponse(BaseModel):
|
||||
success: bool
|
||||
link_type: LinkType = LinkType.SERVER
|
||||
platform: str
|
||||
platform_server_id: str
|
||||
server_name: str | None
|
||||
|
||||
|
||||
class ConfirmUserLinkResponse(BaseModel):
|
||||
success: bool
|
||||
link_type: LinkType = LinkType.USER
|
||||
platform: str
|
||||
platform_user_id: str
|
||||
|
||||
|
||||
class DeleteLinkResponse(BaseModel):
|
||||
success: bool
|
||||
|
||||
|
||||
class ChatTurnHandle(BaseModel):
|
||||
"""Subscribe keys for a pending copilot turn."""
|
||||
|
||||
session_id: str
|
||||
turn_id: str
|
||||
user_id: str
|
||||
subscribe_from: str = "0-0"
|
||||
178
autogpt_platform/backend/backend/platform_linking/models_test.py
Normal file
178
autogpt_platform/backend/backend/platform_linking/models_test.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Schema validation tests for platform_linking Pydantic models."""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from .models import (
|
||||
BotChatRequest,
|
||||
ConfirmLinkResponse,
|
||||
CreateLinkTokenRequest,
|
||||
DeleteLinkResponse,
|
||||
LinkTokenStatusResponse,
|
||||
Platform,
|
||||
ResolveResponse,
|
||||
ResolveServerRequest,
|
||||
)
|
||||
|
||||
|
||||
class TestPlatformEnum:
|
||||
def test_all_platforms_exist(self):
|
||||
assert Platform.DISCORD.value == "DISCORD"
|
||||
assert Platform.TELEGRAM.value == "TELEGRAM"
|
||||
assert Platform.SLACK.value == "SLACK"
|
||||
assert Platform.TEAMS.value == "TEAMS"
|
||||
assert Platform.WHATSAPP.value == "WHATSAPP"
|
||||
assert Platform.GITHUB.value == "GITHUB"
|
||||
assert Platform.LINEAR.value == "LINEAR"
|
||||
|
||||
|
||||
class TestCreateLinkTokenRequest:
|
||||
def test_valid_request(self):
|
||||
req = CreateLinkTokenRequest(
|
||||
platform=Platform.DISCORD,
|
||||
platform_server_id="1126875755960336515",
|
||||
platform_user_id="353922987235213313",
|
||||
platform_username="Bently",
|
||||
server_name="My Discord Server",
|
||||
)
|
||||
assert req.platform == Platform.DISCORD
|
||||
assert req.platform_server_id == "1126875755960336515"
|
||||
assert req.platform_user_id == "353922987235213313"
|
||||
assert req.server_name == "My Discord Server"
|
||||
|
||||
def test_minimal_request(self):
|
||||
req = CreateLinkTokenRequest(
|
||||
platform=Platform.TELEGRAM,
|
||||
platform_server_id="-100123456789",
|
||||
platform_user_id="987654321",
|
||||
)
|
||||
assert req.server_name is None
|
||||
assert req.platform_username is None
|
||||
|
||||
def test_empty_server_id_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
CreateLinkTokenRequest(
|
||||
platform=Platform.DISCORD,
|
||||
platform_server_id="",
|
||||
platform_user_id="123",
|
||||
)
|
||||
|
||||
def test_too_long_server_id_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
CreateLinkTokenRequest(
|
||||
platform=Platform.DISCORD,
|
||||
platform_server_id="x" * 256,
|
||||
platform_user_id="123",
|
||||
)
|
||||
|
||||
def test_invalid_platform_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
CreateLinkTokenRequest.model_validate(
|
||||
{
|
||||
"platform": "INVALID",
|
||||
"platform_server_id": "123",
|
||||
"platform_user_id": "456",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestResolveServerRequest:
|
||||
def test_valid_request(self):
|
||||
req = ResolveServerRequest(
|
||||
platform=Platform.DISCORD,
|
||||
platform_server_id="1126875755960336515",
|
||||
)
|
||||
assert req.platform == Platform.DISCORD
|
||||
assert req.platform_server_id == "1126875755960336515"
|
||||
|
||||
def test_empty_server_id_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
ResolveServerRequest(
|
||||
platform=Platform.SLACK,
|
||||
platform_server_id="",
|
||||
)
|
||||
|
||||
|
||||
class TestBotChatRequest:
|
||||
def test_server_context(self):
|
||||
req = BotChatRequest(
|
||||
platform=Platform.DISCORD,
|
||||
platform_server_id="1126875755960336515",
|
||||
platform_user_id="353922987235213313",
|
||||
message="Hello CoPilot!",
|
||||
)
|
||||
assert req.platform == Platform.DISCORD
|
||||
assert req.platform_server_id == "1126875755960336515"
|
||||
assert req.session_id is None
|
||||
|
||||
def test_dm_context_omits_server_id(self):
|
||||
req = BotChatRequest(
|
||||
platform=Platform.DISCORD,
|
||||
platform_user_id="353922987235213313",
|
||||
message="Hello in DMs!",
|
||||
)
|
||||
assert req.platform_server_id is None
|
||||
|
||||
def test_with_session_id(self):
|
||||
req = BotChatRequest(
|
||||
platform=Platform.DISCORD,
|
||||
platform_server_id="guild_123",
|
||||
platform_user_id="user_456",
|
||||
message="follow up",
|
||||
session_id="session-uuid-here",
|
||||
)
|
||||
assert req.session_id == "session-uuid-here"
|
||||
|
||||
def test_empty_message_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
BotChatRequest(
|
||||
platform=Platform.DISCORD,
|
||||
platform_server_id="guild_123",
|
||||
platform_user_id="user_456",
|
||||
message="",
|
||||
)
|
||||
|
||||
def test_empty_string_server_id_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
BotChatRequest(
|
||||
platform=Platform.DISCORD,
|
||||
platform_server_id="",
|
||||
platform_user_id="user_456",
|
||||
message="hi",
|
||||
)
|
||||
|
||||
|
||||
class TestResponseModels:
|
||||
def test_link_token_status_pending(self):
|
||||
resp = LinkTokenStatusResponse(status="pending")
|
||||
assert resp.status == "pending"
|
||||
|
||||
def test_link_token_status_linked(self):
|
||||
resp = LinkTokenStatusResponse(status="linked")
|
||||
assert resp.status == "linked"
|
||||
|
||||
def test_link_token_status_expired(self):
|
||||
resp = LinkTokenStatusResponse(status="expired")
|
||||
assert resp.status == "expired"
|
||||
|
||||
def test_resolve_linked(self):
|
||||
resp = ResolveResponse(linked=True)
|
||||
assert resp.linked is True
|
||||
|
||||
def test_resolve_not_linked(self):
|
||||
resp = ResolveResponse(linked=False)
|
||||
assert resp.linked is False
|
||||
|
||||
def test_confirm_link_response(self):
|
||||
resp = ConfirmLinkResponse(
|
||||
success=True,
|
||||
platform="DISCORD",
|
||||
platform_server_id="1126875755960336515",
|
||||
server_name="My Server",
|
||||
)
|
||||
assert resp.success is True
|
||||
assert resp.server_name == "My Server"
|
||||
|
||||
def test_delete_link_response(self):
|
||||
resp = DeleteLinkResponse(success=True)
|
||||
assert resp.success is True
|
||||
15
autogpt_platform/backend/backend/platform_linking_manager.py
Normal file
15
autogpt_platform/backend/backend/platform_linking_manager.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from backend.app import run_processes
|
||||
from backend.platform_linking.manager import PlatformLinkingManager
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Run the AutoGPT-server Platform Linking Manager service.
|
||||
"""
|
||||
run_processes(
|
||||
PlatformLinkingManager(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -27,6 +27,7 @@ if TYPE_CHECKING:
|
||||
from backend.executor.scheduler import SchedulerClient
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.platform_linking.manager import PlatformLinkingManagerClient
|
||||
|
||||
|
||||
@thread_cached
|
||||
@@ -67,6 +68,15 @@ def get_notification_manager_client() -> "NotificationManagerClient":
|
||||
return get_service_client(NotificationManagerClient)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_platform_linking_manager_client() -> "PlatformLinkingManagerClient":
|
||||
"""Get a thread-cached PlatformLinkingManagerClient."""
|
||||
from backend.platform_linking.manager import PlatformLinkingManagerClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(PlatformLinkingManagerClient)
|
||||
|
||||
|
||||
# ============ Execution Event Bus Helpers ============ #
|
||||
|
||||
|
||||
|
||||
@@ -155,3 +155,19 @@ class RedisError(Exception):
|
||||
"""Raised when there is an error interacting with Redis"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LinkAlreadyExistsError(ValueError):
|
||||
"""A platform_linking target (server or user) is already linked."""
|
||||
|
||||
|
||||
class LinkTokenExpiredError(ValueError):
|
||||
"""A platform_linking token has expired or been consumed."""
|
||||
|
||||
|
||||
class LinkFlowMismatchError(ValueError):
|
||||
"""A platform_linking token was used for the wrong flow (server vs user)."""
|
||||
|
||||
|
||||
class DuplicateChatMessageError(ValueError):
|
||||
"""The same user message is already in flight for this chat session."""
|
||||
|
||||
@@ -252,6 +252,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="The port for notification service daemon to run on",
|
||||
)
|
||||
|
||||
platform_linking_service_port: int = Field(
|
||||
default=8009,
|
||||
description="The port for the platform_linking manager daemon to run on",
|
||||
)
|
||||
|
||||
otto_api_url: str = Field(
|
||||
default="",
|
||||
description="The URL for the Otto API service",
|
||||
@@ -269,6 +274,13 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
"This value is then used to generate redirect URLs for OAuth flows.",
|
||||
)
|
||||
|
||||
platform_link_base_url: str = Field(
|
||||
default="https://platform.agpt.co/link",
|
||||
description="Base URL the bot service prepends to one-time linking "
|
||||
"tokens when it posts them to users ({base}/{token}?platform=...). "
|
||||
"Should point at the frontend /link page.",
|
||||
)
|
||||
|
||||
media_gcs_bucket_name: str = Field(
|
||||
default="",
|
||||
description="The name of the Google Cloud Storage bucket for media files",
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "PlatformType" AS ENUM ('DISCORD', 'TELEGRAM', 'SLACK', 'TEAMS', 'WHATSAPP', 'GITHUB', 'LINEAR');
|
||||
|
||||
-- CreateTable
|
||||
-- PlatformLink maps a platform server (Discord guild, Telegram group, etc.) to an AutoGPT
|
||||
-- owner account. The first user to authenticate becomes the owner — all usage from that
|
||||
-- server is billed to their account. Each user within the server gets their own CoPilot
|
||||
-- session, all visible in the owner's AutoGPT account.
|
||||
CREATE TABLE "PlatformLink" (
|
||||
"id" TEXT NOT NULL,
|
||||
"userId" TEXT NOT NULL,
|
||||
"platform" "PlatformType" NOT NULL,
|
||||
"platformServerId" TEXT NOT NULL,
|
||||
"ownerPlatformUserId" TEXT NOT NULL,
|
||||
"serverName" TEXT,
|
||||
"linkedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
CONSTRAINT "PlatformLink_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
-- PlatformLinkToken is a one-time token for the server linking flow.
|
||||
CREATE TABLE "PlatformLinkToken" (
|
||||
"id" TEXT NOT NULL,
|
||||
"token" TEXT NOT NULL,
|
||||
"platform" "PlatformType" NOT NULL,
|
||||
"platformServerId" TEXT NOT NULL,
|
||||
"platformUserId" TEXT NOT NULL,
|
||||
"platformUsername" TEXT,
|
||||
"serverName" TEXT,
|
||||
"channelId" TEXT,
|
||||
"expiresAt" TIMESTAMP(3) NOT NULL,
|
||||
"usedAt" TIMESTAMP(3),
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
CONSTRAINT "PlatformLinkToken_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "PlatformLink_platform_platformServerId_key" ON "PlatformLink"("platform", "platformServerId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformLink_userId_idx" ON "PlatformLink"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "PlatformLinkToken_token_key" ON "PlatformLinkToken"("token");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformLinkToken_platform_platformServerId_idx" ON "PlatformLinkToken"("platform", "platformServerId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformLinkToken_expiresAt_idx" ON "PlatformLinkToken"("expiresAt");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "PlatformLink" ADD CONSTRAINT "PlatformLink_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
@@ -0,0 +1,37 @@
|
||||
-- CreateEnum
|
||||
-- Server links (group chats / guilds) and user links (personal DMs) are
|
||||
-- fully independent — a user who owns a linked server still has to link
|
||||
-- their DMs separately.
|
||||
CREATE TYPE "PlatformLinkType" AS ENUM ('SERVER', 'USER');
|
||||
|
||||
-- CreateTable
|
||||
-- PlatformUserLink maps an individual platform user identity to an AutoGPT
|
||||
-- account for 1:1 DMs with the bot. Independent from PlatformLink.
|
||||
CREATE TABLE "PlatformUserLink" (
|
||||
"id" TEXT NOT NULL,
|
||||
"userId" TEXT NOT NULL,
|
||||
"platform" "PlatformType" NOT NULL,
|
||||
"platformUserId" TEXT NOT NULL,
|
||||
"platformUsername" TEXT,
|
||||
"linkedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
CONSTRAINT "PlatformUserLink_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "PlatformUserLink_platform_platformUserId_key" ON "PlatformUserLink"("platform", "platformUserId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformUserLink_userId_idx" ON "PlatformUserLink"("userId");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "PlatformUserLink" ADD CONSTRAINT "PlatformUserLink_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AlterTable: PlatformLinkToken now supports SERVER or USER tokens.
|
||||
-- Existing rows are all SERVER (default matches the column default).
|
||||
ALTER TABLE "PlatformLinkToken"
|
||||
ADD COLUMN "linkType" "PlatformLinkType" NOT NULL DEFAULT 'SERVER',
|
||||
ALTER COLUMN "platformServerId" DROP NOT NULL;
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformLinkToken_platform_platformUserId_idx" ON "PlatformLinkToken"("platform", "platformUserId");
|
||||
@@ -0,0 +1,25 @@
|
||||
-- CreateTable
|
||||
CREATE TABLE "SharedExecutionFile" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"executionId" TEXT NOT NULL,
|
||||
"fileId" TEXT NOT NULL,
|
||||
"shareToken" TEXT NOT NULL,
|
||||
|
||||
CONSTRAINT "SharedExecutionFile_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "SharedExecutionFile_shareToken_fileId_key" ON "SharedExecutionFile"("shareToken", "fileId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "SharedExecutionFile_shareToken_idx" ON "SharedExecutionFile"("shareToken");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "SharedExecutionFile_executionId_idx" ON "SharedExecutionFile"("executionId");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "SharedExecutionFile" ADD CONSTRAINT "SharedExecutionFile_executionId_fkey" FOREIGN KEY ("executionId") REFERENCES "AgentGraphExecution"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "SharedExecutionFile" ADD CONSTRAINT "SharedExecutionFile_fileId_fkey" FOREIGN KEY ("fileId") REFERENCES "UserWorkspaceFile"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
@@ -129,6 +129,7 @@ db = "backend.db:main"
|
||||
ws = "backend.ws:main"
|
||||
scheduler = "backend.scheduler:main"
|
||||
notification = "backend.notification:main"
|
||||
platform-linking-manager = "backend.platform_linking_manager:main"
|
||||
executor = "backend.exec:main"
|
||||
analytics-setup = "scripts.generate_views:main_setup"
|
||||
analytics-views = "scripts.generate_views:main_views"
|
||||
|
||||
@@ -81,6 +81,10 @@ model User {
|
||||
OAuthAuthorizationCodes OAuthAuthorizationCode[]
|
||||
OAuthAccessTokens OAuthAccessToken[]
|
||||
OAuthRefreshTokens OAuthRefreshToken[]
|
||||
|
||||
// Platform bot linking
|
||||
PlatformLinks PlatformLink[]
|
||||
PlatformUserLinks PlatformUserLink[]
|
||||
}
|
||||
|
||||
enum SubscriptionTier {
|
||||
@@ -200,10 +204,32 @@ model UserWorkspaceFile {
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
SharedExecutionFiles SharedExecutionFile[]
|
||||
|
||||
@@unique([workspaceId, path])
|
||||
@@index([workspaceId, isDeleted])
|
||||
}
|
||||
|
||||
// Tracks which workspace files are exposed via a shared execution.
|
||||
// Created when sharing is enabled, deleted when sharing is disabled.
|
||||
// The public file download endpoint validates against this table.
|
||||
model SharedExecutionFile {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
executionId String
|
||||
Execution AgentGraphExecution @relation(fields: [executionId], references: [id], onDelete: Cascade)
|
||||
|
||||
fileId String
|
||||
File UserWorkspaceFile @relation(fields: [fileId], references: [id], onDelete: Cascade)
|
||||
|
||||
shareToken String
|
||||
|
||||
@@unique([shareToken, fileId])
|
||||
@@index([shareToken])
|
||||
@@index([executionId])
|
||||
}
|
||||
|
||||
model BuilderSearchHistory {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
@@ -585,9 +611,10 @@ model AgentGraphExecution {
|
||||
ChildExecutions AgentGraphExecution[] @relation("ParentChildExecution")
|
||||
|
||||
// Sharing fields
|
||||
isShared Boolean @default(false)
|
||||
shareToken String? @unique
|
||||
sharedAt DateTime?
|
||||
isShared Boolean @default(false)
|
||||
shareToken String? @unique
|
||||
sharedAt DateTime?
|
||||
SharedExecutionFiles SharedExecutionFile[]
|
||||
|
||||
@@index([agentGraphId, agentGraphVersion])
|
||||
@@index([userId, isDeleted, createdAt])
|
||||
@@ -1366,3 +1393,84 @@ model OAuthRefreshToken {
|
||||
@@index([userId, applicationId])
|
||||
@@index([expiresAt]) // For cleanup
|
||||
}
|
||||
|
||||
// ── Platform Bot Linking ──────────────────────────────────────────────
|
||||
// Links external chat platform identities (Discord, Telegram, Slack, etc.)
|
||||
// to AutoGPT user accounts, enabling the multi-platform CoPilot bot.
|
||||
|
||||
enum PlatformType {
|
||||
DISCORD
|
||||
TELEGRAM
|
||||
SLACK
|
||||
TEAMS
|
||||
WHATSAPP
|
||||
GITHUB
|
||||
LINEAR
|
||||
}
|
||||
|
||||
// Whether a linking token claims a server (group chat / guild) or a personal
|
||||
// 1:1 user link (DMs). Server and user links are completely independent —
|
||||
// linking a server does not grant DM access and vice versa.
|
||||
enum PlatformLinkType {
|
||||
SERVER
|
||||
USER
|
||||
}
|
||||
|
||||
// Maps a platform server (Discord guild, Telegram group, Slack workspace, etc.)
|
||||
// to an AutoGPT user account. The user who first authenticates becomes the
|
||||
// "owner" — all usage from that server is attributed to their account.
|
||||
model PlatformLink {
|
||||
id String @id @default(uuid())
|
||||
userId String // AutoGPT user ID of the owner
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
platform PlatformType
|
||||
platformServerId String // Server/guild/group ID on that platform
|
||||
ownerPlatformUserId String // Platform user ID of the person who set it up
|
||||
serverName String? // Display name of the server (best-effort, may go stale)
|
||||
linkedAt DateTime @default(now())
|
||||
|
||||
@@unique([platform, platformServerId])
|
||||
@@index([userId])
|
||||
}
|
||||
|
||||
// Maps a platform user identity (a single Discord / Telegram / Slack user) to
|
||||
// an AutoGPT account for 1:1 DM conversations with the bot. Independent from
|
||||
// PlatformLink — a user who owns a linked server must still link their DMs
|
||||
// separately.
|
||||
model PlatformUserLink {
|
||||
id String @id @default(uuid())
|
||||
userId String // AutoGPT user ID
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
platform PlatformType
|
||||
platformUserId String // Individual's user ID on the platform
|
||||
platformUsername String? // Display name at link time (best-effort)
|
||||
linkedAt DateTime @default(now())
|
||||
|
||||
@@unique([platform, platformUserId])
|
||||
@@index([userId])
|
||||
}
|
||||
|
||||
// One-time tokens for either the server linking flow or the DM (user) linking
|
||||
// flow. linkType determines which target is populated — SERVER tokens carry
|
||||
// platformServerId + serverName + ownerPlatformUserId, USER tokens carry
|
||||
// platformUserId only.
|
||||
model PlatformLinkToken {
|
||||
id String @id @default(uuid())
|
||||
token String @unique
|
||||
platform PlatformType
|
||||
linkType PlatformLinkType @default(SERVER)
|
||||
// SERVER token fields (null for USER tokens)
|
||||
platformServerId String? // Server/guild/group ID being linked
|
||||
serverName String? // Server display name
|
||||
channelId String? // Channel to send confirmation back to
|
||||
// Always set — platform user ID of the person who will claim ownership
|
||||
platformUserId String
|
||||
platformUsername String? // Their display name
|
||||
expiresAt DateTime
|
||||
usedAt DateTime?
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
@@index([platform, platformServerId])
|
||||
@@index([platform, platformUserId])
|
||||
@@index([expiresAt])
|
||||
}
|
||||
|
||||
@@ -92,11 +92,12 @@
|
||||
"geist": "1.5.1",
|
||||
"highlight.js": "11.11.1",
|
||||
"jaro-winkler": "0.2.8",
|
||||
"jszip": "3.10.1",
|
||||
"katex": "0.16.25",
|
||||
"launchdarkly-react-client-sdk": "3.9.0",
|
||||
"lodash": "4.17.21",
|
||||
"lucide-react": "0.552.0",
|
||||
"next": "15.4.10",
|
||||
"next": "15.4.11",
|
||||
"next-themes": "0.4.6",
|
||||
"nuqs": "2.7.2",
|
||||
"posthog-js": "1.334.1",
|
||||
|
||||
101
autogpt_platform/frontend/pnpm-lock.yaml
generated
101
autogpt_platform/frontend/pnpm-lock.yaml
generated
@@ -26,7 +26,7 @@ importers:
|
||||
version: 5.2.2(react-hook-form@7.66.0(react@18.3.1))
|
||||
'@next/third-parties':
|
||||
specifier: 15.4.6
|
||||
version: 15.4.6(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
|
||||
version: 15.4.6(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
|
||||
'@phosphor-icons/react':
|
||||
specifier: 2.1.10
|
||||
version: 2.1.10(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
@@ -107,7 +107,7 @@ importers:
|
||||
version: 6.1.2(@rjsf/utils@6.1.2(react@18.3.1))
|
||||
'@sentry/nextjs':
|
||||
specifier: 10.27.0
|
||||
version: 10.27.0(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)(webpack@5.104.1(esbuild@0.25.12))
|
||||
version: 10.27.0(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)(webpack@5.104.1(esbuild@0.25.12))
|
||||
'@streamdown/cjk':
|
||||
specifier: 1.0.1
|
||||
version: 1.0.1(@types/mdast@4.0.4)(micromark-util-types@2.0.2)(micromark@4.0.2)(react@18.3.1)(unified@11.0.5)
|
||||
@@ -134,10 +134,10 @@ importers:
|
||||
version: 0.2.4
|
||||
'@vercel/analytics':
|
||||
specifier: 1.5.0
|
||||
version: 1.5.0(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
|
||||
version: 1.5.0(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
|
||||
'@vercel/speed-insights':
|
||||
specifier: 1.2.0
|
||||
version: 1.2.0(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
|
||||
version: 1.2.0(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
|
||||
'@xyflow/react':
|
||||
specifier: 12.9.2
|
||||
version: 12.9.2(@types/react@18.3.17)(immer@11.1.3)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
@@ -185,13 +185,16 @@ importers:
|
||||
version: 12.23.24(@emotion/is-prop-valid@1.2.2)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
geist:
|
||||
specifier: 1.5.1
|
||||
version: 1.5.1(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))
|
||||
version: 1.5.1(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))
|
||||
highlight.js:
|
||||
specifier: 11.11.1
|
||||
version: 11.11.1
|
||||
jaro-winkler:
|
||||
specifier: 0.2.8
|
||||
version: 0.2.8
|
||||
jszip:
|
||||
specifier: 3.10.1
|
||||
version: 3.10.1
|
||||
katex:
|
||||
specifier: 0.16.25
|
||||
version: 0.16.25
|
||||
@@ -205,14 +208,14 @@ importers:
|
||||
specifier: 0.552.0
|
||||
version: 0.552.0(react@18.3.1)
|
||||
next:
|
||||
specifier: 15.4.10
|
||||
version: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
specifier: 15.4.11
|
||||
version: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
next-themes:
|
||||
specifier: 0.4.6
|
||||
version: 0.4.6(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
nuqs:
|
||||
specifier: 2.7.2
|
||||
version: 2.7.2(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
|
||||
version: 2.7.2(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
|
||||
posthog-js:
|
||||
specifier: 1.334.1
|
||||
version: 1.334.1
|
||||
@@ -330,7 +333,7 @@ importers:
|
||||
version: 9.1.5(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2)(vite@7.3.1(@types/node@24.10.0)(jiti@2.6.1)(terser@5.44.1)(yaml@2.8.2)))
|
||||
'@storybook/nextjs':
|
||||
specifier: 9.1.5
|
||||
version: 9.1.5(esbuild@0.25.12)(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2)(vite@7.3.1(@types/node@24.10.0)(jiti@2.6.1)(terser@5.44.1)(yaml@2.8.2)))(type-fest@4.41.0)(typescript@5.9.3)(webpack-hot-middleware@2.26.1)(webpack@5.104.1(esbuild@0.25.12))
|
||||
version: 9.1.5(esbuild@0.25.12)(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2)(vite@7.3.1(@types/node@24.10.0)(jiti@2.6.1)(terser@5.44.1)(yaml@2.8.2)))(type-fest@4.41.0)(typescript@5.9.3)(webpack-hot-middleware@2.26.1)(webpack@5.104.1(esbuild@0.25.12))
|
||||
'@tanstack/eslint-plugin-query':
|
||||
specifier: 5.91.2
|
||||
version: 5.91.2(eslint@8.57.1)(typescript@5.9.3)
|
||||
@@ -1844,8 +1847,8 @@ packages:
|
||||
'@neoconfetti/react@1.0.0':
|
||||
resolution: {integrity: sha512-klcSooChXXOzIm+SE5IISIAn3bYzYfPjbX7D7HoqZL84oAfgREeSg5vSIaSFH+DaGzzvImTyWe1OyrJ67vik4A==}
|
||||
|
||||
'@next/env@15.4.10':
|
||||
resolution: {integrity: sha512-knhmoJ0Vv7VRf6pZEPSnciUG1S4bIhWx+qTYBW/AjxEtlzsiNORPk8sFDCEvqLfmKuey56UB9FL1UdHEV3uBrg==}
|
||||
'@next/env@15.4.11':
|
||||
resolution: {integrity: sha512-mIYp/091eYfPFezKX7ZPTWqrmSXq+ih6+LcUyKvLmeLQGhlPtot33kuEOd4U+xAA7sFfj21+OtCpIZx0g5SpvQ==}
|
||||
|
||||
'@next/eslint-plugin-next@15.5.7':
|
||||
resolution: {integrity: sha512-DtRU2N7BkGr8r+pExfuWHwMEPX5SD57FeA6pxdgCHODo+b/UgIgjE+rgWKtJAbEbGhVZ2jtHn4g3wNhWFoNBQQ==}
|
||||
@@ -5919,6 +5922,9 @@ packages:
|
||||
engines: {node: '>=16.x'}
|
||||
hasBin: true
|
||||
|
||||
immediate@3.0.6:
|
||||
resolution: {integrity: sha512-XXOFtyqDjNDAQxVfYxuF7g9Il/IbWmmlQg2MYKOH8ExIT1qg6xc4zyS3HaEEATgs1btfzxq15ciUiY7gjSXRGQ==}
|
||||
|
||||
immer@10.2.0:
|
||||
resolution: {integrity: sha512-d/+XTN3zfODyjr89gM3mPq1WNX2B8pYsu7eORitdwyA2sBubnTl3laYlBk4sXY5FUa5qTZGBDPJICVbvqzjlbw==}
|
||||
|
||||
@@ -6279,6 +6285,9 @@ packages:
|
||||
resolution: {integrity: sha512-ZZow9HBI5O6EPgSJLUb8n2NKgmVWTwCvHGwFuJlMjvLFqlGG6pjirPhtdsseaLZjSibD8eegzmYpUZwoIlj2cQ==}
|
||||
engines: {node: '>=4.0'}
|
||||
|
||||
jszip@3.10.1:
|
||||
resolution: {integrity: sha512-xXDvecyTpGLrqFrvkrUSoxxfJI5AH7U8zxxtVclpsUtMCq4JQ290LY8AW5c7Ggnr/Y/oK+bQMbqK2qmtk3pN4g==}
|
||||
|
||||
junit-report-builder@5.1.1:
|
||||
resolution: {integrity: sha512-ZNOIIGMzqCGcHQEA2Q4rIQQ3Df6gSIfne+X9Rly9Bc2y55KxAZu8iGv+n2pP0bLf0XAOctJZgeloC54hWzCahQ==}
|
||||
engines: {node: '>=16'}
|
||||
@@ -6348,6 +6357,9 @@ packages:
|
||||
resolution: {integrity: sha512-+bT2uH4E5LGE7h/n3evcS/sQlJXCpIp6ym8OWJ5eV6+67Dsql/LaaT7qJBAt2rzfoa/5QBGBhxDix1dMt2kQKQ==}
|
||||
engines: {node: '>= 0.8.0'}
|
||||
|
||||
lie@3.3.0:
|
||||
resolution: {integrity: sha512-UaiMJzeWRlEujzAuw5LokY1L5ecNQYZKfmyZ9L7wDHb/p5etKaxXhohBcrw0EYby+G/NA52vRSN4N39dxHAIwQ==}
|
||||
|
||||
lilconfig@3.1.3:
|
||||
resolution: {integrity: sha512-/vlFKAoH5Cgt3Ie+JLhRbwOsCQePABiU3tJ1egGvyQ+33R/vcwM2Zl2QR/LzjsBeItPt3oSVXapn+m4nQDvpzw==}
|
||||
engines: {node: '>=14'}
|
||||
@@ -6839,8 +6851,8 @@ packages:
|
||||
react: ^16.8 || ^17 || ^18 || ^19 || ^19.0.0-rc
|
||||
react-dom: ^16.8 || ^17 || ^18 || ^19 || ^19.0.0-rc
|
||||
|
||||
next@15.4.10:
|
||||
resolution: {integrity: sha512-itVlc79QjpKMFMRhP+kbGKaSG/gZM6RCvwhEbwmCNF06CdDiNaoHcbeg0PqkEa2GOcn8KJ0nnc7+yL7EjoYLHQ==}
|
||||
next@15.4.11:
|
||||
resolution: {integrity: sha512-IJRyXal45mIsshZI5XJne/intjusslUP1F+FHVBIyMGEqbYtIq1Irdx5vdWBBg58smviPDycmDeV6txsfkv1RQ==}
|
||||
engines: {node: ^18.18.0 || ^19.8.0 || >= 20.0.0}
|
||||
hasBin: true
|
||||
peerDependencies:
|
||||
@@ -10423,7 +10435,7 @@ snapshots:
|
||||
|
||||
'@neoconfetti/react@1.0.0': {}
|
||||
|
||||
'@next/env@15.4.10': {}
|
||||
'@next/env@15.4.11': {}
|
||||
|
||||
'@next/eslint-plugin-next@15.5.7':
|
||||
dependencies:
|
||||
@@ -10453,9 +10465,9 @@ snapshots:
|
||||
'@next/swc-win32-x64-msvc@15.4.8':
|
||||
optional: true
|
||||
|
||||
'@next/third-parties@15.4.6(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)':
|
||||
'@next/third-parties@15.4.6(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)':
|
||||
dependencies:
|
||||
next: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
next: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
react: 18.3.1
|
||||
third-party-capital: 1.0.20
|
||||
|
||||
@@ -11770,7 +11782,7 @@ snapshots:
|
||||
|
||||
'@sentry/core@10.27.0': {}
|
||||
|
||||
'@sentry/nextjs@10.27.0(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)(webpack@5.104.1(esbuild@0.25.12))':
|
||||
'@sentry/nextjs@10.27.0(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)(webpack@5.104.1(esbuild@0.25.12))':
|
||||
dependencies:
|
||||
'@opentelemetry/api': 1.9.0
|
||||
'@opentelemetry/semantic-conventions': 1.38.0
|
||||
@@ -11783,7 +11795,7 @@ snapshots:
|
||||
'@sentry/react': 10.27.0(react@18.3.1)
|
||||
'@sentry/vercel-edge': 10.27.0
|
||||
'@sentry/webpack-plugin': 4.6.1(webpack@5.104.1(esbuild@0.25.12))
|
||||
next: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
next: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
resolve: 1.22.8
|
||||
rollup: 4.55.1
|
||||
stacktrace-parser: 0.1.11
|
||||
@@ -12162,7 +12174,7 @@ snapshots:
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
|
||||
'@storybook/nextjs@9.1.5(esbuild@0.25.12)(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2)(vite@7.3.1(@types/node@24.10.0)(jiti@2.6.1)(terser@5.44.1)(yaml@2.8.2)))(type-fest@4.41.0)(typescript@5.9.3)(webpack-hot-middleware@2.26.1)(webpack@5.104.1(esbuild@0.25.12))':
|
||||
'@storybook/nextjs@9.1.5(esbuild@0.25.12)(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2)(vite@7.3.1(@types/node@24.10.0)(jiti@2.6.1)(terser@5.44.1)(yaml@2.8.2)))(type-fest@4.41.0)(typescript@5.9.3)(webpack-hot-middleware@2.26.1)(webpack@5.104.1(esbuild@0.25.12))':
|
||||
dependencies:
|
||||
'@babel/core': 7.28.5
|
||||
'@babel/plugin-syntax-bigint': 7.8.3(@babel/core@7.28.5)
|
||||
@@ -12186,7 +12198,7 @@ snapshots:
|
||||
css-loader: 6.11.0(webpack@5.104.1(esbuild@0.25.12))
|
||||
image-size: 2.0.2
|
||||
loader-utils: 3.3.1
|
||||
next: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
next: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
node-polyfill-webpack-plugin: 2.0.1(webpack@5.104.1(esbuild@0.25.12))
|
||||
postcss: 8.5.6
|
||||
postcss-loader: 8.2.0(postcss@8.5.6)(typescript@5.9.3)(webpack@5.104.1(esbuild@0.25.12))
|
||||
@@ -12872,16 +12884,16 @@ snapshots:
|
||||
'@unrs/resolver-binding-win32-x64-msvc@1.11.1':
|
||||
optional: true
|
||||
|
||||
'@vercel/analytics@1.5.0(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)':
|
||||
'@vercel/analytics@1.5.0(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)':
|
||||
optionalDependencies:
|
||||
next: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
next: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
react: 18.3.1
|
||||
|
||||
'@vercel/oidc@3.1.0': {}
|
||||
|
||||
'@vercel/speed-insights@1.2.0(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)':
|
||||
'@vercel/speed-insights@1.2.0(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)':
|
||||
optionalDependencies:
|
||||
next: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
next: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
react: 18.3.1
|
||||
|
||||
'@vitejs/plugin-react@5.1.2(vite@7.3.1(@types/node@24.10.0)(jiti@2.6.1)(terser@5.44.1)(yaml@2.8.2))':
|
||||
@@ -14449,8 +14461,8 @@ snapshots:
|
||||
'@typescript-eslint/parser': 8.52.0(eslint@8.57.1)(typescript@5.9.3)
|
||||
eslint: 8.57.1
|
||||
eslint-import-resolver-node: 0.3.9
|
||||
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1)
|
||||
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1)
|
||||
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1)
|
||||
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1)
|
||||
eslint-plugin-jsx-a11y: 6.10.2(eslint@8.57.1)
|
||||
eslint-plugin-react: 7.37.5(eslint@8.57.1)
|
||||
eslint-plugin-react-hooks: 5.2.0(eslint@8.57.1)
|
||||
@@ -14469,7 +14481,7 @@ snapshots:
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1):
|
||||
eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1):
|
||||
dependencies:
|
||||
'@nolyfill/is-core-module': 1.0.39
|
||||
debug: 4.4.3
|
||||
@@ -14480,22 +14492,22 @@ snapshots:
|
||||
tinyglobby: 0.2.15
|
||||
unrs-resolver: 1.11.1
|
||||
optionalDependencies:
|
||||
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1)
|
||||
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
eslint-module-utils@2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1):
|
||||
eslint-module-utils@2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1):
|
||||
dependencies:
|
||||
debug: 3.2.7
|
||||
optionalDependencies:
|
||||
'@typescript-eslint/parser': 8.52.0(eslint@8.57.1)(typescript@5.9.3)
|
||||
eslint: 8.57.1
|
||||
eslint-import-resolver-node: 0.3.9
|
||||
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1)
|
||||
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1):
|
||||
eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1):
|
||||
dependencies:
|
||||
'@rtsao/scc': 1.1.0
|
||||
array-includes: 3.1.9
|
||||
@@ -14506,7 +14518,7 @@ snapshots:
|
||||
doctrine: 2.1.0
|
||||
eslint: 8.57.1
|
||||
eslint-import-resolver-node: 0.3.9
|
||||
eslint-module-utils: 2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1)
|
||||
eslint-module-utils: 2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1)
|
||||
hasown: 2.0.2
|
||||
is-core-module: 2.16.1
|
||||
is-glob: 4.0.3
|
||||
@@ -14877,9 +14889,9 @@ snapshots:
|
||||
|
||||
functions-have-names@1.2.3: {}
|
||||
|
||||
geist@1.5.1(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)):
|
||||
geist@1.5.1(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)):
|
||||
dependencies:
|
||||
next: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
next: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
|
||||
generator-function@2.0.1: {}
|
||||
|
||||
@@ -15290,6 +15302,8 @@ snapshots:
|
||||
|
||||
image-size@2.0.2: {}
|
||||
|
||||
immediate@3.0.6: {}
|
||||
|
||||
immer@10.2.0: {}
|
||||
|
||||
immer@11.1.3: {}
|
||||
@@ -15646,6 +15660,13 @@ snapshots:
|
||||
object.assign: 4.1.7
|
||||
object.values: 1.2.1
|
||||
|
||||
jszip@3.10.1:
|
||||
dependencies:
|
||||
lie: 3.3.0
|
||||
pako: 1.0.11
|
||||
readable-stream: 2.3.8
|
||||
setimmediate: 1.0.5
|
||||
|
||||
junit-report-builder@5.1.1:
|
||||
dependencies:
|
||||
lodash: 4.17.21
|
||||
@@ -15739,6 +15760,10 @@ snapshots:
|
||||
prelude-ls: 1.2.1
|
||||
type-check: 0.4.0
|
||||
|
||||
lie@3.3.0:
|
||||
dependencies:
|
||||
immediate: 3.0.6
|
||||
|
||||
lilconfig@3.1.3: {}
|
||||
|
||||
lines-and-columns@1.2.4: {}
|
||||
@@ -16465,9 +16490,9 @@ snapshots:
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
|
||||
next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
|
||||
next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
|
||||
dependencies:
|
||||
'@next/env': 15.4.10
|
||||
'@next/env': 15.4.11
|
||||
'@swc/helpers': 0.5.15
|
||||
caniuse-lite: 1.0.30001762
|
||||
postcss: 8.4.31
|
||||
@@ -16569,12 +16594,12 @@ snapshots:
|
||||
dependencies:
|
||||
boolbase: 1.0.0
|
||||
|
||||
nuqs@2.7.2(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1):
|
||||
nuqs@2.7.2(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1):
|
||||
dependencies:
|
||||
'@standard-schema/spec': 1.0.0
|
||||
react: 18.3.1
|
||||
optionalDependencies:
|
||||
next: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
next: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
|
||||
oas-kit-common@1.0.8:
|
||||
dependencies:
|
||||
|
||||
@@ -119,7 +119,7 @@ export default function SharePage() {
|
||||
<CardTitle>Output</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<RunOutputs outputs={executionData.outputs} />
|
||||
<RunOutputs outputs={executionData.outputs} shareToken={token} />
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import type { Metadata } from "next";
|
||||
import Image from "next/image";
|
||||
import Link from "next/link";
|
||||
|
||||
export const metadata: Metadata = {
|
||||
title: "Shared Agent Run - AutoGPT",
|
||||
@@ -13,6 +15,27 @@ export default function ShareLayout({
|
||||
}) {
|
||||
return (
|
||||
<div className="min-h-screen bg-background">
|
||||
<header className="border-b border-border bg-background">
|
||||
<div className="container mx-auto flex justify-center px-4 py-4">
|
||||
<Link href="/" className="inline-block">
|
||||
<Image
|
||||
src="/autogpt-logo-dark-bg.png"
|
||||
alt="AutoGPT"
|
||||
width={120}
|
||||
height={54}
|
||||
className="hidden h-8 w-auto dark:block"
|
||||
/>
|
||||
<Image
|
||||
src="/autogpt-logo-light-bg.png"
|
||||
alt="AutoGPT"
|
||||
width={120}
|
||||
height={54}
|
||||
className="block h-8 w-auto dark:hidden"
|
||||
priority
|
||||
/>
|
||||
</Link>
|
||||
</div>
|
||||
</header>
|
||||
<div className="container mx-auto px-4 py-8">{children}</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -6,9 +6,11 @@ import { Suspense, useState } from "react";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import type { ArtifactRef } from "../../../store";
|
||||
import type { ArtifactClassification } from "../helpers";
|
||||
import { ArtifactErrorBoundary } from "./ArtifactErrorBoundary";
|
||||
import { ArtifactReactPreview } from "./ArtifactReactPreview";
|
||||
import { ArtifactSkeleton } from "./ArtifactSkeleton";
|
||||
import {
|
||||
FRAGMENT_LINK_INTERCEPTOR_SCRIPT,
|
||||
TAILWIND_CDN_URL,
|
||||
wrapWithHeadInjection,
|
||||
} from "@/lib/iframe-sandbox-csp";
|
||||
@@ -53,20 +55,35 @@ function ArtifactContentLoader({
|
||||
|
||||
return (
|
||||
<div ref={scrollRef} className="flex-1 overflow-y-auto">
|
||||
<ArtifactRenderer
|
||||
artifact={artifact}
|
||||
content={content}
|
||||
pdfUrl={pdfUrl}
|
||||
isSourceView={isSourceView}
|
||||
classification={classification}
|
||||
/>
|
||||
<ArtifactErrorBoundary
|
||||
artifactID={artifact.id}
|
||||
artifactTitle={artifact.title}
|
||||
artifactType={classification.type}
|
||||
>
|
||||
<ArtifactRenderer
|
||||
artifact={artifact}
|
||||
content={content}
|
||||
pdfUrl={pdfUrl}
|
||||
isSourceView={isSourceView}
|
||||
classification={classification}
|
||||
/>
|
||||
</ArtifactErrorBoundary>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function withCacheBust(src: string, nonce: number): string {
|
||||
if (nonce === 0) return src;
|
||||
const sep = src.includes("?") ? "&" : "?";
|
||||
return `${src}${sep}_retry=${nonce}`;
|
||||
}
|
||||
|
||||
function ArtifactImage({ src, alt }: { src: string; alt: string }) {
|
||||
const [loaded, setLoaded] = useState(false);
|
||||
const [error, setError] = useState(false);
|
||||
// Incremented on every Try Again so the URL changes and the browser
|
||||
// can't reuse a negative-cached response (SECRT-2221).
|
||||
const [retryNonce, setRetryNonce] = useState(0);
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
@@ -80,6 +97,7 @@ function ArtifactImage({ src, alt }: { src: string; alt: string }) {
|
||||
onClick={() => {
|
||||
setError(false);
|
||||
setLoaded(false);
|
||||
setRetryNonce((n) => n + 1);
|
||||
}}
|
||||
className="rounded-md border border-zinc-200 bg-white px-3 py-1.5 text-xs font-medium text-zinc-700 shadow-sm transition-colors hover:bg-zinc-50 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-violet-400"
|
||||
>
|
||||
@@ -96,7 +114,7 @@ function ArtifactImage({ src, alt }: { src: string; alt: string }) {
|
||||
)}
|
||||
{/* eslint-disable-next-line @next/next/no-img-element */}
|
||||
<img
|
||||
src={src}
|
||||
src={withCacheBust(src, retryNonce)}
|
||||
alt={alt}
|
||||
className={`max-h-full max-w-full object-contain transition-opacity ${loaded ? "opacity-100" : "opacity-0"}`}
|
||||
onLoad={() => setLoaded(true)}
|
||||
@@ -109,6 +127,7 @@ function ArtifactImage({ src, alt }: { src: string; alt: string }) {
|
||||
function ArtifactVideo({ src }: { src: string }) {
|
||||
const [loaded, setLoaded] = useState(false);
|
||||
const [error, setError] = useState(false);
|
||||
const [retryNonce, setRetryNonce] = useState(0);
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
@@ -122,6 +141,7 @@ function ArtifactVideo({ src }: { src: string }) {
|
||||
onClick={() => {
|
||||
setError(false);
|
||||
setLoaded(false);
|
||||
setRetryNonce((n) => n + 1);
|
||||
}}
|
||||
className="rounded-md border border-zinc-200 bg-white px-3 py-1.5 text-xs font-medium text-zinc-700 shadow-sm transition-colors hover:bg-zinc-50 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-violet-400"
|
||||
>
|
||||
@@ -137,7 +157,7 @@ function ArtifactVideo({ src }: { src: string }) {
|
||||
<Skeleton className="absolute inset-4 h-[calc(100%-2rem)] w-[calc(100%-2rem)] rounded-md" />
|
||||
)}
|
||||
<video
|
||||
src={src}
|
||||
src={withCacheBust(src, retryNonce)}
|
||||
controls
|
||||
preload="metadata"
|
||||
className={`max-h-full max-w-full rounded-md transition-opacity ${loaded ? "opacity-100" : "opacity-0"}`}
|
||||
@@ -200,7 +220,10 @@ function ArtifactRenderer({
|
||||
if (classification.type === "html") {
|
||||
// Inject Tailwind CDN — no CSP (see iframe-sandbox-csp.ts for why)
|
||||
const tailwindScript = `<script src="${TAILWIND_CDN_URL}"></script>`;
|
||||
const wrapped = wrapWithHeadInjection(content, tailwindScript);
|
||||
const wrapped = wrapWithHeadInjection(
|
||||
content,
|
||||
tailwindScript + FRAGMENT_LINK_INTERCEPTOR_SCRIPT,
|
||||
);
|
||||
return (
|
||||
<iframe
|
||||
sandbox="allow-scripts"
|
||||
|
||||
@@ -27,6 +27,10 @@ export function ArtifactDragHandle({
|
||||
minWidthRef.current = minWidth;
|
||||
maxWidthPercentRef.current = maxWidthPercent;
|
||||
|
||||
// Track the captured pointer id so pointerup can release it even after
|
||||
// React re-renders.
|
||||
const pointerIdRef = useRef<number | null>(null);
|
||||
|
||||
// Attach document listeners only while dragging, and always tear them down
|
||||
// on unmount — otherwise closing the panel mid-drag leaves listeners bound
|
||||
// to a handler that calls setState on the unmounted component.
|
||||
@@ -57,7 +61,7 @@ export function ArtifactDragHandle({
|
||||
};
|
||||
}, [isDragging]);
|
||||
|
||||
function handlePointerDown(e: React.PointerEvent) {
|
||||
function handlePointerDown(e: React.PointerEvent<HTMLDivElement>) {
|
||||
e.preventDefault();
|
||||
startXRef.current = e.clientX;
|
||||
|
||||
@@ -67,9 +71,31 @@ export function ArtifactDragHandle({
|
||||
) as HTMLElement | null;
|
||||
startWidthRef.current = panel?.offsetWidth ?? DEFAULT_PANEL_WIDTH;
|
||||
|
||||
// Capture the pointer so pointermove/pointerup still reach us when the
|
||||
// cursor drifts over sandboxed artifact iframes. Without this, the iframe
|
||||
// eats the events and the drag gets stuck (SECRT-2256).
|
||||
try {
|
||||
e.currentTarget.setPointerCapture(e.pointerId);
|
||||
pointerIdRef.current = e.pointerId;
|
||||
} catch {
|
||||
// Non-supporting environments (older test DOMs) — safe to ignore.
|
||||
}
|
||||
|
||||
setIsDragging(true);
|
||||
}
|
||||
|
||||
function handlePointerUp(e: React.PointerEvent<HTMLDivElement>) {
|
||||
if (pointerIdRef.current != null) {
|
||||
try {
|
||||
e.currentTarget.releasePointerCapture(pointerIdRef.current);
|
||||
} catch {
|
||||
// Capture may already be released.
|
||||
}
|
||||
pointerIdRef.current = null;
|
||||
}
|
||||
setIsDragging(false);
|
||||
}
|
||||
|
||||
return (
|
||||
// 12px transparent hit target with the visible 1px line centered inside
|
||||
// (WCAG-compliant, matches ~8-12px conventions of other resizable panels).
|
||||
@@ -81,6 +107,9 @@ export function ArtifactDragHandle({
|
||||
"group absolute -left-1.5 top-0 z-10 flex h-full w-3 cursor-col-resize items-stretch justify-center",
|
||||
)}
|
||||
onPointerDown={handlePointerDown}
|
||||
onPointerUp={handlePointerUp}
|
||||
onPointerCancel={handlePointerUp}
|
||||
style={{ touchAction: "none" }}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
"use client";
|
||||
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { Component, type ErrorInfo, type ReactNode } from "react";
|
||||
|
||||
interface Props {
|
||||
children: ReactNode;
|
||||
artifactID: string;
|
||||
artifactTitle: string;
|
||||
artifactType: string;
|
||||
}
|
||||
|
||||
interface State {
|
||||
error: Error | null;
|
||||
}
|
||||
|
||||
export class ArtifactErrorBoundary extends Component<Props, State> {
|
||||
state: State = { error: null };
|
||||
|
||||
static getDerivedStateFromError(error: Error): State {
|
||||
return { error };
|
||||
}
|
||||
|
||||
componentDidCatch(error: Error, errorInfo: ErrorInfo) {
|
||||
Sentry.captureException(error, {
|
||||
contexts: {
|
||||
react: { componentStack: errorInfo.componentStack },
|
||||
},
|
||||
tags: { errorBoundary: "true", context: "copilot-artifact" },
|
||||
extra: {
|
||||
artifactID: this.props.artifactID,
|
||||
artifactTitle: this.props.artifactTitle,
|
||||
artifactType: this.props.artifactType,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
componentDidUpdate(prevProps: Props) {
|
||||
if (
|
||||
this.state.error &&
|
||||
(prevProps.artifactID !== this.props.artifactID ||
|
||||
prevProps.artifactTitle !== this.props.artifactTitle ||
|
||||
prevProps.artifactType !== this.props.artifactType)
|
||||
) {
|
||||
this.setState({ error: null });
|
||||
}
|
||||
}
|
||||
|
||||
handleCopy = () => {
|
||||
const { error } = this.state;
|
||||
if (!error) return;
|
||||
const details = [
|
||||
`Artifact: ${this.props.artifactTitle}`,
|
||||
`ID: ${this.props.artifactID}`,
|
||||
`Type: ${this.props.artifactType}`,
|
||||
`Error: ${error.message}`,
|
||||
error.stack ? `Stack:\n${error.stack}` : "",
|
||||
]
|
||||
.filter(Boolean)
|
||||
.join("\n");
|
||||
navigator.clipboard?.writeText(details).catch(() => {});
|
||||
};
|
||||
|
||||
render() {
|
||||
const { error } = this.state;
|
||||
if (!error) return this.props.children;
|
||||
|
||||
const message = error.message || "Unknown rendering error";
|
||||
|
||||
return (
|
||||
<div
|
||||
role="alert"
|
||||
className="flex h-full flex-col items-center justify-center gap-3 p-8 text-center"
|
||||
>
|
||||
<p className="text-sm font-medium text-zinc-700">
|
||||
This artifact couldn't be rendered
|
||||
</p>
|
||||
<p className="max-w-md break-words text-xs text-zinc-500">
|
||||
Something in{" "}
|
||||
<span className="font-mono">{this.props.artifactTitle}</span> threw an
|
||||
error while rendering. The chat and sidebar are still working.
|
||||
</p>
|
||||
<pre className="max-h-32 max-w-md overflow-auto whitespace-pre-wrap break-words rounded-md bg-zinc-100 px-3 py-2 text-left text-xs text-zinc-700">
|
||||
{message}
|
||||
</pre>
|
||||
<button
|
||||
type="button"
|
||||
onClick={this.handleCopy}
|
||||
className="rounded-md border border-zinc-200 bg-white px-3 py-1.5 text-xs font-medium text-zinc-700 shadow-sm transition-colors hover:bg-zinc-50 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-violet-400"
|
||||
>
|
||||
Copy error details
|
||||
</button>
|
||||
<p className="max-w-md text-xs text-zinc-400">
|
||||
Paste this into the chat so the agent can regenerate a working
|
||||
version.
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -199,6 +199,82 @@ describe("ArtifactContent", () => {
|
||||
});
|
||||
});
|
||||
|
||||
// SECRT-2221 integration: the classification-level fix (hi-res PNGs stop
|
||||
// being size-gated) only matters if the end-to-end rendering pipeline
|
||||
// actually reaches the <img> path. Pass in the real classifyArtifact
|
||||
// result for a 25 MB .png and assert the panel renders an img element
|
||||
// rather than routing to the download-only surface.
|
||||
it("renders a 25 MB PNG through the <img> path, not download-only (SECRT-2221)", () => {
|
||||
const artifact = makeArtifact({
|
||||
id: "hires-png-001",
|
||||
title: "poster.png",
|
||||
mimeType: "image/png",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/hires-png-001/download",
|
||||
sizeBytes: 25 * 1024 * 1024,
|
||||
});
|
||||
const classification = classifyArtifact(
|
||||
artifact.mimeType,
|
||||
artifact.title,
|
||||
artifact.sizeBytes,
|
||||
);
|
||||
expect(classification.type).toBe("image");
|
||||
expect(classification.openable).toBe(true);
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const img = container.querySelector("img");
|
||||
expect(img).toBeTruthy();
|
||||
expect(img?.getAttribute("src")).toBe(artifact.sourceUrl);
|
||||
});
|
||||
|
||||
// SECRT-2221: image retry appends a cache-busting query so the browser
|
||||
// can't reuse a previously-failed response. Without this, a transient
|
||||
// 5xx that gets negative-cached keeps showing "Failed to load image" no
|
||||
// matter how many times the user clicks Try again.
|
||||
it("image retry appends a cache-busting query so the browser re-fetches (SECRT-2221)", async () => {
|
||||
const artifact = makeArtifact({
|
||||
id: "img-cachebust",
|
||||
title: "hires.png",
|
||||
mimeType: "image/png",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/img-cachebust/download",
|
||||
});
|
||||
const classification = makeClassification({ type: "image" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const firstImg = container.querySelector("img");
|
||||
const firstSrc = firstImg?.getAttribute("src");
|
||||
expect(firstSrc).toBe(artifact.sourceUrl);
|
||||
|
||||
fireEvent.error(firstImg!);
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText("Failed to load image")).toBeTruthy();
|
||||
});
|
||||
fireEvent.click(screen.getByRole("button", { name: /try again/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
const nextImg = container.querySelector("img");
|
||||
const nextSrc = nextImg?.getAttribute("src") ?? "";
|
||||
expect(nextSrc).not.toBe(firstSrc);
|
||||
expect(nextSrc.startsWith(artifact.sourceUrl)).toBe(true);
|
||||
// Assert the specific cache-bust contract, not just that the URL
|
||||
// changed — guards against accidental rewrites that drop the key.
|
||||
expect(nextSrc).toContain("_retry=");
|
||||
});
|
||||
});
|
||||
|
||||
// ── Video ─────────────────────────────────────────────────────────
|
||||
|
||||
it("renders video artifact with video tag and controls", () => {
|
||||
@@ -379,6 +455,117 @@ describe("ArtifactContent", () => {
|
||||
expect(retryButtons.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
// SECRT-2224: "try again doesn't do anything". The retry itself works — the
|
||||
// user's complaint is that there's no visible feedback when the same error
|
||||
// returns (e.g. a 404 for a deleted file). Clicking Try Again must flip the
|
||||
// UI into the loading skeleton immediately so the user can tell their click
|
||||
// registered, instead of the error UI re-flashing in place.
|
||||
it("clicking Try Again shows the loading skeleton before the next fetch settles (SECRT-2224)", async () => {
|
||||
let resolveSecond: (value: unknown) => void = () => {};
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
if (callCount === 1) {
|
||||
return Promise.resolve({
|
||||
ok: false,
|
||||
status: 404,
|
||||
text: () => Promise.resolve("Not found"),
|
||||
});
|
||||
}
|
||||
return new Promise((resolve) => {
|
||||
resolveSecond = resolve;
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "retry-skeleton-001",
|
||||
title: "flaky.html",
|
||||
mimeType: "text/html",
|
||||
});
|
||||
const classification = makeClassification({ type: "html" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
await screen.findByText("Failed to load content");
|
||||
fireEvent.click(screen.getByRole("button", { name: /try again/i }));
|
||||
|
||||
// Before the second fetch resolves, the error must be gone and a skeleton
|
||||
// visible (animate-pulse is the Skeleton component's signature class).
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText("Failed to load content")).toBeNull();
|
||||
expect(container.querySelector('[class*="animate-pulse"]')).toBeTruthy();
|
||||
});
|
||||
|
||||
// Let the second fetch complete and wait for the recovered render so
|
||||
// pending React updates can't leak into the next test.
|
||||
resolveSecond({
|
||||
ok: true,
|
||||
text: () => Promise.resolve("<html><body>ok</body></html>"),
|
||||
});
|
||||
await screen.findByTitle("flaky.html");
|
||||
});
|
||||
|
||||
// SECRT-2224 end-to-end: Try Again actually recovers when the next fetch
|
||||
// succeeds. Covers the full click → re-fetch → iframe-render loop.
|
||||
it("clicking Try Again re-fetches and renders recovered HTML content (SECRT-2224)", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
if (callCount === 1) {
|
||||
return Promise.resolve({
|
||||
ok: false,
|
||||
status: 404,
|
||||
text: () => Promise.resolve("Not found"),
|
||||
});
|
||||
}
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
text: () =>
|
||||
Promise.resolve(
|
||||
"<html><body><h1 id='ok'>recovered</h1></body></html>",
|
||||
),
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "retry-recover-001",
|
||||
title: "flaky.html",
|
||||
mimeType: "text/html",
|
||||
});
|
||||
const classification = makeClassification({ type: "html" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
await screen.findByText("Failed to load content");
|
||||
fireEvent.click(screen.getByRole("button", { name: /try again/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
const iframe = container.querySelector("iframe");
|
||||
expect(iframe).toBeTruthy();
|
||||
expect(iframe?.getAttribute("srcdoc")).toContain("recovered");
|
||||
});
|
||||
expect(screen.queryByText("Failed to load content")).toBeNull();
|
||||
expect(callCount).toBeGreaterThanOrEqual(2);
|
||||
});
|
||||
|
||||
// ── HTML ──────────────────────────────────────────────────────────
|
||||
|
||||
it("renders HTML content in sandboxed iframe", async () => {
|
||||
@@ -412,6 +599,41 @@ describe("ArtifactContent", () => {
|
||||
expect(iframe?.getAttribute("sandbox")).toBe("allow-scripts");
|
||||
});
|
||||
|
||||
it("injects the fragment-link interceptor into HTML artifact iframes (regression)", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () =>
|
||||
Promise.resolve(
|
||||
'<html><head></head><body><a href="#x">x</a><div id="x">x</div></body></html>',
|
||||
),
|
||||
}),
|
||||
);
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={makeArtifact({
|
||||
id: "html-frag",
|
||||
title: "page.html",
|
||||
mimeType: "text/html",
|
||||
})}
|
||||
isSourceView={false}
|
||||
classification={makeClassification({ type: "html" })}
|
||||
/>,
|
||||
);
|
||||
|
||||
await screen.findByTitle("page.html");
|
||||
const srcdoc = container.querySelector("iframe")?.getAttribute("srcdoc");
|
||||
expect(srcdoc).toBeTruthy();
|
||||
// Markers unique to FRAGMENT_LINK_INTERCEPTOR_SCRIPT — if any of these
|
||||
// disappear, the interceptor is no longer being injected and fragment
|
||||
// links will navigate the parent URL again.
|
||||
expect(srcdoc).toContain("__fragmentLinkInterceptor");
|
||||
expect(srcdoc).toContain('a[href^="#"]');
|
||||
expect(srcdoc).toContain("scrollIntoView");
|
||||
});
|
||||
|
||||
// ── Source view ───────────────────────────────────────────────────
|
||||
|
||||
it("renders source view as pre tag", async () => {
|
||||
@@ -923,6 +1145,239 @@ describe("ArtifactContent", () => {
|
||||
},
|
||||
);
|
||||
|
||||
// ── Error boundary ────────────────────────────────────────────────
|
||||
|
||||
it("shows a visible error instead of crashing when the renderer throws", async () => {
|
||||
const consoleErr = vi.spyOn(console, "error").mockImplementation(() => {});
|
||||
const originalImpl = vi
|
||||
.mocked(ArtifactReactPreview)
|
||||
.getMockImplementation();
|
||||
vi.mocked(ArtifactReactPreview).mockImplementation(() => {
|
||||
throw new Error("boom in renderer");
|
||||
});
|
||||
|
||||
try {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve("source"),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "crash-001",
|
||||
title: "broken.tsx",
|
||||
mimeType: "text/tsx",
|
||||
});
|
||||
const classification = makeClassification({ type: "react" });
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(
|
||||
await screen.findByText(/This artifact couldn't be rendered/i),
|
||||
).toBeTruthy();
|
||||
expect(screen.getByText(/boom in renderer/)).toBeTruthy();
|
||||
expect(
|
||||
screen.getByRole("button", { name: /copy error details/i }),
|
||||
).toBeTruthy();
|
||||
} finally {
|
||||
if (originalImpl) {
|
||||
vi.mocked(ArtifactReactPreview).mockImplementation(originalImpl);
|
||||
}
|
||||
consoleErr.mockRestore();
|
||||
}
|
||||
});
|
||||
|
||||
it("copies artifact title, type, and error to the clipboard", async () => {
|
||||
const consoleErr = vi.spyOn(console, "error").mockImplementation(() => {});
|
||||
const writeText = vi.fn().mockResolvedValue(undefined);
|
||||
Object.defineProperty(navigator, "clipboard", {
|
||||
value: { writeText },
|
||||
writable: true,
|
||||
configurable: true,
|
||||
});
|
||||
|
||||
const originalImpl = vi
|
||||
.mocked(ArtifactReactPreview)
|
||||
.getMockImplementation();
|
||||
vi.mocked(ArtifactReactPreview).mockImplementation(() => {
|
||||
throw new Error("jsx parse failed at line 42");
|
||||
});
|
||||
|
||||
try {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve("source"),
|
||||
}),
|
||||
);
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={makeArtifact({
|
||||
id: "crash-002",
|
||||
title: "report.tsx",
|
||||
mimeType: "text/tsx",
|
||||
})}
|
||||
isSourceView={false}
|
||||
classification={makeClassification({ type: "react" })}
|
||||
/>,
|
||||
);
|
||||
|
||||
fireEvent.click(
|
||||
await screen.findByRole("button", { name: /copy error details/i }),
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(writeText).toHaveBeenCalled();
|
||||
});
|
||||
const payload = writeText.mock.calls[0]![0] as string;
|
||||
expect(payload).toContain("report.tsx");
|
||||
expect(payload).toContain("crash-002");
|
||||
expect(payload).toContain("react");
|
||||
expect(payload).toContain("jsx parse failed at line 42");
|
||||
} finally {
|
||||
if (originalImpl) {
|
||||
vi.mocked(ArtifactReactPreview).mockImplementation(originalImpl);
|
||||
}
|
||||
consoleErr.mockRestore();
|
||||
}
|
||||
});
|
||||
|
||||
// Regression: two different artifacts can share the same title+type (e.g.
|
||||
// two "App.tsx" files from different sessions). The boundary must reset
|
||||
// when artifact.id changes, not only on title/type changes, otherwise
|
||||
// opening a second artifact after a crash stays stuck on the first's error.
|
||||
it("resets the error fallback when the artifact id changes (same title/type)", async () => {
|
||||
const consoleErr = vi.spyOn(console, "error").mockImplementation(() => {});
|
||||
const originalImpl = vi
|
||||
.mocked(ArtifactReactPreview)
|
||||
.getMockImplementation();
|
||||
|
||||
// First render: throws.
|
||||
vi.mocked(ArtifactReactPreview).mockImplementation(() => {
|
||||
throw new Error("first render boom");
|
||||
});
|
||||
|
||||
try {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve("source"),
|
||||
}),
|
||||
);
|
||||
const classification = makeClassification({ type: "react" });
|
||||
|
||||
const { rerender } = render(
|
||||
<ArtifactContent
|
||||
artifact={makeArtifact({
|
||||
id: "id-one",
|
||||
title: "App.tsx",
|
||||
mimeType: "text/tsx",
|
||||
})}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
await screen.findByText(/This artifact couldn't be rendered/i);
|
||||
|
||||
// Swap in a working renderer and a different artifact id (same title/type).
|
||||
if (originalImpl) {
|
||||
vi.mocked(ArtifactReactPreview).mockImplementation(originalImpl);
|
||||
}
|
||||
|
||||
rerender(
|
||||
<ArtifactContent
|
||||
artifact={makeArtifact({
|
||||
id: "id-two",
|
||||
title: "App.tsx",
|
||||
mimeType: "text/tsx",
|
||||
})}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.queryByText(/This artifact couldn't be rendered/i),
|
||||
).toBeNull();
|
||||
expect(screen.getByTestId("react-preview")).toBeTruthy();
|
||||
});
|
||||
} finally {
|
||||
if (originalImpl) {
|
||||
vi.mocked(ArtifactReactPreview).mockImplementation(originalImpl);
|
||||
}
|
||||
consoleErr.mockRestore();
|
||||
}
|
||||
});
|
||||
|
||||
it("renders the user-reported plotly HTML artifact into a sandboxed iframe", async () => {
|
||||
const html = `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>AutoGPT Beta Launch Interactive Report</title>
|
||||
<script src="https://cdn.plot.ly/plotly-2.27.0.min.js"></script>
|
||||
<style>
|
||||
:root { --bg: #f8f9fa; --primary: #6c5ce7; }
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
body { font-family: 'Segoe UI', system-ui, sans-serif; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header><h1>\u{1F4CA} AutoGPT Beta Launch Interactive Report</h1></header>
|
||||
<div class="chart-container" id="globalActivationChart"></div>
|
||||
<script>
|
||||
function showTab(tabId, groupId) {
|
||||
const group = document.getElementById(groupId);
|
||||
group.querySelectorAll('.tab-content').forEach(t => t.classList.remove('active'));
|
||||
document.getElementById(tabId).classList.add('active');
|
||||
}
|
||||
Plotly.newPlot('globalActivationChart', [{ type: 'pie', values: [1, 2] }], {});
|
||||
</script>
|
||||
</body>
|
||||
</html>`;
|
||||
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve(html),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "html-big-report",
|
||||
title: "report.html",
|
||||
mimeType: "text/html",
|
||||
});
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={makeClassification({ type: "html" })}
|
||||
/>,
|
||||
);
|
||||
|
||||
await screen.findByTitle("report.html");
|
||||
const iframe = container.querySelector("iframe");
|
||||
expect(iframe).toBeTruthy();
|
||||
expect(iframe?.getAttribute("sandbox")).toBe("allow-scripts");
|
||||
expect(screen.queryByText(/couldn't be rendered/i)).toBeNull();
|
||||
});
|
||||
|
||||
it("falls back to pre tag when no renderer matches", async () => {
|
||||
const { globalRegistry } = await import(
|
||||
"@/components/contextual/OutputRenderers"
|
||||
|
||||
@@ -0,0 +1,181 @@
|
||||
import { cleanup, fireEvent, render } from "@testing-library/react";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { ArtifactDragHandle } from "../ArtifactDragHandle";
|
||||
|
||||
function renderHandle(onWidthChange = vi.fn(), panelWidth = 600) {
|
||||
const utils = render(
|
||||
<div
|
||||
data-artifact-panel
|
||||
style={{
|
||||
width: `${panelWidth}px`,
|
||||
height: "400px",
|
||||
position: "relative",
|
||||
}}
|
||||
>
|
||||
<ArtifactDragHandle onWidthChange={onWidthChange} />
|
||||
</div>,
|
||||
);
|
||||
const panel = utils.container.querySelector(
|
||||
"[data-artifact-panel]",
|
||||
) as HTMLElement;
|
||||
// happy-dom doesn't compute layout; stub offsetWidth so the handle reads
|
||||
// the intended starting width.
|
||||
Object.defineProperty(panel, "offsetWidth", {
|
||||
value: panelWidth,
|
||||
configurable: true,
|
||||
});
|
||||
const handle = utils.container.querySelector(
|
||||
'[role="separator"]',
|
||||
) as HTMLElement;
|
||||
return { handle, onWidthChange, ...utils };
|
||||
}
|
||||
|
||||
// jsdom/happy-dom don't implement pointer capture by default — spy on the
|
||||
// prototype so vi.restoreAllMocks() can tear the spies down. We also seed
|
||||
// no-op base implementations where the prototype lacks them so vi.spyOn has
|
||||
// something to wrap. Both the seeded properties and window.innerWidth are
|
||||
// manual mutations that vi.restoreAllMocks() won't undo, so capture their
|
||||
// original descriptors and restore them in `restoreGlobals`.
|
||||
function installPointerCaptureStub() {
|
||||
const proto = HTMLElement.prototype as unknown as {
|
||||
setPointerCapture?: (id: number) => void;
|
||||
releasePointerCapture?: (id: number) => void;
|
||||
};
|
||||
const originalSetPointerCapture = Object.getOwnPropertyDescriptor(
|
||||
proto,
|
||||
"setPointerCapture",
|
||||
);
|
||||
const originalReleasePointerCapture = Object.getOwnPropertyDescriptor(
|
||||
proto,
|
||||
"releasePointerCapture",
|
||||
);
|
||||
|
||||
if (!proto.setPointerCapture) proto.setPointerCapture = () => {};
|
||||
if (!proto.releasePointerCapture) proto.releasePointerCapture = () => {};
|
||||
const setPointerCapture = vi
|
||||
.spyOn(HTMLElement.prototype, "setPointerCapture")
|
||||
.mockImplementation(() => {});
|
||||
const releasePointerCapture = vi
|
||||
.spyOn(HTMLElement.prototype, "releasePointerCapture")
|
||||
.mockImplementation(() => {});
|
||||
|
||||
function restoreGlobals() {
|
||||
if (originalSetPointerCapture) {
|
||||
Object.defineProperty(
|
||||
proto,
|
||||
"setPointerCapture",
|
||||
originalSetPointerCapture,
|
||||
);
|
||||
} else {
|
||||
delete proto.setPointerCapture;
|
||||
}
|
||||
if (originalReleasePointerCapture) {
|
||||
Object.defineProperty(
|
||||
proto,
|
||||
"releasePointerCapture",
|
||||
originalReleasePointerCapture,
|
||||
);
|
||||
} else {
|
||||
delete proto.releasePointerCapture;
|
||||
}
|
||||
}
|
||||
|
||||
return { setPointerCapture, releasePointerCapture, restoreGlobals };
|
||||
}
|
||||
|
||||
describe("ArtifactDragHandle", () => {
|
||||
let spies: ReturnType<typeof installPointerCaptureStub>;
|
||||
const originalInnerWidth = Object.getOwnPropertyDescriptor(
|
||||
window,
|
||||
"innerWidth",
|
||||
);
|
||||
|
||||
beforeEach(() => {
|
||||
spies = installPointerCaptureStub();
|
||||
Object.defineProperty(window, "innerWidth", {
|
||||
value: 1200,
|
||||
writable: true,
|
||||
configurable: true,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
vi.restoreAllMocks();
|
||||
spies.restoreGlobals();
|
||||
if (originalInnerWidth) {
|
||||
Object.defineProperty(window, "innerWidth", originalInnerWidth);
|
||||
}
|
||||
});
|
||||
|
||||
// SECRT-2256: when the cursor drifts over a sandboxed iframe mid-drag, the
|
||||
// iframe eats pointermove/pointerup and the drag gets stuck. setPointerCapture
|
||||
// routes all subsequent pointer events to the handle regardless of what's
|
||||
// under the cursor, which fixes both "can't drag right" and "drag doesn't
|
||||
// stop on release".
|
||||
it("captures the pointer on pointerdown so drags survive the cursor drifting over iframes (SECRT-2256)", () => {
|
||||
const { handle } = renderHandle();
|
||||
|
||||
fireEvent.pointerDown(handle, { clientX: 500, pointerId: 7 });
|
||||
|
||||
expect(spies.setPointerCapture).toHaveBeenCalledWith(7);
|
||||
});
|
||||
|
||||
it("releases the pointer capture when the drag ends", () => {
|
||||
const { handle } = renderHandle();
|
||||
|
||||
fireEvent.pointerDown(handle, { clientX: 500, pointerId: 7 });
|
||||
fireEvent.pointerUp(handle, { clientX: 400, pointerId: 7 });
|
||||
|
||||
expect(spies.releasePointerCapture).toHaveBeenCalledWith(7);
|
||||
});
|
||||
|
||||
it("calls onWidthChange with the expanded width when dragging leftwards", () => {
|
||||
const onWidthChange = vi.fn();
|
||||
const { handle } = renderHandle(onWidthChange);
|
||||
|
||||
fireEvent.pointerDown(handle, { clientX: 800, pointerId: 1 });
|
||||
fireEvent.pointerMove(document, { clientX: 700, pointerId: 1 });
|
||||
|
||||
// startWidth is 600 (container), delta = 800 - 700 = 100 → newWidth 700
|
||||
expect(onWidthChange).toHaveBeenCalledWith(700);
|
||||
});
|
||||
|
||||
it("calls onWidthChange with the shrunk width when dragging rightwards", () => {
|
||||
const onWidthChange = vi.fn();
|
||||
const { handle } = renderHandle(onWidthChange);
|
||||
|
||||
fireEvent.pointerDown(handle, { clientX: 800, pointerId: 1 });
|
||||
fireEvent.pointerMove(document, { clientX: 900, pointerId: 1 });
|
||||
|
||||
// delta = -100 → newWidth 500
|
||||
expect(onWidthChange).toHaveBeenCalledWith(500);
|
||||
});
|
||||
|
||||
it("clamps to minWidth and maxWidth", () => {
|
||||
const onWidthChange = vi.fn();
|
||||
const { handle } = renderHandle(onWidthChange);
|
||||
|
||||
fireEvent.pointerDown(handle, { clientX: 800, pointerId: 1 });
|
||||
|
||||
// Drag way left → want huge width, should clamp at 85% of 1200 = 1020
|
||||
fireEvent.pointerMove(document, { clientX: -5000, pointerId: 1 });
|
||||
expect(onWidthChange).toHaveBeenLastCalledWith(1020);
|
||||
|
||||
// Drag way right → want tiny width, should clamp at minWidth 320
|
||||
fireEvent.pointerMove(document, { clientX: 5000, pointerId: 1 });
|
||||
expect(onWidthChange).toHaveBeenLastCalledWith(320);
|
||||
});
|
||||
|
||||
it("stops dragging on pointerup so subsequent cursor moves don't resize", () => {
|
||||
const onWidthChange = vi.fn();
|
||||
const { handle } = renderHandle(onWidthChange);
|
||||
|
||||
fireEvent.pointerDown(handle, { clientX: 800, pointerId: 1 });
|
||||
fireEvent.pointerUp(handle, { clientX: 800, pointerId: 1 });
|
||||
onWidthChange.mockClear();
|
||||
|
||||
fireEvent.pointerMove(document, { clientX: 500, pointerId: 1 });
|
||||
expect(onWidthChange).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
@@ -3,7 +3,7 @@ import {
|
||||
buildReactArtifactSrcDoc,
|
||||
collectPreviewStyles,
|
||||
escapeHtml,
|
||||
} from "./reactArtifactPreview";
|
||||
} from "../reactArtifactPreview";
|
||||
|
||||
describe("escapeHtml", () => {
|
||||
it("escapes &, <, >, \", '", () => {
|
||||
@@ -116,4 +116,11 @@ describe("buildReactArtifactSrcDoc", () => {
|
||||
expect(doc).toContain("/^[A-Z]/.test(name)");
|
||||
expect(doc).toContain("wrapWithProviders");
|
||||
});
|
||||
|
||||
it("injects the fragment-link interceptor so #anchor clicks stay inside the iframe (regression)", () => {
|
||||
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
|
||||
expect(doc).toContain("__fragmentLinkInterceptor");
|
||||
expect(doc).toContain('a[href^="#"]');
|
||||
expect(doc).toContain("scrollIntoView");
|
||||
});
|
||||
});
|
||||
@@ -19,7 +19,10 @@
|
||||
* React is loaded from unpkg with pinned version and SRI integrity hashes.
|
||||
*/
|
||||
|
||||
import { TAILWIND_CDN_URL } from "@/lib/iframe-sandbox-csp";
|
||||
import {
|
||||
FRAGMENT_LINK_INTERCEPTOR_SCRIPT,
|
||||
TAILWIND_CDN_URL,
|
||||
} from "@/lib/iframe-sandbox-csp";
|
||||
|
||||
export { transpileReactArtifactSource } from "./transpileReactArtifact";
|
||||
|
||||
@@ -95,6 +98,7 @@ export function buildReactArtifactSrcDoc(
|
||||
}
|
||||
</style>
|
||||
<script src="${TAILWIND_CDN_URL}"></script>
|
||||
${FRAGMENT_LINK_INTERCEPTOR_SCRIPT}
|
||||
<script crossorigin="anonymous" src="https://unpkg.com/react@18.3.1/umd/react.production.min.js" integrity="sha384-DGyLxAyjq0f9SPpVevD6IgztCFlnMF6oW/XQGmfe+IsZ8TqEiDrcHkMLKI6fiB/Z"></script><!-- pragma: allowlist secret -->
|
||||
<script crossorigin="anonymous" src="https://unpkg.com/react-dom@18.3.1/umd/react-dom.production.min.js" integrity="sha384-gTGxhz21lVGYNMcdJOyq01Edg0jhn/c22nsx0kyqP0TxaV5WVdsSH1fSDUf5YJj1"></script><!-- pragma: allowlist secret -->
|
||||
</head>
|
||||
|
||||
@@ -141,6 +141,11 @@ export function useArtifactContent(
|
||||
function retry() {
|
||||
// Drop any cached failure/content for this id so we actually re-fetch.
|
||||
contentCache.delete(artifact.id);
|
||||
// Flip into loading + clear error synchronously with the click so the
|
||||
// user always sees the skeleton (rather than the error UI re-flashing
|
||||
// instantly for same-error retries). See SECRT-2224.
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
setRetryNonce((n) => n + 1);
|
||||
}
|
||||
|
||||
|
||||
@@ -45,12 +45,33 @@ describe("classifyArtifact", () => {
|
||||
expect(classifyArtifact("text/markdown", "x").type).toBe("markdown");
|
||||
});
|
||||
|
||||
it("gates files > 10MB to download-only", () => {
|
||||
it("gates text/code files > 10MB to download-only", () => {
|
||||
const c = classifyArtifact("text/plain", "big.txt", 20 * 1024 * 1024);
|
||||
expect(c.openable).toBe(false);
|
||||
expect(c.type).toBe("download-only");
|
||||
});
|
||||
|
||||
// SECRT-2221: large images (hi-res PNGs, etc.) were getting force-classified
|
||||
// as download-only by the generic >10MB gate, so clicking them started a
|
||||
// download instead of previewing — and the preview was "broken" in the
|
||||
// sense that it never appeared. Images, videos, and PDFs are decoded
|
||||
// natively by the browser and don't run through our JS render pipeline,
|
||||
// so the size gate shouldn't apply to them.
|
||||
it("does NOT size-gate large images, videos, or PDFs (SECRT-2221)", () => {
|
||||
expect(
|
||||
classifyArtifact("image/png", "hires.png", 25 * 1024 * 1024).type,
|
||||
).toBe("image");
|
||||
expect(
|
||||
classifyArtifact("image/jpeg", "huge.jpg", 50 * 1024 * 1024).type,
|
||||
).toBe("image");
|
||||
expect(
|
||||
classifyArtifact("video/mp4", "long.mp4", 500 * 1024 * 1024).type,
|
||||
).toBe("video");
|
||||
expect(
|
||||
classifyArtifact("application/pdf", "book.pdf", 80 * 1024 * 1024).type,
|
||||
).toBe("pdf");
|
||||
});
|
||||
|
||||
it("treats binary/octet-stream MIME as download-only", () => {
|
||||
expect(classifyArtifact("application/zip", "a.zip").openable).toBe(false);
|
||||
expect(classifyArtifact("application/octet-stream", "x").openable).toBe(
|
||||
|
||||
@@ -257,14 +257,34 @@ function getExtension(filename?: string): string {
|
||||
return filename.slice(lastDot).toLowerCase();
|
||||
}
|
||||
|
||||
// Types the browser renders natively — we don't run their bytes through our
|
||||
// React/JS pipeline, so the size gate doesn't need to apply.
|
||||
const NATIVELY_RENDERED = new Set<ArtifactClassification["type"]>([
|
||||
"image",
|
||||
"video",
|
||||
"pdf",
|
||||
]);
|
||||
|
||||
export function classifyArtifact(
|
||||
mimeType: string | null,
|
||||
filename?: string,
|
||||
sizeBytes?: number,
|
||||
): ArtifactClassification {
|
||||
// Size gate: >10MB is download-only regardless of type.
|
||||
if (sizeBytes && sizeBytes > TEN_MB) return KIND["download-only"];
|
||||
const kind = classifyByTypeOnly(mimeType, filename);
|
||||
// Size gate: >10MB is download-only, but only for content we actually
|
||||
// render in JS. Images, videos, and PDFs are handled natively by the
|
||||
// browser — gating them produced "broken previews" for hi-res files
|
||||
// (SECRT-2221).
|
||||
if (sizeBytes && sizeBytes > TEN_MB && !NATIVELY_RENDERED.has(kind.type)) {
|
||||
return KIND["download-only"];
|
||||
}
|
||||
return kind;
|
||||
}
|
||||
|
||||
function classifyByTypeOnly(
|
||||
mimeType: string | null,
|
||||
filename?: string,
|
||||
): ArtifactClassification {
|
||||
const basename = getBasename(filename);
|
||||
const exactKind = EXACT_FILENAME_KIND[basename];
|
||||
if (exactKind) return KIND[exactKind];
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { act, renderHook } from "@testing-library/react";
|
||||
import { beforeEach, describe, expect, it } from "vitest";
|
||||
import { act, cleanup, renderHook } from "@testing-library/react";
|
||||
import { afterEach, beforeEach, describe, expect, it } from "vitest";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
import { useAutoOpenArtifacts } from "./useAutoOpenArtifacts";
|
||||
|
||||
@@ -31,6 +31,11 @@ function resetStore() {
|
||||
|
||||
describe("useAutoOpenArtifacts", () => {
|
||||
beforeEach(resetStore);
|
||||
// Testing Library auto-cleanup isn't registered in our Vitest setup, so
|
||||
// mounted `renderHook` instances (and their unmount cleanups) would leak
|
||||
// between tests — here the unmount effect in useAutoOpenArtifacts would
|
||||
// fire after the next test had already run and corrupt its assertions.
|
||||
afterEach(cleanup);
|
||||
|
||||
it("does not auto-open on initial render", () => {
|
||||
renderHook(() => useAutoOpenArtifacts({ sessionId: "s1" }));
|
||||
@@ -88,4 +93,53 @@ describe("useAutoOpenArtifacts", () => {
|
||||
expect(s.activeArtifact?.id).toBe("c");
|
||||
expect(s.history).toEqual([]);
|
||||
});
|
||||
|
||||
// SECRT-2254: "had agent panel open then went to profile then went to home
|
||||
// and agent panel was still open". Nav-away unmounts the copilot page; if
|
||||
// the panel state persists in the store, coming back re-renders it open.
|
||||
it("closes the panel on unmount so nav-away → nav-back doesn't resurrect it (SECRT-2254)", () => {
|
||||
useCopilotUIStore.getState().openArtifact(makeArtifact(A_ID, "a.txt"));
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(true);
|
||||
|
||||
const { unmount } = renderHook(() =>
|
||||
useAutoOpenArtifacts({ sessionId: "s1" }),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
unmount();
|
||||
});
|
||||
|
||||
const s = useCopilotUIStore.getState().artifactPanel;
|
||||
expect(s.isOpen).toBe(false);
|
||||
expect(s.activeArtifact).toBeNull();
|
||||
expect(s.history).toEqual([]);
|
||||
});
|
||||
|
||||
// SECRT-2220: "keep closed by default" — a fresh mount (e.g. user returns to
|
||||
// /copilot) must start with a closed panel even if the store somehow carries
|
||||
// stale state from a prior life.
|
||||
it("does not re-open a panel whose store state is stale on fresh mount (SECRT-2220)", () => {
|
||||
// Simulate the store being left in an open state by a previous page life.
|
||||
useCopilotUIStore.setState({
|
||||
artifactPanel: {
|
||||
isOpen: true,
|
||||
isMinimized: false,
|
||||
isMaximized: false,
|
||||
width: 600,
|
||||
activeArtifact: makeArtifact(A_ID, "stale.txt"),
|
||||
history: [],
|
||||
},
|
||||
});
|
||||
|
||||
const { unmount } = renderHook(() =>
|
||||
useAutoOpenArtifacts({ sessionId: "s1" }),
|
||||
);
|
||||
act(() => {
|
||||
unmount();
|
||||
});
|
||||
|
||||
// Next mount of the page should see a clean store.
|
||||
renderHook(() => useAutoOpenArtifacts({ sessionId: "s1" }));
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -26,4 +26,13 @@ export function useAutoOpenArtifacts({
|
||||
resetArtifactPanel();
|
||||
}
|
||||
}, [sessionId, resetArtifactPanel]);
|
||||
|
||||
// Reset on unmount so navigating away from /copilot (to /profile, /home,
|
||||
// etc.) can't leave the panel open in the Zustand store, which would then
|
||||
// render the panel re-open when the user returns. See SECRT-2254/2220.
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
resetArtifactPanel();
|
||||
};
|
||||
}, [resetArtifactPanel]);
|
||||
}
|
||||
|
||||
@@ -1,74 +1,2 @@
|
||||
import { OutputRenderer, OutputMetadata } from "../types";
|
||||
|
||||
export interface DownloadItem {
|
||||
value: any;
|
||||
metadata?: OutputMetadata;
|
||||
renderer: OutputRenderer;
|
||||
}
|
||||
|
||||
export async function downloadOutputs(items: DownloadItem[]) {
|
||||
const concatenableTexts: string[] = [];
|
||||
const nonConcatenableDownloads: Array<{ blob: Blob; filename: string }> = [];
|
||||
|
||||
for (const item of items) {
|
||||
if (item.renderer.isConcatenable(item.value, item.metadata)) {
|
||||
const copyContent = item.renderer.getCopyContent(
|
||||
item.value,
|
||||
item.metadata,
|
||||
);
|
||||
if (copyContent) {
|
||||
// Extract text from CopyContent
|
||||
let text: string;
|
||||
if (typeof copyContent.data === "string") {
|
||||
text = copyContent.data;
|
||||
} else if (copyContent.fallbackText) {
|
||||
text = copyContent.fallbackText;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
concatenableTexts.push(text);
|
||||
}
|
||||
} else {
|
||||
const downloadContent = item.renderer.getDownloadContent(
|
||||
item.value,
|
||||
item.metadata,
|
||||
);
|
||||
if (downloadContent) {
|
||||
if (typeof downloadContent.data === "string") {
|
||||
if (downloadContent.data.startsWith("http")) {
|
||||
const link = document.createElement("a");
|
||||
link.href = downloadContent.data;
|
||||
link.download = downloadContent.filename;
|
||||
link.click();
|
||||
}
|
||||
} else {
|
||||
nonConcatenableDownloads.push({
|
||||
blob: downloadContent.data as Blob,
|
||||
filename: downloadContent.filename,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (concatenableTexts.length > 0) {
|
||||
const combinedText = concatenableTexts.join("\n\n---\n\n");
|
||||
const blob = new Blob([combinedText], { type: "text/plain" });
|
||||
downloadBlob(blob, "combined_output.txt");
|
||||
}
|
||||
|
||||
for (const download of nonConcatenableDownloads) {
|
||||
downloadBlob(download.blob, download.filename);
|
||||
}
|
||||
}
|
||||
|
||||
function downloadBlob(blob: Blob, filename: string) {
|
||||
const url = URL.createObjectURL(blob);
|
||||
const link = document.createElement("a");
|
||||
link.href = url;
|
||||
link.download = filename;
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
document.body.removeChild(link);
|
||||
URL.revokeObjectURL(url);
|
||||
}
|
||||
export { downloadOutputs } from "@/lib/utils/download-outputs";
|
||||
export type { DownloadItem } from "@/lib/utils/download-outputs";
|
||||
|
||||
@@ -15,9 +15,10 @@ type OutputsRecord = Record<string, Array<unknown>>;
|
||||
|
||||
interface RunOutputsProps {
|
||||
outputs: OutputsRecord;
|
||||
shareToken?: string;
|
||||
}
|
||||
|
||||
export function RunOutputs({ outputs }: RunOutputsProps) {
|
||||
export function RunOutputs({ outputs, shareToken }: RunOutputsProps) {
|
||||
const items = useMemo(() => {
|
||||
const list: Array<{
|
||||
key: string;
|
||||
@@ -30,6 +31,7 @@ export function RunOutputs({ outputs }: RunOutputsProps) {
|
||||
Object.entries(outputs || {}).forEach(([key, values]) => {
|
||||
(values || []).forEach((value, index) => {
|
||||
const metadata: OutputMetadata = {};
|
||||
if (shareToken) metadata.shareToken = shareToken;
|
||||
if (
|
||||
typeof value === "object" &&
|
||||
value !== null &&
|
||||
@@ -76,7 +78,7 @@ export function RunOutputs({ outputs }: RunOutputsProps) {
|
||||
});
|
||||
|
||||
return list;
|
||||
}, [outputs]);
|
||||
}, [outputs, shareToken]);
|
||||
|
||||
if (!items.length) {
|
||||
return <div className="text-neutral-600">No output from this run.</div>;
|
||||
|
||||
@@ -6931,6 +6931,262 @@
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/platform-linking/links": {
|
||||
"get": {
|
||||
"tags": ["platform-linking"],
|
||||
"summary": "List all platform servers linked to the authenticated user",
|
||||
"operationId": "getPlatform-linkingList all platform servers linked to the authenticated user",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"items": { "$ref": "#/components/schemas/PlatformLinkInfo" },
|
||||
"type": "array",
|
||||
"title": "Response Getplatform-Linkinglist All Platform Servers Linked To The Authenticated User"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
}
|
||||
},
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/platform-linking/links/{link_id}": {
|
||||
"delete": {
|
||||
"tags": ["platform-linking"],
|
||||
"summary": "Unlink a platform server",
|
||||
"operationId": "deletePlatform-linkingUnlink a platform server",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "link_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Link Id" }
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/DeleteLinkResponse" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/platform-linking/tokens/{token}/confirm": {
|
||||
"post": {
|
||||
"tags": ["platform-linking"],
|
||||
"summary": "Confirm a SERVER link token (user must be authenticated)",
|
||||
"operationId": "postPlatform-linkingConfirm a server link token (user must be authenticated)",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "token",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"maxLength": 64,
|
||||
"pattern": "^[A-Za-z0-9_-]+$",
|
||||
"title": "Token"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/ConfirmLinkResponse" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/platform-linking/tokens/{token}/info": {
|
||||
"get": {
|
||||
"tags": ["platform-linking"],
|
||||
"summary": "Get display info for a link token",
|
||||
"operationId": "getPlatform-linkingGet display info for a link token",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "token",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"maxLength": 64,
|
||||
"pattern": "^[A-Za-z0-9_-]+$",
|
||||
"title": "Token"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/LinkTokenInfoResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/platform-linking/user-links": {
|
||||
"get": {
|
||||
"tags": ["platform-linking"],
|
||||
"summary": "List all DM links for the authenticated user",
|
||||
"operationId": "getPlatform-linkingList all dm links for the authenticated user",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/PlatformUserLinkInfo"
|
||||
},
|
||||
"type": "array",
|
||||
"title": "Response Getplatform-Linkinglist All Dm Links For The Authenticated User"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
}
|
||||
},
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/platform-linking/user-links/{link_id}": {
|
||||
"delete": {
|
||||
"tags": ["platform-linking"],
|
||||
"summary": "Unlink a DM / user link",
|
||||
"operationId": "deletePlatform-linkingUnlink a dm / user link",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "link_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Link Id" }
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/DeleteLinkResponse" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/platform-linking/user-tokens/{token}/confirm": {
|
||||
"post": {
|
||||
"tags": ["platform-linking"],
|
||||
"summary": "Confirm a USER link token (user must be authenticated)",
|
||||
"operationId": "postPlatform-linkingConfirm a user link token (user must be authenticated)",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "token",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"maxLength": 64,
|
||||
"pattern": "^[A-Za-z0-9_-]+$",
|
||||
"title": "Token"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ConfirmUserLinkResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/public/shared/{share_token}": {
|
||||
"get": {
|
||||
"tags": ["v1"],
|
||||
@@ -6971,6 +7227,50 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/public/shared/{share_token}/files/{file_id}/download": {
|
||||
"get": {
|
||||
"tags": ["v1", "graphs"],
|
||||
"summary": "Download a file from a shared execution",
|
||||
"description": "Download a workspace file from a shared execution (no auth required).\n\nValidates that the file was explicitly exposed when sharing was enabled.\nReturns a uniform 404 for all failure modes to prevent enumeration attacks.",
|
||||
"operationId": "download_shared_file",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "share_token",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"pattern": "^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
|
||||
"title": "Share Token"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "file_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"pattern": "^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
|
||||
"title": "File Id"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": { "application/json": { "schema": {} } }
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/review/action": {
|
||||
"post": {
|
||||
"tags": ["v2", "executions", "review", "v2", "executions", "review"],
|
||||
@@ -10029,6 +10329,46 @@
|
||||
"title": "CoPilotUsagePublic",
|
||||
"description": "Current usage status for a user — public (client-safe) shape."
|
||||
},
|
||||
"ConfirmLinkResponse": {
|
||||
"properties": {
|
||||
"success": { "type": "boolean", "title": "Success" },
|
||||
"link_type": {
|
||||
"$ref": "#/components/schemas/LinkType",
|
||||
"default": "SERVER"
|
||||
},
|
||||
"platform": { "type": "string", "title": "Platform" },
|
||||
"platform_server_id": {
|
||||
"type": "string",
|
||||
"title": "Platform Server Id"
|
||||
},
|
||||
"server_name": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Server Name"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"success",
|
||||
"platform",
|
||||
"platform_server_id",
|
||||
"server_name"
|
||||
],
|
||||
"title": "ConfirmLinkResponse"
|
||||
},
|
||||
"ConfirmUserLinkResponse": {
|
||||
"properties": {
|
||||
"success": { "type": "boolean", "title": "Success" },
|
||||
"link_type": {
|
||||
"$ref": "#/components/schemas/LinkType",
|
||||
"default": "USER"
|
||||
},
|
||||
"platform": { "type": "string", "title": "Platform" },
|
||||
"platform_user_id": { "type": "string", "title": "Platform User Id" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["success", "platform", "platform_user_id"],
|
||||
"title": "ConfirmUserLinkResponse"
|
||||
},
|
||||
"ContentType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -10396,6 +10736,12 @@
|
||||
"required": ["version_counts"],
|
||||
"title": "DeleteGraphResponse"
|
||||
},
|
||||
"DeleteLinkResponse": {
|
||||
"properties": { "success": { "type": "boolean", "title": "Success" } },
|
||||
"type": "object",
|
||||
"required": ["success"],
|
||||
"title": "DeleteLinkResponse"
|
||||
},
|
||||
"DiscoverToolsRequest": {
|
||||
"properties": {
|
||||
"server_url": {
|
||||
@@ -12347,6 +12693,24 @@
|
||||
"required": ["source_id", "sink_id", "source_name", "sink_name"],
|
||||
"title": "Link"
|
||||
},
|
||||
"LinkTokenInfoResponse": {
|
||||
"properties": {
|
||||
"platform": { "type": "string", "title": "Platform" },
|
||||
"link_type": { "$ref": "#/components/schemas/LinkType" },
|
||||
"server_name": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Server Name"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["platform", "link_type"],
|
||||
"title": "LinkTokenInfoResponse"
|
||||
},
|
||||
"LinkType": {
|
||||
"type": "string",
|
||||
"enum": ["SERVER", "USER"],
|
||||
"title": "LinkType"
|
||||
},
|
||||
"ListFilesResponse": {
|
||||
"properties": {
|
||||
"files": {
|
||||
@@ -13491,6 +13855,64 @@
|
||||
"required": ["logs", "pagination"],
|
||||
"title": "PlatformCostLogsResponse"
|
||||
},
|
||||
"PlatformLinkInfo": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
"platform": { "type": "string", "title": "Platform" },
|
||||
"platform_server_id": {
|
||||
"type": "string",
|
||||
"title": "Platform Server Id"
|
||||
},
|
||||
"owner_platform_user_id": {
|
||||
"type": "string",
|
||||
"title": "Owner Platform User Id"
|
||||
},
|
||||
"server_name": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Server Name"
|
||||
},
|
||||
"linked_at": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"title": "Linked At"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"id",
|
||||
"platform",
|
||||
"platform_server_id",
|
||||
"owner_platform_user_id",
|
||||
"server_name",
|
||||
"linked_at"
|
||||
],
|
||||
"title": "PlatformLinkInfo"
|
||||
},
|
||||
"PlatformUserLinkInfo": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
"platform": { "type": "string", "title": "Platform" },
|
||||
"platform_user_id": { "type": "string", "title": "Platform User Id" },
|
||||
"platform_username": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Platform Username"
|
||||
},
|
||||
"linked_at": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"title": "Linked At"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"id",
|
||||
"platform",
|
||||
"platform_user_id",
|
||||
"platform_username",
|
||||
"linked_at"
|
||||
],
|
||||
"title": "PlatformUserLinkInfo"
|
||||
},
|
||||
"PostmarkBounceEnum": {
|
||||
"type": "integer",
|
||||
"enum": [
|
||||
|
||||
@@ -9,13 +9,16 @@ import {
|
||||
} from "./route.helpers";
|
||||
|
||||
describe("isWorkspaceDownloadRequest", () => {
|
||||
it("matches api/workspace/files/{id}/download pattern", () => {
|
||||
const VALID_UUID = "550e8400-e29b-41d4-a716-446655440000";
|
||||
const VALID_UUID_2 = "6ba7b810-9dad-11d1-80b4-00c04fd430c8";
|
||||
|
||||
it("matches api/workspace/files/{uuid}/download pattern", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
"abc-123",
|
||||
VALID_UUID,
|
||||
"download",
|
||||
]),
|
||||
).toBe(true);
|
||||
@@ -30,7 +33,7 @@ describe("isWorkspaceDownloadRequest", () => {
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
"id",
|
||||
VALID_UUID,
|
||||
"download",
|
||||
"extra",
|
||||
]),
|
||||
@@ -43,7 +46,7 @@ describe("isWorkspaceDownloadRequest", () => {
|
||||
"v1",
|
||||
"workspace",
|
||||
"files",
|
||||
"id",
|
||||
VALID_UUID,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
@@ -55,11 +58,622 @@ describe("isWorkspaceDownloadRequest", () => {
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
"id",
|
||||
VALID_UUID,
|
||||
"metadata",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("matches api/public/shared/{uuid}/files/{uuid}/download pattern", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"public",
|
||||
"shared",
|
||||
VALID_UUID,
|
||||
"files",
|
||||
VALID_UUID_2,
|
||||
"download",
|
||||
]),
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
it("rejects public shared paths not ending with download", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"public",
|
||||
"shared",
|
||||
VALID_UUID,
|
||||
"files",
|
||||
VALID_UUID_2,
|
||||
"metadata",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects non-UUID file ID in workspace path", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
"not-a-uuid",
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects non-UUID token in public share path", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"public",
|
||||
"shared",
|
||||
"not-a-uuid",
|
||||
"files",
|
||||
VALID_UUID,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects non-UUID file ID in public share path", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"public",
|
||||
"shared",
|
||||
VALID_UUID,
|
||||
"files",
|
||||
"not-a-uuid",
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("accepts uppercase hex in UUIDs", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
"550E8400-E29B-41D4-A716-446655440000",
|
||||
"download",
|
||||
]),
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
describe("adversarial inputs", () => {
|
||||
it("rejects empty path", () => {
|
||||
expect(isWorkspaceDownloadRequest([])).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects single-segment path", () => {
|
||||
expect(isWorkspaceDownloadRequest(["download"])).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects path traversal in file ID segment", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
"../../etc/passwd",
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects path traversal in token segment", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"public",
|
||||
"shared",
|
||||
"../../etc/passwd",
|
||||
"files",
|
||||
VALID_UUID,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects path traversal replacing fixed segments", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"..",
|
||||
"files",
|
||||
VALID_UUID,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"..",
|
||||
"workspace",
|
||||
"files",
|
||||
VALID_UUID,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects swapped workspace/public segments to confuse routing", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"public",
|
||||
"files",
|
||||
VALID_UUID,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"shared",
|
||||
VALID_UUID,
|
||||
"files",
|
||||
VALID_UUID_2,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects case variations on fixed segments", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"API",
|
||||
"workspace",
|
||||
"files",
|
||||
VALID_UUID,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"Workspace",
|
||||
"files",
|
||||
VALID_UUID,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
VALID_UUID,
|
||||
"DOWNLOAD",
|
||||
]),
|
||||
).toBe(false);
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"PUBLIC",
|
||||
"shared",
|
||||
VALID_UUID,
|
||||
"files",
|
||||
VALID_UUID_2,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"public",
|
||||
"SHARED",
|
||||
VALID_UUID,
|
||||
"files",
|
||||
VALID_UUID_2,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects empty string in fixed segments", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"",
|
||||
"files",
|
||||
VALID_UUID,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects empty token in public share path", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"public",
|
||||
"shared",
|
||||
"",
|
||||
"files",
|
||||
VALID_UUID,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects empty file ID in public share path", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"public",
|
||||
"shared",
|
||||
VALID_UUID,
|
||||
"files",
|
||||
"",
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects empty file ID in workspace path", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
"",
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects UUID with null bytes injected", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
VALID_UUID + "\x00.jpg",
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects UUID with trailing garbage", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
VALID_UUID + "-extra",
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects UUID with leading garbage", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
"prefix-" + VALID_UUID,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects truncated UUIDs", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
"550e8400-e29b-41d4",
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects UUID-length strings with wrong format", () => {
|
||||
// Right length (36 chars) but missing hyphens
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
"550e8400e29b41d4a716446655440000xxxx",
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
// Hyphens in wrong positions
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
"550e-8400e29b-41d4a716-44665544-0000",
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects UUID with non-hex characters", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
"550e8400-e29b-41d4-a716-44665544000g",
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
"550e8400-e29b-41d4-a716-44665544000!",
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects SQL injection via ID segment", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
"'; DROP TABLE files;--",
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects padded segments with whitespace", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
" workspace",
|
||||
"files",
|
||||
VALID_UUID,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace ",
|
||||
"files",
|
||||
VALID_UUID,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
" " + VALID_UUID,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects extra trailing segments after download", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
VALID_UUID,
|
||||
"download",
|
||||
"",
|
||||
]),
|
||||
).toBe(false);
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"public",
|
||||
"shared",
|
||||
VALID_UUID,
|
||||
"files",
|
||||
VALID_UUID_2,
|
||||
"download",
|
||||
"extra",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects extra leading segments before api", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"prefix",
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
VALID_UUID,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"",
|
||||
"api",
|
||||
"public",
|
||||
"shared",
|
||||
VALID_UUID,
|
||||
"files",
|
||||
VALID_UUID_2,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects URL-encoded segment lookalikes", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace%2Ffiles",
|
||||
VALID_UUID,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"public%2Fshared",
|
||||
VALID_UUID,
|
||||
"files",
|
||||
VALID_UUID_2,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects unicode homoglyph substitutions in fixed segments", () => {
|
||||
// Cyrillic 'а' (U+0430) instead of Latin 'a'
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"\u0430pi",
|
||||
"workspace",
|
||||
"files",
|
||||
VALID_UUID,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
// Fullwidth 'a' (U+FF41)
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"\uff41pi",
|
||||
"workspace",
|
||||
"files",
|
||||
VALID_UUID,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects hybrid path mixing workspace and public patterns", () => {
|
||||
// 5-segment but with public prefix
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"public",
|
||||
"shared",
|
||||
VALID_UUID,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
// 7-segment but with workspace prefix
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
VALID_UUID,
|
||||
"files",
|
||||
VALID_UUID_2,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects download appearing in non-terminal position", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"download",
|
||||
"files",
|
||||
VALID_UUID,
|
||||
]),
|
||||
).toBe(false);
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"public",
|
||||
"shared",
|
||||
"download",
|
||||
"files",
|
||||
VALID_UUID,
|
||||
"extra",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects prototype pollution segment names as IDs", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
"__proto__",
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
"constructor",
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects very long path segments (DoS vector)", () => {
|
||||
const longId = "a".repeat(10000);
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
longId,
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects UUID with embedded path separators", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
"550e8400/e29b-41d4-a716-446655440000",
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects UUID-shaped strings with unicode hyphens", () => {
|
||||
// EN DASH (U+2013) instead of HYPHEN-MINUS
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
"550e8400\u2013e29b\u201341d4\u2013a716\u2013446655440000",
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects SSRF-style payloads in ID position", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
"http://169.254.169.254",
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("isRedirectStatus", () => {
|
||||
|
||||
@@ -1,11 +1,34 @@
|
||||
const UUID_RE =
|
||||
/^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i;
|
||||
|
||||
export function isWorkspaceDownloadRequest(path: string[]): boolean {
|
||||
return (
|
||||
path.length == 5 &&
|
||||
// api/workspace/files/{id}/download
|
||||
if (
|
||||
path.length === 5 &&
|
||||
path[0] === "api" &&
|
||||
path[1] === "workspace" &&
|
||||
path[2] === "files" &&
|
||||
path[path.length - 1] === "download"
|
||||
);
|
||||
UUID_RE.test(path[3]) &&
|
||||
path[4] === "download"
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// api/public/shared/{token}/files/{id}/download
|
||||
if (
|
||||
path.length === 7 &&
|
||||
path[0] === "api" &&
|
||||
path[1] === "public" &&
|
||||
path[2] === "shared" &&
|
||||
UUID_RE.test(path[3]) &&
|
||||
path[4] === "files" &&
|
||||
UUID_RE.test(path[5]) &&
|
||||
path[6] === "download"
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
export function isRedirectStatus(status: number): boolean {
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
import { cleanup, render } from "@testing-library/react";
|
||||
import { afterEach, describe, expect, it } from "vitest";
|
||||
import { htmlRenderer } from "./HTMLRenderer";
|
||||
|
||||
describe("HTMLRenderer", () => {
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
});
|
||||
|
||||
it("renders text/html content in a sandboxed iframe", () => {
|
||||
const { container } = render(
|
||||
<>
|
||||
{htmlRenderer.render("<h1>Hi</h1>", {
|
||||
mimeType: "text/html",
|
||||
filename: "page.html",
|
||||
})}
|
||||
</>,
|
||||
);
|
||||
const iframe = container.querySelector("iframe");
|
||||
expect(iframe).toBeTruthy();
|
||||
expect(iframe?.getAttribute("sandbox")).toBe("allow-scripts");
|
||||
});
|
||||
|
||||
it("injects the fragment-link interceptor into the srcDoc (regression)", () => {
|
||||
const { container } = render(
|
||||
<>
|
||||
{htmlRenderer.render(
|
||||
'<html><head></head><body><a href="#x">x</a><div id="x">x</div></body></html>',
|
||||
{ mimeType: "text/html", filename: "page.html" },
|
||||
)}
|
||||
</>,
|
||||
);
|
||||
const srcdoc = container.querySelector("iframe")?.getAttribute("srcdoc");
|
||||
expect(srcdoc).toBeTruthy();
|
||||
expect(srcdoc).toContain("__fragmentLinkInterceptor");
|
||||
expect(srcdoc).toContain('a[href^="#"]');
|
||||
expect(srcdoc).toContain("scrollIntoView");
|
||||
});
|
||||
|
||||
it("canRender recognises text/html mime type and .html/.htm filenames", () => {
|
||||
expect(
|
||||
htmlRenderer.canRender("<h1>Hi</h1>", { mimeType: "text/html" }),
|
||||
).toBe(true);
|
||||
expect(
|
||||
htmlRenderer.canRender("<h1>Hi</h1>", { filename: "report.html" }),
|
||||
).toBe(true);
|
||||
expect(
|
||||
htmlRenderer.canRender("<h1>Hi</h1>", { filename: "report.htm" }),
|
||||
).toBe(true);
|
||||
expect(
|
||||
htmlRenderer.canRender("<h1>Hi</h1>", { mimeType: "text/plain" }),
|
||||
).toBe(false);
|
||||
});
|
||||
});
|
||||
@@ -1,5 +1,6 @@
|
||||
import React from "react";
|
||||
import {
|
||||
FRAGMENT_LINK_INTERCEPTOR_SCRIPT,
|
||||
TAILWIND_CDN_URL,
|
||||
wrapWithHeadInjection,
|
||||
} from "@/lib/iframe-sandbox-csp";
|
||||
@@ -13,7 +14,10 @@ import {
|
||||
function HTMLPreview({ value }: { value: string }) {
|
||||
// Inject Tailwind CDN — no CSP (see iframe-sandbox-csp.ts for why)
|
||||
const tailwindScript = `<script src="${TAILWIND_CDN_URL}"></script>`;
|
||||
const srcDoc = wrapWithHeadInjection(value, tailwindScript);
|
||||
const srcDoc = wrapWithHeadInjection(
|
||||
value,
|
||||
tailwindScript + FRAGMENT_LINK_INTERCEPTOR_SCRIPT,
|
||||
);
|
||||
return (
|
||||
<iframe
|
||||
sandbox="allow-scripts"
|
||||
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
isWorkspaceURI,
|
||||
buildWorkspaceURI,
|
||||
} from "@/lib/workspace-uri";
|
||||
import { workspaceFileRenderer } from "./WorkspaceFileRenderer";
|
||||
|
||||
describe("parseWorkspaceURI", () => {
|
||||
it("parses a full workspace URI with mime type", () => {
|
||||
@@ -113,3 +114,26 @@ describe("buildWorkspaceURI", () => {
|
||||
expect(parsed).toEqual({ fileID: "file-abc", mimeType: "text/plain" });
|
||||
});
|
||||
});
|
||||
|
||||
describe("workspaceFileRenderer.getDownloadContent", () => {
|
||||
it("returns auth-proxied URL without share token", () => {
|
||||
const result = workspaceFileRenderer.getDownloadContent(
|
||||
"workspace://file-123#image/png",
|
||||
);
|
||||
expect(result).not.toBeNull();
|
||||
expect(result!.data).toBe(
|
||||
"/api/proxy/api/workspace/files/file-123/download",
|
||||
);
|
||||
});
|
||||
|
||||
it("returns public share URL when share token is in metadata", () => {
|
||||
const result = workspaceFileRenderer.getDownloadContent(
|
||||
"workspace://file-123#image/png",
|
||||
{ shareToken: "abc-token-123" },
|
||||
);
|
||||
expect(result).not.toBeNull();
|
||||
expect(result!.data).toBe(
|
||||
"/api/proxy/api/public/shared/abc-token-123/files/file-123/download",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -37,7 +37,10 @@ const audioMimeTypes = [
|
||||
"audio/flac",
|
||||
];
|
||||
|
||||
function buildDownloadURL(fileID: string): string {
|
||||
function buildDownloadURL(fileID: string, shareToken?: string): string {
|
||||
if (shareToken) {
|
||||
return `/api/proxy/api/public/shared/${shareToken}/files/${fileID}/download`;
|
||||
}
|
||||
return `/api/proxy/api/workspace/files/${fileID}/download`;
|
||||
}
|
||||
|
||||
@@ -124,7 +127,7 @@ function renderWorkspaceFile(
|
||||
const uri = parseWorkspaceURI(String(value));
|
||||
if (!uri) return null;
|
||||
|
||||
const downloadURL = buildDownloadURL(uri.fileID);
|
||||
const downloadURL = buildDownloadURL(uri.fileID, metadata?.shareToken);
|
||||
const mimeType = uri.mimeType || metadata?.mimeType || null;
|
||||
|
||||
if (mimeType && imageMimeTypes.includes(mimeType)) {
|
||||
@@ -174,7 +177,7 @@ function getCopyContentWorkspaceFile(
|
||||
const uri = parseWorkspaceURI(String(value));
|
||||
if (!uri) return null;
|
||||
|
||||
const downloadURL = buildDownloadURL(uri.fileID);
|
||||
const downloadURL = buildDownloadURL(uri.fileID, metadata?.shareToken);
|
||||
const mimeType =
|
||||
uri.mimeType || metadata?.mimeType || "application/octet-stream";
|
||||
|
||||
@@ -205,7 +208,7 @@ function getDownloadContentWorkspaceFile(
|
||||
const filename = metadata?.filename || `file.${ext}`;
|
||||
|
||||
return {
|
||||
data: buildDownloadURL(uri.fileID),
|
||||
data: buildDownloadURL(uri.fileID, metadata?.shareToken),
|
||||
filename,
|
||||
mimeType,
|
||||
};
|
||||
|
||||
@@ -1,74 +1,2 @@
|
||||
import { OutputRenderer, OutputMetadata } from "../types";
|
||||
|
||||
export interface DownloadItem {
|
||||
value: any;
|
||||
metadata?: OutputMetadata;
|
||||
renderer: OutputRenderer;
|
||||
}
|
||||
|
||||
export async function downloadOutputs(items: DownloadItem[]) {
|
||||
const concatenableTexts: string[] = [];
|
||||
const nonConcatenableDownloads: Array<{ blob: Blob; filename: string }> = [];
|
||||
|
||||
for (const item of items) {
|
||||
if (item.renderer.isConcatenable(item.value, item.metadata)) {
|
||||
const copyContent = item.renderer.getCopyContent(
|
||||
item.value,
|
||||
item.metadata,
|
||||
);
|
||||
if (copyContent) {
|
||||
// Extract text from CopyContent
|
||||
let text: string;
|
||||
if (typeof copyContent.data === "string") {
|
||||
text = copyContent.data;
|
||||
} else if (copyContent.fallbackText) {
|
||||
text = copyContent.fallbackText;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
concatenableTexts.push(text);
|
||||
}
|
||||
} else {
|
||||
const downloadContent = item.renderer.getDownloadContent(
|
||||
item.value,
|
||||
item.metadata,
|
||||
);
|
||||
if (downloadContent) {
|
||||
if (typeof downloadContent.data === "string") {
|
||||
if (downloadContent.data.startsWith("http")) {
|
||||
const link = document.createElement("a");
|
||||
link.href = downloadContent.data;
|
||||
link.download = downloadContent.filename;
|
||||
link.click();
|
||||
}
|
||||
} else {
|
||||
nonConcatenableDownloads.push({
|
||||
blob: downloadContent.data as Blob,
|
||||
filename: downloadContent.filename,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (concatenableTexts.length > 0) {
|
||||
const combinedText = concatenableTexts.join("\n\n---\n\n");
|
||||
const blob = new Blob([combinedText], { type: "text/plain" });
|
||||
downloadBlob(blob, "combined_output.txt");
|
||||
}
|
||||
|
||||
for (const download of nonConcatenableDownloads) {
|
||||
downloadBlob(download.blob, download.filename);
|
||||
}
|
||||
}
|
||||
|
||||
function downloadBlob(blob: Blob, filename: string) {
|
||||
const url = URL.createObjectURL(blob);
|
||||
const link = document.createElement("a");
|
||||
link.href = url;
|
||||
link.download = filename;
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
document.body.removeChild(link);
|
||||
URL.revokeObjectURL(url);
|
||||
}
|
||||
export { downloadOutputs } from "@/lib/utils/download-outputs";
|
||||
export type { DownloadItem } from "@/lib/utils/download-outputs";
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { TAILWIND_CDN_URL, wrapWithHeadInjection } from "../iframe-sandbox-csp";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import {
|
||||
FRAGMENT_LINK_INTERCEPTOR_SCRIPT,
|
||||
TAILWIND_CDN_URL,
|
||||
wrapWithHeadInjection,
|
||||
} from "../iframe-sandbox-csp";
|
||||
|
||||
describe("wrapWithHeadInjection", () => {
|
||||
const injection = '<script src="https://example.com/lib.js"></script>';
|
||||
@@ -45,6 +49,142 @@ describe("TAILWIND_CDN_URL", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("FRAGMENT_LINK_INTERCEPTOR_SCRIPT", () => {
|
||||
// Evaluate the script body (without <script> tags) against the current
|
||||
// document. Because sandboxed srcdoc iframes run their scripts in isolation
|
||||
// anyway, the behavior we care about is just "this code, when executed in
|
||||
// a document, intercepts #anchor clicks and calls scrollIntoView".
|
||||
//
|
||||
// Parse the exported <script> via the DOM rather than regex — CodeQL flags
|
||||
// regex-based HTML stripping, and the test already runs in a DOM env.
|
||||
function installInterceptor() {
|
||||
const template = document.createElement("template");
|
||||
template.innerHTML = FRAGMENT_LINK_INTERCEPTOR_SCRIPT;
|
||||
const script = template.content.querySelector("script");
|
||||
if (!script) throw new Error("Interceptor script tag not found");
|
||||
new Function(script.textContent ?? "")();
|
||||
}
|
||||
|
||||
let cleanup: (() => void) | null = null;
|
||||
|
||||
beforeEach(() => {
|
||||
document.body.innerHTML = "";
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
if (cleanup) cleanup();
|
||||
cleanup = null;
|
||||
document.body.innerHTML = "";
|
||||
const doc = document as Document & {
|
||||
__fragmentLinkInterceptor?: EventListener;
|
||||
};
|
||||
if (doc.__fragmentLinkInterceptor) {
|
||||
document.removeEventListener("click", doc.__fragmentLinkInterceptor);
|
||||
delete doc.__fragmentLinkInterceptor;
|
||||
}
|
||||
});
|
||||
|
||||
it("exports a <script> tag wrapping the interceptor", () => {
|
||||
expect(FRAGMENT_LINK_INTERCEPTOR_SCRIPT.startsWith("<script>")).toBe(true);
|
||||
expect(FRAGMENT_LINK_INTERCEPTOR_SCRIPT.endsWith("</script>")).toBe(true);
|
||||
expect(FRAGMENT_LINK_INTERCEPTOR_SCRIPT).toContain("addEventListener");
|
||||
expect(FRAGMENT_LINK_INTERCEPTOR_SCRIPT).toContain("scrollIntoView");
|
||||
expect(FRAGMENT_LINK_INTERCEPTOR_SCRIPT).toContain('a[href^="#"]');
|
||||
});
|
||||
|
||||
// Install the interceptor first, then a tail listener that records
|
||||
// defaultPrevented. Listeners fire in registration order, so the tail
|
||||
// sees the post-interceptor state.
|
||||
function installWithObserver() {
|
||||
installInterceptor();
|
||||
const observed = { defaulted: false };
|
||||
const listener = (e: Event) => {
|
||||
observed.defaulted = e.defaultPrevented;
|
||||
};
|
||||
document.addEventListener("click", listener);
|
||||
cleanup = () => document.removeEventListener("click", listener);
|
||||
return observed;
|
||||
}
|
||||
|
||||
it("intercepts fragment-link clicks, calls preventDefault, and scrolls the target into view", () => {
|
||||
document.body.innerHTML = `
|
||||
<nav><a id="nav-link" href="#activation">Activation</a></nav>
|
||||
<section id="activation">Target</section>
|
||||
`;
|
||||
const scrollSpy = vi.fn();
|
||||
document.getElementById("activation")!.scrollIntoView = scrollSpy;
|
||||
|
||||
const observed = installWithObserver();
|
||||
|
||||
document.getElementById("nav-link")!.click();
|
||||
|
||||
expect(scrollSpy).toHaveBeenCalledTimes(1);
|
||||
expect(observed.defaulted).toBe(true);
|
||||
});
|
||||
|
||||
it("does not intercept bare '#' links (no target id)", () => {
|
||||
document.body.innerHTML = `<a id="top" href="#">Back to top</a>`;
|
||||
const observed = installWithObserver();
|
||||
|
||||
document.getElementById("top")!.click();
|
||||
|
||||
expect(observed.defaulted).toBe(false);
|
||||
});
|
||||
|
||||
it("does not intercept links with no matching target in the document", () => {
|
||||
document.body.innerHTML = `<a id="dangle" href="#missing">Nowhere</a>`;
|
||||
const observed = installWithObserver();
|
||||
|
||||
document.getElementById("dangle")!.click();
|
||||
|
||||
expect(observed.defaulted).toBe(false);
|
||||
});
|
||||
|
||||
it("does not intercept non-fragment links", () => {
|
||||
document.body.innerHTML = `<a id="ext" href="https://example.com/x">Ext</a>`;
|
||||
installInterceptor();
|
||||
const observed = { defaulted: false };
|
||||
const listener = (e: Event) => {
|
||||
observed.defaulted = e.defaultPrevented;
|
||||
e.preventDefault();
|
||||
};
|
||||
document.addEventListener("click", listener);
|
||||
cleanup = () => document.removeEventListener("click", listener);
|
||||
|
||||
document.getElementById("ext")!.click();
|
||||
|
||||
expect(observed.defaulted).toBe(false);
|
||||
});
|
||||
|
||||
it("scrolls to target when click originates from a nested child of the anchor", () => {
|
||||
document.body.innerHTML = `
|
||||
<a id="outer" href="#costs"><span id="inner">💰 Costs</span></a>
|
||||
<section id="costs">Target</section>
|
||||
`;
|
||||
const scrollSpy = vi.fn();
|
||||
document.getElementById("costs")!.scrollIntoView = scrollSpy;
|
||||
|
||||
installInterceptor();
|
||||
document.getElementById("inner")!.click();
|
||||
|
||||
expect(scrollSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("handles percent-encoded ids", () => {
|
||||
document.body.innerHTML = `
|
||||
<a id="enc" href="#top%20costs">Jump</a>
|
||||
<section id="top costs">Target</section>
|
||||
`;
|
||||
const scrollSpy = vi.fn();
|
||||
document.getElementById("top costs")!.scrollIntoView = scrollSpy;
|
||||
|
||||
installInterceptor();
|
||||
document.getElementById("enc")!.click();
|
||||
|
||||
expect(scrollSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe("no CSP is exported", () => {
|
||||
it("does not export ARTIFACT_IFRAME_CSP", async () => {
|
||||
const mod = await import("../iframe-sandbox-csp");
|
||||
|
||||
@@ -32,6 +32,38 @@
|
||||
// changes (SRI is not possible because the JIT runtime is generated on demand).
|
||||
export const TAILWIND_CDN_URL = "https://cdn.tailwindcss.com/3.4.16";
|
||||
|
||||
// Sandboxed srcdoc iframes without `allow-same-origin` resolve `href="#id"` links
|
||||
// against the parent's URL as base. The default click then either navigates the
|
||||
// iframe to `<parent-url>#id` (reloading our app inside the iframe) or updates
|
||||
// the parent window's hash — both of which break the artifact preview.
|
||||
//
|
||||
// This script stays inside the iframe document and handles in-page anchor
|
||||
// navigation locally by scrolling to the element with the matching id.
|
||||
export const FRAGMENT_LINK_INTERCEPTOR_SCRIPT = `<script>
|
||||
(function() {
|
||||
if (document.__fragmentLinkInterceptor) return;
|
||||
function handler(e) {
|
||||
var t = e.target;
|
||||
if (!t || typeof t.closest !== 'function') return;
|
||||
var a = t.closest('a[href^="#"]');
|
||||
if (!a) return;
|
||||
var href = a.getAttribute('href');
|
||||
if (!href || href === '#') return;
|
||||
var id;
|
||||
try { id = decodeURIComponent(href.slice(1)); } catch (_) { id = href.slice(1); }
|
||||
if (!id) return;
|
||||
var target = document.getElementById(id);
|
||||
if (!target) return;
|
||||
e.preventDefault();
|
||||
if (typeof target.scrollIntoView === 'function') {
|
||||
target.scrollIntoView({ behavior: 'smooth', block: 'start' });
|
||||
}
|
||||
}
|
||||
document.__fragmentLinkInterceptor = handler;
|
||||
document.addEventListener('click', handler);
|
||||
})();
|
||||
</script>`;
|
||||
|
||||
/**
|
||||
* Inject content into the <head> of an HTML document string.
|
||||
* If the content has no <head> tag, wraps it in a full document skeleton.
|
||||
|
||||
@@ -0,0 +1,423 @@
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import {
|
||||
sanitizeFilename,
|
||||
getUniqueFilename,
|
||||
downloadOutputs,
|
||||
} from "../download-outputs";
|
||||
import type { DownloadItem } from "../download-outputs";
|
||||
|
||||
describe("sanitizeFilename", () => {
|
||||
it("strips forward slashes", () => {
|
||||
expect(sanitizeFilename("path/to/file.txt")).toBe("path_to_file.txt");
|
||||
});
|
||||
|
||||
it("strips backslashes", () => {
|
||||
expect(sanitizeFilename("path\\to\\file.txt")).toBe("path_to_file.txt");
|
||||
});
|
||||
|
||||
it("replaces parent directory traversal", () => {
|
||||
const result = sanitizeFilename("../../etc/passwd");
|
||||
expect(result).not.toContain("/");
|
||||
expect(result).not.toContain("\\");
|
||||
expect(result).not.toContain("..");
|
||||
expect(result).not.toMatch(/^\./);
|
||||
});
|
||||
|
||||
it("strips leading dots", () => {
|
||||
expect(sanitizeFilename(".gitignore")).toBe("gitignore");
|
||||
expect(sanitizeFilename("..hidden")).toBe("hidden");
|
||||
expect(sanitizeFilename("...triple")).toBe("triple");
|
||||
});
|
||||
|
||||
it("returns 'file' for empty results", () => {
|
||||
expect(sanitizeFilename("")).toBe("file");
|
||||
expect(sanitizeFilename("...")).toBe("file");
|
||||
expect(sanitizeFilename(".")).toBe("file");
|
||||
});
|
||||
|
||||
it("leaves safe filenames unchanged", () => {
|
||||
expect(sanitizeFilename("report.pdf")).toBe("report.pdf");
|
||||
expect(sanitizeFilename("image_001.png")).toBe("image_001.png");
|
||||
});
|
||||
});
|
||||
|
||||
describe("getUniqueFilename", () => {
|
||||
it("returns the filename when not already used", () => {
|
||||
const used = new Set<string>();
|
||||
expect(getUniqueFilename("file.txt", used)).toBe("file.txt");
|
||||
expect(used.has("file.txt")).toBe(true);
|
||||
});
|
||||
|
||||
it("appends a counter on collision", () => {
|
||||
const used = new Set<string>(["file.txt"]);
|
||||
expect(getUniqueFilename("file.txt", used)).toBe("file_1.txt");
|
||||
expect(used.has("file_1.txt")).toBe(true);
|
||||
});
|
||||
|
||||
it("increments counter until unique", () => {
|
||||
const used = new Set<string>(["file.txt", "file_1.txt", "file_2.txt"]);
|
||||
expect(getUniqueFilename("file.txt", used)).toBe("file_3.txt");
|
||||
});
|
||||
|
||||
it("handles filenames without extensions", () => {
|
||||
const used = new Set<string>(["README"]);
|
||||
expect(getUniqueFilename("README", used)).toBe("README_1");
|
||||
});
|
||||
|
||||
it("sanitizes the filename before deduplication", () => {
|
||||
const used = new Set<string>();
|
||||
expect(getUniqueFilename("../evil.txt", used)).toBe("_evil.txt");
|
||||
});
|
||||
|
||||
it("handles dotfiles by stripping leading dots first", () => {
|
||||
const used = new Set<string>();
|
||||
expect(getUniqueFilename(".gitignore", used)).toBe("gitignore");
|
||||
});
|
||||
});
|
||||
|
||||
const mockZipFile = vi.fn();
|
||||
const mockGenerateAsync = vi.fn();
|
||||
let mockZipFiles: Record<string, { async: () => Promise<Blob> }> = {};
|
||||
|
||||
vi.mock("jszip", () => ({
|
||||
default: class MockJSZip {
|
||||
files = mockZipFiles;
|
||||
file = (...args: unknown[]) => {
|
||||
if (typeof args[0] === "string" && args[1] !== undefined) {
|
||||
const content = args[1];
|
||||
mockZipFiles[args[0] as string] = {
|
||||
async: () =>
|
||||
Promise.resolve(
|
||||
content instanceof Blob ? content : new Blob([String(content)]),
|
||||
),
|
||||
};
|
||||
}
|
||||
mockZipFile(...args);
|
||||
};
|
||||
generateAsync = mockGenerateAsync;
|
||||
},
|
||||
}));
|
||||
|
||||
function makeRenderer(overrides: {
|
||||
isConcatenable?: boolean;
|
||||
copyData?: string;
|
||||
downloadData?: Blob | string;
|
||||
downloadFilename?: string;
|
||||
}) {
|
||||
return {
|
||||
value: "test",
|
||||
metadata: undefined,
|
||||
renderer: {
|
||||
name: "test",
|
||||
priority: 1,
|
||||
canRender: () => true,
|
||||
render: () => null,
|
||||
isConcatenable: () => overrides.isConcatenable ?? false,
|
||||
getCopyContent: () =>
|
||||
overrides.copyData
|
||||
? { mimeType: "text/plain", data: overrides.copyData }
|
||||
: null,
|
||||
getDownloadContent: () =>
|
||||
overrides.downloadData
|
||||
? {
|
||||
data: overrides.downloadData,
|
||||
filename: overrides.downloadFilename ?? "file.bin",
|
||||
mimeType: "application/octet-stream",
|
||||
}
|
||||
: null,
|
||||
},
|
||||
} satisfies DownloadItem;
|
||||
}
|
||||
|
||||
describe("downloadOutputs", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockZipFiles = {};
|
||||
mockGenerateAsync.mockResolvedValue(new Blob(["zip-content"]));
|
||||
vi.stubGlobal(
|
||||
"URL",
|
||||
Object.assign(URL, {
|
||||
createObjectURL: vi.fn(() => "blob:mock-url"),
|
||||
revokeObjectURL: vi.fn(),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("creates a zip with concatenable text outputs", async () => {
|
||||
const items = [
|
||||
makeRenderer({ isConcatenable: true, copyData: "Hello" }),
|
||||
makeRenderer({ isConcatenable: true, copyData: "World" }),
|
||||
];
|
||||
|
||||
await downloadOutputs(items);
|
||||
|
||||
expect(mockZipFile).toHaveBeenCalledWith(
|
||||
"combined_output.txt",
|
||||
"Hello\n\n---\n\nWorld",
|
||||
);
|
||||
// Single file in zip → downloaded directly, no zip generation
|
||||
expect(mockGenerateAsync).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("includes direct blob data in the zip", async () => {
|
||||
const blob = new Blob(["binary data"]);
|
||||
const items = [
|
||||
makeRenderer({ downloadData: blob, downloadFilename: "image.png" }),
|
||||
];
|
||||
|
||||
await downloadOutputs(items);
|
||||
|
||||
expect(mockZipFile).toHaveBeenCalledWith("image.png", blob);
|
||||
});
|
||||
|
||||
it("skips blobs exceeding size limit", async () => {
|
||||
const consoleSpy = vi.spyOn(console, "warn").mockImplementation(() => {});
|
||||
const bigBlob = new Blob(["x".repeat(100)]);
|
||||
Object.defineProperty(bigBlob, "size", { value: 60 * 1024 * 1024 });
|
||||
|
||||
const items = [
|
||||
makeRenderer({ downloadData: bigBlob, downloadFilename: "huge.bin" }),
|
||||
];
|
||||
|
||||
await downloadOutputs(items);
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining("blob too large"),
|
||||
);
|
||||
expect(mockZipFile).not.toHaveBeenCalledWith("huge.bin", expect.anything());
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
|
||||
it("fetches http URLs and adds to zip", async () => {
|
||||
const mockBlob = new Blob(["fetched"]);
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
headers: new Headers({ "content-length": "7" }),
|
||||
blob: () => Promise.resolve(mockBlob),
|
||||
}),
|
||||
);
|
||||
|
||||
const items = [
|
||||
makeRenderer({
|
||||
downloadData: "https://example.com/file.pdf",
|
||||
downloadFilename: "report.pdf",
|
||||
}),
|
||||
];
|
||||
|
||||
await downloadOutputs(items);
|
||||
|
||||
expect(fetch).toHaveBeenCalledWith("https://example.com/file.pdf", {
|
||||
mode: "cors",
|
||||
});
|
||||
expect(mockZipFile).toHaveBeenCalledWith("report.pdf", mockBlob);
|
||||
});
|
||||
|
||||
it("handles fetch failures gracefully and records unfetchable URLs", async () => {
|
||||
const consoleSpy = vi.spyOn(console, "warn").mockImplementation(() => {});
|
||||
vi.stubGlobal("fetch", vi.fn().mockRejectedValue(new Error("CORS error")));
|
||||
|
||||
const items = [
|
||||
makeRenderer({ isConcatenable: true, copyData: "some text" }),
|
||||
makeRenderer({
|
||||
downloadData: "https://cors-blocked.com/file.bin",
|
||||
downloadFilename: "blocked.bin",
|
||||
}),
|
||||
];
|
||||
|
||||
await downloadOutputs(items);
|
||||
|
||||
expect(mockZipFile).toHaveBeenCalledWith(
|
||||
"unfetched_files.txt",
|
||||
expect.stringContaining("cors-blocked.com"),
|
||||
);
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
|
||||
it("handles malformed data URLs with try-catch", async () => {
|
||||
const consoleSpy = vi.spyOn(console, "warn").mockImplementation(() => {});
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockRejectedValue(new Error("Invalid data URL")),
|
||||
);
|
||||
|
||||
const items = [
|
||||
makeRenderer({
|
||||
downloadData: "data:invalid",
|
||||
downloadFilename: "broken.bin",
|
||||
}),
|
||||
];
|
||||
|
||||
await downloadOutputs(items);
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining("malformed or unsupported format"),
|
||||
);
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
|
||||
it("skips unsupported URL formats with a warning", async () => {
|
||||
const consoleSpy = vi.spyOn(console, "warn").mockImplementation(() => {});
|
||||
|
||||
const items = [
|
||||
makeRenderer({
|
||||
downloadData: "ftp://server/file.dat",
|
||||
downloadFilename: "file.dat",
|
||||
}),
|
||||
];
|
||||
|
||||
await downloadOutputs(items);
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining("unsupported URL format"),
|
||||
);
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
|
||||
it("does nothing when items array is empty", async () => {
|
||||
await downloadOutputs([]);
|
||||
|
||||
expect(mockZipFile).not.toHaveBeenCalled();
|
||||
expect(mockGenerateAsync).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("fetches relative URLs (workspace files) and adds to zip", async () => {
|
||||
const mockBlob = new Blob(["image-data"]);
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
headers: new Headers({ "content-length": "10" }),
|
||||
blob: () => Promise.resolve(mockBlob),
|
||||
}),
|
||||
);
|
||||
|
||||
const items = [
|
||||
makeRenderer({
|
||||
downloadData: "/api/proxy/api/workspace/files/abc-123/download",
|
||||
downloadFilename: "photo.png",
|
||||
}),
|
||||
];
|
||||
|
||||
await downloadOutputs(items);
|
||||
|
||||
expect(fetch).toHaveBeenCalledWith(
|
||||
"/api/proxy/api/workspace/files/abc-123/download",
|
||||
{ mode: "cors" },
|
||||
);
|
||||
expect(mockZipFile).toHaveBeenCalledWith("photo.png", mockBlob);
|
||||
});
|
||||
|
||||
it("includes workspace images that renderers return as relative URLs", async () => {
|
||||
const mockBlob = new Blob(["img"]);
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
headers: new Headers({ "content-length": "3" }),
|
||||
blob: () => Promise.resolve(mockBlob),
|
||||
}),
|
||||
);
|
||||
|
||||
const items = [
|
||||
makeRenderer({
|
||||
downloadData: "/api/proxy/api/workspace/files/file-1/download",
|
||||
downloadFilename: "image1.png",
|
||||
}),
|
||||
makeRenderer({
|
||||
downloadData: "/api/proxy/api/workspace/files/file-2/download",
|
||||
downloadFilename: "image2.jpg",
|
||||
}),
|
||||
];
|
||||
|
||||
await downloadOutputs(items);
|
||||
|
||||
expect(mockZipFile).toHaveBeenCalledWith("image1.png", mockBlob);
|
||||
expect(mockZipFile).toHaveBeenCalledWith("image2.jpg", mockBlob);
|
||||
});
|
||||
|
||||
it("fetches public share endpoint URLs for workspace files", async () => {
|
||||
const mockBlob = new Blob(["shared-image-data"]);
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
headers: new Headers({ "content-length": "17" }),
|
||||
blob: () => Promise.resolve(mockBlob),
|
||||
}),
|
||||
);
|
||||
|
||||
const items = [
|
||||
makeRenderer({
|
||||
downloadData:
|
||||
"/api/proxy/api/public/shared/abc-token/files/file-123/download",
|
||||
downloadFilename: "shared-image.png",
|
||||
}),
|
||||
];
|
||||
|
||||
await downloadOutputs(items);
|
||||
|
||||
expect(fetch).toHaveBeenCalledWith(
|
||||
"/api/proxy/api/public/shared/abc-token/files/file-123/download",
|
||||
{ mode: "cors" },
|
||||
);
|
||||
expect(mockZipFile).toHaveBeenCalledWith("shared-image.png", mockBlob);
|
||||
});
|
||||
|
||||
it("rejects files over content-length before buffering", async () => {
|
||||
const consoleSpy = vi.spyOn(console, "warn").mockImplementation(() => {});
|
||||
const blobFn = vi.fn();
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
headers: new Headers({
|
||||
"content-length": String(60 * 1024 * 1024),
|
||||
}),
|
||||
blob: blobFn,
|
||||
}),
|
||||
);
|
||||
|
||||
const items = [
|
||||
makeRenderer({
|
||||
downloadData: "https://example.com/huge.zip",
|
||||
downloadFilename: "huge.zip",
|
||||
}),
|
||||
];
|
||||
|
||||
await downloadOutputs(items);
|
||||
|
||||
expect(blobFn).not.toHaveBeenCalled();
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining("file too large"),
|
||||
);
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
|
||||
it("downloads single file directly without zip wrapping", async () => {
|
||||
const blob = new Blob(["single file"]);
|
||||
const items = [
|
||||
makeRenderer({ downloadData: blob, downloadFilename: "photo.png" }),
|
||||
];
|
||||
|
||||
await downloadOutputs(items);
|
||||
|
||||
expect(mockZipFile).toHaveBeenCalledWith("photo.png", blob);
|
||||
expect(mockGenerateAsync).not.toHaveBeenCalled();
|
||||
expect(URL.createObjectURL).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("uses zip when multiple files are present", async () => {
|
||||
const blob1 = new Blob(["file1"]);
|
||||
const blob2 = new Blob(["file2"]);
|
||||
const items = [
|
||||
makeRenderer({ downloadData: blob1, downloadFilename: "a.png" }),
|
||||
makeRenderer({ downloadData: blob2, downloadFilename: "b.png" }),
|
||||
];
|
||||
|
||||
await downloadOutputs(items);
|
||||
|
||||
expect(mockGenerateAsync).toHaveBeenCalledWith({ type: "blob" });
|
||||
});
|
||||
});
|
||||
282
autogpt_platform/frontend/src/lib/utils/download-outputs.ts
Normal file
282
autogpt_platform/frontend/src/lib/utils/download-outputs.ts
Normal file
@@ -0,0 +1,282 @@
|
||||
import type {
|
||||
OutputRenderer,
|
||||
OutputMetadata,
|
||||
} from "@/components/contextual/OutputRenderers/types";
|
||||
|
||||
export interface DownloadItem {
|
||||
value: unknown;
|
||||
metadata?: OutputMetadata;
|
||||
renderer: OutputRenderer;
|
||||
}
|
||||
|
||||
/** Maximum individual file size for zip inclusion (50 MB) */
|
||||
const MAX_FILE_SIZE_BYTES = 50 * 1024 * 1024;
|
||||
|
||||
/** Maximum total zip content size before generation (200 MB) */
|
||||
const MAX_TOTAL_SIZE_BYTES = 200 * 1024 * 1024;
|
||||
|
||||
/** Maximum concurrent file fetches */
|
||||
const FETCH_CONCURRENCY = 5;
|
||||
|
||||
async function fetchFileAsBlob(url: string): Promise<Blob | null> {
|
||||
try {
|
||||
const response = await fetch(url, { mode: "cors" });
|
||||
if (!response.ok) {
|
||||
console.error(`Failed to fetch ${url}: ${response.status}`);
|
||||
return null;
|
||||
}
|
||||
const contentLength = Number(response.headers.get("content-length") ?? "0");
|
||||
if (contentLength > MAX_FILE_SIZE_BYTES) {
|
||||
console.warn(
|
||||
`Skipping ${url}: file too large (${(contentLength / 1024 / 1024).toFixed(1)} MB, limit ${MAX_FILE_SIZE_BYTES / 1024 / 1024} MB)`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
const blob = await response.blob();
|
||||
if (blob.size > MAX_FILE_SIZE_BYTES) {
|
||||
console.warn(
|
||||
`Skipping ${url}: file too large (${(blob.size / 1024 / 1024).toFixed(1)} MB, limit ${MAX_FILE_SIZE_BYTES / 1024 / 1024} MB)`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
return blob;
|
||||
} catch (_error) {
|
||||
console.warn(
|
||||
`Could not fetch ${url} (likely CORS). Adding as link reference.`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/** Strip path traversal components and unsafe characters from a filename. */
|
||||
export function sanitizeFilename(filename: string): string {
|
||||
const sanitized = filename
|
||||
.replace(/[/\\]/g, "_")
|
||||
.replace(/^\.+/, "")
|
||||
.replace(/\.\./g, "_");
|
||||
return sanitized || "file";
|
||||
}
|
||||
|
||||
export function getUniqueFilename(
|
||||
filename: string,
|
||||
usedNames: Set<string>,
|
||||
): string {
|
||||
const safe = sanitizeFilename(filename);
|
||||
if (!usedNames.has(safe)) {
|
||||
usedNames.add(safe);
|
||||
return safe;
|
||||
}
|
||||
|
||||
const dotIndex = safe.lastIndexOf(".");
|
||||
const baseName = dotIndex > 0 ? safe.slice(0, dotIndex) : safe;
|
||||
const extension = dotIndex > 0 ? safe.slice(dotIndex) : "";
|
||||
|
||||
let counter = 1;
|
||||
let newName = `${baseName}_${counter}${extension}`;
|
||||
while (usedNames.has(newName)) {
|
||||
counter++;
|
||||
newName = `${baseName}_${counter}${extension}`;
|
||||
}
|
||||
usedNames.add(newName);
|
||||
return newName;
|
||||
}
|
||||
|
||||
async function fetchInParallel<T>(
|
||||
tasks: (() => Promise<T>)[],
|
||||
concurrency: number,
|
||||
): Promise<T[]> {
|
||||
const results: T[] = [];
|
||||
let index = 0;
|
||||
|
||||
async function worker() {
|
||||
while (index < tasks.length) {
|
||||
const i = index++;
|
||||
results[i] = await tasks[i]();
|
||||
}
|
||||
}
|
||||
|
||||
await Promise.all(
|
||||
Array.from({ length: Math.min(concurrency, tasks.length) }, () => worker()),
|
||||
);
|
||||
return results;
|
||||
}
|
||||
|
||||
type FetchResult = {
|
||||
blob: Blob | null;
|
||||
filename: string;
|
||||
sourceUrl: string | null;
|
||||
};
|
||||
|
||||
export async function downloadOutputs(items: DownloadItem[]) {
|
||||
if (items.length === 0) return;
|
||||
|
||||
const { default: JSZip } = await import("jszip");
|
||||
const zip = new JSZip();
|
||||
const usedFilenames = new Set<string>();
|
||||
let hasFiles = false;
|
||||
let totalSize = 0;
|
||||
|
||||
const concatenableTexts: string[] = [];
|
||||
const unfetchableUrls: string[] = [];
|
||||
|
||||
const fileItems: Array<{
|
||||
downloadContent: { data: unknown; filename: string };
|
||||
}> = [];
|
||||
|
||||
for (const item of items) {
|
||||
if (item.renderer.isConcatenable(item.value, item.metadata)) {
|
||||
const copyContent = item.renderer.getCopyContent(
|
||||
item.value,
|
||||
item.metadata,
|
||||
);
|
||||
if (copyContent) {
|
||||
let text: string;
|
||||
if (typeof copyContent.data === "string") {
|
||||
text = copyContent.data;
|
||||
} else if (copyContent.fallbackText) {
|
||||
text = copyContent.fallbackText;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
concatenableTexts.push(text);
|
||||
}
|
||||
} else {
|
||||
const downloadContent = item.renderer.getDownloadContent(
|
||||
item.value,
|
||||
item.metadata,
|
||||
);
|
||||
if (downloadContent) {
|
||||
fileItems.push({ downloadContent });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const fetchTasks = fileItems.map(
|
||||
({ downloadContent }) =>
|
||||
async (): Promise<FetchResult> => {
|
||||
let blob: Blob | null = null;
|
||||
const filename = downloadContent.filename;
|
||||
let sourceUrl: string | null = null;
|
||||
|
||||
if (typeof downloadContent.data === "string") {
|
||||
if (
|
||||
downloadContent.data.startsWith("http://") ||
|
||||
downloadContent.data.startsWith("https://") ||
|
||||
downloadContent.data.startsWith("/")
|
||||
) {
|
||||
sourceUrl = downloadContent.data;
|
||||
blob = await fetchFileAsBlob(downloadContent.data);
|
||||
} else if (downloadContent.data.startsWith("data:")) {
|
||||
try {
|
||||
const dataBlob = await fetch(downloadContent.data).then((r) =>
|
||||
r.blob(),
|
||||
);
|
||||
if (dataBlob.size <= MAX_FILE_SIZE_BYTES) {
|
||||
blob = dataBlob;
|
||||
} else {
|
||||
console.warn(
|
||||
`Skipping data URL: too large (${(dataBlob.size / 1024 / 1024).toFixed(1)} MB)`,
|
||||
);
|
||||
}
|
||||
} catch (_error) {
|
||||
console.warn(
|
||||
`Failed to process data URL for ${filename}: malformed or unsupported format`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
console.warn(
|
||||
`Skipping unsupported URL format: ${downloadContent.data.slice(0, 50)}...`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
const rawBlob = downloadContent.data as Blob;
|
||||
if (rawBlob.size <= MAX_FILE_SIZE_BYTES) {
|
||||
blob = rawBlob;
|
||||
} else {
|
||||
console.warn(
|
||||
`Skipping ${filename}: blob too large (${(rawBlob.size / 1024 / 1024).toFixed(1)} MB)`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return { blob, filename, sourceUrl };
|
||||
},
|
||||
);
|
||||
|
||||
const results = await fetchInParallel(fetchTasks, FETCH_CONCURRENCY);
|
||||
|
||||
for (const { blob, filename, sourceUrl } of results) {
|
||||
if (blob) {
|
||||
if (totalSize + blob.size > MAX_TOTAL_SIZE_BYTES) {
|
||||
console.warn(
|
||||
`Skipping ${filename}: would exceed total zip size limit (${MAX_TOTAL_SIZE_BYTES / 1024 / 1024} MB)`,
|
||||
);
|
||||
if (sourceUrl) unfetchableUrls.push(sourceUrl);
|
||||
continue;
|
||||
}
|
||||
const uniqueFilename = getUniqueFilename(filename, usedFilenames);
|
||||
zip.file(uniqueFilename, blob);
|
||||
totalSize += blob.size;
|
||||
hasFiles = true;
|
||||
} else if (sourceUrl) {
|
||||
unfetchableUrls.push(sourceUrl);
|
||||
}
|
||||
}
|
||||
|
||||
if (concatenableTexts.length > 0) {
|
||||
const combinedText = concatenableTexts.join("\n\n---\n\n");
|
||||
const textSize = new Blob([combinedText]).size;
|
||||
if (totalSize + textSize <= MAX_TOTAL_SIZE_BYTES) {
|
||||
const filename = getUniqueFilename("combined_output.txt", usedFilenames);
|
||||
zip.file(filename, combinedText);
|
||||
totalSize += textSize;
|
||||
hasFiles = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (unfetchableUrls.length > 0) {
|
||||
const linksContent = unfetchableUrls
|
||||
.map((url, i) => `${i + 1}. ${url}`)
|
||||
.join("\n");
|
||||
const manifest = `The following files could not be included in the zip (CORS restriction or size limit).\nYou can download them directly from these URLs:\n\n${linksContent}\n`;
|
||||
const manifestSize = new Blob([manifest]).size;
|
||||
if (totalSize + manifestSize <= MAX_TOTAL_SIZE_BYTES) {
|
||||
const manifestFilename = getUniqueFilename(
|
||||
"unfetched_files.txt",
|
||||
usedFilenames,
|
||||
);
|
||||
zip.file(manifestFilename, manifest);
|
||||
totalSize += manifestSize;
|
||||
hasFiles = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!hasFiles) return;
|
||||
|
||||
// Single-file shortcut: download directly instead of wrapping in a zip
|
||||
if (
|
||||
zip.files &&
|
||||
Object.keys(zip.files).length === 1 &&
|
||||
unfetchableUrls.length === 0
|
||||
) {
|
||||
const onlyFilename = Object.keys(zip.files)[0];
|
||||
const entry = zip.files[onlyFilename];
|
||||
const content = await entry.async("blob");
|
||||
downloadBlob(content, onlyFilename);
|
||||
return;
|
||||
}
|
||||
|
||||
const zipBlob = await zip.generateAsync({ type: "blob" });
|
||||
downloadBlob(zipBlob, "outputs.zip");
|
||||
}
|
||||
|
||||
function downloadBlob(blob: Blob, filename: string) {
|
||||
const url = URL.createObjectURL(blob);
|
||||
const link = document.createElement("a");
|
||||
link.href = url;
|
||||
link.download = filename;
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
document.body.removeChild(link);
|
||||
URL.revokeObjectURL(url);
|
||||
}
|
||||
Reference in New Issue
Block a user