mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-13 00:05:02 -05:00
Compare commits
28 Commits
feat/mcp-b
...
pwuts/open
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
746a36822d | ||
|
|
d95aef7665 | ||
|
|
2a46d3fbf4 | ||
|
|
ab25516a46 | ||
|
|
6e2f595c7d | ||
|
|
e523eb62b5 | ||
|
|
97ff65ef6a | ||
|
|
e8b81f71ef | ||
|
|
d652821ed5 | ||
|
|
80659d90e4 | ||
|
|
eef892893c | ||
|
|
23175708e6 | ||
|
|
f02c00374e | ||
|
|
2fa166d839 | ||
|
|
d927e4b611 | ||
|
|
6591b2171c | ||
|
|
85d97a9d5c | ||
|
|
16c8b2a6e3 | ||
|
|
cad54a9f3e | ||
|
|
ca0620b102 | ||
|
|
7a4cf4e186 | ||
|
|
fe9debd80f | ||
|
|
7083dcf226 | ||
|
|
ee2805d14c | ||
|
|
f15362d619 | ||
|
|
6c2374593f | ||
|
|
0f4c33308f | ||
|
|
ecb9fdae25 |
@@ -1,4 +1,9 @@
|
||||
"""Common test fixtures for server tests."""
|
||||
"""Common test fixtures for server tests.
|
||||
|
||||
Note: Common fixtures like test_user_id, admin_user_id, target_user_id,
|
||||
setup_test_user, and setup_admin_user are defined in the parent conftest.py
|
||||
(backend/conftest.py) and are available here automatically.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
@@ -11,54 +16,6 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
|
||||
return snapshot
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_id() -> str:
|
||||
"""Test user ID fixture."""
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_id() -> str:
|
||||
"""Admin user ID fixture."""
|
||||
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_user_id() -> str:
|
||||
"""Target user ID fixture."""
|
||||
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_test_user(test_user_id):
|
||||
"""Create test user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the test user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": test_user_id,
|
||||
"email": "test@example.com",
|
||||
"user_metadata": {"name": "Test User"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return test_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_admin_user(admin_user_id):
|
||||
"""Create admin user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the admin user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": admin_user_id,
|
||||
"email": "test-admin@example.com",
|
||||
"user_metadata": {"name": "Test Admin"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return admin_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jwt_user(test_user_id):
|
||||
"""Provide mock JWT payload for regular user testing."""
|
||||
|
||||
@@ -15,9 +15,9 @@ from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools import find_agent_tool, run_agent_tool
|
||||
from backend.api.features.chat.tools.models import ToolResponseBase
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools import find_agent_tool, run_agent_tool
|
||||
from backend.copilot.tools.models import ToolResponseBase
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -10,15 +10,22 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response,
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .completion_handler import process_operation_failure, process_operation_success
|
||||
from .config import ChatConfig
|
||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||
from .response_model import StreamFinish, StreamHeartbeat
|
||||
from .tools.models import (
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.completion_handler import (
|
||||
process_operation_failure,
|
||||
process_operation_success,
|
||||
)
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.executor.utils import enqueue_copilot_task
|
||||
from backend.copilot.model import (
|
||||
ChatSession,
|
||||
create_chat_session,
|
||||
get_chat_session,
|
||||
get_user_sessions,
|
||||
)
|
||||
from backend.copilot.response_model import StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.models import (
|
||||
AgentDetailsResponse,
|
||||
AgentOutputResponse,
|
||||
AgentPreviewResponse,
|
||||
@@ -40,6 +47,7 @@ from .tools.models import (
|
||||
SetupRequirementsResponse,
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
@@ -301,7 +309,7 @@ async def stream_chat_post(
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
_session = await _validate_and_get_session(session_id, user_id) # noqa: F841
|
||||
logger.info(
|
||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
|
||||
extra={
|
||||
@@ -336,82 +344,20 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
|
||||
# Background task that runs the AI generation independently of SSE connection
|
||||
async def run_ai_generation():
|
||||
import time as time_module
|
||||
# Enqueue the task to RabbitMQ for processing by the CoPilot executor
|
||||
await enqueue_copilot_task(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
operation_id=operation_id,
|
||||
message=request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
)
|
||||
|
||||
gen_start_time = time_module.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
first_chunk_time, ttfc = None, None
|
||||
chunk_count = 0
|
||||
try:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
|
||||
):
|
||||
chunk_count += 1
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = time_module.perf_counter()
|
||||
ttfc = first_chunk_time - gen_start_time
|
||||
logger.info(
|
||||
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunk_type": type(chunk).__name__,
|
||||
"time_to_first_chunk_ms": ttfc * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
# Write to Redis (subscribers will receive via XREAD)
|
||||
await stream_registry.publish_chunk(task_id, chunk)
|
||||
|
||||
gen_end_time = time_module.perf_counter()
|
||||
total_time = (gen_end_time - gen_start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
|
||||
f"task={task_id}, session={session_id}, "
|
||||
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time,
|
||||
"time_to_first_chunk_ms": (
|
||||
ttfc * 1000 if ttfc is not None else None
|
||||
),
|
||||
"n_chunks": chunk_count,
|
||||
}
|
||||
},
|
||||
)
|
||||
await stream_registry.mark_task_completed(task_id, "completed")
|
||||
except Exception as e:
|
||||
elapsed = time_module.perf_counter() - gen_start_time
|
||||
logger.error(
|
||||
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"elapsed_ms": elapsed * 1000,
|
||||
"error": str(e),
|
||||
}
|
||||
},
|
||||
)
|
||||
await stream_registry.mark_task_completed(task_id, "failed")
|
||||
|
||||
# Start the AI generation in a background task
|
||||
bg_task = asyncio.create_task(run_ai_generation())
|
||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
|
||||
f"[TIMING] Task enqueued to RabbitMQ, setup={setup_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Annotated, Any, List, Literal
|
||||
from typing import TYPE_CHECKING, Annotated, List, Literal
|
||||
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from fastapi import (
|
||||
@@ -14,7 +14,7 @@ from fastapi import (
|
||||
Security,
|
||||
status,
|
||||
)
|
||||
from pydantic import BaseModel, Field, SecretStr, model_validator
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
||||
|
||||
from backend.api.features.library.db import set_preset_webhook, update_preset
|
||||
@@ -39,11 +39,7 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
||||
from backend.data.user import get_user_integrations
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||
from backend.integrations.credentials_store import provider_matches
|
||||
from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
create_mcp_oauth_handler,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
@@ -106,37 +102,9 @@ class CredentialsMetaResponse(BaseModel):
|
||||
scopes: list[str] | None
|
||||
username: str | None
|
||||
host: str | None = Field(
|
||||
default=None,
|
||||
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
|
||||
default=None, description="Host pattern for host-scoped credentials"
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_provider(cls, data: Any) -> Any:
|
||||
"""Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug."""
|
||||
if isinstance(data, dict):
|
||||
prov = data.get("provider", "")
|
||||
if isinstance(prov, str) and prov.startswith("ProviderName."):
|
||||
member = prov.removeprefix("ProviderName.")
|
||||
try:
|
||||
data = {**data, "provider": ProviderName[member].value}
|
||||
except KeyError:
|
||||
pass
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def get_host(cred: Credentials) -> str | None:
|
||||
"""Extract host from credential: HostScoped host or MCP server URL."""
|
||||
if isinstance(cred, HostScopedCredentials):
|
||||
return cred.host
|
||||
if isinstance(cred, OAuth2Credentials) and cred.provider in (
|
||||
ProviderName.MCP,
|
||||
ProviderName.MCP.value,
|
||||
"ProviderName.MCP",
|
||||
):
|
||||
return (cred.metadata or {}).get("mcp_server_url")
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
||||
async def callback(
|
||||
@@ -211,7 +179,9 @@ async def callback(
|
||||
title=credentials.title,
|
||||
scopes=credentials.scopes,
|
||||
username=credentials.username,
|
||||
host=(CredentialsMetaResponse.get_host(credentials)),
|
||||
host=(
|
||||
credentials.host if isinstance(credentials, HostScopedCredentials) else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -229,7 +199,7 @@ async def list_credentials(
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
@@ -252,7 +222,7 @@ async def list_credentials_by_provider(
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
@@ -352,11 +322,7 @@ async def delete_credentials(
|
||||
|
||||
tokens_revoked = None
|
||||
if isinstance(creds, OAuth2Credentials):
|
||||
if provider_matches(provider.value, ProviderName.MCP.value):
|
||||
# MCP uses dynamic per-server OAuth — create handler from metadata
|
||||
handler = create_mcp_oauth_handler(creds)
|
||||
else:
|
||||
handler = _get_provider_oauth_handler(request, provider)
|
||||
handler = _get_provider_oauth_handler(request, provider)
|
||||
tokens_revoked = await handler.revoke_tokens(creds)
|
||||
|
||||
return CredentialsDeletionResponse(revoked=tokens_revoked)
|
||||
|
||||
@@ -1,404 +0,0 @@
|
||||
"""
|
||||
MCP (Model Context Protocol) API routes.
|
||||
|
||||
Provides endpoints for MCP tool discovery and OAuth authentication so the
|
||||
frontend can list available tools on an MCP server before placing a block.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import fastapi
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from fastapi import Security
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.api.features.integrations.router import CredentialsMetaResponse
|
||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
||||
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import HTTPClientError, Requests
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
settings = Settings()
|
||||
router = fastapi.APIRouter(tags=["mcp"])
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
# ====================== Tool Discovery ====================== #
|
||||
|
||||
|
||||
class DiscoverToolsRequest(BaseModel):
|
||||
"""Request to discover tools on an MCP server."""
|
||||
|
||||
server_url: str = Field(description="URL of the MCP server")
|
||||
auth_token: str | None = Field(
|
||||
default=None,
|
||||
description="Optional Bearer token for authenticated MCP servers",
|
||||
)
|
||||
|
||||
|
||||
class MCPToolResponse(BaseModel):
|
||||
"""A single MCP tool returned by discovery."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
input_schema: dict[str, Any]
|
||||
|
||||
|
||||
class DiscoverToolsResponse(BaseModel):
|
||||
"""Response containing the list of tools available on an MCP server."""
|
||||
|
||||
tools: list[MCPToolResponse]
|
||||
server_name: str | None = None
|
||||
protocol_version: str | None = None
|
||||
|
||||
|
||||
@router.post(
|
||||
"/discover-tools",
|
||||
summary="Discover available tools on an MCP server",
|
||||
response_model=DiscoverToolsResponse,
|
||||
)
|
||||
async def discover_tools(
|
||||
request: DiscoverToolsRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> DiscoverToolsResponse:
|
||||
"""
|
||||
Connect to an MCP server and return its available tools.
|
||||
|
||||
If the user has a stored MCP credential for this server URL, it will be
|
||||
used automatically — no need to pass an explicit auth token.
|
||||
"""
|
||||
auth_token = request.auth_token
|
||||
|
||||
# Auto-use stored MCP credential when no explicit token is provided.
|
||||
if not auth_token:
|
||||
mcp_creds = await creds_manager.store.get_creds_by_provider(
|
||||
user_id, ProviderName.MCP.value
|
||||
)
|
||||
# Find the freshest credential for this server URL
|
||||
best_cred: OAuth2Credentials | None = None
|
||||
for cred in mcp_creds:
|
||||
if (
|
||||
isinstance(cred, OAuth2Credentials)
|
||||
and (cred.metadata or {}).get("mcp_server_url") == request.server_url
|
||||
):
|
||||
if best_cred is None or (
|
||||
(cred.access_token_expires_at or 0)
|
||||
> (best_cred.access_token_expires_at or 0)
|
||||
):
|
||||
best_cred = cred
|
||||
if best_cred:
|
||||
# Refresh the token if expired before using it
|
||||
best_cred = await creds_manager.refresh_if_needed(user_id, best_cred)
|
||||
logger.info(
|
||||
f"Using MCP credential {best_cred.id} for {request.server_url}, "
|
||||
f"expires_at={best_cred.access_token_expires_at}"
|
||||
)
|
||||
auth_token = best_cred.access_token.get_secret_value()
|
||||
|
||||
client = MCPClient(request.server_url, auth_token=auth_token)
|
||||
|
||||
try:
|
||||
init_result = await client.initialize()
|
||||
tools = await client.list_tools()
|
||||
except HTTPClientError as e:
|
||||
if e.status_code in (401, 403):
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401,
|
||||
detail="This MCP server requires authentication. "
|
||||
"Please provide a valid auth token.",
|
||||
)
|
||||
raise fastapi.HTTPException(status_code=502, detail=str(e))
|
||||
except MCPClientError as e:
|
||||
raise fastapi.HTTPException(status_code=502, detail=str(e))
|
||||
except Exception as e:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Failed to connect to MCP server: {e}",
|
||||
)
|
||||
|
||||
return DiscoverToolsResponse(
|
||||
tools=[
|
||||
MCPToolResponse(
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
input_schema=t.input_schema,
|
||||
)
|
||||
for t in tools
|
||||
],
|
||||
server_name=(
|
||||
init_result.get("serverInfo", {}).get("name")
|
||||
or urlparse(request.server_url).hostname
|
||||
or "MCP"
|
||||
),
|
||||
protocol_version=init_result.get("protocolVersion"),
|
||||
)
|
||||
|
||||
|
||||
# ======================== OAuth Flow ======================== #
|
||||
|
||||
|
||||
class MCPOAuthLoginRequest(BaseModel):
|
||||
"""Request to start an OAuth flow for an MCP server."""
|
||||
|
||||
server_url: str = Field(description="URL of the MCP server that requires OAuth")
|
||||
|
||||
|
||||
class MCPOAuthLoginResponse(BaseModel):
|
||||
"""Response with the OAuth login URL for the user to authenticate."""
|
||||
|
||||
login_url: str
|
||||
state_token: str
|
||||
|
||||
|
||||
@router.post(
|
||||
"/oauth/login",
|
||||
summary="Initiate OAuth login for an MCP server",
|
||||
)
|
||||
async def mcp_oauth_login(
|
||||
request: MCPOAuthLoginRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> MCPOAuthLoginResponse:
|
||||
"""
|
||||
Discover OAuth metadata from the MCP server and return a login URL.
|
||||
|
||||
1. Discovers the protected-resource metadata (RFC 9728)
|
||||
2. Fetches the authorization server metadata (RFC 8414)
|
||||
3. Performs Dynamic Client Registration (RFC 7591) if available
|
||||
4. Returns the authorization URL for the frontend to open in a popup
|
||||
"""
|
||||
client = MCPClient(request.server_url)
|
||||
|
||||
# Step 1: Discover protected-resource metadata (RFC 9728)
|
||||
protected_resource = await client.discover_auth()
|
||||
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
if protected_resource and protected_resource.get("authorization_servers"):
|
||||
auth_server_url = protected_resource["authorization_servers"][0]
|
||||
resource_url = protected_resource.get("resource", request.server_url)
|
||||
|
||||
# Step 2a: Discover auth-server metadata (RFC 8414)
|
||||
metadata = await client.discover_auth_server_metadata(auth_server_url)
|
||||
else:
|
||||
# Fallback: Some MCP servers (e.g. Linear) are their own auth server
|
||||
# and serve OAuth metadata directly without protected-resource metadata.
|
||||
# Don't assume a resource_url — omitting it lets the auth server choose
|
||||
# the correct audience for the token (RFC 8707 resource is optional).
|
||||
resource_url = None
|
||||
metadata = await client.discover_auth_server_metadata(request.server_url)
|
||||
|
||||
if (
|
||||
not metadata
|
||||
or "authorization_endpoint" not in metadata
|
||||
or "token_endpoint" not in metadata
|
||||
):
|
||||
raise fastapi.HTTPException(
|
||||
status_code=400,
|
||||
detail="This MCP server does not advertise OAuth support. "
|
||||
"You may need to provide an auth token manually.",
|
||||
)
|
||||
|
||||
authorize_url = metadata["authorization_endpoint"]
|
||||
token_url = metadata["token_endpoint"]
|
||||
registration_endpoint = metadata.get("registration_endpoint")
|
||||
revoke_url = metadata.get("revocation_endpoint")
|
||||
|
||||
# Step 3: Dynamic Client Registration (RFC 7591) if available
|
||||
frontend_base_url = settings.config.frontend_base_url
|
||||
if not frontend_base_url:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Frontend base URL is not configured.",
|
||||
)
|
||||
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
|
||||
|
||||
client_id = ""
|
||||
client_secret = ""
|
||||
if registration_endpoint:
|
||||
reg_result = await _register_mcp_client(
|
||||
registration_endpoint, redirect_uri, request.server_url
|
||||
)
|
||||
if reg_result:
|
||||
client_id = reg_result.get("client_id", "")
|
||||
client_secret = reg_result.get("client_secret", "")
|
||||
|
||||
if not client_id:
|
||||
client_id = "autogpt-platform"
|
||||
|
||||
# Step 4: Store state token with OAuth metadata for the callback
|
||||
scopes = (protected_resource or {}).get("scopes_supported") or metadata.get(
|
||||
"scopes_supported", []
|
||||
)
|
||||
state_token, code_challenge = await creds_manager.store.store_state_token(
|
||||
user_id,
|
||||
ProviderName.MCP.value,
|
||||
scopes,
|
||||
state_metadata={
|
||||
"authorize_url": authorize_url,
|
||||
"token_url": token_url,
|
||||
"revoke_url": revoke_url,
|
||||
"resource_url": resource_url,
|
||||
"server_url": request.server_url,
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
},
|
||||
)
|
||||
|
||||
# Step 5: Build and return the login URL
|
||||
handler = MCPOAuthHandler(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=redirect_uri,
|
||||
authorize_url=authorize_url,
|
||||
token_url=token_url,
|
||||
resource_url=resource_url,
|
||||
)
|
||||
login_url = handler.get_login_url(
|
||||
scopes, state_token, code_challenge=code_challenge
|
||||
)
|
||||
|
||||
return MCPOAuthLoginResponse(login_url=login_url, state_token=state_token)
|
||||
|
||||
|
||||
class MCPOAuthCallbackRequest(BaseModel):
|
||||
"""Request to exchange an OAuth code for tokens."""
|
||||
|
||||
code: str = Field(description="Authorization code from OAuth callback")
|
||||
state_token: str = Field(description="State token for CSRF verification")
|
||||
|
||||
|
||||
class MCPOAuthCallbackResponse(BaseModel):
|
||||
"""Response after successfully storing OAuth credentials."""
|
||||
|
||||
credential_id: str
|
||||
|
||||
|
||||
@router.post(
|
||||
"/oauth/callback",
|
||||
summary="Exchange OAuth code for MCP tokens",
|
||||
)
|
||||
async def mcp_oauth_callback(
|
||||
request: MCPOAuthCallbackRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> CredentialsMetaResponse:
|
||||
"""
|
||||
Exchange the authorization code for tokens and store the credential.
|
||||
|
||||
The frontend calls this after receiving the OAuth code from the popup.
|
||||
On success, subsequent ``/discover-tools`` calls for the same server URL
|
||||
will automatically use the stored credential.
|
||||
"""
|
||||
valid_state = await creds_manager.store.verify_state_token(
|
||||
user_id, request.state_token, ProviderName.MCP.value
|
||||
)
|
||||
if not valid_state:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid or expired state token.",
|
||||
)
|
||||
|
||||
meta = valid_state.state_metadata
|
||||
frontend_base_url = settings.config.frontend_base_url
|
||||
if not frontend_base_url:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Frontend base URL is not configured.",
|
||||
)
|
||||
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
|
||||
|
||||
handler = MCPOAuthHandler(
|
||||
client_id=meta["client_id"],
|
||||
client_secret=meta.get("client_secret", ""),
|
||||
redirect_uri=redirect_uri,
|
||||
authorize_url=meta["authorize_url"],
|
||||
token_url=meta["token_url"],
|
||||
revoke_url=meta.get("revoke_url"),
|
||||
resource_url=meta.get("resource_url"),
|
||||
)
|
||||
|
||||
try:
|
||||
credentials = await handler.exchange_code_for_tokens(
|
||||
request.code, valid_state.scopes, valid_state.code_verifier
|
||||
)
|
||||
except Exception as e:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=400,
|
||||
detail=f"OAuth token exchange failed: {e}",
|
||||
)
|
||||
|
||||
# Enrich credential metadata for future lookup and token refresh
|
||||
if credentials.metadata is None:
|
||||
credentials.metadata = {}
|
||||
credentials.metadata["mcp_server_url"] = meta["server_url"]
|
||||
credentials.metadata["mcp_client_id"] = meta["client_id"]
|
||||
credentials.metadata["mcp_client_secret"] = meta.get("client_secret", "")
|
||||
credentials.metadata["mcp_token_url"] = meta["token_url"]
|
||||
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
|
||||
|
||||
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
|
||||
credentials.title = f"MCP: {hostname}"
|
||||
|
||||
# Remove old MCP credentials for the same server to prevent stale token buildup.
|
||||
try:
|
||||
old_creds = await creds_manager.store.get_creds_by_provider(
|
||||
user_id, ProviderName.MCP.value
|
||||
)
|
||||
for old in old_creds:
|
||||
if (
|
||||
isinstance(old, OAuth2Credentials)
|
||||
and (old.metadata or {}).get("mcp_server_url") == meta["server_url"]
|
||||
):
|
||||
await creds_manager.store.delete_creds_by_id(user_id, old.id)
|
||||
logger.info(
|
||||
f"Removed old MCP credential {old.id} for {meta['server_url']}"
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Could not clean up old MCP credentials", exc_info=True)
|
||||
|
||||
await creds_manager.create(user_id, credentials)
|
||||
|
||||
return CredentialsMetaResponse(
|
||||
id=credentials.id,
|
||||
provider=credentials.provider,
|
||||
type=credentials.type,
|
||||
title=credentials.title,
|
||||
scopes=credentials.scopes,
|
||||
username=credentials.username,
|
||||
host=credentials.metadata.get("mcp_server_url"),
|
||||
)
|
||||
|
||||
|
||||
# ======================== Helpers ======================== #
|
||||
|
||||
|
||||
async def _register_mcp_client(
|
||||
registration_endpoint: str,
|
||||
redirect_uri: str,
|
||||
server_url: str,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Attempt Dynamic Client Registration (RFC 7591) with an MCP auth server."""
|
||||
try:
|
||||
response = await Requests(raise_for_status=True).post(
|
||||
registration_endpoint,
|
||||
json={
|
||||
"client_name": "AutoGPT Platform",
|
||||
"redirect_uris": [redirect_uri],
|
||||
"grant_types": ["authorization_code"],
|
||||
"response_types": ["code"],
|
||||
"token_endpoint_auth_method": "client_secret_post",
|
||||
},
|
||||
)
|
||||
data = response.json()
|
||||
if isinstance(data, dict) and "client_id" in data:
|
||||
return data
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Dynamic client registration failed for {server_url}: {e}")
|
||||
return None
|
||||
@@ -1,436 +0,0 @@
|
||||
"""Tests for MCP API routes.
|
||||
|
||||
Uses httpx.AsyncClient with ASGITransport instead of fastapi.testclient.TestClient
|
||||
to avoid creating blocking portals that can corrupt pytest-asyncio's session event loop.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import fastapi
|
||||
import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from autogpt_libs.auth import get_user_id
|
||||
|
||||
from backend.api.features.mcp.routes import router
|
||||
from backend.blocks.mcp.client import MCPClientError, MCPTool
|
||||
from backend.util.request import HTTPClientError
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
app.dependency_overrides[get_user_id] = lambda: "test-user-id"
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="module")
|
||||
async def client():
|
||||
transport = httpx.ASGITransport(app=app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
|
||||
yield c
|
||||
|
||||
|
||||
class TestDiscoverTools:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_success(self, client):
|
||||
mock_tools = [
|
||||
MCPTool(
|
||||
name="get_weather",
|
||||
description="Get weather for a city",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
),
|
||||
MCPTool(
|
||||
name="add_numbers",
|
||||
description="Add two numbers",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "number"},
|
||||
"b": {"type": "number"},
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
return_value={
|
||||
"protocolVersion": "2025-03-26",
|
||||
"serverInfo": {"name": "test-server"},
|
||||
}
|
||||
)
|
||||
instance.list_tools = AsyncMock(return_value=mock_tools)
|
||||
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={"server_url": "https://mcp.example.com/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["tools"]) == 2
|
||||
assert data["tools"][0]["name"] == "get_weather"
|
||||
assert data["tools"][1]["name"] == "add_numbers"
|
||||
assert data["server_name"] == "test-server"
|
||||
assert data["protocol_version"] == "2025-03-26"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_with_auth_token(self, client):
|
||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
|
||||
)
|
||||
instance.list_tools = AsyncMock(return_value=[])
|
||||
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={
|
||||
"server_url": "https://mcp.example.com/mcp",
|
||||
"auth_token": "my-secret-token",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
MockClient.assert_called_once_with(
|
||||
"https://mcp.example.com/mcp",
|
||||
auth_token="my-secret-token",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_auto_uses_stored_credential(self, client):
|
||||
"""When no explicit token is given, stored MCP credentials are used."""
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
|
||||
stored_cred = OAuth2Credentials(
|
||||
provider="mcp",
|
||||
title="MCP: example.com",
|
||||
access_token=SecretStr("stored-token-123"),
|
||||
refresh_token=None,
|
||||
access_token_expires_at=None,
|
||||
refresh_token_expires_at=None,
|
||||
scopes=[],
|
||||
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred])
|
||||
mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred)
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
|
||||
)
|
||||
instance.list_tools = AsyncMock(return_value=[])
|
||||
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={"server_url": "https://mcp.example.com/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
MockClient.assert_called_once_with(
|
||||
"https://mcp.example.com/mcp",
|
||||
auth_token="stored-token-123",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_mcp_error(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
side_effect=MCPClientError("Connection refused")
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={"server_url": "https://bad-server.example.com/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 502
|
||||
assert "Connection refused" in response.json()["detail"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_generic_error(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
|
||||
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={"server_url": "https://timeout.example.com/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 502
|
||||
assert "Failed to connect" in response.json()["detail"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_auth_required(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={"server_url": "https://auth-server.example.com/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert "requires authentication" in response.json()["detail"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_forbidden(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={"server_url": "https://auth-server.example.com/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert "requires authentication" in response.json()["detail"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_missing_url(self, client):
|
||||
response = await client.post("/discover-tools", json={})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
class TestOAuthLogin:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_login_success(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
||||
patch(
|
||||
"backend.api.features.mcp.routes._register_mcp_client"
|
||||
) as mock_register,
|
||||
):
|
||||
instance = MockClient.return_value
|
||||
instance.discover_auth = AsyncMock(
|
||||
return_value={
|
||||
"authorization_servers": ["https://auth.sentry.io"],
|
||||
"resource": "https://mcp.sentry.dev/mcp",
|
||||
"scopes_supported": ["openid"],
|
||||
}
|
||||
)
|
||||
instance.discover_auth_server_metadata = AsyncMock(
|
||||
return_value={
|
||||
"authorization_endpoint": "https://auth.sentry.io/authorize",
|
||||
"token_endpoint": "https://auth.sentry.io/token",
|
||||
"registration_endpoint": "https://auth.sentry.io/register",
|
||||
}
|
||||
)
|
||||
mock_register.return_value = {
|
||||
"client_id": "registered-client-id",
|
||||
"client_secret": "registered-secret",
|
||||
}
|
||||
mock_cm.store.store_state_token = AsyncMock(
|
||||
return_value=("state-token-123", "code-challenge-abc")
|
||||
)
|
||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
||||
|
||||
response = await client.post(
|
||||
"/oauth/login",
|
||||
json={"server_url": "https://mcp.sentry.dev/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "login_url" in data
|
||||
assert data["state_token"] == "state-token-123"
|
||||
assert "auth.sentry.io/authorize" in data["login_url"]
|
||||
assert "registered-client-id" in data["login_url"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_login_no_oauth_support(self, client):
|
||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||
instance = MockClient.return_value
|
||||
instance.discover_auth = AsyncMock(return_value=None)
|
||||
instance.discover_auth_server_metadata = AsyncMock(return_value=None)
|
||||
|
||||
response = await client.post(
|
||||
"/oauth/login",
|
||||
json={"server_url": "https://simple-server.example.com/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "does not advertise OAuth" in response.json()["detail"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_login_fallback_to_public_client(self, client):
|
||||
"""When DCR is unavailable, falls back to default public client ID."""
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
||||
):
|
||||
instance = MockClient.return_value
|
||||
instance.discover_auth = AsyncMock(
|
||||
return_value={
|
||||
"authorization_servers": ["https://auth.example.com"],
|
||||
"resource": "https://mcp.example.com/mcp",
|
||||
}
|
||||
)
|
||||
instance.discover_auth_server_metadata = AsyncMock(
|
||||
return_value={
|
||||
"authorization_endpoint": "https://auth.example.com/authorize",
|
||||
"token_endpoint": "https://auth.example.com/token",
|
||||
# No registration_endpoint
|
||||
}
|
||||
)
|
||||
mock_cm.store.store_state_token = AsyncMock(
|
||||
return_value=("state-abc", "challenge-xyz")
|
||||
)
|
||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
||||
|
||||
response = await client.post(
|
||||
"/oauth/login",
|
||||
json={"server_url": "https://mcp.example.com/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "autogpt-platform" in data["login_url"]
|
||||
|
||||
|
||||
class TestOAuthCallback:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_callback_success(self, client):
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
|
||||
mock_creds = OAuth2Credentials(
|
||||
provider="mcp",
|
||||
title=None,
|
||||
access_token=SecretStr("access-token-xyz"),
|
||||
refresh_token=None,
|
||||
access_token_expires_at=None,
|
||||
refresh_token_expires_at=None,
|
||||
scopes=[],
|
||||
metadata={
|
||||
"mcp_token_url": "https://auth.sentry.io/token",
|
||||
"mcp_resource_url": "https://mcp.sentry.dev/mcp",
|
||||
},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
||||
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
|
||||
):
|
||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
||||
|
||||
# Mock state verification
|
||||
mock_state = AsyncMock()
|
||||
mock_state.state_metadata = {
|
||||
"authorize_url": "https://auth.sentry.io/authorize",
|
||||
"token_url": "https://auth.sentry.io/token",
|
||||
"client_id": "test-client-id",
|
||||
"client_secret": "test-secret",
|
||||
"server_url": "https://mcp.sentry.dev/mcp",
|
||||
}
|
||||
mock_state.scopes = ["openid"]
|
||||
mock_state.code_verifier = "verifier-123"
|
||||
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
|
||||
mock_cm.create = AsyncMock()
|
||||
|
||||
handler_instance = MockHandler.return_value
|
||||
handler_instance.exchange_code_for_tokens = AsyncMock(
|
||||
return_value=mock_creds
|
||||
)
|
||||
|
||||
# Mock old credential cleanup
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
|
||||
response = await client.post(
|
||||
"/oauth/callback",
|
||||
json={"code": "auth-code-abc", "state_token": "state-token-123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "id" in data
|
||||
assert data["provider"] == "mcp"
|
||||
assert data["type"] == "oauth2"
|
||||
mock_cm.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_callback_invalid_state(self, client):
|
||||
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
|
||||
mock_cm.store.verify_state_token = AsyncMock(return_value=None)
|
||||
|
||||
response = await client.post(
|
||||
"/oauth/callback",
|
||||
json={"code": "auth-code", "state_token": "bad-state"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Invalid or expired" in response.json()["detail"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_callback_token_exchange_fails(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
||||
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
|
||||
):
|
||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
||||
mock_state = AsyncMock()
|
||||
mock_state.state_metadata = {
|
||||
"authorize_url": "https://auth.example.com/authorize",
|
||||
"token_url": "https://auth.example.com/token",
|
||||
"client_id": "cid",
|
||||
"server_url": "https://mcp.example.com/mcp",
|
||||
}
|
||||
mock_state.scopes = []
|
||||
mock_state.code_verifier = "v"
|
||||
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
|
||||
|
||||
handler_instance = MockHandler.return_value
|
||||
handler_instance.exchange_code_for_tokens = AsyncMock(
|
||||
side_effect=RuntimeError("Token exchange failed")
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
"/oauth/callback",
|
||||
json={"code": "bad-code", "state_token": "state"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "token exchange failed" in response.json()["detail"].lower()
|
||||
@@ -26,7 +26,6 @@ import backend.api.features.executions.review.routes
|
||||
import backend.api.features.library.db
|
||||
import backend.api.features.library.model
|
||||
import backend.api.features.library.routes
|
||||
import backend.api.features.mcp.routes as mcp_routes
|
||||
import backend.api.features.oauth
|
||||
import backend.api.features.otto.routes
|
||||
import backend.api.features.postmark.postmark
|
||||
@@ -41,11 +40,11 @@ import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.api.features.chat.completion_consumer import (
|
||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||
from backend.copilot.completion_consumer import (
|
||||
start_completion_consumer,
|
||||
stop_completion_consumer,
|
||||
)
|
||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
@@ -344,11 +343,6 @@ app.include_router(
|
||||
tags=["workspace"],
|
||||
prefix="/api/workspace",
|
||||
)
|
||||
app.include_router(
|
||||
mcp_routes.router,
|
||||
tags=["v2", "mcp"],
|
||||
prefix="/api/mcp",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.oauth.router,
|
||||
tags=["oauth"],
|
||||
|
||||
@@ -38,7 +38,9 @@ def main(**kwargs):
|
||||
|
||||
from backend.api.rest_api import AgentServer
|
||||
from backend.api.ws_api import WebsocketServer
|
||||
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
|
||||
from backend.copilot.executor.manager import CoPilotExecutor
|
||||
from backend.data.db_manager import DatabaseManager
|
||||
from backend.executor import ExecutionManager, Scheduler
|
||||
from backend.notifications import NotificationManager
|
||||
|
||||
run_processes(
|
||||
@@ -48,6 +50,7 @@ def main(**kwargs):
|
||||
WebsocketServer(),
|
||||
AgentServer(),
|
||||
ExecutionManager(),
|
||||
CoPilotExecutor(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -64,7 +64,6 @@ class BlockType(Enum):
|
||||
AI = "AI"
|
||||
AYRSHARE = "Ayrshare"
|
||||
HUMAN_IN_THE_LOOP = "Human In The Loop"
|
||||
MCP_TOOL = "MCP Tool"
|
||||
|
||||
|
||||
class BlockCategory(Enum):
|
||||
|
||||
@@ -1,300 +0,0 @@
|
||||
"""
|
||||
MCP (Model Context Protocol) Tool Block.
|
||||
|
||||
A single dynamic block that can connect to any MCP server, discover available tools,
|
||||
and execute them. Works like AgentExecutorBlock — the user selects a tool from a
|
||||
dropdown and the input/output schema adapts dynamically.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
)
|
||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
||||
from backend.data.block import BlockInput, BlockOutput
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
OAuth2Credentials,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.json import validate_with_jsonschema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TEST_CREDENTIALS = OAuth2Credentials(
|
||||
id="test-mcp-cred",
|
||||
provider="mcp",
|
||||
access_token=SecretStr("mock-mcp-token"),
|
||||
refresh_token=SecretStr("mock-refresh"),
|
||||
scopes=[],
|
||||
title="Mock MCP credential",
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
MCPCredentials = CredentialsMetaInput[Literal[ProviderName.MCP], Literal["oauth2"]]
|
||||
|
||||
|
||||
class MCPToolBlock(Block):
|
||||
"""
|
||||
A block that connects to an MCP server, lets the user pick a tool,
|
||||
and executes it with dynamic input/output schema.
|
||||
|
||||
The flow:
|
||||
1. User provides an MCP server URL (and optional credentials)
|
||||
2. Frontend calls the backend to get tool list from that URL
|
||||
3. User selects a tool from a dropdown (available_tools)
|
||||
4. The block's input schema updates to reflect the selected tool's parameters
|
||||
5. On execution, the block calls the MCP server to run the tool
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
server_url: str = SchemaField(
|
||||
description="URL of the MCP server (Streamable HTTP endpoint)",
|
||||
placeholder="https://mcp.example.com/mcp",
|
||||
)
|
||||
credentials: MCPCredentials = CredentialsField(
|
||||
discriminator="server_url",
|
||||
description="MCP server OAuth credentials",
|
||||
default={},
|
||||
)
|
||||
selected_tool: str = SchemaField(
|
||||
description="The MCP tool to execute",
|
||||
placeholder="Select a tool",
|
||||
default="",
|
||||
)
|
||||
tool_input_schema: dict[str, Any] = SchemaField(
|
||||
description="JSON Schema for the selected tool's input parameters. "
|
||||
"Populated automatically when a tool is selected.",
|
||||
default={},
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
tool_arguments: dict[str, Any] = SchemaField(
|
||||
description="Arguments to pass to the selected MCP tool. "
|
||||
"The fields here are defined by the tool's input schema.",
|
||||
default={},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
|
||||
"""Return the tool's input schema so the builder UI renders dynamic fields."""
|
||||
return data.get("tool_input_schema", {})
|
||||
|
||||
@classmethod
|
||||
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||
"""Return the current tool_arguments as defaults for the dynamic fields."""
|
||||
return data.get("tool_arguments", {})
|
||||
|
||||
@classmethod
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
"""Check which required tool arguments are missing."""
|
||||
required_fields = cls.get_input_schema(data).get("required", [])
|
||||
tool_arguments = data.get("tool_arguments", {})
|
||||
return set(required_fields) - set(tool_arguments)
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
"""Validate tool_arguments against the tool's input schema."""
|
||||
tool_schema = cls.get_input_schema(data)
|
||||
if not tool_schema:
|
||||
return None
|
||||
tool_arguments = data.get("tool_arguments", {})
|
||||
return validate_with_jsonschema(tool_schema, tool_arguments)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
result: Any = SchemaField(description="The result returned by the MCP tool")
|
||||
error: str = SchemaField(description="Error message if the tool call failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4",
|
||||
description="Connect to any MCP server and execute its tools. "
|
||||
"Provide a server URL, select a tool, and pass arguments dynamically.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=MCPToolBlock.Input,
|
||||
output_schema=MCPToolBlock.Output,
|
||||
block_type=BlockType.MCP_TOOL,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"server_url": "https://mcp.example.com/mcp",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"selected_tool": "get_weather",
|
||||
"tool_input_schema": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
"tool_arguments": {"city": "London"},
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"result",
|
||||
{"weather": "sunny", "temperature": 20},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"_call_mcp_tool": lambda *a, **kw: {
|
||||
"weather": "sunny",
|
||||
"temperature": 20,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
async def _call_mcp_tool(
|
||||
self,
|
||||
server_url: str,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
auth_token: str | None = None,
|
||||
) -> Any:
|
||||
"""Call a tool on the MCP server. Extracted for easy mocking in tests."""
|
||||
client = MCPClient(server_url, auth_token=auth_token)
|
||||
await client.initialize()
|
||||
result = await client.call_tool(tool_name, arguments)
|
||||
|
||||
if result.is_error:
|
||||
error_text = ""
|
||||
for item in result.content:
|
||||
if item.get("type") == "text":
|
||||
error_text += item.get("text", "")
|
||||
raise MCPClientError(
|
||||
f"MCP tool '{tool_name}' returned an error: "
|
||||
f"{error_text or 'Unknown error'}"
|
||||
)
|
||||
|
||||
# Extract text content from the result
|
||||
output_parts = []
|
||||
for item in result.content:
|
||||
if item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
# Try to parse as JSON for structured output
|
||||
try:
|
||||
output_parts.append(json.loads(text))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
output_parts.append(text)
|
||||
elif item.get("type") == "image":
|
||||
output_parts.append(
|
||||
{
|
||||
"type": "image",
|
||||
"data": item.get("data"),
|
||||
"mimeType": item.get("mimeType"),
|
||||
}
|
||||
)
|
||||
elif item.get("type") == "resource":
|
||||
output_parts.append(item.get("resource", {}))
|
||||
|
||||
# If single result, unwrap
|
||||
if len(output_parts) == 1:
|
||||
return output_parts[0]
|
||||
return output_parts if output_parts else None
|
||||
|
||||
@staticmethod
|
||||
async def _auto_lookup_credential(
|
||||
user_id: str, server_url: str
|
||||
) -> "OAuth2Credentials | None":
|
||||
"""Auto-lookup stored MCP credential for a server URL.
|
||||
|
||||
This is a fallback for nodes that don't have ``credentials`` explicitly
|
||||
set (e.g. nodes created before the credential field was wired up).
|
||||
"""
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
try:
|
||||
mgr = IntegrationCredentialsManager()
|
||||
mcp_creds = await mgr.store.get_creds_by_provider(
|
||||
user_id, ProviderName.MCP.value
|
||||
)
|
||||
best: OAuth2Credentials | None = None
|
||||
for cred in mcp_creds:
|
||||
if (
|
||||
isinstance(cred, OAuth2Credentials)
|
||||
and (cred.metadata or {}).get("mcp_server_url") == server_url
|
||||
):
|
||||
if best is None or (
|
||||
(cred.access_token_expires_at or 0)
|
||||
> (best.access_token_expires_at or 0)
|
||||
):
|
||||
best = cred
|
||||
if best:
|
||||
best = await mgr.refresh_if_needed(user_id, best)
|
||||
logger.info(
|
||||
"Auto-resolved MCP credential %s for %s", best.id, server_url
|
||||
)
|
||||
return best
|
||||
except Exception:
|
||||
logger.warning("Auto-lookup MCP credential failed", exc_info=True)
|
||||
return None
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
user_id: str,
|
||||
credentials: OAuth2Credentials | None = None,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
if not input_data.server_url:
|
||||
yield "error", "MCP server URL is required"
|
||||
return
|
||||
|
||||
if not input_data.selected_tool:
|
||||
yield "error", "No tool selected. Please select a tool from the dropdown."
|
||||
return
|
||||
|
||||
# Validate required tool arguments before calling the server.
|
||||
# The executor-level validation is bypassed for MCP blocks because
|
||||
# get_input_defaults() flattens tool_arguments, stripping tool_input_schema
|
||||
# from the validation context.
|
||||
required = set(input_data.tool_input_schema.get("required", []))
|
||||
if required:
|
||||
missing = required - set(input_data.tool_arguments.keys())
|
||||
if missing:
|
||||
yield "error", (
|
||||
f"Missing required argument(s): {', '.join(sorted(missing))}. "
|
||||
f"Please fill in all required fields marked with * in the block form."
|
||||
)
|
||||
return
|
||||
|
||||
# If no credentials were injected by the executor (e.g. legacy nodes
|
||||
# that don't have the credentials field set), try to auto-lookup
|
||||
# the stored MCP credential for this server URL.
|
||||
if credentials is None:
|
||||
credentials = await self._auto_lookup_credential(
|
||||
user_id, input_data.server_url
|
||||
)
|
||||
|
||||
auth_token = (
|
||||
credentials.access_token.get_secret_value() if credentials else None
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self._call_mcp_tool(
|
||||
server_url=input_data.server_url,
|
||||
tool_name=input_data.selected_tool,
|
||||
arguments=input_data.tool_arguments,
|
||||
auth_token=auth_token,
|
||||
)
|
||||
yield "result", result
|
||||
except MCPClientError as e:
|
||||
yield "error", str(e)
|
||||
except Exception as e:
|
||||
logger.exception(f"MCP tool call failed: {e}")
|
||||
yield "error", f"MCP tool call failed: {str(e)}"
|
||||
@@ -1,323 +0,0 @@
|
||||
"""
|
||||
MCP (Model Context Protocol) HTTP client.
|
||||
|
||||
Implements the MCP Streamable HTTP transport for listing tools and calling tools
|
||||
on remote MCP servers. Uses JSON-RPC 2.0 over HTTP POST.
|
||||
|
||||
Handles both JSON and SSE (text/event-stream) response formats per the MCP spec.
|
||||
|
||||
Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from backend.util.request import Requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPTool:
|
||||
"""Represents an MCP tool discovered from a server."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
input_schema: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPCallResult:
|
||||
"""Result from calling an MCP tool."""
|
||||
|
||||
content: list[dict[str, Any]] = field(default_factory=list)
|
||||
is_error: bool = False
|
||||
|
||||
|
||||
class MCPClientError(Exception):
|
||||
"""Raised when an MCP protocol error occurs."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""
|
||||
Async HTTP client for the MCP Streamable HTTP transport.
|
||||
|
||||
Communicates with MCP servers using JSON-RPC 2.0 over HTTP POST.
|
||||
Supports optional Bearer token authentication.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
auth_token: str | None = None,
|
||||
):
|
||||
self.server_url = server_url.rstrip("/")
|
||||
self.auth_token = auth_token
|
||||
self._request_id = 0
|
||||
self._session_id: str | None = None
|
||||
|
||||
def _next_id(self) -> int:
|
||||
self._request_id += 1
|
||||
return self._request_id
|
||||
|
||||
def _build_headers(self) -> dict[str, str]:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
}
|
||||
if self.auth_token:
|
||||
headers["Authorization"] = f"Bearer {self.auth_token}"
|
||||
if self._session_id:
|
||||
headers["Mcp-Session-Id"] = self._session_id
|
||||
return headers
|
||||
|
||||
def _build_jsonrpc_request(
|
||||
self, method: str, params: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
req: dict[str, Any] = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
"id": self._next_id(),
|
||||
}
|
||||
if params is not None:
|
||||
req["params"] = params
|
||||
return req
|
||||
|
||||
@staticmethod
|
||||
def _parse_sse_response(text: str) -> dict[str, Any]:
|
||||
"""Parse an SSE (text/event-stream) response body into JSON-RPC data.
|
||||
|
||||
MCP servers may return responses as SSE with format:
|
||||
event: message
|
||||
data: {"jsonrpc":"2.0","result":{...},"id":1}
|
||||
|
||||
We extract the last `data:` line that contains a JSON-RPC response
|
||||
(i.e. has an "id" field), which is the reply to our request.
|
||||
"""
|
||||
last_data: dict[str, Any] | None = None
|
||||
for line in text.splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("data:"):
|
||||
payload = stripped[len("data:") :].strip()
|
||||
if not payload:
|
||||
continue
|
||||
try:
|
||||
parsed = json.loads(payload)
|
||||
# Only keep JSON-RPC responses (have "id"), skip notifications
|
||||
if isinstance(parsed, dict) and "id" in parsed:
|
||||
last_data = parsed
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
continue
|
||||
if last_data is None:
|
||||
raise MCPClientError("No JSON-RPC response found in SSE stream")
|
||||
return last_data
|
||||
|
||||
async def _send_request(
|
||||
self, method: str, params: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Send a JSON-RPC request to the MCP server and return the result.
|
||||
|
||||
Handles both ``application/json`` and ``text/event-stream`` responses
|
||||
as required by the MCP Streamable HTTP transport specification.
|
||||
"""
|
||||
payload = self._build_jsonrpc_request(method, params)
|
||||
headers = self._build_headers()
|
||||
|
||||
requests = Requests(
|
||||
raise_for_status=True,
|
||||
extra_headers=headers,
|
||||
)
|
||||
response = await requests.post(self.server_url, json=payload)
|
||||
|
||||
# Capture session ID from response (MCP Streamable HTTP transport)
|
||||
session_id = response.headers.get("Mcp-Session-Id")
|
||||
if session_id:
|
||||
self._session_id = session_id
|
||||
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "text/event-stream" in content_type:
|
||||
body = self._parse_sse_response(response.text())
|
||||
else:
|
||||
try:
|
||||
body = response.json()
|
||||
except Exception as e:
|
||||
raise MCPClientError(
|
||||
f"MCP server returned non-JSON response: {e}"
|
||||
) from e
|
||||
|
||||
if not isinstance(body, dict):
|
||||
raise MCPClientError(
|
||||
f"MCP server returned unexpected JSON type: {type(body).__name__}"
|
||||
)
|
||||
|
||||
# Handle JSON-RPC error
|
||||
if "error" in body:
|
||||
error = body["error"]
|
||||
if isinstance(error, dict):
|
||||
raise MCPClientError(
|
||||
f"MCP server error [{error.get('code', '?')}]: "
|
||||
f"{error.get('message', 'Unknown error')}"
|
||||
)
|
||||
raise MCPClientError(f"MCP server error: {error}")
|
||||
|
||||
return body.get("result")
|
||||
|
||||
async def _send_notification(self, method: str) -> None:
|
||||
"""Send a JSON-RPC notification (no id, no response expected)."""
|
||||
headers = self._build_headers()
|
||||
notification = {"jsonrpc": "2.0", "method": method}
|
||||
requests = Requests(
|
||||
raise_for_status=False,
|
||||
extra_headers=headers,
|
||||
)
|
||||
await requests.post(self.server_url, json=notification)
|
||||
|
||||
async def discover_auth(self) -> dict[str, Any] | None:
|
||||
"""Probe the MCP server's OAuth metadata (RFC 9728 / MCP spec).
|
||||
|
||||
Returns ``None`` if the server doesn't require auth, otherwise returns
|
||||
a dict with:
|
||||
- ``authorization_servers``: list of authorization server URLs
|
||||
- ``resource``: the resource indicator URL (usually the MCP endpoint)
|
||||
- ``scopes_supported``: optional list of supported scopes
|
||||
|
||||
The caller can then fetch the authorization server metadata to get
|
||||
``authorization_endpoint``, ``token_endpoint``, etc.
|
||||
"""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(self.server_url)
|
||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
# Build candidates for protected-resource metadata (per RFC 9728)
|
||||
path = parsed.path.rstrip("/")
|
||||
candidates = []
|
||||
if path and path != "/":
|
||||
candidates.append(f"{base}/.well-known/oauth-protected-resource{path}")
|
||||
candidates.append(f"{base}/.well-known/oauth-protected-resource")
|
||||
|
||||
requests = Requests(
|
||||
raise_for_status=False,
|
||||
)
|
||||
for url in candidates:
|
||||
try:
|
||||
resp = await requests.get(url)
|
||||
if resp.status == 200:
|
||||
data = resp.json()
|
||||
if isinstance(data, dict) and "authorization_servers" in data:
|
||||
return data
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
async def discover_auth_server_metadata(
|
||||
self, auth_server_url: str
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch the OAuth Authorization Server Metadata (RFC 8414).
|
||||
|
||||
Given an authorization server URL, returns a dict with:
|
||||
- ``authorization_endpoint``
|
||||
- ``token_endpoint``
|
||||
- ``registration_endpoint`` (for dynamic client registration)
|
||||
- ``scopes_supported``
|
||||
- ``code_challenge_methods_supported``
|
||||
- etc.
|
||||
"""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(auth_server_url)
|
||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||
path = parsed.path.rstrip("/")
|
||||
|
||||
# Try standard metadata endpoints (RFC 8414 and OpenID Connect)
|
||||
candidates = []
|
||||
if path and path != "/":
|
||||
candidates.append(f"{base}/.well-known/oauth-authorization-server{path}")
|
||||
candidates.append(f"{base}/.well-known/oauth-authorization-server")
|
||||
candidates.append(f"{base}/.well-known/openid-configuration")
|
||||
|
||||
requests = Requests(
|
||||
raise_for_status=False,
|
||||
)
|
||||
for url in candidates:
|
||||
try:
|
||||
resp = await requests.get(url)
|
||||
if resp.status == 200:
|
||||
data = resp.json()
|
||||
if isinstance(data, dict) and "authorization_endpoint" in data:
|
||||
return data
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
async def initialize(self) -> dict[str, Any]:
|
||||
"""
|
||||
Send the MCP initialize request.
|
||||
|
||||
This is required by the MCP protocol before any other requests.
|
||||
Returns the server's capabilities.
|
||||
"""
|
||||
result = await self._send_request(
|
||||
"initialize",
|
||||
{
|
||||
"protocolVersion": "2025-03-26",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "AutoGPT-Platform", "version": "1.0.0"},
|
||||
},
|
||||
)
|
||||
# Send initialized notification (no response expected)
|
||||
await self._send_notification("notifications/initialized")
|
||||
|
||||
return result or {}
|
||||
|
||||
async def list_tools(self) -> list[MCPTool]:
|
||||
"""
|
||||
Discover available tools from the MCP server.
|
||||
|
||||
Returns a list of MCPTool objects with name, description, and input schema.
|
||||
"""
|
||||
result = await self._send_request("tools/list")
|
||||
if not result or "tools" not in result:
|
||||
return []
|
||||
|
||||
tools = []
|
||||
for tool_data in result["tools"]:
|
||||
tools.append(
|
||||
MCPTool(
|
||||
name=tool_data.get("name", ""),
|
||||
description=tool_data.get("description", ""),
|
||||
input_schema=tool_data.get("inputSchema", {}),
|
||||
)
|
||||
)
|
||||
return tools
|
||||
|
||||
async def call_tool(
|
||||
self, tool_name: str, arguments: dict[str, Any]
|
||||
) -> MCPCallResult:
|
||||
"""
|
||||
Call a tool on the MCP server.
|
||||
|
||||
Args:
|
||||
tool_name: The name of the tool to call.
|
||||
arguments: The arguments to pass to the tool.
|
||||
|
||||
Returns:
|
||||
MCPCallResult with the tool's response content.
|
||||
"""
|
||||
result = await self._send_request(
|
||||
"tools/call",
|
||||
{"name": tool_name, "arguments": arguments},
|
||||
)
|
||||
if not result:
|
||||
return MCPCallResult(is_error=True)
|
||||
|
||||
return MCPCallResult(
|
||||
content=result.get("content", []),
|
||||
is_error=result.get("isError", False),
|
||||
)
|
||||
@@ -1,204 +0,0 @@
|
||||
"""
|
||||
MCP OAuth handler for MCP servers that use OAuth 2.1 authorization.
|
||||
|
||||
Unlike other OAuth handlers (GitHub, Google, etc.) where endpoints are fixed,
|
||||
MCP servers have dynamic endpoints discovered via RFC 9728 / RFC 8414 metadata.
|
||||
This handler accepts those endpoints at construction time.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import Requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPOAuthHandler(BaseOAuthHandler):
|
||||
"""
|
||||
OAuth handler for MCP servers with dynamically-discovered endpoints.
|
||||
|
||||
Construction requires the authorization and token endpoint URLs,
|
||||
which are obtained via MCP OAuth metadata discovery
|
||||
(``MCPClient.discover_auth`` + ``discover_auth_server_metadata``).
|
||||
"""
|
||||
|
||||
PROVIDER_NAME: ClassVar[ProviderName | str] = ProviderName.MCP
|
||||
DEFAULT_SCOPES: ClassVar[list[str]] = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
redirect_uri: str,
|
||||
*,
|
||||
authorize_url: str,
|
||||
token_url: str,
|
||||
revoke_url: str | None = None,
|
||||
resource_url: str | None = None,
|
||||
):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
self.authorize_url = authorize_url
|
||||
self.token_url = token_url
|
||||
self.revoke_url = revoke_url
|
||||
self.resource_url = resource_url
|
||||
|
||||
def get_login_url(
|
||||
self,
|
||||
scopes: list[str],
|
||||
state: str,
|
||||
code_challenge: Optional[str],
|
||||
) -> str:
|
||||
scopes = self.handle_default_scopes(scopes)
|
||||
|
||||
params: dict[str, str] = {
|
||||
"response_type": "code",
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"state": state,
|
||||
}
|
||||
if scopes:
|
||||
params["scope"] = " ".join(scopes)
|
||||
# PKCE (S256) — included when the caller provides a code_challenge
|
||||
if code_challenge:
|
||||
params["code_challenge"] = code_challenge
|
||||
params["code_challenge_method"] = "S256"
|
||||
# MCP spec requires resource indicator (RFC 8707)
|
||||
if self.resource_url:
|
||||
params["resource"] = self.resource_url
|
||||
|
||||
return f"{self.authorize_url}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
async def exchange_code_for_tokens(
|
||||
self,
|
||||
code: str,
|
||||
scopes: list[str],
|
||||
code_verifier: Optional[str],
|
||||
) -> OAuth2Credentials:
|
||||
data: dict[str, str] = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"client_id": self.client_id,
|
||||
}
|
||||
if self.client_secret:
|
||||
data["client_secret"] = self.client_secret
|
||||
if code_verifier:
|
||||
data["code_verifier"] = code_verifier
|
||||
if self.resource_url:
|
||||
data["resource"] = self.resource_url
|
||||
|
||||
response = await Requests(raise_for_status=True).post(
|
||||
self.token_url,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
tokens = response.json()
|
||||
|
||||
if "error" in tokens:
|
||||
raise RuntimeError(
|
||||
f"Token exchange failed: {tokens.get('error_description', tokens['error'])}"
|
||||
)
|
||||
|
||||
if "access_token" not in tokens:
|
||||
raise RuntimeError("OAuth token response missing 'access_token' field")
|
||||
|
||||
now = int(time.time())
|
||||
expires_in = tokens.get("expires_in")
|
||||
|
||||
return OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=None,
|
||||
access_token=SecretStr(tokens["access_token"]),
|
||||
refresh_token=(
|
||||
SecretStr(tokens["refresh_token"])
|
||||
if tokens.get("refresh_token")
|
||||
else None
|
||||
),
|
||||
access_token_expires_at=now + expires_in if expires_in else None,
|
||||
refresh_token_expires_at=None,
|
||||
scopes=scopes,
|
||||
metadata={
|
||||
"mcp_token_url": self.token_url,
|
||||
"mcp_resource_url": self.resource_url,
|
||||
},
|
||||
)
|
||||
|
||||
async def _refresh_tokens(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
if not credentials.refresh_token:
|
||||
raise ValueError("No refresh token available for MCP OAuth credentials")
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
||||
"client_id": self.client_id,
|
||||
}
|
||||
if self.client_secret:
|
||||
data["client_secret"] = self.client_secret
|
||||
if self.resource_url:
|
||||
data["resource"] = self.resource_url
|
||||
|
||||
response = await Requests(raise_for_status=True).post(
|
||||
self.token_url,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
tokens = response.json()
|
||||
|
||||
if "error" in tokens:
|
||||
raise RuntimeError(
|
||||
f"Token refresh failed: {tokens.get('error_description', tokens['error'])}"
|
||||
)
|
||||
|
||||
if "access_token" not in tokens:
|
||||
raise RuntimeError("OAuth refresh response missing 'access_token' field")
|
||||
|
||||
now = int(time.time())
|
||||
expires_in = tokens.get("expires_in")
|
||||
|
||||
return OAuth2Credentials(
|
||||
id=credentials.id,
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=credentials.title,
|
||||
access_token=SecretStr(tokens["access_token"]),
|
||||
refresh_token=(
|
||||
SecretStr(tokens["refresh_token"])
|
||||
if tokens.get("refresh_token")
|
||||
else credentials.refresh_token
|
||||
),
|
||||
access_token_expires_at=now + expires_in if expires_in else None,
|
||||
refresh_token_expires_at=credentials.refresh_token_expires_at,
|
||||
scopes=credentials.scopes,
|
||||
metadata=credentials.metadata,
|
||||
)
|
||||
|
||||
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
if not self.revoke_url:
|
||||
return False
|
||||
|
||||
try:
|
||||
data = {
|
||||
"token": credentials.access_token.get_secret_value(),
|
||||
"token_type_hint": "access_token",
|
||||
"client_id": self.client_id,
|
||||
}
|
||||
await Requests().post(
|
||||
self.revoke_url,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
logger.warning("Failed to revoke MCP OAuth tokens", exc_info=True)
|
||||
return False
|
||||
@@ -1,109 +0,0 @@
|
||||
"""
|
||||
End-to-end tests against a real public MCP server.
|
||||
|
||||
These tests hit the OpenAI docs MCP server (https://developers.openai.com/mcp)
|
||||
which is publicly accessible without authentication and returns SSE responses.
|
||||
|
||||
Mark: These are tagged with ``@pytest.mark.e2e`` so they can be run/skipped
|
||||
independently of the rest of the test suite (they require network access).
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.mcp.client import MCPClient
|
||||
|
||||
# Public MCP server that requires no authentication
|
||||
OPENAI_DOCS_MCP_URL = "https://developers.openai.com/mcp"
|
||||
|
||||
# Skip all tests in this module unless RUN_E2E env var is set
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not os.environ.get("RUN_E2E"), reason="set RUN_E2E=1 to run e2e tests"
|
||||
)
|
||||
|
||||
|
||||
class TestRealMCPServer:
|
||||
"""Tests against the live OpenAI docs MCP server."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_initialize(self):
|
||||
"""Verify we can complete the MCP handshake with a real server."""
|
||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||
result = await client.initialize()
|
||||
|
||||
assert result["protocolVersion"] == "2025-03-26"
|
||||
assert "serverInfo" in result
|
||||
assert result["serverInfo"]["name"] == "openai-docs-mcp"
|
||||
assert "tools" in result.get("capabilities", {})
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_list_tools(self):
|
||||
"""Verify we can discover tools from a real MCP server."""
|
||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||
await client.initialize()
|
||||
tools = await client.list_tools()
|
||||
|
||||
assert len(tools) >= 3 # server has at least 5 tools as of writing
|
||||
|
||||
tool_names = {t.name for t in tools}
|
||||
# These tools are documented and should be stable
|
||||
assert "search_openai_docs" in tool_names
|
||||
assert "list_openai_docs" in tool_names
|
||||
assert "fetch_openai_doc" in tool_names
|
||||
|
||||
# Verify schema structure
|
||||
search_tool = next(t for t in tools if t.name == "search_openai_docs")
|
||||
assert "query" in search_tool.input_schema.get("properties", {})
|
||||
assert "query" in search_tool.input_schema.get("required", [])
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_tool_list_api_endpoints(self):
|
||||
"""Call the list_api_endpoints tool and verify we get real data."""
|
||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||
await client.initialize()
|
||||
result = await client.call_tool("list_api_endpoints", {})
|
||||
|
||||
assert not result.is_error
|
||||
assert len(result.content) >= 1
|
||||
assert result.content[0]["type"] == "text"
|
||||
|
||||
data = json.loads(result.content[0]["text"])
|
||||
assert "paths" in data or "urls" in data
|
||||
# The OpenAI API should have many endpoints
|
||||
total = data.get("total", len(data.get("paths", [])))
|
||||
assert total > 50
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_tool_search(self):
|
||||
"""Search for docs and verify we get results."""
|
||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||
await client.initialize()
|
||||
result = await client.call_tool(
|
||||
"search_openai_docs", {"query": "chat completions", "limit": 3}
|
||||
)
|
||||
|
||||
assert not result.is_error
|
||||
assert len(result.content) >= 1
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_sse_response_handling(self):
|
||||
"""Verify the client correctly handles SSE responses from a real server.
|
||||
|
||||
This is the key test — our local test server returns JSON,
|
||||
but real MCP servers typically return SSE. This proves the
|
||||
SSE parsing works end-to-end.
|
||||
"""
|
||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||
# initialize() internally calls _send_request which must parse SSE
|
||||
result = await client.initialize()
|
||||
|
||||
# If we got here without error, SSE parsing works
|
||||
assert isinstance(result, dict)
|
||||
assert "protocolVersion" in result
|
||||
|
||||
# Also verify list_tools works (another SSE response)
|
||||
tools = await client.list_tools()
|
||||
assert len(tools) > 0
|
||||
assert all(hasattr(t, "name") for t in tools)
|
||||
@@ -1,389 +0,0 @@
|
||||
"""
|
||||
Integration tests for MCP client and MCPToolBlock against a real HTTP server.
|
||||
|
||||
These tests spin up a local MCP test server and run the full client/block flow
|
||||
against it — no mocking, real HTTP requests.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.mcp.block import MCPToolBlock
|
||||
from backend.blocks.mcp.client import MCPClient
|
||||
from backend.blocks.mcp.test_server import create_test_mcp_app
|
||||
from backend.data.model import OAuth2Credentials
|
||||
|
||||
MOCK_USER_ID = "test-user-integration"
|
||||
|
||||
|
||||
class _MCPTestServer:
|
||||
"""
|
||||
Run an MCP test server in a background thread with its own event loop.
|
||||
This avoids event loop conflicts with pytest-asyncio.
|
||||
"""
|
||||
|
||||
def __init__(self, auth_token: str | None = None):
|
||||
self.auth_token = auth_token
|
||||
self.url: str = ""
|
||||
self._runner: web.AppRunner | None = None
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._thread: threading.Thread | None = None
|
||||
self._started = threading.Event()
|
||||
|
||||
def _run(self):
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._loop.run_until_complete(self._start())
|
||||
self._started.set()
|
||||
self._loop.run_forever()
|
||||
|
||||
async def _start(self):
|
||||
app = create_test_mcp_app(auth_token=self.auth_token)
|
||||
self._runner = web.AppRunner(app)
|
||||
await self._runner.setup()
|
||||
site = web.TCPSite(self._runner, "127.0.0.1", 0)
|
||||
await site.start()
|
||||
port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr]
|
||||
self.url = f"http://127.0.0.1:{port}/mcp"
|
||||
|
||||
def start(self):
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
if not self._started.wait(timeout=5):
|
||||
raise RuntimeError("MCP test server failed to start within 5 seconds")
|
||||
return self
|
||||
|
||||
def stop(self):
|
||||
if self._loop and self._runner:
|
||||
asyncio.run_coroutine_threadsafe(self._runner.cleanup(), self._loop).result(
|
||||
timeout=5
|
||||
)
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mcp_server():
|
||||
"""Start a local MCP test server in a background thread."""
|
||||
server = _MCPTestServer()
|
||||
server.start()
|
||||
yield server.url
|
||||
server.stop()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mcp_server_with_auth():
|
||||
"""Start a local MCP test server with auth in a background thread."""
|
||||
server = _MCPTestServer(auth_token="test-secret-token")
|
||||
server.start()
|
||||
yield server.url, "test-secret-token"
|
||||
server.stop()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _allow_localhost():
|
||||
"""
|
||||
Allow 127.0.0.1 through SSRF protection for integration tests.
|
||||
|
||||
The Requests class blocks private IPs by default. We patch the Requests
|
||||
constructor to always include 127.0.0.1 as a trusted origin so the local
|
||||
test server is reachable.
|
||||
"""
|
||||
from backend.util.request import Requests
|
||||
|
||||
original_init = Requests.__init__
|
||||
|
||||
def patched_init(self, *args, **kwargs):
|
||||
trusted = list(kwargs.get("trusted_origins") or [])
|
||||
trusted.append("http://127.0.0.1")
|
||||
kwargs["trusted_origins"] = trusted
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
with patch.object(Requests, "__init__", patched_init):
|
||||
yield
|
||||
|
||||
|
||||
def _make_client(url: str, auth_token: str | None = None) -> MCPClient:
|
||||
"""Create an MCPClient for integration tests."""
|
||||
return MCPClient(url, auth_token=auth_token)
|
||||
|
||||
|
||||
# ── MCPClient integration tests ──────────────────────────────────────
|
||||
|
||||
|
||||
class TestMCPClientIntegration:
|
||||
"""Test MCPClient against a real local MCP server."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_initialize(self, mcp_server):
|
||||
client = _make_client(mcp_server)
|
||||
result = await client.initialize()
|
||||
|
||||
assert result["protocolVersion"] == "2025-03-26"
|
||||
assert result["serverInfo"]["name"] == "test-mcp-server"
|
||||
assert "tools" in result["capabilities"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_list_tools(self, mcp_server):
|
||||
client = _make_client(mcp_server)
|
||||
await client.initialize()
|
||||
tools = await client.list_tools()
|
||||
|
||||
assert len(tools) == 3
|
||||
|
||||
tool_names = {t.name for t in tools}
|
||||
assert tool_names == {"get_weather", "add_numbers", "echo"}
|
||||
|
||||
# Check get_weather schema
|
||||
weather = next(t for t in tools if t.name == "get_weather")
|
||||
assert weather.description == "Get current weather for a city"
|
||||
assert "city" in weather.input_schema["properties"]
|
||||
assert weather.input_schema["required"] == ["city"]
|
||||
|
||||
# Check add_numbers schema
|
||||
add = next(t for t in tools if t.name == "add_numbers")
|
||||
assert "a" in add.input_schema["properties"]
|
||||
assert "b" in add.input_schema["properties"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_tool_get_weather(self, mcp_server):
|
||||
client = _make_client(mcp_server)
|
||||
await client.initialize()
|
||||
result = await client.call_tool("get_weather", {"city": "London"})
|
||||
|
||||
assert not result.is_error
|
||||
assert len(result.content) == 1
|
||||
assert result.content[0]["type"] == "text"
|
||||
|
||||
data = json.loads(result.content[0]["text"])
|
||||
assert data["city"] == "London"
|
||||
assert data["temperature"] == 22
|
||||
assert data["condition"] == "sunny"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_tool_add_numbers(self, mcp_server):
|
||||
client = _make_client(mcp_server)
|
||||
await client.initialize()
|
||||
result = await client.call_tool("add_numbers", {"a": 3, "b": 7})
|
||||
|
||||
assert not result.is_error
|
||||
data = json.loads(result.content[0]["text"])
|
||||
assert data["result"] == 10
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_tool_echo(self, mcp_server):
|
||||
client = _make_client(mcp_server)
|
||||
await client.initialize()
|
||||
result = await client.call_tool("echo", {"message": "Hello MCP!"})
|
||||
|
||||
assert not result.is_error
|
||||
assert result.content[0]["text"] == "Hello MCP!"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_unknown_tool(self, mcp_server):
|
||||
client = _make_client(mcp_server)
|
||||
await client.initialize()
|
||||
result = await client.call_tool("nonexistent_tool", {})
|
||||
|
||||
assert result.is_error
|
||||
assert "Unknown tool" in result.content[0]["text"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_auth_success(self, mcp_server_with_auth):
|
||||
url, token = mcp_server_with_auth
|
||||
client = _make_client(url, auth_token=token)
|
||||
result = await client.initialize()
|
||||
|
||||
assert result["protocolVersion"] == "2025-03-26"
|
||||
|
||||
tools = await client.list_tools()
|
||||
assert len(tools) == 3
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_auth_failure(self, mcp_server_with_auth):
|
||||
url, _ = mcp_server_with_auth
|
||||
client = _make_client(url, auth_token="wrong-token")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await client.initialize()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_auth_missing(self, mcp_server_with_auth):
|
||||
url, _ = mcp_server_with_auth
|
||||
client = _make_client(url)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await client.initialize()
|
||||
|
||||
|
||||
# ── MCPToolBlock integration tests ───────────────────────────────────
|
||||
|
||||
|
||||
class TestMCPToolBlockIntegration:
|
||||
"""Test MCPToolBlock end-to-end against a real local MCP server."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_full_flow_get_weather(self, mcp_server):
|
||||
"""Full flow: discover tools, select one, execute it."""
|
||||
# Step 1: Discover tools (simulating what the frontend/API would do)
|
||||
client = _make_client(mcp_server)
|
||||
await client.initialize()
|
||||
tools = await client.list_tools()
|
||||
assert len(tools) == 3
|
||||
|
||||
# Step 2: User selects "get_weather" and we get its schema
|
||||
weather_tool = next(t for t in tools if t.name == "get_weather")
|
||||
|
||||
# Step 3: Execute the block — no credentials (public server)
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url=mcp_server,
|
||||
selected_tool="get_weather",
|
||||
tool_input_schema=weather_tool.input_schema,
|
||||
tool_arguments={"city": "Paris"},
|
||||
)
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0][0] == "result"
|
||||
result = outputs[0][1]
|
||||
assert result["city"] == "Paris"
|
||||
assert result["temperature"] == 22
|
||||
assert result["condition"] == "sunny"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_full_flow_add_numbers(self, mcp_server):
|
||||
"""Full flow for add_numbers tool."""
|
||||
client = _make_client(mcp_server)
|
||||
await client.initialize()
|
||||
tools = await client.list_tools()
|
||||
add_tool = next(t for t in tools if t.name == "add_numbers")
|
||||
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url=mcp_server,
|
||||
selected_tool="add_numbers",
|
||||
tool_input_schema=add_tool.input_schema,
|
||||
tool_arguments={"a": 42, "b": 58},
|
||||
)
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0][0] == "result"
|
||||
assert outputs[0][1]["result"] == 100
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_full_flow_echo_plain_text(self, mcp_server):
|
||||
"""Verify plain text (non-JSON) responses work."""
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url=mcp_server,
|
||||
selected_tool="echo",
|
||||
tool_input_schema={
|
||||
"type": "object",
|
||||
"properties": {"message": {"type": "string"}},
|
||||
"required": ["message"],
|
||||
},
|
||||
tool_arguments={"message": "Hello from AutoGPT!"},
|
||||
)
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0][0] == "result"
|
||||
assert outputs[0][1] == "Hello from AutoGPT!"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_full_flow_unknown_tool_yields_error(self, mcp_server):
|
||||
"""Calling an unknown tool should yield an error output."""
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url=mcp_server,
|
||||
selected_tool="nonexistent_tool",
|
||||
tool_arguments={},
|
||||
)
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0][0] == "error"
|
||||
assert "returned an error" in outputs[0][1]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_full_flow_with_auth(self, mcp_server_with_auth):
|
||||
"""Full flow with authentication via credentials kwarg."""
|
||||
url, token = mcp_server_with_auth
|
||||
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url=url,
|
||||
selected_tool="echo",
|
||||
tool_input_schema={
|
||||
"type": "object",
|
||||
"properties": {"message": {"type": "string"}},
|
||||
"required": ["message"],
|
||||
},
|
||||
tool_arguments={"message": "Authenticated!"},
|
||||
)
|
||||
|
||||
# Pass credentials via the standard kwarg (as the executor would)
|
||||
test_creds = OAuth2Credentials(
|
||||
id="test-cred",
|
||||
provider="mcp",
|
||||
access_token=SecretStr(token),
|
||||
refresh_token=SecretStr(""),
|
||||
scopes=[],
|
||||
title="Test MCP credential",
|
||||
)
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(
|
||||
input_data, user_id=MOCK_USER_ID, credentials=test_creds
|
||||
):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0][0] == "result"
|
||||
assert outputs[0][1] == "Authenticated!"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_credentials_runs_without_auth(self, mcp_server):
|
||||
"""Block runs without auth when no credentials are provided."""
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url=mcp_server,
|
||||
selected_tool="echo",
|
||||
tool_input_schema={
|
||||
"type": "object",
|
||||
"properties": {"message": {"type": "string"}},
|
||||
"required": ["message"],
|
||||
},
|
||||
tool_arguments={"message": "No auth needed"},
|
||||
)
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(
|
||||
input_data, user_id=MOCK_USER_ID, credentials=None
|
||||
):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0][0] == "result"
|
||||
assert outputs[0][1] == "No auth needed"
|
||||
@@ -1,619 +0,0 @@
|
||||
"""
|
||||
Tests for MCP client and MCPToolBlock.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.mcp.block import MCPToolBlock
|
||||
from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError
|
||||
from backend.util.test import execute_block_test
|
||||
|
||||
# ── SSE parsing unit tests ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSSEParsing:
|
||||
"""Tests for SSE (text/event-stream) response parsing."""
|
||||
|
||||
def test_parse_sse_simple(self):
|
||||
sse = (
|
||||
"event: message\n"
|
||||
'data: {"jsonrpc":"2.0","result":{"tools":[]},"id":1}\n'
|
||||
"\n"
|
||||
)
|
||||
body = MCPClient._parse_sse_response(sse)
|
||||
assert body["result"] == {"tools": []}
|
||||
assert body["id"] == 1
|
||||
|
||||
def test_parse_sse_with_notifications(self):
|
||||
"""SSE streams can contain notifications (no id) before the response."""
|
||||
sse = (
|
||||
"event: message\n"
|
||||
'data: {"jsonrpc":"2.0","method":"some/notification"}\n'
|
||||
"\n"
|
||||
"event: message\n"
|
||||
'data: {"jsonrpc":"2.0","result":{"ok":true},"id":2}\n'
|
||||
"\n"
|
||||
)
|
||||
body = MCPClient._parse_sse_response(sse)
|
||||
assert body["result"] == {"ok": True}
|
||||
assert body["id"] == 2
|
||||
|
||||
def test_parse_sse_error_response(self):
|
||||
sse = (
|
||||
"event: message\n"
|
||||
'data: {"jsonrpc":"2.0","error":{"code":-32600,"message":"Bad Request"},"id":1}\n'
|
||||
)
|
||||
body = MCPClient._parse_sse_response(sse)
|
||||
assert "error" in body
|
||||
assert body["error"]["code"] == -32600
|
||||
|
||||
def test_parse_sse_no_data_raises(self):
|
||||
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
|
||||
MCPClient._parse_sse_response("event: message\n\n")
|
||||
|
||||
def test_parse_sse_empty_raises(self):
|
||||
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
|
||||
MCPClient._parse_sse_response("")
|
||||
|
||||
def test_parse_sse_ignores_non_data_lines(self):
|
||||
sse = (
|
||||
": comment line\n"
|
||||
"event: message\n"
|
||||
"id: 123\n"
|
||||
'data: {"jsonrpc":"2.0","result":"ok","id":1}\n'
|
||||
"\n"
|
||||
)
|
||||
body = MCPClient._parse_sse_response(sse)
|
||||
assert body["result"] == "ok"
|
||||
|
||||
def test_parse_sse_uses_last_response(self):
|
||||
"""If multiple responses exist, use the last one."""
|
||||
sse = (
|
||||
'data: {"jsonrpc":"2.0","result":"first","id":1}\n'
|
||||
"\n"
|
||||
'data: {"jsonrpc":"2.0","result":"second","id":2}\n'
|
||||
"\n"
|
||||
)
|
||||
body = MCPClient._parse_sse_response(sse)
|
||||
assert body["result"] == "second"
|
||||
|
||||
|
||||
# ── MCPClient unit tests ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestMCPClient:
|
||||
"""Tests for the MCP HTTP client."""
|
||||
|
||||
def test_build_headers_without_auth(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
headers = client._build_headers()
|
||||
assert "Authorization" not in headers
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
|
||||
def test_build_headers_with_auth(self):
|
||||
client = MCPClient("https://mcp.example.com", auth_token="my-token")
|
||||
headers = client._build_headers()
|
||||
assert headers["Authorization"] == "Bearer my-token"
|
||||
|
||||
def test_build_jsonrpc_request(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
req = client._build_jsonrpc_request("tools/list")
|
||||
assert req["jsonrpc"] == "2.0"
|
||||
assert req["method"] == "tools/list"
|
||||
assert "id" in req
|
||||
assert "params" not in req
|
||||
|
||||
def test_build_jsonrpc_request_with_params(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
req = client._build_jsonrpc_request(
|
||||
"tools/call", {"name": "test", "arguments": {"x": 1}}
|
||||
)
|
||||
assert req["params"] == {"name": "test", "arguments": {"x": 1}}
|
||||
|
||||
def test_request_id_increments(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
req1 = client._build_jsonrpc_request("tools/list")
|
||||
req2 = client._build_jsonrpc_request("tools/list")
|
||||
assert req2["id"] > req1["id"]
|
||||
|
||||
def test_server_url_trailing_slash_stripped(self):
|
||||
client = MCPClient("https://mcp.example.com/mcp/")
|
||||
assert client.server_url == "https://mcp.example.com/mcp"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_send_request_success(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.json.return_value = {
|
||||
"jsonrpc": "2.0",
|
||||
"result": {"tools": []},
|
||||
"id": 1,
|
||||
}
|
||||
|
||||
with patch.object(client, "_send_request", return_value={"tools": []}):
|
||||
result = await client._send_request("tools/list")
|
||||
assert result == {"tools": []}
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_send_request_error(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
async def mock_send(*args, **kwargs):
|
||||
raise MCPClientError("MCP server error [-32600]: Invalid Request")
|
||||
|
||||
with patch.object(client, "_send_request", side_effect=mock_send):
|
||||
with pytest.raises(MCPClientError, match="Invalid Request"):
|
||||
await client._send_request("tools/list")
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_list_tools(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
mock_result = {
|
||||
"tools": [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather for a city",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "search",
|
||||
"description": "Search the web",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with patch.object(client, "_send_request", return_value=mock_result):
|
||||
tools = await client.list_tools()
|
||||
|
||||
assert len(tools) == 2
|
||||
assert tools[0].name == "get_weather"
|
||||
assert tools[0].description == "Get current weather for a city"
|
||||
assert tools[0].input_schema["properties"]["city"]["type"] == "string"
|
||||
assert tools[1].name == "search"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_list_tools_empty(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
with patch.object(client, "_send_request", return_value={"tools": []}):
|
||||
tools = await client.list_tools()
|
||||
|
||||
assert tools == []
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_list_tools_none_result(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
with patch.object(client, "_send_request", return_value=None):
|
||||
tools = await client.list_tools()
|
||||
|
||||
assert tools == []
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_tool_success(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
mock_result = {
|
||||
"content": [
|
||||
{"type": "text", "text": json.dumps({"temp": 20, "city": "London"})}
|
||||
],
|
||||
"isError": False,
|
||||
}
|
||||
|
||||
with patch.object(client, "_send_request", return_value=mock_result):
|
||||
result = await client.call_tool("get_weather", {"city": "London"})
|
||||
|
||||
assert not result.is_error
|
||||
assert len(result.content) == 1
|
||||
assert result.content[0]["type"] == "text"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_tool_error(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
mock_result = {
|
||||
"content": [{"type": "text", "text": "City not found"}],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
with patch.object(client, "_send_request", return_value=mock_result):
|
||||
result = await client.call_tool("get_weather", {"city": "???"})
|
||||
|
||||
assert result.is_error
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_tool_none_result(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
with patch.object(client, "_send_request", return_value=None):
|
||||
result = await client.call_tool("get_weather", {"city": "London"})
|
||||
|
||||
assert result.is_error
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_initialize(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
mock_result = {
|
||||
"protocolVersion": "2025-03-26",
|
||||
"capabilities": {"tools": {}},
|
||||
"serverInfo": {"name": "test-server", "version": "1.0.0"},
|
||||
}
|
||||
|
||||
with (
|
||||
patch.object(client, "_send_request", return_value=mock_result) as mock_req,
|
||||
patch.object(client, "_send_notification") as mock_notif,
|
||||
):
|
||||
result = await client.initialize()
|
||||
|
||||
mock_req.assert_called_once()
|
||||
mock_notif.assert_called_once_with("notifications/initialized")
|
||||
assert result["protocolVersion"] == "2025-03-26"
|
||||
|
||||
|
||||
# ── MCPToolBlock unit tests ──────────────────────────────────────────
|
||||
|
||||
MOCK_USER_ID = "test-user-123"
|
||||
|
||||
|
||||
class TestMCPToolBlock:
|
||||
"""Tests for the MCPToolBlock."""
|
||||
|
||||
def test_block_instantiation(self):
|
||||
block = MCPToolBlock()
|
||||
assert block.id == "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
|
||||
assert block.name == "MCPToolBlock"
|
||||
|
||||
def test_input_schema_has_required_fields(self):
|
||||
block = MCPToolBlock()
|
||||
schema = block.input_schema.jsonschema()
|
||||
props = schema.get("properties", {})
|
||||
assert "server_url" in props
|
||||
assert "selected_tool" in props
|
||||
assert "tool_arguments" in props
|
||||
assert "credentials" in props
|
||||
|
||||
def test_output_schema(self):
|
||||
block = MCPToolBlock()
|
||||
schema = block.output_schema.jsonschema()
|
||||
props = schema.get("properties", {})
|
||||
assert "result" in props
|
||||
assert "error" in props
|
||||
|
||||
def test_get_input_schema_with_tool_schema(self):
|
||||
tool_schema = {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"],
|
||||
}
|
||||
data = {"tool_input_schema": tool_schema}
|
||||
result = MCPToolBlock.Input.get_input_schema(data)
|
||||
assert result == tool_schema
|
||||
|
||||
def test_get_input_schema_without_tool_schema(self):
|
||||
result = MCPToolBlock.Input.get_input_schema({})
|
||||
assert result == {}
|
||||
|
||||
def test_get_input_defaults(self):
|
||||
data = {"tool_arguments": {"city": "London"}}
|
||||
result = MCPToolBlock.Input.get_input_defaults(data)
|
||||
assert result == {"city": "London"}
|
||||
|
||||
def test_get_missing_input(self):
|
||||
data = {
|
||||
"tool_input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"units": {"type": "string"},
|
||||
},
|
||||
"required": ["city", "units"],
|
||||
},
|
||||
"tool_arguments": {"city": "London"},
|
||||
}
|
||||
missing = MCPToolBlock.Input.get_missing_input(data)
|
||||
assert missing == {"units"}
|
||||
|
||||
def test_get_missing_input_all_present(self):
|
||||
data = {
|
||||
"tool_input_schema": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
"tool_arguments": {"city": "London"},
|
||||
}
|
||||
missing = MCPToolBlock.Input.get_missing_input(data)
|
||||
assert missing == set()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_with_mock(self):
|
||||
"""Test the block using the built-in test infrastructure."""
|
||||
block = MCPToolBlock()
|
||||
await execute_block_test(block)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_missing_server_url(self):
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="",
|
||||
selected_tool="test",
|
||||
)
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
outputs.append((name, data))
|
||||
assert outputs == [("error", "MCP server URL is required")]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_missing_tool(self):
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="https://mcp.example.com/mcp",
|
||||
selected_tool="",
|
||||
)
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
outputs.append((name, data))
|
||||
assert outputs == [
|
||||
("error", "No tool selected. Please select a tool from the dropdown.")
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_success(self):
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="https://mcp.example.com/mcp",
|
||||
selected_tool="get_weather",
|
||||
tool_input_schema={
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
},
|
||||
tool_arguments={"city": "London"},
|
||||
)
|
||||
|
||||
async def mock_call(*args, **kwargs):
|
||||
return {"temp": 20, "city": "London"}
|
||||
|
||||
block._call_mcp_tool = mock_call # type: ignore
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0][0] == "result"
|
||||
assert outputs[0][1] == {"temp": 20, "city": "London"}
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_mcp_error(self):
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="https://mcp.example.com/mcp",
|
||||
selected_tool="bad_tool",
|
||||
)
|
||||
|
||||
async def mock_call(*args, **kwargs):
|
||||
raise MCPClientError("Tool not found")
|
||||
|
||||
block._call_mcp_tool = mock_call # type: ignore
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert outputs[0][0] == "error"
|
||||
assert "Tool not found" in outputs[0][1]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_mcp_tool_parses_json_text(self):
|
||||
block = MCPToolBlock()
|
||||
|
||||
mock_result = MCPCallResult(
|
||||
content=[
|
||||
{"type": "text", "text": '{"temp": 20}'},
|
||||
],
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
async def mock_init(self):
|
||||
return {}
|
||||
|
||||
async def mock_call(self, name, args):
|
||||
return mock_result
|
||||
|
||||
with (
|
||||
patch.object(MCPClient, "initialize", mock_init),
|
||||
patch.object(MCPClient, "call_tool", mock_call),
|
||||
):
|
||||
result = await block._call_mcp_tool(
|
||||
"https://mcp.example.com", "test_tool", {}
|
||||
)
|
||||
|
||||
assert result == {"temp": 20}
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_mcp_tool_plain_text(self):
|
||||
block = MCPToolBlock()
|
||||
|
||||
mock_result = MCPCallResult(
|
||||
content=[
|
||||
{"type": "text", "text": "Hello, world!"},
|
||||
],
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
async def mock_init(self):
|
||||
return {}
|
||||
|
||||
async def mock_call(self, name, args):
|
||||
return mock_result
|
||||
|
||||
with (
|
||||
patch.object(MCPClient, "initialize", mock_init),
|
||||
patch.object(MCPClient, "call_tool", mock_call),
|
||||
):
|
||||
result = await block._call_mcp_tool(
|
||||
"https://mcp.example.com", "test_tool", {}
|
||||
)
|
||||
|
||||
assert result == "Hello, world!"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_mcp_tool_multiple_content(self):
|
||||
block = MCPToolBlock()
|
||||
|
||||
mock_result = MCPCallResult(
|
||||
content=[
|
||||
{"type": "text", "text": "Part 1"},
|
||||
{"type": "text", "text": '{"part": 2}'},
|
||||
],
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
async def mock_init(self):
|
||||
return {}
|
||||
|
||||
async def mock_call(self, name, args):
|
||||
return mock_result
|
||||
|
||||
with (
|
||||
patch.object(MCPClient, "initialize", mock_init),
|
||||
patch.object(MCPClient, "call_tool", mock_call),
|
||||
):
|
||||
result = await block._call_mcp_tool(
|
||||
"https://mcp.example.com", "test_tool", {}
|
||||
)
|
||||
|
||||
assert result == ["Part 1", {"part": 2}]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_mcp_tool_error_result(self):
|
||||
block = MCPToolBlock()
|
||||
|
||||
mock_result = MCPCallResult(
|
||||
content=[{"type": "text", "text": "Something went wrong"}],
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
async def mock_init(self):
|
||||
return {}
|
||||
|
||||
async def mock_call(self, name, args):
|
||||
return mock_result
|
||||
|
||||
with (
|
||||
patch.object(MCPClient, "initialize", mock_init),
|
||||
patch.object(MCPClient, "call_tool", mock_call),
|
||||
):
|
||||
with pytest.raises(MCPClientError, match="returned an error"):
|
||||
await block._call_mcp_tool("https://mcp.example.com", "test_tool", {})
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_mcp_tool_image_content(self):
|
||||
block = MCPToolBlock()
|
||||
|
||||
mock_result = MCPCallResult(
|
||||
content=[
|
||||
{
|
||||
"type": "image",
|
||||
"data": "base64data==",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
],
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
async def mock_init(self):
|
||||
return {}
|
||||
|
||||
async def mock_call(self, name, args):
|
||||
return mock_result
|
||||
|
||||
with (
|
||||
patch.object(MCPClient, "initialize", mock_init),
|
||||
patch.object(MCPClient, "call_tool", mock_call),
|
||||
):
|
||||
result = await block._call_mcp_tool(
|
||||
"https://mcp.example.com", "test_tool", {}
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"type": "image",
|
||||
"data": "base64data==",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_with_credentials(self):
|
||||
"""Verify the block uses OAuth2Credentials and passes auth token."""
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="https://mcp.example.com/mcp",
|
||||
selected_tool="test_tool",
|
||||
)
|
||||
|
||||
captured_tokens: list[str | None] = []
|
||||
|
||||
async def mock_call(server_url, tool_name, arguments, auth_token=None):
|
||||
captured_tokens.append(auth_token)
|
||||
return "ok"
|
||||
|
||||
block._call_mcp_tool = mock_call # type: ignore
|
||||
|
||||
test_creds = OAuth2Credentials(
|
||||
id="cred-123",
|
||||
provider="mcp",
|
||||
access_token=SecretStr("resolved-token"),
|
||||
refresh_token=SecretStr(""),
|
||||
scopes=[],
|
||||
title="Test MCP credential",
|
||||
)
|
||||
|
||||
async for _ in block.run(
|
||||
input_data, user_id=MOCK_USER_ID, credentials=test_creds
|
||||
):
|
||||
pass
|
||||
|
||||
assert captured_tokens == ["resolved-token"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_without_credentials(self):
|
||||
"""Verify the block works without credentials (public server)."""
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="https://mcp.example.com/mcp",
|
||||
selected_tool="test_tool",
|
||||
)
|
||||
|
||||
captured_tokens: list[str | None] = []
|
||||
|
||||
async def mock_call(server_url, tool_name, arguments, auth_token=None):
|
||||
captured_tokens.append(auth_token)
|
||||
return "ok"
|
||||
|
||||
block._call_mcp_tool = mock_call # type: ignore
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert captured_tokens == [None]
|
||||
assert outputs == [("result", "ok")]
|
||||
@@ -1,242 +0,0 @@
|
||||
"""
|
||||
Tests for MCP OAuth handler.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.mcp.client import MCPClient
|
||||
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||
from backend.data.model import OAuth2Credentials
|
||||
|
||||
|
||||
def _mock_response(json_data: dict, status: int = 200) -> MagicMock:
|
||||
"""Create a mock Response with synchronous json() (matching Requests.Response)."""
|
||||
resp = MagicMock()
|
||||
resp.status = status
|
||||
resp.ok = 200 <= status < 300
|
||||
resp.json.return_value = json_data
|
||||
return resp
|
||||
|
||||
|
||||
class TestMCPOAuthHandler:
|
||||
"""Tests for the MCPOAuthHandler."""
|
||||
|
||||
def _make_handler(self, **overrides) -> MCPOAuthHandler:
|
||||
defaults = {
|
||||
"client_id": "test-client-id",
|
||||
"client_secret": "test-client-secret",
|
||||
"redirect_uri": "https://app.example.com/callback",
|
||||
"authorize_url": "https://auth.example.com/authorize",
|
||||
"token_url": "https://auth.example.com/token",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return MCPOAuthHandler(**defaults)
|
||||
|
||||
def test_get_login_url_basic(self):
|
||||
handler = self._make_handler()
|
||||
url = handler.get_login_url(
|
||||
scopes=["read", "write"],
|
||||
state="random-state-token",
|
||||
code_challenge="S256-challenge-value",
|
||||
)
|
||||
|
||||
assert "https://auth.example.com/authorize?" in url
|
||||
assert "response_type=code" in url
|
||||
assert "client_id=test-client-id" in url
|
||||
assert "state=random-state-token" in url
|
||||
assert "code_challenge=S256-challenge-value" in url
|
||||
assert "code_challenge_method=S256" in url
|
||||
assert "scope=read+write" in url
|
||||
|
||||
def test_get_login_url_with_resource(self):
|
||||
handler = self._make_handler(resource_url="https://mcp.example.com/mcp")
|
||||
url = handler.get_login_url(
|
||||
scopes=[], state="state", code_challenge="challenge"
|
||||
)
|
||||
|
||||
assert "resource=https" in url
|
||||
|
||||
def test_get_login_url_without_pkce(self):
|
||||
handler = self._make_handler()
|
||||
url = handler.get_login_url(scopes=["read"], state="state", code_challenge=None)
|
||||
|
||||
assert "code_challenge" not in url
|
||||
assert "code_challenge_method" not in url
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_exchange_code_for_tokens(self):
|
||||
handler = self._make_handler()
|
||||
|
||||
resp = _mock_response(
|
||||
{
|
||||
"access_token": "new-access-token",
|
||||
"refresh_token": "new-refresh-token",
|
||||
"expires_in": 3600,
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
)
|
||||
|
||||
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
||||
instance = MockRequests.return_value
|
||||
instance.post = AsyncMock(return_value=resp)
|
||||
|
||||
creds = await handler.exchange_code_for_tokens(
|
||||
code="auth-code",
|
||||
scopes=["read"],
|
||||
code_verifier="pkce-verifier",
|
||||
)
|
||||
|
||||
assert isinstance(creds, OAuth2Credentials)
|
||||
assert creds.access_token.get_secret_value() == "new-access-token"
|
||||
assert creds.refresh_token is not None
|
||||
assert creds.refresh_token.get_secret_value() == "new-refresh-token"
|
||||
assert creds.scopes == ["read"]
|
||||
assert creds.access_token_expires_at is not None
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_refresh_tokens(self):
|
||||
handler = self._make_handler()
|
||||
|
||||
existing_creds = OAuth2Credentials(
|
||||
id="existing-id",
|
||||
provider="mcp",
|
||||
access_token=SecretStr("old-token"),
|
||||
refresh_token=SecretStr("old-refresh"),
|
||||
scopes=["read"],
|
||||
title="test",
|
||||
)
|
||||
|
||||
resp = _mock_response(
|
||||
{
|
||||
"access_token": "refreshed-token",
|
||||
"refresh_token": "new-refresh",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
)
|
||||
|
||||
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
||||
instance = MockRequests.return_value
|
||||
instance.post = AsyncMock(return_value=resp)
|
||||
|
||||
refreshed = await handler._refresh_tokens(existing_creds)
|
||||
|
||||
assert refreshed.id == "existing-id"
|
||||
assert refreshed.access_token.get_secret_value() == "refreshed-token"
|
||||
assert refreshed.refresh_token is not None
|
||||
assert refreshed.refresh_token.get_secret_value() == "new-refresh"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_refresh_tokens_no_refresh_token(self):
|
||||
handler = self._make_handler()
|
||||
|
||||
creds = OAuth2Credentials(
|
||||
provider="mcp",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=["read"],
|
||||
title="test",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="No refresh token"):
|
||||
await handler._refresh_tokens(creds)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_revoke_tokens_no_url(self):
|
||||
handler = self._make_handler(revoke_url=None)
|
||||
|
||||
creds = OAuth2Credentials(
|
||||
provider="mcp",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=[],
|
||||
title="test",
|
||||
)
|
||||
|
||||
result = await handler.revoke_tokens(creds)
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_revoke_tokens_with_url(self):
|
||||
handler = self._make_handler(revoke_url="https://auth.example.com/revoke")
|
||||
|
||||
creds = OAuth2Credentials(
|
||||
provider="mcp",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=[],
|
||||
title="test",
|
||||
)
|
||||
|
||||
resp = _mock_response({}, status=200)
|
||||
|
||||
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
||||
instance = MockRequests.return_value
|
||||
instance.post = AsyncMock(return_value=resp)
|
||||
|
||||
result = await handler.revoke_tokens(creds)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestMCPClientDiscovery:
|
||||
"""Tests for MCPClient OAuth metadata discovery."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_auth_found(self):
|
||||
client = MCPClient("https://mcp.example.com/mcp")
|
||||
|
||||
metadata = {
|
||||
"authorization_servers": ["https://auth.example.com"],
|
||||
"resource": "https://mcp.example.com/mcp",
|
||||
}
|
||||
|
||||
resp = _mock_response(metadata, status=200)
|
||||
|
||||
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
||||
instance = MockRequests.return_value
|
||||
instance.get = AsyncMock(return_value=resp)
|
||||
|
||||
result = await client.discover_auth()
|
||||
|
||||
assert result is not None
|
||||
assert result["authorization_servers"] == ["https://auth.example.com"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_auth_not_found(self):
|
||||
client = MCPClient("https://mcp.example.com/mcp")
|
||||
|
||||
resp = _mock_response({}, status=404)
|
||||
|
||||
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
||||
instance = MockRequests.return_value
|
||||
instance.get = AsyncMock(return_value=resp)
|
||||
|
||||
result = await client.discover_auth()
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_auth_server_metadata(self):
|
||||
client = MCPClient("https://mcp.example.com/mcp")
|
||||
|
||||
server_metadata = {
|
||||
"issuer": "https://auth.example.com",
|
||||
"authorization_endpoint": "https://auth.example.com/authorize",
|
||||
"token_endpoint": "https://auth.example.com/token",
|
||||
"registration_endpoint": "https://auth.example.com/register",
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
}
|
||||
|
||||
resp = _mock_response(server_metadata, status=200)
|
||||
|
||||
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
||||
instance = MockRequests.return_value
|
||||
instance.get = AsyncMock(return_value=resp)
|
||||
|
||||
result = await client.discover_auth_server_metadata(
|
||||
"https://auth.example.com"
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result["authorization_endpoint"] == "https://auth.example.com/authorize"
|
||||
assert result["token_endpoint"] == "https://auth.example.com/token"
|
||||
@@ -1,162 +0,0 @@
|
||||
"""
|
||||
Minimal MCP server for integration testing.
|
||||
|
||||
Implements the MCP Streamable HTTP transport (JSON-RPC 2.0 over HTTP POST)
|
||||
with a few sample tools. Runs on localhost with a random available port.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Sample tools this test server exposes
|
||||
TEST_TOOLS = [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather for a city",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City name",
|
||||
},
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "add_numbers",
|
||||
"description": "Add two numbers together",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "number", "description": "First number"},
|
||||
"b": {"type": "number", "description": "Second number"},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "echo",
|
||||
"description": "Echo back the input message",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {"type": "string", "description": "Message to echo"},
|
||||
},
|
||||
"required": ["message"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _handle_initialize(params: dict) -> dict:
|
||||
return {
|
||||
"protocolVersion": "2025-03-26",
|
||||
"capabilities": {"tools": {"listChanged": False}},
|
||||
"serverInfo": {"name": "test-mcp-server", "version": "1.0.0"},
|
||||
}
|
||||
|
||||
|
||||
def _handle_tools_list(params: dict) -> dict:
|
||||
return {"tools": TEST_TOOLS}
|
||||
|
||||
|
||||
def _handle_tools_call(params: dict) -> dict:
|
||||
tool_name = params.get("name", "")
|
||||
arguments = params.get("arguments", {})
|
||||
|
||||
if tool_name == "get_weather":
|
||||
city = arguments.get("city", "Unknown")
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(
|
||||
{"city": city, "temperature": 22, "condition": "sunny"}
|
||||
),
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
elif tool_name == "add_numbers":
|
||||
a = arguments.get("a", 0)
|
||||
b = arguments.get("b", 0)
|
||||
return {
|
||||
"content": [{"type": "text", "text": json.dumps({"result": a + b})}],
|
||||
}
|
||||
|
||||
elif tool_name == "echo":
|
||||
message = arguments.get("message", "")
|
||||
return {
|
||||
"content": [{"type": "text", "text": message}],
|
||||
}
|
||||
|
||||
else:
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
|
||||
HANDLERS = {
|
||||
"initialize": _handle_initialize,
|
||||
"tools/list": _handle_tools_list,
|
||||
"tools/call": _handle_tools_call,
|
||||
}
|
||||
|
||||
|
||||
async def handle_mcp_request(request: web.Request) -> web.Response:
|
||||
"""Handle incoming MCP JSON-RPC 2.0 requests."""
|
||||
# Check auth if configured
|
||||
expected_token = request.app.get("auth_token")
|
||||
if expected_token:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header != f"Bearer {expected_token}":
|
||||
return web.json_response(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32001, "message": "Unauthorized"},
|
||||
"id": None,
|
||||
},
|
||||
status=401,
|
||||
)
|
||||
|
||||
body = await request.json()
|
||||
|
||||
# Handle notifications (no id field) — just acknowledge
|
||||
if "id" not in body:
|
||||
return web.Response(status=202)
|
||||
|
||||
method = body.get("method", "")
|
||||
params = body.get("params", {})
|
||||
request_id = body.get("id")
|
||||
|
||||
handler = HANDLERS.get(method)
|
||||
if not handler:
|
||||
return web.json_response(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"error": {
|
||||
"code": -32601,
|
||||
"message": f"Method not found: {method}",
|
||||
},
|
||||
"id": request_id,
|
||||
}
|
||||
)
|
||||
|
||||
result = handler(params)
|
||||
return web.json_response({"jsonrpc": "2.0", "result": result, "id": request_id})
|
||||
|
||||
|
||||
def create_test_mcp_app(auth_token: str | None = None) -> web.Application:
|
||||
"""Create an aiohttp app that acts as an MCP server."""
|
||||
app = web.Application()
|
||||
app.router.add_post("/mcp", handle_mcp_request)
|
||||
if auth_token:
|
||||
app["auth_token"] = auth_token
|
||||
return app
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -27,6 +28,54 @@ async def server():
|
||||
yield server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_id() -> str:
|
||||
"""Test user ID fixture."""
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_id() -> str:
|
||||
"""Admin user ID fixture."""
|
||||
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_user_id() -> str:
|
||||
"""Target user ID fixture."""
|
||||
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_test_user(test_user_id):
|
||||
"""Create test user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the test user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": test_user_id,
|
||||
"email": "test@example.com",
|
||||
"user_metadata": {"name": "Test User"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return test_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_admin_user(admin_user_id):
|
||||
"""Create admin user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the admin user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": admin_user_id,
|
||||
"email": "test-admin@example.com",
|
||||
"user_metadata": {"name": "Test Admin"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return admin_user_id
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True)
|
||||
async def graph_cleanup(server):
|
||||
created_graph_ids = []
|
||||
|
||||
8
autogpt_platform/backend/backend/copilot/__init__.py
Normal file
8
autogpt_platform/backend/backend/copilot/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""CoPilot module - AI assistant for AutoGPT platform.
|
||||
|
||||
This module contains the core CoPilot functionality including:
|
||||
- AI generation service (LLM calls)
|
||||
- Tool execution
|
||||
- Session management
|
||||
- Stream registry for SSE reconnection
|
||||
"""
|
||||
@@ -119,8 +119,9 @@ class ChatCompletionConsumer:
|
||||
"""Lazily initialize Prisma client on first use."""
|
||||
if self._prisma is None:
|
||||
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
||||
self._prisma = Prisma(datasource={"url": database_url})
|
||||
await self._prisma.connect()
|
||||
prisma = Prisma(datasource={"url": database_url})
|
||||
await prisma.connect()
|
||||
self._prisma = prisma
|
||||
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
|
||||
return self._prisma
|
||||
|
||||
@@ -14,7 +14,7 @@ from prisma.types import (
|
||||
ChatSessionWhereInput,
|
||||
)
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.data import db
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -147,7 +147,7 @@ async def add_chat_messages_batch(
|
||||
|
||||
created_messages = []
|
||||
|
||||
async with transaction() as tx:
|
||||
async with db.transaction() as tx:
|
||||
for i, msg in enumerate(messages):
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput
|
||||
# directly because Prisma's TypedDict validation rejects optional fields
|
||||
@@ -0,0 +1,5 @@
|
||||
"""CoPilot Executor - Dedicated service for AI generation and tool execution.
|
||||
|
||||
This module contains the executor service that processes CoPilot tasks
|
||||
from RabbitMQ, following the graph executor pattern.
|
||||
"""
|
||||
@@ -0,0 +1,18 @@
|
||||
"""Entry point for running the CoPilot Executor service.
|
||||
|
||||
Usage:
|
||||
python -m backend.copilot.executor
|
||||
"""
|
||||
|
||||
from backend.app import run_processes
|
||||
|
||||
from .manager import CoPilotExecutor
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the CoPilot Executor service."""
|
||||
run_processes(CoPilotExecutor())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
508
autogpt_platform/backend/backend/copilot/executor/manager.py
Normal file
508
autogpt_platform/backend/backend/copilot/executor/manager.py
Normal file
@@ -0,0 +1,508 @@
|
||||
"""CoPilot Executor Manager - main service for CoPilot task execution.
|
||||
|
||||
This module contains the CoPilotExecutor class that consumes chat tasks from
|
||||
RabbitMQ and processes them using a thread pool, following the graph executor pattern.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.exceptions import AMQPChannelError, AMQPConnectionError
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.cluster_lock import ClusterLock
|
||||
from backend.util.decorator import error_logged
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.process import AppProcess
|
||||
from backend.util.retry import continuous_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .processor import execute_copilot_task, init_worker
|
||||
from .utils import (
|
||||
COPILOT_CANCEL_QUEUE_NAME,
|
||||
COPILOT_EXECUTION_QUEUE_NAME,
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
CancelCoPilotEvent,
|
||||
CoPilotExecutionEntry,
|
||||
create_copilot_queue_config,
|
||||
)
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
|
||||
settings = Settings()
|
||||
|
||||
# Prometheus metrics
|
||||
active_tasks_gauge = Gauge(
|
||||
"copilot_executor_active_tasks",
|
||||
"Number of active CoPilot tasks",
|
||||
)
|
||||
pool_size_gauge = Gauge(
|
||||
"copilot_executor_pool_size",
|
||||
"Maximum number of CoPilot executor workers",
|
||||
)
|
||||
utilization_gauge = Gauge(
|
||||
"copilot_executor_utilization_ratio",
|
||||
"Ratio of active tasks to pool size",
|
||||
)
|
||||
|
||||
|
||||
class CoPilotExecutor(AppProcess):
|
||||
"""CoPilot Executor service for processing chat generation tasks.
|
||||
|
||||
This service consumes tasks from RabbitMQ, processes them using a thread pool,
|
||||
and publishes results to Redis Streams. It follows the graph executor pattern
|
||||
for reliable message handling and graceful shutdown.
|
||||
|
||||
Key features:
|
||||
- RabbitMQ-based task distribution with manual acknowledgment
|
||||
- Thread pool executor for concurrent task processing
|
||||
- Cluster lock for duplicate prevention across pods
|
||||
- Graceful shutdown with timeout for in-flight tasks
|
||||
- FANOUT exchange for cancellation broadcast
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pool_size = settings.config.num_copilot_workers
|
||||
self.active_tasks: dict[str, tuple[Future, threading.Event]] = {}
|
||||
self.executor_id = str(uuid.uuid4())
|
||||
|
||||
self._executor = None
|
||||
self._stop_consuming = None
|
||||
|
||||
self._cancel_thread = None
|
||||
self._cancel_client = None
|
||||
self._run_thread = None
|
||||
self._run_client = None
|
||||
|
||||
self._task_locks: dict[str, ClusterLock] = {}
|
||||
|
||||
# ============ Main Entry Points (AppProcess interface) ============ #
|
||||
|
||||
def run(self):
|
||||
"""Main service loop - consume from RabbitMQ."""
|
||||
logger.info(f"Pod assigned executor_id: {self.executor_id}")
|
||||
logger.info(f"Spawn max-{self.pool_size} workers...")
|
||||
|
||||
pool_size_gauge.set(self.pool_size)
|
||||
self._update_metrics()
|
||||
start_http_server(settings.config.copilot_executor_port)
|
||||
|
||||
self.cancel_thread.start()
|
||||
self.run_thread.start()
|
||||
|
||||
while True:
|
||||
time.sleep(1e5)
|
||||
|
||||
def cleanup(self):
|
||||
"""Graceful shutdown with active execution waiting."""
|
||||
pid = os.getpid()
|
||||
logger.info(f"[cleanup {pid}] Starting graceful shutdown...")
|
||||
|
||||
# Signal the consumer thread to stop
|
||||
try:
|
||||
self.stop_consuming.set()
|
||||
run_channel = self.run_client.get_channel()
|
||||
run_channel.connection.add_callback_threadsafe(
|
||||
lambda: run_channel.stop_consuming()
|
||||
)
|
||||
logger.info(f"[cleanup {pid}] Consumer has been signaled to stop")
|
||||
except Exception as e:
|
||||
logger.error(f"[cleanup {pid}] Error stopping consumer: {e}")
|
||||
|
||||
# Wait for active executions to complete
|
||||
if self.active_tasks:
|
||||
logger.info(
|
||||
f"[cleanup {pid}] Waiting for {len(self.active_tasks)} active tasks to complete (timeout: {GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}s)..."
|
||||
)
|
||||
|
||||
start_time = time.monotonic()
|
||||
last_refresh = start_time
|
||||
lock_refresh_interval = settings.config.cluster_lock_timeout / 10
|
||||
|
||||
while (
|
||||
self.active_tasks
|
||||
and (time.monotonic() - start_time) < GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS
|
||||
):
|
||||
self._cleanup_completed_tasks()
|
||||
if not self.active_tasks:
|
||||
break
|
||||
|
||||
# Refresh cluster locks periodically
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_refresh >= lock_refresh_interval:
|
||||
for lock in self._task_locks.values():
|
||||
try:
|
||||
lock.refresh()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[cleanup {pid}] Failed to refresh lock: {e}"
|
||||
)
|
||||
last_refresh = current_time
|
||||
|
||||
logger.info(
|
||||
f"[cleanup {pid}] {len(self.active_tasks)} tasks still active, waiting..."
|
||||
)
|
||||
time.sleep(10.0)
|
||||
|
||||
# Stop message consumers
|
||||
if self._run_thread:
|
||||
self._stop_message_consumers(
|
||||
self._run_thread, self.run_client, "[cleanup][run]"
|
||||
)
|
||||
if self._cancel_thread:
|
||||
self._stop_message_consumers(
|
||||
self._cancel_thread, self.cancel_client, "[cleanup][cancel]"
|
||||
)
|
||||
|
||||
# Shutdown executor
|
||||
if self._executor:
|
||||
logger.info(f"[cleanup {pid}] Shutting down executor...")
|
||||
self._executor.shutdown(wait=False)
|
||||
|
||||
# Release any remaining locks
|
||||
for task_id, lock in list(self._task_locks.items()):
|
||||
try:
|
||||
lock.release()
|
||||
logger.info(f"[cleanup {pid}] Released lock for {task_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[cleanup {pid}] Failed to release lock for {task_id}: {e}"
|
||||
)
|
||||
|
||||
logger.info(f"[cleanup {pid}] Graceful shutdown completed")
|
||||
|
||||
# ============ RabbitMQ Consumer Methods ============ #
|
||||
|
||||
@continuous_retry()
|
||||
def _consume_cancel(self):
|
||||
"""Consume cancellation messages from FANOUT exchange."""
|
||||
if self.stop_consuming.is_set() and not self.active_tasks:
|
||||
logger.info("Stop reconnecting cancel consumer - service cleaned up")
|
||||
return
|
||||
|
||||
if not self.cancel_client.is_ready:
|
||||
self.cancel_client.disconnect()
|
||||
self.cancel_client.connect()
|
||||
|
||||
# Check again after connect - shutdown may have been requested
|
||||
if self.stop_consuming.is_set() and not self.active_tasks:
|
||||
logger.info("Stop consuming requested during reconnect - disconnecting")
|
||||
self.cancel_client.disconnect()
|
||||
return
|
||||
|
||||
cancel_channel = self.cancel_client.get_channel()
|
||||
cancel_channel.basic_consume(
|
||||
queue=COPILOT_CANCEL_QUEUE_NAME,
|
||||
on_message_callback=self._handle_cancel_message,
|
||||
auto_ack=True,
|
||||
)
|
||||
logger.info("Starting cancel message consumer...")
|
||||
cancel_channel.start_consuming()
|
||||
if not self.stop_consuming.is_set() or self.active_tasks:
|
||||
raise RuntimeError("Cancel message consumer stopped unexpectedly")
|
||||
logger.info("Cancel message consumer stopped gracefully")
|
||||
|
||||
@continuous_retry()
|
||||
def _consume_run(self):
|
||||
"""Consume run messages from DIRECT exchange."""
|
||||
if self.stop_consuming.is_set():
|
||||
logger.info("Stop reconnecting run consumer - service cleaned up")
|
||||
return
|
||||
|
||||
if not self.run_client.is_ready:
|
||||
self.run_client.disconnect()
|
||||
self.run_client.connect()
|
||||
|
||||
# Check again after connect - shutdown may have been requested
|
||||
if self.stop_consuming.is_set():
|
||||
logger.info("Stop consuming requested during reconnect - disconnecting")
|
||||
self.run_client.disconnect()
|
||||
return
|
||||
|
||||
run_channel = self.run_client.get_channel()
|
||||
run_channel.basic_qos(prefetch_count=self.pool_size)
|
||||
|
||||
run_channel.basic_consume(
|
||||
queue=COPILOT_EXECUTION_QUEUE_NAME,
|
||||
on_message_callback=self._handle_run_message,
|
||||
auto_ack=False,
|
||||
consumer_tag="copilot_execution_consumer",
|
||||
)
|
||||
logger.info("Starting to consume run messages...")
|
||||
run_channel.start_consuming()
|
||||
if not self.stop_consuming.is_set():
|
||||
raise RuntimeError("Run message consumer stopped unexpectedly")
|
||||
logger.info("Run message consumer stopped gracefully")
|
||||
|
||||
# ============ Message Handlers ============ #
|
||||
|
||||
@error_logged(swallow=True)
|
||||
def _handle_cancel_message(
|
||||
self,
|
||||
_channel: BlockingChannel,
|
||||
_method: Basic.Deliver,
|
||||
_properties: BasicProperties,
|
||||
body: bytes,
|
||||
):
|
||||
"""Handle cancel message from FANOUT exchange."""
|
||||
request = CancelCoPilotEvent.model_validate_json(body)
|
||||
task_id = request.task_id
|
||||
if not task_id:
|
||||
logger.warning("Cancel message missing 'task_id'")
|
||||
return
|
||||
if task_id not in self.active_tasks:
|
||||
logger.debug(f"Cancel received for {task_id} but not active")
|
||||
return
|
||||
|
||||
_, cancel_event = self.active_tasks[task_id]
|
||||
logger.info(f"Received cancel for {task_id}")
|
||||
if not cancel_event.is_set():
|
||||
cancel_event.set()
|
||||
else:
|
||||
logger.debug(f"Cancel already set for {task_id}")
|
||||
|
||||
def _handle_run_message(
|
||||
self,
|
||||
_channel: BlockingChannel,
|
||||
method: Basic.Deliver,
|
||||
_properties: BasicProperties,
|
||||
body: bytes,
|
||||
):
|
||||
"""Handle run message from DIRECT exchange."""
|
||||
delivery_tag = method.delivery_tag
|
||||
# Capture the channel used at message delivery time to ensure we ack
|
||||
# on the correct channel. Delivery tags are channel-scoped and become
|
||||
# invalid if the channel is recreated after reconnection.
|
||||
delivery_channel = _channel
|
||||
|
||||
def ack_message(reject: bool, requeue: bool):
|
||||
"""Acknowledge or reject the message.
|
||||
|
||||
Uses the channel from the original message delivery. If the channel
|
||||
is no longer open (e.g., after reconnection), logs a warning and
|
||||
skips the ack - RabbitMQ will redeliver the message automatically.
|
||||
"""
|
||||
try:
|
||||
if not delivery_channel.is_open:
|
||||
logger.warning(
|
||||
f"Channel closed, cannot ack delivery_tag={delivery_tag}. "
|
||||
"Message will be redelivered by RabbitMQ."
|
||||
)
|
||||
return
|
||||
|
||||
if reject:
|
||||
delivery_channel.connection.add_callback_threadsafe(
|
||||
lambda: delivery_channel.basic_nack(
|
||||
delivery_tag, requeue=requeue
|
||||
)
|
||||
)
|
||||
else:
|
||||
delivery_channel.connection.add_callback_threadsafe(
|
||||
lambda: delivery_channel.basic_ack(delivery_tag)
|
||||
)
|
||||
except (AMQPChannelError, AMQPConnectionError) as e:
|
||||
# Channel/connection errors indicate stale delivery tag - don't retry
|
||||
logger.warning(
|
||||
f"Cannot ack delivery_tag={delivery_tag} due to channel/connection "
|
||||
f"error: {e}. Message will be redelivered by RabbitMQ."
|
||||
)
|
||||
except Exception as e:
|
||||
# Other errors might be transient, but log and skip to avoid blocking
|
||||
logger.error(
|
||||
f"Unexpected error acking delivery_tag={delivery_tag}: {e}"
|
||||
)
|
||||
|
||||
# Check if we're shutting down
|
||||
if self.stop_consuming.is_set():
|
||||
logger.info("Rejecting new task during shutdown")
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
# Check if we can accept more tasks
|
||||
self._cleanup_completed_tasks()
|
||||
if len(self.active_tasks) >= self.pool_size:
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
try:
|
||||
entry = CoPilotExecutionEntry.model_validate_json(body)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not parse run message: {e}, body={body}")
|
||||
ack_message(reject=True, requeue=False)
|
||||
return
|
||||
|
||||
task_id = entry.task_id
|
||||
|
||||
# Check for local duplicate - task is already running on this executor
|
||||
if task_id in self.active_tasks:
|
||||
logger.warning(
|
||||
f"Task {task_id} already running locally, rejecting duplicate"
|
||||
)
|
||||
ack_message(reject=True, requeue=False)
|
||||
return
|
||||
|
||||
# Try to acquire cluster-wide lock
|
||||
cluster_lock = ClusterLock(
|
||||
redis=redis.get_redis(),
|
||||
key=f"copilot:task:{task_id}:lock",
|
||||
owner_id=self.executor_id,
|
||||
timeout=settings.config.cluster_lock_timeout,
|
||||
)
|
||||
current_owner = cluster_lock.try_acquire()
|
||||
if current_owner != self.executor_id:
|
||||
if current_owner is not None:
|
||||
logger.warning(f"Task {task_id} already running on pod {current_owner}")
|
||||
ack_message(reject=True, requeue=False)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Could not acquire lock for {task_id} - Redis unavailable"
|
||||
)
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
# Execute the task
|
||||
try:
|
||||
self._task_locks[task_id] = cluster_lock
|
||||
|
||||
logger.info(
|
||||
f"Acquired cluster lock for {task_id}, executor_id={self.executor_id}"
|
||||
)
|
||||
|
||||
cancel_event = threading.Event()
|
||||
future = self.executor.submit(
|
||||
execute_copilot_task, entry, cancel_event, cluster_lock
|
||||
)
|
||||
self.active_tasks[task_id] = (future, cancel_event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to setup execution for {task_id}: {e}")
|
||||
cluster_lock.release()
|
||||
if task_id in self._task_locks:
|
||||
del self._task_locks[task_id]
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
self._update_metrics()
|
||||
|
||||
def on_run_done(f: Future):
|
||||
logger.info(f"Run completed for {task_id}")
|
||||
try:
|
||||
if exec_error := f.exception():
|
||||
logger.error(f"Execution for {task_id} failed: {exec_error}")
|
||||
# Don't requeue failed tasks - they've been marked as failed
|
||||
# in the stream registry. Requeuing would cause infinite retries
|
||||
# for deterministic failures.
|
||||
ack_message(reject=True, requeue=False)
|
||||
else:
|
||||
ack_message(reject=False, requeue=False)
|
||||
except BaseException as e:
|
||||
logger.exception(f"Error in run completion callback: {e}")
|
||||
finally:
|
||||
# Release the cluster lock
|
||||
if task_id in self._task_locks:
|
||||
logger.info(f"Releasing cluster lock for {task_id}")
|
||||
self._task_locks[task_id].release()
|
||||
del self._task_locks[task_id]
|
||||
self._cleanup_completed_tasks()
|
||||
|
||||
future.add_done_callback(on_run_done)
|
||||
|
||||
# ============ Helper Methods ============ #
|
||||
|
||||
def _cleanup_completed_tasks(self) -> list[str]:
|
||||
"""Remove completed futures from active_tasks and update metrics."""
|
||||
completed_tasks = []
|
||||
for task_id, (future, _) in self.active_tasks.items():
|
||||
if future.done():
|
||||
completed_tasks.append(task_id)
|
||||
|
||||
for task_id in completed_tasks:
|
||||
logger.info(f"Cleaned up completed task {task_id}")
|
||||
self.active_tasks.pop(task_id, None)
|
||||
|
||||
self._update_metrics()
|
||||
return completed_tasks
|
||||
|
||||
def _update_metrics(self):
|
||||
"""Update Prometheus metrics."""
|
||||
active_count = len(self.active_tasks)
|
||||
active_tasks_gauge.set(active_count)
|
||||
if self.stop_consuming.is_set():
|
||||
utilization_gauge.set(1.0)
|
||||
else:
|
||||
utilization_gauge.set(
|
||||
active_count / self.pool_size if self.pool_size > 0 else 0
|
||||
)
|
||||
|
||||
def _stop_message_consumers(
|
||||
self, thread: threading.Thread, client: SyncRabbitMQ, prefix: str
|
||||
):
|
||||
"""Stop a message consumer thread."""
|
||||
try:
|
||||
channel = client.get_channel()
|
||||
channel.connection.add_callback_threadsafe(lambda: channel.stop_consuming())
|
||||
|
||||
thread.join(timeout=300)
|
||||
if thread.is_alive():
|
||||
logger.error(
|
||||
f"{prefix} Thread did not finish in time, forcing disconnect"
|
||||
)
|
||||
|
||||
client.disconnect()
|
||||
logger.info(f"{prefix} Client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"{prefix} Error disconnecting client: {e}")
|
||||
|
||||
# ============ Lazy-initialized Properties ============ #
|
||||
|
||||
@property
|
||||
def cancel_thread(self) -> threading.Thread:
|
||||
if self._cancel_thread is None:
|
||||
self._cancel_thread = threading.Thread(
|
||||
target=lambda: self._consume_cancel(),
|
||||
daemon=True,
|
||||
)
|
||||
return self._cancel_thread
|
||||
|
||||
@property
|
||||
def run_thread(self) -> threading.Thread:
|
||||
if self._run_thread is None:
|
||||
self._run_thread = threading.Thread(
|
||||
target=lambda: self._consume_run(),
|
||||
daemon=True,
|
||||
)
|
||||
return self._run_thread
|
||||
|
||||
@property
|
||||
def stop_consuming(self) -> threading.Event:
|
||||
if self._stop_consuming is None:
|
||||
self._stop_consuming = threading.Event()
|
||||
return self._stop_consuming
|
||||
|
||||
@property
|
||||
def executor(self) -> ThreadPoolExecutor:
|
||||
if self._executor is None:
|
||||
self._executor = ThreadPoolExecutor(
|
||||
max_workers=self.pool_size,
|
||||
initializer=init_worker,
|
||||
)
|
||||
return self._executor
|
||||
|
||||
@property
|
||||
def cancel_client(self) -> SyncRabbitMQ:
|
||||
if self._cancel_client is None:
|
||||
self._cancel_client = SyncRabbitMQ(create_copilot_queue_config())
|
||||
return self._cancel_client
|
||||
|
||||
@property
|
||||
def run_client(self) -> SyncRabbitMQ:
|
||||
if self._run_client is None:
|
||||
self._run_client = SyncRabbitMQ(create_copilot_queue_config())
|
||||
return self._run_client
|
||||
237
autogpt_platform/backend/backend/copilot/executor/processor.py
Normal file
237
autogpt_platform/backend/backend/copilot/executor/processor.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""CoPilot execution processor - per-worker execution logic.
|
||||
|
||||
This module contains the processor class that handles CoPilot task execution
|
||||
in a thread-local context, following the graph executor pattern.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
from backend.copilot import service as copilot_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamFinishStep
|
||||
from backend.executor.cluster_lock import ClusterLock
|
||||
from backend.util.decorator import error_logged
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.process import set_service_name
|
||||
from backend.util.retry import func_retry
|
||||
|
||||
from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
|
||||
|
||||
|
||||
# ============ Module Entry Points ============ #
|
||||
|
||||
# Thread-local storage for processor instances
|
||||
_tls = threading.local()
|
||||
|
||||
|
||||
def execute_copilot_task(
|
||||
entry: CoPilotExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
"""Execute a CoPilot task using the thread-local processor.
|
||||
|
||||
This function is the entry point called by the thread pool executor.
|
||||
|
||||
Args:
|
||||
entry: The task payload
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock for this execution
|
||||
"""
|
||||
processor: CoPilotProcessor = _tls.processor
|
||||
return processor.execute(entry, cancel, cluster_lock)
|
||||
|
||||
|
||||
def init_worker():
|
||||
"""Initialize the processor for the current worker thread.
|
||||
|
||||
This function is called by the thread pool executor when a new worker
|
||||
thread is created. It ensures each worker has its own processor instance.
|
||||
"""
|
||||
_tls.processor = CoPilotProcessor()
|
||||
_tls.processor.on_executor_start()
|
||||
|
||||
|
||||
# ============ Processor Class ============ #
|
||||
|
||||
|
||||
class CoPilotProcessor:
|
||||
"""Per-worker execution logic for CoPilot tasks.
|
||||
|
||||
This class is instantiated once per worker thread and handles the execution
|
||||
of CoPilot chat generation tasks. It maintains an async event loop for
|
||||
running the async service code.
|
||||
|
||||
The execution flow:
|
||||
1. CoPilot task is picked from RabbitMQ queue
|
||||
2. Manager submits task to thread pool
|
||||
3. Processor executes the task in its event loop
|
||||
4. Results are published to Redis Streams
|
||||
"""
|
||||
|
||||
@func_retry
|
||||
def on_executor_start(self):
|
||||
"""Initialize the processor when the worker thread starts.
|
||||
|
||||
This method is called once per worker thread to set up the async event
|
||||
loop and initialize any required resources.
|
||||
|
||||
Database is accessed only through DatabaseManager, so we don't need to connect
|
||||
to Prisma directly.
|
||||
"""
|
||||
configure_logging()
|
||||
set_service_name("CoPilotExecutor")
|
||||
self.tid = threading.get_ident()
|
||||
self.execution_loop = asyncio.new_event_loop()
|
||||
self.execution_thread = threading.Thread(
|
||||
target=self.execution_loop.run_forever, daemon=True
|
||||
)
|
||||
self.execution_thread.start()
|
||||
|
||||
logger.info(f"[CoPilotExecutor] Worker {self.tid} started")
|
||||
|
||||
@error_logged(swallow=False)
|
||||
def execute(
|
||||
self,
|
||||
entry: CoPilotExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
"""Execute a CoPilot task.
|
||||
|
||||
This is the main entry point for task execution. It runs the async
|
||||
execution logic in the worker's event loop and handles errors.
|
||||
|
||||
Args:
|
||||
entry: The task payload containing session and message info
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock to prevent duplicate execution
|
||||
"""
|
||||
log = CoPilotLogMetadata(
|
||||
logging.getLogger(__name__),
|
||||
task_id=entry.task_id,
|
||||
session_id=entry.session_id,
|
||||
user_id=entry.user_id,
|
||||
)
|
||||
log.info("Starting execution")
|
||||
|
||||
start_time = time.monotonic()
|
||||
|
||||
try:
|
||||
# Run the async execution in our event loop
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._execute_async(entry, cancel, cluster_lock, log),
|
||||
self.execution_loop,
|
||||
)
|
||||
|
||||
# Wait for completion, checking cancel periodically
|
||||
while not future.done():
|
||||
try:
|
||||
future.result(timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
if cancel.is_set():
|
||||
log.info("Cancellation requested")
|
||||
future.cancel()
|
||||
break
|
||||
# Refresh cluster lock to maintain ownership
|
||||
cluster_lock.refresh()
|
||||
|
||||
if not future.cancelled():
|
||||
# Get result to propagate any exceptions
|
||||
future.result()
|
||||
|
||||
elapsed = time.monotonic() - start_time
|
||||
log.info(f"Execution completed in {elapsed:.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.monotonic() - start_time
|
||||
log.error(f"Execution failed after {elapsed:.2f}s: {e}")
|
||||
# Note: _execute_async already marks the task as failed before re-raising,
|
||||
# so we don't call _mark_task_failed here to avoid duplicate error events.
|
||||
raise
|
||||
|
||||
async def _execute_async(
|
||||
self,
|
||||
entry: CoPilotExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
log: CoPilotLogMetadata,
|
||||
):
|
||||
"""Async execution logic for CoPilot task.
|
||||
|
||||
This method calls the existing stream_chat_completion service function
|
||||
and publishes results to the stream registry.
|
||||
|
||||
Args:
|
||||
entry: The task payload
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock for refresh
|
||||
log: Structured logger for this task
|
||||
"""
|
||||
last_refresh = time.monotonic()
|
||||
refresh_interval = 30.0 # Refresh lock every 30 seconds
|
||||
|
||||
try:
|
||||
# Stream chat completion and publish chunks to Redis
|
||||
async for chunk in copilot_service.stream_chat_completion(
|
||||
session_id=entry.session_id,
|
||||
message=entry.message if entry.message else None,
|
||||
is_user_message=entry.is_user_message,
|
||||
user_id=entry.user_id,
|
||||
context=entry.context,
|
||||
_task_id=entry.task_id,
|
||||
):
|
||||
# Check for cancellation
|
||||
if cancel.is_set():
|
||||
log.info("Cancelled during streaming")
|
||||
await stream_registry.publish_chunk(
|
||||
entry.task_id, StreamError(errorText="Operation cancelled")
|
||||
)
|
||||
await stream_registry.publish_chunk(
|
||||
entry.task_id, StreamFinishStep()
|
||||
)
|
||||
await stream_registry.publish_chunk(entry.task_id, StreamFinish())
|
||||
await stream_registry.mark_task_completed(
|
||||
entry.task_id, status="failed"
|
||||
)
|
||||
return
|
||||
|
||||
# Refresh cluster lock periodically
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_refresh >= refresh_interval:
|
||||
cluster_lock.refresh()
|
||||
last_refresh = current_time
|
||||
|
||||
# Publish chunk to stream registry
|
||||
await stream_registry.publish_chunk(entry.task_id, chunk)
|
||||
|
||||
# Mark task as completed
|
||||
await stream_registry.mark_task_completed(entry.task_id, status="completed")
|
||||
log.info("Task completed successfully")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
log.info("Task cancelled")
|
||||
await stream_registry.mark_task_completed(entry.task_id, status="failed")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Task failed: {e}")
|
||||
await self._mark_task_failed(entry.task_id, str(e))
|
||||
raise
|
||||
|
||||
async def _mark_task_failed(self, task_id: str, error_message: str):
|
||||
"""Mark a task as failed and publish error to stream registry."""
|
||||
try:
|
||||
await stream_registry.publish_chunk(
|
||||
task_id, StreamError(errorText=error_message)
|
||||
)
|
||||
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to mark task {task_id} as failed: {e}")
|
||||
207
autogpt_platform/backend/backend/copilot/executor/utils.py
Normal file
207
autogpt_platform/backend/backend/copilot/executor/utils.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""RabbitMQ queue configuration for CoPilot executor.
|
||||
|
||||
Defines two exchanges and queues following the graph executor pattern:
|
||||
- 'copilot_execution' (DIRECT) for chat generation tasks
|
||||
- 'copilot_cancel' (FANOUT) for cancellation requests
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============ Logging Helper ============ #
|
||||
|
||||
|
||||
class CoPilotLogMetadata(TruncatedLogger):
|
||||
"""Structured logging helper for CoPilot executor.
|
||||
|
||||
In cloud environments (structured logging enabled), uses a simple prefix
|
||||
and passes metadata via json_fields. In local environments, uses a detailed
|
||||
prefix with all metadata key-value pairs for easier debugging.
|
||||
|
||||
Args:
|
||||
logger: The underlying logger instance
|
||||
max_length: Maximum log message length before truncation
|
||||
**kwargs: Metadata key-value pairs (e.g., task_id="abc", session_id="xyz")
|
||||
These are added to json_fields in cloud mode, or to the prefix in local mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logger: logging.Logger,
|
||||
max_length: int = 1000,
|
||||
**kwargs: str | None,
|
||||
):
|
||||
# Filter out None values
|
||||
metadata = {k: v for k, v in kwargs.items() if v is not None}
|
||||
metadata["component"] = "CoPilotExecutor"
|
||||
|
||||
if is_structured_logging_enabled():
|
||||
prefix = "[CoPilotExecutor]"
|
||||
else:
|
||||
# Build prefix from metadata key-value pairs
|
||||
meta_parts = "|".join(
|
||||
f"{k}:{v}" for k, v in metadata.items() if k != "component"
|
||||
)
|
||||
prefix = (
|
||||
f"[CoPilotExecutor|{meta_parts}]" if meta_parts else "[CoPilotExecutor]"
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
logger,
|
||||
max_length=max_length,
|
||||
prefix=prefix,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
# ============ Exchange and Queue Configuration ============ #
|
||||
|
||||
COPILOT_EXECUTION_EXCHANGE = Exchange(
|
||||
name="copilot_execution",
|
||||
type=ExchangeType.DIRECT,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
COPILOT_EXECUTION_QUEUE_NAME = "copilot_execution_queue"
|
||||
COPILOT_EXECUTION_ROUTING_KEY = "copilot.run"
|
||||
|
||||
COPILOT_CANCEL_EXCHANGE = Exchange(
|
||||
name="copilot_cancel",
|
||||
type=ExchangeType.FANOUT,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
COPILOT_CANCEL_QUEUE_NAME = "copilot_cancel_queue"
|
||||
|
||||
# CoPilot operations can include extended thinking and agent generation
|
||||
# which may take 30+ minutes to complete
|
||||
COPILOT_CONSUMER_TIMEOUT_SECONDS = 60 * 60 # 1 hour
|
||||
|
||||
# Graceful shutdown timeout - allow in-flight operations to complete
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def create_copilot_queue_config() -> RabbitMQConfig:
|
||||
"""Create RabbitMQ configuration for CoPilot executor.
|
||||
|
||||
Defines two exchanges and queues:
|
||||
- 'copilot_execution' (DIRECT) for chat generation tasks
|
||||
- 'copilot_cancel' (FANOUT) for cancellation requests
|
||||
|
||||
Returns:
|
||||
RabbitMQConfig with exchanges and queues defined
|
||||
"""
|
||||
run_queue = Queue(
|
||||
name=COPILOT_EXECUTION_QUEUE_NAME,
|
||||
exchange=COPILOT_EXECUTION_EXCHANGE,
|
||||
routing_key=COPILOT_EXECUTION_ROUTING_KEY,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
arguments={
|
||||
# Extended consumer timeout for long-running LLM operations
|
||||
# Default 30-minute timeout is insufficient for extended thinking
|
||||
# and agent generation which can take 30+ minutes
|
||||
"x-consumer-timeout": COPILOT_CONSUMER_TIMEOUT_SECONDS
|
||||
* 1000,
|
||||
},
|
||||
)
|
||||
cancel_queue = Queue(
|
||||
name=COPILOT_CANCEL_QUEUE_NAME,
|
||||
exchange=COPILOT_CANCEL_EXCHANGE,
|
||||
routing_key="", # not used for FANOUT
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
return RabbitMQConfig(
|
||||
vhost="/",
|
||||
exchanges=[COPILOT_EXECUTION_EXCHANGE, COPILOT_CANCEL_EXCHANGE],
|
||||
queues=[run_queue, cancel_queue],
|
||||
)
|
||||
|
||||
|
||||
# ============ Message Models ============ #
|
||||
|
||||
|
||||
class CoPilotExecutionEntry(BaseModel):
|
||||
"""Task payload for CoPilot AI generation.
|
||||
|
||||
This model represents a chat generation task to be processed by the executor.
|
||||
"""
|
||||
|
||||
task_id: str
|
||||
"""Unique identifier for this task (used for stream registry)"""
|
||||
|
||||
session_id: str
|
||||
"""Chat session ID"""
|
||||
|
||||
user_id: str | None
|
||||
"""User ID (may be None for anonymous users)"""
|
||||
|
||||
operation_id: str
|
||||
"""Operation ID for webhook callbacks and completion tracking"""
|
||||
|
||||
message: str
|
||||
"""User's message to process"""
|
||||
|
||||
is_user_message: bool = True
|
||||
"""Whether the message is from the user (vs system/assistant)"""
|
||||
|
||||
context: dict[str, str] | None = None
|
||||
"""Optional context for the message (e.g., {url: str, content: str})"""
|
||||
|
||||
|
||||
class CancelCoPilotEvent(BaseModel):
|
||||
"""Event to cancel a CoPilot operation."""
|
||||
|
||||
task_id: str
|
||||
"""Task ID to cancel"""
|
||||
|
||||
|
||||
# ============ Queue Publishing Helpers ============ #
|
||||
|
||||
|
||||
async def enqueue_copilot_task(
|
||||
task_id: str,
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
operation_id: str,
|
||||
message: str,
|
||||
is_user_message: bool = True,
|
||||
context: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Enqueue a CoPilot task for processing by the executor service.
|
||||
|
||||
Args:
|
||||
task_id: Unique identifier for this task (used for stream registry)
|
||||
session_id: Chat session ID
|
||||
user_id: User ID (may be None for anonymous users)
|
||||
operation_id: Operation ID for webhook callbacks and completion tracking
|
||||
message: User's message to process
|
||||
is_user_message: Whether the message is from the user (vs system/assistant)
|
||||
context: Optional context for the message (e.g., {url: str, content: str})
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
entry = CoPilotExecutionEntry(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
operation_id=operation_id,
|
||||
message=message,
|
||||
is_user_message=is_user_message,
|
||||
context=context,
|
||||
)
|
||||
|
||||
queue_client = await get_async_copilot_queue()
|
||||
await queue_client.publish_message(
|
||||
routing_key=COPILOT_EXECUTION_ROUTING_KEY,
|
||||
message=entry.model_dump_json(),
|
||||
exchange=COPILOT_EXECUTION_EXCHANGE,
|
||||
)
|
||||
@@ -23,26 +23,17 @@ from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.db_accessors import chat_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util import json
|
||||
from backend.util.exceptions import DatabaseError, RedisError
|
||||
|
||||
from . import db as chat_db
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
||||
"""Parse a JSON field that may be stored as string or already parsed."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, str):
|
||||
return json.loads(value)
|
||||
return value
|
||||
|
||||
|
||||
# Redis cache key prefix for chat sessions
|
||||
CHAT_SESSION_CACHE_PREFIX = "chat:session:"
|
||||
|
||||
@@ -52,28 +43,7 @@ def _get_session_cache_key(session_id: str) -> str:
|
||||
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
|
||||
|
||||
|
||||
# Session-level locks to prevent race conditions during concurrent upserts.
|
||||
# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced,
|
||||
# preventing unbounded memory growth while maintaining lock semantics for active sessions.
|
||||
# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after
|
||||
# async with lock: completes). Explicit cleanup also occurs in delete_chat_session().
|
||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
|
||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||
|
||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||
when no coroutine holds a reference to them, preventing memory leaks from
|
||||
unbounded growth of session locks.
|
||||
"""
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_session_locks[session_id] = lock
|
||||
return lock
|
||||
# ===================== Chat data models ===================== #
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
@@ -322,38 +292,26 @@ class ChatSession(BaseModel):
|
||||
return self._merge_consecutive_assistant_messages(messages)
|
||||
|
||||
|
||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from Redis cache."""
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||
|
||||
if raw_session is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
logger.info(
|
||||
f"Loading session {session_id} from cache: "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
)
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
||||
"""Parse a JSON field that may be stored as string or already parsed."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, str):
|
||||
return json.loads(value)
|
||||
return value
|
||||
|
||||
|
||||
async def _cache_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session in Redis."""
|
||||
redis_key = _get_session_cache_key(session.session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||
# ================ Chat cache + DB operations ================ #
|
||||
|
||||
# NOTE: Database calls are automatically routed through DatabaseManager if Prisma is not
|
||||
# connected directly.
|
||||
|
||||
|
||||
async def cache_chat_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session without persisting to the database."""
|
||||
await _cache_session(session)
|
||||
"""Cache a chat session in Redis (without persisting to the database)."""
|
||||
redis_key = _get_session_cache_key(session.session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||
|
||||
|
||||
async def invalidate_session_cache(session_id: str) -> None:
|
||||
@@ -371,80 +329,6 @@ async def invalidate_session_cache(session_id: str) -> None:
|
||||
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
prisma_session = await chat_db.get_chat_session(session_id)
|
||||
if not prisma_session:
|
||||
return None
|
||||
|
||||
messages = prisma_session.Messages
|
||||
logger.info(
|
||||
f"Loading session {session_id} from DB: "
|
||||
f"has_messages={messages is not None}, "
|
||||
f"message_count={len(messages) if messages else 0}, "
|
||||
f"roles={[m.role for m in messages] if messages else []}"
|
||||
)
|
||||
|
||||
return ChatSession.from_db(prisma_session, messages)
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
session: ChatSession, existing_message_count: int
|
||||
) -> None:
|
||||
"""Save or update a chat session in the database."""
|
||||
# Check if session exists in DB
|
||||
existing = await chat_db.get_chat_session(session.session_id)
|
||||
|
||||
if not existing:
|
||||
# Create new session
|
||||
await chat_db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
# Calculate total tokens from usage
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
# Update session metadata
|
||||
await chat_db.update_chat_session(
|
||||
session_id=session.session_id,
|
||||
credentials=session.credentials,
|
||||
successful_agent_runs=session.successful_agent_runs,
|
||||
successful_agent_schedules=session.successful_agent_schedules,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
# Add new messages (only those after existing count)
|
||||
new_messages = session.messages[existing_message_count:]
|
||||
if new_messages:
|
||||
messages_data = []
|
||||
for msg in new_messages:
|
||||
messages_data.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"name": msg.name,
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"refusal": msg.refusal,
|
||||
"tool_calls": msg.tool_calls,
|
||||
"function_call": msg.function_call,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||
f"roles={[m['role'] for m in messages_data]}, "
|
||||
f"start_sequence={existing_message_count}"
|
||||
)
|
||||
await chat_db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
messages=messages_data,
|
||||
start_sequence=existing_message_count,
|
||||
)
|
||||
|
||||
|
||||
async def get_chat_session(
|
||||
session_id: str,
|
||||
user_id: str | None = None,
|
||||
@@ -492,7 +376,7 @@ async def get_chat_session(
|
||||
|
||||
# Cache the session from DB
|
||||
try:
|
||||
await _cache_session(session)
|
||||
await cache_chat_session(session)
|
||||
logger.info(f"Cached session {session_id} from database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||
@@ -500,6 +384,45 @@ async def get_chat_session(
|
||||
return session
|
||||
|
||||
|
||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from Redis cache."""
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||
|
||||
if raw_session is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
logger.info(
|
||||
f"Loading session {session_id} from cache: "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
)
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
prisma_session = await chat_db().get_chat_session(session_id)
|
||||
if not prisma_session:
|
||||
return None
|
||||
|
||||
messages = prisma_session.Messages
|
||||
logger.info(
|
||||
f"Loading session {session_id} from DB: "
|
||||
f"has_messages={messages is not None}, "
|
||||
f"message_count={len(messages) if messages else 0}, "
|
||||
f"roles={[m.role for m in messages] if messages else []}"
|
||||
)
|
||||
|
||||
return ChatSession.from_db(prisma_session, messages)
|
||||
|
||||
|
||||
async def upsert_chat_session(
|
||||
session: ChatSession,
|
||||
) -> ChatSession:
|
||||
@@ -520,7 +443,7 @@ async def upsert_chat_session(
|
||||
|
||||
async with lock:
|
||||
# Get existing message count from DB for incremental saves
|
||||
existing_message_count = await chat_db.get_chat_session_message_count(
|
||||
existing_message_count = await chat_db().get_chat_session_message_count(
|
||||
session.session_id
|
||||
)
|
||||
|
||||
@@ -537,7 +460,7 @@ async def upsert_chat_session(
|
||||
|
||||
# Save to cache (best-effort, even if DB failed)
|
||||
try:
|
||||
await _cache_session(session)
|
||||
await cache_chat_session(session)
|
||||
except Exception as e:
|
||||
# If DB succeeded but cache failed, raise cache error
|
||||
if db_error is None:
|
||||
@@ -558,6 +481,65 @@ async def upsert_chat_session(
|
||||
return session
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
session: ChatSession, existing_message_count: int
|
||||
) -> None:
|
||||
"""Save or update a chat session in the database."""
|
||||
db = chat_db()
|
||||
|
||||
# Check if session exists in DB
|
||||
existing = await db.get_chat_session(session.session_id)
|
||||
|
||||
if not existing:
|
||||
# Create new session
|
||||
await db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
# Calculate total tokens from usage
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
# Update session metadata
|
||||
await db.update_chat_session(
|
||||
session_id=session.session_id,
|
||||
credentials=session.credentials,
|
||||
successful_agent_runs=session.successful_agent_runs,
|
||||
successful_agent_schedules=session.successful_agent_schedules,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
# Add new messages (only those after existing count)
|
||||
new_messages = session.messages[existing_message_count:]
|
||||
if new_messages:
|
||||
messages_data = []
|
||||
for msg in new_messages:
|
||||
messages_data.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"name": msg.name,
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"refusal": msg.refusal,
|
||||
"tool_calls": msg.tool_calls,
|
||||
"function_call": msg.function_call,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||
f"roles={[m['role'] for m in messages_data]}, "
|
||||
f"start_sequence={existing_message_count}"
|
||||
)
|
||||
await db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
messages=messages_data,
|
||||
start_sequence=existing_message_count,
|
||||
)
|
||||
|
||||
|
||||
async def create_chat_session(user_id: str) -> ChatSession:
|
||||
"""Create a new chat session and persist it.
|
||||
|
||||
@@ -570,7 +552,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
|
||||
|
||||
# Create in database first - fail fast if this fails
|
||||
try:
|
||||
await chat_db.create_chat_session(
|
||||
await chat_db().create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -582,7 +564,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
|
||||
|
||||
# Cache the session (best-effort optimization, DB is source of truth)
|
||||
try:
|
||||
await _cache_session(session)
|
||||
await cache_chat_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache new session {session.session_id}: {e}")
|
||||
|
||||
@@ -600,8 +582,9 @@ async def get_user_sessions(
|
||||
A tuple of (sessions, total_count) where total_count is the overall
|
||||
number of sessions for the user (not just the current page).
|
||||
"""
|
||||
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
|
||||
total_count = await chat_db.get_user_session_count(user_id)
|
||||
db = chat_db()
|
||||
prisma_sessions = await db.get_user_chat_sessions(user_id, limit, offset)
|
||||
total_count = await db.get_user_session_count(user_id)
|
||||
|
||||
sessions = []
|
||||
for prisma_session in prisma_sessions:
|
||||
@@ -624,7 +607,7 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
"""
|
||||
# Delete from database first (with optional user_id validation)
|
||||
# This confirms ownership before invalidating cache
|
||||
deleted = await chat_db.delete_chat_session(session_id, user_id)
|
||||
deleted = await chat_db().delete_chat_session(session_id, user_id)
|
||||
|
||||
if not deleted:
|
||||
return False
|
||||
@@ -659,7 +642,7 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
||||
True if updated successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
result = await chat_db.update_chat_session(session_id=session_id, title=title)
|
||||
result = await chat_db().update_chat_session(session_id=session_id, title=title)
|
||||
if result is None:
|
||||
logger.warning(f"Session {session_id} not found for title update")
|
||||
return False
|
||||
@@ -676,3 +659,29 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update title for session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# ==================== Chat session locks ==================== #
|
||||
|
||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
|
||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||
|
||||
This was originally added to solve the specific problem of race conditions between
|
||||
the session title thread and the conversation thread, which always occurs on the
|
||||
same instance as we prevent rapid request sends on the frontend.
|
||||
|
||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||
when no coroutine holds a reference to them, preventing memory leaks from
|
||||
unbounded growth of session locks. Explicit cleanup also occurs
|
||||
in `delete_chat_session()`.
|
||||
"""
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_session_locks[session_id] = lock
|
||||
return lock
|
||||
@@ -27,6 +27,7 @@ from openai.types.chat import (
|
||||
ChatCompletionToolParam,
|
||||
)
|
||||
|
||||
from backend.data.db_accessors import chat_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import (
|
||||
format_understanding_for_prompt,
|
||||
@@ -35,7 +36,6 @@ from backend.data.understanding import (
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.settings import AppEnvironment, Settings
|
||||
|
||||
from . import db as chat_db
|
||||
from . import stream_registry
|
||||
from .config import ChatConfig
|
||||
from .model import (
|
||||
@@ -1744,7 +1744,7 @@ async def _update_pending_operation(
|
||||
This is called by background tasks when long-running operations complete.
|
||||
"""
|
||||
# Update the message in database
|
||||
updated = await chat_db.update_tool_message_content(
|
||||
updated = await chat_db().update_tool_message_content(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
new_content=result,
|
||||
@@ -3,8 +3,8 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tracking import track_tool_called
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tracking import track_tool_called
|
||||
|
||||
from .add_understanding import AddUnderstandingTool
|
||||
from .agent_output import AgentOutputTool
|
||||
@@ -27,7 +27,7 @@ from .workspace_files import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -6,11 +6,11 @@ import pytest
|
||||
from prisma.types import ProfileCreateInput
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db import prisma
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.data.model import APIKeyCredentials
|
||||
@@ -3,11 +3,9 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import understanding_db
|
||||
from backend.data.understanding import BusinessUnderstandingInput
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
|
||||
@@ -99,7 +97,9 @@ and automations for the user's specific needs."""
|
||||
]
|
||||
|
||||
# Upsert with merge
|
||||
understanding = await upsert_business_understanding(user_id, input_data)
|
||||
understanding = await understanding_db().upsert_business_understanding(
|
||||
user_id, input_data
|
||||
)
|
||||
|
||||
# Build current understanding summary (filter out empty values)
|
||||
current_understanding = {
|
||||
@@ -5,9 +5,8 @@ import re
|
||||
import uuid
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
|
||||
from backend.data.db_accessors import graph_db, library_db, store_db
|
||||
from backend.data.graph import Graph, Link, Node
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .service import (
|
||||
@@ -145,8 +144,9 @@ async def get_library_agent_by_id(
|
||||
Returns:
|
||||
LibraryAgentSummary if found, None otherwise
|
||||
"""
|
||||
db = library_db()
|
||||
try:
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
agent = await db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||
return LibraryAgentSummary(
|
||||
@@ -163,7 +163,7 @@ async def get_library_agent_by_id(
|
||||
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
|
||||
|
||||
try:
|
||||
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
agent = await db.get_library_agent(agent_id, user_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||
return LibraryAgentSummary(
|
||||
@@ -215,7 +215,7 @@ async def get_library_agents_for_generation(
|
||||
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
|
||||
"""
|
||||
try:
|
||||
response = await library_db.list_library_agents(
|
||||
response = await library_db().list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_query,
|
||||
page=1,
|
||||
@@ -272,7 +272,7 @@ async def search_marketplace_agents_for_generation(
|
||||
List of LibraryAgentSummary with full input/output schemas
|
||||
"""
|
||||
try:
|
||||
response = await store_db.get_store_agents(
|
||||
response = await store_db().get_store_agents(
|
||||
search_query=search_query,
|
||||
page=1,
|
||||
page_size=max_results,
|
||||
@@ -286,7 +286,7 @@ async def search_marketplace_agents_for_generation(
|
||||
return []
|
||||
|
||||
graph_ids = [agent.agent_graph_id for agent in agents_with_graphs]
|
||||
graphs = await get_store_listed_graphs(*graph_ids)
|
||||
graphs = await graph_db().get_store_listed_graphs(*graph_ids)
|
||||
|
||||
results: list[LibraryAgentSummary] = []
|
||||
for agent in agents_with_graphs:
|
||||
@@ -673,9 +673,10 @@ async def save_agent_to_library(
|
||||
Tuple of (created Graph, LibraryAgent)
|
||||
"""
|
||||
graph = json_to_graph(agent_json)
|
||||
db = library_db()
|
||||
if is_update:
|
||||
return await library_db.update_graph_in_library(graph, user_id)
|
||||
return await library_db.create_graph_in_library(graph, user_id)
|
||||
return await db.update_graph_in_library(graph, user_id)
|
||||
return await db.create_graph_in_library(graph, user_id)
|
||||
|
||||
|
||||
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
||||
@@ -735,12 +736,14 @@ async def get_agent_as_json(
|
||||
Returns:
|
||||
Agent as JSON dict or None if not found
|
||||
"""
|
||||
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
||||
db = graph_db()
|
||||
|
||||
graph = await db.get_graph(agent_id, version=None, user_id=user_id)
|
||||
|
||||
if not graph and user_id:
|
||||
try:
|
||||
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
graph = await get_graph(
|
||||
library_agent = await library_db().get_library_agent(agent_id, user_id)
|
||||
graph = await db.get_graph(
|
||||
library_agent.graph_id, version=None, user_id=user_id
|
||||
)
|
||||
except NotFoundError:
|
||||
@@ -0,0 +1,154 @@
|
||||
"""Dummy Agent Generator for testing.
|
||||
|
||||
Returns mock responses matching the format expected from the external service.
|
||||
Enable via AGENTGENERATOR_USE_DUMMY=true in settings.
|
||||
|
||||
WARNING: This is for testing only. Do not use in production.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Dummy decomposition result (instructions type)
|
||||
DUMMY_DECOMPOSITION_RESULT: dict[str, Any] = {
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{
|
||||
"description": "Get input from user",
|
||||
"action": "input",
|
||||
"block_name": "AgentInputBlock",
|
||||
},
|
||||
{
|
||||
"description": "Process the input",
|
||||
"action": "process",
|
||||
"block_name": "TextFormatterBlock",
|
||||
},
|
||||
{
|
||||
"description": "Return output to user",
|
||||
"action": "output",
|
||||
"block_name": "AgentOutputBlock",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Block IDs from backend/blocks/io.py
|
||||
AGENT_INPUT_BLOCK_ID = "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
|
||||
AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4"
|
||||
|
||||
|
||||
def _generate_dummy_agent_json() -> dict[str, Any]:
|
||||
"""Generate a minimal valid agent JSON for testing."""
|
||||
input_node_id = str(uuid.uuid4())
|
||||
output_node_id = str(uuid.uuid4())
|
||||
|
||||
return {
|
||||
"id": str(uuid.uuid4()),
|
||||
"version": 1,
|
||||
"is_active": True,
|
||||
"name": "Dummy Test Agent",
|
||||
"description": "A dummy agent generated for testing purposes",
|
||||
"nodes": [
|
||||
{
|
||||
"id": input_node_id,
|
||||
"block_id": AGENT_INPUT_BLOCK_ID,
|
||||
"input_default": {
|
||||
"name": "input",
|
||||
"title": "Input",
|
||||
"description": "Enter your input",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
"metadata": {"position": {"x": 0, "y": 0}},
|
||||
},
|
||||
{
|
||||
"id": output_node_id,
|
||||
"block_id": AGENT_OUTPUT_BLOCK_ID,
|
||||
"input_default": {
|
||||
"name": "output",
|
||||
"title": "Output",
|
||||
"description": "Agent output",
|
||||
"format": "{output}",
|
||||
},
|
||||
"metadata": {"position": {"x": 400, "y": 0}},
|
||||
},
|
||||
],
|
||||
"links": [
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": input_node_id,
|
||||
"sink_id": output_node_id,
|
||||
"source_name": "result",
|
||||
"sink_name": "value",
|
||||
"is_static": False,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def decompose_goal_dummy(
|
||||
description: str,
|
||||
context: str = "",
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Return dummy decomposition result."""
|
||||
logger.info("Using dummy agent generator for decompose_goal")
|
||||
return DUMMY_DECOMPOSITION_RESULT.copy()
|
||||
|
||||
|
||||
async def generate_agent_dummy(
|
||||
instructions: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Return dummy agent JSON after a simulated delay."""
|
||||
logger.info("Using dummy agent generator for generate_agent (30s delay)")
|
||||
await asyncio.sleep(30)
|
||||
return _generate_dummy_agent_json()
|
||||
|
||||
|
||||
async def generate_agent_patch_dummy(
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Return dummy patched agent (returns the current agent with updated description)."""
|
||||
logger.info("Using dummy agent generator for generate_agent_patch")
|
||||
patched = current_agent.copy()
|
||||
patched["description"] = (
|
||||
f"{current_agent.get('description', '')} (updated: {update_request})"
|
||||
)
|
||||
return patched
|
||||
|
||||
|
||||
async def customize_template_dummy(
|
||||
template_agent: dict[str, Any],
|
||||
modification_request: str,
|
||||
context: str = "",
|
||||
) -> dict[str, Any]:
|
||||
"""Return dummy customized template (returns template with updated description)."""
|
||||
logger.info("Using dummy agent generator for customize_template")
|
||||
customized = template_agent.copy()
|
||||
customized["description"] = (
|
||||
f"{template_agent.get('description', '')} (customized: {modification_request})"
|
||||
)
|
||||
return customized
|
||||
|
||||
|
||||
async def get_blocks_dummy() -> list[dict[str, Any]]:
|
||||
"""Return dummy blocks list."""
|
||||
logger.info("Using dummy agent generator for get_blocks")
|
||||
return [
|
||||
{"id": AGENT_INPUT_BLOCK_ID, "name": "AgentInputBlock"},
|
||||
{"id": AGENT_OUTPUT_BLOCK_ID, "name": "AgentOutputBlock"},
|
||||
]
|
||||
|
||||
|
||||
async def health_check_dummy() -> bool:
|
||||
"""Always returns healthy for dummy service."""
|
||||
return True
|
||||
@@ -12,8 +12,19 @@ import httpx
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .dummy import (
|
||||
customize_template_dummy,
|
||||
decompose_goal_dummy,
|
||||
generate_agent_dummy,
|
||||
generate_agent_patch_dummy,
|
||||
get_blocks_dummy,
|
||||
health_check_dummy,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_dummy_mode_warned = False
|
||||
|
||||
|
||||
def _create_error_response(
|
||||
error_message: str,
|
||||
@@ -90,10 +101,26 @@ def _get_settings() -> Settings:
|
||||
return _settings
|
||||
|
||||
|
||||
def is_external_service_configured() -> bool:
|
||||
"""Check if external Agent Generator service is configured."""
|
||||
def _is_dummy_mode() -> bool:
|
||||
"""Check if dummy mode is enabled for testing."""
|
||||
global _dummy_mode_warned
|
||||
settings = _get_settings()
|
||||
return bool(settings.config.agentgenerator_host)
|
||||
is_dummy = bool(settings.config.agentgenerator_use_dummy)
|
||||
if is_dummy and not _dummy_mode_warned:
|
||||
logger.warning(
|
||||
"Agent Generator running in DUMMY MODE - returning mock responses. "
|
||||
"Do not use in production!"
|
||||
)
|
||||
_dummy_mode_warned = True
|
||||
return is_dummy
|
||||
|
||||
|
||||
def is_external_service_configured() -> bool:
|
||||
"""Check if external Agent Generator service is configured (or dummy mode)."""
|
||||
settings = _get_settings()
|
||||
return bool(settings.config.agentgenerator_host) or bool(
|
||||
settings.config.agentgenerator_use_dummy
|
||||
)
|
||||
|
||||
|
||||
def _get_base_url() -> str:
|
||||
@@ -137,6 +164,9 @@ async def decompose_goal_external(
|
||||
- {"type": "error", "error": "...", "error_type": "..."} on error
|
||||
Or None on unexpected error
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await decompose_goal_dummy(description, context, library_agents)
|
||||
|
||||
client = _get_client()
|
||||
|
||||
if context:
|
||||
@@ -226,6 +256,11 @@ async def generate_agent_external(
|
||||
Returns:
|
||||
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await generate_agent_dummy(
|
||||
instructions, library_agents, operation_id, task_id
|
||||
)
|
||||
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
@@ -297,6 +332,11 @@ async def generate_agent_patch_external(
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await generate_agent_patch_dummy(
|
||||
update_request, current_agent, library_agents, operation_id, task_id
|
||||
)
|
||||
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
@@ -383,6 +423,11 @@ async def customize_template_external(
|
||||
Returns:
|
||||
Customized agent JSON, clarifying questions dict, or error dict on error
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await customize_template_dummy(
|
||||
template_agent, modification_request, context
|
||||
)
|
||||
|
||||
client = _get_client()
|
||||
|
||||
request = modification_request
|
||||
@@ -445,6 +490,9 @@ async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||
Returns:
|
||||
List of block info dicts or None on error
|
||||
"""
|
||||
if _is_dummy_mode():
|
||||
return await get_blocks_dummy()
|
||||
|
||||
client = _get_client()
|
||||
|
||||
try:
|
||||
@@ -478,6 +526,9 @@ async def health_check() -> bool:
|
||||
if not is_external_service_configured():
|
||||
return False
|
||||
|
||||
if _is_dummy_mode():
|
||||
return await health_check_dummy()
|
||||
|
||||
client = _get_client()
|
||||
|
||||
try:
|
||||
@@ -7,10 +7,9 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library.model import LibraryAgent
|
||||
from backend.data import execution as execution_db
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import execution_db, library_db
|
||||
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
|
||||
|
||||
from .base import BaseTool
|
||||
@@ -165,10 +164,12 @@ class AgentOutputTool(BaseTool):
|
||||
Resolve agent from provided identifiers.
|
||||
Returns (library_agent, error_message).
|
||||
"""
|
||||
lib_db = library_db()
|
||||
|
||||
# Priority 1: Exact library agent ID
|
||||
if library_agent_id:
|
||||
try:
|
||||
agent = await library_db.get_library_agent(library_agent_id, user_id)
|
||||
agent = await lib_db.get_library_agent(library_agent_id, user_id)
|
||||
return agent, None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get library agent by ID: {e}")
|
||||
@@ -182,7 +183,7 @@ class AgentOutputTool(BaseTool):
|
||||
return None, f"Agent '{store_slug}' not found in marketplace"
|
||||
|
||||
# Find in user's library by graph_id
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
|
||||
agent = await lib_db.get_library_agent_by_graph_id(user_id, graph.id)
|
||||
if not agent:
|
||||
return (
|
||||
None,
|
||||
@@ -194,7 +195,7 @@ class AgentOutputTool(BaseTool):
|
||||
# Priority 3: Fuzzy name search in library
|
||||
if agent_name:
|
||||
try:
|
||||
response = await library_db.list_library_agents(
|
||||
response = await lib_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=agent_name,
|
||||
page_size=5,
|
||||
@@ -228,9 +229,11 @@ class AgentOutputTool(BaseTool):
|
||||
Fetch execution(s) based on filters.
|
||||
Returns (single_execution, available_executions_meta, error_message).
|
||||
"""
|
||||
exec_db = execution_db()
|
||||
|
||||
# If specific execution_id provided, fetch it directly
|
||||
if execution_id:
|
||||
execution = await execution_db.get_graph_execution(
|
||||
execution = await exec_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=False,
|
||||
@@ -240,7 +243,7 @@ class AgentOutputTool(BaseTool):
|
||||
return execution, [], None
|
||||
|
||||
# Get completed executions with time filters
|
||||
executions = await execution_db.get_graph_executions(
|
||||
executions = await exec_db.get_graph_executions(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
statuses=[ExecutionStatus.COMPLETED],
|
||||
@@ -254,7 +257,7 @@ class AgentOutputTool(BaseTool):
|
||||
|
||||
# If only one execution, fetch full details
|
||||
if len(executions) == 1:
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
full_execution = await exec_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
@@ -262,7 +265,7 @@ class AgentOutputTool(BaseTool):
|
||||
return full_execution, [], None
|
||||
|
||||
# Multiple executions - return latest with full details, plus list of available
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
full_execution = await exec_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
@@ -380,7 +383,7 @@ class AgentOutputTool(BaseTool):
|
||||
and not input_data.store_slug
|
||||
):
|
||||
# Fetch execution directly to get graph_id
|
||||
execution = await execution_db.get_graph_execution(
|
||||
execution = await execution_db().get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=input_data.execution_id,
|
||||
include_node_executions=False,
|
||||
@@ -392,7 +395,7 @@ class AgentOutputTool(BaseTool):
|
||||
)
|
||||
|
||||
# Find library agent by graph_id
|
||||
agent = await library_db.get_library_agent_by_graph_id(
|
||||
agent = await library_db().get_library_agent_by_graph_id(
|
||||
user_id, execution.graph_id
|
||||
)
|
||||
if not agent:
|
||||
@@ -4,8 +4,7 @@ import logging
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data.db_accessors import library_db, store_db
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .models import (
|
||||
@@ -45,8 +44,10 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
|
||||
Returns:
|
||||
AgentInfo if found, None otherwise
|
||||
"""
|
||||
lib_db = library_db()
|
||||
|
||||
try:
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||
return AgentInfo(
|
||||
@@ -71,7 +72,7 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
|
||||
)
|
||||
|
||||
try:
|
||||
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
agent = await lib_db.get_library_agent(agent_id, user_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||
return AgentInfo(
|
||||
@@ -133,7 +134,7 @@ async def search_agents(
|
||||
try:
|
||||
if source == "marketplace":
|
||||
logger.info(f"Searching marketplace for: {query}")
|
||||
results = await store_db.get_store_agents(search_query=query, page_size=5)
|
||||
results = await store_db().get_store_agents(search_query=query, page_size=5)
|
||||
for agent in results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
@@ -159,7 +160,7 @@ async def search_agents(
|
||||
|
||||
if not agents:
|
||||
logger.info(f"Searching user library for: {query}")
|
||||
results = await library_db.list_library_agents(
|
||||
results = await library_db().list_library_agents(
|
||||
user_id=user_id, # type: ignore[arg-type]
|
||||
search_term=query,
|
||||
page_size=10,
|
||||
@@ -5,8 +5,8 @@ from typing import Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
|
||||
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
@@ -3,9 +3,9 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.api.features.store.exceptions import AgentNotFoundError
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import store_db as get_store_db
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
@@ -137,6 +137,8 @@ class CustomizeAgentTool(BaseTool):
|
||||
|
||||
creator_username, agent_slug = parts
|
||||
|
||||
store_db = get_store_db()
|
||||
|
||||
# Fetch the marketplace agent details
|
||||
try:
|
||||
agent_details = await store_db.get_store_agent_details(
|
||||
@@ -3,7 +3,7 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .base import BaseTool
|
||||
@@ -3,18 +3,18 @@ from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
|
||||
from backend.api.features.chat.tools.models import (
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.base import BaseTool, ToolResponseBase
|
||||
from backend.copilot.tools.models import (
|
||||
BlockInfoSummary,
|
||||
BlockInputFieldInfo,
|
||||
BlockListResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
)
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.data.db_accessors import search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -107,7 +107,7 @@ class FindBlockTool(BaseTool):
|
||||
|
||||
try:
|
||||
# Search for blocks using hybrid search
|
||||
results, total = await unified_hybrid_search(
|
||||
results, total = await search().unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
@@ -4,13 +4,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.find_block import (
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.tools.find_block import (
|
||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||
FindBlockTool,
|
||||
)
|
||||
from backend.api.features.chat.tools.models import BlockListResponse
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.tools.models import BlockListResponse
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
@@ -75,13 +75,17 @@ class TestFindBlockFiltering:
|
||||
"standard-block-id": standard_block,
|
||||
}.get(block_id)
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||
return_value=(search_results, 2)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(search_results, 2),
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.get_block",
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
@@ -119,13 +123,17 @@ class TestFindBlockFiltering:
|
||||
"normal-block-id": normal_block,
|
||||
}.get(block_id)
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||
return_value=(search_results, 2)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(search_results, 2),
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.get_block",
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .base import BaseTool
|
||||
@@ -4,9 +4,9 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.base import BaseTool
|
||||
from backend.copilot.tools.models import (
|
||||
DocPageResponse,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
@@ -5,16 +5,12 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.api.features.chat.config import ChatConfig
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tracking import (
|
||||
track_agent_run_success,
|
||||
track_agent_scheduled,
|
||||
)
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tracking import track_agent_run_success, track_agent_scheduled
|
||||
from backend.data.db_accessors import graph_db, library_db, user_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
@@ -200,7 +196,7 @@ class RunAgentTool(BaseTool):
|
||||
|
||||
# Priority: library_agent_id if provided
|
||||
if has_library_id:
|
||||
library_agent = await library_db.get_library_agent(
|
||||
library_agent = await library_db().get_library_agent(
|
||||
params.library_agent_id, user_id
|
||||
)
|
||||
if not library_agent:
|
||||
@@ -209,9 +205,7 @@ class RunAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
# Get the graph from the library agent
|
||||
from backend.data.graph import get_graph
|
||||
|
||||
graph = await get_graph(
|
||||
graph = await graph_db().get_graph(
|
||||
library_agent.graph_id,
|
||||
library_agent.graph_version,
|
||||
user_id=user_id,
|
||||
@@ -522,7 +516,7 @@ class RunAgentTool(BaseTool):
|
||||
library_agent = await get_or_create_library_agent(graph, user_id)
|
||||
|
||||
# Get user timezone
|
||||
user = await get_user_by_id(user_id)
|
||||
user = await user_db().get_user_by_id(user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else timezone)
|
||||
|
||||
# Create schedule
|
||||
@@ -7,16 +7,16 @@ from typing import Any
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.find_block import (
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.find_block import (
|
||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||
)
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
@@ -190,7 +190,7 @@ class RunBlockTool(BaseTool):
|
||||
|
||||
try:
|
||||
# Get or create user's workspace for CoPilot file operations
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
|
||||
# Generate synthetic IDs for CoPilot context
|
||||
# Each chat session is treated as its own agent with one continuous run
|
||||
@@ -4,9 +4,9 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.models import ErrorResponse
|
||||
from backend.api.features.chat.tools.run_block import RunBlockTool
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
from backend.copilot.tools.run_block import RunBlockTool
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
@@ -39,7 +39,7 @@ class TestRunBlockFiltering:
|
||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=input_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
@@ -65,7 +65,7 @@ class TestRunBlockFiltering:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=smart_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
@@ -89,7 +89,7 @@ class TestRunBlockFiltering:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=standard_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
@@ -5,16 +5,16 @@ from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.base import BaseTool
|
||||
from backend.copilot.tools.models import (
|
||||
DocSearchResult,
|
||||
DocSearchResultsResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.data.db_accessors import search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -117,7 +117,7 @@ class SearchDocsTool(BaseTool):
|
||||
|
||||
try:
|
||||
# Search using hybrid search for DOCUMENTATION content type only
|
||||
results, total = await unified_hybrid_search(
|
||||
results, total = await search().unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.DOCUMENTATION],
|
||||
page=1,
|
||||
@@ -3,9 +3,8 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library import model as library_model
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data.db_accessors import library_db, store_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import (
|
||||
Credentials,
|
||||
@@ -15,7 +14,6 @@ from backend.data.model import (
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -39,13 +37,14 @@ async def fetch_graph_from_store_slug(
|
||||
Raises:
|
||||
DatabaseError: If there's a database error during lookup.
|
||||
"""
|
||||
sdb = store_db()
|
||||
try:
|
||||
store_agent = await store_db.get_store_agent_details(username, agent_name)
|
||||
store_agent = await sdb.get_store_agent_details(username, agent_name)
|
||||
except NotFoundError:
|
||||
return None, None
|
||||
|
||||
# Get the graph from store listing version
|
||||
graph = await store_db.get_available_graph(
|
||||
graph = await sdb.get_available_graph(
|
||||
store_agent.store_listing_version_id, hide_nodes=False
|
||||
)
|
||||
return graph, store_agent
|
||||
@@ -210,13 +209,13 @@ async def get_or_create_library_agent(
|
||||
Returns:
|
||||
LibraryAgent instance
|
||||
"""
|
||||
existing = await library_db.get_library_agent_by_graph_id(
|
||||
existing = await library_db().get_library_agent_by_graph_id(
|
||||
graph_id=graph.id, user_id=user_id
|
||||
)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
library_agents = await library_db.create_library_agent(
|
||||
library_agents = await library_db().create_library_agent(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
create_library_agents_for_sub_graphs=False,
|
||||
@@ -360,7 +359,7 @@ async def match_user_credentials_to_graph(
|
||||
_,
|
||||
_,
|
||||
) in aggregated_creds.items():
|
||||
# Find first matching credential by provider, type, scopes, and host/URL
|
||||
# Find first matching credential by provider, type, and scopes
|
||||
matching_cred = next(
|
||||
(
|
||||
cred
|
||||
@@ -375,10 +374,6 @@ async def match_user_credentials_to_graph(
|
||||
cred.type != "host_scoped"
|
||||
or _credential_is_for_host(cred, credential_requirements)
|
||||
)
|
||||
and (
|
||||
cred.provider != ProviderName.MCP
|
||||
or _credential_is_for_mcp_server(cred, credential_requirements)
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
@@ -449,22 +444,6 @@ def _credential_is_for_host(
|
||||
return credential.matches_url(list(requirements.discriminator_values)[0])
|
||||
|
||||
|
||||
def _credential_is_for_mcp_server(
|
||||
credential: Credentials,
|
||||
requirements: CredentialsFieldInfo,
|
||||
) -> bool:
|
||||
"""Check if an MCP OAuth credential matches the required server URL."""
|
||||
if not requirements.discriminator_values:
|
||||
return True
|
||||
|
||||
server_url = (
|
||||
credential.metadata.get("mcp_server_url")
|
||||
if isinstance(credential, OAuth2Credentials)
|
||||
else None
|
||||
)
|
||||
return server_url in requirements.discriminator_values if server_url else False
|
||||
|
||||
|
||||
async def check_user_has_required_credentials(
|
||||
user_id: str,
|
||||
required_credentials: list[CredentialsMetaInput],
|
||||
@@ -6,8 +6,8 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
@@ -146,7 +146,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
@@ -280,7 +280,7 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
@@ -478,7 +478,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
# Virus scan
|
||||
await scan_content_safe(content, filename=filename)
|
||||
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
@@ -577,7 +577,7 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
118
autogpt_platform/backend/backend/data/db_accessors.py
Normal file
118
autogpt_platform/backend/backend/data/db_accessors.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from backend.data import db
|
||||
|
||||
|
||||
def chat_db():
|
||||
if db.is_connected():
|
||||
from backend.copilot import db as _chat_db
|
||||
|
||||
chat_db = _chat_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
chat_db = get_database_manager_async_client()
|
||||
|
||||
return chat_db
|
||||
|
||||
|
||||
def graph_db():
|
||||
if db.is_connected():
|
||||
from backend.data import graph as _graph_db
|
||||
|
||||
graph_db = _graph_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
graph_db = get_database_manager_async_client()
|
||||
|
||||
return graph_db
|
||||
|
||||
|
||||
def library_db():
|
||||
if db.is_connected():
|
||||
from backend.api.features.library import db as _library_db
|
||||
|
||||
library_db = _library_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
library_db = get_database_manager_async_client()
|
||||
|
||||
return library_db
|
||||
|
||||
|
||||
def store_db():
|
||||
if db.is_connected():
|
||||
from backend.api.features.store import db as _store_db
|
||||
|
||||
store_db = _store_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
store_db = get_database_manager_async_client()
|
||||
|
||||
return store_db
|
||||
|
||||
|
||||
def search():
|
||||
if db.is_connected():
|
||||
from backend.api.features.store import hybrid_search as _search
|
||||
|
||||
search = _search
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
search = get_database_manager_async_client()
|
||||
|
||||
return search
|
||||
|
||||
|
||||
def execution_db():
|
||||
if db.is_connected():
|
||||
from backend.data import execution as _execution_db
|
||||
|
||||
execution_db = _execution_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
execution_db = get_database_manager_async_client()
|
||||
|
||||
return execution_db
|
||||
|
||||
|
||||
def user_db():
|
||||
if db.is_connected():
|
||||
from backend.data import user as _user_db
|
||||
|
||||
user_db = _user_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
user_db = get_database_manager_async_client()
|
||||
|
||||
return user_db
|
||||
|
||||
|
||||
def understanding_db():
|
||||
if db.is_connected():
|
||||
from backend.data import understanding as _understanding_db
|
||||
|
||||
understanding_db = _understanding_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
understanding_db = get_database_manager_async_client()
|
||||
|
||||
return understanding_db
|
||||
|
||||
|
||||
def workspace_db():
|
||||
if db.is_connected():
|
||||
from backend.data import workspace as _workspace_db
|
||||
|
||||
workspace_db = _workspace_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
workspace_db = get_database_manager_async_client()
|
||||
|
||||
return workspace_db
|
||||
@@ -4,14 +4,26 @@ from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cas
|
||||
|
||||
from backend.api.features.library.db import (
|
||||
add_store_agent_to_library,
|
||||
create_graph_in_library,
|
||||
create_library_agent,
|
||||
get_library_agent,
|
||||
get_library_agent_by_graph_id,
|
||||
list_library_agents,
|
||||
update_graph_in_library,
|
||||
)
|
||||
from backend.api.features.store.db import (
|
||||
get_agent,
|
||||
get_available_graph,
|
||||
get_store_agent_details,
|
||||
get_store_agents,
|
||||
)
|
||||
from backend.api.features.store.db import get_store_agent_details, get_store_agents
|
||||
from backend.api.features.store.embeddings import (
|
||||
backfill_missing_embeddings,
|
||||
cleanup_orphaned_embeddings,
|
||||
get_embedding_stats,
|
||||
)
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.copilot import db as chat_db
|
||||
from backend.data import db
|
||||
from backend.data.analytics import (
|
||||
get_accuracy_trends_and_alerts,
|
||||
@@ -48,6 +60,7 @@ from backend.data.graph import (
|
||||
get_graph_metadata,
|
||||
get_graph_settings,
|
||||
get_node,
|
||||
get_store_listed_graphs,
|
||||
validate_graph_execution_permissions,
|
||||
)
|
||||
from backend.data.human_review import (
|
||||
@@ -67,6 +80,10 @@ from backend.data.notifications import (
|
||||
remove_notifications_from_batch,
|
||||
)
|
||||
from backend.data.onboarding import increment_onboarding_runs
|
||||
from backend.data.understanding import (
|
||||
get_business_understanding,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_active_user_ids_in_timerange,
|
||||
get_user_by_id,
|
||||
@@ -76,6 +93,7 @@ from backend.data.user import (
|
||||
get_user_notification_preference,
|
||||
update_user_integrations,
|
||||
)
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
@@ -107,6 +125,13 @@ async def _get_credits(user_id: str) -> int:
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
"""Database connection pooling service.
|
||||
|
||||
This service connects to the Prisma engine and exposes database
|
||||
operations via RPC endpoints. It acts as a centralized connection pool
|
||||
for all services that need database access.
|
||||
"""
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(self, app: "FastAPI"):
|
||||
async with super().lifespan(app):
|
||||
@@ -142,11 +167,15 @@ class DatabaseManager(AppService):
|
||||
def _(
|
||||
f: Callable[P, R], name: str | None = None
|
||||
) -> Callable[Concatenate[object, P], R]:
|
||||
"""
|
||||
Exposes a function as an RPC endpoint, and adds a virtual `self` param
|
||||
to the function's type so it can be bound as a method.
|
||||
"""
|
||||
if name is not None:
|
||||
f.__name__ = name
|
||||
return cast(Callable[Concatenate[object, P], R], expose(f))
|
||||
|
||||
# Executions
|
||||
# ============ Graph Executions ============ #
|
||||
get_child_graph_executions = _(get_child_graph_executions)
|
||||
get_graph_executions = _(get_graph_executions)
|
||||
get_graph_executions_count = _(get_graph_executions_count)
|
||||
@@ -170,36 +199,37 @@ class DatabaseManager(AppService):
|
||||
get_frequently_executed_graphs = _(get_frequently_executed_graphs)
|
||||
get_marketplace_graphs_for_monitoring = _(get_marketplace_graphs_for_monitoring)
|
||||
|
||||
# Graphs
|
||||
# ============ Graphs ============ #
|
||||
get_node = _(get_node)
|
||||
get_graph = _(get_graph)
|
||||
get_connected_output_nodes = _(get_connected_output_nodes)
|
||||
get_graph_metadata = _(get_graph_metadata)
|
||||
get_graph_settings = _(get_graph_settings)
|
||||
get_store_listed_graphs = _(get_store_listed_graphs)
|
||||
|
||||
# Credits
|
||||
# ============ Credits ============ #
|
||||
spend_credits = _(_spend_credits, name="spend_credits")
|
||||
get_credits = _(_get_credits, name="get_credits")
|
||||
|
||||
# User + User Metadata + User Integrations
|
||||
# ============ User + Integrations ============ #
|
||||
get_user_by_id = _(get_user_by_id)
|
||||
get_user_integrations = _(get_user_integrations)
|
||||
update_user_integrations = _(update_user_integrations)
|
||||
|
||||
# User Comms - async
|
||||
# ============ User Comms ============ #
|
||||
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
|
||||
get_user_by_id = _(get_user_by_id)
|
||||
get_user_email_by_id = _(get_user_email_by_id)
|
||||
get_user_email_verification = _(get_user_email_verification)
|
||||
get_user_notification_preference = _(get_user_notification_preference)
|
||||
|
||||
# Human In The Loop
|
||||
# ============ Human In The Loop ============ #
|
||||
cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution)
|
||||
check_approval = _(check_approval)
|
||||
get_or_create_human_review = _(get_or_create_human_review)
|
||||
has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec)
|
||||
update_review_processed_status = _(update_review_processed_status)
|
||||
|
||||
# Notifications - async
|
||||
# ============ Notifications ============ #
|
||||
clear_all_user_notification_batches = _(clear_all_user_notification_batches)
|
||||
create_or_add_to_user_notification_batch = _(
|
||||
create_or_add_to_user_notification_batch
|
||||
@@ -212,29 +242,56 @@ class DatabaseManager(AppService):
|
||||
get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
# Library
|
||||
# ============ Library ============ #
|
||||
list_library_agents = _(list_library_agents)
|
||||
add_store_agent_to_library = _(add_store_agent_to_library)
|
||||
create_graph_in_library = _(create_graph_in_library)
|
||||
create_library_agent = _(create_library_agent)
|
||||
get_library_agent = _(get_library_agent)
|
||||
get_library_agent_by_graph_id = _(get_library_agent_by_graph_id)
|
||||
update_graph_in_library = _(update_graph_in_library)
|
||||
validate_graph_execution_permissions = _(validate_graph_execution_permissions)
|
||||
|
||||
# Onboarding
|
||||
# ============ Onboarding ============ #
|
||||
increment_onboarding_runs = _(increment_onboarding_runs)
|
||||
|
||||
# OAuth
|
||||
# ============ OAuth ============ #
|
||||
cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens)
|
||||
|
||||
# Store
|
||||
# ============ Store ============ #
|
||||
get_store_agents = _(get_store_agents)
|
||||
get_store_agent_details = _(get_store_agent_details)
|
||||
get_agent = _(get_agent)
|
||||
get_available_graph = _(get_available_graph)
|
||||
|
||||
# Store Embeddings
|
||||
# ============ Search ============ #
|
||||
get_embedding_stats = _(get_embedding_stats)
|
||||
backfill_missing_embeddings = _(backfill_missing_embeddings)
|
||||
cleanup_orphaned_embeddings = _(cleanup_orphaned_embeddings)
|
||||
unified_hybrid_search = _(unified_hybrid_search)
|
||||
|
||||
# Summary data - async
|
||||
# ============ Summary Data ============ #
|
||||
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
||||
|
||||
# ============ Workspace ============ #
|
||||
get_or_create_workspace = _(get_or_create_workspace)
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = _(get_business_understanding)
|
||||
upsert_business_understanding = _(upsert_business_understanding)
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
get_chat_session = _(chat_db.get_chat_session)
|
||||
create_chat_session = _(chat_db.create_chat_session)
|
||||
update_chat_session = _(chat_db.update_chat_session)
|
||||
add_chat_message = _(chat_db.add_chat_message)
|
||||
add_chat_messages_batch = _(chat_db.add_chat_messages_batch)
|
||||
get_user_chat_sessions = _(chat_db.get_user_chat_sessions)
|
||||
get_user_session_count = _(chat_db.get_user_session_count)
|
||||
delete_chat_session = _(chat_db.delete_chat_session)
|
||||
get_chat_session_message_count = _(chat_db.get_chat_session_message_count)
|
||||
update_tool_message_content = _(chat_db.update_tool_message_content)
|
||||
|
||||
|
||||
class DatabaseManagerClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
@@ -296,43 +353,50 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
def get_service_type(cls):
|
||||
return DatabaseManager
|
||||
|
||||
# ============ Graph Executions ============ #
|
||||
create_graph_execution = d.create_graph_execution
|
||||
get_child_graph_executions = d.get_child_graph_executions
|
||||
get_connected_output_nodes = d.get_connected_output_nodes
|
||||
get_latest_node_execution = d.get_latest_node_execution
|
||||
get_graph = d.get_graph
|
||||
get_graph_metadata = d.get_graph_metadata
|
||||
get_graph_settings = d.get_graph_settings
|
||||
get_graph_execution = d.get_graph_execution
|
||||
get_graph_execution_meta = d.get_graph_execution_meta
|
||||
get_node = d.get_node
|
||||
get_graph_executions = d.get_graph_executions
|
||||
get_node_execution = d.get_node_execution
|
||||
get_node_executions = d.get_node_executions
|
||||
get_user_by_id = d.get_user_by_id
|
||||
get_user_integrations = d.get_user_integrations
|
||||
upsert_execution_input = d.upsert_execution_input
|
||||
upsert_execution_output = d.upsert_execution_output
|
||||
get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id
|
||||
update_graph_execution_stats = d.update_graph_execution_stats
|
||||
update_node_execution_status = d.update_node_execution_status
|
||||
update_node_execution_status_batch = d.update_node_execution_status_batch
|
||||
update_user_integrations = d.update_user_integrations
|
||||
upsert_execution_input = d.upsert_execution_input
|
||||
upsert_execution_output = d.upsert_execution_output
|
||||
get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id
|
||||
get_execution_kv_data = d.get_execution_kv_data
|
||||
set_execution_kv_data = d.set_execution_kv_data
|
||||
|
||||
# Human In The Loop
|
||||
# ============ Graphs ============ #
|
||||
get_graph = d.get_graph
|
||||
get_graph_metadata = d.get_graph_metadata
|
||||
get_graph_settings = d.get_graph_settings
|
||||
get_node = d.get_node
|
||||
get_store_listed_graphs = d.get_store_listed_graphs
|
||||
|
||||
# ============ User + Integrations ============ #
|
||||
get_user_by_id = d.get_user_by_id
|
||||
get_user_integrations = d.get_user_integrations
|
||||
update_user_integrations = d.update_user_integrations
|
||||
|
||||
# ============ Human In The Loop ============ #
|
||||
cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution
|
||||
check_approval = d.check_approval
|
||||
get_or_create_human_review = d.get_or_create_human_review
|
||||
update_review_processed_status = d.update_review_processed_status
|
||||
|
||||
# User Comms
|
||||
# ============ User Comms ============ #
|
||||
get_active_user_ids_in_timerange = d.get_active_user_ids_in_timerange
|
||||
get_user_email_by_id = d.get_user_email_by_id
|
||||
get_user_email_verification = d.get_user_email_verification
|
||||
get_user_notification_preference = d.get_user_notification_preference
|
||||
|
||||
# Notifications
|
||||
# ============ Notifications ============ #
|
||||
clear_all_user_notification_batches = d.clear_all_user_notification_batches
|
||||
create_or_add_to_user_notification_batch = (
|
||||
d.create_or_add_to_user_notification_batch
|
||||
@@ -345,20 +409,49 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
d.get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
# Library
|
||||
# ============ Library ============ #
|
||||
list_library_agents = d.list_library_agents
|
||||
add_store_agent_to_library = d.add_store_agent_to_library
|
||||
create_graph_in_library = d.create_graph_in_library
|
||||
create_library_agent = d.create_library_agent
|
||||
get_library_agent = d.get_library_agent
|
||||
get_library_agent_by_graph_id = d.get_library_agent_by_graph_id
|
||||
update_graph_in_library = d.update_graph_in_library
|
||||
validate_graph_execution_permissions = d.validate_graph_execution_permissions
|
||||
|
||||
# Onboarding
|
||||
# ============ Onboarding ============ #
|
||||
increment_onboarding_runs = d.increment_onboarding_runs
|
||||
|
||||
# OAuth
|
||||
# ============ OAuth ============ #
|
||||
cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens
|
||||
|
||||
# Store
|
||||
# ============ Store ============ #
|
||||
get_store_agents = d.get_store_agents
|
||||
get_store_agent_details = d.get_store_agent_details
|
||||
get_agent = d.get_agent
|
||||
get_available_graph = d.get_available_graph
|
||||
|
||||
# Summary data
|
||||
# ============ Search ============ #
|
||||
unified_hybrid_search = d.unified_hybrid_search
|
||||
|
||||
# ============ Summary Data ============ #
|
||||
get_user_execution_summary_data = d.get_user_execution_summary_data
|
||||
|
||||
# ============ Workspace ============ #
|
||||
get_or_create_workspace = d.get_or_create_workspace
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = d.get_business_understanding
|
||||
upsert_business_understanding = d.upsert_business_understanding
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
get_chat_session = d.get_chat_session
|
||||
create_chat_session = d.create_chat_session
|
||||
update_chat_session = d.update_chat_session
|
||||
add_chat_message = d.add_chat_message
|
||||
add_chat_messages_batch = d.add_chat_messages_batch
|
||||
get_user_chat_sessions = d.get_user_chat_sessions
|
||||
get_user_session_count = d.get_user_session_count
|
||||
delete_chat_session = d.delete_chat_session
|
||||
get_chat_session_message_count = d.get_chat_session_message_count
|
||||
update_tool_message_content = d.update_tool_message_content
|
||||
@@ -33,7 +33,6 @@ from backend.util import type as type_utils
|
||||
from backend.util.exceptions import GraphNotAccessibleError, GraphNotInLibraryError
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.request import parse_url
|
||||
|
||||
from .block import BlockInput
|
||||
from .db import BaseDbModel
|
||||
@@ -450,9 +449,6 @@ class GraphModel(Graph, GraphMeta):
|
||||
continue
|
||||
if ProviderName.HTTP in field.provider:
|
||||
continue
|
||||
# MCP credentials are intentionally split by server URL
|
||||
if ProviderName.MCP in field.provider:
|
||||
continue
|
||||
|
||||
# If this happens, that means a block implementation probably needs
|
||||
# to be updated.
|
||||
@@ -509,18 +505,6 @@ class GraphModel(Graph, GraphMeta):
|
||||
"required": ["id", "provider", "type"],
|
||||
}
|
||||
|
||||
# Add a descriptive display title when URL-based discriminator values
|
||||
# are present (e.g. "mcp.sentry.dev" instead of just "Mcp")
|
||||
if (
|
||||
field_info.discriminator
|
||||
and not field_info.discriminator_mapping
|
||||
and field_info.discriminator_values
|
||||
):
|
||||
hostnames = sorted(
|
||||
parse_url(str(v)).netloc for v in field_info.discriminator_values
|
||||
)
|
||||
field_schema["display_name"] = ", ".join(hostnames)
|
||||
|
||||
# Add other (optional) field info items
|
||||
field_schema.update(
|
||||
field_info.model_dump(
|
||||
@@ -565,17 +549,8 @@ class GraphModel(Graph, GraphMeta):
|
||||
|
||||
for graph in [self] + self.sub_graphs:
|
||||
for node in graph.nodes:
|
||||
# A node's credentials are optional if either:
|
||||
# 1. The node metadata says so (credentials_optional=True), or
|
||||
# 2. All credential fields on the block have defaults (not required by schema)
|
||||
block_required = node.block.input_schema.get_required_fields()
|
||||
creds_required_by_schema = any(
|
||||
fname in block_required
|
||||
for fname in node.block.input_schema.get_credentials_fields()
|
||||
)
|
||||
node_required_map[node.id] = (
|
||||
not node.credentials_optional and creds_required_by_schema
|
||||
)
|
||||
# Track if this node requires credentials (credentials_optional=False means required)
|
||||
node_required_map[node.id] = not node.credentials_optional
|
||||
|
||||
for (
|
||||
field_name,
|
||||
@@ -801,19 +776,6 @@ class GraphModel(Graph, GraphMeta):
|
||||
"'credentials' and `*_credentials` are reserved"
|
||||
)
|
||||
|
||||
# Check custom block-level validation (e.g., MCP dynamic tool arguments).
|
||||
# Blocks can override get_missing_input to report additional missing fields
|
||||
# beyond the standard top-level required fields.
|
||||
if for_run:
|
||||
credential_fields = InputSchema.get_credentials_fields()
|
||||
custom_missing = InputSchema.get_missing_input(node.input_default)
|
||||
for field_name in custom_missing:
|
||||
if (
|
||||
field_name not in provided_inputs
|
||||
and field_name not in credential_fields
|
||||
):
|
||||
node_errors[node.id][field_name] = "This field is required"
|
||||
|
||||
# Get input schema properties and check dependencies
|
||||
input_fields = InputSchema.model_fields
|
||||
|
||||
|
||||
@@ -462,120 +462,3 @@ def test_node_credentials_optional_with_other_metadata():
|
||||
assert node.credentials_optional is True
|
||||
assert node.metadata["position"] == {"x": 100, "y": 200}
|
||||
assert node.metadata["customized_name"] == "My Custom Node"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for MCP Credential Deduplication
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_mcp_credential_combine_different_servers():
|
||||
"""Two MCP credential fields with different server URLs should produce
|
||||
separate entries when combined (not merged into one)."""
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsType
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"])
|
||||
|
||||
field_sentry = CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.MCP]),
|
||||
credentials_types=oauth2_types,
|
||||
credentials_scopes=None,
|
||||
discriminator="server_url",
|
||||
discriminator_values={"https://mcp.sentry.dev/mcp"},
|
||||
)
|
||||
field_linear = CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.MCP]),
|
||||
credentials_types=oauth2_types,
|
||||
credentials_scopes=None,
|
||||
discriminator="server_url",
|
||||
discriminator_values={"https://mcp.linear.app/mcp"},
|
||||
)
|
||||
|
||||
combined = CredentialsFieldInfo.combine(
|
||||
(field_sentry, ("node-sentry", "credentials")),
|
||||
(field_linear, ("node-linear", "credentials")),
|
||||
)
|
||||
|
||||
# Should produce 2 separate credential entries
|
||||
assert len(combined) == 2, (
|
||||
f"Expected 2 credential entries for 2 MCP blocks with different servers, "
|
||||
f"got {len(combined)}: {list(combined.keys())}"
|
||||
)
|
||||
|
||||
# Each entry should contain the server hostname in its key
|
||||
keys = list(combined.keys())
|
||||
assert any(
|
||||
"mcp.sentry.dev" in k for k in keys
|
||||
), f"Expected 'mcp.sentry.dev' in one key, got {keys}"
|
||||
assert any(
|
||||
"mcp.linear.app" in k for k in keys
|
||||
), f"Expected 'mcp.linear.app' in one key, got {keys}"
|
||||
|
||||
|
||||
def test_mcp_credential_combine_same_server():
|
||||
"""Two MCP credential fields with the same server URL should be combined
|
||||
into one credential entry."""
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsType
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"])
|
||||
|
||||
field_a = CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.MCP]),
|
||||
credentials_types=oauth2_types,
|
||||
credentials_scopes=None,
|
||||
discriminator="server_url",
|
||||
discriminator_values={"https://mcp.sentry.dev/mcp"},
|
||||
)
|
||||
field_b = CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.MCP]),
|
||||
credentials_types=oauth2_types,
|
||||
credentials_scopes=None,
|
||||
discriminator="server_url",
|
||||
discriminator_values={"https://mcp.sentry.dev/mcp"},
|
||||
)
|
||||
|
||||
combined = CredentialsFieldInfo.combine(
|
||||
(field_a, ("node-a", "credentials")),
|
||||
(field_b, ("node-b", "credentials")),
|
||||
)
|
||||
|
||||
# Should produce 1 credential entry (same server URL)
|
||||
assert len(combined) == 1, (
|
||||
f"Expected 1 credential entry for 2 MCP blocks with same server, "
|
||||
f"got {len(combined)}: {list(combined.keys())}"
|
||||
)
|
||||
|
||||
|
||||
def test_mcp_credential_combine_no_discriminator_values():
|
||||
"""MCP credential fields without discriminator_values should be merged
|
||||
into a single entry (backwards compat for blocks without server_url set)."""
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsType
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"])
|
||||
|
||||
field_a = CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.MCP]),
|
||||
credentials_types=oauth2_types,
|
||||
credentials_scopes=None,
|
||||
discriminator="server_url",
|
||||
)
|
||||
field_b = CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([ProviderName.MCP]),
|
||||
credentials_types=oauth2_types,
|
||||
credentials_scopes=None,
|
||||
discriminator="server_url",
|
||||
)
|
||||
|
||||
combined = CredentialsFieldInfo.combine(
|
||||
(field_a, ("node-a", "credentials")),
|
||||
(field_b, ("node-b", "credentials")),
|
||||
)
|
||||
|
||||
# Should produce 1 entry (no URL differentiation)
|
||||
assert len(combined) == 1, (
|
||||
f"Expected 1 credential entry for MCP blocks without discriminator_values, "
|
||||
f"got {len(combined)}: {list(combined.keys())}"
|
||||
)
|
||||
|
||||
@@ -29,7 +29,6 @@ from pydantic import (
|
||||
GetCoreSchemaHandler,
|
||||
SecretStr,
|
||||
field_serializer,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic_core import (
|
||||
CoreSchema,
|
||||
@@ -503,25 +502,6 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
provider: CP
|
||||
type: CT
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_legacy_provider(cls, data: Any) -> Any:
|
||||
"""Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug.
|
||||
|
||||
Python 3.13 changed ``str(StrEnum)`` to return ``"ClassName.MEMBER"``
|
||||
instead of the plain value. Old stored credential references may have
|
||||
``provider: "ProviderName.MCP"`` instead of ``"mcp"``.
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
prov = data.get("provider", "")
|
||||
if isinstance(prov, str) and prov.startswith("ProviderName."):
|
||||
member = prov.removeprefix("ProviderName.")
|
||||
try:
|
||||
data = {**data, "provider": ProviderName[member].value}
|
||||
except KeyError:
|
||||
pass
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def allowed_providers(cls) -> tuple[ProviderName, ...] | None:
|
||||
return get_args(cls.model_fields["provider"].annotation)
|
||||
@@ -626,18 +606,11 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
] = defaultdict(list)
|
||||
|
||||
for field, key in fields:
|
||||
if (
|
||||
field.discriminator
|
||||
and not field.discriminator_mapping
|
||||
and field.discriminator_values
|
||||
):
|
||||
# URL-based discrimination (e.g. HTTP host-scoped, MCP server URL):
|
||||
# Each unique host gets its own credential entry.
|
||||
provider_prefix = next(iter(field.provider))
|
||||
# Use .value for enum types to get the plain string (e.g. "mcp" not "ProviderName.MCP")
|
||||
prefix_str = getattr(provider_prefix, "value", str(provider_prefix))
|
||||
if field.provider == frozenset([ProviderName.HTTP]):
|
||||
# HTTP host-scoped credentials can have different hosts that reqires different credential sets.
|
||||
# Group by host extracted from the URL
|
||||
providers = frozenset(
|
||||
[cast(CP, prefix_str)]
|
||||
[cast(CP, "http")]
|
||||
+ [
|
||||
cast(CP, parse_url(str(value)).netloc)
|
||||
for value in field.discriminator_values
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from backend.app import run_processes
|
||||
from backend.executor import DatabaseManager
|
||||
from backend.data.db_manager import DatabaseManager
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
from .database import DatabaseManager, DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
from .manager import ExecutionManager
|
||||
from .scheduler import Scheduler
|
||||
|
||||
__all__ = [
|
||||
"DatabaseManager",
|
||||
"DatabaseManagerClient",
|
||||
"DatabaseManagerAsyncClient",
|
||||
"ExecutionManager",
|
||||
"Scheduler",
|
||||
]
|
||||
|
||||
@@ -22,7 +22,7 @@ from backend.util.settings import Settings
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.data.db_manager import DatabaseManagerAsyncClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.data.db_manager import DatabaseManagerAsyncClient
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Redis-based distributed locking for cluster coordination."""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -19,6 +20,7 @@ class ClusterLock:
|
||||
self.owner_id = owner_id
|
||||
self.timeout = timeout
|
||||
self._last_refresh = 0.0
|
||||
self._refresh_lock = threading.Lock()
|
||||
|
||||
def try_acquire(self) -> str | None:
|
||||
"""Try to acquire the lock.
|
||||
@@ -31,7 +33,8 @@ class ClusterLock:
|
||||
try:
|
||||
success = self.redis.set(self.key, self.owner_id, nx=True, ex=self.timeout)
|
||||
if success:
|
||||
self._last_refresh = time.time()
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = time.time()
|
||||
return self.owner_id # Successfully acquired
|
||||
|
||||
# Failed to acquire, get current owner
|
||||
@@ -57,23 +60,27 @@ class ClusterLock:
|
||||
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
|
||||
During rate limiting, still verifies lock existence but skips TTL extension.
|
||||
Setting _last_refresh to 0 bypasses rate limiting for testing.
|
||||
|
||||
Thread-safe: uses _refresh_lock to protect _last_refresh access.
|
||||
"""
|
||||
# Calculate refresh interval: max(timeout // 10, 1)
|
||||
refresh_interval = max(self.timeout // 10, 1)
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we're within the rate limit period
|
||||
# Check if we're within the rate limit period (thread-safe read)
|
||||
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
|
||||
with self._refresh_lock:
|
||||
last_refresh = self._last_refresh
|
||||
is_rate_limited = (
|
||||
self._last_refresh > 0
|
||||
and (current_time - self._last_refresh) < refresh_interval
|
||||
last_refresh > 0 and (current_time - last_refresh) < refresh_interval
|
||||
)
|
||||
|
||||
try:
|
||||
# Always verify lock existence, even during rate limiting
|
||||
current_value = self.redis.get(self.key)
|
||||
if not current_value:
|
||||
self._last_refresh = 0
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
stored_owner = (
|
||||
@@ -82,7 +89,8 @@ class ClusterLock:
|
||||
else str(current_value)
|
||||
)
|
||||
if stored_owner != self.owner_id:
|
||||
self._last_refresh = 0
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
# If rate limited, return True but don't update TTL or timestamp
|
||||
@@ -91,25 +99,30 @@ class ClusterLock:
|
||||
|
||||
# Perform actual refresh
|
||||
if self.redis.expire(self.key, self.timeout):
|
||||
self._last_refresh = current_time
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = current_time
|
||||
return True
|
||||
|
||||
self._last_refresh = 0
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ClusterLock.refresh failed for key {self.key}: {e}")
|
||||
self._last_refresh = 0
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
def release(self):
|
||||
"""Release the lock."""
|
||||
if self._last_refresh == 0:
|
||||
return
|
||||
with self._refresh_lock:
|
||||
if self._last_refresh == 0:
|
||||
return
|
||||
|
||||
try:
|
||||
self.redis.delete(self.key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._last_refresh = 0.0
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = 0.0
|
||||
|
||||
@@ -20,7 +20,6 @@ from backend.blocks import get_block
|
||||
from backend.blocks._base import BlockSchema
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.blocks.mcp.block import MCPToolBlock
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
@@ -93,7 +92,10 @@ from .utils import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
from backend.data.db_manager import (
|
||||
DatabaseManagerAsyncClient,
|
||||
DatabaseManagerClient,
|
||||
)
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
@@ -229,18 +231,6 @@ async def execute_node(
|
||||
_input_data.nodes_input_masks = nodes_input_masks
|
||||
_input_data.user_id = user_id
|
||||
input_data = _input_data.model_dump()
|
||||
elif isinstance(node_block, MCPToolBlock):
|
||||
_mcp_data = MCPToolBlock.Input(**node.input_default)
|
||||
# Dynamic tool fields are flattened to top-level by validate_exec
|
||||
# (via get_input_defaults). Collect them back into tool_arguments.
|
||||
tool_schema = _mcp_data.tool_input_schema
|
||||
tool_props = set(tool_schema.get("properties", {}).keys())
|
||||
merged_args = {**_mcp_data.tool_arguments}
|
||||
for key in tool_props:
|
||||
if key in input_data:
|
||||
merged_args[key] = input_data[key]
|
||||
_mcp_data.tool_arguments = merged_args
|
||||
input_data = _mcp_data.model_dump()
|
||||
data.inputs = input_data
|
||||
|
||||
# Execute the node
|
||||
@@ -277,34 +267,8 @@ async def execute_node(
|
||||
|
||||
# Handle regular credentials fields
|
||||
for field_name, input_type in input_model.get_credentials_fields().items():
|
||||
field_value = input_data.get(field_name)
|
||||
if not field_value or (
|
||||
isinstance(field_value, dict) and not field_value.get("id")
|
||||
):
|
||||
# No credentials configured — nullify so JSON schema validation
|
||||
# doesn't choke on the empty default `{}`.
|
||||
input_data[field_name] = None
|
||||
continue # Block runs without credentials
|
||||
|
||||
credentials_meta = input_type(**field_value)
|
||||
# Write normalized values back so JSON schema validation also passes
|
||||
# (model_validator may have fixed legacy formats like "ProviderName.MCP")
|
||||
input_data[field_name] = credentials_meta.model_dump(mode="json")
|
||||
try:
|
||||
credentials, lock = await creds_manager.acquire(
|
||||
user_id, credentials_meta.id
|
||||
)
|
||||
except ValueError:
|
||||
# Credential was deleted or doesn't exist.
|
||||
# If the field has a default, run without credentials.
|
||||
if input_model.model_fields[field_name].default is not None:
|
||||
log_metadata.warning(
|
||||
f"Credentials #{credentials_meta.id} not found, "
|
||||
"running without (field has default)"
|
||||
)
|
||||
input_data[field_name] = None
|
||||
continue
|
||||
raise
|
||||
credentials_meta = input_type(**input_data[field_name])
|
||||
credentials, lock = await creds_manager.acquire(user_id, credentials_meta.id)
|
||||
creds_locks.append(lock)
|
||||
extra_exec_kwargs[field_name] = credentials
|
||||
|
||||
|
||||
@@ -260,13 +260,7 @@ async def _validate_node_input_credentials(
|
||||
# Track if any credential field is missing for this node
|
||||
has_missing_credentials = False
|
||||
|
||||
# A credential field is optional if the node metadata says so, or if
|
||||
# the block schema declares a default for the field.
|
||||
required_fields = block.input_schema.get_required_fields()
|
||||
is_creds_optional = node.credentials_optional
|
||||
|
||||
for field_name, credentials_meta_type in credentials_fields.items():
|
||||
field_is_optional = is_creds_optional or field_name not in required_fields
|
||||
try:
|
||||
# Check nodes_input_masks first, then input_default
|
||||
field_value = None
|
||||
@@ -279,7 +273,7 @@ async def _validate_node_input_credentials(
|
||||
elif field_name in node.input_default:
|
||||
# For optional credentials, don't use input_default - treat as missing
|
||||
# This prevents stale credential IDs from failing validation
|
||||
if field_is_optional:
|
||||
if node.credentials_optional:
|
||||
field_value = None
|
||||
else:
|
||||
field_value = node.input_default[field_name]
|
||||
@@ -289,8 +283,8 @@ async def _validate_node_input_credentials(
|
||||
isinstance(field_value, dict) and not field_value.get("id")
|
||||
):
|
||||
has_missing_credentials = True
|
||||
# If credential field is optional, skip instead of error
|
||||
if field_is_optional:
|
||||
# If node has credentials_optional flag, mark for skipping instead of error
|
||||
if node.credentials_optional:
|
||||
continue # Don't add error, will be marked for skip after loop
|
||||
else:
|
||||
credential_errors[node.id][
|
||||
@@ -340,16 +334,16 @@ async def _validate_node_input_credentials(
|
||||
] = "Invalid credentials: type/provider mismatch"
|
||||
continue
|
||||
|
||||
# If node has optional credentials and any are missing, allow running without.
|
||||
# The executor will pass credentials=None to the block's run().
|
||||
# If node has optional credentials and any are missing, mark for skipping
|
||||
# But only if there are no other errors for this node
|
||||
if (
|
||||
has_missing_credentials
|
||||
and is_creds_optional
|
||||
and node.credentials_optional
|
||||
and node.id not in credential_errors
|
||||
):
|
||||
nodes_to_skip.add(node.id)
|
||||
logger.info(
|
||||
f"Node #{node.id}: optional credentials not configured, "
|
||||
"running without"
|
||||
f"Node #{node.id} will be skipped: optional credentials not configured"
|
||||
)
|
||||
|
||||
return credential_errors, nodes_to_skip
|
||||
|
||||
@@ -495,7 +495,6 @@ async def test_validate_node_input_credentials_returns_nodes_to_skip(
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||
"credentials": mock_credentials_field_type
|
||||
}
|
||||
mock_block.input_schema.get_required_fields.return_value = {"credentials"}
|
||||
mock_node.block = mock_block
|
||||
|
||||
# Create mock graph
|
||||
@@ -509,8 +508,8 @@ async def test_validate_node_input_credentials_returns_nodes_to_skip(
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
# Node should NOT be in nodes_to_skip (runs without credentials) and not in errors
|
||||
assert mock_node.id not in nodes_to_skip
|
||||
# Node should be in nodes_to_skip, not in errors
|
||||
assert mock_node.id in nodes_to_skip
|
||||
assert mock_node.id not in errors
|
||||
|
||||
|
||||
@@ -536,7 +535,6 @@ async def test_validate_node_input_credentials_required_missing_creds_error(
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||
"credentials": mock_credentials_field_type
|
||||
}
|
||||
mock_block.input_schema.get_required_fields.return_value = {"credentials"}
|
||||
mock_node.block = mock_block
|
||||
|
||||
# Create mock graph
|
||||
|
||||
@@ -22,27 +22,6 @@ from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def provider_matches(stored: str, expected: str) -> bool:
|
||||
"""Compare provider strings, handling Python 3.13 ``str(StrEnum)`` bug.
|
||||
|
||||
On Python 3.13, ``str(ProviderName.MCP)`` returns ``"ProviderName.MCP"``
|
||||
instead of ``"mcp"``. OAuth states persisted with the buggy format need
|
||||
to match when ``expected`` is the canonical value (e.g. ``"mcp"``).
|
||||
"""
|
||||
if stored == expected:
|
||||
return True
|
||||
if stored.startswith("ProviderName."):
|
||||
member = stored.removeprefix("ProviderName.")
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
try:
|
||||
return ProviderName[member].value == expected
|
||||
except KeyError:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
# This is an overrride since ollama doesn't actually require an API key, but the creddential system enforces one be attached
|
||||
ollama_credentials = APIKeyCredentials(
|
||||
id="744fdc56-071a-4761-b5a5-0af0ce10a2b5",
|
||||
@@ -410,7 +389,7 @@ class IntegrationCredentialsStore:
|
||||
self, user_id: str, provider: str
|
||||
) -> list[Credentials]:
|
||||
credentials = await self.get_all_creds(user_id)
|
||||
return [c for c in credentials if provider_matches(c.provider, provider)]
|
||||
return [c for c in credentials if c.provider == provider]
|
||||
|
||||
async def get_authorized_providers(self, user_id: str) -> list[str]:
|
||||
credentials = await self.get_all_creds(user_id)
|
||||
@@ -506,6 +485,17 @@ class IntegrationCredentialsStore:
|
||||
async with self.edit_user_integrations(user_id) as user_integrations:
|
||||
user_integrations.oauth_states.append(state)
|
||||
|
||||
async with await self.locked_user_integrations(user_id):
|
||||
|
||||
user_integrations = await self._get_user_integrations(user_id)
|
||||
oauth_states = user_integrations.oauth_states
|
||||
oauth_states.append(state)
|
||||
user_integrations.oauth_states = oauth_states
|
||||
|
||||
await self.db_manager.update_user_integrations(
|
||||
user_id=user_id, data=user_integrations
|
||||
)
|
||||
|
||||
return token, code_challenge
|
||||
|
||||
def _generate_code_challenge(self) -> tuple[str, str]:
|
||||
@@ -531,7 +521,7 @@ class IntegrationCredentialsStore:
|
||||
state
|
||||
for state in oauth_states
|
||||
if secrets.compare_digest(state.token, token)
|
||||
and provider_matches(state.provider, provider)
|
||||
and state.provider == provider
|
||||
and state.expires_at > now.timestamp()
|
||||
),
|
||||
None,
|
||||
|
||||
@@ -9,10 +9,7 @@ from redis.asyncio.lock import Lock as AsyncRedisLock
|
||||
|
||||
from backend.data.model import Credentials, OAuth2Credentials
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.integrations.credentials_store import (
|
||||
IntegrationCredentialsStore,
|
||||
provider_matches,
|
||||
)
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
@@ -140,10 +137,7 @@ class IntegrationCredentialsManager:
|
||||
self, user_id: str, credentials: OAuth2Credentials, lock: bool = True
|
||||
) -> OAuth2Credentials:
|
||||
async with self._locked(user_id, credentials.id, "refresh"):
|
||||
if provider_matches(credentials.provider, ProviderName.MCP.value):
|
||||
oauth_handler = create_mcp_oauth_handler(credentials)
|
||||
else:
|
||||
oauth_handler = await _get_provider_oauth_handler(credentials.provider)
|
||||
oauth_handler = await _get_provider_oauth_handler(credentials.provider)
|
||||
if oauth_handler.needs_refresh(credentials):
|
||||
logger.debug(
|
||||
f"Refreshing '{credentials.provider}' "
|
||||
@@ -242,31 +236,3 @@ async def _get_provider_oauth_handler(provider_name_str: str) -> "BaseOAuthHandl
|
||||
client_secret=client_secret,
|
||||
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
|
||||
)
|
||||
|
||||
|
||||
def create_mcp_oauth_handler(
|
||||
credentials: OAuth2Credentials,
|
||||
) -> "BaseOAuthHandler":
|
||||
"""Create an MCPOAuthHandler from credential metadata for token refresh.
|
||||
|
||||
MCP OAuth handlers have dynamic endpoints discovered per-server, so they
|
||||
can't be registered as singletons in HANDLERS_BY_NAME. Instead, the handler
|
||||
is reconstructed from metadata stored on the credential during initial auth.
|
||||
"""
|
||||
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||
|
||||
meta = credentials.metadata or {}
|
||||
token_url = meta.get("mcp_token_url", "")
|
||||
if not token_url:
|
||||
raise ValueError(
|
||||
f"MCP credential {credentials.id} is missing 'mcp_token_url' metadata; "
|
||||
"cannot refresh tokens"
|
||||
)
|
||||
return MCPOAuthHandler(
|
||||
client_id=meta.get("mcp_client_id", ""),
|
||||
client_secret=meta.get("mcp_client_secret", ""),
|
||||
redirect_uri="", # Not needed for token refresh
|
||||
authorize_url="", # Not needed for token refresh
|
||||
token_url=token_url,
|
||||
resource_url=meta.get("mcp_resource_url"),
|
||||
)
|
||||
|
||||
@@ -30,7 +30,6 @@ class ProviderName(str, Enum):
|
||||
IDEOGRAM = "ideogram"
|
||||
JINA = "jina"
|
||||
LLAMA_API = "llama_api"
|
||||
MCP = "mcp"
|
||||
MEDIUM = "medium"
|
||||
MEM0 = "mem0"
|
||||
NOTION = "notion"
|
||||
|
||||
@@ -51,21 +51,6 @@ async def _on_graph_activate(graph: "BaseGraph | GraphModel", user_id: str):
|
||||
if (
|
||||
creds_meta := new_node.input_default.get(creds_field_name)
|
||||
) and not await get_credentials(creds_meta["id"]):
|
||||
# If the credential field is optional (has a default in the
|
||||
# schema, or node metadata marks it optional), clear the stale
|
||||
# reference instead of blocking the save.
|
||||
creds_field_optional = (
|
||||
new_node.credentials_optional
|
||||
or creds_field_name not in block_input_schema.get_required_fields()
|
||||
)
|
||||
if creds_field_optional:
|
||||
new_node.input_default[creds_field_name] = {}
|
||||
logger.warning(
|
||||
f"Node #{new_node.id}: cleared stale optional "
|
||||
f"credentials #{creds_meta['id']} for "
|
||||
f"'{creds_field_name}'"
|
||||
)
|
||||
continue
|
||||
raise ValueError(
|
||||
f"Node #{new_node.id} input '{creds_field_name}' updated with "
|
||||
f"non-existent credentials #{creds_meta['id']}"
|
||||
|
||||
@@ -13,12 +13,15 @@ if TYPE_CHECKING:
|
||||
from openai import AsyncOpenAI
|
||||
from supabase import AClient, Client
|
||||
|
||||
from backend.data.db_manager import (
|
||||
DatabaseManagerAsyncClient,
|
||||
DatabaseManagerClient,
|
||||
)
|
||||
from backend.data.execution import (
|
||||
AsyncRedisExecutionEventBus,
|
||||
RedisExecutionEventBus,
|
||||
)
|
||||
from backend.data.rabbitmq import AsyncRabbitMQ, SyncRabbitMQ
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
from backend.executor.scheduler import SchedulerClient
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
@@ -27,7 +30,7 @@ if TYPE_CHECKING:
|
||||
@thread_cached
|
||||
def get_database_manager_client() -> "DatabaseManagerClient":
|
||||
"""Get a thread-cached DatabaseManagerClient with request retry enabled."""
|
||||
from backend.executor import DatabaseManagerClient
|
||||
from backend.data.db_manager import DatabaseManagerClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerClient, request_retry=True)
|
||||
@@ -38,7 +41,7 @@ def get_database_manager_async_client(
|
||||
should_retry: bool = True,
|
||||
) -> "DatabaseManagerAsyncClient":
|
||||
"""Get a thread-cached DatabaseManagerAsyncClient with request retry enabled."""
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.data.db_manager import DatabaseManagerAsyncClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerAsyncClient, request_retry=should_retry)
|
||||
@@ -106,6 +109,20 @@ async def get_async_execution_queue() -> "AsyncRabbitMQ":
|
||||
return client
|
||||
|
||||
|
||||
# ============ CoPilot Queue Helpers ============ #
|
||||
|
||||
|
||||
@thread_cached
|
||||
async def get_async_copilot_queue() -> "AsyncRabbitMQ":
|
||||
"""Get a thread-cached AsyncRabbitMQ CoPilot queue client."""
|
||||
from backend.copilot.executor.utils import create_copilot_queue_config
|
||||
from backend.data.rabbitmq import AsyncRabbitMQ
|
||||
|
||||
client = AsyncRabbitMQ(create_copilot_queue_config())
|
||||
await client.connect()
|
||||
return client
|
||||
|
||||
|
||||
# ============ Integration Credentials Store ============ #
|
||||
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ class HostResolver(abc.AbstractResolver):
|
||||
def __init__(self, ssl_hostname: str, ip_addresses: list[str]):
|
||||
self.ssl_hostname = ssl_hostname
|
||||
self.ip_addresses = ip_addresses
|
||||
self._default = aiohttp.ThreadedResolver()
|
||||
self._default = aiohttp.AsyncResolver()
|
||||
|
||||
async def resolve(self, host, port=0, family=socket.AF_INET):
|
||||
if host == self.ssl_hostname:
|
||||
@@ -467,7 +467,7 @@ class Requests:
|
||||
resolver = HostResolver(ssl_hostname=hostname, ip_addresses=ip_addresses)
|
||||
ssl_context = ssl.create_default_context()
|
||||
connector = aiohttp.TCPConnector(resolver=resolver, ssl=ssl_context)
|
||||
session_kwargs: dict = {}
|
||||
session_kwargs = {}
|
||||
if connector:
|
||||
session_kwargs["connector"] = connector
|
||||
|
||||
|
||||
@@ -211,16 +211,23 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="The port for execution manager daemon to run on",
|
||||
)
|
||||
|
||||
num_copilot_workers: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Number of concurrent CoPilot executor workers",
|
||||
)
|
||||
|
||||
copilot_executor_port: int = Field(
|
||||
default=8008,
|
||||
description="The port for CoPilot executor daemon to run on",
|
||||
)
|
||||
|
||||
execution_scheduler_port: int = Field(
|
||||
default=8003,
|
||||
description="The port for execution scheduler daemon to run on",
|
||||
)
|
||||
|
||||
agent_server_port: int = Field(
|
||||
default=8004,
|
||||
description="The port for agent server daemon to run on",
|
||||
)
|
||||
|
||||
database_api_port: int = Field(
|
||||
default=8005,
|
||||
description="The port for database server API to run on",
|
||||
@@ -368,6 +375,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default=600,
|
||||
description="The timeout in seconds for Agent Generator service requests (includes retries for rate limits)",
|
||||
)
|
||||
agentgenerator_use_dummy: bool = Field(
|
||||
default=False,
|
||||
description="Use dummy agent generator responses for testing (bypasses external service)",
|
||||
)
|
||||
|
||||
enable_example_blocks: bool = Field(
|
||||
default=False,
|
||||
|
||||
@@ -11,6 +11,7 @@ from backend.api.rest_api import AgentServer
|
||||
from backend.blocks._base import Block, BlockSchema
|
||||
from backend.data import db
|
||||
from backend.data.block import initialize_blocks
|
||||
from backend.data.db_manager import DatabaseManager
|
||||
from backend.data.execution import (
|
||||
ExecutionContext,
|
||||
ExecutionStatus,
|
||||
@@ -19,7 +20,7 @@ from backend.data.execution import (
|
||||
)
|
||||
from backend.data.model import _BaseCredentials
|
||||
from backend.data.user import create_default_user
|
||||
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
|
||||
from backend.executor import ExecutionManager, Scheduler
|
||||
from backend.notifications.notifications import NotificationManager
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -116,6 +116,7 @@ ws = "backend.ws:main"
|
||||
scheduler = "backend.scheduler:main"
|
||||
notification = "backend.notification:main"
|
||||
executor = "backend.exec:main"
|
||||
copilot-executor = "backend.copilot.executor.__main__:main"
|
||||
cli = "backend.cli:main"
|
||||
format = "linter:format"
|
||||
lint = "linter:lint"
|
||||
|
||||
@@ -9,10 +9,8 @@ from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.agent_generator import core
|
||||
from backend.api.features.chat.tools.agent_generator.core import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
)
|
||||
from backend.copilot.tools.agent_generator import core
|
||||
from backend.copilot.tools.agent_generator.core import AgentGeneratorNotConfiguredError
|
||||
|
||||
|
||||
class TestServiceNotConfigured:
|
||||
|
||||
@@ -9,7 +9,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.agent_generator import core
|
||||
from backend.copilot.tools.agent_generator import core
|
||||
|
||||
|
||||
class TestGetLibraryAgentsForGeneration:
|
||||
@@ -31,18 +31,20 @@ class TestGetLibraryAgentsForGeneration:
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = [mock_agent]
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_list:
|
||||
core,
|
||||
"library_db",
|
||||
return_value=mock_db,
|
||||
):
|
||||
result = await core.get_library_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query="send email",
|
||||
)
|
||||
|
||||
mock_list.assert_called_once_with(
|
||||
mock_db.list_library_agents.assert_called_once_with(
|
||||
user_id="user-123",
|
||||
search_term="send email",
|
||||
page=1,
|
||||
@@ -80,11 +82,13 @@ class TestGetLibraryAgentsForGeneration:
|
||||
),
|
||||
]
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
core,
|
||||
"library_db",
|
||||
return_value=mock_db,
|
||||
):
|
||||
result = await core.get_library_agents_for_generation(
|
||||
user_id="user-123",
|
||||
@@ -101,18 +105,20 @@ class TestGetLibraryAgentsForGeneration:
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = []
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_list:
|
||||
core,
|
||||
"library_db",
|
||||
return_value=mock_db,
|
||||
):
|
||||
await core.get_library_agents_for_generation(
|
||||
user_id="user-123",
|
||||
max_results=5,
|
||||
)
|
||||
|
||||
mock_list.assert_called_once_with(
|
||||
mock_db.list_library_agents.assert_called_once_with(
|
||||
user_id="user-123",
|
||||
search_term=None,
|
||||
page=1,
|
||||
@@ -144,24 +150,24 @@ class TestSearchMarketplaceAgentsForGeneration:
|
||||
mock_graph.input_schema = {"type": "object"}
|
||||
mock_graph.output_schema = {"type": "object"}
|
||||
|
||||
mock_store_db = MagicMock()
|
||||
mock_store_db.get_store_agents = AsyncMock(return_value=mock_response)
|
||||
|
||||
mock_graph_db = MagicMock()
|
||||
mock_graph_db.get_store_listed_graphs = AsyncMock(
|
||||
return_value={"graph-123": mock_graph}
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.store.db.get_store_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_search,
|
||||
patch(
|
||||
"backend.api.features.chat.tools.agent_generator.core.get_store_listed_graphs",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"graph-123": mock_graph},
|
||||
),
|
||||
patch.object(core, "store_db", return_value=mock_store_db),
|
||||
patch.object(core, "graph_db", return_value=mock_graph_db),
|
||||
):
|
||||
result = await core.search_marketplace_agents_for_generation(
|
||||
search_query="automation",
|
||||
max_results=10,
|
||||
)
|
||||
|
||||
mock_search.assert_called_once_with(
|
||||
mock_store_db.get_store_agents.assert_called_once_with(
|
||||
search_query="automation",
|
||||
page=1,
|
||||
page_size=10,
|
||||
@@ -707,7 +713,7 @@ class TestExtractUuidsFromText:
|
||||
|
||||
|
||||
class TestGetLibraryAgentById:
|
||||
"""Test get_library_agent_by_id function (and its alias get_library_agent_by_graph_id)."""
|
||||
"""Test get_library_agent_by_id function (alias: get_library_agent_by_graph_id)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_agent_when_found_by_graph_id(self):
|
||||
@@ -720,12 +726,10 @@ class TestGetLibraryAgentById:
|
||||
mock_agent.input_schema = {"properties": {}}
|
||||
mock_agent.output_schema = {"properties": {}}
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_agent,
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
|
||||
|
||||
with patch.object(core, "library_db", return_value=mock_db):
|
||||
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
||||
|
||||
assert result is not None
|
||||
@@ -743,20 +747,11 @@ class TestGetLibraryAgentById:
|
||||
mock_agent.input_schema = {"properties": {}}
|
||||
mock_agent.output_schema = {"properties": {}}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None, # Not found by graph_id
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_agent, # Found by library ID
|
||||
),
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=None)
|
||||
mock_db.get_library_agent = AsyncMock(return_value=mock_agent)
|
||||
|
||||
with patch.object(core, "library_db", return_value=mock_db):
|
||||
result = await core.get_library_agent_by_id("user-123", "library-id-123")
|
||||
|
||||
assert result is not None
|
||||
@@ -766,20 +761,13 @@ class TestGetLibraryAgentById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_not_found_by_either_method(self):
|
||||
"""Test that None is returned when agent not found by either method."""
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=core.NotFoundError("Not found"),
|
||||
),
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=None)
|
||||
mock_db.get_library_agent = AsyncMock(
|
||||
side_effect=core.NotFoundError("Not found")
|
||||
)
|
||||
|
||||
with patch.object(core, "library_db", return_value=mock_db):
|
||||
result = await core.get_library_agent_by_id("user-123", "nonexistent")
|
||||
|
||||
assert result is None
|
||||
@@ -787,27 +775,20 @@ class TestGetLibraryAgentById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_exception(self):
|
||||
"""Test that None is returned when exception occurs in both lookups."""
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Database error"),
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Database error"),
|
||||
),
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_library_agent_by_graph_id = AsyncMock(
|
||||
side_effect=Exception("Database error")
|
||||
)
|
||||
mock_db.get_library_agent = AsyncMock(side_effect=Exception("Database error"))
|
||||
|
||||
with patch.object(core, "library_db", return_value=mock_db):
|
||||
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alias_works(self):
|
||||
"""Test that get_library_agent_by_graph_id is an alias for get_library_agent_by_id."""
|
||||
"""Test that get_library_agent_by_graph_id is an alias."""
|
||||
assert core.get_library_agent_by_graph_id is core.get_library_agent_by_id
|
||||
|
||||
|
||||
@@ -828,20 +809,11 @@ class TestGetAllRelevantAgentsWithUuids:
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_agent,
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
),
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
|
||||
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(core, "library_db", return_value=mock_db):
|
||||
result = await core.get_all_relevant_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query="Use agent 46631191-e8a8-486f-ad90-84f89738321d",
|
||||
|
||||
@@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.agent_generator import service
|
||||
from backend.copilot.tools.agent_generator import service
|
||||
|
||||
|
||||
class TestServiceConfiguration:
|
||||
@@ -25,6 +25,7 @@ class TestServiceConfiguration:
|
||||
"""Test that external service is not configured when host is empty."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.config.agentgenerator_host = ""
|
||||
mock_settings.config.agentgenerator_use_dummy = False
|
||||
|
||||
with patch.object(service, "_get_settings", return_value=mock_settings):
|
||||
assert service.is_external_service_configured() is False
|
||||
|
||||
@@ -158,6 +158,41 @@ services:
|
||||
max-size: "10m"
|
||||
max-file: "3"
|
||||
|
||||
copilot_executor:
|
||||
build:
|
||||
context: ../
|
||||
dockerfile: autogpt_platform/backend/Dockerfile
|
||||
target: server
|
||||
command: ["python", "-m", "backend.copilot.executor"]
|
||||
develop:
|
||||
watch:
|
||||
- path: ./
|
||||
target: autogpt_platform/backend/
|
||||
action: rebuild
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
rabbitmq:
|
||||
condition: service_healthy
|
||||
db:
|
||||
condition: service_healthy
|
||||
migrate:
|
||||
condition: service_completed_successfully
|
||||
database_manager:
|
||||
condition: service_started
|
||||
<<: *backend-env-files
|
||||
environment:
|
||||
<<: *backend-env
|
||||
ports:
|
||||
- "8008:8008"
|
||||
networks:
|
||||
- app-network
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "10m"
|
||||
max-file: "3"
|
||||
|
||||
websocket_server:
|
||||
build:
|
||||
context: ../
|
||||
|
||||
@@ -53,6 +53,12 @@ services:
|
||||
file: ./docker-compose.platform.yml
|
||||
service: executor
|
||||
|
||||
copilot_executor:
|
||||
<<: *agpt-services
|
||||
extends:
|
||||
file: ./docker-compose.platform.yml
|
||||
service: copilot_executor
|
||||
|
||||
websocket_server:
|
||||
<<: *agpt-services
|
||||
extends:
|
||||
@@ -174,5 +180,6 @@ services:
|
||||
- deps
|
||||
- rest_server
|
||||
- executor
|
||||
- copilot_executor
|
||||
- websocket_server
|
||||
- database_manager
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
import { NextResponse } from "next/server";
|
||||
|
||||
/**
|
||||
* Safely encode a value as JSON for embedding in a script tag.
|
||||
* Escapes characters that could break out of the script context to prevent XSS.
|
||||
*/
|
||||
function safeJsonStringify(value: unknown): string {
|
||||
return JSON.stringify(value)
|
||||
.replace(/</g, "\\u003c")
|
||||
.replace(/>/g, "\\u003e")
|
||||
.replace(/&/g, "\\u0026");
|
||||
}
|
||||
|
||||
// MCP-specific OAuth callback route.
|
||||
//
|
||||
// Unlike the generic oauth_callback which relies on window.opener.postMessage,
|
||||
// this route uses BroadcastChannel as the PRIMARY communication method.
|
||||
// This is critical because cross-origin OAuth flows (e.g. Sentry → localhost)
|
||||
// often lose window.opener due to COOP (Cross-Origin-Opener-Policy) headers.
|
||||
//
|
||||
// BroadcastChannel works across all same-origin tabs/popups regardless of opener.
|
||||
export async function GET(request: Request) {
|
||||
const { searchParams } = new URL(request.url);
|
||||
const code = searchParams.get("code");
|
||||
const state = searchParams.get("state");
|
||||
|
||||
const success = Boolean(code && state);
|
||||
const message = success
|
||||
? { success: true, code, state }
|
||||
: {
|
||||
success: false,
|
||||
message: `Missing parameters: ${searchParams.toString()}`,
|
||||
};
|
||||
|
||||
return new NextResponse(
|
||||
`<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>MCP Sign-in</title></head>
|
||||
<body style="font-family: system-ui, -apple-system, sans-serif; display: flex; align-items: center; justify-content: center; min-height: 100vh; margin: 0; background: #f9fafb;">
|
||||
<div style="text-align: center; max-width: 400px; padding: 2rem;">
|
||||
<div id="spinner" style="margin: 0 auto 1rem; width: 32px; height: 32px; border: 3px solid #e5e7eb; border-top-color: #3b82f6; border-radius: 50%; animation: spin 0.8s linear infinite;"></div>
|
||||
<p id="status" style="color: #374151; font-size: 16px;">Completing sign-in...</p>
|
||||
</div>
|
||||
<style>@keyframes spin { to { transform: rotate(360deg); } }</style>
|
||||
<script>
|
||||
(function() {
|
||||
var msg = ${safeJsonStringify(message)};
|
||||
var sent = false;
|
||||
|
||||
// Method 1: BroadcastChannel (reliable across tabs/popups, no opener needed)
|
||||
try {
|
||||
var bc = new BroadcastChannel("mcp_oauth");
|
||||
bc.postMessage({ type: "mcp_oauth_result", success: msg.success, code: msg.code, state: msg.state, message: msg.message });
|
||||
bc.close();
|
||||
sent = true;
|
||||
} catch(e) { /* BroadcastChannel not supported */ }
|
||||
|
||||
// Method 2: window.opener.postMessage (fallback for same-origin popups)
|
||||
try {
|
||||
if (window.opener && !window.opener.closed) {
|
||||
window.opener.postMessage(
|
||||
{ message_type: "mcp_oauth_result", success: msg.success, code: msg.code, state: msg.state, message: msg.message },
|
||||
window.location.origin
|
||||
);
|
||||
sent = true;
|
||||
}
|
||||
} catch(e) { /* opener not available (COOP) */ }
|
||||
|
||||
// Method 3: localStorage (most reliable cross-tab fallback)
|
||||
try {
|
||||
localStorage.setItem("mcp_oauth_result", JSON.stringify(msg));
|
||||
sent = true;
|
||||
} catch(e) { /* localStorage not available */ }
|
||||
|
||||
var statusEl = document.getElementById("status");
|
||||
var spinnerEl = document.getElementById("spinner");
|
||||
spinnerEl.style.display = "none";
|
||||
|
||||
if (msg.success && sent) {
|
||||
statusEl.textContent = "Sign-in complete! This window will close.";
|
||||
statusEl.style.color = "#059669";
|
||||
setTimeout(function() { window.close(); }, 1500);
|
||||
} else if (msg.success) {
|
||||
statusEl.textContent = "Sign-in successful! You can close this tab and return to the builder.";
|
||||
statusEl.style.color = "#059669";
|
||||
} else {
|
||||
statusEl.textContent = "Sign-in failed: " + (msg.message || "Unknown error");
|
||||
statusEl.style.color = "#dc2626";
|
||||
}
|
||||
})();
|
||||
</script>
|
||||
</body>
|
||||
</html>`,
|
||||
{ headers: { "Content-Type": "text/html" } },
|
||||
);
|
||||
}
|
||||
@@ -47,10 +47,7 @@ export type CustomNode = XYNode<CustomNodeData, "custom">;
|
||||
|
||||
export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||
({ data, id: nodeId, selected }) => {
|
||||
const { inputSchema, outputSchema, isMCPWithTool } = useCustomNode({
|
||||
data,
|
||||
nodeId,
|
||||
});
|
||||
const { inputSchema, outputSchema } = useCustomNode({ data, nodeId });
|
||||
|
||||
const isAgent = data.uiType === BlockUIType.AGENT;
|
||||
|
||||
@@ -101,7 +98,6 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||
jsonSchema={preprocessInputSchema(inputSchema)}
|
||||
nodeId={nodeId}
|
||||
uiType={data.uiType}
|
||||
isMCPWithTool={isMCPWithTool}
|
||||
className={cn(
|
||||
"bg-white px-4",
|
||||
isWebhook && "pointer-events-none opacity-50",
|
||||
|
||||
@@ -20,8 +20,10 @@ type Props = {
|
||||
|
||||
export const NodeHeader = ({ data, nodeId }: Props) => {
|
||||
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
||||
|
||||
const title = (data.metadata?.customized_name as string) || data.title;
|
||||
const title =
|
||||
(data.metadata?.customized_name as string) ||
|
||||
data.hardcodedValues?.agent_name ||
|
||||
data.title;
|
||||
|
||||
const [isEditingTitle, setIsEditingTitle] = useState(false);
|
||||
const [editedTitle, setEditedTitle] = useState(title);
|
||||
|
||||
@@ -3,34 +3,6 @@ import { CustomNodeData } from "./CustomNode";
|
||||
import { BlockUIType } from "../../../types";
|
||||
import { useMemo } from "react";
|
||||
import { mergeSchemaForResolution } from "./helpers";
|
||||
/**
|
||||
* Build a dynamic input schema for MCP blocks.
|
||||
*
|
||||
* When a tool has been selected (tool_input_schema is populated), the block
|
||||
* renders the selected tool's input parameters *plus* the credentials field
|
||||
* so users can select/change the OAuth credential used for execution.
|
||||
*
|
||||
* Static fields like server_url, selected_tool, available_tools, and
|
||||
* tool_arguments are hidden because they're pre-configured from the dialog.
|
||||
*/
|
||||
function buildMCPInputSchema(
|
||||
toolInputSchema: Record<string, any>,
|
||||
blockInputSchema: Record<string, any>,
|
||||
): Record<string, any> {
|
||||
// Extract the credentials field from the block's original input schema
|
||||
const credentialsSchema =
|
||||
blockInputSchema?.properties?.credentials ?? undefined;
|
||||
|
||||
return {
|
||||
type: "object",
|
||||
properties: {
|
||||
// Credentials field first so the dropdown appears at the top
|
||||
...(credentialsSchema ? { credentials: credentialsSchema } : {}),
|
||||
...(toolInputSchema.properties ?? {}),
|
||||
},
|
||||
required: [...(toolInputSchema.required ?? [])],
|
||||
};
|
||||
}
|
||||
|
||||
export const useCustomNode = ({
|
||||
data,
|
||||
@@ -47,18 +19,10 @@ export const useCustomNode = ({
|
||||
);
|
||||
|
||||
const isAgent = data.uiType === BlockUIType.AGENT;
|
||||
const isMCPWithTool =
|
||||
data.uiType === BlockUIType.MCP_TOOL &&
|
||||
!!data.hardcodedValues?.tool_input_schema?.properties;
|
||||
|
||||
const currentInputSchema = isAgent
|
||||
? (data.hardcodedValues.input_schema ?? {})
|
||||
: isMCPWithTool
|
||||
? buildMCPInputSchema(
|
||||
data.hardcodedValues.tool_input_schema,
|
||||
data.inputSchema,
|
||||
)
|
||||
: data.inputSchema;
|
||||
: data.inputSchema;
|
||||
const currentOutputSchema = isAgent
|
||||
? (data.hardcodedValues.output_schema ?? {})
|
||||
: data.outputSchema;
|
||||
@@ -90,6 +54,5 @@ export const useCustomNode = ({
|
||||
return {
|
||||
inputSchema,
|
||||
outputSchema,
|
||||
isMCPWithTool,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -9,72 +9,39 @@ interface FormCreatorProps {
|
||||
jsonSchema: RJSFSchema;
|
||||
nodeId: string;
|
||||
uiType: BlockUIType;
|
||||
/** When true the block is an MCP Tool with a selected tool. */
|
||||
isMCPWithTool?: boolean;
|
||||
showHandles?: boolean;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export const FormCreator: React.FC<FormCreatorProps> = React.memo(
|
||||
({
|
||||
jsonSchema,
|
||||
nodeId,
|
||||
uiType,
|
||||
isMCPWithTool = false,
|
||||
showHandles = true,
|
||||
className,
|
||||
}) => {
|
||||
({ jsonSchema, nodeId, uiType, showHandles = true, className }) => {
|
||||
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
||||
|
||||
const getHardCodedValues = useNodeStore(
|
||||
(state) => state.getHardCodedValues,
|
||||
);
|
||||
|
||||
const isAgent = uiType === BlockUIType.AGENT;
|
||||
|
||||
const handleChange = ({ formData }: any) => {
|
||||
if ("credentials" in formData && !formData.credentials?.id) {
|
||||
delete formData.credentials;
|
||||
}
|
||||
|
||||
let updatedValues;
|
||||
if (isAgent) {
|
||||
updatedValues = {
|
||||
...getHardCodedValues(nodeId),
|
||||
inputs: formData,
|
||||
};
|
||||
} else if (isMCPWithTool) {
|
||||
// Separate credentials from tool arguments — credentials are stored
|
||||
// at the top level of hardcodedValues, not inside tool_arguments.
|
||||
const { credentials, ...toolArgs } = formData;
|
||||
updatedValues = {
|
||||
...getHardCodedValues(nodeId),
|
||||
tool_arguments: toolArgs,
|
||||
...(credentials?.id ? { credentials } : {}),
|
||||
};
|
||||
} else {
|
||||
updatedValues = formData;
|
||||
}
|
||||
const updatedValues =
|
||||
uiType === BlockUIType.AGENT
|
||||
? {
|
||||
...getHardCodedValues(nodeId),
|
||||
inputs: formData,
|
||||
}
|
||||
: formData;
|
||||
|
||||
updateNodeData(nodeId, { hardcodedValues: updatedValues });
|
||||
};
|
||||
|
||||
const hardcodedValues = getHardCodedValues(nodeId);
|
||||
|
||||
let initialValues;
|
||||
if (isAgent) {
|
||||
initialValues = hardcodedValues.inputs ?? {};
|
||||
} else if (isMCPWithTool) {
|
||||
// Merge tool arguments with credentials for the form
|
||||
initialValues = {
|
||||
...(hardcodedValues.tool_arguments ?? {}),
|
||||
...(hardcodedValues.credentials?.id
|
||||
? { credentials: hardcodedValues.credentials }
|
||||
: {}),
|
||||
};
|
||||
} else {
|
||||
initialValues = hardcodedValues;
|
||||
}
|
||||
const initialValues =
|
||||
uiType === BlockUIType.AGENT
|
||||
? (hardcodedValues.inputs ?? {})
|
||||
: hardcodedValues;
|
||||
|
||||
return (
|
||||
<div
|
||||
|
||||
@@ -1,558 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import React, {
|
||||
useState,
|
||||
useCallback,
|
||||
useRef,
|
||||
useEffect,
|
||||
useContext,
|
||||
} from "react";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/__legacy__/ui/dialog";
|
||||
import { Button } from "@/components/__legacy__/ui/button";
|
||||
import { Input } from "@/components/__legacy__/ui/input";
|
||||
import { Label } from "@/components/__legacy__/ui/label";
|
||||
import { LoadingSpinner } from "@/components/__legacy__/ui/loading";
|
||||
import { Badge } from "@/components/__legacy__/ui/badge";
|
||||
import { ScrollArea } from "@/components/__legacy__/ui/scroll-area";
|
||||
import type { CredentialsMetaInput } from "@/lib/autogpt-server-api";
|
||||
import type { MCPToolResponse } from "@/app/api/__generated__/models/mCPToolResponse";
|
||||
import {
|
||||
postV2DiscoverAvailableToolsOnAnMcpServer,
|
||||
postV2InitiateOauthLoginForAnMcpServer,
|
||||
postV2ExchangeOauthCodeForMcpTokens,
|
||||
} from "@/app/api/__generated__/endpoints/mcp/mcp";
|
||||
import { CaretDown } from "@phosphor-icons/react";
|
||||
import { openOAuthPopup } from "@/lib/oauth-popup";
|
||||
import { CredentialsProvidersContext } from "@/providers/agent-credentials/credentials-provider";
|
||||
|
||||
export type MCPToolDialogResult = {
|
||||
serverUrl: string;
|
||||
serverName: string | null;
|
||||
selectedTool: string;
|
||||
toolInputSchema: Record<string, any>;
|
||||
availableTools: Record<string, any>;
|
||||
/** Credentials meta from OAuth flow, null for public servers. */
|
||||
credentials: CredentialsMetaInput | null;
|
||||
};
|
||||
|
||||
interface MCPToolDialogProps {
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
onConfirm: (result: MCPToolDialogResult) => void;
|
||||
}
|
||||
|
||||
type DialogStep = "url" | "tool";
|
||||
|
||||
export function MCPToolDialog({
|
||||
open,
|
||||
onClose,
|
||||
onConfirm,
|
||||
}: MCPToolDialogProps) {
|
||||
const allProviders = useContext(CredentialsProvidersContext);
|
||||
|
||||
const [step, setStep] = useState<DialogStep>("url");
|
||||
const [serverUrl, setServerUrl] = useState("");
|
||||
const [tools, setTools] = useState<MCPToolResponse[]>([]);
|
||||
const [serverName, setServerName] = useState<string | null>(null);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [authRequired, setAuthRequired] = useState(false);
|
||||
const [oauthLoading, setOauthLoading] = useState(false);
|
||||
const [showManualToken, setShowManualToken] = useState(false);
|
||||
const [manualToken, setManualToken] = useState("");
|
||||
const [selectedTool, setSelectedTool] = useState<MCPToolResponse | null>(
|
||||
null,
|
||||
);
|
||||
const [credentials, setCredentials] = useState<CredentialsMetaInput | null>(
|
||||
null,
|
||||
);
|
||||
|
||||
const startOAuthRef = useRef(false);
|
||||
const oauthAbortRef = useRef<((reason?: string) => void) | null>(null);
|
||||
|
||||
// Clean up on unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
oauthAbortRef.current?.();
|
||||
};
|
||||
}, []);
|
||||
|
||||
const reset = useCallback(() => {
|
||||
oauthAbortRef.current?.();
|
||||
oauthAbortRef.current = null;
|
||||
setStep("url");
|
||||
setServerUrl("");
|
||||
setManualToken("");
|
||||
setTools([]);
|
||||
setServerName(null);
|
||||
setLoading(false);
|
||||
setError(null);
|
||||
setAuthRequired(false);
|
||||
setOauthLoading(false);
|
||||
setShowManualToken(false);
|
||||
setSelectedTool(null);
|
||||
setCredentials(null);
|
||||
}, []);
|
||||
|
||||
const handleClose = useCallback(() => {
|
||||
reset();
|
||||
onClose();
|
||||
}, [reset, onClose]);
|
||||
|
||||
const discoverTools = useCallback(async (url: string, authToken?: string) => {
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
try {
|
||||
const response = await postV2DiscoverAvailableToolsOnAnMcpServer({
|
||||
server_url: url,
|
||||
auth_token: authToken || null,
|
||||
});
|
||||
if (response.status !== 200) throw response.data;
|
||||
setTools(response.data.tools);
|
||||
setServerName(response.data.server_name ?? null);
|
||||
setAuthRequired(false);
|
||||
setShowManualToken(false);
|
||||
setStep("tool");
|
||||
} catch (e: any) {
|
||||
if (e?.status === 401 || e?.status === 403) {
|
||||
setAuthRequired(true);
|
||||
setError(null);
|
||||
// Automatically start OAuth sign-in instead of requiring a second click
|
||||
setLoading(false);
|
||||
startOAuthRef.current = true;
|
||||
return;
|
||||
} else {
|
||||
const message =
|
||||
e?.message || e?.detail || "Failed to connect to MCP server";
|
||||
setError(
|
||||
typeof message === "string" ? message : JSON.stringify(message),
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const handleDiscoverTools = useCallback(() => {
|
||||
if (!serverUrl.trim()) return;
|
||||
discoverTools(serverUrl.trim(), manualToken.trim() || undefined);
|
||||
}, [serverUrl, manualToken, discoverTools]);
|
||||
|
||||
const handleOAuthSignIn = useCallback(async () => {
|
||||
if (!serverUrl.trim()) return;
|
||||
setError(null);
|
||||
|
||||
// Abort any previous OAuth flow
|
||||
oauthAbortRef.current?.();
|
||||
|
||||
setOauthLoading(true);
|
||||
|
||||
try {
|
||||
const loginResponse = await postV2InitiateOauthLoginForAnMcpServer({
|
||||
server_url: serverUrl.trim(),
|
||||
});
|
||||
if (loginResponse.status !== 200) throw loginResponse.data;
|
||||
const { login_url, state_token } = loginResponse.data;
|
||||
|
||||
const { promise, cleanup } = openOAuthPopup(login_url, {
|
||||
stateToken: state_token,
|
||||
useCrossOriginListeners: true,
|
||||
});
|
||||
oauthAbortRef.current = cleanup.abort;
|
||||
|
||||
const result = await promise;
|
||||
|
||||
// Exchange code for tokens via the credentials provider (updates cache)
|
||||
setLoading(true);
|
||||
setOauthLoading(false);
|
||||
|
||||
const mcpProvider = allProviders?.["mcp"];
|
||||
let callbackResult;
|
||||
if (mcpProvider) {
|
||||
callbackResult = await mcpProvider.mcpOAuthCallback(
|
||||
result.code,
|
||||
state_token,
|
||||
);
|
||||
} else {
|
||||
const cbResponse = await postV2ExchangeOauthCodeForMcpTokens({
|
||||
code: result.code,
|
||||
state_token,
|
||||
});
|
||||
if (cbResponse.status !== 200) throw cbResponse.data;
|
||||
callbackResult = cbResponse.data;
|
||||
}
|
||||
|
||||
setCredentials({
|
||||
id: callbackResult.id,
|
||||
provider: callbackResult.provider,
|
||||
type: callbackResult.type,
|
||||
title: callbackResult.title,
|
||||
});
|
||||
setAuthRequired(false);
|
||||
|
||||
// Discover tools now that we're authenticated
|
||||
const toolsResponse = await postV2DiscoverAvailableToolsOnAnMcpServer({
|
||||
server_url: serverUrl.trim(),
|
||||
});
|
||||
if (toolsResponse.status !== 200) throw toolsResponse.data;
|
||||
setTools(toolsResponse.data.tools);
|
||||
setServerName(toolsResponse.data.server_name ?? null);
|
||||
setStep("tool");
|
||||
} catch (e: any) {
|
||||
// If server doesn't support OAuth → show manual token entry
|
||||
if (e?.status === 400) {
|
||||
setShowManualToken(true);
|
||||
setError(
|
||||
"This server does not support OAuth sign-in. Please enter a token manually.",
|
||||
);
|
||||
} else if (e?.message === "OAuth flow timed out") {
|
||||
setError("OAuth sign-in timed out. Please try again.");
|
||||
} else {
|
||||
const status = e?.status;
|
||||
let message: string;
|
||||
if (status === 401 || status === 403) {
|
||||
message =
|
||||
"Authentication succeeded but the server still rejected the request. " +
|
||||
"The token audience may not match. Please try again.";
|
||||
} else {
|
||||
message = e?.message || e?.detail || "Failed to complete sign-in";
|
||||
}
|
||||
setError(
|
||||
typeof message === "string" ? message : JSON.stringify(message),
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
setOauthLoading(false);
|
||||
setLoading(false);
|
||||
oauthAbortRef.current = null;
|
||||
}
|
||||
}, [serverUrl, allProviders]);
|
||||
|
||||
// Auto-start OAuth sign-in when server returns 401/403
|
||||
useEffect(() => {
|
||||
if (authRequired && startOAuthRef.current) {
|
||||
startOAuthRef.current = false;
|
||||
handleOAuthSignIn();
|
||||
}
|
||||
}, [authRequired, handleOAuthSignIn]);
|
||||
|
||||
const handleConfirm = useCallback(() => {
|
||||
if (!selectedTool) return;
|
||||
|
||||
const availableTools: Record<string, any> = {};
|
||||
for (const t of tools) {
|
||||
availableTools[t.name] = {
|
||||
description: t.description,
|
||||
input_schema: t.input_schema,
|
||||
};
|
||||
}
|
||||
|
||||
onConfirm({
|
||||
serverUrl: serverUrl.trim(),
|
||||
serverName,
|
||||
selectedTool: selectedTool.name,
|
||||
toolInputSchema: selectedTool.input_schema,
|
||||
availableTools,
|
||||
credentials,
|
||||
});
|
||||
reset();
|
||||
}, [
|
||||
selectedTool,
|
||||
tools,
|
||||
serverUrl,
|
||||
serverName,
|
||||
credentials,
|
||||
onConfirm,
|
||||
reset,
|
||||
]);
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={(isOpen) => !isOpen && handleClose()}>
|
||||
<DialogContent className="max-w-lg">
|
||||
<DialogHeader>
|
||||
<DialogTitle>
|
||||
{step === "url"
|
||||
? "Connect to MCP Server"
|
||||
: `Select a Tool${serverName ? ` — ${serverName}` : ""}`}
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
{step === "url"
|
||||
? "Enter the URL of an MCP server to discover its available tools."
|
||||
: `Found ${tools.length} tool${tools.length !== 1 ? "s" : ""}. Select one to add to your agent.`}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
{step === "url" && (
|
||||
<div className="flex flex-col gap-4 py-2">
|
||||
<div className="flex flex-col gap-2">
|
||||
<Label htmlFor="mcp-server-url">Server URL</Label>
|
||||
<Input
|
||||
id="mcp-server-url"
|
||||
type="url"
|
||||
placeholder="https://mcp.example.com/mcp"
|
||||
value={serverUrl}
|
||||
onChange={(e) => setServerUrl(e.target.value)}
|
||||
onKeyDown={(e) => e.key === "Enter" && handleDiscoverTools()}
|
||||
autoFocus
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Auth required: show manual token option */}
|
||||
{authRequired && !showManualToken && (
|
||||
<button
|
||||
onClick={() => setShowManualToken(true)}
|
||||
className="text-xs text-gray-500 underline hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-300"
|
||||
>
|
||||
or enter a token manually
|
||||
</button>
|
||||
)}
|
||||
|
||||
{/* Manual token entry — only visible when expanded */}
|
||||
{showManualToken && (
|
||||
<div className="flex flex-col gap-2">
|
||||
<Label htmlFor="mcp-auth-token" className="text-sm">
|
||||
Bearer Token
|
||||
</Label>
|
||||
<Input
|
||||
id="mcp-auth-token"
|
||||
type="password"
|
||||
placeholder="Paste your auth token here"
|
||||
value={manualToken}
|
||||
onChange={(e) => setManualToken(e.target.value)}
|
||||
onKeyDown={(e) => e.key === "Enter" && handleDiscoverTools()}
|
||||
autoFocus
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{error && <p className="text-sm text-red-500">{error}</p>}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{step === "tool" && (
|
||||
<ScrollArea className="max-h-[50vh] py-2">
|
||||
<div className="flex flex-col gap-2 pr-3">
|
||||
{tools.map((tool) => (
|
||||
<MCPToolCard
|
||||
key={tool.name}
|
||||
tool={tool}
|
||||
selected={selectedTool?.name === tool.name}
|
||||
onSelect={() => setSelectedTool(tool)}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
)}
|
||||
|
||||
<DialogFooter>
|
||||
{step === "tool" && (
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={() => {
|
||||
setStep("url");
|
||||
setSelectedTool(null);
|
||||
}}
|
||||
>
|
||||
Back
|
||||
</Button>
|
||||
)}
|
||||
<Button variant="outline" onClick={handleClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
{step === "url" && (
|
||||
<Button
|
||||
onClick={
|
||||
authRequired && !showManualToken
|
||||
? handleOAuthSignIn
|
||||
: handleDiscoverTools
|
||||
}
|
||||
disabled={!serverUrl.trim() || loading || oauthLoading}
|
||||
>
|
||||
{loading || oauthLoading ? (
|
||||
<span className="flex items-center gap-2">
|
||||
<LoadingSpinner className="size-4" />
|
||||
{oauthLoading ? "Waiting for sign-in..." : "Connecting..."}
|
||||
</span>
|
||||
) : authRequired && !showManualToken ? (
|
||||
"Sign in & Connect"
|
||||
) : (
|
||||
"Discover Tools"
|
||||
)}
|
||||
</Button>
|
||||
)}
|
||||
{step === "tool" && (
|
||||
<Button onClick={handleConfirm} disabled={!selectedTool}>
|
||||
Add Block
|
||||
</Button>
|
||||
)}
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
|
||||
// --------------- Tool Card Component --------------- //
|
||||
|
||||
/** Truncate a description to a reasonable length for the collapsed view. */
|
||||
function truncateDescription(text: string, maxLen = 120): string {
|
||||
if (text.length <= maxLen) return text;
|
||||
return text.slice(0, maxLen).trimEnd() + "…";
|
||||
}
|
||||
|
||||
/** Pretty-print a JSON Schema type for a parameter. */
|
||||
function schemaTypeLabel(schema: Record<string, any>): string {
|
||||
if (schema.type) return schema.type;
|
||||
if (schema.anyOf)
|
||||
return schema.anyOf.map((s: any) => s.type ?? "any").join(" | ");
|
||||
if (schema.oneOf)
|
||||
return schema.oneOf.map((s: any) => s.type ?? "any").join(" | ");
|
||||
return "any";
|
||||
}
|
||||
|
||||
function MCPToolCard({
|
||||
tool,
|
||||
selected,
|
||||
onSelect,
|
||||
}: {
|
||||
tool: MCPToolResponse;
|
||||
selected: boolean;
|
||||
onSelect: () => void;
|
||||
}) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const schema = tool.input_schema as Record<string, any>;
|
||||
const properties = schema?.properties ?? {};
|
||||
const required = new Set<string>(schema?.required ?? []);
|
||||
const paramNames = Object.keys(properties);
|
||||
|
||||
// Strip XML-like tags from description for cleaner display.
|
||||
// Loop to handle nested tags like <scr<script>ipt> (CodeQL fix).
|
||||
let cleanDescription = tool.description ?? "";
|
||||
let prev = "";
|
||||
while (prev !== cleanDescription) {
|
||||
prev = cleanDescription;
|
||||
cleanDescription = cleanDescription.replace(/<[^>]*>/g, "");
|
||||
}
|
||||
cleanDescription = cleanDescription.trim();
|
||||
|
||||
return (
|
||||
<button
|
||||
onClick={onSelect}
|
||||
className={`group flex flex-col rounded-lg border text-left transition-colors ${
|
||||
selected
|
||||
? "border-blue-500 bg-blue-50 dark:border-blue-400 dark:bg-blue-950"
|
||||
: "border-gray-200 hover:border-gray-300 hover:bg-gray-50 dark:border-slate-700 dark:hover:border-slate-600 dark:hover:bg-slate-800"
|
||||
}`}
|
||||
>
|
||||
{/* Header */}
|
||||
<div className="flex items-center gap-2 px-3 pb-1 pt-3">
|
||||
<span className="flex-1 text-sm font-semibold dark:text-white">
|
||||
{tool.name}
|
||||
</span>
|
||||
{paramNames.length > 0 && (
|
||||
<Badge variant="secondary" className="text-[10px]">
|
||||
{paramNames.length} param{paramNames.length !== 1 ? "s" : ""}
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Description (collapsed: truncated) */}
|
||||
{cleanDescription && (
|
||||
<p className="px-3 pb-1 text-xs leading-relaxed text-gray-500 dark:text-gray-400">
|
||||
{expanded ? cleanDescription : truncateDescription(cleanDescription)}
|
||||
</p>
|
||||
)}
|
||||
|
||||
{/* Parameter badges (collapsed view) */}
|
||||
{!expanded && paramNames.length > 0 && (
|
||||
<div className="flex flex-wrap gap-1 px-3 pb-2">
|
||||
{paramNames.slice(0, 6).map((name) => (
|
||||
<Badge
|
||||
key={name}
|
||||
variant="outline"
|
||||
className="text-[10px] font-normal"
|
||||
>
|
||||
{name}
|
||||
{required.has(name) && (
|
||||
<span className="ml-0.5 text-red-400">*</span>
|
||||
)}
|
||||
</Badge>
|
||||
))}
|
||||
{paramNames.length > 6 && (
|
||||
<Badge variant="outline" className="text-[10px] font-normal">
|
||||
+{paramNames.length - 6} more
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Expanded: full parameter details */}
|
||||
{expanded && paramNames.length > 0 && (
|
||||
<div className="mx-3 mb-2 rounded border border-gray-100 bg-gray-50/50 dark:border-slate-700 dark:bg-slate-800/50">
|
||||
<table className="w-full text-xs">
|
||||
<thead>
|
||||
<tr className="border-b border-gray-100 dark:border-slate-700">
|
||||
<th className="px-2 py-1 text-left font-medium text-gray-500 dark:text-gray-400">
|
||||
Parameter
|
||||
</th>
|
||||
<th className="px-2 py-1 text-left font-medium text-gray-500 dark:text-gray-400">
|
||||
Type
|
||||
</th>
|
||||
<th className="px-2 py-1 text-left font-medium text-gray-500 dark:text-gray-400">
|
||||
Description
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{paramNames.map((name) => {
|
||||
const prop = properties[name] ?? {};
|
||||
return (
|
||||
<tr
|
||||
key={name}
|
||||
className="border-b border-gray-50 last:border-0 dark:border-slate-700/50"
|
||||
>
|
||||
<td className="px-2 py-1 font-mono text-[11px] text-gray-700 dark:text-gray-300">
|
||||
{name}
|
||||
{required.has(name) && (
|
||||
<span className="ml-0.5 text-red-400">*</span>
|
||||
)}
|
||||
</td>
|
||||
<td className="px-2 py-1 text-gray-500 dark:text-gray-400">
|
||||
{schemaTypeLabel(prop)}
|
||||
</td>
|
||||
<td className="max-w-[200px] truncate px-2 py-1 text-gray-500 dark:text-gray-400">
|
||||
{prop.description ?? "—"}
|
||||
</td>
|
||||
</tr>
|
||||
);
|
||||
})}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Toggle details */}
|
||||
{(paramNames.length > 0 || cleanDescription.length > 120) && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setExpanded((prev) => !prev);
|
||||
}}
|
||||
className="flex w-full items-center justify-center gap-1 border-t border-gray-100 py-1.5 text-[10px] text-gray-400 hover:text-gray-600 dark:border-slate-700 dark:text-gray-500 dark:hover:text-gray-300"
|
||||
>
|
||||
{expanded ? "Hide details" : "Show details"}
|
||||
<CaretDown
|
||||
className={`h-3 w-3 transition-transform ${expanded ? "rotate-180" : ""}`}
|
||||
/>
|
||||
</button>
|
||||
)}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Button } from "@/components/__legacy__/ui/button";
|
||||
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
|
||||
import { beautifyString, cn } from "@/lib/utils";
|
||||
import React, { ButtonHTMLAttributes, useCallback, useState } from "react";
|
||||
import React, { ButtonHTMLAttributes } from "react";
|
||||
import { highlightText } from "./helpers";
|
||||
import { PlusIcon } from "@phosphor-icons/react";
|
||||
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
|
||||
@@ -9,12 +9,6 @@ import { useControlPanelStore } from "../../../stores/controlPanelStore";
|
||||
import { blockDragPreviewStyle } from "./style";
|
||||
import { useReactFlow } from "@xyflow/react";
|
||||
import { useNodeStore } from "../../../stores/nodeStore";
|
||||
import { BlockUIType, SpecialBlockID } from "@/lib/autogpt-server-api";
|
||||
import {
|
||||
MCPToolDialog,
|
||||
type MCPToolDialogResult,
|
||||
} from "@/app/(platform)/build/components/MCPToolDialog";
|
||||
|
||||
interface Props extends ButtonHTMLAttributes<HTMLButtonElement> {
|
||||
title?: string;
|
||||
description?: string;
|
||||
@@ -39,86 +33,22 @@ export const Block: BlockComponent = ({
|
||||
);
|
||||
const { setViewport } = useReactFlow();
|
||||
const { addBlock } = useNodeStore();
|
||||
const [mcpDialogOpen, setMcpDialogOpen] = useState(false);
|
||||
|
||||
const isMCPBlock = blockData.uiType === BlockUIType.MCP_TOOL;
|
||||
|
||||
const addBlockAndCenter = useCallback(
|
||||
(block: BlockInfo, hardcodedValues?: Record<string, any>) => {
|
||||
const customNode = addBlock(block, hardcodedValues);
|
||||
setTimeout(() => {
|
||||
setViewport(
|
||||
{
|
||||
x: -customNode.position.x * 0.8 + window.innerWidth / 2,
|
||||
y: -customNode.position.y * 0.8 + (window.innerHeight - 400) / 2,
|
||||
zoom: 0.8,
|
||||
},
|
||||
{ duration: 500 },
|
||||
);
|
||||
}, 50);
|
||||
return customNode;
|
||||
},
|
||||
[addBlock, setViewport],
|
||||
);
|
||||
|
||||
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
||||
|
||||
const handleMCPToolConfirm = useCallback(
|
||||
(result: MCPToolDialogResult) => {
|
||||
// Derive a display label: prefer server name, fall back to URL hostname.
|
||||
let serverLabel = result.serverName;
|
||||
if (!serverLabel) {
|
||||
try {
|
||||
serverLabel = new URL(result.serverUrl).hostname;
|
||||
} catch {
|
||||
serverLabel = "MCP";
|
||||
}
|
||||
}
|
||||
|
||||
const customNode = addBlockAndCenter(blockData, {
|
||||
server_url: result.serverUrl,
|
||||
server_name: serverLabel,
|
||||
selected_tool: result.selectedTool,
|
||||
tool_input_schema: result.toolInputSchema,
|
||||
available_tools: result.availableTools,
|
||||
credentials: result.credentials ?? undefined,
|
||||
});
|
||||
if (customNode) {
|
||||
const title = result.selectedTool
|
||||
? `${serverLabel}: ${beautifyString(result.selectedTool)}`
|
||||
: undefined;
|
||||
updateNodeData(customNode.id, {
|
||||
metadata: {
|
||||
...customNode.data.metadata,
|
||||
credentials_optional: true,
|
||||
...(title && { customized_name: title }),
|
||||
},
|
||||
});
|
||||
}
|
||||
setMcpDialogOpen(false);
|
||||
},
|
||||
[addBlockAndCenter, blockData, updateNodeData],
|
||||
);
|
||||
|
||||
const handleClick = () => {
|
||||
if (isMCPBlock) {
|
||||
setMcpDialogOpen(true);
|
||||
return;
|
||||
}
|
||||
const customNode = addBlockAndCenter(blockData);
|
||||
// Set customized_name for agent blocks so the agent's name persists
|
||||
if (customNode && blockData.id === SpecialBlockID.AGENT) {
|
||||
updateNodeData(customNode.id, {
|
||||
metadata: {
|
||||
...customNode.data.metadata,
|
||||
customized_name: blockData.name,
|
||||
const customNode = addBlock(blockData);
|
||||
setTimeout(() => {
|
||||
setViewport(
|
||||
{
|
||||
x: -customNode.position.x * 0.8 + window.innerWidth / 2,
|
||||
y: -customNode.position.y * 0.8 + (window.innerHeight - 400) / 2,
|
||||
zoom: 0.8,
|
||||
},
|
||||
});
|
||||
}
|
||||
{ duration: 500 },
|
||||
);
|
||||
}, 50);
|
||||
};
|
||||
|
||||
const handleDragStart = (e: React.DragEvent<HTMLButtonElement>) => {
|
||||
if (isMCPBlock) return;
|
||||
e.dataTransfer.effectAllowed = "copy";
|
||||
e.dataTransfer.setData("application/reactflow", JSON.stringify(blockData));
|
||||
|
||||
@@ -141,56 +71,46 @@ export const Block: BlockComponent = ({
|
||||
: undefined;
|
||||
|
||||
return (
|
||||
<>
|
||||
<Button
|
||||
draggable={!isMCPBlock}
|
||||
data-id={blockDataId}
|
||||
className={cn(
|
||||
"group flex h-16 w-full min-w-[7.5rem] items-center justify-start space-x-3 whitespace-normal rounded-[0.75rem] bg-zinc-50 px-[0.875rem] py-[0.625rem] text-start shadow-none",
|
||||
"hover:cursor-default hover:bg-zinc-100 focus:ring-0 active:bg-zinc-100 active:ring-1 active:ring-zinc-300 disabled:cursor-not-allowed",
|
||||
isMCPBlock && "hover:cursor-pointer",
|
||||
className,
|
||||
)}
|
||||
onDragStart={handleDragStart}
|
||||
onClick={handleClick}
|
||||
{...rest}
|
||||
>
|
||||
<div className="flex flex-1 flex-col items-start gap-0.5">
|
||||
{title && (
|
||||
<span
|
||||
className={cn(
|
||||
"line-clamp-1 font-sans text-sm font-medium leading-[1.375rem] text-zinc-800 group-disabled:text-zinc-400",
|
||||
)}
|
||||
>
|
||||
{highlightText(beautifyString(title), highlightedText)}
|
||||
</span>
|
||||
)}
|
||||
{description && (
|
||||
<span
|
||||
className={cn(
|
||||
"line-clamp-1 font-sans text-xs font-normal leading-5 text-zinc-500 group-disabled:text-zinc-400",
|
||||
)}
|
||||
>
|
||||
{highlightText(description, highlightedText)}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<div
|
||||
className={cn(
|
||||
"flex h-7 w-7 items-center justify-center rounded-[0.5rem] bg-zinc-700 group-disabled:bg-zinc-400",
|
||||
)}
|
||||
>
|
||||
<PlusIcon className="h-5 w-5 text-zinc-50" />
|
||||
</div>
|
||||
</Button>
|
||||
{isMCPBlock && (
|
||||
<MCPToolDialog
|
||||
open={mcpDialogOpen}
|
||||
onClose={() => setMcpDialogOpen(false)}
|
||||
onConfirm={handleMCPToolConfirm}
|
||||
/>
|
||||
<Button
|
||||
draggable={true}
|
||||
data-id={blockDataId}
|
||||
className={cn(
|
||||
"group flex h-16 w-full min-w-[7.5rem] items-center justify-start space-x-3 whitespace-normal rounded-[0.75rem] bg-zinc-50 px-[0.875rem] py-[0.625rem] text-start shadow-none",
|
||||
"hover:cursor-default hover:bg-zinc-100 focus:ring-0 active:bg-zinc-100 active:ring-1 active:ring-zinc-300 disabled:cursor-not-allowed",
|
||||
className,
|
||||
)}
|
||||
</>
|
||||
onDragStart={handleDragStart}
|
||||
onClick={handleClick}
|
||||
{...rest}
|
||||
>
|
||||
<div className="flex flex-1 flex-col items-start gap-0.5">
|
||||
{title && (
|
||||
<span
|
||||
className={cn(
|
||||
"line-clamp-1 font-sans text-sm font-medium leading-[1.375rem] text-zinc-800 group-disabled:text-zinc-400",
|
||||
)}
|
||||
>
|
||||
{highlightText(beautifyString(title), highlightedText)}
|
||||
</span>
|
||||
)}
|
||||
{description && (
|
||||
<span
|
||||
className={cn(
|
||||
"line-clamp-1 font-sans text-xs font-normal leading-5 text-zinc-500 group-disabled:text-zinc-400",
|
||||
)}
|
||||
>
|
||||
{highlightText(description, highlightedText)}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<div
|
||||
className={cn(
|
||||
"flex h-7 w-7 items-center justify-center rounded-[0.5rem] bg-zinc-700 group-disabled:bg-zinc-400",
|
||||
)}
|
||||
>
|
||||
<PlusIcon className="h-5 w-5 text-zinc-50" />
|
||||
</div>
|
||||
</Button>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user