Compare commits

..

68 Commits

Author SHA1 Message Date
Zamil Majdy
4ac025da09 fix(backend/mcp): Fix test_routes.py tests and refine exception handling
- discover_tools: Keep generic Exception catch → 502 for MCP connectivity errors
- mcp_oauth_callback: Add targeted exception handler for token exchange → 400
- test_routes: Fix discover_auth_server_metadata mock for no-OAuth-support case
- oauth.py: Auto-formatting only
2026-02-10 22:22:38 +04:00
Zamil Majdy
fbe4c740cb refactor(backend/mcp): Remove useless catch-reraise exception patterns in routes
Drop the broad `except Exception` catch-and-reraise-as-HTTPException
blocks. Keep only the meaningful error handlers (HTTPClientError for
401/403, MCPClientError for 502). Unhandled exceptions now propagate
naturally to FastAPI's default 500 handler.
2026-02-10 22:09:42 +04:00
Zamil Majdy
ff48f4335b fix(backend/mcp): Set loop_scope=session on all async MCP tests
Matches the pattern used by oauth_test.py to prevent event loop
conflicts with session-scoped fixtures (server, graph_cleanup).
2026-02-10 22:02:48 +04:00
Zamil Majdy
11fbb51a70 fix(tests): Remove MCP conftest.py to fix session event loop conflict
The MCP conftest.py with pytest hooks (pytest_addoption,
pytest_collection_modifyitems) was disrupting pytest-asyncio's session
event loop lifecycle, causing the SpinTestServer to be torn down before
session-scoped oauth tests could run.

Replace the conftest-based e2e gating with a simple pytestmark skipif
in the test file itself.
2026-02-10 20:48:30 +04:00
Zamil Majdy
bb8b56c7de fix(mcp): Validate JSON-RPC response is a dict before accessing keys
Add isinstance check after response.json() to prevent TypeError/
AttributeError if an MCP server returns a non-object JSON response.
2026-02-10 20:31:37 +04:00
Zamil Majdy
6ebd97f874 fix(executor): Extract only tool_arguments for MCPToolBlock input reshaping
The entire merged input_data dict (containing server_url, credentials,
selected_tool, etc.) was being assigned to tool_arguments instead of
just the tool_arguments sub-dict. This would cause validation failures
or MCP server rejections.
2026-02-10 20:22:47 +04:00
Zamil Majdy
e934f0d0c2 fix(tests): Remove session-scoped fixture overrides from MCP conftest
The MCP conftest.py was overriding session-scoped `server` and
`graph_cleanup` fixtures with no-op versions. Having two session-scoped
fixtures with the same name at different directory levels caused
pytest-asyncio event loop conflicts, making all oauth_test.py tests
fail with "Event loop is closed".

Since these fixtures are session-scoped and shared across the entire
test run, the override was unnecessary — the SpinTestServer is already
created for other tests.

Also adds defensive `access_token` key validation in MCP OAuth token
exchange and refresh to prevent KeyError on malformed responses.
2026-02-10 20:17:07 +04:00
Zamil Majdy
3e38b141dd fix(tests): Use async pytest_asyncio fixtures in MCP conftest
The MCP conftest's sync server/graph_cleanup fixtures must match the
parent conftest's async pytest_asyncio fixtures to avoid disrupting
the session event loop management, which caused "Event loop is closed"
errors in oauth_test.py tests.
2026-02-10 18:40:22 +04:00
Zamil Majdy
ec72e7eb7b fix(tests): Restore pytest_asyncio.fixture with loop_scope for session fixtures
The server and graph_cleanup fixtures in conftest.py require explicit
pytest_asyncio.fixture(loop_scope="session") to properly manage the
session event loop. Using plain pytest.fixture causes "Event loop is
closed" errors in all oauth_test.py tests.
2026-02-10 18:16:36 +04:00
Zamil Majdy
db038cd0e0 fix(tests): Revert oauth_test.py fixture scope changes to fix event loop error
Restores session-scoped fixtures and pytest_asyncio decorators that were
accidentally changed, causing "RuntimeError: Event loop is closed" in
test_authorize_creates_code_in_database. Also regenerates openapi.json.
2026-02-10 17:32:52 +04:00
Zamil Majdy
6805a4f3c5 Merge branch 'dev' into feat/mcp-blocks 2026-02-10 17:27:23 +04:00
Abhimanyu Yadav
7d4c020a9b feat(chat): implement AI SDK integration with custom streaming response handling (#11901)
### Changes 🏗️

- Added AI SDK integration for chat streaming with proper message
handling
- Implemented custom to_sse method in StreamToolOutputAvailable to
exclude non-spec fields
- Modified stream_chat_completion to reuse message IDs for tool call
continuations
- Created new Copilot 2.0 UI with AI SDK React components
- Added streamdown and related packages for markdown rendering
- Built reusable conversation and message components for the chat
interface
- Added support for tool output display in the chat UI

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Start a new chat session and verify streaming works correctly
  - [x] Test tool calls and verify they display properly in the UI
  - [x] Verify message continuations don't create duplicate messages
  - [x] Test markdown rendering with code blocks and other formatting
  - [x] Verify the UI is responsive and scrolls correctly

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

---------

Co-authored-by: Lluis Agusti <hi@llu.lu>
Co-authored-by: Ubbe <hi@ubbe.dev>
2026-02-10 21:12:21 +08:00
Zamil Majdy
cb7a0cbdd7 refactor(mcp): Make create_mcp_oauth_handler public and use top-level import 2026-02-10 17:08:03 +04:00
Zamil Majdy
79d6e8e2d7 fix(frontend/credentials): Handle stale credential IDs in CredentialsSelect
When a credential is deleted but the node still references its ID,
CredentialsSelect now treats the stale ID as unselected and falls
back to the first available credential instead of showing the raw ID.
2026-02-10 17:02:08 +04:00
Zamil Majdy
472117a872 fix(mcp): Handle MCP credential deletion with dynamic OAuth handler
MCP credentials use per-server dynamic OAuth handlers, not a static
handler registered in HANDLERS_BY_NAME. The delete endpoint now
creates a dynamic handler from credential metadata for token
revocation instead of failing with "Provider 'mcp' does not support
OAuth".
2026-02-10 16:47:36 +04:00
Zamil Majdy
75a7ccf36e fix(mcp): Address PR review comments - defensive checks and docs
- Validate token_endpoint in OAuth metadata before accessing it
- Check authorization_servers list is non-empty before indexing
- Use provider_matches() (renamed from private _provider_matches) in
  creds_manager for Python 3.13 StrEnum compatibility
- Fill in MCP block documentation with technical explanation and use cases
2026-02-10 16:42:29 +04:00
Zamil Majdy
84809f4b94 fix(mcp): Set customized_name at block creation and add pre-run validation
- Set customized_name in metadata when MCP and Agent blocks are created
  (both legacy and new builder) so titles persist through save/load
- Remove convoluted agent_name fallback from NodeHeader and getNodeTitle
- Add custom block-level validation in graph pre-run checks so MCP tool
  arguments are validated before execution
- Fix server_name fallback to URL hostname in discover_tools endpoint
2026-02-10 16:34:19 +04:00
Zamil Majdy
4364a771d4 fix(mcp): Validate required tool args and fix title fallback for existing blocks
Backend: Add required-field validation in MCPToolBlock.run() before
calling the MCP server. The executor-level validation is bypassed for
MCP blocks because get_input_defaults() flattens tool_arguments,
stripping tool_input_schema from the validation context.

Frontend: NodeHeader now derives the MCP server label from the server
URL hostname when server_name is missing (pruned by pruneEmptyValues).
This fixes the title for existing blocks that don't have customized_name
in metadata.
2026-02-10 15:42:51 +04:00
Zamil Majdy
4d4ed562f0 fix(frontend/mcp): Ensure MCP block title persists across save/refresh
When the MCP server returns a null server_name, fall back to the URL
hostname so customized_name is always set in metadata. This prevents
the title from degrading to "MCP:" after save and reload.
2026-02-10 15:31:29 +04:00
dependabot[bot]
e596ea87cb chore(libs/deps-dev): bump pytest-cov from 6.2.1 to 7.0.0 in /autogpt_platform/autogpt_libs (#12030)
Bumps [pytest-cov](https://github.com/pytest-dev/pytest-cov) from 6.2.1
to 7.0.0.
<details>
<summary>Changelog</summary>
<p><em>Sourced from <a
href="https://github.com/pytest-dev/pytest-cov/blob/master/CHANGELOG.rst">pytest-cov's
changelog</a>.</em></p>
<blockquote>
<h2>7.0.0 (2025-09-09)</h2>
<ul>
<li>
<p>Dropped support for subprocesses measurement.</p>
<p>It was a feature added long time ago when coverage lacked a nice way
to measure subprocesses created in tests.
It relied on a <code>.pth</code> file, there was no way to opt-out and
it created bad interations
with <code>coverage's new patch system
&lt;https://coverage.readthedocs.io/en/latest/config.html#run-patch&gt;</code>_
added
in <code>7.10
&lt;https://coverage.readthedocs.io/en/7.10.6/changes.html#version-7-10-0-2025-07-24&gt;</code>_.</p>
<p>To migrate to this release you might need to enable the suprocess
patch, example for <code>.coveragerc</code>:</p>
<p>.. code-block:: ini</p>
<p>[run]
patch = subprocess</p>
<p>This release also requires at least coverage 7.10.6.</p>
</li>
<li>
<p>Switched packaging to have metadata completely in
<code>pyproject.toml</code> and use <code>hatchling
&lt;https://pypi.org/project/hatchling/&gt;</code>_ for
building.
Contributed by Ofek Lev in
<code>[#551](https://github.com/pytest-dev/pytest-cov/issues/551)
&lt;https://github.com/pytest-dev/pytest-cov/pull/551&gt;</code>_
with some extras in
<code>[#716](https://github.com/pytest-dev/pytest-cov/issues/716)
&lt;https://github.com/pytest-dev/pytest-cov/pull/716&gt;</code>_.</p>
</li>
<li>
<p>Removed some not really necessary testing deps like
<code>six</code>.</p>
</li>
</ul>
<h2>6.3.0 (2025-09-06)</h2>
<ul>
<li>Added support for markdown reports.
Contributed by Marcos Boger in
<code>[#712](https://github.com/pytest-dev/pytest-cov/issues/712)
&lt;https://github.com/pytest-dev/pytest-cov/pull/712&gt;</code>_
and <code>[#714](https://github.com/pytest-dev/pytest-cov/issues/714)
&lt;https://github.com/pytest-dev/pytest-cov/pull/714&gt;</code>_.</li>
<li>Fixed some formatting issues in docs.
Anonymous contribution in
<code>[#706](https://github.com/pytest-dev/pytest-cov/issues/706)
&lt;https://github.com/pytest-dev/pytest-cov/pull/706&gt;</code>_.</li>
</ul>
</blockquote>
</details>
<details>
<summary>Commits</summary>
<ul>
<li><a
href="224d8964ca"><code>224d896</code></a>
Bump version: 6.3.0 → 7.0.0</li>
<li><a
href="73424e3999"><code>73424e3</code></a>
Cleanup the docs a bit.</li>
<li><a
href="36f1cc2967"><code>36f1cc2</code></a>
Bump pins in template.</li>
<li><a
href="f299c590a6"><code>f299c59</code></a>
Bump the github-actions group with 2 updates</li>
<li><a
href="25f0b2e0cd"><code>25f0b2e</code></a>
Update docs/config.rst</li>
<li><a
href="bb23eacc55"><code>bb23eac</code></a>
Improve configuration docs</li>
<li><a
href="a19531e91e"><code>a19531e</code></a>
Switch from build/pre-commit to uv/prek - this should make this
faster.</li>
<li><a
href="82f9993910"><code>82f9993</code></a>
Update changelog.</li>
<li><a
href="211b5cd41c"><code>211b5cd</code></a>
Fix links.</li>
<li><a
href="97aadd74bc"><code>97aadd7</code></a>
Update some ci config, reformat and apply some lint fixes.</li>
<li>Additional commits viewable in <a
href="https://github.com/pytest-dev/pytest-cov/compare/v6.2.1...v7.0.0">compare
view</a></li>
</ul>
</details>
<br />


[![Dependabot compatibility
score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=pytest-cov&package-manager=pip&previous-version=6.2.1&new-version=7.0.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)

Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.

[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)

---

<details>
<summary>Dependabot commands and options</summary>
<br />

You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot show <dependency name> ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)


</details>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
Co-authored-by: Otto <otto@agpt.co>
2026-02-10 11:22:25 +00:00
Zamil Majdy
8bea7cf875 chore(mcp): Remove dev artifacts and simplify credential lookup
- Remove MCP_BLOCK_IMPLEMENTATION.md development doc
- Remove console.log debug statements from OAuth callback
- Simplify credential lookup to single call (get_creds_by_provider
  already handles Python 3.13 StrEnum bug via _provider_matches)
- Remove unused Credentials import from routes.py
2026-02-10 15:18:55 +04:00
Zamil Majdy
c1c269c4a9 fix(frontend/credentials): Auto-select first credential and persist MCP block title
- CredentialsSelect: default to first available credential instead of
  "None" when credentials exist, reorder options to show credentials
  before the "None" option, and notify parent on auto-select
- Revert CredentialsGroupedView user auto-select effect (now handled
  at the CredentialsSelect level)
- Block.tsx: persist MCP block title as customized_name in metadata
  so it survives save/load
2026-02-10 14:56:19 +04:00
Zamil Majdy
65987ff15e fix(frontend/credentials): Auto-select user credentials in run dialog
Add auto-selection for user credentials (like MCP OAuth) in the
CredentialsGroupedView run dialog. When exactly one credential matches
the provider, type, and discriminator values (e.g. MCP server URL),
it is pre-selected instead of defaulting to "None (skip this credential)".
2026-02-10 14:21:23 +04:00
Zamil Majdy
ed50f7f87d fix(mcp): Wire credentials into MCP block form and add auto-lookup fallback
Frontend: Include credentials field in MCP block's dynamic input schema
so users can select OAuth credentials from the node form. Separate
credentials from tool_arguments in FormCreator to store them at the
correct level in hardcodedValues.

Backend: Add _auto_lookup_credential fallback in MCPToolBlock.run() for
legacy nodes that don't have credentials explicitly set. This resolves
the credential by matching mcp_server_url in stored OAuth metadata.
2026-02-10 14:05:09 +04:00
Zamil Majdy
c03fb170e0 fix(backend/credentials): Handle Python 3.13 str(StrEnum) bug in OAuth state verification
verify_state_token and get_creds_by_provider compared provider strings
with ==, which failed when OAuth states were stored with the buggy
"ProviderName.MCP" format from Python 3.13's str(Enum) behavior.

Also fix double-append in store_state_token where the state was written
once via edit_user_integrations and again via a redundant manual block.
2026-02-10 13:32:38 +04:00
Zamil Majdy
8a2f98b23c fix(mcp): Fix discover_tools test mock and credential auto-unselect
- Add missing `refresh_if_needed` mock to test_discover_tools_auto_uses_stored_credential
  so it returns the stored credential instead of a MagicMock
- Fix credential auto-unselect clearing MCP credentials on initial render:
  skip the "unselect if not available" check when the saved credentials
  list is empty (empty list means not loaded yet, not invalid)
2026-02-10 12:55:35 +04:00
Zamil Majdy
5e2ae3cec5 Merge branch 'dev' into feat/mcp-blocks 2026-02-10 12:45:34 +04:00
Zamil Majdy
f8771484fe fix(mcp): Fix CI failures and credential validation issues
- Fix pyright errors in graph_test.py by properly typing frozenset[CredentialsType]
- Fix executor validation crash when credentials is empty {} by nullifying
  the field before JSON schema validation
- Exclude MCP Tool block from e2e block discovery test (requires dialog)
- Normalize provider string in CredentialsMetaResponse to handle Python 3.13
  str(Enum) bug for stored credentials
- Fix get_host() to match MCP provider regardless of enum string format
2026-02-10 12:34:00 +04:00
Zamil Majdy
81e4f0a4b0 fix(mcp): Refresh expired OAuth tokens before tool discovery
The discover_tools endpoint was reading raw access tokens from stored
credentials without checking if they had expired. This caused users
to be prompted to re-authenticate every time the token expired (~1h).

Now uses creds_manager.refresh_if_needed() to transparently refresh
expired tokens before using them.
2026-02-10 12:18:43 +04:00
Zamil Majdy
66aada30f0 fix(tests): Revert conftest.py and oauth_test.py fixture scope changes
The pytest_asyncio fixture changes with loop_scope="session" caused
"Event loop is closed" errors in all 31 oauth_test.py tests on CI.
MCP tests have their own conftest override and don't need these changes.
2026-02-10 11:56:28 +04:00
Zamil Majdy
74e04f71f4 fix(tests): Properly mock get_required_fields in credential validation tests
The tests used MagicMock for block.input_schema but didn't mock
get_required_fields(), causing the "required missing creds" test to
silently treat all credentials as optional.
2026-02-10 11:44:12 +04:00
Otto
81f8290f01 debug(backend/db): Add diagnostic logging for vector type errors (#12024)
Adds diagnostic logging when the `type vector does not exist` error
occurs in raw SQL queries.

## Problem

We're seeing intermittent "type vector does not exist" errors on
dev-behave ([Sentry
issue](https://significant-gravitas.sentry.io/issues/7205929979/)). The
pgvector extension should be in the search_path, but occasionally
queries fail to resolve the vector type.

## Solution

When a query fails with this specific error, we now log:
- `SHOW search_path` - what schemas are being searched
- `SELECT current_schema()` - the active schema
- `SELECT current_user, session_user, current_database()` - connection
context

This diagnostic info will help identify why the vector extension isn't
visible in certain cases.

## Changes

- Added `_log_vector_error_diagnostics()` helper function in
`backend/data/db.py`
- Wrapped SQL execution in try/except to catch and diagnose vector type
errors
- Original exception is re-raised after logging (no behavior change)

## Testing

This is observational/diagnostic code. It will be validated by waiting
for the error to occur naturally on dev and checking the logs.

## Rollout

Once we've captured diagnostic logs and identified the root cause, this
logging can be removed or reduced in verbosity.
2026-02-10 07:35:13 +00:00
Zamil Majdy
4db27ca112 fix(mcp): Auto-select credential and gracefully handle stale IDs
- Auto-select credential when exactly one match exists (even for
  optional fields). Only skip auto-select for optional fields with
  multiple choices.
- In executor, catch ValueError from creds_manager.acquire() for
  optional credential fields — fall back to running without credentials
  instead of crashing when stale IDs reference deleted credentials.
2026-02-10 11:25:06 +04:00
Zamil Majdy
27ba4e8e93 fix(frontend): Remove credential field reordering on selection
The sortByUnsetFirst comparator in splitCredentialFieldsBySystem
caused credential inputs to jump positions every time a credential
was selected (set fields moved to bottom, unset moved to top).
Remove the sort to keep stable ordering.
2026-02-10 11:18:42 +04:00
Zamil Majdy
1a1985186a fix(mcp): Normalize credential input_data for JSON schema validation
The model_validator on CredentialsMetaInput normalizes legacy
"ProviderName.MCP" format for Pydantic validation, but validate_data()
uses raw JSON schema which bypasses Pydantic. Write normalized values
back to input_data after Pydantic processes them so both validation
paths see correct data.
2026-02-10 11:15:10 +04:00
Zamil Majdy
8fd13ade74 fix(mcp): Normalize legacy ProviderName format and fix credential optionality
- Add model_validator on CredentialsMetaInput to auto-normalize old
  "ProviderName.MCP" format to "mcp" at the model level, eliminating
  the need for string parsing hacks in every consumer.

- Fix aggregate_credentials_inputs to check block schema defaults when
  determining if credentials are required, not just node metadata.
  MCP blocks with default={} are always optional regardless of metadata.
2026-02-10 09:43:28 +04:00
Zamil Majdy
88ee4b3a11 fix(mcp): Clean up old credentials stored with wrong provider string
Also search for credentials stored with "ProviderName.MCP" (from the
Python 3.13 str(Enum) bug) during both discover-tools auto-lookup and
OAuth callback cleanup. Remove the temporary debug endpoint.
2026-02-10 09:22:50 +04:00
Zamil Majdy
8eed4ad653 fix(mcp): Use ProviderName enum directly instead of str() for credential provider
Python 3.13 changed str(StrEnum) to return "ClassName.MEMBER" instead of
the plain value. This caused MCP credentials to be stored with provider
"ProviderName.MCP" instead of "mcp", leading to type/provider mismatch
errors during graph validation and execution.

Fix: Pass the enum directly to Pydantic (which extracts .value automatically),
matching the pattern used by all other OAuth handlers. Use .value explicitly
only in non-Pydantic contexts (string comparisons, API calls).
2026-02-10 09:06:29 +04:00
Zamil Majdy
7744b89e96 fix(mcp): Use ProviderName.MCP.value instead of str() for credential provider
Python 3.13 changed str(StrEnum) to return "ClassName.MEMBER" instead of
the plain value. This caused MCP credentials to be stored with provider
"ProviderName.MCP" instead of "mcp", leading to type/provider mismatch
errors during graph validation and execution.
2026-02-10 09:04:38 +04:00
Zamil Majdy
4c02cd8f2f fix(mcp): Handle optional credentials in graph save and execution validation
- _on_graph_activate: Clear stale credential references for optional
  fields instead of blocking the save. Checks both node metadata
  (credentials_optional) and block schema (field not in required_fields).
- _validate_node_input_credentials: Use block schema's required_fields
  as fallback for credentials_optional check, so MCP blocks with
  default={} credentials are properly treated as optional.
- Set credentials_optional metadata on new MCP nodes in the frontend.
2026-02-10 08:53:15 +04:00
Zamil Majdy
909f313e1e fix(frontend/mcp): Filter credential auto-select by server URL discriminator
Prevent MCP credential cross-contamination where a credential for one
server (e.g. Sentry) fills credential fields for other servers (e.g.
Linear). Adds matchesDiscriminatorValues() to match credentials by host
against discriminator_values from the schema.
2026-02-10 07:39:44 +04:00
Zamil Majdy
edd9a90903 refactor(mcp): Share OAuth popup logic and fix credential persistence
- Extract shared OAuth popup utility (oauth-popup.ts) used by both
  MCPToolDialog and useCredentialsInput, eliminating ~200 lines of
  duplicated BroadcastChannel/postMessage/localStorage listener code
- Add mcpOAuthCallback to credentials provider so MCP credentials
  are added to the in-memory cache after OAuth (fixes credentials not
  appearing in the credential picker after OAuth via MCPToolDialog)
- Fix oauth_test.py async fixtures missing loop_scope="session"
- Add MCP token refresh handler in creds_manager for dynamic endpoints
- Fix enum string representation in CredentialsFieldInfo.combine()
2026-02-10 07:17:05 +04:00
Zamil Majdy
ba031329e9 fix(mcp): Integrate MCPToolBlock with standard credentials system
- Replace manual credential_id field with CredentialsMetaInput pattern
- Fix credential deduplication so different MCP server URLs get separate
  credential entries in the task credentials panel
- Add descriptive display names (e.g. "MCP: mcp.sentry.dev")
- Fix OAuth popup callback by adding mcp_callback route to middleware
  exclusion list and adding localStorage polling fallback
- Fix SSRF test fixture to patch Requests constructor directly
- Add MCP server URL matching for credential auto-assignment
- Return CredentialsMetaResponse from MCP OAuth callback
- Support MCP-specific OAuth flow in frontend credential input
- Filter MCP credentials by server URL in frontend
- Add test coverage for credential deduplication logic
2026-02-09 20:59:37 +04:00
Zamil Majdy
6ab1a6867e fix(backend/mcp): Fix pyright errors and formatting in MCP block and tests
- Use isinstance(creds, APIKeyCredentials) instead of hasattr check
- Rewrite integration tests to use user_id param and mock _resolve_auth_token
- Fix f-string and line-length formatting issues in routes.py
2026-02-09 19:16:17 +04:00
Zamil Majdy
d9269310cc fix(frontend/mcp): Loop HTML tag stripping to prevent XSS bypass
The single-pass regex `/<[^>]+>/g` can be bypassed with nested tags
like `<scr<script>ipt>`. Loop until no more tags are found.
Note: React auto-escapes JSX so this is defense-in-depth.
2026-02-09 19:10:17 +04:00
Zamil Majdy
fe70b6929f fix(mcp): Remove trusted_origins to prevent SSRF on user-provided URLs
User-provided MCP server URLs should not bypass SSRF IP-blocking
validation. Remove trusted_origins from all MCP code so that
private/internal IPs are properly blocked. Keep ThreadedResolver
in HostResolver fallback for DNS reliability in subprocess
environments.
2026-02-09 18:55:17 +04:00
Zamil Majdy
340520ba85 fix(mcp): OAuth discovery fallback, session ID, credential lookup, and DNS reliability
- Support MCP servers that serve OAuth metadata directly without
  protected-resource metadata (e.g. Linear) by falling back to
  discover_auth_server_metadata on the server's own origin
- Omit resource_url when no protected-resource metadata exists to
  avoid token audience mismatch errors (RFC 8707 resource is optional)
- Add Mcp-Session-Id header tracking per MCP Streamable HTTP spec
- Fall back to server_url credential lookup when credential_id is
  empty (pruneEmptyValues strips it from saved graphs)
- Use ThreadedResolver instead of c-ares AsyncResolver to avoid DNS
  failures in forked subprocess environments
- Simplify OAuth UX: single "Sign in & Connect" button on 401,
  remove sticky localStorage URL prefill
- Clean up stale MCP credentials on re-authentication
2026-02-09 18:51:53 +04:00
Reinier van der Leer
6467f6734f debug(backend/chat): Add timing logging to chat stream generation mechanism (#12019)
[SECRT-1912: Investigate & eliminate chat session start
latency](https://linear.app/autogpt/issue/SECRT-1912)

### Changes 🏗️

- Add timing logs to `backend.api.features.chat` in `routes.py`,
`service.py`, and `stream_registry.py`
- Remove unneeded DB join in `create_chat_session`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - CI checks
2026-02-09 14:05:29 +00:00
Zamil Majdy
6c2791b00b fix(frontend/mcp): Robust OAuth callback with localStorage fallback and popup close detection
BroadcastChannel can silently fail in some browser scenarios. Added:
- localStorage as third communication method in callback page
- storage event listener in dialog
- Popup close detection that checks localStorage directly
- Cleaned up auth-required box styling (gray instead of amber)
2026-02-09 17:52:02 +04:00
Otto
5a30d11416 refactor(copilot): Code cleanup and deduplication (#11950)
## Summary

Code cleanup of the AI Copilot codebase - rebased onto latest dev.

## Changes

### New Files
- `backend/util/validation.py` - UUID validation helpers
- `backend/api/features/chat/tools/helpers.py` - Shared tool utilities

### Credential Matching Consolidation  
- Added shared utilities to `utils.py`
- Refactored `run_block._check_block_credentials()` with discriminator
support
- Extracted `_resolve_discriminated_credentials()` for multi-provider
handling

### Routes Cleanup
- Extracted `_create_stream_generator()` and `SSE_RESPONSE_HEADERS`

### Tool Files Cleanup
- Updated `run_agent.py` and `run_block.py` to use shared helpers

**WIP** - This PR will be updated incrementally.
2026-02-09 13:43:55 +00:00
Zamil Majdy
7decc20a32 fix(backend/mcp): Auto-refresh expired OAuth tokens before MCP tool calls
_resolve_auth_token now checks token expiry and refreshes using
MCPOAuthHandler with metadata (token_url, client_id, client_secret)
stored during the OAuth callback flow.
2026-02-09 17:37:24 +04:00
Zamil Majdy
54375065d5 fix(mcp): Reshape execution input for MCPToolBlock like AgentExecutorBlock
The dynamic get_input_defaults returns only tool_arguments, so the
execution engine loses block-level fields like server_url. Reconstruct
the full Input from node.input_default and set tool_arguments from the
resolved dynamic input, matching the AgentExecutorBlock pattern.
2026-02-09 14:51:49 +04:00
Zamil Majdy
d62fde9445 fix(mcp): Use manual credential resolution instead of CredentialsField
The block framework's CredentialsField requires credentials to always be
present, which doesn't work for public MCP servers. Replace it with a
plain credential_id field and manual resolution from the credential store,
allowing both authenticated and public MCP servers to work seamlessly.
2026-02-09 14:41:14 +04:00
Zamil Majdy
03487f7b4d fix(frontend/mcp): Remove broken credentials widget and disable auto-connect
- Remove credentials field from MCP dynamic schema since auth is handled
  by the dialog's OAuth flow (the standard credentials widget doesn't
  support MCP as a provider and fails with 404)
- Simplify FormCreator MCP handling — all form fields are tool arguments
- Disable auto-connect on dialog open; pre-fill last URL instead so user
  can edit before connecting
2026-02-09 14:27:36 +04:00
Bently
1f4105e8f9 fix(frontend): Handle object values in FileInput component (#11948)
Fixes
[#11800](https://github.com/Significant-Gravitas/AutoGPT/issues/11800)

## Problem
The FileInput component crashed with `TypeError: e.startsWith is not a
function` when the value was an object (from external API) instead of a
string.

## Example Input Object
When using the external API
(`/external-api/v1/graphs/{id}/execute/{version}`), file inputs can be
passed as objects:

```json
{
  "node_input": {
    "input_image": {
      "name": "image.jpeg",
      "type": "image/jpeg",
      "size": 131147,
      "data": "/9j/4QAW..."
    }
  }
}
```

## Changes
- Updated `getFileLabelFromValue()` to handle object format: `{ name,
type, size, data }`
- Added type guards for string vs object values
- Graceful fallback for edge cases (null, undefined, empty object)

## Test cases verified
- Object with name: returns filename
- Object with type only: extracts and formats MIME type
- String data URI: parses correctly
- String file path: extracts extension
- Edge cases: returns "File" fallback
2026-02-09 10:25:08 +00:00
Bently
caf9ff34e6 fix(backend): Handle stale RabbitMQ channels on connection drop (#11929)
### Changes 🏗️

Fixes
[**AUTOGPT-SERVER-1TN**](https://autoagpt.sentry.io/issues/?query=AUTOGPT-SERVER-1TN)
(~39K events since Feb 2025) and related connection issues
**6JC/6JD/6JE/6JF** (~6K combined).

#### Problem

When the RabbitMQ TCP connection drops (network blip, server restart,
etc.):

1. `connect_robust` (aio_pika) automatically reconnects the underlying
AMQP connection
2. But `AsyncRabbitMQ._channel` still references the **old dead
channel**
3. `is_ready` checks `not self._channel.is_closed` — but the channel
object doesn't know the transport is gone
4. `publish_message` tries to use the stale channel →
`ChannelInvalidStateError: No active transport in channel`
5. `@func_retry` retries 5 times, but each retry hits the same stale
channel (it passes `is_ready`)

This means every connection drop generates errors until the process is
restarted.

#### Fix

**New `_ensure_channel()` helper** that resets stale channels before
reconnecting, so `connect()` creates a fresh one instead of
short-circuiting on `is_connected`.

**Explicit `ChannelInvalidStateError` handling in `publish_message`:**
1. First attempt uses `_ensure_channel()` (handles normal staleness)
2. If publish throws `ChannelInvalidStateError`, does a full reconnect
(resets both `_channel` and `_connection`) and retries once
3. `@func_retry` provides additional retry resilience on top

**Simplified `get_channel()`** to use the same resilient helper.

**1 file changed, 62 insertions, 24 deletions.**

#### Impact
- Eliminates ~39K `ChannelInvalidStateError` Sentry events
- RabbitMQ operations self-heal after connection drops without process
restart
- Related transport EOF errors (6JC/6JD/6JE/6JF) should also reduce
2026-02-09 10:24:08 +00:00
Zamil Majdy
df41d02fce feat(frontend/mcp): Add MCP tool discovery UI, OAuth flow, and dynamic block schema
- Add MCPToolDialog with tool discovery, OAuth sign-in, and card-based tool selection
- Add OAuth callback route using BroadcastChannel API for popup communication
- Add API client methods for MCP discovery, OAuth login, and callback
- Register MCP API routes on the backend REST API
- Render dynamic input schema for MCP blocks (credentials + tool params)
  in both legacy and new builder CustomNode components
- Nest MCP tool argument values under tool_arguments in hardcodedValues
- Display tool name with server name prefix in block header
- Add backend route tests for discovery, OAuth login, and callback endpoints
2026-02-09 14:18:59 +04:00
Otto
7c9e47ba76 fix(mcp): Remove redundant exception handling and unnecessary str() cast
- client.py: except (ValueError, Exception) → except Exception
  (Exception already catches ValueError, so it's redundant)
- oauth.py: SecretStr(str(tokens[...])) → SecretStr(tokens[...])
  (refresh_token is already a string, no cast needed)
2026-02-09 08:40:58 +00:00
Nicholas Tindle
e8fc8ee623 fix(backend): filter graph-only blocks from CoPilot's find_block results (#11892)
Filters out blocks that are unsuitable for standalone execution from
CoPilot's block search and execution. These blocks serve graph-specific
purposes and will either fail, hang, or confuse users when run outside
of a graph context.

**Important:** This does NOT affect the Builder UI which uses
`load_all_blocks()` directly.

### Changes 🏗️

- **find_block.py**: Added `EXCLUDED_BLOCK_TYPES` and
`EXCLUDED_BLOCK_IDS` constants, skip excluded blocks in search results
- **run_block.py**: Added execution guard that returns clear error
message for excluded blocks
- **content_handlers.py**: Added filtering to
`BlockHandler.get_missing_items()` and `get_stats()` to prevent indexing
excluded blocks

**Excluded by BlockType:**
| BlockType | Reason |
|-----------|--------|
| `INPUT` | Graph interface definition - data enters via chat, not graph
inputs |
| `OUTPUT` | Graph interface definition - data exits via chat, not graph
outputs |
| `WEBHOOK` | Wait for external events - would hang forever in CoPilot |
| `WEBHOOK_MANUAL` | Same as WEBHOOK |
| `NOTE` | Visual annotation only - no runtime behavior |
| `HUMAN_IN_THE_LOOP` | Pauses for human approval - CoPilot IS
human-in-the-loop |
| `AGENT` | AgentExecutorBlock requires graph context - use `run_agent`
tool instead |

**Excluded by ID:**
| Block | Reason |
|-------|--------|
| `SmartDecisionMakerBlock` | Dynamically discovers downstream blocks
via graph topology |

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [ ] Search for "input" in CoPilot - should NOT return AgentInputBlock
variants
- [ ] Search for "output" in CoPilot - should NOT return
AgentOutputBlock
- [ ] Search for "webhook" in CoPilot - should NOT return trigger blocks
- [ ] Search for "human" in CoPilot - should NOT return
HumanInTheLoopBlock
- [ ] Search for "decision" in CoPilot - should NOT return
SmartDecisionMakerBlock
- [ ] Verify functional blocks still appear (e.g., "email", "http",
"text")
  - [ ] Verify Builder UI still shows ALL blocks (no regression)

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

No configuration changes required.

---

Resolves: [SECRT-1831](https://linear.app/autogpt/issue/SECRT-1831)

🤖 Generated with [Claude Code](https://claude.ai/code)

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Low Risk**
> Behavior change is limited to CoPilot’s block discovery/execution
guards and is covered by new tests; main risk is inadvertently excluding
a block that should be runnable.
> 
> **Overview**
> CoPilot now **filters out graph-only blocks** from `find_block`
results and prevents them from being executed via `run_block`, returning
a clear error when a user attempts to run an excluded block.
> 
> `find_block` introduces explicit exclusion lists (by `BlockType` and a
specific block ID), over-fetches search results to maintain up to 10
usable matches after filtering, and adds debug logging when results are
reduced. New unit tests cover both the search filtering and the
`run_block` execution guard; a minor cleanup removes an unused `pytest`
import in `execution_queue_test.py`.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
bc50755dcf. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <ntindle@users.noreply.github.com>
Co-authored-by: Otto <otto@agpt.co>
2026-02-09 07:19:43 +00:00
dependabot[bot]
1a16e203b8 chore(deps): Bump actions/setup-node from 4 to 6 (#11213)
Bumps [actions/setup-node](https://github.com/actions/setup-node) from 4
to 6.
<details>
<summary>Release notes</summary>
<p><em>Sourced from <a
href="https://github.com/actions/setup-node/releases">actions/setup-node's
releases</a>.</em></p>
<blockquote>
<h2>v6.0.0</h2>
<h2>What's Changed</h2>
<p><strong>Breaking Changes</strong></p>
<ul>
<li>Limit automatic caching to npm, update workflows and documentation
by <a
href="https://github.com/priyagupta108"><code>@​priyagupta108</code></a>
in <a
href="https://redirect.github.com/actions/setup-node/pull/1374">actions/setup-node#1374</a></li>
</ul>
<p><strong>Dependency Upgrades</strong></p>
<ul>
<li>Upgrade ts-jest from 29.1.2 to 29.4.1 and document breaking changes
in v5 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a>[bot]
in <a
href="https://redirect.github.com/actions/setup-node/pull/1336">#1336</a></li>
<li>Upgrade prettier from 2.8.8 to 3.6.2 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a>[bot]
in <a
href="https://redirect.github.com/actions/setup-node/pull/1334">#1334</a></li>
<li>Upgrade actions/publish-action from 0.3.0 to 0.4.0 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a>[bot]
in <a
href="https://redirect.github.com/actions/setup-node/pull/1362">#1362</a></li>
</ul>
<p><strong>Full Changelog</strong>: <a
href="https://github.com/actions/setup-node/compare/v5...v6.0.0">https://github.com/actions/setup-node/compare/v5...v6.0.0</a></p>
<h2>v5.0.0</h2>
<h2>What's Changed</h2>
<h3>Breaking Changes</h3>
<ul>
<li>Enhance caching in setup-node with automatic package manager
detection by <a
href="https://github.com/priya-kinthali"><code>@​priya-kinthali</code></a>
in <a
href="https://redirect.github.com/actions/setup-node/pull/1348">actions/setup-node#1348</a></li>
</ul>
<p>This update, introduces automatic caching when a valid
<code>packageManager</code> field is present in your
<code>package.json</code>. This aims to improve workflow performance and
make dependency management more seamless.
To disable this automatic caching, set <code>package-manager-cache:
false</code></p>
<pre lang="yaml"><code>steps:
- uses: actions/checkout@v5
- uses: actions/setup-node@v5
  with:
    package-manager-cache: false
</code></pre>
<ul>
<li>Upgrade action to use node24 by <a
href="https://github.com/salmanmkc"><code>@​salmanmkc</code></a> in <a
href="https://redirect.github.com/actions/setup-node/pull/1325">actions/setup-node#1325</a></li>
</ul>
<p>Make sure your runner is on version v2.327.1 or later to ensure
compatibility with this release. <a
href="https://github.com/actions/runner/releases/tag/v2.327.1">See
Release Notes</a></p>
<h3>Dependency Upgrades</h3>
<ul>
<li>Upgrade <code>@​octokit/request-error</code> and
<code>@​actions/github</code> by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a>[bot]
in <a
href="https://redirect.github.com/actions/setup-node/pull/1227">actions/setup-node#1227</a></li>
<li>Upgrade uuid from 9.0.1 to 11.1.0 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a>[bot]
in <a
href="https://redirect.github.com/actions/setup-node/pull/1273">actions/setup-node#1273</a></li>
<li>Upgrade undici from 5.28.5 to 5.29.0 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a>[bot]
in <a
href="https://redirect.github.com/actions/setup-node/pull/1295">actions/setup-node#1295</a></li>
<li>Upgrade form-data to bring in fix for critical vulnerability by <a
href="https://github.com/gowridurgad"><code>@​gowridurgad</code></a> in
<a
href="https://redirect.github.com/actions/setup-node/pull/1332">actions/setup-node#1332</a></li>
<li>Upgrade actions/checkout from 4 to 5 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a>[bot]
in <a
href="https://redirect.github.com/actions/setup-node/pull/1345">actions/setup-node#1345</a></li>
</ul>
<h2>New Contributors</h2>
<ul>
<li><a
href="https://github.com/priya-kinthali"><code>@​priya-kinthali</code></a>
made their first contribution in <a
href="https://redirect.github.com/actions/setup-node/pull/1348">actions/setup-node#1348</a></li>
<li><a href="https://github.com/salmanmkc"><code>@​salmanmkc</code></a>
made their first contribution in <a
href="https://redirect.github.com/actions/setup-node/pull/1325">actions/setup-node#1325</a></li>
</ul>
<p><strong>Full Changelog</strong>: <a
href="https://github.com/actions/setup-node/compare/v4...v5.0.0">https://github.com/actions/setup-node/compare/v4...v5.0.0</a></p>
<h2>v4.4.0</h2>
<!-- raw HTML omitted -->
</blockquote>
<p>... (truncated)</p>
</details>
<details>
<summary>Commits</summary>
<ul>
<li><a
href="2028fbc5c2"><code>2028fbc</code></a>
Limit automatic caching to npm, update workflows and documentation (<a
href="https://redirect.github.com/actions/setup-node/issues/1374">#1374</a>)</li>
<li><a
href="13427813f7"><code>1342781</code></a>
Bump actions/publish-action from 0.3.0 to 0.4.0 (<a
href="https://redirect.github.com/actions/setup-node/issues/1362">#1362</a>)</li>
<li><a
href="89d709d423"><code>89d709d</code></a>
Bump prettier from 2.8.8 to 3.6.2 (<a
href="https://redirect.github.com/actions/setup-node/issues/1334">#1334</a>)</li>
<li><a
href="cd2651c462"><code>cd2651c</code></a>
Bump ts-jest from 29.1.2 to 29.4.1 (<a
href="https://redirect.github.com/actions/setup-node/issues/1336">#1336</a>)</li>
<li><a
href="a0853c2454"><code>a0853c2</code></a>
Bump actions/checkout from 4 to 5 (<a
href="https://redirect.github.com/actions/setup-node/issues/1345">#1345</a>)</li>
<li><a
href="b7234cc9fe"><code>b7234cc</code></a>
Upgrade action to use node24 (<a
href="https://redirect.github.com/actions/setup-node/issues/1325">#1325</a>)</li>
<li><a
href="d7a11313b5"><code>d7a1131</code></a>
Enhance caching in setup-node with automatic package manager detection
(<a
href="https://redirect.github.com/actions/setup-node/issues/1348">#1348</a>)</li>
<li><a
href="5e2628c959"><code>5e2628c</code></a>
Bumps form-data (<a
href="https://redirect.github.com/actions/setup-node/issues/1332">#1332</a>)</li>
<li><a
href="65beceff8e"><code>65becef</code></a>
Bump undici from 5.28.5 to 5.29.0 (<a
href="https://redirect.github.com/actions/setup-node/issues/1295">#1295</a>)</li>
<li><a
href="7e24a656e1"><code>7e24a65</code></a>
Bump uuid from 9.0.1 to 11.1.0 (<a
href="https://redirect.github.com/actions/setup-node/issues/1273">#1273</a>)</li>
<li>Additional commits viewable in <a
href="https://github.com/actions/setup-node/compare/v4...v6">compare
view</a></li>
</ul>
</details>
<br />


[![Dependabot compatibility
score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=actions/setup-node&package-manager=github_actions&previous-version=4&new-version=6)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)

You can trigger a rebase of this PR by commenting `@dependabot rebase`.

[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)

---

<details>
<summary>Dependabot commands and options</summary>
<br />

You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after
your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge
and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating
it. You can achieve the same result by closing it manually
- `@dependabot show <dependency name> ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)


</details>

> **Note**
> Automatic rebases have been disabled on this pull request as it has
been open for over 30 days.

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Nick Tindle <nick@ntindle.com>
2026-02-09 07:11:21 +00:00
Zamil Majdy
e59e8dd9a9 fix(mcp): Skip e2e tests in CI unless --run-e2e is passed
E2e tests hit a real external MCP server and are inherently flaky.
Skip them by default, require --run-e2e flag to opt in.
2026-02-09 10:14:35 +04:00
Zamil Majdy
7aab2eb1d5 style(mcp): Apply linter formatting 2026-02-08 20:00:44 +04:00
Zamil Majdy
5ab28ccda2 fix(mcp): Fix pyright errors in test files
- Add type: ignore for aiohttp private _server.sockets access
- Add assert not None before accessing Optional refresh_token
2026-02-08 19:52:52 +04:00
Zamil Majdy
4fe0f05980 docs: Sync block documentation for MCP Tool block 2026-02-08 19:37:49 +04:00
Zamil Majdy
19b3373052 fix(mcp): Address PR review comments
- Fix get_missing_input/get_mismatch_error to validate tool_arguments
  dict instead of the entire BlockInput data (critical bug)
- Add type check for non-dict JSON-RPC error field in client.py
- Add try/catch for non-JSON responses in client.py
- Add raise_for_status and error payload checks to OAuth token requests
- Remove hardcoded placeholder skip-list from _extract_auth_token
- Fix server start timeout check in integration tests
- Remove unused MCPTool import, move execute_block_test to top-level
- Update tests to match fixed validation behavior
- Fix MCP_BLOCK_IMPLEMENTATION.md (remove duplicate section, local path)
- Soften PKCE comment in oauth.py
2026-02-08 19:34:28 +04:00
Zamil Majdy
7db3f12876 feat(backend/blocks/mcp): Add SSE support, OAuth auth, and e2e tests
- Handle text/event-stream (SSE) responses from real MCP servers
  (MCPClient._parse_sse_response) alongside plain JSON responses
- Add e2e tests against OpenAI docs MCP server (developers.openai.com/mcp)
  verifying SSE parsing, tool discovery, and tool execution work with a
  real production MCP server
- Support both api_key and oauth2 credential types on MCPToolBlock
  (MCPCredentials union type, _extract_auth_token helper)
- Add MCPOAuthHandler implementing BaseOAuthHandler with dynamic
  endpoints (authorize_url, token_url) for MCP OAuth 2.1 with PKCE
- Add OAuth metadata discovery to MCPClient (discover_auth,
  discover_auth_server_metadata) per RFC 9728 / RFC 8414
- 76 total tests: 46 unit, 11 OAuth, 14 integration, 5 e2e
2026-02-08 16:32:50 +04:00
Zamil Majdy
e9b996abb0 feat(backend/blocks): Add integration tests and trusted_origins support
- Add a test MCP server (test_server.py) for integration testing
- Add 14 integration tests that hit a real local MCP server over HTTP
- Add trusted_origins support to MCPClient for localhost/internal servers
- MCPToolBlock now trusts the user-configured server URL by default
- Add local conftest.py to avoid SpinTestServer overhead for MCP tests

Test results: 34 unit tests + 14 integration tests = 48 total, all passing
2026-02-08 13:49:44 +04:00
Zamil Majdy
9b972389a0 feat(backend/blocks): Add MCP (Model Context Protocol) tool block
Add a dynamic MCPToolBlock that can connect to any MCP server, discover
available tools, and execute them with dynamically generated input/output
schemas. This follows the same pattern as AgentExecutorBlock for dynamic
schema handling.

New files:
- backend/blocks/mcp/client.py — MCP Streamable HTTP client (JSON-RPC 2.0)
- backend/blocks/mcp/block.py — MCPToolBlock with dynamic schema
- backend/blocks/mcp/test_mcp.py — 34 tests covering client + block
- MCP_BLOCK_IMPLEMENTATION.md — Design document

Modified files:
- backend/integrations/providers.py — Add MCP provider name
2026-02-08 12:49:28 +04:00
234 changed files with 18129 additions and 9592 deletions

View File

@@ -78,7 +78,7 @@ jobs:
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml) # Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v4 uses: actions/setup-node@v6
with: with:
node-version: "22" node-version: "22"

View File

@@ -94,7 +94,7 @@ jobs:
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml) # Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v4 uses: actions/setup-node@v6
with: with:
node-version: "22" node-version: "22"

View File

@@ -76,7 +76,7 @@ jobs:
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml) # Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v4 uses: actions/setup-node@v6
with: with:
node-version: "22" node-version: "22"

View File

@@ -42,7 +42,7 @@ jobs:
- 'autogpt_platform/frontend/src/components/**' - 'autogpt_platform/frontend/src/components/**'
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v4 uses: actions/setup-node@v6
with: with:
node-version: "22.18.0" node-version: "22.18.0"
@@ -74,7 +74,7 @@ jobs:
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v4 uses: actions/setup-node@v6
with: with:
node-version: "22.18.0" node-version: "22.18.0"
@@ -112,7 +112,7 @@ jobs:
fetch-depth: 0 fetch-depth: 0
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v4 uses: actions/setup-node@v6
with: with:
node-version: "22.18.0" node-version: "22.18.0"
@@ -153,7 +153,7 @@ jobs:
submodules: recursive submodules: recursive
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v4 uses: actions/setup-node@v6
with: with:
node-version: "22.18.0" node-version: "22.18.0"
@@ -282,7 +282,7 @@ jobs:
submodules: recursive submodules: recursive
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v4 uses: actions/setup-node@v6
with: with:
node-version: "22.18.0" node-version: "22.18.0"

View File

@@ -32,7 +32,7 @@ jobs:
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v4 uses: actions/setup-node@v6
with: with:
node-version: "22.18.0" node-version: "22.18.0"
@@ -56,7 +56,7 @@ jobs:
run: pnpm install --frozen-lockfile run: pnpm install --frozen-lockfile
types: types:
runs-on: ubuntu-latest runs-on: big-boi
needs: setup needs: setup
strategy: strategy:
fail-fast: false fail-fast: false
@@ -68,7 +68,7 @@ jobs:
submodules: recursive submodules: recursive
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v4 uses: actions/setup-node@v6
with: with:
node-version: "22.18.0" node-version: "22.18.0"
@@ -85,7 +85,7 @@ jobs:
- name: Run docker compose - name: Run docker compose
run: | run: |
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d docker compose -f ../docker-compose.yml --profile local up -d deps_backend
- name: Restore dependencies cache - name: Restore dependencies cache
uses: actions/cache@v5 uses: actions/cache@v5

View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. # This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
[[package]] [[package]]
name = "annotated-doc" name = "annotated-doc"
@@ -67,7 +67,7 @@ description = "Backport of asyncio.Runner, a context manager that controls event
optional = false optional = false
python-versions = "<3.11,>=3.8" python-versions = "<3.11,>=3.8"
groups = ["dev"] groups = ["dev"]
markers = "python_version == \"3.10\"" markers = "python_version < \"3.11\""
files = [ files = [
{file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"}, {file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"},
{file = "backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162"}, {file = "backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162"},
@@ -326,100 +326,118 @@ files = [
[[package]] [[package]]
name = "coverage" name = "coverage"
version = "7.10.5" version = "7.13.4"
description = "Code coverage measurement for Python" description = "Code coverage measurement for Python"
optional = false optional = false
python-versions = ">=3.9" python-versions = ">=3.10"
groups = ["dev"] groups = ["dev"]
files = [ files = [
{file = "coverage-7.10.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c6a5c3414bfc7451b879141ce772c546985163cf553f08e0f135f0699a911801"}, {file = "coverage-7.13.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0fc31c787a84f8cd6027eba44010517020e0d18487064cd3d8968941856d1415"},
{file = "coverage-7.10.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bc8e4d99ce82f1710cc3c125adc30fd1487d3cf6c2cd4994d78d68a47b16989a"}, {file = "coverage-7.13.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a32ebc02a1805adf637fc8dec324b5cdacd2e493515424f70ee33799573d661b"},
{file = "coverage-7.10.5-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:02252dc1216e512a9311f596b3169fad54abcb13827a8d76d5630c798a50a754"}, {file = "coverage-7.13.4-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:e24f9156097ff9dc286f2f913df3a7f63c0e333dcafa3c196f2c18b4175ca09a"},
{file = "coverage-7.10.5-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:73269df37883e02d460bee0cc16be90509faea1e3bd105d77360b512d5bb9c33"}, {file = "coverage-7.13.4-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8041b6c5bfdc03257666e9881d33b1abc88daccaf73f7b6340fb7946655cd10f"},
{file = "coverage-7.10.5-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1f8a81b0614642f91c9effd53eec284f965577591f51f547a1cbeb32035b4c2f"}, {file = "coverage-7.13.4-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2a09cfa6a5862bc2fc6ca7c3def5b2926194a56b8ab78ffcf617d28911123012"},
{file = "coverage-7.10.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6a29f8e0adb7f8c2b95fa2d4566a1d6e6722e0a637634c6563cb1ab844427dd9"}, {file = "coverage-7.13.4-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:296f8b0af861d3970c2a4d8c91d48eb4dd4771bcef9baedec6a9b515d7de3def"},
{file = "coverage-7.10.5-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:fcf6ab569436b4a647d4e91accba12509ad9f2554bc93d3aee23cc596e7f99c3"}, {file = "coverage-7.13.4-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e101609bcbbfb04605ea1027b10dc3735c094d12d40826a60f897b98b1c30256"},
{file = "coverage-7.10.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:90dc3d6fb222b194a5de60af8d190bedeeddcbc7add317e4a3cd333ee6b7c879"}, {file = "coverage-7.13.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:aa3feb8db2e87ff5e6d00d7e1480ae241876286691265657b500886c98f38bda"},
{file = "coverage-7.10.5-cp310-cp310-win32.whl", hash = "sha256:414a568cd545f9dc75f0686a0049393de8098414b58ea071e03395505b73d7a8"}, {file = "coverage-7.13.4-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:4fc7fa81bbaf5a02801b65346c8b3e657f1d93763e58c0abdf7c992addd81a92"},
{file = "coverage-7.10.5-cp310-cp310-win_amd64.whl", hash = "sha256:e551f9d03347196271935fd3c0c165f0e8c049220280c1120de0084d65e9c7ff"}, {file = "coverage-7.13.4-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:33901f604424145c6e9c2398684b92e176c0b12df77d52db81c20abd48c3794c"},
{file = "coverage-7.10.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c177e6ffe2ebc7c410785307758ee21258aa8e8092b44d09a2da767834f075f2"}, {file = "coverage-7.13.4-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:bb28c0f2cf2782508a40cec377935829d5fcc3ad9a3681375af4e84eb34b6b58"},
{file = "coverage-7.10.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:14d6071c51ad0f703d6440827eaa46386169b5fdced42631d5a5ac419616046f"}, {file = "coverage-7.13.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:9d107aff57a83222ddbd8d9ee705ede2af2cc926608b57abed8ef96b50b7e8f9"},
{file = "coverage-7.10.5-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:61f78c7c3bc272a410c5ae3fde7792b4ffb4acc03d35a7df73ca8978826bb7ab"}, {file = "coverage-7.13.4-cp310-cp310-win32.whl", hash = "sha256:a6f94a7d00eb18f1b6d403c91a88fd58cfc92d4b16080dfdb774afc8294469bf"},
{file = "coverage-7.10.5-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f39071caa126f69d63f99b324fb08c7b1da2ec28cbb1fe7b5b1799926492f65c"}, {file = "coverage-7.13.4-cp310-cp310-win_amd64.whl", hash = "sha256:2cb0f1e000ebc419632bbe04366a8990b6e32c4e0b51543a6484ffe15eaeda95"},
{file = "coverage-7.10.5-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:343a023193f04d46edc46b2616cdbee68c94dd10208ecd3adc56fcc54ef2baa1"}, {file = "coverage-7.13.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d490ba50c3f35dd7c17953c68f3270e7ccd1c6642e2d2afe2d8e720b98f5a053"},
{file = "coverage-7.10.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:585ffe93ae5894d1ebdee69fc0b0d4b7c75d8007983692fb300ac98eed146f78"}, {file = "coverage-7.13.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:19bc3c88078789f8ef36acb014d7241961dbf883fd2533d18cb1e7a5b4e28b11"},
{file = "coverage-7.10.5-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b0ef4e66f006ed181df29b59921bd8fc7ed7cd6a9289295cd8b2824b49b570df"}, {file = "coverage-7.13.4-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3998e5a32e62fdf410c0dbd3115df86297995d6e3429af80b8798aad894ca7aa"},
{file = "coverage-7.10.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:eb7b0bbf7cc1d0453b843eca7b5fa017874735bef9bfdfa4121373d2cc885ed6"}, {file = "coverage-7.13.4-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8e264226ec98e01a8e1054314af91ee6cde0eacac4f465cc93b03dbe0bce2fd7"},
{file = "coverage-7.10.5-cp311-cp311-win32.whl", hash = "sha256:1d043a8a06987cc0c98516e57c4d3fc2c1591364831e9deb59c9e1b4937e8caf"}, {file = "coverage-7.13.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a3aa4e7b9e416774b21797365b358a6e827ffadaaca81b69ee02946852449f00"},
{file = "coverage-7.10.5-cp311-cp311-win_amd64.whl", hash = "sha256:fefafcca09c3ac56372ef64a40f5fe17c5592fab906e0fdffd09543f3012ba50"}, {file = "coverage-7.13.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:71ca20079dd8f27fcf808817e281e90220475cd75115162218d0e27549f95fef"},
{file = "coverage-7.10.5-cp311-cp311-win_arm64.whl", hash = "sha256:7e78b767da8b5fc5b2faa69bb001edafcd6f3995b42a331c53ef9572c55ceb82"}, {file = "coverage-7.13.4-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e2f25215f1a359ab17320b47bcdaca3e6e6356652e8256f2441e4ef972052903"},
{file = "coverage-7.10.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c2d05c7e73c60a4cecc7d9b60dbfd603b4ebc0adafaef371445b47d0f805c8a9"}, {file = "coverage-7.13.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d65b2d373032411e86960604dc4edac91fdfb5dca539461cf2cbe78327d1e64f"},
{file = "coverage-7.10.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:32ddaa3b2c509778ed5373b177eb2bf5662405493baeff52278a0b4f9415188b"}, {file = "coverage-7.13.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:94eb63f9b363180aff17de3e7c8760c3ba94664ea2695c52f10111244d16a299"},
{file = "coverage-7.10.5-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:dd382410039fe062097aa0292ab6335a3f1e7af7bba2ef8d27dcda484918f20c"}, {file = "coverage-7.13.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:e856bf6616714c3a9fbc270ab54103f4e685ba236fa98c054e8f87f266c93505"},
{file = "coverage-7.10.5-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7fa22800f3908df31cea6fb230f20ac49e343515d968cc3a42b30d5c3ebf9b5a"}, {file = "coverage-7.13.4-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:65dfcbe305c3dfe658492df2d85259e0d79ead4177f9ae724b6fb245198f55d6"},
{file = "coverage-7.10.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f366a57ac81f5e12797136552f5b7502fa053c861a009b91b80ed51f2ce651c6"}, {file = "coverage-7.13.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b507778ae8a4c915436ed5c2e05b4a6cecfa70f734e19c22a005152a11c7b6a9"},
{file = "coverage-7.10.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5f1dc8f1980a272ad4a6c84cba7981792344dad33bf5869361576b7aef42733a"}, {file = "coverage-7.13.4-cp311-cp311-win32.whl", hash = "sha256:784fc3cf8be001197b652d51d3fd259b1e2262888693a4636e18879f613a62a9"},
{file = "coverage-7.10.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:2285c04ee8676f7938b02b4936d9b9b672064daab3187c20f73a55f3d70e6b4a"}, {file = "coverage-7.13.4-cp311-cp311-win_amd64.whl", hash = "sha256:2421d591f8ca05b308cf0092807308b2facbefe54af7c02ac22548b88b95c98f"},
{file = "coverage-7.10.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c2492e4dd9daab63f5f56286f8a04c51323d237631eb98505d87e4c4ff19ec34"}, {file = "coverage-7.13.4-cp311-cp311-win_arm64.whl", hash = "sha256:79e73a76b854d9c6088fe5d8b2ebe745f8681c55f7397c3c0a016192d681045f"},
{file = "coverage-7.10.5-cp312-cp312-win32.whl", hash = "sha256:38a9109c4ee8135d5df5505384fc2f20287a47ccbe0b3f04c53c9a1989c2bbaf"}, {file = "coverage-7.13.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:02231499b08dabbe2b96612993e5fc34217cdae907a51b906ac7fca8027a4459"},
{file = "coverage-7.10.5-cp312-cp312-win_amd64.whl", hash = "sha256:6b87f1ad60b30bc3c43c66afa7db6b22a3109902e28c5094957626a0143a001f"}, {file = "coverage-7.13.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40aa8808140e55dc022b15d8aa7f651b6b3d68b365ea0398f1441e0b04d859c3"},
{file = "coverage-7.10.5-cp312-cp312-win_arm64.whl", hash = "sha256:672a6c1da5aea6c629819a0e1461e89d244f78d7b60c424ecf4f1f2556c041d8"}, {file = "coverage-7.13.4-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:5b856a8ccf749480024ff3bd7310adaef57bf31fd17e1bfc404b7940b6986634"},
{file = "coverage-7.10.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ef3b83594d933020f54cf65ea1f4405d1f4e41a009c46df629dd964fcb6e907c"}, {file = "coverage-7.13.4-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:2c048ea43875fbf8b45d476ad79f179809c590ec7b79e2035c662e7afa3192e3"},
{file = "coverage-7.10.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:2b96bfdf7c0ea9faebce088a3ecb2382819da4fbc05c7b80040dbc428df6af44"}, {file = "coverage-7.13.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b7b38448866e83176e28086674fe7368ab8590e4610fb662b44e345b86d63ffa"},
{file = "coverage-7.10.5-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:63df1fdaffa42d914d5c4d293e838937638bf75c794cf20bee12978fc8c4e3bc"}, {file = "coverage-7.13.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:de6defc1c9badbf8b9e67ae90fd00519186d6ab64e5cc5f3d21359c2a9b2c1d3"},
{file = "coverage-7.10.5-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8002dc6a049aac0e81ecec97abfb08c01ef0c1fbf962d0c98da3950ace89b869"}, {file = "coverage-7.13.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:7eda778067ad7ffccd23ecffce537dface96212576a07924cbf0d8799d2ded5a"},
{file = "coverage-7.10.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:63d4bb2966d6f5f705a6b0c6784c8969c468dbc4bcf9d9ded8bff1c7e092451f"}, {file = "coverage-7.13.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e87f6c587c3f34356c3759f0420693e35e7eb0e2e41e4c011cb6ec6ecbbf1db7"},
{file = "coverage-7.10.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:1f672efc0731a6846b157389b6e6d5d5e9e59d1d1a23a5c66a99fd58339914d5"}, {file = "coverage-7.13.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:8248977c2e33aecb2ced42fef99f2d319e9904a36e55a8a68b69207fb7e43edc"},
{file = "coverage-7.10.5-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:3f39cef43d08049e8afc1fde4a5da8510fc6be843f8dea350ee46e2a26b2f54c"}, {file = "coverage-7.13.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:25381386e80ae727608e662474db537d4df1ecd42379b5ba33c84633a2b36d47"},
{file = "coverage-7.10.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2968647e3ed5a6c019a419264386b013979ff1fb67dd11f5c9886c43d6a31fc2"}, {file = "coverage-7.13.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:ee756f00726693e5ba94d6df2bdfd64d4852d23b09bb0bc700e3b30e6f333985"},
{file = "coverage-7.10.5-cp313-cp313-win32.whl", hash = "sha256:0d511dda38595b2b6934c2b730a1fd57a3635c6aa2a04cb74714cdfdd53846f4"}, {file = "coverage-7.13.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fdfc1e28e7c7cdce44985b3043bc13bbd9c747520f94a4d7164af8260b3d91f0"},
{file = "coverage-7.10.5-cp313-cp313-win_amd64.whl", hash = "sha256:9a86281794a393513cf117177fd39c796b3f8e3759bb2764259a2abba5cce54b"}, {file = "coverage-7.13.4-cp312-cp312-win32.whl", hash = "sha256:01d4cbc3c283a17fc1e42d614a119f7f438eabb593391283adca8dc86eff1246"},
{file = "coverage-7.10.5-cp313-cp313-win_arm64.whl", hash = "sha256:cebd8e906eb98bb09c10d1feed16096700b1198d482267f8bf0474e63a7b8d84"}, {file = "coverage-7.13.4-cp312-cp312-win_amd64.whl", hash = "sha256:9401ebc7ef522f01d01d45532c68c5ac40fb27113019b6b7d8b208f6e9baa126"},
{file = "coverage-7.10.5-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0520dff502da5e09d0d20781df74d8189ab334a1e40d5bafe2efaa4158e2d9e7"}, {file = "coverage-7.13.4-cp312-cp312-win_arm64.whl", hash = "sha256:b1ec7b6b6e93255f952e27ab58fbc68dcc468844b16ecbee881aeb29b6ab4d8d"},
{file = "coverage-7.10.5-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:d9cd64aca68f503ed3f1f18c7c9174cbb797baba02ca8ab5112f9d1c0328cd4b"}, {file = "coverage-7.13.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b66a2da594b6068b48b2692f043f35d4d3693fb639d5ea8b39533c2ad9ac3ab9"},
{file = "coverage-7.10.5-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:0913dd1613a33b13c4f84aa6e3f4198c1a21ee28ccb4f674985c1f22109f0aae"}, {file = "coverage-7.13.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3599eb3992d814d23b35c536c28df1a882caa950f8f507cef23d1cbf334995ac"},
{file = "coverage-7.10.5-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:1b7181c0feeb06ed8a02da02792f42f829a7b29990fef52eff257fef0885d760"}, {file = "coverage-7.13.4-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:93550784d9281e374fb5a12bf1324cc8a963fd63b2d2f223503ef0fd4aa339ea"},
{file = "coverage-7.10.5-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:36d42b7396b605f774d4372dd9c49bed71cbabce4ae1ccd074d155709dd8f235"}, {file = "coverage-7.13.4-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b720ce6a88a2755f7c697c23268ddc47a571b88052e6b155224347389fdf6a3b"},
{file = "coverage-7.10.5-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b4fdc777e05c4940b297bf47bf7eedd56a39a61dc23ba798e4b830d585486ca5"}, {file = "coverage-7.13.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7b322db1284a2ed3aa28ffd8ebe3db91c929b7a333c0820abec3d838ef5b3525"},
{file = "coverage-7.10.5-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:42144e8e346de44a6f1dbd0a56575dd8ab8dfa7e9007da02ea5b1c30ab33a7db"}, {file = "coverage-7.13.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f4594c67d8a7c89cf922d9df0438c7c7bb022ad506eddb0fdb2863359ff78242"},
{file = "coverage-7.10.5-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:66c644cbd7aed8fe266d5917e2c9f65458a51cfe5eeff9c05f15b335f697066e"}, {file = "coverage-7.13.4-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:53d133df809c743eb8bce33b24bcababb371f4441340578cd406e084d94a6148"},
{file = "coverage-7.10.5-cp313-cp313t-win32.whl", hash = "sha256:2d1b73023854068c44b0c554578a4e1ef1b050ed07cf8b431549e624a29a66ee"}, {file = "coverage-7.13.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:76451d1978b95ba6507a039090ba076105c87cc76fc3efd5d35d72093964d49a"},
{file = "coverage-7.10.5-cp313-cp313t-win_amd64.whl", hash = "sha256:54a1532c8a642d8cc0bd5a9a51f5a9dcc440294fd06e9dda55e743c5ec1a8f14"}, {file = "coverage-7.13.4-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7f57b33491e281e962021de110b451ab8a24182589be17e12a22c79047935e23"},
{file = "coverage-7.10.5-cp313-cp313t-win_arm64.whl", hash = "sha256:74d5b63fe3f5f5d372253a4ef92492c11a4305f3550631beaa432fc9df16fcff"}, {file = "coverage-7.13.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:1731dc33dc276dafc410a885cbf5992f1ff171393e48a21453b78727d090de80"},
{file = "coverage-7.10.5-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:68c5e0bc5f44f68053369fa0d94459c84548a77660a5f2561c5e5f1e3bed7031"}, {file = "coverage-7.13.4-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:bd60d4fe2f6fa7dff9223ca1bbc9f05d2b6697bc5961072e5d3b952d46e1b1ea"},
{file = "coverage-7.10.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:cf33134ffae93865e32e1e37df043bef15a5e857d8caebc0099d225c579b0fa3"}, {file = "coverage-7.13.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9181a3ccead280b828fae232df12b16652702b49d41e99d657f46cc7b1f6ec7a"},
{file = "coverage-7.10.5-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:ad8fa9d5193bafcf668231294241302b5e683a0518bf1e33a9a0dfb142ec3031"}, {file = "coverage-7.13.4-cp313-cp313-win32.whl", hash = "sha256:f53d492307962561ac7de4cd1de3e363589b000ab69617c6156a16ba7237998d"},
{file = "coverage-7.10.5-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:146fa1531973d38ab4b689bc764592fe6c2f913e7e80a39e7eeafd11f0ef6db2"}, {file = "coverage-7.13.4-cp313-cp313-win_amd64.whl", hash = "sha256:e6f70dec1cc557e52df5306d051ef56003f74d56e9c4dd7ddb07e07ef32a84dd"},
{file = "coverage-7.10.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6013a37b8a4854c478d3219ee8bc2392dea51602dd0803a12d6f6182a0061762"}, {file = "coverage-7.13.4-cp313-cp313-win_arm64.whl", hash = "sha256:fb07dc5da7e849e2ad31a5d74e9bece81f30ecf5a42909d0a695f8bd1874d6af"},
{file = "coverage-7.10.5-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:eb90fe20db9c3d930fa2ad7a308207ab5b86bf6a76f54ab6a40be4012d88fcae"}, {file = "coverage-7.13.4-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:40d74da8e6c4b9ac18b15331c4b5ebc35a17069410cad462ad4f40dcd2d50c0d"},
{file = "coverage-7.10.5-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:384b34482272e960c438703cafe63316dfbea124ac62006a455c8410bf2a2262"}, {file = "coverage-7.13.4-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:4223b4230a376138939a9173f1bdd6521994f2aff8047fae100d6d94d50c5a12"},
{file = "coverage-7.10.5-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:467dc74bd0a1a7de2bedf8deaf6811f43602cb532bd34d81ffd6038d6d8abe99"}, {file = "coverage-7.13.4-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:1d4be36a5114c499f9f1f9195e95ebf979460dbe2d88e6816ea202010ba1c34b"},
{file = "coverage-7.10.5-cp314-cp314-win32.whl", hash = "sha256:556d23d4e6393ca898b2e63a5bca91e9ac2d5fb13299ec286cd69a09a7187fde"}, {file = "coverage-7.13.4-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:200dea7d1e8095cc6e98cdabe3fd1d21ab17d3cee6dab00cadbb2fe35d9c15b9"},
{file = "coverage-7.10.5-cp314-cp314-win_amd64.whl", hash = "sha256:f4446a9547681533c8fa3e3c6cf62121eeee616e6a92bd9201c6edd91beffe13"}, {file = "coverage-7.13.4-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b8eb931ee8e6d8243e253e5ed7336deea6904369d2fd8ae6e43f68abbf167092"},
{file = "coverage-7.10.5-cp314-cp314-win_arm64.whl", hash = "sha256:5e78bd9cf65da4c303bf663de0d73bf69f81e878bf72a94e9af67137c69b9fe9"}, {file = "coverage-7.13.4-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:75eab1ebe4f2f64d9509b984f9314d4aa788540368218b858dad56dc8f3e5eb9"},
{file = "coverage-7.10.5-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:5661bf987d91ec756a47c7e5df4fbcb949f39e32f9334ccd3f43233bbb65e508"}, {file = "coverage-7.13.4-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c35eb28c1d085eb7d8c9b3296567a1bebe03ce72962e932431b9a61f28facf26"},
{file = "coverage-7.10.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:a46473129244db42a720439a26984f8c6f834762fc4573616c1f37f13994b357"}, {file = "coverage-7.13.4-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:eb88b316ec33760714a4720feb2816a3a59180fd58c1985012054fa7aebee4c2"},
{file = "coverage-7.10.5-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:1f64b8d3415d60f24b058b58d859e9512624bdfa57a2d1f8aff93c1ec45c429b"}, {file = "coverage-7.13.4-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:7d41eead3cc673cbd38a4417deb7fd0b4ca26954ff7dc6078e33f6ff97bed940"},
{file = "coverage-7.10.5-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:44d43de99a9d90b20e0163f9770542357f58860a26e24dc1d924643bd6aa7cb4"}, {file = "coverage-7.13.4-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:fb26a934946a6afe0e326aebe0730cdff393a8bc0bbb65a2f41e30feddca399c"},
{file = "coverage-7.10.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a931a87e5ddb6b6404e65443b742cb1c14959622777f2a4efd81fba84f5d91ba"}, {file = "coverage-7.13.4-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:dae88bc0fc77edaa65c14be099bd57ee140cf507e6bfdeea7938457ab387efb0"},
{file = "coverage-7.10.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:f9559b906a100029274448f4c8b8b0a127daa4dade5661dfd821b8c188058842"}, {file = "coverage-7.13.4-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:845f352911777a8e722bfce168958214951e07e47e5d5d9744109fa5fe77f79b"},
{file = "coverage-7.10.5-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:b08801e25e3b4526ef9ced1aa29344131a8f5213c60c03c18fe4c6170ffa2874"}, {file = "coverage-7.13.4-cp313-cp313t-win32.whl", hash = "sha256:2fa8d5f8de70688a28240de9e139fa16b153cc3cbb01c5f16d88d6505ebdadf9"},
{file = "coverage-7.10.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:ed9749bb8eda35f8b636fb7632f1c62f735a236a5d4edadd8bbcc5ea0542e732"}, {file = "coverage-7.13.4-cp313-cp313t-win_amd64.whl", hash = "sha256:9351229c8c8407645840edcc277f4a2d44814d1bc34a2128c11c2a031d45a5dd"},
{file = "coverage-7.10.5-cp314-cp314t-win32.whl", hash = "sha256:609b60d123fc2cc63ccee6d17e4676699075db72d14ac3c107cc4976d516f2df"}, {file = "coverage-7.13.4-cp313-cp313t-win_arm64.whl", hash = "sha256:30b8d0512f2dc8c8747557e8fb459d6176a2c9e5731e2b74d311c03b78451997"},
{file = "coverage-7.10.5-cp314-cp314t-win_amd64.whl", hash = "sha256:0666cf3d2c1626b5a3463fd5b05f5e21f99e6aec40a3192eee4d07a15970b07f"}, {file = "coverage-7.13.4-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:300deaee342f90696ed186e3a00c71b5b3d27bffe9e827677954f4ee56969601"},
{file = "coverage-7.10.5-cp314-cp314t-win_arm64.whl", hash = "sha256:bc85eb2d35e760120540afddd3044a5bf69118a91a296a8b3940dfc4fdcfe1e2"}, {file = "coverage-7.13.4-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:29e3220258d682b6226a9b0925bc563ed9a1ebcff3cad30f043eceea7eaf2689"},
{file = "coverage-7.10.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:62835c1b00c4a4ace24c1a88561a5a59b612fbb83a525d1c70ff5720c97c0610"}, {file = "coverage-7.13.4-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:391ee8f19bef69210978363ca930f7328081c6a0152f1166c91f0b5fdd2a773c"},
{file = "coverage-7.10.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5255b3bbcc1d32a4069d6403820ac8e6dbcc1d68cb28a60a1ebf17e47028e898"}, {file = "coverage-7.13.4-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:0dd7ab8278f0d58a0128ba2fca25824321f05d059c1441800e934ff2efa52129"},
{file = "coverage-7.10.5-cp39-cp39-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3876385722e335d6e991c430302c24251ef9c2a9701b2b390f5473199b1b8ebf"}, {file = "coverage-7.13.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:78cdf0d578b15148b009ccf18c686aa4f719d887e76e6b40c38ffb61d264a552"},
{file = "coverage-7.10.5-cp39-cp39-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8048ce4b149c93447a55d279078c8ae98b08a6951a3c4d2d7e87f4efc7bfe100"}, {file = "coverage-7.13.4-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:48685fee12c2eb3b27c62f2658e7ea21e9c3239cba5a8a242801a0a3f6a8c62a"},
{file = "coverage-7.10.5-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4028e7558e268dd8bcf4d9484aad393cafa654c24b4885f6f9474bf53183a82a"}, {file = "coverage-7.13.4-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:4e83efc079eb39480e6346a15a1bcb3e9b04759c5202d157e1dd4303cd619356"},
{file = "coverage-7.10.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:03f47dc870eec0367fcdd603ca6a01517d2504e83dc18dbfafae37faec66129a"}, {file = "coverage-7.13.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:ecae9737b72408d6a950f7e525f30aca12d4bd8dd95e37342e5beb3a2a8c4f71"},
{file = "coverage-7.10.5-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:2d488d7d42b6ded7ea0704884f89dcabd2619505457de8fc9a6011c62106f6e5"}, {file = "coverage-7.13.4-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:ae4578f8528569d3cf303fef2ea569c7f4c4059a38c8667ccef15c6e1f118aa5"},
{file = "coverage-7.10.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:b3dcf2ead47fa8be14224ee817dfc1df98043af568fe120a22f81c0eb3c34ad2"}, {file = "coverage-7.13.4-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:6fdef321fdfbb30a197efa02d48fcd9981f0d8ad2ae8903ac318adc653f5df98"},
{file = "coverage-7.10.5-cp39-cp39-win32.whl", hash = "sha256:02650a11324b80057b8c9c29487020073d5e98a498f1857f37e3f9b6ea1b2426"}, {file = "coverage-7.13.4-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:2b0f6ccf3dbe577170bebfce1318707d0e8c3650003cb4b3a9dd744575daa8b5"},
{file = "coverage-7.10.5-cp39-cp39-win_amd64.whl", hash = "sha256:b45264dd450a10f9e03237b41a9a24e85cbb1e278e5a32adb1a303f58f0017f3"}, {file = "coverage-7.13.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:75fcd519f2a5765db3f0e391eb3b7d150cce1a771bf4c9f861aeab86c767a3c0"},
{file = "coverage-7.10.5-py3-none-any.whl", hash = "sha256:0be24d35e4db1d23d0db5c0f6a74a962e2ec83c426b5cac09f4234aadef38e4a"}, {file = "coverage-7.13.4-cp314-cp314-win32.whl", hash = "sha256:8e798c266c378da2bd819b0677df41ab46d78065fb2a399558f3f6cae78b2fbb"},
{file = "coverage-7.10.5.tar.gz", hash = "sha256:f2e57716a78bc3ae80b2207be0709a3b2b63b9f2dcf9740ee6ac03588a2015b6"}, {file = "coverage-7.13.4-cp314-cp314-win_amd64.whl", hash = "sha256:245e37f664d89861cf2329c9afa2c1fe9e6d4e1a09d872c947e70718aeeac505"},
{file = "coverage-7.13.4-cp314-cp314-win_arm64.whl", hash = "sha256:ad27098a189e5838900ce4c2a99f2fe42a0bf0c2093c17c69b45a71579e8d4a2"},
{file = "coverage-7.13.4-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:85480adfb35ffc32d40918aad81b89c69c9cc5661a9b8a81476d3e645321a056"},
{file = "coverage-7.13.4-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:79be69cf7f3bf9b0deeeb062eab7ac7f36cd4cc4c4dd694bd28921ba4d8596cc"},
{file = "coverage-7.13.4-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:caa421e2684e382c5d8973ac55e4f36bed6821a9bad5c953494de960c74595c9"},
{file = "coverage-7.13.4-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:14375934243ee05f56c45393fe2ce81fe5cc503c07cee2bdf1725fb8bef3ffaf"},
{file = "coverage-7.13.4-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:25a41c3104d08edb094d9db0d905ca54d0cd41c928bb6be3c4c799a54753af55"},
{file = "coverage-7.13.4-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:6f01afcff62bf9a08fb32b2c1d6e924236c0383c02c790732b6537269e466a72"},
{file = "coverage-7.13.4-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:eb9078108fbf0bcdde37c3f4779303673c2fa1fe8f7956e68d447d0dd426d38a"},
{file = "coverage-7.13.4-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:0e086334e8537ddd17e5f16a344777c1ab8194986ec533711cbe6c41cde841b6"},
{file = "coverage-7.13.4-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:725d985c5ab621268b2edb8e50dfe57633dc69bda071abc470fed55a14935fd3"},
{file = "coverage-7.13.4-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:3c06f0f1337c667b971ca2f975523347e63ec5e500b9aa5882d91931cd3ef750"},
{file = "coverage-7.13.4-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:590c0ed4bf8e85f745e6b805b2e1c457b2e33d5255dd9729743165253bc9ad39"},
{file = "coverage-7.13.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:eb30bf180de3f632cd043322dad5751390e5385108b2807368997d1a92a509d0"},
{file = "coverage-7.13.4-cp314-cp314t-win32.whl", hash = "sha256:c4240e7eded42d131a2d2c4dec70374b781b043ddc79a9de4d55ca71f8e98aea"},
{file = "coverage-7.13.4-cp314-cp314t-win_amd64.whl", hash = "sha256:4c7d3cc01e7350f2f0f6f7036caaf5673fb56b6998889ccfe9e1c1fe75a9c932"},
{file = "coverage-7.13.4-cp314-cp314t-win_arm64.whl", hash = "sha256:23e3f687cf945070d1c90f85db66d11e3025665d8dafa831301a0e0038f3db9b"},
{file = "coverage-7.13.4-py3-none-any.whl", hash = "sha256:1af1641e57cf7ba1bd67d677c9abdbcd6cc2ab7da3bca7fa1e2b7e50e65f2ad0"},
{file = "coverage-7.13.4.tar.gz", hash = "sha256:e5c8f6ed1e61a8b2dcdf31eb0b9bbf0130750ca79c1c49eb898e2ad86f5ccc91"},
] ]
[package.dependencies] [package.dependencies]
@@ -523,7 +541,7 @@ description = "Backport of PEP 654 (exception groups)"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
groups = ["main", "dev"] groups = ["main", "dev"]
markers = "python_version == \"3.10\"" markers = "python_version < \"3.11\""
files = [ files = [
{file = "exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10"}, {file = "exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10"},
{file = "exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88"}, {file = "exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88"},
@@ -2162,23 +2180,23 @@ testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
[[package]] [[package]]
name = "pytest-cov" name = "pytest-cov"
version = "6.2.1" version = "7.0.0"
description = "Pytest plugin for measuring coverage." description = "Pytest plugin for measuring coverage."
optional = false optional = false
python-versions = ">=3.9" python-versions = ">=3.9"
groups = ["dev"] groups = ["dev"]
files = [ files = [
{file = "pytest_cov-6.2.1-py3-none-any.whl", hash = "sha256:f5bc4c23f42f1cdd23c70b1dab1bbaef4fc505ba950d53e0081d0730dd7e86d5"}, {file = "pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861"},
{file = "pytest_cov-6.2.1.tar.gz", hash = "sha256:25cc6cc0a5358204b8108ecedc51a9b57b34cc6b8c967cc2c01a4e00d8a67da2"}, {file = "pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1"},
] ]
[package.dependencies] [package.dependencies]
coverage = {version = ">=7.5", extras = ["toml"]} coverage = {version = ">=7.10.6", extras = ["toml"]}
pluggy = ">=1.2" pluggy = ">=1.2"
pytest = ">=6.2.5" pytest = ">=7"
[package.extras] [package.extras]
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"] testing = ["process-tests", "pytest-xdist", "virtualenv"]
[[package]] [[package]]
name = "pytest-mock" name = "pytest-mock"
@@ -2545,7 +2563,7 @@ description = "A lil' TOML parser"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
groups = ["dev"] groups = ["dev"]
markers = "python_version == \"3.10\"" markers = "python_version < \"3.11\""
files = [ files = [
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
{file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},
@@ -2893,4 +2911,4 @@ type = ["pytest-mypy"]
[metadata] [metadata]
lock-version = "2.1" lock-version = "2.1"
python-versions = ">=3.10,<4.0" python-versions = ">=3.10,<4.0"
content-hash = "b7ac335a86aa44c3d7d2802298818b389a6f1286e3e9b7b0edb2ff06377cecaf" content-hash = "40eae94995dc0a388fa832ed4af9b6137f28d5b5ced3aaea70d5f91d4d9a179d"

View File

@@ -26,7 +26,7 @@ pyright = "^1.1.408"
pytest = "^8.4.1" pytest = "^8.4.1"
pytest-asyncio = "^1.3.0" pytest-asyncio = "^1.3.0"
pytest-mock = "^3.15.1" pytest-mock = "^3.15.1"
pytest-cov = "^6.2.1" pytest-cov = "^7.0.0"
ruff = "^0.15.0" ruff = "^0.15.0"
[build-system] [build-system]

View File

@@ -45,10 +45,7 @@ async def create_chat_session(
successfulAgentRuns=SafeJson({}), successfulAgentRuns=SafeJson({}),
successfulAgentSchedules=SafeJson({}), successfulAgentSchedules=SafeJson({}),
) )
return await PrismaChatSession.prisma().create( return await PrismaChatSession.prisma().create(data=data)
data=data,
include={"Messages": True},
)
async def update_chat_session( async def update_chat_session(

View File

@@ -18,6 +18,10 @@ class ResponseType(str, Enum):
START = "start" START = "start"
FINISH = "finish" FINISH = "finish"
# Step lifecycle (one LLM API call within a message)
START_STEP = "start-step"
FINISH_STEP = "finish-step"
# Text streaming # Text streaming
TEXT_START = "text-start" TEXT_START = "text-start"
TEXT_DELTA = "text-delta" TEXT_DELTA = "text-delta"
@@ -57,6 +61,16 @@ class StreamStart(StreamBaseResponse):
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream", description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
) )
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-protocol fields like taskId."""
import json
data: dict[str, Any] = {
"type": self.type.value,
"messageId": self.messageId,
}
return f"data: {json.dumps(data)}\n\n"
class StreamFinish(StreamBaseResponse): class StreamFinish(StreamBaseResponse):
"""End of message/stream.""" """End of message/stream."""
@@ -64,6 +78,26 @@ class StreamFinish(StreamBaseResponse):
type: ResponseType = ResponseType.FINISH type: ResponseType = ResponseType.FINISH
class StreamStartStep(StreamBaseResponse):
"""Start of a step (one LLM API call within a message).
The AI SDK uses this to add a step-start boundary to message.parts,
enabling visual separation between multiple LLM calls in a single message.
"""
type: ResponseType = ResponseType.START_STEP
class StreamFinishStep(StreamBaseResponse):
"""End of a step (one LLM API call within a message).
The AI SDK uses this to reset activeTextParts and activeReasoningParts,
so the next LLM call in a tool-call continuation starts with clean state.
"""
type: ResponseType = ResponseType.FINISH_STEP
# ========== Text Streaming ========== # ========== Text Streaming ==========
@@ -117,7 +151,7 @@ class StreamToolOutputAvailable(StreamBaseResponse):
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
toolCallId: str = Field(..., description="Tool call ID this responds to") toolCallId: str = Field(..., description="Tool call ID this responds to")
output: str | dict[str, Any] = Field(..., description="Tool execution output") output: str | dict[str, Any] = Field(..., description="Tool execution output")
# Additional fields for internal use (not part of AI SDK spec but useful) # Keep these for internal backend use
toolName: str | None = Field( toolName: str | None = Field(
default=None, description="Name of the tool that was executed" default=None, description="Name of the tool that was executed"
) )
@@ -125,6 +159,17 @@ class StreamToolOutputAvailable(StreamBaseResponse):
default=True, description="Whether the tool execution succeeded" default=True, description="Whether the tool execution succeeded"
) )
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-spec fields."""
import json
data = {
"type": self.type.value,
"toolCallId": self.toolCallId,
"output": self.output,
}
return f"data: {json.dumps(data)}\n\n"
# ========== Other ========== # ========== Other ==========

View File

@@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator
from typing import Annotated from typing import Annotated
from autogpt_libs import auth from autogpt_libs import auth
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
@@ -17,7 +17,29 @@ from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig from .config import ChatConfig
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
from .response_model import StreamFinish, StreamHeartbeat, StreamStart from .response_model import StreamFinish, StreamHeartbeat
from .tools.models import (
AgentDetailsResponse,
AgentOutputResponse,
AgentPreviewResponse,
AgentSavedResponse,
AgentsFoundResponse,
BlockListResponse,
BlockOutputResponse,
ClarificationNeededResponse,
DocPageResponse,
DocSearchResultsResponse,
ErrorResponse,
ExecutionStartedResponse,
InputValidationErrorResponse,
NeedLoginResponse,
NoResultsResponse,
OperationInProgressResponse,
OperationPendingResponse,
OperationStartedResponse,
SetupRequirementsResponse,
UnderstandingUpdatedResponse,
)
config = ChatConfig() config = ChatConfig()
@@ -266,12 +288,36 @@ async def stream_chat_post(
""" """
import asyncio import asyncio
import time
stream_start_time = time.perf_counter()
log_meta = {"component": "ChatStream", "session_id": session_id}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
f"user={user_id}, message_len={len(request.message)}",
extra={"json_fields": log_meta},
)
session = await _validate_and_get_session(session_id, user_id) session = await _validate_and_get_session(session_id, user_id)
logger.info(
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - stream_start_time) * 1000,
}
},
)
# Create a task in the stream registry for reconnection support # Create a task in the stream registry for reconnection support
task_id = str(uuid_module.uuid4()) task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4()) operation_id = str(uuid_module.uuid4())
log_meta["task_id"] = task_id
task_create_start = time.perf_counter()
await stream_registry.create_task( await stream_registry.create_task(
task_id=task_id, task_id=task_id,
session_id=session_id, session_id=session_id,
@@ -280,14 +326,28 @@ async def stream_chat_post(
tool_name="chat", tool_name="chat",
operation_id=operation_id, operation_id=operation_id,
) )
logger.info(
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
}
},
)
# Background task that runs the AI generation independently of SSE connection # Background task that runs the AI generation independently of SSE connection
async def run_ai_generation(): async def run_ai_generation():
try: import time as time_module
# Emit a start event with task_id for reconnection
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
await stream_registry.publish_chunk(task_id, start_chunk)
gen_start_time = time_module.perf_counter()
logger.info(
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
extra={"json_fields": log_meta},
)
first_chunk_time, ttfc = None, None
chunk_count = 0
try:
async for chunk in chat_service.stream_chat_completion( async for chunk in chat_service.stream_chat_completion(
session_id, session_id,
request.message, request.message,
@@ -295,25 +355,79 @@ async def stream_chat_post(
user_id=user_id, user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context, context=request.context,
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
): ):
chunk_count += 1
if first_chunk_time is None:
first_chunk_time = time_module.perf_counter()
ttfc = first_chunk_time - gen_start_time
logger.info(
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"chunk_type": type(chunk).__name__,
"time_to_first_chunk_ms": ttfc * 1000,
}
},
)
# Write to Redis (subscribers will receive via XREAD) # Write to Redis (subscribers will receive via XREAD)
await stream_registry.publish_chunk(task_id, chunk) await stream_registry.publish_chunk(task_id, chunk)
# Mark task as completed gen_end_time = time_module.perf_counter()
total_time = (gen_end_time - gen_start_time) * 1000
logger.info(
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
f"task={task_id}, session={session_id}, "
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"time_to_first_chunk_ms": (
ttfc * 1000 if ttfc is not None else None
),
"n_chunks": chunk_count,
}
},
)
await stream_registry.mark_task_completed(task_id, "completed") await stream_registry.mark_task_completed(task_id, "completed")
except Exception as e: except Exception as e:
elapsed = time_module.perf_counter() - gen_start_time
logger.error( logger.error(
f"Error in background AI generation for session {session_id}: {e}" f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed * 1000,
"error": str(e),
}
},
) )
await stream_registry.mark_task_completed(task_id, "failed") await stream_registry.mark_task_completed(task_id, "failed")
# Start the AI generation in a background task # Start the AI generation in a background task
bg_task = asyncio.create_task(run_ai_generation()) bg_task = asyncio.create_task(run_ai_generation())
await stream_registry.set_task_asyncio_task(task_id, bg_task) await stream_registry.set_task_asyncio_task(task_id, bg_task)
setup_time = (time.perf_counter() - stream_start_time) * 1000
logger.info(
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
)
# SSE endpoint that subscribes to the task's stream # SSE endpoint that subscribes to the task's stream
async def event_generator() -> AsyncGenerator[str, None]: async def event_generator() -> AsyncGenerator[str, None]:
import time as time_module
event_gen_start = time_module.perf_counter()
logger.info(
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
f"user={user_id}",
extra={"json_fields": log_meta},
)
subscriber_queue = None subscriber_queue = None
first_chunk_yielded = False
chunks_yielded = 0
try: try:
# Subscribe to the task stream (this replays existing messages + live updates) # Subscribe to the task stream (this replays existing messages + live updates)
subscriber_queue = await stream_registry.subscribe_to_task( subscriber_queue = await stream_registry.subscribe_to_task(
@@ -328,22 +442,70 @@ async def stream_chat_post(
return return
# Read from the subscriber queue and yield to SSE # Read from the subscriber queue and yield to SSE
logger.info(
"[TIMING] Starting to read from subscriber_queue",
extra={"json_fields": log_meta},
)
while True: while True:
try: try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0) chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
chunks_yielded += 1
if not first_chunk_yielded:
first_chunk_yielded = True
elapsed = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] FIRST CHUNK from queue at {elapsed:.2f}s, "
f"type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"chunk_type": type(chunk).__name__,
"elapsed_ms": elapsed * 1000,
}
},
)
yield chunk.to_sse() yield chunk.to_sse()
# Check for finish signal # Check for finish signal
if isinstance(chunk, StreamFinish): if isinstance(chunk, StreamFinish):
total_time = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] StreamFinish received in {total_time:.2f}s; "
f"n_chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"chunks_yielded": chunks_yielded,
"total_time_ms": total_time * 1000,
}
},
)
break break
except asyncio.TimeoutError: except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse() yield StreamHeartbeat().to_sse()
except GeneratorExit: except GeneratorExit:
logger.info(
f"[TIMING] GeneratorExit (client disconnected), chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"chunks_yielded": chunks_yielded,
"reason": "client_disconnect",
}
},
)
pass # Client disconnected - background task continues pass # Client disconnected - background task continues
except Exception as e: except Exception as e:
logger.error(f"Error in SSE stream for task {task_id}: {e}") elapsed = (time_module.perf_counter() - event_gen_start) * 1000
logger.error(
f"[TIMING] event_generator ERROR after {elapsed:.1f}ms: {e}",
extra={
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
},
)
finally: finally:
# Unsubscribe when client disconnects or stream ends to prevent resource leak # Unsubscribe when client disconnects or stream ends to prevent resource leak
if subscriber_queue is not None: if subscriber_queue is not None:
@@ -357,6 +519,18 @@ async def stream_chat_post(
exc_info=True, exc_info=True,
) )
# AI SDK protocol termination - always yield even if unsubscribe fails # AI SDK protocol termination - always yield even if unsubscribe fails
total_time = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time * 1000,
"chunks_yielded": chunks_yielded,
}
},
)
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse( return StreamingResponse(
@@ -374,44 +548,53 @@ async def stream_chat_post(
@router.get( @router.get(
"/sessions/{session_id}/stream", "/sessions/{session_id}/stream",
) )
async def stream_chat_get( async def resume_session_stream(
session_id: str, session_id: str,
message: Annotated[str, Query(min_length=1, max_length=10000)],
user_id: str | None = Depends(auth.get_user_id), user_id: str | None = Depends(auth.get_user_id),
is_user_message: bool = Query(default=True),
): ):
""" """
Stream chat responses for a session (GET - legacy endpoint). Resume an active stream for a session.
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including: Called by the AI SDK's ``useChat(resume: true)`` on page load.
- Text fragments as they are generated Checks for an active (in-progress) task on the session and either replays
- Tool call UI elements (if invoked) the full SSE stream or returns 204 No Content if nothing is running.
- Tool execution results
Args: Args:
session_id: The chat session identifier to associate with the streamed messages. session_id: The chat session identifier.
message: The user's new message to process.
user_id: Optional authenticated user ID. user_id: Optional authenticated user ID.
is_user_message: Whether the message is a user message.
Returns:
StreamingResponse: SSE-formatted response chunks.
Returns:
StreamingResponse (SSE) when an active stream exists,
or 204 No Content when there is nothing to resume.
""" """
session = await _validate_and_get_session(session_id, user_id) import asyncio
active_task, _last_id = await stream_registry.get_active_task_for_session(
session_id, user_id
)
if not active_task:
return Response(status_code=204)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=active_task.task_id,
user_id=user_id,
last_message_id="0-0", # Full replay so useChat rebuilds the message
)
if subscriber_queue is None:
return Response(status_code=204)
async def event_generator() -> AsyncGenerator[str, None]: async def event_generator() -> AsyncGenerator[str, None]:
chunk_count = 0 chunk_count = 0
first_chunk_type: str | None = None first_chunk_type: str | None = None
async for chunk in chat_service.stream_chat_completion( try:
session_id, while True:
message, try:
is_user_message=is_user_message, chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
):
if chunk_count < 3: if chunk_count < 3:
logger.info( logger.info(
"Chat stream chunk", "Resume stream chunk",
extra={ extra={
"session_id": session_id, "session_id": session_id,
"chunk_type": str(chunk.type), "chunk_type": str(chunk.type),
@@ -421,15 +604,33 @@ async def stream_chat_get(
first_chunk_type = str(chunk.type) first_chunk_type = str(chunk.type)
chunk_count += 1 chunk_count += 1
yield chunk.to_sse() yield chunk.to_sse()
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
yield StreamHeartbeat().to_sse()
except GeneratorExit:
pass
except Exception as e:
logger.error(f"Error in resume stream for session {session_id}: {e}")
finally:
try:
await stream_registry.unsubscribe_from_task(
active_task.task_id, subscriber_queue
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
exc_info=True,
)
logger.info( logger.info(
"Chat stream completed", "Resume stream completed",
extra={ extra={
"session_id": session_id, "session_id": session_id,
"chunk_count": chunk_count, "n_chunks": chunk_count,
"first_chunk_type": first_chunk_type, "first_chunk_type": first_chunk_type,
}, },
) )
# AI SDK protocol termination
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse( return StreamingResponse(
@@ -438,8 +639,8 @@ async def stream_chat_get(
headers={ headers={
"Cache-Control": "no-cache", "Cache-Control": "no-cache",
"Connection": "keep-alive", "Connection": "keep-alive",
"X-Accel-Buffering": "no", # Disable nginx buffering "X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header "x-vercel-ai-ui-message-stream": "v1",
}, },
) )
@@ -751,3 +952,42 @@ async def health_check() -> dict:
"service": "chat", "service": "chat",
"version": "0.1.0", "version": "0.1.0",
} }
# ========== Schema Export (for OpenAPI / Orval codegen) ==========
ToolResponseUnion = (
AgentsFoundResponse
| NoResultsResponse
| AgentDetailsResponse
| SetupRequirementsResponse
| ExecutionStartedResponse
| NeedLoginResponse
| ErrorResponse
| InputValidationErrorResponse
| AgentOutputResponse
| UnderstandingUpdatedResponse
| AgentPreviewResponse
| AgentSavedResponse
| ClarificationNeededResponse
| BlockListResponse
| BlockOutputResponse
| DocSearchResultsResponse
| DocPageResponse
| OperationStartedResponse
| OperationPendingResponse
| OperationInProgressResponse
)
@router.get(
"/schema/tool-responses",
response_model=ToolResponseUnion,
include_in_schema=True,
summary="[Dummy] Tool response type export for codegen",
description="This endpoint is not meant to be called. It exists solely to "
"expose tool response models in the OpenAPI schema for frontend codegen.",
)
async def _tool_response_schema() -> ToolResponseUnion: # type: ignore[return]
"""Never called at runtime. Exists only so Orval generates TS types."""
raise HTTPException(status_code=501, detail="Schema-only endpoint")

View File

@@ -52,8 +52,10 @@ from .response_model import (
StreamBaseResponse, StreamBaseResponse,
StreamError, StreamError,
StreamFinish, StreamFinish,
StreamFinishStep,
StreamHeartbeat, StreamHeartbeat,
StreamStart, StreamStart,
StreamStartStep,
StreamTextDelta, StreamTextDelta,
StreamTextEnd, StreamTextEnd,
StreamTextStart, StreamTextStart,
@@ -351,6 +353,10 @@ async def stream_chat_completion(
retry_count: int = 0, retry_count: int = 0,
session: ChatSession | None = None, session: ChatSession | None = None,
context: dict[str, str] | None = None, # {url: str, content: str} context: dict[str, str] | None = None, # {url: str, content: str}
_continuation_message_id: (
str | None
) = None, # Internal: reuse message ID for tool call continuations
_task_id: str | None = None, # Internal: task ID for SSE reconnection support
) -> AsyncGenerator[StreamBaseResponse, None]: ) -> AsyncGenerator[StreamBaseResponse, None]:
"""Main entry point for streaming chat completions with database handling. """Main entry point for streaming chat completions with database handling.
@@ -371,21 +377,45 @@ async def stream_chat_completion(
ValueError: If max_context_messages is exceeded ValueError: If max_context_messages is exceeded
""" """
completion_start = time.monotonic()
# Build log metadata for structured logging
log_meta = {"component": "ChatService", "session_id": session_id}
if user_id:
log_meta["user_id"] = user_id
logger.info( logger.info(
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}" f"[TIMING] stream_chat_completion STARTED, session={session_id}, user={user_id}, "
f"message_len={len(message) if message else 0}, is_user={is_user_message}",
extra={
"json_fields": {
**log_meta,
"message_len": len(message) if message else 0,
"is_user_message": is_user_message,
}
},
) )
# Only fetch from Redis if session not provided (initial call) # Only fetch from Redis if session not provided (initial call)
if session is None: if session is None:
fetch_start = time.monotonic()
session = await get_chat_session(session_id, user_id) session = await get_chat_session(session_id, user_id)
fetch_time = (time.monotonic() - fetch_start) * 1000
logger.info( logger.info(
f"Fetched session from Redis: {session.session_id if session else 'None'}, " f"[TIMING] get_chat_session took {fetch_time:.1f}ms, "
f"message_count={len(session.messages) if session else 0}" f"n_messages={len(session.messages) if session else 0}",
extra={
"json_fields": {
**log_meta,
"duration_ms": fetch_time,
"n_messages": len(session.messages) if session else 0,
}
},
) )
else: else:
logger.info( logger.info(
f"Using provided session object: {session.session_id}, " f"[TIMING] Using provided session, messages={len(session.messages)}",
f"message_count={len(session.messages)}" extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
) )
if not session: if not session:
@@ -406,17 +436,25 @@ async def stream_chat_completion(
# Track user message in PostHog # Track user message in PostHog
if is_user_message: if is_user_message:
posthog_start = time.monotonic()
track_user_message( track_user_message(
user_id=user_id, user_id=user_id,
session_id=session_id, session_id=session_id,
message_length=len(message), message_length=len(message),
) )
posthog_time = (time.monotonic() - posthog_start) * 1000
logger.info( logger.info(
f"Upserting session: {session.session_id} with user id {session.user_id}, " f"[TIMING] track_user_message took {posthog_time:.1f}ms",
f"message_count={len(session.messages)}" extra={"json_fields": {**log_meta, "duration_ms": posthog_time}},
) )
upsert_start = time.monotonic()
session = await upsert_chat_session(session) session = await upsert_chat_session(session)
upsert_time = (time.monotonic() - upsert_start) * 1000
logger.info(
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": upsert_time}},
)
assert session, "Session not found" assert session, "Session not found"
# Generate title for new sessions on first user message (non-blocking) # Generate title for new sessions on first user message (non-blocking)
@@ -454,7 +492,13 @@ async def stream_chat_completion(
asyncio.create_task(_update_title()) asyncio.create_task(_update_title())
# Build system prompt with business understanding # Build system prompt with business understanding
prompt_start = time.monotonic()
system_prompt, understanding = await _build_system_prompt(user_id) system_prompt, understanding = await _build_system_prompt(user_id)
prompt_time = (time.monotonic() - prompt_start) * 1000
logger.info(
f"[TIMING] _build_system_prompt took {prompt_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": prompt_time}},
)
# Initialize variables for streaming # Initialize variables for streaming
assistant_response = ChatMessage( assistant_response = ChatMessage(
@@ -479,13 +523,27 @@ async def stream_chat_completion(
# Generate unique IDs for AI SDK protocol # Generate unique IDs for AI SDK protocol
import uuid as uuid_module import uuid as uuid_module
message_id = str(uuid_module.uuid4()) is_continuation = _continuation_message_id is not None
message_id = _continuation_message_id or str(uuid_module.uuid4())
text_block_id = str(uuid_module.uuid4()) text_block_id = str(uuid_module.uuid4())
# Yield message start # Only yield message start for the initial call, not for continuations.
yield StreamStart(messageId=message_id) setup_time = (time.monotonic() - completion_start) * 1000
logger.info(
f"[TIMING] Setup complete, yielding StreamStart at {setup_time:.1f}ms",
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
)
if not is_continuation:
yield StreamStart(messageId=message_id, taskId=_task_id)
# Emit start-step before each LLM call (AI SDK uses this to add step boundaries)
yield StreamStartStep()
try: try:
logger.info(
"[TIMING] Calling _stream_chat_chunks",
extra={"json_fields": log_meta},
)
async for chunk in _stream_chat_chunks( async for chunk in _stream_chat_chunks(
session=session, session=session,
tools=tools, tools=tools,
@@ -585,6 +643,10 @@ async def stream_chat_completion(
) )
yield chunk yield chunk
elif isinstance(chunk, StreamFinish): elif isinstance(chunk, StreamFinish):
if has_done_tool_call:
# Tool calls happened — close the step but don't send message-level finish.
# The continuation will open a new step, and finish will come at the end.
yield StreamFinishStep()
if not has_done_tool_call: if not has_done_tool_call:
# Emit text-end before finish if we received text but haven't closed it # Emit text-end before finish if we received text but haven't closed it
if has_received_text and not text_streaming_ended: if has_received_text and not text_streaming_ended:
@@ -616,6 +678,8 @@ async def stream_chat_completion(
has_saved_assistant_message = True has_saved_assistant_message = True
has_yielded_end = True has_yielded_end = True
# Emit finish-step before finish (resets AI SDK text/reasoning state)
yield StreamFinishStep()
yield chunk yield chunk
elif isinstance(chunk, StreamError): elif isinstance(chunk, StreamError):
has_yielded_error = True has_yielded_error = True
@@ -665,6 +729,10 @@ async def stream_chat_completion(
logger.info( logger.info(
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}" f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
) )
# Close the current step before retrying so the recursive call's
# StreamStartStep doesn't produce unbalanced step events.
if not has_yielded_end:
yield StreamFinishStep()
should_retry = True should_retry = True
else: else:
# Non-retryable error or max retries exceeded # Non-retryable error or max retries exceeded
@@ -700,6 +768,7 @@ async def stream_chat_completion(
error_response = StreamError(errorText=error_message) error_response = StreamError(errorText=error_message)
yield error_response yield error_response
if not has_yielded_end: if not has_yielded_end:
yield StreamFinishStep()
yield StreamFinish() yield StreamFinish()
return return
@@ -714,6 +783,8 @@ async def stream_chat_completion(
retry_count=retry_count + 1, retry_count=retry_count + 1,
session=session, session=session,
context=context, context=context,
_continuation_message_id=message_id, # Reuse message ID since start was already sent
_task_id=_task_id,
): ):
yield chunk yield chunk
return # Exit after retry to avoid double-saving in finally block return # Exit after retry to avoid double-saving in finally block
@@ -783,6 +854,8 @@ async def stream_chat_completion(
session=session, # Pass session object to avoid Redis refetch session=session, # Pass session object to avoid Redis refetch
context=context, context=context,
tool_call_response=str(tool_response_messages), tool_call_response=str(tool_response_messages),
_continuation_message_id=message_id, # Reuse message ID to avoid duplicates
_task_id=_task_id,
): ):
yield chunk yield chunk
@@ -893,9 +966,21 @@ async def _stream_chat_chunks(
SSE formatted JSON response objects SSE formatted JSON response objects
""" """
import time as time_module
stream_chunks_start = time_module.perf_counter()
model = config.model model = config.model
logger.info("Starting pure chat stream") # Build log metadata for structured logging
log_meta = {"component": "ChatService", "session_id": session.session_id}
if session.user_id:
log_meta["user_id"] = session.user_id
logger.info(
f"[TIMING] _stream_chat_chunks STARTED, session={session.session_id}, "
f"user={session.user_id}, n_messages={len(session.messages)}",
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
)
messages = session.to_openai_messages() messages = session.to_openai_messages()
if system_prompt: if system_prompt:
@@ -906,12 +991,18 @@ async def _stream_chat_chunks(
messages = [system_message] + messages messages = [system_message] + messages
# Apply context window management # Apply context window management
context_start = time_module.perf_counter()
context_result = await _manage_context_window( context_result = await _manage_context_window(
messages=messages, messages=messages,
model=model, model=model,
api_key=config.api_key, api_key=config.api_key,
base_url=config.base_url, base_url=config.base_url,
) )
context_time = (time_module.perf_counter() - context_start) * 1000
logger.info(
f"[TIMING] _manage_context_window took {context_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": context_time}},
)
if context_result.error: if context_result.error:
if "System prompt dropped" in context_result.error: if "System prompt dropped" in context_result.error:
@@ -946,9 +1037,19 @@ async def _stream_chat_chunks(
while retry_count <= MAX_RETRIES: while retry_count <= MAX_RETRIES:
try: try:
elapsed = (time_module.perf_counter() - stream_chunks_start) * 1000
retry_info = (
f" (retry {retry_count}/{MAX_RETRIES})" if retry_count > 0 else ""
)
logger.info( logger.info(
f"Creating OpenAI chat completion stream..." f"[TIMING] Creating OpenAI stream at {elapsed:.1f}ms{retry_info}",
f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}" extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"retry_count": retry_count,
}
},
) )
# Build extra_body for OpenRouter tracing and PostHog analytics # Build extra_body for OpenRouter tracing and PostHog analytics
@@ -965,6 +1066,7 @@ async def _stream_chat_chunks(
:128 :128
] # OpenRouter limit ] # OpenRouter limit
api_call_start = time_module.perf_counter()
stream = await client.chat.completions.create( stream = await client.chat.completions.create(
model=model, model=model,
messages=cast(list[ChatCompletionMessageParam], messages), messages=cast(list[ChatCompletionMessageParam], messages),
@@ -974,6 +1076,11 @@ async def _stream_chat_chunks(
stream_options=ChatCompletionStreamOptionsParam(include_usage=True), stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
extra_body=extra_body, extra_body=extra_body,
) )
api_init_time = (time_module.perf_counter() - api_call_start) * 1000
logger.info(
f"[TIMING] OpenAI stream object returned in {api_init_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": api_init_time}},
)
# Variables to accumulate tool calls # Variables to accumulate tool calls
tool_calls: list[dict[str, Any]] = [] tool_calls: list[dict[str, Any]] = []
@@ -984,10 +1091,13 @@ async def _stream_chat_chunks(
# Track if we've started the text block # Track if we've started the text block
text_started = False text_started = False
first_content_chunk = True
chunk_count = 0
# Process the stream # Process the stream
chunk: ChatCompletionChunk chunk: ChatCompletionChunk
async for chunk in stream: async for chunk in stream:
chunk_count += 1
if chunk.usage: if chunk.usage:
yield StreamUsage( yield StreamUsage(
promptTokens=chunk.usage.prompt_tokens, promptTokens=chunk.usage.prompt_tokens,
@@ -1010,6 +1120,23 @@ async def _stream_chat_chunks(
if not text_started and text_block_id: if not text_started and text_block_id:
yield StreamTextStart(id=text_block_id) yield StreamTextStart(id=text_block_id)
text_started = True text_started = True
# Log timing for first content chunk
if first_content_chunk:
first_content_chunk = False
ttfc = (
time_module.perf_counter() - api_call_start
) * 1000
logger.info(
f"[TIMING] FIRST CONTENT CHUNK at {ttfc:.1f}ms "
f"(since API call), n_chunks={chunk_count}",
extra={
"json_fields": {
**log_meta,
"time_to_first_chunk_ms": ttfc,
"n_chunks": chunk_count,
}
},
)
# Stream the text delta # Stream the text delta
text_response = StreamTextDelta( text_response = StreamTextDelta(
id=text_block_id or "", id=text_block_id or "",
@@ -1066,7 +1193,21 @@ async def _stream_chat_chunks(
toolName=tool_calls[idx]["function"]["name"], toolName=tool_calls[idx]["function"]["name"],
) )
emitted_start_for_idx.add(idx) emitted_start_for_idx.add(idx)
logger.info(f"Stream complete. Finish reason: {finish_reason}") stream_duration = time_module.perf_counter() - api_call_start
logger.info(
f"[TIMING] OpenAI stream COMPLETE, finish_reason={finish_reason}, "
f"duration={stream_duration:.2f}s, "
f"n_chunks={chunk_count}, n_tool_calls={len(tool_calls)}",
extra={
"json_fields": {
**log_meta,
"stream_duration_ms": stream_duration * 1000,
"finish_reason": finish_reason,
"n_chunks": chunk_count,
"n_tool_calls": len(tool_calls),
}
},
)
# Yield all accumulated tool calls after the stream is complete # Yield all accumulated tool calls after the stream is complete
# This ensures all tool call arguments have been fully received # This ensures all tool call arguments have been fully received
@@ -1086,11 +1227,16 @@ async def _stream_chat_chunks(
# Re-raise to trigger retry logic in the parent function # Re-raise to trigger retry logic in the parent function
raise raise
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
logger.info(
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; "
f"session={session.session_id}, user={session.user_id}",
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
)
yield StreamFinish() yield StreamFinish()
return return
except Exception as e: except Exception as e:
last_error = e last_error = e
if _is_retryable_error(e) and retry_count < MAX_RETRIES: if _is_retryable_error(e) and retry_count < MAX_RETRIES:
retry_count += 1 retry_count += 1
# Calculate delay with exponential backoff # Calculate delay with exponential backoff
@@ -1106,26 +1252,12 @@ async def _stream_chat_chunks(
continue # Retry the stream continue # Retry the stream
else: else:
# Non-retryable error or max retries exceeded # Non-retryable error or max retries exceeded
_log_api_error( logger.error(
error=e, f"Error in stream (not retrying): {e!s}",
session_id=session.session_id if session else None, exc_info=True,
message_count=len(messages) if messages else None,
model=model,
retry_count=retry_count,
) )
error_code = None error_code = None
error_text = str(e) error_text = str(e)
error_details = _extract_api_error_details(e)
if error_details.get("response_body"):
body = error_details["response_body"]
if isinstance(body, dict):
err = body.get("error")
if isinstance(err, dict) and err.get("message"):
error_text = err["message"]
elif body.get("message"):
error_text = body["message"]
if _is_region_blocked_error(e): if _is_region_blocked_error(e):
error_code = "MODEL_NOT_AVAILABLE_REGION" error_code = "MODEL_NOT_AVAILABLE_REGION"
error_text = ( error_text = (
@@ -1142,12 +1274,9 @@ async def _stream_chat_chunks(
# If we exit the retry loop without returning, it means we exhausted retries # If we exit the retry loop without returning, it means we exhausted retries
if last_error: if last_error:
_log_api_error( logger.error(
error=last_error, f"Max retries ({MAX_RETRIES}) exceeded. Last error: {last_error!s}",
session_id=session.session_id if session else None, exc_info=True,
message_count=len(messages) if messages else None,
model=model,
retry_count=MAX_RETRIES,
) )
yield StreamError(errorText=f"Max retries exceeded: {last_error!s}") yield StreamError(errorText=f"Max retries exceeded: {last_error!s}")
yield StreamFinish() yield StreamFinish()
@@ -1583,6 +1712,7 @@ async def _execute_long_running_tool_with_streaming(
task_id, task_id,
StreamError(errorText=str(e)), StreamError(errorText=str(e)),
) )
await stream_registry.publish_chunk(task_id, StreamFinishStep())
await stream_registry.publish_chunk(task_id, StreamFinish()) await stream_registry.publish_chunk(task_id, StreamFinish())
await _update_pending_operation( await _update_pending_operation(
@@ -1719,7 +1849,6 @@ async def _generate_llm_continuation(
break # Success, exit retry loop break # Success, exit retry loop
except Exception as e: except Exception as e:
last_error = e last_error = e
if _is_retryable_error(e) and retry_count < MAX_RETRIES: if _is_retryable_error(e) and retry_count < MAX_RETRIES:
retry_count += 1 retry_count += 1
delay = min( delay = min(
@@ -1733,23 +1862,17 @@ async def _generate_llm_continuation(
await asyncio.sleep(delay) await asyncio.sleep(delay)
continue continue
else: else:
# Non-retryable error - log details and exit gracefully # Non-retryable error - log and exit gracefully
_log_api_error( logger.error(
error=e, f"Non-retryable error in LLM continuation: {e!s}",
session_id=session_id, exc_info=True,
message_count=len(messages) if messages else None,
model=config.model,
retry_count=retry_count,
) )
return return
if last_error: if last_error:
_log_api_error( logger.error(
error=last_error, f"Max retries ({MAX_RETRIES}) exceeded for LLM continuation. "
session_id=session_id, f"Last error: {last_error!s}"
message_count=len(messages) if messages else None,
model=config.model,
retry_count=MAX_RETRIES,
) )
return return
@@ -1789,89 +1912,6 @@ async def _generate_llm_continuation(
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True) logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
def _log_api_error(
error: Exception,
session_id: str | None = None,
message_count: int | None = None,
model: str | None = None,
retry_count: int = 0,
) -> None:
"""Log detailed API error information for debugging."""
details = _extract_api_error_details(error)
details["session_id"] = session_id
details["message_count"] = message_count
details["model"] = model
details["retry_count"] = retry_count
if isinstance(error, RateLimitError):
logger.warning(f"Rate limit error: {details}")
elif isinstance(error, APIConnectionError):
logger.warning(f"API connection error: {details}")
elif isinstance(error, APIStatusError) and error.status_code >= 500:
logger.error(f"API server error (5xx): {details}")
else:
logger.error(f"API error: {details}")
def _extract_api_error_details(error: Exception) -> dict[str, Any]:
"""Extract detailed information from OpenAI/OpenRouter API errors."""
error_msg = str(error)
details: dict[str, Any] = {
"error_type": type(error).__name__,
"error_message": error_msg[:500] + "..." if len(error_msg) > 500 else error_msg,
}
if hasattr(error, "code"):
details["code"] = getattr(error, "code", None)
if hasattr(error, "param"):
details["param"] = getattr(error, "param", None)
if isinstance(error, APIStatusError):
details["status_code"] = error.status_code
details["request_id"] = getattr(error, "request_id", None)
if hasattr(error, "body") and error.body:
details["response_body"] = _sanitize_error_body(error.body)
if hasattr(error, "response") and error.response:
headers = error.response.headers
details["openrouter_provider"] = headers.get("x-openrouter-provider")
details["openrouter_model"] = headers.get("x-openrouter-model")
details["retry_after"] = headers.get("retry-after")
details["rate_limit_remaining"] = headers.get("x-ratelimit-remaining")
return details
def _sanitize_error_body(
body: Any, max_length: int = 2000
) -> dict[str, Any] | str | None:
"""Extract only safe fields from error response body to avoid logging sensitive data."""
if not isinstance(body, dict):
# Non-dict bodies (e.g., HTML error pages) - return truncated string
if body is not None:
body_str = str(body)
if len(body_str) > max_length:
return body_str[:max_length] + "...[truncated]"
return body_str
return None
safe_fields = ("message", "type", "code", "param", "error")
sanitized: dict[str, Any] = {}
for field in safe_fields:
if field in body:
value = body[field]
if field == "error" and isinstance(value, dict):
sanitized[field] = _sanitize_error_body(value, max_length)
elif isinstance(value, str) and len(value) > max_length:
sanitized[field] = value[:max_length] + "...[truncated]"
else:
sanitized[field] = value
return sanitized if sanitized else None
async def _generate_llm_continuation_with_streaming( async def _generate_llm_continuation_with_streaming(
session_id: str, session_id: str,
user_id: str | None, user_id: str | None,
@@ -1930,6 +1970,7 @@ async def _generate_llm_continuation_with_streaming(
# Publish start event # Publish start event
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id)) await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
await stream_registry.publish_chunk(task_id, StreamStartStep())
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id)) await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
# Stream the response # Stream the response
@@ -1953,6 +1994,7 @@ async def _generate_llm_continuation_with_streaming(
# Publish end events # Publish end events
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id)) await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
await stream_registry.publish_chunk(task_id, StreamFinishStep())
if assistant_content: if assistant_content:
# Reload session from DB to avoid race condition with user messages # Reload session from DB to avoid race condition with user messages
@@ -1994,4 +2036,5 @@ async def _generate_llm_continuation_with_streaming(
task_id, task_id,
StreamError(errorText=f"Failed to generate response: {e}"), StreamError(errorText=f"Failed to generate response: {e}"),
) )
await stream_registry.publish_chunk(task_id, StreamFinishStep())
await stream_registry.publish_chunk(task_id, StreamFinish()) await stream_registry.publish_chunk(task_id, StreamFinish())

View File

@@ -104,6 +104,24 @@ async def create_task(
Returns: Returns:
The created ActiveTask instance (metadata only) The created ActiveTask instance (metadata only)
""" """
import time
start_time = time.perf_counter()
# Build log metadata for structured logging
log_meta = {
"component": "StreamRegistry",
"task_id": task_id,
"session_id": session_id,
}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
extra={"json_fields": log_meta},
)
task = ActiveTask( task = ActiveTask(
task_id=task_id, task_id=task_id,
session_id=session_id, session_id=session_id,
@@ -114,10 +132,18 @@ async def create_task(
) )
# Store metadata in Redis # Store metadata in Redis
redis_start = time.perf_counter()
redis = await get_redis_async() redis = await get_redis_async()
redis_time = (time.perf_counter() - redis_start) * 1000
logger.info(
f"[TIMING] get_redis_async took {redis_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
)
meta_key = _get_task_meta_key(task_id) meta_key = _get_task_meta_key(task_id)
op_key = _get_operation_mapping_key(operation_id) op_key = _get_operation_mapping_key(operation_id)
hset_start = time.perf_counter()
await redis.hset( # type: ignore[misc] await redis.hset( # type: ignore[misc]
meta_key, meta_key,
mapping={ mapping={
@@ -131,12 +157,22 @@ async def create_task(
"created_at": task.created_at.isoformat(), "created_at": task.created_at.isoformat(),
}, },
) )
hset_time = (time.perf_counter() - hset_start) * 1000
logger.info(
f"[TIMING] redis.hset took {hset_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": hset_time}},
)
await redis.expire(meta_key, config.stream_ttl) await redis.expire(meta_key, config.stream_ttl)
# Create operation_id -> task_id mapping for webhook lookups # Create operation_id -> task_id mapping for webhook lookups
await redis.set(op_key, task_id, ex=config.stream_ttl) await redis.set(op_key, task_id, ex=config.stream_ttl)
logger.debug(f"Created task {task_id} for session {session_id}") total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
)
return task return task
@@ -156,26 +192,60 @@ async def publish_chunk(
Returns: Returns:
The Redis Stream message ID The Redis Stream message ID
""" """
import time
start_time = time.perf_counter()
chunk_type = type(chunk).__name__
chunk_json = chunk.model_dump_json() chunk_json = chunk.model_dump_json()
message_id = "0-0" message_id = "0-0"
# Build log metadata
log_meta = {
"component": "StreamRegistry",
"task_id": task_id,
"chunk_type": chunk_type,
}
try: try:
redis = await get_redis_async() redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id) stream_key = _get_task_stream_key(task_id)
# Write to Redis Stream for persistence and real-time delivery # Write to Redis Stream for persistence and real-time delivery
xadd_start = time.perf_counter()
raw_id = await redis.xadd( raw_id = await redis.xadd(
stream_key, stream_key,
{"data": chunk_json}, {"data": chunk_json},
maxlen=config.stream_max_length, maxlen=config.stream_max_length,
) )
xadd_time = (time.perf_counter() - xadd_start) * 1000
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode() message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
# Set TTL on stream to match task metadata TTL # Set TTL on stream to match task metadata TTL
await redis.expire(stream_key, config.stream_ttl) await redis.expire(stream_key, config.stream_ttl)
total_time = (time.perf_counter() - start_time) * 1000
# Only log timing for significant chunks or slow operations
if (
chunk_type
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
or total_time > 50
):
logger.info(
f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"xadd_time_ms": xadd_time,
"message_id": message_id,
}
},
)
except Exception as e: except Exception as e:
elapsed = (time.perf_counter() - start_time) * 1000
logger.error( logger.error(
f"Failed to publish chunk for task {task_id}: {e}", f"[TIMING] Failed to publish chunk {chunk_type} after {elapsed:.1f}ms: {e}",
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
exc_info=True, exc_info=True,
) )
@@ -200,24 +270,61 @@ async def subscribe_to_task(
An asyncio Queue that will receive stream chunks, or None if task not found An asyncio Queue that will receive stream chunks, or None if task not found
or user doesn't have access or user doesn't have access
""" """
import time
start_time = time.perf_counter()
# Build log metadata
log_meta = {"component": "StreamRegistry", "task_id": task_id}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
)
redis_start = time.perf_counter()
redis = await get_redis_async() redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id) meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
hgetall_time = (time.perf_counter() - redis_start) * 1000
logger.info(
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
)
if not meta: if not meta:
logger.debug(f"Task {task_id} not found in Redis") elapsed = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] Task not found in Redis after {elapsed:.1f}ms",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"reason": "task_not_found",
}
},
)
return None return None
# Note: Redis client uses decode_responses=True, so keys are strings # Note: Redis client uses decode_responses=True, so keys are strings
task_status = meta.get("status", "") task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None task_user_id = meta.get("user_id", "") or None
log_meta["session_id"] = meta.get("session_id", "")
# Validate ownership - if task has an owner, requester must match # Validate ownership - if task has an owner, requester must match
if task_user_id: if task_user_id:
if user_id != task_user_id: if user_id != task_user_id:
logger.warning( logger.warning(
f"User {user_id} denied access to task {task_id} " f"[TIMING] Access denied: user {user_id} tried to access task owned by {task_user_id}",
f"owned by {task_user_id}" extra={
"json_fields": {
**log_meta,
"task_owner": task_user_id,
"reason": "access_denied",
}
},
) )
return None return None
@@ -225,7 +332,19 @@ async def subscribe_to_task(
stream_key = _get_task_stream_key(task_id) stream_key = _get_task_stream_key(task_id)
# Step 1: Replay messages from Redis Stream # Step 1: Replay messages from Redis Stream
xread_start = time.perf_counter()
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000) messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
xread_time = (time.perf_counter() - xread_start) * 1000
logger.info(
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}",
extra={
"json_fields": {
**log_meta,
"duration_ms": xread_time,
"task_status": task_status,
}
},
)
replayed_count = 0 replayed_count = 0
replay_last_id = last_message_id replay_last_id = last_message_id
@@ -244,19 +363,48 @@ async def subscribe_to_task(
except Exception as e: except Exception as e:
logger.warning(f"Failed to replay message: {e}") logger.warning(f"Failed to replay message: {e}")
logger.debug(f"Task {task_id}: replayed {replayed_count} messages") logger.info(
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
extra={
"json_fields": {
**log_meta,
"n_messages_replayed": replayed_count,
"replay_last_id": replay_last_id,
}
},
)
# Step 2: If task is still running, start stream listener for live updates # Step 2: If task is still running, start stream listener for live updates
if task_status == "running": if task_status == "running":
logger.info(
"[TIMING] Task still running, starting _stream_listener",
extra={"json_fields": {**log_meta, "task_status": task_status}},
)
listener_task = asyncio.create_task( listener_task = asyncio.create_task(
_stream_listener(task_id, subscriber_queue, replay_last_id) _stream_listener(task_id, subscriber_queue, replay_last_id, log_meta)
) )
# Track listener task for cleanup on unsubscribe # Track listener task for cleanup on unsubscribe
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task) _listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
else: else:
# Task is completed/failed - add finish marker # Task is completed/failed - add finish marker
logger.info(
f"[TIMING] Task already {task_status}, adding StreamFinish",
extra={"json_fields": {**log_meta, "task_status": task_status}},
)
await subscriber_queue.put(StreamFinish()) await subscriber_queue.put(StreamFinish())
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
f"n_messages_replayed={replayed_count}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"n_messages_replayed": replayed_count,
}
},
)
return subscriber_queue return subscriber_queue
@@ -264,6 +412,7 @@ async def _stream_listener(
task_id: str, task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse], subscriber_queue: asyncio.Queue[StreamBaseResponse],
last_replayed_id: str, last_replayed_id: str,
log_meta: dict | None = None,
) -> None: ) -> None:
"""Listen to Redis Stream for new messages using blocking XREAD. """Listen to Redis Stream for new messages using blocking XREAD.
@@ -274,10 +423,27 @@ async def _stream_listener(
task_id: Task ID to listen for task_id: Task ID to listen for
subscriber_queue: Queue to deliver messages to subscriber_queue: Queue to deliver messages to
last_replayed_id: Last message ID from replay (continue from here) last_replayed_id: Last message ID from replay (continue from here)
log_meta: Structured logging metadata
""" """
import time
start_time = time.perf_counter()
# Use provided log_meta or build minimal one
if log_meta is None:
log_meta = {"component": "StreamRegistry", "task_id": task_id}
logger.info(
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
)
queue_id = id(subscriber_queue) queue_id = id(subscriber_queue)
# Track the last successfully delivered message ID for recovery hints # Track the last successfully delivered message ID for recovery hints
last_delivered_id = last_replayed_id last_delivered_id = last_replayed_id
messages_delivered = 0
first_message_time = None
xread_count = 0
try: try:
redis = await get_redis_async() redis = await get_redis_async()
@@ -287,9 +453,39 @@ async def _stream_listener(
while True: while True:
# Block for up to 30 seconds waiting for new messages # Block for up to 30 seconds waiting for new messages
# This allows periodic checking if task is still running # This allows periodic checking if task is still running
xread_start = time.perf_counter()
xread_count += 1
messages = await redis.xread( messages = await redis.xread(
{stream_key: current_id}, block=30000, count=100 {stream_key: current_id}, block=30000, count=100
) )
xread_time = (time.perf_counter() - xread_start) * 1000
if messages:
msg_count = sum(len(msgs) for _, msgs in messages)
logger.info(
f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms",
extra={
"json_fields": {
**log_meta,
"xread_count": xread_count,
"n_messages": msg_count,
"duration_ms": xread_time,
}
},
)
elif xread_time > 1000:
# Only log timeouts (30s blocking)
logger.info(
f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms",
extra={
"json_fields": {
**log_meta,
"xread_count": xread_count,
"duration_ms": xread_time,
"reason": "timeout",
}
},
)
if not messages: if not messages:
# Timeout - check if task is still running # Timeout - check if task is still running
@@ -326,10 +522,30 @@ async def _stream_listener(
) )
# Update last delivered ID on successful delivery # Update last delivered ID on successful delivery
last_delivered_id = current_id last_delivered_id = current_id
messages_delivered += 1
if first_message_time is None:
first_message_time = time.perf_counter()
elapsed = (first_message_time - start_time) * 1000
logger.info(
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"chunk_type": type(chunk).__name__,
}
},
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning( logger.warning(
f"Subscriber queue full for task {task_id}, " f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s",
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s" extra={
"json_fields": {
**log_meta,
"timeout_s": QUEUE_PUT_TIMEOUT,
"reason": "queue_full",
}
},
) )
# Send overflow error with recovery info # Send overflow error with recovery info
try: try:
@@ -351,15 +567,44 @@ async def _stream_listener(
# Stop listening on finish # Stop listening on finish
if isinstance(chunk, StreamFinish): if isinstance(chunk, StreamFinish):
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"messages_delivered": messages_delivered,
}
},
)
return return
except Exception as e: except Exception as e:
logger.warning(f"Error processing stream message: {e}") logger.warning(
f"Error processing stream message: {e}",
extra={"json_fields": {**log_meta, "error": str(e)}},
)
except asyncio.CancelledError: except asyncio.CancelledError:
logger.debug(f"Stream listener cancelled for task {task_id}") elapsed = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"messages_delivered": messages_delivered,
"reason": "cancelled",
}
},
)
raise # Re-raise to propagate cancellation raise # Re-raise to propagate cancellation
except Exception as e: except Exception as e:
logger.error(f"Stream listener error for task {task_id}: {e}") elapsed = (time.perf_counter() - start_time) * 1000
logger.error(
f"[TIMING] _stream_listener ERROR after {elapsed:.1f}ms: {e}",
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
)
# On error, send finish to unblock subscriber # On error, send finish to unblock subscriber
try: try:
await asyncio.wait_for( await asyncio.wait_for(
@@ -368,10 +613,24 @@ async def _stream_listener(
) )
except (asyncio.TimeoutError, asyncio.QueueFull): except (asyncio.TimeoutError, asyncio.QueueFull):
logger.warning( logger.warning(
f"Could not deliver finish event for task {task_id} after error" "Could not deliver finish event after error",
extra={"json_fields": log_meta},
) )
finally: finally:
# Clean up listener task mapping on exit # Clean up listener task mapping on exit
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, "
f"delivered={messages_delivered}, xread_count={xread_count}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"messages_delivered": messages_delivered,
"xread_count": xread_count,
}
},
)
_listener_tasks.pop(queue_id, None) _listener_tasks.pop(queue_id, None)
@@ -598,8 +857,10 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
ResponseType, ResponseType,
StreamError, StreamError,
StreamFinish, StreamFinish,
StreamFinishStep,
StreamHeartbeat, StreamHeartbeat,
StreamStart, StreamStart,
StreamStartStep,
StreamTextDelta, StreamTextDelta,
StreamTextEnd, StreamTextEnd,
StreamTextStart, StreamTextStart,
@@ -613,6 +874,8 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
type_to_class: dict[str, type[StreamBaseResponse]] = { type_to_class: dict[str, type[StreamBaseResponse]] = {
ResponseType.START.value: StreamStart, ResponseType.START.value: StreamStart,
ResponseType.FINISH.value: StreamFinish, ResponseType.FINISH.value: StreamFinish,
ResponseType.START_STEP.value: StreamStartStep,
ResponseType.FINISH_STEP.value: StreamFinishStep,
ResponseType.TEXT_START.value: StreamTextStart, ResponseType.TEXT_START.value: StreamTextStart,
ResponseType.TEXT_DELTA.value: StreamTextDelta, ResponseType.TEXT_DELTA.value: StreamTextDelta,
ResponseType.TEXT_END.value: StreamTextEnd, ResponseType.TEXT_END.value: StreamTextEnd,

View File

@@ -13,10 +13,32 @@ from backend.api.features.chat.tools.models import (
NoResultsResponse, NoResultsResponse,
) )
from backend.api.features.store.hybrid_search import unified_hybrid_search from backend.api.features.store.hybrid_search import unified_hybrid_search
from backend.data.block import get_block from backend.data.block import BlockType, get_block
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_TARGET_RESULTS = 10
# Over-fetch to compensate for post-hoc filtering of graph-only blocks.
# 40 is 2x current removed; speed of query 10 vs 40 is minimial
_OVERFETCH_PAGE_SIZE = 40
# Block types that only work within graphs and cannot run standalone in CoPilot.
COPILOT_EXCLUDED_BLOCK_TYPES = {
BlockType.INPUT, # Graph interface definition - data enters via chat, not graph inputs
BlockType.OUTPUT, # Graph interface definition - data exits via chat, not graph outputs
BlockType.WEBHOOK, # Wait for external events - would hang forever in CoPilot
BlockType.WEBHOOK_MANUAL, # Same as WEBHOOK
BlockType.NOTE, # Visual annotation only - no runtime behavior
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
}
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
COPILOT_EXCLUDED_BLOCK_IDS = {
# SmartDecisionMakerBlock - dynamically discovers downstream blocks via graph topology
"3b191d9f-356f-482d-8238-ba04b6d18381",
}
class FindBlockTool(BaseTool): class FindBlockTool(BaseTool):
"""Tool for searching available blocks.""" """Tool for searching available blocks."""
@@ -88,7 +110,7 @@ class FindBlockTool(BaseTool):
query=query, query=query,
content_types=[ContentType.BLOCK], content_types=[ContentType.BLOCK],
page=1, page=1,
page_size=10, page_size=_OVERFETCH_PAGE_SIZE,
) )
if not results: if not results:
@@ -108,18 +130,35 @@ class FindBlockTool(BaseTool):
block = get_block(block_id) block = get_block(block_id)
# Skip disabled blocks # Skip disabled blocks
if block and not block.disabled: if not block or block.disabled:
continue
# Skip blocks excluded from CoPilot (graph-only blocks)
if (
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
):
continue
# Get input/output schemas # Get input/output schemas
input_schema = {} input_schema = {}
output_schema = {} output_schema = {}
try: try:
input_schema = block.input_schema.jsonschema() input_schema = block.input_schema.jsonschema()
except Exception: except Exception as e:
pass logger.debug(
"Failed to generate input schema for block %s: %s",
block_id,
e,
)
try: try:
output_schema = block.output_schema.jsonschema() output_schema = block.output_schema.jsonschema()
except Exception: except Exception as e:
pass logger.debug(
"Failed to generate output schema for block %s: %s",
block_id,
e,
)
# Get categories from block instance # Get categories from block instance
categories = [] categories = []
@@ -163,6 +202,19 @@ class FindBlockTool(BaseTool):
) )
) )
if len(blocks) >= _TARGET_RESULTS:
break
if blocks and len(blocks) < _TARGET_RESULTS:
logger.debug(
"find_block returned %d/%d results for query '%s' "
"(filtered %d excluded/disabled blocks)",
len(blocks),
_TARGET_RESULTS,
query,
len(results) - len(blocks),
)
if not blocks: if not blocks:
return NoResultsResponse( return NoResultsResponse(
message=f"No blocks found for '{query}'", message=f"No blocks found for '{query}'",

View File

@@ -0,0 +1,139 @@
"""Tests for block filtering in FindBlockTool."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.api.features.chat.tools.find_block import (
COPILOT_EXCLUDED_BLOCK_IDS,
COPILOT_EXCLUDED_BLOCK_TYPES,
FindBlockTool,
)
from backend.api.features.chat.tools.models import BlockListResponse
from backend.data.block import BlockType
from ._test_data import make_session
_TEST_USER_ID = "test-user-find-block"
def make_mock_block(
block_id: str, name: str, block_type: BlockType, disabled: bool = False
):
"""Create a mock block for testing."""
mock = MagicMock()
mock.id = block_id
mock.name = name
mock.description = f"{name} description"
mock.block_type = block_type
mock.disabled = disabled
mock.input_schema = MagicMock()
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
mock.input_schema.get_credentials_fields.return_value = {}
mock.output_schema = MagicMock()
mock.output_schema.jsonschema.return_value = {}
mock.categories = []
return mock
class TestFindBlockFiltering:
"""Tests for block filtering in FindBlockTool."""
def test_excluded_block_types_contains_expected_types(self):
"""Verify COPILOT_EXCLUDED_BLOCK_TYPES contains all graph-only types."""
assert BlockType.INPUT in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.OUTPUT in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.WEBHOOK in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.WEBHOOK_MANUAL in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.NOTE in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.HUMAN_IN_THE_LOOP in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.AGENT in COPILOT_EXCLUDED_BLOCK_TYPES
def test_excluded_block_ids_contains_smart_decision_maker(self):
"""Verify SmartDecisionMakerBlock is in COPILOT_EXCLUDED_BLOCK_IDS."""
assert "3b191d9f-356f-482d-8238-ba04b6d18381" in COPILOT_EXCLUDED_BLOCK_IDS
@pytest.mark.asyncio(loop_scope="session")
async def test_excluded_block_type_filtered_from_results(self):
"""Verify blocks with excluded BlockTypes are filtered from search results."""
session = make_session(user_id=_TEST_USER_ID)
# Mock search returns an INPUT block (excluded) and a STANDARD block (included)
search_results = [
{"content_id": "input-block-id", "score": 0.9},
{"content_id": "standard-block-id", "score": 0.8},
]
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
standard_block = make_mock_block(
"standard-block-id", "HTTP Request", BlockType.STANDARD
)
def mock_get_block(block_id):
return {
"input-block-id": input_block,
"standard-block-id": standard_block,
}.get(block_id)
with patch(
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
new_callable=AsyncMock,
return_value=(search_results, 2),
):
with patch(
"backend.api.features.chat.tools.find_block.get_block",
side_effect=mock_get_block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID, session=session, query="test"
)
# Should only return the standard block, not the INPUT block
assert isinstance(response, BlockListResponse)
assert len(response.blocks) == 1
assert response.blocks[0].id == "standard-block-id"
@pytest.mark.asyncio(loop_scope="session")
async def test_excluded_block_id_filtered_from_results(self):
"""Verify SmartDecisionMakerBlock is filtered from search results."""
session = make_session(user_id=_TEST_USER_ID)
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
search_results = [
{"content_id": smart_decision_id, "score": 0.9},
{"content_id": "normal-block-id", "score": 0.8},
]
# SmartDecisionMakerBlock has STANDARD type but is excluded by ID
smart_block = make_mock_block(
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
)
normal_block = make_mock_block(
"normal-block-id", "Normal Block", BlockType.STANDARD
)
def mock_get_block(block_id):
return {
smart_decision_id: smart_block,
"normal-block-id": normal_block,
}.get(block_id)
with patch(
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
new_callable=AsyncMock,
return_value=(search_results, 2),
):
with patch(
"backend.api.features.chat.tools.find_block.get_block",
side_effect=mock_get_block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID, session=session, query="decision"
)
# Should only return normal block, not SmartDecisionMakerBlock
assert isinstance(response, BlockListResponse)
assert len(response.blocks) == 1
assert response.blocks[0].id == "normal-block-id"

View File

@@ -0,0 +1,29 @@
"""Shared helpers for chat tools."""
from typing import Any
def get_inputs_from_schema(
input_schema: dict[str, Any],
exclude_fields: set[str] | None = None,
) -> list[dict[str, Any]]:
"""Extract input field info from JSON schema."""
if not isinstance(input_schema, dict):
return []
exclude = exclude_fields or set()
properties = input_schema.get("properties", {})
required = set(input_schema.get("required", []))
return [
{
"name": name,
"title": schema.get("title", name),
"type": schema.get("type", "string"),
"description": schema.get("description", ""),
"required": name in required,
"default": schema.get("default"),
}
for name, schema in properties.items()
if name not in exclude
]

View File

@@ -24,6 +24,7 @@ from backend.util.timezone_utils import (
) )
from .base import BaseTool from .base import BaseTool
from .helpers import get_inputs_from_schema
from .models import ( from .models import (
AgentDetails, AgentDetails,
AgentDetailsResponse, AgentDetailsResponse,
@@ -261,7 +262,7 @@ class RunAgentTool(BaseTool):
), ),
requirements={ requirements={
"credentials": requirements_creds_list, "credentials": requirements_creds_list,
"inputs": self._get_inputs_list(graph.input_schema), "inputs": get_inputs_from_schema(graph.input_schema),
"execution_modes": self._get_execution_modes(graph), "execution_modes": self._get_execution_modes(graph),
}, },
), ),
@@ -369,22 +370,6 @@ class RunAgentTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
"""Extract inputs list from schema."""
inputs_list = []
if isinstance(input_schema, dict) and "properties" in input_schema:
for field_name, field_schema in input_schema["properties"].items():
inputs_list.append(
{
"name": field_name,
"title": field_schema.get("title", field_name),
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""),
"required": field_name in input_schema.get("required", []),
}
)
return inputs_list
def _get_execution_modes(self, graph: GraphModel) -> list[str]: def _get_execution_modes(self, graph: GraphModel) -> list[str]:
"""Get available execution modes for the graph.""" """Get available execution modes for the graph."""
trigger_info = graph.trigger_setup_info trigger_info = graph.trigger_setup_info
@@ -398,7 +383,7 @@ class RunAgentTool(BaseTool):
suffix: str, suffix: str,
) -> str: ) -> str:
"""Build a message describing available inputs for an agent.""" """Build a message describing available inputs for an agent."""
inputs_list = self._get_inputs_list(graph.input_schema) inputs_list = get_inputs_from_schema(graph.input_schema)
required_names = [i["name"] for i in inputs_list if i["required"]] required_names = [i["name"] for i in inputs_list if i["required"]]
optional_names = [i["name"] for i in inputs_list if not i["required"]] optional_names = [i["name"] for i in inputs_list if not i["required"]]

View File

@@ -8,14 +8,19 @@ from typing import Any
from pydantic_core import PydanticUndefined from pydantic_core import PydanticUndefined
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.data.block import get_block from backend.api.features.chat.tools.find_block import (
COPILOT_EXCLUDED_BLOCK_IDS,
COPILOT_EXCLUDED_BLOCK_TYPES,
)
from backend.data.block import AnyBlockSchema, get_block
from backend.data.execution import ExecutionContext from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsMetaInput from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.data.workspace import get_or_create_workspace from backend.data.workspace import get_or_create_workspace
from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError from backend.util.exceptions import BlockError
from .base import BaseTool from .base import BaseTool
from .helpers import get_inputs_from_schema
from .models import ( from .models import (
BlockOutputResponse, BlockOutputResponse,
ErrorResponse, ErrorResponse,
@@ -24,7 +29,10 @@ from .models import (
ToolResponseBase, ToolResponseBase,
UserReadiness, UserReadiness,
) )
from .utils import build_missing_credentials_from_field_info from .utils import (
build_missing_credentials_from_field_info,
match_credentials_to_requirements,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -73,91 +81,6 @@ class RunBlockTool(BaseTool):
def requires_auth(self) -> bool: def requires_auth(self) -> bool:
return True return True
async def _check_block_credentials(
self,
user_id: str,
block: Any,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Check if user has required credentials for a block.
Args:
user_id: User ID
block: Block to check credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple[matched_credentials, missing_credentials]
"""
matched_credentials: dict[str, CredentialsMetaInput] = {}
missing_credentials: list[CredentialsMetaInput] = []
input_data = input_data or {}
# Get credential field info from block's input schema
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return matched_credentials, missing_credentials
# Get user's available credentials
creds_manager = IntegrationCredentialsManager()
available_creds = await creds_manager.store.get_all_creds(user_id)
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
# Get discriminator from input, falling back to schema default
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(
field_info.discriminator
)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in effective_field_info.provider
and cred.type in effective_field_info.supported_types
),
None,
)
if matching_cred:
matched_credentials[field_name] = CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
else:
# Create a placeholder for the missing credential
provider = next(iter(effective_field_info.provider), "unknown")
cred_type = next(iter(effective_field_info.supported_types), "api_key")
missing_credentials.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return matched_credentials, missing_credentials
async def _execute( async def _execute(
self, self,
user_id: str | None, user_id: str | None,
@@ -212,11 +135,24 @@ class RunBlockTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
# Check if block is excluded from CoPilot (graph-only blocks)
if (
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
):
return ErrorResponse(
message=(
f"Block '{block.name}' cannot be run directly in CoPilot. "
"This block is designed for use within graphs only."
),
session_id=session_id,
)
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}") logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
creds_manager = IntegrationCredentialsManager() creds_manager = IntegrationCredentialsManager()
matched_credentials, missing_credentials = await self._check_block_credentials( matched_credentials, missing_credentials = (
user_id, block, input_data await self._resolve_block_credentials(user_id, block, input_data)
) )
if missing_credentials: if missing_credentials:
@@ -345,29 +281,75 @@ class RunBlockTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]: async def _resolve_block_credentials(
self,
user_id: str,
block: AnyBlockSchema,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Resolve credentials for a block by matching user's available credentials.
Args:
user_id: User ID
block: Block to resolve credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple of (matched_credentials, missing_credentials) - matched credentials
are used for block execution, missing ones indicate setup requirements.
"""
input_data = input_data or {}
requirements = self._resolve_discriminated_credentials(block, input_data)
if not requirements:
return {}, []
return await match_credentials_to_requirements(user_id, requirements)
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
"""Extract non-credential inputs from block schema.""" """Extract non-credential inputs from block schema."""
inputs_list = []
schema = block.input_schema.jsonschema() schema = block.input_schema.jsonschema()
properties = schema.get("properties", {})
required_fields = set(schema.get("required", []))
# Get credential field names to exclude
credentials_fields = set(block.input_schema.get_credentials_fields().keys()) credentials_fields = set(block.input_schema.get_credentials_fields().keys())
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
for field_name, field_schema in properties.items(): def _resolve_discriminated_credentials(
# Skip credential fields self,
if field_name in credentials_fields: block: AnyBlockSchema,
continue input_data: dict[str, Any],
) -> dict[str, CredentialsFieldInfo]:
"""Resolve credential requirements, applying discriminator logic where needed."""
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return {}
inputs_list.append( resolved: dict[str, CredentialsFieldInfo] = {}
{
"name": field_name, for field_name, field_info in credentials_fields_info.items():
"title": field_schema.get("title", field_name), effective_field_info = field_info
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""), if field_info.discriminator and field_info.discriminator_mapping:
"required": field_name in required_fields, discriminator_value = input_data.get(field_info.discriminator)
} if discriminator_value is None:
field = block.input_schema.model_fields.get(
field_info.discriminator
)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
# For host-scoped credentials, add the discriminator value
# (e.g., URL) so _credential_is_for_host can match it
effective_field_info.discriminator_values.add(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
) )
return inputs_list resolved[field_name] = effective_field_info
return resolved

View File

@@ -0,0 +1,106 @@
"""Tests for block execution guards in RunBlockTool."""
from unittest.mock import MagicMock, patch
import pytest
from backend.api.features.chat.tools.models import ErrorResponse
from backend.api.features.chat.tools.run_block import RunBlockTool
from backend.data.block import BlockType
from ._test_data import make_session
_TEST_USER_ID = "test-user-run-block"
def make_mock_block(
block_id: str, name: str, block_type: BlockType, disabled: bool = False
):
"""Create a mock block for testing."""
mock = MagicMock()
mock.id = block_id
mock.name = name
mock.block_type = block_type
mock.disabled = disabled
mock.input_schema = MagicMock()
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
mock.input_schema.get_credentials_fields_info.return_value = []
return mock
class TestRunBlockFiltering:
"""Tests for block execution guards in RunBlockTool."""
@pytest.mark.asyncio(loop_scope="session")
async def test_excluded_block_type_returns_error(self):
"""Attempting to execute a block with excluded BlockType returns error."""
session = make_session(user_id=_TEST_USER_ID)
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
with patch(
"backend.api.features.chat.tools.run_block.get_block",
return_value=input_block,
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="input-block-id",
input_data={},
)
assert isinstance(response, ErrorResponse)
assert "cannot be run directly in CoPilot" in response.message
assert "designed for use within graphs only" in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_excluded_block_id_returns_error(self):
"""Attempting to execute SmartDecisionMakerBlock returns error."""
session = make_session(user_id=_TEST_USER_ID)
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
smart_block = make_mock_block(
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
)
with patch(
"backend.api.features.chat.tools.run_block.get_block",
return_value=smart_block,
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id=smart_decision_id,
input_data={},
)
assert isinstance(response, ErrorResponse)
assert "cannot be run directly in CoPilot" in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_non_excluded_block_passes_guard(self):
"""Non-excluded blocks pass the filtering guard (may fail later for other reasons)."""
session = make_session(user_id=_TEST_USER_ID)
standard_block = make_mock_block(
"standard-id", "HTTP Request", BlockType.STANDARD
)
with patch(
"backend.api.features.chat.tools.run_block.get_block",
return_value=standard_block,
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="standard-id",
input_data={},
)
# Should NOT be an ErrorResponse about CoPilot exclusion
# (may be other errors like missing credentials, but not the exclusion guard)
if isinstance(response, ErrorResponse):
assert "cannot be run directly in CoPilot" not in response.message

View File

@@ -8,12 +8,14 @@ from backend.api.features.library import model as library_model
from backend.api.features.store import db as store_db from backend.api.features.store import db as store_db
from backend.data.graph import GraphModel from backend.data.graph import GraphModel
from backend.data.model import ( from backend.data.model import (
Credentials,
CredentialsFieldInfo, CredentialsFieldInfo,
CredentialsMetaInput, CredentialsMetaInput,
HostScopedCredentials, HostScopedCredentials,
OAuth2Credentials, OAuth2Credentials,
) )
from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.exceptions import NotFoundError from backend.util.exceptions import NotFoundError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -223,6 +225,99 @@ async def get_or_create_library_agent(
return library_agents[0] return library_agents[0]
async def match_credentials_to_requirements(
user_id: str,
requirements: dict[str, CredentialsFieldInfo],
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Match user's credentials against a dictionary of credential requirements.
This is the core matching logic shared by both graph and block credential matching.
"""
matched: dict[str, CredentialsMetaInput] = {}
missing: list[CredentialsMetaInput] = []
if not requirements:
return matched, missing
available_creds = await get_user_credentials(user_id)
for field_name, field_info in requirements.items():
matching_cred = find_matching_credential(available_creds, field_info)
if matching_cred:
try:
matched[field_name] = create_credential_meta_from_match(matching_cred)
except Exception as e:
logger.error(
f"Failed to create CredentialsMetaInput for field '{field_name}': "
f"provider={matching_cred.provider}, type={matching_cred.type}, "
f"credential_id={matching_cred.id}",
exc_info=True,
)
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=f"{field_name} (validation failed: {e})",
)
)
else:
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return matched, missing
async def get_user_credentials(user_id: str) -> list[Credentials]:
"""Get all available credentials for a user."""
creds_manager = IntegrationCredentialsManager()
return await creds_manager.store.get_all_creds(user_id)
def find_matching_credential(
available_creds: list[Credentials],
field_info: CredentialsFieldInfo,
) -> Credentials | None:
"""Find a credential that matches the required provider, type, scopes, and host."""
for cred in available_creds:
if cred.provider not in field_info.provider:
continue
if cred.type not in field_info.supported_types:
continue
if cred.type == "oauth2" and not _credential_has_required_scopes(
cred, field_info
):
continue
if cred.type == "host_scoped" and not _credential_is_for_host(cred, field_info):
continue
return cred
return None
def create_credential_meta_from_match(
matching_cred: Credentials,
) -> CredentialsMetaInput:
"""Create a CredentialsMetaInput from a matched credential."""
return CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
async def match_user_credentials_to_graph( async def match_user_credentials_to_graph(
user_id: str, user_id: str,
graph: GraphModel, graph: GraphModel,
@@ -265,7 +360,7 @@ async def match_user_credentials_to_graph(
_, _,
_, _,
) in aggregated_creds.items(): ) in aggregated_creds.items():
# Find first matching credential by provider, type, and scopes # Find first matching credential by provider, type, scopes, and host/URL
matching_cred = next( matching_cred = next(
( (
cred cred
@@ -280,6 +375,10 @@ async def match_user_credentials_to_graph(
cred.type != "host_scoped" cred.type != "host_scoped"
or _credential_is_for_host(cred, credential_requirements) or _credential_is_for_host(cred, credential_requirements)
) )
and (
cred.provider != ProviderName.MCP
or _credential_is_for_mcp_server(cred, credential_requirements)
)
), ),
None, None,
) )
@@ -331,8 +430,6 @@ def _credential_has_required_scopes(
# If no scopes are required, any credential matches # If no scopes are required, any credential matches
if not requirements.required_scopes: if not requirements.required_scopes:
return True return True
# Check that credential scopes are a superset of required scopes
return set(credential.scopes).issuperset(requirements.required_scopes) return set(credential.scopes).issuperset(requirements.required_scopes)
@@ -352,6 +449,22 @@ def _credential_is_for_host(
return credential.matches_url(list(requirements.discriminator_values)[0]) return credential.matches_url(list(requirements.discriminator_values)[0])
def _credential_is_for_mcp_server(
credential: Credentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""Check if an MCP OAuth credential matches the required server URL."""
if not requirements.discriminator_values:
return True
server_url = (
credential.metadata.get("mcp_server_url")
if isinstance(credential, OAuth2Credentials)
else None
)
return server_url in requirements.discriminator_values if server_url else False
async def check_user_has_required_credentials( async def check_user_has_required_credentials(
user_id: str, user_id: str,
required_credentials: list[CredentialsMetaInput], required_credentials: list[CredentialsMetaInput],

View File

@@ -1,7 +1,7 @@
import asyncio import asyncio
import logging import logging
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Annotated, List, Literal from typing import TYPE_CHECKING, Annotated, Any, List, Literal
from autogpt_libs.auth import get_user_id from autogpt_libs.auth import get_user_id
from fastapi import ( from fastapi import (
@@ -14,7 +14,7 @@ from fastapi import (
Security, Security,
status, status,
) )
from pydantic import BaseModel, Field, SecretStr from pydantic import BaseModel, Field, SecretStr, model_validator
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
from backend.api.features.library.db import set_preset_webhook, update_preset from backend.api.features.library.db import set_preset_webhook, update_preset
@@ -39,7 +39,11 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step
from backend.data.user import get_user_integrations from backend.data.user import get_user_integrations
from backend.executor.utils import add_graph_execution from backend.executor.utils import add_graph_execution
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.credentials_store import provider_matches
from backend.integrations.creds_manager import (
IntegrationCredentialsManager,
create_mcp_oauth_handler,
)
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager from backend.integrations.webhooks import get_webhook_manager
@@ -102,9 +106,37 @@ class CredentialsMetaResponse(BaseModel):
scopes: list[str] | None scopes: list[str] | None
username: str | None username: str | None
host: str | None = Field( host: str | None = Field(
default=None, description="Host pattern for host-scoped credentials" default=None,
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
) )
@model_validator(mode="before")
@classmethod
def _normalize_provider(cls, data: Any) -> Any:
"""Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug."""
if isinstance(data, dict):
prov = data.get("provider", "")
if isinstance(prov, str) and prov.startswith("ProviderName."):
member = prov.removeprefix("ProviderName.")
try:
data = {**data, "provider": ProviderName[member].value}
except KeyError:
pass
return data
@staticmethod
def get_host(cred: Credentials) -> str | None:
"""Extract host from credential: HostScoped host or MCP server URL."""
if isinstance(cred, HostScopedCredentials):
return cred.host
if isinstance(cred, OAuth2Credentials) and cred.provider in (
ProviderName.MCP,
ProviderName.MCP.value,
"ProviderName.MCP",
):
return (cred.metadata or {}).get("mcp_server_url")
return None
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens") @router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
async def callback( async def callback(
@@ -179,9 +211,7 @@ async def callback(
title=credentials.title, title=credentials.title,
scopes=credentials.scopes, scopes=credentials.scopes,
username=credentials.username, username=credentials.username,
host=( host=(CredentialsMetaResponse.get_host(credentials)),
credentials.host if isinstance(credentials, HostScopedCredentials) else None
),
) )
@@ -199,7 +229,7 @@ async def list_credentials(
title=cred.title, title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None, scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None, username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=cred.host if isinstance(cred, HostScopedCredentials) else None, host=CredentialsMetaResponse.get_host(cred),
) )
for cred in credentials for cred in credentials
] ]
@@ -222,7 +252,7 @@ async def list_credentials_by_provider(
title=cred.title, title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None, scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None, username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=cred.host if isinstance(cred, HostScopedCredentials) else None, host=CredentialsMetaResponse.get_host(cred),
) )
for cred in credentials for cred in credentials
] ]
@@ -322,6 +352,10 @@ async def delete_credentials(
tokens_revoked = None tokens_revoked = None
if isinstance(creds, OAuth2Credentials): if isinstance(creds, OAuth2Credentials):
if provider_matches(provider.value, ProviderName.MCP.value):
# MCP uses dynamic per-server OAuth — create handler from metadata
handler = create_mcp_oauth_handler(creds)
else:
handler = _get_provider_oauth_handler(request, provider) handler = _get_provider_oauth_handler(request, provider)
tokens_revoked = await handler.revoke_tokens(creds) tokens_revoked = await handler.revoke_tokens(creds)

View File

@@ -0,0 +1,402 @@
"""
MCP (Model Context Protocol) API routes.
Provides endpoints for MCP tool discovery and OAuth authentication so the
frontend can list available tools on an MCP server before placing a block.
"""
import logging
from typing import Annotated, Any
from urllib.parse import urlparse
import fastapi
from autogpt_libs.auth import get_user_id
from fastapi import Security
from pydantic import BaseModel, Field
from backend.api.features.integrations.router import CredentialsMetaResponse
from backend.blocks.mcp.client import MCPClient, MCPClientError
from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.request import HTTPClientError, Requests
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
router = fastapi.APIRouter(tags=["mcp"])
creds_manager = IntegrationCredentialsManager()
# ====================== Tool Discovery ====================== #
class DiscoverToolsRequest(BaseModel):
"""Request to discover tools on an MCP server."""
server_url: str = Field(description="URL of the MCP server")
auth_token: str | None = Field(
default=None,
description="Optional Bearer token for authenticated MCP servers",
)
class MCPToolResponse(BaseModel):
"""A single MCP tool returned by discovery."""
name: str
description: str
input_schema: dict[str, Any]
class DiscoverToolsResponse(BaseModel):
"""Response containing the list of tools available on an MCP server."""
tools: list[MCPToolResponse]
server_name: str | None = None
protocol_version: str | None = None
@router.post(
"/discover-tools",
summary="Discover available tools on an MCP server",
response_model=DiscoverToolsResponse,
)
async def discover_tools(
request: DiscoverToolsRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> DiscoverToolsResponse:
"""
Connect to an MCP server and return its available tools.
If the user has a stored MCP credential for this server URL, it will be
used automatically — no need to pass an explicit auth token.
"""
auth_token = request.auth_token
# Auto-use stored MCP credential when no explicit token is provided.
if not auth_token:
try:
mcp_creds = await creds_manager.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
# Find the freshest credential for this server URL
best_cred: OAuth2Credentials | None = None
for cred in mcp_creds:
if (
isinstance(cred, OAuth2Credentials)
and cred.metadata.get("mcp_server_url") == request.server_url
):
if best_cred is None or (
(cred.access_token_expires_at or 0)
> (best_cred.access_token_expires_at or 0)
):
best_cred = cred
if best_cred:
# Refresh the token if expired before using it
best_cred = await creds_manager.refresh_if_needed(user_id, best_cred)
logger.info(
f"Using MCP credential {best_cred.id} for {request.server_url}, "
f"expires_at={best_cred.access_token_expires_at}"
)
auth_token = best_cred.access_token.get_secret_value()
except Exception:
logger.debug("Could not look up stored MCP credentials", exc_info=True)
client = MCPClient(request.server_url, auth_token=auth_token)
try:
init_result = await client.initialize()
tools = await client.list_tools()
except HTTPClientError as e:
if e.status_code in (401, 403):
raise fastapi.HTTPException(
status_code=401,
detail="This MCP server requires authentication. "
"Please provide a valid auth token.",
)
raise fastapi.HTTPException(status_code=502, detail=str(e))
except MCPClientError as e:
raise fastapi.HTTPException(status_code=502, detail=str(e))
except Exception as e:
raise fastapi.HTTPException(
status_code=502,
detail=f"Failed to connect to MCP server: {e}",
)
return DiscoverToolsResponse(
tools=[
MCPToolResponse(
name=t.name,
description=t.description,
input_schema=t.input_schema,
)
for t in tools
],
server_name=(
init_result.get("serverInfo", {}).get("name")
or urlparse(request.server_url).hostname
or "MCP"
),
protocol_version=init_result.get("protocolVersion"),
)
# ======================== OAuth Flow ======================== #
class MCPOAuthLoginRequest(BaseModel):
"""Request to start an OAuth flow for an MCP server."""
server_url: str = Field(description="URL of the MCP server that requires OAuth")
class MCPOAuthLoginResponse(BaseModel):
"""Response with the OAuth login URL for the user to authenticate."""
login_url: str
state_token: str
@router.post(
"/oauth/login",
summary="Initiate OAuth login for an MCP server",
)
async def mcp_oauth_login(
request: MCPOAuthLoginRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> MCPOAuthLoginResponse:
"""
Discover OAuth metadata from the MCP server and return a login URL.
1. Discovers the protected-resource metadata (RFC 9728)
2. Fetches the authorization server metadata (RFC 8414)
3. Performs Dynamic Client Registration (RFC 7591) if available
4. Returns the authorization URL for the frontend to open in a popup
"""
client = MCPClient(request.server_url)
# Step 1: Discover protected-resource metadata (RFC 9728)
protected_resource = await client.discover_auth()
metadata: dict[str, Any] | None = None
if protected_resource and protected_resource.get("authorization_servers"):
auth_server_url = protected_resource["authorization_servers"][0]
resource_url = protected_resource.get("resource", request.server_url)
# Step 2a: Discover auth-server metadata (RFC 8414)
metadata = await client.discover_auth_server_metadata(auth_server_url)
else:
# Fallback: Some MCP servers (e.g. Linear) are their own auth server
# and serve OAuth metadata directly without protected-resource metadata.
# Don't assume a resource_url — omitting it lets the auth server choose
# the correct audience for the token (RFC 8707 resource is optional).
resource_url = None
metadata = await client.discover_auth_server_metadata(request.server_url)
if (
not metadata
or "authorization_endpoint" not in metadata
or "token_endpoint" not in metadata
):
raise fastapi.HTTPException(
status_code=400,
detail="This MCP server does not advertise OAuth support. "
"You may need to provide an auth token manually.",
)
authorize_url = metadata["authorization_endpoint"]
token_url = metadata["token_endpoint"]
registration_endpoint = metadata.get("registration_endpoint")
revoke_url = metadata.get("revocation_endpoint")
# Step 3: Dynamic Client Registration (RFC 7591) if available
frontend_base_url = settings.config.frontend_base_url
if not frontend_base_url:
raise fastapi.HTTPException(
status_code=500,
detail="Frontend base URL is not configured.",
)
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
client_id = ""
client_secret = ""
if registration_endpoint:
reg_result = await _register_mcp_client(
registration_endpoint, redirect_uri, request.server_url
)
if reg_result:
client_id = reg_result.get("client_id", "")
client_secret = reg_result.get("client_secret", "")
if not client_id:
client_id = "autogpt-platform"
# Step 4: Store state token with OAuth metadata for the callback
scopes = (protected_resource or {}).get("scopes_supported") or metadata.get(
"scopes_supported", []
)
state_token, code_challenge = await creds_manager.store.store_state_token(
user_id,
ProviderName.MCP.value,
scopes,
state_metadata={
"authorize_url": authorize_url,
"token_url": token_url,
"revoke_url": revoke_url,
"resource_url": resource_url,
"server_url": request.server_url,
"client_id": client_id,
"client_secret": client_secret,
},
)
# Step 5: Build and return the login URL
handler = MCPOAuthHandler(
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
authorize_url=authorize_url,
token_url=token_url,
resource_url=resource_url,
)
login_url = handler.get_login_url(
scopes, state_token, code_challenge=code_challenge
)
return MCPOAuthLoginResponse(login_url=login_url, state_token=state_token)
class MCPOAuthCallbackRequest(BaseModel):
"""Request to exchange an OAuth code for tokens."""
code: str = Field(description="Authorization code from OAuth callback")
state_token: str = Field(description="State token for CSRF verification")
class MCPOAuthCallbackResponse(BaseModel):
"""Response after successfully storing OAuth credentials."""
credential_id: str
@router.post(
"/oauth/callback",
summary="Exchange OAuth code for MCP tokens",
)
async def mcp_oauth_callback(
request: MCPOAuthCallbackRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> CredentialsMetaResponse:
"""
Exchange the authorization code for tokens and store the credential.
The frontend calls this after receiving the OAuth code from the popup.
On success, subsequent ``/discover-tools`` calls for the same server URL
will automatically use the stored credential.
"""
valid_state = await creds_manager.store.verify_state_token(
user_id, request.state_token, ProviderName.MCP.value
)
if not valid_state:
raise fastapi.HTTPException(
status_code=400,
detail="Invalid or expired state token.",
)
meta = valid_state.state_metadata
frontend_base_url = settings.config.frontend_base_url
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
handler = MCPOAuthHandler(
client_id=meta["client_id"],
client_secret=meta.get("client_secret", ""),
redirect_uri=redirect_uri,
authorize_url=meta["authorize_url"],
token_url=meta["token_url"],
revoke_url=meta.get("revoke_url"),
resource_url=meta.get("resource_url"),
)
try:
credentials = await handler.exchange_code_for_tokens(
request.code, valid_state.scopes, valid_state.code_verifier
)
except Exception as e:
raise fastapi.HTTPException(
status_code=400,
detail=f"OAuth token exchange failed: {e}",
)
# Enrich credential metadata for future lookup and token refresh
if credentials.metadata is None:
credentials.metadata = {}
credentials.metadata["mcp_server_url"] = meta["server_url"]
credentials.metadata["mcp_client_id"] = meta["client_id"]
credentials.metadata["mcp_client_secret"] = meta.get("client_secret", "")
credentials.metadata["mcp_token_url"] = meta["token_url"]
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
credentials.title = f"MCP: {hostname}"
# Remove old MCP credentials for the same server to prevent stale token buildup.
try:
old_creds = await creds_manager.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
for old in old_creds:
if (
isinstance(old, OAuth2Credentials)
and old.metadata.get("mcp_server_url") == meta["server_url"]
):
await creds_manager.store.delete_creds_by_id(user_id, old.id)
logger.info(
f"Removed old MCP credential {old.id} for {meta['server_url']}"
)
except Exception:
logger.debug("Could not clean up old MCP credentials", exc_info=True)
await creds_manager.create(user_id, credentials)
return CredentialsMetaResponse(
id=credentials.id,
provider=credentials.provider,
type=credentials.type,
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
host=credentials.metadata.get("mcp_server_url"),
)
# ======================== Helpers ======================== #
async def _register_mcp_client(
registration_endpoint: str,
redirect_uri: str,
server_url: str,
) -> dict[str, Any] | None:
"""Attempt Dynamic Client Registration (RFC 7591) with an MCP auth server."""
try:
response = await Requests(raise_for_status=True).post(
registration_endpoint,
json={
"client_name": "AutoGPT Platform",
"redirect_uris": [redirect_uri],
"grant_types": ["authorization_code"],
"response_types": ["code"],
"token_endpoint_auth_method": "client_secret_post",
},
)
data = response.json()
if isinstance(data, dict) and "client_id" in data:
return data
return None
except Exception as e:
logger.warning(f"Dynamic client registration failed for {server_url}: {e}")
return None

View File

@@ -0,0 +1,390 @@
"""Tests for MCP API routes."""
from unittest.mock import AsyncMock, patch
import fastapi
import fastapi.testclient
from autogpt_libs.auth import get_user_id
from backend.api.features.mcp.routes import router
from backend.blocks.mcp.client import MCPClientError, MCPTool
from backend.util.request import HTTPClientError
app = fastapi.FastAPI()
app.include_router(router)
app.dependency_overrides[get_user_id] = lambda: "test-user-id"
client = fastapi.testclient.TestClient(app)
class TestDiscoverTools:
def test_discover_tools_success(self):
mock_tools = [
MCPTool(
name="get_weather",
description="Get weather for a city",
input_schema={
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
),
MCPTool(
name="add_numbers",
description="Add two numbers",
input_schema={
"type": "object",
"properties": {
"a": {"type": "number"},
"b": {"type": "number"},
},
},
),
]
with (patch("backend.api.features.mcp.routes.MCPClient") as MockClient,):
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={
"protocolVersion": "2025-03-26",
"serverInfo": {"name": "test-server"},
}
)
instance.list_tools = AsyncMock(return_value=mock_tools)
response = client.post(
"/discover-tools",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
data = response.json()
assert len(data["tools"]) == 2
assert data["tools"][0]["name"] == "get_weather"
assert data["tools"][1]["name"] == "add_numbers"
assert data["server_name"] == "test-server"
assert data["protocol_version"] == "2025-03-26"
def test_discover_tools_with_auth_token(self):
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
)
instance.list_tools = AsyncMock(return_value=[])
response = client.post(
"/discover-tools",
json={
"server_url": "https://mcp.example.com/mcp",
"auth_token": "my-secret-token",
},
)
assert response.status_code == 200
MockClient.assert_called_once_with(
"https://mcp.example.com/mcp",
auth_token="my-secret-token",
)
def test_discover_tools_auto_uses_stored_credential(self):
"""When no explicit token is given, stored MCP credentials are used."""
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
stored_cred = OAuth2Credentials(
provider="mcp",
title="MCP: example.com",
access_token=SecretStr("stored-token-123"),
refresh_token=None,
access_token_expires_at=None,
refresh_token_expires_at=None,
scopes=[],
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
)
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred])
mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred)
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
)
instance.list_tools = AsyncMock(return_value=[])
response = client.post(
"/discover-tools",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
MockClient.assert_called_once_with(
"https://mcp.example.com/mcp",
auth_token="stored-token-123",
)
def test_discover_tools_mcp_error(self):
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=MCPClientError("Connection refused")
)
response = client.post(
"/discover-tools",
json={"server_url": "https://bad-server.example.com/mcp"},
)
assert response.status_code == 502
assert "Connection refused" in response.json()["detail"]
def test_discover_tools_generic_error(self):
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
instance = MockClient.return_value
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
response = client.post(
"/discover-tools",
json={"server_url": "https://timeout.example.com/mcp"},
)
assert response.status_code == 502
assert "Failed to connect" in response.json()["detail"]
def test_discover_tools_auth_required(self):
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
)
response = client.post(
"/discover-tools",
json={"server_url": "https://auth-server.example.com/mcp"},
)
assert response.status_code == 401
assert "requires authentication" in response.json()["detail"]
def test_discover_tools_forbidden(self):
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
)
response = client.post(
"/discover-tools",
json={"server_url": "https://auth-server.example.com/mcp"},
)
assert response.status_code == 401
assert "requires authentication" in response.json()["detail"]
def test_discover_tools_missing_url(self):
response = client.post("/discover-tools", json={})
assert response.status_code == 422
class TestOAuthLogin:
def test_oauth_login_success(self):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch(
"backend.api.features.mcp.routes._register_mcp_client"
) as mock_register,
):
instance = MockClient.return_value
instance.discover_auth = AsyncMock(
return_value={
"authorization_servers": ["https://auth.sentry.io"],
"resource": "https://mcp.sentry.dev/mcp",
"scopes_supported": ["openid"],
}
)
instance.discover_auth_server_metadata = AsyncMock(
return_value={
"authorization_endpoint": "https://auth.sentry.io/authorize",
"token_endpoint": "https://auth.sentry.io/token",
"registration_endpoint": "https://auth.sentry.io/register",
}
)
mock_register.return_value = {
"client_id": "registered-client-id",
"client_secret": "registered-secret",
}
mock_cm.store.store_state_token = AsyncMock(
return_value=("state-token-123", "code-challenge-abc")
)
mock_settings.config.frontend_base_url = "http://localhost:3000"
response = client.post(
"/oauth/login",
json={"server_url": "https://mcp.sentry.dev/mcp"},
)
assert response.status_code == 200
data = response.json()
assert "login_url" in data
assert data["state_token"] == "state-token-123"
assert "auth.sentry.io/authorize" in data["login_url"]
assert "registered-client-id" in data["login_url"]
def test_oauth_login_no_oauth_support(self):
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
instance = MockClient.return_value
instance.discover_auth = AsyncMock(return_value=None)
instance.discover_auth_server_metadata = AsyncMock(return_value=None)
response = client.post(
"/oauth/login",
json={"server_url": "https://simple-server.example.com/mcp"},
)
assert response.status_code == 400
assert "does not advertise OAuth" in response.json()["detail"]
def test_oauth_login_fallback_to_public_client(self):
"""When DCR is unavailable, falls back to default public client ID."""
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
):
instance = MockClient.return_value
instance.discover_auth = AsyncMock(
return_value={
"authorization_servers": ["https://auth.example.com"],
"resource": "https://mcp.example.com/mcp",
}
)
instance.discover_auth_server_metadata = AsyncMock(
return_value={
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
# No registration_endpoint
}
)
mock_cm.store.store_state_token = AsyncMock(
return_value=("state-abc", "challenge-xyz")
)
mock_settings.config.frontend_base_url = "http://localhost:3000"
response = client.post(
"/oauth/login",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
data = response.json()
assert "autogpt-platform" in data["login_url"]
class TestOAuthCallback:
def test_oauth_callback_success(self):
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
mock_creds = OAuth2Credentials(
provider="mcp",
title=None,
access_token=SecretStr("access-token-xyz"),
refresh_token=None,
access_token_expires_at=None,
refresh_token_expires_at=None,
scopes=[],
metadata={
"mcp_token_url": "https://auth.sentry.io/token",
"mcp_resource_url": "https://mcp.sentry.dev/mcp",
},
)
with (
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
):
mock_settings.config.frontend_base_url = "http://localhost:3000"
# Mock state verification
mock_state = AsyncMock()
mock_state.state_metadata = {
"authorize_url": "https://auth.sentry.io/authorize",
"token_url": "https://auth.sentry.io/token",
"client_id": "test-client-id",
"client_secret": "test-secret",
"server_url": "https://mcp.sentry.dev/mcp",
}
mock_state.scopes = ["openid"]
mock_state.code_verifier = "verifier-123"
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
mock_cm.create = AsyncMock()
handler_instance = MockHandler.return_value
handler_instance.exchange_code_for_tokens = AsyncMock(
return_value=mock_creds
)
# Mock old credential cleanup
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
response = client.post(
"/oauth/callback",
json={"code": "auth-code-abc", "state_token": "state-token-123"},
)
assert response.status_code == 200
data = response.json()
assert "id" in data
assert data["provider"] == "mcp"
assert data["type"] == "oauth2"
mock_cm.create.assert_called_once()
def test_oauth_callback_invalid_state(self):
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
mock_cm.store.verify_state_token = AsyncMock(return_value=None)
response = client.post(
"/oauth/callback",
json={"code": "auth-code", "state_token": "bad-state"},
)
assert response.status_code == 400
assert "Invalid or expired" in response.json()["detail"]
def test_oauth_callback_token_exchange_fails(self):
with (
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
):
mock_settings.config.frontend_base_url = "http://localhost:3000"
mock_state = AsyncMock()
mock_state.state_metadata = {
"authorize_url": "https://auth.example.com/authorize",
"token_url": "https://auth.example.com/token",
"client_id": "cid",
"server_url": "https://mcp.example.com/mcp",
}
mock_state.scopes = []
mock_state.code_verifier = "v"
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
handler_instance = MockHandler.return_value
handler_instance.exchange_code_for_tokens = AsyncMock(
side_effect=RuntimeError("Token exchange failed")
)
response = client.post(
"/oauth/callback",
json={"code": "bad-code", "state_token": "state"},
)
assert response.status_code == 400
assert "token exchange failed" in response.json()["detail"].lower()

View File

@@ -8,6 +8,7 @@ Includes BM25 reranking for improved lexical relevance.
import logging import logging
import re import re
import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Literal from typing import Any, Literal
@@ -362,7 +363,11 @@ async def unified_hybrid_search(
LIMIT {limit_param} OFFSET {offset_param} LIMIT {limit_param} OFFSET {offset_param}
""" """
try:
results = await query_raw_with_schema(sql_query, *params) results = await query_raw_with_schema(sql_query, *params)
except Exception as e:
await _log_vector_error_diagnostics(e)
raise
total = results[0]["total_count"] if results else 0 total = results[0]["total_count"] if results else 0
# Apply BM25 reranking # Apply BM25 reranking
@@ -686,7 +691,11 @@ async def hybrid_search(
LIMIT {limit_param} OFFSET {offset_param} LIMIT {limit_param} OFFSET {offset_param}
""" """
try:
results = await query_raw_with_schema(sql_query, *params) results = await query_raw_with_schema(sql_query, *params)
except Exception as e:
await _log_vector_error_diagnostics(e)
raise
total = results[0]["total_count"] if results else 0 total = results[0]["total_count"] if results else 0
@@ -718,6 +727,87 @@ async def hybrid_search_simple(
return await hybrid_search(query=query, page=page, page_size=page_size) return await hybrid_search(query=query, page=page, page_size=page_size)
# ============================================================================
# Diagnostics
# ============================================================================
# Rate limit: only log vector error diagnostics once per this interval
_VECTOR_DIAG_INTERVAL_SECONDS = 60
_last_vector_diag_time: float = 0
async def _log_vector_error_diagnostics(error: Exception) -> None:
"""Log diagnostic info when 'type vector does not exist' error occurs.
Note: Diagnostic queries use query_raw_with_schema which may run on a different
pooled connection than the one that failed. Session-level search_path can differ,
so these diagnostics show cluster-wide state, not necessarily the failed session.
Includes rate limiting to avoid log spam - only logs once per minute.
Caller should re-raise the error after calling this function.
"""
global _last_vector_diag_time
# Check if this is the vector type error
error_str = str(error).lower()
if not (
"type" in error_str and "vector" in error_str and "does not exist" in error_str
):
return
# Rate limit: only log once per interval
now = time.time()
if now - _last_vector_diag_time < _VECTOR_DIAG_INTERVAL_SECONDS:
return
_last_vector_diag_time = now
try:
diagnostics: dict[str, object] = {}
try:
search_path_result = await query_raw_with_schema("SHOW search_path")
diagnostics["search_path"] = search_path_result
except Exception as e:
diagnostics["search_path"] = f"Error: {e}"
try:
schema_result = await query_raw_with_schema("SELECT current_schema()")
diagnostics["current_schema"] = schema_result
except Exception as e:
diagnostics["current_schema"] = f"Error: {e}"
try:
user_result = await query_raw_with_schema(
"SELECT current_user, session_user, current_database()"
)
diagnostics["user_info"] = user_result
except Exception as e:
diagnostics["user_info"] = f"Error: {e}"
try:
# Check pgvector extension installation (cluster-wide, stable info)
ext_result = await query_raw_with_schema(
"SELECT extname, extversion, nspname as schema "
"FROM pg_extension e "
"JOIN pg_namespace n ON e.extnamespace = n.oid "
"WHERE extname = 'vector'"
)
diagnostics["pgvector_extension"] = ext_result
except Exception as e:
diagnostics["pgvector_extension"] = f"Error: {e}"
logger.error(
f"Vector type error diagnostics:\n"
f" Error: {error}\n"
f" search_path: {diagnostics.get('search_path')}\n"
f" current_schema: {diagnostics.get('current_schema')}\n"
f" user_info: {diagnostics.get('user_info')}\n"
f" pgvector_extension: {diagnostics.get('pgvector_extension')}"
)
except Exception as diag_error:
logger.error(f"Failed to collect vector error diagnostics: {diag_error}")
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights # Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
# for existing code that expects the popularity parameter # for existing code that expects the popularity parameter
HybridSearchWeights = StoreAgentSearchWeights HybridSearchWeights = StoreAgentSearchWeights

View File

@@ -26,6 +26,7 @@ import backend.api.features.executions.review.routes
import backend.api.features.library.db import backend.api.features.library.db
import backend.api.features.library.model import backend.api.features.library.model
import backend.api.features.library.routes import backend.api.features.library.routes
import backend.api.features.mcp.routes as mcp_routes
import backend.api.features.oauth import backend.api.features.oauth
import backend.api.features.otto.routes import backend.api.features.otto.routes
import backend.api.features.postmark.postmark import backend.api.features.postmark.postmark
@@ -343,6 +344,11 @@ app.include_router(
tags=["workspace"], tags=["workspace"],
prefix="/api/workspace", prefix="/api/workspace",
) )
app.include_router(
mcp_routes.router,
tags=["v2", "mcp"],
prefix="/api/mcp",
)
app.include_router( app.include_router(
backend.api.features.oauth.router, backend.api.features.oauth.router,
tags=["oauth"], tags=["oauth"],

View File

@@ -0,0 +1,301 @@
"""
MCP (Model Context Protocol) Tool Block.
A single dynamic block that can connect to any MCP server, discover available tools,
and execute them. Works like AgentExecutorBlock — the user selects a tool from a
dropdown and the input/output schema adapts dynamically.
"""
import json
import logging
from typing import Any, Literal
from pydantic import SecretStr
from backend.blocks.mcp.client import MCPClient, MCPClientError
from backend.data.block import (
Block,
BlockCategory,
BlockInput,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
BlockType,
)
from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
OAuth2Credentials,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.json import validate_with_jsonschema
logger = logging.getLogger(__name__)
TEST_CREDENTIALS = OAuth2Credentials(
id="test-mcp-cred",
provider="mcp",
access_token=SecretStr("mock-mcp-token"),
refresh_token=SecretStr("mock-refresh"),
scopes=[],
title="Mock MCP credential",
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
MCPCredentials = CredentialsMetaInput[Literal[ProviderName.MCP], Literal["oauth2"]]
class MCPToolBlock(Block):
"""
A block that connects to an MCP server, lets the user pick a tool,
and executes it with dynamic input/output schema.
The flow:
1. User provides an MCP server URL (and optional credentials)
2. Frontend calls the backend to get tool list from that URL
3. User selects a tool from a dropdown (available_tools)
4. The block's input schema updates to reflect the selected tool's parameters
5. On execution, the block calls the MCP server to run the tool
"""
class Input(BlockSchemaInput):
server_url: str = SchemaField(
description="URL of the MCP server (Streamable HTTP endpoint)",
placeholder="https://mcp.example.com/mcp",
)
credentials: MCPCredentials = CredentialsField(
discriminator="server_url",
description="MCP server OAuth credentials",
default={},
)
selected_tool: str = SchemaField(
description="The MCP tool to execute",
placeholder="Select a tool",
default="",
)
tool_input_schema: dict[str, Any] = SchemaField(
description="JSON Schema for the selected tool's input parameters. "
"Populated automatically when a tool is selected.",
default={},
hidden=True,
)
tool_arguments: dict[str, Any] = SchemaField(
description="Arguments to pass to the selected MCP tool. "
"The fields here are defined by the tool's input schema.",
default={},
)
@classmethod
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
"""Return the tool's input schema so the builder UI renders dynamic fields."""
return data.get("tool_input_schema", {})
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
"""Return the current tool_arguments as defaults for the dynamic fields."""
return data.get("tool_arguments", {})
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
"""Check which required tool arguments are missing."""
required_fields = cls.get_input_schema(data).get("required", [])
tool_arguments = data.get("tool_arguments", {})
return set(required_fields) - set(tool_arguments)
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
"""Validate tool_arguments against the tool's input schema."""
tool_schema = cls.get_input_schema(data)
if not tool_schema:
return None
tool_arguments = data.get("tool_arguments", {})
return validate_with_jsonschema(tool_schema, tool_arguments)
class Output(BlockSchemaOutput):
result: Any = SchemaField(description="The result returned by the MCP tool")
error: str = SchemaField(description="Error message if the tool call failed")
def __init__(self):
super().__init__(
id="a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4",
description="Connect to any MCP server and execute its tools. "
"Provide a server URL, select a tool, and pass arguments dynamically.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=MCPToolBlock.Input,
output_schema=MCPToolBlock.Output,
block_type=BlockType.STANDARD,
test_credentials=TEST_CREDENTIALS,
test_input={
"server_url": "https://mcp.example.com/mcp",
"credentials": TEST_CREDENTIALS_INPUT,
"selected_tool": "get_weather",
"tool_input_schema": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
"tool_arguments": {"city": "London"},
},
test_output=[
(
"result",
{"weather": "sunny", "temperature": 20},
),
],
test_mock={
"_call_mcp_tool": lambda *a, **kw: {
"weather": "sunny",
"temperature": 20,
},
},
)
async def _call_mcp_tool(
self,
server_url: str,
tool_name: str,
arguments: dict[str, Any],
auth_token: str | None = None,
) -> Any:
"""Call a tool on the MCP server. Extracted for easy mocking in tests."""
client = MCPClient(server_url, auth_token=auth_token)
await client.initialize()
result = await client.call_tool(tool_name, arguments)
if result.is_error:
error_text = ""
for item in result.content:
if item.get("type") == "text":
error_text += item.get("text", "")
raise MCPClientError(
f"MCP tool '{tool_name}' returned an error: "
f"{error_text or 'Unknown error'}"
)
# Extract text content from the result
output_parts = []
for item in result.content:
if item.get("type") == "text":
text = item.get("text", "")
# Try to parse as JSON for structured output
try:
output_parts.append(json.loads(text))
except (json.JSONDecodeError, ValueError):
output_parts.append(text)
elif item.get("type") == "image":
output_parts.append(
{
"type": "image",
"data": item.get("data"),
"mimeType": item.get("mimeType"),
}
)
elif item.get("type") == "resource":
output_parts.append(item.get("resource", {}))
# If single result, unwrap
if len(output_parts) == 1:
return output_parts[0]
return output_parts if output_parts else None
@staticmethod
async def _auto_lookup_credential(
user_id: str, server_url: str
) -> "OAuth2Credentials | None":
"""Auto-lookup stored MCP credential for a server URL.
This is a fallback for nodes that don't have ``credentials`` explicitly
set (e.g. nodes created before the credential field was wired up).
"""
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
try:
mgr = IntegrationCredentialsManager()
mcp_creds = await mgr.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
best: OAuth2Credentials | None = None
for cred in mcp_creds:
if (
isinstance(cred, OAuth2Credentials)
and cred.metadata.get("mcp_server_url") == server_url
):
if best is None or (
(cred.access_token_expires_at or 0)
> (best.access_token_expires_at or 0)
):
best = cred
if best:
best = await mgr.refresh_if_needed(user_id, best)
logger.info(
"Auto-resolved MCP credential %s for %s", best.id, server_url
)
return best
except Exception:
logger.debug("Auto-lookup MCP credential failed", exc_info=True)
return None
async def run(
self,
input_data: Input,
*,
user_id: str,
credentials: OAuth2Credentials | None = None,
**kwargs,
) -> BlockOutput:
if not input_data.server_url:
yield "error", "MCP server URL is required"
return
if not input_data.selected_tool:
yield "error", "No tool selected. Please select a tool from the dropdown."
return
# Validate required tool arguments before calling the server.
# The executor-level validation is bypassed for MCP blocks because
# get_input_defaults() flattens tool_arguments, stripping tool_input_schema
# from the validation context.
required = set(input_data.tool_input_schema.get("required", []))
if required:
missing = required - set(input_data.tool_arguments.keys())
if missing:
yield "error", (
f"Missing required argument(s): {', '.join(sorted(missing))}. "
f"Please fill in all required fields marked with * in the block form."
)
return
# If no credentials were injected by the executor (e.g. legacy nodes
# that don't have the credentials field set), try to auto-lookup
# the stored MCP credential for this server URL.
if credentials is None:
credentials = await self._auto_lookup_credential(
user_id, input_data.server_url
)
auth_token = (
credentials.access_token.get_secret_value() if credentials else None
)
try:
result = await self._call_mcp_tool(
server_url=input_data.server_url,
tool_name=input_data.selected_tool,
arguments=input_data.tool_arguments,
auth_token=auth_token,
)
yield "result", result
except MCPClientError as e:
yield "error", str(e)
except Exception as e:
logger.exception(f"MCP tool call failed: {e}")
yield "error", f"MCP tool call failed: {str(e)}"

View File

@@ -0,0 +1,323 @@
"""
MCP (Model Context Protocol) HTTP client.
Implements the MCP Streamable HTTP transport for listing tools and calling tools
on remote MCP servers. Uses JSON-RPC 2.0 over HTTP POST.
Handles both JSON and SSE (text/event-stream) response formats per the MCP spec.
Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports
"""
import json
import logging
from dataclasses import dataclass, field
from typing import Any
from backend.util.request import Requests
logger = logging.getLogger(__name__)
@dataclass
class MCPTool:
"""Represents an MCP tool discovered from a server."""
name: str
description: str
input_schema: dict[str, Any]
@dataclass
class MCPCallResult:
"""Result from calling an MCP tool."""
content: list[dict[str, Any]] = field(default_factory=list)
is_error: bool = False
class MCPClientError(Exception):
"""Raised when an MCP protocol error occurs."""
pass
class MCPClient:
"""
Async HTTP client for the MCP Streamable HTTP transport.
Communicates with MCP servers using JSON-RPC 2.0 over HTTP POST.
Supports optional Bearer token authentication.
"""
def __init__(
self,
server_url: str,
auth_token: str | None = None,
):
self.server_url = server_url.rstrip("/")
self.auth_token = auth_token
self._request_id = 0
self._session_id: str | None = None
def _next_id(self) -> int:
self._request_id += 1
return self._request_id
def _build_headers(self) -> dict[str, str]:
headers = {
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
}
if self.auth_token:
headers["Authorization"] = f"Bearer {self.auth_token}"
if self._session_id:
headers["Mcp-Session-Id"] = self._session_id
return headers
def _build_jsonrpc_request(
self, method: str, params: dict[str, Any] | None = None
) -> dict[str, Any]:
req: dict[str, Any] = {
"jsonrpc": "2.0",
"method": method,
"id": self._next_id(),
}
if params is not None:
req["params"] = params
return req
@staticmethod
def _parse_sse_response(text: str) -> dict[str, Any]:
"""Parse an SSE (text/event-stream) response body into JSON-RPC data.
MCP servers may return responses as SSE with format:
event: message
data: {"jsonrpc":"2.0","result":{...},"id":1}
We extract the last `data:` line that contains a JSON-RPC response
(i.e. has an "id" field), which is the reply to our request.
"""
last_data: dict[str, Any] | None = None
for line in text.splitlines():
stripped = line.strip()
if stripped.startswith("data:"):
payload = stripped[len("data:") :].strip()
if not payload:
continue
try:
parsed = json.loads(payload)
# Only keep JSON-RPC responses (have "id"), skip notifications
if isinstance(parsed, dict) and "id" in parsed:
last_data = parsed
except (json.JSONDecodeError, ValueError):
continue
if last_data is None:
raise MCPClientError("No JSON-RPC response found in SSE stream")
return last_data
async def _send_request(
self, method: str, params: dict[str, Any] | None = None
) -> Any:
"""Send a JSON-RPC request to the MCP server and return the result.
Handles both ``application/json`` and ``text/event-stream`` responses
as required by the MCP Streamable HTTP transport specification.
"""
payload = self._build_jsonrpc_request(method, params)
headers = self._build_headers()
requests = Requests(
raise_for_status=True,
extra_headers=headers,
)
response = await requests.post(self.server_url, json=payload)
# Capture session ID from response (MCP Streamable HTTP transport)
session_id = response.headers.get("Mcp-Session-Id")
if session_id:
self._session_id = session_id
content_type = response.headers.get("content-type", "")
if "text/event-stream" in content_type:
body = self._parse_sse_response(response.text())
else:
try:
body = response.json()
except Exception as e:
raise MCPClientError(
f"MCP server returned non-JSON response: {e}"
) from e
if not isinstance(body, dict):
raise MCPClientError(
f"MCP server returned unexpected JSON type: {type(body).__name__}"
)
# Handle JSON-RPC error
if "error" in body:
error = body["error"]
if isinstance(error, dict):
raise MCPClientError(
f"MCP server error [{error.get('code', '?')}]: "
f"{error.get('message', 'Unknown error')}"
)
raise MCPClientError(f"MCP server error: {error}")
return body.get("result")
async def _send_notification(self, method: str) -> None:
"""Send a JSON-RPC notification (no id, no response expected)."""
headers = self._build_headers()
notification = {"jsonrpc": "2.0", "method": method}
requests = Requests(
raise_for_status=False,
extra_headers=headers,
)
await requests.post(self.server_url, json=notification)
async def discover_auth(self) -> dict[str, Any] | None:
"""Probe the MCP server's OAuth metadata (RFC 9728 / MCP spec).
Returns ``None`` if the server doesn't require auth, otherwise returns
a dict with:
- ``authorization_servers``: list of authorization server URLs
- ``resource``: the resource indicator URL (usually the MCP endpoint)
- ``scopes_supported``: optional list of supported scopes
The caller can then fetch the authorization server metadata to get
``authorization_endpoint``, ``token_endpoint``, etc.
"""
from urllib.parse import urlparse
parsed = urlparse(self.server_url)
base = f"{parsed.scheme}://{parsed.netloc}"
# Build candidates for protected-resource metadata (per RFC 9728)
path = parsed.path.rstrip("/")
candidates = []
if path and path != "/":
candidates.append(f"{base}/.well-known/oauth-protected-resource{path}")
candidates.append(f"{base}/.well-known/oauth-protected-resource")
requests = Requests(
raise_for_status=False,
)
for url in candidates:
try:
resp = await requests.get(url)
if resp.status == 200:
data = resp.json()
if isinstance(data, dict) and "authorization_servers" in data:
return data
except Exception:
continue
return None
async def discover_auth_server_metadata(
self, auth_server_url: str
) -> dict[str, Any] | None:
"""Fetch the OAuth Authorization Server Metadata (RFC 8414).
Given an authorization server URL, returns a dict with:
- ``authorization_endpoint``
- ``token_endpoint``
- ``registration_endpoint`` (for dynamic client registration)
- ``scopes_supported``
- ``code_challenge_methods_supported``
- etc.
"""
from urllib.parse import urlparse
parsed = urlparse(auth_server_url)
base = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path.rstrip("/")
# Try standard metadata endpoints (RFC 8414 and OpenID Connect)
candidates = []
if path and path != "/":
candidates.append(f"{base}/.well-known/oauth-authorization-server{path}")
candidates.append(f"{base}/.well-known/oauth-authorization-server")
candidates.append(f"{base}/.well-known/openid-configuration")
requests = Requests(
raise_for_status=False,
)
for url in candidates:
try:
resp = await requests.get(url)
if resp.status == 200:
data = resp.json()
if isinstance(data, dict) and "authorization_endpoint" in data:
return data
except Exception:
continue
return None
async def initialize(self) -> dict[str, Any]:
"""
Send the MCP initialize request.
This is required by the MCP protocol before any other requests.
Returns the server's capabilities.
"""
result = await self._send_request(
"initialize",
{
"protocolVersion": "2025-03-26",
"capabilities": {},
"clientInfo": {"name": "AutoGPT-Platform", "version": "1.0.0"},
},
)
# Send initialized notification (no response expected)
await self._send_notification("notifications/initialized")
return result or {}
async def list_tools(self) -> list[MCPTool]:
"""
Discover available tools from the MCP server.
Returns a list of MCPTool objects with name, description, and input schema.
"""
result = await self._send_request("tools/list")
if not result or "tools" not in result:
return []
tools = []
for tool_data in result["tools"]:
tools.append(
MCPTool(
name=tool_data.get("name", ""),
description=tool_data.get("description", ""),
input_schema=tool_data.get("inputSchema", {}),
)
)
return tools
async def call_tool(
self, tool_name: str, arguments: dict[str, Any]
) -> MCPCallResult:
"""
Call a tool on the MCP server.
Args:
tool_name: The name of the tool to call.
arguments: The arguments to pass to the tool.
Returns:
MCPCallResult with the tool's response content.
"""
result = await self._send_request(
"tools/call",
{"name": tool_name, "arguments": arguments},
)
if not result:
return MCPCallResult(is_error=True)
return MCPCallResult(
content=result.get("content", []),
is_error=result.get("isError", False),
)

View File

@@ -0,0 +1,204 @@
"""
MCP OAuth handler for MCP servers that use OAuth 2.1 authorization.
Unlike other OAuth handlers (GitHub, Google, etc.) where endpoints are fixed,
MCP servers have dynamic endpoints discovered via RFC 9728 / RFC 8414 metadata.
This handler accepts those endpoints at construction time.
"""
import logging
import time
import urllib.parse
from typing import ClassVar, Optional
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
from backend.integrations.oauth.base import BaseOAuthHandler
from backend.integrations.providers import ProviderName
from backend.util.request import Requests
logger = logging.getLogger(__name__)
class MCPOAuthHandler(BaseOAuthHandler):
"""
OAuth handler for MCP servers with dynamically-discovered endpoints.
Construction requires the authorization and token endpoint URLs,
which are obtained via MCP OAuth metadata discovery
(``MCPClient.discover_auth`` + ``discover_auth_server_metadata``).
"""
PROVIDER_NAME: ClassVar[ProviderName | str] = ProviderName.MCP
DEFAULT_SCOPES: ClassVar[list[str]] = []
def __init__(
self,
client_id: str,
client_secret: str,
redirect_uri: str,
*,
authorize_url: str,
token_url: str,
revoke_url: str | None = None,
resource_url: str | None = None,
):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
self.authorize_url = authorize_url
self.token_url = token_url
self.revoke_url = revoke_url
self.resource_url = resource_url
def get_login_url(
self,
scopes: list[str],
state: str,
code_challenge: Optional[str],
) -> str:
scopes = self.handle_default_scopes(scopes)
params: dict[str, str] = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"state": state,
}
if scopes:
params["scope"] = " ".join(scopes)
# PKCE (S256) — included when the caller provides a code_challenge
if code_challenge:
params["code_challenge"] = code_challenge
params["code_challenge_method"] = "S256"
# MCP spec requires resource indicator (RFC 8707)
if self.resource_url:
params["resource"] = self.resource_url
return f"{self.authorize_url}?{urllib.parse.urlencode(params)}"
async def exchange_code_for_tokens(
self,
code: str,
scopes: list[str],
code_verifier: Optional[str],
) -> OAuth2Credentials:
data: dict[str, str] = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": self.redirect_uri,
"client_id": self.client_id,
}
if self.client_secret:
data["client_secret"] = self.client_secret
if code_verifier:
data["code_verifier"] = code_verifier
if self.resource_url:
data["resource"] = self.resource_url
response = await Requests(raise_for_status=True).post(
self.token_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
tokens = response.json()
if "error" in tokens:
raise RuntimeError(
f"Token exchange failed: {tokens.get('error_description', tokens['error'])}"
)
if "access_token" not in tokens:
raise RuntimeError("OAuth token response missing 'access_token' field")
now = int(time.time())
expires_in = tokens.get("expires_in")
return OAuth2Credentials(
provider=self.PROVIDER_NAME,
title=None,
access_token=SecretStr(tokens["access_token"]),
refresh_token=(
SecretStr(tokens["refresh_token"])
if tokens.get("refresh_token")
else None
),
access_token_expires_at=now + expires_in if expires_in else None,
refresh_token_expires_at=None,
scopes=scopes,
metadata={
"mcp_token_url": self.token_url,
"mcp_resource_url": self.resource_url,
},
)
async def _refresh_tokens(
self, credentials: OAuth2Credentials
) -> OAuth2Credentials:
if not credentials.refresh_token:
raise ValueError("No refresh token available for MCP OAuth credentials")
data: dict[str, str] = {
"grant_type": "refresh_token",
"refresh_token": credentials.refresh_token.get_secret_value(),
"client_id": self.client_id,
}
if self.client_secret:
data["client_secret"] = self.client_secret
if self.resource_url:
data["resource"] = self.resource_url
response = await Requests(raise_for_status=True).post(
self.token_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
tokens = response.json()
if "error" in tokens:
raise RuntimeError(
f"Token refresh failed: {tokens.get('error_description', tokens['error'])}"
)
if "access_token" not in tokens:
raise RuntimeError("OAuth refresh response missing 'access_token' field")
now = int(time.time())
expires_in = tokens.get("expires_in")
return OAuth2Credentials(
id=credentials.id,
provider=self.PROVIDER_NAME,
title=credentials.title,
access_token=SecretStr(tokens["access_token"]),
refresh_token=(
SecretStr(tokens["refresh_token"])
if tokens.get("refresh_token")
else credentials.refresh_token
),
access_token_expires_at=now + expires_in if expires_in else None,
refresh_token_expires_at=credentials.refresh_token_expires_at,
scopes=credentials.scopes,
metadata=credentials.metadata,
)
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
if not self.revoke_url:
return False
try:
data = {
"token": credentials.access_token.get_secret_value(),
"token_type_hint": "access_token",
"client_id": self.client_id,
}
await Requests().post(
self.revoke_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
return True
except Exception:
logger.warning("Failed to revoke MCP OAuth tokens", exc_info=True)
return False

View File

@@ -0,0 +1,109 @@
"""
End-to-end tests against a real public MCP server.
These tests hit the OpenAI docs MCP server (https://developers.openai.com/mcp)
which is publicly accessible without authentication and returns SSE responses.
Mark: These are tagged with ``@pytest.mark.e2e`` so they can be run/skipped
independently of the rest of the test suite (they require network access).
"""
import json
import os
import pytest
from backend.blocks.mcp.client import MCPClient
# Public MCP server that requires no authentication
OPENAI_DOCS_MCP_URL = "https://developers.openai.com/mcp"
# Skip all tests in this module unless RUN_E2E env var is set
pytestmark = pytest.mark.skipif(
not os.environ.get("RUN_E2E"), reason="set RUN_E2E=1 to run e2e tests"
)
class TestRealMCPServer:
"""Tests against the live OpenAI docs MCP server."""
@pytest.mark.asyncio(loop_scope="session")
async def test_initialize(self):
"""Verify we can complete the MCP handshake with a real server."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
result = await client.initialize()
assert result["protocolVersion"] == "2025-03-26"
assert "serverInfo" in result
assert result["serverInfo"]["name"] == "openai-docs-mcp"
assert "tools" in result.get("capabilities", {})
@pytest.mark.asyncio(loop_scope="session")
async def test_list_tools(self):
"""Verify we can discover tools from a real MCP server."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
await client.initialize()
tools = await client.list_tools()
assert len(tools) >= 3 # server has at least 5 tools as of writing
tool_names = {t.name for t in tools}
# These tools are documented and should be stable
assert "search_openai_docs" in tool_names
assert "list_openai_docs" in tool_names
assert "fetch_openai_doc" in tool_names
# Verify schema structure
search_tool = next(t for t in tools if t.name == "search_openai_docs")
assert "query" in search_tool.input_schema.get("properties", {})
assert "query" in search_tool.input_schema.get("required", [])
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_list_api_endpoints(self):
"""Call the list_api_endpoints tool and verify we get real data."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
await client.initialize()
result = await client.call_tool("list_api_endpoints", {})
assert not result.is_error
assert len(result.content) >= 1
assert result.content[0]["type"] == "text"
data = json.loads(result.content[0]["text"])
assert "paths" in data or "urls" in data
# The OpenAI API should have many endpoints
total = data.get("total", len(data.get("paths", [])))
assert total > 50
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_search(self):
"""Search for docs and verify we get results."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
await client.initialize()
result = await client.call_tool(
"search_openai_docs", {"query": "chat completions", "limit": 3}
)
assert not result.is_error
assert len(result.content) >= 1
@pytest.mark.asyncio(loop_scope="session")
async def test_sse_response_handling(self):
"""Verify the client correctly handles SSE responses from a real server.
This is the key test — our local test server returns JSON,
but real MCP servers typically return SSE. This proves the
SSE parsing works end-to-end.
"""
client = MCPClient(OPENAI_DOCS_MCP_URL)
# initialize() internally calls _send_request which must parse SSE
result = await client.initialize()
# If we got here without error, SSE parsing works
assert isinstance(result, dict)
assert "protocolVersion" in result
# Also verify list_tools works (another SSE response)
tools = await client.list_tools()
assert len(tools) > 0
assert all(hasattr(t, "name") for t in tools)

View File

@@ -0,0 +1,389 @@
"""
Integration tests for MCP client and MCPToolBlock against a real HTTP server.
These tests spin up a local MCP test server and run the full client/block flow
against it — no mocking, real HTTP requests.
"""
import asyncio
import json
import threading
from unittest.mock import patch
import pytest
from aiohttp import web
from pydantic import SecretStr
from backend.blocks.mcp.block import MCPToolBlock
from backend.blocks.mcp.client import MCPClient
from backend.blocks.mcp.test_server import create_test_mcp_app
from backend.data.model import OAuth2Credentials
MOCK_USER_ID = "test-user-integration"
class _MCPTestServer:
"""
Run an MCP test server in a background thread with its own event loop.
This avoids event loop conflicts with pytest-asyncio.
"""
def __init__(self, auth_token: str | None = None):
self.auth_token = auth_token
self.url: str = ""
self._runner: web.AppRunner | None = None
self._loop: asyncio.AbstractEventLoop | None = None
self._thread: threading.Thread | None = None
self._started = threading.Event()
def _run(self):
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._loop.run_until_complete(self._start())
self._started.set()
self._loop.run_forever()
async def _start(self):
app = create_test_mcp_app(auth_token=self.auth_token)
self._runner = web.AppRunner(app)
await self._runner.setup()
site = web.TCPSite(self._runner, "127.0.0.1", 0)
await site.start()
port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr]
self.url = f"http://127.0.0.1:{port}/mcp"
def start(self):
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
if not self._started.wait(timeout=5):
raise RuntimeError("MCP test server failed to start within 5 seconds")
return self
def stop(self):
if self._loop and self._runner:
asyncio.run_coroutine_threadsafe(self._runner.cleanup(), self._loop).result(
timeout=5
)
self._loop.call_soon_threadsafe(self._loop.stop)
if self._thread:
self._thread.join(timeout=5)
@pytest.fixture(scope="module")
def mcp_server():
"""Start a local MCP test server in a background thread."""
server = _MCPTestServer()
server.start()
yield server.url
server.stop()
@pytest.fixture(scope="module")
def mcp_server_with_auth():
"""Start a local MCP test server with auth in a background thread."""
server = _MCPTestServer(auth_token="test-secret-token")
server.start()
yield server.url, "test-secret-token"
server.stop()
@pytest.fixture(autouse=True)
def _allow_localhost():
"""
Allow 127.0.0.1 through SSRF protection for integration tests.
The Requests class blocks private IPs by default. We patch the Requests
constructor to always include 127.0.0.1 as a trusted origin so the local
test server is reachable.
"""
from backend.util.request import Requests
original_init = Requests.__init__
def patched_init(self, *args, **kwargs):
trusted = list(kwargs.get("trusted_origins") or [])
trusted.append("http://127.0.0.1")
kwargs["trusted_origins"] = trusted
original_init(self, *args, **kwargs)
with patch.object(Requests, "__init__", patched_init):
yield
def _make_client(url: str, auth_token: str | None = None) -> MCPClient:
"""Create an MCPClient for integration tests."""
return MCPClient(url, auth_token=auth_token)
# ── MCPClient integration tests ──────────────────────────────────────
class TestMCPClientIntegration:
"""Test MCPClient against a real local MCP server."""
@pytest.mark.asyncio(loop_scope="session")
async def test_initialize(self, mcp_server):
client = _make_client(mcp_server)
result = await client.initialize()
assert result["protocolVersion"] == "2025-03-26"
assert result["serverInfo"]["name"] == "test-mcp-server"
assert "tools" in result["capabilities"]
@pytest.mark.asyncio(loop_scope="session")
async def test_list_tools(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
tools = await client.list_tools()
assert len(tools) == 3
tool_names = {t.name for t in tools}
assert tool_names == {"get_weather", "add_numbers", "echo"}
# Check get_weather schema
weather = next(t for t in tools if t.name == "get_weather")
assert weather.description == "Get current weather for a city"
assert "city" in weather.input_schema["properties"]
assert weather.input_schema["required"] == ["city"]
# Check add_numbers schema
add = next(t for t in tools if t.name == "add_numbers")
assert "a" in add.input_schema["properties"]
assert "b" in add.input_schema["properties"]
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_get_weather(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
result = await client.call_tool("get_weather", {"city": "London"})
assert not result.is_error
assert len(result.content) == 1
assert result.content[0]["type"] == "text"
data = json.loads(result.content[0]["text"])
assert data["city"] == "London"
assert data["temperature"] == 22
assert data["condition"] == "sunny"
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_add_numbers(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
result = await client.call_tool("add_numbers", {"a": 3, "b": 7})
assert not result.is_error
data = json.loads(result.content[0]["text"])
assert data["result"] == 10
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_echo(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
result = await client.call_tool("echo", {"message": "Hello MCP!"})
assert not result.is_error
assert result.content[0]["text"] == "Hello MCP!"
@pytest.mark.asyncio(loop_scope="session")
async def test_call_unknown_tool(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
result = await client.call_tool("nonexistent_tool", {})
assert result.is_error
assert "Unknown tool" in result.content[0]["text"]
@pytest.mark.asyncio(loop_scope="session")
async def test_auth_success(self, mcp_server_with_auth):
url, token = mcp_server_with_auth
client = _make_client(url, auth_token=token)
result = await client.initialize()
assert result["protocolVersion"] == "2025-03-26"
tools = await client.list_tools()
assert len(tools) == 3
@pytest.mark.asyncio(loop_scope="session")
async def test_auth_failure(self, mcp_server_with_auth):
url, _ = mcp_server_with_auth
client = _make_client(url, auth_token="wrong-token")
with pytest.raises(Exception):
await client.initialize()
@pytest.mark.asyncio(loop_scope="session")
async def test_auth_missing(self, mcp_server_with_auth):
url, _ = mcp_server_with_auth
client = _make_client(url)
with pytest.raises(Exception):
await client.initialize()
# ── MCPToolBlock integration tests ───────────────────────────────────
class TestMCPToolBlockIntegration:
"""Test MCPToolBlock end-to-end against a real local MCP server."""
@pytest.mark.asyncio(loop_scope="session")
async def test_full_flow_get_weather(self, mcp_server):
"""Full flow: discover tools, select one, execute it."""
# Step 1: Discover tools (simulating what the frontend/API would do)
client = _make_client(mcp_server)
await client.initialize()
tools = await client.list_tools()
assert len(tools) == 3
# Step 2: User selects "get_weather" and we get its schema
weather_tool = next(t for t in tools if t.name == "get_weather")
# Step 3: Execute the block — no credentials (public server)
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="get_weather",
tool_input_schema=weather_tool.input_schema,
tool_arguments={"city": "Paris"},
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
result = outputs[0][1]
assert result["city"] == "Paris"
assert result["temperature"] == 22
assert result["condition"] == "sunny"
@pytest.mark.asyncio(loop_scope="session")
async def test_full_flow_add_numbers(self, mcp_server):
"""Full flow for add_numbers tool."""
client = _make_client(mcp_server)
await client.initialize()
tools = await client.list_tools()
add_tool = next(t for t in tools if t.name == "add_numbers")
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="add_numbers",
tool_input_schema=add_tool.input_schema,
tool_arguments={"a": 42, "b": 58},
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1]["result"] == 100
@pytest.mark.asyncio(loop_scope="session")
async def test_full_flow_echo_plain_text(self, mcp_server):
"""Verify plain text (non-JSON) responses work."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="echo",
tool_input_schema={
"type": "object",
"properties": {"message": {"type": "string"}},
"required": ["message"],
},
tool_arguments={"message": "Hello from AutoGPT!"},
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1] == "Hello from AutoGPT!"
@pytest.mark.asyncio(loop_scope="session")
async def test_full_flow_unknown_tool_yields_error(self, mcp_server):
"""Calling an unknown tool should yield an error output."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="nonexistent_tool",
tool_arguments={},
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "error"
assert "returned an error" in outputs[0][1]
@pytest.mark.asyncio(loop_scope="session")
async def test_full_flow_with_auth(self, mcp_server_with_auth):
"""Full flow with authentication via credentials kwarg."""
url, token = mcp_server_with_auth
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=url,
selected_tool="echo",
tool_input_schema={
"type": "object",
"properties": {"message": {"type": "string"}},
"required": ["message"],
},
tool_arguments={"message": "Authenticated!"},
)
# Pass credentials via the standard kwarg (as the executor would)
test_creds = OAuth2Credentials(
id="test-cred",
provider="mcp",
access_token=SecretStr(token),
refresh_token=SecretStr(""),
scopes=[],
title="Test MCP credential",
)
outputs = []
async for name, data in block.run(
input_data, user_id=MOCK_USER_ID, credentials=test_creds
):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1] == "Authenticated!"
@pytest.mark.asyncio(loop_scope="session")
async def test_no_credentials_runs_without_auth(self, mcp_server):
"""Block runs without auth when no credentials are provided."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="echo",
tool_input_schema={
"type": "object",
"properties": {"message": {"type": "string"}},
"required": ["message"],
},
tool_arguments={"message": "No auth needed"},
)
outputs = []
async for name, data in block.run(
input_data, user_id=MOCK_USER_ID, credentials=None
):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1] == "No auth needed"

View File

@@ -0,0 +1,619 @@
"""
Tests for MCP client and MCPToolBlock.
"""
import json
from unittest.mock import AsyncMock, patch
import pytest
from backend.blocks.mcp.block import MCPToolBlock
from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError
from backend.util.test import execute_block_test
# ── SSE parsing unit tests ───────────────────────────────────────────
class TestSSEParsing:
"""Tests for SSE (text/event-stream) response parsing."""
def test_parse_sse_simple(self):
sse = (
"event: message\n"
'data: {"jsonrpc":"2.0","result":{"tools":[]},"id":1}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == {"tools": []}
assert body["id"] == 1
def test_parse_sse_with_notifications(self):
"""SSE streams can contain notifications (no id) before the response."""
sse = (
"event: message\n"
'data: {"jsonrpc":"2.0","method":"some/notification"}\n'
"\n"
"event: message\n"
'data: {"jsonrpc":"2.0","result":{"ok":true},"id":2}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == {"ok": True}
assert body["id"] == 2
def test_parse_sse_error_response(self):
sse = (
"event: message\n"
'data: {"jsonrpc":"2.0","error":{"code":-32600,"message":"Bad Request"},"id":1}\n'
)
body = MCPClient._parse_sse_response(sse)
assert "error" in body
assert body["error"]["code"] == -32600
def test_parse_sse_no_data_raises(self):
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
MCPClient._parse_sse_response("event: message\n\n")
def test_parse_sse_empty_raises(self):
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
MCPClient._parse_sse_response("")
def test_parse_sse_ignores_non_data_lines(self):
sse = (
": comment line\n"
"event: message\n"
"id: 123\n"
'data: {"jsonrpc":"2.0","result":"ok","id":1}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == "ok"
def test_parse_sse_uses_last_response(self):
"""If multiple responses exist, use the last one."""
sse = (
'data: {"jsonrpc":"2.0","result":"first","id":1}\n'
"\n"
'data: {"jsonrpc":"2.0","result":"second","id":2}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == "second"
# ── MCPClient unit tests ─────────────────────────────────────────────
class TestMCPClient:
"""Tests for the MCP HTTP client."""
def test_build_headers_without_auth(self):
client = MCPClient("https://mcp.example.com")
headers = client._build_headers()
assert "Authorization" not in headers
assert headers["Content-Type"] == "application/json"
def test_build_headers_with_auth(self):
client = MCPClient("https://mcp.example.com", auth_token="my-token")
headers = client._build_headers()
assert headers["Authorization"] == "Bearer my-token"
def test_build_jsonrpc_request(self):
client = MCPClient("https://mcp.example.com")
req = client._build_jsonrpc_request("tools/list")
assert req["jsonrpc"] == "2.0"
assert req["method"] == "tools/list"
assert "id" in req
assert "params" not in req
def test_build_jsonrpc_request_with_params(self):
client = MCPClient("https://mcp.example.com")
req = client._build_jsonrpc_request(
"tools/call", {"name": "test", "arguments": {"x": 1}}
)
assert req["params"] == {"name": "test", "arguments": {"x": 1}}
def test_request_id_increments(self):
client = MCPClient("https://mcp.example.com")
req1 = client._build_jsonrpc_request("tools/list")
req2 = client._build_jsonrpc_request("tools/list")
assert req2["id"] > req1["id"]
def test_server_url_trailing_slash_stripped(self):
client = MCPClient("https://mcp.example.com/mcp/")
assert client.server_url == "https://mcp.example.com/mcp"
@pytest.mark.asyncio(loop_scope="session")
async def test_send_request_success(self):
client = MCPClient("https://mcp.example.com")
mock_response = AsyncMock()
mock_response.json.return_value = {
"jsonrpc": "2.0",
"result": {"tools": []},
"id": 1,
}
with patch.object(client, "_send_request", return_value={"tools": []}):
result = await client._send_request("tools/list")
assert result == {"tools": []}
@pytest.mark.asyncio(loop_scope="session")
async def test_send_request_error(self):
client = MCPClient("https://mcp.example.com")
async def mock_send(*args, **kwargs):
raise MCPClientError("MCP server error [-32600]: Invalid Request")
with patch.object(client, "_send_request", side_effect=mock_send):
with pytest.raises(MCPClientError, match="Invalid Request"):
await client._send_request("tools/list")
@pytest.mark.asyncio(loop_scope="session")
async def test_list_tools(self):
client = MCPClient("https://mcp.example.com")
mock_result = {
"tools": [
{
"name": "get_weather",
"description": "Get current weather for a city",
"inputSchema": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
},
{
"name": "search",
"description": "Search the web",
"inputSchema": {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
},
},
]
}
with patch.object(client, "_send_request", return_value=mock_result):
tools = await client.list_tools()
assert len(tools) == 2
assert tools[0].name == "get_weather"
assert tools[0].description == "Get current weather for a city"
assert tools[0].input_schema["properties"]["city"]["type"] == "string"
assert tools[1].name == "search"
@pytest.mark.asyncio(loop_scope="session")
async def test_list_tools_empty(self):
client = MCPClient("https://mcp.example.com")
with patch.object(client, "_send_request", return_value={"tools": []}):
tools = await client.list_tools()
assert tools == []
@pytest.mark.asyncio(loop_scope="session")
async def test_list_tools_none_result(self):
client = MCPClient("https://mcp.example.com")
with patch.object(client, "_send_request", return_value=None):
tools = await client.list_tools()
assert tools == []
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_success(self):
client = MCPClient("https://mcp.example.com")
mock_result = {
"content": [
{"type": "text", "text": json.dumps({"temp": 20, "city": "London"})}
],
"isError": False,
}
with patch.object(client, "_send_request", return_value=mock_result):
result = await client.call_tool("get_weather", {"city": "London"})
assert not result.is_error
assert len(result.content) == 1
assert result.content[0]["type"] == "text"
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_error(self):
client = MCPClient("https://mcp.example.com")
mock_result = {
"content": [{"type": "text", "text": "City not found"}],
"isError": True,
}
with patch.object(client, "_send_request", return_value=mock_result):
result = await client.call_tool("get_weather", {"city": "???"})
assert result.is_error
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_none_result(self):
client = MCPClient("https://mcp.example.com")
with patch.object(client, "_send_request", return_value=None):
result = await client.call_tool("get_weather", {"city": "London"})
assert result.is_error
@pytest.mark.asyncio(loop_scope="session")
async def test_initialize(self):
client = MCPClient("https://mcp.example.com")
mock_result = {
"protocolVersion": "2025-03-26",
"capabilities": {"tools": {}},
"serverInfo": {"name": "test-server", "version": "1.0.0"},
}
with (
patch.object(client, "_send_request", return_value=mock_result) as mock_req,
patch.object(client, "_send_notification") as mock_notif,
):
result = await client.initialize()
mock_req.assert_called_once()
mock_notif.assert_called_once_with("notifications/initialized")
assert result["protocolVersion"] == "2025-03-26"
# ── MCPToolBlock unit tests ──────────────────────────────────────────
MOCK_USER_ID = "test-user-123"
class TestMCPToolBlock:
"""Tests for the MCPToolBlock."""
def test_block_instantiation(self):
block = MCPToolBlock()
assert block.id == "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
assert block.name == "MCPToolBlock"
def test_input_schema_has_required_fields(self):
block = MCPToolBlock()
schema = block.input_schema.jsonschema()
props = schema.get("properties", {})
assert "server_url" in props
assert "selected_tool" in props
assert "tool_arguments" in props
assert "credentials" in props
def test_output_schema(self):
block = MCPToolBlock()
schema = block.output_schema.jsonschema()
props = schema.get("properties", {})
assert "result" in props
assert "error" in props
def test_get_input_schema_with_tool_schema(self):
tool_schema = {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
}
data = {"tool_input_schema": tool_schema}
result = MCPToolBlock.Input.get_input_schema(data)
assert result == tool_schema
def test_get_input_schema_without_tool_schema(self):
result = MCPToolBlock.Input.get_input_schema({})
assert result == {}
def test_get_input_defaults(self):
data = {"tool_arguments": {"city": "London"}}
result = MCPToolBlock.Input.get_input_defaults(data)
assert result == {"city": "London"}
def test_get_missing_input(self):
data = {
"tool_input_schema": {
"type": "object",
"properties": {
"city": {"type": "string"},
"units": {"type": "string"},
},
"required": ["city", "units"],
},
"tool_arguments": {"city": "London"},
}
missing = MCPToolBlock.Input.get_missing_input(data)
assert missing == {"units"}
def test_get_missing_input_all_present(self):
data = {
"tool_input_schema": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
"tool_arguments": {"city": "London"},
}
missing = MCPToolBlock.Input.get_missing_input(data)
assert missing == set()
@pytest.mark.asyncio(loop_scope="session")
async def test_run_with_mock(self):
"""Test the block using the built-in test infrastructure."""
block = MCPToolBlock()
await execute_block_test(block)
@pytest.mark.asyncio(loop_scope="session")
async def test_run_missing_server_url(self):
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="",
selected_tool="test",
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert outputs == [("error", "MCP server URL is required")]
@pytest.mark.asyncio(loop_scope="session")
async def test_run_missing_tool(self):
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="",
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert outputs == [
("error", "No tool selected. Please select a tool from the dropdown.")
]
@pytest.mark.asyncio(loop_scope="session")
async def test_run_success(self):
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="get_weather",
tool_input_schema={
"type": "object",
"properties": {"city": {"type": "string"}},
},
tool_arguments={"city": "London"},
)
async def mock_call(*args, **kwargs):
return {"temp": 20, "city": "London"}
block._call_mcp_tool = mock_call # type: ignore
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1] == {"temp": 20, "city": "London"}
@pytest.mark.asyncio(loop_scope="session")
async def test_run_mcp_error(self):
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="bad_tool",
)
async def mock_call(*args, **kwargs):
raise MCPClientError("Tool not found")
block._call_mcp_tool = mock_call # type: ignore
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert outputs[0][0] == "error"
assert "Tool not found" in outputs[0][1]
@pytest.mark.asyncio(loop_scope="session")
async def test_call_mcp_tool_parses_json_text(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[
{"type": "text", "text": '{"temp": 20}'},
],
is_error=False,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
result = await block._call_mcp_tool(
"https://mcp.example.com", "test_tool", {}
)
assert result == {"temp": 20}
@pytest.mark.asyncio(loop_scope="session")
async def test_call_mcp_tool_plain_text(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[
{"type": "text", "text": "Hello, world!"},
],
is_error=False,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
result = await block._call_mcp_tool(
"https://mcp.example.com", "test_tool", {}
)
assert result == "Hello, world!"
@pytest.mark.asyncio(loop_scope="session")
async def test_call_mcp_tool_multiple_content(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[
{"type": "text", "text": "Part 1"},
{"type": "text", "text": '{"part": 2}'},
],
is_error=False,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
result = await block._call_mcp_tool(
"https://mcp.example.com", "test_tool", {}
)
assert result == ["Part 1", {"part": 2}]
@pytest.mark.asyncio(loop_scope="session")
async def test_call_mcp_tool_error_result(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[{"type": "text", "text": "Something went wrong"}],
is_error=True,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
with pytest.raises(MCPClientError, match="returned an error"):
await block._call_mcp_tool("https://mcp.example.com", "test_tool", {})
@pytest.mark.asyncio(loop_scope="session")
async def test_call_mcp_tool_image_content(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[
{
"type": "image",
"data": "base64data==",
"mimeType": "image/png",
}
],
is_error=False,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
result = await block._call_mcp_tool(
"https://mcp.example.com", "test_tool", {}
)
assert result == {
"type": "image",
"data": "base64data==",
"mimeType": "image/png",
}
@pytest.mark.asyncio(loop_scope="session")
async def test_run_with_credentials(self):
"""Verify the block uses OAuth2Credentials and passes auth token."""
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="test_tool",
)
captured_tokens: list[str | None] = []
async def mock_call(server_url, tool_name, arguments, auth_token=None):
captured_tokens.append(auth_token)
return "ok"
block._call_mcp_tool = mock_call # type: ignore
test_creds = OAuth2Credentials(
id="cred-123",
provider="mcp",
access_token=SecretStr("resolved-token"),
refresh_token=SecretStr(""),
scopes=[],
title="Test MCP credential",
)
async for _ in block.run(
input_data, user_id=MOCK_USER_ID, credentials=test_creds
):
pass
assert captured_tokens == ["resolved-token"]
@pytest.mark.asyncio(loop_scope="session")
async def test_run_without_credentials(self):
"""Verify the block works without credentials (public server)."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="test_tool",
)
captured_tokens: list[str | None] = []
async def mock_call(server_url, tool_name, arguments, auth_token=None):
captured_tokens.append(auth_token)
return "ok"
block._call_mcp_tool = mock_call # type: ignore
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert captured_tokens == [None]
assert outputs == [("result", "ok")]

View File

@@ -0,0 +1,242 @@
"""
Tests for MCP OAuth handler.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
from backend.blocks.mcp.client import MCPClient
from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
def _mock_response(json_data: dict, status: int = 200) -> MagicMock:
"""Create a mock Response with synchronous json() (matching Requests.Response)."""
resp = MagicMock()
resp.status = status
resp.ok = 200 <= status < 300
resp.json.return_value = json_data
return resp
class TestMCPOAuthHandler:
"""Tests for the MCPOAuthHandler."""
def _make_handler(self, **overrides) -> MCPOAuthHandler:
defaults = {
"client_id": "test-client-id",
"client_secret": "test-client-secret",
"redirect_uri": "https://app.example.com/callback",
"authorize_url": "https://auth.example.com/authorize",
"token_url": "https://auth.example.com/token",
}
defaults.update(overrides)
return MCPOAuthHandler(**defaults)
def test_get_login_url_basic(self):
handler = self._make_handler()
url = handler.get_login_url(
scopes=["read", "write"],
state="random-state-token",
code_challenge="S256-challenge-value",
)
assert "https://auth.example.com/authorize?" in url
assert "response_type=code" in url
assert "client_id=test-client-id" in url
assert "state=random-state-token" in url
assert "code_challenge=S256-challenge-value" in url
assert "code_challenge_method=S256" in url
assert "scope=read+write" in url
def test_get_login_url_with_resource(self):
handler = self._make_handler(resource_url="https://mcp.example.com/mcp")
url = handler.get_login_url(
scopes=[], state="state", code_challenge="challenge"
)
assert "resource=https" in url
def test_get_login_url_without_pkce(self):
handler = self._make_handler()
url = handler.get_login_url(scopes=["read"], state="state", code_challenge=None)
assert "code_challenge" not in url
assert "code_challenge_method" not in url
@pytest.mark.asyncio(loop_scope="session")
async def test_exchange_code_for_tokens(self):
handler = self._make_handler()
resp = _mock_response(
{
"access_token": "new-access-token",
"refresh_token": "new-refresh-token",
"expires_in": 3600,
"token_type": "Bearer",
}
)
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
instance = MockRequests.return_value
instance.post = AsyncMock(return_value=resp)
creds = await handler.exchange_code_for_tokens(
code="auth-code",
scopes=["read"],
code_verifier="pkce-verifier",
)
assert isinstance(creds, OAuth2Credentials)
assert creds.access_token.get_secret_value() == "new-access-token"
assert creds.refresh_token is not None
assert creds.refresh_token.get_secret_value() == "new-refresh-token"
assert creds.scopes == ["read"]
assert creds.access_token_expires_at is not None
@pytest.mark.asyncio(loop_scope="session")
async def test_refresh_tokens(self):
handler = self._make_handler()
existing_creds = OAuth2Credentials(
id="existing-id",
provider="mcp",
access_token=SecretStr("old-token"),
refresh_token=SecretStr("old-refresh"),
scopes=["read"],
title="test",
)
resp = _mock_response(
{
"access_token": "refreshed-token",
"refresh_token": "new-refresh",
"expires_in": 3600,
}
)
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
instance = MockRequests.return_value
instance.post = AsyncMock(return_value=resp)
refreshed = await handler._refresh_tokens(existing_creds)
assert refreshed.id == "existing-id"
assert refreshed.access_token.get_secret_value() == "refreshed-token"
assert refreshed.refresh_token is not None
assert refreshed.refresh_token.get_secret_value() == "new-refresh"
@pytest.mark.asyncio(loop_scope="session")
async def test_refresh_tokens_no_refresh_token(self):
handler = self._make_handler()
creds = OAuth2Credentials(
provider="mcp",
access_token=SecretStr("token"),
scopes=["read"],
title="test",
)
with pytest.raises(ValueError, match="No refresh token"):
await handler._refresh_tokens(creds)
@pytest.mark.asyncio(loop_scope="session")
async def test_revoke_tokens_no_url(self):
handler = self._make_handler(revoke_url=None)
creds = OAuth2Credentials(
provider="mcp",
access_token=SecretStr("token"),
scopes=[],
title="test",
)
result = await handler.revoke_tokens(creds)
assert result is False
@pytest.mark.asyncio(loop_scope="session")
async def test_revoke_tokens_with_url(self):
handler = self._make_handler(revoke_url="https://auth.example.com/revoke")
creds = OAuth2Credentials(
provider="mcp",
access_token=SecretStr("token"),
scopes=[],
title="test",
)
resp = _mock_response({}, status=200)
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
instance = MockRequests.return_value
instance.post = AsyncMock(return_value=resp)
result = await handler.revoke_tokens(creds)
assert result is True
class TestMCPClientDiscovery:
"""Tests for MCPClient OAuth metadata discovery."""
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_auth_found(self):
client = MCPClient("https://mcp.example.com/mcp")
metadata = {
"authorization_servers": ["https://auth.example.com"],
"resource": "https://mcp.example.com/mcp",
}
resp = _mock_response(metadata, status=200)
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
instance = MockRequests.return_value
instance.get = AsyncMock(return_value=resp)
result = await client.discover_auth()
assert result is not None
assert result["authorization_servers"] == ["https://auth.example.com"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_auth_not_found(self):
client = MCPClient("https://mcp.example.com/mcp")
resp = _mock_response({}, status=404)
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
instance = MockRequests.return_value
instance.get = AsyncMock(return_value=resp)
result = await client.discover_auth()
assert result is None
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_auth_server_metadata(self):
client = MCPClient("https://mcp.example.com/mcp")
server_metadata = {
"issuer": "https://auth.example.com",
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
"registration_endpoint": "https://auth.example.com/register",
"code_challenge_methods_supported": ["S256"],
}
resp = _mock_response(server_metadata, status=200)
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
instance = MockRequests.return_value
instance.get = AsyncMock(return_value=resp)
result = await client.discover_auth_server_metadata(
"https://auth.example.com"
)
assert result is not None
assert result["authorization_endpoint"] == "https://auth.example.com/authorize"
assert result["token_endpoint"] == "https://auth.example.com/token"

View File

@@ -0,0 +1,162 @@
"""
Minimal MCP server for integration testing.
Implements the MCP Streamable HTTP transport (JSON-RPC 2.0 over HTTP POST)
with a few sample tools. Runs on localhost with a random available port.
"""
import json
import logging
from aiohttp import web
logger = logging.getLogger(__name__)
# Sample tools this test server exposes
TEST_TOOLS = [
{
"name": "get_weather",
"description": "Get current weather for a city",
"inputSchema": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City name",
},
},
"required": ["city"],
},
},
{
"name": "add_numbers",
"description": "Add two numbers together",
"inputSchema": {
"type": "object",
"properties": {
"a": {"type": "number", "description": "First number"},
"b": {"type": "number", "description": "Second number"},
},
"required": ["a", "b"],
},
},
{
"name": "echo",
"description": "Echo back the input message",
"inputSchema": {
"type": "object",
"properties": {
"message": {"type": "string", "description": "Message to echo"},
},
"required": ["message"],
},
},
]
def _handle_initialize(params: dict) -> dict:
return {
"protocolVersion": "2025-03-26",
"capabilities": {"tools": {"listChanged": False}},
"serverInfo": {"name": "test-mcp-server", "version": "1.0.0"},
}
def _handle_tools_list(params: dict) -> dict:
return {"tools": TEST_TOOLS}
def _handle_tools_call(params: dict) -> dict:
tool_name = params.get("name", "")
arguments = params.get("arguments", {})
if tool_name == "get_weather":
city = arguments.get("city", "Unknown")
return {
"content": [
{
"type": "text",
"text": json.dumps(
{"city": city, "temperature": 22, "condition": "sunny"}
),
}
],
}
elif tool_name == "add_numbers":
a = arguments.get("a", 0)
b = arguments.get("b", 0)
return {
"content": [{"type": "text", "text": json.dumps({"result": a + b})}],
}
elif tool_name == "echo":
message = arguments.get("message", "")
return {
"content": [{"type": "text", "text": message}],
}
else:
return {
"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}],
"isError": True,
}
HANDLERS = {
"initialize": _handle_initialize,
"tools/list": _handle_tools_list,
"tools/call": _handle_tools_call,
}
async def handle_mcp_request(request: web.Request) -> web.Response:
"""Handle incoming MCP JSON-RPC 2.0 requests."""
# Check auth if configured
expected_token = request.app.get("auth_token")
if expected_token:
auth_header = request.headers.get("Authorization", "")
if auth_header != f"Bearer {expected_token}":
return web.json_response(
{
"jsonrpc": "2.0",
"error": {"code": -32001, "message": "Unauthorized"},
"id": None,
},
status=401,
)
body = await request.json()
# Handle notifications (no id field) — just acknowledge
if "id" not in body:
return web.Response(status=202)
method = body.get("method", "")
params = body.get("params", {})
request_id = body.get("id")
handler = HANDLERS.get(method)
if not handler:
return web.json_response(
{
"jsonrpc": "2.0",
"error": {
"code": -32601,
"message": f"Method not found: {method}",
},
"id": request_id,
}
)
result = handler(params)
return web.json_response({"jsonrpc": "2.0", "result": result, "id": request_id})
def create_test_mcp_app(auth_token: str | None = None) -> web.Application:
"""Create an aiohttp app that acts as an MCP server."""
app = web.Application()
app.router.add_post("/mcp", handle_mcp_request)
if auth_token:
app["auth_token"] = auth_token
return app

View File

@@ -39,6 +39,7 @@ from backend.util import type as type_utils
from backend.util.exceptions import GraphNotAccessibleError, GraphNotInLibraryError from backend.util.exceptions import GraphNotAccessibleError, GraphNotInLibraryError
from backend.util.json import SafeJson from backend.util.json import SafeJson
from backend.util.models import Pagination from backend.util.models import Pagination
from backend.util.request import parse_url
from .block import ( from .block import (
AnyBlockSchema, AnyBlockSchema,
@@ -462,6 +463,9 @@ class GraphModel(Graph, GraphMeta):
continue continue
if ProviderName.HTTP in field.provider: if ProviderName.HTTP in field.provider:
continue continue
# MCP credentials are intentionally split by server URL
if ProviderName.MCP in field.provider:
continue
# If this happens, that means a block implementation probably needs # If this happens, that means a block implementation probably needs
# to be updated. # to be updated.
@@ -518,6 +522,18 @@ class GraphModel(Graph, GraphMeta):
"required": ["id", "provider", "type"], "required": ["id", "provider", "type"],
} }
# Add a descriptive display title when URL-based discriminator values
# are present (e.g. "mcp.sentry.dev" instead of just "Mcp")
if (
field_info.discriminator
and not field_info.discriminator_mapping
and field_info.discriminator_values
):
hostnames = sorted(
parse_url(str(v)).netloc for v in field_info.discriminator_values
)
field_schema["display_name"] = ", ".join(hostnames)
# Add other (optional) field info items # Add other (optional) field info items
field_schema.update( field_schema.update(
field_info.model_dump( field_info.model_dump(
@@ -562,8 +578,17 @@ class GraphModel(Graph, GraphMeta):
for graph in [self] + self.sub_graphs: for graph in [self] + self.sub_graphs:
for node in graph.nodes: for node in graph.nodes:
# Track if this node requires credentials (credentials_optional=False means required) # A node's credentials are optional if either:
node_required_map[node.id] = not node.credentials_optional # 1. The node metadata says so (credentials_optional=True), or
# 2. All credential fields on the block have defaults (not required by schema)
block_required = node.block.input_schema.get_required_fields()
creds_required_by_schema = any(
fname in block_required
for fname in node.block.input_schema.get_credentials_fields()
)
node_required_map[node.id] = (
not node.credentials_optional and creds_required_by_schema
)
for ( for (
field_name, field_name,
@@ -784,6 +809,19 @@ class GraphModel(Graph, GraphMeta):
"'credentials' and `*_credentials` are reserved" "'credentials' and `*_credentials` are reserved"
) )
# Check custom block-level validation (e.g., MCP dynamic tool arguments).
# Blocks can override get_missing_input to report additional missing fields
# beyond the standard top-level required fields.
if for_run:
credential_fields = InputSchema.get_credentials_fields()
custom_missing = InputSchema.get_missing_input(node.input_default)
for field_name in custom_missing:
if (
field_name not in provided_inputs
and field_name not in credential_fields
):
node_errors[node.id][field_name] = "This field is required"
# Get input schema properties and check dependencies # Get input schema properties and check dependencies
input_fields = InputSchema.model_fields input_fields = InputSchema.model_fields

View File

@@ -463,3 +463,120 @@ def test_node_credentials_optional_with_other_metadata():
assert node.credentials_optional is True assert node.credentials_optional is True
assert node.metadata["position"] == {"x": 100, "y": 200} assert node.metadata["position"] == {"x": 100, "y": 200}
assert node.metadata["customized_name"] == "My Custom Node" assert node.metadata["customized_name"] == "My Custom Node"
# ============================================================================
# Tests for MCP Credential Deduplication
# ============================================================================
def test_mcp_credential_combine_different_servers():
"""Two MCP credential fields with different server URLs should produce
separate entries when combined (not merged into one)."""
from backend.data.model import CredentialsFieldInfo, CredentialsType
from backend.integrations.providers import ProviderName
oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"])
field_sentry = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.MCP]),
credentials_types=oauth2_types,
credentials_scopes=None,
discriminator="server_url",
discriminator_values={"https://mcp.sentry.dev/mcp"},
)
field_linear = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.MCP]),
credentials_types=oauth2_types,
credentials_scopes=None,
discriminator="server_url",
discriminator_values={"https://mcp.linear.app/mcp"},
)
combined = CredentialsFieldInfo.combine(
(field_sentry, ("node-sentry", "credentials")),
(field_linear, ("node-linear", "credentials")),
)
# Should produce 2 separate credential entries
assert len(combined) == 2, (
f"Expected 2 credential entries for 2 MCP blocks with different servers, "
f"got {len(combined)}: {list(combined.keys())}"
)
# Each entry should contain the server hostname in its key
keys = list(combined.keys())
assert any(
"mcp.sentry.dev" in k for k in keys
), f"Expected 'mcp.sentry.dev' in one key, got {keys}"
assert any(
"mcp.linear.app" in k for k in keys
), f"Expected 'mcp.linear.app' in one key, got {keys}"
def test_mcp_credential_combine_same_server():
"""Two MCP credential fields with the same server URL should be combined
into one credential entry."""
from backend.data.model import CredentialsFieldInfo, CredentialsType
from backend.integrations.providers import ProviderName
oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"])
field_a = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.MCP]),
credentials_types=oauth2_types,
credentials_scopes=None,
discriminator="server_url",
discriminator_values={"https://mcp.sentry.dev/mcp"},
)
field_b = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.MCP]),
credentials_types=oauth2_types,
credentials_scopes=None,
discriminator="server_url",
discriminator_values={"https://mcp.sentry.dev/mcp"},
)
combined = CredentialsFieldInfo.combine(
(field_a, ("node-a", "credentials")),
(field_b, ("node-b", "credentials")),
)
# Should produce 1 credential entry (same server URL)
assert len(combined) == 1, (
f"Expected 1 credential entry for 2 MCP blocks with same server, "
f"got {len(combined)}: {list(combined.keys())}"
)
def test_mcp_credential_combine_no_discriminator_values():
"""MCP credential fields without discriminator_values should be merged
into a single entry (backwards compat for blocks without server_url set)."""
from backend.data.model import CredentialsFieldInfo, CredentialsType
from backend.integrations.providers import ProviderName
oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"])
field_a = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.MCP]),
credentials_types=oauth2_types,
credentials_scopes=None,
discriminator="server_url",
)
field_b = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.MCP]),
credentials_types=oauth2_types,
credentials_scopes=None,
discriminator="server_url",
)
combined = CredentialsFieldInfo.combine(
(field_a, ("node-a", "credentials")),
(field_b, ("node-b", "credentials")),
)
# Should produce 1 entry (no URL differentiation)
assert len(combined) == 1, (
f"Expected 1 credential entry for MCP blocks without discriminator_values, "
f"got {len(combined)}: {list(combined.keys())}"
)

View File

@@ -29,6 +29,7 @@ from pydantic import (
GetCoreSchemaHandler, GetCoreSchemaHandler,
SecretStr, SecretStr,
field_serializer, field_serializer,
model_validator,
) )
from pydantic_core import ( from pydantic_core import (
CoreSchema, CoreSchema,
@@ -499,6 +500,25 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
provider: CP provider: CP
type: CT type: CT
@model_validator(mode="before")
@classmethod
def _normalize_legacy_provider(cls, data: Any) -> Any:
"""Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug.
Python 3.13 changed ``str(StrEnum)`` to return ``"ClassName.MEMBER"``
instead of the plain value. Old stored credential references may have
``provider: "ProviderName.MCP"`` instead of ``"mcp"``.
"""
if isinstance(data, dict):
prov = data.get("provider", "")
if isinstance(prov, str) and prov.startswith("ProviderName."):
member = prov.removeprefix("ProviderName.")
try:
data = {**data, "provider": ProviderName[member].value}
except KeyError:
pass
return data
@classmethod @classmethod
def allowed_providers(cls) -> tuple[ProviderName, ...] | None: def allowed_providers(cls) -> tuple[ProviderName, ...] | None:
return get_args(cls.model_fields["provider"].annotation) return get_args(cls.model_fields["provider"].annotation)
@@ -603,11 +623,18 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
] = defaultdict(list) ] = defaultdict(list)
for field, key in fields: for field, key in fields:
if field.provider == frozenset([ProviderName.HTTP]): if (
# HTTP host-scoped credentials can have different hosts that reqires different credential sets. field.discriminator
# Group by host extracted from the URL and not field.discriminator_mapping
and field.discriminator_values
):
# URL-based discrimination (e.g. HTTP host-scoped, MCP server URL):
# Each unique host gets its own credential entry.
provider_prefix = next(iter(field.provider))
# Use .value for enum types to get the plain string (e.g. "mcp" not "ProviderName.MCP")
prefix_str = getattr(provider_prefix, "value", str(provider_prefix))
providers = frozenset( providers = frozenset(
[cast(CP, "http")] [cast(CP, prefix_str)]
+ [ + [
cast(CP, parse_url(str(value)).netloc) cast(CP, parse_url(str(value)).netloc)
for value in field.discriminator_values for value in field.discriminator_values

View File

@@ -1,3 +1,4 @@
import asyncio
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
@@ -225,6 +226,10 @@ class SyncRabbitMQ(RabbitMQBase):
class AsyncRabbitMQ(RabbitMQBase): class AsyncRabbitMQ(RabbitMQBase):
"""Asynchronous RabbitMQ client""" """Asynchronous RabbitMQ client"""
def __init__(self, config: RabbitMQConfig):
super().__init__(config)
self._reconnect_lock: asyncio.Lock | None = None
@property @property
def is_connected(self) -> bool: def is_connected(self) -> bool:
return bool(self._connection and not self._connection.is_closed) return bool(self._connection and not self._connection.is_closed)
@@ -235,7 +240,17 @@ class AsyncRabbitMQ(RabbitMQBase):
@conn_retry("AsyncRabbitMQ", "Acquiring async connection") @conn_retry("AsyncRabbitMQ", "Acquiring async connection")
async def connect(self): async def connect(self):
if self.is_connected: if self.is_connected and self._channel and not self._channel.is_closed:
return
if (
self.is_connected
and self._connection
and (self._channel is None or self._channel.is_closed)
):
self._channel = await self._connection.channel()
await self._channel.set_qos(prefetch_count=1)
await self.declare_infrastructure()
return return
self._connection = await aio_pika.connect_robust( self._connection = await aio_pika.connect_robust(
@@ -291,24 +306,46 @@ class AsyncRabbitMQ(RabbitMQBase):
exchange, routing_key=queue.routing_key or queue.name exchange, routing_key=queue.routing_key or queue.name
) )
@func_retry @property
async def publish_message( def _lock(self) -> asyncio.Lock:
if self._reconnect_lock is None:
self._reconnect_lock = asyncio.Lock()
return self._reconnect_lock
async def _ensure_channel(self) -> aio_pika.abc.AbstractChannel:
"""Get a valid channel, reconnecting if the current one is stale.
Uses a lock to prevent concurrent reconnection attempts from racing.
"""
if self.is_ready:
return self._channel # type: ignore # is_ready guarantees non-None
async with self._lock:
# Double-check after acquiring lock
if self.is_ready:
return self._channel # type: ignore
self._channel = None
await self.connect()
if self._channel is None:
raise RuntimeError("Channel should be established after connect")
return self._channel
async def _publish_once(
self, self,
routing_key: str, routing_key: str,
message: str, message: str,
exchange: Optional[Exchange] = None, exchange: Optional[Exchange] = None,
persistent: bool = True, persistent: bool = True,
) -> None: ) -> None:
if not self.is_ready: channel = await self._ensure_channel()
await self.connect()
if self._channel is None:
raise RuntimeError("Channel should be established after connect")
if exchange: if exchange:
exchange_obj = await self._channel.get_exchange(exchange.name) exchange_obj = await channel.get_exchange(exchange.name)
else: else:
exchange_obj = self._channel.default_exchange exchange_obj = channel.default_exchange
await exchange_obj.publish( await exchange_obj.publish(
aio_pika.Message( aio_pika.Message(
@@ -322,9 +359,23 @@ class AsyncRabbitMQ(RabbitMQBase):
routing_key=routing_key, routing_key=routing_key,
) )
@func_retry
async def publish_message(
self,
routing_key: str,
message: str,
exchange: Optional[Exchange] = None,
persistent: bool = True,
) -> None:
try:
await self._publish_once(routing_key, message, exchange, persistent)
except aio_pika.exceptions.ChannelInvalidStateError:
logger.warning(
"RabbitMQ channel invalid, forcing reconnect and retrying publish"
)
async with self._lock:
self._channel = None
await self._publish_once(routing_key, message, exchange, persistent)
async def get_channel(self) -> aio_pika.abc.AbstractChannel: async def get_channel(self) -> aio_pika.abc.AbstractChannel:
if not self.is_ready: return await self._ensure_channel()
await self.connect()
if self._channel is None:
raise RuntimeError("Channel should be established after connect")
return self._channel

View File

@@ -18,6 +18,7 @@ from redis.asyncio.lock import Lock as AsyncRedisLock
from backend.blocks.agent import AgentExecutorBlock from backend.blocks.agent import AgentExecutorBlock
from backend.blocks.io import AgentOutputBlock from backend.blocks.io import AgentOutputBlock
from backend.blocks.mcp.block import MCPToolBlock
from backend.data import redis_client as redis from backend.data import redis_client as redis
from backend.data.block import ( from backend.data.block import (
BlockInput, BlockInput,
@@ -229,6 +230,10 @@ async def execute_node(
_input_data.nodes_input_masks = nodes_input_masks _input_data.nodes_input_masks = nodes_input_masks
_input_data.user_id = user_id _input_data.user_id = user_id
input_data = _input_data.model_dump() input_data = _input_data.model_dump()
elif isinstance(node_block, MCPToolBlock):
_mcp_data = MCPToolBlock.Input(**node.input_default)
_mcp_data.tool_arguments = input_data.get("tool_arguments", {})
input_data = _mcp_data.model_dump()
data.inputs = input_data data.inputs = input_data
# Execute the node # Execute the node
@@ -265,8 +270,34 @@ async def execute_node(
# Handle regular credentials fields # Handle regular credentials fields
for field_name, input_type in input_model.get_credentials_fields().items(): for field_name, input_type in input_model.get_credentials_fields().items():
credentials_meta = input_type(**input_data[field_name]) field_value = input_data.get(field_name)
credentials, lock = await creds_manager.acquire(user_id, credentials_meta.id) if not field_value or (
isinstance(field_value, dict) and not field_value.get("id")
):
# No credentials configured — nullify so JSON schema validation
# doesn't choke on the empty default `{}`.
input_data[field_name] = None
continue # Block runs without credentials
credentials_meta = input_type(**field_value)
# Write normalized values back so JSON schema validation also passes
# (model_validator may have fixed legacy formats like "ProviderName.MCP")
input_data[field_name] = credentials_meta.model_dump(mode="json")
try:
credentials, lock = await creds_manager.acquire(
user_id, credentials_meta.id
)
except ValueError:
# Credential was deleted or doesn't exist.
# If the field has a default, run without credentials.
if input_model.model_fields[field_name].default is not None:
log_metadata.warning(
f"Credentials #{credentials_meta.id} not found, "
"running without (field has default)"
)
input_data[field_name] = input_model.model_fields[field_name].default
continue
raise
creds_locks.append(lock) creds_locks.append(lock)
extra_exec_kwargs[field_name] = credentials extra_exec_kwargs[field_name] = credentials

View File

@@ -265,7 +265,13 @@ async def _validate_node_input_credentials(
# Track if any credential field is missing for this node # Track if any credential field is missing for this node
has_missing_credentials = False has_missing_credentials = False
# A credential field is optional if the node metadata says so, or if
# the block schema declares a default for the field.
required_fields = block.input_schema.get_required_fields()
is_creds_optional = node.credentials_optional
for field_name, credentials_meta_type in credentials_fields.items(): for field_name, credentials_meta_type in credentials_fields.items():
field_is_optional = is_creds_optional or field_name not in required_fields
try: try:
# Check nodes_input_masks first, then input_default # Check nodes_input_masks first, then input_default
field_value = None field_value = None
@@ -278,7 +284,7 @@ async def _validate_node_input_credentials(
elif field_name in node.input_default: elif field_name in node.input_default:
# For optional credentials, don't use input_default - treat as missing # For optional credentials, don't use input_default - treat as missing
# This prevents stale credential IDs from failing validation # This prevents stale credential IDs from failing validation
if node.credentials_optional: if field_is_optional:
field_value = None field_value = None
else: else:
field_value = node.input_default[field_name] field_value = node.input_default[field_name]
@@ -288,8 +294,8 @@ async def _validate_node_input_credentials(
isinstance(field_value, dict) and not field_value.get("id") isinstance(field_value, dict) and not field_value.get("id")
): ):
has_missing_credentials = True has_missing_credentials = True
# If node has credentials_optional flag, mark for skipping instead of error # If credential field is optional, skip instead of error
if node.credentials_optional: if field_is_optional:
continue # Don't add error, will be marked for skip after loop continue # Don't add error, will be marked for skip after loop
else: else:
credential_errors[node.id][ credential_errors[node.id][
@@ -339,16 +345,16 @@ async def _validate_node_input_credentials(
] = "Invalid credentials: type/provider mismatch" ] = "Invalid credentials: type/provider mismatch"
continue continue
# If node has optional credentials and any are missing, mark for skipping # If node has optional credentials and any are missing, allow running without.
# But only if there are no other errors for this node # The executor will pass credentials=None to the block's run().
if ( if (
has_missing_credentials has_missing_credentials
and node.credentials_optional and is_creds_optional
and node.id not in credential_errors and node.id not in credential_errors
): ):
nodes_to_skip.add(node.id)
logger.info( logger.info(
f"Node #{node.id} will be skipped: optional credentials not configured" f"Node #{node.id}: optional credentials not configured, "
"running without"
) )
return credential_errors, nodes_to_skip return credential_errors, nodes_to_skip

View File

@@ -495,6 +495,7 @@ async def test_validate_node_input_credentials_returns_nodes_to_skip(
mock_block.input_schema.get_credentials_fields.return_value = { mock_block.input_schema.get_credentials_fields.return_value = {
"credentials": mock_credentials_field_type "credentials": mock_credentials_field_type
} }
mock_block.input_schema.get_required_fields.return_value = {"credentials"}
mock_node.block = mock_block mock_node.block = mock_block
# Create mock graph # Create mock graph
@@ -508,8 +509,8 @@ async def test_validate_node_input_credentials_returns_nodes_to_skip(
nodes_input_masks=None, nodes_input_masks=None,
) )
# Node should be in nodes_to_skip, not in errors # Node should NOT be in nodes_to_skip (runs without credentials) and not in errors
assert mock_node.id in nodes_to_skip assert mock_node.id not in nodes_to_skip
assert mock_node.id not in errors assert mock_node.id not in errors
@@ -535,6 +536,7 @@ async def test_validate_node_input_credentials_required_missing_creds_error(
mock_block.input_schema.get_credentials_fields.return_value = { mock_block.input_schema.get_credentials_fields.return_value = {
"credentials": mock_credentials_field_type "credentials": mock_credentials_field_type
} }
mock_block.input_schema.get_required_fields.return_value = {"credentials"}
mock_node.block = mock_block mock_node.block = mock_block
# Create mock graph # Create mock graph

View File

@@ -22,6 +22,27 @@ from backend.util.settings import Settings
settings = Settings() settings = Settings()
def provider_matches(stored: str, expected: str) -> bool:
"""Compare provider strings, handling Python 3.13 ``str(StrEnum)`` bug.
On Python 3.13, ``str(ProviderName.MCP)`` returns ``"ProviderName.MCP"``
instead of ``"mcp"``. OAuth states persisted with the buggy format need
to match when ``expected`` is the canonical value (e.g. ``"mcp"``).
"""
if stored == expected:
return True
if stored.startswith("ProviderName."):
member = stored.removeprefix("ProviderName.")
from backend.integrations.providers import ProviderName
try:
return ProviderName[member].value == expected
except KeyError:
pass
return False
# This is an overrride since ollama doesn't actually require an API key, but the creddential system enforces one be attached # This is an overrride since ollama doesn't actually require an API key, but the creddential system enforces one be attached
ollama_credentials = APIKeyCredentials( ollama_credentials = APIKeyCredentials(
id="744fdc56-071a-4761-b5a5-0af0ce10a2b5", id="744fdc56-071a-4761-b5a5-0af0ce10a2b5",
@@ -389,7 +410,7 @@ class IntegrationCredentialsStore:
self, user_id: str, provider: str self, user_id: str, provider: str
) -> list[Credentials]: ) -> list[Credentials]:
credentials = await self.get_all_creds(user_id) credentials = await self.get_all_creds(user_id)
return [c for c in credentials if c.provider == provider] return [c for c in credentials if provider_matches(c.provider, provider)]
async def get_authorized_providers(self, user_id: str) -> list[str]: async def get_authorized_providers(self, user_id: str) -> list[str]:
credentials = await self.get_all_creds(user_id) credentials = await self.get_all_creds(user_id)
@@ -485,17 +506,6 @@ class IntegrationCredentialsStore:
async with self.edit_user_integrations(user_id) as user_integrations: async with self.edit_user_integrations(user_id) as user_integrations:
user_integrations.oauth_states.append(state) user_integrations.oauth_states.append(state)
async with await self.locked_user_integrations(user_id):
user_integrations = await self._get_user_integrations(user_id)
oauth_states = user_integrations.oauth_states
oauth_states.append(state)
user_integrations.oauth_states = oauth_states
await self.db_manager.update_user_integrations(
user_id=user_id, data=user_integrations
)
return token, code_challenge return token, code_challenge
def _generate_code_challenge(self) -> tuple[str, str]: def _generate_code_challenge(self) -> tuple[str, str]:
@@ -521,7 +531,7 @@ class IntegrationCredentialsStore:
state state
for state in oauth_states for state in oauth_states
if secrets.compare_digest(state.token, token) if secrets.compare_digest(state.token, token)
and state.provider == provider and provider_matches(state.provider, provider)
and state.expires_at > now.timestamp() and state.expires_at > now.timestamp()
), ),
None, None,

View File

@@ -9,7 +9,10 @@ from redis.asyncio.lock import Lock as AsyncRedisLock
from backend.data.model import Credentials, OAuth2Credentials from backend.data.model import Credentials, OAuth2Credentials
from backend.data.redis_client import get_redis_async from backend.data.redis_client import get_redis_async
from backend.integrations.credentials_store import IntegrationCredentialsStore from backend.integrations.credentials_store import (
IntegrationCredentialsStore,
provider_matches,
)
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName from backend.integrations.providers import ProviderName
from backend.util.exceptions import MissingConfigError from backend.util.exceptions import MissingConfigError
@@ -137,6 +140,9 @@ class IntegrationCredentialsManager:
self, user_id: str, credentials: OAuth2Credentials, lock: bool = True self, user_id: str, credentials: OAuth2Credentials, lock: bool = True
) -> OAuth2Credentials: ) -> OAuth2Credentials:
async with self._locked(user_id, credentials.id, "refresh"): async with self._locked(user_id, credentials.id, "refresh"):
if provider_matches(credentials.provider, ProviderName.MCP.value):
oauth_handler = create_mcp_oauth_handler(credentials)
else:
oauth_handler = await _get_provider_oauth_handler(credentials.provider) oauth_handler = await _get_provider_oauth_handler(credentials.provider)
if oauth_handler.needs_refresh(credentials): if oauth_handler.needs_refresh(credentials):
logger.debug( logger.debug(
@@ -236,3 +242,25 @@ async def _get_provider_oauth_handler(provider_name_str: str) -> "BaseOAuthHandl
client_secret=client_secret, client_secret=client_secret,
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback", redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
) )
def create_mcp_oauth_handler(
credentials: OAuth2Credentials,
) -> "BaseOAuthHandler":
"""Create an MCPOAuthHandler from credential metadata for token refresh.
MCP OAuth handlers have dynamic endpoints discovered per-server, so they
can't be registered as singletons in HANDLERS_BY_NAME. Instead, the handler
is reconstructed from metadata stored on the credential during initial auth.
"""
from backend.blocks.mcp.oauth import MCPOAuthHandler
meta = credentials.metadata or {}
return MCPOAuthHandler(
client_id=meta.get("mcp_client_id", ""),
client_secret=meta.get("mcp_client_secret", ""),
redirect_uri="", # Not needed for token refresh
authorize_url="", # Not needed for token refresh
token_url=meta.get("mcp_token_url", ""),
resource_url=meta.get("mcp_resource_url"),
)

View File

@@ -30,6 +30,7 @@ class ProviderName(str, Enum):
IDEOGRAM = "ideogram" IDEOGRAM = "ideogram"
JINA = "jina" JINA = "jina"
LLAMA_API = "llama_api" LLAMA_API = "llama_api"
MCP = "mcp"
MEDIUM = "medium" MEDIUM = "medium"
MEM0 = "mem0" MEM0 = "mem0"
NOTION = "notion" NOTION = "notion"

View File

@@ -50,6 +50,21 @@ async def _on_graph_activate(graph: "BaseGraph | GraphModel", user_id: str):
if ( if (
creds_meta := new_node.input_default.get(creds_field_name) creds_meta := new_node.input_default.get(creds_field_name)
) and not await get_credentials(creds_meta["id"]): ) and not await get_credentials(creds_meta["id"]):
# If the credential field is optional (has a default in the
# schema, or node metadata marks it optional), clear the stale
# reference instead of blocking the save.
creds_field_optional = (
new_node.credentials_optional
or creds_field_name not in block_input_schema.get_required_fields()
)
if creds_field_optional:
new_node.input_default[creds_field_name] = {}
logger.warning(
f"Node #{new_node.id}: cleared stale optional "
f"credentials #{creds_meta['id']} for "
f"'{creds_field_name}'"
)
continue
raise ValueError( raise ValueError(
f"Node #{new_node.id} input '{creds_field_name}' updated with " f"Node #{new_node.id} input '{creds_field_name}' updated with "
f"non-existent credentials #{creds_meta['id']}" f"non-existent credentials #{creds_meta['id']}"

View File

@@ -101,7 +101,7 @@ class HostResolver(abc.AbstractResolver):
def __init__(self, ssl_hostname: str, ip_addresses: list[str]): def __init__(self, ssl_hostname: str, ip_addresses: list[str]):
self.ssl_hostname = ssl_hostname self.ssl_hostname = ssl_hostname
self.ip_addresses = ip_addresses self.ip_addresses = ip_addresses
self._default = aiohttp.AsyncResolver() self._default = aiohttp.ThreadedResolver()
async def resolve(self, host, port=0, family=socket.AF_INET): async def resolve(self, host, port=0, family=socket.AF_INET):
if host == self.ssl_hostname: if host == self.ssl_hostname:
@@ -467,7 +467,7 @@ class Requests:
resolver = HostResolver(ssl_hostname=hostname, ip_addresses=ip_addresses) resolver = HostResolver(ssl_hostname=hostname, ip_addresses=ip_addresses)
ssl_context = ssl.create_default_context() ssl_context = ssl.create_default_context()
connector = aiohttp.TCPConnector(resolver=resolver, ssl=ssl_context) connector = aiohttp.TCPConnector(resolver=resolver, ssl=ssl_context)
session_kwargs = {} session_kwargs: dict = {}
if connector: if connector:
session_kwargs["connector"] = connector session_kwargs["connector"] = connector

View File

@@ -25,8 +25,12 @@ RUN if [ -f .env.production ]; then \
cp .env.default .env; \ cp .env.default .env; \
fi fi
RUN pnpm run generate:api RUN pnpm run generate:api
# Disable source-map generation in Docker builds to halve webpack memory usage.
# Source maps are only useful when SENTRY_AUTH_TOKEN is set (Vercel deploys);
# the Docker image never uploads them, so generating them just wastes RAM.
ENV NEXT_PUBLIC_SOURCEMAPS="false"
# In CI, we want NEXT_PUBLIC_PW_TEST=true during build so Next.js inlines it # In CI, we want NEXT_PUBLIC_PW_TEST=true during build so Next.js inlines it
RUN if [ "$NEXT_PUBLIC_PW_TEST" = "true" ]; then NEXT_PUBLIC_PW_TEST=true NODE_OPTIONS="--max-old-space-size=4096" pnpm build; else NODE_OPTIONS="--max-old-space-size=4096" pnpm build; fi RUN if [ "$NEXT_PUBLIC_PW_TEST" = "true" ]; then NEXT_PUBLIC_PW_TEST=true NODE_OPTIONS="--max-old-space-size=8192" pnpm build; else NODE_OPTIONS="--max-old-space-size=8192" pnpm build; fi
# Prod stage - based on NextJS reference Dockerfile https://github.com/vercel/next.js/blob/64271354533ed16da51be5dce85f0dbd15f17517/examples/with-docker/Dockerfile # Prod stage - based on NextJS reference Dockerfile https://github.com/vercel/next.js/blob/64271354533ed16da51be5dce85f0dbd15f17517/examples/with-docker/Dockerfile
FROM node:21-alpine AS prod FROM node:21-alpine AS prod

View File

@@ -1,8 +1,12 @@
import { withSentryConfig } from "@sentry/nextjs"; import { withSentryConfig } from "@sentry/nextjs";
// Allow Docker builds to skip source-map generation (halves memory usage).
// Defaults to true so Vercel/local builds are unaffected.
const enableSourceMaps = process.env.NEXT_PUBLIC_SOURCEMAPS !== "false";
/** @type {import('next').NextConfig} */ /** @type {import('next').NextConfig} */
const nextConfig = { const nextConfig = {
productionBrowserSourceMaps: true, productionBrowserSourceMaps: enableSourceMaps,
// Externalize OpenTelemetry packages to fix Turbopack HMR issues // Externalize OpenTelemetry packages to fix Turbopack HMR issues
serverExternalPackages: [ serverExternalPackages: [
"@opentelemetry/instrumentation", "@opentelemetry/instrumentation",
@@ -14,9 +18,37 @@ const nextConfig = {
serverActions: { serverActions: {
bodySizeLimit: "256mb", bodySizeLimit: "256mb",
}, },
// Increase body size limit for API routes (file uploads) - 256MB to match backend limit
proxyClientMaxBodySize: "256mb",
middlewareClientMaxBodySize: "256mb", middlewareClientMaxBodySize: "256mb",
// Limit parallel webpack workers to reduce peak memory during builds.
cpus: 2,
},
// Work around cssnano "Invalid array length" bug in Next.js's bundled
// cssnano-simple comment parser when processing very large CSS chunks.
// CSS is still bundled correctly; gzip handles most of the size savings anyway.
webpack: (config, { dev }) => {
if (!dev) {
// Next.js adds CssMinimizerPlugin internally (after user config), so we
// can't filter it from config.plugins. Instead, intercept the webpack
// compilation hooks and replace the buggy plugin's tap with a no-op.
config.plugins.push({
apply(compiler) {
compiler.hooks.compilation.tap(
"DisableCssMinimizer",
(compilation) => {
compilation.hooks.processAssets.intercept({
register: (tap) => {
if (tap.name === "CssMinimizerPlugin") {
return { ...tap, fn: async () => {} };
}
return tap;
},
});
},
);
},
});
}
return config;
}, },
images: { images: {
domains: [ domains: [
@@ -54,9 +86,16 @@ const nextConfig = {
transpilePackages: ["geist"], transpilePackages: ["geist"],
}; };
const isDevelopmentBuild = process.env.NODE_ENV !== "production"; // Only run the Sentry webpack plugin when we can actually upload source maps
// (i.e. on Vercel with SENTRY_AUTH_TOKEN set). The Sentry *runtime* SDK
// (imported in app code) still captures errors without the plugin.
// Skipping the plugin saves ~1 GB of peak memory during `next build`.
const skipSentryPlugin =
process.env.NODE_ENV !== "production" ||
!enableSourceMaps ||
!process.env.SENTRY_AUTH_TOKEN;
export default isDevelopmentBuild export default skipSentryPlugin
? nextConfig ? nextConfig
: withSentryConfig(nextConfig, { : withSentryConfig(nextConfig, {
// For all available options, see: // For all available options, see:
@@ -96,7 +135,7 @@ export default isDevelopmentBuild
// This helps Sentry with sourcemaps... https://docs.sentry.io/platforms/javascript/guides/nextjs/sourcemaps/ // This helps Sentry with sourcemaps... https://docs.sentry.io/platforms/javascript/guides/nextjs/sourcemaps/
sourcemaps: { sourcemaps: {
disable: false, disable: !enableSourceMaps,
assets: [".next/**/*.js", ".next/**/*.js.map"], assets: [".next/**/*.js", ".next/**/*.js.map"],
ignore: ["**/node_modules/**"], ignore: ["**/node_modules/**"],
deleteSourcemapsAfterUpload: false, // Source is public anyway :) deleteSourcemapsAfterUpload: false, // Source is public anyway :)

View File

@@ -7,7 +7,7 @@
}, },
"scripts": { "scripts": {
"dev": "pnpm run generate:api:force && next dev --turbo", "dev": "pnpm run generate:api:force && next dev --turbo",
"build": "next build", "build": "cross-env NODE_OPTIONS=--max-old-space-size=16384 next build",
"start": "next start", "start": "next start",
"start:standalone": "cd .next/standalone && node server.js", "start:standalone": "cd .next/standalone && node server.js",
"lint": "next lint && prettier --check .", "lint": "next lint && prettier --check .",
@@ -30,6 +30,7 @@
"defaults" "defaults"
], ],
"dependencies": { "dependencies": {
"@ai-sdk/react": "3.0.61",
"@faker-js/faker": "10.0.0", "@faker-js/faker": "10.0.0",
"@hookform/resolvers": "5.2.2", "@hookform/resolvers": "5.2.2",
"@next/third-parties": "15.4.6", "@next/third-parties": "15.4.6",
@@ -60,6 +61,10 @@
"@rjsf/utils": "6.1.2", "@rjsf/utils": "6.1.2",
"@rjsf/validator-ajv8": "6.1.2", "@rjsf/validator-ajv8": "6.1.2",
"@sentry/nextjs": "10.27.0", "@sentry/nextjs": "10.27.0",
"@streamdown/cjk": "1.0.1",
"@streamdown/code": "1.0.1",
"@streamdown/math": "1.0.1",
"@streamdown/mermaid": "1.0.1",
"@supabase/ssr": "0.7.0", "@supabase/ssr": "0.7.0",
"@supabase/supabase-js": "2.78.0", "@supabase/supabase-js": "2.78.0",
"@tanstack/react-query": "5.90.6", "@tanstack/react-query": "5.90.6",
@@ -68,6 +73,7 @@
"@vercel/analytics": "1.5.0", "@vercel/analytics": "1.5.0",
"@vercel/speed-insights": "1.2.0", "@vercel/speed-insights": "1.2.0",
"@xyflow/react": "12.9.2", "@xyflow/react": "12.9.2",
"ai": "6.0.59",
"boring-avatars": "1.11.2", "boring-avatars": "1.11.2",
"class-variance-authority": "0.7.1", "class-variance-authority": "0.7.1",
"clsx": "2.1.1", "clsx": "2.1.1",
@@ -87,7 +93,6 @@
"launchdarkly-react-client-sdk": "3.9.0", "launchdarkly-react-client-sdk": "3.9.0",
"lodash": "4.17.21", "lodash": "4.17.21",
"lucide-react": "0.552.0", "lucide-react": "0.552.0",
"moment": "2.30.1",
"next": "15.4.10", "next": "15.4.10",
"next-themes": "0.4.6", "next-themes": "0.4.6",
"nuqs": "2.7.2", "nuqs": "2.7.2",
@@ -112,9 +117,11 @@
"remark-math": "6.0.0", "remark-math": "6.0.0",
"shepherd.js": "14.5.1", "shepherd.js": "14.5.1",
"sonner": "2.0.7", "sonner": "2.0.7",
"streamdown": "2.1.0",
"tailwind-merge": "2.6.0", "tailwind-merge": "2.6.0",
"tailwind-scrollbar": "3.1.0", "tailwind-scrollbar": "3.1.0",
"tailwindcss-animate": "1.0.7", "tailwindcss-animate": "1.0.7",
"use-stick-to-bottom": "1.1.2",
"uuid": "11.1.0", "uuid": "11.1.0",
"vaul": "1.1.2", "vaul": "1.1.2",
"zod": "3.25.76", "zod": "3.25.76",
@@ -172,7 +179,8 @@
}, },
"pnpm": { "pnpm": {
"overrides": { "overrides": {
"@opentelemetry/instrumentation": "0.209.0" "@opentelemetry/instrumentation": "0.209.0",
"lodash-es": "4.17.23"
} }
}, },
"packageManager": "pnpm@10.20.0+sha512.cf9998222162dd85864d0a8102e7892e7ba4ceadebbf5a31f9c2fce48dfce317a9c53b9f6464d1ef9042cba2e02ae02a9f7c143a2b438cd93c91840f0192b9dd" "packageManager": "pnpm@10.20.0+sha512.cf9998222162dd85864d0a8102e7892e7ba4ceadebbf5a31f9c2fce48dfce317a9c53b9f6464d1ef9042cba2e02ae02a9f7c143a2b438cd93c91840f0192b9dd"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,96 @@
import { NextResponse } from "next/server";
/**
* Safely encode a value as JSON for embedding in a script tag.
* Escapes characters that could break out of the script context to prevent XSS.
*/
function safeJsonStringify(value: unknown): string {
return JSON.stringify(value)
.replace(/</g, "\\u003c")
.replace(/>/g, "\\u003e")
.replace(/&/g, "\\u0026");
}
// MCP-specific OAuth callback route.
//
// Unlike the generic oauth_callback which relies on window.opener.postMessage,
// this route uses BroadcastChannel as the PRIMARY communication method.
// This is critical because cross-origin OAuth flows (e.g. Sentry → localhost)
// often lose window.opener due to COOP (Cross-Origin-Opener-Policy) headers.
//
// BroadcastChannel works across all same-origin tabs/popups regardless of opener.
export async function GET(request: Request) {
const { searchParams } = new URL(request.url);
const code = searchParams.get("code");
const state = searchParams.get("state");
const success = Boolean(code && state);
const message = success
? { success: true, code, state }
: {
success: false,
message: `Missing parameters: ${searchParams.toString()}`,
};
return new NextResponse(
`<!DOCTYPE html>
<html>
<head><title>MCP Sign-in</title></head>
<body style="font-family: system-ui, -apple-system, sans-serif; display: flex; align-items: center; justify-content: center; min-height: 100vh; margin: 0; background: #f9fafb;">
<div style="text-align: center; max-width: 400px; padding: 2rem;">
<div id="spinner" style="margin: 0 auto 1rem; width: 32px; height: 32px; border: 3px solid #e5e7eb; border-top-color: #3b82f6; border-radius: 50%; animation: spin 0.8s linear infinite;"></div>
<p id="status" style="color: #374151; font-size: 16px;">Completing sign-in...</p>
</div>
<style>@keyframes spin { to { transform: rotate(360deg); } }</style>
<script>
(function() {
var msg = ${safeJsonStringify(message)};
var sent = false;
// Method 1: BroadcastChannel (reliable across tabs/popups, no opener needed)
try {
var bc = new BroadcastChannel("mcp_oauth");
bc.postMessage({ type: "mcp_oauth_result", success: msg.success, code: msg.code, state: msg.state, message: msg.message });
bc.close();
sent = true;
} catch(e) { /* BroadcastChannel not supported */ }
// Method 2: window.opener.postMessage (fallback for same-origin popups)
try {
if (window.opener && !window.opener.closed) {
window.opener.postMessage(
{ message_type: "mcp_oauth_result", success: msg.success, code: msg.code, state: msg.state, message: msg.message },
window.location.origin
);
sent = true;
}
} catch(e) { /* opener not available (COOP) */ }
// Method 3: localStorage (most reliable cross-tab fallback)
try {
localStorage.setItem("mcp_oauth_result", JSON.stringify(msg));
sent = true;
} catch(e) { /* localStorage not available */ }
var statusEl = document.getElementById("status");
var spinnerEl = document.getElementById("spinner");
spinnerEl.style.display = "none";
if (msg.success && sent) {
statusEl.textContent = "Sign-in complete! This window will close.";
statusEl.style.color = "#059669";
setTimeout(function() { window.close(); }, 1500);
} else if (msg.success) {
statusEl.textContent = "Sign-in successful! You can close this tab and return to the builder.";
statusEl.style.color = "#059669";
} else {
statusEl.textContent = "Sign-in failed: " + (msg.message || "Unknown error");
statusEl.style.color = "#dc2626";
}
})();
</script>
</body>
</html>`,
{ headers: { "Content-Type": "text/html" } },
);
}

View File

@@ -47,7 +47,10 @@ export type CustomNode = XYNode<CustomNodeData, "custom">;
export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo( export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
({ data, id: nodeId, selected }) => { ({ data, id: nodeId, selected }) => {
const { inputSchema, outputSchema } = useCustomNode({ data, nodeId }); const { inputSchema, outputSchema, isMCPWithTool } = useCustomNode({
data,
nodeId,
});
const isAgent = data.uiType === BlockUIType.AGENT; const isAgent = data.uiType === BlockUIType.AGENT;
@@ -98,6 +101,7 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
jsonSchema={preprocessInputSchema(inputSchema)} jsonSchema={preprocessInputSchema(inputSchema)}
nodeId={nodeId} nodeId={nodeId}
uiType={data.uiType} uiType={data.uiType}
isMCPWithTool={isMCPWithTool}
className={cn( className={cn(
"bg-white px-4", "bg-white px-4",
isWebhook && "pointer-events-none opacity-50", isWebhook && "pointer-events-none opacity-50",

View File

@@ -20,10 +20,8 @@ type Props = {
export const NodeHeader = ({ data, nodeId }: Props) => { export const NodeHeader = ({ data, nodeId }: Props) => {
const updateNodeData = useNodeStore((state) => state.updateNodeData); const updateNodeData = useNodeStore((state) => state.updateNodeData);
const title =
(data.metadata?.customized_name as string) || const title = (data.metadata?.customized_name as string) || data.title;
data.hardcodedValues?.agent_name ||
data.title;
const [isEditingTitle, setIsEditingTitle] = useState(false); const [isEditingTitle, setIsEditingTitle] = useState(false);
const [editedTitle, setEditedTitle] = useState(title); const [editedTitle, setEditedTitle] = useState(title);

View File

@@ -3,6 +3,36 @@ import { CustomNodeData } from "./CustomNode";
import { BlockUIType } from "../../../types"; import { BlockUIType } from "../../../types";
import { useMemo } from "react"; import { useMemo } from "react";
import { mergeSchemaForResolution } from "./helpers"; import { mergeSchemaForResolution } from "./helpers";
import { SpecialBlockID } from "@/lib/autogpt-server-api";
/**
* Build a dynamic input schema for MCP blocks.
*
* When a tool has been selected (tool_input_schema is populated), the block
* renders the selected tool's input parameters *plus* the credentials field
* so users can select/change the OAuth credential used for execution.
*
* Static fields like server_url, selected_tool, available_tools, and
* tool_arguments are hidden because they're pre-configured from the dialog.
*/
function buildMCPInputSchema(
toolInputSchema: Record<string, any>,
blockInputSchema: Record<string, any>,
): Record<string, any> {
// Extract the credentials field from the block's original input schema
const credentialsSchema =
blockInputSchema?.properties?.credentials ?? undefined;
return {
type: "object",
properties: {
// Credentials field first so the dropdown appears at the top
...(credentialsSchema ? { credentials: credentialsSchema } : {}),
...(toolInputSchema.properties ?? {}),
},
required: [...(toolInputSchema.required ?? [])],
};
}
export const useCustomNode = ({ export const useCustomNode = ({
data, data,
@@ -19,9 +49,17 @@ export const useCustomNode = ({
); );
const isAgent = data.uiType === BlockUIType.AGENT; const isAgent = data.uiType === BlockUIType.AGENT;
const isMCPWithTool =
data.block_id === SpecialBlockID.MCP_TOOL &&
!!data.hardcodedValues?.tool_input_schema?.properties;
const currentInputSchema = isAgent const currentInputSchema = isAgent
? (data.hardcodedValues.input_schema ?? {}) ? (data.hardcodedValues.input_schema ?? {})
: isMCPWithTool
? buildMCPInputSchema(
data.hardcodedValues.tool_input_schema,
data.inputSchema,
)
: data.inputSchema; : data.inputSchema;
const currentOutputSchema = isAgent const currentOutputSchema = isAgent
? (data.hardcodedValues.output_schema ?? {}) ? (data.hardcodedValues.output_schema ?? {})
@@ -54,5 +92,6 @@ export const useCustomNode = ({
return { return {
inputSchema, inputSchema,
outputSchema, outputSchema,
isMCPWithTool,
}; };
}; };

View File

@@ -9,39 +9,72 @@ interface FormCreatorProps {
jsonSchema: RJSFSchema; jsonSchema: RJSFSchema;
nodeId: string; nodeId: string;
uiType: BlockUIType; uiType: BlockUIType;
/** When true the block is an MCP Tool with a selected tool. */
isMCPWithTool?: boolean;
showHandles?: boolean; showHandles?: boolean;
className?: string; className?: string;
} }
export const FormCreator: React.FC<FormCreatorProps> = React.memo( export const FormCreator: React.FC<FormCreatorProps> = React.memo(
({ jsonSchema, nodeId, uiType, showHandles = true, className }) => { ({
jsonSchema,
nodeId,
uiType,
isMCPWithTool = false,
showHandles = true,
className,
}) => {
const updateNodeData = useNodeStore((state) => state.updateNodeData); const updateNodeData = useNodeStore((state) => state.updateNodeData);
const getHardCodedValues = useNodeStore( const getHardCodedValues = useNodeStore(
(state) => state.getHardCodedValues, (state) => state.getHardCodedValues,
); );
const isAgent = uiType === BlockUIType.AGENT;
const handleChange = ({ formData }: any) => { const handleChange = ({ formData }: any) => {
if ("credentials" in formData && !formData.credentials?.id) { if ("credentials" in formData && !formData.credentials?.id) {
delete formData.credentials; delete formData.credentials;
} }
const updatedValues = let updatedValues;
uiType === BlockUIType.AGENT if (isAgent) {
? { updatedValues = {
...getHardCodedValues(nodeId), ...getHardCodedValues(nodeId),
inputs: formData, inputs: formData,
};
} else if (isMCPWithTool) {
// Separate credentials from tool arguments — credentials are stored
// at the top level of hardcodedValues, not inside tool_arguments.
const { credentials, ...toolArgs } = formData;
updatedValues = {
...getHardCodedValues(nodeId),
tool_arguments: toolArgs,
...(credentials?.id ? { credentials } : {}),
};
} else {
updatedValues = formData;
} }
: formData;
updateNodeData(nodeId, { hardcodedValues: updatedValues }); updateNodeData(nodeId, { hardcodedValues: updatedValues });
}; };
const hardcodedValues = getHardCodedValues(nodeId); const hardcodedValues = getHardCodedValues(nodeId);
const initialValues =
uiType === BlockUIType.AGENT let initialValues;
? (hardcodedValues.inputs ?? {}) if (isAgent) {
: hardcodedValues; initialValues = hardcodedValues.inputs ?? {};
} else if (isMCPWithTool) {
// Merge tool arguments with credentials for the form
initialValues = {
...(hardcodedValues.tool_arguments ?? {}),
...(hardcodedValues.credentials?.id
? { credentials: hardcodedValues.credentials }
: {}),
};
} else {
initialValues = hardcodedValues;
}
return ( return (
<div <div

View File

@@ -1,7 +1,7 @@
import { Button } from "@/components/__legacy__/ui/button"; import { Button } from "@/components/__legacy__/ui/button";
import { Skeleton } from "@/components/__legacy__/ui/skeleton"; import { Skeleton } from "@/components/__legacy__/ui/skeleton";
import { beautifyString, cn } from "@/lib/utils"; import { beautifyString, cn } from "@/lib/utils";
import React, { ButtonHTMLAttributes } from "react"; import React, { ButtonHTMLAttributes, useCallback, useState } from "react";
import { highlightText } from "./helpers"; import { highlightText } from "./helpers";
import { PlusIcon } from "@phosphor-icons/react"; import { PlusIcon } from "@phosphor-icons/react";
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo"; import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
@@ -9,6 +9,12 @@ import { useControlPanelStore } from "../../../stores/controlPanelStore";
import { blockDragPreviewStyle } from "./style"; import { blockDragPreviewStyle } from "./style";
import { useReactFlow } from "@xyflow/react"; import { useReactFlow } from "@xyflow/react";
import { useNodeStore } from "../../../stores/nodeStore"; import { useNodeStore } from "../../../stores/nodeStore";
import { SpecialBlockID } from "@/lib/autogpt-server-api";
import {
MCPToolDialog,
type MCPToolDialogResult,
} from "@/app/(platform)/build/components/legacy-builder/MCPToolDialog";
interface Props extends ButtonHTMLAttributes<HTMLButtonElement> { interface Props extends ButtonHTMLAttributes<HTMLButtonElement> {
title?: string; title?: string;
description?: string; description?: string;
@@ -33,9 +39,13 @@ export const Block: BlockComponent = ({
); );
const { setViewport } = useReactFlow(); const { setViewport } = useReactFlow();
const { addBlock } = useNodeStore(); const { addBlock } = useNodeStore();
const [mcpDialogOpen, setMcpDialogOpen] = useState(false);
const handleClick = () => { const isMCPBlock = blockData.id === SpecialBlockID.MCP_TOOL;
const customNode = addBlock(blockData);
const addBlockAndCenter = useCallback(
(block: BlockInfo, hardcodedValues?: Record<string, any>) => {
const customNode = addBlock(block, hardcodedValues);
setTimeout(() => { setTimeout(() => {
setViewport( setViewport(
{ {
@@ -46,9 +56,69 @@ export const Block: BlockComponent = ({
{ duration: 500 }, { duration: 500 },
); );
}, 50); }, 50);
return customNode;
},
[addBlock, setViewport],
);
const updateNodeData = useNodeStore((state) => state.updateNodeData);
const handleMCPToolConfirm = useCallback(
(result: MCPToolDialogResult) => {
// Derive a display label: prefer server name, fall back to URL hostname.
let serverLabel = result.serverName;
if (!serverLabel) {
try {
serverLabel = new URL(result.serverUrl).hostname;
} catch {
serverLabel = "MCP";
}
}
const customNode = addBlockAndCenter(blockData, {
server_url: result.serverUrl,
server_name: serverLabel,
selected_tool: result.selectedTool,
tool_input_schema: result.toolInputSchema,
available_tools: result.availableTools,
credentials: result.credentials ?? undefined,
});
if (customNode) {
const title = result.selectedTool
? `${serverLabel}: ${beautifyString(result.selectedTool)}`
: undefined;
updateNodeData(customNode.id, {
metadata: {
...customNode.data.metadata,
credentials_optional: true,
...(title && { customized_name: title }),
},
});
}
setMcpDialogOpen(false);
},
[addBlockAndCenter, blockData, updateNodeData],
);
const handleClick = () => {
if (isMCPBlock) {
setMcpDialogOpen(true);
return;
}
const customNode = addBlockAndCenter(blockData);
// Set customized_name for agent blocks so the agent's name persists
if (customNode && blockData.id === SpecialBlockID.AGENT) {
updateNodeData(customNode.id, {
metadata: {
...customNode.data.metadata,
customized_name: blockData.name,
},
});
}
}; };
const handleDragStart = (e: React.DragEvent<HTMLButtonElement>) => { const handleDragStart = (e: React.DragEvent<HTMLButtonElement>) => {
if (isMCPBlock) return;
e.dataTransfer.effectAllowed = "copy"; e.dataTransfer.effectAllowed = "copy";
e.dataTransfer.setData("application/reactflow", JSON.stringify(blockData)); e.dataTransfer.setData("application/reactflow", JSON.stringify(blockData));
@@ -71,12 +141,14 @@ export const Block: BlockComponent = ({
: undefined; : undefined;
return ( return (
<>
<Button <Button
draggable={true} draggable={!isMCPBlock}
data-id={blockDataId} data-id={blockDataId}
className={cn( className={cn(
"group flex h-16 w-full min-w-[7.5rem] items-center justify-start space-x-3 whitespace-normal rounded-[0.75rem] bg-zinc-50 px-[0.875rem] py-[0.625rem] text-start shadow-none", "group flex h-16 w-full min-w-[7.5rem] items-center justify-start space-x-3 whitespace-normal rounded-[0.75rem] bg-zinc-50 px-[0.875rem] py-[0.625rem] text-start shadow-none",
"hover:cursor-default hover:bg-zinc-100 focus:ring-0 active:bg-zinc-100 active:ring-1 active:ring-zinc-300 disabled:cursor-not-allowed", "hover:cursor-default hover:bg-zinc-100 focus:ring-0 active:bg-zinc-100 active:ring-1 active:ring-zinc-300 disabled:cursor-not-allowed",
isMCPBlock && "hover:cursor-pointer",
className, className,
)} )}
onDragStart={handleDragStart} onDragStart={handleDragStart}
@@ -111,6 +183,14 @@ export const Block: BlockComponent = ({
<PlusIcon className="h-5 w-5 text-zinc-50" /> <PlusIcon className="h-5 w-5 text-zinc-50" />
</div> </div>
</Button> </Button>
{isMCPBlock && (
<MCPToolDialog
open={mcpDialogOpen}
onClose={() => setMcpDialogOpen(false)}
onConfirm={handleMCPToolConfirm}
/>
)}
</>
); );
}; };

View File

@@ -1,4 +1,4 @@
import { debounce } from "lodash"; import debounce from "lodash/debounce";
import { useCallback, useEffect, useRef, useState } from "react"; import { useCallback, useEffect, useRef, useState } from "react";
import { useBlockMenuStore } from "../../../../stores/blockMenuStore"; import { useBlockMenuStore } from "../../../../stores/blockMenuStore";
import { getQueryClient } from "@/lib/react-query/queryClient"; import { getQueryClient } from "@/lib/react-query/queryClient";

View File

@@ -70,10 +70,10 @@ export const HorizontalScroll: React.FC<HorizontalScrollAreaProps> = ({
{children} {children}
</div> </div>
{canScrollLeft && ( {canScrollLeft && (
<div className="pointer-events-none absolute inset-y-0 left-0 w-8 bg-gradient-to-r from-white via-white/80 to-white/0" /> <div className="pointer-events-none absolute inset-y-0 left-0 w-8 bg-gradient-to-r from-background via-background/80 to-background/0" />
)} )}
{canScrollRight && ( {canScrollRight && (
<div className="pointer-events-none absolute inset-y-0 right-0 w-8 bg-gradient-to-l from-white via-white/80 to-white/0" /> <div className="pointer-events-none absolute inset-y-0 right-0 w-8 bg-gradient-to-l from-background via-background/80 to-background/0" />
)} )}
{canScrollLeft && ( {canScrollLeft && (
<button <button

View File

@@ -29,6 +29,10 @@ import {
TooltipTrigger, TooltipTrigger,
} from "@/components/atoms/Tooltip/BaseTooltip"; } from "@/components/atoms/Tooltip/BaseTooltip";
import { GraphMeta } from "@/lib/autogpt-server-api"; import { GraphMeta } from "@/lib/autogpt-server-api";
import {
MCPToolDialog,
type MCPToolDialogResult,
} from "@/app/(platform)/build/components/legacy-builder/MCPToolDialog";
import jaro from "jaro-winkler"; import jaro from "jaro-winkler";
import { getV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs"; import { getV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { okData } from "@/app/api/helpers"; import { okData } from "@/app/api/helpers";
@@ -94,6 +98,7 @@ export function BlocksControl({
const [searchQuery, setSearchQuery] = useState(""); const [searchQuery, setSearchQuery] = useState("");
const deferredSearchQuery = useDeferredValue(searchQuery); const deferredSearchQuery = useDeferredValue(searchQuery);
const [selectedCategory, setSelectedCategory] = useState<string | null>(null); const [selectedCategory, setSelectedCategory] = useState<string | null>(null);
const [mcpDialogOpen, setMcpDialogOpen] = useState(false);
const blocks = useSearchableBlocks(_blocks); const blocks = useSearchableBlocks(_blocks);
@@ -186,11 +191,32 @@ export function BlocksControl({
setSelectedCategory(null); setSelectedCategory(null);
}, []); }, []);
const handleMCPToolConfirm = useCallback(
(result: MCPToolDialogResult) => {
addBlock(SpecialBlockID.MCP_TOOL, "MCPToolBlock", {
server_url: result.serverUrl,
server_name: result.serverName,
selected_tool: result.selectedTool,
tool_input_schema: result.toolInputSchema,
available_tools: result.availableTools,
credentials: result.credentials ?? undefined,
});
setMcpDialogOpen(false);
},
[addBlock],
);
// Handler to add a block, fetching graph data on-demand for agent blocks // Handler to add a block, fetching graph data on-demand for agent blocks
const handleAddBlock = useCallback( const handleAddBlock = useCallback(
async (block: _Block & { notAvailable: string | null }) => { async (block: _Block & { notAvailable: string | null }) => {
if (block.notAvailable) return; if (block.notAvailable) return;
// For MCP blocks, open the configuration dialog instead of placing directly
if (block.id === SpecialBlockID.MCP_TOOL) {
setMcpDialogOpen(true);
return;
}
// For agent blocks, fetch the full graph to get schemas // For agent blocks, fetch the full graph to get schemas
if (block.uiType === BlockUIType.AGENT && block.hardcodedValues) { if (block.uiType === BlockUIType.AGENT && block.hardcodedValues) {
const graphID = block.hardcodedValues.graph_id as string; const graphID = block.hardcodedValues.graph_id as string;
@@ -230,6 +256,7 @@ export function BlocksControl({
}, [blocks]); }, [blocks]);
return ( return (
<>
<Popover <Popover
open={pinBlocksPopover ? true : undefined} open={pinBlocksPopover ? true : undefined}
onOpenChange={(open) => open || resetFilters()} onOpenChange={(open) => open || resetFilters()}
@@ -322,12 +349,21 @@ export function BlocksControl({
className={`m-2 my-4 flex h-20 shadow-none dark:border-slate-700 dark:bg-slate-800 dark:text-slate-100 dark:hover:bg-slate-700 ${ className={`m-2 my-4 flex h-20 shadow-none dark:border-slate-700 dark:bg-slate-800 dark:text-slate-100 dark:hover:bg-slate-700 ${
block.notAvailable block.notAvailable
? "cursor-not-allowed opacity-50" ? "cursor-not-allowed opacity-50"
: block.id === SpecialBlockID.MCP_TOOL
? "cursor-pointer hover:shadow-lg"
: "cursor-move hover:shadow-lg" : "cursor-move hover:shadow-lg"
}`} }`}
data-id={`block-card-${block.id}`} data-id={`block-card-${block.id}`}
draggable={!block.notAvailable} draggable={
!block.notAvailable &&
block.id !== SpecialBlockID.MCP_TOOL
}
onDragStart={(e) => { onDragStart={(e) => {
if (block.notAvailable) return; if (
block.notAvailable ||
block.id === SpecialBlockID.MCP_TOOL
)
return;
e.dataTransfer.effectAllowed = "copy"; e.dataTransfer.effectAllowed = "copy";
e.dataTransfer.setData( e.dataTransfer.setData(
"application/reactflow", "application/reactflow",
@@ -386,6 +422,13 @@ export function BlocksControl({
</Card> </Card>
</PopoverContent> </PopoverContent>
</Popover> </Popover>
<MCPToolDialog
open={mcpDialogOpen}
onClose={() => setMcpDialogOpen(false)}
onConfirm={handleMCPToolConfirm}
/>
</>
); );
} }

View File

@@ -21,6 +21,7 @@ import {
GraphInputSchema, GraphInputSchema,
GraphOutputSchema, GraphOutputSchema,
NodeExecutionResult, NodeExecutionResult,
SpecialBlockID,
} from "@/lib/autogpt-server-api"; } from "@/lib/autogpt-server-api";
import { import {
beautifyString, beautifyString,
@@ -215,6 +216,26 @@ export const CustomNode = React.memo(
} }
} }
// MCP Tool block: display the selected tool's dynamic schema
const isMCPWithTool =
data.block_id === SpecialBlockID.MCP_TOOL &&
!!data.hardcodedValues?.tool_input_schema?.properties;
if (isMCPWithTool) {
// Show only the tool's input parameters. Credentials are NOT included
// because authentication is handled by the MCP dialog's OAuth flow
// and stored server-side.
const toolSchema = data.hardcodedValues.tool_input_schema;
data.inputSchema = {
type: "object",
properties: {
...(toolSchema.properties ?? {}),
},
required: [...(toolSchema.required ?? [])],
} as BlockIORootSchema;
}
const setHardcodedValues = useCallback( const setHardcodedValues = useCallback(
(values: any) => { (values: any) => {
updateNodeData(id, { hardcodedValues: values }); updateNodeData(id, { hardcodedValues: values });
@@ -375,7 +396,9 @@ export const CustomNode = React.memo(
const displayTitle = const displayTitle =
customTitle || customTitle ||
beautifyString(data.blockType?.replace(/Block$/, "") || data.title); (isMCPWithTool
? `${data.hardcodedValues.server_name || "MCP"}: ${beautifyString(data.hardcodedValues.selected_tool || "")}`
: beautifyString(data.blockType?.replace(/Block$/, "") || data.title));
useEffect(() => { useEffect(() => {
isInitialSetup.current = false; isInitialSetup.current = false;
@@ -389,6 +412,15 @@ export const CustomNode = React.memo(
data.inputSchema, data.inputSchema,
), ),
}); });
} else if (isMCPWithTool) {
// MCP dialog already configured server_url, selected_tool, etc.
// Just ensure tool_arguments is initialized.
if (!data.hardcodedValues.tool_arguments) {
setHardcodedValues({
...data.hardcodedValues,
tool_arguments: {},
});
}
} else { } else {
setHardcodedValues( setHardcodedValues(
fillObjectDefaultsFromSchema(data.hardcodedValues, data.inputSchema), fillObjectDefaultsFromSchema(data.hardcodedValues, data.inputSchema),
@@ -525,8 +557,11 @@ export const CustomNode = React.memo(
); );
default: default:
const getInputPropKey = (key: string) => const getInputPropKey = (key: string) => {
nodeType == BlockUIType.AGENT ? `inputs.${key}` : key; if (nodeType == BlockUIType.AGENT) return `inputs.${key}`;
if (isMCPWithTool) return `tool_arguments.${key}`;
return key;
};
return keys.map(([propKey, propSchema]) => { return keys.map(([propKey, propSchema]) => {
const isRequired = data.inputSchema.required?.includes(propKey); const isRequired = data.inputSchema.required?.includes(propKey);

View File

@@ -42,7 +42,11 @@ import { getV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/
import { okData } from "@/app/api/helpers"; import { okData } from "@/app/api/helpers";
import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types"; import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types";
import { Key, storage } from "@/services/storage/local-storage"; import { Key, storage } from "@/services/storage/local-storage";
import { findNewlyAddedBlockCoordinates, getTypeColor } from "@/lib/utils"; import {
beautifyString,
findNewlyAddedBlockCoordinates,
getTypeColor,
} from "@/lib/utils";
import { history } from "../history"; import { history } from "../history";
import { CustomEdge } from "../CustomEdge/CustomEdge"; import { CustomEdge } from "../CustomEdge/CustomEdge";
import ConnectionLine from "../ConnectionLine"; import ConnectionLine from "../ConnectionLine";
@@ -748,6 +752,27 @@ const FlowEditor: React.FC<{
block_id: blockID, block_id: blockID,
isOutputStatic: nodeSchema.staticOutput, isOutputStatic: nodeSchema.staticOutput,
uiType: nodeSchema.uiType, uiType: nodeSchema.uiType,
// Set customized_name at creation so it persists through save/load
...(blockID === SpecialBlockID.MCP_TOOL && {
metadata: {
credentials_optional: true,
...(finalHardcodedValues.selected_tool && {
customized_name: `${
finalHardcodedValues.server_name ||
(() => {
try {
return new URL(finalHardcodedValues.server_url).hostname;
} catch {
return "MCP";
}
})()
}: ${beautifyString(finalHardcodedValues.selected_tool)}`,
}),
},
}),
...(blockID === SpecialBlockID.AGENT && {
metadata: { customized_name: blockName },
}),
}, },
}; };
@@ -877,8 +902,6 @@ const FlowEditor: React.FC<{
return ( return (
node.data.metadata?.customized_name || node.data.metadata?.customized_name ||
(node.data.uiType == BlockUIType.AGENT &&
node.data.hardcodedValues.agent_name) ||
node.data.blockType.replace(/Block$/, "") node.data.blockType.replace(/Block$/, "")
); );
}, },

View File

@@ -0,0 +1,534 @@
"use client";
import React, {
useState,
useCallback,
useRef,
useEffect,
useContext,
} from "react";
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
} from "@/components/__legacy__/ui/dialog";
import { Button } from "@/components/__legacy__/ui/button";
import { Input } from "@/components/__legacy__/ui/input";
import { Label } from "@/components/__legacy__/ui/label";
import { LoadingSpinner } from "@/components/__legacy__/ui/loading";
import { Badge } from "@/components/__legacy__/ui/badge";
import { ScrollArea } from "@/components/__legacy__/ui/scroll-area";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
import type { CredentialsMetaInput, MCPTool } from "@/lib/autogpt-server-api";
import { CaretDown } from "@phosphor-icons/react";
import { openOAuthPopup } from "@/lib/oauth-popup";
import { CredentialsProvidersContext } from "@/providers/agent-credentials/credentials-provider";
export type MCPToolDialogResult = {
serverUrl: string;
serverName: string | null;
selectedTool: string;
toolInputSchema: Record<string, any>;
availableTools: Record<string, any>;
/** Credentials meta from OAuth flow, null for public servers. */
credentials: CredentialsMetaInput | null;
};
interface MCPToolDialogProps {
open: boolean;
onClose: () => void;
onConfirm: (result: MCPToolDialogResult) => void;
}
type DialogStep = "url" | "tool";
export function MCPToolDialog({
open,
onClose,
onConfirm,
}: MCPToolDialogProps) {
const api = useBackendAPI();
const allProviders = useContext(CredentialsProvidersContext);
const [step, setStep] = useState<DialogStep>("url");
const [serverUrl, setServerUrl] = useState("");
const [tools, setTools] = useState<MCPTool[]>([]);
const [serverName, setServerName] = useState<string | null>(null);
const [loading, setLoading] = useState(false);
const [error, setError] = useState<string | null>(null);
const [authRequired, setAuthRequired] = useState(false);
const [oauthLoading, setOauthLoading] = useState(false);
const [showManualToken, setShowManualToken] = useState(false);
const [manualToken, setManualToken] = useState("");
const [selectedTool, setSelectedTool] = useState<MCPTool | null>(null);
const [credentials, setCredentials] = useState<CredentialsMetaInput | null>(
null,
);
const startOAuthRef = useRef(false);
const oauthAbortRef = useRef<((reason?: string) => void) | null>(null);
// Clean up on unmount
useEffect(() => {
return () => {
oauthAbortRef.current?.();
};
}, []);
const reset = useCallback(() => {
oauthAbortRef.current?.();
oauthAbortRef.current = null;
setStep("url");
setServerUrl("");
setManualToken("");
setTools([]);
setServerName(null);
setLoading(false);
setError(null);
setAuthRequired(false);
setOauthLoading(false);
setShowManualToken(false);
setSelectedTool(null);
setCredentials(null);
}, []);
const handleClose = useCallback(() => {
reset();
onClose();
}, [reset, onClose]);
const discoverTools = useCallback(
async (url: string, authToken?: string) => {
setLoading(true);
setError(null);
try {
const result = await api.mcpDiscoverTools(url, authToken);
setTools(result.tools);
setServerName(result.server_name);
setAuthRequired(false);
setShowManualToken(false);
setStep("tool");
} catch (e: any) {
if (e?.status === 401 || e?.status === 403) {
setAuthRequired(true);
setError(null);
// Automatically start OAuth sign-in instead of requiring a second click
setLoading(false);
startOAuthRef.current = true;
return;
} else {
const message =
e?.message || e?.detail || "Failed to connect to MCP server";
setError(
typeof message === "string" ? message : JSON.stringify(message),
);
}
} finally {
setLoading(false);
}
},
[api],
);
const handleDiscoverTools = useCallback(() => {
if (!serverUrl.trim()) return;
discoverTools(serverUrl.trim(), manualToken.trim() || undefined);
}, [serverUrl, manualToken, discoverTools]);
const handleOAuthSignIn = useCallback(async () => {
if (!serverUrl.trim()) return;
setError(null);
// Abort any previous OAuth flow
oauthAbortRef.current?.();
setOauthLoading(true);
try {
const { login_url, state_token } = await api.mcpOAuthLogin(
serverUrl.trim(),
);
const { promise, cleanup } = openOAuthPopup(login_url, {
stateToken: state_token,
useCrossOriginListeners: true,
});
oauthAbortRef.current = cleanup.abort;
const result = await promise;
// Exchange code for tokens via the credentials provider (updates cache)
setLoading(true);
setOauthLoading(false);
const mcpProvider = allProviders?.["mcp"];
const callbackResult = mcpProvider
? await mcpProvider.mcpOAuthCallback(result.code, state_token)
: await api.mcpOAuthCallback(result.code, state_token);
setCredentials({
id: callbackResult.id,
provider: callbackResult.provider,
type: callbackResult.type,
title: callbackResult.title,
});
setAuthRequired(false);
// Discover tools now that we're authenticated
const toolsResult = await api.mcpDiscoverTools(serverUrl.trim());
setTools(toolsResult.tools);
setServerName(toolsResult.server_name);
setStep("tool");
} catch (e: any) {
// If server doesn't support OAuth → show manual token entry
if (e?.status === 400) {
setShowManualToken(true);
setError(
"This server does not support OAuth sign-in. Please enter a token manually.",
);
} else if (e?.message === "OAuth flow timed out") {
setError("OAuth sign-in timed out. Please try again.");
} else {
const status = e?.status;
let message: string;
if (status === 401 || status === 403) {
message =
"Authentication succeeded but the server still rejected the request. " +
"The token audience may not match. Please try again.";
} else {
message = e?.message || e?.detail || "Failed to complete sign-in";
}
setError(
typeof message === "string" ? message : JSON.stringify(message),
);
}
} finally {
setOauthLoading(false);
setLoading(false);
oauthAbortRef.current = null;
}
}, [api, serverUrl, allProviders]);
// Auto-start OAuth sign-in when server returns 401/403
useEffect(() => {
if (authRequired && startOAuthRef.current) {
startOAuthRef.current = false;
handleOAuthSignIn();
}
}, [authRequired, handleOAuthSignIn]);
const handleConfirm = useCallback(() => {
if (!selectedTool) return;
const availableTools: Record<string, any> = {};
for (const t of tools) {
availableTools[t.name] = {
description: t.description,
input_schema: t.input_schema,
};
}
onConfirm({
serverUrl: serverUrl.trim(),
serverName,
selectedTool: selectedTool.name,
toolInputSchema: selectedTool.input_schema,
availableTools,
credentials,
});
reset();
}, [
selectedTool,
tools,
serverUrl,
serverName,
credentials,
onConfirm,
reset,
]);
return (
<Dialog open={open} onOpenChange={(isOpen) => !isOpen && handleClose()}>
<DialogContent className="max-w-lg">
<DialogHeader>
<DialogTitle>
{step === "url"
? "Connect to MCP Server"
: `Select a Tool${serverName ? `${serverName}` : ""}`}
</DialogTitle>
<DialogDescription>
{step === "url"
? "Enter the URL of an MCP server to discover its available tools."
: `Found ${tools.length} tool${tools.length !== 1 ? "s" : ""}. Select one to add to your agent.`}
</DialogDescription>
</DialogHeader>
{step === "url" && (
<div className="flex flex-col gap-4 py-2">
<div className="flex flex-col gap-2">
<Label htmlFor="mcp-server-url">Server URL</Label>
<Input
id="mcp-server-url"
type="url"
placeholder="https://mcp.example.com/mcp"
value={serverUrl}
onChange={(e) => setServerUrl(e.target.value)}
onKeyDown={(e) => e.key === "Enter" && handleDiscoverTools()}
autoFocus
/>
</div>
{/* Auth required: show manual token option */}
{authRequired && !showManualToken && (
<button
onClick={() => setShowManualToken(true)}
className="text-xs text-gray-500 underline hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-300"
>
or enter a token manually
</button>
)}
{/* Manual token entry — only visible when expanded */}
{showManualToken && (
<div className="flex flex-col gap-2">
<Label htmlFor="mcp-auth-token" className="text-sm">
Bearer Token
</Label>
<Input
id="mcp-auth-token"
type="password"
placeholder="Paste your auth token here"
value={manualToken}
onChange={(e) => setManualToken(e.target.value)}
onKeyDown={(e) => e.key === "Enter" && handleDiscoverTools()}
autoFocus
/>
</div>
)}
{error && <p className="text-sm text-red-500">{error}</p>}
</div>
)}
{step === "tool" && (
<ScrollArea className="max-h-[50vh] py-2">
<div className="flex flex-col gap-2 pr-3">
{tools.map((tool) => (
<MCPToolCard
key={tool.name}
tool={tool}
selected={selectedTool?.name === tool.name}
onSelect={() => setSelectedTool(tool)}
/>
))}
</div>
</ScrollArea>
)}
<DialogFooter>
{step === "tool" && (
<Button
variant="outline"
onClick={() => {
setStep("url");
setSelectedTool(null);
}}
>
Back
</Button>
)}
<Button variant="outline" onClick={handleClose}>
Cancel
</Button>
{step === "url" && (
<Button
onClick={
authRequired && !showManualToken
? handleOAuthSignIn
: handleDiscoverTools
}
disabled={!serverUrl.trim() || loading || oauthLoading}
>
{loading || oauthLoading ? (
<span className="flex items-center gap-2">
<LoadingSpinner className="size-4" />
{oauthLoading ? "Waiting for sign-in..." : "Connecting..."}
</span>
) : authRequired && !showManualToken ? (
"Sign in & Connect"
) : (
"Discover Tools"
)}
</Button>
)}
{step === "tool" && (
<Button onClick={handleConfirm} disabled={!selectedTool}>
Add Block
</Button>
)}
</DialogFooter>
</DialogContent>
</Dialog>
);
}
// --------------- Tool Card Component --------------- //
/** Truncate a description to a reasonable length for the collapsed view. */
function truncateDescription(text: string, maxLen = 120): string {
if (text.length <= maxLen) return text;
return text.slice(0, maxLen).trimEnd() + "…";
}
/** Pretty-print a JSON Schema type for a parameter. */
function schemaTypeLabel(schema: Record<string, any>): string {
if (schema.type) return schema.type;
if (schema.anyOf)
return schema.anyOf.map((s: any) => s.type ?? "any").join(" | ");
if (schema.oneOf)
return schema.oneOf.map((s: any) => s.type ?? "any").join(" | ");
return "any";
}
function MCPToolCard({
tool,
selected,
onSelect,
}: {
tool: MCPTool;
selected: boolean;
onSelect: () => void;
}) {
const [expanded, setExpanded] = useState(false);
const properties = tool.input_schema?.properties ?? {};
const required = new Set<string>(tool.input_schema?.required ?? []);
const paramNames = Object.keys(properties);
// Strip XML-like tags from description for cleaner display.
// Loop to handle nested tags like <scr<script>ipt> (CodeQL fix).
let cleanDescription = tool.description ?? "";
let prev = "";
while (prev !== cleanDescription) {
prev = cleanDescription;
cleanDescription = cleanDescription.replace(/<[^>]*>/g, "");
}
cleanDescription = cleanDescription.trim();
return (
<button
onClick={onSelect}
className={`group flex flex-col rounded-lg border text-left transition-colors ${
selected
? "border-blue-500 bg-blue-50 dark:border-blue-400 dark:bg-blue-950"
: "border-gray-200 hover:border-gray-300 hover:bg-gray-50 dark:border-slate-700 dark:hover:border-slate-600 dark:hover:bg-slate-800"
}`}
>
{/* Header */}
<div className="flex items-center gap-2 px-3 pb-1 pt-3">
<span className="flex-1 text-sm font-semibold dark:text-white">
{tool.name}
</span>
{paramNames.length > 0 && (
<Badge variant="secondary" className="text-[10px]">
{paramNames.length} param{paramNames.length !== 1 ? "s" : ""}
</Badge>
)}
</div>
{/* Description (collapsed: truncated) */}
{cleanDescription && (
<p className="px-3 pb-1 text-xs leading-relaxed text-gray-500 dark:text-gray-400">
{expanded ? cleanDescription : truncateDescription(cleanDescription)}
</p>
)}
{/* Parameter badges (collapsed view) */}
{!expanded && paramNames.length > 0 && (
<div className="flex flex-wrap gap-1 px-3 pb-2">
{paramNames.slice(0, 6).map((name) => (
<Badge
key={name}
variant="outline"
className="text-[10px] font-normal"
>
{name}
{required.has(name) && (
<span className="ml-0.5 text-red-400">*</span>
)}
</Badge>
))}
{paramNames.length > 6 && (
<Badge variant="outline" className="text-[10px] font-normal">
+{paramNames.length - 6} more
</Badge>
)}
</div>
)}
{/* Expanded: full parameter details */}
{expanded && paramNames.length > 0 && (
<div className="mx-3 mb-2 rounded border border-gray-100 bg-gray-50/50 dark:border-slate-700 dark:bg-slate-800/50">
<table className="w-full text-xs">
<thead>
<tr className="border-b border-gray-100 dark:border-slate-700">
<th className="px-2 py-1 text-left font-medium text-gray-500 dark:text-gray-400">
Parameter
</th>
<th className="px-2 py-1 text-left font-medium text-gray-500 dark:text-gray-400">
Type
</th>
<th className="px-2 py-1 text-left font-medium text-gray-500 dark:text-gray-400">
Description
</th>
</tr>
</thead>
<tbody>
{paramNames.map((name) => {
const prop = properties[name] ?? {};
return (
<tr
key={name}
className="border-b border-gray-50 last:border-0 dark:border-slate-700/50"
>
<td className="px-2 py-1 font-mono text-[11px] text-gray-700 dark:text-gray-300">
{name}
{required.has(name) && (
<span className="ml-0.5 text-red-400">*</span>
)}
</td>
<td className="px-2 py-1 text-gray-500 dark:text-gray-400">
{schemaTypeLabel(prop)}
</td>
<td className="max-w-[200px] truncate px-2 py-1 text-gray-500 dark:text-gray-400">
{prop.description ?? "—"}
</td>
</tr>
);
})}
</tbody>
</table>
</div>
)}
{/* Toggle details */}
{(paramNames.length > 0 || cleanDescription.length > 120) && (
<button
type="button"
onClick={(e) => {
e.stopPropagation();
setExpanded((prev) => !prev);
}}
className="flex w-full items-center justify-center gap-1 border-t border-gray-100 py-1.5 text-[10px] text-gray-400 hover:text-gray-600 dark:border-slate-700 dark:text-gray-500 dark:hover:text-gray-300"
>
{expanded ? "Hide details" : "Show details"}
<CaretDown
className={`h-3 w-3 transition-transform ${expanded ? "rotate-180" : ""}`}
/>
</button>
)}
</button>
);
}

View File

@@ -0,0 +1,76 @@
"use client";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { SidebarProvider } from "@/components/ui/sidebar";
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
import { useCopilotPage } from "./useCopilotPage";
export function CopilotPage() {
const {
sessionId,
messages,
status,
error,
stop,
createSession,
onSend,
isLoadingSession,
isCreatingSession,
isUserLoading,
isLoggedIn,
// Mobile drawer
isMobile,
isDrawerOpen,
sessions,
isLoadingSessions,
handleOpenDrawer,
handleCloseDrawer,
handleDrawerOpenChange,
handleSelectSession,
handleNewChat,
} = useCopilotPage();
if (isUserLoading || !isLoggedIn) {
return <LoadingSpinner size="large" cover />;
}
return (
<SidebarProvider
defaultOpen={true}
className="h-[calc(100vh-72px)] min-h-0"
>
{!isMobile && <ChatSidebar />}
<div className="relative flex h-full w-full flex-col overflow-hidden bg-[#f8f8f9] px-0">
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
<div className="flex-1 overflow-hidden">
<ChatContainer
messages={messages}
status={status}
error={error}
sessionId={sessionId}
isLoadingSession={isLoadingSession}
isCreatingSession={isCreatingSession}
onCreateSession={createSession}
onSend={onSend}
onStop={stop}
/>
</div>
</div>
{isMobile && (
<MobileDrawer
isOpen={isDrawerOpen}
sessions={sessions}
currentSessionId={sessionId}
isLoading={isLoadingSessions}
onSelectSession={handleSelectSession}
onNewChat={handleNewChat}
onClose={handleCloseDrawer}
onOpenChange={handleDrawerOpenChange}
/>
)}
</SidebarProvider>
);
}

View File

@@ -0,0 +1,74 @@
"use client";
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
import { UIDataTypes, UIMessage, UITools } from "ai";
import { LayoutGroup, motion } from "framer-motion";
import { ChatMessagesContainer } from "../ChatMessagesContainer/ChatMessagesContainer";
import { CopilotChatActionsProvider } from "../CopilotChatActionsProvider/CopilotChatActionsProvider";
import { EmptySession } from "../EmptySession/EmptySession";
export interface ChatContainerProps {
messages: UIMessage<unknown, UIDataTypes, UITools>[];
status: string;
error: Error | undefined;
sessionId: string | null;
isLoadingSession: boolean;
isCreatingSession: boolean;
onCreateSession: () => void | Promise<string>;
onSend: (message: string) => void | Promise<void>;
onStop: () => void;
}
export const ChatContainer = ({
messages,
status,
error,
sessionId,
isLoadingSession,
isCreatingSession,
onCreateSession,
onSend,
onStop,
}: ChatContainerProps) => {
const inputLayoutId = "copilot-2-chat-input";
return (
<CopilotChatActionsProvider onSend={onSend}>
<LayoutGroup id="copilot-2-chat-layout">
<div className="flex h-full min-h-0 w-full flex-col bg-[#f8f8f9] px-2 lg:px-0">
{sessionId ? (
<div className="mx-auto flex h-full min-h-0 w-full max-w-3xl flex-col">
<ChatMessagesContainer
messages={messages}
status={status}
error={error}
isLoading={isLoadingSession}
/>
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.3 }}
className="relative px-3 pb-2 pt-2"
>
<div className="pointer-events-none absolute left-0 right-0 top-[-18px] z-10 h-6 bg-gradient-to-b from-transparent to-[#f8f8f9]" />
<ChatInput
inputId="chat-input-session"
onSend={onSend}
disabled={status === "streaming"}
isStreaming={status === "streaming"}
onStop={onStop}
placeholder="What else can I help with?"
/>
</motion.div>
</div>
) : (
<EmptySession
inputLayoutId={inputLayoutId}
isCreatingSession={isCreatingSession}
onCreateSession={onCreateSession}
onSend={onSend}
/>
)}
</div>
</LayoutGroup>
</CopilotChatActionsProvider>
);
};

View File

@@ -6,17 +6,19 @@ import {
MicrophoneIcon, MicrophoneIcon,
StopIcon, StopIcon,
} from "@phosphor-icons/react"; } from "@phosphor-icons/react";
import { ChangeEvent, useCallback } from "react";
import { RecordingIndicator } from "./components/RecordingIndicator"; import { RecordingIndicator } from "./components/RecordingIndicator";
import { useChatInput } from "./useChatInput"; import { useChatInput } from "./useChatInput";
import { useVoiceRecording } from "./useVoiceRecording"; import { useVoiceRecording } from "./useVoiceRecording";
export interface Props { export interface Props {
onSend: (message: string) => void; onSend: (message: string) => void | Promise<void>;
disabled?: boolean; disabled?: boolean;
isStreaming?: boolean; isStreaming?: boolean;
onStop?: () => void; onStop?: () => void;
placeholder?: string; placeholder?: string;
className?: string; className?: string;
inputId?: string;
} }
export function ChatInput({ export function ChatInput({
@@ -26,14 +28,14 @@ export function ChatInput({
onStop, onStop,
placeholder = "Type your message...", placeholder = "Type your message...",
className, className,
inputId = "chat-input",
}: Props) { }: Props) {
const inputId = "chat-input";
const { const {
value, value,
setValue, setValue,
handleKeyDown: baseHandleKeyDown, handleKeyDown: baseHandleKeyDown,
handleSubmit, handleSubmit,
handleChange, handleChange: baseHandleChange,
hasMultipleLines, hasMultipleLines,
} = useChatInput({ } = useChatInput({
onSend, onSend,
@@ -60,6 +62,15 @@ export function ChatInput({
inputId, inputId,
}); });
// Block text changes when recording
const handleChange = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
if (isRecording) return;
baseHandleChange(e);
},
[isRecording, baseHandleChange],
);
return ( return (
<form onSubmit={handleSubmit} className={cn("relative flex-1", className)}> <form onSubmit={handleSubmit} className={cn("relative flex-1", className)}>
<div className="relative"> <div className="relative">

View File

@@ -21,6 +21,7 @@ export function useChatInput({
}: Args) { }: Args) {
const [value, setValue] = useState(""); const [value, setValue] = useState("");
const [hasMultipleLines, setHasMultipleLines] = useState(false); const [hasMultipleLines, setHasMultipleLines] = useState(false);
const [isSending, setIsSending] = useState(false);
useEffect( useEffect(
function focusOnMount() { function focusOnMount() {
@@ -100,9 +101,12 @@ export function useChatInput({
} }
}, [value, maxRows, inputId]); }, [value, maxRows, inputId]);
const handleSend = () => { async function handleSend() {
if (disabled || !value.trim()) return; if (disabled || isSending || !value.trim()) return;
onSend(value.trim());
setIsSending(true);
try {
await onSend(value.trim());
setValue(""); setValue("");
setHasMultipleLines(false); setHasMultipleLines(false);
const textarea = document.getElementById(inputId) as HTMLTextAreaElement; const textarea = document.getElementById(inputId) as HTMLTextAreaElement;
@@ -116,18 +120,21 @@ export function useChatInput({
wrapper.style.height = ""; wrapper.style.height = "";
wrapper.style.maxHeight = ""; wrapper.style.maxHeight = "";
} }
}; } finally {
setIsSending(false);
}
}
function handleKeyDown(event: KeyboardEvent<HTMLTextAreaElement>) { function handleKeyDown(event: KeyboardEvent<HTMLTextAreaElement>) {
if (event.key === "Enter" && !event.shiftKey) { if (event.key === "Enter" && !event.shiftKey) {
event.preventDefault(); event.preventDefault();
handleSend(); void handleSend();
} }
} }
function handleSubmit(e: FormEvent<HTMLFormElement>) { function handleSubmit(e: FormEvent<HTMLFormElement>) {
e.preventDefault(); e.preventDefault();
handleSend(); void handleSend();
} }
function handleChange(e: ChangeEvent<HTMLTextAreaElement>) { function handleChange(e: ChangeEvent<HTMLTextAreaElement>) {
@@ -142,5 +149,6 @@ export function useChatInput({
handleSubmit, handleSubmit,
handleChange, handleChange,
hasMultipleLines, hasMultipleLines,
isSending,
}; };
} }

View File

@@ -38,9 +38,13 @@ export function useVoiceRecording({
const streamRef = useRef<MediaStream | null>(null); const streamRef = useRef<MediaStream | null>(null);
const isRecordingRef = useRef(false); const isRecordingRef = useRef(false);
const isSupported = const [isSupported, setIsSupported] = useState(false);
typeof window !== "undefined" &&
!!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia); useEffect(() => {
setIsSupported(
!!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia),
);
}, []);
const clearTimer = useCallback(() => { const clearTimer = useCallback(() => {
if (timerRef.current) { if (timerRef.current) {
@@ -214,17 +218,33 @@ export function useVoiceRecording({
const handleKeyDown = useCallback( const handleKeyDown = useCallback(
(event: KeyboardEvent<HTMLTextAreaElement>) => { (event: KeyboardEvent<HTMLTextAreaElement>) => {
if (event.key === " " && !value.trim() && !isTranscribing) { // Allow space to toggle recording (start when empty, stop when recording)
if (event.key === " " && !isTranscribing) {
if (isRecordingRef.current) {
// Stop recording on space
event.preventDefault();
stopRecording();
return;
} else if (!value.trim()) {
// Start recording on space when input is empty
event.preventDefault();
void startRecording();
return;
}
}
// Block all key events when recording (except space handled above)
if (isRecordingRef.current) {
event.preventDefault(); event.preventDefault();
toggleRecording();
return; return;
} }
baseHandleKeyDown(event); baseHandleKeyDown(event);
}, },
[value, isTranscribing, toggleRecording, baseHandleKeyDown], [value, isTranscribing, stopRecording, startRecording, baseHandleKeyDown],
); );
const showMicButton = isSupported; const showMicButton = isSupported;
// Don't include isRecording in disabled state - we need key events to work
// Text input is blocked via handleKeyDown instead
const isInputDisabled = disabled || isStreaming || isTranscribing; const isInputDisabled = disabled || isStreaming || isTranscribing;
// Cleanup on unmount // Cleanup on unmount

View File

@@ -0,0 +1,274 @@
import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
import {
Conversation,
ConversationContent,
ConversationScrollButton,
} from "@/components/ai-elements/conversation";
import {
Message,
MessageContent,
MessageResponse,
} from "@/components/ai-elements/message";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
import { useEffect, useState } from "react";
import { CreateAgentTool } from "../../tools/CreateAgent/CreateAgent";
import { EditAgentTool } from "../../tools/EditAgent/EditAgent";
import { FindAgentsTool } from "../../tools/FindAgents/FindAgents";
import { FindBlocksTool } from "../../tools/FindBlocks/FindBlocks";
import { RunAgentTool } from "../../tools/RunAgent/RunAgent";
import { RunBlockTool } from "../../tools/RunBlock/RunBlock";
import { SearchDocsTool } from "../../tools/SearchDocs/SearchDocs";
import { ViewAgentOutputTool } from "../../tools/ViewAgentOutput/ViewAgentOutput";
// ---------------------------------------------------------------------------
// Workspace media support
// ---------------------------------------------------------------------------
/**
* Resolve workspace:// URLs in markdown text to proxy download URLs.
* Detects MIME type from the hash fragment (e.g. workspace://id#video/mp4)
* and prefixes the alt text with "video:" so the custom img component can
* render a <video> element instead.
*/
function resolveWorkspaceUrls(text: string): string {
return text.replace(
/!\[([^\]]*)\]\(workspace:\/\/([^)#\s]+)(?:#([^)\s]*))?\)/g,
(_match, alt: string, fileId: string, mimeHint?: string) => {
const apiPath = getGetWorkspaceDownloadFileByIdUrl(fileId);
const url = `/api/proxy${apiPath}`;
if (mimeHint?.startsWith("video/")) {
return `![video:${alt || "Video"}](${url})`;
}
return `![${alt || "Image"}](${url})`;
},
);
}
/**
* Custom img component for Streamdown that renders <video> elements
* for workspace video files (detected via "video:" alt-text prefix).
* Falls back to <video> when an <img> fails to load for workspace files.
*/
function WorkspaceMediaImage(props: React.JSX.IntrinsicElements["img"]) {
const { src, alt, ...rest } = props;
const [imgFailed, setImgFailed] = useState(false);
const isWorkspace = src?.includes("/workspace/files/") ?? false;
if (!src) return null;
if (alt?.startsWith("video:") || (imgFailed && isWorkspace)) {
return (
<span className="my-2 inline-block">
<video
controls
className="h-auto max-w-full rounded-md border border-zinc-200"
preload="metadata"
>
<source src={src} />
Your browser does not support the video tag.
</video>
</span>
);
}
return (
// eslint-disable-next-line @next/next/no-img-element
<img
src={src}
alt={alt || "Image"}
className="h-auto max-w-full rounded-md border border-zinc-200"
loading="lazy"
onError={() => {
if (isWorkspace) setImgFailed(true);
}}
{...rest}
/>
);
}
/** Stable components override for Streamdown (avoids re-creating on every render). */
const STREAMDOWN_COMPONENTS = { img: WorkspaceMediaImage };
const THINKING_PHRASES = [
"Thinking...",
"Considering this...",
"Working through this...",
"Analyzing your request...",
"Reasoning...",
"Looking into it...",
"Processing your request...",
"Mulling this over...",
"Piecing it together...",
"On it...",
];
function getRandomPhrase() {
return THINKING_PHRASES[Math.floor(Math.random() * THINKING_PHRASES.length)];
}
interface ChatMessagesContainerProps {
messages: UIMessage<unknown, UIDataTypes, UITools>[];
status: string;
error: Error | undefined;
isLoading: boolean;
}
export const ChatMessagesContainer = ({
messages,
status,
error,
isLoading,
}: ChatMessagesContainerProps) => {
const [thinkingPhrase, setThinkingPhrase] = useState(getRandomPhrase);
useEffect(() => {
if (status === "submitted") {
setThinkingPhrase(getRandomPhrase());
}
}, [status]);
const lastMessage = messages[messages.length - 1];
const lastAssistantHasVisibleContent =
lastMessage?.role === "assistant" &&
lastMessage.parts.some(
(p) =>
(p.type === "text" && p.text.trim().length > 0) ||
p.type.startsWith("tool-"),
);
const showThinking =
status === "submitted" ||
(status === "streaming" && !lastAssistantHasVisibleContent);
return (
<Conversation className="min-h-0 flex-1">
<ConversationContent className="gap-6 px-3 py-6">
{isLoading && messages.length === 0 && (
<div className="flex flex-1 items-center justify-center">
<LoadingSpinner size="large" className="text-neutral-400" />
</div>
)}
{messages.map((message, messageIndex) => {
const isLastAssistant =
messageIndex === messages.length - 1 &&
message.role === "assistant";
const messageHasVisibleContent = message.parts.some(
(p) =>
(p.type === "text" && p.text.trim().length > 0) ||
p.type.startsWith("tool-"),
);
return (
<Message from={message.role} key={message.id}>
<MessageContent
className={
"text-[1rem] leading-relaxed " +
"group-[.is-user]:rounded-xl group-[.is-user]:bg-purple-100 group-[.is-user]:px-3 group-[.is-user]:py-2.5 group-[.is-user]:text-slate-900 group-[.is-user]:[border-bottom-right-radius:0] " +
"group-[.is-assistant]:bg-transparent group-[.is-assistant]:text-slate-900"
}
>
{message.parts.map((part, i) => {
switch (part.type) {
case "text":
return (
<MessageResponse
key={`${message.id}-${i}`}
components={STREAMDOWN_COMPONENTS}
>
{resolveWorkspaceUrls(part.text)}
</MessageResponse>
);
case "tool-find_block":
return (
<FindBlocksTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
case "tool-find_agent":
case "tool-find_library_agent":
return (
<FindAgentsTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
case "tool-search_docs":
case "tool-get_doc_page":
return (
<SearchDocsTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
case "tool-run_block":
return (
<RunBlockTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
case "tool-run_agent":
case "tool-schedule_agent":
return (
<RunAgentTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
case "tool-create_agent":
return (
<CreateAgentTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
case "tool-edit_agent":
return (
<EditAgentTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
case "tool-view_agent_output":
return (
<ViewAgentOutputTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
default:
return null;
}
})}
{isLastAssistant &&
!messageHasVisibleContent &&
showThinking && (
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
{thinkingPhrase}
</span>
)}
</MessageContent>
</Message>
);
})}
{showThinking && lastMessage?.role !== "assistant" && (
<Message from="assistant">
<MessageContent className="text-[1rem] leading-relaxed">
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
{thinkingPhrase}
</span>
</MessageContent>
</Message>
)}
{error && (
<div className="rounded-lg bg-red-50 p-3 text-red-600">
Error: {error.message}
</div>
)}
</ConversationContent>
<ConversationScrollButton />
</Conversation>
);
};

View File

@@ -0,0 +1,188 @@
"use client";
import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat";
import { Button } from "@/components/atoms/Button/Button";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { Text } from "@/components/atoms/Text/Text";
import {
Sidebar,
SidebarContent,
SidebarFooter,
SidebarHeader,
SidebarTrigger,
useSidebar,
} from "@/components/ui/sidebar";
import { cn } from "@/lib/utils";
import { PlusCircleIcon, PlusIcon } from "@phosphor-icons/react";
import { motion } from "framer-motion";
import { parseAsString, useQueryState } from "nuqs";
export function ChatSidebar() {
const { state } = useSidebar();
const isCollapsed = state === "collapsed";
const [sessionId, setSessionId] = useQueryState("sessionId", parseAsString);
const { data: sessionsResponse, isLoading: isLoadingSessions } =
useGetV2ListSessions({ limit: 50 });
const sessions =
sessionsResponse?.status === 200 ? sessionsResponse.data.sessions : [];
function handleNewChat() {
setSessionId(null);
}
function handleSelectSession(id: string) {
setSessionId(id);
}
function formatDate(dateString: string) {
const date = new Date(dateString);
const now = new Date();
const diffMs = now.getTime() - date.getTime();
const diffDays = Math.floor(diffMs / (1000 * 60 * 60 * 24));
if (diffDays === 0) return "Today";
if (diffDays === 1) return "Yesterday";
if (diffDays < 7) return `${diffDays} days ago`;
const day = date.getDate();
const ordinal =
day % 10 === 1 && day !== 11
? "st"
: day % 10 === 2 && day !== 12
? "nd"
: day % 10 === 3 && day !== 13
? "rd"
: "th";
const month = date.toLocaleDateString("en-US", { month: "short" });
const year = date.getFullYear();
return `${day}${ordinal} ${month} ${year}`;
}
return (
<Sidebar
variant="inset"
collapsible="icon"
className="!top-[50px] !h-[calc(100vh-50px)] border-r border-zinc-100 px-0"
>
{isCollapsed && (
<SidebarHeader
className={cn(
"flex",
isCollapsed
? "flex-row items-center justify-between gap-y-4 md:flex-col md:items-start md:justify-start"
: "flex-row items-center justify-between",
)}
>
<motion.div
key={isCollapsed ? "header-collapsed" : "header-expanded"}
className="flex flex-col items-center gap-3 pt-4"
initial={{ opacity: 0, filter: "blur(3px)" }}
animate={{ opacity: 1, filter: "blur(0px)" }}
transition={{ type: "spring", bounce: 0.2 }}
>
<div className="flex flex-col items-center gap-2">
<SidebarTrigger />
<Button
variant="ghost"
onClick={handleNewChat}
style={{ minWidth: "auto", width: "auto" }}
>
<PlusCircleIcon className="!size-5" />
<span className="sr-only">New Chat</span>
</Button>
</div>
</motion.div>
</SidebarHeader>
)}
<SidebarContent className="gap-4 overflow-y-auto px-4 py-4 [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
{!isCollapsed && (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.2, delay: 0.1 }}
className="flex items-center justify-between px-3"
>
<Text variant="h3" size="body-medium">
Your chats
</Text>
<div className="relative left-6">
<SidebarTrigger />
</div>
</motion.div>
)}
{!isCollapsed && (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.2, delay: 0.15 }}
className="mt-4 flex flex-col gap-1"
>
{isLoadingSessions ? (
<div className="flex items-center justify-center py-4">
<LoadingSpinner size="small" className="text-neutral-400" />
</div>
) : sessions.length === 0 ? (
<p className="py-4 text-center text-sm text-neutral-500">
No conversations yet
</p>
) : (
sessions.map((session) => (
<button
key={session.id}
onClick={() => handleSelectSession(session.id)}
className={cn(
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
session.id === sessionId
? "bg-zinc-100"
: "hover:bg-zinc-50",
)}
>
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
<div className="min-w-0 max-w-full">
<Text
variant="body"
className={cn(
"truncate font-normal",
session.id === sessionId
? "text-zinc-600"
: "text-zinc-800",
)}
>
{session.title || `Untitled chat`}
</Text>
</div>
<Text variant="small" className="text-neutral-400">
{formatDate(session.updated_at)}
</Text>
</div>
</button>
))
)}
</motion.div>
)}
</SidebarContent>
{!isCollapsed && sessionId && (
<SidebarFooter className="shrink-0 bg-zinc-50 p-3 pb-1 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.2, delay: 0.2 }}
>
<Button
variant="primary"
size="small"
onClick={handleNewChat}
className="w-full"
leftIcon={<PlusIcon className="h-4 w-4" weight="bold" />}
>
New Chat
</Button>
</motion.div>
</SidebarFooter>
)}
</Sidebar>
);
}

View File

@@ -0,0 +1,16 @@
"use client";
import { CopilotChatActionsContext } from "./useCopilotChatActions";
interface Props {
onSend: (message: string) => void | Promise<void>;
children: React.ReactNode;
}
export function CopilotChatActionsProvider({ onSend, children }: Props) {
return (
<CopilotChatActionsContext.Provider value={{ onSend }}>
{children}
</CopilotChatActionsContext.Provider>
);
}

View File

@@ -0,0 +1,23 @@
"use client";
import { createContext, useContext } from "react";
interface CopilotChatActions {
onSend: (message: string) => void | Promise<void>;
}
const CopilotChatActionsContext = createContext<CopilotChatActions | null>(
null,
);
export function useCopilotChatActions(): CopilotChatActions {
const ctx = useContext(CopilotChatActionsContext);
if (!ctx) {
throw new Error(
"useCopilotChatActions must be used within CopilotChatActionsProvider",
);
}
return ctx;
}
export { CopilotChatActionsContext };

View File

@@ -1,99 +0,0 @@
"use client";
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
import { Text } from "@/components/atoms/Text/Text";
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
import type { ReactNode } from "react";
import { DesktopSidebar } from "./components/DesktopSidebar/DesktopSidebar";
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
import { useCopilotShell } from "./useCopilotShell";
interface Props {
children: ReactNode;
}
export function CopilotShell({ children }: Props) {
const {
isMobile,
isDrawerOpen,
isLoading,
isCreatingSession,
isLoggedIn,
hasActiveSession,
sessions,
currentSessionId,
handleOpenDrawer,
handleCloseDrawer,
handleDrawerOpenChange,
handleNewChatClick,
handleSessionClick,
hasNextPage,
isFetchingNextPage,
fetchNextPage,
} = useCopilotShell();
if (!isLoggedIn) {
return (
<div className="flex h-full items-center justify-center">
<ChatLoader />
</div>
);
}
return (
<div
className="flex overflow-hidden bg-[#EFEFF0]"
style={{ height: `calc(100vh - ${NAVBAR_HEIGHT_PX}px)` }}
>
{!isMobile && (
<DesktopSidebar
sessions={sessions}
currentSessionId={currentSessionId}
isLoading={isLoading}
hasNextPage={hasNextPage}
isFetchingNextPage={isFetchingNextPage}
onSelectSession={handleSessionClick}
onFetchNextPage={fetchNextPage}
onNewChat={handleNewChatClick}
hasActiveSession={Boolean(hasActiveSession)}
/>
)}
<div className="relative flex min-h-0 flex-1 flex-col">
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
<div className="flex min-h-0 flex-1 flex-col">
{isCreatingSession ? (
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9]">
<div className="flex flex-col items-center gap-4">
<ChatLoader />
<Text variant="body" className="text-zinc-500">
Creating your chat...
</Text>
</div>
</div>
) : (
children
)}
</div>
</div>
{isMobile && (
<MobileDrawer
isOpen={isDrawerOpen}
sessions={sessions}
currentSessionId={currentSessionId}
isLoading={isLoading}
hasNextPage={hasNextPage}
isFetchingNextPage={isFetchingNextPage}
onSelectSession={handleSessionClick}
onFetchNextPage={fetchNextPage}
onNewChat={handleNewChatClick}
onClose={handleCloseDrawer}
onOpenChange={handleDrawerOpenChange}
hasActiveSession={Boolean(hasActiveSession)}
/>
)}
</div>
);
}

View File

@@ -1,70 +0,0 @@
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { scrollbarStyles } from "@/components/styles/scrollbars";
import { cn } from "@/lib/utils";
import { Plus } from "@phosphor-icons/react";
import { SessionsList } from "../SessionsList/SessionsList";
interface Props {
sessions: SessionSummaryResponse[];
currentSessionId: string | null;
isLoading: boolean;
hasNextPage: boolean;
isFetchingNextPage: boolean;
onSelectSession: (sessionId: string) => void;
onFetchNextPage: () => void;
onNewChat: () => void;
hasActiveSession: boolean;
}
export function DesktopSidebar({
sessions,
currentSessionId,
isLoading,
hasNextPage,
isFetchingNextPage,
onSelectSession,
onFetchNextPage,
onNewChat,
hasActiveSession,
}: Props) {
return (
<aside className="flex h-full w-80 flex-col border-r border-zinc-100 bg-zinc-50">
<div className="shrink-0 px-6 py-4">
<Text variant="h3" size="body-medium">
Your chats
</Text>
</div>
<div
className={cn(
"flex min-h-0 flex-1 flex-col overflow-y-auto px-3 py-3",
scrollbarStyles,
)}
>
<SessionsList
sessions={sessions}
currentSessionId={currentSessionId}
isLoading={isLoading}
hasNextPage={hasNextPage}
isFetchingNextPage={isFetchingNextPage}
onSelectSession={onSelectSession}
onFetchNextPage={onFetchNextPage}
/>
</div>
{hasActiveSession && (
<div className="shrink-0 bg-zinc-50 p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
<Button
variant="primary"
size="small"
onClick={onNewChat}
className="w-full"
leftIcon={<Plus width="1rem" height="1rem" />}
>
New Chat
</Button>
</div>
)}
</aside>
);
}

View File

@@ -1,91 +0,0 @@
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { Button } from "@/components/atoms/Button/Button";
import { scrollbarStyles } from "@/components/styles/scrollbars";
import { cn } from "@/lib/utils";
import { PlusIcon, X } from "@phosphor-icons/react";
import { Drawer } from "vaul";
import { SessionsList } from "../SessionsList/SessionsList";
interface Props {
isOpen: boolean;
sessions: SessionSummaryResponse[];
currentSessionId: string | null;
isLoading: boolean;
hasNextPage: boolean;
isFetchingNextPage: boolean;
onSelectSession: (sessionId: string) => void;
onFetchNextPage: () => void;
onNewChat: () => void;
onClose: () => void;
onOpenChange: (open: boolean) => void;
hasActiveSession: boolean;
}
export function MobileDrawer({
isOpen,
sessions,
currentSessionId,
isLoading,
hasNextPage,
isFetchingNextPage,
onSelectSession,
onFetchNextPage,
onNewChat,
onClose,
onOpenChange,
hasActiveSession,
}: Props) {
return (
<Drawer.Root open={isOpen} onOpenChange={onOpenChange} direction="left">
<Drawer.Portal>
<Drawer.Overlay className="fixed inset-0 z-[60] bg-black/10 backdrop-blur-sm" />
<Drawer.Content className="fixed left-0 top-0 z-[70] flex h-full w-80 flex-col border-r border-zinc-200 bg-zinc-50">
<div className="shrink-0 border-b border-zinc-200 p-4">
<div className="flex items-center justify-between">
<Drawer.Title className="text-lg font-semibold text-zinc-800">
Your chats
</Drawer.Title>
<Button
variant="icon"
size="icon"
aria-label="Close sessions"
onClick={onClose}
>
<X width="1.25rem" height="1.25rem" />
</Button>
</div>
</div>
<div
className={cn(
"flex min-h-0 flex-1 flex-col overflow-y-auto px-3 py-3",
scrollbarStyles,
)}
>
<SessionsList
sessions={sessions}
currentSessionId={currentSessionId}
isLoading={isLoading}
hasNextPage={hasNextPage}
isFetchingNextPage={isFetchingNextPage}
onSelectSession={onSelectSession}
onFetchNextPage={onFetchNextPage}
/>
</div>
{hasActiveSession && (
<div className="shrink-0 bg-white p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
<Button
variant="primary"
size="small"
onClick={onNewChat}
className="w-full"
leftIcon={<PlusIcon width="1rem" height="1rem" />}
>
New Chat
</Button>
</div>
)}
</Drawer.Content>
</Drawer.Portal>
</Drawer.Root>
);
}

View File

@@ -1,24 +0,0 @@
import { useState } from "react";
export function useMobileDrawer() {
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
const handleOpenDrawer = () => {
setIsDrawerOpen(true);
};
const handleCloseDrawer = () => {
setIsDrawerOpen(false);
};
const handleDrawerOpenChange = (open: boolean) => {
setIsDrawerOpen(open);
};
return {
isDrawerOpen,
handleOpenDrawer,
handleCloseDrawer,
handleDrawerOpenChange,
};
}

View File

@@ -1,80 +0,0 @@
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
import { Text } from "@/components/atoms/Text/Text";
import { InfiniteList } from "@/components/molecules/InfiniteList/InfiniteList";
import { cn } from "@/lib/utils";
import { getSessionTitle } from "../../helpers";
interface Props {
sessions: SessionSummaryResponse[];
currentSessionId: string | null;
isLoading: boolean;
hasNextPage: boolean;
isFetchingNextPage: boolean;
onSelectSession: (sessionId: string) => void;
onFetchNextPage: () => void;
}
export function SessionsList({
sessions,
currentSessionId,
isLoading,
hasNextPage,
isFetchingNextPage,
onSelectSession,
onFetchNextPage,
}: Props) {
if (isLoading) {
return (
<div className="space-y-1">
{Array.from({ length: 5 }).map((_, i) => (
<div key={i} className="rounded-lg px-3 py-2.5">
<Skeleton className="h-5 w-full" />
</div>
))}
</div>
);
}
if (sessions.length === 0) {
return (
<div className="flex h-full items-center justify-center">
<Text variant="body" className="text-zinc-500">
You don&apos;t have previous chats
</Text>
</div>
);
}
return (
<InfiniteList
items={sessions}
hasMore={hasNextPage}
isFetchingMore={isFetchingNextPage}
onEndReached={onFetchNextPage}
className="space-y-1"
renderItem={(session) => {
const isActive = session.id === currentSessionId;
return (
<button
onClick={() => onSelectSession(session.id)}
className={cn(
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
isActive ? "bg-zinc-100" : "hover:bg-zinc-50",
)}
>
<Text
variant="body"
className={cn(
"font-normal",
isActive ? "text-zinc-600" : "text-zinc-800",
)}
>
{getSessionTitle(session)}
</Text>
</button>
);
}}
/>
);
}

View File

@@ -1,91 +0,0 @@
import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat";
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { okData } from "@/app/api/helpers";
import { useEffect, useState } from "react";
const PAGE_SIZE = 50;
export interface UseSessionsPaginationArgs {
enabled: boolean;
}
export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
const [offset, setOffset] = useState(0);
const [accumulatedSessions, setAccumulatedSessions] = useState<
SessionSummaryResponse[]
>([]);
const [totalCount, setTotalCount] = useState<number | null>(null);
const { data, isLoading, isFetching, isError } = useGetV2ListSessions(
{ limit: PAGE_SIZE, offset },
{
query: {
enabled: enabled && offset >= 0,
},
},
);
useEffect(() => {
const responseData = okData(data);
if (responseData) {
const newSessions = responseData.sessions;
const total = responseData.total;
setTotalCount(total);
if (offset === 0) {
setAccumulatedSessions(newSessions);
} else {
setAccumulatedSessions((prev) => [...prev, ...newSessions]);
}
} else if (!enabled) {
setAccumulatedSessions([]);
setTotalCount(null);
}
}, [data, offset, enabled]);
const hasNextPage =
totalCount !== null && accumulatedSessions.length < totalCount;
const areAllSessionsLoaded =
totalCount !== null &&
accumulatedSessions.length >= totalCount &&
!isFetching &&
!isLoading;
useEffect(() => {
if (
hasNextPage &&
!isFetching &&
!isLoading &&
!isError &&
totalCount !== null
) {
setOffset((prev) => prev + PAGE_SIZE);
}
}, [hasNextPage, isFetching, isLoading, isError, totalCount]);
const fetchNextPage = () => {
if (hasNextPage && !isFetching) {
setOffset((prev) => prev + PAGE_SIZE);
}
};
const reset = () => {
// Only reset the offset - keep existing sessions visible during refetch
// The effect will replace sessions when new data arrives at offset 0
setOffset(0);
};
return {
sessions: accumulatedSessions,
isLoading,
isFetching,
hasNextPage,
areAllSessionsLoaded,
totalCount,
fetchNextPage,
reset,
};
}

View File

@@ -1,106 +0,0 @@
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { format, formatDistanceToNow, isToday } from "date-fns";
export function convertSessionDetailToSummary(session: SessionDetailResponse) {
return {
id: session.id,
created_at: session.created_at,
updated_at: session.updated_at,
title: undefined,
};
}
export function filterVisibleSessions(sessions: SessionSummaryResponse[]) {
const fiveMinutesAgo = Date.now() - 5 * 60 * 1000;
return sessions.filter((session) => {
const hasBeenUpdated = session.updated_at !== session.created_at;
if (hasBeenUpdated) return true;
const isRecentlyCreated =
new Date(session.created_at).getTime() > fiveMinutesAgo;
return isRecentlyCreated;
});
}
export function getSessionTitle(session: SessionSummaryResponse) {
if (session.title) return session.title;
const isNewSession = session.updated_at === session.created_at;
if (isNewSession) {
const createdDate = new Date(session.created_at);
if (isToday(createdDate)) {
return "Today";
}
return format(createdDate, "MMM d, yyyy");
}
return "Untitled Chat";
}
export function getSessionUpdatedLabel(session: SessionSummaryResponse) {
if (!session.updated_at) return "";
return formatDistanceToNow(new Date(session.updated_at), { addSuffix: true });
}
export function mergeCurrentSessionIntoList(
accumulatedSessions: SessionSummaryResponse[],
currentSessionId: string | null,
currentSessionData: SessionDetailResponse | null | undefined,
recentlyCreatedSessions?: Map<string, SessionSummaryResponse>,
) {
const filteredSessions: SessionSummaryResponse[] = [];
const addedIds = new Set<string>();
if (accumulatedSessions.length > 0) {
const visibleSessions = filterVisibleSessions(accumulatedSessions);
if (currentSessionId) {
const currentInAll = accumulatedSessions.find(
(s) => s.id === currentSessionId,
);
if (currentInAll) {
const isInVisible = visibleSessions.some(
(s) => s.id === currentSessionId,
);
if (!isInVisible) {
filteredSessions.push(currentInAll);
addedIds.add(currentInAll.id);
}
}
}
for (const session of visibleSessions) {
if (!addedIds.has(session.id)) {
filteredSessions.push(session);
addedIds.add(session.id);
}
}
}
if (currentSessionId && currentSessionData) {
if (!addedIds.has(currentSessionId)) {
const summarySession = convertSessionDetailToSummary(currentSessionData);
filteredSessions.unshift(summarySession);
addedIds.add(currentSessionId);
}
}
if (recentlyCreatedSessions) {
for (const [sessionId, sessionData] of recentlyCreatedSessions) {
if (!addedIds.has(sessionId)) {
filteredSessions.unshift(sessionData);
addedIds.add(sessionId);
}
}
}
return filteredSessions;
}
export function getCurrentSessionId(searchParams: URLSearchParams) {
return searchParams.get("sessionId");
}

View File

@@ -1,124 +0,0 @@
"use client";
import {
getGetV2GetSessionQueryKey,
getGetV2ListSessionsQueryKey,
useGetV2GetSession,
} from "@/app/api/__generated__/endpoints/chat/chat";
import { okData } from "@/app/api/helpers";
import { useChatStore } from "@/components/contextual/Chat/chat-store";
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { useQueryClient } from "@tanstack/react-query";
import { usePathname, useSearchParams } from "next/navigation";
import { useCopilotStore } from "../../copilot-page-store";
import { useCopilotSessionId } from "../../useCopilotSessionId";
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
import { getCurrentSessionId } from "./helpers";
import { useShellSessionList } from "./useShellSessionList";
export function useCopilotShell() {
const pathname = usePathname();
const searchParams = useSearchParams();
const queryClient = useQueryClient();
const breakpoint = useBreakpoint();
const { isLoggedIn } = useSupabase();
const isMobile =
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
const { urlSessionId, setUrlSessionId } = useCopilotSessionId();
const isOnHomepage = pathname === "/copilot";
const paramSessionId = searchParams.get("sessionId");
const {
isDrawerOpen,
handleOpenDrawer,
handleCloseDrawer,
handleDrawerOpenChange,
} = useMobileDrawer();
const paginationEnabled = !isMobile || isDrawerOpen || !!paramSessionId;
const currentSessionId = getCurrentSessionId(searchParams);
const { data: currentSessionData } = useGetV2GetSession(
currentSessionId || "",
{
query: {
enabled: !!currentSessionId,
select: okData,
},
},
);
const {
sessions,
isLoading,
isSessionsFetching,
hasNextPage,
fetchNextPage,
resetPagination,
recentlyCreatedSessionsRef,
} = useShellSessionList({
paginationEnabled,
currentSessionId,
currentSessionData,
isOnHomepage,
paramSessionId,
});
const stopStream = useChatStore((s) => s.stopStream);
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
function handleSessionClick(sessionId: string) {
if (sessionId === currentSessionId) return;
// Stop current stream - SSE reconnection allows resuming later
if (currentSessionId) {
stopStream(currentSessionId);
}
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(sessionId),
});
}
setUrlSessionId(sessionId, { shallow: false });
if (isMobile) handleCloseDrawer();
}
function handleNewChatClick() {
// Stop current stream - SSE reconnection allows resuming later
if (currentSessionId) {
stopStream(currentSessionId);
}
resetPagination();
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
});
setUrlSessionId(null, { shallow: false });
if (isMobile) handleCloseDrawer();
}
return {
isMobile,
isDrawerOpen,
isLoggedIn,
hasActiveSession:
Boolean(currentSessionId) && (!isOnHomepage || Boolean(paramSessionId)),
isLoading: isLoading || isCreatingSession,
isCreatingSession,
sessions,
currentSessionId: urlSessionId,
handleOpenDrawer,
handleCloseDrawer,
handleDrawerOpenChange,
handleNewChatClick,
handleSessionClick,
hasNextPage,
isFetchingNextPage: isSessionsFetching,
fetchNextPage,
};
}

View File

@@ -1,113 +0,0 @@
import { getGetV2ListSessionsQueryKey } from "@/app/api/__generated__/endpoints/chat/chat";
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { useChatStore } from "@/components/contextual/Chat/chat-store";
import { useQueryClient } from "@tanstack/react-query";
import { useEffect, useMemo, useRef } from "react";
import { useSessionsPagination } from "./components/SessionsList/useSessionsPagination";
import {
convertSessionDetailToSummary,
filterVisibleSessions,
mergeCurrentSessionIntoList,
} from "./helpers";
interface UseShellSessionListArgs {
paginationEnabled: boolean;
currentSessionId: string | null;
currentSessionData: SessionDetailResponse | null | undefined;
isOnHomepage: boolean;
paramSessionId: string | null;
}
export function useShellSessionList({
paginationEnabled,
currentSessionId,
currentSessionData,
isOnHomepage,
paramSessionId,
}: UseShellSessionListArgs) {
const queryClient = useQueryClient();
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
const {
sessions: accumulatedSessions,
isLoading: isSessionsLoading,
isFetching: isSessionsFetching,
hasNextPage,
fetchNextPage,
reset: resetPagination,
} = useSessionsPagination({
enabled: paginationEnabled,
});
const recentlyCreatedSessionsRef = useRef<
Map<string, SessionSummaryResponse>
>(new Map());
useEffect(() => {
if (isOnHomepage && !paramSessionId) {
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
});
}
}, [isOnHomepage, paramSessionId, queryClient]);
useEffect(() => {
if (currentSessionId && currentSessionData) {
const isNewSession =
currentSessionData.updated_at === currentSessionData.created_at;
const isNotInAccumulated = !accumulatedSessions.some(
(s) => s.id === currentSessionId,
);
if (isNewSession || isNotInAccumulated) {
const summary = convertSessionDetailToSummary(currentSessionData);
recentlyCreatedSessionsRef.current.set(currentSessionId, summary);
}
}
}, [currentSessionId, currentSessionData, accumulatedSessions]);
useEffect(() => {
for (const sessionId of recentlyCreatedSessionsRef.current.keys()) {
if (accumulatedSessions.some((s) => s.id === sessionId)) {
recentlyCreatedSessionsRef.current.delete(sessionId);
}
}
}, [accumulatedSessions]);
useEffect(() => {
const unsubscribe = onStreamComplete(() => {
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
});
});
return unsubscribe;
}, [onStreamComplete, queryClient]);
const sessions = useMemo(
() =>
mergeCurrentSessionIntoList(
accumulatedSessions,
currentSessionId,
currentSessionData,
recentlyCreatedSessionsRef.current,
),
[accumulatedSessions, currentSessionId, currentSessionData],
);
const visibleSessions = useMemo(
() => filterVisibleSessions(sessions),
[sessions],
);
const isLoading = isSessionsLoading && accumulatedSessions.length === 0;
return {
sessions: visibleSessions,
isLoading,
isSessionsFetching,
hasNextPage,
fetchNextPage,
resetPagination,
recentlyCreatedSessionsRef,
};
}

View File

@@ -0,0 +1,111 @@
"use client";
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { SpinnerGapIcon } from "@phosphor-icons/react";
import { motion } from "framer-motion";
import { useEffect, useState } from "react";
import {
getGreetingName,
getInputPlaceholder,
getQuickActions,
} from "./helpers";
interface Props {
inputLayoutId: string;
isCreatingSession: boolean;
onCreateSession: () => void | Promise<string>;
onSend: (message: string) => void | Promise<void>;
}
export function EmptySession({
inputLayoutId,
isCreatingSession,
onSend,
}: Props) {
const { user } = useSupabase();
const greetingName = getGreetingName(user);
const quickActions = getQuickActions();
const [loadingAction, setLoadingAction] = useState<string | null>(null);
const [inputPlaceholder, setInputPlaceholder] = useState(
getInputPlaceholder(),
);
useEffect(() => {
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
}, [window.innerWidth]);
async function handleQuickActionClick(action: string) {
if (isCreatingSession || loadingAction) return;
setLoadingAction(action);
try {
await onSend(action);
} finally {
setLoadingAction(null);
}
}
return (
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-0 py-5 md:px-6 md:py-10">
<motion.div
className="w-full max-w-3xl text-center"
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.3 }}
>
<div className="mx-auto max-w-3xl">
<Text variant="h3" className="mb-1 !text-[1.375rem] text-zinc-700">
Hey, <span className="text-violet-600">{greetingName}</span>
</Text>
<Text variant="h3" className="mb-8 !font-normal">
Tell me about your work I&apos;ll find what to automate.
</Text>
<div className="mb-6">
<motion.div
layoutId={inputLayoutId}
transition={{ type: "spring", bounce: 0.2, duration: 0.65 }}
className="w-full px-2"
>
<ChatInput
inputId="chat-input-empty"
onSend={onSend}
disabled={isCreatingSession}
placeholder={inputPlaceholder}
className="w-full"
/>
</motion.div>
</div>
</div>
<div className="flex flex-wrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
{quickActions.map((action) => (
<Button
key={action}
type="button"
variant="outline"
size="small"
onClick={() => void handleQuickActionClick(action)}
disabled={isCreatingSession || loadingAction !== null}
aria-busy={loadingAction === action}
leftIcon={
loadingAction === action ? (
<SpinnerGapIcon
className="h-4 w-4 animate-spin"
weight="bold"
/>
) : null
}
className="h-auto shrink-0 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
>
{action}
</Button>
))}
</div>
</motion.div>
</div>
);
}

View File

@@ -1,6 +1,26 @@
import type { User } from "@supabase/supabase-js"; import { User } from "@supabase/supabase-js";
export function getGreetingName(user?: User | null): string { export function getInputPlaceholder(width?: number) {
if (!width) return "What's your role and what eats up most of your day?";
if (width < 500) {
return "I'm a chef and I hate...";
}
if (width <= 1080) {
return "What's your role and what eats up most of your day?";
}
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
}
export function getQuickActions() {
return [
"I don't know where to start, just ask me stuff",
"I do the same thing every week and it's killing me",
"Help me find where I'm wasting my time",
];
}
export function getGreetingName(user?: User | null) {
if (!user) return "there"; if (!user) return "there";
const metadata = user.user_metadata as Record<string, unknown> | undefined; const metadata = user.user_metadata as Record<string, unknown> | undefined;
const fullName = metadata?.full_name; const fullName = metadata?.full_name;
@@ -16,30 +36,3 @@ export function getGreetingName(user?: User | null): string {
} }
return "there"; return "there";
} }
export function buildCopilotChatUrl(prompt: string): string {
const trimmed = prompt.trim();
if (!trimmed) return "/copilot/chat";
const encoded = encodeURIComponent(trimmed);
return `/copilot/chat?prompt=${encoded}`;
}
export function getQuickActions(): string[] {
return [
"I don't know where to start, just ask me stuff",
"I do the same thing every week and it's killing me",
"Help me find where I'm wasting my time",
];
}
export function getInputPlaceholder(width?: number) {
if (!width) return "What's your role and what eats up most of your day?";
if (width < 500) {
return "I'm a chef and I hate...";
}
if (width <= 1080) {
return "What's your role and what eats up most of your day?";
}
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
}

View File

@@ -0,0 +1,140 @@
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { scrollbarStyles } from "@/components/styles/scrollbars";
import { cn } from "@/lib/utils";
import { PlusIcon, SpinnerGapIcon, X } from "@phosphor-icons/react";
import { Drawer } from "vaul";
interface Props {
isOpen: boolean;
sessions: SessionSummaryResponse[];
currentSessionId: string | null;
isLoading: boolean;
onSelectSession: (sessionId: string) => void;
onNewChat: () => void;
onClose: () => void;
onOpenChange: (open: boolean) => void;
}
function formatDate(dateString: string) {
const date = new Date(dateString);
const now = new Date();
const diffMs = now.getTime() - date.getTime();
const diffDays = Math.floor(diffMs / (1000 * 60 * 60 * 24));
if (diffDays === 0) return "Today";
if (diffDays === 1) return "Yesterday";
if (diffDays < 7) return `${diffDays} days ago`;
const day = date.getDate();
const ordinal =
day % 10 === 1 && day !== 11
? "st"
: day % 10 === 2 && day !== 12
? "nd"
: day % 10 === 3 && day !== 13
? "rd"
: "th";
const month = date.toLocaleDateString("en-US", { month: "short" });
const year = date.getFullYear();
return `${day}${ordinal} ${month} ${year}`;
}
export function MobileDrawer({
isOpen,
sessions,
currentSessionId,
isLoading,
onSelectSession,
onNewChat,
onClose,
onOpenChange,
}: Props) {
return (
<Drawer.Root open={isOpen} onOpenChange={onOpenChange} direction="left">
<Drawer.Portal>
<Drawer.Overlay className="fixed inset-0 z-[60] bg-black/10 backdrop-blur-sm" />
<Drawer.Content className="fixed left-0 top-0 z-[70] flex h-full w-80 flex-col border-r border-zinc-200 bg-zinc-50">
<div className="shrink-0 border-b border-zinc-200 px-4 py-2">
<div className="flex items-center justify-between">
<Drawer.Title className="text-lg font-semibold text-zinc-800">
Your chats
</Drawer.Title>
<Button
variant="icon"
size="icon"
aria-label="Close sessions"
onClick={onClose}
>
<X width="1rem" height="1rem" />
</Button>
</div>
</div>
<div
className={cn(
"flex min-h-0 flex-1 flex-col gap-1 overflow-y-auto px-3 py-3",
scrollbarStyles,
)}
>
{isLoading ? (
<div className="flex items-center justify-center py-4">
<SpinnerGapIcon className="h-5 w-5 animate-spin text-neutral-400" />
</div>
) : sessions.length === 0 ? (
<p className="py-4 text-center text-sm text-neutral-500">
No conversations yet
</p>
) : (
sessions.map((session) => (
<button
key={session.id}
onClick={() => onSelectSession(session.id)}
className={cn(
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
session.id === currentSessionId
? "bg-zinc-100"
: "hover:bg-zinc-50",
)}
>
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
<div className="min-w-0 max-w-full">
<Text
variant="body"
className={cn(
"truncate font-normal",
session.id === currentSessionId
? "text-zinc-600"
: "text-zinc-800",
)}
>
{session.title || "Untitled chat"}
</Text>
</div>
<Text variant="small" className="text-neutral-400">
{formatDate(session.updated_at)}
</Text>
</div>
</button>
))
)}
</div>
{currentSessionId && (
<div className="shrink-0 bg-white p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
<Button
variant="primary"
size="small"
onClick={onNewChat}
className="w-full"
leftIcon={<PlusIcon width="1rem" height="1rem" />}
>
New Chat
</Button>
</div>
)}
</Drawer.Content>
</Drawer.Portal>
</Drawer.Root>
);
}

View File

@@ -0,0 +1,54 @@
import { cn } from "@/lib/utils";
import { AnimatePresence, motion } from "framer-motion";
interface Props {
text: string;
className?: string;
}
export function MorphingTextAnimation({ text, className }: Props) {
const letters = text.split("");
return (
<div className={cn(className)}>
<AnimatePresence mode="popLayout" initial={false}>
<motion.div key={text} className="whitespace-nowrap">
<motion.span className="inline-flex overflow-hidden">
{letters.map((char, index) => (
<motion.span
key={`${text}-${index}`}
initial={{
opacity: 0,
y: 8,
rotateX: "80deg",
filter: "blur(6px)",
}}
animate={{
opacity: 1,
y: 0,
rotateX: "0deg",
filter: "blur(0px)",
}}
exit={{
opacity: 0,
y: -8,
rotateX: "-80deg",
filter: "blur(6px)",
}}
style={{ willChange: "transform" }}
transition={{
delay: 0.015 * index,
type: "spring",
bounce: 0.5,
}}
className="inline-block"
>
{char === " " ? "\u00A0" : char}
</motion.span>
))}
</motion.span>
</motion.div>
</AnimatePresence>
</div>
);
}

View File

@@ -0,0 +1,69 @@
.loader {
position: relative;
animation: rotate 1s infinite;
}
.loader::before,
.loader::after {
border-radius: 50%;
content: "";
display: block;
/* 40% of container size */
height: 40%;
width: 40%;
}
.loader::before {
animation: ball1 1s infinite;
background-color: #a1a1aa; /* zinc-400 */
box-shadow: calc(var(--spacing)) 0 0 #18181b; /* zinc-900 */
margin-bottom: calc(var(--gap));
}
.loader::after {
animation: ball2 1s infinite;
background-color: #18181b; /* zinc-900 */
box-shadow: calc(var(--spacing)) 0 0 #a1a1aa; /* zinc-400 */
}
@keyframes rotate {
0% {
transform: rotate(0deg) scale(0.8);
}
50% {
transform: rotate(360deg) scale(1.2);
}
100% {
transform: rotate(720deg) scale(0.8);
}
}
@keyframes ball1 {
0% {
box-shadow: calc(var(--spacing)) 0 0 #18181b;
}
50% {
box-shadow: 0 0 0 #18181b;
margin-bottom: 0;
transform: translate(calc(var(--spacing) / 2), calc(var(--spacing) / 2));
}
100% {
box-shadow: calc(var(--spacing)) 0 0 #18181b;
margin-bottom: calc(var(--gap));
}
}
@keyframes ball2 {
0% {
box-shadow: calc(var(--spacing)) 0 0 #a1a1aa;
}
50% {
box-shadow: 0 0 0 #a1a1aa;
margin-top: calc(var(--ball-size) * -1);
transform: translate(calc(var(--spacing) / 2), calc(var(--spacing) / 2));
}
100% {
box-shadow: calc(var(--spacing)) 0 0 #a1a1aa;
margin-top: 0;
}
}

View File

@@ -0,0 +1,28 @@
import { cn } from "@/lib/utils";
import styles from "./OrbitLoader.module.css";
interface Props {
size?: number;
className?: string;
}
export function OrbitLoader({ size = 24, className }: Props) {
const ballSize = Math.round(size * 0.4);
const spacing = Math.round(size * 0.6);
const gap = Math.round(size * 0.2);
return (
<div
className={cn(styles.loader, className)}
style={
{
width: size,
height: size,
"--ball-size": `${ballSize}px`,
"--spacing": `${spacing}px`,
"--gap": `${gap}px`,
} as React.CSSProperties
}
/>
);
}

View File

@@ -0,0 +1,26 @@
import { cn } from "@/lib/utils";
interface Props {
value: number;
label?: string;
className?: string;
}
export function ProgressBar({ value, label, className }: Props) {
const clamped = Math.min(100, Math.max(0, value));
return (
<div className={cn("flex flex-col gap-1.5", className)}>
<div className="flex items-center justify-between text-xs text-neutral-500">
<span>{label ?? "Working on it..."}</span>
<span>{Math.round(clamped)}%</span>
</div>
<div className="h-2 w-full overflow-hidden rounded-full bg-neutral-200">
<div
className="h-full rounded-full bg-neutral-900 transition-[width] duration-300 ease-out"
style={{ width: `${clamped}%` }}
/>
</div>
</div>
);
}

View File

@@ -0,0 +1,34 @@
.loader {
position: relative;
display: inline-block;
flex-shrink: 0;
}
.loader::before,
.loader::after {
content: "";
box-sizing: border-box;
width: 100%;
height: 100%;
border-radius: 50%;
background: currentColor;
position: absolute;
left: 0;
top: 0;
animation: ripple 2s linear infinite;
}
.loader::after {
animation-delay: 1s;
}
@keyframes ripple {
0% {
transform: scale(0);
opacity: 1;
}
100% {
transform: scale(1);
opacity: 0;
}
}

View File

@@ -0,0 +1,16 @@
import { cn } from "@/lib/utils";
import styles from "./PulseLoader.module.css";
interface Props {
size?: number;
className?: string;
}
export function PulseLoader({ size = 24, className }: Props) {
return (
<div
className={cn(styles.loader, className)}
style={{ width: size, height: size }}
/>
);
}

View File

@@ -0,0 +1,57 @@
.loader {
position: relative;
display: inline-block;
flex-shrink: 0;
transform: rotateZ(45deg);
perspective: 1000px;
border-radius: 50%;
color: currentColor;
}
.loader::before,
.loader::after {
content: "";
display: block;
position: absolute;
top: 0;
left: 0;
width: inherit;
height: inherit;
border-radius: 50%;
transform: rotateX(70deg);
animation: spin 1s linear infinite;
}
.loader::after {
color: var(--spinner-accent, #a855f7);
transform: rotateY(70deg);
animation-delay: 0.4s;
}
@keyframes spin {
0%,
100% {
box-shadow: 0.2em 0 0 0 currentColor;
}
12% {
box-shadow: 0.2em 0.2em 0 0 currentColor;
}
25% {
box-shadow: 0 0.2em 0 0 currentColor;
}
37% {
box-shadow: -0.2em 0.2em 0 0 currentColor;
}
50% {
box-shadow: -0.2em 0 0 0 currentColor;
}
62% {
box-shadow: -0.2em -0.2em 0 0 currentColor;
}
75% {
box-shadow: 0 -0.2em 0 0 currentColor;
}
87% {
box-shadow: 0.2em -0.2em 0 0 currentColor;
}
}

View File

@@ -0,0 +1,16 @@
import { cn } from "@/lib/utils";
import styles from "./SpinnerLoader.module.css";
interface Props {
size?: number;
className?: string;
}
export function SpinnerLoader({ size = 24, className }: Props) {
return (
<div
className={cn(styles.loader, className)}
style={{ width: size, height: size }}
/>
);
}

View File

@@ -0,0 +1,235 @@
import { Link } from "@/components/atoms/Link/Link";
import { Text } from "@/components/atoms/Text/Text";
import { cn } from "@/lib/utils";
/* ------------------------------------------------------------------ */
/* Layout */
/* ------------------------------------------------------------------ */
export function ContentGrid({
children,
className,
}: {
children: React.ReactNode;
className?: string;
}) {
return <div className={cn("grid gap-2", className)}>{children}</div>;
}
/* ------------------------------------------------------------------ */
/* Card */
/* ------------------------------------------------------------------ */
export function ContentCard({
children,
className,
}: {
children: React.ReactNode;
className?: string;
}) {
return (
<div
className={cn(
"rounded-lg bg-gradient-to-r from-purple-500/30 to-blue-500/30 p-[1px]",
className,
)}
>
<div className="rounded-lg bg-neutral-100 p-3">{children}</div>
</div>
);
}
/** Flex row with a left content area (`children`) and an optional rightside `action`. */
export function ContentCardHeader({
children,
action,
className,
}: {
children: React.ReactNode;
action?: React.ReactNode;
className?: string;
}) {
return (
<div className={cn("flex items-start justify-between gap-2", className)}>
<div className="min-w-0">{children}</div>
{action}
</div>
);
}
export function ContentCardTitle({
children,
className,
}: {
children: React.ReactNode;
className?: string;
}) {
return (
<Text
variant="body-medium"
className={cn("truncate text-zinc-800", className)}
>
{children}
</Text>
);
}
export function ContentCardSubtitle({
children,
className,
}: {
children: React.ReactNode;
className?: string;
}) {
return (
<Text
variant="small"
className={cn("mt-0.5 truncate font-mono text-zinc-800", className)}
>
{children}
</Text>
);
}
export function ContentCardDescription({
children,
className,
}: {
children: React.ReactNode;
className?: string;
}) {
return (
<Text variant="body" className={cn("mt-2 text-zinc-800", className)}>
{children}
</Text>
);
}
/* ------------------------------------------------------------------ */
/* Text */
/* ------------------------------------------------------------------ */
export function ContentMessage({
children,
className,
}: {
children: React.ReactNode;
className?: string;
}) {
return (
<Text variant="body" className={cn("text-zinc-800", className)}>
{children}
</Text>
);
}
export function ContentHint({
children,
className,
}: {
children: React.ReactNode;
className?: string;
}) {
return (
<Text variant="small" className={cn("text-neutral-500", className)}>
{children}
</Text>
);
}
/* ------------------------------------------------------------------ */
/* Code / data */
/* ------------------------------------------------------------------ */
export function ContentCodeBlock({
children,
className,
}: {
children: React.ReactNode;
className?: string;
}) {
return (
<pre
className={cn(
"whitespace-pre-wrap rounded-lg border bg-black p-3 text-xs text-neutral-200",
className,
)}
>
{children}
</pre>
);
}
/* ------------------------------------------------------------------ */
/* Inline elements */
/* ------------------------------------------------------------------ */
export function ContentBadge({
children,
className,
}: {
children: React.ReactNode;
className?: string;
}) {
return (
<Text
variant="small"
as="span"
className={cn(
"shrink-0 rounded-full border bg-muted px-2 py-0.5 text-[11px] text-zinc-800",
className,
)}
>
{children}
</Text>
);
}
export function ContentLink({
href,
children,
className,
...rest
}: Omit<React.ComponentProps<typeof Link>, "className"> & {
className?: string;
}) {
return (
<Link
variant="primary"
isExternal
href={href}
className={cn("shrink-0 text-xs text-purple-500", className)}
{...rest}
>
{children}
</Link>
);
}
/* ------------------------------------------------------------------ */
/* Lists */
/* ------------------------------------------------------------------ */
export function ContentSuggestionsList({
items,
max = 5,
className,
}: {
items: string[];
max?: number;
className?: string;
}) {
if (items.length === 0) return null;
return (
<ul
className={cn(
"mt-2 list-disc space-y-1 pl-5 font-sans text-[0.75rem] leading-[1.125rem] text-zinc-800",
className,
)}
>
{items.slice(0, max).map((s) => (
<li key={s}>{s}</li>
))}
</ul>
);
}

View File

@@ -0,0 +1,102 @@
"use client";
import { cn } from "@/lib/utils";
import { CaretDownIcon } from "@phosphor-icons/react";
import { AnimatePresence, motion, useReducedMotion } from "framer-motion";
import { useId } from "react";
import { useToolAccordion } from "./useToolAccordion";
interface Props {
icon: React.ReactNode;
title: React.ReactNode;
titleClassName?: string;
description?: React.ReactNode;
children: React.ReactNode;
className?: string;
defaultExpanded?: boolean;
expanded?: boolean;
onExpandedChange?: (expanded: boolean) => void;
}
export function ToolAccordion({
icon,
title,
titleClassName,
description,
children,
className,
defaultExpanded,
expanded,
onExpandedChange,
}: Props) {
const shouldReduceMotion = useReducedMotion();
const contentId = useId();
const { isExpanded, toggle } = useToolAccordion({
expanded,
defaultExpanded,
onExpandedChange,
});
return (
<div
className={cn(
"mt-2 w-full rounded-lg border border-slate-200 bg-slate-100 px-3 py-2",
className,
)}
>
<button
type="button"
aria-expanded={isExpanded}
aria-controls={contentId}
onClick={toggle}
className="flex w-full items-center justify-between gap-3 py-1 text-left"
>
<div className="flex min-w-0 items-center gap-3">
<span className="flex shrink-0 items-center text-gray-800">
{icon}
</span>
<div className="min-w-0">
<p
className={cn(
"truncate text-sm font-medium text-gray-800",
titleClassName,
)}
>
{title}
</p>
{description && (
<p className="truncate text-xs text-slate-800">{description}</p>
)}
</div>
</div>
<CaretDownIcon
className={cn(
"h-4 w-4 shrink-0 text-slate-500 transition-transform",
isExpanded && "rotate-180",
)}
weight="bold"
/>
</button>
<AnimatePresence initial={false}>
{isExpanded && (
<motion.div
id={contentId}
initial={{ height: 0, opacity: 0, filter: "blur(10px)" }}
animate={{ height: "auto", opacity: 1, filter: "blur(0px)" }}
exit={{ height: 0, opacity: 0, filter: "blur(10px)" }}
transition={
shouldReduceMotion
? { duration: 0 }
: { type: "spring", bounce: 0.35, duration: 0.55 }
}
className="overflow-hidden"
style={{ willChange: "height, opacity, filter" }}
>
<div className="pb-2 pt-3">{children}</div>
</motion.div>
)}
</AnimatePresence>
</div>
);
}

View File

@@ -0,0 +1,32 @@
import { useState } from "react";
interface UseToolAccordionOptions {
expanded?: boolean;
defaultExpanded?: boolean;
onExpandedChange?: (expanded: boolean) => void;
}
interface UseToolAccordionResult {
isExpanded: boolean;
toggle: () => void;
}
export function useToolAccordion({
expanded,
defaultExpanded = false,
onExpandedChange,
}: UseToolAccordionOptions): UseToolAccordionResult {
const [uncontrolledExpanded, setUncontrolledExpanded] =
useState(defaultExpanded);
const isControlled = typeof expanded === "boolean";
const isExpanded = isControlled ? expanded : uncontrolledExpanded;
function toggle() {
const next = !isExpanded;
if (!isControlled) setUncontrolledExpanded(next);
onExpandedChange?.(next);
}
return { isExpanded, toggle };
}

View File

@@ -1,56 +0,0 @@
"use client";
import { create } from "zustand";
interface CopilotStoreState {
isStreaming: boolean;
isSwitchingSession: boolean;
isCreatingSession: boolean;
isInterruptModalOpen: boolean;
pendingAction: (() => void) | null;
}
interface CopilotStoreActions {
setIsStreaming: (isStreaming: boolean) => void;
setIsSwitchingSession: (isSwitchingSession: boolean) => void;
setIsCreatingSession: (isCreating: boolean) => void;
openInterruptModal: (onConfirm: () => void) => void;
confirmInterrupt: () => void;
cancelInterrupt: () => void;
}
type CopilotStore = CopilotStoreState & CopilotStoreActions;
export const useCopilotStore = create<CopilotStore>((set, get) => ({
isStreaming: false,
isSwitchingSession: false,
isCreatingSession: false,
isInterruptModalOpen: false,
pendingAction: null,
setIsStreaming(isStreaming) {
set({ isStreaming });
},
setIsSwitchingSession(isSwitchingSession) {
set({ isSwitchingSession });
},
setIsCreatingSession(isCreatingSession) {
set({ isCreatingSession });
},
openInterruptModal(onConfirm) {
set({ isInterruptModalOpen: true, pendingAction: onConfirm });
},
confirmInterrupt() {
const { pendingAction } = get();
set({ isInterruptModalOpen: false, pendingAction: null });
if (pendingAction) pendingAction();
},
cancelInterrupt() {
set({ isInterruptModalOpen: false, pendingAction: null });
},
}));

View File

@@ -0,0 +1,128 @@
import type { UIMessage, UIDataTypes, UITools } from "ai";
interface SessionChatMessage {
role: string;
content: string | null;
tool_call_id: string | null;
tool_calls: unknown[] | null;
}
function coerceSessionChatMessages(
rawMessages: unknown[],
): SessionChatMessage[] {
return rawMessages
.map((m) => {
if (!m || typeof m !== "object") return null;
const msg = m as Record<string, unknown>;
const role = typeof msg.role === "string" ? msg.role : null;
if (!role) return null;
return {
role,
content:
typeof msg.content === "string"
? msg.content
: msg.content == null
? null
: String(msg.content),
tool_call_id:
typeof msg.tool_call_id === "string"
? msg.tool_call_id
: msg.tool_call_id == null
? null
: String(msg.tool_call_id),
tool_calls: Array.isArray(msg.tool_calls) ? msg.tool_calls : null,
};
})
.filter((m): m is SessionChatMessage => m !== null);
}
function safeJsonParse(value: string): unknown {
try {
return JSON.parse(value) as unknown;
} catch {
return value;
}
}
function toToolInput(rawArguments: unknown): unknown {
if (typeof rawArguments === "string") {
const trimmed = rawArguments.trim();
return trimmed ? safeJsonParse(trimmed) : {};
}
if (rawArguments && typeof rawArguments === "object") return rawArguments;
return {};
}
export function convertChatSessionMessagesToUiMessages(
sessionId: string,
rawMessages: unknown[],
): UIMessage<unknown, UIDataTypes, UITools>[] {
const messages = coerceSessionChatMessages(rawMessages);
const toolOutputsByCallId = new Map<string, unknown>();
for (const msg of messages) {
if (msg.role !== "tool") continue;
if (!msg.tool_call_id) continue;
if (msg.content == null) continue;
toolOutputsByCallId.set(msg.tool_call_id, msg.content);
}
const uiMessages: UIMessage<unknown, UIDataTypes, UITools>[] = [];
messages.forEach((msg, index) => {
if (msg.role === "tool") return;
if (msg.role !== "user" && msg.role !== "assistant") return;
const parts: UIMessage<unknown, UIDataTypes, UITools>["parts"] = [];
if (typeof msg.content === "string" && msg.content.trim()) {
parts.push({ type: "text", text: msg.content, state: "done" });
}
if (msg.role === "assistant" && Array.isArray(msg.tool_calls)) {
for (const rawToolCall of msg.tool_calls) {
if (!rawToolCall || typeof rawToolCall !== "object") continue;
const toolCall = rawToolCall as {
id?: unknown;
function?: { name?: unknown; arguments?: unknown };
};
const toolCallId = String(toolCall.id ?? "").trim();
const toolName = String(toolCall.function?.name ?? "").trim();
if (!toolCallId || !toolName) continue;
const input = toToolInput(toolCall.function?.arguments);
const output = toolOutputsByCallId.get(toolCallId);
if (output !== undefined) {
parts.push({
type: `tool-${toolName}`,
toolCallId,
state: "output-available",
input,
output: typeof output === "string" ? safeJsonParse(output) : output,
});
} else {
parts.push({
type: `tool-${toolName}`,
toolCallId,
state: "input-available",
input,
});
}
}
}
if (parts.length === 0) return;
uiMessages.push({
id: `${sessionId}-${index}`,
role: msg.role,
parts,
});
});
return uiMessages;
}

Some files were not shown because too many files have changed in this diff Show More