Compare commits

...

85 Commits

Author SHA1 Message Date
Zamil Majdy
8ad4add219 Merge branch 'dev' into feat/mcp-blocks 2026-02-12 21:23:03 +04:00
Zamil Majdy
3ce375a38c fix(backend/mcp): Collect flattened dynamic fields back into tool_arguments
validate_exec merges get_input_defaults() (which flattens tool_arguments
to top-level) with execution data. The MCP reshaping code then reads
input_data.get("tool_arguments") which misses the flattened values.

Now we explicitly collect tool schema properties from top-level input_data
back into tool_arguments before passing to the block.
2026-02-12 21:06:46 +04:00
Nicholas Tindle
cb166dd6fb feat(blocks): Store sandbox files to workspace (#12073)
Store files created by sandbox blocks (Claude Code, Code Executor) to
the user's workspace for persistence across runs.

### Changes 🏗️

- **New `sandbox_files.py` utility** (`backend/util/sandbox_files.py`)
  - Shared module for extracting files from E2B sandboxes
- Stores files to workspace via `store_media_file()` (includes virus
scanning, size limits)
  - Returns `SandboxFileOutput` with path, content, and `workspace_ref`

- **Claude Code block** (`backend/blocks/claude_code.py`)
  - Added `workspace_ref` field to `FileOutput` schema
  - Replaced inline `_extract_files()` with shared utility
  - Files from working directory now stored to workspace automatically

- **Code Executor block** (`backend/blocks/code_executor.py`)
  - Added `files` output field to `ExecuteCodeBlock.Output`
  - Creates `/output` directory in sandbox before execution
  - Extracts all files (text + binary) from `/output` after execution
- Updated `execute_code()` to support file extraction with
`extract_files` param

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Create agent with Claude Code block, have it create a file, verify
`workspace_ref` in output
- [x] Create agent with Code Executor block, write file to `/output`,
verify `workspace_ref` in output
  - [x] Verify files persist in workspace after sandbox disposal
- [x] Verify binary files (images, etc.) work correctly in Code Executor
- [x] Verify existing graphs using `content` field still work (backward
compat)

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

No configuration changes required - this is purely additive backend
code.

---

**Related:** Closes SECRT-1931

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Medium Risk**
> Adds automatic extraction and workspace storage of sandbox-written
files (including binaries for code execution), which can affect output
payload size, performance, and file-handling edge cases.
> 
> **Overview**
> **Sandbox blocks now persist generated files to workspace.** A new
shared utility (`backend/util/sandbox_files.py`) extracts files from an
E2B sandbox (scoped by a start timestamp) and stores them via
`store_media_file`, returning `SandboxFileOutput` with `workspace_ref`.
> 
> `ClaudeCodeBlock` replaces its inline file-scraping logic with this
utility and updates the `files` output schema to include
`workspace_ref`.
> 
> `ExecuteCodeBlock` adds a `files` output and extends the executor
mixin to optionally extract/store files (text + binary) when an
`execution_context` is provided; related mocks/tests and docs are
updated accordingly.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
343854c0cf. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 15:56:59 +00:00
Zamil Majdy
2e4fd05864 fix(backend/mcp): Add null metadata check in _auto_lookup_credential
Same defensive pattern as c8de1d6 — use (cred.metadata or {}).get(...)
to avoid AttributeError when credential metadata is None.
2026-02-12 19:54:22 +04:00
Zamil Majdy
d1a9db75f1 Merge branch 'dev' into feat/mcp-blocks 2026-02-12 19:40:39 +04:00
Swifty
3d31f62bf1 Revert "added feature request tooling"
This reverts commit b8b6c9de23.
2026-02-12 16:39:24 +01:00
Swifty
b8b6c9de23 added feature request tooling 2026-02-12 16:38:17 +01:00
Zamil Majdy
c8de1d6dd9 fix(backend/mcp): Add defensive null check for credential metadata access
Guard against NoneType AttributeError when cred.metadata is None by
using (cred.metadata or {}).get() pattern consistently.
2026-02-12 19:38:00 +04:00
Zamil Majdy
8663d7a5ba fix(frontend): Add response status narrowing for generated MCP API types
The generated API functions return discriminated union types (success |
error). Add status checks to narrow the type before accessing
success-specific fields, and convert null to undefined for
CredentialsMetaResponse compatibility.
2026-02-12 19:28:57 +04:00
Zamil Majdy
2dd09943cb fix: Restore openapi.json with MCP endpoints
The previous regeneration fetched from a dev-branch server (port 8006)
that didn't have MCP routes, removing MCP endpoints from the spec.
2026-02-12 18:38:50 +04:00
Zamil Majdy
03522395c0 chore: Regenerate OpenAPI spec and apply formatting 2026-02-12 18:32:41 +04:00
Zamil Majdy
5ca2e4cd47 refactor(frontend): Migrate MCP API calls to generated queries/types
Replace manual BackendAPI.mcpDiscoverTools, mcpOAuthLogin, and
mcpOAuthCallback methods with Orval-generated functions from
__generated__/endpoints/mcp/mcp.ts. Remove manual MCPTool and
MCPDiscoverToolsResponse types from types.ts in favor of generated
MCPToolResponse and DiscoverToolsResponse models.
2026-02-12 18:24:49 +04:00
Zamil Majdy
a4d194cb07 fix(frontend): Address PR review comments
- Remove legacy-builder changes: revert BlocksControl.tsx, CustomNode.tsx,
  and Flow.tsx to dev state
- Move MCPToolDialog.tsx from legacy-builder/ to components/ and update
  import path in NewControlPanel/NewBlockMenu/Block.tsx
- Revert out-of-scope CredentialsSelect auto-defaulting behavior
- Remove `as any` cast in CredentialsInput.tsx display name
2026-02-12 18:08:44 +04:00
Zamil Majdy
3ed4d6e56b Merge branch 'dev' into feat/mcp-blocks
Resolve conflict in backend/data/block.py caused by the circular
import refactor (113e87a23) that moved Block classes to
backend/blocks/_base.py. Added MCP_TOOL enum member to BlockType in
its new location and updated mcp/block.py imports accordingly.
2026-02-12 17:47:27 +04:00
Abhimanyu Yadav
4f6055f494 refactor(frontend): remove default expiration date from API key credentials form (#12092)
### Changes 🏗️

Removed the default expiration date for API keys in the credentials
modal. Previously, API keys were set to expire the next day by default,
but now the expiration date field starts empty, allowing users to
explicitly choose whether they want to set an expiration date.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Open the API key credentials modal and verify the expiration date
field is empty by default
  - [x] Test creating an API key with and without an expiration date
  - [x] Verify both scenarios work correctly

<!-- greptile_comment -->

<h2>Greptile Overview</h2>

<details><summary><h3>Greptile Summary</h3></summary>

Removed the default expiration date for API key credentials in the
credentials modal. Previously, API keys were automatically set to expire
the next day at midnight. Now the expiration date field starts empty,
allowing users to explicitly choose whether to set an expiration.

- Removed `getDefaultExpirationDate()` helper function that calculated
tomorrow's date
- Changed default `expiresAt` value from calculated date to empty string
- Backend already supports optional expiration (`expires_at?: number`),
so no backend changes needed
- Form submission correctly handles empty expiration by passing
`undefined` to the API
</details>


<details><summary><h3>Confidence Score: 5/5</h3></summary>

- This PR is safe to merge with minimal risk
- The changes are straightforward and well-contained. The refactor
removes a helper function and changes a default value. The backend API
already supports optional expiration dates, and the form submission
logic correctly handles empty values by passing undefined. The change
improves UX by not forcing a default expiration date on users.
- No files require special attention
</details>


<!-- greptile_other_comments_section -->

<!-- /greptile_comment -->
2026-02-12 12:57:06 +00:00
Otto
695a185fa1 fix(frontend): remove fixed min-height from CoPilot message container (#12091)
## Summary

Removes the `min-h-screen` class from `ConversationContent` in
ChatMessagesContainer, which was causing fixed height layout issues in
the CoPilot chat interface.

## Changes

- Removed `min-h-screen` from ConversationContent className

## Linear

Fixes [SECRT-1944](https://linear.app/autogpt/issue/SECRT-1944)

<!-- greptile_comment -->

<h2>Greptile Overview</h2>

<details><summary><h3>Greptile Summary</h3></summary>

Removes the `min-h-screen` (100vh) class from `ConversationContent` that
was causing the chat message container to enforce a minimum viewport
height. The parent container already handles height constraints with
`h-full min-h-0` and flexbox layout, so the fixed minimum height was
creating layout conflicts. The component now properly grows within its
flex container using `flex-1`.
</details>


<details><summary><h3>Confidence Score: 5/5</h3></summary>

- This PR is safe to merge with minimal risk
- The change removes a single problematic CSS class that was causing
fixed height layout issues. The parent container already handles height
constraints properly with flexbox, and removing min-h-screen allows the
component to size correctly within its flex parent. This is a targeted,
low-risk bug fix with no logic changes.
- No files require special attention
</details>


<!-- greptile_other_comments_section -->

<!-- /greptile_comment -->
2026-02-12 12:46:29 +00:00
Reinier van der Leer
113e87a23c refactor(backend): Reduce circular imports (#12068)
I'm getting circular import issues because there is a lot of
cross-importing between `backend.data`, `backend.blocks`, and other
modules. This change reduces block-related cross-imports and thus risk
of breaking circular imports.

### Changes 🏗️

- Strip down `backend.data.block`
- Move `Block` base class and related class/enum defs to
`backend.blocks._base`
  - Move `is_block_auth_configured` to `backend.blocks._utils`
- Move `get_blocks()`, `get_io_block_ids()` etc. to `backend.blocks`
(`__init__.py`)
  - Update imports everywhere
- Remove unused and poorly typed `Block.create()`
  - Change usages from `block_cls.create()` to `block_cls()`
- Improve typing of `load_all_blocks` and `get_blocks`
- Move cross-import of `backend.api.features.library.model` from
`backend/data/__init__.py` to `backend/data/integrations.py`
- Remove deprecated attribute `NodeModel.webhook`
  - Re-generate OpenAPI spec and fix frontend usage
- Eliminate module-level `backend.blocks` import from `blocks/agent.py`
- Eliminate module-level `backend.data.execution` and
`backend.executor.manager` imports from `blocks/helpers/review.py`
- Replace `BlockInput` with `GraphInput` for graph inputs

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - CI static type-checking + tests should be sufficient for this
2026-02-12 12:07:49 +00:00
Abhimanyu Yadav
d09f1532a4 feat(frontend): replace legacy builder with new flow editor
(#12081)

### Changes 🏗️

This PR completes the migration from the legacy builder to the new Flow
editor by removing all legacy code and feature flags.

**Removed:**
- Old builder view toggle functionality (`BuilderViewTabs.tsx`)
- Legacy debug panel (`RightSidebar.tsx`)
- Feature flags: `NEW_FLOW_EDITOR` and `BUILDER_VIEW_SWITCH`
- `useBuilderView` hook and related view-switching logic

**Updated:**
- Simplified `build/page.tsx` to always render the new Flow editor
- Added CSS styling (`flow.css`) to properly render Phosphor icons in
React Flow handles

**Tests:**
- Skipped e2e test suite in `build.spec.ts` (legacy builder tests)
- Follow-up PR (#12082) will add new e2e tests for the Flow editor

### Checklist 📋

#### For code changes:

- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
    - [x] Create a new flow and verify it loads correctly
    - [x] Add nodes and connections to verify basic functionality works
    - [x] Verify that node handles render correctly with the new CSS
- [x] Check that the UI is clean without the old debug panel or view
toggles

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
2026-02-12 11:16:01 +00:00
Zamil Majdy
cd242dcda7 Merge branch 'dev' into feat/mcp-blocks 2026-02-12 11:51:06 +04:00
Zamil Majdy
19d8ba941b Merge branch 'feat/mcp-blocks' of github.com:Significant-Gravitas/AutoGPT into feat/mcp-blocks 2026-02-11 22:02:30 +04:00
Zamil Majdy
db760a7ce9 fix(mcp): Remove silent exception swallowing in credential lookup
Let credential lookup errors propagate in discover_tools endpoint
instead of silently catching all exceptions. Upgrade block.py
auto-lookup logging from debug to warning. Update tests to mock
creds_manager so they don't hit the database.
2026-02-11 07:23:52 +04:00
Zamil Majdy
ab0ec3cfe7 Merge branch 'dev' into feat/mcp-blocks 2026-02-11 07:18:05 +04:00
Zamil Majdy
3263d50f4b refactor(mcp): Use BlockUIType.MCP_TOOL instead of SpecialBlockID checks
Add MCP_TOOL to backend BlockType enum and frontend BlockUIType enum,
matching the existing pattern used by AGENT blocks. Replace all
SpecialBlockID.MCP_TOOL type-checks with uiType-based checks.
2026-02-11 07:10:22 +04:00
Zamil Majdy
99a3891bb6 fix(backend/mcp): Add defensive validation for OAuth callback and token refresh
- Validate frontend_base_url in mcp_oauth_callback (matching login route)
- Validate mcp_token_url metadata before attempting token refresh
2026-02-11 06:49:04 +04:00
Zamil Majdy
ade1072e4a fix(backend/executor): Nullify credentials when deleted instead of using field default
When a configured credential is deleted, set input_data to None (consistent
with the "no credentials" path at line 279) instead of the field's raw
default value ({}), which would fail CredentialsMetaInput validation.
2026-02-11 06:41:41 +04:00
Zamil Majdy
23f092d65c fix(backend/mcp): Use httpx.AsyncClient in test_routes to prevent event loop corruption
Replace fastapi.testclient.TestClient with httpx.AsyncClient + ASGITransport.

TestClient creates a new anyio blocking portal per request. When 11+ portals
are created and destroyed in a session that also has pytest-asyncio session-scoped
async fixtures, the session event loop gets corrupted, causing
"RuntimeError: Event loop is closed" in subsequent async tests.

AsyncClient with ASGITransport runs the ASGI app directly in the current
event loop without creating blocking portals.
2026-02-11 06:39:14 +04:00
Zamil Majdy
4ac025da09 fix(backend/mcp): Fix test_routes.py tests and refine exception handling
- discover_tools: Keep generic Exception catch → 502 for MCP connectivity errors
- mcp_oauth_callback: Add targeted exception handler for token exchange → 400
- test_routes: Fix discover_auth_server_metadata mock for no-OAuth-support case
- oauth.py: Auto-formatting only
2026-02-10 22:22:38 +04:00
Zamil Majdy
fbe4c740cb refactor(backend/mcp): Remove useless catch-reraise exception patterns in routes
Drop the broad `except Exception` catch-and-reraise-as-HTTPException
blocks. Keep only the meaningful error handlers (HTTPClientError for
401/403, MCPClientError for 502). Unhandled exceptions now propagate
naturally to FastAPI's default 500 handler.
2026-02-10 22:09:42 +04:00
Zamil Majdy
ff48f4335b fix(backend/mcp): Set loop_scope=session on all async MCP tests
Matches the pattern used by oauth_test.py to prevent event loop
conflicts with session-scoped fixtures (server, graph_cleanup).
2026-02-10 22:02:48 +04:00
Zamil Majdy
11fbb51a70 fix(tests): Remove MCP conftest.py to fix session event loop conflict
The MCP conftest.py with pytest hooks (pytest_addoption,
pytest_collection_modifyitems) was disrupting pytest-asyncio's session
event loop lifecycle, causing the SpinTestServer to be torn down before
session-scoped oauth tests could run.

Replace the conftest-based e2e gating with a simple pytestmark skipif
in the test file itself.
2026-02-10 20:48:30 +04:00
Zamil Majdy
bb8b56c7de fix(mcp): Validate JSON-RPC response is a dict before accessing keys
Add isinstance check after response.json() to prevent TypeError/
AttributeError if an MCP server returns a non-object JSON response.
2026-02-10 20:31:37 +04:00
Zamil Majdy
6ebd97f874 fix(executor): Extract only tool_arguments for MCPToolBlock input reshaping
The entire merged input_data dict (containing server_url, credentials,
selected_tool, etc.) was being assigned to tool_arguments instead of
just the tool_arguments sub-dict. This would cause validation failures
or MCP server rejections.
2026-02-10 20:22:47 +04:00
Zamil Majdy
e934f0d0c2 fix(tests): Remove session-scoped fixture overrides from MCP conftest
The MCP conftest.py was overriding session-scoped `server` and
`graph_cleanup` fixtures with no-op versions. Having two session-scoped
fixtures with the same name at different directory levels caused
pytest-asyncio event loop conflicts, making all oauth_test.py tests
fail with "Event loop is closed".

Since these fixtures are session-scoped and shared across the entire
test run, the override was unnecessary — the SpinTestServer is already
created for other tests.

Also adds defensive `access_token` key validation in MCP OAuth token
exchange and refresh to prevent KeyError on malformed responses.
2026-02-10 20:17:07 +04:00
Zamil Majdy
3e38b141dd fix(tests): Use async pytest_asyncio fixtures in MCP conftest
The MCP conftest's sync server/graph_cleanup fixtures must match the
parent conftest's async pytest_asyncio fixtures to avoid disrupting
the session event loop management, which caused "Event loop is closed"
errors in oauth_test.py tests.
2026-02-10 18:40:22 +04:00
Zamil Majdy
ec72e7eb7b fix(tests): Restore pytest_asyncio.fixture with loop_scope for session fixtures
The server and graph_cleanup fixtures in conftest.py require explicit
pytest_asyncio.fixture(loop_scope="session") to properly manage the
session event loop. Using plain pytest.fixture causes "Event loop is
closed" errors in all oauth_test.py tests.
2026-02-10 18:16:36 +04:00
Zamil Majdy
db038cd0e0 fix(tests): Revert oauth_test.py fixture scope changes to fix event loop error
Restores session-scoped fixtures and pytest_asyncio decorators that were
accidentally changed, causing "RuntimeError: Event loop is closed" in
test_authorize_creates_code_in_database. Also regenerates openapi.json.
2026-02-10 17:32:52 +04:00
Zamil Majdy
6805a4f3c5 Merge branch 'dev' into feat/mcp-blocks 2026-02-10 17:27:23 +04:00
Zamil Majdy
cb7a0cbdd7 refactor(mcp): Make create_mcp_oauth_handler public and use top-level import 2026-02-10 17:08:03 +04:00
Zamil Majdy
79d6e8e2d7 fix(frontend/credentials): Handle stale credential IDs in CredentialsSelect
When a credential is deleted but the node still references its ID,
CredentialsSelect now treats the stale ID as unselected and falls
back to the first available credential instead of showing the raw ID.
2026-02-10 17:02:08 +04:00
Zamil Majdy
472117a872 fix(mcp): Handle MCP credential deletion with dynamic OAuth handler
MCP credentials use per-server dynamic OAuth handlers, not a static
handler registered in HANDLERS_BY_NAME. The delete endpoint now
creates a dynamic handler from credential metadata for token
revocation instead of failing with "Provider 'mcp' does not support
OAuth".
2026-02-10 16:47:36 +04:00
Zamil Majdy
75a7ccf36e fix(mcp): Address PR review comments - defensive checks and docs
- Validate token_endpoint in OAuth metadata before accessing it
- Check authorization_servers list is non-empty before indexing
- Use provider_matches() (renamed from private _provider_matches) in
  creds_manager for Python 3.13 StrEnum compatibility
- Fill in MCP block documentation with technical explanation and use cases
2026-02-10 16:42:29 +04:00
Zamil Majdy
84809f4b94 fix(mcp): Set customized_name at block creation and add pre-run validation
- Set customized_name in metadata when MCP and Agent blocks are created
  (both legacy and new builder) so titles persist through save/load
- Remove convoluted agent_name fallback from NodeHeader and getNodeTitle
- Add custom block-level validation in graph pre-run checks so MCP tool
  arguments are validated before execution
- Fix server_name fallback to URL hostname in discover_tools endpoint
2026-02-10 16:34:19 +04:00
Zamil Majdy
4364a771d4 fix(mcp): Validate required tool args and fix title fallback for existing blocks
Backend: Add required-field validation in MCPToolBlock.run() before
calling the MCP server. The executor-level validation is bypassed for
MCP blocks because get_input_defaults() flattens tool_arguments,
stripping tool_input_schema from the validation context.

Frontend: NodeHeader now derives the MCP server label from the server
URL hostname when server_name is missing (pruned by pruneEmptyValues).
This fixes the title for existing blocks that don't have customized_name
in metadata.
2026-02-10 15:42:51 +04:00
Zamil Majdy
4d4ed562f0 fix(frontend/mcp): Ensure MCP block title persists across save/refresh
When the MCP server returns a null server_name, fall back to the URL
hostname so customized_name is always set in metadata. This prevents
the title from degrading to "MCP:" after save and reload.
2026-02-10 15:31:29 +04:00
Zamil Majdy
8bea7cf875 chore(mcp): Remove dev artifacts and simplify credential lookup
- Remove MCP_BLOCK_IMPLEMENTATION.md development doc
- Remove console.log debug statements from OAuth callback
- Simplify credential lookup to single call (get_creds_by_provider
  already handles Python 3.13 StrEnum bug via _provider_matches)
- Remove unused Credentials import from routes.py
2026-02-10 15:18:55 +04:00
Zamil Majdy
c1c269c4a9 fix(frontend/credentials): Auto-select first credential and persist MCP block title
- CredentialsSelect: default to first available credential instead of
  "None" when credentials exist, reorder options to show credentials
  before the "None" option, and notify parent on auto-select
- Revert CredentialsGroupedView user auto-select effect (now handled
  at the CredentialsSelect level)
- Block.tsx: persist MCP block title as customized_name in metadata
  so it survives save/load
2026-02-10 14:56:19 +04:00
Zamil Majdy
65987ff15e fix(frontend/credentials): Auto-select user credentials in run dialog
Add auto-selection for user credentials (like MCP OAuth) in the
CredentialsGroupedView run dialog. When exactly one credential matches
the provider, type, and discriminator values (e.g. MCP server URL),
it is pre-selected instead of defaulting to "None (skip this credential)".
2026-02-10 14:21:23 +04:00
Zamil Majdy
ed50f7f87d fix(mcp): Wire credentials into MCP block form and add auto-lookup fallback
Frontend: Include credentials field in MCP block's dynamic input schema
so users can select OAuth credentials from the node form. Separate
credentials from tool_arguments in FormCreator to store them at the
correct level in hardcodedValues.

Backend: Add _auto_lookup_credential fallback in MCPToolBlock.run() for
legacy nodes that don't have credentials explicitly set. This resolves
the credential by matching mcp_server_url in stored OAuth metadata.
2026-02-10 14:05:09 +04:00
Zamil Majdy
c03fb170e0 fix(backend/credentials): Handle Python 3.13 str(StrEnum) bug in OAuth state verification
verify_state_token and get_creds_by_provider compared provider strings
with ==, which failed when OAuth states were stored with the buggy
"ProviderName.MCP" format from Python 3.13's str(Enum) behavior.

Also fix double-append in store_state_token where the state was written
once via edit_user_integrations and again via a redundant manual block.
2026-02-10 13:32:38 +04:00
Zamil Majdy
8a2f98b23c fix(mcp): Fix discover_tools test mock and credential auto-unselect
- Add missing `refresh_if_needed` mock to test_discover_tools_auto_uses_stored_credential
  so it returns the stored credential instead of a MagicMock
- Fix credential auto-unselect clearing MCP credentials on initial render:
  skip the "unselect if not available" check when the saved credentials
  list is empty (empty list means not loaded yet, not invalid)
2026-02-10 12:55:35 +04:00
Zamil Majdy
5e2ae3cec5 Merge branch 'dev' into feat/mcp-blocks 2026-02-10 12:45:34 +04:00
Zamil Majdy
f8771484fe fix(mcp): Fix CI failures and credential validation issues
- Fix pyright errors in graph_test.py by properly typing frozenset[CredentialsType]
- Fix executor validation crash when credentials is empty {} by nullifying
  the field before JSON schema validation
- Exclude MCP Tool block from e2e block discovery test (requires dialog)
- Normalize provider string in CredentialsMetaResponse to handle Python 3.13
  str(Enum) bug for stored credentials
- Fix get_host() to match MCP provider regardless of enum string format
2026-02-10 12:34:00 +04:00
Zamil Majdy
81e4f0a4b0 fix(mcp): Refresh expired OAuth tokens before tool discovery
The discover_tools endpoint was reading raw access tokens from stored
credentials without checking if they had expired. This caused users
to be prompted to re-authenticate every time the token expired (~1h).

Now uses creds_manager.refresh_if_needed() to transparently refresh
expired tokens before using them.
2026-02-10 12:18:43 +04:00
Zamil Majdy
66aada30f0 fix(tests): Revert conftest.py and oauth_test.py fixture scope changes
The pytest_asyncio fixture changes with loop_scope="session" caused
"Event loop is closed" errors in all 31 oauth_test.py tests on CI.
MCP tests have their own conftest override and don't need these changes.
2026-02-10 11:56:28 +04:00
Zamil Majdy
74e04f71f4 fix(tests): Properly mock get_required_fields in credential validation tests
The tests used MagicMock for block.input_schema but didn't mock
get_required_fields(), causing the "required missing creds" test to
silently treat all credentials as optional.
2026-02-10 11:44:12 +04:00
Zamil Majdy
4db27ca112 fix(mcp): Auto-select credential and gracefully handle stale IDs
- Auto-select credential when exactly one match exists (even for
  optional fields). Only skip auto-select for optional fields with
  multiple choices.
- In executor, catch ValueError from creds_manager.acquire() for
  optional credential fields — fall back to running without credentials
  instead of crashing when stale IDs reference deleted credentials.
2026-02-10 11:25:06 +04:00
Zamil Majdy
27ba4e8e93 fix(frontend): Remove credential field reordering on selection
The sortByUnsetFirst comparator in splitCredentialFieldsBySystem
caused credential inputs to jump positions every time a credential
was selected (set fields moved to bottom, unset moved to top).
Remove the sort to keep stable ordering.
2026-02-10 11:18:42 +04:00
Zamil Majdy
1a1985186a fix(mcp): Normalize credential input_data for JSON schema validation
The model_validator on CredentialsMetaInput normalizes legacy
"ProviderName.MCP" format for Pydantic validation, but validate_data()
uses raw JSON schema which bypasses Pydantic. Write normalized values
back to input_data after Pydantic processes them so both validation
paths see correct data.
2026-02-10 11:15:10 +04:00
Zamil Majdy
8fd13ade74 fix(mcp): Normalize legacy ProviderName format and fix credential optionality
- Add model_validator on CredentialsMetaInput to auto-normalize old
  "ProviderName.MCP" format to "mcp" at the model level, eliminating
  the need for string parsing hacks in every consumer.

- Fix aggregate_credentials_inputs to check block schema defaults when
  determining if credentials are required, not just node metadata.
  MCP blocks with default={} are always optional regardless of metadata.
2026-02-10 09:43:28 +04:00
Zamil Majdy
88ee4b3a11 fix(mcp): Clean up old credentials stored with wrong provider string
Also search for credentials stored with "ProviderName.MCP" (from the
Python 3.13 str(Enum) bug) during both discover-tools auto-lookup and
OAuth callback cleanup. Remove the temporary debug endpoint.
2026-02-10 09:22:50 +04:00
Zamil Majdy
8eed4ad653 fix(mcp): Use ProviderName enum directly instead of str() for credential provider
Python 3.13 changed str(StrEnum) to return "ClassName.MEMBER" instead of
the plain value. This caused MCP credentials to be stored with provider
"ProviderName.MCP" instead of "mcp", leading to type/provider mismatch
errors during graph validation and execution.

Fix: Pass the enum directly to Pydantic (which extracts .value automatically),
matching the pattern used by all other OAuth handlers. Use .value explicitly
only in non-Pydantic contexts (string comparisons, API calls).
2026-02-10 09:06:29 +04:00
Zamil Majdy
7744b89e96 fix(mcp): Use ProviderName.MCP.value instead of str() for credential provider
Python 3.13 changed str(StrEnum) to return "ClassName.MEMBER" instead of
the plain value. This caused MCP credentials to be stored with provider
"ProviderName.MCP" instead of "mcp", leading to type/provider mismatch
errors during graph validation and execution.
2026-02-10 09:04:38 +04:00
Zamil Majdy
4c02cd8f2f fix(mcp): Handle optional credentials in graph save and execution validation
- _on_graph_activate: Clear stale credential references for optional
  fields instead of blocking the save. Checks both node metadata
  (credentials_optional) and block schema (field not in required_fields).
- _validate_node_input_credentials: Use block schema's required_fields
  as fallback for credentials_optional check, so MCP blocks with
  default={} credentials are properly treated as optional.
- Set credentials_optional metadata on new MCP nodes in the frontend.
2026-02-10 08:53:15 +04:00
Zamil Majdy
909f313e1e fix(frontend/mcp): Filter credential auto-select by server URL discriminator
Prevent MCP credential cross-contamination where a credential for one
server (e.g. Sentry) fills credential fields for other servers (e.g.
Linear). Adds matchesDiscriminatorValues() to match credentials by host
against discriminator_values from the schema.
2026-02-10 07:39:44 +04:00
Zamil Majdy
edd9a90903 refactor(mcp): Share OAuth popup logic and fix credential persistence
- Extract shared OAuth popup utility (oauth-popup.ts) used by both
  MCPToolDialog and useCredentialsInput, eliminating ~200 lines of
  duplicated BroadcastChannel/postMessage/localStorage listener code
- Add mcpOAuthCallback to credentials provider so MCP credentials
  are added to the in-memory cache after OAuth (fixes credentials not
  appearing in the credential picker after OAuth via MCPToolDialog)
- Fix oauth_test.py async fixtures missing loop_scope="session"
- Add MCP token refresh handler in creds_manager for dynamic endpoints
- Fix enum string representation in CredentialsFieldInfo.combine()
2026-02-10 07:17:05 +04:00
Zamil Majdy
ba031329e9 fix(mcp): Integrate MCPToolBlock with standard credentials system
- Replace manual credential_id field with CredentialsMetaInput pattern
- Fix credential deduplication so different MCP server URLs get separate
  credential entries in the task credentials panel
- Add descriptive display names (e.g. "MCP: mcp.sentry.dev")
- Fix OAuth popup callback by adding mcp_callback route to middleware
  exclusion list and adding localStorage polling fallback
- Fix SSRF test fixture to patch Requests constructor directly
- Add MCP server URL matching for credential auto-assignment
- Return CredentialsMetaResponse from MCP OAuth callback
- Support MCP-specific OAuth flow in frontend credential input
- Filter MCP credentials by server URL in frontend
- Add test coverage for credential deduplication logic
2026-02-09 20:59:37 +04:00
Zamil Majdy
6ab1a6867e fix(backend/mcp): Fix pyright errors and formatting in MCP block and tests
- Use isinstance(creds, APIKeyCredentials) instead of hasattr check
- Rewrite integration tests to use user_id param and mock _resolve_auth_token
- Fix f-string and line-length formatting issues in routes.py
2026-02-09 19:16:17 +04:00
Zamil Majdy
d9269310cc fix(frontend/mcp): Loop HTML tag stripping to prevent XSS bypass
The single-pass regex `/<[^>]+>/g` can be bypassed with nested tags
like `<scr<script>ipt>`. Loop until no more tags are found.
Note: React auto-escapes JSX so this is defense-in-depth.
2026-02-09 19:10:17 +04:00
Zamil Majdy
fe70b6929f fix(mcp): Remove trusted_origins to prevent SSRF on user-provided URLs
User-provided MCP server URLs should not bypass SSRF IP-blocking
validation. Remove trusted_origins from all MCP code so that
private/internal IPs are properly blocked. Keep ThreadedResolver
in HostResolver fallback for DNS reliability in subprocess
environments.
2026-02-09 18:55:17 +04:00
Zamil Majdy
340520ba85 fix(mcp): OAuth discovery fallback, session ID, credential lookup, and DNS reliability
- Support MCP servers that serve OAuth metadata directly without
  protected-resource metadata (e.g. Linear) by falling back to
  discover_auth_server_metadata on the server's own origin
- Omit resource_url when no protected-resource metadata exists to
  avoid token audience mismatch errors (RFC 8707 resource is optional)
- Add Mcp-Session-Id header tracking per MCP Streamable HTTP spec
- Fall back to server_url credential lookup when credential_id is
  empty (pruneEmptyValues strips it from saved graphs)
- Use ThreadedResolver instead of c-ares AsyncResolver to avoid DNS
  failures in forked subprocess environments
- Simplify OAuth UX: single "Sign in & Connect" button on 401,
  remove sticky localStorage URL prefill
- Clean up stale MCP credentials on re-authentication
2026-02-09 18:51:53 +04:00
Zamil Majdy
6c2791b00b fix(frontend/mcp): Robust OAuth callback with localStorage fallback and popup close detection
BroadcastChannel can silently fail in some browser scenarios. Added:
- localStorage as third communication method in callback page
- storage event listener in dialog
- Popup close detection that checks localStorage directly
- Cleaned up auth-required box styling (gray instead of amber)
2026-02-09 17:52:02 +04:00
Zamil Majdy
7decc20a32 fix(backend/mcp): Auto-refresh expired OAuth tokens before MCP tool calls
_resolve_auth_token now checks token expiry and refreshes using
MCPOAuthHandler with metadata (token_url, client_id, client_secret)
stored during the OAuth callback flow.
2026-02-09 17:37:24 +04:00
Zamil Majdy
54375065d5 fix(mcp): Reshape execution input for MCPToolBlock like AgentExecutorBlock
The dynamic get_input_defaults returns only tool_arguments, so the
execution engine loses block-level fields like server_url. Reconstruct
the full Input from node.input_default and set tool_arguments from the
resolved dynamic input, matching the AgentExecutorBlock pattern.
2026-02-09 14:51:49 +04:00
Zamil Majdy
d62fde9445 fix(mcp): Use manual credential resolution instead of CredentialsField
The block framework's CredentialsField requires credentials to always be
present, which doesn't work for public MCP servers. Replace it with a
plain credential_id field and manual resolution from the credential store,
allowing both authenticated and public MCP servers to work seamlessly.
2026-02-09 14:41:14 +04:00
Zamil Majdy
03487f7b4d fix(frontend/mcp): Remove broken credentials widget and disable auto-connect
- Remove credentials field from MCP dynamic schema since auth is handled
  by the dialog's OAuth flow (the standard credentials widget doesn't
  support MCP as a provider and fails with 404)
- Simplify FormCreator MCP handling — all form fields are tool arguments
- Disable auto-connect on dialog open; pre-fill last URL instead so user
  can edit before connecting
2026-02-09 14:27:36 +04:00
Zamil Majdy
df41d02fce feat(frontend/mcp): Add MCP tool discovery UI, OAuth flow, and dynamic block schema
- Add MCPToolDialog with tool discovery, OAuth sign-in, and card-based tool selection
- Add OAuth callback route using BroadcastChannel API for popup communication
- Add API client methods for MCP discovery, OAuth login, and callback
- Register MCP API routes on the backend REST API
- Render dynamic input schema for MCP blocks (credentials + tool params)
  in both legacy and new builder CustomNode components
- Nest MCP tool argument values under tool_arguments in hardcodedValues
- Display tool name with server name prefix in block header
- Add backend route tests for discovery, OAuth login, and callback endpoints
2026-02-09 14:18:59 +04:00
Otto
7c9e47ba76 fix(mcp): Remove redundant exception handling and unnecessary str() cast
- client.py: except (ValueError, Exception) → except Exception
  (Exception already catches ValueError, so it's redundant)
- oauth.py: SecretStr(str(tokens[...])) → SecretStr(tokens[...])
  (refresh_token is already a string, no cast needed)
2026-02-09 08:40:58 +00:00
Zamil Majdy
e59e8dd9a9 fix(mcp): Skip e2e tests in CI unless --run-e2e is passed
E2e tests hit a real external MCP server and are inherently flaky.
Skip them by default, require --run-e2e flag to opt in.
2026-02-09 10:14:35 +04:00
Zamil Majdy
7aab2eb1d5 style(mcp): Apply linter formatting 2026-02-08 20:00:44 +04:00
Zamil Majdy
5ab28ccda2 fix(mcp): Fix pyright errors in test files
- Add type: ignore for aiohttp private _server.sockets access
- Add assert not None before accessing Optional refresh_token
2026-02-08 19:52:52 +04:00
Zamil Majdy
4fe0f05980 docs: Sync block documentation for MCP Tool block 2026-02-08 19:37:49 +04:00
Zamil Majdy
19b3373052 fix(mcp): Address PR review comments
- Fix get_missing_input/get_mismatch_error to validate tool_arguments
  dict instead of the entire BlockInput data (critical bug)
- Add type check for non-dict JSON-RPC error field in client.py
- Add try/catch for non-JSON responses in client.py
- Add raise_for_status and error payload checks to OAuth token requests
- Remove hardcoded placeholder skip-list from _extract_auth_token
- Fix server start timeout check in integration tests
- Remove unused MCPTool import, move execute_block_test to top-level
- Update tests to match fixed validation behavior
- Fix MCP_BLOCK_IMPLEMENTATION.md (remove duplicate section, local path)
- Soften PKCE comment in oauth.py
2026-02-08 19:34:28 +04:00
Zamil Majdy
7db3f12876 feat(backend/blocks/mcp): Add SSE support, OAuth auth, and e2e tests
- Handle text/event-stream (SSE) responses from real MCP servers
  (MCPClient._parse_sse_response) alongside plain JSON responses
- Add e2e tests against OpenAI docs MCP server (developers.openai.com/mcp)
  verifying SSE parsing, tool discovery, and tool execution work with a
  real production MCP server
- Support both api_key and oauth2 credential types on MCPToolBlock
  (MCPCredentials union type, _extract_auth_token helper)
- Add MCPOAuthHandler implementing BaseOAuthHandler with dynamic
  endpoints (authorize_url, token_url) for MCP OAuth 2.1 with PKCE
- Add OAuth metadata discovery to MCPClient (discover_auth,
  discover_auth_server_metadata) per RFC 9728 / RFC 8414
- 76 total tests: 46 unit, 11 OAuth, 14 integration, 5 e2e
2026-02-08 16:32:50 +04:00
Zamil Majdy
e9b996abb0 feat(backend/blocks): Add integration tests and trusted_origins support
- Add a test MCP server (test_server.py) for integration testing
- Add 14 integration tests that hit a real local MCP server over HTTP
- Add trusted_origins support to MCPClient for localhost/internal servers
- MCPToolBlock now trusts the user-configured server URL by default
- Add local conftest.py to avoid SpinTestServer overhead for MCP tests

Test results: 34 unit tests + 14 integration tests = 48 total, all passing
2026-02-08 13:49:44 +04:00
Zamil Majdy
9b972389a0 feat(backend/blocks): Add MCP (Model Context Protocol) tool block
Add a dynamic MCPToolBlock that can connect to any MCP server, discover
available tools, and execute them with dynamically generated input/output
schemas. This follows the same pattern as AgentExecutorBlock for dynamic
schema handling.

New files:
- backend/blocks/mcp/client.py — MCP Streamable HTTP client (JSON-RPC 2.0)
- backend/blocks/mcp/block.py — MCPToolBlock with dynamic schema
- backend/blocks/mcp/test_mcp.py — 34 tests covering client + block
- MCP_BLOCK_IMPLEMENTATION.md — Design document

Modified files:
- backend/integrations/providers.py — Add MCP provider name
2026-02-08 12:49:28 +04:00
230 changed files with 6994 additions and 2112 deletions

View File

@@ -10,7 +10,7 @@ from typing_extensions import TypedDict
import backend.api.features.store.cache as store_cache import backend.api.features.store.cache as store_cache
import backend.api.features.store.model as store_model import backend.api.features.store.model as store_model
import backend.data.block import backend.blocks
from backend.api.external.middleware import require_permission from backend.api.external.middleware import require_permission
from backend.data import execution as execution_db from backend.data import execution as execution_db
from backend.data import graph as graph_db from backend.data import graph as graph_db
@@ -67,7 +67,7 @@ async def get_user_info(
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))], dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
) )
async def get_graph_blocks() -> Sequence[dict[Any, Any]]: async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
blocks = [block() for block in backend.data.block.get_blocks().values()] blocks = [block() for block in backend.blocks.get_blocks().values()]
return [b.to_dict() for b in blocks if not b.disabled] return [b.to_dict() for b in blocks if not b.disabled]
@@ -83,7 +83,7 @@ async def execute_graph_block(
require_permission(APIKeyPermission.EXECUTE_BLOCK) require_permission(APIKeyPermission.EXECUTE_BLOCK)
), ),
) -> CompletedBlockOutput: ) -> CompletedBlockOutput:
obj = backend.data.block.get_block(block_id) obj = backend.blocks.get_block(block_id)
if not obj: if not obj:
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.") raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
if obj.disabled: if obj.disabled:

View File

@@ -10,10 +10,15 @@ import backend.api.features.library.db as library_db
import backend.api.features.library.model as library_model import backend.api.features.library.model as library_model
import backend.api.features.store.db as store_db import backend.api.features.store.db as store_db
import backend.api.features.store.model as store_model import backend.api.features.store.model as store_model
import backend.data.block
from backend.blocks import load_all_blocks from backend.blocks import load_all_blocks
from backend.blocks._base import (
AnyBlockSchema,
BlockCategory,
BlockInfo,
BlockSchema,
BlockType,
)
from backend.blocks.llm import LlmModel from backend.blocks.llm import LlmModel
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
from backend.data.db import query_raw_with_schema from backend.data.db import query_raw_with_schema
from backend.integrations.providers import ProviderName from backend.integrations.providers import ProviderName
from backend.util.cache import cached from backend.util.cache import cached
@@ -22,7 +27,7 @@ from backend.util.models import Pagination
from .model import ( from .model import (
BlockCategoryResponse, BlockCategoryResponse,
BlockResponse, BlockResponse,
BlockType, BlockTypeFilter,
CountResponse, CountResponse,
FilterType, FilterType,
Provider, Provider,
@@ -88,7 +93,7 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
def get_blocks( def get_blocks(
*, *,
category: str | None = None, category: str | None = None,
type: BlockType | None = None, type: BlockTypeFilter | None = None,
provider: ProviderName | None = None, provider: ProviderName | None = None,
page: int = 1, page: int = 1,
page_size: int = 50, page_size: int = 50,
@@ -669,9 +674,9 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
for block_type in load_all_blocks().values(): for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type() block: AnyBlockSchema = block_type()
if block.disabled or block.block_type in ( if block.disabled or block.block_type in (
backend.data.block.BlockType.INPUT, BlockType.INPUT,
backend.data.block.BlockType.OUTPUT, BlockType.OUTPUT,
backend.data.block.BlockType.AGENT, BlockType.AGENT,
): ):
continue continue
# Find the execution count for this block # Find the execution count for this block

View File

@@ -4,7 +4,7 @@ from pydantic import BaseModel
import backend.api.features.library.model as library_model import backend.api.features.library.model as library_model
import backend.api.features.store.model as store_model import backend.api.features.store.model as store_model
from backend.data.block import BlockInfo from backend.blocks._base import BlockInfo
from backend.integrations.providers import ProviderName from backend.integrations.providers import ProviderName
from backend.util.models import Pagination from backend.util.models import Pagination
@@ -15,7 +15,7 @@ FilterType = Literal[
"my_agents", "my_agents",
] ]
BlockType = Literal["all", "input", "action", "output"] BlockTypeFilter = Literal["all", "input", "action", "output"]
class SearchEntry(BaseModel): class SearchEntry(BaseModel):

View File

@@ -88,7 +88,7 @@ async def get_block_categories(
) )
async def get_blocks( async def get_blocks(
category: Annotated[str | None, fastapi.Query()] = None, category: Annotated[str | None, fastapi.Query()] = None,
type: Annotated[builder_model.BlockType | None, fastapi.Query()] = None, type: Annotated[builder_model.BlockTypeFilter | None, fastapi.Query()] = None,
provider: Annotated[ProviderName | None, fastapi.Query()] = None, provider: Annotated[ProviderName | None, fastapi.Query()] = None,
page: Annotated[int, fastapi.Query()] = 1, page: Annotated[int, fastapi.Query()] = 1,
page_size: Annotated[int, fastapi.Query()] = 50, page_size: Annotated[int, fastapi.Query()] = 50,

View File

@@ -13,7 +13,8 @@ from backend.api.features.chat.tools.models import (
NoResultsResponse, NoResultsResponse,
) )
from backend.api.features.store.hybrid_search import unified_hybrid_search from backend.api.features.store.hybrid_search import unified_hybrid_search
from backend.data.block import BlockType, get_block from backend.blocks import get_block
from backend.blocks._base import BlockType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -10,7 +10,7 @@ from backend.api.features.chat.tools.find_block import (
FindBlockTool, FindBlockTool,
) )
from backend.api.features.chat.tools.models import BlockListResponse from backend.api.features.chat.tools.models import BlockListResponse
from backend.data.block import BlockType from backend.blocks._base import BlockType
from ._test_data import make_session from ._test_data import make_session

View File

@@ -12,7 +12,8 @@ from backend.api.features.chat.tools.find_block import (
COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_IDS,
COPILOT_EXCLUDED_BLOCK_TYPES, COPILOT_EXCLUDED_BLOCK_TYPES,
) )
from backend.data.block import AnyBlockSchema, get_block from backend.blocks import get_block
from backend.blocks._base import AnyBlockSchema
from backend.data.execution import ExecutionContext from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.data.workspace import get_or_create_workspace from backend.data.workspace import get_or_create_workspace

View File

@@ -6,7 +6,7 @@ import pytest
from backend.api.features.chat.tools.models import ErrorResponse from backend.api.features.chat.tools.models import ErrorResponse
from backend.api.features.chat.tools.run_block import RunBlockTool from backend.api.features.chat.tools.run_block import RunBlockTool
from backend.data.block import BlockType from backend.blocks._base import BlockType
from ._test_data import make_session from ._test_data import make_session

View File

@@ -15,6 +15,7 @@ from backend.data.model import (
OAuth2Credentials, OAuth2Credentials,
) )
from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.exceptions import NotFoundError from backend.util.exceptions import NotFoundError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -359,7 +360,7 @@ async def match_user_credentials_to_graph(
_, _,
_, _,
) in aggregated_creds.items(): ) in aggregated_creds.items():
# Find first matching credential by provider, type, and scopes # Find first matching credential by provider, type, scopes, and host/URL
matching_cred = next( matching_cred = next(
( (
cred cred
@@ -374,6 +375,10 @@ async def match_user_credentials_to_graph(
cred.type != "host_scoped" cred.type != "host_scoped"
or _credential_is_for_host(cred, credential_requirements) or _credential_is_for_host(cred, credential_requirements)
) )
and (
cred.provider != ProviderName.MCP
or _credential_is_for_mcp_server(cred, credential_requirements)
)
), ),
None, None,
) )
@@ -444,6 +449,22 @@ def _credential_is_for_host(
return credential.matches_url(list(requirements.discriminator_values)[0]) return credential.matches_url(list(requirements.discriminator_values)[0])
def _credential_is_for_mcp_server(
credential: Credentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""Check if an MCP OAuth credential matches the required server URL."""
if not requirements.discriminator_values:
return True
server_url = (
credential.metadata.get("mcp_server_url")
if isinstance(credential, OAuth2Credentials)
else None
)
return server_url in requirements.discriminator_values if server_url else False
async def check_user_has_required_credentials( async def check_user_has_required_credentials(
user_id: str, user_id: str,
required_credentials: list[CredentialsMetaInput], required_credentials: list[CredentialsMetaInput],

View File

@@ -1,7 +1,7 @@
import asyncio import asyncio
import logging import logging
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Annotated, List, Literal from typing import TYPE_CHECKING, Annotated, Any, List, Literal
from autogpt_libs.auth import get_user_id from autogpt_libs.auth import get_user_id
from fastapi import ( from fastapi import (
@@ -14,7 +14,7 @@ from fastapi import (
Security, Security,
status, status,
) )
from pydantic import BaseModel, Field, SecretStr from pydantic import BaseModel, Field, SecretStr, model_validator
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
from backend.api.features.library.db import set_preset_webhook, update_preset from backend.api.features.library.db import set_preset_webhook, update_preset
@@ -39,7 +39,11 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step
from backend.data.user import get_user_integrations from backend.data.user import get_user_integrations
from backend.executor.utils import add_graph_execution from backend.executor.utils import add_graph_execution
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.credentials_store import provider_matches
from backend.integrations.creds_manager import (
IntegrationCredentialsManager,
create_mcp_oauth_handler,
)
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager from backend.integrations.webhooks import get_webhook_manager
@@ -102,9 +106,37 @@ class CredentialsMetaResponse(BaseModel):
scopes: list[str] | None scopes: list[str] | None
username: str | None username: str | None
host: str | None = Field( host: str | None = Field(
default=None, description="Host pattern for host-scoped credentials" default=None,
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
) )
@model_validator(mode="before")
@classmethod
def _normalize_provider(cls, data: Any) -> Any:
"""Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug."""
if isinstance(data, dict):
prov = data.get("provider", "")
if isinstance(prov, str) and prov.startswith("ProviderName."):
member = prov.removeprefix("ProviderName.")
try:
data = {**data, "provider": ProviderName[member].value}
except KeyError:
pass
return data
@staticmethod
def get_host(cred: Credentials) -> str | None:
"""Extract host from credential: HostScoped host or MCP server URL."""
if isinstance(cred, HostScopedCredentials):
return cred.host
if isinstance(cred, OAuth2Credentials) and cred.provider in (
ProviderName.MCP,
ProviderName.MCP.value,
"ProviderName.MCP",
):
return (cred.metadata or {}).get("mcp_server_url")
return None
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens") @router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
async def callback( async def callback(
@@ -179,9 +211,7 @@ async def callback(
title=credentials.title, title=credentials.title,
scopes=credentials.scopes, scopes=credentials.scopes,
username=credentials.username, username=credentials.username,
host=( host=(CredentialsMetaResponse.get_host(credentials)),
credentials.host if isinstance(credentials, HostScopedCredentials) else None
),
) )
@@ -199,7 +229,7 @@ async def list_credentials(
title=cred.title, title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None, scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None, username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=cred.host if isinstance(cred, HostScopedCredentials) else None, host=CredentialsMetaResponse.get_host(cred),
) )
for cred in credentials for cred in credentials
] ]
@@ -222,7 +252,7 @@ async def list_credentials_by_provider(
title=cred.title, title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None, scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None, username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=cred.host if isinstance(cred, HostScopedCredentials) else None, host=CredentialsMetaResponse.get_host(cred),
) )
for cred in credentials for cred in credentials
] ]
@@ -322,6 +352,10 @@ async def delete_credentials(
tokens_revoked = None tokens_revoked = None
if isinstance(creds, OAuth2Credentials): if isinstance(creds, OAuth2Credentials):
if provider_matches(provider.value, ProviderName.MCP.value):
# MCP uses dynamic per-server OAuth — create handler from metadata
handler = create_mcp_oauth_handler(creds)
else:
handler = _get_provider_oauth_handler(request, provider) handler = _get_provider_oauth_handler(request, provider)
tokens_revoked = await handler.revoke_tokens(creds) tokens_revoked = await handler.revoke_tokens(creds)

View File

@@ -12,12 +12,11 @@ import backend.api.features.store.image_gen as store_image_gen
import backend.api.features.store.media as store_media import backend.api.features.store.media as store_media
import backend.data.graph as graph_db import backend.data.graph as graph_db
import backend.data.integrations as integrations_db import backend.data.integrations as integrations_db
from backend.data.block import BlockInput
from backend.data.db import transaction from backend.data.db import transaction
from backend.data.execution import get_graph_execution from backend.data.execution import get_graph_execution
from backend.data.graph import GraphSettings from backend.data.graph import GraphSettings
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
from backend.data.model import CredentialsMetaInput from backend.data.model import CredentialsMetaInput, GraphInput
from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import ( from backend.integrations.webhooks.graph_lifecycle_hooks import (
on_graph_activate, on_graph_activate,
@@ -1130,7 +1129,7 @@ async def create_preset_from_graph_execution(
async def update_preset( async def update_preset(
user_id: str, user_id: str,
preset_id: str, preset_id: str,
inputs: Optional[BlockInput] = None, inputs: Optional[GraphInput] = None,
credentials: Optional[dict[str, CredentialsMetaInput]] = None, credentials: Optional[dict[str, CredentialsMetaInput]] = None,
name: Optional[str] = None, name: Optional[str] = None,
description: Optional[str] = None, description: Optional[str] = None,

View File

@@ -6,9 +6,12 @@ import prisma.enums
import prisma.models import prisma.models
import pydantic import pydantic
from backend.data.block import BlockInput
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
from backend.data.model import CredentialsMetaInput, is_credentials_field_name from backend.data.model import (
CredentialsMetaInput,
GraphInput,
is_credentials_field_name,
)
from backend.util.json import loads as json_loads from backend.util.json import loads as json_loads
from backend.util.models import Pagination from backend.util.models import Pagination
@@ -323,7 +326,7 @@ class LibraryAgentPresetCreatable(pydantic.BaseModel):
graph_id: str graph_id: str
graph_version: int graph_version: int
inputs: BlockInput inputs: GraphInput
credentials: dict[str, CredentialsMetaInput] credentials: dict[str, CredentialsMetaInput]
name: str name: str
@@ -352,7 +355,7 @@ class LibraryAgentPresetUpdatable(pydantic.BaseModel):
Request model used when updating a preset for a library agent. Request model used when updating a preset for a library agent.
""" """
inputs: Optional[BlockInput] = None inputs: Optional[GraphInput] = None
credentials: Optional[dict[str, CredentialsMetaInput]] = None credentials: Optional[dict[str, CredentialsMetaInput]] = None
name: Optional[str] = None name: Optional[str] = None
@@ -395,7 +398,7 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
"Webhook must be included in AgentPreset query when webhookId is set" "Webhook must be included in AgentPreset query when webhookId is set"
) )
input_data: BlockInput = {} input_data: GraphInput = {}
input_credentials: dict[str, CredentialsMetaInput] = {} input_credentials: dict[str, CredentialsMetaInput] = {}
for preset_input in preset.InputPresets: for preset_input in preset.InputPresets:

View File

@@ -0,0 +1,404 @@
"""
MCP (Model Context Protocol) API routes.
Provides endpoints for MCP tool discovery and OAuth authentication so the
frontend can list available tools on an MCP server before placing a block.
"""
import logging
from typing import Annotated, Any
from urllib.parse import urlparse
import fastapi
from autogpt_libs.auth import get_user_id
from fastapi import Security
from pydantic import BaseModel, Field
from backend.api.features.integrations.router import CredentialsMetaResponse
from backend.blocks.mcp.client import MCPClient, MCPClientError
from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.request import HTTPClientError, Requests
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
router = fastapi.APIRouter(tags=["mcp"])
creds_manager = IntegrationCredentialsManager()
# ====================== Tool Discovery ====================== #
class DiscoverToolsRequest(BaseModel):
"""Request to discover tools on an MCP server."""
server_url: str = Field(description="URL of the MCP server")
auth_token: str | None = Field(
default=None,
description="Optional Bearer token for authenticated MCP servers",
)
class MCPToolResponse(BaseModel):
"""A single MCP tool returned by discovery."""
name: str
description: str
input_schema: dict[str, Any]
class DiscoverToolsResponse(BaseModel):
"""Response containing the list of tools available on an MCP server."""
tools: list[MCPToolResponse]
server_name: str | None = None
protocol_version: str | None = None
@router.post(
"/discover-tools",
summary="Discover available tools on an MCP server",
response_model=DiscoverToolsResponse,
)
async def discover_tools(
request: DiscoverToolsRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> DiscoverToolsResponse:
"""
Connect to an MCP server and return its available tools.
If the user has a stored MCP credential for this server URL, it will be
used automatically — no need to pass an explicit auth token.
"""
auth_token = request.auth_token
# Auto-use stored MCP credential when no explicit token is provided.
if not auth_token:
mcp_creds = await creds_manager.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
# Find the freshest credential for this server URL
best_cred: OAuth2Credentials | None = None
for cred in mcp_creds:
if (
isinstance(cred, OAuth2Credentials)
and (cred.metadata or {}).get("mcp_server_url") == request.server_url
):
if best_cred is None or (
(cred.access_token_expires_at or 0)
> (best_cred.access_token_expires_at or 0)
):
best_cred = cred
if best_cred:
# Refresh the token if expired before using it
best_cred = await creds_manager.refresh_if_needed(user_id, best_cred)
logger.info(
f"Using MCP credential {best_cred.id} for {request.server_url}, "
f"expires_at={best_cred.access_token_expires_at}"
)
auth_token = best_cred.access_token.get_secret_value()
client = MCPClient(request.server_url, auth_token=auth_token)
try:
init_result = await client.initialize()
tools = await client.list_tools()
except HTTPClientError as e:
if e.status_code in (401, 403):
raise fastapi.HTTPException(
status_code=401,
detail="This MCP server requires authentication. "
"Please provide a valid auth token.",
)
raise fastapi.HTTPException(status_code=502, detail=str(e))
except MCPClientError as e:
raise fastapi.HTTPException(status_code=502, detail=str(e))
except Exception as e:
raise fastapi.HTTPException(
status_code=502,
detail=f"Failed to connect to MCP server: {e}",
)
return DiscoverToolsResponse(
tools=[
MCPToolResponse(
name=t.name,
description=t.description,
input_schema=t.input_schema,
)
for t in tools
],
server_name=(
init_result.get("serverInfo", {}).get("name")
or urlparse(request.server_url).hostname
or "MCP"
),
protocol_version=init_result.get("protocolVersion"),
)
# ======================== OAuth Flow ======================== #
class MCPOAuthLoginRequest(BaseModel):
"""Request to start an OAuth flow for an MCP server."""
server_url: str = Field(description="URL of the MCP server that requires OAuth")
class MCPOAuthLoginResponse(BaseModel):
"""Response with the OAuth login URL for the user to authenticate."""
login_url: str
state_token: str
@router.post(
"/oauth/login",
summary="Initiate OAuth login for an MCP server",
)
async def mcp_oauth_login(
request: MCPOAuthLoginRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> MCPOAuthLoginResponse:
"""
Discover OAuth metadata from the MCP server and return a login URL.
1. Discovers the protected-resource metadata (RFC 9728)
2. Fetches the authorization server metadata (RFC 8414)
3. Performs Dynamic Client Registration (RFC 7591) if available
4. Returns the authorization URL for the frontend to open in a popup
"""
client = MCPClient(request.server_url)
# Step 1: Discover protected-resource metadata (RFC 9728)
protected_resource = await client.discover_auth()
metadata: dict[str, Any] | None = None
if protected_resource and protected_resource.get("authorization_servers"):
auth_server_url = protected_resource["authorization_servers"][0]
resource_url = protected_resource.get("resource", request.server_url)
# Step 2a: Discover auth-server metadata (RFC 8414)
metadata = await client.discover_auth_server_metadata(auth_server_url)
else:
# Fallback: Some MCP servers (e.g. Linear) are their own auth server
# and serve OAuth metadata directly without protected-resource metadata.
# Don't assume a resource_url — omitting it lets the auth server choose
# the correct audience for the token (RFC 8707 resource is optional).
resource_url = None
metadata = await client.discover_auth_server_metadata(request.server_url)
if (
not metadata
or "authorization_endpoint" not in metadata
or "token_endpoint" not in metadata
):
raise fastapi.HTTPException(
status_code=400,
detail="This MCP server does not advertise OAuth support. "
"You may need to provide an auth token manually.",
)
authorize_url = metadata["authorization_endpoint"]
token_url = metadata["token_endpoint"]
registration_endpoint = metadata.get("registration_endpoint")
revoke_url = metadata.get("revocation_endpoint")
# Step 3: Dynamic Client Registration (RFC 7591) if available
frontend_base_url = settings.config.frontend_base_url
if not frontend_base_url:
raise fastapi.HTTPException(
status_code=500,
detail="Frontend base URL is not configured.",
)
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
client_id = ""
client_secret = ""
if registration_endpoint:
reg_result = await _register_mcp_client(
registration_endpoint, redirect_uri, request.server_url
)
if reg_result:
client_id = reg_result.get("client_id", "")
client_secret = reg_result.get("client_secret", "")
if not client_id:
client_id = "autogpt-platform"
# Step 4: Store state token with OAuth metadata for the callback
scopes = (protected_resource or {}).get("scopes_supported") or metadata.get(
"scopes_supported", []
)
state_token, code_challenge = await creds_manager.store.store_state_token(
user_id,
ProviderName.MCP.value,
scopes,
state_metadata={
"authorize_url": authorize_url,
"token_url": token_url,
"revoke_url": revoke_url,
"resource_url": resource_url,
"server_url": request.server_url,
"client_id": client_id,
"client_secret": client_secret,
},
)
# Step 5: Build and return the login URL
handler = MCPOAuthHandler(
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
authorize_url=authorize_url,
token_url=token_url,
resource_url=resource_url,
)
login_url = handler.get_login_url(
scopes, state_token, code_challenge=code_challenge
)
return MCPOAuthLoginResponse(login_url=login_url, state_token=state_token)
class MCPOAuthCallbackRequest(BaseModel):
"""Request to exchange an OAuth code for tokens."""
code: str = Field(description="Authorization code from OAuth callback")
state_token: str = Field(description="State token for CSRF verification")
class MCPOAuthCallbackResponse(BaseModel):
"""Response after successfully storing OAuth credentials."""
credential_id: str
@router.post(
"/oauth/callback",
summary="Exchange OAuth code for MCP tokens",
)
async def mcp_oauth_callback(
request: MCPOAuthCallbackRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> CredentialsMetaResponse:
"""
Exchange the authorization code for tokens and store the credential.
The frontend calls this after receiving the OAuth code from the popup.
On success, subsequent ``/discover-tools`` calls for the same server URL
will automatically use the stored credential.
"""
valid_state = await creds_manager.store.verify_state_token(
user_id, request.state_token, ProviderName.MCP.value
)
if not valid_state:
raise fastapi.HTTPException(
status_code=400,
detail="Invalid or expired state token.",
)
meta = valid_state.state_metadata
frontend_base_url = settings.config.frontend_base_url
if not frontend_base_url:
raise fastapi.HTTPException(
status_code=500,
detail="Frontend base URL is not configured.",
)
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
handler = MCPOAuthHandler(
client_id=meta["client_id"],
client_secret=meta.get("client_secret", ""),
redirect_uri=redirect_uri,
authorize_url=meta["authorize_url"],
token_url=meta["token_url"],
revoke_url=meta.get("revoke_url"),
resource_url=meta.get("resource_url"),
)
try:
credentials = await handler.exchange_code_for_tokens(
request.code, valid_state.scopes, valid_state.code_verifier
)
except Exception as e:
raise fastapi.HTTPException(
status_code=400,
detail=f"OAuth token exchange failed: {e}",
)
# Enrich credential metadata for future lookup and token refresh
if credentials.metadata is None:
credentials.metadata = {}
credentials.metadata["mcp_server_url"] = meta["server_url"]
credentials.metadata["mcp_client_id"] = meta["client_id"]
credentials.metadata["mcp_client_secret"] = meta.get("client_secret", "")
credentials.metadata["mcp_token_url"] = meta["token_url"]
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
credentials.title = f"MCP: {hostname}"
# Remove old MCP credentials for the same server to prevent stale token buildup.
try:
old_creds = await creds_manager.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
for old in old_creds:
if (
isinstance(old, OAuth2Credentials)
and (old.metadata or {}).get("mcp_server_url") == meta["server_url"]
):
await creds_manager.store.delete_creds_by_id(user_id, old.id)
logger.info(
f"Removed old MCP credential {old.id} for {meta['server_url']}"
)
except Exception:
logger.debug("Could not clean up old MCP credentials", exc_info=True)
await creds_manager.create(user_id, credentials)
return CredentialsMetaResponse(
id=credentials.id,
provider=credentials.provider,
type=credentials.type,
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
host=credentials.metadata.get("mcp_server_url"),
)
# ======================== Helpers ======================== #
async def _register_mcp_client(
registration_endpoint: str,
redirect_uri: str,
server_url: str,
) -> dict[str, Any] | None:
"""Attempt Dynamic Client Registration (RFC 7591) with an MCP auth server."""
try:
response = await Requests(raise_for_status=True).post(
registration_endpoint,
json={
"client_name": "AutoGPT Platform",
"redirect_uris": [redirect_uri],
"grant_types": ["authorization_code"],
"response_types": ["code"],
"token_endpoint_auth_method": "client_secret_post",
},
)
data = response.json()
if isinstance(data, dict) and "client_id" in data:
return data
return None
except Exception as e:
logger.warning(f"Dynamic client registration failed for {server_url}: {e}")
return None

View File

@@ -0,0 +1,436 @@
"""Tests for MCP API routes.
Uses httpx.AsyncClient with ASGITransport instead of fastapi.testclient.TestClient
to avoid creating blocking portals that can corrupt pytest-asyncio's session event loop.
"""
from unittest.mock import AsyncMock, patch
import fastapi
import httpx
import pytest
import pytest_asyncio
from autogpt_libs.auth import get_user_id
from backend.api.features.mcp.routes import router
from backend.blocks.mcp.client import MCPClientError, MCPTool
from backend.util.request import HTTPClientError
app = fastapi.FastAPI()
app.include_router(router)
app.dependency_overrides[get_user_id] = lambda: "test-user-id"
@pytest_asyncio.fixture(scope="module")
async def client():
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
yield c
class TestDiscoverTools:
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_success(self, client):
mock_tools = [
MCPTool(
name="get_weather",
description="Get weather for a city",
input_schema={
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
),
MCPTool(
name="add_numbers",
description="Add two numbers",
input_schema={
"type": "object",
"properties": {
"a": {"type": "number"},
"b": {"type": "number"},
},
},
),
]
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={
"protocolVersion": "2025-03-26",
"serverInfo": {"name": "test-server"},
}
)
instance.list_tools = AsyncMock(return_value=mock_tools)
response = await client.post(
"/discover-tools",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
data = response.json()
assert len(data["tools"]) == 2
assert data["tools"][0]["name"] == "get_weather"
assert data["tools"][1]["name"] == "add_numbers"
assert data["server_name"] == "test-server"
assert data["protocol_version"] == "2025-03-26"
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_with_auth_token(self, client):
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
)
instance.list_tools = AsyncMock(return_value=[])
response = await client.post(
"/discover-tools",
json={
"server_url": "https://mcp.example.com/mcp",
"auth_token": "my-secret-token",
},
)
assert response.status_code == 200
MockClient.assert_called_once_with(
"https://mcp.example.com/mcp",
auth_token="my-secret-token",
)
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_auto_uses_stored_credential(self, client):
"""When no explicit token is given, stored MCP credentials are used."""
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
stored_cred = OAuth2Credentials(
provider="mcp",
title="MCP: example.com",
access_token=SecretStr("stored-token-123"),
refresh_token=None,
access_token_expires_at=None,
refresh_token_expires_at=None,
scopes=[],
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
)
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred])
mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred)
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
)
instance.list_tools = AsyncMock(return_value=[])
response = await client.post(
"/discover-tools",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
MockClient.assert_called_once_with(
"https://mcp.example.com/mcp",
auth_token="stored-token-123",
)
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_mcp_error(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=MCPClientError("Connection refused")
)
response = await client.post(
"/discover-tools",
json={"server_url": "https://bad-server.example.com/mcp"},
)
assert response.status_code == 502
assert "Connection refused" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_generic_error(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
response = await client.post(
"/discover-tools",
json={"server_url": "https://timeout.example.com/mcp"},
)
assert response.status_code == 502
assert "Failed to connect" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_auth_required(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
)
response = await client.post(
"/discover-tools",
json={"server_url": "https://auth-server.example.com/mcp"},
)
assert response.status_code == 401
assert "requires authentication" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_forbidden(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
)
response = await client.post(
"/discover-tools",
json={"server_url": "https://auth-server.example.com/mcp"},
)
assert response.status_code == 401
assert "requires authentication" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_missing_url(self, client):
response = await client.post("/discover-tools", json={})
assert response.status_code == 422
class TestOAuthLogin:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_success(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch(
"backend.api.features.mcp.routes._register_mcp_client"
) as mock_register,
):
instance = MockClient.return_value
instance.discover_auth = AsyncMock(
return_value={
"authorization_servers": ["https://auth.sentry.io"],
"resource": "https://mcp.sentry.dev/mcp",
"scopes_supported": ["openid"],
}
)
instance.discover_auth_server_metadata = AsyncMock(
return_value={
"authorization_endpoint": "https://auth.sentry.io/authorize",
"token_endpoint": "https://auth.sentry.io/token",
"registration_endpoint": "https://auth.sentry.io/register",
}
)
mock_register.return_value = {
"client_id": "registered-client-id",
"client_secret": "registered-secret",
}
mock_cm.store.store_state_token = AsyncMock(
return_value=("state-token-123", "code-challenge-abc")
)
mock_settings.config.frontend_base_url = "http://localhost:3000"
response = await client.post(
"/oauth/login",
json={"server_url": "https://mcp.sentry.dev/mcp"},
)
assert response.status_code == 200
data = response.json()
assert "login_url" in data
assert data["state_token"] == "state-token-123"
assert "auth.sentry.io/authorize" in data["login_url"]
assert "registered-client-id" in data["login_url"]
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_no_oauth_support(self, client):
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
instance = MockClient.return_value
instance.discover_auth = AsyncMock(return_value=None)
instance.discover_auth_server_metadata = AsyncMock(return_value=None)
response = await client.post(
"/oauth/login",
json={"server_url": "https://simple-server.example.com/mcp"},
)
assert response.status_code == 400
assert "does not advertise OAuth" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_fallback_to_public_client(self, client):
"""When DCR is unavailable, falls back to default public client ID."""
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
):
instance = MockClient.return_value
instance.discover_auth = AsyncMock(
return_value={
"authorization_servers": ["https://auth.example.com"],
"resource": "https://mcp.example.com/mcp",
}
)
instance.discover_auth_server_metadata = AsyncMock(
return_value={
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
# No registration_endpoint
}
)
mock_cm.store.store_state_token = AsyncMock(
return_value=("state-abc", "challenge-xyz")
)
mock_settings.config.frontend_base_url = "http://localhost:3000"
response = await client.post(
"/oauth/login",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
data = response.json()
assert "autogpt-platform" in data["login_url"]
class TestOAuthCallback:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_success(self, client):
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
mock_creds = OAuth2Credentials(
provider="mcp",
title=None,
access_token=SecretStr("access-token-xyz"),
refresh_token=None,
access_token_expires_at=None,
refresh_token_expires_at=None,
scopes=[],
metadata={
"mcp_token_url": "https://auth.sentry.io/token",
"mcp_resource_url": "https://mcp.sentry.dev/mcp",
},
)
with (
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
):
mock_settings.config.frontend_base_url = "http://localhost:3000"
# Mock state verification
mock_state = AsyncMock()
mock_state.state_metadata = {
"authorize_url": "https://auth.sentry.io/authorize",
"token_url": "https://auth.sentry.io/token",
"client_id": "test-client-id",
"client_secret": "test-secret",
"server_url": "https://mcp.sentry.dev/mcp",
}
mock_state.scopes = ["openid"]
mock_state.code_verifier = "verifier-123"
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
mock_cm.create = AsyncMock()
handler_instance = MockHandler.return_value
handler_instance.exchange_code_for_tokens = AsyncMock(
return_value=mock_creds
)
# Mock old credential cleanup
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
response = await client.post(
"/oauth/callback",
json={"code": "auth-code-abc", "state_token": "state-token-123"},
)
assert response.status_code == 200
data = response.json()
assert "id" in data
assert data["provider"] == "mcp"
assert data["type"] == "oauth2"
mock_cm.create.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_invalid_state(self, client):
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
mock_cm.store.verify_state_token = AsyncMock(return_value=None)
response = await client.post(
"/oauth/callback",
json={"code": "auth-code", "state_token": "bad-state"},
)
assert response.status_code == 400
assert "Invalid or expired" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_token_exchange_fails(self, client):
with (
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
):
mock_settings.config.frontend_base_url = "http://localhost:3000"
mock_state = AsyncMock()
mock_state.state_metadata = {
"authorize_url": "https://auth.example.com/authorize",
"token_url": "https://auth.example.com/token",
"client_id": "cid",
"server_url": "https://mcp.example.com/mcp",
}
mock_state.scopes = []
mock_state.code_verifier = "v"
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
handler_instance = MockHandler.return_value
handler_instance.exchange_code_for_tokens = AsyncMock(
side_effect=RuntimeError("Token exchange failed")
)
response = await client.post(
"/oauth/callback",
json={"code": "bad-code", "state_token": "state"},
)
assert response.status_code == 400
assert "token exchange failed" in response.json()["detail"].lower()

View File

@@ -5,8 +5,8 @@ from typing import Optional
import aiohttp import aiohttp
from fastapi import HTTPException from fastapi import HTTPException
from backend.blocks import get_block
from backend.data import graph as graph_db from backend.data import graph as graph_db
from backend.data.block import get_block
from backend.util.settings import Settings from backend.util.settings import Settings
from .models import ApiResponse, ChatRequest, GraphData from .models import ApiResponse, ChatRequest, GraphData

View File

@@ -152,7 +152,7 @@ class BlockHandler(ContentHandler):
async def get_missing_items(self, batch_size: int) -> list[ContentItem]: async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
"""Fetch blocks without embeddings.""" """Fetch blocks without embeddings."""
from backend.data.block import get_blocks from backend.blocks import get_blocks
# Get all available blocks # Get all available blocks
all_blocks = get_blocks() all_blocks = get_blocks()
@@ -249,7 +249,7 @@ class BlockHandler(ContentHandler):
async def get_stats(self) -> dict[str, int]: async def get_stats(self) -> dict[str, int]:
"""Get statistics about block embedding coverage.""" """Get statistics about block embedding coverage."""
from backend.data.block import get_blocks from backend.blocks import get_blocks
all_blocks = get_blocks() all_blocks = get_blocks()

View File

@@ -93,7 +93,7 @@ async def test_block_handler_get_missing_items(mocker):
mock_existing = [] mock_existing = []
with patch( with patch(
"backend.data.block.get_blocks", "backend.blocks.get_blocks",
return_value=mock_blocks, return_value=mock_blocks,
): ):
with patch( with patch(
@@ -135,7 +135,7 @@ async def test_block_handler_get_stats(mocker):
mock_embedded = [{"count": 2}] mock_embedded = [{"count": 2}]
with patch( with patch(
"backend.data.block.get_blocks", "backend.blocks.get_blocks",
return_value=mock_blocks, return_value=mock_blocks,
): ):
with patch( with patch(
@@ -327,7 +327,7 @@ async def test_block_handler_handles_missing_attributes():
mock_blocks = {"block-minimal": mock_block_class} mock_blocks = {"block-minimal": mock_block_class}
with patch( with patch(
"backend.data.block.get_blocks", "backend.blocks.get_blocks",
return_value=mock_blocks, return_value=mock_blocks,
): ):
with patch( with patch(
@@ -360,7 +360,7 @@ async def test_block_handler_skips_failed_blocks():
mock_blocks = {"good-block": good_block, "bad-block": bad_block} mock_blocks = {"good-block": good_block, "bad-block": bad_block}
with patch( with patch(
"backend.data.block.get_blocks", "backend.blocks.get_blocks",
return_value=mock_blocks, return_value=mock_blocks,
): ):
with patch( with patch(

View File

@@ -662,7 +662,7 @@ async def cleanup_orphaned_embeddings() -> dict[str, Any]:
) )
current_ids = {row["id"] for row in valid_agents} current_ids = {row["id"] for row in valid_agents}
elif content_type == ContentType.BLOCK: elif content_type == ContentType.BLOCK:
from backend.data.block import get_blocks from backend.blocks import get_blocks
current_ids = set(get_blocks().keys()) current_ids = set(get_blocks().keys())
elif content_type == ContentType.DOCUMENTATION: elif content_type == ContentType.DOCUMENTATION:

View File

@@ -7,15 +7,6 @@ from replicate.client import Client as ReplicateClient
from replicate.exceptions import ReplicateError from replicate.exceptions import ReplicateError
from replicate.helpers import FileOutput from replicate.helpers import FileOutput
from backend.blocks.ideogram import (
AspectRatio,
ColorPalettePreset,
IdeogramModelBlock,
IdeogramModelName,
MagicPromptOption,
StyleType,
UpscaleOption,
)
from backend.data.graph import GraphBaseMeta from backend.data.graph import GraphBaseMeta
from backend.data.model import CredentialsMetaInput, ProviderName from backend.data.model import CredentialsMetaInput, ProviderName
from backend.integrations.credentials_store import ideogram_credentials from backend.integrations.credentials_store import ideogram_credentials
@@ -50,6 +41,16 @@ async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.Bytes
if not ideogram_credentials.api_key: if not ideogram_credentials.api_key:
raise ValueError("Missing Ideogram API key") raise ValueError("Missing Ideogram API key")
from backend.blocks.ideogram import (
AspectRatio,
ColorPalettePreset,
IdeogramModelBlock,
IdeogramModelName,
MagicPromptOption,
StyleType,
UpscaleOption,
)
name = graph.name name = graph.name
description = f"{name} ({graph.description})" if graph.description else name description = f"{name} ({graph.description})" if graph.description else name

View File

@@ -40,10 +40,11 @@ from backend.api.model import (
UpdateTimezoneRequest, UpdateTimezoneRequest,
UploadFileResponse, UploadFileResponse,
) )
from backend.blocks import get_block, get_blocks
from backend.data import execution as execution_db from backend.data import execution as execution_db
from backend.data import graph as graph_db from backend.data import graph as graph_db
from backend.data.auth import api_key as api_key_db from backend.data.auth import api_key as api_key_db
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.credit import ( from backend.data.credit import (
AutoTopUpConfig, AutoTopUpConfig,
RefundRequest, RefundRequest,

View File

@@ -26,6 +26,7 @@ import backend.api.features.executions.review.routes
import backend.api.features.library.db import backend.api.features.library.db
import backend.api.features.library.model import backend.api.features.library.model
import backend.api.features.library.routes import backend.api.features.library.routes
import backend.api.features.mcp.routes as mcp_routes
import backend.api.features.oauth import backend.api.features.oauth
import backend.api.features.otto.routes import backend.api.features.otto.routes
import backend.api.features.postmark.postmark import backend.api.features.postmark.postmark
@@ -343,6 +344,11 @@ app.include_router(
tags=["workspace"], tags=["workspace"],
prefix="/api/workspace", prefix="/api/workspace",
) )
app.include_router(
mcp_routes.router,
tags=["v2", "mcp"],
prefix="/api/mcp",
)
app.include_router( app.include_router(
backend.api.features.oauth.router, backend.api.features.oauth.router,
tags=["oauth"], tags=["oauth"],

View File

@@ -3,22 +3,19 @@ import logging
import os import os
import re import re
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, TypeVar from typing import Sequence, Type, TypeVar
from backend.blocks._base import AnyBlockSchema, BlockType
from backend.util.cache import cached from backend.util.cache import cached
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from backend.data.block import Block
T = TypeVar("T") T = TypeVar("T")
@cached(ttl_seconds=3600) @cached(ttl_seconds=3600)
def load_all_blocks() -> dict[str, type["Block"]]: def load_all_blocks() -> dict[str, type["AnyBlockSchema"]]:
from backend.data.block import Block from backend.blocks._base import Block
from backend.util.settings import Config from backend.util.settings import Config
# Check if example blocks should be loaded from settings # Check if example blocks should be loaded from settings
@@ -50,8 +47,8 @@ def load_all_blocks() -> dict[str, type["Block"]]:
importlib.import_module(f".{module}", package=__name__) importlib.import_module(f".{module}", package=__name__)
# Load all Block instances from the available modules # Load all Block instances from the available modules
available_blocks: dict[str, type["Block"]] = {} available_blocks: dict[str, type["AnyBlockSchema"]] = {}
for block_cls in all_subclasses(Block): for block_cls in _all_subclasses(Block):
class_name = block_cls.__name__ class_name = block_cls.__name__
if class_name.endswith("Base"): if class_name.endswith("Base"):
@@ -64,7 +61,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
"please name the class with 'Base' at the end" "please name the class with 'Base' at the end"
) )
block = block_cls.create() block = block_cls() # pyright: ignore[reportAbstractUsage]
if not isinstance(block.id, str) or len(block.id) != 36: if not isinstance(block.id, str) or len(block.id) != 36:
raise ValueError( raise ValueError(
@@ -105,7 +102,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
available_blocks[block.id] = block_cls available_blocks[block.id] = block_cls
# Filter out blocks with incomplete auth configs, e.g. missing OAuth server secrets # Filter out blocks with incomplete auth configs, e.g. missing OAuth server secrets
from backend.data.block import is_block_auth_configured from ._utils import is_block_auth_configured
filtered_blocks = {} filtered_blocks = {}
for block_id, block_cls in available_blocks.items(): for block_id, block_cls in available_blocks.items():
@@ -115,11 +112,48 @@ def load_all_blocks() -> dict[str, type["Block"]]:
return filtered_blocks return filtered_blocks
__all__ = ["load_all_blocks"] def _all_subclasses(cls: type[T]) -> list[type[T]]:
def all_subclasses(cls: type[T]) -> list[type[T]]:
subclasses = cls.__subclasses__() subclasses = cls.__subclasses__()
for subclass in subclasses: for subclass in subclasses:
subclasses += all_subclasses(subclass) subclasses += _all_subclasses(subclass)
return subclasses return subclasses
# ============== Block access helper functions ============== #
def get_blocks() -> dict[str, Type["AnyBlockSchema"]]:
return load_all_blocks()
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
def get_block(block_id: str) -> "AnyBlockSchema | None":
cls = get_blocks().get(block_id)
return cls() if cls else None
@cached(ttl_seconds=3600)
def get_webhook_block_ids() -> Sequence[str]:
return [
id
for id, B in get_blocks().items()
if B().block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
]
@cached(ttl_seconds=3600)
def get_io_block_ids() -> Sequence[str]:
return [
id
for id, B in get_blocks().items()
if B().block_type in (BlockType.INPUT, BlockType.OUTPUT)
]
@cached(ttl_seconds=3600)
def get_human_in_the_loop_block_ids() -> Sequence[str]:
return [
id
for id, B in get_blocks().items()
if B().block_type == BlockType.HUMAN_IN_THE_LOOP
]

View File

@@ -0,0 +1,740 @@
import inspect
import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Generic,
Optional,
Type,
TypeAlias,
TypeVar,
cast,
get_origin,
)
import jsonref
import jsonschema
from pydantic import BaseModel
from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry
from backend.data.model import (
Credentials,
CredentialsFieldInfo,
CredentialsMetaInput,
SchemaField,
is_credentials_field_name,
)
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.exceptions import (
BlockError,
BlockExecutionError,
BlockInputError,
BlockOutputError,
BlockUnknownError,
)
from backend.util.settings import Config
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from backend.data.execution import ExecutionContext
from backend.data.model import ContributorDetails, NodeExecutionStats
from ..data.graph import Link
app_config = Config()
BlockTestOutput = BlockOutputEntry | tuple[str, Callable[[Any], bool]]
class BlockType(Enum):
STANDARD = "Standard"
INPUT = "Input"
OUTPUT = "Output"
NOTE = "Note"
WEBHOOK = "Webhook"
WEBHOOK_MANUAL = "Webhook (manual)"
AGENT = "Agent"
AI = "AI"
AYRSHARE = "Ayrshare"
HUMAN_IN_THE_LOOP = "Human In The Loop"
MCP_TOOL = "MCP Tool"
class BlockCategory(Enum):
AI = "Block that leverages AI to perform a task."
SOCIAL = "Block that interacts with social media platforms."
TEXT = "Block that processes text data."
SEARCH = "Block that searches or extracts information from the internet."
BASIC = "Block that performs basic operations."
INPUT = "Block that interacts with input of the graph."
OUTPUT = "Block that interacts with output of the graph."
LOGIC = "Programming logic to control the flow of your agent"
COMMUNICATION = "Block that interacts with communication platforms."
DEVELOPER_TOOLS = "Developer tools such as GitHub blocks."
DATA = "Block that interacts with structured data."
HARDWARE = "Block that interacts with hardware."
AGENT = "Block that interacts with other agents."
CRM = "Block that interacts with CRM services."
SAFETY = (
"Block that provides AI safety mechanisms such as detecting harmful content"
)
PRODUCTIVITY = "Block that helps with productivity"
ISSUE_TRACKING = "Block that helps with issue tracking"
MULTIMEDIA = "Block that interacts with multimedia content"
MARKETING = "Block that helps with marketing"
def dict(self) -> dict[str, str]:
return {"category": self.name, "description": self.value}
class BlockCostType(str, Enum):
RUN = "run" # cost X credits per run
BYTE = "byte" # cost X credits per byte
SECOND = "second" # cost X credits per second
class BlockCost(BaseModel):
cost_amount: int
cost_filter: BlockInput
cost_type: BlockCostType
def __init__(
self,
cost_amount: int,
cost_type: BlockCostType = BlockCostType.RUN,
cost_filter: Optional[BlockInput] = None,
**data: Any,
) -> None:
super().__init__(
cost_amount=cost_amount,
cost_filter=cost_filter or {},
cost_type=cost_type,
**data,
)
class BlockInfo(BaseModel):
id: str
name: str
inputSchema: dict[str, Any]
outputSchema: dict[str, Any]
costs: list[BlockCost]
description: str
categories: list[dict[str, str]]
contributors: list[dict[str, Any]]
staticOutput: bool
uiType: str
class BlockSchema(BaseModel):
cached_jsonschema: ClassVar[dict[str, Any]]
@classmethod
def jsonschema(cls) -> dict[str, Any]:
if cls.cached_jsonschema:
return cls.cached_jsonschema
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
def ref_to_dict(obj):
if isinstance(obj, dict):
# OpenAPI <3.1 does not support sibling fields that has a $ref key
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
keys = {"allOf", "anyOf", "oneOf"}
one_key = next((k for k in keys if k in obj and len(obj[k]) == 1), None)
if one_key:
obj.update(obj[one_key][0])
return {
key: ref_to_dict(value)
for key, value in obj.items()
if not key.startswith("$") and key != one_key
}
elif isinstance(obj, list):
return [ref_to_dict(item) for item in obj]
return obj
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
return cls.cached_jsonschema
@classmethod
def validate_data(cls, data: BlockInput) -> str | None:
return json.validate_with_jsonschema(
schema=cls.jsonschema(),
data={k: v for k, v in data.items() if v is not None},
)
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
return cls.validate_data(data)
@classmethod
def get_field_schema(cls, field_name: str) -> dict[str, Any]:
model_schema = cls.jsonschema().get("properties", {})
if not model_schema:
raise ValueError(f"Invalid model schema {cls}")
property_schema = model_schema.get(field_name)
if not property_schema:
raise ValueError(f"Invalid property name {field_name}")
return property_schema
@classmethod
def validate_field(cls, field_name: str, data: BlockInput) -> str | None:
"""
Validate the data against a specific property (one of the input/output name).
Returns the validation error message if the data does not match the schema.
"""
try:
property_schema = cls.get_field_schema(field_name)
jsonschema.validate(json.to_dict(data), property_schema)
return None
except jsonschema.ValidationError as e:
return str(e)
@classmethod
def get_fields(cls) -> set[str]:
return set(cls.model_fields.keys())
@classmethod
def get_required_fields(cls) -> set[str]:
return {
field
for field, field_info in cls.model_fields.items()
if field_info.is_required()
}
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
"""Validates the schema definition. Rules:
- Fields with annotation `CredentialsMetaInput` MUST be
named `credentials` or `*_credentials`
- Fields named `credentials` or `*_credentials` MUST be
of type `CredentialsMetaInput`
"""
super().__pydantic_init_subclass__(**kwargs)
# Reset cached JSON schema to prevent inheriting it from parent class
cls.cached_jsonschema = {}
credentials_fields = cls.get_credentials_fields()
for field_name in cls.get_fields():
if is_credentials_field_name(field_name):
if field_name not in credentials_fields:
raise TypeError(
f"Credentials field '{field_name}' on {cls.__qualname__} "
f"is not of type {CredentialsMetaInput.__name__}"
)
CredentialsMetaInput.validate_credentials_field_schema(
cls.get_field_schema(field_name), field_name
)
elif field_name in credentials_fields:
raise KeyError(
f"Credentials field '{field_name}' on {cls.__qualname__} "
"has invalid name: must be 'credentials' or *_credentials"
)
@classmethod
def get_credentials_fields(cls) -> dict[str, type[CredentialsMetaInput]]:
return {
field_name: info.annotation
for field_name, info in cls.model_fields.items()
if (
inspect.isclass(info.annotation)
and issubclass(
get_origin(info.annotation) or info.annotation,
CredentialsMetaInput,
)
)
}
@classmethod
def get_auto_credentials_fields(cls) -> dict[str, dict[str, Any]]:
"""
Get fields that have auto_credentials metadata (e.g., GoogleDriveFileInput).
Returns a dict mapping kwarg_name -> {field_name, auto_credentials_config}
Raises:
ValueError: If multiple fields have the same kwarg_name, as this would
cause silent overwriting and only the last field would be processed.
"""
result: dict[str, dict[str, Any]] = {}
schema = cls.jsonschema()
properties = schema.get("properties", {})
for field_name, field_schema in properties.items():
auto_creds = field_schema.get("auto_credentials")
if auto_creds:
kwarg_name = auto_creds.get("kwarg_name", "credentials")
if kwarg_name in result:
raise ValueError(
f"Duplicate auto_credentials kwarg_name '{kwarg_name}' "
f"in fields '{result[kwarg_name]['field_name']}' and "
f"'{field_name}' on {cls.__qualname__}"
)
result[kwarg_name] = {
"field_name": field_name,
"config": auto_creds,
}
return result
@classmethod
def get_credentials_fields_info(cls) -> dict[str, CredentialsFieldInfo]:
result = {}
# Regular credentials fields
for field_name in cls.get_credentials_fields().keys():
result[field_name] = CredentialsFieldInfo.model_validate(
cls.get_field_schema(field_name), by_alias=True
)
# Auto-generated credentials fields (from GoogleDriveFileInput etc.)
for kwarg_name, info in cls.get_auto_credentials_fields().items():
config = info["config"]
# Build a schema-like dict that CredentialsFieldInfo can parse
auto_schema = {
"credentials_provider": [config.get("provider", "google")],
"credentials_types": [config.get("type", "oauth2")],
"credentials_scopes": config.get("scopes"),
}
result[kwarg_name] = CredentialsFieldInfo.model_validate(
auto_schema, by_alias=True
)
return result
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
return data # Return as is, by default.
@classmethod
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
input_fields_from_nodes = {link.sink_name for link in links}
return input_fields_from_nodes - set(data)
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
return cls.get_required_fields() - set(data)
class BlockSchemaInput(BlockSchema):
"""
Base schema class for block inputs.
All block input schemas should extend this class for consistency.
"""
pass
class BlockSchemaOutput(BlockSchema):
"""
Base schema class for block outputs that includes a standard error field.
All block output schemas should extend this class to ensure consistent error handling.
"""
error: str = SchemaField(
description="Error message if the operation failed", default=""
)
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchemaInput)
BlockSchemaOutputType = TypeVar("BlockSchemaOutputType", bound=BlockSchemaOutput)
class EmptyInputSchema(BlockSchemaInput):
pass
class EmptyOutputSchema(BlockSchemaOutput):
pass
# For backward compatibility - will be deprecated
EmptySchema = EmptyOutputSchema
# --8<-- [start:BlockWebhookConfig]
class BlockManualWebhookConfig(BaseModel):
"""
Configuration model for webhook-triggered blocks on which
the user has to manually set up the webhook at the provider.
"""
provider: ProviderName
"""The service provider that the webhook connects to"""
webhook_type: str
"""
Identifier for the webhook type. E.g. GitHub has repo and organization level hooks.
Only for use in the corresponding `WebhooksManager`.
"""
event_filter_input: str = ""
"""
Name of the block's event filter input.
Leave empty if the corresponding webhook doesn't have distinct event/payload types.
"""
event_format: str = "{event}"
"""
Template string for the event(s) that a block instance subscribes to.
Applied individually to each event selected in the event filter input.
Example: `"pull_request.{event}"` -> `"pull_request.opened"`
"""
class BlockWebhookConfig(BlockManualWebhookConfig):
"""
Configuration model for webhook-triggered blocks for which
the webhook can be automatically set up through the provider's API.
"""
resource_format: str
"""
Template string for the resource that a block instance subscribes to.
Fields will be filled from the block's inputs (except `payload`).
Example: `f"{repo}/pull_requests"` (note: not how it's actually implemented)
Only for use in the corresponding `WebhooksManager`.
"""
# --8<-- [end:BlockWebhookConfig]
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
def __init__(
self,
id: str = "",
description: str = "",
contributors: list["ContributorDetails"] = [],
categories: set[BlockCategory] | None = None,
input_schema: Type[BlockSchemaInputType] = EmptyInputSchema,
output_schema: Type[BlockSchemaOutputType] = EmptyOutputSchema,
test_input: BlockInput | list[BlockInput] | None = None,
test_output: BlockTestOutput | list[BlockTestOutput] | None = None,
test_mock: dict[str, Any] | None = None,
test_credentials: Optional[Credentials | dict[str, Credentials]] = None,
disabled: bool = False,
static_output: bool = False,
block_type: BlockType = BlockType.STANDARD,
webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None,
is_sensitive_action: bool = False,
):
"""
Initialize the block with the given schema.
Args:
id: The unique identifier for the block, this value will be persisted in the
DB. So it should be a unique and constant across the application run.
Use the UUID format for the ID.
description: The description of the block, explaining what the block does.
contributors: The list of contributors who contributed to the block.
input_schema: The schema, defined as a Pydantic model, for the input data.
output_schema: The schema, defined as a Pydantic model, for the output data.
test_input: The list or single sample input data for the block, for testing.
test_output: The list or single expected output if the test_input is run.
test_mock: function names on the block implementation to mock on test run.
disabled: If the block is disabled, it will not be available for execution.
static_output: Whether the output links of the block are static by default.
"""
from backend.data.model import NodeExecutionStats
self.id = id
self.input_schema = input_schema
self.output_schema = output_schema
self.test_input = test_input
self.test_output = test_output
self.test_mock = test_mock
self.test_credentials = test_credentials
self.description = description
self.categories = categories or set()
self.contributors = contributors or set()
self.disabled = disabled
self.static_output = static_output
self.block_type = block_type
self.webhook_config = webhook_config
self.is_sensitive_action = is_sensitive_action
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
if self.webhook_config:
if isinstance(self.webhook_config, BlockWebhookConfig):
# Enforce presence of credentials field on auto-setup webhook blocks
if not (cred_fields := self.input_schema.get_credentials_fields()):
raise TypeError(
"credentials field is required on auto-setup webhook blocks"
)
# Disallow multiple credentials inputs on webhook blocks
elif len(cred_fields) > 1:
raise ValueError(
"Multiple credentials inputs not supported on webhook blocks"
)
self.block_type = BlockType.WEBHOOK
else:
self.block_type = BlockType.WEBHOOK_MANUAL
# Enforce shape of webhook event filter, if present
if self.webhook_config.event_filter_input:
event_filter_field = self.input_schema.model_fields[
self.webhook_config.event_filter_input
]
if not (
isinstance(event_filter_field.annotation, type)
and issubclass(event_filter_field.annotation, BaseModel)
and all(
field.annotation is bool
for field in event_filter_field.annotation.model_fields.values()
)
):
raise NotImplementedError(
f"{self.name} has an invalid webhook event selector: "
"field must be a BaseModel and all its fields must be boolean"
)
# Enforce presence of 'payload' input
if "payload" not in self.input_schema.model_fields:
raise TypeError(
f"{self.name} is webhook-triggered but has no 'payload' input"
)
# Disable webhook-triggered block if webhook functionality not available
if not app_config.platform_base_url:
self.disabled = True
@abstractmethod
async def run(self, input_data: BlockSchemaInputType, **kwargs) -> BlockOutput:
"""
Run the block with the given input data.
Args:
input_data: The input data with the structure of input_schema.
Kwargs: Currently 14/02/2025 these include
graph_id: The ID of the graph.
node_id: The ID of the node.
graph_exec_id: The ID of the graph execution.
node_exec_id: The ID of the node execution.
user_id: The ID of the user.
Returns:
A Generator that yields (output_name, output_data).
output_name: One of the output name defined in Block's output_schema.
output_data: The data for the output_name, matching the defined schema.
"""
# --- satisfy the type checker, never executed -------------
if False: # noqa: SIM115
yield "name", "value" # pyright: ignore[reportMissingYield]
raise NotImplementedError(f"{self.name} does not implement the run method.")
async def run_once(
self, input_data: BlockSchemaInputType, output: str, **kwargs
) -> Any:
async for item in self.run(input_data, **kwargs):
name, data = item
if name == output:
return data
raise ValueError(f"{self.name} did not produce any output for {output}")
def merge_stats(self, stats: "NodeExecutionStats") -> "NodeExecutionStats":
self.execution_stats += stats
return self.execution_stats
@property
def name(self):
return self.__class__.__name__
def to_dict(self):
return {
"id": self.id,
"name": self.name,
"inputSchema": self.input_schema.jsonschema(),
"outputSchema": self.output_schema.jsonschema(),
"description": self.description,
"categories": [category.dict() for category in self.categories],
"contributors": [
contributor.model_dump() for contributor in self.contributors
],
"staticOutput": self.static_output,
"uiType": self.block_type.value,
}
def get_info(self) -> BlockInfo:
from backend.data.credit import get_block_cost
return BlockInfo(
id=self.id,
name=self.name,
inputSchema=self.input_schema.jsonschema(),
outputSchema=self.output_schema.jsonschema(),
costs=get_block_cost(self),
description=self.description,
categories=[category.dict() for category in self.categories],
contributors=[
contributor.model_dump() for contributor in self.contributors
],
staticOutput=self.static_output,
uiType=self.block_type.value,
)
async def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
try:
async for output_name, output_data in self._execute(input_data, **kwargs):
yield output_name, output_data
except Exception as ex:
if isinstance(ex, BlockError):
raise ex
else:
raise (
BlockExecutionError
if isinstance(ex, ValueError)
else BlockUnknownError
)(
message=str(ex),
block_name=self.name,
block_id=self.id,
) from ex
async def is_block_exec_need_review(
self,
input_data: BlockInput,
*,
user_id: str,
node_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: "ExecutionContext",
**kwargs,
) -> tuple[bool, BlockInput]:
"""
Check if this block execution needs human review and handle the review process.
Returns:
Tuple of (should_pause, input_data_to_use)
- should_pause: True if execution should be paused for review
- input_data_to_use: The input data to use (may be modified by reviewer)
"""
if not (
self.is_sensitive_action and execution_context.sensitive_action_safe_mode
):
return False, input_data
from backend.blocks.helpers.review import HITLReviewHelper
# Handle the review request and get decision
decision = await HITLReviewHelper.handle_review_decision(
input_data=input_data,
user_id=user_id,
node_id=node_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
block_name=self.name,
editable=True,
)
if decision is None:
# We're awaiting review - pause execution
return True, input_data
if not decision.should_proceed:
# Review was rejected, raise an error to stop execution
raise BlockExecutionError(
message=f"Block execution rejected by reviewer: {decision.message}",
block_name=self.name,
block_id=self.id,
)
# Review was approved - use the potentially modified data
# ReviewResult.data must be a dict for block inputs
reviewed_data = decision.review_result.data
if not isinstance(reviewed_data, dict):
raise BlockExecutionError(
message=f"Review data must be a dict for block input, got {type(reviewed_data).__name__}",
block_name=self.name,
block_id=self.id,
)
return False, reviewed_data
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
# Check for review requirement only if running within a graph execution context
# Direct block execution (e.g., from chat) skips the review process
has_graph_context = all(
key in kwargs
for key in (
"node_exec_id",
"graph_exec_id",
"graph_id",
"execution_context",
)
)
if has_graph_context:
should_pause, input_data = await self.is_block_exec_need_review(
input_data, **kwargs
)
if should_pause:
return
# Validate the input data (original or reviewer-modified) once
if error := self.input_schema.validate_data(input_data):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
block_name=self.name,
block_id=self.id,
)
# Use the validated input data
async for output_name, output_data in self.run(
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
**kwargs,
):
if output_name == "error":
raise BlockExecutionError(
message=output_data, block_name=self.name, block_id=self.id
)
if self.block_type == BlockType.STANDARD and (
error := self.output_schema.validate_field(output_name, output_data)
):
raise BlockOutputError(
message=f"Block produced an invalid output data: {error}",
block_name=self.name,
block_id=self.id,
)
yield output_name, output_data
def is_triggered_by_event_type(
self, trigger_config: dict[str, Any], event_type: str
) -> bool:
if not self.webhook_config:
raise TypeError("This method can't be used on non-trigger blocks")
if not self.webhook_config.event_filter_input:
return True
event_filter = trigger_config.get(self.webhook_config.event_filter_input)
if not event_filter:
raise ValueError("Event filter is not configured on trigger")
return event_type in [
self.webhook_config.event_format.format(event=k)
for k in event_filter
if event_filter[k] is True
]
# Type alias for any block with standard input/output schemas
AnyBlockSchema: TypeAlias = Block[BlockSchemaInput, BlockSchemaOutput]

View File

@@ -0,0 +1,122 @@
import logging
import os
from backend.integrations.providers import ProviderName
from ._base import AnyBlockSchema
logger = logging.getLogger(__name__)
def is_block_auth_configured(
block_cls: type[AnyBlockSchema],
) -> bool:
"""
Check if a block has a valid authentication method configured at runtime.
For example if a block is an OAuth-only block and there env vars are not set,
do not show it in the UI.
"""
from backend.sdk.registry import AutoRegistry
# Create an instance to access input_schema
try:
block = block_cls()
except Exception as e:
# If we can't create a block instance, assume it's not OAuth-only
logger.error(f"Error creating block instance for {block_cls.__name__}: {e}")
return True
logger.debug(
f"Checking if block {block_cls.__name__} has a valid provider configured"
)
# Get all credential inputs from input schema
credential_inputs = block.input_schema.get_credentials_fields_info()
required_inputs = block.input_schema.get_required_fields()
if not credential_inputs:
logger.debug(
f"Block {block_cls.__name__} has no credential inputs - Treating as valid"
)
return True
# Check credential inputs
if len(required_inputs.intersection(credential_inputs.keys())) == 0:
logger.debug(
f"Block {block_cls.__name__} has only optional credential inputs"
" - will work without credentials configured"
)
# Check if the credential inputs for this block are correctly configured
for field_name, field_info in credential_inputs.items():
provider_names = field_info.provider
if not provider_names:
logger.warning(
f"Block {block_cls.__name__} "
f"has credential input '{field_name}' with no provider options"
" - Disabling"
)
return False
# If a field has multiple possible providers, each one needs to be usable to
# prevent breaking the UX
for _provider_name in provider_names:
provider_name = _provider_name.value
if provider_name in ProviderName.__members__.values():
logger.debug(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"provider '{provider_name}' is part of the legacy provider system"
" - Treating as valid"
)
break
provider = AutoRegistry.get_provider(provider_name)
if not provider:
logger.warning(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"refers to unknown provider '{provider_name}' - Disabling"
)
return False
# Check the provider's supported auth types
if field_info.supported_types != provider.supported_auth_types:
logger.warning(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"has mismatched supported auth types (field <> Provider): "
f"{field_info.supported_types} != {provider.supported_auth_types}"
)
if not (supported_auth_types := provider.supported_auth_types):
# No auth methods are been configured for this provider
logger.warning(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"provider '{provider_name}' "
"has no authentication methods configured - Disabling"
)
return False
# Check if provider supports OAuth
if "oauth2" in supported_auth_types:
# Check if OAuth environment variables are set
if (oauth_config := provider.oauth_config) and bool(
os.getenv(oauth_config.client_id_env_var)
and os.getenv(oauth_config.client_secret_env_var)
):
logger.debug(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"provider '{provider_name}' is configured for OAuth"
)
else:
logger.error(
f"Block {block_cls.__name__} credential input '{field_name}' "
f"provider '{provider_name}' "
"is missing OAuth client ID or secret - Disabling"
)
return False
logger.debug(
f"Block {block_cls.__name__} credential input '{field_name}' is valid; "
f"supported credential types: {', '.join(field_info.supported_types)}"
)
return True

View File

@@ -1,7 +1,7 @@
import logging import logging
from typing import Any, Optional from typing import TYPE_CHECKING, Any, Optional
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockInput, BlockInput,
@@ -9,13 +9,15 @@ from backend.data.block import (
BlockSchema, BlockSchema,
BlockSchemaInput, BlockSchemaInput,
BlockType, BlockType,
get_block,
) )
from backend.data.execution import ExecutionContext, ExecutionStatus, NodesInputMasks from backend.data.execution import ExecutionContext, ExecutionStatus, NodesInputMasks
from backend.data.model import NodeExecutionStats, SchemaField from backend.data.model import NodeExecutionStats, SchemaField
from backend.util.json import validate_with_jsonschema from backend.util.json import validate_with_jsonschema
from backend.util.retry import func_retry from backend.util.retry import func_retry
if TYPE_CHECKING:
from backend.executor.utils import LogMetadata
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@@ -124,9 +126,10 @@ class AgentExecutorBlock(Block):
graph_version: int, graph_version: int,
graph_exec_id: str, graph_exec_id: str,
user_id: str, user_id: str,
logger, logger: "LogMetadata",
) -> BlockOutput: ) -> BlockOutput:
from backend.blocks import get_block
from backend.data.execution import ExecutionEventType from backend.data.execution import ExecutionEventType
from backend.executor import utils as execution_utils from backend.executor import utils as execution_utils
@@ -198,7 +201,7 @@ class AgentExecutorBlock(Block):
self, self,
graph_exec_id: str, graph_exec_id: str,
user_id: str, user_id: str,
logger, logger: "LogMetadata",
) -> None: ) -> None:
from backend.executor import utils as execution_utils from backend.executor import utils as execution_utils

View File

@@ -1,5 +1,11 @@
from typing import Any from typing import Any
from backend.blocks._base import (
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.llm import ( from backend.blocks.llm import (
DEFAULT_LLM_MODEL, DEFAULT_LLM_MODEL,
TEST_CREDENTIALS, TEST_CREDENTIALS,
@@ -11,12 +17,6 @@ from backend.blocks.llm import (
LLMResponse, LLMResponse,
llm_call, llm_call,
) )
from backend.data.block import (
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField

View File

@@ -6,7 +6,7 @@ from pydantic import SecretStr
from replicate.client import Client as ReplicateClient from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput from replicate.helpers import FileOutput
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -5,7 +5,12 @@ from pydantic import SecretStr
from replicate.client import Client as ReplicateClient from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput from replicate.helpers import FileOutput
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput from backend.blocks._base import (
Block,
BlockCategory,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext from backend.data.execution import ExecutionContext
from backend.data.model import ( from backend.data.model import (
APIKeyCredentials, APIKeyCredentials,

View File

@@ -6,7 +6,7 @@ from typing import Literal
from pydantic import SecretStr from pydantic import SecretStr
from replicate.client import Client as ReplicateClient from replicate.client import Client as ReplicateClient
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -6,7 +6,7 @@ from typing import Literal
from pydantic import SecretStr from pydantic import SecretStr
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -1,3 +1,10 @@
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.apollo._api import ApolloClient from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import ( from backend.blocks.apollo._auth import (
TEST_CREDENTIALS, TEST_CREDENTIALS,
@@ -10,13 +17,6 @@ from backend.blocks.apollo.models import (
PrimaryPhone, PrimaryPhone,
SearchOrganizationsRequest, SearchOrganizationsRequest,
) )
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import CredentialsField, SchemaField from backend.data.model import CredentialsField, SchemaField

View File

@@ -1,5 +1,12 @@
import asyncio import asyncio
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.apollo._api import ApolloClient from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import ( from backend.blocks.apollo._auth import (
TEST_CREDENTIALS, TEST_CREDENTIALS,
@@ -14,13 +21,6 @@ from backend.blocks.apollo.models import (
SearchPeopleRequest, SearchPeopleRequest,
SenorityLevels, SenorityLevels,
) )
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import CredentialsField, SchemaField from backend.data.model import CredentialsField, SchemaField

View File

@@ -1,3 +1,10 @@
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.apollo._api import ApolloClient from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import ( from backend.blocks.apollo._auth import (
TEST_CREDENTIALS, TEST_CREDENTIALS,
@@ -6,13 +13,6 @@ from backend.blocks.apollo._auth import (
ApolloCredentialsInput, ApolloCredentialsInput,
) )
from backend.blocks.apollo.models import Contact, EnrichPersonRequest from backend.blocks.apollo.models import Contact, EnrichPersonRequest
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import CredentialsField, SchemaField from backend.data.model import CredentialsField, SchemaField

View File

@@ -3,7 +3,7 @@ from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from backend.data.block import BlockSchemaInput from backend.blocks._base import BlockSchemaInput
from backend.data.model import SchemaField, UserIntegrations from backend.data.model import SchemaField, UserIntegrations
from backend.integrations.ayrshare import AyrshareClient from backend.integrations.ayrshare import AyrshareClient
from backend.util.clients import get_database_manager_async_client from backend.util.clients import get_database_manager_async_client

View File

@@ -1,7 +1,7 @@
import enum import enum
from typing import Any from typing import Any
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -2,7 +2,7 @@ import os
import re import re
from typing import Type from typing import Type
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -1,7 +1,7 @@
from enum import Enum from enum import Enum
from typing import Any from typing import Any
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -1,12 +1,12 @@
import json import json
import shlex import shlex
import uuid import uuid
from typing import Literal, Optional from typing import TYPE_CHECKING, Literal, Optional
from e2b import AsyncSandbox as BaseAsyncSandbox from e2b import AsyncSandbox as BaseAsyncSandbox
from pydantic import BaseModel, SecretStr from pydantic import SecretStr
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
@@ -20,6 +20,13 @@ from backend.data.model import (
SchemaField, SchemaField,
) )
from backend.integrations.providers import ProviderName from backend.integrations.providers import ProviderName
from backend.util.sandbox_files import (
SandboxFileOutput,
extract_and_store_sandbox_files,
)
if TYPE_CHECKING:
from backend.executor.utils import ExecutionContext
class ClaudeCodeExecutionError(Exception): class ClaudeCodeExecutionError(Exception):
@@ -174,22 +181,15 @@ class ClaudeCodeBlock(Block):
advanced=True, advanced=True,
) )
class FileOutput(BaseModel):
"""A file extracted from the sandbox."""
path: str
relative_path: str # Path relative to working directory (for GitHub, etc.)
name: str
content: str
class Output(BlockSchemaOutput): class Output(BlockSchemaOutput):
response: str = SchemaField( response: str = SchemaField(
description="The output/response from Claude Code execution" description="The output/response from Claude Code execution"
) )
files: list["ClaudeCodeBlock.FileOutput"] = SchemaField( files: list[SandboxFileOutput] = SchemaField(
description=( description=(
"List of text files created/modified by Claude Code during this execution. " "List of text files created/modified by Claude Code during this execution. "
"Each file has 'path', 'relative_path', 'name', and 'content' fields." "Each file has 'path', 'relative_path', 'name', 'content', and 'workspace_ref' fields. "
"workspace_ref contains a workspace:// URI if the file was stored to workspace."
) )
) )
conversation_history: str = SchemaField( conversation_history: str = SchemaField(
@@ -252,6 +252,7 @@ class ClaudeCodeBlock(Block):
"relative_path": "index.html", "relative_path": "index.html",
"name": "index.html", "name": "index.html",
"content": "<html>Hello World</html>", "content": "<html>Hello World</html>",
"workspace_ref": None,
} }
], ],
), ),
@@ -267,11 +268,12 @@ class ClaudeCodeBlock(Block):
"execute_claude_code": lambda *args, **kwargs: ( "execute_claude_code": lambda *args, **kwargs: (
"Created index.html with hello world content", # response "Created index.html with hello world content", # response
[ [
ClaudeCodeBlock.FileOutput( SandboxFileOutput(
path="/home/user/index.html", path="/home/user/index.html",
relative_path="index.html", relative_path="index.html",
name="index.html", name="index.html",
content="<html>Hello World</html>", content="<html>Hello World</html>",
workspace_ref=None,
) )
], # files ], # files
"User: Create a hello world HTML file\n" "User: Create a hello world HTML file\n"
@@ -294,7 +296,8 @@ class ClaudeCodeBlock(Block):
existing_sandbox_id: str, existing_sandbox_id: str,
conversation_history: str, conversation_history: str,
dispose_sandbox: bool, dispose_sandbox: bool,
) -> tuple[str, list["ClaudeCodeBlock.FileOutput"], str, str, str]: execution_context: "ExecutionContext",
) -> tuple[str, list[SandboxFileOutput], str, str, str]:
""" """
Execute Claude Code in an E2B sandbox. Execute Claude Code in an E2B sandbox.
@@ -449,14 +452,18 @@ class ClaudeCodeBlock(Block):
else: else:
new_conversation_history = turn_entry new_conversation_history = turn_entry
# Extract files created/modified during this run # Extract files created/modified during this run and store to workspace
files = await self._extract_files( sandbox_files = await extract_and_store_sandbox_files(
sandbox, working_directory, start_timestamp sandbox=sandbox,
working_directory=working_directory,
execution_context=execution_context,
since_timestamp=start_timestamp,
text_only=True,
) )
return ( return (
response, response,
files, sandbox_files, # Already SandboxFileOutput objects
new_conversation_history, new_conversation_history,
current_session_id, current_session_id,
sandbox_id, sandbox_id,
@@ -471,140 +478,6 @@ class ClaudeCodeBlock(Block):
if dispose_sandbox and sandbox: if dispose_sandbox and sandbox:
await sandbox.kill() await sandbox.kill()
async def _extract_files(
self,
sandbox: BaseAsyncSandbox,
working_directory: str,
since_timestamp: str | None = None,
) -> list["ClaudeCodeBlock.FileOutput"]:
"""
Extract text files created/modified during this Claude Code execution.
Args:
sandbox: The E2B sandbox instance
working_directory: Directory to search for files
since_timestamp: ISO timestamp - only return files modified after this time
Returns:
List of FileOutput objects with path, relative_path, name, and content
"""
files: list[ClaudeCodeBlock.FileOutput] = []
# Text file extensions we can safely read as text
text_extensions = {
".txt",
".md",
".html",
".htm",
".css",
".js",
".ts",
".jsx",
".tsx",
".json",
".xml",
".yaml",
".yml",
".toml",
".ini",
".cfg",
".conf",
".py",
".rb",
".php",
".java",
".c",
".cpp",
".h",
".hpp",
".cs",
".go",
".rs",
".swift",
".kt",
".scala",
".sh",
".bash",
".zsh",
".sql",
".graphql",
".env",
".gitignore",
".dockerfile",
"Dockerfile",
".vue",
".svelte",
".astro",
".mdx",
".rst",
".tex",
".csv",
".log",
}
try:
# List files recursively using find command
# Exclude node_modules and .git directories, but allow hidden files
# like .env and .gitignore (they're filtered by text_extensions later)
# Filter by timestamp to only get files created/modified during this run
safe_working_dir = shlex.quote(working_directory)
timestamp_filter = ""
if since_timestamp:
timestamp_filter = f"-newermt {shlex.quote(since_timestamp)} "
find_result = await sandbox.commands.run(
f"find {safe_working_dir} -type f "
f"{timestamp_filter}"
f"-not -path '*/node_modules/*' "
f"-not -path '*/.git/*' "
f"2>/dev/null"
)
if find_result.stdout:
for file_path in find_result.stdout.strip().split("\n"):
if not file_path:
continue
# Check if it's a text file we can read
is_text = any(
file_path.endswith(ext) for ext in text_extensions
) or file_path.endswith("Dockerfile")
if is_text:
try:
content = await sandbox.files.read(file_path)
# Handle bytes or string
if isinstance(content, bytes):
content = content.decode("utf-8", errors="replace")
# Extract filename from path
file_name = file_path.split("/")[-1]
# Calculate relative path by stripping working directory
relative_path = file_path
if file_path.startswith(working_directory):
relative_path = file_path[len(working_directory) :]
# Remove leading slash if present
if relative_path.startswith("/"):
relative_path = relative_path[1:]
files.append(
ClaudeCodeBlock.FileOutput(
path=file_path,
relative_path=relative_path,
name=file_name,
content=content,
)
)
except Exception:
# Skip files that can't be read
pass
except Exception:
# If file extraction fails, return empty results
pass
return files
def _escape_prompt(self, prompt: str) -> str: def _escape_prompt(self, prompt: str) -> str:
"""Escape the prompt for safe shell execution.""" """Escape the prompt for safe shell execution."""
# Use single quotes and escape any single quotes in the prompt # Use single quotes and escape any single quotes in the prompt
@@ -617,6 +490,7 @@ class ClaudeCodeBlock(Block):
*, *,
e2b_credentials: APIKeyCredentials, e2b_credentials: APIKeyCredentials,
anthropic_credentials: APIKeyCredentials, anthropic_credentials: APIKeyCredentials,
execution_context: "ExecutionContext",
**kwargs, **kwargs,
) -> BlockOutput: ) -> BlockOutput:
try: try:
@@ -637,6 +511,7 @@ class ClaudeCodeBlock(Block):
existing_sandbox_id=input_data.sandbox_id, existing_sandbox_id=input_data.sandbox_id,
conversation_history=input_data.conversation_history, conversation_history=input_data.conversation_history,
dispose_sandbox=input_data.dispose_sandbox, dispose_sandbox=input_data.dispose_sandbox,
execution_context=execution_context,
) )
yield "response", response yield "response", response

View File

@@ -1,12 +1,12 @@
from enum import Enum from enum import Enum
from typing import Any, Literal, Optional from typing import TYPE_CHECKING, Any, Literal, Optional
from e2b_code_interpreter import AsyncSandbox from e2b_code_interpreter import AsyncSandbox
from e2b_code_interpreter import Result as E2BExecutionResult from e2b_code_interpreter import Result as E2BExecutionResult
from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart
from pydantic import BaseModel, Field, JsonValue, SecretStr from pydantic import BaseModel, Field, JsonValue, SecretStr
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
@@ -20,6 +20,13 @@ from backend.data.model import (
SchemaField, SchemaField,
) )
from backend.integrations.providers import ProviderName from backend.integrations.providers import ProviderName
from backend.util.sandbox_files import (
SandboxFileOutput,
extract_and_store_sandbox_files,
)
if TYPE_CHECKING:
from backend.executor.utils import ExecutionContext
TEST_CREDENTIALS = APIKeyCredentials( TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef", id="01234567-89ab-cdef-0123-456789abcdef",
@@ -85,6 +92,9 @@ class CodeExecutionResult(MainCodeExecutionResult):
class BaseE2BExecutorMixin: class BaseE2BExecutorMixin:
"""Shared implementation methods for E2B executor blocks.""" """Shared implementation methods for E2B executor blocks."""
# Default working directory in E2B sandboxes
WORKING_DIR = "/home/user"
async def execute_code( async def execute_code(
self, self,
api_key: str, api_key: str,
@@ -95,14 +105,21 @@ class BaseE2BExecutorMixin:
timeout: Optional[int] = None, timeout: Optional[int] = None,
sandbox_id: Optional[str] = None, sandbox_id: Optional[str] = None,
dispose_sandbox: bool = False, dispose_sandbox: bool = False,
execution_context: Optional["ExecutionContext"] = None,
extract_files: bool = False,
): ):
""" """
Unified code execution method that handles all three use cases: Unified code execution method that handles all three use cases:
1. Create new sandbox and execute (ExecuteCodeBlock) 1. Create new sandbox and execute (ExecuteCodeBlock)
2. Create new sandbox, execute, and return sandbox_id (InstantiateCodeSandboxBlock) 2. Create new sandbox, execute, and return sandbox_id (InstantiateCodeSandboxBlock)
3. Connect to existing sandbox and execute (ExecuteCodeStepBlock) 3. Connect to existing sandbox and execute (ExecuteCodeStepBlock)
Args:
extract_files: If True and execution_context provided, extract files
created/modified during execution and store to workspace.
""" # noqa """ # noqa
sandbox = None sandbox = None
files: list[SandboxFileOutput] = []
try: try:
if sandbox_id: if sandbox_id:
# Connect to existing sandbox (ExecuteCodeStepBlock case) # Connect to existing sandbox (ExecuteCodeStepBlock case)
@@ -118,6 +135,12 @@ class BaseE2BExecutorMixin:
for cmd in setup_commands: for cmd in setup_commands:
await sandbox.commands.run(cmd) await sandbox.commands.run(cmd)
# Capture timestamp before execution to scope file extraction
start_timestamp = None
if extract_files:
ts_result = await sandbox.commands.run("date -u +%Y-%m-%dT%H:%M:%S")
start_timestamp = ts_result.stdout.strip() if ts_result.stdout else None
# Execute the code # Execute the code
execution = await sandbox.run_code( execution = await sandbox.run_code(
code, code,
@@ -133,7 +156,24 @@ class BaseE2BExecutorMixin:
stdout_logs = "".join(execution.logs.stdout) stdout_logs = "".join(execution.logs.stdout)
stderr_logs = "".join(execution.logs.stderr) stderr_logs = "".join(execution.logs.stderr)
return results, text_output, stdout_logs, stderr_logs, sandbox.sandbox_id # Extract files created/modified during this execution
if extract_files and execution_context:
files = await extract_and_store_sandbox_files(
sandbox=sandbox,
working_directory=self.WORKING_DIR,
execution_context=execution_context,
since_timestamp=start_timestamp,
text_only=False, # Include binary files too
)
return (
results,
text_output,
stdout_logs,
stderr_logs,
sandbox.sandbox_id,
files,
)
finally: finally:
# Dispose of sandbox if requested to reduce usage costs # Dispose of sandbox if requested to reduce usage costs
if dispose_sandbox and sandbox: if dispose_sandbox and sandbox:
@@ -238,6 +278,12 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
description="Standard output logs from execution" description="Standard output logs from execution"
) )
stderr_logs: str = SchemaField(description="Standard error logs from execution") stderr_logs: str = SchemaField(description="Standard error logs from execution")
files: list[SandboxFileOutput] = SchemaField(
description=(
"Files created or modified during execution. "
"Each file has path, name, content, and workspace_ref (if stored)."
),
)
def __init__(self): def __init__(self):
super().__init__( super().__init__(
@@ -259,23 +305,30 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
("results", []), ("results", []),
("response", "Hello World"), ("response", "Hello World"),
("stdout_logs", "Hello World\n"), ("stdout_logs", "Hello World\n"),
("files", []),
], ],
test_mock={ test_mock={
"execute_code": lambda api_key, code, language, template_id, setup_commands, timeout, dispose_sandbox: ( # noqa "execute_code": lambda api_key, code, language, template_id, setup_commands, timeout, dispose_sandbox, execution_context, extract_files: ( # noqa
[], # results [], # results
"Hello World", # text_output "Hello World", # text_output
"Hello World\n", # stdout_logs "Hello World\n", # stdout_logs
"", # stderr_logs "", # stderr_logs
"sandbox_id", # sandbox_id "sandbox_id", # sandbox_id
[], # files
), ),
}, },
) )
async def run( async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: "ExecutionContext",
**kwargs,
) -> BlockOutput: ) -> BlockOutput:
try: try:
results, text_output, stdout, stderr, _ = await self.execute_code( results, text_output, stdout, stderr, _, files = await self.execute_code(
api_key=credentials.api_key.get_secret_value(), api_key=credentials.api_key.get_secret_value(),
code=input_data.code, code=input_data.code,
language=input_data.language, language=input_data.language,
@@ -283,6 +336,8 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
setup_commands=input_data.setup_commands, setup_commands=input_data.setup_commands,
timeout=input_data.timeout, timeout=input_data.timeout,
dispose_sandbox=input_data.dispose_sandbox, dispose_sandbox=input_data.dispose_sandbox,
execution_context=execution_context,
extract_files=True,
) )
# Determine result object shape & filter out empty formats # Determine result object shape & filter out empty formats
@@ -296,6 +351,8 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
yield "stdout_logs", stdout yield "stdout_logs", stdout
if stderr: if stderr:
yield "stderr_logs", stderr yield "stderr_logs", stderr
# Always yield files (empty list if none)
yield "files", [f.model_dump() for f in files]
except Exception as e: except Exception as e:
yield "error", str(e) yield "error", str(e)
@@ -393,6 +450,7 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
"Hello World\n", # stdout_logs "Hello World\n", # stdout_logs
"", # stderr_logs "", # stderr_logs
"sandbox_id", # sandbox_id "sandbox_id", # sandbox_id
[], # files
), ),
}, },
) )
@@ -401,7 +459,7 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput: ) -> BlockOutput:
try: try:
_, text_output, stdout, stderr, sandbox_id = await self.execute_code( _, text_output, stdout, stderr, sandbox_id, _ = await self.execute_code(
api_key=credentials.api_key.get_secret_value(), api_key=credentials.api_key.get_secret_value(),
code=input_data.setup_code, code=input_data.setup_code,
language=input_data.language, language=input_data.language,
@@ -500,6 +558,7 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
"Hello World\n", # stdout_logs "Hello World\n", # stdout_logs
"", # stderr_logs "", # stderr_logs
sandbox_id, # sandbox_id sandbox_id, # sandbox_id
[], # files
), ),
}, },
) )
@@ -508,7 +567,7 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput: ) -> BlockOutput:
try: try:
results, text_output, stdout, stderr, _ = await self.execute_code( results, text_output, stdout, stderr, _, _ = await self.execute_code(
api_key=credentials.api_key.get_secret_value(), api_key=credentials.api_key.get_secret_value(),
code=input_data.step_code, code=input_data.step_code,
language=input_data.language, language=input_data.language,

View File

@@ -1,6 +1,6 @@
import re import re
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -6,7 +6,7 @@ from openai import AsyncOpenAI
from openai.types.responses import Response as OpenAIResponse from openai.types.responses import Response as OpenAIResponse
from pydantic import SecretStr from pydantic import SecretStr
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -1,6 +1,6 @@
from pydantic import BaseModel from pydantic import BaseModel
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockManualWebhookConfig, BlockManualWebhookConfig,

View File

@@ -1,4 +1,4 @@
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -1,6 +1,6 @@
from typing import Any, List from typing import Any, List
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -1,6 +1,6 @@
import codecs import codecs
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -8,7 +8,7 @@ from typing import Any, Literal, cast
import discord import discord
from pydantic import SecretStr from pydantic import SecretStr
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -2,7 +2,7 @@
Discord OAuth-based blocks. Discord OAuth-based blocks.
""" """
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -7,7 +7,7 @@ from typing import Literal
from pydantic import BaseModel, ConfigDict, SecretStr from pydantic import BaseModel, ConfigDict, SecretStr
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -2,7 +2,7 @@
import codecs import codecs
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -8,7 +8,7 @@ which provides access to LinkedIn profile data and related information.
import logging import logging
from typing import Optional from typing import Optional
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -3,6 +3,13 @@ import logging
from enum import Enum from enum import Enum
from typing import Any from typing import Any
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.fal._auth import ( from backend.blocks.fal._auth import (
TEST_CREDENTIALS, TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT, TEST_CREDENTIALS_INPUT,
@@ -10,13 +17,6 @@ from backend.blocks.fal._auth import (
FalCredentialsField, FalCredentialsField,
FalCredentialsInput, FalCredentialsInput,
) )
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField from backend.data.model import SchemaField
from backend.util.file import store_media_file from backend.util.file import store_media_file

View File

@@ -5,7 +5,7 @@ from pydantic import SecretStr
from replicate.client import Client as ReplicateClient from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput from replicate.helpers import FileOutput
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -3,7 +3,7 @@ from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -5,7 +5,7 @@ from typing import Optional
from typing_extensions import TypedDict from typing_extensions import TypedDict
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -3,7 +3,7 @@ from urllib.parse import urlparse
from typing_extensions import TypedDict from typing_extensions import TypedDict
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -2,7 +2,7 @@ import re
from typing_extensions import TypedDict from typing_extensions import TypedDict
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -2,7 +2,7 @@ import base64
from typing_extensions import TypedDict from typing_extensions import TypedDict
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -4,7 +4,7 @@ from typing import Any, List, Optional
from typing_extensions import TypedDict from typing_extensions import TypedDict
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -3,7 +3,7 @@ from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -4,7 +4,7 @@ from pathlib import Path
from pydantic import BaseModel from pydantic import BaseModel
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -8,7 +8,7 @@ from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build from googleapiclient.discovery import build
from pydantic import BaseModel from pydantic import BaseModel
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -7,14 +7,14 @@ from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build from googleapiclient.discovery import build
from gravitas_md2gdocs import to_requests from gravitas_md2gdocs import to_requests
from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField from backend.blocks._base import (
from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
BlockSchemaInput, BlockSchemaInput,
BlockSchemaOutput, BlockSchemaOutput,
) )
from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField
from backend.data.model import SchemaField from backend.data.model import SchemaField
from backend.util.settings import Settings from backend.util.settings import Settings

View File

@@ -14,7 +14,7 @@ from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build from googleapiclient.discovery import build
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -7,14 +7,14 @@ from enum import Enum
from google.oauth2.credentials import Credentials from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build from googleapiclient.discovery import build
from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField from backend.blocks._base import (
from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
BlockSchemaInput, BlockSchemaInput,
BlockSchemaOutput, BlockSchemaOutput,
) )
from backend.blocks.google._drive import GoogleDriveFile, GoogleDriveFileField
from backend.data.model import SchemaField from backend.data.model import SchemaField
from backend.util.settings import Settings from backend.util.settings import Settings

View File

@@ -3,7 +3,7 @@ from typing import Literal
import googlemaps import googlemaps
from pydantic import BaseModel, SecretStr from pydantic import BaseModel, SecretStr
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -9,9 +9,7 @@ from typing import Any, Optional
from prisma.enums import ReviewStatus from prisma.enums import ReviewStatus
from pydantic import BaseModel from pydantic import BaseModel
from backend.data.execution import ExecutionStatus
from backend.data.human_review import ReviewResult from backend.data.human_review import ReviewResult
from backend.executor.manager import async_update_node_execution_status
from backend.util.clients import get_database_manager_async_client from backend.util.clients import get_database_manager_async_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -43,6 +41,8 @@ class HITLReviewHelper:
@staticmethod @staticmethod
async def update_node_execution_status(**kwargs) -> None: async def update_node_execution_status(**kwargs) -> None:
"""Update the execution status of a node.""" """Update the execution status of a node."""
from backend.executor.manager import async_update_node_execution_status
await async_update_node_execution_status( await async_update_node_execution_status(
db_client=get_database_manager_async_client(), **kwargs db_client=get_database_manager_async_client(), **kwargs
) )
@@ -88,12 +88,13 @@ class HITLReviewHelper:
Raises: Raises:
Exception: If review creation or status update fails Exception: If review creation or status update fails
""" """
from backend.data.execution import ExecutionStatus
# Note: Safe mode checks (human_in_the_loop_safe_mode, sensitive_action_safe_mode) # Note: Safe mode checks (human_in_the_loop_safe_mode, sensitive_action_safe_mode)
# are handled by the caller: # are handled by the caller:
# - HITL blocks check human_in_the_loop_safe_mode in their run() method # - HITL blocks check human_in_the_loop_safe_mode in their run() method
# - Sensitive action blocks check sensitive_action_safe_mode in is_block_exec_need_review() # - Sensitive action blocks check sensitive_action_safe_mode in is_block_exec_need_review()
# This function only handles checking for existing approvals. # This function only handles checking for existing approvals.
# Check if this node has already been approved (normal or auto-approval) # Check if this node has already been approved (normal or auto-approval)
if approval_result := await HITLReviewHelper.check_approval( if approval_result := await HITLReviewHelper.check_approval(
node_exec_id=node_exec_id, node_exec_id=node_exec_id,

View File

@@ -8,7 +8,7 @@ from typing import Literal
import aiofiles import aiofiles
from pydantic import SecretStr from pydantic import SecretStr
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -1,15 +1,15 @@
from backend.blocks.hubspot._auth import ( from backend.blocks._base import (
HubSpotCredentials,
HubSpotCredentialsField,
HubSpotCredentialsInput,
)
from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
BlockSchemaInput, BlockSchemaInput,
BlockSchemaOutput, BlockSchemaOutput,
) )
from backend.blocks.hubspot._auth import (
HubSpotCredentials,
HubSpotCredentialsField,
HubSpotCredentialsInput,
)
from backend.data.model import SchemaField from backend.data.model import SchemaField
from backend.util.request import Requests from backend.util.request import Requests

View File

@@ -1,15 +1,15 @@
from backend.blocks.hubspot._auth import ( from backend.blocks._base import (
HubSpotCredentials,
HubSpotCredentialsField,
HubSpotCredentialsInput,
)
from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
BlockSchemaInput, BlockSchemaInput,
BlockSchemaOutput, BlockSchemaOutput,
) )
from backend.blocks.hubspot._auth import (
HubSpotCredentials,
HubSpotCredentialsField,
HubSpotCredentialsInput,
)
from backend.data.model import SchemaField from backend.data.model import SchemaField
from backend.util.request import Requests from backend.util.request import Requests

View File

@@ -1,17 +1,17 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from backend.blocks.hubspot._auth import ( from backend.blocks._base import (
HubSpotCredentials,
HubSpotCredentialsField,
HubSpotCredentialsInput,
)
from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
BlockSchemaInput, BlockSchemaInput,
BlockSchemaOutput, BlockSchemaOutput,
) )
from backend.blocks.hubspot._auth import (
HubSpotCredentials,
HubSpotCredentialsField,
HubSpotCredentialsInput,
)
from backend.data.model import SchemaField from backend.data.model import SchemaField
from backend.util.request import Requests from backend.util.request import Requests

View File

@@ -3,8 +3,7 @@ from typing import Any
from prisma.enums import ReviewStatus from prisma.enums import ReviewStatus
from backend.blocks.helpers.review import HITLReviewHelper from backend.blocks._base import (
from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
@@ -12,6 +11,7 @@ from backend.data.block import (
BlockSchemaOutput, BlockSchemaOutput,
BlockType, BlockType,
) )
from backend.blocks.helpers.review import HITLReviewHelper
from backend.data.execution import ExecutionContext from backend.data.execution import ExecutionContext
from backend.data.human_review import ReviewResult from backend.data.human_review import ReviewResult
from backend.data.model import SchemaField from backend.data.model import SchemaField

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, Literal, Optional
from pydantic import SecretStr from pydantic import SecretStr
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -2,9 +2,7 @@ import copy
from datetime import date, time from datetime import date, time
from typing import Any, Optional from typing import Any, Optional
# Import for Google Drive file input block from backend.blocks._base import (
from backend.blocks.google._drive import AttachmentView, GoogleDriveFile
from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
@@ -12,6 +10,9 @@ from backend.data.block import (
BlockSchemaInput, BlockSchemaInput,
BlockType, BlockType,
) )
# Import for Google Drive file input block
from backend.blocks.google._drive import AttachmentView, GoogleDriveFile
from backend.data.execution import ExecutionContext from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField from backend.data.model import SchemaField
from backend.util.file import store_media_file from backend.util.file import store_media_file

View File

@@ -1,6 +1,6 @@
from typing import Any from typing import Any
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -1,15 +1,15 @@
from backend.blocks.jina._auth import ( from backend.blocks._base import (
JinaCredentials,
JinaCredentialsField,
JinaCredentialsInput,
)
from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
BlockSchemaInput, BlockSchemaInput,
BlockSchemaOutput, BlockSchemaOutput,
) )
from backend.blocks.jina._auth import (
JinaCredentials,
JinaCredentialsField,
JinaCredentialsInput,
)
from backend.data.model import SchemaField from backend.data.model import SchemaField
from backend.util.request import Requests from backend.util.request import Requests

View File

@@ -1,15 +1,15 @@
from backend.blocks.jina._auth import ( from backend.blocks._base import (
JinaCredentials,
JinaCredentialsField,
JinaCredentialsInput,
)
from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
BlockSchemaInput, BlockSchemaInput,
BlockSchemaOutput, BlockSchemaOutput,
) )
from backend.blocks.jina._auth import (
JinaCredentials,
JinaCredentialsField,
JinaCredentialsInput,
)
from backend.data.model import SchemaField from backend.data.model import SchemaField
from backend.util.request import Requests from backend.util.request import Requests

View File

@@ -3,18 +3,18 @@ from urllib.parse import quote
from typing_extensions import TypedDict from typing_extensions import TypedDict
from backend.blocks.jina._auth import ( from backend.blocks._base import (
JinaCredentials,
JinaCredentialsField,
JinaCredentialsInput,
)
from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
BlockSchemaInput, BlockSchemaInput,
BlockSchemaOutput, BlockSchemaOutput,
) )
from backend.blocks.jina._auth import (
JinaCredentials,
JinaCredentialsField,
JinaCredentialsInput,
)
from backend.data.model import SchemaField from backend.data.model import SchemaField
from backend.util.request import Requests from backend.util.request import Requests

View File

@@ -1,5 +1,12 @@
from urllib.parse import quote from urllib.parse import quote
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.blocks.jina._auth import ( from backend.blocks.jina._auth import (
TEST_CREDENTIALS, TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT, TEST_CREDENTIALS_INPUT,
@@ -8,13 +15,6 @@ from backend.blocks.jina._auth import (
JinaCredentialsInput, JinaCredentialsInput,
) )
from backend.blocks.search import GetRequest from backend.blocks.search import GetRequest
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import SchemaField from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError from backend.util.exceptions import BlockExecutionError

View File

@@ -15,7 +15,7 @@ from anthropic.types import ToolParam
from groq import AsyncGroq from groq import AsyncGroq
from pydantic import BaseModel, SecretStr from pydantic import BaseModel, SecretStr
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -2,7 +2,7 @@ import operator
from enum import Enum from enum import Enum
from typing import Any from typing import Any
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -0,0 +1,300 @@
"""
MCP (Model Context Protocol) Tool Block.
A single dynamic block that can connect to any MCP server, discover available tools,
and execute them. Works like AgentExecutorBlock — the user selects a tool from a
dropdown and the input/output schema adapts dynamically.
"""
import json
import logging
from typing import Any, Literal
from pydantic import SecretStr
from backend.blocks._base import (
Block,
BlockCategory,
BlockSchemaInput,
BlockSchemaOutput,
BlockType,
)
from backend.blocks.mcp.client import MCPClient, MCPClientError
from backend.data.block import BlockInput, BlockOutput
from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
OAuth2Credentials,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.json import validate_with_jsonschema
logger = logging.getLogger(__name__)
TEST_CREDENTIALS = OAuth2Credentials(
id="test-mcp-cred",
provider="mcp",
access_token=SecretStr("mock-mcp-token"),
refresh_token=SecretStr("mock-refresh"),
scopes=[],
title="Mock MCP credential",
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
MCPCredentials = CredentialsMetaInput[Literal[ProviderName.MCP], Literal["oauth2"]]
class MCPToolBlock(Block):
"""
A block that connects to an MCP server, lets the user pick a tool,
and executes it with dynamic input/output schema.
The flow:
1. User provides an MCP server URL (and optional credentials)
2. Frontend calls the backend to get tool list from that URL
3. User selects a tool from a dropdown (available_tools)
4. The block's input schema updates to reflect the selected tool's parameters
5. On execution, the block calls the MCP server to run the tool
"""
class Input(BlockSchemaInput):
server_url: str = SchemaField(
description="URL of the MCP server (Streamable HTTP endpoint)",
placeholder="https://mcp.example.com/mcp",
)
credentials: MCPCredentials = CredentialsField(
discriminator="server_url",
description="MCP server OAuth credentials",
default={},
)
selected_tool: str = SchemaField(
description="The MCP tool to execute",
placeholder="Select a tool",
default="",
)
tool_input_schema: dict[str, Any] = SchemaField(
description="JSON Schema for the selected tool's input parameters. "
"Populated automatically when a tool is selected.",
default={},
hidden=True,
)
tool_arguments: dict[str, Any] = SchemaField(
description="Arguments to pass to the selected MCP tool. "
"The fields here are defined by the tool's input schema.",
default={},
)
@classmethod
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
"""Return the tool's input schema so the builder UI renders dynamic fields."""
return data.get("tool_input_schema", {})
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
"""Return the current tool_arguments as defaults for the dynamic fields."""
return data.get("tool_arguments", {})
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
"""Check which required tool arguments are missing."""
required_fields = cls.get_input_schema(data).get("required", [])
tool_arguments = data.get("tool_arguments", {})
return set(required_fields) - set(tool_arguments)
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
"""Validate tool_arguments against the tool's input schema."""
tool_schema = cls.get_input_schema(data)
if not tool_schema:
return None
tool_arguments = data.get("tool_arguments", {})
return validate_with_jsonschema(tool_schema, tool_arguments)
class Output(BlockSchemaOutput):
result: Any = SchemaField(description="The result returned by the MCP tool")
error: str = SchemaField(description="Error message if the tool call failed")
def __init__(self):
super().__init__(
id="a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4",
description="Connect to any MCP server and execute its tools. "
"Provide a server URL, select a tool, and pass arguments dynamically.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=MCPToolBlock.Input,
output_schema=MCPToolBlock.Output,
block_type=BlockType.MCP_TOOL,
test_credentials=TEST_CREDENTIALS,
test_input={
"server_url": "https://mcp.example.com/mcp",
"credentials": TEST_CREDENTIALS_INPUT,
"selected_tool": "get_weather",
"tool_input_schema": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
"tool_arguments": {"city": "London"},
},
test_output=[
(
"result",
{"weather": "sunny", "temperature": 20},
),
],
test_mock={
"_call_mcp_tool": lambda *a, **kw: {
"weather": "sunny",
"temperature": 20,
},
},
)
async def _call_mcp_tool(
self,
server_url: str,
tool_name: str,
arguments: dict[str, Any],
auth_token: str | None = None,
) -> Any:
"""Call a tool on the MCP server. Extracted for easy mocking in tests."""
client = MCPClient(server_url, auth_token=auth_token)
await client.initialize()
result = await client.call_tool(tool_name, arguments)
if result.is_error:
error_text = ""
for item in result.content:
if item.get("type") == "text":
error_text += item.get("text", "")
raise MCPClientError(
f"MCP tool '{tool_name}' returned an error: "
f"{error_text or 'Unknown error'}"
)
# Extract text content from the result
output_parts = []
for item in result.content:
if item.get("type") == "text":
text = item.get("text", "")
# Try to parse as JSON for structured output
try:
output_parts.append(json.loads(text))
except (json.JSONDecodeError, ValueError):
output_parts.append(text)
elif item.get("type") == "image":
output_parts.append(
{
"type": "image",
"data": item.get("data"),
"mimeType": item.get("mimeType"),
}
)
elif item.get("type") == "resource":
output_parts.append(item.get("resource", {}))
# If single result, unwrap
if len(output_parts) == 1:
return output_parts[0]
return output_parts if output_parts else None
@staticmethod
async def _auto_lookup_credential(
user_id: str, server_url: str
) -> "OAuth2Credentials | None":
"""Auto-lookup stored MCP credential for a server URL.
This is a fallback for nodes that don't have ``credentials`` explicitly
set (e.g. nodes created before the credential field was wired up).
"""
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
try:
mgr = IntegrationCredentialsManager()
mcp_creds = await mgr.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
best: OAuth2Credentials | None = None
for cred in mcp_creds:
if (
isinstance(cred, OAuth2Credentials)
and (cred.metadata or {}).get("mcp_server_url") == server_url
):
if best is None or (
(cred.access_token_expires_at or 0)
> (best.access_token_expires_at or 0)
):
best = cred
if best:
best = await mgr.refresh_if_needed(user_id, best)
logger.info(
"Auto-resolved MCP credential %s for %s", best.id, server_url
)
return best
except Exception:
logger.warning("Auto-lookup MCP credential failed", exc_info=True)
return None
async def run(
self,
input_data: Input,
*,
user_id: str,
credentials: OAuth2Credentials | None = None,
**kwargs,
) -> BlockOutput:
if not input_data.server_url:
yield "error", "MCP server URL is required"
return
if not input_data.selected_tool:
yield "error", "No tool selected. Please select a tool from the dropdown."
return
# Validate required tool arguments before calling the server.
# The executor-level validation is bypassed for MCP blocks because
# get_input_defaults() flattens tool_arguments, stripping tool_input_schema
# from the validation context.
required = set(input_data.tool_input_schema.get("required", []))
if required:
missing = required - set(input_data.tool_arguments.keys())
if missing:
yield "error", (
f"Missing required argument(s): {', '.join(sorted(missing))}. "
f"Please fill in all required fields marked with * in the block form."
)
return
# If no credentials were injected by the executor (e.g. legacy nodes
# that don't have the credentials field set), try to auto-lookup
# the stored MCP credential for this server URL.
if credentials is None:
credentials = await self._auto_lookup_credential(
user_id, input_data.server_url
)
auth_token = (
credentials.access_token.get_secret_value() if credentials else None
)
try:
result = await self._call_mcp_tool(
server_url=input_data.server_url,
tool_name=input_data.selected_tool,
arguments=input_data.tool_arguments,
auth_token=auth_token,
)
yield "result", result
except MCPClientError as e:
yield "error", str(e)
except Exception as e:
logger.exception(f"MCP tool call failed: {e}")
yield "error", f"MCP tool call failed: {str(e)}"

View File

@@ -0,0 +1,323 @@
"""
MCP (Model Context Protocol) HTTP client.
Implements the MCP Streamable HTTP transport for listing tools and calling tools
on remote MCP servers. Uses JSON-RPC 2.0 over HTTP POST.
Handles both JSON and SSE (text/event-stream) response formats per the MCP spec.
Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports
"""
import json
import logging
from dataclasses import dataclass, field
from typing import Any
from backend.util.request import Requests
logger = logging.getLogger(__name__)
@dataclass
class MCPTool:
"""Represents an MCP tool discovered from a server."""
name: str
description: str
input_schema: dict[str, Any]
@dataclass
class MCPCallResult:
"""Result from calling an MCP tool."""
content: list[dict[str, Any]] = field(default_factory=list)
is_error: bool = False
class MCPClientError(Exception):
"""Raised when an MCP protocol error occurs."""
pass
class MCPClient:
"""
Async HTTP client for the MCP Streamable HTTP transport.
Communicates with MCP servers using JSON-RPC 2.0 over HTTP POST.
Supports optional Bearer token authentication.
"""
def __init__(
self,
server_url: str,
auth_token: str | None = None,
):
self.server_url = server_url.rstrip("/")
self.auth_token = auth_token
self._request_id = 0
self._session_id: str | None = None
def _next_id(self) -> int:
self._request_id += 1
return self._request_id
def _build_headers(self) -> dict[str, str]:
headers = {
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
}
if self.auth_token:
headers["Authorization"] = f"Bearer {self.auth_token}"
if self._session_id:
headers["Mcp-Session-Id"] = self._session_id
return headers
def _build_jsonrpc_request(
self, method: str, params: dict[str, Any] | None = None
) -> dict[str, Any]:
req: dict[str, Any] = {
"jsonrpc": "2.0",
"method": method,
"id": self._next_id(),
}
if params is not None:
req["params"] = params
return req
@staticmethod
def _parse_sse_response(text: str) -> dict[str, Any]:
"""Parse an SSE (text/event-stream) response body into JSON-RPC data.
MCP servers may return responses as SSE with format:
event: message
data: {"jsonrpc":"2.0","result":{...},"id":1}
We extract the last `data:` line that contains a JSON-RPC response
(i.e. has an "id" field), which is the reply to our request.
"""
last_data: dict[str, Any] | None = None
for line in text.splitlines():
stripped = line.strip()
if stripped.startswith("data:"):
payload = stripped[len("data:") :].strip()
if not payload:
continue
try:
parsed = json.loads(payload)
# Only keep JSON-RPC responses (have "id"), skip notifications
if isinstance(parsed, dict) and "id" in parsed:
last_data = parsed
except (json.JSONDecodeError, ValueError):
continue
if last_data is None:
raise MCPClientError("No JSON-RPC response found in SSE stream")
return last_data
async def _send_request(
self, method: str, params: dict[str, Any] | None = None
) -> Any:
"""Send a JSON-RPC request to the MCP server and return the result.
Handles both ``application/json`` and ``text/event-stream`` responses
as required by the MCP Streamable HTTP transport specification.
"""
payload = self._build_jsonrpc_request(method, params)
headers = self._build_headers()
requests = Requests(
raise_for_status=True,
extra_headers=headers,
)
response = await requests.post(self.server_url, json=payload)
# Capture session ID from response (MCP Streamable HTTP transport)
session_id = response.headers.get("Mcp-Session-Id")
if session_id:
self._session_id = session_id
content_type = response.headers.get("content-type", "")
if "text/event-stream" in content_type:
body = self._parse_sse_response(response.text())
else:
try:
body = response.json()
except Exception as e:
raise MCPClientError(
f"MCP server returned non-JSON response: {e}"
) from e
if not isinstance(body, dict):
raise MCPClientError(
f"MCP server returned unexpected JSON type: {type(body).__name__}"
)
# Handle JSON-RPC error
if "error" in body:
error = body["error"]
if isinstance(error, dict):
raise MCPClientError(
f"MCP server error [{error.get('code', '?')}]: "
f"{error.get('message', 'Unknown error')}"
)
raise MCPClientError(f"MCP server error: {error}")
return body.get("result")
async def _send_notification(self, method: str) -> None:
"""Send a JSON-RPC notification (no id, no response expected)."""
headers = self._build_headers()
notification = {"jsonrpc": "2.0", "method": method}
requests = Requests(
raise_for_status=False,
extra_headers=headers,
)
await requests.post(self.server_url, json=notification)
async def discover_auth(self) -> dict[str, Any] | None:
"""Probe the MCP server's OAuth metadata (RFC 9728 / MCP spec).
Returns ``None`` if the server doesn't require auth, otherwise returns
a dict with:
- ``authorization_servers``: list of authorization server URLs
- ``resource``: the resource indicator URL (usually the MCP endpoint)
- ``scopes_supported``: optional list of supported scopes
The caller can then fetch the authorization server metadata to get
``authorization_endpoint``, ``token_endpoint``, etc.
"""
from urllib.parse import urlparse
parsed = urlparse(self.server_url)
base = f"{parsed.scheme}://{parsed.netloc}"
# Build candidates for protected-resource metadata (per RFC 9728)
path = parsed.path.rstrip("/")
candidates = []
if path and path != "/":
candidates.append(f"{base}/.well-known/oauth-protected-resource{path}")
candidates.append(f"{base}/.well-known/oauth-protected-resource")
requests = Requests(
raise_for_status=False,
)
for url in candidates:
try:
resp = await requests.get(url)
if resp.status == 200:
data = resp.json()
if isinstance(data, dict) and "authorization_servers" in data:
return data
except Exception:
continue
return None
async def discover_auth_server_metadata(
self, auth_server_url: str
) -> dict[str, Any] | None:
"""Fetch the OAuth Authorization Server Metadata (RFC 8414).
Given an authorization server URL, returns a dict with:
- ``authorization_endpoint``
- ``token_endpoint``
- ``registration_endpoint`` (for dynamic client registration)
- ``scopes_supported``
- ``code_challenge_methods_supported``
- etc.
"""
from urllib.parse import urlparse
parsed = urlparse(auth_server_url)
base = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path.rstrip("/")
# Try standard metadata endpoints (RFC 8414 and OpenID Connect)
candidates = []
if path and path != "/":
candidates.append(f"{base}/.well-known/oauth-authorization-server{path}")
candidates.append(f"{base}/.well-known/oauth-authorization-server")
candidates.append(f"{base}/.well-known/openid-configuration")
requests = Requests(
raise_for_status=False,
)
for url in candidates:
try:
resp = await requests.get(url)
if resp.status == 200:
data = resp.json()
if isinstance(data, dict) and "authorization_endpoint" in data:
return data
except Exception:
continue
return None
async def initialize(self) -> dict[str, Any]:
"""
Send the MCP initialize request.
This is required by the MCP protocol before any other requests.
Returns the server's capabilities.
"""
result = await self._send_request(
"initialize",
{
"protocolVersion": "2025-03-26",
"capabilities": {},
"clientInfo": {"name": "AutoGPT-Platform", "version": "1.0.0"},
},
)
# Send initialized notification (no response expected)
await self._send_notification("notifications/initialized")
return result or {}
async def list_tools(self) -> list[MCPTool]:
"""
Discover available tools from the MCP server.
Returns a list of MCPTool objects with name, description, and input schema.
"""
result = await self._send_request("tools/list")
if not result or "tools" not in result:
return []
tools = []
for tool_data in result["tools"]:
tools.append(
MCPTool(
name=tool_data.get("name", ""),
description=tool_data.get("description", ""),
input_schema=tool_data.get("inputSchema", {}),
)
)
return tools
async def call_tool(
self, tool_name: str, arguments: dict[str, Any]
) -> MCPCallResult:
"""
Call a tool on the MCP server.
Args:
tool_name: The name of the tool to call.
arguments: The arguments to pass to the tool.
Returns:
MCPCallResult with the tool's response content.
"""
result = await self._send_request(
"tools/call",
{"name": tool_name, "arguments": arguments},
)
if not result:
return MCPCallResult(is_error=True)
return MCPCallResult(
content=result.get("content", []),
is_error=result.get("isError", False),
)

View File

@@ -0,0 +1,204 @@
"""
MCP OAuth handler for MCP servers that use OAuth 2.1 authorization.
Unlike other OAuth handlers (GitHub, Google, etc.) where endpoints are fixed,
MCP servers have dynamic endpoints discovered via RFC 9728 / RFC 8414 metadata.
This handler accepts those endpoints at construction time.
"""
import logging
import time
import urllib.parse
from typing import ClassVar, Optional
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
from backend.integrations.oauth.base import BaseOAuthHandler
from backend.integrations.providers import ProviderName
from backend.util.request import Requests
logger = logging.getLogger(__name__)
class MCPOAuthHandler(BaseOAuthHandler):
"""
OAuth handler for MCP servers with dynamically-discovered endpoints.
Construction requires the authorization and token endpoint URLs,
which are obtained via MCP OAuth metadata discovery
(``MCPClient.discover_auth`` + ``discover_auth_server_metadata``).
"""
PROVIDER_NAME: ClassVar[ProviderName | str] = ProviderName.MCP
DEFAULT_SCOPES: ClassVar[list[str]] = []
def __init__(
self,
client_id: str,
client_secret: str,
redirect_uri: str,
*,
authorize_url: str,
token_url: str,
revoke_url: str | None = None,
resource_url: str | None = None,
):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
self.authorize_url = authorize_url
self.token_url = token_url
self.revoke_url = revoke_url
self.resource_url = resource_url
def get_login_url(
self,
scopes: list[str],
state: str,
code_challenge: Optional[str],
) -> str:
scopes = self.handle_default_scopes(scopes)
params: dict[str, str] = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"state": state,
}
if scopes:
params["scope"] = " ".join(scopes)
# PKCE (S256) — included when the caller provides a code_challenge
if code_challenge:
params["code_challenge"] = code_challenge
params["code_challenge_method"] = "S256"
# MCP spec requires resource indicator (RFC 8707)
if self.resource_url:
params["resource"] = self.resource_url
return f"{self.authorize_url}?{urllib.parse.urlencode(params)}"
async def exchange_code_for_tokens(
self,
code: str,
scopes: list[str],
code_verifier: Optional[str],
) -> OAuth2Credentials:
data: dict[str, str] = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": self.redirect_uri,
"client_id": self.client_id,
}
if self.client_secret:
data["client_secret"] = self.client_secret
if code_verifier:
data["code_verifier"] = code_verifier
if self.resource_url:
data["resource"] = self.resource_url
response = await Requests(raise_for_status=True).post(
self.token_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
tokens = response.json()
if "error" in tokens:
raise RuntimeError(
f"Token exchange failed: {tokens.get('error_description', tokens['error'])}"
)
if "access_token" not in tokens:
raise RuntimeError("OAuth token response missing 'access_token' field")
now = int(time.time())
expires_in = tokens.get("expires_in")
return OAuth2Credentials(
provider=self.PROVIDER_NAME,
title=None,
access_token=SecretStr(tokens["access_token"]),
refresh_token=(
SecretStr(tokens["refresh_token"])
if tokens.get("refresh_token")
else None
),
access_token_expires_at=now + expires_in if expires_in else None,
refresh_token_expires_at=None,
scopes=scopes,
metadata={
"mcp_token_url": self.token_url,
"mcp_resource_url": self.resource_url,
},
)
async def _refresh_tokens(
self, credentials: OAuth2Credentials
) -> OAuth2Credentials:
if not credentials.refresh_token:
raise ValueError("No refresh token available for MCP OAuth credentials")
data: dict[str, str] = {
"grant_type": "refresh_token",
"refresh_token": credentials.refresh_token.get_secret_value(),
"client_id": self.client_id,
}
if self.client_secret:
data["client_secret"] = self.client_secret
if self.resource_url:
data["resource"] = self.resource_url
response = await Requests(raise_for_status=True).post(
self.token_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
tokens = response.json()
if "error" in tokens:
raise RuntimeError(
f"Token refresh failed: {tokens.get('error_description', tokens['error'])}"
)
if "access_token" not in tokens:
raise RuntimeError("OAuth refresh response missing 'access_token' field")
now = int(time.time())
expires_in = tokens.get("expires_in")
return OAuth2Credentials(
id=credentials.id,
provider=self.PROVIDER_NAME,
title=credentials.title,
access_token=SecretStr(tokens["access_token"]),
refresh_token=(
SecretStr(tokens["refresh_token"])
if tokens.get("refresh_token")
else credentials.refresh_token
),
access_token_expires_at=now + expires_in if expires_in else None,
refresh_token_expires_at=credentials.refresh_token_expires_at,
scopes=credentials.scopes,
metadata=credentials.metadata,
)
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
if not self.revoke_url:
return False
try:
data = {
"token": credentials.access_token.get_secret_value(),
"token_type_hint": "access_token",
"client_id": self.client_id,
}
await Requests().post(
self.revoke_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
return True
except Exception:
logger.warning("Failed to revoke MCP OAuth tokens", exc_info=True)
return False

View File

@@ -0,0 +1,109 @@
"""
End-to-end tests against a real public MCP server.
These tests hit the OpenAI docs MCP server (https://developers.openai.com/mcp)
which is publicly accessible without authentication and returns SSE responses.
Mark: These are tagged with ``@pytest.mark.e2e`` so they can be run/skipped
independently of the rest of the test suite (they require network access).
"""
import json
import os
import pytest
from backend.blocks.mcp.client import MCPClient
# Public MCP server that requires no authentication
OPENAI_DOCS_MCP_URL = "https://developers.openai.com/mcp"
# Skip all tests in this module unless RUN_E2E env var is set
pytestmark = pytest.mark.skipif(
not os.environ.get("RUN_E2E"), reason="set RUN_E2E=1 to run e2e tests"
)
class TestRealMCPServer:
"""Tests against the live OpenAI docs MCP server."""
@pytest.mark.asyncio(loop_scope="session")
async def test_initialize(self):
"""Verify we can complete the MCP handshake with a real server."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
result = await client.initialize()
assert result["protocolVersion"] == "2025-03-26"
assert "serverInfo" in result
assert result["serverInfo"]["name"] == "openai-docs-mcp"
assert "tools" in result.get("capabilities", {})
@pytest.mark.asyncio(loop_scope="session")
async def test_list_tools(self):
"""Verify we can discover tools from a real MCP server."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
await client.initialize()
tools = await client.list_tools()
assert len(tools) >= 3 # server has at least 5 tools as of writing
tool_names = {t.name for t in tools}
# These tools are documented and should be stable
assert "search_openai_docs" in tool_names
assert "list_openai_docs" in tool_names
assert "fetch_openai_doc" in tool_names
# Verify schema structure
search_tool = next(t for t in tools if t.name == "search_openai_docs")
assert "query" in search_tool.input_schema.get("properties", {})
assert "query" in search_tool.input_schema.get("required", [])
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_list_api_endpoints(self):
"""Call the list_api_endpoints tool and verify we get real data."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
await client.initialize()
result = await client.call_tool("list_api_endpoints", {})
assert not result.is_error
assert len(result.content) >= 1
assert result.content[0]["type"] == "text"
data = json.loads(result.content[0]["text"])
assert "paths" in data or "urls" in data
# The OpenAI API should have many endpoints
total = data.get("total", len(data.get("paths", [])))
assert total > 50
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_search(self):
"""Search for docs and verify we get results."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
await client.initialize()
result = await client.call_tool(
"search_openai_docs", {"query": "chat completions", "limit": 3}
)
assert not result.is_error
assert len(result.content) >= 1
@pytest.mark.asyncio(loop_scope="session")
async def test_sse_response_handling(self):
"""Verify the client correctly handles SSE responses from a real server.
This is the key test — our local test server returns JSON,
but real MCP servers typically return SSE. This proves the
SSE parsing works end-to-end.
"""
client = MCPClient(OPENAI_DOCS_MCP_URL)
# initialize() internally calls _send_request which must parse SSE
result = await client.initialize()
# If we got here without error, SSE parsing works
assert isinstance(result, dict)
assert "protocolVersion" in result
# Also verify list_tools works (another SSE response)
tools = await client.list_tools()
assert len(tools) > 0
assert all(hasattr(t, "name") for t in tools)

View File

@@ -0,0 +1,389 @@
"""
Integration tests for MCP client and MCPToolBlock against a real HTTP server.
These tests spin up a local MCP test server and run the full client/block flow
against it — no mocking, real HTTP requests.
"""
import asyncio
import json
import threading
from unittest.mock import patch
import pytest
from aiohttp import web
from pydantic import SecretStr
from backend.blocks.mcp.block import MCPToolBlock
from backend.blocks.mcp.client import MCPClient
from backend.blocks.mcp.test_server import create_test_mcp_app
from backend.data.model import OAuth2Credentials
MOCK_USER_ID = "test-user-integration"
class _MCPTestServer:
"""
Run an MCP test server in a background thread with its own event loop.
This avoids event loop conflicts with pytest-asyncio.
"""
def __init__(self, auth_token: str | None = None):
self.auth_token = auth_token
self.url: str = ""
self._runner: web.AppRunner | None = None
self._loop: asyncio.AbstractEventLoop | None = None
self._thread: threading.Thread | None = None
self._started = threading.Event()
def _run(self):
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._loop.run_until_complete(self._start())
self._started.set()
self._loop.run_forever()
async def _start(self):
app = create_test_mcp_app(auth_token=self.auth_token)
self._runner = web.AppRunner(app)
await self._runner.setup()
site = web.TCPSite(self._runner, "127.0.0.1", 0)
await site.start()
port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr]
self.url = f"http://127.0.0.1:{port}/mcp"
def start(self):
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
if not self._started.wait(timeout=5):
raise RuntimeError("MCP test server failed to start within 5 seconds")
return self
def stop(self):
if self._loop and self._runner:
asyncio.run_coroutine_threadsafe(self._runner.cleanup(), self._loop).result(
timeout=5
)
self._loop.call_soon_threadsafe(self._loop.stop)
if self._thread:
self._thread.join(timeout=5)
@pytest.fixture(scope="module")
def mcp_server():
"""Start a local MCP test server in a background thread."""
server = _MCPTestServer()
server.start()
yield server.url
server.stop()
@pytest.fixture(scope="module")
def mcp_server_with_auth():
"""Start a local MCP test server with auth in a background thread."""
server = _MCPTestServer(auth_token="test-secret-token")
server.start()
yield server.url, "test-secret-token"
server.stop()
@pytest.fixture(autouse=True)
def _allow_localhost():
"""
Allow 127.0.0.1 through SSRF protection for integration tests.
The Requests class blocks private IPs by default. We patch the Requests
constructor to always include 127.0.0.1 as a trusted origin so the local
test server is reachable.
"""
from backend.util.request import Requests
original_init = Requests.__init__
def patched_init(self, *args, **kwargs):
trusted = list(kwargs.get("trusted_origins") or [])
trusted.append("http://127.0.0.1")
kwargs["trusted_origins"] = trusted
original_init(self, *args, **kwargs)
with patch.object(Requests, "__init__", patched_init):
yield
def _make_client(url: str, auth_token: str | None = None) -> MCPClient:
"""Create an MCPClient for integration tests."""
return MCPClient(url, auth_token=auth_token)
# ── MCPClient integration tests ──────────────────────────────────────
class TestMCPClientIntegration:
"""Test MCPClient against a real local MCP server."""
@pytest.mark.asyncio(loop_scope="session")
async def test_initialize(self, mcp_server):
client = _make_client(mcp_server)
result = await client.initialize()
assert result["protocolVersion"] == "2025-03-26"
assert result["serverInfo"]["name"] == "test-mcp-server"
assert "tools" in result["capabilities"]
@pytest.mark.asyncio(loop_scope="session")
async def test_list_tools(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
tools = await client.list_tools()
assert len(tools) == 3
tool_names = {t.name for t in tools}
assert tool_names == {"get_weather", "add_numbers", "echo"}
# Check get_weather schema
weather = next(t for t in tools if t.name == "get_weather")
assert weather.description == "Get current weather for a city"
assert "city" in weather.input_schema["properties"]
assert weather.input_schema["required"] == ["city"]
# Check add_numbers schema
add = next(t for t in tools if t.name == "add_numbers")
assert "a" in add.input_schema["properties"]
assert "b" in add.input_schema["properties"]
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_get_weather(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
result = await client.call_tool("get_weather", {"city": "London"})
assert not result.is_error
assert len(result.content) == 1
assert result.content[0]["type"] == "text"
data = json.loads(result.content[0]["text"])
assert data["city"] == "London"
assert data["temperature"] == 22
assert data["condition"] == "sunny"
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_add_numbers(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
result = await client.call_tool("add_numbers", {"a": 3, "b": 7})
assert not result.is_error
data = json.loads(result.content[0]["text"])
assert data["result"] == 10
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_echo(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
result = await client.call_tool("echo", {"message": "Hello MCP!"})
assert not result.is_error
assert result.content[0]["text"] == "Hello MCP!"
@pytest.mark.asyncio(loop_scope="session")
async def test_call_unknown_tool(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
result = await client.call_tool("nonexistent_tool", {})
assert result.is_error
assert "Unknown tool" in result.content[0]["text"]
@pytest.mark.asyncio(loop_scope="session")
async def test_auth_success(self, mcp_server_with_auth):
url, token = mcp_server_with_auth
client = _make_client(url, auth_token=token)
result = await client.initialize()
assert result["protocolVersion"] == "2025-03-26"
tools = await client.list_tools()
assert len(tools) == 3
@pytest.mark.asyncio(loop_scope="session")
async def test_auth_failure(self, mcp_server_with_auth):
url, _ = mcp_server_with_auth
client = _make_client(url, auth_token="wrong-token")
with pytest.raises(Exception):
await client.initialize()
@pytest.mark.asyncio(loop_scope="session")
async def test_auth_missing(self, mcp_server_with_auth):
url, _ = mcp_server_with_auth
client = _make_client(url)
with pytest.raises(Exception):
await client.initialize()
# ── MCPToolBlock integration tests ───────────────────────────────────
class TestMCPToolBlockIntegration:
"""Test MCPToolBlock end-to-end against a real local MCP server."""
@pytest.mark.asyncio(loop_scope="session")
async def test_full_flow_get_weather(self, mcp_server):
"""Full flow: discover tools, select one, execute it."""
# Step 1: Discover tools (simulating what the frontend/API would do)
client = _make_client(mcp_server)
await client.initialize()
tools = await client.list_tools()
assert len(tools) == 3
# Step 2: User selects "get_weather" and we get its schema
weather_tool = next(t for t in tools if t.name == "get_weather")
# Step 3: Execute the block — no credentials (public server)
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="get_weather",
tool_input_schema=weather_tool.input_schema,
tool_arguments={"city": "Paris"},
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
result = outputs[0][1]
assert result["city"] == "Paris"
assert result["temperature"] == 22
assert result["condition"] == "sunny"
@pytest.mark.asyncio(loop_scope="session")
async def test_full_flow_add_numbers(self, mcp_server):
"""Full flow for add_numbers tool."""
client = _make_client(mcp_server)
await client.initialize()
tools = await client.list_tools()
add_tool = next(t for t in tools if t.name == "add_numbers")
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="add_numbers",
tool_input_schema=add_tool.input_schema,
tool_arguments={"a": 42, "b": 58},
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1]["result"] == 100
@pytest.mark.asyncio(loop_scope="session")
async def test_full_flow_echo_plain_text(self, mcp_server):
"""Verify plain text (non-JSON) responses work."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="echo",
tool_input_schema={
"type": "object",
"properties": {"message": {"type": "string"}},
"required": ["message"],
},
tool_arguments={"message": "Hello from AutoGPT!"},
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1] == "Hello from AutoGPT!"
@pytest.mark.asyncio(loop_scope="session")
async def test_full_flow_unknown_tool_yields_error(self, mcp_server):
"""Calling an unknown tool should yield an error output."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="nonexistent_tool",
tool_arguments={},
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "error"
assert "returned an error" in outputs[0][1]
@pytest.mark.asyncio(loop_scope="session")
async def test_full_flow_with_auth(self, mcp_server_with_auth):
"""Full flow with authentication via credentials kwarg."""
url, token = mcp_server_with_auth
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=url,
selected_tool="echo",
tool_input_schema={
"type": "object",
"properties": {"message": {"type": "string"}},
"required": ["message"],
},
tool_arguments={"message": "Authenticated!"},
)
# Pass credentials via the standard kwarg (as the executor would)
test_creds = OAuth2Credentials(
id="test-cred",
provider="mcp",
access_token=SecretStr(token),
refresh_token=SecretStr(""),
scopes=[],
title="Test MCP credential",
)
outputs = []
async for name, data in block.run(
input_data, user_id=MOCK_USER_ID, credentials=test_creds
):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1] == "Authenticated!"
@pytest.mark.asyncio(loop_scope="session")
async def test_no_credentials_runs_without_auth(self, mcp_server):
"""Block runs without auth when no credentials are provided."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="echo",
tool_input_schema={
"type": "object",
"properties": {"message": {"type": "string"}},
"required": ["message"],
},
tool_arguments={"message": "No auth needed"},
)
outputs = []
async for name, data in block.run(
input_data, user_id=MOCK_USER_ID, credentials=None
):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1] == "No auth needed"

View File

@@ -0,0 +1,619 @@
"""
Tests for MCP client and MCPToolBlock.
"""
import json
from unittest.mock import AsyncMock, patch
import pytest
from backend.blocks.mcp.block import MCPToolBlock
from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError
from backend.util.test import execute_block_test
# ── SSE parsing unit tests ───────────────────────────────────────────
class TestSSEParsing:
"""Tests for SSE (text/event-stream) response parsing."""
def test_parse_sse_simple(self):
sse = (
"event: message\n"
'data: {"jsonrpc":"2.0","result":{"tools":[]},"id":1}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == {"tools": []}
assert body["id"] == 1
def test_parse_sse_with_notifications(self):
"""SSE streams can contain notifications (no id) before the response."""
sse = (
"event: message\n"
'data: {"jsonrpc":"2.0","method":"some/notification"}\n'
"\n"
"event: message\n"
'data: {"jsonrpc":"2.0","result":{"ok":true},"id":2}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == {"ok": True}
assert body["id"] == 2
def test_parse_sse_error_response(self):
sse = (
"event: message\n"
'data: {"jsonrpc":"2.0","error":{"code":-32600,"message":"Bad Request"},"id":1}\n'
)
body = MCPClient._parse_sse_response(sse)
assert "error" in body
assert body["error"]["code"] == -32600
def test_parse_sse_no_data_raises(self):
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
MCPClient._parse_sse_response("event: message\n\n")
def test_parse_sse_empty_raises(self):
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
MCPClient._parse_sse_response("")
def test_parse_sse_ignores_non_data_lines(self):
sse = (
": comment line\n"
"event: message\n"
"id: 123\n"
'data: {"jsonrpc":"2.0","result":"ok","id":1}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == "ok"
def test_parse_sse_uses_last_response(self):
"""If multiple responses exist, use the last one."""
sse = (
'data: {"jsonrpc":"2.0","result":"first","id":1}\n'
"\n"
'data: {"jsonrpc":"2.0","result":"second","id":2}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == "second"
# ── MCPClient unit tests ─────────────────────────────────────────────
class TestMCPClient:
"""Tests for the MCP HTTP client."""
def test_build_headers_without_auth(self):
client = MCPClient("https://mcp.example.com")
headers = client._build_headers()
assert "Authorization" not in headers
assert headers["Content-Type"] == "application/json"
def test_build_headers_with_auth(self):
client = MCPClient("https://mcp.example.com", auth_token="my-token")
headers = client._build_headers()
assert headers["Authorization"] == "Bearer my-token"
def test_build_jsonrpc_request(self):
client = MCPClient("https://mcp.example.com")
req = client._build_jsonrpc_request("tools/list")
assert req["jsonrpc"] == "2.0"
assert req["method"] == "tools/list"
assert "id" in req
assert "params" not in req
def test_build_jsonrpc_request_with_params(self):
client = MCPClient("https://mcp.example.com")
req = client._build_jsonrpc_request(
"tools/call", {"name": "test", "arguments": {"x": 1}}
)
assert req["params"] == {"name": "test", "arguments": {"x": 1}}
def test_request_id_increments(self):
client = MCPClient("https://mcp.example.com")
req1 = client._build_jsonrpc_request("tools/list")
req2 = client._build_jsonrpc_request("tools/list")
assert req2["id"] > req1["id"]
def test_server_url_trailing_slash_stripped(self):
client = MCPClient("https://mcp.example.com/mcp/")
assert client.server_url == "https://mcp.example.com/mcp"
@pytest.mark.asyncio(loop_scope="session")
async def test_send_request_success(self):
client = MCPClient("https://mcp.example.com")
mock_response = AsyncMock()
mock_response.json.return_value = {
"jsonrpc": "2.0",
"result": {"tools": []},
"id": 1,
}
with patch.object(client, "_send_request", return_value={"tools": []}):
result = await client._send_request("tools/list")
assert result == {"tools": []}
@pytest.mark.asyncio(loop_scope="session")
async def test_send_request_error(self):
client = MCPClient("https://mcp.example.com")
async def mock_send(*args, **kwargs):
raise MCPClientError("MCP server error [-32600]: Invalid Request")
with patch.object(client, "_send_request", side_effect=mock_send):
with pytest.raises(MCPClientError, match="Invalid Request"):
await client._send_request("tools/list")
@pytest.mark.asyncio(loop_scope="session")
async def test_list_tools(self):
client = MCPClient("https://mcp.example.com")
mock_result = {
"tools": [
{
"name": "get_weather",
"description": "Get current weather for a city",
"inputSchema": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
},
{
"name": "search",
"description": "Search the web",
"inputSchema": {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
},
},
]
}
with patch.object(client, "_send_request", return_value=mock_result):
tools = await client.list_tools()
assert len(tools) == 2
assert tools[0].name == "get_weather"
assert tools[0].description == "Get current weather for a city"
assert tools[0].input_schema["properties"]["city"]["type"] == "string"
assert tools[1].name == "search"
@pytest.mark.asyncio(loop_scope="session")
async def test_list_tools_empty(self):
client = MCPClient("https://mcp.example.com")
with patch.object(client, "_send_request", return_value={"tools": []}):
tools = await client.list_tools()
assert tools == []
@pytest.mark.asyncio(loop_scope="session")
async def test_list_tools_none_result(self):
client = MCPClient("https://mcp.example.com")
with patch.object(client, "_send_request", return_value=None):
tools = await client.list_tools()
assert tools == []
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_success(self):
client = MCPClient("https://mcp.example.com")
mock_result = {
"content": [
{"type": "text", "text": json.dumps({"temp": 20, "city": "London"})}
],
"isError": False,
}
with patch.object(client, "_send_request", return_value=mock_result):
result = await client.call_tool("get_weather", {"city": "London"})
assert not result.is_error
assert len(result.content) == 1
assert result.content[0]["type"] == "text"
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_error(self):
client = MCPClient("https://mcp.example.com")
mock_result = {
"content": [{"type": "text", "text": "City not found"}],
"isError": True,
}
with patch.object(client, "_send_request", return_value=mock_result):
result = await client.call_tool("get_weather", {"city": "???"})
assert result.is_error
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_none_result(self):
client = MCPClient("https://mcp.example.com")
with patch.object(client, "_send_request", return_value=None):
result = await client.call_tool("get_weather", {"city": "London"})
assert result.is_error
@pytest.mark.asyncio(loop_scope="session")
async def test_initialize(self):
client = MCPClient("https://mcp.example.com")
mock_result = {
"protocolVersion": "2025-03-26",
"capabilities": {"tools": {}},
"serverInfo": {"name": "test-server", "version": "1.0.0"},
}
with (
patch.object(client, "_send_request", return_value=mock_result) as mock_req,
patch.object(client, "_send_notification") as mock_notif,
):
result = await client.initialize()
mock_req.assert_called_once()
mock_notif.assert_called_once_with("notifications/initialized")
assert result["protocolVersion"] == "2025-03-26"
# ── MCPToolBlock unit tests ──────────────────────────────────────────
MOCK_USER_ID = "test-user-123"
class TestMCPToolBlock:
"""Tests for the MCPToolBlock."""
def test_block_instantiation(self):
block = MCPToolBlock()
assert block.id == "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
assert block.name == "MCPToolBlock"
def test_input_schema_has_required_fields(self):
block = MCPToolBlock()
schema = block.input_schema.jsonschema()
props = schema.get("properties", {})
assert "server_url" in props
assert "selected_tool" in props
assert "tool_arguments" in props
assert "credentials" in props
def test_output_schema(self):
block = MCPToolBlock()
schema = block.output_schema.jsonschema()
props = schema.get("properties", {})
assert "result" in props
assert "error" in props
def test_get_input_schema_with_tool_schema(self):
tool_schema = {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
}
data = {"tool_input_schema": tool_schema}
result = MCPToolBlock.Input.get_input_schema(data)
assert result == tool_schema
def test_get_input_schema_without_tool_schema(self):
result = MCPToolBlock.Input.get_input_schema({})
assert result == {}
def test_get_input_defaults(self):
data = {"tool_arguments": {"city": "London"}}
result = MCPToolBlock.Input.get_input_defaults(data)
assert result == {"city": "London"}
def test_get_missing_input(self):
data = {
"tool_input_schema": {
"type": "object",
"properties": {
"city": {"type": "string"},
"units": {"type": "string"},
},
"required": ["city", "units"],
},
"tool_arguments": {"city": "London"},
}
missing = MCPToolBlock.Input.get_missing_input(data)
assert missing == {"units"}
def test_get_missing_input_all_present(self):
data = {
"tool_input_schema": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
"tool_arguments": {"city": "London"},
}
missing = MCPToolBlock.Input.get_missing_input(data)
assert missing == set()
@pytest.mark.asyncio(loop_scope="session")
async def test_run_with_mock(self):
"""Test the block using the built-in test infrastructure."""
block = MCPToolBlock()
await execute_block_test(block)
@pytest.mark.asyncio(loop_scope="session")
async def test_run_missing_server_url(self):
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="",
selected_tool="test",
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert outputs == [("error", "MCP server URL is required")]
@pytest.mark.asyncio(loop_scope="session")
async def test_run_missing_tool(self):
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="",
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert outputs == [
("error", "No tool selected. Please select a tool from the dropdown.")
]
@pytest.mark.asyncio(loop_scope="session")
async def test_run_success(self):
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="get_weather",
tool_input_schema={
"type": "object",
"properties": {"city": {"type": "string"}},
},
tool_arguments={"city": "London"},
)
async def mock_call(*args, **kwargs):
return {"temp": 20, "city": "London"}
block._call_mcp_tool = mock_call # type: ignore
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1] == {"temp": 20, "city": "London"}
@pytest.mark.asyncio(loop_scope="session")
async def test_run_mcp_error(self):
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="bad_tool",
)
async def mock_call(*args, **kwargs):
raise MCPClientError("Tool not found")
block._call_mcp_tool = mock_call # type: ignore
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert outputs[0][0] == "error"
assert "Tool not found" in outputs[0][1]
@pytest.mark.asyncio(loop_scope="session")
async def test_call_mcp_tool_parses_json_text(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[
{"type": "text", "text": '{"temp": 20}'},
],
is_error=False,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
result = await block._call_mcp_tool(
"https://mcp.example.com", "test_tool", {}
)
assert result == {"temp": 20}
@pytest.mark.asyncio(loop_scope="session")
async def test_call_mcp_tool_plain_text(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[
{"type": "text", "text": "Hello, world!"},
],
is_error=False,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
result = await block._call_mcp_tool(
"https://mcp.example.com", "test_tool", {}
)
assert result == "Hello, world!"
@pytest.mark.asyncio(loop_scope="session")
async def test_call_mcp_tool_multiple_content(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[
{"type": "text", "text": "Part 1"},
{"type": "text", "text": '{"part": 2}'},
],
is_error=False,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
result = await block._call_mcp_tool(
"https://mcp.example.com", "test_tool", {}
)
assert result == ["Part 1", {"part": 2}]
@pytest.mark.asyncio(loop_scope="session")
async def test_call_mcp_tool_error_result(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[{"type": "text", "text": "Something went wrong"}],
is_error=True,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
with pytest.raises(MCPClientError, match="returned an error"):
await block._call_mcp_tool("https://mcp.example.com", "test_tool", {})
@pytest.mark.asyncio(loop_scope="session")
async def test_call_mcp_tool_image_content(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[
{
"type": "image",
"data": "base64data==",
"mimeType": "image/png",
}
],
is_error=False,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
result = await block._call_mcp_tool(
"https://mcp.example.com", "test_tool", {}
)
assert result == {
"type": "image",
"data": "base64data==",
"mimeType": "image/png",
}
@pytest.mark.asyncio(loop_scope="session")
async def test_run_with_credentials(self):
"""Verify the block uses OAuth2Credentials and passes auth token."""
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="test_tool",
)
captured_tokens: list[str | None] = []
async def mock_call(server_url, tool_name, arguments, auth_token=None):
captured_tokens.append(auth_token)
return "ok"
block._call_mcp_tool = mock_call # type: ignore
test_creds = OAuth2Credentials(
id="cred-123",
provider="mcp",
access_token=SecretStr("resolved-token"),
refresh_token=SecretStr(""),
scopes=[],
title="Test MCP credential",
)
async for _ in block.run(
input_data, user_id=MOCK_USER_ID, credentials=test_creds
):
pass
assert captured_tokens == ["resolved-token"]
@pytest.mark.asyncio(loop_scope="session")
async def test_run_without_credentials(self):
"""Verify the block works without credentials (public server)."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="test_tool",
)
captured_tokens: list[str | None] = []
async def mock_call(server_url, tool_name, arguments, auth_token=None):
captured_tokens.append(auth_token)
return "ok"
block._call_mcp_tool = mock_call # type: ignore
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert captured_tokens == [None]
assert outputs == [("result", "ok")]

View File

@@ -0,0 +1,242 @@
"""
Tests for MCP OAuth handler.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
from backend.blocks.mcp.client import MCPClient
from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
def _mock_response(json_data: dict, status: int = 200) -> MagicMock:
"""Create a mock Response with synchronous json() (matching Requests.Response)."""
resp = MagicMock()
resp.status = status
resp.ok = 200 <= status < 300
resp.json.return_value = json_data
return resp
class TestMCPOAuthHandler:
"""Tests for the MCPOAuthHandler."""
def _make_handler(self, **overrides) -> MCPOAuthHandler:
defaults = {
"client_id": "test-client-id",
"client_secret": "test-client-secret",
"redirect_uri": "https://app.example.com/callback",
"authorize_url": "https://auth.example.com/authorize",
"token_url": "https://auth.example.com/token",
}
defaults.update(overrides)
return MCPOAuthHandler(**defaults)
def test_get_login_url_basic(self):
handler = self._make_handler()
url = handler.get_login_url(
scopes=["read", "write"],
state="random-state-token",
code_challenge="S256-challenge-value",
)
assert "https://auth.example.com/authorize?" in url
assert "response_type=code" in url
assert "client_id=test-client-id" in url
assert "state=random-state-token" in url
assert "code_challenge=S256-challenge-value" in url
assert "code_challenge_method=S256" in url
assert "scope=read+write" in url
def test_get_login_url_with_resource(self):
handler = self._make_handler(resource_url="https://mcp.example.com/mcp")
url = handler.get_login_url(
scopes=[], state="state", code_challenge="challenge"
)
assert "resource=https" in url
def test_get_login_url_without_pkce(self):
handler = self._make_handler()
url = handler.get_login_url(scopes=["read"], state="state", code_challenge=None)
assert "code_challenge" not in url
assert "code_challenge_method" not in url
@pytest.mark.asyncio(loop_scope="session")
async def test_exchange_code_for_tokens(self):
handler = self._make_handler()
resp = _mock_response(
{
"access_token": "new-access-token",
"refresh_token": "new-refresh-token",
"expires_in": 3600,
"token_type": "Bearer",
}
)
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
instance = MockRequests.return_value
instance.post = AsyncMock(return_value=resp)
creds = await handler.exchange_code_for_tokens(
code="auth-code",
scopes=["read"],
code_verifier="pkce-verifier",
)
assert isinstance(creds, OAuth2Credentials)
assert creds.access_token.get_secret_value() == "new-access-token"
assert creds.refresh_token is not None
assert creds.refresh_token.get_secret_value() == "new-refresh-token"
assert creds.scopes == ["read"]
assert creds.access_token_expires_at is not None
@pytest.mark.asyncio(loop_scope="session")
async def test_refresh_tokens(self):
handler = self._make_handler()
existing_creds = OAuth2Credentials(
id="existing-id",
provider="mcp",
access_token=SecretStr("old-token"),
refresh_token=SecretStr("old-refresh"),
scopes=["read"],
title="test",
)
resp = _mock_response(
{
"access_token": "refreshed-token",
"refresh_token": "new-refresh",
"expires_in": 3600,
}
)
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
instance = MockRequests.return_value
instance.post = AsyncMock(return_value=resp)
refreshed = await handler._refresh_tokens(existing_creds)
assert refreshed.id == "existing-id"
assert refreshed.access_token.get_secret_value() == "refreshed-token"
assert refreshed.refresh_token is not None
assert refreshed.refresh_token.get_secret_value() == "new-refresh"
@pytest.mark.asyncio(loop_scope="session")
async def test_refresh_tokens_no_refresh_token(self):
handler = self._make_handler()
creds = OAuth2Credentials(
provider="mcp",
access_token=SecretStr("token"),
scopes=["read"],
title="test",
)
with pytest.raises(ValueError, match="No refresh token"):
await handler._refresh_tokens(creds)
@pytest.mark.asyncio(loop_scope="session")
async def test_revoke_tokens_no_url(self):
handler = self._make_handler(revoke_url=None)
creds = OAuth2Credentials(
provider="mcp",
access_token=SecretStr("token"),
scopes=[],
title="test",
)
result = await handler.revoke_tokens(creds)
assert result is False
@pytest.mark.asyncio(loop_scope="session")
async def test_revoke_tokens_with_url(self):
handler = self._make_handler(revoke_url="https://auth.example.com/revoke")
creds = OAuth2Credentials(
provider="mcp",
access_token=SecretStr("token"),
scopes=[],
title="test",
)
resp = _mock_response({}, status=200)
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
instance = MockRequests.return_value
instance.post = AsyncMock(return_value=resp)
result = await handler.revoke_tokens(creds)
assert result is True
class TestMCPClientDiscovery:
"""Tests for MCPClient OAuth metadata discovery."""
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_auth_found(self):
client = MCPClient("https://mcp.example.com/mcp")
metadata = {
"authorization_servers": ["https://auth.example.com"],
"resource": "https://mcp.example.com/mcp",
}
resp = _mock_response(metadata, status=200)
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
instance = MockRequests.return_value
instance.get = AsyncMock(return_value=resp)
result = await client.discover_auth()
assert result is not None
assert result["authorization_servers"] == ["https://auth.example.com"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_auth_not_found(self):
client = MCPClient("https://mcp.example.com/mcp")
resp = _mock_response({}, status=404)
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
instance = MockRequests.return_value
instance.get = AsyncMock(return_value=resp)
result = await client.discover_auth()
assert result is None
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_auth_server_metadata(self):
client = MCPClient("https://mcp.example.com/mcp")
server_metadata = {
"issuer": "https://auth.example.com",
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
"registration_endpoint": "https://auth.example.com/register",
"code_challenge_methods_supported": ["S256"],
}
resp = _mock_response(server_metadata, status=200)
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
instance = MockRequests.return_value
instance.get = AsyncMock(return_value=resp)
result = await client.discover_auth_server_metadata(
"https://auth.example.com"
)
assert result is not None
assert result["authorization_endpoint"] == "https://auth.example.com/authorize"
assert result["token_endpoint"] == "https://auth.example.com/token"

View File

@@ -0,0 +1,162 @@
"""
Minimal MCP server for integration testing.
Implements the MCP Streamable HTTP transport (JSON-RPC 2.0 over HTTP POST)
with a few sample tools. Runs on localhost with a random available port.
"""
import json
import logging
from aiohttp import web
logger = logging.getLogger(__name__)
# Sample tools this test server exposes
TEST_TOOLS = [
{
"name": "get_weather",
"description": "Get current weather for a city",
"inputSchema": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City name",
},
},
"required": ["city"],
},
},
{
"name": "add_numbers",
"description": "Add two numbers together",
"inputSchema": {
"type": "object",
"properties": {
"a": {"type": "number", "description": "First number"},
"b": {"type": "number", "description": "Second number"},
},
"required": ["a", "b"],
},
},
{
"name": "echo",
"description": "Echo back the input message",
"inputSchema": {
"type": "object",
"properties": {
"message": {"type": "string", "description": "Message to echo"},
},
"required": ["message"],
},
},
]
def _handle_initialize(params: dict) -> dict:
return {
"protocolVersion": "2025-03-26",
"capabilities": {"tools": {"listChanged": False}},
"serverInfo": {"name": "test-mcp-server", "version": "1.0.0"},
}
def _handle_tools_list(params: dict) -> dict:
return {"tools": TEST_TOOLS}
def _handle_tools_call(params: dict) -> dict:
tool_name = params.get("name", "")
arguments = params.get("arguments", {})
if tool_name == "get_weather":
city = arguments.get("city", "Unknown")
return {
"content": [
{
"type": "text",
"text": json.dumps(
{"city": city, "temperature": 22, "condition": "sunny"}
),
}
],
}
elif tool_name == "add_numbers":
a = arguments.get("a", 0)
b = arguments.get("b", 0)
return {
"content": [{"type": "text", "text": json.dumps({"result": a + b})}],
}
elif tool_name == "echo":
message = arguments.get("message", "")
return {
"content": [{"type": "text", "text": message}],
}
else:
return {
"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}],
"isError": True,
}
HANDLERS = {
"initialize": _handle_initialize,
"tools/list": _handle_tools_list,
"tools/call": _handle_tools_call,
}
async def handle_mcp_request(request: web.Request) -> web.Response:
"""Handle incoming MCP JSON-RPC 2.0 requests."""
# Check auth if configured
expected_token = request.app.get("auth_token")
if expected_token:
auth_header = request.headers.get("Authorization", "")
if auth_header != f"Bearer {expected_token}":
return web.json_response(
{
"jsonrpc": "2.0",
"error": {"code": -32001, "message": "Unauthorized"},
"id": None,
},
status=401,
)
body = await request.json()
# Handle notifications (no id field) — just acknowledge
if "id" not in body:
return web.Response(status=202)
method = body.get("method", "")
params = body.get("params", {})
request_id = body.get("id")
handler = HANDLERS.get(method)
if not handler:
return web.json_response(
{
"jsonrpc": "2.0",
"error": {
"code": -32601,
"message": f"Method not found: {method}",
},
"id": request_id,
}
)
result = handler(params)
return web.json_response({"jsonrpc": "2.0", "result": result, "id": request_id})
def create_test_mcp_app(auth_token: str | None = None) -> web.Application:
"""Create an aiohttp app that acts as an MCP server."""
app = web.Application()
app.router.add_post("/mcp", handle_mcp_request)
if auth_token:
app["auth_token"] = auth_token
return app

View File

@@ -3,7 +3,7 @@ from typing import List, Literal
from pydantic import SecretStr from pydantic import SecretStr
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -3,7 +3,7 @@ from typing import Any, Literal, Optional, Union
from mem0 import MemoryClient from mem0 import MemoryClient
from pydantic import BaseModel, SecretStr from pydantic import BaseModel, SecretStr
from backend.data.block import Block, BlockOutput, BlockSchemaInput, BlockSchemaOutput from backend.blocks._base import Block, BlockOutput, BlockSchemaInput, BlockSchemaOutput
from backend.data.model import ( from backend.data.model import (
APIKeyCredentials, APIKeyCredentials,
CredentialsField, CredentialsField,

View File

@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
from pydantic import model_validator from pydantic import model_validator
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -4,7 +4,7 @@ from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -1,15 +1,15 @@
from backend.blocks.nvidia._auth import ( from backend.blocks._base import (
NvidiaCredentials,
NvidiaCredentialsField,
NvidiaCredentialsInput,
)
from backend.data.block import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,
BlockSchemaInput, BlockSchemaInput,
BlockSchemaOutput, BlockSchemaOutput,
) )
from backend.blocks.nvidia._auth import (
NvidiaCredentials,
NvidiaCredentialsField,
NvidiaCredentialsInput,
)
from backend.data.model import SchemaField from backend.data.model import SchemaField
from backend.util.request import Requests from backend.util.request import Requests
from backend.util.type import MediaFileType from backend.util.type import MediaFileType

View File

@@ -6,7 +6,7 @@ from typing import Any, Literal
import openai import openai
from pydantic import SecretStr from pydantic import SecretStr
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

View File

@@ -1,7 +1,7 @@
import logging import logging
from typing import Any, Literal from typing import Any, Literal
from backend.data.block import ( from backend.blocks._base import (
Block, Block,
BlockCategory, BlockCategory,
BlockOutput, BlockOutput,

Some files were not shown because too many files have changed in this diff Show More