mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
12 Commits
feat/platf
...
feat/agent
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3db2a944f7 | ||
|
|
59192102a6 | ||
|
|
65cca9bef8 | ||
|
|
6b32e43d84 | ||
|
|
b73d05c23e | ||
|
|
8277cce835 | ||
|
|
57b17dc8e1 | ||
|
|
a20188ae59 | ||
|
|
c410be890e | ||
|
|
37d9863552 | ||
|
|
2f42ff9b47 | ||
|
|
914efc53e5 |
@@ -178,6 +178,7 @@ SMTP_USERNAME=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# Business & Marketing Tools
|
||||
AGENTMAIL_API_KEY=
|
||||
APOLLO_API_KEY=
|
||||
ENRICHLAYER_API_KEY=
|
||||
AYRSHARE_API_KEY=
|
||||
|
||||
@@ -31,7 +31,10 @@ from backend.data.model import (
|
||||
UserPasswordCredentials,
|
||||
is_sdk_default,
|
||||
)
|
||||
from backend.integrations.credentials_store import provider_matches
|
||||
from backend.integrations.credentials_store import (
|
||||
is_system_credential,
|
||||
provider_matches,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -618,6 +621,11 @@ async def delete_credential(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if is_system_credential(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="System-managed credentials cannot be deleted",
|
||||
)
|
||||
creds = await creds_manager.store.get_creds_by_id(auth.user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -40,11 +40,15 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
||||
from backend.data.user import get_user_integrations
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||
from backend.integrations.credentials_store import provider_matches
|
||||
from backend.integrations.credentials_store import (
|
||||
is_system_credential,
|
||||
provider_matches,
|
||||
)
|
||||
from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
create_mcp_oauth_handler,
|
||||
)
|
||||
from backend.integrations.managed_credentials import ensure_managed_credentials
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
@@ -110,6 +114,7 @@ class CredentialsMetaResponse(BaseModel):
|
||||
default=None,
|
||||
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
|
||||
)
|
||||
is_managed: bool = False
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -148,6 +153,7 @@ def to_meta_response(cred: Credentials) -> CredentialsMetaResponse:
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
is_managed=cred.is_managed,
|
||||
)
|
||||
|
||||
|
||||
@@ -224,6 +230,9 @@ async def callback(
|
||||
async def list_credentials(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
# Fire-and-forget: provision missing managed credentials in the background.
|
||||
# The credential appears on the next page load; listing is never blocked.
|
||||
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
|
||||
credentials = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
return [
|
||||
@@ -238,6 +247,7 @@ async def list_credentials_by_provider(
|
||||
],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
|
||||
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
|
||||
return [
|
||||
@@ -332,6 +342,11 @@ async def delete_credentials(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if is_system_credential(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="System-managed credentials cannot be deleted",
|
||||
)
|
||||
creds = await creds_manager.store.get_creds_by_id(user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
@@ -342,6 +357,11 @@ async def delete_credentials(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials not found",
|
||||
)
|
||||
if creds.is_managed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="AutoGPT-managed credentials cannot be deleted",
|
||||
)
|
||||
|
||||
try:
|
||||
await remove_all_webhooks_for_credentials(user_id, creds, force)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Tests for credentials API security: no secret leakage, SDK defaults filtered."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
@@ -276,3 +277,294 @@ class TestCreateCredentialNoSecretInResponse:
|
||||
|
||||
assert resp.status_code == 403
|
||||
mock_mgr.create.assert_not_called()
|
||||
|
||||
|
||||
class TestManagedCredentials:
|
||||
"""AutoGPT-managed credentials cannot be deleted by users."""
|
||||
|
||||
def test_delete_is_managed_returns_403(self):
|
||||
cred = APIKeyCredentials(
|
||||
id="managed-cred-1",
|
||||
provider="agent_mail",
|
||||
title="AgentMail (managed by AutoGPT)",
|
||||
api_key=SecretStr("sk-managed-key"),
|
||||
is_managed=True,
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_creds_by_id = AsyncMock(return_value=cred)
|
||||
resp = client.request("DELETE", "/agent_mail/credentials/managed-cred-1")
|
||||
|
||||
assert resp.status_code == 403
|
||||
assert "AutoGPT-managed" in resp.json()["detail"]
|
||||
|
||||
def test_list_credentials_includes_is_managed_field(self):
|
||||
managed = APIKeyCredentials(
|
||||
id="managed-1",
|
||||
provider="agent_mail",
|
||||
title="AgentMail (managed)",
|
||||
api_key=SecretStr("sk-key"),
|
||||
is_managed=True,
|
||||
)
|
||||
regular = APIKeyCredentials(
|
||||
id="regular-1",
|
||||
provider="openai",
|
||||
title="My Key",
|
||||
api_key=SecretStr("sk-key"),
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_all_creds = AsyncMock(return_value=[managed, regular])
|
||||
resp = client.get("/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
managed_cred = next(c for c in data if c["id"] == "managed-1")
|
||||
regular_cred = next(c for c in data if c["id"] == "regular-1")
|
||||
assert managed_cred["is_managed"] is True
|
||||
assert regular_cred["is_managed"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Managed credential provisioning infrastructure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_managed_cred(
|
||||
provider: str = "agent_mail", pod_id: str = "pod-abc"
|
||||
) -> APIKeyCredentials:
|
||||
return APIKeyCredentials(
|
||||
id="managed-auto",
|
||||
provider=provider,
|
||||
title="AgentMail (managed by AutoGPT)",
|
||||
api_key=SecretStr("sk-pod-key"),
|
||||
is_managed=True,
|
||||
metadata={"pod_id": pod_id},
|
||||
)
|
||||
|
||||
|
||||
def _make_store_mock(**kwargs) -> MagicMock:
|
||||
"""Create a store mock with a working async ``locks()`` context manager."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def _noop_locked(key):
|
||||
yield
|
||||
|
||||
locks_obj = MagicMock()
|
||||
locks_obj.locked = _noop_locked
|
||||
|
||||
store = MagicMock(**kwargs)
|
||||
store.locks = AsyncMock(return_value=locks_obj)
|
||||
return store
|
||||
|
||||
|
||||
class TestEnsureManagedCredentials:
|
||||
"""Unit tests for the ensure/cleanup helpers in managed_credentials.py."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provisions_when_missing(self):
|
||||
"""Provider.provision() is called when no managed credential exists."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
cred = _make_managed_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=True)
|
||||
provider.provision = AsyncMock(return_value=cred)
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock(return_value=False)
|
||||
store.add_managed_credential = AsyncMock()
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
provider.provision.assert_awaited_once_with("user-1")
|
||||
store.add_managed_credential.assert_awaited_once_with("user-1", cred)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_already_exists(self):
|
||||
"""Provider.provision() is NOT called when managed credential exists."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=True)
|
||||
provider.provision = AsyncMock()
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock(return_value=True)
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
provider.provision.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_unavailable(self):
|
||||
"""Provider.provision() is NOT called when provider is not available."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=False)
|
||||
provider.provision = AsyncMock()
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock()
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
provider.provision.assert_not_awaited()
|
||||
store.has_managed_credential.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provision_failure_does_not_propagate(self):
|
||||
"""A failed provision is logged but does not raise."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=True)
|
||||
provider.provision = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock(return_value=False)
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
# No exception raised — provisioning failure is swallowed.
|
||||
|
||||
|
||||
class TestCleanupManagedCredentials:
|
||||
"""Unit tests for cleanup_managed_credentials."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_deprovision_for_managed_creds(self):
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
cleanup_managed_credentials,
|
||||
)
|
||||
|
||||
cred = _make_managed_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "agent_mail"
|
||||
provider.deprovision = AsyncMock()
|
||||
|
||||
store = MagicMock()
|
||||
store.get_all_creds = AsyncMock(return_value=[cred])
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["agent_mail"] = provider
|
||||
try:
|
||||
await cleanup_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
|
||||
provider.deprovision.assert_awaited_once_with("user-1", cred)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_non_managed_creds(self):
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
cleanup_managed_credentials,
|
||||
)
|
||||
|
||||
regular = _make_api_key_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "openai"
|
||||
provider.deprovision = AsyncMock()
|
||||
|
||||
store = MagicMock()
|
||||
store.get_all_creds = AsyncMock(return_value=[regular])
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["openai"] = provider
|
||||
try:
|
||||
await cleanup_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
|
||||
provider.deprovision.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deprovision_failure_does_not_propagate(self):
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
cleanup_managed_credentials,
|
||||
)
|
||||
|
||||
cred = _make_managed_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "agent_mail"
|
||||
provider.deprovision = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
store = MagicMock()
|
||||
store.get_all_creds = AsyncMock(return_value=[cred])
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["agent_mail"] = provider
|
||||
try:
|
||||
await cleanup_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
|
||||
# No exception raised — cleanup failure is swallowed.
|
||||
|
||||
@@ -12,6 +12,7 @@ Tests cover:
|
||||
5. Complete OAuth flow end-to-end
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
@@ -58,14 +59,27 @@ async def test_user(server, test_user_id: str):
|
||||
|
||||
yield test_user_id
|
||||
|
||||
# Cleanup - delete in correct order due to foreign key constraints
|
||||
await PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id})
|
||||
await PrismaOAuthRefreshToken.prisma().delete_many(where={"userId": test_user_id})
|
||||
await PrismaOAuthAuthorizationCode.prisma().delete_many(
|
||||
where={"userId": test_user_id}
|
||||
)
|
||||
await PrismaOAuthApplication.prisma().delete_many(where={"ownerId": test_user_id})
|
||||
await PrismaUser.prisma().delete(where={"id": test_user_id})
|
||||
# Cleanup - delete in correct order due to foreign key constraints.
|
||||
# Wrap in try/except because the event loop or Prisma engine may already
|
||||
# be closed during session teardown on Python 3.12+.
|
||||
try:
|
||||
await asyncio.gather(
|
||||
PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id}),
|
||||
PrismaOAuthRefreshToken.prisma().delete_many(
|
||||
where={"userId": test_user_id}
|
||||
),
|
||||
PrismaOAuthAuthorizationCode.prisma().delete_many(
|
||||
where={"userId": test_user_id}
|
||||
),
|
||||
)
|
||||
await asyncio.gather(
|
||||
PrismaOAuthApplication.prisma().delete_many(
|
||||
where={"ownerId": test_user_id}
|
||||
),
|
||||
PrismaUser.prisma().delete(where={"id": test_user_id}),
|
||||
)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
# Platform bot linking API
|
||||
@@ -1,35 +0,0 @@
|
||||
"""Bot API key authentication for platform linking endpoints."""
|
||||
|
||||
import hmac
|
||||
import os
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
async def get_bot_api_key(request: Request) -> str | None:
|
||||
"""Extract the bot API key from the X-Bot-API-Key header."""
|
||||
return request.headers.get("x-bot-api-key")
|
||||
|
||||
|
||||
def check_bot_api_key(api_key: str | None) -> None:
|
||||
"""Validate the bot API key. Uses constant-time comparison.
|
||||
|
||||
Reads the key from env on each call so rotated secrets take effect
|
||||
without restarting the process.
|
||||
"""
|
||||
configured_key = os.getenv("PLATFORM_BOT_API_KEY", "")
|
||||
|
||||
if not configured_key:
|
||||
settings = Settings()
|
||||
if settings.config.enable_auth:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Bot API key not configured.",
|
||||
)
|
||||
# Auth disabled (local dev) — allow without key
|
||||
return
|
||||
|
||||
if not api_key or not hmac.compare_digest(api_key, configured_key):
|
||||
raise HTTPException(status_code=401, detail="Invalid bot API key.")
|
||||
@@ -1,170 +0,0 @@
|
||||
"""
|
||||
Bot Chat Proxy endpoints.
|
||||
|
||||
Allows the bot service to send messages to CoPilot on behalf of
|
||||
linked users, authenticated via bot API key.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from prisma.models import PlatformLink
|
||||
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.executor.utils import enqueue_copilot_turn
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
append_and_save_message,
|
||||
create_chat_session,
|
||||
get_chat_session,
|
||||
)
|
||||
from backend.copilot.response_model import StreamFinish
|
||||
|
||||
from .auth import check_bot_api_key, get_bot_api_key
|
||||
from .models import BotChatRequest, BotChatSessionResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/chat/session",
|
||||
response_model=BotChatSessionResponse,
|
||||
summary="Create a CoPilot session for a linked user (bot-facing)",
|
||||
)
|
||||
async def bot_create_session(
|
||||
request: BotChatRequest,
|
||||
x_bot_api_key: str | None = Depends(get_bot_api_key),
|
||||
) -> BotChatSessionResponse:
|
||||
"""Creates a new CoPilot chat session on behalf of a linked user."""
|
||||
check_bot_api_key(x_bot_api_key)
|
||||
|
||||
link = await PlatformLink.prisma().find_first(where={"userId": request.user_id})
|
||||
if not link:
|
||||
raise HTTPException(status_code=404, detail="User has no platform links.")
|
||||
|
||||
session = await create_chat_session(request.user_id)
|
||||
|
||||
return BotChatSessionResponse(session_id=session.session_id)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/chat/stream",
|
||||
summary="Stream a CoPilot response for a linked user (bot-facing)",
|
||||
)
|
||||
async def bot_chat_stream(
|
||||
request: BotChatRequest,
|
||||
x_bot_api_key: str | None = Depends(get_bot_api_key),
|
||||
):
|
||||
"""
|
||||
Send a message to CoPilot on behalf of a linked user and stream
|
||||
the response back as Server-Sent Events.
|
||||
|
||||
The bot authenticates with its API key — no user JWT needed.
|
||||
"""
|
||||
check_bot_api_key(x_bot_api_key)
|
||||
|
||||
user_id = request.user_id
|
||||
|
||||
# Verify user has a platform link
|
||||
link = await PlatformLink.prisma().find_first(where={"userId": user_id})
|
||||
if not link:
|
||||
raise HTTPException(status_code=404, detail="User has no platform links.")
|
||||
|
||||
# Get or create session
|
||||
session_id = request.session_id
|
||||
if session_id:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found.")
|
||||
else:
|
||||
session = await create_chat_session(user_id)
|
||||
session_id = session.session_id
|
||||
|
||||
# Save user message
|
||||
message = ChatMessage(role="user", content=request.message)
|
||||
await append_and_save_message(session_id, message)
|
||||
|
||||
# Create a turn and enqueue
|
||||
turn_id = str(uuid4())
|
||||
|
||||
await stream_registry.create_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id="chat_stream",
|
||||
tool_name="chat",
|
||||
turn_id=turn_id,
|
||||
)
|
||||
|
||||
subscribe_from_id = "0-0"
|
||||
|
||||
await enqueue_copilot_turn(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=request.message,
|
||||
turn_id=turn_id,
|
||||
is_user_message=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Bot chat: user ...%s, session %s, turn %s",
|
||||
user_id[-8:],
|
||||
session_id,
|
||||
turn_id,
|
||||
)
|
||||
|
||||
async def event_generator():
|
||||
subscriber_queue = None
|
||||
try:
|
||||
subscriber_queue = await stream_registry.subscribe_to_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
last_message_id=subscribe_from_id,
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||
|
||||
if isinstance(chunk, str):
|
||||
yield chunk
|
||||
else:
|
||||
yield chunk.to_sse()
|
||||
|
||||
if isinstance(chunk, StreamFinish) or (
|
||||
isinstance(chunk, str) and "[DONE]" in chunk
|
||||
):
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
yield ": keepalive\n\n"
|
||||
|
||||
except Exception:
|
||||
logger.exception("Bot chat stream error for session %s", session_id)
|
||||
yield 'data: {"type": "error", "content": "Stream error"}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
finally:
|
||||
if subscriber_queue is not None:
|
||||
await stream_registry.unsubscribe_from_session(
|
||||
session_id=session_id,
|
||||
subscriber_queue=subscriber_queue,
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Session-Id": session_id,
|
||||
},
|
||||
)
|
||||
@@ -1,107 +0,0 @@
|
||||
"""Pydantic models for the platform bot linking API."""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Platform(str, Enum):
|
||||
"""Supported platform types (mirrors Prisma PlatformType)."""
|
||||
|
||||
DISCORD = "DISCORD"
|
||||
TELEGRAM = "TELEGRAM"
|
||||
SLACK = "SLACK"
|
||||
TEAMS = "TEAMS"
|
||||
WHATSAPP = "WHATSAPP"
|
||||
GITHUB = "GITHUB"
|
||||
LINEAR = "LINEAR"
|
||||
|
||||
|
||||
# ── Request Models ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class CreateLinkTokenRequest(BaseModel):
|
||||
"""Request from the bot service to create a linking token."""
|
||||
|
||||
platform: Platform = Field(description="Platform name")
|
||||
platform_user_id: str = Field(
|
||||
description="The user's ID on the platform",
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
)
|
||||
platform_username: str | None = Field(
|
||||
default=None,
|
||||
description="Display name (best effort)",
|
||||
max_length=255,
|
||||
)
|
||||
channel_id: str | None = Field(
|
||||
default=None,
|
||||
description="Channel ID for sending confirmation back",
|
||||
max_length=255,
|
||||
)
|
||||
|
||||
|
||||
class ResolveRequest(BaseModel):
|
||||
"""Resolve a platform identity to an AutoGPT user."""
|
||||
|
||||
platform: Platform
|
||||
platform_user_id: str = Field(min_length=1, max_length=255)
|
||||
|
||||
|
||||
class BotChatRequest(BaseModel):
|
||||
"""Request from the bot to chat as a linked user."""
|
||||
|
||||
user_id: str = Field(description="The linked AutoGPT user ID")
|
||||
message: str = Field(
|
||||
description="The user's message", min_length=1, max_length=32000
|
||||
)
|
||||
session_id: str | None = Field(
|
||||
default=None,
|
||||
description="Existing chat session ID. If omitted, a new session is created.",
|
||||
)
|
||||
|
||||
|
||||
# ── Response Models ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class LinkTokenResponse(BaseModel):
|
||||
token: str
|
||||
expires_at: datetime
|
||||
link_url: str
|
||||
|
||||
|
||||
class LinkTokenStatusResponse(BaseModel):
|
||||
status: Literal["pending", "linked", "expired"]
|
||||
user_id: str | None = None
|
||||
|
||||
|
||||
class ResolveResponse(BaseModel):
|
||||
linked: bool
|
||||
user_id: str | None = None
|
||||
|
||||
|
||||
class PlatformLinkInfo(BaseModel):
|
||||
id: str
|
||||
platform: str
|
||||
platform_user_id: str
|
||||
platform_username: str | None
|
||||
linked_at: datetime
|
||||
|
||||
|
||||
class ConfirmLinkResponse(BaseModel):
|
||||
success: bool
|
||||
platform: str
|
||||
platform_user_id: str
|
||||
platform_username: str | None
|
||||
|
||||
|
||||
class DeleteLinkResponse(BaseModel):
|
||||
success: bool
|
||||
|
||||
|
||||
class BotChatSessionResponse(BaseModel):
|
||||
"""Returned when creating a new session via the bot proxy."""
|
||||
|
||||
session_id: str
|
||||
@@ -1,340 +0,0 @@
|
||||
"""
|
||||
Platform Bot Linking API routes.
|
||||
|
||||
Enables linking external chat platform identities (Discord, Telegram, Slack, etc.)
|
||||
to AutoGPT user accounts. Used by the multi-platform CoPilot bot.
|
||||
|
||||
Flow:
|
||||
1. Bot calls POST /api/platform-linking/tokens to create a link token
|
||||
for an unlinked platform user.
|
||||
2. Bot sends the user a link: {frontend}/link/{token}
|
||||
3. User clicks the link, logs in to AutoGPT, and the frontend calls
|
||||
POST /api/platform-linking/tokens/{token}/confirm to complete the link.
|
||||
4. Bot can poll GET /api/platform-linking/tokens/{token}/status or just
|
||||
check on next message via GET /api/platform-linking/resolve.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Annotated
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Security
|
||||
from prisma.models import PlatformLink, PlatformLinkToken
|
||||
|
||||
from .auth import check_bot_api_key, get_bot_api_key
|
||||
from .models import (
|
||||
ConfirmLinkResponse,
|
||||
CreateLinkTokenRequest,
|
||||
DeleteLinkResponse,
|
||||
LinkTokenResponse,
|
||||
LinkTokenStatusResponse,
|
||||
PlatformLinkInfo,
|
||||
ResolveRequest,
|
||||
ResolveResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
LINK_TOKEN_EXPIRY_MINUTES = 30
|
||||
|
||||
# Path parameter with validation for link tokens
|
||||
TokenPath = Annotated[
|
||||
str,
|
||||
Path(max_length=64, pattern=r"^[A-Za-z0-9_-]+$"),
|
||||
]
|
||||
|
||||
|
||||
# ── Bot-facing endpoints (API key auth) ───────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tokens",
|
||||
response_model=LinkTokenResponse,
|
||||
summary="Create a link token for an unlinked platform user",
|
||||
)
|
||||
async def create_link_token(
|
||||
request: CreateLinkTokenRequest,
|
||||
x_bot_api_key: str | None = Depends(get_bot_api_key),
|
||||
) -> LinkTokenResponse:
|
||||
"""
|
||||
Called by the bot service when it encounters an unlinked user.
|
||||
Generates a one-time token the user can use to link their account.
|
||||
"""
|
||||
check_bot_api_key(x_bot_api_key)
|
||||
|
||||
platform = request.platform.value
|
||||
|
||||
# Check if already linked
|
||||
existing = await PlatformLink.prisma().find_first(
|
||||
where={
|
||||
"platform": platform,
|
||||
"platformUserId": request.platform_user_id,
|
||||
}
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="This platform account is already linked.",
|
||||
)
|
||||
|
||||
# Invalidate any existing pending tokens for this user
|
||||
await PlatformLinkToken.prisma().update_many(
|
||||
where={
|
||||
"platform": platform,
|
||||
"platformUserId": request.platform_user_id,
|
||||
"usedAt": None,
|
||||
},
|
||||
data={"usedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
# Generate token
|
||||
token = secrets.token_urlsafe(32)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=LINK_TOKEN_EXPIRY_MINUTES
|
||||
)
|
||||
|
||||
await PlatformLinkToken.prisma().create(
|
||||
data={
|
||||
"token": token,
|
||||
"platform": platform,
|
||||
"platformUserId": request.platform_user_id,
|
||||
"platformUsername": request.platform_username,
|
||||
"channelId": request.channel_id,
|
||||
"expiresAt": expires_at,
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Created link token for %s (expires %s)",
|
||||
platform,
|
||||
expires_at.isoformat(),
|
||||
)
|
||||
|
||||
link_base_url = os.getenv(
|
||||
"PLATFORM_LINK_BASE_URL", "https://platform.agpt.co/link"
|
||||
)
|
||||
link_url = f"{link_base_url}/{token}"
|
||||
|
||||
return LinkTokenResponse(
|
||||
token=token,
|
||||
expires_at=expires_at,
|
||||
link_url=link_url,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tokens/{token}/status",
|
||||
response_model=LinkTokenStatusResponse,
|
||||
summary="Check if a link token has been consumed",
|
||||
)
|
||||
async def get_link_token_status(
|
||||
token: TokenPath,
|
||||
x_bot_api_key: str | None = Depends(get_bot_api_key),
|
||||
) -> LinkTokenStatusResponse:
|
||||
"""
|
||||
Called by the bot service to check if a user has completed linking.
|
||||
"""
|
||||
check_bot_api_key(x_bot_api_key)
|
||||
|
||||
link_token = await PlatformLinkToken.prisma().find_unique(where={"token": token})
|
||||
|
||||
if not link_token:
|
||||
raise HTTPException(status_code=404, detail="Token not found")
|
||||
|
||||
if link_token.usedAt is not None:
|
||||
# Token was used — find the linked account
|
||||
link = await PlatformLink.prisma().find_first(
|
||||
where={
|
||||
"platform": link_token.platform,
|
||||
"platformUserId": link_token.platformUserId,
|
||||
}
|
||||
)
|
||||
return LinkTokenStatusResponse(
|
||||
status="linked",
|
||||
user_id=link.userId if link else None,
|
||||
)
|
||||
|
||||
if link_token.expiresAt.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
|
||||
return LinkTokenStatusResponse(status="expired")
|
||||
|
||||
return LinkTokenStatusResponse(status="pending")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/resolve",
|
||||
response_model=ResolveResponse,
|
||||
summary="Resolve a platform identity to an AutoGPT user",
|
||||
)
|
||||
async def resolve_platform_user(
|
||||
request: ResolveRequest,
|
||||
x_bot_api_key: str | None = Depends(get_bot_api_key),
|
||||
) -> ResolveResponse:
|
||||
"""
|
||||
Called by the bot service on every incoming message to check if
|
||||
the platform user has a linked AutoGPT account.
|
||||
"""
|
||||
check_bot_api_key(x_bot_api_key)
|
||||
|
||||
link = await PlatformLink.prisma().find_first(
|
||||
where={
|
||||
"platform": request.platform.value,
|
||||
"platformUserId": request.platform_user_id,
|
||||
}
|
||||
)
|
||||
|
||||
if not link:
|
||||
return ResolveResponse(linked=False)
|
||||
|
||||
return ResolveResponse(linked=True, user_id=link.userId)
|
||||
|
||||
|
||||
# ── User-facing endpoints (JWT auth) ──────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tokens/{token}/confirm",
|
||||
response_model=ConfirmLinkResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Confirm a link token (user must be authenticated)",
|
||||
)
|
||||
async def confirm_link_token(
|
||||
token: TokenPath,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> ConfirmLinkResponse:
|
||||
"""
|
||||
Called by the frontend when the user clicks the link and is logged in.
|
||||
Consumes the token and creates the platform link.
|
||||
Uses atomic update_many to prevent race conditions on double-click.
|
||||
"""
|
||||
link_token = await PlatformLinkToken.prisma().find_unique(where={"token": token})
|
||||
|
||||
if not link_token:
|
||||
raise HTTPException(status_code=404, detail="Token not found.")
|
||||
|
||||
if link_token.usedAt is not None:
|
||||
raise HTTPException(status_code=410, detail="This link has already been used.")
|
||||
|
||||
if link_token.expiresAt.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
|
||||
raise HTTPException(status_code=410, detail="This link has expired.")
|
||||
|
||||
# Atomically mark token as used (only if still unused)
|
||||
updated = await PlatformLinkToken.prisma().update_many(
|
||||
where={"token": token, "usedAt": None},
|
||||
data={"usedAt": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
if updated == 0:
|
||||
raise HTTPException(status_code=410, detail="This link has already been used.")
|
||||
|
||||
# Check if this platform identity is already linked
|
||||
existing = await PlatformLink.prisma().find_first(
|
||||
where={
|
||||
"platform": link_token.platform,
|
||||
"platformUserId": link_token.platformUserId,
|
||||
}
|
||||
)
|
||||
if existing:
|
||||
detail = (
|
||||
"This platform account is already linked to your account."
|
||||
if existing.userId == user_id
|
||||
else "This platform account is already linked to another user."
|
||||
)
|
||||
raise HTTPException(status_code=409, detail=detail)
|
||||
|
||||
# Create the link — catch unique constraint race condition
|
||||
try:
|
||||
await PlatformLink.prisma().create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"platform": link_token.platform,
|
||||
"platformUserId": link_token.platformUserId,
|
||||
"platformUsername": link_token.platformUsername,
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
if "unique" in str(exc).lower():
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="This platform account was just linked by another request.",
|
||||
) from exc
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
"Linked %s:%s to user ...%s",
|
||||
link_token.platform,
|
||||
link_token.platformUserId,
|
||||
user_id[-8:],
|
||||
)
|
||||
|
||||
return ConfirmLinkResponse(
|
||||
success=True,
|
||||
platform=link_token.platform,
|
||||
platform_user_id=link_token.platformUserId,
|
||||
platform_username=link_token.platformUsername,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/links",
|
||||
response_model=list[PlatformLinkInfo],
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="List all platform links for the authenticated user",
|
||||
)
|
||||
async def list_my_links(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> list[PlatformLinkInfo]:
|
||||
"""Returns all platform identities linked to the current user's account."""
|
||||
links = await PlatformLink.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order={"linkedAt": "desc"},
|
||||
)
|
||||
|
||||
return [
|
||||
PlatformLinkInfo(
|
||||
id=link.id,
|
||||
platform=link.platform,
|
||||
platform_user_id=link.platformUserId,
|
||||
platform_username=link.platformUsername,
|
||||
linked_at=link.linkedAt,
|
||||
)
|
||||
for link in links
|
||||
]
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/links/{link_id}",
|
||||
response_model=DeleteLinkResponse,
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
summary="Unlink a platform identity",
|
||||
)
|
||||
async def delete_link(
|
||||
link_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> DeleteLinkResponse:
|
||||
"""
|
||||
Removes a platform link. The user will need to re-link if they
|
||||
want to use the bot on that platform again.
|
||||
"""
|
||||
link = await PlatformLink.prisma().find_unique(where={"id": link_id})
|
||||
|
||||
if not link:
|
||||
raise HTTPException(status_code=404, detail="Link not found.")
|
||||
|
||||
if link.userId != user_id:
|
||||
raise HTTPException(status_code=403, detail="Not your link.")
|
||||
|
||||
await PlatformLink.prisma().delete(where={"id": link_id})
|
||||
|
||||
logger.info(
|
||||
"Unlinked %s:%s from user ...%s",
|
||||
link.platform,
|
||||
link.platformUserId,
|
||||
user_id[-8:],
|
||||
)
|
||||
|
||||
return DeleteLinkResponse(success=True)
|
||||
@@ -1,138 +0,0 @@
|
||||
"""Tests for platform bot linking API routes."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from backend.api.features.platform_linking.auth import check_bot_api_key
|
||||
from backend.api.features.platform_linking.models import (
|
||||
ConfirmLinkResponse,
|
||||
CreateLinkTokenRequest,
|
||||
DeleteLinkResponse,
|
||||
LinkTokenStatusResponse,
|
||||
Platform,
|
||||
ResolveRequest,
|
||||
)
|
||||
|
||||
|
||||
class TestPlatformEnum:
|
||||
def test_all_platforms_exist(self):
|
||||
assert Platform.DISCORD.value == "DISCORD"
|
||||
assert Platform.TELEGRAM.value == "TELEGRAM"
|
||||
assert Platform.SLACK.value == "SLACK"
|
||||
assert Platform.TEAMS.value == "TEAMS"
|
||||
assert Platform.WHATSAPP.value == "WHATSAPP"
|
||||
assert Platform.GITHUB.value == "GITHUB"
|
||||
assert Platform.LINEAR.value == "LINEAR"
|
||||
|
||||
|
||||
class TestBotApiKeyAuth:
|
||||
@patch.dict("os.environ", {"PLATFORM_BOT_API_KEY": ""}, clear=False)
|
||||
@patch("backend.api.features.platform_linking.auth.Settings")
|
||||
def test_no_key_configured_allows_when_auth_disabled(self, mock_settings_cls):
|
||||
mock_settings_cls.return_value.config.enable_auth = False
|
||||
check_bot_api_key(None)
|
||||
|
||||
@patch.dict("os.environ", {"PLATFORM_BOT_API_KEY": ""}, clear=False)
|
||||
@patch("backend.api.features.platform_linking.auth.Settings")
|
||||
def test_no_key_configured_rejects_when_auth_enabled(self, mock_settings_cls):
|
||||
mock_settings_cls.return_value.config.enable_auth = True
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
check_bot_api_key(None)
|
||||
assert exc_info.value.status_code == 503
|
||||
|
||||
@patch.dict("os.environ", {"PLATFORM_BOT_API_KEY": "secret123"}, clear=False)
|
||||
def test_valid_key(self):
|
||||
check_bot_api_key("secret123")
|
||||
|
||||
@patch.dict("os.environ", {"PLATFORM_BOT_API_KEY": "secret123"}, clear=False)
|
||||
def test_invalid_key_rejected(self):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
check_bot_api_key("wrong")
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
@patch.dict("os.environ", {"PLATFORM_BOT_API_KEY": "secret123"}, clear=False)
|
||||
def test_missing_key_rejected(self):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
check_bot_api_key(None)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
class TestCreateLinkTokenRequest:
|
||||
def test_valid_request(self):
|
||||
req = CreateLinkTokenRequest(
|
||||
platform=Platform.DISCORD,
|
||||
platform_user_id="353922987235213313",
|
||||
)
|
||||
assert req.platform == Platform.DISCORD
|
||||
assert req.platform_user_id == "353922987235213313"
|
||||
|
||||
def test_empty_platform_user_id_rejected(self):
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
CreateLinkTokenRequest(
|
||||
platform=Platform.DISCORD,
|
||||
platform_user_id="",
|
||||
)
|
||||
|
||||
def test_too_long_platform_user_id_rejected(self):
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
CreateLinkTokenRequest(
|
||||
platform=Platform.DISCORD,
|
||||
platform_user_id="x" * 256,
|
||||
)
|
||||
|
||||
def test_invalid_platform_rejected(self):
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
CreateLinkTokenRequest.model_validate(
|
||||
{"platform": "INVALID", "platform_user_id": "123"}
|
||||
)
|
||||
|
||||
|
||||
class TestResolveRequest:
|
||||
def test_valid_request(self):
|
||||
req = ResolveRequest(
|
||||
platform=Platform.TELEGRAM,
|
||||
platform_user_id="123456789",
|
||||
)
|
||||
assert req.platform == Platform.TELEGRAM
|
||||
|
||||
def test_empty_id_rejected(self):
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
ResolveRequest(
|
||||
platform=Platform.SLACK,
|
||||
platform_user_id="",
|
||||
)
|
||||
|
||||
|
||||
class TestResponseModels:
|
||||
def test_link_token_status_literal(self):
|
||||
resp = LinkTokenStatusResponse(status="pending")
|
||||
assert resp.status == "pending"
|
||||
|
||||
resp = LinkTokenStatusResponse(status="linked", user_id="abc")
|
||||
assert resp.status == "linked"
|
||||
|
||||
resp = LinkTokenStatusResponse(status="expired")
|
||||
assert resp.status == "expired"
|
||||
|
||||
def test_confirm_link_response(self):
|
||||
resp = ConfirmLinkResponse(
|
||||
success=True,
|
||||
platform="DISCORD",
|
||||
platform_user_id="123",
|
||||
platform_username="testuser",
|
||||
)
|
||||
assert resp.success is True
|
||||
|
||||
def test_delete_link_response(self):
|
||||
resp = DeleteLinkResponse(success=True)
|
||||
assert resp.success is True
|
||||
@@ -30,8 +30,6 @@ import backend.api.features.library.routes
|
||||
import backend.api.features.mcp.routes as mcp_routes
|
||||
import backend.api.features.oauth
|
||||
import backend.api.features.otto.routes
|
||||
import backend.api.features.platform_linking.chat_proxy
|
||||
import backend.api.features.platform_linking.routes
|
||||
import backend.api.features.postmark.postmark
|
||||
import backend.api.features.store.model
|
||||
import backend.api.features.store.routes
|
||||
@@ -120,6 +118,11 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Register managed credential providers (e.g. AgentMail)
|
||||
from backend.integrations.managed_providers import register_all
|
||||
|
||||
register_all()
|
||||
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
@@ -363,16 +366,6 @@ app.include_router(
|
||||
tags=["oauth"],
|
||||
prefix="/api/oauth",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.platform_linking.routes.router,
|
||||
tags=["platform-linking"],
|
||||
prefix="/api/platform-linking",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.platform_linking.chat_proxy.router,
|
||||
tags=["platform-linking"],
|
||||
prefix="/api/platform-linking",
|
||||
)
|
||||
|
||||
app.mount("/external-api", external_api)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import re
|
||||
from abc import ABC
|
||||
from email import encoders
|
||||
from email.mime.base import MIMEBase
|
||||
@@ -8,7 +9,7 @@ from email.mime.text import MIMEText
|
||||
from email.policy import SMTP
|
||||
from email.utils import getaddresses, parseaddr
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Optional
|
||||
from typing import List, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
@@ -42,8 +43,52 @@ NO_WRAP_POLICY = SMTP.clone(max_line_length=0)
|
||||
|
||||
|
||||
def serialize_email_recipients(recipients: list[str]) -> str:
|
||||
"""Serialize recipients list to comma-separated string."""
|
||||
return ", ".join(recipients)
|
||||
"""Serialize recipients list to comma-separated string.
|
||||
|
||||
Strips leading/trailing whitespace from each address to keep MIME
|
||||
headers clean (mirrors the strip done in ``validate_email_recipients``).
|
||||
"""
|
||||
return ", ".join(addr.strip() for addr in recipients)
|
||||
|
||||
|
||||
# RFC 5322 simplified pattern: local@domain where domain has at least one dot
|
||||
_EMAIL_RE = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
|
||||
|
||||
|
||||
def validate_email_recipients(recipients: list[str], field_name: str = "to") -> None:
|
||||
"""Validate that all recipients are plausible email addresses.
|
||||
|
||||
Raises ``ValueError`` with a user-friendly message listing every
|
||||
invalid entry so the caller (or LLM) can correct them in one pass.
|
||||
"""
|
||||
invalid = [addr for addr in recipients if not _EMAIL_RE.match(addr.strip())]
|
||||
if invalid:
|
||||
formatted = ", ".join(f"'{a}'" for a in invalid)
|
||||
raise ValueError(
|
||||
f"Invalid email address(es) in '{field_name}': {formatted}. "
|
||||
f"Each entry must be a valid email address (e.g. user@example.com)."
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class HasRecipients(Protocol):
|
||||
to: list[str]
|
||||
cc: list[str]
|
||||
bcc: list[str]
|
||||
|
||||
|
||||
def validate_all_recipients(input_data: HasRecipients) -> None:
|
||||
"""Validate to/cc/bcc recipient fields on an input namespace.
|
||||
|
||||
Calls ``validate_email_recipients`` for ``to`` (required) and
|
||||
``cc``/``bcc`` (when non-empty), raising ``ValueError`` on the
|
||||
first field that contains an invalid address.
|
||||
"""
|
||||
validate_email_recipients(input_data.to, "to")
|
||||
if input_data.cc:
|
||||
validate_email_recipients(input_data.cc, "cc")
|
||||
if input_data.bcc:
|
||||
validate_email_recipients(input_data.bcc, "bcc")
|
||||
|
||||
|
||||
def _make_mime_text(
|
||||
@@ -100,14 +145,16 @@ async def create_mime_message(
|
||||
) -> str:
|
||||
"""Create a MIME message with attachments and return base64-encoded raw message."""
|
||||
|
||||
validate_all_recipients(input_data)
|
||||
|
||||
message = MIMEMultipart()
|
||||
message["to"] = serialize_email_recipients(input_data.to)
|
||||
message["subject"] = input_data.subject
|
||||
|
||||
if input_data.cc:
|
||||
message["cc"] = ", ".join(input_data.cc)
|
||||
message["cc"] = serialize_email_recipients(input_data.cc)
|
||||
if input_data.bcc:
|
||||
message["bcc"] = ", ".join(input_data.bcc)
|
||||
message["bcc"] = serialize_email_recipients(input_data.bcc)
|
||||
|
||||
# Use the new helper function with content_type if available
|
||||
content_type = getattr(input_data, "content_type", None)
|
||||
@@ -1167,13 +1214,15 @@ async def _build_reply_message(
|
||||
references.append(headers["message-id"])
|
||||
|
||||
# Create MIME message
|
||||
validate_all_recipients(input_data)
|
||||
|
||||
msg = MIMEMultipart()
|
||||
if input_data.to:
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
msg["To"] = serialize_email_recipients(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
msg["Cc"] = serialize_email_recipients(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
if headers.get("message-id"):
|
||||
msg["In-Reply-To"] = headers["message-id"]
|
||||
@@ -1685,13 +1734,16 @@ To: {original_to}
|
||||
else:
|
||||
body = f"{forward_header}\n\n{original_body}"
|
||||
|
||||
# Validate all recipient lists before building the MIME message
|
||||
validate_all_recipients(input_data)
|
||||
|
||||
# Create MIME message
|
||||
msg = MIMEMultipart()
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
msg["To"] = serialize_email_recipients(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
msg["Cc"] = serialize_email_recipients(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
|
||||
# Add body with proper content type
|
||||
|
||||
@@ -724,6 +724,9 @@ def convert_openai_tool_fmt_to_anthropic(
|
||||
def extract_openai_reasoning(response) -> str | None:
|
||||
"""Extract reasoning from OpenAI-compatible response if available."""
|
||||
"""Note: This will likely not working since the reasoning is not present in another Response API"""
|
||||
if not response.choices:
|
||||
logger.warning("LLM response has empty choices in extract_openai_reasoning")
|
||||
return None
|
||||
reasoning = None
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "reasoning") and getattr(choice, "reasoning", None):
|
||||
@@ -739,6 +742,9 @@ def extract_openai_reasoning(response) -> str | None:
|
||||
|
||||
def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
||||
"""Extract tool calls from OpenAI-compatible response."""
|
||||
if not response.choices:
|
||||
logger.warning("LLM response has empty choices in extract_openai_tool_calls")
|
||||
return None
|
||||
if response.choices[0].message.tool_calls:
|
||||
return [
|
||||
ToolContentBlock(
|
||||
@@ -972,6 +978,8 @@ async def llm_call(
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if not response.choices:
|
||||
raise ValueError("Groq returned empty choices in response")
|
||||
return LLMResponse(
|
||||
raw_response=response.choices[0].message,
|
||||
prompt=prompt,
|
||||
@@ -1031,12 +1039,8 @@ async def llm_call(
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
# If there's no response, raise an error
|
||||
if not response.choices:
|
||||
if response:
|
||||
raise ValueError(f"OpenRouter error: {response}")
|
||||
else:
|
||||
raise ValueError("No response from OpenRouter.")
|
||||
raise ValueError(f"OpenRouter returned empty choices: {response}")
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
@@ -1073,12 +1077,8 @@ async def llm_call(
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
# If there's no response, raise an error
|
||||
if not response.choices:
|
||||
if response:
|
||||
raise ValueError(f"Llama API error: {response}")
|
||||
else:
|
||||
raise ValueError("No response from Llama API.")
|
||||
raise ValueError(f"Llama API returned empty choices: {response}")
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
@@ -1108,6 +1108,8 @@ async def llm_call(
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if not completion.choices:
|
||||
raise ValueError("AI/ML API returned empty choices in response")
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=completion.choices[0].message,
|
||||
@@ -1144,6 +1146,9 @@ async def llm_call(
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
if not response.choices:
|
||||
raise ValueError(f"v0 API returned empty choices: {response}")
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
|
||||
@@ -2011,6 +2016,19 @@ class AIConversationBlock(AIBlockBase):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
has_messages = any(
|
||||
isinstance(m, dict)
|
||||
and isinstance(m.get("content"), str)
|
||||
and bool(m["content"].strip())
|
||||
for m in (input_data.messages or [])
|
||||
)
|
||||
has_prompt = bool(input_data.prompt and input_data.prompt.strip())
|
||||
if not has_messages and not has_prompt:
|
||||
raise ValueError(
|
||||
"Cannot call LLM with no messages and no prompt. "
|
||||
"Provide at least one message or a non-empty prompt."
|
||||
)
|
||||
|
||||
response = await self.llm_call(
|
||||
AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt=input_data.prompt,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -488,6 +488,154 @@ class TestLLMStatsTracking:
|
||||
assert outputs["response"] == {"result": "test"}
|
||||
|
||||
|
||||
class TestAIConversationBlockValidation:
|
||||
"""Test that AIConversationBlock validates inputs before calling the LLM."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_messages_and_empty_prompt_raises_error(self):
|
||||
"""Empty messages with no prompt should raise ValueError, not a cryptic API error."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_messages_with_prompt_succeeds(self):
|
||||
"""Empty messages but a non-empty prompt should proceed without error."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
async def mock_llm_call(input_data, credentials):
|
||||
return {"response": "OK"}
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[],
|
||||
prompt="Hello, how are you?",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[name] = data
|
||||
|
||||
assert outputs["response"] == "OK"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonempty_messages_with_empty_prompt_succeeds(self):
|
||||
"""Non-empty messages with no prompt should proceed without error."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
async def mock_llm_call(input_data, credentials):
|
||||
return {"response": "response from conversation"}
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[name] = data
|
||||
|
||||
assert outputs["response"] == "response from conversation"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_empty_content_raises_error(self):
|
||||
"""Messages with empty content strings should be treated as no messages."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{"role": "user", "content": ""}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_whitespace_content_raises_error(self):
|
||||
"""Messages with whitespace-only content should be treated as no messages."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{"role": "user", "content": " "}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_none_entry_raises_error(self):
|
||||
"""Messages list containing None should be treated as no messages."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[None],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_empty_dict_raises_error(self):
|
||||
"""Messages list containing empty dict should be treated as no messages."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_none_content_raises_error(self):
|
||||
"""Messages with content=None should not crash with AttributeError."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{"role": "user", "content": None}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
|
||||
class TestAITextSummarizerValidation:
|
||||
"""Test that AITextSummarizerBlock validates LLM responses are strings."""
|
||||
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
"""Tests for empty-choices guard in extract_openai_tool_calls() and extract_openai_reasoning()."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.llm import extract_openai_reasoning, extract_openai_tool_calls
|
||||
|
||||
|
||||
class TestExtractOpenaiToolCallsEmptyChoices:
|
||||
"""extract_openai_tool_calls() must return None when choices is empty."""
|
||||
|
||||
def test_returns_none_for_empty_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = []
|
||||
assert extract_openai_tool_calls(response) is None
|
||||
|
||||
def test_returns_none_for_none_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = None
|
||||
assert extract_openai_tool_calls(response) is None
|
||||
|
||||
def test_returns_tool_calls_when_choices_present(self):
|
||||
tool = MagicMock()
|
||||
tool.id = "call_1"
|
||||
tool.type = "function"
|
||||
tool.function.name = "my_func"
|
||||
tool.function.arguments = '{"a": 1}'
|
||||
|
||||
message = MagicMock()
|
||||
message.tool_calls = [tool]
|
||||
|
||||
choice = MagicMock()
|
||||
choice.message = message
|
||||
|
||||
response = MagicMock()
|
||||
response.choices = [choice]
|
||||
|
||||
result = extract_openai_tool_calls(response)
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert result[0].function.name == "my_func"
|
||||
|
||||
def test_returns_none_when_no_tool_calls(self):
|
||||
message = MagicMock()
|
||||
message.tool_calls = None
|
||||
|
||||
choice = MagicMock()
|
||||
choice.message = message
|
||||
|
||||
response = MagicMock()
|
||||
response.choices = [choice]
|
||||
|
||||
assert extract_openai_tool_calls(response) is None
|
||||
|
||||
|
||||
class TestExtractOpenaiReasoningEmptyChoices:
|
||||
"""extract_openai_reasoning() must return None when choices is empty."""
|
||||
|
||||
def test_returns_none_for_empty_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = []
|
||||
assert extract_openai_reasoning(response) is None
|
||||
|
||||
def test_returns_none_for_none_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = None
|
||||
assert extract_openai_reasoning(response) is None
|
||||
|
||||
def test_returns_reasoning_from_choice(self):
|
||||
choice = MagicMock()
|
||||
choice.reasoning = "Step-by-step reasoning"
|
||||
choice.message = MagicMock(spec=[]) # no 'reasoning' attr on message
|
||||
|
||||
response = MagicMock(spec=[]) # no 'reasoning' attr on response
|
||||
response.choices = [choice]
|
||||
|
||||
result = extract_openai_reasoning(response)
|
||||
assert result == "Step-by-step reasoning"
|
||||
|
||||
def test_returns_none_when_no_reasoning(self):
|
||||
choice = MagicMock(spec=[]) # no 'reasoning' attr
|
||||
choice.message = MagicMock(spec=[]) # no 'reasoning' attr
|
||||
|
||||
response = MagicMock(spec=[]) # no 'reasoning' attr
|
||||
response.choices = [choice]
|
||||
|
||||
result = extract_openai_reasoning(response)
|
||||
assert result is None
|
||||
@@ -1074,6 +1074,7 @@ async def test_orchestrator_uses_customized_name_for_blocks():
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
|
||||
mock_node.block = StoreValueBlock()
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
@@ -1105,6 +1106,7 @@ async def test_orchestrator_falls_back_to_block_name():
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {} # No customized_name
|
||||
mock_node.block = StoreValueBlock()
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
|
||||
@@ -0,0 +1,202 @@
|
||||
"""Tests for ExecutionMode enum and provider validation in the orchestrator.
|
||||
|
||||
Covers:
|
||||
- ExecutionMode enum members exist and have stable values
|
||||
- EXTENDED_THINKING provider validation (anthropic/open_router allowed, others rejected)
|
||||
- EXTENDED_THINKING model-name validation (must start with "claude")
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.blocks.orchestrator import ExecutionMode, OrchestratorBlock
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExecutionMode enum integrity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExecutionModeEnum:
|
||||
"""Guard against accidental renames or removals of enum members."""
|
||||
|
||||
def test_built_in_exists(self):
|
||||
assert hasattr(ExecutionMode, "BUILT_IN")
|
||||
assert ExecutionMode.BUILT_IN.value == "built_in"
|
||||
|
||||
def test_extended_thinking_exists(self):
|
||||
assert hasattr(ExecutionMode, "EXTENDED_THINKING")
|
||||
assert ExecutionMode.EXTENDED_THINKING.value == "extended_thinking"
|
||||
|
||||
def test_exactly_two_members(self):
|
||||
"""If a new mode is added, this test should be updated intentionally."""
|
||||
assert set(ExecutionMode.__members__.keys()) == {
|
||||
"BUILT_IN",
|
||||
"EXTENDED_THINKING",
|
||||
}
|
||||
|
||||
def test_string_enum(self):
|
||||
"""ExecutionMode is a str enum so it serialises cleanly to JSON."""
|
||||
assert isinstance(ExecutionMode.BUILT_IN, str)
|
||||
assert isinstance(ExecutionMode.EXTENDED_THINKING, str)
|
||||
|
||||
def test_round_trip_from_value(self):
|
||||
"""Constructing from the string value should return the same member."""
|
||||
assert ExecutionMode("built_in") is ExecutionMode.BUILT_IN
|
||||
assert ExecutionMode("extended_thinking") is ExecutionMode.EXTENDED_THINKING
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider validation (inline in OrchestratorBlock.run)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_model_stub(provider: str, value: str):
|
||||
"""Create a lightweight stub that behaves like LlmModel for validation."""
|
||||
metadata = MagicMock()
|
||||
metadata.provider = provider
|
||||
stub = MagicMock()
|
||||
stub.metadata = metadata
|
||||
stub.value = value
|
||||
return stub
|
||||
|
||||
|
||||
class TestExtendedThinkingProviderValidation:
|
||||
"""The orchestrator rejects EXTENDED_THINKING for non-Anthropic providers."""
|
||||
|
||||
def test_anthropic_provider_accepted(self):
|
||||
"""provider='anthropic' + claude model should not raise."""
|
||||
model = _make_model_stub("anthropic", "claude-opus-4-6")
|
||||
provider = model.metadata.provider
|
||||
model_name = model.value
|
||||
assert provider in ("anthropic", "open_router")
|
||||
assert model_name.startswith("claude")
|
||||
|
||||
def test_open_router_provider_accepted(self):
|
||||
"""provider='open_router' + claude model should not raise."""
|
||||
model = _make_model_stub("open_router", "claude-sonnet-4-6")
|
||||
provider = model.metadata.provider
|
||||
model_name = model.value
|
||||
assert provider in ("anthropic", "open_router")
|
||||
assert model_name.startswith("claude")
|
||||
|
||||
def test_openai_provider_rejected(self):
|
||||
"""provider='openai' should be rejected for EXTENDED_THINKING."""
|
||||
model = _make_model_stub("openai", "gpt-4o")
|
||||
provider = model.metadata.provider
|
||||
assert provider not in ("anthropic", "open_router")
|
||||
|
||||
def test_groq_provider_rejected(self):
|
||||
model = _make_model_stub("groq", "llama-3.3-70b-versatile")
|
||||
provider = model.metadata.provider
|
||||
assert provider not in ("anthropic", "open_router")
|
||||
|
||||
def test_non_claude_model_rejected_even_if_anthropic_provider(self):
|
||||
"""A hypothetical non-Claude model with provider='anthropic' is rejected."""
|
||||
model = _make_model_stub("anthropic", "not-a-claude-model")
|
||||
model_name = model.value
|
||||
assert not model_name.startswith("claude")
|
||||
|
||||
def test_real_gpt4o_model_rejected(self):
|
||||
"""Verify a real LlmModel enum member (GPT4O) fails the provider check."""
|
||||
model = LlmModel.GPT4O
|
||||
provider = model.metadata.provider
|
||||
assert provider not in ("anthropic", "open_router")
|
||||
|
||||
def test_real_claude_model_passes(self):
|
||||
"""Verify a real LlmModel enum member (CLAUDE_4_6_SONNET) passes."""
|
||||
model = LlmModel.CLAUDE_4_6_SONNET
|
||||
provider = model.metadata.provider
|
||||
model_name = model.value
|
||||
assert provider in ("anthropic", "open_router")
|
||||
assert model_name.startswith("claude")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration-style: exercise the validation branch via OrchestratorBlock.run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_input_data(model, execution_mode=ExecutionMode.EXTENDED_THINKING):
|
||||
"""Build a minimal MagicMock that satisfies OrchestratorBlock.run's early path."""
|
||||
inp = MagicMock()
|
||||
inp.execution_mode = execution_mode
|
||||
inp.model = model
|
||||
inp.prompt = "test"
|
||||
inp.sys_prompt = ""
|
||||
inp.conversation_history = []
|
||||
inp.last_tool_output = None
|
||||
inp.prompt_values = {}
|
||||
return inp
|
||||
|
||||
|
||||
async def _collect_run_outputs(block, input_data, **kwargs):
|
||||
"""Exhaust the OrchestratorBlock.run async generator, collecting outputs."""
|
||||
outputs = []
|
||||
async for item in block.run(input_data, **kwargs):
|
||||
outputs.append(item)
|
||||
return outputs
|
||||
|
||||
|
||||
class TestExtendedThinkingValidationRaisesInBlock:
|
||||
"""Call OrchestratorBlock.run far enough to trigger the ValueError."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_anthropic_provider_raises_valueerror(self):
|
||||
"""EXTENDED_THINKING + openai provider raises ValueError."""
|
||||
block = OrchestratorBlock()
|
||||
input_data = _make_input_data(model=LlmModel.GPT4O)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
block,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
pytest.raises(ValueError, match="Anthropic-compatible"),
|
||||
):
|
||||
await _collect_run_outputs(
|
||||
block,
|
||||
input_data,
|
||||
credentials=MagicMock(),
|
||||
graph_id="g",
|
||||
node_id="n",
|
||||
graph_exec_id="ge",
|
||||
node_exec_id="ne",
|
||||
user_id="u",
|
||||
graph_version=1,
|
||||
execution_context=MagicMock(),
|
||||
execution_processor=MagicMock(),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_claude_model_with_anthropic_provider_raises(self):
|
||||
"""A model with anthropic provider but non-claude name raises ValueError."""
|
||||
block = OrchestratorBlock()
|
||||
fake_model = _make_model_stub("anthropic", "not-a-claude-model")
|
||||
input_data = _make_input_data(model=fake_model)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
block,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
pytest.raises(ValueError, match="only supports Claude models"),
|
||||
):
|
||||
await _collect_run_outputs(
|
||||
block,
|
||||
input_data,
|
||||
credentials=MagicMock(),
|
||||
graph_id="g",
|
||||
node_id="n",
|
||||
graph_exec_id="ge",
|
||||
node_exec_id="ne",
|
||||
user_id="u",
|
||||
graph_version=1,
|
||||
execution_context=MagicMock(),
|
||||
execution_processor=MagicMock(),
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,11 +9,14 @@ shared tool registry as the SDK path.
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any, cast
|
||||
|
||||
import orjson
|
||||
from langfuse import propagate_attributes
|
||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
|
||||
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
@@ -48,7 +51,17 @@ from backend.copilot.token_tracking import persist_and_record_usage
|
||||
from backend.copilot.tools import execute_tool, get_available_tools
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.prompt import compress_context
|
||||
from backend.util.prompt import (
|
||||
compress_context,
|
||||
estimate_token_count,
|
||||
estimate_token_count_str,
|
||||
)
|
||||
from backend.util.tool_call_loop import (
|
||||
LLMLoopResponse,
|
||||
LLMToolCall,
|
||||
ToolCallResult,
|
||||
tool_call_loop,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -59,6 +72,247 @@ _background_tasks: set[asyncio.Task[Any]] = set()
|
||||
_MAX_TOOL_ROUNDS = 30
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BaselineStreamState:
|
||||
"""Mutable state shared between the tool-call loop callbacks.
|
||||
|
||||
Extracted from ``stream_chat_completion_baseline`` so that the callbacks
|
||||
can be module-level functions instead of deeply nested closures.
|
||||
"""
|
||||
|
||||
pending_events: list[StreamBaseResponse] = field(default_factory=list)
|
||||
assistant_text: str = ""
|
||||
text_block_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
text_started: bool = False
|
||||
turn_prompt_tokens: int = 0
|
||||
turn_completion_tokens: int = 0
|
||||
|
||||
|
||||
async def _baseline_llm_caller(
|
||||
messages: list[dict[str, Any]],
|
||||
tools: Sequence[Any],
|
||||
*,
|
||||
state: _BaselineStreamState,
|
||||
) -> LLMLoopResponse:
|
||||
"""Stream an OpenAI-compatible response and collect results.
|
||||
|
||||
Extracted from ``stream_chat_completion_baseline`` for readability.
|
||||
"""
|
||||
state.pending_events.append(StreamStartStep())
|
||||
|
||||
round_text = ""
|
||||
try:
|
||||
client = _get_openai_client()
|
||||
typed_messages = cast(list[ChatCompletionMessageParam], messages)
|
||||
if tools:
|
||||
typed_tools = cast(list[ChatCompletionToolParam], tools)
|
||||
response = await client.chat.completions.create(
|
||||
model=config.model,
|
||||
messages=typed_messages,
|
||||
tools=typed_tools,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
else:
|
||||
response = await client.chat.completions.create(
|
||||
model=config.model,
|
||||
messages=typed_messages,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
tool_calls_by_index: dict[int, dict[str, str]] = {}
|
||||
|
||||
async for chunk in response:
|
||||
if chunk.usage:
|
||||
state.turn_prompt_tokens += chunk.usage.prompt_tokens or 0
|
||||
state.turn_completion_tokens += chunk.usage.completion_tokens or 0
|
||||
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
if not delta:
|
||||
continue
|
||||
|
||||
if delta.content:
|
||||
if not state.text_started:
|
||||
state.pending_events.append(StreamTextStart(id=state.text_block_id))
|
||||
state.text_started = True
|
||||
round_text += delta.content
|
||||
state.pending_events.append(
|
||||
StreamTextDelta(id=state.text_block_id, delta=delta.content)
|
||||
)
|
||||
|
||||
if delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index
|
||||
if idx not in tool_calls_by_index:
|
||||
tool_calls_by_index[idx] = {
|
||||
"id": "",
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
entry = tool_calls_by_index[idx]
|
||||
if tc.id:
|
||||
entry["id"] = tc.id
|
||||
if tc.function and tc.function.name:
|
||||
entry["name"] = tc.function.name
|
||||
if tc.function and tc.function.arguments:
|
||||
entry["arguments"] += tc.function.arguments
|
||||
|
||||
# Close text block
|
||||
if state.text_started:
|
||||
state.pending_events.append(StreamTextEnd(id=state.text_block_id))
|
||||
state.text_started = False
|
||||
state.text_block_id = str(uuid.uuid4())
|
||||
finally:
|
||||
# Always persist partial text so the session history stays consistent,
|
||||
# even when the stream is interrupted by an exception.
|
||||
state.assistant_text += round_text
|
||||
# Always emit StreamFinishStep to match the StreamStartStep,
|
||||
# even if an exception occurred during streaming.
|
||||
state.pending_events.append(StreamFinishStep())
|
||||
|
||||
# Convert to shared format
|
||||
llm_tool_calls = [
|
||||
LLMToolCall(
|
||||
id=tc["id"],
|
||||
name=tc["name"],
|
||||
arguments=tc["arguments"] or "{}",
|
||||
)
|
||||
for tc in tool_calls_by_index.values()
|
||||
]
|
||||
|
||||
return LLMLoopResponse(
|
||||
response_text=round_text or None,
|
||||
tool_calls=llm_tool_calls,
|
||||
raw_response=None, # Not needed for baseline conversation updater
|
||||
prompt_tokens=0, # Tracked via state accumulators
|
||||
completion_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
async def _baseline_tool_executor(
|
||||
tool_call: LLMToolCall,
|
||||
tools: Sequence[Any],
|
||||
*,
|
||||
state: _BaselineStreamState,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
) -> ToolCallResult:
|
||||
"""Execute a tool via the copilot tool registry.
|
||||
|
||||
Extracted from ``stream_chat_completion_baseline`` for readability.
|
||||
"""
|
||||
tool_call_id = tool_call.id
|
||||
tool_name = tool_call.name
|
||||
raw_args = tool_call.arguments or "{}"
|
||||
|
||||
try:
|
||||
tool_args = orjson.loads(raw_args)
|
||||
except orjson.JSONDecodeError as parse_err:
|
||||
parse_error = f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
|
||||
logger.warning("[Baseline] %s", parse_error)
|
||||
state.pending_events.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=parse_error,
|
||||
success=False,
|
||||
)
|
||||
)
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
content=parse_error,
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
state.pending_events.append(
|
||||
StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
|
||||
)
|
||||
state.pending_events.append(
|
||||
StreamToolInputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
input=tool_args,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
result: StreamToolOutputAvailable = await execute_tool(
|
||||
tool_name=tool_name,
|
||||
parameters=tool_args,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
state.pending_events.append(result)
|
||||
tool_output = (
|
||||
result.output if isinstance(result.output, str) else str(result.output)
|
||||
)
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
content=tool_output,
|
||||
)
|
||||
except Exception as e:
|
||||
error_output = f"Tool execution error: {e}"
|
||||
logger.error(
|
||||
"[Baseline] Tool %s failed: %s",
|
||||
tool_name,
|
||||
error_output,
|
||||
exc_info=True,
|
||||
)
|
||||
state.pending_events.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=error_output,
|
||||
success=False,
|
||||
)
|
||||
)
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
content=error_output,
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
|
||||
def _baseline_conversation_updater(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
"""Update OpenAI message list with assistant response + tool results.
|
||||
|
||||
Extracted from ``stream_chat_completion_baseline`` for readability.
|
||||
"""
|
||||
if tool_results:
|
||||
# Build assistant message with tool_calls
|
||||
assistant_msg: dict[str, Any] = {"role": "assistant"}
|
||||
if response.response_text:
|
||||
assistant_msg["content"] = response.response_text
|
||||
assistant_msg["tool_calls"] = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {"name": tc.name, "arguments": tc.arguments},
|
||||
}
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
messages.append(assistant_msg)
|
||||
for tr in tool_results:
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.tool_call_id,
|
||||
"content": tr.content,
|
||||
}
|
||||
)
|
||||
else:
|
||||
if response.response_text:
|
||||
messages.append({"role": "assistant", "content": response.response_text})
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
session_id: str, message: str, user_id: str | None
|
||||
) -> None:
|
||||
@@ -219,191 +473,32 @@ async def stream_chat_completion_baseline(
|
||||
except Exception:
|
||||
logger.warning("[Baseline] Langfuse trace context setup failed")
|
||||
|
||||
assistant_text = ""
|
||||
text_block_id = str(uuid.uuid4())
|
||||
text_started = False
|
||||
step_open = False
|
||||
# Token usage accumulators — populated from streaming chunks
|
||||
turn_prompt_tokens = 0
|
||||
turn_completion_tokens = 0
|
||||
_stream_error = False # Track whether an error occurred during streaming
|
||||
state = _BaselineStreamState()
|
||||
|
||||
# Bind extracted module-level callbacks to this request's state/session
|
||||
# using functools.partial so they satisfy the Protocol signatures.
|
||||
_bound_llm_caller = partial(_baseline_llm_caller, state=state)
|
||||
_bound_tool_executor = partial(
|
||||
_baseline_tool_executor, state=state, user_id=user_id, session=session
|
||||
)
|
||||
|
||||
try:
|
||||
for _round in range(_MAX_TOOL_ROUNDS):
|
||||
# Open a new step for each LLM round
|
||||
yield StreamStartStep()
|
||||
step_open = True
|
||||
loop_result = None
|
||||
async for loop_result in tool_call_loop(
|
||||
messages=openai_messages,
|
||||
tools=tools,
|
||||
llm_call=_bound_llm_caller,
|
||||
execute_tool=_bound_tool_executor,
|
||||
update_conversation=_baseline_conversation_updater,
|
||||
max_iterations=_MAX_TOOL_ROUNDS,
|
||||
):
|
||||
# Drain buffered events after each iteration (real-time streaming)
|
||||
for evt in state.pending_events:
|
||||
yield evt
|
||||
state.pending_events.clear()
|
||||
|
||||
# Stream a response from the model
|
||||
create_kwargs: dict[str, Any] = dict(
|
||||
model=config.model,
|
||||
messages=openai_messages,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
if tools:
|
||||
create_kwargs["tools"] = tools
|
||||
response = await _get_openai_client().chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
|
||||
|
||||
# Accumulate streamed response (text + tool calls)
|
||||
round_text = ""
|
||||
tool_calls_by_index: dict[int, dict[str, str]] = {}
|
||||
|
||||
async for chunk in response:
|
||||
# Capture token usage from the streaming chunk.
|
||||
# OpenRouter normalises all providers into OpenAI format
|
||||
# where prompt_tokens already includes cached tokens
|
||||
# (unlike Anthropic's native API). Use += to sum all
|
||||
# tool-call rounds since each API call is independent.
|
||||
# NOTE: stream_options={"include_usage": True} is not
|
||||
# universally supported — some providers (Mistral, Llama
|
||||
# via OpenRouter) always return chunk.usage=None. When
|
||||
# that happens, tokens stay 0 and the tiktoken fallback
|
||||
# below activates. Fail-open: one round is estimated.
|
||||
if chunk.usage:
|
||||
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
|
||||
turn_completion_tokens += chunk.usage.completion_tokens or 0
|
||||
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
if not delta:
|
||||
continue
|
||||
|
||||
# Text content
|
||||
if delta.content:
|
||||
if not text_started:
|
||||
yield StreamTextStart(id=text_block_id)
|
||||
text_started = True
|
||||
round_text += delta.content
|
||||
yield StreamTextDelta(id=text_block_id, delta=delta.content)
|
||||
|
||||
# Tool call fragments (streamed incrementally)
|
||||
if delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index
|
||||
if idx not in tool_calls_by_index:
|
||||
tool_calls_by_index[idx] = {
|
||||
"id": "",
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
entry = tool_calls_by_index[idx]
|
||||
if tc.id:
|
||||
entry["id"] = tc.id
|
||||
if tc.function and tc.function.name:
|
||||
entry["name"] = tc.function.name
|
||||
if tc.function and tc.function.arguments:
|
||||
entry["arguments"] += tc.function.arguments
|
||||
|
||||
# Close text block if we had one this round
|
||||
if text_started:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
text_started = False
|
||||
text_block_id = str(uuid.uuid4())
|
||||
|
||||
# Accumulate text for session persistence
|
||||
assistant_text += round_text
|
||||
|
||||
# No tool calls -> model is done
|
||||
if not tool_calls_by_index:
|
||||
yield StreamFinishStep()
|
||||
step_open = False
|
||||
break
|
||||
|
||||
# Close step before tool execution
|
||||
yield StreamFinishStep()
|
||||
step_open = False
|
||||
|
||||
# Append the assistant message with tool_calls to context.
|
||||
assistant_msg: dict[str, Any] = {"role": "assistant"}
|
||||
if round_text:
|
||||
assistant_msg["content"] = round_text
|
||||
assistant_msg["tool_calls"] = [
|
||||
{
|
||||
"id": tc["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc["name"],
|
||||
"arguments": tc["arguments"] or "{}",
|
||||
},
|
||||
}
|
||||
for tc in tool_calls_by_index.values()
|
||||
]
|
||||
openai_messages.append(assistant_msg)
|
||||
|
||||
# Execute each tool call and stream events
|
||||
for tc in tool_calls_by_index.values():
|
||||
tool_call_id = tc["id"]
|
||||
tool_name = tc["name"]
|
||||
raw_args = tc["arguments"] or "{}"
|
||||
try:
|
||||
tool_args = orjson.loads(raw_args)
|
||||
except orjson.JSONDecodeError as parse_err:
|
||||
parse_error = (
|
||||
f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
|
||||
)
|
||||
logger.warning("[Baseline] %s", parse_error)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=parse_error,
|
||||
success=False,
|
||||
)
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": parse_error,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
yield StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
input=tool_args,
|
||||
)
|
||||
|
||||
# Execute via shared tool registry
|
||||
try:
|
||||
result: StreamToolOutputAvailable = await execute_tool(
|
||||
tool_name=tool_name,
|
||||
parameters=tool_args,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
yield result
|
||||
tool_output = (
|
||||
result.output
|
||||
if isinstance(result.output, str)
|
||||
else str(result.output)
|
||||
)
|
||||
except Exception as e:
|
||||
error_output = f"Tool execution error: {e}"
|
||||
logger.error(
|
||||
"[Baseline] Tool %s failed: %s",
|
||||
tool_name,
|
||||
error_output,
|
||||
exc_info=True,
|
||||
)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=error_output,
|
||||
success=False,
|
||||
)
|
||||
tool_output = error_output
|
||||
|
||||
# Append tool result to context for next round
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": tool_output,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# for-loop exhausted without break -> tool-round limit hit
|
||||
if loop_result and not loop_result.finished_naturally:
|
||||
limit_msg = (
|
||||
f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds "
|
||||
"without a final response."
|
||||
@@ -418,11 +513,28 @@ async def stream_chat_completion_baseline(
|
||||
_stream_error = True
|
||||
error_msg = str(e) or type(e).__name__
|
||||
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
|
||||
# Close any open text/step before emitting error
|
||||
if text_started:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
if step_open:
|
||||
yield StreamFinishStep()
|
||||
# Close any open text block. The llm_caller's finally block
|
||||
# already appended StreamFinishStep to pending_events, so we must
|
||||
# insert StreamTextEnd *before* StreamFinishStep to preserve the
|
||||
# protocol ordering:
|
||||
# StreamStartStep -> StreamTextStart -> ...deltas... ->
|
||||
# StreamTextEnd -> StreamFinishStep
|
||||
# Appending (or yielding directly) would place it after
|
||||
# StreamFinishStep, violating the protocol.
|
||||
if state.text_started:
|
||||
# Find the last StreamFinishStep and insert before it.
|
||||
insert_pos = len(state.pending_events)
|
||||
for i in range(len(state.pending_events) - 1, -1, -1):
|
||||
if isinstance(state.pending_events[i], StreamFinishStep):
|
||||
insert_pos = i
|
||||
break
|
||||
state.pending_events.insert(
|
||||
insert_pos, StreamTextEnd(id=state.text_block_id)
|
||||
)
|
||||
# Drain pending events in correct order
|
||||
for evt in state.pending_events:
|
||||
yield evt
|
||||
state.pending_events.clear()
|
||||
yield StreamError(errorText=error_msg, code="baseline_error")
|
||||
# Still persist whatever we got
|
||||
finally:
|
||||
@@ -442,26 +554,21 @@ async def stream_chat_completion_baseline(
|
||||
# Skip fallback when an error occurred and no output was produced —
|
||||
# charging rate-limit tokens for completely failed requests is unfair.
|
||||
if (
|
||||
turn_prompt_tokens == 0
|
||||
and turn_completion_tokens == 0
|
||||
and not (_stream_error and not assistant_text)
|
||||
state.turn_prompt_tokens == 0
|
||||
and state.turn_completion_tokens == 0
|
||||
and not (_stream_error and not state.assistant_text)
|
||||
):
|
||||
from backend.util.prompt import (
|
||||
estimate_token_count,
|
||||
estimate_token_count_str,
|
||||
)
|
||||
|
||||
turn_prompt_tokens = max(
|
||||
state.turn_prompt_tokens = max(
|
||||
estimate_token_count(openai_messages, model=config.model), 1
|
||||
)
|
||||
turn_completion_tokens = estimate_token_count_str(
|
||||
assistant_text, model=config.model
|
||||
state.turn_completion_tokens = estimate_token_count_str(
|
||||
state.assistant_text, model=config.model
|
||||
)
|
||||
logger.info(
|
||||
"[Baseline] No streaming usage reported; estimated tokens: "
|
||||
"prompt=%d, completion=%d",
|
||||
turn_prompt_tokens,
|
||||
turn_completion_tokens,
|
||||
state.turn_prompt_tokens,
|
||||
state.turn_completion_tokens,
|
||||
)
|
||||
|
||||
# Persist token usage to session and record for rate limiting.
|
||||
@@ -471,15 +578,15 @@ async def stream_chat_completion_baseline(
|
||||
await persist_and_record_usage(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
prompt_tokens=state.turn_prompt_tokens,
|
||||
completion_tokens=state.turn_completion_tokens,
|
||||
log_prefix="[Baseline]",
|
||||
)
|
||||
|
||||
# Persist assistant response
|
||||
if assistant_text:
|
||||
if state.assistant_text:
|
||||
session.messages.append(
|
||||
ChatMessage(role="assistant", content=assistant_text)
|
||||
ChatMessage(role="assistant", content=state.assistant_text)
|
||||
)
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
@@ -491,11 +598,11 @@ async def stream_chat_completion_baseline(
|
||||
# aclose() — doing so raises RuntimeError on client disconnect.
|
||||
# On GeneratorExit the client is already gone, so unreachable yields
|
||||
# are harmless; on normal completion they reach the SSE stream.
|
||||
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
|
||||
if state.turn_prompt_tokens > 0 or state.turn_completion_tokens > 0:
|
||||
yield StreamUsage(
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
total_tokens=turn_prompt_tokens + turn_completion_tokens,
|
||||
prompt_tokens=state.turn_prompt_tokens,
|
||||
completion_tokens=state.turn_completion_tokens,
|
||||
total_tokens=state.turn_prompt_tokens + state.turn_completion_tokens,
|
||||
)
|
||||
|
||||
yield StreamFinish()
|
||||
|
||||
@@ -178,7 +178,7 @@ class ChatConfig(BaseSettings):
|
||||
|
||||
Single source of truth for "will the SDK route through OpenRouter?".
|
||||
Checks the flag *and* that ``api_key`` + a valid ``base_url`` are
|
||||
present — mirrors the fallback logic in ``_build_sdk_env``.
|
||||
present — mirrors the fallback logic in ``build_sdk_env``.
|
||||
"""
|
||||
if not self.use_openrouter:
|
||||
return False
|
||||
|
||||
68
autogpt_platform/backend/backend/copilot/sdk/env.py
Normal file
68
autogpt_platform/backend/backend/copilot/sdk/env.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""SDK environment variable builder — importable without circular deps.
|
||||
|
||||
Extracted from ``service.py`` so that ``backend.blocks.orchestrator``
|
||||
can reuse the same subscription / OpenRouter / direct-Anthropic logic
|
||||
without pulling in the full copilot service module (which would create a
|
||||
circular import through ``executor`` → ``credit`` → ``block_cost_config``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.sdk.subscription import validate_subscription
|
||||
|
||||
# ChatConfig is stateless (reads env vars) — a separate instance is fine.
|
||||
# A singleton would require importing service.py which causes the circular dep
|
||||
# this module was created to avoid.
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
def build_sdk_env(
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Build env vars for the SDK CLI subprocess.
|
||||
|
||||
Three modes (checked in order):
|
||||
1. **Subscription** — clears all keys; CLI uses ``claude login`` auth.
|
||||
2. **Direct Anthropic** — returns ``{}``; subprocess inherits
|
||||
``ANTHROPIC_API_KEY`` from the parent environment.
|
||||
3. **OpenRouter** (default) — overrides base URL and auth token to
|
||||
route through the proxy, with Langfuse trace headers.
|
||||
"""
|
||||
# --- Mode 1: Claude Code subscription auth ---
|
||||
if config.use_claude_code_subscription:
|
||||
validate_subscription()
|
||||
return {
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
|
||||
# --- Mode 2: Direct Anthropic (no proxy hop) ---
|
||||
if not config.openrouter_active:
|
||||
return {}
|
||||
|
||||
# --- Mode 3: OpenRouter proxy ---
|
||||
base = (config.base_url or "").rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
env: dict[str, str] = {
|
||||
"ANTHROPIC_BASE_URL": base,
|
||||
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
|
||||
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
|
||||
}
|
||||
|
||||
# Inject broadcast headers so OpenRouter forwards traces to Langfuse.
|
||||
def _safe(v: str) -> str:
|
||||
return v.replace("\r", "").replace("\n", "").strip()[:128]
|
||||
|
||||
parts = []
|
||||
if session_id:
|
||||
parts.append(f"x-session-id: {_safe(session_id)}")
|
||||
if user_id:
|
||||
parts.append(f"x-user-id: {_safe(user_id)}")
|
||||
if parts:
|
||||
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
|
||||
|
||||
return env
|
||||
242
autogpt_platform/backend/backend/copilot/sdk/env_test.py
Normal file
242
autogpt_platform/backend/backend/copilot/sdk/env_test.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""Tests for build_sdk_env() — the SDK subprocess environment builder."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers — build a ChatConfig with explicit field values so tests don't
|
||||
# depend on real environment variables.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_config(**overrides) -> ChatConfig:
|
||||
"""Create a ChatConfig with safe defaults, applying *overrides*."""
|
||||
defaults = {
|
||||
"use_claude_code_subscription": False,
|
||||
"use_openrouter": False,
|
||||
"api_key": None,
|
||||
"base_url": None,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return ChatConfig(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mode 1 — Subscription auth
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildSdkEnvSubscription:
|
||||
"""When ``use_claude_code_subscription`` is True, keys are blanked."""
|
||||
|
||||
@patch("backend.copilot.sdk.env.validate_subscription")
|
||||
def test_returns_blanked_keys(self, mock_validate):
|
||||
"""Subscription mode clears API_KEY, AUTH_TOKEN, and BASE_URL."""
|
||||
cfg = _make_config(use_claude_code_subscription=True)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result == {
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
mock_validate.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"backend.copilot.sdk.env.validate_subscription",
|
||||
side_effect=RuntimeError("CLI not found"),
|
||||
)
|
||||
def test_propagates_validation_error(self, mock_validate):
|
||||
"""If validate_subscription fails, the error bubbles up."""
|
||||
cfg = _make_config(use_claude_code_subscription=True)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
with pytest.raises(RuntimeError, match="CLI not found"):
|
||||
build_sdk_env()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mode 2 — Direct Anthropic (no OpenRouter)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildSdkEnvDirectAnthropic:
|
||||
"""When OpenRouter is inactive, return empty dict (inherit parent env)."""
|
||||
|
||||
def test_returns_empty_dict_when_openrouter_inactive(self):
|
||||
cfg = _make_config(use_openrouter=False)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_returns_empty_dict_when_openrouter_flag_true_but_no_key(self):
|
||||
"""OpenRouter flag is True but no api_key => openrouter_active is False."""
|
||||
cfg = _make_config(use_openrouter=True, base_url="https://openrouter.ai/api/v1")
|
||||
# Force api_key to None after construction (field_validator may pick up env vars)
|
||||
object.__setattr__(cfg, "api_key", None)
|
||||
assert not cfg.openrouter_active
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mode 3 — OpenRouter proxy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildSdkEnvOpenRouter:
|
||||
"""When OpenRouter is active, return proxy env vars."""
|
||||
|
||||
def _openrouter_config(self, **overrides):
|
||||
defaults = {
|
||||
"use_openrouter": True,
|
||||
"api_key": "sk-or-test-key",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return _make_config(**defaults)
|
||||
|
||||
def test_basic_openrouter_env(self):
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
|
||||
assert result["ANTHROPIC_AUTH_TOKEN"] == "sk-or-test-key"
|
||||
assert result["ANTHROPIC_API_KEY"] == ""
|
||||
assert "ANTHROPIC_CUSTOM_HEADERS" not in result
|
||||
|
||||
def test_strips_trailing_v1(self):
|
||||
"""The /v1 suffix is stripped from the base URL."""
|
||||
cfg = self._openrouter_config(base_url="https://openrouter.ai/api/v1")
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
|
||||
|
||||
def test_strips_trailing_v1_and_slash(self):
|
||||
"""Trailing slash before /v1 strip is handled."""
|
||||
cfg = self._openrouter_config(base_url="https://openrouter.ai/api/v1/")
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
# rstrip("/") first, then remove /v1
|
||||
assert result["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
|
||||
|
||||
def test_no_v1_suffix_left_alone(self):
|
||||
"""A base URL without /v1 is used as-is."""
|
||||
cfg = self._openrouter_config(base_url="https://custom-proxy.example.com")
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result["ANTHROPIC_BASE_URL"] == "https://custom-proxy.example.com"
|
||||
|
||||
def test_session_id_header(self):
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(session_id="sess-123")
|
||||
|
||||
assert "ANTHROPIC_CUSTOM_HEADERS" in result
|
||||
assert "x-session-id: sess-123" in result["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
|
||||
def test_user_id_header(self):
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(user_id="user-456")
|
||||
|
||||
assert "x-user-id: user-456" in result["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
|
||||
def test_both_headers(self):
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(session_id="s1", user_id="u2")
|
||||
|
||||
headers = result["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
assert "x-session-id: s1" in headers
|
||||
assert "x-user-id: u2" in headers
|
||||
# They should be newline-separated
|
||||
assert "\n" in headers
|
||||
|
||||
def test_header_sanitisation_strips_newlines(self):
|
||||
"""Newlines/carriage-returns in header values are stripped."""
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(session_id="bad\r\nvalue")
|
||||
|
||||
header_val = result["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
# The _safe helper removes \r and \n
|
||||
assert "\r" not in header_val.split(": ", 1)[1]
|
||||
assert "badvalue" in header_val
|
||||
|
||||
def test_header_value_truncated_to_128_chars(self):
|
||||
"""Header values are truncated to 128 characters."""
|
||||
cfg = self._openrouter_config()
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
long_id = "x" * 200
|
||||
result = build_sdk_env(session_id=long_id)
|
||||
|
||||
# The value after "x-session-id: " should be at most 128 chars
|
||||
header_line = result["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
value = header_line.split(": ", 1)[1]
|
||||
assert len(value) == 128
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mode priority
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildSdkEnvModePriority:
|
||||
"""Subscription mode takes precedence over OpenRouter."""
|
||||
|
||||
@patch("backend.copilot.sdk.env.validate_subscription")
|
||||
def test_subscription_overrides_openrouter(self, mock_validate):
|
||||
cfg = _make_config(
|
||||
use_claude_code_subscription=True,
|
||||
use_openrouter=True,
|
||||
api_key="sk-or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
# Should get subscription result, not OpenRouter
|
||||
assert result == {
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
@@ -1010,7 +1010,7 @@ def _make_sdk_patches(
|
||||
(f"{_SVC}.create_security_hooks", dict(return_value=MagicMock())),
|
||||
(f"{_SVC}.get_copilot_tool_names", dict(return_value=[])),
|
||||
(f"{_SVC}.get_sdk_disallowed_tools", dict(return_value=[])),
|
||||
(f"{_SVC}._build_sdk_env", dict(return_value=None)),
|
||||
(f"{_SVC}.build_sdk_env", dict(return_value=None)),
|
||||
(f"{_SVC}._resolve_sdk_model", dict(return_value=None)),
|
||||
(f"{_SVC}.set_execution_context", {}),
|
||||
(
|
||||
|
||||
@@ -77,9 +77,9 @@ from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
|
||||
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||
from ..tracking import track_user_message
|
||||
from .compaction import CompactionTracker, filter_compaction_messages
|
||||
from .env import build_sdk_env # noqa: F401 — re-export for backward compat
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .security_hooks import create_security_hooks
|
||||
from .subscription import validate_subscription as _validate_claude_code_subscription
|
||||
from .tool_adapter import (
|
||||
cancel_pending_tool_tasks,
|
||||
create_copilot_mcp_server,
|
||||
@@ -567,60 +567,6 @@ def _resolve_sdk_model() -> str | None:
|
||||
return model
|
||||
|
||||
|
||||
def _build_sdk_env(
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Build env vars for the SDK CLI subprocess.
|
||||
|
||||
Three modes (checked in order):
|
||||
1. **Subscription** — clears all keys; CLI uses `claude login` auth.
|
||||
2. **Direct Anthropic** — returns `{}`; subprocess inherits
|
||||
`ANTHROPIC_API_KEY` from the parent environment.
|
||||
3. **OpenRouter** (default) — overrides base URL and auth token to
|
||||
route through the proxy, with Langfuse trace headers.
|
||||
"""
|
||||
# --- Mode 1: Claude Code subscription auth ---
|
||||
if config.use_claude_code_subscription:
|
||||
_validate_claude_code_subscription()
|
||||
return {
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
|
||||
# --- Mode 2: Direct Anthropic (no proxy hop) ---
|
||||
# `openrouter_active` checks the flag *and* credential presence.
|
||||
if not config.openrouter_active:
|
||||
return {}
|
||||
|
||||
# --- Mode 3: OpenRouter proxy ---
|
||||
# Strip /v1 suffix — SDK expects the base URL without a version path.
|
||||
base = (config.base_url or "").rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
env: dict[str, str] = {
|
||||
"ANTHROPIC_BASE_URL": base,
|
||||
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
|
||||
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
|
||||
}
|
||||
|
||||
# Inject broadcast headers so OpenRouter forwards traces to Langfuse.
|
||||
def _safe(v: str) -> str:
|
||||
"""Sanitise a header value: strip newlines/whitespace and cap length."""
|
||||
return v.replace("\r", "").replace("\n", "").strip()[:128]
|
||||
|
||||
parts = []
|
||||
if session_id:
|
||||
parts.append(f"x-session-id: {_safe(session_id)}")
|
||||
if user_id:
|
||||
parts.append(f"x-user-id: {_safe(user_id)}")
|
||||
if parts:
|
||||
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def _make_sdk_cwd(session_id: str) -> str:
|
||||
"""Create a safe, session-specific working directory path.
|
||||
|
||||
@@ -1867,7 +1813,7 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
|
||||
# Fail fast when no API credentials are available at all.
|
||||
sdk_env = _build_sdk_env(session_id=session_id, user_id=user_id)
|
||||
sdk_env = build_sdk_env(session_id=session_id, user_id=user_id)
|
||||
if not config.api_key and not config.use_claude_code_subscription:
|
||||
raise RuntimeError(
|
||||
"No API key configured. Set OPEN_ROUTER_API_KEY, "
|
||||
|
||||
@@ -325,6 +325,8 @@ class _BaseCredentials(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
provider: str
|
||||
title: Optional[str] = None
|
||||
is_managed: bool = False
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_serializer("*")
|
||||
def dump_secret_strings(value: Any, _info):
|
||||
@@ -344,7 +346,6 @@ class OAuth2Credentials(_BaseCredentials):
|
||||
refresh_token_expires_at: Optional[int] = None
|
||||
"""Unix timestamp (seconds) indicating when the refresh token expires (if at all)"""
|
||||
scopes: list[str]
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def auth_header(self) -> str:
|
||||
return f"Bearer {self.access_token.get_secret_value()}"
|
||||
|
||||
@@ -3,7 +3,7 @@ import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, cast
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from autogpt_libs.auth.models import DEFAULT_USER_ID
|
||||
@@ -21,6 +21,9 @@ from backend.util.exceptions import DatabaseError
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
@@ -453,6 +456,27 @@ async def unsubscribe_user_by_token(token: str) -> None:
|
||||
raise DatabaseError(f"Failed to unsubscribe user by token {token}: {e}") from e
|
||||
|
||||
|
||||
async def cleanup_user_managed_credentials(
|
||||
user_id: str,
|
||||
store: Optional["IntegrationCredentialsStore"] = None,
|
||||
) -> None:
|
||||
"""Revoke all externally-provisioned managed credentials for *user_id*.
|
||||
|
||||
Call this before deleting a user account so that external resources
|
||||
(e.g. AgentMail pods, pod-scoped API keys) are properly cleaned up.
|
||||
The credential rows themselves are cascade-deleted with the User row.
|
||||
|
||||
Pass an existing *store* for testability; when omitted a fresh instance
|
||||
is created.
|
||||
"""
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.integrations.managed_credentials import cleanup_managed_credentials
|
||||
|
||||
if store is None:
|
||||
store = IntegrationCredentialsStore()
|
||||
await cleanup_managed_credentials(user_id, store)
|
||||
|
||||
|
||||
async def update_user_timezone(user_id: str, timezone: str) -> User:
|
||||
"""Update a user's timezone setting."""
|
||||
try:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
@@ -21,6 +22,7 @@ from backend.data.redis_client import get_redis_async
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def provider_matches(stored: str, expected: str) -> bool:
|
||||
@@ -284,6 +286,7 @@ DEFAULT_CREDENTIALS = [
|
||||
elevenlabs_credentials,
|
||||
]
|
||||
|
||||
|
||||
SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS}
|
||||
|
||||
# Set of providers that have system credentials available
|
||||
@@ -323,20 +326,45 @@ class IntegrationCredentialsStore:
|
||||
return get_database_manager_async_client()
|
||||
|
||||
# =============== USER-MANAGED CREDENTIALS =============== #
|
||||
|
||||
async def _get_persisted_user_creds_unlocked(
|
||||
self, user_id: str
|
||||
) -> list[Credentials]:
|
||||
"""Return only the persisted (user-stored) credentials — no side effects.
|
||||
|
||||
**Caller must already hold ``locked_user_integrations(user_id)``.**
|
||||
"""
|
||||
return list((await self._get_user_integrations(user_id)).credentials)
|
||||
|
||||
async def add_creds(self, user_id: str, credentials: Credentials) -> None:
|
||||
async with await self.locked_user_integrations(user_id):
|
||||
if await self.get_creds_by_id(user_id, credentials.id):
|
||||
# Check system/managed IDs without triggering provisioning
|
||||
if credentials.id in SYSTEM_CREDENTIAL_IDS:
|
||||
raise ValueError(
|
||||
f"Can not re-create existing credentials #{credentials.id} "
|
||||
f"for user #{user_id}"
|
||||
)
|
||||
await self._set_user_integration_creds(
|
||||
user_id, [*(await self.get_all_creds(user_id)), credentials]
|
||||
)
|
||||
persisted = await self._get_persisted_user_creds_unlocked(user_id)
|
||||
if any(c.id == credentials.id for c in persisted):
|
||||
raise ValueError(
|
||||
f"Can not re-create existing credentials #{credentials.id} "
|
||||
f"for user #{user_id}"
|
||||
)
|
||||
await self._set_user_integration_creds(user_id, [*persisted, credentials])
|
||||
|
||||
async def get_all_creds(self, user_id: str) -> list[Credentials]:
|
||||
users_credentials = (await self._get_user_integrations(user_id)).credentials
|
||||
all_credentials = users_credentials
|
||||
"""Public entry point — acquires lock, then delegates."""
|
||||
async with await self.locked_user_integrations(user_id):
|
||||
return await self._get_all_creds_unlocked(user_id)
|
||||
|
||||
async def _get_all_creds_unlocked(self, user_id: str) -> list[Credentials]:
|
||||
"""Return all credentials for *user_id*.
|
||||
|
||||
**Caller must already hold ``locked_user_integrations(user_id)``.**
|
||||
"""
|
||||
user_integrations = await self._get_user_integrations(user_id)
|
||||
all_credentials = list(user_integrations.credentials)
|
||||
|
||||
# These will always be added
|
||||
all_credentials.append(ollama_credentials)
|
||||
|
||||
@@ -417,13 +445,22 @@ class IntegrationCredentialsStore:
|
||||
return list(set(c.provider for c in credentials))
|
||||
|
||||
async def update_creds(self, user_id: str, updated: Credentials) -> None:
|
||||
if updated.id in SYSTEM_CREDENTIAL_IDS:
|
||||
raise ValueError(
|
||||
f"System credential #{updated.id} cannot be updated directly"
|
||||
)
|
||||
async with await self.locked_user_integrations(user_id):
|
||||
current = await self.get_creds_by_id(user_id, updated.id)
|
||||
persisted = await self._get_persisted_user_creds_unlocked(user_id)
|
||||
current = next((c for c in persisted if c.id == updated.id), None)
|
||||
if not current:
|
||||
raise ValueError(
|
||||
f"Credentials with ID {updated.id} "
|
||||
f"for user with ID {user_id} not found"
|
||||
)
|
||||
if current.is_managed:
|
||||
raise ValueError(
|
||||
f"AutoGPT-managed credential #{updated.id} cannot be updated"
|
||||
)
|
||||
if type(current) is not type(updated):
|
||||
raise TypeError(
|
||||
f"Can not update credentials with ID {updated.id} "
|
||||
@@ -443,22 +480,53 @@ class IntegrationCredentialsStore:
|
||||
f"to more restrictive set of scopes {updated.scopes}"
|
||||
)
|
||||
|
||||
# Update the credentials
|
||||
# Update only persisted credentials — no side-effectful provisioning
|
||||
updated_credentials_list = [
|
||||
updated if c.id == updated.id else c
|
||||
for c in await self.get_all_creds(user_id)
|
||||
updated if c.id == updated.id else c for c in persisted
|
||||
]
|
||||
await self._set_user_integration_creds(user_id, updated_credentials_list)
|
||||
|
||||
async def delete_creds_by_id(self, user_id: str, credentials_id: str) -> None:
|
||||
if credentials_id in SYSTEM_CREDENTIAL_IDS:
|
||||
raise ValueError(f"System credential #{credentials_id} cannot be deleted")
|
||||
async with await self.locked_user_integrations(user_id):
|
||||
filtered_credentials = [
|
||||
c for c in await self.get_all_creds(user_id) if c.id != credentials_id
|
||||
]
|
||||
persisted = await self._get_persisted_user_creds_unlocked(user_id)
|
||||
target = next((c for c in persisted if c.id == credentials_id), None)
|
||||
if target and target.is_managed:
|
||||
raise ValueError(
|
||||
f"AutoGPT-managed credential #{credentials_id} cannot be deleted"
|
||||
)
|
||||
filtered_credentials = [c for c in persisted if c.id != credentials_id]
|
||||
await self._set_user_integration_creds(user_id, filtered_credentials)
|
||||
|
||||
# ============== SYSTEM-MANAGED CREDENTIALS ============== #
|
||||
|
||||
async def has_managed_credential(self, user_id: str, provider: str) -> bool:
|
||||
"""Check if a managed credential exists for *provider*."""
|
||||
user_integrations = await self._get_user_integrations(user_id)
|
||||
return any(
|
||||
c.provider == provider and c.is_managed
|
||||
for c in user_integrations.credentials
|
||||
)
|
||||
|
||||
async def add_managed_credential(
|
||||
self, user_id: str, credential: Credentials
|
||||
) -> None:
|
||||
"""Upsert a managed credential.
|
||||
|
||||
Removes any existing managed credential for the same provider,
|
||||
then appends the new one. The credential MUST have is_managed=True.
|
||||
"""
|
||||
if not credential.is_managed:
|
||||
raise ValueError("credential.is_managed must be True")
|
||||
async with self.edit_user_integrations(user_id) as user_integrations:
|
||||
user_integrations.credentials = [
|
||||
c
|
||||
for c in user_integrations.credentials
|
||||
if not (c.provider == credential.provider and c.is_managed)
|
||||
]
|
||||
user_integrations.credentials.append(credential)
|
||||
|
||||
async def set_ayrshare_profile_key(self, user_id: str, profile_key: str) -> None:
|
||||
"""Set the Ayrshare profile key for a user.
|
||||
|
||||
|
||||
@@ -0,0 +1,188 @@
|
||||
"""Generic infrastructure for system-provided, per-user managed credentials.
|
||||
|
||||
Managed credentials are provisioned automatically by the platform (e.g. an
|
||||
AgentMail pod-scoped API key) and stored alongside regular user credentials
|
||||
with ``is_managed=True``. Users cannot update or delete them.
|
||||
|
||||
New integrations register a :class:`ManagedCredentialProvider` at import time;
|
||||
the two entry-points consumed by the rest of the application are:
|
||||
|
||||
* :func:`ensure_managed_credentials` – fired as a background task from the
|
||||
credential-listing endpoints (non-blocking).
|
||||
* :func:`cleanup_managed_credentials` – called during account deletion to
|
||||
revoke external resources (API keys, pods, etc.).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Abstract provider
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ManagedCredentialProvider(ABC):
|
||||
"""Base class for integrations that auto-provision per-user credentials."""
|
||||
|
||||
provider_name: str
|
||||
"""Must match the ``provider`` field on the resulting credential."""
|
||||
|
||||
@abstractmethod
|
||||
async def is_available(self) -> bool:
|
||||
"""Return ``True`` when the org-level configuration is present."""
|
||||
|
||||
@abstractmethod
|
||||
async def provision(self, user_id: str) -> Credentials:
|
||||
"""Create external resources and return a credential.
|
||||
|
||||
The returned credential **must** have ``is_managed=True``.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def deprovision(self, user_id: str, credential: Credentials) -> None:
|
||||
"""Revoke external resources during account deletion."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PROVIDERS: dict[str, ManagedCredentialProvider] = {}
|
||||
|
||||
# Users whose managed credentials have already been verified recently.
|
||||
# Avoids redundant DB checks on every GET /credentials call.
|
||||
# maxsize caps memory; TTL re-checks periodically (e.g. when new providers
|
||||
# are added). ~100K entries ≈ 4-8 MB.
|
||||
_provisioned_users: TTLCache[str, bool] = TTLCache(maxsize=100_000, ttl=3600)
|
||||
|
||||
|
||||
def register_managed_provider(provider: ManagedCredentialProvider) -> None:
|
||||
_PROVIDERS[provider.provider_name] = provider
|
||||
|
||||
|
||||
def get_managed_provider(name: str) -> ManagedCredentialProvider | None:
|
||||
return _PROVIDERS.get(name)
|
||||
|
||||
|
||||
def get_managed_providers() -> dict[str, ManagedCredentialProvider]:
|
||||
return dict(_PROVIDERS)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _ensure_one(
|
||||
user_id: str,
|
||||
store: IntegrationCredentialsStore,
|
||||
name: str,
|
||||
provider: ManagedCredentialProvider,
|
||||
) -> bool:
|
||||
"""Provision a single managed credential under a distributed Redis lock.
|
||||
|
||||
Returns ``True`` if the credential already exists or was successfully
|
||||
provisioned, ``False`` on transient failure so the caller knows not to
|
||||
cache the user as fully provisioned.
|
||||
"""
|
||||
try:
|
||||
if not await provider.is_available():
|
||||
return True
|
||||
# Use a distributed Redis lock so the check-then-provision operation
|
||||
# is atomic across all workers, preventing duplicate external
|
||||
# resource provisioning (e.g. AgentMail API keys).
|
||||
locks = await store.locks()
|
||||
key = (f"user:{user_id}", f"managed-provision:{name}")
|
||||
async with locks.locked(key):
|
||||
# Re-check under lock to avoid duplicate provisioning.
|
||||
if await store.has_managed_credential(user_id, name):
|
||||
return True
|
||||
credential = await provider.provision(user_id)
|
||||
await store.add_managed_credential(user_id, credential)
|
||||
logger.info(
|
||||
"Provisioned managed credential for provider=%s user=%s",
|
||||
name,
|
||||
user_id,
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to provision managed credential for provider=%s user=%s",
|
||||
name,
|
||||
user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def ensure_managed_credentials(
|
||||
user_id: str,
|
||||
store: IntegrationCredentialsStore,
|
||||
) -> None:
|
||||
"""Provision missing managed credentials for *user_id*.
|
||||
|
||||
Fired as a non-blocking background task from the credential-listing
|
||||
endpoints. Failures are logged but never propagated — the user simply
|
||||
will not see the managed credential until the next page load.
|
||||
|
||||
Skips entirely if this user has already been checked during the current
|
||||
process lifetime (in-memory cache). Resets on restart — just a
|
||||
performance optimisation, not a correctness guarantee.
|
||||
|
||||
Providers are checked concurrently via ``asyncio.gather``.
|
||||
"""
|
||||
if user_id in _provisioned_users:
|
||||
return
|
||||
|
||||
results = await asyncio.gather(
|
||||
*(_ensure_one(user_id, store, n, p) for n, p in _PROVIDERS.items())
|
||||
)
|
||||
|
||||
# Only cache the user as provisioned when every provider succeeded or
|
||||
# was already present. A transient failure (network timeout, Redis
|
||||
# blip) returns False, so the next page load will retry.
|
||||
if all(results):
|
||||
_provisioned_users[user_id] = True
|
||||
|
||||
|
||||
async def cleanup_managed_credentials(
|
||||
user_id: str,
|
||||
store: IntegrationCredentialsStore,
|
||||
) -> None:
|
||||
"""Revoke all external managed resources for a user being deleted."""
|
||||
all_creds = await store.get_all_creds(user_id)
|
||||
managed = [c for c in all_creds if c.is_managed]
|
||||
for cred in managed:
|
||||
provider = _PROVIDERS.get(cred.provider)
|
||||
if not provider:
|
||||
logger.warning(
|
||||
"No managed provider registered for %s — skipping cleanup",
|
||||
cred.provider,
|
||||
)
|
||||
continue
|
||||
try:
|
||||
await provider.deprovision(user_id, cred)
|
||||
logger.info(
|
||||
"Deprovisioned managed credential for provider=%s user=%s",
|
||||
cred.provider,
|
||||
user_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
"Failed to deprovision %s for user %s",
|
||||
cred.provider,
|
||||
user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
@@ -0,0 +1,17 @@
|
||||
"""Managed credential providers.
|
||||
|
||||
Call :func:`register_all` at application startup (e.g. in ``rest_api.py``)
|
||||
to populate the provider registry before any requests are processed.
|
||||
"""
|
||||
|
||||
from backend.integrations.managed_credentials import (
|
||||
get_managed_provider,
|
||||
register_managed_provider,
|
||||
)
|
||||
from backend.integrations.managed_providers.agentmail import AgentMailManagedProvider
|
||||
|
||||
|
||||
def register_all() -> None:
|
||||
"""Register every built-in managed credential provider (idempotent)."""
|
||||
if get_managed_provider(AgentMailManagedProvider.provider_name) is None:
|
||||
register_managed_provider(AgentMailManagedProvider())
|
||||
@@ -0,0 +1,90 @@
|
||||
"""AgentMail managed credential provider.
|
||||
|
||||
Uses the org-level AgentMail API key to create a per-user pod and a
|
||||
pod-scoped API key. The pod key is stored as an ``is_managed``
|
||||
credential so it appears automatically in block credential dropdowns.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, Credentials
|
||||
from backend.integrations.managed_credentials import ManagedCredentialProvider
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class AgentMailManagedProvider(ManagedCredentialProvider):
|
||||
provider_name = "agent_mail"
|
||||
|
||||
async def is_available(self) -> bool:
|
||||
return bool(settings.secrets.agentmail_api_key)
|
||||
|
||||
async def provision(self, user_id: str) -> Credentials:
|
||||
from agentmail import AsyncAgentMail
|
||||
|
||||
client = AsyncAgentMail(api_key=settings.secrets.agentmail_api_key)
|
||||
|
||||
# client_id makes pod creation idempotent — if a pod already exists
|
||||
# for this user_id the SDK returns the existing pod.
|
||||
pod = await client.pods.create(client_id=user_id, name=f"{user_id}-pod")
|
||||
|
||||
# NOTE: api_keys.create() is NOT idempotent. If the caller retries
|
||||
# after a partial failure (pod created, key created, but store write
|
||||
# failed), a second key will be created and the first becomes orphaned
|
||||
# on AgentMail's side. The double-check pattern in _ensure_one
|
||||
# (has_managed_credential under lock) prevents this in normal flow;
|
||||
# only a crash between key creation and store write can cause it.
|
||||
api_key_obj = await client.pods.api_keys.create(
|
||||
pod_id=pod.pod_id, name=f"{user_id}-agpt-managed"
|
||||
)
|
||||
|
||||
return APIKeyCredentials(
|
||||
provider=self.provider_name,
|
||||
title="AgentMail (managed by AutoGPT)",
|
||||
api_key=SecretStr(api_key_obj.api_key),
|
||||
expires_at=None,
|
||||
is_managed=True,
|
||||
metadata={"pod_id": pod.pod_id},
|
||||
)
|
||||
|
||||
async def deprovision(self, user_id: str, credential: Credentials) -> None:
|
||||
from agentmail import AsyncAgentMail
|
||||
|
||||
pod_id = credential.metadata.get("pod_id")
|
||||
if not pod_id:
|
||||
logger.warning(
|
||||
"Managed credential for user %s has no pod_id in metadata — "
|
||||
"skipping AgentMail cleanup",
|
||||
user_id,
|
||||
)
|
||||
return
|
||||
|
||||
client = AsyncAgentMail(api_key=settings.secrets.agentmail_api_key)
|
||||
try:
|
||||
# Verify the pod actually belongs to this user before deleting,
|
||||
# as a safety measure against cross-user deletion via the
|
||||
# org-level API key.
|
||||
pod = await client.pods.get(pod_id=pod_id)
|
||||
if getattr(pod, "client_id", None) and pod.client_id != user_id:
|
||||
logger.error(
|
||||
"Pod %s client_id=%s does not match user %s — "
|
||||
"refusing to delete",
|
||||
pod_id,
|
||||
pod.client_id,
|
||||
user_id,
|
||||
)
|
||||
return
|
||||
await client.pods.delete(pod_id=pod_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to delete AgentMail pod %s for user %s",
|
||||
pod_id,
|
||||
user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
20
autogpt_platform/backend/backend/util/security.py
Normal file
20
autogpt_platform/backend/backend/util/security.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Shared security constants for field-level filtering.
|
||||
|
||||
Other modules (e.g. orchestrator, future blocks) import from here so the
|
||||
sensitive-field list stays in one place.
|
||||
"""
|
||||
|
||||
# Field names to exclude from hardcoded-defaults descriptions (case-insensitive).
|
||||
SENSITIVE_FIELD_NAMES: frozenset[str] = frozenset(
|
||||
{
|
||||
"credentials",
|
||||
"api_key",
|
||||
"password",
|
||||
"secret",
|
||||
"token",
|
||||
"auth",
|
||||
"authorization",
|
||||
"access_token",
|
||||
"refresh_token",
|
||||
}
|
||||
)
|
||||
@@ -708,6 +708,8 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
description="The LaunchDarkly SDK key for feature flag management",
|
||||
)
|
||||
|
||||
agentmail_api_key: str = Field(default="", description="AgentMail API Key")
|
||||
|
||||
ayrshare_api_key: str = Field(default="", description="Ayrshare API Key")
|
||||
ayrshare_jwt_key: str = Field(default="", description="Ayrshare private Key")
|
||||
|
||||
|
||||
281
autogpt_platform/backend/backend/util/tool_call_loop.py
Normal file
281
autogpt_platform/backend/backend/util/tool_call_loop.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""Shared tool-calling conversation loop.
|
||||
|
||||
Provides a generic, provider-agnostic conversation loop that both
|
||||
the OrchestratorBlock and copilot baseline can use. The loop:
|
||||
|
||||
1. Calls the LLM with tool definitions
|
||||
2. Extracts tool calls from the response
|
||||
3. Executes tools via a caller-supplied callback
|
||||
4. Appends results to the conversation
|
||||
5. Repeats until no more tool calls or max iterations reached
|
||||
|
||||
Callers provide callbacks for LLM calling, tool execution, and
|
||||
conversation updating.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol, TypedDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Typed dict definitions for tool definitions and conversation messages.
|
||||
# These document the expected shapes and allow callers to pass TypedDict
|
||||
# subclasses (e.g. ``ChatCompletionToolParam``) without ``type: ignore``.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class FunctionParameters(TypedDict, total=False):
|
||||
"""JSON Schema object describing a tool function's parameters."""
|
||||
|
||||
type: str
|
||||
properties: dict[str, Any]
|
||||
required: list[str]
|
||||
additionalProperties: bool
|
||||
|
||||
|
||||
class FunctionDefinition(TypedDict, total=False):
|
||||
"""Function definition within a tool definition."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
parameters: FunctionParameters
|
||||
|
||||
|
||||
class ToolDefinition(TypedDict):
|
||||
"""OpenAI-compatible tool definition (function-calling format).
|
||||
|
||||
Compatible with ``openai.types.chat.ChatCompletionToolParam`` and the
|
||||
dict-based tool definitions built by ``OrchestratorBlock``.
|
||||
"""
|
||||
|
||||
type: str
|
||||
function: FunctionDefinition
|
||||
|
||||
|
||||
class ConversationMessage(TypedDict, total=False):
|
||||
"""A single message in the conversation (OpenAI chat format).
|
||||
|
||||
Primarily for documentation; at runtime plain dicts are used because
|
||||
messages from different providers carry varying keys.
|
||||
"""
|
||||
|
||||
role: str
|
||||
content: str | list[Any] | None
|
||||
tool_calls: list[dict[str, Any]]
|
||||
tool_call_id: str
|
||||
name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallResult:
|
||||
"""Result of a single tool execution."""
|
||||
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
content: str
|
||||
is_error: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMToolCall:
|
||||
"""A tool call extracted from an LLM response."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
arguments: str # JSON string
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMLoopResponse:
|
||||
"""Response from a single LLM call in the loop.
|
||||
|
||||
``raw_response`` is typed as ``Any`` intentionally: the loop itself
|
||||
never inspects it — it is an opaque pass-through that the caller's
|
||||
``ConversationUpdater`` uses to rebuild provider-specific message
|
||||
history (OpenAI ChatCompletion, Anthropic Message, Ollama str, etc.).
|
||||
"""
|
||||
|
||||
response_text: str | None
|
||||
tool_calls: list[LLMToolCall]
|
||||
raw_response: Any
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
reasoning: str | None = None
|
||||
|
||||
|
||||
class LLMCaller(Protocol):
|
||||
"""Protocol for LLM call functions."""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: Sequence[Any],
|
||||
) -> LLMLoopResponse: ...
|
||||
|
||||
|
||||
class ToolExecutor(Protocol):
|
||||
"""Protocol for tool execution functions."""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
tool_call: LLMToolCall,
|
||||
tools: Sequence[Any],
|
||||
) -> ToolCallResult: ...
|
||||
|
||||
|
||||
class ConversationUpdater(Protocol):
|
||||
"""Protocol for updating conversation history after an LLM response."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallLoopResult:
|
||||
"""Final result of the tool-calling loop."""
|
||||
|
||||
response_text: str
|
||||
messages: list[dict[str, Any]]
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
iterations: int = 0
|
||||
finished_naturally: bool = True # False if hit max iterations
|
||||
last_tool_calls: list[LLMToolCall] = field(default_factory=list)
|
||||
|
||||
|
||||
async def tool_call_loop(
|
||||
*,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: Sequence[Any],
|
||||
llm_call: LLMCaller,
|
||||
execute_tool: ToolExecutor,
|
||||
update_conversation: ConversationUpdater,
|
||||
max_iterations: int = -1,
|
||||
last_iteration_message: str | None = None,
|
||||
parallel_tool_calls: bool = True,
|
||||
) -> AsyncGenerator[ToolCallLoopResult, None]:
|
||||
"""Run a tool-calling conversation loop as an async generator.
|
||||
|
||||
Yields a ``ToolCallLoopResult`` after each iteration so callers can
|
||||
drain buffered events (e.g. streaming text deltas) between iterations.
|
||||
The **final** yielded result has ``finished_naturally`` set and contains
|
||||
the complete response text.
|
||||
|
||||
Args:
|
||||
messages: Initial conversation messages (modified in-place).
|
||||
tools: Tool function definitions (OpenAI format). Accepts any
|
||||
sequence of tool dicts, including ``ChatCompletionToolParam``.
|
||||
llm_call: Async function to call the LLM. The callback can
|
||||
perform streaming internally (e.g. accumulate text deltas
|
||||
and collect events) — it just needs to return the final
|
||||
``LLMLoopResponse`` with extracted tool calls.
|
||||
execute_tool: Async function to execute a tool call.
|
||||
update_conversation: Function to update messages with LLM
|
||||
response and tool results.
|
||||
max_iterations: Max iterations. -1 = infinite, 0 = no loop
|
||||
(immediately yields a "max reached" result).
|
||||
last_iteration_message: Optional message to append on the last
|
||||
iteration to encourage the model to finish.
|
||||
parallel_tool_calls: If True (default), execute multiple tool
|
||||
calls from a single LLM response concurrently via
|
||||
``asyncio.gather``. Set to False when tool calls may have
|
||||
ordering dependencies or mutate shared state.
|
||||
|
||||
Yields:
|
||||
ToolCallLoopResult after each iteration. Check ``finished_naturally``
|
||||
to determine if the loop completed or is still running.
|
||||
"""
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
iteration = 0
|
||||
|
||||
while max_iterations < 0 or iteration < max_iterations:
|
||||
iteration += 1
|
||||
|
||||
# On last iteration, add a hint to finish. Only copy the list
|
||||
# when the hint needs to be appended to avoid per-iteration overhead
|
||||
# on long conversations.
|
||||
is_last = (
|
||||
last_iteration_message
|
||||
and max_iterations > 0
|
||||
and iteration == max_iterations
|
||||
)
|
||||
if is_last:
|
||||
iteration_messages = list(messages)
|
||||
iteration_messages.append(
|
||||
{"role": "system", "content": last_iteration_message}
|
||||
)
|
||||
else:
|
||||
iteration_messages = messages
|
||||
|
||||
# Call LLM
|
||||
response = await llm_call(iteration_messages, tools)
|
||||
total_prompt_tokens += response.prompt_tokens
|
||||
total_completion_tokens += response.completion_tokens
|
||||
|
||||
# No tool calls = done
|
||||
if not response.tool_calls:
|
||||
update_conversation(messages, response)
|
||||
yield ToolCallLoopResult(
|
||||
response_text=response.response_text or "",
|
||||
messages=messages,
|
||||
total_prompt_tokens=total_prompt_tokens,
|
||||
total_completion_tokens=total_completion_tokens,
|
||||
iterations=iteration,
|
||||
finished_naturally=True,
|
||||
)
|
||||
return
|
||||
|
||||
# Execute tools — parallel or sequential depending on caller preference.
|
||||
# NOTE: asyncio.gather does not cancel sibling tasks when one raises.
|
||||
# Callers should handle errors inside execute_tool (return error
|
||||
# ToolCallResult) rather than letting exceptions propagate.
|
||||
if parallel_tool_calls and len(response.tool_calls) > 1:
|
||||
# Parallel: side-effects from different tool executors (e.g.
|
||||
# streaming events appended to a shared list) may interleave
|
||||
# nondeterministically. Each event carries its own tool-call
|
||||
# identifier, so consumers must correlate by ID.
|
||||
tool_results: list[ToolCallResult] = list(
|
||||
await asyncio.gather(
|
||||
*(execute_tool(tc, tools) for tc in response.tool_calls)
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Sequential: preserves ordering guarantees for callers that
|
||||
# need deterministic execution order.
|
||||
tool_results = [await execute_tool(tc, tools) for tc in response.tool_calls]
|
||||
|
||||
# Update conversation with response + tool results
|
||||
update_conversation(messages, response, tool_results)
|
||||
|
||||
# Yield a fresh result so callers can drain buffered events
|
||||
yield ToolCallLoopResult(
|
||||
response_text="",
|
||||
messages=messages,
|
||||
total_prompt_tokens=total_prompt_tokens,
|
||||
total_completion_tokens=total_completion_tokens,
|
||||
iterations=iteration,
|
||||
finished_naturally=False,
|
||||
last_tool_calls=list(response.tool_calls),
|
||||
)
|
||||
|
||||
# Hit max iterations
|
||||
yield ToolCallLoopResult(
|
||||
response_text=f"Completed after {max_iterations} iterations (limit reached)",
|
||||
messages=messages,
|
||||
total_prompt_tokens=total_prompt_tokens,
|
||||
total_completion_tokens=total_completion_tokens,
|
||||
iterations=iteration,
|
||||
finished_naturally=False,
|
||||
)
|
||||
554
autogpt_platform/backend/backend/util/tool_call_loop_test.py
Normal file
554
autogpt_platform/backend/backend/util/tool_call_loop_test.py
Normal file
@@ -0,0 +1,554 @@
|
||||
"""Unit tests for tool_call_loop shared abstraction.
|
||||
|
||||
Covers:
|
||||
- Happy path with tool calls (single and multi-round)
|
||||
- Final text response (no tool calls)
|
||||
- Max iterations reached
|
||||
- No tools scenario
|
||||
- Exception propagation from tool executor
|
||||
- Parallel tool execution
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.tool_call_loop import (
|
||||
LLMLoopResponse,
|
||||
LLMToolCall,
|
||||
ToolCallLoopResult,
|
||||
ToolCallResult,
|
||||
tool_call_loop,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TOOL_DEFS: list[dict[str, Any]] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather for a city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def _make_response(
|
||||
text: str | None = None,
|
||||
tool_calls: list[LLMToolCall] | None = None,
|
||||
prompt_tokens: int = 10,
|
||||
completion_tokens: int = 5,
|
||||
) -> LLMLoopResponse:
|
||||
return LLMLoopResponse(
|
||||
response_text=text,
|
||||
tool_calls=tool_calls or [],
|
||||
raw_response={"mock": True},
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_response_no_tool_calls():
|
||||
"""LLM responds with text only -- loop should yield once and finish."""
|
||||
|
||||
async def llm_call(
|
||||
messages: list[dict[str, Any]], tools: Sequence[Any]
|
||||
) -> LLMLoopResponse:
|
||||
return _make_response(text="Hello world")
|
||||
|
||||
async def execute_tool(
|
||||
tool_call: LLMToolCall, tools: Sequence[Any]
|
||||
) -> ToolCallResult:
|
||||
raise AssertionError("Should not be called")
|
||||
|
||||
def update_conversation(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
messages.append({"role": "assistant", "content": response.response_text})
|
||||
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Hi"}]
|
||||
results: list[ToolCallLoopResult] = []
|
||||
async for r in tool_call_loop(
|
||||
messages=msgs,
|
||||
tools=TOOL_DEFS,
|
||||
llm_call=llm_call,
|
||||
execute_tool=execute_tool,
|
||||
update_conversation=update_conversation,
|
||||
):
|
||||
results.append(r)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].finished_naturally is True
|
||||
assert results[0].response_text == "Hello world"
|
||||
assert results[0].iterations == 1
|
||||
assert results[0].total_prompt_tokens == 10
|
||||
assert results[0].total_completion_tokens == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_tool_call_then_text():
|
||||
"""LLM makes one tool call, then responds with text on second round."""
|
||||
call_count = 0
|
||||
|
||||
async def llm_call(
|
||||
messages: list[dict[str, Any]], tools: Sequence[Any]
|
||||
) -> LLMLoopResponse:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _make_response(
|
||||
tool_calls=[
|
||||
LLMToolCall(
|
||||
id="tc_1", name="get_weather", arguments='{"city":"NYC"}'
|
||||
)
|
||||
]
|
||||
)
|
||||
return _make_response(text="It's sunny in NYC")
|
||||
|
||||
async def execute_tool(
|
||||
tool_call: LLMToolCall, tools: Sequence[Any]
|
||||
) -> ToolCallResult:
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call.id,
|
||||
tool_name=tool_call.name,
|
||||
content='{"temp": 72}',
|
||||
)
|
||||
|
||||
def update_conversation(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
messages.append({"role": "assistant", "content": response.response_text})
|
||||
if tool_results:
|
||||
for tr in tool_results:
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.tool_call_id,
|
||||
"content": tr.content,
|
||||
}
|
||||
)
|
||||
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Weather?"}]
|
||||
results: list[ToolCallLoopResult] = []
|
||||
async for r in tool_call_loop(
|
||||
messages=msgs,
|
||||
tools=TOOL_DEFS,
|
||||
llm_call=llm_call,
|
||||
execute_tool=execute_tool,
|
||||
update_conversation=update_conversation,
|
||||
):
|
||||
results.append(r)
|
||||
|
||||
# First yield: tool call iteration (not finished)
|
||||
# Second yield: text response (finished)
|
||||
assert len(results) == 2
|
||||
assert results[0].finished_naturally is False
|
||||
assert results[0].iterations == 1
|
||||
assert len(results[0].last_tool_calls) == 1
|
||||
assert results[1].finished_naturally is True
|
||||
assert results[1].response_text == "It's sunny in NYC"
|
||||
assert results[1].iterations == 2
|
||||
assert results[1].total_prompt_tokens == 20
|
||||
assert results[1].total_completion_tokens == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_iterations_reached():
|
||||
"""Loop should stop after max_iterations even if LLM keeps calling tools."""
|
||||
|
||||
async def llm_call(
|
||||
messages: list[dict[str, Any]], tools: Sequence[Any]
|
||||
) -> LLMLoopResponse:
|
||||
return _make_response(
|
||||
tool_calls=[
|
||||
LLMToolCall(id="tc_x", name="get_weather", arguments='{"city":"X"}')
|
||||
]
|
||||
)
|
||||
|
||||
async def execute_tool(
|
||||
tool_call: LLMToolCall, tools: Sequence[Any]
|
||||
) -> ToolCallResult:
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call.id, tool_name=tool_call.name, content="result"
|
||||
)
|
||||
|
||||
def update_conversation(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Go"}]
|
||||
results: list[ToolCallLoopResult] = []
|
||||
async for r in tool_call_loop(
|
||||
messages=msgs,
|
||||
tools=TOOL_DEFS,
|
||||
llm_call=llm_call,
|
||||
execute_tool=execute_tool,
|
||||
update_conversation=update_conversation,
|
||||
max_iterations=3,
|
||||
):
|
||||
results.append(r)
|
||||
|
||||
# 3 tool-call iterations + 1 final "max reached"
|
||||
assert len(results) == 4
|
||||
for r in results[:3]:
|
||||
assert r.finished_naturally is False
|
||||
final = results[-1]
|
||||
assert final.finished_naturally is False
|
||||
assert "3 iterations" in final.response_text
|
||||
assert final.iterations == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tools_first_response_text():
|
||||
"""When LLM immediately responds with text (empty tools list), finishes."""
|
||||
|
||||
async def llm_call(
|
||||
messages: list[dict[str, Any]], tools: Sequence[Any]
|
||||
) -> LLMLoopResponse:
|
||||
return _make_response(text="No tools needed")
|
||||
|
||||
async def execute_tool(
|
||||
tool_call: LLMToolCall, tools: Sequence[Any]
|
||||
) -> ToolCallResult:
|
||||
raise AssertionError("Should not be called")
|
||||
|
||||
def update_conversation(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Hi"}]
|
||||
results: list[ToolCallLoopResult] = []
|
||||
async for r in tool_call_loop(
|
||||
messages=msgs,
|
||||
tools=[],
|
||||
llm_call=llm_call,
|
||||
execute_tool=execute_tool,
|
||||
update_conversation=update_conversation,
|
||||
):
|
||||
results.append(r)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].finished_naturally is True
|
||||
assert results[0].response_text == "No tools needed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_executor_exception_propagates():
|
||||
"""Exception in execute_tool should propagate out of the loop."""
|
||||
|
||||
async def llm_call(
|
||||
messages: list[dict[str, Any]], tools: Sequence[Any]
|
||||
) -> LLMLoopResponse:
|
||||
return _make_response(
|
||||
tool_calls=[LLMToolCall(id="tc_err", name="get_weather", arguments="{}")]
|
||||
)
|
||||
|
||||
async def execute_tool(
|
||||
tool_call: LLMToolCall, tools: Sequence[Any]
|
||||
) -> ToolCallResult:
|
||||
raise RuntimeError("Tool execution failed!")
|
||||
|
||||
def update_conversation(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Go"}]
|
||||
with pytest.raises(RuntimeError, match="Tool execution failed!"):
|
||||
async for _ in tool_call_loop(
|
||||
messages=msgs,
|
||||
tools=TOOL_DEFS,
|
||||
llm_call=llm_call,
|
||||
execute_tool=execute_tool,
|
||||
update_conversation=update_conversation,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_tool_execution():
|
||||
"""Multiple tool calls in one response should execute concurrently."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def llm_call(
|
||||
messages: list[dict[str, Any]], tools: Sequence[Any]
|
||||
) -> LLMLoopResponse:
|
||||
if len(messages) == 1:
|
||||
return _make_response(
|
||||
tool_calls=[
|
||||
LLMToolCall(id="tc_a", name="tool_a", arguments="{}"),
|
||||
LLMToolCall(id="tc_b", name="tool_b", arguments="{}"),
|
||||
]
|
||||
)
|
||||
return _make_response(text="Done")
|
||||
|
||||
async def execute_tool(
|
||||
tool_call: LLMToolCall, tools: Sequence[Any]
|
||||
) -> ToolCallResult:
|
||||
# tool_b starts instantly, tool_a has a small delay.
|
||||
# With parallel execution, both should overlap.
|
||||
if tool_call.name == "tool_a":
|
||||
await asyncio.sleep(0.05)
|
||||
execution_order.append(tool_call.name)
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call.id, tool_name=tool_call.name, content="ok"
|
||||
)
|
||||
|
||||
def update_conversation(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
messages.append({"role": "assistant", "content": "called tools"})
|
||||
if tool_results:
|
||||
for tr in tool_results:
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.tool_call_id,
|
||||
"content": tr.content,
|
||||
}
|
||||
)
|
||||
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Run both"}]
|
||||
async for _ in tool_call_loop(
|
||||
messages=msgs,
|
||||
tools=TOOL_DEFS,
|
||||
llm_call=llm_call,
|
||||
execute_tool=execute_tool,
|
||||
update_conversation=update_conversation,
|
||||
):
|
||||
pass
|
||||
|
||||
# With parallel execution, tool_b (no delay) finishes before tool_a
|
||||
assert execution_order == ["tool_b", "tool_a"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sequential_tool_execution():
|
||||
"""With parallel_tool_calls=False, tools execute in order regardless of speed."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def llm_call(
|
||||
messages: list[dict[str, Any]], tools: Sequence[Any]
|
||||
) -> LLMLoopResponse:
|
||||
if len(messages) == 1:
|
||||
return _make_response(
|
||||
tool_calls=[
|
||||
LLMToolCall(id="tc_a", name="tool_a", arguments="{}"),
|
||||
LLMToolCall(id="tc_b", name="tool_b", arguments="{}"),
|
||||
]
|
||||
)
|
||||
return _make_response(text="Done")
|
||||
|
||||
async def execute_tool(
|
||||
tool_call: LLMToolCall, tools: Sequence[Any]
|
||||
) -> ToolCallResult:
|
||||
# tool_b would finish first if parallel, but sequential should keep order
|
||||
if tool_call.name == "tool_a":
|
||||
await asyncio.sleep(0.05)
|
||||
execution_order.append(tool_call.name)
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call.id, tool_name=tool_call.name, content="ok"
|
||||
)
|
||||
|
||||
def update_conversation(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
messages.append({"role": "assistant", "content": "called tools"})
|
||||
if tool_results:
|
||||
for tr in tool_results:
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.tool_call_id,
|
||||
"content": tr.content,
|
||||
}
|
||||
)
|
||||
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Run both"}]
|
||||
async for _ in tool_call_loop(
|
||||
messages=msgs,
|
||||
tools=TOOL_DEFS,
|
||||
llm_call=llm_call,
|
||||
execute_tool=execute_tool,
|
||||
update_conversation=update_conversation,
|
||||
parallel_tool_calls=False,
|
||||
):
|
||||
pass
|
||||
|
||||
# With sequential execution, tool_a runs first despite being slower
|
||||
assert execution_order == ["tool_a", "tool_b"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_last_iteration_message_appended():
|
||||
"""On the final iteration, last_iteration_message should be appended."""
|
||||
captured_messages: list[list[dict[str, Any]]] = []
|
||||
|
||||
async def llm_call(
|
||||
messages: list[dict[str, Any]], tools: Sequence[Any]
|
||||
) -> LLMLoopResponse:
|
||||
captured_messages.append(list(messages))
|
||||
return _make_response(
|
||||
tool_calls=[LLMToolCall(id="tc_1", name="get_weather", arguments="{}")]
|
||||
)
|
||||
|
||||
async def execute_tool(
|
||||
tool_call: LLMToolCall, tools: Sequence[Any]
|
||||
) -> ToolCallResult:
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call.id, tool_name=tool_call.name, content="ok"
|
||||
)
|
||||
|
||||
def update_conversation(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Go"}]
|
||||
async for _ in tool_call_loop(
|
||||
messages=msgs,
|
||||
tools=TOOL_DEFS,
|
||||
llm_call=llm_call,
|
||||
execute_tool=execute_tool,
|
||||
update_conversation=update_conversation,
|
||||
max_iterations=2,
|
||||
last_iteration_message="Please finish now.",
|
||||
):
|
||||
pass
|
||||
|
||||
# First iteration: no extra message
|
||||
assert len(captured_messages[0]) == 1
|
||||
# Second (last) iteration: should have the hint appended
|
||||
last_call_msgs = captured_messages[1]
|
||||
assert any(
|
||||
m.get("role") == "system" and "Please finish now." in m.get("content", "")
|
||||
for m in last_call_msgs
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_accumulation():
|
||||
"""Tokens should accumulate across iterations."""
|
||||
call_count = 0
|
||||
|
||||
async def llm_call(
|
||||
messages: list[dict[str, Any]], tools: Sequence[Any]
|
||||
) -> LLMLoopResponse:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count <= 2:
|
||||
return _make_response(
|
||||
tool_calls=[
|
||||
LLMToolCall(
|
||||
id=f"tc_{call_count}", name="get_weather", arguments="{}"
|
||||
)
|
||||
],
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
)
|
||||
return _make_response(text="Final", prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
async def execute_tool(
|
||||
tool_call: LLMToolCall, tools: Sequence[Any]
|
||||
) -> ToolCallResult:
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call.id, tool_name=tool_call.name, content="ok"
|
||||
)
|
||||
|
||||
def update_conversation(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Go"}]
|
||||
final_result = None
|
||||
async for r in tool_call_loop(
|
||||
messages=msgs,
|
||||
tools=TOOL_DEFS,
|
||||
llm_call=llm_call,
|
||||
execute_tool=execute_tool,
|
||||
update_conversation=update_conversation,
|
||||
):
|
||||
final_result = r
|
||||
|
||||
assert final_result is not None
|
||||
assert final_result.total_prompt_tokens == 300 # 3 calls * 100
|
||||
assert final_result.total_completion_tokens == 150 # 3 calls * 50
|
||||
assert final_result.iterations == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_iterations_zero_no_loop():
|
||||
"""max_iterations=0 should immediately yield a 'max reached' result without calling LLM."""
|
||||
|
||||
async def llm_call(
|
||||
messages: list[dict[str, Any]], tools: Sequence[Any]
|
||||
) -> LLMLoopResponse:
|
||||
raise AssertionError("LLM should not be called when max_iterations=0")
|
||||
|
||||
async def execute_tool(
|
||||
tool_call: LLMToolCall, tools: Sequence[Any]
|
||||
) -> ToolCallResult:
|
||||
raise AssertionError("Tool should not be called when max_iterations=0")
|
||||
|
||||
def update_conversation(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
raise AssertionError("Updater should not be called when max_iterations=0")
|
||||
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": "Go"}]
|
||||
results: list[ToolCallLoopResult] = []
|
||||
async for r in tool_call_loop(
|
||||
messages=msgs,
|
||||
tools=TOOL_DEFS,
|
||||
llm_call=llm_call,
|
||||
execute_tool=execute_tool,
|
||||
update_conversation=update_conversation,
|
||||
max_iterations=0,
|
||||
):
|
||||
results.append(r)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].finished_naturally is False
|
||||
assert results[0].iterations == 0
|
||||
assert "0 iterations" in results[0].response_text
|
||||
@@ -1,44 +0,0 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "PlatformType" AS ENUM ('DISCORD', 'TELEGRAM', 'SLACK', 'TEAMS', 'WHATSAPP', 'GITHUB', 'LINEAR');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "PlatformLink" (
|
||||
"id" TEXT NOT NULL,
|
||||
"userId" TEXT NOT NULL,
|
||||
"platform" "PlatformType" NOT NULL,
|
||||
"platformUserId" TEXT NOT NULL,
|
||||
"platformUsername" TEXT,
|
||||
"linkedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
CONSTRAINT "PlatformLink_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "PlatformLinkToken" (
|
||||
"id" TEXT NOT NULL,
|
||||
"token" TEXT NOT NULL,
|
||||
"platform" "PlatformType" NOT NULL,
|
||||
"platformUserId" TEXT NOT NULL,
|
||||
"platformUsername" TEXT,
|
||||
"channelId" TEXT,
|
||||
"expiresAt" TIMESTAMP(3) NOT NULL,
|
||||
"usedAt" TIMESTAMP(3),
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
CONSTRAINT "PlatformLinkToken_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "PlatformLink_platform_platformUserId_key" ON "PlatformLink"("platform", "platformUserId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformLink_userId_idx" ON "PlatformLink"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "PlatformLinkToken_token_key" ON "PlatformLinkToken"("token");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformLinkToken_expiresAt_idx" ON "PlatformLinkToken"("expiresAt");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "PlatformLink" ADD CONSTRAINT "PlatformLink_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
@@ -155,6 +155,7 @@ asyncio_default_fixture_loop_scope = "session"
|
||||
addopts = "-p no:syrupy"
|
||||
markers = [
|
||||
"supplementary: tests kept for coverage but superseded by integration tests",
|
||||
"integration: end-to-end tests that require a live database (skipped in CI)",
|
||||
]
|
||||
filterwarnings = [
|
||||
"ignore:'audioop' is deprecated:DeprecationWarning:discord.player",
|
||||
|
||||
@@ -71,9 +71,6 @@ model User {
|
||||
OAuthAuthorizationCodes OAuthAuthorizationCode[]
|
||||
OAuthAccessTokens OAuthAccessToken[]
|
||||
OAuthRefreshTokens OAuthRefreshToken[]
|
||||
|
||||
// Platform bot linking
|
||||
PlatformLinks PlatformLink[]
|
||||
}
|
||||
|
||||
enum OnboardingStep {
|
||||
@@ -1305,50 +1302,3 @@ model OAuthRefreshToken {
|
||||
@@index([userId, applicationId])
|
||||
@@index([expiresAt]) // For cleanup
|
||||
}
|
||||
|
||||
// ── Platform Bot Linking ──────────────────────────────────────────────
|
||||
// Links external chat platform identities (Discord, Telegram, Slack, etc.)
|
||||
// to AutoGPT user accounts, enabling the multi-platform CoPilot bot.
|
||||
|
||||
enum PlatformType {
|
||||
DISCORD
|
||||
TELEGRAM
|
||||
SLACK
|
||||
TEAMS
|
||||
WHATSAPP
|
||||
GITHUB
|
||||
LINEAR
|
||||
}
|
||||
|
||||
// Maps a platform user identity to an AutoGPT account.
|
||||
// One AutoGPT user can have multiple platform links (e.g. Discord + Telegram).
|
||||
// Each platform identity can only link to one AutoGPT account.
|
||||
model PlatformLink {
|
||||
id String @id @default(uuid())
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
platform PlatformType
|
||||
platformUserId String // The user's ID on that platform
|
||||
platformUsername String? // Display name (best-effort, may go stale)
|
||||
linkedAt DateTime @default(now())
|
||||
|
||||
@@unique([platform, platformUserId])
|
||||
@@index([userId])
|
||||
}
|
||||
|
||||
// One-time tokens for the account linking flow.
|
||||
// Generated when an unlinked user messages the bot; consumed when they
|
||||
// complete the link on the AutoGPT web app.
|
||||
model PlatformLinkToken {
|
||||
id String @id @default(uuid())
|
||||
token String @unique
|
||||
platform PlatformType
|
||||
platformUserId String
|
||||
platformUsername String?
|
||||
channelId String? // So the bot can send a confirmation message
|
||||
expiresAt DateTime
|
||||
usedAt DateTime?
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
@@index([expiresAt])
|
||||
}
|
||||
|
||||
@@ -1,9 +1,20 @@
|
||||
import base64
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.google.gmail import GmailReadBlock
|
||||
from backend.blocks.google.gmail import (
|
||||
GmailForwardBlock,
|
||||
GmailReadBlock,
|
||||
HasRecipients,
|
||||
_build_reply_message,
|
||||
create_mime_message,
|
||||
validate_all_recipients,
|
||||
validate_email_recipients,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
|
||||
class TestGmailReadBlock:
|
||||
@@ -250,3 +261,244 @@ class TestGmailReadBlock:
|
||||
|
||||
result = await self.gmail_block._get_email_body(msg, self.mock_service)
|
||||
assert result == "This email does not contain a readable body."
|
||||
|
||||
|
||||
class TestValidateEmailRecipients:
|
||||
"""Test cases for validate_email_recipients."""
|
||||
|
||||
def test_valid_single_email(self):
|
||||
validate_email_recipients(["user@example.com"])
|
||||
|
||||
def test_valid_multiple_emails(self):
|
||||
validate_email_recipients(["a@b.com", "x@y.org", "test@sub.domain.co"])
|
||||
|
||||
def test_invalid_missing_at(self):
|
||||
with pytest.raises(ValueError, match="Invalid email address"):
|
||||
validate_email_recipients(["not-an-email"])
|
||||
|
||||
def test_invalid_missing_domain_dot(self):
|
||||
with pytest.raises(ValueError, match="Invalid email address"):
|
||||
validate_email_recipients(["user@localhost"])
|
||||
|
||||
def test_invalid_empty_string(self):
|
||||
with pytest.raises(ValueError, match="Invalid email address"):
|
||||
validate_email_recipients([""])
|
||||
|
||||
def test_invalid_json_object_string(self):
|
||||
with pytest.raises(ValueError, match="Invalid email address"):
|
||||
validate_email_recipients(['{"email": "user@example.com"}'])
|
||||
|
||||
def test_mixed_valid_and_invalid(self):
|
||||
with pytest.raises(ValueError, match="'bad-addr'"):
|
||||
validate_email_recipients(["good@example.com", "bad-addr"])
|
||||
|
||||
def test_field_name_in_error(self):
|
||||
with pytest.raises(ValueError, match="'cc'"):
|
||||
validate_email_recipients(["nope"], field_name="cc")
|
||||
|
||||
def test_whitespace_trimmed(self):
|
||||
validate_email_recipients([" user@example.com "])
|
||||
|
||||
def test_empty_list_passes(self):
|
||||
validate_email_recipients([])
|
||||
|
||||
|
||||
class TestValidateAllRecipients:
|
||||
"""Test cases for validate_all_recipients."""
|
||||
|
||||
def test_valid_all_fields(self):
|
||||
data = cast(
|
||||
HasRecipients,
|
||||
SimpleNamespace(to=["a@b.com"], cc=["c@d.com"], bcc=["e@f.com"]),
|
||||
)
|
||||
validate_all_recipients(data)
|
||||
|
||||
def test_invalid_to_raises(self):
|
||||
data = cast(HasRecipients, SimpleNamespace(to=["bad"], cc=[], bcc=[]))
|
||||
with pytest.raises(ValueError, match="'to'"):
|
||||
validate_all_recipients(data)
|
||||
|
||||
def test_invalid_cc_raises(self):
|
||||
data = cast(HasRecipients, SimpleNamespace(to=["a@b.com"], cc=["bad"], bcc=[]))
|
||||
with pytest.raises(ValueError, match="'cc'"):
|
||||
validate_all_recipients(data)
|
||||
|
||||
def test_invalid_bcc_raises(self):
|
||||
data = cast(
|
||||
HasRecipients,
|
||||
SimpleNamespace(to=["a@b.com"], cc=["c@d.com"], bcc=["bad"]),
|
||||
)
|
||||
with pytest.raises(ValueError, match="'bcc'"):
|
||||
validate_all_recipients(data)
|
||||
|
||||
def test_empty_cc_bcc_skipped(self):
|
||||
data = cast(HasRecipients, SimpleNamespace(to=["a@b.com"], cc=[], bcc=[]))
|
||||
validate_all_recipients(data)
|
||||
|
||||
|
||||
class TestCreateMimeMessageValidation:
|
||||
"""Integration tests verifying validation hooks in create_mime_message()."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_to_raises_before_mime_construction(self):
|
||||
"""Invalid 'to' recipients should raise ValueError before any MIME work."""
|
||||
input_data = SimpleNamespace(
|
||||
to=["not-an-email"],
|
||||
cc=[],
|
||||
bcc=[],
|
||||
subject="Test",
|
||||
body="Hello",
|
||||
attachments=[],
|
||||
)
|
||||
exec_ctx = cast(ExecutionContext, SimpleNamespace(graph_exec_id="test-exec-id"))
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid email address"):
|
||||
await create_mime_message(input_data, exec_ctx)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cc_raises_before_mime_construction(self):
|
||||
"""Invalid 'cc' recipients should raise ValueError."""
|
||||
input_data = SimpleNamespace(
|
||||
to=["valid@example.com"],
|
||||
cc=["bad-addr"],
|
||||
bcc=[],
|
||||
subject="Test",
|
||||
body="Hello",
|
||||
attachments=[],
|
||||
)
|
||||
exec_ctx = cast(ExecutionContext, SimpleNamespace(graph_exec_id="test-exec-id"))
|
||||
|
||||
with pytest.raises(ValueError, match="'cc'"):
|
||||
await create_mime_message(input_data, exec_ctx)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_recipients_passes_validation(self):
|
||||
"""Valid recipients should not raise during validation."""
|
||||
input_data = SimpleNamespace(
|
||||
to=["user@example.com"],
|
||||
cc=["other@example.com"],
|
||||
bcc=[],
|
||||
subject="Test",
|
||||
body="Hello",
|
||||
attachments=[],
|
||||
)
|
||||
exec_ctx = cast(ExecutionContext, SimpleNamespace(graph_exec_id="test-exec-id"))
|
||||
|
||||
# Should succeed without raising
|
||||
result = await create_mime_message(input_data, exec_ctx)
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
class TestBuildReplyMessageValidation:
|
||||
"""Integration tests verifying validation hooks in _build_reply_message()."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_to_raises_before_reply_construction(self):
|
||||
"""Invalid 'to' in reply should raise ValueError before MIME work."""
|
||||
mock_service = Mock()
|
||||
mock_parent = {
|
||||
"threadId": "thread-1",
|
||||
"payload": {
|
||||
"headers": [
|
||||
{"name": "Subject", "value": "Original"},
|
||||
{"name": "Message-ID", "value": "<msg@example.com>"},
|
||||
{"name": "From", "value": "sender@example.com"},
|
||||
]
|
||||
},
|
||||
}
|
||||
mock_service.users().messages().get().execute.return_value = mock_parent
|
||||
|
||||
input_data = SimpleNamespace(
|
||||
parentMessageId="msg-1",
|
||||
to=["not-valid"],
|
||||
cc=[],
|
||||
bcc=[],
|
||||
subject="",
|
||||
body="Reply body",
|
||||
replyAll=False,
|
||||
attachments=[],
|
||||
)
|
||||
exec_ctx = cast(ExecutionContext, SimpleNamespace(graph_exec_id="test-exec-id"))
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid email address"):
|
||||
await _build_reply_message(mock_service, input_data, exec_ctx)
|
||||
|
||||
|
||||
class TestForwardMessageValidation:
|
||||
"""Test that _forward_message() raises ValueError for invalid recipients."""
|
||||
|
||||
@staticmethod
|
||||
def _make_input(
|
||||
to: list[str] | None = None,
|
||||
cc: list[str] | None = None,
|
||||
bcc: list[str] | None = None,
|
||||
) -> "GmailForwardBlock.Input":
|
||||
mock = Mock(spec=GmailForwardBlock.Input)
|
||||
mock.messageId = "m1"
|
||||
mock.to = to or []
|
||||
mock.cc = cc or []
|
||||
mock.bcc = bcc or []
|
||||
mock.subject = ""
|
||||
mock.forwardMessage = "FYI"
|
||||
mock.includeAttachments = False
|
||||
mock.content_type = None
|
||||
mock.additionalAttachments = []
|
||||
mock.credentials = None
|
||||
return mock
|
||||
|
||||
@staticmethod
|
||||
def _exec_ctx():
|
||||
return ExecutionContext(user_id="u1", graph_exec_id="g1")
|
||||
|
||||
@staticmethod
|
||||
def _mock_service():
|
||||
"""Build a mock Gmail service that returns a parent message."""
|
||||
parent_message = {
|
||||
"id": "m1",
|
||||
"payload": {
|
||||
"headers": [
|
||||
{"name": "Subject", "value": "Original subject"},
|
||||
{"name": "From", "value": "sender@example.com"},
|
||||
{"name": "To", "value": "me@example.com"},
|
||||
{"name": "Date", "value": "Mon, 31 Mar 2026 00:00:00 +0000"},
|
||||
],
|
||||
"mimeType": "text/plain",
|
||||
"body": {
|
||||
"data": base64.urlsafe_b64encode(b"Hello world").decode(),
|
||||
},
|
||||
"parts": [],
|
||||
},
|
||||
}
|
||||
svc = Mock()
|
||||
svc.users().messages().get().execute.return_value = parent_message
|
||||
return svc
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_to_raises(self):
|
||||
block = GmailForwardBlock()
|
||||
with pytest.raises(ValueError, match="Invalid email address.*'to'"):
|
||||
await block._forward_message(
|
||||
self._mock_service(),
|
||||
self._make_input(to=["bad-addr"]),
|
||||
self._exec_ctx(),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cc_raises(self):
|
||||
block = GmailForwardBlock()
|
||||
with pytest.raises(ValueError, match="Invalid email address.*'cc'"):
|
||||
await block._forward_message(
|
||||
self._mock_service(),
|
||||
self._make_input(to=["valid@example.com"], cc=["not-valid"]),
|
||||
self._exec_ctx(),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_bcc_raises(self):
|
||||
block = GmailForwardBlock()
|
||||
with pytest.raises(ValueError, match="Invalid email address.*'bcc'"):
|
||||
await block._forward_message(
|
||||
self._mock_service(),
|
||||
self._make_input(to=["valid@example.com"], bcc=["nope"]),
|
||||
self._exec_ctx(),
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@ images: {
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
@@ -569,6 +570,10 @@ async def main():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(
|
||||
os.getenv("CI") == "true",
|
||||
reason="Data seeding test requires a dedicated database; not for CI",
|
||||
)
|
||||
async def test_main_function_runs_without_errors():
|
||||
await main()
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Base stage for both dev and prod
|
||||
FROM node:21-alpine AS base
|
||||
FROM node:22.22-alpine3.23 AS base
|
||||
WORKDIR /app
|
||||
RUN corepack enable
|
||||
COPY autogpt_platform/frontend/package.json autogpt_platform/frontend/pnpm-lock.yaml ./
|
||||
@@ -33,7 +33,7 @@ ENV NEXT_PUBLIC_SOURCEMAPS="false"
|
||||
RUN if [ "$NEXT_PUBLIC_PW_TEST" = "true" ]; then NEXT_PUBLIC_PW_TEST=true NODE_OPTIONS="--max-old-space-size=8192" pnpm build; else NODE_OPTIONS="--max-old-space-size=8192" pnpm build; fi
|
||||
|
||||
# Prod stage - based on NextJS reference Dockerfile https://github.com/vercel/next.js/blob/64271354533ed16da51be5dce85f0dbd15f17517/examples/with-docker/Dockerfile
|
||||
FROM node:21-alpine AS prod
|
||||
FROM node:22.22-alpine3.23 AS prod
|
||||
ENV NODE_ENV=production
|
||||
ENV HOSTNAME=0.0.0.0
|
||||
WORKDIR /app
|
||||
|
||||
@@ -13,6 +13,8 @@ import {
|
||||
getSuggestionThemes,
|
||||
} from "./helpers";
|
||||
import { SuggestionThemes } from "./components/SuggestionThemes/SuggestionThemes";
|
||||
import { PulseChips } from "../PulseChips/PulseChips";
|
||||
import { usePulseChips } from "../PulseChips/usePulseChips";
|
||||
|
||||
interface Props {
|
||||
inputLayoutId: string;
|
||||
@@ -34,6 +36,7 @@ export function EmptySession({
|
||||
}: Props) {
|
||||
const { user } = useSupabase();
|
||||
const greetingName = getGreetingName(user);
|
||||
const pulseChips = usePulseChips();
|
||||
|
||||
const { data: suggestedPromptsResponse, isLoading: isLoadingPrompts } =
|
||||
useGetV2GetSuggestedPrompts({
|
||||
@@ -80,6 +83,8 @@ export function EmptySession({
|
||||
Tell me about your work — I'll find what to automate.
|
||||
</Text>
|
||||
|
||||
<PulseChips chips={pulseChips} onChipClick={onSend} />
|
||||
|
||||
<div className="mb-6">
|
||||
<motion.div
|
||||
layoutId={inputLayoutId}
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
"use client";
|
||||
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { ArrowRightIcon } from "@phosphor-icons/react";
|
||||
import NextLink from "next/link";
|
||||
import { StatusBadge } from "@/app/(platform)/library/components/StatusBadge/StatusBadge";
|
||||
import type { AgentStatus } from "@/app/(platform)/library/types";
|
||||
|
||||
export interface PulseChipData {
|
||||
id: string;
|
||||
name: string;
|
||||
status: AgentStatus;
|
||||
shortMessage: string;
|
||||
}
|
||||
|
||||
interface Props {
|
||||
chips: PulseChipData[];
|
||||
onChipClick?: (prompt: string) => void;
|
||||
}
|
||||
|
||||
export function PulseChips({ chips, onChipClick }: Props) {
|
||||
if (chips.length === 0) return null;
|
||||
|
||||
return (
|
||||
<div className="mb-6">
|
||||
<div className="mb-3 flex items-center justify-between">
|
||||
<Text variant="small-medium" className="text-zinc-600">
|
||||
What's happening with your agents
|
||||
</Text>
|
||||
<NextLink
|
||||
href="/library"
|
||||
className="flex items-center gap-1 text-xs text-zinc-500 hover:text-zinc-700"
|
||||
>
|
||||
View all <ArrowRightIcon size={12} />
|
||||
</NextLink>
|
||||
</div>
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{chips.map((chip) => (
|
||||
<PulseChip key={chip.id} chip={chip} onClick={onChipClick} />
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface ChipProps {
|
||||
chip: PulseChipData;
|
||||
onClick?: (prompt: string) => void;
|
||||
}
|
||||
|
||||
function PulseChip({ chip, onClick }: ChipProps) {
|
||||
function handleClick() {
|
||||
const prompt = buildChipPrompt(chip);
|
||||
onClick?.(prompt);
|
||||
}
|
||||
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleClick}
|
||||
className="flex items-center gap-2 rounded-medium border border-zinc-100 bg-white px-3 py-2 text-left transition-all hover:border-zinc-200 hover:shadow-sm"
|
||||
>
|
||||
<StatusBadge status={chip.status} />
|
||||
<div className="min-w-0">
|
||||
<Text variant="small-medium" className="truncate text-zinc-900">
|
||||
{chip.name}
|
||||
</Text>
|
||||
<Text variant="small" className="truncate text-zinc-500">
|
||||
{chip.shortMessage}
|
||||
</Text>
|
||||
</div>
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
function buildChipPrompt(chip: PulseChipData): string {
|
||||
switch (chip.status) {
|
||||
case "error":
|
||||
return `What happened with ${chip.name}? It has an error — can you check?`;
|
||||
case "running":
|
||||
return `Give me a status update on ${chip.name} — what has it done so far?`;
|
||||
case "idle":
|
||||
return `${chip.name} hasn't run recently. Should I keep it or update and re-run it?`;
|
||||
default:
|
||||
return `Tell me about ${chip.name} — what's its current status?`;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import type { PulseChipData } from "./PulseChips";
|
||||
import type { AgentStatus } from "@/app/(platform)/library/types";
|
||||
|
||||
/**
|
||||
* Provides a prioritised list of pulse chips for the Home empty state.
|
||||
* Errors → running → stale, max 5 chips.
|
||||
*
|
||||
* TODO: Replace with real API data from `GET /agents/summary` or similar.
|
||||
*/
|
||||
export function usePulseChips(): PulseChipData[] {
|
||||
const [chips] = useState<PulseChipData[]>(() => MOCK_CHIPS);
|
||||
return chips;
|
||||
}
|
||||
|
||||
const MOCK_CHIPS: PulseChipData[] = [
|
||||
{
|
||||
id: "chip-1",
|
||||
name: "Lead Finder",
|
||||
status: "error" as AgentStatus,
|
||||
shortMessage: "API rate limit hit",
|
||||
},
|
||||
{
|
||||
id: "chip-2",
|
||||
name: "CEO Finder",
|
||||
status: "running" as AgentStatus,
|
||||
shortMessage: "72% complete",
|
||||
},
|
||||
{
|
||||
id: "chip-3",
|
||||
name: "Cart Recovery",
|
||||
status: "idle" as AgentStatus,
|
||||
shortMessage: "No runs in 3 weeks",
|
||||
},
|
||||
{
|
||||
id: "chip-4",
|
||||
name: "Social Collector",
|
||||
status: "listening" as AgentStatus,
|
||||
shortMessage: "Waiting for trigger",
|
||||
},
|
||||
];
|
||||
@@ -2,14 +2,17 @@ import { Navbar } from "@/components/layout/Navbar/Navbar";
|
||||
import { NetworkStatusMonitor } from "@/services/network-status/NetworkStatusMonitor";
|
||||
import { ReactNode } from "react";
|
||||
import { AdminImpersonationBanner } from "./admin/components/AdminImpersonationBanner";
|
||||
import { AutoPilotBridgeProvider } from "@/contexts/AutoPilotBridgeContext";
|
||||
|
||||
export default function PlatformLayout({ children }: { children: ReactNode }) {
|
||||
return (
|
||||
<main className="flex h-screen w-full flex-col">
|
||||
<NetworkStatusMonitor />
|
||||
<Navbar />
|
||||
<AdminImpersonationBanner />
|
||||
<section className="flex-1">{children}</section>
|
||||
</main>
|
||||
<AutoPilotBridgeProvider>
|
||||
<main className="flex h-screen w-full flex-col">
|
||||
<NetworkStatusMonitor />
|
||||
<Navbar />
|
||||
<AdminImpersonationBanner />
|
||||
<section className="flex-1">{children}</section>
|
||||
</main>
|
||||
</AutoPilotBridgeProvider>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
"use client";
|
||||
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { CaretUpIcon, CaretDownIcon } from "@phosphor-icons/react";
|
||||
import { useState } from "react";
|
||||
import type { FleetSummary, AgentStatusFilter } from "../../types";
|
||||
import { SitrepList } from "../SitrepItem/SitrepList";
|
||||
import { StatsGrid } from "./StatsGrid";
|
||||
|
||||
interface Props {
|
||||
summary: FleetSummary;
|
||||
agentIDs: string[];
|
||||
onFilterChange?: (filter: AgentStatusFilter) => void;
|
||||
activeFilter?: AgentStatusFilter;
|
||||
}
|
||||
|
||||
export function AgentBriefingPanel({
|
||||
summary,
|
||||
agentIDs,
|
||||
onFilterChange,
|
||||
activeFilter = "all",
|
||||
}: Props) {
|
||||
const [isCollapsed, setIsCollapsed] = useState(false);
|
||||
|
||||
const totalAttention = summary.error;
|
||||
|
||||
const headerSummary = [
|
||||
summary.running > 0 && `${summary.running} running`,
|
||||
totalAttention > 0 && `${totalAttention} need attention`,
|
||||
summary.listening > 0 && `${summary.listening} listening`,
|
||||
]
|
||||
.filter(Boolean)
|
||||
.join(" · ");
|
||||
|
||||
return (
|
||||
<div className="rounded-large border border-zinc-100 bg-white p-5 shadow-sm">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-3">
|
||||
<Text variant="h5">Agent Briefing</Text>
|
||||
{headerSummary && (
|
||||
<Text variant="small" className="text-zinc-500">
|
||||
{headerSummary}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={() => setIsCollapsed(!isCollapsed)}
|
||||
aria-label={isCollapsed ? "Expand briefing" : "Collapse briefing"}
|
||||
>
|
||||
{isCollapsed ? (
|
||||
<CaretDownIcon size={16} />
|
||||
) : (
|
||||
<CaretUpIcon size={16} />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{!isCollapsed && (
|
||||
<div className="mt-4 space-y-5">
|
||||
<StatsGrid
|
||||
summary={summary}
|
||||
activeFilter={activeFilter}
|
||||
onFilterChange={onFilterChange}
|
||||
/>
|
||||
<SitrepList agentIDs={agentIDs} />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
"use client";
|
||||
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import {
|
||||
CurrencyDollarIcon,
|
||||
PlayCircleIcon,
|
||||
WarningCircleIcon,
|
||||
EarIcon,
|
||||
ClockIcon,
|
||||
PauseCircleIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import type { FleetSummary, AgentStatusFilter } from "../../types";
|
||||
|
||||
interface Props {
|
||||
summary: FleetSummary;
|
||||
activeFilter: AgentStatusFilter;
|
||||
onFilterChange?: (filter: AgentStatusFilter) => void;
|
||||
}
|
||||
|
||||
export function StatsGrid({ summary, activeFilter, onFilterChange }: Props) {
|
||||
const tiles = [
|
||||
{
|
||||
label: "Spend this month",
|
||||
value: `$${summary.monthlySpend.toLocaleString()}`,
|
||||
filter: "all" as AgentStatusFilter,
|
||||
icon: CurrencyDollarIcon,
|
||||
color: "text-zinc-700",
|
||||
},
|
||||
{
|
||||
label: "Running now",
|
||||
value: summary.running,
|
||||
filter: "running" as AgentStatusFilter,
|
||||
icon: PlayCircleIcon,
|
||||
color: "text-blue-600",
|
||||
},
|
||||
{
|
||||
label: "Needs attention",
|
||||
value: summary.error,
|
||||
filter: "attention" as AgentStatusFilter,
|
||||
icon: WarningCircleIcon,
|
||||
color: "text-red-500",
|
||||
},
|
||||
{
|
||||
label: "Listening",
|
||||
value: summary.listening,
|
||||
filter: "listening" as AgentStatusFilter,
|
||||
icon: EarIcon,
|
||||
color: "text-purple-500",
|
||||
},
|
||||
{
|
||||
label: "Scheduled",
|
||||
value: summary.scheduled,
|
||||
filter: "scheduled" as AgentStatusFilter,
|
||||
icon: ClockIcon,
|
||||
color: "text-yellow-600",
|
||||
},
|
||||
{
|
||||
label: "Idle",
|
||||
value: summary.idle,
|
||||
filter: "idle" as AgentStatusFilter,
|
||||
icon: PauseCircleIcon,
|
||||
color: "text-zinc-400",
|
||||
},
|
||||
];
|
||||
|
||||
return (
|
||||
<div className="grid grid-cols-2 gap-3 sm:grid-cols-3 lg:grid-cols-6">
|
||||
{tiles.map((tile) => {
|
||||
const Icon = tile.icon;
|
||||
const isActive = activeFilter === tile.filter;
|
||||
|
||||
return (
|
||||
<button
|
||||
key={tile.label}
|
||||
type="button"
|
||||
onClick={() => onFilterChange?.(tile.filter)}
|
||||
className={cn(
|
||||
"flex flex-col gap-1 rounded-medium border p-3 text-left transition-all hover:shadow-sm",
|
||||
isActive
|
||||
? "border-zinc-900 bg-zinc-50"
|
||||
: "border-zinc-100 bg-white",
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-1.5">
|
||||
<Icon size={14} className={tile.color} />
|
||||
<Text variant="small" className="text-zinc-500">
|
||||
{tile.label}
|
||||
</Text>
|
||||
</div>
|
||||
<Text variant="h4" className={tile.color}>
|
||||
{tile.value}
|
||||
</Text>
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectGroup,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/__legacy__/ui/select";
|
||||
import { FunnelIcon } from "@phosphor-icons/react";
|
||||
import type { AgentStatusFilter, FleetSummary } from "../../types";
|
||||
|
||||
interface Props {
|
||||
value: AgentStatusFilter;
|
||||
onChange: (value: AgentStatusFilter) => void;
|
||||
summary: FleetSummary;
|
||||
}
|
||||
|
||||
export function AgentFilterMenu({ value, onChange, summary }: Props) {
|
||||
function handleChange(val: string) {
|
||||
onChange(val as AgentStatusFilter);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex items-center" data-testid="agent-filter-dropdown">
|
||||
<span className="hidden whitespace-nowrap text-sm sm:inline">filter</span>
|
||||
<Select value={value} onValueChange={handleChange}>
|
||||
<SelectTrigger className="ml-1 w-fit space-x-1 border-none px-0 text-sm underline underline-offset-4 shadow-none">
|
||||
<FunnelIcon className="h-4 w-4 sm:hidden" />
|
||||
<SelectValue placeholder="All Agents" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectGroup>
|
||||
<SelectItem value="all">All Agents</SelectItem>
|
||||
<SelectItem value="running">Running ({summary.running})</SelectItem>
|
||||
<SelectItem value="attention">
|
||||
Needs Attention ({summary.error})
|
||||
</SelectItem>
|
||||
<SelectItem value="listening">
|
||||
Listening ({summary.listening})
|
||||
</SelectItem>
|
||||
<SelectItem value="scheduled">
|
||||
Scheduled ({summary.scheduled})
|
||||
</SelectItem>
|
||||
<SelectItem value="idle">Idle / Stale ({summary.idle})</SelectItem>
|
||||
<SelectItem value="healthy">Healthy</SelectItem>
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
EyeIcon,
|
||||
ArrowsClockwiseIcon,
|
||||
MonitorPlayIcon,
|
||||
PlayIcon,
|
||||
ArrowCounterClockwiseIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import type { AgentStatus } from "../../types";
|
||||
|
||||
interface Props {
|
||||
status: AgentStatus;
|
||||
agentID: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Renders the single most relevant action for an agent based on its status.
|
||||
*
|
||||
* | Status | Action | Behaviour (TODO: wire to real endpoints) |
|
||||
* |-----------|-----------------|------------------------------------------|
|
||||
* | error | View error | Opens error detail / run log |
|
||||
* | listening | Reconnect | Opens reconnection flow |
|
||||
* | running | Watch live | Opens real-time execution view |
|
||||
* | idle | Run now | Triggers immediate new run |
|
||||
* | scheduled | Run now | Triggers immediate new run |
|
||||
*/
|
||||
export function ContextualActionButton({ status, agentID, className }: Props) {
|
||||
const { toast } = useToast();
|
||||
|
||||
const config = ACTION_CONFIG[status];
|
||||
if (!config) return null;
|
||||
|
||||
const Icon = config.icon;
|
||||
|
||||
function handleClick(e: React.MouseEvent) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
// TODO: Replace with real API calls
|
||||
toast({
|
||||
title: config.label,
|
||||
description: `${config.label} triggered for agent ${agentID.slice(0, 8)}…`,
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<Button
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={handleClick}
|
||||
leftIcon={<Icon size={14} />}
|
||||
className={className}
|
||||
>
|
||||
{config.label}
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
|
||||
const ACTION_CONFIG: Record<
|
||||
AgentStatus,
|
||||
{ label: string; icon: typeof EyeIcon }
|
||||
> = {
|
||||
error: { label: "View error", icon: EyeIcon },
|
||||
listening: { label: "Reconnect", icon: ArrowsClockwiseIcon },
|
||||
running: { label: "Watch live", icon: MonitorPlayIcon },
|
||||
idle: { label: "Run now", icon: PlayIcon },
|
||||
scheduled: { label: "Run now", icon: ArrowCounterClockwiseIcon },
|
||||
};
|
||||
@@ -12,10 +12,16 @@ import Avatar, {
|
||||
AvatarImage,
|
||||
} from "@/components/atoms/Avatar/Avatar";
|
||||
import { Link } from "@/components/atoms/Link/Link";
|
||||
import { Progress } from "@/components/atoms/Progress/Progress";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { AgentCardMenu } from "./components/AgentCardMenu";
|
||||
import { FavoriteButton } from "./components/FavoriteButton";
|
||||
import { useLibraryAgentCard } from "./useLibraryAgentCard";
|
||||
import { useFavoriteAnimation } from "../../context/FavoriteAnimationContext";
|
||||
import { StatusBadge } from "../StatusBadge/StatusBadge";
|
||||
import { ContextualActionButton } from "../ContextualActionButton/ContextualActionButton";
|
||||
import { useAgentStatus } from "../../hooks/useAgentStatus";
|
||||
import { formatTimeAgo } from "../../helpers";
|
||||
|
||||
interface Props {
|
||||
agent: LibraryAgent;
|
||||
@@ -25,6 +31,7 @@ interface Props {
|
||||
export function LibraryAgentCard({ agent, draggable = true }: Props) {
|
||||
const { id, name, graph_id, can_access_graph, image_url } = agent;
|
||||
const { triggerFavoriteAnimation } = useFavoriteAnimation();
|
||||
const statusInfo = useAgentStatus(id);
|
||||
|
||||
function handleDragStart(e: React.DragEvent<HTMLDivElement>) {
|
||||
e.dataTransfer.setData("application/agent-id", id);
|
||||
@@ -42,6 +49,9 @@ export function LibraryAgentCard({ agent, draggable = true }: Props) {
|
||||
onFavoriteAdd: triggerFavoriteAnimation,
|
||||
});
|
||||
|
||||
const hasError = statusInfo.status === "error";
|
||||
const isRunning = statusInfo.status === "running";
|
||||
|
||||
return (
|
||||
<div
|
||||
draggable={draggable}
|
||||
@@ -52,7 +62,12 @@ export function LibraryAgentCard({ agent, draggable = true }: Props) {
|
||||
layoutId={`agent-card-${id}`}
|
||||
data-testid="library-agent-card"
|
||||
data-agent-id={id}
|
||||
className="group relative inline-flex h-[10.625rem] w-full max-w-[25rem] flex-col items-start justify-start gap-2.5 rounded-medium border border-zinc-100 bg-white hover:shadow-md"
|
||||
className={cn(
|
||||
"group relative inline-flex h-auto min-h-[10.625rem] w-full max-w-[25rem] flex-col items-start justify-start gap-2.5 rounded-medium border bg-white hover:shadow-md",
|
||||
hasError
|
||||
? "border-l-2 border-b-zinc-100 border-l-red-400 border-r-zinc-100 border-t-zinc-100"
|
||||
: "border-zinc-100",
|
||||
)}
|
||||
transition={{
|
||||
type: "spring",
|
||||
damping: 25,
|
||||
@@ -79,6 +94,7 @@ export function LibraryAgentCard({ agent, draggable = true }: Props) {
|
||||
>
|
||||
{isFromMarketplace ? "FROM MARKETPLACE" : "Built by you"}
|
||||
</Text>
|
||||
<StatusBadge status={statusInfo.status} className="ml-auto" />
|
||||
</div>
|
||||
</NextLink>
|
||||
<FavoriteButton
|
||||
@@ -128,26 +144,65 @@ export function LibraryAgentCard({ agent, draggable = true }: Props) {
|
||||
)}
|
||||
</Link>
|
||||
|
||||
<div className="mt-auto flex w-full justify-start gap-6 border-t border-zinc-100 pb-1 pt-3">
|
||||
<Link
|
||||
href={`/library/agents/${id}`}
|
||||
data-testid="library-agent-card-see-runs-link"
|
||||
className="flex items-center gap-1 text-[13px]"
|
||||
>
|
||||
See runs <CaretCircleRightIcon size={20} />
|
||||
</Link>
|
||||
{/* Status details: progress bar, error message, stats */}
|
||||
{isRunning && statusInfo.progress !== null && (
|
||||
<div className="mt-1 flex items-center gap-2">
|
||||
<Progress value={statusInfo.progress} className="h-1.5 flex-1" />
|
||||
<Text variant="small" className="text-blue-600">
|
||||
{statusInfo.progress}%
|
||||
</Text>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{can_access_graph && (
|
||||
<Link
|
||||
href={`/build?flowID=${graph_id}`}
|
||||
data-testid="library-agent-card-open-in-builder-link"
|
||||
className="flex items-center gap-1 text-[13px]"
|
||||
isExternal
|
||||
>
|
||||
Open in builder <CaretCircleRightIcon size={20} />
|
||||
</Link>
|
||||
{hasError && statusInfo.lastError && (
|
||||
<Text variant="small" className="mt-1 line-clamp-1 text-red-500">
|
||||
{statusInfo.lastError}
|
||||
</Text>
|
||||
)}
|
||||
|
||||
<div className="mt-1 flex items-center gap-3">
|
||||
<Text variant="small" className="text-zinc-400">
|
||||
{statusInfo.totalRuns} runs
|
||||
</Text>
|
||||
<Text variant="small" className="text-zinc-400">
|
||||
${statusInfo.monthlySpend}
|
||||
</Text>
|
||||
{statusInfo.lastRunAt && (
|
||||
<Text variant="small" className="text-zinc-400">
|
||||
{formatTimeAgo(statusInfo.lastRunAt)}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="mt-auto flex w-full items-center justify-between gap-2 border-t border-zinc-100 pb-1 pt-3">
|
||||
<div className="flex gap-6">
|
||||
<Link
|
||||
href={`/library/agents/${id}`}
|
||||
data-testid="library-agent-card-see-runs-link"
|
||||
className="flex items-center gap-1 text-[13px]"
|
||||
>
|
||||
See runs <CaretCircleRightIcon size={20} />
|
||||
</Link>
|
||||
|
||||
{can_access_graph && (
|
||||
<Link
|
||||
href={`/build?flowID=${graph_id}`}
|
||||
data-testid="library-agent-card-open-in-builder-link"
|
||||
className="flex items-center gap-1 text-[13px]"
|
||||
isExternal
|
||||
>
|
||||
Open in builder <CaretCircleRightIcon size={20} />
|
||||
</Link>
|
||||
)}
|
||||
</div>
|
||||
<div className="opacity-0 transition-opacity group-hover:opacity-100">
|
||||
<ContextualActionButton
|
||||
status={statusInfo.status}
|
||||
agentID={id}
|
||||
className="text-xs"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</motion.div>
|
||||
</div>
|
||||
|
||||
@@ -16,8 +16,9 @@ import {
|
||||
} from "framer-motion";
|
||||
import { LibraryFolderEditDialog } from "../LibraryFolderEditDialog/LibraryFolderEditDialog";
|
||||
import { LibraryFolderDeleteDialog } from "../LibraryFolderDeleteDialog/LibraryFolderDeleteDialog";
|
||||
import { LibraryTab } from "../../types";
|
||||
import type { LibraryTab, AgentStatusFilter, FleetSummary } from "../../types";
|
||||
import { useLibraryAgentList } from "./useLibraryAgentList";
|
||||
import { AgentBriefingPanel } from "../AgentBriefingPanel/AgentBriefingPanel";
|
||||
|
||||
// cancels the current spring and starts a new one from current state.
|
||||
const containerVariants = {
|
||||
@@ -70,6 +71,9 @@ interface Props {
|
||||
tabs: LibraryTab[];
|
||||
activeTab: string;
|
||||
onTabChange: (tabId: string) => void;
|
||||
statusFilter?: AgentStatusFilter;
|
||||
onStatusFilterChange?: (filter: AgentStatusFilter) => void;
|
||||
fleetSummary?: FleetSummary;
|
||||
}
|
||||
|
||||
export function LibraryAgentList({
|
||||
@@ -81,6 +85,9 @@ export function LibraryAgentList({
|
||||
tabs,
|
||||
activeTab,
|
||||
onTabChange,
|
||||
statusFilter = "all",
|
||||
onStatusFilterChange,
|
||||
fleetSummary,
|
||||
}: Props) {
|
||||
const shouldReduceMotion = useReducedMotion();
|
||||
const activeContainerVariants = shouldReduceMotion
|
||||
@@ -95,7 +102,8 @@ export function LibraryAgentList({
|
||||
const {
|
||||
isFavoritesTab,
|
||||
agentLoading,
|
||||
allAgentsCount,
|
||||
displayedCount,
|
||||
allAgentIDs,
|
||||
favoritesCount,
|
||||
agents,
|
||||
hasNextPage,
|
||||
@@ -116,18 +124,33 @@ export function LibraryAgentList({
|
||||
selectedFolderId,
|
||||
onFolderSelect,
|
||||
activeTab,
|
||||
statusFilter,
|
||||
});
|
||||
|
||||
return (
|
||||
<>
|
||||
{!selectedFolderId && fleetSummary && (
|
||||
<div className="mb-4">
|
||||
<AgentBriefingPanel
|
||||
summary={fleetSummary}
|
||||
agentIDs={allAgentIDs}
|
||||
onFilterChange={onStatusFilterChange}
|
||||
activeFilter={statusFilter}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!selectedFolderId && (
|
||||
<LibrarySubSection
|
||||
tabs={tabs}
|
||||
activeTab={activeTab}
|
||||
onTabChange={onTabChange}
|
||||
allCount={allAgentsCount}
|
||||
allCount={displayedCount}
|
||||
favoritesCount={favoritesCount}
|
||||
setLibrarySort={setLibrarySort}
|
||||
statusFilter={statusFilter}
|
||||
onStatusFilterChange={onStatusFilterChange}
|
||||
fleetSummary={fleetSummary}
|
||||
/>
|
||||
)}
|
||||
|
||||
|
||||
@@ -22,6 +22,10 @@ import { useFavoriteAgents } from "../../hooks/useFavoriteAgents";
|
||||
import { getQueryClient } from "@/lib/react-query/queryClient";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import type { AgentStatusFilter } from "../../types";
|
||||
import { mockStatusForAgent } from "../../hooks/useAgentStatus";
|
||||
|
||||
const FILTER_EXHAUST_THRESHOLD = 3;
|
||||
|
||||
interface Props {
|
||||
searchTerm: string;
|
||||
@@ -29,6 +33,7 @@ interface Props {
|
||||
selectedFolderId: string | null;
|
||||
onFolderSelect: (folderId: string | null) => void;
|
||||
activeTab: string;
|
||||
statusFilter?: AgentStatusFilter;
|
||||
}
|
||||
|
||||
export function useLibraryAgentList({
|
||||
@@ -37,12 +42,15 @@ export function useLibraryAgentList({
|
||||
selectedFolderId,
|
||||
onFolderSelect,
|
||||
activeTab,
|
||||
statusFilter = "all",
|
||||
}: Props) {
|
||||
const isFavoritesTab = activeTab === "favorites";
|
||||
const { toast } = useToast();
|
||||
const stableQueryClient = getQueryClient();
|
||||
const queryClient = useQueryClient();
|
||||
const prevSortRef = useRef<LibraryAgentSort | null>(null);
|
||||
const consecutiveEmptyPagesRef = useRef(0);
|
||||
const prevFilteredLengthRef = useRef(0);
|
||||
|
||||
const [editingFolder, setEditingFolder] = useState<LibraryFolder | null>(
|
||||
null,
|
||||
@@ -199,6 +207,50 @@ export function useLibraryAgentList({
|
||||
|
||||
const showFolders = !isFavoritesTab;
|
||||
|
||||
// All loaded agent IDs (unfiltered) — used by AgentBriefingPanel so the
|
||||
// sitrep always covers the full fleet, not just the currently filtered view.
|
||||
const allAgentIDs = agents.map((a) => a.id);
|
||||
|
||||
// Client-side filter by status using mock data until the real API supports it.
|
||||
const filteredAgents = filterAgentsByStatus(agents, statusFilter);
|
||||
|
||||
// Track consecutive pages that produced no new filtered items
|
||||
useEffect(() => {
|
||||
if (statusFilter === "all") {
|
||||
consecutiveEmptyPagesRef.current = 0;
|
||||
prevFilteredLengthRef.current = filteredAgents.length;
|
||||
return;
|
||||
}
|
||||
|
||||
const newFilteredCount = filteredAgents.length;
|
||||
const previousCount = prevFilteredLengthRef.current;
|
||||
|
||||
if (newFilteredCount > previousCount) {
|
||||
// New filtered items were added, reset counter
|
||||
consecutiveEmptyPagesRef.current = 0;
|
||||
} else if (!isFetchingNextPage && previousCount > 0) {
|
||||
// No new items and not currently fetching means last fetch was empty
|
||||
consecutiveEmptyPagesRef.current++;
|
||||
}
|
||||
|
||||
prevFilteredLengthRef.current = newFilteredCount;
|
||||
}, [filteredAgents.length, statusFilter, isFetchingNextPage]);
|
||||
|
||||
// Reset counter when statusFilter changes
|
||||
useEffect(() => {
|
||||
consecutiveEmptyPagesRef.current = 0;
|
||||
prevFilteredLengthRef.current = 0;
|
||||
}, [statusFilter]);
|
||||
|
||||
// Derive filteredExhausted: stop fetching when threshold reached
|
||||
const filteredExhausted =
|
||||
statusFilter !== "all" &&
|
||||
consecutiveEmptyPagesRef.current >= FILTER_EXHAUST_THRESHOLD;
|
||||
|
||||
// When a filter is active, show the filtered count instead of the API total.
|
||||
const displayedCount =
|
||||
statusFilter === "all" ? allAgentsCount : filteredAgents.length;
|
||||
|
||||
function handleFolderDeleted() {
|
||||
if (selectedFolderId === deletingFolder?.id) {
|
||||
onFolderSelect(null);
|
||||
@@ -210,9 +262,11 @@ export function useLibraryAgentList({
|
||||
agentLoading,
|
||||
agentCount,
|
||||
allAgentsCount,
|
||||
displayedCount,
|
||||
allAgentIDs,
|
||||
favoritesCount: favoriteAgentsData.agentCount,
|
||||
agents,
|
||||
hasNextPage: agentsHasNextPage,
|
||||
agents: filteredAgents,
|
||||
hasNextPage: agentsHasNextPage && !filteredExhausted,
|
||||
isFetchingNextPage: agentsIsFetchingNextPage,
|
||||
fetchNextPage: agentsFetchNextPage,
|
||||
foldersData,
|
||||
@@ -226,3 +280,16 @@ export function useLibraryAgentList({
|
||||
handleFolderDeleted,
|
||||
};
|
||||
}
|
||||
|
||||
function filterAgentsByStatus<T extends { id: string }>(
|
||||
agents: T[],
|
||||
statusFilter: AgentStatusFilter,
|
||||
): T[] {
|
||||
if (statusFilter === "all") return agents;
|
||||
return agents.filter((agent) => {
|
||||
const info = mockStatusForAgent(agent.id);
|
||||
if (statusFilter === "attention") return info.health === "attention";
|
||||
if (statusFilter === "healthy") return info.health === "good";
|
||||
return info.status === statusFilter;
|
||||
});
|
||||
}
|
||||
@@ -10,6 +10,8 @@ import {
|
||||
} from "./FolderIcon";
|
||||
import { useState } from "react";
|
||||
import { PencilSimpleIcon, TrashIcon } from "@phosphor-icons/react";
|
||||
import type { AgentStatus } from "../../types";
|
||||
import { StatusBadge } from "../StatusBadge/StatusBadge";
|
||||
|
||||
interface Props {
|
||||
id: string;
|
||||
@@ -21,6 +23,8 @@ interface Props {
|
||||
onDelete?: () => void;
|
||||
onAgentDrop?: (agentId: string, folderId: string) => void;
|
||||
onClick?: () => void;
|
||||
/** Worst status among child agents (optional, for status aggregation). */
|
||||
worstStatus?: AgentStatus;
|
||||
}
|
||||
|
||||
export function LibraryFolder({
|
||||
@@ -33,6 +37,7 @@ export function LibraryFolder({
|
||||
onDelete,
|
||||
onAgentDrop,
|
||||
onClick,
|
||||
worstStatus,
|
||||
}: Props) {
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
const [isDragOver, setIsDragOver] = useState(false);
|
||||
@@ -86,13 +91,18 @@ export function LibraryFolder({
|
||||
>
|
||||
{name}
|
||||
</Text>
|
||||
<Text
|
||||
variant="small"
|
||||
className="text-zinc-500"
|
||||
data-testid="library-folder-agent-count"
|
||||
>
|
||||
{agentCount} {agentCount === 1 ? "agent" : "agents"}
|
||||
</Text>
|
||||
<div className="flex items-center gap-2">
|
||||
<Text
|
||||
variant="small"
|
||||
className="text-zinc-500"
|
||||
data-testid="library-folder-agent-count"
|
||||
>
|
||||
{agentCount} {agentCount === 1 ? "agent" : "agents"}
|
||||
</Text>
|
||||
{worstStatus && worstStatus !== "idle" && (
|
||||
<StatusBadge status={worstStatus} />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Right side - Custom folder icon */}
|
||||
|
||||
@@ -6,9 +6,10 @@ import {
|
||||
} from "@/components/molecules/TabsLine/TabsLine";
|
||||
import { LibraryAgentSort } from "@/app/api/__generated__/models/libraryAgentSort";
|
||||
import { useFavoriteAnimation } from "../../context/FavoriteAnimationContext";
|
||||
import { LibraryTab } from "../../types";
|
||||
import type { LibraryTab, AgentStatusFilter, FleetSummary } from "../../types";
|
||||
import LibraryFolderCreationDialog from "../LibraryFolderCreationDialog/LibraryFolderCreationDialog";
|
||||
import { LibrarySortMenu } from "../LibrarySortMenu/LibrarySortMenu";
|
||||
import { AgentFilterMenu } from "../AgentFilterMenu/AgentFilterMenu";
|
||||
|
||||
interface Props {
|
||||
tabs: LibraryTab[];
|
||||
@@ -17,6 +18,9 @@ interface Props {
|
||||
allCount: number;
|
||||
favoritesCount: number;
|
||||
setLibrarySort: (value: LibraryAgentSort) => void;
|
||||
statusFilter?: AgentStatusFilter;
|
||||
onStatusFilterChange?: (filter: AgentStatusFilter) => void;
|
||||
fleetSummary?: FleetSummary;
|
||||
}
|
||||
|
||||
export function LibrarySubSection({
|
||||
@@ -26,6 +30,9 @@ export function LibrarySubSection({
|
||||
allCount,
|
||||
favoritesCount,
|
||||
setLibrarySort,
|
||||
statusFilter = "all",
|
||||
onStatusFilterChange,
|
||||
fleetSummary,
|
||||
}: Props) {
|
||||
const { registerFavoritesTabRef } = useFavoriteAnimation();
|
||||
const favoritesRef = useRef<HTMLButtonElement>(null);
|
||||
@@ -70,6 +77,13 @@ export function LibrarySubSection({
|
||||
</TabsLine>
|
||||
<div className="hidden items-center gap-6 md:flex">
|
||||
<LibraryFolderCreationDialog />
|
||||
{fleetSummary && onStatusFilterChange && (
|
||||
<AgentFilterMenu
|
||||
value={statusFilter}
|
||||
onChange={onStatusFilterChange}
|
||||
summary={fleetSummary}
|
||||
/>
|
||||
)}
|
||||
<LibrarySortMenu setLibrarySort={setLibrarySort} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -0,0 +1,118 @@
|
||||
"use client";
|
||||
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
WarningCircleIcon,
|
||||
PlayIcon,
|
||||
ClockCountdownIcon,
|
||||
CheckCircleIcon,
|
||||
ChatCircleDotsIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import type { AgentStatus } from "../../types";
|
||||
import { ContextualActionButton } from "../ContextualActionButton/ContextualActionButton";
|
||||
|
||||
export type SitrepPriority = "error" | "running" | "stale" | "success";
|
||||
|
||||
export interface SitrepItemData {
|
||||
id: string;
|
||||
agentID: string;
|
||||
agentName: string;
|
||||
priority: SitrepPriority;
|
||||
message: string;
|
||||
status: AgentStatus;
|
||||
}
|
||||
|
||||
interface Props {
|
||||
item: SitrepItemData;
|
||||
onAskAutoPilot?: (prompt: string) => void;
|
||||
}
|
||||
|
||||
const PRIORITY_CONFIG: Record<
|
||||
SitrepPriority,
|
||||
{ icon: typeof WarningCircleIcon; color: string; bg: string }
|
||||
> = {
|
||||
error: {
|
||||
icon: WarningCircleIcon,
|
||||
color: "text-red-500",
|
||||
bg: "bg-red-50",
|
||||
},
|
||||
running: {
|
||||
icon: PlayIcon,
|
||||
color: "text-blue-600",
|
||||
bg: "bg-blue-50",
|
||||
},
|
||||
stale: {
|
||||
icon: ClockCountdownIcon,
|
||||
color: "text-yellow-600",
|
||||
bg: "bg-yellow-50",
|
||||
},
|
||||
success: {
|
||||
icon: CheckCircleIcon,
|
||||
color: "text-green-600",
|
||||
bg: "bg-green-50",
|
||||
},
|
||||
};
|
||||
|
||||
export function SitrepItem({ item, onAskAutoPilot }: Props) {
|
||||
const config = PRIORITY_CONFIG[item.priority];
|
||||
const Icon = config.icon;
|
||||
|
||||
function handleAskAutoPilot() {
|
||||
const prompt = buildAutoPilotPrompt(item);
|
||||
onAskAutoPilot?.(prompt);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"group flex items-start gap-3 rounded-medium border border-transparent p-3 transition-colors hover:border-zinc-100 hover:bg-zinc-50/50",
|
||||
)}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
"mt-0.5 flex h-6 w-6 flex-shrink-0 items-center justify-center rounded-full",
|
||||
config.bg,
|
||||
)}
|
||||
>
|
||||
<Icon size={14} className={config.color} weight="fill" />
|
||||
</div>
|
||||
|
||||
<div className="min-w-0 flex-1">
|
||||
<Text variant="small-medium" className="text-zinc-900">
|
||||
{item.agentName}
|
||||
</Text>
|
||||
<Text variant="small" className="mt-0.5 text-zinc-500">
|
||||
{item.message}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-shrink-0 items-center gap-1.5 opacity-0 transition-opacity group-hover:opacity-100">
|
||||
<ContextualActionButton status={item.status} agentID={item.agentID} />
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
onClick={handleAskAutoPilot}
|
||||
leftIcon={<ChatCircleDotsIcon size={14} />}
|
||||
className="text-xs"
|
||||
>
|
||||
Ask AutoPilot
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function buildAutoPilotPrompt(item: SitrepItemData): string {
|
||||
switch (item.priority) {
|
||||
case "error":
|
||||
return `What happened with ${item.agentName}? It says "${item.message}" — can you check the logs and tell me what to fix?`;
|
||||
case "running":
|
||||
return `Give me a status update on the ${item.agentName} run — what has it found so far?`;
|
||||
case "stale":
|
||||
return `${item.agentName} hasn't run recently. Should I keep it or update and re-run it?`;
|
||||
case "success":
|
||||
return `How has ${item.agentName} been performing? Give me a quick summary of recent results.`;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
"use client";
|
||||
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { useSitrepItems } from "./useSitrepItems";
|
||||
import { SitrepItem } from "./SitrepItem";
|
||||
import { useAutoPilotBridge } from "@/contexts/AutoPilotBridgeContext";
|
||||
|
||||
interface Props {
|
||||
agentIDs: string[];
|
||||
maxItems?: number;
|
||||
}
|
||||
|
||||
export function SitrepList({ agentIDs, maxItems = 10 }: Props) {
|
||||
const items = useSitrepItems(agentIDs, maxItems);
|
||||
const { sendPrompt } = useAutoPilotBridge();
|
||||
|
||||
if (items.length === 0) {
|
||||
return (
|
||||
<div className="py-4 text-center">
|
||||
<Text variant="small" className="text-zinc-400">
|
||||
All agents are healthy — nothing to report.
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="mb-2 flex items-center justify-between">
|
||||
<Text variant="small-medium" className="text-zinc-700">
|
||||
AI Summary
|
||||
</Text>
|
||||
<Text variant="small" className="text-zinc-400">
|
||||
Updated just now
|
||||
</Text>
|
||||
</div>
|
||||
<div className="space-y-1">
|
||||
{items.map((item) => (
|
||||
<SitrepItem key={item.id} item={item} onAskAutoPilot={sendPrompt} />
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
"use client";
|
||||
|
||||
import { useMemo } from "react";
|
||||
import { mockStatusForAgent } from "../../hooks/useAgentStatus";
|
||||
import type { SitrepItemData, SitrepPriority } from "./SitrepItem";
|
||||
import type { AgentStatus } from "../../types";
|
||||
|
||||
/**
|
||||
* Produce a prioritised list of sitrep items from agent IDs.
|
||||
* Priority order: error → running → stale → success.
|
||||
*
|
||||
* TODO: Replace with `GET /agents/sitrep` once the backend endpoint exists.
|
||||
*/
|
||||
export function useSitrepItems(
|
||||
agentIDs: string[],
|
||||
maxItems: number,
|
||||
): SitrepItemData[] {
|
||||
const items = useMemo<SitrepItemData[]>(() => {
|
||||
const raw: SitrepItemData[] = agentIDs.map((id) => {
|
||||
const info = mockStatusForAgent(id);
|
||||
return {
|
||||
id,
|
||||
agentID: id,
|
||||
agentName: `Agent ${id.slice(0, 6)}`,
|
||||
priority: toPriority(info.status, info.health === "stale"),
|
||||
message: buildMessage(info.status, info.lastError, info.progress),
|
||||
status: info.status,
|
||||
};
|
||||
});
|
||||
|
||||
const order: Record<SitrepPriority, number> = {
|
||||
error: 0,
|
||||
running: 1,
|
||||
stale: 2,
|
||||
success: 3,
|
||||
};
|
||||
raw.sort((a, b) => order[a.priority] - order[b.priority]);
|
||||
|
||||
return raw.slice(0, maxItems);
|
||||
}, [agentIDs, maxItems]);
|
||||
|
||||
return items;
|
||||
}
|
||||
|
||||
function toPriority(status: AgentStatus, isStale: boolean): SitrepPriority {
|
||||
if (status === "error") return "error";
|
||||
if (status === "running") return "running";
|
||||
if (isStale || status === "idle") return "stale";
|
||||
return "success";
|
||||
}
|
||||
|
||||
function buildMessage(
|
||||
status: AgentStatus,
|
||||
lastError: string | null,
|
||||
progress: number | null,
|
||||
): string {
|
||||
switch (status) {
|
||||
case "error":
|
||||
return lastError ?? "Unknown error occurred";
|
||||
case "running":
|
||||
return progress !== null
|
||||
? `${progress}% complete`
|
||||
: "Currently executing";
|
||||
case "idle":
|
||||
return "Hasn't run recently — still relevant?";
|
||||
case "listening":
|
||||
return "Waiting for trigger event";
|
||||
case "scheduled":
|
||||
return "Next run scheduled";
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import type { AgentStatus } from "../../types";
|
||||
|
||||
const STATUS_CONFIG: Record<
|
||||
AgentStatus,
|
||||
{ label: string; bg: string; text: string; pulse: boolean }
|
||||
> = {
|
||||
running: {
|
||||
label: "Running",
|
||||
bg: "bg-blue-50",
|
||||
text: "text-blue-600",
|
||||
pulse: true,
|
||||
},
|
||||
error: {
|
||||
label: "Error",
|
||||
bg: "bg-red-50",
|
||||
text: "text-red-500",
|
||||
pulse: false,
|
||||
},
|
||||
listening: {
|
||||
label: "Listening",
|
||||
bg: "bg-purple-50",
|
||||
text: "text-purple-500",
|
||||
pulse: true,
|
||||
},
|
||||
scheduled: {
|
||||
label: "Scheduled",
|
||||
bg: "bg-yellow-50",
|
||||
text: "text-yellow-600",
|
||||
pulse: false,
|
||||
},
|
||||
idle: {
|
||||
label: "Idle",
|
||||
bg: "bg-zinc-100",
|
||||
text: "text-zinc-500",
|
||||
pulse: false,
|
||||
},
|
||||
};
|
||||
|
||||
interface Props {
|
||||
status: AgentStatus;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function StatusBadge({ status, className }: Props) {
|
||||
const config = STATUS_CONFIG[status];
|
||||
|
||||
return (
|
||||
<span
|
||||
className={cn(
|
||||
"inline-flex items-center gap-1.5 rounded-full px-2 py-0.5 text-xs font-medium",
|
||||
config.bg,
|
||||
config.text,
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<span
|
||||
className={cn(
|
||||
"inline-block h-1.5 w-1.5 rounded-full",
|
||||
config.pulse && "animate-pulse",
|
||||
statusDotColor(status),
|
||||
)}
|
||||
/>
|
||||
{config.label}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
function statusDotColor(status: AgentStatus): string {
|
||||
switch (status) {
|
||||
case "running":
|
||||
return "bg-blue-500";
|
||||
case "error":
|
||||
return "bg-red-500";
|
||||
case "listening":
|
||||
return "bg-purple-500";
|
||||
case "scheduled":
|
||||
return "bg-yellow-500";
|
||||
case "idle":
|
||||
return "bg-zinc-400";
|
||||
}
|
||||
}
|
||||
|
||||
export { STATUS_CONFIG };
|
||||
@@ -0,0 +1,16 @@
|
||||
/**
|
||||
* Formats an ISO date string into a human-readable relative time string.
|
||||
* e.g. "3m ago", "2h ago", "5d ago".
|
||||
*/
|
||||
export function formatTimeAgo(isoDate: string): string {
|
||||
const parsed = new Date(isoDate).getTime();
|
||||
if (Number.isNaN(parsed)) return "unknown";
|
||||
const diff = Date.now() - parsed;
|
||||
if (diff < 0) return "just now";
|
||||
const minutes = Math.floor(diff / 60000);
|
||||
if (minutes < 60) return `${minutes}m ago`;
|
||||
const hours = Math.floor(minutes / 60);
|
||||
if (hours < 24) return `${hours}h ago`;
|
||||
const days = Math.floor(hours / 24);
|
||||
return `${days}d ago`;
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
"use client";
|
||||
|
||||
import { useMemo, useState } from "react";
|
||||
import type {
|
||||
AgentStatus,
|
||||
AgentHealth,
|
||||
AgentStatusInfo,
|
||||
FleetSummary,
|
||||
} from "../types";
|
||||
|
||||
/**
|
||||
* Derive health from status and recency.
|
||||
* TODO: Replace with real computation once backend provides the data.
|
||||
*/
|
||||
function deriveHealth(
|
||||
status: AgentStatus,
|
||||
lastRunAt: string | null,
|
||||
): AgentHealth {
|
||||
if (status === "error") return "attention";
|
||||
if (status === "idle" && lastRunAt) {
|
||||
const daysSince =
|
||||
(Date.now() - new Date(lastRunAt).getTime()) / (1000 * 60 * 60 * 24);
|
||||
if (daysSince > 14) return "stale";
|
||||
}
|
||||
return "good";
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate deterministic mock status for an agent based on its ID.
|
||||
* This allows the UI to render realistic data before the real API is built.
|
||||
* TODO: Replace with real API call `GET /agents/:id/status`.
|
||||
*/
|
||||
function mockStatusForAgent(agentID: string): AgentStatusInfo {
|
||||
const hash = simpleHash(agentID);
|
||||
const statuses: AgentStatus[] = [
|
||||
"running",
|
||||
"error",
|
||||
"listening",
|
||||
"scheduled",
|
||||
"idle",
|
||||
];
|
||||
const status = statuses[hash % statuses.length];
|
||||
const progress = status === "running" ? (hash * 17) % 100 : null;
|
||||
const totalRuns = (hash % 200) + 1;
|
||||
const daysAgo = (hash % 30) + 1;
|
||||
const lastRunAt = new Date(
|
||||
Date.now() - daysAgo * 24 * 60 * 60 * 1000,
|
||||
).toISOString();
|
||||
const lastError =
|
||||
status === "error" ? "API rate limit exceeded — paused" : null;
|
||||
const monthlySpend = Number(((hash % 5000) / 100).toFixed(2));
|
||||
|
||||
return {
|
||||
status,
|
||||
health: deriveHealth(status, lastRunAt),
|
||||
progress,
|
||||
totalRuns,
|
||||
lastRunAt,
|
||||
lastError,
|
||||
monthlySpend,
|
||||
nextScheduledRun:
|
||||
status === "scheduled"
|
||||
? new Date(Date.now() + 3600_000).toISOString()
|
||||
: null,
|
||||
triggerType: status === "listening" ? "webhook" : null,
|
||||
};
|
||||
}
|
||||
|
||||
function simpleHash(str: string): number {
|
||||
let h = 0;
|
||||
for (let i = 0; i < str.length; i++) {
|
||||
h = (h * 31 + str.charCodeAt(i)) >>> 0;
|
||||
}
|
||||
return h;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook returning status info for a single agent.
|
||||
* TODO: Wire to `GET /agents/:id/status` + WebSocket `/agents/live`.
|
||||
*/
|
||||
export function useAgentStatus(agentID: string): AgentStatusInfo {
|
||||
// NOTE: useState initializer runs once on mount; a new agentID prop will not
|
||||
// re-derive info. Replace with a real API call wired to the agentID param.
|
||||
const [info] = useState(() => mockStatusForAgent(agentID));
|
||||
return info;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook returning fleet-wide summary counts.
|
||||
* TODO: Wire to `GET /agents/summary`.
|
||||
*/
|
||||
export function useFleetSummary(agentIDs: string[]): FleetSummary {
|
||||
const summary = useMemo<FleetSummary>(() => {
|
||||
const counts: FleetSummary = {
|
||||
running: 0,
|
||||
error: 0,
|
||||
listening: 0,
|
||||
scheduled: 0,
|
||||
idle: 0,
|
||||
monthlySpend: 0,
|
||||
};
|
||||
for (const id of agentIDs) {
|
||||
const info = mockStatusForAgent(id);
|
||||
counts[info.status] += 1;
|
||||
counts.monthlySpend += info.monthlySpend;
|
||||
}
|
||||
counts.monthlySpend = Number(counts.monthlySpend.toFixed(2));
|
||||
return counts;
|
||||
}, [agentIDs]);
|
||||
return summary;
|
||||
}
|
||||
|
||||
export { mockStatusForAgent, deriveHealth };
|
||||
@@ -0,0 +1,25 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import type { FleetSummary } from "../types";
|
||||
|
||||
/**
|
||||
* Returns fleet-wide summary counts for the Agent Briefing Panel.
|
||||
*
|
||||
* TODO: Replace with a real `GET /agents/summary` API call once available.
|
||||
* For now, returns deterministic mock data so the UI renders correctly.
|
||||
*/
|
||||
export function useLibraryFleetSummary(): FleetSummary {
|
||||
// NOTE: useState initializer runs once on mount; the hardcoded mock values
|
||||
// will not update if the component re-renders. Replace with a real API call
|
||||
// once the backend endpoint is available.
|
||||
const [summary] = useState<FleetSummary>(() => ({
|
||||
running: 3,
|
||||
error: 2,
|
||||
listening: 4,
|
||||
scheduled: 5,
|
||||
idle: 8,
|
||||
monthlySpend: 127.45,
|
||||
}));
|
||||
return summary;
|
||||
}
|
||||
@@ -7,7 +7,8 @@ import { LibraryActionHeader } from "./components/LibraryActionHeader/LibraryAct
|
||||
import { LibraryAgentList } from "./components/LibraryAgentList/LibraryAgentList";
|
||||
import { useLibraryListPage } from "./components/useLibraryListPage";
|
||||
import { FavoriteAnimationProvider } from "./context/FavoriteAnimationContext";
|
||||
import { LibraryTab } from "./types";
|
||||
import type { LibraryTab, AgentStatusFilter } from "./types";
|
||||
import { useLibraryFleetSummary } from "./hooks/useLibraryFleetSummary";
|
||||
|
||||
const LIBRARY_TABS: LibraryTab[] = [
|
||||
{ id: "all", title: "All", icon: ListIcon },
|
||||
@@ -19,6 +20,8 @@ export default function LibraryPage() {
|
||||
useLibraryListPage();
|
||||
const [selectedFolderId, setSelectedFolderId] = useState<string | null>(null);
|
||||
const [activeTab, setActiveTab] = useState(LIBRARY_TABS[0].id);
|
||||
const [statusFilter, setStatusFilter] = useState<AgentStatusFilter>("all");
|
||||
const fleetSummary = useLibraryFleetSummary();
|
||||
|
||||
useEffect(() => {
|
||||
document.title = "Library – AutoGPT Platform";
|
||||
@@ -50,6 +53,9 @@ export default function LibraryPage() {
|
||||
tabs={LIBRARY_TABS}
|
||||
activeTab={activeTab}
|
||||
onTabChange={handleTabChange}
|
||||
statusFilter={statusFilter}
|
||||
onStatusFilterChange={setStatusFilter}
|
||||
fleetSummary={fleetSummary}
|
||||
/>
|
||||
</main>
|
||||
</FavoriteAnimationProvider>
|
||||
|
||||
@@ -1,7 +1,52 @@
|
||||
import { Icon } from "@phosphor-icons/react";
|
||||
import type { Icon } from "@phosphor-icons/react";
|
||||
|
||||
export interface LibraryTab {
|
||||
id: string;
|
||||
title: string;
|
||||
icon: Icon;
|
||||
}
|
||||
|
||||
/** Agent execution status — drives StatusBadge visuals & filtering. */
|
||||
export type AgentStatus =
|
||||
| "running"
|
||||
| "error"
|
||||
| "listening"
|
||||
| "scheduled"
|
||||
| "idle";
|
||||
|
||||
/** Derived health bucket for quick triage. */
|
||||
export type AgentHealth = "good" | "attention" | "stale";
|
||||
|
||||
/** Real-time metadata that powers the Intelligence Layer features. */
|
||||
export interface AgentStatusInfo {
|
||||
status: AgentStatus;
|
||||
health: AgentHealth;
|
||||
/** 0-100 progress for currently running agents. */
|
||||
progress: number | null;
|
||||
totalRuns: number;
|
||||
lastRunAt: string | null;
|
||||
lastError: string | null;
|
||||
monthlySpend: number;
|
||||
nextScheduledRun: string | null;
|
||||
triggerType: string | null;
|
||||
}
|
||||
|
||||
/** Fleet-wide aggregate counts used by the Briefing Panel stats grid. */
|
||||
export interface FleetSummary {
|
||||
running: number;
|
||||
error: number;
|
||||
listening: number;
|
||||
scheduled: number;
|
||||
idle: number;
|
||||
monthlySpend: number;
|
||||
}
|
||||
|
||||
/** Filter options for the agent filter dropdown. */
|
||||
export type AgentStatusFilter =
|
||||
| "all"
|
||||
| "running"
|
||||
| "attention"
|
||||
| "listening"
|
||||
| "scheduled"
|
||||
| "idle"
|
||||
| "healthy";
|
||||
|
||||
@@ -198,12 +198,14 @@ export default function UserIntegrationsPage() {
|
||||
</small>
|
||||
</TableCell>
|
||||
<TableCell className="w-0 whitespace-nowrap">
|
||||
<Button
|
||||
variant="destructive"
|
||||
onClick={() => removeCredentials(cred.provider, cred.id)}
|
||||
>
|
||||
<Trash2Icon className="mr-1.5 size-4" /> Delete
|
||||
</Button>
|
||||
{!cred.is_managed && (
|
||||
<Button
|
||||
variant="destructive"
|
||||
onClick={() => removeCredentials(cred.provider, cred.id)}
|
||||
>
|
||||
<Trash2Icon className="mr-1.5 size-4" /> Delete
|
||||
</Button>
|
||||
)}
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
|
||||
@@ -5457,290 +5457,6 @@
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/platform-linking/chat/session": {
|
||||
"post": {
|
||||
"tags": ["platform-linking"],
|
||||
"summary": "Create a CoPilot session for a linked user (bot-facing)",
|
||||
"description": "Creates a new CoPilot chat session on behalf of a linked user.",
|
||||
"operationId": "postPlatform-linkingCreate a copilot session for a linked user (bot-facing)",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/BotChatRequest" }
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/BotChatSessionResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/platform-linking/chat/stream": {
|
||||
"post": {
|
||||
"tags": ["platform-linking"],
|
||||
"summary": "Stream a CoPilot response for a linked user (bot-facing)",
|
||||
"description": "Send a message to CoPilot on behalf of a linked user and stream\nthe response back as Server-Sent Events.\n\nThe bot authenticates with its API key — no user JWT needed.",
|
||||
"operationId": "postPlatform-linkingStream a copilot response for a linked user (bot-facing)",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/BotChatRequest" }
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": { "application/json": { "schema": {} } }
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/platform-linking/links": {
|
||||
"get": {
|
||||
"tags": ["platform-linking"],
|
||||
"summary": "List all platform links for the authenticated user",
|
||||
"description": "Returns all platform identities linked to the current user's account.",
|
||||
"operationId": "getPlatform-linkingList all platform links for the authenticated user",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"items": { "$ref": "#/components/schemas/PlatformLinkInfo" },
|
||||
"type": "array",
|
||||
"title": "Response Getplatform-Linkinglist All Platform Links For The Authenticated User"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
}
|
||||
},
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/platform-linking/links/{link_id}": {
|
||||
"delete": {
|
||||
"tags": ["platform-linking"],
|
||||
"summary": "Unlink a platform identity",
|
||||
"description": "Removes a platform link. The user will need to re-link if they\nwant to use the bot on that platform again.",
|
||||
"operationId": "deletePlatform-linkingUnlink a platform identity",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "link_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Link Id" }
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/DeleteLinkResponse" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/platform-linking/resolve": {
|
||||
"post": {
|
||||
"tags": ["platform-linking"],
|
||||
"summary": "Resolve a platform identity to an AutoGPT user",
|
||||
"description": "Called by the bot service on every incoming message to check if\nthe platform user has a linked AutoGPT account.",
|
||||
"operationId": "postPlatform-linkingResolve a platform identity to an autogpt user",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/ResolveRequest" }
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/ResolveResponse" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/platform-linking/tokens": {
|
||||
"post": {
|
||||
"tags": ["platform-linking"],
|
||||
"summary": "Create a link token for an unlinked platform user",
|
||||
"description": "Called by the bot service when it encounters an unlinked user.\nGenerates a one-time token the user can use to link their account.",
|
||||
"operationId": "postPlatform-linkingCreate a link token for an unlinked platform user",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/CreateLinkTokenRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/LinkTokenResponse" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/platform-linking/tokens/{token}/confirm": {
|
||||
"post": {
|
||||
"tags": ["platform-linking"],
|
||||
"summary": "Confirm a link token (user must be authenticated)",
|
||||
"description": "Called by the frontend when the user clicks the link and is logged in.\nConsumes the token and creates the platform link.\nUses atomic update_many to prevent race conditions on double-click.",
|
||||
"operationId": "postPlatform-linkingConfirm a link token (user must be authenticated)",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "token",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"maxLength": 64,
|
||||
"pattern": "^[A-Za-z0-9_-]+$",
|
||||
"title": "Token"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/ConfirmLinkResponse" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/platform-linking/tokens/{token}/status": {
|
||||
"get": {
|
||||
"tags": ["platform-linking"],
|
||||
"summary": "Check if a link token has been consumed",
|
||||
"description": "Called by the bot service to check if a user has completed linking.",
|
||||
"operationId": "getPlatform-linkingCheck if a link token has been consumed",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "token",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"maxLength": 64,
|
||||
"pattern": "^[A-Za-z0-9_-]+$",
|
||||
"title": "Token"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/LinkTokenStatusResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/public/shared/{share_token}": {
|
||||
"get": {
|
||||
"tags": ["v1"],
|
||||
@@ -7330,6 +7046,16 @@
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Title"
|
||||
},
|
||||
"is_managed": {
|
||||
"type": "boolean",
|
||||
"title": "Is Managed",
|
||||
"default": false
|
||||
},
|
||||
"metadata": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Metadata"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "api_key",
|
||||
@@ -8596,40 +8322,6 @@
|
||||
"required": ["file"],
|
||||
"title": "Body_postWorkspaceUpload file to workspace"
|
||||
},
|
||||
"BotChatRequest": {
|
||||
"properties": {
|
||||
"user_id": {
|
||||
"type": "string",
|
||||
"title": "User Id",
|
||||
"description": "The linked AutoGPT user ID"
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"maxLength": 32000,
|
||||
"minLength": 1,
|
||||
"title": "Message",
|
||||
"description": "The user's message"
|
||||
},
|
||||
"session_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Session Id",
|
||||
"description": "Existing chat session ID. If omitted, a new session is created."
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["user_id", "message"],
|
||||
"title": "BotChatRequest",
|
||||
"description": "Request from the bot to chat as a linked user."
|
||||
},
|
||||
"BotChatSessionResponse": {
|
||||
"properties": {
|
||||
"session_id": { "type": "string", "title": "Session Id" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["session_id"],
|
||||
"title": "BotChatSessionResponse",
|
||||
"description": "Returned when creating a new session via the bot proxy."
|
||||
},
|
||||
"BulkMoveAgentsRequest": {
|
||||
"properties": {
|
||||
"agent_ids": {
|
||||
@@ -8745,25 +8437,6 @@
|
||||
"title": "CoPilotUsageStatus",
|
||||
"description": "Current usage status for a user across all windows."
|
||||
},
|
||||
"ConfirmLinkResponse": {
|
||||
"properties": {
|
||||
"success": { "type": "boolean", "title": "Success" },
|
||||
"platform": { "type": "string", "title": "Platform" },
|
||||
"platform_user_id": { "type": "string", "title": "Platform User Id" },
|
||||
"platform_username": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Platform Username"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"success",
|
||||
"platform",
|
||||
"platform_user_id",
|
||||
"platform_username"
|
||||
],
|
||||
"title": "ConfirmLinkResponse"
|
||||
},
|
||||
"ContentType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -8841,41 +8514,6 @@
|
||||
"required": ["graph"],
|
||||
"title": "CreateGraph"
|
||||
},
|
||||
"CreateLinkTokenRequest": {
|
||||
"properties": {
|
||||
"platform": {
|
||||
"$ref": "#/components/schemas/Platform",
|
||||
"description": "Platform name"
|
||||
},
|
||||
"platform_user_id": {
|
||||
"type": "string",
|
||||
"maxLength": 255,
|
||||
"minLength": 1,
|
||||
"title": "Platform User Id",
|
||||
"description": "The user's ID on the platform"
|
||||
},
|
||||
"platform_username": {
|
||||
"anyOf": [
|
||||
{ "type": "string", "maxLength": 255 },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Platform Username",
|
||||
"description": "Display name (best effort)"
|
||||
},
|
||||
"channel_id": {
|
||||
"anyOf": [
|
||||
{ "type": "string", "maxLength": 255 },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Channel Id",
|
||||
"description": "Channel ID for sending confirmation back"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["platform", "platform_user_id"],
|
||||
"title": "CreateLinkTokenRequest",
|
||||
"description": "Request from the bot service to create a linking token."
|
||||
},
|
||||
"CreateSessionResponse": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
@@ -9033,6 +8671,11 @@
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Host",
|
||||
"description": "Host pattern for host-scoped or MCP server URL for MCP credentials"
|
||||
},
|
||||
"is_managed": {
|
||||
"type": "boolean",
|
||||
"title": "Is Managed",
|
||||
"default": false
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
@@ -9058,12 +8701,6 @@
|
||||
"required": ["version_counts"],
|
||||
"title": "DeleteGraphResponse"
|
||||
},
|
||||
"DeleteLinkResponse": {
|
||||
"properties": { "success": { "type": "boolean", "title": "Success" } },
|
||||
"type": "object",
|
||||
"required": ["success"],
|
||||
"title": "DeleteLinkResponse"
|
||||
},
|
||||
"DiscoverToolsRequest": {
|
||||
"properties": {
|
||||
"server_url": {
|
||||
@@ -10270,6 +9907,16 @@
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Title"
|
||||
},
|
||||
"is_managed": {
|
||||
"type": "boolean",
|
||||
"title": "Is Managed",
|
||||
"default": false
|
||||
},
|
||||
"metadata": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Metadata"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "host_scoped",
|
||||
@@ -10835,36 +10482,6 @@
|
||||
"required": ["source_id", "sink_id", "source_name", "sink_name"],
|
||||
"title": "Link"
|
||||
},
|
||||
"LinkTokenResponse": {
|
||||
"properties": {
|
||||
"token": { "type": "string", "title": "Token" },
|
||||
"expires_at": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"title": "Expires At"
|
||||
},
|
||||
"link_url": { "type": "string", "title": "Link Url" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["token", "expires_at", "link_url"],
|
||||
"title": "LinkTokenResponse"
|
||||
},
|
||||
"LinkTokenStatusResponse": {
|
||||
"properties": {
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["pending", "linked", "expired"],
|
||||
"title": "Status"
|
||||
},
|
||||
"user_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "User Id"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["status"],
|
||||
"title": "LinkTokenStatusResponse"
|
||||
},
|
||||
"ListSessionsResponse": {
|
||||
"properties": {
|
||||
"sessions": {
|
||||
@@ -11393,6 +11010,16 @@
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Title"
|
||||
},
|
||||
"is_managed": {
|
||||
"type": "boolean",
|
||||
"title": "Is Managed",
|
||||
"default": false
|
||||
},
|
||||
"metadata": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Metadata"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "oauth2",
|
||||
@@ -11428,11 +11055,6 @@
|
||||
"items": { "type": "string" },
|
||||
"type": "array",
|
||||
"title": "Scopes"
|
||||
},
|
||||
"metadata": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Metadata"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
@@ -11704,45 +11326,6 @@
|
||||
"title": "PendingHumanReviewModel",
|
||||
"description": "Response model for pending human review data.\n\nRepresents a human review request that is awaiting user action.\nContains all necessary information for a user to review and approve\nor reject data from a Human-in-the-Loop block execution.\n\nAttributes:\n id: Unique identifier for the review record\n user_id: ID of the user who must perform the review\n node_exec_id: ID of the node execution that created this review\n node_id: ID of the node definition (for grouping reviews from same node)\n graph_exec_id: ID of the graph execution containing the node\n graph_id: ID of the graph template being executed\n graph_version: Version number of the graph template\n payload: The actual data payload awaiting review\n instructions: Instructions or message for the reviewer\n editable: Whether the reviewer can edit the data\n status: Current review status (WAITING, APPROVED, or REJECTED)\n review_message: Optional message from the reviewer\n created_at: Timestamp when review was created\n updated_at: Timestamp when review was last modified\n reviewed_at: Timestamp when review was completed (if applicable)"
|
||||
},
|
||||
"Platform": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"DISCORD",
|
||||
"TELEGRAM",
|
||||
"SLACK",
|
||||
"TEAMS",
|
||||
"WHATSAPP",
|
||||
"GITHUB",
|
||||
"LINEAR"
|
||||
],
|
||||
"title": "Platform",
|
||||
"description": "Supported platform types (mirrors Prisma PlatformType)."
|
||||
},
|
||||
"PlatformLinkInfo": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
"platform": { "type": "string", "title": "Platform" },
|
||||
"platform_user_id": { "type": "string", "title": "Platform User Id" },
|
||||
"platform_username": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Platform Username"
|
||||
},
|
||||
"linked_at": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"title": "Linked At"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"id",
|
||||
"platform",
|
||||
"platform_user_id",
|
||||
"platform_username",
|
||||
"linked_at"
|
||||
],
|
||||
"title": "PlatformLinkInfo"
|
||||
},
|
||||
"PostmarkBounceEnum": {
|
||||
"type": "integer",
|
||||
"enum": [
|
||||
@@ -12265,33 +11848,6 @@
|
||||
"required": ["credit_amount"],
|
||||
"title": "RequestTopUp"
|
||||
},
|
||||
"ResolveRequest": {
|
||||
"properties": {
|
||||
"platform": { "$ref": "#/components/schemas/Platform" },
|
||||
"platform_user_id": {
|
||||
"type": "string",
|
||||
"maxLength": 255,
|
||||
"minLength": 1,
|
||||
"title": "Platform User Id"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["platform", "platform_user_id"],
|
||||
"title": "ResolveRequest",
|
||||
"description": "Resolve a platform identity to an AutoGPT user."
|
||||
},
|
||||
"ResolveResponse": {
|
||||
"properties": {
|
||||
"linked": { "type": "boolean", "title": "Linked" },
|
||||
"user_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "User Id"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["linked"],
|
||||
"title": "ResolveResponse"
|
||||
},
|
||||
"ResponseType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -15159,6 +14715,16 @@
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Title"
|
||||
},
|
||||
"is_managed": {
|
||||
"type": "boolean",
|
||||
"title": "Is Managed",
|
||||
"default": false
|
||||
},
|
||||
"metadata": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Metadata"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "user_password",
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
"use client";
|
||||
|
||||
import { createContext, useContext, useCallback, useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
const STORAGE_KEY = "autopilot_pending_prompt";
|
||||
|
||||
interface AutoPilotBridgeState {
|
||||
/** Pending prompt to be injected into AutoPilot chat. */
|
||||
pendingPrompt: string | null;
|
||||
/** Queue a prompt that the Home/Copilot tab will pick up. */
|
||||
sendPrompt: (prompt: string) => void;
|
||||
/** Consume and clear the pending prompt (called by the chat page). */
|
||||
consumePrompt: () => string | null;
|
||||
}
|
||||
|
||||
const AutoPilotBridgeContext = createContext<AutoPilotBridgeState | null>(null);
|
||||
|
||||
interface Props {
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
export function AutoPilotBridgeProvider({ children }: Props) {
|
||||
const router = useRouter();
|
||||
|
||||
// Hydrate from sessionStorage in case we just navigated here
|
||||
const [pendingPrompt, setPendingPrompt] = useState<string | null>(() => {
|
||||
if (typeof window === "undefined") return null;
|
||||
return sessionStorage.getItem(STORAGE_KEY);
|
||||
});
|
||||
|
||||
const sendPrompt = useCallback(
|
||||
(prompt: string) => {
|
||||
// Persist to sessionStorage so it survives client-side navigation
|
||||
sessionStorage.setItem(STORAGE_KEY, prompt);
|
||||
setPendingPrompt(prompt);
|
||||
// Use Next.js router for client-side navigation (preserves React tree)
|
||||
router.push("/");
|
||||
},
|
||||
[router],
|
||||
);
|
||||
|
||||
const consumePrompt = useCallback((): string | null => {
|
||||
const prompt = pendingPrompt ?? sessionStorage.getItem(STORAGE_KEY);
|
||||
if (prompt !== null) {
|
||||
sessionStorage.removeItem(STORAGE_KEY);
|
||||
setPendingPrompt(null);
|
||||
}
|
||||
return prompt;
|
||||
}, [pendingPrompt]);
|
||||
|
||||
return (
|
||||
<AutoPilotBridgeContext.Provider
|
||||
value={{ pendingPrompt, sendPrompt, consumePrompt }}
|
||||
>
|
||||
{children}
|
||||
</AutoPilotBridgeContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
export function useAutoPilotBridge(): AutoPilotBridgeState {
|
||||
const context = useContext(AutoPilotBridgeContext);
|
||||
if (!context) {
|
||||
// Return a no-op implementation when used outside the provider
|
||||
// (e.g. in tests or isolated component renders).
|
||||
return {
|
||||
pendingPrompt: null,
|
||||
sendPrompt: () => {},
|
||||
consumePrompt: () => null,
|
||||
};
|
||||
}
|
||||
return context;
|
||||
}
|
||||
@@ -614,6 +614,7 @@ export type CredentialsMetaResponse = {
|
||||
username?: string;
|
||||
host?: string;
|
||||
is_system?: boolean;
|
||||
is_managed?: boolean;
|
||||
};
|
||||
|
||||
/* Mirror of backend/api/features/integrations/router.py:CredentialsDeletionResponse */
|
||||
|
||||
@@ -731,6 +731,7 @@ _Add technical explanation here._
|
||||
| max_tokens | The maximum number of tokens to generate in the chat completion. | int | No |
|
||||
| ollama_host | Ollama host for local models | str | No |
|
||||
| agent_mode_max_iterations | Maximum iterations for agent mode. 0 = traditional mode (single LLM call, yield tool calls for external execution), -1 = infinite agent mode (loop until finished), 1+ = agent mode with max iterations limit. | int | No |
|
||||
| execution_mode | How tool calls are executed. 'built_in' uses the default tool-call loop (all providers). 'extended_thinking' delegates to an external Agent SDK for richer reasoning (currently Anthropic / OpenRouter only, requires API credentials, ignores 'Agent Mode Max Iterations'). | "built_in" \| "extended_thinking" | No |
|
||||
| conversation_compaction | Automatically compact the context window once it hits the limit | bool | No |
|
||||
|
||||
### Outputs
|
||||
|
||||
Reference in New Issue
Block a user