mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-09 06:15:41 -05:00
Compare commits
31 Commits
fix/sentry
...
feat/mcp-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
54375065d5 | ||
|
|
d62fde9445 | ||
|
|
03487f7b4d | ||
|
|
df41d02fce | ||
|
|
7c9e47ba76 | ||
|
|
e59e8dd9a9 | ||
|
|
7aab2eb1d5 | ||
|
|
5ab28ccda2 | ||
|
|
4fe0f05980 | ||
|
|
19b3373052 | ||
|
|
7db3f12876 | ||
|
|
e9b996abb0 | ||
|
|
9b972389a0 | ||
|
|
cd64562e1b | ||
|
|
8fddc9d71f | ||
|
|
3d1cd03fc8 | ||
|
|
e7ebe42306 | ||
|
|
e0fab7e34e | ||
|
|
29ee85c86f | ||
|
|
85b6520710 | ||
|
|
bfa942e032 | ||
|
|
11256076d8 | ||
|
|
3ca2387631 | ||
|
|
ed07f02738 | ||
|
|
b121030c94 | ||
|
|
c22c18374d | ||
|
|
e40233a3ac | ||
|
|
3ae5eabf9d | ||
|
|
a077ba9f03 | ||
|
|
5401d54eaa | ||
|
|
5ac89d7c0b |
16
.github/workflows/platform-frontend-ci.yml
vendored
16
.github/workflows/platform-frontend-ci.yml
vendored
@@ -27,11 +27,20 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
outputs:
|
outputs:
|
||||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||||
|
components-changed: ${{ steps.filter.outputs.components }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Check for component changes
|
||||||
|
uses: dorny/paths-filter@v3
|
||||||
|
id: filter
|
||||||
|
with:
|
||||||
|
filters: |
|
||||||
|
components:
|
||||||
|
- 'autogpt_platform/frontend/src/components/**'
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
@@ -90,8 +99,11 @@ jobs:
|
|||||||
chromatic:
|
chromatic:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: setup
|
needs: setup
|
||||||
# Only run on dev branch pushes or PRs targeting dev
|
# Disabled: to re-enable, remove 'false &&' from the condition below
|
||||||
if: github.ref == 'refs/heads/dev' || github.base_ref == 'dev'
|
if: >-
|
||||||
|
false
|
||||||
|
&& (github.ref == 'refs/heads/dev' || github.base_ref == 'dev')
|
||||||
|
&& needs.setup.outputs.components-changed == 'true'
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
|
|||||||
1328
autogpt_platform/autogpt_libs/poetry.lock
generated
1328
autogpt_platform/autogpt_libs/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -11,15 +11,15 @@ python = ">=3.10,<4.0"
|
|||||||
colorama = "^0.4.6"
|
colorama = "^0.4.6"
|
||||||
cryptography = "^45.0"
|
cryptography = "^45.0"
|
||||||
expiringdict = "^1.2.2"
|
expiringdict = "^1.2.2"
|
||||||
fastapi = "^0.116.1"
|
fastapi = "^0.128.0"
|
||||||
google-cloud-logging = "^3.12.1"
|
google-cloud-logging = "^3.13.0"
|
||||||
launchdarkly-server-sdk = "^9.12.0"
|
launchdarkly-server-sdk = "^9.14.1"
|
||||||
pydantic = "^2.11.7"
|
pydantic = "^2.12.5"
|
||||||
pydantic-settings = "^2.10.1"
|
pydantic-settings = "^2.12.0"
|
||||||
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
||||||
redis = "^6.2.0"
|
redis = "^6.2.0"
|
||||||
supabase = "^2.16.0"
|
supabase = "^2.27.2"
|
||||||
uvicorn = "^0.35.0"
|
uvicorn = "^0.40.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pyright = "^1.1.404"
|
pyright = "^1.1.404"
|
||||||
|
|||||||
@@ -152,6 +152,7 @@ REPLICATE_API_KEY=
|
|||||||
REVID_API_KEY=
|
REVID_API_KEY=
|
||||||
SCREENSHOTONE_API_KEY=
|
SCREENSHOTONE_API_KEY=
|
||||||
UNREAL_SPEECH_API_KEY=
|
UNREAL_SPEECH_API_KEY=
|
||||||
|
ELEVENLABS_API_KEY=
|
||||||
|
|
||||||
# Data & Search Services
|
# Data & Search Services
|
||||||
E2B_API_KEY=
|
E2B_API_KEY=
|
||||||
|
|||||||
3
autogpt_platform/backend/.gitignore
vendored
3
autogpt_platform/backend/.gitignore
vendored
@@ -19,3 +19,6 @@ load-tests/*.json
|
|||||||
load-tests/*.log
|
load-tests/*.log
|
||||||
load-tests/node_modules/*
|
load-tests/node_modules/*
|
||||||
migrations/*/rollback*.sql
|
migrations/*/rollback*.sql
|
||||||
|
|
||||||
|
# Workspace files
|
||||||
|
workspaces/
|
||||||
|
|||||||
@@ -62,10 +62,12 @@ ENV POETRY_HOME=/opt/poetry \
|
|||||||
DEBIAN_FRONTEND=noninteractive
|
DEBIAN_FRONTEND=noninteractive
|
||||||
ENV PATH=/opt/poetry/bin:$PATH
|
ENV PATH=/opt/poetry/bin:$PATH
|
||||||
|
|
||||||
# Install Python without upgrading system-managed packages
|
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
python3.13 \
|
python3.13 \
|
||||||
python3-pip \
|
python3-pip \
|
||||||
|
ffmpeg \
|
||||||
|
imagemagick \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy only necessary files from builder
|
# Copy only necessary files from builder
|
||||||
|
|||||||
76
autogpt_platform/backend/MCP_BLOCK_IMPLEMENTATION.md
Normal file
76
autogpt_platform/backend/MCP_BLOCK_IMPLEMENTATION.md
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
# MCP Block Implementation Plan
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Create a single **MCPBlock** that dynamically integrates with any MCP (Model Context Protocol)
|
||||||
|
server. Users provide a server URL, the block discovers available tools, presents them as a
|
||||||
|
dropdown, and dynamically adjusts input/output schema based on the selected tool — exactly like
|
||||||
|
`AgentExecutorBlock` handles dynamic schemas.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
User provides MCP server URL + credentials
|
||||||
|
↓
|
||||||
|
MCPBlock fetches tools via MCP protocol (tools/list)
|
||||||
|
↓
|
||||||
|
User selects tool from dropdown (stored in constantInput)
|
||||||
|
↓
|
||||||
|
Input schema dynamically updates based on selected tool's inputSchema
|
||||||
|
↓
|
||||||
|
On execution: MCPBlock calls the tool via MCP protocol (tools/call)
|
||||||
|
↓
|
||||||
|
Result yielded as block output
|
||||||
|
```
|
||||||
|
|
||||||
|
## Design Decisions
|
||||||
|
|
||||||
|
1. **Single block, not many blocks** — One `MCPBlock` handles all MCP servers/tools
|
||||||
|
2. **Dynamic schema via AgentExecutorBlock pattern** — Override `get_input_schema()`,
|
||||||
|
`get_input_defaults()`, `get_missing_input()` on the Input class
|
||||||
|
3. **Auth via API key or OAuth2 credentials** — Use existing `APIKeyCredentials` or
|
||||||
|
`OAuth2Credentials` with `ProviderName.MCP` provider. API keys are sent as Bearer tokens;
|
||||||
|
OAuth2 uses the access token.
|
||||||
|
4. **HTTP-based MCP client** — Use `aiohttp` (already a dependency) to implement MCP Streamable
|
||||||
|
HTTP transport directly. No need for the `mcp` Python SDK — the protocol is simple JSON-RPC
|
||||||
|
over HTTP. Handles both JSON and SSE response formats.
|
||||||
|
5. **No new DB tables** — Everything fits in existing `AgentBlock` + `AgentNode` tables
|
||||||
|
|
||||||
|
## Implementation Files
|
||||||
|
|
||||||
|
### New Files
|
||||||
|
- `backend/blocks/mcp/` — MCP block package
|
||||||
|
- `__init__.py`
|
||||||
|
- `block.py` — MCPToolBlock implementation
|
||||||
|
- `client.py` — MCP HTTP client (list_tools, call_tool)
|
||||||
|
- `oauth.py` — MCP OAuth handler for dynamic endpoint discovery
|
||||||
|
- `test_mcp.py` — Unit tests
|
||||||
|
- `test_oauth.py` — OAuth handler tests
|
||||||
|
- `test_integration.py` — Integration tests with local test server
|
||||||
|
- `test_e2e.py` — E2E tests against real MCP servers
|
||||||
|
|
||||||
|
### Modified Files
|
||||||
|
- `backend/integrations/providers.py` — Add `MCP = "mcp"` to ProviderName
|
||||||
|
|
||||||
|
## Dev Loop
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd autogpt_platform/backend
|
||||||
|
poetry run pytest backend/blocks/mcp/test_mcp.py -xvs # Unit tests
|
||||||
|
poetry run pytest backend/blocks/mcp/test_oauth.py -xvs # OAuth tests
|
||||||
|
poetry run pytest backend/blocks/mcp/test_integration.py -xvs # Integration tests
|
||||||
|
poetry run pytest backend/blocks/mcp/ -xvs # All MCP tests
|
||||||
|
```
|
||||||
|
|
||||||
|
## Status
|
||||||
|
|
||||||
|
- [x] Research & Design
|
||||||
|
- [x] Add ProviderName.MCP
|
||||||
|
- [x] Implement MCP client (client.py)
|
||||||
|
- [x] Implement MCPToolBlock (block.py)
|
||||||
|
- [x] Add OAuth2 support (oauth.py)
|
||||||
|
- [x] Write unit tests
|
||||||
|
- [x] Write integration tests
|
||||||
|
- [x] Write E2E tests
|
||||||
|
- [x] Run tests & fix issues
|
||||||
|
- [x] Create PR
|
||||||
@@ -11,7 +11,7 @@ class ChatConfig(BaseSettings):
|
|||||||
|
|
||||||
# OpenAI API Configuration
|
# OpenAI API Configuration
|
||||||
model: str = Field(
|
model: str = Field(
|
||||||
default="anthropic/claude-opus-4.5", description="Default model to use"
|
default="anthropic/claude-opus-4.6", description="Default model to use"
|
||||||
)
|
)
|
||||||
title_model: str = Field(
|
title_model: str = Field(
|
||||||
default="openai/gpt-4o-mini",
|
default="openai/gpt-4o-mini",
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from backend.data.understanding import (
|
|||||||
get_business_understanding,
|
get_business_understanding,
|
||||||
)
|
)
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import AppEnvironment, Settings
|
||||||
|
|
||||||
from . import db as chat_db
|
from . import db as chat_db
|
||||||
from . import stream_registry
|
from . import stream_registry
|
||||||
@@ -222,8 +222,18 @@ async def _get_system_prompt_template(context: str) -> str:
|
|||||||
try:
|
try:
|
||||||
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
||||||
# Use asyncio.to_thread to avoid blocking the event loop
|
# Use asyncio.to_thread to avoid blocking the event loop
|
||||||
|
# In non-production environments, fetch the latest prompt version
|
||||||
|
# instead of the production-labeled version for easier testing
|
||||||
|
label = (
|
||||||
|
None
|
||||||
|
if settings.config.app_env == AppEnvironment.PRODUCTION
|
||||||
|
else "latest"
|
||||||
|
)
|
||||||
prompt = await asyncio.to_thread(
|
prompt = await asyncio.to_thread(
|
||||||
langfuse.get_prompt, config.langfuse_prompt_name, cache_ttl_seconds=0
|
langfuse.get_prompt,
|
||||||
|
config.langfuse_prompt_name,
|
||||||
|
label=label,
|
||||||
|
cache_ttl_seconds=0,
|
||||||
)
|
)
|
||||||
return prompt.compile(users_information=context)
|
return prompt.compile(users_information=context)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -618,6 +628,9 @@ async def stream_chat_completion(
|
|||||||
total_tokens=chunk.totalTokens,
|
total_tokens=chunk.totalTokens,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
elif isinstance(chunk, StreamHeartbeat):
|
||||||
|
# Pass through heartbeat to keep SSE connection alive
|
||||||
|
yield chunk
|
||||||
else:
|
else:
|
||||||
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
|
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
|
||||||
|
|
||||||
|
|||||||
@@ -7,15 +7,7 @@ from typing import Any, NotRequired, TypedDict
|
|||||||
|
|
||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data.graph import (
|
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
|
||||||
Graph,
|
|
||||||
Link,
|
|
||||||
Node,
|
|
||||||
create_graph,
|
|
||||||
get_graph,
|
|
||||||
get_graph_all_versions,
|
|
||||||
get_store_listed_graphs,
|
|
||||||
)
|
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
from .service import (
|
from .service import (
|
||||||
@@ -28,8 +20,6 @@ from .service import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionSummary(TypedDict):
|
class ExecutionSummary(TypedDict):
|
||||||
"""Summary of a single execution for quality assessment."""
|
"""Summary of a single execution for quality assessment."""
|
||||||
@@ -669,45 +659,6 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _reassign_node_ids(graph: Graph) -> None:
|
|
||||||
"""Reassign all node and link IDs to new UUIDs.
|
|
||||||
|
|
||||||
This is needed when creating a new version to avoid unique constraint violations.
|
|
||||||
"""
|
|
||||||
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
|
||||||
|
|
||||||
for node in graph.nodes:
|
|
||||||
node.id = id_map[node.id]
|
|
||||||
|
|
||||||
for link in graph.links:
|
|
||||||
link.id = str(uuid.uuid4())
|
|
||||||
if link.source_id in id_map:
|
|
||||||
link.source_id = id_map[link.source_id]
|
|
||||||
if link.sink_id in id_map:
|
|
||||||
link.sink_id = id_map[link.sink_id]
|
|
||||||
|
|
||||||
|
|
||||||
def _populate_agent_executor_user_ids(agent_json: dict[str, Any], user_id: str) -> None:
|
|
||||||
"""Populate user_id in AgentExecutorBlock nodes.
|
|
||||||
|
|
||||||
The external agent generator creates AgentExecutorBlock nodes with empty user_id.
|
|
||||||
This function fills in the actual user_id so sub-agents run with correct permissions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_json: Agent JSON dict (modified in place)
|
|
||||||
user_id: User ID to set
|
|
||||||
"""
|
|
||||||
for node in agent_json.get("nodes", []):
|
|
||||||
if node.get("block_id") == AGENT_EXECUTOR_BLOCK_ID:
|
|
||||||
input_default = node.get("input_default") or {}
|
|
||||||
if not input_default.get("user_id"):
|
|
||||||
input_default["user_id"] = user_id
|
|
||||||
node["input_default"] = input_default
|
|
||||||
logger.debug(
|
|
||||||
f"Set user_id for AgentExecutorBlock node {node.get('id')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def save_agent_to_library(
|
async def save_agent_to_library(
|
||||||
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||||
) -> tuple[Graph, Any]:
|
) -> tuple[Graph, Any]:
|
||||||
@@ -721,35 +672,10 @@ async def save_agent_to_library(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (created Graph, LibraryAgent)
|
Tuple of (created Graph, LibraryAgent)
|
||||||
"""
|
"""
|
||||||
# Populate user_id in AgentExecutorBlock nodes before conversion
|
|
||||||
_populate_agent_executor_user_ids(agent_json, user_id)
|
|
||||||
|
|
||||||
graph = json_to_graph(agent_json)
|
graph = json_to_graph(agent_json)
|
||||||
|
|
||||||
if is_update:
|
if is_update:
|
||||||
if graph.id:
|
return await library_db.update_graph_in_library(graph, user_id)
|
||||||
existing_versions = await get_graph_all_versions(graph.id, user_id)
|
return await library_db.create_graph_in_library(graph, user_id)
|
||||||
if existing_versions:
|
|
||||||
latest_version = max(v.version for v in existing_versions)
|
|
||||||
graph.version = latest_version + 1
|
|
||||||
_reassign_node_ids(graph)
|
|
||||||
logger.info(f"Updating agent {graph.id} to version {graph.version}")
|
|
||||||
else:
|
|
||||||
graph.id = str(uuid.uuid4())
|
|
||||||
graph.version = 1
|
|
||||||
_reassign_node_ids(graph)
|
|
||||||
logger.info(f"Creating new agent with ID {graph.id}")
|
|
||||||
|
|
||||||
created_graph = await create_graph(graph, user_id)
|
|
||||||
|
|
||||||
library_agents = await library_db.create_library_agent(
|
|
||||||
graph=created_graph,
|
|
||||||
user_id=user_id,
|
|
||||||
sensitive_action_safe_mode=True,
|
|
||||||
create_library_agents_for_sub_graphs=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return created_graph, library_agents[0]
|
|
||||||
|
|
||||||
|
|
||||||
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
||||||
|
|||||||
@@ -206,9 +206,9 @@ async def search_agents(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
no_results_msg = (
|
no_results_msg = (
|
||||||
f"No agents found matching '{query}'. Try different keywords or browse the marketplace."
|
f"No agents found matching '{query}'. Let the user know they can try different keywords or browse the marketplace. Also let them know you can create a custom agent for them based on their needs."
|
||||||
if source == "marketplace"
|
if source == "marketplace"
|
||||||
else f"No agents matching '{query}' found in your library."
|
else f"No agents matching '{query}' found in your library. Let the user know you can create a custom agent for them based on their needs."
|
||||||
)
|
)
|
||||||
return NoResultsResponse(
|
return NoResultsResponse(
|
||||||
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
||||||
@@ -224,10 +224,10 @@ async def search_agents(
|
|||||||
message = (
|
message = (
|
||||||
"Now you have found some options for the user to choose from. "
|
"Now you have found some options for the user to choose from. "
|
||||||
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
||||||
"Please ask the user if they would like to use any of these agents."
|
"Please ask the user if they would like to use any of these agents. Let the user know we can create a custom agent for them based on their needs."
|
||||||
if source == "marketplace"
|
if source == "marketplace"
|
||||||
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
||||||
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute."
|
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute. Let the user know we can create a custom agent for them based on their needs."
|
||||||
)
|
)
|
||||||
|
|
||||||
return AgentsFoundResponse(
|
return AgentsFoundResponse(
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from typing import Any
|
|||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.api.features.library import model as library_model
|
from backend.api.features.library import model as library_model
|
||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data import graph as graph_db
|
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
CredentialsFieldInfo,
|
CredentialsFieldInfo,
|
||||||
@@ -44,14 +43,8 @@ async def fetch_graph_from_store_slug(
|
|||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# Get the graph from store listing version
|
# Get the graph from store listing version
|
||||||
graph_meta = await store_db.get_available_graph(
|
graph = await store_db.get_available_graph(
|
||||||
store_agent.store_listing_version_id
|
store_agent.store_listing_version_id, hide_nodes=False
|
||||||
)
|
|
||||||
graph = await graph_db.get_graph(
|
|
||||||
graph_id=graph_meta.id,
|
|
||||||
version=graph_meta.version,
|
|
||||||
user_id=None, # Public access
|
|
||||||
include_subgraphs=True,
|
|
||||||
)
|
)
|
||||||
return graph, store_agent
|
return graph, store_agent
|
||||||
|
|
||||||
@@ -128,7 +121,7 @@ def build_missing_credentials_from_graph(
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
field_key: _serialize_missing_credential(field_key, field_info)
|
field_key: _serialize_missing_credential(field_key, field_info)
|
||||||
for field_key, (field_info, _node_fields) in aggregated_fields.items()
|
for field_key, (field_info, _, _) in aggregated_fields.items()
|
||||||
if field_key not in matched_keys
|
if field_key not in matched_keys
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -269,7 +262,8 @@ async def match_user_credentials_to_graph(
|
|||||||
# provider is in the set of acceptable providers.
|
# provider is in the set of acceptable providers.
|
||||||
for credential_field_name, (
|
for credential_field_name, (
|
||||||
credential_requirements,
|
credential_requirements,
|
||||||
_node_fields,
|
_,
|
||||||
|
_,
|
||||||
) in aggregated_creds.items():
|
) in aggregated_creds.items():
|
||||||
# Find first matching credential by provider, type, and scopes
|
# Find first matching credential by provider, type, and scopes
|
||||||
matching_cred = next(
|
matching_cred = next(
|
||||||
|
|||||||
@@ -19,7 +19,10 @@ from backend.data.graph import GraphSettings
|
|||||||
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||||
|
on_graph_activate,
|
||||||
|
on_graph_deactivate,
|
||||||
|
)
|
||||||
from backend.util.clients import get_scheduler_client
|
from backend.util.clients import get_scheduler_client
|
||||||
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
|
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
@@ -371,7 +374,7 @@ async def get_library_agent_by_graph_id(
|
|||||||
|
|
||||||
|
|
||||||
async def add_generated_agent_image(
|
async def add_generated_agent_image(
|
||||||
graph: graph_db.BaseGraph,
|
graph: graph_db.GraphBaseMeta,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
) -> Optional[prisma.models.LibraryAgent]:
|
) -> Optional[prisma.models.LibraryAgent]:
|
||||||
@@ -537,6 +540,92 @@ async def update_agent_version_in_library(
|
|||||||
return library_model.LibraryAgent.from_db(lib)
|
return library_model.LibraryAgent.from_db(lib)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_graph_in_library(
|
||||||
|
graph: graph_db.Graph,
|
||||||
|
user_id: str,
|
||||||
|
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
||||||
|
"""Create a new graph and add it to the user's library."""
|
||||||
|
graph.version = 1
|
||||||
|
graph_model = graph_db.make_graph_model(graph, user_id)
|
||||||
|
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
||||||
|
|
||||||
|
created_graph = await graph_db.create_graph(graph_model, user_id)
|
||||||
|
|
||||||
|
library_agents = await create_library_agent(
|
||||||
|
graph=created_graph,
|
||||||
|
user_id=user_id,
|
||||||
|
sensitive_action_safe_mode=True,
|
||||||
|
create_library_agents_for_sub_graphs=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if created_graph.is_active:
|
||||||
|
created_graph = await on_graph_activate(created_graph, user_id=user_id)
|
||||||
|
|
||||||
|
return created_graph, library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def update_graph_in_library(
|
||||||
|
graph: graph_db.Graph,
|
||||||
|
user_id: str,
|
||||||
|
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
||||||
|
"""Create a new version of an existing graph and update the library entry."""
|
||||||
|
existing_versions = await graph_db.get_graph_all_versions(graph.id, user_id)
|
||||||
|
current_active_version = (
|
||||||
|
next((v for v in existing_versions if v.is_active), None)
|
||||||
|
if existing_versions
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
graph.version = (
|
||||||
|
max(v.version for v in existing_versions) + 1 if existing_versions else 1
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_model = graph_db.make_graph_model(graph, user_id)
|
||||||
|
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
||||||
|
|
||||||
|
created_graph = await graph_db.create_graph(graph_model, user_id)
|
||||||
|
|
||||||
|
library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id)
|
||||||
|
if not library_agent:
|
||||||
|
raise NotFoundError(f"Library agent not found for graph {created_graph.id}")
|
||||||
|
|
||||||
|
library_agent = await update_library_agent_version_and_settings(
|
||||||
|
user_id, created_graph
|
||||||
|
)
|
||||||
|
|
||||||
|
if created_graph.is_active:
|
||||||
|
created_graph = await on_graph_activate(created_graph, user_id=user_id)
|
||||||
|
await graph_db.set_graph_active_version(
|
||||||
|
graph_id=created_graph.id,
|
||||||
|
version=created_graph.version,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
if current_active_version:
|
||||||
|
await on_graph_deactivate(current_active_version, user_id=user_id)
|
||||||
|
|
||||||
|
return created_graph, library_agent
|
||||||
|
|
||||||
|
|
||||||
|
async def update_library_agent_version_and_settings(
|
||||||
|
user_id: str, agent_graph: graph_db.GraphModel
|
||||||
|
) -> library_model.LibraryAgent:
|
||||||
|
"""Update library agent to point to new graph version and sync settings."""
|
||||||
|
library = await update_agent_version_in_library(
|
||||||
|
user_id, agent_graph.id, agent_graph.version
|
||||||
|
)
|
||||||
|
updated_settings = GraphSettings.from_graph(
|
||||||
|
graph=agent_graph,
|
||||||
|
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
||||||
|
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
||||||
|
)
|
||||||
|
if updated_settings != library.settings:
|
||||||
|
library = await update_library_agent(
|
||||||
|
library_agent_id=library.id,
|
||||||
|
user_id=user_id,
|
||||||
|
settings=updated_settings,
|
||||||
|
)
|
||||||
|
return library
|
||||||
|
|
||||||
|
|
||||||
async def update_library_agent(
|
async def update_library_agent(
|
||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
|||||||
365
autogpt_platform/backend/backend/api/features/mcp/routes.py
Normal file
365
autogpt_platform/backend/backend/api/features/mcp/routes.py
Normal file
@@ -0,0 +1,365 @@
|
|||||||
|
"""
|
||||||
|
MCP (Model Context Protocol) API routes.
|
||||||
|
|
||||||
|
Provides endpoints for MCP tool discovery and OAuth authentication so the
|
||||||
|
frontend can list available tools on an MCP server before placing a block.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Annotated, Any
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
from autogpt_libs.auth import get_user_id
|
||||||
|
from fastapi import Security
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
||||||
|
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||||
|
from backend.data.model import OAuth2Credentials
|
||||||
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util.request import HTTPClientError, Requests
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
router = fastapi.APIRouter(tags=["mcp"])
|
||||||
|
creds_manager = IntegrationCredentialsManager()
|
||||||
|
|
||||||
|
|
||||||
|
# ====================== Tool Discovery ====================== #
|
||||||
|
|
||||||
|
|
||||||
|
class DiscoverToolsRequest(BaseModel):
|
||||||
|
"""Request to discover tools on an MCP server."""
|
||||||
|
|
||||||
|
server_url: str = Field(description="URL of the MCP server")
|
||||||
|
auth_token: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional Bearer token for authenticated MCP servers",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPToolResponse(BaseModel):
|
||||||
|
"""A single MCP tool returned by discovery."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
input_schema: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class DiscoverToolsResponse(BaseModel):
|
||||||
|
"""Response containing the list of tools available on an MCP server."""
|
||||||
|
|
||||||
|
tools: list[MCPToolResponse]
|
||||||
|
server_name: str | None = None
|
||||||
|
protocol_version: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/discover-tools",
|
||||||
|
summary="Discover available tools on an MCP server",
|
||||||
|
response_model=DiscoverToolsResponse,
|
||||||
|
)
|
||||||
|
async def discover_tools(
|
||||||
|
request: DiscoverToolsRequest,
|
||||||
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
|
) -> DiscoverToolsResponse:
|
||||||
|
"""
|
||||||
|
Connect to an MCP server and return its available tools.
|
||||||
|
|
||||||
|
If the user has a stored MCP credential for this server URL, it will be
|
||||||
|
used automatically — no need to pass an explicit auth token.
|
||||||
|
"""
|
||||||
|
auth_token = request.auth_token
|
||||||
|
|
||||||
|
# Auto-use stored MCP credential when no explicit token is provided
|
||||||
|
if not auth_token:
|
||||||
|
try:
|
||||||
|
mcp_creds = await creds_manager.store.get_creds_by_provider(
|
||||||
|
user_id, str(ProviderName.MCP)
|
||||||
|
)
|
||||||
|
for cred in mcp_creds:
|
||||||
|
if (
|
||||||
|
isinstance(cred, OAuth2Credentials)
|
||||||
|
and cred.metadata.get("mcp_server_url") == request.server_url
|
||||||
|
):
|
||||||
|
auth_token = cred.access_token.get_secret_value()
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Could not look up stored MCP credentials", exc_info=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
client = MCPClient(
|
||||||
|
request.server_url,
|
||||||
|
auth_token=auth_token,
|
||||||
|
trusted_origins=[request.server_url],
|
||||||
|
)
|
||||||
|
|
||||||
|
init_result = await client.initialize()
|
||||||
|
tools = await client.list_tools()
|
||||||
|
|
||||||
|
return DiscoverToolsResponse(
|
||||||
|
tools=[
|
||||||
|
MCPToolResponse(
|
||||||
|
name=t.name,
|
||||||
|
description=t.description,
|
||||||
|
input_schema=t.input_schema,
|
||||||
|
)
|
||||||
|
for t in tools
|
||||||
|
],
|
||||||
|
server_name=init_result.get("serverInfo", {}).get("name"),
|
||||||
|
protocol_version=init_result.get("protocolVersion"),
|
||||||
|
)
|
||||||
|
except HTTPClientError as e:
|
||||||
|
if e.status_code in (401, 403):
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="This MCP server requires authentication. "
|
||||||
|
"Please provide a valid auth token.",
|
||||||
|
)
|
||||||
|
raise fastapi.HTTPException(status_code=502, detail=str(e))
|
||||||
|
except MCPClientError as e:
|
||||||
|
raise fastapi.HTTPException(status_code=502, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("MCP tool discovery failed")
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=f"Failed to connect to MCP server: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ======================== OAuth Flow ======================== #
|
||||||
|
|
||||||
|
|
||||||
|
class MCPOAuthLoginRequest(BaseModel):
|
||||||
|
"""Request to start an OAuth flow for an MCP server."""
|
||||||
|
|
||||||
|
server_url: str = Field(description="URL of the MCP server that requires OAuth")
|
||||||
|
|
||||||
|
|
||||||
|
class MCPOAuthLoginResponse(BaseModel):
|
||||||
|
"""Response with the OAuth login URL for the user to authenticate."""
|
||||||
|
|
||||||
|
login_url: str
|
||||||
|
state_token: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/oauth/login",
|
||||||
|
summary="Initiate OAuth login for an MCP server",
|
||||||
|
)
|
||||||
|
async def mcp_oauth_login(
|
||||||
|
request: MCPOAuthLoginRequest,
|
||||||
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
|
) -> MCPOAuthLoginResponse:
|
||||||
|
"""
|
||||||
|
Discover OAuth metadata from the MCP server and return a login URL.
|
||||||
|
|
||||||
|
1. Discovers the protected-resource metadata (RFC 9728)
|
||||||
|
2. Fetches the authorization server metadata (RFC 8414)
|
||||||
|
3. Performs Dynamic Client Registration (RFC 7591) if available
|
||||||
|
4. Returns the authorization URL for the frontend to open in a popup
|
||||||
|
"""
|
||||||
|
client = MCPClient(request.server_url, trusted_origins=[request.server_url])
|
||||||
|
|
||||||
|
# Step 1: Discover protected-resource metadata (RFC 9728)
|
||||||
|
try:
|
||||||
|
protected_resource = await client.discover_auth()
|
||||||
|
except Exception as e:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=f"Failed to discover OAuth metadata: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not protected_resource or "authorization_servers" not in protected_resource:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="This MCP server does not advertise OAuth support. "
|
||||||
|
"You may need to provide an auth token manually.",
|
||||||
|
)
|
||||||
|
|
||||||
|
auth_server_url = protected_resource["authorization_servers"][0]
|
||||||
|
resource_url = protected_resource.get("resource", request.server_url)
|
||||||
|
|
||||||
|
# Step 2: Discover auth-server metadata (RFC 8414)
|
||||||
|
try:
|
||||||
|
metadata = await client.discover_auth_server_metadata(auth_server_url)
|
||||||
|
except Exception as e:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=f"Failed to discover authorization server metadata: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not metadata or "authorization_endpoint" not in metadata:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail="Authorization server metadata is missing required endpoints.",
|
||||||
|
)
|
||||||
|
|
||||||
|
authorize_url = metadata["authorization_endpoint"]
|
||||||
|
token_url = metadata["token_endpoint"]
|
||||||
|
registration_endpoint = metadata.get("registration_endpoint")
|
||||||
|
revoke_url = metadata.get("revocation_endpoint")
|
||||||
|
|
||||||
|
# Step 3: Dynamic Client Registration (RFC 7591) if available
|
||||||
|
frontend_base_url = settings.config.frontend_base_url
|
||||||
|
if not frontend_base_url:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Frontend base URL is not configured.",
|
||||||
|
)
|
||||||
|
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
|
||||||
|
|
||||||
|
client_id = ""
|
||||||
|
client_secret = ""
|
||||||
|
if registration_endpoint:
|
||||||
|
reg_result = await _register_mcp_client(
|
||||||
|
registration_endpoint, redirect_uri, request.server_url
|
||||||
|
)
|
||||||
|
if reg_result:
|
||||||
|
client_id = reg_result.get("client_id", "")
|
||||||
|
client_secret = reg_result.get("client_secret", "")
|
||||||
|
|
||||||
|
if not client_id:
|
||||||
|
client_id = "autogpt-platform"
|
||||||
|
|
||||||
|
# Step 4: Store state token with OAuth metadata for the callback
|
||||||
|
scopes = protected_resource.get("scopes_supported", [])
|
||||||
|
state_token, code_challenge = await creds_manager.store.store_state_token(
|
||||||
|
user_id,
|
||||||
|
str(ProviderName.MCP),
|
||||||
|
scopes,
|
||||||
|
state_metadata={
|
||||||
|
"authorize_url": authorize_url,
|
||||||
|
"token_url": token_url,
|
||||||
|
"revoke_url": revoke_url,
|
||||||
|
"resource_url": resource_url,
|
||||||
|
"server_url": request.server_url,
|
||||||
|
"client_id": client_id,
|
||||||
|
"client_secret": client_secret,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 5: Build and return the login URL
|
||||||
|
handler = MCPOAuthHandler(
|
||||||
|
client_id=client_id,
|
||||||
|
client_secret=client_secret,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
authorize_url=authorize_url,
|
||||||
|
token_url=token_url,
|
||||||
|
resource_url=resource_url,
|
||||||
|
)
|
||||||
|
login_url = handler.get_login_url(
|
||||||
|
scopes, state_token, code_challenge=code_challenge
|
||||||
|
)
|
||||||
|
|
||||||
|
return MCPOAuthLoginResponse(login_url=login_url, state_token=state_token)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPOAuthCallbackRequest(BaseModel):
|
||||||
|
"""Request to exchange an OAuth code for tokens."""
|
||||||
|
|
||||||
|
code: str = Field(description="Authorization code from OAuth callback")
|
||||||
|
state_token: str = Field(description="State token for CSRF verification")
|
||||||
|
|
||||||
|
|
||||||
|
class MCPOAuthCallbackResponse(BaseModel):
|
||||||
|
"""Response after successfully storing OAuth credentials."""
|
||||||
|
|
||||||
|
credential_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/oauth/callback",
|
||||||
|
summary="Exchange OAuth code for MCP tokens",
|
||||||
|
)
|
||||||
|
async def mcp_oauth_callback(
|
||||||
|
request: MCPOAuthCallbackRequest,
|
||||||
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
|
) -> MCPOAuthCallbackResponse:
|
||||||
|
"""
|
||||||
|
Exchange the authorization code for tokens and store the credential.
|
||||||
|
|
||||||
|
The frontend calls this after receiving the OAuth code from the popup.
|
||||||
|
On success, subsequent ``/discover-tools`` calls for the same server URL
|
||||||
|
will automatically use the stored credential.
|
||||||
|
"""
|
||||||
|
valid_state = await creds_manager.store.verify_state_token(
|
||||||
|
user_id, request.state_token, str(ProviderName.MCP)
|
||||||
|
)
|
||||||
|
if not valid_state:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Invalid or expired state token.",
|
||||||
|
)
|
||||||
|
|
||||||
|
meta = valid_state.state_metadata
|
||||||
|
frontend_base_url = settings.config.frontend_base_url
|
||||||
|
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
|
||||||
|
|
||||||
|
handler = MCPOAuthHandler(
|
||||||
|
client_id=meta["client_id"],
|
||||||
|
client_secret=meta.get("client_secret", ""),
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
authorize_url=meta["authorize_url"],
|
||||||
|
token_url=meta["token_url"],
|
||||||
|
revoke_url=meta.get("revoke_url"),
|
||||||
|
resource_url=meta.get("resource_url"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
credentials = await handler.exchange_code_for_tokens(
|
||||||
|
request.code, valid_state.scopes, valid_state.code_verifier
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("MCP OAuth token exchange failed")
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"OAuth token exchange failed: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enrich credential metadata for future lookup and token refresh
|
||||||
|
if credentials.metadata is None:
|
||||||
|
credentials.metadata = {}
|
||||||
|
credentials.metadata["mcp_server_url"] = meta["server_url"]
|
||||||
|
credentials.metadata["mcp_client_id"] = meta["client_id"]
|
||||||
|
credentials.metadata["mcp_client_secret"] = meta.get("client_secret", "")
|
||||||
|
|
||||||
|
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
|
||||||
|
credentials.title = f"MCP: {hostname}"
|
||||||
|
|
||||||
|
await creds_manager.create(user_id, credentials)
|
||||||
|
|
||||||
|
return MCPOAuthCallbackResponse(credential_id=credentials.id)
|
||||||
|
|
||||||
|
|
||||||
|
# ======================== Helpers ======================== #
|
||||||
|
|
||||||
|
|
||||||
|
async def _register_mcp_client(
|
||||||
|
registration_endpoint: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
server_url: str,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Attempt Dynamic Client Registration (RFC 7591) with an MCP auth server."""
|
||||||
|
try:
|
||||||
|
response = await Requests(raise_for_status=True).post(
|
||||||
|
registration_endpoint,
|
||||||
|
json={
|
||||||
|
"client_name": "AutoGPT Platform",
|
||||||
|
"redirect_uris": [redirect_uri],
|
||||||
|
"grant_types": ["authorization_code"],
|
||||||
|
"response_types": ["code"],
|
||||||
|
"token_endpoint_auth_method": "client_secret_post",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
data = response.json()
|
||||||
|
if isinstance(data, dict) and "client_id" in data:
|
||||||
|
return data
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Dynamic client registration failed for {server_url}: {e}")
|
||||||
|
return None
|
||||||
385
autogpt_platform/backend/backend/api/features/mcp/test_routes.py
Normal file
385
autogpt_platform/backend/backend/api/features/mcp/test_routes.py
Normal file
@@ -0,0 +1,385 @@
|
|||||||
|
"""Tests for MCP API routes."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
import fastapi.testclient
|
||||||
|
from autogpt_libs.auth import get_user_id
|
||||||
|
|
||||||
|
from backend.api.features.mcp.routes import router
|
||||||
|
from backend.blocks.mcp.client import MCPClientError, MCPTool
|
||||||
|
from backend.util.request import HTTPClientError
|
||||||
|
|
||||||
|
app = fastapi.FastAPI()
|
||||||
|
app.include_router(router)
|
||||||
|
app.dependency_overrides[get_user_id] = lambda: "test-user-id"
|
||||||
|
client = fastapi.testclient.TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDiscoverTools:
|
||||||
|
def test_discover_tools_success(self):
|
||||||
|
mock_tools = [
|
||||||
|
MCPTool(
|
||||||
|
name="get_weather",
|
||||||
|
description="Get weather for a city",
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"city": {"type": "string"}},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
MCPTool(
|
||||||
|
name="add_numbers",
|
||||||
|
description="Add two numbers",
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"a": {"type": "number"},
|
||||||
|
"b": {"type": "number"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
with (patch("backend.api.features.mcp.routes.MCPClient") as MockClient,):
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.initialize = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"protocolVersion": "2025-03-26",
|
||||||
|
"serverInfo": {"name": "test-server"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
instance.list_tools = AsyncMock(return_value=mock_tools)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/discover-tools",
|
||||||
|
json={"server_url": "https://mcp.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["tools"]) == 2
|
||||||
|
assert data["tools"][0]["name"] == "get_weather"
|
||||||
|
assert data["tools"][1]["name"] == "add_numbers"
|
||||||
|
assert data["server_name"] == "test-server"
|
||||||
|
assert data["protocol_version"] == "2025-03-26"
|
||||||
|
|
||||||
|
def test_discover_tools_with_auth_token(self):
|
||||||
|
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.initialize = AsyncMock(
|
||||||
|
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
|
||||||
|
)
|
||||||
|
instance.list_tools = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/discover-tools",
|
||||||
|
json={
|
||||||
|
"server_url": "https://mcp.example.com/mcp",
|
||||||
|
"auth_token": "my-secret-token",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
MockClient.assert_called_once_with(
|
||||||
|
"https://mcp.example.com/mcp",
|
||||||
|
auth_token="my-secret-token",
|
||||||
|
trusted_origins=["https://mcp.example.com/mcp"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_discover_tools_auto_uses_stored_credential(self):
|
||||||
|
"""When no explicit token is given, stored MCP credentials are used."""
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.data.model import OAuth2Credentials
|
||||||
|
|
||||||
|
stored_cred = OAuth2Credentials(
|
||||||
|
provider="mcp",
|
||||||
|
title="MCP: example.com",
|
||||||
|
access_token=SecretStr("stored-token-123"),
|
||||||
|
refresh_token=None,
|
||||||
|
access_token_expires_at=None,
|
||||||
|
refresh_token_expires_at=None,
|
||||||
|
scopes=[],
|
||||||
|
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||||
|
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||||
|
):
|
||||||
|
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred])
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.initialize = AsyncMock(
|
||||||
|
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
|
||||||
|
)
|
||||||
|
instance.list_tools = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/discover-tools",
|
||||||
|
json={"server_url": "https://mcp.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
MockClient.assert_called_once_with(
|
||||||
|
"https://mcp.example.com/mcp",
|
||||||
|
auth_token="stored-token-123",
|
||||||
|
trusted_origins=["https://mcp.example.com/mcp"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_discover_tools_mcp_error(self):
|
||||||
|
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.initialize = AsyncMock(
|
||||||
|
side_effect=MCPClientError("Connection refused")
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/discover-tools",
|
||||||
|
json={"server_url": "https://bad-server.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 502
|
||||||
|
assert "Connection refused" in response.json()["detail"]
|
||||||
|
|
||||||
|
def test_discover_tools_generic_error(self):
|
||||||
|
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/discover-tools",
|
||||||
|
json={"server_url": "https://timeout.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 502
|
||||||
|
assert "Failed to connect" in response.json()["detail"]
|
||||||
|
|
||||||
|
def test_discover_tools_auth_required(self):
|
||||||
|
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.initialize = AsyncMock(
|
||||||
|
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/discover-tools",
|
||||||
|
json={"server_url": "https://auth-server.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "requires authentication" in response.json()["detail"]
|
||||||
|
|
||||||
|
def test_discover_tools_forbidden(self):
|
||||||
|
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.initialize = AsyncMock(
|
||||||
|
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/discover-tools",
|
||||||
|
json={"server_url": "https://auth-server.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "requires authentication" in response.json()["detail"]
|
||||||
|
|
||||||
|
def test_discover_tools_missing_url(self):
|
||||||
|
response = client.post("/discover-tools", json={})
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuthLogin:
|
||||||
|
def test_oauth_login_success(self):
|
||||||
|
with (
|
||||||
|
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||||
|
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||||
|
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
||||||
|
patch(
|
||||||
|
"backend.api.features.mcp.routes._register_mcp_client"
|
||||||
|
) as mock_register,
|
||||||
|
):
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.discover_auth = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"authorization_servers": ["https://auth.sentry.io"],
|
||||||
|
"resource": "https://mcp.sentry.dev/mcp",
|
||||||
|
"scopes_supported": ["openid"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
instance.discover_auth_server_metadata = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"authorization_endpoint": "https://auth.sentry.io/authorize",
|
||||||
|
"token_endpoint": "https://auth.sentry.io/token",
|
||||||
|
"registration_endpoint": "https://auth.sentry.io/register",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_register.return_value = {
|
||||||
|
"client_id": "registered-client-id",
|
||||||
|
"client_secret": "registered-secret",
|
||||||
|
}
|
||||||
|
mock_cm.store.store_state_token = AsyncMock(
|
||||||
|
return_value=("state-token-123", "code-challenge-abc")
|
||||||
|
)
|
||||||
|
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/oauth/login",
|
||||||
|
json={"server_url": "https://mcp.sentry.dev/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "login_url" in data
|
||||||
|
assert data["state_token"] == "state-token-123"
|
||||||
|
assert "auth.sentry.io/authorize" in data["login_url"]
|
||||||
|
assert "registered-client-id" in data["login_url"]
|
||||||
|
|
||||||
|
def test_oauth_login_no_oauth_support(self):
|
||||||
|
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.discover_auth = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/oauth/login",
|
||||||
|
json={"server_url": "https://simple-server.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "does not advertise OAuth" in response.json()["detail"]
|
||||||
|
|
||||||
|
def test_oauth_login_fallback_to_public_client(self):
|
||||||
|
"""When DCR is unavailable, falls back to default public client ID."""
|
||||||
|
with (
|
||||||
|
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||||
|
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||||
|
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
||||||
|
):
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.discover_auth = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"authorization_servers": ["https://auth.example.com"],
|
||||||
|
"resource": "https://mcp.example.com/mcp",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
instance.discover_auth_server_metadata = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"authorization_endpoint": "https://auth.example.com/authorize",
|
||||||
|
"token_endpoint": "https://auth.example.com/token",
|
||||||
|
# No registration_endpoint
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_cm.store.store_state_token = AsyncMock(
|
||||||
|
return_value=("state-abc", "challenge-xyz")
|
||||||
|
)
|
||||||
|
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/oauth/login",
|
||||||
|
json={"server_url": "https://mcp.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "autogpt-platform" in data["login_url"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuthCallback:
|
||||||
|
def test_oauth_callback_success(self):
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.data.model import OAuth2Credentials
|
||||||
|
|
||||||
|
mock_creds = OAuth2Credentials(
|
||||||
|
provider="mcp",
|
||||||
|
title=None,
|
||||||
|
access_token=SecretStr("access-token-xyz"),
|
||||||
|
refresh_token=None,
|
||||||
|
access_token_expires_at=None,
|
||||||
|
refresh_token_expires_at=None,
|
||||||
|
scopes=[],
|
||||||
|
metadata={
|
||||||
|
"mcp_token_url": "https://auth.sentry.io/token",
|
||||||
|
"mcp_resource_url": "https://mcp.sentry.dev/mcp",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||||
|
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
||||||
|
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
|
||||||
|
):
|
||||||
|
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
||||||
|
|
||||||
|
# Mock state verification
|
||||||
|
mock_state = AsyncMock()
|
||||||
|
mock_state.state_metadata = {
|
||||||
|
"authorize_url": "https://auth.sentry.io/authorize",
|
||||||
|
"token_url": "https://auth.sentry.io/token",
|
||||||
|
"client_id": "test-client-id",
|
||||||
|
"client_secret": "test-secret",
|
||||||
|
"server_url": "https://mcp.sentry.dev/mcp",
|
||||||
|
}
|
||||||
|
mock_state.scopes = ["openid"]
|
||||||
|
mock_state.code_verifier = "verifier-123"
|
||||||
|
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
|
||||||
|
mock_cm.create = AsyncMock()
|
||||||
|
|
||||||
|
handler_instance = MockHandler.return_value
|
||||||
|
handler_instance.exchange_code_for_tokens = AsyncMock(
|
||||||
|
return_value=mock_creds
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/oauth/callback",
|
||||||
|
json={"code": "auth-code-abc", "state_token": "state-token-123"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "credential_id" in data
|
||||||
|
mock_cm.create.assert_called_once()
|
||||||
|
|
||||||
|
def test_oauth_callback_invalid_state(self):
|
||||||
|
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
|
||||||
|
mock_cm.store.verify_state_token = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/oauth/callback",
|
||||||
|
json={"code": "auth-code", "state_token": "bad-state"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "Invalid or expired" in response.json()["detail"]
|
||||||
|
|
||||||
|
def test_oauth_callback_token_exchange_fails(self):
|
||||||
|
with (
|
||||||
|
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||||
|
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
||||||
|
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
|
||||||
|
):
|
||||||
|
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
||||||
|
mock_state = AsyncMock()
|
||||||
|
mock_state.state_metadata = {
|
||||||
|
"authorize_url": "https://auth.example.com/authorize",
|
||||||
|
"token_url": "https://auth.example.com/token",
|
||||||
|
"client_id": "cid",
|
||||||
|
"server_url": "https://mcp.example.com/mcp",
|
||||||
|
}
|
||||||
|
mock_state.scopes = []
|
||||||
|
mock_state.code_verifier = "v"
|
||||||
|
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
|
||||||
|
|
||||||
|
handler_instance = MockHandler.return_value
|
||||||
|
handler_instance.exchange_code_for_tokens = AsyncMock(
|
||||||
|
side_effect=RuntimeError("Token exchange failed")
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/oauth/callback",
|
||||||
|
json={"code": "bad-code", "state_token": "state"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "token exchange failed" in response.json()["detail"].lower()
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, overload
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import prisma.enums
|
import prisma.enums
|
||||||
@@ -11,8 +11,8 @@ import prisma.types
|
|||||||
|
|
||||||
from backend.data.db import transaction
|
from backend.data.db import transaction
|
||||||
from backend.data.graph import (
|
from backend.data.graph import (
|
||||||
GraphMeta,
|
|
||||||
GraphModel,
|
GraphModel,
|
||||||
|
GraphModelWithoutNodes,
|
||||||
get_graph,
|
get_graph,
|
||||||
get_graph_as_admin,
|
get_graph_as_admin,
|
||||||
get_sub_graphs,
|
get_sub_graphs,
|
||||||
@@ -334,7 +334,22 @@ async def get_store_agent_details(
|
|||||||
raise DatabaseError("Failed to fetch agent details") from e
|
raise DatabaseError("Failed to fetch agent details") from e
|
||||||
|
|
||||||
|
|
||||||
async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
@overload
|
||||||
|
async def get_available_graph(
|
||||||
|
store_listing_version_id: str, hide_nodes: Literal[False]
|
||||||
|
) -> GraphModel: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def get_available_graph(
|
||||||
|
store_listing_version_id: str, hide_nodes: Literal[True] = True
|
||||||
|
) -> GraphModelWithoutNodes: ...
|
||||||
|
|
||||||
|
|
||||||
|
async def get_available_graph(
|
||||||
|
store_listing_version_id: str,
|
||||||
|
hide_nodes: bool = True,
|
||||||
|
) -> GraphModelWithoutNodes | GraphModel:
|
||||||
try:
|
try:
|
||||||
# Get avaialble, non-deleted store listing version
|
# Get avaialble, non-deleted store listing version
|
||||||
store_listing_version = (
|
store_listing_version = (
|
||||||
@@ -344,7 +359,7 @@ async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
|||||||
"isAvailable": True,
|
"isAvailable": True,
|
||||||
"isDeleted": False,
|
"isDeleted": False,
|
||||||
},
|
},
|
||||||
include={"AgentGraph": {"include": {"Nodes": True}}},
|
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -354,7 +369,9 @@ async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
|||||||
detail=f"Store listing version {store_listing_version_id} not found",
|
detail=f"Store listing version {store_listing_version_id} not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
return GraphModel.from_db(store_listing_version.AgentGraph).meta()
|
return (GraphModelWithoutNodes if hide_nodes else GraphModel).from_db(
|
||||||
|
store_listing_version.AgentGraph
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting agent: {e}")
|
logger.error(f"Error getting agent: {e}")
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from backend.blocks.ideogram import (
|
|||||||
StyleType,
|
StyleType,
|
||||||
UpscaleOption,
|
UpscaleOption,
|
||||||
)
|
)
|
||||||
from backend.data.graph import BaseGraph
|
from backend.data.graph import GraphBaseMeta
|
||||||
from backend.data.model import CredentialsMetaInput, ProviderName
|
from backend.data.model import CredentialsMetaInput, ProviderName
|
||||||
from backend.integrations.credentials_store import ideogram_credentials
|
from backend.integrations.credentials_store import ideogram_credentials
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
@@ -34,14 +34,14 @@ class ImageStyle(str, Enum):
|
|||||||
DIGITAL_ART = "digital art"
|
DIGITAL_ART = "digital art"
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
||||||
if settings.config.use_agent_image_generation_v2:
|
if settings.config.use_agent_image_generation_v2:
|
||||||
return await generate_agent_image_v2(graph=agent)
|
return await generate_agent_image_v2(graph=agent)
|
||||||
else:
|
else:
|
||||||
return await generate_agent_image_v1(agent=agent)
|
return await generate_agent_image_v1(agent=agent)
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
||||||
"""
|
"""
|
||||||
Generate an image for an agent using Ideogram model.
|
Generate an image for an agent using Ideogram model.
|
||||||
Returns:
|
Returns:
|
||||||
@@ -54,14 +54,17 @@ async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
|
|||||||
description = f"{name} ({graph.description})" if graph.description else name
|
description = f"{name} ({graph.description})" if graph.description else name
|
||||||
|
|
||||||
prompt = (
|
prompt = (
|
||||||
f"Create a visually striking retro-futuristic vector pop art illustration prominently featuring "
|
"Create a visually striking retro-futuristic vector pop art illustration "
|
||||||
f'"{name}" in bold typography. The image clearly and literally depicts a {description}, '
|
f'prominently featuring "{name}" in bold typography. The image clearly and '
|
||||||
f"along with recognizable objects directly associated with the primary function of a {name}. "
|
f"literally depicts a {description}, along with recognizable objects directly "
|
||||||
f"Ensure the imagery is concrete, intuitive, and immediately understandable, clearly conveying the "
|
f"associated with the primary function of a {name}. "
|
||||||
f"purpose of a {name}. Maintain vibrant, limited-palette colors, sharp vector lines, geometric "
|
f"Ensure the imagery is concrete, intuitive, and immediately understandable, "
|
||||||
f"shapes, flat illustration techniques, and solid colors without gradients or shading. Preserve a "
|
f"clearly conveying the purpose of a {name}. "
|
||||||
f"retro-futuristic aesthetic influenced by mid-century futurism and 1960s psychedelia, "
|
"Maintain vibrant, limited-palette colors, sharp vector lines, "
|
||||||
f"prioritizing clear visual storytelling and thematic clarity above all else."
|
"geometric shapes, flat illustration techniques, and solid colors "
|
||||||
|
"without gradients or shading. Preserve a retro-futuristic aesthetic "
|
||||||
|
"influenced by mid-century futurism and 1960s psychedelia, "
|
||||||
|
"prioritizing clear visual storytelling and thematic clarity above all else."
|
||||||
)
|
)
|
||||||
|
|
||||||
custom_colors = [
|
custom_colors = [
|
||||||
@@ -99,12 +102,12 @@ async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
|
|||||||
return io.BytesIO(response.content)
|
return io.BytesIO(response.content)
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image_v1(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image_v1(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
||||||
"""
|
"""
|
||||||
Generate an image for an agent using Flux model via Replicate API.
|
Generate an image for an agent using Flux model via Replicate API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agent (Graph): The agent to generate an image for
|
agent (GraphBaseMeta | AgentGraph): The agent to generate an image for
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
io.BytesIO: The generated image as bytes
|
io.BytesIO: The generated image as bytes
|
||||||
@@ -114,7 +117,13 @@ async def generate_agent_image_v1(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
|||||||
raise ValueError("Missing Replicate API key in settings")
|
raise ValueError("Missing Replicate API key in settings")
|
||||||
|
|
||||||
# Construct prompt from agent details
|
# Construct prompt from agent details
|
||||||
prompt = f"Create a visually engaging app store thumbnail for the AI agent that highlights what it does in a clear and captivating way:\n- **Name**: {agent.name}\n- **Description**: {agent.description}\nFocus on showcasing its core functionality with an appealing design."
|
prompt = (
|
||||||
|
"Create a visually engaging app store thumbnail for the AI agent "
|
||||||
|
"that highlights what it does in a clear and captivating way:\n"
|
||||||
|
f"- **Name**: {agent.name}\n"
|
||||||
|
f"- **Description**: {agent.description}\n"
|
||||||
|
f"Focus on showcasing its core functionality with an appealing design."
|
||||||
|
)
|
||||||
|
|
||||||
# Set up Replicate client
|
# Set up Replicate client
|
||||||
client = ReplicateClient(api_token=settings.secrets.replicate_api_key)
|
client = ReplicateClient(api_token=settings.secrets.replicate_api_key)
|
||||||
|
|||||||
@@ -278,7 +278,7 @@ async def get_agent(
|
|||||||
)
|
)
|
||||||
async def get_graph_meta_by_store_listing_version_id(
|
async def get_graph_meta_by_store_listing_version_id(
|
||||||
store_listing_version_id: str,
|
store_listing_version_id: str,
|
||||||
) -> backend.data.graph.GraphMeta:
|
) -> backend.data.graph.GraphModelWithoutNodes:
|
||||||
"""
|
"""
|
||||||
Get Agent Graph from Store Listing Version ID.
|
Get Agent Graph from Store Listing Version ID.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -101,7 +101,6 @@ from backend.util.timezone_utils import (
|
|||||||
from backend.util.virus_scanner import scan_content_safe
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
|
|
||||||
from .library import db as library_db
|
from .library import db as library_db
|
||||||
from .library import model as library_model
|
|
||||||
from .store.model import StoreAgentDetails
|
from .store.model import StoreAgentDetails
|
||||||
|
|
||||||
|
|
||||||
@@ -823,18 +822,16 @@ async def update_graph(
|
|||||||
graph: graph_db.Graph,
|
graph: graph_db.Graph,
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
) -> graph_db.GraphModel:
|
) -> graph_db.GraphModel:
|
||||||
# Sanity check
|
|
||||||
if graph.id and graph.id != graph_id:
|
if graph.id and graph.id != graph_id:
|
||||||
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
||||||
|
|
||||||
# Determine new version
|
|
||||||
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
||||||
if not existing_versions:
|
if not existing_versions:
|
||||||
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
||||||
latest_version_number = max(g.version for g in existing_versions)
|
|
||||||
graph.version = latest_version_number + 1
|
|
||||||
|
|
||||||
|
graph.version = max(g.version for g in existing_versions) + 1
|
||||||
current_active_version = next((v for v in existing_versions if v.is_active), None)
|
current_active_version = next((v for v in existing_versions if v.is_active), None)
|
||||||
|
|
||||||
graph = graph_db.make_graph_model(graph, user_id)
|
graph = graph_db.make_graph_model(graph, user_id)
|
||||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
||||||
graph.validate_graph(for_run=False)
|
graph.validate_graph(for_run=False)
|
||||||
@@ -842,27 +839,23 @@ async def update_graph(
|
|||||||
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
||||||
|
|
||||||
if new_graph_version.is_active:
|
if new_graph_version.is_active:
|
||||||
# Keep the library agent up to date with the new active version
|
await library_db.update_library_agent_version_and_settings(
|
||||||
await _update_library_agent_version_and_settings(user_id, new_graph_version)
|
user_id, new_graph_version
|
||||||
|
)
|
||||||
# Handle activation of the new graph first to ensure continuity
|
|
||||||
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
|
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
|
||||||
# Ensure new version is the only active version
|
|
||||||
await graph_db.set_graph_active_version(
|
await graph_db.set_graph_active_version(
|
||||||
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
||||||
)
|
)
|
||||||
if current_active_version:
|
if current_active_version:
|
||||||
# Handle deactivation of the previously active version
|
|
||||||
await on_graph_deactivate(current_active_version, user_id=user_id)
|
await on_graph_deactivate(current_active_version, user_id=user_id)
|
||||||
|
|
||||||
# Fetch new graph version *with sub-graphs* (needed for credentials input schema)
|
|
||||||
new_graph_version_with_subgraphs = await graph_db.get_graph(
|
new_graph_version_with_subgraphs = await graph_db.get_graph(
|
||||||
graph_id,
|
graph_id,
|
||||||
new_graph_version.version,
|
new_graph_version.version,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
include_subgraphs=True,
|
include_subgraphs=True,
|
||||||
)
|
)
|
||||||
assert new_graph_version_with_subgraphs # make type checker happy
|
assert new_graph_version_with_subgraphs
|
||||||
return new_graph_version_with_subgraphs
|
return new_graph_version_with_subgraphs
|
||||||
|
|
||||||
|
|
||||||
@@ -900,33 +893,15 @@ async def set_graph_active_version(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Keep the library agent up to date with the new active version
|
# Keep the library agent up to date with the new active version
|
||||||
await _update_library_agent_version_and_settings(user_id, new_active_graph)
|
await library_db.update_library_agent_version_and_settings(
|
||||||
|
user_id, new_active_graph
|
||||||
|
)
|
||||||
|
|
||||||
if current_active_graph and current_active_graph.version != new_active_version:
|
if current_active_graph and current_active_graph.version != new_active_version:
|
||||||
# Handle deactivation of the previously active version
|
# Handle deactivation of the previously active version
|
||||||
await on_graph_deactivate(current_active_graph, user_id=user_id)
|
await on_graph_deactivate(current_active_graph, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
async def _update_library_agent_version_and_settings(
|
|
||||||
user_id: str, agent_graph: graph_db.GraphModel
|
|
||||||
) -> library_model.LibraryAgent:
|
|
||||||
library = await library_db.update_agent_version_in_library(
|
|
||||||
user_id, agent_graph.id, agent_graph.version
|
|
||||||
)
|
|
||||||
updated_settings = GraphSettings.from_graph(
|
|
||||||
graph=agent_graph,
|
|
||||||
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
|
||||||
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
|
||||||
)
|
|
||||||
if updated_settings != library.settings:
|
|
||||||
library = await library_db.update_library_agent(
|
|
||||||
library_agent_id=library.id,
|
|
||||||
user_id=user_id,
|
|
||||||
settings=updated_settings,
|
|
||||||
)
|
|
||||||
return library
|
|
||||||
|
|
||||||
|
|
||||||
@v1_router.patch(
|
@v1_router.patch(
|
||||||
path="/graphs/{graph_id}/settings",
|
path="/graphs/{graph_id}/settings",
|
||||||
summary="Update graph settings",
|
summary="Update graph settings",
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ import backend.api.features.executions.review.routes
|
|||||||
import backend.api.features.library.db
|
import backend.api.features.library.db
|
||||||
import backend.api.features.library.model
|
import backend.api.features.library.model
|
||||||
import backend.api.features.library.routes
|
import backend.api.features.library.routes
|
||||||
|
import backend.api.features.mcp.routes as mcp_routes
|
||||||
import backend.api.features.oauth
|
import backend.api.features.oauth
|
||||||
import backend.api.features.otto.routes
|
import backend.api.features.otto.routes
|
||||||
import backend.api.features.postmark.postmark
|
import backend.api.features.postmark.postmark
|
||||||
@@ -343,6 +344,11 @@ app.include_router(
|
|||||||
tags=["workspace"],
|
tags=["workspace"],
|
||||||
prefix="/api/workspace",
|
prefix="/api/workspace",
|
||||||
)
|
)
|
||||||
|
app.include_router(
|
||||||
|
mcp_routes.router,
|
||||||
|
tags=["v2", "mcp"],
|
||||||
|
prefix="/api/mcp",
|
||||||
|
)
|
||||||
app.include_router(
|
app.include_router(
|
||||||
backend.api.features.oauth.router,
|
backend.api.features.oauth.router,
|
||||||
tags=["oauth"],
|
tags=["oauth"],
|
||||||
|
|||||||
28
autogpt_platform/backend/backend/blocks/elevenlabs/_auth.py
Normal file
28
autogpt_platform/backend/backend/blocks/elevenlabs/_auth.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""ElevenLabs integration blocks - test credentials and shared utilities."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
|
TEST_CREDENTIALS = APIKeyCredentials(
|
||||||
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
|
provider="elevenlabs",
|
||||||
|
api_key=SecretStr("mock-elevenlabs-api-key"),
|
||||||
|
title="Mock ElevenLabs API key",
|
||||||
|
expires_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
TEST_CREDENTIALS_INPUT = {
|
||||||
|
"provider": TEST_CREDENTIALS.provider,
|
||||||
|
"id": TEST_CREDENTIALS.id,
|
||||||
|
"type": TEST_CREDENTIALS.type,
|
||||||
|
"title": TEST_CREDENTIALS.title,
|
||||||
|
}
|
||||||
|
|
||||||
|
ElevenLabsCredentials = APIKeyCredentials
|
||||||
|
ElevenLabsCredentialsInput = CredentialsMetaInput[
|
||||||
|
Literal[ProviderName.ELEVENLABS], Literal["api_key"]
|
||||||
|
]
|
||||||
77
autogpt_platform/backend/backend/blocks/encoder_block.py
Normal file
77
autogpt_platform/backend/backend/blocks/encoder_block.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""Text encoding block for converting special characters to escape sequences."""
|
||||||
|
|
||||||
|
import codecs
|
||||||
|
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
class TextEncoderBlock(Block):
|
||||||
|
"""
|
||||||
|
Encodes a string by converting special characters into escape sequences.
|
||||||
|
|
||||||
|
This block is the inverse of TextDecoderBlock. It takes text containing
|
||||||
|
special characters (like newlines, tabs, etc.) and converts them into
|
||||||
|
their escape sequence representations (e.g., newline becomes \\n).
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
"""Input schema for TextEncoderBlock."""
|
||||||
|
|
||||||
|
text: str = SchemaField(
|
||||||
|
description="A string containing special characters to be encoded",
|
||||||
|
placeholder="Your text with newlines and quotes to encode",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
"""Output schema for TextEncoderBlock."""
|
||||||
|
|
||||||
|
encoded_text: str = SchemaField(
|
||||||
|
description="The encoded text with special characters converted to escape sequences"
|
||||||
|
)
|
||||||
|
error: str = SchemaField(description="Error message if encoding fails")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="5185f32e-4b65-4ecf-8fbb-873f003f09d6",
|
||||||
|
description="Encodes a string by converting special characters into escape sequences",
|
||||||
|
categories={BlockCategory.TEXT},
|
||||||
|
input_schema=TextEncoderBlock.Input,
|
||||||
|
output_schema=TextEncoderBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"text": """Hello
|
||||||
|
World!
|
||||||
|
This is a "quoted" string."""
|
||||||
|
},
|
||||||
|
test_output=[
|
||||||
|
(
|
||||||
|
"encoded_text",
|
||||||
|
"""Hello\\nWorld!\\nThis is a "quoted" string.""",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||||
|
"""
|
||||||
|
Encode the input text by converting special characters to escape sequences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_data: The input containing the text to encode.
|
||||||
|
**kwargs: Additional keyword arguments (unused).
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The encoded text with escape sequences, or an error message if encoding fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
encoded_text = codecs.encode(input_data.text, "unicode_escape").decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
yield "encoded_text", encoded_text
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", f"Encoding error: {str(e)}"
|
||||||
@@ -478,7 +478,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
webset = aexa.websets.get(id=input_data.external_id)
|
webset = await aexa.websets.get(id=input_data.external_id)
|
||||||
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
||||||
|
|
||||||
yield "webset", webset_result
|
yield "webset", webset_result
|
||||||
@@ -494,7 +494,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
|||||||
count=input_data.search_count,
|
count=input_data.search_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
webset = aexa.websets.create(
|
webset = await aexa.websets.create(
|
||||||
params=CreateWebsetParameters(
|
params=CreateWebsetParameters(
|
||||||
search=search_params,
|
search=search_params,
|
||||||
external_id=input_data.external_id,
|
external_id=input_data.external_id,
|
||||||
@@ -554,7 +554,7 @@ class ExaUpdateWebsetBlock(Block):
|
|||||||
if input_data.metadata is not None:
|
if input_data.metadata is not None:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_webset = aexa.websets.update(id=input_data.webset_id, params=payload)
|
sdk_webset = await aexa.websets.update(id=input_data.webset_id, params=payload)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
sdk_webset.status.value
|
sdk_webset.status.value
|
||||||
@@ -617,7 +617,7 @@ class ExaListWebsetsBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = aexa.websets.list(
|
response = await aexa.websets.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
@@ -678,7 +678,7 @@ class ExaGetWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_webset = aexa.websets.get(id=input_data.webset_id)
|
sdk_webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
sdk_webset.status.value
|
sdk_webset.status.value
|
||||||
@@ -748,7 +748,7 @@ class ExaDeleteWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_webset = aexa.websets.delete(id=input_data.webset_id)
|
deleted_webset = await aexa.websets.delete(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
deleted_webset.status.value
|
deleted_webset.status.value
|
||||||
@@ -798,7 +798,7 @@ class ExaCancelWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_webset = aexa.websets.cancel(id=input_data.webset_id)
|
canceled_webset = await aexa.websets.cancel(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
canceled_webset.status.value
|
canceled_webset.status.value
|
||||||
@@ -968,7 +968,7 @@ class ExaPreviewWebsetBlock(Block):
|
|||||||
entity["description"] = input_data.entity_description
|
entity["description"] = input_data.entity_description
|
||||||
payload["entity"] = entity
|
payload["entity"] = entity
|
||||||
|
|
||||||
sdk_preview = aexa.websets.preview(params=payload)
|
sdk_preview = await aexa.websets.preview(params=payload)
|
||||||
|
|
||||||
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
||||||
|
|
||||||
@@ -1051,7 +1051,7 @@ class ExaWebsetStatusBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status = (
|
status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
@@ -1185,7 +1185,7 @@ class ExaWebsetSummaryBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
# Extract basic info
|
# Extract basic info
|
||||||
webset_id = webset.id
|
webset_id = webset.id
|
||||||
@@ -1211,7 +1211,7 @@ class ExaWebsetSummaryBlock(Block):
|
|||||||
total_items = 0
|
total_items = 0
|
||||||
|
|
||||||
if input_data.include_sample_items and input_data.sample_size > 0:
|
if input_data.include_sample_items and input_data.sample_size > 0:
|
||||||
items_response = aexa.websets.items.list(
|
items_response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||||
)
|
)
|
||||||
sample_items_data = [
|
sample_items_data = [
|
||||||
@@ -1362,7 +1362,7 @@ class ExaWebsetReadyCheckBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get webset details
|
# Get webset details
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status = (
|
status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_enrichment = aexa.websets.enrichments.create(
|
sdk_enrichment = await aexa.websets.enrichments.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -223,7 +223,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
items_enriched = 0
|
items_enriched = 0
|
||||||
|
|
||||||
while time.time() - poll_start < input_data.polling_timeout:
|
while time.time() - poll_start < input_data.polling_timeout:
|
||||||
current_enrich = aexa.websets.enrichments.get(
|
current_enrich = await aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=enrichment_id
|
webset_id=input_data.webset_id, id=enrichment_id
|
||||||
)
|
)
|
||||||
current_status = (
|
current_status = (
|
||||||
@@ -234,7 +234,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
|
|
||||||
if current_status in ["completed", "failed", "cancelled"]:
|
if current_status in ["completed", "failed", "cancelled"]:
|
||||||
# Estimate items from webset searches
|
# Estimate items from webset searches
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
if webset.searches:
|
if webset.searches:
|
||||||
for search in webset.searches:
|
for search in webset.searches:
|
||||||
if search.progress:
|
if search.progress:
|
||||||
@@ -329,7 +329,7 @@ class ExaGetEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_enrichment = aexa.websets.enrichments.get(
|
sdk_enrichment = await aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -474,7 +474,7 @@ class ExaDeleteEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_enrichment = aexa.websets.enrichments.delete(
|
deleted_enrichment = await aexa.websets.enrichments.delete(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -525,13 +525,13 @@ class ExaCancelEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_enrichment = aexa.websets.enrichments.cancel(
|
canceled_enrichment = await aexa.websets.enrichments.cancel(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to estimate how many items were enriched before cancellation
|
# Try to estimate how many items were enriched before cancellation
|
||||||
items_enriched = 0
|
items_enriched = 0
|
||||||
items_response = aexa.websets.items.list(
|
items_response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=100
|
webset_id=input_data.webset_id, limit=100
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -222,7 +222,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
def _create_test_mock():
|
def _create_test_mock():
|
||||||
"""Create test mocks for the AsyncExa SDK."""
|
"""Create test mocks for the AsyncExa SDK."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
# Create mock SDK import object
|
# Create mock SDK import object
|
||||||
mock_import = MagicMock()
|
mock_import = MagicMock()
|
||||||
@@ -247,7 +247,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(
|
||||||
imports=MagicMock(create=lambda *args, **kwargs: mock_import)
|
imports=MagicMock(create=AsyncMock(return_value=mock_import))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -294,7 +294,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
if input_data.metadata:
|
if input_data.metadata:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_import = aexa.websets.imports.create(
|
sdk_import = await aexa.websets.imports.create(
|
||||||
params=payload, csv_data=input_data.csv_data
|
params=payload, csv_data=input_data.csv_data
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -360,7 +360,7 @@ class ExaGetImportBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_import = aexa.websets.imports.get(import_id=input_data.import_id)
|
sdk_import = await aexa.websets.imports.get(import_id=input_data.import_id)
|
||||||
|
|
||||||
import_obj = ImportModel.from_sdk(sdk_import)
|
import_obj = ImportModel.from_sdk(sdk_import)
|
||||||
|
|
||||||
@@ -426,7 +426,7 @@ class ExaListImportsBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = aexa.websets.imports.list(
|
response = await aexa.websets.imports.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
@@ -474,7 +474,9 @@ class ExaDeleteImportBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_import = aexa.websets.imports.delete(import_id=input_data.import_id)
|
deleted_import = await aexa.websets.imports.delete(
|
||||||
|
import_id=input_data.import_id
|
||||||
|
)
|
||||||
|
|
||||||
yield "import_id", deleted_import.id
|
yield "import_id", deleted_import.id
|
||||||
yield "success", "true"
|
yield "success", "true"
|
||||||
@@ -573,14 +575,14 @@ class ExaExportWebsetBlock(Block):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create mock iterator
|
# Create async iterator for list_all
|
||||||
mock_items = [mock_item1, mock_item2]
|
async def async_item_iterator(*args, **kwargs):
|
||||||
|
for item in [mock_item1, mock_item2]:
|
||||||
|
yield item
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(items=MagicMock(list_all=async_item_iterator))
|
||||||
items=MagicMock(list_all=lambda *args, **kwargs: iter(mock_items))
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -602,7 +604,7 @@ class ExaExportWebsetBlock(Block):
|
|||||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||||
)
|
)
|
||||||
|
|
||||||
for sdk_item in item_iterator:
|
async for sdk_item in item_iterator:
|
||||||
if len(all_items) >= input_data.max_items:
|
if len(all_items) >= input_data.max_items:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class ExaGetWebsetItemBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_item = aexa.websets.items.get(
|
sdk_item = await aexa.websets.items.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.item_id
|
webset_id=input_data.webset_id, id=input_data.item_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -269,7 +269,7 @@ class ExaListWebsetItemsBlock(Block):
|
|||||||
response = None
|
response = None
|
||||||
|
|
||||||
while time.time() - start_time < input_data.wait_timeout:
|
while time.time() - start_time < input_data.wait_timeout:
|
||||||
response = aexa.websets.items.list(
|
response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
@@ -282,13 +282,13 @@ class ExaListWebsetItemsBlock(Block):
|
|||||||
interval = min(interval * 1.2, 10)
|
interval = min(interval * 1.2, 10)
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
response = aexa.websets.items.list(
|
response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = aexa.websets.items.list(
|
response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
@@ -340,7 +340,7 @@ class ExaDeleteWebsetItemBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_item = aexa.websets.items.delete(
|
deleted_item = await aexa.websets.items.delete(
|
||||||
webset_id=input_data.webset_id, id=input_data.item_id
|
webset_id=input_data.webset_id, id=input_data.item_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -408,7 +408,7 @@ class ExaBulkWebsetItemsBlock(Block):
|
|||||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||||
)
|
)
|
||||||
|
|
||||||
for sdk_item in item_iterator:
|
async for sdk_item in item_iterator:
|
||||||
if len(all_items) >= input_data.max_items:
|
if len(all_items) >= input_data.max_items:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -475,7 +475,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
entity_type = "unknown"
|
entity_type = "unknown"
|
||||||
if webset.searches:
|
if webset.searches:
|
||||||
@@ -495,7 +495,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
|||||||
# Get sample items if requested
|
# Get sample items if requested
|
||||||
sample_items: List[WebsetItemModel] = []
|
sample_items: List[WebsetItemModel] = []
|
||||||
if input_data.sample_size > 0:
|
if input_data.sample_size > 0:
|
||||||
items_response = aexa.websets.items.list(
|
items_response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||||
)
|
)
|
||||||
# Convert to our stable models
|
# Convert to our stable models
|
||||||
@@ -569,7 +569,7 @@ class ExaGetNewItemsBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get items starting from cursor
|
# Get items starting from cursor
|
||||||
response = aexa.websets.items.list(
|
response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.since_cursor,
|
cursor=input_data.since_cursor,
|
||||||
limit=input_data.max_items,
|
limit=input_data.max_items,
|
||||||
|
|||||||
@@ -233,7 +233,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
def _create_test_mock():
|
def _create_test_mock():
|
||||||
"""Create test mocks for the AsyncExa SDK."""
|
"""Create test mocks for the AsyncExa SDK."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
# Create mock SDK monitor object
|
# Create mock SDK monitor object
|
||||||
mock_monitor = MagicMock()
|
mock_monitor = MagicMock()
|
||||||
@@ -263,7 +263,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(
|
||||||
monitors=MagicMock(create=lambda *args, **kwargs: mock_monitor)
|
monitors=MagicMock(create=AsyncMock(return_value=mock_monitor))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -320,7 +320,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
if input_data.metadata:
|
if input_data.metadata:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_monitor = aexa.websets.monitors.create(params=payload)
|
sdk_monitor = await aexa.websets.monitors.create(params=payload)
|
||||||
|
|
||||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||||
|
|
||||||
@@ -384,7 +384,7 @@ class ExaGetMonitorBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_monitor = aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
sdk_monitor = await aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
||||||
|
|
||||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||||
|
|
||||||
@@ -476,7 +476,7 @@ class ExaUpdateMonitorBlock(Block):
|
|||||||
if input_data.metadata is not None:
|
if input_data.metadata is not None:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_monitor = aexa.websets.monitors.update(
|
sdk_monitor = await aexa.websets.monitors.update(
|
||||||
monitor_id=input_data.monitor_id, params=payload
|
monitor_id=input_data.monitor_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -522,7 +522,9 @@ class ExaDeleteMonitorBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_monitor = aexa.websets.monitors.delete(monitor_id=input_data.monitor_id)
|
deleted_monitor = await aexa.websets.monitors.delete(
|
||||||
|
monitor_id=input_data.monitor_id
|
||||||
|
)
|
||||||
|
|
||||||
yield "monitor_id", deleted_monitor.id
|
yield "monitor_id", deleted_monitor.id
|
||||||
yield "success", "true"
|
yield "success", "true"
|
||||||
@@ -579,7 +581,7 @@ class ExaListMonitorsBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = aexa.websets.monitors.list(
|
response = await aexa.websets.monitors.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
WebsetTargetStatus.IDLE,
|
WebsetTargetStatus.IDLE,
|
||||||
WebsetTargetStatus.ANY_COMPLETE,
|
WebsetTargetStatus.ANY_COMPLETE,
|
||||||
]:
|
]:
|
||||||
final_webset = aexa.websets.wait_until_idle(
|
final_webset = await aexa.websets.wait_until_idle(
|
||||||
id=input_data.webset_id,
|
id=input_data.webset_id,
|
||||||
timeout=input_data.timeout,
|
timeout=input_data.timeout,
|
||||||
poll_interval=input_data.check_interval,
|
poll_interval=input_data.check_interval,
|
||||||
@@ -164,7 +164,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
interval = input_data.check_interval
|
interval = input_data.check_interval
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current webset status
|
# Get current webset status
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
current_status = (
|
current_status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
if hasattr(webset.status, "value")
|
if hasattr(webset.status, "value")
|
||||||
@@ -209,7 +209,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
|
|
||||||
# Timeout reached
|
# Timeout reached
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
final_status = (
|
final_status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
if hasattr(webset.status, "value")
|
if hasattr(webset.status, "value")
|
||||||
@@ -345,7 +345,7 @@ class ExaWaitForSearchBlock(Block):
|
|||||||
try:
|
try:
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current search status using SDK
|
# Get current search status using SDK
|
||||||
search = aexa.websets.searches.get(
|
search = await aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -401,7 +401,7 @@ class ExaWaitForSearchBlock(Block):
|
|||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
# Get last known status
|
# Get last known status
|
||||||
search = aexa.websets.searches.get(
|
search = await aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
final_status = (
|
final_status = (
|
||||||
@@ -503,7 +503,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
try:
|
try:
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current enrichment status using SDK
|
# Get current enrichment status using SDK
|
||||||
enrichment = aexa.websets.enrichments.get(
|
enrichment = await aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -548,7 +548,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
# Get last known status
|
# Get last known status
|
||||||
enrichment = aexa.websets.enrichments.get(
|
enrichment = await aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
final_status = (
|
final_status = (
|
||||||
@@ -575,7 +575,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
) -> tuple[list[SampleEnrichmentModel], int]:
|
) -> tuple[list[SampleEnrichmentModel], int]:
|
||||||
"""Get sample enriched data and count."""
|
"""Get sample enriched data and count."""
|
||||||
# Get a few items to see enrichment results using SDK
|
# Get a few items to see enrichment results using SDK
|
||||||
response = aexa.websets.items.list(webset_id=webset_id, limit=5)
|
response = await aexa.websets.items.list(webset_id=webset_id, limit=5)
|
||||||
|
|
||||||
sample_data: list[SampleEnrichmentModel] = []
|
sample_data: list[SampleEnrichmentModel] = []
|
||||||
enriched_count = 0
|
enriched_count = 0
|
||||||
|
|||||||
@@ -317,7 +317,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
|||||||
|
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_search = aexa.websets.searches.create(
|
sdk_search = await aexa.websets.searches.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -350,7 +350,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
|||||||
poll_start = time.time()
|
poll_start = time.time()
|
||||||
|
|
||||||
while time.time() - poll_start < input_data.polling_timeout:
|
while time.time() - poll_start < input_data.polling_timeout:
|
||||||
current_search = aexa.websets.searches.get(
|
current_search = await aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=search_id
|
webset_id=input_data.webset_id, id=search_id
|
||||||
)
|
)
|
||||||
current_status = (
|
current_status = (
|
||||||
@@ -442,7 +442,7 @@ class ExaGetWebsetSearchBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_search = aexa.websets.searches.get(
|
sdk_search = await aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -523,7 +523,7 @@ class ExaCancelWebsetSearchBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_search = aexa.websets.searches.cancel(
|
canceled_search = await aexa.websets.searches.cancel(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -604,7 +604,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get webset to check existing searches
|
# Get webset to check existing searches
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
# Look for existing search with same query
|
# Look for existing search with same query
|
||||||
existing_search = None
|
existing_search = None
|
||||||
@@ -636,7 +636,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
|||||||
if input_data.entity_type != SearchEntityType.AUTO:
|
if input_data.entity_type != SearchEntityType.AUTO:
|
||||||
payload["entity"] = {"type": input_data.entity_type.value}
|
payload["entity"] = {"type": input_data.entity_type.value}
|
||||||
|
|
||||||
sdk_search = aexa.websets.searches.create(
|
sdk_search = await aexa.websets.searches.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -115,6 +115,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
|||||||
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
||||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||||
|
CLAUDE_4_6_OPUS = "claude-opus-4-6"
|
||||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||||
# AI/ML API models
|
# AI/ML API models
|
||||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||||
@@ -270,6 +271,9 @@ MODEL_METADATA = {
|
|||||||
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
||||||
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
|
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
|
||||||
), # claude-4-sonnet-20250514
|
), # claude-4-sonnet-20250514
|
||||||
|
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
|
||||||
|
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
|
||||||
|
), # claude-opus-4-6
|
||||||
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
||||||
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
||||||
), # claude-opus-4-5-20251101
|
), # claude-opus-4-5-20251101
|
||||||
@@ -592,10 +596,10 @@ def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
|||||||
|
|
||||||
def get_parallel_tool_calls_param(
|
def get_parallel_tool_calls_param(
|
||||||
llm_model: LlmModel, parallel_tool_calls: bool | None
|
llm_model: LlmModel, parallel_tool_calls: bool | None
|
||||||
):
|
) -> bool | openai.Omit:
|
||||||
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
||||||
if llm_model.startswith("o") or parallel_tool_calls is None:
|
if llm_model.startswith("o") or parallel_tool_calls is None:
|
||||||
return openai.NOT_GIVEN
|
return openai.omit
|
||||||
return parallel_tool_calls
|
return parallel_tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
238
autogpt_platform/backend/backend/blocks/mcp/block.py
Normal file
238
autogpt_platform/backend/backend/blocks/mcp/block.py
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
"""
|
||||||
|
MCP (Model Context Protocol) Tool Block.
|
||||||
|
|
||||||
|
A single dynamic block that can connect to any MCP server, discover available tools,
|
||||||
|
and execute them. Works like AgentExecutorBlock — the user selects a tool from a
|
||||||
|
dropdown and the input/output schema adapts dynamically.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockInput,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
BlockType,
|
||||||
|
)
|
||||||
|
from backend.data.model import OAuth2Credentials, SchemaField
|
||||||
|
from backend.util.json import validate_with_jsonschema
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPToolBlock(Block):
|
||||||
|
"""
|
||||||
|
A block that connects to an MCP server, lets the user pick a tool,
|
||||||
|
and executes it with dynamic input/output schema.
|
||||||
|
|
||||||
|
The flow:
|
||||||
|
1. User provides an MCP server URL (and optional credentials)
|
||||||
|
2. Frontend calls the backend to get tool list from that URL
|
||||||
|
3. User selects a tool from a dropdown (available_tools)
|
||||||
|
4. The block's input schema updates to reflect the selected tool's parameters
|
||||||
|
5. On execution, the block calls the MCP server to run the tool
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
server_url: str = SchemaField(
|
||||||
|
description="URL of the MCP server (Streamable HTTP endpoint)",
|
||||||
|
placeholder="https://mcp.example.com/mcp",
|
||||||
|
)
|
||||||
|
credential_id: str = SchemaField(
|
||||||
|
description="Credential ID from OAuth flow (empty for public servers)",
|
||||||
|
default="",
|
||||||
|
hidden=True,
|
||||||
|
)
|
||||||
|
available_tools: dict[str, Any] = SchemaField(
|
||||||
|
description="Available tools on the MCP server. "
|
||||||
|
"This is populated automatically when a server URL is provided.",
|
||||||
|
default={},
|
||||||
|
hidden=True,
|
||||||
|
)
|
||||||
|
selected_tool: str = SchemaField(
|
||||||
|
description="The MCP tool to execute",
|
||||||
|
placeholder="Select a tool",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
tool_input_schema: dict[str, Any] = SchemaField(
|
||||||
|
description="JSON Schema for the selected tool's input parameters. "
|
||||||
|
"Populated automatically when a tool is selected.",
|
||||||
|
default={},
|
||||||
|
hidden=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_arguments: dict[str, Any] = SchemaField(
|
||||||
|
description="Arguments to pass to the selected MCP tool. "
|
||||||
|
"The fields here are defined by the tool's input schema.",
|
||||||
|
default={},
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
|
||||||
|
"""Return the tool's input schema so the builder UI renders dynamic fields."""
|
||||||
|
return data.get("tool_input_schema", {})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||||
|
"""Return the current tool_arguments as defaults for the dynamic fields."""
|
||||||
|
return data.get("tool_arguments", {})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||||
|
"""Check which required tool arguments are missing."""
|
||||||
|
required_fields = cls.get_input_schema(data).get("required", [])
|
||||||
|
tool_arguments = data.get("tool_arguments", {})
|
||||||
|
return set(required_fields) - set(tool_arguments)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||||
|
"""Validate tool_arguments against the tool's input schema."""
|
||||||
|
tool_schema = cls.get_input_schema(data)
|
||||||
|
if not tool_schema:
|
||||||
|
return None
|
||||||
|
tool_arguments = data.get("tool_arguments", {})
|
||||||
|
return validate_with_jsonschema(tool_schema, tool_arguments)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
result: Any = SchemaField(description="The result returned by the MCP tool")
|
||||||
|
error: str = SchemaField(description="Error message if the tool call failed")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4",
|
||||||
|
description="Connect to any MCP server and execute its tools. "
|
||||||
|
"Provide a server URL, select a tool, and pass arguments dynamically.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=MCPToolBlock.Input,
|
||||||
|
output_schema=MCPToolBlock.Output,
|
||||||
|
block_type=BlockType.STANDARD,
|
||||||
|
test_input={
|
||||||
|
"server_url": "https://mcp.example.com/mcp",
|
||||||
|
"selected_tool": "get_weather",
|
||||||
|
"tool_input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"city": {"type": "string"}},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
"tool_arguments": {"city": "London"},
|
||||||
|
},
|
||||||
|
test_output=[
|
||||||
|
(
|
||||||
|
"result",
|
||||||
|
{"weather": "sunny", "temperature": 20},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"_call_mcp_tool": lambda *a, **kw: {
|
||||||
|
"weather": "sunny",
|
||||||
|
"temperature": 20,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _call_mcp_tool(
|
||||||
|
self,
|
||||||
|
server_url: str,
|
||||||
|
tool_name: str,
|
||||||
|
arguments: dict[str, Any],
|
||||||
|
auth_token: str | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Call a tool on the MCP server. Extracted for easy mocking in tests."""
|
||||||
|
# Trust the user-configured server URL to allow internal/localhost servers
|
||||||
|
client = MCPClient(
|
||||||
|
server_url,
|
||||||
|
auth_token=auth_token,
|
||||||
|
trusted_origins=[server_url],
|
||||||
|
)
|
||||||
|
await client.initialize()
|
||||||
|
result = await client.call_tool(tool_name, arguments)
|
||||||
|
|
||||||
|
if result.is_error:
|
||||||
|
error_text = ""
|
||||||
|
for item in result.content:
|
||||||
|
if item.get("type") == "text":
|
||||||
|
error_text += item.get("text", "")
|
||||||
|
raise MCPClientError(
|
||||||
|
f"MCP tool '{tool_name}' returned an error: "
|
||||||
|
f"{error_text or 'Unknown error'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract text content from the result
|
||||||
|
output_parts = []
|
||||||
|
for item in result.content:
|
||||||
|
if item.get("type") == "text":
|
||||||
|
text = item.get("text", "")
|
||||||
|
# Try to parse as JSON for structured output
|
||||||
|
try:
|
||||||
|
output_parts.append(json.loads(text))
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
output_parts.append(text)
|
||||||
|
elif item.get("type") == "image":
|
||||||
|
output_parts.append(
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"data": item.get("data"),
|
||||||
|
"mimeType": item.get("mimeType"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif item.get("type") == "resource":
|
||||||
|
output_parts.append(item.get("resource", {}))
|
||||||
|
|
||||||
|
# If single result, unwrap
|
||||||
|
if len(output_parts) == 1:
|
||||||
|
return output_parts[0]
|
||||||
|
return output_parts if output_parts else None
|
||||||
|
|
||||||
|
async def _resolve_auth_token(self, credential_id: str, user_id: str) -> str | None:
|
||||||
|
"""Resolve a Bearer token from a stored credential ID."""
|
||||||
|
if not credential_id:
|
||||||
|
return None
|
||||||
|
from backend.util.clients import get_integration_credentials_store
|
||||||
|
|
||||||
|
store = get_integration_credentials_store()
|
||||||
|
creds = await store.get_creds_by_id(user_id, credential_id)
|
||||||
|
if not creds:
|
||||||
|
logger.warning(f"Credential {credential_id} not found")
|
||||||
|
return None
|
||||||
|
if isinstance(creds, OAuth2Credentials):
|
||||||
|
return creds.access_token.get_secret_value()
|
||||||
|
if hasattr(creds, "api_key") and creds.api_key:
|
||||||
|
return creds.api_key.get_secret_value() or None
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
if not input_data.server_url:
|
||||||
|
yield "error", "MCP server URL is required"
|
||||||
|
return
|
||||||
|
|
||||||
|
if not input_data.selected_tool:
|
||||||
|
yield "error", "No tool selected. Please select a tool from the dropdown."
|
||||||
|
return
|
||||||
|
|
||||||
|
auth_token = await self._resolve_auth_token(input_data.credential_id, user_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self._call_mcp_tool(
|
||||||
|
server_url=input_data.server_url,
|
||||||
|
tool_name=input_data.selected_tool,
|
||||||
|
arguments=input_data.tool_arguments,
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
yield "result", result
|
||||||
|
except MCPClientError as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"MCP tool call failed: {e}")
|
||||||
|
yield "error", f"MCP tool call failed: {str(e)}"
|
||||||
316
autogpt_platform/backend/backend/blocks/mcp/client.py
Normal file
316
autogpt_platform/backend/backend/blocks/mcp/client.py
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
"""
|
||||||
|
MCP (Model Context Protocol) HTTP client.
|
||||||
|
|
||||||
|
Implements the MCP Streamable HTTP transport for listing tools and calling tools
|
||||||
|
on remote MCP servers. Uses JSON-RPC 2.0 over HTTP POST.
|
||||||
|
|
||||||
|
Handles both JSON and SSE (text/event-stream) response formats per the MCP spec.
|
||||||
|
|
||||||
|
Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MCPTool:
|
||||||
|
"""Represents an MCP tool discovered from a server."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
input_schema: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MCPCallResult:
|
||||||
|
"""Result from calling an MCP tool."""
|
||||||
|
|
||||||
|
content: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
is_error: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class MCPClientError(Exception):
|
||||||
|
"""Raised when an MCP protocol error occurs."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MCPClient:
|
||||||
|
"""
|
||||||
|
Async HTTP client for the MCP Streamable HTTP transport.
|
||||||
|
|
||||||
|
Communicates with MCP servers using JSON-RPC 2.0 over HTTP POST.
|
||||||
|
Supports optional Bearer token authentication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_url: str,
|
||||||
|
auth_token: str | None = None,
|
||||||
|
trusted_origins: list[str] | None = None,
|
||||||
|
):
|
||||||
|
self.server_url = server_url.rstrip("/")
|
||||||
|
self.auth_token = auth_token
|
||||||
|
self.trusted_origins = trusted_origins or []
|
||||||
|
self._request_id = 0
|
||||||
|
|
||||||
|
def _next_id(self) -> int:
|
||||||
|
self._request_id += 1
|
||||||
|
return self._request_id
|
||||||
|
|
||||||
|
def _build_headers(self) -> dict[str, str]:
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json, text/event-stream",
|
||||||
|
}
|
||||||
|
if self.auth_token:
|
||||||
|
headers["Authorization"] = f"Bearer {self.auth_token}"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def _build_jsonrpc_request(
|
||||||
|
self, method: str, params: dict[str, Any] | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
req: dict[str, Any] = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": method,
|
||||||
|
"id": self._next_id(),
|
||||||
|
}
|
||||||
|
if params is not None:
|
||||||
|
req["params"] = params
|
||||||
|
return req
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_sse_response(text: str) -> dict[str, Any]:
|
||||||
|
"""Parse an SSE (text/event-stream) response body into JSON-RPC data.
|
||||||
|
|
||||||
|
MCP servers may return responses as SSE with format:
|
||||||
|
event: message
|
||||||
|
data: {"jsonrpc":"2.0","result":{...},"id":1}
|
||||||
|
|
||||||
|
We extract the last `data:` line that contains a JSON-RPC response
|
||||||
|
(i.e. has an "id" field), which is the reply to our request.
|
||||||
|
"""
|
||||||
|
last_data: dict[str, Any] | None = None
|
||||||
|
for line in text.splitlines():
|
||||||
|
stripped = line.strip()
|
||||||
|
if stripped.startswith("data:"):
|
||||||
|
payload = stripped[len("data:") :].strip()
|
||||||
|
if not payload:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
parsed = json.loads(payload)
|
||||||
|
# Only keep JSON-RPC responses (have "id"), skip notifications
|
||||||
|
if isinstance(parsed, dict) and "id" in parsed:
|
||||||
|
last_data = parsed
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
continue
|
||||||
|
if last_data is None:
|
||||||
|
raise MCPClientError("No JSON-RPC response found in SSE stream")
|
||||||
|
return last_data
|
||||||
|
|
||||||
|
async def _send_request(
|
||||||
|
self, method: str, params: dict[str, Any] | None = None
|
||||||
|
) -> Any:
|
||||||
|
"""Send a JSON-RPC request to the MCP server and return the result.
|
||||||
|
|
||||||
|
Handles both ``application/json`` and ``text/event-stream`` responses
|
||||||
|
as required by the MCP Streamable HTTP transport specification.
|
||||||
|
"""
|
||||||
|
payload = self._build_jsonrpc_request(method, params)
|
||||||
|
headers = self._build_headers()
|
||||||
|
|
||||||
|
requests = Requests(
|
||||||
|
raise_for_status=True,
|
||||||
|
extra_headers=headers,
|
||||||
|
trusted_origins=self.trusted_origins,
|
||||||
|
)
|
||||||
|
response = await requests.post(self.server_url, json=payload)
|
||||||
|
|
||||||
|
content_type = response.headers.get("content-type", "")
|
||||||
|
if "text/event-stream" in content_type:
|
||||||
|
body = self._parse_sse_response(response.text())
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
body = response.json()
|
||||||
|
except Exception as e:
|
||||||
|
raise MCPClientError(
|
||||||
|
f"MCP server returned non-JSON response: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# Handle JSON-RPC error
|
||||||
|
if "error" in body:
|
||||||
|
error = body["error"]
|
||||||
|
if isinstance(error, dict):
|
||||||
|
raise MCPClientError(
|
||||||
|
f"MCP server error [{error.get('code', '?')}]: "
|
||||||
|
f"{error.get('message', 'Unknown error')}"
|
||||||
|
)
|
||||||
|
raise MCPClientError(f"MCP server error: {error}")
|
||||||
|
|
||||||
|
return body.get("result")
|
||||||
|
|
||||||
|
async def _send_notification(self, method: str) -> None:
|
||||||
|
"""Send a JSON-RPC notification (no id, no response expected)."""
|
||||||
|
headers = self._build_headers()
|
||||||
|
notification = {"jsonrpc": "2.0", "method": method}
|
||||||
|
requests = Requests(
|
||||||
|
raise_for_status=False,
|
||||||
|
extra_headers=headers,
|
||||||
|
trusted_origins=self.trusted_origins,
|
||||||
|
)
|
||||||
|
await requests.post(self.server_url, json=notification)
|
||||||
|
|
||||||
|
async def discover_auth(self) -> dict[str, Any] | None:
|
||||||
|
"""Probe the MCP server's OAuth metadata (RFC 9728 / MCP spec).
|
||||||
|
|
||||||
|
Returns ``None`` if the server doesn't require auth, otherwise returns
|
||||||
|
a dict with:
|
||||||
|
- ``authorization_servers``: list of authorization server URLs
|
||||||
|
- ``resource``: the resource indicator URL (usually the MCP endpoint)
|
||||||
|
- ``scopes_supported``: optional list of supported scopes
|
||||||
|
|
||||||
|
The caller can then fetch the authorization server metadata to get
|
||||||
|
``authorization_endpoint``, ``token_endpoint``, etc.
|
||||||
|
"""
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
parsed = urlparse(self.server_url)
|
||||||
|
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||||
|
|
||||||
|
# Build candidates for protected-resource metadata (per RFC 9728)
|
||||||
|
path = parsed.path.rstrip("/")
|
||||||
|
candidates = []
|
||||||
|
if path and path != "/":
|
||||||
|
candidates.append(f"{base}/.well-known/oauth-protected-resource{path}")
|
||||||
|
candidates.append(f"{base}/.well-known/oauth-protected-resource")
|
||||||
|
|
||||||
|
requests = Requests(
|
||||||
|
raise_for_status=False,
|
||||||
|
trusted_origins=self.trusted_origins,
|
||||||
|
)
|
||||||
|
for url in candidates:
|
||||||
|
try:
|
||||||
|
resp = await requests.get(url)
|
||||||
|
if resp.status == 200:
|
||||||
|
data = resp.json()
|
||||||
|
if isinstance(data, dict) and "authorization_servers" in data:
|
||||||
|
return data
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def discover_auth_server_metadata(
|
||||||
|
self, auth_server_url: str
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Fetch the OAuth Authorization Server Metadata (RFC 8414).
|
||||||
|
|
||||||
|
Given an authorization server URL, returns a dict with:
|
||||||
|
- ``authorization_endpoint``
|
||||||
|
- ``token_endpoint``
|
||||||
|
- ``registration_endpoint`` (for dynamic client registration)
|
||||||
|
- ``scopes_supported``
|
||||||
|
- ``code_challenge_methods_supported``
|
||||||
|
- etc.
|
||||||
|
"""
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
parsed = urlparse(auth_server_url)
|
||||||
|
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||||
|
path = parsed.path.rstrip("/")
|
||||||
|
|
||||||
|
# Try standard metadata endpoints (RFC 8414 and OpenID Connect)
|
||||||
|
candidates = []
|
||||||
|
if path and path != "/":
|
||||||
|
candidates.append(f"{base}/.well-known/oauth-authorization-server{path}")
|
||||||
|
candidates.append(f"{base}/.well-known/oauth-authorization-server")
|
||||||
|
candidates.append(f"{base}/.well-known/openid-configuration")
|
||||||
|
|
||||||
|
requests = Requests(
|
||||||
|
raise_for_status=False,
|
||||||
|
trusted_origins=self.trusted_origins,
|
||||||
|
)
|
||||||
|
for url in candidates:
|
||||||
|
try:
|
||||||
|
resp = await requests.get(url)
|
||||||
|
if resp.status == 200:
|
||||||
|
data = resp.json()
|
||||||
|
if isinstance(data, dict) and "authorization_endpoint" in data:
|
||||||
|
return data
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def initialize(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Send the MCP initialize request.
|
||||||
|
|
||||||
|
This is required by the MCP protocol before any other requests.
|
||||||
|
Returns the server's capabilities.
|
||||||
|
"""
|
||||||
|
result = await self._send_request(
|
||||||
|
"initialize",
|
||||||
|
{
|
||||||
|
"protocolVersion": "2025-03-26",
|
||||||
|
"capabilities": {},
|
||||||
|
"clientInfo": {"name": "AutoGPT-Platform", "version": "1.0.0"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Send initialized notification (no response expected)
|
||||||
|
await self._send_notification("notifications/initialized")
|
||||||
|
|
||||||
|
return result or {}
|
||||||
|
|
||||||
|
async def list_tools(self) -> list[MCPTool]:
|
||||||
|
"""
|
||||||
|
Discover available tools from the MCP server.
|
||||||
|
|
||||||
|
Returns a list of MCPTool objects with name, description, and input schema.
|
||||||
|
"""
|
||||||
|
result = await self._send_request("tools/list")
|
||||||
|
if not result or "tools" not in result:
|
||||||
|
return []
|
||||||
|
|
||||||
|
tools = []
|
||||||
|
for tool_data in result["tools"]:
|
||||||
|
tools.append(
|
||||||
|
MCPTool(
|
||||||
|
name=tool_data.get("name", ""),
|
||||||
|
description=tool_data.get("description", ""),
|
||||||
|
input_schema=tool_data.get("inputSchema", {}),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
async def call_tool(
|
||||||
|
self, tool_name: str, arguments: dict[str, Any]
|
||||||
|
) -> MCPCallResult:
|
||||||
|
"""
|
||||||
|
Call a tool on the MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: The name of the tool to call.
|
||||||
|
arguments: The arguments to pass to the tool.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MCPCallResult with the tool's response content.
|
||||||
|
"""
|
||||||
|
result = await self._send_request(
|
||||||
|
"tools/call",
|
||||||
|
{"name": tool_name, "arguments": arguments},
|
||||||
|
)
|
||||||
|
if not result:
|
||||||
|
return MCPCallResult(is_error=True)
|
||||||
|
|
||||||
|
return MCPCallResult(
|
||||||
|
content=result.get("content", []),
|
||||||
|
is_error=result.get("isError", False),
|
||||||
|
)
|
||||||
42
autogpt_platform/backend/backend/blocks/mcp/conftest.py
Normal file
42
autogpt_platform/backend/backend/blocks/mcp/conftest.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""
|
||||||
|
Conftest for MCP block tests.
|
||||||
|
|
||||||
|
Override the session-scoped server and graph_cleanup fixtures from
|
||||||
|
backend/conftest.py so that MCP integration tests don't spin up the
|
||||||
|
full SpinTestServer infrastructure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config: pytest.Config) -> None:
|
||||||
|
config.addinivalue_line("markers", "e2e: end-to-end tests requiring network")
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_collection_modifyitems(
|
||||||
|
config: pytest.Config, items: list[pytest.Item]
|
||||||
|
) -> None:
|
||||||
|
"""Skip e2e tests unless --run-e2e is passed."""
|
||||||
|
if not config.getoption("--run-e2e", default=False):
|
||||||
|
skip_e2e = pytest.mark.skip(reason="need --run-e2e option to run")
|
||||||
|
for item in items:
|
||||||
|
if "e2e" in item.keywords:
|
||||||
|
item.add_marker(skip_e2e)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||||
|
parser.addoption(
|
||||||
|
"--run-e2e", action="store_true", default=False, help="run e2e tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def server():
|
||||||
|
"""No-op override — MCP tests don't need the full platform server."""
|
||||||
|
yield None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def graph_cleanup(server):
|
||||||
|
"""No-op override — MCP tests don't create graphs."""
|
||||||
|
yield
|
||||||
198
autogpt_platform/backend/backend/blocks/mcp/oauth.py
Normal file
198
autogpt_platform/backend/backend/blocks/mcp/oauth.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
"""
|
||||||
|
MCP OAuth handler for MCP servers that use OAuth 2.1 authorization.
|
||||||
|
|
||||||
|
Unlike other OAuth handlers (GitHub, Google, etc.) where endpoints are fixed,
|
||||||
|
MCP servers have dynamic endpoints discovered via RFC 9728 / RFC 8414 metadata.
|
||||||
|
This handler accepts those endpoints at construction time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import urllib.parse
|
||||||
|
from typing import ClassVar, Optional
|
||||||
|
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.data.model import OAuth2Credentials
|
||||||
|
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPOAuthHandler(BaseOAuthHandler):
|
||||||
|
"""
|
||||||
|
OAuth handler for MCP servers with dynamically-discovered endpoints.
|
||||||
|
|
||||||
|
Construction requires the authorization and token endpoint URLs,
|
||||||
|
which are obtained via MCP OAuth metadata discovery
|
||||||
|
(``MCPClient.discover_auth`` + ``discover_auth_server_metadata``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
PROVIDER_NAME: ClassVar[ProviderName | str] = ProviderName.MCP
|
||||||
|
DEFAULT_SCOPES: ClassVar[list[str]] = []
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
client_id: str,
|
||||||
|
client_secret: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
*,
|
||||||
|
authorize_url: str,
|
||||||
|
token_url: str,
|
||||||
|
revoke_url: str | None = None,
|
||||||
|
resource_url: str | None = None,
|
||||||
|
):
|
||||||
|
self.client_id = client_id
|
||||||
|
self.client_secret = client_secret
|
||||||
|
self.redirect_uri = redirect_uri
|
||||||
|
self.authorize_url = authorize_url
|
||||||
|
self.token_url = token_url
|
||||||
|
self.revoke_url = revoke_url
|
||||||
|
self.resource_url = resource_url
|
||||||
|
|
||||||
|
def get_login_url(
|
||||||
|
self,
|
||||||
|
scopes: list[str],
|
||||||
|
state: str,
|
||||||
|
code_challenge: Optional[str],
|
||||||
|
) -> str:
|
||||||
|
scopes = self.handle_default_scopes(scopes)
|
||||||
|
|
||||||
|
params: dict[str, str] = {
|
||||||
|
"response_type": "code",
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"redirect_uri": self.redirect_uri,
|
||||||
|
"state": state,
|
||||||
|
}
|
||||||
|
if scopes:
|
||||||
|
params["scope"] = " ".join(scopes)
|
||||||
|
# PKCE (S256) — included when the caller provides a code_challenge
|
||||||
|
if code_challenge:
|
||||||
|
params["code_challenge"] = code_challenge
|
||||||
|
params["code_challenge_method"] = "S256"
|
||||||
|
# MCP spec requires resource indicator (RFC 8707)
|
||||||
|
if self.resource_url:
|
||||||
|
params["resource"] = self.resource_url
|
||||||
|
|
||||||
|
return f"{self.authorize_url}?{urllib.parse.urlencode(params)}"
|
||||||
|
|
||||||
|
async def exchange_code_for_tokens(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
scopes: list[str],
|
||||||
|
code_verifier: Optional[str],
|
||||||
|
) -> OAuth2Credentials:
|
||||||
|
data: dict[str, str] = {
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"code": code,
|
||||||
|
"redirect_uri": self.redirect_uri,
|
||||||
|
"client_id": self.client_id,
|
||||||
|
}
|
||||||
|
if self.client_secret:
|
||||||
|
data["client_secret"] = self.client_secret
|
||||||
|
if code_verifier:
|
||||||
|
data["code_verifier"] = code_verifier
|
||||||
|
if self.resource_url:
|
||||||
|
data["resource"] = self.resource_url
|
||||||
|
|
||||||
|
response = await Requests(raise_for_status=True).post(
|
||||||
|
self.token_url,
|
||||||
|
data=data,
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
)
|
||||||
|
tokens = response.json()
|
||||||
|
|
||||||
|
if "error" in tokens:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Token exchange failed: {tokens.get('error_description', tokens['error'])}"
|
||||||
|
)
|
||||||
|
|
||||||
|
now = int(time.time())
|
||||||
|
expires_in = tokens.get("expires_in")
|
||||||
|
|
||||||
|
return OAuth2Credentials(
|
||||||
|
provider=str(self.PROVIDER_NAME),
|
||||||
|
title=None,
|
||||||
|
access_token=SecretStr(tokens["access_token"]),
|
||||||
|
refresh_token=(
|
||||||
|
SecretStr(tokens["refresh_token"])
|
||||||
|
if tokens.get("refresh_token")
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
access_token_expires_at=now + expires_in if expires_in else None,
|
||||||
|
refresh_token_expires_at=None,
|
||||||
|
scopes=scopes,
|
||||||
|
metadata={
|
||||||
|
"mcp_token_url": self.token_url,
|
||||||
|
"mcp_resource_url": self.resource_url,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _refresh_tokens(
|
||||||
|
self, credentials: OAuth2Credentials
|
||||||
|
) -> OAuth2Credentials:
|
||||||
|
if not credentials.refresh_token:
|
||||||
|
raise ValueError("No refresh token available for MCP OAuth credentials")
|
||||||
|
|
||||||
|
data: dict[str, str] = {
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
"refresh_token": credentials.refresh_token.get_secret_value(),
|
||||||
|
"client_id": self.client_id,
|
||||||
|
}
|
||||||
|
if self.client_secret:
|
||||||
|
data["client_secret"] = self.client_secret
|
||||||
|
if self.resource_url:
|
||||||
|
data["resource"] = self.resource_url
|
||||||
|
|
||||||
|
response = await Requests(raise_for_status=True).post(
|
||||||
|
self.token_url,
|
||||||
|
data=data,
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
)
|
||||||
|
tokens = response.json()
|
||||||
|
|
||||||
|
if "error" in tokens:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Token refresh failed: {tokens.get('error_description', tokens['error'])}"
|
||||||
|
)
|
||||||
|
|
||||||
|
now = int(time.time())
|
||||||
|
expires_in = tokens.get("expires_in")
|
||||||
|
|
||||||
|
return OAuth2Credentials(
|
||||||
|
id=credentials.id,
|
||||||
|
provider=str(self.PROVIDER_NAME),
|
||||||
|
title=credentials.title,
|
||||||
|
access_token=SecretStr(tokens["access_token"]),
|
||||||
|
refresh_token=(
|
||||||
|
SecretStr(tokens["refresh_token"])
|
||||||
|
if tokens.get("refresh_token")
|
||||||
|
else credentials.refresh_token
|
||||||
|
),
|
||||||
|
access_token_expires_at=now + expires_in if expires_in else None,
|
||||||
|
refresh_token_expires_at=credentials.refresh_token_expires_at,
|
||||||
|
scopes=credentials.scopes,
|
||||||
|
metadata=credentials.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||||
|
if not self.revoke_url:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = {
|
||||||
|
"token": credentials.access_token.get_secret_value(),
|
||||||
|
"token_type_hint": "access_token",
|
||||||
|
"client_id": self.client_id,
|
||||||
|
}
|
||||||
|
await Requests().post(
|
||||||
|
self.revoke_url,
|
||||||
|
data=data,
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to revoke MCP OAuth tokens", exc_info=True)
|
||||||
|
return False
|
||||||
104
autogpt_platform/backend/backend/blocks/mcp/test_e2e.py
Normal file
104
autogpt_platform/backend/backend/blocks/mcp/test_e2e.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
"""
|
||||||
|
End-to-end tests against a real public MCP server.
|
||||||
|
|
||||||
|
These tests hit the OpenAI docs MCP server (https://developers.openai.com/mcp)
|
||||||
|
which is publicly accessible without authentication and returns SSE responses.
|
||||||
|
|
||||||
|
Mark: These are tagged with ``@pytest.mark.e2e`` so they can be run/skipped
|
||||||
|
independently of the rest of the test suite (they require network access).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.blocks.mcp.client import MCPClient
|
||||||
|
|
||||||
|
# Public MCP server that requires no authentication
|
||||||
|
OPENAI_DOCS_MCP_URL = "https://developers.openai.com/mcp"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.e2e
|
||||||
|
class TestRealMCPServer:
|
||||||
|
"""Tests against the live OpenAI docs MCP server."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_initialize(self):
|
||||||
|
"""Verify we can complete the MCP handshake with a real server."""
|
||||||
|
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||||
|
result = await client.initialize()
|
||||||
|
|
||||||
|
assert result["protocolVersion"] == "2025-03-26"
|
||||||
|
assert "serverInfo" in result
|
||||||
|
assert result["serverInfo"]["name"] == "openai-docs-mcp"
|
||||||
|
assert "tools" in result.get("capabilities", {})
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_tools(self):
|
||||||
|
"""Verify we can discover tools from a real MCP server."""
|
||||||
|
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||||
|
await client.initialize()
|
||||||
|
tools = await client.list_tools()
|
||||||
|
|
||||||
|
assert len(tools) >= 3 # server has at least 5 tools as of writing
|
||||||
|
|
||||||
|
tool_names = {t.name for t in tools}
|
||||||
|
# These tools are documented and should be stable
|
||||||
|
assert "search_openai_docs" in tool_names
|
||||||
|
assert "list_openai_docs" in tool_names
|
||||||
|
assert "fetch_openai_doc" in tool_names
|
||||||
|
|
||||||
|
# Verify schema structure
|
||||||
|
search_tool = next(t for t in tools if t.name == "search_openai_docs")
|
||||||
|
assert "query" in search_tool.input_schema.get("properties", {})
|
||||||
|
assert "query" in search_tool.input_schema.get("required", [])
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_tool_list_api_endpoints(self):
|
||||||
|
"""Call the list_api_endpoints tool and verify we get real data."""
|
||||||
|
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||||
|
await client.initialize()
|
||||||
|
result = await client.call_tool("list_api_endpoints", {})
|
||||||
|
|
||||||
|
assert not result.is_error
|
||||||
|
assert len(result.content) >= 1
|
||||||
|
assert result.content[0]["type"] == "text"
|
||||||
|
|
||||||
|
data = json.loads(result.content[0]["text"])
|
||||||
|
assert "paths" in data or "urls" in data
|
||||||
|
# The OpenAI API should have many endpoints
|
||||||
|
total = data.get("total", len(data.get("paths", [])))
|
||||||
|
assert total > 50
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_tool_search(self):
|
||||||
|
"""Search for docs and verify we get results."""
|
||||||
|
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||||
|
await client.initialize()
|
||||||
|
result = await client.call_tool(
|
||||||
|
"search_openai_docs", {"query": "chat completions", "limit": 3}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not result.is_error
|
||||||
|
assert len(result.content) >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sse_response_handling(self):
|
||||||
|
"""Verify the client correctly handles SSE responses from a real server.
|
||||||
|
|
||||||
|
This is the key test — our local test server returns JSON,
|
||||||
|
but real MCP servers typically return SSE. This proves the
|
||||||
|
SSE parsing works end-to-end.
|
||||||
|
"""
|
||||||
|
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||||
|
# initialize() internally calls _send_request which must parse SSE
|
||||||
|
result = await client.initialize()
|
||||||
|
|
||||||
|
# If we got here without error, SSE parsing works
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert "protocolVersion" in result
|
||||||
|
|
||||||
|
# Also verify list_tools works (another SSE response)
|
||||||
|
tools = await client.list_tools()
|
||||||
|
assert len(tools) > 0
|
||||||
|
assert all(hasattr(t, "name") for t in tools)
|
||||||
367
autogpt_platform/backend/backend/blocks/mcp/test_integration.py
Normal file
367
autogpt_platform/backend/backend/blocks/mcp/test_integration.py
Normal file
@@ -0,0 +1,367 @@
|
|||||||
|
"""
|
||||||
|
Integration tests for MCP client and MCPToolBlock against a real HTTP server.
|
||||||
|
|
||||||
|
These tests spin up a local MCP test server and run the full client/block flow
|
||||||
|
against it — no mocking, real HTTP requests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from aiohttp import web
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.blocks.mcp.block import MCPToolBlock
|
||||||
|
from backend.blocks.mcp.client import MCPClient
|
||||||
|
from backend.blocks.mcp.test_server import create_test_mcp_app
|
||||||
|
from backend.data.model import APIKeyCredentials
|
||||||
|
|
||||||
|
|
||||||
|
class _MCPTestServer:
|
||||||
|
"""
|
||||||
|
Run an MCP test server in a background thread with its own event loop.
|
||||||
|
This avoids event loop conflicts with pytest-asyncio.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, auth_token: str | None = None):
|
||||||
|
self.auth_token = auth_token
|
||||||
|
self.url: str = ""
|
||||||
|
self._runner: web.AppRunner | None = None
|
||||||
|
self._loop: asyncio.AbstractEventLoop | None = None
|
||||||
|
self._thread: threading.Thread | None = None
|
||||||
|
self._started = threading.Event()
|
||||||
|
|
||||||
|
def _run(self):
|
||||||
|
self._loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self._loop)
|
||||||
|
self._loop.run_until_complete(self._start())
|
||||||
|
self._started.set()
|
||||||
|
self._loop.run_forever()
|
||||||
|
|
||||||
|
async def _start(self):
|
||||||
|
app = create_test_mcp_app(auth_token=self.auth_token)
|
||||||
|
self._runner = web.AppRunner(app)
|
||||||
|
await self._runner.setup()
|
||||||
|
site = web.TCPSite(self._runner, "127.0.0.1", 0)
|
||||||
|
await site.start()
|
||||||
|
port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr]
|
||||||
|
self.url = f"http://127.0.0.1:{port}/mcp"
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
if not self._started.wait(timeout=5):
|
||||||
|
raise RuntimeError("MCP test server failed to start within 5 seconds")
|
||||||
|
return self
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
if self._loop and self._runner:
|
||||||
|
asyncio.run_coroutine_threadsafe(self._runner.cleanup(), self._loop).result(
|
||||||
|
timeout=5
|
||||||
|
)
|
||||||
|
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||||
|
if self._thread:
|
||||||
|
self._thread.join(timeout=5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def mcp_server():
|
||||||
|
"""Start a local MCP test server in a background thread."""
|
||||||
|
server = _MCPTestServer()
|
||||||
|
server.start()
|
||||||
|
yield server.url
|
||||||
|
server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def mcp_server_with_auth():
|
||||||
|
"""Start a local MCP test server with auth in a background thread."""
|
||||||
|
server = _MCPTestServer(auth_token="test-secret-token")
|
||||||
|
server.start()
|
||||||
|
yield server.url, "test-secret-token"
|
||||||
|
server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_client(url: str, auth_token: str | None = None) -> MCPClient:
|
||||||
|
"""Create an MCPClient with localhost trusted for integration tests."""
|
||||||
|
return MCPClient(url, auth_token=auth_token, trusted_origins=[url])
|
||||||
|
|
||||||
|
|
||||||
|
def _make_fake_creds(api_key: str = "FAKE_API_KEY") -> APIKeyCredentials:
|
||||||
|
return APIKeyCredentials(
|
||||||
|
id="test-integration",
|
||||||
|
provider="mcp",
|
||||||
|
api_key=SecretStr(api_key),
|
||||||
|
title="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── MCPClient integration tests ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPClientIntegration:
|
||||||
|
"""Test MCPClient against a real local MCP server."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_initialize(self, mcp_server):
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
result = await client.initialize()
|
||||||
|
|
||||||
|
assert result["protocolVersion"] == "2025-03-26"
|
||||||
|
assert result["serverInfo"]["name"] == "test-mcp-server"
|
||||||
|
assert "tools" in result["capabilities"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_tools(self, mcp_server):
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
await client.initialize()
|
||||||
|
tools = await client.list_tools()
|
||||||
|
|
||||||
|
assert len(tools) == 3
|
||||||
|
|
||||||
|
tool_names = {t.name for t in tools}
|
||||||
|
assert tool_names == {"get_weather", "add_numbers", "echo"}
|
||||||
|
|
||||||
|
# Check get_weather schema
|
||||||
|
weather = next(t for t in tools if t.name == "get_weather")
|
||||||
|
assert weather.description == "Get current weather for a city"
|
||||||
|
assert "city" in weather.input_schema["properties"]
|
||||||
|
assert weather.input_schema["required"] == ["city"]
|
||||||
|
|
||||||
|
# Check add_numbers schema
|
||||||
|
add = next(t for t in tools if t.name == "add_numbers")
|
||||||
|
assert "a" in add.input_schema["properties"]
|
||||||
|
assert "b" in add.input_schema["properties"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_tool_get_weather(self, mcp_server):
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
await client.initialize()
|
||||||
|
result = await client.call_tool("get_weather", {"city": "London"})
|
||||||
|
|
||||||
|
assert not result.is_error
|
||||||
|
assert len(result.content) == 1
|
||||||
|
assert result.content[0]["type"] == "text"
|
||||||
|
|
||||||
|
data = json.loads(result.content[0]["text"])
|
||||||
|
assert data["city"] == "London"
|
||||||
|
assert data["temperature"] == 22
|
||||||
|
assert data["condition"] == "sunny"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_tool_add_numbers(self, mcp_server):
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
await client.initialize()
|
||||||
|
result = await client.call_tool("add_numbers", {"a": 3, "b": 7})
|
||||||
|
|
||||||
|
assert not result.is_error
|
||||||
|
data = json.loads(result.content[0]["text"])
|
||||||
|
assert data["result"] == 10
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_tool_echo(self, mcp_server):
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
await client.initialize()
|
||||||
|
result = await client.call_tool("echo", {"message": "Hello MCP!"})
|
||||||
|
|
||||||
|
assert not result.is_error
|
||||||
|
assert result.content[0]["text"] == "Hello MCP!"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_unknown_tool(self, mcp_server):
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
await client.initialize()
|
||||||
|
result = await client.call_tool("nonexistent_tool", {})
|
||||||
|
|
||||||
|
assert result.is_error
|
||||||
|
assert "Unknown tool" in result.content[0]["text"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth_success(self, mcp_server_with_auth):
|
||||||
|
url, token = mcp_server_with_auth
|
||||||
|
client = _make_client(url, auth_token=token)
|
||||||
|
result = await client.initialize()
|
||||||
|
|
||||||
|
assert result["protocolVersion"] == "2025-03-26"
|
||||||
|
|
||||||
|
tools = await client.list_tools()
|
||||||
|
assert len(tools) == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth_failure(self, mcp_server_with_auth):
|
||||||
|
url, _ = mcp_server_with_auth
|
||||||
|
client = _make_client(url, auth_token="wrong-token")
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
await client.initialize()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth_missing(self, mcp_server_with_auth):
|
||||||
|
url, _ = mcp_server_with_auth
|
||||||
|
client = _make_client(url)
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
await client.initialize()
|
||||||
|
|
||||||
|
|
||||||
|
# ── MCPToolBlock integration tests ───────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPToolBlockIntegration:
|
||||||
|
"""Test MCPToolBlock end-to-end against a real local MCP server."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_flow_get_weather(self, mcp_server):
|
||||||
|
"""Full flow: discover tools, select one, execute it."""
|
||||||
|
# Step 1: Discover tools (simulating what the frontend/API would do)
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
await client.initialize()
|
||||||
|
tools = await client.list_tools()
|
||||||
|
assert len(tools) == 3
|
||||||
|
|
||||||
|
# Step 2: User selects "get_weather" and we get its schema
|
||||||
|
weather_tool = next(t for t in tools if t.name == "get_weather")
|
||||||
|
|
||||||
|
# Step 3: Execute the block with the selected tool
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url=mcp_server,
|
||||||
|
selected_tool="get_weather",
|
||||||
|
tool_input_schema=weather_tool.input_schema,
|
||||||
|
tool_arguments={"city": "Paris"},
|
||||||
|
credentials={ # type: ignore
|
||||||
|
"provider": "mcp",
|
||||||
|
"id": "test",
|
||||||
|
"type": "api_key",
|
||||||
|
"title": "test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, credentials=_make_fake_creds()):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
assert outputs[0][0] == "result"
|
||||||
|
result = outputs[0][1]
|
||||||
|
assert result["city"] == "Paris"
|
||||||
|
assert result["temperature"] == 22
|
||||||
|
assert result["condition"] == "sunny"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_flow_add_numbers(self, mcp_server):
|
||||||
|
"""Full flow for add_numbers tool."""
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
await client.initialize()
|
||||||
|
tools = await client.list_tools()
|
||||||
|
add_tool = next(t for t in tools if t.name == "add_numbers")
|
||||||
|
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url=mcp_server,
|
||||||
|
selected_tool="add_numbers",
|
||||||
|
tool_input_schema=add_tool.input_schema,
|
||||||
|
tool_arguments={"a": 42, "b": 58},
|
||||||
|
credentials={ # type: ignore
|
||||||
|
"provider": "mcp",
|
||||||
|
"id": "test",
|
||||||
|
"type": "api_key",
|
||||||
|
"title": "test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, credentials=_make_fake_creds()):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
assert outputs[0][0] == "result"
|
||||||
|
assert outputs[0][1]["result"] == 100
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_flow_echo_plain_text(self, mcp_server):
|
||||||
|
"""Verify plain text (non-JSON) responses work."""
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url=mcp_server,
|
||||||
|
selected_tool="echo",
|
||||||
|
tool_input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"message": {"type": "string"}},
|
||||||
|
"required": ["message"],
|
||||||
|
},
|
||||||
|
tool_arguments={"message": "Hello from AutoGPT!"},
|
||||||
|
credentials={ # type: ignore
|
||||||
|
"provider": "mcp",
|
||||||
|
"id": "test",
|
||||||
|
"type": "api_key",
|
||||||
|
"title": "test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, credentials=_make_fake_creds()):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
assert outputs[0][0] == "result"
|
||||||
|
assert outputs[0][1] == "Hello from AutoGPT!"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_flow_unknown_tool_yields_error(self, mcp_server):
|
||||||
|
"""Calling an unknown tool should yield an error output."""
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url=mcp_server,
|
||||||
|
selected_tool="nonexistent_tool",
|
||||||
|
tool_arguments={},
|
||||||
|
credentials={ # type: ignore
|
||||||
|
"provider": "mcp",
|
||||||
|
"id": "test",
|
||||||
|
"type": "api_key",
|
||||||
|
"title": "test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, credentials=_make_fake_creds()):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
assert outputs[0][0] == "error"
|
||||||
|
assert "returned an error" in outputs[0][1]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_flow_with_auth(self, mcp_server_with_auth):
|
||||||
|
"""Full flow with authentication."""
|
||||||
|
url, token = mcp_server_with_auth
|
||||||
|
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url=url,
|
||||||
|
selected_tool="echo",
|
||||||
|
tool_input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"message": {"type": "string"}},
|
||||||
|
"required": ["message"],
|
||||||
|
},
|
||||||
|
tool_arguments={"message": "Authenticated!"},
|
||||||
|
credentials={ # type: ignore
|
||||||
|
"provider": "mcp",
|
||||||
|
"id": "test",
|
||||||
|
"type": "api_key",
|
||||||
|
"title": "test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(
|
||||||
|
input_data, credentials=_make_fake_creds(api_key=token)
|
||||||
|
):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
assert outputs[0][0] == "result"
|
||||||
|
assert outputs[0][1] == "Authenticated!"
|
||||||
609
autogpt_platform/backend/backend/blocks/mcp/test_mcp.py
Normal file
609
autogpt_platform/backend/backend/blocks/mcp/test_mcp.py
Normal file
@@ -0,0 +1,609 @@
|
|||||||
|
"""
|
||||||
|
Tests for MCP client and MCPToolBlock.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.blocks.mcp.block import MCPToolBlock
|
||||||
|
from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError
|
||||||
|
from backend.util.test import execute_block_test
|
||||||
|
|
||||||
|
# ── SSE parsing unit tests ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSEParsing:
|
||||||
|
"""Tests for SSE (text/event-stream) response parsing."""
|
||||||
|
|
||||||
|
def test_parse_sse_simple(self):
|
||||||
|
sse = (
|
||||||
|
"event: message\n"
|
||||||
|
'data: {"jsonrpc":"2.0","result":{"tools":[]},"id":1}\n'
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
body = MCPClient._parse_sse_response(sse)
|
||||||
|
assert body["result"] == {"tools": []}
|
||||||
|
assert body["id"] == 1
|
||||||
|
|
||||||
|
def test_parse_sse_with_notifications(self):
|
||||||
|
"""SSE streams can contain notifications (no id) before the response."""
|
||||||
|
sse = (
|
||||||
|
"event: message\n"
|
||||||
|
'data: {"jsonrpc":"2.0","method":"some/notification"}\n'
|
||||||
|
"\n"
|
||||||
|
"event: message\n"
|
||||||
|
'data: {"jsonrpc":"2.0","result":{"ok":true},"id":2}\n'
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
body = MCPClient._parse_sse_response(sse)
|
||||||
|
assert body["result"] == {"ok": True}
|
||||||
|
assert body["id"] == 2
|
||||||
|
|
||||||
|
def test_parse_sse_error_response(self):
|
||||||
|
sse = (
|
||||||
|
"event: message\n"
|
||||||
|
'data: {"jsonrpc":"2.0","error":{"code":-32600,"message":"Bad Request"},"id":1}\n'
|
||||||
|
)
|
||||||
|
body = MCPClient._parse_sse_response(sse)
|
||||||
|
assert "error" in body
|
||||||
|
assert body["error"]["code"] == -32600
|
||||||
|
|
||||||
|
def test_parse_sse_no_data_raises(self):
|
||||||
|
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
|
||||||
|
MCPClient._parse_sse_response("event: message\n\n")
|
||||||
|
|
||||||
|
def test_parse_sse_empty_raises(self):
|
||||||
|
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
|
||||||
|
MCPClient._parse_sse_response("")
|
||||||
|
|
||||||
|
def test_parse_sse_ignores_non_data_lines(self):
|
||||||
|
sse = (
|
||||||
|
": comment line\n"
|
||||||
|
"event: message\n"
|
||||||
|
"id: 123\n"
|
||||||
|
'data: {"jsonrpc":"2.0","result":"ok","id":1}\n'
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
body = MCPClient._parse_sse_response(sse)
|
||||||
|
assert body["result"] == "ok"
|
||||||
|
|
||||||
|
def test_parse_sse_uses_last_response(self):
|
||||||
|
"""If multiple responses exist, use the last one."""
|
||||||
|
sse = (
|
||||||
|
'data: {"jsonrpc":"2.0","result":"first","id":1}\n'
|
||||||
|
"\n"
|
||||||
|
'data: {"jsonrpc":"2.0","result":"second","id":2}\n'
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
body = MCPClient._parse_sse_response(sse)
|
||||||
|
assert body["result"] == "second"
|
||||||
|
|
||||||
|
|
||||||
|
# ── MCPClient unit tests ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPClient:
|
||||||
|
"""Tests for the MCP HTTP client."""
|
||||||
|
|
||||||
|
def test_build_headers_without_auth(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
headers = client._build_headers()
|
||||||
|
assert "Authorization" not in headers
|
||||||
|
assert headers["Content-Type"] == "application/json"
|
||||||
|
|
||||||
|
def test_build_headers_with_auth(self):
|
||||||
|
client = MCPClient("https://mcp.example.com", auth_token="my-token")
|
||||||
|
headers = client._build_headers()
|
||||||
|
assert headers["Authorization"] == "Bearer my-token"
|
||||||
|
|
||||||
|
def test_build_jsonrpc_request(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
req = client._build_jsonrpc_request("tools/list")
|
||||||
|
assert req["jsonrpc"] == "2.0"
|
||||||
|
assert req["method"] == "tools/list"
|
||||||
|
assert "id" in req
|
||||||
|
assert "params" not in req
|
||||||
|
|
||||||
|
def test_build_jsonrpc_request_with_params(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
req = client._build_jsonrpc_request(
|
||||||
|
"tools/call", {"name": "test", "arguments": {"x": 1}}
|
||||||
|
)
|
||||||
|
assert req["params"] == {"name": "test", "arguments": {"x": 1}}
|
||||||
|
|
||||||
|
def test_request_id_increments(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
req1 = client._build_jsonrpc_request("tools/list")
|
||||||
|
req2 = client._build_jsonrpc_request("tools/list")
|
||||||
|
assert req2["id"] > req1["id"]
|
||||||
|
|
||||||
|
def test_server_url_trailing_slash_stripped(self):
|
||||||
|
client = MCPClient("https://mcp.example.com/mcp/")
|
||||||
|
assert client.server_url == "https://mcp.example.com/mcp"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_request_success(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"result": {"tools": []},
|
||||||
|
"id": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", return_value={"tools": []}):
|
||||||
|
result = await client._send_request("tools/list")
|
||||||
|
assert result == {"tools": []}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_request_error(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
async def mock_send(*args, **kwargs):
|
||||||
|
raise MCPClientError("MCP server error [-32600]: Invalid Request")
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", side_effect=mock_send):
|
||||||
|
with pytest.raises(MCPClientError, match="Invalid Request"):
|
||||||
|
await client._send_request("tools/list")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_tools(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
mock_result = {
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get current weather for a city",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"city": {"type": "string"}},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "search",
|
||||||
|
"description": "Search the web",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"query": {"type": "string"}},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", return_value=mock_result):
|
||||||
|
tools = await client.list_tools()
|
||||||
|
|
||||||
|
assert len(tools) == 2
|
||||||
|
assert tools[0].name == "get_weather"
|
||||||
|
assert tools[0].description == "Get current weather for a city"
|
||||||
|
assert tools[0].input_schema["properties"]["city"]["type"] == "string"
|
||||||
|
assert tools[1].name == "search"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_tools_empty(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", return_value={"tools": []}):
|
||||||
|
tools = await client.list_tools()
|
||||||
|
|
||||||
|
assert tools == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_tools_none_result(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", return_value=None):
|
||||||
|
tools = await client.list_tools()
|
||||||
|
|
||||||
|
assert tools == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_tool_success(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
mock_result = {
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": json.dumps({"temp": 20, "city": "London"})}
|
||||||
|
],
|
||||||
|
"isError": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", return_value=mock_result):
|
||||||
|
result = await client.call_tool("get_weather", {"city": "London"})
|
||||||
|
|
||||||
|
assert not result.is_error
|
||||||
|
assert len(result.content) == 1
|
||||||
|
assert result.content[0]["type"] == "text"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_tool_error(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
mock_result = {
|
||||||
|
"content": [{"type": "text", "text": "City not found"}],
|
||||||
|
"isError": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", return_value=mock_result):
|
||||||
|
result = await client.call_tool("get_weather", {"city": "???"})
|
||||||
|
|
||||||
|
assert result.is_error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_tool_none_result(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", return_value=None):
|
||||||
|
result = await client.call_tool("get_weather", {"city": "London"})
|
||||||
|
|
||||||
|
assert result.is_error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_initialize(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
mock_result = {
|
||||||
|
"protocolVersion": "2025-03-26",
|
||||||
|
"capabilities": {"tools": {}},
|
||||||
|
"serverInfo": {"name": "test-server", "version": "1.0.0"},
|
||||||
|
}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(client, "_send_request", return_value=mock_result) as mock_req,
|
||||||
|
patch.object(client, "_send_notification") as mock_notif,
|
||||||
|
):
|
||||||
|
result = await client.initialize()
|
||||||
|
|
||||||
|
mock_req.assert_called_once()
|
||||||
|
mock_notif.assert_called_once_with("notifications/initialized")
|
||||||
|
assert result["protocolVersion"] == "2025-03-26"
|
||||||
|
|
||||||
|
|
||||||
|
# ── MCPToolBlock unit tests ──────────────────────────────────────────
|
||||||
|
|
||||||
|
MOCK_USER_ID = "test-user-123"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPToolBlock:
|
||||||
|
"""Tests for the MCPToolBlock."""
|
||||||
|
|
||||||
|
def test_block_instantiation(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
assert block.id == "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
|
||||||
|
assert block.name == "MCPToolBlock"
|
||||||
|
|
||||||
|
def test_input_schema_has_required_fields(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
schema = block.input_schema.jsonschema()
|
||||||
|
props = schema.get("properties", {})
|
||||||
|
assert "server_url" in props
|
||||||
|
assert "selected_tool" in props
|
||||||
|
assert "tool_arguments" in props
|
||||||
|
assert "credential_id" in props
|
||||||
|
|
||||||
|
def test_output_schema(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
schema = block.output_schema.jsonschema()
|
||||||
|
props = schema.get("properties", {})
|
||||||
|
assert "result" in props
|
||||||
|
assert "error" in props
|
||||||
|
|
||||||
|
def test_get_input_schema_with_tool_schema(self):
|
||||||
|
tool_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"query": {"type": "string"}},
|
||||||
|
"required": ["query"],
|
||||||
|
}
|
||||||
|
data = {"tool_input_schema": tool_schema}
|
||||||
|
result = MCPToolBlock.Input.get_input_schema(data)
|
||||||
|
assert result == tool_schema
|
||||||
|
|
||||||
|
def test_get_input_schema_without_tool_schema(self):
|
||||||
|
result = MCPToolBlock.Input.get_input_schema({})
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_get_input_defaults(self):
|
||||||
|
data = {"tool_arguments": {"city": "London"}}
|
||||||
|
result = MCPToolBlock.Input.get_input_defaults(data)
|
||||||
|
assert result == {"city": "London"}
|
||||||
|
|
||||||
|
def test_get_missing_input(self):
|
||||||
|
data = {
|
||||||
|
"tool_input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {"type": "string"},
|
||||||
|
"units": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["city", "units"],
|
||||||
|
},
|
||||||
|
"tool_arguments": {"city": "London"},
|
||||||
|
}
|
||||||
|
missing = MCPToolBlock.Input.get_missing_input(data)
|
||||||
|
assert missing == {"units"}
|
||||||
|
|
||||||
|
def test_get_missing_input_all_present(self):
|
||||||
|
data = {
|
||||||
|
"tool_input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"city": {"type": "string"}},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
"tool_arguments": {"city": "London"},
|
||||||
|
}
|
||||||
|
missing = MCPToolBlock.Input.get_missing_input(data)
|
||||||
|
assert missing == set()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_with_mock(self):
|
||||||
|
"""Test the block using the built-in test infrastructure."""
|
||||||
|
block = MCPToolBlock()
|
||||||
|
await execute_block_test(block)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_missing_server_url(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url="",
|
||||||
|
selected_tool="test",
|
||||||
|
)
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||||
|
outputs.append((name, data))
|
||||||
|
assert outputs == [("error", "MCP server URL is required")]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_missing_tool(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url="https://mcp.example.com/mcp",
|
||||||
|
selected_tool="",
|
||||||
|
)
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||||
|
outputs.append((name, data))
|
||||||
|
assert outputs == [
|
||||||
|
("error", "No tool selected. Please select a tool from the dropdown.")
|
||||||
|
]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_success(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url="https://mcp.example.com/mcp",
|
||||||
|
selected_tool="get_weather",
|
||||||
|
tool_input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"city": {"type": "string"}},
|
||||||
|
},
|
||||||
|
tool_arguments={"city": "London"},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_call(*args, **kwargs):
|
||||||
|
return {"temp": 20, "city": "London"}
|
||||||
|
|
||||||
|
block._call_mcp_tool = mock_call # type: ignore
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
assert outputs[0][0] == "result"
|
||||||
|
assert outputs[0][1] == {"temp": 20, "city": "London"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_mcp_error(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url="https://mcp.example.com/mcp",
|
||||||
|
selected_tool="bad_tool",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_call(*args, **kwargs):
|
||||||
|
raise MCPClientError("Tool not found")
|
||||||
|
|
||||||
|
block._call_mcp_tool = mock_call # type: ignore
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert outputs[0][0] == "error"
|
||||||
|
assert "Tool not found" in outputs[0][1]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_mcp_tool_parses_json_text(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
|
||||||
|
mock_result = MCPCallResult(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": '{"temp": 20}'},
|
||||||
|
],
|
||||||
|
is_error=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_init(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def mock_call(self, name, args):
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(MCPClient, "initialize", mock_init),
|
||||||
|
patch.object(MCPClient, "call_tool", mock_call),
|
||||||
|
):
|
||||||
|
result = await block._call_mcp_tool(
|
||||||
|
"https://mcp.example.com", "test_tool", {}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {"temp": 20}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_mcp_tool_plain_text(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
|
||||||
|
mock_result = MCPCallResult(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "Hello, world!"},
|
||||||
|
],
|
||||||
|
is_error=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_init(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def mock_call(self, name, args):
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(MCPClient, "initialize", mock_init),
|
||||||
|
patch.object(MCPClient, "call_tool", mock_call),
|
||||||
|
):
|
||||||
|
result = await block._call_mcp_tool(
|
||||||
|
"https://mcp.example.com", "test_tool", {}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == "Hello, world!"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_mcp_tool_multiple_content(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
|
||||||
|
mock_result = MCPCallResult(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "Part 1"},
|
||||||
|
{"type": "text", "text": '{"part": 2}'},
|
||||||
|
],
|
||||||
|
is_error=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_init(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def mock_call(self, name, args):
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(MCPClient, "initialize", mock_init),
|
||||||
|
patch.object(MCPClient, "call_tool", mock_call),
|
||||||
|
):
|
||||||
|
result = await block._call_mcp_tool(
|
||||||
|
"https://mcp.example.com", "test_tool", {}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == ["Part 1", {"part": 2}]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_mcp_tool_error_result(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
|
||||||
|
mock_result = MCPCallResult(
|
||||||
|
content=[{"type": "text", "text": "Something went wrong"}],
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_init(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def mock_call(self, name, args):
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(MCPClient, "initialize", mock_init),
|
||||||
|
patch.object(MCPClient, "call_tool", mock_call),
|
||||||
|
):
|
||||||
|
with pytest.raises(MCPClientError, match="returned an error"):
|
||||||
|
await block._call_mcp_tool("https://mcp.example.com", "test_tool", {})
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_mcp_tool_image_content(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
|
||||||
|
mock_result = MCPCallResult(
|
||||||
|
content=[
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"data": "base64data==",
|
||||||
|
"mimeType": "image/png",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
is_error=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_init(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def mock_call(self, name, args):
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(MCPClient, "initialize", mock_init),
|
||||||
|
patch.object(MCPClient, "call_tool", mock_call),
|
||||||
|
):
|
||||||
|
result = await block._call_mcp_tool(
|
||||||
|
"https://mcp.example.com", "test_tool", {}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"type": "image",
|
||||||
|
"data": "base64data==",
|
||||||
|
"mimeType": "image/png",
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_with_credential_id(self):
|
||||||
|
"""Verify the block resolves credential_id and passes auth token."""
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url="https://mcp.example.com/mcp",
|
||||||
|
selected_tool="test_tool",
|
||||||
|
credential_id="cred-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
captured_tokens = []
|
||||||
|
|
||||||
|
async def mock_call(server_url, tool_name, arguments, auth_token=None):
|
||||||
|
captured_tokens.append(auth_token)
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
async def mock_resolve(self, cred_id, uid):
|
||||||
|
return "resolved-token"
|
||||||
|
|
||||||
|
block._call_mcp_tool = mock_call # type: ignore
|
||||||
|
|
||||||
|
with patch.object(MCPToolBlock, "_resolve_auth_token", mock_resolve):
|
||||||
|
async for _ in block.run(input_data, user_id=MOCK_USER_ID):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert captured_tokens == ["resolved-token"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_without_credential_id(self):
|
||||||
|
"""Verify the block works without credentials (public server)."""
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url="https://mcp.example.com/mcp",
|
||||||
|
selected_tool="test_tool",
|
||||||
|
)
|
||||||
|
|
||||||
|
captured_tokens = []
|
||||||
|
|
||||||
|
async def mock_call(server_url, tool_name, arguments, auth_token=None):
|
||||||
|
captured_tokens.append(auth_token)
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
block._call_mcp_tool = mock_call # type: ignore
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert captured_tokens == [None]
|
||||||
|
assert outputs == [("result", "ok")]
|
||||||
242
autogpt_platform/backend/backend/blocks/mcp/test_oauth.py
Normal file
242
autogpt_platform/backend/backend/blocks/mcp/test_oauth.py
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
"""
|
||||||
|
Tests for MCP OAuth handler.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.blocks.mcp.client import MCPClient
|
||||||
|
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||||
|
from backend.data.model import OAuth2Credentials
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_response(json_data: dict, status: int = 200) -> MagicMock:
|
||||||
|
"""Create a mock Response with synchronous json() (matching Requests.Response)."""
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.status = status
|
||||||
|
resp.ok = 200 <= status < 300
|
||||||
|
resp.json.return_value = json_data
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPOAuthHandler:
|
||||||
|
"""Tests for the MCPOAuthHandler."""
|
||||||
|
|
||||||
|
def _make_handler(self, **overrides) -> MCPOAuthHandler:
|
||||||
|
defaults = {
|
||||||
|
"client_id": "test-client-id",
|
||||||
|
"client_secret": "test-client-secret",
|
||||||
|
"redirect_uri": "https://app.example.com/callback",
|
||||||
|
"authorize_url": "https://auth.example.com/authorize",
|
||||||
|
"token_url": "https://auth.example.com/token",
|
||||||
|
}
|
||||||
|
defaults.update(overrides)
|
||||||
|
return MCPOAuthHandler(**defaults)
|
||||||
|
|
||||||
|
def test_get_login_url_basic(self):
|
||||||
|
handler = self._make_handler()
|
||||||
|
url = handler.get_login_url(
|
||||||
|
scopes=["read", "write"],
|
||||||
|
state="random-state-token",
|
||||||
|
code_challenge="S256-challenge-value",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "https://auth.example.com/authorize?" in url
|
||||||
|
assert "response_type=code" in url
|
||||||
|
assert "client_id=test-client-id" in url
|
||||||
|
assert "state=random-state-token" in url
|
||||||
|
assert "code_challenge=S256-challenge-value" in url
|
||||||
|
assert "code_challenge_method=S256" in url
|
||||||
|
assert "scope=read+write" in url
|
||||||
|
|
||||||
|
def test_get_login_url_with_resource(self):
|
||||||
|
handler = self._make_handler(resource_url="https://mcp.example.com/mcp")
|
||||||
|
url = handler.get_login_url(
|
||||||
|
scopes=[], state="state", code_challenge="challenge"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "resource=https" in url
|
||||||
|
|
||||||
|
def test_get_login_url_without_pkce(self):
|
||||||
|
handler = self._make_handler()
|
||||||
|
url = handler.get_login_url(scopes=["read"], state="state", code_challenge=None)
|
||||||
|
|
||||||
|
assert "code_challenge" not in url
|
||||||
|
assert "code_challenge_method" not in url
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exchange_code_for_tokens(self):
|
||||||
|
handler = self._make_handler()
|
||||||
|
|
||||||
|
resp = _mock_response(
|
||||||
|
{
|
||||||
|
"access_token": "new-access-token",
|
||||||
|
"refresh_token": "new-refresh-token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
||||||
|
instance = MockRequests.return_value
|
||||||
|
instance.post = AsyncMock(return_value=resp)
|
||||||
|
|
||||||
|
creds = await handler.exchange_code_for_tokens(
|
||||||
|
code="auth-code",
|
||||||
|
scopes=["read"],
|
||||||
|
code_verifier="pkce-verifier",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(creds, OAuth2Credentials)
|
||||||
|
assert creds.access_token.get_secret_value() == "new-access-token"
|
||||||
|
assert creds.refresh_token is not None
|
||||||
|
assert creds.refresh_token.get_secret_value() == "new-refresh-token"
|
||||||
|
assert creds.scopes == ["read"]
|
||||||
|
assert creds.access_token_expires_at is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_tokens(self):
|
||||||
|
handler = self._make_handler()
|
||||||
|
|
||||||
|
existing_creds = OAuth2Credentials(
|
||||||
|
id="existing-id",
|
||||||
|
provider="mcp",
|
||||||
|
access_token=SecretStr("old-token"),
|
||||||
|
refresh_token=SecretStr("old-refresh"),
|
||||||
|
scopes=["read"],
|
||||||
|
title="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = _mock_response(
|
||||||
|
{
|
||||||
|
"access_token": "refreshed-token",
|
||||||
|
"refresh_token": "new-refresh",
|
||||||
|
"expires_in": 3600,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
||||||
|
instance = MockRequests.return_value
|
||||||
|
instance.post = AsyncMock(return_value=resp)
|
||||||
|
|
||||||
|
refreshed = await handler._refresh_tokens(existing_creds)
|
||||||
|
|
||||||
|
assert refreshed.id == "existing-id"
|
||||||
|
assert refreshed.access_token.get_secret_value() == "refreshed-token"
|
||||||
|
assert refreshed.refresh_token is not None
|
||||||
|
assert refreshed.refresh_token.get_secret_value() == "new-refresh"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_tokens_no_refresh_token(self):
|
||||||
|
handler = self._make_handler()
|
||||||
|
|
||||||
|
creds = OAuth2Credentials(
|
||||||
|
provider="mcp",
|
||||||
|
access_token=SecretStr("token"),
|
||||||
|
scopes=["read"],
|
||||||
|
title="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No refresh token"):
|
||||||
|
await handler._refresh_tokens(creds)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_revoke_tokens_no_url(self):
|
||||||
|
handler = self._make_handler(revoke_url=None)
|
||||||
|
|
||||||
|
creds = OAuth2Credentials(
|
||||||
|
provider="mcp",
|
||||||
|
access_token=SecretStr("token"),
|
||||||
|
scopes=[],
|
||||||
|
title="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await handler.revoke_tokens(creds)
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_revoke_tokens_with_url(self):
|
||||||
|
handler = self._make_handler(revoke_url="https://auth.example.com/revoke")
|
||||||
|
|
||||||
|
creds = OAuth2Credentials(
|
||||||
|
provider="mcp",
|
||||||
|
access_token=SecretStr("token"),
|
||||||
|
scopes=[],
|
||||||
|
title="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = _mock_response({}, status=200)
|
||||||
|
|
||||||
|
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
||||||
|
instance = MockRequests.return_value
|
||||||
|
instance.post = AsyncMock(return_value=resp)
|
||||||
|
|
||||||
|
result = await handler.revoke_tokens(creds)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPClientDiscovery:
|
||||||
|
"""Tests for MCPClient OAuth metadata discovery."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_discover_auth_found(self):
|
||||||
|
client = MCPClient("https://mcp.example.com/mcp")
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"authorization_servers": ["https://auth.example.com"],
|
||||||
|
"resource": "https://mcp.example.com/mcp",
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = _mock_response(metadata, status=200)
|
||||||
|
|
||||||
|
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
||||||
|
instance = MockRequests.return_value
|
||||||
|
instance.get = AsyncMock(return_value=resp)
|
||||||
|
|
||||||
|
result = await client.discover_auth()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["authorization_servers"] == ["https://auth.example.com"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_discover_auth_not_found(self):
|
||||||
|
client = MCPClient("https://mcp.example.com/mcp")
|
||||||
|
|
||||||
|
resp = _mock_response({}, status=404)
|
||||||
|
|
||||||
|
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
||||||
|
instance = MockRequests.return_value
|
||||||
|
instance.get = AsyncMock(return_value=resp)
|
||||||
|
|
||||||
|
result = await client.discover_auth()
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_discover_auth_server_metadata(self):
|
||||||
|
client = MCPClient("https://mcp.example.com/mcp")
|
||||||
|
|
||||||
|
server_metadata = {
|
||||||
|
"issuer": "https://auth.example.com",
|
||||||
|
"authorization_endpoint": "https://auth.example.com/authorize",
|
||||||
|
"token_endpoint": "https://auth.example.com/token",
|
||||||
|
"registration_endpoint": "https://auth.example.com/register",
|
||||||
|
"code_challenge_methods_supported": ["S256"],
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = _mock_response(server_metadata, status=200)
|
||||||
|
|
||||||
|
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
||||||
|
instance = MockRequests.return_value
|
||||||
|
instance.get = AsyncMock(return_value=resp)
|
||||||
|
|
||||||
|
result = await client.discover_auth_server_metadata(
|
||||||
|
"https://auth.example.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["authorization_endpoint"] == "https://auth.example.com/authorize"
|
||||||
|
assert result["token_endpoint"] == "https://auth.example.com/token"
|
||||||
162
autogpt_platform/backend/backend/blocks/mcp/test_server.py
Normal file
162
autogpt_platform/backend/backend/blocks/mcp/test_server.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
"""
|
||||||
|
Minimal MCP server for integration testing.
|
||||||
|
|
||||||
|
Implements the MCP Streamable HTTP transport (JSON-RPC 2.0 over HTTP POST)
|
||||||
|
with a few sample tools. Runs on localhost with a random available port.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Sample tools this test server exposes
|
||||||
|
TEST_TOOLS = [
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get current weather for a city",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "City name",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "add_numbers",
|
||||||
|
"description": "Add two numbers together",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"a": {"type": "number", "description": "First number"},
|
||||||
|
"b": {"type": "number", "description": "Second number"},
|
||||||
|
},
|
||||||
|
"required": ["a", "b"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "echo",
|
||||||
|
"description": "Echo back the input message",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"message": {"type": "string", "description": "Message to echo"},
|
||||||
|
},
|
||||||
|
"required": ["message"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_initialize(params: dict) -> dict:
|
||||||
|
return {
|
||||||
|
"protocolVersion": "2025-03-26",
|
||||||
|
"capabilities": {"tools": {"listChanged": False}},
|
||||||
|
"serverInfo": {"name": "test-mcp-server", "version": "1.0.0"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_tools_list(params: dict) -> dict:
|
||||||
|
return {"tools": TEST_TOOLS}
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_tools_call(params: dict) -> dict:
|
||||||
|
tool_name = params.get("name", "")
|
||||||
|
arguments = params.get("arguments", {})
|
||||||
|
|
||||||
|
if tool_name == "get_weather":
|
||||||
|
city = arguments.get("city", "Unknown")
|
||||||
|
return {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": json.dumps(
|
||||||
|
{"city": city, "temperature": 22, "condition": "sunny"}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
elif tool_name == "add_numbers":
|
||||||
|
a = arguments.get("a", 0)
|
||||||
|
b = arguments.get("b", 0)
|
||||||
|
return {
|
||||||
|
"content": [{"type": "text", "text": json.dumps({"result": a + b})}],
|
||||||
|
}
|
||||||
|
|
||||||
|
elif tool_name == "echo":
|
||||||
|
message = arguments.get("message", "")
|
||||||
|
return {
|
||||||
|
"content": [{"type": "text", "text": message}],
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}],
|
||||||
|
"isError": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
HANDLERS = {
|
||||||
|
"initialize": _handle_initialize,
|
||||||
|
"tools/list": _handle_tools_list,
|
||||||
|
"tools/call": _handle_tools_call,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_mcp_request(request: web.Request) -> web.Response:
|
||||||
|
"""Handle incoming MCP JSON-RPC 2.0 requests."""
|
||||||
|
# Check auth if configured
|
||||||
|
expected_token = request.app.get("auth_token")
|
||||||
|
if expected_token:
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if auth_header != f"Bearer {expected_token}":
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"error": {"code": -32001, "message": "Unauthorized"},
|
||||||
|
"id": None,
|
||||||
|
},
|
||||||
|
status=401,
|
||||||
|
)
|
||||||
|
|
||||||
|
body = await request.json()
|
||||||
|
|
||||||
|
# Handle notifications (no id field) — just acknowledge
|
||||||
|
if "id" not in body:
|
||||||
|
return web.Response(status=202)
|
||||||
|
|
||||||
|
method = body.get("method", "")
|
||||||
|
params = body.get("params", {})
|
||||||
|
request_id = body.get("id")
|
||||||
|
|
||||||
|
handler = HANDLERS.get(method)
|
||||||
|
if not handler:
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"error": {
|
||||||
|
"code": -32601,
|
||||||
|
"message": f"Method not found: {method}",
|
||||||
|
},
|
||||||
|
"id": request_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = handler(params)
|
||||||
|
return web.json_response({"jsonrpc": "2.0", "result": result, "id": request_id})
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_mcp_app(auth_token: str | None = None) -> web.Application:
|
||||||
|
"""Create an aiohttp app that acts as an MCP server."""
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_post("/mcp", handle_mcp_request)
|
||||||
|
if auth_token:
|
||||||
|
app["auth_token"] = auth_token
|
||||||
|
return app
|
||||||
@@ -1,246 +0,0 @@
|
|||||||
import os
|
|
||||||
import tempfile
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
|
||||||
from moviepy.video.fx.Loop import Loop
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class MediaDurationBlock(Block):
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
media_in: MediaFileType = SchemaField(
|
|
||||||
description="Media input (URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
is_video: bool = SchemaField(
|
|
||||||
description="Whether the media is a video (True) or audio (False).",
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
duration: float = SchemaField(
|
|
||||||
description="Duration of the media file (in seconds)."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
|
|
||||||
description="Block to get the duration of a media file.",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=MediaDurationBlock.Input,
|
|
||||||
output_schema=MediaDurationBlock.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
# 1) Store the input media locally
|
|
||||||
local_media_path = await store_media_file(
|
|
||||||
file=input_data.media_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
media_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, local_media_path
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2) Load the clip
|
|
||||||
if input_data.is_video:
|
|
||||||
clip = VideoFileClip(media_abspath)
|
|
||||||
else:
|
|
||||||
clip = AudioFileClip(media_abspath)
|
|
||||||
|
|
||||||
yield "duration", clip.duration
|
|
||||||
|
|
||||||
|
|
||||||
class LoopVideoBlock(Block):
|
|
||||||
"""
|
|
||||||
Block for looping (repeating) a video clip until a given duration or number of loops.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
video_in: MediaFileType = SchemaField(
|
|
||||||
description="The input video (can be a URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
# Provide EITHER a `duration` or `n_loops` or both. We'll demonstrate `duration`.
|
|
||||||
duration: Optional[float] = SchemaField(
|
|
||||||
description="Target duration (in seconds) to loop the video to. If omitted, defaults to no looping.",
|
|
||||||
default=None,
|
|
||||||
ge=0.0,
|
|
||||||
)
|
|
||||||
n_loops: Optional[int] = SchemaField(
|
|
||||||
description="Number of times to repeat the video. If omitted, defaults to 1 (no repeat).",
|
|
||||||
default=None,
|
|
||||||
ge=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: str = SchemaField(
|
|
||||||
description="Looped video returned either as a relative path or a data URI."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="8bf9eef6-5451-4213-b265-25306446e94b",
|
|
||||||
description="Block to loop a video to a given duration or number of repeats.",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=LoopVideoBlock.Input,
|
|
||||||
output_schema=LoopVideoBlock.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
assert execution_context.node_exec_id is not None
|
|
||||||
graph_exec_id = execution_context.graph_exec_id
|
|
||||||
node_exec_id = execution_context.node_exec_id
|
|
||||||
|
|
||||||
# 1) Store the input video locally
|
|
||||||
local_video_path = await store_media_file(
|
|
||||||
file=input_data.video_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
|
||||||
|
|
||||||
# 2) Load the clip
|
|
||||||
clip = VideoFileClip(input_abspath)
|
|
||||||
|
|
||||||
# 3) Apply the loop effect
|
|
||||||
looped_clip = clip
|
|
||||||
if input_data.duration:
|
|
||||||
# Loop until we reach the specified duration
|
|
||||||
looped_clip = looped_clip.with_effects([Loop(duration=input_data.duration)])
|
|
||||||
elif input_data.n_loops:
|
|
||||||
looped_clip = looped_clip.with_effects([Loop(n=input_data.n_loops)])
|
|
||||||
else:
|
|
||||||
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
|
|
||||||
|
|
||||||
assert isinstance(looped_clip, VideoFileClip)
|
|
||||||
|
|
||||||
# 4) Save the looped output
|
|
||||||
output_filename = MediaFileType(
|
|
||||||
f"{node_exec_id}_looped_{os.path.basename(local_video_path)}"
|
|
||||||
)
|
|
||||||
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
|
||||||
|
|
||||||
looped_clip = looped_clip.with_audio(clip.audio)
|
|
||||||
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
|
||||||
|
|
||||||
# Return output - for_block_output returns workspace:// if available, else data URI
|
|
||||||
video_out = await store_media_file(
|
|
||||||
file=output_filename,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
|
|
||||||
|
|
||||||
class AddAudioToVideoBlock(Block):
|
|
||||||
"""
|
|
||||||
Block that adds (attaches) an audio track to an existing video.
|
|
||||||
Optionally scale the volume of the new track.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
video_in: MediaFileType = SchemaField(
|
|
||||||
description="Video input (URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
audio_in: MediaFileType = SchemaField(
|
|
||||||
description="Audio input (URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
volume: float = SchemaField(
|
|
||||||
description="Volume scale for the newly attached audio track (1.0 = original).",
|
|
||||||
default=1.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: MediaFileType = SchemaField(
|
|
||||||
description="Final video (with attached audio), as a path or data URI."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="3503748d-62b6-4425-91d6-725b064af509",
|
|
||||||
description="Block to attach an audio file to a video file using moviepy.",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=AddAudioToVideoBlock.Input,
|
|
||||||
output_schema=AddAudioToVideoBlock.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
assert execution_context.node_exec_id is not None
|
|
||||||
graph_exec_id = execution_context.graph_exec_id
|
|
||||||
node_exec_id = execution_context.node_exec_id
|
|
||||||
|
|
||||||
# 1) Store the inputs locally
|
|
||||||
local_video_path = await store_media_file(
|
|
||||||
file=input_data.video_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
local_audio_path = await store_media_file(
|
|
||||||
file=input_data.audio_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|
||||||
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
|
|
||||||
video_abspath = os.path.join(abs_temp_dir, local_video_path)
|
|
||||||
audio_abspath = os.path.join(abs_temp_dir, local_audio_path)
|
|
||||||
|
|
||||||
# 2) Load video + audio with moviepy
|
|
||||||
video_clip = VideoFileClip(video_abspath)
|
|
||||||
audio_clip = AudioFileClip(audio_abspath)
|
|
||||||
# Optionally scale volume
|
|
||||||
if input_data.volume != 1.0:
|
|
||||||
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
|
|
||||||
|
|
||||||
# 3) Attach the new audio track
|
|
||||||
final_clip = video_clip.with_audio(audio_clip)
|
|
||||||
|
|
||||||
# 4) Write to output file
|
|
||||||
output_filename = MediaFileType(
|
|
||||||
f"{node_exec_id}_audio_attached_{os.path.basename(local_video_path)}"
|
|
||||||
)
|
|
||||||
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
|
||||||
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
|
||||||
|
|
||||||
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
|
||||||
video_out = await store_media_file(
|
|
||||||
file=output_filename,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.blocks.encoder_block import TextEncoderBlock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_encoder_basic():
|
||||||
|
"""Test basic encoding of newlines and special characters."""
|
||||||
|
block = TextEncoderBlock()
|
||||||
|
result = []
|
||||||
|
async for output in block.run(TextEncoderBlock.Input(text="Hello\nWorld")):
|
||||||
|
result.append(output)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0][0] == "encoded_text"
|
||||||
|
assert result[0][1] == "Hello\\nWorld"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_encoder_multiple_escapes():
|
||||||
|
"""Test encoding of multiple escape sequences."""
|
||||||
|
block = TextEncoderBlock()
|
||||||
|
result = []
|
||||||
|
async for output in block.run(
|
||||||
|
TextEncoderBlock.Input(text="Line1\nLine2\tTabbed\rCarriage")
|
||||||
|
):
|
||||||
|
result.append(output)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0][0] == "encoded_text"
|
||||||
|
assert "\\n" in result[0][1]
|
||||||
|
assert "\\t" in result[0][1]
|
||||||
|
assert "\\r" in result[0][1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_encoder_unicode():
|
||||||
|
"""Test that unicode characters are handled correctly."""
|
||||||
|
block = TextEncoderBlock()
|
||||||
|
result = []
|
||||||
|
async for output in block.run(TextEncoderBlock.Input(text="Hello 世界\n")):
|
||||||
|
result.append(output)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0][0] == "encoded_text"
|
||||||
|
# Unicode characters should be escaped as \uXXXX sequences
|
||||||
|
assert "\\n" in result[0][1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_encoder_empty_string():
|
||||||
|
"""Test encoding of an empty string."""
|
||||||
|
block = TextEncoderBlock()
|
||||||
|
result = []
|
||||||
|
async for output in block.run(TextEncoderBlock.Input(text="")):
|
||||||
|
result.append(output)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0][0] == "encoded_text"
|
||||||
|
assert result[0][1] == ""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_encoder_error_handling():
|
||||||
|
"""Test that encoding errors are handled gracefully."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
block = TextEncoderBlock()
|
||||||
|
result = []
|
||||||
|
|
||||||
|
with patch("codecs.encode", side_effect=Exception("Mocked encoding error")):
|
||||||
|
async for output in block.run(TextEncoderBlock.Input(text="test")):
|
||||||
|
result.append(output)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0][0] == "error"
|
||||||
|
assert "Mocked encoding error" in result[0][1]
|
||||||
37
autogpt_platform/backend/backend/blocks/video/__init__.py
Normal file
37
autogpt_platform/backend/backend/blocks/video/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""Video editing blocks for AutoGPT Platform.
|
||||||
|
|
||||||
|
This module provides blocks for:
|
||||||
|
- Downloading videos from URLs (YouTube, Vimeo, news sites, direct links)
|
||||||
|
- Clipping/trimming video segments
|
||||||
|
- Concatenating multiple videos
|
||||||
|
- Adding text overlays
|
||||||
|
- Adding AI-generated narration
|
||||||
|
- Getting media duration
|
||||||
|
- Looping videos
|
||||||
|
- Adding audio to videos
|
||||||
|
|
||||||
|
Dependencies:
|
||||||
|
- yt-dlp: For video downloading
|
||||||
|
- moviepy: For video editing operations
|
||||||
|
- elevenlabs: For AI narration (optional)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from backend.blocks.video.add_audio import AddAudioToVideoBlock
|
||||||
|
from backend.blocks.video.clip import VideoClipBlock
|
||||||
|
from backend.blocks.video.concat import VideoConcatBlock
|
||||||
|
from backend.blocks.video.download import VideoDownloadBlock
|
||||||
|
from backend.blocks.video.duration import MediaDurationBlock
|
||||||
|
from backend.blocks.video.loop import LoopVideoBlock
|
||||||
|
from backend.blocks.video.narration import VideoNarrationBlock
|
||||||
|
from backend.blocks.video.text_overlay import VideoTextOverlayBlock
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AddAudioToVideoBlock",
|
||||||
|
"LoopVideoBlock",
|
||||||
|
"MediaDurationBlock",
|
||||||
|
"VideoClipBlock",
|
||||||
|
"VideoConcatBlock",
|
||||||
|
"VideoDownloadBlock",
|
||||||
|
"VideoNarrationBlock",
|
||||||
|
"VideoTextOverlayBlock",
|
||||||
|
]
|
||||||
131
autogpt_platform/backend/backend/blocks/video/_utils.py
Normal file
131
autogpt_platform/backend/backend/blocks/video/_utils.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""Shared utilities for video blocks."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Known operation tags added by video blocks
|
||||||
|
_VIDEO_OPS = (
|
||||||
|
r"(?:clip|overlay|narrated|looped|concat|audio_attached|with_audio|narration)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Matches: {node_exec_id}_{operation}_ where node_exec_id contains a UUID
|
||||||
|
_BLOCK_PREFIX_RE = re.compile(
|
||||||
|
r"^[a-zA-Z0-9_-]*"
|
||||||
|
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
||||||
|
r"[a-zA-Z0-9_-]*"
|
||||||
|
r"_" + _VIDEO_OPS + r"_"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Matches: a lone {node_exec_id}_ prefix (no operation keyword, e.g. download output)
|
||||||
|
_UUID_PREFIX_RE = re.compile(
|
||||||
|
r"^[a-zA-Z0-9_-]*"
|
||||||
|
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
||||||
|
r"[a-zA-Z0-9_-]*_"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_source_name(input_path: str, max_length: int = 50) -> str:
|
||||||
|
"""Extract the original source filename by stripping block-generated prefixes.
|
||||||
|
|
||||||
|
Iteratively removes {node_exec_id}_{operation}_ prefixes that accumulate
|
||||||
|
when chaining video blocks, recovering the original human-readable name.
|
||||||
|
|
||||||
|
Safe for plain filenames (no UUID -> no stripping).
|
||||||
|
Falls back to "video" if everything is stripped.
|
||||||
|
"""
|
||||||
|
stem = Path(input_path).stem
|
||||||
|
|
||||||
|
# Pass 1: strip {node_exec_id}_{operation}_ prefixes iteratively
|
||||||
|
while _BLOCK_PREFIX_RE.match(stem):
|
||||||
|
stem = _BLOCK_PREFIX_RE.sub("", stem, count=1)
|
||||||
|
|
||||||
|
# Pass 2: strip a lone {node_exec_id}_ prefix (e.g. from download block)
|
||||||
|
if _UUID_PREFIX_RE.match(stem):
|
||||||
|
stem = _UUID_PREFIX_RE.sub("", stem, count=1)
|
||||||
|
|
||||||
|
if not stem:
|
||||||
|
return "video"
|
||||||
|
|
||||||
|
return stem[:max_length]
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_codecs(output_path: str) -> tuple[str, str]:
|
||||||
|
"""Get appropriate video and audio codecs based on output file extension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_path: Path to the output file (used to determine extension)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (video_codec, audio_codec)
|
||||||
|
|
||||||
|
Codec mappings:
|
||||||
|
- .mp4: H.264 + AAC (universal compatibility)
|
||||||
|
- .webm: VP8 + Vorbis (web streaming)
|
||||||
|
- .mkv: H.264 + AAC (container supports many codecs)
|
||||||
|
- .mov: H.264 + AAC (Apple QuickTime, widely compatible)
|
||||||
|
- .m4v: H.264 + AAC (Apple iTunes/devices)
|
||||||
|
- .avi: MPEG-4 + MP3 (legacy Windows)
|
||||||
|
"""
|
||||||
|
ext = os.path.splitext(output_path)[1].lower()
|
||||||
|
|
||||||
|
codec_map: dict[str, tuple[str, str]] = {
|
||||||
|
".mp4": ("libx264", "aac"),
|
||||||
|
".webm": ("libvpx", "libvorbis"),
|
||||||
|
".mkv": ("libx264", "aac"),
|
||||||
|
".mov": ("libx264", "aac"),
|
||||||
|
".m4v": ("libx264", "aac"),
|
||||||
|
".avi": ("mpeg4", "libmp3lame"),
|
||||||
|
}
|
||||||
|
|
||||||
|
return codec_map.get(ext, ("libx264", "aac"))
|
||||||
|
|
||||||
|
|
||||||
|
def strip_chapters_inplace(video_path: str) -> None:
|
||||||
|
"""Strip chapter metadata from a media file in-place using ffmpeg.
|
||||||
|
|
||||||
|
MoviePy 2.x crashes with IndexError when parsing files with embedded
|
||||||
|
chapter metadata (https://github.com/Zulko/moviepy/issues/2419).
|
||||||
|
This strips chapters without re-encoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Absolute path to the media file to strip chapters from.
|
||||||
|
"""
|
||||||
|
base, ext = os.path.splitext(video_path)
|
||||||
|
tmp_path = base + ".tmp" + ext
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
[
|
||||||
|
"ffmpeg",
|
||||||
|
"-y",
|
||||||
|
"-i",
|
||||||
|
video_path,
|
||||||
|
"-map_chapters",
|
||||||
|
"-1",
|
||||||
|
"-codec",
|
||||||
|
"copy",
|
||||||
|
tmp_path,
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=300,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
logger.warning(
|
||||||
|
"ffmpeg chapter strip failed (rc=%d): %s",
|
||||||
|
result.returncode,
|
||||||
|
result.stderr,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
os.replace(tmp_path, video_path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.warning("ffmpeg not found; skipping chapter strip")
|
||||||
|
finally:
|
||||||
|
if os.path.exists(tmp_path):
|
||||||
|
os.unlink(tmp_path)
|
||||||
113
autogpt_platform/backend/backend/blocks/video/add_audio.py
Normal file
113
autogpt_platform/backend/backend/blocks/video/add_audio.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""AddAudioToVideoBlock - Attach an audio track to a video file."""
|
||||||
|
|
||||||
|
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import extract_source_name, strip_chapters_inplace
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class AddAudioToVideoBlock(Block):
|
||||||
|
"""Add (attach) an audio track to an existing video."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="Video input (URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
audio_in: MediaFileType = SchemaField(
|
||||||
|
description="Audio input (URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
volume: float = SchemaField(
|
||||||
|
description="Volume scale for the newly attached audio track (1.0 = original).",
|
||||||
|
default=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Final video (with attached audio), as a path or data URI."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="3503748d-62b6-4425-91d6-725b064af509",
|
||||||
|
description="Block to attach an audio file to a video file using moviepy.",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=AddAudioToVideoBlock.Input,
|
||||||
|
output_schema=AddAudioToVideoBlock.Output,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
assert execution_context.node_exec_id is not None
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
node_exec_id = execution_context.node_exec_id
|
||||||
|
|
||||||
|
# 1) Store the inputs locally
|
||||||
|
local_video_path = await store_media_file(
|
||||||
|
file=input_data.video_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
local_audio_path = await store_media_file(
|
||||||
|
file=input_data.audio_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
video_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||||
|
audio_abspath = get_exec_file_path(graph_exec_id, local_audio_path)
|
||||||
|
|
||||||
|
# 2) Load video + audio with moviepy
|
||||||
|
strip_chapters_inplace(video_abspath)
|
||||||
|
strip_chapters_inplace(audio_abspath)
|
||||||
|
video_clip = None
|
||||||
|
audio_clip = None
|
||||||
|
final_clip = None
|
||||||
|
try:
|
||||||
|
video_clip = VideoFileClip(video_abspath)
|
||||||
|
audio_clip = AudioFileClip(audio_abspath)
|
||||||
|
# Optionally scale volume
|
||||||
|
if input_data.volume != 1.0:
|
||||||
|
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
|
||||||
|
|
||||||
|
# 3) Attach the new audio track
|
||||||
|
final_clip = video_clip.with_audio(audio_clip)
|
||||||
|
|
||||||
|
# 4) Write to output file
|
||||||
|
source = extract_source_name(local_video_path)
|
||||||
|
output_filename = MediaFileType(f"{node_exec_id}_with_audio_{source}.mp4")
|
||||||
|
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
||||||
|
final_clip.write_videofile(
|
||||||
|
output_abspath, codec="libx264", audio_codec="aac"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if final_clip:
|
||||||
|
final_clip.close()
|
||||||
|
if audio_clip:
|
||||||
|
audio_clip.close()
|
||||||
|
if video_clip:
|
||||||
|
video_clip.close()
|
||||||
|
|
||||||
|
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
||||||
|
video_out = await store_media_file(
|
||||||
|
file=output_filename,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
167
autogpt_platform/backend/backend/blocks/video/clip.py
Normal file
167
autogpt_platform/backend/backend/blocks/video/clip.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
"""VideoClipBlock - Extract a segment from a video file."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import (
|
||||||
|
extract_source_name,
|
||||||
|
get_video_codecs,
|
||||||
|
strip_chapters_inplace,
|
||||||
|
)
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class VideoClipBlock(Block):
|
||||||
|
"""Extract a time segment from a video."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="Input video (URL, data URI, or local path)"
|
||||||
|
)
|
||||||
|
start_time: float = SchemaField(description="Start time in seconds", ge=0.0)
|
||||||
|
end_time: float = SchemaField(description="End time in seconds", ge=0.0)
|
||||||
|
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
|
||||||
|
description="Output format", default="mp4", advanced=True
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Clipped video file (path or data URI)"
|
||||||
|
)
|
||||||
|
duration: float = SchemaField(description="Clip duration in seconds")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="8f539119-e580-4d86-ad41-86fbcb22abb1",
|
||||||
|
description="Extract a time segment from a video",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
test_input={
|
||||||
|
"video_in": "/tmp/test.mp4",
|
||||||
|
"start_time": 0.0,
|
||||||
|
"end_time": 10.0,
|
||||||
|
},
|
||||||
|
test_output=[("video_out", str), ("duration", float)],
|
||||||
|
test_mock={
|
||||||
|
"_clip_video": lambda *args: 10.0,
|
||||||
|
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||||
|
"_store_output_video": lambda *args, **kwargs: "clip_test.mp4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_input_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store input video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_output_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store output video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _clip_video(
|
||||||
|
self,
|
||||||
|
video_abspath: str,
|
||||||
|
output_abspath: str,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
) -> float:
|
||||||
|
"""Extract a clip from a video. Extracted for testability."""
|
||||||
|
clip = None
|
||||||
|
subclip = None
|
||||||
|
try:
|
||||||
|
strip_chapters_inplace(video_abspath)
|
||||||
|
clip = VideoFileClip(video_abspath)
|
||||||
|
subclip = clip.subclipped(start_time, end_time)
|
||||||
|
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||||
|
subclip.write_videofile(
|
||||||
|
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||||
|
)
|
||||||
|
return subclip.duration
|
||||||
|
finally:
|
||||||
|
if subclip:
|
||||||
|
subclip.close()
|
||||||
|
if clip:
|
||||||
|
clip.close()
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
node_exec_id: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
# Validate time range
|
||||||
|
if input_data.end_time <= input_data.start_time:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
|
||||||
|
# Store the input video locally
|
||||||
|
local_video_path = await self._store_input_video(
|
||||||
|
execution_context, input_data.video_in
|
||||||
|
)
|
||||||
|
video_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, local_video_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build output path
|
||||||
|
source = extract_source_name(local_video_path)
|
||||||
|
output_filename = MediaFileType(
|
||||||
|
f"{node_exec_id}_clip_{source}.{input_data.output_format}"
|
||||||
|
)
|
||||||
|
output_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
duration = self._clip_video(
|
||||||
|
video_abspath,
|
||||||
|
output_abspath,
|
||||||
|
input_data.start_time,
|
||||||
|
input_data.end_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return as workspace path or data URI based on context
|
||||||
|
video_out = await self._store_output_video(
|
||||||
|
execution_context, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
|
yield "duration", duration
|
||||||
|
|
||||||
|
except BlockExecutionError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Failed to clip video: {e}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
) from e
|
||||||
227
autogpt_platform/backend/backend/blocks/video/concat.py
Normal file
227
autogpt_platform/backend/backend/blocks/video/concat.py
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
"""VideoConcatBlock - Concatenate multiple video clips into one."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from moviepy import concatenate_videoclips
|
||||||
|
from moviepy.video.fx import CrossFadeIn, CrossFadeOut, FadeIn, FadeOut
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import (
|
||||||
|
extract_source_name,
|
||||||
|
get_video_codecs,
|
||||||
|
strip_chapters_inplace,
|
||||||
|
)
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class VideoConcatBlock(Block):
|
||||||
|
"""Merge multiple video clips into one continuous video."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
videos: list[MediaFileType] = SchemaField(
|
||||||
|
description="List of video files to concatenate (in order)"
|
||||||
|
)
|
||||||
|
transition: Literal["none", "crossfade", "fade_black"] = SchemaField(
|
||||||
|
description="Transition between clips", default="none"
|
||||||
|
)
|
||||||
|
transition_duration: int = SchemaField(
|
||||||
|
description="Transition duration in seconds",
|
||||||
|
default=1,
|
||||||
|
ge=0,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
|
||||||
|
description="Output format", default="mp4", advanced=True
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Concatenated video file (path or data URI)"
|
||||||
|
)
|
||||||
|
total_duration: float = SchemaField(description="Total duration in seconds")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="9b0f531a-1118-487f-aeec-3fa63ea8900a",
|
||||||
|
description="Merge multiple video clips into one continuous video",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
test_input={
|
||||||
|
"videos": ["/tmp/a.mp4", "/tmp/b.mp4"],
|
||||||
|
},
|
||||||
|
test_output=[
|
||||||
|
("video_out", str),
|
||||||
|
("total_duration", float),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"_concat_videos": lambda *args: 20.0,
|
||||||
|
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||||
|
"_store_output_video": lambda *args, **kwargs: "concat_test.mp4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_input_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store input video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_output_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store output video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _concat_videos(
|
||||||
|
self,
|
||||||
|
video_abspaths: list[str],
|
||||||
|
output_abspath: str,
|
||||||
|
transition: str,
|
||||||
|
transition_duration: int,
|
||||||
|
) -> float:
|
||||||
|
"""Concatenate videos. Extracted for testability.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total duration of the concatenated video.
|
||||||
|
"""
|
||||||
|
clips = []
|
||||||
|
faded_clips = []
|
||||||
|
final = None
|
||||||
|
try:
|
||||||
|
# Load clips
|
||||||
|
for v in video_abspaths:
|
||||||
|
strip_chapters_inplace(v)
|
||||||
|
clips.append(VideoFileClip(v))
|
||||||
|
|
||||||
|
# Validate transition_duration against shortest clip
|
||||||
|
if transition in {"crossfade", "fade_black"} and transition_duration > 0:
|
||||||
|
min_duration = min(c.duration for c in clips)
|
||||||
|
if transition_duration >= min_duration:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=(
|
||||||
|
f"transition_duration ({transition_duration}s) must be "
|
||||||
|
f"shorter than the shortest clip ({min_duration:.2f}s)"
|
||||||
|
),
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
if transition == "crossfade":
|
||||||
|
for i, clip in enumerate(clips):
|
||||||
|
effects = []
|
||||||
|
if i > 0:
|
||||||
|
effects.append(CrossFadeIn(transition_duration))
|
||||||
|
if i < len(clips) - 1:
|
||||||
|
effects.append(CrossFadeOut(transition_duration))
|
||||||
|
if effects:
|
||||||
|
clip = clip.with_effects(effects)
|
||||||
|
faded_clips.append(clip)
|
||||||
|
final = concatenate_videoclips(
|
||||||
|
faded_clips,
|
||||||
|
method="compose",
|
||||||
|
padding=-transition_duration,
|
||||||
|
)
|
||||||
|
elif transition == "fade_black":
|
||||||
|
for clip in clips:
|
||||||
|
faded = clip.with_effects(
|
||||||
|
[FadeIn(transition_duration), FadeOut(transition_duration)]
|
||||||
|
)
|
||||||
|
faded_clips.append(faded)
|
||||||
|
final = concatenate_videoclips(faded_clips)
|
||||||
|
else:
|
||||||
|
final = concatenate_videoclips(clips)
|
||||||
|
|
||||||
|
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||||
|
final.write_videofile(
|
||||||
|
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||||
|
)
|
||||||
|
|
||||||
|
return final.duration
|
||||||
|
finally:
|
||||||
|
if final:
|
||||||
|
final.close()
|
||||||
|
for clip in faded_clips:
|
||||||
|
clip.close()
|
||||||
|
for clip in clips:
|
||||||
|
clip.close()
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
node_exec_id: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
# Validate minimum clips
|
||||||
|
if len(input_data.videos) < 2:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message="At least 2 videos are required for concatenation",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
|
||||||
|
# Store all input videos locally
|
||||||
|
video_abspaths = []
|
||||||
|
for video in input_data.videos:
|
||||||
|
local_path = await self._store_input_video(execution_context, video)
|
||||||
|
video_abspaths.append(
|
||||||
|
get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build output path
|
||||||
|
source = (
|
||||||
|
extract_source_name(video_abspaths[0]) if video_abspaths else "video"
|
||||||
|
)
|
||||||
|
output_filename = MediaFileType(
|
||||||
|
f"{node_exec_id}_concat_{source}.{input_data.output_format}"
|
||||||
|
)
|
||||||
|
output_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
total_duration = self._concat_videos(
|
||||||
|
video_abspaths,
|
||||||
|
output_abspath,
|
||||||
|
input_data.transition,
|
||||||
|
input_data.transition_duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return as workspace path or data URI based on context
|
||||||
|
video_out = await self._store_output_video(
|
||||||
|
execution_context, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
|
yield "total_duration", total_duration
|
||||||
|
|
||||||
|
except BlockExecutionError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Failed to concatenate videos: {e}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
) from e
|
||||||
172
autogpt_platform/backend/backend/blocks/video/download.py
Normal file
172
autogpt_platform/backend/backend/blocks/video/download.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
"""VideoDownloadBlock - Download video from URL (YouTube, Vimeo, news sites, direct links)."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import typing
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import yt_dlp
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from yt_dlp import _Params
|
||||||
|
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class VideoDownloadBlock(Block):
|
||||||
|
"""Download video from URL using yt-dlp."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
url: str = SchemaField(
|
||||||
|
description="URL of the video to download (YouTube, Vimeo, direct link, etc.)",
|
||||||
|
placeholder="https://www.youtube.com/watch?v=...",
|
||||||
|
)
|
||||||
|
quality: Literal["best", "1080p", "720p", "480p", "audio_only"] = SchemaField(
|
||||||
|
description="Video quality preference", default="720p"
|
||||||
|
)
|
||||||
|
output_format: Literal["mp4", "webm", "mkv"] = SchemaField(
|
||||||
|
description="Output video format", default="mp4", advanced=True
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_file: MediaFileType = SchemaField(
|
||||||
|
description="Downloaded video (path or data URI)"
|
||||||
|
)
|
||||||
|
duration: float = SchemaField(description="Video duration in seconds")
|
||||||
|
title: str = SchemaField(description="Video title from source")
|
||||||
|
source_url: str = SchemaField(description="Original source URL")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="c35daabb-cd60-493b-b9ad-51f1fe4b50c4",
|
||||||
|
description="Download video from URL (YouTube, Vimeo, news sites, direct links)",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
disabled=True, # Disable until we can sandbox yt-dlp and handle security implications
|
||||||
|
test_input={
|
||||||
|
"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
|
||||||
|
"quality": "480p",
|
||||||
|
},
|
||||||
|
test_output=[
|
||||||
|
("video_file", str),
|
||||||
|
("duration", float),
|
||||||
|
("title", str),
|
||||||
|
("source_url", str),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"_download_video": lambda *args: (
|
||||||
|
"video.mp4",
|
||||||
|
212.0,
|
||||||
|
"Test Video",
|
||||||
|
),
|
||||||
|
"_store_output_video": lambda *args, **kwargs: "video.mp4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_output_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store output video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_format_string(self, quality: str) -> str:
|
||||||
|
formats = {
|
||||||
|
"best": "bestvideo+bestaudio/best",
|
||||||
|
"1080p": "bestvideo[height<=1080]+bestaudio/best[height<=1080]",
|
||||||
|
"720p": "bestvideo[height<=720]+bestaudio/best[height<=720]",
|
||||||
|
"480p": "bestvideo[height<=480]+bestaudio/best[height<=480]",
|
||||||
|
"audio_only": "bestaudio/best",
|
||||||
|
}
|
||||||
|
return formats.get(quality, formats["720p"])
|
||||||
|
|
||||||
|
def _download_video(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
quality: str,
|
||||||
|
output_format: str,
|
||||||
|
output_dir: str,
|
||||||
|
node_exec_id: str,
|
||||||
|
) -> tuple[str, float, str]:
|
||||||
|
"""Download video. Extracted for testability."""
|
||||||
|
output_template = os.path.join(
|
||||||
|
output_dir, f"{node_exec_id}_%(title).50s.%(ext)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
ydl_opts: "_Params" = {
|
||||||
|
"format": f"{self._get_format_string(quality)}/best",
|
||||||
|
"outtmpl": output_template,
|
||||||
|
"merge_output_format": output_format,
|
||||||
|
"quiet": True,
|
||||||
|
"no_warnings": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
||||||
|
info = ydl.extract_info(url, download=True)
|
||||||
|
video_path = ydl.prepare_filename(info)
|
||||||
|
|
||||||
|
# Handle format conversion in filename
|
||||||
|
if not video_path.endswith(f".{output_format}"):
|
||||||
|
video_path = video_path.rsplit(".", 1)[0] + f".{output_format}"
|
||||||
|
|
||||||
|
# Return just the filename, not the full path
|
||||||
|
filename = os.path.basename(video_path)
|
||||||
|
|
||||||
|
return (
|
||||||
|
filename,
|
||||||
|
info.get("duration") or 0.0,
|
||||||
|
info.get("title") or "Unknown",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
node_exec_id: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
|
||||||
|
# Get the exec file directory
|
||||||
|
output_dir = get_exec_file_path(execution_context.graph_exec_id, "")
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
filename, duration, title = self._download_video(
|
||||||
|
input_data.url,
|
||||||
|
input_data.quality,
|
||||||
|
input_data.output_format,
|
||||||
|
output_dir,
|
||||||
|
node_exec_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return as workspace path or data URI based on context
|
||||||
|
video_out = await self._store_output_video(
|
||||||
|
execution_context, MediaFileType(filename)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_file", video_out
|
||||||
|
yield "duration", duration
|
||||||
|
yield "title", title
|
||||||
|
yield "source_url", input_data.url
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Failed to download video: {e}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
) from e
|
||||||
77
autogpt_platform/backend/backend/blocks/video/duration.py
Normal file
77
autogpt_platform/backend/backend/blocks/video/duration.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""MediaDurationBlock - Get the duration of a media file."""
|
||||||
|
|
||||||
|
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import strip_chapters_inplace
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class MediaDurationBlock(Block):
|
||||||
|
"""Get the duration of a media file (video or audio)."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
media_in: MediaFileType = SchemaField(
|
||||||
|
description="Media input (URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
is_video: bool = SchemaField(
|
||||||
|
description="Whether the media is a video (True) or audio (False).",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
duration: float = SchemaField(
|
||||||
|
description="Duration of the media file (in seconds)."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
|
||||||
|
description="Block to get the duration of a media file.",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=MediaDurationBlock.Input,
|
||||||
|
output_schema=MediaDurationBlock.Output,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
# 1) Store the input media locally
|
||||||
|
local_media_path = await store_media_file(
|
||||||
|
file=input_data.media_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
media_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, local_media_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2) Strip chapters to avoid MoviePy crash, then load the clip
|
||||||
|
strip_chapters_inplace(media_abspath)
|
||||||
|
clip = None
|
||||||
|
try:
|
||||||
|
if input_data.is_video:
|
||||||
|
clip = VideoFileClip(media_abspath)
|
||||||
|
else:
|
||||||
|
clip = AudioFileClip(media_abspath)
|
||||||
|
|
||||||
|
duration = clip.duration
|
||||||
|
finally:
|
||||||
|
if clip:
|
||||||
|
clip.close()
|
||||||
|
|
||||||
|
yield "duration", duration
|
||||||
115
autogpt_platform/backend/backend/blocks/video/loop.py
Normal file
115
autogpt_platform/backend/backend/blocks/video/loop.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
"""LoopVideoBlock - Loop a video to a given duration or number of repeats."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from moviepy.video.fx.Loop import Loop
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import extract_source_name, strip_chapters_inplace
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class LoopVideoBlock(Block):
|
||||||
|
"""Loop (repeat) a video clip until a given duration or number of loops."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="The input video (can be a URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
duration: Optional[float] = SchemaField(
|
||||||
|
description="Target duration (in seconds) to loop the video to. Either duration or n_loops must be provided.",
|
||||||
|
default=None,
|
||||||
|
ge=0.0,
|
||||||
|
le=3600.0, # Max 1 hour to prevent disk exhaustion
|
||||||
|
)
|
||||||
|
n_loops: Optional[int] = SchemaField(
|
||||||
|
description="Number of times to repeat the video. Either n_loops or duration must be provided.",
|
||||||
|
default=None,
|
||||||
|
ge=1,
|
||||||
|
le=10, # Max 10 loops to prevent disk exhaustion
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Looped video returned either as a relative path or a data URI."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="8bf9eef6-5451-4213-b265-25306446e94b",
|
||||||
|
description="Block to loop a video to a given duration or number of repeats.",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=LoopVideoBlock.Input,
|
||||||
|
output_schema=LoopVideoBlock.Output,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
assert execution_context.node_exec_id is not None
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
node_exec_id = execution_context.node_exec_id
|
||||||
|
|
||||||
|
# 1) Store the input video locally
|
||||||
|
local_video_path = await store_media_file(
|
||||||
|
file=input_data.video_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||||
|
|
||||||
|
# 2) Load the clip
|
||||||
|
strip_chapters_inplace(input_abspath)
|
||||||
|
clip = None
|
||||||
|
looped_clip = None
|
||||||
|
try:
|
||||||
|
clip = VideoFileClip(input_abspath)
|
||||||
|
|
||||||
|
# 3) Apply the loop effect
|
||||||
|
if input_data.duration:
|
||||||
|
# Loop until we reach the specified duration
|
||||||
|
looped_clip = clip.with_effects([Loop(duration=input_data.duration)])
|
||||||
|
elif input_data.n_loops:
|
||||||
|
looped_clip = clip.with_effects([Loop(n=input_data.n_loops)])
|
||||||
|
else:
|
||||||
|
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
|
||||||
|
|
||||||
|
assert isinstance(looped_clip, VideoFileClip)
|
||||||
|
|
||||||
|
# 4) Save the looped output
|
||||||
|
source = extract_source_name(local_video_path)
|
||||||
|
output_filename = MediaFileType(f"{node_exec_id}_looped_{source}.mp4")
|
||||||
|
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
||||||
|
|
||||||
|
looped_clip = looped_clip.with_audio(clip.audio)
|
||||||
|
looped_clip.write_videofile(
|
||||||
|
output_abspath, codec="libx264", audio_codec="aac"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if looped_clip:
|
||||||
|
looped_clip.close()
|
||||||
|
if clip:
|
||||||
|
clip.close()
|
||||||
|
|
||||||
|
# Return output - for_block_output returns workspace:// if available, else data URI
|
||||||
|
video_out = await store_media_file(
|
||||||
|
file=output_filename,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
267
autogpt_platform/backend/backend/blocks/video/narration.py
Normal file
267
autogpt_platform/backend/backend/blocks/video/narration.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
"""VideoNarrationBlock - Generate AI voice narration and add to video."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from elevenlabs import ElevenLabs
|
||||||
|
from moviepy import CompositeAudioClip
|
||||||
|
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.elevenlabs._auth import (
|
||||||
|
TEST_CREDENTIALS,
|
||||||
|
TEST_CREDENTIALS_INPUT,
|
||||||
|
ElevenLabsCredentials,
|
||||||
|
ElevenLabsCredentialsInput,
|
||||||
|
)
|
||||||
|
from backend.blocks.video._utils import (
|
||||||
|
extract_source_name,
|
||||||
|
get_video_codecs,
|
||||||
|
strip_chapters_inplace,
|
||||||
|
)
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import CredentialsField, SchemaField
|
||||||
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class VideoNarrationBlock(Block):
|
||||||
|
"""Generate AI narration and add to video."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: ElevenLabsCredentialsInput = CredentialsField(
|
||||||
|
description="ElevenLabs API key for voice synthesis"
|
||||||
|
)
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="Input video (URL, data URI, or local path)"
|
||||||
|
)
|
||||||
|
script: str = SchemaField(description="Narration script text")
|
||||||
|
voice_id: str = SchemaField(
|
||||||
|
description="ElevenLabs voice ID", default="21m00Tcm4TlvDq8ikWAM" # Rachel
|
||||||
|
)
|
||||||
|
model_id: Literal[
|
||||||
|
"eleven_multilingual_v2",
|
||||||
|
"eleven_flash_v2_5",
|
||||||
|
"eleven_turbo_v2_5",
|
||||||
|
"eleven_turbo_v2",
|
||||||
|
] = SchemaField(
|
||||||
|
description="ElevenLabs TTS model",
|
||||||
|
default="eleven_multilingual_v2",
|
||||||
|
)
|
||||||
|
mix_mode: Literal["replace", "mix", "ducking"] = SchemaField(
|
||||||
|
description="How to combine with original audio. 'ducking' applies stronger attenuation than 'mix'.",
|
||||||
|
default="ducking",
|
||||||
|
)
|
||||||
|
narration_volume: float = SchemaField(
|
||||||
|
description="Narration volume (0.0 to 2.0)",
|
||||||
|
default=1.0,
|
||||||
|
ge=0.0,
|
||||||
|
le=2.0,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
original_volume: float = SchemaField(
|
||||||
|
description="Original audio volume when mixing (0.0 to 1.0)",
|
||||||
|
default=0.3,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Video with narration (path or data URI)"
|
||||||
|
)
|
||||||
|
audio_file: MediaFileType = SchemaField(
|
||||||
|
description="Generated audio file (path or data URI)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="3d036b53-859c-4b17-9826-ca340f736e0e",
|
||||||
|
description="Generate AI narration and add to video",
|
||||||
|
categories={BlockCategory.MULTIMEDIA, BlockCategory.AI},
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
test_input={
|
||||||
|
"video_in": "/tmp/test.mp4",
|
||||||
|
"script": "Hello world",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[("video_out", str), ("audio_file", str)],
|
||||||
|
test_mock={
|
||||||
|
"_generate_narration_audio": lambda *args: b"mock audio content",
|
||||||
|
"_add_narration_to_video": lambda *args: None,
|
||||||
|
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||||
|
"_store_output_video": lambda *args, **kwargs: "narrated_test.mp4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_input_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store input video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_output_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store output video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_narration_audio(
|
||||||
|
self, api_key: str, script: str, voice_id: str, model_id: str
|
||||||
|
) -> bytes:
|
||||||
|
"""Generate narration audio via ElevenLabs API."""
|
||||||
|
client = ElevenLabs(api_key=api_key)
|
||||||
|
audio_generator = client.text_to_speech.convert(
|
||||||
|
voice_id=voice_id,
|
||||||
|
text=script,
|
||||||
|
model_id=model_id,
|
||||||
|
)
|
||||||
|
# The SDK returns a generator, collect all chunks
|
||||||
|
return b"".join(audio_generator)
|
||||||
|
|
||||||
|
def _add_narration_to_video(
|
||||||
|
self,
|
||||||
|
video_abspath: str,
|
||||||
|
audio_abspath: str,
|
||||||
|
output_abspath: str,
|
||||||
|
mix_mode: str,
|
||||||
|
narration_volume: float,
|
||||||
|
original_volume: float,
|
||||||
|
) -> None:
|
||||||
|
"""Add narration audio to video. Extracted for testability."""
|
||||||
|
video = None
|
||||||
|
final = None
|
||||||
|
narration_original = None
|
||||||
|
narration_scaled = None
|
||||||
|
original = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
strip_chapters_inplace(video_abspath)
|
||||||
|
video = VideoFileClip(video_abspath)
|
||||||
|
narration_original = AudioFileClip(audio_abspath)
|
||||||
|
narration_scaled = narration_original.with_volume_scaled(narration_volume)
|
||||||
|
narration = narration_scaled
|
||||||
|
|
||||||
|
if mix_mode == "replace":
|
||||||
|
final_audio = narration
|
||||||
|
elif mix_mode == "mix":
|
||||||
|
if video.audio:
|
||||||
|
original = video.audio.with_volume_scaled(original_volume)
|
||||||
|
final_audio = CompositeAudioClip([original, narration])
|
||||||
|
else:
|
||||||
|
final_audio = narration
|
||||||
|
else: # ducking - apply stronger attenuation
|
||||||
|
if video.audio:
|
||||||
|
# Ducking uses a much lower volume for original audio
|
||||||
|
ducking_volume = original_volume * 0.3
|
||||||
|
original = video.audio.with_volume_scaled(ducking_volume)
|
||||||
|
final_audio = CompositeAudioClip([original, narration])
|
||||||
|
else:
|
||||||
|
final_audio = narration
|
||||||
|
|
||||||
|
final = video.with_audio(final_audio)
|
||||||
|
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||||
|
final.write_videofile(
|
||||||
|
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if original:
|
||||||
|
original.close()
|
||||||
|
if narration_scaled:
|
||||||
|
narration_scaled.close()
|
||||||
|
if narration_original:
|
||||||
|
narration_original.close()
|
||||||
|
if final:
|
||||||
|
final.close()
|
||||||
|
if video:
|
||||||
|
video.close()
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: ElevenLabsCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
node_exec_id: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
|
||||||
|
# Store the input video locally
|
||||||
|
local_video_path = await self._store_input_video(
|
||||||
|
execution_context, input_data.video_in
|
||||||
|
)
|
||||||
|
video_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, local_video_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate narration audio via ElevenLabs
|
||||||
|
audio_content = self._generate_narration_audio(
|
||||||
|
credentials.api_key.get_secret_value(),
|
||||||
|
input_data.script,
|
||||||
|
input_data.voice_id,
|
||||||
|
input_data.model_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save audio to exec file path
|
||||||
|
audio_filename = MediaFileType(f"{node_exec_id}_narration.mp3")
|
||||||
|
audio_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, audio_filename
|
||||||
|
)
|
||||||
|
os.makedirs(os.path.dirname(audio_abspath), exist_ok=True)
|
||||||
|
with open(audio_abspath, "wb") as f:
|
||||||
|
f.write(audio_content)
|
||||||
|
|
||||||
|
# Add narration to video
|
||||||
|
source = extract_source_name(local_video_path)
|
||||||
|
output_filename = MediaFileType(f"{node_exec_id}_narrated_{source}.mp4")
|
||||||
|
output_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
self._add_narration_to_video(
|
||||||
|
video_abspath,
|
||||||
|
audio_abspath,
|
||||||
|
output_abspath,
|
||||||
|
input_data.mix_mode,
|
||||||
|
input_data.narration_volume,
|
||||||
|
input_data.original_volume,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return as workspace path or data URI based on context
|
||||||
|
video_out = await self._store_output_video(
|
||||||
|
execution_context, output_filename
|
||||||
|
)
|
||||||
|
audio_out = await self._store_output_video(
|
||||||
|
execution_context, audio_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
|
yield "audio_file", audio_out
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Failed to add narration: {e}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
) from e
|
||||||
231
autogpt_platform/backend/backend/blocks/video/text_overlay.py
Normal file
231
autogpt_platform/backend/backend/blocks/video/text_overlay.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
"""VideoTextOverlayBlock - Add text overlay to video."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from moviepy import CompositeVideoClip, TextClip
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import (
|
||||||
|
extract_source_name,
|
||||||
|
get_video_codecs,
|
||||||
|
strip_chapters_inplace,
|
||||||
|
)
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class VideoTextOverlayBlock(Block):
|
||||||
|
"""Add text overlay/caption to video."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="Input video (URL, data URI, or local path)"
|
||||||
|
)
|
||||||
|
text: str = SchemaField(description="Text to overlay on video")
|
||||||
|
position: Literal[
|
||||||
|
"top",
|
||||||
|
"center",
|
||||||
|
"bottom",
|
||||||
|
"top-left",
|
||||||
|
"top-right",
|
||||||
|
"bottom-left",
|
||||||
|
"bottom-right",
|
||||||
|
] = SchemaField(description="Position of text on screen", default="bottom")
|
||||||
|
start_time: float | None = SchemaField(
|
||||||
|
description="When to show text (seconds). None = entire video",
|
||||||
|
default=None,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
end_time: float | None = SchemaField(
|
||||||
|
description="When to hide text (seconds). None = until end",
|
||||||
|
default=None,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
font_size: int = SchemaField(
|
||||||
|
description="Font size", default=48, ge=12, le=200, advanced=True
|
||||||
|
)
|
||||||
|
font_color: str = SchemaField(
|
||||||
|
description="Font color (hex or name)", default="white", advanced=True
|
||||||
|
)
|
||||||
|
bg_color: str | None = SchemaField(
|
||||||
|
description="Background color behind text (None for transparent)",
|
||||||
|
default=None,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Video with text overlay (path or data URI)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="8ef14de6-cc90-430a-8cfa-3a003be92454",
|
||||||
|
description="Add text overlay/caption to video",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
disabled=True, # Disable until we can lockdown imagemagick security policy
|
||||||
|
test_input={"video_in": "/tmp/test.mp4", "text": "Hello World"},
|
||||||
|
test_output=[("video_out", str)],
|
||||||
|
test_mock={
|
||||||
|
"_add_text_overlay": lambda *args: None,
|
||||||
|
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||||
|
"_store_output_video": lambda *args, **kwargs: "overlay_test.mp4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_input_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store input video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_output_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store output video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_text_overlay(
|
||||||
|
self,
|
||||||
|
video_abspath: str,
|
||||||
|
output_abspath: str,
|
||||||
|
text: str,
|
||||||
|
position: str,
|
||||||
|
start_time: float | None,
|
||||||
|
end_time: float | None,
|
||||||
|
font_size: int,
|
||||||
|
font_color: str,
|
||||||
|
bg_color: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Add text overlay to video. Extracted for testability."""
|
||||||
|
video = None
|
||||||
|
final = None
|
||||||
|
txt_clip = None
|
||||||
|
try:
|
||||||
|
strip_chapters_inplace(video_abspath)
|
||||||
|
video = VideoFileClip(video_abspath)
|
||||||
|
|
||||||
|
txt_clip = TextClip(
|
||||||
|
text=text,
|
||||||
|
font_size=font_size,
|
||||||
|
color=font_color,
|
||||||
|
bg_color=bg_color,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Position mapping
|
||||||
|
pos_map = {
|
||||||
|
"top": ("center", "top"),
|
||||||
|
"center": ("center", "center"),
|
||||||
|
"bottom": ("center", "bottom"),
|
||||||
|
"top-left": ("left", "top"),
|
||||||
|
"top-right": ("right", "top"),
|
||||||
|
"bottom-left": ("left", "bottom"),
|
||||||
|
"bottom-right": ("right", "bottom"),
|
||||||
|
}
|
||||||
|
|
||||||
|
txt_clip = txt_clip.with_position(pos_map[position])
|
||||||
|
|
||||||
|
# Set timing
|
||||||
|
start = start_time or 0
|
||||||
|
end = end_time or video.duration
|
||||||
|
duration = max(0, end - start)
|
||||||
|
txt_clip = txt_clip.with_start(start).with_end(end).with_duration(duration)
|
||||||
|
|
||||||
|
final = CompositeVideoClip([video, txt_clip])
|
||||||
|
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||||
|
final.write_videofile(
|
||||||
|
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if txt_clip:
|
||||||
|
txt_clip.close()
|
||||||
|
if final:
|
||||||
|
final.close()
|
||||||
|
if video:
|
||||||
|
video.close()
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
node_exec_id: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
# Validate time range if both are provided
|
||||||
|
if (
|
||||||
|
input_data.start_time is not None
|
||||||
|
and input_data.end_time is not None
|
||||||
|
and input_data.end_time <= input_data.start_time
|
||||||
|
):
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
|
||||||
|
# Store the input video locally
|
||||||
|
local_video_path = await self._store_input_video(
|
||||||
|
execution_context, input_data.video_in
|
||||||
|
)
|
||||||
|
video_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, local_video_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build output path
|
||||||
|
source = extract_source_name(local_video_path)
|
||||||
|
output_filename = MediaFileType(f"{node_exec_id}_overlay_{source}.mp4")
|
||||||
|
output_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
self._add_text_overlay(
|
||||||
|
video_abspath,
|
||||||
|
output_abspath,
|
||||||
|
input_data.text,
|
||||||
|
input_data.position,
|
||||||
|
input_data.start_time,
|
||||||
|
input_data.end_time,
|
||||||
|
input_data.font_size,
|
||||||
|
input_data.font_color,
|
||||||
|
input_data.bg_color,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return as workspace path or data URI based on context
|
||||||
|
video_out = await self._store_output_video(
|
||||||
|
execution_context, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
|
|
||||||
|
except BlockExecutionError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Failed to add text overlay: {e}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
) from e
|
||||||
@@ -165,10 +165,13 @@ class TranscribeYoutubeVideoBlock(Block):
|
|||||||
credentials: WebshareProxyCredentials,
|
credentials: WebshareProxyCredentials,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
video_id = self.extract_video_id(input_data.youtube_url)
|
try:
|
||||||
yield "video_id", video_id
|
video_id = self.extract_video_id(input_data.youtube_url)
|
||||||
|
transcript = self.get_transcript(video_id, credentials)
|
||||||
|
transcript_text = self.format_transcript(transcript=transcript)
|
||||||
|
|
||||||
transcript = self.get_transcript(video_id, credentials)
|
# Only yield after all operations succeed
|
||||||
transcript_text = self.format_transcript(transcript=transcript)
|
yield "video_id", video_id
|
||||||
|
yield "transcript", transcript_text
|
||||||
yield "transcript", transcript_text
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
|||||||
@@ -246,7 +246,9 @@ class BlockSchema(BaseModel):
|
|||||||
f"is not of type {CredentialsMetaInput.__name__}"
|
f"is not of type {CredentialsMetaInput.__name__}"
|
||||||
)
|
)
|
||||||
|
|
||||||
credentials_fields[field_name].validate_credentials_field_schema(cls)
|
CredentialsMetaInput.validate_credentials_field_schema(
|
||||||
|
cls.get_field_schema(field_name), field_name
|
||||||
|
)
|
||||||
|
|
||||||
elif field_name in credentials_fields:
|
elif field_name in credentials_fields:
|
||||||
raise KeyError(
|
raise KeyError(
|
||||||
|
|||||||
@@ -36,12 +36,14 @@ from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
|||||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||||
|
from backend.blocks.video.narration import VideoNarrationBlock
|
||||||
from backend.data.block import Block, BlockCost, BlockCostType
|
from backend.data.block import Block, BlockCost, BlockCostType
|
||||||
from backend.integrations.credentials_store import (
|
from backend.integrations.credentials_store import (
|
||||||
aiml_api_credentials,
|
aiml_api_credentials,
|
||||||
anthropic_credentials,
|
anthropic_credentials,
|
||||||
apollo_credentials,
|
apollo_credentials,
|
||||||
did_credentials,
|
did_credentials,
|
||||||
|
elevenlabs_credentials,
|
||||||
enrichlayer_credentials,
|
enrichlayer_credentials,
|
||||||
groq_credentials,
|
groq_credentials,
|
||||||
ideogram_credentials,
|
ideogram_credentials,
|
||||||
@@ -78,6 +80,7 @@ MODEL_COST: dict[LlmModel, int] = {
|
|||||||
LlmModel.CLAUDE_4_1_OPUS: 21,
|
LlmModel.CLAUDE_4_1_OPUS: 21,
|
||||||
LlmModel.CLAUDE_4_OPUS: 21,
|
LlmModel.CLAUDE_4_OPUS: 21,
|
||||||
LlmModel.CLAUDE_4_SONNET: 5,
|
LlmModel.CLAUDE_4_SONNET: 5,
|
||||||
|
LlmModel.CLAUDE_4_6_OPUS: 14,
|
||||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||||
LlmModel.CLAUDE_4_5_OPUS: 14,
|
LlmModel.CLAUDE_4_5_OPUS: 14,
|
||||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||||
@@ -639,4 +642,16 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
VideoNarrationBlock: [
|
||||||
|
BlockCost(
|
||||||
|
cost_amount=5, # ElevenLabs TTS cost
|
||||||
|
cost_filter={
|
||||||
|
"credentials": {
|
||||||
|
"id": elevenlabs_credentials.id,
|
||||||
|
"provider": elevenlabs_credentials.provider,
|
||||||
|
"type": elevenlabs_credentials.type,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -134,6 +134,16 @@ async def test_block_credit_reset(server: SpinTestServer):
|
|||||||
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
|
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
|
||||||
user_credit.time_now = lambda: month1
|
user_credit.time_now = lambda: month1
|
||||||
|
|
||||||
|
# IMPORTANT: Set updatedAt to December of previous year to ensure it's
|
||||||
|
# in a different month than month1 (January). This fixes a timing bug
|
||||||
|
# where if the test runs in early February, 35 days ago would be January,
|
||||||
|
# matching the mocked month1 and preventing the refill from triggering.
|
||||||
|
dec_previous_year = month1.replace(year=month1.year - 1, month=12, day=15)
|
||||||
|
await UserBalance.prisma().update(
|
||||||
|
where={"userId": DEFAULT_USER_ID},
|
||||||
|
data={"updatedAt": dec_previous_year},
|
||||||
|
)
|
||||||
|
|
||||||
# First call in month 1 should trigger refill
|
# First call in month 1 should trigger refill
|
||||||
balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||||
assert balance == REFILL_VALUE # Should get 1000 credits
|
assert balance == REFILL_VALUE # Should get 1000 credits
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, cast
|
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Self, cast
|
||||||
|
|
||||||
from prisma.enums import SubmissionStatus
|
from prisma.enums import SubmissionStatus
|
||||||
from prisma.models import (
|
from prisma.models import (
|
||||||
@@ -20,7 +20,7 @@ from prisma.types import (
|
|||||||
AgentNodeLinkCreateInput,
|
AgentNodeLinkCreateInput,
|
||||||
StoreListingVersionWhereInput,
|
StoreListingVersionWhereInput,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel, BeforeValidator, Field, create_model
|
from pydantic import BaseModel, BeforeValidator, Field
|
||||||
from pydantic.fields import computed_field
|
from pydantic.fields import computed_field
|
||||||
|
|
||||||
from backend.blocks.agent import AgentExecutorBlock
|
from backend.blocks.agent import AgentExecutorBlock
|
||||||
@@ -30,7 +30,6 @@ from backend.data.db import prisma as db
|
|||||||
from backend.data.dynamic_fields import is_tool_pin, sanitize_pin_name
|
from backend.data.dynamic_fields import is_tool_pin, sanitize_pin_name
|
||||||
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
|
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
CredentialsField,
|
|
||||||
CredentialsFieldInfo,
|
CredentialsFieldInfo,
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
is_credentials_field_name,
|
is_credentials_field_name,
|
||||||
@@ -45,7 +44,6 @@ from .block import (
|
|||||||
AnyBlockSchema,
|
AnyBlockSchema,
|
||||||
Block,
|
Block,
|
||||||
BlockInput,
|
BlockInput,
|
||||||
BlockSchema,
|
|
||||||
BlockType,
|
BlockType,
|
||||||
EmptySchema,
|
EmptySchema,
|
||||||
get_block,
|
get_block,
|
||||||
@@ -113,10 +111,12 @@ class Link(BaseDbModel):
|
|||||||
|
|
||||||
class Node(BaseDbModel):
|
class Node(BaseDbModel):
|
||||||
block_id: str
|
block_id: str
|
||||||
input_default: BlockInput = {} # dict[input_name, default_value]
|
input_default: BlockInput = Field( # dict[input_name, default_value]
|
||||||
metadata: dict[str, Any] = {}
|
default_factory=dict
|
||||||
input_links: list[Link] = []
|
)
|
||||||
output_links: list[Link] = []
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
input_links: list[Link] = Field(default_factory=list)
|
||||||
|
output_links: list[Link] = Field(default_factory=list)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def credentials_optional(self) -> bool:
|
def credentials_optional(self) -> bool:
|
||||||
@@ -221,18 +221,33 @@ class NodeModel(Node):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class BaseGraph(BaseDbModel):
|
class GraphBaseMeta(BaseDbModel):
|
||||||
|
"""
|
||||||
|
Shared base for `GraphMeta` and `BaseGraph`, with core graph metadata fields.
|
||||||
|
"""
|
||||||
|
|
||||||
version: int = 1
|
version: int = 1
|
||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
instructions: str | None = None
|
instructions: str | None = None
|
||||||
recommended_schedule_cron: str | None = None
|
recommended_schedule_cron: str | None = None
|
||||||
nodes: list[Node] = []
|
|
||||||
links: list[Link] = []
|
|
||||||
forked_from_id: str | None = None
|
forked_from_id: str | None = None
|
||||||
forked_from_version: int | None = None
|
forked_from_version: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class BaseGraph(GraphBaseMeta):
|
||||||
|
"""
|
||||||
|
Graph with nodes, links, and computed I/O schema fields.
|
||||||
|
|
||||||
|
Used to represent sub-graphs within a `Graph`. Contains the full graph
|
||||||
|
structure including nodes and links, plus computed fields for schemas
|
||||||
|
and trigger info. Does NOT include user_id or created_at (see GraphModel).
|
||||||
|
"""
|
||||||
|
|
||||||
|
nodes: list[Node] = Field(default_factory=list)
|
||||||
|
links: list[Link] = Field(default_factory=list)
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
def input_schema(self) -> dict[str, Any]:
|
def input_schema(self) -> dict[str, Any]:
|
||||||
@@ -361,44 +376,79 @@ class GraphTriggerInfo(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Graph(BaseGraph):
|
class Graph(BaseGraph):
|
||||||
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs
|
"""Creatable graph model used in API create/update endpoints."""
|
||||||
|
|
||||||
|
sub_graphs: list[BaseGraph] = Field(default_factory=list) # Flattened sub-graphs
|
||||||
|
|
||||||
|
|
||||||
|
class GraphMeta(GraphBaseMeta):
|
||||||
|
"""
|
||||||
|
Lightweight graph metadata model representing an existing graph from the database,
|
||||||
|
for use in listings and summaries.
|
||||||
|
|
||||||
|
Lacks `GraphModel`'s nodes, links, and expensive computed fields.
|
||||||
|
Use for list endpoints where full graph data is not needed and performance matters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str # type: ignore
|
||||||
|
version: int # type: ignore
|
||||||
|
user_id: str
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_db(cls, graph: "AgentGraph") -> Self:
|
||||||
|
return cls(
|
||||||
|
id=graph.id,
|
||||||
|
version=graph.version,
|
||||||
|
is_active=graph.isActive,
|
||||||
|
name=graph.name or "",
|
||||||
|
description=graph.description or "",
|
||||||
|
instructions=graph.instructions,
|
||||||
|
recommended_schedule_cron=graph.recommendedScheduleCron,
|
||||||
|
forked_from_id=graph.forkedFromId,
|
||||||
|
forked_from_version=graph.forkedFromVersion,
|
||||||
|
user_id=graph.userId,
|
||||||
|
created_at=graph.createdAt,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphModel(Graph, GraphMeta):
|
||||||
|
"""
|
||||||
|
Full graph model representing an existing graph from the database.
|
||||||
|
|
||||||
|
This is the primary model for working with persisted graphs. Includes all
|
||||||
|
graph data (nodes, links, sub_graphs) plus user ownership and timestamps.
|
||||||
|
Provides computed fields (input_schema, output_schema, etc.) used during
|
||||||
|
set-up (frontend) and execution (backend).
|
||||||
|
|
||||||
|
Inherits from:
|
||||||
|
- `Graph`: provides structure (nodes, links, sub_graphs) and computed schemas
|
||||||
|
- `GraphMeta`: provides user_id, created_at for database records
|
||||||
|
"""
|
||||||
|
|
||||||
|
nodes: list[NodeModel] = Field(default_factory=list) # type: ignore
|
||||||
|
|
||||||
|
@property
|
||||||
|
def starting_nodes(self) -> list[NodeModel]:
|
||||||
|
outbound_nodes = {link.sink_id for link in self.links}
|
||||||
|
input_nodes = {
|
||||||
|
node.id for node in self.nodes if node.block.block_type == BlockType.INPUT
|
||||||
|
}
|
||||||
|
return [
|
||||||
|
node
|
||||||
|
for node in self.nodes
|
||||||
|
if node.id not in outbound_nodes or node.id in input_nodes
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def webhook_input_node(self) -> NodeModel | None: # type: ignore
|
||||||
|
return cast(NodeModel, super().webhook_input_node)
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
def credentials_input_schema(self) -> dict[str, Any]:
|
def credentials_input_schema(self) -> dict[str, Any]:
|
||||||
schema = self._credentials_input_schema.jsonschema()
|
|
||||||
|
|
||||||
# Determine which credential fields are required based on credentials_optional metadata
|
|
||||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||||
required_fields = []
|
|
||||||
|
|
||||||
# Build a map of node_id -> node for quick lookup
|
|
||||||
all_nodes = {node.id: node for node in self.nodes}
|
|
||||||
for sub_graph in self.sub_graphs:
|
|
||||||
for node in sub_graph.nodes:
|
|
||||||
all_nodes[node.id] = node
|
|
||||||
|
|
||||||
for field_key, (
|
|
||||||
_field_info,
|
|
||||||
node_field_pairs,
|
|
||||||
) in graph_credentials_inputs.items():
|
|
||||||
# A field is required if ANY node using it has credentials_optional=False
|
|
||||||
is_required = False
|
|
||||||
for node_id, _field_name in node_field_pairs:
|
|
||||||
node = all_nodes.get(node_id)
|
|
||||||
if node and not node.credentials_optional:
|
|
||||||
is_required = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if is_required:
|
|
||||||
required_fields.append(field_key)
|
|
||||||
|
|
||||||
schema["required"] = required_fields
|
|
||||||
return schema
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _credentials_input_schema(self) -> type[BlockSchema]:
|
|
||||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
||||||
f"{graph_credentials_inputs}"
|
f"{graph_credentials_inputs}"
|
||||||
@@ -406,8 +456,8 @@ class Graph(BaseGraph):
|
|||||||
|
|
||||||
# Warn if same-provider credentials inputs can't be combined (= bad UX)
|
# Warn if same-provider credentials inputs can't be combined (= bad UX)
|
||||||
graph_cred_fields = list(graph_credentials_inputs.values())
|
graph_cred_fields = list(graph_credentials_inputs.values())
|
||||||
for i, (field, keys) in enumerate(graph_cred_fields):
|
for i, (field, keys, _) in enumerate(graph_cred_fields):
|
||||||
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
|
for other_field, other_keys, _ in list(graph_cred_fields)[i + 1 :]:
|
||||||
if field.provider != other_field.provider:
|
if field.provider != other_field.provider:
|
||||||
continue
|
continue
|
||||||
if ProviderName.HTTP in field.provider:
|
if ProviderName.HTTP in field.provider:
|
||||||
@@ -423,31 +473,78 @@ class Graph(BaseGraph):
|
|||||||
f"keys: {keys} <> {other_keys}."
|
f"keys: {keys} <> {other_keys}."
|
||||||
)
|
)
|
||||||
|
|
||||||
fields: dict[str, tuple[type[CredentialsMetaInput], CredentialsMetaInput]] = {
|
# Build JSON schema directly to avoid expensive create_model + validation overhead
|
||||||
agg_field_key: (
|
properties = {}
|
||||||
CredentialsMetaInput[
|
required_fields = []
|
||||||
Literal[tuple(field_info.provider)], # type: ignore
|
|
||||||
Literal[tuple(field_info.supported_types)], # type: ignore
|
|
||||||
],
|
|
||||||
CredentialsField(
|
|
||||||
required_scopes=set(field_info.required_scopes or []),
|
|
||||||
discriminator=field_info.discriminator,
|
|
||||||
discriminator_mapping=field_info.discriminator_mapping,
|
|
||||||
discriminator_values=field_info.discriminator_values,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
return create_model(
|
for agg_field_key, (
|
||||||
self.name.replace(" ", "") + "CredentialsInputSchema",
|
field_info,
|
||||||
__base__=BlockSchema,
|
_,
|
||||||
**fields, # type: ignore
|
is_required,
|
||||||
)
|
) in graph_credentials_inputs.items():
|
||||||
|
providers = list(field_info.provider)
|
||||||
|
cred_types = list(field_info.supported_types)
|
||||||
|
|
||||||
|
field_schema: dict[str, Any] = {
|
||||||
|
"credentials_provider": providers,
|
||||||
|
"credentials_types": cred_types,
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"id": {"title": "Id", "type": "string"},
|
||||||
|
"title": {
|
||||||
|
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||||
|
"default": None,
|
||||||
|
"title": "Title",
|
||||||
|
},
|
||||||
|
"provider": {
|
||||||
|
"title": "Provider",
|
||||||
|
"type": "string",
|
||||||
|
**(
|
||||||
|
{"enum": providers}
|
||||||
|
if len(providers) > 1
|
||||||
|
else {"const": providers[0]}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"title": "Type",
|
||||||
|
"type": "string",
|
||||||
|
**(
|
||||||
|
{"enum": cred_types}
|
||||||
|
if len(cred_types) > 1
|
||||||
|
else {"const": cred_types[0]}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["id", "provider", "type"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add other (optional) field info items
|
||||||
|
field_schema.update(
|
||||||
|
field_info.model_dump(
|
||||||
|
by_alias=True,
|
||||||
|
exclude_defaults=True,
|
||||||
|
exclude={"provider", "supported_types"}, # already included above
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure field schema is well-formed
|
||||||
|
CredentialsMetaInput.validate_credentials_field_schema(
|
||||||
|
field_schema, agg_field_key
|
||||||
|
)
|
||||||
|
|
||||||
|
properties[agg_field_key] = field_schema
|
||||||
|
if is_required:
|
||||||
|
required_fields.append(agg_field_key)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": properties,
|
||||||
|
"required": required_fields,
|
||||||
|
}
|
||||||
|
|
||||||
def aggregate_credentials_inputs(
|
def aggregate_credentials_inputs(
|
||||||
self,
|
self,
|
||||||
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]]]]:
|
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]:
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
dict[aggregated_field_key, tuple(
|
dict[aggregated_field_key, tuple(
|
||||||
@@ -455,13 +552,19 @@ class Graph(BaseGraph):
|
|||||||
(now includes discriminator_values from matching nodes)
|
(now includes discriminator_values from matching nodes)
|
||||||
set[(node_id, field_name)]: Node credentials fields that are
|
set[(node_id, field_name)]: Node credentials fields that are
|
||||||
compatible with this aggregated field spec
|
compatible with this aggregated field spec
|
||||||
|
bool: True if the field is required (any node has credentials_optional=False)
|
||||||
)]
|
)]
|
||||||
"""
|
"""
|
||||||
# First collect all credential field data with input defaults
|
# First collect all credential field data with input defaults
|
||||||
node_credential_data = []
|
# Track (field_info, (node_id, field_name), is_required) for each credential field
|
||||||
|
node_credential_data: list[tuple[CredentialsFieldInfo, tuple[str, str]]] = []
|
||||||
|
node_required_map: dict[str, bool] = {} # node_id -> is_required
|
||||||
|
|
||||||
for graph in [self] + self.sub_graphs:
|
for graph in [self] + self.sub_graphs:
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
|
# Track if this node requires credentials (credentials_optional=False means required)
|
||||||
|
node_required_map[node.id] = not node.credentials_optional
|
||||||
|
|
||||||
for (
|
for (
|
||||||
field_name,
|
field_name,
|
||||||
field_info,
|
field_info,
|
||||||
@@ -485,37 +588,21 @@ class Graph(BaseGraph):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Combine credential field info (this will merge discriminator_values automatically)
|
# Combine credential field info (this will merge discriminator_values automatically)
|
||||||
return CredentialsFieldInfo.combine(*node_credential_data)
|
combined = CredentialsFieldInfo.combine(*node_credential_data)
|
||||||
|
|
||||||
|
# Add is_required flag to each aggregated field
|
||||||
class GraphModel(Graph):
|
# A field is required if ANY node using it has credentials_optional=False
|
||||||
user_id: str
|
return {
|
||||||
nodes: list[NodeModel] = [] # type: ignore
|
key: (
|
||||||
|
field_info,
|
||||||
created_at: datetime
|
node_field_pairs,
|
||||||
|
any(
|
||||||
@property
|
node_required_map.get(node_id, True)
|
||||||
def starting_nodes(self) -> list[NodeModel]:
|
for node_id, _ in node_field_pairs
|
||||||
outbound_nodes = {link.sink_id for link in self.links}
|
),
|
||||||
input_nodes = {
|
)
|
||||||
node.id for node in self.nodes if node.block.block_type == BlockType.INPUT
|
for key, (field_info, node_field_pairs) in combined.items()
|
||||||
}
|
}
|
||||||
return [
|
|
||||||
node
|
|
||||||
for node in self.nodes
|
|
||||||
if node.id not in outbound_nodes or node.id in input_nodes
|
|
||||||
]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def webhook_input_node(self) -> NodeModel | None: # type: ignore
|
|
||||||
return cast(NodeModel, super().webhook_input_node)
|
|
||||||
|
|
||||||
def meta(self) -> "GraphMeta":
|
|
||||||
"""
|
|
||||||
Returns a GraphMeta object with metadata about the graph.
|
|
||||||
This is used to return metadata about the graph without exposing nodes and links.
|
|
||||||
"""
|
|
||||||
return GraphMeta.from_graph(self)
|
|
||||||
|
|
||||||
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
||||||
"""
|
"""
|
||||||
@@ -799,13 +886,14 @@ class GraphModel(Graph):
|
|||||||
if is_static_output_block(link.source_id):
|
if is_static_output_block(link.source_id):
|
||||||
link.is_static = True # Each value block output should be static.
|
link.is_static = True # Each value block output should be static.
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def from_db(
|
def from_db( # type: ignore[reportIncompatibleMethodOverride]
|
||||||
|
cls,
|
||||||
graph: AgentGraph,
|
graph: AgentGraph,
|
||||||
for_export: bool = False,
|
for_export: bool = False,
|
||||||
sub_graphs: list[AgentGraph] | None = None,
|
sub_graphs: list[AgentGraph] | None = None,
|
||||||
) -> "GraphModel":
|
) -> Self:
|
||||||
return GraphModel(
|
return cls(
|
||||||
id=graph.id,
|
id=graph.id,
|
||||||
user_id=graph.userId if not for_export else "",
|
user_id=graph.userId if not for_export else "",
|
||||||
version=graph.version,
|
version=graph.version,
|
||||||
@@ -831,17 +919,28 @@ class GraphModel(Graph):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def hide_nodes(self) -> "GraphModelWithoutNodes":
|
||||||
|
"""
|
||||||
|
Returns a copy of the `GraphModel` with nodes, links, and sub-graphs hidden
|
||||||
|
(excluded from serialization). They are still present in the model instance
|
||||||
|
so all computed fields (e.g. `credentials_input_schema`) still work.
|
||||||
|
"""
|
||||||
|
return GraphModelWithoutNodes.model_validate(self, from_attributes=True)
|
||||||
|
|
||||||
class GraphMeta(Graph):
|
|
||||||
user_id: str
|
|
||||||
|
|
||||||
# Easy work-around to prevent exposing nodes and links in the API response
|
class GraphModelWithoutNodes(GraphModel):
|
||||||
nodes: list[NodeModel] = Field(default=[], exclude=True) # type: ignore
|
"""
|
||||||
links: list[Link] = Field(default=[], exclude=True)
|
GraphModel variant that excludes nodes, links, and sub-graphs from serialization.
|
||||||
|
|
||||||
@staticmethod
|
Used in contexts like the store where exposing internal graph structure
|
||||||
def from_graph(graph: GraphModel) -> "GraphMeta":
|
is not desired. Inherits all computed fields from GraphModel but marks
|
||||||
return GraphMeta(**graph.model_dump())
|
nodes and links as excluded from JSON output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
nodes: list[NodeModel] = Field(default_factory=list, exclude=True)
|
||||||
|
links: list[Link] = Field(default_factory=list, exclude=True)
|
||||||
|
|
||||||
|
sub_graphs: list[BaseGraph] = Field(default_factory=list, exclude=True)
|
||||||
|
|
||||||
|
|
||||||
class GraphsPaginated(BaseModel):
|
class GraphsPaginated(BaseModel):
|
||||||
@@ -912,21 +1011,11 @@ async def list_graphs_paginated(
|
|||||||
where=where_clause,
|
where=where_clause,
|
||||||
distinct=["id"],
|
distinct=["id"],
|
||||||
order={"version": "desc"},
|
order={"version": "desc"},
|
||||||
include=AGENT_GRAPH_INCLUDE,
|
|
||||||
skip=offset,
|
skip=offset,
|
||||||
take=page_size,
|
take=page_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
graph_models: list[GraphMeta] = []
|
graph_models = [GraphMeta.from_db(graph) for graph in graphs]
|
||||||
for graph in graphs:
|
|
||||||
try:
|
|
||||||
graph_meta = GraphModel.from_db(graph).meta()
|
|
||||||
# Trigger serialization to validate that the graph is well formed
|
|
||||||
graph_meta.model_dump()
|
|
||||||
graph_models.append(graph_meta)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing graph {graph.id}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
return GraphsPaginated(
|
return GraphsPaginated(
|
||||||
graphs=graph_models,
|
graphs=graph_models,
|
||||||
|
|||||||
@@ -163,7 +163,6 @@ class User(BaseModel):
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from prisma.models import User as PrismaUser
|
from prisma.models import User as PrismaUser
|
||||||
|
|
||||||
from backend.data.block import BlockSchema
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -508,15 +507,13 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
|||||||
def allowed_cred_types(cls) -> tuple[CredentialsType, ...]:
|
def allowed_cred_types(cls) -> tuple[CredentialsType, ...]:
|
||||||
return get_args(cls.model_fields["type"].annotation)
|
return get_args(cls.model_fields["type"].annotation)
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def validate_credentials_field_schema(cls, model: type["BlockSchema"]):
|
def validate_credentials_field_schema(
|
||||||
|
field_schema: dict[str, Any], field_name: str
|
||||||
|
):
|
||||||
"""Validates the schema of a credentials input field"""
|
"""Validates the schema of a credentials input field"""
|
||||||
field_name = next(
|
|
||||||
name for name, type in model.get_credentials_fields().items() if type is cls
|
|
||||||
)
|
|
||||||
field_schema = model.jsonschema()["properties"][field_name]
|
|
||||||
try:
|
try:
|
||||||
schema_extra = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
|
field_info = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
if "Field required [type=missing" not in str(e):
|
if "Field required [type=missing" not in str(e):
|
||||||
raise
|
raise
|
||||||
@@ -526,11 +523,11 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
|||||||
f"{field_schema}"
|
f"{field_schema}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
providers = cls.allowed_providers()
|
providers = field_info.provider
|
||||||
if (
|
if (
|
||||||
providers is not None
|
providers is not None
|
||||||
and len(providers) > 1
|
and len(providers) > 1
|
||||||
and not schema_extra.discriminator
|
and not field_info.discriminator
|
||||||
):
|
):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Multi-provider CredentialsField '{field_name}' "
|
f"Multi-provider CredentialsField '{field_name}' "
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from redis.asyncio.lock import Lock as AsyncRedisLock
|
|||||||
|
|
||||||
from backend.blocks.agent import AgentExecutorBlock
|
from backend.blocks.agent import AgentExecutorBlock
|
||||||
from backend.blocks.io import AgentOutputBlock
|
from backend.blocks.io import AgentOutputBlock
|
||||||
|
from backend.blocks.mcp.block import MCPToolBlock
|
||||||
from backend.data import redis_client as redis
|
from backend.data import redis_client as redis
|
||||||
from backend.data.block import (
|
from backend.data.block import (
|
||||||
BlockInput,
|
BlockInput,
|
||||||
@@ -229,6 +230,10 @@ async def execute_node(
|
|||||||
_input_data.nodes_input_masks = nodes_input_masks
|
_input_data.nodes_input_masks = nodes_input_masks
|
||||||
_input_data.user_id = user_id
|
_input_data.user_id = user_id
|
||||||
input_data = _input_data.model_dump()
|
input_data = _input_data.model_dump()
|
||||||
|
elif isinstance(node_block, MCPToolBlock):
|
||||||
|
_mcp_data = MCPToolBlock.Input(**node.input_default)
|
||||||
|
_mcp_data.tool_arguments = input_data
|
||||||
|
input_data = _mcp_data.model_dump()
|
||||||
data.inputs = input_data
|
data.inputs = input_data
|
||||||
|
|
||||||
# Execute the node
|
# Execute the node
|
||||||
|
|||||||
@@ -373,7 +373,7 @@ def make_node_credentials_input_map(
|
|||||||
# Get aggregated credentials fields for the graph
|
# Get aggregated credentials fields for the graph
|
||||||
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
||||||
|
|
||||||
for graph_input_name, (_, compatible_node_fields) in graph_cred_inputs.items():
|
for graph_input_name, (_, compatible_node_fields, _) in graph_cred_inputs.items():
|
||||||
# Best-effort map: skip missing items
|
# Best-effort map: skip missing items
|
||||||
if graph_input_name not in graph_credentials_input:
|
if graph_input_name not in graph_credentials_input:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -224,6 +224,14 @@ openweathermap_credentials = APIKeyCredentials(
|
|||||||
expires_at=None,
|
expires_at=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elevenlabs_credentials = APIKeyCredentials(
|
||||||
|
id="f4a8b6c2-3d1e-4f5a-9b8c-7d6e5f4a3b2c",
|
||||||
|
provider="elevenlabs",
|
||||||
|
api_key=SecretStr(settings.secrets.elevenlabs_api_key),
|
||||||
|
title="Use Credits for ElevenLabs",
|
||||||
|
expires_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
DEFAULT_CREDENTIALS = [
|
DEFAULT_CREDENTIALS = [
|
||||||
ollama_credentials,
|
ollama_credentials,
|
||||||
revid_credentials,
|
revid_credentials,
|
||||||
@@ -252,6 +260,7 @@ DEFAULT_CREDENTIALS = [
|
|||||||
v0_credentials,
|
v0_credentials,
|
||||||
webshare_proxy_credentials,
|
webshare_proxy_credentials,
|
||||||
openweathermap_credentials,
|
openweathermap_credentials,
|
||||||
|
elevenlabs_credentials,
|
||||||
]
|
]
|
||||||
|
|
||||||
SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS}
|
SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS}
|
||||||
@@ -366,6 +375,8 @@ class IntegrationCredentialsStore:
|
|||||||
all_credentials.append(webshare_proxy_credentials)
|
all_credentials.append(webshare_proxy_credentials)
|
||||||
if settings.secrets.openweathermap_api_key:
|
if settings.secrets.openweathermap_api_key:
|
||||||
all_credentials.append(openweathermap_credentials)
|
all_credentials.append(openweathermap_credentials)
|
||||||
|
if settings.secrets.elevenlabs_api_key:
|
||||||
|
all_credentials.append(elevenlabs_credentials)
|
||||||
return all_credentials
|
return all_credentials
|
||||||
|
|
||||||
async def get_creds_by_id(
|
async def get_creds_by_id(
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ class ProviderName(str, Enum):
|
|||||||
DISCORD = "discord"
|
DISCORD = "discord"
|
||||||
D_ID = "d_id"
|
D_ID = "d_id"
|
||||||
E2B = "e2b"
|
E2B = "e2b"
|
||||||
|
ELEVENLABS = "elevenlabs"
|
||||||
FAL = "fal"
|
FAL = "fal"
|
||||||
GITHUB = "github"
|
GITHUB = "github"
|
||||||
GOOGLE = "google"
|
GOOGLE = "google"
|
||||||
@@ -29,6 +30,7 @@ class ProviderName(str, Enum):
|
|||||||
IDEOGRAM = "ideogram"
|
IDEOGRAM = "ideogram"
|
||||||
JINA = "jina"
|
JINA = "jina"
|
||||||
LLAMA_API = "llama_api"
|
LLAMA_API = "llama_api"
|
||||||
|
MCP = "mcp"
|
||||||
MEDIUM = "medium"
|
MEDIUM = "medium"
|
||||||
MEM0 = "mem0"
|
MEM0 = "mem0"
|
||||||
NOTION = "notion"
|
NOTION = "notion"
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ from pathlib import Path
|
|||||||
from typing import TYPE_CHECKING, Literal
|
from typing import TYPE_CHECKING, Literal
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
@@ -17,6 +19,35 @@ from backend.util.virus_scanner import scan_content_safe
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceUri(BaseModel):
|
||||||
|
"""Parsed workspace:// URI."""
|
||||||
|
|
||||||
|
file_ref: str # File ID or path (e.g. "abc123" or "/path/to/file.txt")
|
||||||
|
mime_type: str | None = None # MIME type from fragment (e.g. "video/mp4")
|
||||||
|
is_path: bool = False # True if file_ref is a path (starts with "/")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_workspace_uri(uri: str) -> WorkspaceUri:
|
||||||
|
"""Parse a workspace:// URI into its components.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
"workspace://abc123" → WorkspaceUri(file_ref="abc123", mime_type=None, is_path=False)
|
||||||
|
"workspace://abc123#video/mp4" → WorkspaceUri(file_ref="abc123", mime_type="video/mp4", is_path=False)
|
||||||
|
"workspace:///path/to/file.txt" → WorkspaceUri(file_ref="/path/to/file.txt", mime_type=None, is_path=True)
|
||||||
|
"""
|
||||||
|
raw = uri.removeprefix("workspace://")
|
||||||
|
mime_type: str | None = None
|
||||||
|
if "#" in raw:
|
||||||
|
raw, fragment = raw.split("#", 1)
|
||||||
|
mime_type = fragment or None
|
||||||
|
return WorkspaceUri(
|
||||||
|
file_ref=raw,
|
||||||
|
mime_type=mime_type,
|
||||||
|
is_path=raw.startswith("/"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Return format options for store_media_file
|
# Return format options for store_media_file
|
||||||
# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
||||||
# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs
|
# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs
|
||||||
@@ -183,22 +214,20 @@ async def store_media_file(
|
|||||||
"This file type is only available in CoPilot sessions."
|
"This file type is only available in CoPilot sessions."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse workspace reference
|
# Parse workspace reference (strips #mimeType fragment from file ID)
|
||||||
# workspace://abc123 - by file ID
|
ws = parse_workspace_uri(file)
|
||||||
# workspace:///path/to/file.txt - by virtual path
|
|
||||||
file_ref = file[12:] # Remove "workspace://"
|
|
||||||
|
|
||||||
if file_ref.startswith("/"):
|
if ws.is_path:
|
||||||
# Path reference
|
# Path reference: workspace:///path/to/file.txt
|
||||||
workspace_content = await workspace_manager.read_file(file_ref)
|
workspace_content = await workspace_manager.read_file(ws.file_ref)
|
||||||
file_info = await workspace_manager.get_file_info_by_path(file_ref)
|
file_info = await workspace_manager.get_file_info_by_path(ws.file_ref)
|
||||||
filename = sanitize_filename(
|
filename = sanitize_filename(
|
||||||
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# ID reference
|
# ID reference: workspace://abc123 or workspace://abc123#video/mp4
|
||||||
workspace_content = await workspace_manager.read_file_by_id(file_ref)
|
workspace_content = await workspace_manager.read_file_by_id(ws.file_ref)
|
||||||
file_info = await workspace_manager.get_file_info(file_ref)
|
file_info = await workspace_manager.get_file_info(ws.file_ref)
|
||||||
filename = sanitize_filename(
|
filename = sanitize_filename(
|
||||||
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||||
)
|
)
|
||||||
@@ -334,7 +363,21 @@ async def store_media_file(
|
|||||||
|
|
||||||
# Don't re-save if input was already from workspace
|
# Don't re-save if input was already from workspace
|
||||||
if is_from_workspace:
|
if is_from_workspace:
|
||||||
# Return original workspace reference
|
# Return original workspace reference, ensuring MIME type fragment
|
||||||
|
ws = parse_workspace_uri(file)
|
||||||
|
if not ws.mime_type:
|
||||||
|
# Add MIME type fragment if missing (older refs without it)
|
||||||
|
try:
|
||||||
|
if ws.is_path:
|
||||||
|
info = await workspace_manager.get_file_info_by_path(
|
||||||
|
ws.file_ref
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
info = await workspace_manager.get_file_info(ws.file_ref)
|
||||||
|
if info:
|
||||||
|
return MediaFileType(f"{file}#{info.mimeType}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
return MediaFileType(file)
|
return MediaFileType(file)
|
||||||
|
|
||||||
# Save new content to workspace
|
# Save new content to workspace
|
||||||
@@ -346,7 +389,7 @@ async def store_media_file(
|
|||||||
filename=filename,
|
filename=filename,
|
||||||
overwrite=True,
|
overwrite=True,
|
||||||
)
|
)
|
||||||
return MediaFileType(f"workspace://{file_record.id}")
|
return MediaFileType(f"workspace://{file_record.id}#{file_record.mimeType}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid return_format: {return_format}")
|
raise ValueError(f"Invalid return_format: {return_format}")
|
||||||
|
|||||||
@@ -656,6 +656,7 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
|||||||
e2b_api_key: str = Field(default="", description="E2B API key")
|
e2b_api_key: str = Field(default="", description="E2B API key")
|
||||||
nvidia_api_key: str = Field(default="", description="Nvidia API key")
|
nvidia_api_key: str = Field(default="", description="Nvidia API key")
|
||||||
mem0_api_key: str = Field(default="", description="Mem0 API key")
|
mem0_api_key: str = Field(default="", description="Mem0 API key")
|
||||||
|
elevenlabs_api_key: str = Field(default="", description="ElevenLabs API key")
|
||||||
|
|
||||||
linear_client_id: str = Field(default="", description="Linear client ID")
|
linear_client_id: str = Field(default="", description="Linear client ID")
|
||||||
linear_client_secret: str = Field(default="", description="Linear client secret")
|
linear_client_secret: str = Field(default="", description="Linear client secret")
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from backend.data.workspace import (
|
|||||||
soft_delete_workspace_file,
|
soft_delete_workspace_file,
|
||||||
)
|
)
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
|
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -187,6 +188,9 @@ class WorkspaceManager:
|
|||||||
f"{Config().max_file_size_mb}MB limit"
|
f"{Config().max_file_size_mb}MB limit"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Virus scan content before persisting (defense in depth)
|
||||||
|
await scan_content_safe(content, filename=filename)
|
||||||
|
|
||||||
# Determine path with session scoping
|
# Determine path with session scoping
|
||||||
if path is None:
|
if path is None:
|
||||||
path = f"/{filename}"
|
path = f"/{filename}"
|
||||||
|
|||||||
6890
autogpt_platform/backend/poetry.lock
generated
6890
autogpt_platform/backend/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -20,7 +20,8 @@ click = "^8.2.0"
|
|||||||
cryptography = "^45.0"
|
cryptography = "^45.0"
|
||||||
discord-py = "^2.5.2"
|
discord-py = "^2.5.2"
|
||||||
e2b-code-interpreter = "^1.5.2"
|
e2b-code-interpreter = "^1.5.2"
|
||||||
fastapi = "^0.116.1"
|
elevenlabs = "^1.50.0"
|
||||||
|
fastapi = "^0.128.0"
|
||||||
feedparser = "^6.0.11"
|
feedparser = "^6.0.11"
|
||||||
flake8 = "^7.3.0"
|
flake8 = "^7.3.0"
|
||||||
google-api-python-client = "^2.177.0"
|
google-api-python-client = "^2.177.0"
|
||||||
@@ -34,7 +35,7 @@ jinja2 = "^3.1.6"
|
|||||||
jsonref = "^1.1.0"
|
jsonref = "^1.1.0"
|
||||||
jsonschema = "^4.25.0"
|
jsonschema = "^4.25.0"
|
||||||
langfuse = "^3.11.0"
|
langfuse = "^3.11.0"
|
||||||
launchdarkly-server-sdk = "^9.12.0"
|
launchdarkly-server-sdk = "^9.14.1"
|
||||||
mem0ai = "^0.1.115"
|
mem0ai = "^0.1.115"
|
||||||
moviepy = "^2.1.2"
|
moviepy = "^2.1.2"
|
||||||
ollama = "^0.5.1"
|
ollama = "^0.5.1"
|
||||||
@@ -51,8 +52,8 @@ prometheus-client = "^0.22.1"
|
|||||||
prometheus-fastapi-instrumentator = "^7.0.0"
|
prometheus-fastapi-instrumentator = "^7.0.0"
|
||||||
psutil = "^7.0.0"
|
psutil = "^7.0.0"
|
||||||
psycopg2-binary = "^2.9.10"
|
psycopg2-binary = "^2.9.10"
|
||||||
pydantic = { extras = ["email"], version = "^2.11.7" }
|
pydantic = { extras = ["email"], version = "^2.12.5" }
|
||||||
pydantic-settings = "^2.10.1"
|
pydantic-settings = "^2.12.0"
|
||||||
pytest = "^8.4.1"
|
pytest = "^8.4.1"
|
||||||
pytest-asyncio = "^1.1.0"
|
pytest-asyncio = "^1.1.0"
|
||||||
python-dotenv = "^1.1.1"
|
python-dotenv = "^1.1.1"
|
||||||
@@ -64,13 +65,14 @@ sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlal
|
|||||||
sqlalchemy = "^2.0.40"
|
sqlalchemy = "^2.0.40"
|
||||||
strenum = "^0.4.9"
|
strenum = "^0.4.9"
|
||||||
stripe = "^11.5.0"
|
stripe = "^11.5.0"
|
||||||
supabase = "2.17.0"
|
supabase = "2.27.2"
|
||||||
tenacity = "^9.1.2"
|
tenacity = "^9.1.2"
|
||||||
todoist-api-python = "^2.1.7"
|
todoist-api-python = "^2.1.7"
|
||||||
tweepy = "^4.16.0"
|
tweepy = "^4.16.0"
|
||||||
uvicorn = { extras = ["standard"], version = "^0.35.0" }
|
uvicorn = { extras = ["standard"], version = "^0.40.0" }
|
||||||
websockets = "^15.0"
|
websockets = "^15.0"
|
||||||
youtube-transcript-api = "^1.2.1"
|
youtube-transcript-api = "^1.2.1"
|
||||||
|
yt-dlp = "2025.12.08"
|
||||||
zerobouncesdk = "^1.1.2"
|
zerobouncesdk = "^1.1.2"
|
||||||
# NOTE: please insert new dependencies in their alphabetical location
|
# NOTE: please insert new dependencies in their alphabetical location
|
||||||
pytest-snapshot = "^0.9.0"
|
pytest-snapshot = "^0.9.0"
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
"credentials_input_schema": {
|
"credentials_input_schema": {
|
||||||
"properties": {},
|
"properties": {},
|
||||||
"required": [],
|
"required": [],
|
||||||
"title": "TestGraphCredentialsInputSchema",
|
|
||||||
"type": "object"
|
"type": "object"
|
||||||
},
|
},
|
||||||
"description": "A test graph",
|
"description": "A test graph",
|
||||||
|
|||||||
@@ -1,34 +1,14 @@
|
|||||||
[
|
[
|
||||||
{
|
{
|
||||||
"credentials_input_schema": {
|
"created_at": "2025-09-04T13:37:00",
|
||||||
"properties": {},
|
|
||||||
"required": [],
|
|
||||||
"title": "TestGraphCredentialsInputSchema",
|
|
||||||
"type": "object"
|
|
||||||
},
|
|
||||||
"description": "A test graph",
|
"description": "A test graph",
|
||||||
"forked_from_id": null,
|
"forked_from_id": null,
|
||||||
"forked_from_version": null,
|
"forked_from_version": null,
|
||||||
"has_external_trigger": false,
|
|
||||||
"has_human_in_the_loop": false,
|
|
||||||
"has_sensitive_action": false,
|
|
||||||
"id": "graph-123",
|
"id": "graph-123",
|
||||||
"input_schema": {
|
|
||||||
"properties": {},
|
|
||||||
"required": [],
|
|
||||||
"type": "object"
|
|
||||||
},
|
|
||||||
"instructions": null,
|
"instructions": null,
|
||||||
"is_active": true,
|
"is_active": true,
|
||||||
"name": "Test Graph",
|
"name": "Test Graph",
|
||||||
"output_schema": {
|
|
||||||
"properties": {},
|
|
||||||
"required": [],
|
|
||||||
"type": "object"
|
|
||||||
},
|
|
||||||
"recommended_schedule_cron": null,
|
"recommended_schedule_cron": null,
|
||||||
"sub_graphs": [],
|
|
||||||
"trigger_setup_info": null,
|
|
||||||
"user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
"user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||||
"version": 1
|
"version": 1
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
|
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
|
||||||
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
import { GraphModel } from "@/app/api/__generated__/models/graphModel";
|
||||||
import { CredentialsInput } from "@/components/contextual/CredentialsInput/CredentialsInput";
|
import { CredentialsInput } from "@/components/contextual/CredentialsInput/CredentialsInput";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
import { getSchemaDefaultCredentials } from "../../helpers";
|
import { getSchemaDefaultCredentials } from "../../helpers";
|
||||||
@@ -9,7 +9,7 @@ type Credential = CredentialsMetaInput | undefined;
|
|||||||
type Credentials = Record<string, Credential>;
|
type Credentials = Record<string, Credential>;
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
agent: GraphMeta | null;
|
agent: GraphModel | null;
|
||||||
siblingInputs?: Record<string, any>;
|
siblingInputs?: Record<string, any>;
|
||||||
onCredentialsChange: (
|
onCredentialsChange: (
|
||||||
credentials: Record<string, CredentialsMetaInput>,
|
credentials: Record<string, CredentialsMetaInput>,
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
|
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
|
||||||
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
import { GraphModel } from "@/app/api/__generated__/models/graphModel";
|
||||||
import { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api/types";
|
import { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api/types";
|
||||||
|
|
||||||
export function getCredentialFields(
|
export function getCredentialFields(
|
||||||
agent: GraphMeta | null,
|
agent: GraphModel | null,
|
||||||
): AgentCredentialsFields {
|
): AgentCredentialsFields {
|
||||||
if (!agent) return {};
|
if (!agent) return {};
|
||||||
|
|
||||||
|
|||||||
@@ -3,10 +3,10 @@ import type {
|
|||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
} from "@/lib/autogpt-server-api/types";
|
} from "@/lib/autogpt-server-api/types";
|
||||||
import type { InputValues } from "./types";
|
import type { InputValues } from "./types";
|
||||||
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
import { GraphModel } from "@/app/api/__generated__/models/graphModel";
|
||||||
|
|
||||||
export function computeInitialAgentInputs(
|
export function computeInitialAgentInputs(
|
||||||
agent: GraphMeta | null,
|
agent: GraphModel | null,
|
||||||
existingInputs?: InputValues | null,
|
existingInputs?: InputValues | null,
|
||||||
): InputValues {
|
): InputValues {
|
||||||
const properties = agent?.input_schema?.properties || {};
|
const properties = agent?.input_schema?.properties || {};
|
||||||
@@ -29,7 +29,7 @@ export function computeInitialAgentInputs(
|
|||||||
}
|
}
|
||||||
|
|
||||||
type IsRunDisabledParams = {
|
type IsRunDisabledParams = {
|
||||||
agent: GraphMeta | null;
|
agent: GraphModel | null;
|
||||||
isRunning: boolean;
|
isRunning: boolean;
|
||||||
agentInputs: InputValues | null | undefined;
|
agentInputs: InputValues | null | undefined;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -0,0 +1,90 @@
|
|||||||
|
import { NextResponse } from "next/server";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Safely encode a value as JSON for embedding in a script tag.
|
||||||
|
* Escapes characters that could break out of the script context to prevent XSS.
|
||||||
|
*/
|
||||||
|
function safeJsonStringify(value: unknown): string {
|
||||||
|
return JSON.stringify(value)
|
||||||
|
.replace(/</g, "\\u003c")
|
||||||
|
.replace(/>/g, "\\u003e")
|
||||||
|
.replace(/&/g, "\\u0026");
|
||||||
|
}
|
||||||
|
|
||||||
|
// MCP-specific OAuth callback route.
|
||||||
|
//
|
||||||
|
// Unlike the generic oauth_callback which relies on window.opener.postMessage,
|
||||||
|
// this route uses BroadcastChannel as the PRIMARY communication method.
|
||||||
|
// This is critical because cross-origin OAuth flows (e.g. Sentry → localhost)
|
||||||
|
// often lose window.opener due to COOP (Cross-Origin-Opener-Policy) headers.
|
||||||
|
//
|
||||||
|
// BroadcastChannel works across all same-origin tabs/popups regardless of opener.
|
||||||
|
export async function GET(request: Request) {
|
||||||
|
const { searchParams } = new URL(request.url);
|
||||||
|
const code = searchParams.get("code");
|
||||||
|
const state = searchParams.get("state");
|
||||||
|
|
||||||
|
const success = Boolean(code && state);
|
||||||
|
const message = success
|
||||||
|
? { success: true, code, state }
|
||||||
|
: {
|
||||||
|
success: false,
|
||||||
|
message: `Missing parameters: ${searchParams.toString()}`,
|
||||||
|
};
|
||||||
|
|
||||||
|
return new NextResponse(
|
||||||
|
`<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head><title>MCP Sign-in</title></head>
|
||||||
|
<body style="font-family: system-ui, -apple-system, sans-serif; display: flex; align-items: center; justify-content: center; min-height: 100vh; margin: 0; background: #f9fafb;">
|
||||||
|
<div style="text-align: center; max-width: 400px; padding: 2rem;">
|
||||||
|
<div id="spinner" style="margin: 0 auto 1rem; width: 32px; height: 32px; border: 3px solid #e5e7eb; border-top-color: #3b82f6; border-radius: 50%; animation: spin 0.8s linear infinite;"></div>
|
||||||
|
<p id="status" style="color: #374151; font-size: 16px;">Completing sign-in...</p>
|
||||||
|
</div>
|
||||||
|
<style>@keyframes spin { to { transform: rotate(360deg); } }</style>
|
||||||
|
<script>
|
||||||
|
(function() {
|
||||||
|
var msg = ${safeJsonStringify(message)};
|
||||||
|
var sent = false;
|
||||||
|
|
||||||
|
// Method 1: BroadcastChannel (reliable across tabs/popups, no opener needed)
|
||||||
|
try {
|
||||||
|
var bc = new BroadcastChannel("mcp_oauth");
|
||||||
|
bc.postMessage({ type: "mcp_oauth_result", success: msg.success, code: msg.code, state: msg.state, message: msg.message });
|
||||||
|
bc.close();
|
||||||
|
sent = true;
|
||||||
|
} catch(e) { console.warn("BroadcastChannel failed:", e); }
|
||||||
|
|
||||||
|
// Method 2: window.opener.postMessage (fallback for same-origin popups)
|
||||||
|
try {
|
||||||
|
if (window.opener && !window.opener.closed) {
|
||||||
|
window.opener.postMessage(
|
||||||
|
{ message_type: "mcp_oauth_result", success: msg.success, code: msg.code, state: msg.state, message: msg.message },
|
||||||
|
window.location.origin
|
||||||
|
);
|
||||||
|
sent = true;
|
||||||
|
}
|
||||||
|
} catch(e) { console.warn("postMessage failed:", e); }
|
||||||
|
|
||||||
|
var statusEl = document.getElementById("status");
|
||||||
|
var spinnerEl = document.getElementById("spinner");
|
||||||
|
spinnerEl.style.display = "none";
|
||||||
|
|
||||||
|
if (msg.success && sent) {
|
||||||
|
statusEl.textContent = "Sign-in complete! This window will close.";
|
||||||
|
statusEl.style.color = "#059669";
|
||||||
|
setTimeout(function() { window.close(); }, 1500);
|
||||||
|
} else if (msg.success) {
|
||||||
|
statusEl.textContent = "Sign-in successful! You can close this tab and return to the builder.";
|
||||||
|
statusEl.style.color = "#059669";
|
||||||
|
} else {
|
||||||
|
statusEl.textContent = "Sign-in failed: " + (msg.message || "Unknown error");
|
||||||
|
statusEl.style.color = "#dc2626";
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>`,
|
||||||
|
{ headers: { "Content-Type": "text/html" } },
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -47,7 +47,10 @@ export type CustomNode = XYNode<CustomNodeData, "custom">;
|
|||||||
|
|
||||||
export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||||
({ data, id: nodeId, selected }) => {
|
({ data, id: nodeId, selected }) => {
|
||||||
const { inputSchema, outputSchema } = useCustomNode({ data, nodeId });
|
const { inputSchema, outputSchema, isMCPWithTool } = useCustomNode({
|
||||||
|
data,
|
||||||
|
nodeId,
|
||||||
|
});
|
||||||
|
|
||||||
const isAgent = data.uiType === BlockUIType.AGENT;
|
const isAgent = data.uiType === BlockUIType.AGENT;
|
||||||
|
|
||||||
@@ -98,6 +101,7 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
|||||||
jsonSchema={preprocessInputSchema(inputSchema)}
|
jsonSchema={preprocessInputSchema(inputSchema)}
|
||||||
nodeId={nodeId}
|
nodeId={nodeId}
|
||||||
uiType={data.uiType}
|
uiType={data.uiType}
|
||||||
|
isMCPWithTool={isMCPWithTool}
|
||||||
className={cn(
|
className={cn(
|
||||||
"bg-white px-4",
|
"bg-white px-4",
|
||||||
isWebhook && "pointer-events-none opacity-50",
|
isWebhook && "pointer-events-none opacity-50",
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import {
|
|||||||
TooltipProvider,
|
TooltipProvider,
|
||||||
TooltipTrigger,
|
TooltipTrigger,
|
||||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||||
|
import { SpecialBlockID } from "@/lib/autogpt-server-api";
|
||||||
import { beautifyString, cn } from "@/lib/utils";
|
import { beautifyString, cn } from "@/lib/utils";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
import { CustomNodeData } from "../CustomNode";
|
import { CustomNodeData } from "../CustomNode";
|
||||||
@@ -20,8 +21,15 @@ type Props = {
|
|||||||
|
|
||||||
export const NodeHeader = ({ data, nodeId }: Props) => {
|
export const NodeHeader = ({ data, nodeId }: Props) => {
|
||||||
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
||||||
|
const isMCPWithTool =
|
||||||
|
data.block_id === SpecialBlockID.MCP_TOOL &&
|
||||||
|
!!data.hardcodedValues?.selected_tool;
|
||||||
|
|
||||||
const title =
|
const title =
|
||||||
(data.metadata?.customized_name as string) ||
|
(data.metadata?.customized_name as string) ||
|
||||||
|
(isMCPWithTool
|
||||||
|
? `${data.hardcodedValues.server_name || "MCP"}: ${beautifyString(data.hardcodedValues.selected_tool)}`
|
||||||
|
: null) ||
|
||||||
data.hardcodedValues?.agent_name ||
|
data.hardcodedValues?.agent_name ||
|
||||||
data.title;
|
data.title;
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,30 @@ import { CustomNodeData } from "./CustomNode";
|
|||||||
import { BlockUIType } from "../../../types";
|
import { BlockUIType } from "../../../types";
|
||||||
import { useMemo } from "react";
|
import { useMemo } from "react";
|
||||||
import { mergeSchemaForResolution } from "./helpers";
|
import { mergeSchemaForResolution } from "./helpers";
|
||||||
|
import { SpecialBlockID } from "@/lib/autogpt-server-api";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build a dynamic input schema for MCP blocks.
|
||||||
|
*
|
||||||
|
* When a tool has been selected (tool_input_schema is populated), the block
|
||||||
|
* renders only the selected tool's input parameters. Credentials are NOT
|
||||||
|
* included because authentication is already handled by the MCP dialog's
|
||||||
|
* OAuth flow and stored server-side.
|
||||||
|
*
|
||||||
|
* Static fields like server_url, selected_tool, available_tools, and
|
||||||
|
* tool_arguments are hidden because they're pre-configured from the dialog.
|
||||||
|
*/
|
||||||
|
function buildMCPInputSchema(
|
||||||
|
toolInputSchema: Record<string, any>,
|
||||||
|
): Record<string, any> {
|
||||||
|
return {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
...(toolInputSchema.properties ?? {}),
|
||||||
|
},
|
||||||
|
required: [...(toolInputSchema.required ?? [])],
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
export const useCustomNode = ({
|
export const useCustomNode = ({
|
||||||
data,
|
data,
|
||||||
@@ -19,10 +43,15 @@ export const useCustomNode = ({
|
|||||||
);
|
);
|
||||||
|
|
||||||
const isAgent = data.uiType === BlockUIType.AGENT;
|
const isAgent = data.uiType === BlockUIType.AGENT;
|
||||||
|
const isMCPWithTool =
|
||||||
|
data.block_id === SpecialBlockID.MCP_TOOL &&
|
||||||
|
!!data.hardcodedValues?.tool_input_schema?.properties;
|
||||||
|
|
||||||
const currentInputSchema = isAgent
|
const currentInputSchema = isAgent
|
||||||
? (data.hardcodedValues.input_schema ?? {})
|
? (data.hardcodedValues.input_schema ?? {})
|
||||||
: data.inputSchema;
|
: isMCPWithTool
|
||||||
|
? buildMCPInputSchema(data.hardcodedValues.tool_input_schema)
|
||||||
|
: data.inputSchema;
|
||||||
const currentOutputSchema = isAgent
|
const currentOutputSchema = isAgent
|
||||||
? (data.hardcodedValues.output_schema ?? {})
|
? (data.hardcodedValues.output_schema ?? {})
|
||||||
: data.outputSchema;
|
: data.outputSchema;
|
||||||
@@ -54,5 +83,6 @@ export const useCustomNode = ({
|
|||||||
return {
|
return {
|
||||||
inputSchema,
|
inputSchema,
|
||||||
outputSchema,
|
outputSchema,
|
||||||
|
isMCPWithTool,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -9,39 +9,63 @@ interface FormCreatorProps {
|
|||||||
jsonSchema: RJSFSchema;
|
jsonSchema: RJSFSchema;
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
uiType: BlockUIType;
|
uiType: BlockUIType;
|
||||||
|
/** When true the block is an MCP Tool with a selected tool. */
|
||||||
|
isMCPWithTool?: boolean;
|
||||||
showHandles?: boolean;
|
showHandles?: boolean;
|
||||||
className?: string;
|
className?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const FormCreator: React.FC<FormCreatorProps> = React.memo(
|
export const FormCreator: React.FC<FormCreatorProps> = React.memo(
|
||||||
({ jsonSchema, nodeId, uiType, showHandles = true, className }) => {
|
({
|
||||||
|
jsonSchema,
|
||||||
|
nodeId,
|
||||||
|
uiType,
|
||||||
|
isMCPWithTool = false,
|
||||||
|
showHandles = true,
|
||||||
|
className,
|
||||||
|
}) => {
|
||||||
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
||||||
|
|
||||||
const getHardCodedValues = useNodeStore(
|
const getHardCodedValues = useNodeStore(
|
||||||
(state) => state.getHardCodedValues,
|
(state) => state.getHardCodedValues,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const isAgent = uiType === BlockUIType.AGENT;
|
||||||
|
|
||||||
const handleChange = ({ formData }: any) => {
|
const handleChange = ({ formData }: any) => {
|
||||||
if ("credentials" in formData && !formData.credentials?.id) {
|
if ("credentials" in formData && !formData.credentials?.id) {
|
||||||
delete formData.credentials;
|
delete formData.credentials;
|
||||||
}
|
}
|
||||||
|
|
||||||
const updatedValues =
|
let updatedValues;
|
||||||
uiType === BlockUIType.AGENT
|
if (isAgent) {
|
||||||
? {
|
updatedValues = {
|
||||||
...getHardCodedValues(nodeId),
|
...getHardCodedValues(nodeId),
|
||||||
inputs: formData,
|
inputs: formData,
|
||||||
}
|
};
|
||||||
: formData;
|
} else if (isMCPWithTool) {
|
||||||
|
// All form fields are tool arguments (credentials handled by dialog)
|
||||||
|
updatedValues = {
|
||||||
|
...getHardCodedValues(nodeId),
|
||||||
|
tool_arguments: formData,
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
updatedValues = formData;
|
||||||
|
}
|
||||||
|
|
||||||
updateNodeData(nodeId, { hardcodedValues: updatedValues });
|
updateNodeData(nodeId, { hardcodedValues: updatedValues });
|
||||||
};
|
};
|
||||||
|
|
||||||
const hardcodedValues = getHardCodedValues(nodeId);
|
const hardcodedValues = getHardCodedValues(nodeId);
|
||||||
const initialValues =
|
|
||||||
uiType === BlockUIType.AGENT
|
let initialValues;
|
||||||
? (hardcodedValues.inputs ?? {})
|
if (isAgent) {
|
||||||
: hardcodedValues;
|
initialValues = hardcodedValues.inputs ?? {};
|
||||||
|
} else if (isMCPWithTool) {
|
||||||
|
initialValues = hardcodedValues.tool_arguments ?? {};
|
||||||
|
} else {
|
||||||
|
initialValues = hardcodedValues;
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import { Button } from "@/components/__legacy__/ui/button";
|
import { Button } from "@/components/__legacy__/ui/button";
|
||||||
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
|
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
|
||||||
import { beautifyString, cn } from "@/lib/utils";
|
import { beautifyString, cn } from "@/lib/utils";
|
||||||
import React, { ButtonHTMLAttributes } from "react";
|
import React, { ButtonHTMLAttributes, useCallback, useState } from "react";
|
||||||
import { highlightText } from "./helpers";
|
import { highlightText } from "./helpers";
|
||||||
import { PlusIcon } from "@phosphor-icons/react";
|
import { PlusIcon } from "@phosphor-icons/react";
|
||||||
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
|
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
|
||||||
@@ -9,6 +9,12 @@ import { useControlPanelStore } from "../../../stores/controlPanelStore";
|
|||||||
import { blockDragPreviewStyle } from "./style";
|
import { blockDragPreviewStyle } from "./style";
|
||||||
import { useReactFlow } from "@xyflow/react";
|
import { useReactFlow } from "@xyflow/react";
|
||||||
import { useNodeStore } from "../../../stores/nodeStore";
|
import { useNodeStore } from "../../../stores/nodeStore";
|
||||||
|
import { SpecialBlockID } from "@/lib/autogpt-server-api";
|
||||||
|
import {
|
||||||
|
MCPToolDialog,
|
||||||
|
type MCPToolDialogResult,
|
||||||
|
} from "@/app/(platform)/build/components/legacy-builder/MCPToolDialog";
|
||||||
|
|
||||||
interface Props extends ButtonHTMLAttributes<HTMLButtonElement> {
|
interface Props extends ButtonHTMLAttributes<HTMLButtonElement> {
|
||||||
title?: string;
|
title?: string;
|
||||||
description?: string;
|
description?: string;
|
||||||
@@ -33,22 +39,52 @@ export const Block: BlockComponent = ({
|
|||||||
);
|
);
|
||||||
const { setViewport } = useReactFlow();
|
const { setViewport } = useReactFlow();
|
||||||
const { addBlock } = useNodeStore();
|
const { addBlock } = useNodeStore();
|
||||||
|
const [mcpDialogOpen, setMcpDialogOpen] = useState(false);
|
||||||
|
|
||||||
|
const isMCPBlock = blockData.id === SpecialBlockID.MCP_TOOL;
|
||||||
|
|
||||||
|
const addBlockAndCenter = useCallback(
|
||||||
|
(block: BlockInfo, hardcodedValues?: Record<string, any>) => {
|
||||||
|
const customNode = addBlock(block, hardcodedValues);
|
||||||
|
setTimeout(() => {
|
||||||
|
setViewport(
|
||||||
|
{
|
||||||
|
x: -customNode.position.x * 0.8 + window.innerWidth / 2,
|
||||||
|
y: -customNode.position.y * 0.8 + (window.innerHeight - 400) / 2,
|
||||||
|
zoom: 0.8,
|
||||||
|
},
|
||||||
|
{ duration: 500 },
|
||||||
|
);
|
||||||
|
}, 50);
|
||||||
|
},
|
||||||
|
[addBlock, setViewport],
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleMCPToolConfirm = useCallback(
|
||||||
|
(result: MCPToolDialogResult) => {
|
||||||
|
addBlockAndCenter(blockData, {
|
||||||
|
server_url: result.serverUrl,
|
||||||
|
server_name: result.serverName,
|
||||||
|
selected_tool: result.selectedTool,
|
||||||
|
tool_input_schema: result.toolInputSchema,
|
||||||
|
available_tools: result.availableTools,
|
||||||
|
credential_id: result.credentialId ?? "",
|
||||||
|
});
|
||||||
|
setMcpDialogOpen(false);
|
||||||
|
},
|
||||||
|
[addBlockAndCenter, blockData],
|
||||||
|
);
|
||||||
|
|
||||||
const handleClick = () => {
|
const handleClick = () => {
|
||||||
const customNode = addBlock(blockData);
|
if (isMCPBlock) {
|
||||||
setTimeout(() => {
|
setMcpDialogOpen(true);
|
||||||
setViewport(
|
return;
|
||||||
{
|
}
|
||||||
x: -customNode.position.x * 0.8 + window.innerWidth / 2,
|
addBlockAndCenter(blockData);
|
||||||
y: -customNode.position.y * 0.8 + (window.innerHeight - 400) / 2,
|
|
||||||
zoom: 0.8,
|
|
||||||
},
|
|
||||||
{ duration: 500 },
|
|
||||||
);
|
|
||||||
}, 50);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleDragStart = (e: React.DragEvent<HTMLButtonElement>) => {
|
const handleDragStart = (e: React.DragEvent<HTMLButtonElement>) => {
|
||||||
|
if (isMCPBlock) return;
|
||||||
e.dataTransfer.effectAllowed = "copy";
|
e.dataTransfer.effectAllowed = "copy";
|
||||||
e.dataTransfer.setData("application/reactflow", JSON.stringify(blockData));
|
e.dataTransfer.setData("application/reactflow", JSON.stringify(blockData));
|
||||||
|
|
||||||
@@ -71,46 +107,56 @@ export const Block: BlockComponent = ({
|
|||||||
: undefined;
|
: undefined;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Button
|
<>
|
||||||
draggable={true}
|
<Button
|
||||||
data-id={blockDataId}
|
draggable={!isMCPBlock}
|
||||||
className={cn(
|
data-id={blockDataId}
|
||||||
"group flex h-16 w-full min-w-[7.5rem] items-center justify-start space-x-3 whitespace-normal rounded-[0.75rem] bg-zinc-50 px-[0.875rem] py-[0.625rem] text-start shadow-none",
|
|
||||||
"hover:cursor-default hover:bg-zinc-100 focus:ring-0 active:bg-zinc-100 active:ring-1 active:ring-zinc-300 disabled:cursor-not-allowed",
|
|
||||||
className,
|
|
||||||
)}
|
|
||||||
onDragStart={handleDragStart}
|
|
||||||
onClick={handleClick}
|
|
||||||
{...rest}
|
|
||||||
>
|
|
||||||
<div className="flex flex-1 flex-col items-start gap-0.5">
|
|
||||||
{title && (
|
|
||||||
<span
|
|
||||||
className={cn(
|
|
||||||
"line-clamp-1 font-sans text-sm font-medium leading-[1.375rem] text-zinc-800 group-disabled:text-zinc-400",
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
{highlightText(beautifyString(title), highlightedText)}
|
|
||||||
</span>
|
|
||||||
)}
|
|
||||||
{description && (
|
|
||||||
<span
|
|
||||||
className={cn(
|
|
||||||
"line-clamp-1 font-sans text-xs font-normal leading-5 text-zinc-500 group-disabled:text-zinc-400",
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
{highlightText(description, highlightedText)}
|
|
||||||
</span>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
<div
|
|
||||||
className={cn(
|
className={cn(
|
||||||
"flex h-7 w-7 items-center justify-center rounded-[0.5rem] bg-zinc-700 group-disabled:bg-zinc-400",
|
"group flex h-16 w-full min-w-[7.5rem] items-center justify-start space-x-3 whitespace-normal rounded-[0.75rem] bg-zinc-50 px-[0.875rem] py-[0.625rem] text-start shadow-none",
|
||||||
|
"hover:cursor-default hover:bg-zinc-100 focus:ring-0 active:bg-zinc-100 active:ring-1 active:ring-zinc-300 disabled:cursor-not-allowed",
|
||||||
|
isMCPBlock && "hover:cursor-pointer",
|
||||||
|
className,
|
||||||
)}
|
)}
|
||||||
|
onDragStart={handleDragStart}
|
||||||
|
onClick={handleClick}
|
||||||
|
{...rest}
|
||||||
>
|
>
|
||||||
<PlusIcon className="h-5 w-5 text-zinc-50" />
|
<div className="flex flex-1 flex-col items-start gap-0.5">
|
||||||
</div>
|
{title && (
|
||||||
</Button>
|
<span
|
||||||
|
className={cn(
|
||||||
|
"line-clamp-1 font-sans text-sm font-medium leading-[1.375rem] text-zinc-800 group-disabled:text-zinc-400",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{highlightText(beautifyString(title), highlightedText)}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
{description && (
|
||||||
|
<span
|
||||||
|
className={cn(
|
||||||
|
"line-clamp-1 font-sans text-xs font-normal leading-5 text-zinc-500 group-disabled:text-zinc-400",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{highlightText(description, highlightedText)}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"flex h-7 w-7 items-center justify-center rounded-[0.5rem] bg-zinc-700 group-disabled:bg-zinc-400",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<PlusIcon className="h-5 w-5 text-zinc-50" />
|
||||||
|
</div>
|
||||||
|
</Button>
|
||||||
|
{isMCPBlock && (
|
||||||
|
<MCPToolDialog
|
||||||
|
open={mcpDialogOpen}
|
||||||
|
onClose={() => setMcpDialogOpen(false)}
|
||||||
|
onConfirm={handleMCPToolConfirm}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,13 @@ import {
|
|||||||
TooltipTrigger,
|
TooltipTrigger,
|
||||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||||
import { GraphMeta } from "@/lib/autogpt-server-api";
|
import { GraphMeta } from "@/lib/autogpt-server-api";
|
||||||
|
import {
|
||||||
|
MCPToolDialog,
|
||||||
|
type MCPToolDialogResult,
|
||||||
|
} from "@/app/(platform)/build/components/legacy-builder/MCPToolDialog";
|
||||||
import jaro from "jaro-winkler";
|
import jaro from "jaro-winkler";
|
||||||
|
import { getV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||||
|
import { okData } from "@/app/api/helpers";
|
||||||
|
|
||||||
type _Block = Omit<Block, "inputSchema" | "outputSchema"> & {
|
type _Block = Omit<Block, "inputSchema" | "outputSchema"> & {
|
||||||
uiKey?: string;
|
uiKey?: string;
|
||||||
@@ -92,6 +98,7 @@ export function BlocksControl({
|
|||||||
const [searchQuery, setSearchQuery] = useState("");
|
const [searchQuery, setSearchQuery] = useState("");
|
||||||
const deferredSearchQuery = useDeferredValue(searchQuery);
|
const deferredSearchQuery = useDeferredValue(searchQuery);
|
||||||
const [selectedCategory, setSelectedCategory] = useState<string | null>(null);
|
const [selectedCategory, setSelectedCategory] = useState<string | null>(null);
|
||||||
|
const [mcpDialogOpen, setMcpDialogOpen] = useState(false);
|
||||||
|
|
||||||
const blocks = useSearchableBlocks(_blocks);
|
const blocks = useSearchableBlocks(_blocks);
|
||||||
|
|
||||||
@@ -107,6 +114,8 @@ export function BlocksControl({
|
|||||||
.filter((b) => b.uiType !== BlockUIType.AGENT)
|
.filter((b) => b.uiType !== BlockUIType.AGENT)
|
||||||
.sort((a, b) => a.name.localeCompare(b.name));
|
.sort((a, b) => a.name.localeCompare(b.name));
|
||||||
|
|
||||||
|
// Agent blocks are created from GraphMeta which doesn't include schemas.
|
||||||
|
// Schemas will be fetched on-demand when the block is actually added.
|
||||||
const agentBlockList = flows
|
const agentBlockList = flows
|
||||||
.map((flow): _Block => {
|
.map((flow): _Block => {
|
||||||
return {
|
return {
|
||||||
@@ -116,8 +125,9 @@ export function BlocksControl({
|
|||||||
`Ver.${flow.version}` +
|
`Ver.${flow.version}` +
|
||||||
(flow.description ? ` | ${flow.description}` : ""),
|
(flow.description ? ` | ${flow.description}` : ""),
|
||||||
categories: [{ category: "AGENT", description: "" }],
|
categories: [{ category: "AGENT", description: "" }],
|
||||||
inputSchema: flow.input_schema,
|
// Empty schemas - will be populated when block is added
|
||||||
outputSchema: flow.output_schema,
|
inputSchema: { type: "object", properties: {} },
|
||||||
|
outputSchema: { type: "object", properties: {} },
|
||||||
staticOutput: false,
|
staticOutput: false,
|
||||||
uiType: BlockUIType.AGENT,
|
uiType: BlockUIType.AGENT,
|
||||||
costs: [],
|
costs: [],
|
||||||
@@ -125,8 +135,7 @@ export function BlocksControl({
|
|||||||
hardcodedValues: {
|
hardcodedValues: {
|
||||||
graph_id: flow.id,
|
graph_id: flow.id,
|
||||||
graph_version: flow.version,
|
graph_version: flow.version,
|
||||||
input_schema: flow.input_schema,
|
// Schemas will be fetched on-demand when block is added
|
||||||
output_schema: flow.output_schema,
|
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
})
|
})
|
||||||
@@ -182,6 +191,58 @@ export function BlocksControl({
|
|||||||
setSelectedCategory(null);
|
setSelectedCategory(null);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
const handleMCPToolConfirm = useCallback(
|
||||||
|
(result: MCPToolDialogResult) => {
|
||||||
|
addBlock(SpecialBlockID.MCP_TOOL, "MCPToolBlock", {
|
||||||
|
server_url: result.serverUrl,
|
||||||
|
server_name: result.serverName,
|
||||||
|
selected_tool: result.selectedTool,
|
||||||
|
tool_input_schema: result.toolInputSchema,
|
||||||
|
available_tools: result.availableTools,
|
||||||
|
credential_id: result.credentialId ?? "",
|
||||||
|
});
|
||||||
|
setMcpDialogOpen(false);
|
||||||
|
},
|
||||||
|
[addBlock],
|
||||||
|
);
|
||||||
|
|
||||||
|
// Handler to add a block, fetching graph data on-demand for agent blocks
|
||||||
|
const handleAddBlock = useCallback(
|
||||||
|
async (block: _Block & { notAvailable: string | null }) => {
|
||||||
|
if (block.notAvailable) return;
|
||||||
|
|
||||||
|
// For MCP blocks, open the configuration dialog instead of placing directly
|
||||||
|
if (block.id === SpecialBlockID.MCP_TOOL) {
|
||||||
|
setMcpDialogOpen(true);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// For agent blocks, fetch the full graph to get schemas
|
||||||
|
if (block.uiType === BlockUIType.AGENT && block.hardcodedValues) {
|
||||||
|
const graphID = block.hardcodedValues.graph_id as string;
|
||||||
|
const graphVersion = block.hardcodedValues.graph_version as number;
|
||||||
|
const graphData = okData(
|
||||||
|
await getV1GetSpecificGraph(graphID, { version: graphVersion }),
|
||||||
|
);
|
||||||
|
|
||||||
|
if (graphData) {
|
||||||
|
addBlock(block.id, block.name, {
|
||||||
|
...block.hardcodedValues,
|
||||||
|
input_schema: graphData.input_schema,
|
||||||
|
output_schema: graphData.output_schema,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// Fallback: add without schemas (will be incomplete)
|
||||||
|
console.error("Failed to fetch graph data for agent block");
|
||||||
|
addBlock(block.id, block.name, block.hardcodedValues || {});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
addBlock(block.id, block.name, block.hardcodedValues || {});
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[addBlock],
|
||||||
|
);
|
||||||
|
|
||||||
// Extract unique categories from blocks
|
// Extract unique categories from blocks
|
||||||
const categories = useMemo(() => {
|
const categories = useMemo(() => {
|
||||||
return Array.from(
|
return Array.from(
|
||||||
@@ -195,165 +256,179 @@ export function BlocksControl({
|
|||||||
}, [blocks]);
|
}, [blocks]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Popover
|
<>
|
||||||
open={pinBlocksPopover ? true : undefined}
|
<Popover
|
||||||
onOpenChange={(open) => open || resetFilters()}
|
open={pinBlocksPopover ? true : undefined}
|
||||||
>
|
onOpenChange={(open) => open || resetFilters()}
|
||||||
<Tooltip delayDuration={500}>
|
|
||||||
<TooltipTrigger asChild>
|
|
||||||
<PopoverTrigger asChild>
|
|
||||||
<Button
|
|
||||||
variant="ghost"
|
|
||||||
size="icon"
|
|
||||||
data-id="blocks-control-popover-trigger"
|
|
||||||
data-testid="blocks-control-blocks-button"
|
|
||||||
name="Blocks"
|
|
||||||
className="dark:hover:bg-slate-800"
|
|
||||||
>
|
|
||||||
<IconToyBrick />
|
|
||||||
</Button>
|
|
||||||
</PopoverTrigger>
|
|
||||||
</TooltipTrigger>
|
|
||||||
<TooltipContent side="right">Blocks</TooltipContent>
|
|
||||||
</Tooltip>
|
|
||||||
<PopoverContent
|
|
||||||
side="right"
|
|
||||||
sideOffset={22}
|
|
||||||
align="start"
|
|
||||||
className="absolute -top-3 w-[17rem] rounded-xl border-none p-0 shadow-none md:w-[30rem]"
|
|
||||||
data-id="blocks-control-popover-content"
|
|
||||||
>
|
>
|
||||||
<Card className="p-3 pb-0 dark:bg-slate-900">
|
<Tooltip delayDuration={500}>
|
||||||
<CardHeader className="flex flex-col gap-x-8 gap-y-1 p-3 px-2">
|
<TooltipTrigger asChild>
|
||||||
<div className="items-center justify-between">
|
<PopoverTrigger asChild>
|
||||||
<Label
|
<Button
|
||||||
htmlFor="search-blocks"
|
variant="ghost"
|
||||||
className="whitespace-nowrap text-base font-bold text-black dark:text-white 2xl:text-xl"
|
size="icon"
|
||||||
data-id="blocks-control-label"
|
data-id="blocks-control-popover-trigger"
|
||||||
data-testid="blocks-control-blocks-label"
|
data-testid="blocks-control-blocks-button"
|
||||||
|
name="Blocks"
|
||||||
|
className="dark:hover:bg-slate-800"
|
||||||
>
|
>
|
||||||
Blocks
|
<IconToyBrick />
|
||||||
</Label>
|
</Button>
|
||||||
</div>
|
</PopoverTrigger>
|
||||||
<div className="relative flex items-center">
|
</TooltipTrigger>
|
||||||
<MagnifyingGlassIcon className="absolute m-2 h-5 w-5 text-gray-500 dark:text-gray-400" />
|
<TooltipContent side="right">Blocks</TooltipContent>
|
||||||
<Input
|
</Tooltip>
|
||||||
id="search-blocks"
|
<PopoverContent
|
||||||
type="text"
|
side="right"
|
||||||
placeholder="Search blocks"
|
sideOffset={22}
|
||||||
value={searchQuery}
|
align="start"
|
||||||
onChange={(e) => setSearchQuery(e.target.value)}
|
className="absolute -top-3 w-[17rem] rounded-xl border-none p-0 shadow-none md:w-[30rem]"
|
||||||
className="rounded-lg px-8 py-5 dark:bg-slate-800 dark:text-white"
|
data-id="blocks-control-popover-content"
|
||||||
data-id="blocks-control-search-input"
|
>
|
||||||
autoComplete="off"
|
<Card className="p-3 pb-0 dark:bg-slate-900">
|
||||||
/>
|
<CardHeader className="flex flex-col gap-x-8 gap-y-1 p-3 px-2">
|
||||||
</div>
|
<div className="items-center justify-between">
|
||||||
<div
|
<Label
|
||||||
className="mt-2 flex flex-wrap gap-2"
|
htmlFor="search-blocks"
|
||||||
data-testid="blocks-categories-list"
|
className="whitespace-nowrap text-base font-bold text-black dark:text-white 2xl:text-xl"
|
||||||
>
|
data-id="blocks-control-label"
|
||||||
{categories.map((category) => {
|
data-testid="blocks-control-blocks-label"
|
||||||
const color = getPrimaryCategoryColor([
|
|
||||||
{ category: category || "All", description: "" },
|
|
||||||
]);
|
|
||||||
const colorClass =
|
|
||||||
selectedCategory === category ? `${color}` : "";
|
|
||||||
return (
|
|
||||||
<div
|
|
||||||
key={category}
|
|
||||||
data-testid="blocks-category"
|
|
||||||
role="button"
|
|
||||||
className={`cursor-pointer rounded-xl border px-2 py-2 text-xs font-medium dark:border-slate-700 dark:text-white ${colorClass}`}
|
|
||||||
onClick={() =>
|
|
||||||
setSelectedCategory(
|
|
||||||
selectedCategory === category ? null : category,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
>
|
|
||||||
{beautifyString((category || "All").toLowerCase())}
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
})}
|
|
||||||
</div>
|
|
||||||
</CardHeader>
|
|
||||||
<CardContent className="overflow-scroll border-t border-t-gray-200 p-0 dark:border-t-slate-700">
|
|
||||||
<ScrollArea
|
|
||||||
className="h-[60vh] w-full"
|
|
||||||
data-id="blocks-control-scroll-area"
|
|
||||||
>
|
|
||||||
{filteredAvailableBlocks.map((block) => (
|
|
||||||
<Card
|
|
||||||
key={block.uiKey || block.id}
|
|
||||||
className={`m-2 my-4 flex h-20 shadow-none dark:border-slate-700 dark:bg-slate-800 dark:text-slate-100 dark:hover:bg-slate-700 ${
|
|
||||||
block.notAvailable
|
|
||||||
? "cursor-not-allowed opacity-50"
|
|
||||||
: "cursor-move hover:shadow-lg"
|
|
||||||
}`}
|
|
||||||
data-id={`block-card-${block.id}`}
|
|
||||||
draggable={!block.notAvailable}
|
|
||||||
onDragStart={(e) => {
|
|
||||||
if (block.notAvailable) return;
|
|
||||||
e.dataTransfer.effectAllowed = "copy";
|
|
||||||
e.dataTransfer.setData(
|
|
||||||
"application/reactflow",
|
|
||||||
JSON.stringify({
|
|
||||||
blockId: block.id,
|
|
||||||
blockName: block.name,
|
|
||||||
hardcodedValues: block?.hardcodedValues || {},
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
}}
|
|
||||||
onClick={() =>
|
|
||||||
!block.notAvailable &&
|
|
||||||
addBlock(block.id, block.name, block?.hardcodedValues || {})
|
|
||||||
}
|
|
||||||
title={block.notAvailable ?? undefined}
|
|
||||||
>
|
>
|
||||||
<div
|
Blocks
|
||||||
className={`-ml-px h-full w-3 rounded-l-xl ${getPrimaryCategoryColor(block.categories)}`}
|
</Label>
|
||||||
></div>
|
</div>
|
||||||
|
<div className="relative flex items-center">
|
||||||
<div className="mx-3 flex flex-1 items-center justify-between">
|
<MagnifyingGlassIcon className="absolute m-2 h-5 w-5 text-gray-500 dark:text-gray-400" />
|
||||||
<div className="mr-2 min-w-0">
|
<Input
|
||||||
<span
|
id="search-blocks"
|
||||||
className="block truncate pb-1 text-sm font-semibold dark:text-white"
|
type="text"
|
||||||
data-id={`block-name-${block.id}`}
|
placeholder="Search blocks"
|
||||||
data-type={block.uiType}
|
value={searchQuery}
|
||||||
data-testid={`block-name-${block.id}`}
|
onChange={(e) => setSearchQuery(e.target.value)}
|
||||||
>
|
className="rounded-lg px-8 py-5 dark:bg-slate-800 dark:text-white"
|
||||||
<TextRenderer
|
data-id="blocks-control-search-input"
|
||||||
value={beautifyString(block.name).replace(
|
autoComplete="off"
|
||||||
/ Block$/,
|
/>
|
||||||
"",
|
</div>
|
||||||
)}
|
<div
|
||||||
truncateLengthLimit={45}
|
className="mt-2 flex flex-wrap gap-2"
|
||||||
/>
|
data-testid="blocks-categories-list"
|
||||||
</span>
|
>
|
||||||
<span
|
{categories.map((category) => {
|
||||||
className="block break-all text-xs font-normal text-gray-500 dark:text-gray-400"
|
const color = getPrimaryCategoryColor([
|
||||||
data-testid={`block-description-${block.id}`}
|
{ category: category || "All", description: "" },
|
||||||
>
|
]);
|
||||||
<TextRenderer
|
const colorClass =
|
||||||
value={block.description}
|
selectedCategory === category ? `${color}` : "";
|
||||||
truncateLengthLimit={165}
|
return (
|
||||||
/>
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
<div
|
<div
|
||||||
className="flex flex-shrink-0 items-center gap-1"
|
key={category}
|
||||||
data-id={`block-tooltip-${block.id}`}
|
data-testid="blocks-category"
|
||||||
data-testid={`block-add`}
|
role="button"
|
||||||
|
className={`cursor-pointer rounded-xl border px-2 py-2 text-xs font-medium dark:border-slate-700 dark:text-white ${colorClass}`}
|
||||||
|
onClick={() =>
|
||||||
|
setSelectedCategory(
|
||||||
|
selectedCategory === category ? null : category,
|
||||||
|
)
|
||||||
|
}
|
||||||
>
|
>
|
||||||
<PlusIcon className="h-6 w-6 rounded-lg bg-gray-200 stroke-black stroke-[0.5px] p-1 dark:bg-gray-700 dark:stroke-white" />
|
{beautifyString((category || "All").toLowerCase())}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
);
|
||||||
</Card>
|
})}
|
||||||
))}
|
</div>
|
||||||
</ScrollArea>
|
</CardHeader>
|
||||||
</CardContent>
|
<CardContent className="overflow-scroll border-t border-t-gray-200 p-0 dark:border-t-slate-700">
|
||||||
</Card>
|
<ScrollArea
|
||||||
</PopoverContent>
|
className="h-[60vh] w-full"
|
||||||
</Popover>
|
data-id="blocks-control-scroll-area"
|
||||||
|
>
|
||||||
|
{filteredAvailableBlocks.map((block) => (
|
||||||
|
<Card
|
||||||
|
key={block.uiKey || block.id}
|
||||||
|
className={`m-2 my-4 flex h-20 shadow-none dark:border-slate-700 dark:bg-slate-800 dark:text-slate-100 dark:hover:bg-slate-700 ${
|
||||||
|
block.notAvailable
|
||||||
|
? "cursor-not-allowed opacity-50"
|
||||||
|
: block.id === SpecialBlockID.MCP_TOOL
|
||||||
|
? "cursor-pointer hover:shadow-lg"
|
||||||
|
: "cursor-move hover:shadow-lg"
|
||||||
|
}`}
|
||||||
|
data-id={`block-card-${block.id}`}
|
||||||
|
draggable={
|
||||||
|
!block.notAvailable &&
|
||||||
|
block.id !== SpecialBlockID.MCP_TOOL
|
||||||
|
}
|
||||||
|
onDragStart={(e) => {
|
||||||
|
if (
|
||||||
|
block.notAvailable ||
|
||||||
|
block.id === SpecialBlockID.MCP_TOOL
|
||||||
|
)
|
||||||
|
return;
|
||||||
|
e.dataTransfer.effectAllowed = "copy";
|
||||||
|
e.dataTransfer.setData(
|
||||||
|
"application/reactflow",
|
||||||
|
JSON.stringify({
|
||||||
|
blockId: block.id,
|
||||||
|
blockName: block.name,
|
||||||
|
hardcodedValues: block?.hardcodedValues || {},
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}}
|
||||||
|
onClick={() => handleAddBlock(block)}
|
||||||
|
title={block.notAvailable ?? undefined}
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
className={`-ml-px h-full w-3 rounded-l-xl ${getPrimaryCategoryColor(block.categories)}`}
|
||||||
|
></div>
|
||||||
|
|
||||||
|
<div className="mx-3 flex flex-1 items-center justify-between">
|
||||||
|
<div className="mr-2 min-w-0">
|
||||||
|
<span
|
||||||
|
className="block truncate pb-1 text-sm font-semibold dark:text-white"
|
||||||
|
data-id={`block-name-${block.id}`}
|
||||||
|
data-type={block.uiType}
|
||||||
|
data-testid={`block-name-${block.id}`}
|
||||||
|
>
|
||||||
|
<TextRenderer
|
||||||
|
value={beautifyString(block.name).replace(
|
||||||
|
/ Block$/,
|
||||||
|
"",
|
||||||
|
)}
|
||||||
|
truncateLengthLimit={45}
|
||||||
|
/>
|
||||||
|
</span>
|
||||||
|
<span
|
||||||
|
className="block break-all text-xs font-normal text-gray-500 dark:text-gray-400"
|
||||||
|
data-testid={`block-description-${block.id}`}
|
||||||
|
>
|
||||||
|
<TextRenderer
|
||||||
|
value={block.description}
|
||||||
|
truncateLengthLimit={165}
|
||||||
|
/>
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
className="flex flex-shrink-0 items-center gap-1"
|
||||||
|
data-id={`block-tooltip-${block.id}`}
|
||||||
|
data-testid={`block-add`}
|
||||||
|
>
|
||||||
|
<PlusIcon className="h-6 w-6 rounded-lg bg-gray-200 stroke-black stroke-[0.5px] p-1 dark:bg-gray-700 dark:stroke-white" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Card>
|
||||||
|
))}
|
||||||
|
</ScrollArea>
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
</PopoverContent>
|
||||||
|
</Popover>
|
||||||
|
|
||||||
|
<MCPToolDialog
|
||||||
|
open={mcpDialogOpen}
|
||||||
|
onClose={() => setMcpDialogOpen(false)}
|
||||||
|
onConfirm={handleMCPToolConfirm}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import {
|
|||||||
GraphInputSchema,
|
GraphInputSchema,
|
||||||
GraphOutputSchema,
|
GraphOutputSchema,
|
||||||
NodeExecutionResult,
|
NodeExecutionResult,
|
||||||
|
SpecialBlockID,
|
||||||
} from "@/lib/autogpt-server-api";
|
} from "@/lib/autogpt-server-api";
|
||||||
import {
|
import {
|
||||||
beautifyString,
|
beautifyString,
|
||||||
@@ -215,6 +216,26 @@ export const CustomNode = React.memo(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MCP Tool block: display the selected tool's dynamic schema
|
||||||
|
const isMCPWithTool =
|
||||||
|
data.block_id === SpecialBlockID.MCP_TOOL &&
|
||||||
|
!!data.hardcodedValues?.tool_input_schema?.properties;
|
||||||
|
|
||||||
|
if (isMCPWithTool) {
|
||||||
|
// Show only the tool's input parameters. Credentials are NOT included
|
||||||
|
// because authentication is handled by the MCP dialog's OAuth flow
|
||||||
|
// and stored server-side.
|
||||||
|
const toolSchema = data.hardcodedValues.tool_input_schema;
|
||||||
|
|
||||||
|
data.inputSchema = {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
...(toolSchema.properties ?? {}),
|
||||||
|
},
|
||||||
|
required: [...(toolSchema.required ?? [])],
|
||||||
|
} as BlockIORootSchema;
|
||||||
|
}
|
||||||
|
|
||||||
const setHardcodedValues = useCallback(
|
const setHardcodedValues = useCallback(
|
||||||
(values: any) => {
|
(values: any) => {
|
||||||
updateNodeData(id, { hardcodedValues: values });
|
updateNodeData(id, { hardcodedValues: values });
|
||||||
@@ -375,7 +396,9 @@ export const CustomNode = React.memo(
|
|||||||
|
|
||||||
const displayTitle =
|
const displayTitle =
|
||||||
customTitle ||
|
customTitle ||
|
||||||
beautifyString(data.blockType?.replace(/Block$/, "") || data.title);
|
(isMCPWithTool
|
||||||
|
? `${data.hardcodedValues.server_name || "MCP"}: ${beautifyString(data.hardcodedValues.selected_tool || "")}`
|
||||||
|
: beautifyString(data.blockType?.replace(/Block$/, "") || data.title));
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
isInitialSetup.current = false;
|
isInitialSetup.current = false;
|
||||||
@@ -389,6 +412,15 @@ export const CustomNode = React.memo(
|
|||||||
data.inputSchema,
|
data.inputSchema,
|
||||||
),
|
),
|
||||||
});
|
});
|
||||||
|
} else if (isMCPWithTool) {
|
||||||
|
// MCP dialog already configured server_url, selected_tool, etc.
|
||||||
|
// Just ensure tool_arguments is initialized.
|
||||||
|
if (!data.hardcodedValues.tool_arguments) {
|
||||||
|
setHardcodedValues({
|
||||||
|
...data.hardcodedValues,
|
||||||
|
tool_arguments: {},
|
||||||
|
});
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
setHardcodedValues(
|
setHardcodedValues(
|
||||||
fillObjectDefaultsFromSchema(data.hardcodedValues, data.inputSchema),
|
fillObjectDefaultsFromSchema(data.hardcodedValues, data.inputSchema),
|
||||||
@@ -525,8 +557,11 @@ export const CustomNode = React.memo(
|
|||||||
);
|
);
|
||||||
|
|
||||||
default:
|
default:
|
||||||
const getInputPropKey = (key: string) =>
|
const getInputPropKey = (key: string) => {
|
||||||
nodeType == BlockUIType.AGENT ? `inputs.${key}` : key;
|
if (nodeType == BlockUIType.AGENT) return `inputs.${key}`;
|
||||||
|
if (isMCPWithTool) return `tool_arguments.${key}`;
|
||||||
|
return key;
|
||||||
|
};
|
||||||
|
|
||||||
return keys.map(([propKey, propSchema]) => {
|
return keys.map(([propKey, propSchema]) => {
|
||||||
const isRequired = data.inputSchema.required?.includes(propKey);
|
const isRequired = data.inputSchema.required?.includes(propKey);
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { beautifyString } from "@/lib/utils";
|
import { beautifyString } from "@/lib/utils";
|
||||||
import { Clipboard, Maximize2 } from "lucide-react";
|
import { Clipboard, Maximize2 } from "lucide-react";
|
||||||
import React, { useState } from "react";
|
import React, { useMemo, useState } from "react";
|
||||||
import { Button } from "../../../../../components/__legacy__/ui/button";
|
import { Button } from "../../../../../components/__legacy__/ui/button";
|
||||||
import { ContentRenderer } from "../../../../../components/__legacy__/ui/render";
|
import { ContentRenderer } from "../../../../../components/__legacy__/ui/render";
|
||||||
import {
|
import {
|
||||||
@@ -11,6 +11,12 @@ import {
|
|||||||
TableHeader,
|
TableHeader,
|
||||||
TableRow,
|
TableRow,
|
||||||
} from "../../../../../components/__legacy__/ui/table";
|
} from "../../../../../components/__legacy__/ui/table";
|
||||||
|
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||||
|
import {
|
||||||
|
globalRegistry,
|
||||||
|
OutputItem,
|
||||||
|
} from "@/components/contextual/OutputRenderers";
|
||||||
|
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||||
import { useToast } from "../../../../../components/molecules/Toast/use-toast";
|
import { useToast } from "../../../../../components/molecules/Toast/use-toast";
|
||||||
import ExpandableOutputDialog from "./ExpandableOutputDialog";
|
import ExpandableOutputDialog from "./ExpandableOutputDialog";
|
||||||
|
|
||||||
@@ -26,6 +32,9 @@ export default function DataTable({
|
|||||||
data,
|
data,
|
||||||
}: DataTableProps) {
|
}: DataTableProps) {
|
||||||
const { toast } = useToast();
|
const { toast } = useToast();
|
||||||
|
const enableEnhancedOutputHandling = useGetFlag(
|
||||||
|
Flag.ENABLE_ENHANCED_OUTPUT_HANDLING,
|
||||||
|
);
|
||||||
const [expandedDialog, setExpandedDialog] = useState<{
|
const [expandedDialog, setExpandedDialog] = useState<{
|
||||||
isOpen: boolean;
|
isOpen: boolean;
|
||||||
execId: string;
|
execId: string;
|
||||||
@@ -33,6 +42,15 @@ export default function DataTable({
|
|||||||
data: any[];
|
data: any[];
|
||||||
} | null>(null);
|
} | null>(null);
|
||||||
|
|
||||||
|
// Prepare renderers for each item when enhanced mode is enabled
|
||||||
|
const getItemRenderer = useMemo(() => {
|
||||||
|
if (!enableEnhancedOutputHandling) return null;
|
||||||
|
return (item: unknown) => {
|
||||||
|
const metadata: OutputMetadata = {};
|
||||||
|
return globalRegistry.getRenderer(item, metadata);
|
||||||
|
};
|
||||||
|
}, [enableEnhancedOutputHandling]);
|
||||||
|
|
||||||
const copyData = (pin: string, data: string) => {
|
const copyData = (pin: string, data: string) => {
|
||||||
navigator.clipboard.writeText(data).then(() => {
|
navigator.clipboard.writeText(data).then(() => {
|
||||||
toast({
|
toast({
|
||||||
@@ -102,15 +120,31 @@ export default function DataTable({
|
|||||||
<Clipboard size={18} />
|
<Clipboard size={18} />
|
||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
{value.map((item, index) => (
|
{value.map((item, index) => {
|
||||||
<React.Fragment key={index}>
|
const renderer = getItemRenderer?.(item);
|
||||||
<ContentRenderer
|
if (enableEnhancedOutputHandling && renderer) {
|
||||||
value={item}
|
const metadata: OutputMetadata = {};
|
||||||
truncateLongData={truncateLongData}
|
return (
|
||||||
/>
|
<React.Fragment key={index}>
|
||||||
{index < value.length - 1 && ", "}
|
<OutputItem
|
||||||
</React.Fragment>
|
value={item}
|
||||||
))}
|
metadata={metadata}
|
||||||
|
renderer={renderer}
|
||||||
|
/>
|
||||||
|
{index < value.length - 1 && ", "}
|
||||||
|
</React.Fragment>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<React.Fragment key={index}>
|
||||||
|
<ContentRenderer
|
||||||
|
value={item}
|
||||||
|
truncateLongData={truncateLongData}
|
||||||
|
/>
|
||||||
|
{index < value.length - 1 && ", "}
|
||||||
|
</React.Fragment>
|
||||||
|
);
|
||||||
|
})}
|
||||||
</div>
|
</div>
|
||||||
</TableCell>
|
</TableCell>
|
||||||
</TableRow>
|
</TableRow>
|
||||||
|
|||||||
@@ -29,13 +29,17 @@ import "@xyflow/react/dist/style.css";
|
|||||||
import { ConnectedEdge, CustomNode } from "../CustomNode/CustomNode";
|
import { ConnectedEdge, CustomNode } from "../CustomNode/CustomNode";
|
||||||
import "./flow.css";
|
import "./flow.css";
|
||||||
import {
|
import {
|
||||||
|
BlockIORootSchema,
|
||||||
BlockUIType,
|
BlockUIType,
|
||||||
formatEdgeID,
|
formatEdgeID,
|
||||||
GraphExecutionID,
|
GraphExecutionID,
|
||||||
GraphID,
|
GraphID,
|
||||||
GraphMeta,
|
GraphMeta,
|
||||||
LibraryAgent,
|
LibraryAgent,
|
||||||
|
SpecialBlockID,
|
||||||
} from "@/lib/autogpt-server-api";
|
} from "@/lib/autogpt-server-api";
|
||||||
|
import { getV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||||
|
import { okData } from "@/app/api/helpers";
|
||||||
import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types";
|
import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types";
|
||||||
import { Key, storage } from "@/services/storage/local-storage";
|
import { Key, storage } from "@/services/storage/local-storage";
|
||||||
import { findNewlyAddedBlockCoordinates, getTypeColor } from "@/lib/utils";
|
import { findNewlyAddedBlockCoordinates, getTypeColor } from "@/lib/utils";
|
||||||
@@ -687,8 +691,94 @@ const FlowEditor: React.FC<{
|
|||||||
[getNode, updateNode, nodes],
|
[getNode, updateNode, nodes],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/* Shared helper to create and add a node */
|
||||||
|
const createAndAddNode = useCallback(
|
||||||
|
async (
|
||||||
|
blockID: string,
|
||||||
|
blockName: string,
|
||||||
|
hardcodedValues: Record<string, any>,
|
||||||
|
position: { x: number; y: number },
|
||||||
|
): Promise<CustomNode | null> => {
|
||||||
|
const nodeSchema = availableBlocks.find((node) => node.id === blockID);
|
||||||
|
if (!nodeSchema) {
|
||||||
|
console.error(`Schema not found for block ID: ${blockID}`);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// For agent blocks, fetch the full graph to get schemas
|
||||||
|
let inputSchema: BlockIORootSchema = nodeSchema.inputSchema;
|
||||||
|
let outputSchema: BlockIORootSchema = nodeSchema.outputSchema;
|
||||||
|
let finalHardcodedValues = hardcodedValues;
|
||||||
|
|
||||||
|
if (blockID === SpecialBlockID.AGENT) {
|
||||||
|
const graphID = hardcodedValues.graph_id as string;
|
||||||
|
const graphVersion = hardcodedValues.graph_version as number;
|
||||||
|
const graphData = okData(
|
||||||
|
await getV1GetSpecificGraph(graphID, { version: graphVersion }),
|
||||||
|
);
|
||||||
|
|
||||||
|
if (graphData) {
|
||||||
|
inputSchema = graphData.input_schema as BlockIORootSchema;
|
||||||
|
outputSchema = graphData.output_schema as BlockIORootSchema;
|
||||||
|
finalHardcodedValues = {
|
||||||
|
...hardcodedValues,
|
||||||
|
input_schema: graphData.input_schema,
|
||||||
|
output_schema: graphData.output_schema,
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
console.error("Failed to fetch graph data for agent block");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const newNode: CustomNode = {
|
||||||
|
id: nodeId.toString(),
|
||||||
|
type: "custom",
|
||||||
|
position,
|
||||||
|
data: {
|
||||||
|
blockType: blockName,
|
||||||
|
blockCosts: nodeSchema.costs || [],
|
||||||
|
title: `${blockName} ${nodeId}`,
|
||||||
|
description: nodeSchema.description,
|
||||||
|
categories: nodeSchema.categories,
|
||||||
|
inputSchema: inputSchema,
|
||||||
|
outputSchema: outputSchema,
|
||||||
|
hardcodedValues: finalHardcodedValues,
|
||||||
|
connections: [],
|
||||||
|
isOutputOpen: false,
|
||||||
|
block_id: blockID,
|
||||||
|
isOutputStatic: nodeSchema.staticOutput,
|
||||||
|
uiType: nodeSchema.uiType,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
addNodes(newNode);
|
||||||
|
setNodeId((prevId) => prevId + 1);
|
||||||
|
clearNodesStatusAndOutput();
|
||||||
|
|
||||||
|
history.push({
|
||||||
|
type: "ADD_NODE",
|
||||||
|
payload: { node: { ...newNode, ...newNode.data } },
|
||||||
|
undo: () => deleteElements({ nodes: [{ id: newNode.id }] }),
|
||||||
|
redo: () => addNodes(newNode),
|
||||||
|
});
|
||||||
|
|
||||||
|
return newNode;
|
||||||
|
},
|
||||||
|
[
|
||||||
|
availableBlocks,
|
||||||
|
nodeId,
|
||||||
|
addNodes,
|
||||||
|
deleteElements,
|
||||||
|
clearNodesStatusAndOutput,
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
const addNode = useCallback(
|
const addNode = useCallback(
|
||||||
(blockId: string, nodeType: string, hardcodedValues: any = {}) => {
|
async (
|
||||||
|
blockId: string,
|
||||||
|
nodeType: string,
|
||||||
|
hardcodedValues: Record<string, any> = {},
|
||||||
|
) => {
|
||||||
const nodeSchema = availableBlocks.find((node) => node.id === blockId);
|
const nodeSchema = availableBlocks.find((node) => node.id === blockId);
|
||||||
if (!nodeSchema) {
|
if (!nodeSchema) {
|
||||||
console.error(`Schema not found for block ID: ${blockId}`);
|
console.error(`Schema not found for block ID: ${blockId}`);
|
||||||
@@ -707,73 +797,42 @@ const FlowEditor: React.FC<{
|
|||||||
// Alternative: We could also use D3 force, Intersection for this (React flow Pro examples)
|
// Alternative: We could also use D3 force, Intersection for this (React flow Pro examples)
|
||||||
|
|
||||||
const { x, y } = getViewport();
|
const { x, y } = getViewport();
|
||||||
const viewportCoordinates =
|
const position =
|
||||||
nodeDimensions && Object.keys(nodeDimensions).length > 0
|
nodeDimensions && Object.keys(nodeDimensions).length > 0
|
||||||
? // we will get all the dimension of nodes, then store
|
? findNewlyAddedBlockCoordinates(
|
||||||
findNewlyAddedBlockCoordinates(
|
|
||||||
nodeDimensions,
|
nodeDimensions,
|
||||||
nodeSchema.uiType == BlockUIType.NOTE ? 300 : 500,
|
nodeSchema.uiType == BlockUIType.NOTE ? 300 : 500,
|
||||||
60,
|
60,
|
||||||
1.0,
|
1.0,
|
||||||
)
|
)
|
||||||
: // we will get all the dimension of nodes, then store
|
: {
|
||||||
{
|
|
||||||
x: window.innerWidth / 2 - x,
|
x: window.innerWidth / 2 - x,
|
||||||
y: window.innerHeight / 2 - y,
|
y: window.innerHeight / 2 - y,
|
||||||
};
|
};
|
||||||
|
|
||||||
const newNode: CustomNode = {
|
const newNode = await createAndAddNode(
|
||||||
id: nodeId.toString(),
|
blockId,
|
||||||
type: "custom",
|
nodeType,
|
||||||
position: viewportCoordinates, // Set the position to the calculated viewport center
|
hardcodedValues,
|
||||||
data: {
|
position,
|
||||||
blockType: nodeType,
|
);
|
||||||
blockCosts: nodeSchema.costs,
|
if (!newNode) return;
|
||||||
title: `${nodeType} ${nodeId}`,
|
|
||||||
description: nodeSchema.description,
|
|
||||||
categories: nodeSchema.categories,
|
|
||||||
inputSchema: nodeSchema.inputSchema,
|
|
||||||
outputSchema: nodeSchema.outputSchema,
|
|
||||||
hardcodedValues: hardcodedValues,
|
|
||||||
connections: [],
|
|
||||||
isOutputOpen: false,
|
|
||||||
block_id: blockId,
|
|
||||||
isOutputStatic: nodeSchema.staticOutput,
|
|
||||||
uiType: nodeSchema.uiType,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
addNodes(newNode);
|
|
||||||
setNodeId((prevId) => prevId + 1);
|
|
||||||
clearNodesStatusAndOutput(); // Clear status and output when a new node is added
|
|
||||||
|
|
||||||
setViewport(
|
setViewport(
|
||||||
{
|
{
|
||||||
// Rough estimate of the dimension of the node is: 500x400px.
|
x: -position.x * 0.8 + (window.innerWidth - 0.0) / 2,
|
||||||
// Though we skip shifting the X, considering the block menu side-bar.
|
y: -position.y * 0.8 + (window.innerHeight - 400) / 2,
|
||||||
x: -viewportCoordinates.x * 0.8 + (window.innerWidth - 0.0) / 2,
|
|
||||||
y: -viewportCoordinates.y * 0.8 + (window.innerHeight - 400) / 2,
|
|
||||||
zoom: 0.8,
|
zoom: 0.8,
|
||||||
},
|
},
|
||||||
{ duration: 500 },
|
{ duration: 500 },
|
||||||
);
|
);
|
||||||
|
|
||||||
history.push({
|
|
||||||
type: "ADD_NODE",
|
|
||||||
payload: { node: { ...newNode, ...newNode.data } },
|
|
||||||
undo: () => deleteElements({ nodes: [{ id: newNode.id }] }),
|
|
||||||
redo: () => addNodes(newNode),
|
|
||||||
});
|
|
||||||
},
|
},
|
||||||
[
|
[
|
||||||
nodeId,
|
|
||||||
getViewport,
|
getViewport,
|
||||||
setViewport,
|
setViewport,
|
||||||
availableBlocks,
|
availableBlocks,
|
||||||
addNodes,
|
|
||||||
nodeDimensions,
|
nodeDimensions,
|
||||||
deleteElements,
|
createAndAddNode,
|
||||||
clearNodesStatusAndOutput,
|
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -920,7 +979,7 @@ const FlowEditor: React.FC<{
|
|||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const onDrop = useCallback(
|
const onDrop = useCallback(
|
||||||
(event: React.DragEvent) => {
|
async (event: React.DragEvent) => {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
|
|
||||||
const blockData = event.dataTransfer.getData("application/reactflow");
|
const blockData = event.dataTransfer.getData("application/reactflow");
|
||||||
@@ -935,62 +994,17 @@ const FlowEditor: React.FC<{
|
|||||||
y: event.clientY,
|
y: event.clientY,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Find the block schema
|
await createAndAddNode(
|
||||||
const nodeSchema = availableBlocks.find((node) => node.id === blockId);
|
blockId,
|
||||||
if (!nodeSchema) {
|
blockName,
|
||||||
console.error(`Schema not found for block ID: ${blockId}`);
|
hardcodedValues || {},
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the new node at the drop position
|
|
||||||
const newNode: CustomNode = {
|
|
||||||
id: nodeId.toString(),
|
|
||||||
type: "custom",
|
|
||||||
position,
|
position,
|
||||||
data: {
|
);
|
||||||
blockType: blockName,
|
|
||||||
blockCosts: nodeSchema.costs || [],
|
|
||||||
title: `${blockName} ${nodeId}`,
|
|
||||||
description: nodeSchema.description,
|
|
||||||
categories: nodeSchema.categories,
|
|
||||||
inputSchema: nodeSchema.inputSchema,
|
|
||||||
outputSchema: nodeSchema.outputSchema,
|
|
||||||
hardcodedValues: hardcodedValues,
|
|
||||||
connections: [],
|
|
||||||
isOutputOpen: false,
|
|
||||||
block_id: blockId,
|
|
||||||
uiType: nodeSchema.uiType,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
history.push({
|
|
||||||
type: "ADD_NODE",
|
|
||||||
payload: { node: { ...newNode, ...newNode.data } },
|
|
||||||
undo: () => {
|
|
||||||
deleteElements({ nodes: [{ id: newNode.id } as any], edges: [] });
|
|
||||||
},
|
|
||||||
redo: () => {
|
|
||||||
addNodes([newNode]);
|
|
||||||
},
|
|
||||||
});
|
|
||||||
addNodes([newNode]);
|
|
||||||
clearNodesStatusAndOutput();
|
|
||||||
|
|
||||||
setNodeId((prevId) => prevId + 1);
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Failed to drop block:", error);
|
console.error("Failed to drop block:", error);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[
|
[screenToFlowPosition, createAndAddNode],
|
||||||
nodeId,
|
|
||||||
availableBlocks,
|
|
||||||
nodes,
|
|
||||||
edges,
|
|
||||||
addNodes,
|
|
||||||
screenToFlowPosition,
|
|
||||||
deleteElements,
|
|
||||||
clearNodesStatusAndOutput,
|
|
||||||
],
|
|
||||||
);
|
);
|
||||||
|
|
||||||
const buildContextValue: BuilderContextType = useMemo(
|
const buildContextValue: BuilderContextType = useMemo(
|
||||||
|
|||||||
@@ -0,0 +1,606 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import React, { useState, useCallback, useRef, useEffect } from "react";
|
||||||
|
import {
|
||||||
|
Dialog,
|
||||||
|
DialogContent,
|
||||||
|
DialogDescription,
|
||||||
|
DialogFooter,
|
||||||
|
DialogHeader,
|
||||||
|
DialogTitle,
|
||||||
|
} from "@/components/__legacy__/ui/dialog";
|
||||||
|
import { Button } from "@/components/__legacy__/ui/button";
|
||||||
|
import { Input } from "@/components/__legacy__/ui/input";
|
||||||
|
import { Label } from "@/components/__legacy__/ui/label";
|
||||||
|
import { LoadingSpinner } from "@/components/__legacy__/ui/loading";
|
||||||
|
import { Badge } from "@/components/__legacy__/ui/badge";
|
||||||
|
import { ScrollArea } from "@/components/__legacy__/ui/scroll-area";
|
||||||
|
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||||
|
import type { MCPTool } from "@/lib/autogpt-server-api";
|
||||||
|
import { CaretDown } from "@phosphor-icons/react";
|
||||||
|
|
||||||
|
export type MCPToolDialogResult = {
|
||||||
|
serverUrl: string;
|
||||||
|
serverName: string | null;
|
||||||
|
selectedTool: string;
|
||||||
|
toolInputSchema: Record<string, any>;
|
||||||
|
availableTools: Record<string, any>;
|
||||||
|
/** Credential ID from OAuth flow, null for public servers. */
|
||||||
|
credentialId: string | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
interface MCPToolDialogProps {
|
||||||
|
open: boolean;
|
||||||
|
onClose: () => void;
|
||||||
|
onConfirm: (result: MCPToolDialogResult) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
type DialogStep = "url" | "tool";
|
||||||
|
|
||||||
|
const OAUTH_TIMEOUT_MS = 5 * 60 * 1000; // 5 minutes
|
||||||
|
const STORAGE_KEY = "mcp_last_server_url";
|
||||||
|
|
||||||
|
export function MCPToolDialog({
|
||||||
|
open,
|
||||||
|
onClose,
|
||||||
|
onConfirm,
|
||||||
|
}: MCPToolDialogProps) {
|
||||||
|
const api = useBackendAPI();
|
||||||
|
|
||||||
|
const [step, setStep] = useState<DialogStep>("url");
|
||||||
|
const [serverUrl, setServerUrl] = useState("");
|
||||||
|
const [tools, setTools] = useState<MCPTool[]>([]);
|
||||||
|
const [serverName, setServerName] = useState<string | null>(null);
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
const [error, setError] = useState<string | null>(null);
|
||||||
|
const [authRequired, setAuthRequired] = useState(false);
|
||||||
|
const [oauthLoading, setOauthLoading] = useState(false);
|
||||||
|
const [showManualToken, setShowManualToken] = useState(false);
|
||||||
|
const [manualToken, setManualToken] = useState("");
|
||||||
|
const [selectedTool, setSelectedTool] = useState<MCPTool | null>(null);
|
||||||
|
const [credentialId, setCredentialId] = useState<string | null>(null);
|
||||||
|
|
||||||
|
const oauthLoadingRef = useRef(false);
|
||||||
|
const stateTokenRef = useRef<string | null>(null);
|
||||||
|
const broadcastChannelRef = useRef<BroadcastChannel | null>(null);
|
||||||
|
const messageHandlerRef = useRef<((event: MessageEvent) => void) | null>(
|
||||||
|
null,
|
||||||
|
);
|
||||||
|
const oauthHandledRef = useRef(false);
|
||||||
|
const autoConnectAttemptedRef = useRef(false);
|
||||||
|
|
||||||
|
// Pre-fill last used server URL when dialog opens (without auto-connecting)
|
||||||
|
useEffect(() => {
|
||||||
|
if (!open) {
|
||||||
|
autoConnectAttemptedRef.current = false;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (autoConnectAttemptedRef.current) return;
|
||||||
|
autoConnectAttemptedRef.current = true;
|
||||||
|
|
||||||
|
const lastUrl = localStorage.getItem(STORAGE_KEY);
|
||||||
|
if (lastUrl) {
|
||||||
|
setServerUrl(lastUrl);
|
||||||
|
}
|
||||||
|
}, [open]);
|
||||||
|
|
||||||
|
// Clean up listeners on unmount
|
||||||
|
useEffect(() => {
|
||||||
|
return () => {
|
||||||
|
if (messageHandlerRef.current) {
|
||||||
|
window.removeEventListener("message", messageHandlerRef.current);
|
||||||
|
}
|
||||||
|
if (broadcastChannelRef.current) {
|
||||||
|
broadcastChannelRef.current.close();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const cleanupOAuthListeners = useCallback(() => {
|
||||||
|
if (messageHandlerRef.current) {
|
||||||
|
window.removeEventListener("message", messageHandlerRef.current);
|
||||||
|
messageHandlerRef.current = null;
|
||||||
|
}
|
||||||
|
if (broadcastChannelRef.current) {
|
||||||
|
broadcastChannelRef.current.close();
|
||||||
|
broadcastChannelRef.current = null;
|
||||||
|
}
|
||||||
|
setOauthLoading(false);
|
||||||
|
oauthLoadingRef.current = false;
|
||||||
|
oauthHandledRef.current = false;
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const reset = useCallback(() => {
|
||||||
|
cleanupOAuthListeners();
|
||||||
|
setStep("url");
|
||||||
|
setServerUrl("");
|
||||||
|
setManualToken("");
|
||||||
|
setTools([]);
|
||||||
|
setServerName(null);
|
||||||
|
setLoading(false);
|
||||||
|
setError(null);
|
||||||
|
setAuthRequired(false);
|
||||||
|
setShowManualToken(false);
|
||||||
|
setSelectedTool(null);
|
||||||
|
setCredentialId(null);
|
||||||
|
stateTokenRef.current = null;
|
||||||
|
}, [cleanupOAuthListeners]);
|
||||||
|
|
||||||
|
const handleClose = useCallback(() => {
|
||||||
|
reset();
|
||||||
|
onClose();
|
||||||
|
}, [reset, onClose]);
|
||||||
|
|
||||||
|
const discoverTools = useCallback(
|
||||||
|
async (url: string, authToken?: string) => {
|
||||||
|
setLoading(true);
|
||||||
|
setError(null);
|
||||||
|
try {
|
||||||
|
const result = await api.mcpDiscoverTools(url, authToken);
|
||||||
|
localStorage.setItem(STORAGE_KEY, url);
|
||||||
|
setTools(result.tools);
|
||||||
|
setServerName(result.server_name);
|
||||||
|
setAuthRequired(false);
|
||||||
|
setShowManualToken(false);
|
||||||
|
setStep("tool");
|
||||||
|
} catch (e: any) {
|
||||||
|
if (e?.status === 401 || e?.status === 403) {
|
||||||
|
setAuthRequired(true);
|
||||||
|
setError(null);
|
||||||
|
} else {
|
||||||
|
const message =
|
||||||
|
e?.message || e?.detail || "Failed to connect to MCP server";
|
||||||
|
setError(
|
||||||
|
typeof message === "string" ? message : JSON.stringify(message),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[api],
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleDiscoverTools = useCallback(() => {
|
||||||
|
if (!serverUrl.trim()) return;
|
||||||
|
discoverTools(serverUrl.trim(), manualToken.trim() || undefined);
|
||||||
|
}, [serverUrl, manualToken, discoverTools]);
|
||||||
|
|
||||||
|
const handleOAuthResult = useCallback(
|
||||||
|
async (data: {
|
||||||
|
success: boolean;
|
||||||
|
code?: string;
|
||||||
|
state?: string;
|
||||||
|
message?: string;
|
||||||
|
}) => {
|
||||||
|
// Prevent double-handling (BroadcastChannel + postMessage may both fire)
|
||||||
|
if (oauthHandledRef.current) return;
|
||||||
|
oauthHandledRef.current = true;
|
||||||
|
|
||||||
|
if (!data.success) {
|
||||||
|
setError(data.message || "OAuth authentication failed.");
|
||||||
|
cleanupOAuthListeners();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanupOAuthListeners();
|
||||||
|
setAuthRequired(false);
|
||||||
|
|
||||||
|
// Exchange code for tokens (stored server-side)
|
||||||
|
setLoading(true);
|
||||||
|
try {
|
||||||
|
const callbackResult = await api.mcpOAuthCallback(
|
||||||
|
data.code!,
|
||||||
|
stateTokenRef.current!,
|
||||||
|
);
|
||||||
|
setCredentialId(callbackResult.credential_id);
|
||||||
|
const result = await api.mcpDiscoverTools(serverUrl.trim());
|
||||||
|
localStorage.setItem(STORAGE_KEY, serverUrl.trim());
|
||||||
|
setTools(result.tools);
|
||||||
|
setServerName(result.server_name);
|
||||||
|
setStep("tool");
|
||||||
|
} catch (e: any) {
|
||||||
|
const message = e?.message || e?.detail || "Failed to complete sign-in";
|
||||||
|
setError(
|
||||||
|
typeof message === "string" ? message : JSON.stringify(message),
|
||||||
|
);
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[api, serverUrl, cleanupOAuthListeners],
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleOAuthSignIn = useCallback(async () => {
|
||||||
|
if (!serverUrl.trim()) return;
|
||||||
|
setError(null);
|
||||||
|
oauthHandledRef.current = false;
|
||||||
|
|
||||||
|
// Open popup SYNCHRONOUSLY (before async call) to avoid browser popup blockers
|
||||||
|
const width = 500;
|
||||||
|
const height = 700;
|
||||||
|
const left = window.screenX + (window.outerWidth - width) / 2;
|
||||||
|
const top = window.screenY + (window.outerHeight - height) / 2;
|
||||||
|
const popup = window.open(
|
||||||
|
"about:blank",
|
||||||
|
"mcp_oauth",
|
||||||
|
`width=${width},height=${height},left=${left},top=${top},scrollbars=yes`,
|
||||||
|
);
|
||||||
|
|
||||||
|
setOauthLoading(true);
|
||||||
|
oauthLoadingRef.current = true;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const { login_url, state_token } = await api.mcpOAuthLogin(
|
||||||
|
serverUrl.trim(),
|
||||||
|
);
|
||||||
|
stateTokenRef.current = state_token;
|
||||||
|
|
||||||
|
if (popup && !popup.closed) {
|
||||||
|
popup.location.href = login_url;
|
||||||
|
} else {
|
||||||
|
// Popup was blocked — open in new tab as fallback
|
||||||
|
window.open(login_url, "_blank");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listener 1: BroadcastChannel (works even when window.opener is null)
|
||||||
|
const bc = new BroadcastChannel("mcp_oauth");
|
||||||
|
bc.onmessage = (event) => {
|
||||||
|
if (event.data?.type === "mcp_oauth_result") {
|
||||||
|
handleOAuthResult(event.data);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
broadcastChannelRef.current = bc;
|
||||||
|
|
||||||
|
// Listener 2: window.postMessage (fallback)
|
||||||
|
const handleMessage = (event: MessageEvent) => {
|
||||||
|
if (event.origin !== window.location.origin) return;
|
||||||
|
if (event.data?.message_type === "mcp_oauth_result") {
|
||||||
|
handleOAuthResult(event.data);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
messageHandlerRef.current = handleMessage;
|
||||||
|
window.addEventListener("message", handleMessage);
|
||||||
|
|
||||||
|
// Timeout
|
||||||
|
setTimeout(() => {
|
||||||
|
if (oauthLoadingRef.current) {
|
||||||
|
cleanupOAuthListeners();
|
||||||
|
setError("OAuth sign-in timed out. Please try again.");
|
||||||
|
}
|
||||||
|
}, OAUTH_TIMEOUT_MS);
|
||||||
|
} catch (e: any) {
|
||||||
|
if (popup && !popup.closed) popup.close();
|
||||||
|
|
||||||
|
// If server doesn't support OAuth → show manual token entry
|
||||||
|
if (e?.status === 400) {
|
||||||
|
setShowManualToken(true);
|
||||||
|
setError(
|
||||||
|
"This server does not support OAuth sign-in. Please enter a token manually.",
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
const message = e?.message || "Failed to initiate sign-in";
|
||||||
|
setError(
|
||||||
|
typeof message === "string" ? message : JSON.stringify(message),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
cleanupOAuthListeners();
|
||||||
|
}
|
||||||
|
}, [api, serverUrl, handleOAuthResult, cleanupOAuthListeners]);
|
||||||
|
|
||||||
|
const handleConfirm = useCallback(() => {
|
||||||
|
if (!selectedTool) return;
|
||||||
|
|
||||||
|
const availableTools: Record<string, any> = {};
|
||||||
|
for (const t of tools) {
|
||||||
|
availableTools[t.name] = {
|
||||||
|
description: t.description,
|
||||||
|
input_schema: t.input_schema,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
onConfirm({
|
||||||
|
serverUrl: serverUrl.trim(),
|
||||||
|
serverName,
|
||||||
|
selectedTool: selectedTool.name,
|
||||||
|
toolInputSchema: selectedTool.input_schema,
|
||||||
|
availableTools,
|
||||||
|
credentialId,
|
||||||
|
});
|
||||||
|
reset();
|
||||||
|
}, [selectedTool, tools, serverUrl, credentialId, onConfirm, reset]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Dialog open={open} onOpenChange={(isOpen) => !isOpen && handleClose()}>
|
||||||
|
<DialogContent className="max-w-lg">
|
||||||
|
<DialogHeader>
|
||||||
|
<DialogTitle>
|
||||||
|
{step === "url"
|
||||||
|
? "Connect to MCP Server"
|
||||||
|
: `Select a Tool${serverName ? ` — ${serverName}` : ""}`}
|
||||||
|
</DialogTitle>
|
||||||
|
<DialogDescription>
|
||||||
|
{step === "url"
|
||||||
|
? "Enter the URL of an MCP server to discover its available tools."
|
||||||
|
: `Found ${tools.length} tool${tools.length !== 1 ? "s" : ""}. Select one to add to your agent.`}
|
||||||
|
</DialogDescription>
|
||||||
|
</DialogHeader>
|
||||||
|
|
||||||
|
{step === "url" && (
|
||||||
|
<div className="flex flex-col gap-4 py-2">
|
||||||
|
<div className="flex flex-col gap-2">
|
||||||
|
<Label htmlFor="mcp-server-url">Server URL</Label>
|
||||||
|
<Input
|
||||||
|
id="mcp-server-url"
|
||||||
|
type="url"
|
||||||
|
placeholder="https://mcp.example.com/mcp"
|
||||||
|
value={serverUrl}
|
||||||
|
onChange={(e) => setServerUrl(e.target.value)}
|
||||||
|
onKeyDown={(e) => e.key === "Enter" && handleDiscoverTools()}
|
||||||
|
autoFocus
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Auth required: show sign-in panel */}
|
||||||
|
{authRequired && (
|
||||||
|
<div className="flex flex-col items-center gap-3 rounded-lg border border-amber-200 bg-amber-50 p-4 dark:border-amber-800 dark:bg-amber-950">
|
||||||
|
<p className="text-sm font-medium text-amber-700 dark:text-amber-300">
|
||||||
|
This server requires authentication
|
||||||
|
</p>
|
||||||
|
<Button
|
||||||
|
onClick={handleOAuthSignIn}
|
||||||
|
disabled={oauthLoading || loading}
|
||||||
|
className="w-full"
|
||||||
|
>
|
||||||
|
{oauthLoading ? (
|
||||||
|
<span className="flex items-center gap-2">
|
||||||
|
<LoadingSpinner className="size-4" />
|
||||||
|
Waiting for sign-in...
|
||||||
|
</span>
|
||||||
|
) : (
|
||||||
|
"Sign in"
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
{!showManualToken && (
|
||||||
|
<button
|
||||||
|
onClick={() => setShowManualToken(true)}
|
||||||
|
className="text-xs text-gray-500 underline hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-300"
|
||||||
|
>
|
||||||
|
or enter a token manually
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Manual token entry — only visible when expanded */}
|
||||||
|
{showManualToken && (
|
||||||
|
<div className="flex flex-col gap-2">
|
||||||
|
<Label htmlFor="mcp-auth-token" className="text-sm">
|
||||||
|
Bearer Token
|
||||||
|
</Label>
|
||||||
|
<Input
|
||||||
|
id="mcp-auth-token"
|
||||||
|
type="password"
|
||||||
|
placeholder="Paste your auth token here"
|
||||||
|
value={manualToken}
|
||||||
|
onChange={(e) => setManualToken(e.target.value)}
|
||||||
|
onKeyDown={(e) => e.key === "Enter" && handleDiscoverTools()}
|
||||||
|
autoFocus
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{error && <p className="text-sm text-red-500">{error}</p>}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{step === "tool" && (
|
||||||
|
<ScrollArea className="max-h-[50vh] py-2">
|
||||||
|
<div className="flex flex-col gap-2 pr-3">
|
||||||
|
{tools.map((tool) => (
|
||||||
|
<MCPToolCard
|
||||||
|
key={tool.name}
|
||||||
|
tool={tool}
|
||||||
|
selected={selectedTool?.name === tool.name}
|
||||||
|
onSelect={() => setSelectedTool(tool)}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</ScrollArea>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<DialogFooter>
|
||||||
|
{step === "tool" && (
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
onClick={() => {
|
||||||
|
setStep("url");
|
||||||
|
setSelectedTool(null);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
Back
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
<Button variant="outline" onClick={handleClose}>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
{step === "url" && (
|
||||||
|
<Button
|
||||||
|
onClick={handleDiscoverTools}
|
||||||
|
disabled={!serverUrl.trim() || loading || oauthLoading}
|
||||||
|
>
|
||||||
|
{loading ? (
|
||||||
|
<span className="flex items-center gap-2">
|
||||||
|
<LoadingSpinner className="size-4" />
|
||||||
|
Connecting...
|
||||||
|
</span>
|
||||||
|
) : (
|
||||||
|
"Discover Tools"
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
{step === "tool" && (
|
||||||
|
<Button onClick={handleConfirm} disabled={!selectedTool}>
|
||||||
|
Add Block
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</DialogFooter>
|
||||||
|
</DialogContent>
|
||||||
|
</Dialog>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------- Tool Card Component --------------- //
|
||||||
|
|
||||||
|
/** Truncate a description to a reasonable length for the collapsed view. */
|
||||||
|
function truncateDescription(text: string, maxLen = 120): string {
|
||||||
|
if (text.length <= maxLen) return text;
|
||||||
|
return text.slice(0, maxLen).trimEnd() + "…";
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Pretty-print a JSON Schema type for a parameter. */
|
||||||
|
function schemaTypeLabel(schema: Record<string, any>): string {
|
||||||
|
if (schema.type) return schema.type;
|
||||||
|
if (schema.anyOf)
|
||||||
|
return schema.anyOf.map((s: any) => s.type ?? "any").join(" | ");
|
||||||
|
if (schema.oneOf)
|
||||||
|
return schema.oneOf.map((s: any) => s.type ?? "any").join(" | ");
|
||||||
|
return "any";
|
||||||
|
}
|
||||||
|
|
||||||
|
function MCPToolCard({
|
||||||
|
tool,
|
||||||
|
selected,
|
||||||
|
onSelect,
|
||||||
|
}: {
|
||||||
|
tool: MCPTool;
|
||||||
|
selected: boolean;
|
||||||
|
onSelect: () => void;
|
||||||
|
}) {
|
||||||
|
const [expanded, setExpanded] = useState(false);
|
||||||
|
const properties = tool.input_schema?.properties ?? {};
|
||||||
|
const required = new Set<string>(tool.input_schema?.required ?? []);
|
||||||
|
const paramNames = Object.keys(properties);
|
||||||
|
|
||||||
|
// Strip XML-like tags and hints from description for cleaner display
|
||||||
|
const cleanDescription = (tool.description ?? "")
|
||||||
|
.replace(/<[^>]+>[^<]*<\/[^>]+>/g, "")
|
||||||
|
.replace(/<[^>]+>/g, "")
|
||||||
|
.trim();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<button
|
||||||
|
onClick={onSelect}
|
||||||
|
className={`group flex flex-col rounded-lg border text-left transition-colors ${
|
||||||
|
selected
|
||||||
|
? "border-blue-500 bg-blue-50 dark:border-blue-400 dark:bg-blue-950"
|
||||||
|
: "border-gray-200 hover:border-gray-300 hover:bg-gray-50 dark:border-slate-700 dark:hover:border-slate-600 dark:hover:bg-slate-800"
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
{/* Header */}
|
||||||
|
<div className="flex items-center gap-2 px-3 pb-1 pt-3">
|
||||||
|
<span className="flex-1 text-sm font-semibold dark:text-white">
|
||||||
|
{tool.name}
|
||||||
|
</span>
|
||||||
|
{paramNames.length > 0 && (
|
||||||
|
<Badge variant="secondary" className="text-[10px]">
|
||||||
|
{paramNames.length} param{paramNames.length !== 1 ? "s" : ""}
|
||||||
|
</Badge>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Description (collapsed: truncated) */}
|
||||||
|
{cleanDescription && (
|
||||||
|
<p className="px-3 pb-1 text-xs leading-relaxed text-gray-500 dark:text-gray-400">
|
||||||
|
{expanded ? cleanDescription : truncateDescription(cleanDescription)}
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Parameter badges (collapsed view) */}
|
||||||
|
{!expanded && paramNames.length > 0 && (
|
||||||
|
<div className="flex flex-wrap gap-1 px-3 pb-2">
|
||||||
|
{paramNames.slice(0, 6).map((name) => (
|
||||||
|
<Badge
|
||||||
|
key={name}
|
||||||
|
variant="outline"
|
||||||
|
className="text-[10px] font-normal"
|
||||||
|
>
|
||||||
|
{name}
|
||||||
|
{required.has(name) && (
|
||||||
|
<span className="ml-0.5 text-red-400">*</span>
|
||||||
|
)}
|
||||||
|
</Badge>
|
||||||
|
))}
|
||||||
|
{paramNames.length > 6 && (
|
||||||
|
<Badge variant="outline" className="text-[10px] font-normal">
|
||||||
|
+{paramNames.length - 6} more
|
||||||
|
</Badge>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Expanded: full parameter details */}
|
||||||
|
{expanded && paramNames.length > 0 && (
|
||||||
|
<div className="mx-3 mb-2 rounded border border-gray-100 bg-gray-50/50 dark:border-slate-700 dark:bg-slate-800/50">
|
||||||
|
<table className="w-full text-xs">
|
||||||
|
<thead>
|
||||||
|
<tr className="border-b border-gray-100 dark:border-slate-700">
|
||||||
|
<th className="px-2 py-1 text-left font-medium text-gray-500 dark:text-gray-400">
|
||||||
|
Parameter
|
||||||
|
</th>
|
||||||
|
<th className="px-2 py-1 text-left font-medium text-gray-500 dark:text-gray-400">
|
||||||
|
Type
|
||||||
|
</th>
|
||||||
|
<th className="px-2 py-1 text-left font-medium text-gray-500 dark:text-gray-400">
|
||||||
|
Description
|
||||||
|
</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{paramNames.map((name) => {
|
||||||
|
const prop = properties[name] ?? {};
|
||||||
|
return (
|
||||||
|
<tr
|
||||||
|
key={name}
|
||||||
|
className="border-b border-gray-50 last:border-0 dark:border-slate-700/50"
|
||||||
|
>
|
||||||
|
<td className="px-2 py-1 font-mono text-[11px] text-gray-700 dark:text-gray-300">
|
||||||
|
{name}
|
||||||
|
{required.has(name) && (
|
||||||
|
<span className="ml-0.5 text-red-400">*</span>
|
||||||
|
)}
|
||||||
|
</td>
|
||||||
|
<td className="px-2 py-1 text-gray-500 dark:text-gray-400">
|
||||||
|
{schemaTypeLabel(prop)}
|
||||||
|
</td>
|
||||||
|
<td className="max-w-[200px] truncate px-2 py-1 text-gray-500 dark:text-gray-400">
|
||||||
|
{prop.description ?? "—"}
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Toggle details */}
|
||||||
|
{(paramNames.length > 0 || cleanDescription.length > 120) && (
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
setExpanded((prev) => !prev);
|
||||||
|
}}
|
||||||
|
className="flex w-full items-center justify-center gap-1 border-t border-gray-100 py-1.5 text-[10px] text-gray-400 hover:text-gray-600 dark:border-slate-700 dark:text-gray-500 dark:hover:text-gray-300"
|
||||||
|
>
|
||||||
|
{expanded ? "Hide details" : "Show details"}
|
||||||
|
<CaretDown
|
||||||
|
className={`h-3 w-3 transition-transform ${expanded ? "rotate-180" : ""}`}
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
</button>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -1,8 +1,14 @@
|
|||||||
import React, { useContext, useState } from "react";
|
import React, { useContext, useMemo, useState } from "react";
|
||||||
import { Button } from "@/components/__legacy__/ui/button";
|
import { Button } from "@/components/__legacy__/ui/button";
|
||||||
import { Maximize2 } from "lucide-react";
|
import { Maximize2 } from "lucide-react";
|
||||||
import * as Separator from "@radix-ui/react-separator";
|
import * as Separator from "@radix-ui/react-separator";
|
||||||
import { ContentRenderer } from "@/components/__legacy__/ui/render";
|
import { ContentRenderer } from "@/components/__legacy__/ui/render";
|
||||||
|
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||||
|
import {
|
||||||
|
globalRegistry,
|
||||||
|
OutputItem,
|
||||||
|
} from "@/components/contextual/OutputRenderers";
|
||||||
|
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||||
|
|
||||||
import { beautifyString } from "@/lib/utils";
|
import { beautifyString } from "@/lib/utils";
|
||||||
|
|
||||||
@@ -21,6 +27,9 @@ export default function NodeOutputs({
|
|||||||
data,
|
data,
|
||||||
}: NodeOutputsProps) {
|
}: NodeOutputsProps) {
|
||||||
const builderContext = useContext(BuilderContext);
|
const builderContext = useContext(BuilderContext);
|
||||||
|
const enableEnhancedOutputHandling = useGetFlag(
|
||||||
|
Flag.ENABLE_ENHANCED_OUTPUT_HANDLING,
|
||||||
|
);
|
||||||
|
|
||||||
const [expandedDialog, setExpandedDialog] = useState<{
|
const [expandedDialog, setExpandedDialog] = useState<{
|
||||||
isOpen: boolean;
|
isOpen: boolean;
|
||||||
@@ -37,6 +46,15 @@ export default function NodeOutputs({
|
|||||||
|
|
||||||
const { getNodeTitle } = builderContext;
|
const { getNodeTitle } = builderContext;
|
||||||
|
|
||||||
|
// Prepare renderers for each item when enhanced mode is enabled
|
||||||
|
const getItemRenderer = useMemo(() => {
|
||||||
|
if (!enableEnhancedOutputHandling) return null;
|
||||||
|
return (item: unknown) => {
|
||||||
|
const metadata: OutputMetadata = {};
|
||||||
|
return globalRegistry.getRenderer(item, metadata);
|
||||||
|
};
|
||||||
|
}, [enableEnhancedOutputHandling]);
|
||||||
|
|
||||||
const getBeautifiedPinName = (pin: string) => {
|
const getBeautifiedPinName = (pin: string) => {
|
||||||
if (!pin.startsWith("tools_^_")) {
|
if (!pin.startsWith("tools_^_")) {
|
||||||
return beautifyString(pin);
|
return beautifyString(pin);
|
||||||
@@ -87,15 +105,31 @@ export default function NodeOutputs({
|
|||||||
<div className="mt-2">
|
<div className="mt-2">
|
||||||
<strong className="mr-2">Data:</strong>
|
<strong className="mr-2">Data:</strong>
|
||||||
<div className="mt-1">
|
<div className="mt-1">
|
||||||
{dataArray.slice(0, 10).map((item, index) => (
|
{dataArray.slice(0, 10).map((item, index) => {
|
||||||
<React.Fragment key={index}>
|
const renderer = getItemRenderer?.(item);
|
||||||
<ContentRenderer
|
if (enableEnhancedOutputHandling && renderer) {
|
||||||
value={item}
|
const metadata: OutputMetadata = {};
|
||||||
truncateLongData={truncateLongData}
|
return (
|
||||||
/>
|
<React.Fragment key={index}>
|
||||||
{index < Math.min(dataArray.length, 10) - 1 && ", "}
|
<OutputItem
|
||||||
</React.Fragment>
|
value={item}
|
||||||
))}
|
metadata={metadata}
|
||||||
|
renderer={renderer}
|
||||||
|
/>
|
||||||
|
{index < Math.min(dataArray.length, 10) - 1 && ", "}
|
||||||
|
</React.Fragment>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<React.Fragment key={index}>
|
||||||
|
<ContentRenderer
|
||||||
|
value={item}
|
||||||
|
truncateLongData={truncateLongData}
|
||||||
|
/>
|
||||||
|
{index < Math.min(dataArray.length, 10) - 1 && ", "}
|
||||||
|
</React.Fragment>
|
||||||
|
);
|
||||||
|
})}
|
||||||
{dataArray.length > 10 && (
|
{dataArray.length > 10 && (
|
||||||
<span style={{ color: "#888" }}>
|
<span style={{ color: "#888" }}>
|
||||||
<br />
|
<br />
|
||||||
|
|||||||
@@ -4,13 +4,13 @@ import { AgentRunDraftView } from "@/app/(platform)/library/agents/[id]/componen
|
|||||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||||
import type {
|
import type {
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
GraphMeta,
|
Graph,
|
||||||
} from "@/lib/autogpt-server-api/types";
|
} from "@/lib/autogpt-server-api/types";
|
||||||
|
|
||||||
interface RunInputDialogProps {
|
interface RunInputDialogProps {
|
||||||
isOpen: boolean;
|
isOpen: boolean;
|
||||||
doClose: () => void;
|
doClose: () => void;
|
||||||
graph: GraphMeta;
|
graph: Graph;
|
||||||
doRun?: (
|
doRun?: (
|
||||||
inputs: Record<string, any>,
|
inputs: Record<string, any>,
|
||||||
credentialsInputs: Record<string, CredentialsMetaInput>,
|
credentialsInputs: Record<string, CredentialsMetaInput>,
|
||||||
|
|||||||
@@ -9,13 +9,13 @@ import { CustomNodeData } from "@/app/(platform)/build/components/legacy-builder
|
|||||||
import {
|
import {
|
||||||
BlockUIType,
|
BlockUIType,
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
GraphMeta,
|
Graph,
|
||||||
} from "@/lib/autogpt-server-api/types";
|
} from "@/lib/autogpt-server-api/types";
|
||||||
import RunnerOutputUI, { OutputNodeInfo } from "./RunnerOutputUI";
|
import RunnerOutputUI, { OutputNodeInfo } from "./RunnerOutputUI";
|
||||||
import { RunnerInputDialog } from "./RunnerInputUI";
|
import { RunnerInputDialog } from "./RunnerInputUI";
|
||||||
|
|
||||||
interface RunnerUIWrapperProps {
|
interface RunnerUIWrapperProps {
|
||||||
graph: GraphMeta;
|
graph: Graph;
|
||||||
nodes: Node<CustomNodeData>[];
|
nodes: Node<CustomNodeData>[];
|
||||||
graphExecutionError?: string | null;
|
graphExecutionError?: string | null;
|
||||||
saveAndRun: (
|
saveAndRun: (
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import { GraphInputSchema } from "@/lib/autogpt-server-api";
|
import { GraphInputSchema } from "@/lib/autogpt-server-api";
|
||||||
import { GraphMetaLike, IncompatibilityInfo } from "./types";
|
import { GraphLike, IncompatibilityInfo } from "./types";
|
||||||
|
|
||||||
// Helper type for schema properties - the generated types are too loose
|
// Helper type for schema properties - the generated types are too loose
|
||||||
type SchemaProperties = Record<string, GraphInputSchema["properties"][string]>;
|
type SchemaProperties = Record<string, GraphInputSchema["properties"][string]>;
|
||||||
@@ -36,7 +36,7 @@ export function getSchemaRequired(schema: unknown): SchemaRequired {
|
|||||||
*/
|
*/
|
||||||
export function createUpdatedAgentNodeInputs(
|
export function createUpdatedAgentNodeInputs(
|
||||||
currentInputs: Record<string, unknown>,
|
currentInputs: Record<string, unknown>,
|
||||||
latestSubGraphVersion: GraphMetaLike,
|
latestSubGraphVersion: GraphLike,
|
||||||
): Record<string, unknown> {
|
): Record<string, unknown> {
|
||||||
return {
|
return {
|
||||||
...currentInputs,
|
...currentInputs,
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
import type { GraphMeta as LegacyGraphMeta } from "@/lib/autogpt-server-api";
|
import type {
|
||||||
|
Graph as LegacyGraph,
|
||||||
|
GraphMeta as LegacyGraphMeta,
|
||||||
|
} from "@/lib/autogpt-server-api";
|
||||||
|
import type { GraphModel as GeneratedGraph } from "@/app/api/__generated__/models/graphModel";
|
||||||
import type { GraphMeta as GeneratedGraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
import type { GraphMeta as GeneratedGraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
||||||
|
|
||||||
export type SubAgentUpdateInfo<T extends GraphMetaLike = GraphMetaLike> = {
|
export type SubAgentUpdateInfo<T extends GraphLike = GraphLike> = {
|
||||||
hasUpdate: boolean;
|
hasUpdate: boolean;
|
||||||
currentVersion: number;
|
currentVersion: number;
|
||||||
latestVersion: number;
|
latestVersion: number;
|
||||||
@@ -10,7 +14,10 @@ export type SubAgentUpdateInfo<T extends GraphMetaLike = GraphMetaLike> = {
|
|||||||
incompatibilities: IncompatibilityInfo | null;
|
incompatibilities: IncompatibilityInfo | null;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Union type for GraphMeta that works with both legacy and new builder
|
// Union type for Graph (with schemas) that works with both legacy and new builder
|
||||||
|
export type GraphLike = LegacyGraph | GeneratedGraph;
|
||||||
|
|
||||||
|
// Union type for GraphMeta (without schemas) for version detection
|
||||||
export type GraphMetaLike = LegacyGraphMeta | GeneratedGraphMeta;
|
export type GraphMetaLike = LegacyGraphMeta | GeneratedGraphMeta;
|
||||||
|
|
||||||
export type IncompatibilityInfo = {
|
export type IncompatibilityInfo = {
|
||||||
|
|||||||
@@ -1,5 +1,11 @@
|
|||||||
import { useMemo } from "react";
|
import { useMemo } from "react";
|
||||||
import { GraphInputSchema, GraphOutputSchema } from "@/lib/autogpt-server-api";
|
import type {
|
||||||
|
GraphInputSchema,
|
||||||
|
GraphOutputSchema,
|
||||||
|
} from "@/lib/autogpt-server-api";
|
||||||
|
import type { GraphModel } from "@/app/api/__generated__/models/graphModel";
|
||||||
|
import { useGetV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||||
|
import { okData } from "@/app/api/helpers";
|
||||||
import { getEffectiveType } from "@/lib/utils";
|
import { getEffectiveType } from "@/lib/utils";
|
||||||
import { EdgeLike, getSchemaProperties, getSchemaRequired } from "./helpers";
|
import { EdgeLike, getSchemaProperties, getSchemaRequired } from "./helpers";
|
||||||
import {
|
import {
|
||||||
@@ -11,26 +17,38 @@ import {
|
|||||||
/**
|
/**
|
||||||
* Checks if a newer version of a sub-agent is available and determines compatibility
|
* Checks if a newer version of a sub-agent is available and determines compatibility
|
||||||
*/
|
*/
|
||||||
export function useSubAgentUpdate<T extends GraphMetaLike>(
|
export function useSubAgentUpdate(
|
||||||
nodeID: string,
|
nodeID: string,
|
||||||
graphID: string | undefined,
|
graphID: string | undefined,
|
||||||
graphVersion: number | undefined,
|
graphVersion: number | undefined,
|
||||||
currentInputSchema: GraphInputSchema | undefined,
|
currentInputSchema: GraphInputSchema | undefined,
|
||||||
currentOutputSchema: GraphOutputSchema | undefined,
|
currentOutputSchema: GraphOutputSchema | undefined,
|
||||||
connections: EdgeLike[],
|
connections: EdgeLike[],
|
||||||
availableGraphs: T[],
|
availableGraphs: GraphMetaLike[],
|
||||||
): SubAgentUpdateInfo<T> {
|
): SubAgentUpdateInfo<GraphModel> {
|
||||||
// Find the latest version of the same graph
|
// Find the latest version of the same graph
|
||||||
const latestGraph = useMemo(() => {
|
const latestGraphInfo = useMemo(() => {
|
||||||
if (!graphID) return null;
|
if (!graphID) return null;
|
||||||
return availableGraphs.find((graph) => graph.id === graphID) || null;
|
return availableGraphs.find((graph) => graph.id === graphID) || null;
|
||||||
}, [graphID, availableGraphs]);
|
}, [graphID, availableGraphs]);
|
||||||
|
|
||||||
// Check if there's an update available
|
// Check if there's a newer version available
|
||||||
const hasUpdate = useMemo(() => {
|
const hasUpdate = useMemo(() => {
|
||||||
if (!latestGraph || graphVersion === undefined) return false;
|
if (!latestGraphInfo || graphVersion === undefined) return false;
|
||||||
return latestGraph.version! > graphVersion;
|
return latestGraphInfo.version! > graphVersion;
|
||||||
}, [latestGraph, graphVersion]);
|
}, [latestGraphInfo, graphVersion]);
|
||||||
|
|
||||||
|
// Fetch full graph IF an update is detected
|
||||||
|
const { data: latestGraph } = useGetV1GetSpecificGraph(
|
||||||
|
graphID ?? "",
|
||||||
|
{ version: latestGraphInfo?.version },
|
||||||
|
{
|
||||||
|
query: {
|
||||||
|
enabled: hasUpdate && !!graphID && !!latestGraphInfo?.version,
|
||||||
|
select: okData,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
// Get connected input and output handles for this specific node
|
// Get connected input and output handles for this specific node
|
||||||
const connectedHandles = useMemo(() => {
|
const connectedHandles = useMemo(() => {
|
||||||
@@ -152,8 +170,8 @@ export function useSubAgentUpdate<T extends GraphMetaLike>(
|
|||||||
return {
|
return {
|
||||||
hasUpdate,
|
hasUpdate,
|
||||||
currentVersion: graphVersion || 0,
|
currentVersion: graphVersion || 0,
|
||||||
latestVersion: latestGraph?.version || 0,
|
latestVersion: latestGraphInfo?.version || 0,
|
||||||
latestGraph,
|
latestGraph: latestGraph || null,
|
||||||
isCompatible: compatibilityResult.isCompatible,
|
isCompatible: compatibilityResult.isCompatible,
|
||||||
incompatibilities: compatibilityResult.incompatibilities,
|
incompatibilities: compatibilityResult.incompatibilities,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ interface GraphStore {
|
|||||||
outputSchema: Record<string, any> | null,
|
outputSchema: Record<string, any> | null,
|
||||||
) => void;
|
) => void;
|
||||||
|
|
||||||
// Available graphs; used for sub-graph updates
|
// Available graphs; used for sub-graph updated version detection
|
||||||
availableSubGraphs: GraphMeta[];
|
availableSubGraphs: GraphMeta[];
|
||||||
setAvailableSubGraphs: (graphs: GraphMeta[]) => void;
|
setAvailableSubGraphs: (graphs: GraphMeta[]) => void;
|
||||||
|
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ import React, {
|
|||||||
import {
|
import {
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
CredentialsType,
|
CredentialsType,
|
||||||
|
Graph,
|
||||||
GraphExecutionID,
|
GraphExecutionID,
|
||||||
GraphMeta,
|
|
||||||
LibraryAgentPreset,
|
LibraryAgentPreset,
|
||||||
LibraryAgentPresetID,
|
LibraryAgentPresetID,
|
||||||
LibraryAgentPresetUpdatable,
|
LibraryAgentPresetUpdatable,
|
||||||
@@ -69,7 +69,7 @@ export function AgentRunDraftView({
|
|||||||
className,
|
className,
|
||||||
recommendedScheduleCron,
|
recommendedScheduleCron,
|
||||||
}: {
|
}: {
|
||||||
graph: GraphMeta;
|
graph: Graph;
|
||||||
agentActions?: ButtonAction[];
|
agentActions?: ButtonAction[];
|
||||||
recommendedScheduleCron?: string | null;
|
recommendedScheduleCron?: string | null;
|
||||||
doRun?: (
|
doRun?: (
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
import React, { useCallback, useMemo } from "react";
|
import React, { useCallback, useMemo } from "react";
|
||||||
|
|
||||||
import {
|
import {
|
||||||
|
Graph,
|
||||||
GraphExecutionID,
|
GraphExecutionID,
|
||||||
GraphMeta,
|
|
||||||
Schedule,
|
Schedule,
|
||||||
ScheduleID,
|
ScheduleID,
|
||||||
} from "@/lib/autogpt-server-api";
|
} from "@/lib/autogpt-server-api";
|
||||||
@@ -35,7 +35,7 @@ export function AgentScheduleDetailsView({
|
|||||||
onForcedRun,
|
onForcedRun,
|
||||||
doDeleteSchedule,
|
doDeleteSchedule,
|
||||||
}: {
|
}: {
|
||||||
graph: GraphMeta;
|
graph: Graph;
|
||||||
schedule: Schedule;
|
schedule: Schedule;
|
||||||
agentActions: ButtonAction[];
|
agentActions: ButtonAction[];
|
||||||
onForcedRun: (runID: GraphExecutionID) => void;
|
onForcedRun: (runID: GraphExecutionID) => void;
|
||||||
|
|||||||
@@ -4237,6 +4237,128 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/api/mcp/discover-tools": {
|
||||||
|
"post": {
|
||||||
|
"tags": ["v2", "mcp", "mcp"],
|
||||||
|
"summary": "Discover available tools on an MCP server",
|
||||||
|
"description": "Connect to an MCP server and return its available tools.\n\nIf the user has a stored MCP credential for this server URL, it will be\nused automatically — no need to pass an explicit auth token.",
|
||||||
|
"operationId": "postV2Discover available tools on an mcp server",
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": { "$ref": "#/components/schemas/DiscoverToolsRequest" }
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Successful Response",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/DiscoverToolsResponse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"401": {
|
||||||
|
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||||
|
},
|
||||||
|
"422": {
|
||||||
|
"description": "Validation Error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"security": [{ "HTTPBearerJWT": [] }]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/api/mcp/oauth/callback": {
|
||||||
|
"post": {
|
||||||
|
"tags": ["v2", "mcp", "mcp"],
|
||||||
|
"summary": "Exchange OAuth code for MCP tokens",
|
||||||
|
"description": "Exchange the authorization code for tokens and store the credential.\n\nThe frontend calls this after receiving the OAuth code from the popup.\nOn success, subsequent ``/discover-tools`` calls for the same server URL\nwill automatically use the stored credential.",
|
||||||
|
"operationId": "postV2Exchange oauth code for mcp tokens",
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/MCPOAuthCallbackRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Successful Response",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/MCPOAuthCallbackResponse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"401": {
|
||||||
|
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||||
|
},
|
||||||
|
"422": {
|
||||||
|
"description": "Validation Error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"security": [{ "HTTPBearerJWT": [] }]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/api/mcp/oauth/login": {
|
||||||
|
"post": {
|
||||||
|
"tags": ["v2", "mcp", "mcp"],
|
||||||
|
"summary": "Initiate OAuth login for an MCP server",
|
||||||
|
"description": "Discover OAuth metadata from the MCP server and return a login URL.\n\n1. Discovers the protected-resource metadata (RFC 9728)\n2. Fetches the authorization server metadata (RFC 8414)\n3. Performs Dynamic Client Registration (RFC 7591) if available\n4. Returns the authorization URL for the frontend to open in a popup",
|
||||||
|
"operationId": "postV2Initiate oauth login for an mcp server",
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": { "$ref": "#/components/schemas/MCPOAuthLoginRequest" }
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Successful Response",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/MCPOAuthLoginResponse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"401": {
|
||||||
|
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||||
|
},
|
||||||
|
"422": {
|
||||||
|
"description": "Validation Error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"security": [{ "HTTPBearerJWT": [] }]
|
||||||
|
}
|
||||||
|
},
|
||||||
"/api/oauth/app/{client_id}": {
|
"/api/oauth/app/{client_id}": {
|
||||||
"get": {
|
"get": {
|
||||||
"tags": ["oauth"],
|
"tags": ["oauth"],
|
||||||
@@ -5629,7 +5751,9 @@
|
|||||||
"description": "Successful Response",
|
"description": "Successful Response",
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"schema": { "$ref": "#/components/schemas/GraphMeta" }
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/GraphModelWithoutNodes"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -6495,18 +6619,6 @@
|
|||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
"title": "Recommended Schedule Cron"
|
"title": "Recommended Schedule Cron"
|
||||||
},
|
},
|
||||||
"nodes": {
|
|
||||||
"items": { "$ref": "#/components/schemas/Node" },
|
|
||||||
"type": "array",
|
|
||||||
"title": "Nodes",
|
|
||||||
"default": []
|
|
||||||
},
|
|
||||||
"links": {
|
|
||||||
"items": { "$ref": "#/components/schemas/Link" },
|
|
||||||
"type": "array",
|
|
||||||
"title": "Links",
|
|
||||||
"default": []
|
|
||||||
},
|
|
||||||
"forked_from_id": {
|
"forked_from_id": {
|
||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
"title": "Forked From Id"
|
"title": "Forked From Id"
|
||||||
@@ -6514,11 +6626,22 @@
|
|||||||
"forked_from_version": {
|
"forked_from_version": {
|
||||||
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
||||||
"title": "Forked From Version"
|
"title": "Forked From Version"
|
||||||
|
},
|
||||||
|
"nodes": {
|
||||||
|
"items": { "$ref": "#/components/schemas/Node" },
|
||||||
|
"type": "array",
|
||||||
|
"title": "Nodes"
|
||||||
|
},
|
||||||
|
"links": {
|
||||||
|
"items": { "$ref": "#/components/schemas/Link" },
|
||||||
|
"type": "array",
|
||||||
|
"title": "Links"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["name", "description"],
|
"required": ["name", "description"],
|
||||||
"title": "BaseGraph"
|
"title": "BaseGraph",
|
||||||
|
"description": "Graph with nodes, links, and computed I/O schema fields.\n\nUsed to represent sub-graphs within a `Graph`. Contains the full graph\nstructure including nodes and links, plus computed fields for schemas\nand trigger info. Does NOT include user_id or created_at (see GraphModel)."
|
||||||
},
|
},
|
||||||
"BaseGraph-Output": {
|
"BaseGraph-Output": {
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -6539,18 +6662,6 @@
|
|||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
"title": "Recommended Schedule Cron"
|
"title": "Recommended Schedule Cron"
|
||||||
},
|
},
|
||||||
"nodes": {
|
|
||||||
"items": { "$ref": "#/components/schemas/Node" },
|
|
||||||
"type": "array",
|
|
||||||
"title": "Nodes",
|
|
||||||
"default": []
|
|
||||||
},
|
|
||||||
"links": {
|
|
||||||
"items": { "$ref": "#/components/schemas/Link" },
|
|
||||||
"type": "array",
|
|
||||||
"title": "Links",
|
|
||||||
"default": []
|
|
||||||
},
|
|
||||||
"forked_from_id": {
|
"forked_from_id": {
|
||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
"title": "Forked From Id"
|
"title": "Forked From Id"
|
||||||
@@ -6559,6 +6670,16 @@
|
|||||||
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
||||||
"title": "Forked From Version"
|
"title": "Forked From Version"
|
||||||
},
|
},
|
||||||
|
"nodes": {
|
||||||
|
"items": { "$ref": "#/components/schemas/Node" },
|
||||||
|
"type": "array",
|
||||||
|
"title": "Nodes"
|
||||||
|
},
|
||||||
|
"links": {
|
||||||
|
"items": { "$ref": "#/components/schemas/Link" },
|
||||||
|
"type": "array",
|
||||||
|
"title": "Links"
|
||||||
|
},
|
||||||
"input_schema": {
|
"input_schema": {
|
||||||
"additionalProperties": true,
|
"additionalProperties": true,
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@@ -6605,7 +6726,8 @@
|
|||||||
"has_sensitive_action",
|
"has_sensitive_action",
|
||||||
"trigger_setup_info"
|
"trigger_setup_info"
|
||||||
],
|
],
|
||||||
"title": "BaseGraph"
|
"title": "BaseGraph",
|
||||||
|
"description": "Graph with nodes, links, and computed I/O schema fields.\n\nUsed to represent sub-graphs within a `Graph`. Contains the full graph\nstructure including nodes and links, plus computed fields for schemas\nand trigger info. Does NOT include user_id or created_at (see GraphModel)."
|
||||||
},
|
},
|
||||||
"BlockCategoryResponse": {
|
"BlockCategoryResponse": {
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -7195,6 +7317,45 @@
|
|||||||
"required": ["version_counts"],
|
"required": ["version_counts"],
|
||||||
"title": "DeleteGraphResponse"
|
"title": "DeleteGraphResponse"
|
||||||
},
|
},
|
||||||
|
"DiscoverToolsRequest": {
|
||||||
|
"properties": {
|
||||||
|
"server_url": {
|
||||||
|
"type": "string",
|
||||||
|
"title": "Server Url",
|
||||||
|
"description": "URL of the MCP server"
|
||||||
|
},
|
||||||
|
"auth_token": {
|
||||||
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
|
"title": "Auth Token",
|
||||||
|
"description": "Optional Bearer token for authenticated MCP servers"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "object",
|
||||||
|
"required": ["server_url"],
|
||||||
|
"title": "DiscoverToolsRequest",
|
||||||
|
"description": "Request to discover tools on an MCP server."
|
||||||
|
},
|
||||||
|
"DiscoverToolsResponse": {
|
||||||
|
"properties": {
|
||||||
|
"tools": {
|
||||||
|
"items": { "$ref": "#/components/schemas/MCPToolResponse" },
|
||||||
|
"type": "array",
|
||||||
|
"title": "Tools"
|
||||||
|
},
|
||||||
|
"server_name": {
|
||||||
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
|
"title": "Server Name"
|
||||||
|
},
|
||||||
|
"protocol_version": {
|
||||||
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
|
"title": "Protocol Version"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "object",
|
||||||
|
"required": ["tools"],
|
||||||
|
"title": "DiscoverToolsResponse",
|
||||||
|
"description": "Response containing the list of tools available on an MCP server."
|
||||||
|
},
|
||||||
"Document": {
|
"Document": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"url": { "type": "string", "title": "Url" },
|
"url": { "type": "string", "title": "Url" },
|
||||||
@@ -7399,18 +7560,6 @@
|
|||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
"title": "Recommended Schedule Cron"
|
"title": "Recommended Schedule Cron"
|
||||||
},
|
},
|
||||||
"nodes": {
|
|
||||||
"items": { "$ref": "#/components/schemas/Node" },
|
|
||||||
"type": "array",
|
|
||||||
"title": "Nodes",
|
|
||||||
"default": []
|
|
||||||
},
|
|
||||||
"links": {
|
|
||||||
"items": { "$ref": "#/components/schemas/Link" },
|
|
||||||
"type": "array",
|
|
||||||
"title": "Links",
|
|
||||||
"default": []
|
|
||||||
},
|
|
||||||
"forked_from_id": {
|
"forked_from_id": {
|
||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
"title": "Forked From Id"
|
"title": "Forked From Id"
|
||||||
@@ -7419,16 +7568,26 @@
|
|||||||
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
||||||
"title": "Forked From Version"
|
"title": "Forked From Version"
|
||||||
},
|
},
|
||||||
|
"nodes": {
|
||||||
|
"items": { "$ref": "#/components/schemas/Node" },
|
||||||
|
"type": "array",
|
||||||
|
"title": "Nodes"
|
||||||
|
},
|
||||||
|
"links": {
|
||||||
|
"items": { "$ref": "#/components/schemas/Link" },
|
||||||
|
"type": "array",
|
||||||
|
"title": "Links"
|
||||||
|
},
|
||||||
"sub_graphs": {
|
"sub_graphs": {
|
||||||
"items": { "$ref": "#/components/schemas/BaseGraph-Input" },
|
"items": { "$ref": "#/components/schemas/BaseGraph-Input" },
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"title": "Sub Graphs",
|
"title": "Sub Graphs"
|
||||||
"default": []
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["name", "description"],
|
"required": ["name", "description"],
|
||||||
"title": "Graph"
|
"title": "Graph",
|
||||||
|
"description": "Creatable graph model used in API create/update endpoints."
|
||||||
},
|
},
|
||||||
"GraphExecution": {
|
"GraphExecution": {
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -7778,6 +7937,52 @@
|
|||||||
"description": "Response schema for paginated graph executions."
|
"description": "Response schema for paginated graph executions."
|
||||||
},
|
},
|
||||||
"GraphMeta": {
|
"GraphMeta": {
|
||||||
|
"properties": {
|
||||||
|
"id": { "type": "string", "title": "Id" },
|
||||||
|
"version": { "type": "integer", "title": "Version" },
|
||||||
|
"is_active": {
|
||||||
|
"type": "boolean",
|
||||||
|
"title": "Is Active",
|
||||||
|
"default": true
|
||||||
|
},
|
||||||
|
"name": { "type": "string", "title": "Name" },
|
||||||
|
"description": { "type": "string", "title": "Description" },
|
||||||
|
"instructions": {
|
||||||
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
|
"title": "Instructions"
|
||||||
|
},
|
||||||
|
"recommended_schedule_cron": {
|
||||||
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
|
"title": "Recommended Schedule Cron"
|
||||||
|
},
|
||||||
|
"forked_from_id": {
|
||||||
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
|
"title": "Forked From Id"
|
||||||
|
},
|
||||||
|
"forked_from_version": {
|
||||||
|
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
||||||
|
"title": "Forked From Version"
|
||||||
|
},
|
||||||
|
"user_id": { "type": "string", "title": "User Id" },
|
||||||
|
"created_at": {
|
||||||
|
"type": "string",
|
||||||
|
"format": "date-time",
|
||||||
|
"title": "Created At"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"id",
|
||||||
|
"version",
|
||||||
|
"name",
|
||||||
|
"description",
|
||||||
|
"user_id",
|
||||||
|
"created_at"
|
||||||
|
],
|
||||||
|
"title": "GraphMeta",
|
||||||
|
"description": "Lightweight graph metadata model representing an existing graph from the database,\nfor use in listings and summaries.\n\nLacks `GraphModel`'s nodes, links, and expensive computed fields.\nUse for list endpoints where full graph data is not needed and performance matters."
|
||||||
|
},
|
||||||
|
"GraphModel": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"id": { "type": "string", "title": "Id" },
|
"id": { "type": "string", "title": "Id" },
|
||||||
"version": { "type": "integer", "title": "Version", "default": 1 },
|
"version": { "type": "integer", "title": "Version", "default": 1 },
|
||||||
@@ -7804,13 +8009,27 @@
|
|||||||
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
||||||
"title": "Forked From Version"
|
"title": "Forked From Version"
|
||||||
},
|
},
|
||||||
|
"user_id": { "type": "string", "title": "User Id" },
|
||||||
|
"created_at": {
|
||||||
|
"type": "string",
|
||||||
|
"format": "date-time",
|
||||||
|
"title": "Created At"
|
||||||
|
},
|
||||||
|
"nodes": {
|
||||||
|
"items": { "$ref": "#/components/schemas/NodeModel" },
|
||||||
|
"type": "array",
|
||||||
|
"title": "Nodes"
|
||||||
|
},
|
||||||
|
"links": {
|
||||||
|
"items": { "$ref": "#/components/schemas/Link" },
|
||||||
|
"type": "array",
|
||||||
|
"title": "Links"
|
||||||
|
},
|
||||||
"sub_graphs": {
|
"sub_graphs": {
|
||||||
"items": { "$ref": "#/components/schemas/BaseGraph-Output" },
|
"items": { "$ref": "#/components/schemas/BaseGraph-Output" },
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"title": "Sub Graphs",
|
"title": "Sub Graphs"
|
||||||
"default": []
|
|
||||||
},
|
},
|
||||||
"user_id": { "type": "string", "title": "User Id" },
|
|
||||||
"input_schema": {
|
"input_schema": {
|
||||||
"additionalProperties": true,
|
"additionalProperties": true,
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@@ -7857,6 +8076,7 @@
|
|||||||
"name",
|
"name",
|
||||||
"description",
|
"description",
|
||||||
"user_id",
|
"user_id",
|
||||||
|
"created_at",
|
||||||
"input_schema",
|
"input_schema",
|
||||||
"output_schema",
|
"output_schema",
|
||||||
"has_external_trigger",
|
"has_external_trigger",
|
||||||
@@ -7865,9 +8085,10 @@
|
|||||||
"trigger_setup_info",
|
"trigger_setup_info",
|
||||||
"credentials_input_schema"
|
"credentials_input_schema"
|
||||||
],
|
],
|
||||||
"title": "GraphMeta"
|
"title": "GraphModel",
|
||||||
|
"description": "Full graph model representing an existing graph from the database.\n\nThis is the primary model for working with persisted graphs. Includes all\ngraph data (nodes, links, sub_graphs) plus user ownership and timestamps.\nProvides computed fields (input_schema, output_schema, etc.) used during\nset-up (frontend) and execution (backend).\n\nInherits from:\n- `Graph`: provides structure (nodes, links, sub_graphs) and computed schemas\n- `GraphMeta`: provides user_id, created_at for database records"
|
||||||
},
|
},
|
||||||
"GraphModel": {
|
"GraphModelWithoutNodes": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"id": { "type": "string", "title": "Id" },
|
"id": { "type": "string", "title": "Id" },
|
||||||
"version": { "type": "integer", "title": "Version", "default": 1 },
|
"version": { "type": "integer", "title": "Version", "default": 1 },
|
||||||
@@ -7886,18 +8107,6 @@
|
|||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
"title": "Recommended Schedule Cron"
|
"title": "Recommended Schedule Cron"
|
||||||
},
|
},
|
||||||
"nodes": {
|
|
||||||
"items": { "$ref": "#/components/schemas/NodeModel" },
|
|
||||||
"type": "array",
|
|
||||||
"title": "Nodes",
|
|
||||||
"default": []
|
|
||||||
},
|
|
||||||
"links": {
|
|
||||||
"items": { "$ref": "#/components/schemas/Link" },
|
|
||||||
"type": "array",
|
|
||||||
"title": "Links",
|
|
||||||
"default": []
|
|
||||||
},
|
|
||||||
"forked_from_id": {
|
"forked_from_id": {
|
||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
"title": "Forked From Id"
|
"title": "Forked From Id"
|
||||||
@@ -7906,12 +8115,6 @@
|
|||||||
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
||||||
"title": "Forked From Version"
|
"title": "Forked From Version"
|
||||||
},
|
},
|
||||||
"sub_graphs": {
|
|
||||||
"items": { "$ref": "#/components/schemas/BaseGraph-Output" },
|
|
||||||
"type": "array",
|
|
||||||
"title": "Sub Graphs",
|
|
||||||
"default": []
|
|
||||||
},
|
|
||||||
"user_id": { "type": "string", "title": "User Id" },
|
"user_id": { "type": "string", "title": "User Id" },
|
||||||
"created_at": {
|
"created_at": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -7973,7 +8176,8 @@
|
|||||||
"trigger_setup_info",
|
"trigger_setup_info",
|
||||||
"credentials_input_schema"
|
"credentials_input_schema"
|
||||||
],
|
],
|
||||||
"title": "GraphModel"
|
"title": "GraphModelWithoutNodes",
|
||||||
|
"description": "GraphModel variant that excludes nodes, links, and sub-graphs from serialization.\n\nUsed in contexts like the store where exposing internal graph structure\nis not desired. Inherits all computed fields from GraphModel but marks\nnodes and links as excluded from JSON output."
|
||||||
},
|
},
|
||||||
"GraphSettings": {
|
"GraphSettings": {
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -8519,6 +8723,71 @@
|
|||||||
"required": ["login_url", "state_token"],
|
"required": ["login_url", "state_token"],
|
||||||
"title": "LoginResponse"
|
"title": "LoginResponse"
|
||||||
},
|
},
|
||||||
|
"MCPOAuthCallbackRequest": {
|
||||||
|
"properties": {
|
||||||
|
"code": {
|
||||||
|
"type": "string",
|
||||||
|
"title": "Code",
|
||||||
|
"description": "Authorization code from OAuth callback"
|
||||||
|
},
|
||||||
|
"state_token": {
|
||||||
|
"type": "string",
|
||||||
|
"title": "State Token",
|
||||||
|
"description": "State token for CSRF verification"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "object",
|
||||||
|
"required": ["code", "state_token"],
|
||||||
|
"title": "MCPOAuthCallbackRequest",
|
||||||
|
"description": "Request to exchange an OAuth code for tokens."
|
||||||
|
},
|
||||||
|
"MCPOAuthCallbackResponse": {
|
||||||
|
"properties": {
|
||||||
|
"credential_id": { "type": "string", "title": "Credential Id" }
|
||||||
|
},
|
||||||
|
"type": "object",
|
||||||
|
"required": ["credential_id"],
|
||||||
|
"title": "MCPOAuthCallbackResponse",
|
||||||
|
"description": "Response after successfully storing OAuth credentials."
|
||||||
|
},
|
||||||
|
"MCPOAuthLoginRequest": {
|
||||||
|
"properties": {
|
||||||
|
"server_url": {
|
||||||
|
"type": "string",
|
||||||
|
"title": "Server Url",
|
||||||
|
"description": "URL of the MCP server that requires OAuth"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "object",
|
||||||
|
"required": ["server_url"],
|
||||||
|
"title": "MCPOAuthLoginRequest",
|
||||||
|
"description": "Request to start an OAuth flow for an MCP server."
|
||||||
|
},
|
||||||
|
"MCPOAuthLoginResponse": {
|
||||||
|
"properties": {
|
||||||
|
"login_url": { "type": "string", "title": "Login Url" },
|
||||||
|
"state_token": { "type": "string", "title": "State Token" }
|
||||||
|
},
|
||||||
|
"type": "object",
|
||||||
|
"required": ["login_url", "state_token"],
|
||||||
|
"title": "MCPOAuthLoginResponse",
|
||||||
|
"description": "Response with the OAuth login URL for the user to authenticate."
|
||||||
|
},
|
||||||
|
"MCPToolResponse": {
|
||||||
|
"properties": {
|
||||||
|
"name": { "type": "string", "title": "Name" },
|
||||||
|
"description": { "type": "string", "title": "Description" },
|
||||||
|
"input_schema": {
|
||||||
|
"additionalProperties": true,
|
||||||
|
"type": "object",
|
||||||
|
"title": "Input Schema"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "object",
|
||||||
|
"required": ["name", "description", "input_schema"],
|
||||||
|
"title": "MCPToolResponse",
|
||||||
|
"description": "A single MCP tool returned by discovery."
|
||||||
|
},
|
||||||
"MarketplaceListing": {
|
"MarketplaceListing": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"id": { "type": "string", "title": "Id" },
|
"id": { "type": "string", "title": "Id" },
|
||||||
@@ -8613,26 +8882,22 @@
|
|||||||
"input_default": {
|
"input_default": {
|
||||||
"additionalProperties": true,
|
"additionalProperties": true,
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"title": "Input Default",
|
"title": "Input Default"
|
||||||
"default": {}
|
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"additionalProperties": true,
|
"additionalProperties": true,
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"title": "Metadata",
|
"title": "Metadata"
|
||||||
"default": {}
|
|
||||||
},
|
},
|
||||||
"input_links": {
|
"input_links": {
|
||||||
"items": { "$ref": "#/components/schemas/Link" },
|
"items": { "$ref": "#/components/schemas/Link" },
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"title": "Input Links",
|
"title": "Input Links"
|
||||||
"default": []
|
|
||||||
},
|
},
|
||||||
"output_links": {
|
"output_links": {
|
||||||
"items": { "$ref": "#/components/schemas/Link" },
|
"items": { "$ref": "#/components/schemas/Link" },
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"title": "Output Links",
|
"title": "Output Links"
|
||||||
"default": []
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@@ -8712,26 +8977,22 @@
|
|||||||
"input_default": {
|
"input_default": {
|
||||||
"additionalProperties": true,
|
"additionalProperties": true,
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"title": "Input Default",
|
"title": "Input Default"
|
||||||
"default": {}
|
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"additionalProperties": true,
|
"additionalProperties": true,
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"title": "Metadata",
|
"title": "Metadata"
|
||||||
"default": {}
|
|
||||||
},
|
},
|
||||||
"input_links": {
|
"input_links": {
|
||||||
"items": { "$ref": "#/components/schemas/Link" },
|
"items": { "$ref": "#/components/schemas/Link" },
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"title": "Input Links",
|
"title": "Input Links"
|
||||||
"default": []
|
|
||||||
},
|
},
|
||||||
"output_links": {
|
"output_links": {
|
||||||
"items": { "$ref": "#/components/schemas/Link" },
|
"items": { "$ref": "#/components/schemas/Link" },
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"title": "Output Links",
|
"title": "Output Links"
|
||||||
"default": []
|
|
||||||
},
|
},
|
||||||
"graph_id": { "type": "string", "title": "Graph Id" },
|
"graph_id": { "type": "string", "title": "Graph Id" },
|
||||||
"graph_version": { "type": "integer", "title": "Graph Version" },
|
"graph_version": { "type": "integer", "title": "Graph Version" },
|
||||||
@@ -12272,7 +12533,9 @@
|
|||||||
"title": "Location"
|
"title": "Location"
|
||||||
},
|
},
|
||||||
"msg": { "type": "string", "title": "Message" },
|
"msg": { "type": "string", "title": "Message" },
|
||||||
"type": { "type": "string", "title": "Error Type" }
|
"type": { "type": "string", "title": "Error Type" },
|
||||||
|
"input": { "title": "Input" },
|
||||||
|
"ctx": { "type": "object", "title": "Context" }
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["loc", "msg", "type"],
|
"required": ["loc", "msg", "type"],
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ const isValidVideoUrl = (url: string): boolean => {
|
|||||||
if (url.startsWith("data:video")) {
|
if (url.startsWith("data:video")) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
const videoExtensions = /\.(mp4|webm|ogg)$/i;
|
const videoExtensions = /\.(mp4|webm|ogg|mov|avi|mkv|m4v)$/i;
|
||||||
const youtubeRegex = /^(https?:\/\/)?(www\.)?(youtube\.com|youtu\.?be)\/.+$/;
|
const youtubeRegex = /^(https?:\/\/)?(www\.)?(youtube\.com|youtu\.?be)\/.+$/;
|
||||||
const cleanedUrl = url.split("?")[0];
|
const cleanedUrl = url.split("?")[0];
|
||||||
return (
|
return (
|
||||||
@@ -44,11 +44,29 @@ const isValidAudioUrl = (url: string): boolean => {
|
|||||||
if (url.startsWith("data:audio")) {
|
if (url.startsWith("data:audio")) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
const audioExtensions = /\.(mp3|wav)$/i;
|
const audioExtensions = /\.(mp3|wav|ogg|m4a|aac|flac)$/i;
|
||||||
const cleanedUrl = url.split("?")[0];
|
const cleanedUrl = url.split("?")[0];
|
||||||
return isValidMediaUri(url) && audioExtensions.test(cleanedUrl);
|
return isValidMediaUri(url) && audioExtensions.test(cleanedUrl);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const getVideoMimeType = (url: string): string => {
|
||||||
|
if (url.startsWith("data:video/")) {
|
||||||
|
const match = url.match(/^data:(video\/[^;]+)/);
|
||||||
|
return match?.[1] || "video/mp4";
|
||||||
|
}
|
||||||
|
const extension = url.split("?")[0].split(".").pop()?.toLowerCase();
|
||||||
|
const mimeMap: Record<string, string> = {
|
||||||
|
mp4: "video/mp4",
|
||||||
|
webm: "video/webm",
|
||||||
|
ogg: "video/ogg",
|
||||||
|
mov: "video/quicktime",
|
||||||
|
avi: "video/x-msvideo",
|
||||||
|
mkv: "video/x-matroska",
|
||||||
|
m4v: "video/mp4",
|
||||||
|
};
|
||||||
|
return mimeMap[extension || ""] || "video/mp4";
|
||||||
|
};
|
||||||
|
|
||||||
const VideoRenderer: React.FC<{ videoUrl: string }> = ({ videoUrl }) => {
|
const VideoRenderer: React.FC<{ videoUrl: string }> = ({ videoUrl }) => {
|
||||||
const videoId = getYouTubeVideoId(videoUrl);
|
const videoId = getYouTubeVideoId(videoUrl);
|
||||||
return (
|
return (
|
||||||
@@ -63,7 +81,7 @@ const VideoRenderer: React.FC<{ videoUrl: string }> = ({ videoUrl }) => {
|
|||||||
></iframe>
|
></iframe>
|
||||||
) : (
|
) : (
|
||||||
<video controls width="100%" height="315">
|
<video controls width="100%" height="315">
|
||||||
<source src={videoUrl} type="video/mp4" />
|
<source src={videoUrl} type={getVideoMimeType(videoUrl)} />
|
||||||
Your browser does not support the video tag.
|
Your browser does not support the video tag.
|
||||||
</video>
|
</video>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@@ -102,18 +102,6 @@ export function ChatMessage({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleClarificationAnswers(answers: Record<string, string>) {
|
|
||||||
if (onSendMessage) {
|
|
||||||
const contextMessage = Object.entries(answers)
|
|
||||||
.map(([keyword, answer]) => `${keyword}: ${answer}`)
|
|
||||||
.join("\n");
|
|
||||||
|
|
||||||
onSendMessage(
|
|
||||||
`I have the answers to your questions:\n\n${contextMessage}\n\nPlease proceed with creating the agent.`,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const handleCopy = useCallback(
|
const handleCopy = useCallback(
|
||||||
async function handleCopy() {
|
async function handleCopy() {
|
||||||
if (message.type !== "message") return;
|
if (message.type !== "message") return;
|
||||||
@@ -162,6 +150,22 @@ export function ChatMessage({
|
|||||||
.slice(index + 1)
|
.slice(index + 1)
|
||||||
.some((m) => m.type === "message" && m.role === "user");
|
.some((m) => m.type === "message" && m.role === "user");
|
||||||
|
|
||||||
|
const handleClarificationAnswers = (answers: Record<string, string>) => {
|
||||||
|
if (onSendMessage) {
|
||||||
|
// Iterate over questions (preserves original order) instead of answers
|
||||||
|
const contextMessage = message.questions
|
||||||
|
.map((q) => {
|
||||||
|
const answer = answers[q.keyword] || "";
|
||||||
|
return `> ${q.question}\n\n${answer}`;
|
||||||
|
})
|
||||||
|
.join("\n\n");
|
||||||
|
|
||||||
|
onSendMessage(
|
||||||
|
`**Here are my answers:**\n\n${contextMessage}\n\nPlease proceed with creating the agent.`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ClarificationQuestionsWidget
|
<ClarificationQuestionsWidget
|
||||||
questions={message.questions}
|
questions={message.questions}
|
||||||
@@ -346,6 +350,7 @@ export function ChatMessage({
|
|||||||
toolId={message.toolId}
|
toolId={message.toolId}
|
||||||
toolName={message.toolName}
|
toolName={message.toolName}
|
||||||
result={message.result}
|
result={message.result}
|
||||||
|
onSendMessage={onSendMessage}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
|
import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { EyeSlash } from "@phosphor-icons/react";
|
import { EyeSlash } from "@phosphor-icons/react";
|
||||||
import React from "react";
|
import React, { useState } from "react";
|
||||||
import ReactMarkdown from "react-markdown";
|
import ReactMarkdown from "react-markdown";
|
||||||
import remarkGfm from "remark-gfm";
|
import remarkGfm from "remark-gfm";
|
||||||
|
|
||||||
@@ -48,7 +48,9 @@ interface InputProps extends React.InputHTMLAttributes<HTMLInputElement> {
|
|||||||
*/
|
*/
|
||||||
function resolveWorkspaceUrl(src: string): string {
|
function resolveWorkspaceUrl(src: string): string {
|
||||||
if (src.startsWith("workspace://")) {
|
if (src.startsWith("workspace://")) {
|
||||||
const fileId = src.replace("workspace://", "");
|
// Strip MIME type fragment if present (e.g., workspace://abc123#video/mp4 → abc123)
|
||||||
|
const withoutPrefix = src.replace("workspace://", "");
|
||||||
|
const fileId = withoutPrefix.split("#")[0];
|
||||||
// Use the generated API URL helper to get the correct path
|
// Use the generated API URL helper to get the correct path
|
||||||
const apiPath = getGetWorkspaceDownloadFileByIdUrl(fileId);
|
const apiPath = getGetWorkspaceDownloadFileByIdUrl(fileId);
|
||||||
// Route through the Next.js proxy (same pattern as customMutator for client-side)
|
// Route through the Next.js proxy (same pattern as customMutator for client-side)
|
||||||
@@ -65,13 +67,49 @@ function isWorkspaceImage(src: string | undefined): boolean {
|
|||||||
return src?.includes("/workspace/files/") ?? false;
|
return src?.includes("/workspace/files/") ?? false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Renders a workspace video with controls and an optional "AI cannot see" badge.
|
||||||
|
*/
|
||||||
|
function WorkspaceVideo({
|
||||||
|
src,
|
||||||
|
aiCannotSee,
|
||||||
|
}: {
|
||||||
|
src: string;
|
||||||
|
aiCannotSee: boolean;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<span className="relative my-2 inline-block">
|
||||||
|
<video
|
||||||
|
controls
|
||||||
|
className="h-auto max-w-full rounded-md border border-zinc-200"
|
||||||
|
preload="metadata"
|
||||||
|
>
|
||||||
|
<source src={src} />
|
||||||
|
Your browser does not support the video tag.
|
||||||
|
</video>
|
||||||
|
{aiCannotSee && (
|
||||||
|
<span
|
||||||
|
className="absolute bottom-2 right-2 flex items-center gap-1 rounded bg-black/70 px-2 py-1 text-xs text-white"
|
||||||
|
title="The AI cannot see this video"
|
||||||
|
>
|
||||||
|
<EyeSlash size={14} />
|
||||||
|
<span>AI cannot see this video</span>
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</span>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Custom image component that shows an indicator when the AI cannot see the image.
|
* Custom image component that shows an indicator when the AI cannot see the image.
|
||||||
|
* Also handles the "video:" alt-text prefix convention to render <video> elements.
|
||||||
|
* For workspace files with unknown types, falls back to <video> if <img> fails.
|
||||||
* Note: src is already transformed by urlTransform, so workspace:// is now /api/workspace/...
|
* Note: src is already transformed by urlTransform, so workspace:// is now /api/workspace/...
|
||||||
*/
|
*/
|
||||||
function MarkdownImage(props: Record<string, unknown>) {
|
function MarkdownImage(props: Record<string, unknown>) {
|
||||||
const src = props.src as string | undefined;
|
const src = props.src as string | undefined;
|
||||||
const alt = props.alt as string | undefined;
|
const alt = props.alt as string | undefined;
|
||||||
|
const [imgFailed, setImgFailed] = useState(false);
|
||||||
|
|
||||||
const aiCannotSee = isWorkspaceImage(src);
|
const aiCannotSee = isWorkspaceImage(src);
|
||||||
|
|
||||||
@@ -84,6 +122,18 @@ function MarkdownImage(props: Record<string, unknown>) {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Detect video: prefix in alt text (set by formatOutputValue in helpers.ts)
|
||||||
|
if (alt?.startsWith("video:")) {
|
||||||
|
return <WorkspaceVideo src={src} aiCannotSee={aiCannotSee} />;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the <img> failed to load and this is a workspace file, try as video.
|
||||||
|
// This handles generic output keys like "file_out" where the MIME type
|
||||||
|
// isn't known from the key name alone.
|
||||||
|
if (imgFailed && aiCannotSee) {
|
||||||
|
return <WorkspaceVideo src={src} aiCannotSee={aiCannotSee} />;
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<span className="relative my-2 inline-block">
|
<span className="relative my-2 inline-block">
|
||||||
{/* eslint-disable-next-line @next/next/no-img-element */}
|
{/* eslint-disable-next-line @next/next/no-img-element */}
|
||||||
@@ -92,6 +142,9 @@ function MarkdownImage(props: Record<string, unknown>) {
|
|||||||
alt={alt || "Image"}
|
alt={alt || "Image"}
|
||||||
className="h-auto max-w-full rounded-md border border-zinc-200"
|
className="h-auto max-w-full rounded-md border border-zinc-200"
|
||||||
loading="lazy"
|
loading="lazy"
|
||||||
|
onError={() => {
|
||||||
|
if (aiCannotSee) setImgFailed(true);
|
||||||
|
}}
|
||||||
/>
|
/>
|
||||||
{aiCannotSee && (
|
{aiCannotSee && (
|
||||||
<span
|
<span
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ export function MessageList({
|
|||||||
key={index}
|
key={index}
|
||||||
message={message}
|
message={message}
|
||||||
prevMessage={messages[index - 1]}
|
prevMessage={messages[index - 1]}
|
||||||
|
onSendMessage={onSendMessage}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,11 +5,13 @@ import { shouldSkipAgentOutput } from "../../helpers";
|
|||||||
export interface LastToolResponseProps {
|
export interface LastToolResponseProps {
|
||||||
message: ChatMessageData;
|
message: ChatMessageData;
|
||||||
prevMessage: ChatMessageData | undefined;
|
prevMessage: ChatMessageData | undefined;
|
||||||
|
onSendMessage?: (content: string) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function LastToolResponse({
|
export function LastToolResponse({
|
||||||
message,
|
message,
|
||||||
prevMessage,
|
prevMessage,
|
||||||
|
onSendMessage,
|
||||||
}: LastToolResponseProps) {
|
}: LastToolResponseProps) {
|
||||||
if (message.type !== "tool_response") return null;
|
if (message.type !== "tool_response") return null;
|
||||||
|
|
||||||
@@ -21,6 +23,7 @@ export function LastToolResponse({
|
|||||||
toolId={message.toolId}
|
toolId={message.toolId}
|
||||||
toolName={message.toolName}
|
toolName={message.toolName}
|
||||||
result={message.result}
|
result={message.result}
|
||||||
|
onSendMessage={onSendMessage}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
|
import { Progress } from "@/components/atoms/Progress/Progress";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { useEffect, useRef, useState } from "react";
|
import { useEffect, useRef, useState } from "react";
|
||||||
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
||||||
|
import { useAsymptoticProgress } from "../ToolCallMessage/useAsymptoticProgress";
|
||||||
|
|
||||||
export interface ThinkingMessageProps {
|
export interface ThinkingMessageProps {
|
||||||
className?: string;
|
className?: string;
|
||||||
@@ -11,6 +13,7 @@ export function ThinkingMessage({ className }: ThinkingMessageProps) {
|
|||||||
const [showCoffeeMessage, setShowCoffeeMessage] = useState(false);
|
const [showCoffeeMessage, setShowCoffeeMessage] = useState(false);
|
||||||
const timerRef = useRef<NodeJS.Timeout | null>(null);
|
const timerRef = useRef<NodeJS.Timeout | null>(null);
|
||||||
const coffeeTimerRef = useRef<NodeJS.Timeout | null>(null);
|
const coffeeTimerRef = useRef<NodeJS.Timeout | null>(null);
|
||||||
|
const progress = useAsymptoticProgress(showCoffeeMessage);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (timerRef.current === null) {
|
if (timerRef.current === null) {
|
||||||
@@ -49,9 +52,18 @@ export function ThinkingMessage({ className }: ThinkingMessageProps) {
|
|||||||
<AIChatBubble>
|
<AIChatBubble>
|
||||||
<div className="transition-all duration-500 ease-in-out">
|
<div className="transition-all duration-500 ease-in-out">
|
||||||
{showCoffeeMessage ? (
|
{showCoffeeMessage ? (
|
||||||
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
|
<div className="flex flex-col items-center gap-3">
|
||||||
This could take a few minutes, grab a coffee ☕️
|
<div className="flex w-full max-w-[280px] flex-col gap-1.5">
|
||||||
</span>
|
<div className="flex items-center justify-between text-xs text-neutral-500">
|
||||||
|
<span>Working on it...</span>
|
||||||
|
<span>{Math.round(progress)}%</span>
|
||||||
|
</div>
|
||||||
|
<Progress value={progress} className="h-2 w-full" />
|
||||||
|
</div>
|
||||||
|
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
|
||||||
|
This could take a few minutes, grab a coffee ☕️
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
) : showSlowLoader ? (
|
) : showSlowLoader ? (
|
||||||
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
|
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
|
||||||
Taking a bit more time...
|
Taking a bit more time...
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user