diff --git a/autogpt_platform/.gitignore b/autogpt_platform/.gitignore index 3e31a9970e..bc70dc96bc 100644 --- a/autogpt_platform/.gitignore +++ b/autogpt_platform/.gitignore @@ -1,3 +1,6 @@ *.ignore.* *.ign.* .application.logs + +# Claude Code local settings only — the rest of .claude/ is shared (skills etc.) +.claude/settings.local.json diff --git a/autogpt_platform/backend/.env.default b/autogpt_platform/backend/.env.default index e731f9f9bf..67444c2e36 100644 --- a/autogpt_platform/backend/.env.default +++ b/autogpt_platform/backend/.env.default @@ -179,6 +179,9 @@ MEM0_API_KEY= OPENWEATHERMAP_API_KEY= GOOGLE_MAPS_API_KEY= +# Platform Bot Linking +PLATFORM_LINK_BASE_URL=http://localhost:3000/link + # Communication Services DISCORD_BOT_TOKEN= MEDIUM_API_KEY= diff --git a/autogpt_platform/backend/backend/api/features/platform_linking/__init__.py b/autogpt_platform/backend/backend/api/features/platform_linking/__init__.py new file mode 100644 index 0000000000..7764686098 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/platform_linking/__init__.py @@ -0,0 +1 @@ +"""Platform bot linking — user-facing REST routes.""" diff --git a/autogpt_platform/backend/backend/api/features/platform_linking/routes.py b/autogpt_platform/backend/backend/api/features/platform_linking/routes.py new file mode 100644 index 0000000000..7b0f845c01 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/platform_linking/routes.py @@ -0,0 +1,158 @@ +"""User-facing platform_linking REST routes (JWT auth).""" + +import logging +from typing import Annotated + +from autogpt_libs import auth +from fastapi import APIRouter, HTTPException, Path, Security + +from backend.data.db_accessors import platform_linking_db +from backend.platform_linking.models import ( + ConfirmLinkResponse, + ConfirmUserLinkResponse, + DeleteLinkResponse, + LinkTokenInfoResponse, + PlatformLinkInfo, + PlatformUserLinkInfo, +) +from backend.util.exceptions import ( + LinkAlreadyExistsError, + LinkFlowMismatchError, + LinkTokenExpiredError, + NotAuthorizedError, + NotFoundError, +) + +logger = logging.getLogger(__name__) + +router = APIRouter() + +TokenPath = Annotated[ + str, + Path(max_length=64, pattern=r"^[A-Za-z0-9_-]+$"), +] + + +def _translate(exc: Exception) -> HTTPException: + if isinstance(exc, NotFoundError): + return HTTPException(status_code=404, detail=str(exc)) + if isinstance(exc, NotAuthorizedError): + return HTTPException(status_code=403, detail=str(exc)) + if isinstance(exc, LinkAlreadyExistsError): + return HTTPException(status_code=409, detail=str(exc)) + if isinstance(exc, LinkTokenExpiredError): + return HTTPException(status_code=410, detail=str(exc)) + if isinstance(exc, LinkFlowMismatchError): + return HTTPException(status_code=400, detail=str(exc)) + return HTTPException(status_code=500, detail="Internal error.") + + +@router.get( + "/tokens/{token}/info", + response_model=LinkTokenInfoResponse, + dependencies=[Security(auth.requires_user)], + summary="Get display info for a link token", +) +async def get_link_token_info_route(token: TokenPath) -> LinkTokenInfoResponse: + try: + return await platform_linking_db().get_link_token_info(token) + except (NotFoundError, LinkTokenExpiredError) as exc: + raise _translate(exc) from exc + + +@router.post( + "/tokens/{token}/confirm", + response_model=ConfirmLinkResponse, + dependencies=[Security(auth.requires_user)], + summary="Confirm a SERVER link token (user must be authenticated)", +) +async def confirm_link_token( + token: TokenPath, + user_id: Annotated[str, Security(auth.get_user_id)], +) -> ConfirmLinkResponse: + try: + return await platform_linking_db().confirm_server_link(token, user_id) + except ( + NotFoundError, + LinkFlowMismatchError, + LinkTokenExpiredError, + LinkAlreadyExistsError, + ) as exc: + raise _translate(exc) from exc + + +@router.post( + "/user-tokens/{token}/confirm", + response_model=ConfirmUserLinkResponse, + dependencies=[Security(auth.requires_user)], + summary="Confirm a USER link token (user must be authenticated)", +) +async def confirm_user_link_token( + token: TokenPath, + user_id: Annotated[str, Security(auth.get_user_id)], +) -> ConfirmUserLinkResponse: + try: + return await platform_linking_db().confirm_user_link(token, user_id) + except ( + NotFoundError, + LinkFlowMismatchError, + LinkTokenExpiredError, + LinkAlreadyExistsError, + ) as exc: + raise _translate(exc) from exc + + +@router.get( + "/links", + response_model=list[PlatformLinkInfo], + dependencies=[Security(auth.requires_user)], + summary="List all platform servers linked to the authenticated user", +) +async def list_my_links( + user_id: Annotated[str, Security(auth.get_user_id)], +) -> list[PlatformLinkInfo]: + return await platform_linking_db().list_server_links(user_id) + + +@router.get( + "/user-links", + response_model=list[PlatformUserLinkInfo], + dependencies=[Security(auth.requires_user)], + summary="List all DM links for the authenticated user", +) +async def list_my_user_links( + user_id: Annotated[str, Security(auth.get_user_id)], +) -> list[PlatformUserLinkInfo]: + return await platform_linking_db().list_user_links(user_id) + + +@router.delete( + "/links/{link_id}", + response_model=DeleteLinkResponse, + dependencies=[Security(auth.requires_user)], + summary="Unlink a platform server", +) +async def delete_link( + link_id: str, + user_id: Annotated[str, Security(auth.get_user_id)], +) -> DeleteLinkResponse: + try: + return await platform_linking_db().delete_server_link(link_id, user_id) + except (NotFoundError, NotAuthorizedError) as exc: + raise _translate(exc) from exc + + +@router.delete( + "/user-links/{link_id}", + response_model=DeleteLinkResponse, + dependencies=[Security(auth.requires_user)], + summary="Unlink a DM / user link", +) +async def delete_user_link_route( + link_id: str, + user_id: Annotated[str, Security(auth.get_user_id)], +) -> DeleteLinkResponse: + try: + return await platform_linking_db().delete_user_link(link_id, user_id) + except (NotFoundError, NotAuthorizedError) as exc: + raise _translate(exc) from exc diff --git a/autogpt_platform/backend/backend/api/features/platform_linking/routes_test.py b/autogpt_platform/backend/backend/api/features/platform_linking/routes_test.py new file mode 100644 index 0000000000..944ef8eb6a --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/platform_linking/routes_test.py @@ -0,0 +1,264 @@ +"""Route tests: domain exceptions → HTTPException status codes.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import HTTPException + +from backend.util.exceptions import ( + LinkAlreadyExistsError, + LinkFlowMismatchError, + LinkTokenExpiredError, + NotAuthorizedError, + NotFoundError, +) + + +def _db_mock(**method_configs): + """Return a mock of the accessor's return value with the given AsyncMocks.""" + db = MagicMock() + for name, mock in method_configs.items(): + setattr(db, name, mock) + return db + + +class TestTokenInfoRouteTranslation: + @pytest.mark.asyncio + async def test_not_found_maps_to_404(self): + from backend.api.features.platform_linking.routes import ( + get_link_token_info_route, + ) + + db = _db_mock( + get_link_token_info=AsyncMock(side_effect=NotFoundError("missing")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as exc: + await get_link_token_info_route(token="abc") + assert exc.value.status_code == 404 + + @pytest.mark.asyncio + async def test_expired_maps_to_410(self): + from backend.api.features.platform_linking.routes import ( + get_link_token_info_route, + ) + + db = _db_mock( + get_link_token_info=AsyncMock(side_effect=LinkTokenExpiredError("expired")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as exc: + await get_link_token_info_route(token="abc") + assert exc.value.status_code == 410 + + +class TestConfirmLinkRouteTranslation: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "exc,expected_status", + [ + (NotFoundError("missing"), 404), + (LinkFlowMismatchError("wrong flow"), 400), + (LinkTokenExpiredError("expired"), 410), + (LinkAlreadyExistsError("already"), 409), + ], + ) + async def test_translation(self, exc: Exception, expected_status: int): + from backend.api.features.platform_linking.routes import confirm_link_token + + db = _db_mock(confirm_server_link=AsyncMock(side_effect=exc)) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as ctx: + await confirm_link_token(token="abc", user_id="u1") + assert ctx.value.status_code == expected_status + + +class TestConfirmUserLinkRouteTranslation: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "exc,expected_status", + [ + (NotFoundError("missing"), 404), + (LinkFlowMismatchError("wrong flow"), 400), + (LinkTokenExpiredError("expired"), 410), + (LinkAlreadyExistsError("already"), 409), + ], + ) + async def test_translation(self, exc: Exception, expected_status: int): + from backend.api.features.platform_linking.routes import confirm_user_link_token + + db = _db_mock(confirm_user_link=AsyncMock(side_effect=exc)) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as ctx: + await confirm_user_link_token(token="abc", user_id="u1") + assert ctx.value.status_code == expected_status + + +class TestDeleteLinkRouteTranslation: + @pytest.mark.asyncio + async def test_not_found_maps_to_404(self): + from backend.api.features.platform_linking.routes import delete_link + + db = _db_mock( + delete_server_link=AsyncMock(side_effect=NotFoundError("missing")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as exc: + await delete_link(link_id="x", user_id="u1") + assert exc.value.status_code == 404 + + @pytest.mark.asyncio + async def test_not_owned_maps_to_403(self): + from backend.api.features.platform_linking.routes import delete_link + + db = _db_mock( + delete_server_link=AsyncMock(side_effect=NotAuthorizedError("nope")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as exc: + await delete_link(link_id="x", user_id="u1") + assert exc.value.status_code == 403 + + +class TestDeleteUserLinkRouteTranslation: + @pytest.mark.asyncio + async def test_not_found_maps_to_404(self): + from backend.api.features.platform_linking.routes import delete_user_link_route + + db = _db_mock(delete_user_link=AsyncMock(side_effect=NotFoundError("missing"))) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as exc: + await delete_user_link_route(link_id="x", user_id="u1") + assert exc.value.status_code == 404 + + @pytest.mark.asyncio + async def test_not_owned_maps_to_403(self): + from backend.api.features.platform_linking.routes import delete_user_link_route + + db = _db_mock( + delete_user_link=AsyncMock(side_effect=NotAuthorizedError("nope")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + with pytest.raises(HTTPException) as exc: + await delete_user_link_route(link_id="x", user_id="u1") + assert exc.value.status_code == 403 + + +# ── Adversarial: malformed token path params ────────────────────────── + + +class TestAdversarialTokenPath: + # TokenPath enforces `^[A-Za-z0-9_-]+$` + max_length=64. + + @pytest.fixture + def client(self): + import fastapi + from autogpt_libs.auth import get_user_id, requires_user + from fastapi.testclient import TestClient + + import backend.api.features.platform_linking.routes as routes_mod + + app = fastapi.FastAPI() + app.dependency_overrides[requires_user] = lambda: None + app.dependency_overrides[get_user_id] = lambda: "caller-user" + app.include_router(routes_mod.router, prefix="/api/platform-linking") + return TestClient(app) + + def test_rejects_token_with_special_chars(self, client): + response = client.get("/api/platform-linking/tokens/bad%24token/info") + assert response.status_code == 422 + + def test_rejects_token_with_path_traversal(self, client): + for probe in ("..%2F..", "foo..bar", "foo%2Fbar"): + response = client.get(f"/api/platform-linking/tokens/{probe}/info") + assert response.status_code in ( + 404, + 422, + ), f"path-traversal probe {probe!r} returned {response.status_code}" + + def test_rejects_token_too_long(self, client): + long_token = "a" * 65 + response = client.get(f"/api/platform-linking/tokens/{long_token}/info") + assert response.status_code == 422 + + def test_accepts_token_at_max_length(self, client): + token = "a" * 64 + db = _db_mock( + get_link_token_info=AsyncMock(side_effect=NotFoundError("missing")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + response = client.get(f"/api/platform-linking/tokens/{token}/info") + assert response.status_code == 404 + + def test_accepts_urlsafe_b64_token_shape(self, client): + db = _db_mock( + get_link_token_info=AsyncMock(side_effect=NotFoundError("missing")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + response = client.get("/api/platform-linking/tokens/abc-_XYZ123-_abc/info") + assert response.status_code == 404 + + def test_confirm_rejects_malformed_token(self, client): + response = client.post("/api/platform-linking/tokens/bad%24token/confirm") + assert response.status_code == 422 + + +class TestAdversarialDeleteLinkId: + """DELETE link_id has no regex — ensure weird values are handled via + NotFoundError (no crash, no cross-user leak).""" + + @pytest.fixture + def client(self): + import fastapi + from autogpt_libs.auth import get_user_id, requires_user + from fastapi.testclient import TestClient + + import backend.api.features.platform_linking.routes as routes_mod + + app = fastapi.FastAPI() + app.dependency_overrides[requires_user] = lambda: None + app.dependency_overrides[get_user_id] = lambda: "caller-user" + app.include_router(routes_mod.router, prefix="/api/platform-linking") + return TestClient(app) + + def test_weird_link_id_returns_404(self, client): + db = _db_mock( + delete_server_link=AsyncMock(side_effect=NotFoundError("missing")) + ) + with patch( + "backend.api.features.platform_linking.routes.platform_linking_db", + return_value=db, + ): + for link_id in ("'; DROP TABLE links;--", "../../etc/passwd", ""): + response = client.delete(f"/api/platform-linking/links/{link_id}") + assert response.status_code in (404, 405) diff --git a/autogpt_platform/backend/backend/api/features/v1.py b/autogpt_platform/backend/backend/api/features/v1.py index 3559071043..12a31e6bd1 100644 --- a/autogpt_platform/backend/backend/api/features/v1.py +++ b/autogpt_platform/backend/backend/api/features/v1.py @@ -30,6 +30,7 @@ from pydantic import BaseModel, Field from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND from typing_extensions import Optional, TypedDict +from backend.api.features.workspace.routes import create_file_download_response from backend.api.model import ( CreateAPIKeyRequest, CreateAPIKeyResponse, @@ -96,6 +97,7 @@ from backend.data.user import ( update_user_notification_preference, update_user_timezone, ) +from backend.data.workspace import get_workspace_file_by_id from backend.executor import scheduler from backend.executor import utils as execution_utils from backend.integrations.webhooks.graph_lifecycle_hooks import ( @@ -1703,6 +1705,10 @@ async def enable_execution_sharing( # Generate a unique share token share_token = str(uuid.uuid4()) + # Remove stale allowlist records before updating the token — prevents a + # window where old records + new token could coexist. + await execution_db.delete_shared_execution_files(execution_id=graph_exec_id) + # Update the execution with share info await execution_db.update_graph_execution_share_status( execution_id=graph_exec_id, @@ -1712,6 +1718,14 @@ async def enable_execution_sharing( shared_at=datetime.now(timezone.utc), ) + # Create allowlist of workspace files referenced in outputs + await execution_db.create_shared_execution_files( + execution_id=graph_exec_id, + share_token=share_token, + user_id=user_id, + outputs=execution.outputs, + ) + # Return the share URL frontend_url = settings.config.frontend_base_url or "http://localhost:3000" share_url = f"{frontend_url}/share/{share_token}" @@ -1737,6 +1751,9 @@ async def disable_execution_sharing( if not execution: raise HTTPException(status_code=404, detail="Execution not found") + # Remove shared file allowlist records + await execution_db.delete_shared_execution_files(execution_id=graph_exec_id) + # Remove share info await execution_db.update_graph_execution_share_status( execution_id=graph_exec_id, @@ -1762,6 +1779,43 @@ async def get_shared_execution( return execution +@v1_router.get( + "/public/shared/{share_token}/files/{file_id}/download", + summary="Download a file from a shared execution", + operation_id="download_shared_file", + tags=["graphs"], +) +async def download_shared_file( + share_token: Annotated[ + str, + Path(pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"), + ], + file_id: Annotated[ + str, + Path(pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"), + ], +) -> Response: + """Download a workspace file from a shared execution (no auth required). + + Validates that the file was explicitly exposed when sharing was enabled. + Returns a uniform 404 for all failure modes to prevent enumeration attacks. + """ + # Single-query validation against the allowlist + execution_id = await execution_db.get_shared_execution_file( + share_token=share_token, file_id=file_id + ) + if not execution_id: + raise HTTPException(status_code=404, detail="Not found") + + # Look up the actual file (no workspace scoping needed — the allowlist + # already validated that this file belongs to the shared execution) + file = await get_workspace_file_by_id(file_id) + if not file: + raise HTTPException(status_code=404, detail="Not found") + + return await create_file_download_response(file, inline=True) + + ######################################################## ##################### Schedules ######################## ######################################################## diff --git a/autogpt_platform/backend/backend/api/features/v1_share_test.py b/autogpt_platform/backend/backend/api/features/v1_share_test.py new file mode 100644 index 0000000000..de5d14ad80 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/v1_share_test.py @@ -0,0 +1,157 @@ +"""Tests for the public shared file download endpoint.""" + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from starlette.responses import Response + +from backend.api.features.v1 import v1_router +from backend.data.workspace import WorkspaceFile + +app = FastAPI() +app.include_router(v1_router, prefix="/api") + +VALID_TOKEN = "550e8400-e29b-41d4-a716-446655440000" +VALID_FILE_ID = "6ba7b810-9dad-11d1-80b4-00c04fd430c8" + + +def _make_workspace_file(**overrides) -> WorkspaceFile: + defaults = { + "id": VALID_FILE_ID, + "workspace_id": "ws-001", + "created_at": datetime(2026, 1, 1, tzinfo=timezone.utc), + "updated_at": datetime(2026, 1, 1, tzinfo=timezone.utc), + "name": "image.png", + "path": "/image.png", + "storage_path": "local://uploads/image.png", + "mime_type": "image/png", + "size_bytes": 4, + "checksum": None, + "is_deleted": False, + "deleted_at": None, + "metadata": {}, + } + defaults.update(overrides) + return WorkspaceFile(**defaults) + + +def _mock_download_response(**kwargs): + """Return an AsyncMock that resolves to a Response with inline disposition.""" + + async def _handler(file, *, inline=False): + return Response( + content=b"\x89PNG", + media_type="image/png", + headers={ + "Content-Disposition": ( + 'inline; filename="image.png"' + if inline + else 'attachment; filename="image.png"' + ), + "Content-Length": "4", + }, + ) + + return _handler + + +class TestDownloadSharedFile: + """Tests for GET /api/public/shared/{token}/files/{id}/download.""" + + @pytest.fixture(autouse=True) + def _client(self): + self.client = TestClient(app, raise_server_exceptions=False) + + def test_valid_token_and_file_returns_inline_content(self): + with ( + patch( + "backend.api.features.v1.execution_db.get_shared_execution_file", + new_callable=AsyncMock, + return_value="exec-123", + ), + patch( + "backend.api.features.v1.get_workspace_file_by_id", + new_callable=AsyncMock, + return_value=_make_workspace_file(), + ), + patch( + "backend.api.features.v1.create_file_download_response", + side_effect=_mock_download_response(), + ), + ): + response = self.client.get( + f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download" + ) + + assert response.status_code == 200 + assert response.content == b"\x89PNG" + assert "inline" in response.headers["Content-Disposition"] + + def test_invalid_token_format_returns_422(self): + response = self.client.get( + f"/api/public/shared/not-a-uuid/files/{VALID_FILE_ID}/download" + ) + assert response.status_code == 422 + + def test_token_not_in_allowlist_returns_404(self): + with patch( + "backend.api.features.v1.execution_db.get_shared_execution_file", + new_callable=AsyncMock, + return_value=None, + ): + response = self.client.get( + f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download" + ) + assert response.status_code == 404 + + def test_file_missing_from_workspace_returns_404(self): + with ( + patch( + "backend.api.features.v1.execution_db.get_shared_execution_file", + new_callable=AsyncMock, + return_value="exec-123", + ), + patch( + "backend.api.features.v1.get_workspace_file_by_id", + new_callable=AsyncMock, + return_value=None, + ), + ): + response = self.client.get( + f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download" + ) + assert response.status_code == 404 + + def test_uniform_404_prevents_enumeration(self): + """Both failure modes produce identical 404 — no information leak.""" + with patch( + "backend.api.features.v1.execution_db.get_shared_execution_file", + new_callable=AsyncMock, + return_value=None, + ): + resp_no_allow = self.client.get( + f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download" + ) + + with ( + patch( + "backend.api.features.v1.execution_db.get_shared_execution_file", + new_callable=AsyncMock, + return_value="exec-123", + ), + patch( + "backend.api.features.v1.get_workspace_file_by_id", + new_callable=AsyncMock, + return_value=None, + ), + ): + resp_no_file = self.client.get( + f"/api/public/shared/{VALID_TOKEN}/files/{VALID_FILE_ID}/download" + ) + + assert resp_no_allow.status_code == 404 + assert resp_no_file.status_code == 404 + assert resp_no_allow.json() == resp_no_file.json() diff --git a/autogpt_platform/backend/backend/api/features/workspace/routes.py b/autogpt_platform/backend/backend/api/features/workspace/routes.py index 39bcc6c7c4..c22cc445c4 100644 --- a/autogpt_platform/backend/backend/api/features/workspace/routes.py +++ b/autogpt_platform/backend/backend/api/features/workspace/routes.py @@ -29,7 +29,9 @@ from backend.util.workspace import WorkspaceManager from backend.util.workspace_storage import get_workspace_storage -def _sanitize_filename_for_header(filename: str) -> str: +def _sanitize_filename_for_header( + filename: str, disposition: str = "attachment" +) -> str: """ Sanitize filename for Content-Disposition header to prevent header injection. @@ -44,11 +46,11 @@ def _sanitize_filename_for_header(filename: str) -> str: # Check if filename has non-ASCII characters try: sanitized.encode("ascii") - return f'attachment; filename="{sanitized}"' + return f'{disposition}; filename="{sanitized}"' except UnicodeEncodeError: # Use RFC5987 encoding for UTF-8 filenames encoded = quote(sanitized, safe="") - return f"attachment; filename*=UTF-8''{encoded}" + return f"{disposition}; filename*=UTF-8''{encoded}" logger = logging.getLogger(__name__) @@ -58,19 +60,26 @@ router = fastapi.APIRouter( ) -def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response: +def _create_streaming_response( + content: bytes, file: WorkspaceFile, *, inline: bool = False +) -> Response: """Create a streaming response for file content.""" + disposition = _sanitize_filename_for_header( + file.name, disposition="inline" if inline else "attachment" + ) return Response( content=content, media_type=file.mime_type, headers={ - "Content-Disposition": _sanitize_filename_for_header(file.name), + "Content-Disposition": disposition, "Content-Length": str(len(content)), }, ) -async def _create_file_download_response(file: WorkspaceFile) -> Response: +async def create_file_download_response( + file: WorkspaceFile, *, inline: bool = False +) -> Response: """ Create a download response for a workspace file. @@ -82,7 +91,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response: # For local storage, stream the file directly if file.storage_path.startswith("local://"): content = await storage.retrieve(file.storage_path) - return _create_streaming_response(content, file) + return _create_streaming_response(content, file, inline=inline) # For GCS, try to redirect to signed URL, fall back to streaming try: @@ -90,7 +99,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response: # If we got back an API path (fallback), stream directly instead if url.startswith("/api/"): content = await storage.retrieve(file.storage_path) - return _create_streaming_response(content, file) + return _create_streaming_response(content, file, inline=inline) return fastapi.responses.RedirectResponse(url=url, status_code=302) except Exception as e: # Log the signed URL failure with context @@ -102,7 +111,7 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response: # Fall back to streaming directly from GCS try: content = await storage.retrieve(file.storage_path) - return _create_streaming_response(content, file) + return _create_streaming_response(content, file, inline=inline) except Exception as fallback_error: logger.error( f"Fallback streaming also failed for file {file.id} " @@ -169,7 +178,7 @@ async def download_file( if file is None: raise fastapi.HTTPException(status_code=404, detail="File not found") - return await _create_file_download_response(file) + return await create_file_download_response(file) @router.delete( diff --git a/autogpt_platform/backend/backend/api/features/workspace/routes_test.py b/autogpt_platform/backend/backend/api/features/workspace/routes_test.py index 42726ba051..ffc712014f 100644 --- a/autogpt_platform/backend/backend/api/features/workspace/routes_test.py +++ b/autogpt_platform/backend/backend/api/features/workspace/routes_test.py @@ -600,3 +600,221 @@ def test_list_files_offset_is_echoed_back(mock_manager_cls, mock_get_workspace): mock_instance.list_files.assert_called_once_with( limit=11, offset=50, include_all_sessions=True ) + + +# -- _sanitize_filename_for_header tests -- + + +class TestSanitizeFilenameForHeader: + def test_simple_ascii_attachment(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + assert _sanitize_filename_for_header("report.pdf") == ( + 'attachment; filename="report.pdf"' + ) + + def test_inline_disposition(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + assert _sanitize_filename_for_header("image.png", disposition="inline") == ( + 'inline; filename="image.png"' + ) + + def test_strips_cr_lf_null(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + result = _sanitize_filename_for_header("a\rb\nc\x00d.txt") + assert "\r" not in result + assert "\n" not in result + assert "\x00" not in result + assert 'filename="abcd.txt"' in result + + def test_escapes_quotes(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + result = _sanitize_filename_for_header('file"name.txt') + assert 'filename="file\\"name.txt"' in result + + def test_header_injection_blocked(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + result = _sanitize_filename_for_header("evil.txt\r\nX-Injected: true") + # CR/LF stripped — the remaining text is safely inside the quoted value + assert "\r" not in result + assert "\n" not in result + assert result == 'attachment; filename="evil.txtX-Injected: true"' + + def test_unicode_uses_rfc5987(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + result = _sanitize_filename_for_header("日本語.pdf") + assert "filename*=UTF-8''" in result + assert "attachment" in result + + def test_unicode_inline(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + result = _sanitize_filename_for_header("图片.png", disposition="inline") + assert result.startswith("inline; filename*=UTF-8''") + + def test_empty_filename(self): + from backend.api.features.workspace.routes import _sanitize_filename_for_header + + result = _sanitize_filename_for_header("") + assert result == 'attachment; filename=""' + + +# -- _create_streaming_response tests -- + + +class TestCreateStreamingResponse: + def test_attachment_disposition_by_default(self): + from backend.api.features.workspace.routes import _create_streaming_response + + file = _make_file(name="data.bin", mime_type="application/octet-stream") + response = _create_streaming_response(b"binary-data", file) + assert ( + response.headers["Content-Disposition"] == 'attachment; filename="data.bin"' + ) + assert response.headers["Content-Type"] == "application/octet-stream" + assert response.headers["Content-Length"] == "11" + assert response.body == b"binary-data" + + def test_inline_disposition(self): + from backend.api.features.workspace.routes import _create_streaming_response + + file = _make_file(name="photo.png", mime_type="image/png") + response = _create_streaming_response(b"\x89PNG", file, inline=True) + assert response.headers["Content-Disposition"] == 'inline; filename="photo.png"' + assert response.headers["Content-Type"] == "image/png" + + def test_inline_sanitizes_filename(self): + from backend.api.features.workspace.routes import _create_streaming_response + + file = _make_file(name='evil"\r\n.txt', mime_type="text/plain") + response = _create_streaming_response(b"data", file, inline=True) + assert "\r" not in response.headers["Content-Disposition"] + assert "\n" not in response.headers["Content-Disposition"] + assert "inline" in response.headers["Content-Disposition"] + + def test_content_length_matches_body(self): + from backend.api.features.workspace.routes import _create_streaming_response + + content = b"x" * 1000 + file = _make_file(name="big.bin", mime_type="application/octet-stream") + response = _create_streaming_response(content, file) + assert response.headers["Content-Length"] == "1000" + + +# -- create_file_download_response tests -- + + +class TestCreateFileDownloadResponse: + @pytest.mark.asyncio + async def test_local_storage_returns_streaming_response(self, mocker): + from backend.api.features.workspace.routes import create_file_download_response + + mock_storage = AsyncMock() + mock_storage.retrieve.return_value = b"file contents" + mocker.patch( + "backend.api.features.workspace.routes.get_workspace_storage", + return_value=mock_storage, + ) + + file = _make_file( + storage_path="local://uploads/test.txt", + mime_type="text/plain", + ) + response = await create_file_download_response(file) + assert response.status_code == 200 + assert response.body == b"file contents" + assert "attachment" in response.headers["Content-Disposition"] + + @pytest.mark.asyncio + async def test_local_storage_inline(self, mocker): + from backend.api.features.workspace.routes import create_file_download_response + + mock_storage = AsyncMock() + mock_storage.retrieve.return_value = b"\x89PNG" + mocker.patch( + "backend.api.features.workspace.routes.get_workspace_storage", + return_value=mock_storage, + ) + + file = _make_file( + storage_path="local://uploads/photo.png", + mime_type="image/png", + name="photo.png", + ) + response = await create_file_download_response(file, inline=True) + assert "inline" in response.headers["Content-Disposition"] + + @pytest.mark.asyncio + async def test_gcs_redirect(self, mocker): + from backend.api.features.workspace.routes import create_file_download_response + + mock_storage = AsyncMock() + mock_storage.get_download_url.return_value = ( + "https://storage.googleapis.com/signed-url" + ) + mocker.patch( + "backend.api.features.workspace.routes.get_workspace_storage", + return_value=mock_storage, + ) + + file = _make_file(storage_path="gcs://bucket/file.pdf") + response = await create_file_download_response(file) + assert response.status_code == 302 + assert ( + response.headers["location"] == "https://storage.googleapis.com/signed-url" + ) + + @pytest.mark.asyncio + async def test_gcs_api_fallback_streams_directly(self, mocker): + from backend.api.features.workspace.routes import create_file_download_response + + mock_storage = AsyncMock() + mock_storage.get_download_url.return_value = "/api/fallback" + mock_storage.retrieve.return_value = b"fallback content" + mocker.patch( + "backend.api.features.workspace.routes.get_workspace_storage", + return_value=mock_storage, + ) + + file = _make_file(storage_path="gcs://bucket/file.txt") + response = await create_file_download_response(file) + assert response.status_code == 200 + assert response.body == b"fallback content" + + @pytest.mark.asyncio + async def test_gcs_signed_url_failure_falls_back_to_streaming(self, mocker): + from backend.api.features.workspace.routes import create_file_download_response + + mock_storage = AsyncMock() + mock_storage.get_download_url.side_effect = RuntimeError("GCS error") + mock_storage.retrieve.return_value = b"streamed" + mocker.patch( + "backend.api.features.workspace.routes.get_workspace_storage", + return_value=mock_storage, + ) + + file = _make_file(storage_path="gcs://bucket/file.txt") + response = await create_file_download_response(file) + assert response.status_code == 200 + assert response.body == b"streamed" + + @pytest.mark.asyncio + async def test_gcs_total_failure_raises(self, mocker): + from backend.api.features.workspace.routes import create_file_download_response + + mock_storage = AsyncMock() + mock_storage.get_download_url.side_effect = RuntimeError("GCS error") + mock_storage.retrieve.side_effect = RuntimeError("Also failed") + mocker.patch( + "backend.api.features.workspace.routes.get_workspace_storage", + return_value=mock_storage, + ) + + file = _make_file(storage_path="gcs://bucket/file.txt") + with pytest.raises(RuntimeError, match="Also failed"): + await create_file_download_response(file) diff --git a/autogpt_platform/backend/backend/api/rest_api.py b/autogpt_platform/backend/backend/api/rest_api.py index b4fc2da4e9..abe261b725 100644 --- a/autogpt_platform/backend/backend/api/rest_api.py +++ b/autogpt_platform/backend/backend/api/rest_api.py @@ -32,6 +32,7 @@ import backend.api.features.library.routes import backend.api.features.mcp.routes as mcp_routes import backend.api.features.oauth import backend.api.features.otto.routes +import backend.api.features.platform_linking.routes import backend.api.features.postmark.postmark import backend.api.features.store.model import backend.api.features.store.routes @@ -378,6 +379,11 @@ app.include_router( tags=["oauth"], prefix="/api/oauth", ) +app.include_router( + backend.api.features.platform_linking.routes.router, + tags=["platform-linking"], + prefix="/api/platform-linking", +) app.mount("/external-api", external_api) diff --git a/autogpt_platform/backend/backend/app.py b/autogpt_platform/backend/backend/app.py index 236f098761..534f385009 100644 --- a/autogpt_platform/backend/backend/app.py +++ b/autogpt_platform/backend/backend/app.py @@ -42,11 +42,13 @@ def main(**kwargs): from backend.data.db_manager import DatabaseManager from backend.executor import ExecutionManager, Scheduler from backend.notifications import NotificationManager + from backend.platform_linking.manager import PlatformLinkingManager run_processes( DatabaseManager().set_log_level("warning"), Scheduler(), NotificationManager(), + PlatformLinkingManager(), WebsocketServer(), AgentServer(), ExecutionManager(), diff --git a/autogpt_platform/backend/backend/data/db_accessors.py b/autogpt_platform/backend/backend/data/db_accessors.py index 743e3c778c..8598fe9d6f 100644 --- a/autogpt_platform/backend/backend/data/db_accessors.py +++ b/autogpt_platform/backend/backend/data/db_accessors.py @@ -155,3 +155,16 @@ def platform_cost_db(): platform_cost_db = get_database_manager_async_client() return platform_cost_db + + +def platform_linking_db(): + if db.is_connected(): + from backend.platform_linking import db as _platform_linking_db + + platform_linking_db = _platform_linking_db + else: + from backend.util.clients import get_database_manager_async_client + + platform_linking_db = get_database_manager_async_client() + + return platform_linking_db diff --git a/autogpt_platform/backend/backend/data/db_manager.py b/autogpt_platform/backend/backend/data/db_manager.py index 842b49a262..e06fec1b58 100644 --- a/autogpt_platform/backend/backend/data/db_manager.py +++ b/autogpt_platform/backend/backend/data/db_manager.py @@ -120,6 +120,7 @@ from backend.data.workspace import ( list_workspace_files, soft_delete_workspace_file, ) +from backend.platform_linking import db as platform_linking_db from backend.util.service import ( AppService, AppServiceClient, @@ -338,6 +339,22 @@ class DatabaseManager(AppService): # ============ Platform Cost Tracking ============ # log_platform_cost = _(log_platform_cost) + # ============ Platform Linking ============ # + find_server_link_owner = _(platform_linking_db.find_server_link_owner) + find_user_link_owner = _(platform_linking_db.find_user_link_owner) + resolve_server_link = _(platform_linking_db.resolve_server_link) + resolve_user_link = _(platform_linking_db.resolve_user_link) + create_server_link_token = _(platform_linking_db.create_server_link_token) + create_user_link_token = _(platform_linking_db.create_user_link_token) + get_link_token_status = _(platform_linking_db.get_link_token_status) + get_link_token_info = _(platform_linking_db.get_link_token_info) + confirm_server_link = _(platform_linking_db.confirm_server_link) + confirm_user_link = _(platform_linking_db.confirm_user_link) + list_server_links = _(platform_linking_db.list_server_links) + list_user_links = _(platform_linking_db.list_user_links) + delete_server_link = _(platform_linking_db.delete_server_link) + delete_user_link = _(platform_linking_db.delete_user_link) + # ============ CoPilot Chat Sessions ============ # get_chat_session = _(chat_db.get_chat_session) create_chat_session = _(chat_db.create_chat_session) @@ -540,6 +557,22 @@ class DatabaseManagerAsyncClient(AppServiceClient): # ============ Platform Cost Tracking ============ # log_platform_cost = d.log_platform_cost + # ============ Platform Linking ============ # + find_server_link_owner = d.find_server_link_owner + find_user_link_owner = d.find_user_link_owner + resolve_server_link = d.resolve_server_link + resolve_user_link = d.resolve_user_link + create_server_link_token = d.create_server_link_token + create_user_link_token = d.create_user_link_token + get_link_token_status = d.get_link_token_status + get_link_token_info = d.get_link_token_info + confirm_server_link = d.confirm_server_link + confirm_user_link = d.confirm_user_link + list_server_links = d.list_server_links + list_user_links = d.list_user_links + delete_server_link = d.delete_server_link + delete_user_link = d.delete_user_link + # ============ CoPilot Chat Sessions ============ # get_chat_session = d.get_chat_session create_chat_session = d.create_chat_session diff --git a/autogpt_platform/backend/backend/data/execution.py b/autogpt_platform/backend/backend/data/execution.py index 4403a59080..cd50d7df3c 100644 --- a/autogpt_platform/backend/backend/data/execution.py +++ b/autogpt_platform/backend/backend/data/execution.py @@ -19,11 +19,15 @@ from typing import ( from prisma import Json from prisma.enums import AgentExecutionStatus +from prisma.errors import ForeignKeyViolationError, UniqueViolationError from prisma.models import ( AgentGraphExecution, AgentNodeExecution, AgentNodeExecutionInputOutput, AgentNodeExecutionKeyValueData, + SharedExecutionFile, + UserWorkspace, + UserWorkspaceFile, ) from prisma.types import ( AgentGraphExecutionOrderByInput, @@ -1602,6 +1606,121 @@ async def get_graph_execution_by_share_token( ) +def _extract_workspace_file_ids(outputs: CompletedBlockOutput) -> set[str]: + """Extract workspace file IDs from execution outputs. + + Scans all output values for workspace:// URI strings and extracts + the file IDs. Only matches values that are plain strings starting + with workspace://, not substrings within larger text. + """ + file_ids: set[str] = set() + + def _scan(value: Any) -> None: + if isinstance(value, str) and value.startswith("workspace://"): + raw = value.removeprefix("workspace://") + file_ref = raw.split("#", 1)[0] if "#" in raw else raw + if file_ref and not file_ref.startswith("/"): + file_ids.add(file_ref) + elif isinstance(value, list): + for item in value: + _scan(item) + elif isinstance(value, dict): + for v in value.values(): + _scan(v) + + for output_values in outputs.values(): + if isinstance(output_values, list): + for val in output_values: + _scan(val) + else: + _scan(output_values) + + return file_ids + + +async def create_shared_execution_files( + execution_id: str, + share_token: str, + user_id: str, + outputs: CompletedBlockOutput, +) -> int: + """Scan execution outputs for workspace files and create allowlist records. + + Only files belonging to the user's workspace are allowlisted — prevents + cross-workspace file exposure via crafted outputs. + + Returns the number of records created. + """ + file_ids = _extract_workspace_file_ids(outputs) + if not file_ids: + return 0 + + # Validate file IDs belong to the user's workspace + workspace = await UserWorkspace.prisma().find_unique(where={"userId": user_id}) + if not workspace: + return 0 + + owned_files = await UserWorkspaceFile.prisma().find_many( + where={ + "id": {"in": list(file_ids)}, + "workspaceId": workspace.id, + "isDeleted": False, + } + ) + owned_ids = {f.id for f in owned_files} + + created = 0 + for file_id in owned_ids: + try: + await SharedExecutionFile.prisma().create( + data={ + "executionId": execution_id, + "fileId": file_id, + "shareToken": share_token, + } + ) + created += 1 + except UniqueViolationError: + logger.debug( + f"Skipping shared file record for {file_id}: " f"record already exists" + ) + except ForeignKeyViolationError: + logger.debug( + f"Skipping shared file record for {file_id}: " f"file does not exist" + ) + return created + + +async def delete_shared_execution_files(execution_id: str) -> int: + """Delete all shared file records for an execution. + + Returns the number of records deleted. + """ + result = await SharedExecutionFile.prisma().delete_many( + where={"executionId": execution_id} + ) + return result + + +async def get_shared_execution_file( + share_token: str, + file_id: str, +) -> str | None: + """Look up a file ID in the shared execution file allowlist. + + Returns the execution ID if the file is in the allowlist, None otherwise. + Uses a single query and returns a uniform None for all failure modes + to prevent timing-based enumeration attacks. + """ + record = await SharedExecutionFile.prisma().find_first( + where={ + "shareToken": share_token, + "fileId": file_id, + } + ) + return record.executionId if record else None + + async def get_frequently_executed_graphs( days_back: int = 30, min_executions: int = 10, diff --git a/autogpt_platform/backend/backend/data/shared_execution_file_test.py b/autogpt_platform/backend/backend/data/shared_execution_file_test.py new file mode 100644 index 0000000000..e9beed280c --- /dev/null +++ b/autogpt_platform/backend/backend/data/shared_execution_file_test.py @@ -0,0 +1,72 @@ +"""Tests for SharedExecutionFile workspace URI extraction logic.""" + +from backend.data.execution import _extract_workspace_file_ids + + +class TestExtractWorkspaceFileIds: + def test_extracts_simple_workspace_uri(self): + outputs = {"image": ["workspace://abc123"]} + assert _extract_workspace_file_ids(outputs) == {"abc123"} + + def test_extracts_workspace_uri_with_mime_fragment(self): + outputs = {"image": ["workspace://abc123#image/png"]} + assert _extract_workspace_file_ids(outputs) == {"abc123"} + + def test_extracts_multiple_files_from_multiple_outputs(self): + outputs = { + "images": ["workspace://file1#image/png", "workspace://file2#image/jpeg"], + "video": ["workspace://file3#video/mp4"], + } + assert _extract_workspace_file_ids(outputs) == {"file1", "file2", "file3"} + + def test_ignores_non_workspace_strings(self): + outputs = { + "text": ["hello world"], + "url": ["https://example.com/image.png"], + "data": ["data:image/png;base64,abc"], + } + assert _extract_workspace_file_ids(outputs) == set() + + def test_ignores_path_references(self): + """workspace:///path/to/file is a path reference, not a file ID.""" + outputs = {"file": ["workspace:///path/to/file.txt"]} + assert _extract_workspace_file_ids(outputs) == set() + + def test_handles_nested_dicts_in_output_values(self): + outputs = { + "result": [{"url": "workspace://nested-file#image/png", "label": "test"}] + } + assert _extract_workspace_file_ids(outputs) == {"nested-file"} + + def test_handles_nested_lists_in_output_values(self): + outputs = {"result": [["workspace://inner-file"]]} + assert _extract_workspace_file_ids(outputs) == {"inner-file"} + + def test_handles_empty_outputs(self): + assert _extract_workspace_file_ids({}) == set() + + def test_handles_non_string_values(self): + outputs = {"count": [42], "flag": [True], "empty": [None]} + assert _extract_workspace_file_ids(outputs) == set() + + def test_deduplicates_repeated_file_ids(self): + outputs = { + "a": ["workspace://same-file#image/png"], + "b": ["workspace://same-file#image/jpeg"], + } + assert _extract_workspace_file_ids(outputs) == {"same-file"} + + def test_does_not_match_workspace_substring_in_text(self): + """Plain text that contains workspace:// as a substring should NOT be extracted + because the value itself must start with workspace://.""" + outputs = {"text": ["check out workspace://fake-id for details"]} + # The string starts with "check out", not "workspace://", so no match + assert _extract_workspace_file_ids(outputs) == set() + + def test_mixed_workspace_and_non_workspace_outputs(self): + outputs = { + "image": ["workspace://real-file#image/png"], + "text": ["just some text"], + "url": ["https://example.com"], + } + assert _extract_workspace_file_ids(outputs) == {"real-file"} diff --git a/autogpt_platform/backend/backend/data/workspace.py b/autogpt_platform/backend/backend/data/workspace.py index 43e328813b..62220b45fe 100644 --- a/autogpt_platform/backend/backend/data/workspace.py +++ b/autogpt_platform/backend/backend/data/workspace.py @@ -204,6 +204,22 @@ async def get_workspace_file( return WorkspaceFile.from_db(file) if file else None +async def get_workspace_file_by_id( + file_id: str, +) -> Optional[WorkspaceFile]: + """ + Get a workspace file by ID without workspace scoping. + + Only use this when access has already been validated through another + mechanism (e.g. SharedExecutionFile allowlist). For user-facing + endpoints, use get_workspace_file() which enforces workspace scoping. + """ + file = await UserWorkspaceFile.prisma().find_first( + where={"id": file_id, "isDeleted": False} + ) + return WorkspaceFile.from_db(file) if file else None + + async def get_workspace_file_by_path( workspace_id: str, path: str, diff --git a/autogpt_platform/backend/backend/platform_linking/__init__.py b/autogpt_platform/backend/backend/platform_linking/__init__.py new file mode 100644 index 0000000000..64834840d3 --- /dev/null +++ b/autogpt_platform/backend/backend/platform_linking/__init__.py @@ -0,0 +1 @@ +"""Platform bot linking: helpers, chat orchestration, and AppService.""" diff --git a/autogpt_platform/backend/backend/platform_linking/chat.py b/autogpt_platform/backend/backend/platform_linking/chat.py new file mode 100644 index 0000000000..1d71029759 --- /dev/null +++ b/autogpt_platform/backend/backend/platform_linking/chat.py @@ -0,0 +1,112 @@ +"""Chat-turn orchestration for the platform bot bridge.""" + +import logging +from uuid import uuid4 + +from backend.copilot import stream_registry +from backend.copilot.executor.utils import enqueue_copilot_turn +from backend.copilot.model import ( + ChatMessage, + append_and_save_message, + create_chat_session, + get_chat_session, +) +from backend.data.db_accessors import platform_linking_db +from backend.util.exceptions import DuplicateChatMessageError, NotFoundError + +from .models import BotChatRequest, ChatTurnHandle + +logger = logging.getLogger(__name__) + +CHAT_TOOL_CALL_ID = "chat_stream" +CHAT_TOOL_NAME = "chat" + + +async def resolve_chat_owner(request: BotChatRequest) -> str: + """Return the AutoGPT user ID that owns the platform conversation. + + Server context → server owner. DM context → the DM-linked user. + """ + platform = request.platform.value + db = platform_linking_db() + + if request.platform_server_id: + owner = await db.find_server_link_owner(platform, request.platform_server_id) + if owner is None: + raise NotFoundError("This server is not linked to an AutoGPT account.") + return owner + + owner = await db.find_user_link_owner(platform, request.platform_user_id) + if owner is None: + raise NotFoundError("Your DMs are not linked to an AutoGPT account.") + return owner + + +async def start_chat_turn(request: BotChatRequest) -> ChatTurnHandle: + """Prepare a copilot turn; caller subscribes via the returned handle. + + ``subscribe_from="0-0"`` on the handle means a late subscriber replays + the full stream (Redis Streams, not pub/sub). + """ + owner_user_id = await resolve_chat_owner(request) + + session_id = request.session_id + if session_id: + session = await get_chat_session(session_id, owner_user_id) + if not session: + raise NotFoundError("Session not found.") + else: + session = await create_chat_session(owner_user_id, dry_run=False) + session_id = session.session_id + + # Persist the user message before enqueueing, mirroring the REST chat + # endpoint — otherwise the executor runs against empty history. + is_duplicate = ( + await append_and_save_message( + session_id, ChatMessage(role="user", content=request.message) + ) + ) is None + if is_duplicate: + # Matches REST chat behaviour: skip create_session + enqueue so we + # don't create an orphan stream with no producer. Caller subscribes + # to the in-flight turn via its own retry logic, or drops. + logger.info( + "Duplicate bot message for session %s (platform %s, user ...%s)", + session_id, + request.platform.value, + owner_user_id[-8:], + ) + raise DuplicateChatMessageError("Message already in flight.") + + turn_id = str(uuid4()) + + await stream_registry.create_session( + session_id=session_id, + user_id=owner_user_id, + tool_call_id=CHAT_TOOL_CALL_ID, + tool_name=CHAT_TOOL_NAME, + turn_id=turn_id, + ) + + await enqueue_copilot_turn( + session_id=session_id, + user_id=owner_user_id, + message=request.message, + turn_id=turn_id, + is_user_message=True, + ) + + logger.info( + "Bot chat turn started: %s (server %s, session %s, turn %s, owner ...%s)", + request.platform.value, + request.platform_server_id or "DM", + session_id, + turn_id, + owner_user_id[-8:], + ) + + return ChatTurnHandle( + session_id=session_id, + turn_id=turn_id, + user_id=owner_user_id, + ) diff --git a/autogpt_platform/backend/backend/platform_linking/chat_test.py b/autogpt_platform/backend/backend/platform_linking/chat_test.py new file mode 100644 index 0000000000..ebc41ee6f8 --- /dev/null +++ b/autogpt_platform/backend/backend/platform_linking/chat_test.py @@ -0,0 +1,125 @@ +"""Tests for chat-turn orchestration — esp. the duplicate-message guard.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from backend.util.exceptions import DuplicateChatMessageError, NotFoundError + +from .chat import start_chat_turn +from .models import BotChatRequest, Platform + + +def _request(**overrides) -> BotChatRequest: + defaults = dict( + platform=Platform.DISCORD, + platform_user_id="pu1", + message="hello", + ) + defaults.update(overrides) + return BotChatRequest(**defaults) + + +class TestStartChatTurn: + @pytest.mark.asyncio + async def test_no_user_link_raises_not_found(self): + db_mock = MagicMock() + db_mock.find_user_link_owner = AsyncMock(return_value=None) + with patch( + "backend.platform_linking.chat.platform_linking_db", + return_value=db_mock, + ): + with pytest.raises(NotFoundError): + await start_chat_turn(_request()) + + @pytest.mark.asyncio + async def test_duplicate_message_raises_and_skips_stream_create(self): + # append_and_save_message returns None → duplicate. + # Verify we raise and do NOT create a stream session. + db_mock = MagicMock() + db_mock.find_user_link_owner = AsyncMock(return_value="owner-1") + session = MagicMock(session_id="sess-existing") + + with ( + patch( + "backend.platform_linking.chat.platform_linking_db", + return_value=db_mock, + ), + patch( + "backend.platform_linking.chat.create_chat_session", + new=AsyncMock(return_value=session), + ), + patch( + "backend.platform_linking.chat.append_and_save_message", + new=AsyncMock(return_value=None), + ), + patch( + "backend.platform_linking.chat.stream_registry" + ) as mock_stream_registry, + patch( + "backend.platform_linking.chat.enqueue_copilot_turn", + new=AsyncMock(), + ) as mock_enqueue, + ): + mock_stream_registry.create_session = AsyncMock() + + with pytest.raises(DuplicateChatMessageError): + await start_chat_turn(_request()) + + mock_stream_registry.create_session.assert_not_awaited() + mock_enqueue.assert_not_awaited() + + @pytest.mark.asyncio + async def test_happy_path_creates_stream_and_enqueues(self): + db_mock = MagicMock() + db_mock.find_user_link_owner = AsyncMock(return_value="owner-1") + session = MagicMock(session_id="sess-new") + + with ( + patch( + "backend.platform_linking.chat.platform_linking_db", + return_value=db_mock, + ), + patch( + "backend.platform_linking.chat.create_chat_session", + new=AsyncMock(return_value=session), + ), + patch( + "backend.platform_linking.chat.append_and_save_message", + new=AsyncMock(return_value=MagicMock()), + ), + patch( + "backend.platform_linking.chat.stream_registry" + ) as mock_stream_registry, + patch( + "backend.platform_linking.chat.enqueue_copilot_turn", + new=AsyncMock(), + ) as mock_enqueue, + ): + mock_stream_registry.create_session = AsyncMock() + handle = await start_chat_turn(_request()) + + assert handle.session_id == "sess-new" + assert handle.user_id == "owner-1" + assert handle.turn_id + assert handle.subscribe_from == "0-0" + mock_stream_registry.create_session.assert_awaited_once() + mock_enqueue.assert_awaited_once() + + @pytest.mark.asyncio + async def test_existing_session_id_wrong_user_raises_not_found(self): + db_mock = MagicMock() + db_mock.find_user_link_owner = AsyncMock(return_value="owner-1") + + with ( + patch( + "backend.platform_linking.chat.platform_linking_db", + return_value=db_mock, + ), + patch( + "backend.platform_linking.chat.get_chat_session", + new=AsyncMock(return_value=None), + ), + ): + with pytest.raises(NotFoundError): + await start_chat_turn(_request(session_id="someone-elses")) diff --git a/autogpt_platform/backend/backend/platform_linking/db.py b/autogpt_platform/backend/backend/platform_linking/db.py new file mode 100644 index 0000000000..8e419fba72 --- /dev/null +++ b/autogpt_platform/backend/backend/platform_linking/db.py @@ -0,0 +1,428 @@ +"""Platform link DB operations. + +Directly accessed by the ``AgentServer`` / ``DatabaseManager`` pods (which +hold the Prisma connection). Other services go through +``backend.data.db_accessors.platform_linking_db`` so calls are transparently +routed via ``DatabaseManagerAsyncClient`` when no local Prisma is available. +""" + +import logging +import secrets +from datetime import datetime, timedelta, timezone + +from prisma.errors import UniqueViolationError +from prisma.models import PlatformLink, PlatformLinkToken, PlatformUserLink + +from backend.data.db import transaction +from backend.util.exceptions import ( + LinkAlreadyExistsError, + LinkFlowMismatchError, + LinkTokenExpiredError, + NotAuthorizedError, + NotFoundError, +) +from backend.util.settings import Settings + +from .models import ( + ConfirmLinkResponse, + ConfirmUserLinkResponse, + CreateLinkTokenRequest, + CreateUserLinkTokenRequest, + DeleteLinkResponse, + LinkTokenInfoResponse, + LinkTokenResponse, + LinkTokenStatusResponse, + LinkType, + PlatformLinkInfo, + PlatformUserLinkInfo, + ResolveResponse, +) + +logger = logging.getLogger(__name__) + +LINK_TOKEN_EXPIRY_MINUTES = 30 + + +def _link_base_url() -> str: + return Settings().config.platform_link_base_url + + +# ── Owner lookups ───────────────────────────────────────────────────── +# These return the owning AutoGPT user_id (or None). Using scalars instead +# of Prisma models keeps everything RPC-safe — Prisma objects are rejected +# by AppService's result validator. + + +async def find_server_link_owner(platform: str, platform_server_id: str) -> str | None: + link = await PlatformLink.prisma().find_first( + where={"platform": platform, "platformServerId": platform_server_id} + ) + return link.userId if link else None + + +async def find_user_link_owner(platform: str, platform_user_id: str) -> str | None: + link = await PlatformUserLink.prisma().find_unique( + where={ + "platform_platformUserId": { + "platform": platform, + "platformUserId": platform_user_id, + } + } + ) + return link.userId if link else None + + +async def resolve_server_link( + platform: str, platform_server_id: str +) -> ResolveResponse: + owner = await find_server_link_owner(platform, platform_server_id) + return ResolveResponse(linked=owner is not None) + + +async def resolve_user_link(platform: str, platform_user_id: str) -> ResolveResponse: + owner = await find_user_link_owner(platform, platform_user_id) + return ResolveResponse(linked=owner is not None) + + +# ── Token creation ──────────────────────────────────────────────────── + + +async def create_server_link_token( + request: CreateLinkTokenRequest, +) -> LinkTokenResponse: + platform = request.platform.value + + if await find_server_link_owner(platform, request.platform_server_id): + raise LinkAlreadyExistsError( + "This server is already linked to an AutoGPT account." + ) + + token = secrets.token_urlsafe(32) + expires_at = datetime.now(timezone.utc) + timedelta( + minutes=LINK_TOKEN_EXPIRY_MINUTES + ) + + # Atomic: invalidate pending tokens + create the new one, so two racing + # create calls can't leave two valid tokens for the same target. + async with transaction() as tx: + await PlatformLinkToken.prisma(tx).update_many( + where={ + "platform": platform, + "linkType": LinkType.SERVER.value, + "platformServerId": request.platform_server_id, + "usedAt": None, + }, + data={"usedAt": datetime.now(timezone.utc)}, + ) + await PlatformLinkToken.prisma(tx).create( + data={ + "token": token, + "platform": platform, + "linkType": LinkType.SERVER.value, + "platformServerId": request.platform_server_id, + "platformUserId": request.platform_user_id, + "platformUsername": request.platform_username, + "serverName": request.server_name, + "channelId": request.channel_id, + "expiresAt": expires_at, + } + ) + + logger.info( + "Created SERVER link token for %s server %s (expires %s)", + platform, + request.platform_server_id, + expires_at.isoformat(), + ) + + return LinkTokenResponse( + token=token, + expires_at=expires_at, + link_url=f"{_link_base_url()}/{token}?platform={platform}", + ) + + +async def create_user_link_token( + request: CreateUserLinkTokenRequest, +) -> LinkTokenResponse: + platform = request.platform.value + + if await find_user_link_owner(platform, request.platform_user_id): + raise LinkAlreadyExistsError( + "Your DMs with the bot are already linked to an AutoGPT account." + ) + + token = secrets.token_urlsafe(32) + expires_at = datetime.now(timezone.utc) + timedelta( + minutes=LINK_TOKEN_EXPIRY_MINUTES + ) + + async with transaction() as tx: + await PlatformLinkToken.prisma(tx).update_many( + where={ + "platform": platform, + "linkType": LinkType.USER.value, + "platformUserId": request.platform_user_id, + "usedAt": None, + }, + data={"usedAt": datetime.now(timezone.utc)}, + ) + await PlatformLinkToken.prisma(tx).create( + data={ + "token": token, + "platform": platform, + "linkType": LinkType.USER.value, + "platformUserId": request.platform_user_id, + "platformUsername": request.platform_username, + "expiresAt": expires_at, + } + ) + + logger.info( + "Created USER link token for %s (expires %s)", platform, expires_at.isoformat() + ) + + return LinkTokenResponse( + token=token, + expires_at=expires_at, + link_url=f"{_link_base_url()}/{token}?platform={platform}", + ) + + +# ── Token status / info ─────────────────────────────────────────────── + + +async def get_link_token_status(token: str) -> LinkTokenStatusResponse: + link_token = await PlatformLinkToken.prisma().find_unique(where={"token": token}) + + if not link_token: + raise NotFoundError("Token not found.") + + if link_token.usedAt is not None: + # A superseded token (invalidated by create_*_token) has usedAt set + # without a backing link row — report expired, not linked. + if link_token.linkType == LinkType.USER.value: + owner = await find_user_link_owner( + link_token.platform, link_token.platformUserId + ) + else: + owner = ( + await find_server_link_owner( + link_token.platform, link_token.platformServerId + ) + if link_token.platformServerId + else None + ) + return LinkTokenStatusResponse(status="linked" if owner else "expired") + + if link_token.expiresAt.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc): + return LinkTokenStatusResponse(status="expired") + + return LinkTokenStatusResponse(status="pending") + + +async def get_link_token_info(token: str) -> LinkTokenInfoResponse: + link_token = await PlatformLinkToken.prisma().find_unique(where={"token": token}) + + if not link_token or link_token.usedAt is not None: + raise NotFoundError("Token not found.") + + if link_token.expiresAt.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc): + raise LinkTokenExpiredError("Token expired.") + + return LinkTokenInfoResponse( + platform=link_token.platform, + link_type=LinkType(link_token.linkType), + server_name=link_token.serverName, + ) + + +# ── Confirmation (user-facing, JWT-authed) ──────────────────────────── + + +async def confirm_server_link(token: str, user_id: str) -> ConfirmLinkResponse: + link_token = await PlatformLinkToken.prisma().find_unique(where={"token": token}) + + if not link_token: + raise NotFoundError("Token not found.") + if link_token.linkType != LinkType.SERVER.value: + raise LinkFlowMismatchError("This link is for a different linking flow.") + if link_token.usedAt is not None: + raise LinkTokenExpiredError("This link has already been used.") + if link_token.expiresAt.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc): + raise LinkTokenExpiredError("This link has expired.") + if not link_token.platformServerId: + raise LinkFlowMismatchError("Server token missing server ID.") + + owner = await find_server_link_owner( + link_token.platform, link_token.platformServerId + ) + if owner: + detail = ( + "This server is already linked to your account." + if owner == user_id + else "This server is already linked to another AutoGPT account." + ) + raise LinkAlreadyExistsError(detail) + + # Atomic consume + create so a failed create doesn't burn the token. + now = datetime.now(timezone.utc) + try: + async with transaction() as tx: + updated = await PlatformLinkToken.prisma(tx).update_many( + where={"token": token, "usedAt": None, "expiresAt": {"gt": now}}, + data={"usedAt": now}, + ) + if updated == 0: + raise LinkTokenExpiredError("This link has already been used.") + await PlatformLink.prisma(tx).create( + data={ + "userId": user_id, + "platform": link_token.platform, + "platformServerId": link_token.platformServerId, + "ownerPlatformUserId": link_token.platformUserId, + "serverName": link_token.serverName, + } + ) + except UniqueViolationError as exc: + raise LinkAlreadyExistsError( + "This server was just linked by another request." + ) from exc + + logger.info( + "Linked %s server %s to user ...%s", + link_token.platform, + link_token.platformServerId, + user_id[-8:], + ) + + return ConfirmLinkResponse( + success=True, + platform=link_token.platform, + platform_server_id=link_token.platformServerId, + server_name=link_token.serverName, + ) + + +async def confirm_user_link(token: str, user_id: str) -> ConfirmUserLinkResponse: + link_token = await PlatformLinkToken.prisma().find_unique(where={"token": token}) + + if not link_token: + raise NotFoundError("Token not found.") + if link_token.linkType != LinkType.USER.value: + raise LinkFlowMismatchError("This link is for a different linking flow.") + if link_token.usedAt is not None: + raise LinkTokenExpiredError("This link has already been used.") + if link_token.expiresAt.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc): + raise LinkTokenExpiredError("This link has expired.") + + owner = await find_user_link_owner(link_token.platform, link_token.platformUserId) + if owner: + detail = ( + "Your DMs are already linked to your account." + if owner == user_id + else "This platform user is already linked to another AutoGPT account." + ) + raise LinkAlreadyExistsError(detail) + + now = datetime.now(timezone.utc) + try: + async with transaction() as tx: + updated = await PlatformLinkToken.prisma(tx).update_many( + where={"token": token, "usedAt": None, "expiresAt": {"gt": now}}, + data={"usedAt": now}, + ) + if updated == 0: + raise LinkTokenExpiredError("This link has already been used.") + await PlatformUserLink.prisma(tx).create( + data={ + "userId": user_id, + "platform": link_token.platform, + "platformUserId": link_token.platformUserId, + "platformUsername": link_token.platformUsername, + } + ) + except UniqueViolationError as exc: + raise LinkAlreadyExistsError( + "Your DMs were just linked by another request." + ) from exc + + logger.info( + "Linked %s DMs to AutoGPT user ...%s", link_token.platform, user_id[-8:] + ) + + return ConfirmUserLinkResponse( + success=True, + platform=link_token.platform, + platform_user_id=link_token.platformUserId, + ) + + +# ── Listing ─────────────────────────────────────────────────────────── + + +async def list_server_links(user_id: str) -> list[PlatformLinkInfo]: + links = await PlatformLink.prisma().find_many( + where={"userId": user_id}, + order={"linkedAt": "desc"}, + ) + return [ + PlatformLinkInfo( + id=link.id, + platform=link.platform, + platform_server_id=link.platformServerId, + owner_platform_user_id=link.ownerPlatformUserId, + server_name=link.serverName, + linked_at=link.linkedAt, + ) + for link in links + ] + + +async def list_user_links(user_id: str) -> list[PlatformUserLinkInfo]: + links = await PlatformUserLink.prisma().find_many( + where={"userId": user_id}, + order={"linkedAt": "desc"}, + ) + return [ + PlatformUserLinkInfo( + id=link.id, + platform=link.platform, + platform_user_id=link.platformUserId, + platform_username=link.platformUsername, + linked_at=link.linkedAt, + ) + for link in links + ] + + +# ── Deletion ────────────────────────────────────────────────────────── + + +async def delete_server_link(link_id: str, user_id: str) -> DeleteLinkResponse: + link = await PlatformLink.prisma().find_unique(where={"id": link_id}) + if not link: + raise NotFoundError("Link not found.") + if link.userId != user_id: + raise NotAuthorizedError("Not your link.") + + await PlatformLink.prisma().delete(where={"id": link_id}) + logger.info( + "Unlinked %s server %s from user ...%s", + link.platform, + link.platformServerId, + user_id[-8:], + ) + return DeleteLinkResponse(success=True) + + +async def delete_user_link(link_id: str, user_id: str) -> DeleteLinkResponse: + link = await PlatformUserLink.prisma().find_unique(where={"id": link_id}) + if not link: + raise NotFoundError("Link not found.") + if link.userId != user_id: + raise NotAuthorizedError("Not your link.") + + await PlatformUserLink.prisma().delete(where={"id": link_id}) + logger.info("Unlinked %s DMs from AutoGPT user ...%s", link.platform, user_id[-8:]) + return DeleteLinkResponse(success=True) diff --git a/autogpt_platform/backend/backend/platform_linking/db_test.py b/autogpt_platform/backend/backend/platform_linking/db_test.py new file mode 100644 index 0000000000..b02679103f --- /dev/null +++ b/autogpt_platform/backend/backend/platform_linking/db_test.py @@ -0,0 +1,481 @@ +"""Unit tests for platform_linking DB operations.""" + +from contextlib import asynccontextmanager +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from backend.util.exceptions import ( + LinkAlreadyExistsError, + LinkFlowMismatchError, + LinkTokenExpiredError, + NotAuthorizedError, + NotFoundError, +) + +from .db import ( + confirm_server_link, + confirm_user_link, + create_server_link_token, + create_user_link_token, + delete_server_link, + delete_user_link, + get_link_token_info, + get_link_token_status, + resolve_server_link, + resolve_user_link, +) +from .models import ( + CreateLinkTokenRequest, + CreateUserLinkTokenRequest, + LinkType, + Platform, +) + + +@asynccontextmanager +async def _fake_transaction(): + # Avoids Prisma's tx binding asyncio primitives to the wrong loop in tests. + yield MagicMock() + + +# ── Resolve ────────────────────────────────────────────────────────── + + +class TestResolve: + @pytest.mark.asyncio + async def test_server_linked(self): + with patch("backend.platform_linking.db.PlatformLink") as mock_link: + mock_link.prisma.return_value.find_first = AsyncMock( + return_value=MagicMock(userId="u-123") + ) + result = await resolve_server_link("DISCORD", "g1") + assert result.linked is True + + @pytest.mark.asyncio + async def test_server_unlinked(self): + with patch("backend.platform_linking.db.PlatformLink") as mock_link: + mock_link.prisma.return_value.find_first = AsyncMock(return_value=None) + result = await resolve_server_link("DISCORD", "g1") + assert result.linked is False + + @pytest.mark.asyncio + async def test_user_linked(self): + with patch("backend.platform_linking.db.PlatformUserLink") as mock_user_link: + mock_user_link.prisma.return_value.find_unique = AsyncMock( + return_value=MagicMock(userId="u-xyz") + ) + result = await resolve_user_link("DISCORD", "pu1") + assert result.linked is True + + @pytest.mark.asyncio + async def test_user_unlinked(self): + with patch("backend.platform_linking.db.PlatformUserLink") as mock_user_link: + mock_user_link.prisma.return_value.find_unique = AsyncMock( + return_value=None + ) + result = await resolve_user_link("DISCORD", "pu1") + assert result.linked is False + + +# ── Token creation ─────────────────────────────────────────────────── + + +class TestCreateServerLinkToken: + @pytest.mark.asyncio + async def test_creates_token_for_unlinked_server(self): + with ( + patch("backend.platform_linking.db.PlatformLink") as mock_link, + patch( + "backend.platform_linking.db.transaction", + new=_fake_transaction, + ), + patch("backend.platform_linking.db.PlatformLinkToken") as mock_token_model, + ): + mock_link.prisma.return_value.find_first = AsyncMock(return_value=None) + mock_token_model.prisma.return_value.update_many = AsyncMock(return_value=0) + mock_token_model.prisma.return_value.create = AsyncMock( + return_value=MagicMock() + ) + + result = await create_server_link_token( + CreateLinkTokenRequest( + platform=Platform.DISCORD, + platform_server_id="g1", + platform_user_id="u1", + server_name="Test", + ), + ) + + assert result.token + assert result.token in result.link_url + assert "?platform=DISCORD" in result.link_url + + @pytest.mark.asyncio + async def test_rejects_when_already_linked(self): + with patch("backend.platform_linking.db.PlatformLink") as mock_link: + mock_link.prisma.return_value.find_first = AsyncMock( + return_value=MagicMock(userId="u-owner") + ) + with pytest.raises(LinkAlreadyExistsError): + await create_server_link_token( + CreateLinkTokenRequest( + platform=Platform.DISCORD, + platform_server_id="g1", + platform_user_id="u1", + ), + ) + + +class TestCreateUserLinkToken: + @pytest.mark.asyncio + async def test_creates_token_for_unlinked_user(self): + with ( + patch("backend.platform_linking.db.PlatformUserLink") as mock_user_link, + patch( + "backend.platform_linking.db.transaction", + new=_fake_transaction, + ), + patch("backend.platform_linking.db.PlatformLinkToken") as mock_token_model, + ): + mock_user_link.prisma.return_value.find_unique = AsyncMock( + return_value=None + ) + mock_token_model.prisma.return_value.update_many = AsyncMock(return_value=0) + mock_token_model.prisma.return_value.create = AsyncMock( + return_value=MagicMock() + ) + + result = await create_user_link_token( + CreateUserLinkTokenRequest( + platform=Platform.DISCORD, + platform_user_id="pu1", + platform_username="Bently", + ), + ) + + assert result.token + assert result.token in result.link_url + + @pytest.mark.asyncio + async def test_rejects_when_already_linked(self): + with patch("backend.platform_linking.db.PlatformUserLink") as mock_user_link: + mock_user_link.prisma.return_value.find_unique = AsyncMock( + return_value=MagicMock(userId="u-owner") + ) + with pytest.raises(LinkAlreadyExistsError): + await create_user_link_token( + CreateUserLinkTokenRequest( + platform=Platform.DISCORD, + platform_user_id="pu1", + ), + ) + + +# ── Token status / info ─────────────────────────────────────────────── + + +class TestGetLinkTokenStatus: + @pytest.mark.asyncio + async def test_not_found(self): + with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model: + mock_model.prisma.return_value.find_unique = AsyncMock(return_value=None) + with pytest.raises(NotFoundError): + await get_link_token_status("abc") + + @pytest.mark.asyncio + async def test_pending(self): + future = datetime.now(timezone.utc) + timedelta(minutes=10) + fake_token = MagicMock(usedAt=None, expiresAt=future) + with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model: + mock_model.prisma.return_value.find_unique = AsyncMock( + return_value=fake_token + ) + result = await get_link_token_status("abc") + assert result.status == "pending" + + @pytest.mark.asyncio + async def test_expired_by_time(self): + past = datetime.now(timezone.utc) - timedelta(minutes=10) + fake_token = MagicMock(usedAt=None, expiresAt=past) + with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model: + mock_model.prisma.return_value.find_unique = AsyncMock( + return_value=fake_token + ) + result = await get_link_token_status("abc") + assert result.status == "expired" + + @pytest.mark.asyncio + async def test_used_with_user_link_reports_linked(self): + fake_token = MagicMock( + usedAt=datetime.now(timezone.utc), + linkType=LinkType.USER.value, + platform="DISCORD", + platformUserId="pu1", + ) + with ( + patch("backend.platform_linking.db.PlatformLinkToken") as mock_token, + patch("backend.platform_linking.db.PlatformUserLink") as mock_user_link, + ): + mock_token.prisma.return_value.find_unique = AsyncMock( + return_value=fake_token + ) + mock_user_link.prisma.return_value.find_unique = AsyncMock( + return_value=MagicMock(userId="u-owner") + ) + result = await get_link_token_status("abc") + assert result.status == "linked" + + @pytest.mark.asyncio + async def test_used_without_link_reports_expired(self): + # Superseded token: usedAt set, but no backing link row. + fake_token = MagicMock( + usedAt=datetime.now(timezone.utc), + linkType=LinkType.SERVER.value, + platform="DISCORD", + platformServerId="g1", + ) + with ( + patch("backend.platform_linking.db.PlatformLinkToken") as mock_token, + patch("backend.platform_linking.db.PlatformLink") as mock_link, + ): + mock_token.prisma.return_value.find_unique = AsyncMock( + return_value=fake_token + ) + mock_link.prisma.return_value.find_first = AsyncMock(return_value=None) + result = await get_link_token_status("abc") + assert result.status == "expired" + + +class TestGetLinkTokenInfo: + @pytest.mark.asyncio + async def test_not_found(self): + with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model: + mock_model.prisma.return_value.find_unique = AsyncMock(return_value=None) + with pytest.raises(NotFoundError): + await get_link_token_info("abc") + + @pytest.mark.asyncio + async def test_used_returns_not_found(self): + fake_token = MagicMock(usedAt=datetime.now(timezone.utc)) + with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model: + mock_model.prisma.return_value.find_unique = AsyncMock( + return_value=fake_token + ) + with pytest.raises(NotFoundError): + await get_link_token_info("abc") + + @pytest.mark.asyncio + async def test_expired_raises_expired(self): + past = datetime.now(timezone.utc) - timedelta(minutes=5) + fake_token = MagicMock(usedAt=None, expiresAt=past) + with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model: + mock_model.prisma.return_value.find_unique = AsyncMock( + return_value=fake_token + ) + with pytest.raises(LinkTokenExpiredError): + await get_link_token_info("abc") + + @pytest.mark.asyncio + async def test_success_returns_display_info(self): + future = datetime.now(timezone.utc) + timedelta(minutes=10) + fake_token = MagicMock( + usedAt=None, + expiresAt=future, + platform="DISCORD", + linkType=LinkType.SERVER.value, + serverName="My Server", + ) + with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model: + mock_model.prisma.return_value.find_unique = AsyncMock( + return_value=fake_token + ) + result = await get_link_token_info("abc") + assert result.platform == "DISCORD" + assert result.link_type == LinkType.SERVER + assert result.server_name == "My Server" + + +# ── Confirmation ───────────────────────────────────────────────────── + + +class TestConfirmServerLink: + @pytest.mark.asyncio + async def test_not_found(self): + with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model: + mock_model.prisma.return_value.find_unique = AsyncMock(return_value=None) + with pytest.raises(NotFoundError): + await confirm_server_link("abc", "u1") + + @pytest.mark.asyncio + async def test_wrong_link_type_rejected(self): + fake_token = MagicMock(linkType=LinkType.USER.value) + with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model: + mock_model.prisma.return_value.find_unique = AsyncMock( + return_value=fake_token + ) + with pytest.raises(LinkFlowMismatchError): + await confirm_server_link("abc", "u1") + + @pytest.mark.asyncio + async def test_already_used(self): + fake_token = MagicMock( + linkType=LinkType.SERVER.value, usedAt=datetime.now(timezone.utc) + ) + with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model: + mock_model.prisma.return_value.find_unique = AsyncMock( + return_value=fake_token + ) + with pytest.raises(LinkTokenExpiredError): + await confirm_server_link("abc", "u1") + + @pytest.mark.asyncio + async def test_expired_by_time(self): + fake_token = MagicMock( + linkType=LinkType.SERVER.value, + usedAt=None, + expiresAt=datetime.now(timezone.utc) - timedelta(minutes=5), + ) + with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model: + mock_model.prisma.return_value.find_unique = AsyncMock( + return_value=fake_token + ) + with pytest.raises(LinkTokenExpiredError): + await confirm_server_link("abc", "u1") + + @pytest.mark.asyncio + async def test_already_linked_to_same_user(self): + fake_token = MagicMock( + linkType=LinkType.SERVER.value, + usedAt=None, + expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10), + platform="DISCORD", + platformServerId="g1", + ) + with ( + patch("backend.platform_linking.db.PlatformLinkToken") as mock_token, + patch("backend.platform_linking.db.PlatformLink") as mock_link, + ): + mock_token.prisma.return_value.find_unique = AsyncMock( + return_value=fake_token + ) + mock_link.prisma.return_value.find_first = AsyncMock( + return_value=MagicMock(userId="u1") + ) + with pytest.raises(LinkAlreadyExistsError) as exc_info: + await confirm_server_link("abc", "u1") + assert "your account" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_already_linked_to_other_user(self): + fake_token = MagicMock( + linkType=LinkType.SERVER.value, + usedAt=None, + expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10), + platform="DISCORD", + platformServerId="g1", + ) + with ( + patch("backend.platform_linking.db.PlatformLinkToken") as mock_token, + patch("backend.platform_linking.db.PlatformLink") as mock_link, + ): + mock_token.prisma.return_value.find_unique = AsyncMock( + return_value=fake_token + ) + mock_link.prisma.return_value.find_first = AsyncMock( + return_value=MagicMock(userId="other-user") + ) + with pytest.raises(LinkAlreadyExistsError) as exc_info: + await confirm_server_link("abc", "u1") + assert "another" in str(exc_info.value) + + +class TestConfirmUserLink: + @pytest.mark.asyncio + async def test_not_found(self): + with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model: + mock_model.prisma.return_value.find_unique = AsyncMock(return_value=None) + with pytest.raises(NotFoundError): + await confirm_user_link("abc", "u1") + + @pytest.mark.asyncio + async def test_wrong_link_type_rejected(self): + fake_token = MagicMock(linkType=LinkType.SERVER.value) + with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model: + mock_model.prisma.return_value.find_unique = AsyncMock( + return_value=fake_token + ) + with pytest.raises(LinkFlowMismatchError): + await confirm_user_link("abc", "u1") + + @pytest.mark.asyncio + async def test_expired_by_time(self): + fake_token = MagicMock( + linkType=LinkType.USER.value, + usedAt=None, + expiresAt=datetime.now(timezone.utc) - timedelta(minutes=5), + ) + with patch("backend.platform_linking.db.PlatformLinkToken") as mock_model: + mock_model.prisma.return_value.find_unique = AsyncMock( + return_value=fake_token + ) + with pytest.raises(LinkTokenExpiredError): + await confirm_user_link("abc", "u1") + + @pytest.mark.asyncio + async def test_already_linked_to_other_user(self): + fake_token = MagicMock( + linkType=LinkType.USER.value, + usedAt=None, + expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10), + platform="DISCORD", + platformUserId="pu1", + ) + with ( + patch("backend.platform_linking.db.PlatformLinkToken") as mock_token, + patch("backend.platform_linking.db.PlatformUserLink") as mock_user_link, + ): + mock_token.prisma.return_value.find_unique = AsyncMock( + return_value=fake_token + ) + mock_user_link.prisma.return_value.find_unique = AsyncMock( + return_value=MagicMock(userId="other-user") + ) + with pytest.raises(LinkAlreadyExistsError): + await confirm_user_link("abc", "u1") + + +# ── Delete (authz checks) ──────────────────────────────────────────── + + +class TestDeleteLinks: + @pytest.mark.asyncio + async def test_delete_server_link_not_found(self): + with patch("backend.platform_linking.db.PlatformLink") as mock_model: + mock_model.prisma.return_value.find_unique = AsyncMock(return_value=None) + with pytest.raises(NotFoundError): + await delete_server_link("x", "u1") + + @pytest.mark.asyncio + async def test_delete_server_link_not_owned(self): + link = MagicMock(userId="owner-A", platform="DISCORD", platformServerId="g1") + with patch("backend.platform_linking.db.PlatformLink") as mock_model: + mock_model.prisma.return_value.find_unique = AsyncMock(return_value=link) + with pytest.raises(NotAuthorizedError): + await delete_server_link("x", "u-other") + + @pytest.mark.asyncio + async def test_delete_user_link_not_found(self): + with patch("backend.platform_linking.db.PlatformUserLink") as mock_model: + mock_model.prisma.return_value.find_unique = AsyncMock(return_value=None) + with pytest.raises(NotFoundError): + await delete_user_link("x", "u1") + + @pytest.mark.asyncio + async def test_delete_user_link_not_owned(self): + link = MagicMock(userId="owner-A", platform="DISCORD") + with patch("backend.platform_linking.db.PlatformUserLink") as mock_model: + mock_model.prisma.return_value.find_unique = AsyncMock(return_value=link) + with pytest.raises(NotAuthorizedError): + await delete_user_link("x", "u-other") diff --git a/autogpt_platform/backend/backend/platform_linking/manager.py b/autogpt_platform/backend/backend/platform_linking/manager.py new file mode 100644 index 0000000000..c8c7fdbd3a --- /dev/null +++ b/autogpt_platform/backend/backend/platform_linking/manager.py @@ -0,0 +1,82 @@ +"""AppService exposing bot-facing platform_linking ops over internal RPC.""" + +import logging + +from backend.data.db_accessors import platform_linking_db +from backend.util.service import AppService, AppServiceClient, endpoint_to_async, expose +from backend.util.settings import Settings + +from .chat import start_chat_turn +from .models import ( + BotChatRequest, + ChatTurnHandle, + CreateLinkTokenRequest, + CreateUserLinkTokenRequest, + LinkTokenResponse, + LinkTokenStatusResponse, + Platform, + ResolveResponse, +) + +logger = logging.getLogger(__name__) + + +class PlatformLinkingManager(AppService): + @classmethod + def get_port(cls) -> int: + return Settings().config.platform_linking_service_port + + @expose + async def resolve_server_link( + self, platform: Platform, platform_server_id: str + ) -> ResolveResponse: + return await platform_linking_db().resolve_server_link( + platform.value, platform_server_id + ) + + @expose + async def resolve_user_link( + self, platform: Platform, platform_user_id: str + ) -> ResolveResponse: + return await platform_linking_db().resolve_user_link( + platform.value, platform_user_id + ) + + @expose + async def create_server_link_token( + self, request: CreateLinkTokenRequest + ) -> LinkTokenResponse: + return await platform_linking_db().create_server_link_token(request) + + @expose + async def create_user_link_token( + self, request: CreateUserLinkTokenRequest + ) -> LinkTokenResponse: + return await platform_linking_db().create_user_link_token(request) + + @expose + async def get_link_token_status(self, token: str) -> LinkTokenStatusResponse: + return await platform_linking_db().get_link_token_status(token) + + @expose + async def start_chat_turn(self, request: BotChatRequest) -> ChatTurnHandle: + return await start_chat_turn(request) + + +class PlatformLinkingManagerClient(AppServiceClient): + @classmethod + def get_service_type(cls): + return PlatformLinkingManager + + resolve_server_link = endpoint_to_async(PlatformLinkingManager.resolve_server_link) + resolve_user_link = endpoint_to_async(PlatformLinkingManager.resolve_user_link) + create_server_link_token = endpoint_to_async( + PlatformLinkingManager.create_server_link_token + ) + create_user_link_token = endpoint_to_async( + PlatformLinkingManager.create_user_link_token + ) + get_link_token_status = endpoint_to_async( + PlatformLinkingManager.get_link_token_status + ) + start_chat_turn = endpoint_to_async(PlatformLinkingManager.start_chat_turn) diff --git a/autogpt_platform/backend/backend/platform_linking/manager_test.py b/autogpt_platform/backend/backend/platform_linking/manager_test.py new file mode 100644 index 0000000000..a768c08dac --- /dev/null +++ b/autogpt_platform/backend/backend/platform_linking/manager_test.py @@ -0,0 +1,346 @@ +"""Tests for PlatformLinkingManager RPC wiring and confirm-token races.""" + +import asyncio +from contextlib import asynccontextmanager +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from backend.util.exceptions import LinkTokenExpiredError + +from .db import confirm_server_link, confirm_user_link +from .manager import PlatformLinkingManager, PlatformLinkingManagerClient +from .models import ( + BotChatRequest, + CreateLinkTokenRequest, + CreateUserLinkTokenRequest, + LinkType, + Platform, + ResolveResponse, +) + + +@asynccontextmanager +async def _fake_transaction(): + yield MagicMock() + + +class TestManagerWiring: + def test_get_port(self): + assert PlatformLinkingManager.get_port() == 8009 + + def test_client_exposes_expected_rpc_surface(self): + service_type = PlatformLinkingManagerClient.get_service_type() + assert service_type is PlatformLinkingManager + + expected = { + "resolve_server_link", + "resolve_user_link", + "create_server_link_token", + "create_user_link_token", + "get_link_token_status", + "start_chat_turn", + } + for name in expected: + assert hasattr( + PlatformLinkingManagerClient, name + ), f"Client missing RPC stub: {name}" + + for name in ( + "confirm_server_link", + "confirm_user_link", + "list_server_links", + "list_user_links", + "delete_server_link", + "delete_user_link", + ): + assert not hasattr( + PlatformLinkingManagerClient, name + ), f"User-facing method leaked to bot client: {name}" + + @pytest.mark.asyncio + async def test_resolve_server_link_delegates_to_accessor(self): + manager = PlatformLinkingManager() + db_mock = MagicMock() + db_mock.resolve_server_link = AsyncMock( + return_value=ResolveResponse(linked=True) + ) + with patch( + "backend.platform_linking.manager.platform_linking_db", + return_value=db_mock, + ): + result = await manager.resolve_server_link(Platform.DISCORD, "g1") + db_mock.resolve_server_link.assert_awaited_once_with("DISCORD", "g1") + assert result.linked is True + + @pytest.mark.asyncio + async def test_resolve_user_link_delegates_to_accessor(self): + manager = PlatformLinkingManager() + db_mock = MagicMock() + db_mock.resolve_user_link = AsyncMock( + return_value=ResolveResponse(linked=False) + ) + with patch( + "backend.platform_linking.manager.platform_linking_db", + return_value=db_mock, + ): + result = await manager.resolve_user_link(Platform.DISCORD, "pu1") + db_mock.resolve_user_link.assert_awaited_once_with("DISCORD", "pu1") + assert result.linked is False + + @pytest.mark.asyncio + async def test_create_server_link_token_delegates(self): + manager = PlatformLinkingManager() + req = CreateLinkTokenRequest( + platform=Platform.DISCORD, + platform_server_id="g1", + platform_user_id="u1", + ) + fake_response = MagicMock() + db_mock = MagicMock() + db_mock.create_server_link_token = AsyncMock(return_value=fake_response) + with patch( + "backend.platform_linking.manager.platform_linking_db", + return_value=db_mock, + ): + result = await manager.create_server_link_token(req) + db_mock.create_server_link_token.assert_awaited_once_with(req) + assert result is fake_response + + @pytest.mark.asyncio + async def test_create_user_link_token_delegates(self): + manager = PlatformLinkingManager() + req = CreateUserLinkTokenRequest( + platform=Platform.DISCORD, platform_user_id="pu1" + ) + fake_response = MagicMock() + db_mock = MagicMock() + db_mock.create_user_link_token = AsyncMock(return_value=fake_response) + with patch( + "backend.platform_linking.manager.platform_linking_db", + return_value=db_mock, + ): + result = await manager.create_user_link_token(req) + db_mock.create_user_link_token.assert_awaited_once_with(req) + assert result is fake_response + + @pytest.mark.asyncio + async def test_get_link_token_status_delegates(self): + manager = PlatformLinkingManager() + fake_response = MagicMock() + db_mock = MagicMock() + db_mock.get_link_token_status = AsyncMock(return_value=fake_response) + with patch( + "backend.platform_linking.manager.platform_linking_db", + return_value=db_mock, + ): + result = await manager.get_link_token_status("tok") + db_mock.get_link_token_status.assert_awaited_once_with("tok") + assert result is fake_response + + @pytest.mark.asyncio + async def test_start_chat_turn_delegates(self): + manager = PlatformLinkingManager() + req = BotChatRequest( + platform=Platform.DISCORD, + platform_user_id="pu1", + message="hi", + ) + fake_response = MagicMock() + with patch( + "backend.platform_linking.manager.start_chat_turn", + new=AsyncMock(return_value=fake_response), + ) as stub: + result = await manager.start_chat_turn(req) + stub.assert_awaited_once_with(req) + assert result is fake_response + + +class TestAdversarialConfirmRace: + """Concurrent confirm of one token: exactly one winner via ``update_many`` + guarded on ``usedAt = None``.""" + + @pytest.mark.asyncio + async def test_second_confirm_loses(self): + # update_many returns 0 → caller lost the race + fake_token = MagicMock( + linkType=LinkType.SERVER.value, + usedAt=None, + expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10), + platform="DISCORD", + platformServerId="g1", + ) + + with ( + patch("backend.platform_linking.db.PlatformLinkToken") as mock_token, + patch("backend.platform_linking.db.PlatformLink") as mock_link, + patch("backend.platform_linking.db.transaction", new=_fake_transaction), + ): + mock_token.prisma.return_value.find_unique = AsyncMock( + return_value=fake_token + ) + mock_link.prisma.return_value.find_first = AsyncMock(return_value=None) + mock_token.prisma.return_value.update_many = AsyncMock(return_value=0) + + with pytest.raises(LinkTokenExpiredError): + await confirm_server_link("abc", "user-late") + + @pytest.mark.asyncio + async def test_second_confirm_wins_when_update_many_returns_one(self): + fake_token = MagicMock( + linkType=LinkType.SERVER.value, + usedAt=None, + expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10), + platform="DISCORD", + platformServerId="g1", + platformUserId="pu1", + serverName="S1", + ) + + with ( + patch("backend.platform_linking.db.PlatformLinkToken") as mock_token, + patch("backend.platform_linking.db.PlatformLink") as mock_link, + patch("backend.platform_linking.db.transaction", new=_fake_transaction), + ): + mock_token.prisma.return_value.find_unique = AsyncMock( + return_value=fake_token + ) + mock_link.prisma.return_value.find_first = AsyncMock(return_value=None) + mock_token.prisma.return_value.update_many = AsyncMock(return_value=1) + mock_link.prisma.return_value.create = AsyncMock(return_value=MagicMock()) + + result = await confirm_server_link("abc", "user-winner") + + assert result.success is True + assert result.platform_server_id == "g1" + + @pytest.mark.asyncio + async def test_gather_confirm_same_user_one_winner(self): + fake_token = MagicMock( + linkType=LinkType.SERVER.value, + usedAt=None, + expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10), + platform="DISCORD", + platformServerId="g1", + platformUserId="pu1", + serverName="S1", + ) + update_results = [1, 0] + + async def flaky_update_many(*args, **kwargs): + return update_results.pop(0) + + with ( + patch("backend.platform_linking.db.PlatformLinkToken") as mock_token, + patch("backend.platform_linking.db.PlatformLink") as mock_link, + patch("backend.platform_linking.db.transaction", new=_fake_transaction), + ): + mock_token.prisma.return_value.find_unique = AsyncMock( + return_value=fake_token + ) + mock_link.prisma.return_value.find_first = AsyncMock(return_value=None) + mock_token.prisma.return_value.update_many = flaky_update_many + mock_link.prisma.return_value.create = AsyncMock(return_value=MagicMock()) + + results = await asyncio.gather( + confirm_server_link("abc", "u1"), + confirm_server_link("abc", "u1"), + return_exceptions=True, + ) + + successes = [r for r in results if not isinstance(r, Exception)] + losses = [r for r in results if isinstance(r, LinkTokenExpiredError)] + assert len(successes) == 1 + assert len(losses) == 1 + + @pytest.mark.asyncio + async def test_gather_confirm_different_users_one_winner_no_hijack(self): + # Different users racing the same token: still exactly one winner, + # and the other gets a clean LinkTokenExpiredError (no partial state). + fake_token = MagicMock( + linkType=LinkType.SERVER.value, + usedAt=None, + expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10), + platform="DISCORD", + platformServerId="g1", + platformUserId="pu1", + serverName="S1", + ) + update_results = [1, 0] + + async def flaky_update_many(*args, **kwargs): + return update_results.pop(0) + + created_link_user_ids: list[str] = [] + + async def record_create(*, data): + created_link_user_ids.append(data["userId"]) + return MagicMock() + + with ( + patch("backend.platform_linking.db.PlatformLinkToken") as mock_token, + patch("backend.platform_linking.db.PlatformLink") as mock_link, + patch("backend.platform_linking.db.transaction", new=_fake_transaction), + ): + mock_token.prisma.return_value.find_unique = AsyncMock( + return_value=fake_token + ) + mock_link.prisma.return_value.find_first = AsyncMock(return_value=None) + mock_token.prisma.return_value.update_many = flaky_update_many + mock_link.prisma.return_value.create = record_create + + results = await asyncio.gather( + confirm_server_link("abc", "user-a"), + confirm_server_link("abc", "user-b"), + return_exceptions=True, + ) + + successes = [r for r in results if not isinstance(r, Exception)] + losses = [r for r in results if isinstance(r, LinkTokenExpiredError)] + assert len(successes) == 1 + assert len(losses) == 1 + assert len(created_link_user_ids) == 1 + assert created_link_user_ids[0] in ("user-a", "user-b") + + @pytest.mark.asyncio + async def test_gather_confirm_user_link_one_winner(self): + fake_token = MagicMock( + linkType=LinkType.USER.value, + usedAt=None, + expiresAt=datetime.now(timezone.utc) + timedelta(minutes=10), + platform="DISCORD", + platformUserId="pu1", + platformUsername="pu_name", + ) + update_results = [1, 0] + + async def flaky_update_many(*args, **kwargs): + return update_results.pop(0) + + with ( + patch("backend.platform_linking.db.PlatformLinkToken") as mock_token, + patch("backend.platform_linking.db.PlatformUserLink") as mock_user_link, + patch("backend.platform_linking.db.transaction", new=_fake_transaction), + ): + mock_token.prisma.return_value.find_unique = AsyncMock( + return_value=fake_token + ) + mock_user_link.prisma.return_value.find_unique = AsyncMock( + return_value=None + ) + mock_token.prisma.return_value.update_many = flaky_update_many + mock_user_link.prisma.return_value.create = AsyncMock( + return_value=MagicMock() + ) + + results = await asyncio.gather( + confirm_user_link("abc", "user-a"), + confirm_user_link("abc", "user-b"), + return_exceptions=True, + ) + + successes = [r for r in results if not isinstance(r, Exception)] + losses = [r for r in results if isinstance(r, LinkTokenExpiredError)] + assert len(successes) == 1 + assert len(losses) == 1 diff --git a/autogpt_platform/backend/backend/platform_linking/models.py b/autogpt_platform/backend/backend/platform_linking/models.py new file mode 100644 index 0000000000..fa17871b7f --- /dev/null +++ b/autogpt_platform/backend/backend/platform_linking/models.py @@ -0,0 +1,182 @@ +"""Pydantic models for platform_linking requests and responses.""" + +from datetime import datetime +from enum import Enum +from typing import Literal + +from pydantic import BaseModel, Field + + +class Platform(str, Enum): + """Mirror of the Prisma PlatformType enum.""" + + DISCORD = "DISCORD" + TELEGRAM = "TELEGRAM" + SLACK = "SLACK" + TEAMS = "TEAMS" + WHATSAPP = "WHATSAPP" + GITHUB = "GITHUB" + LINEAR = "LINEAR" + + +class LinkType(str, Enum): + SERVER = "SERVER" + USER = "USER" + + +# ── Request Models ───────────────────────────────────────────────────── + + +class CreateLinkTokenRequest(BaseModel): + platform: Platform = Field(description="Platform name") + platform_server_id: str = Field( + description="Server/guild/group ID on the platform", + min_length=1, + max_length=255, + ) + platform_user_id: str = Field( + description="Platform user ID of the person claiming ownership", + min_length=1, + max_length=255, + ) + platform_username: str | None = Field( + default=None, + description="Display name of the person claiming ownership", + max_length=255, + ) + server_name: str | None = Field( + default=None, + description="Display name of the server/group", + max_length=255, + ) + channel_id: str | None = Field( + default=None, + description="Channel ID so the bot can send a confirmation message", + max_length=255, + ) + + +class CreateUserLinkTokenRequest(BaseModel): + platform: Platform + platform_user_id: str = Field( + description="Platform user ID of the person linking their DMs", + min_length=1, + max_length=255, + ) + platform_username: str | None = Field( + default=None, + description="Their display name (best-effort for audit)", + max_length=255, + ) + + +class ResolveServerRequest(BaseModel): + platform: Platform + platform_server_id: str = Field( + description="Server/guild/group ID to look up", + min_length=1, + max_length=255, + ) + + +class ResolveUserRequest(BaseModel): + platform: Platform + platform_user_id: str = Field( + description="Platform user ID to look up", + min_length=1, + max_length=255, + ) + + +class BotChatRequest(BaseModel): + """Bot message request. If ``platform_server_id`` is set, the turn is + billed to that server's owner; otherwise billed to ``platform_user_id`` + (DM context).""" + + platform: Platform + platform_server_id: str | None = Field( + default=None, + description="Server/guild/group ID — null for DM context", + min_length=1, + max_length=255, + ) + platform_user_id: str = Field( + description="Platform user ID of the person who sent the message", + min_length=1, + max_length=255, + ) + message: str = Field( + description="The user's message", min_length=1, max_length=32000 + ) + session_id: str | None = Field( + default=None, + description="Existing CoPilot session ID. If omitted, a new session is created.", + ) + + +# ── Response Models ──────────────────────────────────────────────────── + + +class LinkTokenResponse(BaseModel): + token: str + expires_at: datetime + link_url: str + + +class LinkTokenStatusResponse(BaseModel): + status: Literal["pending", "linked", "expired"] + + +class LinkTokenInfoResponse(BaseModel): + platform: str + link_type: LinkType + server_name: str | None = None + + +class ResolveResponse(BaseModel): + linked: bool + + +class PlatformLinkInfo(BaseModel): + id: str + platform: str + platform_server_id: str + owner_platform_user_id: str + server_name: str | None + linked_at: datetime + + +class PlatformUserLinkInfo(BaseModel): + id: str + platform: str + platform_user_id: str + platform_username: str | None + linked_at: datetime + + +class ConfirmLinkResponse(BaseModel): + success: bool + link_type: LinkType = LinkType.SERVER + platform: str + platform_server_id: str + server_name: str | None + + +class ConfirmUserLinkResponse(BaseModel): + success: bool + link_type: LinkType = LinkType.USER + platform: str + platform_user_id: str + + +class DeleteLinkResponse(BaseModel): + success: bool + + +class ChatTurnHandle(BaseModel): + """Subscribe keys for a pending copilot turn.""" + + session_id: str + turn_id: str + user_id: str + subscribe_from: str = "0-0" diff --git a/autogpt_platform/backend/backend/platform_linking/models_test.py b/autogpt_platform/backend/backend/platform_linking/models_test.py new file mode 100644 index 0000000000..0f5a0918be --- /dev/null +++ b/autogpt_platform/backend/backend/platform_linking/models_test.py @@ -0,0 +1,178 @@ +"""Schema validation tests for platform_linking Pydantic models.""" + +import pytest +from pydantic import ValidationError + +from .models import ( + BotChatRequest, + ConfirmLinkResponse, + CreateLinkTokenRequest, + DeleteLinkResponse, + LinkTokenStatusResponse, + Platform, + ResolveResponse, + ResolveServerRequest, +) + + +class TestPlatformEnum: + def test_all_platforms_exist(self): + assert Platform.DISCORD.value == "DISCORD" + assert Platform.TELEGRAM.value == "TELEGRAM" + assert Platform.SLACK.value == "SLACK" + assert Platform.TEAMS.value == "TEAMS" + assert Platform.WHATSAPP.value == "WHATSAPP" + assert Platform.GITHUB.value == "GITHUB" + assert Platform.LINEAR.value == "LINEAR" + + +class TestCreateLinkTokenRequest: + def test_valid_request(self): + req = CreateLinkTokenRequest( + platform=Platform.DISCORD, + platform_server_id="1126875755960336515", + platform_user_id="353922987235213313", + platform_username="Bently", + server_name="My Discord Server", + ) + assert req.platform == Platform.DISCORD + assert req.platform_server_id == "1126875755960336515" + assert req.platform_user_id == "353922987235213313" + assert req.server_name == "My Discord Server" + + def test_minimal_request(self): + req = CreateLinkTokenRequest( + platform=Platform.TELEGRAM, + platform_server_id="-100123456789", + platform_user_id="987654321", + ) + assert req.server_name is None + assert req.platform_username is None + + def test_empty_server_id_rejected(self): + with pytest.raises(ValidationError): + CreateLinkTokenRequest( + platform=Platform.DISCORD, + platform_server_id="", + platform_user_id="123", + ) + + def test_too_long_server_id_rejected(self): + with pytest.raises(ValidationError): + CreateLinkTokenRequest( + platform=Platform.DISCORD, + platform_server_id="x" * 256, + platform_user_id="123", + ) + + def test_invalid_platform_rejected(self): + with pytest.raises(ValidationError): + CreateLinkTokenRequest.model_validate( + { + "platform": "INVALID", + "platform_server_id": "123", + "platform_user_id": "456", + } + ) + + +class TestResolveServerRequest: + def test_valid_request(self): + req = ResolveServerRequest( + platform=Platform.DISCORD, + platform_server_id="1126875755960336515", + ) + assert req.platform == Platform.DISCORD + assert req.platform_server_id == "1126875755960336515" + + def test_empty_server_id_rejected(self): + with pytest.raises(ValidationError): + ResolveServerRequest( + platform=Platform.SLACK, + platform_server_id="", + ) + + +class TestBotChatRequest: + def test_server_context(self): + req = BotChatRequest( + platform=Platform.DISCORD, + platform_server_id="1126875755960336515", + platform_user_id="353922987235213313", + message="Hello CoPilot!", + ) + assert req.platform == Platform.DISCORD + assert req.platform_server_id == "1126875755960336515" + assert req.session_id is None + + def test_dm_context_omits_server_id(self): + req = BotChatRequest( + platform=Platform.DISCORD, + platform_user_id="353922987235213313", + message="Hello in DMs!", + ) + assert req.platform_server_id is None + + def test_with_session_id(self): + req = BotChatRequest( + platform=Platform.DISCORD, + platform_server_id="guild_123", + platform_user_id="user_456", + message="follow up", + session_id="session-uuid-here", + ) + assert req.session_id == "session-uuid-here" + + def test_empty_message_rejected(self): + with pytest.raises(ValidationError): + BotChatRequest( + platform=Platform.DISCORD, + platform_server_id="guild_123", + platform_user_id="user_456", + message="", + ) + + def test_empty_string_server_id_rejected(self): + with pytest.raises(ValidationError): + BotChatRequest( + platform=Platform.DISCORD, + platform_server_id="", + platform_user_id="user_456", + message="hi", + ) + + +class TestResponseModels: + def test_link_token_status_pending(self): + resp = LinkTokenStatusResponse(status="pending") + assert resp.status == "pending" + + def test_link_token_status_linked(self): + resp = LinkTokenStatusResponse(status="linked") + assert resp.status == "linked" + + def test_link_token_status_expired(self): + resp = LinkTokenStatusResponse(status="expired") + assert resp.status == "expired" + + def test_resolve_linked(self): + resp = ResolveResponse(linked=True) + assert resp.linked is True + + def test_resolve_not_linked(self): + resp = ResolveResponse(linked=False) + assert resp.linked is False + + def test_confirm_link_response(self): + resp = ConfirmLinkResponse( + success=True, + platform="DISCORD", + platform_server_id="1126875755960336515", + server_name="My Server", + ) + assert resp.success is True + assert resp.server_name == "My Server" + + def test_delete_link_response(self): + resp = DeleteLinkResponse(success=True) + assert resp.success is True diff --git a/autogpt_platform/backend/backend/platform_linking_manager.py b/autogpt_platform/backend/backend/platform_linking_manager.py new file mode 100644 index 0000000000..1c36efd29c --- /dev/null +++ b/autogpt_platform/backend/backend/platform_linking_manager.py @@ -0,0 +1,15 @@ +from backend.app import run_processes +from backend.platform_linking.manager import PlatformLinkingManager + + +def main(): + """ + Run the AutoGPT-server Platform Linking Manager service. + """ + run_processes( + PlatformLinkingManager(), + ) + + +if __name__ == "__main__": + main() diff --git a/autogpt_platform/backend/backend/util/clients.py b/autogpt_platform/backend/backend/util/clients.py index 6d23313f02..391142b214 100644 --- a/autogpt_platform/backend/backend/util/clients.py +++ b/autogpt_platform/backend/backend/util/clients.py @@ -27,6 +27,7 @@ if TYPE_CHECKING: from backend.executor.scheduler import SchedulerClient from backend.integrations.credentials_store import IntegrationCredentialsStore from backend.notifications.notifications import NotificationManagerClient + from backend.platform_linking.manager import PlatformLinkingManagerClient @thread_cached @@ -67,6 +68,15 @@ def get_notification_manager_client() -> "NotificationManagerClient": return get_service_client(NotificationManagerClient) +@thread_cached +def get_platform_linking_manager_client() -> "PlatformLinkingManagerClient": + """Get a thread-cached PlatformLinkingManagerClient.""" + from backend.platform_linking.manager import PlatformLinkingManagerClient + from backend.util.service import get_service_client + + return get_service_client(PlatformLinkingManagerClient) + + # ============ Execution Event Bus Helpers ============ # diff --git a/autogpt_platform/backend/backend/util/exceptions.py b/autogpt_platform/backend/backend/util/exceptions.py index 04172465b9..69d3396789 100644 --- a/autogpt_platform/backend/backend/util/exceptions.py +++ b/autogpt_platform/backend/backend/util/exceptions.py @@ -155,3 +155,19 @@ class RedisError(Exception): """Raised when there is an error interacting with Redis""" pass + + +class LinkAlreadyExistsError(ValueError): + """A platform_linking target (server or user) is already linked.""" + + +class LinkTokenExpiredError(ValueError): + """A platform_linking token has expired or been consumed.""" + + +class LinkFlowMismatchError(ValueError): + """A platform_linking token was used for the wrong flow (server vs user).""" + + +class DuplicateChatMessageError(ValueError): + """The same user message is already in flight for this chat session.""" diff --git a/autogpt_platform/backend/backend/util/settings.py b/autogpt_platform/backend/backend/util/settings.py index 736219ea9b..5c831b2a34 100644 --- a/autogpt_platform/backend/backend/util/settings.py +++ b/autogpt_platform/backend/backend/util/settings.py @@ -252,6 +252,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings): description="The port for notification service daemon to run on", ) + platform_linking_service_port: int = Field( + default=8009, + description="The port for the platform_linking manager daemon to run on", + ) + otto_api_url: str = Field( default="", description="The URL for the Otto API service", @@ -269,6 +274,13 @@ class Config(UpdateTrackingModel["Config"], BaseSettings): "This value is then used to generate redirect URLs for OAuth flows.", ) + platform_link_base_url: str = Field( + default="https://platform.agpt.co/link", + description="Base URL the bot service prepends to one-time linking " + "tokens when it posts them to users ({base}/{token}?platform=...). " + "Should point at the frontend /link page.", + ) + media_gcs_bucket_name: str = Field( default="", description="The name of the Google Cloud Storage bucket for media files", diff --git a/autogpt_platform/backend/migrations/20260331120000_add_platform_bot_linking/migration.sql b/autogpt_platform/backend/migrations/20260331120000_add_platform_bot_linking/migration.sql new file mode 100644 index 0000000000..2704daeedf --- /dev/null +++ b/autogpt_platform/backend/migrations/20260331120000_add_platform_bot_linking/migration.sql @@ -0,0 +1,55 @@ +-- CreateEnum +CREATE TYPE "PlatformType" AS ENUM ('DISCORD', 'TELEGRAM', 'SLACK', 'TEAMS', 'WHATSAPP', 'GITHUB', 'LINEAR'); + +-- CreateTable +-- PlatformLink maps a platform server (Discord guild, Telegram group, etc.) to an AutoGPT +-- owner account. The first user to authenticate becomes the owner — all usage from that +-- server is billed to their account. Each user within the server gets their own CoPilot +-- session, all visible in the owner's AutoGPT account. +CREATE TABLE "PlatformLink" ( + "id" TEXT NOT NULL, + "userId" TEXT NOT NULL, + "platform" "PlatformType" NOT NULL, + "platformServerId" TEXT NOT NULL, + "ownerPlatformUserId" TEXT NOT NULL, + "serverName" TEXT, + "linkedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + + CONSTRAINT "PlatformLink_pkey" PRIMARY KEY ("id") +); + +-- CreateTable +-- PlatformLinkToken is a one-time token for the server linking flow. +CREATE TABLE "PlatformLinkToken" ( + "id" TEXT NOT NULL, + "token" TEXT NOT NULL, + "platform" "PlatformType" NOT NULL, + "platformServerId" TEXT NOT NULL, + "platformUserId" TEXT NOT NULL, + "platformUsername" TEXT, + "serverName" TEXT, + "channelId" TEXT, + "expiresAt" TIMESTAMP(3) NOT NULL, + "usedAt" TIMESTAMP(3), + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + + CONSTRAINT "PlatformLinkToken_pkey" PRIMARY KEY ("id") +); + +-- CreateIndex +CREATE UNIQUE INDEX "PlatformLink_platform_platformServerId_key" ON "PlatformLink"("platform", "platformServerId"); + +-- CreateIndex +CREATE INDEX "PlatformLink_userId_idx" ON "PlatformLink"("userId"); + +-- CreateIndex +CREATE UNIQUE INDEX "PlatformLinkToken_token_key" ON "PlatformLinkToken"("token"); + +-- CreateIndex +CREATE INDEX "PlatformLinkToken_platform_platformServerId_idx" ON "PlatformLinkToken"("platform", "platformServerId"); + +-- CreateIndex +CREATE INDEX "PlatformLinkToken_expiresAt_idx" ON "PlatformLinkToken"("expiresAt"); + +-- AddForeignKey +ALTER TABLE "PlatformLink" ADD CONSTRAINT "PlatformLink_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/autogpt_platform/backend/migrations/20260414160000_add_platform_user_links/migration.sql b/autogpt_platform/backend/migrations/20260414160000_add_platform_user_links/migration.sql new file mode 100644 index 0000000000..deb098a288 --- /dev/null +++ b/autogpt_platform/backend/migrations/20260414160000_add_platform_user_links/migration.sql @@ -0,0 +1,37 @@ +-- CreateEnum +-- Server links (group chats / guilds) and user links (personal DMs) are +-- fully independent — a user who owns a linked server still has to link +-- their DMs separately. +CREATE TYPE "PlatformLinkType" AS ENUM ('SERVER', 'USER'); + +-- CreateTable +-- PlatformUserLink maps an individual platform user identity to an AutoGPT +-- account for 1:1 DMs with the bot. Independent from PlatformLink. +CREATE TABLE "PlatformUserLink" ( + "id" TEXT NOT NULL, + "userId" TEXT NOT NULL, + "platform" "PlatformType" NOT NULL, + "platformUserId" TEXT NOT NULL, + "platformUsername" TEXT, + "linkedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + + CONSTRAINT "PlatformUserLink_pkey" PRIMARY KEY ("id") +); + +-- CreateIndex +CREATE UNIQUE INDEX "PlatformUserLink_platform_platformUserId_key" ON "PlatformUserLink"("platform", "platformUserId"); + +-- CreateIndex +CREATE INDEX "PlatformUserLink_userId_idx" ON "PlatformUserLink"("userId"); + +-- AddForeignKey +ALTER TABLE "PlatformUserLink" ADD CONSTRAINT "PlatformUserLink_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AlterTable: PlatformLinkToken now supports SERVER or USER tokens. +-- Existing rows are all SERVER (default matches the column default). +ALTER TABLE "PlatformLinkToken" + ADD COLUMN "linkType" "PlatformLinkType" NOT NULL DEFAULT 'SERVER', + ALTER COLUMN "platformServerId" DROP NOT NULL; + +-- CreateIndex +CREATE INDEX "PlatformLinkToken_platform_platformUserId_idx" ON "PlatformLinkToken"("platform", "platformUserId"); diff --git a/autogpt_platform/backend/migrations/20260417000000_add_shared_execution_file/migration.sql b/autogpt_platform/backend/migrations/20260417000000_add_shared_execution_file/migration.sql new file mode 100644 index 0000000000..ad8b67647e --- /dev/null +++ b/autogpt_platform/backend/migrations/20260417000000_add_shared_execution_file/migration.sql @@ -0,0 +1,25 @@ +-- CreateTable +CREATE TABLE "SharedExecutionFile" ( + "id" TEXT NOT NULL, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "executionId" TEXT NOT NULL, + "fileId" TEXT NOT NULL, + "shareToken" TEXT NOT NULL, + + CONSTRAINT "SharedExecutionFile_pkey" PRIMARY KEY ("id") +); + +-- CreateIndex +CREATE UNIQUE INDEX "SharedExecutionFile_shareToken_fileId_key" ON "SharedExecutionFile"("shareToken", "fileId"); + +-- CreateIndex +CREATE INDEX "SharedExecutionFile_shareToken_idx" ON "SharedExecutionFile"("shareToken"); + +-- CreateIndex +CREATE INDEX "SharedExecutionFile_executionId_idx" ON "SharedExecutionFile"("executionId"); + +-- AddForeignKey +ALTER TABLE "SharedExecutionFile" ADD CONSTRAINT "SharedExecutionFile_executionId_fkey" FOREIGN KEY ("executionId") REFERENCES "AgentGraphExecution"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "SharedExecutionFile" ADD CONSTRAINT "SharedExecutionFile_fileId_fkey" FOREIGN KEY ("fileId") REFERENCES "UserWorkspaceFile"("id") ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/autogpt_platform/backend/pyproject.toml b/autogpt_platform/backend/pyproject.toml index 6e7003a65d..eb9c26c5dd 100644 --- a/autogpt_platform/backend/pyproject.toml +++ b/autogpt_platform/backend/pyproject.toml @@ -129,6 +129,7 @@ db = "backend.db:main" ws = "backend.ws:main" scheduler = "backend.scheduler:main" notification = "backend.notification:main" +platform-linking-manager = "backend.platform_linking_manager:main" executor = "backend.exec:main" analytics-setup = "scripts.generate_views:main_setup" analytics-views = "scripts.generate_views:main_views" diff --git a/autogpt_platform/backend/schema.prisma b/autogpt_platform/backend/schema.prisma index e224be7d5f..b6ddc7cad0 100644 --- a/autogpt_platform/backend/schema.prisma +++ b/autogpt_platform/backend/schema.prisma @@ -81,6 +81,10 @@ model User { OAuthAuthorizationCodes OAuthAuthorizationCode[] OAuthAccessTokens OAuthAccessToken[] OAuthRefreshTokens OAuthRefreshToken[] + + // Platform bot linking + PlatformLinks PlatformLink[] + PlatformUserLinks PlatformUserLink[] } enum SubscriptionTier { @@ -200,10 +204,32 @@ model UserWorkspaceFile { metadata Json @default("{}") + SharedExecutionFiles SharedExecutionFile[] + @@unique([workspaceId, path]) @@index([workspaceId, isDeleted]) } +// Tracks which workspace files are exposed via a shared execution. +// Created when sharing is enabled, deleted when sharing is disabled. +// The public file download endpoint validates against this table. +model SharedExecutionFile { + id String @id @default(uuid()) + createdAt DateTime @default(now()) + + executionId String + Execution AgentGraphExecution @relation(fields: [executionId], references: [id], onDelete: Cascade) + + fileId String + File UserWorkspaceFile @relation(fields: [fileId], references: [id], onDelete: Cascade) + + shareToken String + + @@unique([shareToken, fileId]) + @@index([shareToken]) + @@index([executionId]) +} + model BuilderSearchHistory { id String @id @default(uuid()) createdAt DateTime @default(now()) @@ -585,9 +611,10 @@ model AgentGraphExecution { ChildExecutions AgentGraphExecution[] @relation("ParentChildExecution") // Sharing fields - isShared Boolean @default(false) - shareToken String? @unique - sharedAt DateTime? + isShared Boolean @default(false) + shareToken String? @unique + sharedAt DateTime? + SharedExecutionFiles SharedExecutionFile[] @@index([agentGraphId, agentGraphVersion]) @@index([userId, isDeleted, createdAt]) @@ -1366,3 +1393,84 @@ model OAuthRefreshToken { @@index([userId, applicationId]) @@index([expiresAt]) // For cleanup } + +// ── Platform Bot Linking ────────────────────────────────────────────── +// Links external chat platform identities (Discord, Telegram, Slack, etc.) +// to AutoGPT user accounts, enabling the multi-platform CoPilot bot. + +enum PlatformType { + DISCORD + TELEGRAM + SLACK + TEAMS + WHATSAPP + GITHUB + LINEAR +} + +// Whether a linking token claims a server (group chat / guild) or a personal +// 1:1 user link (DMs). Server and user links are completely independent — +// linking a server does not grant DM access and vice versa. +enum PlatformLinkType { + SERVER + USER +} + +// Maps a platform server (Discord guild, Telegram group, Slack workspace, etc.) +// to an AutoGPT user account. The user who first authenticates becomes the +// "owner" — all usage from that server is attributed to their account. +model PlatformLink { + id String @id @default(uuid()) + userId String // AutoGPT user ID of the owner + User User @relation(fields: [userId], references: [id], onDelete: Cascade) + platform PlatformType + platformServerId String // Server/guild/group ID on that platform + ownerPlatformUserId String // Platform user ID of the person who set it up + serverName String? // Display name of the server (best-effort, may go stale) + linkedAt DateTime @default(now()) + + @@unique([platform, platformServerId]) + @@index([userId]) +} + +// Maps a platform user identity (a single Discord / Telegram / Slack user) to +// an AutoGPT account for 1:1 DM conversations with the bot. Independent from +// PlatformLink — a user who owns a linked server must still link their DMs +// separately. +model PlatformUserLink { + id String @id @default(uuid()) + userId String // AutoGPT user ID + User User @relation(fields: [userId], references: [id], onDelete: Cascade) + platform PlatformType + platformUserId String // Individual's user ID on the platform + platformUsername String? // Display name at link time (best-effort) + linkedAt DateTime @default(now()) + + @@unique([platform, platformUserId]) + @@index([userId]) +} + +// One-time tokens for either the server linking flow or the DM (user) linking +// flow. linkType determines which target is populated — SERVER tokens carry +// platformServerId + serverName + ownerPlatformUserId, USER tokens carry +// platformUserId only. +model PlatformLinkToken { + id String @id @default(uuid()) + token String @unique + platform PlatformType + linkType PlatformLinkType @default(SERVER) + // SERVER token fields (null for USER tokens) + platformServerId String? // Server/guild/group ID being linked + serverName String? // Server display name + channelId String? // Channel to send confirmation back to + // Always set — platform user ID of the person who will claim ownership + platformUserId String + platformUsername String? // Their display name + expiresAt DateTime + usedAt DateTime? + createdAt DateTime @default(now()) + + @@index([platform, platformServerId]) + @@index([platform, platformUserId]) + @@index([expiresAt]) +} diff --git a/autogpt_platform/frontend/package.json b/autogpt_platform/frontend/package.json index 292e64e8dd..68815bcf79 100644 --- a/autogpt_platform/frontend/package.json +++ b/autogpt_platform/frontend/package.json @@ -92,11 +92,12 @@ "geist": "1.5.1", "highlight.js": "11.11.1", "jaro-winkler": "0.2.8", + "jszip": "3.10.1", "katex": "0.16.25", "launchdarkly-react-client-sdk": "3.9.0", "lodash": "4.17.21", "lucide-react": "0.552.0", - "next": "15.4.10", + "next": "15.4.11", "next-themes": "0.4.6", "nuqs": "2.7.2", "posthog-js": "1.334.1", diff --git a/autogpt_platform/frontend/pnpm-lock.yaml b/autogpt_platform/frontend/pnpm-lock.yaml index ad6429ac52..82a76350af 100644 --- a/autogpt_platform/frontend/pnpm-lock.yaml +++ b/autogpt_platform/frontend/pnpm-lock.yaml @@ -26,7 +26,7 @@ importers: version: 5.2.2(react-hook-form@7.66.0(react@18.3.1)) '@next/third-parties': specifier: 15.4.6 - version: 15.4.6(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1) + version: 15.4.6(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1) '@phosphor-icons/react': specifier: 2.1.10 version: 2.1.10(react-dom@18.3.1(react@18.3.1))(react@18.3.1) @@ -107,7 +107,7 @@ importers: version: 6.1.2(@rjsf/utils@6.1.2(react@18.3.1)) '@sentry/nextjs': specifier: 10.27.0 - version: 10.27.0(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)(webpack@5.104.1(esbuild@0.25.12)) + version: 10.27.0(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)(webpack@5.104.1(esbuild@0.25.12)) '@streamdown/cjk': specifier: 1.0.1 version: 1.0.1(@types/mdast@4.0.4)(micromark-util-types@2.0.2)(micromark@4.0.2)(react@18.3.1)(unified@11.0.5) @@ -134,10 +134,10 @@ importers: version: 0.2.4 '@vercel/analytics': specifier: 1.5.0 - version: 1.5.0(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1) + version: 1.5.0(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1) '@vercel/speed-insights': specifier: 1.2.0 - version: 1.2.0(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1) + version: 1.2.0(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1) '@xyflow/react': specifier: 12.9.2 version: 12.9.2(@types/react@18.3.17)(immer@11.1.3)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) @@ -185,13 +185,16 @@ importers: version: 12.23.24(@emotion/is-prop-valid@1.2.2)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) geist: specifier: 1.5.1 - version: 1.5.1(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)) + version: 1.5.1(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)) highlight.js: specifier: 11.11.1 version: 11.11.1 jaro-winkler: specifier: 0.2.8 version: 0.2.8 + jszip: + specifier: 3.10.1 + version: 3.10.1 katex: specifier: 0.16.25 version: 0.16.25 @@ -205,14 +208,14 @@ importers: specifier: 0.552.0 version: 0.552.0(react@18.3.1) next: - specifier: 15.4.10 - version: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + specifier: 15.4.11 + version: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) next-themes: specifier: 0.4.6 version: 0.4.6(react-dom@18.3.1(react@18.3.1))(react@18.3.1) nuqs: specifier: 2.7.2 - version: 2.7.2(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1) + version: 2.7.2(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1) posthog-js: specifier: 1.334.1 version: 1.334.1 @@ -330,7 +333,7 @@ importers: version: 9.1.5(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2)(vite@7.3.1(@types/node@24.10.0)(jiti@2.6.1)(terser@5.44.1)(yaml@2.8.2))) '@storybook/nextjs': specifier: 9.1.5 - version: 9.1.5(esbuild@0.25.12)(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2)(vite@7.3.1(@types/node@24.10.0)(jiti@2.6.1)(terser@5.44.1)(yaml@2.8.2)))(type-fest@4.41.0)(typescript@5.9.3)(webpack-hot-middleware@2.26.1)(webpack@5.104.1(esbuild@0.25.12)) + version: 9.1.5(esbuild@0.25.12)(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2)(vite@7.3.1(@types/node@24.10.0)(jiti@2.6.1)(terser@5.44.1)(yaml@2.8.2)))(type-fest@4.41.0)(typescript@5.9.3)(webpack-hot-middleware@2.26.1)(webpack@5.104.1(esbuild@0.25.12)) '@tanstack/eslint-plugin-query': specifier: 5.91.2 version: 5.91.2(eslint@8.57.1)(typescript@5.9.3) @@ -1844,8 +1847,8 @@ packages: '@neoconfetti/react@1.0.0': resolution: {integrity: sha512-klcSooChXXOzIm+SE5IISIAn3bYzYfPjbX7D7HoqZL84oAfgREeSg5vSIaSFH+DaGzzvImTyWe1OyrJ67vik4A==} - '@next/env@15.4.10': - resolution: {integrity: sha512-knhmoJ0Vv7VRf6pZEPSnciUG1S4bIhWx+qTYBW/AjxEtlzsiNORPk8sFDCEvqLfmKuey56UB9FL1UdHEV3uBrg==} + '@next/env@15.4.11': + resolution: {integrity: sha512-mIYp/091eYfPFezKX7ZPTWqrmSXq+ih6+LcUyKvLmeLQGhlPtot33kuEOd4U+xAA7sFfj21+OtCpIZx0g5SpvQ==} '@next/eslint-plugin-next@15.5.7': resolution: {integrity: sha512-DtRU2N7BkGr8r+pExfuWHwMEPX5SD57FeA6pxdgCHODo+b/UgIgjE+rgWKtJAbEbGhVZ2jtHn4g3wNhWFoNBQQ==} @@ -5919,6 +5922,9 @@ packages: engines: {node: '>=16.x'} hasBin: true + immediate@3.0.6: + resolution: {integrity: sha512-XXOFtyqDjNDAQxVfYxuF7g9Il/IbWmmlQg2MYKOH8ExIT1qg6xc4zyS3HaEEATgs1btfzxq15ciUiY7gjSXRGQ==} + immer@10.2.0: resolution: {integrity: sha512-d/+XTN3zfODyjr89gM3mPq1WNX2B8pYsu7eORitdwyA2sBubnTl3laYlBk4sXY5FUa5qTZGBDPJICVbvqzjlbw==} @@ -6279,6 +6285,9 @@ packages: resolution: {integrity: sha512-ZZow9HBI5O6EPgSJLUb8n2NKgmVWTwCvHGwFuJlMjvLFqlGG6pjirPhtdsseaLZjSibD8eegzmYpUZwoIlj2cQ==} engines: {node: '>=4.0'} + jszip@3.10.1: + resolution: {integrity: sha512-xXDvecyTpGLrqFrvkrUSoxxfJI5AH7U8zxxtVclpsUtMCq4JQ290LY8AW5c7Ggnr/Y/oK+bQMbqK2qmtk3pN4g==} + junit-report-builder@5.1.1: resolution: {integrity: sha512-ZNOIIGMzqCGcHQEA2Q4rIQQ3Df6gSIfne+X9Rly9Bc2y55KxAZu8iGv+n2pP0bLf0XAOctJZgeloC54hWzCahQ==} engines: {node: '>=16'} @@ -6348,6 +6357,9 @@ packages: resolution: {integrity: sha512-+bT2uH4E5LGE7h/n3evcS/sQlJXCpIp6ym8OWJ5eV6+67Dsql/LaaT7qJBAt2rzfoa/5QBGBhxDix1dMt2kQKQ==} engines: {node: '>= 0.8.0'} + lie@3.3.0: + resolution: {integrity: sha512-UaiMJzeWRlEujzAuw5LokY1L5ecNQYZKfmyZ9L7wDHb/p5etKaxXhohBcrw0EYby+G/NA52vRSN4N39dxHAIwQ==} + lilconfig@3.1.3: resolution: {integrity: sha512-/vlFKAoH5Cgt3Ie+JLhRbwOsCQePABiU3tJ1egGvyQ+33R/vcwM2Zl2QR/LzjsBeItPt3oSVXapn+m4nQDvpzw==} engines: {node: '>=14'} @@ -6839,8 +6851,8 @@ packages: react: ^16.8 || ^17 || ^18 || ^19 || ^19.0.0-rc react-dom: ^16.8 || ^17 || ^18 || ^19 || ^19.0.0-rc - next@15.4.10: - resolution: {integrity: sha512-itVlc79QjpKMFMRhP+kbGKaSG/gZM6RCvwhEbwmCNF06CdDiNaoHcbeg0PqkEa2GOcn8KJ0nnc7+yL7EjoYLHQ==} + next@15.4.11: + resolution: {integrity: sha512-IJRyXal45mIsshZI5XJne/intjusslUP1F+FHVBIyMGEqbYtIq1Irdx5vdWBBg58smviPDycmDeV6txsfkv1RQ==} engines: {node: ^18.18.0 || ^19.8.0 || >= 20.0.0} hasBin: true peerDependencies: @@ -10423,7 +10435,7 @@ snapshots: '@neoconfetti/react@1.0.0': {} - '@next/env@15.4.10': {} + '@next/env@15.4.11': {} '@next/eslint-plugin-next@15.5.7': dependencies: @@ -10453,9 +10465,9 @@ snapshots: '@next/swc-win32-x64-msvc@15.4.8': optional: true - '@next/third-parties@15.4.6(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)': + '@next/third-parties@15.4.6(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)': dependencies: - next: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + next: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) react: 18.3.1 third-party-capital: 1.0.20 @@ -11770,7 +11782,7 @@ snapshots: '@sentry/core@10.27.0': {} - '@sentry/nextjs@10.27.0(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)(webpack@5.104.1(esbuild@0.25.12))': + '@sentry/nextjs@10.27.0(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)(webpack@5.104.1(esbuild@0.25.12))': dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/semantic-conventions': 1.38.0 @@ -11783,7 +11795,7 @@ snapshots: '@sentry/react': 10.27.0(react@18.3.1) '@sentry/vercel-edge': 10.27.0 '@sentry/webpack-plugin': 4.6.1(webpack@5.104.1(esbuild@0.25.12)) - next: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + next: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) resolve: 1.22.8 rollup: 4.55.1 stacktrace-parser: 0.1.11 @@ -12162,7 +12174,7 @@ snapshots: react: 18.3.1 react-dom: 18.3.1(react@18.3.1) - '@storybook/nextjs@9.1.5(esbuild@0.25.12)(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2)(vite@7.3.1(@types/node@24.10.0)(jiti@2.6.1)(terser@5.44.1)(yaml@2.8.2)))(type-fest@4.41.0)(typescript@5.9.3)(webpack-hot-middleware@2.26.1)(webpack@5.104.1(esbuild@0.25.12))': + '@storybook/nextjs@9.1.5(esbuild@0.25.12)(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2)(vite@7.3.1(@types/node@24.10.0)(jiti@2.6.1)(terser@5.44.1)(yaml@2.8.2)))(type-fest@4.41.0)(typescript@5.9.3)(webpack-hot-middleware@2.26.1)(webpack@5.104.1(esbuild@0.25.12))': dependencies: '@babel/core': 7.28.5 '@babel/plugin-syntax-bigint': 7.8.3(@babel/core@7.28.5) @@ -12186,7 +12198,7 @@ snapshots: css-loader: 6.11.0(webpack@5.104.1(esbuild@0.25.12)) image-size: 2.0.2 loader-utils: 3.3.1 - next: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + next: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) node-polyfill-webpack-plugin: 2.0.1(webpack@5.104.1(esbuild@0.25.12)) postcss: 8.5.6 postcss-loader: 8.2.0(postcss@8.5.6)(typescript@5.9.3)(webpack@5.104.1(esbuild@0.25.12)) @@ -12872,16 +12884,16 @@ snapshots: '@unrs/resolver-binding-win32-x64-msvc@1.11.1': optional: true - '@vercel/analytics@1.5.0(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)': + '@vercel/analytics@1.5.0(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)': optionalDependencies: - next: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + next: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) react: 18.3.1 '@vercel/oidc@3.1.0': {} - '@vercel/speed-insights@1.2.0(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)': + '@vercel/speed-insights@1.2.0(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)': optionalDependencies: - next: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + next: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) react: 18.3.1 '@vitejs/plugin-react@5.1.2(vite@7.3.1(@types/node@24.10.0)(jiti@2.6.1)(terser@5.44.1)(yaml@2.8.2))': @@ -14449,8 +14461,8 @@ snapshots: '@typescript-eslint/parser': 8.52.0(eslint@8.57.1)(typescript@5.9.3) eslint: 8.57.1 eslint-import-resolver-node: 0.3.9 - eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1) - eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1) + eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1) + eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1) eslint-plugin-jsx-a11y: 6.10.2(eslint@8.57.1) eslint-plugin-react: 7.37.5(eslint@8.57.1) eslint-plugin-react-hooks: 5.2.0(eslint@8.57.1) @@ -14469,7 +14481,7 @@ snapshots: transitivePeerDependencies: - supports-color - eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1): + eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1): dependencies: '@nolyfill/is-core-module': 1.0.39 debug: 4.4.3 @@ -14480,22 +14492,22 @@ snapshots: tinyglobby: 0.2.15 unrs-resolver: 1.11.1 optionalDependencies: - eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1) + eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1) transitivePeerDependencies: - supports-color - eslint-module-utils@2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1): + eslint-module-utils@2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1): dependencies: debug: 3.2.7 optionalDependencies: '@typescript-eslint/parser': 8.52.0(eslint@8.57.1)(typescript@5.9.3) eslint: 8.57.1 eslint-import-resolver-node: 0.3.9 - eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1) + eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1) transitivePeerDependencies: - supports-color - eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1): + eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1): dependencies: '@rtsao/scc': 1.1.0 array-includes: 3.1.9 @@ -14506,7 +14518,7 @@ snapshots: doctrine: 2.1.0 eslint: 8.57.1 eslint-import-resolver-node: 0.3.9 - eslint-module-utils: 2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1) + eslint-module-utils: 2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1) hasown: 2.0.2 is-core-module: 2.16.1 is-glob: 4.0.3 @@ -14877,9 +14889,9 @@ snapshots: functions-have-names@1.2.3: {} - geist@1.5.1(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)): + geist@1.5.1(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)): dependencies: - next: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + next: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) generator-function@2.0.1: {} @@ -15290,6 +15302,8 @@ snapshots: image-size@2.0.2: {} + immediate@3.0.6: {} + immer@10.2.0: {} immer@11.1.3: {} @@ -15646,6 +15660,13 @@ snapshots: object.assign: 4.1.7 object.values: 1.2.1 + jszip@3.10.1: + dependencies: + lie: 3.3.0 + pako: 1.0.11 + readable-stream: 2.3.8 + setimmediate: 1.0.5 + junit-report-builder@5.1.1: dependencies: lodash: 4.17.21 @@ -15739,6 +15760,10 @@ snapshots: prelude-ls: 1.2.1 type-check: 0.4.0 + lie@3.3.0: + dependencies: + immediate: 3.0.6 + lilconfig@3.1.3: {} lines-and-columns@1.2.4: {} @@ -16465,9 +16490,9 @@ snapshots: react: 18.3.1 react-dom: 18.3.1(react@18.3.1) - next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1): dependencies: - '@next/env': 15.4.10 + '@next/env': 15.4.11 '@swc/helpers': 0.5.15 caniuse-lite: 1.0.30001762 postcss: 8.4.31 @@ -16569,12 +16594,12 @@ snapshots: dependencies: boolbase: 1.0.0 - nuqs@2.7.2(next@15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1): + nuqs@2.7.2(next@15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1): dependencies: '@standard-schema/spec': 1.0.0 react: 18.3.1 optionalDependencies: - next: 15.4.10(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + next: 15.4.11(@babel/core@7.28.5)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) oas-kit-common@1.0.8: dependencies: diff --git a/autogpt_platform/frontend/src/app/(no-navbar)/share/[token]/page.tsx b/autogpt_platform/frontend/src/app/(no-navbar)/share/[token]/page.tsx index 1c37c6c72f..3db91f411b 100644 --- a/autogpt_platform/frontend/src/app/(no-navbar)/share/[token]/page.tsx +++ b/autogpt_platform/frontend/src/app/(no-navbar)/share/[token]/page.tsx @@ -119,7 +119,7 @@ export default function SharePage() { Output - + diff --git a/autogpt_platform/frontend/src/app/(no-navbar)/share/layout.tsx b/autogpt_platform/frontend/src/app/(no-navbar)/share/layout.tsx index 3b79d323c0..a0d4654ff8 100644 --- a/autogpt_platform/frontend/src/app/(no-navbar)/share/layout.tsx +++ b/autogpt_platform/frontend/src/app/(no-navbar)/share/layout.tsx @@ -1,4 +1,6 @@ import type { Metadata } from "next"; +import Image from "next/image"; +import Link from "next/link"; export const metadata: Metadata = { title: "Shared Agent Run - AutoGPT", @@ -13,6 +15,27 @@ export default function ShareLayout({ }) { return (
+
+
+ + AutoGPT + AutoGPT + +
+
{children}
); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ArtifactPanel/components/ArtifactContent.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ArtifactPanel/components/ArtifactContent.tsx index 506cbc3b60..7a65188b86 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ArtifactPanel/components/ArtifactContent.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ArtifactPanel/components/ArtifactContent.tsx @@ -6,9 +6,11 @@ import { Suspense, useState } from "react"; import { Skeleton } from "@/components/ui/skeleton"; import type { ArtifactRef } from "../../../store"; import type { ArtifactClassification } from "../helpers"; +import { ArtifactErrorBoundary } from "./ArtifactErrorBoundary"; import { ArtifactReactPreview } from "./ArtifactReactPreview"; import { ArtifactSkeleton } from "./ArtifactSkeleton"; import { + FRAGMENT_LINK_INTERCEPTOR_SCRIPT, TAILWIND_CDN_URL, wrapWithHeadInjection, } from "@/lib/iframe-sandbox-csp"; @@ -53,20 +55,35 @@ function ArtifactContentLoader({ return (
- + + +
); } +function withCacheBust(src: string, nonce: number): string { + if (nonce === 0) return src; + const sep = src.includes("?") ? "&" : "?"; + return `${src}${sep}_retry=${nonce}`; +} + function ArtifactImage({ src, alt }: { src: string; alt: string }) { const [loaded, setLoaded] = useState(false); const [error, setError] = useState(false); + // Incremented on every Try Again so the URL changes and the browser + // can't reuse a negative-cached response (SECRT-2221). + const [retryNonce, setRetryNonce] = useState(0); if (error) { return ( @@ -80,6 +97,7 @@ function ArtifactImage({ src, alt }: { src: string; alt: string }) { onClick={() => { setError(false); setLoaded(false); + setRetryNonce((n) => n + 1); }} className="rounded-md border border-zinc-200 bg-white px-3 py-1.5 text-xs font-medium text-zinc-700 shadow-sm transition-colors hover:bg-zinc-50 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-violet-400" > @@ -96,7 +114,7 @@ function ArtifactImage({ src, alt }: { src: string; alt: string }) { )} {/* eslint-disable-next-line @next/next/no-img-element */} {alt} setLoaded(true)} @@ -109,6 +127,7 @@ function ArtifactImage({ src, alt }: { src: string; alt: string }) { function ArtifactVideo({ src }: { src: string }) { const [loaded, setLoaded] = useState(false); const [error, setError] = useState(false); + const [retryNonce, setRetryNonce] = useState(0); if (error) { return ( @@ -122,6 +141,7 @@ function ArtifactVideo({ src }: { src: string }) { onClick={() => { setError(false); setLoaded(false); + setRetryNonce((n) => n + 1); }} className="rounded-md border border-zinc-200 bg-white px-3 py-1.5 text-xs font-medium text-zinc-700 shadow-sm transition-colors hover:bg-zinc-50 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-violet-400" > @@ -137,7 +157,7 @@ function ArtifactVideo({ src }: { src: string }) { )}