mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-13 08:14:58 -05:00
Compare commits
10 Commits
feat/mcp-b
...
fix/claude
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0b2fb655bc | ||
|
|
99f8bf5f0c | ||
|
|
3f76f1318b | ||
|
|
b011289dd2 | ||
|
|
49c2f578b4 | ||
|
|
7150b7768d | ||
|
|
8c95b03636 | ||
|
|
4a8368887f | ||
|
|
d46e5e6b6a | ||
|
|
4e632bbd60 |
@@ -10,7 +10,7 @@ from typing_extensions import TypedDict
|
|||||||
|
|
||||||
import backend.api.features.store.cache as store_cache
|
import backend.api.features.store.cache as store_cache
|
||||||
import backend.api.features.store.model as store_model
|
import backend.api.features.store.model as store_model
|
||||||
import backend.blocks
|
import backend.data.block
|
||||||
from backend.api.external.middleware import require_permission
|
from backend.api.external.middleware import require_permission
|
||||||
from backend.data import execution as execution_db
|
from backend.data import execution as execution_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
@@ -67,7 +67,7 @@ async def get_user_info(
|
|||||||
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
|
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
|
||||||
)
|
)
|
||||||
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||||
blocks = [block() for block in backend.blocks.get_blocks().values()]
|
blocks = [block() for block in backend.data.block.get_blocks().values()]
|
||||||
return [b.to_dict() for b in blocks if not b.disabled]
|
return [b.to_dict() for b in blocks if not b.disabled]
|
||||||
|
|
||||||
|
|
||||||
@@ -83,7 +83,7 @@ async def execute_graph_block(
|
|||||||
require_permission(APIKeyPermission.EXECUTE_BLOCK)
|
require_permission(APIKeyPermission.EXECUTE_BLOCK)
|
||||||
),
|
),
|
||||||
) -> CompletedBlockOutput:
|
) -> CompletedBlockOutput:
|
||||||
obj = backend.blocks.get_block(block_id)
|
obj = backend.data.block.get_block(block_id)
|
||||||
if not obj:
|
if not obj:
|
||||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||||
if obj.disabled:
|
if obj.disabled:
|
||||||
|
|||||||
@@ -10,15 +10,10 @@ import backend.api.features.library.db as library_db
|
|||||||
import backend.api.features.library.model as library_model
|
import backend.api.features.library.model as library_model
|
||||||
import backend.api.features.store.db as store_db
|
import backend.api.features.store.db as store_db
|
||||||
import backend.api.features.store.model as store_model
|
import backend.api.features.store.model as store_model
|
||||||
|
import backend.data.block
|
||||||
from backend.blocks import load_all_blocks
|
from backend.blocks import load_all_blocks
|
||||||
from backend.blocks._base import (
|
|
||||||
AnyBlockSchema,
|
|
||||||
BlockCategory,
|
|
||||||
BlockInfo,
|
|
||||||
BlockSchema,
|
|
||||||
BlockType,
|
|
||||||
)
|
|
||||||
from backend.blocks.llm import LlmModel
|
from backend.blocks.llm import LlmModel
|
||||||
|
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
|
||||||
from backend.data.db import query_raw_with_schema
|
from backend.data.db import query_raw_with_schema
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.cache import cached
|
from backend.util.cache import cached
|
||||||
@@ -27,7 +22,7 @@ from backend.util.models import Pagination
|
|||||||
from .model import (
|
from .model import (
|
||||||
BlockCategoryResponse,
|
BlockCategoryResponse,
|
||||||
BlockResponse,
|
BlockResponse,
|
||||||
BlockTypeFilter,
|
BlockType,
|
||||||
CountResponse,
|
CountResponse,
|
||||||
FilterType,
|
FilterType,
|
||||||
Provider,
|
Provider,
|
||||||
@@ -93,7 +88,7 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
|
|||||||
def get_blocks(
|
def get_blocks(
|
||||||
*,
|
*,
|
||||||
category: str | None = None,
|
category: str | None = None,
|
||||||
type: BlockTypeFilter | None = None,
|
type: BlockType | None = None,
|
||||||
provider: ProviderName | None = None,
|
provider: ProviderName | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 50,
|
page_size: int = 50,
|
||||||
@@ -674,9 +669,9 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
|||||||
for block_type in load_all_blocks().values():
|
for block_type in load_all_blocks().values():
|
||||||
block: AnyBlockSchema = block_type()
|
block: AnyBlockSchema = block_type()
|
||||||
if block.disabled or block.block_type in (
|
if block.disabled or block.block_type in (
|
||||||
BlockType.INPUT,
|
backend.data.block.BlockType.INPUT,
|
||||||
BlockType.OUTPUT,
|
backend.data.block.BlockType.OUTPUT,
|
||||||
BlockType.AGENT,
|
backend.data.block.BlockType.AGENT,
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
# Find the execution count for this block
|
# Find the execution count for this block
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
import backend.api.features.library.model as library_model
|
import backend.api.features.library.model as library_model
|
||||||
import backend.api.features.store.model as store_model
|
import backend.api.features.store.model as store_model
|
||||||
from backend.blocks._base import BlockInfo
|
from backend.data.block import BlockInfo
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
@@ -15,7 +15,7 @@ FilterType = Literal[
|
|||||||
"my_agents",
|
"my_agents",
|
||||||
]
|
]
|
||||||
|
|
||||||
BlockTypeFilter = Literal["all", "input", "action", "output"]
|
BlockType = Literal["all", "input", "action", "output"]
|
||||||
|
|
||||||
|
|
||||||
class SearchEntry(BaseModel):
|
class SearchEntry(BaseModel):
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ async def get_block_categories(
|
|||||||
)
|
)
|
||||||
async def get_blocks(
|
async def get_blocks(
|
||||||
category: Annotated[str | None, fastapi.Query()] = None,
|
category: Annotated[str | None, fastapi.Query()] = None,
|
||||||
type: Annotated[builder_model.BlockTypeFilter | None, fastapi.Query()] = None,
|
type: Annotated[builder_model.BlockType | None, fastapi.Query()] = None,
|
||||||
provider: Annotated[ProviderName | None, fastapi.Query()] = None,
|
provider: Annotated[ProviderName | None, fastapi.Query()] = None,
|
||||||
page: Annotated[int, fastapi.Query()] = 1,
|
page: Annotated[int, fastapi.Query()] = 1,
|
||||||
page_size: Annotated[int, fastapi.Query()] = 50,
|
page_size: Annotated[int, fastapi.Query()] = 50,
|
||||||
|
|||||||
@@ -13,8 +13,7 @@ from backend.api.features.chat.tools.models import (
|
|||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
)
|
)
|
||||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||||
from backend.blocks import get_block
|
from backend.data.block import BlockType, get_block
|
||||||
from backend.blocks._base import BlockType
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from backend.api.features.chat.tools.find_block import (
|
|||||||
FindBlockTool,
|
FindBlockTool,
|
||||||
)
|
)
|
||||||
from backend.api.features.chat.tools.models import BlockListResponse
|
from backend.api.features.chat.tools.models import BlockListResponse
|
||||||
from backend.blocks._base import BlockType
|
from backend.data.block import BlockType
|
||||||
|
|
||||||
from ._test_data import make_session
|
from ._test_data import make_session
|
||||||
|
|
||||||
|
|||||||
@@ -12,8 +12,7 @@ from backend.api.features.chat.tools.find_block import (
|
|||||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||||
)
|
)
|
||||||
from backend.blocks import get_block
|
from backend.data.block import AnyBlockSchema, get_block
|
||||||
from backend.blocks._base import AnyBlockSchema
|
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||||
from backend.data.workspace import get_or_create_workspace
|
from backend.data.workspace import get_or_create_workspace
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import pytest
|
|||||||
|
|
||||||
from backend.api.features.chat.tools.models import ErrorResponse
|
from backend.api.features.chat.tools.models import ErrorResponse
|
||||||
from backend.api.features.chat.tools.run_block import RunBlockTool
|
from backend.api.features.chat.tools.run_block import RunBlockTool
|
||||||
from backend.blocks._base import BlockType
|
from backend.data.block import BlockType
|
||||||
|
|
||||||
from ._test_data import make_session
|
from ._test_data import make_session
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ from backend.data.model import (
|
|||||||
OAuth2Credentials,
|
OAuth2Credentials,
|
||||||
)
|
)
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -360,7 +359,7 @@ async def match_user_credentials_to_graph(
|
|||||||
_,
|
_,
|
||||||
_,
|
_,
|
||||||
) in aggregated_creds.items():
|
) in aggregated_creds.items():
|
||||||
# Find first matching credential by provider, type, scopes, and host/URL
|
# Find first matching credential by provider, type, and scopes
|
||||||
matching_cred = next(
|
matching_cred = next(
|
||||||
(
|
(
|
||||||
cred
|
cred
|
||||||
@@ -375,10 +374,6 @@ async def match_user_credentials_to_graph(
|
|||||||
cred.type != "host_scoped"
|
cred.type != "host_scoped"
|
||||||
or _credential_is_for_host(cred, credential_requirements)
|
or _credential_is_for_host(cred, credential_requirements)
|
||||||
)
|
)
|
||||||
and (
|
|
||||||
cred.provider != ProviderName.MCP
|
|
||||||
or _credential_is_for_mcp_server(cred, credential_requirements)
|
|
||||||
)
|
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@@ -449,22 +444,6 @@ def _credential_is_for_host(
|
|||||||
return credential.matches_url(list(requirements.discriminator_values)[0])
|
return credential.matches_url(list(requirements.discriminator_values)[0])
|
||||||
|
|
||||||
|
|
||||||
def _credential_is_for_mcp_server(
|
|
||||||
credential: Credentials,
|
|
||||||
requirements: CredentialsFieldInfo,
|
|
||||||
) -> bool:
|
|
||||||
"""Check if an MCP OAuth credential matches the required server URL."""
|
|
||||||
if not requirements.discriminator_values:
|
|
||||||
return True
|
|
||||||
|
|
||||||
server_url = (
|
|
||||||
credential.metadata.get("mcp_server_url")
|
|
||||||
if isinstance(credential, OAuth2Credentials)
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
return server_url in requirements.discriminator_values if server_url else False
|
|
||||||
|
|
||||||
|
|
||||||
async def check_user_has_required_credentials(
|
async def check_user_has_required_credentials(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
required_credentials: list[CredentialsMetaInput],
|
required_credentials: list[CredentialsMetaInput],
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import TYPE_CHECKING, Annotated, Any, List, Literal
|
from typing import TYPE_CHECKING, Annotated, List, Literal
|
||||||
|
|
||||||
from autogpt_libs.auth import get_user_id
|
from autogpt_libs.auth import get_user_id
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
@@ -14,7 +14,7 @@ from fastapi import (
|
|||||||
Security,
|
Security,
|
||||||
status,
|
status,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel, Field, SecretStr, model_validator
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
||||||
|
|
||||||
from backend.api.features.library.db import set_preset_webhook, update_preset
|
from backend.api.features.library.db import set_preset_webhook, update_preset
|
||||||
@@ -39,11 +39,7 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
|||||||
from backend.data.user import get_user_integrations
|
from backend.data.user import get_user_integrations
|
||||||
from backend.executor.utils import add_graph_execution
|
from backend.executor.utils import add_graph_execution
|
||||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||||
from backend.integrations.credentials_store import provider_matches
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.creds_manager import (
|
|
||||||
IntegrationCredentialsManager,
|
|
||||||
create_mcp_oauth_handler,
|
|
||||||
)
|
|
||||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.integrations.webhooks import get_webhook_manager
|
from backend.integrations.webhooks import get_webhook_manager
|
||||||
@@ -106,37 +102,9 @@ class CredentialsMetaResponse(BaseModel):
|
|||||||
scopes: list[str] | None
|
scopes: list[str] | None
|
||||||
username: str | None
|
username: str | None
|
||||||
host: str | None = Field(
|
host: str | None = Field(
|
||||||
default=None,
|
default=None, description="Host pattern for host-scoped credentials"
|
||||||
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def _normalize_provider(cls, data: Any) -> Any:
|
|
||||||
"""Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug."""
|
|
||||||
if isinstance(data, dict):
|
|
||||||
prov = data.get("provider", "")
|
|
||||||
if isinstance(prov, str) and prov.startswith("ProviderName."):
|
|
||||||
member = prov.removeprefix("ProviderName.")
|
|
||||||
try:
|
|
||||||
data = {**data, "provider": ProviderName[member].value}
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
return data
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_host(cred: Credentials) -> str | None:
|
|
||||||
"""Extract host from credential: HostScoped host or MCP server URL."""
|
|
||||||
if isinstance(cred, HostScopedCredentials):
|
|
||||||
return cred.host
|
|
||||||
if isinstance(cred, OAuth2Credentials) and cred.provider in (
|
|
||||||
ProviderName.MCP,
|
|
||||||
ProviderName.MCP.value,
|
|
||||||
"ProviderName.MCP",
|
|
||||||
):
|
|
||||||
return (cred.metadata or {}).get("mcp_server_url")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
||||||
async def callback(
|
async def callback(
|
||||||
@@ -211,7 +179,9 @@ async def callback(
|
|||||||
title=credentials.title,
|
title=credentials.title,
|
||||||
scopes=credentials.scopes,
|
scopes=credentials.scopes,
|
||||||
username=credentials.username,
|
username=credentials.username,
|
||||||
host=(CredentialsMetaResponse.get_host(credentials)),
|
host=(
|
||||||
|
credentials.host if isinstance(credentials, HostScopedCredentials) else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -229,7 +199,7 @@ async def list_credentials(
|
|||||||
title=cred.title,
|
title=cred.title,
|
||||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||||
host=CredentialsMetaResponse.get_host(cred),
|
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||||
)
|
)
|
||||||
for cred in credentials
|
for cred in credentials
|
||||||
]
|
]
|
||||||
@@ -252,7 +222,7 @@ async def list_credentials_by_provider(
|
|||||||
title=cred.title,
|
title=cred.title,
|
||||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||||
host=CredentialsMetaResponse.get_host(cred),
|
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||||
)
|
)
|
||||||
for cred in credentials
|
for cred in credentials
|
||||||
]
|
]
|
||||||
@@ -352,11 +322,7 @@ async def delete_credentials(
|
|||||||
|
|
||||||
tokens_revoked = None
|
tokens_revoked = None
|
||||||
if isinstance(creds, OAuth2Credentials):
|
if isinstance(creds, OAuth2Credentials):
|
||||||
if provider_matches(provider.value, ProviderName.MCP.value):
|
handler = _get_provider_oauth_handler(request, provider)
|
||||||
# MCP uses dynamic per-server OAuth — create handler from metadata
|
|
||||||
handler = create_mcp_oauth_handler(creds)
|
|
||||||
else:
|
|
||||||
handler = _get_provider_oauth_handler(request, provider)
|
|
||||||
tokens_revoked = await handler.revoke_tokens(creds)
|
tokens_revoked = await handler.revoke_tokens(creds)
|
||||||
|
|
||||||
return CredentialsDeletionResponse(revoked=tokens_revoked)
|
return CredentialsDeletionResponse(revoked=tokens_revoked)
|
||||||
|
|||||||
@@ -12,11 +12,12 @@ import backend.api.features.store.image_gen as store_image_gen
|
|||||||
import backend.api.features.store.media as store_media
|
import backend.api.features.store.media as store_media
|
||||||
import backend.data.graph as graph_db
|
import backend.data.graph as graph_db
|
||||||
import backend.data.integrations as integrations_db
|
import backend.data.integrations as integrations_db
|
||||||
|
from backend.data.block import BlockInput
|
||||||
from backend.data.db import transaction
|
from backend.data.db import transaction
|
||||||
from backend.data.execution import get_graph_execution
|
from backend.data.execution import get_graph_execution
|
||||||
from backend.data.graph import GraphSettings
|
from backend.data.graph import GraphSettings
|
||||||
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
||||||
from backend.data.model import CredentialsMetaInput, GraphInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||||
on_graph_activate,
|
on_graph_activate,
|
||||||
@@ -1129,7 +1130,7 @@ async def create_preset_from_graph_execution(
|
|||||||
async def update_preset(
|
async def update_preset(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
preset_id: str,
|
preset_id: str,
|
||||||
inputs: Optional[GraphInput] = None,
|
inputs: Optional[BlockInput] = None,
|
||||||
credentials: Optional[dict[str, CredentialsMetaInput]] = None,
|
credentials: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
description: Optional[str] = None,
|
description: Optional[str] = None,
|
||||||
|
|||||||
@@ -6,12 +6,9 @@ import prisma.enums
|
|||||||
import prisma.models
|
import prisma.models
|
||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
|
from backend.data.block import BlockInput
|
||||||
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
|
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
|
||||||
from backend.data.model import (
|
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
|
||||||
CredentialsMetaInput,
|
|
||||||
GraphInput,
|
|
||||||
is_credentials_field_name,
|
|
||||||
)
|
|
||||||
from backend.util.json import loads as json_loads
|
from backend.util.json import loads as json_loads
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
@@ -326,7 +323,7 @@ class LibraryAgentPresetCreatable(pydantic.BaseModel):
|
|||||||
graph_id: str
|
graph_id: str
|
||||||
graph_version: int
|
graph_version: int
|
||||||
|
|
||||||
inputs: GraphInput
|
inputs: BlockInput
|
||||||
credentials: dict[str, CredentialsMetaInput]
|
credentials: dict[str, CredentialsMetaInput]
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
@@ -355,7 +352,7 @@ class LibraryAgentPresetUpdatable(pydantic.BaseModel):
|
|||||||
Request model used when updating a preset for a library agent.
|
Request model used when updating a preset for a library agent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
inputs: Optional[GraphInput] = None
|
inputs: Optional[BlockInput] = None
|
||||||
credentials: Optional[dict[str, CredentialsMetaInput]] = None
|
credentials: Optional[dict[str, CredentialsMetaInput]] = None
|
||||||
|
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
@@ -398,7 +395,7 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
|
|||||||
"Webhook must be included in AgentPreset query when webhookId is set"
|
"Webhook must be included in AgentPreset query when webhookId is set"
|
||||||
)
|
)
|
||||||
|
|
||||||
input_data: GraphInput = {}
|
input_data: BlockInput = {}
|
||||||
input_credentials: dict[str, CredentialsMetaInput] = {}
|
input_credentials: dict[str, CredentialsMetaInput] = {}
|
||||||
|
|
||||||
for preset_input in preset.InputPresets:
|
for preset_input in preset.InputPresets:
|
||||||
|
|||||||
@@ -1,404 +0,0 @@
|
|||||||
"""
|
|
||||||
MCP (Model Context Protocol) API routes.
|
|
||||||
|
|
||||||
Provides endpoints for MCP tool discovery and OAuth authentication so the
|
|
||||||
frontend can list available tools on an MCP server before placing a block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Annotated, Any
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import fastapi
|
|
||||||
from autogpt_libs.auth import get_user_id
|
|
||||||
from fastapi import Security
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from backend.api.features.integrations.router import CredentialsMetaResponse
|
|
||||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
|
||||||
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.util.request import HTTPClientError, Requests
|
|
||||||
from backend.util.settings import Settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
router = fastapi.APIRouter(tags=["mcp"])
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Tool Discovery ====================== #
|
|
||||||
|
|
||||||
|
|
||||||
class DiscoverToolsRequest(BaseModel):
|
|
||||||
"""Request to discover tools on an MCP server."""
|
|
||||||
|
|
||||||
server_url: str = Field(description="URL of the MCP server")
|
|
||||||
auth_token: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Optional Bearer token for authenticated MCP servers",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MCPToolResponse(BaseModel):
|
|
||||||
"""A single MCP tool returned by discovery."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
input_schema: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class DiscoverToolsResponse(BaseModel):
|
|
||||||
"""Response containing the list of tools available on an MCP server."""
|
|
||||||
|
|
||||||
tools: list[MCPToolResponse]
|
|
||||||
server_name: str | None = None
|
|
||||||
protocol_version: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/discover-tools",
|
|
||||||
summary="Discover available tools on an MCP server",
|
|
||||||
response_model=DiscoverToolsResponse,
|
|
||||||
)
|
|
||||||
async def discover_tools(
|
|
||||||
request: DiscoverToolsRequest,
|
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
|
||||||
) -> DiscoverToolsResponse:
|
|
||||||
"""
|
|
||||||
Connect to an MCP server and return its available tools.
|
|
||||||
|
|
||||||
If the user has a stored MCP credential for this server URL, it will be
|
|
||||||
used automatically — no need to pass an explicit auth token.
|
|
||||||
"""
|
|
||||||
auth_token = request.auth_token
|
|
||||||
|
|
||||||
# Auto-use stored MCP credential when no explicit token is provided.
|
|
||||||
if not auth_token:
|
|
||||||
mcp_creds = await creds_manager.store.get_creds_by_provider(
|
|
||||||
user_id, ProviderName.MCP.value
|
|
||||||
)
|
|
||||||
# Find the freshest credential for this server URL
|
|
||||||
best_cred: OAuth2Credentials | None = None
|
|
||||||
for cred in mcp_creds:
|
|
||||||
if (
|
|
||||||
isinstance(cred, OAuth2Credentials)
|
|
||||||
and (cred.metadata or {}).get("mcp_server_url") == request.server_url
|
|
||||||
):
|
|
||||||
if best_cred is None or (
|
|
||||||
(cred.access_token_expires_at or 0)
|
|
||||||
> (best_cred.access_token_expires_at or 0)
|
|
||||||
):
|
|
||||||
best_cred = cred
|
|
||||||
if best_cred:
|
|
||||||
# Refresh the token if expired before using it
|
|
||||||
best_cred = await creds_manager.refresh_if_needed(user_id, best_cred)
|
|
||||||
logger.info(
|
|
||||||
f"Using MCP credential {best_cred.id} for {request.server_url}, "
|
|
||||||
f"expires_at={best_cred.access_token_expires_at}"
|
|
||||||
)
|
|
||||||
auth_token = best_cred.access_token.get_secret_value()
|
|
||||||
|
|
||||||
client = MCPClient(request.server_url, auth_token=auth_token)
|
|
||||||
|
|
||||||
try:
|
|
||||||
init_result = await client.initialize()
|
|
||||||
tools = await client.list_tools()
|
|
||||||
except HTTPClientError as e:
|
|
||||||
if e.status_code in (401, 403):
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail="This MCP server requires authentication. "
|
|
||||||
"Please provide a valid auth token.",
|
|
||||||
)
|
|
||||||
raise fastapi.HTTPException(status_code=502, detail=str(e))
|
|
||||||
except MCPClientError as e:
|
|
||||||
raise fastapi.HTTPException(status_code=502, detail=str(e))
|
|
||||||
except Exception as e:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=502,
|
|
||||||
detail=f"Failed to connect to MCP server: {e}",
|
|
||||||
)
|
|
||||||
|
|
||||||
return DiscoverToolsResponse(
|
|
||||||
tools=[
|
|
||||||
MCPToolResponse(
|
|
||||||
name=t.name,
|
|
||||||
description=t.description,
|
|
||||||
input_schema=t.input_schema,
|
|
||||||
)
|
|
||||||
for t in tools
|
|
||||||
],
|
|
||||||
server_name=(
|
|
||||||
init_result.get("serverInfo", {}).get("name")
|
|
||||||
or urlparse(request.server_url).hostname
|
|
||||||
or "MCP"
|
|
||||||
),
|
|
||||||
protocol_version=init_result.get("protocolVersion"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ======================== OAuth Flow ======================== #
|
|
||||||
|
|
||||||
|
|
||||||
class MCPOAuthLoginRequest(BaseModel):
|
|
||||||
"""Request to start an OAuth flow for an MCP server."""
|
|
||||||
|
|
||||||
server_url: str = Field(description="URL of the MCP server that requires OAuth")
|
|
||||||
|
|
||||||
|
|
||||||
class MCPOAuthLoginResponse(BaseModel):
|
|
||||||
"""Response with the OAuth login URL for the user to authenticate."""
|
|
||||||
|
|
||||||
login_url: str
|
|
||||||
state_token: str
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/oauth/login",
|
|
||||||
summary="Initiate OAuth login for an MCP server",
|
|
||||||
)
|
|
||||||
async def mcp_oauth_login(
|
|
||||||
request: MCPOAuthLoginRequest,
|
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
|
||||||
) -> MCPOAuthLoginResponse:
|
|
||||||
"""
|
|
||||||
Discover OAuth metadata from the MCP server and return a login URL.
|
|
||||||
|
|
||||||
1. Discovers the protected-resource metadata (RFC 9728)
|
|
||||||
2. Fetches the authorization server metadata (RFC 8414)
|
|
||||||
3. Performs Dynamic Client Registration (RFC 7591) if available
|
|
||||||
4. Returns the authorization URL for the frontend to open in a popup
|
|
||||||
"""
|
|
||||||
client = MCPClient(request.server_url)
|
|
||||||
|
|
||||||
# Step 1: Discover protected-resource metadata (RFC 9728)
|
|
||||||
protected_resource = await client.discover_auth()
|
|
||||||
|
|
||||||
metadata: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
if protected_resource and protected_resource.get("authorization_servers"):
|
|
||||||
auth_server_url = protected_resource["authorization_servers"][0]
|
|
||||||
resource_url = protected_resource.get("resource", request.server_url)
|
|
||||||
|
|
||||||
# Step 2a: Discover auth-server metadata (RFC 8414)
|
|
||||||
metadata = await client.discover_auth_server_metadata(auth_server_url)
|
|
||||||
else:
|
|
||||||
# Fallback: Some MCP servers (e.g. Linear) are their own auth server
|
|
||||||
# and serve OAuth metadata directly without protected-resource metadata.
|
|
||||||
# Don't assume a resource_url — omitting it lets the auth server choose
|
|
||||||
# the correct audience for the token (RFC 8707 resource is optional).
|
|
||||||
resource_url = None
|
|
||||||
metadata = await client.discover_auth_server_metadata(request.server_url)
|
|
||||||
|
|
||||||
if (
|
|
||||||
not metadata
|
|
||||||
or "authorization_endpoint" not in metadata
|
|
||||||
or "token_endpoint" not in metadata
|
|
||||||
):
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="This MCP server does not advertise OAuth support. "
|
|
||||||
"You may need to provide an auth token manually.",
|
|
||||||
)
|
|
||||||
|
|
||||||
authorize_url = metadata["authorization_endpoint"]
|
|
||||||
token_url = metadata["token_endpoint"]
|
|
||||||
registration_endpoint = metadata.get("registration_endpoint")
|
|
||||||
revoke_url = metadata.get("revocation_endpoint")
|
|
||||||
|
|
||||||
# Step 3: Dynamic Client Registration (RFC 7591) if available
|
|
||||||
frontend_base_url = settings.config.frontend_base_url
|
|
||||||
if not frontend_base_url:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="Frontend base URL is not configured.",
|
|
||||||
)
|
|
||||||
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
|
|
||||||
|
|
||||||
client_id = ""
|
|
||||||
client_secret = ""
|
|
||||||
if registration_endpoint:
|
|
||||||
reg_result = await _register_mcp_client(
|
|
||||||
registration_endpoint, redirect_uri, request.server_url
|
|
||||||
)
|
|
||||||
if reg_result:
|
|
||||||
client_id = reg_result.get("client_id", "")
|
|
||||||
client_secret = reg_result.get("client_secret", "")
|
|
||||||
|
|
||||||
if not client_id:
|
|
||||||
client_id = "autogpt-platform"
|
|
||||||
|
|
||||||
# Step 4: Store state token with OAuth metadata for the callback
|
|
||||||
scopes = (protected_resource or {}).get("scopes_supported") or metadata.get(
|
|
||||||
"scopes_supported", []
|
|
||||||
)
|
|
||||||
state_token, code_challenge = await creds_manager.store.store_state_token(
|
|
||||||
user_id,
|
|
||||||
ProviderName.MCP.value,
|
|
||||||
scopes,
|
|
||||||
state_metadata={
|
|
||||||
"authorize_url": authorize_url,
|
|
||||||
"token_url": token_url,
|
|
||||||
"revoke_url": revoke_url,
|
|
||||||
"resource_url": resource_url,
|
|
||||||
"server_url": request.server_url,
|
|
||||||
"client_id": client_id,
|
|
||||||
"client_secret": client_secret,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 5: Build and return the login URL
|
|
||||||
handler = MCPOAuthHandler(
|
|
||||||
client_id=client_id,
|
|
||||||
client_secret=client_secret,
|
|
||||||
redirect_uri=redirect_uri,
|
|
||||||
authorize_url=authorize_url,
|
|
||||||
token_url=token_url,
|
|
||||||
resource_url=resource_url,
|
|
||||||
)
|
|
||||||
login_url = handler.get_login_url(
|
|
||||||
scopes, state_token, code_challenge=code_challenge
|
|
||||||
)
|
|
||||||
|
|
||||||
return MCPOAuthLoginResponse(login_url=login_url, state_token=state_token)
|
|
||||||
|
|
||||||
|
|
||||||
class MCPOAuthCallbackRequest(BaseModel):
|
|
||||||
"""Request to exchange an OAuth code for tokens."""
|
|
||||||
|
|
||||||
code: str = Field(description="Authorization code from OAuth callback")
|
|
||||||
state_token: str = Field(description="State token for CSRF verification")
|
|
||||||
|
|
||||||
|
|
||||||
class MCPOAuthCallbackResponse(BaseModel):
|
|
||||||
"""Response after successfully storing OAuth credentials."""
|
|
||||||
|
|
||||||
credential_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/oauth/callback",
|
|
||||||
summary="Exchange OAuth code for MCP tokens",
|
|
||||||
)
|
|
||||||
async def mcp_oauth_callback(
|
|
||||||
request: MCPOAuthCallbackRequest,
|
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
|
||||||
) -> CredentialsMetaResponse:
|
|
||||||
"""
|
|
||||||
Exchange the authorization code for tokens and store the credential.
|
|
||||||
|
|
||||||
The frontend calls this after receiving the OAuth code from the popup.
|
|
||||||
On success, subsequent ``/discover-tools`` calls for the same server URL
|
|
||||||
will automatically use the stored credential.
|
|
||||||
"""
|
|
||||||
valid_state = await creds_manager.store.verify_state_token(
|
|
||||||
user_id, request.state_token, ProviderName.MCP.value
|
|
||||||
)
|
|
||||||
if not valid_state:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Invalid or expired state token.",
|
|
||||||
)
|
|
||||||
|
|
||||||
meta = valid_state.state_metadata
|
|
||||||
frontend_base_url = settings.config.frontend_base_url
|
|
||||||
if not frontend_base_url:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="Frontend base URL is not configured.",
|
|
||||||
)
|
|
||||||
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
|
|
||||||
|
|
||||||
handler = MCPOAuthHandler(
|
|
||||||
client_id=meta["client_id"],
|
|
||||||
client_secret=meta.get("client_secret", ""),
|
|
||||||
redirect_uri=redirect_uri,
|
|
||||||
authorize_url=meta["authorize_url"],
|
|
||||||
token_url=meta["token_url"],
|
|
||||||
revoke_url=meta.get("revoke_url"),
|
|
||||||
resource_url=meta.get("resource_url"),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
credentials = await handler.exchange_code_for_tokens(
|
|
||||||
request.code, valid_state.scopes, valid_state.code_verifier
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"OAuth token exchange failed: {e}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Enrich credential metadata for future lookup and token refresh
|
|
||||||
if credentials.metadata is None:
|
|
||||||
credentials.metadata = {}
|
|
||||||
credentials.metadata["mcp_server_url"] = meta["server_url"]
|
|
||||||
credentials.metadata["mcp_client_id"] = meta["client_id"]
|
|
||||||
credentials.metadata["mcp_client_secret"] = meta.get("client_secret", "")
|
|
||||||
credentials.metadata["mcp_token_url"] = meta["token_url"]
|
|
||||||
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
|
|
||||||
|
|
||||||
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
|
|
||||||
credentials.title = f"MCP: {hostname}"
|
|
||||||
|
|
||||||
# Remove old MCP credentials for the same server to prevent stale token buildup.
|
|
||||||
try:
|
|
||||||
old_creds = await creds_manager.store.get_creds_by_provider(
|
|
||||||
user_id, ProviderName.MCP.value
|
|
||||||
)
|
|
||||||
for old in old_creds:
|
|
||||||
if (
|
|
||||||
isinstance(old, OAuth2Credentials)
|
|
||||||
and (old.metadata or {}).get("mcp_server_url") == meta["server_url"]
|
|
||||||
):
|
|
||||||
await creds_manager.store.delete_creds_by_id(user_id, old.id)
|
|
||||||
logger.info(
|
|
||||||
f"Removed old MCP credential {old.id} for {meta['server_url']}"
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Could not clean up old MCP credentials", exc_info=True)
|
|
||||||
|
|
||||||
await creds_manager.create(user_id, credentials)
|
|
||||||
|
|
||||||
return CredentialsMetaResponse(
|
|
||||||
id=credentials.id,
|
|
||||||
provider=credentials.provider,
|
|
||||||
type=credentials.type,
|
|
||||||
title=credentials.title,
|
|
||||||
scopes=credentials.scopes,
|
|
||||||
username=credentials.username,
|
|
||||||
host=credentials.metadata.get("mcp_server_url"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ======================== Helpers ======================== #
|
|
||||||
|
|
||||||
|
|
||||||
async def _register_mcp_client(
|
|
||||||
registration_endpoint: str,
|
|
||||||
redirect_uri: str,
|
|
||||||
server_url: str,
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Attempt Dynamic Client Registration (RFC 7591) with an MCP auth server."""
|
|
||||||
try:
|
|
||||||
response = await Requests(raise_for_status=True).post(
|
|
||||||
registration_endpoint,
|
|
||||||
json={
|
|
||||||
"client_name": "AutoGPT Platform",
|
|
||||||
"redirect_uris": [redirect_uri],
|
|
||||||
"grant_types": ["authorization_code"],
|
|
||||||
"response_types": ["code"],
|
|
||||||
"token_endpoint_auth_method": "client_secret_post",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
data = response.json()
|
|
||||||
if isinstance(data, dict) and "client_id" in data:
|
|
||||||
return data
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Dynamic client registration failed for {server_url}: {e}")
|
|
||||||
return None
|
|
||||||
@@ -1,436 +0,0 @@
|
|||||||
"""Tests for MCP API routes.
|
|
||||||
|
|
||||||
Uses httpx.AsyncClient with ASGITransport instead of fastapi.testclient.TestClient
|
|
||||||
to avoid creating blocking portals that can corrupt pytest-asyncio's session event loop.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, patch
|
|
||||||
|
|
||||||
import fastapi
|
|
||||||
import httpx
|
|
||||||
import pytest
|
|
||||||
import pytest_asyncio
|
|
||||||
from autogpt_libs.auth import get_user_id
|
|
||||||
|
|
||||||
from backend.api.features.mcp.routes import router
|
|
||||||
from backend.blocks.mcp.client import MCPClientError, MCPTool
|
|
||||||
from backend.util.request import HTTPClientError
|
|
||||||
|
|
||||||
app = fastapi.FastAPI()
|
|
||||||
app.include_router(router)
|
|
||||||
app.dependency_overrides[get_user_id] = lambda: "test-user-id"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="module")
|
|
||||||
async def client():
|
|
||||||
transport = httpx.ASGITransport(app=app)
|
|
||||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
|
|
||||||
yield c
|
|
||||||
|
|
||||||
|
|
||||||
class TestDiscoverTools:
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_discover_tools_success(self, client):
|
|
||||||
mock_tools = [
|
|
||||||
MCPTool(
|
|
||||||
name="get_weather",
|
|
||||||
description="Get weather for a city",
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"city": {"type": "string"}},
|
|
||||||
"required": ["city"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
MCPTool(
|
|
||||||
name="add_numbers",
|
|
||||||
description="Add two numbers",
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"a": {"type": "number"},
|
|
||||||
"b": {"type": "number"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
):
|
|
||||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
return_value={
|
|
||||||
"protocolVersion": "2025-03-26",
|
|
||||||
"serverInfo": {"name": "test-server"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
instance.list_tools = AsyncMock(return_value=mock_tools)
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://mcp.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert len(data["tools"]) == 2
|
|
||||||
assert data["tools"][0]["name"] == "get_weather"
|
|
||||||
assert data["tools"][1]["name"] == "add_numbers"
|
|
||||||
assert data["server_name"] == "test-server"
|
|
||||||
assert data["protocol_version"] == "2025-03-26"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_discover_tools_with_auth_token(self, client):
|
|
||||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
|
|
||||||
)
|
|
||||||
instance.list_tools = AsyncMock(return_value=[])
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={
|
|
||||||
"server_url": "https://mcp.example.com/mcp",
|
|
||||||
"auth_token": "my-secret-token",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
MockClient.assert_called_once_with(
|
|
||||||
"https://mcp.example.com/mcp",
|
|
||||||
auth_token="my-secret-token",
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_discover_tools_auto_uses_stored_credential(self, client):
|
|
||||||
"""When no explicit token is given, stored MCP credentials are used."""
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
|
|
||||||
stored_cred = OAuth2Credentials(
|
|
||||||
provider="mcp",
|
|
||||||
title="MCP: example.com",
|
|
||||||
access_token=SecretStr("stored-token-123"),
|
|
||||||
refresh_token=None,
|
|
||||||
access_token_expires_at=None,
|
|
||||||
refresh_token_expires_at=None,
|
|
||||||
scopes=[],
|
|
||||||
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
):
|
|
||||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred])
|
|
||||||
mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred)
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
|
|
||||||
)
|
|
||||||
instance.list_tools = AsyncMock(return_value=[])
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://mcp.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
MockClient.assert_called_once_with(
|
|
||||||
"https://mcp.example.com/mcp",
|
|
||||||
auth_token="stored-token-123",
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_discover_tools_mcp_error(self, client):
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
):
|
|
||||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
side_effect=MCPClientError("Connection refused")
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://bad-server.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 502
|
|
||||||
assert "Connection refused" in response.json()["detail"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_discover_tools_generic_error(self, client):
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
):
|
|
||||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://timeout.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 502
|
|
||||||
assert "Failed to connect" in response.json()["detail"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_discover_tools_auth_required(self, client):
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
):
|
|
||||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://auth-server.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 401
|
|
||||||
assert "requires authentication" in response.json()["detail"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_discover_tools_forbidden(self, client):
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
):
|
|
||||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://auth-server.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 401
|
|
||||||
assert "requires authentication" in response.json()["detail"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_discover_tools_missing_url(self, client):
|
|
||||||
response = await client.post("/discover-tools", json={})
|
|
||||||
assert response.status_code == 422
|
|
||||||
|
|
||||||
|
|
||||||
class TestOAuthLogin:
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_oauth_login_success(self, client):
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
|
||||||
patch(
|
|
||||||
"backend.api.features.mcp.routes._register_mcp_client"
|
|
||||||
) as mock_register,
|
|
||||||
):
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.discover_auth = AsyncMock(
|
|
||||||
return_value={
|
|
||||||
"authorization_servers": ["https://auth.sentry.io"],
|
|
||||||
"resource": "https://mcp.sentry.dev/mcp",
|
|
||||||
"scopes_supported": ["openid"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
instance.discover_auth_server_metadata = AsyncMock(
|
|
||||||
return_value={
|
|
||||||
"authorization_endpoint": "https://auth.sentry.io/authorize",
|
|
||||||
"token_endpoint": "https://auth.sentry.io/token",
|
|
||||||
"registration_endpoint": "https://auth.sentry.io/register",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
mock_register.return_value = {
|
|
||||||
"client_id": "registered-client-id",
|
|
||||||
"client_secret": "registered-secret",
|
|
||||||
}
|
|
||||||
mock_cm.store.store_state_token = AsyncMock(
|
|
||||||
return_value=("state-token-123", "code-challenge-abc")
|
|
||||||
)
|
|
||||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/oauth/login",
|
|
||||||
json={"server_url": "https://mcp.sentry.dev/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert "login_url" in data
|
|
||||||
assert data["state_token"] == "state-token-123"
|
|
||||||
assert "auth.sentry.io/authorize" in data["login_url"]
|
|
||||||
assert "registered-client-id" in data["login_url"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_oauth_login_no_oauth_support(self, client):
|
|
||||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.discover_auth = AsyncMock(return_value=None)
|
|
||||||
instance.discover_auth_server_metadata = AsyncMock(return_value=None)
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/oauth/login",
|
|
||||||
json={"server_url": "https://simple-server.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "does not advertise OAuth" in response.json()["detail"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_oauth_login_fallback_to_public_client(self, client):
|
|
||||||
"""When DCR is unavailable, falls back to default public client ID."""
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
|
||||||
):
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.discover_auth = AsyncMock(
|
|
||||||
return_value={
|
|
||||||
"authorization_servers": ["https://auth.example.com"],
|
|
||||||
"resource": "https://mcp.example.com/mcp",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
instance.discover_auth_server_metadata = AsyncMock(
|
|
||||||
return_value={
|
|
||||||
"authorization_endpoint": "https://auth.example.com/authorize",
|
|
||||||
"token_endpoint": "https://auth.example.com/token",
|
|
||||||
# No registration_endpoint
|
|
||||||
}
|
|
||||||
)
|
|
||||||
mock_cm.store.store_state_token = AsyncMock(
|
|
||||||
return_value=("state-abc", "challenge-xyz")
|
|
||||||
)
|
|
||||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/oauth/login",
|
|
||||||
json={"server_url": "https://mcp.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert "autogpt-platform" in data["login_url"]
|
|
||||||
|
|
||||||
|
|
||||||
class TestOAuthCallback:
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_oauth_callback_success(self, client):
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
|
|
||||||
mock_creds = OAuth2Credentials(
|
|
||||||
provider="mcp",
|
|
||||||
title=None,
|
|
||||||
access_token=SecretStr("access-token-xyz"),
|
|
||||||
refresh_token=None,
|
|
||||||
access_token_expires_at=None,
|
|
||||||
refresh_token_expires_at=None,
|
|
||||||
scopes=[],
|
|
||||||
metadata={
|
|
||||||
"mcp_token_url": "https://auth.sentry.io/token",
|
|
||||||
"mcp_resource_url": "https://mcp.sentry.dev/mcp",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
|
||||||
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
|
|
||||||
):
|
|
||||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
|
||||||
|
|
||||||
# Mock state verification
|
|
||||||
mock_state = AsyncMock()
|
|
||||||
mock_state.state_metadata = {
|
|
||||||
"authorize_url": "https://auth.sentry.io/authorize",
|
|
||||||
"token_url": "https://auth.sentry.io/token",
|
|
||||||
"client_id": "test-client-id",
|
|
||||||
"client_secret": "test-secret",
|
|
||||||
"server_url": "https://mcp.sentry.dev/mcp",
|
|
||||||
}
|
|
||||||
mock_state.scopes = ["openid"]
|
|
||||||
mock_state.code_verifier = "verifier-123"
|
|
||||||
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
|
|
||||||
mock_cm.create = AsyncMock()
|
|
||||||
|
|
||||||
handler_instance = MockHandler.return_value
|
|
||||||
handler_instance.exchange_code_for_tokens = AsyncMock(
|
|
||||||
return_value=mock_creds
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock old credential cleanup
|
|
||||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/oauth/callback",
|
|
||||||
json={"code": "auth-code-abc", "state_token": "state-token-123"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert "id" in data
|
|
||||||
assert data["provider"] == "mcp"
|
|
||||||
assert data["type"] == "oauth2"
|
|
||||||
mock_cm.create.assert_called_once()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_oauth_callback_invalid_state(self, client):
|
|
||||||
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
|
|
||||||
mock_cm.store.verify_state_token = AsyncMock(return_value=None)
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/oauth/callback",
|
|
||||||
json={"code": "auth-code", "state_token": "bad-state"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "Invalid or expired" in response.json()["detail"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_oauth_callback_token_exchange_fails(self, client):
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
|
||||||
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
|
|
||||||
):
|
|
||||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
|
||||||
mock_state = AsyncMock()
|
|
||||||
mock_state.state_metadata = {
|
|
||||||
"authorize_url": "https://auth.example.com/authorize",
|
|
||||||
"token_url": "https://auth.example.com/token",
|
|
||||||
"client_id": "cid",
|
|
||||||
"server_url": "https://mcp.example.com/mcp",
|
|
||||||
}
|
|
||||||
mock_state.scopes = []
|
|
||||||
mock_state.code_verifier = "v"
|
|
||||||
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
|
|
||||||
|
|
||||||
handler_instance = MockHandler.return_value
|
|
||||||
handler_instance.exchange_code_for_tokens = AsyncMock(
|
|
||||||
side_effect=RuntimeError("Token exchange failed")
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/oauth/callback",
|
|
||||||
json={"code": "bad-code", "state_token": "state"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "token exchange failed" in response.json()["detail"].lower()
|
|
||||||
@@ -5,8 +5,8 @@ from typing import Optional
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
from backend.blocks import get_block
|
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
|
from backend.data.block import get_block
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
from .models import ApiResponse, ChatRequest, GraphData
|
from .models import ApiResponse, ChatRequest, GraphData
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ class BlockHandler(ContentHandler):
|
|||||||
|
|
||||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||||
"""Fetch blocks without embeddings."""
|
"""Fetch blocks without embeddings."""
|
||||||
from backend.blocks import get_blocks
|
from backend.data.block import get_blocks
|
||||||
|
|
||||||
# Get all available blocks
|
# Get all available blocks
|
||||||
all_blocks = get_blocks()
|
all_blocks = get_blocks()
|
||||||
@@ -249,7 +249,7 @@ class BlockHandler(ContentHandler):
|
|||||||
|
|
||||||
async def get_stats(self) -> dict[str, int]:
|
async def get_stats(self) -> dict[str, int]:
|
||||||
"""Get statistics about block embedding coverage."""
|
"""Get statistics about block embedding coverage."""
|
||||||
from backend.blocks import get_blocks
|
from backend.data.block import get_blocks
|
||||||
|
|
||||||
all_blocks = get_blocks()
|
all_blocks = get_blocks()
|
||||||
|
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ async def test_block_handler_get_missing_items(mocker):
|
|||||||
mock_existing = []
|
mock_existing = []
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.blocks.get_blocks",
|
"backend.data.block.get_blocks",
|
||||||
return_value=mock_blocks,
|
return_value=mock_blocks,
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
@@ -135,7 +135,7 @@ async def test_block_handler_get_stats(mocker):
|
|||||||
mock_embedded = [{"count": 2}]
|
mock_embedded = [{"count": 2}]
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.blocks.get_blocks",
|
"backend.data.block.get_blocks",
|
||||||
return_value=mock_blocks,
|
return_value=mock_blocks,
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
@@ -327,7 +327,7 @@ async def test_block_handler_handles_missing_attributes():
|
|||||||
mock_blocks = {"block-minimal": mock_block_class}
|
mock_blocks = {"block-minimal": mock_block_class}
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.blocks.get_blocks",
|
"backend.data.block.get_blocks",
|
||||||
return_value=mock_blocks,
|
return_value=mock_blocks,
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
@@ -360,7 +360,7 @@ async def test_block_handler_skips_failed_blocks():
|
|||||||
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
|
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.blocks.get_blocks",
|
"backend.data.block.get_blocks",
|
||||||
return_value=mock_blocks,
|
return_value=mock_blocks,
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
|
|||||||
@@ -662,7 +662,7 @@ async def cleanup_orphaned_embeddings() -> dict[str, Any]:
|
|||||||
)
|
)
|
||||||
current_ids = {row["id"] for row in valid_agents}
|
current_ids = {row["id"] for row in valid_agents}
|
||||||
elif content_type == ContentType.BLOCK:
|
elif content_type == ContentType.BLOCK:
|
||||||
from backend.blocks import get_blocks
|
from backend.data.block import get_blocks
|
||||||
|
|
||||||
current_ids = set(get_blocks().keys())
|
current_ids = set(get_blocks().keys())
|
||||||
elif content_type == ContentType.DOCUMENTATION:
|
elif content_type == ContentType.DOCUMENTATION:
|
||||||
|
|||||||
@@ -7,6 +7,15 @@ from replicate.client import Client as ReplicateClient
|
|||||||
from replicate.exceptions import ReplicateError
|
from replicate.exceptions import ReplicateError
|
||||||
from replicate.helpers import FileOutput
|
from replicate.helpers import FileOutput
|
||||||
|
|
||||||
|
from backend.blocks.ideogram import (
|
||||||
|
AspectRatio,
|
||||||
|
ColorPalettePreset,
|
||||||
|
IdeogramModelBlock,
|
||||||
|
IdeogramModelName,
|
||||||
|
MagicPromptOption,
|
||||||
|
StyleType,
|
||||||
|
UpscaleOption,
|
||||||
|
)
|
||||||
from backend.data.graph import GraphBaseMeta
|
from backend.data.graph import GraphBaseMeta
|
||||||
from backend.data.model import CredentialsMetaInput, ProviderName
|
from backend.data.model import CredentialsMetaInput, ProviderName
|
||||||
from backend.integrations.credentials_store import ideogram_credentials
|
from backend.integrations.credentials_store import ideogram_credentials
|
||||||
@@ -41,16 +50,6 @@ async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.Bytes
|
|||||||
if not ideogram_credentials.api_key:
|
if not ideogram_credentials.api_key:
|
||||||
raise ValueError("Missing Ideogram API key")
|
raise ValueError("Missing Ideogram API key")
|
||||||
|
|
||||||
from backend.blocks.ideogram import (
|
|
||||||
AspectRatio,
|
|
||||||
ColorPalettePreset,
|
|
||||||
IdeogramModelBlock,
|
|
||||||
IdeogramModelName,
|
|
||||||
MagicPromptOption,
|
|
||||||
StyleType,
|
|
||||||
UpscaleOption,
|
|
||||||
)
|
|
||||||
|
|
||||||
name = graph.name
|
name = graph.name
|
||||||
description = f"{name} ({graph.description})" if graph.description else name
|
description = f"{name} ({graph.description})" if graph.description else name
|
||||||
|
|
||||||
|
|||||||
@@ -40,11 +40,10 @@ from backend.api.model import (
|
|||||||
UpdateTimezoneRequest,
|
UpdateTimezoneRequest,
|
||||||
UploadFileResponse,
|
UploadFileResponse,
|
||||||
)
|
)
|
||||||
from backend.blocks import get_block, get_blocks
|
|
||||||
from backend.data import execution as execution_db
|
from backend.data import execution as execution_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data.auth import api_key as api_key_db
|
from backend.data.auth import api_key as api_key_db
|
||||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
|
||||||
from backend.data.credit import (
|
from backend.data.credit import (
|
||||||
AutoTopUpConfig,
|
AutoTopUpConfig,
|
||||||
RefundRequest,
|
RefundRequest,
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ import backend.api.features.executions.review.routes
|
|||||||
import backend.api.features.library.db
|
import backend.api.features.library.db
|
||||||
import backend.api.features.library.model
|
import backend.api.features.library.model
|
||||||
import backend.api.features.library.routes
|
import backend.api.features.library.routes
|
||||||
import backend.api.features.mcp.routes as mcp_routes
|
|
||||||
import backend.api.features.oauth
|
import backend.api.features.oauth
|
||||||
import backend.api.features.otto.routes
|
import backend.api.features.otto.routes
|
||||||
import backend.api.features.postmark.postmark
|
import backend.api.features.postmark.postmark
|
||||||
@@ -344,11 +343,6 @@ app.include_router(
|
|||||||
tags=["workspace"],
|
tags=["workspace"],
|
||||||
prefix="/api/workspace",
|
prefix="/api/workspace",
|
||||||
)
|
)
|
||||||
app.include_router(
|
|
||||||
mcp_routes.router,
|
|
||||||
tags=["v2", "mcp"],
|
|
||||||
prefix="/api/mcp",
|
|
||||||
)
|
|
||||||
app.include_router(
|
app.include_router(
|
||||||
backend.api.features.oauth.router,
|
backend.api.features.oauth.router,
|
||||||
tags=["oauth"],
|
tags=["oauth"],
|
||||||
|
|||||||
@@ -3,19 +3,22 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Sequence, Type, TypeVar
|
from typing import TYPE_CHECKING, TypeVar
|
||||||
|
|
||||||
from backend.blocks._base import AnyBlockSchema, BlockType
|
|
||||||
from backend.util.cache import cached
|
from backend.util.cache import cached
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from backend.data.block import Block
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl_seconds=3600)
|
@cached(ttl_seconds=3600)
|
||||||
def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
|
def load_all_blocks() -> dict[str, type["Block"]]:
|
||||||
from backend.blocks._base import Block
|
from backend.data.block import Block
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
|
|
||||||
# Check if example blocks should be loaded from settings
|
# Check if example blocks should be loaded from settings
|
||||||
@@ -47,8 +50,8 @@ def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
|
|||||||
importlib.import_module(f".{module}", package=__name__)
|
importlib.import_module(f".{module}", package=__name__)
|
||||||
|
|
||||||
# Load all Block instances from the available modules
|
# Load all Block instances from the available modules
|
||||||
available_blocks: dict[str, type["AnyBlockSchema"]] = {}
|
available_blocks: dict[str, type["Block"]] = {}
|
||||||
for block_cls in _all_subclasses(Block):
|
for block_cls in all_subclasses(Block):
|
||||||
class_name = block_cls.__name__
|
class_name = block_cls.__name__
|
||||||
|
|
||||||
if class_name.endswith("Base"):
|
if class_name.endswith("Base"):
|
||||||
@@ -61,7 +64,7 @@ def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
|
|||||||
"please name the class with 'Base' at the end"
|
"please name the class with 'Base' at the end"
|
||||||
)
|
)
|
||||||
|
|
||||||
block = block_cls() # pyright: ignore[reportAbstractUsage]
|
block = block_cls.create()
|
||||||
|
|
||||||
if not isinstance(block.id, str) or len(block.id) != 36:
|
if not isinstance(block.id, str) or len(block.id) != 36:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -102,7 +105,7 @@ def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
|
|||||||
available_blocks[block.id] = block_cls
|
available_blocks[block.id] = block_cls
|
||||||
|
|
||||||
# Filter out blocks with incomplete auth configs, e.g. missing OAuth server secrets
|
# Filter out blocks with incomplete auth configs, e.g. missing OAuth server secrets
|
||||||
from ._utils import is_block_auth_configured
|
from backend.data.block import is_block_auth_configured
|
||||||
|
|
||||||
filtered_blocks = {}
|
filtered_blocks = {}
|
||||||
for block_id, block_cls in available_blocks.items():
|
for block_id, block_cls in available_blocks.items():
|
||||||
@@ -112,48 +115,11 @@ def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
|
|||||||
return filtered_blocks
|
return filtered_blocks
|
||||||
|
|
||||||
|
|
||||||
def _all_subclasses(cls: type[T]) -> list[type[T]]:
|
__all__ = ["load_all_blocks"]
|
||||||
|
|
||||||
|
|
||||||
|
def all_subclasses(cls: type[T]) -> list[type[T]]:
|
||||||
subclasses = cls.__subclasses__()
|
subclasses = cls.__subclasses__()
|
||||||
for subclass in subclasses:
|
for subclass in subclasses:
|
||||||
subclasses += _all_subclasses(subclass)
|
subclasses += all_subclasses(subclass)
|
||||||
return subclasses
|
return subclasses
|
||||||
|
|
||||||
|
|
||||||
# ============== Block access helper functions ============== #
|
|
||||||
|
|
||||||
|
|
||||||
def get_blocks() -> dict[str, Type["AnyBlockSchema"]]:
|
|
||||||
return load_all_blocks()
|
|
||||||
|
|
||||||
|
|
||||||
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
|
|
||||||
def get_block(block_id: str) -> "AnyBlockSchema | None":
|
|
||||||
cls = get_blocks().get(block_id)
|
|
||||||
return cls() if cls else None
|
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl_seconds=3600)
|
|
||||||
def get_webhook_block_ids() -> Sequence[str]:
|
|
||||||
return [
|
|
||||||
id
|
|
||||||
for id, B in get_blocks().items()
|
|
||||||
if B().block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl_seconds=3600)
|
|
||||||
def get_io_block_ids() -> Sequence[str]:
|
|
||||||
return [
|
|
||||||
id
|
|
||||||
for id, B in get_blocks().items()
|
|
||||||
if B().block_type in (BlockType.INPUT, BlockType.OUTPUT)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl_seconds=3600)
|
|
||||||
def get_human_in_the_loop_block_ids() -> Sequence[str]:
|
|
||||||
return [
|
|
||||||
id
|
|
||||||
for id, B in get_blocks().items()
|
|
||||||
if B().block_type == BlockType.HUMAN_IN_THE_LOOP
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -1,740 +0,0 @@
|
|||||||
import inspect
|
|
||||||
import logging
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from enum import Enum
|
|
||||||
from typing import (
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
ClassVar,
|
|
||||||
Generic,
|
|
||||||
Optional,
|
|
||||||
Type,
|
|
||||||
TypeAlias,
|
|
||||||
TypeVar,
|
|
||||||
cast,
|
|
||||||
get_origin,
|
|
||||||
)
|
|
||||||
|
|
||||||
import jsonref
|
|
||||||
import jsonschema
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry
|
|
||||||
from backend.data.model import (
|
|
||||||
Credentials,
|
|
||||||
CredentialsFieldInfo,
|
|
||||||
CredentialsMetaInput,
|
|
||||||
SchemaField,
|
|
||||||
is_credentials_field_name,
|
|
||||||
)
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.util import json
|
|
||||||
from backend.util.exceptions import (
|
|
||||||
BlockError,
|
|
||||||
BlockExecutionError,
|
|
||||||
BlockInputError,
|
|
||||||
BlockOutputError,
|
|
||||||
BlockUnknownError,
|
|
||||||
)
|
|
||||||
from backend.util.settings import Config
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import ContributorDetails, NodeExecutionStats
|
|
||||||
|
|
||||||
from ..data.graph import Link
|
|
||||||
|
|
||||||
app_config = Config()
|
|
||||||
|
|
||||||
|
|
||||||
BlockTestOutput = BlockOutputEntry | tuple[str, Callable[[Any], bool]]
|
|
||||||
|
|
||||||
|
|
||||||
class BlockType(Enum):
|
|
||||||
STANDARD = "Standard"
|
|
||||||
INPUT = "Input"
|
|
||||||
OUTPUT = "Output"
|
|
||||||
NOTE = "Note"
|
|
||||||
WEBHOOK = "Webhook"
|
|
||||||
WEBHOOK_MANUAL = "Webhook (manual)"
|
|
||||||
AGENT = "Agent"
|
|
||||||
AI = "AI"
|
|
||||||
AYRSHARE = "Ayrshare"
|
|
||||||
HUMAN_IN_THE_LOOP = "Human In The Loop"
|
|
||||||
MCP_TOOL = "MCP Tool"
|
|
||||||
|
|
||||||
|
|
||||||
class BlockCategory(Enum):
|
|
||||||
AI = "Block that leverages AI to perform a task."
|
|
||||||
SOCIAL = "Block that interacts with social media platforms."
|
|
||||||
TEXT = "Block that processes text data."
|
|
||||||
SEARCH = "Block that searches or extracts information from the internet."
|
|
||||||
BASIC = "Block that performs basic operations."
|
|
||||||
INPUT = "Block that interacts with input of the graph."
|
|
||||||
OUTPUT = "Block that interacts with output of the graph."
|
|
||||||
LOGIC = "Programming logic to control the flow of your agent"
|
|
||||||
COMMUNICATION = "Block that interacts with communication platforms."
|
|
||||||
DEVELOPER_TOOLS = "Developer tools such as GitHub blocks."
|
|
||||||
DATA = "Block that interacts with structured data."
|
|
||||||
HARDWARE = "Block that interacts with hardware."
|
|
||||||
AGENT = "Block that interacts with other agents."
|
|
||||||
CRM = "Block that interacts with CRM services."
|
|
||||||
SAFETY = (
|
|
||||||
"Block that provides AI safety mechanisms such as detecting harmful content"
|
|
||||||
)
|
|
||||||
PRODUCTIVITY = "Block that helps with productivity"
|
|
||||||
ISSUE_TRACKING = "Block that helps with issue tracking"
|
|
||||||
MULTIMEDIA = "Block that interacts with multimedia content"
|
|
||||||
MARKETING = "Block that helps with marketing"
|
|
||||||
|
|
||||||
def dict(self) -> dict[str, str]:
|
|
||||||
return {"category": self.name, "description": self.value}
|
|
||||||
|
|
||||||
|
|
||||||
class BlockCostType(str, Enum):
|
|
||||||
RUN = "run" # cost X credits per run
|
|
||||||
BYTE = "byte" # cost X credits per byte
|
|
||||||
SECOND = "second" # cost X credits per second
|
|
||||||
|
|
||||||
|
|
||||||
class BlockCost(BaseModel):
|
|
||||||
cost_amount: int
|
|
||||||
cost_filter: BlockInput
|
|
||||||
cost_type: BlockCostType
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
cost_amount: int,
|
|
||||||
cost_type: BlockCostType = BlockCostType.RUN,
|
|
||||||
cost_filter: Optional[BlockInput] = None,
|
|
||||||
**data: Any,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(
|
|
||||||
cost_amount=cost_amount,
|
|
||||||
cost_filter=cost_filter or {},
|
|
||||||
cost_type=cost_type,
|
|
||||||
**data,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BlockInfo(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
inputSchema: dict[str, Any]
|
|
||||||
outputSchema: dict[str, Any]
|
|
||||||
costs: list[BlockCost]
|
|
||||||
description: str
|
|
||||||
categories: list[dict[str, str]]
|
|
||||||
contributors: list[dict[str, Any]]
|
|
||||||
staticOutput: bool
|
|
||||||
uiType: str
|
|
||||||
|
|
||||||
|
|
||||||
class BlockSchema(BaseModel):
|
|
||||||
cached_jsonschema: ClassVar[dict[str, Any]]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def jsonschema(cls) -> dict[str, Any]:
|
|
||||||
if cls.cached_jsonschema:
|
|
||||||
return cls.cached_jsonschema
|
|
||||||
|
|
||||||
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
|
|
||||||
|
|
||||||
def ref_to_dict(obj):
|
|
||||||
if isinstance(obj, dict):
|
|
||||||
# OpenAPI <3.1 does not support sibling fields that has a $ref key
|
|
||||||
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
|
|
||||||
keys = {"allOf", "anyOf", "oneOf"}
|
|
||||||
one_key = next((k for k in keys if k in obj and len(obj[k]) == 1), None)
|
|
||||||
if one_key:
|
|
||||||
obj.update(obj[one_key][0])
|
|
||||||
|
|
||||||
return {
|
|
||||||
key: ref_to_dict(value)
|
|
||||||
for key, value in obj.items()
|
|
||||||
if not key.startswith("$") and key != one_key
|
|
||||||
}
|
|
||||||
elif isinstance(obj, list):
|
|
||||||
return [ref_to_dict(item) for item in obj]
|
|
||||||
|
|
||||||
return obj
|
|
||||||
|
|
||||||
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
|
|
||||||
|
|
||||||
return cls.cached_jsonschema
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def validate_data(cls, data: BlockInput) -> str | None:
|
|
||||||
return json.validate_with_jsonschema(
|
|
||||||
schema=cls.jsonschema(),
|
|
||||||
data={k: v for k, v in data.items() if v is not None},
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
|
||||||
return cls.validate_data(data)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_field_schema(cls, field_name: str) -> dict[str, Any]:
|
|
||||||
model_schema = cls.jsonschema().get("properties", {})
|
|
||||||
if not model_schema:
|
|
||||||
raise ValueError(f"Invalid model schema {cls}")
|
|
||||||
|
|
||||||
property_schema = model_schema.get(field_name)
|
|
||||||
if not property_schema:
|
|
||||||
raise ValueError(f"Invalid property name {field_name}")
|
|
||||||
|
|
||||||
return property_schema
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def validate_field(cls, field_name: str, data: BlockInput) -> str | None:
|
|
||||||
"""
|
|
||||||
Validate the data against a specific property (one of the input/output name).
|
|
||||||
Returns the validation error message if the data does not match the schema.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
property_schema = cls.get_field_schema(field_name)
|
|
||||||
jsonschema.validate(json.to_dict(data), property_schema)
|
|
||||||
return None
|
|
||||||
except jsonschema.ValidationError as e:
|
|
||||||
return str(e)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_fields(cls) -> set[str]:
|
|
||||||
return set(cls.model_fields.keys())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_required_fields(cls) -> set[str]:
|
|
||||||
return {
|
|
||||||
field
|
|
||||||
for field, field_info in cls.model_fields.items()
|
|
||||||
if field_info.is_required()
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __pydantic_init_subclass__(cls, **kwargs):
|
|
||||||
"""Validates the schema definition. Rules:
|
|
||||||
- Fields with annotation `CredentialsMetaInput` MUST be
|
|
||||||
named `credentials` or `*_credentials`
|
|
||||||
- Fields named `credentials` or `*_credentials` MUST be
|
|
||||||
of type `CredentialsMetaInput`
|
|
||||||
"""
|
|
||||||
super().__pydantic_init_subclass__(**kwargs)
|
|
||||||
|
|
||||||
# Reset cached JSON schema to prevent inheriting it from parent class
|
|
||||||
cls.cached_jsonschema = {}
|
|
||||||
|
|
||||||
credentials_fields = cls.get_credentials_fields()
|
|
||||||
|
|
||||||
for field_name in cls.get_fields():
|
|
||||||
if is_credentials_field_name(field_name):
|
|
||||||
if field_name not in credentials_fields:
|
|
||||||
raise TypeError(
|
|
||||||
f"Credentials field '{field_name}' on {cls.__qualname__} "
|
|
||||||
f"is not of type {CredentialsMetaInput.__name__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
CredentialsMetaInput.validate_credentials_field_schema(
|
|
||||||
cls.get_field_schema(field_name), field_name
|
|
||||||
)
|
|
||||||
|
|
||||||
elif field_name in credentials_fields:
|
|
||||||
raise KeyError(
|
|
||||||
f"Credentials field '{field_name}' on {cls.__qualname__} "
|
|
||||||
"has invalid name: must be 'credentials' or *_credentials"
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_credentials_fields(cls) -> dict[str, type[CredentialsMetaInput]]:
|
|
||||||
return {
|
|
||||||
field_name: info.annotation
|
|
||||||
for field_name, info in cls.model_fields.items()
|
|
||||||
if (
|
|
||||||
inspect.isclass(info.annotation)
|
|
||||||
and issubclass(
|
|
||||||
get_origin(info.annotation) or info.annotation,
|
|
||||||
CredentialsMetaInput,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_auto_credentials_fields(cls) -> dict[str, dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Get fields that have auto_credentials metadata (e.g., GoogleDriveFileInput).
|
|
||||||
|
|
||||||
Returns a dict mapping kwarg_name -> {field_name, auto_credentials_config}
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If multiple fields have the same kwarg_name, as this would
|
|
||||||
cause silent overwriting and only the last field would be processed.
|
|
||||||
"""
|
|
||||||
result: dict[str, dict[str, Any]] = {}
|
|
||||||
schema = cls.jsonschema()
|
|
||||||
properties = schema.get("properties", {})
|
|
||||||
|
|
||||||
for field_name, field_schema in properties.items():
|
|
||||||
auto_creds = field_schema.get("auto_credentials")
|
|
||||||
if auto_creds:
|
|
||||||
kwarg_name = auto_creds.get("kwarg_name", "credentials")
|
|
||||||
if kwarg_name in result:
|
|
||||||
raise ValueError(
|
|
||||||
f"Duplicate auto_credentials kwarg_name '{kwarg_name}' "
|
|
||||||
f"in fields '{result[kwarg_name]['field_name']}' and "
|
|
||||||
f"'{field_name}' on {cls.__qualname__}"
|
|
||||||
)
|
|
||||||
result[kwarg_name] = {
|
|
||||||
"field_name": field_name,
|
|
||||||
"config": auto_creds,
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_credentials_fields_info(cls) -> dict[str, CredentialsFieldInfo]:
|
|
||||||
result = {}
|
|
||||||
|
|
||||||
# Regular credentials fields
|
|
||||||
for field_name in cls.get_credentials_fields().keys():
|
|
||||||
result[field_name] = CredentialsFieldInfo.model_validate(
|
|
||||||
cls.get_field_schema(field_name), by_alias=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Auto-generated credentials fields (from GoogleDriveFileInput etc.)
|
|
||||||
for kwarg_name, info in cls.get_auto_credentials_fields().items():
|
|
||||||
config = info["config"]
|
|
||||||
# Build a schema-like dict that CredentialsFieldInfo can parse
|
|
||||||
auto_schema = {
|
|
||||||
"credentials_provider": [config.get("provider", "google")],
|
|
||||||
"credentials_types": [config.get("type", "oauth2")],
|
|
||||||
"credentials_scopes": config.get("scopes"),
|
|
||||||
}
|
|
||||||
result[kwarg_name] = CredentialsFieldInfo.model_validate(
|
|
||||||
auto_schema, by_alias=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
|
||||||
return data # Return as is, by default.
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
|
|
||||||
input_fields_from_nodes = {link.sink_name for link in links}
|
|
||||||
return input_fields_from_nodes - set(data)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
|
||||||
return cls.get_required_fields() - set(data)
|
|
||||||
|
|
||||||
|
|
||||||
class BlockSchemaInput(BlockSchema):
|
|
||||||
"""
|
|
||||||
Base schema class for block inputs.
|
|
||||||
All block input schemas should extend this class for consistency.
|
|
||||||
"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BlockSchemaOutput(BlockSchema):
|
|
||||||
"""
|
|
||||||
Base schema class for block outputs that includes a standard error field.
|
|
||||||
All block output schemas should extend this class to ensure consistent error handling.
|
|
||||||
"""
|
|
||||||
|
|
||||||
error: str = SchemaField(
|
|
||||||
description="Error message if the operation failed", default=""
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchemaInput)
|
|
||||||
BlockSchemaOutputType = TypeVar("BlockSchemaOutputType", bound=BlockSchemaOutput)
|
|
||||||
|
|
||||||
|
|
||||||
class EmptyInputSchema(BlockSchemaInput):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class EmptyOutputSchema(BlockSchemaOutput):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# For backward compatibility - will be deprecated
|
|
||||||
EmptySchema = EmptyOutputSchema
|
|
||||||
|
|
||||||
|
|
||||||
# --8<-- [start:BlockWebhookConfig]
|
|
||||||
class BlockManualWebhookConfig(BaseModel):
|
|
||||||
"""
|
|
||||||
Configuration model for webhook-triggered blocks on which
|
|
||||||
the user has to manually set up the webhook at the provider.
|
|
||||||
"""
|
|
||||||
|
|
||||||
provider: ProviderName
|
|
||||||
"""The service provider that the webhook connects to"""
|
|
||||||
|
|
||||||
webhook_type: str
|
|
||||||
"""
|
|
||||||
Identifier for the webhook type. E.g. GitHub has repo and organization level hooks.
|
|
||||||
|
|
||||||
Only for use in the corresponding `WebhooksManager`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
event_filter_input: str = ""
|
|
||||||
"""
|
|
||||||
Name of the block's event filter input.
|
|
||||||
Leave empty if the corresponding webhook doesn't have distinct event/payload types.
|
|
||||||
"""
|
|
||||||
|
|
||||||
event_format: str = "{event}"
|
|
||||||
"""
|
|
||||||
Template string for the event(s) that a block instance subscribes to.
|
|
||||||
Applied individually to each event selected in the event filter input.
|
|
||||||
|
|
||||||
Example: `"pull_request.{event}"` -> `"pull_request.opened"`
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class BlockWebhookConfig(BlockManualWebhookConfig):
|
|
||||||
"""
|
|
||||||
Configuration model for webhook-triggered blocks for which
|
|
||||||
the webhook can be automatically set up through the provider's API.
|
|
||||||
"""
|
|
||||||
|
|
||||||
resource_format: str
|
|
||||||
"""
|
|
||||||
Template string for the resource that a block instance subscribes to.
|
|
||||||
Fields will be filled from the block's inputs (except `payload`).
|
|
||||||
|
|
||||||
Example: `f"{repo}/pull_requests"` (note: not how it's actually implemented)
|
|
||||||
|
|
||||||
Only for use in the corresponding `WebhooksManager`.
|
|
||||||
"""
|
|
||||||
# --8<-- [end:BlockWebhookConfig]
|
|
||||||
|
|
||||||
|
|
||||||
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
id: str = "",
|
|
||||||
description: str = "",
|
|
||||||
contributors: list["ContributorDetails"] = [],
|
|
||||||
categories: set[BlockCategory] | None = None,
|
|
||||||
input_schema: Type[BlockSchemaInputType] = EmptyInputSchema,
|
|
||||||
output_schema: Type[BlockSchemaOutputType] = EmptyOutputSchema,
|
|
||||||
test_input: BlockInput | list[BlockInput] | None = None,
|
|
||||||
test_output: BlockTestOutput | list[BlockTestOutput] | None = None,
|
|
||||||
test_mock: dict[str, Any] | None = None,
|
|
||||||
test_credentials: Optional[Credentials | dict[str, Credentials]] = None,
|
|
||||||
disabled: bool = False,
|
|
||||||
static_output: bool = False,
|
|
||||||
block_type: BlockType = BlockType.STANDARD,
|
|
||||||
webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None,
|
|
||||||
is_sensitive_action: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the block with the given schema.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
id: The unique identifier for the block, this value will be persisted in the
|
|
||||||
DB. So it should be a unique and constant across the application run.
|
|
||||||
Use the UUID format for the ID.
|
|
||||||
description: The description of the block, explaining what the block does.
|
|
||||||
contributors: The list of contributors who contributed to the block.
|
|
||||||
input_schema: The schema, defined as a Pydantic model, for the input data.
|
|
||||||
output_schema: The schema, defined as a Pydantic model, for the output data.
|
|
||||||
test_input: The list or single sample input data for the block, for testing.
|
|
||||||
test_output: The list or single expected output if the test_input is run.
|
|
||||||
test_mock: function names on the block implementation to mock on test run.
|
|
||||||
disabled: If the block is disabled, it will not be available for execution.
|
|
||||||
static_output: Whether the output links of the block are static by default.
|
|
||||||
"""
|
|
||||||
from backend.data.model import NodeExecutionStats
|
|
||||||
|
|
||||||
self.id = id
|
|
||||||
self.input_schema = input_schema
|
|
||||||
self.output_schema = output_schema
|
|
||||||
self.test_input = test_input
|
|
||||||
self.test_output = test_output
|
|
||||||
self.test_mock = test_mock
|
|
||||||
self.test_credentials = test_credentials
|
|
||||||
self.description = description
|
|
||||||
self.categories = categories or set()
|
|
||||||
self.contributors = contributors or set()
|
|
||||||
self.disabled = disabled
|
|
||||||
self.static_output = static_output
|
|
||||||
self.block_type = block_type
|
|
||||||
self.webhook_config = webhook_config
|
|
||||||
self.is_sensitive_action = is_sensitive_action
|
|
||||||
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
|
|
||||||
|
|
||||||
if self.webhook_config:
|
|
||||||
if isinstance(self.webhook_config, BlockWebhookConfig):
|
|
||||||
# Enforce presence of credentials field on auto-setup webhook blocks
|
|
||||||
if not (cred_fields := self.input_schema.get_credentials_fields()):
|
|
||||||
raise TypeError(
|
|
||||||
"credentials field is required on auto-setup webhook blocks"
|
|
||||||
)
|
|
||||||
# Disallow multiple credentials inputs on webhook blocks
|
|
||||||
elif len(cred_fields) > 1:
|
|
||||||
raise ValueError(
|
|
||||||
"Multiple credentials inputs not supported on webhook blocks"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.block_type = BlockType.WEBHOOK
|
|
||||||
else:
|
|
||||||
self.block_type = BlockType.WEBHOOK_MANUAL
|
|
||||||
|
|
||||||
# Enforce shape of webhook event filter, if present
|
|
||||||
if self.webhook_config.event_filter_input:
|
|
||||||
event_filter_field = self.input_schema.model_fields[
|
|
||||||
self.webhook_config.event_filter_input
|
|
||||||
]
|
|
||||||
if not (
|
|
||||||
isinstance(event_filter_field.annotation, type)
|
|
||||||
and issubclass(event_filter_field.annotation, BaseModel)
|
|
||||||
and all(
|
|
||||||
field.annotation is bool
|
|
||||||
for field in event_filter_field.annotation.model_fields.values()
|
|
||||||
)
|
|
||||||
):
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"{self.name} has an invalid webhook event selector: "
|
|
||||||
"field must be a BaseModel and all its fields must be boolean"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Enforce presence of 'payload' input
|
|
||||||
if "payload" not in self.input_schema.model_fields:
|
|
||||||
raise TypeError(
|
|
||||||
f"{self.name} is webhook-triggered but has no 'payload' input"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Disable webhook-triggered block if webhook functionality not available
|
|
||||||
if not app_config.platform_base_url:
|
|
||||||
self.disabled = True
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def run(self, input_data: BlockSchemaInputType, **kwargs) -> BlockOutput:
|
|
||||||
"""
|
|
||||||
Run the block with the given input data.
|
|
||||||
Args:
|
|
||||||
input_data: The input data with the structure of input_schema.
|
|
||||||
|
|
||||||
Kwargs: Currently 14/02/2025 these include
|
|
||||||
graph_id: The ID of the graph.
|
|
||||||
node_id: The ID of the node.
|
|
||||||
graph_exec_id: The ID of the graph execution.
|
|
||||||
node_exec_id: The ID of the node execution.
|
|
||||||
user_id: The ID of the user.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A Generator that yields (output_name, output_data).
|
|
||||||
output_name: One of the output name defined in Block's output_schema.
|
|
||||||
output_data: The data for the output_name, matching the defined schema.
|
|
||||||
"""
|
|
||||||
# --- satisfy the type checker, never executed -------------
|
|
||||||
if False: # noqa: SIM115
|
|
||||||
yield "name", "value" # pyright: ignore[reportMissingYield]
|
|
||||||
raise NotImplementedError(f"{self.name} does not implement the run method.")
|
|
||||||
|
|
||||||
async def run_once(
|
|
||||||
self, input_data: BlockSchemaInputType, output: str, **kwargs
|
|
||||||
) -> Any:
|
|
||||||
async for item in self.run(input_data, **kwargs):
|
|
||||||
name, data = item
|
|
||||||
if name == output:
|
|
||||||
return data
|
|
||||||
raise ValueError(f"{self.name} did not produce any output for {output}")
|
|
||||||
|
|
||||||
def merge_stats(self, stats: "NodeExecutionStats") -> "NodeExecutionStats":
|
|
||||||
self.execution_stats += stats
|
|
||||||
return self.execution_stats
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
return self.__class__.__name__
|
|
||||||
|
|
||||||
def to_dict(self):
|
|
||||||
return {
|
|
||||||
"id": self.id,
|
|
||||||
"name": self.name,
|
|
||||||
"inputSchema": self.input_schema.jsonschema(),
|
|
||||||
"outputSchema": self.output_schema.jsonschema(),
|
|
||||||
"description": self.description,
|
|
||||||
"categories": [category.dict() for category in self.categories],
|
|
||||||
"contributors": [
|
|
||||||
contributor.model_dump() for contributor in self.contributors
|
|
||||||
],
|
|
||||||
"staticOutput": self.static_output,
|
|
||||||
"uiType": self.block_type.value,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_info(self) -> BlockInfo:
|
|
||||||
from backend.data.credit import get_block_cost
|
|
||||||
|
|
||||||
return BlockInfo(
|
|
||||||
id=self.id,
|
|
||||||
name=self.name,
|
|
||||||
inputSchema=self.input_schema.jsonschema(),
|
|
||||||
outputSchema=self.output_schema.jsonschema(),
|
|
||||||
costs=get_block_cost(self),
|
|
||||||
description=self.description,
|
|
||||||
categories=[category.dict() for category in self.categories],
|
|
||||||
contributors=[
|
|
||||||
contributor.model_dump() for contributor in self.contributors
|
|
||||||
],
|
|
||||||
staticOutput=self.static_output,
|
|
||||||
uiType=self.block_type.value,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
async for output_name, output_data in self._execute(input_data, **kwargs):
|
|
||||||
yield output_name, output_data
|
|
||||||
except Exception as ex:
|
|
||||||
if isinstance(ex, BlockError):
|
|
||||||
raise ex
|
|
||||||
else:
|
|
||||||
raise (
|
|
||||||
BlockExecutionError
|
|
||||||
if isinstance(ex, ValueError)
|
|
||||||
else BlockUnknownError
|
|
||||||
)(
|
|
||||||
message=str(ex),
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=self.id,
|
|
||||||
) from ex
|
|
||||||
|
|
||||||
async def is_block_exec_need_review(
|
|
||||||
self,
|
|
||||||
input_data: BlockInput,
|
|
||||||
*,
|
|
||||||
user_id: str,
|
|
||||||
node_id: str,
|
|
||||||
node_exec_id: str,
|
|
||||||
graph_exec_id: str,
|
|
||||||
graph_id: str,
|
|
||||||
graph_version: int,
|
|
||||||
execution_context: "ExecutionContext",
|
|
||||||
**kwargs,
|
|
||||||
) -> tuple[bool, BlockInput]:
|
|
||||||
"""
|
|
||||||
Check if this block execution needs human review and handle the review process.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (should_pause, input_data_to_use)
|
|
||||||
- should_pause: True if execution should be paused for review
|
|
||||||
- input_data_to_use: The input data to use (may be modified by reviewer)
|
|
||||||
"""
|
|
||||||
if not (
|
|
||||||
self.is_sensitive_action and execution_context.sensitive_action_safe_mode
|
|
||||||
):
|
|
||||||
return False, input_data
|
|
||||||
|
|
||||||
from backend.blocks.helpers.review import HITLReviewHelper
|
|
||||||
|
|
||||||
# Handle the review request and get decision
|
|
||||||
decision = await HITLReviewHelper.handle_review_decision(
|
|
||||||
input_data=input_data,
|
|
||||||
user_id=user_id,
|
|
||||||
node_id=node_id,
|
|
||||||
node_exec_id=node_exec_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
graph_id=graph_id,
|
|
||||||
graph_version=graph_version,
|
|
||||||
block_name=self.name,
|
|
||||||
editable=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if decision is None:
|
|
||||||
# We're awaiting review - pause execution
|
|
||||||
return True, input_data
|
|
||||||
|
|
||||||
if not decision.should_proceed:
|
|
||||||
# Review was rejected, raise an error to stop execution
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"Block execution rejected by reviewer: {decision.message}",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=self.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Review was approved - use the potentially modified data
|
|
||||||
# ReviewResult.data must be a dict for block inputs
|
|
||||||
reviewed_data = decision.review_result.data
|
|
||||||
if not isinstance(reviewed_data, dict):
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=f"Review data must be a dict for block input, got {type(reviewed_data).__name__}",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=self.id,
|
|
||||||
)
|
|
||||||
return False, reviewed_data
|
|
||||||
|
|
||||||
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
|
||||||
# Check for review requirement only if running within a graph execution context
|
|
||||||
# Direct block execution (e.g., from chat) skips the review process
|
|
||||||
has_graph_context = all(
|
|
||||||
key in kwargs
|
|
||||||
for key in (
|
|
||||||
"node_exec_id",
|
|
||||||
"graph_exec_id",
|
|
||||||
"graph_id",
|
|
||||||
"execution_context",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if has_graph_context:
|
|
||||||
should_pause, input_data = await self.is_block_exec_need_review(
|
|
||||||
input_data, **kwargs
|
|
||||||
)
|
|
||||||
if should_pause:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Validate the input data (original or reviewer-modified) once
|
|
||||||
if error := self.input_schema.validate_data(input_data):
|
|
||||||
raise BlockInputError(
|
|
||||||
message=f"Unable to execute block with invalid input data: {error}",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=self.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use the validated input data
|
|
||||||
async for output_name, output_data in self.run(
|
|
||||||
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
if output_name == "error":
|
|
||||||
raise BlockExecutionError(
|
|
||||||
message=output_data, block_name=self.name, block_id=self.id
|
|
||||||
)
|
|
||||||
if self.block_type == BlockType.STANDARD and (
|
|
||||||
error := self.output_schema.validate_field(output_name, output_data)
|
|
||||||
):
|
|
||||||
raise BlockOutputError(
|
|
||||||
message=f"Block produced an invalid output data: {error}",
|
|
||||||
block_name=self.name,
|
|
||||||
block_id=self.id,
|
|
||||||
)
|
|
||||||
yield output_name, output_data
|
|
||||||
|
|
||||||
def is_triggered_by_event_type(
|
|
||||||
self, trigger_config: dict[str, Any], event_type: str
|
|
||||||
) -> bool:
|
|
||||||
if not self.webhook_config:
|
|
||||||
raise TypeError("This method can't be used on non-trigger blocks")
|
|
||||||
if not self.webhook_config.event_filter_input:
|
|
||||||
return True
|
|
||||||
event_filter = trigger_config.get(self.webhook_config.event_filter_input)
|
|
||||||
if not event_filter:
|
|
||||||
raise ValueError("Event filter is not configured on trigger")
|
|
||||||
return event_type in [
|
|
||||||
self.webhook_config.event_format.format(event=k)
|
|
||||||
for k in event_filter
|
|
||||||
if event_filter[k] is True
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# Type alias for any block with standard input/output schemas
|
|
||||||
AnyBlockSchema: TypeAlias = Block[BlockSchemaInput, BlockSchemaOutput]
|
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
|
|
||||||
from ._base import AnyBlockSchema
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def is_block_auth_configured(
|
|
||||||
block_cls: type[AnyBlockSchema],
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a block has a valid authentication method configured at runtime.
|
|
||||||
|
|
||||||
For example if a block is an OAuth-only block and there env vars are not set,
|
|
||||||
do not show it in the UI.
|
|
||||||
|
|
||||||
"""
|
|
||||||
from backend.sdk.registry import AutoRegistry
|
|
||||||
|
|
||||||
# Create an instance to access input_schema
|
|
||||||
try:
|
|
||||||
block = block_cls()
|
|
||||||
except Exception as e:
|
|
||||||
# If we can't create a block instance, assume it's not OAuth-only
|
|
||||||
logger.error(f"Error creating block instance for {block_cls.__name__}: {e}")
|
|
||||||
return True
|
|
||||||
logger.debug(
|
|
||||||
f"Checking if block {block_cls.__name__} has a valid provider configured"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get all credential inputs from input schema
|
|
||||||
credential_inputs = block.input_schema.get_credentials_fields_info()
|
|
||||||
required_inputs = block.input_schema.get_required_fields()
|
|
||||||
if not credential_inputs:
|
|
||||||
logger.debug(
|
|
||||||
f"Block {block_cls.__name__} has no credential inputs - Treating as valid"
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check credential inputs
|
|
||||||
if len(required_inputs.intersection(credential_inputs.keys())) == 0:
|
|
||||||
logger.debug(
|
|
||||||
f"Block {block_cls.__name__} has only optional credential inputs"
|
|
||||||
" - will work without credentials configured"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if the credential inputs for this block are correctly configured
|
|
||||||
for field_name, field_info in credential_inputs.items():
|
|
||||||
provider_names = field_info.provider
|
|
||||||
if not provider_names:
|
|
||||||
logger.warning(
|
|
||||||
f"Block {block_cls.__name__} "
|
|
||||||
f"has credential input '{field_name}' with no provider options"
|
|
||||||
" - Disabling"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
# If a field has multiple possible providers, each one needs to be usable to
|
|
||||||
# prevent breaking the UX
|
|
||||||
for _provider_name in provider_names:
|
|
||||||
provider_name = _provider_name.value
|
|
||||||
if provider_name in ProviderName.__members__.values():
|
|
||||||
logger.debug(
|
|
||||||
f"Block {block_cls.__name__} credential input '{field_name}' "
|
|
||||||
f"provider '{provider_name}' is part of the legacy provider system"
|
|
||||||
" - Treating as valid"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
provider = AutoRegistry.get_provider(provider_name)
|
|
||||||
if not provider:
|
|
||||||
logger.warning(
|
|
||||||
f"Block {block_cls.__name__} credential input '{field_name}' "
|
|
||||||
f"refers to unknown provider '{provider_name}' - Disabling"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check the provider's supported auth types
|
|
||||||
if field_info.supported_types != provider.supported_auth_types:
|
|
||||||
logger.warning(
|
|
||||||
f"Block {block_cls.__name__} credential input '{field_name}' "
|
|
||||||
f"has mismatched supported auth types (field <> Provider): "
|
|
||||||
f"{field_info.supported_types} != {provider.supported_auth_types}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not (supported_auth_types := provider.supported_auth_types):
|
|
||||||
# No auth methods are been configured for this provider
|
|
||||||
logger.warning(
|
|
||||||
f"Block {block_cls.__name__} credential input '{field_name}' "
|
|
||||||
f"provider '{provider_name}' "
|
|
||||||
"has no authentication methods configured - Disabling"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if provider supports OAuth
|
|
||||||
if "oauth2" in supported_auth_types:
|
|
||||||
# Check if OAuth environment variables are set
|
|
||||||
if (oauth_config := provider.oauth_config) and bool(
|
|
||||||
os.getenv(oauth_config.client_id_env_var)
|
|
||||||
and os.getenv(oauth_config.client_secret_env_var)
|
|
||||||
):
|
|
||||||
logger.debug(
|
|
||||||
f"Block {block_cls.__name__} credential input '{field_name}' "
|
|
||||||
f"provider '{provider_name}' is configured for OAuth"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"Block {block_cls.__name__} credential input '{field_name}' "
|
|
||||||
f"provider '{provider_name}' "
|
|
||||||
"is missing OAuth client ID or secret - Disabling"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Block {block_cls.__name__} credential input '{field_name}' is valid; "
|
|
||||||
f"supported credential types: {', '.join(field_info.supported_types)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockInput,
|
BlockInput,
|
||||||
@@ -9,15 +9,13 @@ from backend.blocks._base import (
|
|||||||
BlockSchema,
|
BlockSchema,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockType,
|
BlockType,
|
||||||
|
get_block,
|
||||||
)
|
)
|
||||||
from backend.data.execution import ExecutionContext, ExecutionStatus, NodesInputMasks
|
from backend.data.execution import ExecutionContext, ExecutionStatus, NodesInputMasks
|
||||||
from backend.data.model import NodeExecutionStats, SchemaField
|
from backend.data.model import NodeExecutionStats, SchemaField
|
||||||
from backend.util.json import validate_with_jsonschema
|
from backend.util.json import validate_with_jsonschema
|
||||||
from backend.util.retry import func_retry
|
from backend.util.retry import func_retry
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from backend.executor.utils import LogMetadata
|
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -126,10 +124,9 @@ class AgentExecutorBlock(Block):
|
|||||||
graph_version: int,
|
graph_version: int,
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
logger: "LogMetadata",
|
logger,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
|
||||||
from backend.blocks import get_block
|
|
||||||
from backend.data.execution import ExecutionEventType
|
from backend.data.execution import ExecutionEventType
|
||||||
from backend.executor import utils as execution_utils
|
from backend.executor import utils as execution_utils
|
||||||
|
|
||||||
@@ -201,7 +198,7 @@ class AgentExecutorBlock(Block):
|
|||||||
self,
|
self,
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
logger: "LogMetadata",
|
logger,
|
||||||
) -> None:
|
) -> None:
|
||||||
from backend.executor import utils as execution_utils
|
from backend.executor import utils as execution_utils
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,5 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.blocks._base import (
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.blocks.llm import (
|
from backend.blocks.llm import (
|
||||||
DEFAULT_LLM_MODEL,
|
DEFAULT_LLM_MODEL,
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
@@ -17,6 +11,12 @@ from backend.blocks.llm import (
|
|||||||
LLMResponse,
|
LLMResponse,
|
||||||
llm_call,
|
llm_call,
|
||||||
)
|
)
|
||||||
|
from backend.data.block import (
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from pydantic import SecretStr
|
|||||||
from replicate.client import Client as ReplicateClient
|
from replicate.client import Client as ReplicateClient
|
||||||
from replicate.helpers import FileOutput
|
from replicate.helpers import FileOutput
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -5,12 +5,7 @@ from pydantic import SecretStr
|
|||||||
from replicate.client import Client as ReplicateClient
|
from replicate.client import Client as ReplicateClient
|
||||||
from replicate.helpers import FileOutput
|
from replicate.helpers import FileOutput
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Literal
|
|||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
from replicate.client import Client as ReplicateClient
|
from replicate.client import Client as ReplicateClient
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Literal
|
|||||||
|
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,10 +1,3 @@
|
|||||||
from backend.blocks._base import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.blocks.apollo._api import ApolloClient
|
from backend.blocks.apollo._api import ApolloClient
|
||||||
from backend.blocks.apollo._auth import (
|
from backend.blocks.apollo._auth import (
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
@@ -17,6 +10,13 @@ from backend.blocks.apollo.models import (
|
|||||||
PrimaryPhone,
|
PrimaryPhone,
|
||||||
SearchOrganizationsRequest,
|
SearchOrganizationsRequest,
|
||||||
)
|
)
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.data.model import CredentialsField, SchemaField
|
from backend.data.model import CredentialsField, SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from backend.blocks._base import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.blocks.apollo._api import ApolloClient
|
from backend.blocks.apollo._api import ApolloClient
|
||||||
from backend.blocks.apollo._auth import (
|
from backend.blocks.apollo._auth import (
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
@@ -21,6 +14,13 @@ from backend.blocks.apollo.models import (
|
|||||||
SearchPeopleRequest,
|
SearchPeopleRequest,
|
||||||
SenorityLevels,
|
SenorityLevels,
|
||||||
)
|
)
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.data.model import CredentialsField, SchemaField
|
from backend.data.model import CredentialsField, SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,3 @@
|
|||||||
from backend.blocks._base import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.blocks.apollo._api import ApolloClient
|
from backend.blocks.apollo._api import ApolloClient
|
||||||
from backend.blocks.apollo._auth import (
|
from backend.blocks.apollo._auth import (
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
@@ -13,6 +6,13 @@ from backend.blocks.apollo._auth import (
|
|||||||
ApolloCredentialsInput,
|
ApolloCredentialsInput,
|
||||||
)
|
)
|
||||||
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
|
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.data.model import CredentialsField, SchemaField
|
from backend.data.model import CredentialsField, SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from backend.blocks._base import BlockSchemaInput
|
from backend.data.block import BlockSchemaInput
|
||||||
from backend.data.model import SchemaField, UserIntegrations
|
from backend.data.model import SchemaField, UserIntegrations
|
||||||
from backend.integrations.ayrshare import AyrshareClient
|
from backend.integrations.ayrshare import AyrshareClient
|
||||||
from backend.util.clients import get_database_manager_async_client
|
from backend.util.clients import get_database_manager_async_client
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import enum
|
import enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
|
import base64
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import shlex
|
import shlex
|
||||||
import uuid
|
import uuid
|
||||||
from typing import TYPE_CHECKING, Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from e2b import AsyncSandbox as BaseAsyncSandbox
|
from e2b import AsyncSandbox as BaseAsyncSandbox
|
||||||
from pydantic import SecretStr
|
from pydantic import BaseModel, SecretStr
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
@@ -20,13 +22,11 @@ from backend.data.model import (
|
|||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.sandbox_files import (
|
|
||||||
SandboxFileOutput,
|
|
||||||
extract_and_store_sandbox_files,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
logger = logging.getLogger(__name__)
|
||||||
from backend.executor.utils import ExecutionContext
|
|
||||||
|
# Maximum size for binary files to extract (50MB)
|
||||||
|
MAX_BINARY_FILE_SIZE = 50 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
class ClaudeCodeExecutionError(Exception):
|
class ClaudeCodeExecutionError(Exception):
|
||||||
@@ -181,15 +181,27 @@ class ClaudeCodeBlock(Block):
|
|||||||
advanced=True,
|
advanced=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class FileOutput(BaseModel):
|
||||||
|
"""A file extracted from the sandbox."""
|
||||||
|
|
||||||
|
path: str
|
||||||
|
relative_path: str # Path relative to working directory (for GitHub, etc.)
|
||||||
|
name: str
|
||||||
|
content: str # Text content for text files, empty string for binary files
|
||||||
|
is_binary: bool = False # True if this is a binary file
|
||||||
|
content_base64: Optional[str] = None # Base64-encoded content for binary files
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
response: str = SchemaField(
|
response: str = SchemaField(
|
||||||
description="The output/response from Claude Code execution"
|
description="The output/response from Claude Code execution"
|
||||||
)
|
)
|
||||||
files: list[SandboxFileOutput] = SchemaField(
|
files: list["ClaudeCodeBlock.FileOutput"] = SchemaField(
|
||||||
description=(
|
description=(
|
||||||
"List of text files created/modified by Claude Code during this execution. "
|
"List of files created/modified by Claude Code during this execution. "
|
||||||
"Each file has 'path', 'relative_path', 'name', 'content', and 'workspace_ref' fields. "
|
"Each file has 'path', 'relative_path', 'name', 'content', 'is_binary', "
|
||||||
"workspace_ref contains a workspace:// URI if the file was stored to workspace."
|
"and 'content_base64' fields. For text files, 'content' contains the text "
|
||||||
|
"and 'is_binary' is False. For binary files (PDFs, images, etc.), "
|
||||||
|
"'is_binary' is True and 'content_base64' contains the base64-encoded data."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
conversation_history: str = SchemaField(
|
conversation_history: str = SchemaField(
|
||||||
@@ -252,7 +264,8 @@ class ClaudeCodeBlock(Block):
|
|||||||
"relative_path": "index.html",
|
"relative_path": "index.html",
|
||||||
"name": "index.html",
|
"name": "index.html",
|
||||||
"content": "<html>Hello World</html>",
|
"content": "<html>Hello World</html>",
|
||||||
"workspace_ref": None,
|
"is_binary": False,
|
||||||
|
"content_base64": None,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
@@ -268,12 +281,13 @@ class ClaudeCodeBlock(Block):
|
|||||||
"execute_claude_code": lambda *args, **kwargs: (
|
"execute_claude_code": lambda *args, **kwargs: (
|
||||||
"Created index.html with hello world content", # response
|
"Created index.html with hello world content", # response
|
||||||
[
|
[
|
||||||
SandboxFileOutput(
|
ClaudeCodeBlock.FileOutput(
|
||||||
path="/home/user/index.html",
|
path="/home/user/index.html",
|
||||||
relative_path="index.html",
|
relative_path="index.html",
|
||||||
name="index.html",
|
name="index.html",
|
||||||
content="<html>Hello World</html>",
|
content="<html>Hello World</html>",
|
||||||
workspace_ref=None,
|
is_binary=False,
|
||||||
|
content_base64=None,
|
||||||
)
|
)
|
||||||
], # files
|
], # files
|
||||||
"User: Create a hello world HTML file\n"
|
"User: Create a hello world HTML file\n"
|
||||||
@@ -296,8 +310,7 @@ class ClaudeCodeBlock(Block):
|
|||||||
existing_sandbox_id: str,
|
existing_sandbox_id: str,
|
||||||
conversation_history: str,
|
conversation_history: str,
|
||||||
dispose_sandbox: bool,
|
dispose_sandbox: bool,
|
||||||
execution_context: "ExecutionContext",
|
) -> tuple[str, list["ClaudeCodeBlock.FileOutput"], str, str, str]:
|
||||||
) -> tuple[str, list[SandboxFileOutput], str, str, str]:
|
|
||||||
"""
|
"""
|
||||||
Execute Claude Code in an E2B sandbox.
|
Execute Claude Code in an E2B sandbox.
|
||||||
|
|
||||||
@@ -452,18 +465,14 @@ class ClaudeCodeBlock(Block):
|
|||||||
else:
|
else:
|
||||||
new_conversation_history = turn_entry
|
new_conversation_history = turn_entry
|
||||||
|
|
||||||
# Extract files created/modified during this run and store to workspace
|
# Extract files created/modified during this run
|
||||||
sandbox_files = await extract_and_store_sandbox_files(
|
files = await self._extract_files(
|
||||||
sandbox=sandbox,
|
sandbox, working_directory, start_timestamp
|
||||||
working_directory=working_directory,
|
|
||||||
execution_context=execution_context,
|
|
||||||
since_timestamp=start_timestamp,
|
|
||||||
text_only=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
response,
|
response,
|
||||||
sandbox_files, # Already SandboxFileOutput objects
|
files,
|
||||||
new_conversation_history,
|
new_conversation_history,
|
||||||
current_session_id,
|
current_session_id,
|
||||||
sandbox_id,
|
sandbox_id,
|
||||||
@@ -478,6 +487,233 @@ class ClaudeCodeBlock(Block):
|
|||||||
if dispose_sandbox and sandbox:
|
if dispose_sandbox and sandbox:
|
||||||
await sandbox.kill()
|
await sandbox.kill()
|
||||||
|
|
||||||
|
async def _extract_files(
|
||||||
|
self,
|
||||||
|
sandbox: BaseAsyncSandbox,
|
||||||
|
working_directory: str,
|
||||||
|
since_timestamp: str | None = None,
|
||||||
|
) -> list["ClaudeCodeBlock.FileOutput"]:
|
||||||
|
"""
|
||||||
|
Extract text files created/modified during this Claude Code execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sandbox: The E2B sandbox instance
|
||||||
|
working_directory: Directory to search for files
|
||||||
|
since_timestamp: ISO timestamp - only return files modified after this time
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of FileOutput objects with path, relative_path, name, and content
|
||||||
|
"""
|
||||||
|
files: list[ClaudeCodeBlock.FileOutput] = []
|
||||||
|
|
||||||
|
# Text file extensions we can safely read as text
|
||||||
|
text_extensions = {
|
||||||
|
".txt",
|
||||||
|
".md",
|
||||||
|
".html",
|
||||||
|
".htm",
|
||||||
|
".css",
|
||||||
|
".js",
|
||||||
|
".ts",
|
||||||
|
".jsx",
|
||||||
|
".tsx",
|
||||||
|
".json",
|
||||||
|
".xml",
|
||||||
|
".yaml",
|
||||||
|
".yml",
|
||||||
|
".toml",
|
||||||
|
".ini",
|
||||||
|
".cfg",
|
||||||
|
".conf",
|
||||||
|
".py",
|
||||||
|
".rb",
|
||||||
|
".php",
|
||||||
|
".java",
|
||||||
|
".c",
|
||||||
|
".cpp",
|
||||||
|
".h",
|
||||||
|
".hpp",
|
||||||
|
".cs",
|
||||||
|
".go",
|
||||||
|
".rs",
|
||||||
|
".swift",
|
||||||
|
".kt",
|
||||||
|
".scala",
|
||||||
|
".sh",
|
||||||
|
".bash",
|
||||||
|
".zsh",
|
||||||
|
".sql",
|
||||||
|
".graphql",
|
||||||
|
".env",
|
||||||
|
".gitignore",
|
||||||
|
".dockerfile",
|
||||||
|
".vue",
|
||||||
|
".svelte",
|
||||||
|
".astro",
|
||||||
|
".mdx",
|
||||||
|
".rst",
|
||||||
|
".tex",
|
||||||
|
".csv",
|
||||||
|
".log",
|
||||||
|
".svg", # SVG is XML-based text
|
||||||
|
}
|
||||||
|
|
||||||
|
# Binary file extensions we can read and base64-encode
|
||||||
|
binary_extensions = {
|
||||||
|
# Images
|
||||||
|
".png",
|
||||||
|
".jpg",
|
||||||
|
".jpeg",
|
||||||
|
".gif",
|
||||||
|
".webp",
|
||||||
|
".ico",
|
||||||
|
".bmp",
|
||||||
|
".tiff",
|
||||||
|
".tif",
|
||||||
|
# Documents
|
||||||
|
".pdf",
|
||||||
|
# Archives (useful for downloads)
|
||||||
|
".zip",
|
||||||
|
".tar",
|
||||||
|
".gz",
|
||||||
|
".7z",
|
||||||
|
# Audio/Video (if small enough)
|
||||||
|
".mp3",
|
||||||
|
".wav",
|
||||||
|
".mp4",
|
||||||
|
".webm",
|
||||||
|
# Other binary formats
|
||||||
|
".woff",
|
||||||
|
".woff2",
|
||||||
|
".ttf",
|
||||||
|
".otf",
|
||||||
|
".eot",
|
||||||
|
".bin",
|
||||||
|
".exe",
|
||||||
|
".dll",
|
||||||
|
".so",
|
||||||
|
".dylib",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# List files recursively using find command
|
||||||
|
# Exclude node_modules and .git directories, but allow hidden files
|
||||||
|
# like .env and .gitignore (they're filtered by text_extensions later)
|
||||||
|
# Filter by timestamp to only get files created/modified during this run
|
||||||
|
safe_working_dir = shlex.quote(working_directory)
|
||||||
|
timestamp_filter = ""
|
||||||
|
if since_timestamp:
|
||||||
|
timestamp_filter = f"-newermt {shlex.quote(since_timestamp)} "
|
||||||
|
find_result = await sandbox.commands.run(
|
||||||
|
f"find {safe_working_dir} -type f "
|
||||||
|
f"{timestamp_filter}"
|
||||||
|
f"-not -path '*/node_modules/*' "
|
||||||
|
f"-not -path '*/.git/*' "
|
||||||
|
f"2>/dev/null"
|
||||||
|
)
|
||||||
|
|
||||||
|
if find_result.stdout:
|
||||||
|
for file_path in find_result.stdout.strip().split("\n"):
|
||||||
|
if not file_path:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if it's a text file we can read (case-insensitive)
|
||||||
|
file_path_lower = file_path.lower()
|
||||||
|
is_text = any(
|
||||||
|
file_path_lower.endswith(ext) for ext in text_extensions
|
||||||
|
) or file_path_lower.endswith("dockerfile")
|
||||||
|
|
||||||
|
# Check if it's a binary file we should extract
|
||||||
|
is_binary = any(
|
||||||
|
file_path_lower.endswith(ext) for ext in binary_extensions
|
||||||
|
)
|
||||||
|
|
||||||
|
# Helper to extract filename and relative path
|
||||||
|
def get_file_info(path: str, work_dir: str) -> tuple[str, str]:
|
||||||
|
name = path.split("/")[-1]
|
||||||
|
rel_path = path
|
||||||
|
if path.startswith(work_dir):
|
||||||
|
rel_path = path[len(work_dir) :]
|
||||||
|
if rel_path.startswith("/"):
|
||||||
|
rel_path = rel_path[1:]
|
||||||
|
return name, rel_path
|
||||||
|
|
||||||
|
if is_text:
|
||||||
|
try:
|
||||||
|
content = await sandbox.files.read(file_path)
|
||||||
|
# Handle bytes or string
|
||||||
|
if isinstance(content, bytes):
|
||||||
|
content = content.decode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
file_name, relative_path = get_file_info(
|
||||||
|
file_path, working_directory
|
||||||
|
)
|
||||||
|
files.append(
|
||||||
|
ClaudeCodeBlock.FileOutput(
|
||||||
|
path=file_path,
|
||||||
|
relative_path=relative_path,
|
||||||
|
name=file_name,
|
||||||
|
content=content,
|
||||||
|
is_binary=False,
|
||||||
|
content_base64=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to read text file {file_path}: {e}")
|
||||||
|
elif is_binary:
|
||||||
|
try:
|
||||||
|
# Check file size before reading to avoid OOM
|
||||||
|
stat_result = await sandbox.commands.run(
|
||||||
|
f"stat -c %s {shlex.quote(file_path)} 2>/dev/null"
|
||||||
|
)
|
||||||
|
if stat_result.exit_code != 0 or not stat_result.stdout:
|
||||||
|
logger.warning(
|
||||||
|
f"Skipping binary file {file_path}: "
|
||||||
|
f"could not determine file size"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
file_size = int(stat_result.stdout.strip())
|
||||||
|
if file_size > MAX_BINARY_FILE_SIZE:
|
||||||
|
logger.warning(
|
||||||
|
f"Skipping binary file {file_path}: "
|
||||||
|
f"size {file_size} exceeds limit "
|
||||||
|
f"{MAX_BINARY_FILE_SIZE}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Read binary file as bytes using format="bytes"
|
||||||
|
content_bytes = await sandbox.files.read(
|
||||||
|
file_path, format="bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Base64 encode the binary content
|
||||||
|
content_b64 = base64.b64encode(content_bytes).decode(
|
||||||
|
"ascii"
|
||||||
|
)
|
||||||
|
|
||||||
|
file_name, relative_path = get_file_info(
|
||||||
|
file_path, working_directory
|
||||||
|
)
|
||||||
|
files.append(
|
||||||
|
ClaudeCodeBlock.FileOutput(
|
||||||
|
path=file_path,
|
||||||
|
relative_path=relative_path,
|
||||||
|
name=file_name,
|
||||||
|
content="", # Empty for binary files
|
||||||
|
is_binary=True,
|
||||||
|
content_base64=content_b64,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to read binary file {file_path}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"File extraction failed: {e}")
|
||||||
|
|
||||||
|
return files
|
||||||
|
|
||||||
def _escape_prompt(self, prompt: str) -> str:
|
def _escape_prompt(self, prompt: str) -> str:
|
||||||
"""Escape the prompt for safe shell execution."""
|
"""Escape the prompt for safe shell execution."""
|
||||||
# Use single quotes and escape any single quotes in the prompt
|
# Use single quotes and escape any single quotes in the prompt
|
||||||
@@ -490,7 +726,6 @@ class ClaudeCodeBlock(Block):
|
|||||||
*,
|
*,
|
||||||
e2b_credentials: APIKeyCredentials,
|
e2b_credentials: APIKeyCredentials,
|
||||||
anthropic_credentials: APIKeyCredentials,
|
anthropic_credentials: APIKeyCredentials,
|
||||||
execution_context: "ExecutionContext",
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
@@ -511,7 +746,6 @@ class ClaudeCodeBlock(Block):
|
|||||||
existing_sandbox_id=input_data.sandbox_id,
|
existing_sandbox_id=input_data.sandbox_id,
|
||||||
conversation_history=input_data.conversation_history,
|
conversation_history=input_data.conversation_history,
|
||||||
dispose_sandbox=input_data.dispose_sandbox,
|
dispose_sandbox=input_data.dispose_sandbox,
|
||||||
execution_context=execution_context,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
yield "response", response
|
yield "response", response
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
from e2b_code_interpreter import AsyncSandbox
|
from e2b_code_interpreter import AsyncSandbox
|
||||||
from e2b_code_interpreter import Result as E2BExecutionResult
|
from e2b_code_interpreter import Result as E2BExecutionResult
|
||||||
from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart
|
from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart
|
||||||
from pydantic import BaseModel, Field, JsonValue, SecretStr
|
from pydantic import BaseModel, Field, JsonValue, SecretStr
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
@@ -20,13 +20,6 @@ from backend.data.model import (
|
|||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.sandbox_files import (
|
|
||||||
SandboxFileOutput,
|
|
||||||
extract_and_store_sandbox_files,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from backend.executor.utils import ExecutionContext
|
|
||||||
|
|
||||||
TEST_CREDENTIALS = APIKeyCredentials(
|
TEST_CREDENTIALS = APIKeyCredentials(
|
||||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
@@ -92,9 +85,6 @@ class CodeExecutionResult(MainCodeExecutionResult):
|
|||||||
class BaseE2BExecutorMixin:
|
class BaseE2BExecutorMixin:
|
||||||
"""Shared implementation methods for E2B executor blocks."""
|
"""Shared implementation methods for E2B executor blocks."""
|
||||||
|
|
||||||
# Default working directory in E2B sandboxes
|
|
||||||
WORKING_DIR = "/home/user"
|
|
||||||
|
|
||||||
async def execute_code(
|
async def execute_code(
|
||||||
self,
|
self,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
@@ -105,21 +95,14 @@ class BaseE2BExecutorMixin:
|
|||||||
timeout: Optional[int] = None,
|
timeout: Optional[int] = None,
|
||||||
sandbox_id: Optional[str] = None,
|
sandbox_id: Optional[str] = None,
|
||||||
dispose_sandbox: bool = False,
|
dispose_sandbox: bool = False,
|
||||||
execution_context: Optional["ExecutionContext"] = None,
|
|
||||||
extract_files: bool = False,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Unified code execution method that handles all three use cases:
|
Unified code execution method that handles all three use cases:
|
||||||
1. Create new sandbox and execute (ExecuteCodeBlock)
|
1. Create new sandbox and execute (ExecuteCodeBlock)
|
||||||
2. Create new sandbox, execute, and return sandbox_id (InstantiateCodeSandboxBlock)
|
2. Create new sandbox, execute, and return sandbox_id (InstantiateCodeSandboxBlock)
|
||||||
3. Connect to existing sandbox and execute (ExecuteCodeStepBlock)
|
3. Connect to existing sandbox and execute (ExecuteCodeStepBlock)
|
||||||
|
|
||||||
Args:
|
|
||||||
extract_files: If True and execution_context provided, extract files
|
|
||||||
created/modified during execution and store to workspace.
|
|
||||||
""" # noqa
|
""" # noqa
|
||||||
sandbox = None
|
sandbox = None
|
||||||
files: list[SandboxFileOutput] = []
|
|
||||||
try:
|
try:
|
||||||
if sandbox_id:
|
if sandbox_id:
|
||||||
# Connect to existing sandbox (ExecuteCodeStepBlock case)
|
# Connect to existing sandbox (ExecuteCodeStepBlock case)
|
||||||
@@ -135,12 +118,6 @@ class BaseE2BExecutorMixin:
|
|||||||
for cmd in setup_commands:
|
for cmd in setup_commands:
|
||||||
await sandbox.commands.run(cmd)
|
await sandbox.commands.run(cmd)
|
||||||
|
|
||||||
# Capture timestamp before execution to scope file extraction
|
|
||||||
start_timestamp = None
|
|
||||||
if extract_files:
|
|
||||||
ts_result = await sandbox.commands.run("date -u +%Y-%m-%dT%H:%M:%S")
|
|
||||||
start_timestamp = ts_result.stdout.strip() if ts_result.stdout else None
|
|
||||||
|
|
||||||
# Execute the code
|
# Execute the code
|
||||||
execution = await sandbox.run_code(
|
execution = await sandbox.run_code(
|
||||||
code,
|
code,
|
||||||
@@ -156,24 +133,7 @@ class BaseE2BExecutorMixin:
|
|||||||
stdout_logs = "".join(execution.logs.stdout)
|
stdout_logs = "".join(execution.logs.stdout)
|
||||||
stderr_logs = "".join(execution.logs.stderr)
|
stderr_logs = "".join(execution.logs.stderr)
|
||||||
|
|
||||||
# Extract files created/modified during this execution
|
return results, text_output, stdout_logs, stderr_logs, sandbox.sandbox_id
|
||||||
if extract_files and execution_context:
|
|
||||||
files = await extract_and_store_sandbox_files(
|
|
||||||
sandbox=sandbox,
|
|
||||||
working_directory=self.WORKING_DIR,
|
|
||||||
execution_context=execution_context,
|
|
||||||
since_timestamp=start_timestamp,
|
|
||||||
text_only=False, # Include binary files too
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
results,
|
|
||||||
text_output,
|
|
||||||
stdout_logs,
|
|
||||||
stderr_logs,
|
|
||||||
sandbox.sandbox_id,
|
|
||||||
files,
|
|
||||||
)
|
|
||||||
finally:
|
finally:
|
||||||
# Dispose of sandbox if requested to reduce usage costs
|
# Dispose of sandbox if requested to reduce usage costs
|
||||||
if dispose_sandbox and sandbox:
|
if dispose_sandbox and sandbox:
|
||||||
@@ -278,12 +238,6 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
|||||||
description="Standard output logs from execution"
|
description="Standard output logs from execution"
|
||||||
)
|
)
|
||||||
stderr_logs: str = SchemaField(description="Standard error logs from execution")
|
stderr_logs: str = SchemaField(description="Standard error logs from execution")
|
||||||
files: list[SandboxFileOutput] = SchemaField(
|
|
||||||
description=(
|
|
||||||
"Files created or modified during execution. "
|
|
||||||
"Each file has path, name, content, and workspace_ref (if stored)."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -305,30 +259,23 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
|||||||
("results", []),
|
("results", []),
|
||||||
("response", "Hello World"),
|
("response", "Hello World"),
|
||||||
("stdout_logs", "Hello World\n"),
|
("stdout_logs", "Hello World\n"),
|
||||||
("files", []),
|
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
"execute_code": lambda api_key, code, language, template_id, setup_commands, timeout, dispose_sandbox, execution_context, extract_files: ( # noqa
|
"execute_code": lambda api_key, code, language, template_id, setup_commands, timeout, dispose_sandbox: ( # noqa
|
||||||
[], # results
|
[], # results
|
||||||
"Hello World", # text_output
|
"Hello World", # text_output
|
||||||
"Hello World\n", # stdout_logs
|
"Hello World\n", # stdout_logs
|
||||||
"", # stderr_logs
|
"", # stderr_logs
|
||||||
"sandbox_id", # sandbox_id
|
"sandbox_id", # sandbox_id
|
||||||
[], # files
|
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: APIKeyCredentials,
|
|
||||||
execution_context: "ExecutionContext",
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
results, text_output, stdout, stderr, _, files = await self.execute_code(
|
results, text_output, stdout, stderr, _ = await self.execute_code(
|
||||||
api_key=credentials.api_key.get_secret_value(),
|
api_key=credentials.api_key.get_secret_value(),
|
||||||
code=input_data.code,
|
code=input_data.code,
|
||||||
language=input_data.language,
|
language=input_data.language,
|
||||||
@@ -336,8 +283,6 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
|||||||
setup_commands=input_data.setup_commands,
|
setup_commands=input_data.setup_commands,
|
||||||
timeout=input_data.timeout,
|
timeout=input_data.timeout,
|
||||||
dispose_sandbox=input_data.dispose_sandbox,
|
dispose_sandbox=input_data.dispose_sandbox,
|
||||||
execution_context=execution_context,
|
|
||||||
extract_files=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Determine result object shape & filter out empty formats
|
# Determine result object shape & filter out empty formats
|
||||||
@@ -351,8 +296,6 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
|||||||
yield "stdout_logs", stdout
|
yield "stdout_logs", stdout
|
||||||
if stderr:
|
if stderr:
|
||||||
yield "stderr_logs", stderr
|
yield "stderr_logs", stderr
|
||||||
# Always yield files (empty list if none)
|
|
||||||
yield "files", [f.model_dump() for f in files]
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "error", str(e)
|
yield "error", str(e)
|
||||||
|
|
||||||
@@ -450,7 +393,6 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
|||||||
"Hello World\n", # stdout_logs
|
"Hello World\n", # stdout_logs
|
||||||
"", # stderr_logs
|
"", # stderr_logs
|
||||||
"sandbox_id", # sandbox_id
|
"sandbox_id", # sandbox_id
|
||||||
[], # files
|
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -459,7 +401,7 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
|||||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
_, text_output, stdout, stderr, sandbox_id, _ = await self.execute_code(
|
_, text_output, stdout, stderr, sandbox_id = await self.execute_code(
|
||||||
api_key=credentials.api_key.get_secret_value(),
|
api_key=credentials.api_key.get_secret_value(),
|
||||||
code=input_data.setup_code,
|
code=input_data.setup_code,
|
||||||
language=input_data.language,
|
language=input_data.language,
|
||||||
@@ -558,7 +500,6 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
|||||||
"Hello World\n", # stdout_logs
|
"Hello World\n", # stdout_logs
|
||||||
"", # stderr_logs
|
"", # stderr_logs
|
||||||
sandbox_id, # sandbox_id
|
sandbox_id, # sandbox_id
|
||||||
[], # files
|
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -567,7 +508,7 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
|||||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
results, text_output, stdout, stderr, _, _ = await self.execute_code(
|
results, text_output, stdout, stderr, _ = await self.execute_code(
|
||||||
api_key=credentials.api_key.get_secret_value(),
|
api_key=credentials.api_key.get_secret_value(),
|
||||||
code=input_data.step_code,
|
code=input_data.step_code,
|
||||||
language=input_data.language,
|
language=input_data.language,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from openai import AsyncOpenAI
|
|||||||
from openai.types.responses import Response as OpenAIResponse
|
from openai.types.responses import Response as OpenAIResponse
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockManualWebhookConfig,
|
BlockManualWebhookConfig,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import codecs
|
import codecs
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import Any, Literal, cast
|
|||||||
import discord
|
import discord
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
Discord OAuth-based blocks.
|
Discord OAuth-based blocks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import Literal
|
|||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, SecretStr
|
from pydantic import BaseModel, ConfigDict, SecretStr
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import codecs
|
import codecs
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ which provides access to LinkedIn profile data and related information.
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -3,13 +3,6 @@ import logging
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.blocks._base import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.blocks.fal._auth import (
|
from backend.blocks.fal._auth import (
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
TEST_CREDENTIALS_INPUT,
|
TEST_CREDENTIALS_INPUT,
|
||||||
@@ -17,6 +10,13 @@ from backend.blocks.fal._auth import (
|
|||||||
FalCredentialsField,
|
FalCredentialsField,
|
||||||
FalCredentialsInput,
|
FalCredentialsInput,
|
||||||
)
|
)
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from pydantic import SecretStr
|
|||||||
from replicate.client import Client as ReplicateClient
|
from replicate.client import Client as ReplicateClient
|
||||||
from replicate.helpers import FileOutput
|
from replicate.helpers import FileOutput
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from urllib.parse import urlparse
|
|||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import re
|
|||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import base64
|
|||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import Any, List, Optional
|
|||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from google.oauth2.credentials import Credentials
|
|||||||
from googleapiclient.discovery import build
|
from googleapiclient.discovery import build
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -7,14 +7,14 @@ from google.oauth2.credentials import Credentials
|
|||||||
from googleapiclient.discovery import build
|
from googleapiclient.discovery import build
|
||||||
from gravitas_md2gdocs import to_requests
|
from gravitas_md2gdocs import to_requests
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField
|
||||||
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField
|
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from google.oauth2.credentials import Credentials
|
|||||||
from googleapiclient.discovery import build
|
from googleapiclient.discovery import build
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -7,14 +7,14 @@ from enum import Enum
|
|||||||
from google.oauth2.credentials import Credentials
|
from google.oauth2.credentials import Credentials
|
||||||
from googleapiclient.discovery import build
|
from googleapiclient.discovery import build
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField
|
||||||
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField
|
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Literal
|
|||||||
import googlemaps
|
import googlemaps
|
||||||
from pydantic import BaseModel, SecretStr
|
from pydantic import BaseModel, SecretStr
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -9,7 +9,9 @@ from typing import Any, Optional
|
|||||||
from prisma.enums import ReviewStatus
|
from prisma.enums import ReviewStatus
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from backend.data.execution import ExecutionStatus
|
||||||
from backend.data.human_review import ReviewResult
|
from backend.data.human_review import ReviewResult
|
||||||
|
from backend.executor.manager import async_update_node_execution_status
|
||||||
from backend.util.clients import get_database_manager_async_client
|
from backend.util.clients import get_database_manager_async_client
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -41,8 +43,6 @@ class HITLReviewHelper:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def update_node_execution_status(**kwargs) -> None:
|
async def update_node_execution_status(**kwargs) -> None:
|
||||||
"""Update the execution status of a node."""
|
"""Update the execution status of a node."""
|
||||||
from backend.executor.manager import async_update_node_execution_status
|
|
||||||
|
|
||||||
await async_update_node_execution_status(
|
await async_update_node_execution_status(
|
||||||
db_client=get_database_manager_async_client(), **kwargs
|
db_client=get_database_manager_async_client(), **kwargs
|
||||||
)
|
)
|
||||||
@@ -88,13 +88,12 @@ class HITLReviewHelper:
|
|||||||
Raises:
|
Raises:
|
||||||
Exception: If review creation or status update fails
|
Exception: If review creation or status update fails
|
||||||
"""
|
"""
|
||||||
from backend.data.execution import ExecutionStatus
|
|
||||||
|
|
||||||
# Note: Safe mode checks (human_in_the_loop_safe_mode, sensitive_action_safe_mode)
|
# Note: Safe mode checks (human_in_the_loop_safe_mode, sensitive_action_safe_mode)
|
||||||
# are handled by the caller:
|
# are handled by the caller:
|
||||||
# - HITL blocks check human_in_the_loop_safe_mode in their run() method
|
# - HITL blocks check human_in_the_loop_safe_mode in their run() method
|
||||||
# - Sensitive action blocks check sensitive_action_safe_mode in is_block_exec_need_review()
|
# - Sensitive action blocks check sensitive_action_safe_mode in is_block_exec_need_review()
|
||||||
# This function only handles checking for existing approvals.
|
# This function only handles checking for existing approvals.
|
||||||
|
|
||||||
# Check if this node has already been approved (normal or auto-approval)
|
# Check if this node has already been approved (normal or auto-approval)
|
||||||
if approval_result := await HITLReviewHelper.check_approval(
|
if approval_result := await HITLReviewHelper.check_approval(
|
||||||
node_exec_id=node_exec_id,
|
node_exec_id=node_exec_id,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import Literal
|
|||||||
import aiofiles
|
import aiofiles
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
from backend.blocks._base import (
|
from backend.blocks.hubspot._auth import (
|
||||||
|
HubSpotCredentials,
|
||||||
|
HubSpotCredentialsField,
|
||||||
|
HubSpotCredentialsInput,
|
||||||
|
)
|
||||||
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.blocks.hubspot._auth import (
|
|
||||||
HubSpotCredentials,
|
|
||||||
HubSpotCredentialsField,
|
|
||||||
HubSpotCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
from backend.blocks._base import (
|
from backend.blocks.hubspot._auth import (
|
||||||
|
HubSpotCredentials,
|
||||||
|
HubSpotCredentialsField,
|
||||||
|
HubSpotCredentialsInput,
|
||||||
|
)
|
||||||
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.blocks.hubspot._auth import (
|
|
||||||
HubSpotCredentials,
|
|
||||||
HubSpotCredentialsField,
|
|
||||||
HubSpotCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.blocks.hubspot._auth import (
|
||||||
|
HubSpotCredentials,
|
||||||
|
HubSpotCredentialsField,
|
||||||
|
HubSpotCredentialsInput,
|
||||||
|
)
|
||||||
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.blocks.hubspot._auth import (
|
|
||||||
HubSpotCredentials,
|
|
||||||
HubSpotCredentialsField,
|
|
||||||
HubSpotCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,8 @@ from typing import Any
|
|||||||
|
|
||||||
from prisma.enums import ReviewStatus
|
from prisma.enums import ReviewStatus
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.blocks.helpers.review import HITLReviewHelper
|
||||||
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
@@ -11,7 +12,6 @@ from backend.blocks._base import (
|
|||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
BlockType,
|
BlockType,
|
||||||
)
|
)
|
||||||
from backend.blocks.helpers.review import HITLReviewHelper
|
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.human_review import ReviewResult
|
from backend.data.human_review import ReviewResult
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Any, Dict, Literal, Optional
|
|||||||
|
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,9 @@ import copy
|
|||||||
from datetime import date, time
|
from datetime import date, time
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from backend.blocks._base import (
|
# Import for Google Drive file input block
|
||||||
|
from backend.blocks.google._drive import AttachmentView, GoogleDriveFile
|
||||||
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
@@ -10,9 +12,6 @@ from backend.blocks._base import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockType,
|
BlockType,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Import for Google Drive file input block
|
|
||||||
from backend.blocks.google._drive import AttachmentView, GoogleDriveFile
|
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
from backend.blocks._base import (
|
from backend.blocks.jina._auth import (
|
||||||
|
JinaCredentials,
|
||||||
|
JinaCredentialsField,
|
||||||
|
JinaCredentialsInput,
|
||||||
|
)
|
||||||
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.blocks.jina._auth import (
|
|
||||||
JinaCredentials,
|
|
||||||
JinaCredentialsField,
|
|
||||||
JinaCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
from backend.blocks._base import (
|
from backend.blocks.jina._auth import (
|
||||||
|
JinaCredentials,
|
||||||
|
JinaCredentialsField,
|
||||||
|
JinaCredentialsInput,
|
||||||
|
)
|
||||||
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.blocks.jina._auth import (
|
|
||||||
JinaCredentials,
|
|
||||||
JinaCredentialsField,
|
|
||||||
JinaCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
|||||||
@@ -3,18 +3,18 @@ from urllib.parse import quote
|
|||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.blocks.jina._auth import (
|
||||||
|
JinaCredentials,
|
||||||
|
JinaCredentialsField,
|
||||||
|
JinaCredentialsInput,
|
||||||
|
)
|
||||||
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.blocks.jina._auth import (
|
|
||||||
JinaCredentials,
|
|
||||||
JinaCredentialsField,
|
|
||||||
JinaCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,5 @@
|
|||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
from backend.blocks._base import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.blocks.jina._auth import (
|
from backend.blocks.jina._auth import (
|
||||||
TEST_CREDENTIALS,
|
TEST_CREDENTIALS,
|
||||||
TEST_CREDENTIALS_INPUT,
|
TEST_CREDENTIALS_INPUT,
|
||||||
@@ -15,6 +8,13 @@ from backend.blocks.jina._auth import (
|
|||||||
JinaCredentialsInput,
|
JinaCredentialsInput,
|
||||||
)
|
)
|
||||||
from backend.blocks.search import GetRequest
|
from backend.blocks.search import GetRequest
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.exceptions import BlockExecutionError
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from anthropic.types import ToolParam
|
|||||||
from groq import AsyncGroq
|
from groq import AsyncGroq
|
||||||
from pydantic import BaseModel, SecretStr
|
from pydantic import BaseModel, SecretStr
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import operator
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,300 +0,0 @@
|
|||||||
"""
|
|
||||||
MCP (Model Context Protocol) Tool Block.
|
|
||||||
|
|
||||||
A single dynamic block that can connect to any MCP server, discover available tools,
|
|
||||||
and execute them. Works like AgentExecutorBlock — the user selects a tool from a
|
|
||||||
dropdown and the input/output schema adapts dynamically.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.blocks._base import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
BlockType,
|
|
||||||
)
|
|
||||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
|
||||||
from backend.data.block import BlockInput, BlockOutput
|
|
||||||
from backend.data.model import (
|
|
||||||
CredentialsField,
|
|
||||||
CredentialsMetaInput,
|
|
||||||
OAuth2Credentials,
|
|
||||||
SchemaField,
|
|
||||||
)
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.util.json import validate_with_jsonschema
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
TEST_CREDENTIALS = OAuth2Credentials(
|
|
||||||
id="test-mcp-cred",
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("mock-mcp-token"),
|
|
||||||
refresh_token=SecretStr("mock-refresh"),
|
|
||||||
scopes=[],
|
|
||||||
title="Mock MCP credential",
|
|
||||||
)
|
|
||||||
TEST_CREDENTIALS_INPUT = {
|
|
||||||
"provider": TEST_CREDENTIALS.provider,
|
|
||||||
"id": TEST_CREDENTIALS.id,
|
|
||||||
"type": TEST_CREDENTIALS.type,
|
|
||||||
"title": TEST_CREDENTIALS.title,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
MCPCredentials = CredentialsMetaInput[Literal[ProviderName.MCP], Literal["oauth2"]]
|
|
||||||
|
|
||||||
|
|
||||||
class MCPToolBlock(Block):
|
|
||||||
"""
|
|
||||||
A block that connects to an MCP server, lets the user pick a tool,
|
|
||||||
and executes it with dynamic input/output schema.
|
|
||||||
|
|
||||||
The flow:
|
|
||||||
1. User provides an MCP server URL (and optional credentials)
|
|
||||||
2. Frontend calls the backend to get tool list from that URL
|
|
||||||
3. User selects a tool from a dropdown (available_tools)
|
|
||||||
4. The block's input schema updates to reflect the selected tool's parameters
|
|
||||||
5. On execution, the block calls the MCP server to run the tool
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
server_url: str = SchemaField(
|
|
||||||
description="URL of the MCP server (Streamable HTTP endpoint)",
|
|
||||||
placeholder="https://mcp.example.com/mcp",
|
|
||||||
)
|
|
||||||
credentials: MCPCredentials = CredentialsField(
|
|
||||||
discriminator="server_url",
|
|
||||||
description="MCP server OAuth credentials",
|
|
||||||
default={},
|
|
||||||
)
|
|
||||||
selected_tool: str = SchemaField(
|
|
||||||
description="The MCP tool to execute",
|
|
||||||
placeholder="Select a tool",
|
|
||||||
default="",
|
|
||||||
)
|
|
||||||
tool_input_schema: dict[str, Any] = SchemaField(
|
|
||||||
description="JSON Schema for the selected tool's input parameters. "
|
|
||||||
"Populated automatically when a tool is selected.",
|
|
||||||
default={},
|
|
||||||
hidden=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_arguments: dict[str, Any] = SchemaField(
|
|
||||||
description="Arguments to pass to the selected MCP tool. "
|
|
||||||
"The fields here are defined by the tool's input schema.",
|
|
||||||
default={},
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
|
|
||||||
"""Return the tool's input schema so the builder UI renders dynamic fields."""
|
|
||||||
return data.get("tool_input_schema", {})
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
|
||||||
"""Return the current tool_arguments as defaults for the dynamic fields."""
|
|
||||||
return data.get("tool_arguments", {})
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
|
||||||
"""Check which required tool arguments are missing."""
|
|
||||||
required_fields = cls.get_input_schema(data).get("required", [])
|
|
||||||
tool_arguments = data.get("tool_arguments", {})
|
|
||||||
return set(required_fields) - set(tool_arguments)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
|
||||||
"""Validate tool_arguments against the tool's input schema."""
|
|
||||||
tool_schema = cls.get_input_schema(data)
|
|
||||||
if not tool_schema:
|
|
||||||
return None
|
|
||||||
tool_arguments = data.get("tool_arguments", {})
|
|
||||||
return validate_with_jsonschema(tool_schema, tool_arguments)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
result: Any = SchemaField(description="The result returned by the MCP tool")
|
|
||||||
error: str = SchemaField(description="Error message if the tool call failed")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4",
|
|
||||||
description="Connect to any MCP server and execute its tools. "
|
|
||||||
"Provide a server URL, select a tool, and pass arguments dynamically.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=MCPToolBlock.Input,
|
|
||||||
output_schema=MCPToolBlock.Output,
|
|
||||||
block_type=BlockType.MCP_TOOL,
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_input={
|
|
||||||
"server_url": "https://mcp.example.com/mcp",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
"selected_tool": "get_weather",
|
|
||||||
"tool_input_schema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"city": {"type": "string"}},
|
|
||||||
"required": ["city"],
|
|
||||||
},
|
|
||||||
"tool_arguments": {"city": "London"},
|
|
||||||
},
|
|
||||||
test_output=[
|
|
||||||
(
|
|
||||||
"result",
|
|
||||||
{"weather": "sunny", "temperature": 20},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"_call_mcp_tool": lambda *a, **kw: {
|
|
||||||
"weather": "sunny",
|
|
||||||
"temperature": 20,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _call_mcp_tool(
|
|
||||||
self,
|
|
||||||
server_url: str,
|
|
||||||
tool_name: str,
|
|
||||||
arguments: dict[str, Any],
|
|
||||||
auth_token: str | None = None,
|
|
||||||
) -> Any:
|
|
||||||
"""Call a tool on the MCP server. Extracted for easy mocking in tests."""
|
|
||||||
client = MCPClient(server_url, auth_token=auth_token)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool(tool_name, arguments)
|
|
||||||
|
|
||||||
if result.is_error:
|
|
||||||
error_text = ""
|
|
||||||
for item in result.content:
|
|
||||||
if item.get("type") == "text":
|
|
||||||
error_text += item.get("text", "")
|
|
||||||
raise MCPClientError(
|
|
||||||
f"MCP tool '{tool_name}' returned an error: "
|
|
||||||
f"{error_text or 'Unknown error'}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract text content from the result
|
|
||||||
output_parts = []
|
|
||||||
for item in result.content:
|
|
||||||
if item.get("type") == "text":
|
|
||||||
text = item.get("text", "")
|
|
||||||
# Try to parse as JSON for structured output
|
|
||||||
try:
|
|
||||||
output_parts.append(json.loads(text))
|
|
||||||
except (json.JSONDecodeError, ValueError):
|
|
||||||
output_parts.append(text)
|
|
||||||
elif item.get("type") == "image":
|
|
||||||
output_parts.append(
|
|
||||||
{
|
|
||||||
"type": "image",
|
|
||||||
"data": item.get("data"),
|
|
||||||
"mimeType": item.get("mimeType"),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
elif item.get("type") == "resource":
|
|
||||||
output_parts.append(item.get("resource", {}))
|
|
||||||
|
|
||||||
# If single result, unwrap
|
|
||||||
if len(output_parts) == 1:
|
|
||||||
return output_parts[0]
|
|
||||||
return output_parts if output_parts else None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _auto_lookup_credential(
|
|
||||||
user_id: str, server_url: str
|
|
||||||
) -> "OAuth2Credentials | None":
|
|
||||||
"""Auto-lookup stored MCP credential for a server URL.
|
|
||||||
|
|
||||||
This is a fallback for nodes that don't have ``credentials`` explicitly
|
|
||||||
set (e.g. nodes created before the credential field was wired up).
|
|
||||||
"""
|
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
|
|
||||||
try:
|
|
||||||
mgr = IntegrationCredentialsManager()
|
|
||||||
mcp_creds = await mgr.store.get_creds_by_provider(
|
|
||||||
user_id, ProviderName.MCP.value
|
|
||||||
)
|
|
||||||
best: OAuth2Credentials | None = None
|
|
||||||
for cred in mcp_creds:
|
|
||||||
if (
|
|
||||||
isinstance(cred, OAuth2Credentials)
|
|
||||||
and (cred.metadata or {}).get("mcp_server_url") == server_url
|
|
||||||
):
|
|
||||||
if best is None or (
|
|
||||||
(cred.access_token_expires_at or 0)
|
|
||||||
> (best.access_token_expires_at or 0)
|
|
||||||
):
|
|
||||||
best = cred
|
|
||||||
if best:
|
|
||||||
best = await mgr.refresh_if_needed(user_id, best)
|
|
||||||
logger.info(
|
|
||||||
"Auto-resolved MCP credential %s for %s", best.id, server_url
|
|
||||||
)
|
|
||||||
return best
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Auto-lookup MCP credential failed", exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
user_id: str,
|
|
||||||
credentials: OAuth2Credentials | None = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
if not input_data.server_url:
|
|
||||||
yield "error", "MCP server URL is required"
|
|
||||||
return
|
|
||||||
|
|
||||||
if not input_data.selected_tool:
|
|
||||||
yield "error", "No tool selected. Please select a tool from the dropdown."
|
|
||||||
return
|
|
||||||
|
|
||||||
# Validate required tool arguments before calling the server.
|
|
||||||
# The executor-level validation is bypassed for MCP blocks because
|
|
||||||
# get_input_defaults() flattens tool_arguments, stripping tool_input_schema
|
|
||||||
# from the validation context.
|
|
||||||
required = set(input_data.tool_input_schema.get("required", []))
|
|
||||||
if required:
|
|
||||||
missing = required - set(input_data.tool_arguments.keys())
|
|
||||||
if missing:
|
|
||||||
yield "error", (
|
|
||||||
f"Missing required argument(s): {', '.join(sorted(missing))}. "
|
|
||||||
f"Please fill in all required fields marked with * in the block form."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# If no credentials were injected by the executor (e.g. legacy nodes
|
|
||||||
# that don't have the credentials field set), try to auto-lookup
|
|
||||||
# the stored MCP credential for this server URL.
|
|
||||||
if credentials is None:
|
|
||||||
credentials = await self._auto_lookup_credential(
|
|
||||||
user_id, input_data.server_url
|
|
||||||
)
|
|
||||||
|
|
||||||
auth_token = (
|
|
||||||
credentials.access_token.get_secret_value() if credentials else None
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await self._call_mcp_tool(
|
|
||||||
server_url=input_data.server_url,
|
|
||||||
tool_name=input_data.selected_tool,
|
|
||||||
arguments=input_data.tool_arguments,
|
|
||||||
auth_token=auth_token,
|
|
||||||
)
|
|
||||||
yield "result", result
|
|
||||||
except MCPClientError as e:
|
|
||||||
yield "error", str(e)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"MCP tool call failed: {e}")
|
|
||||||
yield "error", f"MCP tool call failed: {str(e)}"
|
|
||||||
@@ -1,323 +0,0 @@
|
|||||||
"""
|
|
||||||
MCP (Model Context Protocol) HTTP client.
|
|
||||||
|
|
||||||
Implements the MCP Streamable HTTP transport for listing tools and calling tools
|
|
||||||
on remote MCP servers. Uses JSON-RPC 2.0 over HTTP POST.
|
|
||||||
|
|
||||||
Handles both JSON and SSE (text/event-stream) response formats per the MCP spec.
|
|
||||||
|
|
||||||
Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from backend.util.request import Requests
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MCPTool:
|
|
||||||
"""Represents an MCP tool discovered from a server."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
input_schema: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MCPCallResult:
|
|
||||||
"""Result from calling an MCP tool."""
|
|
||||||
|
|
||||||
content: list[dict[str, Any]] = field(default_factory=list)
|
|
||||||
is_error: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class MCPClientError(Exception):
|
|
||||||
"""Raised when an MCP protocol error occurs."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MCPClient:
|
|
||||||
"""
|
|
||||||
Async HTTP client for the MCP Streamable HTTP transport.
|
|
||||||
|
|
||||||
Communicates with MCP servers using JSON-RPC 2.0 over HTTP POST.
|
|
||||||
Supports optional Bearer token authentication.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
server_url: str,
|
|
||||||
auth_token: str | None = None,
|
|
||||||
):
|
|
||||||
self.server_url = server_url.rstrip("/")
|
|
||||||
self.auth_token = auth_token
|
|
||||||
self._request_id = 0
|
|
||||||
self._session_id: str | None = None
|
|
||||||
|
|
||||||
def _next_id(self) -> int:
|
|
||||||
self._request_id += 1
|
|
||||||
return self._request_id
|
|
||||||
|
|
||||||
def _build_headers(self) -> dict[str, str]:
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Accept": "application/json, text/event-stream",
|
|
||||||
}
|
|
||||||
if self.auth_token:
|
|
||||||
headers["Authorization"] = f"Bearer {self.auth_token}"
|
|
||||||
if self._session_id:
|
|
||||||
headers["Mcp-Session-Id"] = self._session_id
|
|
||||||
return headers
|
|
||||||
|
|
||||||
def _build_jsonrpc_request(
|
|
||||||
self, method: str, params: dict[str, Any] | None = None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
req: dict[str, Any] = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"method": method,
|
|
||||||
"id": self._next_id(),
|
|
||||||
}
|
|
||||||
if params is not None:
|
|
||||||
req["params"] = params
|
|
||||||
return req
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _parse_sse_response(text: str) -> dict[str, Any]:
|
|
||||||
"""Parse an SSE (text/event-stream) response body into JSON-RPC data.
|
|
||||||
|
|
||||||
MCP servers may return responses as SSE with format:
|
|
||||||
event: message
|
|
||||||
data: {"jsonrpc":"2.0","result":{...},"id":1}
|
|
||||||
|
|
||||||
We extract the last `data:` line that contains a JSON-RPC response
|
|
||||||
(i.e. has an "id" field), which is the reply to our request.
|
|
||||||
"""
|
|
||||||
last_data: dict[str, Any] | None = None
|
|
||||||
for line in text.splitlines():
|
|
||||||
stripped = line.strip()
|
|
||||||
if stripped.startswith("data:"):
|
|
||||||
payload = stripped[len("data:") :].strip()
|
|
||||||
if not payload:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
parsed = json.loads(payload)
|
|
||||||
# Only keep JSON-RPC responses (have "id"), skip notifications
|
|
||||||
if isinstance(parsed, dict) and "id" in parsed:
|
|
||||||
last_data = parsed
|
|
||||||
except (json.JSONDecodeError, ValueError):
|
|
||||||
continue
|
|
||||||
if last_data is None:
|
|
||||||
raise MCPClientError("No JSON-RPC response found in SSE stream")
|
|
||||||
return last_data
|
|
||||||
|
|
||||||
async def _send_request(
|
|
||||||
self, method: str, params: dict[str, Any] | None = None
|
|
||||||
) -> Any:
|
|
||||||
"""Send a JSON-RPC request to the MCP server and return the result.
|
|
||||||
|
|
||||||
Handles both ``application/json`` and ``text/event-stream`` responses
|
|
||||||
as required by the MCP Streamable HTTP transport specification.
|
|
||||||
"""
|
|
||||||
payload = self._build_jsonrpc_request(method, params)
|
|
||||||
headers = self._build_headers()
|
|
||||||
|
|
||||||
requests = Requests(
|
|
||||||
raise_for_status=True,
|
|
||||||
extra_headers=headers,
|
|
||||||
)
|
|
||||||
response = await requests.post(self.server_url, json=payload)
|
|
||||||
|
|
||||||
# Capture session ID from response (MCP Streamable HTTP transport)
|
|
||||||
session_id = response.headers.get("Mcp-Session-Id")
|
|
||||||
if session_id:
|
|
||||||
self._session_id = session_id
|
|
||||||
|
|
||||||
content_type = response.headers.get("content-type", "")
|
|
||||||
if "text/event-stream" in content_type:
|
|
||||||
body = self._parse_sse_response(response.text())
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
body = response.json()
|
|
||||||
except Exception as e:
|
|
||||||
raise MCPClientError(
|
|
||||||
f"MCP server returned non-JSON response: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
if not isinstance(body, dict):
|
|
||||||
raise MCPClientError(
|
|
||||||
f"MCP server returned unexpected JSON type: {type(body).__name__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle JSON-RPC error
|
|
||||||
if "error" in body:
|
|
||||||
error = body["error"]
|
|
||||||
if isinstance(error, dict):
|
|
||||||
raise MCPClientError(
|
|
||||||
f"MCP server error [{error.get('code', '?')}]: "
|
|
||||||
f"{error.get('message', 'Unknown error')}"
|
|
||||||
)
|
|
||||||
raise MCPClientError(f"MCP server error: {error}")
|
|
||||||
|
|
||||||
return body.get("result")
|
|
||||||
|
|
||||||
async def _send_notification(self, method: str) -> None:
|
|
||||||
"""Send a JSON-RPC notification (no id, no response expected)."""
|
|
||||||
headers = self._build_headers()
|
|
||||||
notification = {"jsonrpc": "2.0", "method": method}
|
|
||||||
requests = Requests(
|
|
||||||
raise_for_status=False,
|
|
||||||
extra_headers=headers,
|
|
||||||
)
|
|
||||||
await requests.post(self.server_url, json=notification)
|
|
||||||
|
|
||||||
async def discover_auth(self) -> dict[str, Any] | None:
|
|
||||||
"""Probe the MCP server's OAuth metadata (RFC 9728 / MCP spec).
|
|
||||||
|
|
||||||
Returns ``None`` if the server doesn't require auth, otherwise returns
|
|
||||||
a dict with:
|
|
||||||
- ``authorization_servers``: list of authorization server URLs
|
|
||||||
- ``resource``: the resource indicator URL (usually the MCP endpoint)
|
|
||||||
- ``scopes_supported``: optional list of supported scopes
|
|
||||||
|
|
||||||
The caller can then fetch the authorization server metadata to get
|
|
||||||
``authorization_endpoint``, ``token_endpoint``, etc.
|
|
||||||
"""
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
parsed = urlparse(self.server_url)
|
|
||||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
|
||||||
|
|
||||||
# Build candidates for protected-resource metadata (per RFC 9728)
|
|
||||||
path = parsed.path.rstrip("/")
|
|
||||||
candidates = []
|
|
||||||
if path and path != "/":
|
|
||||||
candidates.append(f"{base}/.well-known/oauth-protected-resource{path}")
|
|
||||||
candidates.append(f"{base}/.well-known/oauth-protected-resource")
|
|
||||||
|
|
||||||
requests = Requests(
|
|
||||||
raise_for_status=False,
|
|
||||||
)
|
|
||||||
for url in candidates:
|
|
||||||
try:
|
|
||||||
resp = await requests.get(url)
|
|
||||||
if resp.status == 200:
|
|
||||||
data = resp.json()
|
|
||||||
if isinstance(data, dict) and "authorization_servers" in data:
|
|
||||||
return data
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def discover_auth_server_metadata(
|
|
||||||
self, auth_server_url: str
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Fetch the OAuth Authorization Server Metadata (RFC 8414).
|
|
||||||
|
|
||||||
Given an authorization server URL, returns a dict with:
|
|
||||||
- ``authorization_endpoint``
|
|
||||||
- ``token_endpoint``
|
|
||||||
- ``registration_endpoint`` (for dynamic client registration)
|
|
||||||
- ``scopes_supported``
|
|
||||||
- ``code_challenge_methods_supported``
|
|
||||||
- etc.
|
|
||||||
"""
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
parsed = urlparse(auth_server_url)
|
|
||||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
|
||||||
path = parsed.path.rstrip("/")
|
|
||||||
|
|
||||||
# Try standard metadata endpoints (RFC 8414 and OpenID Connect)
|
|
||||||
candidates = []
|
|
||||||
if path and path != "/":
|
|
||||||
candidates.append(f"{base}/.well-known/oauth-authorization-server{path}")
|
|
||||||
candidates.append(f"{base}/.well-known/oauth-authorization-server")
|
|
||||||
candidates.append(f"{base}/.well-known/openid-configuration")
|
|
||||||
|
|
||||||
requests = Requests(
|
|
||||||
raise_for_status=False,
|
|
||||||
)
|
|
||||||
for url in candidates:
|
|
||||||
try:
|
|
||||||
resp = await requests.get(url)
|
|
||||||
if resp.status == 200:
|
|
||||||
data = resp.json()
|
|
||||||
if isinstance(data, dict) and "authorization_endpoint" in data:
|
|
||||||
return data
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def initialize(self) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Send the MCP initialize request.
|
|
||||||
|
|
||||||
This is required by the MCP protocol before any other requests.
|
|
||||||
Returns the server's capabilities.
|
|
||||||
"""
|
|
||||||
result = await self._send_request(
|
|
||||||
"initialize",
|
|
||||||
{
|
|
||||||
"protocolVersion": "2025-03-26",
|
|
||||||
"capabilities": {},
|
|
||||||
"clientInfo": {"name": "AutoGPT-Platform", "version": "1.0.0"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# Send initialized notification (no response expected)
|
|
||||||
await self._send_notification("notifications/initialized")
|
|
||||||
|
|
||||||
return result or {}
|
|
||||||
|
|
||||||
async def list_tools(self) -> list[MCPTool]:
|
|
||||||
"""
|
|
||||||
Discover available tools from the MCP server.
|
|
||||||
|
|
||||||
Returns a list of MCPTool objects with name, description, and input schema.
|
|
||||||
"""
|
|
||||||
result = await self._send_request("tools/list")
|
|
||||||
if not result or "tools" not in result:
|
|
||||||
return []
|
|
||||||
|
|
||||||
tools = []
|
|
||||||
for tool_data in result["tools"]:
|
|
||||||
tools.append(
|
|
||||||
MCPTool(
|
|
||||||
name=tool_data.get("name", ""),
|
|
||||||
description=tool_data.get("description", ""),
|
|
||||||
input_schema=tool_data.get("inputSchema", {}),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return tools
|
|
||||||
|
|
||||||
async def call_tool(
|
|
||||||
self, tool_name: str, arguments: dict[str, Any]
|
|
||||||
) -> MCPCallResult:
|
|
||||||
"""
|
|
||||||
Call a tool on the MCP server.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name: The name of the tool to call.
|
|
||||||
arguments: The arguments to pass to the tool.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
MCPCallResult with the tool's response content.
|
|
||||||
"""
|
|
||||||
result = await self._send_request(
|
|
||||||
"tools/call",
|
|
||||||
{"name": tool_name, "arguments": arguments},
|
|
||||||
)
|
|
||||||
if not result:
|
|
||||||
return MCPCallResult(is_error=True)
|
|
||||||
|
|
||||||
return MCPCallResult(
|
|
||||||
content=result.get("content", []),
|
|
||||||
is_error=result.get("isError", False),
|
|
||||||
)
|
|
||||||
@@ -1,204 +0,0 @@
|
|||||||
"""
|
|
||||||
MCP OAuth handler for MCP servers that use OAuth 2.1 authorization.
|
|
||||||
|
|
||||||
Unlike other OAuth handlers (GitHub, Google, etc.) where endpoints are fixed,
|
|
||||||
MCP servers have dynamic endpoints discovered via RFC 9728 / RFC 8414 metadata.
|
|
||||||
This handler accepts those endpoints at construction time.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
import urllib.parse
|
|
||||||
from typing import ClassVar, Optional
|
|
||||||
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.util.request import Requests
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class MCPOAuthHandler(BaseOAuthHandler):
|
|
||||||
"""
|
|
||||||
OAuth handler for MCP servers with dynamically-discovered endpoints.
|
|
||||||
|
|
||||||
Construction requires the authorization and token endpoint URLs,
|
|
||||||
which are obtained via MCP OAuth metadata discovery
|
|
||||||
(``MCPClient.discover_auth`` + ``discover_auth_server_metadata``).
|
|
||||||
"""
|
|
||||||
|
|
||||||
PROVIDER_NAME: ClassVar[ProviderName | str] = ProviderName.MCP
|
|
||||||
DEFAULT_SCOPES: ClassVar[list[str]] = []
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
client_id: str,
|
|
||||||
client_secret: str,
|
|
||||||
redirect_uri: str,
|
|
||||||
*,
|
|
||||||
authorize_url: str,
|
|
||||||
token_url: str,
|
|
||||||
revoke_url: str | None = None,
|
|
||||||
resource_url: str | None = None,
|
|
||||||
):
|
|
||||||
self.client_id = client_id
|
|
||||||
self.client_secret = client_secret
|
|
||||||
self.redirect_uri = redirect_uri
|
|
||||||
self.authorize_url = authorize_url
|
|
||||||
self.token_url = token_url
|
|
||||||
self.revoke_url = revoke_url
|
|
||||||
self.resource_url = resource_url
|
|
||||||
|
|
||||||
def get_login_url(
|
|
||||||
self,
|
|
||||||
scopes: list[str],
|
|
||||||
state: str,
|
|
||||||
code_challenge: Optional[str],
|
|
||||||
) -> str:
|
|
||||||
scopes = self.handle_default_scopes(scopes)
|
|
||||||
|
|
||||||
params: dict[str, str] = {
|
|
||||||
"response_type": "code",
|
|
||||||
"client_id": self.client_id,
|
|
||||||
"redirect_uri": self.redirect_uri,
|
|
||||||
"state": state,
|
|
||||||
}
|
|
||||||
if scopes:
|
|
||||||
params["scope"] = " ".join(scopes)
|
|
||||||
# PKCE (S256) — included when the caller provides a code_challenge
|
|
||||||
if code_challenge:
|
|
||||||
params["code_challenge"] = code_challenge
|
|
||||||
params["code_challenge_method"] = "S256"
|
|
||||||
# MCP spec requires resource indicator (RFC 8707)
|
|
||||||
if self.resource_url:
|
|
||||||
params["resource"] = self.resource_url
|
|
||||||
|
|
||||||
return f"{self.authorize_url}?{urllib.parse.urlencode(params)}"
|
|
||||||
|
|
||||||
async def exchange_code_for_tokens(
|
|
||||||
self,
|
|
||||||
code: str,
|
|
||||||
scopes: list[str],
|
|
||||||
code_verifier: Optional[str],
|
|
||||||
) -> OAuth2Credentials:
|
|
||||||
data: dict[str, str] = {
|
|
||||||
"grant_type": "authorization_code",
|
|
||||||
"code": code,
|
|
||||||
"redirect_uri": self.redirect_uri,
|
|
||||||
"client_id": self.client_id,
|
|
||||||
}
|
|
||||||
if self.client_secret:
|
|
||||||
data["client_secret"] = self.client_secret
|
|
||||||
if code_verifier:
|
|
||||||
data["code_verifier"] = code_verifier
|
|
||||||
if self.resource_url:
|
|
||||||
data["resource"] = self.resource_url
|
|
||||||
|
|
||||||
response = await Requests(raise_for_status=True).post(
|
|
||||||
self.token_url,
|
|
||||||
data=data,
|
|
||||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
||||||
)
|
|
||||||
tokens = response.json()
|
|
||||||
|
|
||||||
if "error" in tokens:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Token exchange failed: {tokens.get('error_description', tokens['error'])}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if "access_token" not in tokens:
|
|
||||||
raise RuntimeError("OAuth token response missing 'access_token' field")
|
|
||||||
|
|
||||||
now = int(time.time())
|
|
||||||
expires_in = tokens.get("expires_in")
|
|
||||||
|
|
||||||
return OAuth2Credentials(
|
|
||||||
provider=self.PROVIDER_NAME,
|
|
||||||
title=None,
|
|
||||||
access_token=SecretStr(tokens["access_token"]),
|
|
||||||
refresh_token=(
|
|
||||||
SecretStr(tokens["refresh_token"])
|
|
||||||
if tokens.get("refresh_token")
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
access_token_expires_at=now + expires_in if expires_in else None,
|
|
||||||
refresh_token_expires_at=None,
|
|
||||||
scopes=scopes,
|
|
||||||
metadata={
|
|
||||||
"mcp_token_url": self.token_url,
|
|
||||||
"mcp_resource_url": self.resource_url,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _refresh_tokens(
|
|
||||||
self, credentials: OAuth2Credentials
|
|
||||||
) -> OAuth2Credentials:
|
|
||||||
if not credentials.refresh_token:
|
|
||||||
raise ValueError("No refresh token available for MCP OAuth credentials")
|
|
||||||
|
|
||||||
data: dict[str, str] = {
|
|
||||||
"grant_type": "refresh_token",
|
|
||||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
|
||||||
"client_id": self.client_id,
|
|
||||||
}
|
|
||||||
if self.client_secret:
|
|
||||||
data["client_secret"] = self.client_secret
|
|
||||||
if self.resource_url:
|
|
||||||
data["resource"] = self.resource_url
|
|
||||||
|
|
||||||
response = await Requests(raise_for_status=True).post(
|
|
||||||
self.token_url,
|
|
||||||
data=data,
|
|
||||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
||||||
)
|
|
||||||
tokens = response.json()
|
|
||||||
|
|
||||||
if "error" in tokens:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Token refresh failed: {tokens.get('error_description', tokens['error'])}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if "access_token" not in tokens:
|
|
||||||
raise RuntimeError("OAuth refresh response missing 'access_token' field")
|
|
||||||
|
|
||||||
now = int(time.time())
|
|
||||||
expires_in = tokens.get("expires_in")
|
|
||||||
|
|
||||||
return OAuth2Credentials(
|
|
||||||
id=credentials.id,
|
|
||||||
provider=self.PROVIDER_NAME,
|
|
||||||
title=credentials.title,
|
|
||||||
access_token=SecretStr(tokens["access_token"]),
|
|
||||||
refresh_token=(
|
|
||||||
SecretStr(tokens["refresh_token"])
|
|
||||||
if tokens.get("refresh_token")
|
|
||||||
else credentials.refresh_token
|
|
||||||
),
|
|
||||||
access_token_expires_at=now + expires_in if expires_in else None,
|
|
||||||
refresh_token_expires_at=credentials.refresh_token_expires_at,
|
|
||||||
scopes=credentials.scopes,
|
|
||||||
metadata=credentials.metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
|
||||||
if not self.revoke_url:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = {
|
|
||||||
"token": credentials.access_token.get_secret_value(),
|
|
||||||
"token_type_hint": "access_token",
|
|
||||||
"client_id": self.client_id,
|
|
||||||
}
|
|
||||||
await Requests().post(
|
|
||||||
self.revoke_url,
|
|
||||||
data=data,
|
|
||||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to revoke MCP OAuth tokens", exc_info=True)
|
|
||||||
return False
|
|
||||||
@@ -1,109 +0,0 @@
|
|||||||
"""
|
|
||||||
End-to-end tests against a real public MCP server.
|
|
||||||
|
|
||||||
These tests hit the OpenAI docs MCP server (https://developers.openai.com/mcp)
|
|
||||||
which is publicly accessible without authentication and returns SSE responses.
|
|
||||||
|
|
||||||
Mark: These are tagged with ``@pytest.mark.e2e`` so they can be run/skipped
|
|
||||||
independently of the rest of the test suite (they require network access).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from backend.blocks.mcp.client import MCPClient
|
|
||||||
|
|
||||||
# Public MCP server that requires no authentication
|
|
||||||
OPENAI_DOCS_MCP_URL = "https://developers.openai.com/mcp"
|
|
||||||
|
|
||||||
# Skip all tests in this module unless RUN_E2E env var is set
|
|
||||||
pytestmark = pytest.mark.skipif(
|
|
||||||
not os.environ.get("RUN_E2E"), reason="set RUN_E2E=1 to run e2e tests"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestRealMCPServer:
|
|
||||||
"""Tests against the live OpenAI docs MCP server."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_initialize(self):
|
|
||||||
"""Verify we can complete the MCP handshake with a real server."""
|
|
||||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
|
||||||
result = await client.initialize()
|
|
||||||
|
|
||||||
assert result["protocolVersion"] == "2025-03-26"
|
|
||||||
assert "serverInfo" in result
|
|
||||||
assert result["serverInfo"]["name"] == "openai-docs-mcp"
|
|
||||||
assert "tools" in result.get("capabilities", {})
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_list_tools(self):
|
|
||||||
"""Verify we can discover tools from a real MCP server."""
|
|
||||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
|
||||||
await client.initialize()
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
assert len(tools) >= 3 # server has at least 5 tools as of writing
|
|
||||||
|
|
||||||
tool_names = {t.name for t in tools}
|
|
||||||
# These tools are documented and should be stable
|
|
||||||
assert "search_openai_docs" in tool_names
|
|
||||||
assert "list_openai_docs" in tool_names
|
|
||||||
assert "fetch_openai_doc" in tool_names
|
|
||||||
|
|
||||||
# Verify schema structure
|
|
||||||
search_tool = next(t for t in tools if t.name == "search_openai_docs")
|
|
||||||
assert "query" in search_tool.input_schema.get("properties", {})
|
|
||||||
assert "query" in search_tool.input_schema.get("required", [])
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_tool_list_api_endpoints(self):
|
|
||||||
"""Call the list_api_endpoints tool and verify we get real data."""
|
|
||||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool("list_api_endpoints", {})
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
assert len(result.content) >= 1
|
|
||||||
assert result.content[0]["type"] == "text"
|
|
||||||
|
|
||||||
data = json.loads(result.content[0]["text"])
|
|
||||||
assert "paths" in data or "urls" in data
|
|
||||||
# The OpenAI API should have many endpoints
|
|
||||||
total = data.get("total", len(data.get("paths", [])))
|
|
||||||
assert total > 50
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_tool_search(self):
|
|
||||||
"""Search for docs and verify we get results."""
|
|
||||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool(
|
|
||||||
"search_openai_docs", {"query": "chat completions", "limit": 3}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
assert len(result.content) >= 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_sse_response_handling(self):
|
|
||||||
"""Verify the client correctly handles SSE responses from a real server.
|
|
||||||
|
|
||||||
This is the key test — our local test server returns JSON,
|
|
||||||
but real MCP servers typically return SSE. This proves the
|
|
||||||
SSE parsing works end-to-end.
|
|
||||||
"""
|
|
||||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
|
||||||
# initialize() internally calls _send_request which must parse SSE
|
|
||||||
result = await client.initialize()
|
|
||||||
|
|
||||||
# If we got here without error, SSE parsing works
|
|
||||||
assert isinstance(result, dict)
|
|
||||||
assert "protocolVersion" in result
|
|
||||||
|
|
||||||
# Also verify list_tools works (another SSE response)
|
|
||||||
tools = await client.list_tools()
|
|
||||||
assert len(tools) > 0
|
|
||||||
assert all(hasattr(t, "name") for t in tools)
|
|
||||||
@@ -1,389 +0,0 @@
|
|||||||
"""
|
|
||||||
Integration tests for MCP client and MCPToolBlock against a real HTTP server.
|
|
||||||
|
|
||||||
These tests spin up a local MCP test server and run the full client/block flow
|
|
||||||
against it — no mocking, real HTTP requests.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import threading
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from aiohttp import web
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.blocks.mcp.block import MCPToolBlock
|
|
||||||
from backend.blocks.mcp.client import MCPClient
|
|
||||||
from backend.blocks.mcp.test_server import create_test_mcp_app
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
|
|
||||||
MOCK_USER_ID = "test-user-integration"
|
|
||||||
|
|
||||||
|
|
||||||
class _MCPTestServer:
|
|
||||||
"""
|
|
||||||
Run an MCP test server in a background thread with its own event loop.
|
|
||||||
This avoids event loop conflicts with pytest-asyncio.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, auth_token: str | None = None):
|
|
||||||
self.auth_token = auth_token
|
|
||||||
self.url: str = ""
|
|
||||||
self._runner: web.AppRunner | None = None
|
|
||||||
self._loop: asyncio.AbstractEventLoop | None = None
|
|
||||||
self._thread: threading.Thread | None = None
|
|
||||||
self._started = threading.Event()
|
|
||||||
|
|
||||||
def _run(self):
|
|
||||||
self._loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(self._loop)
|
|
||||||
self._loop.run_until_complete(self._start())
|
|
||||||
self._started.set()
|
|
||||||
self._loop.run_forever()
|
|
||||||
|
|
||||||
async def _start(self):
|
|
||||||
app = create_test_mcp_app(auth_token=self.auth_token)
|
|
||||||
self._runner = web.AppRunner(app)
|
|
||||||
await self._runner.setup()
|
|
||||||
site = web.TCPSite(self._runner, "127.0.0.1", 0)
|
|
||||||
await site.start()
|
|
||||||
port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr]
|
|
||||||
self.url = f"http://127.0.0.1:{port}/mcp"
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
|
||||||
self._thread.start()
|
|
||||||
if not self._started.wait(timeout=5):
|
|
||||||
raise RuntimeError("MCP test server failed to start within 5 seconds")
|
|
||||||
return self
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
if self._loop and self._runner:
|
|
||||||
asyncio.run_coroutine_threadsafe(self._runner.cleanup(), self._loop).result(
|
|
||||||
timeout=5
|
|
||||||
)
|
|
||||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
|
||||||
if self._thread:
|
|
||||||
self._thread.join(timeout=5)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def mcp_server():
|
|
||||||
"""Start a local MCP test server in a background thread."""
|
|
||||||
server = _MCPTestServer()
|
|
||||||
server.start()
|
|
||||||
yield server.url
|
|
||||||
server.stop()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def mcp_server_with_auth():
|
|
||||||
"""Start a local MCP test server with auth in a background thread."""
|
|
||||||
server = _MCPTestServer(auth_token="test-secret-token")
|
|
||||||
server.start()
|
|
||||||
yield server.url, "test-secret-token"
|
|
||||||
server.stop()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def _allow_localhost():
|
|
||||||
"""
|
|
||||||
Allow 127.0.0.1 through SSRF protection for integration tests.
|
|
||||||
|
|
||||||
The Requests class blocks private IPs by default. We patch the Requests
|
|
||||||
constructor to always include 127.0.0.1 as a trusted origin so the local
|
|
||||||
test server is reachable.
|
|
||||||
"""
|
|
||||||
from backend.util.request import Requests
|
|
||||||
|
|
||||||
original_init = Requests.__init__
|
|
||||||
|
|
||||||
def patched_init(self, *args, **kwargs):
|
|
||||||
trusted = list(kwargs.get("trusted_origins") or [])
|
|
||||||
trusted.append("http://127.0.0.1")
|
|
||||||
kwargs["trusted_origins"] = trusted
|
|
||||||
original_init(self, *args, **kwargs)
|
|
||||||
|
|
||||||
with patch.object(Requests, "__init__", patched_init):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
def _make_client(url: str, auth_token: str | None = None) -> MCPClient:
|
|
||||||
"""Create an MCPClient for integration tests."""
|
|
||||||
return MCPClient(url, auth_token=auth_token)
|
|
||||||
|
|
||||||
|
|
||||||
# ── MCPClient integration tests ──────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPClientIntegration:
|
|
||||||
"""Test MCPClient against a real local MCP server."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_initialize(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
result = await client.initialize()
|
|
||||||
|
|
||||||
assert result["protocolVersion"] == "2025-03-26"
|
|
||||||
assert result["serverInfo"]["name"] == "test-mcp-server"
|
|
||||||
assert "tools" in result["capabilities"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_list_tools(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
assert len(tools) == 3
|
|
||||||
|
|
||||||
tool_names = {t.name for t in tools}
|
|
||||||
assert tool_names == {"get_weather", "add_numbers", "echo"}
|
|
||||||
|
|
||||||
# Check get_weather schema
|
|
||||||
weather = next(t for t in tools if t.name == "get_weather")
|
|
||||||
assert weather.description == "Get current weather for a city"
|
|
||||||
assert "city" in weather.input_schema["properties"]
|
|
||||||
assert weather.input_schema["required"] == ["city"]
|
|
||||||
|
|
||||||
# Check add_numbers schema
|
|
||||||
add = next(t for t in tools if t.name == "add_numbers")
|
|
||||||
assert "a" in add.input_schema["properties"]
|
|
||||||
assert "b" in add.input_schema["properties"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_tool_get_weather(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool("get_weather", {"city": "London"})
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
assert len(result.content) == 1
|
|
||||||
assert result.content[0]["type"] == "text"
|
|
||||||
|
|
||||||
data = json.loads(result.content[0]["text"])
|
|
||||||
assert data["city"] == "London"
|
|
||||||
assert data["temperature"] == 22
|
|
||||||
assert data["condition"] == "sunny"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_tool_add_numbers(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool("add_numbers", {"a": 3, "b": 7})
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
data = json.loads(result.content[0]["text"])
|
|
||||||
assert data["result"] == 10
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_tool_echo(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool("echo", {"message": "Hello MCP!"})
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
assert result.content[0]["text"] == "Hello MCP!"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_unknown_tool(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool("nonexistent_tool", {})
|
|
||||||
|
|
||||||
assert result.is_error
|
|
||||||
assert "Unknown tool" in result.content[0]["text"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_auth_success(self, mcp_server_with_auth):
|
|
||||||
url, token = mcp_server_with_auth
|
|
||||||
client = _make_client(url, auth_token=token)
|
|
||||||
result = await client.initialize()
|
|
||||||
|
|
||||||
assert result["protocolVersion"] == "2025-03-26"
|
|
||||||
|
|
||||||
tools = await client.list_tools()
|
|
||||||
assert len(tools) == 3
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_auth_failure(self, mcp_server_with_auth):
|
|
||||||
url, _ = mcp_server_with_auth
|
|
||||||
client = _make_client(url, auth_token="wrong-token")
|
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
await client.initialize()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_auth_missing(self, mcp_server_with_auth):
|
|
||||||
url, _ = mcp_server_with_auth
|
|
||||||
client = _make_client(url)
|
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
await client.initialize()
|
|
||||||
|
|
||||||
|
|
||||||
# ── MCPToolBlock integration tests ───────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPToolBlockIntegration:
|
|
||||||
"""Test MCPToolBlock end-to-end against a real local MCP server."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_full_flow_get_weather(self, mcp_server):
|
|
||||||
"""Full flow: discover tools, select one, execute it."""
|
|
||||||
# Step 1: Discover tools (simulating what the frontend/API would do)
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
tools = await client.list_tools()
|
|
||||||
assert len(tools) == 3
|
|
||||||
|
|
||||||
# Step 2: User selects "get_weather" and we get its schema
|
|
||||||
weather_tool = next(t for t in tools if t.name == "get_weather")
|
|
||||||
|
|
||||||
# Step 3: Execute the block — no credentials (public server)
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=mcp_server,
|
|
||||||
selected_tool="get_weather",
|
|
||||||
tool_input_schema=weather_tool.input_schema,
|
|
||||||
tool_arguments={"city": "Paris"},
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
result = outputs[0][1]
|
|
||||||
assert result["city"] == "Paris"
|
|
||||||
assert result["temperature"] == 22
|
|
||||||
assert result["condition"] == "sunny"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_full_flow_add_numbers(self, mcp_server):
|
|
||||||
"""Full flow for add_numbers tool."""
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
tools = await client.list_tools()
|
|
||||||
add_tool = next(t for t in tools if t.name == "add_numbers")
|
|
||||||
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=mcp_server,
|
|
||||||
selected_tool="add_numbers",
|
|
||||||
tool_input_schema=add_tool.input_schema,
|
|
||||||
tool_arguments={"a": 42, "b": 58},
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
assert outputs[0][1]["result"] == 100
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_full_flow_echo_plain_text(self, mcp_server):
|
|
||||||
"""Verify plain text (non-JSON) responses work."""
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=mcp_server,
|
|
||||||
selected_tool="echo",
|
|
||||||
tool_input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"message": {"type": "string"}},
|
|
||||||
"required": ["message"],
|
|
||||||
},
|
|
||||||
tool_arguments={"message": "Hello from AutoGPT!"},
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
assert outputs[0][1] == "Hello from AutoGPT!"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_full_flow_unknown_tool_yields_error(self, mcp_server):
|
|
||||||
"""Calling an unknown tool should yield an error output."""
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=mcp_server,
|
|
||||||
selected_tool="nonexistent_tool",
|
|
||||||
tool_arguments={},
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "error"
|
|
||||||
assert "returned an error" in outputs[0][1]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_full_flow_with_auth(self, mcp_server_with_auth):
|
|
||||||
"""Full flow with authentication via credentials kwarg."""
|
|
||||||
url, token = mcp_server_with_auth
|
|
||||||
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=url,
|
|
||||||
selected_tool="echo",
|
|
||||||
tool_input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"message": {"type": "string"}},
|
|
||||||
"required": ["message"],
|
|
||||||
},
|
|
||||||
tool_arguments={"message": "Authenticated!"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pass credentials via the standard kwarg (as the executor would)
|
|
||||||
test_creds = OAuth2Credentials(
|
|
||||||
id="test-cred",
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr(token),
|
|
||||||
refresh_token=SecretStr(""),
|
|
||||||
scopes=[],
|
|
||||||
title="Test MCP credential",
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(
|
|
||||||
input_data, user_id=MOCK_USER_ID, credentials=test_creds
|
|
||||||
):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
assert outputs[0][1] == "Authenticated!"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_no_credentials_runs_without_auth(self, mcp_server):
|
|
||||||
"""Block runs without auth when no credentials are provided."""
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=mcp_server,
|
|
||||||
selected_tool="echo",
|
|
||||||
tool_input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"message": {"type": "string"}},
|
|
||||||
"required": ["message"],
|
|
||||||
},
|
|
||||||
tool_arguments={"message": "No auth needed"},
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(
|
|
||||||
input_data, user_id=MOCK_USER_ID, credentials=None
|
|
||||||
):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
assert outputs[0][1] == "No auth needed"
|
|
||||||
@@ -1,619 +0,0 @@
|
|||||||
"""
|
|
||||||
Tests for MCP client and MCPToolBlock.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from unittest.mock import AsyncMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from backend.blocks.mcp.block import MCPToolBlock
|
|
||||||
from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError
|
|
||||||
from backend.util.test import execute_block_test
|
|
||||||
|
|
||||||
# ── SSE parsing unit tests ───────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestSSEParsing:
|
|
||||||
"""Tests for SSE (text/event-stream) response parsing."""
|
|
||||||
|
|
||||||
def test_parse_sse_simple(self):
|
|
||||||
sse = (
|
|
||||||
"event: message\n"
|
|
||||||
'data: {"jsonrpc":"2.0","result":{"tools":[]},"id":1}\n'
|
|
||||||
"\n"
|
|
||||||
)
|
|
||||||
body = MCPClient._parse_sse_response(sse)
|
|
||||||
assert body["result"] == {"tools": []}
|
|
||||||
assert body["id"] == 1
|
|
||||||
|
|
||||||
def test_parse_sse_with_notifications(self):
|
|
||||||
"""SSE streams can contain notifications (no id) before the response."""
|
|
||||||
sse = (
|
|
||||||
"event: message\n"
|
|
||||||
'data: {"jsonrpc":"2.0","method":"some/notification"}\n'
|
|
||||||
"\n"
|
|
||||||
"event: message\n"
|
|
||||||
'data: {"jsonrpc":"2.0","result":{"ok":true},"id":2}\n'
|
|
||||||
"\n"
|
|
||||||
)
|
|
||||||
body = MCPClient._parse_sse_response(sse)
|
|
||||||
assert body["result"] == {"ok": True}
|
|
||||||
assert body["id"] == 2
|
|
||||||
|
|
||||||
def test_parse_sse_error_response(self):
|
|
||||||
sse = (
|
|
||||||
"event: message\n"
|
|
||||||
'data: {"jsonrpc":"2.0","error":{"code":-32600,"message":"Bad Request"},"id":1}\n'
|
|
||||||
)
|
|
||||||
body = MCPClient._parse_sse_response(sse)
|
|
||||||
assert "error" in body
|
|
||||||
assert body["error"]["code"] == -32600
|
|
||||||
|
|
||||||
def test_parse_sse_no_data_raises(self):
|
|
||||||
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
|
|
||||||
MCPClient._parse_sse_response("event: message\n\n")
|
|
||||||
|
|
||||||
def test_parse_sse_empty_raises(self):
|
|
||||||
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
|
|
||||||
MCPClient._parse_sse_response("")
|
|
||||||
|
|
||||||
def test_parse_sse_ignores_non_data_lines(self):
|
|
||||||
sse = (
|
|
||||||
": comment line\n"
|
|
||||||
"event: message\n"
|
|
||||||
"id: 123\n"
|
|
||||||
'data: {"jsonrpc":"2.0","result":"ok","id":1}\n'
|
|
||||||
"\n"
|
|
||||||
)
|
|
||||||
body = MCPClient._parse_sse_response(sse)
|
|
||||||
assert body["result"] == "ok"
|
|
||||||
|
|
||||||
def test_parse_sse_uses_last_response(self):
|
|
||||||
"""If multiple responses exist, use the last one."""
|
|
||||||
sse = (
|
|
||||||
'data: {"jsonrpc":"2.0","result":"first","id":1}\n'
|
|
||||||
"\n"
|
|
||||||
'data: {"jsonrpc":"2.0","result":"second","id":2}\n'
|
|
||||||
"\n"
|
|
||||||
)
|
|
||||||
body = MCPClient._parse_sse_response(sse)
|
|
||||||
assert body["result"] == "second"
|
|
||||||
|
|
||||||
|
|
||||||
# ── MCPClient unit tests ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPClient:
|
|
||||||
"""Tests for the MCP HTTP client."""
|
|
||||||
|
|
||||||
def test_build_headers_without_auth(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
headers = client._build_headers()
|
|
||||||
assert "Authorization" not in headers
|
|
||||||
assert headers["Content-Type"] == "application/json"
|
|
||||||
|
|
||||||
def test_build_headers_with_auth(self):
|
|
||||||
client = MCPClient("https://mcp.example.com", auth_token="my-token")
|
|
||||||
headers = client._build_headers()
|
|
||||||
assert headers["Authorization"] == "Bearer my-token"
|
|
||||||
|
|
||||||
def test_build_jsonrpc_request(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
req = client._build_jsonrpc_request("tools/list")
|
|
||||||
assert req["jsonrpc"] == "2.0"
|
|
||||||
assert req["method"] == "tools/list"
|
|
||||||
assert "id" in req
|
|
||||||
assert "params" not in req
|
|
||||||
|
|
||||||
def test_build_jsonrpc_request_with_params(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
req = client._build_jsonrpc_request(
|
|
||||||
"tools/call", {"name": "test", "arguments": {"x": 1}}
|
|
||||||
)
|
|
||||||
assert req["params"] == {"name": "test", "arguments": {"x": 1}}
|
|
||||||
|
|
||||||
def test_request_id_increments(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
req1 = client._build_jsonrpc_request("tools/list")
|
|
||||||
req2 = client._build_jsonrpc_request("tools/list")
|
|
||||||
assert req2["id"] > req1["id"]
|
|
||||||
|
|
||||||
def test_server_url_trailing_slash_stripped(self):
|
|
||||||
client = MCPClient("https://mcp.example.com/mcp/")
|
|
||||||
assert client.server_url == "https://mcp.example.com/mcp"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_send_request_success(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
mock_response = AsyncMock()
|
|
||||||
mock_response.json.return_value = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"result": {"tools": []},
|
|
||||||
"id": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value={"tools": []}):
|
|
||||||
result = await client._send_request("tools/list")
|
|
||||||
assert result == {"tools": []}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_send_request_error(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
async def mock_send(*args, **kwargs):
|
|
||||||
raise MCPClientError("MCP server error [-32600]: Invalid Request")
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", side_effect=mock_send):
|
|
||||||
with pytest.raises(MCPClientError, match="Invalid Request"):
|
|
||||||
await client._send_request("tools/list")
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_list_tools(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
mock_result = {
|
|
||||||
"tools": [
|
|
||||||
{
|
|
||||||
"name": "get_weather",
|
|
||||||
"description": "Get current weather for a city",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"city": {"type": "string"}},
|
|
||||||
"required": ["city"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "search",
|
|
||||||
"description": "Search the web",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"query": {"type": "string"}},
|
|
||||||
"required": ["query"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value=mock_result):
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
assert len(tools) == 2
|
|
||||||
assert tools[0].name == "get_weather"
|
|
||||||
assert tools[0].description == "Get current weather for a city"
|
|
||||||
assert tools[0].input_schema["properties"]["city"]["type"] == "string"
|
|
||||||
assert tools[1].name == "search"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_list_tools_empty(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value={"tools": []}):
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
assert tools == []
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_list_tools_none_result(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value=None):
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
assert tools == []
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_tool_success(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
mock_result = {
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": json.dumps({"temp": 20, "city": "London"})}
|
|
||||||
],
|
|
||||||
"isError": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value=mock_result):
|
|
||||||
result = await client.call_tool("get_weather", {"city": "London"})
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
assert len(result.content) == 1
|
|
||||||
assert result.content[0]["type"] == "text"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_tool_error(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
mock_result = {
|
|
||||||
"content": [{"type": "text", "text": "City not found"}],
|
|
||||||
"isError": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value=mock_result):
|
|
||||||
result = await client.call_tool("get_weather", {"city": "???"})
|
|
||||||
|
|
||||||
assert result.is_error
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_tool_none_result(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value=None):
|
|
||||||
result = await client.call_tool("get_weather", {"city": "London"})
|
|
||||||
|
|
||||||
assert result.is_error
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_initialize(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
mock_result = {
|
|
||||||
"protocolVersion": "2025-03-26",
|
|
||||||
"capabilities": {"tools": {}},
|
|
||||||
"serverInfo": {"name": "test-server", "version": "1.0.0"},
|
|
||||||
}
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(client, "_send_request", return_value=mock_result) as mock_req,
|
|
||||||
patch.object(client, "_send_notification") as mock_notif,
|
|
||||||
):
|
|
||||||
result = await client.initialize()
|
|
||||||
|
|
||||||
mock_req.assert_called_once()
|
|
||||||
mock_notif.assert_called_once_with("notifications/initialized")
|
|
||||||
assert result["protocolVersion"] == "2025-03-26"
|
|
||||||
|
|
||||||
|
|
||||||
# ── MCPToolBlock unit tests ──────────────────────────────────────────
|
|
||||||
|
|
||||||
MOCK_USER_ID = "test-user-123"
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPToolBlock:
|
|
||||||
"""Tests for the MCPToolBlock."""
|
|
||||||
|
|
||||||
def test_block_instantiation(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
assert block.id == "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
|
|
||||||
assert block.name == "MCPToolBlock"
|
|
||||||
|
|
||||||
def test_input_schema_has_required_fields(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
schema = block.input_schema.jsonschema()
|
|
||||||
props = schema.get("properties", {})
|
|
||||||
assert "server_url" in props
|
|
||||||
assert "selected_tool" in props
|
|
||||||
assert "tool_arguments" in props
|
|
||||||
assert "credentials" in props
|
|
||||||
|
|
||||||
def test_output_schema(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
schema = block.output_schema.jsonschema()
|
|
||||||
props = schema.get("properties", {})
|
|
||||||
assert "result" in props
|
|
||||||
assert "error" in props
|
|
||||||
|
|
||||||
def test_get_input_schema_with_tool_schema(self):
|
|
||||||
tool_schema = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"query": {"type": "string"}},
|
|
||||||
"required": ["query"],
|
|
||||||
}
|
|
||||||
data = {"tool_input_schema": tool_schema}
|
|
||||||
result = MCPToolBlock.Input.get_input_schema(data)
|
|
||||||
assert result == tool_schema
|
|
||||||
|
|
||||||
def test_get_input_schema_without_tool_schema(self):
|
|
||||||
result = MCPToolBlock.Input.get_input_schema({})
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
def test_get_input_defaults(self):
|
|
||||||
data = {"tool_arguments": {"city": "London"}}
|
|
||||||
result = MCPToolBlock.Input.get_input_defaults(data)
|
|
||||||
assert result == {"city": "London"}
|
|
||||||
|
|
||||||
def test_get_missing_input(self):
|
|
||||||
data = {
|
|
||||||
"tool_input_schema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"city": {"type": "string"},
|
|
||||||
"units": {"type": "string"},
|
|
||||||
},
|
|
||||||
"required": ["city", "units"],
|
|
||||||
},
|
|
||||||
"tool_arguments": {"city": "London"},
|
|
||||||
}
|
|
||||||
missing = MCPToolBlock.Input.get_missing_input(data)
|
|
||||||
assert missing == {"units"}
|
|
||||||
|
|
||||||
def test_get_missing_input_all_present(self):
|
|
||||||
data = {
|
|
||||||
"tool_input_schema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"city": {"type": "string"}},
|
|
||||||
"required": ["city"],
|
|
||||||
},
|
|
||||||
"tool_arguments": {"city": "London"},
|
|
||||||
}
|
|
||||||
missing = MCPToolBlock.Input.get_missing_input(data)
|
|
||||||
assert missing == set()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_run_with_mock(self):
|
|
||||||
"""Test the block using the built-in test infrastructure."""
|
|
||||||
block = MCPToolBlock()
|
|
||||||
await execute_block_test(block)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_run_missing_server_url(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="",
|
|
||||||
selected_tool="test",
|
|
||||||
)
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
assert outputs == [("error", "MCP server URL is required")]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_run_missing_tool(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="https://mcp.example.com/mcp",
|
|
||||||
selected_tool="",
|
|
||||||
)
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
assert outputs == [
|
|
||||||
("error", "No tool selected. Please select a tool from the dropdown.")
|
|
||||||
]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_run_success(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="https://mcp.example.com/mcp",
|
|
||||||
selected_tool="get_weather",
|
|
||||||
tool_input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"city": {"type": "string"}},
|
|
||||||
},
|
|
||||||
tool_arguments={"city": "London"},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_call(*args, **kwargs):
|
|
||||||
return {"temp": 20, "city": "London"}
|
|
||||||
|
|
||||||
block._call_mcp_tool = mock_call # type: ignore
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
assert outputs[0][1] == {"temp": 20, "city": "London"}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_run_mcp_error(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="https://mcp.example.com/mcp",
|
|
||||||
selected_tool="bad_tool",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_call(*args, **kwargs):
|
|
||||||
raise MCPClientError("Tool not found")
|
|
||||||
|
|
||||||
block._call_mcp_tool = mock_call # type: ignore
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert outputs[0][0] == "error"
|
|
||||||
assert "Tool not found" in outputs[0][1]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_mcp_tool_parses_json_text(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
|
|
||||||
mock_result = MCPCallResult(
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": '{"temp": 20}'},
|
|
||||||
],
|
|
||||||
is_error=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_init(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def mock_call(self, name, args):
|
|
||||||
return mock_result
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MCPClient, "initialize", mock_init),
|
|
||||||
patch.object(MCPClient, "call_tool", mock_call),
|
|
||||||
):
|
|
||||||
result = await block._call_mcp_tool(
|
|
||||||
"https://mcp.example.com", "test_tool", {}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == {"temp": 20}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_mcp_tool_plain_text(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
|
|
||||||
mock_result = MCPCallResult(
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": "Hello, world!"},
|
|
||||||
],
|
|
||||||
is_error=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_init(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def mock_call(self, name, args):
|
|
||||||
return mock_result
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MCPClient, "initialize", mock_init),
|
|
||||||
patch.object(MCPClient, "call_tool", mock_call),
|
|
||||||
):
|
|
||||||
result = await block._call_mcp_tool(
|
|
||||||
"https://mcp.example.com", "test_tool", {}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == "Hello, world!"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_mcp_tool_multiple_content(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
|
|
||||||
mock_result = MCPCallResult(
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": "Part 1"},
|
|
||||||
{"type": "text", "text": '{"part": 2}'},
|
|
||||||
],
|
|
||||||
is_error=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_init(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def mock_call(self, name, args):
|
|
||||||
return mock_result
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MCPClient, "initialize", mock_init),
|
|
||||||
patch.object(MCPClient, "call_tool", mock_call),
|
|
||||||
):
|
|
||||||
result = await block._call_mcp_tool(
|
|
||||||
"https://mcp.example.com", "test_tool", {}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == ["Part 1", {"part": 2}]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_mcp_tool_error_result(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
|
|
||||||
mock_result = MCPCallResult(
|
|
||||||
content=[{"type": "text", "text": "Something went wrong"}],
|
|
||||||
is_error=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_init(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def mock_call(self, name, args):
|
|
||||||
return mock_result
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MCPClient, "initialize", mock_init),
|
|
||||||
patch.object(MCPClient, "call_tool", mock_call),
|
|
||||||
):
|
|
||||||
with pytest.raises(MCPClientError, match="returned an error"):
|
|
||||||
await block._call_mcp_tool("https://mcp.example.com", "test_tool", {})
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_mcp_tool_image_content(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
|
|
||||||
mock_result = MCPCallResult(
|
|
||||||
content=[
|
|
||||||
{
|
|
||||||
"type": "image",
|
|
||||||
"data": "base64data==",
|
|
||||||
"mimeType": "image/png",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
is_error=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_init(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def mock_call(self, name, args):
|
|
||||||
return mock_result
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MCPClient, "initialize", mock_init),
|
|
||||||
patch.object(MCPClient, "call_tool", mock_call),
|
|
||||||
):
|
|
||||||
result = await block._call_mcp_tool(
|
|
||||||
"https://mcp.example.com", "test_tool", {}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == {
|
|
||||||
"type": "image",
|
|
||||||
"data": "base64data==",
|
|
||||||
"mimeType": "image/png",
|
|
||||||
}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_run_with_credentials(self):
|
|
||||||
"""Verify the block uses OAuth2Credentials and passes auth token."""
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="https://mcp.example.com/mcp",
|
|
||||||
selected_tool="test_tool",
|
|
||||||
)
|
|
||||||
|
|
||||||
captured_tokens: list[str | None] = []
|
|
||||||
|
|
||||||
async def mock_call(server_url, tool_name, arguments, auth_token=None):
|
|
||||||
captured_tokens.append(auth_token)
|
|
||||||
return "ok"
|
|
||||||
|
|
||||||
block._call_mcp_tool = mock_call # type: ignore
|
|
||||||
|
|
||||||
test_creds = OAuth2Credentials(
|
|
||||||
id="cred-123",
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("resolved-token"),
|
|
||||||
refresh_token=SecretStr(""),
|
|
||||||
scopes=[],
|
|
||||||
title="Test MCP credential",
|
|
||||||
)
|
|
||||||
|
|
||||||
async for _ in block.run(
|
|
||||||
input_data, user_id=MOCK_USER_ID, credentials=test_creds
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|
||||||
assert captured_tokens == ["resolved-token"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_run_without_credentials(self):
|
|
||||||
"""Verify the block works without credentials (public server)."""
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="https://mcp.example.com/mcp",
|
|
||||||
selected_tool="test_tool",
|
|
||||||
)
|
|
||||||
|
|
||||||
captured_tokens: list[str | None] = []
|
|
||||||
|
|
||||||
async def mock_call(server_url, tool_name, arguments, auth_token=None):
|
|
||||||
captured_tokens.append(auth_token)
|
|
||||||
return "ok"
|
|
||||||
|
|
||||||
block._call_mcp_tool = mock_call # type: ignore
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert captured_tokens == [None]
|
|
||||||
assert outputs == [("result", "ok")]
|
|
||||||
@@ -1,242 +0,0 @@
|
|||||||
"""
|
|
||||||
Tests for MCP OAuth handler.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.blocks.mcp.client import MCPClient
|
|
||||||
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_response(json_data: dict, status: int = 200) -> MagicMock:
|
|
||||||
"""Create a mock Response with synchronous json() (matching Requests.Response)."""
|
|
||||||
resp = MagicMock()
|
|
||||||
resp.status = status
|
|
||||||
resp.ok = 200 <= status < 300
|
|
||||||
resp.json.return_value = json_data
|
|
||||||
return resp
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPOAuthHandler:
|
|
||||||
"""Tests for the MCPOAuthHandler."""
|
|
||||||
|
|
||||||
def _make_handler(self, **overrides) -> MCPOAuthHandler:
|
|
||||||
defaults = {
|
|
||||||
"client_id": "test-client-id",
|
|
||||||
"client_secret": "test-client-secret",
|
|
||||||
"redirect_uri": "https://app.example.com/callback",
|
|
||||||
"authorize_url": "https://auth.example.com/authorize",
|
|
||||||
"token_url": "https://auth.example.com/token",
|
|
||||||
}
|
|
||||||
defaults.update(overrides)
|
|
||||||
return MCPOAuthHandler(**defaults)
|
|
||||||
|
|
||||||
def test_get_login_url_basic(self):
|
|
||||||
handler = self._make_handler()
|
|
||||||
url = handler.get_login_url(
|
|
||||||
scopes=["read", "write"],
|
|
||||||
state="random-state-token",
|
|
||||||
code_challenge="S256-challenge-value",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "https://auth.example.com/authorize?" in url
|
|
||||||
assert "response_type=code" in url
|
|
||||||
assert "client_id=test-client-id" in url
|
|
||||||
assert "state=random-state-token" in url
|
|
||||||
assert "code_challenge=S256-challenge-value" in url
|
|
||||||
assert "code_challenge_method=S256" in url
|
|
||||||
assert "scope=read+write" in url
|
|
||||||
|
|
||||||
def test_get_login_url_with_resource(self):
|
|
||||||
handler = self._make_handler(resource_url="https://mcp.example.com/mcp")
|
|
||||||
url = handler.get_login_url(
|
|
||||||
scopes=[], state="state", code_challenge="challenge"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "resource=https" in url
|
|
||||||
|
|
||||||
def test_get_login_url_without_pkce(self):
|
|
||||||
handler = self._make_handler()
|
|
||||||
url = handler.get_login_url(scopes=["read"], state="state", code_challenge=None)
|
|
||||||
|
|
||||||
assert "code_challenge" not in url
|
|
||||||
assert "code_challenge_method" not in url
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_exchange_code_for_tokens(self):
|
|
||||||
handler = self._make_handler()
|
|
||||||
|
|
||||||
resp = _mock_response(
|
|
||||||
{
|
|
||||||
"access_token": "new-access-token",
|
|
||||||
"refresh_token": "new-refresh-token",
|
|
||||||
"expires_in": 3600,
|
|
||||||
"token_type": "Bearer",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.post = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
creds = await handler.exchange_code_for_tokens(
|
|
||||||
code="auth-code",
|
|
||||||
scopes=["read"],
|
|
||||||
code_verifier="pkce-verifier",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(creds, OAuth2Credentials)
|
|
||||||
assert creds.access_token.get_secret_value() == "new-access-token"
|
|
||||||
assert creds.refresh_token is not None
|
|
||||||
assert creds.refresh_token.get_secret_value() == "new-refresh-token"
|
|
||||||
assert creds.scopes == ["read"]
|
|
||||||
assert creds.access_token_expires_at is not None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_refresh_tokens(self):
|
|
||||||
handler = self._make_handler()
|
|
||||||
|
|
||||||
existing_creds = OAuth2Credentials(
|
|
||||||
id="existing-id",
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("old-token"),
|
|
||||||
refresh_token=SecretStr("old-refresh"),
|
|
||||||
scopes=["read"],
|
|
||||||
title="test",
|
|
||||||
)
|
|
||||||
|
|
||||||
resp = _mock_response(
|
|
||||||
{
|
|
||||||
"access_token": "refreshed-token",
|
|
||||||
"refresh_token": "new-refresh",
|
|
||||||
"expires_in": 3600,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.post = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
refreshed = await handler._refresh_tokens(existing_creds)
|
|
||||||
|
|
||||||
assert refreshed.id == "existing-id"
|
|
||||||
assert refreshed.access_token.get_secret_value() == "refreshed-token"
|
|
||||||
assert refreshed.refresh_token is not None
|
|
||||||
assert refreshed.refresh_token.get_secret_value() == "new-refresh"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_refresh_tokens_no_refresh_token(self):
|
|
||||||
handler = self._make_handler()
|
|
||||||
|
|
||||||
creds = OAuth2Credentials(
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("token"),
|
|
||||||
scopes=["read"],
|
|
||||||
title="test",
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="No refresh token"):
|
|
||||||
await handler._refresh_tokens(creds)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_revoke_tokens_no_url(self):
|
|
||||||
handler = self._make_handler(revoke_url=None)
|
|
||||||
|
|
||||||
creds = OAuth2Credentials(
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("token"),
|
|
||||||
scopes=[],
|
|
||||||
title="test",
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await handler.revoke_tokens(creds)
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_revoke_tokens_with_url(self):
|
|
||||||
handler = self._make_handler(revoke_url="https://auth.example.com/revoke")
|
|
||||||
|
|
||||||
creds = OAuth2Credentials(
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("token"),
|
|
||||||
scopes=[],
|
|
||||||
title="test",
|
|
||||||
)
|
|
||||||
|
|
||||||
resp = _mock_response({}, status=200)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.post = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
result = await handler.revoke_tokens(creds)
|
|
||||||
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPClientDiscovery:
|
|
||||||
"""Tests for MCPClient OAuth metadata discovery."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_discover_auth_found(self):
|
|
||||||
client = MCPClient("https://mcp.example.com/mcp")
|
|
||||||
|
|
||||||
metadata = {
|
|
||||||
"authorization_servers": ["https://auth.example.com"],
|
|
||||||
"resource": "https://mcp.example.com/mcp",
|
|
||||||
}
|
|
||||||
|
|
||||||
resp = _mock_response(metadata, status=200)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.get = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
result = await client.discover_auth()
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
assert result["authorization_servers"] == ["https://auth.example.com"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_discover_auth_not_found(self):
|
|
||||||
client = MCPClient("https://mcp.example.com/mcp")
|
|
||||||
|
|
||||||
resp = _mock_response({}, status=404)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.get = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
result = await client.discover_auth()
|
|
||||||
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_discover_auth_server_metadata(self):
|
|
||||||
client = MCPClient("https://mcp.example.com/mcp")
|
|
||||||
|
|
||||||
server_metadata = {
|
|
||||||
"issuer": "https://auth.example.com",
|
|
||||||
"authorization_endpoint": "https://auth.example.com/authorize",
|
|
||||||
"token_endpoint": "https://auth.example.com/token",
|
|
||||||
"registration_endpoint": "https://auth.example.com/register",
|
|
||||||
"code_challenge_methods_supported": ["S256"],
|
|
||||||
}
|
|
||||||
|
|
||||||
resp = _mock_response(server_metadata, status=200)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.get = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
result = await client.discover_auth_server_metadata(
|
|
||||||
"https://auth.example.com"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
assert result["authorization_endpoint"] == "https://auth.example.com/authorize"
|
|
||||||
assert result["token_endpoint"] == "https://auth.example.com/token"
|
|
||||||
@@ -1,162 +0,0 @@
|
|||||||
"""
|
|
||||||
Minimal MCP server for integration testing.
|
|
||||||
|
|
||||||
Implements the MCP Streamable HTTP transport (JSON-RPC 2.0 over HTTP POST)
|
|
||||||
with a few sample tools. Runs on localhost with a random available port.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from aiohttp import web
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Sample tools this test server exposes
|
|
||||||
TEST_TOOLS = [
|
|
||||||
{
|
|
||||||
"name": "get_weather",
|
|
||||||
"description": "Get current weather for a city",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"city": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "City name",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["city"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "add_numbers",
|
|
||||||
"description": "Add two numbers together",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"a": {"type": "number", "description": "First number"},
|
|
||||||
"b": {"type": "number", "description": "Second number"},
|
|
||||||
},
|
|
||||||
"required": ["a", "b"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "echo",
|
|
||||||
"description": "Echo back the input message",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"message": {"type": "string", "description": "Message to echo"},
|
|
||||||
},
|
|
||||||
"required": ["message"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_initialize(params: dict) -> dict:
|
|
||||||
return {
|
|
||||||
"protocolVersion": "2025-03-26",
|
|
||||||
"capabilities": {"tools": {"listChanged": False}},
|
|
||||||
"serverInfo": {"name": "test-mcp-server", "version": "1.0.0"},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_tools_list(params: dict) -> dict:
|
|
||||||
return {"tools": TEST_TOOLS}
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_tools_call(params: dict) -> dict:
|
|
||||||
tool_name = params.get("name", "")
|
|
||||||
arguments = params.get("arguments", {})
|
|
||||||
|
|
||||||
if tool_name == "get_weather":
|
|
||||||
city = arguments.get("city", "Unknown")
|
|
||||||
return {
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": json.dumps(
|
|
||||||
{"city": city, "temperature": 22, "condition": "sunny"}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
elif tool_name == "add_numbers":
|
|
||||||
a = arguments.get("a", 0)
|
|
||||||
b = arguments.get("b", 0)
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": json.dumps({"result": a + b})}],
|
|
||||||
}
|
|
||||||
|
|
||||||
elif tool_name == "echo":
|
|
||||||
message = arguments.get("message", "")
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": message}],
|
|
||||||
}
|
|
||||||
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}],
|
|
||||||
"isError": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
HANDLERS = {
|
|
||||||
"initialize": _handle_initialize,
|
|
||||||
"tools/list": _handle_tools_list,
|
|
||||||
"tools/call": _handle_tools_call,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_mcp_request(request: web.Request) -> web.Response:
|
|
||||||
"""Handle incoming MCP JSON-RPC 2.0 requests."""
|
|
||||||
# Check auth if configured
|
|
||||||
expected_token = request.app.get("auth_token")
|
|
||||||
if expected_token:
|
|
||||||
auth_header = request.headers.get("Authorization", "")
|
|
||||||
if auth_header != f"Bearer {expected_token}":
|
|
||||||
return web.json_response(
|
|
||||||
{
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"error": {"code": -32001, "message": "Unauthorized"},
|
|
||||||
"id": None,
|
|
||||||
},
|
|
||||||
status=401,
|
|
||||||
)
|
|
||||||
|
|
||||||
body = await request.json()
|
|
||||||
|
|
||||||
# Handle notifications (no id field) — just acknowledge
|
|
||||||
if "id" not in body:
|
|
||||||
return web.Response(status=202)
|
|
||||||
|
|
||||||
method = body.get("method", "")
|
|
||||||
params = body.get("params", {})
|
|
||||||
request_id = body.get("id")
|
|
||||||
|
|
||||||
handler = HANDLERS.get(method)
|
|
||||||
if not handler:
|
|
||||||
return web.json_response(
|
|
||||||
{
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"error": {
|
|
||||||
"code": -32601,
|
|
||||||
"message": f"Method not found: {method}",
|
|
||||||
},
|
|
||||||
"id": request_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
result = handler(params)
|
|
||||||
return web.json_response({"jsonrpc": "2.0", "result": result, "id": request_id})
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_mcp_app(auth_token: str | None = None) -> web.Application:
|
|
||||||
"""Create an aiohttp app that acts as an MCP server."""
|
|
||||||
app = web.Application()
|
|
||||||
app.router.add_post("/mcp", handle_mcp_request)
|
|
||||||
if auth_token:
|
|
||||||
app["auth_token"] = auth_token
|
|
||||||
return app
|
|
||||||
@@ -3,7 +3,7 @@ from typing import List, Literal
|
|||||||
|
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Any, Literal, Optional, Union
|
|||||||
from mem0 import MemoryClient
|
from mem0 import MemoryClient
|
||||||
from pydantic import BaseModel, SecretStr
|
from pydantic import BaseModel, SecretStr
|
||||||
|
|
||||||
from backend.blocks._base import Block, BlockOutput, BlockSchemaInput, BlockSchemaOutput
|
from backend.data.block import Block, BlockOutput, BlockSchemaInput, BlockSchemaOutput
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from pydantic import model_validator
|
from pydantic import model_validator
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import List, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
from backend.blocks._base import (
|
from backend.blocks.nvidia._auth import (
|
||||||
|
NvidiaCredentials,
|
||||||
|
NvidiaCredentialsField,
|
||||||
|
NvidiaCredentialsInput,
|
||||||
|
)
|
||||||
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.blocks.nvidia._auth import (
|
|
||||||
NvidiaCredentials,
|
|
||||||
NvidiaCredentialsField,
|
|
||||||
NvidiaCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
from backend.util.type import MediaFileType
|
from backend.util.type import MediaFileType
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Any, Literal
|
|||||||
import openai
|
import openai
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.data.block import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
BlockOutput,
|
BlockOutput,
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user